cloe-nessy 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. cloe_nessy/integration/reader/api_reader.py +4 -2
  2. cloe_nessy/integration/reader/catalog_reader.py +6 -3
  3. cloe_nessy/integration/reader/excel_reader.py +1 -1
  4. cloe_nessy/integration/reader/file_reader.py +78 -5
  5. cloe_nessy/integration/writer/__init__.py +8 -1
  6. cloe_nessy/integration/writer/delta_writer/__init__.py +7 -0
  7. cloe_nessy/integration/writer/delta_writer/delta_append_writer.py +108 -0
  8. cloe_nessy/integration/writer/delta_writer/delta_merge_writer.py +215 -0
  9. cloe_nessy/integration/writer/delta_writer/delta_table_operation_type.py +21 -0
  10. cloe_nessy/integration/writer/delta_writer/delta_writer_base.py +210 -0
  11. cloe_nessy/integration/writer/delta_writer/exceptions.py +4 -0
  12. cloe_nessy/integration/writer/file_writer.py +132 -0
  13. cloe_nessy/integration/writer/writer.py +54 -0
  14. cloe_nessy/models/adapter/unity_catalog_adapter.py +5 -1
  15. cloe_nessy/models/schema.py +1 -1
  16. cloe_nessy/models/table.py +17 -6
  17. cloe_nessy/object_manager/table_manager.py +73 -19
  18. cloe_nessy/pipeline/actions/__init__.py +7 -1
  19. cloe_nessy/pipeline/actions/read_catalog_table.py +1 -4
  20. cloe_nessy/pipeline/actions/write_delta_append.py +69 -0
  21. cloe_nessy/pipeline/actions/write_delta_merge.py +118 -0
  22. cloe_nessy/pipeline/actions/write_file.py +94 -0
  23. {cloe_nessy-0.3.8.dist-info → cloe_nessy-0.3.9.dist-info}/METADATA +28 -4
  24. {cloe_nessy-0.3.8.dist-info → cloe_nessy-0.3.9.dist-info}/RECORD +26 -15
  25. {cloe_nessy-0.3.8.dist-info → cloe_nessy-0.3.9.dist-info}/WHEEL +1 -1
  26. {cloe_nessy-0.3.8.dist-info → cloe_nessy-0.3.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,210 @@
1
+ import logging
2
+ from abc import ABC
3
+ from dataclasses import dataclass, field
4
+
5
+ from pyspark.sql import DataFrame, Row
6
+ from pyspark.sql.functions import col, concat, concat_ws, format_string, lit
7
+
8
+ from ....object_manager import TableManager
9
+ from ....session import SessionManager
10
+ from ..writer import BaseWriter
11
+ from .delta_table_operation_type import DeltaTableOperationType
12
+ from .exceptions import EmptyDataframeError
13
+
14
+
15
+ @dataclass
16
+ class DeltaWriterLogs:
17
+ """Dataclass defining the delta writer logs table."""
18
+
19
+ logger_name = "Tabular:DeltaWriter"
20
+ log_type: str = "nessy_simple_logs"
21
+ uc_table_name: str = "nessy_simple_logs"
22
+ uc_table_columns: dict[str, str] = field(
23
+ default_factory=lambda: {
24
+ "message": "STRING",
25
+ }
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class TableOperationMetricsLogs:
31
+ """Dataclass defining the table operation metrics logs table."""
32
+
33
+ logger_name = "Tabular:TableOperationMetrics"
34
+ log_type: str = "nessy_table_operation_metrics"
35
+ uc_table_name: str = "nessy_table_operation_metrics"
36
+ uc_table_columns: dict[str, str] = field(
37
+ default_factory=lambda: {
38
+ "timestamp": "TIMESTAMP",
39
+ "table_identifier": "STRING",
40
+ "operation_type": "STRING",
41
+ "metric": "STRING",
42
+ "value": "STRING",
43
+ "user_name": "STRING",
44
+ "job_id": "STRING",
45
+ "job_run_id": "STRING",
46
+ "run_id": "STRING",
47
+ "notebook_id": "STRING",
48
+ "cluster_id": "STRING",
49
+ }
50
+ )
51
+
52
+
53
+ class BaseDeltaWriter(BaseWriter, ABC):
54
+ """A class for writing DataFrames to Delta tables."""
55
+
56
+ def __init__(
57
+ self,
58
+ tabular_logger: logging.Logger | None = None,
59
+ table_operation_metrics_logger: logging.Logger | None = None,
60
+ ):
61
+ super().__init__()
62
+ self._spark = SessionManager.get_spark_session()
63
+ self._dbutils = SessionManager.get_utils()
64
+ self._table_operation_metrics_logger = table_operation_metrics_logger or self.get_tabular_logger(
65
+ **DeltaWriterLogs().__dict__
66
+ )
67
+ self.table_manager = TableManager()
68
+ self._tabular_logger = tabular_logger or self.get_tabular_logger(**DeltaWriterLogs().__dict__)
69
+
70
+ def _delta_operation_log(self, table_identifier: str, operation_type: DeltaTableOperationType) -> dict:
71
+ """Returns a dictionary containing the most recent delta log of a Delta table for given operation type.
72
+
73
+ Args:
74
+ table_identifier: The identifier of the Delta table in the format 'catalog.schema.table'.
75
+ operation_type: A DeltaTableOperationType
76
+ object specifying the type of operation for which metrics should
77
+ be retrieved (UPDATE, DELETE, MERGE or WRITE).
78
+
79
+ Returns:
80
+ dict: A dictionary containing the operation log.
81
+ """
82
+ delta_history = self._spark.sql(f"DESCRIBE HISTORY {table_identifier}")
83
+
84
+ try:
85
+ operation_log: dict = (
86
+ delta_history.filter(col("operation") == operation_type.name.replace("_", " "))
87
+ .orderBy("version", ascending=False)
88
+ .collect()[0]
89
+ .asDict()
90
+ )
91
+ except IndexError:
92
+ operation_log = {}
93
+
94
+ return operation_log
95
+
96
+ def _report_delta_table_operation_metrics(
97
+ self, table_identifier: str, operation_type: DeltaTableOperationType
98
+ ) -> None:
99
+ """Logs the most recent metrics of a Delta table for given operation type.
100
+
101
+ Args:
102
+ table_identifier: The identifier of the Delta table in the format 'catalog.schema.table'.
103
+ operation_type: A DeltaTableOperationType object specifying the type
104
+ of operation for which metrics should be retrieved (UPDATE, DELETE,
105
+ MERGE or WRITE).
106
+ """
107
+ operation_log = self._delta_operation_log(table_identifier, operation_type)
108
+ timestamp = operation_log.get("timestamp")
109
+ user_name = operation_log.get("userName")
110
+ job_id = (operation_log.get("job") or Row(jobId=None)).asDict().get("jobId")
111
+ job_run_id = (operation_log.get("job") or Row(jobRunId=None)).asDict().get("jobRunId")
112
+ run_id = (operation_log.get("job") or Row(runId=None)).asDict().get("runId")
113
+ notebook_id = (operation_log.get("notebook") or Row(notebook_id=None)).asDict().get("notebookId")
114
+ cluster_id = operation_log.get("clusterId")
115
+ affected_rows = {
116
+ k: v for k, v in operation_log.get("operationMetrics", {}).items() if k in operation_type.value
117
+ }
118
+ for metric, value in affected_rows.items():
119
+ log_message = f"""timestamp: {timestamp} |
120
+ table_identifier: {table_identifier} |
121
+ operation_type: {operation_type.name} |
122
+ metric_name: {metric} |
123
+ metric_value: {value} |
124
+ user_name: {user_name} |
125
+ job_id: {job_id} |
126
+ job_run_id: {job_run_id} |
127
+ run_id: {run_id} |
128
+ notebook_id: {notebook_id} |
129
+ cluster_id: {cluster_id}
130
+ """
131
+ self._table_operation_metrics_logger.info(log_message)
132
+
133
+ @staticmethod
134
+ def _merge_match_conditions(columns: list[str]) -> str:
135
+ """Merges match conditions of the given columns into a single string.
136
+
137
+ This function is used to generate an SQL query to match rows between two tables based on
138
+ the specified columns.
139
+
140
+ Args:
141
+ columns: A list of strings representing the names of the columns to match.
142
+
143
+ Returns:
144
+ A string containing the match conditions, separated by " AND "
145
+
146
+ Example:
147
+ ```python
148
+ _merge_match_conditions(["column1", "column2"]) # "target.column1 <=> source.column1 AND target.column2 <=> source.column2"
149
+ ```
150
+ """
151
+ return " AND ".join([f"target.`{c}` <=> source.`{c}`" for c in columns])
152
+
153
+ @staticmethod
154
+ def _partition_pruning_conditions(df, partition_cols: list[str] | None) -> str:
155
+ """Generates partition pruning conditions for an SQL query.
156
+
157
+ This function is used to optimize the performance of an SQL query by only scanning the
158
+ necessary partitions in a table, based on the specified partition columns and the data
159
+ in a Spark dataframe.
160
+
161
+ Args:
162
+ df: A Spark dataframe containing the data to generate the partition pruning
163
+ conditions from.
164
+ partition_cols: A list of strings representing the names of the
165
+ partition columns.
166
+
167
+ Returns:
168
+ A string, representing the partition pruning conditions.
169
+
170
+ Example:
171
+ ```python
172
+ _partition_pruning_conditions(df, ["column1", "column2"])
173
+ "(target.column1 = 'value1' AND target.column2 = 'value2') OR (target.column1 = 'value3'
174
+ AND target.column2 = 'value4')"
175
+ ```
176
+ """
177
+ if not partition_cols:
178
+ return ""
179
+ pruning_conditions = (
180
+ df.select(*partition_cols)
181
+ .distinct()
182
+ .select([format_string("target.`%s` = '%s'", lit(c), col(c)).alias(c) for c in partition_cols])
183
+ .withColumn("result", concat(lit("("), concat_ws(" AND ", *partition_cols), lit(")")))
184
+ .select("result")
185
+ .toPandas()
186
+ .result.str.cat(sep=" OR ")
187
+ )
188
+ pruning_conditions = "(" + pruning_conditions + ")"
189
+
190
+ return str(pruning_conditions)
191
+
192
+ def _empty_dataframe_check(self, df: DataFrame, ignore_empty_df: bool) -> bool | None:
193
+ """Checks if a DataFrame is empty and raises an exception if it is not expected to be empty.
194
+
195
+ Args:
196
+ df: The DataFrame to check for emptiness.
197
+ ignore_empty_df: If True, the function will return without raising
198
+ an exception if the DataFrame is empty. If False, an EmptyDataframeException
199
+ will be raised.
200
+
201
+ Raises:
202
+ EmptyDataframeException: If the DataFrame is empty and ignore_empty_df is False.
203
+ """
204
+ if df.isEmpty():
205
+ if ignore_empty_df:
206
+ return True
207
+ raise EmptyDataframeError(
208
+ "EMPTY DATAFRAME, nothing to write. If this is expected, consider setting `ignore_empty_df` to True.",
209
+ )
210
+ return None
@@ -0,0 +1,4 @@
1
+ class EmptyDataframeError(Exception):
2
+ """When a dataframe is empty when it should not be."""
3
+
4
+ pass
@@ -0,0 +1,132 @@
1
+ from typing import Any
2
+
3
+ from pyspark.sql import DataFrame, DataFrameWriter
4
+ from pyspark.sql.streaming import DataStreamWriter
5
+
6
+ from .writer import BaseWriter
7
+
8
+
9
+ class FileWriter(BaseWriter):
10
+ """Utility class for writing a DataFrame to a file."""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def _get_writer(self, df: DataFrame) -> DataFrameWriter:
16
+ """Returns a DataFrameWriter."""
17
+ return df.write
18
+
19
+ def _get_stream_writer(self, df: DataFrame) -> DataStreamWriter:
20
+ """Returns a DataStreamWriter."""
21
+ return df.writeStream
22
+
23
+ def _log_operation(self, location: str, status: str, error: str | None = None):
24
+ """Logs the status of an operation."""
25
+ if status == "start":
26
+ self._console_logger.info(f"Starting to write to {location}")
27
+ elif status == "succeeded":
28
+ self._console_logger.info(f"Successfully wrote to {location}")
29
+ elif status == "failed":
30
+ self._console_logger.error(f"Failed to write to {location}: {error}")
31
+
32
+ def _validate_trigger(self, trigger_dict: dict[str, Any]):
33
+ """Validates the trigger type."""
34
+ triggers = ["processingTime", "once", "continuous", "availableNow"]
35
+ if not any(trigger in trigger_dict for trigger in triggers):
36
+ raise ValueError(f"Invalid trigger type. Supported types are: {', '.join(triggers)}")
37
+
38
+ def write_stream(
39
+ self,
40
+ data_frame: DataFrame | None = None,
41
+ location: str | None = None,
42
+ format: str = "delta",
43
+ checkpoint_location: str | None = None,
44
+ partition_cols: list[str] | None = None,
45
+ mode: str = "append",
46
+ trigger_dict: dict | None = None,
47
+ options: dict[str, Any] | None = None,
48
+ await_termination: bool = False,
49
+ **_: Any,
50
+ ):
51
+ """Writes a dataframe to specified location in specified format as a stream.
52
+
53
+ Args:
54
+ data_frame: The DataFrame to write.
55
+ location: The location to write the DataFrame to.
56
+ format: The format to write the DataFrame in.
57
+ checkpoint_location: Location of checkpoint. If None, defaults
58
+ to the location of the table being written, with '_checkpoint_'
59
+ added before the name.
60
+ partition_cols: Columns to partition by.
61
+ mode: The write mode.
62
+ trigger_dict: A dictionary specifying the trigger configuration for the streaming query.
63
+ Supported keys include:
64
+
65
+ - "processingTime": Specifies a time interval (e.g., "10 seconds") for micro-batch processing.
66
+ - "once": Processes all available data once and then stops.
67
+ - "continuous": Specifies a time interval (e.g., "1 second") for continuous processing.
68
+ - "availableNow": Processes all available data immediately and then stops.
69
+
70
+ If nothing is provided, the default is {"availableNow": True}.
71
+ options: Additional options for writing.
72
+ await_termination: If True, the function will wait for the streaming
73
+ query to finish before returning. This is useful for ensuring that
74
+ the data has been fully written before proceeding with other
75
+ operations.
76
+ """
77
+ if not location or not data_frame:
78
+ raise ValueError("Location and data_frame are required for streaming.")
79
+
80
+ self._log_operation(location, "start")
81
+ try:
82
+ options = options or {}
83
+ trigger_dict = trigger_dict or {"availableNow": True}
84
+ checkpoint_location = self._get_checkpoint_location(location, checkpoint_location)
85
+ self._validate_trigger(trigger_dict)
86
+ stream_writer = self._get_stream_writer(data_frame)
87
+
88
+ stream_writer.trigger(**trigger_dict)
89
+ stream_writer.format(format)
90
+ stream_writer.outputMode(mode)
91
+ stream_writer.options(**options).option("checkpointLocation", checkpoint_location)
92
+ if partition_cols:
93
+ stream_writer.partitionBy(partition_cols)
94
+
95
+ query = stream_writer.start(location)
96
+ if await_termination is True:
97
+ query.awaitTermination()
98
+ except Exception as e:
99
+ self._log_operation(location, "failed", str(e))
100
+ raise e
101
+ else:
102
+ self._log_operation(location, "succeeded")
103
+
104
+ def write(
105
+ self,
106
+ data_frame: DataFrame,
107
+ location: str | None = None,
108
+ format: str = "delta",
109
+ partition_cols: list[str] | None = None,
110
+ mode: str = "append",
111
+ options: dict[str, Any] | None = None,
112
+ **_: Any,
113
+ ):
114
+ """Writes a dataframe to specified location in specified format."""
115
+ if not location:
116
+ raise ValueError("Location is required for writing to file.")
117
+
118
+ self._log_operation(location, "start")
119
+ try:
120
+ options = options or {}
121
+ df_writer = self._get_writer(data_frame)
122
+ df_writer.format(format)
123
+ df_writer.mode(mode)
124
+ if partition_cols:
125
+ df_writer.partitionBy(partition_cols)
126
+ df_writer.options(**options)
127
+ df_writer.save(str(location))
128
+ except Exception as e:
129
+ self._log_operation(location, "failed", str(e))
130
+ raise e
131
+ else:
132
+ self._log_operation(location, "succeeded")
@@ -0,0 +1,54 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ from pyspark.sql import DataFrame
6
+
7
+ from ...logging import LoggerMixin
8
+
9
+
10
+ class BaseWriter(ABC, LoggerMixin):
11
+ """BaseWriter class to write data."""
12
+
13
+ def __init__(self):
14
+ self._console_logger = self.get_console_logger()
15
+
16
+ @abstractmethod
17
+ def write_stream(self, **kwargs: Any):
18
+ """Writes a DataFrame stream."""
19
+ pass
20
+
21
+ @abstractmethod
22
+ def write(
23
+ self,
24
+ data_frame: DataFrame,
25
+ **kwargs: Any,
26
+ ):
27
+ """Writes a DataFrame."""
28
+ pass
29
+
30
+ def log_operation(self, operation: str, identifier: str | Path, status: str, error: str = ""):
31
+ """Logs the metrics for one operation on the given identifier.
32
+
33
+ Args:
34
+ operation: Describes the type of operation, e.g. 'read_api'.
35
+ identifier: An identifier for the object that's being interacted with.
36
+ status: The status of the operation. Must be one of "start", "failed", "succeeded".
37
+ error: The error message, if any. Defaults to ''.
38
+ """
39
+ self._console_logger.info(
40
+ "operation:%s | identifier:%s | status:%s | error:%s",
41
+ operation,
42
+ identifier,
43
+ status,
44
+ error,
45
+ )
46
+
47
+ def _get_checkpoint_location(self, location: str, checkpoint_location: str | None) -> str:
48
+ """Generates the checkpoint location if not provided."""
49
+ if checkpoint_location is None:
50
+ location_path = Path(location)
51
+ checkpoint_location = str(location_path.parent / f"_checkpoint_{location_path.name}").replace(
52
+ "abfss:/", "abfss://"
53
+ )
54
+ return checkpoint_location
@@ -79,7 +79,11 @@ class UnityCatalogAdapter(LoggerMixin):
79
79
 
80
80
  for schema in schemas_df:
81
81
  schemas.append(
82
- Schema(name=schema["schema_name"], catalog=catalog, comment=schema["comment"]),
82
+ Schema(
83
+ name=schema["schema_name"],
84
+ catalog=catalog,
85
+ comment=schema["comment"],
86
+ ),
83
87
  )
84
88
  return schemas
85
89
 
@@ -49,7 +49,7 @@ class Schema(ReadInstancesMixin):
49
49
  instance_path=processed_instance_path.parents[0] / table_dir_name,
50
50
  catalog_name=schema.catalog,
51
51
  schema_name=schema.name,
52
- schema_storage_path=Path(schema.storage_path),
52
+ schema_storage_path=schema.storage_path,
53
53
  fail_on_missing_subfolder=fail_on_missing_subfolder,
54
54
  )
55
55
  schema.tables = tables
@@ -4,7 +4,13 @@ from typing import Any, Self
4
4
  import yaml
5
5
  import yaml.scanner
6
6
  from jinja2 import TemplateNotFound
7
- from pydantic import Field, ValidationError, ValidationInfo, field_validator, model_validator
7
+ from pydantic import (
8
+ Field,
9
+ ValidationError,
10
+ ValidationInfo,
11
+ field_validator,
12
+ model_validator,
13
+ )
8
14
 
9
15
  from ..logging import LoggerMixin
10
16
  from ..utils.file_and_directory_handler import process_path
@@ -28,7 +34,7 @@ class Table(TemplateLoaderMixin, ReadInstancesMixin, LoggerMixin):
28
34
  properties: dict[str, str] = Field(default_factory=dict)
29
35
  constraints: list[Constraint] = Field(default_factory=list)
30
36
  foreign_keys: list[ForeignKey] = Field(default_factory=list)
31
- storage_path: Path | None = None
37
+ storage_path: str | None = None
32
38
  business_properties: dict[str, str] = Field(default_factory=dict)
33
39
  comment: str | None = None
34
40
  data_source_format: str | None = None
@@ -93,6 +99,7 @@ class Table(TemplateLoaderMixin, ReadInstancesMixin, LoggerMixin):
93
99
  """If is_external is set to True, storage_path has to be set."""
94
100
  if table.is_external and table.storage_path is None:
95
101
  raise ValueError("is_external cannot be true while storage_path is None.")
102
+ return table
96
103
 
97
104
  @classmethod
98
105
  def read_instances_from_directory(
@@ -154,7 +161,7 @@ class Table(TemplateLoaderMixin, ReadInstancesMixin, LoggerMixin):
154
161
  sub_errors: list[ValidationErrorType] = []
155
162
  if instance_file.is_file() and instance_file.suffix in (".yaml", ".yml"):
156
163
  instance, sub_errors = cls.read_instance_from_file(
157
- instance_file, catalog_name, schema_name, schema_storage_path
164
+ instance_file, catalog_name, schema_name, str(schema_storage_path)
158
165
  )
159
166
  instances += [] if instance is None else [instance]
160
167
  errors += sub_errors
@@ -206,9 +213,9 @@ class Table(TemplateLoaderMixin, ReadInstancesMixin, LoggerMixin):
206
213
  data["identifier"] = f"{catalog_name}.{schema_name}.{data['name']}"
207
214
  if data.get("is_external"):
208
215
  if storage_path := data.get("storage_path"):
209
- data["storage_path"] = Path(storage_path)
216
+ data["storage_path"] = storage_path
210
217
  elif schema_storage_path:
211
- data["storage_path"] = schema_storage_path / data["name"]
218
+ data["storage_path"] = (schema_storage_path / data["name"]).as_posix()
212
219
  else:
213
220
  raise ValueError(
214
221
  f"Neither storage path nor schema storage path of table {data['name']} has been provided."
@@ -216,7 +223,11 @@ class Table(TemplateLoaderMixin, ReadInstancesMixin, LoggerMixin):
216
223
 
217
224
  instance, sub_errors = cls.metadata_to_instance(data)
218
225
  errors += sub_errors
219
- except (ValidationError, yaml.parser.ParserError, yaml.scanner.ScannerError) as e:
226
+ except (
227
+ ValidationError,
228
+ yaml.parser.ParserError,
229
+ yaml.scanner.ScannerError,
230
+ ) as e:
220
231
  instance = None
221
232
  errors.append(e)
222
233
  return instance, errors
@@ -48,9 +48,8 @@ def table_log_decorator(operation: str):
48
48
  def inner_decorator(func):
49
49
  @functools.wraps(func)
50
50
  def wrapper(self, *args, **kwargs):
51
- table_identifier = kwargs.get("table_identifier") or kwargs.get("table").identifier or args[0]
52
- if not isinstance(table_identifier, str):
53
- # assume its a Table object
51
+ table_identifier = kwargs.get("table_identifier") or kwargs.get("table") or args[0]
52
+ if isinstance(table_identifier, Table):
54
53
  table_identifier = table_identifier.identifier
55
54
  self._tabular_logger.info(
56
55
  "operation:%s | identifier:%s | status:start | error:''",
@@ -84,7 +83,6 @@ class TableManager(LoggerMixin):
84
83
 
85
84
  def __init__(self, tabular_logger: logging.Logger | None = None):
86
85
  self._spark = SessionManager.get_spark_session()
87
- self._utils = SessionManager.get_utils()
88
86
  self._console_logger = self.get_console_logger()
89
87
  self._console_logger.debug("TableManager initialized...")
90
88
  self._tabular_logger = tabular_logger or self.get_tabular_logger(**TableManagerLogs().__dict__)
@@ -115,51 +113,83 @@ class TableManager(LoggerMixin):
115
113
  if statement and statement != "\n":
116
114
  self._spark.sql(statement)
117
115
 
118
- def drop_table(self, table_identifier: str, delete_physical_data: bool = False):
116
+ def drop_table(
117
+ self,
118
+ table: Table | None = None,
119
+ storage_location: str | None = None,
120
+ table_identifier: str | None = None,
121
+ delete_physical_data: bool = False,
122
+ ):
119
123
  """Deletes a Table. For security reasons you are forced to pass the table_name.
120
124
 
121
125
  If delete_physical_data is True the actual physical data on the ADLS will be deleted.
122
126
  Use with caution!
123
127
 
124
128
  Args:
129
+ table: The Table object representing the Delta table.
130
+ storage_location: The location of the Delta table on the ADLS.
125
131
  table_identifier: The table identifier in the catalog. Must be in the format 'catalog.schema.table'.
126
132
  delete_physical_data: If set to True, deletes not only the metadata
127
133
  within the Catalog but also the physical data.
134
+
135
+ Raises:
136
+ ValueError: If neither table nor table_identifier is provided, or if both are provided.
137
+ ValueError: If the table storage path is not provided by the table object.
128
138
  """
129
139
  self._console_logger.info(f"Deleting table [ '{table_identifier}' ] ...")
130
- if not isinstance(table_identifier, str):
131
- raise NotImplementedError("table_identifier must be a string, can be a Table object in the future.")
132
-
140
+ if table is not None and (table_identifier is not None or storage_location is not None):
141
+ raise ValueError("Either table or table_identifier and storage_location must be provided, but not both.")
142
+ if table is not None:
143
+ table_identifier = table.identifier
144
+ storage_location = str(table.storage_path)
133
145
  if delete_physical_data:
134
- self._delete_physical_data()
135
- self.drop_table_from_catalog(table_identifier)
146
+ self._delete_physical_data(location=storage_location)
147
+ self.drop_table_from_catalog(table_identifier=table_identifier)
136
148
 
137
- def drop_table_from_catalog(self, table_identifier: str) -> None:
149
+ def drop_table_from_catalog(self, table_identifier: str | None = None, table: Table | None = None) -> None:
138
150
  """Removes a table from the catalog. Physical data is retained.
139
151
 
140
152
  Args:
141
153
  table_identifier: The table identifier in the catalog. Must be in the format 'catalog.schema.table'.
154
+ table: The Table object representing the Delta table.
155
+
156
+ Raises:
157
+ ValueError: If neither table nor table_identifier is provided, or if both are provided.
142
158
  """
159
+ if (table is None and table_identifier is None) or (table is not None and table_identifier is not None):
160
+ raise ValueError("Either table or table_identifier must be provided, but not both.")
161
+ if table is not None:
162
+ table_identifier = table.identifier
143
163
  self._console_logger.info(f"... deleting table [ '{table_identifier}' ] from Catalog.")
144
- if not isinstance(table_identifier, str):
145
- raise NotImplementedError("table_identifier must be a string, can be a Table object in the future.")
146
164
  self._spark.sql(f"DROP TABLE IF EXISTS {table_identifier};")
147
165
 
148
- def _delete_physical_data(self):
166
+ def _delete_physical_data(self, table: Table | None = None, location: str | None = None):
149
167
  """Removes the physical data on the ADLS for the location of this table.
150
168
 
169
+ Args:
170
+ table: The Table object representing the Delta table to be deleted.
171
+ location: The location of the Delta table to be deleted.
172
+
151
173
  Raises:
152
- NotImplementedError: This can be implemented, once a Table object is available.
174
+ ValueError: If neither table nor location is provided, or if both are provided.
175
+ ValueError: If the table storage path is not provided by the table object.
153
176
  """
154
- self._console_logger.info("... deleting physical data for table [ '' ] from Catalog.")
155
- raise NotImplementedError("This can be implemented, once a Table object is available.")
177
+ if (table is None and location is None) or (table is not None and location is not None):
178
+ raise ValueError("Either table or location must be provided, but not both.")
179
+ if table is not None:
180
+ if table.storage_path is None:
181
+ raise ValueError("Table storage path must be provided.")
182
+ location = str(table.storage_path)
183
+ SessionManager.get_utils().fs.rm(location, recurse=True)
184
+ self._console_logger.info("... deleting physical data.")
156
185
 
157
- def get_delta_table(self, table: Table | None = None, location: str | None = None) -> DeltaTable:
186
+ def get_delta_table(self, table: Table | None = None, location: str | None = None, spark=None) -> DeltaTable:
158
187
  """Get the DeltaTable object from the Table objects location or a location string.
159
188
 
160
189
  Args:
161
190
  table: A Table object representing the Delta table.
162
191
  location: A string representing the table location.
192
+ spark: An optional Spark session. If not provided, the current Spark session will be used.
163
193
 
164
194
  Returns:
165
195
  The DeltaTable object corresponding to the given Table object or location string.
@@ -173,7 +203,7 @@ class TableManager(LoggerMixin):
173
203
  if table is not None:
174
204
  location = str(table.storage_path)
175
205
  self._console_logger.info(f"Getting DeltaTable object for location: {location}")
176
- return DeltaTable.forPath(self._spark, str(location))
206
+ return DeltaTable.forPath(spark or self._spark, str(location))
177
207
 
178
208
  def table_exists(self, table: Table | None = None, table_identifier: str | None = None) -> bool:
179
209
  """Checks if a table exists in the catalog.
@@ -232,3 +262,27 @@ class TableManager(LoggerMixin):
232
262
 
233
263
  self._console_logger.info(f"Refreshing table: {table_identifier}")
234
264
  self._spark.sql(f"REFRESH TABLE {table_identifier};")
265
+
266
+ @table_log_decorator(operation="truncate")
267
+ def truncate_table(
268
+ self,
269
+ table: Table | None = None,
270
+ table_identifier: str | None = None,
271
+ ):
272
+ """Truncates a table.
273
+
274
+ Args:
275
+ table: A Table object representing the Delta table.
276
+ table_identifier: The identifier of the Delta table in the format 'catalog.schema.table'.
277
+
278
+ Raises:
279
+ ValueError: If neither table nor table_identifier is provided, or if both are provided.
280
+ """
281
+ if (table is None and table_identifier is None) or (table is not None and table_identifier is not None):
282
+ raise ValueError("Either table or table_identifier must be provided, but not both.")
283
+
284
+ if table is not None:
285
+ table_identifier = table.escaped_identifier
286
+
287
+ self._console_logger.info(f"Truncating table: {table_identifier}")
288
+ self._spark.sql(f"TRUNCATE TABLE {table_identifier};")
@@ -22,6 +22,9 @@ from .transform_replace_values import TransformReplaceValuesAction
22
22
  from .transform_select_columns import TransformSelectColumnsAction
23
23
  from .transform_union import TransformUnionAction
24
24
  from .write_catalog_table import WriteCatalogTableAction
25
+ from .write_delta_append import WriteDeltaAppendAction
26
+ from .write_delta_merge import WriteDeltaMergeAction
27
+ from .write_file import WriteFileAction
25
28
 
26
29
  # Get all subclasses of PipelineAction defined in this submodule
27
30
  pipeline_actions = {cls.name: cls for cls in PipelineAction.__subclasses__()}
@@ -36,7 +39,6 @@ __all__ = [
36
39
  "ReadExcelAction",
37
40
  "ReadFilesAction",
38
41
  "ReadMetadataYAMLAction",
39
- "WriteCatalogTableAction",
40
42
  "PipelineActionType",
41
43
  "TransformFilterAction",
42
44
  "TransformUnionAction",
@@ -52,5 +54,9 @@ __all__ = [
52
54
  "TransformRenameColumnsAction",
53
55
  "TransformReplaceValuesAction",
54
56
  "TransformSelectColumnsAction",
57
+ "WriteCatalogTableAction",
58
+ "WriteDeltaAppendAction",
59
+ "WriteDeltaMergeAction",
60
+ "WriteFileAction",
55
61
  "TransformHashColumnsAction",
56
62
  ]