atlan-application-sdk 0.1.1rc40__py3-none-any.whl → 0.1.1rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. application_sdk/activities/common/utils.py +78 -4
  2. application_sdk/activities/metadata_extraction/sql.py +400 -27
  3. application_sdk/application/__init__.py +2 -0
  4. application_sdk/application/metadata_extraction/sql.py +3 -0
  5. application_sdk/clients/models.py +42 -0
  6. application_sdk/clients/sql.py +17 -13
  7. application_sdk/common/aws_utils.py +259 -11
  8. application_sdk/common/utils.py +145 -9
  9. application_sdk/handlers/__init__.py +8 -1
  10. application_sdk/handlers/sql.py +63 -22
  11. application_sdk/inputs/__init__.py +98 -2
  12. application_sdk/inputs/json.py +59 -87
  13. application_sdk/inputs/parquet.py +173 -94
  14. application_sdk/observability/decorators/observability_decorator.py +36 -22
  15. application_sdk/server/fastapi/__init__.py +59 -3
  16. application_sdk/server/fastapi/models.py +27 -0
  17. application_sdk/test_utils/hypothesis/strategies/inputs/json_input.py +10 -5
  18. application_sdk/test_utils/hypothesis/strategies/inputs/parquet_input.py +9 -4
  19. application_sdk/version.py +1 -1
  20. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/METADATA +1 -1
  21. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/RECORD +24 -23
  22. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/WHEEL +0 -0
  23. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/LICENSE +0 -0
  24. {atlan_application_sdk-0.1.1rc40.dist-info → atlan_application_sdk-0.1.1rc42.dist-info}/licenses/NOTICE +0 -0
@@ -56,9 +56,13 @@ class BaseSQLHandler(HandlerInterface):
56
56
  schema_alias_key: str = SQLConstants.SCHEMA_ALIAS_KEY.value
57
57
  database_result_key: str = SQLConstants.DATABASE_RESULT_KEY.value
58
58
  schema_result_key: str = SQLConstants.SCHEMA_RESULT_KEY.value
59
+ multidb: bool = False
59
60
 
60
- def __init__(self, sql_client: BaseSQLClient | None = None):
61
+ def __init__(
62
+ self, sql_client: BaseSQLClient | None = None, multidb: Optional[bool] = False
63
+ ):
61
64
  self.sql_client = sql_client
65
+ self.multidb = multidb
62
66
 
63
67
  async def load(self, credentials: Dict[str, Any]) -> None:
64
68
  """
@@ -294,35 +298,26 @@ class BaseSQLHandler(HandlerInterface):
294
298
  return False, f"{db}.{sch} schema"
295
299
  return True, ""
296
300
 
297
- async def tables_check(
298
- self,
299
- payload: Dict[str, Any],
300
- ) -> Dict[str, Any]:
301
+ async def tables_check(self, payload: Dict[str, Any]) -> Dict[str, Any]:
301
302
  """
302
303
  Method to check the count of tables
303
304
  """
304
305
  logger.info("Starting tables check")
305
- query = prepare_query(
306
- query=self.tables_check_sql,
307
- workflow_args=payload,
308
- temp_table_regex_sql=self.extract_temp_table_regex_table_sql,
309
- )
310
- if not query:
311
- raise ValueError("tables_check_sql is not defined")
312
- sql_input = SQLQueryInput(
313
- engine=self.sql_client.engine, query=query, chunk_size=None
314
- )
315
- sql_input = await sql_input.get_dataframe()
316
- try:
317
- result = 0
318
- for row in sql_input.to_dict(orient="records"):
319
- result += row["count"]
306
+
307
+ def _sum_counts_from_records(records_iter) -> int:
308
+ total = 0
309
+ for row in records_iter:
310
+ total += row["count"]
311
+ return total
312
+
313
+ def _build_success(total: int) -> Dict[str, Any]:
320
314
  return {
321
315
  "success": True,
322
- "successMessage": f"Tables check successful. Table count: {result}",
316
+ "successMessage": f"Tables check successful. Table count: {total}",
323
317
  "failureMessage": "",
324
318
  }
325
- except Exception as exc:
319
+
320
+ def _build_failure(exc: Exception) -> Dict[str, Any]:
326
321
  logger.error("Error during tables check", exc_info=True)
327
322
  return {
328
323
  "success": False,
@@ -331,6 +326,52 @@ class BaseSQLHandler(HandlerInterface):
331
326
  "error": str(exc),
332
327
  }
333
328
 
329
+ if self.multidb:
330
+ try:
331
+ from application_sdk.activities.metadata_extraction.sql import (
332
+ BaseSQLMetadataExtractionActivities,
333
+ )
334
+
335
+ # Use the base query executor in multidb mode to get concatenated df
336
+ activities = BaseSQLMetadataExtractionActivities()
337
+ activities.multidb = True
338
+ concatenated_df = await activities.query_executor(
339
+ sql_engine=self.sql_client.engine if self.sql_client else None,
340
+ sql_query=self.tables_check_sql,
341
+ workflow_args=payload,
342
+ output_suffix="raw/table",
343
+ typename="table",
344
+ write_to_file=False,
345
+ concatenate=True,
346
+ return_dataframe=True,
347
+ sql_client=self.sql_client,
348
+ )
349
+
350
+ if concatenated_df is None:
351
+ return _build_success(0)
352
+
353
+ total = int(concatenated_df["count"].sum()) # type: ignore[index]
354
+ return _build_success(total)
355
+ except Exception as exc:
356
+ return _build_failure(exc)
357
+ else:
358
+ query = prepare_query(
359
+ query=self.tables_check_sql,
360
+ workflow_args=payload,
361
+ temp_table_regex_sql=self.extract_temp_table_regex_table_sql,
362
+ )
363
+ if not query:
364
+ raise ValueError("tables_check_sql is not defined")
365
+ sql_input = SQLQueryInput(
366
+ engine=self.sql_client.engine, query=query, chunk_size=None
367
+ )
368
+ sql_input = await sql_input.get_dataframe()
369
+ try:
370
+ total = _sum_counts_from_records(sql_input.to_dict(orient="records"))
371
+ return _build_success(total)
372
+ except Exception as exc:
373
+ return _build_failure(exc)
374
+
334
375
  async def check_client_version(self) -> Dict[str, Any]:
335
376
  """
336
377
  Check if the client version meets the minimum required version.
@@ -1,7 +1,15 @@
1
+ import os
1
2
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, AsyncIterator, Iterator, Union
3
-
3
+ from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Union
4
+
5
+ from application_sdk.activities.common.utils import (
6
+ find_local_files_by_extension,
7
+ get_object_store_prefix,
8
+ )
9
+ from application_sdk.common.error_codes import IOError
10
+ from application_sdk.constants import TEMPORARY_PATH
4
11
  from application_sdk.observability.logger_adaptor import get_logger
12
+ from application_sdk.services.objectstore import ObjectStore
5
13
 
6
14
  logger = get_logger(__name__)
7
15
 
@@ -15,6 +23,94 @@ class Input(ABC):
15
23
  Abstract base class for input data sources.
16
24
  """
17
25
 
26
+ async def download_files(self) -> List[str]:
27
+ """Download files from object store if not available locally.
28
+
29
+ Flow:
30
+ 1. Check if files exist locally at self.path
31
+ 2. If not, try to download from object store
32
+ 3. Filter by self.file_names if provided
33
+ 4. Return list of file paths for logging purposes
34
+
35
+ Returns:
36
+ List[str]: List of file paths
37
+
38
+ Raises:
39
+ AttributeError: When the input class doesn't support file operations or _extension
40
+ IOError: When no files found locally or in object store
41
+ """
42
+ # Step 1: Check if files exist locally
43
+ local_files = find_local_files_by_extension(
44
+ self.path, self._EXTENSION, self.file_names
45
+ )
46
+ if local_files:
47
+ logger.info(
48
+ f"Found {len(local_files)} {self._EXTENSION} files locally at: {self.path}"
49
+ )
50
+ return local_files
51
+
52
+ # Step 2: Try to download from object store
53
+ logger.info(
54
+ f"No local {self._EXTENSION} files found at {self.path}, checking object store..."
55
+ )
56
+
57
+ try:
58
+ # Determine what to download based on path type and filters
59
+ downloaded_paths = []
60
+
61
+ if self.path.endswith(self._EXTENSION):
62
+ # Single file case (file_names validation already ensures this is valid)
63
+ source_path = get_object_store_prefix(self.path)
64
+ destination_path = os.path.join(TEMPORARY_PATH, source_path)
65
+ await ObjectStore.download_file(
66
+ source=source_path,
67
+ destination=destination_path,
68
+ )
69
+ downloaded_paths.append(destination_path)
70
+
71
+ elif self.file_names:
72
+ # Directory with specific files - download each file individually
73
+ for file_name in self.file_names:
74
+ file_path = os.path.join(self.path, file_name)
75
+ source_path = get_object_store_prefix(file_path)
76
+ destination_path = os.path.join(TEMPORARY_PATH, source_path)
77
+ await ObjectStore.download_file(
78
+ source=source_path,
79
+ destination=destination_path,
80
+ )
81
+ downloaded_paths.append(destination_path)
82
+ else:
83
+ # Download entire directory
84
+ source_path = get_object_store_prefix(self.path)
85
+ destination_path = os.path.join(TEMPORARY_PATH, source_path)
86
+ await ObjectStore.download_prefix(
87
+ source=source_path,
88
+ destination=destination_path,
89
+ )
90
+ # Find the actual files in the downloaded directory
91
+ found_files = find_local_files_by_extension(
92
+ destination_path, self._EXTENSION, getattr(self, "file_names", None)
93
+ )
94
+ downloaded_paths.extend(found_files)
95
+
96
+ # Check results
97
+ if downloaded_paths:
98
+ logger.info(
99
+ f"Successfully downloaded {len(downloaded_paths)} {self._EXTENSION} files from object store"
100
+ )
101
+ return downloaded_paths
102
+ else:
103
+ raise IOError(
104
+ f"{IOError.OBJECT_STORE_READ_ERROR}: Downloaded from object store but no {self._EXTENSION} files found"
105
+ )
106
+
107
+ except Exception as e:
108
+ logger.error(f"Failed to download from object store: {str(e)}")
109
+ raise IOError(
110
+ f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: No {self._EXTENSION} files found locally at '{self.path}' and failed to download from object store. "
111
+ f"Error: {str(e)}"
112
+ )
113
+
18
114
  @abstractmethod
19
115
  async def get_batched_dataframe(
20
116
  self,
@@ -1,11 +1,7 @@
1
- import os
2
1
  from typing import TYPE_CHECKING, AsyncIterator, Iterator, List, Optional, Union
3
2
 
4
- from application_sdk.activities.common.utils import get_object_store_prefix
5
- from application_sdk.common.error_codes import IOError
6
3
  from application_sdk.inputs import Input
7
4
  from application_sdk.observability.logger_adaptor import get_logger
8
- from application_sdk.services.objectstore import ObjectStore
9
5
 
10
6
  if TYPE_CHECKING:
11
7
  import daft
@@ -15,56 +11,43 @@ logger = get_logger(__name__)
15
11
 
16
12
 
17
13
  class JsonInput(Input):
18
- path: str
19
- chunk_size: Optional[int]
20
- file_names: Optional[List[str]]
21
- download_file_prefix: Optional[str]
14
+ """
15
+ JSON Input class to read data from JSON files using daft and pandas.
16
+ Supports reading both single files and directories containing multiple JSON files.
17
+ """
18
+
19
+ _EXTENSION = ".json"
22
20
 
23
21
  def __init__(
24
22
  self,
25
23
  path: str,
26
24
  file_names: Optional[List[str]] = None,
27
- download_file_prefix: Optional[str] = None,
28
- chunk_size: Optional[int] = None,
25
+ chunk_size: int = 100000,
29
26
  ):
30
27
  """Initialize the JsonInput class.
31
28
 
32
29
  Args:
33
- path (str): The path to the input directory.
34
- file_names (Optional[List[str]]): The list of files to read.
35
- download_file_prefix (Optional[str]): The prefix path in object store.
36
- chunk_size (Optional[int]): The chunk size to read the data. If None, uses config value.
30
+ path (str): Path to JSON file or directory containing JSON files.
31
+ It accepts both types of paths:
32
+ local path or object store path
33
+ Wildcards are not supported.
34
+ file_names (Optional[List[str]]): List of specific file names to read. Defaults to None.
35
+ chunk_size (int): Number of rows per batch. Defaults to 100000.
36
+
37
+ Raises:
38
+ ValueError: When path is not provided or when single file path is combined with file_names
37
39
  """
40
+
41
+ # Validate that single file path and file_names are not both specified
42
+ if path.endswith(self._EXTENSION) and file_names:
43
+ raise ValueError(
44
+ f"Cannot specify both a single file path ('{path}') and file_names filter. "
45
+ f"Either provide a directory path with file_names, or specify the exact file path without file_names."
46
+ )
47
+
38
48
  self.path = path
39
- # If chunk_size is provided, use it; otherwise default to 100,000 rows per batch
40
- self.chunk_size = chunk_size if chunk_size is not None else 100000
49
+ self.chunk_size = chunk_size
41
50
  self.file_names = file_names
42
- self.download_file_prefix = download_file_prefix
43
-
44
- async def download_files(self):
45
- """Download the files from the object store to the local path"""
46
- if not self.file_names:
47
- logger.debug("No files to download")
48
- return
49
-
50
- for file_name in self.file_names or []:
51
- try:
52
- if self.download_file_prefix is not None and not os.path.exists(
53
- os.path.join(self.path, file_name)
54
- ):
55
- destination_file_path = os.path.join(self.path, file_name)
56
- await ObjectStore.download_file(
57
- source=get_object_store_prefix(destination_file_path),
58
- destination=destination_file_path,
59
- )
60
- except IOError as e:
61
- logger.error(
62
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error downloading file {file_name}: {str(e)}",
63
- error_code=IOError.OBJECT_STORE_DOWNLOAD_ERROR.code,
64
- )
65
- raise IOError(
66
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error downloading file {file_name}: {str(e)}"
67
- )
68
51
 
69
52
  async def get_batched_dataframe(
70
53
  self,
@@ -76,22 +59,20 @@ class JsonInput(Input):
76
59
  try:
77
60
  import pandas as pd
78
61
 
79
- await self.download_files()
62
+ # Ensure files are available (local or downloaded)
63
+ json_files = await self.download_files()
64
+ logger.info(f"Reading {len(json_files)} JSON files in batches")
80
65
 
81
- for file_name in self.file_names or []:
82
- file_path = os.path.join(self.path, file_name)
66
+ for json_file in json_files:
83
67
  json_reader_obj = pd.read_json(
84
- file_path,
68
+ json_file,
85
69
  chunksize=self.chunk_size,
86
70
  lines=True,
87
71
  )
88
72
  for chunk in json_reader_obj:
89
73
  yield chunk
90
- except IOError as e:
91
- logger.error(
92
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error reading batched data from JSON: {str(e)}",
93
- error_code=IOError.OBJECT_STORE_DOWNLOAD_ERROR.code,
94
- )
74
+ except Exception as e:
75
+ logger.error(f"Error reading batched data from JSON: {str(e)}")
95
76
  raise
96
77
 
97
78
  async def get_dataframe(self) -> "pd.DataFrame":
@@ -102,21 +83,17 @@ class JsonInput(Input):
102
83
  try:
103
84
  import pandas as pd
104
85
 
105
- dataframes = []
106
- await self.download_files()
107
- for file_name in self.file_names or []:
108
- dataframes.append(
109
- pd.read_json(
110
- os.path.join(self.path, file_name),
111
- lines=True,
112
- )
113
- )
114
- return pd.concat(dataframes, ignore_index=True)
115
- except IOError as e:
116
- logger.error(
117
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error reading data from JSON: {str(e)}",
118
- error_code=IOError.OBJECT_STORE_DOWNLOAD_ERROR.code,
86
+ # Ensure files are available (local or downloaded)
87
+ json_files = await self.download_files()
88
+ logger.info(f"Reading {len(json_files)} JSON files as pandas dataframe")
89
+
90
+ return pd.concat(
91
+ (pd.read_json(json_file, lines=True) for json_file in json_files),
92
+ ignore_index=True,
119
93
  )
94
+
95
+ except Exception as e:
96
+ logger.error(f"Error reading data from JSON: {str(e)}")
120
97
  raise
121
98
 
122
99
  async def get_batched_daft_dataframe(
@@ -129,18 +106,15 @@ class JsonInput(Input):
129
106
  try:
130
107
  import daft
131
108
 
132
- await self.download_files()
133
- for file_name in self.file_names or []:
134
- json_reader_obj = daft.read_json(
135
- path=os.path.join(self.path, file_name),
136
- _chunk_size=self.chunk_size,
137
- )
138
- yield json_reader_obj
139
- except IOError as e:
140
- logger.error(
141
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error reading batched data from JSON: {str(e)}",
142
- error_code=IOError.OBJECT_STORE_DOWNLOAD_ERROR.code,
143
- )
109
+ # Ensure files are available (local or downloaded)
110
+ json_files = await self.download_files()
111
+ logger.info(f"Reading {len(json_files)} JSON files as daft batches")
112
+
113
+ # Yield each discovered file as separate batch with chunking
114
+ for json_file in json_files:
115
+ yield daft.read_json(json_file, _chunk_size=self.chunk_size)
116
+ except Exception as e:
117
+ logger.error(f"Error reading batched data from JSON using daft: {str(e)}")
144
118
  raise
145
119
 
146
120
  async def get_daft_dataframe(self) -> "daft.DataFrame": # noqa: F821
@@ -151,14 +125,12 @@ class JsonInput(Input):
151
125
  try:
152
126
  import daft
153
127
 
154
- await self.download_files()
155
- if not self.file_names or len(self.file_names) == 0:
156
- raise ValueError("No files to read")
157
- directory = os.path.join(self.path, self.file_names[0].split("/")[0])
158
- return daft.read_json(path=f"{directory}/*.json")
159
- except IOError as e:
160
- logger.error(
161
- f"{IOError.OBJECT_STORE_DOWNLOAD_ERROR}: Error reading data from JSON using daft: {str(e)}",
162
- error_code=IOError.OBJECT_STORE_DOWNLOAD_ERROR.code,
163
- )
128
+ # Ensure files are available (local or downloaded)
129
+ json_files = await self.download_files()
130
+ logger.info(f"Reading {len(json_files)} JSON files with daft")
131
+
132
+ # Use the discovered/downloaded files directly
133
+ return daft.read_json(json_files)
134
+ except Exception as e:
135
+ logger.error(f"Error reading data from JSON using daft: {str(e)}")
164
136
  raise