soda-sparkdf 4.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-sparkdf
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: freezegun
6
+ Requires-Dist: pyspark>=3.5.0
7
+ Requires-Dist: soda-databricks==4.0.5
8
+ Dynamic: requires-dist
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,22 @@
1
+ #!/usr/bin/env python
2
+
3
+ from setuptools import setup
4
+
5
+ package_name = "soda-sparkdf"
6
+ package_version = "4.0.5"
7
+ description = "Soda SparkDF V4"
8
+
9
+ requires = [f"soda-core=={package_version}", "freezegun", "pyspark>=3.5.0", f"soda-databricks=={package_version}"]
10
+ # Note: some java runtime (>17) is required to run the tests.
11
+
12
+ setup(
13
+ name=package_name,
14
+ version=package_version,
15
+ install_requires=requires,
16
+ package_dir={"": "src"},
17
+ entry_points={
18
+ "soda.plugins.data_source.sparkdf": [
19
+ "SparkDataFrameDataSourceImpl = soda_sparkdf.common.data_sources.sparkdf_data_source:SparkDataFrameDataSourceImpl",
20
+ ],
21
+ },
22
+ )
@@ -0,0 +1,7 @@
1
+ from soda_sparkdf.common.data_sources.sparkdf_data_source import (
2
+ SparkDataFrameDataSourceImpl as SparkDataFrameDataSource,
3
+ )
4
+
5
+ __all__ = [
6
+ "SparkDataFrameDataSource",
7
+ ]
@@ -0,0 +1,243 @@
1
+ from datetime import datetime, timezone
2
+ from typing import Optional
3
+
4
+ from freezegun import freeze_time
5
+ from pyspark.sql import DataFrame, SparkSession
6
+ from pyspark.sql.types import Row
7
+ from soda_core.common.data_source_connection import DataSourceConnection
8
+ from soda_core.common.data_source_impl import DataSourceImpl, MetadataTablesQuery
9
+ from soda_core.common.data_source_results import QueryResult
10
+ from soda_core.common.metadata_types import ColumnMetadata, SodaDataTypeName
11
+ from soda_core.common.sql_dialect import SqlDialect
12
+ from soda_databricks.common.data_sources.databricks_data_source import (
13
+ DatabricksSqlDialect,
14
+ )
15
+ from soda_databricks.common.statements.hive_metadata_tables_query import (
16
+ HiveMetadataTablesQuery,
17
+ )
18
+ from soda_sparkdf.common.data_sources.sparkdf_data_source_connection import (
19
+ SparkDataFrameConnectionProperties,
20
+ )
21
+ from soda_sparkdf.common.data_sources.sparkdf_data_source_connection import (
22
+ SparkDataFrameDataSource as SparkDataFrameDataSourceModel,
23
+ )
24
+ from soda_sparkdf.common.data_sources.sparkdf_data_source_connection import (
25
+ SparkDataFrameExistingSessionProperties,
26
+ SparkDataFrameNewSessionProperties,
27
+ )
28
+
29
+ _in_memory_connection = None
30
+
31
+
32
+ class SparkDataFrameCursor:
33
+ # Copy from v3 of the cursor implementation
34
+ def __init__(self, spark_session: SparkSession, test_dir: Optional[str] = None):
35
+ self.spark_session = spark_session
36
+ self.df: DataFrame | None = None
37
+ self.description: tuple[tuple] | None = None
38
+ self.rowcount: int = -1
39
+ self.cursor_index: int = -1
40
+
41
+ def execute(self, sql: str):
42
+ self.df = self.spark_session.sql(sqlQuery=sql)
43
+ self.description = self.convert_spark_df_schema_to_dbapi_description(self.df)
44
+ self.cursor_index = 0
45
+
46
+ def fetchall(self) -> tuple[tuple]:
47
+ rows = []
48
+ with freeze_time(
49
+ datetime.now(timezone.utc)
50
+ ): # We need to freeze the time to UTC at the time of collecting to avoid issues with timestamps
51
+ # Spark stores the timestamps in UTC (verified by reading the parquet file that spark creates), but when querying it converts it to the (python) session-local timezone.
52
+ # By using freeze_time, we can ensure that the timestamps are collected in UTC, regardless of the session-local timezone. This plays nice with the freshness checks.
53
+ spark_rows: list[Row] = self.df.collect()
54
+ # Alternative approach: convert to PyArrow. This will set the timestamps correctly, but introduces memory and time overhead (for the conversion).
55
+ # This also requires more changes regarding the cursor implementation: spark_rows: list[Row] = self.df.toArrow().to_pylist()
56
+ self.rowcount = len(spark_rows)
57
+ for spark_row in spark_rows:
58
+ row = self.convert_spark_row_to_dbapi_row(spark_row)
59
+ rows.append(row)
60
+ return tuple(rows)
61
+
62
+ def fetchmany(self, size: int) -> tuple[tuple]:
63
+ rows = []
64
+ self.rowcount = self.df.count()
65
+ with freeze_time(
66
+ datetime.now(timezone.utc)
67
+ ): # We need to freeze the time to UTC at the time of collecting to avoid issues with timestamps. See the comment in fetchall() for more details.
68
+ spark_rows: list[Row] = self.df.offset(self.cursor_index).limit(size).collect()
69
+ self.cursor_index += len(spark_rows)
70
+ for spark_row in spark_rows:
71
+ row = self.convert_spark_row_to_dbapi_row(spark_row)
72
+ rows.append(row)
73
+ return tuple(rows)
74
+
75
+ def fetchone(self) -> tuple:
76
+ with freeze_time(
77
+ datetime.now(timezone.utc)
78
+ ): # We need to freeze the time to UTC at the time of collecting to avoid issues with timestamps. See the comment in fetchall() for more details.
79
+ spark_rows: list[Row] = self.df.collect()
80
+ self.rowcount = len(spark_rows)
81
+ spark_row = spark_rows[0]
82
+ row = self.convert_spark_row_to_dbapi_row(spark_row)
83
+ return tuple(row)
84
+
85
+ @staticmethod
86
+ def convert_spark_row_to_dbapi_row(spark_row):
87
+ return [spark_row[field] for field in spark_row.__fields__]
88
+
89
+ def close(self):
90
+ pass # No-op
91
+
92
+ @staticmethod
93
+ def convert_spark_df_schema_to_dbapi_description(df) -> tuple[tuple]:
94
+ return tuple((field.name, type(field.dataType).__name__) for field in df.schema.fields)
95
+
96
+
97
+ class SparkDataFrameDataSourceConnectionWrapper:
98
+ def __init__(self, session: SparkSession):
99
+ self._session = session
100
+
101
+ def __getattr__(self, attr):
102
+ if attr in self.__dict__:
103
+ return getattr(self, attr)
104
+ return getattr(self._session, attr)
105
+
106
+ def commit(self):
107
+ pass # Do nothing, Spark does not have a commit concept
108
+
109
+ def cursor(self):
110
+ return SparkDataFrameCursor(self._session)
111
+
112
+
113
+ class SparkDataFrameSqlDialect(DatabricksSqlDialect):
114
+ SODA_DATA_TYPE_SYNONYMS = (
115
+ (SodaDataTypeName.TEXT, SodaDataTypeName.VARCHAR, SodaDataTypeName.CHAR),
116
+ (SodaDataTypeName.NUMERIC, SodaDataTypeName.DECIMAL),
117
+ (SodaDataTypeName.TIMESTAMP_TZ, SodaDataTypeName.TIMESTAMP),
118
+ )
119
+
120
+ def get_database_prefix_index(self) -> int | None:
121
+ return None
122
+
123
+ def get_schema_prefix_index(self) -> int | None:
124
+ return 0
125
+
126
+ def create_schema_if_not_exists_sql(self, prefixes: list[str], add_semicolon: bool = True) -> str:
127
+ schema_name: str = prefixes[0]
128
+ quoted_schema_name: str = self.quote_default(schema_name)
129
+ return f"CREATE SCHEMA IF NOT EXISTS {quoted_schema_name}" + (";" if add_semicolon else "")
130
+
131
+ def post_schema_create_sql(self, prefixes: list[str]) -> Optional[list[str]]:
132
+ pass # Do nothing, Spark does not have a post-schema-create concept
133
+
134
+ def build_column_metadatas_from_query_result(self, query_result: QueryResult) -> list[ColumnMetadata]:
135
+ # Filter out dataset description rows (first such line starts with #, ignore the rest) or empty
136
+ filtered_rows = []
137
+ for row in query_result.rows:
138
+ if row[0].startswith("#"): # ignore all description rows
139
+ break
140
+ if not row[0] and not row[1]: # empty row
141
+ continue
142
+
143
+ # Trim data type details, e.g. decimal(10,0) -> decimal. Only decimal supports it anyway.
144
+ data_type = row[1]
145
+ if "(" in data_type:
146
+ data_type = data_type[: data_type.index("(")].strip()
147
+ row = (row[0], data_type)
148
+ filtered_rows.append(row)
149
+
150
+ return super().build_column_metadatas_from_query_result(
151
+ QueryResult(rows=filtered_rows, columns=query_result.columns)
152
+ )
153
+
154
+ def literal_datetime_with_tz(self, datetime: datetime):
155
+ # Always convert the timestamp to utc when we insert. Spark is not aware of the timezones, so we need to do this conversion so it's ready to be extracted as UTC.
156
+ return f"to_utc_timestamp('{datetime.isoformat()}', 'UTC')"
157
+
158
+ def literal_datetime(self, datetime: datetime):
159
+ # Always convert the timestamp to utc when we insert. Spark is not aware of the timezones, so we need to do this conversion so it's ready to be extracted as UTC.
160
+ return f"to_utc_timestamp('{datetime.isoformat()}', 'UTC')"
161
+
162
+
163
+ class SparkDataFrameDataSourceConnection(DataSourceConnection):
164
+ def __init__(self, name: str, connection_properties: dict, connection: Optional[object] = None):
165
+ super().__init__(name, connection_properties, connection)
166
+
167
+ def _create_connection(
168
+ self,
169
+ config: SparkDataFrameConnectionProperties,
170
+ ):
171
+ session = None
172
+ if isinstance(config, SparkDataFrameExistingSessionProperties):
173
+ session = config.spark_session
174
+ elif isinstance(config, SparkDataFrameNewSessionProperties):
175
+ session = (
176
+ SparkSession.builder.master("local")
177
+ .appName(self.name)
178
+ .config("spark.sql.warehouse.dir", config.test_dir)
179
+ .getOrCreate()
180
+ )
181
+ session.sql("SET spark.sql.session.timeZone = +00:00;")
182
+ session.sql("SET TIME ZONE 'UTC';")
183
+ if session is None:
184
+ raise ValueError("No session provided")
185
+ self.session = session
186
+ return SparkDataFrameDataSourceConnectionWrapper(session=session)
187
+
188
+ def close_connection(self) -> None:
189
+ "This is a no-op for SparkDataFrameDataSourceConnection, there is no connection to close."
190
+
191
+ def _execute_query_get_result_row_column_name(self, column) -> str:
192
+ return column[0] # The first element of the tuple is the column name
193
+
194
+
195
+ class SparkDataFrameDataSourceImpl(DataSourceImpl, model_class=SparkDataFrameDataSourceModel):
196
+ def _create_sql_dialect(self) -> SqlDialect:
197
+ return SparkDataFrameSqlDialect(data_source_impl=self)
198
+
199
+ def _create_data_source_connection(self) -> DataSourceConnection:
200
+ return SparkDataFrameDataSourceConnection(
201
+ name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
202
+ )
203
+
204
+ def create_metadata_tables_query(self) -> MetadataTablesQuery:
205
+ return HiveMetadataTablesQuery(sql_dialect=self.sql_dialect, data_source_connection=self.data_source_connection)
206
+
207
+ @classmethod
208
+ def from_existing_session(cls, session: SparkSession, name: str) -> DataSourceImpl:
209
+ connection_properties = {"spark_session": session, "schema_": name}
210
+ ds_model = SparkDataFrameDataSourceModel(
211
+ name=name,
212
+ connection_properties=connection_properties,
213
+ )
214
+ soda_connection = SparkDataFrameDataSourceConnection(
215
+ name=name,
216
+ connection_properties=connection_properties,
217
+ connection=SparkDataFrameDataSourceConnectionWrapper(session),
218
+ )
219
+ return cls(data_source_model=ds_model, connection=soda_connection)
220
+
221
+ def build_columns_metadata_query_str(self, dataset_prefixes: list[str], dataset_name: str) -> str:
222
+ if len(dataset_prefixes) == 0:
223
+ return f"DESCRIBE {dataset_name}"
224
+ elif len(dataset_prefixes) == 1:
225
+ schema_name: str = dataset_prefixes[0]
226
+ return f"DESCRIBE {schema_name}.{dataset_name}"
227
+ elif len(dataset_prefixes) == 2:
228
+ database_name: str = dataset_prefixes[0]
229
+ schema_name: str = dataset_prefixes[1]
230
+ return f"DESCRIBE {database_name}.{schema_name}.{dataset_name}"
231
+ else:
232
+ raise ValueError(f"Invalid number of dataset prefixes: {len(dataset_prefixes)}")
233
+
234
+ def test_schema_exists(self, prefixes: list[str]) -> bool:
235
+ result = self.connection.session.sql(f"SHOW SCHEMAS LIKE '{prefixes[0]}'").collect()
236
+ for row in result:
237
+ if row[0] and row[0].lower() == prefixes[0].lower():
238
+ return True
239
+ return False
240
+
241
+
242
+ # Alias to make the import and usage cleaner
243
+ SparkDataFrameDataSource = SparkDataFrameDataSourceImpl
@@ -0,0 +1,44 @@
1
+ from abc import ABC
2
+ from typing import Any, Literal, Optional, Union
3
+
4
+ from pydantic import Field, field_validator
5
+ from soda_core.model.data_source.data_source import DataSourceBase
6
+ from soda_core.model.data_source.data_source_connection_properties import (
7
+ DataSourceConnectionProperties,
8
+ )
9
+
10
+
11
+ class SparkDataFrameConnectionProperties(DataSourceConnectionProperties, ABC):
12
+ schema_: Optional[str] = Field(
13
+ "main", description="Optional schema name to use for the SparkDataFrame connection", alias="schema"
14
+ )
15
+ test_dir: Optional[str] = Field(None, description="The directory to use for the test")
16
+
17
+
18
+ class SparkDataFrameNewSessionProperties(SparkDataFrameConnectionProperties):
19
+ new_session: bool = Field(True, description="Whether to create a new Spark session")
20
+
21
+
22
+ class SparkDataFrameExistingSessionProperties(SparkDataFrameConnectionProperties, arbitrary_types_allowed=True):
23
+ # We set the type to Any to avoid type errors when the SparkSession is not a SparkSession object
24
+ # This could be the case on Databricks serverless, where the SparkSession is imported as a different object
25
+ spark_session: Any = Field(..., description="The existing Spark session to use")
26
+
27
+
28
+ class SparkDataFrameDataSource(DataSourceBase, ABC):
29
+ type: Literal["sparkdf"] = Field("sparkdf")
30
+ connection_properties: Union[SparkDataFrameExistingSessionProperties, SparkDataFrameNewSessionProperties] = Field(
31
+ ..., alias="connection", description="SparkDataFrame connection configuration"
32
+ )
33
+
34
+ @field_validator("connection_properties", mode="before")
35
+ @classmethod
36
+ def infer_connection_type(cls, value):
37
+ if isinstance(value, SparkDataFrameNewSessionProperties):
38
+ return value
39
+
40
+ if "spark_session" in value:
41
+ return SparkDataFrameExistingSessionProperties(**value)
42
+ elif "new_session" in value:
43
+ return SparkDataFrameNewSessionProperties(**value)
44
+ raise ValueError("Could not infer SparkDataFrame connection type from input")
@@ -0,0 +1,38 @@
1
+ from __future__ import annotations
2
+
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from helpers.data_source_test_helper import DataSourceTestHelper
7
+
8
+
9
+ class SparkDataFrameDataSourceTestHelper(DataSourceTestHelper):
10
+ def _create_data_source_yaml_str(self) -> str:
11
+ """
12
+ Called in _create_data_source_impl to initialized self.data_source_impl
13
+ self.database_name and self.schema_name are available if appropriate for the data source type
14
+ """
15
+ self.test_dir = tempfile.mkdtemp(prefix=f"soda_test_sparkdf_{self.name}_")
16
+ return f"""
17
+ type: sparkdf
18
+ name: {self.name}
19
+ connection:
20
+ new_session: true
21
+ test_dir: {self.test_dir}
22
+ """
23
+
24
+ # We need these methods to comply with the rest of the test helper infrastructure
25
+ def _create_database_name(self) -> Optional[str]:
26
+ return None
27
+
28
+ def _create_schema_name(self) -> Optional[str]:
29
+ return "main"
30
+
31
+ def _create_dataset_prefix(self) -> list[str]:
32
+ schema_name: str = self._create_schema_name()
33
+ return [schema_name]
34
+
35
+ def drop_test_schema_if_exists(self) -> None:
36
+ """
37
+ In-memory SparkDF does not support schemas, so this is a no-op.
38
+ """
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: soda-sparkdf
3
+ Version: 4.0.5
4
+ Requires-Dist: soda-core==4.0.5
5
+ Requires-Dist: freezegun
6
+ Requires-Dist: pyspark>=3.5.0
7
+ Requires-Dist: soda-databricks==4.0.5
8
+ Dynamic: requires-dist
@@ -0,0 +1,11 @@
1
+ setup.py
2
+ src/soda_sparkdf/__init__.py
3
+ src/soda_sparkdf.egg-info/PKG-INFO
4
+ src/soda_sparkdf.egg-info/SOURCES.txt
5
+ src/soda_sparkdf.egg-info/dependency_links.txt
6
+ src/soda_sparkdf.egg-info/entry_points.txt
7
+ src/soda_sparkdf.egg-info/requires.txt
8
+ src/soda_sparkdf.egg-info/top_level.txt
9
+ src/soda_sparkdf/common/data_sources/sparkdf_data_source.py
10
+ src/soda_sparkdf/common/data_sources/sparkdf_data_source_connection.py
11
+ src/soda_sparkdf/test_helpers/sparkdf_data_source_test_helper.py
@@ -0,0 +1,2 @@
1
+ [soda.plugins.data_source.sparkdf]
2
+ SparkDataFrameDataSourceImpl = soda_sparkdf.common.data_sources.sparkdf_data_source:SparkDataFrameDataSourceImpl
@@ -0,0 +1,4 @@
1
+ soda-core==4.0.5
2
+ freezegun
3
+ pyspark>=3.5.0
4
+ soda-databricks==4.0.5
@@ -0,0 +1 @@
1
+ soda_sparkdf