lecrapaud 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lecrapaud might be problematic. Click here for more details.

@@ -10,24 +10,27 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute
10
10
  from lecrapaud.db.session import get_db
11
11
  from sqlalchemy.ext.declarative import declared_attr
12
12
  from sqlalchemy.dialects.mysql import insert as mysql_insert
13
+ from sqlalchemy import UniqueConstraint
14
+ from sqlalchemy.inspection import inspect as sqlalchemy_inspect
13
15
  from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
14
16
 
15
17
 
16
18
  def with_db(func):
17
19
  """Decorator to provide a database session to the wrapped function.
18
-
20
+
19
21
  If a db parameter is already provided, it will be used. Otherwise,
20
22
  a new session will be created and automatically managed.
21
23
  """
24
+
22
25
  @wraps(func)
23
26
  def wrapper(*args, **kwargs):
24
27
  if "db" in kwargs and kwargs["db"] is not None:
25
28
  return func(*args, **kwargs)
26
-
29
+
27
30
  with get_db() as db:
28
31
  kwargs["db"] = db
29
32
  return func(*args, **kwargs)
30
-
33
+
31
34
  return wrapper
32
35
 
33
36
 
@@ -106,51 +109,6 @@ class Base(DeclarativeBase):
106
109
  ]
107
110
  return results
108
111
 
109
- @classmethod
110
- @with_db
111
- def upsert_bulk(cls, db=None, match_fields: list[str] = None, **kwargs):
112
- """
113
- Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
114
-
115
- Args:
116
- db (Session): SQLAlchemy DB session
117
- match_fields (list[str]): Fields to match on for deduplication
118
- **kwargs: Column-wise keyword arguments (field_name=[...])
119
- """
120
- # Ensure all provided fields have values of equal length
121
- value_lengths = [len(v) for v in kwargs.values()]
122
- if not value_lengths or len(set(value_lengths)) != 1:
123
- raise ValueError(
124
- "All field values must be non-empty lists of the same length."
125
- )
126
-
127
- # Convert column-wise kwargs to row-wise list of dicts
128
- items = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
129
- if not items:
130
- return
131
-
132
- stmt = mysql_insert(cls.__table__).values(items)
133
-
134
- # Default to primary keys if match_fields not provided
135
- if not match_fields:
136
- match_fields = [col.name for col in cls.__table__.primary_key.columns]
137
-
138
- # Ensure all columns to be updated are in the insert
139
- update_dict = {
140
- c.name: stmt.inserted[c.name]
141
- for c in cls.__table__.columns
142
- if c.name not in match_fields and c.name in items[0]
143
- }
144
-
145
- if not update_dict:
146
- # Avoid triggering ON DUPLICATE KEY UPDATE with empty dict
147
- db.execute(stmt.prefix_with("IGNORE"))
148
- else:
149
- upsert_stmt = stmt.on_duplicate_key_update(**update_dict)
150
- db.execute(upsert_stmt)
151
-
152
- db.commit()
153
-
154
112
  @classmethod
155
113
  @with_db
156
114
  def filter(cls, db=None, **kwargs):
@@ -198,33 +156,113 @@ class Base(DeclarativeBase):
198
156
 
199
157
  @classmethod
200
158
  @with_db
201
- def upsert(cls, match_fields: list[str], db=None, **kwargs):
159
+ def upsert(cls, db=None, **kwargs):
202
160
  """
203
- Upsert an instance of the model: update if found, else create.
161
+ Upsert an instance of the model using MySQL's ON DUPLICATE KEY UPDATE.
204
162
 
205
- :param match_fields: list of field names to use for matching
206
163
  :param kwargs: all fields for creation or update
207
164
  """
208
- filters = [
209
- getattr(cls, field) == kwargs[field]
210
- for field in match_fields
211
- if field in kwargs
212
- ]
165
+ # Use INSERT ... ON DUPLICATE KEY UPDATE
166
+ stmt = mysql_insert(cls.__table__).values(**kwargs)
167
+ stmt = stmt.on_duplicate_key_update(
168
+ **{k: v for k, v in kwargs.items() if k != "id"}
169
+ )
213
170
 
214
- instance = db.query(cls).filter(*filters).first()
171
+ result = db.execute(stmt)
172
+ db.commit()
215
173
 
216
- if instance:
217
- for key, value in kwargs.items():
218
- if key != "id":
219
- setattr(instance, key, value)
174
+ # Get the instance - either the newly inserted or updated one
175
+ # If updated, lastrowid is 0, so we need to query
176
+ if result.lastrowid and result.lastrowid > 0:
177
+ # New insert
178
+ instance = db.get(cls, result.lastrowid)
220
179
  else:
221
- instance = cls(**kwargs)
222
- db.add(instance)
180
+ # Updated - need to find it using unique constraint fields
181
+ mapper = sqlalchemy_inspect(cls)
182
+ instance = None
183
+
184
+ for constraint in mapper.mapped_table.constraints:
185
+ if isinstance(constraint, UniqueConstraint):
186
+ col_names = [col.name for col in constraint.columns]
187
+ if all(name in kwargs for name in col_names):
188
+ filters = [
189
+ getattr(cls, col_name) == kwargs[col_name]
190
+ for col_name in col_names
191
+ ]
192
+ instance = db.query(cls).filter(*filters).first()
193
+ if instance:
194
+ break
195
+
196
+ # Check for single column unique constraints
197
+ if not instance:
198
+ for col in mapper.mapped_table.columns:
199
+ if col.unique and col.name in kwargs:
200
+ instance = (
201
+ db.query(cls)
202
+ .filter(getattr(cls, col.name) == kwargs[col.name])
203
+ .first()
204
+ )
205
+ if instance:
206
+ break
207
+
208
+ # If still not found, try to find by all kwargs (excluding None values)
209
+ if not instance:
210
+ instance = (
211
+ db.query(cls)
212
+ .filter_by(
213
+ **{
214
+ k: v
215
+ for k, v in kwargs.items()
216
+ if v is not None and k != "id"
217
+ }
218
+ )
219
+ .first()
220
+ )
221
+
222
+ if instance:
223
+ db.refresh(instance)
223
224
 
224
- db.commit()
225
- db.refresh(instance)
226
225
  return instance
227
226
 
227
+ @classmethod
228
+ @with_db
229
+ def bulk_upsert(cls, rows: list[dict] = None, db=None, **kwargs):
230
+ """
231
+ Performs a bulk upsert into the database using ON DUPLICATE KEY UPDATE.
232
+
233
+ Args:
234
+ rows (list[dict]): List of dictionaries representing rows to upsert
235
+ db (Session): SQLAlchemy DB session
236
+ **kwargs: Column-wise keyword arguments (field_name=[...]) for backwards compatibility
237
+ """
238
+ # Handle both new format (rows) and legacy format (kwargs)
239
+ if rows is None and kwargs:
240
+ # Legacy format: convert column-wise kwargs to row-wise list of dicts
241
+ value_lengths = [len(v) for v in kwargs.values()]
242
+ if not value_lengths or len(set(value_lengths)) != 1:
243
+ raise ValueError(
244
+ "All field values must be non-empty lists of the same length."
245
+ )
246
+ rows = [dict(zip(kwargs.keys(), row)) for row in zip(*kwargs.values())]
247
+
248
+ if not rows:
249
+ return 0
250
+
251
+ BATCH_SIZE = 200
252
+ total_affected = 0
253
+
254
+ for i in range(0, len(rows), BATCH_SIZE):
255
+ batch = rows[i : i + BATCH_SIZE]
256
+ stmt = mysql_insert(cls.__table__).values(batch)
257
+ stmt = stmt.on_duplicate_key_update(
258
+ **{key: stmt.inserted[key] for key in batch[0] if key != "id"}
259
+ )
260
+ result = db.execute(stmt)
261
+ total_affected += result.rowcount
262
+
263
+ db.commit()
264
+ return total_affected
265
+
228
266
  @classmethod
229
267
  @with_db
230
268
  def delete(cls, id: int, db=None):
@@ -20,8 +20,7 @@ from sqlalchemy.ext.hybrid import hybrid_property
20
20
  from sqlalchemy import func
21
21
  from statistics import fmean as mean
22
22
  from lecrapaud.db.models.model_selection import ModelSelection
23
- from lecrapaud.db.models.model_training import ModelTraining
24
- from lecrapaud.db.models.score import Score
23
+ from lecrapaud.db.models.model_selection_score import ModelSelectionScore
25
24
 
26
25
  from lecrapaud.db.models.base import Base, with_db
27
26
  from lecrapaud.db.models.utils import create_association_table
@@ -51,10 +50,43 @@ class Experiment(Base):
51
50
  )
52
51
  name = Column(String(255), nullable=False)
53
52
  path = Column(String(255)) # we do not have this at creation time
54
- type = Column(String(50), nullable=False)
55
53
  size = Column(Integer, nullable=False)
56
54
  train_size = Column(Integer)
57
55
  val_size = Column(Integer)
56
+ test_size = Column(Integer)
57
+ number_of_groups = Column(Integer)
58
+ list_of_groups = Column(JSON)
59
+ number_of_targets = Column(Integer)
60
+ start_date = Column(DateTime)
61
+ end_date = Column(DateTime)
62
+ train_start_date = Column(DateTime)
63
+ train_end_date = Column(DateTime)
64
+ val_start_date = Column(DateTime)
65
+ val_end_date = Column(DateTime)
66
+ test_start_date = Column(DateTime)
67
+ test_end_date = Column(DateTime)
68
+ context = Column(JSON)
69
+
70
+ feature_selections = relationship(
71
+ "FeatureSelection",
72
+ back_populates="experiment",
73
+ cascade="all, delete-orphan",
74
+ lazy="selectin",
75
+ )
76
+
77
+ targets = relationship(
78
+ "Target",
79
+ secondary=lecrapaud_experiment_target_association,
80
+ back_populates="experiments",
81
+ lazy="selectin",
82
+ )
83
+
84
+ __table_args__ = (
85
+ UniqueConstraint(
86
+ "name",
87
+ name="uq_experiments_composite",
88
+ ),
89
+ )
58
90
 
59
91
  # Relationships
60
92
  model_selections = relationship(
@@ -69,18 +101,9 @@ class Experiment(Base):
69
101
  """Best RMSE score across all model selections and trainings."""
70
102
  # Get the minimum RMSE for each model selection
71
103
  min_scores = [
72
- min(
73
- score.rmse
74
- for mt in ms.model_trainings
75
- for score in mt.score
76
- if score.rmse is not None
77
- )
104
+ min(mss.rmse for mss in ms.model_selection_scores if mss.rmse is not None)
78
105
  for ms in self.model_selections
79
- if any(
80
- score.rmse is not None
81
- for mt in ms.model_trainings
82
- for score in mt.score
83
- )
106
+ if any(mss.rmse is not None for mss in ms.model_selection_scores)
84
107
  ]
85
108
  return min(min_scores) if min_scores else None
86
109
 
@@ -90,17 +113,12 @@ class Experiment(Base):
90
113
  # Get the minimum LogLoss for each model selection
91
114
  min_scores = [
92
115
  min(
93
- score.logloss
94
- for mt in ms.model_trainings
95
- for score in mt.score
96
- if score.logloss is not None
116
+ mss.logloss
117
+ for mss in ms.model_selection_scores
118
+ if mss.logloss is not None
97
119
  )
98
120
  for ms in self.model_selections
99
- if any(
100
- score.logloss is not None
101
- for mt in ms.model_trainings
102
- for score in mt.score
103
- )
121
+ if any(mss.logloss is not None for mss in ms.model_selection_scores)
104
122
  ]
105
123
  return min(min_scores) if min_scores else None
106
124
 
@@ -109,18 +127,9 @@ class Experiment(Base):
109
127
  """Average RMSE score across all model selections and trainings."""
110
128
  # Get the minimum RMSE for each model selection
111
129
  min_scores = [
112
- min(
113
- score.rmse
114
- for mt in ms.model_trainings
115
- for score in mt.score
116
- if score.rmse is not None
117
- )
130
+ min(mss.rmse for mss in ms.model_selection_scores if mss.rmse is not None)
118
131
  for ms in self.model_selections
119
- if any(
120
- score.rmse is not None
121
- for mt in ms.model_trainings
122
- for score in mt.score
123
- )
132
+ if any(mss.rmse is not None for mss in ms.model_selection_scores)
124
133
  ]
125
134
  return mean(min_scores) if min_scores else None
126
135
 
@@ -130,57 +139,15 @@ class Experiment(Base):
130
139
  # Get the minimum LogLoss for each model selection
131
140
  min_scores = [
132
141
  min(
133
- score.logloss
134
- for mt in ms.model_trainings
135
- for score in mt.score
136
- if score.logloss is not None
142
+ mss.logloss
143
+ for mss in ms.model_selection_scores
144
+ if mss.logloss is not None
137
145
  )
138
146
  for ms in self.model_selections
139
- if any(
140
- score.logloss is not None
141
- for mt in ms.model_trainings
142
- for score in mt.score
143
- )
147
+ if any(mss.logloss is not None for mss in ms.model_selection_scores)
144
148
  ]
145
149
  return mean(min_scores) if min_scores else None
146
150
 
147
- test_size = Column(Integer)
148
- corr_threshold = Column(Float, nullable=False)
149
- max_features = Column(Integer, nullable=False)
150
- percentile = Column(Float, nullable=False)
151
- number_of_groups = Column(Integer)
152
- list_of_groups = Column(JSON)
153
- start_date = Column(DateTime)
154
- end_date = Column(DateTime)
155
- train_start_date = Column(DateTime)
156
- train_end_date = Column(DateTime)
157
- val_start_date = Column(DateTime)
158
- val_end_date = Column(DateTime)
159
- test_start_date = Column(DateTime)
160
- test_end_date = Column(DateTime)
161
- context = Column(JSON)
162
-
163
- feature_selections = relationship(
164
- "FeatureSelection",
165
- back_populates="experiment",
166
- cascade="all, delete-orphan",
167
- lazy="selectin",
168
- )
169
-
170
- targets = relationship(
171
- "Target",
172
- secondary=lecrapaud_experiment_target_association,
173
- back_populates="experiments",
174
- lazy="selectin",
175
- )
176
-
177
- __table_args__ = (
178
- UniqueConstraint(
179
- "name",
180
- name="uq_experiments_composite",
181
- ),
182
- )
183
-
184
151
  @classmethod
185
152
  @with_db
186
153
  def get_all_by_name(cls, name: str | None = None, limit: int = 1000, db=None):
@@ -353,7 +320,7 @@ class Experiment(Base):
353
320
  (ms for ms in self.model_selections if ms.target_id == target.id), None
354
321
  )
355
322
 
356
- if not best_model_selection or not best_model_selection.model_trainings:
323
+ if not best_model_selection or not best_model_selection.model_selection_scores:
357
324
  return {
358
325
  "experiment_name": self.name,
359
326
  "target_number": target_number,
@@ -361,22 +328,31 @@ class Experiment(Base):
361
328
  "scores": {},
362
329
  }
363
330
 
364
- # Get the best model training (assuming the first one is the best)
365
- best_training = best_model_selection.model_trainings[0]
366
-
367
- # Get the validation score for this training
368
- validation_scores = [s for s in best_training.score if s.type == "validation"]
331
+ # Get the best model score based on lowest logloss or rmse
332
+ model_scores = best_model_selection.model_selection_scores
369
333
 
370
- if not validation_scores:
334
+ # Determine if we should use logloss or rmse based on what's available
335
+ if any(ms.logloss is not None for ms in model_scores):
336
+ # Classification: find lowest logloss
337
+ best_score = min(
338
+ (ms for ms in model_scores if ms.logloss is not None),
339
+ key=lambda x: x.logloss,
340
+ )
341
+ elif any(ms.rmse is not None for ms in model_scores):
342
+ # Regression: find lowest rmse
343
+ best_score = min(
344
+ (ms for ms in model_scores if ms.rmse is not None), key=lambda x: x.rmse
345
+ )
346
+ else:
371
347
  return {
372
348
  "experiment_name": self.name,
373
349
  "target_number": target_number,
374
- "error": "No validation scores found for the best model",
350
+ "error": "No scores found for the best model",
375
351
  "scores": {},
376
352
  }
377
353
 
378
- # Get all available metrics from the first validation score
379
- score = validation_scores[0]
354
+ # Use the best score found
355
+ score = best_score
380
356
  available_metrics = [
381
357
  "rmse",
382
358
  "mae",
@@ -397,13 +373,9 @@ class Experiment(Base):
397
373
 
398
374
  # Get the model info
399
375
  model_info = {
400
- "model_type": (
401
- best_training.model.model_type if best_training.model else "unknown"
402
- ),
403
- "model_name": (
404
- best_training.model.name if best_training.model else "unknown"
405
- ),
406
- "training_time_seconds": best_training.training_time,
376
+ "model_type": (score.model.model_type if score.model else "unknown"),
377
+ "model_name": (score.model.name if score.model else "unknown"),
378
+ "training_time_seconds": score.training_time,
407
379
  }
408
380
 
409
381
  return {
@@ -413,7 +385,10 @@ class Experiment(Base):
413
385
  "scores": scores,
414
386
  }
415
387
 
416
- def get_features(self, target_number: int):
388
+ @with_db
389
+ def get_features(self, target_number: int, db=None):
390
+ # Ensure we have a fresh instance attached to the session
391
+ self = db.merge(self)
417
392
  targets = [t for t in self.targets if t.name == f"TARGET_{target_number}"]
418
393
  if targets:
419
394
  target_id = targets[0].id
@@ -429,7 +404,12 @@ class Experiment(Base):
429
404
  features = joblib.load(f"{self.path}/TARGET_{target_number}/features.pkl")
430
405
  return features
431
406
 
432
- def get_all_features(self, date_column: str = None, group_column: str = None):
407
+ @with_db
408
+ def get_all_features(
409
+ self, date_column: str = None, group_column: str = None, db=None
410
+ ):
411
+ # Ensure we have a fresh instance attached to the session
412
+ self = db.merge(self)
433
413
  target_idx = [target.id for target in self.targets]
434
414
  _all_features = chain.from_iterable(
435
415
  [f.name for f in fs.features]
@@ -115,7 +115,4 @@ class FeatureSelection(Base):
115
115
  if feature not in self.features:
116
116
  self.features.append(feature)
117
117
 
118
- db.flush()
119
- db.refresh(self)
120
- print(self.features)
121
118
  return self
@@ -65,21 +65,3 @@ class FeatureSelectionRank(Base):
65
65
  name="uq_feature_selection_rank_composite",
66
66
  ),
67
67
  )
68
-
69
- @classmethod
70
- @with_db
71
- def bulk_upsert(cls, rows, db=None):
72
- stmt = insert(cls).values(rows)
73
-
74
- update_fields = {
75
- key: stmt.inserted[key]
76
- for key in rows[0]
77
- if key not in ("feature_selection_id", "feature_id", "method")
78
- }
79
-
80
- stmt = stmt.on_duplicate_key_update(**update_fields)
81
-
82
- db.execute(stmt)
83
- db.commit()
84
-
85
- return len(rows)
@@ -54,8 +54,8 @@ class ModelSelection(Base):
54
54
  )
55
55
 
56
56
  best_model = relationship("Model", lazy="selectin")
57
- model_trainings = relationship(
58
- "ModelTraining",
57
+ model_selection_scores = relationship(
58
+ "ModelSelectionScore",
59
59
  back_populates="model_selection",
60
60
  cascade="all, delete-orphan",
61
61
  lazy="selectin",
@@ -3,10 +3,11 @@ from sqlalchemy import (
3
3
  Integer,
4
4
  String,
5
5
  Float,
6
+ JSON,
6
7
  ForeignKey,
7
8
  BigInteger,
8
9
  TIMESTAMP,
9
- JSON,
10
+ UniqueConstraint,
10
11
  )
11
12
  from sqlalchemy import func
12
13
  from sqlalchemy.orm import relationship
@@ -14,7 +15,9 @@ from lecrapaud.db.models.base import Base
14
15
  from lecrapaud.config import LECRAPAUD_TABLE_PREFIX
15
16
 
16
17
 
17
- class Score(Base):
18
+ class ModelSelectionScore(Base):
19
+ __tablename__ = f"{LECRAPAUD_TABLE_PREFIX}_model_selection_scores"
20
+
18
21
  id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
19
22
  created_at = Column(
20
23
  TIMESTAMP(timezone=True), server_default=func.now(), nullable=False
@@ -25,10 +28,21 @@ class Score(Base):
25
28
  onupdate=func.now(),
26
29
  nullable=False,
27
30
  )
28
- type = Column(
29
- String(50), nullable=False
30
- ) # either hyperopts or validation or crossval
31
+
32
+ # From ModelTraining
33
+ best_params = Column(JSON)
34
+ model_path = Column(String(255))
31
35
  training_time = Column(Integer)
36
+ model_id = Column(
37
+ BigInteger, ForeignKey(f"{LECRAPAUD_TABLE_PREFIX}_models.id"), nullable=False
38
+ )
39
+ model_selection_id = Column(
40
+ BigInteger,
41
+ ForeignKey(f"{LECRAPAUD_TABLE_PREFIX}_model_selections.id", ondelete="CASCADE"),
42
+ nullable=False,
43
+ )
44
+
45
+ # From Score (excluding type and training_time which is already in ModelTraining)
32
46
  eval_data_std = Column(Float)
33
47
  rmse = Column(Float)
34
48
  rmse_std_ratio = Column(Float)
@@ -50,12 +64,15 @@ class Score(Base):
50
64
  precision_at_threshold = Column(Float)
51
65
  recall_at_threshold = Column(Float)
52
66
  f1_at_threshold = Column(Float)
53
- model_training_id = Column(
54
- BigInteger,
55
- ForeignKey(f"{LECRAPAUD_TABLE_PREFIX}_model_trainings.id", ondelete="CASCADE"),
56
- nullable=False,
57
- )
58
67
 
59
- model_trainings = relationship(
60
- "ModelTraining", back_populates="score", lazy="selectin"
68
+ # Relationships
69
+ model = relationship("Model", lazy="selectin")
70
+ model_selection = relationship(
71
+ "ModelSelection", back_populates="model_selection_scores", lazy="selectin"
61
72
  )
73
+
74
+ __table_args__ = (
75
+ UniqueConstraint(
76
+ "model_id", "model_selection_id", name="uq_model_selection_score_composite"
77
+ ),
78
+ )
lecrapaud/db/session.py CHANGED
@@ -73,6 +73,7 @@ def init_db(uri: str = None):
73
73
  autocommit=False,
74
74
  autoflush=False,
75
75
  bind=_engine,
76
+ expire_on_commit=False, # Prevent detached instance errors
76
77
  )
77
78
 
78
79
  # Step 5: Apply Alembic migrations programmatically
lecrapaud/experiment.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
 
4
4
  import pandas as pd
5
5
  import joblib
6
+ from datetime import datetime
6
7
 
7
8
  # Set up coverage file path
8
9
  os.environ["COVERAGE_FILE"] = str(Path(".coverage").resolve())
@@ -15,9 +16,6 @@ from lecrapaud.db.session import get_db
15
16
 
16
17
  def create_experiment(
17
18
  data: pd.DataFrame | str,
18
- corr_threshold,
19
- percentile,
20
- max_features,
21
19
  date_column,
22
20
  group_column,
23
21
  experiment_name,
@@ -42,7 +40,10 @@ def create_experiment(
42
40
  targets = [
43
41
  target for target in all_targets if target.name in data.columns.str.upper()
44
42
  ]
45
- experiment_name = f"{experiment_name}_{groups["number_of_groups"] if group_column else 'ng'}_{corr_threshold}_{percentile}_{max_features}_{dates['start_date'].date() if date_column else 'nd'}_{dates['end_date'].date() if date_column else 'nd'}"
43
+ experiment_name = (
44
+ f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
45
+ )
46
+ number_of_targets = len(targets)
46
47
 
47
48
  experiment_dir = f"{tmp_dir}/{experiment_name}"
48
49
  preprocessing_dir = f"{experiment_dir}/preprocessing"
@@ -50,23 +51,16 @@ def create_experiment(
50
51
  os.makedirs(preprocessing_dir, exist_ok=True)
51
52
  os.makedirs(data_dir, exist_ok=True)
52
53
 
54
+ # Create or update experiment (without targets relation)
53
55
  experiment = Experiment.upsert(
54
- match_fields=["name"],
55
56
  db=db,
56
57
  name=experiment_name,
57
58
  path=Path(experiment_dir).resolve(),
58
- type="training",
59
59
  size=data.shape[0],
60
- corr_threshold=corr_threshold,
61
- percentile=percentile,
62
- max_features=max_features,
60
+ number_of_targets=number_of_targets,
63
61
  **groups,
64
62
  **dates,
65
- targets=targets,
66
63
  context={
67
- "corr_threshold": corr_threshold,
68
- "percentile": percentile,
69
- "max_features": max_features,
70
64
  "date_column": date_column,
71
65
  "group_column": group_column,
72
66
  "experiment_name": experiment_name,
@@ -74,4 +68,8 @@ def create_experiment(
74
68
  },
75
69
  )
76
70
 
71
+ # Set targets relationship after creation/update
72
+ experiment.targets = targets
73
+ experiment.save(db=db)
74
+
77
75
  return experiment