atlan-application-sdk 0.1.1rc43__py3-none-any.whl → 0.1.1rc45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- application_sdk/activities/metadata_extraction/sql.py +9 -35
- application_sdk/activities/query_extraction/sql.py +0 -2
- application_sdk/application/__init__.py +55 -15
- application_sdk/constants.py +6 -0
- application_sdk/decorators/mcp_tool.py +63 -0
- application_sdk/inputs/parquet.py +15 -3
- application_sdk/inputs/sql_query.py +2 -2
- application_sdk/interceptors/cleanup.py +0 -1
- application_sdk/outputs/__init__.py +176 -12
- application_sdk/outputs/json.py +57 -181
- application_sdk/outputs/parquet.py +230 -161
- application_sdk/server/mcp/__init__.py +4 -0
- application_sdk/server/mcp/models.py +11 -0
- application_sdk/server/mcp/server.py +96 -0
- application_sdk/transformers/query/__init__.py +1 -1
- application_sdk/version.py +1 -1
- application_sdk/workflows/metadata_extraction/sql.py +5 -4
- {atlan_application_sdk-0.1.1rc43.dist-info → atlan_application_sdk-0.1.1rc45.dist-info}/METADATA +1 -1
- {atlan_application_sdk-0.1.1rc43.dist-info → atlan_application_sdk-0.1.1rc45.dist-info}/RECORD +22 -18
- {atlan_application_sdk-0.1.1rc43.dist-info → atlan_application_sdk-0.1.1rc45.dist-info}/WHEEL +0 -0
- {atlan_application_sdk-0.1.1rc43.dist-info → atlan_application_sdk-0.1.1rc45.dist-info}/licenses/LICENSE +0 -0
- {atlan_application_sdk-0.1.1rc43.dist-info → atlan_application_sdk-0.1.1rc45.dist-info}/licenses/NOTICE +0 -0
|
@@ -2,10 +2,8 @@ import os
|
|
|
2
2
|
from typing import (
|
|
3
3
|
TYPE_CHECKING,
|
|
4
4
|
Any,
|
|
5
|
-
AsyncGenerator,
|
|
6
5
|
AsyncIterator,
|
|
7
6
|
Dict,
|
|
8
|
-
Generator,
|
|
9
7
|
Iterator,
|
|
10
8
|
List,
|
|
11
9
|
Optional,
|
|
@@ -368,9 +366,9 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
368
366
|
"Output prefix and path must be specified in workflow_args."
|
|
369
367
|
)
|
|
370
368
|
return ParquetOutput(
|
|
371
|
-
output_prefix=output_prefix,
|
|
372
369
|
output_path=output_path,
|
|
373
370
|
output_suffix=output_suffix,
|
|
371
|
+
use_consolidation=True,
|
|
374
372
|
)
|
|
375
373
|
|
|
376
374
|
def _get_temp_table_regex_sql(self, typename: str) -> str:
|
|
@@ -553,7 +551,7 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
553
551
|
)
|
|
554
552
|
|
|
555
553
|
# Execute using helper method
|
|
556
|
-
success,
|
|
554
|
+
success, batched_iterator = await self._execute_single_db(
|
|
557
555
|
effective_sql_client.engine,
|
|
558
556
|
prepared_query,
|
|
559
557
|
parquet_output,
|
|
@@ -570,12 +568,12 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
570
568
|
logger.warning(
|
|
571
569
|
f"Failed to process database '{database_name}': {str(e)}. Skipping to next database."
|
|
572
570
|
)
|
|
573
|
-
success,
|
|
571
|
+
success, batched_iterator = False, None
|
|
574
572
|
|
|
575
573
|
if success:
|
|
576
574
|
successful_databases.append(database_name)
|
|
577
|
-
if not write_to_file and
|
|
578
|
-
dataframe_list.append(
|
|
575
|
+
if not write_to_file and batched_iterator:
|
|
576
|
+
dataframe_list.append(batched_iterator)
|
|
579
577
|
else:
|
|
580
578
|
failed_databases.append(database_name)
|
|
581
579
|
|
|
@@ -615,37 +613,13 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
615
613
|
|
|
616
614
|
try:
|
|
617
615
|
sql_input = SQLQueryInput(engine=sql_engine, query=prepared_query)
|
|
618
|
-
|
|
616
|
+
batched_iterator = await sql_input.get_batched_dataframe()
|
|
619
617
|
|
|
620
618
|
if write_to_file and parquet_output:
|
|
621
|
-
|
|
622
|
-
if hasattr(batched_iter, "__anext__"):
|
|
623
|
-
|
|
624
|
-
async def _to_async_gen(
|
|
625
|
-
it: AsyncIterator["pd.DataFrame"],
|
|
626
|
-
) -> AsyncGenerator["pd.DataFrame", None]:
|
|
627
|
-
async for item in it:
|
|
628
|
-
yield item
|
|
629
|
-
|
|
630
|
-
wrapped: AsyncGenerator["pd.DataFrame", None] = _to_async_gen( # type: ignore
|
|
631
|
-
batched_iter # type: ignore
|
|
632
|
-
)
|
|
633
|
-
await parquet_output.write_batched_dataframe(wrapped)
|
|
634
|
-
else:
|
|
635
|
-
|
|
636
|
-
def _to_gen(
|
|
637
|
-
it: Iterator["pd.DataFrame"],
|
|
638
|
-
) -> Generator["pd.DataFrame", None, None]:
|
|
639
|
-
for item in it:
|
|
640
|
-
yield item
|
|
641
|
-
|
|
642
|
-
wrapped_sync: Generator["pd.DataFrame", None, None] = _to_gen( # type: ignore
|
|
643
|
-
batched_iter # type: ignore
|
|
644
|
-
)
|
|
645
|
-
await parquet_output.write_batched_dataframe(wrapped_sync)
|
|
619
|
+
await parquet_output.write_batched_dataframe(batched_iterator) # type: ignore
|
|
646
620
|
return True, None
|
|
647
621
|
|
|
648
|
-
return True,
|
|
622
|
+
return True, batched_iterator
|
|
649
623
|
except Exception as e:
|
|
650
624
|
logger.error(
|
|
651
625
|
f"Error during query execution or output writing: {e}", exc_info=True
|
|
@@ -863,10 +837,10 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
863
837
|
file_names=workflow_args.get("file_names"),
|
|
864
838
|
)
|
|
865
839
|
raw_input = raw_input.get_batched_daft_dataframe()
|
|
840
|
+
|
|
866
841
|
transformed_output = JsonOutput(
|
|
867
842
|
output_path=output_path,
|
|
868
843
|
output_suffix="transformed",
|
|
869
|
-
output_prefix=output_prefix,
|
|
870
844
|
typename=typename,
|
|
871
845
|
chunk_start=workflow_args.get("chunk_start"),
|
|
872
846
|
)
|
|
@@ -210,7 +210,6 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
210
210
|
sql_input = await sql_input.get_dataframe()
|
|
211
211
|
|
|
212
212
|
raw_output = ParquetOutput(
|
|
213
|
-
output_prefix=workflow_args["output_prefix"],
|
|
214
213
|
output_path=workflow_args["output_path"],
|
|
215
214
|
output_suffix="raw/query",
|
|
216
215
|
chunk_size=workflow_args["miner_args"].get("chunk_size", 100000),
|
|
@@ -218,7 +217,6 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
218
217
|
end_marker=workflow_args["end_marker"],
|
|
219
218
|
)
|
|
220
219
|
await raw_output.write_dataframe(sql_input)
|
|
221
|
-
|
|
222
220
|
logger.info(
|
|
223
221
|
f"Query fetch completed, {raw_output.total_record_count} records processed",
|
|
224
222
|
)
|
|
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|
|
4
4
|
from application_sdk.activities import ActivitiesInterface
|
|
5
5
|
from application_sdk.clients.base import BaseClient
|
|
6
6
|
from application_sdk.clients.utils import get_workflow_client
|
|
7
|
+
from application_sdk.constants import ENABLE_MCP
|
|
7
8
|
from application_sdk.events.models import EventRegistration
|
|
8
9
|
from application_sdk.handlers.base import BaseHandler
|
|
9
10
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
@@ -29,7 +30,7 @@ class BaseApplication:
|
|
|
29
30
|
self,
|
|
30
31
|
name: str,
|
|
31
32
|
server: Optional[ServerInterface] = None,
|
|
32
|
-
application_manifest: Optional[
|
|
33
|
+
application_manifest: Optional[Dict[str, Any]] = None,
|
|
33
34
|
client_class: Optional[Type[BaseClient]] = None,
|
|
34
35
|
handler_class: Optional[Type[BaseHandler]] = None,
|
|
35
36
|
):
|
|
@@ -39,6 +40,9 @@ class BaseApplication:
|
|
|
39
40
|
Args:
|
|
40
41
|
name (str): The name of the application.
|
|
41
42
|
server (ServerInterface): The server class for the application.
|
|
43
|
+
application_manifest (Optional[Dict[str, Any]]): Application manifest configuration.
|
|
44
|
+
client_class (Optional[Type[BaseClient]]): Client class for the application.
|
|
45
|
+
handler_class (Optional[Type[BaseHandler]]): Handler class for the application.
|
|
42
46
|
"""
|
|
43
47
|
self.application_name = name
|
|
44
48
|
|
|
@@ -49,14 +53,21 @@ class BaseApplication:
|
|
|
49
53
|
|
|
50
54
|
self.workflow_client = get_workflow_client(application_name=name)
|
|
51
55
|
|
|
52
|
-
self.application_manifest: Dict[str, Any] = application_manifest
|
|
56
|
+
self.application_manifest: Optional[Dict[str, Any]] = application_manifest
|
|
53
57
|
self.bootstrap_event_registration()
|
|
54
58
|
|
|
55
59
|
self.client_class = client_class or BaseClient
|
|
56
60
|
self.handler_class = handler_class or BaseHandler
|
|
57
61
|
|
|
62
|
+
# MCP configuration
|
|
63
|
+
self.mcp_server: Optional["MCPServer"] = None
|
|
64
|
+
if ENABLE_MCP:
|
|
65
|
+
from application_sdk.server.mcp import MCPServer
|
|
66
|
+
|
|
67
|
+
self.mcp_server = MCPServer(application_name=name)
|
|
68
|
+
|
|
58
69
|
def bootstrap_event_registration(self):
|
|
59
|
-
self.event_subscriptions = {}
|
|
70
|
+
self.event_subscriptions: Dict[str, EventWorkflowTrigger] = {}
|
|
60
71
|
if self.application_manifest is None:
|
|
61
72
|
logger.warning("No application manifest found, skipping event registration")
|
|
62
73
|
return
|
|
@@ -122,8 +133,8 @@ class BaseApplication:
|
|
|
122
133
|
]
|
|
123
134
|
workflow_activities = []
|
|
124
135
|
for workflow_class, activities_class in workflow_and_activities_classes:
|
|
125
|
-
workflow_activities.extend(
|
|
126
|
-
workflow_class.get_activities(activities_class())
|
|
136
|
+
workflow_activities.extend( # type: ignore
|
|
137
|
+
workflow_class.get_activities(activities_class()) # type: ignore
|
|
127
138
|
)
|
|
128
139
|
|
|
129
140
|
self.worker = Worker(
|
|
@@ -134,6 +145,13 @@ class BaseApplication:
|
|
|
134
145
|
activity_executor=activity_executor,
|
|
135
146
|
)
|
|
136
147
|
|
|
148
|
+
# Register MCP tools if ENABLED_MCP is True and an MCP server is initialized
|
|
149
|
+
if self.mcp_server:
|
|
150
|
+
logger.info("Registering MCP tools from workflow and activities classes")
|
|
151
|
+
await self.mcp_server.register_tools( # type: ignore
|
|
152
|
+
workflow_and_activities_classes=workflow_and_activities_classes
|
|
153
|
+
)
|
|
154
|
+
|
|
137
155
|
async def start_workflow(self, workflow_args, workflow_class) -> Any:
|
|
138
156
|
"""
|
|
139
157
|
Start a new workflow execution.
|
|
@@ -147,7 +165,7 @@ class BaseApplication:
|
|
|
147
165
|
"""
|
|
148
166
|
if self.workflow_client is None:
|
|
149
167
|
raise ValueError("Workflow client not initialized")
|
|
150
|
-
return await self.workflow_client.start_workflow(workflow_args, workflow_class)
|
|
168
|
+
return await self.workflow_client.start_workflow(workflow_args, workflow_class) # type: ignore
|
|
151
169
|
|
|
152
170
|
async def start_worker(self, daemon: bool = True):
|
|
153
171
|
"""
|
|
@@ -162,39 +180,61 @@ class BaseApplication:
|
|
|
162
180
|
|
|
163
181
|
async def setup_server(
|
|
164
182
|
self,
|
|
165
|
-
workflow_class,
|
|
183
|
+
workflow_class: Type[WorkflowInterface],
|
|
166
184
|
ui_enabled: bool = True,
|
|
167
185
|
has_configmap: bool = False,
|
|
168
186
|
):
|
|
169
187
|
"""
|
|
170
|
-
|
|
188
|
+
Set up FastAPI server and automatically mount MCP if enabled.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
workflow_class (WorkflowInterface): The workflow class for the application.
|
|
192
|
+
ui_enabled (bool): Whether to enable the UI.
|
|
193
|
+
has_configmap (bool): Whether to enable the configmap.
|
|
171
194
|
"""
|
|
172
195
|
if self.workflow_client is None:
|
|
173
196
|
await self.workflow_client.load()
|
|
174
197
|
|
|
175
|
-
|
|
198
|
+
mcp_http_app: Optional[Any] = None
|
|
199
|
+
lifespan: Optional[Any] = None
|
|
200
|
+
|
|
201
|
+
if self.mcp_server:
|
|
202
|
+
try:
|
|
203
|
+
mcp_http_app = await self.mcp_server.get_http_app()
|
|
204
|
+
lifespan = mcp_http_app.lifespan
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.warning(f"Failed to get MCP HTTP app: {e}")
|
|
207
|
+
|
|
176
208
|
self.server = APIServer(
|
|
209
|
+
lifespan=lifespan,
|
|
177
210
|
workflow_client=self.workflow_client,
|
|
178
211
|
ui_enabled=ui_enabled,
|
|
179
212
|
handler=self.handler_class(client=self.client_class()),
|
|
180
213
|
has_configmap=has_configmap,
|
|
181
214
|
)
|
|
182
215
|
|
|
216
|
+
# Mount MCP at root
|
|
217
|
+
if mcp_http_app:
|
|
218
|
+
try:
|
|
219
|
+
self.server.app.mount("", mcp_http_app) # Mount at root
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.warning(f"Failed to mount MCP HTTP app: {e}")
|
|
222
|
+
|
|
223
|
+
# Register event-based workflows if any
|
|
183
224
|
if self.event_subscriptions:
|
|
184
225
|
for event_trigger in self.event_subscriptions.values():
|
|
185
|
-
if event_trigger.workflow_class is None:
|
|
226
|
+
if event_trigger.workflow_class is None: # type: ignore
|
|
186
227
|
raise ValueError(
|
|
187
228
|
f"Workflow class not set for event trigger {event_trigger.event_id}"
|
|
188
229
|
)
|
|
189
230
|
|
|
190
|
-
self.server.register_workflow(
|
|
191
|
-
workflow_class=event_trigger.workflow_class,
|
|
231
|
+
self.server.register_workflow( # type: ignore
|
|
232
|
+
workflow_class=event_trigger.workflow_class, # type: ignore
|
|
192
233
|
triggers=[event_trigger],
|
|
193
234
|
)
|
|
194
235
|
|
|
195
|
-
#
|
|
196
|
-
#
|
|
197
|
-
self.server.register_workflow(
|
|
236
|
+
# Register the main workflow (HTTP POST /start endpoint)
|
|
237
|
+
self.server.register_workflow( # type: ignore
|
|
198
238
|
workflow_class=workflow_class,
|
|
199
239
|
triggers=[HttpWorkflowTrigger()],
|
|
200
240
|
)
|
application_sdk/constants.py
CHANGED
|
@@ -255,3 +255,9 @@ REDIS_SENTINEL_HOSTS = os.getenv("REDIS_SENTINEL_HOSTS", "")
|
|
|
255
255
|
IS_LOCKING_DISABLED = os.getenv("IS_LOCKING_DISABLED", "true").lower() == "true"
|
|
256
256
|
#: Retry interval for lock acquisition
|
|
257
257
|
LOCK_RETRY_INTERVAL = int(os.getenv("LOCK_RETRY_INTERVAL", "5"))
|
|
258
|
+
|
|
259
|
+
# MCP Configuration
|
|
260
|
+
#: Flag to indicate if MCP should be enabled or not. Turning this to true will setup an MCP server along
|
|
261
|
+
#: with the application.
|
|
262
|
+
ENABLE_MCP = os.getenv("ENABLE_MCP", "false").lower() == "true"
|
|
263
|
+
MCP_METADATA_KEY = "__atlan_application_sdk_mcp_metadata"
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP tool decorator for marking activities as MCP tools.
|
|
3
|
+
|
|
4
|
+
This module provides the @mcp_tool decorator that developers use to mark
|
|
5
|
+
activities for automatic exposure via Model Context Protocol.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Callable, Optional
|
|
9
|
+
|
|
10
|
+
from application_sdk.constants import MCP_METADATA_KEY
|
|
11
|
+
from application_sdk.server.mcp import MCPMetadata
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def mcp_tool(
|
|
15
|
+
name: Optional[str] = None,
|
|
16
|
+
description: Optional[str] = None,
|
|
17
|
+
visible: bool = True,
|
|
18
|
+
*args,
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Decorator to mark functions as MCP tools.
|
|
23
|
+
|
|
24
|
+
Use this decorator to mark any function as an MCP tool. You can additionally use the `visible`
|
|
25
|
+
parameter to control whether the tool is visible at runtime or not.
|
|
26
|
+
|
|
27
|
+
Function parameters that are Pydantic models will be automatically converted into correct JSON schema
|
|
28
|
+
for the tool specification. This is handled by the underlying FastMCP server implementation.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
name(Optional[str]): The name of the tool. Defaults to the function name.
|
|
32
|
+
description(Optional[str]): The description of the tool. Defaults to the function docstring.
|
|
33
|
+
visible(bool): Whether the MCP tool is visible at runtime or not. Defaults to True.
|
|
34
|
+
*args: Additional arguments to pass to the tool.
|
|
35
|
+
**kwargs: Additional keyword arguments to pass to the tool.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
>>> @mcp_tool(name="add_numbers", description="Add two numbers", visible=True)
|
|
39
|
+
>>> def add_numbers(self, a: int, b: int) -> int:
|
|
40
|
+
>>> return a + b
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
>>> # Use with Temporal activity decorator
|
|
44
|
+
>>> @activity.defn
|
|
45
|
+
>>> @mcp_tool(name="get_weather", description="Get the weather for a given city")
|
|
46
|
+
>>> async def get_weather(self, city: str) -> str:
|
|
47
|
+
>>> # ... activity implementation unchanged ...
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
|
51
|
+
mcp_metadata = MCPMetadata(
|
|
52
|
+
name=name if name else f.__name__,
|
|
53
|
+
description=description if description else f.__doc__,
|
|
54
|
+
visible=visible,
|
|
55
|
+
args=args,
|
|
56
|
+
kwargs=kwargs,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
setattr(f, MCP_METADATA_KEY, mcp_metadata)
|
|
60
|
+
|
|
61
|
+
return f
|
|
62
|
+
|
|
63
|
+
return decorator
|
|
@@ -22,6 +22,7 @@ class ParquetInput(Input):
|
|
|
22
22
|
self,
|
|
23
23
|
path: str,
|
|
24
24
|
chunk_size: int = 100000,
|
|
25
|
+
buffer_size: int = 5000,
|
|
25
26
|
file_names: Optional[List[str]] = None,
|
|
26
27
|
):
|
|
27
28
|
"""Initialize the Parquet input class.
|
|
@@ -32,6 +33,7 @@ class ParquetInput(Input):
|
|
|
32
33
|
local path or object store path
|
|
33
34
|
Wildcards are not supported.
|
|
34
35
|
chunk_size (int): Number of rows per batch. Defaults to 100000.
|
|
36
|
+
buffer_size (int): Number of rows per batch. Defaults to 5000.
|
|
35
37
|
file_names (Optional[List[str]]): List of file names to read. Defaults to None.
|
|
36
38
|
|
|
37
39
|
Raises:
|
|
@@ -47,6 +49,7 @@ class ParquetInput(Input):
|
|
|
47
49
|
|
|
48
50
|
self.path = path
|
|
49
51
|
self.chunk_size = chunk_size
|
|
52
|
+
self.buffer_size = buffer_size
|
|
50
53
|
self.file_names = file_names
|
|
51
54
|
|
|
52
55
|
async def get_dataframe(self) -> "pd.DataFrame":
|
|
@@ -249,9 +252,18 @@ class ParquetInput(Input):
|
|
|
249
252
|
parquet_files = await self.download_files()
|
|
250
253
|
logger.info(f"Reading {len(parquet_files)} parquet files as daft batches")
|
|
251
254
|
|
|
252
|
-
#
|
|
253
|
-
|
|
254
|
-
|
|
255
|
+
# Create a lazy dataframe without loading data into memory
|
|
256
|
+
lazy_df = daft.read_parquet(parquet_files)
|
|
257
|
+
|
|
258
|
+
# Get total count efficiently
|
|
259
|
+
total_rows = lazy_df.count_rows()
|
|
260
|
+
|
|
261
|
+
# Yield chunks without loading everything into memory
|
|
262
|
+
for offset in range(0, total_rows, self.buffer_size):
|
|
263
|
+
chunk = lazy_df.offset(offset).limit(self.buffer_size)
|
|
264
|
+
yield chunk
|
|
265
|
+
|
|
266
|
+
del lazy_df
|
|
255
267
|
|
|
256
268
|
except Exception as error:
|
|
257
269
|
logger.error(
|
|
@@ -34,7 +34,7 @@ class SQLQueryInput(Input):
|
|
|
34
34
|
self,
|
|
35
35
|
query: str,
|
|
36
36
|
engine: Union["Engine", str],
|
|
37
|
-
chunk_size: Optional[int] =
|
|
37
|
+
chunk_size: Optional[int] = 5000,
|
|
38
38
|
):
|
|
39
39
|
"""Initialize the async SQL query input handler.
|
|
40
40
|
|
|
@@ -42,7 +42,7 @@ class SQLQueryInput(Input):
|
|
|
42
42
|
engine (Union[Engine, str]): SQLAlchemy engine or connection string.
|
|
43
43
|
query (str): The SQL query to execute.
|
|
44
44
|
chunk_size (Optional[int], optional): Number of rows per batch.
|
|
45
|
-
Defaults to
|
|
45
|
+
Defaults to 5000.
|
|
46
46
|
"""
|
|
47
47
|
self.query = query
|
|
48
48
|
self.engine = engine
|
|
@@ -114,7 +114,6 @@ class CleanupWorkflowInboundInterceptor(WorkflowInboundInterceptor):
|
|
|
114
114
|
retry_policy=RetryPolicy(
|
|
115
115
|
maximum_attempts=3,
|
|
116
116
|
),
|
|
117
|
-
summary="This activity is used to cleanup the local artifacts and the activity state after the workflow is completed.",
|
|
118
117
|
)
|
|
119
118
|
|
|
120
119
|
logger.info("Cleanup completed successfully")
|
|
@@ -4,8 +4,11 @@ This module provides base classes and utilities for handling various types of da
|
|
|
4
4
|
in the application, including file outputs and object store interactions.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
import gc
|
|
7
8
|
import inspect
|
|
9
|
+
import os
|
|
8
10
|
from abc import ABC, abstractmethod
|
|
11
|
+
from enum import Enum
|
|
9
12
|
from typing import (
|
|
10
13
|
TYPE_CHECKING,
|
|
11
14
|
Any,
|
|
@@ -13,7 +16,6 @@ from typing import (
|
|
|
13
16
|
Dict,
|
|
14
17
|
Generator,
|
|
15
18
|
List,
|
|
16
|
-
Literal,
|
|
17
19
|
Optional,
|
|
18
20
|
Union,
|
|
19
21
|
cast,
|
|
@@ -26,6 +28,7 @@ from application_sdk.activities.common.models import ActivityStatistics
|
|
|
26
28
|
from application_sdk.activities.common.utils import get_object_store_prefix
|
|
27
29
|
from application_sdk.common.dataframe_utils import is_empty_dataframe
|
|
28
30
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
31
|
+
from application_sdk.observability.metrics_adaptor import MetricType
|
|
29
32
|
from application_sdk.services.objectstore import ObjectStore
|
|
30
33
|
|
|
31
34
|
logger = get_logger(__name__)
|
|
@@ -36,6 +39,14 @@ if TYPE_CHECKING:
|
|
|
36
39
|
import pandas as pd
|
|
37
40
|
|
|
38
41
|
|
|
42
|
+
class WriteMode(Enum):
|
|
43
|
+
"""Enumeration of write modes for output operations."""
|
|
44
|
+
|
|
45
|
+
APPEND = "append"
|
|
46
|
+
OVERWRITE = "overwrite"
|
|
47
|
+
OVERWRITE_PARTITIONS = "overwrite-partitions"
|
|
48
|
+
|
|
49
|
+
|
|
39
50
|
class Output(ABC):
|
|
40
51
|
"""Abstract base class for output handlers.
|
|
41
52
|
|
|
@@ -53,11 +64,13 @@ class Output(ABC):
|
|
|
53
64
|
output_prefix: str
|
|
54
65
|
total_record_count: int
|
|
55
66
|
chunk_count: int
|
|
56
|
-
|
|
67
|
+
buffer_size: int
|
|
68
|
+
max_file_size_bytes: int
|
|
69
|
+
current_buffer_size: int
|
|
70
|
+
current_buffer_size_bytes: int
|
|
71
|
+
partitions: List[int]
|
|
57
72
|
|
|
58
|
-
def
|
|
59
|
-
self, dataframe: "pd.DataFrame", file_type: Literal["json", "parquet"]
|
|
60
|
-
) -> int:
|
|
73
|
+
def estimate_dataframe_record_size(self, dataframe: "pd.DataFrame") -> int:
|
|
61
74
|
"""Estimate File size of a DataFrame by sampling a few records."""
|
|
62
75
|
if len(dataframe) == 0:
|
|
63
76
|
return 0
|
|
@@ -65,16 +78,47 @@ class Output(ABC):
|
|
|
65
78
|
# Sample up to 10 records to estimate average size
|
|
66
79
|
sample_size = min(10, len(dataframe))
|
|
67
80
|
sample = dataframe.head(sample_size)
|
|
81
|
+
file_type = type(self).__name__.lower().replace("output", "")
|
|
82
|
+
compression_factor = 1
|
|
68
83
|
if file_type == "json":
|
|
69
84
|
sample_file = sample.to_json(orient="records", lines=True)
|
|
70
85
|
else:
|
|
71
86
|
sample_file = sample.to_parquet(index=False, compression="snappy")
|
|
87
|
+
compression_factor = 0.01
|
|
72
88
|
if sample_file is not None:
|
|
73
|
-
avg_record_size = len(sample_file) / sample_size
|
|
74
|
-
return int(avg_record_size
|
|
89
|
+
avg_record_size = len(sample_file) / sample_size * compression_factor
|
|
90
|
+
return int(avg_record_size)
|
|
75
91
|
|
|
76
92
|
return 0
|
|
77
93
|
|
|
94
|
+
def path_gen(
|
|
95
|
+
self,
|
|
96
|
+
chunk_count: Optional[int] = None,
|
|
97
|
+
chunk_part: int = 0,
|
|
98
|
+
start_marker: Optional[str] = None,
|
|
99
|
+
end_marker: Optional[str] = None,
|
|
100
|
+
) -> str:
|
|
101
|
+
"""Generate a file path for a chunk.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
chunk_start (Optional[int]): Starting index of the chunk, or None for single chunk.
|
|
105
|
+
chunk_count (int): Total number of chunks.
|
|
106
|
+
start_marker (Optional[str]): Start marker for query extraction.
|
|
107
|
+
end_marker (Optional[str]): End marker for query extraction.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
str: Generated file path for the chunk.
|
|
111
|
+
"""
|
|
112
|
+
# For Query Extraction - use start and end markers without chunk count
|
|
113
|
+
if start_marker and end_marker:
|
|
114
|
+
return f"{start_marker}_{end_marker}{self._EXTENSION}"
|
|
115
|
+
|
|
116
|
+
# For regular chunking - include chunk count
|
|
117
|
+
if chunk_count is None:
|
|
118
|
+
return f"{str(chunk_part)}{self._EXTENSION}"
|
|
119
|
+
else:
|
|
120
|
+
return f"chunk-{str(chunk_count)}-part{str(chunk_part)}{self._EXTENSION}"
|
|
121
|
+
|
|
78
122
|
def process_null_fields(
|
|
79
123
|
self,
|
|
80
124
|
obj: Any,
|
|
@@ -146,15 +190,86 @@ class Output(ABC):
|
|
|
146
190
|
await self.write_dataframe(dataframe)
|
|
147
191
|
except Exception as e:
|
|
148
192
|
logger.error(f"Error writing batched dataframe: {str(e)}")
|
|
193
|
+
raise
|
|
149
194
|
|
|
150
|
-
@abstractmethod
|
|
151
195
|
async def write_dataframe(self, dataframe: "pd.DataFrame"):
|
|
152
|
-
"""Write a pandas DataFrame to
|
|
196
|
+
"""Write a pandas DataFrame to Parquet files and upload to object store.
|
|
153
197
|
|
|
154
198
|
Args:
|
|
155
199
|
dataframe (pd.DataFrame): The DataFrame to write.
|
|
156
200
|
"""
|
|
157
|
-
|
|
201
|
+
try:
|
|
202
|
+
if self.chunk_start is None:
|
|
203
|
+
self.chunk_part = 0
|
|
204
|
+
if len(dataframe) == 0:
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
chunk_size_bytes = self.estimate_dataframe_record_size(dataframe)
|
|
208
|
+
|
|
209
|
+
for i in range(0, len(dataframe), self.buffer_size):
|
|
210
|
+
chunk = dataframe[i : i + self.buffer_size]
|
|
211
|
+
|
|
212
|
+
if (
|
|
213
|
+
self.current_buffer_size_bytes + chunk_size_bytes
|
|
214
|
+
> self.max_file_size_bytes
|
|
215
|
+
):
|
|
216
|
+
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
|
|
217
|
+
if os.path.exists(output_file_name):
|
|
218
|
+
await self._upload_file(output_file_name)
|
|
219
|
+
self.chunk_part += 1
|
|
220
|
+
|
|
221
|
+
self.current_buffer_size += len(chunk)
|
|
222
|
+
self.current_buffer_size_bytes += chunk_size_bytes * len(chunk)
|
|
223
|
+
await self._flush_buffer(chunk, self.chunk_part)
|
|
224
|
+
|
|
225
|
+
del chunk
|
|
226
|
+
gc.collect()
|
|
227
|
+
|
|
228
|
+
if self.current_buffer_size_bytes > 0:
|
|
229
|
+
# Finally upload the final file to the object store
|
|
230
|
+
output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
|
|
231
|
+
if os.path.exists(output_file_name):
|
|
232
|
+
await self._upload_file(output_file_name)
|
|
233
|
+
self.chunk_part += 1
|
|
234
|
+
|
|
235
|
+
# Record metrics for successful write
|
|
236
|
+
self.metrics.record_metric(
|
|
237
|
+
name="write_records",
|
|
238
|
+
value=len(dataframe),
|
|
239
|
+
metric_type=MetricType.COUNTER,
|
|
240
|
+
labels={"type": "pandas", "mode": WriteMode.APPEND.value},
|
|
241
|
+
description="Number of records written to files from pandas DataFrame",
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Record chunk metrics
|
|
245
|
+
self.metrics.record_metric(
|
|
246
|
+
name="chunks_written",
|
|
247
|
+
value=1,
|
|
248
|
+
metric_type=MetricType.COUNTER,
|
|
249
|
+
labels={"type": "pandas", "mode": WriteMode.APPEND.value},
|
|
250
|
+
description="Number of chunks written to files",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# If chunk_start is set we don't want to increment the chunk_count
|
|
254
|
+
# Since it should only increment the chunk_part in this case
|
|
255
|
+
if self.chunk_start is None:
|
|
256
|
+
self.chunk_count += 1
|
|
257
|
+
self.partitions.append(self.chunk_part)
|
|
258
|
+
except Exception as e:
|
|
259
|
+
# Record metrics for failed write
|
|
260
|
+
self.metrics.record_metric(
|
|
261
|
+
name="write_errors",
|
|
262
|
+
value=1,
|
|
263
|
+
metric_type=MetricType.COUNTER,
|
|
264
|
+
labels={
|
|
265
|
+
"type": "pandas",
|
|
266
|
+
"mode": WriteMode.APPEND.value,
|
|
267
|
+
"error": str(e),
|
|
268
|
+
},
|
|
269
|
+
description="Number of errors while writing to files",
|
|
270
|
+
)
|
|
271
|
+
logger.error(f"Error writing pandas dataframe to files: {str(e)}")
|
|
272
|
+
raise
|
|
158
273
|
|
|
159
274
|
async def write_batched_daft_dataframe(
|
|
160
275
|
self,
|
|
@@ -225,6 +340,55 @@ class Output(ABC):
|
|
|
225
340
|
logger.error(f"Error getting statistics: {str(e)}")
|
|
226
341
|
raise
|
|
227
342
|
|
|
343
|
+
async def _upload_file(self, file_name: str):
|
|
344
|
+
"""Upload a file to the object store."""
|
|
345
|
+
await ObjectStore.upload_file(
|
|
346
|
+
source=file_name,
|
|
347
|
+
destination=get_object_store_prefix(file_name),
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
self.current_buffer_size_bytes = 0
|
|
351
|
+
|
|
352
|
+
async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
|
|
353
|
+
"""Flush the current buffer to a JSON file.
|
|
354
|
+
|
|
355
|
+
This method combines all DataFrames in the buffer, writes them to a JSON file,
|
|
356
|
+
and uploads the file to the object store.
|
|
357
|
+
|
|
358
|
+
Note:
|
|
359
|
+
If the buffer is empty or has no records, the method returns without writing.
|
|
360
|
+
"""
|
|
361
|
+
try:
|
|
362
|
+
if not is_empty_dataframe(chunk):
|
|
363
|
+
self.total_record_count += len(chunk)
|
|
364
|
+
output_file_name = (
|
|
365
|
+
f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part)}"
|
|
366
|
+
)
|
|
367
|
+
await self.write_chunk(chunk, output_file_name)
|
|
368
|
+
|
|
369
|
+
self.current_buffer_size = 0
|
|
370
|
+
|
|
371
|
+
# Record chunk metrics
|
|
372
|
+
self.metrics.record_metric(
|
|
373
|
+
name="chunks_written",
|
|
374
|
+
value=1,
|
|
375
|
+
metric_type=MetricType.COUNTER,
|
|
376
|
+
labels={"type": "output"},
|
|
377
|
+
description="Number of chunks written to files",
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
except Exception as e:
|
|
381
|
+
# Record metrics for failed write
|
|
382
|
+
self.metrics.record_metric(
|
|
383
|
+
name="write_errors",
|
|
384
|
+
value=1,
|
|
385
|
+
metric_type=MetricType.COUNTER,
|
|
386
|
+
labels={"type": "output", "error": str(e)},
|
|
387
|
+
description="Number of errors while writing to files",
|
|
388
|
+
)
|
|
389
|
+
logger.error(f"Error flushing buffer to files: {str(e)}")
|
|
390
|
+
raise e
|
|
391
|
+
|
|
228
392
|
async def write_statistics(self) -> Optional[Dict[str, Any]]:
|
|
229
393
|
"""Write statistics about the output to a JSON file.
|
|
230
394
|
|
|
@@ -238,8 +402,8 @@ class Output(ABC):
|
|
|
238
402
|
# prepare the statistics
|
|
239
403
|
statistics = {
|
|
240
404
|
"total_record_count": self.total_record_count,
|
|
241
|
-
"chunk_count": self.
|
|
242
|
-
"partitions": self.
|
|
405
|
+
"chunk_count": len(self.partitions),
|
|
406
|
+
"partitions": self.partitions,
|
|
243
407
|
}
|
|
244
408
|
|
|
245
409
|
# Write the statistics to a json file
|