lecrapaud 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lecrapaud might be problematic. Click here for more details.

Files changed (42) hide show
  1. lecrapaud/__init__.py +1 -0
  2. lecrapaud/api.py +277 -0
  3. lecrapaud/config.py +10 -0
  4. lecrapaud/db/__init__.py +1 -0
  5. lecrapaud/db/alembic/env.py +2 -2
  6. lecrapaud/db/alembic/versions/2025_05_31_1834-52b809a34371_make_nullablee.py +24 -12
  7. lecrapaud/db/alembic/versions/2025_06_17_1652-c45f5e49fa2c_make_fields_nullable.py +89 -0
  8. lecrapaud/db/alembic.ini +116 -0
  9. lecrapaud/db/models/__init__.py +10 -10
  10. lecrapaud/db/models/base.py +176 -1
  11. lecrapaud/db/models/dataset.py +25 -20
  12. lecrapaud/db/models/feature.py +5 -6
  13. lecrapaud/db/models/feature_selection.py +3 -4
  14. lecrapaud/db/models/feature_selection_rank.py +3 -4
  15. lecrapaud/db/models/model.py +3 -4
  16. lecrapaud/db/models/model_selection.py +15 -8
  17. lecrapaud/db/models/model_training.py +15 -7
  18. lecrapaud/db/models/score.py +9 -6
  19. lecrapaud/db/models/target.py +16 -8
  20. lecrapaud/db/session.py +68 -0
  21. lecrapaud/experiment.py +64 -0
  22. lecrapaud/feature_engineering.py +747 -1022
  23. lecrapaud/feature_selection.py +915 -998
  24. lecrapaud/integrations/openai_integration.py +225 -0
  25. lecrapaud/jobs/__init__.py +2 -2
  26. lecrapaud/jobs/config.py +1 -1
  27. lecrapaud/jobs/scheduler.py +1 -1
  28. lecrapaud/jobs/tasks.py +6 -6
  29. lecrapaud/model_selection.py +1060 -960
  30. lecrapaud/search_space.py +4 -0
  31. lecrapaud/utils.py +2 -2
  32. lecrapaud-0.4.2.dist-info/METADATA +177 -0
  33. {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.2.dist-info}/RECORD +36 -35
  34. {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.2.dist-info}/WHEEL +1 -1
  35. lecrapaud/db/crud.py +0 -179
  36. lecrapaud/db/services.py +0 -0
  37. lecrapaud/db/setup.py +0 -58
  38. lecrapaud/predictions.py +0 -292
  39. lecrapaud/training.py +0 -151
  40. lecrapaud-0.4.0.dist-info/METADATA +0 -103
  41. /lecrapaud/{directory_management.py → directories.py} +0 -0
  42. {lecrapaud-0.4.0.dist-info → lecrapaud-0.4.2.dist-info}/LICENSE +0 -0
@@ -1,6 +1,181 @@
1
+ """Base SQLAlchemy model with CRUD operations."""
2
+
3
+ from functools import wraps
4
+
1
5
  from sqlalchemy.orm import DeclarativeBase
6
+ from sqlalchemy import desc, asc, and_, delete
7
+ from sqlalchemy.inspection import inspect
8
+ from sqlalchemy.orm.attributes import InstrumentedAttribute
9
+ from lecrapaud.db.session import get_db
10
+
11
+
12
+ def with_db(func):
13
+ """Decorator to allow passing an optional db session"""
14
+
15
+ @wraps(func)
16
+ def wrapper(*args, **kwargs):
17
+ db = kwargs.pop("db", None)
18
+ if db:
19
+ return func(*args, db=db, **kwargs)
20
+ with get_db() as db:
21
+ return func(*args, db=db, **kwargs)
22
+
23
+ return wrapper
2
24
 
3
25
 
4
26
  # declarative base class
5
27
  class Base(DeclarativeBase):
6
- pass
28
+ @classmethod
29
+ @with_db
30
+ def create(cls, db, **kwargs):
31
+ instance = cls(**kwargs)
32
+ db.add(instance)
33
+ db.commit()
34
+ db.refresh(instance)
35
+ return instance
36
+
37
+ @classmethod
38
+ @with_db
39
+ def get(cls, id: int, db=None):
40
+ return db.get(cls, id)
41
+
42
+ @classmethod
43
+ @with_db
44
+ def find_by(cls, db=None, **kwargs):
45
+ return db.query(cls).filter_by(**kwargs).first()
46
+
47
+ @classmethod
48
+ @with_db
49
+ def get_all(
50
+ cls, raw=False, db=None, limit: int = 100, order: str = "desc", **kwargs
51
+ ):
52
+ order_by_field = (
53
+ desc(cls.created_at) if order == "desc" else asc(cls.created_at)
54
+ )
55
+
56
+ query = db.query(cls)
57
+
58
+ # Apply filters from kwargs
59
+ for key, value in kwargs.items():
60
+ if hasattr(cls, key):
61
+ query = query.filter(getattr(cls, key) == value)
62
+
63
+ results = query.order_by(order_by_field).limit(limit).all()
64
+
65
+ if raw:
66
+ return [
67
+ {
68
+ column.name: getattr(row, column.name)
69
+ for column in cls.__table__.columns
70
+ }
71
+ for row in results
72
+ ]
73
+
74
+ return results
75
+
76
+ @classmethod
77
+ @with_db
78
+ def filter(cls, db=None, **kwargs):
79
+ filters = []
80
+
81
+ for key, value in kwargs.items():
82
+ if "__" in key:
83
+ field, op = key.split("__", 1)
84
+ else:
85
+ field, op = key, "eq"
86
+
87
+ if not hasattr(cls, field):
88
+ raise ValueError(f"{field} is not a valid field on {cls.__name__}")
89
+
90
+ column: InstrumentedAttribute = getattr(cls, field)
91
+
92
+ if op == "eq":
93
+ filters.append(column == value)
94
+ elif op == "in":
95
+ filters.append(column.in_(value))
96
+ elif op == "gt":
97
+ filters.append(column > value)
98
+ elif op == "lt":
99
+ filters.append(column < value)
100
+ elif op == "gte":
101
+ filters.append(column >= value)
102
+ elif op == "lte":
103
+ filters.append(column <= value)
104
+ else:
105
+ raise ValueError(f"Unsupported operator: {op}")
106
+
107
+ return db.query(cls).filter(and_(*filters)).all()
108
+
109
+ @classmethod
110
+ @with_db
111
+ def update(cls, id: int, db=None, **kwargs):
112
+ instance = db.get(cls, id)
113
+ if not instance:
114
+ return None
115
+ for key, value in kwargs.items():
116
+ setattr(instance, key, value)
117
+ db.commit()
118
+ db.refresh(instance)
119
+ return instance
120
+
121
+ @classmethod
122
+ @with_db
123
+ def upsert(cls, match_fields: list[str], db=None, **kwargs):
124
+ """
125
+ Upsert an instance of the model: update if found, else create.
126
+
127
+ :param match_fields: list of field names to use for matching
128
+ :param kwargs: all fields for creation or update
129
+ """
130
+ filters = [
131
+ getattr(cls, field) == kwargs[field]
132
+ for field in match_fields
133
+ if field in kwargs
134
+ ]
135
+
136
+ instance = db.query(cls).filter(*filters).first()
137
+
138
+ if instance:
139
+ for key, value in kwargs.items():
140
+ setattr(instance, key, value)
141
+ else:
142
+ instance = cls(**kwargs)
143
+ db.add(instance)
144
+
145
+ db.commit()
146
+ db.refresh(instance)
147
+ return instance
148
+
149
+ @classmethod
150
+ @with_db
151
+ def delete(cls, id: int, db=None):
152
+ instance = db.get(cls, id)
153
+ if instance:
154
+ db.delete(instance)
155
+ db.commit()
156
+ return True
157
+ return False
158
+
159
+ @classmethod
160
+ @with_db
161
+ def delete_all(cls, db=None, **kwargs):
162
+ stmt = delete(cls)
163
+
164
+ for key, value in kwargs.items():
165
+ if hasattr(cls, key):
166
+ stmt = stmt.where(getattr(cls, key) == value)
167
+
168
+ db.execute(stmt)
169
+ db.commit()
170
+ return True
171
+
172
+ @with_db
173
+ def save(self, db=None):
174
+ self = db.merge(self)
175
+ db.add(self)
176
+ db.commit()
177
+ db.refresh(self)
178
+ return self
179
+
180
+ def to_json(self):
181
+ return {c.key: getattr(self, c.key) for c in inspect(self).mapper.column_attrs}
@@ -17,9 +17,8 @@ from sqlalchemy import desc, asc, cast, text, func
17
17
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
18
18
  from itertools import chain
19
19
 
20
- from src.db.setup import get_db
21
- from src.db.models.base import Base
22
- from src.db.crud import CRUDMixin, with_db
20
+ from lecrapaud.db.session import get_db
21
+ from lecrapaud.db.models.base import Base
23
22
 
24
23
  # jointures
25
24
  dataset_target_association = Table(
@@ -40,7 +39,7 @@ dataset_target_association = Table(
40
39
  )
41
40
 
42
41
 
43
- class Dataset(Base, CRUDMixin):
42
+ class Dataset(Base):
44
43
  __tablename__ = "datasets"
45
44
 
46
45
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -57,22 +56,22 @@ class Dataset(Base, CRUDMixin):
57
56
  path = Column(String(255)) # we do not have this at creation time
58
57
  type = Column(String(50), nullable=False)
59
58
  size = Column(Integer, nullable=False)
60
- train_size = Column(Integer, nullable=False)
59
+ train_size = Column(Integer)
61
60
  val_size = Column(Integer)
62
- test_size = Column(Integer, nullable=False)
63
- number_of_groups = Column(Integer, nullable=False)
64
- list_of_groups = Column(JSON, nullable=False)
61
+ test_size = Column(Integer)
65
62
  corr_threshold = Column(Float, nullable=False)
66
63
  max_features = Column(Integer, nullable=False)
67
64
  percentile = Column(Float, nullable=False)
68
- start_date = Column(DateTime, nullable=False)
69
- end_date = Column(DateTime, nullable=False)
70
- train_start_date = Column(DateTime, nullable=False)
71
- train_end_date = Column(DateTime, nullable=False)
65
+ number_of_groups = Column(Integer)
66
+ list_of_groups = Column(JSON)
67
+ start_date = Column(DateTime)
68
+ end_date = Column(DateTime)
69
+ train_start_date = Column(DateTime)
70
+ train_end_date = Column(DateTime)
72
71
  val_start_date = Column(DateTime)
73
72
  val_end_date = Column(DateTime)
74
- test_start_date = Column(DateTime, nullable=False)
75
- test_end_date = Column(DateTime, nullable=False)
73
+ test_start_date = Column(DateTime)
74
+ test_end_date = Column(DateTime)
76
75
 
77
76
  feature_selections = relationship(
78
77
  "FeatureSelection",
@@ -111,14 +110,20 @@ class Dataset(Base, CRUDMixin):
111
110
  feature = [f.name for f in feature_selection.features]
112
111
  return feature
113
112
 
114
- def get_all_features(self):
113
+ def get_all_features(self, date_column: str = None, group_column: str = None):
115
114
  target_idx = [target.id for target in self.targets]
116
- all_features = chain.from_iterable(
117
- [f.name for f in fs.features]
118
- for fs in self.feature_selections
119
- if fs.target_id in target_idx
115
+ all_features = []
116
+ if date_column:
117
+ all_features.append(date_column)
118
+ if group_column:
119
+ all_features.append(group_column)
120
+ all_features += list(
121
+ chain.from_iterable(
122
+ [f.name for f in fs.features]
123
+ for fs in self.feature_selections
124
+ if fs.target_id in target_idx
125
+ )
120
126
  )
121
- all_features = ["DATE", "STOCK"] + list(all_features)
122
127
  all_features = list(dict.fromkeys(all_features))
123
128
 
124
129
  return all_features
@@ -16,13 +16,12 @@ from sqlalchemy import desc, asc, cast, text, func
16
16
 
17
17
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
18
18
 
19
- from src.db.setup import get_db
20
- from src.db.models.base import Base
21
- from src.db.models.feature_selection import feature_selection_association
22
- from src.db.crud import CRUDMixin
19
+ from lecrapaud.db.session import get_db
20
+ from lecrapaud.db.models.base import Base
21
+ from lecrapaud.db.models.feature_selection import feature_selection_association
23
22
 
24
23
 
25
- class Feature(Base, CRUDMixin):
24
+ class Feature(Base):
26
25
  __tablename__ = "features"
27
26
 
28
27
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -42,5 +41,5 @@ class Feature(Base, CRUDMixin):
42
41
  "FeatureSelection",
43
42
  secondary=feature_selection_association,
44
43
  back_populates="features",
45
- lazy="selectin"
44
+ lazy="selectin",
46
45
  )
@@ -24,9 +24,8 @@ from sqlalchemy.orm import (
24
24
  )
25
25
  from collections.abc import Iterable
26
26
 
27
- from src.db.setup import get_db
28
- from src.db.models.base import Base
29
- from src.db.crud import CRUDMixin, with_db
27
+ from lecrapaud.db.session import get_db
28
+ from lecrapaud.db.models.base import Base, with_db
30
29
 
31
30
  # jointures
32
31
  feature_selection_association = Table(
@@ -47,7 +46,7 @@ feature_selection_association = Table(
47
46
  )
48
47
 
49
48
 
50
- class FeatureSelection(Base, CRUDMixin):
49
+ class FeatureSelection(Base):
51
50
  __tablename__ = "feature_selections"
52
51
 
53
52
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -18,12 +18,11 @@ from sqlalchemy import desc, asc, cast, text, func
18
18
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
19
19
  from sqlalchemy.dialects.mysql import insert
20
20
 
21
- from src.db.setup import get_db
22
- from src.db.models.base import Base
23
- from src.db.crud import CRUDMixin, with_db
21
+ from lecrapaud.db.session import get_db
22
+ from lecrapaud.db.models.base import Base, with_db
24
23
 
25
24
 
26
- class FeatureSelectionRank(Base, CRUDMixin):
25
+ class FeatureSelectionRank(Base):
27
26
  __tablename__ = "feature_selection_ranks"
28
27
 
29
28
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
17
17
 
18
18
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
19
19
 
20
- from src.db.setup import get_db
21
- from src.db.models.base import Base
22
- from src.db.crud import CRUDMixin
20
+ from lecrapaud.db.session import get_db
21
+ from lecrapaud.db.models.base import Base
23
22
 
24
23
 
25
- class Model(Base, CRUDMixin):
24
+ class Model(Base):
26
25
  __tablename__ = "models"
27
26
 
28
27
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
17
17
 
18
18
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
19
19
 
20
- from src.db.setup import get_db
21
- from src.db.models.base import Base
22
- from src.db.crud import CRUDMixin
20
+ from lecrapaud.db.session import get_db
21
+ from lecrapaud.db.models.base import Base
23
22
 
24
23
 
25
- class ModelSelection(Base, CRUDMixin):
24
+ class ModelSelection(Base):
26
25
  __tablename__ = "model_selections"
27
26
 
28
27
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -38,11 +37,20 @@ class ModelSelection(Base, CRUDMixin):
38
37
  best_model_params = Column(JSON)
39
38
  best_model_path = Column(String(255))
40
39
  best_model_id = Column(BigInteger, ForeignKey("models.id", ondelete="CASCADE"))
41
- target_id = Column(BigInteger, ForeignKey("targets.id", ondelete="CASCADE"), nullable=False)
42
- dataset_id = Column(BigInteger, ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False)
40
+ target_id = Column(
41
+ BigInteger, ForeignKey("targets.id", ondelete="CASCADE"), nullable=False
42
+ )
43
+ dataset_id = Column(
44
+ BigInteger, ForeignKey("datasets.id", ondelete="CASCADE"), nullable=False
45
+ )
43
46
 
44
47
  best_model = relationship("Model", lazy="selectin")
45
- model_trainings = relationship("ModelTraining", back_populates="model_selection", cascade="all, delete-orphan", lazy="selectin")
48
+ model_trainings = relationship(
49
+ "ModelTraining",
50
+ back_populates="model_selection",
51
+ cascade="all, delete-orphan",
52
+ lazy="selectin",
53
+ )
46
54
  dataset = relationship(
47
55
  "Dataset", back_populates="model_selections", lazy="selectin"
48
56
  )
@@ -53,4 +61,3 @@ class ModelSelection(Base, CRUDMixin):
53
61
  "target_id", "dataset_id", name="uq_model_selection_composite"
54
62
  ),
55
63
  )
56
-
@@ -17,12 +17,11 @@ from sqlalchemy import desc, asc, cast, text, func
17
17
 
18
18
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
19
19
 
20
- from src.db.setup import get_db
21
- from src.db.models.base import Base
22
- from src.db.crud import CRUDMixin
20
+ from lecrapaud.db.session import get_db
21
+ from lecrapaud.db.models.base import Base
23
22
 
24
23
 
25
- class ModelTraining(Base, CRUDMixin):
24
+ class ModelTraining(Base):
26
25
  __tablename__ = "model_trainings"
27
26
 
28
27
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -40,12 +39,21 @@ class ModelTraining(Base, CRUDMixin):
40
39
  training_time = Column(Integer)
41
40
  model_id = Column(BigInteger, ForeignKey("models.id"), nullable=False)
42
41
  model_selection_id = Column(
43
- BigInteger, ForeignKey("model_selections.id", ondelete="CASCADE"), nullable=False
42
+ BigInteger,
43
+ ForeignKey("model_selections.id", ondelete="CASCADE"),
44
+ nullable=False,
44
45
  )
45
46
 
46
47
  model = relationship("Model", lazy="selectin")
47
- model_selection = relationship("ModelSelection", back_populates="model_trainings", lazy="selectin")
48
- score = relationship("Score", back_populates="model_trainings", cascade="all, delete-orphan", lazy="selectin")
48
+ model_selection = relationship(
49
+ "ModelSelection", back_populates="model_trainings", lazy="selectin"
50
+ )
51
+ score = relationship(
52
+ "Score",
53
+ back_populates="model_trainings",
54
+ cascade="all, delete-orphan",
55
+ lazy="selectin",
56
+ )
49
57
 
50
58
  __table_args__ = (
51
59
  UniqueConstraint(
@@ -16,12 +16,11 @@ from sqlalchemy import desc, asc, cast, text, func
16
16
 
17
17
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
18
18
 
19
- from src.db.setup import get_db
20
- from src.db.models.base import Base
21
- from src.db.crud import CRUDMixin
19
+ from lecrapaud.db.session import get_db
20
+ from lecrapaud.db.models.base import Base
22
21
 
23
22
 
24
- class Score(Base, CRUDMixin):
23
+ class Score(Base):
25
24
  __tablename__ = "scores"
26
25
 
27
26
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -57,6 +56,10 @@ class Score(Base, CRUDMixin):
57
56
  threshold = Column(Float)
58
57
  precision_at_threshold = Column(Float)
59
58
  recall_at_threshold = Column(Float)
60
- model_training_id = Column(BigInteger, ForeignKey("model_trainings.id", ondelete="CASCADE"), nullable=False)
59
+ model_training_id = Column(
60
+ BigInteger, ForeignKey("model_trainings.id", ondelete="CASCADE"), nullable=False
61
+ )
61
62
 
62
- model_trainings = relationship("ModelTraining", back_populates="score", lazy="selectin")
63
+ model_trainings = relationship(
64
+ "ModelTraining", back_populates="score", lazy="selectin"
65
+ )
@@ -17,13 +17,12 @@ from sqlalchemy import desc, asc, cast, text, func
17
17
 
18
18
  from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
19
19
 
20
- from src.db.setup import get_db
21
- from src.db.models.base import Base
22
- from src.db.crud import CRUDMixin
23
- from src.db.models.dataset import dataset_target_association
20
+ from lecrapaud.db.session import get_db
21
+ from lecrapaud.db.models.base import Base
22
+ from lecrapaud.db.models.dataset import dataset_target_association
24
23
 
25
24
 
26
- class Target(Base, CRUDMixin):
25
+ class Target(Base):
27
26
  __tablename__ = "targets"
28
27
 
29
28
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
@@ -41,13 +40,22 @@ class Target(Base, CRUDMixin):
41
40
  description = Column(String(255))
42
41
 
43
42
  datasets = relationship(
44
- "Dataset", secondary=dataset_target_association, back_populates="targets", lazy="selectin"
43
+ "Dataset",
44
+ secondary=dataset_target_association,
45
+ back_populates="targets",
46
+ lazy="selectin",
45
47
  )
46
48
  feature_selections = relationship(
47
- "FeatureSelection", back_populates="target", cascade="all, delete-orphan", lazy="selectin"
49
+ "FeatureSelection",
50
+ back_populates="target",
51
+ cascade="all, delete-orphan",
52
+ lazy="selectin",
48
53
  )
49
54
  model_selections = relationship(
50
- "ModelSelection", back_populates="target", cascade="all, delete-orphan", lazy="selectin"
55
+ "ModelSelection",
56
+ back_populates="target",
57
+ cascade="all, delete-orphan",
58
+ lazy="selectin",
51
59
  )
52
60
 
53
61
  __table_args__ = (
@@ -0,0 +1,68 @@
1
+ """Database session management and initialization module."""
2
+
3
+ from contextlib import contextmanager
4
+ from sqlalchemy import create_engine, text
5
+ from sqlalchemy.orm import sessionmaker
6
+ from urllib.parse import urlparse
7
+ from alembic.config import Config
8
+ from alembic import command
9
+ import os
10
+
11
+ from lecrapaud.config import DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME, DB_URI
12
+
13
+ _engine = None
14
+ _SessionLocal = None
15
+ DATABASE_URL = (
16
+ f"mysql+pymysql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" or DB_URI
17
+ )
18
+
19
+
20
+ def init_db(uri: str = None):
21
+ global _engine, _SessionLocal
22
+
23
+ uri = uri if uri else DATABASE_URL
24
+ # Extract DB name from URI to connect without it
25
+ parsed = urlparse(uri)
26
+ db_name = parsed.path.lstrip("/") # remove leading slash
27
+
28
+ # Build root engine (no database in URI)
29
+ root_uri = uri.replace(f"/{db_name}", "/")
30
+
31
+ # Step 1: Connect to MySQL without a database
32
+ root_engine = create_engine(root_uri)
33
+
34
+ # Step 2: Create database if it doesn't exist
35
+ with root_engine.connect() as conn:
36
+ conn.execute(text(f"CREATE DATABASE IF NOT EXISTS {DB_NAME}"))
37
+ conn.commit()
38
+
39
+ # Step 3: Connect to the newly created database
40
+ _engine = create_engine(uri, echo=False)
41
+
42
+ # Step 4: Create session factory
43
+ _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
44
+
45
+ # Step 5: Apply Alembic migrations programmatically
46
+ current_dir = os.path.dirname(__file__) # → lecrapaud/db
47
+ alembic_ini_path = os.path.join(current_dir, "alembic.ini")
48
+
49
+ alembic_cfg = Config(alembic_ini_path)
50
+ alembic_cfg.set_main_option("script_location", "lecrapaud.db.alembic")
51
+ alembic_cfg.set_main_option("sqlalchemy.url", uri or os.getenv("DATABASE_URL"))
52
+
53
+ command.upgrade(alembic_cfg, "head")
54
+
55
+
56
+ # Dependency to get a session instance
57
+ @contextmanager
58
+ def get_db():
59
+ if _SessionLocal is None:
60
+ raise RuntimeError("Database not initialized. Call `init_db()` first.")
61
+ db = _SessionLocal()
62
+ try:
63
+ yield db
64
+ except Exception as e:
65
+ db.rollback()
66
+ raise Exception(e) from e
67
+ finally:
68
+ db.close()
@@ -0,0 +1,64 @@
1
+ import pandas as pd
2
+ import os
3
+ from pathlib import Path
4
+
5
+ os.environ["COVERAGE_FILE"] = str(Path(".coverage").resolve())
6
+
7
+ # Internal
8
+ from lecrapaud.directories import tmp_dir
9
+ from lecrapaud.utils import logger
10
+ from lecrapaud.config import PYTHON_ENV
11
+ from lecrapaud.db import (
12
+ Dataset,
13
+ Target,
14
+ )
15
+ from lecrapaud.db.session import get_db
16
+
17
+
18
+ def create_dataset(
19
+ data: pd.DataFrame,
20
+ corr_threshold,
21
+ percentile,
22
+ max_features,
23
+ date_column,
24
+ group_column,
25
+ **kwargs,
26
+ ):
27
+ dates = {}
28
+ if date_column:
29
+ dates["start_date"] = pd.to_datetime(data[date_column].iat[0])
30
+ dates["end_date"] = pd.to_datetime(data[date_column].iat[-1])
31
+
32
+ groups = {}
33
+ if group_column:
34
+ groups["number_of_groups"] = data[group_column].nunique()
35
+ groups["list_of_groups"] = data[group_column].unique().tolist()
36
+
37
+ with get_db() as db:
38
+ all_targets = Target.get_all(db=db)
39
+ targets = [target for target in all_targets if target.name in data.columns]
40
+ dataset_name = f"data_{groups["number_of_groups"] if group_column else 'ng'}_{corr_threshold}_{percentile}_{max_features}_{dates['start_date'].date() if date_column else 'nd'}_{dates['end_date'].date() if date_column else 'nd'}"
41
+
42
+ dataset_dir = f"{tmp_dir}/{dataset_name}"
43
+ preprocessing_dir = f"{dataset_dir}/preprocessing"
44
+ data_dir = f"{dataset_dir}/data"
45
+ os.makedirs(dataset_dir, exist_ok=True)
46
+ os.makedirs(preprocessing_dir, exist_ok=True)
47
+ os.makedirs(data_dir, exist_ok=True)
48
+
49
+ dataset = Dataset.upsert(
50
+ match_fields=["name"],
51
+ db=db,
52
+ name=dataset_name,
53
+ path=Path(dataset_dir).resolve(),
54
+ type="training",
55
+ size=data.shape[0],
56
+ corr_threshold=corr_threshold,
57
+ percentile=percentile,
58
+ max_features=max_features,
59
+ **groups,
60
+ **dates,
61
+ targets=targets,
62
+ )
63
+
64
+ return dataset