likelihood 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- likelihood/graph/nn.py +8 -2
- likelihood/models/deep/autoencoders.py +312 -109
- likelihood/models/simulation.py +9 -9
- likelihood/tools/figures.py +348 -0
- likelihood/tools/impute.py +279 -0
- likelihood/tools/models_tools.py +161 -9
- likelihood/tools/numeric_tools.py +21 -0
- likelihood/tools/tools.py +46 -92
- {likelihood-1.4.1.dist-info → likelihood-1.5.1.dist-info}/METADATA +3 -2
- likelihood-1.5.1.dist-info/RECORD +23 -0
- {likelihood-1.4.1.dist-info → likelihood-1.5.1.dist-info}/WHEEL +1 -1
- likelihood-1.4.1.dist-info/RECORD +0 -21
- {likelihood-1.4.1.dist-info → likelihood-1.5.1.dist-info/licenses}/LICENSE +0 -0
- {likelihood-1.4.1.dist-info → likelihood-1.5.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from matplotlib.ticker import AutoMinorLocator
|
|
8
|
+
from scipy import stats
|
|
9
|
+
|
|
10
|
+
plt.rcParams.update({"font.size": 14})
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def act_pred(
|
|
14
|
+
y_act: np.ndarray,
|
|
15
|
+
y_pred: np.ndarray,
|
|
16
|
+
name: str = "example",
|
|
17
|
+
x_hist: bool = True,
|
|
18
|
+
y_hist: bool = True,
|
|
19
|
+
reg_line: bool = True,
|
|
20
|
+
save_dir: Optional[str] = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Creates a scatter plot of actual vs predicted values along with histograms and a regression line.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
y_act : `np.ndarray`
|
|
28
|
+
The actual values (ground truth) as a 1D numpy array.
|
|
29
|
+
y_pred : `np.ndarray`
|
|
30
|
+
The predicted values as a 1D numpy array.
|
|
31
|
+
name : `str`, optional
|
|
32
|
+
The name for saving the plot. Default is "example".
|
|
33
|
+
x_hist : `bool`, optional
|
|
34
|
+
Whether to display the histogram for the actual values (y_act). Default is True.
|
|
35
|
+
y_hist : `bool`, optional
|
|
36
|
+
Whether to display the histogram for the predicted values (y_pred). Default is True.
|
|
37
|
+
reg_line : `bool`, optional
|
|
38
|
+
Whether to plot a regression line (best-fit line) in the scatter plot. Default is True.
|
|
39
|
+
save_dir : `Optional[str]`, optional
|
|
40
|
+
The directory to save the figure. If None, the figure will not be saved. Default is None.
|
|
41
|
+
|
|
42
|
+
Returns
|
|
43
|
+
-------
|
|
44
|
+
`None` : The function doesn't return anything. It generates and optionally saves a plot.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
y_pred, y_act = y_pred.flatten(), y_act.flatten()
|
|
48
|
+
|
|
49
|
+
if not isinstance(y_act, np.ndarray) or not isinstance(y_pred, np.ndarray):
|
|
50
|
+
raise ValueError("y_act and y_pred must be numpy arrays.")
|
|
51
|
+
if y_act.shape != y_pred.shape:
|
|
52
|
+
raise ValueError("y_act and y_pred must have the same shape.")
|
|
53
|
+
|
|
54
|
+
mec = "#2F4F4F"
|
|
55
|
+
mfc = "#C0C0C0"
|
|
56
|
+
|
|
57
|
+
fig = plt.figure(figsize=(6, 6))
|
|
58
|
+
|
|
59
|
+
left, width = 0.1, 0.65
|
|
60
|
+
bottom, height = 0.1, 0.65
|
|
61
|
+
bottom_h = left + width
|
|
62
|
+
left_h = left + width + 0.05
|
|
63
|
+
|
|
64
|
+
ax2 = fig.add_axes([left, bottom, width, height])
|
|
65
|
+
ax2.tick_params(direction="in", length=7, top=True, right=True)
|
|
66
|
+
ax2.xaxis.set_minor_locator(AutoMinorLocator(2))
|
|
67
|
+
ax2.yaxis.set_minor_locator(AutoMinorLocator(2))
|
|
68
|
+
|
|
69
|
+
ax2.scatter(y_act, y_pred, color=mfc, edgecolor=mec, alpha=0.5, s=35, lw=1.2)
|
|
70
|
+
ax2.plot(
|
|
71
|
+
[y_act.min(), y_act.max()], [y_act.min(), y_act.max()], "k--", alpha=0.8, label="Ideal"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
ax2.set_xlabel("Actual value")
|
|
75
|
+
ax2.set_ylabel("Predicted value")
|
|
76
|
+
ax2.set_xlim([y_act.min() * 1.05, y_act.max() * 1.05])
|
|
77
|
+
ax2.set_ylim([y_act.min() * 1.05, y_act.max() * 1.05])
|
|
78
|
+
|
|
79
|
+
ax1 = fig.add_axes([left, bottom_h, width, 0.15])
|
|
80
|
+
ax1.hist(y_act, bins=31, density=True, color=mfc, edgecolor=mec, alpha=0.6)
|
|
81
|
+
ax1.set_xticks([])
|
|
82
|
+
ax1.set_yticks([])
|
|
83
|
+
ax1.set_xlim(ax2.get_xlim())
|
|
84
|
+
|
|
85
|
+
if x_hist:
|
|
86
|
+
ax1.set_alpha(1.0)
|
|
87
|
+
|
|
88
|
+
ax3 = fig.add_axes([left_h, bottom, 0.15, height])
|
|
89
|
+
ax3.hist(
|
|
90
|
+
y_pred, bins=31, density=True, color=mfc, edgecolor=mec, orientation="horizontal", alpha=0.6
|
|
91
|
+
)
|
|
92
|
+
ax3.set_xticks([])
|
|
93
|
+
ax3.set_yticks([])
|
|
94
|
+
ax3.set_ylim(ax2.get_ylim())
|
|
95
|
+
|
|
96
|
+
if y_hist:
|
|
97
|
+
ax3.set_alpha(1.0)
|
|
98
|
+
|
|
99
|
+
if reg_line:
|
|
100
|
+
polyfit = np.polyfit(y_act, y_pred, deg=1)
|
|
101
|
+
reg_line_vals = np.poly1d(polyfit)(np.unique(y_act))
|
|
102
|
+
ax2.plot(np.unique(y_act), reg_line_vals, "r-", label="Regression Line", alpha=0.8)
|
|
103
|
+
|
|
104
|
+
ax2.legend(loc="upper left", framealpha=0.35, handlelength=1.5)
|
|
105
|
+
|
|
106
|
+
plt.tight_layout()
|
|
107
|
+
|
|
108
|
+
if save_dir is not None:
|
|
109
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
110
|
+
fig_name = os.path.join(save_dir, f"{name}_act_pred.png")
|
|
111
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
112
|
+
|
|
113
|
+
plt.show()
|
|
114
|
+
plt.close(fig)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def residual(
|
|
118
|
+
y_act: np.ndarray, y_pred: np.ndarray, name: str = "example", save_dir: str = None
|
|
119
|
+
) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Plots the residual errors between the actual and predicted values.
|
|
122
|
+
|
|
123
|
+
This function generates a residual plot by calculating the difference between the
|
|
124
|
+
actual values (y_act) and predicted values (y_pred). The plot shows the residuals
|
|
125
|
+
(y_pred - y_act) against the actual values. Optionally, the plot can be saved to a file.
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
y_act : `np.ndarray`
|
|
130
|
+
The actual values, typically the ground truth values.
|
|
131
|
+
y_pred : `np.ndarray`
|
|
132
|
+
The predicted values that are compared against the actual values.
|
|
133
|
+
name : `str`, optional
|
|
134
|
+
The name of the plot file (without extension) used when saving the plot. Default is "example".
|
|
135
|
+
save_dir : `str`, optional
|
|
136
|
+
The directory where the plot will be saved. If None, the plot is not saved. Default is None.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
`None` : This function does not return any value. It generates and optionally saves a plot.
|
|
141
|
+
|
|
142
|
+
Notes
|
|
143
|
+
-----
|
|
144
|
+
- The plot is shown with the residuals (y_pred - y_act) on the y-axis and the actual values (y_act)
|
|
145
|
+
on the x-axis. The plot includes a horizontal line representing the ideal case where the residual
|
|
146
|
+
is zero (i.e., perfect predictions).
|
|
147
|
+
- The plot will be saved as a PNG image if a valid `save_dir` is provided.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
mec = "#2F4F4F"
|
|
151
|
+
mfc = "#C0C0C0"
|
|
152
|
+
|
|
153
|
+
y_act = np.array(y_act)
|
|
154
|
+
y_pred = np.array(y_pred)
|
|
155
|
+
|
|
156
|
+
xmin = np.min([y_act]) * 0.9
|
|
157
|
+
xmax = np.max([y_act]) / 0.9
|
|
158
|
+
y_err = y_pred - y_act
|
|
159
|
+
ymin = np.min([y_err]) * 0.9
|
|
160
|
+
ymax = np.max([y_err]) / 0.9
|
|
161
|
+
|
|
162
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
163
|
+
|
|
164
|
+
ax.plot(y_act, y_err, "o", mec=mec, mfc=mfc, alpha=0.5, label=None, mew=1.2, ms=5.2)
|
|
165
|
+
ax.plot([xmin, xmax], [0, 0], "k--", alpha=0.8, label="ideal")
|
|
166
|
+
|
|
167
|
+
ax.set_ylabel("Residual error")
|
|
168
|
+
ax.set_xlabel("Actual value")
|
|
169
|
+
ax.legend(loc="lower right")
|
|
170
|
+
|
|
171
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
172
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
173
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
174
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
175
|
+
|
|
176
|
+
ax.tick_params(right=True, top=True, direction="in", length=7)
|
|
177
|
+
ax.tick_params(which="minor", right=True, top=True, direction="in", length=4)
|
|
178
|
+
|
|
179
|
+
ax.set_xlim(xmin, xmax)
|
|
180
|
+
ax.set_ylim(ymin, ymax)
|
|
181
|
+
|
|
182
|
+
if save_dir is not None:
|
|
183
|
+
fig_name = f"{save_dir}/{name}_residual.png"
|
|
184
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
185
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
186
|
+
|
|
187
|
+
plt.draw()
|
|
188
|
+
plt.pause(0.001)
|
|
189
|
+
|
|
190
|
+
plt.close()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def residual_hist(
|
|
194
|
+
y_act: np.ndarray, y_pred: np.ndarray, name: str = "example", save_dir: Optional[str] = None
|
|
195
|
+
) -> None:
|
|
196
|
+
"""
|
|
197
|
+
Generates a residual error histogram with kernel density estimate (KDE) for the given true and predicted values.
|
|
198
|
+
Optionally saves the plot to a specified directory.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
y_act : `np.ndarray`
|
|
203
|
+
Array of true (actual) values.
|
|
204
|
+
|
|
205
|
+
y_pred : `np.ndarray`
|
|
206
|
+
Array of predicted values.
|
|
207
|
+
|
|
208
|
+
name : `str`, optional, default="example"
|
|
209
|
+
The name used for the saved plot filename.
|
|
210
|
+
|
|
211
|
+
save_dir : `str`, optional, default=None
|
|
212
|
+
Directory path to save the generated plot. If None, the plot is not saved.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
--------
|
|
216
|
+
`None` : This function generates and optionally saves a plot but does not return any value.
|
|
217
|
+
|
|
218
|
+
Raises
|
|
219
|
+
-------
|
|
220
|
+
`UserWarning` : If the data has high correlation among variables, suggesting dimensionality reduction.
|
|
221
|
+
"""
|
|
222
|
+
mec = "#2F4F4F"
|
|
223
|
+
mfc = "#C0C0C0"
|
|
224
|
+
y_pred, y_act = y_pred.flatten(), y_act.flatten()
|
|
225
|
+
|
|
226
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
227
|
+
y_err = y_pred - y_act
|
|
228
|
+
x_range = np.linspace(min(y_err), max(y_err), 1000)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
kde_act = stats.gaussian_kde(y_err)
|
|
232
|
+
ax.plot(x_range, kde_act(x_range), "-", lw=1.2, color="k", label="kde")
|
|
233
|
+
except np.linalg.LinAlgError as e:
|
|
234
|
+
warnings.warn(
|
|
235
|
+
"The data has very high correlation among variables. Consider dimensionality reduction.",
|
|
236
|
+
UserWarning,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
ax.hist(y_err, color=mfc, bins=35, alpha=1, edgecolor=mec, density=True)
|
|
240
|
+
|
|
241
|
+
ax.set_xlabel("Residual error")
|
|
242
|
+
ax.set_ylabel("Relative frequency")
|
|
243
|
+
plt.legend(loc=2, framealpha=0.35, handlelength=1.5)
|
|
244
|
+
|
|
245
|
+
ax.tick_params(direction="in", length=7, top=True, right=True)
|
|
246
|
+
|
|
247
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
248
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
249
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
250
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
251
|
+
plt.tick_params(which="minor", direction="in", length=4, right=True, top=True)
|
|
252
|
+
|
|
253
|
+
if save_dir is not None:
|
|
254
|
+
fig_name = f"{save_dir}/{name}_residual_hist.png"
|
|
255
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
256
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
257
|
+
|
|
258
|
+
plt.draw()
|
|
259
|
+
plt.pause(0.001)
|
|
260
|
+
plt.close()
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def loss_curve(
|
|
264
|
+
x_data: np.ndarray,
|
|
265
|
+
train_err: np.ndarray,
|
|
266
|
+
val_err: np.ndarray,
|
|
267
|
+
name: str = "example",
|
|
268
|
+
save_dir: Optional[str] = None,
|
|
269
|
+
) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Plots the loss curve for both training and validation errors over epochs,
|
|
272
|
+
and optionally saves the plot as an image.
|
|
273
|
+
|
|
274
|
+
Parameters
|
|
275
|
+
----------
|
|
276
|
+
x_data : `np.ndarray`
|
|
277
|
+
Array of x-values (usually epochs) for the plot.
|
|
278
|
+
train_err : `np.ndarray`
|
|
279
|
+
Array of training error values.
|
|
280
|
+
val_err : `np.ndarray`
|
|
281
|
+
Array of validation error values.
|
|
282
|
+
name : `str`, optional
|
|
283
|
+
The name to use when saving the plot. Default is "example".
|
|
284
|
+
save_dir : `Optional[str]`, optional
|
|
285
|
+
Directory where the plot should be saved. If None, the plot is not saved. Default is None.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
`None` : This function does not return any value. It generates and optionally saves a plot.
|
|
290
|
+
"""
|
|
291
|
+
mec1 = "#2F4F4F"
|
|
292
|
+
mfc1 = "#C0C0C0"
|
|
293
|
+
mec2 = "maroon"
|
|
294
|
+
mfc2 = "pink"
|
|
295
|
+
|
|
296
|
+
fig, ax = plt.subplots(figsize=(4, 4))
|
|
297
|
+
|
|
298
|
+
ax.plot(
|
|
299
|
+
x_data,
|
|
300
|
+
train_err,
|
|
301
|
+
"-",
|
|
302
|
+
color=mec1,
|
|
303
|
+
marker="o",
|
|
304
|
+
mec=mec1,
|
|
305
|
+
mfc=mfc1,
|
|
306
|
+
ms=4,
|
|
307
|
+
alpha=0.5,
|
|
308
|
+
label="train",
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
ax.plot(
|
|
312
|
+
x_data,
|
|
313
|
+
val_err,
|
|
314
|
+
"--",
|
|
315
|
+
color=mec2,
|
|
316
|
+
marker="s",
|
|
317
|
+
mec=mec2,
|
|
318
|
+
mfc=mfc2,
|
|
319
|
+
ms=4,
|
|
320
|
+
alpha=0.5,
|
|
321
|
+
label="validation",
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
max_val_err = max(val_err)
|
|
325
|
+
ax.axhline(max_val_err, color="b", linestyle="--", alpha=0.3)
|
|
326
|
+
|
|
327
|
+
ax.set_xlabel("Number of training epochs")
|
|
328
|
+
ax.set_ylabel("Loss (Units)")
|
|
329
|
+
ax.set_ylim(0, 2 * np.mean(val_err))
|
|
330
|
+
|
|
331
|
+
ax.legend(loc=1, framealpha=0.35, handlelength=1.5)
|
|
332
|
+
|
|
333
|
+
minor_locator_x = AutoMinorLocator(2)
|
|
334
|
+
minor_locator_y = AutoMinorLocator(2)
|
|
335
|
+
ax.get_xaxis().set_minor_locator(minor_locator_x)
|
|
336
|
+
ax.get_yaxis().set_minor_locator(minor_locator_y)
|
|
337
|
+
|
|
338
|
+
ax.tick_params(right=True, top=True, direction="in", length=7)
|
|
339
|
+
ax.tick_params(which="minor", right=True, top=True, direction="in", length=4)
|
|
340
|
+
|
|
341
|
+
if save_dir is not None:
|
|
342
|
+
fig_name = f"{save_dir}/{name}_loss_curve.png"
|
|
343
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
344
|
+
plt.savefig(fig_name, bbox_inches="tight", dpi=300)
|
|
345
|
+
|
|
346
|
+
plt.draw()
|
|
347
|
+
plt.pause(0.001)
|
|
348
|
+
plt.close()
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import pickle
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import seaborn as sns
|
|
9
|
+
|
|
10
|
+
from likelihood.models import SimulationEngine
|
|
11
|
+
from likelihood.tools.numeric_tools import find_multiples
|
|
12
|
+
|
|
13
|
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SimpleImputer:
|
|
17
|
+
"""Multiple imputation using simulation engine."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, n_features: int | None = None, use_scaler: bool = False):
|
|
20
|
+
"""
|
|
21
|
+
Initialize the imputer.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
n_features: int | None
|
|
26
|
+
Number of features to be used in the imputer. Default is None.
|
|
27
|
+
use_scaler: bool
|
|
28
|
+
Whether to use a scaler. Default is False.
|
|
29
|
+
"""
|
|
30
|
+
self.n_features = n_features
|
|
31
|
+
self.sim = SimulationEngine(use_scaler=use_scaler)
|
|
32
|
+
self.params = {}
|
|
33
|
+
self.cols_transf = pd.Series([])
|
|
34
|
+
|
|
35
|
+
def fit(self, X: pd.DataFrame) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Fit the imputer to the data.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
X: pd.DataFrame
|
|
42
|
+
Dataframe to fit the imputer to.
|
|
43
|
+
"""
|
|
44
|
+
X_impute = X.copy()
|
|
45
|
+
self.params = self._get_dict_params(X_impute)
|
|
46
|
+
X_impute = self.sim._clean_data(X_impute)
|
|
47
|
+
|
|
48
|
+
if X_impute.empty:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"The dataframe is empty after cleaning, it is not possible to train the imputer."
|
|
51
|
+
)
|
|
52
|
+
self.n_features = self.n_features or X_impute.shape[1] - 1
|
|
53
|
+
self.sim.fit(X_impute, self.n_features)
|
|
54
|
+
|
|
55
|
+
def transform(
|
|
56
|
+
self, X: pd.DataFrame, boundary: bool = True, inplace: bool = True
|
|
57
|
+
) -> pd.DataFrame:
|
|
58
|
+
"""
|
|
59
|
+
Impute missing values in the data.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
-----------
|
|
63
|
+
X: pd.DataFrame
|
|
64
|
+
Dataframe to impute missing values.
|
|
65
|
+
boundary: bool
|
|
66
|
+
Whether to use the boundaries of the data to impute missing values. Default is True.
|
|
67
|
+
inplace: bool
|
|
68
|
+
Whether to modify the columns of the original dataframe or return new ones. Default is True.
|
|
69
|
+
"""
|
|
70
|
+
X_impute = X.copy()
|
|
71
|
+
self.cols_transf = X_impute.columns
|
|
72
|
+
for column in X_impute.columns:
|
|
73
|
+
if X_impute[column].isnull().sum() > 0:
|
|
74
|
+
|
|
75
|
+
if not X_impute[column].dtype == "object":
|
|
76
|
+
min_value = self.params[column]["min"]
|
|
77
|
+
max_value = self.params[column]["max"]
|
|
78
|
+
to_compare = self.params[column]["to_compare"]
|
|
79
|
+
for row in X_impute.index:
|
|
80
|
+
if pd.isnull(X_impute.loc[row, column]):
|
|
81
|
+
value_impute = self._check_dtype_convert(
|
|
82
|
+
self.sim.predict(
|
|
83
|
+
self._set_zero(X_impute.loc[row, :], column),
|
|
84
|
+
column,
|
|
85
|
+
)[0],
|
|
86
|
+
to_compare,
|
|
87
|
+
)
|
|
88
|
+
if not X_impute[column].dtype == "object" and boundary:
|
|
89
|
+
if value_impute < min_value:
|
|
90
|
+
value_impute = min_value
|
|
91
|
+
if value_impute > max_value:
|
|
92
|
+
value_impute = max_value
|
|
93
|
+
X_impute.loc[row, column] = value_impute
|
|
94
|
+
else:
|
|
95
|
+
self.cols_transf = self.cols_transf.drop(column)
|
|
96
|
+
if not inplace:
|
|
97
|
+
X_impute = X_impute[self.cols_transf].copy()
|
|
98
|
+
X_impute = X_impute.rename(
|
|
99
|
+
columns={column: column + "_imputed" for column in self.cols_transf}
|
|
100
|
+
)
|
|
101
|
+
X_impute = X.join(X_impute, rsuffix="_imputed")
|
|
102
|
+
order_cols = []
|
|
103
|
+
for column in X.columns:
|
|
104
|
+
if column + "_imputed" in X_impute.columns:
|
|
105
|
+
order_cols.append(column)
|
|
106
|
+
order_cols.append(column + "_imputed")
|
|
107
|
+
else:
|
|
108
|
+
order_cols.append(column)
|
|
109
|
+
X_impute = X_impute[order_cols]
|
|
110
|
+
return X_impute
|
|
111
|
+
|
|
112
|
+
def fit_transform(
|
|
113
|
+
self, X: pd.DataFrame, boundary: bool = True, inplace: bool = True
|
|
114
|
+
) -> pd.DataFrame:
|
|
115
|
+
"""
|
|
116
|
+
Fit and transform the data.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
-----------
|
|
120
|
+
X: pd.DataFrame
|
|
121
|
+
Dataframe to fit and transform.
|
|
122
|
+
boundary: bool
|
|
123
|
+
Whether to use the boundaries of the data to impute missing values. Default is True.
|
|
124
|
+
inplace: bool
|
|
125
|
+
Whether to modify the columns of the original dataframe or return new ones. Default is True.
|
|
126
|
+
"""
|
|
127
|
+
X_train = X.copy()
|
|
128
|
+
self.fit(X_train)
|
|
129
|
+
return self.transform(X, boundary, inplace)
|
|
130
|
+
|
|
131
|
+
def _set_zero(self, X: pd.Series, column_exception) -> pd.DataFrame:
|
|
132
|
+
"""
|
|
133
|
+
Set missing values to zero, except for `column_exception`.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
-----------
|
|
137
|
+
X: pd.Series
|
|
138
|
+
Series to set missing values to zero.
|
|
139
|
+
"""
|
|
140
|
+
X = X.copy()
|
|
141
|
+
for column in X.index:
|
|
142
|
+
if pd.isnull(X[column]) and column != column_exception:
|
|
143
|
+
X[column] = 0
|
|
144
|
+
data = X.to_frame().T
|
|
145
|
+
return data
|
|
146
|
+
|
|
147
|
+
def _check_dtype_convert(self, value: Union[int, float], to_compare: Union[int, float]) -> None:
|
|
148
|
+
"""
|
|
149
|
+
Check if the value is an integer and convert it to float if it is.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
-----------
|
|
153
|
+
value: Union[int, float]
|
|
154
|
+
Value to check and convert.
|
|
155
|
+
to_compare: Union[int, float]
|
|
156
|
+
Value to compare to.
|
|
157
|
+
"""
|
|
158
|
+
if isinstance(to_compare, int) and isinstance(value, float):
|
|
159
|
+
value = int(round(value, 0))
|
|
160
|
+
|
|
161
|
+
if isinstance(to_compare, float) and isinstance(value, float):
|
|
162
|
+
value = round(value, len(str(to_compare).split(".")[1]))
|
|
163
|
+
return value
|
|
164
|
+
|
|
165
|
+
def _get_dict_params(self, df: pd.DataFrame) -> dict:
|
|
166
|
+
"""
|
|
167
|
+
Get the parameters for the imputer.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
-----------
|
|
171
|
+
df: pd.DataFrame
|
|
172
|
+
Dataframe to get the parameters from.
|
|
173
|
+
"""
|
|
174
|
+
params = {}
|
|
175
|
+
for column in df.columns:
|
|
176
|
+
if df[column].isnull().sum() > 0:
|
|
177
|
+
if not df[column].dtype == "object":
|
|
178
|
+
to_compare = df[column].dropna().sample().values[0]
|
|
179
|
+
params[column] = {
|
|
180
|
+
"min": df[column].min(),
|
|
181
|
+
"to_compare": to_compare,
|
|
182
|
+
"max": df[column].max(),
|
|
183
|
+
}
|
|
184
|
+
return params
|
|
185
|
+
|
|
186
|
+
def eval(self, X: pd.DataFrame) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Create a histogram of the imputed values.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
-----------
|
|
192
|
+
X: pd.DataFrame
|
|
193
|
+
Dataframe to create the histogram from.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
if not isinstance(X, pd.DataFrame):
|
|
197
|
+
raise ValueError("Input X must be a pandas DataFrame.")
|
|
198
|
+
|
|
199
|
+
df = X.copy()
|
|
200
|
+
|
|
201
|
+
imputed_cols = [col for col in df.columns if col.endswith("_imputed")]
|
|
202
|
+
num_impute = len(imputed_cols)
|
|
203
|
+
|
|
204
|
+
if num_impute == 0:
|
|
205
|
+
print("No imputed columns found in the DataFrame.")
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
ncols, nrows = find_multiples(num_impute)
|
|
210
|
+
except ValueError as e:
|
|
211
|
+
print(f"Error finding multiples for {num_impute}: {e}")
|
|
212
|
+
ncols = 1
|
|
213
|
+
nrows = num_impute
|
|
214
|
+
|
|
215
|
+
_, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows))
|
|
216
|
+
axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
|
|
217
|
+
|
|
218
|
+
for i, col in enumerate(imputed_cols):
|
|
219
|
+
original_col = col.replace("_imputed", "")
|
|
220
|
+
|
|
221
|
+
if original_col in df.columns:
|
|
222
|
+
original_col_data = df[original_col].dropna()
|
|
223
|
+
ax = axes[i]
|
|
224
|
+
|
|
225
|
+
# Plot the original data
|
|
226
|
+
sns.histplot(
|
|
227
|
+
original_col_data,
|
|
228
|
+
kde=True,
|
|
229
|
+
color="blue",
|
|
230
|
+
label=f"Original",
|
|
231
|
+
bins=10,
|
|
232
|
+
ax=ax,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Plot the imputed data
|
|
236
|
+
sns.histplot(
|
|
237
|
+
df[col],
|
|
238
|
+
kde=True,
|
|
239
|
+
color="red",
|
|
240
|
+
label=f"Imputed",
|
|
241
|
+
bins=10,
|
|
242
|
+
ax=ax,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
ax.set_xlabel(original_col)
|
|
246
|
+
ax.set_ylabel("Frequency" if i % ncols == 0 else "")
|
|
247
|
+
ax.legend(loc="upper right")
|
|
248
|
+
|
|
249
|
+
plt.suptitle("Histogram Comparison", fontsize=16, fontweight="bold")
|
|
250
|
+
plt.tight_layout()
|
|
251
|
+
plt.subplots_adjust(top=0.9)
|
|
252
|
+
plt.show()
|
|
253
|
+
|
|
254
|
+
def save(self, filename: str = "./imputer") -> None:
|
|
255
|
+
"""
|
|
256
|
+
Save the state of the SimpleImputer to a file.
|
|
257
|
+
|
|
258
|
+
Parameters
|
|
259
|
+
-----------
|
|
260
|
+
filename: str
|
|
261
|
+
Name of the file to save the imputer to. Default is "./imputer".
|
|
262
|
+
"""
|
|
263
|
+
filename = filename if filename.endswith(".pkl") else filename + ".pkl"
|
|
264
|
+
with open(filename, "wb") as f:
|
|
265
|
+
pickle.dump(self, f)
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def load(filename: str = "./imputer"):
|
|
269
|
+
"""
|
|
270
|
+
Load the state of a SimpleImputer from a file.
|
|
271
|
+
|
|
272
|
+
Parameters
|
|
273
|
+
-----------
|
|
274
|
+
filename: str
|
|
275
|
+
Name of the file to load the imputer from. Default is "./imputer".
|
|
276
|
+
"""
|
|
277
|
+
filename = filename + ".pkl" if not filename.endswith(".pkl") else filename
|
|
278
|
+
with open(filename, "rb") as f:
|
|
279
|
+
return pickle.load(f)
|