lecrapaud 0.19.2__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lecrapaud might be problematic. Click here for more details.
- lecrapaud/api.py +3 -0
- lecrapaud/config.py +1 -0
- lecrapaud/db/alembic/versions/2025_10_25_0635-07e303521594_add_unique_constraint_to_score.py +39 -0
- lecrapaud/db/alembic/versions/2025_10_26_1727-033e0f7eca4f_merge_score_and_model_trainings_into_.py +264 -0
- lecrapaud/db/models/__init__.py +2 -4
- lecrapaud/db/models/base.py +103 -65
- lecrapaud/db/models/experiment.py +53 -46
- lecrapaud/db/models/feature_selection.py +0 -3
- lecrapaud/db/models/feature_selection_rank.py +0 -18
- lecrapaud/db/models/model_selection.py +2 -2
- lecrapaud/db/models/{score.py → model_selection_score.py} +29 -12
- lecrapaud/db/session.py +1 -0
- lecrapaud/experiment.py +7 -4
- lecrapaud/feature_engineering.py +6 -9
- lecrapaud/feature_selection.py +0 -1
- lecrapaud/model_selection.py +478 -170
- lecrapaud/search_space.py +2 -1
- lecrapaud/utils.py +22 -2
- {lecrapaud-0.19.2.dist-info → lecrapaud-0.20.0.dist-info}/METADATA +1 -1
- {lecrapaud-0.19.2.dist-info → lecrapaud-0.20.0.dist-info}/RECORD +22 -21
- lecrapaud/db/models/model_training.py +0 -64
- {lecrapaud-0.19.2.dist-info → lecrapaud-0.20.0.dist-info}/WHEEL +0 -0
- {lecrapaud-0.19.2.dist-info → lecrapaud-0.20.0.dist-info}/licenses/LICENSE +0 -0
lecrapaud/api.py
CHANGED
|
@@ -475,6 +475,9 @@ class ExperimentEngine:
|
|
|
475
475
|
# For lightgbm models
|
|
476
476
|
importances = model.feature_importance(importance_type="split")
|
|
477
477
|
importance_type = "Split"
|
|
478
|
+
elif hasattr(model, "get_feature_importance"):
|
|
479
|
+
importances = model.get_feature_importance()
|
|
480
|
+
importance_type = "Feature importance"
|
|
478
481
|
elif hasattr(model, "coef_"):
|
|
479
482
|
# For linear models
|
|
480
483
|
importances = np.abs(model.coef_.flatten())
|
lecrapaud/config.py
CHANGED
|
@@ -34,3 +34,4 @@ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
34
34
|
LECRAPAUD_LOGFILE = os.getenv("LECRAPAUD_LOGFILE")
|
|
35
35
|
LECRAPAUD_LOCAL = os.getenv("LECRAPAUD_LOCAL", False)
|
|
36
36
|
LECRAPAUD_TABLE_PREFIX = os.getenv("LECRAPAUD_TABLE_PREFIX", "lecrapaud")
|
|
37
|
+
LECRAPAUD_OPTIMIZATION_BACKEND = os.getenv("LECRAPAUD_OPTIMIZATION_BACKEND", "ray").lower()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""add unique constraint to score
|
|
2
|
+
|
|
3
|
+
Revision ID: 07e303521594
|
|
4
|
+
Revises: 8b11c1ba982e
|
|
5
|
+
Create Date: 2025-10-25 06:35:57.950929
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Sequence, Union
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
import sqlalchemy as sa
|
|
13
|
+
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
|
+
|
|
15
|
+
# revision identifiers, used by Alembic.
|
|
16
|
+
revision: str = "07e303521594"
|
|
17
|
+
down_revision: Union[str, None] = "8b11c1ba982e"
|
|
18
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
19
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def upgrade() -> None:
|
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
24
|
+
op.create_unique_constraint(
|
|
25
|
+
"unique_score_per_model_training",
|
|
26
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
27
|
+
["model_training_id"],
|
|
28
|
+
)
|
|
29
|
+
# ### end Alembic commands ###
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def downgrade() -> None:
|
|
33
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
34
|
+
op.drop_constraint(
|
|
35
|
+
"unique_score_per_model_training",
|
|
36
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
37
|
+
type_="unique",
|
|
38
|
+
)
|
|
39
|
+
# ### end Alembic commands ###
|
lecrapaud/db/alembic/versions/2025_10_26_1727-033e0f7eca4f_merge_score_and_model_trainings_into_.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""merge score and model_trainings into model_selection_scores
|
|
2
|
+
|
|
3
|
+
Revision ID: 033e0f7eca4f
|
|
4
|
+
Revises: 07e303521594
|
|
5
|
+
Create Date: 2025-10-26 17:27:30.400473
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Sequence, Union
|
|
10
|
+
|
|
11
|
+
from alembic import op
|
|
12
|
+
import sqlalchemy as sa
|
|
13
|
+
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
|
+
|
|
15
|
+
# revision identifiers, used by Alembic.
|
|
16
|
+
revision: str = "033e0f7eca4f"
|
|
17
|
+
down_revision: Union[str, None] = "07e303521594"
|
|
18
|
+
branch_labels: Union[str, Sequence[str], None] = None
|
|
19
|
+
depends_on: Union[str, Sequence[str], None] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def upgrade() -> None:
|
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
24
|
+
# Check if table exists using inspector
|
|
25
|
+
from sqlalchemy import inspect
|
|
26
|
+
inspector = inspect(op.get_bind())
|
|
27
|
+
existing_tables = inspector.get_table_names()
|
|
28
|
+
|
|
29
|
+
if f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores" not in existing_tables:
|
|
30
|
+
op.create_table(
|
|
31
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
32
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
33
|
+
sa.Column(
|
|
34
|
+
"created_at",
|
|
35
|
+
sa.TIMESTAMP(timezone=True),
|
|
36
|
+
server_default=sa.text("now()"),
|
|
37
|
+
nullable=False,
|
|
38
|
+
),
|
|
39
|
+
sa.Column(
|
|
40
|
+
"updated_at",
|
|
41
|
+
sa.TIMESTAMP(timezone=True),
|
|
42
|
+
server_default=sa.text("now()"),
|
|
43
|
+
nullable=False,
|
|
44
|
+
),
|
|
45
|
+
sa.Column("best_params", sa.JSON(), nullable=True),
|
|
46
|
+
sa.Column("model_path", sa.String(length=255), nullable=True),
|
|
47
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
48
|
+
sa.Column("model_id", sa.BigInteger(), nullable=False),
|
|
49
|
+
sa.Column("model_selection_id", sa.BigInteger(), nullable=False),
|
|
50
|
+
sa.Column("eval_data_std", sa.Float(), nullable=True),
|
|
51
|
+
sa.Column("rmse", sa.Float(), nullable=True),
|
|
52
|
+
sa.Column("rmse_std_ratio", sa.Float(), nullable=True),
|
|
53
|
+
sa.Column("mae", sa.Float(), nullable=True),
|
|
54
|
+
sa.Column("mape", sa.Float(), nullable=True),
|
|
55
|
+
sa.Column("mam", sa.Float(), nullable=True),
|
|
56
|
+
sa.Column("mad", sa.Float(), nullable=True),
|
|
57
|
+
sa.Column("mae_mam_ratio", sa.Float(), nullable=True),
|
|
58
|
+
sa.Column("mae_mad_ratio", sa.Float(), nullable=True),
|
|
59
|
+
sa.Column("r2", sa.Float(), nullable=True),
|
|
60
|
+
sa.Column("logloss", sa.Float(), nullable=True),
|
|
61
|
+
sa.Column("accuracy", sa.Float(), nullable=True),
|
|
62
|
+
sa.Column("precision", sa.Float(), nullable=True),
|
|
63
|
+
sa.Column("recall", sa.Float(), nullable=True),
|
|
64
|
+
sa.Column("f1", sa.Float(), nullable=True),
|
|
65
|
+
sa.Column("roc_auc", sa.Float(), nullable=True),
|
|
66
|
+
sa.Column("avg_precision", sa.Float(), nullable=True),
|
|
67
|
+
sa.Column("thresholds", sa.JSON(), nullable=True),
|
|
68
|
+
sa.Column("precision_at_threshold", sa.Float(), nullable=True),
|
|
69
|
+
sa.Column("recall_at_threshold", sa.Float(), nullable=True),
|
|
70
|
+
sa.Column("f1_at_threshold", sa.Float(), nullable=True),
|
|
71
|
+
sa.ForeignKeyConstraint(
|
|
72
|
+
["model_id"],
|
|
73
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_models.id"],
|
|
74
|
+
),
|
|
75
|
+
sa.ForeignKeyConstraint(
|
|
76
|
+
["model_selection_id"],
|
|
77
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_selections.id"],
|
|
78
|
+
ondelete="CASCADE",
|
|
79
|
+
),
|
|
80
|
+
sa.PrimaryKeyConstraint("id"),
|
|
81
|
+
sa.UniqueConstraint(
|
|
82
|
+
"model_id",
|
|
83
|
+
"model_selection_id",
|
|
84
|
+
name="uq_model_selection_score_composite",
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
op.create_index(
|
|
88
|
+
op.f("ix_model_selection_scores_id"),
|
|
89
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
90
|
+
["id"],
|
|
91
|
+
unique=False,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Migrate existing data from model_trainings and scores to model_selection_scores
|
|
95
|
+
op.execute(
|
|
96
|
+
f"""
|
|
97
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_model_selection_scores (
|
|
98
|
+
created_at, updated_at, best_params, model_path, training_time,
|
|
99
|
+
model_id, model_selection_id,
|
|
100
|
+
eval_data_std, rmse, rmse_std_ratio, mae, mape, mam, mad,
|
|
101
|
+
mae_mam_ratio, mae_mad_ratio, r2, logloss, accuracy, `precision`,
|
|
102
|
+
recall, f1, roc_auc, avg_precision, thresholds,
|
|
103
|
+
precision_at_threshold, recall_at_threshold, f1_at_threshold
|
|
104
|
+
)
|
|
105
|
+
SELECT
|
|
106
|
+
mt.created_at,
|
|
107
|
+
mt.updated_at,
|
|
108
|
+
mt.best_params,
|
|
109
|
+
mt.model_path,
|
|
110
|
+
COALESCE(mt.training_time, s.training_time) as training_time,
|
|
111
|
+
mt.model_id,
|
|
112
|
+
mt.model_selection_id,
|
|
113
|
+
s.eval_data_std, s.rmse, s.rmse_std_ratio, s.mae, s.mape,
|
|
114
|
+
s.mam, s.mad, s.mae_mam_ratio, s.mae_mad_ratio, s.r2,
|
|
115
|
+
s.logloss, s.accuracy, s.`precision`, s.recall, s.f1,
|
|
116
|
+
s.roc_auc, s.avg_precision, s.thresholds,
|
|
117
|
+
s.precision_at_threshold, s.recall_at_threshold, s.f1_at_threshold
|
|
118
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_trainings mt
|
|
119
|
+
LEFT JOIN {LECRAPAUD_TABLE_PREFIX}_scores s ON s.model_training_id = mt.id
|
|
120
|
+
"""
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Drop the old tables
|
|
124
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_scores")
|
|
125
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_model_trainings")
|
|
126
|
+
# ### end Alembic commands ###
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def downgrade() -> None:
|
|
130
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
|
131
|
+
# Recreate the old tables
|
|
132
|
+
op.create_table(
|
|
133
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_trainings",
|
|
134
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
135
|
+
sa.Column(
|
|
136
|
+
"created_at",
|
|
137
|
+
sa.TIMESTAMP(timezone=True),
|
|
138
|
+
server_default=sa.text("now()"),
|
|
139
|
+
nullable=False,
|
|
140
|
+
),
|
|
141
|
+
sa.Column(
|
|
142
|
+
"updated_at",
|
|
143
|
+
sa.TIMESTAMP(timezone=True),
|
|
144
|
+
server_default=sa.text("now()"),
|
|
145
|
+
nullable=False,
|
|
146
|
+
),
|
|
147
|
+
sa.Column("best_params", sa.JSON(), nullable=True),
|
|
148
|
+
sa.Column("model_path", sa.String(length=255), nullable=True),
|
|
149
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
150
|
+
sa.Column("model_id", sa.BigInteger(), nullable=False),
|
|
151
|
+
sa.Column("model_selection_id", sa.BigInteger(), nullable=False),
|
|
152
|
+
sa.ForeignKeyConstraint(
|
|
153
|
+
["model_id"],
|
|
154
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_models.id"],
|
|
155
|
+
),
|
|
156
|
+
sa.ForeignKeyConstraint(
|
|
157
|
+
["model_selection_id"],
|
|
158
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_selections.id"],
|
|
159
|
+
ondelete="CASCADE",
|
|
160
|
+
),
|
|
161
|
+
sa.PrimaryKeyConstraint("id"),
|
|
162
|
+
sa.UniqueConstraint(
|
|
163
|
+
"model_id", "model_selection_id", name="uq_model_training_composite"
|
|
164
|
+
),
|
|
165
|
+
)
|
|
166
|
+
op.create_index(
|
|
167
|
+
op.f("ix_model_trainings_id"),
|
|
168
|
+
f"{LECRAPAUD_TABLE_PREFIX}_model_trainings",
|
|
169
|
+
["id"],
|
|
170
|
+
unique=False,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
op.create_table(
|
|
174
|
+
f"{LECRAPAUD_TABLE_PREFIX}_scores",
|
|
175
|
+
sa.Column("id", sa.BigInteger(), autoincrement=True, nullable=False),
|
|
176
|
+
sa.Column(
|
|
177
|
+
"created_at",
|
|
178
|
+
sa.TIMESTAMP(timezone=True),
|
|
179
|
+
server_default=sa.text("now()"),
|
|
180
|
+
nullable=False,
|
|
181
|
+
),
|
|
182
|
+
sa.Column(
|
|
183
|
+
"updated_at",
|
|
184
|
+
sa.TIMESTAMP(timezone=True),
|
|
185
|
+
server_default=sa.text("now()"),
|
|
186
|
+
nullable=False,
|
|
187
|
+
),
|
|
188
|
+
sa.Column("type", sa.String(length=50), nullable=False),
|
|
189
|
+
sa.Column("training_time", sa.Integer(), nullable=True),
|
|
190
|
+
sa.Column("eval_data_std", sa.Float(), nullable=True),
|
|
191
|
+
sa.Column("rmse", sa.Float(), nullable=True),
|
|
192
|
+
sa.Column("rmse_std_ratio", sa.Float(), nullable=True),
|
|
193
|
+
sa.Column("mae", sa.Float(), nullable=True),
|
|
194
|
+
sa.Column("mape", sa.Float(), nullable=True),
|
|
195
|
+
sa.Column("mam", sa.Float(), nullable=True),
|
|
196
|
+
sa.Column("mad", sa.Float(), nullable=True),
|
|
197
|
+
sa.Column("mae_mam_ratio", sa.Float(), nullable=True),
|
|
198
|
+
sa.Column("mae_mad_ratio", sa.Float(), nullable=True),
|
|
199
|
+
sa.Column("r2", sa.Float(), nullable=True),
|
|
200
|
+
sa.Column("logloss", sa.Float(), nullable=True),
|
|
201
|
+
sa.Column("accuracy", sa.Float(), nullable=True),
|
|
202
|
+
sa.Column("precision", sa.Float(), nullable=True),
|
|
203
|
+
sa.Column("recall", sa.Float(), nullable=True),
|
|
204
|
+
sa.Column("f1", sa.Float(), nullable=True),
|
|
205
|
+
sa.Column("roc_auc", sa.Float(), nullable=True),
|
|
206
|
+
sa.Column("avg_precision", sa.Float(), nullable=True),
|
|
207
|
+
sa.Column("thresholds", sa.JSON(), nullable=True),
|
|
208
|
+
sa.Column("precision_at_threshold", sa.Float(), nullable=True),
|
|
209
|
+
sa.Column("recall_at_threshold", sa.Float(), nullable=True),
|
|
210
|
+
sa.Column("f1_at_threshold", sa.Float(), nullable=True),
|
|
211
|
+
sa.Column("model_training_id", sa.BigInteger(), nullable=False),
|
|
212
|
+
sa.ForeignKeyConstraint(
|
|
213
|
+
["model_training_id"],
|
|
214
|
+
[f"{LECRAPAUD_TABLE_PREFIX}_model_trainings.id"],
|
|
215
|
+
ondelete="CASCADE",
|
|
216
|
+
),
|
|
217
|
+
sa.PrimaryKeyConstraint("id"),
|
|
218
|
+
sa.UniqueConstraint(
|
|
219
|
+
"model_training_id", name="unique_score_per_model_training"
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
op.create_index(
|
|
223
|
+
op.f("ix_scores_id"), f"{LECRAPAUD_TABLE_PREFIX}_scores", ["id"], unique=False
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Migrate data back (note: we'll lose the type column data, defaulting to 'testset')
|
|
227
|
+
op.execute(
|
|
228
|
+
f"""
|
|
229
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_model_trainings (
|
|
230
|
+
id, created_at, updated_at, best_params, model_path,
|
|
231
|
+
training_time, model_id, model_selection_id
|
|
232
|
+
)
|
|
233
|
+
SELECT
|
|
234
|
+
id, created_at, updated_at, best_params, model_path,
|
|
235
|
+
training_time, model_id, model_selection_id
|
|
236
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_selection_scores
|
|
237
|
+
"""
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
op.execute(
|
|
241
|
+
f"""
|
|
242
|
+
INSERT INTO {LECRAPAUD_TABLE_PREFIX}_scores (
|
|
243
|
+
created_at, updated_at, type, training_time, eval_data_std,
|
|
244
|
+
rmse, rmse_std_ratio, mae, mape, mam, mad, mae_mam_ratio,
|
|
245
|
+
mae_mad_ratio, r2, logloss, accuracy, `precision`, recall,
|
|
246
|
+
f1, roc_auc, avg_precision, thresholds, precision_at_threshold,
|
|
247
|
+
recall_at_threshold, f1_at_threshold, model_training_id
|
|
248
|
+
)
|
|
249
|
+
SELECT
|
|
250
|
+
created_at, updated_at, 'testset', training_time, eval_data_std,
|
|
251
|
+
rmse, rmse_std_ratio, mae, mape, mam, mad, mae_mam_ratio,
|
|
252
|
+
mae_mad_ratio, r2, logloss, accuracy, precision, recall,
|
|
253
|
+
f1, roc_auc, avg_precision, thresholds, precision_at_threshold,
|
|
254
|
+
recall_at_threshold, f1_at_threshold, id
|
|
255
|
+
FROM {LECRAPAUD_TABLE_PREFIX}_model_selection_scores
|
|
256
|
+
"""
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
op.drop_index(
|
|
260
|
+
op.f("ix_model_selection_scores_id"),
|
|
261
|
+
table_name=f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores",
|
|
262
|
+
)
|
|
263
|
+
op.drop_table(f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores")
|
|
264
|
+
# ### end Alembic commands ###
|
lecrapaud/db/models/__init__.py
CHANGED
|
@@ -4,9 +4,8 @@ from lecrapaud.db.models.feature_selection_rank import FeatureSelectionRank
|
|
|
4
4
|
from lecrapaud.db.models.feature_selection import FeatureSelection
|
|
5
5
|
from lecrapaud.db.models.feature import Feature
|
|
6
6
|
from lecrapaud.db.models.model_selection import ModelSelection
|
|
7
|
-
from lecrapaud.db.models.model_training import ModelTraining
|
|
8
7
|
from lecrapaud.db.models.model import Model
|
|
9
|
-
from lecrapaud.db.models.
|
|
8
|
+
from lecrapaud.db.models.model_selection_score import ModelSelectionScore
|
|
10
9
|
from lecrapaud.db.models.target import Target
|
|
11
10
|
|
|
12
11
|
__all__ = [
|
|
@@ -16,8 +15,7 @@ __all__ = [
|
|
|
16
15
|
'FeatureSelection',
|
|
17
16
|
'Feature',
|
|
18
17
|
'ModelSelection',
|
|
19
|
-
'ModelTraining',
|
|
20
18
|
'Model',
|
|
21
|
-
'
|
|
19
|
+
'ModelSelectionScore',
|
|
22
20
|
'Target',
|
|
23
21
|
]
|
lecrapaud/db/models/base.py
CHANGED
|
@@ -10,24 +10,27 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
|
10
10
|
from lecrapaud.db.session import get_db
|
|
11
11
|
from sqlalchemy.ext.declarative import declared_attr
|
|
12
12
|
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
|
13
|
+
from sqlalchemy import UniqueConstraint
|
|
14
|
+
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
|
|
13
15
|
from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def with_db(func):
|
|
17
19
|
"""Decorator to provide a database session to the wrapped function.
|
|
18
|
-
|
|
20
|
+
|
|
19
21
|
If a db parameter is already provided, it will be used. Otherwise,
|
|
20
22
|
a new session will be created and automatically managed.
|
|
21
23
|
"""
|
|
24
|
+
|
|
22
25
|
@wraps(func)
|
|
23
26
|
def wrapper(*args, **kwargs):
|
|
24
27
|
if "db" in kwargs and kwargs["db"] is not None:
|
|
25
28
|
return func(*args, **kwargs)
|
|
26
|
-
|
|
29
|
+
|
|
27
30
|
with get_db() as db:
|
|
28
31
|
kwargs["db"] = db
|
|
29
32
|
return func(*args, **kwargs)
|
|
30
|
-
|
|
33
|
+
|
|
31
34
|
return wrapper
|
|
32
35
|
|
|
33
36
|
|
|
@@ -106,51 +109,6 @@ class Base(DeclarativeBase):
|
|
|
106
109
|
]
|
|
107
110
|
return results
|
|
108
111
|
|
|
109
|
-
@classmethod
|
|
110
|
-
@with_db
|
|
111
|
-
def upsert_bulk(cls, db=None, match_fields: list[str] = None, **kwargs):
|
|
112
|
-
"""
|
|
113
|
-
Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
db (Session): SQLAlchemy DB session
|
|
117
|
-
match_fields (list[str]): Fields to match on for deduplication
|
|
118
|
-
**kwargs: Column-wise keyword arguments (field_name=[...])
|
|
119
|
-
"""
|
|
120
|
-
# Ensure all provided fields have values of equal length
|
|
121
|
-
value_lengths = [len(v) for v in kwargs.values()]
|
|
122
|
-
if not value_lengths or len(set(value_lengths)) != 1:
|
|
123
|
-
raise ValueError(
|
|
124
|
-
"All field values must be non-empty lists of the same length."
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Convert column-wise kwargs to row-wise list of dicts
|
|
128
|
-
items = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
|
|
129
|
-
if not items:
|
|
130
|
-
return
|
|
131
|
-
|
|
132
|
-
stmt = mysql_insert(cls.__table__).values(items)
|
|
133
|
-
|
|
134
|
-
# Default to primary keys if match_fields not provided
|
|
135
|
-
if not match_fields:
|
|
136
|
-
match_fields = [col.name for col in cls.__table__.primary_key.columns]
|
|
137
|
-
|
|
138
|
-
# Ensure all columns to be updated are in the insert
|
|
139
|
-
update_dict = {
|
|
140
|
-
c.name: stmt.inserted[c.name]
|
|
141
|
-
for c in cls.__table__.columns
|
|
142
|
-
if c.name not in match_fields and c.name in items[0]
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
if not update_dict:
|
|
146
|
-
# Avoid triggering ON DUPLICATE KEY UPDATE with empty dict
|
|
147
|
-
db.execute(stmt.prefix_with("IGNORE"))
|
|
148
|
-
else:
|
|
149
|
-
upsert_stmt = stmt.on_duplicate_key_update(**update_dict)
|
|
150
|
-
db.execute(upsert_stmt)
|
|
151
|
-
|
|
152
|
-
db.commit()
|
|
153
|
-
|
|
154
112
|
@classmethod
|
|
155
113
|
@with_db
|
|
156
114
|
def filter(cls, db=None, **kwargs):
|
|
@@ -198,33 +156,113 @@ class Base(DeclarativeBase):
|
|
|
198
156
|
|
|
199
157
|
@classmethod
|
|
200
158
|
@with_db
|
|
201
|
-
def upsert(cls,
|
|
159
|
+
def upsert(cls, db=None, **kwargs):
|
|
202
160
|
"""
|
|
203
|
-
Upsert an instance of the model
|
|
161
|
+
Upsert an instance of the model using MySQL's ON DUPLICATE KEY UPDATE.
|
|
204
162
|
|
|
205
|
-
:param match_fields: list of field names to use for matching
|
|
206
163
|
:param kwargs: all fields for creation or update
|
|
207
164
|
"""
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
165
|
+
# Use INSERT ... ON DUPLICATE KEY UPDATE
|
|
166
|
+
stmt = mysql_insert(cls.__table__).values(**kwargs)
|
|
167
|
+
stmt = stmt.on_duplicate_key_update(
|
|
168
|
+
**{k: v for k, v in kwargs.items() if k != "id"}
|
|
169
|
+
)
|
|
213
170
|
|
|
214
|
-
|
|
171
|
+
result = db.execute(stmt)
|
|
172
|
+
db.commit()
|
|
215
173
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
174
|
+
# Get the instance - either the newly inserted or updated one
|
|
175
|
+
# If updated, lastrowid is 0, so we need to query
|
|
176
|
+
if result.lastrowid and result.lastrowid > 0:
|
|
177
|
+
# New insert
|
|
178
|
+
instance = db.get(cls, result.lastrowid)
|
|
220
179
|
else:
|
|
221
|
-
|
|
222
|
-
|
|
180
|
+
# Updated - need to find it using unique constraint fields
|
|
181
|
+
mapper = sqlalchemy_inspect(cls)
|
|
182
|
+
instance = None
|
|
183
|
+
|
|
184
|
+
for constraint in mapper.mapped_table.constraints:
|
|
185
|
+
if isinstance(constraint, UniqueConstraint):
|
|
186
|
+
col_names = [col.name for col in constraint.columns]
|
|
187
|
+
if all(name in kwargs for name in col_names):
|
|
188
|
+
filters = [
|
|
189
|
+
getattr(cls, col_name) == kwargs[col_name]
|
|
190
|
+
for col_name in col_names
|
|
191
|
+
]
|
|
192
|
+
instance = db.query(cls).filter(*filters).first()
|
|
193
|
+
if instance:
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
# Check for single column unique constraints
|
|
197
|
+
if not instance:
|
|
198
|
+
for col in mapper.mapped_table.columns:
|
|
199
|
+
if col.unique and col.name in kwargs:
|
|
200
|
+
instance = (
|
|
201
|
+
db.query(cls)
|
|
202
|
+
.filter(getattr(cls, col.name) == kwargs[col.name])
|
|
203
|
+
.first()
|
|
204
|
+
)
|
|
205
|
+
if instance:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
# If still not found, try to find by all kwargs (excluding None values)
|
|
209
|
+
if not instance:
|
|
210
|
+
instance = (
|
|
211
|
+
db.query(cls)
|
|
212
|
+
.filter_by(
|
|
213
|
+
**{
|
|
214
|
+
k: v
|
|
215
|
+
for k, v in kwargs.items()
|
|
216
|
+
if v is not None and k != "id"
|
|
217
|
+
}
|
|
218
|
+
)
|
|
219
|
+
.first()
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if instance:
|
|
223
|
+
db.refresh(instance)
|
|
223
224
|
|
|
224
|
-
db.commit()
|
|
225
|
-
db.refresh(instance)
|
|
226
225
|
return instance
|
|
227
226
|
|
|
227
|
+
@classmethod
|
|
228
|
+
@with_db
|
|
229
|
+
def bulk_upsert(cls, rows: list[dict] = None, db=None, **kwargs):
|
|
230
|
+
"""
|
|
231
|
+
Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
rows (list[dict]): List of dictionaries representing rows to upsert
|
|
235
|
+
db (Session): SQLAlchemy DB session
|
|
236
|
+
**kwargs: Column-wise keyword arguments (field_name=[...]) for backwards compatibility
|
|
237
|
+
"""
|
|
238
|
+
# Handle both new format (rows) and legacy format (kwargs)
|
|
239
|
+
if rows is None and kwargs:
|
|
240
|
+
# Legacy format: convert column-wise kwargs to row-wise list of dicts
|
|
241
|
+
value_lengths = [len(v) for v in kwargs.values()]
|
|
242
|
+
if not value_lengths or len(set(value_lengths)) != 1:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"All field values must be non-empty lists of the same length."
|
|
245
|
+
)
|
|
246
|
+
rows = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
|
|
247
|
+
|
|
248
|
+
if not rows:
|
|
249
|
+
return 0
|
|
250
|
+
|
|
251
|
+
BATCH_SIZE = 200
|
|
252
|
+
total_affected = 0
|
|
253
|
+
|
|
254
|
+
for i in range(0, len(rows), BATCH_SIZE):
|
|
255
|
+
batch = rows[i : i + BATCH_SIZE]
|
|
256
|
+
stmt = mysql_insert(cls.__table__).values(batch)
|
|
257
|
+
stmt = stmt.on_duplicate_key_update(
|
|
258
|
+
**{key: stmt.inserted[key] for key in batch[0] if key != "id"}
|
|
259
|
+
)
|
|
260
|
+
result = db.execute(stmt)
|
|
261
|
+
total_affected += result.rowcount
|
|
262
|
+
|
|
263
|
+
db.commit()
|
|
264
|
+
return total_affected
|
|
265
|
+
|
|
228
266
|
@classmethod
|
|
229
267
|
@with_db
|
|
230
268
|
def delete(cls, id: int, db=None):
|