lecrapaud 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lecrapaud might be problematic. Click here for more details.
- lecrapaud/__init__.py +1 -0
- lecrapaud/api.py +277 -0
- lecrapaud/config.py +10 -0
- lecrapaud/db/__init__.py +1 -0
- lecrapaud/db/alembic/env.py +2 -2
- lecrapaud/db/alembic/versions/2025_05_31_1834-52b809a34371_make_nullablee.py +24 -12
- lecrapaud/db/alembic/versions/2025_06_17_1652-c45f5e49fa2c_make_fields_nullable.py +89 -0
- lecrapaud/db/alembic.ini +116 -0
- lecrapaud/db/models/__init__.py +10 -10
- lecrapaud/db/models/base.py +176 -1
- lecrapaud/db/models/dataset.py +25 -20
- lecrapaud/db/models/feature.py +5 -6
- lecrapaud/db/models/feature_selection.py +3 -4
- lecrapaud/db/models/feature_selection_rank.py +3 -4
- lecrapaud/db/models/model.py +3 -4
- lecrapaud/db/models/model_selection.py +15 -8
- lecrapaud/db/models/model_training.py +15 -7
- lecrapaud/db/models/score.py +9 -6
- lecrapaud/db/models/target.py +16 -8
- lecrapaud/db/session.py +66 -0
- lecrapaud/experiment.py +64 -0
- lecrapaud/feature_engineering.py +747 -1022
- lecrapaud/feature_selection.py +915 -998
- lecrapaud/integrations/openai_integration.py +225 -0
- lecrapaud/jobs/__init__.py +2 -2
- lecrapaud/jobs/config.py +1 -1
- lecrapaud/jobs/scheduler.py +1 -1
- lecrapaud/jobs/tasks.py +6 -6
- lecrapaud/model_selection.py +1060 -960
- lecrapaud/search_space.py +4 -0
- lecrapaud/utils.py +2 -2
- lecrapaud-0.4.1.dist-info/METADATA +171 -0
- {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.1.dist-info}/RECORD +36 -35
- {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.1.dist-info}/WHEEL +1 -1
- lecrapaud/db/crud.py +0 -179
- lecrapaud/db/services.py +0 -0
- lecrapaud/db/setup.py +0 -58
- lecrapaud/predictions.py +0 -292
- lecrapaud/training.py +0 -151
- lecrapaud-0.4.0.dist-info/METADATA +0 -103
- /lecrapaud/{directory_management.py → directories.py} +0 -0
- {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.1.dist-info}/LICENSE +0 -0
lecrapaud/db/models/base.py
CHANGED
|
@@ -1,6 +1,181 @@
|
|
|
1
|
+
"""Base SQLAlchemy model with CRUD operations."""
|
|
2
|
+
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
1
5
|
from sqlalchemy.orm import DeclarativeBase
|
|
6
|
+
from sqlalchemy import desc, asc, and_, delete
|
|
7
|
+
from sqlalchemy.inspection import inspect
|
|
8
|
+
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
9
|
+
from lecrapaud.db.session import get_db
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def with_db(func):
|
|
13
|
+
"""Decorator to allow passing an optional db session"""
|
|
14
|
+
|
|
15
|
+
@wraps(func)
|
|
16
|
+
def wrapper(*args, **kwargs):
|
|
17
|
+
db = kwargs.pop("db", None)
|
|
18
|
+
if db:
|
|
19
|
+
return func(*args, db=db, **kwargs)
|
|
20
|
+
with get_db() as db:
|
|
21
|
+
return func(*args, db=db, **kwargs)
|
|
22
|
+
|
|
23
|
+
return wrapper
|
|
2
24
|
|
|
3
25
|
|
|
4
26
|
# declarative base class
|
|
5
27
|
class Base(DeclarativeBase):
|
|
6
|
-
|
|
28
|
+
@classmethod
|
|
29
|
+
@with_db
|
|
30
|
+
def create(cls, db, **kwargs):
|
|
31
|
+
instance = cls(**kwargs)
|
|
32
|
+
db.add(instance)
|
|
33
|
+
db.commit()
|
|
34
|
+
db.refresh(instance)
|
|
35
|
+
return instance
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
@with_db
|
|
39
|
+
def get(cls, id: int, db=None):
|
|
40
|
+
return db.get(cls, id)
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
@with_db
|
|
44
|
+
def find_by(cls, db=None, **kwargs):
|
|
45
|
+
return db.query(cls).filter_by(**kwargs).first()
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
@with_db
|
|
49
|
+
def get_all(
|
|
50
|
+
cls, raw=False, db=None, limit: int = 100, order: str = "desc", **kwargs
|
|
51
|
+
):
|
|
52
|
+
order_by_field = (
|
|
53
|
+
desc(cls.created_at) if order == "desc" else asc(cls.created_at)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
query = db.query(cls)
|
|
57
|
+
|
|
58
|
+
# Apply filters from kwargs
|
|
59
|
+
for key, value in kwargs.items():
|
|
60
|
+
if hasattr(cls, key):
|
|
61
|
+
query = query.filter(getattr(cls, key) == value)
|
|
62
|
+
|
|
63
|
+
results = query.order_by(order_by_field).limit(limit).all()
|
|
64
|
+
|
|
65
|
+
if raw:
|
|
66
|
+
return [
|
|
67
|
+
{
|
|
68
|
+
column.name: getattr(row, column.name)
|
|
69
|
+
for column in cls.__table__.columns
|
|
70
|
+
}
|
|
71
|
+
for row in results
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
return results
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
@with_db
|
|
78
|
+
def filter(cls, db=None, **kwargs):
|
|
79
|
+
filters = []
|
|
80
|
+
|
|
81
|
+
for key, value in kwargs.items():
|
|
82
|
+
if "__" in key:
|
|
83
|
+
field, op = key.split("__", 1)
|
|
84
|
+
else:
|
|
85
|
+
field, op = key, "eq"
|
|
86
|
+
|
|
87
|
+
if not hasattr(cls, field):
|
|
88
|
+
raise ValueError(f"{field} is not a valid field on {cls.__name__}")
|
|
89
|
+
|
|
90
|
+
column: InstrumentedAttribute = getattr(cls, field)
|
|
91
|
+
|
|
92
|
+
if op == "eq":
|
|
93
|
+
filters.append(column == value)
|
|
94
|
+
elif op == "in":
|
|
95
|
+
filters.append(column.in_(value))
|
|
96
|
+
elif op == "gt":
|
|
97
|
+
filters.append(column > value)
|
|
98
|
+
elif op == "lt":
|
|
99
|
+
filters.append(column < value)
|
|
100
|
+
elif op == "gte":
|
|
101
|
+
filters.append(column >= value)
|
|
102
|
+
elif op == "lte":
|
|
103
|
+
filters.append(column <= value)
|
|
104
|
+
else:
|
|
105
|
+
raise ValueError(f"Unsupported operator: {op}")
|
|
106
|
+
|
|
107
|
+
return db.query(cls).filter(and_(*filters)).all()
|
|
108
|
+
|
|
109
|
+
@classmethod
|
|
110
|
+
@with_db
|
|
111
|
+
def update(cls, id: int, db=None, **kwargs):
|
|
112
|
+
instance = db.get(cls, id)
|
|
113
|
+
if not instance:
|
|
114
|
+
return None
|
|
115
|
+
for key, value in kwargs.items():
|
|
116
|
+
setattr(instance, key, value)
|
|
117
|
+
db.commit()
|
|
118
|
+
db.refresh(instance)
|
|
119
|
+
return instance
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
@with_db
|
|
123
|
+
def upsert(cls, match_fields: list[str], db=None, **kwargs):
|
|
124
|
+
"""
|
|
125
|
+
Upsert an instance of the model: update if found, else create.
|
|
126
|
+
|
|
127
|
+
:param match_fields: list of field names to use for matching
|
|
128
|
+
:param kwargs: all fields for creation or update
|
|
129
|
+
"""
|
|
130
|
+
filters = [
|
|
131
|
+
getattr(cls, field) == kwargs[field]
|
|
132
|
+
for field in match_fields
|
|
133
|
+
if field in kwargs
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
instance = db.query(cls).filter(*filters).first()
|
|
137
|
+
|
|
138
|
+
if instance:
|
|
139
|
+
for key, value in kwargs.items():
|
|
140
|
+
setattr(instance, key, value)
|
|
141
|
+
else:
|
|
142
|
+
instance = cls(**kwargs)
|
|
143
|
+
db.add(instance)
|
|
144
|
+
|
|
145
|
+
db.commit()
|
|
146
|
+
db.refresh(instance)
|
|
147
|
+
return instance
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
@with_db
|
|
151
|
+
def delete(cls, id: int, db=None):
|
|
152
|
+
instance = db.get(cls, id)
|
|
153
|
+
if instance:
|
|
154
|
+
db.delete(instance)
|
|
155
|
+
db.commit()
|
|
156
|
+
return True
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
@with_db
|
|
161
|
+
def delete_all(cls, db=None, **kwargs):
|
|
162
|
+
stmt = delete(cls)
|
|
163
|
+
|
|
164
|
+
for key, value in kwargs.items():
|
|
165
|
+
if hasattr(cls, key):
|
|
166
|
+
stmt = stmt.where(getattr(cls, key) == value)
|
|
167
|
+
|
|
168
|
+
db.execute(stmt)
|
|
169
|
+
db.commit()
|
|
170
|
+
return True
|
|
171
|
+
|
|
172
|
+
@with_db
|
|
173
|
+
def save(self, db=None):
|
|
174
|
+
self = db.merge(self)
|
|
175
|
+
db.add(self)
|
|
176
|
+
db.commit()
|
|
177
|
+
db.refresh(self)
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
def to_json(self):
|
|
181
|
+
return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}
|
lecrapaud/db/models/dataset.py
CHANGED
|
@@ -17,9 +17,8 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
17
17
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
18
18
|
from itertools import chain
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from src.db.crud import CRUDMixin, with_db
|
|
20
|
+
from lecrapaud.db.session import get_db
|
|
21
|
+
from lecrapaud.db.models.base import Base
|
|
23
22
|
|
|
24
23
|
# jointures
|
|
25
24
|
dataset_target_association = Table(
|
|
@@ -40,7 +39,7 @@ dataset_target_association = Table(
|
|
|
40
39
|
)
|
|
41
40
|
|
|
42
41
|
|
|
43
|
-
class Dataset(Base
|
|
42
|
+
class Dataset(Base):
|
|
44
43
|
__tablename__ = "datasets"
|
|
45
44
|
|
|
46
45
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -57,22 +56,22 @@ class Dataset(Base, CRUDMixin):
|
|
|
57
56
|
path = Column(String(255)) # we do not have this at creation time
|
|
58
57
|
type = Column(String(50), nullable=False)
|
|
59
58
|
size = Column(Integer, nullable=False)
|
|
60
|
-
train_size = Column(Integer
|
|
59
|
+
train_size = Column(Integer)
|
|
61
60
|
val_size = Column(Integer)
|
|
62
|
-
test_size = Column(Integer
|
|
63
|
-
number_of_groups = Column(Integer, nullable=False)
|
|
64
|
-
list_of_groups = Column(JSON, nullable=False)
|
|
61
|
+
test_size = Column(Integer)
|
|
65
62
|
corr_threshold = Column(Float, nullable=False)
|
|
66
63
|
max_features = Column(Integer, nullable=False)
|
|
67
64
|
percentile = Column(Float, nullable=False)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
65
|
+
number_of_groups = Column(Integer)
|
|
66
|
+
list_of_groups = Column(JSON)
|
|
67
|
+
start_date = Column(DateTime)
|
|
68
|
+
end_date = Column(DateTime)
|
|
69
|
+
train_start_date = Column(DateTime)
|
|
70
|
+
train_end_date = Column(DateTime)
|
|
72
71
|
val_start_date = Column(DateTime)
|
|
73
72
|
val_end_date = Column(DateTime)
|
|
74
|
-
test_start_date = Column(DateTime
|
|
75
|
-
test_end_date = Column(DateTime
|
|
73
|
+
test_start_date = Column(DateTime)
|
|
74
|
+
test_end_date = Column(DateTime)
|
|
76
75
|
|
|
77
76
|
feature_selections = relationship(
|
|
78
77
|
"FeatureSelection",
|
|
@@ -111,14 +110,20 @@ class Dataset(Base, CRUDMixin):
|
|
|
111
110
|
feature = [f.name for f in feature_selection.features]
|
|
112
111
|
return feature
|
|
113
112
|
|
|
114
|
-
def get_all_features(self):
|
|
113
|
+
def get_all_features(self, date_column: str = None, group_column: str = None):
|
|
115
114
|
target_idx = [target.id for target in self.targets]
|
|
116
|
-
all_features =
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
115
|
+
all_features = []
|
|
116
|
+
if date_column:
|
|
117
|
+
all_features.append(date_column)
|
|
118
|
+
if group_column:
|
|
119
|
+
all_features.append(group_column)
|
|
120
|
+
all_features += list(
|
|
121
|
+
chain.from_iterable(
|
|
122
|
+
[f.name for f in fs.features]
|
|
123
|
+
for fs in self.feature_selections
|
|
124
|
+
if fs.target_id in target_idx
|
|
125
|
+
)
|
|
120
126
|
)
|
|
121
|
-
all_features = ["DATE", "STOCK"] + list(all_features)
|
|
122
127
|
all_features = list(dict.fromkeys(all_features))
|
|
123
128
|
|
|
124
129
|
return all_features
|
lecrapaud/db/models/feature.py
CHANGED
|
@@ -16,13 +16,12 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
16
16
|
|
|
17
17
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
18
18
|
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from src.db.crud import CRUDMixin
|
|
19
|
+
from lecrapaud.db.session import get_db
|
|
20
|
+
from lecrapaud.db.models.base import Base
|
|
21
|
+
from lecrapaud.db.models.feature_selection import feature_selection_association
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class Feature(Base
|
|
24
|
+
class Feature(Base):
|
|
26
25
|
__tablename__ = "features"
|
|
27
26
|
|
|
28
27
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -42,5 +41,5 @@ class Feature(Base, CRUDMixin):
|
|
|
42
41
|
"FeatureSelection",
|
|
43
42
|
secondary=feature_selection_association,
|
|
44
43
|
back_populates="features",
|
|
45
|
-
lazy="selectin"
|
|
44
|
+
lazy="selectin",
|
|
46
45
|
)
|
|
@@ -24,9 +24,8 @@ from sqlalchemy.orm import (
|
|
|
24
24
|
)
|
|
25
25
|
from collections.abc import Iterable
|
|
26
26
|
|
|
27
|
-
from
|
|
28
|
-
from
|
|
29
|
-
from src.db.crud import CRUDMixin, with_db
|
|
27
|
+
from lecrapaud.db.session import get_db
|
|
28
|
+
from lecrapaud.db.models.base import Base, with_db
|
|
30
29
|
|
|
31
30
|
# jointures
|
|
32
31
|
feature_selection_association = Table(
|
|
@@ -47,7 +46,7 @@ feature_selection_association = Table(
|
|
|
47
46
|
)
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
class FeatureSelection(Base
|
|
49
|
+
class FeatureSelection(Base):
|
|
51
50
|
__tablename__ = "feature_selections"
|
|
52
51
|
|
|
53
52
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -18,12 +18,11 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
18
18
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
19
19
|
from sqlalchemy.dialects.mysql import insert
|
|
20
20
|
|
|
21
|
-
from
|
|
22
|
-
from
|
|
23
|
-
from src.db.crud import CRUDMixin, with_db
|
|
21
|
+
from lecrapaud.db.session import get_db
|
|
22
|
+
from lecrapaud.db.models.base import Base, with_db
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
class FeatureSelectionRank(Base
|
|
25
|
+
class FeatureSelectionRank(Base):
|
|
27
26
|
__tablename__ = "feature_selection_ranks"
|
|
28
27
|
|
|
29
28
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
lecrapaud/db/models/model.py
CHANGED
|
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
17
17
|
|
|
18
18
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from src.db.crud import CRUDMixin
|
|
20
|
+
from lecrapaud.db.session import get_db
|
|
21
|
+
from lecrapaud.db.models.base import Base
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class Model(Base
|
|
24
|
+
class Model(Base):
|
|
26
25
|
__tablename__ = "models"
|
|
27
26
|
|
|
28
27
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
17
17
|
|
|
18
18
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from src.db.crud import CRUDMixin
|
|
20
|
+
from lecrapaud.db.session import get_db
|
|
21
|
+
from lecrapaud.db.models.base import Base
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class ModelSelection(Base
|
|
24
|
+
class ModelSelection(Base):
|
|
26
25
|
__tablename__ = "model_selections"
|
|
27
26
|
|
|
28
27
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -38,11 +37,20 @@ class ModelSelection(Base, CRUDMixin):
|
|
|
38
37
|
best_model_params = Column(JSON)
|
|
39
38
|
best_model_path = Column(String(255))
|
|
40
39
|
best_model_id = Column(BigInteger, ForeignKey("models.id", ondelete="CASCADE"))
|
|
41
|
-
target_id = Column(
|
|
42
|
-
|
|
40
|
+
target_id = Column(
|
|
41
|
+
BigInteger, ForeignKey("targets.id", ondelete="CASCADE"), nullable=False
|
|
42
|
+
)
|
|
43
|
+
dataset_id = Column(
|
|
44
|
+
BigInteger, ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False
|
|
45
|
+
)
|
|
43
46
|
|
|
44
47
|
best_model = relationship("Model", lazy="selectin")
|
|
45
|
-
model_trainings = relationship(
|
|
48
|
+
model_trainings = relationship(
|
|
49
|
+
"ModelTraining",
|
|
50
|
+
back_populates="model_selection",
|
|
51
|
+
cascade="all, delete-orphan",
|
|
52
|
+
lazy="selectin",
|
|
53
|
+
)
|
|
46
54
|
dataset = relationship(
|
|
47
55
|
"Dataset", back_populates="model_selections", lazy="selectin"
|
|
48
56
|
)
|
|
@@ -53,4 +61,3 @@ class ModelSelection(Base, CRUDMixin):
|
|
|
53
61
|
"target_id", "dataset_id", name="uq_model_selection_composite"
|
|
54
62
|
),
|
|
55
63
|
)
|
|
56
|
-
|
|
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
17
17
|
|
|
18
18
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from src.db.crud import CRUDMixin
|
|
20
|
+
from lecrapaud.db.session import get_db
|
|
21
|
+
from lecrapaud.db.models.base import Base
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
class ModelTraining(Base
|
|
24
|
+
class ModelTraining(Base):
|
|
26
25
|
__tablename__ = "model_trainings"
|
|
27
26
|
|
|
28
27
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -40,12 +39,21 @@ class ModelTraining(Base, CRUDMixin):
|
|
|
40
39
|
training_time = Column(Integer)
|
|
41
40
|
model_id = Column(BigInteger, ForeignKey("models.id"), nullable=False)
|
|
42
41
|
model_selection_id = Column(
|
|
43
|
-
BigInteger,
|
|
42
|
+
BigInteger,
|
|
43
|
+
ForeignKey("model_selections.id", ondelete="CASCADE"),
|
|
44
|
+
nullable=False,
|
|
44
45
|
)
|
|
45
46
|
|
|
46
47
|
model = relationship("Model", lazy="selectin")
|
|
47
|
-
model_selection = relationship(
|
|
48
|
-
|
|
48
|
+
model_selection = relationship(
|
|
49
|
+
"ModelSelection", back_populates="model_trainings", lazy="selectin"
|
|
50
|
+
)
|
|
51
|
+
score = relationship(
|
|
52
|
+
"Score",
|
|
53
|
+
back_populates="model_trainings",
|
|
54
|
+
cascade="all, delete-orphan",
|
|
55
|
+
lazy="selectin",
|
|
56
|
+
)
|
|
49
57
|
|
|
50
58
|
__table_args__ = (
|
|
51
59
|
UniqueConstraint(
|
lecrapaud/db/models/score.py
CHANGED
|
@@ -16,12 +16,11 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
16
16
|
|
|
17
17
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
18
18
|
|
|
19
|
-
from
|
|
20
|
-
from
|
|
21
|
-
from src.db.crud import CRUDMixin
|
|
19
|
+
from lecrapaud.db.session import get_db
|
|
20
|
+
from lecrapaud.db.models.base import Base
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
class Score(Base
|
|
23
|
+
class Score(Base):
|
|
25
24
|
__tablename__ = "scores"
|
|
26
25
|
|
|
27
26
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -57,6 +56,10 @@ class Score(Base, CRUDMixin):
|
|
|
57
56
|
threshold = Column(Float)
|
|
58
57
|
precision_at_threshold = Column(Float)
|
|
59
58
|
recall_at_threshold = Column(Float)
|
|
60
|
-
model_training_id = Column(
|
|
59
|
+
model_training_id = Column(
|
|
60
|
+
BigInteger, ForeignKey("model_trainings.id", ondelete="CASCADE"), nullable=False
|
|
61
|
+
)
|
|
61
62
|
|
|
62
|
-
model_trainings = relationship(
|
|
63
|
+
model_trainings = relationship(
|
|
64
|
+
"ModelTraining", back_populates="score", lazy="selectin"
|
|
65
|
+
)
|
lecrapaud/db/models/target.py
CHANGED
|
@@ -17,13 +17,12 @@ from sqlalchemy import desc, asc, cast, text, func
|
|
|
17
17
|
|
|
18
18
|
from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
|
|
19
19
|
|
|
20
|
-
from
|
|
21
|
-
from
|
|
22
|
-
from
|
|
23
|
-
from src.db.models.dataset import dataset_target_association
|
|
20
|
+
from lecrapaud.db.session import get_db
|
|
21
|
+
from lecrapaud.db.models.base import Base
|
|
22
|
+
from lecrapaud.db.models.dataset import dataset_target_association
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
class Target(Base
|
|
25
|
+
class Target(Base):
|
|
27
26
|
__tablename__ = "targets"
|
|
28
27
|
|
|
29
28
|
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
|
|
@@ -41,13 +40,22 @@ class Target(Base, CRUDMixin):
|
|
|
41
40
|
description = Column(String(255))
|
|
42
41
|
|
|
43
42
|
datasets = relationship(
|
|
44
|
-
"Dataset",
|
|
43
|
+
"Dataset",
|
|
44
|
+
secondary=dataset_target_association,
|
|
45
|
+
back_populates="targets",
|
|
46
|
+
lazy="selectin",
|
|
45
47
|
)
|
|
46
48
|
feature_selections = relationship(
|
|
47
|
-
"FeatureSelection",
|
|
49
|
+
"FeatureSelection",
|
|
50
|
+
back_populates="target",
|
|
51
|
+
cascade="all, delete-orphan",
|
|
52
|
+
lazy="selectin",
|
|
48
53
|
)
|
|
49
54
|
model_selections = relationship(
|
|
50
|
-
"ModelSelection",
|
|
55
|
+
"ModelSelection",
|
|
56
|
+
back_populates="target",
|
|
57
|
+
cascade="all, delete-orphan",
|
|
58
|
+
lazy="selectin",
|
|
51
59
|
)
|
|
52
60
|
|
|
53
61
|
__table_args__ = (
|
lecrapaud/db/session.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Database session management and initialization module."""
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from sqlalchemy import create_engine, text
|
|
5
|
+
from sqlalchemy.orm import sessionmaker
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
from alembic.config import Config
|
|
8
|
+
from alembic import command
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
from lecrapaud.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_URI
|
|
12
|
+
|
|
13
|
+
_engine = None
|
|
14
|
+
_SessionLocal = None
|
|
15
|
+
DATABASE_URL = (
|
|
16
|
+
f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" or DB_URI
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def init_db(uri: str = None):
|
|
21
|
+
global _engine, _SessionLocal
|
|
22
|
+
|
|
23
|
+
uri = uri if uri else DATABASE_URL
|
|
24
|
+
# Extract DB name from URI to connect without it
|
|
25
|
+
parsed = urlparse(uri)
|
|
26
|
+
db_name = parsed.path.lstrip("/") # remove leading slash
|
|
27
|
+
|
|
28
|
+
# Build root engine (no database in URI)
|
|
29
|
+
root_uri = uri.replace(f"/{db_name}", "/")
|
|
30
|
+
|
|
31
|
+
# Step 1: Connect to MySQL without a database
|
|
32
|
+
root_engine = create_engine(root_uri)
|
|
33
|
+
|
|
34
|
+
# Step 2: Create database if it doesn't exist
|
|
35
|
+
with root_engine.connect() as conn:
|
|
36
|
+
conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {DB_NAME}"))
|
|
37
|
+
conn.commit()
|
|
38
|
+
|
|
39
|
+
# Step 3: Connect to the newly created database
|
|
40
|
+
_engine = create_engine(uri, echo=False)
|
|
41
|
+
|
|
42
|
+
# Step 4: Create session factory
|
|
43
|
+
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
|
44
|
+
|
|
45
|
+
# Step 5: Apply Alembic migrations programmatically
|
|
46
|
+
project_root = os.path.abspath(os.path.dirname(__file__))
|
|
47
|
+
alembic_cfg_path = os.path.join(project_root, "alembic.ini")
|
|
48
|
+
|
|
49
|
+
alembic_cfg = Config(alembic_cfg_path)
|
|
50
|
+
alembic_cfg.set_main_option("sqlalchemy.url", uri or os.getenv("DATABASE_URL"))
|
|
51
|
+
command.upgrade(alembic_cfg, "head")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# Dependency to get a session instance
|
|
55
|
+
@contextmanager
|
|
56
|
+
def get_db():
|
|
57
|
+
if _SessionLocal is None:
|
|
58
|
+
raise RuntimeError("Database not initialized. Call `init_db()` first.")
|
|
59
|
+
db = _SessionLocal()
|
|
60
|
+
try:
|
|
61
|
+
yield db
|
|
62
|
+
except Exception as e:
|
|
63
|
+
db.rollback()
|
|
64
|
+
raise Exception(e) from e
|
|
65
|
+
finally:
|
|
66
|
+
db.close()
|
lecrapaud/experiment.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
os.environ["COVERAGE_FILE"] = str(Path(".coverage").resolve())
|
|
6
|
+
|
|
7
|
+
# Internal
|
|
8
|
+
from lecrapaud.directories import tmp_dir
|
|
9
|
+
from lecrapaud.utils import logger
|
|
10
|
+
from lecrapaud.config import PYTHON_ENV
|
|
11
|
+
from lecrapaud.db import (
|
|
12
|
+
Dataset,
|
|
13
|
+
Target,
|
|
14
|
+
)
|
|
15
|
+
from lecrapaud.db.session import get_db
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_dataset(
|
|
19
|
+
data: pd.DataFrame,
|
|
20
|
+
corr_threshold,
|
|
21
|
+
percentile,
|
|
22
|
+
max_features,
|
|
23
|
+
date_column,
|
|
24
|
+
group_column,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
dates = {}
|
|
28
|
+
if date_column:
|
|
29
|
+
dates["start_date"] = pd.to_datetime(data[date_column].iat[0])
|
|
30
|
+
dates["end_date"] = pd.to_datetime(data[date_column].iat[-1])
|
|
31
|
+
|
|
32
|
+
groups = {}
|
|
33
|
+
if group_column:
|
|
34
|
+
groups["number_of_groups"] = data[group_column].nunique()
|
|
35
|
+
groups["list_of_groups"] = data[group_column].unique().tolist()
|
|
36
|
+
|
|
37
|
+
with get_db() as db:
|
|
38
|
+
all_targets = Target.get_all(db=db)
|
|
39
|
+
targets = [target for target in all_targets if target.name in data.columns]
|
|
40
|
+
dataset_name = f"data_{groups["number_of_groups"] if group_column else 'ng'}_{corr_threshold}_{percentile}_{max_features}_{dates['start_date'].date() if date_column else 'nd'}_{dates['end_date'].date() if date_column else 'nd'}"
|
|
41
|
+
|
|
42
|
+
dataset_dir = f"{tmp_dir}/{dataset_name}"
|
|
43
|
+
preprocessing_dir = f"{dataset_dir}/preprocessing"
|
|
44
|
+
data_dir = f"{dataset_dir}/data"
|
|
45
|
+
os.makedirs(dataset_dir, exist_ok=True)
|
|
46
|
+
os.makedirs(preprocessing_dir, exist_ok=True)
|
|
47
|
+
os.makedirs(data_dir, exist_ok=True)
|
|
48
|
+
|
|
49
|
+
dataset = Dataset.upsert(
|
|
50
|
+
match_fields=["name"],
|
|
51
|
+
db=db,
|
|
52
|
+
name=dataset_name,
|
|
53
|
+
path=Path(dataset_dir).resolve(),
|
|
54
|
+
type="training",
|
|
55
|
+
size=data.shape[0],
|
|
56
|
+
corr_threshold=corr_threshold,
|
|
57
|
+
percentile=percentile,
|
|
58
|
+
max_features=max_features,
|
|
59
|
+
**groups,
|
|
60
|
+
**dates,
|
|
61
|
+
targets=targets,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return dataset
|