FASTEN-cli 1.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
FASTEN/model.py ADDED
@@ -0,0 +1,152 @@
1
+ from .common import np, torch, nn, F, os, json
2
+ from .config import ModelArgs, ModelInput, ModelOutput
3
+ from .param import ModelParam, Constraint
4
+ from .utils import Scaler, Encoder
5
+ from pydantic import ValidationError
6
+ from zipfile import ZipFile
7
+ import joblib, random
8
+
9
+
10
+ class Network(nn.Module):
11
+ def __init__(self, args: ModelArgs, input_size: int, output_size: int):
12
+ super(Network, self).__init__()
13
+ self.layers, sizes = nn.ModuleList(), [input_size]
14
+ if not args.hidden_layers: sizes.append(output_size)
15
+ elif args.architecture == "pyramidal":
16
+ step = (args.hidden_size - output_size) / args.hidden_layers
17
+ for i in range(args.hidden_layers):
18
+ current_size = int(args.hidden_size - (i * step))
19
+ sizes.append(max(current_size, output_size) )
20
+ sizes.append(output_size)
21
+ elif args.architecture == "rectangular":
22
+ sizes.extend([args.hidden_size] * args.hidden_layers)
23
+ sizes.append(output_size)
24
+ for i in range(len(sizes) - 1):
25
+ self.layers.append(nn.Linear(sizes[i], sizes[i+1]))
26
+
27
+ def load_constraints(self, params: dict[str, ModelParam]):
28
+ for rule in Constraint.RULES:
29
+ self.register_buffer(rule, torch.zeros(len(params), dtype = bool))
30
+ for value in Constraint.VALUES:
31
+ self.register_buffer(value, torch.zeros(len(params), dtype = float))
32
+ for i, param in enumerate(params.values()):
33
+ for rule in Constraint.RULES: self.get_buffer(rule)[i] = param.priors.get_rule(rule)
34
+ for value in Constraint.VALUES: self.get_buffer(value)[i] = param.priors.get_value(value)
35
+ for label in ["low", "high"]:
36
+ loop = (param.base == "label" for param in params.values())
37
+ mask = np.fromiter(loop, dtype = bool, count = len(params))
38
+ self.register_buffer(label, torch.from_numpy(mask))
39
+ self.to(next(self.parameters()).device)
40
+
41
+ def load_scaler(self, scaler: Scaler):
42
+ if not scaler.fitted: return
43
+ self.register_buffer("min", torch.from_numpy(scaler.min))
44
+ self.register_buffer("range", torch.from_numpy(scaler.range))
45
+ self.to(next(self.parameters()).device)
46
+
47
+ def forward(self, x):
48
+ for layer in self.layers[:-1]:
49
+ x = F.relu(layer(x))
50
+ x = self.layers[-1](x)
51
+ x[:, self.greater_than] = F.softplus(x[:, self.greater_than]) + (self.lower[self.greater_than] - self.min[self.greater_than]) / self.range[self.greater_than]
52
+ x[:, self.less_than] = (self.upper[self.less_than] - self.min[self.less_than]) / self.range[self.less_than] - F.softplus(x[:, self.less_than])
53
+ x[:, self.between] = F.sigmoid(x[:, self.between]) * self.interval[self.between] / self.range[self.between] + (self.lower[self.between] - self.min[self.between]) / self.range[self.between]
54
+ x[:, self.high] = x[:, self.low] + F.softplus(x[:, self.high] - x[:, self.low])
55
+ return x
56
+
57
+
58
+ class Model():
59
+ def __init__(self, config_file: str = None, model_file: str = None):
60
+ torch.set_default_dtype(torch.float64)
61
+ self.inputs: dict[str, ModelInput] = dict()
62
+ self.outputs: dict[str, ModelOutput] = dict()
63
+ self.params: dict[str, ModelParam] = dict()
64
+ self.args: ModelArgs = None
65
+ if config_file:
66
+ with open(config_file, "r") as file:
67
+ self.config = json.load(file)
68
+ else: self.config = {"train": {}}
69
+ if not model_file: self.create()
70
+ else: self.load(model_file)
71
+
72
+ def validate_data(self, inputs: dict, outputs: dict):
73
+ for label, config in inputs.items():
74
+ try: self.inputs[label] = ModelInput(**config, label = label)
75
+ except ValidationError as error: raise error
76
+ for label, config in outputs.items():
77
+ try: output = ModelOutput(**config, label = label)
78
+ except ValidationError as error: raise error
79
+ params = output.dist.load_parameters(output)
80
+ self.params.update(params)
81
+ self.outputs[label] = output
82
+
83
+ def validate_args(self, model: dict, train: dict):
84
+ model = {k: v for k, v in model.items() if not isinstance(v, list)}
85
+ train = {k: v for k, v in train.items() if not isinstance(v, list)}
86
+ try: self.args = ModelArgs(**{**model, **train})
87
+ except ValidationError as error: raise error
88
+ if self.args.rand_seed is not None:
89
+ self.set_seed(self.args.rand_seed)
90
+ self.network = Network(self.args, len(self.inputs), len(self.params))
91
+ self.network = self.network.to(self.args.device)
92
+ self.network.load_constraints(self.params)
93
+ self.network.load_scaler(self.param_scaler)
94
+
95
+ @staticmethod
96
+ def set_seed(seed: int):
97
+ torch.manual_seed(seed)
98
+ torch.cuda.manual_seed(seed)
99
+ torch.cuda.manual_seed_all(seed)
100
+ os.environ['PYTHONHASHSEED'] = str(seed)
101
+ torch.backends.cudnn.deterministic = True
102
+ torch.backends.cudnn.benchmark = False
103
+ random.seed(seed)
104
+ np.random.seed(seed)
105
+
106
+ def create(self):
107
+ self.input_scaler, self.param_scaler = Scaler(), Scaler()
108
+ self.validate_data(self.config["inputs"], self.config["outputs"])
109
+ self.validate_args(self.config["model"], self.config["train"])
110
+ self.encoder = Encoder(self.inputs)
111
+
112
+ def load(self, model_file):
113
+ with ZipFile(model_file, "r") as model:
114
+ with model.open("input_scaler.pkl", "r") as file:
115
+ self.input_scaler = joblib.load(file)
116
+ with model.open("param_scaler.pkl", "r") as file:
117
+ self.param_scaler = joblib.load(file)
118
+ with model.open("encoder.pkl", "r") as file:
119
+ self.encoder = joblib.load(file)
120
+ with model.open("config.json", "r") as file:
121
+ model_config = json.load(file)
122
+ self.config["inputs"] = model_config["inputs"]
123
+ self.config["outputs"] = model_config["outputs"]
124
+ self.config["model"] = model_config["model"]
125
+ self.validate_data(self.config["inputs"], self.config["outputs"])
126
+ self.validate_args(self.config["model"], self.config["train"])
127
+ with model.open("model.pth", "r") as file:
128
+ self.network.load_state_dict(torch.load(file))
129
+ self.network.load_scaler(self.param_scaler)
130
+
131
+ def dump(self, output):
132
+ model_file = f"{output}/model.zip"
133
+ with ZipFile(model_file, "w") as model:
134
+ with model.open("input_scaler.pkl", "w") as file:
135
+ joblib.dump(self.input_scaler, file)
136
+ with model.open("param_scaler.pkl", "w") as file:
137
+ joblib.dump(self.param_scaler, file)
138
+ with model.open("encoder.pkl", "w") as file:
139
+ joblib.dump(self.encoder, file)
140
+ with model.open("model.pth", "w") as file:
141
+ torch.save(self.network.state_dict(), file)
142
+ config = json.dumps(self.config, indent = 4)
143
+ model.writestr("config.json", config)
144
+
145
+ # def extract_config(self) -> dict:
146
+ # train, model = self.args.model_dump_json(), dict()
147
+ # train["device"], train["optimizer"] = train["device"].type, train["optimizer"].__name__
148
+ # for key in ["architecture", "hidden_layers", "hidden_size"]: model[key] = train.pop(key)
149
+ # inputs = {label: value.model_dump_json() for label, value in self.inputs.items()}
150
+ # outputs = {label: value.model_dump_json() for label, value in self.outputs.items()}
151
+ # return {"inputs": inputs, "outputs": outputs, "train": train, "model": model}
152
+
FASTEN/param.py ADDED
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+ from .common import torch, pd
3
+ from torch.distributions import *
4
+ from typing import TYPE_CHECKING
5
+ from copy import deepcopy
6
+
7
+ if TYPE_CHECKING:
8
+ from .config import ModelOutput
9
+
10
+
11
+ class ModelParam():
12
+ def __init__(self, output: ModelOutput, label: str):
13
+ self.base: str = label.replace("0", "").replace("1", "")
14
+ self.label: str = f"{output.label}_{label}"
15
+ self.name: str = f"{output.name} {label.replace('_', ' ').title()}"
16
+ self.priors = Constraint(output, self.base)
17
+ self.constraints = deepcopy(self.priors)
18
+
19
+ def load_constraints(self, support: Support, min_val: float, max_val: float):
20
+ if support.upper == self.base: self.constraints.add_lower(max_val)
21
+ if support.lower == self.base: self.constraints.add_upper(min_val)
22
+ self.constraints.add_intervals()
23
+
24
+
25
+ class ModelDist():
26
+ TYPES = {"Bernoulli", "Beta", "Binomial", "Cauchy", "Chi2", "ContinuousBernoulli", "Exponential",
27
+ "FisherSnedecor", "Gamma", "Geometric", "Gumbel", "HalfCauchy", "HalfNormal", "InverseGamma",
28
+ "Laplace", "LogNormal", "NegativeBinomial", "Normal", "Pareto", "Poisson", "Uniform"}
29
+
30
+ def __init__(self, name: str):
31
+ self.name: str = name
32
+ self.params: dict[str, ModelParam] = dict()
33
+ self.support = Support(self.name)
34
+ if hasattr(ModelDist, name): self.base = getattr(ModelDist, name)
35
+ else: self.base = getattr(torch.distributions, name)
36
+
37
+ def load_parameters(self, output: ModelOutput) -> dict[str, ModelParam]:
38
+ base = getattr(torch.distributions, self.name)
39
+ if self.name == "Uniform": labels = ["low", "high"]
40
+ else: labels = list(base.arg_constraints.keys())
41
+ for label in labels:
42
+ if label == "probs": continue
43
+ param = ModelParam(output, label)
44
+ self.params[f"{output.label}_{label}"] = param
45
+ return self.params
46
+
47
+ @staticmethod
48
+ def Bernoulli(logits: float | torch.Tensor) -> Bernoulli:
49
+ return Bernoulli(logits = logits, validate_args = False)
50
+ @staticmethod
51
+ def ContinuousBernoulli(logits: float | torch.Tensor) -> ContinuousBernoulli:
52
+ return ContinuousBernoulli(logits = logits, validate_args = False)
53
+ @staticmethod
54
+ def Geometric(logits: float | torch.Tensor) -> Geometric:
55
+ return Geometric(logits = logits, validate_args = False)
56
+ @staticmethod
57
+ def Binomial(total_count: float | torch.Tensor, logits: float | torch.Tensor) -> Binomial:
58
+ if isinstance(total_count, torch.Tensor):
59
+ rounded_total_count = total_count + (total_count.round() - total_count).detach()
60
+ return Binomial(rounded_total_count, logits = logits, validate_args = False)
61
+ else: return Binomial(round(total_count), logits = logits, validate_args = False)
62
+ @staticmethod
63
+ def NegativeBinomial(total_count: int | torch.Tensor, logits: float | torch.Tensor) -> NegativeBinomial:
64
+ return NegativeBinomial(total_count, logits = logits, validate_args = False)
65
+
66
+
67
+ class Constraint():
68
+ POSITIVE = {"concentration", "total_count", "scale", "df", "rate", "alpha"}
69
+ RULES = ["greater_than", "less_than", "between"]
70
+ VALUES = ["lower", "upper", "interval"]
71
+
72
+ def __init__(self, output: ModelOutput, base: str, eps: float = 1e-16):
73
+ self.greater_than = self.less_than = self.between = False
74
+ self.lower = self.upper = self.interval = 0.0
75
+ if base in self.POSITIVE: self.greater_than, self.lower = True, eps
76
+ if hasattr(output, f"{base}_min"):
77
+ self.add_lower(getattr(output, f"{base}_min"))
78
+ if hasattr(output, f"{base}_max"):
79
+ self.add_upper(getattr(output, f"{base}_max"))
80
+ self.add_intervals()
81
+
82
+ def add_upper(self, upper): self.less_than, self.upper = True, upper
83
+ def add_lower(self, lower): self.greater_than, self.lower = True, lower
84
+ def add_intervals(self):
85
+ if self.greater_than and self.less_than:
86
+ self.greater_than = self.less_than = False
87
+ self.between, self.interval = True, self.upper - self.lower
88
+
89
+ def get_rule(self, rule: str) -> bool: return getattr(self, rule)
90
+ def get_value(self, value: str) -> float: return getattr(self, value)
91
+
92
+
93
+ class Support:
94
+ DISCRETE = {"Bernoulli", "Binomial", "Geometric", "NegativeBinomial", "Poisson"}
95
+ POSITIVE = {"FisherSnedecor", "InverseGamma", "LogNormal"}
96
+ NON_NEGATIVE = {"Binomial", "Poisson", "Geometric", "NegativeBinomial", "Chi2", "Exponential", "Gamma", "HalfCauchy", "HalfNormal"}
97
+ PROBABILITY = {"Bernoulli", "Beta", "ContinuousBernoulli"}
98
+ DEPENDENT = {"Uniform", "Binomial", "Pareto"}
99
+
100
+ def __init__(self, name, eps = 1e-16):
101
+ self.lower = self.upper = None
102
+ self.discrete = (name in self.DISCRETE)
103
+ if name in self.POSITIVE: self.lower = eps
104
+ if name in self.NON_NEGATIVE: self.lower = 0.0
105
+ if name in self.PROBABILITY: self.lower, self.upper = 0.0, 1.0
106
+ self.dependent = (name in self.DEPENDENT)
107
+ if name == "Uniform": self.lower, self.upper = "low", "high"
108
+ if name == "Binomial": self.upper = "total_count"
109
+ if name == "Pareto": self.lower = "scale"
110
+
111
+ def validate(self, data: pd.Series) -> bool:
112
+ if isinstance(self.lower, float) and (data < self.lower).any(): return False
113
+ if isinstance(self.upper, float) and (data > self.upper).any(): return False
114
+ return True
115
+
116
+ def get_bound(self, attr: str, params: dict[str, ModelParam], stats: torch.Tensor, fit: Distribution, std_devs: int = 3) -> torch.Tensor:
117
+ bound, sign = getattr(self, attr), -1 if attr == "lower" else 1
118
+ if bound is None: return (fit.mean + sign * std_devs * torch.sqrt(fit.variance)).unsqueeze(-1)
119
+ if isinstance(bound, str): return stats[:, [param.base == bound for param in params.values()]]
120
+ else: return torch.tensor([bound]).repeat(stats.shape[0]).unsqueeze(-1)
FASTEN/plot.py ADDED
@@ -0,0 +1,215 @@
1
+ from .common import torch, pd, np, os
2
+ import matplotlib.pyplot as plt
3
+ import optuna
4
+
5
+ COLORS = {"logits": "C6", "total_count": "C4", "rate": "C3",
6
+ "concentration": "C2", "loc": "C0", "scale": "C1",
7
+ "df": "C7", "alpha": "C5", "high": "C8", "low": "C9"}
8
+
9
+ def plot_train(trainer, figure_dir):
10
+ if not os.path.exists(figure_dir): os.mkdir(figure_dir)
11
+ plt.figure(figsize = (6, 4))
12
+ plt.plot(range(len(trainer.train.loss)), trainer.train.loss,
13
+ color = "blue", alpha = 0.5, label = "Training", zorder = 2)
14
+ max_val = max(trainer.train.loss)
15
+ if trainer.valid:
16
+ plt.plot(range(len(trainer.valid.loss)), trainer.valid.loss,
17
+ color = "red", alpha = 0.5, label = "Validation", zorder = 2)
18
+ max_val = max(max_val, max(trainer.valid.loss))
19
+ if max_val > 10: plt.yscale("log")
20
+ plt.xlabel("Epoch")
21
+ plt.legend()
22
+ plt.ylabel("Average Loss")
23
+ for i in range(2, 4):
24
+ plt.subplot(1, 3, i)
25
+ plt.plot([], [])
26
+ plt.axis("off")
27
+ plt.tight_layout()
28
+ plt.savefig(f"{figure_dir}/loss_curve_plot.png", dpi = 200)
29
+ plt.close()
30
+
31
+ def plot_tune(tuner, figure_dir):
32
+ plt.figure(figsize = (4,4))
33
+ rank = optuna.importance.get_param_importances(tuner.study)
34
+ importances = np.array(list(rank.values()))
35
+ params = np.array(list(rank.keys()))
36
+ index = np.argsort(importances)
37
+ plt.barh(params[index], importances[index])
38
+ plt.xlabel("Importance")
39
+ plt.ylabel("Hyperparameter")
40
+ plt.tight_layout()
41
+ plt.savefig(f"{figure_dir}/importance_plot.png", dpi = 200)
42
+ plt.close()
43
+
44
+ trials = tuner.study.trials_dataframe()
45
+ plt.figure(figsize = (3 * len(params), 3.5))
46
+ for i, param in enumerate(params):
47
+ plt.subplot(1, len(params), i+1)
48
+ for value, group in trials.groupby(f"params_{param}"):
49
+ x, y = np.repeat(str(value), len(group)), group["value"]
50
+ plt.scatter(x, y, alpha = 0.5, linewidths = 0, color = "C0")
51
+ if y.max() - y.min() > 100: plt.yscale("log")
52
+ plt.xlabel(param)
53
+ plt.ylabel("Loss")
54
+ plt.xticks(rotation = 45, ha = "right")
55
+ plt.tight_layout()
56
+ plt.savefig(f"{figure_dir}/slices_plot.png", dpi = 200)
57
+ plt.close()
58
+
59
+ best = [trials.loc[:i+1, "value"].min() for i in range(len(trials))]
60
+ plt.figure(figsize = (8, 4))
61
+ plt.scatter(range(len(trials)), trials["value"], color = "C0")
62
+ if trials["value"].max() - trials["value"].min() > 100: plt.yscale("log")
63
+ plt.plot(range(len(trials)), best, color = "C3")
64
+ plt.xlabel("Trial")
65
+ plt.ylabel("Loss")
66
+ plt.tight_layout()
67
+ plt.savefig(f"{figure_dir}/convergence_plot.png", dpi = 200)
68
+ plt.close()
69
+
70
+
71
+ def plot_predict(predictor, figure_dir):
72
+ if predictor.true is None: return None, None, None
73
+ predictor.execute()
74
+ mse, kld, nll = predictor.evaluate()
75
+ if not os.path.exists(figure_dir):
76
+ os.mkdir(figure_dir)
77
+ plot_statistics(predictor, figure_dir)
78
+ plot_kl_divergence(predictor, kld, figure_dir)
79
+ plot_mean_squares(predictor, mse, figure_dir)
80
+ plot_samples(predictor, figure_dir)
81
+ return mse, kld, nll
82
+
83
+ def plot_samples(predictor, figure_dir, points = 10000, n_samples = 10):
84
+ sample_dir = f"{figure_dir}/samples"
85
+ if not os.path.exists(sample_dir): os.mkdir(sample_dir)
86
+ groups = predictor.true.samples.outputs.groupby(predictor.true.samples.group)
87
+ n_samples = min(n_samples, len(groups) // 5)
88
+ clusters = predictor.true.stats.cluster_data(n_clusters = n_samples)
89
+ samples = pd.Series(np.arange(len(groups))).groupby(clusters).sample()
90
+ rows, cols = n_samples, len(predictor.model.outputs)
91
+ true = torch.empty((cols, points, rows)), torch.empty((cols, points, rows))
92
+ pred = torch.empty((cols, points, rows)), torch.empty((cols, points, rows))
93
+ (x_true, y_true), (x_pred, y_pred) = true, pred
94
+
95
+ for j, output in enumerate(predictor.model.outputs.values()):
96
+ mask = [param in output.dist.params for param in predictor.model.params]
97
+ for label, (x, y) in {"true": true, "pred": pred}.items():
98
+ dataset = getattr(predictor, label)
99
+ stats = dataset.stats.outputs.iloc[samples,mask]
100
+ params = torch.tensor(stats.values)
101
+ fit = output.dist.base(*params.unbind(dim = 1))
102
+ lower = output.dist.support.get_bound("lower", output.dist.params, params, fit)
103
+ upper = output.dist.support.get_bound("upper", output.dist.params, params, fit)
104
+ dtype = int if output.dist.support.discrete else float
105
+ x[j] = torch.lerp(lower, upper, torch.linspace(0, 1, points)).to(dtype).t()
106
+ y[j] = torch.exp(fit.log_prob(x[j]))
107
+
108
+ for i, sample in enumerate(samples):
109
+ plt.figure(figsize = (2.5 * cols + 0.5, 2))
110
+ group = groups.get_group(sample)
111
+ for j, output in enumerate(predictor.model.outputs.values()):
112
+ plt.subplot(1, cols, j + 1)
113
+ plt.hist(group.iloc[:,j], density = True, color = "darkgray")
114
+ real_true, real_pred = ~torch.isinf(y_true[j,:,i]), ~torch.isinf(y_pred[j,:,i])
115
+ mask_true, mask_pred = real_true.clone(), real_pred.clone()
116
+ mask_true[real_true] = y_true[j,real_true,i] >= 0.01 * max(y_true[j,real_true,i])
117
+ mask_pred[real_pred] = y_pred[j,real_pred,i] >= 0.01 * max(y_pred[j,real_pred,i])
118
+ plt.plot(x_true[j,mask_true,i], y_true[j,mask_true,i], color = "green",
119
+ alpha = 1, linewidth = 2, label = f"Emprical (N = {len(group)})")
120
+ plt.plot(x_pred[j,mask_pred,i], y_pred[j,mask_pred,i], color = "orange",
121
+ alpha = 1, linewidth = 2, linestyle = "dashed", label = "Predicted")
122
+ if j == cols - 1: plt.legend(bbox_to_anchor = (1.05, 1), loc = "upper left")
123
+ plt.xlabel(output.name)
124
+ plt.ylabel("Density")
125
+ plt.xticks(fontsize = 8)
126
+ plt.yticks(fontsize = 8)
127
+ plt.tight_layout()
128
+ plt.savefig(f"{sample_dir}/sample_plot_{sample}.png", dpi = 200)
129
+ plt.close()
130
+
131
+ def plot_statistics(predictor, figure_dir):
132
+ cols, rows = len(predictor.model.outputs), 2
133
+ groups = predictor.true.samples.outputs.groupby(predictor.true.samples.group)
134
+ plt.figure(figsize = (3 * cols, 3 * rows))
135
+
136
+ for j, output in enumerate(predictor.model.outputs.values()):
137
+ mask = [param in output.dist.params for param in predictor.model.params]
138
+ stats = predictor.pred.stats.outputs.iloc[:,mask]
139
+ params = torch.tensor(stats.values)
140
+ fit = output.dist.base(*params.unbind(dim = 1))
141
+ pred_means, pred_vars = fit.mean, fit.variance
142
+ true_means = torch.zeros(groups.ngroups)
143
+ true_vars = torch.zeros(groups.ngroups)
144
+ for i, group in groups:
145
+ true_means[i] = group[output.label].mean()
146
+ true_vars[i] = group[output.label].var()
147
+
148
+ plt.subplot(rows, cols, j + 1)
149
+ plt.scatter(true_means, pred_means, color = "C0")
150
+ max_val = max(pred_means.max().max(), true_means.max().max())
151
+ min_val = min(pred_means.min().min(), true_means.min().min())
152
+ delta = (max_val - min_val) * 0.1
153
+ max_val, min_val = max_val + delta, min_val - delta
154
+ plt.axline((min_val, min_val), slope = 1, color = "darkgray", linestyle = "--")
155
+ plt.xlim(min_val, max_val)
156
+ plt.ylim(min_val, max_val)
157
+ plt.ylabel(f"Predicted Mean")
158
+ plt.xlabel(f"Empirical Mean")
159
+ plt.title(output.name, pad = 10, fontweight = "bold")
160
+
161
+ plt.subplot(rows, cols, cols + j + 1)
162
+ plt.scatter(true_vars, pred_vars, color = "C1")
163
+ max_val = max(pred_vars.max().max(), true_vars.max().max())
164
+ min_val = min(pred_vars.min().min(), true_vars.min().min())
165
+ delta = (max_val - min_val) * 0.1
166
+ max_val, min_val = max_val + delta, min_val - delta
167
+ plt.axline((min_val, min_val), slope = 1, color = "darkgray", linestyle = "--")
168
+ plt.xlim(min_val, max_val)
169
+ plt.ylim(min_val, max_val)
170
+ plt.ylabel(f"Predicted Variance")
171
+ plt.xlabel(f"Empirical Variance")
172
+ plt.tight_layout(rect = [0, 0, 0.99, 1])
173
+ plt.savefig(f"{figure_dir}/statistics_plot.png", dpi = 200)
174
+ plt.close()
175
+
176
+ def plot_kl_divergence(predictor, kld, figure_dir):
177
+ names = [output.name for output in predictor.model.outputs.values()]
178
+ positions = range(1, kld.shape[1] + 1)
179
+ plt.figure(figsize = (6, 6))
180
+ bplot = plt.boxplot(kld, whis = [0, 100], patch_artist = True)
181
+ for patch in bplot["boxes"]: patch.set_facecolor("darkgray")
182
+ for patch in bplot["medians"]: patch.set_color("black")
183
+ plt.ylim(-0.5, kld.max().max() * 5)
184
+ plt.ylabel("KL Divergence")
185
+ plt.xticks(positions, names, rotation = 45, ha = "right")
186
+ plt.yscale("symlog")
187
+ plt.tight_layout()
188
+ plt.savefig(f"{figure_dir}/kl_divergence_plot.png", dpi = 200)
189
+ plt.close()
190
+
191
+ def plot_mean_squares(predictor, mse, figure_dir):
192
+ names = [output.name for output in predictor.model.outputs.values()]
193
+ colors = [COLORS[param.base] for param in predictor.model.params.values()]
194
+ labels = [param.base for param in predictor.model.params.values()]
195
+ positions = [0.5*i for i in range(mse.shape[1])]
196
+ ticks = [0 for i in range(len(predictor.model.outputs))]
197
+ for i, output in enumerate(predictor.model.outputs.values()):
198
+ for j, param in enumerate(predictor.model.params):
199
+ if param in output.dist.params:
200
+ positions[j] += 0.5*i + 1
201
+ ticks[i] += positions[j] / len(output.dist.params)
202
+
203
+ plt.figure(figsize = (10, 6))
204
+ bplot = plt.boxplot(mse, whis = [0, 100], widths = 0.5, positions = positions, patch_artist = True, label = labels)
205
+ for patch, color in zip(bplot["boxes"], colors): patch.set_facecolor(color)
206
+ for patch in bplot["medians"]: patch.set_color("black")
207
+ plt.ylabel("Mean Squared Error")
208
+ plt.xticks(ticks, names, rotation = 45, ha = "right")
209
+ plt.yscale("log")
210
+ handles, labels = plt.gca().get_legend_handles_labels()
211
+ by_label = dict(zip(labels, handles))
212
+ plt.legend(by_label.values(), by_label.keys(), bbox_to_anchor = (1.02, 1), loc = "upper left")
213
+ plt.tight_layout()
214
+ plt.savefig(f"{figure_dir}/mean_squares_plot.png", dpi = 200)
215
+ plt.close()
FASTEN/predict.py ADDED
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+ from .common import pd, np, torch
3
+ from .data import Samples, Statistics, Dataset
4
+ from .learn import Loss
5
+ from .model import Model
6
+
7
+
8
+ class Predictor():
9
+ def __init__(self, model: Model, dataset: Dataset = Dataset()):
10
+ self.model: Model = model
11
+ self.true: Dataset = dataset
12
+ self.pred: Dataset = Dataset()
13
+ self.error: pd.DataFrame = None
14
+ self.criterion: Loss = Loss(model)
15
+
16
+ def load_inputs(self, inputs_file: str, num_runs: int):
17
+ self.true.stats.load_data(inputs_file, self.model.inputs, None)
18
+ self.true.stats.scale_inputs(self.model.input_scaler)
19
+ num_runs = max(num_runs, 1)
20
+ index = np.arange(self.true.stats.inputs.shape[0]).repeat(num_runs)
21
+ self.true.samples.inputs = self.true.stats.inputs.loc[index]
22
+ self.true.samples.inputs = self.true.samples.inputs.reset_index(drop = True)
23
+ self.true.samples.group_data()
24
+
25
+ def dump_statistics(self, outputs_file: str):
26
+ self.pred.stats.dump_data(outputs_file)
27
+
28
+ def dump_samples(self, outputs_file: str):
29
+ self.pred.samples.dump_data(outputs_file)
30
+
31
+ def execute(self):
32
+ self.pred.stats = self.predict_stats()
33
+ self.pred.stats.unscale_inputs(self.model.input_scaler)
34
+ self.pred.stats.unscale_outputs(self.model.param_scaler)
35
+ if isinstance(self.true.stats.outputs, pd.DataFrame):
36
+ self.true.stats.unscale_outputs(self.model.param_scaler)
37
+ self.pred.samples = self.predict_samples()
38
+
39
+ def evaluate(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
40
+ dependent = [output.dist.support.dependent for output in self.model.outputs.values()]
41
+ mse, kld, nll = self.criterion.evaluate(self.true, self.pred, any(dependent))
42
+ mse = pd.DataFrame(mse.detach().cpu().numpy(), columns = self.model.params)
43
+ kld = pd.DataFrame(kld.detach().cpu().numpy(), columns = self.model.outputs)
44
+ if any(dependent): return mse, kld, None
45
+ nll = pd.DataFrame(nll.detach().cpu().numpy(), columns = self.model.outputs)
46
+ return mse, kld, nll
47
+
48
+ def predict_stats(self) -> Statistics:
49
+ self.model.network.eval()
50
+ inputs = self.true.stats.inputs.values
51
+ x = torch.tensor(inputs).to(self.model.args.device)
52
+ with torch.no_grad(): y = self.model.network(x)
53
+ outputs = y.detach().cpu().numpy()
54
+ params = pd.DataFrame(outputs, columns = self.model.params)
55
+ return Statistics(self.true.stats.inputs, params)
56
+
57
+ def predict_samples(self) -> Samples:
58
+ outputs = pd.DataFrame()
59
+ index = self.true.samples.group.values
60
+ for label, output in self.model.outputs.items():
61
+ mask = [param in output.dist.params for param in self.model.params]
62
+ stats = self.pred.stats.outputs.iloc[index, mask]
63
+ params = torch.tensor(stats.values)
64
+ fit = output.dist.base(*params.unbind(dim = 1))
65
+ outputs[label] = fit.sample().numpy()
66
+ return Samples(self.true.samples.inputs, outputs)
FASTEN/train.py ADDED
@@ -0,0 +1,87 @@
1
+ from .common import torch, os, shutil
2
+ from .learn import *
3
+ from .data import Dataset
4
+ from .model import Model
5
+ from rich import progress
6
+
7
+
8
+ class Trainer():
9
+ def __init__(self, model: Model):
10
+ self.model: Model = model
11
+ self.dataset: Dataset = Dataset()
12
+ self.train: Partition = None
13
+ self.valid: Partition = None
14
+ self.test: Partition = None
15
+
16
+ def load_data(self, samples_file: str):
17
+ self.dataset.load_samples(samples_file, self.model)
18
+ self.dataset.load_stats(self.model, self.model.args.estimator)
19
+ self.split_data()
20
+
21
+ def split_data(self):
22
+ self.train = Partition(self.dataset, self.model.args.loss_func)
23
+ test = self.train.dataset.split(self.model.args.test_split, self.model.args.rand_seed)
24
+ split = self.model.args.valid_split / (1 - self.model.args.test_split)
25
+ valid = self.train.dataset.split(split, self.model.args.rand_seed)
26
+ if test: self.test = Partition(test, self.model.args.loss_func)
27
+ if valid: self.valid = Partition(valid, self.model.args.loss_func)
28
+
29
+ def load_args(self):
30
+ torch.set_num_threads(os.cpu_count())
31
+ for partition in [self.train, self.valid, self.test]:
32
+ if partition: partition.load(self.model.args.batch_size, self.model.args.device)
33
+ args = {"lr": self.model.args.learn_rate, "weight_decay": self.model.args.weight_decay}
34
+ if self.model.args.optimizer.__name__ == "SGD": args["momentum"] = self.model.args.momentum
35
+ self.optimizer = self.model.args.optimizer(self.model.network.parameters(), **args)
36
+ patience, min_delta = self.model.args.patience, self.model.args.min_delta
37
+ if self.model.args.early_stop: self.early_stop = EarlyStop(patience, min_delta)
38
+ self.criterion = Loss(self.model).to(self.model.args.device)
39
+
40
+ def execute(self):
41
+ self.load_args()
42
+ loadbar = progress.Progress(progress.TextColumn("{task.description}"), progress.BarColumn(),
43
+ progress.MofNCompleteColumn(), progress.TimeRemainingColumn())
44
+ with loadbar as load:
45
+ task = load.add_task("", total = self.model.args.num_epochs)
46
+ for _ in range(self.model.args.num_epochs):
47
+ if self.train: self.train_model()
48
+ if self.valid: self.test_model(valid = True)
49
+ if self.valid and self.model.args.early_stop:
50
+ loss = self.valid.loss[-1]
51
+ if self.early_stop(loss): break
52
+ load.update(task, advance = 1)
53
+ if self.test: self.test_model(valid = False)
54
+
55
+ def train_model(self):
56
+ self.model.network.train()
57
+ train_loss = 0
58
+ for x, y in self.train.dataloader:
59
+ self.optimizer.zero_grad()
60
+ y_pred = self.model.network(x)
61
+ try: loss = self.criterion(y_pred, y)
62
+ except ValueError: raise ValueError("Training diverged: loss is NaN (possible exploding gradients)")
63
+ train_loss += loss.item()
64
+ loss.backward()
65
+ self.optimizer.step()
66
+ train_loss /= len(self.train.dataloader)
67
+ self.train.loss.append(train_loss)
68
+
69
+ def test_model(self, valid: bool):
70
+ if valid: partition = self.valid
71
+ else: partition = self.test
72
+ self.model.network.eval()
73
+ test_loss = 0
74
+ with torch.no_grad():
75
+ for x, y in partition.dataloader:
76
+ y_pred = self.model.network(x)
77
+ loss = self.criterion(y_pred, y)
78
+ test_loss += loss.item()
79
+ test_loss /= len(partition.dataloader)
80
+ partition.loss.append(test_loss)
81
+
82
+ def dump_model(self, output):
83
+ if os.path.exists(output):
84
+ shutil.rmtree(output)
85
+ os.mkdir(output)
86
+ os.mkdir(f"{output}/plots")
87
+ self.model.dump(output)