pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
pgsui/utils/misc.py CHANGED
@@ -66,8 +66,206 @@ def validate_input_type(
66
66
  if isinstance(X, torch.Tensor):
67
67
  return X
68
68
  elif isinstance(X, np.ndarray):
69
- return torch.tensor(X, dtype=torch.float32)
69
+ return torch.tensor(X, dtype=torch.long)
70
70
  elif isinstance(X, pd.DataFrame):
71
- return torch.tensor(X.to_numpy(), dtype=torch.float32)
71
+ return torch.tensor(X.to_numpy(), dtype=torch.long)
72
72
  elif isinstance(X, list):
73
- return torch.tensor(X, dtype=torch.float32)
73
+ return torch.tensor(X, dtype=torch.long)
74
+
75
+
76
+ def detect_computing_device(
77
+ *, force_cpu: bool = False, verbose: bool = False
78
+ ) -> torch.device:
79
+ """Detects and returns the best available PyTorch compute device.
80
+
81
+ Prioritizes CUDA (NVIDIA) > MPS (Apple Silicon) > CPU.
82
+
83
+ Args:
84
+ force_cpu (bool): If True, forces the device to CPU regardless of available hardware. Defaults to False.
85
+ verbose (bool): If True, prints the selected device to stdout. Defaults to False.
86
+
87
+ Returns:
88
+ torch.device: The selected computing device.
89
+ """
90
+ if force_cpu:
91
+ device = torch.device("cpu") # Forced to CPU
92
+ elif torch.cuda.is_available():
93
+ device = torch.device("cuda")
94
+ elif torch.backends.mps.is_available():
95
+ device = torch.device("mps")
96
+ else:
97
+ device = torch.device("cpu") # Fallback to CPU
98
+
99
+ if verbose:
100
+ print(f"Selected compute device: {device}")
101
+
102
+ return device
103
+
104
+
105
+ def get_missing_mask(
106
+ X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
107
+ ) -> pd.DataFrame | pd.Series | np.ndarray | torch.Tensor:
108
+ """Returns a boolean mask indicating missing values (NaN, None).
109
+
110
+ Notes:
111
+ Lists are converted to numpy arrays to compute the mask.
112
+
113
+ Args:
114
+ X: Input data.
115
+
116
+ Returns:
117
+ pd.DataFrame | pd.Series | np.ndarray | torch.Tensor: Boolean mask of the same shape as X (returned as DF, Array, or Tensor).
118
+
119
+ Raises:
120
+ TypeError: If input type is not supported.
121
+ """
122
+ if isinstance(X, pd.DataFrame):
123
+ return X.isna()
124
+
125
+ elif isinstance(X, pd.Series):
126
+ return pd.isna(X)
127
+
128
+ elif isinstance(X, np.ndarray):
129
+ # np.isnan fails on object arrays (e.g. strings)
130
+ # so we check generically first
131
+ if X.dtype.kind in {"U", "S", "O"}: # String/Object
132
+ return pd.isnull(X)
133
+ return np.isnan(X)
134
+
135
+ elif isinstance(X, torch.Tensor):
136
+ return torch.isnan(X)
137
+
138
+ elif isinstance(X, list):
139
+ arr = np.array(X)
140
+ # Handle mixed types in lists
141
+ if arr.dtype.kind in {"U", "S", "O"}:
142
+ return pd.isnull(arr)
143
+ return np.isnan(arr)
144
+
145
+ else:
146
+ raise TypeError(
147
+ f"Unsupported type for missing value detection. Expected pandas.DataFrame, pandas.Series, numpy.ndarray, list, or torch.Tensor but got {type(X)}"
148
+ )
149
+
150
+
151
+ def ensure_2d(
152
+ X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
153
+ ) -> pd.DataFrame | np.ndarray | list | torch.Tensor:
154
+ """Ensures the input is at least 2-dimensional.
155
+
156
+ If input is 1D (e.g., shape (N,)), it is reshaped to (N, 1). Already 2D+ inputs are returned unchanged.
157
+
158
+ Args:
159
+ X (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
160
+
161
+ Returns:
162
+ pd.DataFrame | np.ndarray | list | torch.Tensor: Input data transformed to be at least 2D.
163
+
164
+ Raises:
165
+ TypeError: If input type is not supported.
166
+ """
167
+ if isinstance(X, pd.DataFrame):
168
+ return X # DataFrames are always 2D
169
+
170
+ elif isinstance(X, pd.Series):
171
+ return X.to_frame() # Convert Series to DataFrame (2D)
172
+
173
+ elif isinstance(X, np.ndarray):
174
+ if X.ndim == 1:
175
+ return X.reshape(-1, 1)
176
+ return X
177
+
178
+ elif isinstance(X, torch.Tensor):
179
+ if X.dim() == 1:
180
+ return X.unsqueeze(1)
181
+ return X
182
+
183
+ elif isinstance(X, list):
184
+ # Check depth of list
185
+ if not X:
186
+ return X
187
+ if not isinstance(X[0], list):
188
+ return [[x] for x in X]
189
+ return X
190
+
191
+ else:
192
+ msg = f"X must be of type pandas.DataFrame, pd.Series, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
193
+ raise TypeError(msg)
194
+
195
+
196
+ def flatten_1d(
197
+ y: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
198
+ ) -> pd.Series | np.ndarray | list | torch.Tensor:
199
+ """
200
+ Flattens input to a 1D structure.
201
+
202
+ Args:
203
+ y (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
204
+
205
+ Returns:
206
+ pd.Series | np.ndarray | list | torch.Tensor: 1D representation of the input.
207
+
208
+ Notes:
209
+ Inputs with multiple columns (e.g., DataFrame with >1 column) are flattened into a single 1D structure.
210
+
211
+ Raises:
212
+ TypeError: If input type is not supported.
213
+ """
214
+ if isinstance(y, pd.DataFrame):
215
+ if y.shape[1] == 1:
216
+ return y.iloc[:, 0]
217
+ else:
218
+ return pd.Series(y.to_numpy().flatten())
219
+
220
+ elif isinstance(y, np.ndarray):
221
+ return y.flatten()
222
+
223
+ elif isinstance(y, torch.Tensor):
224
+ return y.view(-1)
225
+
226
+ elif isinstance(y, list):
227
+ # Recursively flatten list if needed, or simple comprehension if just 2D
228
+ if not y:
229
+ return y
230
+ if isinstance(y[0], list):
231
+ return [item for sublist in y for item in sublist]
232
+ return y
233
+
234
+ else:
235
+ msg = f"Input must be of type pandas.DataFrame, pandas.Series, numpy.ndarray, list, or torch.Tensor, but got {type(y)}"
236
+ raise TypeError(msg)
237
+
238
+
239
+ def safe_shape(
240
+ X: pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor,
241
+ ) -> tuple[int, ...]:
242
+ """Returns the shape of the input container as a tuple.
243
+
244
+ Args:
245
+ X (pd.DataFrame | pd.Series | np.ndarray | list | torch.Tensor): Input data.
246
+
247
+ Returns:
248
+ tuple[int, ...]: Dimensions of the data (rows, cols, etc.).
249
+ """
250
+ if isinstance(X, (pd.DataFrame, np.ndarray)):
251
+ return X.shape
252
+
253
+ elif isinstance(X, pd.Series):
254
+ return (X.shape[0],)
255
+
256
+ elif isinstance(X, torch.Tensor):
257
+ return tuple(X.shape)
258
+
259
+ elif isinstance(X, list):
260
+ if not X:
261
+ return (0,)
262
+ rows = len(X)
263
+
264
+ # Check if 2D list
265
+ if isinstance(X[0], list):
266
+ return (rows, len(X[0]))
267
+ return (rows,)
268
+
269
+ else:
270
+ msg = f"X must be of type pandas.DataFrame, pd.Series, numpy.ndarray, list, or torch.Tensor, but got {type(X)}"
271
+ raise TypeError(msg)
pgsui/utils/plotting.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import warnings
3
3
  from pathlib import Path
4
- from typing import Dict, List, Literal, Optional, Sequence, cast
4
+ from typing import Dict, List, Literal, Optional, Sequence, Mapping, cast
5
5
 
6
6
  import matplotlib as mpl
7
7
 
@@ -294,6 +294,10 @@ class Plotting:
294
294
  ValueError: If model_name is not recognized (legacy guard).
295
295
  """
296
296
  num_classes = y_pred_proba.shape[1]
297
+ if num_classes < 2:
298
+ msg = "plot_metrics: num_classes must be >= 2 for ROC/PR curves."
299
+ self.logger.error(msg)
300
+ raise ValueError(msg)
297
301
 
298
302
  # Validate/normalize label names
299
303
  if label_names is not None and len(label_names) != num_classes:
@@ -391,7 +395,7 @@ class Plotting:
391
395
  ncol=2,
392
396
  )
393
397
 
394
- # PR
398
+ # Precision-recall
395
399
  axes[1].plot(
396
400
  recall["micro"],
397
401
  precision["micro"],
@@ -433,7 +437,9 @@ class Plotting:
433
437
  )
434
438
  fig.savefig(self.output_dir / out_name, bbox_inches="tight")
435
439
  if self.show_plots:
436
- plt.show()
440
+ with warnings.catch_warnings():
441
+ warnings.simplefilter("ignore", UserWarning)
442
+ plt.show()
437
443
  plt.close(fig)
438
444
 
439
445
  # ---- MultiQC: metrics table + per-class AUC/AP heatmap ------------
@@ -465,73 +471,70 @@ class Plotting:
465
471
  except Exception as exc: # pragma: no cover - defensive
466
472
  self.logger.warning(f"Failed to queue MultiQC ROC/PR curves: {exc}")
467
473
 
468
- def plot_history(
469
- self,
470
- history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
471
- ) -> None:
474
+ def _series_from_history(self, vals: list[float]) -> pd.Series:
475
+ """Convert to float series and coerce non-finite to NaN."""
476
+ s = pd.Series(vals, dtype="float64")
477
+ s[~np.isfinite(s.to_numpy())] = np.nan
478
+ return s
479
+
480
+ def _interp_sparse(self, s: pd.Series) -> pd.Series:
481
+ """Interpolate internal gaps; keep leading/trailing NaNs."""
482
+ # Only interpolate if we have enough points to make it meaningful
483
+ if s.notna().sum() < 2:
484
+ return s
485
+ return s.interpolate(method="linear", limit_area="inside")
486
+
487
+ def plot_history(self, history: dict[str, list[float]]) -> None:
472
488
  """Plot model history traces. Will be saved to file.
473
489
 
474
490
  This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
475
491
 
476
492
  Args:
477
- history (Dict[str, List[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
493
+ history (dict[str, list[float]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
478
494
 
479
495
  Raises:
480
- ValueError: nn_method must be either 'ImputeNLPCA', 'ImputeUBP', 'ImputeAutoencoder', 'ImputeVAE'.
496
+ ValueError: self.model_name must be either 'ImputeAutoencoder' or 'ImputeVAE'.
481
497
  """
482
- if self.model_name not in {
483
- "ImputeNLPCA",
484
- "ImputeVAE",
485
- "ImputeAutoencoder",
486
- "ImputeUBP",
487
- }:
488
- msg = "nn_method must be either 'ImputeNLPCA', 'ImputeVAE', 'ImputeAutoencoder', 'ImputeUBP'."
498
+ if self.model_name not in {"ImputeVAE", "ImputeAutoencoder"}:
499
+ msg = f"model_name must be 'ImputeVAE' or 'ImputeAutoencoder', but got: {self.model_name}."
489
500
  self.logger.error(msg)
490
501
  raise ValueError(msg)
491
502
 
492
- if self.model_name != "ImputeUBP":
493
- fig, ax = plt.subplots(1, 1, figsize=(12, 8))
494
- df = pd.DataFrame(history)
495
- df = df.iloc[1:]
496
-
497
- # Plot train accuracy
498
- ax.plot(df["Train"], c="blue", lw=3)
503
+ if not history:
504
+ msg = "history object passed to plot_history is empty."
505
+ self.logger.error(msg)
506
+ raise ValueError(msg)
499
507
 
500
- ax.set_title(f"{self.model_name} Loss per Epoch")
501
- ax.set_ylabel("Loss")
502
- ax.set_xlabel("Epoch")
503
- ax.legend(["Train"], loc="best", shadow=True, fancybox=True)
508
+ if (
509
+ not isinstance(history, dict)
510
+ or "Train" not in history
511
+ or "Val" not in history
512
+ ):
513
+ msg = "history must be of type dict and contain 'Train' and 'Val' keys."
514
+ self.logger.error(msg)
515
+ raise TypeError(msg)
504
516
 
505
- else:
506
- fig, ax = plt.subplots(3, 1, figsize=(12, 8))
517
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
507
518
 
508
- # Ensure history is the nested dictionary type for ImputeUBP
509
- if not (
510
- isinstance(history, dict)
511
- and "Train" in history
512
- and isinstance(history["Train"], dict)
513
- ):
514
- msg = "For ImputeUBP, history must be a nested dictionary with phases."
515
- self.logger.error(msg)
516
- raise TypeError(msg)
519
+ train = self._series_from_history(history["Train"]).iloc[1:]
520
+ val = self._series_from_history(history["Val"]).iloc[1:]
517
521
 
518
- for i, phase in enumerate(range(1, 4)):
519
- train = pd.Series(history["Train"][f"Phase {phase}"])
520
- train = train.iloc[1:] # ignore first epoch
522
+ ax.plot(train.index, train.to_numpy(), c="blue", lw=3, linestyle="-")
523
+ ax.plot(val.index, val.to_numpy(), c="orange", lw=3, linestyle="-")
521
524
 
522
- # Plot train accuracy
523
- ax[i].plot(train, c="blue", lw=3)
524
- ax[i].set_title(f"{self.model_name}: Phase {phase} Loss per Epoch")
525
- ax[i].set_ylabel("Loss")
526
- ax[i].set_xlabel("Epoch")
527
- ax[i].legend([f"Phase {phase}"], loc="best", shadow=True, fancybox=True)
525
+ ax.set_title(f"{self.model_name} Loss per Epoch")
526
+ ax.set_ylabel("Loss")
527
+ ax.set_xlabel("Epoch")
528
+ ax.legend(["Train", "Validation"], loc="best", shadow=True, fancybox=True)
528
529
 
529
530
  fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
530
531
  fn = self.output_dir / fn
531
532
  fig.savefig(fn)
532
533
 
533
534
  if self.show_plots:
534
- plt.show()
535
+ with warnings.catch_warnings():
536
+ warnings.simplefilter("ignore", UserWarning)
537
+ plt.show()
535
538
  plt.close(fig)
536
539
 
537
540
  # ---- MultiQC: training-loss vs epoch linegraphs -------------------
@@ -606,7 +609,7 @@ class Plotting:
606
609
  panel_suffix = f"{prefix}_" if prefix else ""
607
610
  panel_id = f"{self.model_name.lower()}_{panel_suffix}confusion_matrix"
608
611
 
609
- if prefix != "":
612
+ if prefix != "" and not prefix.endswith("_"):
610
613
  prefix = f"{prefix}_"
611
614
 
612
615
  out_name = (
@@ -614,7 +617,9 @@ class Plotting:
614
617
  )
615
618
  fig.savefig(self.output_dir / out_name, bbox_inches="tight")
616
619
  if self.show_plots:
617
- plt.show()
620
+ with warnings.catch_warnings():
621
+ warnings.simplefilter("ignore", UserWarning)
622
+ plt.show()
618
623
  plt.close(fig)
619
624
 
620
625
  # ---- MultiQC: confusion-matrix heatmap ----------------------------
@@ -715,7 +720,9 @@ class Plotting:
715
720
  fig.savefig(fn, dpi=300)
716
721
 
717
722
  if self.show_plots:
718
- plt.show()
723
+ with warnings.catch_warnings():
724
+ warnings.simplefilter("ignore", UserWarning)
725
+ plt.show()
719
726
  plt.close(fig)
720
727
 
721
728
  # ---- MultiQC: genotype-distribution barplot -----------------------
@@ -763,19 +770,19 @@ class Plotting:
763
770
  if df_trials.empty or "value" not in df_trials:
764
771
  return
765
772
 
766
- data: Dict[str, Dict[int, int]] = {
773
+ history_data: Dict[str, Dict[int, float]] = {
767
774
  model_name: {
768
- row["number"]: row["value"]
775
+ int(row["number"]): float(row["value"])
769
776
  for _, row in df_trials.iterrows()
770
777
  if row["value"] is not None
771
778
  }
772
779
  }
773
780
 
774
- if not data[model_name]:
781
+ if not history_data[model_name]:
775
782
  return
776
783
 
777
784
  SNPioMultiQC.queue_linegraph(
778
- data=data,
785
+ data=cast(Dict[str, Dict[int, int]], history_data),
779
786
  panel_id=f"{self.model_name}_optuna_history",
780
787
  section=self.multiqc_section,
781
788
  title=f"{self.model_name} Optuna Optimization History",
@@ -792,8 +799,16 @@ class Plotting:
792
799
  return
793
800
 
794
801
  if best_params:
795
- series = pd.Series(best_params, name="Best Value")
796
- series["objective"] = best_value
802
+ # Build a single dict so static type checkers don't infer a
803
+ # mismatched dtype for the Series and complain about assigning
804
+ # a float value after creation.
805
+ best_param_data: Dict[str, float | int | str] = {
806
+ **{str(k): cast(float | int | str, v) for k, v in best_params.items()},
807
+ "objective": float(best_value),
808
+ }
809
+
810
+ series = pd.Series(best_param_data, name="Best Value")
811
+
797
812
  SNPioMultiQC.queue_table(
798
813
  df=series,
799
814
  panel_id=f"{self.model_name}_optuna_best_params",
@@ -992,7 +1007,7 @@ class Plotting:
992
1007
  def _queue_multiqc_history(
993
1008
  self,
994
1009
  *,
995
- history: Dict[str, List[float] | Dict[str, List[float]] | None] | None,
1010
+ history: Mapping[str, List[float] | Dict[str, List[float]] | None] | None,
996
1011
  ) -> None:
997
1012
  """Queue training history (loss vs epoch) for MultiQC.
998
1013
 
@@ -1,12 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  import math
4
5
  from typing import Any, Iterable, List, Mapping, Optional, Sequence, Tuple
5
6
 
6
- try:
7
- import numpy as np
8
- except Exception:
9
- np = None # type: ignore
7
+ import numpy as np
10
8
 
11
9
  # Optional Rich console; falls back to ASCII if not installed.
12
10
  try:
@@ -152,8 +150,6 @@ class PrettyMetrics:
152
150
  Returns:
153
151
  str: Compact JSON representation, suitable for logging artifacts.
154
152
  """
155
- import json
156
-
157
153
  return json.dumps(self.metrics, separators=(",", ":"), ensure_ascii=False)
158
154
 
159
155
  # ----------------------- Internal helpers -----------------------------
pgsui/utils/scorers.py CHANGED
@@ -5,6 +5,8 @@ from sklearn.metrics import (
5
5
  accuracy_score,
6
6
  average_precision_score,
7
7
  f1_score,
8
+ jaccard_score,
9
+ matthews_corrcoef,
8
10
  precision_score,
9
11
  recall_score,
10
12
  roc_auc_score,
@@ -164,6 +166,8 @@ class Scorer:
164
166
  "f1",
165
167
  "precision",
166
168
  "recall",
169
+ "mcc",
170
+ "jaccard",
167
171
  ] = "pr_macro",
168
172
  ) -> Dict[str, float] | None:
169
173
  """Evaluate the model using various metrics.
@@ -228,6 +232,10 @@ class Scorer:
228
232
  metrics = {"precision": self.precision(y_true, y_pred)}
229
233
  elif tune_metric == "recall":
230
234
  metrics = {"recall": self.recall(y_true, y_pred)}
235
+ elif tune_metric == "jaccard":
236
+ metrics = {"jaccard": self.jaccard(y_true, y_pred)}
237
+ elif tune_metric == "mcc":
238
+ metrics = {"mcc": self.mcc(y_true, y_pred)}
231
239
  else:
232
240
  msg = f"Invalid tune_metric provided: '{tune_metric}'."
233
241
  self.logger.error(msg)
@@ -241,10 +249,41 @@ class Scorer:
241
249
  "roc_auc": self.roc_auc(y_true, y_pred_proba),
242
250
  "average_precision": self.average_precision(y_true, y_pred_proba),
243
251
  "pr_macro": self.pr_macro(y_true_ohe, y_pred_proba),
252
+ "jaccard": self.jaccard(np.asarray(y_true), np.asarray(y_pred)),
253
+ "mcc": self.mcc(np.asarray(y_true), np.asarray(y_pred)),
244
254
  }
245
255
 
246
256
  return {k: float(v) for k, v in metrics.items()}
247
257
 
258
+ def jaccard(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
259
+ """Compute the Jaccard similarity coefficient.
260
+
261
+ The Jaccard similarity coefficient, also known as Intersection over Union (IoU), measures the similarity between two sets. It is defined as the size of the intersection divided by the size of the union of the sample sets.
262
+
263
+ Args:
264
+ y_true (np.ndarray): Ground truth (correct) target values.
265
+ y_pred (np.ndarray): Predicted target values.
266
+
267
+ Returns:
268
+ float: Jaccard similarity coefficient.
269
+ """
270
+ avg: str = self.average
271
+ return float(jaccard_score(y_true, y_pred, average=avg, zero_division=0))
272
+
273
+ def mcc(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
274
+ """Compute the Matthews correlation coefficient (MCC).
275
+
276
+ MCC is a balanced measure that can be used even if the classes are of very different sizes. It returns a value between -1 and +1, where +1 indicates a perfect prediction, 0 indicates no better than random prediction, and -1 indicates total disagreement between prediction and observation.
277
+
278
+ Args:
279
+ y_true (np.ndarray): Ground truth (correct) target values.
280
+ y_pred (np.ndarray): Predicted target values.
281
+
282
+ Returns:
283
+ float: Matthews correlation coefficient.
284
+ """
285
+ return float(matthews_corrcoef(y_true, y_pred))
286
+
248
287
  def average_precision(self, y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
249
288
  """Average precision with safe multiclass handling.
250
289