sqlshell 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sqlshell/__init__.py +84 -0
  2. sqlshell/__main__.py +4926 -0
  3. sqlshell/ai_autocomplete.py +392 -0
  4. sqlshell/ai_settings_dialog.py +337 -0
  5. sqlshell/context_suggester.py +768 -0
  6. sqlshell/create_test_data.py +152 -0
  7. sqlshell/data/create_test_data.py +137 -0
  8. sqlshell/db/__init__.py +6 -0
  9. sqlshell/db/database_manager.py +1318 -0
  10. sqlshell/db/export_manager.py +188 -0
  11. sqlshell/editor.py +1166 -0
  12. sqlshell/editor_integration.py +127 -0
  13. sqlshell/execution_handler.py +421 -0
  14. sqlshell/menus.py +262 -0
  15. sqlshell/notification_manager.py +370 -0
  16. sqlshell/query_tab.py +904 -0
  17. sqlshell/resources/__init__.py +1 -0
  18. sqlshell/resources/icon.png +0 -0
  19. sqlshell/resources/logo_large.png +0 -0
  20. sqlshell/resources/logo_medium.png +0 -0
  21. sqlshell/resources/logo_small.png +0 -0
  22. sqlshell/resources/splash_screen.gif +0 -0
  23. sqlshell/space_invaders.py +501 -0
  24. sqlshell/splash_screen.py +405 -0
  25. sqlshell/sqlshell/__init__.py +5 -0
  26. sqlshell/sqlshell/create_test_data.py +118 -0
  27. sqlshell/sqlshell/create_test_databases.py +96 -0
  28. sqlshell/sqlshell_demo.png +0 -0
  29. sqlshell/styles.py +257 -0
  30. sqlshell/suggester_integration.py +330 -0
  31. sqlshell/syntax_highlighter.py +124 -0
  32. sqlshell/table_list.py +996 -0
  33. sqlshell/ui/__init__.py +6 -0
  34. sqlshell/ui/bar_chart_delegate.py +49 -0
  35. sqlshell/ui/filter_header.py +469 -0
  36. sqlshell/utils/__init__.py +16 -0
  37. sqlshell/utils/profile_cn2.py +1661 -0
  38. sqlshell/utils/profile_column.py +2635 -0
  39. sqlshell/utils/profile_distributions.py +616 -0
  40. sqlshell/utils/profile_entropy.py +347 -0
  41. sqlshell/utils/profile_foreign_keys.py +779 -0
  42. sqlshell/utils/profile_keys.py +2834 -0
  43. sqlshell/utils/profile_ohe.py +934 -0
  44. sqlshell/utils/profile_ohe_advanced.py +754 -0
  45. sqlshell/utils/profile_ohe_comparison.py +237 -0
  46. sqlshell/utils/profile_prediction.py +926 -0
  47. sqlshell/utils/profile_similarity.py +876 -0
  48. sqlshell/utils/search_in_df.py +90 -0
  49. sqlshell/widgets.py +400 -0
  50. sqlshell-0.4.4.dist-info/METADATA +441 -0
  51. sqlshell-0.4.4.dist-info/RECORD +54 -0
  52. sqlshell-0.4.4.dist-info/WHEEL +5 -0
  53. sqlshell-0.4.4.dist-info/entry_points.txt +2 -0
  54. sqlshell-0.4.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1661 @@
1
+ """
2
+ CN2 Rule Induction Algorithm for Classification
3
+
4
+ This module implements the classic CN2 rule-induction algorithm inspired by
5
+ Clark & Niblett (1989). It learns interpretable IF-THEN classification rules
6
+ using a separate-and-conquer (cover-and-remove) strategy with beam search.
7
+
8
+ The algorithm supports:
9
+ - Mixed categorical and numeric features
10
+ - Supervised discretization for numeric features
11
+ - Multiple quality measures (likelihood_ratio, entropy)
12
+ - Laplace-smoothed probability estimates
13
+ - Automatic discretization of numeric target variables using academic methods
14
+
15
+ Example usage:
16
+ from sqlshell.utils.profile_cn2 import CN2Classifier, visualize_cn2_rules
17
+
18
+ clf = CN2Classifier(beam_width=5, min_covered_examples=5)
19
+ clf.fit(X, y)
20
+ predictions = clf.predict(X)
21
+ rules = clf.get_rules()
22
+ """
23
+
24
+ import pandas as pd
25
+ import numpy as np
26
+ from dataclasses import dataclass, field
27
+ from typing import List, Dict, Tuple, Optional, Any, Union
28
+ from collections import Counter
29
+ import warnings
30
+
31
+
32
+ # =============================================================================
33
+ # Numeric Target Discretization
34
+ # =============================================================================
35
+
36
+ class NumericTargetDiscretizer:
37
+ """
38
+ Discretizes continuous numeric target variables into meaningful categorical bins.
39
+
40
+ This class implements several academically-grounded binning methods for converting
41
+ continuous numeric variables into discrete categories suitable for classification.
42
+
43
+ Supported Methods:
44
+ -----------------
45
+ 1. **jenks** (Fisher-Jenks Natural Breaks):
46
+ - Minimizes within-class variance while maximizing between-class variance
47
+ - Based on Fisher (1958) and Jenks (1967)
48
+ - Best for data with natural clusters or multimodal distributions
49
+ - Reference: Jenks, G.F. (1967). "The Data Model Concept in Statistical Mapping"
50
+
51
+ 2. **quantile** (Equal-Frequency Binning):
52
+ - Creates bins with equal number of observations
53
+ - Robust to outliers and skewed distributions
54
+ - Based on standard statistical quantile theory
55
+
56
+ 3. **freedman_diaconis** (Freedman-Diaconis Rule):
57
+ - Bin width = 2 * IQR / n^(1/3)
58
+ - Optimal for normally distributed data
59
+ - Reference: Freedman, D. & Diaconis, P. (1981). "On the histogram as a
60
+ density estimator: L2 theory"
61
+
62
+ 4. **sturges** (Sturges' Rule):
63
+ - Number of bins = 1 + log2(n)
64
+ - Classic rule, best for symmetric distributions
65
+ - Reference: Sturges, H.A. (1926). "The Choice of a Class Interval"
66
+
67
+ 5. **equal_width** (Equal-Width Binning):
68
+ - Creates bins of equal range
69
+ - Simple but may create empty bins with skewed data
70
+
71
+ 6. **auto** (Automatic Selection):
72
+ - Automatically selects the best method based on data characteristics
73
+ - Uses skewness and distribution analysis to choose optimal method
74
+
75
+ Parameters:
76
+ ----------
77
+ method : str, default='auto'
78
+ The binning method to use. One of: 'auto', 'jenks', 'quantile',
79
+ 'freedman_diaconis', 'sturges', 'equal_width'
80
+
81
+ n_bins : int, optional
82
+ Number of bins to create. If None, determined automatically based on method.
83
+ For 'freedman_diaconis' and 'sturges', this is computed from the data.
84
+
85
+ min_bin_size : int, default=5
86
+ Minimum number of samples per bin. Bins smaller than this are merged.
87
+
88
+ max_bins : int, default=10
89
+ Maximum number of bins to create (for automatic methods).
90
+
91
+ Attributes:
92
+ ----------
93
+ bin_edges_ : np.ndarray
94
+ The computed bin edges after fitting.
95
+
96
+ bin_labels_ : List[str]
97
+ Human-readable labels for each bin (e.g., "1-1000", "1001-5000").
98
+
99
+ method_used_ : str
100
+ The actual method used (relevant when method='auto').
101
+
102
+ n_bins_ : int
103
+ The number of bins created.
104
+
105
+ Example:
106
+ -------
107
+ >>> discretizer = NumericTargetDiscretizer(method='jenks', n_bins=5)
108
+ >>> df['income_category'] = discretizer.fit_transform(df['income'])
109
+ >>> print(discretizer.bin_labels_)
110
+ ['0-25000', '25001-50000', '50001-100000', '100001-250000', '250001+']
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ method: str = 'auto',
116
+ n_bins: Optional[int] = None,
117
+ min_bin_size: int = 5,
118
+ max_bins: int = 10
119
+ ):
120
+ valid_methods = {'auto', 'jenks', 'quantile', 'freedman_diaconis',
121
+ 'sturges', 'equal_width'}
122
+ if method not in valid_methods:
123
+ raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
124
+
125
+ self.method = method
126
+ self.n_bins = n_bins
127
+ self.min_bin_size = min_bin_size
128
+ self.max_bins = max_bins
129
+
130
+ # Set after fitting
131
+ self.bin_edges_: np.ndarray = None
132
+ self.bin_labels_: List[str] = []
133
+ self.method_used_: str = None
134
+ self.n_bins_: int = 0
135
+ self._is_fitted: bool = False
136
+
137
+ def fit(self, data: Union[pd.Series, np.ndarray]) -> "NumericTargetDiscretizer":
138
+ """
139
+ Compute bin edges from the data.
140
+
141
+ Parameters:
142
+ ----------
143
+ data : array-like
144
+ The numeric data to fit.
145
+
146
+ Returns:
147
+ -------
148
+ self : NumericTargetDiscretizer
149
+ """
150
+ # Convert to numpy array and remove NaN
151
+ if isinstance(data, pd.Series):
152
+ values = data.dropna().values
153
+ else:
154
+ values = np.asarray(data)
155
+ values = values[~np.isnan(values)]
156
+
157
+ if len(values) == 0:
158
+ raise ValueError("Cannot fit on empty data")
159
+
160
+ n_unique = len(np.unique(values))
161
+ if n_unique <= 1:
162
+ raise ValueError("Need at least 2 unique values to discretize")
163
+
164
+ # Select method if auto
165
+ if self.method == 'auto':
166
+ self.method_used_ = self._select_method(values)
167
+ else:
168
+ self.method_used_ = self.method
169
+
170
+ # Compute number of bins if not specified
171
+ if self.n_bins is not None:
172
+ n_bins = min(self.n_bins, n_unique, self.max_bins)
173
+ else:
174
+ n_bins = self._compute_n_bins(values, self.method_used_)
175
+ n_bins = min(n_bins, n_unique, self.max_bins)
176
+
177
+ # Ensure we have at least 2 bins
178
+ n_bins = max(2, n_bins)
179
+
180
+ # Compute bin edges using selected method
181
+ if self.method_used_ == 'jenks':
182
+ self.bin_edges_ = self._jenks_breaks(values, n_bins)
183
+ elif self.method_used_ == 'quantile':
184
+ self.bin_edges_ = self._quantile_breaks(values, n_bins)
185
+ elif self.method_used_ == 'freedman_diaconis':
186
+ self.bin_edges_ = self._freedman_diaconis_breaks(values, n_bins)
187
+ elif self.method_used_ == 'sturges':
188
+ self.bin_edges_ = self._equal_width_breaks(values, n_bins)
189
+ elif self.method_used_ == 'equal_width':
190
+ self.bin_edges_ = self._equal_width_breaks(values, n_bins)
191
+
192
+ # Merge small bins
193
+ self.bin_edges_ = self._merge_small_bins(values, self.bin_edges_)
194
+
195
+ # Generate human-readable labels
196
+ self.bin_labels_ = self._generate_labels(self.bin_edges_, values)
197
+ self.n_bins_ = len(self.bin_labels_)
198
+ self._is_fitted = True
199
+
200
+ return self
201
+
202
+ def transform(self, data: Union[pd.Series, np.ndarray]) -> np.ndarray:
203
+ """
204
+ Transform numeric values to categorical bin labels.
205
+
206
+ Parameters:
207
+ ----------
208
+ data : array-like
209
+ The numeric data to transform.
210
+
211
+ Returns:
212
+ -------
213
+ labels : np.ndarray
214
+ Array of string labels for each value.
215
+ """
216
+ if not self._is_fitted:
217
+ raise RuntimeError("NumericTargetDiscretizer not fitted. Call fit() first.")
218
+
219
+ if isinstance(data, pd.Series):
220
+ values = data.values
221
+ else:
222
+ values = np.asarray(data)
223
+
224
+ # Digitize values into bins
225
+ # np.digitize returns indices 1 to n_bins, we want 0 to n_bins-1
226
+ bin_indices = np.digitize(values, self.bin_edges_[1:-1])
227
+
228
+ # Map indices to labels, handle NaN
229
+ labels = np.empty(len(values), dtype=object)
230
+ for i, (val, idx) in enumerate(zip(values, bin_indices)):
231
+ if np.isnan(val) if isinstance(val, float) else pd.isna(val):
232
+ labels[i] = "Missing"
233
+ else:
234
+ labels[i] = self.bin_labels_[min(idx, len(self.bin_labels_) - 1)]
235
+
236
+ return labels
237
+
238
+ def fit_transform(self, data: Union[pd.Series, np.ndarray]) -> np.ndarray:
239
+ """Fit and transform in one step."""
240
+ return self.fit(data).transform(data)
241
+
242
+ def _select_method(self, values: np.ndarray) -> str:
243
+ """
244
+ Automatically select the best binning method based on data characteristics.
245
+
246
+ Selection criteria (based on statistical theory):
247
+ - Jenks: For multimodal or clustered distributions
248
+ - Quantile: For highly skewed distributions (|skewness| > 1)
249
+ - Freedman-Diaconis: For moderate distributions
250
+ """
251
+ from scipy import stats
252
+
253
+ n = len(values)
254
+
255
+ # Calculate skewness
256
+ skewness = stats.skew(values)
257
+
258
+ # Calculate kurtosis (excess kurtosis)
259
+ kurtosis = stats.kurtosis(values)
260
+
261
+ # Check for multimodality using dip test approximation
262
+ # Simple heuristic: check if histogram has multiple peaks
263
+ try:
264
+ hist, _ = np.histogram(values, bins='auto')
265
+ # Count local maxima
266
+ peaks = 0
267
+ for i in range(1, len(hist) - 1):
268
+ if hist[i] > hist[i-1] and hist[i] > hist[i+1]:
269
+ peaks += 1
270
+ is_multimodal = peaks >= 2
271
+ except Exception:
272
+ is_multimodal = False
273
+
274
+ # Selection logic
275
+ if is_multimodal:
276
+ # Jenks is best for multimodal distributions
277
+ return 'jenks'
278
+ elif abs(skewness) > 1.5:
279
+ # Highly skewed: quantile binning handles outliers better
280
+ return 'quantile'
281
+ elif abs(skewness) > 0.5:
282
+ # Moderately skewed: Freedman-Diaconis is robust
283
+ return 'freedman_diaconis'
284
+ else:
285
+ # Symmetric distribution: quantile works well
286
+ return 'quantile'
287
+
288
+ def _compute_n_bins(self, values: np.ndarray, method: str) -> int:
289
+ """Compute optimal number of bins based on method."""
290
+ n = len(values)
291
+
292
+ if method == 'sturges':
293
+ # Sturges' rule: k = 1 + log2(n)
294
+ return int(1 + np.log2(n))
295
+
296
+ elif method == 'freedman_diaconis':
297
+ # Freedman-Diaconis: bin width = 2 * IQR / n^(1/3)
298
+ q75, q25 = np.percentile(values, [75, 25])
299
+ iqr = q75 - q25
300
+ if iqr == 0:
301
+ return 5 # Default if no spread
302
+ bin_width = 2 * iqr / (n ** (1/3))
303
+ data_range = values.max() - values.min()
304
+ return max(1, int(np.ceil(data_range / bin_width)))
305
+
306
+ elif method in ('jenks', 'quantile', 'equal_width'):
307
+ # Use square root rule as default: k = sqrt(n)
308
+ k = int(np.sqrt(n))
309
+ return min(max(3, k), self.max_bins)
310
+
311
+ return 5 # Default
312
+
313
+ def _jenks_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
314
+ """
315
+ Compute Jenks natural breaks (Fisher-Jenks algorithm).
316
+
317
+ This minimizes the sum of squared deviations within each class.
318
+ Based on: Fisher, W.D. (1958). "On Grouping for Maximum Homogeneity"
319
+ """
320
+ # Sort values
321
+ sorted_vals = np.sort(values)
322
+ n = len(sorted_vals)
323
+
324
+ # Limit to n_bins - 1 breaks for n_bins bins
325
+ n_classes = min(n_bins, n)
326
+
327
+ if n_classes <= 1:
328
+ return np.array([sorted_vals[0], sorted_vals[-1]])
329
+
330
+ # Initialize matrices for dynamic programming
331
+ # mat1[i][j] = minimum sum of squared deviations for first i values in j classes
332
+ # mat2[i][j] = index of last break for optimal solution
333
+
334
+ # For efficiency, we use a simplified version
335
+ # that finds breaks by minimizing within-class variance
336
+
337
+ # Use k-means style approach for large datasets
338
+ if n > 500:
339
+ return self._jenks_kmeans_approx(sorted_vals, n_classes)
340
+
341
+ # Full Jenks algorithm for smaller datasets
342
+ mat1 = np.zeros((n + 1, n_classes + 1))
343
+ mat2 = np.zeros((n + 1, n_classes + 1), dtype=int)
344
+
345
+ # Initialize
346
+ mat1[:, 0] = np.inf
347
+ mat1[0, :] = np.inf
348
+ mat1[0, 0] = 0
349
+
350
+ # Compute variance for all intervals [i, j]
351
+ variance_combinations = {}
352
+ for i in range(n):
353
+ sums = 0.0
354
+ sum_squares = 0.0
355
+ for j in range(i, n):
356
+ sums += sorted_vals[j]
357
+ sum_squares += sorted_vals[j] ** 2
358
+ count = j - i + 1
359
+ mean = sums / count
360
+ variance = sum_squares / count - mean ** 2
361
+ variance_combinations[(i, j)] = variance * count
362
+
363
+ # Fill DP table
364
+ for i in range(1, n + 1):
365
+ for j in range(1, min(i, n_classes) + 1):
366
+ if j == 1:
367
+ mat1[i, j] = variance_combinations[(0, i - 1)]
368
+ mat2[i, j] = 0
369
+ else:
370
+ min_cost = np.inf
371
+ min_idx = 0
372
+ for k in range(j - 1, i):
373
+ cost = mat1[k, j - 1] + variance_combinations[(k, i - 1)]
374
+ if cost < min_cost:
375
+ min_cost = cost
376
+ min_idx = k
377
+ mat1[i, j] = min_cost
378
+ mat2[i, j] = min_idx
379
+
380
+ # Backtrack to find breaks
381
+ breaks = [n]
382
+ k = n_classes
383
+ while k > 1:
384
+ breaks.append(mat2[breaks[-1], k])
385
+ k -= 1
386
+ breaks.reverse()
387
+
388
+ # Convert to actual values
389
+ bin_edges = [sorted_vals[0]]
390
+ for b in breaks[1:]:
391
+ if b > 0 and b < n:
392
+ bin_edges.append(sorted_vals[b - 1])
393
+ bin_edges.append(sorted_vals[-1])
394
+
395
+ return np.unique(bin_edges)
396
+
397
+ def _jenks_kmeans_approx(self, sorted_vals: np.ndarray, n_bins: int) -> np.ndarray:
398
+ """
399
+ Approximates Jenks using 1D k-means clustering.
400
+ More efficient for large datasets.
401
+ """
402
+ try:
403
+ from scipy.cluster.vq import kmeans
404
+
405
+ # Run k-means on 1D data
406
+ centroids, _ = kmeans(sorted_vals.astype(float), n_bins)
407
+ centroids = np.sort(centroids)
408
+
409
+ # Find breakpoints between centroids
410
+ breaks = [sorted_vals[0]]
411
+ for i in range(len(centroids) - 1):
412
+ midpoint = (centroids[i] + centroids[i + 1]) / 2
413
+ breaks.append(midpoint)
414
+ breaks.append(sorted_vals[-1])
415
+
416
+ return np.array(breaks)
417
+ except Exception:
418
+ # Fallback to quantile
419
+ return self._quantile_breaks(sorted_vals, n_bins)
420
+
421
+ def _quantile_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
422
+ """Compute quantile-based breaks (equal frequency binning)."""
423
+ quantiles = np.linspace(0, 100, n_bins + 1)
424
+ breaks = np.percentile(values, quantiles)
425
+ return np.unique(breaks)
426
+
427
+ def _equal_width_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
428
+ """Compute equal-width breaks."""
429
+ return np.linspace(values.min(), values.max(), n_bins + 1)
430
+
431
+ def _freedman_diaconis_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
432
+ """Compute breaks using Freedman-Diaconis rule for bin width."""
433
+ q75, q25 = np.percentile(values, [75, 25])
434
+ iqr = q75 - q25
435
+
436
+ if iqr == 0:
437
+ # No spread, fall back to equal width
438
+ return self._equal_width_breaks(values, n_bins)
439
+
440
+ bin_width = 2 * iqr / (len(values) ** (1/3))
441
+ min_val = values.min()
442
+ max_val = values.max()
443
+
444
+ # Generate breaks
445
+ breaks = [min_val]
446
+ current = min_val + bin_width
447
+ while current < max_val:
448
+ breaks.append(current)
449
+ current += bin_width
450
+ breaks.append(max_val)
451
+
452
+ # Limit to max_bins
453
+ if len(breaks) - 1 > n_bins:
454
+ return self._equal_width_breaks(values, n_bins)
455
+
456
+ return np.array(breaks)
457
+
458
+ def _merge_small_bins(self, values: np.ndarray, breaks: np.ndarray) -> np.ndarray:
459
+ """Merge bins that have fewer than min_bin_size samples."""
460
+ if len(breaks) <= 2:
461
+ return breaks
462
+
463
+ # Count samples per bin
464
+ bin_indices = np.digitize(values, breaks[1:-1])
465
+ counts = np.bincount(bin_indices, minlength=len(breaks) - 1)
466
+
467
+ # Iteratively merge small bins
468
+ new_breaks = list(breaks)
469
+ merged = True
470
+ while merged and len(new_breaks) > 2:
471
+ merged = False
472
+ bin_indices = np.digitize(values, new_breaks[1:-1])
473
+ counts = np.bincount(bin_indices, minlength=len(new_breaks) - 1)
474
+
475
+ for i in range(len(counts)):
476
+ if counts[i] < self.min_bin_size and len(new_breaks) > 2:
477
+ # Merge with neighbor that has fewer samples
478
+ if i == 0:
479
+ # Merge with next bin (remove break at index 1)
480
+ if len(new_breaks) > 2:
481
+ new_breaks.pop(1)
482
+ merged = True
483
+ break
484
+ elif i == len(counts) - 1:
485
+ # Merge with previous bin (remove second to last break)
486
+ if len(new_breaks) > 2:
487
+ new_breaks.pop(-2)
488
+ merged = True
489
+ break
490
+ else:
491
+ # Merge with smaller neighbor
492
+ if counts[i - 1] <= counts[i + 1]:
493
+ new_breaks.pop(i)
494
+ else:
495
+ new_breaks.pop(i + 1)
496
+ merged = True
497
+ break
498
+
499
+ return np.array(new_breaks)
500
+
501
+ def _generate_labels(self, breaks: np.ndarray, values: np.ndarray) -> List[str]:
502
+ """Generate human-readable bin labels."""
503
+ labels = []
504
+
505
+ # Determine formatting based on value range
506
+ min_val, max_val = values.min(), values.max()
507
+ val_range = max_val - min_val
508
+
509
+ # Determine number of decimal places
510
+ if val_range > 1000:
511
+ fmt = lambda x: f"{int(round(x)):,}"
512
+ elif val_range > 10:
513
+ fmt = lambda x: f"{x:.1f}"
514
+ elif val_range > 1:
515
+ fmt = lambda x: f"{x:.2f}"
516
+ else:
517
+ fmt = lambda x: f"{x:.3f}"
518
+
519
+ for i in range(len(breaks) - 1):
520
+ low = breaks[i]
521
+ high = breaks[i + 1]
522
+
523
+ if i == 0:
524
+ # First bin: "≤ X"
525
+ labels.append(f"≤ {fmt(high)}")
526
+ elif i == len(breaks) - 2:
527
+ # Last bin: "> X"
528
+ labels.append(f"> {fmt(low)}")
529
+ else:
530
+ # Middle bins: "X - Y"
531
+ labels.append(f"{fmt(low)} - {fmt(high)}")
532
+
533
+ # If we only have 2 breaks, create clearer labels
534
+ if len(labels) == 1:
535
+ mid = (breaks[0] + breaks[1]) / 2
536
+ labels = [f"≤ {fmt(mid)}", f"> {fmt(mid)}"]
537
+ self.bin_edges_ = np.array([breaks[0], mid, breaks[1]])
538
+
539
+ return labels
540
+
541
+ def get_bin_summary(self, data: Union[pd.Series, np.ndarray]) -> pd.DataFrame:
542
+ """
543
+ Get a summary of the discretization.
544
+
545
+ Parameters:
546
+ ----------
547
+ data : array-like
548
+ The original numeric data.
549
+
550
+ Returns:
551
+ -------
552
+ summary : pd.DataFrame
553
+ DataFrame with bin ranges, counts, and percentages.
554
+ """
555
+ if not self._is_fitted:
556
+ raise RuntimeError("Not fitted. Call fit() first.")
557
+
558
+ labels = self.transform(data)
559
+ counts = Counter(labels)
560
+
561
+ records = []
562
+ for label in self.bin_labels_:
563
+ count = counts.get(label, 0)
564
+ pct = count / len(labels) * 100 if len(labels) > 0 else 0
565
+ records.append({
566
+ 'bin': label,
567
+ 'count': count,
568
+ 'percentage': f"{pct:.1f}%"
569
+ })
570
+
571
+ return pd.DataFrame(records)
572
+
573
+
574
+ def discretize_numeric_target(
575
+ df: pd.DataFrame,
576
+ target_column: str,
577
+ method: str = 'auto',
578
+ n_bins: Optional[int] = None,
579
+ inplace: bool = False
580
+ ) -> Tuple[pd.DataFrame, NumericTargetDiscretizer]:
581
+ """
582
+ Discretize a numeric target column for use with CN2 classification.
583
+
584
+ This is a convenience function that creates a new categorical column
585
+ from a numeric column using intelligent binning methods.
586
+
587
+ Parameters:
588
+ ----------
589
+ df : pd.DataFrame
590
+ The dataframe containing the target column.
591
+
592
+ target_column : str
593
+ Name of the numeric column to discretize.
594
+
595
+ method : str, default='auto'
596
+ Binning method. One of: 'auto', 'jenks', 'quantile',
597
+ 'freedman_diaconis', 'sturges', 'equal_width'
598
+
599
+ n_bins : int, optional
600
+ Number of bins. If None, determined automatically.
601
+
602
+ inplace : bool, default=False
603
+ If True, modify the dataframe in place. Otherwise, return a copy.
604
+
605
+ Returns:
606
+ -------
607
+ df : pd.DataFrame
608
+ DataFrame with the target column replaced by discretized values.
609
+
610
+ discretizer : NumericTargetDiscretizer
611
+ The fitted discretizer (useful for inspecting bin edges).
612
+
613
+ Example:
614
+ -------
615
+ >>> df_discrete, disc = discretize_numeric_target(df, 'income', method='jenks')
616
+ >>> print(disc.bin_labels_)
617
+ ['≤ 25000', '25001 - 50000', '50001 - 100000', '> 100000']
618
+ """
619
+ if not inplace:
620
+ df = df.copy()
621
+
622
+ if target_column not in df.columns:
623
+ raise ValueError(f"Column '{target_column}' not found in DataFrame")
624
+
625
+ discretizer = NumericTargetDiscretizer(method=method, n_bins=n_bins)
626
+ df[target_column] = discretizer.fit_transform(df[target_column])
627
+
628
+ return df, discretizer
629
+
630
+
631
+ # =============================================================================
632
+ # Data Structures
633
+ # =============================================================================
634
+
635
+ @dataclass
636
+ class Condition:
637
+ """
638
+ A single condition in a rule.
639
+
640
+ Attributes:
641
+ feature: Name of the feature being tested
642
+ operator: Comparison operator ('==', '<=', '>')
643
+ value: Value to compare against
644
+ is_numeric: Whether this is a numeric condition
645
+ """
646
+ feature: str
647
+ operator: str # '==' for categorical, '<=' or '>' for numeric
648
+ value: Any
649
+ is_numeric: bool = False
650
+
651
+ def __str__(self) -> str:
652
+ if self.is_numeric:
653
+ return f"{self.feature} {self.operator} {self.value:.4g}"
654
+ return f"{self.feature} == '{self.value}'"
655
+
656
+ def __hash__(self) -> int:
657
+ return hash((self.feature, self.operator, str(self.value)))
658
+
659
+ def __eq__(self, other) -> bool:
660
+ if not isinstance(other, Condition):
661
+ return False
662
+ return (self.feature == other.feature and
663
+ self.operator == other.operator and
664
+ str(self.value) == str(other.value))
665
+
666
+ def evaluate(self, row: pd.Series) -> bool:
667
+ """Evaluate this condition on a single data row."""
668
+ val = row[self.feature]
669
+ if pd.isna(val):
670
+ return False
671
+ if self.operator == '==':
672
+ return val == self.value
673
+ elif self.operator == '<=':
674
+ return val <= self.value
675
+ elif self.operator == '>':
676
+ return val > self.value
677
+ return False
678
+
679
+
680
+ @dataclass
681
+ class Rule:
682
+ """
683
+ A classification rule consisting of conditions and a predicted class.
684
+
685
+ Attributes:
686
+ conditions: List of conditions forming the rule antecedent
687
+ predicted_class: The class label predicted by this rule
688
+ coverage: Number of training examples covered by this rule
689
+ accuracy: Accuracy of this rule on covered examples
690
+ class_distribution: Distribution of classes among covered examples
691
+ quality_score: Quality score used during rule learning
692
+ """
693
+ conditions: List[Condition] = field(default_factory=list)
694
+ predicted_class: Any = None
695
+ coverage: int = 0
696
+ accuracy: float = 0.0
697
+ class_distribution: Dict[Any, int] = field(default_factory=dict)
698
+ quality_score: float = 0.0
699
+
700
+ def __str__(self) -> str:
701
+ if not self.conditions:
702
+ return f"IF True THEN class = {self.predicted_class}"
703
+ cond_str = " AND ".join(str(c) for c in self.conditions)
704
+ return (f"IF {cond_str} THEN class = {self.predicted_class} "
705
+ f"[cov={self.coverage}, acc={self.accuracy:.2%}]")
706
+
707
+ def __repr__(self) -> str:
708
+ return self.__str__()
709
+
710
+ def covers(self, row: pd.Series) -> bool:
711
+ """Check if this rule covers (applies to) a data row."""
712
+ if not self.conditions:
713
+ return True
714
+ return all(cond.evaluate(row) for cond in self.conditions)
715
+
716
+ def covers_mask(self, X: pd.DataFrame) -> np.ndarray:
717
+ """Return a boolean mask of rows covered by this rule."""
718
+ if not self.conditions:
719
+ return np.ones(len(X), dtype=bool)
720
+ mask = np.ones(len(X), dtype=bool)
721
+ for cond in self.conditions:
722
+ col_vals = X[cond.feature]
723
+ if cond.operator == '==':
724
+ mask &= (col_vals == cond.value).values
725
+ elif cond.operator == '<=':
726
+ mask &= (col_vals <= cond.value).values
727
+ elif cond.operator == '>':
728
+ mask &= (col_vals > cond.value).values
729
+ # Handle NaN values
730
+ mask &= ~col_vals.isna().values
731
+ return mask
732
+
733
+ def to_dict(self) -> Dict:
734
+ """Convert rule to a dictionary representation."""
735
+ return {
736
+ 'conditions': [str(c) for c in self.conditions],
737
+ 'conditions_detailed': [
738
+ {'feature': c.feature, 'operator': c.operator,
739
+ 'value': c.value, 'is_numeric': c.is_numeric}
740
+ for c in self.conditions
741
+ ],
742
+ 'predicted_class': self.predicted_class,
743
+ 'coverage': self.coverage,
744
+ 'accuracy': self.accuracy,
745
+ 'class_distribution': dict(self.class_distribution),
746
+ 'quality_score': self.quality_score
747
+ }
748
+
749
+
750
+ # =============================================================================
751
+ # CN2 Classifier (Optimized)
752
+ # =============================================================================
753
+
754
+ class CN2Classifier:
755
+ """
756
+ CN2 Rule Induction Classifier (Optimized Implementation).
757
+
758
+ Implements the classic CN2 algorithm for learning classification rules
759
+ using a separate-and-conquer strategy with beam search.
760
+
761
+ Parameters:
762
+ max_rules: Maximum number of rules to learn (default: 10)
763
+ beam_width: Width of the beam in beam search (default: 3)
764
+ min_covered_examples: Minimum examples a rule must cover (default: 5)
765
+ max_rule_length: Maximum number of conditions per rule (default: 3)
766
+ quality_measure: Quality heuristic - 'likelihood_ratio' or 'entropy'
767
+ random_state: Random seed for reproducibility
768
+ discretization_bins: Number of bins for numeric feature discretization (default: 4)
769
+ laplace_smoothing: Whether to use Laplace smoothing for probabilities
770
+
771
+ Attributes:
772
+ rules_: List of learned Rule objects (after fitting)
773
+ default_class_: Default class for examples not covered by any rule
774
+ classes_: Unique class labels
775
+ feature_names_: Names of input features
776
+ n_features_: Number of input features
777
+ """
778
+
779
+ def __init__(
780
+ self,
781
+ max_rules: Optional[int] = 10,
782
+ beam_width: int = 3,
783
+ min_covered_examples: int = 5,
784
+ max_rule_length: Optional[int] = 3,
785
+ quality_measure: str = "likelihood_ratio",
786
+ random_state: Optional[int] = None,
787
+ discretization_bins: int = 4,
788
+ laplace_smoothing: bool = True
789
+ ):
790
+ self.max_rules = max_rules
791
+ self.beam_width = beam_width
792
+ self.min_covered_examples = min_covered_examples
793
+ self.max_rule_length = max_rule_length if max_rule_length else 3
794
+ self.quality_measure = quality_measure
795
+ self.random_state = random_state
796
+ self.discretization_bins = discretization_bins
797
+ self.laplace_smoothing = laplace_smoothing
798
+
799
+ # Validate quality measure
800
+ if quality_measure not in ('likelihood_ratio', 'entropy'):
801
+ raise ValueError(
802
+ f"quality_measure must be 'likelihood_ratio' or 'entropy', "
803
+ f"got '{quality_measure}'"
804
+ )
805
+
806
+ # Will be set during fit
807
+ self.rules_: List[Rule] = []
808
+ self.default_class_: Any = None
809
+ self.classes_: np.ndarray = None
810
+ self.feature_names_: List[str] = []
811
+ self.n_features_: int = 0
812
+ self._is_fitted: bool = False
813
+
814
+ # Cached data for fast computation
815
+ self._X_array: np.ndarray = None
816
+ self._y_array: np.ndarray = None
817
+ self._condition_masks: Dict = {} # Pre-computed masks for all conditions
818
+ self._feature_conditions: List = [] # List of (feature_idx, operator, value, is_numeric)
819
+
820
+ def fit(self, X, y) -> "CN2Classifier":
821
+ """
822
+ Learn a rule list from training data.
823
+
824
+ Parameters:
825
+ X: Feature matrix (pandas DataFrame or 2D numpy array)
826
+ y: Target labels (1D array-like)
827
+
828
+ Returns:
829
+ self: The fitted classifier
830
+ """
831
+ # Convert inputs to proper format
832
+ X_df, y = self._validate_input(X, y)
833
+
834
+ # Store class information
835
+ self.classes_ = np.unique(y)
836
+ self.feature_names_ = list(X_df.columns)
837
+ self.n_features_ = len(self.feature_names_)
838
+
839
+ # Store default class (majority class in training data)
840
+ class_counts = Counter(y)
841
+ self.default_class_ = class_counts.most_common(1)[0][0]
842
+
843
+ # Convert class labels to integers for faster processing
844
+ self._class_to_idx = {c: i for i, c in enumerate(self.classes_)}
845
+ self._y_encoded = np.array([self._class_to_idx[c] for c in y])
846
+ n_classes = len(self.classes_)
847
+
848
+ # Pre-compute all condition masks (this is the key optimization)
849
+ self._precompute_condition_masks(X_df)
850
+
851
+ # Check if we have any valid conditions
852
+ if len(self._feature_conditions) == 0:
853
+ # No valid conditions - can't learn rules, just use default class
854
+ self._is_fitted = True
855
+ return self
856
+
857
+ # Main CN2 loop: separate-and-conquer
858
+ self.rules_ = []
859
+ remaining_mask = np.ones(len(X_df), dtype=bool)
860
+ n_samples = len(X_df)
861
+
862
+ while remaining_mask.sum() >= self.min_covered_examples:
863
+ # Check max rules limit
864
+ if self.max_rules is not None and len(self.rules_) >= self.max_rules:
865
+ break
866
+
867
+ # Find the best rule for remaining examples
868
+ best_rule = self._find_best_rule_fast(remaining_mask)
869
+
870
+ if best_rule is None:
871
+ break
872
+
873
+ # Add rule to rule list
874
+ self.rules_.append(best_rule)
875
+
876
+ # Remove covered examples
877
+ rule_mask = self._compute_rule_mask(best_rule._condition_indices)
878
+ remaining_mask &= ~rule_mask
879
+
880
+ # Store dataframe for prediction
881
+ self._X_df = X_df
882
+ self._is_fitted = True
883
+ return self
884
+
885
+ def _precompute_condition_masks(self, X: pd.DataFrame):
886
+ """Pre-compute boolean masks for all possible conditions."""
887
+ self._condition_masks = {}
888
+ self._feature_conditions = []
889
+ n_samples = len(X)
890
+
891
+ for feat_idx, col in enumerate(X.columns):
892
+ col_data = X[col].values
893
+ dtype = X[col].dtype
894
+ is_numeric = np.issubdtype(dtype, np.number)
895
+
896
+ # Handle NaN mask
897
+ nan_mask = pd.isna(col_data)
898
+
899
+ if is_numeric:
900
+ # Get quantile thresholds (fewer bins = faster)
901
+ valid_data = col_data[~nan_mask]
902
+ if len(valid_data) == 0:
903
+ continue
904
+
905
+ n_unique = len(np.unique(valid_data))
906
+ if n_unique <= self.discretization_bins:
907
+ # Treat as categorical
908
+ for val in np.unique(valid_data):
909
+ cond_idx = len(self._feature_conditions)
910
+ mask = (col_data == val) & ~nan_mask
911
+ self._condition_masks[cond_idx] = mask
912
+ self._feature_conditions.append((feat_idx, '==', val, False))
913
+ else:
914
+ # Use quantile thresholds
915
+ quantiles = np.linspace(0, 1, self.discretization_bins + 1)[1:-1]
916
+ thresholds = np.unique(np.quantile(valid_data, quantiles))
917
+
918
+ for thresh in thresholds:
919
+ # <= condition
920
+ cond_idx = len(self._feature_conditions)
921
+ mask = (col_data <= thresh) & ~nan_mask
922
+ self._condition_masks[cond_idx] = mask
923
+ self._feature_conditions.append((feat_idx, '<=', thresh, True))
924
+
925
+ # > condition
926
+ cond_idx = len(self._feature_conditions)
927
+ mask = (col_data > thresh) & ~nan_mask
928
+ self._condition_masks[cond_idx] = mask
929
+ self._feature_conditions.append((feat_idx, '>', thresh, True))
930
+ else:
931
+ # Categorical feature - limit number of values
932
+ unique_vals = pd.Series(col_data).dropna().unique()
933
+ # Limit to top N most frequent values for performance
934
+ if len(unique_vals) > 10:
935
+ value_counts = pd.Series(col_data).value_counts()
936
+ unique_vals = value_counts.head(10).index.tolist()
937
+
938
+ for val in unique_vals:
939
+ cond_idx = len(self._feature_conditions)
940
+ mask = (col_data == val) & ~nan_mask
941
+ self._condition_masks[cond_idx] = mask
942
+ self._feature_conditions.append((feat_idx, '==', val, False))
943
+
944
+ def _compute_rule_mask(self, condition_indices: List[int]) -> np.ndarray:
945
+ """Compute mask for a rule given its condition indices."""
946
+ if not condition_indices:
947
+ return np.ones(len(self._y_encoded), dtype=bool)
948
+
949
+ mask = self._condition_masks[condition_indices[0]].copy()
950
+ for idx in condition_indices[1:]:
951
+ mask &= self._condition_masks[idx]
952
+ return mask
953
+
954
+ def _find_best_rule_fast(self, remaining_mask: np.ndarray) -> Optional[Rule]:
955
+ """
956
+ Find the best rule using optimized beam search.
957
+ """
958
+ n_samples = remaining_mask.sum()
959
+ if n_samples == 0:
960
+ return None
961
+
962
+ n_classes = len(self.classes_)
963
+ if n_classes == 0:
964
+ return None
965
+
966
+ # Check if there are any valid conditions to try
967
+ if len(self._feature_conditions) == 0:
968
+ return None
969
+
970
+ y_remaining = self._y_encoded[remaining_mask]
971
+
972
+ # Track best rule found
973
+ best_rule_info = None
974
+ best_quality = float('-inf')
975
+
976
+ # Beam: list of (condition_indices, quality)
977
+ beam = [([], 0.0)]
978
+
979
+ for depth in range(self.max_rule_length):
980
+ new_beam = []
981
+ seen_masks = set() # Avoid duplicate rules
982
+
983
+ for cond_indices, _ in beam:
984
+ # Get current mask
985
+ if cond_indices:
986
+ current_mask = self._compute_rule_mask(cond_indices) & remaining_mask
987
+ else:
988
+ current_mask = remaining_mask.copy()
989
+
990
+ current_coverage = current_mask.sum()
991
+ if current_coverage < self.min_covered_examples:
992
+ continue
993
+
994
+ # Get features already used
995
+ used_features = {self._feature_conditions[i][0] for i in cond_indices}
996
+
997
+ # Try adding each condition
998
+ for cond_idx, (feat_idx, op, val, is_num) in enumerate(self._feature_conditions):
999
+ # Skip if feature already used
1000
+ if feat_idx in used_features:
1001
+ continue
1002
+
1003
+ # Compute combined mask
1004
+ new_mask = current_mask & self._condition_masks[cond_idx]
1005
+ coverage = new_mask.sum()
1006
+
1007
+ # Skip if insufficient coverage
1008
+ if coverage < self.min_covered_examples:
1009
+ continue
1010
+
1011
+ # Create hashable key to avoid duplicates
1012
+ mask_key = tuple(sorted(cond_indices + [cond_idx]))
1013
+ if mask_key in seen_masks:
1014
+ continue
1015
+ seen_masks.add(mask_key)
1016
+
1017
+ # Compute class distribution efficiently
1018
+ y_covered = self._y_encoded[new_mask]
1019
+ class_counts = np.bincount(y_covered, minlength=n_classes)
1020
+
1021
+ # Compute quality
1022
+ quality = self._compute_quality_fast(class_counts, coverage, n_samples)
1023
+
1024
+ # Compute accuracy
1025
+ majority_class_count = class_counts.max()
1026
+ accuracy = majority_class_count / coverage
1027
+
1028
+ # Track best overall
1029
+ if quality > best_quality and accuracy >= 0.5:
1030
+ best_quality = quality
1031
+ best_rule_info = (cond_indices + [cond_idx], class_counts, coverage, accuracy, quality)
1032
+
1033
+ new_beam.append((cond_indices + [cond_idx], quality))
1034
+
1035
+ if not new_beam:
1036
+ break
1037
+
1038
+ # Keep top beam_width candidates (filter out NaN/inf quality scores)
1039
+ new_beam = [(c, q) for c, q in new_beam if not (np.isnan(q) or np.isinf(q) and q < 0)]
1040
+ if not new_beam:
1041
+ break
1042
+ new_beam.sort(key=lambda x: x[1], reverse=True)
1043
+ beam = new_beam[:self.beam_width]
1044
+
1045
+ # Convert best rule info to Rule object
1046
+ if best_rule_info is None:
1047
+ return None
1048
+
1049
+ cond_indices, class_counts, coverage, accuracy, quality = best_rule_info
1050
+
1051
+ # Build Rule object
1052
+ conditions = []
1053
+ for idx in cond_indices:
1054
+ feat_idx, op, val, is_num = self._feature_conditions[idx]
1055
+ feat_name = self.feature_names_[feat_idx]
1056
+ conditions.append(Condition(feat_name, op, val, is_num))
1057
+
1058
+ # Get predicted class and distribution
1059
+ majority_idx = class_counts.argmax()
1060
+ predicted_class = self.classes_[majority_idx]
1061
+ class_distribution = {self.classes_[i]: int(class_counts[i])
1062
+ for i in range(len(self.classes_)) if class_counts[i] > 0}
1063
+
1064
+ rule = Rule(
1065
+ conditions=conditions,
1066
+ predicted_class=predicted_class,
1067
+ coverage=int(coverage),
1068
+ accuracy=accuracy,
1069
+ class_distribution=class_distribution,
1070
+ quality_score=quality
1071
+ )
1072
+ rule._condition_indices = cond_indices # Store for mask computation
1073
+
1074
+ return rule
1075
+
1076
+ def _compute_quality_fast(self, class_counts: np.ndarray, coverage: int, total: int) -> float:
1077
+ """Compute quality score efficiently using numpy."""
1078
+ if coverage == 0 or total == 0:
1079
+ return float('-inf')
1080
+
1081
+ try:
1082
+ if self.quality_measure == 'likelihood_ratio':
1083
+ # Likelihood ratio
1084
+ n_classes = len(self.classes_)
1085
+ if n_classes == 0:
1086
+ return float('-inf')
1087
+ expected = coverage / n_classes
1088
+ if expected <= 0:
1089
+ return float('-inf')
1090
+ lr = 0.0
1091
+ for count in class_counts:
1092
+ if count > 0:
1093
+ lr += count * np.log(count / expected + 1e-10)
1094
+ result = 2 * lr
1095
+ else:
1096
+ # Entropy-based
1097
+ probs = class_counts / coverage
1098
+ probs = probs[probs > 0]
1099
+ if len(probs) == 0:
1100
+ return float('-inf')
1101
+ entropy = -np.sum(probs * np.log2(probs + 1e-10))
1102
+ max_entropy = np.log2(len(self.classes_))
1103
+ if max_entropy > 0:
1104
+ entropy /= max_entropy
1105
+ result = (coverage / total) * (1.0 - entropy)
1106
+
1107
+ # Guard against NaN
1108
+ if np.isnan(result) or np.isinf(result):
1109
+ return float('-inf')
1110
+ return float(result)
1111
+ except Exception:
1112
+ return float('-inf')
1113
+
1114
+ def predict(self, X) -> np.ndarray:
1115
+ """
1116
+ Predict class labels for samples in X.
1117
+ """
1118
+ self._check_is_fitted()
1119
+ X = self._ensure_dataframe(X)
1120
+
1121
+ predictions = np.full(len(X), self.default_class_, dtype=object)
1122
+ predicted = np.zeros(len(X), dtype=bool)
1123
+
1124
+ for rule in self.rules_:
1125
+ mask = rule.covers_mask(X)
1126
+ new_preds = mask & ~predicted
1127
+ predictions[new_preds] = rule.predicted_class
1128
+ predicted |= mask
1129
+
1130
+ return predictions
1131
+
1132
+ def predict_proba(self, X) -> np.ndarray:
1133
+ """
1134
+ Return class probability estimates for samples in X.
1135
+ """
1136
+ self._check_is_fitted()
1137
+ X = self._ensure_dataframe(X)
1138
+
1139
+ n_samples = len(X)
1140
+ n_classes = len(self.classes_)
1141
+ proba = np.zeros((n_samples, n_classes))
1142
+
1143
+ # Default probabilities
1144
+ default_proba = self._compute_proba_from_distribution(
1145
+ Counter({c: 1 for c in self.classes_})
1146
+ )
1147
+
1148
+ assigned = np.zeros(n_samples, dtype=bool)
1149
+
1150
+ for rule in self.rules_:
1151
+ mask = rule.covers_mask(X)
1152
+ new_assignments = mask & ~assigned
1153
+
1154
+ if new_assignments.sum() > 0:
1155
+ rule_proba = self._compute_proba_from_distribution(rule.class_distribution)
1156
+ proba[new_assignments] = rule_proba
1157
+ assigned[new_assignments] = True
1158
+
1159
+ proba[~assigned] = default_proba
1160
+ return proba
1161
+
1162
+ def _compute_proba_from_distribution(self, class_dist: Dict[Any, int]) -> np.ndarray:
1163
+ """Compute class probabilities from distribution."""
1164
+ n_classes = len(self.classes_)
1165
+ proba = np.zeros(n_classes)
1166
+ total = sum(class_dist.values())
1167
+
1168
+ if self.laplace_smoothing:
1169
+ smoothed_total = total + n_classes
1170
+ for i, cls in enumerate(self.classes_):
1171
+ count = class_dist.get(cls, 0)
1172
+ proba[i] = (count + 1) / smoothed_total
1173
+ else:
1174
+ for i, cls in enumerate(self.classes_):
1175
+ count = class_dist.get(cls, 0)
1176
+ proba[i] = count / total if total > 0 else 1 / n_classes
1177
+
1178
+ return proba
1179
+
1180
+ def get_rules(self) -> List[Rule]:
1181
+ """Return the learned rules."""
1182
+ self._check_is_fitted()
1183
+ return self.rules_.copy()
1184
+
1185
+ def get_rules_as_df(self) -> pd.DataFrame:
1186
+ """Return rules as a pandas DataFrame."""
1187
+ self._check_is_fitted()
1188
+ records = []
1189
+ for i, rule in enumerate(self.rules_):
1190
+ records.append({
1191
+ 'rule_id': i + 1,
1192
+ 'conditions': ' AND '.join(str(c) for c in rule.conditions) or 'True',
1193
+ 'predicted_class': rule.predicted_class,
1194
+ 'coverage': rule.coverage,
1195
+ 'accuracy': rule.accuracy,
1196
+ 'quality_score': rule.quality_score,
1197
+ 'n_conditions': len(rule.conditions)
1198
+ })
1199
+ return pd.DataFrame(records)
1200
+
1201
+ def score(self, X, y) -> float:
1202
+ """Return accuracy on test data."""
1203
+ predictions = self.predict(X)
1204
+ return np.mean(predictions == np.asarray(y))
1205
+
1206
+ def _validate_input(self, X, y) -> Tuple[pd.DataFrame, np.ndarray]:
1207
+ """Validate and convert input data."""
1208
+ X = self._ensure_dataframe(X)
1209
+ y = np.asarray(y, dtype=object) # Keep as object to preserve types
1210
+
1211
+ if len(X) != len(y):
1212
+ raise ValueError(f"X and y must have same length. Got {len(X)} and {len(y)}")
1213
+
1214
+ if len(X) == 0:
1215
+ raise ValueError("Cannot fit on empty dataset")
1216
+
1217
+ # Handle NaN values in target - remove rows with NaN target
1218
+ # Check for NaN: works for float NaN and None
1219
+ valid_mask = np.array([
1220
+ not (v is None or (isinstance(v, float) and np.isnan(v)) or
1221
+ (isinstance(v, str) and v.lower() == 'nan') or
1222
+ pd.isna(v))
1223
+ for v in y
1224
+ ])
1225
+
1226
+ if valid_mask.sum() == 0:
1227
+ raise ValueError("All target values are NaN or missing")
1228
+
1229
+ if valid_mask.sum() < len(y):
1230
+ # Filter out rows with NaN target
1231
+ X = X.loc[valid_mask].reset_index(drop=True)
1232
+ y = y[valid_mask]
1233
+
1234
+ # Convert target to string to avoid mixed type issues during np.unique
1235
+ y = np.array([str(v) for v in y])
1236
+
1237
+ return X, y
1238
+
1239
+ def _ensure_dataframe(self, X) -> pd.DataFrame:
1240
+ """Convert X to DataFrame."""
1241
+ if isinstance(X, pd.DataFrame):
1242
+ return X.copy()
1243
+ elif isinstance(X, np.ndarray):
1244
+ return pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
1245
+ else:
1246
+ raise TypeError(f"X must be DataFrame or ndarray, got {type(X)}")
1247
+
1248
+ def _check_is_fitted(self):
1249
+ """Check if fitted."""
1250
+ if not self._is_fitted:
1251
+ raise RuntimeError("CN2Classifier not fitted. Call fit() first.")
1252
+
1253
+
1254
+ # =============================================================================
1255
+ # Visualization (PyQt6)
1256
+ # =============================================================================
1257
+
1258
+ from PyQt6.QtWidgets import (
1259
+ QMainWindow, QVBoxLayout, QHBoxLayout, QWidget,
1260
+ QTableWidget, QTableWidgetItem, QLabel, QPushButton,
1261
+ QComboBox, QSplitter, QTabWidget, QScrollArea,
1262
+ QFrame, QHeaderView, QTextEdit, QSpinBox,
1263
+ QDoubleSpinBox, QGroupBox, QFormLayout, QMessageBox
1264
+ )
1265
+ from PyQt6.QtCore import Qt, pyqtSignal
1266
+ from PyQt6.QtGui import QFont, QColor
1267
+
1268
+
1269
+ class CN2RulesVisualization(QMainWindow):
1270
+ """
1271
+ Window to visualize CN2 classification rules.
1272
+ """
1273
+
1274
+ rulesApplied = pyqtSignal(object)
1275
+
1276
+ def __init__(
1277
+ self,
1278
+ classifier: CN2Classifier,
1279
+ X: pd.DataFrame,
1280
+ y: np.ndarray,
1281
+ target_column: str = "target",
1282
+ parent=None
1283
+ ):
1284
+ super().__init__(parent)
1285
+ self.classifier = classifier
1286
+ self.X = X
1287
+ self.y = y
1288
+ self.target_column = target_column
1289
+
1290
+ self.setWindowTitle(f"CN2 Rule Induction - {target_column}")
1291
+ self.setGeometry(100, 100, 1200, 800)
1292
+
1293
+ self._setup_ui()
1294
+ self._populate_data()
1295
+
1296
+ def _setup_ui(self):
1297
+ """Set up the user interface."""
1298
+ main_widget = QWidget()
1299
+ self.setCentralWidget(main_widget)
1300
+ main_layout = QVBoxLayout(main_widget)
1301
+
1302
+ # Title
1303
+ title_label = QLabel(f"CN2 Rule Induction Analysis: {self.target_column}")
1304
+ title_font = QFont()
1305
+ title_font.setBold(True)
1306
+ title_font.setPointSize(14)
1307
+ title_label.setFont(title_font)
1308
+ title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
1309
+ main_layout.addWidget(title_label)
1310
+
1311
+ # Description
1312
+ desc_label = QLabel(
1313
+ "CN2 learns interpretable IF-THEN classification rules. "
1314
+ "Rules are applied in order - the first matching rule determines the prediction."
1315
+ )
1316
+ desc_label.setWordWrap(True)
1317
+ main_layout.addWidget(desc_label)
1318
+
1319
+ # Summary stats
1320
+ self._create_summary_section(main_layout)
1321
+
1322
+ # Tabs
1323
+ tab_widget = QTabWidget()
1324
+
1325
+ # Rules Table Tab
1326
+ rules_tab = QWidget()
1327
+ rules_layout = QVBoxLayout(rules_tab)
1328
+ self.rules_table = self._create_rules_table()
1329
+ rules_layout.addWidget(self.rules_table)
1330
+ tab_widget.addTab(rules_tab, "Rules")
1331
+
1332
+ # Rule Details Tab
1333
+ details_tab = QWidget()
1334
+ details_layout = QVBoxLayout(details_tab)
1335
+ self.details_text = QTextEdit()
1336
+ self.details_text.setReadOnly(True)
1337
+ self.details_text.setFont(QFont("Courier New", 10))
1338
+ details_layout.addWidget(self.details_text)
1339
+ tab_widget.addTab(details_tab, "Rule Details")
1340
+
1341
+ # Algorithm Info Tab
1342
+ info_tab = QWidget()
1343
+ info_layout = QVBoxLayout(info_tab)
1344
+ self.info_text = QTextEdit()
1345
+ self.info_text.setReadOnly(True)
1346
+ info_layout.addWidget(self.info_text)
1347
+ tab_widget.addTab(info_tab, "Algorithm Info")
1348
+
1349
+ main_layout.addWidget(tab_widget, 1)
1350
+
1351
+ # Apply button
1352
+ button_layout = QHBoxLayout()
1353
+ button_layout.addStretch()
1354
+
1355
+ self.apply_button = QPushButton("Apply Rules")
1356
+ self.apply_button.setStyleSheet("""
1357
+ QPushButton {
1358
+ background-color: #3498DB;
1359
+ color: white;
1360
+ border: none;
1361
+ padding: 8px 16px;
1362
+ border-radius: 4px;
1363
+ font-weight: bold;
1364
+ }
1365
+ QPushButton:hover { background-color: #2980B9; }
1366
+ """)
1367
+ self.apply_button.clicked.connect(self._apply_rules)
1368
+ button_layout.addWidget(self.apply_button)
1369
+
1370
+ main_layout.addLayout(button_layout)
1371
+
1372
+ def _create_summary_section(self, parent_layout):
1373
+ """Create summary statistics section."""
1374
+ summary_frame = QFrame()
1375
+ summary_frame.setFrameShape(QFrame.Shape.StyledPanel)
1376
+ summary_layout = QHBoxLayout(summary_frame)
1377
+
1378
+ self.n_rules_label = QLabel("Rules: -")
1379
+ self.n_rules_label.setStyleSheet("font-weight: bold;")
1380
+ summary_layout.addWidget(self.n_rules_label)
1381
+
1382
+ summary_layout.addWidget(QLabel("|"))
1383
+
1384
+ self.accuracy_label = QLabel("Training Accuracy: -")
1385
+ self.accuracy_label.setStyleSheet("font-weight: bold;")
1386
+ summary_layout.addWidget(self.accuracy_label)
1387
+
1388
+ summary_layout.addWidget(QLabel("|"))
1389
+
1390
+ self.coverage_label = QLabel("Total Coverage: -")
1391
+ self.coverage_label.setStyleSheet("font-weight: bold;")
1392
+ summary_layout.addWidget(self.coverage_label)
1393
+
1394
+ summary_layout.addWidget(QLabel("|"))
1395
+
1396
+ self.classes_label = QLabel("Classes: -")
1397
+ summary_layout.addWidget(self.classes_label)
1398
+
1399
+ summary_layout.addStretch()
1400
+ parent_layout.addWidget(summary_frame)
1401
+
1402
+ def _create_rules_table(self) -> QTableWidget:
1403
+ """Create rules table widget."""
1404
+ table = QTableWidget()
1405
+ table.setColumnCount(6)
1406
+ table.setHorizontalHeaderLabels([
1407
+ "Rule #", "Conditions", "Predicted Class",
1408
+ "Coverage", "Accuracy", "Quality"
1409
+ ])
1410
+
1411
+ header = table.horizontalHeader()
1412
+ header.setSectionResizeMode(0, QHeaderView.ResizeMode.Fixed)
1413
+ header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
1414
+ header.setSectionResizeMode(2, QHeaderView.ResizeMode.Fixed)
1415
+ header.setSectionResizeMode(3, QHeaderView.ResizeMode.Fixed)
1416
+ header.setSectionResizeMode(4, QHeaderView.ResizeMode.Fixed)
1417
+ header.setSectionResizeMode(5, QHeaderView.ResizeMode.Fixed)
1418
+
1419
+ table.setColumnWidth(0, 60)
1420
+ table.setColumnWidth(2, 120)
1421
+ table.setColumnWidth(3, 80)
1422
+ table.setColumnWidth(4, 80)
1423
+ table.setColumnWidth(5, 80)
1424
+
1425
+ table.setAlternatingRowColors(True)
1426
+ table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
1427
+ table.itemSelectionChanged.connect(self._on_rule_selected)
1428
+
1429
+ return table
1430
+
1431
+ def _populate_data(self):
1432
+ """Populate visualization with classifier data."""
1433
+ rules = self.classifier.get_rules()
1434
+
1435
+ self.n_rules_label.setText(f"Rules: {len(rules)}")
1436
+
1437
+ accuracy = self.classifier.score(self.X, self.y)
1438
+ self.accuracy_label.setText(f"Training Accuracy: {accuracy:.1%}")
1439
+
1440
+ total_coverage = sum(r.coverage for r in rules)
1441
+ self.coverage_label.setText(f"Total Coverage: {total_coverage}")
1442
+
1443
+ classes_str = ", ".join(str(c) for c in self.classifier.classes_)
1444
+ self.classes_label.setText(f"Classes: {classes_str}")
1445
+
1446
+ # Populate rules table
1447
+ self.rules_table.setRowCount(len(rules))
1448
+ for i, rule in enumerate(rules):
1449
+ self.rules_table.setItem(i, 0, QTableWidgetItem(str(i + 1)))
1450
+
1451
+ conditions_str = " AND ".join(str(c) for c in rule.conditions) or "True"
1452
+ self.rules_table.setItem(i, 1, QTableWidgetItem(conditions_str))
1453
+
1454
+ self.rules_table.setItem(i, 2, QTableWidgetItem(str(rule.predicted_class)))
1455
+ self.rules_table.setItem(i, 3, QTableWidgetItem(str(rule.coverage)))
1456
+ self.rules_table.setItem(i, 4, QTableWidgetItem(f"{rule.accuracy:.1%}"))
1457
+ self.rules_table.setItem(i, 5, QTableWidgetItem(f"{rule.quality_score:.2f}"))
1458
+
1459
+ # Color by accuracy
1460
+ if rule.accuracy >= 0.9:
1461
+ color = QColor(200, 255, 200)
1462
+ elif rule.accuracy >= 0.7:
1463
+ color = QColor(255, 255, 200)
1464
+ else:
1465
+ color = QColor(255, 200, 200)
1466
+
1467
+ for j in range(6):
1468
+ item = self.rules_table.item(i, j)
1469
+ if item:
1470
+ item.setBackground(color)
1471
+
1472
+ self._populate_algorithm_info()
1473
+
1474
+ if rules:
1475
+ self.rules_table.selectRow(0)
1476
+
1477
+ def _populate_algorithm_info(self):
1478
+ """Populate algorithm info tab."""
1479
+ clf = self.classifier
1480
+ info = f"""
1481
+ === CN2 Rule Induction Algorithm ===
1482
+
1483
+ Parameters:
1484
+ • Beam Width: {clf.beam_width}
1485
+ • Min Covered Examples: {clf.min_covered_examples}
1486
+ • Max Rule Length: {clf.max_rule_length}
1487
+ • Max Rules: {clf.max_rules}
1488
+ • Quality Measure: {clf.quality_measure}
1489
+ • Discretization Bins: {clf.discretization_bins}
1490
+
1491
+ Dataset:
1492
+ • Total Samples: {len(self.X)}
1493
+ • Features: {clf.n_features_}
1494
+ • Classes: {len(clf.classes_)}
1495
+
1496
+ Results:
1497
+ • Rules Learned: {len(clf.rules_)}
1498
+ • Default Class: {clf.default_class_}
1499
+ • Training Accuracy: {clf.score(self.X, self.y):.2%}
1500
+
1501
+ How CN2 Works:
1502
+ 1. Pre-compute all possible condition masks (for speed)
1503
+ 2. Use beam search to find the best rule
1504
+ 3. Add rule to rule list
1505
+ 4. Remove covered examples
1506
+ 5. Repeat until stopping criteria met
1507
+ """
1508
+ self.info_text.setPlainText(info)
1509
+
1510
+ def _on_rule_selected(self):
1511
+ """Handle rule selection."""
1512
+ selected = self.rules_table.selectedItems()
1513
+ if not selected:
1514
+ return
1515
+
1516
+ row = selected[0].row()
1517
+ rules = self.classifier.get_rules()
1518
+ if row < len(rules):
1519
+ self._show_rule_details(rules[row], row + 1)
1520
+
1521
+ def _show_rule_details(self, rule: Rule, rule_num: int):
1522
+ """Show rule details."""
1523
+ details = f"""
1524
+ === Rule {rule_num} Details ===
1525
+
1526
+ Full Rule:
1527
+ {str(rule)}
1528
+
1529
+ Conditions ({len(rule.conditions)}):
1530
+ """
1531
+ if rule.conditions:
1532
+ for i, cond in enumerate(rule.conditions, 1):
1533
+ details += f" {i}. {cond}\n"
1534
+ else:
1535
+ details += " (No conditions - default rule)\n"
1536
+
1537
+ details += f"""
1538
+ Prediction:
1539
+ Predicted Class: {rule.predicted_class}
1540
+
1541
+ Metrics:
1542
+ Coverage: {rule.coverage} examples
1543
+ Accuracy: {rule.accuracy:.2%}
1544
+ Quality Score: {rule.quality_score:.4f}
1545
+
1546
+ Class Distribution:
1547
+ """
1548
+ for cls, count in sorted(rule.class_distribution.items(), key=lambda x: -x[1]):
1549
+ pct = count / rule.coverage * 100 if rule.coverage > 0 else 0
1550
+ bar = "█" * int(pct / 5)
1551
+ details += f" {cls}: {count} ({pct:.1f}%) {bar}\n"
1552
+
1553
+ self.details_text.setPlainText(details)
1554
+
1555
+ def _apply_rules(self):
1556
+ """Apply rules."""
1557
+ reply = QMessageBox.question(
1558
+ self,
1559
+ "Apply Rules",
1560
+ "Apply these rules to generate predictions?",
1561
+ QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
1562
+ QMessageBox.StandardButton.No
1563
+ )
1564
+
1565
+ if reply == QMessageBox.StandardButton.Yes:
1566
+ self.rulesApplied.emit(self.classifier)
1567
+ QMessageBox.information(self, "Rules Applied", "CN2 classifier applied successfully.")
1568
+
1569
+
1570
+ # =============================================================================
1571
+ # Convenience Functions
1572
+ # =============================================================================
1573
+
1574
+ def fit_cn2(df: pd.DataFrame, target_column: str, **kwargs) -> CN2Classifier:
1575
+ """
1576
+ Fit a CN2 classifier on a DataFrame.
1577
+ """
1578
+ if target_column not in df.columns:
1579
+ raise ValueError(f"Target column '{target_column}' not found")
1580
+
1581
+ X = df.drop(columns=[target_column])
1582
+ y = df[target_column].values
1583
+
1584
+ clf = CN2Classifier(**kwargs)
1585
+ clf.fit(X, y)
1586
+
1587
+ return clf
1588
+
1589
+
1590
+ def visualize_cn2_rules(df: pd.DataFrame, target_column: str, **kwargs) -> CN2RulesVisualization:
1591
+ """
1592
+ Fit CN2 and create visualization window.
1593
+ """
1594
+ if target_column not in df.columns:
1595
+ raise ValueError(f"Target column '{target_column}' not found")
1596
+
1597
+ X = df.drop(columns=[target_column])
1598
+ y = df[target_column].values
1599
+
1600
+ clf = CN2Classifier(**kwargs)
1601
+ clf.fit(X, y)
1602
+
1603
+ vis = CN2RulesVisualization(clf, X, y, target_column)
1604
+ vis.show()
1605
+
1606
+ return vis
1607
+
1608
+
1609
+ # =============================================================================
1610
+ # Testing
1611
+ # =============================================================================
1612
+
1613
+ def test_cn2():
1614
+ """Test the CN2 classifier."""
1615
+ print("\n===== Testing CN2 Rule Induction =====\n")
1616
+
1617
+ np.random.seed(42)
1618
+ n_samples = 150
1619
+
1620
+ data = {
1621
+ 'sepal_length': np.concatenate([
1622
+ np.random.normal(5.0, 0.3, 50),
1623
+ np.random.normal(6.0, 0.4, 50),
1624
+ np.random.normal(6.5, 0.4, 50)
1625
+ ]),
1626
+ 'petal_length': np.concatenate([
1627
+ np.random.normal(1.4, 0.2, 50),
1628
+ np.random.normal(4.2, 0.4, 50),
1629
+ np.random.normal(5.5, 0.5, 50)
1630
+ ]),
1631
+ 'species': ['setosa'] * 50 + ['versicolor'] * 50 + ['virginica'] * 50
1632
+ }
1633
+
1634
+ df = pd.DataFrame(data)
1635
+
1636
+ print("Dataset shape:", df.shape)
1637
+ print("Class distribution:")
1638
+ print(df['species'].value_counts())
1639
+ print()
1640
+
1641
+ import time
1642
+ start = time.time()
1643
+ clf = fit_cn2(df, target_column='species', beam_width=3, min_covered_examples=5)
1644
+ elapsed = time.time() - start
1645
+
1646
+ rules = clf.get_rules()
1647
+ print(f"\nLearned {len(rules)} rules in {elapsed:.3f}s:\n")
1648
+ for i, rule in enumerate(rules, 1):
1649
+ print(f"Rule {i}: {rule}")
1650
+
1651
+ X = df.drop(columns=['species'])
1652
+ y = df['species'].values
1653
+
1654
+ accuracy = clf.score(X, y)
1655
+ print(f"\nTraining accuracy: {accuracy:.2%}")
1656
+
1657
+ print("\n===== CN2 Test Complete =====\n")
1658
+
1659
+
1660
+ if __name__ == "__main__":
1661
+ test_cn2()