pointblank 0.11.6__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pointblank/__init__.py +2 -0
- pointblank/_constants.py +0 -1
- pointblank/_interrogation.py +181 -38
- pointblank/_utils.py +29 -2
- pointblank/assistant.py +9 -0
- pointblank/cli.py +39 -24
- pointblank/data/api-docs.txt +658 -29
- pointblank/schema.py +17 -0
- pointblank/segments.py +163 -0
- pointblank/validate.py +317 -56
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/METADATA +58 -5
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/RECORD +16 -15
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/WHEEL +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/entry_points.txt +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.0.dist-info}/top_level.txt +0 -0
pointblank/validate.py
CHANGED
|
@@ -87,6 +87,7 @@ from pointblank._utils_check_args import (
|
|
|
87
87
|
from pointblank._utils_html import _create_table_dims_html, _create_table_type_html
|
|
88
88
|
from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col
|
|
89
89
|
from pointblank.schema import Schema, _get_schema_validation_info
|
|
90
|
+
from pointblank.segments import Segment
|
|
90
91
|
from pointblank.thresholds import (
|
|
91
92
|
Actions,
|
|
92
93
|
FinalActions,
|
|
@@ -1194,6 +1195,7 @@ def preview(
|
|
|
1194
1195
|
|
|
1195
1196
|
- Polars DataFrame (`"polars"`)
|
|
1196
1197
|
- Pandas DataFrame (`"pandas"`)
|
|
1198
|
+
- PySpark table (`"pyspark"`)
|
|
1197
1199
|
- DuckDB table (`"duckdb"`)*
|
|
1198
1200
|
- MySQL table (`"mysql"`)*
|
|
1199
1201
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -1201,7 +1203,6 @@ def preview(
|
|
|
1201
1203
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
1202
1204
|
- Snowflake table (`"snowflake"`)*
|
|
1203
1205
|
- Databricks table (`"databricks"`)*
|
|
1204
|
-
- PySpark table (`"pyspark"`)*
|
|
1205
1206
|
- BigQuery table (`"bigquery"`)*
|
|
1206
1207
|
- Parquet table (`"parquet"`)*
|
|
1207
1208
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -1396,7 +1397,10 @@ def _generate_display_table(
|
|
|
1396
1397
|
row_number_list: list[int] | None = None,
|
|
1397
1398
|
) -> GT:
|
|
1398
1399
|
# Make a copy of the data to avoid modifying the original
|
|
1399
|
-
|
|
1400
|
+
# Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
|
|
1401
|
+
tbl_type = _get_tbl_type(data=data)
|
|
1402
|
+
if "pyspark" not in tbl_type:
|
|
1403
|
+
data = copy.deepcopy(data)
|
|
1400
1404
|
|
|
1401
1405
|
# Does the data table already have a leading row number column?
|
|
1402
1406
|
if "_row_num_" in data.columns:
|
|
@@ -1422,22 +1426,31 @@ def _generate_display_table(
|
|
|
1422
1426
|
# Determine if the table is a DataFrame or an Ibis table
|
|
1423
1427
|
tbl_type = _get_tbl_type(data=data)
|
|
1424
1428
|
ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
|
|
1425
|
-
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
|
|
1429
|
+
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
|
|
1426
1430
|
|
|
1427
1431
|
# Select the DataFrame library to use for displaying the Ibis table
|
|
1428
1432
|
df_lib_gt = _select_df_lib(preference="polars")
|
|
1429
1433
|
df_lib_name_gt = df_lib_gt.__name__
|
|
1430
1434
|
|
|
1431
|
-
# If the table is a DataFrame (Pandas or
|
|
1432
|
-
# library (e.g., "polars" or "
|
|
1435
|
+
# If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
|
|
1436
|
+
# library (e.g., "polars", "pandas", or "pyspark")
|
|
1433
1437
|
if pl_pb_tbl:
|
|
1434
|
-
|
|
1438
|
+
if "polars" in tbl_type:
|
|
1439
|
+
df_lib_name_gt = "polars"
|
|
1440
|
+
elif "pandas" in tbl_type:
|
|
1441
|
+
df_lib_name_gt = "pandas"
|
|
1442
|
+
elif "pyspark" in tbl_type:
|
|
1443
|
+
df_lib_name_gt = "pyspark"
|
|
1435
1444
|
|
|
1436
|
-
# Handle imports of Polars or
|
|
1445
|
+
# Handle imports of Polars, Pandas, or PySpark here
|
|
1437
1446
|
if df_lib_name_gt == "polars":
|
|
1438
1447
|
import polars as pl
|
|
1439
|
-
|
|
1448
|
+
elif df_lib_name_gt == "pandas":
|
|
1440
1449
|
import pandas as pd
|
|
1450
|
+
elif df_lib_name_gt == "pyspark":
|
|
1451
|
+
# Import pandas for conversion since Great Tables needs pandas DataFrame
|
|
1452
|
+
import pandas as pd
|
|
1453
|
+
# Note: PySpark import is handled as needed, typically already imported in user's environment
|
|
1441
1454
|
|
|
1442
1455
|
# Get the initial column count for the table
|
|
1443
1456
|
n_columns = len(data.columns)
|
|
@@ -1547,6 +1560,42 @@ def _generate_display_table(
|
|
|
1547
1560
|
range(n_rows - n_tail + 1, n_rows + 1)
|
|
1548
1561
|
)
|
|
1549
1562
|
|
|
1563
|
+
if tbl_type == "pyspark":
|
|
1564
|
+
n_rows = data.count()
|
|
1565
|
+
|
|
1566
|
+
# If n_head + n_tail is greater than the row count, display the entire table
|
|
1567
|
+
if n_head + n_tail >= n_rows:
|
|
1568
|
+
full_dataset = True
|
|
1569
|
+
# Convert to pandas for Great Tables compatibility
|
|
1570
|
+
data = data.toPandas()
|
|
1571
|
+
|
|
1572
|
+
row_number_list = range(1, n_rows + 1)
|
|
1573
|
+
else:
|
|
1574
|
+
# Get head and tail samples, then convert to pandas
|
|
1575
|
+
head_data = data.limit(n_head).toPandas()
|
|
1576
|
+
|
|
1577
|
+
# PySpark tail() returns a list of Row objects, need to convert to DataFrame
|
|
1578
|
+
tail_rows = data.tail(n_tail)
|
|
1579
|
+
if tail_rows:
|
|
1580
|
+
# Convert list of Row objects back to DataFrame, then to pandas
|
|
1581
|
+
tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema)
|
|
1582
|
+
tail_data = tail_df.toPandas()
|
|
1583
|
+
else:
|
|
1584
|
+
# If no tail data, create empty DataFrame with same schema
|
|
1585
|
+
import pandas as pd
|
|
1586
|
+
|
|
1587
|
+
tail_data = pd.DataFrame(columns=head_data.columns)
|
|
1588
|
+
|
|
1589
|
+
data = pd.concat([head_data, tail_data])
|
|
1590
|
+
|
|
1591
|
+
row_number_list = list(range(1, n_head + 1)) + list(
|
|
1592
|
+
range(n_rows - n_tail + 1, n_rows + 1)
|
|
1593
|
+
)
|
|
1594
|
+
|
|
1595
|
+
# For PySpark, update schema after conversion to pandas
|
|
1596
|
+
if tbl_type == "pyspark":
|
|
1597
|
+
tbl_schema = Schema(tbl=data)
|
|
1598
|
+
|
|
1550
1599
|
# From the table schema, get a list of tuples containing column names and data types
|
|
1551
1600
|
col_dtype_dict = tbl_schema.columns
|
|
1552
1601
|
|
|
@@ -1566,6 +1615,23 @@ def _generate_display_table(
|
|
|
1566
1615
|
# This is used to highlight these values in the table
|
|
1567
1616
|
if df_lib_name_gt == "polars":
|
|
1568
1617
|
none_values = {k: data[k].is_null().to_list() for k in col_names}
|
|
1618
|
+
elif df_lib_name_gt == "pyspark":
|
|
1619
|
+
# For PySpark, check if data has been converted to pandas already
|
|
1620
|
+
if hasattr(data, "isnull"):
|
|
1621
|
+
# Data has been converted to pandas
|
|
1622
|
+
none_values = {k: data[k].isnull() for k in col_names}
|
|
1623
|
+
else:
|
|
1624
|
+
# Data is still a PySpark DataFrame - use narwhals
|
|
1625
|
+
import narwhals as nw
|
|
1626
|
+
|
|
1627
|
+
df_nw = nw.from_native(data)
|
|
1628
|
+
none_values = {}
|
|
1629
|
+
for col in col_names:
|
|
1630
|
+
# Get null mask, collect to pandas, then convert to list
|
|
1631
|
+
null_mask = (
|
|
1632
|
+
df_nw.select(nw.col(col).is_null()).collect().to_pandas().iloc[:, 0].tolist()
|
|
1633
|
+
)
|
|
1634
|
+
none_values[col] = null_mask
|
|
1569
1635
|
else:
|
|
1570
1636
|
none_values = {k: data[k].isnull() for k in col_names}
|
|
1571
1637
|
|
|
@@ -1579,7 +1645,13 @@ def _generate_display_table(
|
|
|
1579
1645
|
|
|
1580
1646
|
for column in col_dtype_dict.keys():
|
|
1581
1647
|
# Select a single column of values
|
|
1582
|
-
|
|
1648
|
+
if df_lib_name_gt == "pandas":
|
|
1649
|
+
data_col = data[[column]]
|
|
1650
|
+
elif df_lib_name_gt == "pyspark":
|
|
1651
|
+
# PySpark data should have been converted to pandas by now
|
|
1652
|
+
data_col = data[[column]]
|
|
1653
|
+
else:
|
|
1654
|
+
data_col = data.select([column])
|
|
1583
1655
|
|
|
1584
1656
|
# Using Great Tables, render the columns and get the list of values as formatted strings
|
|
1585
1657
|
built_gt = GT(data=data_col).fmt_markdown(columns=column)._build_data(context="html")
|
|
@@ -1658,6 +1730,10 @@ def _generate_display_table(
|
|
|
1658
1730
|
if df_lib_name_gt == "pandas":
|
|
1659
1731
|
data.insert(0, "_row_num_", row_number_list)
|
|
1660
1732
|
|
|
1733
|
+
if df_lib_name_gt == "pyspark":
|
|
1734
|
+
# For PySpark converted to pandas, use pandas method
|
|
1735
|
+
data.insert(0, "_row_num_", row_number_list)
|
|
1736
|
+
|
|
1661
1737
|
# Get the highest number in the `row_number_list` and calculate a width that will
|
|
1662
1738
|
# safely fit a number of that magnitude
|
|
1663
1739
|
if row_number_list: # Check if list is not empty
|
|
@@ -1791,6 +1867,7 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1791
1867
|
|
|
1792
1868
|
- Polars DataFrame (`"polars"`)
|
|
1793
1869
|
- Pandas DataFrame (`"pandas"`)
|
|
1870
|
+
- PySpark table (`"pyspark"`)
|
|
1794
1871
|
- DuckDB table (`"duckdb"`)*
|
|
1795
1872
|
- MySQL table (`"mysql"`)*
|
|
1796
1873
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -1798,7 +1875,6 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1798
1875
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
1799
1876
|
- Snowflake table (`"snowflake"`)*
|
|
1800
1877
|
- Databricks table (`"databricks"`)*
|
|
1801
|
-
- PySpark table (`"pyspark"`)*
|
|
1802
1878
|
- BigQuery table (`"bigquery"`)*
|
|
1803
1879
|
- Parquet table (`"parquet"`)*
|
|
1804
1880
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -1851,7 +1927,10 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1851
1927
|
data = _process_data(data)
|
|
1852
1928
|
|
|
1853
1929
|
# Make a copy of the data to avoid modifying the original
|
|
1854
|
-
|
|
1930
|
+
# Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
|
|
1931
|
+
tbl_type = _get_tbl_type(data=data)
|
|
1932
|
+
if "pyspark" not in tbl_type:
|
|
1933
|
+
data = copy.deepcopy(data)
|
|
1855
1934
|
|
|
1856
1935
|
# Get the number of rows in the table
|
|
1857
1936
|
n_rows = get_row_count(data)
|
|
@@ -1868,22 +1947,28 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1868
1947
|
# Determine if the table is a DataFrame or an Ibis table
|
|
1869
1948
|
tbl_type = _get_tbl_type(data=data)
|
|
1870
1949
|
ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
|
|
1871
|
-
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
|
|
1950
|
+
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
|
|
1872
1951
|
|
|
1873
1952
|
# Select the DataFrame library to use for displaying the Ibis table
|
|
1874
1953
|
df_lib_gt = _select_df_lib(preference="polars")
|
|
1875
1954
|
df_lib_name_gt = df_lib_gt.__name__
|
|
1876
1955
|
|
|
1877
|
-
# If the table is a DataFrame (Pandas or
|
|
1878
|
-
# library (e.g., "polars" or "
|
|
1956
|
+
# If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
|
|
1957
|
+
# library (e.g., "polars", "pandas", or "pyspark")
|
|
1879
1958
|
if pl_pb_tbl:
|
|
1880
|
-
|
|
1959
|
+
if "polars" in tbl_type:
|
|
1960
|
+
df_lib_name_gt = "polars"
|
|
1961
|
+
elif "pandas" in tbl_type:
|
|
1962
|
+
df_lib_name_gt = "pandas"
|
|
1963
|
+
elif "pyspark" in tbl_type:
|
|
1964
|
+
df_lib_name_gt = "pyspark"
|
|
1881
1965
|
|
|
1882
|
-
# Handle imports of Polars or
|
|
1966
|
+
# Handle imports of Polars, Pandas, or PySpark here
|
|
1883
1967
|
if df_lib_name_gt == "polars":
|
|
1884
1968
|
import polars as pl
|
|
1885
|
-
|
|
1969
|
+
elif df_lib_name_gt == "pandas":
|
|
1886
1970
|
import pandas as pd
|
|
1971
|
+
# Note: PySpark import is handled as needed, typically already imported in user's environment
|
|
1887
1972
|
|
|
1888
1973
|
# From an Ibis table:
|
|
1889
1974
|
# - get the row count
|
|
@@ -2047,6 +2132,77 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
2047
2132
|
# Get a dictionary of counts of missing values in each column
|
|
2048
2133
|
missing_val_counts = {col: data[col].isnull().sum() for col in data.columns}
|
|
2049
2134
|
|
|
2135
|
+
if "pyspark" in tbl_type:
|
|
2136
|
+
from pyspark.sql.functions import col as pyspark_col
|
|
2137
|
+
|
|
2138
|
+
# PySpark implementation for missing values calculation
|
|
2139
|
+
missing_vals = {}
|
|
2140
|
+
for col_name in data.columns:
|
|
2141
|
+
col_missing_props = []
|
|
2142
|
+
|
|
2143
|
+
# Calculate missing value proportions for each sector
|
|
2144
|
+
for i in range(len(cut_points)):
|
|
2145
|
+
start_row = cut_points[i - 1] if i > 0 else 0
|
|
2146
|
+
end_row = cut_points[i]
|
|
2147
|
+
sector_size = end_row - start_row
|
|
2148
|
+
|
|
2149
|
+
if sector_size > 0:
|
|
2150
|
+
# Use row_number() to filter rows by range
|
|
2151
|
+
from pyspark.sql.functions import row_number
|
|
2152
|
+
from pyspark.sql.window import Window
|
|
2153
|
+
|
|
2154
|
+
window = Window.orderBy(
|
|
2155
|
+
pyspark_col(data.columns[0])
|
|
2156
|
+
) # Order by first column
|
|
2157
|
+
sector_data = data.withColumn("row_num", row_number().over(window)).filter(
|
|
2158
|
+
(pyspark_col("row_num") > start_row)
|
|
2159
|
+
& (pyspark_col("row_num") <= end_row)
|
|
2160
|
+
)
|
|
2161
|
+
|
|
2162
|
+
# Count nulls in this sector
|
|
2163
|
+
null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
|
|
2164
|
+
missing_prop = (null_count / sector_size) * 100
|
|
2165
|
+
col_missing_props.append(missing_prop)
|
|
2166
|
+
else:
|
|
2167
|
+
col_missing_props.append(0)
|
|
2168
|
+
|
|
2169
|
+
# Handle the final sector (after last cut point)
|
|
2170
|
+
if n_rows > cut_points[-1]:
|
|
2171
|
+
start_row = cut_points[-1]
|
|
2172
|
+
end_row = n_rows
|
|
2173
|
+
sector_size = end_row - start_row
|
|
2174
|
+
|
|
2175
|
+
from pyspark.sql.functions import row_number
|
|
2176
|
+
from pyspark.sql.window import Window
|
|
2177
|
+
|
|
2178
|
+
window = Window.orderBy(pyspark_col(data.columns[0]))
|
|
2179
|
+
sector_data = data.withColumn("row_num", row_number().over(window)).filter(
|
|
2180
|
+
pyspark_col("row_num") > start_row
|
|
2181
|
+
)
|
|
2182
|
+
|
|
2183
|
+
null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
|
|
2184
|
+
missing_prop = (null_count / sector_size) * 100
|
|
2185
|
+
col_missing_props.append(missing_prop)
|
|
2186
|
+
else:
|
|
2187
|
+
col_missing_props.append(0)
|
|
2188
|
+
|
|
2189
|
+
missing_vals[col_name] = col_missing_props
|
|
2190
|
+
|
|
2191
|
+
# Pivot the `missing_vals` dictionary to create a table with the missing value proportions
|
|
2192
|
+
missing_vals = {
|
|
2193
|
+
"columns": list(missing_vals.keys()),
|
|
2194
|
+
**{
|
|
2195
|
+
str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
|
|
2196
|
+
for i in range(len(cut_points) + 1)
|
|
2197
|
+
},
|
|
2198
|
+
}
|
|
2199
|
+
|
|
2200
|
+
# Get a dictionary of counts of missing values in each column
|
|
2201
|
+
missing_val_counts = {}
|
|
2202
|
+
for col_name in data.columns:
|
|
2203
|
+
null_count = data.filter(pyspark_col(col_name).isNull()).count()
|
|
2204
|
+
missing_val_counts[col_name] = null_count
|
|
2205
|
+
|
|
2050
2206
|
# From `missing_vals`, create the DataFrame with the missing value proportions
|
|
2051
2207
|
if df_lib_name_gt == "polars":
|
|
2052
2208
|
import polars as pl
|
|
@@ -2333,6 +2489,7 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2333
2489
|
|
|
2334
2490
|
- Polars DataFrame (`"polars"`)
|
|
2335
2491
|
- Pandas DataFrame (`"pandas"`)
|
|
2492
|
+
- PySpark table (`"pyspark"`)
|
|
2336
2493
|
- DuckDB table (`"duckdb"`)*
|
|
2337
2494
|
- MySQL table (`"mysql"`)*
|
|
2338
2495
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -2340,7 +2497,6 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2340
2497
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
2341
2498
|
- Snowflake table (`"snowflake"`)*
|
|
2342
2499
|
- Databricks table (`"databricks"`)*
|
|
2343
|
-
- PySpark table (`"pyspark"`)*
|
|
2344
2500
|
- BigQuery table (`"bigquery"`)*
|
|
2345
2501
|
- Parquet table (`"parquet"`)*
|
|
2346
2502
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -2467,6 +2623,9 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2467
2623
|
elif "pandas" in str(type(data)):
|
|
2468
2624
|
return data.shape[1]
|
|
2469
2625
|
|
|
2626
|
+
elif "pyspark" in str(type(data)):
|
|
2627
|
+
return len(data.columns)
|
|
2628
|
+
|
|
2470
2629
|
elif "narwhals" in str(type(data)):
|
|
2471
2630
|
return len(data.columns)
|
|
2472
2631
|
|
|
@@ -2501,6 +2660,7 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2501
2660
|
|
|
2502
2661
|
- Polars DataFrame (`"polars"`)
|
|
2503
2662
|
- Pandas DataFrame (`"pandas"`)
|
|
2663
|
+
- PySpark table (`"pyspark"`)
|
|
2504
2664
|
- DuckDB table (`"duckdb"`)*
|
|
2505
2665
|
- MySQL table (`"mysql"`)*
|
|
2506
2666
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -2508,7 +2668,6 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2508
2668
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
2509
2669
|
- Snowflake table (`"snowflake"`)*
|
|
2510
2670
|
- Databricks table (`"databricks"`)*
|
|
2511
|
-
- PySpark table (`"pyspark"`)*
|
|
2512
2671
|
- BigQuery table (`"bigquery"`)*
|
|
2513
2672
|
- Parquet table (`"parquet"`)*
|
|
2514
2673
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -2646,6 +2805,9 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2646
2805
|
elif "pandas" in str(type(data)):
|
|
2647
2806
|
return data.shape[0]
|
|
2648
2807
|
|
|
2808
|
+
elif "pyspark" in str(type(data)):
|
|
2809
|
+
return data.count()
|
|
2810
|
+
|
|
2649
2811
|
elif "narwhals" in str(type(data)):
|
|
2650
2812
|
return data.shape[0]
|
|
2651
2813
|
|
|
@@ -3098,6 +3260,7 @@ class Validate:
|
|
|
3098
3260
|
|
|
3099
3261
|
- Polars DataFrame (`"polars"`)
|
|
3100
3262
|
- Pandas DataFrame (`"pandas"`)
|
|
3263
|
+
- PySpark table (`"pyspark"`)
|
|
3101
3264
|
- DuckDB table (`"duckdb"`)*
|
|
3102
3265
|
- MySQL table (`"mysql"`)*
|
|
3103
3266
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -3105,7 +3268,6 @@ class Validate:
|
|
|
3105
3268
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
3106
3269
|
- Snowflake table (`"snowflake"`)*
|
|
3107
3270
|
- Databricks table (`"databricks"`)*
|
|
3108
|
-
- PySpark table (`"pyspark"`)*
|
|
3109
3271
|
- BigQuery table (`"bigquery"`)*
|
|
3110
3272
|
- Parquet table (`"parquet"`)*
|
|
3111
3273
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -9983,12 +10145,22 @@ class Validate:
|
|
|
9983
10145
|
and tbl_type not in IBIS_BACKENDS
|
|
9984
10146
|
):
|
|
9985
10147
|
# Add row numbers to the results table
|
|
9986
|
-
validation_extract_nw = (
|
|
9987
|
-
|
|
9988
|
-
|
|
9989
|
-
|
|
9990
|
-
|
|
9991
|
-
|
|
10148
|
+
validation_extract_nw = nw.from_native(results_tbl)
|
|
10149
|
+
|
|
10150
|
+
# Handle LazyFrame row indexing which requires order_by parameter
|
|
10151
|
+
try:
|
|
10152
|
+
# Try without order_by first (for DataFrames)
|
|
10153
|
+
validation_extract_nw = validation_extract_nw.with_row_index(name="_row_num_")
|
|
10154
|
+
except TypeError:
|
|
10155
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
10156
|
+
first_col = validation_extract_nw.columns[0]
|
|
10157
|
+
validation_extract_nw = validation_extract_nw.with_row_index(
|
|
10158
|
+
name="_row_num_", order_by=first_col
|
|
10159
|
+
)
|
|
10160
|
+
|
|
10161
|
+
validation_extract_nw = validation_extract_nw.filter(~nw.col("pb_is_good_")).drop(
|
|
10162
|
+
"pb_is_good_"
|
|
10163
|
+
) # noqa
|
|
9992
10164
|
|
|
9993
10165
|
# Add 1 to the row numbers to make them 1-indexed
|
|
9994
10166
|
validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)
|
|
@@ -9997,12 +10169,52 @@ class Validate:
|
|
|
9997
10169
|
if get_first_n is not None:
|
|
9998
10170
|
validation_extract_nw = validation_extract_nw.head(get_first_n)
|
|
9999
10171
|
elif sample_n is not None:
|
|
10000
|
-
|
|
10172
|
+
# Narwhals LazyFrame doesn't have sample method, use head after shuffling
|
|
10173
|
+
try:
|
|
10174
|
+
validation_extract_nw = validation_extract_nw.sample(n=sample_n)
|
|
10175
|
+
except AttributeError:
|
|
10176
|
+
# For LazyFrames without sample method, collect first then sample
|
|
10177
|
+
validation_extract_native = validation_extract_nw.collect().to_native()
|
|
10178
|
+
if hasattr(validation_extract_native, "sample"):
|
|
10179
|
+
# PySpark DataFrame has sample method
|
|
10180
|
+
validation_extract_native = validation_extract_native.sample(
|
|
10181
|
+
fraction=min(1.0, sample_n / validation_extract_native.count())
|
|
10182
|
+
).limit(sample_n)
|
|
10183
|
+
validation_extract_nw = nw.from_native(validation_extract_native)
|
|
10184
|
+
else:
|
|
10185
|
+
# Fallback: just take first n rows after collecting
|
|
10186
|
+
validation_extract_nw = validation_extract_nw.collect().head(sample_n)
|
|
10001
10187
|
elif sample_frac is not None:
|
|
10002
|
-
|
|
10188
|
+
try:
|
|
10189
|
+
validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
|
|
10190
|
+
except AttributeError:
|
|
10191
|
+
# For LazyFrames without sample method, collect first then sample
|
|
10192
|
+
validation_extract_native = validation_extract_nw.collect().to_native()
|
|
10193
|
+
if hasattr(validation_extract_native, "sample"):
|
|
10194
|
+
# PySpark DataFrame has sample method
|
|
10195
|
+
validation_extract_native = validation_extract_native.sample(
|
|
10196
|
+
fraction=sample_frac
|
|
10197
|
+
)
|
|
10198
|
+
validation_extract_nw = nw.from_native(validation_extract_native)
|
|
10199
|
+
else:
|
|
10200
|
+
# Fallback: use fraction to calculate head size
|
|
10201
|
+
collected = validation_extract_nw.collect()
|
|
10202
|
+
sample_size = max(1, int(len(collected) * sample_frac))
|
|
10203
|
+
validation_extract_nw = collected.head(sample_size)
|
|
10003
10204
|
|
|
10004
10205
|
# Ensure a limit is set on the number of rows to extract
|
|
10005
|
-
|
|
10206
|
+
try:
|
|
10207
|
+
# For DataFrames, use len()
|
|
10208
|
+
extract_length = len(validation_extract_nw)
|
|
10209
|
+
except TypeError:
|
|
10210
|
+
# For LazyFrames, collect to get length (or use a reasonable default)
|
|
10211
|
+
try:
|
|
10212
|
+
extract_length = len(validation_extract_nw.collect())
|
|
10213
|
+
except Exception:
|
|
10214
|
+
# If collection fails, apply limit anyway as a safety measure
|
|
10215
|
+
extract_length = extract_limit + 1 # Force limiting
|
|
10216
|
+
|
|
10217
|
+
if extract_length > extract_limit:
|
|
10006
10218
|
validation_extract_nw = validation_extract_nw.head(extract_limit)
|
|
10007
10219
|
|
|
10008
10220
|
# If a 'rows_distinct' validation step, then the extract should have the
|
|
@@ -10030,7 +10242,10 @@ class Validate:
|
|
|
10030
10242
|
.drop("group_min_row")
|
|
10031
10243
|
)
|
|
10032
10244
|
|
|
10033
|
-
# Ensure that the extract is set to its native format
|
|
10245
|
+
# Ensure that the extract is collected and set to its native format
|
|
10246
|
+
# For LazyFrames (like PySpark), we need to collect before converting to native
|
|
10247
|
+
if hasattr(validation_extract_nw, "collect"):
|
|
10248
|
+
validation_extract_nw = validation_extract_nw.collect()
|
|
10034
10249
|
validation.extract = nw.to_native(validation_extract_nw)
|
|
10035
10250
|
|
|
10036
10251
|
# Get the end time for this step
|
|
@@ -11656,7 +11871,16 @@ class Validate:
|
|
|
11656
11871
|
# TODO: add argument for user to specify the index column name
|
|
11657
11872
|
index_name = "pb_index_"
|
|
11658
11873
|
|
|
11659
|
-
data_nw = nw.from_native(self.data)
|
|
11874
|
+
data_nw = nw.from_native(self.data)
|
|
11875
|
+
|
|
11876
|
+
# Handle LazyFrame row indexing which requires order_by parameter
|
|
11877
|
+
try:
|
|
11878
|
+
# Try without order_by first (for DataFrames)
|
|
11879
|
+
data_nw = data_nw.with_row_index(name=index_name)
|
|
11880
|
+
except TypeError:
|
|
11881
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
11882
|
+
first_col = data_nw.columns[0]
|
|
11883
|
+
data_nw = data_nw.with_row_index(name=index_name, order_by=first_col)
|
|
11660
11884
|
|
|
11661
11885
|
# Get all validation step result tables and join together the `pb_is_good_` columns
|
|
11662
11886
|
# ensuring that the columns are named uniquely (e.g., `pb_is_good_1`, `pb_is_good_2`, ...)
|
|
@@ -11665,7 +11889,13 @@ class Validate:
|
|
|
11665
11889
|
results_tbl = nw.from_native(validation.tbl_checked)
|
|
11666
11890
|
|
|
11667
11891
|
# Add row numbers to the results table
|
|
11668
|
-
|
|
11892
|
+
try:
|
|
11893
|
+
# Try without order_by first (for DataFrames)
|
|
11894
|
+
results_tbl = results_tbl.with_row_index(name=index_name)
|
|
11895
|
+
except TypeError:
|
|
11896
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
11897
|
+
first_col = results_tbl.columns[0]
|
|
11898
|
+
results_tbl = results_tbl.with_row_index(name=index_name, order_by=first_col)
|
|
11669
11899
|
|
|
11670
11900
|
# Add numerical suffix to the `pb_is_good_` column to make it unique
|
|
11671
11901
|
results_tbl = results_tbl.select([index_name, "pb_is_good_"]).rename(
|
|
@@ -12284,15 +12514,21 @@ class Validate:
|
|
|
12284
12514
|
# Transform to Narwhals DataFrame
|
|
12285
12515
|
extract_nw = nw.from_native(extract)
|
|
12286
12516
|
|
|
12287
|
-
# Get the number of rows in the extract
|
|
12288
|
-
|
|
12517
|
+
# Get the number of rows in the extract (safe for LazyFrames)
|
|
12518
|
+
try:
|
|
12519
|
+
n_rows = len(extract_nw)
|
|
12520
|
+
except TypeError:
|
|
12521
|
+
# For LazyFrames, collect() first to get length
|
|
12522
|
+
n_rows = len(extract_nw.collect()) if hasattr(extract_nw, "collect") else 0
|
|
12289
12523
|
|
|
12290
12524
|
# If the number of rows is zero, then produce an em dash then go to the next iteration
|
|
12291
12525
|
if n_rows == 0:
|
|
12292
12526
|
extract_upd.append("—")
|
|
12293
12527
|
continue
|
|
12294
12528
|
|
|
12295
|
-
# Write the CSV text
|
|
12529
|
+
# Write the CSV text (ensure LazyFrames are collected first)
|
|
12530
|
+
if hasattr(extract_nw, "collect"):
|
|
12531
|
+
extract_nw = extract_nw.collect()
|
|
12296
12532
|
csv_text = extract_nw.write_csv()
|
|
12297
12533
|
|
|
12298
12534
|
# Use Base64 encoding to encode the CSV text
|
|
@@ -13856,7 +14092,7 @@ def _prep_values_text(
|
|
|
13856
14092
|
return values_str
|
|
13857
14093
|
|
|
13858
14094
|
|
|
13859
|
-
def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
14095
|
+
def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> list[tuple[str, str]]:
|
|
13860
14096
|
"""
|
|
13861
14097
|
Obtain the segmentation categories from a table column.
|
|
13862
14098
|
|
|
@@ -13881,22 +14117,27 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
|
13881
14117
|
list[tuple[str, str]]
|
|
13882
14118
|
A list of tuples representing pairings of a column name and a value in the column.
|
|
13883
14119
|
"""
|
|
14120
|
+
import narwhals as nw
|
|
14121
|
+
|
|
13884
14122
|
# Determine if the table is a DataFrame or a DB table
|
|
13885
14123
|
tbl_type = _get_tbl_type(data=data_tbl)
|
|
13886
14124
|
|
|
13887
14125
|
# Obtain the segmentation categories from the table column given as `segments_expr`
|
|
13888
|
-
if tbl_type
|
|
13889
|
-
|
|
13890
|
-
|
|
13891
|
-
|
|
14126
|
+
if tbl_type in ["polars", "pandas", "pyspark"]:
|
|
14127
|
+
# Use Narwhals for supported DataFrame types
|
|
14128
|
+
data_nw = nw.from_native(data_tbl)
|
|
14129
|
+
unique_vals = data_nw.select(nw.col(segments_expr)).unique()
|
|
14130
|
+
|
|
14131
|
+
# Convert to list of values
|
|
14132
|
+
seg_categories = unique_vals[segments_expr].to_list()
|
|
13892
14133
|
elif tbl_type in IBIS_BACKENDS:
|
|
13893
14134
|
distinct_col_vals = data_tbl.select(segments_expr).distinct()
|
|
13894
14135
|
seg_categories = distinct_col_vals[segments_expr].to_list()
|
|
13895
14136
|
else: # pragma: no cover
|
|
13896
14137
|
raise ValueError(f"Unsupported table type: {tbl_type}")
|
|
13897
14138
|
|
|
13898
|
-
# Ensure that the categories are sorted
|
|
13899
|
-
seg_categories.sort()
|
|
14139
|
+
# Ensure that the categories are sorted, and allow for None values
|
|
14140
|
+
seg_categories.sort(key=lambda x: (x is None, x))
|
|
13900
14141
|
|
|
13901
14142
|
# Place each category and each value in a list of tuples as: `(column, value)`
|
|
13902
14143
|
seg_tuples = [(segments_expr, category) for category in seg_categories]
|
|
@@ -13904,7 +14145,7 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
|
13904
14145
|
return seg_tuples
|
|
13905
14146
|
|
|
13906
14147
|
|
|
13907
|
-
def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str,
|
|
14148
|
+
def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, Any]]:
|
|
13908
14149
|
"""
|
|
13909
14150
|
Normalize the segments expression to a list of tuples, given a single tuple.
|
|
13910
14151
|
|
|
@@ -13930,17 +14171,23 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
|
|
|
13930
14171
|
|
|
13931
14172
|
Returns
|
|
13932
14173
|
-------
|
|
13933
|
-
list[tuple[str,
|
|
14174
|
+
list[tuple[str, Any]]
|
|
13934
14175
|
A list of tuples representing pairings of a column name and a value in the column.
|
|
14176
|
+
Values can be any type, including None.
|
|
13935
14177
|
"""
|
|
14178
|
+
# Unpack the segments expression tuple for more convenient and explicit variable names
|
|
14179
|
+
column, segment = segments_expr
|
|
14180
|
+
|
|
13936
14181
|
# Check if the first element is a string
|
|
13937
|
-
if isinstance(
|
|
13938
|
-
|
|
13939
|
-
|
|
13940
|
-
|
|
14182
|
+
if isinstance(column, str):
|
|
14183
|
+
if isinstance(segment, Segment):
|
|
14184
|
+
seg_tuples = [(column, seg) for seg in segment.segments]
|
|
14185
|
+
# If the second element is a collection, expand into a list of tuples
|
|
14186
|
+
elif isinstance(segment, (list, set, tuple)):
|
|
14187
|
+
seg_tuples = [(column, seg) for seg in segment]
|
|
13941
14188
|
# If the second element is not a list, create a single tuple
|
|
13942
14189
|
else:
|
|
13943
|
-
seg_tuples = [(
|
|
14190
|
+
seg_tuples = [(column, segment)]
|
|
13944
14191
|
# If the first element is not a string, raise an error
|
|
13945
14192
|
else: # pragma: no cover
|
|
13946
14193
|
raise ValueError("The first element of the segments expression must be a string.")
|
|
@@ -13948,7 +14195,7 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
|
|
|
13948
14195
|
return seg_tuples
|
|
13949
14196
|
|
|
13950
14197
|
|
|
13951
|
-
def _apply_segments(data_tbl: any, segments_expr: tuple[str,
|
|
14198
|
+
def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
|
|
13952
14199
|
"""
|
|
13953
14200
|
Apply the segments expression to the data table.
|
|
13954
14201
|
|
|
@@ -13971,15 +14218,24 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
|
|
|
13971
14218
|
# Get the table type
|
|
13972
14219
|
tbl_type = _get_tbl_type(data=data_tbl)
|
|
13973
14220
|
|
|
13974
|
-
|
|
13975
|
-
|
|
14221
|
+
# Unpack the segments expression tuple for more convenient and explicit variable names
|
|
14222
|
+
column, segment = segments_expr
|
|
14223
|
+
|
|
14224
|
+
if tbl_type in ["pandas", "polars", "pyspark"]:
|
|
14225
|
+
# If the table is a Pandas, Polars, or PySpark DataFrame, transforming to a Narwhals table
|
|
13976
14226
|
# and perform the filtering operation
|
|
13977
14227
|
|
|
13978
14228
|
# Transform to Narwhals table if a DataFrame
|
|
13979
14229
|
data_tbl_nw = nw.from_native(data_tbl)
|
|
13980
14230
|
|
|
13981
|
-
# Filter the data table based on the column name and
|
|
13982
|
-
|
|
14231
|
+
# Filter the data table based on the column name and segment
|
|
14232
|
+
if segment is None:
|
|
14233
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_null())
|
|
14234
|
+
# Check if the segment is a segment group
|
|
14235
|
+
elif isinstance(segment, list):
|
|
14236
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_in(segment))
|
|
14237
|
+
else:
|
|
14238
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column) == segment)
|
|
13983
14239
|
|
|
13984
14240
|
# Transform back to the original table type
|
|
13985
14241
|
data_tbl = data_tbl_nw.to_native()
|
|
@@ -13987,8 +14243,13 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
|
|
|
13987
14243
|
elif tbl_type in IBIS_BACKENDS:
|
|
13988
14244
|
# If the table is an Ibis backend table, perform the filtering operation directly
|
|
13989
14245
|
|
|
13990
|
-
# Filter the data table based on the column name and
|
|
13991
|
-
|
|
14246
|
+
# Filter the data table based on the column name and segment
|
|
14247
|
+
if segment is None:
|
|
14248
|
+
data_tbl = data_tbl[data_tbl[column].isnull()]
|
|
14249
|
+
elif isinstance(segment, list):
|
|
14250
|
+
data_tbl = data_tbl[data_tbl[column].isin(segment)]
|
|
14251
|
+
else:
|
|
14252
|
+
data_tbl = data_tbl[data_tbl[column] == segment]
|
|
13992
14253
|
|
|
13993
14254
|
return data_tbl
|
|
13994
14255
|
|