sqlshell 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlshell might be problematic. Click here for more details.

@@ -1,4 +1,3 @@
1
- import shap
2
1
  import pandas as pd
3
2
  import xgboost as xgb
4
3
  import numpy as np
@@ -17,6 +16,7 @@ from PyQt6.QtWidgets import (QApplication, QMainWindow, QTableWidget, QTableWidg
17
16
  QMessageBox, QDialog)
18
17
  from PyQt6.QtCore import Qt, QAbstractTableModel, QModelIndex, QThread, pyqtSignal, QTimer
19
18
  from PyQt6.QtGui import QPalette, QColor, QBrush, QPainter, QPen
19
+ from scipy.stats import chi2_contingency, pearsonr
20
20
 
21
21
  # Import matplotlib at the top level
22
22
  import matplotlib
@@ -24,6 +24,7 @@ matplotlib.use('QtAgg')
24
24
  from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
25
25
  from matplotlib.figure import Figure
26
26
  import seaborn as sns
27
+ import matplotlib.pyplot as plt
27
28
 
28
29
  # Create a cache directory in user's home directory
29
30
  CACHE_DIR = os.path.join(Path.home(), '.sqlshell_cache')
@@ -93,6 +94,71 @@ class ExplainerThread(QThread):
93
94
  """Mark the thread as canceled"""
94
95
  self._is_canceled = True
95
96
 
97
+ def calculate_correlation(self, x, y):
98
+ """Calculate correlation between two variables, handling different data types.
99
+ Returns absolute correlation value between 0 and 1."""
100
+ try:
101
+ # Handle missing values
102
+ mask = ~(pd.isna(x) | pd.isna(y))
103
+ x_clean = x[mask]
104
+ y_clean = y[mask]
105
+
106
+ # If too few data points, return default
107
+ if len(x_clean) < 5:
108
+ return 0.0
109
+
110
+ # Check data types
111
+ x_is_numeric = pd.api.types.is_numeric_dtype(x_clean)
112
+ y_is_numeric = pd.api.types.is_numeric_dtype(y_clean)
113
+
114
+ # Case 1: Both numeric - use Pearson correlation
115
+ if x_is_numeric and y_is_numeric:
116
+ corr, _ = pearsonr(x_clean, y_clean)
117
+ return abs(corr)
118
+
119
+ # Case 2: Categorical vs Categorical - use Cramer's V
120
+ elif not x_is_numeric and not y_is_numeric:
121
+ # Convert to categorical codes
122
+ x_cat = pd.Categorical(x_clean).codes
123
+ y_cat = pd.Categorical(y_clean).codes
124
+
125
+ # Create contingency table
126
+ contingency = pd.crosstab(x_cat, y_cat)
127
+
128
+ # Calculate Cramer's V
129
+ chi2, _, _, _ = chi2_contingency(contingency)
130
+ n = contingency.sum().sum()
131
+ phi2 = chi2 / n
132
+
133
+ # Get dimensions
134
+ r, k = contingency.shape
135
+
136
+ # Calculate Cramer's V with correction for dimensions
137
+ cramers_v = np.sqrt(phi2 / min(k-1, r-1)) if min(k-1, r-1) > 0 else 0.0
138
+ return min(cramers_v, 1.0) # Cap at 1.0
139
+
140
+ # Case 3: Mixed types - convert to ranks or categories
141
+ else:
142
+ if x_is_numeric and not y_is_numeric:
143
+ # Convert categorical y to codes
144
+ y_encoded = pd.Categorical(y_clean).codes
145
+
146
+ # Calculate correlation between x and encoded y
147
+ # Using point-biserial correlation (special case of Pearson)
148
+ corr, _ = pearsonr(x_clean, y_encoded)
149
+ return abs(corr)
150
+ else: # y is numeric, x is categorical
151
+ # Convert categorical x to codes
152
+ x_encoded = pd.Categorical(x_clean).codes
153
+
154
+ # Calculate correlation
155
+ corr, _ = pearsonr(x_encoded, y_clean)
156
+ return abs(corr)
157
+
158
+ except Exception as e:
159
+ print(f"Error calculating correlation: {e}")
160
+ return 0.0 # Return zero if correlation calculation fails
161
+
96
162
  def run(self):
97
163
  try:
98
164
  # Check if canceled
@@ -124,13 +190,25 @@ class ExplainerThread(QThread):
124
190
  # Check if canceled
125
191
  if self._is_canceled:
126
192
  return
193
+
194
+ # Early check for empty dataframe or no columns
195
+ if self.df.empty or len(self.df.columns) == 0:
196
+ raise ValueError("The dataframe is empty or has no columns for analysis")
127
197
 
128
198
  # No cache found, proceed with computation
129
199
  self.progress.emit(5, "Computing new analysis...")
130
200
 
201
+ # Validate that the target column exists in the dataframe
202
+ if self.column not in self.df.columns:
203
+ raise ValueError(f"Target column '{self.column}' not found in the dataframe")
204
+
131
205
  # Create a copy to avoid modifying the original dataframe
132
206
  df = self.df.copy()
133
207
 
208
+ # Verify we have data to work with
209
+ if len(df) == 0:
210
+ raise ValueError("No data available for analysis (empty dataframe)")
211
+
134
212
  # Sample up to 500 rows for better statistical significance while maintaining speed
135
213
  if len(df) > 500:
136
214
  sample_size = 500 # Increased sample size for better analysis
@@ -150,17 +228,35 @@ class ExplainerThread(QThread):
150
228
  if col == self.column: # Don't drop target column
151
229
  continue
152
230
  try:
153
- # Drop if more than 95% unique values (likely ID column)
154
- if df[col].nunique() / len(df) > 0.95:
231
+ # Only drop columns with extremely high uniqueness (99% instead of 95%)
232
+ # This ensures we keep more features for analysis
233
+ if df[col].nunique() / len(df) > 0.99 and len(df) > 100:
155
234
  cols_to_drop.append(col)
156
- # Drop if more than 50% missing values
157
- elif df[col].isna().mean() > 0.5:
235
+ # Only drop columns with very high missing values (80% instead of 50%)
236
+ elif df[col].isna().mean() > 0.8:
158
237
  cols_to_drop.append(col)
159
238
  except:
160
239
  # If we can't analyze the column, drop it
161
240
  cols_to_drop.append(col)
162
241
 
163
- # Drop identified columns
242
+ # Drop identified columns, but ensure we keep at least some features
243
+ remaining_cols = [col for col in df.columns if col != self.column and col not in cols_to_drop]
244
+
245
+ # If dropping would leave us with no features, keep at least 3 columns (or all if less than 3)
246
+ if len(remaining_cols) == 0 and len(cols_to_drop) > 0:
247
+ # Sort dropped columns by uniqueness (keep those with lower uniqueness)
248
+ col_uniqueness = {}
249
+ for col in cols_to_drop:
250
+ try:
251
+ col_uniqueness[col] = df[col].nunique() / len(df)
252
+ except:
253
+ col_uniqueness[col] = 1.0 # Assume high uniqueness for problematic columns
254
+
255
+ # Sort by uniqueness and keep the least unique columns
256
+ cols_to_keep = sorted(col_uniqueness.items(), key=lambda x: x[1])[:min(3, len(cols_to_drop))]
257
+ cols_to_drop = [col for col in cols_to_drop if col not in [c[0] for c in cols_to_keep]]
258
+ print(f"Keeping {len(cols_to_keep)} columns to ensure analysis can proceed")
259
+
164
260
  if cols_to_drop:
165
261
  self.progress.emit(20, f"Removing {len(cols_to_drop)} low-information columns...")
166
262
  df = df.drop(columns=cols_to_drop)
@@ -169,26 +265,56 @@ class ExplainerThread(QThread):
169
265
  if self.column not in df.columns:
170
266
  raise ValueError(f"Target column '{self.column}' not found in dataframe after preprocessing")
171
267
 
268
+ # Calculate correlation coefficients first
269
+ self.progress.emit(25, "Calculating correlation measures...")
270
+ correlations = {}
271
+
272
+ # Get all feature columns (excluding target)
273
+ feature_cols = [col for col in df.columns if col != self.column]
274
+
275
+ # Calculate correlation for each feature
276
+ for col in feature_cols:
277
+ try:
278
+ # Calculate correlation between each feature and target
279
+ cor_val = self.calculate_correlation(df[col], df[self.column])
280
+ correlations[col] = cor_val
281
+ except Exception as e:
282
+ print(f"Error calculating correlation for {col}: {e}")
283
+ correlations[col] = 0.0
284
+
172
285
  # Separate features and target
173
- self.progress.emit(25, "Preparing features and target...")
286
+ self.progress.emit(30, "Preparing features and target...")
174
287
  X = df.drop(columns=[self.column])
175
288
  y = df[self.column]
176
289
 
177
290
  # Handle high-cardinality categorical features
178
- self.progress.emit(30, "Encoding categorical features...")
291
+ self.progress.emit(35, "Encoding categorical features...")
179
292
  # Use a simpler approach - just one-hot encode columns with few unique values
180
- # and drop high-cardinality columns completely for speed
293
+ # and encode (don't drop) high-cardinality columns for speed
181
294
  categorical_cols = X.select_dtypes(include='object').columns
182
- high_cardinality_threshold = 10 # Lower threshold to drop more columns
295
+ high_cardinality_threshold = 20 # Higher threshold to keep more columns
296
+
297
+ # Keep track of how many columns we've processed
298
+ columns_processed = 0
299
+ columns_kept = 0
183
300
 
184
301
  for col in categorical_cols:
302
+ columns_processed += 1
185
303
  unique_count = X[col].nunique()
304
+ # Always keep the column, but use different encoding strategies based on cardinality
186
305
  if unique_count <= high_cardinality_threshold:
187
306
  # Simple label encoding for low-cardinality features
188
307
  X[col] = X[col].fillna('_MISSING_').astype('category').cat.codes
308
+ columns_kept += 1
189
309
  else:
190
- # Drop high-cardinality features to speed up analysis
191
- X = X.drop(columns=[col])
310
+ # For high-cardinality features, still encode them but with a simpler approach
311
+ # Use label encoding instead of dropping
312
+ X[col] = X[col].fillna('_MISSING_').astype('category').cat.codes
313
+ columns_kept += 1
314
+
315
+ # Log how many columns were kept
316
+ if columns_processed > 0:
317
+ self.progress.emit(40, f"Encoded {columns_kept} categorical columns out of {columns_processed}")
192
318
 
193
319
  # Handle target column in a simpler, faster way
194
320
  if y.dtype == 'object':
@@ -199,7 +325,12 @@ class ExplainerThread(QThread):
199
325
  y = y.fillna(y.mean() if pd.api.types.is_numeric_dtype(y) else y.mode()[0])
200
326
 
201
327
  # Train/test split
202
- self.progress.emit(40, "Splitting data into train/test sets...")
328
+ self.progress.emit(45, "Splitting data into train/test sets...")
329
+
330
+ # Make sure we still have features to work with
331
+ if X.shape[1] == 0:
332
+ raise ValueError("No features remain after preprocessing. Try selecting a different target column.")
333
+
203
334
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
204
335
 
205
336
  # Check if canceled
@@ -208,9 +339,21 @@ class ExplainerThread(QThread):
208
339
 
209
340
  # Train a tree-based model
210
341
  self.progress.emit(50, "Training XGBoost model...")
342
+
343
+ # Check the number of features left for analysis
344
+ feature_count = X_train.shape[1]
345
+
346
+ # Adjust model complexity based on feature count
347
+ if feature_count < 3:
348
+ max_depth = 1 # Very simple trees for few features
349
+ n_estimators = 10 # Use more trees to compensate
350
+ else:
351
+ max_depth = 2 # Still shallow trees
352
+ n_estimators = 5 # Fewer trees for more features
353
+
211
354
  model = xgb.XGBRegressor(
212
- n_estimators=5, # Absolute minimum number of trees
213
- max_depth=2, # Very shallow trees
355
+ n_estimators=n_estimators,
356
+ max_depth=max_depth,
214
357
  learning_rate=0.3, # Higher learning rate to compensate for fewer trees
215
358
  tree_method='hist', # Fast histogram method
216
359
  subsample=0.7, # Use 70% of data per tree
@@ -229,37 +372,106 @@ class ExplainerThread(QThread):
229
372
  try:
230
373
  model.fit(X_train, y_train)
231
374
  except Exception as e:
375
+ # Log the error for debugging
376
+ print(f"Initial XGBoost fit failed: {str(e)}")
377
+
232
378
  # If we encounter an error, try with an even smaller and simpler model
233
379
  self.progress.emit(55, "Adjusting model parameters due to computational constraints...")
234
- model = xgb.XGBRegressor(
235
- n_estimators=5,
236
- max_depth=2,
237
- subsample=0.5,
238
- colsample_bytree=0.5,
239
- n_jobs=1
240
- )
241
- model.fit(X_train, y_train)
380
+ try:
381
+ # Try a simpler regressor with more conservative parameters
382
+ model = xgb.XGBRegressor(
383
+ n_estimators=3,
384
+ max_depth=1,
385
+ subsample=0.5,
386
+ colsample_bytree=0.5,
387
+ n_jobs=1,
388
+ verbosity=0
389
+ )
390
+ model.fit(X_train, y_train)
391
+ except Exception as inner_e:
392
+ # If even the simpler model fails, resort to a fallback strategy
393
+ print(f"Even simpler XGBoost failed: {str(inner_e)}")
394
+ self.progress.emit(60, "Using fallback importance calculation method...")
395
+
396
+ # Create a basic feature importance based on correlation with target
397
+ # This is a simple fallback when model training fails
398
+ importance = []
399
+ for col in X.columns:
400
+ try:
401
+ # Use pre-calculated correlations for fallback importance
402
+ corr_value = correlations.get(col, 0.5)
403
+ # Scale correlation to make a reasonable importance value
404
+ # Higher correlation = higher importance
405
+ importance.append(0.5 + corr_value/2 if not pd.isna(corr_value) else 0.5)
406
+ except:
407
+ # If correlation fails, use default
408
+ importance.append(0.5)
409
+
410
+ # Normalize to sum to 1
411
+ importance = np.array(importance)
412
+ if sum(importance) > 0:
413
+ importance = importance / sum(importance)
414
+ else:
415
+ # Equal importance if everything fails
416
+ importance = np.ones(len(X.columns)) / len(X.columns)
417
+
418
+ # Skip the model-based code path since we calculated importances manually
419
+ self.progress.emit(80, "Creating importance results...")
420
+ feature_importance = pd.DataFrame({
421
+ 'feature': X.columns,
422
+ 'importance_value': importance,
423
+ 'correlation': [correlations.get(col, 0.0) for col in X.columns]
424
+ }).sort_values(by='importance_value', ascending=False)
425
+
426
+ # Cache the results for future use
427
+ self.progress.emit(95, "Caching results for future use...")
428
+ cache_results(self.df, self.column, feature_importance)
429
+
430
+ # Clean up after computation
431
+ del df, X, y, X_train, X_test, y_train, y_test
432
+ gc.collect()
433
+
434
+ # Check if canceled
435
+ if self._is_canceled:
436
+ return
437
+
438
+ # Emit the result
439
+ self.progress.emit(100, "Analysis complete (fallback method)")
440
+ self.result.emit(feature_importance)
441
+ return
242
442
 
243
443
  # Check if canceled
244
444
  if self._is_canceled:
245
445
  return
246
446
 
247
- # Skip SHAP and use model feature importance directly for simplicity and reliability
248
- self.progress.emit(80, "Calculating feature importance...")
447
+ # Get feature importance directly from XGBoost
448
+ self.progress.emit(80, "Calculating feature importance and correlations...")
249
449
 
250
450
  try:
451
+ # Check if we have features to analyze
452
+ if X.shape[1] == 0:
453
+ raise ValueError("No features available for importance analysis")
454
+
251
455
  # Get feature importance directly from XGBoost
252
456
  importance = model.feature_importances_
253
457
 
254
- # Create and sort the importance dataframe
255
- shap_importance = pd.DataFrame({
458
+ # Verify importance values are valid
459
+ if np.isnan(importance).any() or np.isinf(importance).any():
460
+ # Handle NaN or Inf values
461
+ print("Warning: Invalid importance values detected, using fallback method")
462
+ # Replace with equal importance
463
+ importance = np.ones(len(X.columns)) / len(X.columns)
464
+
465
+ # Create and sort the importance dataframe with correlations
466
+ feature_importance = pd.DataFrame({
256
467
  'feature': X.columns,
257
- 'mean_abs_shap': importance
258
- }).sort_values(by='mean_abs_shap', ascending=False)
468
+ 'importance_value': importance,
469
+ 'correlation': [correlations.get(col, 0.0) for col in X.columns]
470
+ }).sort_values(by='importance_value', ascending=False)
259
471
 
260
472
  # Cache the results for future use
261
473
  self.progress.emit(95, "Caching results for future use...")
262
- cache_results(self.df, self.column, shap_importance)
474
+ cache_results(self.df, self.column, feature_importance)
263
475
 
264
476
  # Clean up after computation
265
477
  del df, X, y, X_train, X_test, y_train, y_test, model
@@ -271,7 +483,7 @@ class ExplainerThread(QThread):
271
483
 
272
484
  # Emit the result
273
485
  self.progress.emit(100, "Analysis complete")
274
- self.result.emit(shap_importance)
486
+ self.result.emit(feature_importance)
275
487
  return
276
488
 
277
489
  except Exception as e:
@@ -279,37 +491,114 @@ class ExplainerThread(QThread):
279
491
  import traceback
280
492
  traceback.print_exc()
281
493
 
282
- # Last resort: create equal importance for all features
283
- importance_values = np.ones(len(X.columns)) / len(X.columns)
284
- shap_importance = pd.DataFrame({
494
+ # Create fallback importance values when model-based approach fails
495
+ self.progress.emit(85, "Using alternative importance calculation method...")
496
+
497
+ try:
498
+ # Try correlation-based approach first
499
+ importance = []
500
+ has_valid_correlations = False
501
+
502
+ for col in X.columns:
503
+ try:
504
+ # Use pre-calculated correlations
505
+ corr = correlations.get(col, 0.1)
506
+ if not pd.isna(corr):
507
+ importance.append(corr)
508
+ has_valid_correlations = True
509
+ else:
510
+ importance.append(0.1) # Default for failed correlation
511
+ except:
512
+ # Default value for any error
513
+ importance.append(0.1)
514
+
515
+ # Normalize importance values
516
+ importance = np.array(importance)
517
+ if has_valid_correlations and sum(importance) > 0:
518
+ # If we have valid correlations, use them normalized
519
+ importance = importance / max(sum(importance), 0.001)
520
+ else:
521
+ # Otherwise use frequency-based heuristic
522
+ print("Using frequency-based feature importance as fallback")
523
+ # Count unique values as a proxy for importance
524
+ importance = []
525
+ total_rows = len(X)
526
+
527
+ for col in X.columns:
528
+ try:
529
+ # More unique values could indicate more information content
530
+ # But we invert the ratio so columns with fewer unique values
531
+ # (more predictive) get higher importance
532
+ uniqueness = X[col].nunique() / total_rows
533
+ # Invert and scale between 0.1 and 1.0
534
+ val = 1.0 - (0.9 * uniqueness)
535
+ importance.append(max(0.1, min(1.0, val)))
536
+ except:
537
+ importance.append(0.1) # Default value
538
+
539
+ # Normalize
540
+ importance = np.array(importance)
541
+ importance = importance / max(sum(importance), 0.001)
542
+
543
+ except Exception as fallback_error:
544
+ # Last resort: create equal importance for all features
545
+ print(f"Fallback error: {fallback_error}, using equal importance")
546
+ importance_values = np.ones(len(X.columns)) / max(len(X.columns), 1)
547
+ importance = importance_values
548
+
549
+ # Create dataframe with results, including correlations
550
+ feature_importance = pd.DataFrame({
285
551
  'feature': X.columns,
286
- 'mean_abs_shap': importance_values
287
- }).sort_values(by='mean_abs_shap', ascending=False)
552
+ 'importance_value': importance,
553
+ 'correlation': [correlations.get(col, 0.0) for col in X.columns]
554
+ }).sort_values(by='importance_value', ascending=False)
288
555
 
289
556
  # Cache the results
290
557
  try:
291
- cache_results(self.df, self.column, shap_importance)
558
+ cache_results(self.df, self.column, feature_importance)
292
559
  except:
293
560
  pass # Ignore cache errors
294
561
 
295
562
  # Clean up
296
563
  try:
297
- del df, X, y, X_train, X_test, y_train, y_test, model
564
+ del df, X, y, X_train, X_test, y_train, y_test
298
565
  gc.collect()
299
566
  except:
300
567
  pass
301
568
 
302
569
  # Emit the result
303
- self.progress.emit(100, "Analysis complete (with default values)")
304
- self.result.emit(shap_importance)
570
+ self.progress.emit(100, "Analysis complete (with fallback methods)")
571
+ self.result.emit(feature_importance)
305
572
  return
306
573
 
574
+ except IndexError as e:
575
+ # Handle index errors with more detail
576
+ import traceback
577
+ import inspect
578
+ trace = traceback.format_exc()
579
+
580
+ # Get more detailed information
581
+ frame = inspect.trace()[-1]
582
+ frame_info = inspect.getframeinfo(frame[0])
583
+ filename = frame_info.filename
584
+ lineno = frame_info.lineno
585
+ function = frame_info.function
586
+ code_context = frame_info.code_context[0].strip() if frame_info.code_context else "Unknown code context"
587
+
588
+ # Format a more detailed error message
589
+ detail_msg = f"IndexError: {str(e)}\nLocation: {filename}:{lineno} in function '{function}'\nCode: {code_context}\n\n{trace}"
590
+ print(detail_msg) # Print to console for debugging
591
+
592
+ if not self._is_canceled:
593
+ self.error.emit(f"Index error at line {lineno} in {function}:\n{str(e)}\nCode: {code_context}")
594
+
307
595
  except Exception as e:
308
596
  if not self._is_canceled: # Only emit error if not canceled
309
597
  import traceback
598
+ trace = traceback.format_exc()
310
599
  print(f"Error in ExplainerThread: {str(e)}")
311
- print(traceback.format_exc()) # Print full stack trace to help debug
312
- self.error.emit(str(e))
600
+ print(trace) # Print full stack trace to help debug
601
+ self.error.emit(f"{str(e)}\n\nTrace: {trace}")
313
602
 
314
603
  def analyze_column(self):
315
604
  if self.df is None or self.column_selector.currentText() == "":
@@ -395,7 +684,9 @@ class ExplainerThread(QThread):
395
684
  self.progress_label.hide()
396
685
  self.cancel_button.hide()
397
686
 
398
- # Update importance table incrementally
687
+ # Update importance table to include correlation column
688
+ self.importance_table.setColumnCount(3)
689
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance", "Abs. Correlation"])
399
690
  self.importance_table.setRowCount(len(importance_df))
400
691
 
401
692
  # Using a timer for incremental updates
@@ -404,53 +695,276 @@ class ExplainerThread(QThread):
404
695
  self.render_timer = QTimer()
405
696
  self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
406
697
  self.render_timer.start(10) # Update every 10ms
407
-
698
+
408
699
  def render_next_batch(self, batch_size):
409
- if self.current_row >= len(self.importance_df):
410
- # All rows rendered, now render the chart and stop the timer
411
- self.render_chart()
412
- self.render_timer.stop()
413
- return
700
+ try:
701
+ if self.current_row >= len(self.importance_df):
702
+ # All rows rendered, now render the chart and stop the timer
703
+ self.render_chart()
704
+ self.render_timer.stop()
705
+ return
414
706
 
415
- # Render a batch of rows
416
- end_row = min(self.current_row + batch_size, len(self.importance_df))
417
- for row in range(self.current_row, end_row):
418
- feature = self.importance_df.iloc[row]['feature']
419
- mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
420
- self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
421
- self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
707
+ # Render a batch of rows
708
+ end_row = min(self.current_row + batch_size, len(self.importance_df))
709
+ for row in range(self.current_row, end_row):
710
+ try:
711
+ # Check if row exists in dataframe to prevent index errors
712
+ if row < len(self.importance_df):
713
+ feature = self.importance_df.iloc[row]['feature']
714
+ importance_value = self.importance_df.iloc[row]['importance_value']
715
+
716
+ # Add correlation if available
717
+ correlation = self.importance_df.iloc[row].get('correlation', None)
718
+ if correlation is not None:
719
+ self.importance_table.setItem(row, 0, QTableWidgetItem(str(feature)))
720
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(importance_value, 4))))
721
+ self.importance_table.setItem(row, 2, QTableWidgetItem(str(round(correlation, 4))))
722
+ else:
723
+ self.importance_table.setItem(row, 0, QTableWidgetItem(str(feature)))
724
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(importance_value, 4))))
725
+ else:
726
+ # Handle out of range index
727
+ print(f"Warning: Row {row} is out of range (max: {len(self.importance_df)-1})")
728
+ self.importance_table.setItem(row, 0, QTableWidgetItem("Error"))
729
+ self.importance_table.setItem(row, 1, QTableWidgetItem("Out of range"))
730
+ self.importance_table.setItem(row, 2, QTableWidgetItem("N/A"))
731
+ except (IndexError, KeyError) as e:
732
+ # Enhanced error reporting for index and key errors
733
+ import traceback
734
+ trace = traceback.format_exc()
735
+ error_msg = f"Error rendering row {row}: {e.__class__.__name__}: {e}\n{trace}"
736
+ print(error_msg)
737
+
738
+ # Handle missing data in the dataframe gracefully
739
+ self.importance_table.setItem(row, 0, QTableWidgetItem(f"Error: {e.__class__.__name__}"))
740
+ self.importance_table.setItem(row, 1, QTableWidgetItem(f"{str(e)[:20]}"))
741
+ self.importance_table.setItem(row, 2, QTableWidgetItem("Error"))
742
+ except Exception as e:
743
+ # Catch any other exceptions
744
+ print(f"Unexpected error rendering row {row}: {e.__class__.__name__}: {e}")
745
+ self.importance_table.setItem(row, 0, QTableWidgetItem(f"Error: {e.__class__.__name__}"))
746
+ self.importance_table.setItem(row, 1, QTableWidgetItem("See console for details"))
747
+ self.importance_table.setItem(row, 2, QTableWidgetItem("Error"))
748
+
749
+ self.current_row = end_row
750
+ QApplication.processEvents() # Allow UI to update
751
+ except Exception as e:
752
+ # Catch any exceptions in the rendering loop itself
753
+ import traceback
754
+ trace = traceback.format_exc()
755
+ error_msg = f"Error in render_next_batch: {e.__class__.__name__}: {e}\n{trace}"
756
+ print(error_msg)
757
+
758
+ # Try to stop the timer to prevent further errors
759
+ try:
760
+ if self.render_timer and self.render_timer.isActive():
761
+ self.render_timer.stop()
762
+ except:
763
+ pass
422
764
 
423
- self.current_row = end_row
424
- QApplication.processEvents() # Allow UI to update
765
+ # Show error
766
+ QMessageBox.critical(self, "Rendering Error",
767
+ f"Error rendering results: {e.__class__.__name__}: {e}")
425
768
 
426
769
  def render_chart(self):
427
770
  # Create horizontal bar chart
428
- self.chart_view.axes.clear()
429
-
430
- # Limit to top 20 features for better visualization
431
- plot_df = self.importance_df.head(20)
432
-
433
- # Plot with custom colors
434
- bars = self.chart_view.axes.barh(
435
- plot_df['feature'],
436
- plot_df['mean_abs_shap'],
437
- color='skyblue'
438
- )
439
-
440
- # Add values at the end of bars
441
- for bar in bars:
442
- width = bar.get_width()
443
- self.chart_view.axes.text(
444
- width * 1.05,
445
- bar.get_y() + bar.get_height()/2,
446
- f'{width:.2f}',
447
- va='center'
771
+ try:
772
+ if self.importance_df is None or len(self.importance_df) == 0:
773
+ # No data to render
774
+ self.chart_view.axes.clear()
775
+ self.chart_view.axes.text(0.5, 0.5, "No data available for chart",
776
+ ha='center', va='center', fontsize=12, color='gray')
777
+ self.chart_view.axes.set_axis_off()
778
+ self.chart_view.draw()
779
+ return
780
+
781
+ self.chart_view.axes.clear()
782
+
783
+ # Get a sorted copy based on current sort key
784
+ plot_df = self.importance_df.sort_values(by=self.current_sort, ascending=False).head(20).copy()
785
+
786
+ # Verify we have data before proceeding
787
+ if len(plot_df) == 0:
788
+ self.chart_view.axes.text(0.5, 0.5, "No features found with importance values",
789
+ ha='center', va='center', fontsize=12, color='gray')
790
+ self.chart_view.axes.set_axis_off()
791
+ self.chart_view.draw()
792
+ return
793
+
794
+ # Check required columns exist
795
+ required_columns = ['feature', 'importance_value']
796
+ missing_columns = [col for col in required_columns if col not in plot_df.columns]
797
+ if missing_columns:
798
+ error_msg = f"Missing required columns: {', '.join(missing_columns)}"
799
+ self.chart_view.axes.text(0.5, 0.5, error_msg,
800
+ ha='center', va='center', fontsize=12, color='red')
801
+ self.chart_view.axes.set_axis_off()
802
+ self.chart_view.draw()
803
+ print(f"Chart rendering error: {error_msg}")
804
+ return
805
+
806
+ # Truncate long feature names for better display
807
+ max_feature_length = 30
808
+ plot_df['display_feature'] = plot_df['feature'].apply(
809
+ lambda x: (str(x)[:max_feature_length] + '...') if len(str(x)) > max_feature_length else str(x)
448
810
  )
449
811
 
450
- self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
451
- self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
452
- self.chart_view.figure.tight_layout()
453
- self.chart_view.draw()
812
+ # Reverse order for better display (highest at top)
813
+ plot_df = plot_df.iloc[::-1].reset_index(drop=True)
814
+
815
+ # Create a figure with two subplots side by side
816
+ self.chart_view.figure.clear()
817
+ gs = self.chart_view.figure.add_gridspec(1, 2, width_ratios=[3, 2])
818
+
819
+ # First subplot for importance
820
+ ax1 = self.chart_view.figure.add_subplot(gs[0, 0])
821
+
822
+ # Create a colormap for better visualization
823
+ cmap = plt.cm.Blues
824
+ colors = cmap(np.linspace(0.4, 0.8, len(plot_df)))
825
+
826
+ # Plot with custom colors
827
+ bars = ax1.barh(
828
+ plot_df['display_feature'],
829
+ plot_df['importance_value'],
830
+ color=colors,
831
+ height=0.7, # Thinner bars for more spacing
832
+ alpha=0.8
833
+ )
834
+
835
+ # Add values at the end of bars
836
+ for bar in bars:
837
+ width = bar.get_width()
838
+ ax1.text(
839
+ width * 1.05,
840
+ bar.get_y() + bar.get_height()/2,
841
+ f'{width:.2f}',
842
+ va='center',
843
+ fontsize=9,
844
+ fontweight='bold'
845
+ )
846
+
847
+ # Add grid for better readability
848
+ ax1.grid(True, axis='x', linestyle='--', alpha=0.3)
849
+
850
+ # Remove unnecessary spines
851
+ for spine in ['top', 'right']:
852
+ ax1.spines[spine].set_visible(False)
853
+
854
+ # Make labels more readable
855
+ ax1.tick_params(axis='y', labelsize=9)
856
+
857
+ # Set title and labels
858
+ ax1.set_title(f'Feature Importance for {self.column_selector.currentText()}')
859
+ ax1.set_xlabel('Importance Value')
860
+
861
+ # Add a note about the sorting order
862
+ sort_label = "Sorted by: " + ("Importance" if self.current_sort == 'importance_value' else "Correlation")
863
+
864
+ # Second subplot for correlation if available
865
+ if 'correlation' in plot_df.columns:
866
+ ax2 = self.chart_view.figure.add_subplot(gs[0, 1], sharey=ax1)
867
+
868
+ # Create a colormap for correlation - use a different color
869
+ cmap_corr = plt.cm.Reds
870
+ colors_corr = cmap_corr(np.linspace(0.4, 0.8, len(plot_df)))
871
+
872
+ # Plot correlation bars
873
+ corr_bars = ax2.barh(
874
+ plot_df['display_feature'],
875
+ plot_df['correlation'],
876
+ color=colors_corr,
877
+ height=0.7,
878
+ alpha=0.8
879
+ )
880
+
881
+ # Add values at the end of correlation bars
882
+ for bar in corr_bars:
883
+ width = bar.get_width()
884
+ ax2.text(
885
+ width * 1.05,
886
+ bar.get_y() + bar.get_height()/2,
887
+ f'{width:.2f}',
888
+ va='center',
889
+ fontsize=9,
890
+ fontweight='bold'
891
+ )
892
+
893
+ # Add grid and styling
894
+ ax2.grid(True, axis='x', linestyle='--', alpha=0.3)
895
+ ax2.set_title('Absolute Correlation')
896
+ ax2.set_xlabel('Correlation Value')
897
+
898
+ # Hide y-axis labels since they're shared with the first plot
899
+ ax2.set_yticklabels([])
900
+
901
+ # Remove unnecessary spines
902
+ for spine in ['top', 'right']:
903
+ ax2.spines[spine].set_visible(False)
904
+
905
+ # Add a note about the current sort order
906
+ self.chart_view.figure.text(0.5, 0.01, sort_label, ha='center', fontsize=9, style='italic')
907
+
908
+ # Adjust figure size based on number of features
909
+ feature_count = len(plot_df)
910
+ self.chart_view.figure.set_figheight(max(5, min(4 + feature_count * 0.3, 12)))
911
+
912
+ # Adjust layout and draw
913
+ self.chart_view.figure.tight_layout(rect=[0, 0.03, 1, 0.97]) # Make room for sort label
914
+ self.chart_view.draw()
915
+
916
+ except IndexError as e:
917
+ # Special handling for index errors with detailed information
918
+ import traceback
919
+ import inspect
920
+
921
+ # Get stack trace information
922
+ trace = traceback.format_exc()
923
+
924
+ # Try to get line and context information
925
+ try:
926
+ frame = inspect.trace()[-1]
927
+ frame_info = inspect.getframeinfo(frame[0])
928
+ filename = frame_info.filename
929
+ lineno = frame_info.lineno
930
+ function = frame_info.function
931
+ code_context = frame_info.code_context[0].strip() if frame_info.code_context else "Unknown code context"
932
+
933
+ # Detailed error message
934
+ detail_msg = f"IndexError at line {lineno} in {function}: {str(e)}\nCode: {code_context}"
935
+ print(f"Chart rendering error: {detail_msg}\n{trace}")
936
+
937
+ # Display error in chart
938
+ self.chart_view.axes.clear()
939
+ self.chart_view.axes.text(0.5, 0.5,
940
+ f"Index Error in chart rendering:\n{str(e)}\nAt line {lineno}: {code_context}",
941
+ ha='center', va='center', fontsize=12, color='red',
942
+ wrap=True)
943
+ self.chart_view.axes.set_axis_off()
944
+ self.chart_view.draw()
945
+ except Exception as inner_e:
946
+ # Fallback if the detailed error reporting fails
947
+ print(f"Error getting detailed error info: {inner_e}")
948
+ print(f"Original error: {e}\n{trace}")
949
+
950
+ self.chart_view.axes.clear()
951
+ self.chart_view.axes.text(0.5, 0.5, f"Index Error: {str(e)}",
952
+ ha='center', va='center', fontsize=12, color='red')
953
+ self.chart_view.axes.set_axis_off()
954
+ self.chart_view.draw()
955
+ except Exception as e:
956
+ # Recover gracefully from any chart rendering errors with detailed information
957
+ import traceback
958
+ trace = traceback.format_exc()
959
+ error_msg = f"Error rendering chart: {e.__class__.__name__}: {str(e)}"
960
+ print(f"{error_msg}\n{trace}")
961
+
962
+ self.chart_view.axes.clear()
963
+ self.chart_view.axes.text(0.5, 0.5, error_msg,
964
+ ha='center', va='center', fontsize=12, color='red',
965
+ wrap=True)
966
+ self.chart_view.axes.set_axis_off()
967
+ self.chart_view.draw()
454
968
 
455
969
  def handle_error(self, error_message):
456
970
  """Handle errors during analysis"""
@@ -465,23 +979,32 @@ class ExplainerThread(QThread):
465
979
  # Print error to console for debugging
466
980
  print(f"Error in column profiler: {error_message}")
467
981
 
468
- # Show error message
469
- QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
982
+ # Show error message with more details
983
+ msg_box = QMessageBox(self)
984
+ msg_box.setIcon(QMessageBox.Icon.Critical)
985
+ msg_box.setWindowTitle("Error")
986
+ msg_box.setText("An error occurred during analysis")
987
+ msg_box.setDetailedText(error_message)
988
+ msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
989
+ msg_box.exec()
470
990
 
471
991
  # Show a message in the UI as well
472
992
  self.importance_table.setRowCount(1)
473
- self.importance_table.setColumnCount(1)
474
- self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
993
+ self.importance_table.setColumnCount(3)
994
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance", "Abs. Correlation"])
995
+ self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message.split('\n')[0]}"))
996
+ self.importance_table.setItem(0, 1, QTableWidgetItem(""))
997
+ self.importance_table.setItem(0, 2, QTableWidgetItem(""))
475
998
  self.importance_table.resizeColumnsToContents()
476
999
 
477
1000
  # Update the chart to show error
478
1001
  self.chart_view.axes.clear()
479
- self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
1002
+ self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message.split('\n')[0]}",
480
1003
  ha='center', va='center', fontsize=12, color='red',
481
1004
  wrap=True)
482
1005
  self.chart_view.axes.set_axis_off()
483
1006
  self.chart_view.draw()
484
-
1007
+
485
1008
  def closeEvent(self, event):
486
1009
  """Clean up when the window is closed"""
487
1010
  # Stop any running timer
@@ -547,12 +1070,314 @@ class ExplainerThread(QThread):
547
1070
  # Hide the progress label after 2 seconds
548
1071
  QTimer.singleShot(2000, self.progress_label.hide)
549
1072
 
550
- # Custom matplotlib canvas for embedding in Qt
551
- class MatplotlibCanvas(FigureCanvasQTAgg):
552
- def __init__(self, width=5, height=4, dpi=100):
553
- self.figure = Figure(figsize=(width, height), dpi=dpi)
554
- self.axes = self.figure.add_subplot(111)
555
- super().__init__(self.figure)
1073
+ def show_relationship_visualization(self, row, column):
1074
+ """Show visualization of relationship between selected feature and target column"""
1075
+ if self.importance_df is None or row < 0 or row >= len(self.importance_df):
1076
+ return
1077
+
1078
+ # Get the feature name and target column
1079
+ try:
1080
+ feature = self.importance_df.iloc[row]['feature']
1081
+ target = self.column_selector.currentText()
1082
+
1083
+ # Verify both columns exist in the dataframe
1084
+ if feature not in self.df.columns:
1085
+ QMessageBox.warning(self, "Column Not Found",
1086
+ f"Feature column '{feature}' not found in the dataframe")
1087
+ return
1088
+
1089
+ if target not in self.df.columns:
1090
+ QMessageBox.warning(self, "Column Not Found",
1091
+ f"Target column '{target}' not found in the dataframe")
1092
+ return
1093
+ except Exception as e:
1094
+ QMessageBox.critical(self, "Error", f"Error getting column data: {str(e)}")
1095
+ return
1096
+
1097
+ # Create a dialog to show the visualization
1098
+ dialog = QDialog(self)
1099
+ dialog.setWindowTitle(f"Relationship: {feature} vs {target}")
1100
+ dialog.resize(900, 700)
1101
+
1102
+ # Create layout
1103
+ layout = QVBoxLayout(dialog)
1104
+
1105
+ # Create canvas for the plot
1106
+ canvas = MatplotlibCanvas(width=8, height=6, dpi=100)
1107
+ layout.addWidget(canvas)
1108
+
1109
+ # Determine the data types
1110
+ feature_is_numeric = pd.api.types.is_numeric_dtype(self.df[feature])
1111
+ target_is_numeric = pd.api.types.is_numeric_dtype(self.df[target])
1112
+
1113
+ # Get unique counts to determine if we have high cardinality
1114
+ feature_unique_count = self.df[feature].nunique()
1115
+ target_unique_count = self.df[target].nunique()
1116
+
1117
+ # Define high cardinality threshold
1118
+ high_cardinality_threshold = 10
1119
+
1120
+ # Clear the figure
1121
+ canvas.axes.clear()
1122
+
1123
+ # Create a working copy of the dataframe
1124
+ working_df = self.df.copy()
1125
+
1126
+ # Prepare data for high cardinality columns
1127
+ if not feature_is_numeric and feature_unique_count > high_cardinality_threshold:
1128
+ # Get the top N categories by frequency
1129
+ top_categories = self.df[feature].value_counts().nlargest(high_cardinality_threshold).index.tolist()
1130
+ # Create "Other" category for remaining values
1131
+ working_df[feature] = working_df[feature].apply(lambda x: x if x in top_categories else 'Other')
1132
+
1133
+ if not target_is_numeric and target_unique_count > high_cardinality_threshold:
1134
+ top_categories = self.df[target].value_counts().nlargest(high_cardinality_threshold).index.tolist()
1135
+ working_df[target] = working_df[target].apply(lambda x: x if x in top_categories else 'Other')
1136
+
1137
+ # Create appropriate visualization based on data types and cardinality
1138
+ if feature_is_numeric and target_is_numeric:
1139
+ # Scatter plot for numeric vs numeric
1140
+ # Use hexbin for large datasets to avoid overplotting
1141
+ if len(working_df) > 100:
1142
+ canvas.axes.hexbin(
1143
+ working_df[feature],
1144
+ working_df[target],
1145
+ gridsize=25,
1146
+ cmap='Blues',
1147
+ mincnt=1
1148
+ )
1149
+ canvas.axes.set_title(f"Hexbin Density Plot: {feature} vs {target}")
1150
+ canvas.axes.set_xlabel(feature)
1151
+ canvas.axes.set_ylabel(target)
1152
+ # Add a colorbar
1153
+ cbar = canvas.figure.colorbar(canvas.axes.collections[0], ax=canvas.axes)
1154
+ cbar.set_label('Count')
1155
+ else:
1156
+ # For smaller datasets, use a scatter plot with transparency
1157
+ sns.scatterplot(
1158
+ x=feature,
1159
+ y=target,
1160
+ data=working_df,
1161
+ ax=canvas.axes,
1162
+ alpha=0.6
1163
+ )
1164
+ # Add regression line
1165
+ sns.regplot(
1166
+ x=feature,
1167
+ y=target,
1168
+ data=working_df,
1169
+ ax=canvas.axes,
1170
+ scatter=False,
1171
+ line_kws={"color": "red"}
1172
+ )
1173
+ canvas.axes.set_title(f"Scatter Plot: {feature} vs {target}")
1174
+
1175
+ elif feature_is_numeric and not target_is_numeric:
1176
+ # Box plot for numeric vs categorical
1177
+ if target_unique_count <= high_cardinality_threshold * 2:
1178
+ # Standard boxplot for reasonable number of categories
1179
+ order = working_df[target].value_counts().nlargest(high_cardinality_threshold * 2).index
1180
+ sns.boxplot(
1181
+ x=target,
1182
+ y=feature,
1183
+ data=working_df,
1184
+ ax=canvas.axes,
1185
+ order=order
1186
+ )
1187
+ canvas.axes.set_title(f"Box Plot: {feature} by {target}")
1188
+ # Rotate x-axis labels for better readability
1189
+ canvas.axes.set_xticklabels(
1190
+ canvas.axes.get_xticklabels(),
1191
+ rotation=45,
1192
+ ha='right'
1193
+ )
1194
+ else:
1195
+ # For very high cardinality, use a violin plot with limited categories
1196
+ order = working_df[target].value_counts().nlargest(high_cardinality_threshold).index
1197
+ working_df_filtered = working_df[working_df[target].isin(order)]
1198
+ sns.violinplot(
1199
+ x=target,
1200
+ y=feature,
1201
+ data=working_df_filtered,
1202
+ ax=canvas.axes,
1203
+ inner='quartile',
1204
+ cut=0
1205
+ )
1206
+ canvas.axes.set_title(f"Violin Plot: {feature} by Top {len(order)} {target} Categories")
1207
+ canvas.axes.set_xticklabels(
1208
+ canvas.axes.get_xticklabels(),
1209
+ rotation=45,
1210
+ ha='right'
1211
+ )
1212
+
1213
+ elif not feature_is_numeric and target_is_numeric:
1214
+ # Bar plot for categorical vs numeric
1215
+ if feature_unique_count <= high_cardinality_threshold * 2:
1216
+ # Use standard barplot for reasonable number of categories
1217
+ order = working_df[feature].value_counts().nlargest(high_cardinality_threshold * 2).index
1218
+ sns.barplot(
1219
+ x=feature,
1220
+ y=target,
1221
+ data=working_df,
1222
+ ax=canvas.axes,
1223
+ order=order,
1224
+ estimator=np.mean,
1225
+ errorbar=('ci', 95),
1226
+ capsize=0.2
1227
+ )
1228
+ canvas.axes.set_title(f"Bar Plot: Average {target} by {feature}")
1229
+
1230
+ # Add value labels on top of bars
1231
+ for p in canvas.axes.patches:
1232
+ canvas.axes.annotate(
1233
+ f'{p.get_height():.1f}',
1234
+ (p.get_x() + p.get_width() / 2., p.get_height()),
1235
+ ha='center',
1236
+ va='bottom',
1237
+ fontsize=8,
1238
+ rotation=0
1239
+ )
1240
+
1241
+ # Rotate x-axis labels if needed
1242
+ if feature_unique_count > 5:
1243
+ canvas.axes.set_xticklabels(
1244
+ canvas.axes.get_xticklabels(),
1245
+ rotation=45,
1246
+ ha='right'
1247
+ )
1248
+ else:
1249
+ # For high cardinality, use a horizontal bar plot with top N categories
1250
+ top_n = 15 # Show top 15 categories
1251
+ # Calculate mean of target for each feature category
1252
+ grouped = working_df.groupby(feature)[target].agg(['mean', 'count', 'std']).reset_index()
1253
+ # Sort by mean and take top categories
1254
+ top_groups = grouped.nlargest(top_n, 'mean')
1255
+
1256
+ # Sort by mean value for better visualization
1257
+ sns.barplot(
1258
+ y=feature,
1259
+ x='mean',
1260
+ data=top_groups,
1261
+ ax=canvas.axes,
1262
+ orient='h'
1263
+ )
1264
+ canvas.axes.set_title(f"Top {top_n} Categories by Average {target}")
1265
+ canvas.axes.set_xlabel(f"Average {target}")
1266
+
1267
+ # Add count annotations
1268
+ for i, row in enumerate(top_groups.itertuples()):
1269
+ canvas.axes.text(
1270
+ row.mean + 0.1,
1271
+ i,
1272
+ f'n={row.count}',
1273
+ va='center',
1274
+ fontsize=8
1275
+ )
1276
+
1277
+ else:
1278
+ # Both feature and target are categorical
1279
+ if feature_unique_count <= high_cardinality_threshold and target_unique_count <= high_cardinality_threshold:
1280
+ # Heatmap for categorical vs categorical with manageable cardinality
1281
+ crosstab = pd.crosstab(
1282
+ working_df[feature],
1283
+ working_df[target],
1284
+ normalize='index'
1285
+ )
1286
+
1287
+ # Create heatmap with improved readability
1288
+ sns.heatmap(
1289
+ crosstab,
1290
+ annot=True,
1291
+ cmap="YlGnBu",
1292
+ ax=canvas.axes,
1293
+ fmt='.2f',
1294
+ linewidths=0.5,
1295
+ annot_kws={"size": 9 if crosstab.size < 30 else 7}
1296
+ )
1297
+ canvas.axes.set_title(f"Heatmap: {feature} vs {target} (proportions)")
1298
+ else:
1299
+ # For high cardinality in both, show a count plot of top categories
1300
+ feature_top = working_df[feature].value_counts().nlargest(8).index
1301
+ target_top = working_df[target].value_counts().nlargest(5).index
1302
+
1303
+ # Filter data to only include top categories
1304
+ filtered_df = working_df[
1305
+ working_df[feature].isin(feature_top) &
1306
+ working_df[target].isin(target_top)
1307
+ ]
1308
+
1309
+ # Create a grouped count plot
1310
+ sns.countplot(
1311
+ x=feature,
1312
+ hue=target,
1313
+ data=filtered_df,
1314
+ ax=canvas.axes
1315
+ )
1316
+ canvas.axes.set_title(f"Count Plot: Top {len(feature_top)} {feature} by Top {len(target_top)} {target}")
1317
+
1318
+ # Rotate x-axis labels
1319
+ canvas.axes.set_xticklabels(
1320
+ canvas.axes.get_xticklabels(),
1321
+ rotation=45,
1322
+ ha='right'
1323
+ )
1324
+
1325
+ # Move legend to a better position
1326
+ canvas.axes.legend(title=target, bbox_to_anchor=(1.05, 1), loc='upper left')
1327
+
1328
+ # Add informational text about data reduction if applicable
1329
+ if (not feature_is_numeric and feature_unique_count > high_cardinality_threshold) or \
1330
+ (not target_is_numeric and target_unique_count > high_cardinality_threshold):
1331
+ canvas.figure.text(
1332
+ 0.5, 0.01,
1333
+ f"Note: Visualization simplified to show top categories only. Original data has {feature_unique_count} unique {feature} values and {target_unique_count} unique {target} values.",
1334
+ ha='center',
1335
+ fontsize=8,
1336
+ style='italic'
1337
+ )
1338
+
1339
+ # Adjust layout and draw
1340
+ canvas.figure.tight_layout()
1341
+ canvas.draw()
1342
+
1343
+ # Add a close button
1344
+ close_button = QPushButton("Close")
1345
+ close_button.clicked.connect(dialog.accept)
1346
+ layout.addWidget(close_button)
1347
+
1348
+ # Show the dialog
1349
+ dialog.exec()
1350
+
1351
+ def change_sort(self, sort_key):
1352
+ """Change the sort order of the results"""
1353
+ if self.importance_df is None:
1354
+ return
1355
+
1356
+ # Update button states
1357
+ if sort_key == 'importance_value':
1358
+ self.importance_sort_btn.setChecked(True)
1359
+ self.correlation_sort_btn.setChecked(False)
1360
+ else:
1361
+ self.importance_sort_btn.setChecked(False)
1362
+ self.correlation_sort_btn.setChecked(True)
1363
+
1364
+ # Store the current sort key
1365
+ self.current_sort = sort_key
1366
+
1367
+ # Re-sort the dataframe
1368
+ self.importance_df = self.importance_df.sort_values(by=sort_key, ascending=False)
1369
+
1370
+ # Reset rendering of the table
1371
+ self.importance_table.clearContents()
1372
+ self.importance_table.setRowCount(len(self.importance_df))
1373
+ self.current_row = 0
1374
+
1375
+ # Start incremental rendering with the new sort order
1376
+ if self.render_timer and self.render_timer.isActive():
1377
+ self.render_timer.stop()
1378
+ self.render_timer = QTimer()
1379
+ self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
1380
+ self.render_timer.start(10) # Update every 10ms
556
1381
 
557
1382
  # Main application class
558
1383
  class ColumnProfilerApp(QMainWindow):
@@ -576,6 +1401,9 @@ class ColumnProfilerApp(QMainWindow):
576
1401
  self.current_row = 0
577
1402
  self.render_timer = None
578
1403
 
1404
+ # Current sort key
1405
+ self.current_sort = 'importance_value'
1406
+
579
1407
  # Set window properties
580
1408
  self.setWindowTitle("Column Profiler")
581
1409
  self.setMinimumSize(900, 600)
@@ -618,13 +1446,38 @@ class ColumnProfilerApp(QMainWindow):
618
1446
  # Add control panel to main layout
619
1447
  main_layout.addWidget(control_panel)
620
1448
 
1449
+ # Add sorting control
1450
+ sort_panel = QWidget()
1451
+ sort_layout = QHBoxLayout(sort_panel)
1452
+ sort_layout.setContentsMargins(0, 0, 0, 0)
1453
+
1454
+ # Add sort label
1455
+ sort_layout.addWidget(QLabel("Sort by:"))
1456
+
1457
+ # Add sort buttons
1458
+ self.importance_sort_btn = QPushButton("Importance")
1459
+ self.importance_sort_btn.setCheckable(True)
1460
+ self.importance_sort_btn.setChecked(True) # Default sort
1461
+ self.importance_sort_btn.clicked.connect(lambda: self.change_sort('importance_value'))
1462
+
1463
+ self.correlation_sort_btn = QPushButton("Correlation")
1464
+ self.correlation_sort_btn.setCheckable(True)
1465
+ self.correlation_sort_btn.clicked.connect(lambda: self.change_sort('correlation'))
1466
+
1467
+ sort_layout.addWidget(self.importance_sort_btn)
1468
+ sort_layout.addWidget(self.correlation_sort_btn)
1469
+ sort_layout.addStretch()
1470
+
1471
+ # Add buttons to layout
1472
+ main_layout.addWidget(sort_panel)
1473
+
621
1474
  # Add a splitter for results area
622
1475
  results_splitter = QSplitter(Qt.Orientation.Vertical)
623
1476
 
624
1477
  # Create table for showing importance values
625
1478
  self.importance_table = QTableWidget()
626
- self.importance_table.setColumnCount(2)
627
- self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance"])
1479
+ self.importance_table.setColumnCount(3)
1480
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance", "Abs. Correlation"])
628
1481
  self.importance_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
629
1482
  self.importance_table.cellDoubleClicked.connect(self.show_relationship_visualization)
630
1483
  results_splitter.addWidget(self.importance_table)
@@ -647,7 +1500,7 @@ class ColumnProfilerApp(QMainWindow):
647
1500
 
648
1501
  # Set the central widget
649
1502
  self.setCentralWidget(central_widget)
650
-
1503
+
651
1504
  def analyze_column(self):
652
1505
  if self.df is None or self.column_selector.currentText() == "":
653
1506
  return
@@ -732,7 +1585,9 @@ class ColumnProfilerApp(QMainWindow):
732
1585
  self.progress_label.hide()
733
1586
  self.cancel_button.hide()
734
1587
 
735
- # Update importance table incrementally
1588
+ # Update importance table to include correlation column
1589
+ self.importance_table.setColumnCount(3)
1590
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance", "Abs. Correlation"])
736
1591
  self.importance_table.setRowCount(len(importance_df))
737
1592
 
738
1593
  # Using a timer for incremental updates
@@ -741,53 +1596,276 @@ class ColumnProfilerApp(QMainWindow):
741
1596
  self.render_timer = QTimer()
742
1597
  self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
743
1598
  self.render_timer.start(10) # Update every 10ms
744
-
1599
+
745
1600
  def render_next_batch(self, batch_size):
746
- if self.current_row >= len(self.importance_df):
747
- # All rows rendered, now render the chart and stop the timer
748
- self.render_chart()
749
- self.render_timer.stop()
750
- return
1601
+ try:
1602
+ if self.current_row >= len(self.importance_df):
1603
+ # All rows rendered, now render the chart and stop the timer
1604
+ self.render_chart()
1605
+ self.render_timer.stop()
1606
+ return
751
1607
 
752
- # Render a batch of rows
753
- end_row = min(self.current_row + batch_size, len(self.importance_df))
754
- for row in range(self.current_row, end_row):
755
- feature = self.importance_df.iloc[row]['feature']
756
- mean_abs_shap = self.importance_df.iloc[row]['mean_abs_shap']
757
- self.importance_table.setItem(row, 0, QTableWidgetItem(feature))
758
- self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(mean_abs_shap, 4))))
1608
+ # Render a batch of rows
1609
+ end_row = min(self.current_row + batch_size, len(self.importance_df))
1610
+ for row in range(self.current_row, end_row):
1611
+ try:
1612
+ # Check if row exists in dataframe to prevent index errors
1613
+ if row < len(self.importance_df):
1614
+ feature = self.importance_df.iloc[row]['feature']
1615
+ importance_value = self.importance_df.iloc[row]['importance_value']
1616
+
1617
+ # Add correlation if available
1618
+ correlation = self.importance_df.iloc[row].get('correlation', None)
1619
+ if correlation is not None:
1620
+ self.importance_table.setItem(row, 0, QTableWidgetItem(str(feature)))
1621
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(importance_value, 4))))
1622
+ self.importance_table.setItem(row, 2, QTableWidgetItem(str(round(correlation, 4))))
1623
+ else:
1624
+ self.importance_table.setItem(row, 0, QTableWidgetItem(str(feature)))
1625
+ self.importance_table.setItem(row, 1, QTableWidgetItem(str(round(importance_value, 4))))
1626
+ else:
1627
+ # Handle out of range index
1628
+ print(f"Warning: Row {row} is out of range (max: {len(self.importance_df)-1})")
1629
+ self.importance_table.setItem(row, 0, QTableWidgetItem("Error"))
1630
+ self.importance_table.setItem(row, 1, QTableWidgetItem("Out of range"))
1631
+ self.importance_table.setItem(row, 2, QTableWidgetItem("N/A"))
1632
+ except (IndexError, KeyError) as e:
1633
+ # Enhanced error reporting for index and key errors
1634
+ import traceback
1635
+ trace = traceback.format_exc()
1636
+ error_msg = f"Error rendering row {row}: {e.__class__.__name__}: {e}\n{trace}"
1637
+ print(error_msg)
1638
+
1639
+ # Handle missing data in the dataframe gracefully
1640
+ self.importance_table.setItem(row, 0, QTableWidgetItem(f"Error: {e.__class__.__name__}"))
1641
+ self.importance_table.setItem(row, 1, QTableWidgetItem(f"{str(e)[:20]}"))
1642
+ self.importance_table.setItem(row, 2, QTableWidgetItem("Error"))
1643
+ except Exception as e:
1644
+ # Catch any other exceptions
1645
+ print(f"Unexpected error rendering row {row}: {e.__class__.__name__}: {e}")
1646
+ self.importance_table.setItem(row, 0, QTableWidgetItem(f"Error: {e.__class__.__name__}"))
1647
+ self.importance_table.setItem(row, 1, QTableWidgetItem("See console for details"))
1648
+ self.importance_table.setItem(row, 2, QTableWidgetItem("Error"))
1649
+
1650
+ self.current_row = end_row
1651
+ QApplication.processEvents() # Allow UI to update
1652
+ except Exception as e:
1653
+ # Catch any exceptions in the rendering loop itself
1654
+ import traceback
1655
+ trace = traceback.format_exc()
1656
+ error_msg = f"Error in render_next_batch: {e.__class__.__name__}: {e}\n{trace}"
1657
+ print(error_msg)
1658
+
1659
+ # Try to stop the timer to prevent further errors
1660
+ try:
1661
+ if self.render_timer and self.render_timer.isActive():
1662
+ self.render_timer.stop()
1663
+ except:
1664
+ pass
759
1665
 
760
- self.current_row = end_row
761
- QApplication.processEvents() # Allow UI to update
1666
+ # Show error
1667
+ QMessageBox.critical(self, "Rendering Error",
1668
+ f"Error rendering results: {e.__class__.__name__}: {e}")
762
1669
 
763
1670
  def render_chart(self):
764
1671
  # Create horizontal bar chart
765
- self.chart_view.axes.clear()
766
-
767
- # Limit to top 20 features for better visualization
768
- plot_df = self.importance_df.head(20)
769
-
770
- # Plot with custom colors
771
- bars = self.chart_view.axes.barh(
772
- plot_df['feature'],
773
- plot_df['mean_abs_shap'],
774
- color='skyblue'
775
- )
776
-
777
- # Add values at the end of bars
778
- for bar in bars:
779
- width = bar.get_width()
780
- self.chart_view.axes.text(
781
- width * 1.05,
782
- bar.get_y() + bar.get_height()/2,
783
- f'{width:.2f}',
784
- va='center'
1672
+ try:
1673
+ if self.importance_df is None or len(self.importance_df) == 0:
1674
+ # No data to render
1675
+ self.chart_view.axes.clear()
1676
+ self.chart_view.axes.text(0.5, 0.5, "No data available for chart",
1677
+ ha='center', va='center', fontsize=12, color='gray')
1678
+ self.chart_view.axes.set_axis_off()
1679
+ self.chart_view.draw()
1680
+ return
1681
+
1682
+ self.chart_view.axes.clear()
1683
+
1684
+ # Get a sorted copy based on current sort key
1685
+ plot_df = self.importance_df.sort_values(by=self.current_sort, ascending=False).head(20).copy()
1686
+
1687
+ # Verify we have data before proceeding
1688
+ if len(plot_df) == 0:
1689
+ self.chart_view.axes.text(0.5, 0.5, "No features found with importance values",
1690
+ ha='center', va='center', fontsize=12, color='gray')
1691
+ self.chart_view.axes.set_axis_off()
1692
+ self.chart_view.draw()
1693
+ return
1694
+
1695
+ # Check required columns exist
1696
+ required_columns = ['feature', 'importance_value']
1697
+ missing_columns = [col for col in required_columns if col not in plot_df.columns]
1698
+ if missing_columns:
1699
+ error_msg = f"Missing required columns: {', '.join(missing_columns)}"
1700
+ self.chart_view.axes.text(0.5, 0.5, error_msg,
1701
+ ha='center', va='center', fontsize=12, color='red')
1702
+ self.chart_view.axes.set_axis_off()
1703
+ self.chart_view.draw()
1704
+ print(f"Chart rendering error: {error_msg}")
1705
+ return
1706
+
1707
+ # Truncate long feature names for better display
1708
+ max_feature_length = 30
1709
+ plot_df['display_feature'] = plot_df['feature'].apply(
1710
+ lambda x: (str(x)[:max_feature_length] + '...') if len(str(x)) > max_feature_length else str(x)
785
1711
  )
786
1712
 
787
- self.chart_view.axes.set_title(f'Feature Importance for Predicting {self.column_selector.currentText()}')
788
- self.chart_view.axes.set_xlabel('Mean Absolute SHAP Value')
789
- self.chart_view.figure.tight_layout()
790
- self.chart_view.draw()
1713
+ # Reverse order for better display (highest at top)
1714
+ plot_df = plot_df.iloc[::-1].reset_index(drop=True)
1715
+
1716
+ # Create a figure with two subplots side by side
1717
+ self.chart_view.figure.clear()
1718
+ gs = self.chart_view.figure.add_gridspec(1, 2, width_ratios=[3, 2])
1719
+
1720
+ # First subplot for importance
1721
+ ax1 = self.chart_view.figure.add_subplot(gs[0, 0])
1722
+
1723
+ # Create a colormap for better visualization
1724
+ cmap = plt.cm.Blues
1725
+ colors = cmap(np.linspace(0.4, 0.8, len(plot_df)))
1726
+
1727
+ # Plot with custom colors
1728
+ bars = ax1.barh(
1729
+ plot_df['display_feature'],
1730
+ plot_df['importance_value'],
1731
+ color=colors,
1732
+ height=0.7, # Thinner bars for more spacing
1733
+ alpha=0.8
1734
+ )
1735
+
1736
+ # Add values at the end of bars
1737
+ for bar in bars:
1738
+ width = bar.get_width()
1739
+ ax1.text(
1740
+ width * 1.05,
1741
+ bar.get_y() + bar.get_height()/2,
1742
+ f'{width:.2f}',
1743
+ va='center',
1744
+ fontsize=9,
1745
+ fontweight='bold'
1746
+ )
1747
+
1748
+ # Add grid for better readability
1749
+ ax1.grid(True, axis='x', linestyle='--', alpha=0.3)
1750
+
1751
+ # Remove unnecessary spines
1752
+ for spine in ['top', 'right']:
1753
+ ax1.spines[spine].set_visible(False)
1754
+
1755
+ # Make labels more readable
1756
+ ax1.tick_params(axis='y', labelsize=9)
1757
+
1758
+ # Set title and labels
1759
+ ax1.set_title(f'Feature Importance for {self.column_selector.currentText()}')
1760
+ ax1.set_xlabel('Importance Value')
1761
+
1762
+ # Add a note about the sorting order
1763
+ sort_label = "Sorted by: " + ("Importance" if self.current_sort == 'importance_value' else "Correlation")
1764
+
1765
+ # Second subplot for correlation if available
1766
+ if 'correlation' in plot_df.columns:
1767
+ ax2 = self.chart_view.figure.add_subplot(gs[0, 1], sharey=ax1)
1768
+
1769
+ # Create a colormap for correlation - use a different color
1770
+ cmap_corr = plt.cm.Reds
1771
+ colors_corr = cmap_corr(np.linspace(0.4, 0.8, len(plot_df)))
1772
+
1773
+ # Plot correlation bars
1774
+ corr_bars = ax2.barh(
1775
+ plot_df['display_feature'],
1776
+ plot_df['correlation'],
1777
+ color=colors_corr,
1778
+ height=0.7,
1779
+ alpha=0.8
1780
+ )
1781
+
1782
+ # Add values at the end of correlation bars
1783
+ for bar in corr_bars:
1784
+ width = bar.get_width()
1785
+ ax2.text(
1786
+ width * 1.05,
1787
+ bar.get_y() + bar.get_height()/2,
1788
+ f'{width:.2f}',
1789
+ va='center',
1790
+ fontsize=9,
1791
+ fontweight='bold'
1792
+ )
1793
+
1794
+ # Add grid and styling
1795
+ ax2.grid(True, axis='x', linestyle='--', alpha=0.3)
1796
+ ax2.set_title('Absolute Correlation')
1797
+ ax2.set_xlabel('Correlation Value')
1798
+
1799
+ # Hide y-axis labels since they're shared with the first plot
1800
+ ax2.set_yticklabels([])
1801
+
1802
+ # Remove unnecessary spines
1803
+ for spine in ['top', 'right']:
1804
+ ax2.spines[spine].set_visible(False)
1805
+
1806
+ # Add a note about the current sort order
1807
+ self.chart_view.figure.text(0.5, 0.01, sort_label, ha='center', fontsize=9, style='italic')
1808
+
1809
+ # Adjust figure size based on number of features
1810
+ feature_count = len(plot_df)
1811
+ self.chart_view.figure.set_figheight(max(5, min(4 + feature_count * 0.3, 12)))
1812
+
1813
+ # Adjust layout and draw
1814
+ self.chart_view.figure.tight_layout(rect=[0, 0.03, 1, 0.97]) # Make room for sort label
1815
+ self.chart_view.draw()
1816
+
1817
+ except IndexError as e:
1818
+ # Special handling for index errors with detailed information
1819
+ import traceback
1820
+ import inspect
1821
+
1822
+ # Get stack trace information
1823
+ trace = traceback.format_exc()
1824
+
1825
+ # Try to get line and context information
1826
+ try:
1827
+ frame = inspect.trace()[-1]
1828
+ frame_info = inspect.getframeinfo(frame[0])
1829
+ filename = frame_info.filename
1830
+ lineno = frame_info.lineno
1831
+ function = frame_info.function
1832
+ code_context = frame_info.code_context[0].strip() if frame_info.code_context else "Unknown code context"
1833
+
1834
+ # Detailed error message
1835
+ detail_msg = f"IndexError at line {lineno} in {function}: {str(e)}\nCode: {code_context}"
1836
+ print(f"Chart rendering error: {detail_msg}\n{trace}")
1837
+
1838
+ # Display error in chart
1839
+ self.chart_view.axes.clear()
1840
+ self.chart_view.axes.text(0.5, 0.5,
1841
+ f"Index Error in chart rendering:\n{str(e)}\nAt line {lineno}: {code_context}",
1842
+ ha='center', va='center', fontsize=12, color='red',
1843
+ wrap=True)
1844
+ self.chart_view.axes.set_axis_off()
1845
+ self.chart_view.draw()
1846
+ except Exception as inner_e:
1847
+ # Fallback if the detailed error reporting fails
1848
+ print(f"Error getting detailed error info: {inner_e}")
1849
+ print(f"Original error: {e}\n{trace}")
1850
+
1851
+ self.chart_view.axes.clear()
1852
+ self.chart_view.axes.text(0.5, 0.5, f"Index Error: {str(e)}",
1853
+ ha='center', va='center', fontsize=12, color='red')
1854
+ self.chart_view.axes.set_axis_off()
1855
+ self.chart_view.draw()
1856
+ except Exception as e:
1857
+ # Recover gracefully from any chart rendering errors with detailed information
1858
+ import traceback
1859
+ trace = traceback.format_exc()
1860
+ error_msg = f"Error rendering chart: {e.__class__.__name__}: {str(e)}"
1861
+ print(f"{error_msg}\n{trace}")
1862
+
1863
+ self.chart_view.axes.clear()
1864
+ self.chart_view.axes.text(0.5, 0.5, error_msg,
1865
+ ha='center', va='center', fontsize=12, color='red',
1866
+ wrap=True)
1867
+ self.chart_view.axes.set_axis_off()
1868
+ self.chart_view.draw()
791
1869
 
792
1870
  def handle_error(self, error_message):
793
1871
  """Handle errors during analysis"""
@@ -802,23 +1880,32 @@ class ColumnProfilerApp(QMainWindow):
802
1880
  # Print error to console for debugging
803
1881
  print(f"Error in column profiler: {error_message}")
804
1882
 
805
- # Show error message
806
- QMessageBox.critical(self, "Error", f"An error occurred during analysis:\n\n{error_message}")
1883
+ # Show error message with more details
1884
+ msg_box = QMessageBox(self)
1885
+ msg_box.setIcon(QMessageBox.Icon.Critical)
1886
+ msg_box.setWindowTitle("Error")
1887
+ msg_box.setText("An error occurred during analysis")
1888
+ msg_box.setDetailedText(error_message)
1889
+ msg_box.setStandardButtons(QMessageBox.StandardButton.Ok)
1890
+ msg_box.exec()
807
1891
 
808
1892
  # Show a message in the UI as well
809
1893
  self.importance_table.setRowCount(1)
810
- self.importance_table.setColumnCount(1)
811
- self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message}"))
1894
+ self.importance_table.setColumnCount(3)
1895
+ self.importance_table.setHorizontalHeaderLabels(["Feature", "Importance", "Abs. Correlation"])
1896
+ self.importance_table.setItem(0, 0, QTableWidgetItem(f"Error: {error_message.split('\n')[0]}"))
1897
+ self.importance_table.setItem(0, 1, QTableWidgetItem(""))
1898
+ self.importance_table.setItem(0, 2, QTableWidgetItem(""))
812
1899
  self.importance_table.resizeColumnsToContents()
813
1900
 
814
1901
  # Update the chart to show error
815
1902
  self.chart_view.axes.clear()
816
- self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message}",
1903
+ self.chart_view.axes.text(0.5, 0.5, f"Error calculating importance:\n{error_message.split('\n')[0]}",
817
1904
  ha='center', va='center', fontsize=12, color='red',
818
1905
  wrap=True)
819
1906
  self.chart_view.axes.set_axis_off()
820
1907
  self.chart_view.draw()
821
-
1908
+
822
1909
  def closeEvent(self, event):
823
1910
  """Clean up when the window is closed"""
824
1911
  # Stop any running timer
@@ -886,17 +1973,32 @@ class ColumnProfilerApp(QMainWindow):
886
1973
 
887
1974
  def show_relationship_visualization(self, row, column):
888
1975
  """Show visualization of relationship between selected feature and target column"""
889
- if self.importance_df is None or row >= len(self.importance_df):
1976
+ if self.importance_df is None or row < 0 or row >= len(self.importance_df):
890
1977
  return
891
1978
 
892
1979
  # Get the feature name and target column
893
- feature = self.importance_df.iloc[row]['feature']
894
- target = self.column_selector.currentText()
1980
+ try:
1981
+ feature = self.importance_df.iloc[row]['feature']
1982
+ target = self.column_selector.currentText()
1983
+
1984
+ # Verify both columns exist in the dataframe
1985
+ if feature not in self.df.columns:
1986
+ QMessageBox.warning(self, "Column Not Found",
1987
+ f"Feature column '{feature}' not found in the dataframe")
1988
+ return
1989
+
1990
+ if target not in self.df.columns:
1991
+ QMessageBox.warning(self, "Column Not Found",
1992
+ f"Target column '{target}' not found in the dataframe")
1993
+ return
1994
+ except Exception as e:
1995
+ QMessageBox.critical(self, "Error", f"Error getting column data: {str(e)}")
1996
+ return
895
1997
 
896
1998
  # Create a dialog to show the visualization
897
1999
  dialog = QDialog(self)
898
2000
  dialog.setWindowTitle(f"Relationship: {feature} vs {target}")
899
- dialog.resize(800, 600)
2001
+ dialog.resize(900, 700)
900
2002
 
901
2003
  # Create layout
902
2004
  layout = QVBoxLayout(dialog)
@@ -909,37 +2011,231 @@ class ColumnProfilerApp(QMainWindow):
909
2011
  feature_is_numeric = pd.api.types.is_numeric_dtype(self.df[feature])
910
2012
  target_is_numeric = pd.api.types.is_numeric_dtype(self.df[target])
911
2013
 
2014
+ # Get unique counts to determine if we have high cardinality
2015
+ feature_unique_count = self.df[feature].nunique()
2016
+ target_unique_count = self.df[target].nunique()
2017
+
2018
+ # Define high cardinality threshold
2019
+ high_cardinality_threshold = 10
2020
+
912
2021
  # Clear the figure
913
2022
  canvas.axes.clear()
914
2023
 
915
- # Create appropriate visualization based on data types
2024
+ # Create a working copy of the dataframe
2025
+ working_df = self.df.copy()
2026
+
2027
+ # Prepare data for high cardinality columns
2028
+ if not feature_is_numeric and feature_unique_count > high_cardinality_threshold:
2029
+ # Get the top N categories by frequency
2030
+ top_categories = self.df[feature].value_counts().nlargest(high_cardinality_threshold).index.tolist()
2031
+ # Create "Other" category for remaining values
2032
+ working_df[feature] = working_df[feature].apply(lambda x: x if x in top_categories else 'Other')
2033
+
2034
+ if not target_is_numeric and target_unique_count > high_cardinality_threshold:
2035
+ top_categories = self.df[target].value_counts().nlargest(high_cardinality_threshold).index.tolist()
2036
+ working_df[target] = working_df[target].apply(lambda x: x if x in top_categories else 'Other')
2037
+
2038
+ # Create appropriate visualization based on data types and cardinality
916
2039
  if feature_is_numeric and target_is_numeric:
917
2040
  # Scatter plot for numeric vs numeric
918
- sns.scatterplot(x=feature, y=target, data=self.df, ax=canvas.axes)
919
- # Add regression line
920
- sns.regplot(x=feature, y=target, data=self.df, ax=canvas.axes,
921
- scatter=False, line_kws={"color": "red"})
922
- canvas.axes.set_title(f"Scatter Plot: {feature} vs {target}")
2041
+ # Use hexbin for large datasets to avoid overplotting
2042
+ if len(working_df) > 100:
2043
+ canvas.axes.hexbin(
2044
+ working_df[feature],
2045
+ working_df[target],
2046
+ gridsize=25,
2047
+ cmap='Blues',
2048
+ mincnt=1
2049
+ )
2050
+ canvas.axes.set_title(f"Hexbin Density Plot: {feature} vs {target}")
2051
+ canvas.axes.set_xlabel(feature)
2052
+ canvas.axes.set_ylabel(target)
2053
+ # Add a colorbar
2054
+ cbar = canvas.figure.colorbar(canvas.axes.collections[0], ax=canvas.axes)
2055
+ cbar.set_label('Count')
2056
+ else:
2057
+ # For smaller datasets, use a scatter plot with transparency
2058
+ sns.scatterplot(
2059
+ x=feature,
2060
+ y=target,
2061
+ data=working_df,
2062
+ ax=canvas.axes,
2063
+ alpha=0.6
2064
+ )
2065
+ # Add regression line
2066
+ sns.regplot(
2067
+ x=feature,
2068
+ y=target,
2069
+ data=working_df,
2070
+ ax=canvas.axes,
2071
+ scatter=False,
2072
+ line_kws={"color": "red"}
2073
+ )
2074
+ canvas.axes.set_title(f"Scatter Plot: {feature} vs {target}")
923
2075
 
924
2076
  elif feature_is_numeric and not target_is_numeric:
925
2077
  # Box plot for numeric vs categorical
926
- sns.boxplot(x=target, y=feature, data=self.df, ax=canvas.axes)
927
- canvas.axes.set_title(f"Box Plot: {feature} by {target}")
2078
+ if target_unique_count <= high_cardinality_threshold * 2:
2079
+ # Standard boxplot for reasonable number of categories
2080
+ order = working_df[target].value_counts().nlargest(high_cardinality_threshold * 2).index
2081
+ sns.boxplot(
2082
+ x=target,
2083
+ y=feature,
2084
+ data=working_df,
2085
+ ax=canvas.axes,
2086
+ order=order
2087
+ )
2088
+ canvas.axes.set_title(f"Box Plot: {feature} by {target}")
2089
+ # Rotate x-axis labels for better readability
2090
+ canvas.axes.set_xticklabels(
2091
+ canvas.axes.get_xticklabels(),
2092
+ rotation=45,
2093
+ ha='right'
2094
+ )
2095
+ else:
2096
+ # For very high cardinality, use a violin plot with limited categories
2097
+ order = working_df[target].value_counts().nlargest(high_cardinality_threshold).index
2098
+ working_df_filtered = working_df[working_df[target].isin(order)]
2099
+ sns.violinplot(
2100
+ x=target,
2101
+ y=feature,
2102
+ data=working_df_filtered,
2103
+ ax=canvas.axes,
2104
+ inner='quartile',
2105
+ cut=0
2106
+ )
2107
+ canvas.axes.set_title(f"Violin Plot: {feature} by Top {len(order)} {target} Categories")
2108
+ canvas.axes.set_xticklabels(
2109
+ canvas.axes.get_xticklabels(),
2110
+ rotation=45,
2111
+ ha='right'
2112
+ )
928
2113
 
929
2114
  elif not feature_is_numeric and target_is_numeric:
930
2115
  # Bar plot for categorical vs numeric
931
- sns.barplot(x=feature, y=target, data=self.df, ax=canvas.axes)
932
- canvas.axes.set_title(f"Bar Plot: Average {target} by {feature}")
933
- # Rotate x-axis labels if there are many categories
934
- if self.df[feature].nunique() > 5:
935
- canvas.axes.set_xticklabels(canvas.axes.get_xticklabels(), rotation=45, ha='right')
2116
+ if feature_unique_count <= high_cardinality_threshold * 2:
2117
+ # Use standard barplot for reasonable number of categories
2118
+ order = working_df[feature].value_counts().nlargest(high_cardinality_threshold * 2).index
2119
+ sns.barplot(
2120
+ x=feature,
2121
+ y=target,
2122
+ data=working_df,
2123
+ ax=canvas.axes,
2124
+ order=order,
2125
+ estimator=np.mean,
2126
+ errorbar=('ci', 95),
2127
+ capsize=0.2
2128
+ )
2129
+ canvas.axes.set_title(f"Bar Plot: Average {target} by {feature}")
2130
+
2131
+ # Add value labels on top of bars
2132
+ for p in canvas.axes.patches:
2133
+ canvas.axes.annotate(
2134
+ f'{p.get_height():.1f}',
2135
+ (p.get_x() + p.get_width() / 2., p.get_height()),
2136
+ ha='center',
2137
+ va='bottom',
2138
+ fontsize=8,
2139
+ rotation=0
2140
+ )
2141
+
2142
+ # Rotate x-axis labels if needed
2143
+ if feature_unique_count > 5:
2144
+ canvas.axes.set_xticklabels(
2145
+ canvas.axes.get_xticklabels(),
2146
+ rotation=45,
2147
+ ha='right'
2148
+ )
2149
+ else:
2150
+ # For high cardinality, use a horizontal bar plot with top N categories
2151
+ top_n = 15 # Show top 15 categories
2152
+ # Calculate mean of target for each feature category
2153
+ grouped = working_df.groupby(feature)[target].agg(['mean', 'count', 'std']).reset_index()
2154
+ # Sort by mean and take top categories
2155
+ top_groups = grouped.nlargest(top_n, 'mean')
2156
+
2157
+ # Sort by mean value for better visualization
2158
+ sns.barplot(
2159
+ y=feature,
2160
+ x='mean',
2161
+ data=top_groups,
2162
+ ax=canvas.axes,
2163
+ orient='h'
2164
+ )
2165
+ canvas.axes.set_title(f"Top {top_n} Categories by Average {target}")
2166
+ canvas.axes.set_xlabel(f"Average {target}")
2167
+
2168
+ # Add count annotations
2169
+ for i, row in enumerate(top_groups.itertuples()):
2170
+ canvas.axes.text(
2171
+ row.mean + 0.1,
2172
+ i,
2173
+ f'n={row.count}',
2174
+ va='center',
2175
+ fontsize=8
2176
+ )
936
2177
 
937
2178
  else:
938
- # Heatmap for categorical vs categorical
939
- # Create a crosstab of the two categorical variables
940
- crosstab = pd.crosstab(self.df[feature], self.df[target], normalize='index')
941
- sns.heatmap(crosstab, annot=True, cmap="YlGnBu", ax=canvas.axes)
942
- canvas.axes.set_title(f"Heatmap: {feature} vs {target}")
2179
+ # Both feature and target are categorical
2180
+ if feature_unique_count <= high_cardinality_threshold and target_unique_count <= high_cardinality_threshold:
2181
+ # Heatmap for categorical vs categorical with manageable cardinality
2182
+ crosstab = pd.crosstab(
2183
+ working_df[feature],
2184
+ working_df[target],
2185
+ normalize='index'
2186
+ )
2187
+
2188
+ # Create heatmap with improved readability
2189
+ sns.heatmap(
2190
+ crosstab,
2191
+ annot=True,
2192
+ cmap="YlGnBu",
2193
+ ax=canvas.axes,
2194
+ fmt='.2f',
2195
+ linewidths=0.5,
2196
+ annot_kws={"size": 9 if crosstab.size < 30 else 7}
2197
+ )
2198
+ canvas.axes.set_title(f"Heatmap: {feature} vs {target} (proportions)")
2199
+ else:
2200
+ # For high cardinality in both, show a count plot of top categories
2201
+ feature_top = working_df[feature].value_counts().nlargest(8).index
2202
+ target_top = working_df[target].value_counts().nlargest(5).index
2203
+
2204
+ # Filter data to only include top categories
2205
+ filtered_df = working_df[
2206
+ working_df[feature].isin(feature_top) &
2207
+ working_df[target].isin(target_top)
2208
+ ]
2209
+
2210
+ # Create a grouped count plot
2211
+ sns.countplot(
2212
+ x=feature,
2213
+ hue=target,
2214
+ data=filtered_df,
2215
+ ax=canvas.axes
2216
+ )
2217
+ canvas.axes.set_title(f"Count Plot: Top {len(feature_top)} {feature} by Top {len(target_top)} {target}")
2218
+
2219
+ # Rotate x-axis labels
2220
+ canvas.axes.set_xticklabels(
2221
+ canvas.axes.get_xticklabels(),
2222
+ rotation=45,
2223
+ ha='right'
2224
+ )
2225
+
2226
+ # Move legend to a better position
2227
+ canvas.axes.legend(title=target, bbox_to_anchor=(1.05, 1), loc='upper left')
2228
+
2229
+ # Add informational text about data reduction if applicable
2230
+ if (not feature_is_numeric and feature_unique_count > high_cardinality_threshold) or \
2231
+ (not target_is_numeric and target_unique_count > high_cardinality_threshold):
2232
+ canvas.figure.text(
2233
+ 0.5, 0.01,
2234
+ f"Note: Visualization simplified to show top categories only. Original data has {feature_unique_count} unique {feature} values and {target_unique_count} unique {target} values.",
2235
+ ha='center',
2236
+ fontsize=8,
2237
+ style='italic'
2238
+ )
943
2239
 
944
2240
  # Adjust layout and draw
945
2241
  canvas.figure.tight_layout()
@@ -952,6 +2248,44 @@ class ColumnProfilerApp(QMainWindow):
952
2248
 
953
2249
  # Show the dialog
954
2250
  dialog.exec()
2251
+
2252
+ def change_sort(self, sort_key):
2253
+ """Change the sort order of the results"""
2254
+ if self.importance_df is None:
2255
+ return
2256
+
2257
+ # Update button states
2258
+ if sort_key == 'importance_value':
2259
+ self.importance_sort_btn.setChecked(True)
2260
+ self.correlation_sort_btn.setChecked(False)
2261
+ else:
2262
+ self.importance_sort_btn.setChecked(False)
2263
+ self.correlation_sort_btn.setChecked(True)
2264
+
2265
+ # Store the current sort key
2266
+ self.current_sort = sort_key
2267
+
2268
+ # Re-sort the dataframe
2269
+ self.importance_df = self.importance_df.sort_values(by=sort_key, ascending=False)
2270
+
2271
+ # Reset rendering of the table
2272
+ self.importance_table.clearContents()
2273
+ self.importance_table.setRowCount(len(self.importance_df))
2274
+ self.current_row = 0
2275
+
2276
+ # Start incremental rendering with the new sort order
2277
+ if self.render_timer and self.render_timer.isActive():
2278
+ self.render_timer.stop()
2279
+ self.render_timer = QTimer()
2280
+ self.render_timer.timeout.connect(lambda: self.render_next_batch(10))
2281
+ self.render_timer.start(10) # Update every 10ms
2282
+
2283
+ # Custom matplotlib canvas for embedding in Qt
2284
+ class MatplotlibCanvas(FigureCanvasQTAgg):
2285
+ def __init__(self, width=5, height=4, dpi=100):
2286
+ self.figure = Figure(figsize=(width, height), dpi=dpi)
2287
+ self.axes = self.figure.add_subplot(111)
2288
+ super().__init__(self.figure)
955
2289
 
956
2290
  def visualize_profile(df: pd.DataFrame, column: str = None) -> None:
957
2291
  """
@@ -962,6 +2296,18 @@ def visualize_profile(df: pd.DataFrame, column: str = None) -> None:
962
2296
  column: Optional target column to analyze immediately
963
2297
  """
964
2298
  try:
2299
+ # Verify df is a valid DataFrame
2300
+ if not isinstance(df, pd.DataFrame):
2301
+ raise ValueError("Input must be a pandas DataFrame")
2302
+
2303
+ # Verify df has data
2304
+ if len(df) == 0:
2305
+ raise ValueError("DataFrame is empty, cannot analyze")
2306
+
2307
+ # Verify columns exist
2308
+ if column is not None and column not in df.columns:
2309
+ raise ValueError(f"Column '{column}' not found in the DataFrame")
2310
+
965
2311
  # Check if dataset is too small for meaningful analysis
966
2312
  row_count = len(df)
967
2313
  if row_count <= 5:
@@ -1062,11 +2408,11 @@ def test_profile():
1062
2408
  """
1063
2409
  Test the profile and visualization functions with sample data.
1064
2410
  """
1065
- # Create a sample DataFrame
2411
+ # Create a sample DataFrame with 40 columns
1066
2412
  np.random.seed(42)
1067
2413
  n = 1000
1068
2414
 
1069
- # Generate sample data with known relationships
2415
+ # Generate core sample data with known relationships
1070
2416
  age = np.random.normal(35, 10, n).astype(int)
1071
2417
  experience = age - np.random.randint(18, 25, n) # experience correlates with age
1072
2418
  experience = np.maximum(0, experience) # no negative experience
@@ -1082,16 +2428,86 @@ def test_profile():
1082
2428
  performance += 0.01 * experience # experience slightly affects performance
1083
2429
  performance = (performance - performance.min()) / (performance.max() - performance.min()) * 5 # scale to 0-5
1084
2430
 
1085
- # Create the DataFrame
1086
- df = pd.DataFrame({
2431
+ # Create the base DataFrame
2432
+ data = {
1087
2433
  'Age': age,
1088
2434
  'Experience': experience,
1089
2435
  'Department': departments,
1090
2436
  'Education': education,
1091
2437
  'Performance': performance,
1092
2438
  'Salary': salary
1093
- })
2439
+ }
2440
+
2441
+ # Generate additional numeric columns
2442
+ for i in range(1, 15):
2443
+ # Create some columns with relationship to salary
2444
+ if i <= 5:
2445
+ data[f'Metric_{i}'] = salary * (0.01 * i) + np.random.normal(0, 5000, n)
2446
+ # Create columns with relationship to age
2447
+ elif i <= 10:
2448
+ data[f'Metric_{i}'] = age * (i-5) + np.random.normal(0, 10, n)
2449
+ # Create random columns
2450
+ else:
2451
+ data[f'Metric_{i}'] = np.random.normal(100, 50, n)
2452
+
2453
+ # Generate additional categorical columns
2454
+ categories = [
2455
+ ['A', 'B', 'C', 'D'],
2456
+ ['Low', 'Medium', 'High'],
2457
+ ['North', 'South', 'East', 'West'],
2458
+ ['Type1', 'Type2', 'Type3'],
2459
+ ['Yes', 'No', 'Maybe'],
2460
+ ['Red', 'Green', 'Blue', 'Yellow'],
2461
+ ['Small', 'Medium', 'Large']
2462
+ ]
2463
+
2464
+ for i in range(1, 10):
2465
+ # Pick a category list
2466
+ cat_list = categories[i % len(categories)]
2467
+ # Generate random categorical column
2468
+ data[f'Category_{i}'] = np.random.choice(cat_list, n)
2469
+
2470
+ # Generate date and time related columns
2471
+ base_date = np.datetime64('2020-01-01')
2472
+
2473
+ # Instead of datetime objects, convert to days since base date (numeric values)
2474
+ hire_days = np.array([365 * (35 - a) + np.random.randint(0, 30) for a in age])
2475
+ data['Hire_Days_Ago'] = hire_days
2476
+
2477
+ promotion_days = np.array([np.random.randint(0, 1000) for _ in range(n)])
2478
+ data['Last_Promotion_Days_Ago'] = promotion_days
2479
+
2480
+ review_days = np.array([np.random.randint(1000, 1200) for _ in range(n)])
2481
+ data['Next_Review_In_Days'] = review_days
2482
+
2483
+ # For reference, also store the actual dates as strings instead of datetime64
2484
+ data['Hire_Date_Str'] = [str(base_date + np.timedelta64(int(days), 'D')) for days in hire_days]
2485
+ data['Last_Promotion_Date_Str'] = [str(base_date + np.timedelta64(int(days), 'D')) for days in promotion_days]
2486
+ data['Review_Date_Str'] = [str(base_date + np.timedelta64(int(days), 'D')) for days in review_days]
2487
+
2488
+ # Binary columns
2489
+ data['IsManager'] = np.random.choice([0, 1], n, p=[0.8, 0.2])
2490
+ data['RemoteWorker'] = np.random.choice([0, 1], n)
2491
+ data['HasHealthInsurance'] = np.random.choice([0, 1], n, p=[0.1, 0.9])
2492
+ data['HasRetirementPlan'] = np.random.choice([0, 1], n, p=[0.15, 0.85])
2493
+
2494
+ # Columns with missing values
2495
+ data['OptionalMetric_1'] = np.random.normal(50, 10, n)
2496
+ data['OptionalMetric_1'][np.random.choice([True, False], n, p=[0.2, 0.8])] = np.nan
2497
+
2498
+ data['OptionalMetric_2'] = np.random.normal(100, 20, n)
2499
+ data['OptionalMetric_2'][np.random.choice([True, False], n, p=[0.3, 0.7])] = np.nan
2500
+
2501
+ data['OptionalCategory'] = np.random.choice(['Option1', 'Option2', 'Option3', None], n, p=[0.3, 0.3, 0.3, 0.1])
2502
+
2503
+ # High cardinality column (like an ID)
2504
+ data['ID'] = [f"ID_{i:06d}" for i in range(n)]
2505
+
2506
+ # Create the DataFrame with 40 columns
2507
+ df = pd.DataFrame(data)
1094
2508
 
2509
+ print(f"Created sample DataFrame with {len(df.columns)} columns and {len(df)} rows")
2510
+ print("Columns:", ', '.join(df.columns))
1095
2511
  print("Launching PyQt6 Column Profiler application...")
1096
2512
  visualize_profile(df, 'Salary') # Start with Salary analysis
1097
2513