lecrapaud 0.5.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lecrapaud might be problematic. Click here for more details.

lecrapaud/api.py CHANGED
@@ -36,81 +36,66 @@ import logging
36
36
  from lecrapaud.utils import logger
37
37
  from lecrapaud.db.session import init_db
38
38
  from lecrapaud.feature_selection import FeatureSelectionEngine, PreprocessModel
39
- from lecrapaud.model_selection import ModelSelectionEngine, ModelEngine
39
+ from lecrapaud.model_selection import ModelSelectionEngine, ModelEngine, evaluate
40
40
  from lecrapaud.feature_engineering import FeatureEngineeringEngine, PreprocessFeature
41
- from lecrapaud.experiment import create_dataset
42
- from lecrapaud.db import Dataset
41
+ from lecrapaud.experiment import create_experiment
42
+ from lecrapaud.db import Experiment
43
+ from lecrapaud.search_space import normalize_models_idx
43
44
 
44
45
 
45
46
  class LeCrapaud:
46
47
  def __init__(self, uri: str = None):
47
48
  init_db(uri=uri)
48
49
 
49
- def create_experiment(self, **kwargs):
50
- return Experiment(**kwargs)
50
+ def create_experiment(self, data: pd.DataFrame, **kwargs):
51
+ return App(data=data, **kwargs)
51
52
 
52
- def get_experiment(self, id: int):
53
- return Experiment(id)
53
+ def get_experiment(self, id: int, **kwargs):
54
+ return App(id=id, **kwargs)
54
55
 
55
56
 
56
- class Experiment:
57
- def __init__(self, id=None, **kwargs):
57
+ class App:
58
+ def __init__(self, id=None, data=None, **kwargs):
58
59
  if id:
59
- self.dataset = Dataset.get(id)
60
+ self.experiment = Experiment.get(id)
61
+ kwargs.update(self.experiment.context)
60
62
  else:
61
- self.dataset = create_dataset(**kwargs)
63
+ self.experiment = create_experiment(data=data, **kwargs)
62
64
 
63
65
  for key, value in kwargs.items():
66
+ if key == "models_idx":
67
+ value = normalize_models_idx(value)
64
68
  setattr(self, key, value)
65
69
 
66
- self.context = {
67
- # generic
68
- "dataset": self.dataset,
69
- # for FeatureEngineering
70
- "columns_drop": self.columns_drop,
71
- "columns_boolean": self.columns_boolean,
72
- "columns_date": self.columns_date,
73
- "columns_te_groupby": self.columns_te_groupby,
74
- "columns_te_target": self.columns_te_target,
75
- # for PreprocessFeature
76
- "time_series": self.time_series,
77
- "date_column": self.date_column,
78
- "group_column": self.group_column,
79
- "val_size": self.val_size,
80
- "test_size": self.test_size,
81
- "columns_pca": self.columns_pca,
82
- "columns_onehot": self.columns_onehot,
83
- "columns_binary": self.columns_binary,
84
- "columns_frequency": self.columns_frequency,
85
- "columns_ordinal": self.columns_ordinal,
86
- "target_numbers": self.target_numbers,
87
- "target_clf": self.target_clf,
88
- # for PreprocessModel
89
- "models_idx": self.models_idx,
90
- "max_timesteps": self.max_timesteps,
91
- # for ModelSelection
92
- "perform_hyperopt": self.perform_hyperopt,
93
- "number_of_trials": self.number_of_trials,
94
- "perform_crossval": self.perform_crossval,
95
- "plot": self.plot,
96
- "preserve_model": self.preserve_model,
97
- # not yet
98
- "target_mclf": self.target_mclf,
99
- }
100
-
101
70
  def train(self, data):
71
+ logger.info("Running training...")
72
+
102
73
  data_eng = self.feature_engineering(data)
74
+ logger.info("Feature engineering done.")
75
+
103
76
  train, val, test = self.preprocess_feature(data_eng)
104
- all_features = self.feature_selection(train)
77
+ logger.info("Feature preprocessing done.")
78
+
79
+ self.feature_selection(train)
80
+ logger.info("Feature selection done.")
81
+
105
82
  std_data, reshaped_data = self.preprocess_model(train, val, test)
83
+ logger.info("Model preprocessing done.")
84
+
106
85
  self.model_selection(std_data, reshaped_data)
86
+ logger.info("Model selection done.")
107
87
 
108
88
  def predict(self, new_data, verbose: int = 0):
89
+ # for scores if TARGET is in columns
90
+ scores_reg = []
91
+ scores_clf = []
92
+
109
93
  if verbose == 0:
110
94
  logger.setLevel(logging.WARNING)
111
95
 
112
96
  logger.warning("Running prediction...")
113
97
 
98
+ # feature engineering + preprocessing
114
99
  data = self.feature_engineering(
115
100
  data=new_data,
116
101
  for_training=False,
@@ -123,16 +108,16 @@ class Experiment:
123
108
  for target_number in self.target_numbers:
124
109
 
125
110
  # loading model
126
- training_target_dir = f"{self.dataset.path}/TARGET_{target_number}"
127
- all_features = self.dataset.get_all_features(
111
+ training_target_dir = f"{self.experiment.path}/TARGET_{target_number}"
112
+ all_features = self.experiment.get_all_features(
128
113
  date_column=self.date_column, group_column=self.group_column
129
114
  )
130
- if self.dataset.name == "data_28_X_X":
115
+ if self.experiment.name == "data_28_X_X":
131
116
  features = joblib.load(
132
- f"{self.dataset.path}/preprocessing/features_{target_number}.pkl"
117
+ f"{self.experiment.path}/preprocessing/features_{target_number}.pkl"
133
118
  ) # we keep this for backward compatibility
134
119
  else:
135
- features = self.dataset.get_features(target_number)
120
+ features = self.experiment.get_features(target_number)
136
121
  model = ModelEngine(path=training_target_dir)
137
122
 
138
123
  # getting data
@@ -151,7 +136,7 @@ class Experiment:
151
136
  if model.recurrent:
152
137
  y_pred.index = (
153
138
  new_data.index
154
- ) # TODO: not sure this will work for old dataset not aligned with data_for_training for test use case (done, this is why we decode the test set)
139
+ ) # TODO: not sure this will work for old experiment not aligned with data_for_training for test use case (done, this is why we decode the test set)
155
140
 
156
141
  # unscaling prediction
157
142
  if (
@@ -165,6 +150,26 @@ class Experiment:
165
150
  ).flatten(),
166
151
  index=new_data.index,
167
152
  )
153
+ y_pred.name = "PRED"
154
+
155
+ # evaluate if TARGET is in columns
156
+ if f"TARGET_{target_number}" in new_data.columns:
157
+ y_true = new_data[f"TARGET_{target_number}"]
158
+ prediction = pd.concat([y_true, y_pred], axis=1)
159
+ prediction.rename(
160
+ columns={f"TARGET_{target_number}": "TARGET"}, inplace=True
161
+ )
162
+ print(prediction)
163
+ score = evaluate(
164
+ prediction,
165
+ target_type=model.target_type,
166
+ )
167
+ score["TARGET"] = f"TARGET_{target_number}"
168
+
169
+ if model.target_type == "classification":
170
+ scores_clf.append(score)
171
+ else:
172
+ scores_reg.append(score)
168
173
 
169
174
  # renaming pred column and concatenating with initial data
170
175
  if isinstance(y_pred, pd.DataFrame):
@@ -179,7 +184,11 @@ class Experiment:
179
184
  y_pred.name = f"TARGET_{target_number}_PRED"
180
185
  new_data = pd.concat([new_data, y_pred], axis=1)
181
186
 
182
- return new_data
187
+ if len(scores_reg) > 0:
188
+ scores_reg = pd.DataFrame(scores_reg).set_index("TARGET")
189
+ if len(scores_clf) > 0:
190
+ scores_clf = pd.DataFrame(scores_clf).set_index("TARGET")
191
+ return new_data, scores_reg, scores_clf
183
192
 
184
193
  def feature_engineering(self, data, for_training=True):
185
194
  app = FeatureEngineeringEngine(
@@ -197,7 +206,7 @@ class Experiment:
197
206
  def preprocess_feature(self, data, for_training=True):
198
207
  app = PreprocessFeature(
199
208
  data=data,
200
- dataset=self.dataset,
209
+ experiment=self.experiment,
201
210
  time_series=self.time_series,
202
211
  date_column=self.date_column,
203
212
  group_column=self.group_column,
@@ -223,12 +232,12 @@ class Experiment:
223
232
  app = FeatureSelectionEngine(
224
233
  train=train,
225
234
  target_number=target_number,
226
- dataset=self.dataset,
235
+ experiment=self.experiment,
227
236
  target_clf=self.target_clf,
228
237
  )
229
238
  app.run()
230
- self.dataset = Dataset.get(self.dataset.id)
231
- all_features = self.dataset.get_all_features(
239
+ self.experiment = Experiment.get(self.experiment.id)
240
+ all_features = self.experiment.get_all_features(
232
241
  date_column=self.date_column, group_column=self.group_column
233
242
  )
234
243
  return all_features
@@ -238,7 +247,7 @@ class Experiment:
238
247
  train=train,
239
248
  val=val,
240
249
  test=test,
241
- dataset=self.dataset,
250
+ experiment=self.experiment,
242
251
  target_numbers=self.target_numbers,
243
252
  target_clf=self.target_clf,
244
253
  models_idx=self.models_idx,
@@ -260,15 +269,16 @@ class Experiment:
260
269
  data=data,
261
270
  reshaped_data=reshaped_data,
262
271
  target_number=target_number,
263
- dataset=self.dataset,
272
+ experiment=self.experiment,
264
273
  target_clf=self.target_clf,
265
274
  models_idx=self.models_idx,
266
275
  time_series=self.time_series,
267
276
  date_column=self.date_column,
268
277
  group_column=self.group_column,
278
+ target_clf_thresholds=self.target_clf_thresholds,
269
279
  )
270
280
  app.run(
271
- self.session_name,
281
+ self.experiment_name,
272
282
  perform_hyperopt=self.perform_hyperopt,
273
283
  number_of_trials=self.number_of_trials,
274
284
  perform_crossval=self.perform_crossval,
lecrapaud/config.py CHANGED
@@ -25,5 +25,9 @@ DB_PORT = (
25
25
  DB_NAME = (
26
26
  os.getenv("TEST_DB_NAME") if PYTHON_ENV == "Test" else os.getenv("DB_NAME", None)
27
27
  )
28
- DB_URI = os.getenv("TEST_DB_URI") if PYTHON_ENV == "Test" else os.getenv("DB_URI", None)
28
+ DB_URI = (
29
+ os.getenv("TEST_DB_URI", None)
30
+ if PYTHON_ENV == "Test"
31
+ else os.getenv("DB_URI", None)
32
+ )
29
33
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -1,8 +1,8 @@
1
- """initial_setup
1
+ """
2
2
 
3
- Revision ID: 1edada319fd7
3
+ Revision ID: f089dfb7e3ba
4
4
  Revises:
5
- Create Date: 2025-06-20 19:24:25.033055
5
+ Create Date: 2025-06-23 17:48:32.842030
6
6
 
7
7
  """
8
8
  from typing import Sequence, Union
@@ -12,7 +12,7 @@ import sqlalchemy as sa
12
12
 
13
13
 
14
14
  # revision identifiers, used by Alembic.
15
- revision: str = '1edada319fd7'
15
+ revision: str = 'f089dfb7e3ba'
16
16
  down_revision: Union[str, None] = None
17
17
  branch_labels: Union[str, Sequence[str], None] = None
18
18
  depends_on: Union[str, Sequence[str], None] = None
@@ -20,7 +20,7 @@ depends_on: Union[str, Sequence[str], None] = None
20
20
 
21
21
  def upgrade() -> None:
22
22
  # ### commands auto generated by Alembic - please adjust! ###
23
- op.create_table('lecrapaud_datasets',
23
+ op.create_table('lecrapaud_experiments',
24
24
  sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False),
25
25
  sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
26
26
  sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
@@ -45,9 +45,9 @@ def upgrade() -> None:
45
45
  sa.Column('test_start_date', sa.DateTime(), nullable=True),
46
46
  sa.Column('test_end_date', sa.DateTime(), nullable=True),
47
47
  sa.PrimaryKeyConstraint('id'),
48
- sa.UniqueConstraint('name', name='uq_datasets_composite')
48
+ sa.UniqueConstraint('name', name='uq_experiments_composite')
49
49
  )
50
- op.create_index(op.f('ix_lecrapaud_datasets_id'), 'lecrapaud_datasets', ['id'], unique=False)
50
+ op.create_index(op.f('ix_lecrapaud_experiments_id'), 'lecrapaud_experiments', ['id'], unique=False)
51
51
  op.create_table('lecrapaud_features',
52
52
  sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False),
53
53
  sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
@@ -79,12 +79,12 @@ def upgrade() -> None:
79
79
  sa.UniqueConstraint('name', 'type', name='uq_target_composite')
80
80
  )
81
81
  op.create_index(op.f('ix_lecrapaud_targets_id'), 'lecrapaud_targets', ['id'], unique=False)
82
- op.create_table('lecrapaud_dataset_target_association',
83
- sa.Column('dataset_id', sa.BigInteger(), nullable=False),
82
+ op.create_table('lecrapaud_experiment_target_association',
83
+ sa.Column('experiment_id', sa.BigInteger(), nullable=False),
84
84
  sa.Column('target_id', sa.BigInteger(), nullable=False),
85
- sa.ForeignKeyConstraint(['dataset_id'], ['lecrapaud_datasets.id'], ondelete='CASCADE'),
85
+ sa.ForeignKeyConstraint(['experiment_id'], ['lecrapaud_experiments.id'], ondelete='CASCADE'),
86
86
  sa.ForeignKeyConstraint(['target_id'], ['lecrapaud_targets.id'], ondelete='CASCADE'),
87
- sa.PrimaryKeyConstraint('dataset_id', 'target_id')
87
+ sa.PrimaryKeyConstraint('experiment_id', 'target_id')
88
88
  )
89
89
  op.create_table('lecrapaud_feature_selections',
90
90
  sa.Column('id', sa.BigInteger(), autoincrement=True, nullable=False),
@@ -92,12 +92,12 @@ def upgrade() -> None:
92
92
  sa.Column('updated_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False),
93
93
  sa.Column('training_time', sa.Integer(), nullable=True),
94
94
  sa.Column('best_features_path', sa.String(length=255), nullable=True),
95
- sa.Column('dataset_id', sa.BigInteger(), nullable=False),
95
+ sa.Column('experiment_id', sa.BigInteger(), nullable=False),
96
96
  sa.Column('target_id', sa.BigInteger(), nullable=False),
97
- sa.ForeignKeyConstraint(['dataset_id'], ['lecrapaud_datasets.id'], ondelete='CASCADE'),
97
+ sa.ForeignKeyConstraint(['experiment_id'], ['lecrapaud_experiments.id'], ondelete='CASCADE'),
98
98
  sa.ForeignKeyConstraint(['target_id'], ['lecrapaud_targets.id'], ondelete='CASCADE'),
99
99
  sa.PrimaryKeyConstraint('id'),
100
- sa.UniqueConstraint('dataset_id', 'target_id', name='uq_feature_selection_composite')
100
+ sa.UniqueConstraint('experiment_id', 'target_id', name='uq_feature_selection_composite')
101
101
  )
102
102
  op.create_index(op.f('ix_lecrapaud_feature_selections_id'), 'lecrapaud_feature_selections', ['id'], unique=False)
103
103
  op.create_table('lecrapaud_model_selections',
@@ -108,12 +108,12 @@ def upgrade() -> None:
108
108
  sa.Column('best_model_path', sa.String(length=255), nullable=True),
109
109
  sa.Column('best_model_id', sa.BigInteger(), nullable=True),
110
110
  sa.Column('target_id', sa.BigInteger(), nullable=False),
111
- sa.Column('dataset_id', sa.BigInteger(), nullable=False),
111
+ sa.Column('experiment_id', sa.BigInteger(), nullable=False),
112
112
  sa.ForeignKeyConstraint(['best_model_id'], ['lecrapaud_models.id'], ondelete='CASCADE'),
113
- sa.ForeignKeyConstraint(['dataset_id'], ['lecrapaud_datasets.id'], ondelete='CASCADE'),
113
+ sa.ForeignKeyConstraint(['experiment_id'], ['lecrapaud_experiments.id'], ondelete='CASCADE'),
114
114
  sa.ForeignKeyConstraint(['target_id'], ['lecrapaud_targets.id'], ondelete='CASCADE'),
115
115
  sa.PrimaryKeyConstraint('id'),
116
- sa.UniqueConstraint('target_id', 'dataset_id', name='uq_model_selection_composite')
116
+ sa.UniqueConstraint('target_id', 'experiment_id', name='uq_model_selection_composite')
117
117
  )
118
118
  op.create_index(op.f('ix_lecrapaud_model_selections_id'), 'lecrapaud_model_selections', ['id'], unique=False)
119
119
  op.create_table('lecrapaud_feature_selection_association',
@@ -202,13 +202,13 @@ def downgrade() -> None:
202
202
  op.drop_table('lecrapaud_model_selections')
203
203
  op.drop_index(op.f('ix_lecrapaud_feature_selections_id'), table_name='lecrapaud_feature_selections')
204
204
  op.drop_table('lecrapaud_feature_selections')
205
- op.drop_table('lecrapaud_dataset_target_association')
205
+ op.drop_table('lecrapaud_experiment_target_association')
206
206
  op.drop_index(op.f('ix_lecrapaud_targets_id'), table_name='lecrapaud_targets')
207
207
  op.drop_table('lecrapaud_targets')
208
208
  op.drop_index(op.f('ix_lecrapaud_models_id'), table_name='lecrapaud_models')
209
209
  op.drop_table('lecrapaud_models')
210
210
  op.drop_index(op.f('ix_lecrapaud_features_id'), table_name='lecrapaud_features')
211
211
  op.drop_table('lecrapaud_features')
212
- op.drop_index(op.f('ix_lecrapaud_datasets_id'), table_name='lecrapaud_datasets')
213
- op.drop_table('lecrapaud_datasets')
212
+ op.drop_index(op.f('ix_lecrapaud_experiments_id'), table_name='lecrapaud_experiments')
213
+ op.drop_table('lecrapaud_experiments')
214
214
  # ### end Alembic commands ###
@@ -0,0 +1,30 @@
1
+ """
2
+
3
+ Revision ID: c62251b129ed
4
+ Revises: f089dfb7e3ba
5
+ Create Date: 2025-06-24 12:16:21.949079
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = 'c62251b129ed'
16
+ down_revision: Union[str, None] = 'f089dfb7e3ba'
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ # ### commands auto generated by Alembic - please adjust! ###
23
+ op.add_column('lecrapaud_experiments', sa.Column('context', sa.JSON(), nullable=True))
24
+ # ### end Alembic commands ###
25
+
26
+
27
+ def downgrade() -> None:
28
+ # ### commands auto generated by Alembic - please adjust! ###
29
+ op.drop_column('lecrapaud_experiments', 'context')
30
+ # ### end Alembic commands ###
@@ -0,0 +1,34 @@
1
+ """
2
+
3
+ Revision ID: 86457e2f333f
4
+ Revises: c62251b129ed
5
+ Create Date: 2025-06-24 17:11:25.187876
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ from sqlalchemy.dialects import mysql
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision: str = '86457e2f333f'
16
+ down_revision: Union[str, None] = 'c62251b129ed'
17
+ branch_labels: Union[str, Sequence[str], None] = None
18
+ depends_on: Union[str, Sequence[str], None] = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ # ### commands auto generated by Alembic - please adjust! ###
23
+ op.add_column('lecrapaud_scores', sa.Column('thresholds', sa.JSON(), nullable=True))
24
+ op.add_column('lecrapaud_scores', sa.Column('f1_at_threshold', sa.Float(), nullable=True))
25
+ op.drop_column('lecrapaud_scores', 'threshold')
26
+ # ### end Alembic commands ###
27
+
28
+
29
+ def downgrade() -> None:
30
+ # ### commands auto generated by Alembic - please adjust! ###
31
+ op.add_column('lecrapaud_scores', sa.Column('threshold', mysql.FLOAT(), nullable=True))
32
+ op.drop_column('lecrapaud_scores', 'f1_at_threshold')
33
+ op.drop_column('lecrapaud_scores', 'thresholds')
34
+ # ### end Alembic commands ###
@@ -1,6 +1,5 @@
1
1
  from lecrapaud.db.models.base import Base
2
-
3
- from lecrapaud.db.models.dataset import Dataset
2
+ from lecrapaud.db.models.experiment import Experiment
4
3
  from lecrapaud.db.models.feature_selection_rank import FeatureSelectionRank
5
4
  from lecrapaud.db.models.feature_selection import FeatureSelection
6
5
  from lecrapaud.db.models.feature import Feature
@@ -9,3 +8,16 @@ from lecrapaud.db.models.model_training import ModelTraining
9
8
  from lecrapaud.db.models.model import Model
10
9
  from lecrapaud.db.models.score import Score
11
10
  from lecrapaud.db.models.target import Target
11
+
12
+ __all__ = [
13
+ 'Base',
14
+ 'Experiment',
15
+ 'FeatureSelectionRank',
16
+ 'FeatureSelection',
17
+ 'Feature',
18
+ 'ModelSelection',
19
+ 'ModelTraining',
20
+ 'Model',
21
+ 'Score',
22
+ 'Target',
23
+ ]
@@ -9,6 +9,7 @@ from sqlalchemy.inspection import inspect
9
9
  from sqlalchemy.orm.attributes import InstrumentedAttribute
10
10
  from lecrapaud.db.session import get_db
11
11
  from sqlalchemy.ext.declarative import declared_attr
12
+ from sqlalchemy.dialects.mysql import insert as mysql_insert
12
13
 
13
14
 
14
15
  def with_db(func):
@@ -98,9 +99,53 @@ class Base(DeclarativeBase):
98
99
  }
99
100
  for row in results
100
101
  ]
101
-
102
102
  return results
103
103
 
104
+ @classmethod
105
+ @with_db
106
+ def upsert_bulk(cls, db=None, match_fields: list[str] = None, **kwargs):
107
+ """
108
+ Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
109
+
110
+ Args:
111
+ db (Session): SQLAlchemy DB session
112
+ match_fields (list[str]): Fields to match on for deduplication
113
+ **kwargs: Column-wise keyword arguments (field_name=[...])
114
+ """
115
+ # Ensure all provided fields have values of equal length
116
+ value_lengths = [len(v) for v in kwargs.values()]
117
+ if not value_lengths or len(set(value_lengths)) != 1:
118
+ raise ValueError(
119
+ "All field values must be non-empty lists of the same length."
120
+ )
121
+
122
+ # Convert column-wise kwargs to row-wise list of dicts
123
+ items = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
124
+ if not items:
125
+ return
126
+
127
+ stmt = mysql_insert(cls.__table__).values(items)
128
+
129
+ # Default to primary keys if match_fields not provided
130
+ if not match_fields:
131
+ match_fields = [col.name for col in cls.__table__.primary_key.columns]
132
+
133
+ # Ensure all columns to be updated are in the insert
134
+ update_dict = {
135
+ c.name: stmt.inserted[c.name]
136
+ for c in cls.__table__.columns
137
+ if c.name not in match_fields and c.name in items[0]
138
+ }
139
+
140
+ if not update_dict:
141
+ # Avoid triggering ON DUPLICATE KEY UPDATE with empty dict
142
+ db.execute(stmt.prefix_with("IGNORE"))
143
+ else:
144
+ upsert_stmt = stmt.on_duplicate_key_update(**update_dict)
145
+ db.execute(upsert_stmt)
146
+
147
+ db.commit()
148
+
104
149
  @classmethod
105
150
  @with_db
106
151
  def filter(cls, db=None, **kwargs):
@@ -165,7 +210,8 @@ class Base(DeclarativeBase):
165
210
 
166
211
  if instance:
167
212
  for key, value in kwargs.items():
168
- setattr(instance, key, value)
213
+ if key != "id":
214
+ setattr(instance, key, value)
169
215
  else:
170
216
  instance = cls(**kwargs)
171
217
  db.add(instance)
@@ -1,33 +1,31 @@
1
+ from itertools import chain
2
+
1
3
  from sqlalchemy import (
2
4
  Column,
3
5
  Integer,
4
6
  String,
5
7
  DateTime,
6
- Date,
7
8
  Float,
8
9
  JSON,
9
10
  Table,
10
11
  ForeignKey,
11
12
  BigInteger,
12
- Index,
13
13
  TIMESTAMP,
14
14
  UniqueConstraint,
15
+ func,
15
16
  )
16
- from sqlalchemy import desc, asc, cast, text, func
17
- from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase
18
- from itertools import chain
17
+ from sqlalchemy.orm import relationship
19
18
 
20
- from lecrapaud.db.session import get_db
21
19
  from lecrapaud.db.models.base import Base
22
20
 
23
21
  # jointures
24
- lecrapaud_dataset_target_association = Table(
25
- "lecrapaud_dataset_target_association",
22
+ lecrapaud_experiment_target_association = Table(
23
+ "lecrapaud_experiment_target_association",
26
24
  Base.metadata,
27
25
  Column(
28
- "dataset_id",
26
+ "experiment_id",
29
27
  BigInteger,
30
- ForeignKey("lecrapaud_datasets.id", ondelete="CASCADE"),
28
+ ForeignKey("lecrapaud_experiments.id", ondelete="CASCADE"),
31
29
  primary_key=True,
32
30
  ),
33
31
  Column(
@@ -39,7 +37,7 @@ lecrapaud_dataset_target_association = Table(
39
37
  )
40
38
 
41
39
 
42
- class Dataset(Base):
40
+ class Experiment(Base):
43
41
 
44
42
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
45
43
  created_at = Column(
@@ -71,30 +69,31 @@ class Dataset(Base):
71
69
  val_end_date = Column(DateTime)
72
70
  test_start_date = Column(DateTime)
73
71
  test_end_date = Column(DateTime)
72
+ context = Column(JSON)
74
73
 
75
74
  feature_selections = relationship(
76
75
  "FeatureSelection",
77
- back_populates="dataset",
76
+ back_populates="experiment",
78
77
  cascade="all, delete-orphan",
79
78
  lazy="selectin",
80
79
  )
81
80
  model_selections = relationship(
82
81
  "ModelSelection",
83
- back_populates="dataset",
82
+ back_populates="experiment",
84
83
  cascade="all, delete-orphan",
85
84
  lazy="selectin",
86
85
  )
87
86
  targets = relationship(
88
87
  "Target",
89
- secondary=lecrapaud_dataset_target_association,
90
- back_populates="datasets",
88
+ secondary=lecrapaud_experiment_target_association,
89
+ back_populates="experiments",
91
90
  lazy="selectin",
92
91
  )
93
92
 
94
93
  __table_args__ = (
95
94
  UniqueConstraint(
96
95
  "name",
97
- name="uq_datasets_composite",
96
+ name="uq_experiments_composite",
98
97
  ),
99
98
  )
100
99
 
@@ -106,23 +105,22 @@ class Dataset(Base):
106
105
  feature_selection = [
107
106
  fs for fs in feature_selections if fs.target_id == target_id
108
107
  ][0]
109
- feature = [f.name for f in feature_selection.features]
110
- return feature
108
+ features = [f.name for f in feature_selection.features]
109
+ return features
111
110
 
112
111
  def get_all_features(self, date_column: str = None, group_column: str = None):
113
112
  target_idx = [target.id for target in self.targets]
113
+ _all_features = chain.from_iterable(
114
+ [f.name for f in fs.features]
115
+ for fs in self.feature_selections
116
+ if fs.target_id in target_idx
117
+ )
114
118
  all_features = []
115
119
  if date_column:
116
120
  all_features.append(date_column)
117
121
  if group_column:
118
122
  all_features.append(group_column)
119
- all_features += list(
120
- chain.from_iterable(
121
- [f.name for f in fs.features]
122
- for fs in self.feature_selections
123
- if fs.target_id in target_idx
124
- )
125
- )
123
+ all_features += list(_all_features)
126
124
  all_features = list(dict.fromkeys(all_features))
127
125
 
128
126
  return all_features
@@ -60,9 +60,9 @@ class FeatureSelection(Base):
60
60
  )
61
61
  training_time = Column(Integer)
62
62
  best_features_path = Column(String(255))
63
- dataset_id = Column(
63
+ experiment_id = Column(
64
64
  BigInteger,
65
- ForeignKey("lecrapaud_datasets.id", ondelete="CASCADE"),
65
+ ForeignKey("lecrapaud_experiments.id", ondelete="CASCADE"),
66
66
  nullable=False,
67
67
  )
68
68
  target_id = Column(
@@ -71,8 +71,8 @@ class FeatureSelection(Base):
71
71
  nullable=False,
72
72
  )
73
73
 
74
- dataset = relationship(
75
- "Dataset", back_populates="feature_selections", lazy="selectin"
74
+ experiment = relationship(
75
+ "Experiment", back_populates="feature_selections", lazy="selectin"
76
76
  )
77
77
  target = relationship(
78
78
  "Target", back_populates="feature_selections", lazy="selectin"
@@ -92,7 +92,7 @@ class FeatureSelection(Base):
92
92
 
93
93
  __table_args__ = (
94
94
  UniqueConstraint(
95
- "dataset_id", "target_id", name="uq_feature_selection_composite"
95
+ "experiment_id", "target_id", name="uq_feature_selection_composite"
96
96
  ),
97
97
  )
98
98