lecrapaud 0.5.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lecrapaud might be problematic. Click here for more details.

@@ -75,7 +75,7 @@ from lecrapaud.db import (
75
75
  ModelTraining,
76
76
  Score,
77
77
  Target,
78
- Dataset,
78
+ Experiment,
79
79
  )
80
80
 
81
81
  # Reproducible result
@@ -144,8 +144,12 @@ class ModelEngine:
144
144
  self.plot = plot
145
145
  self.log_dir = log_dir
146
146
 
147
- if self.need_scaling and self.target_type == "regression":
148
- self.scaler_y = joblib.load(f"{self.path}/scaler_y.pkl")
147
+ if self.path and self.need_scaling and self.target_type == "regression":
148
+ preprocessing_dir = Path(f"{self.path}/../preprocessing")
149
+ target_number = self.path.split("/")[-1].split("_")[-1]
150
+ self.scaler_y = joblib.load(
151
+ preprocessing_dir / f"scaler_y_{target_number}.pkl"
152
+ )
149
153
  else:
150
154
  self.scaler_y = None
151
155
 
@@ -204,7 +208,7 @@ class ModelEngine:
204
208
  """
205
209
  lightGBM = self.create_model == "lgb"
206
210
 
207
- # Datasets
211
+ # Experiments
208
212
  boosting_dataset = lgb.Dataset if lightGBM else xgb.DMatrix
209
213
  train_data = boosting_dataset(x_train, label=y_train)
210
214
  val_data = boosting_dataset(x_val, label=y_val)
@@ -482,7 +486,7 @@ class ModelEngine:
482
486
 
483
487
  def predict(
484
488
  self,
485
- data: pd.DataFrame,
489
+ data: pd.DataFrame | np.ndarray,
486
490
  threshold: float = 0.5,
487
491
  ):
488
492
  """Function to get prediction from model. Support sklearn, keras and boosting models such as xgboost and lgboost
@@ -500,77 +504,58 @@ class ModelEngine:
500
504
  if self.threshold and threshold == 0.5:
501
505
  threshold = self.threshold
502
506
 
503
- if self.recurrent or model.model_name in ["lgb", "xgb"]:
504
- # keras, lgb & xgb
505
- if model.model_name == "lgb":
506
- # Direct prediction for LightGBM
507
- pred = model.predict(data)
508
- elif model.model_name == "xgb":
509
- # Convert val_data to DMatrix for XGBoost
510
- d_data = xgb.DMatrix(data)
511
- pred = model.predict(d_data)
507
+ # Determine index for output
508
+ if isinstance(data, pd.DataFrame):
509
+ index = data.index
510
+ elif isinstance(data, np.ndarray):
511
+ index = pd.RangeIndex(start=0, stop=data.shape[0])
512
+ else:
513
+ raise ValueError(
514
+ "Unsupported data type: expected pd.DataFrame or np.ndarray"
515
+ )
516
+
517
+ # Keras, LightGBM, XGBoost
518
+ if self.recurrent or self.model_name in ["lgb", "xgb"]:
519
+ if self.model_name == "xgb":
520
+ data_input = xgb.DMatrix(data)
521
+ pred_raw = model.predict(data_input)
512
522
  else:
513
- # Reshape (flatten) for keras if not multiclass
514
- pred = model.predict(data)
515
- if pred.shape[1] == 1:
516
- pred = pred.reshape(-1)
523
+ pred_raw = model.predict(data)
517
524
 
518
- if self.target_type == "classification":
519
- num_class = pred.shape[1] if len(pred.shape) > 1 else 2
525
+ if pred_raw.ndim == 1:
526
+ pred_raw = pred_raw.reshape(-1, 1)
520
527
 
528
+ if self.target_type == "classification":
529
+ num_class = pred_raw.shape[1] if pred_raw.ndim > 1 else 2
521
530
  if num_class <= 2:
522
- # For binary classification, concatenate the predicted probabilities for both classes
523
- pred_df = pd.DataFrame(
524
- {
525
- 0: 1 - pred, # Probability of class 0
526
- 1: pred, # Probability of class 1
527
- },
531
+ pred_proba = pd.DataFrame(
532
+ {0: 1 - pred_raw.ravel(), 1: pred_raw.ravel()}, index=index
528
533
  )
529
534
  else:
530
- # For multi-class classification, use the predicted probabilities for each class
531
- pred_df = pd.DataFrame(pred, columns=range(num_class))
532
-
533
- # Get final predictions (argmax for multi-class, threshold for binary)
534
- if num_class == 2:
535
- pred_df["PRED"] = np.where(
536
- pred_df[1] >= threshold, 1, 0
537
- ) # Class 1 if prob >= threshold
538
- else:
539
- pred_df["PRED"] = pred_df.idxmax(
540
- axis=1
541
- ) # Class with highest probability for multiclasses
542
-
543
- # Reorder columns to show predicted class first, then probabilities
544
- pred = pred_df[["PRED"] + list(range(num_class))]
535
+ pred_proba = pd.DataFrame(
536
+ pred_raw, columns=range(num_class), index=index
537
+ )
545
538
 
539
+ pred_df = apply_thresholds(pred_proba, threshold, pred_proba.columns)
546
540
  else:
547
- pred = pd.Series(pred, name="PRED")
541
+ pred_df = pd.Series(pred_raw.ravel(), index=index, name="PRED")
548
542
 
549
- # set index for lgb and xgb (for keras, as we use np array, we need to set index outside)
550
- if model.model_name in ["lgb", "xgb"]:
551
- pred.index = data.index
543
+ # Sklearn
552
544
  else:
553
- # sk learn
554
- pred = pd.Series(model.predict(data), index=data.index, name="PRED")
555
545
  if self.target_type == "classification":
556
546
  pred_proba = pd.DataFrame(
557
547
  model.predict_proba(data),
558
- index=data.index,
548
+ index=index,
559
549
  columns=[
560
550
  int(c) if isinstance(c, float) and c.is_integer() else c
561
551
  for c in model.classes_
562
552
  ],
563
553
  )
554
+ pred_df = apply_thresholds(pred_proba, threshold, model.classes_)
555
+ else:
556
+ pred_df = pd.Series(model.predict(data), index=index, name="PRED")
564
557
 
565
- # Apply threshold for binary classification
566
- if len(model.classes_) == 2:
567
- positive_class = model.classes_[1] # Assuming classes are ordered
568
- pred = (pred_proba[positive_class] >= threshold).astype(int)
569
- pred.name = "PRED"
570
-
571
- pred = pd.concat([pred, pred_proba], axis=1)
572
-
573
- return pred
558
+ return pred_df
574
559
 
575
560
  def save(self, path):
576
561
  if self.recurrent:
@@ -640,11 +625,13 @@ def trainable(
640
625
  y_val,
641
626
  model_name,
642
627
  target_type,
643
- session_name,
628
+ experiment_name,
644
629
  target_number,
645
630
  create_model,
646
631
  type_name="hyperopts",
647
632
  plot=False,
633
+ log_dir=None,
634
+ target_clf_thresholds: dict = None,
648
635
  ):
649
636
  """Standalone version of train_model that doesn't depend on self"""
650
637
  # Create model engine
@@ -653,10 +640,11 @@ def trainable(
653
640
  target_type=target_type,
654
641
  create_model=create_model,
655
642
  plot=plot,
643
+ log_dir=log_dir,
656
644
  )
657
645
 
658
646
  logger.info(
659
- f"TARGET_{target_number} - Training a {model.model_name} at {datetime.now()} : {session_name}, TARGET_{target_number}"
647
+ f"TARGET_{target_number} - Training a {model.model_name} at {datetime.now()} : {experiment_name}, TARGET_{target_number}"
660
648
  )
661
649
 
662
650
  if model.recurrent:
@@ -696,17 +684,13 @@ def trainable(
696
684
  # Evaluate model
697
685
  score = {
698
686
  "DATE": datetime.now(),
699
- "SESSION": session_name,
700
- "TRAIN_DATA": x_train.shape[0],
701
- "VAL_DATA": x_val.shape[0],
702
- "FEATURES": x_train.shape[-1],
703
687
  "MODEL_NAME": model.model_name,
704
688
  "TYPE": type_name,
705
689
  "TRAINING_TIME": stop - start,
706
690
  "EVAL_DATA_STD": prediction["TARGET"].std(),
707
691
  }
708
692
 
709
- score.update(evaluate(prediction, target_type))
693
+ score.update(evaluate(prediction, target_type, target_clf_thresholds))
710
694
 
711
695
  if type_name == "hyperopts":
712
696
  session.report(metrics=score)
@@ -723,41 +707,47 @@ class ModelSelectionEngine:
723
707
  reshaped_data,
724
708
  target_number,
725
709
  target_clf,
726
- dataset,
710
+ experiment,
727
711
  models_idx,
728
712
  time_series,
729
713
  date_column,
730
714
  group_column,
715
+ target_clf_thresholds,
731
716
  **kwargs,
732
717
  ):
733
718
  self.data = data
734
719
  self.reshaped_data = reshaped_data
735
720
  self.target_number = target_number
736
- self.dataset = dataset
721
+ self.experiment = experiment
737
722
  self.target_clf = target_clf
738
723
  self.models_idx = models_idx
739
724
  self.time_series = time_series
740
725
  self.date_column = date_column
741
726
  self.group_column = group_column
727
+ self.target_clf_thresholds = (
728
+ target_clf_thresholds[target_number]
729
+ if target_number in target_clf_thresholds.keys()
730
+ else None
731
+ )
742
732
 
743
733
  self.target_type = (
744
734
  "classification" if self.target_number in self.target_clf else "regression"
745
735
  )
746
- self.dataset_dir = self.dataset.path
747
- self.dataset_id = self.dataset.id
748
- self.data_dir = f"{self.dataset_dir}/data"
749
- self.preprocessing_dir = f"{self.dataset_dir}/preprocessing"
750
- self.training_target_dir = f"{self.dataset_dir}/TARGET_{self.target_number}"
736
+ self.experiment_dir = self.experiment.path
737
+ self.experiment_id = self.experiment.id
738
+ self.data_dir = f"{self.experiment_dir}/data"
739
+ self.preprocessing_dir = f"{self.experiment_dir}/preprocessing"
740
+ self.training_target_dir = f"{self.experiment_dir}/TARGET_{self.target_number}"
751
741
  self.metric = "RMSE" if self.target_type == "regression" else "LOGLOSS"
752
- self.features = self.dataset.get_features(self.target_number)
753
- self.all_features = self.dataset.get_all_features(
742
+ self.features = self.experiment.get_features(self.target_number)
743
+ self.all_features = self.experiment.get_all_features(
754
744
  date_column=self.date_column, group_column=self.group_column
755
745
  )
756
746
 
757
747
  # Main training function
758
748
  def run(
759
749
  self,
760
- session_name,
750
+ experiment_name,
761
751
  perform_hyperopt=True,
762
752
  number_of_trials=20,
763
753
  perform_crossval=False,
@@ -769,12 +759,12 @@ class ModelSelectionEngine:
769
759
  Selects the best models based on a target variable, optionally performing hyperparameter optimization
770
760
  and cross-validation, and manages outputs in a session-specific directory.
771
761
  """
772
- self.session_name = session_name
762
+ self.experiment_name = experiment_name
773
763
  self.plot = plot
774
764
  self.number_of_trials = number_of_trials
775
765
 
776
- if self.dataset_id is None:
777
- raise ValueError("Please provide a dataset.")
766
+ if self.experiment_id is None:
767
+ raise ValueError("Please provide a experiment.")
778
768
 
779
769
  if self.data:
780
770
  train = self.data["train"]
@@ -791,7 +781,9 @@ class ModelSelectionEngine:
791
781
  train_scaled,
792
782
  val_scaled,
793
783
  test_scaled,
794
- ) = load_train_data(self.dataset_dir, self.target_number, self.target_clf)
784
+ ) = load_train_data(
785
+ self.experiment_dir, self.target_number, self.target_clf
786
+ )
795
787
 
796
788
  if (
797
789
  any(all_models[i].get("recurrent") for i in self.models_idx)
@@ -819,9 +811,9 @@ class ModelSelectionEngine:
819
811
  # create model selection in db
820
812
  target = Target.find_by(name=f"TARGET_{self.target_number}")
821
813
  model_selection = ModelSelection.upsert(
822
- match_fields=["target_id", "dataset_id"],
814
+ match_fields=["target_id", "experiment_id"],
823
815
  target_id=target.id,
824
- dataset_id=self.dataset_id,
816
+ experiment_id=self.experiment_id,
825
817
  )
826
818
 
827
819
  # recurrent models starts at 9 # len(list_models)
@@ -994,7 +986,7 @@ class ModelSelectionEngine:
994
986
  self.metric
995
987
  ].mean()
996
988
  logger.info(
997
- f"Best model mean cross-validation score on entire dataset: {cross_validation_mean_score}"
989
+ f"Best model mean cross-validation score on entire experiment: {cross_validation_mean_score}"
998
990
  )
999
991
 
1000
992
  # Retrain on entire training set, but keep score on cross-validation folds
@@ -1023,7 +1015,7 @@ class ModelSelectionEngine:
1023
1015
 
1024
1016
  # Save validation predictions
1025
1017
  best_pred.to_csv(
1026
- f"{self.results_dir}/pred_val.csv",
1018
+ f"{self.results_dir}/prediction.csv",
1027
1019
  index=True,
1028
1020
  header=True,
1029
1021
  index_label="ID",
@@ -1065,10 +1057,6 @@ class ModelSelectionEngine:
1065
1057
  # Store metrics in DB
1066
1058
  drop_cols = [
1067
1059
  "DATE",
1068
- "SESSION",
1069
- "TRAIN_DATA",
1070
- "VAL_DATA",
1071
- "FEATURES",
1072
1060
  "MODEL_NAME",
1073
1061
  "MODEL_PATH",
1074
1062
  ]
@@ -1117,6 +1105,9 @@ class ModelSelectionEngine:
1117
1105
 
1118
1106
  logger.info(f"Best model overall is : {best_score_overall}")
1119
1107
 
1108
+ best_model = joblib.load(best_model_path)
1109
+ return best_model
1110
+
1120
1111
  def hyperoptimize(self, x_train, y_train, x_val, y_val, model: ModelEngine):
1121
1112
  self.type_name = "hyperopts"
1122
1113
 
@@ -1149,11 +1140,13 @@ class ModelSelectionEngine:
1149
1140
  y_val=y_val,
1150
1141
  model_name=model.model_name,
1151
1142
  target_type=self.target_type,
1152
- session_name=self.session_name,
1143
+ experiment_name=self.experiment_name,
1153
1144
  target_number=self.target_number,
1154
1145
  create_model=model.create_model,
1155
1146
  type_name="hyperopts",
1156
1147
  plot=model.plot,
1148
+ log_dir=model.log_dir,
1149
+ target_clf_thresholds=self.target_clf_thresholds,
1157
1150
  ),
1158
1151
  param_space=model.search_params,
1159
1152
  tune_config=TuneConfig(
@@ -1206,21 +1199,28 @@ class ModelSelectionEngine:
1206
1199
  y_val,
1207
1200
  model.model_name,
1208
1201
  self.target_type,
1209
- self.session_name,
1202
+ self.experiment_name,
1210
1203
  self.target_number,
1211
1204
  model.create_model,
1212
1205
  self.type_name,
1213
1206
  model.plot,
1207
+ log_dir=model.log_dir,
1208
+ target_clf_thresholds=self.target_clf_thresholds,
1214
1209
  )
1215
1210
 
1216
1211
 
1217
- def evaluate(prediction: pd.DataFrame, target_type: str):
1212
+ def evaluate(
1213
+ prediction: pd.DataFrame,
1214
+ target_type: str,
1215
+ target_clf_thresholds: dict = {"precision": 0.80},
1216
+ ):
1218
1217
  """
1219
1218
  Function to evaluate model performance
1220
1219
 
1221
1220
  Args:
1222
1221
  - prediction: the prediction dataframe containing TARGET and PRED columns, as well as predicted probablities for each class for classification tasks
1223
1222
  - target_type: classification or regression
1223
+ - target_clf_thresholds: thresholds for classification tasks like {"recall": 0.9} or {"precision": 0.9}
1224
1224
  """
1225
1225
  score = {}
1226
1226
  y_true = prediction["TARGET"]
@@ -1286,15 +1286,37 @@ def evaluate(prediction: pd.DataFrame, target_type: str):
1286
1286
  average=("binary" if num_classes == 2 else "macro"),
1287
1287
  )
1288
1288
  score["ROC_AUC"] = float(roc_auc_score(y_true, y_pred_proba, multi_class="ovr"))
1289
- (
1290
- score["THRESHOLD"],
1291
- score["PRECISION_AT_THRESHOLD"],
1292
- score["RECALL_AT_THRESHOLD"],
1293
- ) = (
1294
- find_best_precision_threshold(prediction)
1295
- if num_classes == 2
1296
- else (None, None, None)
1289
+
1290
+ # Store the complete thresholds dictionary
1291
+ if len(target_clf_thresholds.keys()) > 1:
1292
+ raise ValueError(
1293
+ f"Only one metric can be specified for threshold optimization. found {target_clf_thresholds.keys()}"
1294
+ )
1295
+ # Get the single key-value pair or use defaults
1296
+ metric, value = (
1297
+ next(iter(target_clf_thresholds.items()))
1298
+ if target_clf_thresholds
1299
+ else ("precision", 0.8)
1297
1300
  )
1301
+
1302
+ score["THRESHOLDS"] = find_best_threshold(prediction, metric, value)
1303
+
1304
+ # Collect valid metrics across all classes (works for both binary and multiclass)
1305
+ valid_metrics = [
1306
+ m for m in score["THRESHOLDS"].values() if m["threshold"] is not None
1307
+ ]
1308
+
1309
+ if valid_metrics:
1310
+ score["PRECISION_AT_THRESHOLD"] = np.mean(
1311
+ [m["precision"] for m in valid_metrics]
1312
+ )
1313
+ score["RECALL_AT_THRESHOLD"] = np.mean([m["recall"] for m in valid_metrics])
1314
+ score["F1_AT_THRESHOLD"] = np.mean([m["f1"] for m in valid_metrics])
1315
+ else:
1316
+ score["PRECISION_AT_THRESHOLD"] = None
1317
+ score["RECALL_AT_THRESHOLD"] = None
1318
+ score["F1_AT_THRESHOLD"] = None
1319
+
1298
1320
  return score
1299
1321
 
1300
1322
 
@@ -1380,196 +1402,181 @@ def plot_confusion_matrix(y_true, y_pred):
1380
1402
  plt.show()
1381
1403
 
1382
1404
 
1383
- # thresholds
1384
- def find_max_f1_threshold(prediction):
1405
+ def find_best_threshold(
1406
+ prediction: pd.DataFrame, metric: str = "recall", target_value: float | None = None
1407
+ ) -> dict:
1385
1408
  """
1386
- Finds the threshold that maximizes the F1 score for a binary classification task.
1409
+ General function to find best threshold optimizing recall, precision, or f1.
1387
1410
 
1388
- Parameters:
1389
- - prediction: DataFrame with 'TARGET' and '1' (predicted probabilities) columns
1390
-
1391
- Returns:
1392
- - best_threshold: The threshold that maximizes the F1 score
1393
- - best_precision: The precision at that threshold
1394
- - best_recall: The recall at that threshold
1395
- """
1396
- y_true = prediction["TARGET"]
1397
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1398
-
1399
- # Compute precision, recall, and thresholds
1400
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1401
-
1402
- # Drop the first element to align with thresholds
1403
- precision = precision[1:]
1404
- recall = recall[1:]
1405
-
1406
- # Filter out trivial cases (precision or recall = 0)
1407
- valid = (precision > 0) & (recall > 0)
1408
- if not np.any(valid):
1409
- raise ValueError("No valid threshold with non-zero precision and recall")
1410
-
1411
- precision = precision[valid]
1412
- recall = recall[valid]
1413
- thresholds = thresholds[valid]
1414
-
1415
- # Compute F1 scores for each threshold
1416
- f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
1417
-
1418
- best_index = np.argmax(f1_scores)
1419
-
1420
- best_threshold = thresholds[best_index]
1421
- best_precision = precision[best_index]
1422
- best_recall = recall[best_index]
1423
-
1424
- return best_threshold, best_precision, best_recall
1425
-
1426
-
1427
- def find_best_f1_threshold(prediction, fscore_target: float):
1428
- """
1429
- Finds the highest threshold achieving at least the given F1 score target.
1411
+ Supports both binary and multiclass classification.
1430
1412
 
1431
1413
  Parameters:
1432
- - prediction: DataFrame with 'TARGET' and '1' (or 1 as int) columns
1433
- - fscore_target: Desired minimum F1 score (between 0 and 1)
1414
+ - prediction (pd.DataFrame): must contain 'TARGET' and class probability columns.
1415
+ - metric (str): 'recall', 'precision', or 'f1'.
1416
+ - target_value (float | None): minimum acceptable value for the chosen metric.
1434
1417
 
1435
1418
  Returns:
1436
- - best_threshold: The highest threshold meeting the F1 target
1437
- - best_precision: Precision at that threshold
1438
- - best_recall: Recall at that threshold
1439
- - best_f1: Actual F1 score at that threshold
1419
+ - dict: {class_label: {'threshold', 'precision', 'recall', 'f1'}}
1440
1420
  """
1421
+ assert metric in {"recall", "precision", "f1"}, "Invalid metric"
1441
1422
  y_true = prediction["TARGET"]
1442
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1443
-
1444
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1445
-
1446
- # Align precision/recall with thresholds
1447
- precision = precision[1:]
1448
- recall = recall[1:]
1449
- f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
1450
-
1451
- # Filter for thresholds meeting F1 target
1452
- valid_indices = [i for i, f1 in enumerate(f1_scores) if f1 >= fscore_target]
1453
-
1454
- if not valid_indices:
1455
- raise ValueError(f"Could not find a threshold with F1 >= {fscore_target:.2f}")
1456
-
1457
- # Pick the highest threshold among valid ones
1458
- best_index = valid_indices[-1]
1459
-
1460
- return (
1461
- thresholds[best_index],
1462
- precision[best_index],
1463
- recall[best_index],
1464
- f1_scores[best_index],
1465
- )
1466
-
1467
-
1468
- def find_max_precision_threshold_without_trivial_case(prediction: dict):
1469
- """
1470
- Finds the threshold that maximizes precision without reaching a precision of 1,
1471
- which indicates all predictions are classified as the negative class (0).
1472
-
1473
- Parameters:
1474
- - prediction: dict with keys 'TARGET' (true labels) and '1' (predicted probabilities)
1475
-
1476
- Returns:
1477
- - threshold: the probability threshold that maximizes precision
1478
- - actual_recall: the recall achieved at this threshold
1479
- - actual_precision: the precision achieved at this threshold
1480
- """
1481
- y_true = prediction["TARGET"]
1482
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1483
-
1484
- # Compute precision, recall, and thresholds
1485
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1486
-
1487
- # Drop the first element of precision and recall to align with thresholds
1488
- precision = precision[1:]
1489
- recall = recall[1:]
1490
-
1491
- # Filter out precision == 1.0 (which might correspond to predicting only 0s)
1492
- valid_indices = np.where(precision < 1.0)[0]
1493
- if len(valid_indices) == 0:
1494
- raise ValueError("No valid precision values less than 1.0")
1423
+ pred_cols = [
1424
+ col for col in prediction.columns if col not in ["ID", "TARGET", "PRED"]
1425
+ ]
1426
+ classes = [1] if len(pred_cols) <= 2 else sorted(y_true.unique())
1427
+
1428
+ results = {}
1429
+ for cls in classes:
1430
+ cls_str = str(cls)
1431
+ if cls_str not in prediction.columns and cls not in prediction.columns:
1432
+ logger.warning(f"Missing predicted probabilities for class '{cls}'")
1433
+ results[cls_str] = {
1434
+ "threshold": None,
1435
+ "precision": None,
1436
+ "recall": None,
1437
+ "f1": None,
1438
+ }
1439
+ continue
1440
+
1441
+ # Binarize for one-vs-rest
1442
+ y_binary = (y_true == int(cls)).astype(int)
1443
+ y_scores = prediction[cls] if cls in prediction.columns else prediction[cls_str]
1444
+
1445
+ precision, recall, thresholds = precision_recall_curve(y_binary, y_scores)
1446
+ precision, recall = precision[1:], recall[1:] # Align with thresholds
1447
+ thresholds = thresholds
1448
+
1449
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
1450
+
1451
+ metric_values = {"precision": precision, "recall": recall, "f1": f1}
1452
+
1453
+ values = metric_values[metric]
1454
+
1455
+ if target_value is not None:
1456
+ if metric == "recall":
1457
+ # Only keep recall >= target
1458
+ valid_indices = [i for i, r in enumerate(recall) if r >= target_value]
1459
+ if valid_indices:
1460
+ # Pick the highest threshold
1461
+ best_idx = max(valid_indices, key=lambda i: thresholds[i])
1462
+ else:
1463
+ logger.warning(
1464
+ f"[Class {cls}] No threshold with recall ≥ {target_value}"
1465
+ )
1466
+ best_idx = int(np.argmax(recall)) # fallback
1495
1467
 
1496
- precision = precision[valid_indices]
1497
- recall = recall[valid_indices]
1498
- thresholds = thresholds[valid_indices]
1468
+ elif metric == "precision":
1469
+ # Only keep precision ≥ target and recall > 0
1470
+ valid_indices = [
1471
+ i
1472
+ for i, (p, r) in enumerate(zip(precision, recall))
1473
+ if p >= target_value and r > 0
1474
+ ]
1475
+ if valid_indices:
1476
+ # Among valid ones, pick the one with highest recall
1477
+ best_idx = max(valid_indices, key=lambda i: recall[i])
1478
+ else:
1479
+ logger.warning(
1480
+ f"[Class {cls}] No threshold with precision ≥ {target_value}"
1481
+ )
1482
+ best_idx = int(np.argmax(precision)) # fallback
1499
1483
 
1500
- # Find the index of the maximum precision
1501
- best_index = np.argmax(precision)
1484
+ elif metric == "f1":
1485
+ valid_indices = [i for i, val in enumerate(f1) if val >= target_value]
1486
+ if valid_indices:
1487
+ best_idx = max(valid_indices, key=lambda i: f1[i])
1488
+ else:
1489
+ logger.warning(
1490
+ f"[Class {cls}] No threshold with f1 ≥ {target_value}"
1491
+ )
1492
+ best_idx = int(np.argmax(f1)) # fallback
1493
+ else:
1494
+ best_idx = int(np.argmax(values)) # no constraint, get best value
1502
1495
 
1503
- # Return the corresponding threshold, precision, and recall
1504
- best_threshold = thresholds[best_index]
1505
- best_precision = precision[best_index]
1506
- best_recall = recall[best_index]
1496
+ results[cls_str] = {
1497
+ "threshold": float(thresholds[best_idx]),
1498
+ "precision": float(precision[best_idx]),
1499
+ "recall": float(recall[best_idx]),
1500
+ "f1": float(f1[best_idx]),
1501
+ }
1507
1502
 
1508
- return best_threshold, best_precision, best_recall
1503
+ return results
1509
1504
 
1510
1505
 
1511
- def find_best_precision_threshold(prediction, precision_target: float = 0.80):
1506
+ def apply_thresholds(
1507
+ pred_proba: pd.DataFrame, threshold: dict | int | float, classes
1508
+ ) -> pd.DataFrame:
1512
1509
  """
1513
- Finds the highest threshold that achieves at least the given precision target.
1510
+ Apply thresholds to predicted probabilities.
1514
1511
 
1515
1512
  Parameters:
1516
- prediction (pd.DataFrame): DataFrame with columns 'TARGET' and '1' or index 1 for predicted probabilities
1517
- precision_target (float): Desired minimum precision (between 0 and 1)
1513
+ - pred_proba (pd.DataFrame): Probabilities per class.
1514
+ - threshold (float | dict): Global threshold (float) or per-class dict from `find_best_threshold`.
1515
+ - classes (iterable): List or array of class labels (used for binary classification).
1518
1516
 
1519
1517
  Returns:
1520
- tuple: (threshold, precision, recall) achieving the desired precision
1518
+ - pd.DataFrame with "PRED" column and original predicted probabilities.
1521
1519
  """
1522
- y_true = prediction["TARGET"]
1523
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1524
-
1525
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1526
1520
 
1527
- # Align lengths: thresholds is N-1 compared to precision/recall
1528
- thresholds = thresholds
1529
- precision = precision[1:] # Shift to match thresholds
1530
- recall = recall[1:]
1521
+ # Case 1: Per-class thresholds
1522
+ if isinstance(threshold, dict):
1523
+ class_predictions = []
1524
+ class_probabilities = []
1531
1525
 
1532
- valid_indices = [i for i, p in enumerate(precision) if p >= precision_target]
1533
-
1534
- if not valid_indices:
1535
- raise ValueError(
1536
- f"Could not find a threshold with precision >= {precision_target}"
1537
- )
1538
-
1539
- best_idx = valid_indices[-1] # Highest threshold with precision >= target
1540
-
1541
- return thresholds[best_idx], precision[best_idx], recall[best_idx]
1542
-
1543
-
1544
- def find_best_recall_threshold(prediction, recall_target: float = 0.98) -> float:
1545
- """
1546
- Finds the highest threshold that achieves at least the given recall target.
1547
-
1548
- Parameters:
1549
- pred_df (pd.DataFrame): DataFrame with columns 'y_true' and 'y_pred_proba'
1550
- recall_target (float): Desired minimum recall (between 0 and 1)
1551
-
1552
- Returns:
1553
- float: Best threshold achieving the desired recall, or None if not reachable
1554
- """
1555
- y_true = prediction["TARGET"]
1556
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1557
-
1558
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
1526
+ for class_label, metrics in threshold.items():
1527
+ # Get threshold from structured dict
1528
+ _threshold = (
1529
+ metrics.get("threshold") if isinstance(metrics, dict) else metrics[0]
1530
+ )
1531
+ if _threshold is not None:
1532
+ if class_label not in pred_proba.columns:
1533
+ continue # skip missing class
1534
+ col = pred_proba[class_label]
1535
+ exceeded = col >= _threshold
1536
+ class_predictions.append(
1537
+ pd.Series(
1538
+ np.where(exceeded, class_label, -1), index=pred_proba.index
1539
+ )
1540
+ )
1541
+ class_probabilities.append(
1542
+ pd.Series(np.where(exceeded, col, -np.inf), index=pred_proba.index)
1543
+ )
1559
1544
 
1560
- # `thresholds` has length N-1 compared to precision and recall
1561
- recall = recall[1:] # Drop first element to align with thresholds
1562
- precision = precision[1:]
1545
+ if class_predictions:
1546
+ preds_df = pd.concat(class_predictions, axis=1)
1547
+ probs_df = pd.concat(class_probabilities, axis=1)
1563
1548
 
1564
- valid_indices = [i for i, r in enumerate(recall) if r >= recall_target]
1549
+ def select_class(row_pred, row_prob, row_orig):
1550
+ exceeded = row_pred >= 0
1551
+ if exceeded.any():
1552
+ return row_pred.iloc[row_prob.argmax()]
1553
+ return row_orig.idxmax()
1565
1554
 
1566
- if not valid_indices:
1567
- logger.warning(f"Could not find a threshold with recall >= {recall_target}")
1568
- return None, None, None
1555
+ pred = pd.Series(
1556
+ [
1557
+ select_class(
1558
+ preds_df.loc[idx], probs_df.loc[idx], pred_proba.loc[idx]
1559
+ )
1560
+ for idx in pred_proba.index
1561
+ ],
1562
+ index=pred_proba.index,
1563
+ name="PRED",
1564
+ )
1565
+ else:
1566
+ # fallback: take max probability if no thresholds apply
1567
+ pred = pred_proba.idxmax(axis=1).rename("PRED")
1569
1568
 
1570
- best_idx = valid_indices[-1] # Highest threshold with recall >= target
1569
+ # Case 2: Global scalar threshold (e.g., 0.5 for binary)
1570
+ else:
1571
+ if len(classes) == 2:
1572
+ # Binary classification: threshold on positive class
1573
+ pos_class = classes[1]
1574
+ pred = (pred_proba[pos_class] >= threshold).astype(int).rename("PRED")
1575
+ else:
1576
+ # Multiclass: default to max probability
1577
+ pred = pred_proba.idxmax(axis=1).rename("PRED")
1571
1578
 
1572
- return thresholds[best_idx], precision[best_idx], recall[best_idx]
1579
+ return pd.concat([pred, pred_proba], axis=1)
1573
1580
 
1574
1581
 
1575
1582
  def plot_threshold(prediction, threshold, precision, recall):
@@ -1629,7 +1636,7 @@ def get_pred_distribution(training_target_dir: str, model_name="linear"):
1629
1636
  Look at prediction distributions
1630
1637
  """
1631
1638
  prediction = pd.read_csv(
1632
- f"{training_target_dir}/{model_name}/pred_val.csv",
1639
+ f"{training_target_dir}/{model_name}/prediction.csv",
1633
1640
  index_col="ID",
1634
1641
  )
1635
1642
  prediction.describe()