sqlshell 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlshell might be problematic. Click here for more details.

@@ -0,0 +1,1099 @@
1
+ import shap
2
+ import pandas as pd
3
+ import xgboost as xgb
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.preprocessing import LabelEncoder
7
+ import sys
8
+ import time
9
+ import hashlib
10
+ import os
11
+ import pickle
12
+ import gc
13
+ from pathlib import Path
14
+ from PyQt6.QtWidgets import (QApplication, QMainWindow, QTableWidget, QTableWidgetItem,
15
+ QVBoxLayout, QHBoxLayout, QLabel, QWidget, QComboBox,
16
+ QPushButton, QSplitter, QHeaderView, QFrame, QProgressBar,
17
+ QMessageBox, QDialog)
18
+ from PyQt6.QtCore import Qt, QAbstractTableModel, QModelIndex, QThread, pyqtSignal, QTimer
19
+ from PyQt6.QtGui import QPalette, QColor, QBrush, QPainter, QPen
20
+
21
+ # Import matplotlib at the top level
22
+ import matplotlib
23
+ matplotlib.use('QtAgg')
24
+ from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
25
+ from matplotlib.figure import Figure
26
+ import seaborn as sns
27
+
28
+ # Create a cache directory in user's home directory
29
+ CACHE_DIR = os.path.join(Path.home(), '.sqlshell_cache')
30
+ os.makedirs(CACHE_DIR, exist_ok=True)
31
+
32
+ def get_cache_key(df, column):
33
+ """Generate a cache key based on dataframe content and column"""
34
+ # Get DataFrame characteristics that make it unique
35
+ columns = ','.join(df.columns)
36
+ shapes = f"{df.shape[0]}x{df.shape[1]}"
37
+ col_types = ','.join(str(dtype) for dtype in df.dtypes)
38
+
39
+ # Sample some values as fingerprint without loading entire dataframe
40
+ sample_rows = min(50, len(df))
41
+ values_sample = df.head(sample_rows).values.tobytes()
42
+
43
+ # Create hash
44
+ hash_input = f"{columns}|{shapes}|{col_types}|{column}|{len(df)}"
45
+ m = hashlib.md5()
46
+ m.update(hash_input.encode())
47
+ m.update(values_sample) # Add sample data to hash
48
+ return m.hexdigest()
49
+
50
+ def cache_results(df, column, results):
51
+ """Save results to disk cache"""
52
+ try:
53
+ cache_key = get_cache_key(df, column)
54
+ cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
55
+ with open(cache_file, 'wb') as f:
56
+ pickle.dump(results, f)
57
+ return True
58
+ except Exception as e:
59
+ print(f"Cache write error: {e}")
60
+ return False
61
+
62
+ def get_cached_results(df, column):
63
+ """Try to get results from disk cache"""
64
+ try:
65
+ cache_key = get_cache_key(df, column)
66
+ cache_file = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
67
+ if os.path.exists(cache_file):
68
+ # Check if cache file is recent (less than 1 day old)
69
+ mod_time = os.path.getmtime(cache_file)
70
+ if time.time() - mod_time < 86400: # 24 hours in seconds
71
+ with open(cache_file, 'rb') as f:
72
+ return pickle.load(f)
73
+ return None
74
+ except Exception as e:
75
+ print(f"Cache read error: {e}")
76
+ return None
77
+
78
+ # Worker thread for background processing
79
+ class ExplainerThread(QThread):
80
+ # Signals for progress updates and results
81
+ progress = pyqtSignal(int, str)
82
+ result = pyqtSignal(object)
83
+ error = pyqtSignal(str)
84
+
85
+ def __init__(self, df, column):
86
+ super().__init__()
87
+ # Make a copy of the dataframe to avoid reference issues
88
+ self.df = df.copy()
89
+ self.column = column
90
+ self._is_canceled = False
91
+
92
+ def cancel(self):
93
+ """Mark the thread as canceled"""
94
+ self._is_canceled = True
95
+
96
+ def run(self):
97
+ try:
98
+ # Check if canceled
99
+ if self._is_canceled:
100
+ return
101
+
102
+ # Check disk cache first
103
+ self.progress.emit(0, "Checking for cached results...")
104
+ cached_results = get_cached_results(self.df, self.column)
105
+ if cached_results is not None:
106
+ # Check if canceled
107
+ if self._is_canceled:
108
+ return
109
+
110
+ self.progress.emit(95, "Found cached results, loading...")
111
+ time.sleep(0.5) # Brief pause to show the user we found a cache
112
+
113
+ # Check if canceled
114
+ if self._is_canceled:
115
+ return
116
+
117
+ self.progress.emit(100, "Loaded from cache")
118
+ self.result.emit(cached_results)
119
+ return
120
+
121
+ # Clean up memory before intensive computation
122
+ gc.collect()
123
+
124
+ # Check if canceled
125
+ if self._is_canceled:
126
+ return
127
+
128
+ # No cache found, proceed with computation
129
+ self.progress.emit(5, "Computing new analysis...")
130
+
131
+ # Create a copy to avoid modifying the original dataframe
132
+ df = self.df.copy()
133
+
134
+ # Sample up to 500 rows for better statistical significance while maintaining speed
135
+ if len(df) > 500:
136
+ sample_size = 500 # Increased sample size for better analysis
137
+ self.progress.emit(10, f"Sampling dataset (using {sample_size} rows from {len(df)} total)...")
138
+ df = df.sample(n=sample_size, random_state=42)
139
+ # Force garbage collection after sampling
140
+ gc.collect()
141
+
142
+ # Check if canceled
143
+ if self._is_canceled:
144
+ return
145
+
146
+ # Drop columns with too many unique values (likely IDs) or excessive NaNs
147
+ self.progress.emit(15, "Analyzing columns for preprocessing...")
148
+ cols_to_drop = []
149
+ for col in df.columns:
150
+ if col == self.column: # Don't drop target column
151
+ continue
152
+ try:
153
+ # Drop if more than 95% unique values (likely ID column)
154
+ if df[col].nunique() / len(df) > 0.95:
155
+ cols_to_drop.append(col)
156
+ # Drop if more than 50% missing values
157
+ elif df[col].isna().mean() > 0.5:
158
+ cols_to_drop.append(col)
159
+ except:
160
+ # If we can't analyze the column, drop it
161
+ cols_to_drop.append(col)
162
+
163
+ # Drop identified columns
164
+ if cols_to_drop:
165
+ self.progress.emit(20, f"Removing {len(cols_to_drop)} low-information columns...")
166
+ df = df.drop(columns=cols_to_drop)
167
+
168
+ # Ensure target column is still in the dataframe
169
+ if self.column not in df.columns:
170
+ raise ValueError(f"Target column '{self.column}' not found in dataframe after preprocessing")
171
+
172
+ # Separate features and target
173
+ self.progress.emit(25, "Preparing features and target...")
174
+ X = df.drop(columns=[self.column])
175
+ y = df[self.column]
176
+
177
+ # Handle high-cardinality categorical features
178
+ self.progress.emit(30, "Encoding categorical features...")
179
+ # Use a simpler approach - just one-hot encode columns with few unique values
180
+ # and drop high-cardinality columns completely for speed
181
+ categorical_cols = X.select_dtypes(include='object').columns
182
+ high_cardinality_threshold = 10 # Lower threshold to drop more columns
183
+
184
+ for col in categorical_cols:
185
+ unique_count = X[col].nunique()
186
+ if unique_count <= high_cardinality_threshold:
187
+ # Simple label encoding for low-cardinality features
188
+ X[col] = X[col].fillna('_MISSING_').astype('category').cat.codes
189
+ else:
190
+ # Drop high-cardinality features to speed up analysis
191
+ X = X.drop(columns=[col])
192
+
193
+ # Handle target column in a simpler, faster way
194
+ if y.dtype == 'object':
195
+ # For categorical targets, use simple category codes
196
+ y = y.fillna('_MISSING_').astype('category').cat.codes
197
+ else:
198
+ # For numeric targets, just fill NaNs with mean
199
+ y = y.fillna(y.mean() if pd.api.types.is_numeric_dtype(y) else y.mode()[0])
200
+
201
+ # Train/test split
202
+ self.progress.emit(40, "Splitting data into train/test sets...")
203
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
204
+
205
+ # Check if canceled
206
+ if self._is_canceled:
207
+ return
208
+
209
+ # Train a tree-based model
210
+ self.progress.emit(50, "Training XGBoost model...")
211
+ model = xgb.XGBRegressor(
212
+ n_estimators=5, # Absolute minimum number of trees
213
+ max_depth=2, # Very shallow trees
214
+ learning_rate=0.3, # Higher learning rate to compensate for fewer trees
215
+ tree_method='hist', # Fast histogram method
216
+ subsample=0.7, # Use 70% of data per tree
217
+ grow_policy='depthwise', # Simple growth policy
218
+ n_jobs=1, # Single thread to avoid overhead
219
+ random_state=42,
220
+ verbosity=0 # Suppress output
221
+ )
222
+
223
+ # Set memory conservation parameter for large datasets with many features
224
+ if X_train.shape[1] > 100: # If there are many features
225
+ self.progress.emit(55, "Large feature set detected, using memory-efficient training...")
226
+ model.set_params(grow_policy='lossguide', max_leaves=64)
227
+
228
+ # Fit model with a try/except to catch memory issues
229
+ try:
230
+ model.fit(X_train, y_train)
231
+ except Exception as e:
232
+ # If we encounter an error, try with an even smaller and simpler model
233
+ self.progress.emit(55, "Adjusting model parameters due to computational constraints...")
234
+ model = xgb.XGBRegressor(
235
+ n_estimators=5,
236
+ max_depth=2,
237
+ subsample=0.5,
238
+ colsample_bytree=0.5,
239
+ n_jobs=1
240
+ )
241
+ model.fit(X_train, y_train)
242
+
243
+ # Check if canceled
244
+ if self._is_canceled:
245
+ return
246
+
247
+ # Skip SHAP and use model feature importance directly for simplicity and reliability
248
+ self.progress.emit(80, "Calculating feature importance...")
249
+
250
+ try:
251
+ # Get feature importance directly from XGBoost
252
+ importance = model.feature_importances_
253
+
254
+ # Create and sort the importance dataframe
255
+ shap_importance = pd.DataFrame({
256
+ 'feature': X.columns,
257
+ 'mean_abs_shap': importance
258
+ }).sort_values(by='mean_abs_shap', ascending=False)
259
+
260
+ # Cache the results for future use
261
+ self.progress.emit(95, "Caching results for future use...")
262
+ cache_results(self.df, self.column, shap_importance)
263
+
264
+ # Clean up after computation
265
+ del df, X, y, X_train, X_test, y_train, y_test, model
266
+ gc.collect()
267
+
268
+ # Check if canceled
269
+ if self._is_canceled:
270
+ return
271
+
272
+ # Emit the result
273
+ self.progress.emit(100, "Analysis complete")
274
+ self.result.emit(shap_importance)
275
+ return
276
+
277
+ except Exception as e:
278
+ print(f"Error in feature importance calculation: {e}")
279
+ import traceback
280
+ traceback.print_exc()
281
+
282
+ # Last resort: create equal importance for all features
283
+ importance_values = np.ones(len(X.columns)) / len(X.columns)
284
+ shap_importance = pd.DataFrame({
285
+ 'feature': X.columns,
286
+ 'mean_abs_shap': importance_values
287
+ }).sort_values(by='mean_abs_shap', ascending=False)
288
+
289
+ # Cache the results
290
+ try:
291
+ cache_results(self.df, self.column, shap_importance)
292
+ except:
293
+ pass # Ignore cache errors
294
+
295
+ # Clean up
296
+ try:
297
+ del df, X, y, X_train, X_test, y_train, y_test, model
298
+ gc.collect()
299
+ except:
300
+ pass
301
+
302
+ # Emit the result
303
+ self.progress.emit(100, "Analysis complete (with default values)")
304
+ self.result.emit(shap_importance)
305
+ return
306
+
307
+ except Exception as e:
308
+ if not self._is_canceled: # Only emit error if not canceled
309
+ import traceback
310
+ print(f"Error in ExplainerThread: {str(e)}")
311
+ print(traceback.format_exc()) # Print full stack trace to help debug
312
+ self.error.emit(str(e))
313
+
314
+ def analyze_column(self):
315
+ if self.df is None or self.column_selector.currentText() == "":
316
+ return
317
+
318
+ # Cancel any existing worker thread
319
+ if self.worker_thread and self.worker_thread.isRunning():
320
+ # Signal the thread to cancel
321
+ self.worker_thread.cancel()
322
+
323
+ try:
324
+ # Disconnect all signals to avoid callbacks during termination
325
+ self.worker_thread.progress.disconnect()
326
+ self.worker_thread.result.disconnect()
327
+ self.worker_thread.error.disconnect()
328
+ self.worker_thread.finished.disconnect()
329
+ except Exception:
330
+ pass # Already disconnected
331
+
332
+ # Terminate thread properly
333
+ self.worker_thread.terminate()
334
+ self.worker_thread.wait(1000) # Wait up to 1 second
335
+ self.worker_thread = None # Clear reference
336
+
337
+ target_column = self.column_selector.currentText()
338
+
339
+ # Check in-memory cache first (fastest)
340
+ if target_column in self.result_cache:
341
+ self.handle_results(self.result_cache[target_column])
342
+ return
343
+
344
+ # Check global application-wide cache second (still fast)
345
+ global_key = get_cache_key(self.df, target_column)
346
+ if global_key in ColumnProfilerApp.global_cache:
347
+ self.result_cache[target_column] = ColumnProfilerApp.global_cache[global_key]
348
+ self.handle_results(self.result_cache[target_column])
349
+ return
350
+
351
+ # Disk cache will be checked in the worker thread
352
+
353
+ # Disable the analyze button while processing
354
+ self.analyze_button.setEnabled(False)
355
+
356
+ # Show progress indicators
357
+ self.progress_bar.setValue(0)
358
+ self.progress_bar.show()
359
+ self.progress_label.setText("Starting analysis...")
360
+ self.progress_label.show()
361
+ self.cancel_button.show()
362
+
363
+ # Create and start the worker thread
364
+ self.worker_thread = ExplainerThread(self.df, target_column)
365
+ self.worker_thread.progress.connect(self.update_progress)
366
+ self.worker_thread.result.connect(self.cache_and_display_results)
367
+ self.worker_thread.error.connect(self.handle_error)
368
+ self.worker_thread.finished.connect(self.on_analysis_finished)
369
+ self.worker_thread.start()
370
+
371
+ def update_progress(self, value, message):
372
+ self.progress_bar.setValue(value)
373
+ self.progress_label.setText(message)
374
+
375
+ def cache_and_display_results(self, importance_df):
376
+ # Cache the results
377
+ target_column = self.column_selector.currentText()
378
+ self.result_cache[target_column] = importance_df
379
+
380
+ # Also cache in the global application cache
381
+ global_key = get_cache_key(self.df, target_column)
382
+ ColumnProfilerApp.global_cache[global_key] = importance_df
383
+
384
+ # Display the results
385
+ self.handle_results(importance_df)
386
+
387
+ def on_analysis_finished(self):
388
+ """Handle cleanup when analysis is finished (either completed or cancelled)"""
389
+ self.analyze_button.setEnabled(True)
390
+ self.cancel_button.hide()
391
+
392
+ def handle_results(self, importance_df):
393
+ # Hide progress indicators
394
+ self.progress_bar.hide()
395
+ self.progress_label.hide()
396
+ self.cancel_button.hide()
397
+
398
+ # Update importance table incrementally
399
+ self.importance_table.setRowCount(len(importance_df))
400
+
401
+ # Using a timer for incremental updates
402
+ self.importance_df = importance_df # Store for incremental rendering
403
+ self.current_row = 0
404
+ self.render_timer = QTimer()
405
+ self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
406
+ self.render_timer.start(10) # Update every 10ms
407
+
408
+ def render_next_batch(self, batch_size):
409
+ if self.current_row >= len(self.importance_df):
410
+ # All rows rendered, now render the chart and stop the timer
411
+ self.render_chart()
412
+ self.render_timer.stop()
413
+ return
414
+
415
+ # Render a batch of rows
416
+ end_row = min(self.current_row + batch_size, len(self.importance_df))
417
+ for row in range(self.current_row, end_row):
418
+ feature = self.importance_df.iloc[row]['feature']
419
+ mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
420
+ self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
421
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
422
+
423
+ self.current_row = end_row
424
+ QApplication.processEvents() # Allow UI to update
425
+
426
+ def render_chart(self):
427
+ # Create horizontal bar chart
428
+ self.chart_view.axes.clear()
429
+
430
+ # Limit to top 20 features for better visualization
431
+ plot_df = self.importance_df.head(20)
432
+
433
+ # Plot with custom colors
434
+ bars = self.chart_view.axes.barh(
435
+ plot_df['feature'],
436
+ plot_df['mean_abs_shap'],
437
+ color='skyblue'
438
+ )
439
+
440
+ # Add values at the end of bars
441
+ for bar in bars:
442
+ width = bar.get_width()
443
+ self.chart_view.axes.text(
444
+ width * 1.05,
445
+ bar.get_y() + bar.get_height()/2,
446
+ f'{width:.2f}',
447
+ va='center'
448
+ )
449
+
450
+ self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
451
+ self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
452
+ self.chart_view.figure.tight_layout()
453
+ self.chart_view.draw()
454
+
455
+ def handle_error(self, error_message):
456
+ """Handle errors during analysis"""
457
+ # Hide progress indicators
458
+ self.progress_bar.hide()
459
+ self.progress_label.hide()
460
+ self.cancel_button.hide()
461
+
462
+ # Re-enable analyze button
463
+ self.analyze_button.setEnabled(True)
464
+
465
+ # Print error to console for debugging
466
+ print(f"Error in column profiler: {error_message}")
467
+
468
+ # Show error message
469
+ QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
470
+
471
+ # Show a message in the UI as well
472
+ self.importance_table.setRowCount(1)
473
+ self.importance_table.setColumnCount(1)
474
+ self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
475
+ self.importance_table.resizeColumnsToContents()
476
+
477
+ # Update the chart to show error
478
+ self.chart_view.axes.clear()
479
+ self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
480
+ ha='center', va='center', fontsize=12, color='red',
481
+ wrap=True)
482
+ self.chart_view.axes.set_axis_off()
483
+ self.chart_view.draw()
484
+
485
+ def closeEvent(self, event):
486
+ """Clean up when the window is closed"""
487
+ # Stop any running timer
488
+ if self.render_timer and self.render_timer.isActive():
489
+ self.render_timer.stop()
490
+
491
+ # Clean up any background threads
492
+ if self.worker_thread and self.worker_thread.isRunning():
493
+ # Disconnect all signals to avoid callbacks during termination
494
+ try:
495
+ self.worker_thread.progress.disconnect()
496
+ self.worker_thread.result.disconnect()
497
+ self.worker_thread.error.disconnect()
498
+ self.worker_thread.finished.disconnect()
499
+ except Exception:
500
+ pass # Already disconnected
501
+
502
+ # Terminate thread properly
503
+ self.worker_thread.terminate()
504
+ self.worker_thread.wait(1000) # Wait up to 1 second
505
+
506
+ # Clear references to prevent thread issues
507
+ self.worker_thread = None
508
+
509
+ # Clean up memory
510
+ self.result_cache.clear()
511
+
512
+ # Accept the close event
513
+ event.accept()
514
+
515
+ # Suggest garbage collection
516
+ gc.collect()
517
+
518
+ def cancel_analysis(self):
519
+ """Cancel the current analysis"""
520
+ if self.worker_thread and self.worker_thread.isRunning():
521
+ # Signal the thread to cancel first
522
+ self.worker_thread.cancel()
523
+
524
+ # Disconnect all signals to avoid callbacks during termination
525
+ try:
526
+ self.worker_thread.progress.disconnect()
527
+ self.worker_thread.result.disconnect()
528
+ self.worker_thread.error.disconnect()
529
+ self.worker_thread.finished.disconnect()
530
+ except Exception:
531
+ pass # Already disconnected
532
+
533
+ # Terminate thread properly
534
+ self.worker_thread.terminate()
535
+ self.worker_thread.wait(1000) # Wait up to 1 second
536
+
537
+ # Clear reference
538
+ self.worker_thread = None
539
+
540
+ # Update UI
541
+ self.progress_bar.hide()
542
+ self.progress_label.setText("Analysis cancelled")
543
+ self.progress_label.show()
544
+ self.cancel_button.hide()
545
+ self.analyze_button.setEnabled(True)
546
+
547
+ # Hide the progress label after 2 seconds
548
+ QTimer.singleShot(2000, self.progress_label.hide)
549
+
550
+ # Custom matplotlib canvas for embedding in Qt
551
+ class MatplotlibCanvas(FigureCanvasQTAgg):
552
+ def __init__(self, width=5, height=4, dpi=100):
553
+ self.figure = Figure(figsize=(width, height), dpi=dpi)
554
+ self.axes = self.figure.add_subplot(111)
555
+ super().__init__(self.figure)
556
+
557
+ # Main application class
558
+ class ColumnProfilerApp(QMainWindow):
559
+ # Global application-wide cache to prevent redundant computations
560
+ global_cache = {}
561
+
562
+ def __init__(self, df):
563
+ super().__init__()
564
+
565
+ # Store reference to data
566
+ self.df = df
567
+
568
+ # Initialize cache for results
569
+ self.result_cache = {}
570
+
571
+ # Initialize thread variable
572
+ self.worker_thread = None
573
+
574
+ # Variables for incremental rendering
575
+ self.importance_df = None
576
+ self.current_row = 0
577
+ self.render_timer = None
578
+
579
+ # Set window properties
580
+ self.setWindowTitle("Column Profiler")
581
+ self.setMinimumSize(900, 600)
582
+
583
+ # Create central widget and main layout
584
+ central_widget = QWidget()
585
+ main_layout = QVBoxLayout(central_widget)
586
+
587
+ # Create top control panel
588
+ control_panel = QWidget()
589
+ control_layout = QHBoxLayout(control_panel)
590
+
591
+ # Column selector
592
+ self.column_selector = QComboBox()
593
+ self.column_selector.addItems([col for col in df.columns])
594
+ control_layout.addWidget(QLabel("Select Column to Analyze:"))
595
+ control_layout.addWidget(self.column_selector)
596
+
597
+ # Analyze button
598
+ self.analyze_button = QPushButton("Analyze")
599
+ self.analyze_button.clicked.connect(self.analyze_column)
600
+ control_layout.addWidget(self.analyze_button)
601
+
602
+ # Progress indicators
603
+ self.progress_bar = QProgressBar()
604
+ self.progress_bar.setRange(0, 100)
605
+ self.progress_bar.hide()
606
+ self.progress_label = QLabel()
607
+ self.progress_label.hide()
608
+
609
+ # Cancel button
610
+ self.cancel_button = QPushButton("Cancel")
611
+ self.cancel_button.clicked.connect(self.cancel_analysis)
612
+ self.cancel_button.hide()
613
+
614
+ control_layout.addWidget(self.progress_bar)
615
+ control_layout.addWidget(self.progress_label)
616
+ control_layout.addWidget(self.cancel_button)
617
+
618
+ # Add control panel to main layout
619
+ main_layout.addWidget(control_panel)
620
+
621
+ # Add a splitter for results area
622
+ results_splitter = QSplitter(Qt.Orientation.Vertical)
623
+
624
+ # Create table for showing importance values
625
+ self.importance_table = QTableWidget()
626
+ self.importance_table.setColumnCount(2)
627
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance"])
628
+ self.importance_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
629
+ self.importance_table.cellDoubleClicked.connect(self.show_relationship_visualization)
630
+ results_splitter.addWidget(self.importance_table)
631
+
632
+ # Add instruction label for double-click functionality
633
+ instruction_label = QLabel("Double-click on any feature to view detailed relationship visualization with the target column")
634
+ instruction_label.setStyleSheet("color: #666; font-style: italic;")
635
+ instruction_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
636
+ main_layout.addWidget(instruction_label)
637
+
638
+ # Create matplotlib canvas for the chart
639
+ self.chart_view = MatplotlibCanvas(width=8, height=5, dpi=100)
640
+ results_splitter.addWidget(self.chart_view)
641
+
642
+ # Set initial splitter sizes
643
+ results_splitter.setSizes([300, 300])
644
+
645
+ # Add the splitter to the main layout
646
+ main_layout.addWidget(results_splitter)
647
+
648
+ # Set the central widget
649
+ self.setCentralWidget(central_widget)
650
+
651
+ def analyze_column(self):
652
+ if self.df is None or self.column_selector.currentText() == "":
653
+ return
654
+
655
+ # Cancel any existing worker thread
656
+ if self.worker_thread and self.worker_thread.isRunning():
657
+ # Signal the thread to cancel
658
+ self.worker_thread.cancel()
659
+
660
+ try:
661
+ # Disconnect all signals to avoid callbacks during termination
662
+ self.worker_thread.progress.disconnect()
663
+ self.worker_thread.result.disconnect()
664
+ self.worker_thread.error.disconnect()
665
+ self.worker_thread.finished.disconnect()
666
+ except Exception:
667
+ pass # Already disconnected
668
+
669
+ # Terminate thread properly
670
+ self.worker_thread.terminate()
671
+ self.worker_thread.wait(1000) # Wait up to 1 second
672
+ self.worker_thread = None # Clear reference
673
+
674
+ target_column = self.column_selector.currentText()
675
+
676
+ # Check in-memory cache first (fastest)
677
+ if target_column in self.result_cache:
678
+ self.handle_results(self.result_cache[target_column])
679
+ return
680
+
681
+ # Check global application-wide cache second (still fast)
682
+ global_key = get_cache_key(self.df, target_column)
683
+ if global_key in ColumnProfilerApp.global_cache:
684
+ self.result_cache[target_column] = ColumnProfilerApp.global_cache[global_key]
685
+ self.handle_results(self.result_cache[target_column])
686
+ return
687
+
688
+ # Disk cache will be checked in the worker thread
689
+
690
+ # Disable the analyze button while processing
691
+ self.analyze_button.setEnabled(False)
692
+
693
+ # Show progress indicators
694
+ self.progress_bar.setValue(0)
695
+ self.progress_bar.show()
696
+ self.progress_label.setText("Starting analysis...")
697
+ self.progress_label.show()
698
+ self.cancel_button.show()
699
+
700
+ # Create and start the worker thread
701
+ self.worker_thread = ExplainerThread(self.df, target_column)
702
+ self.worker_thread.progress.connect(self.update_progress)
703
+ self.worker_thread.result.connect(self.cache_and_display_results)
704
+ self.worker_thread.error.connect(self.handle_error)
705
+ self.worker_thread.finished.connect(self.on_analysis_finished)
706
+ self.worker_thread.start()
707
+
708
+ def update_progress(self, value, message):
709
+ self.progress_bar.setValue(value)
710
+ self.progress_label.setText(message)
711
+
712
+ def cache_and_display_results(self, importance_df):
713
+ # Cache the results
714
+ target_column = self.column_selector.currentText()
715
+ self.result_cache[target_column] = importance_df
716
+
717
+ # Also cache in the global application cache
718
+ global_key = get_cache_key(self.df, target_column)
719
+ ColumnProfilerApp.global_cache[global_key] = importance_df
720
+
721
+ # Display the results
722
+ self.handle_results(importance_df)
723
+
724
+ def on_analysis_finished(self):
725
+ """Handle cleanup when analysis is finished (either completed or cancelled)"""
726
+ self.analyze_button.setEnabled(True)
727
+ self.cancel_button.hide()
728
+
729
+ def handle_results(self, importance_df):
730
+ # Hide progress indicators
731
+ self.progress_bar.hide()
732
+ self.progress_label.hide()
733
+ self.cancel_button.hide()
734
+
735
+ # Update importance table incrementally
736
+ self.importance_table.setRowCount(len(importance_df))
737
+
738
+ # Using a timer for incremental updates
739
+ self.importance_df = importance_df # Store for incremental rendering
740
+ self.current_row = 0
741
+ self.render_timer = QTimer()
742
+ self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
743
+ self.render_timer.start(10) # Update every 10ms
744
+
745
+ def render_next_batch(self, batch_size):
746
+ if self.current_row >= len(self.importance_df):
747
+ # All rows rendered, now render the chart and stop the timer
748
+ self.render_chart()
749
+ self.render_timer.stop()
750
+ return
751
+
752
+ # Render a batch of rows
753
+ end_row = min(self.current_row + batch_size, len(self.importance_df))
754
+ for row in range(self.current_row, end_row):
755
+ feature = self.importance_df.iloc[row]['feature']
756
+ mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
757
+ self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
758
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
759
+
760
+ self.current_row = end_row
761
+ QApplication.processEvents() # Allow UI to update
762
+
763
+ def render_chart(self):
764
+ # Create horizontal bar chart
765
+ self.chart_view.axes.clear()
766
+
767
+ # Limit to top 20 features for better visualization
768
+ plot_df = self.importance_df.head(20)
769
+
770
+ # Plot with custom colors
771
+ bars = self.chart_view.axes.barh(
772
+ plot_df['feature'],
773
+ plot_df['mean_abs_shap'],
774
+ color='skyblue'
775
+ )
776
+
777
+ # Add values at the end of bars
778
+ for bar in bars:
779
+ width = bar.get_width()
780
+ self.chart_view.axes.text(
781
+ width * 1.05,
782
+ bar.get_y() + bar.get_height()/2,
783
+ f'{width:.2f}',
784
+ va='center'
785
+ )
786
+
787
+ self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
788
+ self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
789
+ self.chart_view.figure.tight_layout()
790
+ self.chart_view.draw()
791
+
792
+ def handle_error(self, error_message):
793
+ """Handle errors during analysis"""
794
+ # Hide progress indicators
795
+ self.progress_bar.hide()
796
+ self.progress_label.hide()
797
+ self.cancel_button.hide()
798
+
799
+ # Re-enable analyze button
800
+ self.analyze_button.setEnabled(True)
801
+
802
+ # Print error to console for debugging
803
+ print(f"Error in column profiler: {error_message}")
804
+
805
+ # Show error message
806
+ QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
807
+
808
+ # Show a message in the UI as well
809
+ self.importance_table.setRowCount(1)
810
+ self.importance_table.setColumnCount(1)
811
+ self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
812
+ self.importance_table.resizeColumnsToContents()
813
+
814
+ # Update the chart to show error
815
+ self.chart_view.axes.clear()
816
+ self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
817
+ ha='center', va='center', fontsize=12, color='red',
818
+ wrap=True)
819
+ self.chart_view.axes.set_axis_off()
820
+ self.chart_view.draw()
821
+
822
+ def closeEvent(self, event):
823
+ """Clean up when the window is closed"""
824
+ # Stop any running timer
825
+ if self.render_timer and self.render_timer.isActive():
826
+ self.render_timer.stop()
827
+
828
+ # Clean up any background threads
829
+ if self.worker_thread and self.worker_thread.isRunning():
830
+ # Disconnect all signals to avoid callbacks during termination
831
+ try:
832
+ self.worker_thread.progress.disconnect()
833
+ self.worker_thread.result.disconnect()
834
+ self.worker_thread.error.disconnect()
835
+ self.worker_thread.finished.disconnect()
836
+ except Exception:
837
+ pass # Already disconnected
838
+
839
+ # Terminate thread properly
840
+ self.worker_thread.terminate()
841
+ self.worker_thread.wait(1000) # Wait up to 1 second
842
+
843
+ # Clear references to prevent thread issues
844
+ self.worker_thread = None
845
+
846
+ # Clean up memory
847
+ self.result_cache.clear()
848
+
849
+ # Accept the close event
850
+ event.accept()
851
+
852
+ # Suggest garbage collection
853
+ gc.collect()
854
+
855
+ def cancel_analysis(self):
856
+ """Cancel the current analysis"""
857
+ if self.worker_thread and self.worker_thread.isRunning():
858
+ # Signal the thread to cancel first
859
+ self.worker_thread.cancel()
860
+
861
+ # Disconnect all signals to avoid callbacks during termination
862
+ try:
863
+ self.worker_thread.progress.disconnect()
864
+ self.worker_thread.result.disconnect()
865
+ self.worker_thread.error.disconnect()
866
+ self.worker_thread.finished.disconnect()
867
+ except Exception:
868
+ pass # Already disconnected
869
+
870
+ # Terminate thread properly
871
+ self.worker_thread.terminate()
872
+ self.worker_thread.wait(1000) # Wait up to 1 second
873
+
874
+ # Clear reference
875
+ self.worker_thread = None
876
+
877
+ # Update UI
878
+ self.progress_bar.hide()
879
+ self.progress_label.setText("Analysis cancelled")
880
+ self.progress_label.show()
881
+ self.cancel_button.hide()
882
+ self.analyze_button.setEnabled(True)
883
+
884
+ # Hide the progress label after 2 seconds
885
+ QTimer.singleShot(2000, self.progress_label.hide)
886
+
887
+ def show_relationship_visualization(self, row, column):
888
+ """Show visualization of relationship between selected feature and target column"""
889
+ if self.importance_df is None or row >= len(self.importance_df):
890
+ return
891
+
892
+ # Get the feature name and target column
893
+ feature = self.importance_df.iloc[row]['feature']
894
+ target = self.column_selector.currentText()
895
+
896
+ # Create a dialog to show the visualization
897
+ dialog = QDialog(self)
898
+ dialog.setWindowTitle(f"Relationship: {feature} vs {target}")
899
+ dialog.resize(800, 600)
900
+
901
+ # Create layout
902
+ layout = QVBoxLayout(dialog)
903
+
904
+ # Create canvas for the plot
905
+ canvas = MatplotlibCanvas(width=8, height=6, dpi=100)
906
+ layout.addWidget(canvas)
907
+
908
+ # Determine the data types
909
+ feature_is_numeric = pd.api.types.is_numeric_dtype(self.df[feature])
910
+ target_is_numeric = pd.api.types.is_numeric_dtype(self.df[target])
911
+
912
+ # Clear the figure
913
+ canvas.axes.clear()
914
+
915
+ # Create appropriate visualization based on data types
916
+ if feature_is_numeric and target_is_numeric:
917
+ # Scatter plot for numeric vs numeric
918
+ sns.scatterplot(x=feature, y=target, data=self.df, ax=canvas.axes)
919
+ # Add regression line
920
+ sns.regplot(x=feature, y=target, data=self.df, ax=canvas.axes,
921
+ scatter=False, line_kws={"color": "red"})
922
+ canvas.axes.set_title(f"Scatter Plot: {feature} vs {target}")
923
+
924
+ elif feature_is_numeric and not target_is_numeric:
925
+ # Box plot for numeric vs categorical
926
+ sns.boxplot(x=target, y=feature, data=self.df, ax=canvas.axes)
927
+ canvas.axes.set_title(f"Box Plot: {feature} by {target}")
928
+
929
+ elif not feature_is_numeric and target_is_numeric:
930
+ # Bar plot for categorical vs numeric
931
+ sns.barplot(x=feature, y=target, data=self.df, ax=canvas.axes)
932
+ canvas.axes.set_title(f"Bar Plot: Average {target} by {feature}")
933
+ # Rotate x-axis labels if there are many categories
934
+ if self.df[feature].nunique() > 5:
935
+ canvas.axes.set_xticklabels(canvas.axes.get_xticklabels(), rotation=45, ha='right')
936
+
937
+ else:
938
+ # Heatmap for categorical vs categorical
939
+ # Create a crosstab of the two categorical variables
940
+ crosstab = pd.crosstab(self.df[feature], self.df[target], normalize='index')
941
+ sns.heatmap(crosstab, annot=True, cmap="YlGnBu", ax=canvas.axes)
942
+ canvas.axes.set_title(f"Heatmap: {feature} vs {target}")
943
+
944
+ # Adjust layout and draw
945
+ canvas.figure.tight_layout()
946
+ canvas.draw()
947
+
948
+ # Add a close button
949
+ close_button = QPushButton("Close")
950
+ close_button.clicked.connect(dialog.accept)
951
+ layout.addWidget(close_button)
952
+
953
+ # Show the dialog
954
+ dialog.exec()
955
+
956
+ def visualize_profile(df: pd.DataFrame, column: str = None) -> None:
957
+ """
958
+ Launch a PyQt6 UI for visualizing column importance.
959
+
960
+ Args:
961
+ df: DataFrame containing the data
962
+ column: Optional target column to analyze immediately
963
+ """
964
+ try:
965
+ # Check if dataset is too small for meaningful analysis
966
+ row_count = len(df)
967
+ if row_count <= 5:
968
+ print(f"WARNING: Dataset only has {row_count} rows. Feature importance analysis requires more data for meaningful results.")
969
+ if QApplication.instance():
970
+ QMessageBox.warning(None, "Insufficient Data",
971
+ f"The dataset only contains {row_count} rows. Feature importance analysis requires more data for meaningful results.")
972
+
973
+ # For large datasets, sample up to 500 rows for better statistical significance
974
+ elif row_count > 500:
975
+ print(f"Sampling 500 rows from dataset ({row_count:,} total rows)")
976
+ df = df.sample(n=500, random_state=42)
977
+
978
+ # Check if we're already in a Qt application
979
+ existing_app = QApplication.instance()
980
+ standalone_mode = existing_app is None
981
+
982
+ # Create app if needed
983
+ if standalone_mode:
984
+ app = QApplication(sys.argv)
985
+ else:
986
+ app = existing_app
987
+
988
+ app.setStyle('Fusion') # Modern look
989
+
990
+ # Set modern dark theme (only in standalone mode to avoid affecting parent app)
991
+ if standalone_mode:
992
+ palette = QPalette()
993
+ palette.setColor(QPalette.ColorRole.Window, QColor(53, 53, 53))
994
+ palette.setColor(QPalette.ColorRole.WindowText, Qt.GlobalColor.white)
995
+ palette.setColor(QPalette.ColorRole.Base, QColor(25, 25, 25))
996
+ palette.setColor(QPalette.ColorRole.AlternateBase, QColor(53, 53, 53))
997
+ palette.setColor(QPalette.ColorRole.ToolTipBase, Qt.GlobalColor.white)
998
+ palette.setColor(QPalette.ColorRole.ToolTipText, Qt.GlobalColor.white)
999
+ palette.setColor(QPalette.ColorRole.Text, Qt.GlobalColor.white)
1000
+ palette.setColor(QPalette.ColorRole.Button, QColor(53, 53, 53))
1001
+ palette.setColor(QPalette.ColorRole.ButtonText, Qt.GlobalColor.white)
1002
+ palette.setColor(QPalette.ColorRole.BrightText, Qt.GlobalColor.red)
1003
+ palette.setColor(QPalette.ColorRole.Link, QColor(42, 130, 218))
1004
+ palette.setColor(QPalette.ColorRole.Highlight, QColor(42, 130, 218))
1005
+ palette.setColor(QPalette.ColorRole.HighlightedText, Qt.GlobalColor.black)
1006
+ app.setPalette(palette)
1007
+
1008
+ window = ColumnProfilerApp(df)
1009
+ window.setAttribute(Qt.WidgetAttribute.WA_DeleteOnClose) # Ensure cleanup on close
1010
+ window.show()
1011
+
1012
+ # Add tooltip to explain double-click functionality
1013
+ window.importance_table.setToolTip("Double-click on a feature to visualize its relationship with the target column")
1014
+
1015
+ # If a specific column is provided, analyze it immediately
1016
+ if column is not None and column in df.columns:
1017
+ window.column_selector.setCurrentText(column)
1018
+ # Wrap the analysis in a try/except to prevent crashes
1019
+ def safe_analyze():
1020
+ try:
1021
+ window.analyze_column()
1022
+ except Exception as e:
1023
+ print(f"Error during column analysis: {e}")
1024
+ import traceback
1025
+ traceback.print_exc()
1026
+ QMessageBox.critical(window, "Analysis Error",
1027
+ f"Error analyzing column:\n\n{str(e)}")
1028
+
1029
+ QTimer.singleShot(100, safe_analyze) # Use timer to avoid immediate thread issues
1030
+
1031
+ # Set a watchdog timer to cancel analysis if it takes too long (30 seconds)
1032
+ def check_progress():
1033
+ if window.worker_thread and window.worker_thread.isRunning():
1034
+ # If still running after 30 seconds, cancel the operation
1035
+ QMessageBox.warning(window, "Analysis Timeout",
1036
+ "The analysis is taking longer than expected. It will be canceled to prevent hanging.")
1037
+ try:
1038
+ window.cancel_analysis()
1039
+ except Exception as e:
1040
+ print(f"Error canceling analysis: {e}")
1041
+
1042
+ QTimer.singleShot(30000, check_progress) # 30 seconds timeout
1043
+
1044
+ # Only enter event loop in standalone mode
1045
+ if standalone_mode:
1046
+ sys.exit(app.exec())
1047
+ else:
1048
+ # Return the window for parent app to track
1049
+ return window
1050
+ except Exception as e:
1051
+ # Handle any exceptions to prevent crashes
1052
+ print(f"Error in visualize_profile: {e}")
1053
+ import traceback
1054
+ traceback.print_exc()
1055
+
1056
+ # Show error to user
1057
+ if QApplication.instance():
1058
+ QMessageBox.critical(None, "Profile Error", f"Error creating column profile:\n\n{str(e)}")
1059
+ return None
1060
+
1061
+ def test_profile():
1062
+ """
1063
+ Test the profile and visualization functions with sample data.
1064
+ """
1065
+ # Create a sample DataFrame
1066
+ np.random.seed(42)
1067
+ n = 1000
1068
+
1069
+ # Generate sample data with known relationships
1070
+ age = np.random.normal(35, 10, n).astype(int)
1071
+ experience = age - np.random.randint(18, 25, n) # experience correlates with age
1072
+ experience = np.maximum(0, experience) # no negative experience
1073
+
1074
+ salary = 30000 + 2000 * experience + np.random.normal(0, 10000, n)
1075
+
1076
+ departments = np.random.choice(['Engineering', 'Marketing', 'Sales', 'HR', 'Finance'], n)
1077
+ education = np.random.choice(['High School', 'Bachelor', 'Master', 'PhD'], n,
1078
+ p=[0.2, 0.5, 0.2, 0.1])
1079
+
1080
+ performance = np.random.normal(0, 1, n)
1081
+ performance += 0.5 * (education == 'Master') + 0.8 * (education == 'PhD') # education affects performance
1082
+ performance += 0.01 * experience # experience slightly affects performance
1083
+ performance = (performance - performance.min()) / (performance.max() - performance.min()) * 5 # scale to 0-5
1084
+
1085
+ # Create the DataFrame
1086
+ df = pd.DataFrame({
1087
+ 'Age': age,
1088
+ 'Experience': experience,
1089
+ 'Department': departments,
1090
+ 'Education': education,
1091
+ 'Performance': performance,
1092
+ 'Salary': salary
1093
+ })
1094
+
1095
+ print("Launching PyQt6 Column Profiler application...")
1096
+ visualize_profile(df, 'Salary') # Start with Salary analysis
1097
+
1098
+ if __name__ == "__main__":
1099
+ test_profile()