tabimpute 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. tabimpute-0.0.2/.gitignore +4 -0
  2. tabimpute-0.0.2/PKG-INFO +58 -0
  3. tabimpute-0.0.2/README.md +5 -0
  4. tabimpute-0.0.2/about.py +1 -0
  5. tabimpute-0.0.2/benchmark/README.md +28 -0
  6. tabimpute-0.0.2/benchmark/analyze_imputation_times.py +324 -0
  7. tabimpute-0.0.2/benchmark/auc_table.tex +53 -0
  8. tabimpute-0.0.2/benchmark/cacti_wrapper.py +259 -0
  9. tabimpute-0.0.2/benchmark/calculate_imputation_variance.py +290 -0
  10. tabimpute-0.0.2/benchmark/create_openml_categorical_missingness_datasets.py +38 -0
  11. tabimpute-0.0.2/benchmark/create_openml_missingness_datasets.py +145 -0
  12. tabimpute-0.0.2/benchmark/create_uci_missingness.py +105 -0
  13. tabimpute-0.0.2/benchmark/dataset_sizes.txt +166 -0
  14. tabimpute-0.0.2/benchmark/dataset_sizes_missing.txt +7 -0
  15. tabimpute-0.0.2/benchmark/datasets/.gitignore +4 -0
  16. tabimpute-0.0.2/benchmark/datasets_df.csv +87 -0
  17. tabimpute-0.0.2/benchmark/diffputer_wrapper.py +240 -0
  18. tabimpute-0.0.2/benchmark/download_openml.py +146 -0
  19. tabimpute-0.0.2/benchmark/download_uci_datasets.py +298 -0
  20. tabimpute-0.0.2/benchmark/generate_dataset_table.py +71 -0
  21. tabimpute-0.0.2/benchmark/get_openml_categorical_errors.py +278 -0
  22. tabimpute-0.0.2/benchmark/get_openml_errors.py +486 -0
  23. tabimpute-0.0.2/benchmark/get_uci_errors.py +470 -0
  24. tabimpute-0.0.2/benchmark/notmiwae_wrapper.py +289 -0
  25. tabimpute-0.0.2/benchmark/plot_accuracy.py +288 -0
  26. tabimpute-0.0.2/benchmark/plot_critical_difference.py +347 -0
  27. tabimpute-0.0.2/benchmark/plot_mcar_line.py +193 -0
  28. tabimpute-0.0.2/benchmark/plot_negative_mae.py +52 -0
  29. tabimpute-0.0.2/benchmark/plot_negative_rmse.py +845 -0
  30. tabimpute-0.0.2/benchmark/plot_options.py +234 -0
  31. tabimpute-0.0.2/benchmark/plot_pairwise_by_pattern.py +127 -0
  32. tabimpute-0.0.2/benchmark/plot_r_squared.py +733 -0
  33. tabimpute-0.0.2/benchmark/plot_uci_negative_rmse.py +509 -0
  34. tabimpute-0.0.2/benchmark/plot_wasserstein.py +723 -0
  35. tabimpute-0.0.2/benchmark/plot_win_rate.py +247 -0
  36. tabimpute-0.0.2/benchmark/remasker_wrapper.py +310 -0
  37. tabimpute-0.0.2/benchmark/runtime_benchmark/README.md +47 -0
  38. tabimpute-0.0.2/benchmark/runtime_benchmark/get_runtime_models.py +153 -0
  39. tabimpute-0.0.2/benchmark/runtime_benchmark/plot_runtime_benchmark.py +110 -0
  40. tabimpute-0.0.2/benchmark/runtime_benchmark/runtime_benchmark_results.csv +61 -0
  41. tabimpute-0.0.2/config.pkl +0 -0
  42. tabimpute-0.0.2/pyproject.toml +120 -0
  43. tabimpute-0.0.2/scripts/gen_multiple.sh +14 -0
  44. tabimpute-0.0.2/scripts/generate_data.sh +65 -0
  45. tabimpute-0.0.2/scripts/train.sh +70 -0
  46. tabimpute-0.0.2/scripts/train_mar.sh +70 -0
  47. tabimpute-0.0.2/scripts/train_mcar.sh +71 -0
  48. tabimpute-0.0.2/scripts/train_mcar_nonlinear.sh +71 -0
  49. tabimpute-0.0.2/scripts/train_mnar.sh +70 -0
  50. tabimpute-0.0.2/src/tabimpute/__about__.py +1 -0
  51. tabimpute-0.0.2/src/tabimpute/__init__.py +1 -0
  52. tabimpute-0.0.2/src/tabimpute/data/borders.pt +0 -0
  53. tabimpute-0.0.2/src/tabimpute/data/encoder.pth +0 -0
  54. tabimpute-0.0.2/src/tabimpute/diffusion/__init__.py +13 -0
  55. tabimpute-0.0.2/src/tabimpute/diffusion/mar_diffusion_row10_30_col10_30_mar0_3_marblock0_3_marbandit0_4_epoch100_bs32_samples30k_lr1e_3.py +850 -0
  56. tabimpute-0.0.2/src/tabimpute/interface.py +595 -0
  57. tabimpute-0.0.2/src/tabimpute/misc/_sklearn_compat.py +869 -0
  58. tabimpute-0.0.2/src/tabimpute/misc/debug_versions.py +702 -0
  59. tabimpute-0.0.2/src/tabimpute/model/__init__.py +0 -0
  60. tabimpute-0.0.2/src/tabimpute/model/bar_distribution.py +863 -0
  61. tabimpute-0.0.2/src/tabimpute/model/config.py +150 -0
  62. tabimpute-0.0.2/src/tabimpute/model/encoders.py +1078 -0
  63. tabimpute-0.0.2/src/tabimpute/model/full_attention.py +1565 -0
  64. tabimpute-0.0.2/src/tabimpute/model/inference.py +683 -0
  65. tabimpute-0.0.2/src/tabimpute/model/inference_config.py +228 -0
  66. tabimpute-0.0.2/src/tabimpute/model/layer.py +472 -0
  67. tabimpute-0.0.2/src/tabimpute/model/mcpfn.py +138 -0
  68. tabimpute-0.0.2/src/tabimpute/model/memory.py +452 -0
  69. tabimpute-0.0.2/src/tabimpute/model/mlp.py +138 -0
  70. tabimpute-0.0.2/src/tabimpute/model/model.py +403 -0
  71. tabimpute-0.0.2/src/tabimpute/model/positional.py +169 -0
  72. tabimpute-0.0.2/src/tabimpute/model/transformer.py +870 -0
  73. tabimpute-0.0.2/src/tabimpute/prepreocess.py +575 -0
  74. tabimpute-0.0.2/src/tabimpute/prior/__init__.py +0 -0
  75. tabimpute-0.0.2/src/tabimpute/prior/activations.py +289 -0
  76. tabimpute-0.0.2/src/tabimpute/prior/base_prior.py +392 -0
  77. tabimpute-0.0.2/src/tabimpute/prior/dataset.py +773 -0
  78. tabimpute-0.0.2/src/tabimpute/prior/genload.py +758 -0
  79. tabimpute-0.0.2/src/tabimpute/prior/hp_sampling.py +301 -0
  80. tabimpute-0.0.2/src/tabimpute/prior/mar_block_missing.py +142 -0
  81. tabimpute-0.0.2/src/tabimpute/prior/mar_missing.py +355 -0
  82. tabimpute-0.0.2/src/tabimpute/prior/mar_onesided_missing.py +356 -0
  83. tabimpute-0.0.2/src/tabimpute/prior/mar_sequential_missing.py +584 -0
  84. tabimpute-0.0.2/src/tabimpute/prior/mlp_scm.py +344 -0
  85. tabimpute-0.0.2/src/tabimpute/prior/prior_config.py +94 -0
  86. tabimpute-0.0.2/src/tabimpute/prior/reg2cls.py +390 -0
  87. tabimpute-0.0.2/src/tabimpute/prior/scm_prior.py +383 -0
  88. tabimpute-0.0.2/src/tabimpute/prior/splits.py +45 -0
  89. tabimpute-0.0.2/src/tabimpute/prior/training_set_generation.py +1295 -0
  90. tabimpute-0.0.2/src/tabimpute/prior/tree_scm.py +401 -0
  91. tabimpute-0.0.2/src/tabimpute/prior/utils.py +165 -0
  92. tabimpute-0.0.2/src/tabimpute/tabimpute_v2.py +88 -0
  93. tabimpute-0.0.2/src/tabimpute/tabpfn_extensions_interface.py +73 -0
  94. tabimpute-0.0.2/src/tabimpute/train/__init__.py +0 -0
  95. tabimpute-0.0.2/src/tabimpute/train/callbacks.py +92 -0
  96. tabimpute-0.0.2/src/tabimpute/train/optim.py +356 -0
  97. tabimpute-0.0.2/src/tabimpute/train/run.py +1141 -0
  98. tabimpute-0.0.2/src/tabimpute/train/train_config.py +431 -0
  99. tabimpute-0.0.2/src/tabimpute/train.py +232 -0
  100. tabimpute-0.0.2/test_bar_distribution_shape.py +200 -0
  101. tabimpute-0.0.2/tests.ipynb +12524 -0
  102. tabimpute-0.0.2/train.md +40 -0
  103. tabimpute-0.0.2/utils.py +50 -0
@@ -0,0 +1,4 @@
1
+ misc_data/
2
+ wandb/
3
+ stage1/
4
+ workdir/
@@ -0,0 +1,58 @@
1
+ Metadata-Version: 2.4
2
+ Name: tabimpute
3
+ Version: 0.0.2
4
+ Summary: TabImpute: A Pre-trained Transformer for Missing Data Imputation
5
+ Author: Jacob Feitelberg, Dwaipayan Saha, Zaid Ahmad, Kyuseong Choi, Anish Agarwal, Raaz Dwivedi
6
+ Keywords: foundation model,in-context learning,missing data imputation,tabular data
7
+ Classifier: Development Status :: 4 - Beta
8
+ Classifier: Intended Audience :: Developers
9
+ Classifier: Intended Audience :: Science/Research
10
+ Classifier: Programming Language :: Python
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering
17
+ Requires-Python: <3.13,>=3.9
18
+ Requires-Dist: einops>=0.7
19
+ Requires-Dist: huggingface-hub
20
+ Requires-Dist: networkx
21
+ Requires-Dist: numpy
22
+ Requires-Dist: tabpfn==6.0.0
23
+ Requires-Dist: torch<3,>=2.2
24
+ Provides-Extra: benchmark
25
+ Requires-Dist: joblib; extra == 'benchmark'
26
+ Requires-Dist: matplotlib; extra == 'benchmark'
27
+ Requires-Dist: pandas; extra == 'benchmark'
28
+ Requires-Dist: psutil; extra == 'benchmark'
29
+ Requires-Dist: scikit-learn==1.4.2; extra == 'benchmark'
30
+ Requires-Dist: scipy; extra == 'benchmark'
31
+ Requires-Dist: tabpfn; extra == 'benchmark'
32
+ Requires-Dist: tqdm>=4.64.0; extra == 'benchmark'
33
+ Requires-Dist: transformers; extra == 'benchmark'
34
+ Requires-Dist: wandb; extra == 'benchmark'
35
+ Requires-Dist: xgboost; extra == 'benchmark'
36
+ Provides-Extra: categorical
37
+ Requires-Dist: scipy; extra == 'categorical'
38
+ Provides-Extra: preprocessing
39
+ Requires-Dist: scipy; extra == 'preprocessing'
40
+ Provides-Extra: tabpfn-extensions
41
+ Requires-Dist: tabpfn-extensions; extra == 'tabpfn-extensions'
42
+ Provides-Extra: training
43
+ Requires-Dist: joblib; extra == 'training'
44
+ Requires-Dist: pandas; extra == 'training'
45
+ Requires-Dist: psutil; extra == 'training'
46
+ Requires-Dist: scikit-learn==1.4.2; extra == 'training'
47
+ Requires-Dist: scipy; extra == 'training'
48
+ Requires-Dist: tqdm>=4.64.0; extra == 'training'
49
+ Requires-Dist: transformers; extra == 'training'
50
+ Requires-Dist: wandb; extra == 'training'
51
+ Requires-Dist: xgboost; extra == 'training'
52
+ Description-Content-Type: text/markdown
53
+
54
+ # TabImpute
55
+
56
+ TabImpute is a pre-trained transformer for missing data imputation on tabular data.
57
+
58
+ This code is based on the TabPFN and TabICL codebases, both available on GitHub.
@@ -0,0 +1,5 @@
1
+ # TabImpute
2
+
3
+ TabImpute is a pre-trained transformer for missing data imputation on tabular data.
4
+
5
+ This code is based on the TabPFN and TabICL codebases, both available on GitHub.
@@ -0,0 +1 @@
1
+ __version__ = "0.0.2"
@@ -0,0 +1,28 @@
1
+ This folder contains the code for the benchmark of the imputers on the OpenML datasets.
2
+
3
+ The code is organized as follows:
4
+
5
+ - `create_openml_missingness_datasets.py`: This script is used to create the missingness datasets from the OpenML datasets.
6
+ - `get_openml_errors.py`: This script is used to get the errors of the imputers on the OpenML datasets.
7
+ - `plot_error_violinplots.py`: This script is used to plot the violin plots of the errors of the imputers on the OpenML datasets.
8
+ - `plot_error_boxplots.py`: This script is used to plot the box plots of the errors of the imputers on the OpenML datasets.
9
+ - `plot_negative_rmse.py`: This script is used to plot the negative RMSE of the imputers on the OpenML datasets.
10
+
11
+ The datasets are stored in the `datasets` folder. The figures are stored in the `figures` folder.
12
+
13
+ ## UCI Datasets
14
+
15
+ We test on the same datasets as in the HyperImpute paper:
16
+
17
+ - Airfoil Self-Noise
18
+ - Blood Transfusion
19
+ - California Housing
20
+ - Concrete Compression
21
+ - Diabetes
22
+ - Ionosphere
23
+ - Iris
24
+ - Letter Recognition
25
+ - Libras Movement
26
+ - Spam Base
27
+ - Wine Quality (Red)
28
+ - Wine Quality (White)
@@ -0,0 +1,324 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to analyze imputation times from MCAR_0.4 folders and plot them against dataset sizes.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import glob
9
+ import pandas as pd
10
+ import matplotlib
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import seaborn as sns
14
+ from pathlib import Path
15
+ from scipy import stats
16
+ from plot_options import (
17
+ setup_latex_fonts,
18
+ METHOD_NAMES,
19
+ METHOD_COLORS,
20
+ HIGHLIGHT_COLOR,
21
+ NEUTRAL_COLOR,
22
+ FIGURE_SIZES,
23
+ BARPLOT_STYLE,
24
+ )
25
+
26
+ def parse_dataset_sizes(file_path):
27
+ """Parse the dataset_sizes.txt file to extract dataset names and dimensions."""
28
+ dataset_info = {}
29
+
30
+ with open(file_path, 'r') as f:
31
+ for line in f:
32
+ line = line.strip()
33
+ if line and '|' in line:
34
+ # Parse format: "dataset_name | rows \times cols"
35
+ parts = line.split('|')
36
+ if len(parts) == 2:
37
+ dataset_name = parts[0].strip()
38
+ dimensions = parts[1].strip()
39
+
40
+ # Extract rows and columns from "rows \times cols" format
41
+ match = re.match(r'(\d+)\s*\\times\s*(\d+)', dimensions)
42
+ if match:
43
+ rows = int(match.group(1))
44
+ cols = int(match.group(2))
45
+ dataset_info[dataset_name] = {'rows': rows, 'cols': cols, 'size': rows * cols}
46
+
47
+ return dataset_info
48
+
49
+ def find_imputation_times(base_path):
50
+ """Find all imputation_time.txt files in MCAR_0.4 folders and extract times."""
51
+ imputation_data = []
52
+
53
+ # Find all MCAR_0.4 folders
54
+ mcar_pattern = os.path.join(base_path, "**", "MCAR_0.4")
55
+ mcar_folders = glob.glob(mcar_pattern, recursive=True)
56
+
57
+ print(f"Found {len(mcar_folders)} MCAR_0.4 folders")
58
+
59
+ for folder in mcar_folders:
60
+ # Extract dataset name from path
61
+ path_parts = Path(folder).parts
62
+ dataset_name = None
63
+ for part in path_parts:
64
+ if part in ['openml']:
65
+ # Get the next part as dataset name
66
+ idx = path_parts.index(part)
67
+ if idx + 1 < len(path_parts):
68
+ dataset_name = path_parts[idx + 1]
69
+ break
70
+
71
+ if not dataset_name:
72
+ print(f"Could not extract dataset name from {folder}")
73
+ continue
74
+
75
+ # Find all imputation_time.txt files in this folder
76
+ time_files = glob.glob(os.path.join(folder, "*_imputation_time.txt"))
77
+
78
+ for time_file in time_files:
79
+ try:
80
+ with open(time_file, 'r') as f:
81
+ first_line = f.readline().strip()
82
+ if first_line:
83
+ imputation_time = float(first_line)
84
+
85
+ # Extract method name from filename
86
+ filename = os.path.basename(time_file)
87
+ method_name = filename.replace('_imputation_time.txt', '')
88
+
89
+ imputation_data.append({
90
+ 'dataset': dataset_name,
91
+ 'method': method_name,
92
+ 'time': imputation_time,
93
+ 'file_path': time_file
94
+ })
95
+
96
+ except (ValueError, FileNotFoundError) as e:
97
+ print(f"Error reading {time_file}: {e}")
98
+ continue
99
+
100
+ return imputation_data
101
+
102
+ # Use method names from plot_options and add file-specific mappings
103
+ method_names = METHOD_NAMES.copy()
104
+ # Add file-specific method name mappings
105
+ method_names.update({
106
+ "mcpfn": "TabImpute (GPU)",
107
+ "mcpfn_cpu": "TabImpute (CPU)",
108
+ "tabpfn_unsupervised": "Col-TabPFN (GPU)",
109
+ "tabpfn": "EWF-TabPFN (GPU)",
110
+ "hyperimpute_hyperimpute": "HyperImpute (GPU)",
111
+ "hyperimpute_ot_sinkhorn": "OT",
112
+ "hyperimpute_hyperimpute_missforest": "MissForest",
113
+ "hyperimpute_hyperimpute_ice": "ICE",
114
+ "hyperimpute_hyperimpute_mice": "MICE",
115
+ "hyperimpute_hyperimpute_gain": "GAIN (GPU)",
116
+ "hyperimpute_hyperimpute_miwae": "MIWAE (GPU)",
117
+ "remasker": "ReMasker (GPU)",
118
+ "cacti": "CACTI (GPU)",
119
+ # "tabimpute_mcar_lin": "TabImpute (Lin. Emb.)",
120
+ "tabimpute_dynamic_cls": "TabImpute (New)",
121
+ })
122
+
123
+ # Use colors from plot_options
124
+ neutral_color = NEUTRAL_COLOR
125
+ highlight_color = HIGHLIGHT_COLOR
126
+ # Use darker gray for x-axis labels (not bars) to match plot_negative_rmse.py
127
+ darker_neutral_color = "#333333" # Very dark gray for x-axis label text
128
+ method_colors = METHOD_COLORS.copy()
129
+
130
+ # Add file-specific method colors (for bars)
131
+ method_colors.update({
132
+ "TabImpute (GPU)": highlight_color,
133
+ "TabImpute (CPU)": highlight_color,
134
+ "EWF-TabPFN (GPU)": neutral_color,
135
+ "HyperImpute (GPU)": neutral_color,
136
+ "GAIN (GPU)": neutral_color,
137
+ "MIWAE (GPU)": neutral_color,
138
+ "Col-TabPFN (GPU)": neutral_color,
139
+ "ReMasker (GPU)": neutral_color,
140
+ "CACTI (GPU)": neutral_color,
141
+ # "TabImpute (Lin. Emb.)": highlight_color,
142
+ "TabImpute (New)": highlight_color,
143
+ })
144
+
145
+ include_methods = [
146
+ "mcpfn",
147
+ "mcpfn_cpu",
148
+ # "tabimpute_large_mcar",
149
+ # "mcpfn_ensemble",
150
+ # "mcpfn_ensemble_cpu",
151
+ "tabpfn_unsupervised",
152
+ # "masters_mcar",
153
+ "tabpfn",
154
+ # "tabimpute_mcar_lin",
155
+ "tabimpute_large_mcar_rank_1_11",
156
+ # "tabpfn_impute",
157
+ "hyperimpute_hyperimpute",
158
+ "hyperimpute_hyperimpute_missforest",
159
+ "hyperimpute_ot_sinkhorn",
160
+ "hyperimpute_hyperimpute_ice",
161
+ "hyperimpute_hyperimpute_mice",
162
+ "hyperimpute_hyperimpute_gain",
163
+ "hyperimpute_hyperimpute_miwae",
164
+ # "column_mean",
165
+ "knn",
166
+ "softimpute",
167
+ "forestdiffusion",
168
+ "remasker",
169
+ # "diffputer",
170
+ "cacti",
171
+ ]
172
+
173
+ def create_plots(imputation_data, dataset_info):
174
+ """Create plots of imputation times vs dataset sizes."""
175
+
176
+ # Convert to DataFrame for easier manipulation
177
+ df = pd.DataFrame(imputation_data)
178
+
179
+ # Add dataset size information
180
+ df['dataset_size'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('size', 0))
181
+ df['rows'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('rows', 0))
182
+ df['cols'] = df['dataset'].map(lambda x: dataset_info.get(x, {}).get('cols', 0))
183
+
184
+ # Filter out datasets without size information
185
+ df = df[df['dataset_size'] > 0]
186
+
187
+ print(f"Found {len(df)} imputation time records")
188
+ print(f"Unique methods: {df['method'].unique()}")
189
+ print(f"Unique datasets: {df['dataset'].unique()}")
190
+
191
+
192
+ # Configure LaTeX rendering for all text in plots
193
+ setup_latex_fonts()
194
+ matplotlib.rcParams['text.usetex'] = True
195
+ matplotlib.rcParams['font.family'] = 'serif'
196
+
197
+ # Create efficiency bar plot (runtime per dataset size) using seaborn
198
+ plt.figure(figsize=FIGURE_SIZES['standard'])
199
+
200
+ # Calculate efficiency metric: time per dataset size
201
+ df['efficiency'] = df['time'] / df['dataset_size']
202
+
203
+ # Filter to only include methods in include_methods list
204
+ df_filtered = df[df['method'].isin(include_methods)].copy()
205
+
206
+ # Calculate and print speedup (using method keys, not display names)
207
+ baseline_method_key = 'tabpfn' # Method key for EWF-TabPFN (GPU)
208
+ speed_up_method_key = 'mcpfn' # Method key for TabImpute (GPU)
209
+ baseline_data = df_filtered[df_filtered['method'] == baseline_method_key]
210
+ speed_up_data = df_filtered[df_filtered['method'] == speed_up_method_key]
211
+
212
+ if len(baseline_data) > 0 and len(speed_up_data) > 0:
213
+ baseline_mean_time = baseline_data['time'].mean()
214
+ speed_up_mean_time = speed_up_data['time'].mean()
215
+ speedup = baseline_mean_time / speed_up_mean_time
216
+ baseline_display_name = method_names[baseline_method_key]
217
+ speed_up_display_name = method_names[speed_up_method_key]
218
+ print(f"\nSpeedup of {speed_up_display_name} compared to {baseline_display_name}: {speedup:.2f}x")
219
+ print(f"{speed_up_display_name} mean time: {speed_up_mean_time:.3f} seconds")
220
+ print(f"{baseline_display_name} mean time: {baseline_mean_time:.3f} seconds")
221
+ else:
222
+ print("\nWarning: Could not calculate speedup - missing data for TabPFN (GPU) or TabImpute (GPU)")
223
+
224
+ # Add method names for plotting
225
+ df_filtered['Method'] = df_filtered['method'].map(method_names)
226
+
227
+ # Calculate mean efficiency to determine sort order (decreasing time = increasing efficiency values)
228
+ efficiency_means = df_filtered.groupby('Method')['efficiency'].mean().sort_values(ascending=True)
229
+
230
+ # Create seaborn bar plot with error bars, sorted by efficiency (decreasing time)
231
+ ax = sns.barplot(data=df_filtered, x='Method', y='efficiency', hue='Method',
232
+ order=efficiency_means.index,
233
+ palette=method_colors,
234
+ **BARPLOT_STYLE,
235
+ legend=False)
236
+
237
+ # Set x-axis labels with 45-degree rotation
238
+ # Bold TabImpute methods using LaTeX \textbf{}
239
+ labels_with_bold = [r"\textbf{" + method + "}" if "TabImpute" in method else method for method in efficiency_means.index]
240
+ ax.set_xticks(range(len(efficiency_means.index)))
241
+ ax.set_xticklabels(labels_with_bold, rotation=45, ha='right', fontsize=14)
242
+ ax.set_xlabel('')
243
+
244
+ # Set label colors - use darker color for non-TabImpute x-axis labels
245
+ for i, label in enumerate(ax.get_xticklabels()):
246
+ method_name = efficiency_means.index[i]
247
+ if "TabImpute" in method_name:
248
+ # TabImpute methods use highlight color and are larger
249
+ if method_name in method_colors:
250
+ label.set_color(method_colors[method_name])
251
+ # Make TabImpute methods slightly larger for extra boldness
252
+ label.set_fontsize(label.get_fontsize() * 1.1)
253
+ else:
254
+ # Non-TabImpute methods use darker gray for x-axis labels
255
+ label.set_color(darker_neutral_color)
256
+
257
+ # Use LaTeX-formatted label
258
+ plt.ylabel(r'Milliseconds per entry', fontsize=18)
259
+ # plt.title('Runtime per entry \n(seconds per number of entries (rows × columns))', fontsize=18.0)
260
+ plt.yscale('log') # Set y-axis to log scale
261
+
262
+ # Convert y-axis to milliseconds and format ticks without scientific notation
263
+ ax = plt.gca()
264
+
265
+ # fig.subplots_adjust(left=0.2, right=0.95, bottom=0.05, top=0.95)
266
+
267
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x*1000:.2f}'))
268
+
269
+ # Configure grid for log scale - enable both major and minor grid lines
270
+ ax.yaxis.grid(True, which='major', alpha=0.3, linestyle='-')
271
+ ax.yaxis.grid(True, which='minor', alpha=0.3, linestyle='--')
272
+
273
+ plt.tight_layout()
274
+ plt.savefig('/home/jacobf18/tabular/mcpfn/benchmark/imputation_efficiency_barplot.pdf',
275
+ dpi=300, bbox_inches='tight')
276
+ plt.show()
277
+
278
+ # Print efficiency statistics
279
+ print("\n" + "="*60)
280
+ print("EFFICIENCY ANALYSIS (Runtime per Dataset Size)")
281
+ print("="*60)
282
+ print("Lower values indicate better efficiency:")
283
+
284
+ # Calculate mean efficiency for each method for printing
285
+ efficiency_by_method = df_filtered.groupby('method')['efficiency'].mean().sort_values()
286
+ for method, efficiency in efficiency_by_method.items():
287
+ print(f"{method_names[method]:<25}: {efficiency:.2e} seconds per data point")
288
+
289
+ # Create a summary table
290
+ summary_stats = df.groupby('method').agg({
291
+ 'time': ['count', 'mean', 'std', 'min', 'max'],
292
+ 'dataset_size': ['mean', 'std']
293
+ }).round(3)
294
+
295
+ print("\nSummary Statistics:")
296
+ print(summary_stats)
297
+
298
+ # Save summary to file
299
+ # summary_stats.to_csv('/home/jacobf18/tabular/mcpfn/benchmark/imputation_times_summary.csv')
300
+
301
+ return df
302
+
303
+ def main():
304
+ """Main function to run the analysis."""
305
+ # Paths
306
+ base_path = "/home/jacobf18/tabular/mcpfn/benchmark/datasets"
307
+ dataset_sizes_file = "/home/jacobf18/tabular/mcpfn/benchmark/dataset_sizes.txt"
308
+
309
+ print("Parsing dataset sizes...")
310
+ dataset_info = parse_dataset_sizes(dataset_sizes_file)
311
+ print(f"Found {len(dataset_info)} datasets with size information")
312
+
313
+ print("\nFinding imputation times...")
314
+ imputation_data = find_imputation_times(base_path)
315
+ print(f"Found {len(imputation_data)} imputation time records")
316
+
317
+ print("\nCreating plots...")
318
+ df = create_plots(imputation_data, dataset_info)
319
+
320
+ print(f"\nAnalysis complete! Results saved to:")
321
+ print("- /home/jacobf18/tabular/mcpfn/benchmark/imputation_efficiency_barplot.pdf")
322
+
323
+ if __name__ == "__main__":
324
+ main()
@@ -0,0 +1,53 @@
1
+ ChonicKidneyDisease & 0.602 & 0.582 & 0.525 & 0.500 \\
2
+ Dog Breeds Ranked & 0.566 & 0.579 & 0.539 & 0.500 \\
3
+ HappinessRank 2015 & 0.486 & 0.466 & 0.677 & 0.500 \\
4
+ MY DB & 0.480 & 0.514 & 0.521 & 0.500 \\
5
+ Online Sales & 0.790 & 0.809 & 0.869 & 0.500 \\
6
+ Parkinson Dataset & 0.665 & 0.640 & 0.523 & 0.500 \\
7
+ acute-inflammations & 0.761 & 0.756 & 0.723 & 0.500 \\
8
+ aids & 0.511 & 0.573 & 0.509 & 0.500 \\
9
+ analcatdata creditscore & 0.678 & 0.672 & 0.491 & 0.500 \\
10
+ analcatdata cyyoung8092 & 0.683 & 0.649 & 0.611 & 0.500 \\
11
+ analcatdata cyyoung9302 & 0.713 & 0.700 & 0.629 & 0.500 \\
12
+ analcatdata impeach & 0.751 & 0.749 & 0.706 & 0.500 \\
13
+ analcatdata ncaa & 0.549 & 0.547 & 0.500 & 0.500 \\
14
+ analcatdata wildcat & 0.749 & 0.682 & 0.674 & 0.500 \\
15
+ auto price & 0.749 & 0.802 & 0.761 & 0.500 \\
16
+ backache & 0.529 & 0.516 & 0.502 & 0.500 \\
17
+ blogger & 0.558 & 0.573 & 0.499 & 0.500 \\
18
+ caesarian-section & 0.534 & 0.499 & 0.518 & 0.500 \\
19
+ cloud & 0.436 & 0.409 & 0.440 & 0.500 \\
20
+ cm1 req & 0.677 & 0.658 & 0.602 & 0.500 \\
21
+ cocomo numeric & 0.629 & 0.618 & 0.655 & 0.500 \\
22
+ conference attendance & 0.500 & 0.512 & 0.499 & 0.500 \\
23
+ corral & 0.553 & 0.558 & 0.575 & 0.500 \\
24
+ cpu & 0.694 & 0.675 & 0.546 & 0.500 \\
25
+ fl2000 & 0.608 & 0.468 & 0.533 & 0.500 \\
26
+ flags & 0.584 & 0.583 & 0.517 & 0.500 \\
27
+ fruitfly & 0.602 & 0.604 & 0.593 & 0.500 \\
28
+ grub-damage & 0.620 & 0.584 & 0.541 & 0.500 \\
29
+ hutsof99 logis & 0.569 & 0.568 & 0.583 & 0.500 \\
30
+ iris & 0.885 & 0.885 & 0.827 & 0.500 \\
31
+ kidney & 0.599 & 0.460 & 0.498 & 0.500 \\
32
+ lowbwt & 0.587 & 0.582 & 0.547 & 0.500 \\
33
+ lung & 0.558 & 0.483 & 0.496 & 0.500 \\
34
+ lungcancer GSE31210 & 0.647 & 0.527 & 0.567 & 0.500 \\
35
+ lymph & 0.641 & 0.594 & 0.547 & 0.500 \\
36
+ molecular-biology promoters & 0.508 & 0.505 & 0.504 & 0.500 \\
37
+ mux6 & 0.507 & 0.481 & 0.502 & 0.500 \\
38
+ nadeem & 0.594 & 0.521 & 0.542 & 0.500 \\
39
+ nasa numeric & 0.596 & 0.594 & 0.561 & 0.500 \\
40
+ postoperative-patient-data & 0.468 & 0.494 & 0.496 & 0.500 \\
41
+ prnn crabs & 0.840 & 0.687 & 0.737 & 0.500 \\
42
+ prnn viruses & 0.563 & 0.588 & 0.551 & 0.500 \\
43
+ qualitative-bankruptcy & 0.652 & 0.694 & 0.645 & 0.500 \\
44
+ servo & 0.528 & 0.516 & 0.480 & 0.500 \\
45
+ sleuth case1202 & 0.544 & 0.590 & 0.521 & 0.500 \\
46
+ sleuth case2002 & 0.609 & 0.564 & 0.526 & 0.500 \\
47
+ sleuth ex2015 & 0.589 & 0.599 & 0.733 & 0.500 \\
48
+ sleuth ex2016 & 0.562 & 0.588 & 0.532 & 0.500 \\
49
+ tae & 0.544 & 0.592 & 0.533 & 0.500 \\
50
+ teachingAssistant & 0.533 & 0.528 & 0.499 & 0.500 \\
51
+ veteran & 0.520 & 0.520 & 0.502 & 0.500 \\
52
+ white-clover & 0.626 & 0.598 & 0.622 & 0.500 \\
53
+ zoo & 0.754 & 0.710 & 0.764 & 0.500