pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
- pgsui/__init__.py +0 -8
- pgsui/_version.py +2 -2
- pgsui/cli.py +577 -125
- pgsui/data_processing/config.py +1 -2
- pgsui/data_processing/containers.py +203 -530
- pgsui/data_processing/transformers.py +44 -20
- pgsui/impute/deterministic/imputers/mode.py +475 -182
- pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
- pgsui/impute/supervised/imputers/random_forest.py +3 -2
- pgsui/impute/unsupervised/base.py +1269 -534
- pgsui/impute/unsupervised/callbacks.py +28 -33
- pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
- pgsui/impute/unsupervised/imputers/vae.py +931 -787
- pgsui/impute/unsupervised/loss_functions.py +156 -202
- pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
- pgsui/impute/unsupervised/models/vae_model.py +40 -221
- pgsui/impute/unsupervised/nn_scorers.py +53 -13
- pgsui/utils/classification_viz.py +240 -97
- pgsui/utils/misc.py +201 -3
- pgsui/utils/plotting.py +73 -58
- pgsui/utils/pretty_metrics.py +2 -6
- pgsui/utils/scorers.py +39 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
- pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
- pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
- pgsui/impute/unsupervised/models/ubp_model.py +0 -200
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
- {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
pgsui/utils/misc.py
CHANGED
|
@@ -66,8 +66,206 @@ def validate_input_type(
|
|
|
66
66
|
if isinstance(X, torch.Tensor):
|
|
67
67
|
return X
|
|
68
68
|
elif isinstance(X, np.ndarray):
|
|
69
|
-
return torch.tensor(X, dtype=torch.
|
|
69
|
+
return torch.tensor(X, dtype=torch.long)
|
|
70
70
|
elif isinstance(X, pd.DataFrame):
|
|
71
|
-
return torch.tensor(X.to_numpy(), dtype=torch.
|
|
71
|
+
return torch.tensor(X.to_numpy(), dtype=torch.long)
|
|
72
72
|
elif isinstance(X, list):
|
|
73
|
-
return torch.tensor(X, dtype=torch.
|
|
73
|
+
return torch.tensor(X, dtype=torch.long)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def detect_computing_device(
|
|
77
|
+
*, force_cpu: bool = False, verbose: bool = False
|
|
78
|
+
) -> torch.device:
|
|
79
|
+
"""Detects and returns the best available PyTorch compute device.
|
|
80
|
+
|
|
81
|
+
Prioritizes CUDA (NVIDIA) > MPS (Apple Silicon) > CPU.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
force_cpu (bool): If True, forces the device to CPU regardless of available hardware. Defaults to False.
|
|
85
|
+
verbose (bool): If True, prints the selected device to stdout. Defaults to False.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
torch.device: The selected computing device.
|
|
89
|
+
"""
|
|
90
|
+
if force_cpu:
|
|
91
|
+
device = torch.device("cpu") # Forced to CPU
|
|
92
|
+
elif torch.cuda.is_available():
|
|
93
|
+
device = torch.device("cuda")
|
|
94
|
+
elif torch.backends.mps.is_available():
|
|
95
|
+
device = torch.device("mps")
|
|
96
|
+
else:
|
|
97
|
+
device = torch.device("cpu") # Fallback to CPU
|
|
98
|
+
|
|
99
|
+
if verbose:
|
|
100
|
+
print(f"Selected compute device: {device}")
|
|
101
|
+
|
|
102
|
+
return device
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_missing_mask(
|
|
106
|
+
X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
|
|
107
|
+
) -> pd.DataFrame | pd.Series | np.ndarray | torch.Tensor:
|
|
108
|
+
"""Returns a boolean mask indicating missing values (NaN, None).
|
|
109
|
+
|
|
110
|
+
Notes:
|
|
111
|
+
Lists are converted to numpy arrays to compute the mask.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
X: Input data.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
pd.DataFrame | pd.Series | np.ndarray | torch.Tensor: Boolean mask of the same shape as X (returned as DF, Array, or Tensor).
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
TypeError: If input type is not supported.
|
|
121
|
+
"""
|
|
122
|
+
if isinstance(X, pd.DataFrame):
|
|
123
|
+
return X.isna()
|
|
124
|
+
|
|
125
|
+
elif isinstance(X, pd.Series):
|
|
126
|
+
return pd.isna(X)
|
|
127
|
+
|
|
128
|
+
elif isinstance(X, np.ndarray):
|
|
129
|
+
# np.isnan fails on object arrays (e.g. strings)
|
|
130
|
+
# so we check generically first
|
|
131
|
+
if X.dtype.kind in {"U", "S", "O"}: # String/Object
|
|
132
|
+
return pd.isnull(X)
|
|
133
|
+
return np.isnan(X)
|
|
134
|
+
|
|
135
|
+
elif isinstance(X, torch.Tensor):
|
|
136
|
+
return torch.isnan(X)
|
|
137
|
+
|
|
138
|
+
elif isinstance(X, list):
|
|
139
|
+
arr = np.array(X)
|
|
140
|
+
# Handle mixed types in lists
|
|
141
|
+
if arr.dtype.kind in {"U", "S", "O"}:
|
|
142
|
+
return pd.isnull(arr)
|
|
143
|
+
return np.isnan(arr)
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
raise TypeError(
|
|
147
|
+
f"Unsupported type for missing value detection. Expected pandas.DataFrame, pandas.Series, numpy.ndarray, list, or torch.Tensor but got {type(X)}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def ensure_2d(
|
|
152
|
+
X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
|
|
153
|
+
) -> pd.DataFrame | np.ndarray | list | torch.Tensor:
|
|
154
|
+
"""Ensures the input is at least 2-dimensional.
|
|
155
|
+
|
|
156
|
+
If input is 1D (e.g., shape (N,)), it is reshaped to (N, 1). Already 2D+ inputs are returned unchanged.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
X (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
pd.DataFrame | np.ndarray | list | torch.Tensor: Input data transformed to be at least 2D.
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
TypeError: If input type is not supported.
|
|
166
|
+
"""
|
|
167
|
+
if isinstance(X, pd.DataFrame):
|
|
168
|
+
return X # DataFrames are always 2D
|
|
169
|
+
|
|
170
|
+
elif isinstance(X, pd.Series):
|
|
171
|
+
return X.to_frame() # Convert Series to DataFrame (2D)
|
|
172
|
+
|
|
173
|
+
elif isinstance(X, np.ndarray):
|
|
174
|
+
if X.ndim == 1:
|
|
175
|
+
return X.reshape(-1, 1)
|
|
176
|
+
return X
|
|
177
|
+
|
|
178
|
+
elif isinstance(X, torch.Tensor):
|
|
179
|
+
if X.dim() == 1:
|
|
180
|
+
return X.unsqueeze(1)
|
|
181
|
+
return X
|
|
182
|
+
|
|
183
|
+
elif isinstance(X, list):
|
|
184
|
+
# Check depth of list
|
|
185
|
+
if not X:
|
|
186
|
+
return X
|
|
187
|
+
if not isinstance(X[0], list):
|
|
188
|
+
return [[x] for x in X]
|
|
189
|
+
return X
|
|
190
|
+
|
|
191
|
+
else:
|
|
192
|
+
msg = f"X must be of type pandas.DataFrame, pd.Series, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
|
|
193
|
+
raise TypeError(msg)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def flatten_1d(
|
|
197
|
+
y: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
|
|
198
|
+
) -> pd.Series | np.ndarray | list | torch.Tensor:
|
|
199
|
+
"""
|
|
200
|
+
Flattens input to a 1D structure.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
y (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
pd.Series | np.ndarray | list | torch.Tensor: 1D representation of the input.
|
|
207
|
+
|
|
208
|
+
Notes:
|
|
209
|
+
Inputs with multiple columns (e.g., DataFrame with >1 column) are flattened into a single 1D structure.
|
|
210
|
+
|
|
211
|
+
Raises:
|
|
212
|
+
TypeError: If input type is not supported.
|
|
213
|
+
"""
|
|
214
|
+
if isinstance(y, pd.DataFrame):
|
|
215
|
+
if y.shape[1] == 1:
|
|
216
|
+
return y.iloc[:, 0]
|
|
217
|
+
else:
|
|
218
|
+
return pd.Series(y.to_numpy().flatten())
|
|
219
|
+
|
|
220
|
+
elif isinstance(y, np.ndarray):
|
|
221
|
+
return y.flatten()
|
|
222
|
+
|
|
223
|
+
elif isinstance(y, torch.Tensor):
|
|
224
|
+
return y.view(-1)
|
|
225
|
+
|
|
226
|
+
elif isinstance(y, list):
|
|
227
|
+
# Recursively flatten list if needed, or simple comprehension if just 2D
|
|
228
|
+
if not y:
|
|
229
|
+
return y
|
|
230
|
+
if isinstance(y[0], list):
|
|
231
|
+
return [item for sublist in y for item in sublist]
|
|
232
|
+
return y
|
|
233
|
+
|
|
234
|
+
else:
|
|
235
|
+
msg = f"Input must be of type pandas.DataFrame, pandas.Series, numpy.ndarray, list, or torch.Tensor, but got {type(y)}"
|
|
236
|
+
raise TypeError(msg)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def safe_shape(
|
|
240
|
+
X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
|
|
241
|
+
) -> tuple[int, ...]:
|
|
242
|
+
"""Returns the shape of the input container as a tuple.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
X (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
tuple[int, ...]: Dimensions of the data (rows, cols, etc.).
|
|
249
|
+
"""
|
|
250
|
+
if isinstance(X, (pd.DataFrame, np.ndarray)):
|
|
251
|
+
return X.shape
|
|
252
|
+
|
|
253
|
+
elif isinstance(X, pd.Series):
|
|
254
|
+
return (X.shape[0],)
|
|
255
|
+
|
|
256
|
+
elif isinstance(X, torch.Tensor):
|
|
257
|
+
return tuple(X.shape)
|
|
258
|
+
|
|
259
|
+
elif isinstance(X, list):
|
|
260
|
+
if not X:
|
|
261
|
+
return (0,)
|
|
262
|
+
rows = len(X)
|
|
263
|
+
|
|
264
|
+
# Check if 2D list
|
|
265
|
+
if isinstance(X[0], list):
|
|
266
|
+
return (rows, len(X[0]))
|
|
267
|
+
return (rows,)
|
|
268
|
+
|
|
269
|
+
else:
|
|
270
|
+
msg = f"X must be of type pandas.DataFrame, pd.Series, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
|
|
271
|
+
raise TypeError(msg)
|
pgsui/utils/plotting.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import warnings
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Dict, List, Literal, Optional, Sequence, cast
|
|
4
|
+
from typing import Dict, List, Literal, Optional, Sequence, Mapping, cast
|
|
5
5
|
|
|
6
6
|
import matplotlib as mpl
|
|
7
7
|
|
|
@@ -294,6 +294,10 @@ class Plotting:
|
|
|
294
294
|
ValueError: If model_name is not recognized (legacy guard).
|
|
295
295
|
"""
|
|
296
296
|
num_classes = y_pred_proba.shape[1]
|
|
297
|
+
if num_classes < 2:
|
|
298
|
+
msg = "plot_metrics: num_classes must be >= 2 for ROC/PR curves."
|
|
299
|
+
self.logger.error(msg)
|
|
300
|
+
raise ValueError(msg)
|
|
297
301
|
|
|
298
302
|
# Validate/normalize label names
|
|
299
303
|
if label_names is not None and len(label_names) != num_classes:
|
|
@@ -391,7 +395,7 @@ class Plotting:
|
|
|
391
395
|
ncol=2,
|
|
392
396
|
)
|
|
393
397
|
|
|
394
|
-
#
|
|
398
|
+
# Precision-recall
|
|
395
399
|
axes[1].plot(
|
|
396
400
|
recall["micro"],
|
|
397
401
|
precision["micro"],
|
|
@@ -433,7 +437,9 @@ class Plotting:
|
|
|
433
437
|
)
|
|
434
438
|
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
435
439
|
if self.show_plots:
|
|
436
|
-
|
|
440
|
+
with warnings.catch_warnings():
|
|
441
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
442
|
+
plt.show()
|
|
437
443
|
plt.close(fig)
|
|
438
444
|
|
|
439
445
|
# ---- MultiQC: metrics table + per-class AUC/AP heatmap ------------
|
|
@@ -465,73 +471,70 @@ class Plotting:
|
|
|
465
471
|
except Exception as exc: # pragma: no cover - defensive
|
|
466
472
|
self.logger.warning(f"Failed to queue MultiQC ROC/PR curves: {exc}")
|
|
467
473
|
|
|
468
|
-
def
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
474
|
+
def _series_from_history(self, vals: list[float]) -> pd.Series:
|
|
475
|
+
"""Convert to float series and coerce non-finite to NaN."""
|
|
476
|
+
s = pd.Series(vals, dtype="float64")
|
|
477
|
+
s[~np.isfinite(s.to_numpy())] = np.nan
|
|
478
|
+
return s
|
|
479
|
+
|
|
480
|
+
def _interp_sparse(self, s: pd.Series) -> pd.Series:
|
|
481
|
+
"""Interpolate internal gaps; keep leading/trailing NaNs."""
|
|
482
|
+
# Only interpolate if we have enough points to make it meaningful
|
|
483
|
+
if s.notna().sum() < 2:
|
|
484
|
+
return s
|
|
485
|
+
return s.interpolate(method="linear", limit_area="inside")
|
|
486
|
+
|
|
487
|
+
def plot_history(self, history: dict[str, list[float]]) -> None:
|
|
472
488
|
"""Plot model history traces. Will be saved to file.
|
|
473
489
|
|
|
474
490
|
This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
|
|
475
491
|
|
|
476
492
|
Args:
|
|
477
|
-
history (
|
|
493
|
+
history (dict[str, list[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
|
|
478
494
|
|
|
479
495
|
Raises:
|
|
480
|
-
ValueError:
|
|
496
|
+
ValueError: self.model_name must be either 'ImputeAutoencoder' or 'ImputeVAE'.
|
|
481
497
|
"""
|
|
482
|
-
if self.model_name not in {
|
|
483
|
-
"
|
|
484
|
-
"ImputeVAE",
|
|
485
|
-
"ImputeAutoencoder",
|
|
486
|
-
"ImputeUBP",
|
|
487
|
-
}:
|
|
488
|
-
msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
|
|
498
|
+
if self.model_name not in {"ImputeVAE", "ImputeAutoencoder"}:
|
|
499
|
+
msg = f"model_name must be 'ImputeVAE' or 'ImputeAutoencoder', but got: {self.model_name}."
|
|
489
500
|
self.logger.error(msg)
|
|
490
501
|
raise ValueError(msg)
|
|
491
502
|
|
|
492
|
-
if
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
# Plot train accuracy
|
|
498
|
-
ax.plot(df["Train"], c="blue", lw=3)
|
|
503
|
+
if not history:
|
|
504
|
+
msg = "history object passed to plot_history is empty."
|
|
505
|
+
self.logger.error(msg)
|
|
506
|
+
raise ValueError(msg)
|
|
499
507
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
508
|
+
if (
|
|
509
|
+
not isinstance(history, dict)
|
|
510
|
+
or "Train" not in history
|
|
511
|
+
or "Val" not in history
|
|
512
|
+
):
|
|
513
|
+
msg = "history must be of type dict and contain 'Train' and 'Val' keys."
|
|
514
|
+
self.logger.error(msg)
|
|
515
|
+
raise TypeError(msg)
|
|
504
516
|
|
|
505
|
-
|
|
506
|
-
fig, ax = plt.subplots(3, 1, figsize=(12, 8))
|
|
517
|
+
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
|
507
518
|
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
isinstance(history, dict)
|
|
511
|
-
and "Train" in history
|
|
512
|
-
and isinstance(history["Train"], dict)
|
|
513
|
-
):
|
|
514
|
-
msg = "For ImputeUBP, history must be a nested dictionary with phases."
|
|
515
|
-
self.logger.error(msg)
|
|
516
|
-
raise TypeError(msg)
|
|
519
|
+
train = self._series_from_history(history["Train"]).iloc[1:]
|
|
520
|
+
val = self._series_from_history(history["Val"]).iloc[1:]
|
|
517
521
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
train = train.iloc[1:] # ignore first epoch
|
|
522
|
+
ax.plot(train.index, train.to_numpy(), c="blue", lw=3, linestyle="-")
|
|
523
|
+
ax.plot(val.index, val.to_numpy(), c="orange", lw=3, linestyle="-")
|
|
521
524
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
ax[i].set_xlabel("Epoch")
|
|
527
|
-
ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
|
|
525
|
+
ax.set_title(f"{self.model_name} Loss per Epoch")
|
|
526
|
+
ax.set_ylabel("Loss")
|
|
527
|
+
ax.set_xlabel("Epoch")
|
|
528
|
+
ax.legend(["Train", "Validation"], loc="best", shadow=True, fancybox=True)
|
|
528
529
|
|
|
529
530
|
fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
|
|
530
531
|
fn = self.output_dir / fn
|
|
531
532
|
fig.savefig(fn)
|
|
532
533
|
|
|
533
534
|
if self.show_plots:
|
|
534
|
-
|
|
535
|
+
with warnings.catch_warnings():
|
|
536
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
537
|
+
plt.show()
|
|
535
538
|
plt.close(fig)
|
|
536
539
|
|
|
537
540
|
# ---- MultiQC: training-loss vs epoch linegraphs -------------------
|
|
@@ -606,7 +609,7 @@ class Plotting:
|
|
|
606
609
|
panel_suffix = f"{prefix}_" if prefix else ""
|
|
607
610
|
panel_id = f"{self.model_name.lower()}_{panel_suffix}confusion_matrix"
|
|
608
611
|
|
|
609
|
-
if prefix != "":
|
|
612
|
+
if prefix != "" and not prefix.endswith("_"):
|
|
610
613
|
prefix = f"{prefix}_"
|
|
611
614
|
|
|
612
615
|
out_name = (
|
|
@@ -614,7 +617,9 @@ class Plotting:
|
|
|
614
617
|
)
|
|
615
618
|
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
|
|
616
619
|
if self.show_plots:
|
|
617
|
-
|
|
620
|
+
with warnings.catch_warnings():
|
|
621
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
622
|
+
plt.show()
|
|
618
623
|
plt.close(fig)
|
|
619
624
|
|
|
620
625
|
# ---- MultiQC: confusion-matrix heatmap ----------------------------
|
|
@@ -715,7 +720,9 @@ class Plotting:
|
|
|
715
720
|
fig.savefig(fn, dpi=300)
|
|
716
721
|
|
|
717
722
|
if self.show_plots:
|
|
718
|
-
|
|
723
|
+
with warnings.catch_warnings():
|
|
724
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
725
|
+
plt.show()
|
|
719
726
|
plt.close(fig)
|
|
720
727
|
|
|
721
728
|
# ---- MultiQC: genotype-distribution barplot -----------------------
|
|
@@ -763,19 +770,19 @@ class Plotting:
|
|
|
763
770
|
if df_trials.empty or "value" not in df_trials:
|
|
764
771
|
return
|
|
765
772
|
|
|
766
|
-
|
|
773
|
+
history_data: Dict[str, Dict[int, float]] = {
|
|
767
774
|
model_name: {
|
|
768
|
-
row["number"]: row["value"]
|
|
775
|
+
int(row["number"]): float(row["value"])
|
|
769
776
|
for _, row in df_trials.iterrows()
|
|
770
777
|
if row["value"] is not None
|
|
771
778
|
}
|
|
772
779
|
}
|
|
773
780
|
|
|
774
|
-
if not
|
|
781
|
+
if not history_data[model_name]:
|
|
775
782
|
return
|
|
776
783
|
|
|
777
784
|
SNPioMultiQC.queue_linegraph(
|
|
778
|
-
data=
|
|
785
|
+
data=cast(Dict[str, Dict[int, int]], history_data),
|
|
779
786
|
panel_id=f"{self.model_name}_optuna_history",
|
|
780
787
|
section=self.multiqc_section,
|
|
781
788
|
title=f"{self.model_name} Optuna Optimization History",
|
|
@@ -792,8 +799,16 @@ class Plotting:
|
|
|
792
799
|
return
|
|
793
800
|
|
|
794
801
|
if best_params:
|
|
795
|
-
|
|
796
|
-
|
|
802
|
+
# Build a single dict so static type checkers don't infer a
|
|
803
|
+
# mismatched dtype for the Series and complain about assigning
|
|
804
|
+
# a float value after creation.
|
|
805
|
+
best_param_data: Dict[str, float | int | str] = {
|
|
806
|
+
**{str(k): cast(float | int | str, v) for k, v in best_params.items()},
|
|
807
|
+
"objective": float(best_value),
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
series = pd.Series(best_param_data, name="Best Value")
|
|
811
|
+
|
|
797
812
|
SNPioMultiQC.queue_table(
|
|
798
813
|
df=series,
|
|
799
814
|
panel_id=f"{self.model_name}_optuna_best_params",
|
|
@@ -992,7 +1007,7 @@ class Plotting:
|
|
|
992
1007
|
def _queue_multiqc_history(
|
|
993
1008
|
self,
|
|
994
1009
|
*,
|
|
995
|
-
history:
|
|
1010
|
+
history: Mapping[str, List[float] | Dict[str, List[float]] | None] | None,
|
|
996
1011
|
) -> None:
|
|
997
1012
|
"""Queue training history (loss vs epoch) for MultiQC.
|
|
998
1013
|
|
pgsui/utils/pretty_metrics.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import math
|
|
4
5
|
from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
except Exception:
|
|
9
|
-
np = None # type: ignore
|
|
7
|
+
import numpy as np
|
|
10
8
|
|
|
11
9
|
# Optional Rich console; falls back to ASCII if not installed.
|
|
12
10
|
try:
|
|
@@ -152,8 +150,6 @@ class PrettyMetrics:
|
|
|
152
150
|
Returns:
|
|
153
151
|
str: Compact JSON representation, suitable for logging artifacts.
|
|
154
152
|
"""
|
|
155
|
-
import json
|
|
156
|
-
|
|
157
153
|
return json.dumps(self.metrics, separators=(",", ":"), ensure_ascii=False)
|
|
158
154
|
|
|
159
155
|
# ----------------------- Internal helpers -----------------------------
|
pgsui/utils/scorers.py
CHANGED
|
@@ -5,6 +5,8 @@ from sklearn.metrics import (
|
|
|
5
5
|
accuracy_score,
|
|
6
6
|
average_precision_score,
|
|
7
7
|
f1_score,
|
|
8
|
+
jaccard_score,
|
|
9
|
+
matthews_corrcoef,
|
|
8
10
|
precision_score,
|
|
9
11
|
recall_score,
|
|
10
12
|
roc_auc_score,
|
|
@@ -164,6 +166,8 @@ class Scorer:
|
|
|
164
166
|
"f1",
|
|
165
167
|
"precision",
|
|
166
168
|
"recall",
|
|
169
|
+
"mcc",
|
|
170
|
+
"jaccard",
|
|
167
171
|
] = "pr_macro",
|
|
168
172
|
) -> Dict[str, float] | None:
|
|
169
173
|
"""Evaluate the model using various metrics.
|
|
@@ -228,6 +232,10 @@ class Scorer:
|
|
|
228
232
|
metrics = {"precision": self.precision(y_true, y_pred)}
|
|
229
233
|
elif tune_metric == "recall":
|
|
230
234
|
metrics = {"recall": self.recall(y_true, y_pred)}
|
|
235
|
+
elif tune_metric == "jaccard":
|
|
236
|
+
metrics = {"jaccard": self.jaccard(y_true, y_pred)}
|
|
237
|
+
elif tune_metric == "mcc":
|
|
238
|
+
metrics = {"mcc": self.mcc(y_true, y_pred)}
|
|
231
239
|
else:
|
|
232
240
|
msg = f"Invalid tune_metric provided: '{tune_metric}'."
|
|
233
241
|
self.logger.error(msg)
|
|
@@ -241,10 +249,41 @@ class Scorer:
|
|
|
241
249
|
"roc_auc": self.roc_auc(y_true, y_pred_proba),
|
|
242
250
|
"average_precision": self.average_precision(y_true, y_pred_proba),
|
|
243
251
|
"pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
|
|
252
|
+
"jaccard": self.jaccard(np.asarray(y_true), np.asarray(y_pred)),
|
|
253
|
+
"mcc": self.mcc(np.asarray(y_true), np.asarray(y_pred)),
|
|
244
254
|
}
|
|
245
255
|
|
|
246
256
|
return {k: float(v) for k, v in metrics.items()}
|
|
247
257
|
|
|
258
|
+
def jaccard(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
259
|
+
"""Compute the Jaccard similarity coefficient.
|
|
260
|
+
|
|
261
|
+
The Jaccard similarity coefficient, also known as Intersection over Union (IoU), measures the similarity between two sets. It is defined as the size of the intersection divided by the size of the union of the sample sets.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
y_true (np.ndarray): Ground truth (correct) target values.
|
|
265
|
+
y_pred (np.ndarray): Predicted target values.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
float: Jaccard similarity coefficient.
|
|
269
|
+
"""
|
|
270
|
+
avg: str = self.average
|
|
271
|
+
return float(jaccard_score(y_true, y_pred, average=avg, zero_division=0))
|
|
272
|
+
|
|
273
|
+
def mcc(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
274
|
+
"""Compute the Matthews correlation coefficient (MCC).
|
|
275
|
+
|
|
276
|
+
MCC is a balanced measure that can be used even if the classes are of very different sizes. It returns a value between -1 and +1, where +1 indicates a perfect prediction, 0 indicates no better than random prediction, and -1 indicates total disagreement between prediction and observation.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
y_true (np.ndarray): Ground truth (correct) target values.
|
|
280
|
+
y_pred (np.ndarray): Predicted target values.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
float: Matthews correlation coefficient.
|
|
284
|
+
"""
|
|
285
|
+
return float(matthews_corrcoef(y_true, y_pred))
|
|
286
|
+
|
|
248
287
|
def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
|
|
249
288
|
"""Average precision with safe multiclass handling.
|
|
250
289
|
|