sqlshell 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlshell/__init__.py +84 -0
- sqlshell/__main__.py +4926 -0
- sqlshell/ai_autocomplete.py +392 -0
- sqlshell/ai_settings_dialog.py +337 -0
- sqlshell/context_suggester.py +768 -0
- sqlshell/create_test_data.py +152 -0
- sqlshell/data/create_test_data.py +137 -0
- sqlshell/db/__init__.py +6 -0
- sqlshell/db/database_manager.py +1318 -0
- sqlshell/db/export_manager.py +188 -0
- sqlshell/editor.py +1166 -0
- sqlshell/editor_integration.py +127 -0
- sqlshell/execution_handler.py +421 -0
- sqlshell/menus.py +262 -0
- sqlshell/notification_manager.py +370 -0
- sqlshell/query_tab.py +904 -0
- sqlshell/resources/__init__.py +1 -0
- sqlshell/resources/icon.png +0 -0
- sqlshell/resources/logo_large.png +0 -0
- sqlshell/resources/logo_medium.png +0 -0
- sqlshell/resources/logo_small.png +0 -0
- sqlshell/resources/splash_screen.gif +0 -0
- sqlshell/space_invaders.py +501 -0
- sqlshell/splash_screen.py +405 -0
- sqlshell/sqlshell/__init__.py +5 -0
- sqlshell/sqlshell/create_test_data.py +118 -0
- sqlshell/sqlshell/create_test_databases.py +96 -0
- sqlshell/sqlshell_demo.png +0 -0
- sqlshell/styles.py +257 -0
- sqlshell/suggester_integration.py +330 -0
- sqlshell/syntax_highlighter.py +124 -0
- sqlshell/table_list.py +996 -0
- sqlshell/ui/__init__.py +6 -0
- sqlshell/ui/bar_chart_delegate.py +49 -0
- sqlshell/ui/filter_header.py +469 -0
- sqlshell/utils/__init__.py +16 -0
- sqlshell/utils/profile_cn2.py +1661 -0
- sqlshell/utils/profile_column.py +2635 -0
- sqlshell/utils/profile_distributions.py +616 -0
- sqlshell/utils/profile_entropy.py +347 -0
- sqlshell/utils/profile_foreign_keys.py +779 -0
- sqlshell/utils/profile_keys.py +2834 -0
- sqlshell/utils/profile_ohe.py +934 -0
- sqlshell/utils/profile_ohe_advanced.py +754 -0
- sqlshell/utils/profile_ohe_comparison.py +237 -0
- sqlshell/utils/profile_prediction.py +926 -0
- sqlshell/utils/profile_similarity.py +876 -0
- sqlshell/utils/search_in_df.py +90 -0
- sqlshell/widgets.py +400 -0
- sqlshell-0.4.4.dist-info/METADATA +441 -0
- sqlshell-0.4.4.dist-info/RECORD +54 -0
- sqlshell-0.4.4.dist-info/WHEEL +5 -0
- sqlshell-0.4.4.dist-info/entry_points.txt +2 -0
- sqlshell-0.4.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1661 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CN2 Rule Induction Algorithm for Classification
|
|
3
|
+
|
|
4
|
+
This module implements the classic CN2 rule-induction algorithm inspired by
|
|
5
|
+
Clark & Niblett (1989). It learns interpretable IF-THEN classification rules
|
|
6
|
+
using a separate-and-conquer (cover-and-remove) strategy with beam search.
|
|
7
|
+
|
|
8
|
+
The algorithm supports:
|
|
9
|
+
- Mixed categorical and numeric features
|
|
10
|
+
- Supervised discretization for numeric features
|
|
11
|
+
- Multiple quality measures (likelihood_ratio, entropy)
|
|
12
|
+
- Laplace-smoothed probability estimates
|
|
13
|
+
- Automatic discretization of numeric target variables using academic methods
|
|
14
|
+
|
|
15
|
+
Example usage:
|
|
16
|
+
from sqlshell.utils.profile_cn2 import CN2Classifier, visualize_cn2_rules
|
|
17
|
+
|
|
18
|
+
clf = CN2Classifier(beam_width=5, min_covered_examples=5)
|
|
19
|
+
clf.fit(X, y)
|
|
20
|
+
predictions = clf.predict(X)
|
|
21
|
+
rules = clf.get_rules()
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import pandas as pd
|
|
25
|
+
import numpy as np
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import List, Dict, Tuple, Optional, Any, Union
|
|
28
|
+
from collections import Counter
|
|
29
|
+
import warnings
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =============================================================================
|
|
33
|
+
# Numeric Target Discretization
|
|
34
|
+
# =============================================================================
|
|
35
|
+
|
|
36
|
+
class NumericTargetDiscretizer:
|
|
37
|
+
"""
|
|
38
|
+
Discretizes continuous numeric target variables into meaningful categorical bins.
|
|
39
|
+
|
|
40
|
+
This class implements several academically-grounded binning methods for converting
|
|
41
|
+
continuous numeric variables into discrete categories suitable for classification.
|
|
42
|
+
|
|
43
|
+
Supported Methods:
|
|
44
|
+
-----------------
|
|
45
|
+
1. **jenks** (Fisher-Jenks Natural Breaks):
|
|
46
|
+
- Minimizes within-class variance while maximizing between-class variance
|
|
47
|
+
- Based on Fisher (1958) and Jenks (1967)
|
|
48
|
+
- Best for data with natural clusters or multimodal distributions
|
|
49
|
+
- Reference: Jenks, G.F. (1967). "The Data Model Concept in Statistical Mapping"
|
|
50
|
+
|
|
51
|
+
2. **quantile** (Equal-Frequency Binning):
|
|
52
|
+
- Creates bins with equal number of observations
|
|
53
|
+
- Robust to outliers and skewed distributions
|
|
54
|
+
- Based on standard statistical quantile theory
|
|
55
|
+
|
|
56
|
+
3. **freedman_diaconis** (Freedman-Diaconis Rule):
|
|
57
|
+
- Bin width = 2 * IQR / n^(1/3)
|
|
58
|
+
- Optimal for normally distributed data
|
|
59
|
+
- Reference: Freedman, D. & Diaconis, P. (1981). "On the histogram as a
|
|
60
|
+
density estimator: L2 theory"
|
|
61
|
+
|
|
62
|
+
4. **sturges** (Sturges' Rule):
|
|
63
|
+
- Number of bins = 1 + log2(n)
|
|
64
|
+
- Classic rule, best for symmetric distributions
|
|
65
|
+
- Reference: Sturges, H.A. (1926). "The Choice of a Class Interval"
|
|
66
|
+
|
|
67
|
+
5. **equal_width** (Equal-Width Binning):
|
|
68
|
+
- Creates bins of equal range
|
|
69
|
+
- Simple but may create empty bins with skewed data
|
|
70
|
+
|
|
71
|
+
6. **auto** (Automatic Selection):
|
|
72
|
+
- Automatically selects the best method based on data characteristics
|
|
73
|
+
- Uses skewness and distribution analysis to choose optimal method
|
|
74
|
+
|
|
75
|
+
Parameters:
|
|
76
|
+
----------
|
|
77
|
+
method : str, default='auto'
|
|
78
|
+
The binning method to use. One of: 'auto', 'jenks', 'quantile',
|
|
79
|
+
'freedman_diaconis', 'sturges', 'equal_width'
|
|
80
|
+
|
|
81
|
+
n_bins : int, optional
|
|
82
|
+
Number of bins to create. If None, determined automatically based on method.
|
|
83
|
+
For 'freedman_diaconis' and 'sturges', this is computed from the data.
|
|
84
|
+
|
|
85
|
+
min_bin_size : int, default=5
|
|
86
|
+
Minimum number of samples per bin. Bins smaller than this are merged.
|
|
87
|
+
|
|
88
|
+
max_bins : int, default=10
|
|
89
|
+
Maximum number of bins to create (for automatic methods).
|
|
90
|
+
|
|
91
|
+
Attributes:
|
|
92
|
+
----------
|
|
93
|
+
bin_edges_ : np.ndarray
|
|
94
|
+
The computed bin edges after fitting.
|
|
95
|
+
|
|
96
|
+
bin_labels_ : List[str]
|
|
97
|
+
Human-readable labels for each bin (e.g., "1-1000", "1001-5000").
|
|
98
|
+
|
|
99
|
+
method_used_ : str
|
|
100
|
+
The actual method used (relevant when method='auto').
|
|
101
|
+
|
|
102
|
+
n_bins_ : int
|
|
103
|
+
The number of bins created.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
-------
|
|
107
|
+
>>> discretizer = NumericTargetDiscretizer(method='jenks', n_bins=5)
|
|
108
|
+
>>> df['income_category'] = discretizer.fit_transform(df['income'])
|
|
109
|
+
>>> print(discretizer.bin_labels_)
|
|
110
|
+
['0-25000', '25001-50000', '50001-100000', '100001-250000', '250001+']
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
method: str = 'auto',
|
|
116
|
+
n_bins: Optional[int] = None,
|
|
117
|
+
min_bin_size: int = 5,
|
|
118
|
+
max_bins: int = 10
|
|
119
|
+
):
|
|
120
|
+
valid_methods = {'auto', 'jenks', 'quantile', 'freedman_diaconis',
|
|
121
|
+
'sturges', 'equal_width'}
|
|
122
|
+
if method not in valid_methods:
|
|
123
|
+
raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
|
|
124
|
+
|
|
125
|
+
self.method = method
|
|
126
|
+
self.n_bins = n_bins
|
|
127
|
+
self.min_bin_size = min_bin_size
|
|
128
|
+
self.max_bins = max_bins
|
|
129
|
+
|
|
130
|
+
# Set after fitting
|
|
131
|
+
self.bin_edges_: np.ndarray = None
|
|
132
|
+
self.bin_labels_: List[str] = []
|
|
133
|
+
self.method_used_: str = None
|
|
134
|
+
self.n_bins_: int = 0
|
|
135
|
+
self._is_fitted: bool = False
|
|
136
|
+
|
|
137
|
+
def fit(self, data: Union[pd.Series, np.ndarray]) -> "NumericTargetDiscretizer":
|
|
138
|
+
"""
|
|
139
|
+
Compute bin edges from the data.
|
|
140
|
+
|
|
141
|
+
Parameters:
|
|
142
|
+
----------
|
|
143
|
+
data : array-like
|
|
144
|
+
The numeric data to fit.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
-------
|
|
148
|
+
self : NumericTargetDiscretizer
|
|
149
|
+
"""
|
|
150
|
+
# Convert to numpy array and remove NaN
|
|
151
|
+
if isinstance(data, pd.Series):
|
|
152
|
+
values = data.dropna().values
|
|
153
|
+
else:
|
|
154
|
+
values = np.asarray(data)
|
|
155
|
+
values = values[~np.isnan(values)]
|
|
156
|
+
|
|
157
|
+
if len(values) == 0:
|
|
158
|
+
raise ValueError("Cannot fit on empty data")
|
|
159
|
+
|
|
160
|
+
n_unique = len(np.unique(values))
|
|
161
|
+
if n_unique <= 1:
|
|
162
|
+
raise ValueError("Need at least 2 unique values to discretize")
|
|
163
|
+
|
|
164
|
+
# Select method if auto
|
|
165
|
+
if self.method == 'auto':
|
|
166
|
+
self.method_used_ = self._select_method(values)
|
|
167
|
+
else:
|
|
168
|
+
self.method_used_ = self.method
|
|
169
|
+
|
|
170
|
+
# Compute number of bins if not specified
|
|
171
|
+
if self.n_bins is not None:
|
|
172
|
+
n_bins = min(self.n_bins, n_unique, self.max_bins)
|
|
173
|
+
else:
|
|
174
|
+
n_bins = self._compute_n_bins(values, self.method_used_)
|
|
175
|
+
n_bins = min(n_bins, n_unique, self.max_bins)
|
|
176
|
+
|
|
177
|
+
# Ensure we have at least 2 bins
|
|
178
|
+
n_bins = max(2, n_bins)
|
|
179
|
+
|
|
180
|
+
# Compute bin edges using selected method
|
|
181
|
+
if self.method_used_ == 'jenks':
|
|
182
|
+
self.bin_edges_ = self._jenks_breaks(values, n_bins)
|
|
183
|
+
elif self.method_used_ == 'quantile':
|
|
184
|
+
self.bin_edges_ = self._quantile_breaks(values, n_bins)
|
|
185
|
+
elif self.method_used_ == 'freedman_diaconis':
|
|
186
|
+
self.bin_edges_ = self._freedman_diaconis_breaks(values, n_bins)
|
|
187
|
+
elif self.method_used_ == 'sturges':
|
|
188
|
+
self.bin_edges_ = self._equal_width_breaks(values, n_bins)
|
|
189
|
+
elif self.method_used_ == 'equal_width':
|
|
190
|
+
self.bin_edges_ = self._equal_width_breaks(values, n_bins)
|
|
191
|
+
|
|
192
|
+
# Merge small bins
|
|
193
|
+
self.bin_edges_ = self._merge_small_bins(values, self.bin_edges_)
|
|
194
|
+
|
|
195
|
+
# Generate human-readable labels
|
|
196
|
+
self.bin_labels_ = self._generate_labels(self.bin_edges_, values)
|
|
197
|
+
self.n_bins_ = len(self.bin_labels_)
|
|
198
|
+
self._is_fitted = True
|
|
199
|
+
|
|
200
|
+
return self
|
|
201
|
+
|
|
202
|
+
def transform(self, data: Union[pd.Series, np.ndarray]) -> np.ndarray:
|
|
203
|
+
"""
|
|
204
|
+
Transform numeric values to categorical bin labels.
|
|
205
|
+
|
|
206
|
+
Parameters:
|
|
207
|
+
----------
|
|
208
|
+
data : array-like
|
|
209
|
+
The numeric data to transform.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
-------
|
|
213
|
+
labels : np.ndarray
|
|
214
|
+
Array of string labels for each value.
|
|
215
|
+
"""
|
|
216
|
+
if not self._is_fitted:
|
|
217
|
+
raise RuntimeError("NumericTargetDiscretizer not fitted. Call fit() first.")
|
|
218
|
+
|
|
219
|
+
if isinstance(data, pd.Series):
|
|
220
|
+
values = data.values
|
|
221
|
+
else:
|
|
222
|
+
values = np.asarray(data)
|
|
223
|
+
|
|
224
|
+
# Digitize values into bins
|
|
225
|
+
# np.digitize returns indices 1 to n_bins, we want 0 to n_bins-1
|
|
226
|
+
bin_indices = np.digitize(values, self.bin_edges_[1:-1])
|
|
227
|
+
|
|
228
|
+
# Map indices to labels, handle NaN
|
|
229
|
+
labels = np.empty(len(values), dtype=object)
|
|
230
|
+
for i, (val, idx) in enumerate(zip(values, bin_indices)):
|
|
231
|
+
if np.isnan(val) if isinstance(val, float) else pd.isna(val):
|
|
232
|
+
labels[i] = "Missing"
|
|
233
|
+
else:
|
|
234
|
+
labels[i] = self.bin_labels_[min(idx, len(self.bin_labels_) - 1)]
|
|
235
|
+
|
|
236
|
+
return labels
|
|
237
|
+
|
|
238
|
+
def fit_transform(self, data: Union[pd.Series, np.ndarray]) -> np.ndarray:
|
|
239
|
+
"""Fit and transform in one step."""
|
|
240
|
+
return self.fit(data).transform(data)
|
|
241
|
+
|
|
242
|
+
def _select_method(self, values: np.ndarray) -> str:
|
|
243
|
+
"""
|
|
244
|
+
Automatically select the best binning method based on data characteristics.
|
|
245
|
+
|
|
246
|
+
Selection criteria (based on statistical theory):
|
|
247
|
+
- Jenks: For multimodal or clustered distributions
|
|
248
|
+
- Quantile: For highly skewed distributions (|skewness| > 1)
|
|
249
|
+
- Freedman-Diaconis: For moderate distributions
|
|
250
|
+
"""
|
|
251
|
+
from scipy import stats
|
|
252
|
+
|
|
253
|
+
n = len(values)
|
|
254
|
+
|
|
255
|
+
# Calculate skewness
|
|
256
|
+
skewness = stats.skew(values)
|
|
257
|
+
|
|
258
|
+
# Calculate kurtosis (excess kurtosis)
|
|
259
|
+
kurtosis = stats.kurtosis(values)
|
|
260
|
+
|
|
261
|
+
# Check for multimodality using dip test approximation
|
|
262
|
+
# Simple heuristic: check if histogram has multiple peaks
|
|
263
|
+
try:
|
|
264
|
+
hist, _ = np.histogram(values, bins='auto')
|
|
265
|
+
# Count local maxima
|
|
266
|
+
peaks = 0
|
|
267
|
+
for i in range(1, len(hist) - 1):
|
|
268
|
+
if hist[i] > hist[i-1] and hist[i] > hist[i+1]:
|
|
269
|
+
peaks += 1
|
|
270
|
+
is_multimodal = peaks >= 2
|
|
271
|
+
except Exception:
|
|
272
|
+
is_multimodal = False
|
|
273
|
+
|
|
274
|
+
# Selection logic
|
|
275
|
+
if is_multimodal:
|
|
276
|
+
# Jenks is best for multimodal distributions
|
|
277
|
+
return 'jenks'
|
|
278
|
+
elif abs(skewness) > 1.5:
|
|
279
|
+
# Highly skewed: quantile binning handles outliers better
|
|
280
|
+
return 'quantile'
|
|
281
|
+
elif abs(skewness) > 0.5:
|
|
282
|
+
# Moderately skewed: Freedman-Diaconis is robust
|
|
283
|
+
return 'freedman_diaconis'
|
|
284
|
+
else:
|
|
285
|
+
# Symmetric distribution: quantile works well
|
|
286
|
+
return 'quantile'
|
|
287
|
+
|
|
288
|
+
def _compute_n_bins(self, values: np.ndarray, method: str) -> int:
|
|
289
|
+
"""Compute optimal number of bins based on method."""
|
|
290
|
+
n = len(values)
|
|
291
|
+
|
|
292
|
+
if method == 'sturges':
|
|
293
|
+
# Sturges' rule: k = 1 + log2(n)
|
|
294
|
+
return int(1 + np.log2(n))
|
|
295
|
+
|
|
296
|
+
elif method == 'freedman_diaconis':
|
|
297
|
+
# Freedman-Diaconis: bin width = 2 * IQR / n^(1/3)
|
|
298
|
+
q75, q25 = np.percentile(values, [75, 25])
|
|
299
|
+
iqr = q75 - q25
|
|
300
|
+
if iqr == 0:
|
|
301
|
+
return 5 # Default if no spread
|
|
302
|
+
bin_width = 2 * iqr / (n ** (1/3))
|
|
303
|
+
data_range = values.max() - values.min()
|
|
304
|
+
return max(1, int(np.ceil(data_range / bin_width)))
|
|
305
|
+
|
|
306
|
+
elif method in ('jenks', 'quantile', 'equal_width'):
|
|
307
|
+
# Use square root rule as default: k = sqrt(n)
|
|
308
|
+
k = int(np.sqrt(n))
|
|
309
|
+
return min(max(3, k), self.max_bins)
|
|
310
|
+
|
|
311
|
+
return 5 # Default
|
|
312
|
+
|
|
313
|
+
def _jenks_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
|
|
314
|
+
"""
|
|
315
|
+
Compute Jenks natural breaks (Fisher-Jenks algorithm).
|
|
316
|
+
|
|
317
|
+
This minimizes the sum of squared deviations within each class.
|
|
318
|
+
Based on: Fisher, W.D. (1958). "On Grouping for Maximum Homogeneity"
|
|
319
|
+
"""
|
|
320
|
+
# Sort values
|
|
321
|
+
sorted_vals = np.sort(values)
|
|
322
|
+
n = len(sorted_vals)
|
|
323
|
+
|
|
324
|
+
# Limit to n_bins - 1 breaks for n_bins bins
|
|
325
|
+
n_classes = min(n_bins, n)
|
|
326
|
+
|
|
327
|
+
if n_classes <= 1:
|
|
328
|
+
return np.array([sorted_vals[0], sorted_vals[-1]])
|
|
329
|
+
|
|
330
|
+
# Initialize matrices for dynamic programming
|
|
331
|
+
# mat1[i][j] = minimum sum of squared deviations for first i values in j classes
|
|
332
|
+
# mat2[i][j] = index of last break for optimal solution
|
|
333
|
+
|
|
334
|
+
# For efficiency, we use a simplified version
|
|
335
|
+
# that finds breaks by minimizing within-class variance
|
|
336
|
+
|
|
337
|
+
# Use k-means style approach for large datasets
|
|
338
|
+
if n > 500:
|
|
339
|
+
return self._jenks_kmeans_approx(sorted_vals, n_classes)
|
|
340
|
+
|
|
341
|
+
# Full Jenks algorithm for smaller datasets
|
|
342
|
+
mat1 = np.zeros((n + 1, n_classes + 1))
|
|
343
|
+
mat2 = np.zeros((n + 1, n_classes + 1), dtype=int)
|
|
344
|
+
|
|
345
|
+
# Initialize
|
|
346
|
+
mat1[:, 0] = np.inf
|
|
347
|
+
mat1[0, :] = np.inf
|
|
348
|
+
mat1[0, 0] = 0
|
|
349
|
+
|
|
350
|
+
# Compute variance for all intervals [i, j]
|
|
351
|
+
variance_combinations = {}
|
|
352
|
+
for i in range(n):
|
|
353
|
+
sums = 0.0
|
|
354
|
+
sum_squares = 0.0
|
|
355
|
+
for j in range(i, n):
|
|
356
|
+
sums += sorted_vals[j]
|
|
357
|
+
sum_squares += sorted_vals[j] ** 2
|
|
358
|
+
count = j - i + 1
|
|
359
|
+
mean = sums / count
|
|
360
|
+
variance = sum_squares / count - mean ** 2
|
|
361
|
+
variance_combinations[(i, j)] = variance * count
|
|
362
|
+
|
|
363
|
+
# Fill DP table
|
|
364
|
+
for i in range(1, n + 1):
|
|
365
|
+
for j in range(1, min(i, n_classes) + 1):
|
|
366
|
+
if j == 1:
|
|
367
|
+
mat1[i, j] = variance_combinations[(0, i - 1)]
|
|
368
|
+
mat2[i, j] = 0
|
|
369
|
+
else:
|
|
370
|
+
min_cost = np.inf
|
|
371
|
+
min_idx = 0
|
|
372
|
+
for k in range(j - 1, i):
|
|
373
|
+
cost = mat1[k, j - 1] + variance_combinations[(k, i - 1)]
|
|
374
|
+
if cost < min_cost:
|
|
375
|
+
min_cost = cost
|
|
376
|
+
min_idx = k
|
|
377
|
+
mat1[i, j] = min_cost
|
|
378
|
+
mat2[i, j] = min_idx
|
|
379
|
+
|
|
380
|
+
# Backtrack to find breaks
|
|
381
|
+
breaks = [n]
|
|
382
|
+
k = n_classes
|
|
383
|
+
while k > 1:
|
|
384
|
+
breaks.append(mat2[breaks[-1], k])
|
|
385
|
+
k -= 1
|
|
386
|
+
breaks.reverse()
|
|
387
|
+
|
|
388
|
+
# Convert to actual values
|
|
389
|
+
bin_edges = [sorted_vals[0]]
|
|
390
|
+
for b in breaks[1:]:
|
|
391
|
+
if b > 0 and b < n:
|
|
392
|
+
bin_edges.append(sorted_vals[b - 1])
|
|
393
|
+
bin_edges.append(sorted_vals[-1])
|
|
394
|
+
|
|
395
|
+
return np.unique(bin_edges)
|
|
396
|
+
|
|
397
|
+
def _jenks_kmeans_approx(self, sorted_vals: np.ndarray, n_bins: int) -> np.ndarray:
|
|
398
|
+
"""
|
|
399
|
+
Approximates Jenks using 1D k-means clustering.
|
|
400
|
+
More efficient for large datasets.
|
|
401
|
+
"""
|
|
402
|
+
try:
|
|
403
|
+
from scipy.cluster.vq import kmeans
|
|
404
|
+
|
|
405
|
+
# Run k-means on 1D data
|
|
406
|
+
centroids, _ = kmeans(sorted_vals.astype(float), n_bins)
|
|
407
|
+
centroids = np.sort(centroids)
|
|
408
|
+
|
|
409
|
+
# Find breakpoints between centroids
|
|
410
|
+
breaks = [sorted_vals[0]]
|
|
411
|
+
for i in range(len(centroids) - 1):
|
|
412
|
+
midpoint = (centroids[i] + centroids[i + 1]) / 2
|
|
413
|
+
breaks.append(midpoint)
|
|
414
|
+
breaks.append(sorted_vals[-1])
|
|
415
|
+
|
|
416
|
+
return np.array(breaks)
|
|
417
|
+
except Exception:
|
|
418
|
+
# Fallback to quantile
|
|
419
|
+
return self._quantile_breaks(sorted_vals, n_bins)
|
|
420
|
+
|
|
421
|
+
def _quantile_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
|
|
422
|
+
"""Compute quantile-based breaks (equal frequency binning)."""
|
|
423
|
+
quantiles = np.linspace(0, 100, n_bins + 1)
|
|
424
|
+
breaks = np.percentile(values, quantiles)
|
|
425
|
+
return np.unique(breaks)
|
|
426
|
+
|
|
427
|
+
def _equal_width_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
|
|
428
|
+
"""Compute equal-width breaks."""
|
|
429
|
+
return np.linspace(values.min(), values.max(), n_bins + 1)
|
|
430
|
+
|
|
431
|
+
def _freedman_diaconis_breaks(self, values: np.ndarray, n_bins: int) -> np.ndarray:
|
|
432
|
+
"""Compute breaks using Freedman-Diaconis rule for bin width."""
|
|
433
|
+
q75, q25 = np.percentile(values, [75, 25])
|
|
434
|
+
iqr = q75 - q25
|
|
435
|
+
|
|
436
|
+
if iqr == 0:
|
|
437
|
+
# No spread, fall back to equal width
|
|
438
|
+
return self._equal_width_breaks(values, n_bins)
|
|
439
|
+
|
|
440
|
+
bin_width = 2 * iqr / (len(values) ** (1/3))
|
|
441
|
+
min_val = values.min()
|
|
442
|
+
max_val = values.max()
|
|
443
|
+
|
|
444
|
+
# Generate breaks
|
|
445
|
+
breaks = [min_val]
|
|
446
|
+
current = min_val + bin_width
|
|
447
|
+
while current < max_val:
|
|
448
|
+
breaks.append(current)
|
|
449
|
+
current += bin_width
|
|
450
|
+
breaks.append(max_val)
|
|
451
|
+
|
|
452
|
+
# Limit to max_bins
|
|
453
|
+
if len(breaks) - 1 > n_bins:
|
|
454
|
+
return self._equal_width_breaks(values, n_bins)
|
|
455
|
+
|
|
456
|
+
return np.array(breaks)
|
|
457
|
+
|
|
458
|
+
def _merge_small_bins(self, values: np.ndarray, breaks: np.ndarray) -> np.ndarray:
|
|
459
|
+
"""Merge bins that have fewer than min_bin_size samples."""
|
|
460
|
+
if len(breaks) <= 2:
|
|
461
|
+
return breaks
|
|
462
|
+
|
|
463
|
+
# Count samples per bin
|
|
464
|
+
bin_indices = np.digitize(values, breaks[1:-1])
|
|
465
|
+
counts = np.bincount(bin_indices, minlength=len(breaks) - 1)
|
|
466
|
+
|
|
467
|
+
# Iteratively merge small bins
|
|
468
|
+
new_breaks = list(breaks)
|
|
469
|
+
merged = True
|
|
470
|
+
while merged and len(new_breaks) > 2:
|
|
471
|
+
merged = False
|
|
472
|
+
bin_indices = np.digitize(values, new_breaks[1:-1])
|
|
473
|
+
counts = np.bincount(bin_indices, minlength=len(new_breaks) - 1)
|
|
474
|
+
|
|
475
|
+
for i in range(len(counts)):
|
|
476
|
+
if counts[i] < self.min_bin_size and len(new_breaks) > 2:
|
|
477
|
+
# Merge with neighbor that has fewer samples
|
|
478
|
+
if i == 0:
|
|
479
|
+
# Merge with next bin (remove break at index 1)
|
|
480
|
+
if len(new_breaks) > 2:
|
|
481
|
+
new_breaks.pop(1)
|
|
482
|
+
merged = True
|
|
483
|
+
break
|
|
484
|
+
elif i == len(counts) - 1:
|
|
485
|
+
# Merge with previous bin (remove second to last break)
|
|
486
|
+
if len(new_breaks) > 2:
|
|
487
|
+
new_breaks.pop(-2)
|
|
488
|
+
merged = True
|
|
489
|
+
break
|
|
490
|
+
else:
|
|
491
|
+
# Merge with smaller neighbor
|
|
492
|
+
if counts[i - 1] <= counts[i + 1]:
|
|
493
|
+
new_breaks.pop(i)
|
|
494
|
+
else:
|
|
495
|
+
new_breaks.pop(i + 1)
|
|
496
|
+
merged = True
|
|
497
|
+
break
|
|
498
|
+
|
|
499
|
+
return np.array(new_breaks)
|
|
500
|
+
|
|
501
|
+
def _generate_labels(self, breaks: np.ndarray, values: np.ndarray) -> List[str]:
|
|
502
|
+
"""Generate human-readable bin labels."""
|
|
503
|
+
labels = []
|
|
504
|
+
|
|
505
|
+
# Determine formatting based on value range
|
|
506
|
+
min_val, max_val = values.min(), values.max()
|
|
507
|
+
val_range = max_val - min_val
|
|
508
|
+
|
|
509
|
+
# Determine number of decimal places
|
|
510
|
+
if val_range > 1000:
|
|
511
|
+
fmt = lambda x: f"{int(round(x)):,}"
|
|
512
|
+
elif val_range > 10:
|
|
513
|
+
fmt = lambda x: f"{x:.1f}"
|
|
514
|
+
elif val_range > 1:
|
|
515
|
+
fmt = lambda x: f"{x:.2f}"
|
|
516
|
+
else:
|
|
517
|
+
fmt = lambda x: f"{x:.3f}"
|
|
518
|
+
|
|
519
|
+
for i in range(len(breaks) - 1):
|
|
520
|
+
low = breaks[i]
|
|
521
|
+
high = breaks[i + 1]
|
|
522
|
+
|
|
523
|
+
if i == 0:
|
|
524
|
+
# First bin: "≤ X"
|
|
525
|
+
labels.append(f"≤ {fmt(high)}")
|
|
526
|
+
elif i == len(breaks) - 2:
|
|
527
|
+
# Last bin: "> X"
|
|
528
|
+
labels.append(f"> {fmt(low)}")
|
|
529
|
+
else:
|
|
530
|
+
# Middle bins: "X - Y"
|
|
531
|
+
labels.append(f"{fmt(low)} - {fmt(high)}")
|
|
532
|
+
|
|
533
|
+
# If we only have 2 breaks, create clearer labels
|
|
534
|
+
if len(labels) == 1:
|
|
535
|
+
mid = (breaks[0] + breaks[1]) / 2
|
|
536
|
+
labels = [f"≤ {fmt(mid)}", f"> {fmt(mid)}"]
|
|
537
|
+
self.bin_edges_ = np.array([breaks[0], mid, breaks[1]])
|
|
538
|
+
|
|
539
|
+
return labels
|
|
540
|
+
|
|
541
|
+
def get_bin_summary(self, data: Union[pd.Series, np.ndarray]) -> pd.DataFrame:
|
|
542
|
+
"""
|
|
543
|
+
Get a summary of the discretization.
|
|
544
|
+
|
|
545
|
+
Parameters:
|
|
546
|
+
----------
|
|
547
|
+
data : array-like
|
|
548
|
+
The original numeric data.
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
-------
|
|
552
|
+
summary : pd.DataFrame
|
|
553
|
+
DataFrame with bin ranges, counts, and percentages.
|
|
554
|
+
"""
|
|
555
|
+
if not self._is_fitted:
|
|
556
|
+
raise RuntimeError("Not fitted. Call fit() first.")
|
|
557
|
+
|
|
558
|
+
labels = self.transform(data)
|
|
559
|
+
counts = Counter(labels)
|
|
560
|
+
|
|
561
|
+
records = []
|
|
562
|
+
for label in self.bin_labels_:
|
|
563
|
+
count = counts.get(label, 0)
|
|
564
|
+
pct = count / len(labels) * 100 if len(labels) > 0 else 0
|
|
565
|
+
records.append({
|
|
566
|
+
'bin': label,
|
|
567
|
+
'count': count,
|
|
568
|
+
'percentage': f"{pct:.1f}%"
|
|
569
|
+
})
|
|
570
|
+
|
|
571
|
+
return pd.DataFrame(records)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def discretize_numeric_target(
|
|
575
|
+
df: pd.DataFrame,
|
|
576
|
+
target_column: str,
|
|
577
|
+
method: str = 'auto',
|
|
578
|
+
n_bins: Optional[int] = None,
|
|
579
|
+
inplace: bool = False
|
|
580
|
+
) -> Tuple[pd.DataFrame, NumericTargetDiscretizer]:
|
|
581
|
+
"""
|
|
582
|
+
Discretize a numeric target column for use with CN2 classification.
|
|
583
|
+
|
|
584
|
+
This is a convenience function that creates a new categorical column
|
|
585
|
+
from a numeric column using intelligent binning methods.
|
|
586
|
+
|
|
587
|
+
Parameters:
|
|
588
|
+
----------
|
|
589
|
+
df : pd.DataFrame
|
|
590
|
+
The dataframe containing the target column.
|
|
591
|
+
|
|
592
|
+
target_column : str
|
|
593
|
+
Name of the numeric column to discretize.
|
|
594
|
+
|
|
595
|
+
method : str, default='auto'
|
|
596
|
+
Binning method. One of: 'auto', 'jenks', 'quantile',
|
|
597
|
+
'freedman_diaconis', 'sturges', 'equal_width'
|
|
598
|
+
|
|
599
|
+
n_bins : int, optional
|
|
600
|
+
Number of bins. If None, determined automatically.
|
|
601
|
+
|
|
602
|
+
inplace : bool, default=False
|
|
603
|
+
If True, modify the dataframe in place. Otherwise, return a copy.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
-------
|
|
607
|
+
df : pd.DataFrame
|
|
608
|
+
DataFrame with the target column replaced by discretized values.
|
|
609
|
+
|
|
610
|
+
discretizer : NumericTargetDiscretizer
|
|
611
|
+
The fitted discretizer (useful for inspecting bin edges).
|
|
612
|
+
|
|
613
|
+
Example:
|
|
614
|
+
-------
|
|
615
|
+
>>> df_discrete, disc = discretize_numeric_target(df, 'income', method='jenks')
|
|
616
|
+
>>> print(disc.bin_labels_)
|
|
617
|
+
['≤ 25000', '25001 - 50000', '50001 - 100000', '> 100000']
|
|
618
|
+
"""
|
|
619
|
+
if not inplace:
|
|
620
|
+
df = df.copy()
|
|
621
|
+
|
|
622
|
+
if target_column not in df.columns:
|
|
623
|
+
raise ValueError(f"Column '{target_column}' not found in DataFrame")
|
|
624
|
+
|
|
625
|
+
discretizer = NumericTargetDiscretizer(method=method, n_bins=n_bins)
|
|
626
|
+
df[target_column] = discretizer.fit_transform(df[target_column])
|
|
627
|
+
|
|
628
|
+
return df, discretizer
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
# =============================================================================
|
|
632
|
+
# Data Structures
|
|
633
|
+
# =============================================================================
|
|
634
|
+
|
|
635
|
+
@dataclass
|
|
636
|
+
class Condition:
|
|
637
|
+
"""
|
|
638
|
+
A single condition in a rule.
|
|
639
|
+
|
|
640
|
+
Attributes:
|
|
641
|
+
feature: Name of the feature being tested
|
|
642
|
+
operator: Comparison operator ('==', '<=', '>')
|
|
643
|
+
value: Value to compare against
|
|
644
|
+
is_numeric: Whether this is a numeric condition
|
|
645
|
+
"""
|
|
646
|
+
feature: str
|
|
647
|
+
operator: str # '==' for categorical, '<=' or '>' for numeric
|
|
648
|
+
value: Any
|
|
649
|
+
is_numeric: bool = False
|
|
650
|
+
|
|
651
|
+
def __str__(self) -> str:
|
|
652
|
+
if self.is_numeric:
|
|
653
|
+
return f"{self.feature} {self.operator} {self.value:.4g}"
|
|
654
|
+
return f"{self.feature} == '{self.value}'"
|
|
655
|
+
|
|
656
|
+
def __hash__(self) -> int:
|
|
657
|
+
return hash((self.feature, self.operator, str(self.value)))
|
|
658
|
+
|
|
659
|
+
def __eq__(self, other) -> bool:
|
|
660
|
+
if not isinstance(other, Condition):
|
|
661
|
+
return False
|
|
662
|
+
return (self.feature == other.feature and
|
|
663
|
+
self.operator == other.operator and
|
|
664
|
+
str(self.value) == str(other.value))
|
|
665
|
+
|
|
666
|
+
def evaluate(self, row: pd.Series) -> bool:
|
|
667
|
+
"""Evaluate this condition on a single data row."""
|
|
668
|
+
val = row[self.feature]
|
|
669
|
+
if pd.isna(val):
|
|
670
|
+
return False
|
|
671
|
+
if self.operator == '==':
|
|
672
|
+
return val == self.value
|
|
673
|
+
elif self.operator == '<=':
|
|
674
|
+
return val <= self.value
|
|
675
|
+
elif self.operator == '>':
|
|
676
|
+
return val > self.value
|
|
677
|
+
return False
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
@dataclass
|
|
681
|
+
class Rule:
|
|
682
|
+
"""
|
|
683
|
+
A classification rule consisting of conditions and a predicted class.
|
|
684
|
+
|
|
685
|
+
Attributes:
|
|
686
|
+
conditions: List of conditions forming the rule antecedent
|
|
687
|
+
predicted_class: The class label predicted by this rule
|
|
688
|
+
coverage: Number of training examples covered by this rule
|
|
689
|
+
accuracy: Accuracy of this rule on covered examples
|
|
690
|
+
class_distribution: Distribution of classes among covered examples
|
|
691
|
+
quality_score: Quality score used during rule learning
|
|
692
|
+
"""
|
|
693
|
+
conditions: List[Condition] = field(default_factory=list)
|
|
694
|
+
predicted_class: Any = None
|
|
695
|
+
coverage: int = 0
|
|
696
|
+
accuracy: float = 0.0
|
|
697
|
+
class_distribution: Dict[Any, int] = field(default_factory=dict)
|
|
698
|
+
quality_score: float = 0.0
|
|
699
|
+
|
|
700
|
+
def __str__(self) -> str:
|
|
701
|
+
if not self.conditions:
|
|
702
|
+
return f"IF True THEN class = {self.predicted_class}"
|
|
703
|
+
cond_str = " AND ".join(str(c) for c in self.conditions)
|
|
704
|
+
return (f"IF {cond_str} THEN class = {self.predicted_class} "
|
|
705
|
+
f"[cov={self.coverage}, acc={self.accuracy:.2%}]")
|
|
706
|
+
|
|
707
|
+
def __repr__(self) -> str:
|
|
708
|
+
return self.__str__()
|
|
709
|
+
|
|
710
|
+
def covers(self, row: pd.Series) -> bool:
|
|
711
|
+
"""Check if this rule covers (applies to) a data row."""
|
|
712
|
+
if not self.conditions:
|
|
713
|
+
return True
|
|
714
|
+
return all(cond.evaluate(row) for cond in self.conditions)
|
|
715
|
+
|
|
716
|
+
def covers_mask(self, X: pd.DataFrame) -> np.ndarray:
|
|
717
|
+
"""Return a boolean mask of rows covered by this rule."""
|
|
718
|
+
if not self.conditions:
|
|
719
|
+
return np.ones(len(X), dtype=bool)
|
|
720
|
+
mask = np.ones(len(X), dtype=bool)
|
|
721
|
+
for cond in self.conditions:
|
|
722
|
+
col_vals = X[cond.feature]
|
|
723
|
+
if cond.operator == '==':
|
|
724
|
+
mask &= (col_vals == cond.value).values
|
|
725
|
+
elif cond.operator == '<=':
|
|
726
|
+
mask &= (col_vals <= cond.value).values
|
|
727
|
+
elif cond.operator == '>':
|
|
728
|
+
mask &= (col_vals > cond.value).values
|
|
729
|
+
# Handle NaN values
|
|
730
|
+
mask &= ~col_vals.isna().values
|
|
731
|
+
return mask
|
|
732
|
+
|
|
733
|
+
def to_dict(self) -> Dict:
|
|
734
|
+
"""Convert rule to a dictionary representation."""
|
|
735
|
+
return {
|
|
736
|
+
'conditions': [str(c) for c in self.conditions],
|
|
737
|
+
'conditions_detailed': [
|
|
738
|
+
{'feature': c.feature, 'operator': c.operator,
|
|
739
|
+
'value': c.value, 'is_numeric': c.is_numeric}
|
|
740
|
+
for c in self.conditions
|
|
741
|
+
],
|
|
742
|
+
'predicted_class': self.predicted_class,
|
|
743
|
+
'coverage': self.coverage,
|
|
744
|
+
'accuracy': self.accuracy,
|
|
745
|
+
'class_distribution': dict(self.class_distribution),
|
|
746
|
+
'quality_score': self.quality_score
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
# =============================================================================
|
|
751
|
+
# CN2 Classifier (Optimized)
|
|
752
|
+
# =============================================================================
|
|
753
|
+
|
|
754
|
+
class CN2Classifier:
|
|
755
|
+
"""
|
|
756
|
+
CN2 Rule Induction Classifier (Optimized Implementation).
|
|
757
|
+
|
|
758
|
+
Implements the classic CN2 algorithm for learning classification rules
|
|
759
|
+
using a separate-and-conquer strategy with beam search.
|
|
760
|
+
|
|
761
|
+
Parameters:
|
|
762
|
+
max_rules: Maximum number of rules to learn (default: 10)
|
|
763
|
+
beam_width: Width of the beam in beam search (default: 3)
|
|
764
|
+
min_covered_examples: Minimum examples a rule must cover (default: 5)
|
|
765
|
+
max_rule_length: Maximum number of conditions per rule (default: 3)
|
|
766
|
+
quality_measure: Quality heuristic - 'likelihood_ratio' or 'entropy'
|
|
767
|
+
random_state: Random seed for reproducibility
|
|
768
|
+
discretization_bins: Number of bins for numeric feature discretization (default: 4)
|
|
769
|
+
laplace_smoothing: Whether to use Laplace smoothing for probabilities
|
|
770
|
+
|
|
771
|
+
Attributes:
|
|
772
|
+
rules_: List of learned Rule objects (after fitting)
|
|
773
|
+
default_class_: Default class for examples not covered by any rule
|
|
774
|
+
classes_: Unique class labels
|
|
775
|
+
feature_names_: Names of input features
|
|
776
|
+
n_features_: Number of input features
|
|
777
|
+
"""
|
|
778
|
+
|
|
779
|
+
def __init__(
|
|
780
|
+
self,
|
|
781
|
+
max_rules: Optional[int] = 10,
|
|
782
|
+
beam_width: int = 3,
|
|
783
|
+
min_covered_examples: int = 5,
|
|
784
|
+
max_rule_length: Optional[int] = 3,
|
|
785
|
+
quality_measure: str = "likelihood_ratio",
|
|
786
|
+
random_state: Optional[int] = None,
|
|
787
|
+
discretization_bins: int = 4,
|
|
788
|
+
laplace_smoothing: bool = True
|
|
789
|
+
):
|
|
790
|
+
self.max_rules = max_rules
|
|
791
|
+
self.beam_width = beam_width
|
|
792
|
+
self.min_covered_examples = min_covered_examples
|
|
793
|
+
self.max_rule_length = max_rule_length if max_rule_length else 3
|
|
794
|
+
self.quality_measure = quality_measure
|
|
795
|
+
self.random_state = random_state
|
|
796
|
+
self.discretization_bins = discretization_bins
|
|
797
|
+
self.laplace_smoothing = laplace_smoothing
|
|
798
|
+
|
|
799
|
+
# Validate quality measure
|
|
800
|
+
if quality_measure not in ('likelihood_ratio', 'entropy'):
|
|
801
|
+
raise ValueError(
|
|
802
|
+
f"quality_measure must be 'likelihood_ratio' or 'entropy', "
|
|
803
|
+
f"got '{quality_measure}'"
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Will be set during fit
|
|
807
|
+
self.rules_: List[Rule] = []
|
|
808
|
+
self.default_class_: Any = None
|
|
809
|
+
self.classes_: np.ndarray = None
|
|
810
|
+
self.feature_names_: List[str] = []
|
|
811
|
+
self.n_features_: int = 0
|
|
812
|
+
self._is_fitted: bool = False
|
|
813
|
+
|
|
814
|
+
# Cached data for fast computation
|
|
815
|
+
self._X_array: np.ndarray = None
|
|
816
|
+
self._y_array: np.ndarray = None
|
|
817
|
+
self._condition_masks: Dict = {} # Pre-computed masks for all conditions
|
|
818
|
+
self._feature_conditions: List = [] # List of (feature_idx, operator, value, is_numeric)
|
|
819
|
+
|
|
820
|
+
def fit(self, X, y) -> "CN2Classifier":
|
|
821
|
+
"""
|
|
822
|
+
Learn a rule list from training data.
|
|
823
|
+
|
|
824
|
+
Parameters:
|
|
825
|
+
X: Feature matrix (pandas DataFrame or 2D numpy array)
|
|
826
|
+
y: Target labels (1D array-like)
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
self: The fitted classifier
|
|
830
|
+
"""
|
|
831
|
+
# Convert inputs to proper format
|
|
832
|
+
X_df, y = self._validate_input(X, y)
|
|
833
|
+
|
|
834
|
+
# Store class information
|
|
835
|
+
self.classes_ = np.unique(y)
|
|
836
|
+
self.feature_names_ = list(X_df.columns)
|
|
837
|
+
self.n_features_ = len(self.feature_names_)
|
|
838
|
+
|
|
839
|
+
# Store default class (majority class in training data)
|
|
840
|
+
class_counts = Counter(y)
|
|
841
|
+
self.default_class_ = class_counts.most_common(1)[0][0]
|
|
842
|
+
|
|
843
|
+
# Convert class labels to integers for faster processing
|
|
844
|
+
self._class_to_idx = {c: i for i, c in enumerate(self.classes_)}
|
|
845
|
+
self._y_encoded = np.array([self._class_to_idx[c] for c in y])
|
|
846
|
+
n_classes = len(self.classes_)
|
|
847
|
+
|
|
848
|
+
# Pre-compute all condition masks (this is the key optimization)
|
|
849
|
+
self._precompute_condition_masks(X_df)
|
|
850
|
+
|
|
851
|
+
# Check if we have any valid conditions
|
|
852
|
+
if len(self._feature_conditions) == 0:
|
|
853
|
+
# No valid conditions - can't learn rules, just use default class
|
|
854
|
+
self._is_fitted = True
|
|
855
|
+
return self
|
|
856
|
+
|
|
857
|
+
# Main CN2 loop: separate-and-conquer
|
|
858
|
+
self.rules_ = []
|
|
859
|
+
remaining_mask = np.ones(len(X_df), dtype=bool)
|
|
860
|
+
n_samples = len(X_df)
|
|
861
|
+
|
|
862
|
+
while remaining_mask.sum() >= self.min_covered_examples:
|
|
863
|
+
# Check max rules limit
|
|
864
|
+
if self.max_rules is not None and len(self.rules_) >= self.max_rules:
|
|
865
|
+
break
|
|
866
|
+
|
|
867
|
+
# Find the best rule for remaining examples
|
|
868
|
+
best_rule = self._find_best_rule_fast(remaining_mask)
|
|
869
|
+
|
|
870
|
+
if best_rule is None:
|
|
871
|
+
break
|
|
872
|
+
|
|
873
|
+
# Add rule to rule list
|
|
874
|
+
self.rules_.append(best_rule)
|
|
875
|
+
|
|
876
|
+
# Remove covered examples
|
|
877
|
+
rule_mask = self._compute_rule_mask(best_rule._condition_indices)
|
|
878
|
+
remaining_mask &= ~rule_mask
|
|
879
|
+
|
|
880
|
+
# Store dataframe for prediction
|
|
881
|
+
self._X_df = X_df
|
|
882
|
+
self._is_fitted = True
|
|
883
|
+
return self
|
|
884
|
+
|
|
885
|
+
def _precompute_condition_masks(self, X: pd.DataFrame):
|
|
886
|
+
"""Pre-compute boolean masks for all possible conditions."""
|
|
887
|
+
self._condition_masks = {}
|
|
888
|
+
self._feature_conditions = []
|
|
889
|
+
n_samples = len(X)
|
|
890
|
+
|
|
891
|
+
for feat_idx, col in enumerate(X.columns):
|
|
892
|
+
col_data = X[col].values
|
|
893
|
+
dtype = X[col].dtype
|
|
894
|
+
is_numeric = np.issubdtype(dtype, np.number)
|
|
895
|
+
|
|
896
|
+
# Handle NaN mask
|
|
897
|
+
nan_mask = pd.isna(col_data)
|
|
898
|
+
|
|
899
|
+
if is_numeric:
|
|
900
|
+
# Get quantile thresholds (fewer bins = faster)
|
|
901
|
+
valid_data = col_data[~nan_mask]
|
|
902
|
+
if len(valid_data) == 0:
|
|
903
|
+
continue
|
|
904
|
+
|
|
905
|
+
n_unique = len(np.unique(valid_data))
|
|
906
|
+
if n_unique <= self.discretization_bins:
|
|
907
|
+
# Treat as categorical
|
|
908
|
+
for val in np.unique(valid_data):
|
|
909
|
+
cond_idx = len(self._feature_conditions)
|
|
910
|
+
mask = (col_data == val) & ~nan_mask
|
|
911
|
+
self._condition_masks[cond_idx] = mask
|
|
912
|
+
self._feature_conditions.append((feat_idx, '==', val, False))
|
|
913
|
+
else:
|
|
914
|
+
# Use quantile thresholds
|
|
915
|
+
quantiles = np.linspace(0, 1, self.discretization_bins + 1)[1:-1]
|
|
916
|
+
thresholds = np.unique(np.quantile(valid_data, quantiles))
|
|
917
|
+
|
|
918
|
+
for thresh in thresholds:
|
|
919
|
+
# <= condition
|
|
920
|
+
cond_idx = len(self._feature_conditions)
|
|
921
|
+
mask = (col_data <= thresh) & ~nan_mask
|
|
922
|
+
self._condition_masks[cond_idx] = mask
|
|
923
|
+
self._feature_conditions.append((feat_idx, '<=', thresh, True))
|
|
924
|
+
|
|
925
|
+
# > condition
|
|
926
|
+
cond_idx = len(self._feature_conditions)
|
|
927
|
+
mask = (col_data > thresh) & ~nan_mask
|
|
928
|
+
self._condition_masks[cond_idx] = mask
|
|
929
|
+
self._feature_conditions.append((feat_idx, '>', thresh, True))
|
|
930
|
+
else:
|
|
931
|
+
# Categorical feature - limit number of values
|
|
932
|
+
unique_vals = pd.Series(col_data).dropna().unique()
|
|
933
|
+
# Limit to top N most frequent values for performance
|
|
934
|
+
if len(unique_vals) > 10:
|
|
935
|
+
value_counts = pd.Series(col_data).value_counts()
|
|
936
|
+
unique_vals = value_counts.head(10).index.tolist()
|
|
937
|
+
|
|
938
|
+
for val in unique_vals:
|
|
939
|
+
cond_idx = len(self._feature_conditions)
|
|
940
|
+
mask = (col_data == val) & ~nan_mask
|
|
941
|
+
self._condition_masks[cond_idx] = mask
|
|
942
|
+
self._feature_conditions.append((feat_idx, '==', val, False))
|
|
943
|
+
|
|
944
|
+
def _compute_rule_mask(self, condition_indices: List[int]) -> np.ndarray:
|
|
945
|
+
"""Compute mask for a rule given its condition indices."""
|
|
946
|
+
if not condition_indices:
|
|
947
|
+
return np.ones(len(self._y_encoded), dtype=bool)
|
|
948
|
+
|
|
949
|
+
mask = self._condition_masks[condition_indices[0]].copy()
|
|
950
|
+
for idx in condition_indices[1:]:
|
|
951
|
+
mask &= self._condition_masks[idx]
|
|
952
|
+
return mask
|
|
953
|
+
|
|
954
|
+
def _find_best_rule_fast(self, remaining_mask: np.ndarray) -> Optional[Rule]:
|
|
955
|
+
"""
|
|
956
|
+
Find the best rule using optimized beam search.
|
|
957
|
+
"""
|
|
958
|
+
n_samples = remaining_mask.sum()
|
|
959
|
+
if n_samples == 0:
|
|
960
|
+
return None
|
|
961
|
+
|
|
962
|
+
n_classes = len(self.classes_)
|
|
963
|
+
if n_classes == 0:
|
|
964
|
+
return None
|
|
965
|
+
|
|
966
|
+
# Check if there are any valid conditions to try
|
|
967
|
+
if len(self._feature_conditions) == 0:
|
|
968
|
+
return None
|
|
969
|
+
|
|
970
|
+
y_remaining = self._y_encoded[remaining_mask]
|
|
971
|
+
|
|
972
|
+
# Track best rule found
|
|
973
|
+
best_rule_info = None
|
|
974
|
+
best_quality = float('-inf')
|
|
975
|
+
|
|
976
|
+
# Beam: list of (condition_indices, quality)
|
|
977
|
+
beam = [([], 0.0)]
|
|
978
|
+
|
|
979
|
+
for depth in range(self.max_rule_length):
|
|
980
|
+
new_beam = []
|
|
981
|
+
seen_masks = set() # Avoid duplicate rules
|
|
982
|
+
|
|
983
|
+
for cond_indices, _ in beam:
|
|
984
|
+
# Get current mask
|
|
985
|
+
if cond_indices:
|
|
986
|
+
current_mask = self._compute_rule_mask(cond_indices) & remaining_mask
|
|
987
|
+
else:
|
|
988
|
+
current_mask = remaining_mask.copy()
|
|
989
|
+
|
|
990
|
+
current_coverage = current_mask.sum()
|
|
991
|
+
if current_coverage < self.min_covered_examples:
|
|
992
|
+
continue
|
|
993
|
+
|
|
994
|
+
# Get features already used
|
|
995
|
+
used_features = {self._feature_conditions[i][0] for i in cond_indices}
|
|
996
|
+
|
|
997
|
+
# Try adding each condition
|
|
998
|
+
for cond_idx, (feat_idx, op, val, is_num) in enumerate(self._feature_conditions):
|
|
999
|
+
# Skip if feature already used
|
|
1000
|
+
if feat_idx in used_features:
|
|
1001
|
+
continue
|
|
1002
|
+
|
|
1003
|
+
# Compute combined mask
|
|
1004
|
+
new_mask = current_mask & self._condition_masks[cond_idx]
|
|
1005
|
+
coverage = new_mask.sum()
|
|
1006
|
+
|
|
1007
|
+
# Skip if insufficient coverage
|
|
1008
|
+
if coverage < self.min_covered_examples:
|
|
1009
|
+
continue
|
|
1010
|
+
|
|
1011
|
+
# Create hashable key to avoid duplicates
|
|
1012
|
+
mask_key = tuple(sorted(cond_indices + [cond_idx]))
|
|
1013
|
+
if mask_key in seen_masks:
|
|
1014
|
+
continue
|
|
1015
|
+
seen_masks.add(mask_key)
|
|
1016
|
+
|
|
1017
|
+
# Compute class distribution efficiently
|
|
1018
|
+
y_covered = self._y_encoded[new_mask]
|
|
1019
|
+
class_counts = np.bincount(y_covered, minlength=n_classes)
|
|
1020
|
+
|
|
1021
|
+
# Compute quality
|
|
1022
|
+
quality = self._compute_quality_fast(class_counts, coverage, n_samples)
|
|
1023
|
+
|
|
1024
|
+
# Compute accuracy
|
|
1025
|
+
majority_class_count = class_counts.max()
|
|
1026
|
+
accuracy = majority_class_count / coverage
|
|
1027
|
+
|
|
1028
|
+
# Track best overall
|
|
1029
|
+
if quality > best_quality and accuracy >= 0.5:
|
|
1030
|
+
best_quality = quality
|
|
1031
|
+
best_rule_info = (cond_indices + [cond_idx], class_counts, coverage, accuracy, quality)
|
|
1032
|
+
|
|
1033
|
+
new_beam.append((cond_indices + [cond_idx], quality))
|
|
1034
|
+
|
|
1035
|
+
if not new_beam:
|
|
1036
|
+
break
|
|
1037
|
+
|
|
1038
|
+
# Keep top beam_width candidates (filter out NaN/inf quality scores)
|
|
1039
|
+
new_beam = [(c, q) for c, q in new_beam if not (np.isnan(q) or np.isinf(q) and q < 0)]
|
|
1040
|
+
if not new_beam:
|
|
1041
|
+
break
|
|
1042
|
+
new_beam.sort(key=lambda x: x[1], reverse=True)
|
|
1043
|
+
beam = new_beam[:self.beam_width]
|
|
1044
|
+
|
|
1045
|
+
# Convert best rule info to Rule object
|
|
1046
|
+
if best_rule_info is None:
|
|
1047
|
+
return None
|
|
1048
|
+
|
|
1049
|
+
cond_indices, class_counts, coverage, accuracy, quality = best_rule_info
|
|
1050
|
+
|
|
1051
|
+
# Build Rule object
|
|
1052
|
+
conditions = []
|
|
1053
|
+
for idx in cond_indices:
|
|
1054
|
+
feat_idx, op, val, is_num = self._feature_conditions[idx]
|
|
1055
|
+
feat_name = self.feature_names_[feat_idx]
|
|
1056
|
+
conditions.append(Condition(feat_name, op, val, is_num))
|
|
1057
|
+
|
|
1058
|
+
# Get predicted class and distribution
|
|
1059
|
+
majority_idx = class_counts.argmax()
|
|
1060
|
+
predicted_class = self.classes_[majority_idx]
|
|
1061
|
+
class_distribution = {self.classes_[i]: int(class_counts[i])
|
|
1062
|
+
for i in range(len(self.classes_)) if class_counts[i] > 0}
|
|
1063
|
+
|
|
1064
|
+
rule = Rule(
|
|
1065
|
+
conditions=conditions,
|
|
1066
|
+
predicted_class=predicted_class,
|
|
1067
|
+
coverage=int(coverage),
|
|
1068
|
+
accuracy=accuracy,
|
|
1069
|
+
class_distribution=class_distribution,
|
|
1070
|
+
quality_score=quality
|
|
1071
|
+
)
|
|
1072
|
+
rule._condition_indices = cond_indices # Store for mask computation
|
|
1073
|
+
|
|
1074
|
+
return rule
|
|
1075
|
+
|
|
1076
|
+
def _compute_quality_fast(self, class_counts: np.ndarray, coverage: int, total: int) -> float:
|
|
1077
|
+
"""Compute quality score efficiently using numpy."""
|
|
1078
|
+
if coverage == 0 or total == 0:
|
|
1079
|
+
return float('-inf')
|
|
1080
|
+
|
|
1081
|
+
try:
|
|
1082
|
+
if self.quality_measure == 'likelihood_ratio':
|
|
1083
|
+
# Likelihood ratio
|
|
1084
|
+
n_classes = len(self.classes_)
|
|
1085
|
+
if n_classes == 0:
|
|
1086
|
+
return float('-inf')
|
|
1087
|
+
expected = coverage / n_classes
|
|
1088
|
+
if expected <= 0:
|
|
1089
|
+
return float('-inf')
|
|
1090
|
+
lr = 0.0
|
|
1091
|
+
for count in class_counts:
|
|
1092
|
+
if count > 0:
|
|
1093
|
+
lr += count * np.log(count / expected + 1e-10)
|
|
1094
|
+
result = 2 * lr
|
|
1095
|
+
else:
|
|
1096
|
+
# Entropy-based
|
|
1097
|
+
probs = class_counts / coverage
|
|
1098
|
+
probs = probs[probs > 0]
|
|
1099
|
+
if len(probs) == 0:
|
|
1100
|
+
return float('-inf')
|
|
1101
|
+
entropy = -np.sum(probs * np.log2(probs + 1e-10))
|
|
1102
|
+
max_entropy = np.log2(len(self.classes_))
|
|
1103
|
+
if max_entropy > 0:
|
|
1104
|
+
entropy /= max_entropy
|
|
1105
|
+
result = (coverage / total) * (1.0 - entropy)
|
|
1106
|
+
|
|
1107
|
+
# Guard against NaN
|
|
1108
|
+
if np.isnan(result) or np.isinf(result):
|
|
1109
|
+
return float('-inf')
|
|
1110
|
+
return float(result)
|
|
1111
|
+
except Exception:
|
|
1112
|
+
return float('-inf')
|
|
1113
|
+
|
|
1114
|
+
def predict(self, X) -> np.ndarray:
|
|
1115
|
+
"""
|
|
1116
|
+
Predict class labels for samples in X.
|
|
1117
|
+
"""
|
|
1118
|
+
self._check_is_fitted()
|
|
1119
|
+
X = self._ensure_dataframe(X)
|
|
1120
|
+
|
|
1121
|
+
predictions = np.full(len(X), self.default_class_, dtype=object)
|
|
1122
|
+
predicted = np.zeros(len(X), dtype=bool)
|
|
1123
|
+
|
|
1124
|
+
for rule in self.rules_:
|
|
1125
|
+
mask = rule.covers_mask(X)
|
|
1126
|
+
new_preds = mask & ~predicted
|
|
1127
|
+
predictions[new_preds] = rule.predicted_class
|
|
1128
|
+
predicted |= mask
|
|
1129
|
+
|
|
1130
|
+
return predictions
|
|
1131
|
+
|
|
1132
|
+
def predict_proba(self, X) -> np.ndarray:
|
|
1133
|
+
"""
|
|
1134
|
+
Return class probability estimates for samples in X.
|
|
1135
|
+
"""
|
|
1136
|
+
self._check_is_fitted()
|
|
1137
|
+
X = self._ensure_dataframe(X)
|
|
1138
|
+
|
|
1139
|
+
n_samples = len(X)
|
|
1140
|
+
n_classes = len(self.classes_)
|
|
1141
|
+
proba = np.zeros((n_samples, n_classes))
|
|
1142
|
+
|
|
1143
|
+
# Default probabilities
|
|
1144
|
+
default_proba = self._compute_proba_from_distribution(
|
|
1145
|
+
Counter({c: 1 for c in self.classes_})
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
assigned = np.zeros(n_samples, dtype=bool)
|
|
1149
|
+
|
|
1150
|
+
for rule in self.rules_:
|
|
1151
|
+
mask = rule.covers_mask(X)
|
|
1152
|
+
new_assignments = mask & ~assigned
|
|
1153
|
+
|
|
1154
|
+
if new_assignments.sum() > 0:
|
|
1155
|
+
rule_proba = self._compute_proba_from_distribution(rule.class_distribution)
|
|
1156
|
+
proba[new_assignments] = rule_proba
|
|
1157
|
+
assigned[new_assignments] = True
|
|
1158
|
+
|
|
1159
|
+
proba[~assigned] = default_proba
|
|
1160
|
+
return proba
|
|
1161
|
+
|
|
1162
|
+
def _compute_proba_from_distribution(self, class_dist: Dict[Any, int]) -> np.ndarray:
|
|
1163
|
+
"""Compute class probabilities from distribution."""
|
|
1164
|
+
n_classes = len(self.classes_)
|
|
1165
|
+
proba = np.zeros(n_classes)
|
|
1166
|
+
total = sum(class_dist.values())
|
|
1167
|
+
|
|
1168
|
+
if self.laplace_smoothing:
|
|
1169
|
+
smoothed_total = total + n_classes
|
|
1170
|
+
for i, cls in enumerate(self.classes_):
|
|
1171
|
+
count = class_dist.get(cls, 0)
|
|
1172
|
+
proba[i] = (count + 1) / smoothed_total
|
|
1173
|
+
else:
|
|
1174
|
+
for i, cls in enumerate(self.classes_):
|
|
1175
|
+
count = class_dist.get(cls, 0)
|
|
1176
|
+
proba[i] = count / total if total > 0 else 1 / n_classes
|
|
1177
|
+
|
|
1178
|
+
return proba
|
|
1179
|
+
|
|
1180
|
+
def get_rules(self) -> List[Rule]:
|
|
1181
|
+
"""Return the learned rules."""
|
|
1182
|
+
self._check_is_fitted()
|
|
1183
|
+
return self.rules_.copy()
|
|
1184
|
+
|
|
1185
|
+
def get_rules_as_df(self) -> pd.DataFrame:
|
|
1186
|
+
"""Return rules as a pandas DataFrame."""
|
|
1187
|
+
self._check_is_fitted()
|
|
1188
|
+
records = []
|
|
1189
|
+
for i, rule in enumerate(self.rules_):
|
|
1190
|
+
records.append({
|
|
1191
|
+
'rule_id': i + 1,
|
|
1192
|
+
'conditions': ' AND '.join(str(c) for c in rule.conditions) or 'True',
|
|
1193
|
+
'predicted_class': rule.predicted_class,
|
|
1194
|
+
'coverage': rule.coverage,
|
|
1195
|
+
'accuracy': rule.accuracy,
|
|
1196
|
+
'quality_score': rule.quality_score,
|
|
1197
|
+
'n_conditions': len(rule.conditions)
|
|
1198
|
+
})
|
|
1199
|
+
return pd.DataFrame(records)
|
|
1200
|
+
|
|
1201
|
+
def score(self, X, y) -> float:
|
|
1202
|
+
"""Return accuracy on test data."""
|
|
1203
|
+
predictions = self.predict(X)
|
|
1204
|
+
return np.mean(predictions == np.asarray(y))
|
|
1205
|
+
|
|
1206
|
+
def _validate_input(self, X, y) -> Tuple[pd.DataFrame, np.ndarray]:
|
|
1207
|
+
"""Validate and convert input data."""
|
|
1208
|
+
X = self._ensure_dataframe(X)
|
|
1209
|
+
y = np.asarray(y, dtype=object) # Keep as object to preserve types
|
|
1210
|
+
|
|
1211
|
+
if len(X) != len(y):
|
|
1212
|
+
raise ValueError(f"X and y must have same length. Got {len(X)} and {len(y)}")
|
|
1213
|
+
|
|
1214
|
+
if len(X) == 0:
|
|
1215
|
+
raise ValueError("Cannot fit on empty dataset")
|
|
1216
|
+
|
|
1217
|
+
# Handle NaN values in target - remove rows with NaN target
|
|
1218
|
+
# Check for NaN: works for float NaN and None
|
|
1219
|
+
valid_mask = np.array([
|
|
1220
|
+
not (v is None or (isinstance(v, float) and np.isnan(v)) or
|
|
1221
|
+
(isinstance(v, str) and v.lower() == 'nan') or
|
|
1222
|
+
pd.isna(v))
|
|
1223
|
+
for v in y
|
|
1224
|
+
])
|
|
1225
|
+
|
|
1226
|
+
if valid_mask.sum() == 0:
|
|
1227
|
+
raise ValueError("All target values are NaN or missing")
|
|
1228
|
+
|
|
1229
|
+
if valid_mask.sum() < len(y):
|
|
1230
|
+
# Filter out rows with NaN target
|
|
1231
|
+
X = X.loc[valid_mask].reset_index(drop=True)
|
|
1232
|
+
y = y[valid_mask]
|
|
1233
|
+
|
|
1234
|
+
# Convert target to string to avoid mixed type issues during np.unique
|
|
1235
|
+
y = np.array([str(v) for v in y])
|
|
1236
|
+
|
|
1237
|
+
return X, y
|
|
1238
|
+
|
|
1239
|
+
def _ensure_dataframe(self, X) -> pd.DataFrame:
|
|
1240
|
+
"""Convert X to DataFrame."""
|
|
1241
|
+
if isinstance(X, pd.DataFrame):
|
|
1242
|
+
return X.copy()
|
|
1243
|
+
elif isinstance(X, np.ndarray):
|
|
1244
|
+
return pd.DataFrame(X, columns=[f"feature_{i}" for i in range(X.shape[1])])
|
|
1245
|
+
else:
|
|
1246
|
+
raise TypeError(f"X must be DataFrame or ndarray, got {type(X)}")
|
|
1247
|
+
|
|
1248
|
+
def _check_is_fitted(self):
|
|
1249
|
+
"""Check if fitted."""
|
|
1250
|
+
if not self._is_fitted:
|
|
1251
|
+
raise RuntimeError("CN2Classifier not fitted. Call fit() first.")
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
# =============================================================================
|
|
1255
|
+
# Visualization (PyQt6)
|
|
1256
|
+
# =============================================================================
|
|
1257
|
+
|
|
1258
|
+
from PyQt6.QtWidgets import (
|
|
1259
|
+
QMainWindow, QVBoxLayout, QHBoxLayout, QWidget,
|
|
1260
|
+
QTableWidget, QTableWidgetItem, QLabel, QPushButton,
|
|
1261
|
+
QComboBox, QSplitter, QTabWidget, QScrollArea,
|
|
1262
|
+
QFrame, QHeaderView, QTextEdit, QSpinBox,
|
|
1263
|
+
QDoubleSpinBox, QGroupBox, QFormLayout, QMessageBox
|
|
1264
|
+
)
|
|
1265
|
+
from PyQt6.QtCore import Qt, pyqtSignal
|
|
1266
|
+
from PyQt6.QtGui import QFont, QColor
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
class CN2RulesVisualization(QMainWindow):
|
|
1270
|
+
"""
|
|
1271
|
+
Window to visualize CN2 classification rules.
|
|
1272
|
+
"""
|
|
1273
|
+
|
|
1274
|
+
rulesApplied = pyqtSignal(object)
|
|
1275
|
+
|
|
1276
|
+
def __init__(
|
|
1277
|
+
self,
|
|
1278
|
+
classifier: CN2Classifier,
|
|
1279
|
+
X: pd.DataFrame,
|
|
1280
|
+
y: np.ndarray,
|
|
1281
|
+
target_column: str = "target",
|
|
1282
|
+
parent=None
|
|
1283
|
+
):
|
|
1284
|
+
super().__init__(parent)
|
|
1285
|
+
self.classifier = classifier
|
|
1286
|
+
self.X = X
|
|
1287
|
+
self.y = y
|
|
1288
|
+
self.target_column = target_column
|
|
1289
|
+
|
|
1290
|
+
self.setWindowTitle(f"CN2 Rule Induction - {target_column}")
|
|
1291
|
+
self.setGeometry(100, 100, 1200, 800)
|
|
1292
|
+
|
|
1293
|
+
self._setup_ui()
|
|
1294
|
+
self._populate_data()
|
|
1295
|
+
|
|
1296
|
+
def _setup_ui(self):
|
|
1297
|
+
"""Set up the user interface."""
|
|
1298
|
+
main_widget = QWidget()
|
|
1299
|
+
self.setCentralWidget(main_widget)
|
|
1300
|
+
main_layout = QVBoxLayout(main_widget)
|
|
1301
|
+
|
|
1302
|
+
# Title
|
|
1303
|
+
title_label = QLabel(f"CN2 Rule Induction Analysis: {self.target_column}")
|
|
1304
|
+
title_font = QFont()
|
|
1305
|
+
title_font.setBold(True)
|
|
1306
|
+
title_font.setPointSize(14)
|
|
1307
|
+
title_label.setFont(title_font)
|
|
1308
|
+
title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
1309
|
+
main_layout.addWidget(title_label)
|
|
1310
|
+
|
|
1311
|
+
# Description
|
|
1312
|
+
desc_label = QLabel(
|
|
1313
|
+
"CN2 learns interpretable IF-THEN classification rules. "
|
|
1314
|
+
"Rules are applied in order - the first matching rule determines the prediction."
|
|
1315
|
+
)
|
|
1316
|
+
desc_label.setWordWrap(True)
|
|
1317
|
+
main_layout.addWidget(desc_label)
|
|
1318
|
+
|
|
1319
|
+
# Summary stats
|
|
1320
|
+
self._create_summary_section(main_layout)
|
|
1321
|
+
|
|
1322
|
+
# Tabs
|
|
1323
|
+
tab_widget = QTabWidget()
|
|
1324
|
+
|
|
1325
|
+
# Rules Table Tab
|
|
1326
|
+
rules_tab = QWidget()
|
|
1327
|
+
rules_layout = QVBoxLayout(rules_tab)
|
|
1328
|
+
self.rules_table = self._create_rules_table()
|
|
1329
|
+
rules_layout.addWidget(self.rules_table)
|
|
1330
|
+
tab_widget.addTab(rules_tab, "Rules")
|
|
1331
|
+
|
|
1332
|
+
# Rule Details Tab
|
|
1333
|
+
details_tab = QWidget()
|
|
1334
|
+
details_layout = QVBoxLayout(details_tab)
|
|
1335
|
+
self.details_text = QTextEdit()
|
|
1336
|
+
self.details_text.setReadOnly(True)
|
|
1337
|
+
self.details_text.setFont(QFont("Courier New", 10))
|
|
1338
|
+
details_layout.addWidget(self.details_text)
|
|
1339
|
+
tab_widget.addTab(details_tab, "Rule Details")
|
|
1340
|
+
|
|
1341
|
+
# Algorithm Info Tab
|
|
1342
|
+
info_tab = QWidget()
|
|
1343
|
+
info_layout = QVBoxLayout(info_tab)
|
|
1344
|
+
self.info_text = QTextEdit()
|
|
1345
|
+
self.info_text.setReadOnly(True)
|
|
1346
|
+
info_layout.addWidget(self.info_text)
|
|
1347
|
+
tab_widget.addTab(info_tab, "Algorithm Info")
|
|
1348
|
+
|
|
1349
|
+
main_layout.addWidget(tab_widget, 1)
|
|
1350
|
+
|
|
1351
|
+
# Apply button
|
|
1352
|
+
button_layout = QHBoxLayout()
|
|
1353
|
+
button_layout.addStretch()
|
|
1354
|
+
|
|
1355
|
+
self.apply_button = QPushButton("Apply Rules")
|
|
1356
|
+
self.apply_button.setStyleSheet("""
|
|
1357
|
+
QPushButton {
|
|
1358
|
+
background-color: #3498DB;
|
|
1359
|
+
color: white;
|
|
1360
|
+
border: none;
|
|
1361
|
+
padding: 8px 16px;
|
|
1362
|
+
border-radius: 4px;
|
|
1363
|
+
font-weight: bold;
|
|
1364
|
+
}
|
|
1365
|
+
QPushButton:hover { background-color: #2980B9; }
|
|
1366
|
+
""")
|
|
1367
|
+
self.apply_button.clicked.connect(self._apply_rules)
|
|
1368
|
+
button_layout.addWidget(self.apply_button)
|
|
1369
|
+
|
|
1370
|
+
main_layout.addLayout(button_layout)
|
|
1371
|
+
|
|
1372
|
+
def _create_summary_section(self, parent_layout):
|
|
1373
|
+
"""Create summary statistics section."""
|
|
1374
|
+
summary_frame = QFrame()
|
|
1375
|
+
summary_frame.setFrameShape(QFrame.Shape.StyledPanel)
|
|
1376
|
+
summary_layout = QHBoxLayout(summary_frame)
|
|
1377
|
+
|
|
1378
|
+
self.n_rules_label = QLabel("Rules: -")
|
|
1379
|
+
self.n_rules_label.setStyleSheet("font-weight: bold;")
|
|
1380
|
+
summary_layout.addWidget(self.n_rules_label)
|
|
1381
|
+
|
|
1382
|
+
summary_layout.addWidget(QLabel("|"))
|
|
1383
|
+
|
|
1384
|
+
self.accuracy_label = QLabel("Training Accuracy: -")
|
|
1385
|
+
self.accuracy_label.setStyleSheet("font-weight: bold;")
|
|
1386
|
+
summary_layout.addWidget(self.accuracy_label)
|
|
1387
|
+
|
|
1388
|
+
summary_layout.addWidget(QLabel("|"))
|
|
1389
|
+
|
|
1390
|
+
self.coverage_label = QLabel("Total Coverage: -")
|
|
1391
|
+
self.coverage_label.setStyleSheet("font-weight: bold;")
|
|
1392
|
+
summary_layout.addWidget(self.coverage_label)
|
|
1393
|
+
|
|
1394
|
+
summary_layout.addWidget(QLabel("|"))
|
|
1395
|
+
|
|
1396
|
+
self.classes_label = QLabel("Classes: -")
|
|
1397
|
+
summary_layout.addWidget(self.classes_label)
|
|
1398
|
+
|
|
1399
|
+
summary_layout.addStretch()
|
|
1400
|
+
parent_layout.addWidget(summary_frame)
|
|
1401
|
+
|
|
1402
|
+
def _create_rules_table(self) -> QTableWidget:
|
|
1403
|
+
"""Create rules table widget."""
|
|
1404
|
+
table = QTableWidget()
|
|
1405
|
+
table.setColumnCount(6)
|
|
1406
|
+
table.setHorizontalHeaderLabels([
|
|
1407
|
+
"Rule #", "Conditions", "Predicted Class",
|
|
1408
|
+
"Coverage", "Accuracy", "Quality"
|
|
1409
|
+
])
|
|
1410
|
+
|
|
1411
|
+
header = table.horizontalHeader()
|
|
1412
|
+
header.setSectionResizeMode(0, QHeaderView.ResizeMode.Fixed)
|
|
1413
|
+
header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
|
|
1414
|
+
header.setSectionResizeMode(2, QHeaderView.ResizeMode.Fixed)
|
|
1415
|
+
header.setSectionResizeMode(3, QHeaderView.ResizeMode.Fixed)
|
|
1416
|
+
header.setSectionResizeMode(4, QHeaderView.ResizeMode.Fixed)
|
|
1417
|
+
header.setSectionResizeMode(5, QHeaderView.ResizeMode.Fixed)
|
|
1418
|
+
|
|
1419
|
+
table.setColumnWidth(0, 60)
|
|
1420
|
+
table.setColumnWidth(2, 120)
|
|
1421
|
+
table.setColumnWidth(3, 80)
|
|
1422
|
+
table.setColumnWidth(4, 80)
|
|
1423
|
+
table.setColumnWidth(5, 80)
|
|
1424
|
+
|
|
1425
|
+
table.setAlternatingRowColors(True)
|
|
1426
|
+
table.setSelectionBehavior(QTableWidget.SelectionBehavior.SelectRows)
|
|
1427
|
+
table.itemSelectionChanged.connect(self._on_rule_selected)
|
|
1428
|
+
|
|
1429
|
+
return table
|
|
1430
|
+
|
|
1431
|
+
def _populate_data(self):
|
|
1432
|
+
"""Populate visualization with classifier data."""
|
|
1433
|
+
rules = self.classifier.get_rules()
|
|
1434
|
+
|
|
1435
|
+
self.n_rules_label.setText(f"Rules: {len(rules)}")
|
|
1436
|
+
|
|
1437
|
+
accuracy = self.classifier.score(self.X, self.y)
|
|
1438
|
+
self.accuracy_label.setText(f"Training Accuracy: {accuracy:.1%}")
|
|
1439
|
+
|
|
1440
|
+
total_coverage = sum(r.coverage for r in rules)
|
|
1441
|
+
self.coverage_label.setText(f"Total Coverage: {total_coverage}")
|
|
1442
|
+
|
|
1443
|
+
classes_str = ", ".join(str(c) for c in self.classifier.classes_)
|
|
1444
|
+
self.classes_label.setText(f"Classes: {classes_str}")
|
|
1445
|
+
|
|
1446
|
+
# Populate rules table
|
|
1447
|
+
self.rules_table.setRowCount(len(rules))
|
|
1448
|
+
for i, rule in enumerate(rules):
|
|
1449
|
+
self.rules_table.setItem(i, 0, QTableWidgetItem(str(i + 1)))
|
|
1450
|
+
|
|
1451
|
+
conditions_str = " AND ".join(str(c) for c in rule.conditions) or "True"
|
|
1452
|
+
self.rules_table.setItem(i, 1, QTableWidgetItem(conditions_str))
|
|
1453
|
+
|
|
1454
|
+
self.rules_table.setItem(i, 2, QTableWidgetItem(str(rule.predicted_class)))
|
|
1455
|
+
self.rules_table.setItem(i, 3, QTableWidgetItem(str(rule.coverage)))
|
|
1456
|
+
self.rules_table.setItem(i, 4, QTableWidgetItem(f"{rule.accuracy:.1%}"))
|
|
1457
|
+
self.rules_table.setItem(i, 5, QTableWidgetItem(f"{rule.quality_score:.2f}"))
|
|
1458
|
+
|
|
1459
|
+
# Color by accuracy
|
|
1460
|
+
if rule.accuracy >= 0.9:
|
|
1461
|
+
color = QColor(200, 255, 200)
|
|
1462
|
+
elif rule.accuracy >= 0.7:
|
|
1463
|
+
color = QColor(255, 255, 200)
|
|
1464
|
+
else:
|
|
1465
|
+
color = QColor(255, 200, 200)
|
|
1466
|
+
|
|
1467
|
+
for j in range(6):
|
|
1468
|
+
item = self.rules_table.item(i, j)
|
|
1469
|
+
if item:
|
|
1470
|
+
item.setBackground(color)
|
|
1471
|
+
|
|
1472
|
+
self._populate_algorithm_info()
|
|
1473
|
+
|
|
1474
|
+
if rules:
|
|
1475
|
+
self.rules_table.selectRow(0)
|
|
1476
|
+
|
|
1477
|
+
def _populate_algorithm_info(self):
|
|
1478
|
+
"""Populate algorithm info tab."""
|
|
1479
|
+
clf = self.classifier
|
|
1480
|
+
info = f"""
|
|
1481
|
+
=== CN2 Rule Induction Algorithm ===
|
|
1482
|
+
|
|
1483
|
+
Parameters:
|
|
1484
|
+
• Beam Width: {clf.beam_width}
|
|
1485
|
+
• Min Covered Examples: {clf.min_covered_examples}
|
|
1486
|
+
• Max Rule Length: {clf.max_rule_length}
|
|
1487
|
+
• Max Rules: {clf.max_rules}
|
|
1488
|
+
• Quality Measure: {clf.quality_measure}
|
|
1489
|
+
• Discretization Bins: {clf.discretization_bins}
|
|
1490
|
+
|
|
1491
|
+
Dataset:
|
|
1492
|
+
• Total Samples: {len(self.X)}
|
|
1493
|
+
• Features: {clf.n_features_}
|
|
1494
|
+
• Classes: {len(clf.classes_)}
|
|
1495
|
+
|
|
1496
|
+
Results:
|
|
1497
|
+
• Rules Learned: {len(clf.rules_)}
|
|
1498
|
+
• Default Class: {clf.default_class_}
|
|
1499
|
+
• Training Accuracy: {clf.score(self.X, self.y):.2%}
|
|
1500
|
+
|
|
1501
|
+
How CN2 Works:
|
|
1502
|
+
1. Pre-compute all possible condition masks (for speed)
|
|
1503
|
+
2. Use beam search to find the best rule
|
|
1504
|
+
3. Add rule to rule list
|
|
1505
|
+
4. Remove covered examples
|
|
1506
|
+
5. Repeat until stopping criteria met
|
|
1507
|
+
"""
|
|
1508
|
+
self.info_text.setPlainText(info)
|
|
1509
|
+
|
|
1510
|
+
def _on_rule_selected(self):
|
|
1511
|
+
"""Handle rule selection."""
|
|
1512
|
+
selected = self.rules_table.selectedItems()
|
|
1513
|
+
if not selected:
|
|
1514
|
+
return
|
|
1515
|
+
|
|
1516
|
+
row = selected[0].row()
|
|
1517
|
+
rules = self.classifier.get_rules()
|
|
1518
|
+
if row < len(rules):
|
|
1519
|
+
self._show_rule_details(rules[row], row + 1)
|
|
1520
|
+
|
|
1521
|
+
def _show_rule_details(self, rule: Rule, rule_num: int):
|
|
1522
|
+
"""Show rule details."""
|
|
1523
|
+
details = f"""
|
|
1524
|
+
=== Rule {rule_num} Details ===
|
|
1525
|
+
|
|
1526
|
+
Full Rule:
|
|
1527
|
+
{str(rule)}
|
|
1528
|
+
|
|
1529
|
+
Conditions ({len(rule.conditions)}):
|
|
1530
|
+
"""
|
|
1531
|
+
if rule.conditions:
|
|
1532
|
+
for i, cond in enumerate(rule.conditions, 1):
|
|
1533
|
+
details += f" {i}. {cond}\n"
|
|
1534
|
+
else:
|
|
1535
|
+
details += " (No conditions - default rule)\n"
|
|
1536
|
+
|
|
1537
|
+
details += f"""
|
|
1538
|
+
Prediction:
|
|
1539
|
+
Predicted Class: {rule.predicted_class}
|
|
1540
|
+
|
|
1541
|
+
Metrics:
|
|
1542
|
+
Coverage: {rule.coverage} examples
|
|
1543
|
+
Accuracy: {rule.accuracy:.2%}
|
|
1544
|
+
Quality Score: {rule.quality_score:.4f}
|
|
1545
|
+
|
|
1546
|
+
Class Distribution:
|
|
1547
|
+
"""
|
|
1548
|
+
for cls, count in sorted(rule.class_distribution.items(), key=lambda x: -x[1]):
|
|
1549
|
+
pct = count / rule.coverage * 100 if rule.coverage > 0 else 0
|
|
1550
|
+
bar = "█" * int(pct / 5)
|
|
1551
|
+
details += f" {cls}: {count} ({pct:.1f}%) {bar}\n"
|
|
1552
|
+
|
|
1553
|
+
self.details_text.setPlainText(details)
|
|
1554
|
+
|
|
1555
|
+
def _apply_rules(self):
|
|
1556
|
+
"""Apply rules."""
|
|
1557
|
+
reply = QMessageBox.question(
|
|
1558
|
+
self,
|
|
1559
|
+
"Apply Rules",
|
|
1560
|
+
"Apply these rules to generate predictions?",
|
|
1561
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
1562
|
+
QMessageBox.StandardButton.No
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
1566
|
+
self.rulesApplied.emit(self.classifier)
|
|
1567
|
+
QMessageBox.information(self, "Rules Applied", "CN2 classifier applied successfully.")
|
|
1568
|
+
|
|
1569
|
+
|
|
1570
|
+
# =============================================================================
|
|
1571
|
+
# Convenience Functions
|
|
1572
|
+
# =============================================================================
|
|
1573
|
+
|
|
1574
|
+
def fit_cn2(df: pd.DataFrame, target_column: str, **kwargs) -> CN2Classifier:
|
|
1575
|
+
"""
|
|
1576
|
+
Fit a CN2 classifier on a DataFrame.
|
|
1577
|
+
"""
|
|
1578
|
+
if target_column not in df.columns:
|
|
1579
|
+
raise ValueError(f"Target column '{target_column}' not found")
|
|
1580
|
+
|
|
1581
|
+
X = df.drop(columns=[target_column])
|
|
1582
|
+
y = df[target_column].values
|
|
1583
|
+
|
|
1584
|
+
clf = CN2Classifier(**kwargs)
|
|
1585
|
+
clf.fit(X, y)
|
|
1586
|
+
|
|
1587
|
+
return clf
|
|
1588
|
+
|
|
1589
|
+
|
|
1590
|
+
def visualize_cn2_rules(df: pd.DataFrame, target_column: str, **kwargs) -> CN2RulesVisualization:
|
|
1591
|
+
"""
|
|
1592
|
+
Fit CN2 and create visualization window.
|
|
1593
|
+
"""
|
|
1594
|
+
if target_column not in df.columns:
|
|
1595
|
+
raise ValueError(f"Target column '{target_column}' not found")
|
|
1596
|
+
|
|
1597
|
+
X = df.drop(columns=[target_column])
|
|
1598
|
+
y = df[target_column].values
|
|
1599
|
+
|
|
1600
|
+
clf = CN2Classifier(**kwargs)
|
|
1601
|
+
clf.fit(X, y)
|
|
1602
|
+
|
|
1603
|
+
vis = CN2RulesVisualization(clf, X, y, target_column)
|
|
1604
|
+
vis.show()
|
|
1605
|
+
|
|
1606
|
+
return vis
|
|
1607
|
+
|
|
1608
|
+
|
|
1609
|
+
# =============================================================================
|
|
1610
|
+
# Testing
|
|
1611
|
+
# =============================================================================
|
|
1612
|
+
|
|
1613
|
+
def test_cn2():
|
|
1614
|
+
"""Test the CN2 classifier."""
|
|
1615
|
+
print("\n===== Testing CN2 Rule Induction =====\n")
|
|
1616
|
+
|
|
1617
|
+
np.random.seed(42)
|
|
1618
|
+
n_samples = 150
|
|
1619
|
+
|
|
1620
|
+
data = {
|
|
1621
|
+
'sepal_length': np.concatenate([
|
|
1622
|
+
np.random.normal(5.0, 0.3, 50),
|
|
1623
|
+
np.random.normal(6.0, 0.4, 50),
|
|
1624
|
+
np.random.normal(6.5, 0.4, 50)
|
|
1625
|
+
]),
|
|
1626
|
+
'petal_length': np.concatenate([
|
|
1627
|
+
np.random.normal(1.4, 0.2, 50),
|
|
1628
|
+
np.random.normal(4.2, 0.4, 50),
|
|
1629
|
+
np.random.normal(5.5, 0.5, 50)
|
|
1630
|
+
]),
|
|
1631
|
+
'species': ['setosa'] * 50 + ['versicolor'] * 50 + ['virginica'] * 50
|
|
1632
|
+
}
|
|
1633
|
+
|
|
1634
|
+
df = pd.DataFrame(data)
|
|
1635
|
+
|
|
1636
|
+
print("Dataset shape:", df.shape)
|
|
1637
|
+
print("Class distribution:")
|
|
1638
|
+
print(df['species'].value_counts())
|
|
1639
|
+
print()
|
|
1640
|
+
|
|
1641
|
+
import time
|
|
1642
|
+
start = time.time()
|
|
1643
|
+
clf = fit_cn2(df, target_column='species', beam_width=3, min_covered_examples=5)
|
|
1644
|
+
elapsed = time.time() - start
|
|
1645
|
+
|
|
1646
|
+
rules = clf.get_rules()
|
|
1647
|
+
print(f"\nLearned {len(rules)} rules in {elapsed:.3f}s:\n")
|
|
1648
|
+
for i, rule in enumerate(rules, 1):
|
|
1649
|
+
print(f"Rule {i}: {rule}")
|
|
1650
|
+
|
|
1651
|
+
X = df.drop(columns=['species'])
|
|
1652
|
+
y = df['species'].values
|
|
1653
|
+
|
|
1654
|
+
accuracy = clf.score(X, y)
|
|
1655
|
+
print(f"\nTraining accuracy: {accuracy:.2%}")
|
|
1656
|
+
|
|
1657
|
+
print("\n===== CN2 Test Complete =====\n")
|
|
1658
|
+
|
|
1659
|
+
|
|
1660
|
+
if __name__ == "__main__":
|
|
1661
|
+
test_cn2()
|