pointblank 0.11.6__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pointblank/validate.py CHANGED
@@ -87,6 +87,7 @@ from pointblank._utils_check_args import (
87
87
  from pointblank._utils_html import _create_table_dims_html, _create_table_type_html
88
88
  from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col
89
89
  from pointblank.schema import Schema, _get_schema_validation_info
90
+ from pointblank.segments import Segment
90
91
  from pointblank.thresholds import (
91
92
  Actions,
92
93
  FinalActions,
@@ -1194,6 +1195,7 @@ def preview(
1194
1195
 
1195
1196
  - Polars DataFrame (`"polars"`)
1196
1197
  - Pandas DataFrame (`"pandas"`)
1198
+ - PySpark table (`"pyspark"`)
1197
1199
  - DuckDB table (`"duckdb"`)*
1198
1200
  - MySQL table (`"mysql"`)*
1199
1201
  - PostgreSQL table (`"postgresql"`)*
@@ -1201,7 +1203,6 @@ def preview(
1201
1203
  - Microsoft SQL Server table (`"mssql"`)*
1202
1204
  - Snowflake table (`"snowflake"`)*
1203
1205
  - Databricks table (`"databricks"`)*
1204
- - PySpark table (`"pyspark"`)*
1205
1206
  - BigQuery table (`"bigquery"`)*
1206
1207
  - Parquet table (`"parquet"`)*
1207
1208
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -1396,7 +1397,10 @@ def _generate_display_table(
1396
1397
  row_number_list: list[int] | None = None,
1397
1398
  ) -> GT:
1398
1399
  # Make a copy of the data to avoid modifying the original
1399
- data = copy.deepcopy(data)
1400
+ # Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
1401
+ tbl_type = _get_tbl_type(data=data)
1402
+ if "pyspark" not in tbl_type:
1403
+ data = copy.deepcopy(data)
1400
1404
 
1401
1405
  # Does the data table already have a leading row number column?
1402
1406
  if "_row_num_" in data.columns:
@@ -1422,22 +1426,31 @@ def _generate_display_table(
1422
1426
  # Determine if the table is a DataFrame or an Ibis table
1423
1427
  tbl_type = _get_tbl_type(data=data)
1424
1428
  ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
1425
- pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
1429
+ pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
1426
1430
 
1427
1431
  # Select the DataFrame library to use for displaying the Ibis table
1428
1432
  df_lib_gt = _select_df_lib(preference="polars")
1429
1433
  df_lib_name_gt = df_lib_gt.__name__
1430
1434
 
1431
- # If the table is a DataFrame (Pandas or Polars), set `df_lib_name_gt` to the name of the
1432
- # library (e.g., "polars" or "pandas")
1435
+ # If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
1436
+ # library (e.g., "polars", "pandas", or "pyspark")
1433
1437
  if pl_pb_tbl:
1434
- df_lib_name_gt = "polars" if "polars" in tbl_type else "pandas"
1438
+ if "polars" in tbl_type:
1439
+ df_lib_name_gt = "polars"
1440
+ elif "pandas" in tbl_type:
1441
+ df_lib_name_gt = "pandas"
1442
+ elif "pyspark" in tbl_type:
1443
+ df_lib_name_gt = "pyspark"
1435
1444
 
1436
- # Handle imports of Polars or Pandas here
1445
+ # Handle imports of Polars, Pandas, or PySpark here
1437
1446
  if df_lib_name_gt == "polars":
1438
1447
  import polars as pl
1439
- else:
1448
+ elif df_lib_name_gt == "pandas":
1440
1449
  import pandas as pd
1450
+ elif df_lib_name_gt == "pyspark":
1451
+ # Import pandas for conversion since Great Tables needs pandas DataFrame
1452
+ import pandas as pd
1453
+ # Note: PySpark import is handled as needed, typically already imported in user's environment
1441
1454
 
1442
1455
  # Get the initial column count for the table
1443
1456
  n_columns = len(data.columns)
@@ -1547,6 +1560,42 @@ def _generate_display_table(
1547
1560
  range(n_rows - n_tail + 1, n_rows + 1)
1548
1561
  )
1549
1562
 
1563
+ if tbl_type == "pyspark":
1564
+ n_rows = data.count()
1565
+
1566
+ # If n_head + n_tail is greater than the row count, display the entire table
1567
+ if n_head + n_tail >= n_rows:
1568
+ full_dataset = True
1569
+ # Convert to pandas for Great Tables compatibility
1570
+ data = data.toPandas()
1571
+
1572
+ row_number_list = range(1, n_rows + 1)
1573
+ else:
1574
+ # Get head and tail samples, then convert to pandas
1575
+ head_data = data.limit(n_head).toPandas()
1576
+
1577
+ # PySpark tail() returns a list of Row objects, need to convert to DataFrame
1578
+ tail_rows = data.tail(n_tail)
1579
+ if tail_rows:
1580
+ # Convert list of Row objects back to DataFrame, then to pandas
1581
+ tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema)
1582
+ tail_data = tail_df.toPandas()
1583
+ else:
1584
+ # If no tail data, create empty DataFrame with same schema
1585
+ import pandas as pd
1586
+
1587
+ tail_data = pd.DataFrame(columns=head_data.columns)
1588
+
1589
+ data = pd.concat([head_data, tail_data])
1590
+
1591
+ row_number_list = list(range(1, n_head + 1)) + list(
1592
+ range(n_rows - n_tail + 1, n_rows + 1)
1593
+ )
1594
+
1595
+ # For PySpark, update schema after conversion to pandas
1596
+ if tbl_type == "pyspark":
1597
+ tbl_schema = Schema(tbl=data)
1598
+
1550
1599
  # From the table schema, get a list of tuples containing column names and data types
1551
1600
  col_dtype_dict = tbl_schema.columns
1552
1601
 
@@ -1566,6 +1615,23 @@ def _generate_display_table(
1566
1615
  # This is used to highlight these values in the table
1567
1616
  if df_lib_name_gt == "polars":
1568
1617
  none_values = {k: data[k].is_null().to_list() for k in col_names}
1618
+ elif df_lib_name_gt == "pyspark":
1619
+ # For PySpark, check if data has been converted to pandas already
1620
+ if hasattr(data, "isnull"):
1621
+ # Data has been converted to pandas
1622
+ none_values = {k: data[k].isnull() for k in col_names}
1623
+ else:
1624
+ # Data is still a PySpark DataFrame - use narwhals
1625
+ import narwhals as nw
1626
+
1627
+ df_nw = nw.from_native(data)
1628
+ none_values = {}
1629
+ for col in col_names:
1630
+ # Get null mask, collect to pandas, then convert to list
1631
+ null_mask = (
1632
+ df_nw.select(nw.col(col).is_null()).collect().to_pandas().iloc[:, 0].tolist()
1633
+ )
1634
+ none_values[col] = null_mask
1569
1635
  else:
1570
1636
  none_values = {k: data[k].isnull() for k in col_names}
1571
1637
 
@@ -1579,7 +1645,13 @@ def _generate_display_table(
1579
1645
 
1580
1646
  for column in col_dtype_dict.keys():
1581
1647
  # Select a single column of values
1582
- data_col = data[[column]] if df_lib_name_gt == "pandas" else data.select([column])
1648
+ if df_lib_name_gt == "pandas":
1649
+ data_col = data[[column]]
1650
+ elif df_lib_name_gt == "pyspark":
1651
+ # PySpark data should have been converted to pandas by now
1652
+ data_col = data[[column]]
1653
+ else:
1654
+ data_col = data.select([column])
1583
1655
 
1584
1656
  # Using Great Tables, render the columns and get the list of values as formatted strings
1585
1657
  built_gt = GT(data=data_col).fmt_markdown(columns=column)._build_data(context="html")
@@ -1658,6 +1730,10 @@ def _generate_display_table(
1658
1730
  if df_lib_name_gt == "pandas":
1659
1731
  data.insert(0, "_row_num_", row_number_list)
1660
1732
 
1733
+ if df_lib_name_gt == "pyspark":
1734
+ # For PySpark converted to pandas, use pandas method
1735
+ data.insert(0, "_row_num_", row_number_list)
1736
+
1661
1737
  # Get the highest number in the `row_number_list` and calculate a width that will
1662
1738
  # safely fit a number of that magnitude
1663
1739
  if row_number_list: # Check if list is not empty
@@ -1791,6 +1867,7 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1791
1867
 
1792
1868
  - Polars DataFrame (`"polars"`)
1793
1869
  - Pandas DataFrame (`"pandas"`)
1870
+ - PySpark table (`"pyspark"`)
1794
1871
  - DuckDB table (`"duckdb"`)*
1795
1872
  - MySQL table (`"mysql"`)*
1796
1873
  - PostgreSQL table (`"postgresql"`)*
@@ -1798,7 +1875,6 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1798
1875
  - Microsoft SQL Server table (`"mssql"`)*
1799
1876
  - Snowflake table (`"snowflake"`)*
1800
1877
  - Databricks table (`"databricks"`)*
1801
- - PySpark table (`"pyspark"`)*
1802
1878
  - BigQuery table (`"bigquery"`)*
1803
1879
  - Parquet table (`"parquet"`)*
1804
1880
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -1851,7 +1927,10 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1851
1927
  data = _process_data(data)
1852
1928
 
1853
1929
  # Make a copy of the data to avoid modifying the original
1854
- data = copy.deepcopy(data)
1930
+ # Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
1931
+ tbl_type = _get_tbl_type(data=data)
1932
+ if "pyspark" not in tbl_type:
1933
+ data = copy.deepcopy(data)
1855
1934
 
1856
1935
  # Get the number of rows in the table
1857
1936
  n_rows = get_row_count(data)
@@ -1868,22 +1947,28 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1868
1947
  # Determine if the table is a DataFrame or an Ibis table
1869
1948
  tbl_type = _get_tbl_type(data=data)
1870
1949
  ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
1871
- pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
1950
+ pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
1872
1951
 
1873
1952
  # Select the DataFrame library to use for displaying the Ibis table
1874
1953
  df_lib_gt = _select_df_lib(preference="polars")
1875
1954
  df_lib_name_gt = df_lib_gt.__name__
1876
1955
 
1877
- # If the table is a DataFrame (Pandas or Polars), set `df_lib_name_gt` to the name of the
1878
- # library (e.g., "polars" or "pandas")
1956
+ # If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
1957
+ # library (e.g., "polars", "pandas", or "pyspark")
1879
1958
  if pl_pb_tbl:
1880
- df_lib_name_gt = "polars" if "polars" in tbl_type else "pandas"
1959
+ if "polars" in tbl_type:
1960
+ df_lib_name_gt = "polars"
1961
+ elif "pandas" in tbl_type:
1962
+ df_lib_name_gt = "pandas"
1963
+ elif "pyspark" in tbl_type:
1964
+ df_lib_name_gt = "pyspark"
1881
1965
 
1882
- # Handle imports of Polars or Pandas here
1966
+ # Handle imports of Polars, Pandas, or PySpark here
1883
1967
  if df_lib_name_gt == "polars":
1884
1968
  import polars as pl
1885
- else:
1969
+ elif df_lib_name_gt == "pandas":
1886
1970
  import pandas as pd
1971
+ # Note: PySpark import is handled as needed, typically already imported in user's environment
1887
1972
 
1888
1973
  # From an Ibis table:
1889
1974
  # - get the row count
@@ -2047,6 +2132,77 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
2047
2132
  # Get a dictionary of counts of missing values in each column
2048
2133
  missing_val_counts = {col: data[col].isnull().sum() for col in data.columns}
2049
2134
 
2135
+ if "pyspark" in tbl_type:
2136
+ from pyspark.sql.functions import col as pyspark_col
2137
+
2138
+ # PySpark implementation for missing values calculation
2139
+ missing_vals = {}
2140
+ for col_name in data.columns:
2141
+ col_missing_props = []
2142
+
2143
+ # Calculate missing value proportions for each sector
2144
+ for i in range(len(cut_points)):
2145
+ start_row = cut_points[i - 1] if i > 0 else 0
2146
+ end_row = cut_points[i]
2147
+ sector_size = end_row - start_row
2148
+
2149
+ if sector_size > 0:
2150
+ # Use row_number() to filter rows by range
2151
+ from pyspark.sql.functions import row_number
2152
+ from pyspark.sql.window import Window
2153
+
2154
+ window = Window.orderBy(
2155
+ pyspark_col(data.columns[0])
2156
+ ) # Order by first column
2157
+ sector_data = data.withColumn("row_num", row_number().over(window)).filter(
2158
+ (pyspark_col("row_num") > start_row)
2159
+ & (pyspark_col("row_num") <= end_row)
2160
+ )
2161
+
2162
+ # Count nulls in this sector
2163
+ null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
2164
+ missing_prop = (null_count / sector_size) * 100
2165
+ col_missing_props.append(missing_prop)
2166
+ else:
2167
+ col_missing_props.append(0)
2168
+
2169
+ # Handle the final sector (after last cut point)
2170
+ if n_rows > cut_points[-1]:
2171
+ start_row = cut_points[-1]
2172
+ end_row = n_rows
2173
+ sector_size = end_row - start_row
2174
+
2175
+ from pyspark.sql.functions import row_number
2176
+ from pyspark.sql.window import Window
2177
+
2178
+ window = Window.orderBy(pyspark_col(data.columns[0]))
2179
+ sector_data = data.withColumn("row_num", row_number().over(window)).filter(
2180
+ pyspark_col("row_num") > start_row
2181
+ )
2182
+
2183
+ null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
2184
+ missing_prop = (null_count / sector_size) * 100
2185
+ col_missing_props.append(missing_prop)
2186
+ else:
2187
+ col_missing_props.append(0)
2188
+
2189
+ missing_vals[col_name] = col_missing_props
2190
+
2191
+ # Pivot the `missing_vals` dictionary to create a table with the missing value proportions
2192
+ missing_vals = {
2193
+ "columns": list(missing_vals.keys()),
2194
+ **{
2195
+ str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
2196
+ for i in range(len(cut_points) + 1)
2197
+ },
2198
+ }
2199
+
2200
+ # Get a dictionary of counts of missing values in each column
2201
+ missing_val_counts = {}
2202
+ for col_name in data.columns:
2203
+ null_count = data.filter(pyspark_col(col_name).isNull()).count()
2204
+ missing_val_counts[col_name] = null_count
2205
+
2050
2206
  # From `missing_vals`, create the DataFrame with the missing value proportions
2051
2207
  if df_lib_name_gt == "polars":
2052
2208
  import polars as pl
@@ -2333,6 +2489,7 @@ def get_column_count(data: FrameT | Any) -> int:
2333
2489
 
2334
2490
  - Polars DataFrame (`"polars"`)
2335
2491
  - Pandas DataFrame (`"pandas"`)
2492
+ - PySpark table (`"pyspark"`)
2336
2493
  - DuckDB table (`"duckdb"`)*
2337
2494
  - MySQL table (`"mysql"`)*
2338
2495
  - PostgreSQL table (`"postgresql"`)*
@@ -2340,7 +2497,6 @@ def get_column_count(data: FrameT | Any) -> int:
2340
2497
  - Microsoft SQL Server table (`"mssql"`)*
2341
2498
  - Snowflake table (`"snowflake"`)*
2342
2499
  - Databricks table (`"databricks"`)*
2343
- - PySpark table (`"pyspark"`)*
2344
2500
  - BigQuery table (`"bigquery"`)*
2345
2501
  - Parquet table (`"parquet"`)*
2346
2502
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -2467,6 +2623,9 @@ def get_column_count(data: FrameT | Any) -> int:
2467
2623
  elif "pandas" in str(type(data)):
2468
2624
  return data.shape[1]
2469
2625
 
2626
+ elif "pyspark" in str(type(data)):
2627
+ return len(data.columns)
2628
+
2470
2629
  elif "narwhals" in str(type(data)):
2471
2630
  return len(data.columns)
2472
2631
 
@@ -2501,6 +2660,7 @@ def get_row_count(data: FrameT | Any) -> int:
2501
2660
 
2502
2661
  - Polars DataFrame (`"polars"`)
2503
2662
  - Pandas DataFrame (`"pandas"`)
2663
+ - PySpark table (`"pyspark"`)
2504
2664
  - DuckDB table (`"duckdb"`)*
2505
2665
  - MySQL table (`"mysql"`)*
2506
2666
  - PostgreSQL table (`"postgresql"`)*
@@ -2508,7 +2668,6 @@ def get_row_count(data: FrameT | Any) -> int:
2508
2668
  - Microsoft SQL Server table (`"mssql"`)*
2509
2669
  - Snowflake table (`"snowflake"`)*
2510
2670
  - Databricks table (`"databricks"`)*
2511
- - PySpark table (`"pyspark"`)*
2512
2671
  - BigQuery table (`"bigquery"`)*
2513
2672
  - Parquet table (`"parquet"`)*
2514
2673
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -2646,6 +2805,9 @@ def get_row_count(data: FrameT | Any) -> int:
2646
2805
  elif "pandas" in str(type(data)):
2647
2806
  return data.shape[0]
2648
2807
 
2808
+ elif "pyspark" in str(type(data)):
2809
+ return data.count()
2810
+
2649
2811
  elif "narwhals" in str(type(data)):
2650
2812
  return data.shape[0]
2651
2813
 
@@ -3098,6 +3260,7 @@ class Validate:
3098
3260
 
3099
3261
  - Polars DataFrame (`"polars"`)
3100
3262
  - Pandas DataFrame (`"pandas"`)
3263
+ - PySpark table (`"pyspark"`)
3101
3264
  - DuckDB table (`"duckdb"`)*
3102
3265
  - MySQL table (`"mysql"`)*
3103
3266
  - PostgreSQL table (`"postgresql"`)*
@@ -3105,7 +3268,6 @@ class Validate:
3105
3268
  - Microsoft SQL Server table (`"mssql"`)*
3106
3269
  - Snowflake table (`"snowflake"`)*
3107
3270
  - Databricks table (`"databricks"`)*
3108
- - PySpark table (`"pyspark"`)*
3109
3271
  - BigQuery table (`"bigquery"`)*
3110
3272
  - Parquet table (`"parquet"`)*
3111
3273
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -9983,12 +10145,22 @@ class Validate:
9983
10145
  and tbl_type not in IBIS_BACKENDS
9984
10146
  ):
9985
10147
  # Add row numbers to the results table
9986
- validation_extract_nw = (
9987
- nw.from_native(results_tbl)
9988
- .with_row_index(name="_row_num_")
9989
- .filter(nw.col("pb_is_good_") == False) # noqa
9990
- .drop("pb_is_good_")
9991
- )
10148
+ validation_extract_nw = nw.from_native(results_tbl)
10149
+
10150
+ # Handle LazyFrame row indexing which requires order_by parameter
10151
+ try:
10152
+ # Try without order_by first (for DataFrames)
10153
+ validation_extract_nw = validation_extract_nw.with_row_index(name="_row_num_")
10154
+ except TypeError:
10155
+ # LazyFrames require order_by parameter - use first column for ordering
10156
+ first_col = validation_extract_nw.columns[0]
10157
+ validation_extract_nw = validation_extract_nw.with_row_index(
10158
+ name="_row_num_", order_by=first_col
10159
+ )
10160
+
10161
+ validation_extract_nw = validation_extract_nw.filter(~nw.col("pb_is_good_")).drop(
10162
+ "pb_is_good_"
10163
+ ) # noqa
9992
10164
 
9993
10165
  # Add 1 to the row numbers to make them 1-indexed
9994
10166
  validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)
@@ -9997,12 +10169,52 @@ class Validate:
9997
10169
  if get_first_n is not None:
9998
10170
  validation_extract_nw = validation_extract_nw.head(get_first_n)
9999
10171
  elif sample_n is not None:
10000
- validation_extract_nw = validation_extract_nw.sample(n=sample_n)
10172
+ # Narwhals LazyFrame doesn't have sample method, use head after shuffling
10173
+ try:
10174
+ validation_extract_nw = validation_extract_nw.sample(n=sample_n)
10175
+ except AttributeError:
10176
+ # For LazyFrames without sample method, collect first then sample
10177
+ validation_extract_native = validation_extract_nw.collect().to_native()
10178
+ if hasattr(validation_extract_native, "sample"):
10179
+ # PySpark DataFrame has sample method
10180
+ validation_extract_native = validation_extract_native.sample(
10181
+ fraction=min(1.0, sample_n / validation_extract_native.count())
10182
+ ).limit(sample_n)
10183
+ validation_extract_nw = nw.from_native(validation_extract_native)
10184
+ else:
10185
+ # Fallback: just take first n rows after collecting
10186
+ validation_extract_nw = validation_extract_nw.collect().head(sample_n)
10001
10187
  elif sample_frac is not None:
10002
- validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
10188
+ try:
10189
+ validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
10190
+ except AttributeError:
10191
+ # For LazyFrames without sample method, collect first then sample
10192
+ validation_extract_native = validation_extract_nw.collect().to_native()
10193
+ if hasattr(validation_extract_native, "sample"):
10194
+ # PySpark DataFrame has sample method
10195
+ validation_extract_native = validation_extract_native.sample(
10196
+ fraction=sample_frac
10197
+ )
10198
+ validation_extract_nw = nw.from_native(validation_extract_native)
10199
+ else:
10200
+ # Fallback: use fraction to calculate head size
10201
+ collected = validation_extract_nw.collect()
10202
+ sample_size = max(1, int(len(collected) * sample_frac))
10203
+ validation_extract_nw = collected.head(sample_size)
10003
10204
 
10004
10205
  # Ensure a limit is set on the number of rows to extract
10005
- if len(validation_extract_nw) > extract_limit:
10206
+ try:
10207
+ # For DataFrames, use len()
10208
+ extract_length = len(validation_extract_nw)
10209
+ except TypeError:
10210
+ # For LazyFrames, collect to get length (or use a reasonable default)
10211
+ try:
10212
+ extract_length = len(validation_extract_nw.collect())
10213
+ except Exception:
10214
+ # If collection fails, apply limit anyway as a safety measure
10215
+ extract_length = extract_limit + 1 # Force limiting
10216
+
10217
+ if extract_length > extract_limit:
10006
10218
  validation_extract_nw = validation_extract_nw.head(extract_limit)
10007
10219
 
10008
10220
  # If a 'rows_distinct' validation step, then the extract should have the
@@ -10030,7 +10242,10 @@ class Validate:
10030
10242
  .drop("group_min_row")
10031
10243
  )
10032
10244
 
10033
- # Ensure that the extract is set to its native format
10245
+ # Ensure that the extract is collected and set to its native format
10246
+ # For LazyFrames (like PySpark), we need to collect before converting to native
10247
+ if hasattr(validation_extract_nw, "collect"):
10248
+ validation_extract_nw = validation_extract_nw.collect()
10034
10249
  validation.extract = nw.to_native(validation_extract_nw)
10035
10250
 
10036
10251
  # Get the end time for this step
@@ -11656,7 +11871,16 @@ class Validate:
11656
11871
  # TODO: add argument for user to specify the index column name
11657
11872
  index_name = "pb_index_"
11658
11873
 
11659
- data_nw = nw.from_native(self.data).with_row_index(name=index_name)
11874
+ data_nw = nw.from_native(self.data)
11875
+
11876
+ # Handle LazyFrame row indexing which requires order_by parameter
11877
+ try:
11878
+ # Try without order_by first (for DataFrames)
11879
+ data_nw = data_nw.with_row_index(name=index_name)
11880
+ except TypeError:
11881
+ # LazyFrames require order_by parameter - use first column for ordering
11882
+ first_col = data_nw.columns[0]
11883
+ data_nw = data_nw.with_row_index(name=index_name, order_by=first_col)
11660
11884
 
11661
11885
  # Get all validation step result tables and join together the `pb_is_good_` columns
11662
11886
  # ensuring that the columns are named uniquely (e.g., `pb_is_good_1`, `pb_is_good_2`, ...)
@@ -11665,7 +11889,13 @@ class Validate:
11665
11889
  results_tbl = nw.from_native(validation.tbl_checked)
11666
11890
 
11667
11891
  # Add row numbers to the results table
11668
- results_tbl = results_tbl.with_row_index(name=index_name)
11892
+ try:
11893
+ # Try without order_by first (for DataFrames)
11894
+ results_tbl = results_tbl.with_row_index(name=index_name)
11895
+ except TypeError:
11896
+ # LazyFrames require order_by parameter - use first column for ordering
11897
+ first_col = results_tbl.columns[0]
11898
+ results_tbl = results_tbl.with_row_index(name=index_name, order_by=first_col)
11669
11899
 
11670
11900
  # Add numerical suffix to the `pb_is_good_` column to make it unique
11671
11901
  results_tbl = results_tbl.select([index_name, "pb_is_good_"]).rename(
@@ -12284,15 +12514,21 @@ class Validate:
12284
12514
  # Transform to Narwhals DataFrame
12285
12515
  extract_nw = nw.from_native(extract)
12286
12516
 
12287
- # Get the number of rows in the extract
12288
- n_rows = len(extract_nw)
12517
+ # Get the number of rows in the extract (safe for LazyFrames)
12518
+ try:
12519
+ n_rows = len(extract_nw)
12520
+ except TypeError:
12521
+ # For LazyFrames, collect() first to get length
12522
+ n_rows = len(extract_nw.collect()) if hasattr(extract_nw, "collect") else 0
12289
12523
 
12290
12524
  # If the number of rows is zero, then produce an em dash then go to the next iteration
12291
12525
  if n_rows == 0:
12292
12526
  extract_upd.append("&mdash;")
12293
12527
  continue
12294
12528
 
12295
- # Write the CSV text
12529
+ # Write the CSV text (ensure LazyFrames are collected first)
12530
+ if hasattr(extract_nw, "collect"):
12531
+ extract_nw = extract_nw.collect()
12296
12532
  csv_text = extract_nw.write_csv()
12297
12533
 
12298
12534
  # Use Base64 encoding to encode the CSV text
@@ -13856,7 +14092,7 @@ def _prep_values_text(
13856
14092
  return values_str
13857
14093
 
13858
14094
 
13859
- def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
14095
+ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> list[tuple[str, str]]:
13860
14096
  """
13861
14097
  Obtain the segmentation categories from a table column.
13862
14098
 
@@ -13881,22 +14117,27 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
13881
14117
  list[tuple[str, str]]
13882
14118
  A list of tuples representing pairings of a column name and a value in the column.
13883
14119
  """
14120
+ import narwhals as nw
14121
+
13884
14122
  # Determine if the table is a DataFrame or a DB table
13885
14123
  tbl_type = _get_tbl_type(data=data_tbl)
13886
14124
 
13887
14125
  # Obtain the segmentation categories from the table column given as `segments_expr`
13888
- if tbl_type == "polars":
13889
- seg_categories = data_tbl[segments_expr].unique().to_list()
13890
- elif tbl_type == "pandas":
13891
- seg_categories = data_tbl[segments_expr].unique().tolist()
14126
+ if tbl_type in ["polars", "pandas", "pyspark"]:
14127
+ # Use Narwhals for supported DataFrame types
14128
+ data_nw = nw.from_native(data_tbl)
14129
+ unique_vals = data_nw.select(nw.col(segments_expr)).unique()
14130
+
14131
+ # Convert to list of values
14132
+ seg_categories = unique_vals[segments_expr].to_list()
13892
14133
  elif tbl_type in IBIS_BACKENDS:
13893
14134
  distinct_col_vals = data_tbl.select(segments_expr).distinct()
13894
14135
  seg_categories = distinct_col_vals[segments_expr].to_list()
13895
14136
  else: # pragma: no cover
13896
14137
  raise ValueError(f"Unsupported table type: {tbl_type}")
13897
14138
 
13898
- # Ensure that the categories are sorted
13899
- seg_categories.sort()
14139
+ # Ensure that the categories are sorted, and allow for None values
14140
+ seg_categories.sort(key=lambda x: (x is None, x))
13900
14141
 
13901
14142
  # Place each category and each value in a list of tuples as: `(column, value)`
13902
14143
  seg_tuples = [(segments_expr, category) for category in seg_categories]
@@ -13904,7 +14145,7 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
13904
14145
  return seg_tuples
13905
14146
 
13906
14147
 
13907
- def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
14148
+ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, Any]]:
13908
14149
  """
13909
14150
  Normalize the segments expression to a list of tuples, given a single tuple.
13910
14151
 
@@ -13930,17 +14171,23 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
13930
14171
 
13931
14172
  Returns
13932
14173
  -------
13933
- list[tuple[str, str]]
14174
+ list[tuple[str, Any]]
13934
14175
  A list of tuples representing pairings of a column name and a value in the column.
14176
+ Values can be any type, including None.
13935
14177
  """
14178
+ # Unpack the segments expression tuple for more convenient and explicit variable names
14179
+ column, segment = segments_expr
14180
+
13936
14181
  # Check if the first element is a string
13937
- if isinstance(segments_expr[0], str):
13938
- # If the second element is a list, create a list of tuples
13939
- if isinstance(segments_expr[1], list):
13940
- seg_tuples = [(segments_expr[0], value) for value in segments_expr[1]]
14182
+ if isinstance(column, str):
14183
+ if isinstance(segment, Segment):
14184
+ seg_tuples = [(column, seg) for seg in segment.segments]
14185
+ # If the second element is a collection, expand into a list of tuples
14186
+ elif isinstance(segment, (list, set, tuple)):
14187
+ seg_tuples = [(column, seg) for seg in segment]
13941
14188
  # If the second element is not a list, create a single tuple
13942
14189
  else:
13943
- seg_tuples = [(segments_expr[0], segments_expr[1])]
14190
+ seg_tuples = [(column, segment)]
13944
14191
  # If the first element is not a string, raise an error
13945
14192
  else: # pragma: no cover
13946
14193
  raise ValueError("The first element of the segments expression must be a string.")
@@ -13948,7 +14195,7 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
13948
14195
  return seg_tuples
13949
14196
 
13950
14197
 
13951
- def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
14198
+ def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
13952
14199
  """
13953
14200
  Apply the segments expression to the data table.
13954
14201
 
@@ -13971,15 +14218,24 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
13971
14218
  # Get the table type
13972
14219
  tbl_type = _get_tbl_type(data=data_tbl)
13973
14220
 
13974
- if tbl_type in ["pandas", "polars"]:
13975
- # If the table is a Pandas or Polars DataFrame, transforming to a Narwhals table
14221
+ # Unpack the segments expression tuple for more convenient and explicit variable names
14222
+ column, segment = segments_expr
14223
+
14224
+ if tbl_type in ["pandas", "polars", "pyspark"]:
14225
+ # If the table is a Pandas, Polars, or PySpark DataFrame, transforming to a Narwhals table
13976
14226
  # and perform the filtering operation
13977
14227
 
13978
14228
  # Transform to Narwhals table if a DataFrame
13979
14229
  data_tbl_nw = nw.from_native(data_tbl)
13980
14230
 
13981
- # Filter the data table based on the column name and value
13982
- data_tbl_nw = data_tbl_nw.filter(nw.col(segments_expr[0]) == segments_expr[1])
14231
+ # Filter the data table based on the column name and segment
14232
+ if segment is None:
14233
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_null())
14234
+ # Check if the segment is a segment group
14235
+ elif isinstance(segment, list):
14236
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_in(segment))
14237
+ else:
14238
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column) == segment)
13983
14239
 
13984
14240
  # Transform back to the original table type
13985
14241
  data_tbl = data_tbl_nw.to_native()
@@ -13987,8 +14243,13 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
13987
14243
  elif tbl_type in IBIS_BACKENDS:
13988
14244
  # If the table is an Ibis backend table, perform the filtering operation directly
13989
14245
 
13990
- # Filter the data table based on the column name and value
13991
- data_tbl = data_tbl[data_tbl[segments_expr[0]] == segments_expr[1]]
14246
+ # Filter the data table based on the column name and segment
14247
+ if segment is None:
14248
+ data_tbl = data_tbl[data_tbl[column].isnull()]
14249
+ elif isinstance(segment, list):
14250
+ data_tbl = data_tbl[data_tbl[column].isin(segment)]
14251
+ else:
14252
+ data_tbl = data_tbl[data_tbl[column] == segment]
13992
14253
 
13993
14254
  return data_tbl
13994
14255