lecrapaud 0.19.0__py3-none-any.whl → 0.22.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lecrapaud/__init__.py +22 -1
- lecrapaud/{api.py → base.py} +331 -241
- lecrapaud/config.py +15 -3
- lecrapaud/db/alembic/versions/2025_10_25_0635-07e303521594_add_unique_constraint_to_score.py +39 -0
- lecrapaud/db/alembic/versions/2025_10_26_1727-033e0f7eca4f_merge_score_and_model_trainings_into_.py +264 -0
- lecrapaud/db/alembic/versions/2025_10_28_2006-0a8fb7826e9b_add_number_of_targets_and_remove_other_.py +75 -0
- lecrapaud/db/models/__init__.py +2 -4
- lecrapaud/db/models/base.py +116 -65
- lecrapaud/db/models/experiment.py +195 -182
- lecrapaud/db/models/feature_selection.py +0 -3
- lecrapaud/db/models/feature_selection_rank.py +0 -18
- lecrapaud/db/models/model_selection.py +2 -2
- lecrapaud/db/models/{score.py → model_selection_score.py} +29 -12
- lecrapaud/db/session.py +4 -0
- lecrapaud/experiment.py +44 -17
- lecrapaud/feature_engineering.py +45 -674
- lecrapaud/feature_preprocessing.py +1202 -0
- lecrapaud/feature_selection.py +145 -332
- lecrapaud/integrations/sentry_integration.py +46 -0
- lecrapaud/misc/tabpfn_tests.ipynb +2 -2
- lecrapaud/mixins.py +247 -0
- lecrapaud/model_preprocessing.py +295 -0
- lecrapaud/model_selection.py +612 -242
- lecrapaud/pipeline.py +548 -0
- lecrapaud/search_space.py +2 -1
- lecrapaud/utils.py +36 -3
- lecrapaud-0.22.6.dist-info/METADATA +423 -0
- lecrapaud-0.22.6.dist-info/RECORD +51 -0
- {lecrapaud-0.19.0.dist-info → lecrapaud-0.22.6.dist-info}/WHEEL +1 -1
- {lecrapaud-0.19.0.dist-info → lecrapaud-0.22.6.dist-info/licenses}/LICENSE +1 -1
- lecrapaud/db/models/model_training.py +0 -64
- lecrapaud/jobs/__init__.py +0 -13
- lecrapaud/jobs/config.py +0 -17
- lecrapaud/jobs/scheduler.py +0 -30
- lecrapaud/jobs/tasks.py +0 -17
- lecrapaud-0.19.0.dist-info/METADATA +0 -249
- lecrapaud-0.19.0.dist-info/RECORD +0 -48
lecrapaud/config.py
CHANGED
|
@@ -4,8 +4,6 @@ from dotenv import load_dotenv
|
|
|
4
4
|
load_dotenv(override=False)
|
|
5
5
|
|
|
6
6
|
PYTHON_ENV = os.getenv("PYTHON_ENV")
|
|
7
|
-
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
|
|
8
|
-
EXPERIMENT_ID = os.getenv("EXPERIMENT_ID")
|
|
9
7
|
LOGGING_LEVEL = os.getenv("LOGGING_LEVEL", "INFO")
|
|
10
8
|
|
|
11
9
|
DB_USER = (
|
|
@@ -32,5 +30,19 @@ DB_URI: str = (
|
|
|
32
30
|
)
|
|
33
31
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
34
32
|
LECRAPAUD_LOGFILE = os.getenv("LECRAPAUD_LOGFILE")
|
|
35
|
-
LECRAPAUD_LOCAL = os.getenv("LECRAPAUD_LOCAL", False)
|
|
36
33
|
LECRAPAUD_TABLE_PREFIX = os.getenv("LECRAPAUD_TABLE_PREFIX", "lecrapaud")
|
|
34
|
+
LECRAPAUD_OPTIMIZATION_BACKEND = os.getenv(
|
|
35
|
+
"LECRAPAUD_OPTIMIZATION_BACKEND", "hyperopt"
|
|
36
|
+
).lower()
|
|
37
|
+
|
|
38
|
+
SENTRY_DSN = os.getenv("SENTRY_DSN")
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
SENTRY_TRACES_SAMPLE_RATE = float(os.getenv("SENTRY_TRACES_SAMPLE_RATE", "0"))
|
|
42
|
+
except ValueError:
|
|
43
|
+
SENTRY_TRACES_SAMPLE_RATE = 0.0
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
SENTRY_PROFILES_SAMPLE_RATE = float(os.getenv("SENTRY_PROFILES_SAMPLE_RATE", "0"))
|
|
47
|
+
except ValueError:
|
|
48
|
+
SENTRY_PROFILES_SAMPLE_RATE = 0.0
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""add unique constraint to score
|
|
2
|
+
|
|
3
|
+
Revision ID: 07e303521594
|
|
4
|
+
Revises: 8b11c1ba982e
|
|
5
|
+
Create Date: 2025-10-25 06:35:57.950929
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Sequence, Union
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
import sqlalchemy as sa
|
|
13
|
+
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
|
+
|
|
15
|
+
# revision identifiers, used by Alembic.
|
|
16
|
+
revision: str = "07e303521594"
|
|
17
|
+
down_revision: Union[str, None] = "8b11c1ba982e"
|
|
18
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
19
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def upgrade() -> None:
|
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
24
|
+
op.create_unique_constraint(
|
|
25
|
+
"unique_score_per_model_training",
|
|
26
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
27
|
+
["model_training_id"],
|
|
28
|
+
)
|
|
29
|
+
# ### end Alembic commands ###
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def downgrade() -> None:
|
|
33
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
34
|
+
op.drop_constraint(
|
|
35
|
+
"unique_score_per_model_training",
|
|
36
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
37
|
+
type_="unique",
|
|
38
|
+
)
|
|
39
|
+
# ### end Alembic commands ###
|
lecrapaud/db/alembic/versions/2025_10_26_1727-033e0f7eca4f_merge_score_and_model_trainings_into_.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""merge score and model_trainings into model_selection_scores
|
|
2
|
+
|
|
3
|
+
Revision ID: 033e0f7eca4f
|
|
4
|
+
Revises: 07e303521594
|
|
5
|
+
Create Date: 2025-10-26 17:27:30.400473
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Sequence, Union
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
import sqlalchemy as sa
|
|
13
|
+
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
|
+
|
|
15
|
+
# revision identifiers, used by Alembic.
|
|
16
|
+
revision: str = "033e0f7eca4f"
|
|
17
|
+
down_revision: Union[str, None] = "07e303521594"
|
|
18
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
19
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def upgrade() -> None:
|
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
24
|
+
# Check if table exists using inspector
|
|
25
|
+
from sqlalchemy import inspect
|
|
26
|
+
inspector = inspect(op.get_bind())
|
|
27
|
+
existing_tables = inspector.get_table_names()
|
|
28
|
+
|
|
29
|
+
if f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores" not in existing_tables:
|
|
30
|
+
op.create_table(
|
|
31
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
32
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
33
|
+
sa.Column(
|
|
34
|
+
"created_at",
|
|
35
|
+
sa.TIMESTAMP(timezone=True),
|
|
36
|
+
server_default=sa.text("now()"),
|
|
37
|
+
nullable=False,
|
|
38
|
+
),
|
|
39
|
+
sa.Column(
|
|
40
|
+
"updated_at",
|
|
41
|
+
sa.TIMESTAMP(timezone=True),
|
|
42
|
+
server_default=sa.text("now()"),
|
|
43
|
+
nullable=False,
|
|
44
|
+
),
|
|
45
|
+
sa.Column("best_params", sa.JSON(), nullable=True),
|
|
46
|
+
sa.Column("model_path", sa.String(length=255), nullable=True),
|
|
47
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
48
|
+
sa.Column("model_id", sa.BigInteger(), nullable=False),
|
|
49
|
+
sa.Column("model_selection_id", sa.BigInteger(), nullable=False),
|
|
50
|
+
sa.Column("eval_data_std", sa.Float(), nullable=True),
|
|
51
|
+
sa.Column("rmse", sa.Float(), nullable=True),
|
|
52
|
+
sa.Column("rmse_std_ratio", sa.Float(), nullable=True),
|
|
53
|
+
sa.Column("mae", sa.Float(), nullable=True),
|
|
54
|
+
sa.Column("mape", sa.Float(), nullable=True),
|
|
55
|
+
sa.Column("mam", sa.Float(), nullable=True),
|
|
56
|
+
sa.Column("mad", sa.Float(), nullable=True),
|
|
57
|
+
sa.Column("mae_mam_ratio", sa.Float(), nullable=True),
|
|
58
|
+
sa.Column("mae_mad_ratio", sa.Float(), nullable=True),
|
|
59
|
+
sa.Column("r2", sa.Float(), nullable=True),
|
|
60
|
+
sa.Column("logloss", sa.Float(), nullable=True),
|
|
61
|
+
sa.Column("accuracy", sa.Float(), nullable=True),
|
|
62
|
+
sa.Column("precision", sa.Float(), nullable=True),
|
|
63
|
+
sa.Column("recall", sa.Float(), nullable=True),
|
|
64
|
+
sa.Column("f1", sa.Float(), nullable=True),
|
|
65
|
+
sa.Column("roc_auc", sa.Float(), nullable=True),
|
|
66
|
+
sa.Column("avg_precision", sa.Float(), nullable=True),
|
|
67
|
+
sa.Column("thresholds", sa.JSON(), nullable=True),
|
|
68
|
+
sa.Column("precision_at_threshold", sa.Float(), nullable=True),
|
|
69
|
+
sa.Column("recall_at_threshold", sa.Float(), nullable=True),
|
|
70
|
+
sa.Column("f1_at_threshold", sa.Float(), nullable=True),
|
|
71
|
+
sa.ForeignKeyConstraint(
|
|
72
|
+
["model_id"],
|
|
73
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_models.id"],
|
|
74
|
+
),
|
|
75
|
+
sa.ForeignKeyConstraint(
|
|
76
|
+
["model_selection_id"],
|
|
77
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_selections.id"],
|
|
78
|
+
ondelete="CASCADE",
|
|
79
|
+
),
|
|
80
|
+
sa.PrimaryKeyConstraint("id"),
|
|
81
|
+
sa.UniqueConstraint(
|
|
82
|
+
"model_id",
|
|
83
|
+
"model_selection_id",
|
|
84
|
+
name="uq_model_selection_score_composite",
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
op.create_index(
|
|
88
|
+
op.f("ix_model_selection_scores_id"),
|
|
89
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
90
|
+
["id"],
|
|
91
|
+
unique=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Migrate existing data from model_trainings and scores to model_selection_scores
|
|
95
|
+
op.execute(
|
|
96
|
+
f"""
|
|
97
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_model_selection_scores (
|
|
98
|
+
created_at, updated_at, best_params, model_path, training_time,
|
|
99
|
+
model_id, model_selection_id,
|
|
100
|
+
eval_data_std, rmse, rmse_std_ratio, mae, mape, mam, mad,
|
|
101
|
+
mae_mam_ratio, mae_mad_ratio, r2, logloss, accuracy, `precision`,
|
|
102
|
+
recall, f1, roc_auc, avg_precision, thresholds,
|
|
103
|
+
precision_at_threshold, recall_at_threshold, f1_at_threshold
|
|
104
|
+
)
|
|
105
|
+
SELECT
|
|
106
|
+
mt.created_at,
|
|
107
|
+
mt.updated_at,
|
|
108
|
+
mt.best_params,
|
|
109
|
+
mt.model_path,
|
|
110
|
+
COALESCE(mt.training_time, s.training_time) as training_time,
|
|
111
|
+
mt.model_id,
|
|
112
|
+
mt.model_selection_id,
|
|
113
|
+
s.eval_data_std, s.rmse, s.rmse_std_ratio, s.mae, s.mape,
|
|
114
|
+
s.mam, s.mad, s.mae_mam_ratio, s.mae_mad_ratio, s.r2,
|
|
115
|
+
s.logloss, s.accuracy, s.`precision`, s.recall, s.f1,
|
|
116
|
+
s.roc_auc, s.avg_precision, s.thresholds,
|
|
117
|
+
s.precision_at_threshold, s.recall_at_threshold, s.f1_at_threshold
|
|
118
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_trainings mt
|
|
119
|
+
LEFT JOIN {LECRAPAUD_TABLE_PREFIX}_scores s ON s.model_training_id = mt.id
|
|
120
|
+
"""
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Drop the old tables
|
|
124
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_scores")
|
|
125
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_model_trainings")
|
|
126
|
+
# ### end Alembic commands ###
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def downgrade() -> None:
|
|
130
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
131
|
+
# Recreate the old tables
|
|
132
|
+
op.create_table(
|
|
133
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_trainings",
|
|
134
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
135
|
+
sa.Column(
|
|
136
|
+
"created_at",
|
|
137
|
+
sa.TIMESTAMP(timezone=True),
|
|
138
|
+
server_default=sa.text("now()"),
|
|
139
|
+
nullable=False,
|
|
140
|
+
),
|
|
141
|
+
sa.Column(
|
|
142
|
+
"updated_at",
|
|
143
|
+
sa.TIMESTAMP(timezone=True),
|
|
144
|
+
server_default=sa.text("now()"),
|
|
145
|
+
nullable=False,
|
|
146
|
+
),
|
|
147
|
+
sa.Column("best_params", sa.JSON(), nullable=True),
|
|
148
|
+
sa.Column("model_path", sa.String(length=255), nullable=True),
|
|
149
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
150
|
+
sa.Column("model_id", sa.BigInteger(), nullable=False),
|
|
151
|
+
sa.Column("model_selection_id", sa.BigInteger(), nullable=False),
|
|
152
|
+
sa.ForeignKeyConstraint(
|
|
153
|
+
["model_id"],
|
|
154
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_models.id"],
|
|
155
|
+
),
|
|
156
|
+
sa.ForeignKeyConstraint(
|
|
157
|
+
["model_selection_id"],
|
|
158
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_selections.id"],
|
|
159
|
+
ondelete="CASCADE",
|
|
160
|
+
),
|
|
161
|
+
sa.PrimaryKeyConstraint("id"),
|
|
162
|
+
sa.UniqueConstraint(
|
|
163
|
+
"model_id", "model_selection_id", name="uq_model_training_composite"
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
op.create_index(
|
|
167
|
+
op.f("ix_model_trainings_id"),
|
|
168
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_trainings",
|
|
169
|
+
["id"],
|
|
170
|
+
unique=False,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
op.create_table(
|
|
174
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
175
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
176
|
+
sa.Column(
|
|
177
|
+
"created_at",
|
|
178
|
+
sa.TIMESTAMP(timezone=True),
|
|
179
|
+
server_default=sa.text("now()"),
|
|
180
|
+
nullable=False,
|
|
181
|
+
),
|
|
182
|
+
sa.Column(
|
|
183
|
+
"updated_at",
|
|
184
|
+
sa.TIMESTAMP(timezone=True),
|
|
185
|
+
server_default=sa.text("now()"),
|
|
186
|
+
nullable=False,
|
|
187
|
+
),
|
|
188
|
+
sa.Column("type", sa.String(length=50), nullable=False),
|
|
189
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
190
|
+
sa.Column("eval_data_std", sa.Float(), nullable=True),
|
|
191
|
+
sa.Column("rmse", sa.Float(), nullable=True),
|
|
192
|
+
sa.Column("rmse_std_ratio", sa.Float(), nullable=True),
|
|
193
|
+
sa.Column("mae", sa.Float(), nullable=True),
|
|
194
|
+
sa.Column("mape", sa.Float(), nullable=True),
|
|
195
|
+
sa.Column("mam", sa.Float(), nullable=True),
|
|
196
|
+
sa.Column("mad", sa.Float(), nullable=True),
|
|
197
|
+
sa.Column("mae_mam_ratio", sa.Float(), nullable=True),
|
|
198
|
+
sa.Column("mae_mad_ratio", sa.Float(), nullable=True),
|
|
199
|
+
sa.Column("r2", sa.Float(), nullable=True),
|
|
200
|
+
sa.Column("logloss", sa.Float(), nullable=True),
|
|
201
|
+
sa.Column("accuracy", sa.Float(), nullable=True),
|
|
202
|
+
sa.Column("precision", sa.Float(), nullable=True),
|
|
203
|
+
sa.Column("recall", sa.Float(), nullable=True),
|
|
204
|
+
sa.Column("f1", sa.Float(), nullable=True),
|
|
205
|
+
sa.Column("roc_auc", sa.Float(), nullable=True),
|
|
206
|
+
sa.Column("avg_precision", sa.Float(), nullable=True),
|
|
207
|
+
sa.Column("thresholds", sa.JSON(), nullable=True),
|
|
208
|
+
sa.Column("precision_at_threshold", sa.Float(), nullable=True),
|
|
209
|
+
sa.Column("recall_at_threshold", sa.Float(), nullable=True),
|
|
210
|
+
sa.Column("f1_at_threshold", sa.Float(), nullable=True),
|
|
211
|
+
sa.Column("model_training_id", sa.BigInteger(), nullable=False),
|
|
212
|
+
sa.ForeignKeyConstraint(
|
|
213
|
+
["model_training_id"],
|
|
214
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_trainings.id"],
|
|
215
|
+
ondelete="CASCADE",
|
|
216
|
+
),
|
|
217
|
+
sa.PrimaryKeyConstraint("id"),
|
|
218
|
+
sa.UniqueConstraint(
|
|
219
|
+
"model_training_id", name="unique_score_per_model_training"
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
op.create_index(
|
|
223
|
+
op.f("ix_scores_id"), f"{LECRAPAUD_TABLE_PREFIX}_scores", ["id"], unique=False
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Migrate data back (note: we'll lose the type column data, defaulting to 'testset')
|
|
227
|
+
op.execute(
|
|
228
|
+
f"""
|
|
229
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_model_trainings (
|
|
230
|
+
id, created_at, updated_at, best_params, model_path,
|
|
231
|
+
training_time, model_id, model_selection_id
|
|
232
|
+
)
|
|
233
|
+
SELECT
|
|
234
|
+
id, created_at, updated_at, best_params, model_path,
|
|
235
|
+
training_time, model_id, model_selection_id
|
|
236
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_selection_scores
|
|
237
|
+
"""
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
op.execute(
|
|
241
|
+
f"""
|
|
242
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_scores (
|
|
243
|
+
created_at, updated_at, type, training_time, eval_data_std,
|
|
244
|
+
rmse, rmse_std_ratio, mae, mape, mam, mad, mae_mam_ratio,
|
|
245
|
+
mae_mad_ratio, r2, logloss, accuracy, `precision`, recall,
|
|
246
|
+
f1, roc_auc, avg_precision, thresholds, precision_at_threshold,
|
|
247
|
+
recall_at_threshold, f1_at_threshold, model_training_id
|
|
248
|
+
)
|
|
249
|
+
SELECT
|
|
250
|
+
created_at, updated_at, 'testset', training_time, eval_data_std,
|
|
251
|
+
rmse, rmse_std_ratio, mae, mape, mam, mad, mae_mam_ratio,
|
|
252
|
+
mae_mad_ratio, r2, logloss, accuracy, precision, recall,
|
|
253
|
+
f1, roc_auc, avg_precision, thresholds, precision_at_threshold,
|
|
254
|
+
recall_at_threshold, f1_at_threshold, id
|
|
255
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_selection_scores
|
|
256
|
+
"""
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
op.drop_index(
|
|
260
|
+
op.f("ix_model_selection_scores_id"),
|
|
261
|
+
table_name=f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
262
|
+
)
|
|
263
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores")
|
|
264
|
+
# ### end Alembic commands ###
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""add number_of_targets and remove other fields from experiments
|
|
2
|
+
|
|
3
|
+
Revision ID: 0a8fb7826e9b
|
|
4
|
+
Revises: 033e0f7eca4f
|
|
5
|
+
Create Date: 2025-10-28 20:06:54.792631
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Sequence, Union
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
import sqlalchemy as sa
|
|
13
|
+
from sqlalchemy.dialects import mysql
|
|
14
|
+
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
15
|
+
|
|
16
|
+
# revision identifiers, used by Alembic.
|
|
17
|
+
revision: str = "0a8fb7826e9b"
|
|
18
|
+
down_revision: Union[str, None] = "033e0f7eca4f"
|
|
19
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
20
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def upgrade() -> None:
|
|
24
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
25
|
+
op.add_column(
|
|
26
|
+
f"{LECRAPAUD_TABLE_PREFIX}_experiments",
|
|
27
|
+
sa.Column("number_of_targets", sa.Integer(), nullable=True),
|
|
28
|
+
)
|
|
29
|
+
op.drop_column(f"{LECRAPAUD_TABLE_PREFIX}_experiments", "corr_threshold")
|
|
30
|
+
op.drop_column(f"{LECRAPAUD_TABLE_PREFIX}_experiments", "max_features")
|
|
31
|
+
op.drop_column(f"{LECRAPAUD_TABLE_PREFIX}_experiments", "percentile")
|
|
32
|
+
op.drop_column(f"{LECRAPAUD_TABLE_PREFIX}_experiments", "type")
|
|
33
|
+
op.drop_index(
|
|
34
|
+
op.f("ix_model_selection_scores_id"),
|
|
35
|
+
table_name=f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
36
|
+
)
|
|
37
|
+
op.create_index(
|
|
38
|
+
op.f("ix_model_selection_scores_id"),
|
|
39
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
40
|
+
["id"],
|
|
41
|
+
unique=False,
|
|
42
|
+
)
|
|
43
|
+
# ### end Alembic commands ###
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def downgrade() -> None:
|
|
47
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
48
|
+
op.drop_index(
|
|
49
|
+
op.f("ix_lecrapaud_model_selection_scores_id"),
|
|
50
|
+
table_name=f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
51
|
+
)
|
|
52
|
+
op.create_index(
|
|
53
|
+
op.f("ix_model_selection_scores_id"),
|
|
54
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
55
|
+
["id"],
|
|
56
|
+
unique=False,
|
|
57
|
+
)
|
|
58
|
+
op.add_column(
|
|
59
|
+
f"{LECRAPAUD_TABLE_PREFIX}_experiments",
|
|
60
|
+
sa.Column("type", mysql.VARCHAR(length=50), nullable=False),
|
|
61
|
+
)
|
|
62
|
+
op.add_column(
|
|
63
|
+
f"{LECRAPAUD_TABLE_PREFIX}_experiments",
|
|
64
|
+
sa.Column("percentile", mysql.FLOAT(), nullable=False),
|
|
65
|
+
)
|
|
66
|
+
op.add_column(
|
|
67
|
+
f"{LECRAPAUD_TABLE_PREFIX}_experiments",
|
|
68
|
+
sa.Column("max_features", mysql.INTEGER(), autoincrement=False, nullable=False),
|
|
69
|
+
)
|
|
70
|
+
op.add_column(
|
|
71
|
+
f"{LECRAPAUD_TABLE_PREFIX}_experiments",
|
|
72
|
+
sa.Column("corr_threshold", mysql.FLOAT(), nullable=False),
|
|
73
|
+
)
|
|
74
|
+
op.drop_column(f"{LECRAPAUD_TABLE_PREFIX}_experiments", "number_of_targets")
|
|
75
|
+
# ### end Alembic commands ###
|
lecrapaud/db/models/__init__.py
CHANGED
|
@@ -4,9 +4,8 @@ from lecrapaud.db.models.feature_selection_rank import FeatureSelectionRank
|
|
|
4
4
|
from lecrapaud.db.models.feature_selection import FeatureSelection
|
|
5
5
|
from lecrapaud.db.models.feature import Feature
|
|
6
6
|
from lecrapaud.db.models.model_selection import ModelSelection
|
|
7
|
-
from lecrapaud.db.models.model_training import ModelTraining
|
|
8
7
|
from lecrapaud.db.models.model import Model
|
|
9
|
-
from lecrapaud.db.models.
|
|
8
|
+
from lecrapaud.db.models.model_selection_score import ModelSelectionScore
|
|
10
9
|
from lecrapaud.db.models.target import Target
|
|
11
10
|
|
|
12
11
|
__all__ = [
|
|
@@ -16,8 +15,7 @@ __all__ = [
|
|
|
16
15
|
'FeatureSelection',
|
|
17
16
|
'Feature',
|
|
18
17
|
'ModelSelection',
|
|
19
|
-
'ModelTraining',
|
|
20
18
|
'Model',
|
|
21
|
-
'
|
|
19
|
+
'ModelSelectionScore',
|
|
22
20
|
'Target',
|
|
23
21
|
]
|
lecrapaud/db/models/base.py
CHANGED
|
@@ -10,24 +10,27 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
|
10
10
|
from lecrapaud.db.session import get_db
|
|
11
11
|
from sqlalchemy.ext.declarative import declared_attr
|
|
12
12
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
|
13
|
+
from sqlalchemy import UniqueConstraint
|
|
14
|
+
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
|
|
13
15
|
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def with_db(func):
|
|
17
19
|
"""Decorator to provide a database session to the wrapped function.
|
|
18
|
-
|
|
20
|
+
|
|
19
21
|
If a db parameter is already provided, it will be used. Otherwise,
|
|
20
22
|
a new session will be created and automatically managed.
|
|
21
23
|
"""
|
|
24
|
+
|
|
22
25
|
@wraps(func)
|
|
23
26
|
def wrapper(*args, **kwargs):
|
|
24
27
|
if "db" in kwargs and kwargs["db"] is not None:
|
|
25
28
|
return func(*args, **kwargs)
|
|
26
|
-
|
|
29
|
+
|
|
27
30
|
with get_db() as db:
|
|
28
31
|
kwargs["db"] = db
|
|
29
32
|
return func(*args, **kwargs)
|
|
30
|
-
|
|
33
|
+
|
|
31
34
|
return wrapper
|
|
32
35
|
|
|
33
36
|
|
|
@@ -106,51 +109,6 @@ class Base(DeclarativeBase):
|
|
|
106
109
|
]
|
|
107
110
|
return results
|
|
108
111
|
|
|
109
|
-
@classmethod
|
|
110
|
-
@with_db
|
|
111
|
-
def upsert_bulk(cls, db=None, match_fields: list[str] = None, **kwargs):
|
|
112
|
-
"""
|
|
113
|
-
Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
db (Session): SQLAlchemy DB session
|
|
117
|
-
match_fields (list[str]): Fields to match on for deduplication
|
|
118
|
-
**kwargs: Column-wise keyword arguments (field_name=[...])
|
|
119
|
-
"""
|
|
120
|
-
# Ensure all provided fields have values of equal length
|
|
121
|
-
value_lengths = [len(v) for v in kwargs.values()]
|
|
122
|
-
if not value_lengths or len(set(value_lengths)) != 1:
|
|
123
|
-
raise ValueError(
|
|
124
|
-
"All field values must be non-empty lists of the same length."
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Convert column-wise kwargs to row-wise list of dicts
|
|
128
|
-
items = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
|
|
129
|
-
if not items:
|
|
130
|
-
return
|
|
131
|
-
|
|
132
|
-
stmt = mysql_insert(cls.__table__).values(items)
|
|
133
|
-
|
|
134
|
-
# Default to primary keys if match_fields not provided
|
|
135
|
-
if not match_fields:
|
|
136
|
-
match_fields = [col.name for col in cls.__table__.primary_key.columns]
|
|
137
|
-
|
|
138
|
-
# Ensure all columns to be updated are in the insert
|
|
139
|
-
update_dict = {
|
|
140
|
-
c.name: stmt.inserted[c.name]
|
|
141
|
-
for c in cls.__table__.columns
|
|
142
|
-
if c.name not in match_fields and c.name in items[0]
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
if not update_dict:
|
|
146
|
-
# Avoid triggering ON DUPLICATE KEY UPDATE with empty dict
|
|
147
|
-
db.execute(stmt.prefix_with("IGNORE"))
|
|
148
|
-
else:
|
|
149
|
-
upsert_stmt = stmt.on_duplicate_key_update(**update_dict)
|
|
150
|
-
db.execute(upsert_stmt)
|
|
151
|
-
|
|
152
|
-
db.commit()
|
|
153
|
-
|
|
154
112
|
@classmethod
|
|
155
113
|
@with_db
|
|
156
114
|
def filter(cls, db=None, **kwargs):
|
|
@@ -198,33 +156,126 @@ class Base(DeclarativeBase):
|
|
|
198
156
|
|
|
199
157
|
@classmethod
|
|
200
158
|
@with_db
|
|
201
|
-
def upsert(cls,
|
|
159
|
+
def upsert(cls, db=None, **kwargs):
|
|
202
160
|
"""
|
|
203
|
-
Upsert an instance of the model
|
|
161
|
+
Upsert an instance of the model using MySQL's ON DUPLICATE KEY UPDATE.
|
|
204
162
|
|
|
205
|
-
:param match_fields: list of field names to use for matching
|
|
206
163
|
:param kwargs: all fields for creation or update
|
|
207
164
|
"""
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
165
|
+
# If an ID is provided and row exists, fall back to a standard update
|
|
166
|
+
instance_id = kwargs.get("id")
|
|
167
|
+
if instance_id is not None:
|
|
168
|
+
instance = db.get(cls, instance_id)
|
|
169
|
+
if instance:
|
|
170
|
+
for key, value in kwargs.items():
|
|
171
|
+
if key == "id":
|
|
172
|
+
continue
|
|
173
|
+
setattr(instance, key, value)
|
|
174
|
+
db.commit()
|
|
175
|
+
db.refresh(instance)
|
|
176
|
+
return instance
|
|
177
|
+
|
|
178
|
+
# Use INSERT ... ON DUPLICATE KEY UPDATE
|
|
179
|
+
stmt = mysql_insert(cls.__table__).values(**kwargs)
|
|
180
|
+
stmt = stmt.on_duplicate_key_update(
|
|
181
|
+
**{k: v for k, v in kwargs.items() if k != "id"}
|
|
182
|
+
)
|
|
213
183
|
|
|
214
|
-
|
|
184
|
+
result = db.execute(stmt)
|
|
185
|
+
db.commit()
|
|
215
186
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
187
|
+
# Get the instance - either the newly inserted or updated one
|
|
188
|
+
# If updated, lastrowid is 0, so we need to query
|
|
189
|
+
if result.lastrowid and result.lastrowid > 0:
|
|
190
|
+
# New insert
|
|
191
|
+
instance = db.get(cls, result.lastrowid)
|
|
220
192
|
else:
|
|
221
|
-
|
|
222
|
-
|
|
193
|
+
# Updated - need to find it using unique constraint fields
|
|
194
|
+
mapper = sqlalchemy_inspect(cls)
|
|
195
|
+
instance = None
|
|
196
|
+
|
|
197
|
+
for constraint in mapper.mapped_table.constraints:
|
|
198
|
+
if isinstance(constraint, UniqueConstraint):
|
|
199
|
+
col_names = [col.name for col in constraint.columns]
|
|
200
|
+
if all(name in kwargs for name in col_names):
|
|
201
|
+
filters = [
|
|
202
|
+
getattr(cls, col_name) == kwargs[col_name]
|
|
203
|
+
for col_name in col_names
|
|
204
|
+
]
|
|
205
|
+
instance = db.query(cls).filter(*filters).first()
|
|
206
|
+
if instance:
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
# Check for single column unique constraints
|
|
210
|
+
if not instance:
|
|
211
|
+
for col in mapper.mapped_table.columns:
|
|
212
|
+
if col.unique and col.name in kwargs:
|
|
213
|
+
instance = (
|
|
214
|
+
db.query(cls)
|
|
215
|
+
.filter(getattr(cls, col.name) == kwargs[col.name])
|
|
216
|
+
.first()
|
|
217
|
+
)
|
|
218
|
+
if instance:
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
# If still not found, try to find by all kwargs (excluding None values)
|
|
222
|
+
if not instance:
|
|
223
|
+
instance = (
|
|
224
|
+
db.query(cls)
|
|
225
|
+
.filter_by(
|
|
226
|
+
**{
|
|
227
|
+
k: v
|
|
228
|
+
for k, v in kwargs.items()
|
|
229
|
+
if v is not None and k != "id"
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
.first()
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
if instance:
|
|
236
|
+
db.refresh(instance)
|
|
223
237
|
|
|
224
|
-
db.commit()
|
|
225
|
-
db.refresh(instance)
|
|
226
238
|
return instance
|
|
227
239
|
|
|
240
|
+
@classmethod
|
|
241
|
+
@with_db
|
|
242
|
+
def bulk_upsert(cls, rows: list[dict] = None, db=None, **kwargs):
|
|
243
|
+
"""
|
|
244
|
+
Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
rows (list[dict]): List of dictionaries representing rows to upsert
|
|
248
|
+
db (Session): SQLAlchemy DB session
|
|
249
|
+
**kwargs: Column-wise keyword arguments (field_name=[...]) for backwards compatibility
|
|
250
|
+
"""
|
|
251
|
+
# Handle both new format (rows) and legacy format (kwargs)
|
|
252
|
+
if rows is None and kwargs:
|
|
253
|
+
# Legacy format: convert column-wise kwargs to row-wise list of dicts
|
|
254
|
+
value_lengths = [len(v) for v in kwargs.values()]
|
|
255
|
+
if not value_lengths or len(set(value_lengths)) != 1:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
"All field values must be non-empty lists of the same length."
|
|
258
|
+
)
|
|
259
|
+
rows = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
|
|
260
|
+
|
|
261
|
+
if not rows:
|
|
262
|
+
return 0
|
|
263
|
+
|
|
264
|
+
BATCH_SIZE = 200
|
|
265
|
+
total_affected = 0
|
|
266
|
+
|
|
267
|
+
for i in range(0, len(rows), BATCH_SIZE):
|
|
268
|
+
batch = rows[i : i + BATCH_SIZE]
|
|
269
|
+
stmt = mysql_insert(cls.__table__).values(batch)
|
|
270
|
+
stmt = stmt.on_duplicate_key_update(
|
|
271
|
+
**{key: stmt.inserted[key] for key in batch[0] if key != "id"}
|
|
272
|
+
)
|
|
273
|
+
result = db.execute(stmt)
|
|
274
|
+
total_affected += result.rowcount
|
|
275
|
+
|
|
276
|
+
db.commit()
|
|
277
|
+
return total_affected
|
|
278
|
+
|
|
228
279
|
@classmethod
|
|
229
280
|
@with_db
|
|
230
281
|
def delete(cls, id: int, db=None):
|