sqlframe 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/column.py +41 -0
- sqlframe/base/dataframe.py +77 -3
- sqlframe/base/exceptions.py +12 -0
- sqlframe/base/function_alternatives.py +5 -7
- sqlframe/base/functions.py +4 -2
- sqlframe/base/mixins/dataframe_mixins.py +24 -33
- sqlframe/base/types.py +12 -2
- sqlframe/base/util.py +51 -0
- sqlframe/bigquery/dataframe.py +33 -13
- sqlframe/bigquery/functions.py +1 -0
- sqlframe/duckdb/dataframe.py +6 -15
- sqlframe/postgres/catalog.py +123 -3
- sqlframe/postgres/dataframe.py +6 -10
- sqlframe/redshift/dataframe.py +3 -14
- sqlframe/snowflake/dataframe.py +23 -13
- sqlframe/spark/dataframe.py +25 -15
- sqlframe/testing/__init__.py +3 -0
- sqlframe/testing/utils.py +320 -0
- {sqlframe-1.10.0.dist-info → sqlframe-1.12.0.dist-info}/METADATA +1 -1
- {sqlframe-1.10.0.dist-info → sqlframe-1.12.0.dist-info}/RECORD +24 -22
- {sqlframe-1.10.0.dist-info → sqlframe-1.12.0.dist-info}/LICENSE +0 -0
- {sqlframe-1.10.0.dist-info → sqlframe-1.12.0.dist-info}/WHEEL +0 -0
- {sqlframe-1.10.0.dist-info → sqlframe-1.12.0.dist-info}/top_level.txt +0 -0
sqlframe/postgres/catalog.py
CHANGED
|
@@ -7,16 +7,17 @@ import typing as t
|
|
|
7
7
|
|
|
8
8
|
from sqlglot import exp, parse_one
|
|
9
9
|
|
|
10
|
-
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
10
|
+
from sqlframe.base.catalog import Column, Function, _BaseCatalog
|
|
11
|
+
from sqlframe.base.decorators import normalize
|
|
11
12
|
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
13
|
GetCurrentCatalogFromFunctionMixin,
|
|
13
14
|
GetCurrentDatabaseFromFunctionMixin,
|
|
14
15
|
ListCatalogsFromInfoSchemaMixin,
|
|
15
|
-
ListColumnsFromInfoSchemaMixin,
|
|
16
16
|
ListDatabasesFromInfoSchemaMixin,
|
|
17
17
|
ListTablesFromInfoSchemaMixin,
|
|
18
18
|
SetCurrentDatabaseFromSearchPathMixin,
|
|
19
19
|
)
|
|
20
|
+
from sqlframe.base.util import to_schema
|
|
20
21
|
|
|
21
22
|
if t.TYPE_CHECKING:
|
|
22
23
|
from sqlframe.postgres.session import PostgresSession # noqa
|
|
@@ -30,12 +31,131 @@ class PostgresCatalog(
|
|
|
30
31
|
ListCatalogsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
31
32
|
SetCurrentDatabaseFromSearchPathMixin["PostgresSession", "PostgresDataFrame"],
|
|
32
33
|
ListTablesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
33
|
-
ListColumnsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
34
34
|
_BaseCatalog["PostgresSession", "PostgresDataFrame"],
|
|
35
35
|
):
|
|
36
36
|
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
|
|
37
37
|
TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")
|
|
38
38
|
|
|
39
|
+
@normalize(["tableName", "dbName"])
|
|
40
|
+
def listColumns(
|
|
41
|
+
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
|
|
42
|
+
) -> t.List[Column]:
|
|
43
|
+
"""Returns a t.List of columns for the given table/view in the specified database.
|
|
44
|
+
|
|
45
|
+
.. versionadded:: 2.0.0
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
tableName : str
|
|
50
|
+
name of the table to t.List columns.
|
|
51
|
+
|
|
52
|
+
.. versionchanged:: 3.4.0
|
|
53
|
+
Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
|
|
54
|
+
|
|
55
|
+
dbName : str, t.Optional
|
|
56
|
+
name of the database to find the table to t.List columns.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
t.List
|
|
61
|
+
A t.List of :class:`Column`.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
The order of arguments here is different from that of its JVM counterpart
|
|
66
|
+
because Python does not support method overloading.
|
|
67
|
+
|
|
68
|
+
If no database is specified, the current database and catalog
|
|
69
|
+
are used. This API includes all temporary views.
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
|
|
74
|
+
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
|
|
75
|
+
>>> spark.catalog.t.listColumns("tblA")
|
|
76
|
+
[Column(name='name', description=None, dataType='string', nullable=True, ...
|
|
77
|
+
>>> _ = spark.sql("DROP TABLE tblA")
|
|
78
|
+
"""
|
|
79
|
+
if df := self.session.temp_views.get(tableName):
|
|
80
|
+
return [
|
|
81
|
+
Column(
|
|
82
|
+
name=x,
|
|
83
|
+
description=None,
|
|
84
|
+
dataType="",
|
|
85
|
+
nullable=True,
|
|
86
|
+
isPartition=False,
|
|
87
|
+
isBucket=False,
|
|
88
|
+
)
|
|
89
|
+
for x in df.columns
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
table = exp.to_table(tableName, dialect=self.session.input_dialect)
|
|
93
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
|
|
94
|
+
if not table.db:
|
|
95
|
+
if schema and schema.db:
|
|
96
|
+
table.set("db", schema.args["db"])
|
|
97
|
+
else:
|
|
98
|
+
table.set(
|
|
99
|
+
"db",
|
|
100
|
+
exp.parse_identifier(
|
|
101
|
+
self.currentDatabase(), dialect=self.session.input_dialect
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
if not table.catalog:
|
|
105
|
+
if schema and schema.catalog:
|
|
106
|
+
table.set("catalog", schema.args["catalog"])
|
|
107
|
+
else:
|
|
108
|
+
table.set(
|
|
109
|
+
"catalog",
|
|
110
|
+
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
|
|
111
|
+
)
|
|
112
|
+
source_table = self._get_info_schema_table("columns", database=table.db)
|
|
113
|
+
select = parse_one(
|
|
114
|
+
f"""
|
|
115
|
+
SELECT
|
|
116
|
+
att.attname AS column_name,
|
|
117
|
+
pg_catalog.format_type(att.atttypid, NULL) AS data_type,
|
|
118
|
+
col.is_nullable
|
|
119
|
+
FROM
|
|
120
|
+
pg_catalog.pg_attribute att
|
|
121
|
+
JOIN
|
|
122
|
+
pg_catalog.pg_class cls ON cls.oid = att.attrelid
|
|
123
|
+
JOIN
|
|
124
|
+
pg_catalog.pg_namespace nsp ON nsp.oid = cls.relnamespace
|
|
125
|
+
JOIN
|
|
126
|
+
information_schema.columns col ON col.table_schema = nsp.nspname AND col.table_name = cls.relname AND col.column_name = att.attname
|
|
127
|
+
WHERE
|
|
128
|
+
cls.relname = '{table.name}' AND -- replace with your table name
|
|
129
|
+
att.attnum > 0 AND
|
|
130
|
+
NOT att.attisdropped
|
|
131
|
+
ORDER BY
|
|
132
|
+
att.attnum;
|
|
133
|
+
""",
|
|
134
|
+
dialect="postgres",
|
|
135
|
+
)
|
|
136
|
+
if table.db:
|
|
137
|
+
schema_filter: exp.Expression = exp.column("table_schema").eq(table.db)
|
|
138
|
+
if include_temp and self.TEMP_SCHEMA_FILTER:
|
|
139
|
+
schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER)
|
|
140
|
+
select = select.where(schema_filter) # type: ignore
|
|
141
|
+
if table.catalog:
|
|
142
|
+
catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog)
|
|
143
|
+
if include_temp and self.TEMP_CATALOG_FILTER:
|
|
144
|
+
catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER)
|
|
145
|
+
select = select.where(catalog_filter) # type: ignore
|
|
146
|
+
results = self.session._fetch_rows(select)
|
|
147
|
+
return [
|
|
148
|
+
Column(
|
|
149
|
+
name=x["column_name"],
|
|
150
|
+
description=None,
|
|
151
|
+
dataType=x["data_type"],
|
|
152
|
+
nullable=x["is_nullable"] == "YES",
|
|
153
|
+
isPartition=False,
|
|
154
|
+
isBucket=False,
|
|
155
|
+
)
|
|
156
|
+
for x in results
|
|
157
|
+
]
|
|
158
|
+
|
|
39
159
|
def listFunctions(
|
|
40
160
|
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
41
161
|
) -> t.List[Function]:
|
sqlframe/postgres/dataframe.py
CHANGED
|
@@ -9,7 +9,10 @@ from sqlframe.base.dataframe import (
|
|
|
9
9
|
_BaseDataFrameNaFunctions,
|
|
10
10
|
_BaseDataFrameStatFunctions,
|
|
11
11
|
)
|
|
12
|
-
from sqlframe.base.mixins.dataframe_mixins import
|
|
12
|
+
from sqlframe.base.mixins.dataframe_mixins import (
|
|
13
|
+
NoCachePersistSupportMixin,
|
|
14
|
+
TypedColumnsFromTempViewMixin,
|
|
15
|
+
)
|
|
13
16
|
from sqlframe.postgres.group import PostgresGroupedData
|
|
14
17
|
|
|
15
18
|
if sys.version_info >= (3, 11):
|
|
@@ -34,7 +37,8 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr
|
|
|
34
37
|
|
|
35
38
|
|
|
36
39
|
class PostgresDataFrame(
|
|
37
|
-
|
|
40
|
+
NoCachePersistSupportMixin,
|
|
41
|
+
TypedColumnsFromTempViewMixin,
|
|
38
42
|
_BaseDataFrame[
|
|
39
43
|
"PostgresSession",
|
|
40
44
|
"PostgresDataFrameWriter",
|
|
@@ -46,11 +50,3 @@ class PostgresDataFrame(
|
|
|
46
50
|
_na = PostgresDataFrameNaFunctions
|
|
47
51
|
_stat = PostgresDataFrameStatFunctions
|
|
48
52
|
_group_data = PostgresGroupedData
|
|
49
|
-
|
|
50
|
-
def cache(self) -> Self:
|
|
51
|
-
logger.warning("Postgres does not support caching. Ignoring cache() call.")
|
|
52
|
-
return self
|
|
53
|
-
|
|
54
|
-
def persist(self) -> Self:
|
|
55
|
-
logger.warning("Postgres does not support persist. Ignoring persist() call.")
|
|
56
|
-
return self
|
sqlframe/redshift/dataframe.py
CHANGED
|
@@ -9,13 +9,9 @@ from sqlframe.base.dataframe import (
|
|
|
9
9
|
_BaseDataFrameNaFunctions,
|
|
10
10
|
_BaseDataFrameStatFunctions,
|
|
11
11
|
)
|
|
12
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
13
|
from sqlframe.redshift.group import RedshiftGroupedData
|
|
13
14
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
15
|
if t.TYPE_CHECKING:
|
|
20
16
|
from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
|
|
21
17
|
from sqlframe.redshift.session import RedshiftSession
|
|
@@ -33,22 +29,15 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr
|
|
|
33
29
|
|
|
34
30
|
|
|
35
31
|
class RedshiftDataFrame(
|
|
32
|
+
NoCachePersistSupportMixin,
|
|
36
33
|
_BaseDataFrame[
|
|
37
34
|
"RedshiftSession",
|
|
38
35
|
"RedshiftDataFrameWriter",
|
|
39
36
|
"RedshiftDataFrameNaFunctions",
|
|
40
37
|
"RedshiftDataFrameStatFunctions",
|
|
41
38
|
"RedshiftGroupedData",
|
|
42
|
-
]
|
|
39
|
+
],
|
|
43
40
|
):
|
|
44
41
|
_na = RedshiftDataFrameNaFunctions
|
|
45
42
|
_stat = RedshiftDataFrameStatFunctions
|
|
46
43
|
_group_data = RedshiftGroupedData
|
|
47
|
-
|
|
48
|
-
def cache(self) -> Self:
|
|
49
|
-
logger.warning("Redshift does not support caching. Ignoring cache() call.")
|
|
50
|
-
return self
|
|
51
|
-
|
|
52
|
-
def persist(self) -> Self:
|
|
53
|
-
logger.warning("Redshift does not support persist. Ignoring persist() call.")
|
|
54
|
-
return self
|
sqlframe/snowflake/dataframe.py
CHANGED
|
@@ -4,18 +4,15 @@ import logging
|
|
|
4
4
|
import sys
|
|
5
5
|
import typing as t
|
|
6
6
|
|
|
7
|
+
from sqlframe.base.catalog import Column as CatalogColumn
|
|
7
8
|
from sqlframe.base.dataframe import (
|
|
8
9
|
_BaseDataFrame,
|
|
9
10
|
_BaseDataFrameNaFunctions,
|
|
10
11
|
_BaseDataFrameStatFunctions,
|
|
11
12
|
)
|
|
13
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
14
|
from sqlframe.snowflake.group import SnowflakeGroupedData
|
|
13
15
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
16
|
if t.TYPE_CHECKING:
|
|
20
17
|
from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
|
|
21
18
|
from sqlframe.snowflake.session import SnowflakeSession
|
|
@@ -33,22 +30,35 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData
|
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class SnowflakeDataFrame(
|
|
33
|
+
NoCachePersistSupportMixin,
|
|
36
34
|
_BaseDataFrame[
|
|
37
35
|
"SnowflakeSession",
|
|
38
36
|
"SnowflakeDataFrameWriter",
|
|
39
37
|
"SnowflakeDataFrameNaFunctions",
|
|
40
38
|
"SnowflakeDataFrameStatFunctions",
|
|
41
39
|
"SnowflakeGroupedData",
|
|
42
|
-
]
|
|
40
|
+
],
|
|
43
41
|
):
|
|
44
42
|
_na = SnowflakeDataFrameNaFunctions
|
|
45
43
|
_stat = SnowflakeDataFrameStatFunctions
|
|
46
44
|
_group_data = SnowflakeGroupedData
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
46
|
+
@property
|
|
47
|
+
def _typed_columns(self) -> t.List[CatalogColumn]:
|
|
48
|
+
df = self._convert_leaf_to_cte()
|
|
49
|
+
df = df.limit(0)
|
|
50
|
+
self.session._execute(df.expression)
|
|
51
|
+
query_id = self.session._cur.sfqid
|
|
52
|
+
columns = []
|
|
53
|
+
for row in self.session._fetch_rows(f"DESCRIBE RESULT '{query_id}'"):
|
|
54
|
+
columns.append(
|
|
55
|
+
CatalogColumn(
|
|
56
|
+
name=row.name,
|
|
57
|
+
dataType=row.type,
|
|
58
|
+
nullable=row["null?"] == "Y",
|
|
59
|
+
description=row.comment,
|
|
60
|
+
isPartition=False,
|
|
61
|
+
isBucket=False,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return columns
|
sqlframe/spark/dataframe.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import sys
|
|
5
4
|
import typing as t
|
|
6
5
|
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
|
|
8
|
+
from sqlframe.base.catalog import Column
|
|
7
9
|
from sqlframe.base.dataframe import (
|
|
8
10
|
_BaseDataFrame,
|
|
9
11
|
_BaseDataFrameNaFunctions,
|
|
10
12
|
_BaseDataFrameStatFunctions,
|
|
11
13
|
)
|
|
14
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
15
|
from sqlframe.spark.group import SparkGroupedData
|
|
13
16
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
17
|
if t.TYPE_CHECKING:
|
|
20
18
|
from sqlframe.spark.readwriter import SparkDataFrameWriter
|
|
21
19
|
from sqlframe.spark.session import SparkSession
|
|
22
20
|
|
|
23
|
-
|
|
24
21
|
logger = logging.getLogger(__name__)
|
|
25
22
|
|
|
26
23
|
|
|
@@ -33,22 +30,35 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"])
|
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class SparkDataFrame(
|
|
33
|
+
NoCachePersistSupportMixin,
|
|
36
34
|
_BaseDataFrame[
|
|
37
35
|
"SparkSession",
|
|
38
36
|
"SparkDataFrameWriter",
|
|
39
37
|
"SparkDataFrameNaFunctions",
|
|
40
38
|
"SparkDataFrameStatFunctions",
|
|
41
39
|
"SparkGroupedData",
|
|
42
|
-
]
|
|
40
|
+
],
|
|
43
41
|
):
|
|
44
42
|
_na = SparkDataFrameNaFunctions
|
|
45
43
|
_stat = SparkDataFrameStatFunctions
|
|
46
44
|
_group_data = SparkGroupedData
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
46
|
+
@property
|
|
47
|
+
def _typed_columns(self) -> t.List[Column]:
|
|
48
|
+
columns = []
|
|
49
|
+
for field in self.session.spark_session.sql(
|
|
50
|
+
self.session._to_sql(self.expression)
|
|
51
|
+
).schema.fields:
|
|
52
|
+
columns.append(
|
|
53
|
+
Column(
|
|
54
|
+
name=field.name,
|
|
55
|
+
dataType=exp.DataType.build(field.dataType.simpleString(), dialect="spark").sql(
|
|
56
|
+
dialect="spark"
|
|
57
|
+
),
|
|
58
|
+
nullable=field.nullable,
|
|
59
|
+
description=None,
|
|
60
|
+
isPartition=False,
|
|
61
|
+
isBucket=False,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return columns
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import difflib
|
|
5
|
+
import os
|
|
6
|
+
import typing as t
|
|
7
|
+
from itertools import zip_longest
|
|
8
|
+
|
|
9
|
+
from sqlframe.base import types
|
|
10
|
+
from sqlframe.base.dataframe import _BaseDataFrame
|
|
11
|
+
from sqlframe.base.exceptions import (
|
|
12
|
+
DataFrameDiffError,
|
|
13
|
+
SchemaDiffError,
|
|
14
|
+
SQLFrameException,
|
|
15
|
+
)
|
|
16
|
+
from sqlframe.base.util import verify_pandas_installed
|
|
17
|
+
|
|
18
|
+
if t.TYPE_CHECKING:
|
|
19
|
+
import pandas as pd
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _terminal_color_support():
|
|
23
|
+
try:
|
|
24
|
+
# determine if environment supports color
|
|
25
|
+
script = "$(test $(tput colors)) && $(test $(tput colors) -ge 8) && echo true || echo false"
|
|
26
|
+
return os.popen(script).read()
|
|
27
|
+
except Exception:
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _context_diff(actual: t.List[str], expected: t.List[str], n: int = 3):
|
|
32
|
+
"""
|
|
33
|
+
Modified from difflib context_diff API,
|
|
34
|
+
see original code here: https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def red(s: str) -> str:
|
|
38
|
+
red_color = "\033[31m"
|
|
39
|
+
no_color = "\033[0m"
|
|
40
|
+
return red_color + str(s) + no_color
|
|
41
|
+
|
|
42
|
+
prefix = dict(insert="+ ", delete="- ", replace="! ", equal=" ")
|
|
43
|
+
for group in difflib.SequenceMatcher(None, actual, expected).get_grouped_opcodes(n):
|
|
44
|
+
yield "*** actual ***"
|
|
45
|
+
if any(tag in {"replace", "delete"} for tag, _, _, _, _ in group):
|
|
46
|
+
for tag, i1, i2, _, _ in group:
|
|
47
|
+
for line in actual[i1:i2]:
|
|
48
|
+
if tag != "equal" and _terminal_color_support():
|
|
49
|
+
yield red(prefix[tag] + str(line))
|
|
50
|
+
else:
|
|
51
|
+
yield prefix[tag] + str(line)
|
|
52
|
+
|
|
53
|
+
yield "\n"
|
|
54
|
+
|
|
55
|
+
yield "*** expected ***"
|
|
56
|
+
if any(tag in {"replace", "insert"} for tag, _, _, _, _ in group):
|
|
57
|
+
for tag, _, _, j1, j2 in group:
|
|
58
|
+
for line in expected[j1:j2]:
|
|
59
|
+
if tag != "equal" and _terminal_color_support():
|
|
60
|
+
yield red(prefix[tag] + str(line))
|
|
61
|
+
else:
|
|
62
|
+
yield prefix[tag] + str(line)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Source: https://github.com/apache/spark/blob/master/python/pyspark/testing/utils.py#L519
|
|
66
|
+
def assertDataFrameEqual(
|
|
67
|
+
actual: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
|
|
68
|
+
expected: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
|
|
69
|
+
checkRowOrder: bool = False,
|
|
70
|
+
rtol: float = 1e-5,
|
|
71
|
+
atol: float = 1e-8,
|
|
72
|
+
):
|
|
73
|
+
r"""
|
|
74
|
+
A util function to assert equality between `actual` and `expected`
|
|
75
|
+
(DataFrames or lists of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`.
|
|
76
|
+
|
|
77
|
+
Supports Spark, Spark Connect, pandas, and pandas-on-Spark DataFrames.
|
|
78
|
+
For more information about pandas-on-Spark DataFrame equality, see the docs for
|
|
79
|
+
`assertPandasOnSparkEqual`.
|
|
80
|
+
|
|
81
|
+
.. versionadded:: 3.5.0
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
actual : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
|
|
86
|
+
The DataFrame that is being compared or tested.
|
|
87
|
+
expected : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
|
|
88
|
+
The expected result of the operation, for comparison with the actual result.
|
|
89
|
+
checkRowOrder : bool, optional
|
|
90
|
+
A flag indicating whether the order of rows should be considered in the comparison.
|
|
91
|
+
If set to `False` (default), the row order is not taken into account.
|
|
92
|
+
If set to `True`, the order of rows is important and will be checked during comparison.
|
|
93
|
+
(See Notes)
|
|
94
|
+
rtol : float, optional
|
|
95
|
+
The relative tolerance, used in asserting approximate equality for float values in actual
|
|
96
|
+
and expected. Set to 1e-5 by default. (See Notes)
|
|
97
|
+
atol : float, optional
|
|
98
|
+
The absolute tolerance, used in asserting approximate equality for float values in actual
|
|
99
|
+
and expected. Set to 1e-8 by default. (See Notes)
|
|
100
|
+
|
|
101
|
+
Notes
|
|
102
|
+
-----
|
|
103
|
+
When `assertDataFrameEqual` fails, the error message uses the Python `difflib` library to
|
|
104
|
+
display a diff log of each row that differs in `actual` and `expected`.
|
|
105
|
+
|
|
106
|
+
For `checkRowOrder`, note that PySpark DataFrame ordering is non-deterministic, unless
|
|
107
|
+
explicitly sorted.
|
|
108
|
+
|
|
109
|
+
Note that schema equality is checked only when `expected` is a DataFrame (not a list of Rows).
|
|
110
|
+
|
|
111
|
+
For DataFrames with float values, assertDataFrame asserts approximate equality.
|
|
112
|
+
Two float values a and b are approximately equal if the following equation is True:
|
|
113
|
+
|
|
114
|
+
``absolute(a - b) <= (atol + rtol * absolute(b))``.
|
|
115
|
+
|
|
116
|
+
Examples
|
|
117
|
+
--------
|
|
118
|
+
>>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
|
|
119
|
+
>>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
|
|
120
|
+
>>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical
|
|
121
|
+
|
|
122
|
+
>>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
|
|
123
|
+
>>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
|
|
124
|
+
>>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol
|
|
125
|
+
|
|
126
|
+
>>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"])
|
|
127
|
+
>>> list_of_rows = [Row(1, 1000), Row(2, 3000)]
|
|
128
|
+
>>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal
|
|
129
|
+
|
|
130
|
+
>>> import pyspark.pandas as ps
|
|
131
|
+
>>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
|
|
132
|
+
>>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
|
|
133
|
+
>>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal
|
|
134
|
+
|
|
135
|
+
>>> df1 = spark.createDataFrame(
|
|
136
|
+
... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"])
|
|
137
|
+
>>> df2 = spark.createDataFrame(
|
|
138
|
+
... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"])
|
|
139
|
+
>>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
140
|
+
Traceback (most recent call last):
|
|
141
|
+
...
|
|
142
|
+
PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % )
|
|
143
|
+
*** actual ***
|
|
144
|
+
! Row(id='1', amount=1000.0)
|
|
145
|
+
Row(id='2', amount=3000.0)
|
|
146
|
+
! Row(id='3', amount=2000.0)
|
|
147
|
+
*** expected ***
|
|
148
|
+
! Row(id='1', amount=1001.0)
|
|
149
|
+
Row(id='2', amount=3000.0)
|
|
150
|
+
! Row(id='3', amount=2003.0)
|
|
151
|
+
"""
|
|
152
|
+
import pandas as pd
|
|
153
|
+
|
|
154
|
+
if actual is None and expected is None:
|
|
155
|
+
return True
|
|
156
|
+
elif actual is None or expected is None:
|
|
157
|
+
raise SQLFrameException("Missing required arguments: actual and expected")
|
|
158
|
+
|
|
159
|
+
def compare_rows(r1: types.Row, r2: types.Row):
|
|
160
|
+
def compare_vals(val1, val2):
|
|
161
|
+
if isinstance(val1, list) and isinstance(val2, list):
|
|
162
|
+
return len(val1) == len(val2) and all(
|
|
163
|
+
compare_vals(x, y) for x, y in zip(val1, val2)
|
|
164
|
+
)
|
|
165
|
+
elif isinstance(val1, types.Row) and isinstance(val2, types.Row):
|
|
166
|
+
return all(compare_vals(x, y) for x, y in zip(val1, val2))
|
|
167
|
+
elif isinstance(val1, dict) and isinstance(val2, dict):
|
|
168
|
+
return (
|
|
169
|
+
len(val1.keys()) == len(val2.keys())
|
|
170
|
+
and val1.keys() == val2.keys()
|
|
171
|
+
and all(compare_vals(val1[k], val2[k]) for k in val1.keys())
|
|
172
|
+
)
|
|
173
|
+
elif isinstance(val1, float) and isinstance(val2, float):
|
|
174
|
+
if abs(val1 - val2) > (atol + rtol * abs(val2)):
|
|
175
|
+
return False
|
|
176
|
+
else:
|
|
177
|
+
if val1 != val2:
|
|
178
|
+
return False
|
|
179
|
+
return True
|
|
180
|
+
|
|
181
|
+
if r1 is None and r2 is None:
|
|
182
|
+
return True
|
|
183
|
+
elif r1 is None or r2 is None:
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
return compare_vals(r1, r2)
|
|
187
|
+
|
|
188
|
+
def assert_rows_equal(rows1: t.List[types.Row], rows2: t.List[types.Row]):
|
|
189
|
+
zipped = list(zip_longest(rows1, rows2))
|
|
190
|
+
diff_rows_cnt = 0
|
|
191
|
+
diff_rows = False
|
|
192
|
+
|
|
193
|
+
rows_str1 = ""
|
|
194
|
+
rows_str2 = ""
|
|
195
|
+
|
|
196
|
+
# count different rows
|
|
197
|
+
for r1, r2 in zipped:
|
|
198
|
+
rows_str1 += str(r1) + "\n"
|
|
199
|
+
rows_str2 += str(r2) + "\n"
|
|
200
|
+
if not compare_rows(r1, r2):
|
|
201
|
+
diff_rows_cnt += 1
|
|
202
|
+
diff_rows = True
|
|
203
|
+
|
|
204
|
+
generated_diff = _context_diff(
|
|
205
|
+
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=len(zipped)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if diff_rows:
|
|
209
|
+
error_msg = "Results do not match: "
|
|
210
|
+
percent_diff = (diff_rows_cnt / len(zipped)) * 100
|
|
211
|
+
error_msg += "( %.5f %% )" % percent_diff
|
|
212
|
+
error_msg += "\n" + "\n".join(generated_diff)
|
|
213
|
+
raise DataFrameDiffError("Rows are different:\n%s" % error_msg)
|
|
214
|
+
|
|
215
|
+
# convert actual and expected to list
|
|
216
|
+
if not isinstance(actual, list) and not isinstance(expected, list):
|
|
217
|
+
# only compare schema if expected is not a List
|
|
218
|
+
assertSchemaEqual(actual.schema, expected.schema) # type: ignore
|
|
219
|
+
|
|
220
|
+
if not isinstance(actual, list):
|
|
221
|
+
actual_list = actual.collect() # type: ignore
|
|
222
|
+
else:
|
|
223
|
+
actual_list = actual
|
|
224
|
+
|
|
225
|
+
if not isinstance(expected, list):
|
|
226
|
+
expected_list = expected.collect() # type: ignore
|
|
227
|
+
else:
|
|
228
|
+
expected_list = expected
|
|
229
|
+
|
|
230
|
+
if not checkRowOrder:
|
|
231
|
+
# rename duplicate columns for sorting
|
|
232
|
+
actual_list = sorted(actual_list, key=lambda x: str(x))
|
|
233
|
+
expected_list = sorted(expected_list, key=lambda x: str(x))
|
|
234
|
+
|
|
235
|
+
assert_rows_equal(actual_list, expected_list)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def assertSchemaEqual(actual: types.StructType, expected: types.StructType):
|
|
239
|
+
r"""
|
|
240
|
+
A util function to assert equality between DataFrame schemas `actual` and `expected`.
|
|
241
|
+
|
|
242
|
+
.. versionadded:: 3.5.0
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
actual : StructType
|
|
247
|
+
The DataFrame schema that is being compared or tested.
|
|
248
|
+
expected : StructType
|
|
249
|
+
The expected schema, for comparison with the actual schema.
|
|
250
|
+
|
|
251
|
+
Notes
|
|
252
|
+
-----
|
|
253
|
+
When assertSchemaEqual fails, the error message uses the Python `difflib` library to display
|
|
254
|
+
a diff log of the `actual` and `expected` schemas.
|
|
255
|
+
|
|
256
|
+
Examples
|
|
257
|
+
--------
|
|
258
|
+
>>> from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType, DoubleType
|
|
259
|
+
>>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
|
|
260
|
+
>>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
|
|
261
|
+
>>> assertSchemaEqual(s1, s2) # pass, schemas are identical
|
|
262
|
+
|
|
263
|
+
>>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"])
|
|
264
|
+
>>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"])
|
|
265
|
+
>>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL
|
|
266
|
+
Traceback (most recent call last):
|
|
267
|
+
...
|
|
268
|
+
PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
|
|
269
|
+
--- actual
|
|
270
|
+
+++ expected
|
|
271
|
+
- StructType([StructField('id', LongType(), True), StructField('number', LongType(), True)])
|
|
272
|
+
? ^^ ^^^^^
|
|
273
|
+
+ StructType([StructField('id', StringType(), True), StructField('amount', LongType(), True)])
|
|
274
|
+
? ^^^^ ++++ ^
|
|
275
|
+
"""
|
|
276
|
+
if not isinstance(actual, types.StructType):
|
|
277
|
+
raise RuntimeError("actual must be a StructType")
|
|
278
|
+
if not isinstance(expected, types.StructType):
|
|
279
|
+
raise RuntimeError("expected must be a StructType")
|
|
280
|
+
|
|
281
|
+
def compare_schemas_ignore_nullable(s1: types.StructType, s2: types.StructType):
|
|
282
|
+
if len(s1) != len(s2):
|
|
283
|
+
return False
|
|
284
|
+
zipped = zip_longest(s1, s2)
|
|
285
|
+
for sf1, sf2 in zipped:
|
|
286
|
+
if not compare_structfields_ignore_nullable(sf1, sf2):
|
|
287
|
+
return False
|
|
288
|
+
return True
|
|
289
|
+
|
|
290
|
+
def compare_structfields_ignore_nullable(
|
|
291
|
+
actualSF: types.StructField, expectedSF: types.StructField
|
|
292
|
+
):
|
|
293
|
+
if actualSF is None and expectedSF is None:
|
|
294
|
+
return True
|
|
295
|
+
elif actualSF is None or expectedSF is None:
|
|
296
|
+
return False
|
|
297
|
+
if actualSF.name != expectedSF.name:
|
|
298
|
+
return False
|
|
299
|
+
else:
|
|
300
|
+
return compare_datatypes_ignore_nullable(actualSF.dataType, expectedSF.dataType)
|
|
301
|
+
|
|
302
|
+
def compare_datatypes_ignore_nullable(dt1: t.Any, dt2: t.Any):
|
|
303
|
+
# checks datatype equality, using recursion to ignore nullable
|
|
304
|
+
if dt1.typeName() == dt2.typeName():
|
|
305
|
+
if dt1.typeName() == "array":
|
|
306
|
+
return compare_datatypes_ignore_nullable(dt1.elementType, dt2.elementType)
|
|
307
|
+
elif dt1.typeName() == "struct":
|
|
308
|
+
return compare_schemas_ignore_nullable(dt1, dt2)
|
|
309
|
+
else:
|
|
310
|
+
return True
|
|
311
|
+
else:
|
|
312
|
+
return False
|
|
313
|
+
|
|
314
|
+
# ignore nullable flag by default
|
|
315
|
+
if not compare_schemas_ignore_nullable(actual, expected):
|
|
316
|
+
generated_diff = difflib.ndiff(str(actual).splitlines(), str(expected).splitlines())
|
|
317
|
+
|
|
318
|
+
error_msg = "\n".join(generated_diff)
|
|
319
|
+
|
|
320
|
+
raise SchemaDiffError(error_msg)
|