likelihood 2.2.0.dev1__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,278 @@
1
+ import pickle
2
+ import warnings
3
+ from typing import Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import pandas as pd
8
+ import seaborn as sns
9
+
10
+ from likelihood.models import SimulationEngine
11
+ from likelihood.tools.numeric_tools import find_multiples
12
+
13
+ warnings.simplefilter(action="ignore", category=FutureWarning)
14
+
15
+
16
+ class SimpleImputer:
17
+ """Multiple imputation using simulation engine."""
18
+
19
+ def __init__(self, n_features: int | None = None, use_scaler: bool = False):
20
+ """
21
+ Initialize the imputer.
22
+
23
+ Parameters
24
+ ----------
25
+ n_features: int | None
26
+ Number of features to be used in the imputer. Default is None.
27
+ use_scaler: bool
28
+ Whether to use a scaler. Default is False.
29
+ """
30
+ self.n_features = n_features
31
+ self.sim = SimulationEngine(use_scaler=use_scaler)
32
+ self.params = {}
33
+ self.cols_transf = pd.Series([])
34
+
35
+ def fit(self, X: pd.DataFrame) -> None:
36
+ """
37
+ Fit the imputer to the data.
38
+
39
+ Parameters
40
+ ----------
41
+ X: pd.DataFrame
42
+ Dataframe to fit the imputer to.
43
+ """
44
+ X_impute = X.copy()
45
+ self.params = self._get_dict_params(X_impute)
46
+ X_impute = self.sim._clean_data(X_impute)
47
+
48
+ if X_impute.empty:
49
+ raise ValueError(
50
+ "The dataframe is empty after cleaning, it is not possible to train the imputer."
51
+ )
52
+ self.n_features = self.n_features or X_impute.shape[1] - 1
53
+ self.sim.fit(X_impute, self.n_features)
54
+
55
+ def transform(
56
+ self, X: pd.DataFrame, boundary: bool = True, inplace: bool = True
57
+ ) -> pd.DataFrame:
58
+ """
59
+ Impute missing values in the data.
60
+
61
+ Parameters
62
+ ----------
63
+ X: pd.DataFrame
64
+ Dataframe to impute missing values.
65
+ boundary: bool
66
+ Whether to use the boundaries of the data to impute missing values. Default is True.
67
+ inplace: bool
68
+ Whether to modify the columns of the original dataframe or return new ones. Default is True.
69
+ """
70
+ X_impute = X.copy()
71
+ self.cols_transf = X_impute.columns
72
+ for column in X_impute.columns:
73
+ if X_impute[column].isnull().sum() > 0:
74
+ if not X_impute[column].dtype == "object":
75
+ min_value = self.params[column]["min"]
76
+ max_value = self.params[column]["max"]
77
+ to_compare = self.params[column]["to_compare"]
78
+ for row in X_impute.index:
79
+ if pd.isnull(X_impute.loc[row, column]):
80
+ value_impute = self._check_dtype_convert(
81
+ self.sim.predict(
82
+ self._set_zero(X_impute.loc[row, :], column),
83
+ column,
84
+ )[0],
85
+ to_compare,
86
+ )
87
+ if not X_impute[column].dtype == "object" and boundary:
88
+ if value_impute < min_value:
89
+ value_impute = min_value
90
+ if value_impute > max_value:
91
+ value_impute = max_value
92
+ X_impute.loc[row, column] = value_impute
93
+ else:
94
+ self.cols_transf = self.cols_transf.drop(column)
95
+ if not inplace:
96
+ X_impute = X_impute[self.cols_transf].copy()
97
+ X_impute = X_impute.rename(
98
+ columns={column: column + "_imputed" for column in self.cols_transf}
99
+ )
100
+ X_impute = X.join(X_impute, rsuffix="_imputed")
101
+ order_cols = []
102
+ for column in X.columns:
103
+ if column + "_imputed" in X_impute.columns:
104
+ order_cols.append(column)
105
+ order_cols.append(column + "_imputed")
106
+ else:
107
+ order_cols.append(column)
108
+ X_impute = X_impute[order_cols]
109
+ return X_impute
110
+
111
+ def fit_transform(
112
+ self, X: pd.DataFrame, boundary: bool = True, inplace: bool = True
113
+ ) -> pd.DataFrame:
114
+ """
115
+ Fit and transform the data.
116
+
117
+ Parameters
118
+ ----------
119
+ X: pd.DataFrame
120
+ Dataframe to fit and transform.
121
+ boundary: bool
122
+ Whether to use the boundaries of the data to impute missing values. Default is True.
123
+ inplace: bool
124
+ Whether to modify the columns of the original dataframe or return new ones. Default is True.
125
+ """
126
+ X_train = X.copy()
127
+ self.fit(X_train)
128
+ return self.transform(X, boundary, inplace)
129
+
130
+ def _set_zero(self, X: pd.Series, column_exception) -> pd.DataFrame:
131
+ """
132
+ Set missing values to zero, except for `column_exception`.
133
+
134
+ Parameters
135
+ ----------
136
+ X: pd.Series
137
+ Series to set missing values to zero.
138
+ """
139
+ X = X.copy()
140
+ for column in X.index:
141
+ if pd.isnull(X[column]) and column != column_exception:
142
+ X[column] = 0
143
+ data = X.to_frame().T
144
+ return data
145
+
146
+ def _check_dtype_convert(self, value: Union[int, float], to_compare: Union[int, float]) -> None:
147
+ """
148
+ Check if the value is an integer and convert it to float if it is.
149
+
150
+ Parameters
151
+ ----------
152
+ value: Union[int, float]
153
+ Value to check and convert.
154
+ to_compare: Union[int, float]
155
+ Value to compare to.
156
+ """
157
+ if isinstance(to_compare, int) and isinstance(value, float):
158
+ value = int(round(value, 0))
159
+
160
+ if isinstance(to_compare, float) and isinstance(value, float):
161
+ value = round(value, len(str(to_compare).split(".")[1]))
162
+ return value
163
+
164
+ def _get_dict_params(self, df: pd.DataFrame) -> dict:
165
+ """
166
+ Get the parameters for the imputer.
167
+
168
+ Parameters
169
+ ----------
170
+ df: pd.DataFrame
171
+ Dataframe to get the parameters from.
172
+ """
173
+ params = {}
174
+ for column in df.columns:
175
+ if df[column].isnull().sum() > 0:
176
+ if not df[column].dtype == "object":
177
+ to_compare = df[column].dropna().sample().values[0]
178
+ params[column] = {
179
+ "min": df[column].min(),
180
+ "to_compare": to_compare,
181
+ "max": df[column].max(),
182
+ }
183
+ return params
184
+
185
+ def eval(self, X: pd.DataFrame) -> None:
186
+ """
187
+ Create a histogram of the imputed values.
188
+
189
+ Parameters
190
+ ----------
191
+ X: pd.DataFrame
192
+ Dataframe to create the histogram from.
193
+ """
194
+
195
+ if not isinstance(X, pd.DataFrame):
196
+ raise ValueError("Input X must be a pandas DataFrame.")
197
+
198
+ df = X.copy()
199
+
200
+ imputed_cols = [col for col in df.columns if col.endswith("_imputed")]
201
+ num_impute = len(imputed_cols)
202
+
203
+ if num_impute == 0:
204
+ print("No imputed columns found in the DataFrame.")
205
+ return
206
+
207
+ try:
208
+ ncols, nrows = find_multiples(num_impute)
209
+ except ValueError as e:
210
+ print(f"Error finding multiples for {num_impute}: {e}")
211
+ ncols = 1
212
+ nrows = num_impute
213
+
214
+ _, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(12, 5 * nrows))
215
+ axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]
216
+
217
+ for i, col in enumerate(imputed_cols):
218
+ original_col = col.replace("_imputed", "")
219
+
220
+ if original_col in df.columns:
221
+ original_col_data = df[original_col].dropna()
222
+ ax = axes[i]
223
+
224
+ # Plot the original data
225
+ sns.histplot(
226
+ original_col_data,
227
+ kde=True,
228
+ color="blue",
229
+ label=f"Original",
230
+ bins=10,
231
+ ax=ax,
232
+ )
233
+
234
+ # Plot the imputed data
235
+ sns.histplot(
236
+ df[col],
237
+ kde=True,
238
+ color="red",
239
+ label=f"Imputed",
240
+ bins=10,
241
+ ax=ax,
242
+ )
243
+
244
+ ax.set_xlabel(original_col)
245
+ ax.set_ylabel("Frequency" if i % ncols == 0 else "")
246
+ ax.legend(loc="upper right")
247
+
248
+ plt.suptitle("Histogram Comparison", fontsize=16, fontweight="bold")
249
+ plt.tight_layout()
250
+ plt.subplots_adjust(top=0.9)
251
+ plt.show()
252
+
253
+ def save(self, filename: str = "./imputer") -> None:
254
+ """
255
+ Save the state of the SimpleImputer to a file.
256
+
257
+ Parameters
258
+ ----------
259
+ filename: str
260
+ Name of the file to save the imputer to. Default is "./imputer".
261
+ """
262
+ filename = filename if filename.endswith(".pkl") else filename + ".pkl"
263
+ with open(filename, "wb") as f:
264
+ pickle.dump(self, f)
265
+
266
+ @staticmethod
267
+ def load(filename: str = "./imputer"):
268
+ """
269
+ Load the state of a SimpleImputer from a file.
270
+
271
+ Parameters
272
+ ----------
273
+ filename: str
274
+ Name of the file to load the imputer from. Default is "./imputer".
275
+ """
276
+ filename = filename + ".pkl" if not filename.endswith(".pkl") else filename
277
+ with open(filename, "rb") as f:
278
+ return pickle.load(f)