wavetrainer 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wavetrainer-0.0.9/wavetrainer.egg-info → wavetrainer-0.0.11}/PKG-INFO +1 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/setup.py +1 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/tests/trainer_test.py +22 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/__init__.py +1 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/create.py +4 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/catboost_model.py +14 -6
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/base_selector_reducer.py +13 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/correlation_reducer.py +4 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/selector/selector.py +4 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/trainer.py +42 -11
- {wavetrainer-0.0.9 → wavetrainer-0.0.11/wavetrainer.egg-info}/PKG-INFO +1 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/LICENSE +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/MANIFEST.in +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/README.md +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/requirements.txt +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/setup.cfg +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/tests/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/tests/model/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/tests/model/catboost_kwargs_test.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/calibrator/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/calibrator/calibrator.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/calibrator/calibrator_router.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/calibrator/mapie_calibrator.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/calibrator/vennabers_calibrator.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/exceptions.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/fit.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/catboost_classifier_wrap.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/catboost_kwargs.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/catboost_regressor_wrap.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/model.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model/model_router.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/model_type.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/params.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/combined_reducer.py +1 -1
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/constant_reducer.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/duplicate_reducer.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/nonnumeric_reducer.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/reducer/reducer.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/selector/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/class_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/combined_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/exponential_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/linear_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/noop_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/sigmoid_weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/weights.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/weights/weights_router.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/windower/__init__.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer/windower/windower.py +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer.egg-info/SOURCES.txt +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer.egg-info/dependency_links.txt +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer.egg-info/not-zip-safe +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer.egg-info/requires.txt +0 -0
- {wavetrainer-0.0.9 → wavetrainer-0.0.11}/wavetrainer.egg-info/top_level.txt +0 -0
@@ -23,7 +23,7 @@ def install_requires() -> typing.List[str]:
|
|
23
23
|
|
24
24
|
setup(
|
25
25
|
name='wavetrainer',
|
26
|
-
version='0.0.
|
26
|
+
version='0.0.11',
|
27
27
|
description='A library for automatically finding the optimal model within feature and hyperparameter space.',
|
28
28
|
long_description=long_description,
|
29
29
|
long_description_content_type='text/markdown',
|
@@ -37,3 +37,25 @@ class TestTrainer(unittest.TestCase):
|
|
37
37
|
df = trainer.transform(df)
|
38
38
|
print("df:")
|
39
39
|
print(df)
|
40
|
+
|
41
|
+
def test_trainer_dt_column(self):
|
42
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
43
|
+
trainer = Trainer(tmpdir, walkforward_timedelta=datetime.timedelta(days=7), trials=1, dt_column="dt_column")
|
44
|
+
x_data = [i for i in range(100)]
|
45
|
+
x_index = [datetime.datetime(2022, 1, 1) + datetime.timedelta(days=i) for i in range(len(x_data))]
|
46
|
+
df = pd.DataFrame(
|
47
|
+
data={
|
48
|
+
"column1": x_data,
|
49
|
+
"dt_column": x_index,
|
50
|
+
},
|
51
|
+
)
|
52
|
+
y = pd.DataFrame(
|
53
|
+
data={
|
54
|
+
"y": [x % 2 == 0 for x in x_data],
|
55
|
+
},
|
56
|
+
index=df.index,
|
57
|
+
)
|
58
|
+
trainer.fit(df, y=y)
|
59
|
+
df = trainer.transform(df)
|
60
|
+
print("df:")
|
61
|
+
print(df)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
"""A function for creating a new trainer."""
|
2
2
|
|
3
|
+
# pylint: disable=too-many-arguments,too-many-positional-arguments
|
4
|
+
|
3
5
|
import datetime
|
4
6
|
|
5
7
|
from .trainer import Trainer
|
@@ -11,6 +13,7 @@ def create(
|
|
11
13
|
test_size: float | datetime.timedelta | None = None,
|
12
14
|
validation_size: float | datetime.timedelta | None = None,
|
13
15
|
dt_column: str | None = None,
|
16
|
+
max_train_timeout: datetime.timedelta | None = None,
|
14
17
|
) -> Trainer:
|
15
18
|
"""Create a trainer."""
|
16
19
|
return Trainer(
|
@@ -19,4 +22,5 @@ def create(
|
|
19
22
|
test_size=test_size,
|
20
23
|
validation_size=validation_size,
|
21
24
|
dt_column=dt_column,
|
25
|
+
max_train_timeout=max_train_timeout,
|
22
26
|
)
|
@@ -23,12 +23,13 @@ _DEPTH_KEY = "depth"
|
|
23
23
|
_L2_LEAF_REG_KEY = "l2_leaf_reg"
|
24
24
|
_BOOSTING_TYPE_KEY = "boosting_type"
|
25
25
|
_MODEL_TYPE_KEY = "model_type"
|
26
|
+
_EARLY_STOPPING_ROUNDS = "early_stopping_rounds"
|
26
27
|
|
27
28
|
|
28
29
|
class CatboostModel(Model):
|
29
30
|
"""A class that uses Catboost as a model."""
|
30
31
|
|
31
|
-
# pylint: disable=too-many-positional-arguments,too-many-arguments
|
32
|
+
# pylint: disable=too-many-positional-arguments,too-many-arguments,too-many-instance-attributes
|
32
33
|
|
33
34
|
_catboost: CatBoost | None
|
34
35
|
_iterations: None | int
|
@@ -37,6 +38,7 @@ class CatboostModel(Model):
|
|
37
38
|
_l2_leaf_reg: None | float
|
38
39
|
_boosting_type: None | str
|
39
40
|
_model_type: None | ModelType
|
41
|
+
_early_stopping_rounds: None | int
|
40
42
|
|
41
43
|
@classmethod
|
42
44
|
def name(cls) -> str:
|
@@ -51,6 +53,7 @@ class CatboostModel(Model):
|
|
51
53
|
self._l2_leaf_reg = None
|
52
54
|
self._boosting_type = None
|
53
55
|
self._model_type = None
|
56
|
+
self._early_stopping_rounds = None
|
54
57
|
|
55
58
|
@property
|
56
59
|
def estimator(self) -> Any:
|
@@ -80,6 +83,9 @@ class CatboostModel(Model):
|
|
80
83
|
self._boosting_type = trial.suggest_categorical(
|
81
84
|
_BOOSTING_TYPE_KEY, ["Ordered", "Plain"]
|
82
85
|
)
|
86
|
+
self._early_stopping_rounds = trial.suggest_int(
|
87
|
+
_EARLY_STOPPING_ROUNDS, 10, 1000
|
88
|
+
)
|
83
89
|
|
84
90
|
def load(self, folder: str) -> None:
|
85
91
|
with open(
|
@@ -92,6 +98,7 @@ class CatboostModel(Model):
|
|
92
98
|
self._l2_leaf_reg = params[_L2_LEAF_REG_KEY]
|
93
99
|
self._boosting_type = params[_BOOSTING_TYPE_KEY]
|
94
100
|
self._model_type = ModelType(params[_MODEL_TYPE_KEY])
|
101
|
+
self._early_stopping_rounds = params[_EARLY_STOPPING_ROUNDS]
|
95
102
|
catboost = self._provide_catboost()
|
96
103
|
catboost.load_model(os.path.join(folder, _MODEL_FILENAME))
|
97
104
|
|
@@ -107,6 +114,7 @@ class CatboostModel(Model):
|
|
107
114
|
_L2_LEAF_REG_KEY: self._l2_leaf_reg,
|
108
115
|
_BOOSTING_TYPE_KEY: self._boosting_type,
|
109
116
|
_MODEL_TYPE_KEY: str(self._model_type),
|
117
|
+
_EARLY_STOPPING_ROUNDS: self._early_stopping_rounds,
|
110
118
|
},
|
111
119
|
handle,
|
112
120
|
)
|
@@ -141,7 +149,7 @@ class CatboostModel(Model):
|
|
141
149
|
)
|
142
150
|
catboost.fit(
|
143
151
|
train_pool,
|
144
|
-
early_stopping_rounds=
|
152
|
+
early_stopping_rounds=self._early_stopping_rounds,
|
145
153
|
verbose=False,
|
146
154
|
metric_period=100,
|
147
155
|
eval_set=eval_pool,
|
@@ -178,7 +186,7 @@ class CatboostModel(Model):
|
|
178
186
|
depth=self._depth,
|
179
187
|
l2_leaf_reg=self._l2_leaf_reg,
|
180
188
|
boosting_type=self._boosting_type,
|
181
|
-
early_stopping_rounds=
|
189
|
+
early_stopping_rounds=self._early_stopping_rounds,
|
182
190
|
metric_period=100,
|
183
191
|
)
|
184
192
|
case ModelType.REGRESSION:
|
@@ -188,7 +196,7 @@ class CatboostModel(Model):
|
|
188
196
|
depth=self._depth,
|
189
197
|
l2_leaf_reg=self._l2_leaf_reg,
|
190
198
|
boosting_type=self._boosting_type,
|
191
|
-
early_stopping_rounds=
|
199
|
+
early_stopping_rounds=self._early_stopping_rounds,
|
192
200
|
metric_period=100,
|
193
201
|
)
|
194
202
|
case ModelType.BINNED_BINARY:
|
@@ -198,7 +206,7 @@ class CatboostModel(Model):
|
|
198
206
|
depth=self._depth,
|
199
207
|
l2_leaf_reg=self._l2_leaf_reg,
|
200
208
|
boosting_type=self._boosting_type,
|
201
|
-
early_stopping_rounds=
|
209
|
+
early_stopping_rounds=self._early_stopping_rounds,
|
202
210
|
metric_period=100,
|
203
211
|
)
|
204
212
|
case ModelType.MULTI_CLASSIFICATION:
|
@@ -208,7 +216,7 @@ class CatboostModel(Model):
|
|
208
216
|
depth=self._depth,
|
209
217
|
l2_leaf_reg=self._l2_leaf_reg,
|
210
218
|
boosting_type=self._boosting_type,
|
211
|
-
early_stopping_rounds=
|
219
|
+
early_stopping_rounds=self._early_stopping_rounds,
|
212
220
|
metric_period=100,
|
213
221
|
)
|
214
222
|
self._catboost = catboost
|
@@ -1,5 +1,6 @@
|
|
1
1
|
"""A reducer that uses a base selector from the feature engine."""
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import os
|
4
5
|
from typing import Self
|
5
6
|
|
@@ -26,6 +27,11 @@ class BaseSelectorReducer(Reducer):
|
|
26
27
|
def name(cls) -> str:
|
27
28
|
raise NotImplementedError("name not implemented in parent class.")
|
28
29
|
|
30
|
+
@classmethod
|
31
|
+
def should_raise(cls) -> bool:
|
32
|
+
"""Whether the class should raise its exception if it encounters it."""
|
33
|
+
return True
|
34
|
+
|
29
35
|
def set_options(self, trial: optuna.Trial | optuna.trial.FrozenTrial) -> None:
|
30
36
|
pass
|
31
37
|
|
@@ -45,11 +51,17 @@ class BaseSelectorReducer(Reducer):
|
|
45
51
|
eval_x: pd.DataFrame | None = None,
|
46
52
|
eval_y: pd.Series | pd.DataFrame | None = None,
|
47
53
|
) -> Self:
|
54
|
+
if len(df.columns) <= 1:
|
55
|
+
return self
|
48
56
|
try:
|
49
57
|
self._base_selector.fit(df) # type: ignore
|
50
58
|
except ValueError as exc:
|
51
|
-
|
59
|
+
logging.warning(str(exc))
|
60
|
+
if self.should_raise():
|
61
|
+
raise WavetrainException() from exc
|
52
62
|
return self
|
53
63
|
|
54
64
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
65
|
+
if len(df.columns) <= 1:
|
66
|
+
return df
|
55
67
|
return self._base_selector.transform(df)
|
@@ -53,6 +53,8 @@ class Selector(Params, Fit):
|
|
53
53
|
model_kwargs = self._model.pre_fit(df, y=y, eval_x=eval_x, eval_y=eval_y)
|
54
54
|
if not isinstance(y, pd.Series):
|
55
55
|
raise ValueError("y is not a series.")
|
56
|
+
if len(df.columns) <= 1:
|
57
|
+
return self
|
56
58
|
n_features_to_select = max(1, int(len(df.columns) * self._feature_ratio))
|
57
59
|
self._selector = RFE(
|
58
60
|
self._model.estimator,
|
@@ -70,6 +72,8 @@ class Selector(Params, Fit):
|
|
70
72
|
return self
|
71
73
|
|
72
74
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
75
|
+
if len(df.columns) <= 1:
|
76
|
+
return df
|
73
77
|
selector = self._selector
|
74
78
|
if selector is None:
|
75
79
|
raise ValueError("selector is null.")
|
@@ -51,6 +51,7 @@ class Trainer(Fit):
|
|
51
51
|
test_size: float | datetime.timedelta | None = None,
|
52
52
|
validation_size: float | datetime.timedelta | None = None,
|
53
53
|
dt_column: str | None = None,
|
54
|
+
max_train_timeout: datetime.timedelta | None = None,
|
54
55
|
):
|
55
56
|
tqdm.tqdm.pandas()
|
56
57
|
|
@@ -137,6 +138,7 @@ class Trainer(Fit):
|
|
137
138
|
self._test_size = test_size
|
138
139
|
self._validation_size = validation_size
|
139
140
|
self._dt_column = dt_column
|
141
|
+
self._max_train_timeout = max_train_timeout
|
140
142
|
|
141
143
|
def _provide_study(self, column: str) -> optuna.Study:
|
142
144
|
storage_name = f"sqlite:///{self._folder}/{column}/{_STUDYDB_FILENAME}"
|
@@ -165,7 +167,11 @@ class Trainer(Fit):
|
|
165
167
|
if y is None:
|
166
168
|
return self
|
167
169
|
|
168
|
-
dt_index =
|
170
|
+
dt_index = (
|
171
|
+
df.index
|
172
|
+
if self._dt_column is None
|
173
|
+
else pd.DatetimeIndex(pd.to_datetime(df[self._dt_column]))
|
174
|
+
)
|
169
175
|
|
170
176
|
def _fit_column(y_series: pd.Series):
|
171
177
|
column_dir = os.path.join(self._folder, str(y_series.name))
|
@@ -184,10 +190,10 @@ class Trainer(Fit):
|
|
184
190
|
trial.set_user_attr(_IDX_USR_ATTR_KEY, split_idx.isoformat())
|
185
191
|
|
186
192
|
train_dt_index = dt_index[: len(x)]
|
187
|
-
x_train = x[train_dt_index < split_idx]
|
188
|
-
x_test = x[train_dt_index >= split_idx]
|
189
|
-
y_train = y_series[train_dt_index < split_idx]
|
190
|
-
y_test = y_series[train_dt_index >= split_idx]
|
193
|
+
x_train = x[train_dt_index < split_idx] # type: ignore
|
194
|
+
x_test = x[train_dt_index >= split_idx] # type: ignore
|
195
|
+
y_train = y_series[train_dt_index < split_idx] # type: ignore
|
196
|
+
y_test = y_series[train_dt_index >= split_idx] # type: ignore
|
191
197
|
|
192
198
|
try:
|
193
199
|
# Window the data
|
@@ -250,14 +256,15 @@ class Trainer(Fit):
|
|
250
256
|
return float(r2_score(y_test, y_pred[[PREDICTION_COLUMN]]))
|
251
257
|
return float(f1_score(y_test, y_pred[[PREDICTION_COLUMN]]))
|
252
258
|
except WavetrainException as exc:
|
259
|
+
logging.warning("WE DID NOT END UP TRAINING ANYTHING!!!!!")
|
253
260
|
logging.warning(str(exc))
|
254
261
|
return -1.0
|
255
262
|
|
256
263
|
start_validation_index = (
|
257
|
-
dt_index[-int(len(dt_index) * self._validation_size) - 1]
|
264
|
+
dt_index.to_list()[-int(len(dt_index) * self._validation_size) - 1]
|
258
265
|
if isinstance(self._validation_size, float)
|
259
266
|
else dt_index[
|
260
|
-
dt_index >= (dt_index.to_list()[-1] - self._validation_size)
|
267
|
+
dt_index >= (dt_index.to_list()[-1] - self._validation_size) # type: ignore
|
261
268
|
].to_list()[0]
|
262
269
|
)
|
263
270
|
test_df = df[dt_index < start_validation_index]
|
@@ -284,11 +291,21 @@ class Trainer(Fit):
|
|
284
291
|
initial_trials = max(self._trials - len(study.trials), 0)
|
285
292
|
if initial_trials > 0:
|
286
293
|
study.optimize(
|
287
|
-
test_objective,
|
294
|
+
test_objective,
|
295
|
+
n_trials=initial_trials,
|
296
|
+
show_progress_bar=True,
|
297
|
+
timeout=None
|
298
|
+
if self._max_train_timeout is None
|
299
|
+
else self._max_train_timeout.total_seconds(),
|
288
300
|
)
|
289
301
|
|
290
302
|
train_len = len(df[dt_index < start_test_index])
|
291
|
-
test_len = len(
|
303
|
+
test_len = len(
|
304
|
+
dt_index[
|
305
|
+
(dt_index >= start_test_index)
|
306
|
+
& (dt_index <= start_validation_index)
|
307
|
+
]
|
308
|
+
)
|
292
309
|
|
293
310
|
last_processed_dt = None
|
294
311
|
for count, test_idx in tqdm.tqdm(
|
@@ -326,6 +343,9 @@ class Trainer(Fit):
|
|
326
343
|
validate_objctive, idx=test_idx, series=test_series
|
327
344
|
),
|
328
345
|
n_trials=1,
|
346
|
+
timeout=None
|
347
|
+
if self._max_train_timeout is None
|
348
|
+
else self._max_train_timeout.total_seconds(),
|
329
349
|
)
|
330
350
|
|
331
351
|
_fit(study.best_trial, test_df, test_series, True, test_idx)
|
@@ -341,7 +361,11 @@ class Trainer(Fit):
|
|
341
361
|
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
|
342
362
|
"""Predict the expected values of the data."""
|
343
363
|
feature_columns = df.columns.values
|
344
|
-
dt_index =
|
364
|
+
dt_index = (
|
365
|
+
df.index
|
366
|
+
if self._dt_column is None
|
367
|
+
else pd.DatetimeIndex(pd.to_datetime(df[self._dt_column]))
|
368
|
+
)
|
345
369
|
|
346
370
|
for column in os.listdir(self._folder):
|
347
371
|
column_path = os.path.join(self._folder, column)
|
@@ -353,6 +377,8 @@ class Trainer(Fit):
|
|
353
377
|
if not os.path.isdir(date_path):
|
354
378
|
continue
|
355
379
|
dates.append(datetime.datetime.fromisoformat(date_str))
|
380
|
+
if not dates:
|
381
|
+
raise ValueError(f"no dates found for {column}.")
|
356
382
|
bins: list[datetime.datetime] = sorted(
|
357
383
|
[dt_index.min().to_pydatetime()]
|
358
384
|
+ dates
|
@@ -371,7 +397,12 @@ class Trainer(Fit):
|
|
371
397
|
column: str,
|
372
398
|
dates: list[datetime.datetime],
|
373
399
|
) -> pd.DataFrame:
|
374
|
-
|
400
|
+
group_dt_index = (
|
401
|
+
group.index
|
402
|
+
if self._dt_column is None
|
403
|
+
else pd.DatetimeIndex(pd.to_datetime(group[self._dt_column]))
|
404
|
+
)
|
405
|
+
filtered_dates = [x for x in dates if x < group_dt_index.max()]
|
375
406
|
if not filtered_dates:
|
376
407
|
filtered_dates = [dates[-1]]
|
377
408
|
date_str = dates[-1].isoformat()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|