atlan-application-sdk 0.1.1rc34__py3-none-any.whl → 0.1.1rc36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. application_sdk/activities/__init__.py +3 -2
  2. application_sdk/activities/common/utils.py +21 -1
  3. application_sdk/activities/lock_management.py +110 -0
  4. application_sdk/activities/metadata_extraction/base.py +4 -2
  5. application_sdk/activities/metadata_extraction/sql.py +13 -12
  6. application_sdk/activities/query_extraction/sql.py +24 -20
  7. application_sdk/clients/atlan_auth.py +2 -2
  8. application_sdk/clients/redis.py +443 -0
  9. application_sdk/clients/temporal.py +36 -196
  10. application_sdk/common/error_codes.py +24 -3
  11. application_sdk/constants.py +18 -1
  12. application_sdk/decorators/__init__.py +0 -0
  13. application_sdk/decorators/locks.py +42 -0
  14. application_sdk/handlers/base.py +18 -1
  15. application_sdk/inputs/json.py +6 -4
  16. application_sdk/inputs/parquet.py +16 -13
  17. application_sdk/interceptors/__init__.py +0 -0
  18. application_sdk/interceptors/events.py +193 -0
  19. application_sdk/interceptors/lock.py +139 -0
  20. application_sdk/outputs/__init__.py +6 -3
  21. application_sdk/outputs/json.py +9 -6
  22. application_sdk/outputs/parquet.py +10 -36
  23. application_sdk/server/fastapi/__init__.py +4 -5
  24. application_sdk/services/__init__.py +18 -0
  25. application_sdk/{outputs → services}/atlan_storage.py +64 -16
  26. application_sdk/{outputs → services}/eventstore.py +68 -6
  27. application_sdk/services/objectstore.py +407 -0
  28. application_sdk/services/secretstore.py +344 -0
  29. application_sdk/services/statestore.py +267 -0
  30. application_sdk/version.py +1 -1
  31. application_sdk/worker.py +1 -1
  32. {atlan_application_sdk-0.1.1rc34.dist-info → atlan_application_sdk-0.1.1rc36.dist-info}/METADATA +4 -2
  33. {atlan_application_sdk-0.1.1rc34.dist-info → atlan_application_sdk-0.1.1rc36.dist-info}/RECORD +36 -32
  34. application_sdk/common/credential_utils.py +0 -85
  35. application_sdk/inputs/objectstore.py +0 -238
  36. application_sdk/inputs/secretstore.py +0 -130
  37. application_sdk/inputs/statestore.py +0 -101
  38. application_sdk/outputs/objectstore.py +0 -125
  39. application_sdk/outputs/secretstore.py +0 -38
  40. application_sdk/outputs/statestore.py +0 -113
  41. {atlan_application_sdk-0.1.1rc34.dist-info → atlan_application_sdk-0.1.1rc36.dist-info}/WHEEL +0 -0
  42. {atlan_application_sdk-0.1.1rc34.dist-info → atlan_application_sdk-0.1.1rc36.dist-info}/licenses/LICENSE +0 -0
  43. {atlan_application_sdk-0.1.1rc34.dist-info → atlan_application_sdk-0.1.1rc36.dist-info}/licenses/NOTICE +0 -0
@@ -1,22 +1,12 @@
1
1
  import asyncio
2
2
  import uuid
3
3
  from concurrent.futures import ThreadPoolExecutor
4
- from datetime import timedelta
5
4
  from typing import Any, Dict, Optional, Sequence, Type
6
5
 
7
6
  from temporalio import activity, workflow
8
7
  from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError
9
- from temporalio.common import RetryPolicy
10
8
  from temporalio.types import CallableType, ClassType
11
- from temporalio.worker import (
12
- ActivityInboundInterceptor,
13
- ExecuteActivityInput,
14
- ExecuteWorkflowInput,
15
- Interceptor,
16
- Worker,
17
- WorkflowInboundInterceptor,
18
- WorkflowInterceptorClassInput,
19
- )
9
+ from temporalio.worker import Worker
20
10
  from temporalio.worker.workflow_sandbox import (
21
11
  SandboxedWorkflowRunner,
22
12
  SandboxRestrictions,
@@ -28,6 +18,7 @@ from application_sdk.constants import (
28
18
  APPLICATION_NAME,
29
19
  DEPLOYMENT_NAME,
30
20
  DEPLOYMENT_NAME_KEY,
21
+ IS_LOCKING_DISABLED,
31
22
  MAX_CONCURRENT_ACTIVITIES,
32
23
  WORKFLOW_HOST,
33
24
  WORKFLOW_MAX_TIMEOUT_HOURS,
@@ -35,19 +26,11 @@ from application_sdk.constants import (
35
26
  WORKFLOW_PORT,
36
27
  WORKFLOW_TLS_ENABLED_KEY,
37
28
  )
38
- from application_sdk.events.models import (
39
- ApplicationEventNames,
40
- Event,
41
- EventMetadata,
42
- EventTypes,
43
- WorkflowStates,
44
- )
45
- from application_sdk.inputs.secretstore import SecretStoreInput
46
- from application_sdk.inputs.statestore import StateType
29
+ from application_sdk.interceptors.events import EventInterceptor, publish_event
30
+ from application_sdk.interceptors.lock import RedisLockInterceptor
47
31
  from application_sdk.observability.logger_adaptor import get_logger
48
- from application_sdk.outputs.eventstore import EventStore
49
- from application_sdk.outputs.secretstore import SecretStoreOutput
50
- from application_sdk.outputs.statestore import StateStoreOutput
32
+ from application_sdk.services.secretstore import SecretStore
33
+ from application_sdk.services.statestore import StateStore, StateType
51
34
  from application_sdk.workflows import WorkflowInterface
52
35
 
53
36
  logger = get_logger(__name__)
@@ -57,170 +40,6 @@ TEMPORAL_NOT_FOUND_FAILURE = (
57
40
  )
58
41
 
59
42
 
60
- # Activity for publishing events (runs outside sandbox)
61
- @activity.defn
62
- async def publish_event(event_data: dict) -> None:
63
- """Activity to publish events outside the workflow sandbox.
64
-
65
- Args:
66
- event_data (dict): Event data to publish containing event_type, event_name,
67
- metadata, and data fields.
68
- """
69
- try:
70
- event = Event(**event_data)
71
- await EventStore.publish_event(event)
72
- activity.logger.info(f"Published event: {event_data.get('event_name','')}")
73
- except Exception as e:
74
- activity.logger.error(f"Failed to publish event: {e}")
75
- raise
76
-
77
-
78
- class EventActivityInboundInterceptor(ActivityInboundInterceptor):
79
- """Interceptor for tracking activity execution events.
80
-
81
- This interceptor captures the start and end of activity executions,
82
- creating events that can be used for monitoring and tracking.
83
- Activities run outside the sandbox so they can directly call EventStore.
84
- """
85
-
86
- async def execute_activity(self, input: ExecuteActivityInput) -> Any:
87
- """Execute an activity with event tracking.
88
-
89
- Args:
90
- input (ExecuteActivityInput): The activity execution input.
91
-
92
- Returns:
93
- Any: The result of the activity execution.
94
- """
95
- # Extract activity information for tracking
96
-
97
- start_event = Event(
98
- event_type=EventTypes.APPLICATION_EVENT.value,
99
- event_name=ApplicationEventNames.ACTIVITY_START.value,
100
- data={},
101
- )
102
- await EventStore.publish_event(start_event)
103
-
104
- output = None
105
- try:
106
- output = await super().execute_activity(input)
107
- except Exception:
108
- raise
109
- finally:
110
- end_event = Event(
111
- event_type=EventTypes.APPLICATION_EVENT.value,
112
- event_name=ApplicationEventNames.ACTIVITY_END.value,
113
- data={},
114
- )
115
- await EventStore.publish_event(end_event)
116
-
117
- return output
118
-
119
-
120
- class EventWorkflowInboundInterceptor(WorkflowInboundInterceptor):
121
- """Interceptor for tracking workflow execution events.
122
-
123
- This interceptor captures the start and end of workflow executions,
124
- creating events that can be used for monitoring and tracking.
125
- Uses activities to publish events to avoid sandbox restrictions.
126
- """
127
-
128
- async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
129
- """Execute a workflow with event tracking.
130
-
131
- Args:
132
- input (ExecuteWorkflowInput): The workflow execution input.
133
-
134
- Returns:
135
- Any: The result of the workflow execution.
136
- """
137
-
138
- # Publish workflow start event via activity
139
- try:
140
- await workflow.execute_activity(
141
- publish_event,
142
- {
143
- "metadata": EventMetadata(
144
- workflow_state=WorkflowStates.RUNNING.value
145
- ),
146
- "event_type": EventTypes.APPLICATION_EVENT.value,
147
- "event_name": ApplicationEventNames.WORKFLOW_START.value,
148
- "data": {},
149
- },
150
- schedule_to_close_timeout=timedelta(seconds=30),
151
- retry_policy=RetryPolicy(maximum_attempts=3),
152
- )
153
- except Exception as e:
154
- workflow.logger.warning(f"Failed to publish workflow start event: {e}")
155
- # Don't fail the workflow if event publishing fails
156
-
157
- output = None
158
- workflow_state = WorkflowStates.FAILED.value # Default to failed
159
-
160
- try:
161
- output = await super().execute_workflow(input)
162
- workflow_state = (
163
- WorkflowStates.COMPLETED.value
164
- ) # Update to completed on success
165
- except Exception:
166
- workflow_state = WorkflowStates.FAILED.value # Keep as failed
167
- raise
168
- finally:
169
- # Always publish workflow end event
170
- try:
171
- await workflow.execute_activity(
172
- publish_event,
173
- {
174
- "metadata": EventMetadata(workflow_state=workflow_state),
175
- "event_type": EventTypes.APPLICATION_EVENT.value,
176
- "event_name": ApplicationEventNames.WORKFLOW_END.value,
177
- "data": {},
178
- },
179
- schedule_to_close_timeout=timedelta(seconds=30),
180
- retry_policy=RetryPolicy(maximum_attempts=3),
181
- )
182
- except Exception as publish_error:
183
- workflow.logger.warning(
184
- f"Failed to publish workflow end event: {publish_error}"
185
- )
186
-
187
- return output
188
-
189
-
190
- class EventInterceptor(Interceptor):
191
- """Temporal interceptor for event tracking.
192
-
193
- This interceptor provides event tracking capabilities for both
194
- workflow and activity executions.
195
- """
196
-
197
- def intercept_activity(
198
- self, next: ActivityInboundInterceptor
199
- ) -> ActivityInboundInterceptor:
200
- """Intercept activity executions.
201
-
202
- Args:
203
- next (ActivityInboundInterceptor): The next interceptor in the chain.
204
-
205
- Returns:
206
- ActivityInboundInterceptor: The activity interceptor.
207
- """
208
- return EventActivityInboundInterceptor(super().intercept_activity(next))
209
-
210
- def workflow_interceptor_class(
211
- self, input: WorkflowInterceptorClassInput
212
- ) -> Optional[Type[WorkflowInboundInterceptor]]:
213
- """Get the workflow interceptor class.
214
-
215
- Args:
216
- input (WorkflowInterceptorClassInput): The interceptor input.
217
-
218
- Returns:
219
- Optional[Type[WorkflowInboundInterceptor]]: The workflow interceptor class.
220
- """
221
- return EventWorkflowInboundInterceptor
222
-
223
-
224
43
  class TemporalWorkflowClient(WorkflowClient):
225
44
  """Temporal-specific implementation of WorkflowClient with simple token refresh.
226
45
 
@@ -269,9 +88,7 @@ class TemporalWorkflowClient(WorkflowClient):
269
88
  self.port = port if port else WORKFLOW_PORT
270
89
  self.namespace = namespace if namespace else WORKFLOW_NAMESPACE
271
90
 
272
- self.deployment_config: Dict[str, Any] = (
273
- SecretStoreInput.get_deployment_secret()
274
- )
91
+ self.deployment_config: Dict[str, Any] = SecretStore.get_deployment_secret()
275
92
  self.worker_task_queue = self.get_worker_task_queue()
276
93
  self.auth_manager = AtlanAuthClient()
277
94
 
@@ -426,7 +243,7 @@ class TemporalWorkflowClient(WorkflowClient):
426
243
  """
427
244
  if "credentials" in workflow_args:
428
245
  # remove credentials from workflow_args and add reference to credentials
429
- workflow_args["credential_guid"] = await SecretStoreOutput.save_secret(
246
+ workflow_args["credential_guid"] = await SecretStore.save_secret(
430
247
  workflow_args["credentials"]
431
248
  )
432
249
  del workflow_args["credentials"]
@@ -442,7 +259,7 @@ class TemporalWorkflowClient(WorkflowClient):
442
259
  }
443
260
  )
444
261
 
445
- await StateStoreOutput.save_state_object(
262
+ await StateStore.save_state_object(
446
263
  id=workflow_id, value=workflow_args, type=StateType.WORKFLOWS
447
264
  )
448
265
  logger.info(f"Created workflow config with ID: {workflow_id}")
@@ -541,14 +358,34 @@ class TemporalWorkflowClient(WorkflowClient):
541
358
  f"Started token refresh loop with dynamic interval (initial: {self._token_refresh_interval}s)"
542
359
  )
543
360
 
544
- # Add the publish_event to the activities list
545
- extended_activities = list(activities) + [publish_event]
361
+ # Start with provided activities and add system activities
362
+ final_activities = list(activities) + [publish_event]
363
+
364
+ # Add lock management activities if needed
365
+ if not IS_LOCKING_DISABLED:
366
+ from application_sdk.activities.lock_management import (
367
+ acquire_distributed_lock,
368
+ release_distributed_lock,
369
+ )
370
+
371
+ final_activities.extend(
372
+ [
373
+ acquire_distributed_lock,
374
+ release_distributed_lock,
375
+ ]
376
+ )
377
+ logger.info(
378
+ "Auto-registered lock management activities for @needs_lock decorated activities"
379
+ )
380
+
381
+ # Create activities lookup dict for interceptors
382
+ activities_dict = {getattr(a, "__name__", str(a)): a for a in final_activities}
546
383
 
547
384
  return Worker(
548
385
  self.client,
549
386
  task_queue=self.worker_task_queue,
550
387
  workflows=workflow_classes,
551
- activities=extended_activities, # Use extended activities list
388
+ activities=final_activities,
552
389
  workflow_runner=SandboxedWorkflowRunner(
553
390
  restrictions=SandboxRestrictions.default.with_passthrough_modules(
554
391
  *passthrough_modules
@@ -556,7 +393,10 @@ class TemporalWorkflowClient(WorkflowClient):
556
393
  ),
557
394
  max_concurrent_activities=max_concurrent_activities,
558
395
  activity_executor=activity_executor,
559
- interceptors=[EventInterceptor()],
396
+ interceptors=[
397
+ EventInterceptor(),
398
+ RedisLockInterceptor(activities_dict),
399
+ ],
560
400
  )
561
401
 
562
402
  async def get_workflow_run_status(
@@ -81,6 +81,18 @@ class ClientError(AtlanError):
81
81
  AUTH_CONFIG_ERROR = ErrorCode(
82
82
  ErrorComponent.CLIENT, "400", "00", "Authentication configuration error"
83
83
  )
84
+ REDIS_CONNECTION_ERROR = ErrorCode(
85
+ ErrorComponent.CLIENT, "503", "00", "Redis connection failed"
86
+ )
87
+ REDIS_TIMEOUT_ERROR = ErrorCode(
88
+ ErrorComponent.CLIENT, "408", "00", "Redis operation timeout"
89
+ )
90
+ REDIS_AUTH_ERROR = ErrorCode(
91
+ ErrorComponent.CLIENT, "401", "05", "Redis authentication failed"
92
+ )
93
+ REDIS_PROTOCOL_ERROR = ErrorCode(
94
+ ErrorComponent.CLIENT, "502", "00", "Redis protocol error"
95
+ )
84
96
 
85
97
 
86
98
  class ApiError(AtlanError):
@@ -174,7 +186,7 @@ class IOError(AtlanError):
174
186
  INPUT_PROCESSING_ERROR = ErrorCode(
175
187
  ErrorComponent.IO, "500", "01", "Input processing error"
176
188
  )
177
- SQL_QUERY_ERROR = ErrorCode(ErrorComponent.IO, "400", "00", "SQL query error")
189
+ SQL_QUERY_ERROR = ErrorCode(ErrorComponent.IO, "400", "01", "SQL query error")
178
190
  SQL_QUERY_BATCH_ERROR = ErrorCode(
179
191
  ErrorComponent.IO, "500", "02", "SQL query batch error"
180
192
  )
@@ -268,10 +280,10 @@ class CommonError(AtlanError):
268
280
  ErrorComponent.COMMON, "400", "01", "Query preparation error"
269
281
  )
270
282
  FILTER_PREPARATION_ERROR = ErrorCode(
271
- ErrorComponent.COMMON, "400", "00", "Filter preparation error"
283
+ ErrorComponent.COMMON, "400", "02", "Filter preparation error"
272
284
  )
273
285
  CREDENTIALS_PARSE_ERROR = ErrorCode(
274
- ErrorComponent.COMMON, "400", "02", "Credentials parse error"
286
+ ErrorComponent.COMMON, "400", "03", "Credentials parse error"
275
287
  )
276
288
  CREDENTIALS_RESOLUTION_ERROR = ErrorCode(
277
289
  ErrorComponent.COMMON, "401", "03", "Credentials resolution error"
@@ -367,3 +379,12 @@ class ActivityError(AtlanError):
367
379
  ATLAN_UPLOAD_ERROR = ErrorCode(
368
380
  ErrorComponent.ACTIVITY, "500", "08", "Atlan upload error"
369
381
  )
382
+ LOCK_ACQUISITION_ERROR = ErrorCode(
383
+ ErrorComponent.ACTIVITY, "503", "01", "Distributed lock acquisition error"
384
+ )
385
+ LOCK_RELEASE_ERROR = ErrorCode(
386
+ ErrorComponent.ACTIVITY, "500", "09", "Distributed lock release error"
387
+ )
388
+ LOCK_TIMEOUT_ERROR = ErrorCode(
389
+ ErrorComponent.ACTIVITY, "408", "00", "Lock acquisition timeout"
390
+ )
@@ -146,7 +146,6 @@ DEPLOYMENT_SECRET_STORE_NAME = os.getenv(
146
146
  "DEPLOYMENT_SECRET_STORE_NAME", "deployment-secret-store"
147
147
  )
148
148
 
149
-
150
149
  # Logger Constants
151
150
  #: Log level for the application (DEBUG, INFO, WARNING, ERROR, CRITICAL)
152
151
  LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
@@ -230,3 +229,21 @@ ATLAN_BASE_URL = os.getenv("ATLAN_BASE_URL")
230
229
  ATLAN_API_KEY = os.getenv("ATLAN_API_KEY")
231
230
  ATLAN_CLIENT_ID = os.getenv("CLIENT_ID")
232
231
  ATLAN_CLIENT_SECRET = os.getenv("CLIENT_SECRET")
232
+ # Lock Configuration
233
+ LOCK_METADATA_KEY = "__lock_metadata__"
234
+
235
+ # Redis Lock Configuration
236
+ #: Redis host for direct connection (when not using Sentinel)
237
+ REDIS_HOST = os.getenv("REDIS_HOST", "")
238
+ #: Redis port for direct connection (when not using Sentinel)
239
+ REDIS_PORT = os.getenv("REDIS_PORT", "")
240
+ #: Redis password (required for authenticated Redis instances)
241
+ REDIS_PASSWORD = os.getenv("REDIS_PASSWORD")
242
+ #: Redis Sentinel service name (default: mymaster)
243
+ REDIS_SENTINEL_SERVICE_NAME = os.getenv("REDIS_SENTINEL_SERVICE_NAME", "mymaster")
244
+ #: Redis Sentinel hosts (comma-separated host:port pairs)
245
+ REDIS_SENTINEL_HOSTS = os.getenv("REDIS_SENTINEL_HOSTS", "")
246
+ #: Whether to enable strict locking
247
+ IS_LOCKING_DISABLED = os.getenv("IS_LOCKING_DISABLED", "true").lower() == "true"
248
+ #: Retry interval for lock acquisition
249
+ LOCK_RETRY_INTERVAL = int(os.getenv("LOCK_RETRY_INTERVAL", "5"))
File without changes
@@ -0,0 +1,42 @@
1
+ from typing import Any, Callable, Optional
2
+
3
+ from application_sdk.constants import LOCK_METADATA_KEY
4
+ from application_sdk.observability.logger_adaptor import get_logger
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
9
+ def needs_lock(max_locks: int = 5, lock_name: Optional[str] = None):
10
+ """Decorator to mark activities that require distributed locking.
11
+
12
+ This decorator attaches lock configuration directly to the activity
13
+ definition that will be used by the workflow interceptor to acquire
14
+ locks before executing activities.
15
+
16
+ Note:
17
+ Activities decorated with ``needs_lock`` must be called with
18
+ ``schedule_to_close_timeout`` to ensure proper lock TTL calculation
19
+ that covers retries.
20
+
21
+ Args:
22
+ max_locks (int): Maximum number of concurrent locks allowed.
23
+ lock_name (str | None): Optional custom name for the lock (defaults to activity name).
24
+
25
+ Raises:
26
+ WorkflowError: If activity is called without ``schedule_to_close_timeout``.
27
+ """
28
+
29
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
30
+ # Store lock metadata directly on the function object
31
+ metadata = {
32
+ "is_needs_lock": True,
33
+ "max_locks": max_locks,
34
+ "lock_name": lock_name or func.__name__,
35
+ }
36
+
37
+ # Attach metadata to the function
38
+ setattr(func, LOCK_METADATA_KEY, metadata)
39
+
40
+ return func
41
+
42
+ return decorator
@@ -44,7 +44,24 @@ class BaseHandler(HandlerInterface):
44
44
 
45
45
  # The following methods are inherited from HandlerInterface and should be implemented
46
46
  # by subclasses to handle calls from their respective FastAPI endpoints:
47
- #
48
47
  # - test_auth(**kwargs) -> bool: Called by /workflow/v1/auth endpoint
49
48
  # - preflight_check(**kwargs) -> Any: Called by /workflow/v1/check endpoint
50
49
  # - fetch_metadata(**kwargs) -> Any: Called by /workflow/v1/metadata endpoint
50
+
51
+ async def test_auth(self, **kwargs: Any) -> bool:
52
+ """
53
+ Test the authentication of the handler.
54
+ """
55
+ raise NotImplementedError("test_auth is not implemented")
56
+
57
+ async def preflight_check(self, **kwargs: Any) -> Any:
58
+ """
59
+ Check the preflight of the handler.
60
+ """
61
+ raise NotImplementedError("preflight_check is not implemented")
62
+
63
+ async def fetch_metadata(self, **kwargs: Any) -> Any:
64
+ """
65
+ Fetch the metadata of the handler.
66
+ """
67
+ raise NotImplementedError("fetch_metadata is not implemented")
@@ -1,10 +1,11 @@
1
1
  import os
2
2
  from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional, Union
3
3
 
4
+ from application_sdk.activities.common.utils import get_object_store_prefix
4
5
  from application_sdk.common.error_codes import IOError
5
6
  from application_sdk.inputs import Input
6
- from application_sdk.inputs.objectstore import ObjectStoreInput
7
7
  from application_sdk.observability.logger_adaptor import get_logger
8
+ from application_sdk.services.objectstore import ObjectStore
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  import daft
@@ -51,9 +52,10 @@ class JsonInput(Input):
51
52
  if self.download_file_prefix is not None and not os.path.exists(
52
53
  os.path.join(self.path, file_name)
53
54
  ):
54
- ObjectStoreInput.download_file_from_object_store(
55
- os.path.join(self.download_file_prefix, file_name),
56
- os.path.join(self.path, file_name),
55
+ destination_file_path = os.path.join(self.path, file_name)
56
+ await ObjectStore.download_file(
57
+ source=get_object_store_prefix(destination_file_path),
58
+ destination=destination_file_path,
57
59
  )
58
60
  except IOError as e:
59
61
  logger.error(
@@ -2,9 +2,10 @@ import glob
2
2
  import os
3
3
  from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional, Union
4
4
 
5
+ from application_sdk.activities.common.utils import get_object_store_prefix
5
6
  from application_sdk.inputs import Input
6
- from application_sdk.inputs.objectstore import ObjectStoreInput
7
7
  from application_sdk.observability.logger_adaptor import get_logger
8
+ from application_sdk.services.objectstore import ObjectStore
8
9
 
9
10
  logger = get_logger(__name__)
10
11
 
@@ -42,37 +43,39 @@ class ParquetInput(Input):
42
43
  self.input_prefix = input_prefix
43
44
  self.file_names = file_names
44
45
 
45
- async def download_files(self, local_file_path: str) -> Optional[str]:
46
+ async def download_files(self, local_path: str) -> Optional[str]:
46
47
  """Read a file from the object store.
47
48
 
48
49
  Args:
49
- local_file_path (str): Path to the local file in the temp directory.
50
+ local_path (str): Path to the local data in the temp directory.
50
51
 
51
52
  Returns:
52
53
  Optional[str]: Path to the downloaded local file.
53
54
  """
54
55
  # if the path is a directory, then check if the directory has any parquet files
55
56
  parquet_files = []
56
- if os.path.isdir(local_file_path):
57
- parquet_files = glob.glob(os.path.join(local_file_path, "*.parquet"))
57
+ if os.path.isdir(local_path):
58
+ parquet_files = glob.glob(os.path.join(local_path, "*.parquet"))
58
59
  else:
59
- parquet_files = glob.glob(local_file_path)
60
+ parquet_files = glob.glob(local_path)
60
61
  if not parquet_files:
61
62
  if self.input_prefix:
62
63
  logger.info(
63
- f"Reading file from object store: {local_file_path} from {self.input_prefix}"
64
+ f"Reading file from object store: {local_path} from {self.input_prefix}"
64
65
  )
65
- if os.path.isdir(local_file_path):
66
- ObjectStoreInput.download_files_from_object_store(
67
- self.input_prefix, local_file_path
66
+ if os.path.isdir(local_path):
67
+ await ObjectStore.download_prefix(
68
+ source=get_object_store_prefix(local_path),
69
+ destination=local_path,
68
70
  )
69
71
  else:
70
- ObjectStoreInput.download_file_from_object_store(
71
- self.input_prefix, local_file_path
72
+ await ObjectStore.download_file(
73
+ source=get_object_store_prefix(local_path),
74
+ destination=local_path,
72
75
  )
73
76
  else:
74
77
  raise ValueError(
75
- f"No parquet files found in {local_file_path} and no input prefix provided"
78
+ f"No parquet files found in {local_path} and no input prefix provided"
76
79
  )
77
80
 
78
81
  async def get_dataframe(self) -> "pd.DataFrame":
File without changes