sqlshell 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlshell might be problematic. Click here for more details.
- sqlshell/__init__.py +35 -5
- sqlshell/db/__init__.py +2 -1
- sqlshell/db/database_manager.py +336 -23
- sqlshell/db/export_manager.py +188 -0
- sqlshell/editor_integration.py +127 -0
- sqlshell/execution_handler.py +421 -0
- sqlshell/main.py +570 -140
- sqlshell/query_tab.py +592 -7
- sqlshell/ui/filter_header.py +22 -1
- sqlshell/utils/profile_column.py +1586 -170
- sqlshell/utils/profile_foreign_keys.py +103 -11
- sqlshell/utils/profile_ohe.py +631 -0
- {sqlshell-0.2.3.dist-info → sqlshell-0.3.1.dist-info}/METADATA +126 -7
- {sqlshell-0.2.3.dist-info → sqlshell-0.3.1.dist-info}/RECORD +17 -13
- {sqlshell-0.2.3.dist-info → sqlshell-0.3.1.dist-info}/WHEEL +1 -1
- {sqlshell-0.2.3.dist-info → sqlshell-0.3.1.dist-info}/entry_points.txt +0 -0
- {sqlshell-0.2.3.dist-info → sqlshell-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from nltk.corpus import stopwords
|
|
4
|
+
from nltk.tokenize import word_tokenize
|
|
5
|
+
import nltk
|
|
6
|
+
|
|
7
|
+
# Download required NLTK data
|
|
8
|
+
try:
|
|
9
|
+
nltk.data.find('tokenizers/punkt')
|
|
10
|
+
except LookupError:
|
|
11
|
+
nltk.download('punkt')
|
|
12
|
+
try:
|
|
13
|
+
nltk.data.find('corpora/stopwords')
|
|
14
|
+
except LookupError:
|
|
15
|
+
nltk.download('stopwords')
|
|
16
|
+
# Download punkt_tab explicitly as required by the punkt tokenizer
|
|
17
|
+
try:
|
|
18
|
+
nltk.data.find('tokenizers/punkt_tab/english')
|
|
19
|
+
except LookupError:
|
|
20
|
+
nltk.download('punkt_tab')
|
|
21
|
+
|
|
22
|
+
def get_ohe(dataframe: pd.DataFrame, column: str, binary_format: str = "numeric") -> pd.DataFrame:
|
|
23
|
+
"""
|
|
24
|
+
Create one-hot encoded columns based on the content of the specified column.
|
|
25
|
+
Automatically detects whether the column contains text data or categorical data.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
dataframe (pd.DataFrame): Input dataframe
|
|
29
|
+
column (str): Name of the column to process
|
|
30
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
pd.DataFrame: Original dataframe with additional one-hot encoded columns
|
|
34
|
+
"""
|
|
35
|
+
# Check if column exists
|
|
36
|
+
if column not in dataframe.columns:
|
|
37
|
+
raise ValueError(f"Column '{column}' not found in dataframe")
|
|
38
|
+
|
|
39
|
+
# Check binary format is valid
|
|
40
|
+
if binary_format not in ["numeric", "text"]:
|
|
41
|
+
raise ValueError("binary_format must be either 'numeric' or 'text'")
|
|
42
|
+
|
|
43
|
+
# Check if the column appears to be categorical or text
|
|
44
|
+
# Heuristic: If average string length > 15 or contains spaces, treat as text
|
|
45
|
+
is_text = False
|
|
46
|
+
|
|
47
|
+
# Filter out non-string values
|
|
48
|
+
string_values = dataframe[column].dropna().astype(str)
|
|
49
|
+
if not len(string_values):
|
|
50
|
+
return dataframe # Nothing to process
|
|
51
|
+
|
|
52
|
+
# Check for spaces and average length
|
|
53
|
+
contains_spaces = any(' ' in str(val) for val in string_values)
|
|
54
|
+
avg_length = string_values.str.len().mean()
|
|
55
|
+
|
|
56
|
+
if contains_spaces or avg_length > 15:
|
|
57
|
+
is_text = True
|
|
58
|
+
|
|
59
|
+
# Apply appropriate encoding
|
|
60
|
+
if is_text:
|
|
61
|
+
# Apply text-based one-hot encoding
|
|
62
|
+
# Get stopwords
|
|
63
|
+
stop_words = set(stopwords.words('english'))
|
|
64
|
+
|
|
65
|
+
# Tokenize and count words
|
|
66
|
+
word_counts = {}
|
|
67
|
+
for text in dataframe[column]:
|
|
68
|
+
if isinstance(text, str):
|
|
69
|
+
# Tokenize and convert to lowercase
|
|
70
|
+
words = word_tokenize(text.lower())
|
|
71
|
+
# Remove stopwords and count
|
|
72
|
+
words = [word for word in words if word not in stop_words and word.isalnum()]
|
|
73
|
+
for word in words:
|
|
74
|
+
word_counts[word] = word_counts.get(word, 0) + 1
|
|
75
|
+
|
|
76
|
+
# Get top 10 most frequent words
|
|
77
|
+
top_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
|
78
|
+
top_words = [word for word, _ in top_words]
|
|
79
|
+
|
|
80
|
+
# Create one-hot encoded columns
|
|
81
|
+
for word in top_words:
|
|
82
|
+
column_name = f'has_{word}'
|
|
83
|
+
if binary_format == "numeric":
|
|
84
|
+
dataframe[column_name] = dataframe[column].apply(
|
|
85
|
+
lambda x: 1 if isinstance(x, str) and word in str(x).lower() else 0
|
|
86
|
+
)
|
|
87
|
+
else: # binary_format == "text"
|
|
88
|
+
dataframe[column_name] = dataframe[column].apply(
|
|
89
|
+
lambda x: "Yes" if isinstance(x, str) and word in str(x).lower() else "No"
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
# Apply categorical one-hot encoding
|
|
93
|
+
dataframe = get_categorical_ohe(dataframe, column, binary_format)
|
|
94
|
+
|
|
95
|
+
return dataframe
|
|
96
|
+
|
|
97
|
+
def get_categorical_ohe(dataframe: pd.DataFrame, categorical_column: str, binary_format: str = "numeric") -> pd.DataFrame:
|
|
98
|
+
"""
|
|
99
|
+
Create one-hot encoded columns for each unique category in a categorical column.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
dataframe (pd.DataFrame): Input dataframe
|
|
103
|
+
categorical_column (str): Name of the categorical column to process
|
|
104
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
pd.DataFrame: Original dataframe with additional one-hot encoded columns
|
|
108
|
+
"""
|
|
109
|
+
# Check binary format is valid
|
|
110
|
+
if binary_format not in ["numeric", "text"]:
|
|
111
|
+
raise ValueError("binary_format must be either 'numeric' or 'text'")
|
|
112
|
+
|
|
113
|
+
# Get unique categories
|
|
114
|
+
categories = dataframe[categorical_column].dropna().unique()
|
|
115
|
+
|
|
116
|
+
# Create one-hot encoded columns
|
|
117
|
+
for category in categories:
|
|
118
|
+
column_name = f'is_{category}'
|
|
119
|
+
if binary_format == "numeric":
|
|
120
|
+
dataframe[column_name] = dataframe[categorical_column].apply(
|
|
121
|
+
lambda x: 1 if x == category else 0
|
|
122
|
+
)
|
|
123
|
+
else: # binary_format == "text"
|
|
124
|
+
dataframe[column_name] = dataframe[categorical_column].apply(
|
|
125
|
+
lambda x: "Yes" if x == category else "No"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return dataframe
|
|
129
|
+
|
|
130
|
+
def test_ohe():
|
|
131
|
+
"""
|
|
132
|
+
Test the one-hot encoding function with sample dataframes for both text and categorical data.
|
|
133
|
+
Tests both numeric (1/0) and text (Yes/No) encoding formats.
|
|
134
|
+
"""
|
|
135
|
+
print("\n===== Testing Text Data One-Hot Encoding =====")
|
|
136
|
+
# Create sample text data
|
|
137
|
+
text_data = {
|
|
138
|
+
'text': [
|
|
139
|
+
'The quick brown fox jumps over the lazy dog',
|
|
140
|
+
'A quick brown dog runs in the park',
|
|
141
|
+
'The lazy cat sleeps all day',
|
|
142
|
+
'A brown fox and a lazy dog play together',
|
|
143
|
+
'The quick cat chases the mouse',
|
|
144
|
+
'A lazy dog sleeps in the sun',
|
|
145
|
+
'The brown fox is quick and clever',
|
|
146
|
+
'A cat and a dog are best friends',
|
|
147
|
+
'The quick mouse runs from the cat',
|
|
148
|
+
'A lazy fox sleeps in the shade'
|
|
149
|
+
]
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Create dataframe
|
|
153
|
+
text_df = pd.DataFrame(text_data)
|
|
154
|
+
|
|
155
|
+
# Test numeric format (1/0)
|
|
156
|
+
print("\n----- Testing Numeric Format (1/0) -----")
|
|
157
|
+
# Apply one-hot encoding with numeric format
|
|
158
|
+
text_result_numeric = get_ohe(text_df.copy(), 'text', binary_format="numeric")
|
|
159
|
+
|
|
160
|
+
# Print results
|
|
161
|
+
print("\nOriginal Text DataFrame:")
|
|
162
|
+
print(text_df)
|
|
163
|
+
print("\nDataFrame with Numeric One-Hot Encoded Columns (1/0):")
|
|
164
|
+
print(text_result_numeric)
|
|
165
|
+
|
|
166
|
+
# Verify that the function correctly identified this as text data
|
|
167
|
+
has_columns = [col for col in text_result_numeric.columns if col.startswith('has_')]
|
|
168
|
+
assert len(has_columns) > 0, "Text data was not correctly identified"
|
|
169
|
+
|
|
170
|
+
# Verify that all OHE columns contain only 0s and 1s
|
|
171
|
+
for col in has_columns:
|
|
172
|
+
assert set(text_result_numeric[col].unique()).issubset({0, 1}), f"Column {col} contains invalid values"
|
|
173
|
+
|
|
174
|
+
# Test text format (Yes/No)
|
|
175
|
+
print("\n----- Testing Text Format (Yes/No) -----")
|
|
176
|
+
# Apply one-hot encoding with text format
|
|
177
|
+
text_result_text = get_ohe(text_df.copy(), 'text', binary_format="text")
|
|
178
|
+
|
|
179
|
+
# Print results
|
|
180
|
+
print("\nDataFrame with Text One-Hot Encoded Columns (Yes/No):")
|
|
181
|
+
print(text_result_text)
|
|
182
|
+
|
|
183
|
+
# Verify that all OHE columns contain only Yes and No
|
|
184
|
+
has_columns_text = [col for col in text_result_text.columns if col.startswith('has_')]
|
|
185
|
+
for col in has_columns_text:
|
|
186
|
+
assert set(text_result_text[col].unique()).issubset({"Yes", "No"}), f"Column {col} contains invalid values"
|
|
187
|
+
|
|
188
|
+
print("\nText data tests passed successfully!")
|
|
189
|
+
|
|
190
|
+
print("\n===== Testing Categorical Data One-Hot Encoding =====")
|
|
191
|
+
# Create sample data with categorical values
|
|
192
|
+
categorical_data = {
|
|
193
|
+
'category': [
|
|
194
|
+
'red', 'blue', 'green', 'red', 'yellow',
|
|
195
|
+
'blue', 'green', 'red', 'yellow', 'blue'
|
|
196
|
+
]
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
# Create dataframe
|
|
200
|
+
cat_df = pd.DataFrame(categorical_data)
|
|
201
|
+
|
|
202
|
+
# Test numeric format (1/0)
|
|
203
|
+
print("\n----- Testing Numeric Format (1/0) -----")
|
|
204
|
+
# Apply one-hot encoding with numeric format
|
|
205
|
+
cat_result_numeric = get_ohe(cat_df.copy(), 'category', binary_format="numeric")
|
|
206
|
+
|
|
207
|
+
# Print results
|
|
208
|
+
print("\nOriginal Categorical DataFrame:")
|
|
209
|
+
print(cat_df)
|
|
210
|
+
print("\nDataFrame with Numeric One-Hot Encoded Columns (1/0):")
|
|
211
|
+
print(cat_result_numeric)
|
|
212
|
+
|
|
213
|
+
# Verify that the function correctly identified this as categorical data
|
|
214
|
+
is_columns = [col for col in cat_result_numeric.columns if col.startswith('is_')]
|
|
215
|
+
assert len(is_columns) > 0, "Categorical data was not correctly identified"
|
|
216
|
+
|
|
217
|
+
# Verify that we have the expected number of columns for categorical data
|
|
218
|
+
unique_categories = len(cat_df['category'].unique())
|
|
219
|
+
assert len(is_columns) == unique_categories, "Incorrect number of categorical columns"
|
|
220
|
+
|
|
221
|
+
# Verify that all OHE columns contain only 0s and 1s
|
|
222
|
+
for col in is_columns:
|
|
223
|
+
assert set(cat_result_numeric[col].unique()).issubset({0, 1}), f"Column {col} contains invalid values"
|
|
224
|
+
|
|
225
|
+
# Test text format (Yes/No)
|
|
226
|
+
print("\n----- Testing Text Format (Yes/No) -----")
|
|
227
|
+
# Apply one-hot encoding with text format
|
|
228
|
+
cat_result_text = get_ohe(cat_df.copy(), 'category', binary_format="text")
|
|
229
|
+
|
|
230
|
+
# Print results
|
|
231
|
+
print("\nDataFrame with Text One-Hot Encoded Columns (Yes/No):")
|
|
232
|
+
print(cat_result_text)
|
|
233
|
+
|
|
234
|
+
# Verify that all OHE columns contain only Yes and No
|
|
235
|
+
is_columns_text = [col for col in cat_result_text.columns if col.startswith('is_')]
|
|
236
|
+
for col in is_columns_text:
|
|
237
|
+
assert set(cat_result_text[col].unique()).issubset({"Yes", "No"}), f"Column {col} contains invalid values"
|
|
238
|
+
|
|
239
|
+
print("\nCategorical data tests passed successfully!")
|
|
240
|
+
|
|
241
|
+
def test_categorical_ohe():
|
|
242
|
+
"""
|
|
243
|
+
Test the categorical one-hot encoding function with a sample dataframe.
|
|
244
|
+
Tests both numeric (1/0) and text (Yes/No) encoding formats.
|
|
245
|
+
"""
|
|
246
|
+
# Create sample data with categorical values
|
|
247
|
+
data = {
|
|
248
|
+
'category': [
|
|
249
|
+
'red', 'blue', 'green', 'red', 'yellow',
|
|
250
|
+
'blue', 'green', 'red', 'yellow', 'blue'
|
|
251
|
+
]
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
# Create dataframe
|
|
255
|
+
df = pd.DataFrame(data)
|
|
256
|
+
|
|
257
|
+
# Test numeric format (1/0)
|
|
258
|
+
print("\n----- Testing Numeric Format (1/0) -----")
|
|
259
|
+
# Apply categorical one-hot encoding with numeric format
|
|
260
|
+
result_numeric = get_categorical_ohe(df.copy(), 'category', binary_format="numeric")
|
|
261
|
+
|
|
262
|
+
# Print results
|
|
263
|
+
print("\nOriginal DataFrame:")
|
|
264
|
+
print(df)
|
|
265
|
+
print("\nDataFrame with Numeric One-Hot Encoded Columns (1/0):")
|
|
266
|
+
print(result_numeric)
|
|
267
|
+
|
|
268
|
+
# Verify that we have the expected number of columns
|
|
269
|
+
unique_categories = len(df['category'].unique())
|
|
270
|
+
is_columns = [col for col in result_numeric.columns if col.startswith('is_')]
|
|
271
|
+
assert len(is_columns) == unique_categories, "Incorrect number of categorical columns"
|
|
272
|
+
|
|
273
|
+
# Verify that all OHE columns contain only 0s and 1s
|
|
274
|
+
for col in is_columns:
|
|
275
|
+
assert set(result_numeric[col].unique()).issubset({0, 1}), f"Column {col} contains invalid values"
|
|
276
|
+
|
|
277
|
+
# Test text format (Yes/No)
|
|
278
|
+
print("\n----- Testing Text Format (Yes/No) -----")
|
|
279
|
+
# Apply categorical one-hot encoding with text format
|
|
280
|
+
result_text = get_categorical_ohe(df.copy(), 'category', binary_format="text")
|
|
281
|
+
|
|
282
|
+
# Print results
|
|
283
|
+
print("\nDataFrame with Text One-Hot Encoded Columns (Yes/No):")
|
|
284
|
+
print(result_text)
|
|
285
|
+
|
|
286
|
+
# Verify that all OHE columns contain only Yes and No
|
|
287
|
+
is_columns_text = [col for col in result_text.columns if col.startswith('is_')]
|
|
288
|
+
for col in is_columns_text:
|
|
289
|
+
assert set(result_text[col].unique()).issubset({"Yes", "No"}), f"Column {col} contains invalid values"
|
|
290
|
+
|
|
291
|
+
print("\nAll categorical tests passed successfully!")
|
|
292
|
+
|
|
293
|
+
# Add visualization functionality
|
|
294
|
+
from PyQt6.QtWidgets import (QMainWindow, QVBoxLayout, QHBoxLayout, QWidget,
|
|
295
|
+
QTableWidget, QTableWidgetItem, QLabel, QPushButton,
|
|
296
|
+
QComboBox, QSplitter, QTabWidget, QScrollArea,
|
|
297
|
+
QFrame, QSizePolicy, QButtonGroup, QRadioButton,
|
|
298
|
+
QMessageBox, QHeaderView, QApplication)
|
|
299
|
+
from PyQt6.QtCore import Qt, QSize, pyqtSignal
|
|
300
|
+
from PyQt6.QtGui import QFont
|
|
301
|
+
import matplotlib.pyplot as plt
|
|
302
|
+
import seaborn as sns
|
|
303
|
+
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
|
|
304
|
+
|
|
305
|
+
class OneHotEncodingVisualization(QMainWindow):
|
|
306
|
+
# Add signal to notify when encoding should be applied
|
|
307
|
+
encodingApplied = pyqtSignal(pd.DataFrame)
|
|
308
|
+
|
|
309
|
+
def __init__(self, original_df, encoded_df, encoded_column, binary_format="numeric"):
|
|
310
|
+
super().__init__()
|
|
311
|
+
self.original_df = original_df
|
|
312
|
+
self.encoded_df = encoded_df
|
|
313
|
+
self.encoded_column = encoded_column
|
|
314
|
+
self.binary_format = binary_format
|
|
315
|
+
self.setWindowTitle(f"One-Hot Encoding Visualization - {encoded_column}")
|
|
316
|
+
self.setGeometry(100, 100, 1000, 800)
|
|
317
|
+
|
|
318
|
+
# Main widget
|
|
319
|
+
main_widget = QWidget()
|
|
320
|
+
self.setCentralWidget(main_widget)
|
|
321
|
+
|
|
322
|
+
# Main layout
|
|
323
|
+
main_layout = QVBoxLayout(main_widget)
|
|
324
|
+
|
|
325
|
+
# Title
|
|
326
|
+
title_label = QLabel(f"One-Hot Encoding Analysis: {encoded_column}")
|
|
327
|
+
title_font = QFont()
|
|
328
|
+
title_font.setBold(True)
|
|
329
|
+
title_font.setPointSize(14)
|
|
330
|
+
title_label.setFont(title_font)
|
|
331
|
+
title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
332
|
+
main_layout.addWidget(title_label)
|
|
333
|
+
|
|
334
|
+
# Description
|
|
335
|
+
description = "One-hot encoding transforms categorical data into a binary matrix format where each category becomes a separate binary column."
|
|
336
|
+
desc_label = QLabel(description)
|
|
337
|
+
desc_label.setWordWrap(True)
|
|
338
|
+
main_layout.addWidget(desc_label)
|
|
339
|
+
|
|
340
|
+
# Format selector
|
|
341
|
+
format_layout = QHBoxLayout()
|
|
342
|
+
format_label = QLabel("Encoding Format:")
|
|
343
|
+
self.format_selector = QComboBox()
|
|
344
|
+
self.format_selector.addItems(["Numeric (1/0)", "Text (Yes/No)"])
|
|
345
|
+
self.format_selector.setCurrentIndex(0 if binary_format == "numeric" else 1)
|
|
346
|
+
self.format_selector.currentIndexChanged.connect(self.change_format)
|
|
347
|
+
format_layout.addWidget(format_label)
|
|
348
|
+
format_layout.addWidget(self.format_selector)
|
|
349
|
+
format_layout.addStretch(1)
|
|
350
|
+
main_layout.addLayout(format_layout)
|
|
351
|
+
|
|
352
|
+
# Splitter to divide the screen
|
|
353
|
+
splitter = QSplitter(Qt.Orientation.Vertical)
|
|
354
|
+
main_layout.addWidget(splitter, 1)
|
|
355
|
+
|
|
356
|
+
# Top widget: Data view
|
|
357
|
+
top_widget = QWidget()
|
|
358
|
+
top_layout = QVBoxLayout(top_widget)
|
|
359
|
+
|
|
360
|
+
# Create tab widget for different views
|
|
361
|
+
tab_widget = QTabWidget()
|
|
362
|
+
|
|
363
|
+
# Tab 1: Original data
|
|
364
|
+
original_tab = QWidget()
|
|
365
|
+
original_layout = QVBoxLayout(original_tab)
|
|
366
|
+
original_table = self.create_table_from_df(self.original_df)
|
|
367
|
+
original_layout.addWidget(original_table)
|
|
368
|
+
tab_widget.addTab(original_tab, "Original Data")
|
|
369
|
+
|
|
370
|
+
# Tab 2: Encoded data
|
|
371
|
+
encoded_tab = QWidget()
|
|
372
|
+
encoded_layout = QVBoxLayout(encoded_tab)
|
|
373
|
+
encoded_table = self.create_table_from_df(self.encoded_df)
|
|
374
|
+
encoded_layout.addWidget(encoded_table)
|
|
375
|
+
tab_widget.addTab(encoded_tab, "Encoded Data")
|
|
376
|
+
|
|
377
|
+
top_layout.addWidget(tab_widget)
|
|
378
|
+
splitter.addWidget(top_widget)
|
|
379
|
+
|
|
380
|
+
# Bottom widget: Visualizations
|
|
381
|
+
bottom_widget = QWidget()
|
|
382
|
+
bottom_layout = QVBoxLayout(bottom_widget)
|
|
383
|
+
|
|
384
|
+
# Visualization title
|
|
385
|
+
viz_title = QLabel("Visualization")
|
|
386
|
+
viz_title.setFont(title_font)
|
|
387
|
+
bottom_layout.addWidget(viz_title)
|
|
388
|
+
|
|
389
|
+
# Create matplotlib figure
|
|
390
|
+
self.figure = plt.figure(figsize=(8, 6))
|
|
391
|
+
self.canvas = FigureCanvas(self.figure)
|
|
392
|
+
bottom_layout.addWidget(self.canvas)
|
|
393
|
+
|
|
394
|
+
# Visualization type selector
|
|
395
|
+
viz_selector_layout = QHBoxLayout()
|
|
396
|
+
viz_selector_label = QLabel("Visualization Type:")
|
|
397
|
+
self.viz_selector = QComboBox()
|
|
398
|
+
self.viz_selector.addItems(["Value Counts", "Correlation Heatmap"])
|
|
399
|
+
self.viz_selector.currentIndexChanged.connect(self.update_visualization)
|
|
400
|
+
viz_selector_layout.addWidget(viz_selector_label)
|
|
401
|
+
viz_selector_layout.addWidget(self.viz_selector)
|
|
402
|
+
viz_selector_layout.addStretch(1)
|
|
403
|
+
bottom_layout.addLayout(viz_selector_layout)
|
|
404
|
+
|
|
405
|
+
# Add Apply Button
|
|
406
|
+
apply_layout = QHBoxLayout()
|
|
407
|
+
apply_layout.addStretch(1)
|
|
408
|
+
|
|
409
|
+
self.apply_button = QPushButton("Apply Encoding")
|
|
410
|
+
self.apply_button.setStyleSheet("""
|
|
411
|
+
QPushButton {
|
|
412
|
+
background-color: #3498DB;
|
|
413
|
+
color: white;
|
|
414
|
+
border: none;
|
|
415
|
+
padding: 8px 16px;
|
|
416
|
+
border-radius: 4px;
|
|
417
|
+
font-weight: bold;
|
|
418
|
+
}
|
|
419
|
+
QPushButton:hover {
|
|
420
|
+
background-color: #2980B9;
|
|
421
|
+
}
|
|
422
|
+
QPushButton:pressed {
|
|
423
|
+
background-color: #1F618D;
|
|
424
|
+
}
|
|
425
|
+
""")
|
|
426
|
+
self.apply_button.setMinimumWidth(150)
|
|
427
|
+
self.apply_button.clicked.connect(self.apply_encoding)
|
|
428
|
+
apply_layout.addWidget(self.apply_button)
|
|
429
|
+
|
|
430
|
+
bottom_layout.addLayout(apply_layout)
|
|
431
|
+
|
|
432
|
+
splitter.addWidget(bottom_widget)
|
|
433
|
+
|
|
434
|
+
# Set initial splitter sizes
|
|
435
|
+
splitter.setSizes([300, 500])
|
|
436
|
+
|
|
437
|
+
# Create initial visualization
|
|
438
|
+
self.update_visualization()
|
|
439
|
+
|
|
440
|
+
def create_table_from_df(self, df):
|
|
441
|
+
"""Create a table widget from a dataframe"""
|
|
442
|
+
table = QTableWidget()
|
|
443
|
+
table.setRowCount(min(100, len(df))) # Limit to 100 rows for performance
|
|
444
|
+
table.setColumnCount(len(df.columns))
|
|
445
|
+
|
|
446
|
+
# Set headers
|
|
447
|
+
table.setHorizontalHeaderLabels(df.columns)
|
|
448
|
+
|
|
449
|
+
# Fill data
|
|
450
|
+
for row in range(min(100, len(df))):
|
|
451
|
+
for col, col_name in enumerate(df.columns):
|
|
452
|
+
value = str(df.iloc[row, col])
|
|
453
|
+
item = QTableWidgetItem(value)
|
|
454
|
+
table.setItem(row, col, item)
|
|
455
|
+
|
|
456
|
+
# Optimize appearance
|
|
457
|
+
table.resizeColumnsToContents()
|
|
458
|
+
table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive)
|
|
459
|
+
return table
|
|
460
|
+
|
|
461
|
+
def update_visualization(self):
|
|
462
|
+
"""Update the visualization based on the selected type"""
|
|
463
|
+
viz_type = self.viz_selector.currentText()
|
|
464
|
+
|
|
465
|
+
# Clear previous plot
|
|
466
|
+
self.figure.clear()
|
|
467
|
+
|
|
468
|
+
# Get the encoded columns (those starting with 'is_' or 'has_')
|
|
469
|
+
is_columns = [col for col in self.encoded_df.columns if col.startswith('is_')]
|
|
470
|
+
has_columns = [col for col in self.encoded_df.columns if col.startswith('has_')]
|
|
471
|
+
encoded_columns = is_columns + has_columns
|
|
472
|
+
|
|
473
|
+
if viz_type == "Value Counts":
|
|
474
|
+
# Create value counts visualization
|
|
475
|
+
ax = self.figure.add_subplot(111)
|
|
476
|
+
|
|
477
|
+
# Get value counts from original column
|
|
478
|
+
value_counts = self.original_df[self.encoded_column].value_counts()
|
|
479
|
+
|
|
480
|
+
# Plot
|
|
481
|
+
if len(value_counts) > 15:
|
|
482
|
+
# For high cardinality, show top 15
|
|
483
|
+
value_counts.nlargest(15).plot(kind='barh', ax=ax)
|
|
484
|
+
ax.set_title(f"Top 15 Values in {self.encoded_column}")
|
|
485
|
+
else:
|
|
486
|
+
value_counts.plot(kind='barh', ax=ax)
|
|
487
|
+
ax.set_title(f"Value Counts in {self.encoded_column}")
|
|
488
|
+
|
|
489
|
+
ax.set_xlabel("Count")
|
|
490
|
+
ax.set_ylabel(self.encoded_column)
|
|
491
|
+
|
|
492
|
+
elif viz_type == "Correlation Heatmap":
|
|
493
|
+
# Create correlation heatmap for one-hot encoded columns
|
|
494
|
+
if len(encoded_columns) > 0:
|
|
495
|
+
ax = self.figure.add_subplot(111)
|
|
496
|
+
|
|
497
|
+
# Get subset with just the encoded columns
|
|
498
|
+
encoded_subset = self.encoded_df[encoded_columns]
|
|
499
|
+
|
|
500
|
+
# Calculate correlation matrix
|
|
501
|
+
corr_matrix = encoded_subset.corr()
|
|
502
|
+
|
|
503
|
+
# Create heatmap
|
|
504
|
+
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', linewidths=0.5, ax=ax)
|
|
505
|
+
ax.set_title(f"Correlation Between Encoded Features")
|
|
506
|
+
else:
|
|
507
|
+
# No encoded columns found
|
|
508
|
+
ax = self.figure.add_subplot(111)
|
|
509
|
+
ax.text(0.5, 0.5, "No encoded columns found",
|
|
510
|
+
horizontalalignment='center', verticalalignment='center',
|
|
511
|
+
transform=ax.transAxes)
|
|
512
|
+
ax.axis('off')
|
|
513
|
+
|
|
514
|
+
# Update the canvas
|
|
515
|
+
self.canvas.draw()
|
|
516
|
+
|
|
517
|
+
def apply_encoding(self):
|
|
518
|
+
"""Apply the encoded dataframe to the main window"""
|
|
519
|
+
reply = QMessageBox.question(
|
|
520
|
+
self,
|
|
521
|
+
"Apply Encoding",
|
|
522
|
+
"Are you sure you want to apply this encoding to the original table?\n\n"
|
|
523
|
+
"This will add the one-hot encoded columns to the current result table.",
|
|
524
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
525
|
+
QMessageBox.StandardButton.No
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
529
|
+
# Emit signal with the encoded DataFrame
|
|
530
|
+
self.encodingApplied.emit(self.encoded_df)
|
|
531
|
+
QMessageBox.information(
|
|
532
|
+
self,
|
|
533
|
+
"Encoding Applied",
|
|
534
|
+
"The one-hot encoding has been applied to the table."
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def change_format(self):
|
|
538
|
+
"""Change the binary format and reapply encoding"""
|
|
539
|
+
# Get the selected format
|
|
540
|
+
selected_format = "numeric" if self.format_selector.currentIndex() == 0 else "text"
|
|
541
|
+
|
|
542
|
+
# Only update if format has changed
|
|
543
|
+
if selected_format != self.binary_format:
|
|
544
|
+
# Update format
|
|
545
|
+
self.binary_format = selected_format
|
|
546
|
+
|
|
547
|
+
# Reapply encoding
|
|
548
|
+
self.encoded_df = get_ohe(self.original_df.copy(), self.encoded_column, self.binary_format)
|
|
549
|
+
|
|
550
|
+
# Update tables
|
|
551
|
+
tab_widget = self.findChild(QTabWidget)
|
|
552
|
+
if tab_widget:
|
|
553
|
+
# Update encoded data tab
|
|
554
|
+
encoded_tab = tab_widget.widget(1)
|
|
555
|
+
if encoded_tab:
|
|
556
|
+
# Clear old layout
|
|
557
|
+
for i in reversed(range(encoded_tab.layout().count())):
|
|
558
|
+
encoded_tab.layout().itemAt(i).widget().setParent(None)
|
|
559
|
+
|
|
560
|
+
# Add new table
|
|
561
|
+
encoded_table = self.create_table_from_df(self.encoded_df)
|
|
562
|
+
encoded_tab.layout().addWidget(encoded_table)
|
|
563
|
+
|
|
564
|
+
# Update visualization
|
|
565
|
+
self.update_visualization()
|
|
566
|
+
|
|
567
|
+
# Show confirmation
|
|
568
|
+
QMessageBox.information(
|
|
569
|
+
self,
|
|
570
|
+
"Format Changed",
|
|
571
|
+
f"Encoding format changed to {selected_format}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
def visualize_ohe(df, column, binary_format="numeric"):
|
|
575
|
+
"""
|
|
576
|
+
Visualize the one-hot encoding of a column in a dataframe.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
df (pd.DataFrame): The original dataframe
|
|
580
|
+
column (str): The column to encode and visualize
|
|
581
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
582
|
+
|
|
583
|
+
Returns:
|
|
584
|
+
QMainWindow: The visualization window
|
|
585
|
+
"""
|
|
586
|
+
# Create a copy to avoid modifying the original
|
|
587
|
+
original_df = df.copy()
|
|
588
|
+
|
|
589
|
+
# Apply one-hot encoding
|
|
590
|
+
encoded_df = get_ohe(original_df, column, binary_format)
|
|
591
|
+
|
|
592
|
+
# Create and show the visualization
|
|
593
|
+
vis = OneHotEncodingVisualization(original_df, encoded_df, column, binary_format)
|
|
594
|
+
vis.show()
|
|
595
|
+
|
|
596
|
+
return vis
|
|
597
|
+
|
|
598
|
+
if __name__ == "__main__":
|
|
599
|
+
# Run tests
|
|
600
|
+
test_ohe()
|
|
601
|
+
test_categorical_ohe()
|
|
602
|
+
|
|
603
|
+
# Test the visualization with both formats
|
|
604
|
+
import sys
|
|
605
|
+
from PyQt6.QtWidgets import QApplication
|
|
606
|
+
|
|
607
|
+
if QApplication.instance() is None:
|
|
608
|
+
app = QApplication(sys.argv)
|
|
609
|
+
|
|
610
|
+
# Create a sample dataframe
|
|
611
|
+
data = {
|
|
612
|
+
'category': ['red', 'blue', 'green', 'red', 'yellow', 'blue'],
|
|
613
|
+
'text': [
|
|
614
|
+
'The quick brown fox',
|
|
615
|
+
'A lazy dog',
|
|
616
|
+
'Brown fox jumps',
|
|
617
|
+
'Quick brown fox',
|
|
618
|
+
'Lazy dog sleeps',
|
|
619
|
+
'Fox and dog'
|
|
620
|
+
]
|
|
621
|
+
}
|
|
622
|
+
df = pd.DataFrame(data)
|
|
623
|
+
|
|
624
|
+
# Show visualization with numeric format
|
|
625
|
+
vis_numeric = visualize_ohe(df, 'category', binary_format="numeric")
|
|
626
|
+
|
|
627
|
+
# Show visualization with text format
|
|
628
|
+
vis_text = visualize_ohe(df, 'text', binary_format="text")
|
|
629
|
+
|
|
630
|
+
# Start the application
|
|
631
|
+
sys.exit(app.exec())
|