FASTEN-cli 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FASTEN/__init__.py +7 -0
- FASTEN/cli.py +95 -0
- FASTEN/common.py +5 -0
- FASTEN/config.py +110 -0
- FASTEN/data.py +118 -0
- FASTEN/estimate.py +138 -0
- FASTEN/learn.py +165 -0
- FASTEN/model.py +152 -0
- FASTEN/param.py +120 -0
- FASTEN/plot.py +215 -0
- FASTEN/predict.py +66 -0
- FASTEN/train.py +87 -0
- FASTEN/tune.py +92 -0
- FASTEN/utils.py +67 -0
- fasten_cli-1.0.0.dist-info/METADATA +89 -0
- fasten_cli-1.0.0.dist-info/RECORD +19 -0
- fasten_cli-1.0.0.dist-info/WHEEL +5 -0
- fasten_cli-1.0.0.dist-info/entry_points.txt +2 -0
- fasten_cli-1.0.0.dist-info/licenses/LICENSE.md +21 -0
FASTEN/model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from .common import np, torch, nn, F, os, json
|
|
2
|
+
from .config import ModelArgs, ModelInput, ModelOutput
|
|
3
|
+
from .param import ModelParam, Constraint
|
|
4
|
+
from .utils import Scaler, Encoder
|
|
5
|
+
from pydantic import ValidationError
|
|
6
|
+
from zipfile import ZipFile
|
|
7
|
+
import joblib, random
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Network(nn.Module):
|
|
11
|
+
def __init__(self, args: ModelArgs, input_size: int, output_size: int):
|
|
12
|
+
super(Network, self).__init__()
|
|
13
|
+
self.layers, sizes = nn.ModuleList(), [input_size]
|
|
14
|
+
if not args.hidden_layers: sizes.append(output_size)
|
|
15
|
+
elif args.architecture == "pyramidal":
|
|
16
|
+
step = (args.hidden_size - output_size) / args.hidden_layers
|
|
17
|
+
for i in range(args.hidden_layers):
|
|
18
|
+
current_size = int(args.hidden_size - (i * step))
|
|
19
|
+
sizes.append(max(current_size, output_size) )
|
|
20
|
+
sizes.append(output_size)
|
|
21
|
+
elif args.architecture == "rectangular":
|
|
22
|
+
sizes.extend([args.hidden_size] * args.hidden_layers)
|
|
23
|
+
sizes.append(output_size)
|
|
24
|
+
for i in range(len(sizes) - 1):
|
|
25
|
+
self.layers.append(nn.Linear(sizes[i], sizes[i+1]))
|
|
26
|
+
|
|
27
|
+
def load_constraints(self, params: dict[str, ModelParam]):
|
|
28
|
+
for rule in Constraint.RULES:
|
|
29
|
+
self.register_buffer(rule, torch.zeros(len(params), dtype = bool))
|
|
30
|
+
for value in Constraint.VALUES:
|
|
31
|
+
self.register_buffer(value, torch.zeros(len(params), dtype = float))
|
|
32
|
+
for i, param in enumerate(params.values()):
|
|
33
|
+
for rule in Constraint.RULES: self.get_buffer(rule)[i] = param.priors.get_rule(rule)
|
|
34
|
+
for value in Constraint.VALUES: self.get_buffer(value)[i] = param.priors.get_value(value)
|
|
35
|
+
for label in ["low", "high"]:
|
|
36
|
+
loop = (param.base == "label" for param in params.values())
|
|
37
|
+
mask = np.fromiter(loop, dtype = bool, count = len(params))
|
|
38
|
+
self.register_buffer(label, torch.from_numpy(mask))
|
|
39
|
+
self.to(next(self.parameters()).device)
|
|
40
|
+
|
|
41
|
+
def load_scaler(self, scaler: Scaler):
|
|
42
|
+
if not scaler.fitted: return
|
|
43
|
+
self.register_buffer("min", torch.from_numpy(scaler.min))
|
|
44
|
+
self.register_buffer("range", torch.from_numpy(scaler.range))
|
|
45
|
+
self.to(next(self.parameters()).device)
|
|
46
|
+
|
|
47
|
+
def forward(self, x):
|
|
48
|
+
for layer in self.layers[:-1]:
|
|
49
|
+
x = F.relu(layer(x))
|
|
50
|
+
x = self.layers[-1](x)
|
|
51
|
+
x[:, self.greater_than] = F.softplus(x[:, self.greater_than]) + (self.lower[self.greater_than] - self.min[self.greater_than]) / self.range[self.greater_than]
|
|
52
|
+
x[:, self.less_than] = (self.upper[self.less_than] - self.min[self.less_than]) / self.range[self.less_than] - F.softplus(x[:, self.less_than])
|
|
53
|
+
x[:, self.between] = F.sigmoid(x[:, self.between]) * self.interval[self.between] / self.range[self.between] + (self.lower[self.between] - self.min[self.between]) / self.range[self.between]
|
|
54
|
+
x[:, self.high] = x[:, self.low] + F.softplus(x[:, self.high] - x[:, self.low])
|
|
55
|
+
return x
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Model():
|
|
59
|
+
def __init__(self, config_file: str = None, model_file: str = None):
|
|
60
|
+
torch.set_default_dtype(torch.float64)
|
|
61
|
+
self.inputs: dict[str, ModelInput] = dict()
|
|
62
|
+
self.outputs: dict[str, ModelOutput] = dict()
|
|
63
|
+
self.params: dict[str, ModelParam] = dict()
|
|
64
|
+
self.args: ModelArgs = None
|
|
65
|
+
if config_file:
|
|
66
|
+
with open(config_file, "r") as file:
|
|
67
|
+
self.config = json.load(file)
|
|
68
|
+
else: self.config = {"train": {}}
|
|
69
|
+
if not model_file: self.create()
|
|
70
|
+
else: self.load(model_file)
|
|
71
|
+
|
|
72
|
+
def validate_data(self, inputs: dict, outputs: dict):
|
|
73
|
+
for label, config in inputs.items():
|
|
74
|
+
try: self.inputs[label] = ModelInput(**config, label = label)
|
|
75
|
+
except ValidationError as error: raise error
|
|
76
|
+
for label, config in outputs.items():
|
|
77
|
+
try: output = ModelOutput(**config, label = label)
|
|
78
|
+
except ValidationError as error: raise error
|
|
79
|
+
params = output.dist.load_parameters(output)
|
|
80
|
+
self.params.update(params)
|
|
81
|
+
self.outputs[label] = output
|
|
82
|
+
|
|
83
|
+
def validate_args(self, model: dict, train: dict):
|
|
84
|
+
model = {k: v for k, v in model.items() if not isinstance(v, list)}
|
|
85
|
+
train = {k: v for k, v in train.items() if not isinstance(v, list)}
|
|
86
|
+
try: self.args = ModelArgs(**{**model, **train})
|
|
87
|
+
except ValidationError as error: raise error
|
|
88
|
+
if self.args.rand_seed is not None:
|
|
89
|
+
self.set_seed(self.args.rand_seed)
|
|
90
|
+
self.network = Network(self.args, len(self.inputs), len(self.params))
|
|
91
|
+
self.network = self.network.to(self.args.device)
|
|
92
|
+
self.network.load_constraints(self.params)
|
|
93
|
+
self.network.load_scaler(self.param_scaler)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def set_seed(seed: int):
|
|
97
|
+
torch.manual_seed(seed)
|
|
98
|
+
torch.cuda.manual_seed(seed)
|
|
99
|
+
torch.cuda.manual_seed_all(seed)
|
|
100
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
101
|
+
torch.backends.cudnn.deterministic = True
|
|
102
|
+
torch.backends.cudnn.benchmark = False
|
|
103
|
+
random.seed(seed)
|
|
104
|
+
np.random.seed(seed)
|
|
105
|
+
|
|
106
|
+
def create(self):
|
|
107
|
+
self.input_scaler, self.param_scaler = Scaler(), Scaler()
|
|
108
|
+
self.validate_data(self.config["inputs"], self.config["outputs"])
|
|
109
|
+
self.validate_args(self.config["model"], self.config["train"])
|
|
110
|
+
self.encoder = Encoder(self.inputs)
|
|
111
|
+
|
|
112
|
+
def load(self, model_file):
|
|
113
|
+
with ZipFile(model_file, "r") as model:
|
|
114
|
+
with model.open("input_scaler.pkl", "r") as file:
|
|
115
|
+
self.input_scaler = joblib.load(file)
|
|
116
|
+
with model.open("param_scaler.pkl", "r") as file:
|
|
117
|
+
self.param_scaler = joblib.load(file)
|
|
118
|
+
with model.open("encoder.pkl", "r") as file:
|
|
119
|
+
self.encoder = joblib.load(file)
|
|
120
|
+
with model.open("config.json", "r") as file:
|
|
121
|
+
model_config = json.load(file)
|
|
122
|
+
self.config["inputs"] = model_config["inputs"]
|
|
123
|
+
self.config["outputs"] = model_config["outputs"]
|
|
124
|
+
self.config["model"] = model_config["model"]
|
|
125
|
+
self.validate_data(self.config["inputs"], self.config["outputs"])
|
|
126
|
+
self.validate_args(self.config["model"], self.config["train"])
|
|
127
|
+
with model.open("model.pth", "r") as file:
|
|
128
|
+
self.network.load_state_dict(torch.load(file))
|
|
129
|
+
self.network.load_scaler(self.param_scaler)
|
|
130
|
+
|
|
131
|
+
def dump(self, output):
|
|
132
|
+
model_file = f"{output}/model.zip"
|
|
133
|
+
with ZipFile(model_file, "w") as model:
|
|
134
|
+
with model.open("input_scaler.pkl", "w") as file:
|
|
135
|
+
joblib.dump(self.input_scaler, file)
|
|
136
|
+
with model.open("param_scaler.pkl", "w") as file:
|
|
137
|
+
joblib.dump(self.param_scaler, file)
|
|
138
|
+
with model.open("encoder.pkl", "w") as file:
|
|
139
|
+
joblib.dump(self.encoder, file)
|
|
140
|
+
with model.open("model.pth", "w") as file:
|
|
141
|
+
torch.save(self.network.state_dict(), file)
|
|
142
|
+
config = json.dumps(self.config, indent = 4)
|
|
143
|
+
model.writestr("config.json", config)
|
|
144
|
+
|
|
145
|
+
# def extract_config(self) -> dict:
|
|
146
|
+
# train, model = self.args.model_dump_json(), dict()
|
|
147
|
+
# train["device"], train["optimizer"] = train["device"].type, train["optimizer"].__name__
|
|
148
|
+
# for key in ["architecture", "hidden_layers", "hidden_size"]: model[key] = train.pop(key)
|
|
149
|
+
# inputs = {label: value.model_dump_json() for label, value in self.inputs.items()}
|
|
150
|
+
# outputs = {label: value.model_dump_json() for label, value in self.outputs.items()}
|
|
151
|
+
# return {"inputs": inputs, "outputs": outputs, "train": train, "model": model}
|
|
152
|
+
|
FASTEN/param.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .common import torch, pd
|
|
3
|
+
from torch.distributions import *
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from .config import ModelOutput
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelParam():
|
|
12
|
+
def __init__(self, output: ModelOutput, label: str):
|
|
13
|
+
self.base: str = label.replace("0", "").replace("1", "")
|
|
14
|
+
self.label: str = f"{output.label}_{label}"
|
|
15
|
+
self.name: str = f"{output.name} {label.replace('_', ' ').title()}"
|
|
16
|
+
self.priors = Constraint(output, self.base)
|
|
17
|
+
self.constraints = deepcopy(self.priors)
|
|
18
|
+
|
|
19
|
+
def load_constraints(self, support: Support, min_val: float, max_val: float):
|
|
20
|
+
if support.upper == self.base: self.constraints.add_lower(max_val)
|
|
21
|
+
if support.lower == self.base: self.constraints.add_upper(min_val)
|
|
22
|
+
self.constraints.add_intervals()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelDist():
|
|
26
|
+
TYPES = {"Bernoulli", "Beta", "Binomial", "Cauchy", "Chi2", "ContinuousBernoulli", "Exponential",
|
|
27
|
+
"FisherSnedecor", "Gamma", "Geometric", "Gumbel", "HalfCauchy", "HalfNormal", "InverseGamma",
|
|
28
|
+
"Laplace", "LogNormal", "NegativeBinomial", "Normal", "Pareto", "Poisson", "Uniform"}
|
|
29
|
+
|
|
30
|
+
def __init__(self, name: str):
|
|
31
|
+
self.name: str = name
|
|
32
|
+
self.params: dict[str, ModelParam] = dict()
|
|
33
|
+
self.support = Support(self.name)
|
|
34
|
+
if hasattr(ModelDist, name): self.base = getattr(ModelDist, name)
|
|
35
|
+
else: self.base = getattr(torch.distributions, name)
|
|
36
|
+
|
|
37
|
+
def load_parameters(self, output: ModelOutput) -> dict[str, ModelParam]:
|
|
38
|
+
base = getattr(torch.distributions, self.name)
|
|
39
|
+
if self.name == "Uniform": labels = ["low", "high"]
|
|
40
|
+
else: labels = list(base.arg_constraints.keys())
|
|
41
|
+
for label in labels:
|
|
42
|
+
if label == "probs": continue
|
|
43
|
+
param = ModelParam(output, label)
|
|
44
|
+
self.params[f"{output.label}_{label}"] = param
|
|
45
|
+
return self.params
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def Bernoulli(logits: float | torch.Tensor) -> Bernoulli:
|
|
49
|
+
return Bernoulli(logits = logits, validate_args = False)
|
|
50
|
+
@staticmethod
|
|
51
|
+
def ContinuousBernoulli(logits: float | torch.Tensor) -> ContinuousBernoulli:
|
|
52
|
+
return ContinuousBernoulli(logits = logits, validate_args = False)
|
|
53
|
+
@staticmethod
|
|
54
|
+
def Geometric(logits: float | torch.Tensor) -> Geometric:
|
|
55
|
+
return Geometric(logits = logits, validate_args = False)
|
|
56
|
+
@staticmethod
|
|
57
|
+
def Binomial(total_count: float | torch.Tensor, logits: float | torch.Tensor) -> Binomial:
|
|
58
|
+
if isinstance(total_count, torch.Tensor):
|
|
59
|
+
rounded_total_count = total_count + (total_count.round() - total_count).detach()
|
|
60
|
+
return Binomial(rounded_total_count, logits = logits, validate_args = False)
|
|
61
|
+
else: return Binomial(round(total_count), logits = logits, validate_args = False)
|
|
62
|
+
@staticmethod
|
|
63
|
+
def NegativeBinomial(total_count: int | torch.Tensor, logits: float | torch.Tensor) -> NegativeBinomial:
|
|
64
|
+
return NegativeBinomial(total_count, logits = logits, validate_args = False)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Constraint():
|
|
68
|
+
POSITIVE = {"concentration", "total_count", "scale", "df", "rate", "alpha"}
|
|
69
|
+
RULES = ["greater_than", "less_than", "between"]
|
|
70
|
+
VALUES = ["lower", "upper", "interval"]
|
|
71
|
+
|
|
72
|
+
def __init__(self, output: ModelOutput, base: str, eps: float = 1e-16):
|
|
73
|
+
self.greater_than = self.less_than = self.between = False
|
|
74
|
+
self.lower = self.upper = self.interval = 0.0
|
|
75
|
+
if base in self.POSITIVE: self.greater_than, self.lower = True, eps
|
|
76
|
+
if hasattr(output, f"{base}_min"):
|
|
77
|
+
self.add_lower(getattr(output, f"{base}_min"))
|
|
78
|
+
if hasattr(output, f"{base}_max"):
|
|
79
|
+
self.add_upper(getattr(output, f"{base}_max"))
|
|
80
|
+
self.add_intervals()
|
|
81
|
+
|
|
82
|
+
def add_upper(self, upper): self.less_than, self.upper = True, upper
|
|
83
|
+
def add_lower(self, lower): self.greater_than, self.lower = True, lower
|
|
84
|
+
def add_intervals(self):
|
|
85
|
+
if self.greater_than and self.less_than:
|
|
86
|
+
self.greater_than = self.less_than = False
|
|
87
|
+
self.between, self.interval = True, self.upper - self.lower
|
|
88
|
+
|
|
89
|
+
def get_rule(self, rule: str) -> bool: return getattr(self, rule)
|
|
90
|
+
def get_value(self, value: str) -> float: return getattr(self, value)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Support:
|
|
94
|
+
DISCRETE = {"Bernoulli", "Binomial", "Geometric", "NegativeBinomial", "Poisson"}
|
|
95
|
+
POSITIVE = {"FisherSnedecor", "InverseGamma", "LogNormal"}
|
|
96
|
+
NON_NEGATIVE = {"Binomial", "Poisson", "Geometric", "NegativeBinomial", "Chi2", "Exponential", "Gamma", "HalfCauchy", "HalfNormal"}
|
|
97
|
+
PROBABILITY = {"Bernoulli", "Beta", "ContinuousBernoulli"}
|
|
98
|
+
DEPENDENT = {"Uniform", "Binomial", "Pareto"}
|
|
99
|
+
|
|
100
|
+
def __init__(self, name, eps = 1e-16):
|
|
101
|
+
self.lower = self.upper = None
|
|
102
|
+
self.discrete = (name in self.DISCRETE)
|
|
103
|
+
if name in self.POSITIVE: self.lower = eps
|
|
104
|
+
if name in self.NON_NEGATIVE: self.lower = 0.0
|
|
105
|
+
if name in self.PROBABILITY: self.lower, self.upper = 0.0, 1.0
|
|
106
|
+
self.dependent = (name in self.DEPENDENT)
|
|
107
|
+
if name == "Uniform": self.lower, self.upper = "low", "high"
|
|
108
|
+
if name == "Binomial": self.upper = "total_count"
|
|
109
|
+
if name == "Pareto": self.lower = "scale"
|
|
110
|
+
|
|
111
|
+
def validate(self, data: pd.Series) -> bool:
|
|
112
|
+
if isinstance(self.lower, float) and (data < self.lower).any(): return False
|
|
113
|
+
if isinstance(self.upper, float) and (data > self.upper).any(): return False
|
|
114
|
+
return True
|
|
115
|
+
|
|
116
|
+
def get_bound(self, attr: str, params: dict[str, ModelParam], stats: torch.Tensor, fit: Distribution, std_devs: int = 3) -> torch.Tensor:
|
|
117
|
+
bound, sign = getattr(self, attr), -1 if attr == "lower" else 1
|
|
118
|
+
if bound is None: return (fit.mean + sign * std_devs * torch.sqrt(fit.variance)).unsqueeze(-1)
|
|
119
|
+
if isinstance(bound, str): return stats[:, [param.base == bound for param in params.values()]]
|
|
120
|
+
else: return torch.tensor([bound]).repeat(stats.shape[0]).unsqueeze(-1)
|
FASTEN/plot.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from .common import torch, pd, np, os
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import optuna
|
|
4
|
+
|
|
5
|
+
COLORS = {"logits": "C6", "total_count": "C4", "rate": "C3",
|
|
6
|
+
"concentration": "C2", "loc": "C0", "scale": "C1",
|
|
7
|
+
"df": "C7", "alpha": "C5", "high": "C8", "low": "C9"}
|
|
8
|
+
|
|
9
|
+
def plot_train(trainer, figure_dir):
|
|
10
|
+
if not os.path.exists(figure_dir): os.mkdir(figure_dir)
|
|
11
|
+
plt.figure(figsize = (6, 4))
|
|
12
|
+
plt.plot(range(len(trainer.train.loss)), trainer.train.loss,
|
|
13
|
+
color = "blue", alpha = 0.5, label = "Training", zorder = 2)
|
|
14
|
+
max_val = max(trainer.train.loss)
|
|
15
|
+
if trainer.valid:
|
|
16
|
+
plt.plot(range(len(trainer.valid.loss)), trainer.valid.loss,
|
|
17
|
+
color = "red", alpha = 0.5, label = "Validation", zorder = 2)
|
|
18
|
+
max_val = max(max_val, max(trainer.valid.loss))
|
|
19
|
+
if max_val > 10: plt.yscale("log")
|
|
20
|
+
plt.xlabel("Epoch")
|
|
21
|
+
plt.legend()
|
|
22
|
+
plt.ylabel("Average Loss")
|
|
23
|
+
for i in range(2, 4):
|
|
24
|
+
plt.subplot(1, 3, i)
|
|
25
|
+
plt.plot([], [])
|
|
26
|
+
plt.axis("off")
|
|
27
|
+
plt.tight_layout()
|
|
28
|
+
plt.savefig(f"{figure_dir}/loss_curve_plot.png", dpi = 200)
|
|
29
|
+
plt.close()
|
|
30
|
+
|
|
31
|
+
def plot_tune(tuner, figure_dir):
|
|
32
|
+
plt.figure(figsize = (4,4))
|
|
33
|
+
rank = optuna.importance.get_param_importances(tuner.study)
|
|
34
|
+
importances = np.array(list(rank.values()))
|
|
35
|
+
params = np.array(list(rank.keys()))
|
|
36
|
+
index = np.argsort(importances)
|
|
37
|
+
plt.barh(params[index], importances[index])
|
|
38
|
+
plt.xlabel("Importance")
|
|
39
|
+
plt.ylabel("Hyperparameter")
|
|
40
|
+
plt.tight_layout()
|
|
41
|
+
plt.savefig(f"{figure_dir}/importance_plot.png", dpi = 200)
|
|
42
|
+
plt.close()
|
|
43
|
+
|
|
44
|
+
trials = tuner.study.trials_dataframe()
|
|
45
|
+
plt.figure(figsize = (3 * len(params), 3.5))
|
|
46
|
+
for i, param in enumerate(params):
|
|
47
|
+
plt.subplot(1, len(params), i+1)
|
|
48
|
+
for value, group in trials.groupby(f"params_{param}"):
|
|
49
|
+
x, y = np.repeat(str(value), len(group)), group["value"]
|
|
50
|
+
plt.scatter(x, y, alpha = 0.5, linewidths = 0, color = "C0")
|
|
51
|
+
if y.max() - y.min() > 100: plt.yscale("log")
|
|
52
|
+
plt.xlabel(param)
|
|
53
|
+
plt.ylabel("Loss")
|
|
54
|
+
plt.xticks(rotation = 45, ha = "right")
|
|
55
|
+
plt.tight_layout()
|
|
56
|
+
plt.savefig(f"{figure_dir}/slices_plot.png", dpi = 200)
|
|
57
|
+
plt.close()
|
|
58
|
+
|
|
59
|
+
best = [trials.loc[:i+1, "value"].min() for i in range(len(trials))]
|
|
60
|
+
plt.figure(figsize = (8, 4))
|
|
61
|
+
plt.scatter(range(len(trials)), trials["value"], color = "C0")
|
|
62
|
+
if trials["value"].max() - trials["value"].min() > 100: plt.yscale("log")
|
|
63
|
+
plt.plot(range(len(trials)), best, color = "C3")
|
|
64
|
+
plt.xlabel("Trial")
|
|
65
|
+
plt.ylabel("Loss")
|
|
66
|
+
plt.tight_layout()
|
|
67
|
+
plt.savefig(f"{figure_dir}/convergence_plot.png", dpi = 200)
|
|
68
|
+
plt.close()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def plot_predict(predictor, figure_dir):
|
|
72
|
+
if predictor.true is None: return None, None, None
|
|
73
|
+
predictor.execute()
|
|
74
|
+
mse, kld, nll = predictor.evaluate()
|
|
75
|
+
if not os.path.exists(figure_dir):
|
|
76
|
+
os.mkdir(figure_dir)
|
|
77
|
+
plot_statistics(predictor, figure_dir)
|
|
78
|
+
plot_kl_divergence(predictor, kld, figure_dir)
|
|
79
|
+
plot_mean_squares(predictor, mse, figure_dir)
|
|
80
|
+
plot_samples(predictor, figure_dir)
|
|
81
|
+
return mse, kld, nll
|
|
82
|
+
|
|
83
|
+
def plot_samples(predictor, figure_dir, points = 10000, n_samples = 10):
|
|
84
|
+
sample_dir = f"{figure_dir}/samples"
|
|
85
|
+
if not os.path.exists(sample_dir): os.mkdir(sample_dir)
|
|
86
|
+
groups = predictor.true.samples.outputs.groupby(predictor.true.samples.group)
|
|
87
|
+
n_samples = min(n_samples, len(groups) // 5)
|
|
88
|
+
clusters = predictor.true.stats.cluster_data(n_clusters = n_samples)
|
|
89
|
+
samples = pd.Series(np.arange(len(groups))).groupby(clusters).sample()
|
|
90
|
+
rows, cols = n_samples, len(predictor.model.outputs)
|
|
91
|
+
true = torch.empty((cols, points, rows)), torch.empty((cols, points, rows))
|
|
92
|
+
pred = torch.empty((cols, points, rows)), torch.empty((cols, points, rows))
|
|
93
|
+
(x_true, y_true), (x_pred, y_pred) = true, pred
|
|
94
|
+
|
|
95
|
+
for j, output in enumerate(predictor.model.outputs.values()):
|
|
96
|
+
mask = [param in output.dist.params for param in predictor.model.params]
|
|
97
|
+
for label, (x, y) in {"true": true, "pred": pred}.items():
|
|
98
|
+
dataset = getattr(predictor, label)
|
|
99
|
+
stats = dataset.stats.outputs.iloc[samples,mask]
|
|
100
|
+
params = torch.tensor(stats.values)
|
|
101
|
+
fit = output.dist.base(*params.unbind(dim = 1))
|
|
102
|
+
lower = output.dist.support.get_bound("lower", output.dist.params, params, fit)
|
|
103
|
+
upper = output.dist.support.get_bound("upper", output.dist.params, params, fit)
|
|
104
|
+
dtype = int if output.dist.support.discrete else float
|
|
105
|
+
x[j] = torch.lerp(lower, upper, torch.linspace(0, 1, points)).to(dtype).t()
|
|
106
|
+
y[j] = torch.exp(fit.log_prob(x[j]))
|
|
107
|
+
|
|
108
|
+
for i, sample in enumerate(samples):
|
|
109
|
+
plt.figure(figsize = (2.5 * cols + 0.5, 2))
|
|
110
|
+
group = groups.get_group(sample)
|
|
111
|
+
for j, output in enumerate(predictor.model.outputs.values()):
|
|
112
|
+
plt.subplot(1, cols, j + 1)
|
|
113
|
+
plt.hist(group.iloc[:,j], density = True, color = "darkgray")
|
|
114
|
+
real_true, real_pred = ~torch.isinf(y_true[j,:,i]), ~torch.isinf(y_pred[j,:,i])
|
|
115
|
+
mask_true, mask_pred = real_true.clone(), real_pred.clone()
|
|
116
|
+
mask_true[real_true] = y_true[j,real_true,i] >= 0.01 * max(y_true[j,real_true,i])
|
|
117
|
+
mask_pred[real_pred] = y_pred[j,real_pred,i] >= 0.01 * max(y_pred[j,real_pred,i])
|
|
118
|
+
plt.plot(x_true[j,mask_true,i], y_true[j,mask_true,i], color = "green",
|
|
119
|
+
alpha = 1, linewidth = 2, label = f"Emprical (N = {len(group)})")
|
|
120
|
+
plt.plot(x_pred[j,mask_pred,i], y_pred[j,mask_pred,i], color = "orange",
|
|
121
|
+
alpha = 1, linewidth = 2, linestyle = "dashed", label = "Predicted")
|
|
122
|
+
if j == cols - 1: plt.legend(bbox_to_anchor = (1.05, 1), loc = "upper left")
|
|
123
|
+
plt.xlabel(output.name)
|
|
124
|
+
plt.ylabel("Density")
|
|
125
|
+
plt.xticks(fontsize = 8)
|
|
126
|
+
plt.yticks(fontsize = 8)
|
|
127
|
+
plt.tight_layout()
|
|
128
|
+
plt.savefig(f"{sample_dir}/sample_plot_{sample}.png", dpi = 200)
|
|
129
|
+
plt.close()
|
|
130
|
+
|
|
131
|
+
def plot_statistics(predictor, figure_dir):
|
|
132
|
+
cols, rows = len(predictor.model.outputs), 2
|
|
133
|
+
groups = predictor.true.samples.outputs.groupby(predictor.true.samples.group)
|
|
134
|
+
plt.figure(figsize = (3 * cols, 3 * rows))
|
|
135
|
+
|
|
136
|
+
for j, output in enumerate(predictor.model.outputs.values()):
|
|
137
|
+
mask = [param in output.dist.params for param in predictor.model.params]
|
|
138
|
+
stats = predictor.pred.stats.outputs.iloc[:,mask]
|
|
139
|
+
params = torch.tensor(stats.values)
|
|
140
|
+
fit = output.dist.base(*params.unbind(dim = 1))
|
|
141
|
+
pred_means, pred_vars = fit.mean, fit.variance
|
|
142
|
+
true_means = torch.zeros(groups.ngroups)
|
|
143
|
+
true_vars = torch.zeros(groups.ngroups)
|
|
144
|
+
for i, group in groups:
|
|
145
|
+
true_means[i] = group[output.label].mean()
|
|
146
|
+
true_vars[i] = group[output.label].var()
|
|
147
|
+
|
|
148
|
+
plt.subplot(rows, cols, j + 1)
|
|
149
|
+
plt.scatter(true_means, pred_means, color = "C0")
|
|
150
|
+
max_val = max(pred_means.max().max(), true_means.max().max())
|
|
151
|
+
min_val = min(pred_means.min().min(), true_means.min().min())
|
|
152
|
+
delta = (max_val - min_val) * 0.1
|
|
153
|
+
max_val, min_val = max_val + delta, min_val - delta
|
|
154
|
+
plt.axline((min_val, min_val), slope = 1, color = "darkgray", linestyle = "--")
|
|
155
|
+
plt.xlim(min_val, max_val)
|
|
156
|
+
plt.ylim(min_val, max_val)
|
|
157
|
+
plt.ylabel(f"Predicted Mean")
|
|
158
|
+
plt.xlabel(f"Empirical Mean")
|
|
159
|
+
plt.title(output.name, pad = 10, fontweight = "bold")
|
|
160
|
+
|
|
161
|
+
plt.subplot(rows, cols, cols + j + 1)
|
|
162
|
+
plt.scatter(true_vars, pred_vars, color = "C1")
|
|
163
|
+
max_val = max(pred_vars.max().max(), true_vars.max().max())
|
|
164
|
+
min_val = min(pred_vars.min().min(), true_vars.min().min())
|
|
165
|
+
delta = (max_val - min_val) * 0.1
|
|
166
|
+
max_val, min_val = max_val + delta, min_val - delta
|
|
167
|
+
plt.axline((min_val, min_val), slope = 1, color = "darkgray", linestyle = "--")
|
|
168
|
+
plt.xlim(min_val, max_val)
|
|
169
|
+
plt.ylim(min_val, max_val)
|
|
170
|
+
plt.ylabel(f"Predicted Variance")
|
|
171
|
+
plt.xlabel(f"Empirical Variance")
|
|
172
|
+
plt.tight_layout(rect = [0, 0, 0.99, 1])
|
|
173
|
+
plt.savefig(f"{figure_dir}/statistics_plot.png", dpi = 200)
|
|
174
|
+
plt.close()
|
|
175
|
+
|
|
176
|
+
def plot_kl_divergence(predictor, kld, figure_dir):
|
|
177
|
+
names = [output.name for output in predictor.model.outputs.values()]
|
|
178
|
+
positions = range(1, kld.shape[1] + 1)
|
|
179
|
+
plt.figure(figsize = (6, 6))
|
|
180
|
+
bplot = plt.boxplot(kld, whis = [0, 100], patch_artist = True)
|
|
181
|
+
for patch in bplot["boxes"]: patch.set_facecolor("darkgray")
|
|
182
|
+
for patch in bplot["medians"]: patch.set_color("black")
|
|
183
|
+
plt.ylim(-0.5, kld.max().max() * 5)
|
|
184
|
+
plt.ylabel("KL Divergence")
|
|
185
|
+
plt.xticks(positions, names, rotation = 45, ha = "right")
|
|
186
|
+
plt.yscale("symlog")
|
|
187
|
+
plt.tight_layout()
|
|
188
|
+
plt.savefig(f"{figure_dir}/kl_divergence_plot.png", dpi = 200)
|
|
189
|
+
plt.close()
|
|
190
|
+
|
|
191
|
+
def plot_mean_squares(predictor, mse, figure_dir):
|
|
192
|
+
names = [output.name for output in predictor.model.outputs.values()]
|
|
193
|
+
colors = [COLORS[param.base] for param in predictor.model.params.values()]
|
|
194
|
+
labels = [param.base for param in predictor.model.params.values()]
|
|
195
|
+
positions = [0.5*i for i in range(mse.shape[1])]
|
|
196
|
+
ticks = [0 for i in range(len(predictor.model.outputs))]
|
|
197
|
+
for i, output in enumerate(predictor.model.outputs.values()):
|
|
198
|
+
for j, param in enumerate(predictor.model.params):
|
|
199
|
+
if param in output.dist.params:
|
|
200
|
+
positions[j] += 0.5*i + 1
|
|
201
|
+
ticks[i] += positions[j] / len(output.dist.params)
|
|
202
|
+
|
|
203
|
+
plt.figure(figsize = (10, 6))
|
|
204
|
+
bplot = plt.boxplot(mse, whis = [0, 100], widths = 0.5, positions = positions, patch_artist = True, label = labels)
|
|
205
|
+
for patch, color in zip(bplot["boxes"], colors): patch.set_facecolor(color)
|
|
206
|
+
for patch in bplot["medians"]: patch.set_color("black")
|
|
207
|
+
plt.ylabel("Mean Squared Error")
|
|
208
|
+
plt.xticks(ticks, names, rotation = 45, ha = "right")
|
|
209
|
+
plt.yscale("log")
|
|
210
|
+
handles, labels = plt.gca().get_legend_handles_labels()
|
|
211
|
+
by_label = dict(zip(labels, handles))
|
|
212
|
+
plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor = (1.02, 1), loc = "upper left")
|
|
213
|
+
plt.tight_layout()
|
|
214
|
+
plt.savefig(f"{figure_dir}/mean_squares_plot.png", dpi = 200)
|
|
215
|
+
plt.close()
|
FASTEN/predict.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .common import pd, np, torch
|
|
3
|
+
from .data import Samples, Statistics, Dataset
|
|
4
|
+
from .learn import Loss
|
|
5
|
+
from .model import Model
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Predictor():
|
|
9
|
+
def __init__(self, model: Model, dataset: Dataset = Dataset()):
|
|
10
|
+
self.model: Model = model
|
|
11
|
+
self.true: Dataset = dataset
|
|
12
|
+
self.pred: Dataset = Dataset()
|
|
13
|
+
self.error: pd.DataFrame = None
|
|
14
|
+
self.criterion: Loss = Loss(model)
|
|
15
|
+
|
|
16
|
+
def load_inputs(self, inputs_file: str, num_runs: int):
|
|
17
|
+
self.true.stats.load_data(inputs_file, self.model.inputs, None)
|
|
18
|
+
self.true.stats.scale_inputs(self.model.input_scaler)
|
|
19
|
+
num_runs = max(num_runs, 1)
|
|
20
|
+
index = np.arange(self.true.stats.inputs.shape[0]).repeat(num_runs)
|
|
21
|
+
self.true.samples.inputs = self.true.stats.inputs.loc[index]
|
|
22
|
+
self.true.samples.inputs = self.true.samples.inputs.reset_index(drop = True)
|
|
23
|
+
self.true.samples.group_data()
|
|
24
|
+
|
|
25
|
+
def dump_statistics(self, outputs_file: str):
|
|
26
|
+
self.pred.stats.dump_data(outputs_file)
|
|
27
|
+
|
|
28
|
+
def dump_samples(self, outputs_file: str):
|
|
29
|
+
self.pred.samples.dump_data(outputs_file)
|
|
30
|
+
|
|
31
|
+
def execute(self):
|
|
32
|
+
self.pred.stats = self.predict_stats()
|
|
33
|
+
self.pred.stats.unscale_inputs(self.model.input_scaler)
|
|
34
|
+
self.pred.stats.unscale_outputs(self.model.param_scaler)
|
|
35
|
+
if isinstance(self.true.stats.outputs, pd.DataFrame):
|
|
36
|
+
self.true.stats.unscale_outputs(self.model.param_scaler)
|
|
37
|
+
self.pred.samples = self.predict_samples()
|
|
38
|
+
|
|
39
|
+
def evaluate(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
|
40
|
+
dependent = [output.dist.support.dependent for output in self.model.outputs.values()]
|
|
41
|
+
mse, kld, nll = self.criterion.evaluate(self.true, self.pred, any(dependent))
|
|
42
|
+
mse = pd.DataFrame(mse.detach().cpu().numpy(), columns = self.model.params)
|
|
43
|
+
kld = pd.DataFrame(kld.detach().cpu().numpy(), columns = self.model.outputs)
|
|
44
|
+
if any(dependent): return mse, kld, None
|
|
45
|
+
nll = pd.DataFrame(nll.detach().cpu().numpy(), columns = self.model.outputs)
|
|
46
|
+
return mse, kld, nll
|
|
47
|
+
|
|
48
|
+
def predict_stats(self) -> Statistics:
|
|
49
|
+
self.model.network.eval()
|
|
50
|
+
inputs = self.true.stats.inputs.values
|
|
51
|
+
x = torch.tensor(inputs).to(self.model.args.device)
|
|
52
|
+
with torch.no_grad(): y = self.model.network(x)
|
|
53
|
+
outputs = y.detach().cpu().numpy()
|
|
54
|
+
params = pd.DataFrame(outputs, columns = self.model.params)
|
|
55
|
+
return Statistics(self.true.stats.inputs, params)
|
|
56
|
+
|
|
57
|
+
def predict_samples(self) -> Samples:
|
|
58
|
+
outputs = pd.DataFrame()
|
|
59
|
+
index = self.true.samples.group.values
|
|
60
|
+
for label, output in self.model.outputs.items():
|
|
61
|
+
mask = [param in output.dist.params for param in self.model.params]
|
|
62
|
+
stats = self.pred.stats.outputs.iloc[index, mask]
|
|
63
|
+
params = torch.tensor(stats.values)
|
|
64
|
+
fit = output.dist.base(*params.unbind(dim = 1))
|
|
65
|
+
outputs[label] = fit.sample().numpy()
|
|
66
|
+
return Samples(self.true.samples.inputs, outputs)
|
FASTEN/train.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from .common import torch, os, shutil
|
|
2
|
+
from .learn import *
|
|
3
|
+
from .data import Dataset
|
|
4
|
+
from .model import Model
|
|
5
|
+
from rich import progress
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Trainer():
|
|
9
|
+
def __init__(self, model: Model):
|
|
10
|
+
self.model: Model = model
|
|
11
|
+
self.dataset: Dataset = Dataset()
|
|
12
|
+
self.train: Partition = None
|
|
13
|
+
self.valid: Partition = None
|
|
14
|
+
self.test: Partition = None
|
|
15
|
+
|
|
16
|
+
def load_data(self, samples_file: str):
|
|
17
|
+
self.dataset.load_samples(samples_file, self.model)
|
|
18
|
+
self.dataset.load_stats(self.model, self.model.args.estimator)
|
|
19
|
+
self.split_data()
|
|
20
|
+
|
|
21
|
+
def split_data(self):
|
|
22
|
+
self.train = Partition(self.dataset, self.model.args.loss_func)
|
|
23
|
+
test = self.train.dataset.split(self.model.args.test_split, self.model.args.rand_seed)
|
|
24
|
+
split = self.model.args.valid_split / (1 - self.model.args.test_split)
|
|
25
|
+
valid = self.train.dataset.split(split, self.model.args.rand_seed)
|
|
26
|
+
if test: self.test = Partition(test, self.model.args.loss_func)
|
|
27
|
+
if valid: self.valid = Partition(valid, self.model.args.loss_func)
|
|
28
|
+
|
|
29
|
+
def load_args(self):
|
|
30
|
+
torch.set_num_threads(os.cpu_count())
|
|
31
|
+
for partition in [self.train, self.valid, self.test]:
|
|
32
|
+
if partition: partition.load(self.model.args.batch_size, self.model.args.device)
|
|
33
|
+
args = {"lr": self.model.args.learn_rate, "weight_decay": self.model.args.weight_decay}
|
|
34
|
+
if self.model.args.optimizer.__name__ == "SGD": args["momentum"] = self.model.args.momentum
|
|
35
|
+
self.optimizer = self.model.args.optimizer(self.model.network.parameters(), **args)
|
|
36
|
+
patience, min_delta = self.model.args.patience, self.model.args.min_delta
|
|
37
|
+
if self.model.args.early_stop: self.early_stop = EarlyStop(patience, min_delta)
|
|
38
|
+
self.criterion = Loss(self.model).to(self.model.args.device)
|
|
39
|
+
|
|
40
|
+
def execute(self):
|
|
41
|
+
self.load_args()
|
|
42
|
+
loadbar = progress.Progress(progress.TextColumn("{task.description}"), progress.BarColumn(),
|
|
43
|
+
progress.MofNCompleteColumn(), progress.TimeRemainingColumn())
|
|
44
|
+
with loadbar as load:
|
|
45
|
+
task = load.add_task("", total = self.model.args.num_epochs)
|
|
46
|
+
for _ in range(self.model.args.num_epochs):
|
|
47
|
+
if self.train: self.train_model()
|
|
48
|
+
if self.valid: self.test_model(valid = True)
|
|
49
|
+
if self.valid and self.model.args.early_stop:
|
|
50
|
+
loss = self.valid.loss[-1]
|
|
51
|
+
if self.early_stop(loss): break
|
|
52
|
+
load.update(task, advance = 1)
|
|
53
|
+
if self.test: self.test_model(valid = False)
|
|
54
|
+
|
|
55
|
+
def train_model(self):
|
|
56
|
+
self.model.network.train()
|
|
57
|
+
train_loss = 0
|
|
58
|
+
for x, y in self.train.dataloader:
|
|
59
|
+
self.optimizer.zero_grad()
|
|
60
|
+
y_pred = self.model.network(x)
|
|
61
|
+
try: loss = self.criterion(y_pred, y)
|
|
62
|
+
except ValueError: raise ValueError("Training diverged: loss is NaN (possible exploding gradients)")
|
|
63
|
+
train_loss += loss.item()
|
|
64
|
+
loss.backward()
|
|
65
|
+
self.optimizer.step()
|
|
66
|
+
train_loss /= len(self.train.dataloader)
|
|
67
|
+
self.train.loss.append(train_loss)
|
|
68
|
+
|
|
69
|
+
def test_model(self, valid: bool):
|
|
70
|
+
if valid: partition = self.valid
|
|
71
|
+
else: partition = self.test
|
|
72
|
+
self.model.network.eval()
|
|
73
|
+
test_loss = 0
|
|
74
|
+
with torch.no_grad():
|
|
75
|
+
for x, y in partition.dataloader:
|
|
76
|
+
y_pred = self.model.network(x)
|
|
77
|
+
loss = self.criterion(y_pred, y)
|
|
78
|
+
test_loss += loss.item()
|
|
79
|
+
test_loss /= len(partition.dataloader)
|
|
80
|
+
partition.loss.append(test_loss)
|
|
81
|
+
|
|
82
|
+
def dump_model(self, output):
|
|
83
|
+
if os.path.exists(output):
|
|
84
|
+
shutil.rmtree(output)
|
|
85
|
+
os.mkdir(output)
|
|
86
|
+
os.mkdir(f"{output}/plots")
|
|
87
|
+
self.model.dump(output)
|