lecrapaud 0.11.2__py3-none-any.whl → 0.11.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lecrapaud might be problematic. Click here for more details.
- lecrapaud/api.py +15 -2
- lecrapaud/db/alembic/env.py +1 -1
- lecrapaud/db/alembic.ini +8 -5
- lecrapaud/model_selection.py +116 -31
- lecrapaud/utils.py +1 -1
- {lecrapaud-0.11.2.dist-info → lecrapaud-0.11.4.dist-info}/METADATA +1 -1
- {lecrapaud-0.11.2.dist-info → lecrapaud-0.11.4.dist-info}/RECORD +9 -9
- {lecrapaud-0.11.2.dist-info → lecrapaud-0.11.4.dist-info}/LICENSE +0 -0
- {lecrapaud-0.11.2.dist-info → lecrapaud-0.11.4.dist-info}/WHEEL +0 -0
lecrapaud/api.py
CHANGED
|
@@ -27,11 +27,12 @@ Basic Usage:
|
|
|
27
27
|
|
|
28
28
|
import joblib
|
|
29
29
|
import pandas as pd
|
|
30
|
+
import ast
|
|
31
|
+
import os
|
|
30
32
|
import logging
|
|
31
33
|
import seaborn as sns
|
|
32
34
|
import numpy as np
|
|
33
35
|
import matplotlib.pyplot as plt
|
|
34
|
-
from lecrapaud.utils import logger
|
|
35
36
|
from lecrapaud.db.session import init_db
|
|
36
37
|
from lecrapaud.feature_selection import FeatureSelectionEngine, PreprocessModel
|
|
37
38
|
from lecrapaud.model_selection import (
|
|
@@ -46,6 +47,8 @@ from lecrapaud.feature_engineering import FeatureEngineeringEngine, PreprocessFe
|
|
|
46
47
|
from lecrapaud.experiment import create_experiment
|
|
47
48
|
from lecrapaud.db import Experiment
|
|
48
49
|
from lecrapaud.search_space import normalize_models_idx
|
|
50
|
+
from lecrapaud.utils import logger
|
|
51
|
+
from lecrapaud.directories import tmp_dir
|
|
49
52
|
|
|
50
53
|
|
|
51
54
|
class LeCrapaud:
|
|
@@ -108,6 +111,12 @@ class ExperimentEngine:
|
|
|
108
111
|
if id:
|
|
109
112
|
self.experiment = Experiment.get(id)
|
|
110
113
|
kwargs.update(self.experiment.context)
|
|
114
|
+
experiment_dir = f"{tmp_dir}/{self.experiment.name}"
|
|
115
|
+
preprocessing_dir = f"{experiment_dir}/preprocessing"
|
|
116
|
+
data_dir = f"{experiment_dir}/data"
|
|
117
|
+
os.makedirs(experiment_dir, exist_ok=True)
|
|
118
|
+
os.makedirs(preprocessing_dir, exist_ok=True)
|
|
119
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
111
120
|
else:
|
|
112
121
|
if data is None:
|
|
113
122
|
raise ValueError("Either id or data must be provided")
|
|
@@ -344,9 +353,13 @@ class ExperimentEngine:
|
|
|
344
353
|
return pd.read_csv(f"{self.experiment.path}/feature_summary.csv")
|
|
345
354
|
|
|
346
355
|
def get_threshold(self, target_number: int):
|
|
347
|
-
|
|
356
|
+
thresholds = joblib.load(
|
|
348
357
|
f"{self.experiment.path}/TARGET_{target_number}/thresholds.pkl"
|
|
349
358
|
)
|
|
359
|
+
if isinstance(thresholds, str):
|
|
360
|
+
thresholds = ast.literal_eval(thresholds)
|
|
361
|
+
|
|
362
|
+
return thresholds
|
|
350
363
|
|
|
351
364
|
def load_model(self, target_number: int, model_name: str = None):
|
|
352
365
|
|
lecrapaud/db/alembic/env.py
CHANGED
|
@@ -15,7 +15,7 @@ config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
|
|
15
15
|
# Interpret the config file for Python logging.
|
|
16
16
|
# This line sets up loggers basically.
|
|
17
17
|
if config.config_file_name is not None:
|
|
18
|
-
fileConfig(config.config_file_name)
|
|
18
|
+
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
|
19
19
|
|
|
20
20
|
# add your model's MetaData object here
|
|
21
21
|
# for 'autogenerate' support
|
lecrapaud/db/alembic.ini
CHANGED
|
@@ -84,11 +84,14 @@ sqlalchemy.url = %(DATABASE_URL)s
|
|
|
84
84
|
[loggers]
|
|
85
85
|
keys = root,sqlalchemy,alembic
|
|
86
86
|
|
|
87
|
+
[loggers_root]
|
|
88
|
+
disable_existing_loggers = False
|
|
89
|
+
|
|
87
90
|
[handlers]
|
|
88
91
|
keys = console
|
|
89
92
|
|
|
90
93
|
[formatters]
|
|
91
|
-
keys =
|
|
94
|
+
keys = lecrapaud_format
|
|
92
95
|
|
|
93
96
|
[logger_root]
|
|
94
97
|
level = WARN
|
|
@@ -109,8 +112,8 @@ qualname = alembic
|
|
|
109
112
|
class = StreamHandler
|
|
110
113
|
args = (sys.stderr,)
|
|
111
114
|
level = NOTSET
|
|
112
|
-
formatter =
|
|
115
|
+
formatter = lecrapaud_format
|
|
113
116
|
|
|
114
|
-
[
|
|
115
|
-
format = %(
|
|
116
|
-
datefmt = %H:%M:%S
|
|
117
|
+
[formatter_lecrapaud_format]
|
|
118
|
+
format = %(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s
|
|
119
|
+
datefmt = %Y-%m-%d %H:%M:%S
|
lecrapaud/model_selection.py
CHANGED
|
@@ -1328,46 +1328,131 @@ def load_model(target_dir: str):
|
|
|
1328
1328
|
# plots
|
|
1329
1329
|
def plot_evaluation_for_classification(prediction: dict):
|
|
1330
1330
|
"""
|
|
1331
|
-
|
|
1332
|
-
|
|
1331
|
+
Plot evaluation metrics for classification tasks (both binary and multiclass).
|
|
1332
|
+
|
|
1333
|
+
Args:
|
|
1334
|
+
prediction (pd.DataFrame): Should be a df with:
|
|
1335
|
+
- TARGET: true labels
|
|
1336
|
+
- PRED: predicted labels
|
|
1337
|
+
- For binary: column '1' or 1 for positive class probabilities
|
|
1338
|
+
- For multiclass: columns 2 onwards for class probabilities
|
|
1333
1339
|
"""
|
|
1334
1340
|
y_true = prediction["TARGET"]
|
|
1335
1341
|
y_pred = prediction["PRED"]
|
|
1336
|
-
y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
|
|
1337
1342
|
|
|
1338
1343
|
# Plot confusion matrix
|
|
1339
1344
|
plot_confusion_matrix(y_true, y_pred)
|
|
1340
1345
|
|
|
1341
|
-
#
|
|
1342
|
-
|
|
1343
|
-
|
|
1346
|
+
# Determine if binary or multiclass
|
|
1347
|
+
unique_labels = np.unique(y_true)
|
|
1348
|
+
unique_labels = np.sort(unique_labels)
|
|
1349
|
+
n_classes = len(unique_labels)
|
|
1350
|
+
|
|
1351
|
+
if n_classes <= 2:
|
|
1352
|
+
# Binary classification
|
|
1353
|
+
y_pred_proba = prediction[1] if 1 in prediction.columns else prediction["1"]
|
|
1354
|
+
|
|
1355
|
+
# Compute and plot ROC curve
|
|
1356
|
+
fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
|
|
1357
|
+
roc_auc = auc(fpr, tpr)
|
|
1358
|
+
|
|
1359
|
+
plt.figure(figsize=(8, 8))
|
|
1360
|
+
plt.plot(
|
|
1361
|
+
fpr,
|
|
1362
|
+
tpr,
|
|
1363
|
+
color="darkorange",
|
|
1364
|
+
lw=2,
|
|
1365
|
+
label=f"ROC curve (area = {roc_auc:0.2f})",
|
|
1366
|
+
)
|
|
1367
|
+
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
|
|
1368
|
+
plt.xlim([0.0, 1.0])
|
|
1369
|
+
plt.ylim([0.0, 1.05])
|
|
1370
|
+
plt.xlabel("False Positive Rate")
|
|
1371
|
+
plt.ylabel("True Positive Rate")
|
|
1372
|
+
plt.title("ROC Curve")
|
|
1373
|
+
plt.legend(loc="lower right")
|
|
1374
|
+
plt.show()
|
|
1375
|
+
|
|
1376
|
+
# Compute and plot precision-recall curve
|
|
1377
|
+
precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
|
|
1378
|
+
average_precision = average_precision_score(y_true, y_pred_proba)
|
|
1379
|
+
|
|
1380
|
+
plt.figure(figsize=(8, 8))
|
|
1381
|
+
plt.step(recall, precision, color="b", alpha=0.2, where="post")
|
|
1382
|
+
plt.fill_between(recall, precision, step="post", alpha=0.2, color="b")
|
|
1383
|
+
plt.xlabel("Recall")
|
|
1384
|
+
plt.ylabel("Precision")
|
|
1385
|
+
plt.ylim([0.0, 1.05])
|
|
1386
|
+
plt.xlim([0.0, 1.0])
|
|
1387
|
+
plt.title(f"Precision-Recall Curve: AP={average_precision:0.2f}")
|
|
1388
|
+
plt.show()
|
|
1344
1389
|
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
plt.xlabel("False Positive Rate")
|
|
1353
|
-
plt.ylabel("True Positive Rate")
|
|
1354
|
-
plt.title("ROC Curve")
|
|
1355
|
-
plt.legend(loc="lower right")
|
|
1356
|
-
plt.show()
|
|
1390
|
+
else:
|
|
1391
|
+
# Multiclass classification
|
|
1392
|
+
# Get class probabilities
|
|
1393
|
+
pred_cols = [
|
|
1394
|
+
col for col in prediction.columns if col not in ["ID", "TARGET", "PRED"]
|
|
1395
|
+
]
|
|
1396
|
+
y_pred_proba = prediction[pred_cols].values
|
|
1357
1397
|
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1398
|
+
# Compute ROC curve and ROC area for each class
|
|
1399
|
+
fpr = dict()
|
|
1400
|
+
tpr = dict()
|
|
1401
|
+
roc_auc = dict()
|
|
1402
|
+
|
|
1403
|
+
plt.figure(figsize=(10, 8))
|
|
1404
|
+
colors = plt.cm.get_cmap("tab10")(np.linspace(0, 1, n_classes))
|
|
1405
|
+
|
|
1406
|
+
for i, (label, color) in enumerate(zip(unique_labels, colors)):
|
|
1407
|
+
y_true_binary = (y_true == label).astype(int)
|
|
1408
|
+
y_score = y_pred_proba[:, i]
|
|
1409
|
+
|
|
1410
|
+
fpr[i], tpr[i], _ = roc_curve(y_true_binary, y_score)
|
|
1411
|
+
roc_auc[i] = auc(fpr[i], tpr[i])
|
|
1412
|
+
|
|
1413
|
+
plt.plot(
|
|
1414
|
+
fpr[i],
|
|
1415
|
+
tpr[i],
|
|
1416
|
+
color=color,
|
|
1417
|
+
lw=2,
|
|
1418
|
+
label=f"Class {label} (area = {roc_auc[i]:0.2f})",
|
|
1419
|
+
)
|
|
1420
|
+
|
|
1421
|
+
plt.plot([0, 1], [0, 1], "k--", lw=2)
|
|
1422
|
+
plt.xlim([0.0, 1.0])
|
|
1423
|
+
plt.ylim([0.0, 1.05])
|
|
1424
|
+
plt.xlabel("False Positive Rate")
|
|
1425
|
+
plt.ylabel("True Positive Rate")
|
|
1426
|
+
plt.title("Multiclass ROC Curves (One-vs-Rest)")
|
|
1427
|
+
plt.legend(loc="lower right")
|
|
1428
|
+
plt.show()
|
|
1429
|
+
|
|
1430
|
+
# Compute PR curve for each class
|
|
1431
|
+
plt.figure(figsize=(10, 8))
|
|
1432
|
+
|
|
1433
|
+
for i, (label, color) in enumerate(zip(unique_labels, colors)):
|
|
1434
|
+
y_true_binary = (y_true == label).astype(int)
|
|
1435
|
+
y_score = y_pred_proba[:, i]
|
|
1436
|
+
|
|
1437
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
1438
|
+
average_precision = average_precision_score(y_true_binary, y_score)
|
|
1439
|
+
|
|
1440
|
+
plt.step(
|
|
1441
|
+
recall,
|
|
1442
|
+
precision,
|
|
1443
|
+
color=color,
|
|
1444
|
+
alpha=0.8,
|
|
1445
|
+
where="post",
|
|
1446
|
+
label=f"Class {label} (AP = {average_precision:0.2f})",
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
plt.xlabel("Recall")
|
|
1450
|
+
plt.ylabel("Precision")
|
|
1451
|
+
plt.ylim([0.0, 1.05])
|
|
1452
|
+
plt.xlim([0.0, 1.0])
|
|
1453
|
+
plt.title("Multiclass Precision-Recall Curves")
|
|
1454
|
+
plt.legend(loc="lower left")
|
|
1455
|
+
plt.show()
|
|
1371
1456
|
|
|
1372
1457
|
|
|
1373
1458
|
def plot_confusion_matrix(y_true, y_pred):
|
lecrapaud/utils.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
lecrapaud/__init__.py,sha256=oCxbtw_nk8rlOXbXbWo0RRMlsh6w-hTiZ6e5PRG_wp0,28
|
|
2
|
-
lecrapaud/api.py,sha256=
|
|
2
|
+
lecrapaud/api.py,sha256=nh1dRcqDpEnyOMjvayUNg_DR1D26gXCQ7hZpsYENqk0,17178
|
|
3
3
|
lecrapaud/config.py,sha256=eYnrktVq457xMIMGcUSilJdNxCsaGP_gRAlzCSwd6Vo,1047
|
|
4
4
|
lecrapaud/db/__init__.py,sha256=82o9fMfaqKXPh2_rt44EzNRVZV1R4LScEnQYvj_TjK0,34
|
|
5
5
|
lecrapaud/db/alembic/README,sha256=MVlc9TYmr57RbhXET6QxgyCcwWP7w-vLkEsirENqiIQ,38
|
|
6
|
-
lecrapaud/db/alembic/env.py,sha256=
|
|
6
|
+
lecrapaud/db/alembic/env.py,sha256=0VdxHNIxhPCgUHnx6EwlVZLUMLlbqZ_eV7i0Ho2XqeI,2337
|
|
7
7
|
lecrapaud/db/alembic/script.py.mako,sha256=MEqL-2qATlST9TAOeYgscMn1uy6HUS9NFvDgl93dMj8,635
|
|
8
8
|
lecrapaud/db/alembic/versions/2025_06_23_1748-f089dfb7e3ba_.py,sha256=MNPyqWaQSHNV8zljD1G9f-LzrVz-nOKlgOhHEE0U8Oo,13060
|
|
9
9
|
lecrapaud/db/alembic/versions/2025_06_24_1216-c62251b129ed_.py,sha256=g6aLRV6jAKXkPUEcs9FAeGfsYpe9rMTxfqbNib3U0-U,809
|
|
10
10
|
lecrapaud/db/alembic/versions/2025_06_24_1711-86457e2f333f_.py,sha256=dl6tfvcqErgJ6NKvjve0euu7l0BWyEAKSS-ychsEAl8,1139
|
|
11
11
|
lecrapaud/db/alembic/versions/2025_06_25_1759-72aa496ca65b_.py,sha256=sBgPLvvqI_HmPqQ0Kime1ZL1AHSeuYJHlmFJOnXWeuU,835
|
|
12
|
-
lecrapaud/db/alembic.ini,sha256=
|
|
12
|
+
lecrapaud/db/alembic.ini,sha256=TXrZB4pWVLn2EUg867yp6paA_19vGeirO95mTPA3nbs,3699
|
|
13
13
|
lecrapaud/db/models/__init__.py,sha256=Lhyw9fVLdom0Fc6yIP-ip8FjkU1EwVwjae5q2VM815Q,740
|
|
14
14
|
lecrapaud/db/models/base.py,sha256=CYtof_UjFwX3C7XUifequh_UtLHJ25bU7LCwT501uGE,7508
|
|
15
15
|
lecrapaud/db/models/experiment.py,sha256=IeS-TWPT-4l9xCMIdR2S2O-foXNt3Ru6WmtPMWToK7c,4035
|
|
@@ -35,10 +35,10 @@ lecrapaud/misc/tabpfn_tests.ipynb,sha256=VkgsCUJ30d8jaL2VaWtQAgb8ngHPNtPgnXLs7QQ
|
|
|
35
35
|
lecrapaud/misc/test-gpu-bilstm.ipynb,sha256=4nLuZRJVe2kn6kEmauhRiz5wkWT9AVrYhI9CEk_dYUY,9608
|
|
36
36
|
lecrapaud/misc/test-gpu-resnet.ipynb,sha256=27Vu7nYwujYeh3fOxBNCnKJn3MXNPKZU-U8oDDUbymg,4944
|
|
37
37
|
lecrapaud/misc/test-gpu-transformers.ipynb,sha256=k6MBSs_Um1h4PykvE-LTBcdpbWLbIFST_xl_AFW2jgI,8444
|
|
38
|
-
lecrapaud/model_selection.py,sha256=
|
|
38
|
+
lecrapaud/model_selection.py,sha256=h4WPtGCUeuWIXDJ8L2-i1I7RwrZlnxAresGW5l8bGwE,63195
|
|
39
39
|
lecrapaud/search_space.py,sha256=-JkzuMhaomdwiWi4HvVQY5hiw3-oREemJA16tbwEIp4,34854
|
|
40
|
-
lecrapaud/utils.py,sha256=
|
|
41
|
-
lecrapaud-0.11.
|
|
42
|
-
lecrapaud-0.11.
|
|
43
|
-
lecrapaud-0.11.
|
|
44
|
-
lecrapaud-0.11.
|
|
40
|
+
lecrapaud/utils.py,sha256=JdBB1NvbNIx4y0Una-kSZdo1_ZEocc5hwyYFIZKHmGg,8305
|
|
41
|
+
lecrapaud-0.11.4.dist-info/LICENSE,sha256=MImCryu0AnqhJE_uAZD-PIDKXDKb8sT7v0i1NOYeHTM,11350
|
|
42
|
+
lecrapaud-0.11.4.dist-info/METADATA,sha256=3O_6bcQaCfragLXVyAsZWVpP9lvxEziUo9DqWzRE1r4,11017
|
|
43
|
+
lecrapaud-0.11.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
44
|
+
lecrapaud-0.11.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|