sqlframe 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,16 +7,17 @@ import typing as t
7
7
 
8
8
  from sqlglot import exp, parse_one
9
9
 
10
- from sqlframe.base.catalog import Function, _BaseCatalog
10
+ from sqlframe.base.catalog import Column, Function, _BaseCatalog
11
+ from sqlframe.base.decorators import normalize
11
12
  from sqlframe.base.mixins.catalog_mixins import (
12
13
  GetCurrentCatalogFromFunctionMixin,
13
14
  GetCurrentDatabaseFromFunctionMixin,
14
15
  ListCatalogsFromInfoSchemaMixin,
15
- ListColumnsFromInfoSchemaMixin,
16
16
  ListDatabasesFromInfoSchemaMixin,
17
17
  ListTablesFromInfoSchemaMixin,
18
18
  SetCurrentDatabaseFromSearchPathMixin,
19
19
  )
20
+ from sqlframe.base.util import to_schema
20
21
 
21
22
  if t.TYPE_CHECKING:
22
23
  from sqlframe.postgres.session import PostgresSession # noqa
@@ -30,12 +31,131 @@ class PostgresCatalog(
30
31
  ListCatalogsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
31
32
  SetCurrentDatabaseFromSearchPathMixin["PostgresSession", "PostgresDataFrame"],
32
33
  ListTablesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
33
- ListColumnsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
34
34
  _BaseCatalog["PostgresSession", "PostgresDataFrame"],
35
35
  ):
36
36
  CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
37
37
  TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")
38
38
 
39
+ @normalize(["tableName", "dbName"])
40
+ def listColumns(
41
+ self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
42
+ ) -> t.List[Column]:
43
+ """Returns a t.List of columns for the given table/view in the specified database.
44
+
45
+ .. versionadded:: 2.0.0
46
+
47
+ Parameters
48
+ ----------
49
+ tableName : str
50
+ name of the table to t.List columns.
51
+
52
+ .. versionchanged:: 3.4.0
53
+ Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
54
+
55
+ dbName : str, t.Optional
56
+ name of the database to find the table to t.List columns.
57
+
58
+ Returns
59
+ -------
60
+ t.List
61
+ A t.List of :class:`Column`.
62
+
63
+ Notes
64
+ -----
65
+ The order of arguments here is different from that of its JVM counterpart
66
+ because Python does not support method overloading.
67
+
68
+ If no database is specified, the current database and catalog
69
+ are used. This API includes all temporary views.
70
+
71
+ Examples
72
+ --------
73
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
74
+ >>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
75
+ >>> spark.catalog.t.listColumns("tblA")
76
+ [Column(name='name', description=None, dataType='string', nullable=True, ...
77
+ >>> _ = spark.sql("DROP TABLE tblA")
78
+ """
79
+ if df := self.session.temp_views.get(tableName):
80
+ return [
81
+ Column(
82
+ name=x,
83
+ description=None,
84
+ dataType="",
85
+ nullable=True,
86
+ isPartition=False,
87
+ isBucket=False,
88
+ )
89
+ for x in df.columns
90
+ ]
91
+
92
+ table = exp.to_table(tableName, dialect=self.session.input_dialect)
93
+ schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
94
+ if not table.db:
95
+ if schema and schema.db:
96
+ table.set("db", schema.args["db"])
97
+ else:
98
+ table.set(
99
+ "db",
100
+ exp.parse_identifier(
101
+ self.currentDatabase(), dialect=self.session.input_dialect
102
+ ),
103
+ )
104
+ if not table.catalog:
105
+ if schema and schema.catalog:
106
+ table.set("catalog", schema.args["catalog"])
107
+ else:
108
+ table.set(
109
+ "catalog",
110
+ exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
111
+ )
112
+ source_table = self._get_info_schema_table("columns", database=table.db)
113
+ select = parse_one(
114
+ f"""
115
+ SELECT
116
+ att.attname AS column_name,
117
+ pg_catalog.format_type(att.atttypid, NULL) AS data_type,
118
+ col.is_nullable
119
+ FROM
120
+ pg_catalog.pg_attribute att
121
+ JOIN
122
+ pg_catalog.pg_class cls ON cls.oid = att.attrelid
123
+ JOIN
124
+ pg_catalog.pg_namespace nsp ON nsp.oid = cls.relnamespace
125
+ JOIN
126
+ information_schema.columns col ON col.table_schema = nsp.nspname AND col.table_name = cls.relname AND col.column_name = att.attname
127
+ WHERE
128
+ cls.relname = '{table.name}' AND -- replace with your table name
129
+ att.attnum > 0 AND
130
+ NOT att.attisdropped
131
+ ORDER BY
132
+ att.attnum;
133
+ """,
134
+ dialect="postgres",
135
+ )
136
+ if table.db:
137
+ schema_filter: exp.Expression = exp.column("table_schema").eq(table.db)
138
+ if include_temp and self.TEMP_SCHEMA_FILTER:
139
+ schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER)
140
+ select = select.where(schema_filter) # type: ignore
141
+ if table.catalog:
142
+ catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog)
143
+ if include_temp and self.TEMP_CATALOG_FILTER:
144
+ catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER)
145
+ select = select.where(catalog_filter) # type: ignore
146
+ results = self.session._fetch_rows(select)
147
+ return [
148
+ Column(
149
+ name=x["column_name"],
150
+ description=None,
151
+ dataType=x["data_type"],
152
+ nullable=x["is_nullable"] == "YES",
153
+ isPartition=False,
154
+ isBucket=False,
155
+ )
156
+ for x in results
157
+ ]
158
+
39
159
  def listFunctions(
40
160
  self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
41
161
  ) -> t.List[Function]:
@@ -9,7 +9,10 @@ from sqlframe.base.dataframe import (
9
9
  _BaseDataFrameNaFunctions,
10
10
  _BaseDataFrameStatFunctions,
11
11
  )
12
- from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin
12
+ from sqlframe.base.mixins.dataframe_mixins import (
13
+ NoCachePersistSupportMixin,
14
+ TypedColumnsFromTempViewMixin,
15
+ )
13
16
  from sqlframe.postgres.group import PostgresGroupedData
14
17
 
15
18
  if sys.version_info >= (3, 11):
@@ -34,7 +37,8 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr
34
37
 
35
38
 
36
39
  class PostgresDataFrame(
37
- PrintSchemaFromTempObjectsMixin,
40
+ NoCachePersistSupportMixin,
41
+ TypedColumnsFromTempViewMixin,
38
42
  _BaseDataFrame[
39
43
  "PostgresSession",
40
44
  "PostgresDataFrameWriter",
@@ -46,11 +50,3 @@ class PostgresDataFrame(
46
50
  _na = PostgresDataFrameNaFunctions
47
51
  _stat = PostgresDataFrameStatFunctions
48
52
  _group_data = PostgresGroupedData
49
-
50
- def cache(self) -> Self:
51
- logger.warning("Postgres does not support caching. Ignoring cache() call.")
52
- return self
53
-
54
- def persist(self) -> Self:
55
- logger.warning("Postgres does not support persist. Ignoring persist() call.")
56
- return self
@@ -9,13 +9,9 @@ from sqlframe.base.dataframe import (
9
9
  _BaseDataFrameNaFunctions,
10
10
  _BaseDataFrameStatFunctions,
11
11
  )
12
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
13
  from sqlframe.redshift.group import RedshiftGroupedData
13
14
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
15
  if t.TYPE_CHECKING:
20
16
  from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
21
17
  from sqlframe.redshift.session import RedshiftSession
@@ -33,22 +29,15 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr
33
29
 
34
30
 
35
31
  class RedshiftDataFrame(
32
+ NoCachePersistSupportMixin,
36
33
  _BaseDataFrame[
37
34
  "RedshiftSession",
38
35
  "RedshiftDataFrameWriter",
39
36
  "RedshiftDataFrameNaFunctions",
40
37
  "RedshiftDataFrameStatFunctions",
41
38
  "RedshiftGroupedData",
42
- ]
39
+ ],
43
40
  ):
44
41
  _na = RedshiftDataFrameNaFunctions
45
42
  _stat = RedshiftDataFrameStatFunctions
46
43
  _group_data = RedshiftGroupedData
47
-
48
- def cache(self) -> Self:
49
- logger.warning("Redshift does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Redshift does not support persist. Ignoring persist() call.")
54
- return self
@@ -4,18 +4,15 @@ import logging
4
4
  import sys
5
5
  import typing as t
6
6
 
7
+ from sqlframe.base.catalog import Column as CatalogColumn
7
8
  from sqlframe.base.dataframe import (
8
9
  _BaseDataFrame,
9
10
  _BaseDataFrameNaFunctions,
10
11
  _BaseDataFrameStatFunctions,
11
12
  )
13
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
14
  from sqlframe.snowflake.group import SnowflakeGroupedData
13
15
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
16
  if t.TYPE_CHECKING:
20
17
  from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
21
18
  from sqlframe.snowflake.session import SnowflakeSession
@@ -33,22 +30,35 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData
33
30
 
34
31
 
35
32
  class SnowflakeDataFrame(
33
+ NoCachePersistSupportMixin,
36
34
  _BaseDataFrame[
37
35
  "SnowflakeSession",
38
36
  "SnowflakeDataFrameWriter",
39
37
  "SnowflakeDataFrameNaFunctions",
40
38
  "SnowflakeDataFrameStatFunctions",
41
39
  "SnowflakeGroupedData",
42
- ]
40
+ ],
43
41
  ):
44
42
  _na = SnowflakeDataFrameNaFunctions
45
43
  _stat = SnowflakeDataFrameStatFunctions
46
44
  _group_data = SnowflakeGroupedData
47
45
 
48
- def cache(self) -> Self:
49
- logger.warning("Snowflake does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Snowflake does not support persist. Ignoring persist() call.")
54
- return self
46
+ @property
47
+ def _typed_columns(self) -> t.List[CatalogColumn]:
48
+ df = self._convert_leaf_to_cte()
49
+ df = df.limit(0)
50
+ self.session._execute(df.expression)
51
+ query_id = self.session._cur.sfqid
52
+ columns = []
53
+ for row in self.session._fetch_rows(f"DESCRIBE RESULT '{query_id}'"):
54
+ columns.append(
55
+ CatalogColumn(
56
+ name=row.name,
57
+ dataType=row.type,
58
+ nullable=row["null?"] == "Y",
59
+ description=row.comment,
60
+ isPartition=False,
61
+ isBucket=False,
62
+ )
63
+ )
64
+ return columns
@@ -1,26 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import sys
5
4
  import typing as t
6
5
 
6
+ from sqlglot import exp
7
+
8
+ from sqlframe.base.catalog import Column
7
9
  from sqlframe.base.dataframe import (
8
10
  _BaseDataFrame,
9
11
  _BaseDataFrameNaFunctions,
10
12
  _BaseDataFrameStatFunctions,
11
13
  )
14
+ from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
12
15
  from sqlframe.spark.group import SparkGroupedData
13
16
 
14
- if sys.version_info >= (3, 11):
15
- from typing import Self
16
- else:
17
- from typing_extensions import Self
18
-
19
17
  if t.TYPE_CHECKING:
20
18
  from sqlframe.spark.readwriter import SparkDataFrameWriter
21
19
  from sqlframe.spark.session import SparkSession
22
20
 
23
-
24
21
  logger = logging.getLogger(__name__)
25
22
 
26
23
 
@@ -33,22 +30,35 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"])
33
30
 
34
31
 
35
32
  class SparkDataFrame(
33
+ NoCachePersistSupportMixin,
36
34
  _BaseDataFrame[
37
35
  "SparkSession",
38
36
  "SparkDataFrameWriter",
39
37
  "SparkDataFrameNaFunctions",
40
38
  "SparkDataFrameStatFunctions",
41
39
  "SparkGroupedData",
42
- ]
40
+ ],
43
41
  ):
44
42
  _na = SparkDataFrameNaFunctions
45
43
  _stat = SparkDataFrameStatFunctions
46
44
  _group_data = SparkGroupedData
47
45
 
48
- def cache(self) -> Self:
49
- logger.warning("Spark does not support caching. Ignoring cache() call.")
50
- return self
51
-
52
- def persist(self) -> Self:
53
- logger.warning("Spark does not support persist. Ignoring persist() call.")
54
- return self
46
+ @property
47
+ def _typed_columns(self) -> t.List[Column]:
48
+ columns = []
49
+ for field in self.session.spark_session.sql(
50
+ self.session._to_sql(self.expression)
51
+ ).schema.fields:
52
+ columns.append(
53
+ Column(
54
+ name=field.name,
55
+ dataType=exp.DataType.build(field.dataType.simpleString(), dialect="spark").sql(
56
+ dialect="spark"
57
+ ),
58
+ nullable=field.nullable,
59
+ description=None,
60
+ isPartition=False,
61
+ isBucket=False,
62
+ )
63
+ )
64
+ return columns
@@ -0,0 +1,3 @@
1
+ from sqlframe.testing.utils import assertDataFrameEqual, assertSchemaEqual
2
+
3
+ __all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
@@ -0,0 +1,320 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+ from __future__ import annotations
3
+
4
+ import difflib
5
+ import os
6
+ import typing as t
7
+ from itertools import zip_longest
8
+
9
+ from sqlframe.base import types
10
+ from sqlframe.base.dataframe import _BaseDataFrame
11
+ from sqlframe.base.exceptions import (
12
+ DataFrameDiffError,
13
+ SchemaDiffError,
14
+ SQLFrameException,
15
+ )
16
+ from sqlframe.base.util import verify_pandas_installed
17
+
18
+ if t.TYPE_CHECKING:
19
+ import pandas as pd
20
+
21
+
22
+ def _terminal_color_support():
23
+ try:
24
+ # determine if environment supports color
25
+ script = "$(test $(tput colors)) && $(test $(tput colors) -ge 8) && echo true || echo false"
26
+ return os.popen(script).read()
27
+ except Exception:
28
+ return False
29
+
30
+
31
+ def _context_diff(actual: t.List[str], expected: t.List[str], n: int = 3):
32
+ """
33
+ Modified from difflib context_diff API,
34
+ see original code here: https://github.com/python/cpython/blob/main/Lib/difflib.py#L1180
35
+ """
36
+
37
+ def red(s: str) -> str:
38
+ red_color = "\033[31m"
39
+ no_color = "\033[0m"
40
+ return red_color + str(s) + no_color
41
+
42
+ prefix = dict(insert="+ ", delete="- ", replace="! ", equal=" ")
43
+ for group in difflib.SequenceMatcher(None, actual, expected).get_grouped_opcodes(n):
44
+ yield "*** actual ***"
45
+ if any(tag in {"replace", "delete"} for tag, _, _, _, _ in group):
46
+ for tag, i1, i2, _, _ in group:
47
+ for line in actual[i1:i2]:
48
+ if tag != "equal" and _terminal_color_support():
49
+ yield red(prefix[tag] + str(line))
50
+ else:
51
+ yield prefix[tag] + str(line)
52
+
53
+ yield "\n"
54
+
55
+ yield "*** expected ***"
56
+ if any(tag in {"replace", "insert"} for tag, _, _, _, _ in group):
57
+ for tag, _, _, j1, j2 in group:
58
+ for line in expected[j1:j2]:
59
+ if tag != "equal" and _terminal_color_support():
60
+ yield red(prefix[tag] + str(line))
61
+ else:
62
+ yield prefix[tag] + str(line)
63
+
64
+
65
+ # Source: https://github.com/apache/spark/blob/master/python/pyspark/testing/utils.py#L519
66
+ def assertDataFrameEqual(
67
+ actual: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
68
+ expected: t.Union[_BaseDataFrame, pd.DataFrame, t.List[types.Row]],
69
+ checkRowOrder: bool = False,
70
+ rtol: float = 1e-5,
71
+ atol: float = 1e-8,
72
+ ):
73
+ r"""
74
+ A util function to assert equality between `actual` and `expected`
75
+ (DataFrames or lists of Rows), with optional parameters `checkRowOrder`, `rtol`, and `atol`.
76
+
77
+ Supports Spark, Spark Connect, pandas, and pandas-on-Spark DataFrames.
78
+ For more information about pandas-on-Spark DataFrame equality, see the docs for
79
+ `assertPandasOnSparkEqual`.
80
+
81
+ .. versionadded:: 3.5.0
82
+
83
+ Parameters
84
+ ----------
85
+ actual : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
86
+ The DataFrame that is being compared or tested.
87
+ expected : DataFrame (Spark, Spark Connect, pandas, or pandas-on-Spark) or list of Rows
88
+ The expected result of the operation, for comparison with the actual result.
89
+ checkRowOrder : bool, optional
90
+ A flag indicating whether the order of rows should be considered in the comparison.
91
+ If set to `False` (default), the row order is not taken into account.
92
+ If set to `True`, the order of rows is important and will be checked during comparison.
93
+ (See Notes)
94
+ rtol : float, optional
95
+ The relative tolerance, used in asserting approximate equality for float values in actual
96
+ and expected. Set to 1e-5 by default. (See Notes)
97
+ atol : float, optional
98
+ The absolute tolerance, used in asserting approximate equality for float values in actual
99
+ and expected. Set to 1e-8 by default. (See Notes)
100
+
101
+ Notes
102
+ -----
103
+ When `assertDataFrameEqual` fails, the error message uses the Python `difflib` library to
104
+ display a diff log of each row that differs in `actual` and `expected`.
105
+
106
+ For `checkRowOrder`, note that PySpark DataFrame ordering is non-deterministic, unless
107
+ explicitly sorted.
108
+
109
+ Note that schema equality is checked only when `expected` is a DataFrame (not a list of Rows).
110
+
111
+ For DataFrames with float values, assertDataFrame asserts approximate equality.
112
+ Two float values a and b are approximately equal if the following equation is True:
113
+
114
+ ``absolute(a - b) <= (atol + rtol * absolute(b))``.
115
+
116
+ Examples
117
+ --------
118
+ >>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
119
+ >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
120
+ >>> assertDataFrameEqual(df1, df2) # pass, DataFrames are identical
121
+
122
+ >>> df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
123
+ >>> df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
124
+ >>> assertDataFrameEqual(df1, df2, rtol=1e-1) # pass, DataFrames are approx equal by rtol
125
+
126
+ >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "amount"])
127
+ >>> list_of_rows = [Row(1, 1000), Row(2, 3000)]
128
+ >>> assertDataFrameEqual(df1, list_of_rows) # pass, actual and expected data are equal
129
+
130
+ >>> import pyspark.pandas as ps
131
+ >>> df1 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
132
+ >>> df2 = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})
133
+ >>> assertDataFrameEqual(df1, df2) # pass, pandas-on-Spark DataFrames are equal
134
+
135
+ >>> df1 = spark.createDataFrame(
136
+ ... data=[("1", 1000.00), ("2", 3000.00), ("3", 2000.00)], schema=["id", "amount"])
137
+ >>> df2 = spark.createDataFrame(
138
+ ... data=[("1", 1001.00), ("2", 3000.00), ("3", 2003.00)], schema=["id", "amount"])
139
+ >>> assertDataFrameEqual(df1, df2) # doctest: +IGNORE_EXCEPTION_DETAIL
140
+ Traceback (most recent call last):
141
+ ...
142
+ PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 % )
143
+ *** actual ***
144
+ ! Row(id='1', amount=1000.0)
145
+ Row(id='2', amount=3000.0)
146
+ ! Row(id='3', amount=2000.0)
147
+ *** expected ***
148
+ ! Row(id='1', amount=1001.0)
149
+ Row(id='2', amount=3000.0)
150
+ ! Row(id='3', amount=2003.0)
151
+ """
152
+ import pandas as pd
153
+
154
+ if actual is None and expected is None:
155
+ return True
156
+ elif actual is None or expected is None:
157
+ raise SQLFrameException("Missing required arguments: actual and expected")
158
+
159
+ def compare_rows(r1: types.Row, r2: types.Row):
160
+ def compare_vals(val1, val2):
161
+ if isinstance(val1, list) and isinstance(val2, list):
162
+ return len(val1) == len(val2) and all(
163
+ compare_vals(x, y) for x, y in zip(val1, val2)
164
+ )
165
+ elif isinstance(val1, types.Row) and isinstance(val2, types.Row):
166
+ return all(compare_vals(x, y) for x, y in zip(val1, val2))
167
+ elif isinstance(val1, dict) and isinstance(val2, dict):
168
+ return (
169
+ len(val1.keys()) == len(val2.keys())
170
+ and val1.keys() == val2.keys()
171
+ and all(compare_vals(val1[k], val2[k]) for k in val1.keys())
172
+ )
173
+ elif isinstance(val1, float) and isinstance(val2, float):
174
+ if abs(val1 - val2) > (atol + rtol * abs(val2)):
175
+ return False
176
+ else:
177
+ if val1 != val2:
178
+ return False
179
+ return True
180
+
181
+ if r1 is None and r2 is None:
182
+ return True
183
+ elif r1 is None or r2 is None:
184
+ return False
185
+
186
+ return compare_vals(r1, r2)
187
+
188
+ def assert_rows_equal(rows1: t.List[types.Row], rows2: t.List[types.Row]):
189
+ zipped = list(zip_longest(rows1, rows2))
190
+ diff_rows_cnt = 0
191
+ diff_rows = False
192
+
193
+ rows_str1 = ""
194
+ rows_str2 = ""
195
+
196
+ # count different rows
197
+ for r1, r2 in zipped:
198
+ rows_str1 += str(r1) + "\n"
199
+ rows_str2 += str(r2) + "\n"
200
+ if not compare_rows(r1, r2):
201
+ diff_rows_cnt += 1
202
+ diff_rows = True
203
+
204
+ generated_diff = _context_diff(
205
+ actual=rows_str1.splitlines(), expected=rows_str2.splitlines(), n=len(zipped)
206
+ )
207
+
208
+ if diff_rows:
209
+ error_msg = "Results do not match: "
210
+ percent_diff = (diff_rows_cnt / len(zipped)) * 100
211
+ error_msg += "( %.5f %% )" % percent_diff
212
+ error_msg += "\n" + "\n".join(generated_diff)
213
+ raise DataFrameDiffError("Rows are different:\n%s" % error_msg)
214
+
215
+ # convert actual and expected to list
216
+ if not isinstance(actual, list) and not isinstance(expected, list):
217
+ # only compare schema if expected is not a List
218
+ assertSchemaEqual(actual.schema, expected.schema) # type: ignore
219
+
220
+ if not isinstance(actual, list):
221
+ actual_list = actual.collect() # type: ignore
222
+ else:
223
+ actual_list = actual
224
+
225
+ if not isinstance(expected, list):
226
+ expected_list = expected.collect() # type: ignore
227
+ else:
228
+ expected_list = expected
229
+
230
+ if not checkRowOrder:
231
+ # rename duplicate columns for sorting
232
+ actual_list = sorted(actual_list, key=lambda x: str(x))
233
+ expected_list = sorted(expected_list, key=lambda x: str(x))
234
+
235
+ assert_rows_equal(actual_list, expected_list)
236
+
237
+
238
+ def assertSchemaEqual(actual: types.StructType, expected: types.StructType):
239
+ r"""
240
+ A util function to assert equality between DataFrame schemas `actual` and `expected`.
241
+
242
+ .. versionadded:: 3.5.0
243
+
244
+ Parameters
245
+ ----------
246
+ actual : StructType
247
+ The DataFrame schema that is being compared or tested.
248
+ expected : StructType
249
+ The expected schema, for comparison with the actual schema.
250
+
251
+ Notes
252
+ -----
253
+ When assertSchemaEqual fails, the error message uses the Python `difflib` library to display
254
+ a diff log of the `actual` and `expected` schemas.
255
+
256
+ Examples
257
+ --------
258
+ >>> from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType, DoubleType
259
+ >>> s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
260
+ >>> s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
261
+ >>> assertSchemaEqual(s1, s2) # pass, schemas are identical
262
+
263
+ >>> df1 = spark.createDataFrame(data=[(1, 1000), (2, 3000)], schema=["id", "number"])
264
+ >>> df2 = spark.createDataFrame(data=[("1", 1000), ("2", 5000)], schema=["id", "amount"])
265
+ >>> assertSchemaEqual(df1.schema, df2.schema) # doctest: +IGNORE_EXCEPTION_DETAIL
266
+ Traceback (most recent call last):
267
+ ...
268
+ PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
269
+ --- actual
270
+ +++ expected
271
+ - StructType([StructField('id', LongType(), True), StructField('number', LongType(), True)])
272
+ ? ^^ ^^^^^
273
+ + StructType([StructField('id', StringType(), True), StructField('amount', LongType(), True)])
274
+ ? ^^^^ ++++ ^
275
+ """
276
+ if not isinstance(actual, types.StructType):
277
+ raise RuntimeError("actual must be a StructType")
278
+ if not isinstance(expected, types.StructType):
279
+ raise RuntimeError("expected must be a StructType")
280
+
281
+ def compare_schemas_ignore_nullable(s1: types.StructType, s2: types.StructType):
282
+ if len(s1) != len(s2):
283
+ return False
284
+ zipped = zip_longest(s1, s2)
285
+ for sf1, sf2 in zipped:
286
+ if not compare_structfields_ignore_nullable(sf1, sf2):
287
+ return False
288
+ return True
289
+
290
+ def compare_structfields_ignore_nullable(
291
+ actualSF: types.StructField, expectedSF: types.StructField
292
+ ):
293
+ if actualSF is None and expectedSF is None:
294
+ return True
295
+ elif actualSF is None or expectedSF is None:
296
+ return False
297
+ if actualSF.name != expectedSF.name:
298
+ return False
299
+ else:
300
+ return compare_datatypes_ignore_nullable(actualSF.dataType, expectedSF.dataType)
301
+
302
+ def compare_datatypes_ignore_nullable(dt1: t.Any, dt2: t.Any):
303
+ # checks datatype equality, using recursion to ignore nullable
304
+ if dt1.typeName() == dt2.typeName():
305
+ if dt1.typeName() == "array":
306
+ return compare_datatypes_ignore_nullable(dt1.elementType, dt2.elementType)
307
+ elif dt1.typeName() == "struct":
308
+ return compare_schemas_ignore_nullable(dt1, dt2)
309
+ else:
310
+ return True
311
+ else:
312
+ return False
313
+
314
+ # ignore nullable flag by default
315
+ if not compare_schemas_ignore_nullable(actual, expected):
316
+ generated_diff = difflib.ndiff(str(actual).splitlines(), str(expected).splitlines())
317
+
318
+ error_msg = "\n".join(generated_diff)
319
+
320
+ raise SchemaDiffError(error_msg)