mltrainer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mltrainer-0.1.0/PKG-INFO +35 -0
- mltrainer-0.1.0/README.md +3 -0
- mltrainer-0.1.0/mltrainer/__init__.py +6 -0
- mltrainer-0.1.0/mltrainer/imagemodels.py +56 -0
- mltrainer-0.1.0/mltrainer/metrics.py +61 -0
- mltrainer-0.1.0/mltrainer/rnn_models.py +151 -0
- mltrainer-0.1.0/mltrainer/settings.py +61 -0
- mltrainer-0.1.0/mltrainer/tokenizer.py +83 -0
- mltrainer-0.1.0/mltrainer/trainer.py +235 -0
- mltrainer-0.1.0/mltrainer/vae.py +92 -0
- mltrainer-0.1.0/pyproject.toml +78 -0
mltrainer-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mltrainer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: toolkit for training pytorch models
|
|
5
|
+
Author-Email: R.Grouls <Raoul.Grouls@han.nl>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.10
|
|
8
|
+
Requires-Dist: gin-config>=0.5.0
|
|
9
|
+
Requires-Dist: numpy>=1.25.0
|
|
10
|
+
Requires-Dist: torch>=2.0.1
|
|
11
|
+
Requires-Dist: loguru>=0.7.0
|
|
12
|
+
Requires-Dist: ray[tune]>=2.5.1
|
|
13
|
+
Requires-Dist: tqdm>=4.65.0
|
|
14
|
+
Requires-Dist: pydantic>=1.10.9
|
|
15
|
+
Requires-Dist: torchvision>=0.15.2
|
|
16
|
+
Requires-Dist: torchtext>=0.15.2
|
|
17
|
+
Requires-Dist: torch-tb-profiler>=0.4.1
|
|
18
|
+
Requires-Dist: Flake8-pyproject>=1.2.3; extra == "lint"
|
|
19
|
+
Requires-Dist: pep8-naming>=0.13.3; extra == "lint"
|
|
20
|
+
Requires-Dist: flake8-annotations>=3.0.1; extra == "lint"
|
|
21
|
+
Requires-Dist: black>=23.3.0; extra == "lint"
|
|
22
|
+
Requires-Dist: flake8>=6.0.0; extra == "lint"
|
|
23
|
+
Requires-Dist: isort>=5.12.0; extra == "lint"
|
|
24
|
+
Requires-Dist: mypy>=1.4.1; extra == "lint"
|
|
25
|
+
Requires-Dist: mlflow>=2.4.1; extra == "tuning"
|
|
26
|
+
Requires-Dist: bayesian-optimization>=1.4.3; extra == "tuning"
|
|
27
|
+
Requires-Dist: hpbandster>=0.7.4; extra == "tuning"
|
|
28
|
+
Requires-Dist: configspace>=0.7.1; extra == "tuning"
|
|
29
|
+
Provides-Extra: lint
|
|
30
|
+
Provides-Extra: tuning
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# ml-trainer
|
|
34
|
+
|
|
35
|
+
toolkit for training pytorch models
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import gin
|
|
2
|
+
import torch
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@gin.configurable
|
|
7
|
+
class NeuralNetwork(nn.Module):
|
|
8
|
+
def __init__(self, num_classes: int, units1: int, units2: int) -> None:
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.flatten = nn.Flatten()
|
|
11
|
+
self.linear_relu_stack = nn.Sequential(
|
|
12
|
+
nn.Linear(28 * 28, units1),
|
|
13
|
+
nn.ReLU(),
|
|
14
|
+
nn.Linear(units1, units2),
|
|
15
|
+
nn.ReLU(),
|
|
16
|
+
nn.Linear(units2, num_classes),
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
x = self.flatten(x)
|
|
21
|
+
logits = self.linear_relu_stack(x)
|
|
22
|
+
return logits
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@gin.configurable
|
|
26
|
+
class CNN(nn.Module):
|
|
27
|
+
def __init__(
|
|
28
|
+
self, num_classes: int, kernel_size: int, filter1: int, filter2: int
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.convolutions = nn.Sequential(
|
|
33
|
+
nn.Conv2d(1, filter1, kernel_size=kernel_size, stride=1, padding=1),
|
|
34
|
+
nn.ReLU(),
|
|
35
|
+
nn.MaxPool2d(kernel_size=2),
|
|
36
|
+
nn.Conv2d(filter1, filter2, kernel_size=kernel_size, stride=1, padding=0),
|
|
37
|
+
nn.ReLU(),
|
|
38
|
+
nn.MaxPool2d(kernel_size=2),
|
|
39
|
+
nn.Conv2d(filter2, 32, kernel_size=kernel_size, stride=1, padding=0),
|
|
40
|
+
nn.ReLU(),
|
|
41
|
+
nn.MaxPool2d(kernel_size=2),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
self.dense = nn.Sequential(
|
|
45
|
+
nn.Flatten(),
|
|
46
|
+
nn.Linear(128, 64),
|
|
47
|
+
nn.ReLU(),
|
|
48
|
+
nn.Linear(64, 32),
|
|
49
|
+
nn.ReLU(),
|
|
50
|
+
nn.Linear(32, num_classes),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
x = self.convolutions(x)
|
|
55
|
+
logits = self.dense(x)
|
|
56
|
+
return logits
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterator
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
Tensor = torch.Tensor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Metric:
|
|
11
|
+
def __repr__(self) -> str:
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
def __call__(self, y: Tensor, yhat: Tensor) -> Tensor:
|
|
15
|
+
raise NotImplementedError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MASE(Metric):
|
|
19
|
+
def __init__(self, train: Iterator, horizon: int) -> None:
|
|
20
|
+
self.scale = self.naivenorm(train, horizon)
|
|
21
|
+
|
|
22
|
+
def __repr__(self) -> str:
|
|
23
|
+
return f"MASE(scale={self.scale:.3f})"
|
|
24
|
+
|
|
25
|
+
def naivenorm(self, train: Iterator, horizon: int) -> Tensor:
|
|
26
|
+
elist = []
|
|
27
|
+
# TODO fix ignore
|
|
28
|
+
streamer = train.stream() # type: ignore
|
|
29
|
+
for _ in range(len(train)): # type: ignore
|
|
30
|
+
x, y = next(iter(streamer))
|
|
31
|
+
yhat = self.naivepredict(x, horizon)
|
|
32
|
+
e = self.mae(y, yhat)
|
|
33
|
+
elist.append(e)
|
|
34
|
+
return torch.mean(torch.tensor(elist))
|
|
35
|
+
|
|
36
|
+
def naivepredict(self, x: Tensor, horizon: int) -> Tensor:
|
|
37
|
+
assert horizon > 0
|
|
38
|
+
yhat = x[..., -horizon:, :].squeeze(-1)
|
|
39
|
+
return yhat
|
|
40
|
+
|
|
41
|
+
def mae(self, y: Tensor, yhat: Tensor) -> Tensor:
|
|
42
|
+
return torch.mean(torch.abs(y - yhat))
|
|
43
|
+
|
|
44
|
+
def __call__(self, y: Tensor, yhat: Tensor) -> Tensor:
|
|
45
|
+
return self.mae(y, yhat) / self.scale
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MAE(Metric):
|
|
49
|
+
def __repr__(self) -> str:
|
|
50
|
+
return "MAE"
|
|
51
|
+
|
|
52
|
+
def __call__(self, y: Tensor, yhat: Tensor) -> Tensor:
|
|
53
|
+
return torch.mean(torch.abs(y - yhat))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Accuracy(Metric):
|
|
57
|
+
def __repr__(self) -> str:
|
|
58
|
+
return "Accuracy"
|
|
59
|
+
|
|
60
|
+
def __call__(self, y: Tensor, yhat: Tensor) -> Tensor:
|
|
61
|
+
return (yhat.argmax(dim=1) == y).sum() / len(yhat)
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import gin
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
Tensor = torch.Tensor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseModel(nn.Module):
|
|
11
|
+
def __init__(self, observations: int, horizon: int) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.flatten = nn.Flatten() # we have 3d data, the linear model wants 2D
|
|
14
|
+
self.linear = nn.Linear(observations, horizon)
|
|
15
|
+
|
|
16
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
17
|
+
x = self.flatten(x)
|
|
18
|
+
x = self.linear(x)
|
|
19
|
+
return x
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@gin.configurable
|
|
23
|
+
class BaseRNN(nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self, input_size: int, hidden_size: int, num_layers: int, horizon: int
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.rnn = nn.RNN(
|
|
29
|
+
input_size=input_size,
|
|
30
|
+
hidden_size=hidden_size,
|
|
31
|
+
batch_first=True,
|
|
32
|
+
num_layers=num_layers,
|
|
33
|
+
)
|
|
34
|
+
self.linear = nn.Linear(hidden_size, horizon)
|
|
35
|
+
self.horizon = horizon
|
|
36
|
+
|
|
37
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
38
|
+
x, _ = self.rnn(x)
|
|
39
|
+
last_step = x[:, -1, :]
|
|
40
|
+
yhat = self.linear(last_step)
|
|
41
|
+
return yhat
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@gin.configurable
|
|
45
|
+
class GRUmodel(nn.Module):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
config: Dict,
|
|
49
|
+
) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.rnn = nn.GRU(
|
|
52
|
+
input_size=config["input_size"],
|
|
53
|
+
hidden_size=config["hidden_size"],
|
|
54
|
+
dropout=config["dropout"],
|
|
55
|
+
batch_first=True,
|
|
56
|
+
num_layers=config["num_layers"],
|
|
57
|
+
)
|
|
58
|
+
self.linear = nn.Linear(config["hidden_size"], config["output_size"])
|
|
59
|
+
|
|
60
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
61
|
+
x, _ = self.rnn(x)
|
|
62
|
+
last_step = x[:, -1, :]
|
|
63
|
+
yhat = self.linear(last_step)
|
|
64
|
+
return yhat
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@gin.configurable
|
|
68
|
+
class AttentionGRU(nn.Module):
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
config: Dict,
|
|
72
|
+
) -> None:
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.rnn = nn.GRU(
|
|
75
|
+
input_size=config["input_size"],
|
|
76
|
+
hidden_size=config["hidden_size"],
|
|
77
|
+
dropout=config["dropout"],
|
|
78
|
+
batch_first=True,
|
|
79
|
+
num_layers=config["num_layers"],
|
|
80
|
+
)
|
|
81
|
+
self.attention = nn.MultiheadAttention(
|
|
82
|
+
embed_dim=config["hidden_size"],
|
|
83
|
+
num_heads=4,
|
|
84
|
+
dropout=config["dropout"],
|
|
85
|
+
batch_first=True,
|
|
86
|
+
)
|
|
87
|
+
self.linear = nn.Linear(config["hidden_size"], config["output_size"])
|
|
88
|
+
|
|
89
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
90
|
+
x, _ = self.rnn(x)
|
|
91
|
+
x, _ = self.attention(x.clone(), x.clone(), x)
|
|
92
|
+
last_step = x[:, -1, :]
|
|
93
|
+
yhat = self.linear(last_step)
|
|
94
|
+
return yhat
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@gin.configurable
|
|
98
|
+
class NLPmodel(nn.Module):
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
config: Dict,
|
|
102
|
+
) -> None:
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.emb = nn.Embedding(config["vocab"], config["hidden_size"])
|
|
105
|
+
self.rnn = nn.GRU(
|
|
106
|
+
input_size=config["hidden_size"],
|
|
107
|
+
hidden_size=config["hidden_size"],
|
|
108
|
+
dropout=config["dropout"],
|
|
109
|
+
batch_first=True,
|
|
110
|
+
num_layers=config["num_layers"],
|
|
111
|
+
)
|
|
112
|
+
self.linear = nn.Linear(config["hidden_size"], config["output_size"])
|
|
113
|
+
|
|
114
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
115
|
+
x = self.emb(x)
|
|
116
|
+
x, _ = self.rnn(x)
|
|
117
|
+
last_step = x[:, -1, :]
|
|
118
|
+
yhat = self.linear(last_step)
|
|
119
|
+
return yhat
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@gin.configurable
|
|
123
|
+
class AttentionNLP(nn.Module):
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
config: Dict,
|
|
127
|
+
) -> None:
|
|
128
|
+
super().__init__()
|
|
129
|
+
self.emb = nn.Embedding(config["vocab"], config["hidden_size"])
|
|
130
|
+
self.rnn = nn.GRU(
|
|
131
|
+
input_size=config["hidden_size"],
|
|
132
|
+
hidden_size=config["hidden_size"],
|
|
133
|
+
dropout=config["dropout"],
|
|
134
|
+
batch_first=True,
|
|
135
|
+
num_layers=config["num_layers"],
|
|
136
|
+
)
|
|
137
|
+
self.attention = nn.MultiheadAttention(
|
|
138
|
+
embed_dim=config["hidden_size"],
|
|
139
|
+
num_heads=4,
|
|
140
|
+
dropout=config["dropout"],
|
|
141
|
+
batch_first=True,
|
|
142
|
+
)
|
|
143
|
+
self.linear = nn.Linear(config["hidden_size"], config["output_size"])
|
|
144
|
+
|
|
145
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
146
|
+
x = self.emb(x)
|
|
147
|
+
x, _ = self.rnn(x)
|
|
148
|
+
x, _ = self.attention(x.clone(), x.clone(), x)
|
|
149
|
+
last_step = x[:, -1, :]
|
|
150
|
+
yhat = self.linear(last_step)
|
|
151
|
+
return yhat
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from pydantic import BaseModel, root_validator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FileTypes(Enum):
|
|
12
|
+
JPG = ".jpg"
|
|
13
|
+
PNG = ".png"
|
|
14
|
+
TXT = ".txt"
|
|
15
|
+
ZIP = ".zip"
|
|
16
|
+
TGZ = ".tgz"
|
|
17
|
+
TAR = ".tar.gz"
|
|
18
|
+
GZ = ".gz"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ReportTypes(Enum):
|
|
22
|
+
GIN = 1
|
|
23
|
+
TENSORBOARD = 2
|
|
24
|
+
MLFLOW = 3
|
|
25
|
+
RAY = 4
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FormattedBase(BaseModel):
|
|
29
|
+
def __str__(self) -> str:
|
|
30
|
+
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
|
31
|
+
|
|
32
|
+
def __repr__(self) -> str:
|
|
33
|
+
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TrainerSettings(FormattedBase):
|
|
37
|
+
epochs: int
|
|
38
|
+
metrics: List[Callable]
|
|
39
|
+
logdir: Path
|
|
40
|
+
train_steps: int
|
|
41
|
+
valid_steps: int
|
|
42
|
+
reporttypes: List[ReportTypes]
|
|
43
|
+
optimizer_kwargs: Dict[str, Any] = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
44
|
+
scheduler_kwargs: Optional[Dict[str, Any]] = {"factor": 0.1, "patience": 10}
|
|
45
|
+
earlystop_kwargs: Optional[Dict[str, Any]] = {
|
|
46
|
+
"save": False,
|
|
47
|
+
"verbose": True,
|
|
48
|
+
"patience": 10,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
class Config:
|
|
52
|
+
arbitrary_types_allowed = True
|
|
53
|
+
|
|
54
|
+
@root_validator
|
|
55
|
+
def check_path(cls, values: Dict) -> Dict: # noqa: N805
|
|
56
|
+
datadir = values.get("logdir").resolve()
|
|
57
|
+
if not datadir.exists(): # type: ignore
|
|
58
|
+
raise FileNotFoundError(
|
|
59
|
+
f"Make sure the datadir exists.\n Found {datadir} to be non-existing."
|
|
60
|
+
)
|
|
61
|
+
return values
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import string
|
|
3
|
+
from collections import Counter, OrderedDict
|
|
4
|
+
from typing import Callable, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import gin
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
10
|
+
from torchtext.vocab import Vocab, vocab
|
|
11
|
+
|
|
12
|
+
Tensor = torch.Tensor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def split_and_flat(corpus: List[str]) -> List[str]:
|
|
16
|
+
corpus_ = [x.split() for x in corpus]
|
|
17
|
+
corpus = [x for y in corpus_ for x in y]
|
|
18
|
+
return corpus
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@gin.configurable
|
|
22
|
+
def build_vocab(
|
|
23
|
+
corpus: List[str], max: int, oov: str = "<OOV>", pad: str = "<PAD>"
|
|
24
|
+
) -> Vocab:
|
|
25
|
+
data = split_and_flat(corpus)
|
|
26
|
+
counter = Counter(data).most_common()
|
|
27
|
+
logger.info(f"Found {len(counter)} tokens")
|
|
28
|
+
counter = counter[: max - 2]
|
|
29
|
+
ordered_dict = OrderedDict(counter)
|
|
30
|
+
v1 = vocab(ordered_dict, specials=[pad, oov])
|
|
31
|
+
v1.set_default_index(v1[oov])
|
|
32
|
+
return v1
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def tokenize(corpus: List[str], v: Vocab) -> Tensor:
|
|
36
|
+
batch = []
|
|
37
|
+
for sent in corpus:
|
|
38
|
+
batch.append(torch.tensor([v[word] for word in sent.split()]))
|
|
39
|
+
return pad_sequence(batch, batch_first=True)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def clean(text: str) -> str:
|
|
43
|
+
punctuation = f"[{string.punctuation}]"
|
|
44
|
+
# remove CaPiTaLs
|
|
45
|
+
lowercase = text.lower()
|
|
46
|
+
# change don't and isn't into dont and isnt
|
|
47
|
+
neg = re.sub("\\'", "", lowercase)
|
|
48
|
+
# swap html tags for spaces
|
|
49
|
+
html = re.sub("<br />", " ", neg)
|
|
50
|
+
# swap punctuation for spaces
|
|
51
|
+
stripped = re.sub(punctuation, " ", html)
|
|
52
|
+
# remove extra spaces
|
|
53
|
+
spaces = re.sub(" +", " ", stripped)
|
|
54
|
+
return spaces
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@gin.configurable
|
|
58
|
+
class Preprocessor:
|
|
59
|
+
def __init__(
|
|
60
|
+
self, max: int, vocab: Vocab, clean: Optional[Callable] = None
|
|
61
|
+
) -> None:
|
|
62
|
+
self.max = max
|
|
63
|
+
self.vocab = vocab
|
|
64
|
+
self.clean = clean
|
|
65
|
+
|
|
66
|
+
def cast_label(self, label: str) -> int:
|
|
67
|
+
if label == "neg":
|
|
68
|
+
return 0
|
|
69
|
+
else:
|
|
70
|
+
return 1
|
|
71
|
+
|
|
72
|
+
def __call__(self, batch: List) -> Tuple[Tensor, Tensor]:
|
|
73
|
+
labels, text = [], []
|
|
74
|
+
for x, y in batch:
|
|
75
|
+
if clean is not None:
|
|
76
|
+
x = self.clean(x) # type: ignore
|
|
77
|
+
x = x.split()[: self.max]
|
|
78
|
+
tokens = torch.tensor([self.vocab[word] for word in x], dtype=torch.int32)
|
|
79
|
+
text.append(tokens)
|
|
80
|
+
labels.append(self.cast_label(y))
|
|
81
|
+
|
|
82
|
+
text_ = pad_sequence(text, batch_first=True, padding_value=0)
|
|
83
|
+
return text_, torch.tensor(labels)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Callable, Dict, Iterator, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import gin
|
|
6
|
+
import mlflow
|
|
7
|
+
# needed to make summarywriter load without error
|
|
8
|
+
import torch
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from ray import tune
|
|
11
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from deeptoolkit.settings import ReportTypes, TrainerSettings
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def write_gin(dir: Path, txt: str) -> None:
|
|
18
|
+
path = dir / "saved_config.gin"
|
|
19
|
+
with open(path, "w") as file:
|
|
20
|
+
file.write(txt)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def dir_add_timestamp(log_dir: Optional[Path] = None) -> Path:
|
|
24
|
+
if log_dir is None:
|
|
25
|
+
log_dir = Path(".")
|
|
26
|
+
log_dir = Path(log_dir)
|
|
27
|
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M")
|
|
28
|
+
log_dir = log_dir / timestamp
|
|
29
|
+
logger.info(f"Logging to {log_dir}")
|
|
30
|
+
if not log_dir.exists():
|
|
31
|
+
log_dir.mkdir(parents=True)
|
|
32
|
+
return log_dir
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Trainer:
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model: torch.nn.Module,
|
|
39
|
+
settings: TrainerSettings,
|
|
40
|
+
loss_fn: Callable,
|
|
41
|
+
optimizer: torch.optim.Optimizer,
|
|
42
|
+
traindataloader: Iterator,
|
|
43
|
+
validdataloader: Iterator,
|
|
44
|
+
scheduler: Optional[Callable],
|
|
45
|
+
) -> None:
|
|
46
|
+
self.model = model
|
|
47
|
+
self.settings = settings
|
|
48
|
+
self.log_dir = dir_add_timestamp(settings.logdir)
|
|
49
|
+
self.loss_fn = loss_fn
|
|
50
|
+
self.optimizer = optimizer
|
|
51
|
+
self.traindataloader = traindataloader
|
|
52
|
+
self.validdataloader = validdataloader
|
|
53
|
+
|
|
54
|
+
self.optimizer = optimizer( # type: ignore
|
|
55
|
+
model.parameters(), **settings.optimizer_kwargs
|
|
56
|
+
)
|
|
57
|
+
self.last_epoch = 0
|
|
58
|
+
|
|
59
|
+
if scheduler:
|
|
60
|
+
if settings.scheduler_kwargs is None:
|
|
61
|
+
raise ValueError("Missing 'scheduler_kwargs' in TrainerSettings.")
|
|
62
|
+
self.scheduler = scheduler(self.optimizer, **settings.scheduler_kwargs)
|
|
63
|
+
|
|
64
|
+
if settings.earlystop_kwargs is not None:
|
|
65
|
+
logger.info(
|
|
66
|
+
"Found earlystop_kwargs in settings."
|
|
67
|
+
"Set to None if you dont want earlystopping."
|
|
68
|
+
)
|
|
69
|
+
self.early_stopping: Optional[EarlyStopping] = EarlyStopping(
|
|
70
|
+
self.log_dir, **settings.earlystop_kwargs
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
self.early_stopping = None
|
|
74
|
+
|
|
75
|
+
if ReportTypes.TENSORBOARD in self.settings.reporttypes:
|
|
76
|
+
self.writer = SummaryWriter(log_dir=self.log_dir)
|
|
77
|
+
|
|
78
|
+
if ReportTypes.GIN in self.settings.reporttypes:
|
|
79
|
+
write_gin(self.log_dir, gin.config_str())
|
|
80
|
+
|
|
81
|
+
def loop(self) -> None:
|
|
82
|
+
for epoch in tqdm(range(self.settings.epochs), colour="#1e4706"):
|
|
83
|
+
train_loss = self.trainbatches()
|
|
84
|
+
metric_dict, test_loss = self.evalbatches()
|
|
85
|
+
self.report(epoch, train_loss, test_loss, metric_dict)
|
|
86
|
+
|
|
87
|
+
if self.early_stopping:
|
|
88
|
+
self.early_stopping(test_loss, self.model) # type: ignore
|
|
89
|
+
|
|
90
|
+
if self.early_stopping is not None and self.early_stopping.early_stop:
|
|
91
|
+
logger.info("Interrupting loop due to early stopping patience.")
|
|
92
|
+
self.last_epoch = epoch
|
|
93
|
+
if self.early_stopping.save:
|
|
94
|
+
logger.info("retrieving best model.")
|
|
95
|
+
self.model = self.early_stopping.get_best() # type: ignore
|
|
96
|
+
else:
|
|
97
|
+
logger.info(
|
|
98
|
+
"early_stopping_save was false, using latest model."
|
|
99
|
+
"Set to true to retrieve best model."
|
|
100
|
+
)
|
|
101
|
+
break
|
|
102
|
+
self.last_epoch = epoch
|
|
103
|
+
|
|
104
|
+
def trainbatches(self) -> float:
|
|
105
|
+
self.model.train()
|
|
106
|
+
train_loss: float = 0.0
|
|
107
|
+
train_steps = self.settings.train_steps
|
|
108
|
+
for _ in tqdm(range(train_steps), colour="#1e4706"):
|
|
109
|
+
x, y = next(iter(self.traindataloader))
|
|
110
|
+
self.optimizer.zero_grad()
|
|
111
|
+
yhat = self.model(x)
|
|
112
|
+
loss = self.loss_fn(yhat, y)
|
|
113
|
+
loss.backward()
|
|
114
|
+
self.optimizer.step()
|
|
115
|
+
train_loss += loss.detach().numpy()
|
|
116
|
+
train_loss /= train_steps
|
|
117
|
+
return train_loss
|
|
118
|
+
|
|
119
|
+
def evalbatches(self) -> Tuple[Dict[str, float], float]:
|
|
120
|
+
self.model.eval()
|
|
121
|
+
valid_steps = self.settings.valid_steps
|
|
122
|
+
test_loss: float = 0.0
|
|
123
|
+
metric_dict: Dict[str, float] = {}
|
|
124
|
+
for _ in range(valid_steps):
|
|
125
|
+
x, y = next(iter(self.validdataloader))
|
|
126
|
+
yhat = self.model(x)
|
|
127
|
+
test_loss += self.loss_fn(yhat, y).detach().numpy()
|
|
128
|
+
for m in self.settings.metrics:
|
|
129
|
+
metric_dict[str(m)] = (
|
|
130
|
+
metric_dict.get(str(m), 0.0) + m(y, yhat).detach().numpy()
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
test_loss /= valid_steps
|
|
134
|
+
for key in metric_dict:
|
|
135
|
+
metric_dict[str(key)] = metric_dict[str(key)] / valid_steps
|
|
136
|
+
return metric_dict, test_loss
|
|
137
|
+
|
|
138
|
+
def report(
|
|
139
|
+
self, epoch: int, train_loss: float, test_loss: float, metric_dict: Dict
|
|
140
|
+
) -> None:
|
|
141
|
+
epoch = epoch + self.last_epoch
|
|
142
|
+
reporttypes = self.settings.reporttypes
|
|
143
|
+
self.test_loss = test_loss
|
|
144
|
+
|
|
145
|
+
if ReportTypes.RAY in reporttypes:
|
|
146
|
+
tune.report(
|
|
147
|
+
iterations=epoch,
|
|
148
|
+
train_loss=train_loss,
|
|
149
|
+
test_loss=test_loss,
|
|
150
|
+
**metric_dict,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if ReportTypes.MLFLOW in reporttypes:
|
|
154
|
+
mlflow.log_metric("Loss/train", train_loss, step=epoch)
|
|
155
|
+
mlflow.log_metric("Loss/test", test_loss, step=epoch)
|
|
156
|
+
for m in metric_dict:
|
|
157
|
+
mlflow.log_metric(f"metric/{m}", metric_dict[m], step=epoch)
|
|
158
|
+
lr = [group["lr"] for group in self.optimizer.param_groups][0]
|
|
159
|
+
mlflow.log_metric("learning_rate", lr, step=epoch)
|
|
160
|
+
|
|
161
|
+
if ReportTypes.TENSORBOARD in reporttypes:
|
|
162
|
+
self.writer.add_scalar("Loss/train", train_loss, epoch)
|
|
163
|
+
self.writer.add_scalar("Loss/test", test_loss, epoch)
|
|
164
|
+
for m in metric_dict:
|
|
165
|
+
self.writer.add_scalar(f"metric/{m}", metric_dict[m], epoch)
|
|
166
|
+
lr = [group["lr"] for group in self.optimizer.param_groups][0]
|
|
167
|
+
self.writer.add_scalar("learning_rate", lr, epoch)
|
|
168
|
+
|
|
169
|
+
metric_scores = [f"{v:.4f}" for v in metric_dict.values()]
|
|
170
|
+
logger.info(
|
|
171
|
+
f"Epoch {epoch} train {train_loss:.4f} test {test_loss:.4f} metric {metric_scores}" # noqa E501
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class EarlyStopping:
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
log_dir: Path,
|
|
179
|
+
patience: int = 7,
|
|
180
|
+
verbose: bool = False,
|
|
181
|
+
delta: float = 0.0,
|
|
182
|
+
save: bool = False,
|
|
183
|
+
) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Args:
|
|
186
|
+
log_dir (Path): location to save checkpoint to.
|
|
187
|
+
patience (int): How long to wait after last time validation loss improved.
|
|
188
|
+
Default: 7
|
|
189
|
+
verbose (bool): If True, prints a message for each validation loss improvement.
|
|
190
|
+
Default: False
|
|
191
|
+
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
192
|
+
Default: 0.0
|
|
193
|
+
"""
|
|
194
|
+
self.patience = patience
|
|
195
|
+
self.verbose = verbose
|
|
196
|
+
self.counter = 0
|
|
197
|
+
self.best_loss = None
|
|
198
|
+
self.early_stop = False
|
|
199
|
+
self.delta = delta
|
|
200
|
+
self.path = Path(log_dir) / "checkpoint.pt"
|
|
201
|
+
self.save = save
|
|
202
|
+
|
|
203
|
+
def __call__(self, val_loss: float, model: torch.nn.Module) -> None:
|
|
204
|
+
# first epoch best_loss is still None
|
|
205
|
+
if self.best_loss is None:
|
|
206
|
+
self.best_loss = val_loss # type: ignore
|
|
207
|
+
if self.save:
|
|
208
|
+
self.save_checkpoint(val_loss, model)
|
|
209
|
+
elif val_loss >= self.best_loss + self.delta: # type: ignore
|
|
210
|
+
# we minimize loss. If current loss did not improve
|
|
211
|
+
# the previous best (with a delta) it is considered not to improve.
|
|
212
|
+
self.counter += 1
|
|
213
|
+
logger.info(
|
|
214
|
+
f"best loss: {self.best_loss:.4f}, current loss {val_loss:.4f}. Counter {self.counter}/{self.patience}."
|
|
215
|
+
)
|
|
216
|
+
if self.counter >= self.patience:
|
|
217
|
+
self.early_stop = True
|
|
218
|
+
else:
|
|
219
|
+
# if not the first run, and val_loss is smaller, we improved.
|
|
220
|
+
self.best_loss = val_loss
|
|
221
|
+
if self.save:
|
|
222
|
+
self.save_checkpoint(val_loss, model)
|
|
223
|
+
self.counter = 0
|
|
224
|
+
|
|
225
|
+
def save_checkpoint(self, val_loss: float, model: torch.nn.Module) -> None:
|
|
226
|
+
"""Saves model when validation loss decrease."""
|
|
227
|
+
if self.verbose:
|
|
228
|
+
logger.info(
|
|
229
|
+
f"Validation loss ({self.best_loss:.4f} --> {val_loss:.4f}). Saving {self.path} ..."
|
|
230
|
+
)
|
|
231
|
+
torch.save(model, self.path)
|
|
232
|
+
self.val_loss_min = val_loss
|
|
233
|
+
|
|
234
|
+
def get_best(self) -> torch.nn.Module:
|
|
235
|
+
return torch.load(self.path)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Encoder(nn.Module):
|
|
9
|
+
"""encoder"""
|
|
10
|
+
|
|
11
|
+
def __init__(self, config: Dict) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.flatten = nn.Flatten()
|
|
14
|
+
self.encode = nn.Sequential(
|
|
15
|
+
nn.Linear(config["insize"], config["h1"]),
|
|
16
|
+
nn.ReLU(),
|
|
17
|
+
nn.Linear(config["h1"], config["h2"]),
|
|
18
|
+
nn.ReLU(),
|
|
19
|
+
nn.Linear(config["h2"], config["latent"]),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
x = self.flatten(x)
|
|
24
|
+
latent = self.encode(x)
|
|
25
|
+
return latent
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Decoder(nn.Module):
|
|
29
|
+
def __init__(self, config: Dict) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.decode = nn.Sequential(
|
|
32
|
+
nn.Linear(config["latent"], config["h2"]),
|
|
33
|
+
nn.ReLU(),
|
|
34
|
+
nn.Linear(config["h2"], config["h1"]),
|
|
35
|
+
nn.ReLU(),
|
|
36
|
+
nn.Linear(config["h1"], config["insize"]),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
x = self.decode(x)
|
|
41
|
+
x = x.reshape((-1, 28, 28, 1))
|
|
42
|
+
return x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RecostructionLoss:
|
|
46
|
+
def __call__(self, y, yhat):
|
|
47
|
+
sqe = (y - yhat) ** 2
|
|
48
|
+
summed = sqe.sum(dim=(1, 2, 3))
|
|
49
|
+
return summed.mean()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AutoEncoder(nn.Module):
|
|
53
|
+
def __init__(self, config: Dict) -> None:
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.encoder = Encoder(config)
|
|
56
|
+
self.decoder = Decoder(config)
|
|
57
|
+
|
|
58
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
latent = self.encoder(x)
|
|
60
|
+
x = self.decoder(latent)
|
|
61
|
+
return x
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def sample_range(encoder, stream, k: int = 10):
|
|
65
|
+
minmax_list = []
|
|
66
|
+
for _ in range(10):
|
|
67
|
+
X, _ = next(stream)
|
|
68
|
+
y = encoder(X).detach().numpy()
|
|
69
|
+
minmax_list.append(y.min())
|
|
70
|
+
minmax_list.append(y.max())
|
|
71
|
+
minmax = np.array(minmax_list)
|
|
72
|
+
return minmax.min(), minmax.max()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def build_latent_grid(decoder, minimum: int, maximum: int, k: int = 20):
|
|
76
|
+
x = np.linspace(minimum, maximum, k)
|
|
77
|
+
y = np.linspace(minimum, maximum, k)
|
|
78
|
+
xx, yy = np.meshgrid(x, y)
|
|
79
|
+
grid = np.c_[xx.ravel(), yy.ravel()]
|
|
80
|
+
|
|
81
|
+
img = decoder(torch.tensor(grid, dtype=torch.float32))
|
|
82
|
+
return img.detach().numpy()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def select_n_random(data, labels, n=300):
|
|
86
|
+
"""
|
|
87
|
+
Selects n random datapoints and their corresponding labels from a dataset
|
|
88
|
+
"""
|
|
89
|
+
assert len(data) == len(labels)
|
|
90
|
+
|
|
91
|
+
perm = torch.randperm(len(data))
|
|
92
|
+
return data[perm][:n], labels[perm][:n]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "mltrainer"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "toolkit for training pytorch models"
|
|
5
|
+
authors = [
|
|
6
|
+
{ name = "R.Grouls", email = "Raoul.Grouls@han.nl" },
|
|
7
|
+
]
|
|
8
|
+
dependencies = [
|
|
9
|
+
"gin-config>=0.5.0",
|
|
10
|
+
"numpy>=1.25.0",
|
|
11
|
+
"torch>=2.0.1",
|
|
12
|
+
"loguru>=0.7.0",
|
|
13
|
+
"ray[tune]>=2.5.1",
|
|
14
|
+
"tqdm>=4.65.0",
|
|
15
|
+
"pydantic>=1.10.9",
|
|
16
|
+
"torchvision>=0.15.2",
|
|
17
|
+
"torchtext>=0.15.2",
|
|
18
|
+
"torch-tb-profiler>=0.4.1",
|
|
19
|
+
]
|
|
20
|
+
requires-python = ">=3.10"
|
|
21
|
+
readme = "README.md"
|
|
22
|
+
|
|
23
|
+
[project.license]
|
|
24
|
+
text = "MIT"
|
|
25
|
+
|
|
26
|
+
[project.optional-dependencies]
|
|
27
|
+
lint = [
|
|
28
|
+
"Flake8-pyproject>=1.2.3",
|
|
29
|
+
"pep8-naming>=0.13.3",
|
|
30
|
+
"flake8-annotations>=3.0.1",
|
|
31
|
+
"black>=23.3.0",
|
|
32
|
+
"flake8>=6.0.0",
|
|
33
|
+
"isort>=5.12.0",
|
|
34
|
+
"mypy>=1.4.1",
|
|
35
|
+
]
|
|
36
|
+
tuning = [
|
|
37
|
+
"mlflow>=2.4.1",
|
|
38
|
+
"bayesian-optimization>=1.4.3",
|
|
39
|
+
"hpbandster>=0.7.4",
|
|
40
|
+
"configspace>=0.7.1",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[build-system]
|
|
44
|
+
requires = [
|
|
45
|
+
"pdm-backend",
|
|
46
|
+
]
|
|
47
|
+
build-backend = "pdm.backend"
|
|
48
|
+
|
|
49
|
+
[tool.flake8]
|
|
50
|
+
ignore = [
|
|
51
|
+
"W503",
|
|
52
|
+
"ANN101",
|
|
53
|
+
"ANN002",
|
|
54
|
+
"ANN003",
|
|
55
|
+
]
|
|
56
|
+
max-line-length = 88
|
|
57
|
+
max-complexity = 18
|
|
58
|
+
exclude = [
|
|
59
|
+
"__init__.py",
|
|
60
|
+
"tests/*",
|
|
61
|
+
"docs/*",
|
|
62
|
+
"examples/*",
|
|
63
|
+
"scripts/*",
|
|
64
|
+
]
|
|
65
|
+
select = "C,E,F,W,B,B950"
|
|
66
|
+
extend-ignore = "E203, E501"
|
|
67
|
+
|
|
68
|
+
[tool.isort]
|
|
69
|
+
multi_line_output = 3
|
|
70
|
+
include_trailing_comma = true
|
|
71
|
+
use_parentheses = true
|
|
72
|
+
line_length = 88
|
|
73
|
+
|
|
74
|
+
[tool.mypy]
|
|
75
|
+
ignore_missing_imports = true
|
|
76
|
+
strict_optional = true
|
|
77
|
+
warn_unreachable = true
|
|
78
|
+
pretty = true
|