atlan-application-sdk 0.1.1rc43__py3-none-any.whl → 0.1.1rc45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,10 +2,8 @@ import os
2
2
  from typing import (
3
3
  TYPE_CHECKING,
4
4
  Any,
5
- AsyncGenerator,
6
5
  AsyncIterator,
7
6
  Dict,
8
- Generator,
9
7
  Iterator,
10
8
  List,
11
9
  Optional,
@@ -368,9 +366,9 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
368
366
  "Output prefix and path must be specified in workflow_args."
369
367
  )
370
368
  return ParquetOutput(
371
- output_prefix=output_prefix,
372
369
  output_path=output_path,
373
370
  output_suffix=output_suffix,
371
+ use_consolidation=True,
374
372
  )
375
373
 
376
374
  def _get_temp_table_regex_sql(self, typename: str) -> str:
@@ -553,7 +551,7 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
553
551
  )
554
552
 
555
553
  # Execute using helper method
556
- success, batched_iter = await self._execute_single_db(
554
+ success, batched_iterator = await self._execute_single_db(
557
555
  effective_sql_client.engine,
558
556
  prepared_query,
559
557
  parquet_output,
@@ -570,12 +568,12 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
570
568
  logger.warning(
571
569
  f"Failed to process database '{database_name}': {str(e)}. Skipping to next database."
572
570
  )
573
- success, batched_iter = False, None
571
+ success, batched_iterator = False, None
574
572
 
575
573
  if success:
576
574
  successful_databases.append(database_name)
577
- if not write_to_file and batched_iter:
578
- dataframe_list.append(batched_iter)
575
+ if not write_to_file and batched_iterator:
576
+ dataframe_list.append(batched_iterator)
579
577
  else:
580
578
  failed_databases.append(database_name)
581
579
 
@@ -615,37 +613,13 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
615
613
 
616
614
  try:
617
615
  sql_input = SQLQueryInput(engine=sql_engine, query=prepared_query)
618
- batched_iter = await sql_input.get_batched_dataframe()
616
+ batched_iterator = await sql_input.get_batched_dataframe()
619
617
 
620
618
  if write_to_file and parquet_output:
621
- # Wrap iterator into a proper (async)generator for type safety
622
- if hasattr(batched_iter, "__anext__"):
623
-
624
- async def _to_async_gen(
625
- it: AsyncIterator["pd.DataFrame"],
626
- ) -> AsyncGenerator["pd.DataFrame", None]:
627
- async for item in it:
628
- yield item
629
-
630
- wrapped: AsyncGenerator["pd.DataFrame", None] = _to_async_gen( # type: ignore
631
- batched_iter # type: ignore
632
- )
633
- await parquet_output.write_batched_dataframe(wrapped)
634
- else:
635
-
636
- def _to_gen(
637
- it: Iterator["pd.DataFrame"],
638
- ) -> Generator["pd.DataFrame", None, None]:
639
- for item in it:
640
- yield item
641
-
642
- wrapped_sync: Generator["pd.DataFrame", None, None] = _to_gen( # type: ignore
643
- batched_iter # type: ignore
644
- )
645
- await parquet_output.write_batched_dataframe(wrapped_sync)
619
+ await parquet_output.write_batched_dataframe(batched_iterator) # type: ignore
646
620
  return True, None
647
621
 
648
- return True, batched_iter
622
+ return True, batched_iterator
649
623
  except Exception as e:
650
624
  logger.error(
651
625
  f"Error during query execution or output writing: {e}", exc_info=True
@@ -863,10 +837,10 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
863
837
  file_names=workflow_args.get("file_names"),
864
838
  )
865
839
  raw_input = raw_input.get_batched_daft_dataframe()
840
+
866
841
  transformed_output = JsonOutput(
867
842
  output_path=output_path,
868
843
  output_suffix="transformed",
869
- output_prefix=output_prefix,
870
844
  typename=typename,
871
845
  chunk_start=workflow_args.get("chunk_start"),
872
846
  )
@@ -210,7 +210,6 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
210
210
  sql_input = await sql_input.get_dataframe()
211
211
 
212
212
  raw_output = ParquetOutput(
213
- output_prefix=workflow_args["output_prefix"],
214
213
  output_path=workflow_args["output_path"],
215
214
  output_suffix="raw/query",
216
215
  chunk_size=workflow_args["miner_args"].get("chunk_size", 100000),
@@ -218,7 +217,6 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
218
217
  end_marker=workflow_args["end_marker"],
219
218
  )
220
219
  await raw_output.write_dataframe(sql_input)
221
-
222
220
  logger.info(
223
221
  f"Query fetch completed, {raw_output.total_record_count} records processed",
224
222
  )
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
4
4
  from application_sdk.activities import ActivitiesInterface
5
5
  from application_sdk.clients.base import BaseClient
6
6
  from application_sdk.clients.utils import get_workflow_client
7
+ from application_sdk.constants import ENABLE_MCP
7
8
  from application_sdk.events.models import EventRegistration
8
9
  from application_sdk.handlers.base import BaseHandler
9
10
  from application_sdk.observability.logger_adaptor import get_logger
@@ -29,7 +30,7 @@ class BaseApplication:
29
30
  self,
30
31
  name: str,
31
32
  server: Optional[ServerInterface] = None,
32
- application_manifest: Optional[dict] = None,
33
+ application_manifest: Optional[Dict[str, Any]] = None,
33
34
  client_class: Optional[Type[BaseClient]] = None,
34
35
  handler_class: Optional[Type[BaseHandler]] = None,
35
36
  ):
@@ -39,6 +40,9 @@ class BaseApplication:
39
40
  Args:
40
41
  name (str): The name of the application.
41
42
  server (ServerInterface): The server class for the application.
43
+ application_manifest (Optional[Dict[str, Any]]): Application manifest configuration.
44
+ client_class (Optional[Type[BaseClient]]): Client class for the application.
45
+ handler_class (Optional[Type[BaseHandler]]): Handler class for the application.
42
46
  """
43
47
  self.application_name = name
44
48
 
@@ -49,14 +53,21 @@ class BaseApplication:
49
53
 
50
54
  self.workflow_client = get_workflow_client(application_name=name)
51
55
 
52
- self.application_manifest: Dict[str, Any] = application_manifest
56
+ self.application_manifest: Optional[Dict[str, Any]] = application_manifest
53
57
  self.bootstrap_event_registration()
54
58
 
55
59
  self.client_class = client_class or BaseClient
56
60
  self.handler_class = handler_class or BaseHandler
57
61
 
62
+ # MCP configuration
63
+ self.mcp_server: Optional["MCPServer"] = None
64
+ if ENABLE_MCP:
65
+ from application_sdk.server.mcp import MCPServer
66
+
67
+ self.mcp_server = MCPServer(application_name=name)
68
+
58
69
  def bootstrap_event_registration(self):
59
- self.event_subscriptions = {}
70
+ self.event_subscriptions: Dict[str, EventWorkflowTrigger] = {}
60
71
  if self.application_manifest is None:
61
72
  logger.warning("No application manifest found, skipping event registration")
62
73
  return
@@ -122,8 +133,8 @@ class BaseApplication:
122
133
  ]
123
134
  workflow_activities = []
124
135
  for workflow_class, activities_class in workflow_and_activities_classes:
125
- workflow_activities.extend(
126
- workflow_class.get_activities(activities_class())
136
+ workflow_activities.extend( # type: ignore
137
+ workflow_class.get_activities(activities_class()) # type: ignore
127
138
  )
128
139
 
129
140
  self.worker = Worker(
@@ -134,6 +145,13 @@ class BaseApplication:
134
145
  activity_executor=activity_executor,
135
146
  )
136
147
 
148
+ # Register MCP tools if ENABLED_MCP is True and an MCP server is initialized
149
+ if self.mcp_server:
150
+ logger.info("Registering MCP tools from workflow and activities classes")
151
+ await self.mcp_server.register_tools( # type: ignore
152
+ workflow_and_activities_classes=workflow_and_activities_classes
153
+ )
154
+
137
155
  async def start_workflow(self, workflow_args, workflow_class) -> Any:
138
156
  """
139
157
  Start a new workflow execution.
@@ -147,7 +165,7 @@ class BaseApplication:
147
165
  """
148
166
  if self.workflow_client is None:
149
167
  raise ValueError("Workflow client not initialized")
150
- return await self.workflow_client.start_workflow(workflow_args, workflow_class)
168
+ return await self.workflow_client.start_workflow(workflow_args, workflow_class) # type: ignore
151
169
 
152
170
  async def start_worker(self, daemon: bool = True):
153
171
  """
@@ -162,39 +180,61 @@ class BaseApplication:
162
180
 
163
181
  async def setup_server(
164
182
  self,
165
- workflow_class,
183
+ workflow_class: Type[WorkflowInterface],
166
184
  ui_enabled: bool = True,
167
185
  has_configmap: bool = False,
168
186
  ):
169
187
  """
170
- Optionally set up a server for the application. (No-op by default)
188
+ Set up FastAPI server and automatically mount MCP if enabled.
189
+
190
+ Args:
191
+ workflow_class (WorkflowInterface): The workflow class for the application.
192
+ ui_enabled (bool): Whether to enable the UI.
193
+ has_configmap (bool): Whether to enable the configmap.
171
194
  """
172
195
  if self.workflow_client is None:
173
196
  await self.workflow_client.load()
174
197
 
175
- # Overrides the application server. serves the UI, and handles the various triggers
198
+ mcp_http_app: Optional[Any] = None
199
+ lifespan: Optional[Any] = None
200
+
201
+ if self.mcp_server:
202
+ try:
203
+ mcp_http_app = await self.mcp_server.get_http_app()
204
+ lifespan = mcp_http_app.lifespan
205
+ except Exception as e:
206
+ logger.warning(f"Failed to get MCP HTTP app: {e}")
207
+
176
208
  self.server = APIServer(
209
+ lifespan=lifespan,
177
210
  workflow_client=self.workflow_client,
178
211
  ui_enabled=ui_enabled,
179
212
  handler=self.handler_class(client=self.client_class()),
180
213
  has_configmap=has_configmap,
181
214
  )
182
215
 
216
+ # Mount MCP at root
217
+ if mcp_http_app:
218
+ try:
219
+ self.server.app.mount("", mcp_http_app) # Mount at root
220
+ except Exception as e:
221
+ logger.warning(f"Failed to mount MCP HTTP app: {e}")
222
+
223
+ # Register event-based workflows if any
183
224
  if self.event_subscriptions:
184
225
  for event_trigger in self.event_subscriptions.values():
185
- if event_trigger.workflow_class is None:
226
+ if event_trigger.workflow_class is None: # type: ignore
186
227
  raise ValueError(
187
228
  f"Workflow class not set for event trigger {event_trigger.event_id}"
188
229
  )
189
230
 
190
- self.server.register_workflow(
191
- workflow_class=event_trigger.workflow_class,
231
+ self.server.register_workflow( # type: ignore
232
+ workflow_class=event_trigger.workflow_class, # type: ignore
192
233
  triggers=[event_trigger],
193
234
  )
194
235
 
195
- # register the workflow on the application server
196
- # the workflow is by default triggered by an HTTP POST request to the /start endpoint
197
- self.server.register_workflow(
236
+ # Register the main workflow (HTTP POST /start endpoint)
237
+ self.server.register_workflow( # type: ignore
198
238
  workflow_class=workflow_class,
199
239
  triggers=[HttpWorkflowTrigger()],
200
240
  )
@@ -255,3 +255,9 @@ REDIS_SENTINEL_HOSTS = os.getenv("REDIS_SENTINEL_HOSTS", "")
255
255
  IS_LOCKING_DISABLED = os.getenv("IS_LOCKING_DISABLED", "true").lower() == "true"
256
256
  #: Retry interval for lock acquisition
257
257
  LOCK_RETRY_INTERVAL = int(os.getenv("LOCK_RETRY_INTERVAL", "5"))
258
+
259
+ # MCP Configuration
260
+ #: Flag to indicate if MCP should be enabled or not. Turning this to true will setup an MCP server along
261
+ #: with the application.
262
+ ENABLE_MCP = os.getenv("ENABLE_MCP", "false").lower() == "true"
263
+ MCP_METADATA_KEY = "__atlan_application_sdk_mcp_metadata"
@@ -0,0 +1,63 @@
1
+ """
2
+ MCP tool decorator for marking activities as MCP tools.
3
+
4
+ This module provides the @mcp_tool decorator that developers use to mark
5
+ activities for automatic exposure via Model Context Protocol.
6
+ """
7
+
8
+ from typing import Any, Callable, Optional
9
+
10
+ from application_sdk.constants import MCP_METADATA_KEY
11
+ from application_sdk.server.mcp import MCPMetadata
12
+
13
+
14
+ def mcp_tool(
15
+ name: Optional[str] = None,
16
+ description: Optional[str] = None,
17
+ visible: bool = True,
18
+ *args,
19
+ **kwargs,
20
+ ):
21
+ """
22
+ Decorator to mark functions as MCP tools.
23
+
24
+ Use this decorator to mark any function as an MCP tool. You can additionally use the `visible`
25
+ parameter to control whether the tool is visible at runtime or not.
26
+
27
+ Function parameters that are Pydantic models will be automatically converted into correct JSON schema
28
+ for the tool specification. This is handled by the underlying FastMCP server implementation.
29
+
30
+ Args:
31
+ name(Optional[str]): The name of the tool. Defaults to the function name.
32
+ description(Optional[str]): The description of the tool. Defaults to the function docstring.
33
+ visible(bool): Whether the MCP tool is visible at runtime or not. Defaults to True.
34
+ *args: Additional arguments to pass to the tool.
35
+ **kwargs: Additional keyword arguments to pass to the tool.
36
+
37
+ Examples:
38
+ >>> @mcp_tool(name="add_numbers", description="Add two numbers", visible=True)
39
+ >>> def add_numbers(self, a: int, b: int) -> int:
40
+ >>> return a + b
41
+
42
+
43
+ >>> # Use with Temporal activity decorator
44
+ >>> @activity.defn
45
+ >>> @mcp_tool(name="get_weather", description="Get the weather for a given city")
46
+ >>> async def get_weather(self, city: str) -> str:
47
+ >>> # ... activity implementation unchanged ...
48
+ """
49
+
50
+ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
51
+ mcp_metadata = MCPMetadata(
52
+ name=name if name else f.__name__,
53
+ description=description if description else f.__doc__,
54
+ visible=visible,
55
+ args=args,
56
+ kwargs=kwargs,
57
+ )
58
+
59
+ setattr(f, MCP_METADATA_KEY, mcp_metadata)
60
+
61
+ return f
62
+
63
+ return decorator
@@ -22,6 +22,7 @@ class ParquetInput(Input):
22
22
  self,
23
23
  path: str,
24
24
  chunk_size: int = 100000,
25
+ buffer_size: int = 5000,
25
26
  file_names: Optional[List[str]] = None,
26
27
  ):
27
28
  """Initialize the Parquet input class.
@@ -32,6 +33,7 @@ class ParquetInput(Input):
32
33
  local path or object store path
33
34
  Wildcards are not supported.
34
35
  chunk_size (int): Number of rows per batch. Defaults to 100000.
36
+ buffer_size (int): Number of rows per batch. Defaults to 5000.
35
37
  file_names (Optional[List[str]]): List of file names to read. Defaults to None.
36
38
 
37
39
  Raises:
@@ -47,6 +49,7 @@ class ParquetInput(Input):
47
49
 
48
50
  self.path = path
49
51
  self.chunk_size = chunk_size
52
+ self.buffer_size = buffer_size
50
53
  self.file_names = file_names
51
54
 
52
55
  async def get_dataframe(self) -> "pd.DataFrame":
@@ -249,9 +252,18 @@ class ParquetInput(Input):
249
252
  parquet_files = await self.download_files()
250
253
  logger.info(f"Reading {len(parquet_files)} parquet files as daft batches")
251
254
 
252
- # Yield each discovered file as separate batch
253
- for parquet_file in parquet_files:
254
- yield daft.read_parquet(parquet_file)
255
+ # Create a lazy dataframe without loading data into memory
256
+ lazy_df = daft.read_parquet(parquet_files)
257
+
258
+ # Get total count efficiently
259
+ total_rows = lazy_df.count_rows()
260
+
261
+ # Yield chunks without loading everything into memory
262
+ for offset in range(0, total_rows, self.buffer_size):
263
+ chunk = lazy_df.offset(offset).limit(self.buffer_size)
264
+ yield chunk
265
+
266
+ del lazy_df
255
267
 
256
268
  except Exception as error:
257
269
  logger.error(
@@ -34,7 +34,7 @@ class SQLQueryInput(Input):
34
34
  self,
35
35
  query: str,
36
36
  engine: Union["Engine", str],
37
- chunk_size: Optional[int] = 100000,
37
+ chunk_size: Optional[int] = 5000,
38
38
  ):
39
39
  """Initialize the async SQL query input handler.
40
40
 
@@ -42,7 +42,7 @@ class SQLQueryInput(Input):
42
42
  engine (Union[Engine, str]): SQLAlchemy engine or connection string.
43
43
  query (str): The SQL query to execute.
44
44
  chunk_size (Optional[int], optional): Number of rows per batch.
45
- Defaults to 100000.
45
+ Defaults to 5000.
46
46
  """
47
47
  self.query = query
48
48
  self.engine = engine
@@ -114,7 +114,6 @@ class CleanupWorkflowInboundInterceptor(WorkflowInboundInterceptor):
114
114
  retry_policy=RetryPolicy(
115
115
  maximum_attempts=3,
116
116
  ),
117
- summary="This activity is used to cleanup the local artifacts and the activity state after the workflow is completed.",
118
117
  )
119
118
 
120
119
  logger.info("Cleanup completed successfully")
@@ -4,8 +4,11 @@ This module provides base classes and utilities for handling various types of da
4
4
  in the application, including file outputs and object store interactions.
5
5
  """
6
6
 
7
+ import gc
7
8
  import inspect
9
+ import os
8
10
  from abc import ABC, abstractmethod
11
+ from enum import Enum
9
12
  from typing import (
10
13
  TYPE_CHECKING,
11
14
  Any,
@@ -13,7 +16,6 @@ from typing import (
13
16
  Dict,
14
17
  Generator,
15
18
  List,
16
- Literal,
17
19
  Optional,
18
20
  Union,
19
21
  cast,
@@ -26,6 +28,7 @@ from application_sdk.activities.common.models import ActivityStatistics
26
28
  from application_sdk.activities.common.utils import get_object_store_prefix
27
29
  from application_sdk.common.dataframe_utils import is_empty_dataframe
28
30
  from application_sdk.observability.logger_adaptor import get_logger
31
+ from application_sdk.observability.metrics_adaptor import MetricType
29
32
  from application_sdk.services.objectstore import ObjectStore
30
33
 
31
34
  logger = get_logger(__name__)
@@ -36,6 +39,14 @@ if TYPE_CHECKING:
36
39
  import pandas as pd
37
40
 
38
41
 
42
+ class WriteMode(Enum):
43
+ """Enumeration of write modes for output operations."""
44
+
45
+ APPEND = "append"
46
+ OVERWRITE = "overwrite"
47
+ OVERWRITE_PARTITIONS = "overwrite-partitions"
48
+
49
+
39
50
  class Output(ABC):
40
51
  """Abstract base class for output handlers.
41
52
 
@@ -53,11 +64,13 @@ class Output(ABC):
53
64
  output_prefix: str
54
65
  total_record_count: int
55
66
  chunk_count: int
56
- statistics: List[int] = []
67
+ buffer_size: int
68
+ max_file_size_bytes: int
69
+ current_buffer_size: int
70
+ current_buffer_size_bytes: int
71
+ partitions: List[int]
57
72
 
58
- def estimate_dataframe_file_size(
59
- self, dataframe: "pd.DataFrame", file_type: Literal["json", "parquet"]
60
- ) -> int:
73
+ def estimate_dataframe_record_size(self, dataframe: "pd.DataFrame") -> int:
61
74
  """Estimate File size of a DataFrame by sampling a few records."""
62
75
  if len(dataframe) == 0:
63
76
  return 0
@@ -65,16 +78,47 @@ class Output(ABC):
65
78
  # Sample up to 10 records to estimate average size
66
79
  sample_size = min(10, len(dataframe))
67
80
  sample = dataframe.head(sample_size)
81
+ file_type = type(self).__name__.lower().replace("output", "")
82
+ compression_factor = 1
68
83
  if file_type == "json":
69
84
  sample_file = sample.to_json(orient="records", lines=True)
70
85
  else:
71
86
  sample_file = sample.to_parquet(index=False, compression="snappy")
87
+ compression_factor = 0.01
72
88
  if sample_file is not None:
73
- avg_record_size = len(sample_file) / sample_size
74
- return int(avg_record_size * len(dataframe))
89
+ avg_record_size = len(sample_file) / sample_size * compression_factor
90
+ return int(avg_record_size)
75
91
 
76
92
  return 0
77
93
 
94
+ def path_gen(
95
+ self,
96
+ chunk_count: Optional[int] = None,
97
+ chunk_part: int = 0,
98
+ start_marker: Optional[str] = None,
99
+ end_marker: Optional[str] = None,
100
+ ) -> str:
101
+ """Generate a file path for a chunk.
102
+
103
+ Args:
104
+ chunk_start (Optional[int]): Starting index of the chunk, or None for single chunk.
105
+ chunk_count (int): Total number of chunks.
106
+ start_marker (Optional[str]): Start marker for query extraction.
107
+ end_marker (Optional[str]): End marker for query extraction.
108
+
109
+ Returns:
110
+ str: Generated file path for the chunk.
111
+ """
112
+ # For Query Extraction - use start and end markers without chunk count
113
+ if start_marker and end_marker:
114
+ return f"{start_marker}_{end_marker}{self._EXTENSION}"
115
+
116
+ # For regular chunking - include chunk count
117
+ if chunk_count is None:
118
+ return f"{str(chunk_part)}{self._EXTENSION}"
119
+ else:
120
+ return f"chunk-{str(chunk_count)}-part{str(chunk_part)}{self._EXTENSION}"
121
+
78
122
  def process_null_fields(
79
123
  self,
80
124
  obj: Any,
@@ -146,15 +190,86 @@ class Output(ABC):
146
190
  await self.write_dataframe(dataframe)
147
191
  except Exception as e:
148
192
  logger.error(f"Error writing batched dataframe: {str(e)}")
193
+ raise
149
194
 
150
- @abstractmethod
151
195
  async def write_dataframe(self, dataframe: "pd.DataFrame"):
152
- """Write a pandas DataFrame to the output destination.
196
+ """Write a pandas DataFrame to Parquet files and upload to object store.
153
197
 
154
198
  Args:
155
199
  dataframe (pd.DataFrame): The DataFrame to write.
156
200
  """
157
- pass
201
+ try:
202
+ if self.chunk_start is None:
203
+ self.chunk_part = 0
204
+ if len(dataframe) == 0:
205
+ return
206
+
207
+ chunk_size_bytes = self.estimate_dataframe_record_size(dataframe)
208
+
209
+ for i in range(0, len(dataframe), self.buffer_size):
210
+ chunk = dataframe[i : i + self.buffer_size]
211
+
212
+ if (
213
+ self.current_buffer_size_bytes + chunk_size_bytes
214
+ > self.max_file_size_bytes
215
+ ):
216
+ output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
217
+ if os.path.exists(output_file_name):
218
+ await self._upload_file(output_file_name)
219
+ self.chunk_part += 1
220
+
221
+ self.current_buffer_size += len(chunk)
222
+ self.current_buffer_size_bytes += chunk_size_bytes * len(chunk)
223
+ await self._flush_buffer(chunk, self.chunk_part)
224
+
225
+ del chunk
226
+ gc.collect()
227
+
228
+ if self.current_buffer_size_bytes > 0:
229
+ # Finally upload the final file to the object store
230
+ output_file_name = f"{self.output_path}/{self.path_gen(self.chunk_count, self.chunk_part)}"
231
+ if os.path.exists(output_file_name):
232
+ await self._upload_file(output_file_name)
233
+ self.chunk_part += 1
234
+
235
+ # Record metrics for successful write
236
+ self.metrics.record_metric(
237
+ name="write_records",
238
+ value=len(dataframe),
239
+ metric_type=MetricType.COUNTER,
240
+ labels={"type": "pandas", "mode": WriteMode.APPEND.value},
241
+ description="Number of records written to files from pandas DataFrame",
242
+ )
243
+
244
+ # Record chunk metrics
245
+ self.metrics.record_metric(
246
+ name="chunks_written",
247
+ value=1,
248
+ metric_type=MetricType.COUNTER,
249
+ labels={"type": "pandas", "mode": WriteMode.APPEND.value},
250
+ description="Number of chunks written to files",
251
+ )
252
+
253
+ # If chunk_start is set we don't want to increment the chunk_count
254
+ # Since it should only increment the chunk_part in this case
255
+ if self.chunk_start is None:
256
+ self.chunk_count += 1
257
+ self.partitions.append(self.chunk_part)
258
+ except Exception as e:
259
+ # Record metrics for failed write
260
+ self.metrics.record_metric(
261
+ name="write_errors",
262
+ value=1,
263
+ metric_type=MetricType.COUNTER,
264
+ labels={
265
+ "type": "pandas",
266
+ "mode": WriteMode.APPEND.value,
267
+ "error": str(e),
268
+ },
269
+ description="Number of errors while writing to files",
270
+ )
271
+ logger.error(f"Error writing pandas dataframe to files: {str(e)}")
272
+ raise
158
273
 
159
274
  async def write_batched_daft_dataframe(
160
275
  self,
@@ -225,6 +340,55 @@ class Output(ABC):
225
340
  logger.error(f"Error getting statistics: {str(e)}")
226
341
  raise
227
342
 
343
+ async def _upload_file(self, file_name: str):
344
+ """Upload a file to the object store."""
345
+ await ObjectStore.upload_file(
346
+ source=file_name,
347
+ destination=get_object_store_prefix(file_name),
348
+ )
349
+
350
+ self.current_buffer_size_bytes = 0
351
+
352
+ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
353
+ """Flush the current buffer to a JSON file.
354
+
355
+ This method combines all DataFrames in the buffer, writes them to a JSON file,
356
+ and uploads the file to the object store.
357
+
358
+ Note:
359
+ If the buffer is empty or has no records, the method returns without writing.
360
+ """
361
+ try:
362
+ if not is_empty_dataframe(chunk):
363
+ self.total_record_count += len(chunk)
364
+ output_file_name = (
365
+ f"{self.output_path}/{self.path_gen(self.chunk_count, chunk_part)}"
366
+ )
367
+ await self.write_chunk(chunk, output_file_name)
368
+
369
+ self.current_buffer_size = 0
370
+
371
+ # Record chunk metrics
372
+ self.metrics.record_metric(
373
+ name="chunks_written",
374
+ value=1,
375
+ metric_type=MetricType.COUNTER,
376
+ labels={"type": "output"},
377
+ description="Number of chunks written to files",
378
+ )
379
+
380
+ except Exception as e:
381
+ # Record metrics for failed write
382
+ self.metrics.record_metric(
383
+ name="write_errors",
384
+ value=1,
385
+ metric_type=MetricType.COUNTER,
386
+ labels={"type": "output", "error": str(e)},
387
+ description="Number of errors while writing to files",
388
+ )
389
+ logger.error(f"Error flushing buffer to files: {str(e)}")
390
+ raise e
391
+
228
392
  async def write_statistics(self) -> Optional[Dict[str, Any]]:
229
393
  """Write statistics about the output to a JSON file.
230
394
 
@@ -238,8 +402,8 @@ class Output(ABC):
238
402
  # prepare the statistics
239
403
  statistics = {
240
404
  "total_record_count": self.total_record_count,
241
- "chunk_count": self.chunk_count,
242
- "partitions": self.statistics,
405
+ "chunk_count": len(self.partitions),
406
+ "partitions": self.partitions,
243
407
  }
244
408
 
245
409
  # Write the statistics to a json file