pointblank 0.11.6__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pointblank/validate.py CHANGED
@@ -87,6 +87,7 @@ from pointblank._utils_check_args import (
87
87
  from pointblank._utils_html import _create_table_dims_html, _create_table_type_html
88
88
  from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col
89
89
  from pointblank.schema import Schema, _get_schema_validation_info
90
+ from pointblank.segments import Segment
90
91
  from pointblank.thresholds import (
91
92
  Actions,
92
93
  FinalActions,
@@ -1194,6 +1195,7 @@ def preview(
1194
1195
 
1195
1196
  - Polars DataFrame (`"polars"`)
1196
1197
  - Pandas DataFrame (`"pandas"`)
1198
+ - PySpark table (`"pyspark"`)
1197
1199
  - DuckDB table (`"duckdb"`)*
1198
1200
  - MySQL table (`"mysql"`)*
1199
1201
  - PostgreSQL table (`"postgresql"`)*
@@ -1201,7 +1203,6 @@ def preview(
1201
1203
  - Microsoft SQL Server table (`"mssql"`)*
1202
1204
  - Snowflake table (`"snowflake"`)*
1203
1205
  - Databricks table (`"databricks"`)*
1204
- - PySpark table (`"pyspark"`)*
1205
1206
  - BigQuery table (`"bigquery"`)*
1206
1207
  - Parquet table (`"parquet"`)*
1207
1208
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -1396,7 +1397,10 @@ def _generate_display_table(
1396
1397
  row_number_list: list[int] | None = None,
1397
1398
  ) -> GT:
1398
1399
  # Make a copy of the data to avoid modifying the original
1399
- data = copy.deepcopy(data)
1400
+ # Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
1401
+ tbl_type = _get_tbl_type(data=data)
1402
+ if "pyspark" not in tbl_type:
1403
+ data = copy.deepcopy(data)
1400
1404
 
1401
1405
  # Does the data table already have a leading row number column?
1402
1406
  if "_row_num_" in data.columns:
@@ -1422,22 +1426,31 @@ def _generate_display_table(
1422
1426
  # Determine if the table is a DataFrame or an Ibis table
1423
1427
  tbl_type = _get_tbl_type(data=data)
1424
1428
  ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
1425
- pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
1429
+ pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
1426
1430
 
1427
1431
  # Select the DataFrame library to use for displaying the Ibis table
1428
1432
  df_lib_gt = _select_df_lib(preference="polars")
1429
1433
  df_lib_name_gt = df_lib_gt.__name__
1430
1434
 
1431
- # If the table is a DataFrame (Pandas or Polars), set `df_lib_name_gt` to the name of the
1432
- # library (e.g., "polars" or "pandas")
1435
+ # If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
1436
+ # library (e.g., "polars", "pandas", or "pyspark")
1433
1437
  if pl_pb_tbl:
1434
- df_lib_name_gt = "polars" if "polars" in tbl_type else "pandas"
1438
+ if "polars" in tbl_type:
1439
+ df_lib_name_gt = "polars"
1440
+ elif "pandas" in tbl_type:
1441
+ df_lib_name_gt = "pandas"
1442
+ elif "pyspark" in tbl_type:
1443
+ df_lib_name_gt = "pyspark"
1435
1444
 
1436
- # Handle imports of Polars or Pandas here
1445
+ # Handle imports of Polars, Pandas, or PySpark here
1437
1446
  if df_lib_name_gt == "polars":
1438
1447
  import polars as pl
1439
- else:
1448
+ elif df_lib_name_gt == "pandas":
1440
1449
  import pandas as pd
1450
+ elif df_lib_name_gt == "pyspark":
1451
+ # Import pandas for conversion since Great Tables needs pandas DataFrame
1452
+ import pandas as pd
1453
+ # Note: PySpark import is handled as needed, typically already imported in user's environment
1441
1454
 
1442
1455
  # Get the initial column count for the table
1443
1456
  n_columns = len(data.columns)
@@ -1547,6 +1560,42 @@ def _generate_display_table(
1547
1560
  range(n_rows - n_tail + 1, n_rows + 1)
1548
1561
  )
1549
1562
 
1563
+ if tbl_type == "pyspark":
1564
+ n_rows = data.count()
1565
+
1566
+ # If n_head + n_tail is greater than the row count, display the entire table
1567
+ if n_head + n_tail >= n_rows:
1568
+ full_dataset = True
1569
+ # Convert to pandas for Great Tables compatibility
1570
+ data = data.toPandas()
1571
+
1572
+ row_number_list = range(1, n_rows + 1)
1573
+ else:
1574
+ # Get head and tail samples, then convert to pandas
1575
+ head_data = data.limit(n_head).toPandas()
1576
+
1577
+ # PySpark tail() returns a list of Row objects, need to convert to DataFrame
1578
+ tail_rows = data.tail(n_tail)
1579
+ if tail_rows:
1580
+ # Convert list of Row objects back to DataFrame, then to pandas
1581
+ tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema)
1582
+ tail_data = tail_df.toPandas()
1583
+ else:
1584
+ # If no tail data, create empty DataFrame with same schema
1585
+ import pandas as pd
1586
+
1587
+ tail_data = pd.DataFrame(columns=head_data.columns)
1588
+
1589
+ data = pd.concat([head_data, tail_data])
1590
+
1591
+ row_number_list = list(range(1, n_head + 1)) + list(
1592
+ range(n_rows - n_tail + 1, n_rows + 1)
1593
+ )
1594
+
1595
+ # For PySpark, update schema after conversion to pandas
1596
+ if tbl_type == "pyspark":
1597
+ tbl_schema = Schema(tbl=data)
1598
+
1550
1599
  # From the table schema, get a list of tuples containing column names and data types
1551
1600
  col_dtype_dict = tbl_schema.columns
1552
1601
 
@@ -1566,6 +1615,23 @@ def _generate_display_table(
1566
1615
  # This is used to highlight these values in the table
1567
1616
  if df_lib_name_gt == "polars":
1568
1617
  none_values = {k: data[k].is_null().to_list() for k in col_names}
1618
+ elif df_lib_name_gt == "pyspark":
1619
+ # For PySpark, check if data has been converted to pandas already
1620
+ if hasattr(data, "isnull"):
1621
+ # Data has been converted to pandas
1622
+ none_values = {k: data[k].isnull() for k in col_names}
1623
+ else:
1624
+ # Data is still a PySpark DataFrame - use narwhals
1625
+ import narwhals as nw
1626
+
1627
+ df_nw = nw.from_native(data)
1628
+ none_values = {}
1629
+ for col in col_names:
1630
+ # Get null mask, collect to pandas, then convert to list
1631
+ null_mask = (
1632
+ df_nw.select(nw.col(col).is_null()).collect().to_pandas().iloc[:, 0].tolist()
1633
+ )
1634
+ none_values[col] = null_mask
1569
1635
  else:
1570
1636
  none_values = {k: data[k].isnull() for k in col_names}
1571
1637
 
@@ -1579,7 +1645,13 @@ def _generate_display_table(
1579
1645
 
1580
1646
  for column in col_dtype_dict.keys():
1581
1647
  # Select a single column of values
1582
- data_col = data[[column]] if df_lib_name_gt == "pandas" else data.select([column])
1648
+ if df_lib_name_gt == "pandas":
1649
+ data_col = data[[column]]
1650
+ elif df_lib_name_gt == "pyspark":
1651
+ # PySpark data should have been converted to pandas by now
1652
+ data_col = data[[column]]
1653
+ else:
1654
+ data_col = data.select([column])
1583
1655
 
1584
1656
  # Using Great Tables, render the columns and get the list of values as formatted strings
1585
1657
  built_gt = GT(data=data_col).fmt_markdown(columns=column)._build_data(context="html")
@@ -1658,6 +1730,10 @@ def _generate_display_table(
1658
1730
  if df_lib_name_gt == "pandas":
1659
1731
  data.insert(0, "_row_num_", row_number_list)
1660
1732
 
1733
+ if df_lib_name_gt == "pyspark":
1734
+ # For PySpark converted to pandas, use pandas method
1735
+ data.insert(0, "_row_num_", row_number_list)
1736
+
1661
1737
  # Get the highest number in the `row_number_list` and calculate a width that will
1662
1738
  # safely fit a number of that magnitude
1663
1739
  if row_number_list: # Check if list is not empty
@@ -1791,6 +1867,7 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1791
1867
 
1792
1868
  - Polars DataFrame (`"polars"`)
1793
1869
  - Pandas DataFrame (`"pandas"`)
1870
+ - PySpark table (`"pyspark"`)
1794
1871
  - DuckDB table (`"duckdb"`)*
1795
1872
  - MySQL table (`"mysql"`)*
1796
1873
  - PostgreSQL table (`"postgresql"`)*
@@ -1798,7 +1875,6 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1798
1875
  - Microsoft SQL Server table (`"mssql"`)*
1799
1876
  - Snowflake table (`"snowflake"`)*
1800
1877
  - Databricks table (`"databricks"`)*
1801
- - PySpark table (`"pyspark"`)*
1802
1878
  - BigQuery table (`"bigquery"`)*
1803
1879
  - Parquet table (`"parquet"`)*
1804
1880
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -1851,7 +1927,10 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1851
1927
  data = _process_data(data)
1852
1928
 
1853
1929
  # Make a copy of the data to avoid modifying the original
1854
- data = copy.deepcopy(data)
1930
+ # Note: PySpark DataFrames cannot be deep copied due to SparkContext serialization issues
1931
+ tbl_type = _get_tbl_type(data=data)
1932
+ if "pyspark" not in tbl_type:
1933
+ data = copy.deepcopy(data)
1855
1934
 
1856
1935
  # Get the number of rows in the table
1857
1936
  n_rows = get_row_count(data)
@@ -1868,22 +1947,28 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
1868
1947
  # Determine if the table is a DataFrame or an Ibis table
1869
1948
  tbl_type = _get_tbl_type(data=data)
1870
1949
  ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
1871
- pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type
1950
+ pl_pb_tbl = "polars" in tbl_type or "pandas" in tbl_type or "pyspark" in tbl_type
1872
1951
 
1873
1952
  # Select the DataFrame library to use for displaying the Ibis table
1874
1953
  df_lib_gt = _select_df_lib(preference="polars")
1875
1954
  df_lib_name_gt = df_lib_gt.__name__
1876
1955
 
1877
- # If the table is a DataFrame (Pandas or Polars), set `df_lib_name_gt` to the name of the
1878
- # library (e.g., "polars" or "pandas")
1956
+ # If the table is a DataFrame (Pandas, Polars, or PySpark), set `df_lib_name_gt` to the name of the
1957
+ # library (e.g., "polars", "pandas", or "pyspark")
1879
1958
  if pl_pb_tbl:
1880
- df_lib_name_gt = "polars" if "polars" in tbl_type else "pandas"
1959
+ if "polars" in tbl_type:
1960
+ df_lib_name_gt = "polars"
1961
+ elif "pandas" in tbl_type:
1962
+ df_lib_name_gt = "pandas"
1963
+ elif "pyspark" in tbl_type:
1964
+ df_lib_name_gt = "pyspark"
1881
1965
 
1882
- # Handle imports of Polars or Pandas here
1966
+ # Handle imports of Polars, Pandas, or PySpark here
1883
1967
  if df_lib_name_gt == "polars":
1884
1968
  import polars as pl
1885
- else:
1969
+ elif df_lib_name_gt == "pandas":
1886
1970
  import pandas as pd
1971
+ # Note: PySpark import is handled as needed, typically already imported in user's environment
1887
1972
 
1888
1973
  # From an Ibis table:
1889
1974
  # - get the row count
@@ -2047,6 +2132,77 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
2047
2132
  # Get a dictionary of counts of missing values in each column
2048
2133
  missing_val_counts = {col: data[col].isnull().sum() for col in data.columns}
2049
2134
 
2135
+ if "pyspark" in tbl_type:
2136
+ from pyspark.sql.functions import col as pyspark_col
2137
+
2138
+ # PySpark implementation for missing values calculation
2139
+ missing_vals = {}
2140
+ for col_name in data.columns:
2141
+ col_missing_props = []
2142
+
2143
+ # Calculate missing value proportions for each sector
2144
+ for i in range(len(cut_points)):
2145
+ start_row = cut_points[i - 1] if i > 0 else 0
2146
+ end_row = cut_points[i]
2147
+ sector_size = end_row - start_row
2148
+
2149
+ if sector_size > 0:
2150
+ # Use row_number() to filter rows by range
2151
+ from pyspark.sql.functions import row_number
2152
+ from pyspark.sql.window import Window
2153
+
2154
+ window = Window.orderBy(
2155
+ pyspark_col(data.columns[0])
2156
+ ) # Order by first column
2157
+ sector_data = data.withColumn("row_num", row_number().over(window)).filter(
2158
+ (pyspark_col("row_num") > start_row)
2159
+ & (pyspark_col("row_num") <= end_row)
2160
+ )
2161
+
2162
+ # Count nulls in this sector
2163
+ null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
2164
+ missing_prop = (null_count / sector_size) * 100
2165
+ col_missing_props.append(missing_prop)
2166
+ else:
2167
+ col_missing_props.append(0)
2168
+
2169
+ # Handle the final sector (after last cut point)
2170
+ if n_rows > cut_points[-1]:
2171
+ start_row = cut_points[-1]
2172
+ end_row = n_rows
2173
+ sector_size = end_row - start_row
2174
+
2175
+ from pyspark.sql.functions import row_number
2176
+ from pyspark.sql.window import Window
2177
+
2178
+ window = Window.orderBy(pyspark_col(data.columns[0]))
2179
+ sector_data = data.withColumn("row_num", row_number().over(window)).filter(
2180
+ pyspark_col("row_num") > start_row
2181
+ )
2182
+
2183
+ null_count = sector_data.filter(pyspark_col(col_name).isNull()).count()
2184
+ missing_prop = (null_count / sector_size) * 100
2185
+ col_missing_props.append(missing_prop)
2186
+ else:
2187
+ col_missing_props.append(0)
2188
+
2189
+ missing_vals[col_name] = col_missing_props
2190
+
2191
+ # Pivot the `missing_vals` dictionary to create a table with the missing value proportions
2192
+ missing_vals = {
2193
+ "columns": list(missing_vals.keys()),
2194
+ **{
2195
+ str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
2196
+ for i in range(len(cut_points) + 1)
2197
+ },
2198
+ }
2199
+
2200
+ # Get a dictionary of counts of missing values in each column
2201
+ missing_val_counts = {}
2202
+ for col_name in data.columns:
2203
+ null_count = data.filter(pyspark_col(col_name).isNull()).count()
2204
+ missing_val_counts[col_name] = null_count
2205
+
2050
2206
  # From `missing_vals`, create the DataFrame with the missing value proportions
2051
2207
  if df_lib_name_gt == "polars":
2052
2208
  import polars as pl
@@ -2333,6 +2489,7 @@ def get_column_count(data: FrameT | Any) -> int:
2333
2489
 
2334
2490
  - Polars DataFrame (`"polars"`)
2335
2491
  - Pandas DataFrame (`"pandas"`)
2492
+ - PySpark table (`"pyspark"`)
2336
2493
  - DuckDB table (`"duckdb"`)*
2337
2494
  - MySQL table (`"mysql"`)*
2338
2495
  - PostgreSQL table (`"postgresql"`)*
@@ -2340,7 +2497,6 @@ def get_column_count(data: FrameT | Any) -> int:
2340
2497
  - Microsoft SQL Server table (`"mssql"`)*
2341
2498
  - Snowflake table (`"snowflake"`)*
2342
2499
  - Databricks table (`"databricks"`)*
2343
- - PySpark table (`"pyspark"`)*
2344
2500
  - BigQuery table (`"bigquery"`)*
2345
2501
  - Parquet table (`"parquet"`)*
2346
2502
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -2458,20 +2614,18 @@ def get_column_count(data: FrameT | Any) -> int:
2458
2614
  # Handle list of file paths (likely Parquet files)
2459
2615
  data = _process_parquet_input(data)
2460
2616
 
2461
- if "ibis.expr.types.relations.Table" in str(type(data)):
2462
- return len(data.columns)
2463
-
2464
- elif "polars" in str(type(data)):
2465
- return len(data.columns)
2466
-
2467
- elif "pandas" in str(type(data)):
2468
- return data.shape[1]
2469
-
2470
- elif "narwhals" in str(type(data)):
2471
- return len(data.columns)
2617
+ # Use Narwhals to handle all DataFrame types (including Ibis) uniformly
2618
+ try:
2619
+ import narwhals as nw
2472
2620
 
2473
- else:
2474
- raise ValueError("The input table type supplied in `data=` is not supported.")
2621
+ df_nw = nw.from_native(data)
2622
+ return len(df_nw.columns)
2623
+ except Exception:
2624
+ # Fallback for unsupported types
2625
+ if "pandas" in str(type(data)):
2626
+ return data.shape[1]
2627
+ else:
2628
+ raise ValueError("The input table type supplied in `data=` is not supported.")
2475
2629
 
2476
2630
 
2477
2631
  def get_row_count(data: FrameT | Any) -> int:
@@ -2501,6 +2655,7 @@ def get_row_count(data: FrameT | Any) -> int:
2501
2655
 
2502
2656
  - Polars DataFrame (`"polars"`)
2503
2657
  - Pandas DataFrame (`"pandas"`)
2658
+ - PySpark table (`"pyspark"`)
2504
2659
  - DuckDB table (`"duckdb"`)*
2505
2660
  - MySQL table (`"mysql"`)*
2506
2661
  - PostgreSQL table (`"postgresql"`)*
@@ -2508,7 +2663,6 @@ def get_row_count(data: FrameT | Any) -> int:
2508
2663
  - Microsoft SQL Server table (`"mssql"`)*
2509
2664
  - Snowflake table (`"snowflake"`)*
2510
2665
  - Databricks table (`"databricks"`)*
2511
- - PySpark table (`"pyspark"`)*
2512
2666
  - BigQuery table (`"bigquery"`)*
2513
2667
  - Parquet table (`"parquet"`)*
2514
2668
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -2627,30 +2781,29 @@ def get_row_count(data: FrameT | Any) -> int:
2627
2781
  # Handle list of file paths (likely Parquet files)
2628
2782
  data = _process_parquet_input(data)
2629
2783
 
2630
- if "ibis.expr.types.relations.Table" in str(type(data)):
2631
- # Determine whether Pandas or Polars is available to get the row count
2632
- _check_any_df_lib(method_used="get_row_count")
2633
-
2634
- # Select the DataFrame library to use for displaying the Ibis table
2635
- df_lib = _select_df_lib(preference="polars")
2636
- df_lib_name = df_lib.__name__
2637
-
2638
- if df_lib_name == "pandas":
2639
- return int(data.count().to_pandas())
2784
+ # Use Narwhals to handle all DataFrame types (including Ibis) uniformly
2785
+ try:
2786
+ import narwhals as nw
2787
+
2788
+ df_nw = nw.from_native(data)
2789
+ # Handle LazyFrames by collecting them first
2790
+ if hasattr(df_nw, "collect"):
2791
+ df_nw = df_nw.collect()
2792
+ # Try different ways to get row count
2793
+ if hasattr(df_nw, "shape"):
2794
+ return df_nw.shape[0]
2795
+ elif hasattr(df_nw, "height"):
2796
+ return df_nw.height
2640
2797
  else:
2641
- return int(data.count().to_polars())
2642
-
2643
- elif "polars" in str(type(data)):
2644
- return int(data.height)
2645
-
2646
- elif "pandas" in str(type(data)):
2647
- return data.shape[0]
2648
-
2649
- elif "narwhals" in str(type(data)):
2650
- return data.shape[0]
2651
-
2652
- else:
2653
- raise ValueError("The input table type supplied in `data=` is not supported.")
2798
+ raise ValueError("Unable to determine row count from Narwhals DataFrame")
2799
+ except Exception:
2800
+ # Fallback for types that don't work with Narwhals
2801
+ if "pandas" in str(type(data)):
2802
+ return data.shape[0]
2803
+ elif "pyspark" in str(type(data)):
2804
+ return data.count()
2805
+ else:
2806
+ raise ValueError("The input table type supplied in `data=` is not supported.")
2654
2807
 
2655
2808
 
2656
2809
  @dataclass
@@ -3098,6 +3251,7 @@ class Validate:
3098
3251
 
3099
3252
  - Polars DataFrame (`"polars"`)
3100
3253
  - Pandas DataFrame (`"pandas"`)
3254
+ - PySpark table (`"pyspark"`)
3101
3255
  - DuckDB table (`"duckdb"`)*
3102
3256
  - MySQL table (`"mysql"`)*
3103
3257
  - PostgreSQL table (`"postgresql"`)*
@@ -3105,7 +3259,6 @@ class Validate:
3105
3259
  - Microsoft SQL Server table (`"mssql"`)*
3106
3260
  - Snowflake table (`"snowflake"`)*
3107
3261
  - Databricks table (`"databricks"`)*
3108
- - PySpark table (`"pyspark"`)*
3109
3262
  - BigQuery table (`"bigquery"`)*
3110
3263
  - Parquet table (`"parquet"`)*
3111
3264
  - CSV files (string path or `pathlib.Path` object with `.csv` extension)
@@ -9983,12 +10136,22 @@ class Validate:
9983
10136
  and tbl_type not in IBIS_BACKENDS
9984
10137
  ):
9985
10138
  # Add row numbers to the results table
9986
- validation_extract_nw = (
9987
- nw.from_native(results_tbl)
9988
- .with_row_index(name="_row_num_")
9989
- .filter(nw.col("pb_is_good_") == False) # noqa
9990
- .drop("pb_is_good_")
9991
- )
10139
+ validation_extract_nw = nw.from_native(results_tbl)
10140
+
10141
+ # Handle LazyFrame row indexing which requires order_by parameter
10142
+ try:
10143
+ # Try without order_by first (for DataFrames)
10144
+ validation_extract_nw = validation_extract_nw.with_row_index(name="_row_num_")
10145
+ except TypeError:
10146
+ # LazyFrames require order_by parameter - use first column for ordering
10147
+ first_col = validation_extract_nw.columns[0]
10148
+ validation_extract_nw = validation_extract_nw.with_row_index(
10149
+ name="_row_num_", order_by=first_col
10150
+ )
10151
+
10152
+ validation_extract_nw = validation_extract_nw.filter(~nw.col("pb_is_good_")).drop(
10153
+ "pb_is_good_"
10154
+ ) # noqa
9992
10155
 
9993
10156
  # Add 1 to the row numbers to make them 1-indexed
9994
10157
  validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)
@@ -9997,12 +10160,52 @@ class Validate:
9997
10160
  if get_first_n is not None:
9998
10161
  validation_extract_nw = validation_extract_nw.head(get_first_n)
9999
10162
  elif sample_n is not None:
10000
- validation_extract_nw = validation_extract_nw.sample(n=sample_n)
10163
+ # Narwhals LazyFrame doesn't have sample method, use head after shuffling
10164
+ try:
10165
+ validation_extract_nw = validation_extract_nw.sample(n=sample_n)
10166
+ except AttributeError:
10167
+ # For LazyFrames without sample method, collect first then sample
10168
+ validation_extract_native = validation_extract_nw.collect().to_native()
10169
+ if hasattr(validation_extract_native, "sample"):
10170
+ # PySpark DataFrame has sample method
10171
+ validation_extract_native = validation_extract_native.sample(
10172
+ fraction=min(1.0, sample_n / validation_extract_native.count())
10173
+ ).limit(sample_n)
10174
+ validation_extract_nw = nw.from_native(validation_extract_native)
10175
+ else:
10176
+ # Fallback: just take first n rows after collecting
10177
+ validation_extract_nw = validation_extract_nw.collect().head(sample_n)
10001
10178
  elif sample_frac is not None:
10002
- validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
10179
+ try:
10180
+ validation_extract_nw = validation_extract_nw.sample(fraction=sample_frac)
10181
+ except AttributeError:
10182
+ # For LazyFrames without sample method, collect first then sample
10183
+ validation_extract_native = validation_extract_nw.collect().to_native()
10184
+ if hasattr(validation_extract_native, "sample"):
10185
+ # PySpark DataFrame has sample method
10186
+ validation_extract_native = validation_extract_native.sample(
10187
+ fraction=sample_frac
10188
+ )
10189
+ validation_extract_nw = nw.from_native(validation_extract_native)
10190
+ else:
10191
+ # Fallback: use fraction to calculate head size
10192
+ collected = validation_extract_nw.collect()
10193
+ sample_size = max(1, int(len(collected) * sample_frac))
10194
+ validation_extract_nw = collected.head(sample_size)
10003
10195
 
10004
10196
  # Ensure a limit is set on the number of rows to extract
10005
- if len(validation_extract_nw) > extract_limit:
10197
+ try:
10198
+ # For DataFrames, use len()
10199
+ extract_length = len(validation_extract_nw)
10200
+ except TypeError:
10201
+ # For LazyFrames, collect to get length (or use a reasonable default)
10202
+ try:
10203
+ extract_length = len(validation_extract_nw.collect())
10204
+ except Exception:
10205
+ # If collection fails, apply limit anyway as a safety measure
10206
+ extract_length = extract_limit + 1 # Force limiting
10207
+
10208
+ if extract_length > extract_limit:
10006
10209
  validation_extract_nw = validation_extract_nw.head(extract_limit)
10007
10210
 
10008
10211
  # If a 'rows_distinct' validation step, then the extract should have the
@@ -10030,7 +10233,10 @@ class Validate:
10030
10233
  .drop("group_min_row")
10031
10234
  )
10032
10235
 
10033
- # Ensure that the extract is set to its native format
10236
+ # Ensure that the extract is collected and set to its native format
10237
+ # For LazyFrames (like PySpark), we need to collect before converting to native
10238
+ if hasattr(validation_extract_nw, "collect"):
10239
+ validation_extract_nw = validation_extract_nw.collect()
10034
10240
  validation.extract = nw.to_native(validation_extract_nw)
10035
10241
 
10036
10242
  # Get the end time for this step
@@ -11656,7 +11862,16 @@ class Validate:
11656
11862
  # TODO: add argument for user to specify the index column name
11657
11863
  index_name = "pb_index_"
11658
11864
 
11659
- data_nw = nw.from_native(self.data).with_row_index(name=index_name)
11865
+ data_nw = nw.from_native(self.data)
11866
+
11867
+ # Handle LazyFrame row indexing which requires order_by parameter
11868
+ try:
11869
+ # Try without order_by first (for DataFrames)
11870
+ data_nw = data_nw.with_row_index(name=index_name)
11871
+ except TypeError:
11872
+ # LazyFrames require order_by parameter - use first column for ordering
11873
+ first_col = data_nw.columns[0]
11874
+ data_nw = data_nw.with_row_index(name=index_name, order_by=first_col)
11660
11875
 
11661
11876
  # Get all validation step result tables and join together the `pb_is_good_` columns
11662
11877
  # ensuring that the columns are named uniquely (e.g., `pb_is_good_1`, `pb_is_good_2`, ...)
@@ -11665,7 +11880,13 @@ class Validate:
11665
11880
  results_tbl = nw.from_native(validation.tbl_checked)
11666
11881
 
11667
11882
  # Add row numbers to the results table
11668
- results_tbl = results_tbl.with_row_index(name=index_name)
11883
+ try:
11884
+ # Try without order_by first (for DataFrames)
11885
+ results_tbl = results_tbl.with_row_index(name=index_name)
11886
+ except TypeError:
11887
+ # LazyFrames require order_by parameter - use first column for ordering
11888
+ first_col = results_tbl.columns[0]
11889
+ results_tbl = results_tbl.with_row_index(name=index_name, order_by=first_col)
11669
11890
 
11670
11891
  # Add numerical suffix to the `pb_is_good_` column to make it unique
11671
11892
  results_tbl = results_tbl.select([index_name, "pb_is_good_"]).rename(
@@ -12284,15 +12505,21 @@ class Validate:
12284
12505
  # Transform to Narwhals DataFrame
12285
12506
  extract_nw = nw.from_native(extract)
12286
12507
 
12287
- # Get the number of rows in the extract
12288
- n_rows = len(extract_nw)
12508
+ # Get the number of rows in the extract (safe for LazyFrames)
12509
+ try:
12510
+ n_rows = len(extract_nw)
12511
+ except TypeError:
12512
+ # For LazyFrames, collect() first to get length
12513
+ n_rows = len(extract_nw.collect()) if hasattr(extract_nw, "collect") else 0
12289
12514
 
12290
12515
  # If the number of rows is zero, then produce an em dash then go to the next iteration
12291
12516
  if n_rows == 0:
12292
12517
  extract_upd.append("&mdash;")
12293
12518
  continue
12294
12519
 
12295
- # Write the CSV text
12520
+ # Write the CSV text (ensure LazyFrames are collected first)
12521
+ if hasattr(extract_nw, "collect"):
12522
+ extract_nw = extract_nw.collect()
12296
12523
  csv_text = extract_nw.write_csv()
12297
12524
 
12298
12525
  # Use Base64 encoding to encode the CSV text
@@ -13856,7 +14083,7 @@ def _prep_values_text(
13856
14083
  return values_str
13857
14084
 
13858
14085
 
13859
- def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
14086
+ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> list[tuple[str, str]]:
13860
14087
  """
13861
14088
  Obtain the segmentation categories from a table column.
13862
14089
 
@@ -13881,22 +14108,27 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
13881
14108
  list[tuple[str, str]]
13882
14109
  A list of tuples representing pairings of a column name and a value in the column.
13883
14110
  """
14111
+ import narwhals as nw
14112
+
13884
14113
  # Determine if the table is a DataFrame or a DB table
13885
14114
  tbl_type = _get_tbl_type(data=data_tbl)
13886
14115
 
13887
14116
  # Obtain the segmentation categories from the table column given as `segments_expr`
13888
- if tbl_type == "polars":
13889
- seg_categories = data_tbl[segments_expr].unique().to_list()
13890
- elif tbl_type == "pandas":
13891
- seg_categories = data_tbl[segments_expr].unique().tolist()
14117
+ if tbl_type in ["polars", "pandas", "pyspark"]:
14118
+ # Use Narwhals for supported DataFrame types
14119
+ data_nw = nw.from_native(data_tbl)
14120
+ unique_vals = data_nw.select(nw.col(segments_expr)).unique()
14121
+
14122
+ # Convert to list of values
14123
+ seg_categories = unique_vals[segments_expr].to_list()
13892
14124
  elif tbl_type in IBIS_BACKENDS:
13893
14125
  distinct_col_vals = data_tbl.select(segments_expr).distinct()
13894
14126
  seg_categories = distinct_col_vals[segments_expr].to_list()
13895
14127
  else: # pragma: no cover
13896
14128
  raise ValueError(f"Unsupported table type: {tbl_type}")
13897
14129
 
13898
- # Ensure that the categories are sorted
13899
- seg_categories.sort()
14130
+ # Ensure that the categories are sorted, and allow for None values
14131
+ seg_categories.sort(key=lambda x: (x is None, x))
13900
14132
 
13901
14133
  # Place each category and each value in a list of tuples as: `(column, value)`
13902
14134
  seg_tuples = [(segments_expr, category) for category in seg_categories]
@@ -13904,7 +14136,7 @@ def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> tuple[str, str]:
13904
14136
  return seg_tuples
13905
14137
 
13906
14138
 
13907
- def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
14139
+ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, Any]]:
13908
14140
  """
13909
14141
  Normalize the segments expression to a list of tuples, given a single tuple.
13910
14142
 
@@ -13930,17 +14162,23 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
13930
14162
 
13931
14163
  Returns
13932
14164
  -------
13933
- list[tuple[str, str]]
14165
+ list[tuple[str, Any]]
13934
14166
  A list of tuples representing pairings of a column name and a value in the column.
14167
+ Values can be any type, including None.
13935
14168
  """
14169
+ # Unpack the segments expression tuple for more convenient and explicit variable names
14170
+ column, segment = segments_expr
14171
+
13936
14172
  # Check if the first element is a string
13937
- if isinstance(segments_expr[0], str):
13938
- # If the second element is a list, create a list of tuples
13939
- if isinstance(segments_expr[1], list):
13940
- seg_tuples = [(segments_expr[0], value) for value in segments_expr[1]]
14173
+ if isinstance(column, str):
14174
+ if isinstance(segment, Segment):
14175
+ seg_tuples = [(column, seg) for seg in segment.segments]
14176
+ # If the second element is a collection, expand into a list of tuples
14177
+ elif isinstance(segment, (list, set, tuple)):
14178
+ seg_tuples = [(column, seg) for seg in segment]
13941
14179
  # If the second element is not a list, create a single tuple
13942
14180
  else:
13943
- seg_tuples = [(segments_expr[0], segments_expr[1])]
14181
+ seg_tuples = [(column, segment)]
13944
14182
  # If the first element is not a string, raise an error
13945
14183
  else: # pragma: no cover
13946
14184
  raise ValueError("The first element of the segments expression must be a string.")
@@ -13948,7 +14186,7 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, str]]:
13948
14186
  return seg_tuples
13949
14187
 
13950
14188
 
13951
- def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
14189
+ def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
13952
14190
  """
13953
14191
  Apply the segments expression to the data table.
13954
14192
 
@@ -13971,15 +14209,24 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
13971
14209
  # Get the table type
13972
14210
  tbl_type = _get_tbl_type(data=data_tbl)
13973
14211
 
13974
- if tbl_type in ["pandas", "polars"]:
13975
- # If the table is a Pandas or Polars DataFrame, transforming to a Narwhals table
14212
+ # Unpack the segments expression tuple for more convenient and explicit variable names
14213
+ column, segment = segments_expr
14214
+
14215
+ if tbl_type in ["pandas", "polars", "pyspark"]:
14216
+ # If the table is a Pandas, Polars, or PySpark DataFrame, transforming to a Narwhals table
13976
14217
  # and perform the filtering operation
13977
14218
 
13978
14219
  # Transform to Narwhals table if a DataFrame
13979
14220
  data_tbl_nw = nw.from_native(data_tbl)
13980
14221
 
13981
- # Filter the data table based on the column name and value
13982
- data_tbl_nw = data_tbl_nw.filter(nw.col(segments_expr[0]) == segments_expr[1])
14222
+ # Filter the data table based on the column name and segment
14223
+ if segment is None:
14224
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_null())
14225
+ # Check if the segment is a segment group
14226
+ elif isinstance(segment, list):
14227
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column).is_in(segment))
14228
+ else:
14229
+ data_tbl_nw = data_tbl_nw.filter(nw.col(column) == segment)
13983
14230
 
13984
14231
  # Transform back to the original table type
13985
14232
  data_tbl = data_tbl_nw.to_native()
@@ -13987,8 +14234,13 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, str]) -> any:
13987
14234
  elif tbl_type in IBIS_BACKENDS:
13988
14235
  # If the table is an Ibis backend table, perform the filtering operation directly
13989
14236
 
13990
- # Filter the data table based on the column name and value
13991
- data_tbl = data_tbl[data_tbl[segments_expr[0]] == segments_expr[1]]
14237
+ # Filter the data table based on the column name and segment
14238
+ if segment is None:
14239
+ data_tbl = data_tbl[data_tbl[column].isnull()]
14240
+ elif isinstance(segment, list):
14241
+ data_tbl = data_tbl[data_tbl[column].isin(segment)]
14242
+ else:
14243
+ data_tbl = data_tbl[data_tbl[column] == segment]
13992
14244
 
13993
14245
  return data_tbl
13994
14246