tabimpute 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tabimpute-0.0.2/.gitignore +4 -0
- tabimpute-0.0.2/PKG-INFO +58 -0
- tabimpute-0.0.2/README.md +5 -0
- tabimpute-0.0.2/about.py +1 -0
- tabimpute-0.0.2/benchmark/README.md +28 -0
- tabimpute-0.0.2/benchmark/analyze_imputation_times.py +324 -0
- tabimpute-0.0.2/benchmark/auc_table.tex +53 -0
- tabimpute-0.0.2/benchmark/cacti_wrapper.py +259 -0
- tabimpute-0.0.2/benchmark/calculate_imputation_variance.py +290 -0
- tabimpute-0.0.2/benchmark/create_openml_categorical_missingness_datasets.py +38 -0
- tabimpute-0.0.2/benchmark/create_openml_missingness_datasets.py +145 -0
- tabimpute-0.0.2/benchmark/create_uci_missingness.py +105 -0
- tabimpute-0.0.2/benchmark/dataset_sizes.txt +166 -0
- tabimpute-0.0.2/benchmark/dataset_sizes_missing.txt +7 -0
- tabimpute-0.0.2/benchmark/datasets/.gitignore +4 -0
- tabimpute-0.0.2/benchmark/datasets_df.csv +87 -0
- tabimpute-0.0.2/benchmark/diffputer_wrapper.py +240 -0
- tabimpute-0.0.2/benchmark/download_openml.py +146 -0
- tabimpute-0.0.2/benchmark/download_uci_datasets.py +298 -0
- tabimpute-0.0.2/benchmark/generate_dataset_table.py +71 -0
- tabimpute-0.0.2/benchmark/get_openml_categorical_errors.py +278 -0
- tabimpute-0.0.2/benchmark/get_openml_errors.py +486 -0
- tabimpute-0.0.2/benchmark/get_uci_errors.py +470 -0
- tabimpute-0.0.2/benchmark/notmiwae_wrapper.py +289 -0
- tabimpute-0.0.2/benchmark/plot_accuracy.py +288 -0
- tabimpute-0.0.2/benchmark/plot_critical_difference.py +347 -0
- tabimpute-0.0.2/benchmark/plot_mcar_line.py +193 -0
- tabimpute-0.0.2/benchmark/plot_negative_mae.py +52 -0
- tabimpute-0.0.2/benchmark/plot_negative_rmse.py +845 -0
- tabimpute-0.0.2/benchmark/plot_options.py +234 -0
- tabimpute-0.0.2/benchmark/plot_pairwise_by_pattern.py +127 -0
- tabimpute-0.0.2/benchmark/plot_r_squared.py +733 -0
- tabimpute-0.0.2/benchmark/plot_uci_negative_rmse.py +509 -0
- tabimpute-0.0.2/benchmark/plot_wasserstein.py +723 -0
- tabimpute-0.0.2/benchmark/plot_win_rate.py +247 -0
- tabimpute-0.0.2/benchmark/remasker_wrapper.py +310 -0
- tabimpute-0.0.2/benchmark/runtime_benchmark/README.md +47 -0
- tabimpute-0.0.2/benchmark/runtime_benchmark/get_runtime_models.py +153 -0
- tabimpute-0.0.2/benchmark/runtime_benchmark/plot_runtime_benchmark.py +110 -0
- tabimpute-0.0.2/benchmark/runtime_benchmark/runtime_benchmark_results.csv +61 -0
- tabimpute-0.0.2/config.pkl +0 -0
- tabimpute-0.0.2/pyproject.toml +120 -0
- tabimpute-0.0.2/scripts/gen_multiple.sh +14 -0
- tabimpute-0.0.2/scripts/generate_data.sh +65 -0
- tabimpute-0.0.2/scripts/train.sh +70 -0
- tabimpute-0.0.2/scripts/train_mar.sh +70 -0
- tabimpute-0.0.2/scripts/train_mcar.sh +71 -0
- tabimpute-0.0.2/scripts/train_mcar_nonlinear.sh +71 -0
- tabimpute-0.0.2/scripts/train_mnar.sh +70 -0
- tabimpute-0.0.2/src/tabimpute/__about__.py +1 -0
- tabimpute-0.0.2/src/tabimpute/__init__.py +1 -0
- tabimpute-0.0.2/src/tabimpute/data/borders.pt +0 -0
- tabimpute-0.0.2/src/tabimpute/data/encoder.pth +0 -0
- tabimpute-0.0.2/src/tabimpute/diffusion/__init__.py +13 -0
- tabimpute-0.0.2/src/tabimpute/diffusion/mar_diffusion_row10_30_col10_30_mar0_3_marblock0_3_marbandit0_4_epoch100_bs32_samples30k_lr1e_3.py +850 -0
- tabimpute-0.0.2/src/tabimpute/interface.py +595 -0
- tabimpute-0.0.2/src/tabimpute/misc/_sklearn_compat.py +869 -0
- tabimpute-0.0.2/src/tabimpute/misc/debug_versions.py +702 -0
- tabimpute-0.0.2/src/tabimpute/model/__init__.py +0 -0
- tabimpute-0.0.2/src/tabimpute/model/bar_distribution.py +863 -0
- tabimpute-0.0.2/src/tabimpute/model/config.py +150 -0
- tabimpute-0.0.2/src/tabimpute/model/encoders.py +1078 -0
- tabimpute-0.0.2/src/tabimpute/model/full_attention.py +1565 -0
- tabimpute-0.0.2/src/tabimpute/model/inference.py +683 -0
- tabimpute-0.0.2/src/tabimpute/model/inference_config.py +228 -0
- tabimpute-0.0.2/src/tabimpute/model/layer.py +472 -0
- tabimpute-0.0.2/src/tabimpute/model/mcpfn.py +138 -0
- tabimpute-0.0.2/src/tabimpute/model/memory.py +452 -0
- tabimpute-0.0.2/src/tabimpute/model/mlp.py +138 -0
- tabimpute-0.0.2/src/tabimpute/model/model.py +403 -0
- tabimpute-0.0.2/src/tabimpute/model/positional.py +169 -0
- tabimpute-0.0.2/src/tabimpute/model/transformer.py +870 -0
- tabimpute-0.0.2/src/tabimpute/prepreocess.py +575 -0
- tabimpute-0.0.2/src/tabimpute/prior/__init__.py +0 -0
- tabimpute-0.0.2/src/tabimpute/prior/activations.py +289 -0
- tabimpute-0.0.2/src/tabimpute/prior/base_prior.py +392 -0
- tabimpute-0.0.2/src/tabimpute/prior/dataset.py +773 -0
- tabimpute-0.0.2/src/tabimpute/prior/genload.py +758 -0
- tabimpute-0.0.2/src/tabimpute/prior/hp_sampling.py +301 -0
- tabimpute-0.0.2/src/tabimpute/prior/mar_block_missing.py +142 -0
- tabimpute-0.0.2/src/tabimpute/prior/mar_missing.py +355 -0
- tabimpute-0.0.2/src/tabimpute/prior/mar_onesided_missing.py +356 -0
- tabimpute-0.0.2/src/tabimpute/prior/mar_sequential_missing.py +584 -0
- tabimpute-0.0.2/src/tabimpute/prior/mlp_scm.py +344 -0
- tabimpute-0.0.2/src/tabimpute/prior/prior_config.py +94 -0
- tabimpute-0.0.2/src/tabimpute/prior/reg2cls.py +390 -0
- tabimpute-0.0.2/src/tabimpute/prior/scm_prior.py +383 -0
- tabimpute-0.0.2/src/tabimpute/prior/splits.py +45 -0
- tabimpute-0.0.2/src/tabimpute/prior/training_set_generation.py +1295 -0
- tabimpute-0.0.2/src/tabimpute/prior/tree_scm.py +401 -0
- tabimpute-0.0.2/src/tabimpute/prior/utils.py +165 -0
- tabimpute-0.0.2/src/tabimpute/tabimpute_v2.py +88 -0
- tabimpute-0.0.2/src/tabimpute/tabpfn_extensions_interface.py +73 -0
- tabimpute-0.0.2/src/tabimpute/train/__init__.py +0 -0
- tabimpute-0.0.2/src/tabimpute/train/callbacks.py +92 -0
- tabimpute-0.0.2/src/tabimpute/train/optim.py +356 -0
- tabimpute-0.0.2/src/tabimpute/train/run.py +1141 -0
- tabimpute-0.0.2/src/tabimpute/train/train_config.py +431 -0
- tabimpute-0.0.2/src/tabimpute/train.py +232 -0
- tabimpute-0.0.2/test_bar_distribution_shape.py +200 -0
- tabimpute-0.0.2/tests.ipynb +12524 -0
- tabimpute-0.0.2/train.md +40 -0
- tabimpute-0.0.2/utils.py +50 -0
tabimpute-0.0.2/PKG-INFO
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tabimpute
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: TabImpute: A Pre-trained Transformer for Missing Data Imputation
|
|
5
|
+
Author: Jacob Feitelberg, Dwaipayan Saha, Zaid Ahmad, Kyuseong Choi, Anish Agarwal, Raaz Dwivedi
|
|
6
|
+
Keywords: foundation model,in-context learning,missing data imputation,tabular data
|
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: Programming Language :: Python
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering
|
|
17
|
+
Requires-Python: <3.13,>=3.9
|
|
18
|
+
Requires-Dist: einops>=0.7
|
|
19
|
+
Requires-Dist: huggingface-hub
|
|
20
|
+
Requires-Dist: networkx
|
|
21
|
+
Requires-Dist: numpy
|
|
22
|
+
Requires-Dist: tabpfn==6.0.0
|
|
23
|
+
Requires-Dist: torch<3,>=2.2
|
|
24
|
+
Provides-Extra: benchmark
|
|
25
|
+
Requires-Dist: joblib; extra == 'benchmark'
|
|
26
|
+
Requires-Dist: matplotlib; extra == 'benchmark'
|
|
27
|
+
Requires-Dist: pandas; extra == 'benchmark'
|
|
28
|
+
Requires-Dist: psutil; extra == 'benchmark'
|
|
29
|
+
Requires-Dist: scikit-learn==1.4.2; extra == 'benchmark'
|
|
30
|
+
Requires-Dist: scipy; extra == 'benchmark'
|
|
31
|
+
Requires-Dist: tabpfn; extra == 'benchmark'
|
|
32
|
+
Requires-Dist: tqdm>=4.64.0; extra == 'benchmark'
|
|
33
|
+
Requires-Dist: transformers; extra == 'benchmark'
|
|
34
|
+
Requires-Dist: wandb; extra == 'benchmark'
|
|
35
|
+
Requires-Dist: xgboost; extra == 'benchmark'
|
|
36
|
+
Provides-Extra: categorical
|
|
37
|
+
Requires-Dist: scipy; extra == 'categorical'
|
|
38
|
+
Provides-Extra: preprocessing
|
|
39
|
+
Requires-Dist: scipy; extra == 'preprocessing'
|
|
40
|
+
Provides-Extra: tabpfn-extensions
|
|
41
|
+
Requires-Dist: tabpfn-extensions; extra == 'tabpfn-extensions'
|
|
42
|
+
Provides-Extra: training
|
|
43
|
+
Requires-Dist: joblib; extra == 'training'
|
|
44
|
+
Requires-Dist: pandas; extra == 'training'
|
|
45
|
+
Requires-Dist: psutil; extra == 'training'
|
|
46
|
+
Requires-Dist: scikit-learn==1.4.2; extra == 'training'
|
|
47
|
+
Requires-Dist: scipy; extra == 'training'
|
|
48
|
+
Requires-Dist: tqdm>=4.64.0; extra == 'training'
|
|
49
|
+
Requires-Dist: transformers; extra == 'training'
|
|
50
|
+
Requires-Dist: wandb; extra == 'training'
|
|
51
|
+
Requires-Dist: xgboost; extra == 'training'
|
|
52
|
+
Description-Content-Type: text/markdown
|
|
53
|
+
|
|
54
|
+
# TabImpute
|
|
55
|
+
|
|
56
|
+
TabImpute is a pre-trained transformer for missing data imputation on tabular data.
|
|
57
|
+
|
|
58
|
+
This code is based on the TabPFN and TabICL codebases, both available on GitHub.
|
tabimpute-0.0.2/about.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.0.2"
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
This folder contains the code for the benchmark of the imputers on the OpenML datasets.
|
|
2
|
+
|
|
3
|
+
The code is organized as follows:
|
|
4
|
+
|
|
5
|
+
- `create_openml_missingness_datasets.py`: This script is used to create the missingness datasets from the OpenML datasets.
|
|
6
|
+
- `get_openml_errors.py`: This script is used to get the errors of the imputers on the OpenML datasets.
|
|
7
|
+
- `plot_error_violinplots.py`: This script is used to plot the violin plots of the errors of the imputers on the OpenML datasets.
|
|
8
|
+
- `plot_error_boxplots.py`: This script is used to plot the box plots of the errors of the imputers on the OpenML datasets.
|
|
9
|
+
- `plot_negative_rmse.py`: This script is used to plot the negative RMSE of the imputers on the OpenML datasets.
|
|
10
|
+
|
|
11
|
+
The datasets are stored in the `datasets` folder. The figures are stored in the `figures` folder.
|
|
12
|
+
|
|
13
|
+
## UCI Datasets
|
|
14
|
+
|
|
15
|
+
We test on the same datasets as in the HyperImpute paper:
|
|
16
|
+
|
|
17
|
+
- Airfoil Self-Noise
|
|
18
|
+
- Blood Transfusion
|
|
19
|
+
- California Housing
|
|
20
|
+
- Concrete Compression
|
|
21
|
+
- Diabetes
|
|
22
|
+
- Ionosphere
|
|
23
|
+
- Iris
|
|
24
|
+
- Letter Recognition
|
|
25
|
+
- Libras Movement
|
|
26
|
+
- Spam Base
|
|
27
|
+
- Wine Quality (Red)
|
|
28
|
+
- Wine Quality (White)
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Script to analyze imputation times from MCAR_0.4 folders and plot them against dataset sizes.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import glob
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import matplotlib
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import numpy as np
|
|
13
|
+
import seaborn as sns
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from scipy import stats
|
|
16
|
+
from plot_options import (
|
|
17
|
+
setup_latex_fonts,
|
|
18
|
+
METHOD_NAMES,
|
|
19
|
+
METHOD_COLORS,
|
|
20
|
+
HIGHLIGHT_COLOR,
|
|
21
|
+
NEUTRAL_COLOR,
|
|
22
|
+
FIGURE_SIZES,
|
|
23
|
+
BARPLOT_STYLE,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
def parse_dataset_sizes(file_path):
|
|
27
|
+
"""Parse the dataset_sizes.txt file to extract dataset names and dimensions."""
|
|
28
|
+
dataset_info = {}
|
|
29
|
+
|
|
30
|
+
with open(file_path, 'r') as f:
|
|
31
|
+
for line in f:
|
|
32
|
+
line = line.strip()
|
|
33
|
+
if line and '|' in line:
|
|
34
|
+
# Parse format: "dataset_name | rows \times cols"
|
|
35
|
+
parts = line.split('|')
|
|
36
|
+
if len(parts) == 2:
|
|
37
|
+
dataset_name = parts[0].strip()
|
|
38
|
+
dimensions = parts[1].strip()
|
|
39
|
+
|
|
40
|
+
# Extract rows and columns from "rows \times cols" format
|
|
41
|
+
match = re.match(r'(\d+)\s*\\times\s*(\d+)', dimensions)
|
|
42
|
+
if match:
|
|
43
|
+
rows = int(match.group(1))
|
|
44
|
+
cols = int(match.group(2))
|
|
45
|
+
dataset_info[dataset_name] = {'rows': rows, 'cols': cols, 'size': rows * cols}
|
|
46
|
+
|
|
47
|
+
return dataset_info
|
|
48
|
+
|
|
49
|
+
def find_imputation_times(base_path):
|
|
50
|
+
"""Find all imputation_time.txt files in MCAR_0.4 folders and extract times."""
|
|
51
|
+
imputation_data = []
|
|
52
|
+
|
|
53
|
+
# Find all MCAR_0.4 folders
|
|
54
|
+
mcar_pattern = os.path.join(base_path, "**", "MCAR_0.4")
|
|
55
|
+
mcar_folders = glob.glob(mcar_pattern, recursive=True)
|
|
56
|
+
|
|
57
|
+
print(f"Found {len(mcar_folders)} MCAR_0.4 folders")
|
|
58
|
+
|
|
59
|
+
for folder in mcar_folders:
|
|
60
|
+
# Extract dataset name from path
|
|
61
|
+
path_parts = Path(folder).parts
|
|
62
|
+
dataset_name = None
|
|
63
|
+
for part in path_parts:
|
|
64
|
+
if part in ['openml']:
|
|
65
|
+
# Get the next part as dataset name
|
|
66
|
+
idx = path_parts.index(part)
|
|
67
|
+
if idx + 1 < len(path_parts):
|
|
68
|
+
dataset_name = path_parts[idx + 1]
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
if not dataset_name:
|
|
72
|
+
print(f"Could not extract dataset name from {folder}")
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# Find all imputation_time.txt files in this folder
|
|
76
|
+
time_files = glob.glob(os.path.join(folder, "*_imputation_time.txt"))
|
|
77
|
+
|
|
78
|
+
for time_file in time_files:
|
|
79
|
+
try:
|
|
80
|
+
with open(time_file, 'r') as f:
|
|
81
|
+
first_line = f.readline().strip()
|
|
82
|
+
if first_line:
|
|
83
|
+
imputation_time = float(first_line)
|
|
84
|
+
|
|
85
|
+
# Extract method name from filename
|
|
86
|
+
filename = os.path.basename(time_file)
|
|
87
|
+
method_name = filename.replace('_imputation_time.txt', '')
|
|
88
|
+
|
|
89
|
+
imputation_data.append({
|
|
90
|
+
'dataset': dataset_name,
|
|
91
|
+
'method': method_name,
|
|
92
|
+
'time': imputation_time,
|
|
93
|
+
'file_path': time_file
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
except (ValueError, FileNotFoundError) as e:
|
|
97
|
+
print(f"Error reading {time_file}: {e}")
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
return imputation_data
|
|
101
|
+
|
|
102
|
+
# Use method names from plot_options and add file-specific mappings
|
|
103
|
+
method_names = METHOD_NAMES.copy()
|
|
104
|
+
# Add file-specific method name mappings
|
|
105
|
+
method_names.update({
|
|
106
|
+
"mcpfn": "TabImpute (GPU)",
|
|
107
|
+
"mcpfn_cpu": "TabImpute (CPU)",
|
|
108
|
+
"tabpfn_unsupervised": "Col-TabPFN (GPU)",
|
|
109
|
+
"tabpfn": "EWF-TabPFN (GPU)",
|
|
110
|
+
"hyperimpute_hyperimpute": "HyperImpute (GPU)",
|
|
111
|
+
"hyperimpute_ot_sinkhorn": "OT",
|
|
112
|
+
"hyperimpute_hyperimpute_missforest": "MissForest",
|
|
113
|
+
"hyperimpute_hyperimpute_ice": "ICE",
|
|
114
|
+
"hyperimpute_hyperimpute_mice": "MICE",
|
|
115
|
+
"hyperimpute_hyperimpute_gain": "GAIN (GPU)",
|
|
116
|
+
"hyperimpute_hyperimpute_miwae": "MIWAE (GPU)",
|
|
117
|
+
"remasker": "ReMasker (GPU)",
|
|
118
|
+
"cacti": "CACTI (GPU)",
|
|
119
|
+
# "tabimpute_mcar_lin": "TabImpute (Lin. Emb.)",
|
|
120
|
+
"tabimpute_dynamic_cls": "TabImpute (New)",
|
|
121
|
+
})
|
|
122
|
+
|
|
123
|
+
# Use colors from plot_options
|
|
124
|
+
neutral_color = NEUTRAL_COLOR
|
|
125
|
+
highlight_color = HIGHLIGHT_COLOR
|
|
126
|
+
# Use darker gray for x-axis labels (not bars) to match plot_negative_rmse.py
|
|
127
|
+
darker_neutral_color = "#333333" # Very dark gray for x-axis label text
|
|
128
|
+
method_colors = METHOD_COLORS.copy()
|
|
129
|
+
|
|
130
|
+
# Add file-specific method colors (for bars)
|
|
131
|
+
method_colors.update({
|
|
132
|
+
"TabImpute (GPU)": highlight_color,
|
|
133
|
+
"TabImpute (CPU)": highlight_color,
|
|
134
|
+
"EWF-TabPFN (GPU)": neutral_color,
|
|
135
|
+
"HyperImpute (GPU)": neutral_color,
|
|
136
|
+
"GAIN (GPU)": neutral_color,
|
|
137
|
+
"MIWAE (GPU)": neutral_color,
|
|
138
|
+
"Col-TabPFN (GPU)": neutral_color,
|
|
139
|
+
"ReMasker (GPU)": neutral_color,
|
|
140
|
+
"CACTI (GPU)": neutral_color,
|
|
141
|
+
# "TabImpute (Lin. Emb.)": highlight_color,
|
|
142
|
+
"TabImpute (New)": highlight_color,
|
|
143
|
+
})
|
|
144
|
+
|
|
145
|
+
include_methods = [
|
|
146
|
+
"mcpfn",
|
|
147
|
+
"mcpfn_cpu",
|
|
148
|
+
# "tabimpute_large_mcar",
|
|
149
|
+
# "mcpfn_ensemble",
|
|
150
|
+
# "mcpfn_ensemble_cpu",
|
|
151
|
+
"tabpfn_unsupervised",
|
|
152
|
+
# "masters_mcar",
|
|
153
|
+
"tabpfn",
|
|
154
|
+
# "tabimpute_mcar_lin",
|
|
155
|
+
"tabimpute_large_mcar_rank_1_11",
|
|
156
|
+
# "tabpfn_impute",
|
|
157
|
+
"hyperimpute_hyperimpute",
|
|
158
|
+
"hyperimpute_hyperimpute_missforest",
|
|
159
|
+
"hyperimpute_ot_sinkhorn",
|
|
160
|
+
"hyperimpute_hyperimpute_ice",
|
|
161
|
+
"hyperimpute_hyperimpute_mice",
|
|
162
|
+
"hyperimpute_hyperimpute_gain",
|
|
163
|
+
"hyperimpute_hyperimpute_miwae",
|
|
164
|
+
# "column_mean",
|
|
165
|
+
"knn",
|
|
166
|
+
"softimpute",
|
|
167
|
+
"forestdiffusion",
|
|
168
|
+
"remasker",
|
|
169
|
+
# "diffputer",
|
|
170
|
+
"cacti",
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
def create_plots(imputation_data, dataset_info):
|
|
174
|
+
"""Create plots of imputation times vs dataset sizes."""
|
|
175
|
+
|
|
176
|
+
# Convert to DataFrame for easier manipulation
|
|
177
|
+
df = pd.DataFrame(imputation_data)
|
|
178
|
+
|
|
179
|
+
# Add dataset size information
|
|
180
|
+
df['dataset_size'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('size', 0))
|
|
181
|
+
df['rows'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('rows', 0))
|
|
182
|
+
df['cols'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('cols', 0))
|
|
183
|
+
|
|
184
|
+
# Filter out datasets without size information
|
|
185
|
+
df = df[df['dataset_size'] > 0]
|
|
186
|
+
|
|
187
|
+
print(f"Found {len(df)} imputation time records")
|
|
188
|
+
print(f"Unique methods: {df['method'].unique()}")
|
|
189
|
+
print(f"Unique datasets: {df['dataset'].unique()}")
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
# Configure LaTeX rendering for all text in plots
|
|
193
|
+
setup_latex_fonts()
|
|
194
|
+
matplotlib.rcParams['text.usetex'] = True
|
|
195
|
+
matplotlib.rcParams['font.family'] = 'serif'
|
|
196
|
+
|
|
197
|
+
# Create efficiency bar plot (runtime per dataset size) using seaborn
|
|
198
|
+
plt.figure(figsize=FIGURE_SIZES['standard'])
|
|
199
|
+
|
|
200
|
+
# Calculate efficiency metric: time per dataset size
|
|
201
|
+
df['efficiency'] = df['time'] / df['dataset_size']
|
|
202
|
+
|
|
203
|
+
# Filter to only include methods in include_methods list
|
|
204
|
+
df_filtered = df[df['method'].isin(include_methods)].copy()
|
|
205
|
+
|
|
206
|
+
# Calculate and print speedup (using method keys, not display names)
|
|
207
|
+
baseline_method_key = 'tabpfn' # Method key for EWF-TabPFN (GPU)
|
|
208
|
+
speed_up_method_key = 'mcpfn' # Method key for TabImpute (GPU)
|
|
209
|
+
baseline_data = df_filtered[df_filtered['method'] == baseline_method_key]
|
|
210
|
+
speed_up_data = df_filtered[df_filtered['method'] == speed_up_method_key]
|
|
211
|
+
|
|
212
|
+
if len(baseline_data) > 0 and len(speed_up_data) > 0:
|
|
213
|
+
baseline_mean_time = baseline_data['time'].mean()
|
|
214
|
+
speed_up_mean_time = speed_up_data['time'].mean()
|
|
215
|
+
speedup = baseline_mean_time / speed_up_mean_time
|
|
216
|
+
baseline_display_name = method_names[baseline_method_key]
|
|
217
|
+
speed_up_display_name = method_names[speed_up_method_key]
|
|
218
|
+
print(f"\nSpeedup of {speed_up_display_name} compared to {baseline_display_name}: {speedup:.2f}x")
|
|
219
|
+
print(f"{speed_up_display_name} mean time: {speed_up_mean_time:.3f} seconds")
|
|
220
|
+
print(f"{baseline_display_name} mean time: {baseline_mean_time:.3f} seconds")
|
|
221
|
+
else:
|
|
222
|
+
print("\nWarning: Could not calculate speedup - missing data for TabPFN (GPU) or TabImpute (GPU)")
|
|
223
|
+
|
|
224
|
+
# Add method names for plotting
|
|
225
|
+
df_filtered['Method'] = df_filtered['method'].map(method_names)
|
|
226
|
+
|
|
227
|
+
# Calculate mean efficiency to determine sort order (decreasing time = increasing efficiency values)
|
|
228
|
+
efficiency_means = df_filtered.groupby('Method')['efficiency'].mean().sort_values(ascending=True)
|
|
229
|
+
|
|
230
|
+
# Create seaborn bar plot with error bars, sorted by efficiency (decreasing time)
|
|
231
|
+
ax = sns.barplot(data=df_filtered, x='Method', y='efficiency', hue='Method',
|
|
232
|
+
order=efficiency_means.index,
|
|
233
|
+
palette=method_colors,
|
|
234
|
+
**BARPLOT_STYLE,
|
|
235
|
+
legend=False)
|
|
236
|
+
|
|
237
|
+
# Set x-axis labels with 45-degree rotation
|
|
238
|
+
# Bold TabImpute methods using LaTeX \textbf{}
|
|
239
|
+
labels_with_bold = [r"\textbf{" + method + "}" if "TabImpute" in method else method for method in efficiency_means.index]
|
|
240
|
+
ax.set_xticks(range(len(efficiency_means.index)))
|
|
241
|
+
ax.set_xticklabels(labels_with_bold, rotation=45, ha='right', fontsize=14)
|
|
242
|
+
ax.set_xlabel('')
|
|
243
|
+
|
|
244
|
+
# Set label colors - use darker color for non-TabImpute x-axis labels
|
|
245
|
+
for i, label in enumerate(ax.get_xticklabels()):
|
|
246
|
+
method_name = efficiency_means.index[i]
|
|
247
|
+
if "TabImpute" in method_name:
|
|
248
|
+
# TabImpute methods use highlight color and are larger
|
|
249
|
+
if method_name in method_colors:
|
|
250
|
+
label.set_color(method_colors[method_name])
|
|
251
|
+
# Make TabImpute methods slightly larger for extra boldness
|
|
252
|
+
label.set_fontsize(label.get_fontsize() * 1.1)
|
|
253
|
+
else:
|
|
254
|
+
# Non-TabImpute methods use darker gray for x-axis labels
|
|
255
|
+
label.set_color(darker_neutral_color)
|
|
256
|
+
|
|
257
|
+
# Use LaTeX-formatted label
|
|
258
|
+
plt.ylabel(r'Milliseconds per entry', fontsize=18)
|
|
259
|
+
# plt.title('Runtime per entry \n(seconds per number of entries (rows × columns))', fontsize=18.0)
|
|
260
|
+
plt.yscale('log') # Set y-axis to log scale
|
|
261
|
+
|
|
262
|
+
# Convert y-axis to milliseconds and format ticks without scientific notation
|
|
263
|
+
ax = plt.gca()
|
|
264
|
+
|
|
265
|
+
# fig.subplots_adjust(left=0.2, right=0.95, bottom=0.05, top=0.95)
|
|
266
|
+
|
|
267
|
+
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x*1000:.2f}'))
|
|
268
|
+
|
|
269
|
+
# Configure grid for log scale - enable both major and minor grid lines
|
|
270
|
+
ax.yaxis.grid(True, which='major', alpha=0.3, linestyle='-')
|
|
271
|
+
ax.yaxis.grid(True, which='minor', alpha=0.3, linestyle='--')
|
|
272
|
+
|
|
273
|
+
plt.tight_layout()
|
|
274
|
+
plt.savefig('/home/jacobf18/tabular/mcpfn/benchmark/imputation_efficiency_barplot.pdf',
|
|
275
|
+
dpi=300, bbox_inches='tight')
|
|
276
|
+
plt.show()
|
|
277
|
+
|
|
278
|
+
# Print efficiency statistics
|
|
279
|
+
print("\n" + "="*60)
|
|
280
|
+
print("EFFICIENCY ANALYSIS (Runtime per Dataset Size)")
|
|
281
|
+
print("="*60)
|
|
282
|
+
print("Lower values indicate better efficiency:")
|
|
283
|
+
|
|
284
|
+
# Calculate mean efficiency for each method for printing
|
|
285
|
+
efficiency_by_method = df_filtered.groupby('method')['efficiency'].mean().sort_values()
|
|
286
|
+
for method, efficiency in efficiency_by_method.items():
|
|
287
|
+
print(f"{method_names[method]:<25}: {efficiency:.2e} seconds per data point")
|
|
288
|
+
|
|
289
|
+
# Create a summary table
|
|
290
|
+
summary_stats = df.groupby('method').agg({
|
|
291
|
+
'time': ['count', 'mean', 'std', 'min', 'max'],
|
|
292
|
+
'dataset_size': ['mean', 'std']
|
|
293
|
+
}).round(3)
|
|
294
|
+
|
|
295
|
+
print("\nSummary Statistics:")
|
|
296
|
+
print(summary_stats)
|
|
297
|
+
|
|
298
|
+
# Save summary to file
|
|
299
|
+
# summary_stats.to_csv('/home/jacobf18/tabular/mcpfn/benchmark/imputation_times_summary.csv')
|
|
300
|
+
|
|
301
|
+
return df
|
|
302
|
+
|
|
303
|
+
def main():
|
|
304
|
+
"""Main function to run the analysis."""
|
|
305
|
+
# Paths
|
|
306
|
+
base_path = "/home/jacobf18/tabular/mcpfn/benchmark/datasets"
|
|
307
|
+
dataset_sizes_file = "/home/jacobf18/tabular/mcpfn/benchmark/dataset_sizes.txt"
|
|
308
|
+
|
|
309
|
+
print("Parsing dataset sizes...")
|
|
310
|
+
dataset_info = parse_dataset_sizes(dataset_sizes_file)
|
|
311
|
+
print(f"Found {len(dataset_info)} datasets with size information")
|
|
312
|
+
|
|
313
|
+
print("\nFinding imputation times...")
|
|
314
|
+
imputation_data = find_imputation_times(base_path)
|
|
315
|
+
print(f"Found {len(imputation_data)} imputation time records")
|
|
316
|
+
|
|
317
|
+
print("\nCreating plots...")
|
|
318
|
+
df = create_plots(imputation_data, dataset_info)
|
|
319
|
+
|
|
320
|
+
print(f"\nAnalysis complete! Results saved to:")
|
|
321
|
+
print("- /home/jacobf18/tabular/mcpfn/benchmark/imputation_efficiency_barplot.pdf")
|
|
322
|
+
|
|
323
|
+
if __name__ == "__main__":
|
|
324
|
+
main()
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
ChonicKidneyDisease & 0.602 & 0.582 & 0.525 & 0.500 \\
|
|
2
|
+
Dog Breeds Ranked & 0.566 & 0.579 & 0.539 & 0.500 \\
|
|
3
|
+
HappinessRank 2015 & 0.486 & 0.466 & 0.677 & 0.500 \\
|
|
4
|
+
MY DB & 0.480 & 0.514 & 0.521 & 0.500 \\
|
|
5
|
+
Online Sales & 0.790 & 0.809 & 0.869 & 0.500 \\
|
|
6
|
+
Parkinson Dataset & 0.665 & 0.640 & 0.523 & 0.500 \\
|
|
7
|
+
acute-inflammations & 0.761 & 0.756 & 0.723 & 0.500 \\
|
|
8
|
+
aids & 0.511 & 0.573 & 0.509 & 0.500 \\
|
|
9
|
+
analcatdata creditscore & 0.678 & 0.672 & 0.491 & 0.500 \\
|
|
10
|
+
analcatdata cyyoung8092 & 0.683 & 0.649 & 0.611 & 0.500 \\
|
|
11
|
+
analcatdata cyyoung9302 & 0.713 & 0.700 & 0.629 & 0.500 \\
|
|
12
|
+
analcatdata impeach & 0.751 & 0.749 & 0.706 & 0.500 \\
|
|
13
|
+
analcatdata ncaa & 0.549 & 0.547 & 0.500 & 0.500 \\
|
|
14
|
+
analcatdata wildcat & 0.749 & 0.682 & 0.674 & 0.500 \\
|
|
15
|
+
auto price & 0.749 & 0.802 & 0.761 & 0.500 \\
|
|
16
|
+
backache & 0.529 & 0.516 & 0.502 & 0.500 \\
|
|
17
|
+
blogger & 0.558 & 0.573 & 0.499 & 0.500 \\
|
|
18
|
+
caesarian-section & 0.534 & 0.499 & 0.518 & 0.500 \\
|
|
19
|
+
cloud & 0.436 & 0.409 & 0.440 & 0.500 \\
|
|
20
|
+
cm1 req & 0.677 & 0.658 & 0.602 & 0.500 \\
|
|
21
|
+
cocomo numeric & 0.629 & 0.618 & 0.655 & 0.500 \\
|
|
22
|
+
conference attendance & 0.500 & 0.512 & 0.499 & 0.500 \\
|
|
23
|
+
corral & 0.553 & 0.558 & 0.575 & 0.500 \\
|
|
24
|
+
cpu & 0.694 & 0.675 & 0.546 & 0.500 \\
|
|
25
|
+
fl2000 & 0.608 & 0.468 & 0.533 & 0.500 \\
|
|
26
|
+
flags & 0.584 & 0.583 & 0.517 & 0.500 \\
|
|
27
|
+
fruitfly & 0.602 & 0.604 & 0.593 & 0.500 \\
|
|
28
|
+
grub-damage & 0.620 & 0.584 & 0.541 & 0.500 \\
|
|
29
|
+
hutsof99 logis & 0.569 & 0.568 & 0.583 & 0.500 \\
|
|
30
|
+
iris & 0.885 & 0.885 & 0.827 & 0.500 \\
|
|
31
|
+
kidney & 0.599 & 0.460 & 0.498 & 0.500 \\
|
|
32
|
+
lowbwt & 0.587 & 0.582 & 0.547 & 0.500 \\
|
|
33
|
+
lung & 0.558 & 0.483 & 0.496 & 0.500 \\
|
|
34
|
+
lungcancer GSE31210 & 0.647 & 0.527 & 0.567 & 0.500 \\
|
|
35
|
+
lymph & 0.641 & 0.594 & 0.547 & 0.500 \\
|
|
36
|
+
molecular-biology promoters & 0.508 & 0.505 & 0.504 & 0.500 \\
|
|
37
|
+
mux6 & 0.507 & 0.481 & 0.502 & 0.500 \\
|
|
38
|
+
nadeem & 0.594 & 0.521 & 0.542 & 0.500 \\
|
|
39
|
+
nasa numeric & 0.596 & 0.594 & 0.561 & 0.500 \\
|
|
40
|
+
postoperative-patient-data & 0.468 & 0.494 & 0.496 & 0.500 \\
|
|
41
|
+
prnn crabs & 0.840 & 0.687 & 0.737 & 0.500 \\
|
|
42
|
+
prnn viruses & 0.563 & 0.588 & 0.551 & 0.500 \\
|
|
43
|
+
qualitative-bankruptcy & 0.652 & 0.694 & 0.645 & 0.500 \\
|
|
44
|
+
servo & 0.528 & 0.516 & 0.480 & 0.500 \\
|
|
45
|
+
sleuth case1202 & 0.544 & 0.590 & 0.521 & 0.500 \\
|
|
46
|
+
sleuth case2002 & 0.609 & 0.564 & 0.526 & 0.500 \\
|
|
47
|
+
sleuth ex2015 & 0.589 & 0.599 & 0.733 & 0.500 \\
|
|
48
|
+
sleuth ex2016 & 0.562 & 0.588 & 0.532 & 0.500 \\
|
|
49
|
+
tae & 0.544 & 0.592 & 0.533 & 0.500 \\
|
|
50
|
+
teachingAssistant & 0.533 & 0.528 & 0.499 & 0.500 \\
|
|
51
|
+
veteran & 0.520 & 0.520 & 0.502 & 0.500 \\
|
|
52
|
+
white-clover & 0.626 & 0.598 & 0.622 & 0.500 \\
|
|
53
|
+
zoo & 0.754 & 0.710 & 0.764 & 0.500
|