sqlshell 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlshell might be problematic. Click here for more details.
- sqlshell/README.md +5 -1
- sqlshell/__init__.py +6 -2
- sqlshell/create_test_data.py +29 -0
- sqlshell/main.py +214 -3
- sqlshell/table_list.py +90 -1
- sqlshell/ui/filter_header.py +14 -0
- sqlshell/utils/profile_column.py +1099 -0
- sqlshell/utils/profile_distributions.py +613 -0
- sqlshell/utils/profile_foreign_keys.py +455 -0
- sqlshell-0.2.3.dist-info/METADATA +281 -0
- {sqlshell-0.2.1.dist-info → sqlshell-0.2.3.dist-info}/RECORD +14 -11
- {sqlshell-0.2.1.dist-info → sqlshell-0.2.3.dist-info}/WHEEL +1 -1
- sqlshell-0.2.1.dist-info/METADATA +0 -198
- {sqlshell-0.2.1.dist-info → sqlshell-0.2.3.dist-info}/entry_points.txt +0 -0
- {sqlshell-0.2.1.dist-info → sqlshell-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1099 @@
|
|
|
1
|
+
import shap
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import xgboost as xgb
|
|
4
|
+
import numpy as np
|
|
5
|
+
from sklearn.model_selection import train_test_split
|
|
6
|
+
from sklearn.preprocessing import LabelEncoder
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
import hashlib
|
|
10
|
+
import os
|
|
11
|
+
import pickle
|
|
12
|
+
import gc
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from PyQt6.QtWidgets import (QApplication, QMainWindow, QTableWidget, QTableWidgetItem,
|
|
15
|
+
QVBoxLayout, QHBoxLayout, QLabel, QWidget, QComboBox,
|
|
16
|
+
QPushButton, QSplitter, QHeaderView, QFrame, QProgressBar,
|
|
17
|
+
QMessageBox, QDialog)
|
|
18
|
+
from PyQt6.QtCore import Qt, QAbstractTableModel, QModelIndex, QThread, pyqtSignal, QTimer
|
|
19
|
+
from PyQt6.QtGui import QPalette, QColor, QBrush, QPainter, QPen
|
|
20
|
+
|
|
21
|
+
# Import matplotlib at the top level
|
|
22
|
+
import matplotlib
|
|
23
|
+
matplotlib.use('QtAgg')
|
|
24
|
+
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
|
|
25
|
+
from matplotlib.figure import Figure
|
|
26
|
+
import seaborn as sns
|
|
27
|
+
|
|
28
|
+
# Create a cache directory in user's home directory
|
|
29
|
+
CACHE_DIR = os.path.join(Path.home(), '.sqlshell_cache')
|
|
30
|
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
31
|
+
|
|
32
|
+
def get_cache_key(df, column):
|
|
33
|
+
"""Generate a cache key based on dataframe content and column"""
|
|
34
|
+
# Get DataFrame characteristics that make it unique
|
|
35
|
+
columns = ','.join(df.columns)
|
|
36
|
+
shapes = f"{df.shape[0]}x{df.shape[1]}"
|
|
37
|
+
col_types = ','.join(str(dtype) for dtype in df.dtypes)
|
|
38
|
+
|
|
39
|
+
# Sample some values as fingerprint without loading entire dataframe
|
|
40
|
+
sample_rows = min(50, len(df))
|
|
41
|
+
values_sample = df.head(sample_rows).values.tobytes()
|
|
42
|
+
|
|
43
|
+
# Create hash
|
|
44
|
+
hash_input = f"{columns}|{shapes}|{col_types}|{column}|{len(df)}"
|
|
45
|
+
m = hashlib.md5()
|
|
46
|
+
m.update(hash_input.encode())
|
|
47
|
+
m.update(values_sample) # Add sample data to hash
|
|
48
|
+
return m.hexdigest()
|
|
49
|
+
|
|
50
|
+
def cache_results(df, column, results):
|
|
51
|
+
"""Save results to disk cache"""
|
|
52
|
+
try:
|
|
53
|
+
cache_key = get_cache_key(df, column)
|
|
54
|
+
cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
|
|
55
|
+
with open(cache_file, 'wb') as f:
|
|
56
|
+
pickle.dump(results, f)
|
|
57
|
+
return True
|
|
58
|
+
except Exception as e:
|
|
59
|
+
print(f"Cache write error: {e}")
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
def get_cached_results(df, column):
|
|
63
|
+
"""Try to get results from disk cache"""
|
|
64
|
+
try:
|
|
65
|
+
cache_key = get_cache_key(df, column)
|
|
66
|
+
cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
|
|
67
|
+
if os.path.exists(cache_file):
|
|
68
|
+
# Check if cache file is recent (less than 1 day old)
|
|
69
|
+
mod_time = os.path.getmtime(cache_file)
|
|
70
|
+
if time.time() - mod_time < 86400: # 24 hours in seconds
|
|
71
|
+
with open(cache_file, 'rb') as f:
|
|
72
|
+
return pickle.load(f)
|
|
73
|
+
return None
|
|
74
|
+
except Exception as e:
|
|
75
|
+
print(f"Cache read error: {e}")
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
# Worker thread for background processing
|
|
79
|
+
class ExplainerThread(QThread):
|
|
80
|
+
# Signals for progress updates and results
|
|
81
|
+
progress = pyqtSignal(int, str)
|
|
82
|
+
result = pyqtSignal(object)
|
|
83
|
+
error = pyqtSignal(str)
|
|
84
|
+
|
|
85
|
+
def __init__(self, df, column):
|
|
86
|
+
super().__init__()
|
|
87
|
+
# Make a copy of the dataframe to avoid reference issues
|
|
88
|
+
self.df = df.copy()
|
|
89
|
+
self.column = column
|
|
90
|
+
self._is_canceled = False
|
|
91
|
+
|
|
92
|
+
def cancel(self):
|
|
93
|
+
"""Mark the thread as canceled"""
|
|
94
|
+
self._is_canceled = True
|
|
95
|
+
|
|
96
|
+
def run(self):
|
|
97
|
+
try:
|
|
98
|
+
# Check if canceled
|
|
99
|
+
if self._is_canceled:
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
# Check disk cache first
|
|
103
|
+
self.progress.emit(0, "Checking for cached results...")
|
|
104
|
+
cached_results = get_cached_results(self.df, self.column)
|
|
105
|
+
if cached_results is not None:
|
|
106
|
+
# Check if canceled
|
|
107
|
+
if self._is_canceled:
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
self.progress.emit(95, "Found cached results, loading...")
|
|
111
|
+
time.sleep(0.5) # Brief pause to show the user we found a cache
|
|
112
|
+
|
|
113
|
+
# Check if canceled
|
|
114
|
+
if self._is_canceled:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
self.progress.emit(100, "Loaded from cache")
|
|
118
|
+
self.result.emit(cached_results)
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
# Clean up memory before intensive computation
|
|
122
|
+
gc.collect()
|
|
123
|
+
|
|
124
|
+
# Check if canceled
|
|
125
|
+
if self._is_canceled:
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
# No cache found, proceed with computation
|
|
129
|
+
self.progress.emit(5, "Computing new analysis...")
|
|
130
|
+
|
|
131
|
+
# Create a copy to avoid modifying the original dataframe
|
|
132
|
+
df = self.df.copy()
|
|
133
|
+
|
|
134
|
+
# Sample up to 500 rows for better statistical significance while maintaining speed
|
|
135
|
+
if len(df) > 500:
|
|
136
|
+
sample_size = 500 # Increased sample size for better analysis
|
|
137
|
+
self.progress.emit(10, f"Sampling dataset (using {sample_size} rows from {len(df)} total)...")
|
|
138
|
+
df = df.sample(n=sample_size, random_state=42)
|
|
139
|
+
# Force garbage collection after sampling
|
|
140
|
+
gc.collect()
|
|
141
|
+
|
|
142
|
+
# Check if canceled
|
|
143
|
+
if self._is_canceled:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
# Drop columns with too many unique values (likely IDs) or excessive NaNs
|
|
147
|
+
self.progress.emit(15, "Analyzing columns for preprocessing...")
|
|
148
|
+
cols_to_drop = []
|
|
149
|
+
for col in df.columns:
|
|
150
|
+
if col == self.column: # Don't drop target column
|
|
151
|
+
continue
|
|
152
|
+
try:
|
|
153
|
+
# Drop if more than 95% unique values (likely ID column)
|
|
154
|
+
if df[col].nunique() / len(df) > 0.95:
|
|
155
|
+
cols_to_drop.append(col)
|
|
156
|
+
# Drop if more than 50% missing values
|
|
157
|
+
elif df[col].isna().mean() > 0.5:
|
|
158
|
+
cols_to_drop.append(col)
|
|
159
|
+
except:
|
|
160
|
+
# If we can't analyze the column, drop it
|
|
161
|
+
cols_to_drop.append(col)
|
|
162
|
+
|
|
163
|
+
# Drop identified columns
|
|
164
|
+
if cols_to_drop:
|
|
165
|
+
self.progress.emit(20, f"Removing {len(cols_to_drop)} low-information columns...")
|
|
166
|
+
df = df.drop(columns=cols_to_drop)
|
|
167
|
+
|
|
168
|
+
# Ensure target column is still in the dataframe
|
|
169
|
+
if self.column not in df.columns:
|
|
170
|
+
raise ValueError(f"Target column '{self.column}' not found in dataframe after preprocessing")
|
|
171
|
+
|
|
172
|
+
# Separate features and target
|
|
173
|
+
self.progress.emit(25, "Preparing features and target...")
|
|
174
|
+
X = df.drop(columns=[self.column])
|
|
175
|
+
y = df[self.column]
|
|
176
|
+
|
|
177
|
+
# Handle high-cardinality categorical features
|
|
178
|
+
self.progress.emit(30, "Encoding categorical features...")
|
|
179
|
+
# Use a simpler approach - just one-hot encode columns with few unique values
|
|
180
|
+
# and drop high-cardinality columns completely for speed
|
|
181
|
+
categorical_cols = X.select_dtypes(include='object').columns
|
|
182
|
+
high_cardinality_threshold = 10 # Lower threshold to drop more columns
|
|
183
|
+
|
|
184
|
+
for col in categorical_cols:
|
|
185
|
+
unique_count = X[col].nunique()
|
|
186
|
+
if unique_count <= high_cardinality_threshold:
|
|
187
|
+
# Simple label encoding for low-cardinality features
|
|
188
|
+
X[col] = X[col].fillna('_MISSING_').astype('category').cat.codes
|
|
189
|
+
else:
|
|
190
|
+
# Drop high-cardinality features to speed up analysis
|
|
191
|
+
X = X.drop(columns=[col])
|
|
192
|
+
|
|
193
|
+
# Handle target column in a simpler, faster way
|
|
194
|
+
if y.dtype == 'object':
|
|
195
|
+
# For categorical targets, use simple category codes
|
|
196
|
+
y = y.fillna('_MISSING_').astype('category').cat.codes
|
|
197
|
+
else:
|
|
198
|
+
# For numeric targets, just fill NaNs with mean
|
|
199
|
+
y = y.fillna(y.mean() if pd.api.types.is_numeric_dtype(y) else y.mode()[0])
|
|
200
|
+
|
|
201
|
+
# Train/test split
|
|
202
|
+
self.progress.emit(40, "Splitting data into train/test sets...")
|
|
203
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
204
|
+
|
|
205
|
+
# Check if canceled
|
|
206
|
+
if self._is_canceled:
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
# Train a tree-based model
|
|
210
|
+
self.progress.emit(50, "Training XGBoost model...")
|
|
211
|
+
model = xgb.XGBRegressor(
|
|
212
|
+
n_estimators=5, # Absolute minimum number of trees
|
|
213
|
+
max_depth=2, # Very shallow trees
|
|
214
|
+
learning_rate=0.3, # Higher learning rate to compensate for fewer trees
|
|
215
|
+
tree_method='hist', # Fast histogram method
|
|
216
|
+
subsample=0.7, # Use 70% of data per tree
|
|
217
|
+
grow_policy='depthwise', # Simple growth policy
|
|
218
|
+
n_jobs=1, # Single thread to avoid overhead
|
|
219
|
+
random_state=42,
|
|
220
|
+
verbosity=0 # Suppress output
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Set memory conservation parameter for large datasets with many features
|
|
224
|
+
if X_train.shape[1] > 100: # If there are many features
|
|
225
|
+
self.progress.emit(55, "Large feature set detected, using memory-efficient training...")
|
|
226
|
+
model.set_params(grow_policy='lossguide', max_leaves=64)
|
|
227
|
+
|
|
228
|
+
# Fit model with a try/except to catch memory issues
|
|
229
|
+
try:
|
|
230
|
+
model.fit(X_train, y_train)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
# If we encounter an error, try with an even smaller and simpler model
|
|
233
|
+
self.progress.emit(55, "Adjusting model parameters due to computational constraints...")
|
|
234
|
+
model = xgb.XGBRegressor(
|
|
235
|
+
n_estimators=5,
|
|
236
|
+
max_depth=2,
|
|
237
|
+
subsample=0.5,
|
|
238
|
+
colsample_bytree=0.5,
|
|
239
|
+
n_jobs=1
|
|
240
|
+
)
|
|
241
|
+
model.fit(X_train, y_train)
|
|
242
|
+
|
|
243
|
+
# Check if canceled
|
|
244
|
+
if self._is_canceled:
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
# Skip SHAP and use model feature importance directly for simplicity and reliability
|
|
248
|
+
self.progress.emit(80, "Calculating feature importance...")
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
# Get feature importance directly from XGBoost
|
|
252
|
+
importance = model.feature_importances_
|
|
253
|
+
|
|
254
|
+
# Create and sort the importance dataframe
|
|
255
|
+
shap_importance = pd.DataFrame({
|
|
256
|
+
'feature': X.columns,
|
|
257
|
+
'mean_abs_shap': importance
|
|
258
|
+
}).sort_values(by='mean_abs_shap', ascending=False)
|
|
259
|
+
|
|
260
|
+
# Cache the results for future use
|
|
261
|
+
self.progress.emit(95, "Caching results for future use...")
|
|
262
|
+
cache_results(self.df, self.column, shap_importance)
|
|
263
|
+
|
|
264
|
+
# Clean up after computation
|
|
265
|
+
del df, X, y, X_train, X_test, y_train, y_test, model
|
|
266
|
+
gc.collect()
|
|
267
|
+
|
|
268
|
+
# Check if canceled
|
|
269
|
+
if self._is_canceled:
|
|
270
|
+
return
|
|
271
|
+
|
|
272
|
+
# Emit the result
|
|
273
|
+
self.progress.emit(100, "Analysis complete")
|
|
274
|
+
self.result.emit(shap_importance)
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
except Exception as e:
|
|
278
|
+
print(f"Error in feature importance calculation: {e}")
|
|
279
|
+
import traceback
|
|
280
|
+
traceback.print_exc()
|
|
281
|
+
|
|
282
|
+
# Last resort: create equal importance for all features
|
|
283
|
+
importance_values = np.ones(len(X.columns)) / len(X.columns)
|
|
284
|
+
shap_importance = pd.DataFrame({
|
|
285
|
+
'feature': X.columns,
|
|
286
|
+
'mean_abs_shap': importance_values
|
|
287
|
+
}).sort_values(by='mean_abs_shap', ascending=False)
|
|
288
|
+
|
|
289
|
+
# Cache the results
|
|
290
|
+
try:
|
|
291
|
+
cache_results(self.df, self.column, shap_importance)
|
|
292
|
+
except:
|
|
293
|
+
pass # Ignore cache errors
|
|
294
|
+
|
|
295
|
+
# Clean up
|
|
296
|
+
try:
|
|
297
|
+
del df, X, y, X_train, X_test, y_train, y_test, model
|
|
298
|
+
gc.collect()
|
|
299
|
+
except:
|
|
300
|
+
pass
|
|
301
|
+
|
|
302
|
+
# Emit the result
|
|
303
|
+
self.progress.emit(100, "Analysis complete (with default values)")
|
|
304
|
+
self.result.emit(shap_importance)
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
except Exception as e:
|
|
308
|
+
if not self._is_canceled: # Only emit error if not canceled
|
|
309
|
+
import traceback
|
|
310
|
+
print(f"Error in ExplainerThread: {str(e)}")
|
|
311
|
+
print(traceback.format_exc()) # Print full stack trace to help debug
|
|
312
|
+
self.error.emit(str(e))
|
|
313
|
+
|
|
314
|
+
def analyze_column(self):
|
|
315
|
+
if self.df is None or self.column_selector.currentText() == "":
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
# Cancel any existing worker thread
|
|
319
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
320
|
+
# Signal the thread to cancel
|
|
321
|
+
self.worker_thread.cancel()
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
325
|
+
self.worker_thread.progress.disconnect()
|
|
326
|
+
self.worker_thread.result.disconnect()
|
|
327
|
+
self.worker_thread.error.disconnect()
|
|
328
|
+
self.worker_thread.finished.disconnect()
|
|
329
|
+
except Exception:
|
|
330
|
+
pass # Already disconnected
|
|
331
|
+
|
|
332
|
+
# Terminate thread properly
|
|
333
|
+
self.worker_thread.terminate()
|
|
334
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
335
|
+
self.worker_thread = None # Clear reference
|
|
336
|
+
|
|
337
|
+
target_column = self.column_selector.currentText()
|
|
338
|
+
|
|
339
|
+
# Check in-memory cache first (fastest)
|
|
340
|
+
if target_column in self.result_cache:
|
|
341
|
+
self.handle_results(self.result_cache[target_column])
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
# Check global application-wide cache second (still fast)
|
|
345
|
+
global_key = get_cache_key(self.df, target_column)
|
|
346
|
+
if global_key in ColumnProfilerApp.global_cache:
|
|
347
|
+
self.result_cache[target_column] = ColumnProfilerApp.global_cache[global_key]
|
|
348
|
+
self.handle_results(self.result_cache[target_column])
|
|
349
|
+
return
|
|
350
|
+
|
|
351
|
+
# Disk cache will be checked in the worker thread
|
|
352
|
+
|
|
353
|
+
# Disable the analyze button while processing
|
|
354
|
+
self.analyze_button.setEnabled(False)
|
|
355
|
+
|
|
356
|
+
# Show progress indicators
|
|
357
|
+
self.progress_bar.setValue(0)
|
|
358
|
+
self.progress_bar.show()
|
|
359
|
+
self.progress_label.setText("Starting analysis...")
|
|
360
|
+
self.progress_label.show()
|
|
361
|
+
self.cancel_button.show()
|
|
362
|
+
|
|
363
|
+
# Create and start the worker thread
|
|
364
|
+
self.worker_thread = ExplainerThread(self.df, target_column)
|
|
365
|
+
self.worker_thread.progress.connect(self.update_progress)
|
|
366
|
+
self.worker_thread.result.connect(self.cache_and_display_results)
|
|
367
|
+
self.worker_thread.error.connect(self.handle_error)
|
|
368
|
+
self.worker_thread.finished.connect(self.on_analysis_finished)
|
|
369
|
+
self.worker_thread.start()
|
|
370
|
+
|
|
371
|
+
def update_progress(self, value, message):
|
|
372
|
+
self.progress_bar.setValue(value)
|
|
373
|
+
self.progress_label.setText(message)
|
|
374
|
+
|
|
375
|
+
def cache_and_display_results(self, importance_df):
|
|
376
|
+
# Cache the results
|
|
377
|
+
target_column = self.column_selector.currentText()
|
|
378
|
+
self.result_cache[target_column] = importance_df
|
|
379
|
+
|
|
380
|
+
# Also cache in the global application cache
|
|
381
|
+
global_key = get_cache_key(self.df, target_column)
|
|
382
|
+
ColumnProfilerApp.global_cache[global_key] = importance_df
|
|
383
|
+
|
|
384
|
+
# Display the results
|
|
385
|
+
self.handle_results(importance_df)
|
|
386
|
+
|
|
387
|
+
def on_analysis_finished(self):
|
|
388
|
+
"""Handle cleanup when analysis is finished (either completed or cancelled)"""
|
|
389
|
+
self.analyze_button.setEnabled(True)
|
|
390
|
+
self.cancel_button.hide()
|
|
391
|
+
|
|
392
|
+
def handle_results(self, importance_df):
|
|
393
|
+
# Hide progress indicators
|
|
394
|
+
self.progress_bar.hide()
|
|
395
|
+
self.progress_label.hide()
|
|
396
|
+
self.cancel_button.hide()
|
|
397
|
+
|
|
398
|
+
# Update importance table incrementally
|
|
399
|
+
self.importance_table.setRowCount(len(importance_df))
|
|
400
|
+
|
|
401
|
+
# Using a timer for incremental updates
|
|
402
|
+
self.importance_df = importance_df # Store for incremental rendering
|
|
403
|
+
self.current_row = 0
|
|
404
|
+
self.render_timer = QTimer()
|
|
405
|
+
self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
|
|
406
|
+
self.render_timer.start(10) # Update every 10ms
|
|
407
|
+
|
|
408
|
+
def render_next_batch(self, batch_size):
|
|
409
|
+
if self.current_row >= len(self.importance_df):
|
|
410
|
+
# All rows rendered, now render the chart and stop the timer
|
|
411
|
+
self.render_chart()
|
|
412
|
+
self.render_timer.stop()
|
|
413
|
+
return
|
|
414
|
+
|
|
415
|
+
# Render a batch of rows
|
|
416
|
+
end_row = min(self.current_row + batch_size, len(self.importance_df))
|
|
417
|
+
for row in range(self.current_row, end_row):
|
|
418
|
+
feature = self.importance_df.iloc[row]['feature']
|
|
419
|
+
mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
|
|
420
|
+
self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
|
|
421
|
+
self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
|
|
422
|
+
|
|
423
|
+
self.current_row = end_row
|
|
424
|
+
QApplication.processEvents() # Allow UI to update
|
|
425
|
+
|
|
426
|
+
def render_chart(self):
|
|
427
|
+
# Create horizontal bar chart
|
|
428
|
+
self.chart_view.axes.clear()
|
|
429
|
+
|
|
430
|
+
# Limit to top 20 features for better visualization
|
|
431
|
+
plot_df = self.importance_df.head(20)
|
|
432
|
+
|
|
433
|
+
# Plot with custom colors
|
|
434
|
+
bars = self.chart_view.axes.barh(
|
|
435
|
+
plot_df['feature'],
|
|
436
|
+
plot_df['mean_abs_shap'],
|
|
437
|
+
color='skyblue'
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Add values at the end of bars
|
|
441
|
+
for bar in bars:
|
|
442
|
+
width = bar.get_width()
|
|
443
|
+
self.chart_view.axes.text(
|
|
444
|
+
width * 1.05,
|
|
445
|
+
bar.get_y() + bar.get_height()/2,
|
|
446
|
+
f'{width:.2f}',
|
|
447
|
+
va='center'
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
|
|
451
|
+
self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
|
|
452
|
+
self.chart_view.figure.tight_layout()
|
|
453
|
+
self.chart_view.draw()
|
|
454
|
+
|
|
455
|
+
def handle_error(self, error_message):
|
|
456
|
+
"""Handle errors during analysis"""
|
|
457
|
+
# Hide progress indicators
|
|
458
|
+
self.progress_bar.hide()
|
|
459
|
+
self.progress_label.hide()
|
|
460
|
+
self.cancel_button.hide()
|
|
461
|
+
|
|
462
|
+
# Re-enable analyze button
|
|
463
|
+
self.analyze_button.setEnabled(True)
|
|
464
|
+
|
|
465
|
+
# Print error to console for debugging
|
|
466
|
+
print(f"Error in column profiler: {error_message}")
|
|
467
|
+
|
|
468
|
+
# Show error message
|
|
469
|
+
QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
|
|
470
|
+
|
|
471
|
+
# Show a message in the UI as well
|
|
472
|
+
self.importance_table.setRowCount(1)
|
|
473
|
+
self.importance_table.setColumnCount(1)
|
|
474
|
+
self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
|
|
475
|
+
self.importance_table.resizeColumnsToContents()
|
|
476
|
+
|
|
477
|
+
# Update the chart to show error
|
|
478
|
+
self.chart_view.axes.clear()
|
|
479
|
+
self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
|
|
480
|
+
ha='center', va='center', fontsize=12, color='red',
|
|
481
|
+
wrap=True)
|
|
482
|
+
self.chart_view.axes.set_axis_off()
|
|
483
|
+
self.chart_view.draw()
|
|
484
|
+
|
|
485
|
+
def closeEvent(self, event):
|
|
486
|
+
"""Clean up when the window is closed"""
|
|
487
|
+
# Stop any running timer
|
|
488
|
+
if self.render_timer and self.render_timer.isActive():
|
|
489
|
+
self.render_timer.stop()
|
|
490
|
+
|
|
491
|
+
# Clean up any background threads
|
|
492
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
493
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
494
|
+
try:
|
|
495
|
+
self.worker_thread.progress.disconnect()
|
|
496
|
+
self.worker_thread.result.disconnect()
|
|
497
|
+
self.worker_thread.error.disconnect()
|
|
498
|
+
self.worker_thread.finished.disconnect()
|
|
499
|
+
except Exception:
|
|
500
|
+
pass # Already disconnected
|
|
501
|
+
|
|
502
|
+
# Terminate thread properly
|
|
503
|
+
self.worker_thread.terminate()
|
|
504
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
505
|
+
|
|
506
|
+
# Clear references to prevent thread issues
|
|
507
|
+
self.worker_thread = None
|
|
508
|
+
|
|
509
|
+
# Clean up memory
|
|
510
|
+
self.result_cache.clear()
|
|
511
|
+
|
|
512
|
+
# Accept the close event
|
|
513
|
+
event.accept()
|
|
514
|
+
|
|
515
|
+
# Suggest garbage collection
|
|
516
|
+
gc.collect()
|
|
517
|
+
|
|
518
|
+
def cancel_analysis(self):
|
|
519
|
+
"""Cancel the current analysis"""
|
|
520
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
521
|
+
# Signal the thread to cancel first
|
|
522
|
+
self.worker_thread.cancel()
|
|
523
|
+
|
|
524
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
525
|
+
try:
|
|
526
|
+
self.worker_thread.progress.disconnect()
|
|
527
|
+
self.worker_thread.result.disconnect()
|
|
528
|
+
self.worker_thread.error.disconnect()
|
|
529
|
+
self.worker_thread.finished.disconnect()
|
|
530
|
+
except Exception:
|
|
531
|
+
pass # Already disconnected
|
|
532
|
+
|
|
533
|
+
# Terminate thread properly
|
|
534
|
+
self.worker_thread.terminate()
|
|
535
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
536
|
+
|
|
537
|
+
# Clear reference
|
|
538
|
+
self.worker_thread = None
|
|
539
|
+
|
|
540
|
+
# Update UI
|
|
541
|
+
self.progress_bar.hide()
|
|
542
|
+
self.progress_label.setText("Analysis cancelled")
|
|
543
|
+
self.progress_label.show()
|
|
544
|
+
self.cancel_button.hide()
|
|
545
|
+
self.analyze_button.setEnabled(True)
|
|
546
|
+
|
|
547
|
+
# Hide the progress label after 2 seconds
|
|
548
|
+
QTimer.singleShot(2000, self.progress_label.hide)
|
|
549
|
+
|
|
550
|
+
# Custom matplotlib canvas for embedding in Qt
|
|
551
|
+
class MatplotlibCanvas(FigureCanvasQTAgg):
|
|
552
|
+
def __init__(self, width=5, height=4, dpi=100):
|
|
553
|
+
self.figure = Figure(figsize=(width, height), dpi=dpi)
|
|
554
|
+
self.axes = self.figure.add_subplot(111)
|
|
555
|
+
super().__init__(self.figure)
|
|
556
|
+
|
|
557
|
+
# Main application class
|
|
558
|
+
class ColumnProfilerApp(QMainWindow):
|
|
559
|
+
# Global application-wide cache to prevent redundant computations
|
|
560
|
+
global_cache = {}
|
|
561
|
+
|
|
562
|
+
def __init__(self, df):
|
|
563
|
+
super().__init__()
|
|
564
|
+
|
|
565
|
+
# Store reference to data
|
|
566
|
+
self.df = df
|
|
567
|
+
|
|
568
|
+
# Initialize cache for results
|
|
569
|
+
self.result_cache = {}
|
|
570
|
+
|
|
571
|
+
# Initialize thread variable
|
|
572
|
+
self.worker_thread = None
|
|
573
|
+
|
|
574
|
+
# Variables for incremental rendering
|
|
575
|
+
self.importance_df = None
|
|
576
|
+
self.current_row = 0
|
|
577
|
+
self.render_timer = None
|
|
578
|
+
|
|
579
|
+
# Set window properties
|
|
580
|
+
self.setWindowTitle("Column Profiler")
|
|
581
|
+
self.setMinimumSize(900, 600)
|
|
582
|
+
|
|
583
|
+
# Create central widget and main layout
|
|
584
|
+
central_widget = QWidget()
|
|
585
|
+
main_layout = QVBoxLayout(central_widget)
|
|
586
|
+
|
|
587
|
+
# Create top control panel
|
|
588
|
+
control_panel = QWidget()
|
|
589
|
+
control_layout = QHBoxLayout(control_panel)
|
|
590
|
+
|
|
591
|
+
# Column selector
|
|
592
|
+
self.column_selector = QComboBox()
|
|
593
|
+
self.column_selector.addItems([col for col in df.columns])
|
|
594
|
+
control_layout.addWidget(QLabel("Select Column to Analyze:"))
|
|
595
|
+
control_layout.addWidget(self.column_selector)
|
|
596
|
+
|
|
597
|
+
# Analyze button
|
|
598
|
+
self.analyze_button = QPushButton("Analyze")
|
|
599
|
+
self.analyze_button.clicked.connect(self.analyze_column)
|
|
600
|
+
control_layout.addWidget(self.analyze_button)
|
|
601
|
+
|
|
602
|
+
# Progress indicators
|
|
603
|
+
self.progress_bar = QProgressBar()
|
|
604
|
+
self.progress_bar.setRange(0, 100)
|
|
605
|
+
self.progress_bar.hide()
|
|
606
|
+
self.progress_label = QLabel()
|
|
607
|
+
self.progress_label.hide()
|
|
608
|
+
|
|
609
|
+
# Cancel button
|
|
610
|
+
self.cancel_button = QPushButton("Cancel")
|
|
611
|
+
self.cancel_button.clicked.connect(self.cancel_analysis)
|
|
612
|
+
self.cancel_button.hide()
|
|
613
|
+
|
|
614
|
+
control_layout.addWidget(self.progress_bar)
|
|
615
|
+
control_layout.addWidget(self.progress_label)
|
|
616
|
+
control_layout.addWidget(self.cancel_button)
|
|
617
|
+
|
|
618
|
+
# Add control panel to main layout
|
|
619
|
+
main_layout.addWidget(control_panel)
|
|
620
|
+
|
|
621
|
+
# Add a splitter for results area
|
|
622
|
+
results_splitter = QSplitter(Qt.Orientation.Vertical)
|
|
623
|
+
|
|
624
|
+
# Create table for showing importance values
|
|
625
|
+
self.importance_table = QTableWidget()
|
|
626
|
+
self.importance_table.setColumnCount(2)
|
|
627
|
+
self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance"])
|
|
628
|
+
self.importance_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
|
|
629
|
+
self.importance_table.cellDoubleClicked.connect(self.show_relationship_visualization)
|
|
630
|
+
results_splitter.addWidget(self.importance_table)
|
|
631
|
+
|
|
632
|
+
# Add instruction label for double-click functionality
|
|
633
|
+
instruction_label = QLabel("Double-click on any feature to view detailed relationship visualization with the target column")
|
|
634
|
+
instruction_label.setStyleSheet("color: #666; font-style: italic;")
|
|
635
|
+
instruction_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
636
|
+
main_layout.addWidget(instruction_label)
|
|
637
|
+
|
|
638
|
+
# Create matplotlib canvas for the chart
|
|
639
|
+
self.chart_view = MatplotlibCanvas(width=8, height=5, dpi=100)
|
|
640
|
+
results_splitter.addWidget(self.chart_view)
|
|
641
|
+
|
|
642
|
+
# Set initial splitter sizes
|
|
643
|
+
results_splitter.setSizes([300, 300])
|
|
644
|
+
|
|
645
|
+
# Add the splitter to the main layout
|
|
646
|
+
main_layout.addWidget(results_splitter)
|
|
647
|
+
|
|
648
|
+
# Set the central widget
|
|
649
|
+
self.setCentralWidget(central_widget)
|
|
650
|
+
|
|
651
|
+
def analyze_column(self):
|
|
652
|
+
if self.df is None or self.column_selector.currentText() == "":
|
|
653
|
+
return
|
|
654
|
+
|
|
655
|
+
# Cancel any existing worker thread
|
|
656
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
657
|
+
# Signal the thread to cancel
|
|
658
|
+
self.worker_thread.cancel()
|
|
659
|
+
|
|
660
|
+
try:
|
|
661
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
662
|
+
self.worker_thread.progress.disconnect()
|
|
663
|
+
self.worker_thread.result.disconnect()
|
|
664
|
+
self.worker_thread.error.disconnect()
|
|
665
|
+
self.worker_thread.finished.disconnect()
|
|
666
|
+
except Exception:
|
|
667
|
+
pass # Already disconnected
|
|
668
|
+
|
|
669
|
+
# Terminate thread properly
|
|
670
|
+
self.worker_thread.terminate()
|
|
671
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
672
|
+
self.worker_thread = None # Clear reference
|
|
673
|
+
|
|
674
|
+
target_column = self.column_selector.currentText()
|
|
675
|
+
|
|
676
|
+
# Check in-memory cache first (fastest)
|
|
677
|
+
if target_column in self.result_cache:
|
|
678
|
+
self.handle_results(self.result_cache[target_column])
|
|
679
|
+
return
|
|
680
|
+
|
|
681
|
+
# Check global application-wide cache second (still fast)
|
|
682
|
+
global_key = get_cache_key(self.df, target_column)
|
|
683
|
+
if global_key in ColumnProfilerApp.global_cache:
|
|
684
|
+
self.result_cache[target_column] = ColumnProfilerApp.global_cache[global_key]
|
|
685
|
+
self.handle_results(self.result_cache[target_column])
|
|
686
|
+
return
|
|
687
|
+
|
|
688
|
+
# Disk cache will be checked in the worker thread
|
|
689
|
+
|
|
690
|
+
# Disable the analyze button while processing
|
|
691
|
+
self.analyze_button.setEnabled(False)
|
|
692
|
+
|
|
693
|
+
# Show progress indicators
|
|
694
|
+
self.progress_bar.setValue(0)
|
|
695
|
+
self.progress_bar.show()
|
|
696
|
+
self.progress_label.setText("Starting analysis...")
|
|
697
|
+
self.progress_label.show()
|
|
698
|
+
self.cancel_button.show()
|
|
699
|
+
|
|
700
|
+
# Create and start the worker thread
|
|
701
|
+
self.worker_thread = ExplainerThread(self.df, target_column)
|
|
702
|
+
self.worker_thread.progress.connect(self.update_progress)
|
|
703
|
+
self.worker_thread.result.connect(self.cache_and_display_results)
|
|
704
|
+
self.worker_thread.error.connect(self.handle_error)
|
|
705
|
+
self.worker_thread.finished.connect(self.on_analysis_finished)
|
|
706
|
+
self.worker_thread.start()
|
|
707
|
+
|
|
708
|
+
def update_progress(self, value, message):
|
|
709
|
+
self.progress_bar.setValue(value)
|
|
710
|
+
self.progress_label.setText(message)
|
|
711
|
+
|
|
712
|
+
def cache_and_display_results(self, importance_df):
|
|
713
|
+
# Cache the results
|
|
714
|
+
target_column = self.column_selector.currentText()
|
|
715
|
+
self.result_cache[target_column] = importance_df
|
|
716
|
+
|
|
717
|
+
# Also cache in the global application cache
|
|
718
|
+
global_key = get_cache_key(self.df, target_column)
|
|
719
|
+
ColumnProfilerApp.global_cache[global_key] = importance_df
|
|
720
|
+
|
|
721
|
+
# Display the results
|
|
722
|
+
self.handle_results(importance_df)
|
|
723
|
+
|
|
724
|
+
def on_analysis_finished(self):
|
|
725
|
+
"""Handle cleanup when analysis is finished (either completed or cancelled)"""
|
|
726
|
+
self.analyze_button.setEnabled(True)
|
|
727
|
+
self.cancel_button.hide()
|
|
728
|
+
|
|
729
|
+
def handle_results(self, importance_df):
|
|
730
|
+
# Hide progress indicators
|
|
731
|
+
self.progress_bar.hide()
|
|
732
|
+
self.progress_label.hide()
|
|
733
|
+
self.cancel_button.hide()
|
|
734
|
+
|
|
735
|
+
# Update importance table incrementally
|
|
736
|
+
self.importance_table.setRowCount(len(importance_df))
|
|
737
|
+
|
|
738
|
+
# Using a timer for incremental updates
|
|
739
|
+
self.importance_df = importance_df # Store for incremental rendering
|
|
740
|
+
self.current_row = 0
|
|
741
|
+
self.render_timer = QTimer()
|
|
742
|
+
self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
|
|
743
|
+
self.render_timer.start(10) # Update every 10ms
|
|
744
|
+
|
|
745
|
+
def render_next_batch(self, batch_size):
|
|
746
|
+
if self.current_row >= len(self.importance_df):
|
|
747
|
+
# All rows rendered, now render the chart and stop the timer
|
|
748
|
+
self.render_chart()
|
|
749
|
+
self.render_timer.stop()
|
|
750
|
+
return
|
|
751
|
+
|
|
752
|
+
# Render a batch of rows
|
|
753
|
+
end_row = min(self.current_row + batch_size, len(self.importance_df))
|
|
754
|
+
for row in range(self.current_row, end_row):
|
|
755
|
+
feature = self.importance_df.iloc[row]['feature']
|
|
756
|
+
mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
|
|
757
|
+
self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
|
|
758
|
+
self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
|
|
759
|
+
|
|
760
|
+
self.current_row = end_row
|
|
761
|
+
QApplication.processEvents() # Allow UI to update
|
|
762
|
+
|
|
763
|
+
def render_chart(self):
|
|
764
|
+
# Create horizontal bar chart
|
|
765
|
+
self.chart_view.axes.clear()
|
|
766
|
+
|
|
767
|
+
# Limit to top 20 features for better visualization
|
|
768
|
+
plot_df = self.importance_df.head(20)
|
|
769
|
+
|
|
770
|
+
# Plot with custom colors
|
|
771
|
+
bars = self.chart_view.axes.barh(
|
|
772
|
+
plot_df['feature'],
|
|
773
|
+
plot_df['mean_abs_shap'],
|
|
774
|
+
color='skyblue'
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
# Add values at the end of bars
|
|
778
|
+
for bar in bars:
|
|
779
|
+
width = bar.get_width()
|
|
780
|
+
self.chart_view.axes.text(
|
|
781
|
+
width * 1.05,
|
|
782
|
+
bar.get_y() + bar.get_height()/2,
|
|
783
|
+
f'{width:.2f}',
|
|
784
|
+
va='center'
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
|
|
788
|
+
self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
|
|
789
|
+
self.chart_view.figure.tight_layout()
|
|
790
|
+
self.chart_view.draw()
|
|
791
|
+
|
|
792
|
+
def handle_error(self, error_message):
|
|
793
|
+
"""Handle errors during analysis"""
|
|
794
|
+
# Hide progress indicators
|
|
795
|
+
self.progress_bar.hide()
|
|
796
|
+
self.progress_label.hide()
|
|
797
|
+
self.cancel_button.hide()
|
|
798
|
+
|
|
799
|
+
# Re-enable analyze button
|
|
800
|
+
self.analyze_button.setEnabled(True)
|
|
801
|
+
|
|
802
|
+
# Print error to console for debugging
|
|
803
|
+
print(f"Error in column profiler: {error_message}")
|
|
804
|
+
|
|
805
|
+
# Show error message
|
|
806
|
+
QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
|
|
807
|
+
|
|
808
|
+
# Show a message in the UI as well
|
|
809
|
+
self.importance_table.setRowCount(1)
|
|
810
|
+
self.importance_table.setColumnCount(1)
|
|
811
|
+
self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
|
|
812
|
+
self.importance_table.resizeColumnsToContents()
|
|
813
|
+
|
|
814
|
+
# Update the chart to show error
|
|
815
|
+
self.chart_view.axes.clear()
|
|
816
|
+
self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
|
|
817
|
+
ha='center', va='center', fontsize=12, color='red',
|
|
818
|
+
wrap=True)
|
|
819
|
+
self.chart_view.axes.set_axis_off()
|
|
820
|
+
self.chart_view.draw()
|
|
821
|
+
|
|
822
|
+
def closeEvent(self, event):
|
|
823
|
+
"""Clean up when the window is closed"""
|
|
824
|
+
# Stop any running timer
|
|
825
|
+
if self.render_timer and self.render_timer.isActive():
|
|
826
|
+
self.render_timer.stop()
|
|
827
|
+
|
|
828
|
+
# Clean up any background threads
|
|
829
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
830
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
831
|
+
try:
|
|
832
|
+
self.worker_thread.progress.disconnect()
|
|
833
|
+
self.worker_thread.result.disconnect()
|
|
834
|
+
self.worker_thread.error.disconnect()
|
|
835
|
+
self.worker_thread.finished.disconnect()
|
|
836
|
+
except Exception:
|
|
837
|
+
pass # Already disconnected
|
|
838
|
+
|
|
839
|
+
# Terminate thread properly
|
|
840
|
+
self.worker_thread.terminate()
|
|
841
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
842
|
+
|
|
843
|
+
# Clear references to prevent thread issues
|
|
844
|
+
self.worker_thread = None
|
|
845
|
+
|
|
846
|
+
# Clean up memory
|
|
847
|
+
self.result_cache.clear()
|
|
848
|
+
|
|
849
|
+
# Accept the close event
|
|
850
|
+
event.accept()
|
|
851
|
+
|
|
852
|
+
# Suggest garbage collection
|
|
853
|
+
gc.collect()
|
|
854
|
+
|
|
855
|
+
def cancel_analysis(self):
|
|
856
|
+
"""Cancel the current analysis"""
|
|
857
|
+
if self.worker_thread and self.worker_thread.isRunning():
|
|
858
|
+
# Signal the thread to cancel first
|
|
859
|
+
self.worker_thread.cancel()
|
|
860
|
+
|
|
861
|
+
# Disconnect all signals to avoid callbacks during termination
|
|
862
|
+
try:
|
|
863
|
+
self.worker_thread.progress.disconnect()
|
|
864
|
+
self.worker_thread.result.disconnect()
|
|
865
|
+
self.worker_thread.error.disconnect()
|
|
866
|
+
self.worker_thread.finished.disconnect()
|
|
867
|
+
except Exception:
|
|
868
|
+
pass # Already disconnected
|
|
869
|
+
|
|
870
|
+
# Terminate thread properly
|
|
871
|
+
self.worker_thread.terminate()
|
|
872
|
+
self.worker_thread.wait(1000) # Wait up to 1 second
|
|
873
|
+
|
|
874
|
+
# Clear reference
|
|
875
|
+
self.worker_thread = None
|
|
876
|
+
|
|
877
|
+
# Update UI
|
|
878
|
+
self.progress_bar.hide()
|
|
879
|
+
self.progress_label.setText("Analysis cancelled")
|
|
880
|
+
self.progress_label.show()
|
|
881
|
+
self.cancel_button.hide()
|
|
882
|
+
self.analyze_button.setEnabled(True)
|
|
883
|
+
|
|
884
|
+
# Hide the progress label after 2 seconds
|
|
885
|
+
QTimer.singleShot(2000, self.progress_label.hide)
|
|
886
|
+
|
|
887
|
+
def show_relationship_visualization(self, row, column):
|
|
888
|
+
"""Show visualization of relationship between selected feature and target column"""
|
|
889
|
+
if self.importance_df is None or row >= len(self.importance_df):
|
|
890
|
+
return
|
|
891
|
+
|
|
892
|
+
# Get the feature name and target column
|
|
893
|
+
feature = self.importance_df.iloc[row]['feature']
|
|
894
|
+
target = self.column_selector.currentText()
|
|
895
|
+
|
|
896
|
+
# Create a dialog to show the visualization
|
|
897
|
+
dialog = QDialog(self)
|
|
898
|
+
dialog.setWindowTitle(f"Relationship: {feature} vs {target}")
|
|
899
|
+
dialog.resize(800, 600)
|
|
900
|
+
|
|
901
|
+
# Create layout
|
|
902
|
+
layout = QVBoxLayout(dialog)
|
|
903
|
+
|
|
904
|
+
# Create canvas for the plot
|
|
905
|
+
canvas = MatplotlibCanvas(width=8, height=6, dpi=100)
|
|
906
|
+
layout.addWidget(canvas)
|
|
907
|
+
|
|
908
|
+
# Determine the data types
|
|
909
|
+
feature_is_numeric = pd.api.types.is_numeric_dtype(self.df[feature])
|
|
910
|
+
target_is_numeric = pd.api.types.is_numeric_dtype(self.df[target])
|
|
911
|
+
|
|
912
|
+
# Clear the figure
|
|
913
|
+
canvas.axes.clear()
|
|
914
|
+
|
|
915
|
+
# Create appropriate visualization based on data types
|
|
916
|
+
if feature_is_numeric and target_is_numeric:
|
|
917
|
+
# Scatter plot for numeric vs numeric
|
|
918
|
+
sns.scatterplot(x=feature, y=target, data=self.df, ax=canvas.axes)
|
|
919
|
+
# Add regression line
|
|
920
|
+
sns.regplot(x=feature, y=target, data=self.df, ax=canvas.axes,
|
|
921
|
+
scatter=False, line_kws={"color": "red"})
|
|
922
|
+
canvas.axes.set_title(f"Scatter Plot: {feature} vs {target}")
|
|
923
|
+
|
|
924
|
+
elif feature_is_numeric and not target_is_numeric:
|
|
925
|
+
# Box plot for numeric vs categorical
|
|
926
|
+
sns.boxplot(x=target, y=feature, data=self.df, ax=canvas.axes)
|
|
927
|
+
canvas.axes.set_title(f"Box Plot: {feature} by {target}")
|
|
928
|
+
|
|
929
|
+
elif not feature_is_numeric and target_is_numeric:
|
|
930
|
+
# Bar plot for categorical vs numeric
|
|
931
|
+
sns.barplot(x=feature, y=target, data=self.df, ax=canvas.axes)
|
|
932
|
+
canvas.axes.set_title(f"Bar Plot: Average {target} by {feature}")
|
|
933
|
+
# Rotate x-axis labels if there are many categories
|
|
934
|
+
if self.df[feature].nunique() > 5:
|
|
935
|
+
canvas.axes.set_xticklabels(canvas.axes.get_xticklabels(), rotation=45, ha='right')
|
|
936
|
+
|
|
937
|
+
else:
|
|
938
|
+
# Heatmap for categorical vs categorical
|
|
939
|
+
# Create a crosstab of the two categorical variables
|
|
940
|
+
crosstab = pd.crosstab(self.df[feature], self.df[target], normalize='index')
|
|
941
|
+
sns.heatmap(crosstab, annot=True, cmap="YlGnBu", ax=canvas.axes)
|
|
942
|
+
canvas.axes.set_title(f"Heatmap: {feature} vs {target}")
|
|
943
|
+
|
|
944
|
+
# Adjust layout and draw
|
|
945
|
+
canvas.figure.tight_layout()
|
|
946
|
+
canvas.draw()
|
|
947
|
+
|
|
948
|
+
# Add a close button
|
|
949
|
+
close_button = QPushButton("Close")
|
|
950
|
+
close_button.clicked.connect(dialog.accept)
|
|
951
|
+
layout.addWidget(close_button)
|
|
952
|
+
|
|
953
|
+
# Show the dialog
|
|
954
|
+
dialog.exec()
|
|
955
|
+
|
|
956
|
+
def visualize_profile(df: pd.DataFrame, column: str = None) -> None:
|
|
957
|
+
"""
|
|
958
|
+
Launch a PyQt6 UI for visualizing column importance.
|
|
959
|
+
|
|
960
|
+
Args:
|
|
961
|
+
df: DataFrame containing the data
|
|
962
|
+
column: Optional target column to analyze immediately
|
|
963
|
+
"""
|
|
964
|
+
try:
|
|
965
|
+
# Check if dataset is too small for meaningful analysis
|
|
966
|
+
row_count = len(df)
|
|
967
|
+
if row_count <= 5:
|
|
968
|
+
print(f"WARNING: Dataset only has {row_count} rows. Feature importance analysis requires more data for meaningful results.")
|
|
969
|
+
if QApplication.instance():
|
|
970
|
+
QMessageBox.warning(None, "Insufficient Data",
|
|
971
|
+
f"The dataset only contains {row_count} rows. Feature importance analysis requires more data for meaningful results.")
|
|
972
|
+
|
|
973
|
+
# For large datasets, sample up to 500 rows for better statistical significance
|
|
974
|
+
elif row_count > 500:
|
|
975
|
+
print(f"Sampling 500 rows from dataset ({row_count:,} total rows)")
|
|
976
|
+
df = df.sample(n=500, random_state=42)
|
|
977
|
+
|
|
978
|
+
# Check if we're already in a Qt application
|
|
979
|
+
existing_app = QApplication.instance()
|
|
980
|
+
standalone_mode = existing_app is None
|
|
981
|
+
|
|
982
|
+
# Create app if needed
|
|
983
|
+
if standalone_mode:
|
|
984
|
+
app = QApplication(sys.argv)
|
|
985
|
+
else:
|
|
986
|
+
app = existing_app
|
|
987
|
+
|
|
988
|
+
app.setStyle('Fusion') # Modern look
|
|
989
|
+
|
|
990
|
+
# Set modern dark theme (only in standalone mode to avoid affecting parent app)
|
|
991
|
+
if standalone_mode:
|
|
992
|
+
palette = QPalette()
|
|
993
|
+
palette.setColor(QPalette.ColorRole.Window, QColor(53, 53, 53))
|
|
994
|
+
palette.setColor(QPalette.ColorRole.WindowText, Qt.GlobalColor.white)
|
|
995
|
+
palette.setColor(QPalette.ColorRole.Base, QColor(25, 25, 25))
|
|
996
|
+
palette.setColor(QPalette.ColorRole.AlternateBase, QColor(53, 53, 53))
|
|
997
|
+
palette.setColor(QPalette.ColorRole.ToolTipBase, Qt.GlobalColor.white)
|
|
998
|
+
palette.setColor(QPalette.ColorRole.ToolTipText, Qt.GlobalColor.white)
|
|
999
|
+
palette.setColor(QPalette.ColorRole.Text, Qt.GlobalColor.white)
|
|
1000
|
+
palette.setColor(QPalette.ColorRole.Button, QColor(53, 53, 53))
|
|
1001
|
+
palette.setColor(QPalette.ColorRole.ButtonText, Qt.GlobalColor.white)
|
|
1002
|
+
palette.setColor(QPalette.ColorRole.BrightText, Qt.GlobalColor.red)
|
|
1003
|
+
palette.setColor(QPalette.ColorRole.Link, QColor(42, 130, 218))
|
|
1004
|
+
palette.setColor(QPalette.ColorRole.Highlight, QColor(42, 130, 218))
|
|
1005
|
+
palette.setColor(QPalette.ColorRole.HighlightedText, Qt.GlobalColor.black)
|
|
1006
|
+
app.setPalette(palette)
|
|
1007
|
+
|
|
1008
|
+
window = ColumnProfilerApp(df)
|
|
1009
|
+
window.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose) # Ensure cleanup on close
|
|
1010
|
+
window.show()
|
|
1011
|
+
|
|
1012
|
+
# Add tooltip to explain double-click functionality
|
|
1013
|
+
window.importance_table.setToolTip("Double-click on a feature to visualize its relationship with the target column")
|
|
1014
|
+
|
|
1015
|
+
# If a specific column is provided, analyze it immediately
|
|
1016
|
+
if column is not None and column in df.columns:
|
|
1017
|
+
window.column_selector.setCurrentText(column)
|
|
1018
|
+
# Wrap the analysis in a try/except to prevent crashes
|
|
1019
|
+
def safe_analyze():
|
|
1020
|
+
try:
|
|
1021
|
+
window.analyze_column()
|
|
1022
|
+
except Exception as e:
|
|
1023
|
+
print(f"Error during column analysis: {e}")
|
|
1024
|
+
import traceback
|
|
1025
|
+
traceback.print_exc()
|
|
1026
|
+
QMessageBox.critical(window, "Analysis Error",
|
|
1027
|
+
f"Error analyzing column:\n\n{str(e)}")
|
|
1028
|
+
|
|
1029
|
+
QTimer.singleShot(100, safe_analyze) # Use timer to avoid immediate thread issues
|
|
1030
|
+
|
|
1031
|
+
# Set a watchdog timer to cancel analysis if it takes too long (30 seconds)
|
|
1032
|
+
def check_progress():
|
|
1033
|
+
if window.worker_thread and window.worker_thread.isRunning():
|
|
1034
|
+
# If still running after 30 seconds, cancel the operation
|
|
1035
|
+
QMessageBox.warning(window, "Analysis Timeout",
|
|
1036
|
+
"The analysis is taking longer than expected. It will be canceled to prevent hanging.")
|
|
1037
|
+
try:
|
|
1038
|
+
window.cancel_analysis()
|
|
1039
|
+
except Exception as e:
|
|
1040
|
+
print(f"Error canceling analysis: {e}")
|
|
1041
|
+
|
|
1042
|
+
QTimer.singleShot(30000, check_progress) # 30 seconds timeout
|
|
1043
|
+
|
|
1044
|
+
# Only enter event loop in standalone mode
|
|
1045
|
+
if standalone_mode:
|
|
1046
|
+
sys.exit(app.exec())
|
|
1047
|
+
else:
|
|
1048
|
+
# Return the window for parent app to track
|
|
1049
|
+
return window
|
|
1050
|
+
except Exception as e:
|
|
1051
|
+
# Handle any exceptions to prevent crashes
|
|
1052
|
+
print(f"Error in visualize_profile: {e}")
|
|
1053
|
+
import traceback
|
|
1054
|
+
traceback.print_exc()
|
|
1055
|
+
|
|
1056
|
+
# Show error to user
|
|
1057
|
+
if QApplication.instance():
|
|
1058
|
+
QMessageBox.critical(None, "Profile Error", f"Error creating column profile:\n\n{str(e)}")
|
|
1059
|
+
return None
|
|
1060
|
+
|
|
1061
|
+
def test_profile():
|
|
1062
|
+
"""
|
|
1063
|
+
Test the profile and visualization functions with sample data.
|
|
1064
|
+
"""
|
|
1065
|
+
# Create a sample DataFrame
|
|
1066
|
+
np.random.seed(42)
|
|
1067
|
+
n = 1000
|
|
1068
|
+
|
|
1069
|
+
# Generate sample data with known relationships
|
|
1070
|
+
age = np.random.normal(35, 10, n).astype(int)
|
|
1071
|
+
experience = age - np.random.randint(18, 25, n) # experience correlates with age
|
|
1072
|
+
experience = np.maximum(0, experience) # no negative experience
|
|
1073
|
+
|
|
1074
|
+
salary = 30000 + 2000 * experience + np.random.normal(0, 10000, n)
|
|
1075
|
+
|
|
1076
|
+
departments = np.random.choice(['Engineering', 'Marketing', 'Sales', 'HR', 'Finance'], n)
|
|
1077
|
+
education = np.random.choice(['High School', 'Bachelor', 'Master', 'PhD'], n,
|
|
1078
|
+
p=[0.2, 0.5, 0.2, 0.1])
|
|
1079
|
+
|
|
1080
|
+
performance = np.random.normal(0, 1, n)
|
|
1081
|
+
performance += 0.5 * (education == 'Master') + 0.8 * (education == 'PhD') # education affects performance
|
|
1082
|
+
performance += 0.01 * experience # experience slightly affects performance
|
|
1083
|
+
performance = (performance - performance.min()) / (performance.max() - performance.min()) * 5 # scale to 0-5
|
|
1084
|
+
|
|
1085
|
+
# Create the DataFrame
|
|
1086
|
+
df = pd.DataFrame({
|
|
1087
|
+
'Age': age,
|
|
1088
|
+
'Experience': experience,
|
|
1089
|
+
'Department': departments,
|
|
1090
|
+
'Education': education,
|
|
1091
|
+
'Performance': performance,
|
|
1092
|
+
'Salary': salary
|
|
1093
|
+
})
|
|
1094
|
+
|
|
1095
|
+
print("Launching PyQt6 Column Profiler application...")
|
|
1096
|
+
visualize_profile(df, 'Salary') # Start with Salary analysis
|
|
1097
|
+
|
|
1098
|
+
if __name__ == "__main__":
|
|
1099
|
+
test_profile()
|