sqlshell 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlshell/__init__.py +84 -0
- sqlshell/__main__.py +4926 -0
- sqlshell/ai_autocomplete.py +392 -0
- sqlshell/ai_settings_dialog.py +337 -0
- sqlshell/context_suggester.py +768 -0
- sqlshell/create_test_data.py +152 -0
- sqlshell/data/create_test_data.py +137 -0
- sqlshell/db/__init__.py +6 -0
- sqlshell/db/database_manager.py +1318 -0
- sqlshell/db/export_manager.py +188 -0
- sqlshell/editor.py +1166 -0
- sqlshell/editor_integration.py +127 -0
- sqlshell/execution_handler.py +421 -0
- sqlshell/menus.py +262 -0
- sqlshell/notification_manager.py +370 -0
- sqlshell/query_tab.py +904 -0
- sqlshell/resources/__init__.py +1 -0
- sqlshell/resources/icon.png +0 -0
- sqlshell/resources/logo_large.png +0 -0
- sqlshell/resources/logo_medium.png +0 -0
- sqlshell/resources/logo_small.png +0 -0
- sqlshell/resources/splash_screen.gif +0 -0
- sqlshell/space_invaders.py +501 -0
- sqlshell/splash_screen.py +405 -0
- sqlshell/sqlshell/__init__.py +5 -0
- sqlshell/sqlshell/create_test_data.py +118 -0
- sqlshell/sqlshell/create_test_databases.py +96 -0
- sqlshell/sqlshell_demo.png +0 -0
- sqlshell/styles.py +257 -0
- sqlshell/suggester_integration.py +330 -0
- sqlshell/syntax_highlighter.py +124 -0
- sqlshell/table_list.py +996 -0
- sqlshell/ui/__init__.py +6 -0
- sqlshell/ui/bar_chart_delegate.py +49 -0
- sqlshell/ui/filter_header.py +469 -0
- sqlshell/utils/__init__.py +16 -0
- sqlshell/utils/profile_cn2.py +1661 -0
- sqlshell/utils/profile_column.py +2635 -0
- sqlshell/utils/profile_distributions.py +616 -0
- sqlshell/utils/profile_entropy.py +347 -0
- sqlshell/utils/profile_foreign_keys.py +779 -0
- sqlshell/utils/profile_keys.py +2834 -0
- sqlshell/utils/profile_ohe.py +934 -0
- sqlshell/utils/profile_ohe_advanced.py +754 -0
- sqlshell/utils/profile_ohe_comparison.py +237 -0
- sqlshell/utils/profile_prediction.py +926 -0
- sqlshell/utils/profile_similarity.py +876 -0
- sqlshell/utils/search_in_df.py +90 -0
- sqlshell/widgets.py +400 -0
- sqlshell-0.4.4.dist-info/METADATA +441 -0
- sqlshell-0.4.4.dist-info/RECORD +54 -0
- sqlshell-0.4.4.dist-info/WHEEL +5 -0
- sqlshell-0.4.4.dist-info/entry_points.txt +2 -0
- sqlshell-0.4.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,934 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
# Flag to track if NLTK is available
|
|
7
|
+
NLTK_AVAILABLE = False
|
|
8
|
+
|
|
9
|
+
def _setup_nltk_data_path():
|
|
10
|
+
"""Configure NLTK to find data in bundled location (for PyInstaller builds)"""
|
|
11
|
+
import nltk
|
|
12
|
+
|
|
13
|
+
# Check if running from a PyInstaller bundle
|
|
14
|
+
if getattr(sys, 'frozen', False):
|
|
15
|
+
# Running in a PyInstaller bundle
|
|
16
|
+
bundle_dir = sys._MEIPASS
|
|
17
|
+
nltk_data_path = os.path.join(bundle_dir, 'nltk_data')
|
|
18
|
+
if os.path.exists(nltk_data_path):
|
|
19
|
+
nltk.data.path.insert(0, nltk_data_path)
|
|
20
|
+
|
|
21
|
+
# Also check relative to the application
|
|
22
|
+
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
23
|
+
possible_paths = [
|
|
24
|
+
os.path.join(app_dir, 'nltk_data'),
|
|
25
|
+
os.path.join(os.path.dirname(app_dir), 'nltk_data'),
|
|
26
|
+
]
|
|
27
|
+
for path in possible_paths:
|
|
28
|
+
if os.path.exists(path) and path not in nltk.data.path:
|
|
29
|
+
nltk.data.path.insert(0, path)
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
import nltk
|
|
33
|
+
_setup_nltk_data_path()
|
|
34
|
+
from nltk.corpus import stopwords
|
|
35
|
+
from nltk.tokenize import word_tokenize
|
|
36
|
+
|
|
37
|
+
# Try to find required NLTK data, download if missing
|
|
38
|
+
try:
|
|
39
|
+
nltk.data.find('tokenizers/punkt')
|
|
40
|
+
except LookupError:
|
|
41
|
+
try:
|
|
42
|
+
nltk.download('punkt', quiet=True)
|
|
43
|
+
except Exception:
|
|
44
|
+
pass # Download failed silently - NLTK features will be unavailable
|
|
45
|
+
try:
|
|
46
|
+
nltk.data.find('corpora/stopwords')
|
|
47
|
+
except LookupError:
|
|
48
|
+
try:
|
|
49
|
+
nltk.download('stopwords', quiet=True)
|
|
50
|
+
except Exception:
|
|
51
|
+
pass # Download failed silently - NLTK features will be unavailable
|
|
52
|
+
try:
|
|
53
|
+
nltk.data.find('tokenizers/punkt_tab/english')
|
|
54
|
+
except LookupError:
|
|
55
|
+
try:
|
|
56
|
+
nltk.download('punkt_tab', quiet=True)
|
|
57
|
+
except Exception:
|
|
58
|
+
pass # Download failed silently - NLTK features will be unavailable
|
|
59
|
+
|
|
60
|
+
# Test if NLTK is actually working
|
|
61
|
+
try:
|
|
62
|
+
_ = stopwords.words('english')
|
|
63
|
+
_ = word_tokenize("test")
|
|
64
|
+
NLTK_AVAILABLE = True
|
|
65
|
+
except Exception:
|
|
66
|
+
NLTK_AVAILABLE = False
|
|
67
|
+
|
|
68
|
+
except ImportError:
|
|
69
|
+
NLTK_AVAILABLE = False
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _simple_tokenize(text):
|
|
73
|
+
"""Simple fallback tokenizer when NLTK is not available"""
|
|
74
|
+
import re
|
|
75
|
+
# Simple word tokenization using regex
|
|
76
|
+
return re.findall(r'\b[a-zA-Z]+\b', text.lower())
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_simple_stopwords():
|
|
80
|
+
"""Return a basic set of English stopwords when NLTK is not available"""
|
|
81
|
+
return {
|
|
82
|
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
|
83
|
+
'of', 'with', 'by', 'from', 'as', 'is', 'was', 'are', 'were', 'been',
|
|
84
|
+
'be', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could',
|
|
85
|
+
'should', 'may', 'might', 'must', 'shall', 'can', 'need', 'dare', 'ought',
|
|
86
|
+
'used', 'it', 'its', 'this', 'that', 'these', 'those', 'i', 'me', 'my',
|
|
87
|
+
'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours',
|
|
88
|
+
'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her',
|
|
89
|
+
'hers', 'herself', 'they', 'them', 'their', 'theirs', 'themselves',
|
|
90
|
+
'what', 'which', 'who', 'whom', 'when', 'where', 'why', 'how', 'all',
|
|
91
|
+
'each', 'every', 'both', 'few', 'more', 'most', 'other', 'some', 'such',
|
|
92
|
+
'no', 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very',
|
|
93
|
+
'just', 'also', 'now', 'here', 'there', 'then', 'once', 'if', 'because',
|
|
94
|
+
'while', 'although', 'though', 'after', 'before', 'since', 'until', 'unless'
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
def get_ohe(dataframe: pd.DataFrame, column: str, binary_format: str = "numeric",
|
|
98
|
+
algorithm: str = "basic") -> pd.DataFrame:
|
|
99
|
+
"""
|
|
100
|
+
Create one-hot encoded columns based on the content of the specified column.
|
|
101
|
+
Automatically detects whether the column contains text data or categorical data.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
dataframe (pd.DataFrame): Input dataframe
|
|
105
|
+
column (str): Name of the column to process
|
|
106
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
107
|
+
algorithm (str): Algorithm to use - "basic", "advanced", or "comprehensive"
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
pd.DataFrame: Original dataframe with additional one-hot encoded columns
|
|
111
|
+
"""
|
|
112
|
+
# Check if column exists
|
|
113
|
+
if column not in dataframe.columns:
|
|
114
|
+
raise ValueError(f"Column '{column}' not found in dataframe")
|
|
115
|
+
|
|
116
|
+
# Check binary format is valid
|
|
117
|
+
if binary_format not in ["numeric", "text"]:
|
|
118
|
+
raise ValueError("binary_format must be either 'numeric' or 'text'")
|
|
119
|
+
|
|
120
|
+
# Check algorithm is valid
|
|
121
|
+
if algorithm not in ["basic", "advanced", "comprehensive"]:
|
|
122
|
+
raise ValueError("algorithm must be 'basic', 'advanced', or 'comprehensive'")
|
|
123
|
+
|
|
124
|
+
# Use advanced algorithms if requested
|
|
125
|
+
if algorithm in ["advanced", "comprehensive"]:
|
|
126
|
+
try:
|
|
127
|
+
# Try relative import first
|
|
128
|
+
try:
|
|
129
|
+
from .profile_ohe_advanced import get_advanced_ohe
|
|
130
|
+
except ImportError:
|
|
131
|
+
# Fall back to direct import
|
|
132
|
+
import sys
|
|
133
|
+
import os
|
|
134
|
+
sys.path.insert(0, os.path.dirname(__file__))
|
|
135
|
+
from profile_ohe_advanced import get_advanced_ohe
|
|
136
|
+
|
|
137
|
+
return get_advanced_ohe(dataframe, column, binary_format,
|
|
138
|
+
analysis_type=algorithm, max_features=20)
|
|
139
|
+
except ImportError as e:
|
|
140
|
+
print(f"Advanced algorithms not available ({e}). Using basic approach.")
|
|
141
|
+
algorithm = "basic"
|
|
142
|
+
|
|
143
|
+
# Original basic algorithm
|
|
144
|
+
# Check if the column appears to be categorical or text
|
|
145
|
+
# Heuristic: If average string length > 15 or contains spaces, treat as text
|
|
146
|
+
is_text = False
|
|
147
|
+
|
|
148
|
+
# Filter out non-string values
|
|
149
|
+
string_values = dataframe[column].dropna().astype(str)
|
|
150
|
+
if not len(string_values):
|
|
151
|
+
return dataframe # Nothing to process
|
|
152
|
+
|
|
153
|
+
# Check for spaces and average length
|
|
154
|
+
contains_spaces = any(' ' in str(val) for val in string_values)
|
|
155
|
+
avg_length = string_values.str.len().mean()
|
|
156
|
+
|
|
157
|
+
if contains_spaces or avg_length > 15:
|
|
158
|
+
is_text = True
|
|
159
|
+
|
|
160
|
+
# Apply appropriate encoding
|
|
161
|
+
if is_text:
|
|
162
|
+
# Apply text-based one-hot encoding
|
|
163
|
+
# Get stopwords (use NLTK if available, otherwise fallback)
|
|
164
|
+
if NLTK_AVAILABLE:
|
|
165
|
+
stop_words = set(stopwords.words('english'))
|
|
166
|
+
else:
|
|
167
|
+
stop_words = _get_simple_stopwords()
|
|
168
|
+
|
|
169
|
+
# Tokenize and count words
|
|
170
|
+
word_counts = {}
|
|
171
|
+
for text in dataframe[column]:
|
|
172
|
+
if isinstance(text, str):
|
|
173
|
+
# Tokenize and convert to lowercase (use NLTK if available, otherwise fallback)
|
|
174
|
+
if NLTK_AVAILABLE:
|
|
175
|
+
words = word_tokenize(text.lower())
|
|
176
|
+
else:
|
|
177
|
+
words = _simple_tokenize(text)
|
|
178
|
+
# Remove stopwords and count
|
|
179
|
+
words = [word for word in words if word not in stop_words and word.isalnum()]
|
|
180
|
+
for word in words:
|
|
181
|
+
word_counts[word] = word_counts.get(word, 0) + 1
|
|
182
|
+
|
|
183
|
+
# Get top 10 most frequent words
|
|
184
|
+
top_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:10]
|
|
185
|
+
top_words = [word for word, _ in top_words]
|
|
186
|
+
|
|
187
|
+
# Create one-hot encoded columns
|
|
188
|
+
for word in top_words:
|
|
189
|
+
column_name = f'has_{word}'
|
|
190
|
+
if binary_format == "numeric":
|
|
191
|
+
dataframe[column_name] = dataframe[column].apply(
|
|
192
|
+
lambda x: 1 if isinstance(x, str) and word in str(x).lower() else 0
|
|
193
|
+
)
|
|
194
|
+
else: # binary_format == "text"
|
|
195
|
+
dataframe[column_name] = dataframe[column].apply(
|
|
196
|
+
lambda x: "Yes" if isinstance(x, str) and word in str(x).lower() else "No"
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
# Apply categorical one-hot encoding
|
|
200
|
+
dataframe = get_categorical_ohe(dataframe, column, binary_format)
|
|
201
|
+
|
|
202
|
+
return dataframe
|
|
203
|
+
|
|
204
|
+
def get_categorical_ohe(dataframe: pd.DataFrame, categorical_column: str, binary_format: str = "numeric") -> pd.DataFrame:
|
|
205
|
+
"""
|
|
206
|
+
Create one-hot encoded columns for each unique category in a categorical column.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
dataframe (pd.DataFrame): Input dataframe
|
|
210
|
+
categorical_column (str): Name of the categorical column to process
|
|
211
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
pd.DataFrame: Original dataframe with additional one-hot encoded columns
|
|
215
|
+
"""
|
|
216
|
+
# Check binary format is valid
|
|
217
|
+
if binary_format not in ["numeric", "text"]:
|
|
218
|
+
raise ValueError("binary_format must be either 'numeric' or 'text'")
|
|
219
|
+
|
|
220
|
+
# Get unique categories
|
|
221
|
+
categories = dataframe[categorical_column].dropna().unique()
|
|
222
|
+
|
|
223
|
+
# Create one-hot encoded columns
|
|
224
|
+
for category in categories:
|
|
225
|
+
column_name = f'is_{category}'
|
|
226
|
+
if binary_format == "numeric":
|
|
227
|
+
dataframe[column_name] = dataframe[categorical_column].apply(
|
|
228
|
+
lambda x: 1 if x == category else 0
|
|
229
|
+
)
|
|
230
|
+
else: # binary_format == "text"
|
|
231
|
+
dataframe[column_name] = dataframe[categorical_column].apply(
|
|
232
|
+
lambda x: "Yes" if x == category else "No"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
return dataframe
|
|
236
|
+
|
|
237
|
+
# Add visualization functionality
|
|
238
|
+
from PyQt6.QtWidgets import (QMainWindow, QVBoxLayout, QHBoxLayout, QWidget,
|
|
239
|
+
QTableWidget, QTableWidgetItem, QLabel, QPushButton,
|
|
240
|
+
QComboBox, QSplitter, QTabWidget, QScrollArea,
|
|
241
|
+
QFrame, QSizePolicy, QButtonGroup, QRadioButton,
|
|
242
|
+
QMessageBox, QHeaderView, QApplication, QTextEdit)
|
|
243
|
+
from PyQt6.QtCore import Qt, QSize, pyqtSignal
|
|
244
|
+
from PyQt6.QtGui import QFont
|
|
245
|
+
import matplotlib.pyplot as plt
|
|
246
|
+
import seaborn as sns
|
|
247
|
+
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg as FigureCanvas
|
|
248
|
+
|
|
249
|
+
class OneHotEncodingVisualization(QMainWindow):
|
|
250
|
+
# Add signal to notify when encoding should be applied
|
|
251
|
+
encodingApplied = pyqtSignal(pd.DataFrame)
|
|
252
|
+
|
|
253
|
+
def __init__(self, original_df, encoded_df, encoded_column, binary_format="numeric", algorithm="basic"):
|
|
254
|
+
super().__init__()
|
|
255
|
+
self.original_df = original_df
|
|
256
|
+
self.encoded_df = encoded_df
|
|
257
|
+
self.encoded_column = encoded_column
|
|
258
|
+
self.binary_format = binary_format
|
|
259
|
+
self.algorithm = algorithm
|
|
260
|
+
self.setWindowTitle(f"One-Hot Encoding Visualization - {encoded_column}")
|
|
261
|
+
self.setGeometry(100, 100, 1200, 900)
|
|
262
|
+
|
|
263
|
+
# Main widget
|
|
264
|
+
main_widget = QWidget()
|
|
265
|
+
self.setCentralWidget(main_widget)
|
|
266
|
+
|
|
267
|
+
# Main layout
|
|
268
|
+
main_layout = QVBoxLayout(main_widget)
|
|
269
|
+
|
|
270
|
+
# Title
|
|
271
|
+
title_label = QLabel(f"One-Hot Encoding Analysis: {encoded_column}")
|
|
272
|
+
title_font = QFont()
|
|
273
|
+
title_font.setBold(True)
|
|
274
|
+
title_font.setPointSize(14)
|
|
275
|
+
title_label.setFont(title_font)
|
|
276
|
+
title_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
|
|
277
|
+
main_layout.addWidget(title_label)
|
|
278
|
+
|
|
279
|
+
# Description
|
|
280
|
+
description = "One-hot encoding transforms categorical data into a binary matrix format where each category becomes a separate binary column."
|
|
281
|
+
desc_label = QLabel(description)
|
|
282
|
+
desc_label.setWordWrap(True)
|
|
283
|
+
main_layout.addWidget(desc_label)
|
|
284
|
+
|
|
285
|
+
# Control panel
|
|
286
|
+
control_layout = QHBoxLayout()
|
|
287
|
+
|
|
288
|
+
# Algorithm selector
|
|
289
|
+
algorithm_label = QLabel("Algorithm:")
|
|
290
|
+
self.algorithm_selector = QComboBox()
|
|
291
|
+
self.algorithm_selector.addItems(["Basic (Frequency)", "Advanced (Academic)", "Comprehensive (All Methods)"])
|
|
292
|
+
current_index = {"basic": 0, "advanced": 1, "comprehensive": 2}.get(algorithm, 0)
|
|
293
|
+
self.algorithm_selector.setCurrentIndex(current_index)
|
|
294
|
+
self.algorithm_selector.currentIndexChanged.connect(self.change_algorithm)
|
|
295
|
+
control_layout.addWidget(algorithm_label)
|
|
296
|
+
control_layout.addWidget(self.algorithm_selector)
|
|
297
|
+
|
|
298
|
+
# Format selector
|
|
299
|
+
format_label = QLabel("Encoding Format:")
|
|
300
|
+
self.format_selector = QComboBox()
|
|
301
|
+
self.format_selector.addItems(["Numeric (1/0)", "Text (Yes/No)"])
|
|
302
|
+
self.format_selector.setCurrentIndex(0 if binary_format == "numeric" else 1)
|
|
303
|
+
self.format_selector.currentIndexChanged.connect(self.change_format)
|
|
304
|
+
control_layout.addWidget(format_label)
|
|
305
|
+
control_layout.addWidget(self.format_selector)
|
|
306
|
+
control_layout.addStretch(1)
|
|
307
|
+
|
|
308
|
+
main_layout.addLayout(control_layout)
|
|
309
|
+
|
|
310
|
+
# Splitter to divide the screen
|
|
311
|
+
splitter = QSplitter(Qt.Orientation.Vertical)
|
|
312
|
+
main_layout.addWidget(splitter, 1)
|
|
313
|
+
|
|
314
|
+
# Top widget: Data view
|
|
315
|
+
top_widget = QWidget()
|
|
316
|
+
top_layout = QVBoxLayout(top_widget)
|
|
317
|
+
|
|
318
|
+
# Create tab widget for different views
|
|
319
|
+
tab_widget = QTabWidget()
|
|
320
|
+
|
|
321
|
+
# Tab 1: Original data
|
|
322
|
+
original_tab = QWidget()
|
|
323
|
+
original_layout = QVBoxLayout(original_tab)
|
|
324
|
+
original_table = self.create_table_from_df(self.original_df)
|
|
325
|
+
original_layout.addWidget(original_table)
|
|
326
|
+
tab_widget.addTab(original_tab, "Original Data")
|
|
327
|
+
|
|
328
|
+
# Tab 2: Encoded data
|
|
329
|
+
encoded_tab = QWidget()
|
|
330
|
+
encoded_layout = QVBoxLayout(encoded_tab)
|
|
331
|
+
encoded_table = self.create_table_from_df(self.encoded_df)
|
|
332
|
+
encoded_layout.addWidget(encoded_table)
|
|
333
|
+
tab_widget.addTab(encoded_tab, "Encoded Data")
|
|
334
|
+
|
|
335
|
+
# Tab 3: Algorithm insights (new)
|
|
336
|
+
insights_tab = QWidget()
|
|
337
|
+
insights_layout = QVBoxLayout(insights_tab)
|
|
338
|
+
self.insights_text = QTextEdit()
|
|
339
|
+
self.insights_text.setReadOnly(True)
|
|
340
|
+
insights_layout.addWidget(self.insights_text)
|
|
341
|
+
tab_widget.addTab(insights_tab, "Algorithm Insights")
|
|
342
|
+
|
|
343
|
+
top_layout.addWidget(tab_widget)
|
|
344
|
+
splitter.addWidget(top_widget)
|
|
345
|
+
|
|
346
|
+
# Bottom widget: Visualizations
|
|
347
|
+
bottom_widget = QWidget()
|
|
348
|
+
bottom_layout = QVBoxLayout(bottom_widget)
|
|
349
|
+
|
|
350
|
+
# Visualization title
|
|
351
|
+
viz_title = QLabel("Visualization")
|
|
352
|
+
viz_title.setFont(title_font)
|
|
353
|
+
bottom_layout.addWidget(viz_title)
|
|
354
|
+
|
|
355
|
+
# Create matplotlib figure
|
|
356
|
+
self.figure = plt.figure(figsize=(10, 6))
|
|
357
|
+
self.canvas = FigureCanvas(self.figure)
|
|
358
|
+
bottom_layout.addWidget(self.canvas)
|
|
359
|
+
|
|
360
|
+
# Visualization type selector
|
|
361
|
+
viz_selector_layout = QHBoxLayout()
|
|
362
|
+
viz_selector_label = QLabel("Visualization Type:")
|
|
363
|
+
self.viz_selector = QComboBox()
|
|
364
|
+
viz_options = ["Value Counts", "Correlation Heatmap"]
|
|
365
|
+
if algorithm in ["advanced", "comprehensive"]:
|
|
366
|
+
viz_options.append("Feature Type Analysis")
|
|
367
|
+
self.viz_selector.addItems(viz_options)
|
|
368
|
+
self.viz_selector.currentIndexChanged.connect(self.update_visualization)
|
|
369
|
+
viz_selector_layout.addWidget(viz_selector_label)
|
|
370
|
+
viz_selector_layout.addWidget(self.viz_selector)
|
|
371
|
+
viz_selector_layout.addStretch(1)
|
|
372
|
+
bottom_layout.addLayout(viz_selector_layout)
|
|
373
|
+
|
|
374
|
+
# Add Apply Button
|
|
375
|
+
apply_layout = QHBoxLayout()
|
|
376
|
+
apply_layout.addStretch(1)
|
|
377
|
+
|
|
378
|
+
self.apply_button = QPushButton("Apply Encoding")
|
|
379
|
+
self.apply_button.setStyleSheet("""
|
|
380
|
+
QPushButton {
|
|
381
|
+
background-color: #3498DB;
|
|
382
|
+
color: white;
|
|
383
|
+
border: none;
|
|
384
|
+
padding: 8px 16px;
|
|
385
|
+
border-radius: 4px;
|
|
386
|
+
font-weight: bold;
|
|
387
|
+
}
|
|
388
|
+
QPushButton:hover {
|
|
389
|
+
background-color: #2980B9;
|
|
390
|
+
}
|
|
391
|
+
QPushButton:pressed {
|
|
392
|
+
background-color: #1F618D;
|
|
393
|
+
}
|
|
394
|
+
""")
|
|
395
|
+
self.apply_button.setMinimumWidth(150)
|
|
396
|
+
self.apply_button.clicked.connect(self.apply_encoding)
|
|
397
|
+
apply_layout.addWidget(self.apply_button)
|
|
398
|
+
|
|
399
|
+
bottom_layout.addLayout(apply_layout)
|
|
400
|
+
|
|
401
|
+
splitter.addWidget(bottom_widget)
|
|
402
|
+
|
|
403
|
+
# Set initial splitter sizes
|
|
404
|
+
splitter.setSizes([400, 500])
|
|
405
|
+
|
|
406
|
+
# Update insights and visualization
|
|
407
|
+
self.update_insights()
|
|
408
|
+
self.update_visualization()
|
|
409
|
+
|
|
410
|
+
def create_table_from_df(self, df):
|
|
411
|
+
"""Create a table widget from a dataframe"""
|
|
412
|
+
table = QTableWidget()
|
|
413
|
+
table.setRowCount(min(100, len(df))) # Limit to 100 rows for performance
|
|
414
|
+
table.setColumnCount(len(df.columns))
|
|
415
|
+
|
|
416
|
+
# Set headers
|
|
417
|
+
table.setHorizontalHeaderLabels(df.columns)
|
|
418
|
+
|
|
419
|
+
# Fill data
|
|
420
|
+
for row in range(min(100, len(df))):
|
|
421
|
+
for col, col_name in enumerate(df.columns):
|
|
422
|
+
value = str(df.iloc[row, col])
|
|
423
|
+
item = QTableWidgetItem(value)
|
|
424
|
+
table.setItem(row, col, item)
|
|
425
|
+
|
|
426
|
+
# Optimize appearance
|
|
427
|
+
table.resizeColumnsToContents()
|
|
428
|
+
table.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeMode.Interactive)
|
|
429
|
+
return table
|
|
430
|
+
|
|
431
|
+
def update_visualization(self):
|
|
432
|
+
"""Update the visualization based on the selected type"""
|
|
433
|
+
viz_type = self.viz_selector.currentText()
|
|
434
|
+
|
|
435
|
+
# Clear previous plot
|
|
436
|
+
self.figure.clear()
|
|
437
|
+
|
|
438
|
+
# Get the encoded columns (those starting with 'is_' or 'has_')
|
|
439
|
+
is_columns = [col for col in self.encoded_df.columns if col.startswith('is_')]
|
|
440
|
+
has_columns = [col for col in self.encoded_df.columns if col.startswith('has_')]
|
|
441
|
+
encoded_columns = is_columns + has_columns
|
|
442
|
+
|
|
443
|
+
if viz_type == "Value Counts":
|
|
444
|
+
# Create value counts visualization
|
|
445
|
+
ax = self.figure.add_subplot(111)
|
|
446
|
+
|
|
447
|
+
# Get value counts from original column
|
|
448
|
+
value_counts = self.original_df[self.encoded_column].value_counts()
|
|
449
|
+
|
|
450
|
+
# Plot
|
|
451
|
+
if len(value_counts) > 15:
|
|
452
|
+
# For high cardinality, show top 15
|
|
453
|
+
value_counts.nlargest(15).plot(kind='barh', ax=ax)
|
|
454
|
+
ax.set_title(f"Top 15 Values in {self.encoded_column}")
|
|
455
|
+
else:
|
|
456
|
+
value_counts.plot(kind='barh', ax=ax)
|
|
457
|
+
ax.set_title(f"Value Counts in {self.encoded_column}")
|
|
458
|
+
|
|
459
|
+
ax.set_xlabel("Count")
|
|
460
|
+
ax.set_ylabel(self.encoded_column)
|
|
461
|
+
|
|
462
|
+
elif viz_type == "Correlation Heatmap":
|
|
463
|
+
# Create correlation heatmap for one-hot encoded columns
|
|
464
|
+
if len(encoded_columns) > 1:
|
|
465
|
+
ax = self.figure.add_subplot(111)
|
|
466
|
+
|
|
467
|
+
# Get subset with just the encoded columns
|
|
468
|
+
encoded_subset = self.encoded_df[encoded_columns]
|
|
469
|
+
|
|
470
|
+
# Calculate correlation matrix
|
|
471
|
+
corr_matrix = encoded_subset.corr()
|
|
472
|
+
|
|
473
|
+
# Create heatmap
|
|
474
|
+
if len(encoded_columns) > 10:
|
|
475
|
+
# For many features, don't show annotations
|
|
476
|
+
sns.heatmap(corr_matrix, cmap='coolwarm', linewidths=0.5, ax=ax,
|
|
477
|
+
annot=False, center=0)
|
|
478
|
+
else:
|
|
479
|
+
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', linewidths=0.5,
|
|
480
|
+
ax=ax, fmt='.2f', center=0)
|
|
481
|
+
|
|
482
|
+
ax.set_title(f"Correlation Between Encoded Features ({self.algorithm.title()} Algorithm)")
|
|
483
|
+
else:
|
|
484
|
+
# No encoded columns found
|
|
485
|
+
ax = self.figure.add_subplot(111)
|
|
486
|
+
ax.text(0.5, 0.5, "Need at least 2 features for correlation analysis",
|
|
487
|
+
horizontalalignment='center', verticalalignment='center',
|
|
488
|
+
transform=ax.transAxes)
|
|
489
|
+
ax.axis('off')
|
|
490
|
+
|
|
491
|
+
elif viz_type == "Feature Type Analysis" and self.algorithm in ["advanced", "comprehensive"]:
|
|
492
|
+
# Create feature type analysis for advanced algorithms
|
|
493
|
+
ax = self.figure.add_subplot(111)
|
|
494
|
+
|
|
495
|
+
# Group features by type
|
|
496
|
+
feature_types = {}
|
|
497
|
+
for col in encoded_columns:
|
|
498
|
+
if 'topic_lda' in col:
|
|
499
|
+
feature_types.setdefault('LDA Topics', []).append(col)
|
|
500
|
+
elif 'topic_nmf' in col:
|
|
501
|
+
feature_types.setdefault('NMF Topics', []).append(col)
|
|
502
|
+
elif 'semantic_cluster' in col:
|
|
503
|
+
feature_types.setdefault('Semantic Clusters', []).append(col)
|
|
504
|
+
elif 'domain_' in col:
|
|
505
|
+
feature_types.setdefault('Domain Concepts', []).append(col)
|
|
506
|
+
elif 'ngram_' in col:
|
|
507
|
+
feature_types.setdefault('Key N-grams', []).append(col)
|
|
508
|
+
elif 'entity_' in col:
|
|
509
|
+
feature_types.setdefault('Named Entities', []).append(col)
|
|
510
|
+
else:
|
|
511
|
+
feature_types.setdefault('Basic Features', []).append(col)
|
|
512
|
+
|
|
513
|
+
# Create bar chart of feature types
|
|
514
|
+
types = list(feature_types.keys())
|
|
515
|
+
counts = [len(feature_types[t]) for t in types]
|
|
516
|
+
|
|
517
|
+
bars = ax.bar(types, counts, color=['#3498DB', '#E74C3C', '#2ECC71', '#F39C12', '#9B59B6', '#1ABC9C', '#34495E'][:len(types)])
|
|
518
|
+
ax.set_title(f"Feature Types Created by {self.algorithm.title()} Algorithm")
|
|
519
|
+
ax.set_ylabel("Number of Features")
|
|
520
|
+
ax.set_xlabel("Feature Type")
|
|
521
|
+
|
|
522
|
+
# Add value labels on bars
|
|
523
|
+
for bar, count in zip(bars, counts):
|
|
524
|
+
height = bar.get_height()
|
|
525
|
+
ax.text(bar.get_x() + bar.get_width()/2., height + 0.1,
|
|
526
|
+
f'{count}', ha='center', va='bottom')
|
|
527
|
+
|
|
528
|
+
# Rotate x-axis labels if needed
|
|
529
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
|
|
530
|
+
|
|
531
|
+
# Update the canvas
|
|
532
|
+
plt.tight_layout()
|
|
533
|
+
self.canvas.draw()
|
|
534
|
+
|
|
535
|
+
def apply_encoding(self):
|
|
536
|
+
"""Apply the encoded dataframe to the main window"""
|
|
537
|
+
reply = QMessageBox.question(
|
|
538
|
+
self,
|
|
539
|
+
"Apply Encoding",
|
|
540
|
+
"Are you sure you want to apply this encoding to the original table?\n\n"
|
|
541
|
+
"This will add the one-hot encoded columns to the current result table.",
|
|
542
|
+
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
|
543
|
+
QMessageBox.StandardButton.No
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
if reply == QMessageBox.StandardButton.Yes:
|
|
547
|
+
# Emit signal with the encoded DataFrame
|
|
548
|
+
self.encodingApplied.emit(self.encoded_df)
|
|
549
|
+
QMessageBox.information(
|
|
550
|
+
self,
|
|
551
|
+
"Encoding Applied",
|
|
552
|
+
"The one-hot encoding has been applied to the table."
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def change_format(self):
|
|
556
|
+
"""Change the binary format and reapply encoding"""
|
|
557
|
+
# Get the selected format
|
|
558
|
+
selected_format = "numeric" if self.format_selector.currentIndex() == 0 else "text"
|
|
559
|
+
|
|
560
|
+
# Only update if format has changed
|
|
561
|
+
if selected_format != self.binary_format:
|
|
562
|
+
# Update format
|
|
563
|
+
self.binary_format = selected_format
|
|
564
|
+
|
|
565
|
+
# Reapply encoding with current algorithm
|
|
566
|
+
self.encoded_df = get_ohe(self.original_df.copy(), self.encoded_column,
|
|
567
|
+
self.binary_format, self.algorithm)
|
|
568
|
+
|
|
569
|
+
# Update all tabs
|
|
570
|
+
self.update_all_tabs()
|
|
571
|
+
|
|
572
|
+
# Show confirmation
|
|
573
|
+
QMessageBox.information(
|
|
574
|
+
self,
|
|
575
|
+
"Format Changed",
|
|
576
|
+
f"Encoding format changed to {selected_format}"
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
def change_algorithm(self):
|
|
580
|
+
"""Change the algorithm and reapply encoding"""
|
|
581
|
+
algorithm_map = {0: "basic", 1: "advanced", 2: "comprehensive"}
|
|
582
|
+
selected_algorithm = algorithm_map[self.algorithm_selector.currentIndex()]
|
|
583
|
+
|
|
584
|
+
# Only update if algorithm has changed
|
|
585
|
+
if selected_algorithm != self.algorithm:
|
|
586
|
+
self.algorithm = selected_algorithm
|
|
587
|
+
|
|
588
|
+
# Reapply encoding with new algorithm
|
|
589
|
+
self.encoded_df = get_ohe(self.original_df.copy(), self.encoded_column,
|
|
590
|
+
self.binary_format, self.algorithm)
|
|
591
|
+
|
|
592
|
+
# Update all tabs
|
|
593
|
+
self.update_all_tabs()
|
|
594
|
+
|
|
595
|
+
# Show confirmation
|
|
596
|
+
QMessageBox.information(
|
|
597
|
+
self,
|
|
598
|
+
"Algorithm Changed",
|
|
599
|
+
f"Encoding algorithm changed to {selected_algorithm.title()}"
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
def update_all_tabs(self):
|
|
603
|
+
"""Update all tabs when encoding changes"""
|
|
604
|
+
# Update encoded data tab
|
|
605
|
+
tab_widget = self.findChild(QTabWidget)
|
|
606
|
+
if tab_widget:
|
|
607
|
+
# Update encoded data tab
|
|
608
|
+
encoded_tab = tab_widget.widget(1)
|
|
609
|
+
if encoded_tab:
|
|
610
|
+
# Clear old layout
|
|
611
|
+
for i in reversed(range(encoded_tab.layout().count())):
|
|
612
|
+
encoded_tab.layout().itemAt(i).widget().setParent(None)
|
|
613
|
+
|
|
614
|
+
# Add new table
|
|
615
|
+
encoded_table = self.create_table_from_df(self.encoded_df)
|
|
616
|
+
encoded_tab.layout().addWidget(encoded_table)
|
|
617
|
+
|
|
618
|
+
# Update insights
|
|
619
|
+
self.update_insights()
|
|
620
|
+
|
|
621
|
+
# Update visualization options
|
|
622
|
+
self.update_viz_options()
|
|
623
|
+
|
|
624
|
+
# Update visualization
|
|
625
|
+
self.update_visualization()
|
|
626
|
+
|
|
627
|
+
def update_viz_options(self):
|
|
628
|
+
"""Update visualization options based on current algorithm"""
|
|
629
|
+
current_viz = self.viz_selector.currentText()
|
|
630
|
+
self.viz_selector.clear()
|
|
631
|
+
|
|
632
|
+
viz_options = ["Value Counts", "Correlation Heatmap"]
|
|
633
|
+
if self.algorithm in ["advanced", "comprehensive"]:
|
|
634
|
+
viz_options.append("Feature Type Analysis")
|
|
635
|
+
|
|
636
|
+
self.viz_selector.addItems(viz_options)
|
|
637
|
+
|
|
638
|
+
# Try to keep the same visualization if possible
|
|
639
|
+
for i, option in enumerate(viz_options):
|
|
640
|
+
if option == current_viz:
|
|
641
|
+
self.viz_selector.setCurrentIndex(i)
|
|
642
|
+
break
|
|
643
|
+
|
|
644
|
+
def update_insights(self):
|
|
645
|
+
"""Update the algorithm insights tab"""
|
|
646
|
+
new_columns = [col for col in self.encoded_df.columns if col.startswith('has_')]
|
|
647
|
+
|
|
648
|
+
insights = f"""
|
|
649
|
+
=== {self.algorithm.title()} Algorithm Insights ===
|
|
650
|
+
|
|
651
|
+
Dataset Overview:
|
|
652
|
+
• Total records: {len(self.encoded_df)}
|
|
653
|
+
• Original column: {self.encoded_column}
|
|
654
|
+
• Features created: {len(new_columns)}
|
|
655
|
+
• Binary format: {self.binary_format}
|
|
656
|
+
|
|
657
|
+
Algorithm Details:
|
|
658
|
+
"""
|
|
659
|
+
|
|
660
|
+
if self.algorithm == "basic":
|
|
661
|
+
insights += """
|
|
662
|
+
Basic Frequency Algorithm:
|
|
663
|
+
• Uses simple word frequency analysis
|
|
664
|
+
• Extracts top 10 most common words/categories
|
|
665
|
+
• Good for: Simple categorical data, basic text analysis
|
|
666
|
+
• Limitations: Misses semantic relationships, synonyms, themes
|
|
667
|
+
|
|
668
|
+
How it works:
|
|
669
|
+
1. Tokenizes text and removes stopwords
|
|
670
|
+
2. Counts word frequencies
|
|
671
|
+
3. Creates binary features for most frequent words
|
|
672
|
+
4. Fast and lightweight approach
|
|
673
|
+
"""
|
|
674
|
+
elif self.algorithm == "advanced":
|
|
675
|
+
insights += """
|
|
676
|
+
Advanced Academic Algorithm:
|
|
677
|
+
• Uses sophisticated NLP and ML techniques:
|
|
678
|
+
- Topic Modeling (LDA & NMF)
|
|
679
|
+
- Semantic clustering with TF-IDF
|
|
680
|
+
- N-gram extraction
|
|
681
|
+
- Named Entity Recognition (if available)
|
|
682
|
+
• Good for: Complex text analysis, theme detection
|
|
683
|
+
• Benefits: Captures semantic relationships, identifies topics
|
|
684
|
+
|
|
685
|
+
How it works:
|
|
686
|
+
1. Applies multiple academic algorithms in parallel
|
|
687
|
+
2. Extracts latent topics using probabilistic models
|
|
688
|
+
3. Groups semantically related words into clusters
|
|
689
|
+
4. Identifies key phrases and entities
|
|
690
|
+
5. Creates features based on conceptual understanding
|
|
691
|
+
"""
|
|
692
|
+
elif self.algorithm == "comprehensive":
|
|
693
|
+
insights += """
|
|
694
|
+
Comprehensive Analysis:
|
|
695
|
+
• Combines ALL available methods:
|
|
696
|
+
- Topic Modeling (LDA & NMF)
|
|
697
|
+
- Semantic clustering
|
|
698
|
+
- N-gram extraction
|
|
699
|
+
- Named Entity Recognition
|
|
700
|
+
- Domain-specific concept detection
|
|
701
|
+
• Best for: Research, detailed analysis, maximum insight
|
|
702
|
+
• Benefits: Most complete semantic understanding
|
|
703
|
+
|
|
704
|
+
How it works:
|
|
705
|
+
1. Runs all advanced algorithms simultaneously
|
|
706
|
+
2. Extracts maximum number of meaningful features
|
|
707
|
+
3. Identifies cross-cutting themes and relationships
|
|
708
|
+
4. Provides richest feature representation
|
|
709
|
+
5. Ideal for discovering hidden patterns
|
|
710
|
+
"""
|
|
711
|
+
|
|
712
|
+
# Add feature breakdown
|
|
713
|
+
if new_columns:
|
|
714
|
+
insights += f"""
|
|
715
|
+
Features Created ({len(new_columns)} total):
|
|
716
|
+
"""
|
|
717
|
+
|
|
718
|
+
# Group features by type for advanced algorithms
|
|
719
|
+
if self.algorithm in ["advanced", "comprehensive"]:
|
|
720
|
+
feature_types = {}
|
|
721
|
+
for col in new_columns:
|
|
722
|
+
if 'topic_lda' in col:
|
|
723
|
+
feature_types.setdefault('LDA Topics', []).append(col)
|
|
724
|
+
elif 'topic_nmf' in col:
|
|
725
|
+
feature_types.setdefault('NMF Topics', []).append(col)
|
|
726
|
+
elif 'semantic_cluster' in col:
|
|
727
|
+
feature_types.setdefault('Semantic Clusters', []).append(col)
|
|
728
|
+
elif 'domain_' in col:
|
|
729
|
+
feature_types.setdefault('Domain Concepts', []).append(col)
|
|
730
|
+
elif 'ngram_' in col:
|
|
731
|
+
feature_types.setdefault('Key N-grams', []).append(col)
|
|
732
|
+
elif 'entity_' in col:
|
|
733
|
+
feature_types.setdefault('Named Entities', []).append(col)
|
|
734
|
+
else:
|
|
735
|
+
feature_types.setdefault('Basic Features', []).append(col)
|
|
736
|
+
|
|
737
|
+
for ftype, features in feature_types.items():
|
|
738
|
+
insights += f"\n{ftype} ({len(features)}):\n"
|
|
739
|
+
for feature in features[:5]: # Show first 5
|
|
740
|
+
coverage = self.calculate_coverage(feature)
|
|
741
|
+
insights += f" • {feature}: {coverage:.1f}% coverage\n"
|
|
742
|
+
if len(features) > 5:
|
|
743
|
+
insights += f" ... and {len(features) - 5} more\n"
|
|
744
|
+
else:
|
|
745
|
+
# Basic algorithm - show all features
|
|
746
|
+
for feature in new_columns[:10]: # Show first 10
|
|
747
|
+
coverage = self.calculate_coverage(feature)
|
|
748
|
+
insights += f"• {feature}: {coverage:.1f}% coverage\n"
|
|
749
|
+
if len(new_columns) > 10:
|
|
750
|
+
insights += f"... and {len(new_columns) - 10} more\n"
|
|
751
|
+
|
|
752
|
+
# Add recommendations
|
|
753
|
+
insights += f"""
|
|
754
|
+
Recommendations:
|
|
755
|
+
"""
|
|
756
|
+
if self.algorithm == "basic":
|
|
757
|
+
insights += """
|
|
758
|
+
• Consider upgrading to Advanced for better semantic understanding
|
|
759
|
+
• Good for simple categorical data and quick analysis
|
|
760
|
+
• May miss important relationships in complex text data
|
|
761
|
+
"""
|
|
762
|
+
elif self.algorithm == "advanced":
|
|
763
|
+
insights += """
|
|
764
|
+
• Excellent balance of sophistication and performance
|
|
765
|
+
• Captures most important semantic relationships
|
|
766
|
+
• Good for production use and detailed analysis
|
|
767
|
+
"""
|
|
768
|
+
elif self.algorithm == "comprehensive":
|
|
769
|
+
insights += """
|
|
770
|
+
• Maximum insight extraction from your data
|
|
771
|
+
• Best for research and exploratory analysis
|
|
772
|
+
• Use correlation analysis to identify redundant features
|
|
773
|
+
• Consider feature selection for production deployment
|
|
774
|
+
"""
|
|
775
|
+
|
|
776
|
+
self.insights_text.setPlainText(insights)
|
|
777
|
+
|
|
778
|
+
def calculate_coverage(self, feature_name):
|
|
779
|
+
"""Calculate the coverage percentage of a feature"""
|
|
780
|
+
if self.binary_format == "numeric":
|
|
781
|
+
return (self.encoded_df[feature_name] == 1).sum() / len(self.encoded_df) * 100
|
|
782
|
+
else:
|
|
783
|
+
return (self.encoded_df[feature_name] == "Yes").sum() / len(self.encoded_df) * 100
|
|
784
|
+
|
|
785
|
+
def visualize_ohe(df, column, binary_format="numeric", algorithm="basic"):
|
|
786
|
+
"""
|
|
787
|
+
Visualize the one-hot encoding of a column in a dataframe.
|
|
788
|
+
|
|
789
|
+
Args:
|
|
790
|
+
df (pd.DataFrame): The original dataframe
|
|
791
|
+
column (str): The column to encode and visualize
|
|
792
|
+
binary_format (str): Format for encoding - "numeric" for 1/0 or "text" for "Yes"/"No"
|
|
793
|
+
algorithm (str): Algorithm to use - "basic", "advanced", or "comprehensive"
|
|
794
|
+
|
|
795
|
+
Returns:
|
|
796
|
+
QMainWindow: The visualization window
|
|
797
|
+
"""
|
|
798
|
+
# Create a copy to avoid modifying the original
|
|
799
|
+
original_df = df.copy()
|
|
800
|
+
|
|
801
|
+
# Apply one-hot encoding with selected algorithm
|
|
802
|
+
encoded_df = get_ohe(original_df, column, binary_format, algorithm)
|
|
803
|
+
|
|
804
|
+
# Create and show the visualization
|
|
805
|
+
vis = OneHotEncodingVisualization(original_df, encoded_df, column, binary_format, algorithm)
|
|
806
|
+
vis.show()
|
|
807
|
+
|
|
808
|
+
return vis
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
def test_ohe():
|
|
812
|
+
"""
|
|
813
|
+
Test the one-hot encoding function with sample dataframes for both text and categorical data.
|
|
814
|
+
Tests both numeric (1/0) and text (Yes/No) encoding formats and different algorithms.
|
|
815
|
+
"""
|
|
816
|
+
print("\n===== Testing Text Data One-Hot Encoding =====")
|
|
817
|
+
# Create sample text data
|
|
818
|
+
text_data = {
|
|
819
|
+
'text': [
|
|
820
|
+
'The quick brown fox jumps over the lazy dog',
|
|
821
|
+
'A quick brown dog runs in the park',
|
|
822
|
+
'The lazy cat sleeps all day',
|
|
823
|
+
'A brown fox and a lazy dog play together',
|
|
824
|
+
'The quick cat chases the mouse',
|
|
825
|
+
'A lazy dog sleeps in the sun',
|
|
826
|
+
'The brown fox is quick and clever',
|
|
827
|
+
'A cat and a dog are best friends',
|
|
828
|
+
'The quick mouse runs from the cat',
|
|
829
|
+
'A lazy fox sleeps in the shade'
|
|
830
|
+
]
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
# Create dataframe
|
|
834
|
+
text_df = pd.DataFrame(text_data)
|
|
835
|
+
|
|
836
|
+
# Test basic algorithm
|
|
837
|
+
print("\n----- Testing Basic Algorithm -----")
|
|
838
|
+
basic_result = get_ohe(text_df.copy(), 'text', binary_format="numeric", algorithm="basic")
|
|
839
|
+
basic_features = [col for col in basic_result.columns if col.startswith('has_')]
|
|
840
|
+
print(f"Basic algorithm created {len(basic_features)} features")
|
|
841
|
+
|
|
842
|
+
# Test advanced algorithm
|
|
843
|
+
print("\n----- Testing Advanced Algorithm -----")
|
|
844
|
+
try:
|
|
845
|
+
advanced_result = get_ohe(text_df.copy(), 'text', binary_format="numeric", algorithm="advanced")
|
|
846
|
+
advanced_features = [col for col in advanced_result.columns if col.startswith('has_')]
|
|
847
|
+
print(f"Advanced algorithm created {len(advanced_features)} features")
|
|
848
|
+
except Exception as e:
|
|
849
|
+
print(f"Advanced algorithm failed: {e}")
|
|
850
|
+
|
|
851
|
+
# Test comprehensive algorithm
|
|
852
|
+
print("\n----- Testing Comprehensive Algorithm -----")
|
|
853
|
+
try:
|
|
854
|
+
comprehensive_result = get_ohe(text_df.copy(), 'text', binary_format="numeric", algorithm="comprehensive")
|
|
855
|
+
comprehensive_features = [col for col in comprehensive_result.columns if col.startswith('has_')]
|
|
856
|
+
print(f"Comprehensive algorithm created {len(comprehensive_features)} features")
|
|
857
|
+
except Exception as e:
|
|
858
|
+
print(f"Comprehensive algorithm failed: {e}")
|
|
859
|
+
|
|
860
|
+
print("\nText data tests completed!")
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
def test_advanced_ai_example():
|
|
864
|
+
"""Test with AI-related text to demonstrate semantic understanding"""
|
|
865
|
+
print("\n===== Testing AI/ML Text Analysis =====")
|
|
866
|
+
|
|
867
|
+
ai_data = {
|
|
868
|
+
'description': [
|
|
869
|
+
"Machine learning engineer developing neural networks for computer vision",
|
|
870
|
+
"AI researcher working on natural language processing and transformers",
|
|
871
|
+
"Data scientist implementing deep learning algorithms for analytics",
|
|
872
|
+
"Software engineer building recommendation systems with collaborative filtering",
|
|
873
|
+
"ML ops engineer deploying artificial intelligence models to production"
|
|
874
|
+
]
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
df = pd.DataFrame(ai_data)
|
|
878
|
+
|
|
879
|
+
print("Testing different algorithms on AI-related text:")
|
|
880
|
+
|
|
881
|
+
# Test all algorithms
|
|
882
|
+
for algorithm in ["basic", "advanced", "comprehensive"]:
|
|
883
|
+
print(f"\n--- {algorithm.title()} Algorithm ---")
|
|
884
|
+
try:
|
|
885
|
+
result = get_ohe(df.copy(), 'description', algorithm=algorithm)
|
|
886
|
+
features = [col for col in result.columns if col.startswith('has_')]
|
|
887
|
+
print(f"Created {len(features)} features")
|
|
888
|
+
|
|
889
|
+
# Show AI-related features
|
|
890
|
+
ai_features = [f for f in features if any(term in f.lower() for term in ['ai', 'machine', 'learning', 'neural', 'deep'])]
|
|
891
|
+
if ai_features:
|
|
892
|
+
print(f"AI-related features found: {len(ai_features)}")
|
|
893
|
+
for feature in ai_features[:3]: # Show first 3
|
|
894
|
+
print(f" • {feature}")
|
|
895
|
+
else:
|
|
896
|
+
print("No explicit AI-related features in names (may be captured in topics)")
|
|
897
|
+
|
|
898
|
+
except Exception as e:
|
|
899
|
+
print(f"Failed: {e}")
|
|
900
|
+
|
|
901
|
+
print("\nAI example test completed!")
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
if __name__ == "__main__":
|
|
905
|
+
# Run tests
|
|
906
|
+
test_ohe()
|
|
907
|
+
test_advanced_ai_example()
|
|
908
|
+
|
|
909
|
+
# Test the visualization with different algorithms
|
|
910
|
+
import sys
|
|
911
|
+
from PyQt6.QtWidgets import QApplication
|
|
912
|
+
|
|
913
|
+
if QApplication.instance() is None:
|
|
914
|
+
app = QApplication(sys.argv)
|
|
915
|
+
|
|
916
|
+
# Create a sample dataframe
|
|
917
|
+
data = {
|
|
918
|
+
'category': ['red', 'blue', 'green', 'red', 'yellow', 'blue'],
|
|
919
|
+
'text': [
|
|
920
|
+
'The quick brown fox',
|
|
921
|
+
'A lazy dog',
|
|
922
|
+
'Brown fox jumps',
|
|
923
|
+
'Quick brown fox',
|
|
924
|
+
'Lazy dog sleeps',
|
|
925
|
+
'Fox and dog'
|
|
926
|
+
]
|
|
927
|
+
}
|
|
928
|
+
df = pd.DataFrame(data)
|
|
929
|
+
|
|
930
|
+
# Show visualization with advanced algorithm
|
|
931
|
+
vis = visualize_ohe(df, 'text', binary_format="numeric", algorithm="advanced")
|
|
932
|
+
|
|
933
|
+
# Start the application
|
|
934
|
+
sys.exit(app.exec())
|