pointblank 0.11.6__py3-none-any.whl → 0.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pointblank/__init__.py +2 -0
- pointblank/_constants.py +0 -1
- pointblank/_interrogation.py +244 -606
- pointblank/_utils.py +65 -3
- pointblank/assistant.py +9 -0
- pointblank/cli.py +39 -24
- pointblank/data/api-docs.txt +658 -29
- pointblank/schema.py +17 -0
- pointblank/segments.py +163 -0
- pointblank/validate.py +344 -92
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/METADATA +59 -6
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/RECORD +16 -15
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/WHEEL +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/entry_points.txt +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/licenses/LICENSE +0 -0
- {pointblank-0.11.6.dist-info → pointblank-0.12.1.dist-info}/top_level.txt +0 -0
pointblank/validate.py
CHANGED
|
@@ -87,6 +87,7 @@ from pointblank._utils_check_args import (
|
|
|
87
87
|
from pointblank._utils_html import _create_table_dims_html, _create_table_type_html
|
|
88
88
|
from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col
|
|
89
89
|
from pointblank.schema import Schema, _get_schema_validation_info
|
|
90
|
+
from pointblank.segments import Segment
|
|
90
91
|
from pointblank.thresholds import (
|
|
91
92
|
Actions,
|
|
92
93
|
FinalActions,
|
|
@@ -1194,6 +1195,7 @@ def preview(
|
|
|
1194
1195
|
|
|
1195
1196
|
- Polars DataFrame (`"polars"`)
|
|
1196
1197
|
- Pandas DataFrame (`"pandas"`)
|
|
1198
|
+
- PySpark table (`"pyspark"`)
|
|
1197
1199
|
- DuckDB table (`"duckdb"`)*
|
|
1198
1200
|
- MySQL table (`"mysql"`)*
|
|
1199
1201
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -1201,7 +1203,6 @@ def preview(
|
|
|
1201
1203
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
1202
1204
|
- Snowflake table (`"snowflake"`)*
|
|
1203
1205
|
- Databricks table (`"databricks"`)*
|
|
1204
|
-
- PySpark table (`"pyspark"`)*
|
|
1205
1206
|
- BigQuery table (`"bigquery"`)*
|
|
1206
1207
|
- Parquet table (`"parquet"`)*
|
|
1207
1208
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -1396,7 +1397,10 @@ def _generate_display_table(
|
|
|
1396
1397
|
row_number_list: list[int] | None = None,
|
|
1397
1398
|
) -> GT:
|
|
1398
1399
|
# Make a copy of the data to avoid modifying the original
|
|
1399
|
-
|
|
1400
|
+
# Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
|
|
1401
|
+
tbl_type = _get_tbl_type(data=data)
|
|
1402
|
+
if "pyspark" not in tbl_type:
|
|
1403
|
+
data = copy.deepcopy(data)
|
|
1400
1404
|
|
|
1401
1405
|
# Does the data table already have a leading row number column?
|
|
1402
1406
|
if "_row_num_" in data.columns:
|
|
@@ -1422,22 +1426,31 @@ def _generate_display_table(
|
|
|
1422
1426
|
# Determine if the table is a DataFrame or an Ibis table
|
|
1423
1427
|
tbl_type = _get_tbl_type(data=data)
|
|
1424
1428
|
ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
|
|
1425
|
-
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
|
|
1429
|
+
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
|
|
1426
1430
|
|
|
1427
1431
|
# Select the DataFrame library to use for displaying the Ibis table
|
|
1428
1432
|
df_lib_gt = _select_df_lib(preference="polars")
|
|
1429
1433
|
df_lib_name_gt = df_lib_gt.__name__
|
|
1430
1434
|
|
|
1431
|
-
# If the table is a DataFrame (Pandas or
|
|
1432
|
-
# library (e.g., "polars" or "
|
|
1435
|
+
# If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
|
|
1436
|
+
# library (e.g., "polars", "pandas", or "pyspark")
|
|
1433
1437
|
if pl_pb_tbl:
|
|
1434
|
-
|
|
1438
|
+
if "polars" in tbl_type:
|
|
1439
|
+
df_lib_name_gt = "polars"
|
|
1440
|
+
elif "pandas" in tbl_type:
|
|
1441
|
+
df_lib_name_gt = "pandas"
|
|
1442
|
+
elif "pyspark" in tbl_type:
|
|
1443
|
+
df_lib_name_gt = "pyspark"
|
|
1435
1444
|
|
|
1436
|
-
# Handle imports of Polars or
|
|
1445
|
+
# Handle imports of Polars, Pandas, or PySpark here
|
|
1437
1446
|
if df_lib_name_gt == "polars":
|
|
1438
1447
|
import polars as pl
|
|
1439
|
-
|
|
1448
|
+
elif df_lib_name_gt == "pandas":
|
|
1440
1449
|
import pandas as pd
|
|
1450
|
+
elif df_lib_name_gt == "pyspark":
|
|
1451
|
+
# Import pandas for conversion since Great Tables needs pandas DataFrame
|
|
1452
|
+
import pandas as pd
|
|
1453
|
+
# Note: PySpark import is handled as needed, typically already imported in user's environment
|
|
1441
1454
|
|
|
1442
1455
|
# Get the initial column count for the table
|
|
1443
1456
|
n_columns = len(data.columns)
|
|
@@ -1547,6 +1560,42 @@ def _generate_display_table(
|
|
|
1547
1560
|
range(n_rows - n_tail + 1, n_rows + 1)
|
|
1548
1561
|
)
|
|
1549
1562
|
|
|
1563
|
+
if tbl_type == "pyspark":
|
|
1564
|
+
n_rows = data.count()
|
|
1565
|
+
|
|
1566
|
+
# If n_head + n_tail is greater than the row count, display the entire table
|
|
1567
|
+
if n_head + n_tail >= n_rows:
|
|
1568
|
+
full_dataset = True
|
|
1569
|
+
# Convert to pandas for Great Tables compatibility
|
|
1570
|
+
data = data.toPandas()
|
|
1571
|
+
|
|
1572
|
+
row_number_list = range(1, n_rows + 1)
|
|
1573
|
+
else:
|
|
1574
|
+
# Get head and tail samples, then convert to pandas
|
|
1575
|
+
head_data = data.limit(n_head).toPandas()
|
|
1576
|
+
|
|
1577
|
+
# PySpark tail() returns a list of Row objects, need to convert to DataFrame
|
|
1578
|
+
tail_rows = data.tail(n_tail)
|
|
1579
|
+
if tail_rows:
|
|
1580
|
+
# Convert list of Row objects back to DataFrame, then to pandas
|
|
1581
|
+
tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema)
|
|
1582
|
+
tail_data = tail_df.toPandas()
|
|
1583
|
+
else:
|
|
1584
|
+
# If no tail data, create empty DataFrame with same schema
|
|
1585
|
+
import pandas as pd
|
|
1586
|
+
|
|
1587
|
+
tail_data = pd.DataFrame(columns=head_data.columns)
|
|
1588
|
+
|
|
1589
|
+
data = pd.concat([head_data, tail_data])
|
|
1590
|
+
|
|
1591
|
+
row_number_list = list(range(1, n_head + 1)) + list(
|
|
1592
|
+
range(n_rows - n_tail + 1, n_rows + 1)
|
|
1593
|
+
)
|
|
1594
|
+
|
|
1595
|
+
# For PySpark, update schema after conversion to pandas
|
|
1596
|
+
if tbl_type == "pyspark":
|
|
1597
|
+
tbl_schema = Schema(tbl=data)
|
|
1598
|
+
|
|
1550
1599
|
# From the table schema, get a list of tuples containing column names and data types
|
|
1551
1600
|
col_dtype_dict = tbl_schema.columns
|
|
1552
1601
|
|
|
@@ -1566,6 +1615,23 @@ def _generate_display_table(
|
|
|
1566
1615
|
# This is used to highlight these values in the table
|
|
1567
1616
|
if df_lib_name_gt == "polars":
|
|
1568
1617
|
none_values = {k: data[k].is_null().to_list() for k in col_names}
|
|
1618
|
+
elif df_lib_name_gt == "pyspark":
|
|
1619
|
+
# For PySpark, check if data has been converted to pandas already
|
|
1620
|
+
if hasattr(data, "isnull"):
|
|
1621
|
+
# Data has been converted to pandas
|
|
1622
|
+
none_values = {k: data[k].isnull() for k in col_names}
|
|
1623
|
+
else:
|
|
1624
|
+
# Data is still a PySpark DataFrame - use narwhals
|
|
1625
|
+
import narwhals as nw
|
|
1626
|
+
|
|
1627
|
+
df_nw = nw.from_native(data)
|
|
1628
|
+
none_values = {}
|
|
1629
|
+
for col in col_names:
|
|
1630
|
+
# Get null mask, collect to pandas, then convert to list
|
|
1631
|
+
null_mask = (
|
|
1632
|
+
df_nw.select(nw.col(col).is_null()).collect().to_pandas().iloc[:, 0].tolist()
|
|
1633
|
+
)
|
|
1634
|
+
none_values[col] = null_mask
|
|
1569
1635
|
else:
|
|
1570
1636
|
none_values = {k: data[k].isnull() for k in col_names}
|
|
1571
1637
|
|
|
@@ -1579,7 +1645,13 @@ def _generate_display_table(
|
|
|
1579
1645
|
|
|
1580
1646
|
for column in col_dtype_dict.keys():
|
|
1581
1647
|
# Select a single column of values
|
|
1582
|
-
|
|
1648
|
+
if df_lib_name_gt == "pandas":
|
|
1649
|
+
data_col = data[[column]]
|
|
1650
|
+
elif df_lib_name_gt == "pyspark":
|
|
1651
|
+
# PySpark data should have been converted to pandas by now
|
|
1652
|
+
data_col = data[[column]]
|
|
1653
|
+
else:
|
|
1654
|
+
data_col = data.select([column])
|
|
1583
1655
|
|
|
1584
1656
|
# Using Great Tables, render the columns and get the list of values as formatted strings
|
|
1585
1657
|
built_gt = GT(data=data_col).fmt_markdown(columns=column)._build_data(context="html")
|
|
@@ -1658,6 +1730,10 @@ def _generate_display_table(
|
|
|
1658
1730
|
if df_lib_name_gt == "pandas":
|
|
1659
1731
|
data.insert(0, "_row_num_", row_number_list)
|
|
1660
1732
|
|
|
1733
|
+
if df_lib_name_gt == "pyspark":
|
|
1734
|
+
# For PySpark converted to pandas, use pandas method
|
|
1735
|
+
data.insert(0, "_row_num_", row_number_list)
|
|
1736
|
+
|
|
1661
1737
|
# Get the highest number in the `row_number_list` and calculate a width that will
|
|
1662
1738
|
# safely fit a number of that magnitude
|
|
1663
1739
|
if row_number_list: # Check if list is not empty
|
|
@@ -1791,6 +1867,7 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1791
1867
|
|
|
1792
1868
|
- Polars DataFrame (`"polars"`)
|
|
1793
1869
|
- Pandas DataFrame (`"pandas"`)
|
|
1870
|
+
- PySpark table (`"pyspark"`)
|
|
1794
1871
|
- DuckDB table (`"duckdb"`)*
|
|
1795
1872
|
- MySQL table (`"mysql"`)*
|
|
1796
1873
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -1798,7 +1875,6 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1798
1875
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
1799
1876
|
- Snowflake table (`"snowflake"`)*
|
|
1800
1877
|
- Databricks table (`"databricks"`)*
|
|
1801
|
-
- PySpark table (`"pyspark"`)*
|
|
1802
1878
|
- BigQuery table (`"bigquery"`)*
|
|
1803
1879
|
- Parquet table (`"parquet"`)*
|
|
1804
1880
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -1851,7 +1927,10 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1851
1927
|
data = _process_data(data)
|
|
1852
1928
|
|
|
1853
1929
|
# Make a copy of the data to avoid modifying the original
|
|
1854
|
-
|
|
1930
|
+
# Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
|
|
1931
|
+
tbl_type = _get_tbl_type(data=data)
|
|
1932
|
+
if "pyspark" not in tbl_type:
|
|
1933
|
+
data = copy.deepcopy(data)
|
|
1855
1934
|
|
|
1856
1935
|
# Get the number of rows in the table
|
|
1857
1936
|
n_rows = get_row_count(data)
|
|
@@ -1868,22 +1947,28 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
1868
1947
|
# Determine if the table is a DataFrame or an Ibis table
|
|
1869
1948
|
tbl_type = _get_tbl_type(data=data)
|
|
1870
1949
|
ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
|
|
1871
|
-
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
|
|
1950
|
+
pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
|
|
1872
1951
|
|
|
1873
1952
|
# Select the DataFrame library to use for displaying the Ibis table
|
|
1874
1953
|
df_lib_gt = _select_df_lib(preference="polars")
|
|
1875
1954
|
df_lib_name_gt = df_lib_gt.__name__
|
|
1876
1955
|
|
|
1877
|
-
# If the table is a DataFrame (Pandas or
|
|
1878
|
-
# library (e.g., "polars" or "
|
|
1956
|
+
# If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
|
|
1957
|
+
# library (e.g., "polars", "pandas", or "pyspark")
|
|
1879
1958
|
if pl_pb_tbl:
|
|
1880
|
-
|
|
1959
|
+
if "polars" in tbl_type:
|
|
1960
|
+
df_lib_name_gt = "polars"
|
|
1961
|
+
elif "pandas" in tbl_type:
|
|
1962
|
+
df_lib_name_gt = "pandas"
|
|
1963
|
+
elif "pyspark" in tbl_type:
|
|
1964
|
+
df_lib_name_gt = "pyspark"
|
|
1881
1965
|
|
|
1882
|
-
# Handle imports of Polars or
|
|
1966
|
+
# Handle imports of Polars, Pandas, or PySpark here
|
|
1883
1967
|
if df_lib_name_gt == "polars":
|
|
1884
1968
|
import polars as pl
|
|
1885
|
-
|
|
1969
|
+
elif df_lib_name_gt == "pandas":
|
|
1886
1970
|
import pandas as pd
|
|
1971
|
+
# Note: PySpark import is handled as needed, typically already imported in user's environment
|
|
1887
1972
|
|
|
1888
1973
|
# From an Ibis table:
|
|
1889
1974
|
# - get the row count
|
|
@@ -2047,6 +2132,77 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
|
|
|
2047
2132
|
# Get a dictionary of counts of missing values in each column
|
|
2048
2133
|
missing_val_counts = {col: data[col].isnull().sum() for col in data.columns}
|
|
2049
2134
|
|
|
2135
|
+
if "pyspark" in tbl_type:
|
|
2136
|
+
from pyspark.sql.functions import col as pyspark_col
|
|
2137
|
+
|
|
2138
|
+
# PySpark implementation for missing values calculation
|
|
2139
|
+
missing_vals = {}
|
|
2140
|
+
for col_name in data.columns:
|
|
2141
|
+
col_missing_props = []
|
|
2142
|
+
|
|
2143
|
+
# Calculate missing value proportions for each sector
|
|
2144
|
+
for i in range(len(cut_points)):
|
|
2145
|
+
start_row = cut_points[i - 1] if i > 0 else 0
|
|
2146
|
+
end_row = cut_points[i]
|
|
2147
|
+
sector_size = end_row - start_row
|
|
2148
|
+
|
|
2149
|
+
if sector_size > 0:
|
|
2150
|
+
# Use row_number() to filter rows by range
|
|
2151
|
+
from pyspark.sql.functions import row_number
|
|
2152
|
+
from pyspark.sql.window import Window
|
|
2153
|
+
|
|
2154
|
+
window = Window.orderBy(
|
|
2155
|
+
pyspark_col(data.columns[0])
|
|
2156
|
+
) # Order by first column
|
|
2157
|
+
sector_data = data.withColumn("row_num", row_number().over(window)).filter(
|
|
2158
|
+
(pyspark_col("row_num") > start_row)
|
|
2159
|
+
& (pyspark_col("row_num") <= end_row)
|
|
2160
|
+
)
|
|
2161
|
+
|
|
2162
|
+
# Count nulls in this sector
|
|
2163
|
+
null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
|
|
2164
|
+
missing_prop = (null_count / sector_size) * 100
|
|
2165
|
+
col_missing_props.append(missing_prop)
|
|
2166
|
+
else:
|
|
2167
|
+
col_missing_props.append(0)
|
|
2168
|
+
|
|
2169
|
+
# Handle the final sector (after last cut point)
|
|
2170
|
+
if n_rows > cut_points[-1]:
|
|
2171
|
+
start_row = cut_points[-1]
|
|
2172
|
+
end_row = n_rows
|
|
2173
|
+
sector_size = end_row - start_row
|
|
2174
|
+
|
|
2175
|
+
from pyspark.sql.functions import row_number
|
|
2176
|
+
from pyspark.sql.window import Window
|
|
2177
|
+
|
|
2178
|
+
window = Window.orderBy(pyspark_col(data.columns[0]))
|
|
2179
|
+
sector_data = data.withColumn("row_num", row_number().over(window)).filter(
|
|
2180
|
+
pyspark_col("row_num") > start_row
|
|
2181
|
+
)
|
|
2182
|
+
|
|
2183
|
+
null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
|
|
2184
|
+
missing_prop = (null_count / sector_size) * 100
|
|
2185
|
+
col_missing_props.append(missing_prop)
|
|
2186
|
+
else:
|
|
2187
|
+
col_missing_props.append(0)
|
|
2188
|
+
|
|
2189
|
+
missing_vals[col_name] = col_missing_props
|
|
2190
|
+
|
|
2191
|
+
# Pivot the `missing_vals` dictionary to create a table with the missing value proportions
|
|
2192
|
+
missing_vals = {
|
|
2193
|
+
"columns": list(missing_vals.keys()),
|
|
2194
|
+
**{
|
|
2195
|
+
str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
|
|
2196
|
+
for i in range(len(cut_points) + 1)
|
|
2197
|
+
},
|
|
2198
|
+
}
|
|
2199
|
+
|
|
2200
|
+
# Get a dictionary of counts of missing values in each column
|
|
2201
|
+
missing_val_counts = {}
|
|
2202
|
+
for col_name in data.columns:
|
|
2203
|
+
null_count = data.filter(pyspark_col(col_name).isNull()).count()
|
|
2204
|
+
missing_val_counts[col_name] = null_count
|
|
2205
|
+
|
|
2050
2206
|
# From `missing_vals`, create the DataFrame with the missing value proportions
|
|
2051
2207
|
if df_lib_name_gt == "polars":
|
|
2052
2208
|
import polars as pl
|
|
@@ -2333,6 +2489,7 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2333
2489
|
|
|
2334
2490
|
- Polars DataFrame (`"polars"`)
|
|
2335
2491
|
- Pandas DataFrame (`"pandas"`)
|
|
2492
|
+
- PySpark table (`"pyspark"`)
|
|
2336
2493
|
- DuckDB table (`"duckdb"`)*
|
|
2337
2494
|
- MySQL table (`"mysql"`)*
|
|
2338
2495
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -2340,7 +2497,6 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2340
2497
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
2341
2498
|
- Snowflake table (`"snowflake"`)*
|
|
2342
2499
|
- Databricks table (`"databricks"`)*
|
|
2343
|
-
- PySpark table (`"pyspark"`)*
|
|
2344
2500
|
- BigQuery table (`"bigquery"`)*
|
|
2345
2501
|
- Parquet table (`"parquet"`)*
|
|
2346
2502
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -2458,20 +2614,18 @@ def get_column_count(data: FrameT | Any) -> int:
|
|
|
2458
2614
|
# Handle list of file paths (likely Parquet files)
|
|
2459
2615
|
data = _process_parquet_input(data)
|
|
2460
2616
|
|
|
2461
|
-
|
|
2462
|
-
|
|
2463
|
-
|
|
2464
|
-
elif "polars" in str(type(data)):
|
|
2465
|
-
return len(data.columns)
|
|
2466
|
-
|
|
2467
|
-
elif "pandas" in str(type(data)):
|
|
2468
|
-
return data.shape[1]
|
|
2469
|
-
|
|
2470
|
-
elif "narwhals" in str(type(data)):
|
|
2471
|
-
return len(data.columns)
|
|
2617
|
+
# Use Narwhals to handle all DataFrame types (including Ibis) uniformly
|
|
2618
|
+
try:
|
|
2619
|
+
import narwhals as nw
|
|
2472
2620
|
|
|
2473
|
-
|
|
2474
|
-
|
|
2621
|
+
df_nw = nw.from_native(data)
|
|
2622
|
+
return len(df_nw.columns)
|
|
2623
|
+
except Exception:
|
|
2624
|
+
# Fallback for unsupported types
|
|
2625
|
+
if "pandas" in str(type(data)):
|
|
2626
|
+
return data.shape[1]
|
|
2627
|
+
else:
|
|
2628
|
+
raise ValueError("The input table type supplied in `data=` is not supported.")
|
|
2475
2629
|
|
|
2476
2630
|
|
|
2477
2631
|
def get_row_count(data: FrameT | Any) -> int:
|
|
@@ -2501,6 +2655,7 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2501
2655
|
|
|
2502
2656
|
- Polars DataFrame (`"polars"`)
|
|
2503
2657
|
- Pandas DataFrame (`"pandas"`)
|
|
2658
|
+
- PySpark table (`"pyspark"`)
|
|
2504
2659
|
- DuckDB table (`"duckdb"`)*
|
|
2505
2660
|
- MySQL table (`"mysql"`)*
|
|
2506
2661
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -2508,7 +2663,6 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2508
2663
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
2509
2664
|
- Snowflake table (`"snowflake"`)*
|
|
2510
2665
|
- Databricks table (`"databricks"`)*
|
|
2511
|
-
- PySpark table (`"pyspark"`)*
|
|
2512
2666
|
- BigQuery table (`"bigquery"`)*
|
|
2513
2667
|
- Parquet table (`"parquet"`)*
|
|
2514
2668
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -2627,30 +2781,29 @@ def get_row_count(data: FrameT | Any) -> int:
|
|
|
2627
2781
|
# Handle list of file paths (likely Parquet files)
|
|
2628
2782
|
data = _process_parquet_input(data)
|
|
2629
2783
|
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
|
|
2633
|
-
|
|
2634
|
-
|
|
2635
|
-
|
|
2636
|
-
|
|
2637
|
-
|
|
2638
|
-
|
|
2639
|
-
|
|
2784
|
+
# Use Narwhals to handle all DataFrame types (including Ibis) uniformly
|
|
2785
|
+
try:
|
|
2786
|
+
import narwhals as nw
|
|
2787
|
+
|
|
2788
|
+
df_nw = nw.from_native(data)
|
|
2789
|
+
# Handle LazyFrames by collecting them first
|
|
2790
|
+
if hasattr(df_nw, "collect"):
|
|
2791
|
+
df_nw = df_nw.collect()
|
|
2792
|
+
# Try different ways to get row count
|
|
2793
|
+
if hasattr(df_nw, "shape"):
|
|
2794
|
+
return df_nw.shape[0]
|
|
2795
|
+
elif hasattr(df_nw, "height"):
|
|
2796
|
+
return df_nw.height
|
|
2640
2797
|
else:
|
|
2641
|
-
|
|
2642
|
-
|
|
2643
|
-
|
|
2644
|
-
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
return data.shape[0]
|
|
2651
|
-
|
|
2652
|
-
else:
|
|
2653
|
-
raise ValueError("The input table type supplied in `data=` is not supported.")
|
|
2798
|
+
raise ValueError("Unable to determine row count from Narwhals DataFrame")
|
|
2799
|
+
except Exception:
|
|
2800
|
+
# Fallback for types that don't work with Narwhals
|
|
2801
|
+
if "pandas" in str(type(data)):
|
|
2802
|
+
return data.shape[0]
|
|
2803
|
+
elif "pyspark" in str(type(data)):
|
|
2804
|
+
return data.count()
|
|
2805
|
+
else:
|
|
2806
|
+
raise ValueError("The input table type supplied in `data=` is not supported.")
|
|
2654
2807
|
|
|
2655
2808
|
|
|
2656
2809
|
@dataclass
|
|
@@ -3098,6 +3251,7 @@ class Validate:
|
|
|
3098
3251
|
|
|
3099
3252
|
- Polars DataFrame (`"polars"`)
|
|
3100
3253
|
- Pandas DataFrame (`"pandas"`)
|
|
3254
|
+
- PySpark table (`"pyspark"`)
|
|
3101
3255
|
- DuckDB table (`"duckdb"`)*
|
|
3102
3256
|
- MySQL table (`"mysql"`)*
|
|
3103
3257
|
- PostgreSQL table (`"postgresql"`)*
|
|
@@ -3105,7 +3259,6 @@ class Validate:
|
|
|
3105
3259
|
- Microsoft SQL Server table (`"mssql"`)*
|
|
3106
3260
|
- Snowflake table (`"snowflake"`)*
|
|
3107
3261
|
- Databricks table (`"databricks"`)*
|
|
3108
|
-
- PySpark table (`"pyspark"`)*
|
|
3109
3262
|
- BigQuery table (`"bigquery"`)*
|
|
3110
3263
|
- Parquet table (`"parquet"`)*
|
|
3111
3264
|
- CSV files (string path or `pathlib.Path` object with `.csv` extension)
|
|
@@ -9983,12 +10136,22 @@ class Validate:
|
|
|
9983
10136
|
and tbl_type not in IBIS_BACKENDS
|
|
9984
10137
|
):
|
|
9985
10138
|
# Add row numbers to the results table
|
|
9986
|
-
validation_extract_nw = (
|
|
9987
|
-
|
|
9988
|
-
|
|
9989
|
-
|
|
9990
|
-
|
|
9991
|
-
|
|
10139
|
+
validation_extract_nw = nw.from_native(results_tbl)
|
|
10140
|
+
|
|
10141
|
+
# Handle LazyFrame row indexing which requires order_by parameter
|
|
10142
|
+
try:
|
|
10143
|
+
# Try without order_by first (for DataFrames)
|
|
10144
|
+
validation_extract_nw = validation_extract_nw.with_row_index(name="_row_num_")
|
|
10145
|
+
except TypeError:
|
|
10146
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
10147
|
+
first_col = validation_extract_nw.columns[0]
|
|
10148
|
+
validation_extract_nw = validation_extract_nw.with_row_index(
|
|
10149
|
+
name="_row_num_", order_by=first_col
|
|
10150
|
+
)
|
|
10151
|
+
|
|
10152
|
+
validation_extract_nw = validation_extract_nw.filter(~nw.col("pb_is_good_")).drop(
|
|
10153
|
+
"pb_is_good_"
|
|
10154
|
+
) # noqa
|
|
9992
10155
|
|
|
9993
10156
|
# Add 1 to the row numbers to make them 1-indexed
|
|
9994
10157
|
validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)
|
|
@@ -9997,12 +10160,52 @@ class Validate:
|
|
|
9997
10160
|
if get_first_n is not None:
|
|
9998
10161
|
validation_extract_nw = validation_extract_nw.head(get_first_n)
|
|
9999
10162
|
elif sample_n is not None:
|
|
10000
|
-
|
|
10163
|
+
# Narwhals LazyFrame doesn't have sample method, use head after shuffling
|
|
10164
|
+
try:
|
|
10165
|
+
validation_extract_nw = validation_extract_nw.sample(n=sample_n)
|
|
10166
|
+
except AttributeError:
|
|
10167
|
+
# For LazyFrames without sample method, collect first then sample
|
|
10168
|
+
validation_extract_native = validation_extract_nw.collect().to_native()
|
|
10169
|
+
if hasattr(validation_extract_native, "sample"):
|
|
10170
|
+
# PySpark DataFrame has sample method
|
|
10171
|
+
validation_extract_native = validation_extract_native.sample(
|
|
10172
|
+
fraction=min(1.0, sample_n / validation_extract_native.count())
|
|
10173
|
+
).limit(sample_n)
|
|
10174
|
+
validation_extract_nw = nw.from_native(validation_extract_native)
|
|
10175
|
+
else:
|
|
10176
|
+
# Fallback: just take first n rows after collecting
|
|
10177
|
+
validation_extract_nw = validation_extract_nw.collect().head(sample_n)
|
|
10001
10178
|
elif sample_frac is not None:
|
|
10002
|
-
|
|
10179
|
+
try:
|
|
10180
|
+
validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
|
|
10181
|
+
except AttributeError:
|
|
10182
|
+
# For LazyFrames without sample method, collect first then sample
|
|
10183
|
+
validation_extract_native = validation_extract_nw.collect().to_native()
|
|
10184
|
+
if hasattr(validation_extract_native, "sample"):
|
|
10185
|
+
# PySpark DataFrame has sample method
|
|
10186
|
+
validation_extract_native = validation_extract_native.sample(
|
|
10187
|
+
fraction=sample_frac
|
|
10188
|
+
)
|
|
10189
|
+
validation_extract_nw = nw.from_native(validation_extract_native)
|
|
10190
|
+
else:
|
|
10191
|
+
# Fallback: use fraction to calculate head size
|
|
10192
|
+
collected = validation_extract_nw.collect()
|
|
10193
|
+
sample_size = max(1, int(len(collected) * sample_frac))
|
|
10194
|
+
validation_extract_nw = collected.head(sample_size)
|
|
10003
10195
|
|
|
10004
10196
|
# Ensure a limit is set on the number of rows to extract
|
|
10005
|
-
|
|
10197
|
+
try:
|
|
10198
|
+
# For DataFrames, use len()
|
|
10199
|
+
extract_length = len(validation_extract_nw)
|
|
10200
|
+
except TypeError:
|
|
10201
|
+
# For LazyFrames, collect to get length (or use a reasonable default)
|
|
10202
|
+
try:
|
|
10203
|
+
extract_length = len(validation_extract_nw.collect())
|
|
10204
|
+
except Exception:
|
|
10205
|
+
# If collection fails, apply limit anyway as a safety measure
|
|
10206
|
+
extract_length = extract_limit + 1 # Force limiting
|
|
10207
|
+
|
|
10208
|
+
if extract_length > extract_limit:
|
|
10006
10209
|
validation_extract_nw = validation_extract_nw.head(extract_limit)
|
|
10007
10210
|
|
|
10008
10211
|
# If a 'rows_distinct' validation step, then the extract should have the
|
|
@@ -10030,7 +10233,10 @@ class Validate:
|
|
|
10030
10233
|
.drop("group_min_row")
|
|
10031
10234
|
)
|
|
10032
10235
|
|
|
10033
|
-
# Ensure that the extract is set to its native format
|
|
10236
|
+
# Ensure that the extract is collected and set to its native format
|
|
10237
|
+
# For LazyFrames (like PySpark), we need to collect before converting to native
|
|
10238
|
+
if hasattr(validation_extract_nw, "collect"):
|
|
10239
|
+
validation_extract_nw = validation_extract_nw.collect()
|
|
10034
10240
|
validation.extract = nw.to_native(validation_extract_nw)
|
|
10035
10241
|
|
|
10036
10242
|
# Get the end time for this step
|
|
@@ -11656,7 +11862,16 @@ class Validate:
|
|
|
11656
11862
|
# TODO: add argument for user to specify the index column name
|
|
11657
11863
|
index_name = "pb_index_"
|
|
11658
11864
|
|
|
11659
|
-
data_nw = nw.from_native(self.data)
|
|
11865
|
+
data_nw = nw.from_native(self.data)
|
|
11866
|
+
|
|
11867
|
+
# Handle LazyFrame row indexing which requires order_by parameter
|
|
11868
|
+
try:
|
|
11869
|
+
# Try without order_by first (for DataFrames)
|
|
11870
|
+
data_nw = data_nw.with_row_index(name=index_name)
|
|
11871
|
+
except TypeError:
|
|
11872
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
11873
|
+
first_col = data_nw.columns[0]
|
|
11874
|
+
data_nw = data_nw.with_row_index(name=index_name, order_by=first_col)
|
|
11660
11875
|
|
|
11661
11876
|
# Get all validation step result tables and join together the `pb_is_good_` columns
|
|
11662
11877
|
# ensuring that the columns are named uniquely (e.g., `pb_is_good_1`, `pb_is_good_2`, ...)
|
|
@@ -11665,7 +11880,13 @@ class Validate:
|
|
|
11665
11880
|
results_tbl = nw.from_native(validation.tbl_checked)
|
|
11666
11881
|
|
|
11667
11882
|
# Add row numbers to the results table
|
|
11668
|
-
|
|
11883
|
+
try:
|
|
11884
|
+
# Try without order_by first (for DataFrames)
|
|
11885
|
+
results_tbl = results_tbl.with_row_index(name=index_name)
|
|
11886
|
+
except TypeError:
|
|
11887
|
+
# LazyFrames require order_by parameter - use first column for ordering
|
|
11888
|
+
first_col = results_tbl.columns[0]
|
|
11889
|
+
results_tbl = results_tbl.with_row_index(name=index_name, order_by=first_col)
|
|
11669
11890
|
|
|
11670
11891
|
# Add numerical suffix to the `pb_is_good_` column to make it unique
|
|
11671
11892
|
results_tbl = results_tbl.select([index_name, "pb_is_good_"]).rename(
|
|
@@ -12284,15 +12505,21 @@ class Validate:
|
|
|
12284
12505
|
# Transform to Narwhals DataFrame
|
|
12285
12506
|
extract_nw = nw.from_native(extract)
|
|
12286
12507
|
|
|
12287
|
-
# Get the number of rows in the extract
|
|
12288
|
-
|
|
12508
|
+
# Get the number of rows in the extract (safe for LazyFrames)
|
|
12509
|
+
try:
|
|
12510
|
+
n_rows = len(extract_nw)
|
|
12511
|
+
except TypeError:
|
|
12512
|
+
# For LazyFrames, collect() first to get length
|
|
12513
|
+
n_rows = len(extract_nw.collect()) if hasattr(extract_nw, "collect") else 0
|
|
12289
12514
|
|
|
12290
12515
|
# If the number of rows is zero, then produce an em dash then go to the next iteration
|
|
12291
12516
|
if n_rows == 0:
|
|
12292
12517
|
extract_upd.append("—")
|
|
12293
12518
|
continue
|
|
12294
12519
|
|
|
12295
|
-
# Write the CSV text
|
|
12520
|
+
# Write the CSV text (ensure LazyFrames are collected first)
|
|
12521
|
+
if hasattr(extract_nw, "collect"):
|
|
12522
|
+
extract_nw = extract_nw.collect()
|
|
12296
12523
|
csv_text = extract_nw.write_csv()
|
|
12297
12524
|
|
|
12298
12525
|
# Use Base64 encoding to encode the CSV text
|
|
@@ -13856,7 +14083,7 @@ def _prep_values_text(
|
|
|
13856
14083
|
return values_str
|
|
13857
14084
|
|
|
13858
14085
|
|
|
13859
|
-
def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
14086
|
+
def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> list[tuple[str, str]]:
|
|
13860
14087
|
"""
|
|
13861
14088
|
Obtain the segmentation categories from a table column.
|
|
13862
14089
|
|
|
@@ -13881,22 +14108,27 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
|
13881
14108
|
list[tuple[str, str]]
|
|
13882
14109
|
A list of tuples representing pairings of a column name and a value in the column.
|
|
13883
14110
|
"""
|
|
14111
|
+
import narwhals as nw
|
|
14112
|
+
|
|
13884
14113
|
# Determine if the table is a DataFrame or a DB table
|
|
13885
14114
|
tbl_type = _get_tbl_type(data=data_tbl)
|
|
13886
14115
|
|
|
13887
14116
|
# Obtain the segmentation categories from the table column given as `segments_expr`
|
|
13888
|
-
if tbl_type
|
|
13889
|
-
|
|
13890
|
-
|
|
13891
|
-
|
|
14117
|
+
if tbl_type in ["polars", "pandas", "pyspark"]:
|
|
14118
|
+
# Use Narwhals for supported DataFrame types
|
|
14119
|
+
data_nw = nw.from_native(data_tbl)
|
|
14120
|
+
unique_vals = data_nw.select(nw.col(segments_expr)).unique()
|
|
14121
|
+
|
|
14122
|
+
# Convert to list of values
|
|
14123
|
+
seg_categories = unique_vals[segments_expr].to_list()
|
|
13892
14124
|
elif tbl_type in IBIS_BACKENDS:
|
|
13893
14125
|
distinct_col_vals = data_tbl.select(segments_expr).distinct()
|
|
13894
14126
|
seg_categories = distinct_col_vals[segments_expr].to_list()
|
|
13895
14127
|
else: # pragma: no cover
|
|
13896
14128
|
raise ValueError(f"Unsupported table type: {tbl_type}")
|
|
13897
14129
|
|
|
13898
|
-
# Ensure that the categories are sorted
|
|
13899
|
-
seg_categories.sort()
|
|
14130
|
+
# Ensure that the categories are sorted, and allow for None values
|
|
14131
|
+
seg_categories.sort(key=lambda x: (x is None, x))
|
|
13900
14132
|
|
|
13901
14133
|
# Place each category and each value in a list of tuples as: `(column, value)`
|
|
13902
14134
|
seg_tuples = [(segments_expr, category) for category in seg_categories]
|
|
@@ -13904,7 +14136,7 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
|
|
|
13904
14136
|
return seg_tuples
|
|
13905
14137
|
|
|
13906
14138
|
|
|
13907
|
-
def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str,
|
|
14139
|
+
def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, Any]]:
|
|
13908
14140
|
"""
|
|
13909
14141
|
Normalize the segments expression to a list of tuples, given a single tuple.
|
|
13910
14142
|
|
|
@@ -13930,17 +14162,23 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
|
|
|
13930
14162
|
|
|
13931
14163
|
Returns
|
|
13932
14164
|
-------
|
|
13933
|
-
list[tuple[str,
|
|
14165
|
+
list[tuple[str, Any]]
|
|
13934
14166
|
A list of tuples representing pairings of a column name and a value in the column.
|
|
14167
|
+
Values can be any type, including None.
|
|
13935
14168
|
"""
|
|
14169
|
+
# Unpack the segments expression tuple for more convenient and explicit variable names
|
|
14170
|
+
column, segment = segments_expr
|
|
14171
|
+
|
|
13936
14172
|
# Check if the first element is a string
|
|
13937
|
-
if isinstance(
|
|
13938
|
-
|
|
13939
|
-
|
|
13940
|
-
|
|
14173
|
+
if isinstance(column, str):
|
|
14174
|
+
if isinstance(segment, Segment):
|
|
14175
|
+
seg_tuples = [(column, seg) for seg in segment.segments]
|
|
14176
|
+
# If the second element is a collection, expand into a list of tuples
|
|
14177
|
+
elif isinstance(segment, (list, set, tuple)):
|
|
14178
|
+
seg_tuples = [(column, seg) for seg in segment]
|
|
13941
14179
|
# If the second element is not a list, create a single tuple
|
|
13942
14180
|
else:
|
|
13943
|
-
seg_tuples = [(
|
|
14181
|
+
seg_tuples = [(column, segment)]
|
|
13944
14182
|
# If the first element is not a string, raise an error
|
|
13945
14183
|
else: # pragma: no cover
|
|
13946
14184
|
raise ValueError("The first element of the segments expression must be a string.")
|
|
@@ -13948,7 +14186,7 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
|
|
|
13948
14186
|
return seg_tuples
|
|
13949
14187
|
|
|
13950
14188
|
|
|
13951
|
-
def _apply_segments(data_tbl: any, segments_expr: tuple[str,
|
|
14189
|
+
def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
|
|
13952
14190
|
"""
|
|
13953
14191
|
Apply the segments expression to the data table.
|
|
13954
14192
|
|
|
@@ -13971,15 +14209,24 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
|
|
|
13971
14209
|
# Get the table type
|
|
13972
14210
|
tbl_type = _get_tbl_type(data=data_tbl)
|
|
13973
14211
|
|
|
13974
|
-
|
|
13975
|
-
|
|
14212
|
+
# Unpack the segments expression tuple for more convenient and explicit variable names
|
|
14213
|
+
column, segment = segments_expr
|
|
14214
|
+
|
|
14215
|
+
if tbl_type in ["pandas", "polars", "pyspark"]:
|
|
14216
|
+
# If the table is a Pandas, Polars, or PySpark DataFrame, transforming to a Narwhals table
|
|
13976
14217
|
# and perform the filtering operation
|
|
13977
14218
|
|
|
13978
14219
|
# Transform to Narwhals table if a DataFrame
|
|
13979
14220
|
data_tbl_nw = nw.from_native(data_tbl)
|
|
13980
14221
|
|
|
13981
|
-
# Filter the data table based on the column name and
|
|
13982
|
-
|
|
14222
|
+
# Filter the data table based on the column name and segment
|
|
14223
|
+
if segment is None:
|
|
14224
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_null())
|
|
14225
|
+
# Check if the segment is a segment group
|
|
14226
|
+
elif isinstance(segment, list):
|
|
14227
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_in(segment))
|
|
14228
|
+
else:
|
|
14229
|
+
data_tbl_nw = data_tbl_nw.filter(nw.col(column) == segment)
|
|
13983
14230
|
|
|
13984
14231
|
# Transform back to the original table type
|
|
13985
14232
|
data_tbl = data_tbl_nw.to_native()
|
|
@@ -13987,8 +14234,13 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
|
|
|
13987
14234
|
elif tbl_type in IBIS_BACKENDS:
|
|
13988
14235
|
# If the table is an Ibis backend table, perform the filtering operation directly
|
|
13989
14236
|
|
|
13990
|
-
# Filter the data table based on the column name and
|
|
13991
|
-
|
|
14237
|
+
# Filter the data table based on the column name and segment
|
|
14238
|
+
if segment is None:
|
|
14239
|
+
data_tbl = data_tbl[data_tbl[column].isnull()]
|
|
14240
|
+
elif isinstance(segment, list):
|
|
14241
|
+
data_tbl = data_tbl[data_tbl[column].isin(segment)]
|
|
14242
|
+
else:
|
|
14243
|
+
data_tbl = data_tbl[data_tbl[column] == segment]
|
|
13992
14244
|
|
|
13993
14245
|
return data_tbl
|
|
13994
14246
|
|