atlan-application-sdk 1.1.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. application_sdk/activities/common/sql_utils.py +308 -0
  2. application_sdk/activities/common/utils.py +1 -45
  3. application_sdk/activities/metadata_extraction/sql.py +110 -353
  4. application_sdk/activities/query_extraction/sql.py +12 -11
  5. application_sdk/application/__init__.py +1 -1
  6. application_sdk/clients/sql.py +167 -1
  7. application_sdk/clients/temporal.py +6 -6
  8. application_sdk/common/types.py +8 -0
  9. application_sdk/common/utils.py +1 -8
  10. application_sdk/constants.py +1 -1
  11. application_sdk/handlers/sql.py +10 -25
  12. application_sdk/interceptors/events.py +1 -1
  13. application_sdk/io/__init__.py +654 -0
  14. application_sdk/io/json.py +429 -0
  15. application_sdk/{outputs → io}/parquet.py +358 -47
  16. application_sdk/io/utils.py +307 -0
  17. application_sdk/observability/observability.py +23 -12
  18. application_sdk/server/fastapi/middleware/logmiddleware.py +23 -17
  19. application_sdk/server/fastapi/middleware/metrics.py +27 -24
  20. application_sdk/server/fastapi/models.py +1 -1
  21. application_sdk/server/fastapi/routers/server.py +1 -1
  22. application_sdk/server/fastapi/utils.py +10 -0
  23. application_sdk/services/eventstore.py +4 -4
  24. application_sdk/services/objectstore.py +30 -7
  25. application_sdk/services/secretstore.py +1 -1
  26. application_sdk/test_utils/hypothesis/strategies/outputs/json_output.py +0 -1
  27. application_sdk/test_utils/hypothesis/strategies/server/fastapi/__init__.py +1 -1
  28. application_sdk/version.py +1 -1
  29. application_sdk/worker.py +1 -1
  30. {atlan_application_sdk-1.1.0.dist-info → atlan_application_sdk-2.0.0.dist-info}/METADATA +9 -11
  31. {atlan_application_sdk-1.1.0.dist-info → atlan_application_sdk-2.0.0.dist-info}/RECORD +36 -43
  32. application_sdk/common/dataframe_utils.py +0 -42
  33. application_sdk/events/__init__.py +0 -5
  34. application_sdk/inputs/.cursor/BUGBOT.md +0 -250
  35. application_sdk/inputs/__init__.py +0 -168
  36. application_sdk/inputs/iceberg.py +0 -75
  37. application_sdk/inputs/json.py +0 -136
  38. application_sdk/inputs/parquet.py +0 -272
  39. application_sdk/inputs/sql_query.py +0 -271
  40. application_sdk/outputs/.cursor/BUGBOT.md +0 -295
  41. application_sdk/outputs/__init__.py +0 -445
  42. application_sdk/outputs/iceberg.py +0 -139
  43. application_sdk/outputs/json.py +0 -268
  44. /application_sdk/{events → interceptors}/models.py +0 -0
  45. /application_sdk/{common/dapr_utils.py → services/_utils.py} +0 -0
  46. {atlan_application_sdk-1.1.0.dist-info → atlan_application_sdk-2.0.0.dist-info}/WHEEL +0 -0
  47. {atlan_application_sdk-1.1.0.dist-info → atlan_application_sdk-2.0.0.dist-info}/licenses/LICENSE +0 -0
  48. {atlan_application_sdk-1.1.0.dist-info → atlan_application_sdk-2.0.0.dist-info}/licenses/NOTICE +0 -0
@@ -16,9 +16,8 @@ from application_sdk.clients.sql import BaseSQLClient
16
16
  from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME
17
17
  from application_sdk.handlers import HandlerInterface
18
18
  from application_sdk.handlers.sql import BaseSQLHandler
19
- from application_sdk.inputs.sql_query import SQLQueryInput
19
+ from application_sdk.io.parquet import ParquetFileWriter
20
20
  from application_sdk.observability.logger_adaptor import get_logger
21
- from application_sdk.outputs.parquet import ParquetOutput
22
21
  from application_sdk.services.objectstore import ObjectStore
23
22
  from application_sdk.services.secretstore import SecretStore
24
23
  from application_sdk.transformers import TransformerInterface
@@ -202,21 +201,23 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
202
201
 
203
202
  try:
204
203
  state = await self._get_state(workflow_args)
205
- sql_input = SQLQueryInput(
206
- engine=state.sql_client.engine,
207
- query=self.get_formatted_query(self.fetch_queries_sql, workflow_args),
208
- chunk_size=None,
204
+ sql_client = state.sql_client
205
+ if not sql_client:
206
+ logger.error("SQL client not initialized")
207
+ raise ValueError("SQL client not initialized")
208
+
209
+ formatted_query = self.get_formatted_query(
210
+ self.fetch_queries_sql, workflow_args
209
211
  )
210
- sql_input = await sql_input.get_dataframe()
212
+ sql_results = await sql_client.get_results(formatted_query)
211
213
 
212
- raw_output = ParquetOutput(
213
- output_path=workflow_args["output_path"],
214
- output_suffix="raw/query",
214
+ raw_output = ParquetFileWriter(
215
+ path=os.path.join(workflow_args["output_path"], "raw/query"),
215
216
  chunk_size=workflow_args["miner_args"].get("chunk_size", 100000),
216
217
  start_marker=workflow_args["start_marker"],
217
218
  end_marker=workflow_args["end_marker"],
218
219
  )
219
- await raw_output.write_dataframe(sql_input)
220
+ await raw_output.write(sql_results)
220
221
  logger.info(
221
222
  f"Query fetch completed, {raw_output.total_record_count} records processed",
222
223
  )
@@ -5,8 +5,8 @@ from application_sdk.activities import ActivitiesInterface
5
5
  from application_sdk.clients.base import BaseClient
6
6
  from application_sdk.clients.utils import get_workflow_client
7
7
  from application_sdk.constants import ENABLE_MCP
8
- from application_sdk.events.models import EventRegistration
9
8
  from application_sdk.handlers.base import BaseHandler
9
+ from application_sdk.interceptors.models import EventRegistration
10
10
  from application_sdk.observability.logger_adaptor import get_logger
11
11
  from application_sdk.server import ServerInterface
12
12
  from application_sdk.server.fastapi import APIServer, HttpWorkflowTrigger
@@ -6,8 +6,19 @@ database operations, supporting batch processing and server-side cursors.
6
6
  """
7
7
 
8
8
  import asyncio
9
+ import concurrent
9
10
  from concurrent.futures import ThreadPoolExecutor
10
- from typing import Any, Dict, List, Optional
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ Any,
14
+ AsyncIterator,
15
+ Dict,
16
+ Iterator,
17
+ List,
18
+ Optional,
19
+ Union,
20
+ cast,
21
+ )
11
22
  from urllib.parse import quote_plus
12
23
 
13
24
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
@@ -27,6 +38,11 @@ from application_sdk.observability.logger_adaptor import get_logger
27
38
  logger = get_logger(__name__)
28
39
  activity.logger = logger
29
40
 
41
+ if TYPE_CHECKING:
42
+ import daft
43
+ import pandas as pd
44
+ from sqlalchemy.orm import Session
45
+
30
46
 
31
47
  class BaseSQLClient(ClientInterface):
32
48
  """SQL client for database operations.
@@ -53,6 +69,7 @@ class BaseSQLClient(ClientInterface):
53
69
  self,
54
70
  use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR,
55
71
  credentials: Dict[str, Any] = {},
72
+ chunk_size: int = 5000,
56
73
  ):
57
74
  """
58
75
  Initialize the SQL client.
@@ -64,6 +81,7 @@ class BaseSQLClient(ClientInterface):
64
81
  """
65
82
  self.use_server_side_cursor = use_server_side_cursor
66
83
  self.credentials = credentials
84
+ self.chunk_size = chunk_size
67
85
 
68
86
  async def load(self, credentials: Dict[str, Any]) -> None:
69
87
  """Load credentials and prepare engine for lazy connections.
@@ -383,6 +401,154 @@ class BaseSQLClient(ClientInterface):
383
401
 
384
402
  logger.info("Query execution completed")
385
403
 
404
+ def _execute_pandas_query(
405
+ self, conn, query, chunksize: Optional[int]
406
+ ) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
407
+ """Helper function to execute SQL query using pandas.
408
+ The function is responsible for using import_optional_dependency method of the pandas library to import sqlalchemy
409
+ This function helps pandas in determining weather to use the sqlalchemy connection object and constructs like text()
410
+ or use the underlying database connection object. This has been done to make sure connectors like the Redshift connector,
411
+ which do not support the sqlalchemy connection object, can be made compatible with the application-sdk.
412
+
413
+ Args:
414
+ conn: Database connection object.
415
+
416
+ Returns:
417
+ Union["pd.DataFrame", Iterator["pd.DataFrame"]]: Query results as DataFrame
418
+ or iterator of DataFrames if chunked.
419
+ """
420
+ import pandas as pd
421
+ from pandas.compat._optional import import_optional_dependency
422
+ from sqlalchemy import text
423
+
424
+ if import_optional_dependency("sqlalchemy", errors="ignore"):
425
+ return pd.read_sql_query(text(query), conn, chunksize=chunksize)
426
+ else:
427
+ dbapi_conn = getattr(conn, "connection", None)
428
+ return pd.read_sql_query(query, dbapi_conn, chunksize=chunksize)
429
+
430
+ def _read_sql_query(
431
+ self, session: "Session", query: str, chunksize: Optional[int]
432
+ ) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
433
+ """Execute SQL query using the provided session.
434
+
435
+ Args:
436
+ session: SQLAlchemy session for database operations.
437
+
438
+ Returns:
439
+ Union["pd.DataFrame", Iterator["pd.DataFrame"]]: Query results as DataFrame
440
+ or iterator of DataFrames if chunked.
441
+ """
442
+ conn = session.connection()
443
+ return self._execute_pandas_query(conn, query, chunksize=chunksize)
444
+
445
+ def _execute_query_daft(
446
+ self, query: str, chunksize: Optional[int]
447
+ ) -> Union["daft.DataFrame", Iterator["daft.DataFrame"]]:
448
+ """Execute SQL query using the provided engine and daft.
449
+
450
+ Returns:
451
+ Union["daft.DataFrame", Iterator["daft.DataFrame"]]: Query results as DataFrame
452
+ or iterator of DataFrames if chunked.
453
+ """
454
+ # Daft uses ConnectorX to read data from SQL by default for supported connectors
455
+ # If a connection string is passed, it will use ConnectorX to read data
456
+ # For unsupported connectors and if directly engine is passed, it will use SQLAlchemy
457
+ import daft
458
+
459
+ if not self.engine:
460
+ raise ValueError("Engine is not initialized. Call load() first.")
461
+
462
+ if isinstance(self.engine, str):
463
+ return daft.read_sql(query, self.engine, infer_schema_length=chunksize)
464
+ return daft.read_sql(query, self.engine.connect, infer_schema_length=chunksize)
465
+
466
+ def _execute_query(
467
+ self, query: str, chunksize: Optional[int]
468
+ ) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
469
+ """Execute SQL query using the provided engine and pandas.
470
+
471
+ Returns:
472
+ Union["pd.DataFrame", Iterator["pd.DataFrame"]]: Query results as DataFrame
473
+ or iterator of DataFrames if chunked.
474
+ """
475
+ if not self.engine:
476
+ raise ValueError("Engine is not initialized. Call load() first.")
477
+
478
+ with self.engine.connect() as conn:
479
+ return self._execute_pandas_query(conn, query, chunksize)
480
+
481
+ async def _execute_async_read_operation(
482
+ self, query: str, chunksize: Optional[int]
483
+ ) -> Union["pd.DataFrame", Iterator["pd.DataFrame"]]:
484
+ """Helper to execute async read operation with either async session or thread executor."""
485
+ if isinstance(self.engine, str):
486
+ raise ValueError("Engine should be an SQLAlchemy engine object")
487
+
488
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
489
+
490
+ async_session = None
491
+ if self.engine and isinstance(self.engine, AsyncEngine):
492
+ from sqlalchemy.orm import sessionmaker
493
+
494
+ async_session = sessionmaker(
495
+ self.engine, expire_on_commit=False, class_=AsyncSession
496
+ )
497
+
498
+ if async_session:
499
+ async with async_session() as session:
500
+ return await session.run_sync(
501
+ self._read_sql_query, query, chunksize=chunksize
502
+ )
503
+ else:
504
+ # Run the blocking operation in a thread pool
505
+ with concurrent.futures.ThreadPoolExecutor() as executor:
506
+ return await asyncio.get_event_loop().run_in_executor(
507
+ executor, self._execute_query, query, chunksize
508
+ )
509
+
510
+ async def get_batched_results(
511
+ self,
512
+ query: str,
513
+ ) -> Union[AsyncIterator["pd.DataFrame"], Iterator["pd.DataFrame"]]: # type: ignore
514
+ """Get query results as batched pandas DataFrames asynchronously.
515
+
516
+ Returns:
517
+ AsyncIterator["pd.DataFrame"]: Async iterator yielding batches of query results.
518
+
519
+ Raises:
520
+ ValueError: If engine is a string instead of SQLAlchemy engine.
521
+ Exception: If there's an error executing the query.
522
+ """
523
+ try:
524
+ # We cast to Iterator because passing chunk_size guarantees an Iterator return
525
+ result = await self._execute_async_read_operation(query, self.chunk_size)
526
+ return cast(Iterator["pd.DataFrame"], result)
527
+ except Exception as e:
528
+ logger.error(f"Error reading batched data(pandas) from SQL: {str(e)}")
529
+
530
+ async def get_results(self, query: str) -> "pd.DataFrame":
531
+ """Get all query results as a single pandas DataFrame asynchronously.
532
+
533
+ Returns:
534
+ pd.DataFrame: Query results as a DataFrame.
535
+
536
+ Raises:
537
+ ValueError: If engine is a string instead of SQLAlchemy engine.
538
+ Exception: If there's an error executing the query.
539
+ """
540
+ try:
541
+ result = await self._execute_async_read_operation(query, None)
542
+ import pandas as pd
543
+
544
+ if isinstance(result, pd.DataFrame):
545
+ return result
546
+ raise Exception("Unable to get pandas dataframe from SQL query results")
547
+
548
+ except Exception as e:
549
+ logger.error(f"Error reading data(pandas) from SQL: {str(e)}")
550
+ raise e
551
+
386
552
 
387
553
  class AsyncBaseSQLClient(BaseSQLClient):
388
554
  """Asynchronous SQL client for database operations.
@@ -26,18 +26,18 @@ from application_sdk.constants import (
26
26
  WORKFLOW_PORT,
27
27
  WORKFLOW_TLS_ENABLED,
28
28
  )
29
- from application_sdk.events.models import (
30
- ApplicationEventNames,
31
- Event,
32
- EventTypes,
33
- WorkerTokenRefreshEventData,
34
- )
35
29
  from application_sdk.interceptors.cleanup import CleanupInterceptor, cleanup
36
30
  from application_sdk.interceptors.correlation_context import (
37
31
  CorrelationContextInterceptor,
38
32
  )
39
33
  from application_sdk.interceptors.events import EventInterceptor, publish_event
40
34
  from application_sdk.interceptors.lock import RedisLockInterceptor
35
+ from application_sdk.interceptors.models import (
36
+ ApplicationEventNames,
37
+ Event,
38
+ EventTypes,
39
+ WorkerTokenRefreshEventData,
40
+ )
41
41
  from application_sdk.observability.logger_adaptor import get_logger
42
42
  from application_sdk.services.eventstore import EventStore
43
43
  from application_sdk.services.secretstore import SecretStore
@@ -0,0 +1,8 @@
1
+ from enum import Enum
2
+
3
+
4
+ class DataframeType(Enum):
5
+ """Enumeration of dataframe types."""
6
+
7
+ pandas = "pandas"
8
+ daft = "daft"
@@ -20,7 +20,6 @@ from typing import (
20
20
  from application_sdk.activities.common.utils import get_object_store_prefix
21
21
  from application_sdk.common.error_codes import CommonError
22
22
  from application_sdk.constants import TEMPORARY_PATH
23
- from application_sdk.inputs.sql_query import SQLQueryInput
24
23
  from application_sdk.observability.logger_adaptor import get_logger
25
24
  from application_sdk.services.objectstore import ObjectStore
26
25
 
@@ -280,13 +279,7 @@ async def get_database_names(
280
279
  temp_table_regex_sql=temp_table_regex_sql,
281
280
  use_posix_regex=True,
282
281
  )
283
- # We'll run the query to get all the database names
284
- database_sql_input = SQLQueryInput(
285
- engine=sql_client.engine,
286
- query=prepared_query, # type: ignore
287
- chunk_size=None,
288
- )
289
- database_dataframe = await database_sql_input.get_dataframe()
282
+ database_dataframe = await sql_client.get_results(prepared_query)
290
283
  database_names = list(database_dataframe["database_name"])
291
284
  return database_names
292
285
 
@@ -244,7 +244,7 @@ TRACES_FILE_NAME = "traces.parquet"
244
244
 
245
245
  # Dapr Sink Configuration
246
246
  ENABLE_OBSERVABILITY_DAPR_SINK = (
247
- os.getenv("ATLAN_ENABLE_OBSERVABILITY_DAPR_SINK", "false").lower() == "true"
247
+ os.getenv("ATLAN_ENABLE_OBSERVABILITY_DAPR_SINK", "true").lower() == "true"
248
248
  )
249
249
 
250
250
  # atlan_client configuration (non ATLAN_ prefix are rooted in pyatlan SDK, to be revisited)
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import os
2
3
  import re
3
4
  from enum import Enum
4
5
  from typing import Any, Dict, List, Optional, Set, Tuple
@@ -13,7 +14,6 @@ from application_sdk.common.utils import (
13
14
  )
14
15
  from application_sdk.constants import SQL_QUERIES_PATH, SQL_SERVER_MIN_VERSION
15
16
  from application_sdk.handlers import HandlerInterface
16
- from application_sdk.inputs.sql_query import SQLQueryInput
17
17
  from application_sdk.observability.logger_adaptor import get_logger
18
18
  from application_sdk.server.fastapi.models import MetadataType
19
19
 
@@ -77,10 +77,7 @@ class BaseSQLHandler(HandlerInterface):
77
77
  if self.metadata_sql is None:
78
78
  raise ValueError("metadata_sql is not defined")
79
79
 
80
- sql_input = SQLQueryInput(
81
- engine=self.sql_client.engine, query=self.metadata_sql, chunk_size=None
82
- )
83
- df = await sql_input.get_dataframe()
80
+ df = await self.sql_client.get_results(self.metadata_sql)
84
81
  result: List[Dict[Any, Any]] = []
85
82
  try:
86
83
  for row in df.to_dict(orient="records"):
@@ -103,12 +100,7 @@ class BaseSQLHandler(HandlerInterface):
103
100
  :raises Exception: If the credentials are invalid.
104
101
  """
105
102
  try:
106
- sql_input = SQLQueryInput(
107
- engine=self.sql_client.engine,
108
- query=self.test_authentication_sql,
109
- chunk_size=None,
110
- )
111
- df = await sql_input.get_dataframe()
103
+ df = await self.sql_client.get_results(self.test_authentication_sql)
112
104
  df.to_dict(orient="records")
113
105
  return True
114
106
  except Exception as exc:
@@ -335,16 +327,16 @@ class BaseSQLHandler(HandlerInterface):
335
327
  # Use the base query executor in multidb mode to get concatenated df
336
328
  activities = BaseSQLMetadataExtractionActivities()
337
329
  activities.multidb = True
330
+ base_output_path = payload.get("output_path", "")
338
331
  concatenated_df = await activities.query_executor(
339
- sql_engine=self.sql_client.engine if self.sql_client else None,
332
+ sql_client=self.sql_client,
340
333
  sql_query=self.tables_check_sql,
341
334
  workflow_args=payload,
342
- output_suffix="raw/table",
335
+ output_path=os.path.join(base_output_path, "raw", "table"),
343
336
  typename="table",
344
337
  write_to_file=False,
345
338
  concatenate=True,
346
339
  return_dataframe=True,
347
- sql_client=self.sql_client,
348
340
  )
349
341
 
350
342
  if concatenated_df is None:
@@ -362,12 +354,9 @@ class BaseSQLHandler(HandlerInterface):
362
354
  )
363
355
  if not query:
364
356
  raise ValueError("tables_check_sql is not defined")
365
- sql_input = SQLQueryInput(
366
- engine=self.sql_client.engine, query=query, chunk_size=None
367
- )
368
- sql_input = await sql_input.get_dataframe()
357
+ sql_results = await self.sql_client.get_results(query)
369
358
  try:
370
- total = _sum_counts_from_records(sql_input.to_dict(orient="records"))
359
+ total = _sum_counts_from_records(sql_results.to_dict(orient="records"))
371
360
  return _build_success(total)
372
361
  except Exception as exc:
373
362
  return _build_failure(exc)
@@ -404,13 +393,9 @@ class BaseSQLHandler(HandlerInterface):
404
393
 
405
394
  # If dialect version not available and client_version_sql is defined, use SQL query
406
395
  if not client_version and self.client_version_sql:
407
- sql_input = await SQLQueryInput(
408
- query=self.client_version_sql,
409
- engine=self.sql_client.engine,
410
- chunk_size=None,
411
- ).get_dataframe()
396
+ sql_results = await self.sql_client.get_results(self.client_version_sql)
412
397
  version_string = next(
413
- iter(sql_input.to_dict(orient="records")[0].values())
398
+ iter(sql_results.to_dict(orient="records")[0].values())
414
399
  )
415
400
  version_match = re.search(r"(\d+\.\d+(?:\.\d+)?)", version_string)
416
401
  if version_match:
@@ -12,7 +12,7 @@ from temporalio.worker import (
12
12
  WorkflowInterceptorClassInput,
13
13
  )
14
14
 
15
- from application_sdk.events.models import (
15
+ from application_sdk.interceptors.models import (
16
16
  ApplicationEventNames,
17
17
  Event,
18
18
  EventMetadata,