factrainer-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- factrainer_core-0.1.0/.gitignore +173 -0
- factrainer_core-0.1.0/PKG-INFO +7 -0
- factrainer_core-0.1.0/README.md +0 -0
- factrainer_core-0.1.0/pyproject.toml +18 -0
- factrainer_core-0.1.0/src/factrainer/core/__init__.py +7 -0
- factrainer_core-0.1.0/src/factrainer/core/cv/__init__.py +0 -0
- factrainer_core-0.1.0/src/factrainer/core/cv/config.py +128 -0
- factrainer_core-0.1.0/src/factrainer/core/cv/cv.py +78 -0
- factrainer_core-0.1.0/src/factrainer/core/cv/dataset.py +99 -0
- factrainer_core-0.1.0/src/factrainer/core/cv/raw_model.py +7 -0
- factrainer_core-0.1.0/src/factrainer/core/py.typed +0 -0
- factrainer_core-0.1.0/src/factrainer/core/single.py +55 -0
- factrainer_core-0.1.0/src/factrainer/core/trait.py +58 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
notebooks/
|
|
2
|
+
|
|
3
|
+
# pytest-profiling
|
|
4
|
+
prof/
|
|
5
|
+
|
|
6
|
+
# Byte-compiled / optimized / DLL files
|
|
7
|
+
__pycache__/
|
|
8
|
+
*.py[cod]
|
|
9
|
+
*$py.class
|
|
10
|
+
|
|
11
|
+
# C extensions
|
|
12
|
+
*.so
|
|
13
|
+
|
|
14
|
+
# Distribution / packaging
|
|
15
|
+
.Python
|
|
16
|
+
build/
|
|
17
|
+
develop-eggs/
|
|
18
|
+
dist/
|
|
19
|
+
downloads/
|
|
20
|
+
eggs/
|
|
21
|
+
.eggs/
|
|
22
|
+
lib/
|
|
23
|
+
lib64/
|
|
24
|
+
parts/
|
|
25
|
+
sdist/
|
|
26
|
+
var/
|
|
27
|
+
wheels/
|
|
28
|
+
share/python-wheels/
|
|
29
|
+
*.egg-info/
|
|
30
|
+
.installed.cfg
|
|
31
|
+
*.egg
|
|
32
|
+
MANIFEST
|
|
33
|
+
|
|
34
|
+
# PyInstaller
|
|
35
|
+
# Usually these files are written by a python script from a template
|
|
36
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
37
|
+
*.manifest
|
|
38
|
+
*.spec
|
|
39
|
+
|
|
40
|
+
# Installer logs
|
|
41
|
+
pip-log.txt
|
|
42
|
+
pip-delete-this-directory.txt
|
|
43
|
+
|
|
44
|
+
# Unit test / coverage reports
|
|
45
|
+
htmlcov/
|
|
46
|
+
.tox/
|
|
47
|
+
.nox/
|
|
48
|
+
.coverage
|
|
49
|
+
.coverage.*
|
|
50
|
+
.cache
|
|
51
|
+
nosetests.xml
|
|
52
|
+
coverage.xml
|
|
53
|
+
*.cover
|
|
54
|
+
*.py,cover
|
|
55
|
+
.hypothesis/
|
|
56
|
+
.pytest_cache/
|
|
57
|
+
cover/
|
|
58
|
+
|
|
59
|
+
# Translations
|
|
60
|
+
*.mo
|
|
61
|
+
*.pot
|
|
62
|
+
|
|
63
|
+
# Django stuff:
|
|
64
|
+
*.log
|
|
65
|
+
local_settings.py
|
|
66
|
+
db.sqlite3
|
|
67
|
+
db.sqlite3-journal
|
|
68
|
+
|
|
69
|
+
# Flask stuff:
|
|
70
|
+
instance/
|
|
71
|
+
.webassets-cache
|
|
72
|
+
|
|
73
|
+
# Scrapy stuff:
|
|
74
|
+
.scrapy
|
|
75
|
+
|
|
76
|
+
# Sphinx documentation
|
|
77
|
+
docs/_build/
|
|
78
|
+
|
|
79
|
+
# PyBuilder
|
|
80
|
+
.pybuilder/
|
|
81
|
+
target/
|
|
82
|
+
|
|
83
|
+
# Jupyter Notebook
|
|
84
|
+
.ipynb_checkpoints
|
|
85
|
+
|
|
86
|
+
# IPython
|
|
87
|
+
profile_default/
|
|
88
|
+
ipython_config.py
|
|
89
|
+
|
|
90
|
+
# pyenv
|
|
91
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
92
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
93
|
+
# .python-version
|
|
94
|
+
|
|
95
|
+
# pipenv
|
|
96
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
97
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
98
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
99
|
+
# install all needed dependencies.
|
|
100
|
+
#Pipfile.lock
|
|
101
|
+
|
|
102
|
+
# UV
|
|
103
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
104
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
105
|
+
# commonly ignored for libraries.
|
|
106
|
+
#uv.lock
|
|
107
|
+
|
|
108
|
+
# poetry
|
|
109
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
110
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
111
|
+
# commonly ignored for libraries.
|
|
112
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
113
|
+
#poetry.lock
|
|
114
|
+
|
|
115
|
+
# pdm
|
|
116
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
117
|
+
#pdm.lock
|
|
118
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
119
|
+
# in version control.
|
|
120
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
|
121
|
+
.pdm.toml
|
|
122
|
+
.pdm-python
|
|
123
|
+
.pdm-build/
|
|
124
|
+
|
|
125
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
126
|
+
__pypackages__/
|
|
127
|
+
|
|
128
|
+
# Celery stuff
|
|
129
|
+
celerybeat-schedule
|
|
130
|
+
celerybeat.pid
|
|
131
|
+
|
|
132
|
+
# SageMath parsed files
|
|
133
|
+
*.sage.py
|
|
134
|
+
|
|
135
|
+
# Environments
|
|
136
|
+
.env
|
|
137
|
+
.venv
|
|
138
|
+
env/
|
|
139
|
+
venv/
|
|
140
|
+
ENV/
|
|
141
|
+
env.bak/
|
|
142
|
+
venv.bak/
|
|
143
|
+
|
|
144
|
+
# Spyder project settings
|
|
145
|
+
.spyderproject
|
|
146
|
+
.spyproject
|
|
147
|
+
|
|
148
|
+
# Rope project settings
|
|
149
|
+
.ropeproject
|
|
150
|
+
|
|
151
|
+
# mkdocs documentation
|
|
152
|
+
/site
|
|
153
|
+
|
|
154
|
+
# mypy
|
|
155
|
+
.mypy_cache/
|
|
156
|
+
.dmypy.json
|
|
157
|
+
dmypy.json
|
|
158
|
+
|
|
159
|
+
# Pyre type checker
|
|
160
|
+
.pyre/
|
|
161
|
+
|
|
162
|
+
# pytype static type analyzer
|
|
163
|
+
.pytype/
|
|
164
|
+
|
|
165
|
+
# Cython debug symbols
|
|
166
|
+
cython_debug/
|
|
167
|
+
|
|
168
|
+
# PyCharm
|
|
169
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
170
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
171
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
172
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
173
|
+
#.idea/
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "factrainer-core"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [{ name = "ritsuki1227", email = "ritsuki1227@gmail.com" }]
|
|
7
|
+
requires-python = ">=3.12"
|
|
8
|
+
dependencies = ["factrainer-base"]
|
|
9
|
+
|
|
10
|
+
[build-system]
|
|
11
|
+
requires = ["hatchling"]
|
|
12
|
+
build-backend = "hatchling.build"
|
|
13
|
+
|
|
14
|
+
[tool.hatch.build.targets.wheel]
|
|
15
|
+
packages = ["src/factrainer"]
|
|
16
|
+
|
|
17
|
+
[tool.uv.sources]
|
|
18
|
+
factrainer-base = { workspace = true }
|
|
File without changes
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from typing import Any, Self
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
from factrainer.base.config import (
|
|
6
|
+
BaseLearner,
|
|
7
|
+
BaseMlModelConfig,
|
|
8
|
+
BasePredictConfig,
|
|
9
|
+
BasePredictor,
|
|
10
|
+
BaseTrainConfig,
|
|
11
|
+
)
|
|
12
|
+
from factrainer.base.dataset import IndexableDataset, Prediction
|
|
13
|
+
from factrainer.base.raw_model import RawModel
|
|
14
|
+
from joblib import Parallel, delayed
|
|
15
|
+
|
|
16
|
+
from .dataset import IndexedDatasets
|
|
17
|
+
from .raw_model import CvRawModels
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CvLearner[T: IndexableDataset, U: RawModel, V: BaseTrainConfig](
|
|
21
|
+
BaseLearner[IndexedDatasets[T], CvRawModels[U], V]
|
|
22
|
+
):
|
|
23
|
+
def __init__(
|
|
24
|
+
self, learner: BaseLearner[T, U, V], n_jobs: int | None = None
|
|
25
|
+
) -> None:
|
|
26
|
+
self._learner = learner
|
|
27
|
+
self._n_jobs = n_jobs
|
|
28
|
+
|
|
29
|
+
def train(
|
|
30
|
+
self,
|
|
31
|
+
train_dataset: IndexedDatasets[T],
|
|
32
|
+
val_dataset: IndexedDatasets[T] | None,
|
|
33
|
+
config: V,
|
|
34
|
+
) -> CvRawModels[U]:
|
|
35
|
+
if val_dataset is not None:
|
|
36
|
+
models = Parallel(n_jobs=self.n_jobs)(
|
|
37
|
+
delayed(self._learner.train)(train.data, val.data, config)
|
|
38
|
+
for train, val in zip(train_dataset.datasets, val_dataset.datasets)
|
|
39
|
+
)
|
|
40
|
+
else:
|
|
41
|
+
models = Parallel(n_jobs=self.n_jobs)(
|
|
42
|
+
delayed(self._learner.train)(train.data, None, config)
|
|
43
|
+
for train in train_dataset.datasets
|
|
44
|
+
)
|
|
45
|
+
return CvRawModels(models=models)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def n_jobs(self) -> int | None:
|
|
49
|
+
return self._n_jobs
|
|
50
|
+
|
|
51
|
+
@n_jobs.setter
|
|
52
|
+
def n_jobs(self, n_jobs: int | None) -> None:
|
|
53
|
+
self._n_jobs = n_jobs
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class CvPredictor[T: IndexableDataset, U: RawModel, W: BasePredictConfig](
|
|
57
|
+
BasePredictor[IndexedDatasets[T], CvRawModels[U], W]
|
|
58
|
+
):
|
|
59
|
+
def __init__(
|
|
60
|
+
self, predictor: BasePredictor[T, U, W], n_jobs: int | None = None
|
|
61
|
+
) -> None:
|
|
62
|
+
self._predictor = predictor
|
|
63
|
+
self._n_jobs = n_jobs
|
|
64
|
+
|
|
65
|
+
def predict(
|
|
66
|
+
self, dataset: IndexedDatasets[T], model: CvRawModels[U], config: W | None
|
|
67
|
+
) -> Prediction:
|
|
68
|
+
y_preds = Parallel(n_jobs=self.n_jobs)(
|
|
69
|
+
delayed(self._predictor.predict)(_dataset.data, _model, config)
|
|
70
|
+
for _model, _dataset in zip(model.models, dataset.datasets)
|
|
71
|
+
)
|
|
72
|
+
y_oof_pred = self._init_pred(len(dataset), y_preds[0])
|
|
73
|
+
for y_pred, _dataset in zip(y_preds, dataset.datasets):
|
|
74
|
+
y_oof_pred[_dataset.index] = y_pred
|
|
75
|
+
return y_oof_pred
|
|
76
|
+
|
|
77
|
+
def _init_pred(self, total_length: int, y_pred: npt.NDArray[Any]) -> Prediction:
|
|
78
|
+
return np.empty(
|
|
79
|
+
tuple([total_length] + list(y_pred.shape[1:])), dtype=y_pred.dtype
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def n_jobs(self) -> int | None:
|
|
84
|
+
return self._n_jobs
|
|
85
|
+
|
|
86
|
+
@n_jobs.setter
|
|
87
|
+
def n_jobs(self, n_jobs: int | None) -> None:
|
|
88
|
+
self._n_jobs = n_jobs
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class CvMlModelConfig[
|
|
92
|
+
T: IndexableDataset,
|
|
93
|
+
U: RawModel,
|
|
94
|
+
V: BaseTrainConfig,
|
|
95
|
+
W: BasePredictConfig,
|
|
96
|
+
](BaseMlModelConfig[IndexedDatasets[T], CvRawModels[U], V, W]):
|
|
97
|
+
learner: CvLearner[T, U, V]
|
|
98
|
+
predictor: CvPredictor[T, U, W]
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_config(
|
|
102
|
+
cls,
|
|
103
|
+
config: BaseMlModelConfig[T, U, V, W],
|
|
104
|
+
n_jobs_train: int | None = None,
|
|
105
|
+
n_jobs_predict: int | None = None,
|
|
106
|
+
) -> Self:
|
|
107
|
+
return cls(
|
|
108
|
+
learner=CvLearner(config.learner, n_jobs_train),
|
|
109
|
+
predictor=CvPredictor(config.predictor, n_jobs_predict),
|
|
110
|
+
train_config=config.train_config,
|
|
111
|
+
pred_config=config.pred_config,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def n_jobs_train(self) -> int | None:
|
|
116
|
+
return self.learner.n_jobs
|
|
117
|
+
|
|
118
|
+
@n_jobs_train.setter
|
|
119
|
+
def n_jobs_train(self, n_jobs: int | None) -> None:
|
|
120
|
+
self.learner.n_jobs = n_jobs
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def n_jobs_predict(self) -> int | None:
|
|
124
|
+
return self.predictor.n_jobs
|
|
125
|
+
|
|
126
|
+
@n_jobs_predict.setter
|
|
127
|
+
def n_jobs_predict(self, n_jobs: int | None) -> None:
|
|
128
|
+
self.predictor.n_jobs = n_jobs
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from factrainer.base.config import BaseMlModelConfig, BasePredictConfig, BaseTrainConfig
|
|
2
|
+
from factrainer.base.dataset import IndexableDataset, Prediction
|
|
3
|
+
from factrainer.base.raw_model import RawModel
|
|
4
|
+
from sklearn.model_selection._split import _BaseKFold
|
|
5
|
+
|
|
6
|
+
from ..single import SingleMlModel
|
|
7
|
+
from ..trait import PredictorTrait, TrainerTrait
|
|
8
|
+
from .config import CvMlModelConfig
|
|
9
|
+
from .dataset import IndexedDatasets, SplittedDatasets, SplittedDatasetsIndices
|
|
10
|
+
from .raw_model import CvRawModels
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CvMlModel[
|
|
14
|
+
T: IndexableDataset,
|
|
15
|
+
U: RawModel,
|
|
16
|
+
V: BaseTrainConfig,
|
|
17
|
+
W: BasePredictConfig,
|
|
18
|
+
](TrainerTrait[T, V], PredictorTrait[T, CvRawModels[U], W]):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model_config: BaseMlModelConfig[T, U, V, W],
|
|
22
|
+
k_fold: _BaseKFold,
|
|
23
|
+
n_jobs_train: int | None = None,
|
|
24
|
+
n_jobs_predict: int | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
self._cv_model = SingleMlModel(
|
|
27
|
+
CvMlModelConfig.from_config(model_config, n_jobs_train, n_jobs_predict)
|
|
28
|
+
)
|
|
29
|
+
self._k_fold = k_fold
|
|
30
|
+
|
|
31
|
+
def train(self, dataset: T) -> None:
|
|
32
|
+
datasets = SplittedDatasets.create(dataset, self._k_fold)
|
|
33
|
+
self._cv_indices = datasets.indices
|
|
34
|
+
self._cv_model.train(datasets.train, datasets.test)
|
|
35
|
+
|
|
36
|
+
def predict(self, dataset: T) -> Prediction:
|
|
37
|
+
datasets = IndexedDatasets.create(dataset, self.cv_indices.test)
|
|
38
|
+
return self._cv_model.predict(datasets)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def raw_model(self) -> CvRawModels[U]:
|
|
42
|
+
return self._cv_model.raw_model
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def cv_indices(self) -> SplittedDatasetsIndices:
|
|
46
|
+
return self._cv_indices
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def n_jobs_train(self) -> int | None:
|
|
50
|
+
return self._cv_model.model_config.n_jobs_train # type: ignore
|
|
51
|
+
|
|
52
|
+
@n_jobs_train.setter
|
|
53
|
+
def n_jobs_train(self, n_jobs: int | None) -> None:
|
|
54
|
+
self._cv_model.model_config.n_jobs_train = n_jobs # type: ignore
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def n_jobs_predict(self) -> int | None:
|
|
58
|
+
return self._cv_model.model_config.n_jobs_predict # type: ignore
|
|
59
|
+
|
|
60
|
+
@n_jobs_predict.setter
|
|
61
|
+
def n_jobs_predict(self, n_jobs: int | None) -> None:
|
|
62
|
+
self._cv_model.model_config.n_jobs_predict = n_jobs # type: ignore
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def train_config(self) -> V:
|
|
66
|
+
return self._cv_model.train_config
|
|
67
|
+
|
|
68
|
+
@train_config.setter
|
|
69
|
+
def train_config(self, config: V) -> None:
|
|
70
|
+
self._cv_model.train_config = config
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def pred_config(self) -> W | None:
|
|
74
|
+
return self._cv_model.pred_config
|
|
75
|
+
|
|
76
|
+
@pred_config.setter
|
|
77
|
+
def pred_config(self, config: W | None) -> None:
|
|
78
|
+
self._cv_model.pred_config = config
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Self
|
|
3
|
+
|
|
4
|
+
from factrainer.base.dataset import (
|
|
5
|
+
BaseDataset,
|
|
6
|
+
IndexableDataset,
|
|
7
|
+
RowIndex,
|
|
8
|
+
RowIndices,
|
|
9
|
+
)
|
|
10
|
+
from sklearn.model_selection._split import _BaseKFold
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IndexedDataset[T: IndexableDataset](BaseDataset):
|
|
14
|
+
index: RowIndex
|
|
15
|
+
data: T
|
|
16
|
+
|
|
17
|
+
def __len__(self) -> int:
|
|
18
|
+
return len(self.index)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class IndexedDatasets[T: IndexableDataset](BaseDataset):
|
|
22
|
+
datasets: Sequence[IndexedDataset[T]]
|
|
23
|
+
|
|
24
|
+
def __len__(self) -> int:
|
|
25
|
+
return sum([len(dataset) for dataset in self.datasets])
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def create(cls, dataset: T, k_fold: _BaseKFold | RowIndices) -> Self:
|
|
29
|
+
if isinstance(k_fold, _BaseKFold):
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
return cls(
|
|
32
|
+
datasets=[
|
|
33
|
+
IndexedDataset(index=index, data=dataset[index]) for index in k_fold
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def indices(self) -> RowIndices:
|
|
39
|
+
return [dataset.index for dataset in self.datasets]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SplittedDataset[T: IndexableDataset](BaseDataset):
|
|
43
|
+
train: IndexedDataset[T]
|
|
44
|
+
val: IndexedDataset[T] | None
|
|
45
|
+
test: IndexedDataset[T]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SplittedDatasetsIndices(BaseDataset):
|
|
49
|
+
train: RowIndices
|
|
50
|
+
val: RowIndices | None
|
|
51
|
+
test: RowIndices
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SplittedDatasets[T: IndexableDataset](BaseDataset):
|
|
55
|
+
datasets: Sequence[SplittedDataset[T]]
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def train(self) -> IndexedDatasets[T]:
|
|
59
|
+
return IndexedDatasets(datasets=[dataset.train for dataset in self.datasets])
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def val(self) -> IndexedDatasets[T] | None:
|
|
63
|
+
vals = []
|
|
64
|
+
for dataset in self.datasets:
|
|
65
|
+
if dataset.val is None:
|
|
66
|
+
return None
|
|
67
|
+
vals.append(dataset.val)
|
|
68
|
+
return IndexedDatasets(datasets=vals)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def test(self) -> IndexedDatasets[T]:
|
|
72
|
+
return IndexedDatasets(datasets=[dataset.test for dataset in self.datasets])
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def indices(self) -> SplittedDatasetsIndices:
|
|
76
|
+
return SplittedDatasetsIndices(
|
|
77
|
+
train=self.train.indices,
|
|
78
|
+
val=self.val.indices if self.val is not None else None,
|
|
79
|
+
test=self.test.indices,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def create(
|
|
84
|
+
cls, dataset: T, k_fold: _BaseKFold, share_holdouts: bool = True
|
|
85
|
+
) -> Self:
|
|
86
|
+
datasets = []
|
|
87
|
+
for train_index, val_index in dataset.get_index(k_fold):
|
|
88
|
+
if share_holdouts:
|
|
89
|
+
test_index = val_index
|
|
90
|
+
else:
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
datasets.append(
|
|
93
|
+
SplittedDataset(
|
|
94
|
+
train=IndexedDataset(index=train_index, data=dataset[train_index]),
|
|
95
|
+
val=IndexedDataset(index=val_index, data=dataset[val_index]),
|
|
96
|
+
test=IndexedDataset(index=test_index, data=dataset[test_index]),
|
|
97
|
+
)
|
|
98
|
+
)
|
|
99
|
+
return cls(datasets=datasets)
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from factrainer.base.config import (
|
|
2
|
+
BaseMlModelConfig,
|
|
3
|
+
BasePredictConfig,
|
|
4
|
+
BaseTrainConfig,
|
|
5
|
+
)
|
|
6
|
+
from factrainer.base.dataset import BaseDataset, Prediction
|
|
7
|
+
from factrainer.base.raw_model import RawModel
|
|
8
|
+
|
|
9
|
+
from .trait import (
|
|
10
|
+
PredictorTrait,
|
|
11
|
+
ValidatableTrainerTrait,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SingleMlModel[
|
|
16
|
+
T: BaseDataset,
|
|
17
|
+
U: RawModel,
|
|
18
|
+
V: BaseTrainConfig,
|
|
19
|
+
W: BasePredictConfig,
|
|
20
|
+
](ValidatableTrainerTrait[T, V], PredictorTrait[T, U, W]):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model_config: BaseMlModelConfig[T, U, V, W],
|
|
24
|
+
) -> None:
|
|
25
|
+
self.model_config = model_config
|
|
26
|
+
|
|
27
|
+
def train(self, train_dataset: T, val_dataset: T | None = None) -> None:
|
|
28
|
+
self._model = self.model_config.learner.train(
|
|
29
|
+
train_dataset, val_dataset, self.model_config.train_config
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
def predict(self, dataset: T) -> Prediction:
|
|
33
|
+
return self.model_config.predictor.predict(
|
|
34
|
+
dataset, self.raw_model, self.model_config.pred_config
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def raw_model(self) -> U:
|
|
39
|
+
return self._model
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def train_config(self) -> V:
|
|
43
|
+
return self.model_config.train_config
|
|
44
|
+
|
|
45
|
+
@train_config.setter
|
|
46
|
+
def train_config(self, config: V) -> None:
|
|
47
|
+
self.model_config.train_config = config
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def pred_config(self) -> W | None:
|
|
51
|
+
return self.model_config.pred_config
|
|
52
|
+
|
|
53
|
+
@pred_config.setter
|
|
54
|
+
def pred_config(self, config: W | None) -> None:
|
|
55
|
+
self.model_config.pred_config = config
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from factrainer.base.config import BasePredictConfig, BaseTrainConfig
|
|
4
|
+
from factrainer.base.dataset import BaseDataset, Prediction
|
|
5
|
+
from factrainer.base.raw_model import RawModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PredictorTrait[T: BaseDataset, U: RawModel, W: BasePredictConfig](ABC):
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def predict(self, dataset: T) -> Prediction:
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def raw_model(self) -> U:
|
|
16
|
+
raise NotImplementedError
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def pred_config(self) -> W | None:
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
@pred_config.setter
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def pred_config(self, config: W | None) -> None:
|
|
26
|
+
raise NotImplementedError
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TrainerTrait[T: BaseDataset, V: BaseTrainConfig](ABC):
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def train(self, dataset: T) -> None:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def train_config(self) -> V:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
@train_config.setter
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def train_config(self, config: V) -> None:
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ValidatableTrainerTrait[T: BaseDataset, V: BaseTrainConfig](ABC):
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def train(self, train_dataset: T, val_dataset: T | None = None) -> None:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def train_config(self) -> V:
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@train_config.setter
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def train_config(self, config: V) -> None:
|
|
58
|
+
raise NotImplementedError
|