lecrapaud 0.11.1__py3-none-any.whl → 0.11.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lecrapaud might be problematic. Click here for more details.

lecrapaud/api.py CHANGED
@@ -27,11 +27,11 @@ Basic Usage:
27
27
 
28
28
  import joblib
29
29
  import pandas as pd
30
+ import ast
30
31
  import logging
31
32
  import seaborn as sns
32
33
  import numpy as np
33
34
  import matplotlib.pyplot as plt
34
- from lecrapaud.utils import logger
35
35
  from lecrapaud.db.session import init_db
36
36
  from lecrapaud.feature_selection import FeatureSelectionEngine, PreprocessModel
37
37
  from lecrapaud.model_selection import (
@@ -46,6 +46,7 @@ from lecrapaud.feature_engineering import FeatureEngineeringEngine, PreprocessFe
46
46
  from lecrapaud.experiment import create_experiment
47
47
  from lecrapaud.db import Experiment
48
48
  from lecrapaud.search_space import normalize_models_idx
49
+ from lecrapaud.utils import logger
49
50
 
50
51
 
51
52
  class LeCrapaud:
@@ -85,7 +86,7 @@ class LeCrapaud:
85
86
  """
86
87
  return ExperimentEngine(id=id, **kwargs)
87
88
 
88
- def list_experiments(self, limit=1000) -> list[ExperimentEngine]:
89
+ def list_experiments(self, limit=1000) -> list["ExperimentEngine"]:
89
90
  """List all experiments in the database."""
90
91
  return [ExperimentEngine(id=exp.id) for exp in Experiment.get_all(limit=limit)]
91
92
 
@@ -344,9 +345,13 @@ class ExperimentEngine:
344
345
  return pd.read_csv(f"{self.experiment.path}/feature_summary.csv")
345
346
 
346
347
  def get_threshold(self, target_number: int):
347
- return joblib.load(
348
+ thresholds = joblib.load(
348
349
  f"{self.experiment.path}/TARGET_{target_number}/thresholds.pkl"
349
350
  )
351
+ if isinstance(thresholds, str):
352
+ thresholds = ast.literal_eval(thresholds)
353
+
354
+ return thresholds
350
355
 
351
356
  def load_model(self, target_number: int, model_name: str = None):
352
357
 
@@ -1328,46 +1328,131 @@ def load_model(target_dir: str):
1328
1328
  # plots
1329
1329
  def plot_evaluation_for_classification(prediction: dict):
1330
1330
  """
1331
- Args
1332
- prediction (pd.DataFrame): Should be a df with TARGET, PRED, 0, 1 columns for y_true value (TARGET), y_pred (PRED), and probabilities (for classification only : 0 and 1)
1331
+ Plot evaluation metrics for classification tasks (both binary and multiclass).
1332
+
1333
+ Args:
1334
+ prediction (pd.DataFrame): Should be a df with:
1335
+ - TARGET: true labels
1336
+ - PRED: predicted labels
1337
+ - For binary: column '1' or 1 for positive class probabilities
1338
+ - For multiclass: columns 2 onwards for class probabilities
1333
1339
  """
1334
1340
  y_true = prediction["TARGET"]
1335
1341
  y_pred = prediction["PRED"]
1336
- y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1337
1342
 
1338
1343
  # Plot confusion matrix
1339
1344
  plot_confusion_matrix(y_true, y_pred)
1340
1345
 
1341
- # Compute ROC curve and ROC area
1342
- fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)
1343
- roc_auc = auc(fpr, tpr)
1346
+ # Determine if binary or multiclass
1347
+ unique_labels = np.unique(y_true)
1348
+ unique_labels = np.sort(unique_labels)
1349
+ n_classes = len(unique_labels)
1350
+
1351
+ if n_classes <= 2:
1352
+ # Binary classification
1353
+ y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
1354
+
1355
+ # Compute and plot ROC curve
1356
+ fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
1357
+ roc_auc = auc(fpr, tpr)
1358
+
1359
+ plt.figure(figsize=(8, 8))
1360
+ plt.plot(
1361
+ fpr,
1362
+ tpr,
1363
+ color="darkorange",
1364
+ lw=2,
1365
+ label=f"ROC curve (area = {roc_auc:0.2f})",
1366
+ )
1367
+ plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
1368
+ plt.xlim([0.0, 1.0])
1369
+ plt.ylim([0.0, 1.05])
1370
+ plt.xlabel("False Positive Rate")
1371
+ plt.ylabel("True Positive Rate")
1372
+ plt.title("ROC Curve")
1373
+ plt.legend(loc="lower right")
1374
+ plt.show()
1375
+
1376
+ # Compute and plot precision-recall curve
1377
+ precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
1378
+ average_precision = average_precision_score(y_true, y_pred_proba)
1379
+
1380
+ plt.figure(figsize=(8, 8))
1381
+ plt.step(recall, precision, color="b", alpha=0.2, where="post")
1382
+ plt.fill_between(recall, precision, step="post", alpha=0.2, color="b")
1383
+ plt.xlabel("Recall")
1384
+ plt.ylabel("Precision")
1385
+ plt.ylim([0.0, 1.05])
1386
+ plt.xlim([0.0, 1.0])
1387
+ plt.title(f"Precision-Recall Curve: AP={average_precision:0.2f}")
1388
+ plt.show()
1344
1389
 
1345
- plt.figure(figsize=(8, 8))
1346
- plt.plot(
1347
- fpr, tpr, color="darkorange", lw=2, label="ROC curve (area = %0.2f)" % roc_auc
1348
- )
1349
- plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
1350
- plt.xlim([0.0, 1.0])
1351
- plt.ylim([0.0, 1.05])
1352
- plt.xlabel("False Positive Rate")
1353
- plt.ylabel("True Positive Rate")
1354
- plt.title("ROC Curve")
1355
- plt.legend(loc="lower right")
1356
- plt.show()
1390
+ else:
1391
+ # Multiclass classification
1392
+ # Get class probabilities
1393
+ pred_cols = [
1394
+ col for col in prediction.columns if col not in ["ID", "TARGET", "PRED"]
1395
+ ]
1396
+ y_pred_proba = prediction[pred_cols].values
1357
1397
 
1358
- # Compute precision-recall curve
1359
- precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
1360
- average_precision = average_precision_score(y_true, y_pred_proba)
1361
-
1362
- plt.figure(figsize=(8, 8))
1363
- plt.step(recall, precision, color="b", alpha=0.2, where="post")
1364
- plt.fill_between(recall, precision, step="post", alpha=0.2, color="b")
1365
- plt.xlabel("Recall")
1366
- plt.ylabel("Precision")
1367
- plt.ylim([0.0, 1.05])
1368
- plt.xlim([0.0, 1.0])
1369
- plt.title("Precision-Recall Curve: AP={0:0.2f}".format(average_precision))
1370
- plt.show()
1398
+ # Compute ROC curve and ROC area for each class
1399
+ fpr = dict()
1400
+ tpr = dict()
1401
+ roc_auc = dict()
1402
+
1403
+ plt.figure(figsize=(10, 8))
1404
+ colors = plt.cm.get_cmap("tab10")(np.linspace(0, 1, n_classes))
1405
+
1406
+ for i, (label, color) in enumerate(zip(unique_labels, colors)):
1407
+ y_true_binary = (y_true == label).astype(int)
1408
+ y_score = y_pred_proba[:, i]
1409
+
1410
+ fpr[i], tpr[i], _ = roc_curve(y_true_binary, y_score)
1411
+ roc_auc[i] = auc(fpr[i], tpr[i])
1412
+
1413
+ plt.plot(
1414
+ fpr[i],
1415
+ tpr[i],
1416
+ color=color,
1417
+ lw=2,
1418
+ label=f"Class {label} (area = {roc_auc[i]:0.2f})",
1419
+ )
1420
+
1421
+ plt.plot([0, 1], [0, 1], "k--", lw=2)
1422
+ plt.xlim([0.0, 1.0])
1423
+ plt.ylim([0.0, 1.05])
1424
+ plt.xlabel("False Positive Rate")
1425
+ plt.ylabel("True Positive Rate")
1426
+ plt.title("Multiclass ROC Curves (One-vs-Rest)")
1427
+ plt.legend(loc="lower right")
1428
+ plt.show()
1429
+
1430
+ # Compute PR curve for each class
1431
+ plt.figure(figsize=(10, 8))
1432
+
1433
+ for i, (label, color) in enumerate(zip(unique_labels, colors)):
1434
+ y_true_binary = (y_true == label).astype(int)
1435
+ y_score = y_pred_proba[:, i]
1436
+
1437
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
1438
+ average_precision = average_precision_score(y_true_binary, y_score)
1439
+
1440
+ plt.step(
1441
+ recall,
1442
+ precision,
1443
+ color=color,
1444
+ alpha=0.8,
1445
+ where="post",
1446
+ label=f"Class {label} (AP = {average_precision:0.2f})",
1447
+ )
1448
+
1449
+ plt.xlabel("Recall")
1450
+ plt.ylabel("Precision")
1451
+ plt.ylim([0.0, 1.05])
1452
+ plt.xlim([0.0, 1.0])
1453
+ plt.title("Multiclass Precision-Recall Curves")
1454
+ plt.legend(loc="lower left")
1455
+ plt.show()
1371
1456
 
1372
1457
 
1373
1458
  def plot_confusion_matrix(y_true, y_pred):
lecrapaud/utils.py CHANGED
@@ -9,7 +9,7 @@ from ftfy import fix_text
9
9
  import unicodedata
10
10
  import re
11
11
  import string
12
-
12
+ import sys
13
13
  from lecrapaud.directories import logger_dir
14
14
  from lecrapaud.config import LOGGING_LEVEL, PYTHON_ENV, LECRAPAUD_LOCAL
15
15
 
@@ -57,6 +57,11 @@ def setup_logger():
57
57
  file_handler.setLevel(log_level)
58
58
  logger.addHandler(file_handler)
59
59
 
60
+ stream_handler = logging.StreamHandler(sys.stdout)
61
+ stream_handler.setFormatter(formatter)
62
+ stream_handler.setLevel(log_level)
63
+ logger.addHandler(stream_handler)
64
+
60
65
  _LECRAPAUD_LOGGER_ALREADY_CONFIGURED = True
61
66
  return logger
62
67
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: lecrapaud
3
- Version: 0.11.1
3
+ Version: 0.11.3
4
4
  Summary: Framework for machine and deep learning, with regression, classification and time series analysis
5
5
  License: Apache License
6
6
  Author: Pierre H. Gallet
@@ -1,5 +1,5 @@
1
1
  lecrapaud/__init__.py,sha256=oCxbtw_nk8rlOXbXbWo0RRMlsh6w-hTiZ6e5PRG_wp0,28
2
- lecrapaud/api.py,sha256=wrMc3TaP5qCzGvmN0QsYKxUt2ZPzK3z4nmnetQo23io,16645
2
+ lecrapaud/api.py,sha256=WbaqDW0gybW3BGgCIGJJ929HKyhXFYswjB-3KdHuVNE,16785
3
3
  lecrapaud/config.py,sha256=eYnrktVq457xMIMGcUSilJdNxCsaGP_gRAlzCSwd6Vo,1047
4
4
  lecrapaud/db/__init__.py,sha256=82o9fMfaqKXPh2_rt44EzNRVZV1R4LScEnQYvj_TjK0,34
5
5
  lecrapaud/db/alembic/README,sha256=MVlc9TYmr57RbhXET6QxgyCcwWP7w-vLkEsirENqiIQ,38
@@ -35,10 +35,10 @@ lecrapaud/misc/tabpfn_tests.ipynb,sha256=VkgsCUJ30d8jaL2VaWtQAgb8ngHPNtPgnXLs7QQ
35
35
  lecrapaud/misc/test-gpu-bilstm.ipynb,sha256=4nLuZRJVe2kn6kEmauhRiz5wkWT9AVrYhI9CEk_dYUY,9608
36
36
  lecrapaud/misc/test-gpu-resnet.ipynb,sha256=27Vu7nYwujYeh3fOxBNCnKJn3MXNPKZU-U8oDDUbymg,4944
37
37
  lecrapaud/misc/test-gpu-transformers.ipynb,sha256=k6MBSs_Um1h4PykvE-LTBcdpbWLbIFST_xl_AFW2jgI,8444
38
- lecrapaud/model_selection.py,sha256=PQGEWVWN-4ZeHCqrmXBpHgq1QZi_1nOOeu5gazXGDLQ,60487
38
+ lecrapaud/model_selection.py,sha256=h4WPtGCUeuWIXDJ8L2-i1I7RwrZlnxAresGW5l8bGwE,63195
39
39
  lecrapaud/search_space.py,sha256=-JkzuMhaomdwiWi4HvVQY5hiw3-oREemJA16tbwEIp4,34854
40
- lecrapaud/utils.py,sha256=MUgDoJ31GOF8WRLn_WLzDbHw7OTKxq_ldnZT6dpxdQo,8295
41
- lecrapaud-0.11.1.dist-info/LICENSE,sha256=MImCryu0AnqhJE_uAZD-PIDKXDKb8sT7v0i1NOYeHTM,11350
42
- lecrapaud-0.11.1.dist-info/METADATA,sha256=YDEIQa4j_87wQqkW4SKzeomDWgUjKCFdEYn55nO41MI,11017
43
- lecrapaud-0.11.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
44
- lecrapaud-0.11.1.dist-info/RECORD,,
40
+ lecrapaud/utils.py,sha256=TMsVCFLyli6ww3eFnn1G8AB5LRyU-neG_MeGl-nsTPA,8481
41
+ lecrapaud-0.11.3.dist-info/LICENSE,sha256=MImCryu0AnqhJE_uAZD-PIDKXDKb8sT7v0i1NOYeHTM,11350
42
+ lecrapaud-0.11.3.dist-info/METADATA,sha256=PXjjkqxmNlvhMKgplcTUy8P9ZPeE_Guk7ela4IyUq-c,11017
43
+ lecrapaud-0.11.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
44
+ lecrapaud-0.11.3.dist-info/RECORD,,