polarfrost 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
polarfrost/mondrian.py CHANGED
@@ -10,6 +10,7 @@ if TYPE_CHECKING:
10
10
  from pyspark.sql import DataFrame as SparkDataFrame
11
11
  from pyspark.sql.types import StructType
12
12
 
13
+
13
14
  # ------------------------- POLARS VERSION -------------------------
14
15
  def mondrian_k_anonymity_polars(
15
16
  df: "pl.DataFrame | pl.LazyFrame",
@@ -24,29 +25,40 @@ def mondrian_k_anonymity_polars(
24
25
  """
25
26
  if categorical is None:
26
27
  categorical = []
27
-
28
+
29
+ # Input validation
30
+ if not isinstance(df, (pl.DataFrame, pl.LazyFrame)):
31
+ raise ValueError("Input must be a Polars DataFrame or LazyFrame")
32
+
28
33
  # Convert to LazyFrame if not already
29
34
  if isinstance(df, pl.DataFrame):
30
35
  df = df.lazy()
31
- elif not isinstance(df, pl.LazyFrame):
32
- raise ValueError("Input must be a Polars DataFrame or LazyFrame")
33
-
36
+
37
+ # Check for empty DataFrame by collecting a sample
38
+ if df.select(pl.len()).collect().item(0, 0) == 0:
39
+ raise ValueError("Input DataFrame cannot be empty")
40
+
41
+ # Validate k is a positive integer
42
+ if not isinstance(k, (int, str)) or (isinstance(k, str) and not k.isdigit()) or int(k) < 1:
43
+ raise ValueError("k must be a positive integer")
44
+ k = int(k) # Convert to int if it's a string of digits
45
+
34
46
  # Initialize partitions with the full dataset
35
47
  partitions = [df]
36
48
  result = []
37
-
49
+
38
50
  # Process partitions until none left
39
51
  while partitions:
40
52
  part = partitions.pop()
41
-
53
+
42
54
  # Get partition size (lazy evaluation)
43
55
  n_rows = part.select(pl.len()).collect().item(0, 0)
44
-
56
+
45
57
  # If partition is too small to split, add to results
46
58
  if n_rows < 2 * k:
47
59
  result.append(part)
48
60
  continue
49
-
61
+
50
62
  # Compute spans for each quasi-identifier
51
63
  spans: Dict[str, Any] = {}
52
64
  for col in quasi_identifiers:
@@ -56,26 +68,39 @@ def mondrian_k_anonymity_polars(
56
68
  spans[col] = n_unique
57
69
  else:
58
70
  # For numerical, use range as span
59
- stats = part.select([
60
- pl.col(col).min().alias("min"),
61
- pl.col(col).max().alias("max")
62
- ]).collect()
71
+ stats = part.select(
72
+ [pl.col(col).min().alias("min"), pl.col(col).max().alias("max")]
73
+ ).collect()
63
74
  col_min = stats[0, "min"]
64
75
  col_max = stats[0, "max"]
65
- spans[col] = col_max - col_min if col_max is not None and col_min is not None else 0
66
-
76
+
77
+ # Handle string comparison by converting to float if possible
78
+ if col_min is not None and col_max is not None:
79
+ try:
80
+ # Try to convert to float for comparison
81
+ min_val = float(col_min) if not isinstance(col_min, (int, float)) else col_min
82
+ max_val = float(col_max) if not isinstance(col_max, (int, float)) else col_max
83
+ spans[col] = max_val - min_val
84
+ except (ValueError, TypeError):
85
+ # If conversion fails, use string length difference
86
+ spans[col] = abs(len(str(col_max)) - len(str(col_min)))
87
+ else:
88
+ spans[col] = 0
89
+
67
90
  # Find the attribute with maximum span
68
91
  split_col = max(spans, key=spans.get) # type: ignore
69
-
92
+
70
93
  # If no split possible, add to results
71
94
  if spans[split_col] == 0:
72
95
  result.append(part)
73
96
  continue
74
-
97
+
75
98
  # Split the partition
76
99
  if split_col in categorical:
77
100
  # For categorical, split on unique values
78
- uniq_vals = part.select(pl.col(split_col).unique()).collect().to_series().to_list()
101
+ uniq_vals = (
102
+ part.select(pl.col(split_col).unique()).collect().to_series().to_list()
103
+ )
79
104
  mid = len(uniq_vals) // 2
80
105
  left_vals = set(uniq_vals[:mid])
81
106
  right_vals = set(uniq_vals[mid:])
@@ -86,41 +111,57 @@ def mondrian_k_anonymity_polars(
86
111
  median = part.select(pl.col(split_col).median()).collect().item()
87
112
  left = part.filter(pl.col(split_col) <= median)
88
113
  right = part.filter(pl.col(split_col) > median)
89
-
114
+
90
115
  # Check if both partitions satisfy k-anonymity
91
116
  left_n = left.select(pl.len()).collect().item(0, 0)
92
117
  right_n = right.select(pl.len()).collect().item(0, 0)
93
-
118
+
94
119
  if left_n >= k and right_n >= k:
95
120
  # Both partitions are valid, continue splitting
96
121
  partitions.extend([left, right])
97
122
  else:
98
123
  # At least one partition is too small, keep as is
99
124
  result.append(part)
100
-
125
+
101
126
  # Aggregate each partition
102
127
  agg_rows = []
103
128
  for part in result:
104
129
  # Collect only the columns we need
105
130
  part_df = part.select(quasi_identifiers + [sensitive_column]).collect()
106
131
  row = {}
107
-
132
+
108
133
  # Generalize quasi-identifiers
109
134
  for col in quasi_identifiers:
110
135
  if col in categorical:
111
136
  # For categorical, use set of unique values
112
- row[col] = ','.join(sorted(map(str, part_df[col].unique())))
137
+ unique_vals = part_df[col].unique()
138
+ row[col] = ",".join(sorted(str(v) for v in unique_vals))
113
139
  else:
114
140
  # For numerical, use range
115
- row[col] = f"{part_df[col].min()}-{part_df[col].max()}"
116
-
141
+ min_val = part_df[col].min()
142
+ max_val = part_df[col].max()
143
+
144
+ # Ensure we have valid numeric values
145
+ if min_val is None or max_val is None:
146
+ row[col] = "*" # Handle null values
147
+ else:
148
+ # Convert to string, handling bytes and other types
149
+ min_str = min_val.decode("utf-8") if isinstance(min_val, bytes) else str(min_val)
150
+ max_str = max_val.decode("utf-8") if isinstance(max_val, bytes) else str(max_val)
151
+
152
+ # Store as string range
153
+ row[col] = f"{min_str}-{max_str}"
154
+
117
155
  # Add sensitive values and count
118
- row[sensitive_column] = ','.join(sorted(map(str, part_df[sensitive_column].unique())))
119
- row['count'] = part_df.height
156
+ sensitive_vals = part_df[sensitive_column].unique()
157
+ row[sensitive_column] = ",".join(sorted(str(v) for v in sensitive_vals))
158
+ # Store count as integer
159
+ row["count"] = int(part_df.height)
120
160
  agg_rows.append(row)
121
-
161
+
122
162
  return pl.DataFrame(agg_rows)
123
163
 
164
+
124
165
  # ------------------------- PYSPARK VERSION -------------------------
125
166
  def mondrian_k_anonymity_spark(
126
167
  df: "SparkDataFrame",
@@ -132,26 +173,50 @@ def mondrian_k_anonymity_spark(
132
173
  ) -> "SparkDataFrame":
133
174
  """
134
175
  Perform Mondrian k-anonymity using PySpark for distributed processing.
176
+
177
+ Args:
178
+ df: Input PySpark DataFrame
179
+ quasi_identifiers: List of column names that are quasi-identifiers
180
+ sensitive_column: Name of the sensitive column
181
+ k: Anonymity parameter (minimum group size), must be a positive integer
182
+ categorical: List of categorical column names
183
+ schema: Schema for the output DataFrame
184
+
185
+ Returns:
186
+ Anonymized DataFrame with generalized quasi-identifiers
135
187
  """
136
188
  import pandas as pd
137
189
  from pyspark.sql.functions import pandas_udf, PandasUDFType
138
190
 
191
+ # Validate k parameter first
192
+ if not isinstance(k, int) or k <= 0:
193
+ raise ValueError("k must be a positive integer")
194
+
195
+ # Validate schema
196
+ if schema is None:
197
+ raise ValueError("Schema must be provided for PySpark UDF")
198
+
199
+ # Check for empty DataFrame
200
+ if df.rdd.isEmpty():
201
+ raise ValueError("Input DataFrame cannot be empty")
202
+
139
203
  if categorical is None:
140
204
  categorical = []
141
-
142
- @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
205
+
206
+ # Define the UDF with proper type hints
207
+ @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
143
208
  def mondrian_partition(pdf: pd.DataFrame) -> pd.DataFrame:
144
209
  partitions = [pdf]
145
210
  result = []
146
-
211
+
147
212
  while partitions:
148
213
  part = partitions.pop()
149
-
214
+
150
215
  # If partition is too small to split, add to results
151
216
  if len(part) < 2 * k:
152
217
  result.append(part)
153
218
  continue
154
-
219
+
155
220
  # Compute spans for each quasi-identifier
156
221
  spans = {}
157
222
  for col in quasi_identifiers:
@@ -160,31 +225,42 @@ def mondrian_k_anonymity_spark(
160
225
  else:
161
226
  col_min = part[col].min()
162
227
  col_max = part[col].max()
163
- spans[col] = col_max - col_min if pd.notnull(col_max) and pd.notnull(col_min) else 0
164
-
228
+ spans[col] = (
229
+ col_max - col_min
230
+ if pd.notnull(col_max) and pd.notnull(col_min)
231
+ else 0
232
+ )
233
+
165
234
  # Find the attribute with maximum span
166
- split_col = max(spans, key=spans.get)
167
-
235
+ split_col = max(spans.items(), key=lambda x: x[1])[0] # type: ignore
236
+
168
237
  # If no split possible, add to results
169
- if spans[split_col] == 0:
238
+ if spans.get(split_col, 0) <= 0:
170
239
  result.append(part)
171
240
  continue
172
-
173
- # Split the partition
241
+
242
+ # Split on the chosen column
174
243
  if split_col in categorical:
175
- # For categorical, split on unique values
176
- uniq_vals = part[split_col].unique()
177
- mid = len(uniq_vals) // 2
178
- left_vals = set(uniq_vals[:mid])
179
- right_vals = set(uniq_vals[mid:])
180
- left = part[part[split_col].isin(left_vals)]
181
- right = part[part[split_col].isin(right_vals)]
244
+ # For categorical, split on median value
245
+ value_counts = part[split_col].value_counts()
246
+ if len(value_counts) > 0:
247
+ split_val = value_counts.index[len(value_counts) // 2]
248
+ mask = part[split_col] == split_val
249
+ left = part[mask]
250
+ right = part[~mask]
251
+ else:
252
+ result.append(part)
253
+ continue
182
254
  else:
183
255
  # For numerical, split on median
184
- median = part[split_col].median()
185
- left = part[part[split_col] <= median]
186
- right = part[part[split_col] > median]
187
-
256
+ median_val = part[split_col].median()
257
+ if pd.notna(median_val):
258
+ left = part[part[split_col] <= median_val]
259
+ right = part[part[split_col] > median_val]
260
+ else:
261
+ result.append(part)
262
+ continue
263
+
188
264
  # Check if both partitions satisfy k-anonymity
189
265
  if len(left) >= k and len(right) >= k:
190
266
  # Both partitions are valid, continue splitting
@@ -192,33 +268,37 @@ def mondrian_k_anonymity_spark(
192
268
  else:
193
269
  # At least one partition is too small, keep as is
194
270
  result.append(part)
195
-
271
+
196
272
  # Aggregate the results
197
273
  agg_rows = []
198
274
  for part in result:
199
275
  row = {}
200
-
276
+
201
277
  # Generalize quasi-identifiers
202
278
  for col in quasi_identifiers:
203
279
  if col in categorical:
204
280
  # For categorical, use set of unique values
205
- row[col] = ','.join(sorted(map(str, part[col].unique())))
281
+ row[col] = ",".join(sorted(map(str, part[col].unique())))
206
282
  else:
207
283
  # For numerical, use range
208
284
  row[col] = f"{part[col].min()}-{part[col].max()}"
209
-
285
+
210
286
  # Add sensitive values and count
211
- row[sensitive_column] = ','.join(sorted(map(str, part[sensitive_column].unique())))
212
- row['count'] = len(part)
287
+ row[sensitive_column] = ",".join(
288
+ sorted(str(v) for v in part[sensitive_column].unique())
289
+ )
290
+ # Store count as integer
291
+ row["count"] = int(len(part))
213
292
  agg_rows.append(row)
214
-
293
+
215
294
  return pd.DataFrame(agg_rows)
216
-
217
- # Apply the function to the entire DataFrame
218
- if schema is not None:
219
- return df.groupBy().applyInPandas(mondrian_partition, schema=schema)
220
- else:
221
- return df.groupBy().applyInPandas(mondrian_partition)
295
+
296
+ # Apply the UDF with explicit schema
297
+ result_df = df.groupBy().applyInPandas(
298
+ mondrian_partition, schema=schema # type: ignore
299
+ )
300
+ return result_df
301
+
222
302
 
223
303
  # ------------------------- DISPATCHER -------------------------
224
304
  def mondrian_k_anonymity(
@@ -231,7 +311,7 @@ def mondrian_k_anonymity(
231
311
  ) -> Union[pl.DataFrame, "SparkDataFrame"]:
232
312
  """
233
313
  Dispatcher: Use Polars or PySpark Mondrian k-anonymity depending on input type.
234
-
314
+
235
315
  Args:
236
316
  df: Input DataFrame (Polars or PySpark)
237
317
  quasi_identifiers: List of column names that are quasi-identifiers
@@ -239,18 +319,178 @@ def mondrian_k_anonymity(
239
319
  k: Anonymity parameter (minimum group size)
240
320
  categorical: List of categorical column names
241
321
  schema: Schema for PySpark output (required for PySpark)
242
-
322
+
243
323
  Returns:
244
324
  Anonymized DataFrame with generalized quasi-identifiers
245
325
  """
246
326
  try:
247
327
  from pyspark.sql import DataFrame as SparkDataFrame
328
+
248
329
  if isinstance(df, SparkDataFrame):
249
- return mondrian_k_anonymity_spark(df, quasi_identifiers, sensitive_column, k, categorical, schema)
330
+ return mondrian_k_anonymity_spark(
331
+ df, quasi_identifiers, sensitive_column, k, categorical, schema
332
+ )
250
333
  except ImportError:
251
334
  pass
252
-
335
+
253
336
  if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
254
- return mondrian_k_anonymity_polars(df, quasi_identifiers, sensitive_column, k, categorical)
337
+ return mondrian_k_anonymity_polars(
338
+ df, quasi_identifiers, sensitive_column, k, categorical
339
+ )
340
+
341
+ raise ValueError(
342
+ "Input df must be a polars.DataFrame, polars.LazyFrame, or pyspark.sql.DataFrame"
343
+ )
344
+
345
+
346
+ def _generalize_partition(
347
+ partition: pl.DataFrame,
348
+ quasi_identifiers: List[str],
349
+ categorical: List[str],
350
+ mask_value: str = "masked"
351
+ ) -> pl.DataFrame:
352
+ """Generalize a partition by applying Mondrian-style generalization."""
353
+ result = partition.clone()
354
+
355
+ for col in quasi_identifiers:
356
+ is_cat = col in categorical
357
+ if is_cat:
358
+ # For categoricals, use mask if multiple values exist
359
+ if result[col].n_unique() > 1:
360
+ result = result.with_columns(pl.lit(mask_value).alias(col))
361
+ else:
362
+ # For numerical, create a range
363
+ min_val = result[col].min()
364
+ max_val = result[col].max()
365
+ if min_val == max_val:
366
+ result = result.with_columns(pl.lit(min_val).alias(col))
367
+ else:
368
+ result = result.with_columns(pl.lit(f"[{min_val}-{max_val}]").alias(col))
369
+
370
+ return result
371
+
372
+
373
+ def mondrian_k_anonymity_alt(
374
+ df: pl.LazyFrame,
375
+ quasi_identifiers: List[str],
376
+ sensitive_column: str,
377
+ k: int,
378
+ categorical: Optional[List[str]] = None,
379
+ mask_value: str = "masked",
380
+ group_columns: Optional[List[str]] = None,
381
+ ) -> pl.LazyFrame:
382
+ """
383
+ Alternative Mondrian k-anonymity that preserves the original row count.
384
+
385
+ Args:
386
+ df: Input LazyFrame
387
+ quasi_identifiers: List of column names that are quasi-identifiers
388
+ sensitive_column: Name of the sensitive column
389
+ k: Anonymity parameter (minimum group size)
390
+ categorical: List of categorical column names
391
+ mask_value: Value to use for masking small groups
392
+ group_columns: Additional columns to use for grouping but keep unchanged
393
+
394
+ Returns:
395
+ Anonymized LazyFrame with same row count as input
396
+ """
397
+ if not isinstance(df, pl.LazyFrame):
398
+ raise ValueError("Input must be a Polars LazyFrame")
399
+
400
+ # Get schema to preserve column order
401
+ schema = df.schema
402
+ all_columns = list(schema.keys())
403
+
404
+ # Initialize parameters
405
+ categorical = categorical or []
406
+ group_columns = group_columns or []
407
+
408
+ # Validate inputs
409
+ if k < 1:
410
+ raise ValueError("k must be a positive integer")
411
+
412
+ # Check if all specified columns exist
413
+ for col in set(quasi_identifiers + [sensitive_column] + group_columns + categorical):
414
+ if col not in schema:
415
+ raise ValueError(f"Column '{col}' not found in DataFrame")
416
+
417
+ # Ensure no overlap between group_columns and QIs
418
+ if any(col in quasi_identifiers for col in group_columns):
419
+ raise ValueError("group_columns cannot overlap with quasi_identifiers")
420
+
421
+ # Collect the data once
422
+ df_collected = df.collect()
423
+
424
+ # Process each group
425
+ if group_columns:
426
+ # Get unique group combinations
427
+ groups = df_collected.select(group_columns).unique()
428
+
429
+ results = []
430
+
431
+ for group in groups.rows(named=True):
432
+ # Filter current group
433
+ condition = pl.lit(True)
434
+ for col, val in group.items():
435
+ condition = condition & (pl.col(col) == val)
436
+
437
+ group_df = df_collected.filter(condition)
438
+ group_size = len(group_df)
439
+
440
+ if group_size < k:
441
+ # Mask QIs and sensitive column for small groups
442
+ masked_cols = {}
443
+ # Mask all QIs
444
+ for col in quasi_identifiers:
445
+ if col in categorical:
446
+ masked_cols[col] = pl.lit(mask_value)
447
+ # Always mask the sensitive column for small groups
448
+ masked_cols[sensitive_column] = pl.lit(mask_value)
449
+
450
+ if masked_cols:
451
+ group_df = group_df.with_columns(**masked_cols)
452
+
453
+ results.append(group_df)
454
+ else:
455
+ # Apply generalization to QIs
456
+ if quasi_identifiers:
457
+ group_df = _generalize_partition(
458
+ group_df,
459
+ quasi_identifiers,
460
+ categorical,
461
+ mask_value
462
+ )
463
+ results.append(group_df)
464
+
465
+ # Combine results
466
+ result_df = pl.concat(results)
467
+ else:
468
+ # Process entire dataset as one group
469
+ if len(df_collected) < k:
470
+ # Mask all QIs and sensitive column
471
+ masked_cols = {}
472
+ # Mask all QIs
473
+ for col in quasi_identifiers:
474
+ if col in categorical:
475
+ masked_cols[col] = pl.lit(mask_value)
476
+ # Always mask the sensitive column for small groups
477
+ masked_cols[sensitive_column] = pl.lit(mask_value)
478
+
479
+ if masked_cols:
480
+ result_df = df_collected.with_columns(**masked_cols)
481
+ else:
482
+ result_df = df_collected
483
+ else:
484
+ # Apply generalization to QIs
485
+ if quasi_identifiers:
486
+ result_df = _generalize_partition(
487
+ df_collected,
488
+ quasi_identifiers,
489
+ categorical,
490
+ mask_value
491
+ )
492
+ else:
493
+ result_df = df_collected
255
494
 
256
- raise ValueError("Input df must be a polars.DataFrame, polars.LazyFrame, or pyspark.sql.DataFrame")
495
+ # Ensure original column order and return as LazyFrame
496
+ return result_df.select(all_columns).lazy()