FASTEN-cli 1.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
FASTEN/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ from .model import Model
2
+ from .train import Trainer
3
+ from .predict import Predictor
4
+ from .tune import Tuner
5
+
6
+ __version__ = "0.1.0"
7
+ __all__ = ["Model", "Trainer", "Tuner", "Predictor"]
FASTEN/cli.py ADDED
@@ -0,0 +1,95 @@
1
+ from .model import Model
2
+ from .train import Trainer
3
+ from .predict import Predictor
4
+ from .tune import Tuner
5
+ from .plot import plot_train, plot_predict, plot_tune
6
+ from .common import pd
7
+ from rich.console import Console
8
+ import argparse, time
9
+
10
+ def parse_args():
11
+ superparser = argparse.ArgumentParser(description = "A flexible software framework to approximate computationally \
12
+ intensive simulations using neural-network-based emulators",
13
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter)
14
+ subparsers = superparser.add_subparsers(dest = "command")
15
+ train = subparsers.add_parser("train", description = "Trains emulator on simulation data",
16
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter)
17
+ train.add_argument("-c", "--config", required = True, help = "JSON file defining configuration parameters")
18
+ train.add_argument("-i", "--input", required = True, help = "TSV file with simulation data")
19
+ train.add_argument("-o", "--output", default = "outputs", help = "Folder to output model and figures")
20
+ train.add_argument("-m", "--model", default = None, help = "ZIP file containing initial model")
21
+ tune = subparsers.add_parser("tune", description = "Tunes hyperparameters for emulator on simulation data",
22
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter)
23
+ tune.add_argument("-c", "--config", required = True, help = "JSON file defining configuration parameters")
24
+ tune.add_argument("-i", "--input", required = True, help = "TSV file with simulation data")
25
+ tune.add_argument("-o", "--output", default = "outputs", help = "Folder to output optimal configs and figures")
26
+ tune.add_argument("-n", "--trials", type = int, default = 100, help = "Total number of optimation trials")
27
+ tune.add_argument("--unique", action = "store_true", help = "Prevents re-training with duplicate hyperparameter sets")
28
+ predict = subparsers.add_parser("predict", description = "Predicts simulation data from inputs with emulator",
29
+ formatter_class = argparse.ArgumentDefaultsHelpFormatter)
30
+ predict.add_argument("-m", "--model", required = True, help = "ZIP file containing model")
31
+ predict.add_argument("-i", "--input", required = True, help = "TSV file with simulation inputs")
32
+ predict.add_argument("-o", "--output", default = "outputs.tsv", help = "TSV file to output predicted simulation data")
33
+ predict.add_argument("-n", "--runs", default = 0, type = int, help = "Number of simulation runs per input")
34
+ return superparser.parse_args()
35
+
36
+ def train(args, console):
37
+ model = Model(config_file = args.config,
38
+ model_file = args.model)
39
+ trainer = Trainer(model)
40
+ console.log("Estimating distribution parameters...")
41
+ trainer.load_data(args.input)
42
+ console.log("Training neural network...")
43
+ trainer.execute()
44
+ console.log("Writing outputs...")
45
+ trainer.dump_model(args.output)
46
+
47
+ plot_train(trainer, f"{args.output}/plots/training")
48
+ predictor = Predictor(model, trainer.train.dataset)
49
+ mse, kld, _ = plot_predict(predictor, f"{args.output}/plots/training")
50
+ console.print(f"Average Training MSE = {mse.mean().mean():.3g}\nAverage Training KL Divergence = {kld.mean().mean():.3g}")
51
+ if not trainer.test: return
52
+ predictor = Predictor(model, trainer.test.dataset)
53
+ mse, kld, _ = plot_predict(predictor, f"{args.output}/plots/testing")
54
+ if isinstance(mse, pd.DataFrame):
55
+ console.print(f"Average Testing MSE = {mse.mean().mean():.3g}\nAverage Testing KL Divergence = {kld.mean().mean():.3g}")
56
+
57
+ def predict(args, console):
58
+ model = Model(model_file = args.model)
59
+ console.log("Predicting outputs...")
60
+ predictor = Predictor(model)
61
+ predictor.load_inputs(args.input, args.runs)
62
+ predictor.execute()
63
+
64
+ if args.runs: predictor.dump_samples(args.output)
65
+ else: predictor.dump_statistics(args.output)
66
+
67
+ def tune(args, console):
68
+ model = Model(config_file = args.config)
69
+ trainer = Trainer(model)
70
+ if args.trials:
71
+ console.log("Estimating distribution parameters...")
72
+ trainer.load_data(args.input)
73
+
74
+ console.log("Tuning hyperparameters...")
75
+ tuner = Tuner(trainer)
76
+ tuner.load_study(args.output)
77
+ tuner.execute(args.trials, args.unique)
78
+
79
+ console.log("Writing outputs...")
80
+ tuner.dump_trials(args.output)
81
+ plot_tune(tuner, f"{args.output}/plots")
82
+
83
+ def main():
84
+ args = parse_args()
85
+ start = time.perf_counter()
86
+ console = Console()
87
+ match args.command:
88
+ case "train": train(args, console)
89
+ case "predict": predict(args, console)
90
+ case "tune": tune(args, console)
91
+ end = time.perf_counter()
92
+ if end - start < 60: runtime = f"{(end - start):.2f} s"
93
+ elif end - start < 60 * 60: runtime = f"{(end - start) / 60:.2f} m"
94
+ else: runtime = f"{(end - start) / (60 * 60):.2f} h"
95
+ console.log(f"Done in {runtime}")
FASTEN/common.py ADDED
@@ -0,0 +1,5 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch, os, shutil, json
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
FASTEN/config.py ADDED
@@ -0,0 +1,110 @@
1
+ from .common import pd, torch
2
+ from .param import ModelDist
3
+ from typing import Literal
4
+ from warnings import warn
5
+ import pydantic as pdc
6
+
7
+
8
+ class ModelArgs(pdc.BaseModel):
9
+ test_split: float = pdc.Field(default = 0.1, ge = 0.0, lt = 1.0)
10
+ valid_split: float = pdc.Field(default = 0.1, ge = 0.0, lt = 1.0)
11
+ estimator: Literal["MoM", "MLE"] = "MLE"
12
+ rand_seed: int | None = None
13
+
14
+ architecture: Literal["rectangular", "pyramidal"] = "pyramidal"
15
+ hidden_layers: int = pdc.Field(default = 2, ge = 0)
16
+ hidden_size: int = pdc.Field(default = 64, gt = 0)
17
+
18
+ device: Literal["cpu", "cuda"] = "cpu"
19
+ batch_size: int = pdc.Field(default = 32, gt = 0)
20
+ num_epochs: int = pdc.Field(default = 1e5, gt = 0)
21
+
22
+ early_stop: bool = True
23
+ patience: int = pdc.Field(default = 20, ge = 0)
24
+ min_delta: float = pdc.Field(default = 0.0)
25
+
26
+ optimizer: Literal["SGD", "Adam", "AdamW"] = "AdamW"
27
+ loss_func: Literal["MSE", "KLD", "NLL"] = "NLL"
28
+ learn_rate: float = pdc.Field(default = 1e-3, gt = 0.0)
29
+ weight_decay: float = pdc.Field(default = 0.0, ge = 0.0)
30
+ momentum: float = pdc.Field(default = 0.0, ge = 0.0)
31
+
32
+ @pdc.field_validator("device", mode = "after")
33
+ @classmethod
34
+ def validate_device(cls, value):
35
+ if value == "cuda" and not torch.cuda.is_available():
36
+ warn("PyTorch cannot find a compatible GPU. Defaulting to CPU.")
37
+ return torch.device("cpu")
38
+ return torch.device(value)
39
+
40
+ @pdc.field_validator("optimizer", mode = "after")
41
+ @classmethod
42
+ def validate_optimizer(cls, value: str):
43
+ return getattr(torch.optim, value)
44
+
45
+ @pdc.model_validator(mode = "after")
46
+ def validate_early_stop(self):
47
+ if not self.valid_split and self.early_stop:
48
+ raise ValueError("Non-empty validation set required for early stopping.")
49
+ return self
50
+
51
+ @pdc.model_validator(mode = "after")
52
+ def validate_splits(self):
53
+ if self.valid_split + self.test_split >= 1:
54
+ raise ValueError("Non-empty training set required. Decrease size of validation or testing set.")
55
+ return self
56
+
57
+
58
+ class ModelInput(pdc.BaseModel):
59
+ label: str
60
+ name: str = pdc.Field(default_factory = lambda data: data['label'])
61
+ type: Literal["float", "integer", "string"] = "float"
62
+
63
+ def validate_data(self, data: pd.DataFrame, label: str):
64
+ if data[label].isna().any():
65
+ raise ValueError(f"Training data contains missing or undefined values: {self.name}.")
66
+ if self.type == "string" and not pd.api.types.is_string_dtype(data[label]):
67
+ raise ValueError(f"Training data has invalid values: {self.name}.")
68
+ if self.type in ["integer", "float"]:
69
+ if not pd.api.types.is_numeric_dtype(data[label]):
70
+ raise ValueError(f"Training data has invalid values: {self.name}.")
71
+ else: data[label] = data[label].astype(float)
72
+ if self.type == "integer" and (data[label] % 1 != 0).any():
73
+ warn(f"Integer type specified for non-integer training data: {self.name}. Rounding to nearest integer.")
74
+ data[label] = data[label].round()
75
+
76
+
77
+ class ModelOutput(pdc.BaseModel): # validate priors
78
+ model_config = pdc.ConfigDict(arbitrary_types_allowed = True, extra = "allow")
79
+
80
+ label: str
81
+ dist: str | ModelDist
82
+ name: str = pdc.Field(default_factory = lambda data: data['label'])
83
+ type: Literal["float", "integer"] = "float"
84
+ min_thresh: float | None = None
85
+ max_thresh: float | None = None
86
+
87
+ @pdc.field_validator("dist", mode = "after")
88
+ @classmethod
89
+ def validate_distribution(cls, value: str) -> ModelDist:
90
+ try: dist = ModelDist(value)
91
+ except AttributeError: raise ValueError(f"Invalid distribution specified.")
92
+ return dist
93
+
94
+ @pdc.model_validator(mode = "after")
95
+ def validate_discrete(self):
96
+ if self.dist.support.discrete and self.type != "integer":
97
+ raise ValueError(f"Discrete distribution specified for non-integer training data: {self.name}.")
98
+ return self
99
+
100
+ def validate_data(self, data: pd.DataFrame, label: str):
101
+ if data[label].isna().any():
102
+ raise ValueError(f"Training data contains missing or undefined values: {self.name}.")
103
+ if not pd.api.types.is_numeric_dtype(data[label]):
104
+ raise ValueError(f"Training data has invalid values: {self.name}.")
105
+ else: data[label] = data[label].astype(float)
106
+ if self.type == "integer" and (data[label] % 1 != 0).any():
107
+ warn(f"Integer type specified for non-integer training data: {self.name}. Rounding to nearest integer.")
108
+ data[label] = data[label].round()
109
+ if not self.dist.support.validate(data[label]):
110
+ raise AssertionError(f"Training data contains values outside domain of distribution: {self.name}")
FASTEN/data.py ADDED
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+ from .common import pd, os, np
3
+ from .config import ModelOutput
4
+ from .utils import Scaler, Encoder
5
+ from .estimate import Estimator
6
+ from .model import Model
7
+ from sklearn.model_selection import train_test_split
8
+ from k_means_constrained import KMeansConstrained
9
+ from typing import Any
10
+ from abc import ABC
11
+
12
+
13
+ class Dataset():
14
+ def __init__(self, samples: Samples = None, stats: Statistics = None):
15
+ self.samples = samples if samples else Samples()
16
+ self.stats = stats if stats else Statistics()
17
+
18
+ def load_samples(self, data_file: str, model: Model):
19
+ self.samples.load_data(data_file, model.inputs, model.outputs)
20
+ self.samples.filter_data(model.outputs)
21
+ for label, config in model.outputs.items():
22
+ config.validate_data(self.samples.outputs, label)
23
+ for label, config in model.inputs.items():
24
+ config.validate_data(self.samples.inputs, label)
25
+ self.samples.encode_inputs(model.encoder)
26
+ self.samples.scale_inputs(model.input_scaler)
27
+ self.samples.group_data()
28
+
29
+ def load_stats(self, model: Model, estimator: str):
30
+ self.estimate_stats(Estimator(estimator, model.outputs))
31
+ self.stats.scale_outputs(model.param_scaler)
32
+ model.network.load_scaler(model.param_scaler)
33
+
34
+ def estimate_stats(self, estimator: Estimator):
35
+ input_data = self.samples.inputs.groupby(self.samples.group)
36
+ output_data = self.samples.outputs.groupby(self.samples.group)
37
+ self.stats.inputs = input_data.first().reset_index(drop = True)
38
+ self.stats.outputs = estimator.execute(output_data)
39
+
40
+ def split(self, split_prop: float, rand_seed: int):
41
+ if not split_prop: return None
42
+ groups, index = self.stats.cluster_data(split_prop), self.stats.inputs.index
43
+ train_index, test_index = train_test_split(index, test_size = split_prop, stratify = groups, random_state = rand_seed)
44
+ train_index, test_index = sorted(train_index), sorted(test_index)
45
+ train_samples, test_samples = self.samples.group.isin(train_index), self.samples.group.isin(test_index)
46
+ test_stats = Statistics(self.stats.inputs.loc[test_index].reset_index(drop = True),
47
+ None if self.stats.outputs is None else self.stats.outputs.loc[test_index].reset_index(drop = True))
48
+ self.stats = Statistics(self.stats.inputs.loc[train_index].reset_index(drop = True),
49
+ None if self.stats.outputs is None else self.stats.outputs.loc[train_index].reset_index(drop = True))
50
+ test_samples = Samples(self.samples.inputs.loc[test_samples].reset_index(drop = True),
51
+ self.samples.outputs.loc[test_samples].reset_index(drop = True))
52
+ self.samples = Samples(self.samples.inputs.loc[train_samples].reset_index(drop = True),
53
+ self.samples.outputs.loc[train_samples].reset_index(drop = True))
54
+ return Dataset(samples = test_samples, stats = test_stats)
55
+
56
+
57
+ class Data(ABC):
58
+ def __init__(self, inputs: pd.DataFrame = None, outputs: pd.DataFrame = None):
59
+ self.inputs, self.outputs = inputs, outputs
60
+
61
+ def _repr_html_(self):
62
+ return pd.concat([self.inputs, self.outputs], axis = 1)._repr_html_()
63
+
64
+ def __str__(self):
65
+ return pd.concat([self.inputs, self.outputs], axis = 1).__str__()
66
+
67
+ def dump_data(self, data_file: str):
68
+ data = pd.concat([self.inputs, self.outputs], axis = 1)
69
+ data.to_csv(data_file, sep = '\t', index = False)
70
+
71
+ def load_data(self, data_file: str, inputs: dict[str, Any], outputs: dict[str, Any]):
72
+ inputs, outputs = list(inputs), list(outputs) if outputs else None
73
+ self.inputs = pd.read_csv(data_file, sep = "\t", usecols = inputs)[inputs]
74
+ self.inputs = self.inputs.sort_values(by = inputs)
75
+ if outputs:
76
+ self.outputs = pd.read_csv(data_file, sep = "\t", usecols = outputs)[outputs]
77
+ self.outputs = self.outputs.loc[self.inputs.index].reset_index(drop = True)
78
+ self.inputs = self.inputs.reset_index(drop = True)
79
+
80
+
81
+ class Samples(Data):
82
+ def __init__(self, inputs: pd.DataFrame = None, outputs: pd.DataFrame = None):
83
+ super().__init__(inputs, outputs)
84
+ if isinstance(inputs, pd.DataFrame): self.group_data()
85
+ else: self.group: pd.Series = None
86
+
87
+ def filter_data(self, outputs: dict[str, ModelOutput]):
88
+ for label, output in outputs.items():
89
+ if output.min_thresh is not None:
90
+ mask = (self.outputs[label] > output.min_thresh)
91
+ self.outputs = self.outputs[mask].reset_index(drop = True)
92
+ self.inputs = self.inputs[mask].reset_index(drop = True)
93
+ if output.max_thresh is not None:
94
+ mask = (self.outputs[label] < output.max_thresh)
95
+ self.outputs = self.outputs[mask].reset_index(drop = True)
96
+ self.inputs = self.inputs[mask].reset_index(drop = True)
97
+
98
+ def group_data(self):
99
+ matches = (self.inputs != self.inputs.shift())
100
+ self.group = matches.any(axis = 1).cumsum() - 1
101
+
102
+ def scale_inputs(self, scaler: Scaler): scaler.transform(self.inputs)
103
+ def unscale_inputs(self, scaler: Scaler): scaler.inverse_transform(self.inputs)
104
+ def encode_inputs(self, encoder: Encoder): encoder.transform(self.inputs)
105
+
106
+
107
+ class Statistics(Data):
108
+ def cluster_data(self, split_prop: float = 1, n_clusters: int | None = None) -> np.ndarray:
109
+ if n_clusters is None:
110
+ n_clusters = int(self.inputs.shape[0] * split_prop / 5)
111
+ if n_clusters < 1: return None
112
+ kmeans = KMeansConstrained(n_clusters = n_clusters, size_min = 5)
113
+ return kmeans.fit_predict(self.inputs)
114
+
115
+ def scale_outputs(self, scaler: Scaler): scaler.transform(self.outputs)
116
+ def unscale_outputs(self, scaler: Scaler): scaler.inverse_transform(self.outputs)
117
+ def scale_inputs(self, scaler: Scaler): scaler.transform(self.inputs)
118
+ def unscale_inputs(self, scaler: Scaler): scaler.inverse_transform(self.inputs)
FASTEN/estimate.py ADDED
@@ -0,0 +1,138 @@
1
+ from .common import np, pd, torch, F, os
2
+ from .param import ModelDist, Constraint
3
+ from .config import ModelOutput
4
+ from pandas.api.typing import DataFrameGroupBy
5
+ from rich import progress
6
+
7
+ PROGRESS = progress.Progress(
8
+ progress.TextColumn("{task.description}"),
9
+ progress.BarColumn(),
10
+ progress.MofNCompleteColumn(),
11
+ progress.TimeRemainingColumn(),
12
+ )
13
+
14
+ class Estimator():
15
+
16
+ class Moments():
17
+ @staticmethod
18
+ def Exponential(data: pd.Series) -> torch.Tensor:
19
+ mean = data.mean()
20
+ if not mean: raise AssertionError()
21
+ return torch.tensor([1 / mean])
22
+ @staticmethod
23
+ def Normal(data: pd.Series) -> torch.Tensor:
24
+ mean, var = data.mean(), data.var()
25
+ if np.isnan(var) or not var: raise AssertionError()
26
+ return torch.tensor([mean, np.sqrt(var)])
27
+ @staticmethod
28
+ def HalfNormal(data: pd.Series) -> torch.Tensor:
29
+ mean = np.sqrt(data.pow(2).mean())
30
+ if not mean: raise AssertionError()
31
+ return torch.tensor([mean])
32
+ @staticmethod
33
+ def LogNormal(data: pd.Series) -> torch.Tensor:
34
+ mean, var = np.log(data).mean(), np.sqrt(np.log(data).var())
35
+ if np.isnan(var) or not var: raise AssertionError()
36
+ return torch.tensor([mean, var])
37
+ @staticmethod
38
+ def Uniform(data: pd.Series) -> torch.Tensor:
39
+ if data.min() == data.max(): # raise AssertionError()
40
+ return torch.tensor([data.min(), data.max()])
41
+ return torch.tensor([data.min(), data.max()])
42
+ @staticmethod
43
+ def Geometric(data: pd.Series) -> torch.Tensor:
44
+ return torch.tensor([1 / (1 + data.mean())])
45
+ @staticmethod
46
+ def Poisson(data: pd.Series) -> torch.Tensor:
47
+ mean = data.mean()
48
+ if not mean: raise AssertionError()
49
+ return torch.tensor([mean])
50
+ @staticmethod
51
+ def Bernoulli(data: pd.Series) -> torch.Tensor:
52
+ return torch.tensor([data.mean()])
53
+ @staticmethod
54
+ def Laplace(data: pd.Series) -> torch.Tensor:
55
+ mad = (data - data.median()).abs().mean()
56
+ if mad <= 0: raise AssertionError()
57
+ return torch.tensor([data.median(), mad])
58
+ @staticmethod
59
+ def Pareto(data: pd.Series) -> torch.Tensor:
60
+ log_sum = np.log(data / data.min()).sum()
61
+ if not log_sum: raise AssertionError()
62
+ return torch.tensor([data.min(), data.shape[0] / log_sum])
63
+ @staticmethod
64
+ def Binomial(data: pd.Series) -> torch.Tensor:
65
+ mean, var = data.mean(), data.var()
66
+ if not var or np.isnan(var): raise AssertionError()
67
+ if not mean or var >= mean: raise AssertionError()
68
+ probs = 1 - var / mean
69
+ total_counts = max(mean / probs, data.max())
70
+ return torch.tensor([total_counts, np.log(probs / (1 - probs))])
71
+ @staticmethod
72
+ def NegativeBinomial(data: pd.Series) -> torch.Tensor:
73
+ mean, var = data.mean().item(), data.var().item()
74
+ if not var or np.isnan(var): raise AssertionError()
75
+ if not mean or var <= mean: raise AssertionError()
76
+ total_counts, probs = mean**2 / (var - mean), 1 - mean / var
77
+ return torch.tensor([total_counts, np.log(probs / (1 - probs))])
78
+
79
+ def __init__(self, estimator: str, outputs: dict[str, ModelOutput]):
80
+ self.estimator: str = estimator
81
+ self.outputs: dict[str, ModelOutput] = outputs
82
+
83
+ def execute(self, groups: DataFrameGroupBy) -> pd.DataFrame:
84
+ torch.set_num_threads(os.cpu_count())
85
+ with PROGRESS as progress:
86
+ params = [self.iterate(groups, label, progress) for label in self.outputs]
87
+ return pd.concat(params, axis = 1)
88
+
89
+ def iterate(self, total_groups: DataFrameGroupBy, label: str, progress: progress.Progress) -> pd.DataFrame:
90
+ output, groups, params = self.outputs[label], total_groups[label], dict()
91
+ task = progress.add_task(output.name, total = groups.ngroups)
92
+ for group, data in groups:
93
+ params[group] = self.estimate(output, data)
94
+ progress.update(task, advance = 1)
95
+ return pd.DataFrame.from_dict(params, "index", None, output.dist.params).sort_index()
96
+
97
+ def estimate(self, output: ModelOutput, data: pd.Series) -> np.ndarray:
98
+ if self.estimator == "MoM" and hasattr(self.Moments, output.dist.name):
99
+ method = getattr(self.Moments, output.dist.name)
100
+ try: return method(data).numpy()
101
+ except AssertionError: pass
102
+ return self.max_likelihood(output.dist, torch.tensor(data.values))
103
+
104
+ def max_likelihood(self, dist: ModelDist, data: torch.Tensor) -> np.ndarray:
105
+ self.load_constraints(dist, data)
106
+ if dist.support.discrete: data = data.to(int)
107
+ weights = torch.randn(len(dist.params), requires_grad = True, dtype = float)
108
+ optimizer = torch.optim.LBFGS([weights], max_iter = 200, line_search_fn = "strong_wolfe",
109
+ tolerance_grad = 1e-12, tolerance_change = 1e-12)
110
+
111
+ def closure():
112
+ optimizer.zero_grad()
113
+ params = self.apply_constraints(weights)
114
+ params.nan_to_num_(nan = 1e-16)
115
+ fit = dist.base(*params)
116
+ loss = -1 * fit.log_prob(data).mean()
117
+ loss.backward()
118
+ return loss
119
+
120
+ optimizer.step(closure)
121
+ with torch.no_grad():
122
+ return self.apply_constraints(weights).numpy()
123
+
124
+ def load_constraints(self, dist: ModelDist, data: pd.Series):
125
+ for rule in Constraint.RULES: setattr(self, rule, torch.zeros(len(dist.params), dtype = bool))
126
+ for value in Constraint.VALUES: setattr(self, value, torch.zeros(len(dist.params), dtype = float))
127
+ for i, param in enumerate(dist.params.values()):
128
+ min_val, max_val = data.min().item(), data.max().item()
129
+ param.load_constraints(dist.support, min_val, max_val)
130
+ for rule in Constraint.RULES: getattr(self, rule)[i] = param.constraints.get_rule(rule)
131
+ for value in Constraint.VALUES: getattr(self, value)[i] = param.constraints.get_value(value)
132
+
133
+ def apply_constraints(self, weights: torch.Tensor) -> torch.Tensor:
134
+ params = weights.clone()
135
+ params[self.greater_than] = F.softplus(params[self.greater_than]) + self.lower[self.greater_than]
136
+ params[self.less_than] = self.upper[self.less_than] - F.softplus(params[self.less_than])
137
+ params[self.between] = F.sigmoid(params[self.between]) * self.interval[self.between] + self.lower[self.between]
138
+ return params
FASTEN/learn.py ADDED
@@ -0,0 +1,165 @@
1
+ from .common import torch, nn, F, np
2
+ from .data import Dataset
3
+ from .model import Model
4
+ from torch.distributions import NegativeBinomial, Binomial
5
+ from torch.distributions.kl import register_kl
6
+
7
+
8
+ MAX_STEPS = 100000
9
+ MAX_SIZE = 4 * 1024**2
10
+
11
+ @register_kl(NegativeBinomial, NegativeBinomial)
12
+ def KL_negative_binomial(p, q):
13
+ log_p, log_q = F.logsigmoid(p.logits), F.logsigmoid(q.logits)
14
+ log_neg_p, log_neg_q = F.logsigmoid(-p.logits), F.logsigmoid(-q.logits)
15
+ d_log, d_log_neg = log_p - log_q, log_neg_p - log_neg_q
16
+ mean_p = p.total_count * torch.exp(p.logits)
17
+ kld_exact = mean_p * d_log + p.total_count * d_log_neg
18
+ log_mean_p = torch.log(p.total_count) + p.logits
19
+ std_p = torch.exp((log_mean_p - log_neg_p) / 2)
20
+ return approximate_KL(p, q, mean_p, std_p, kld_exact)
21
+
22
+ @register_kl(Binomial, Binomial)
23
+ def KL_binomial(p, q):
24
+ log_p, log_q = F.logsigmoid(p.logits), F.logsigmoid(q.logits)
25
+ log_neg_p, log_neg_q = F.logsigmoid(-p.logits), F.logsigmoid(-q.logits)
26
+ d_log, d_log_neg = log_p - log_q, log_neg_p - log_neg_q
27
+ mean_p = p.total_count * torch.sigmoid(p.logits)
28
+ kld_exact = mean_p * d_log + (p.total_count - mean_p) * d_log_neg
29
+ std_p = torch.sqrt(mean_p * (1 - torch.sigmoid(p.logits)))
30
+ return approximate_KL(p, q, mean_p, std_p, kld_exact)
31
+
32
+ def approximate_KL(p, q, mean_p, std_p, kld_exact, n_stds = 6):
33
+ comparable = torch.isclose(p.total_count, q.total_count)
34
+ if comparable.all(): return kld_exact
35
+
36
+ max_k = torch.ceil(mean_p + n_stds * std_p)
37
+ min_k = torch.clamp(torch.floor(mean_p - n_stds * std_p), min = 0.0)
38
+ zeros, ones = torch.zeros_like(std_p), torch.ones_like(std_p)
39
+ total_size = torch.where(comparable, zeros, max_k - min_k)
40
+ integer_size = torch.floor(total_size) + 1
41
+ n_steps = int(max(1, min(integer_size.max().item(), MAX_STEPS)))
42
+
43
+ max_width = total_size / max(n_steps - 1, 1)
44
+ width = torch.where(total_size > n_steps, max_width, ones)
45
+ full = torch.full_like(total_size, n_steps)
46
+ valid = torch.where(total_size > n_steps, full, integer_size)
47
+ size = int(max(1, min(MAX_SIZE, n_steps)))
48
+
49
+ kld_approx = torch.zeros_like(mean_p)
50
+ min_k, valid = min_k.unsqueeze(0), valid.unsqueeze(0)
51
+ width = width.unsqueeze(0)
52
+ for i in range(0, n_steps, size):
53
+ j = min(i + size, n_steps)
54
+ shape = [-1] + [1] * len(mean_p.shape)
55
+ delta = torch.arange(i, j, device = p.logits.device).view(shape)
56
+ k = torch.where(delta < valid, min_k + delta * width, min_k)
57
+ log_prob_p, log_prob_q = p.log_prob(k), q.log_prob(k)
58
+ kld = torch.exp(log_prob_p) * (log_prob_p - log_prob_q) * width
59
+ kld = torch.where(delta < valid, kld, torch.zeros_like(kld))
60
+ kld_approx += torch.sum(kld, dim = 0)
61
+ return torch.where(comparable, kld_exact, kld_approx)
62
+
63
+
64
+ class Partition():
65
+ def __init__(self, dataset: Dataset, loss_func: str):
66
+ self.dataset: Dataset = dataset
67
+ self.dataloader: torch.utils.data.DataLoader = None
68
+ self.by_sample: bool = (loss_func == "NLL")
69
+ self.loss: list = []
70
+
71
+ def load(self, batch_size: int, device: torch.device):
72
+ data = self.dataset.samples if self.by_sample else self.dataset.stats
73
+ input_tensor = torch.tensor(data.inputs.values).to(device)
74
+ output_tensor = torch.tensor(data.outputs.values).to(device)
75
+ data_tensor = torch.utils.data.TensorDataset(input_tensor, output_tensor)
76
+ self.dataloader = torch.utils.data.DataLoader(data_tensor, batch_size, shuffle = True)
77
+ self.loss = []
78
+
79
+
80
+ class EarlyStop:
81
+ def __init__(self, patience = 20, min_delta = 0, multiplier = 0.1):
82
+ self.multiplier: float = multiplier
83
+ self.patience: float = patience
84
+ self.min_delta: float = min_delta
85
+ self.counter: int = 0
86
+ self.best_loss: float = float('inf')
87
+ self.avg_loss: float = None
88
+
89
+ def __call__(self, loss):
90
+ if self.avg_loss is not None:
91
+ self.avg_loss *= 1 - self.multiplier
92
+ self.avg_loss += self.multiplier * loss
93
+ else: self.avg_loss = loss
94
+ if self.avg_loss < self.best_loss - self.min_delta:
95
+ if self.avg_loss < self.best_loss:
96
+ self.best_loss = self.avg_loss
97
+ self.counter = 0
98
+ else: self.counter += 1
99
+ return (self.counter >= self.patience)
100
+
101
+
102
+ class Loss(nn.Module):
103
+ def __init__(self, model: Model):
104
+ super().__init__()
105
+ self.loss_func = model.args.loss_func
106
+ self.mean_squared_error = nn.MSELoss(reduction = "none")
107
+ self.register_buffer("min", torch.from_numpy(model.param_scaler.min))
108
+ self.register_buffer("range", torch.from_numpy(model.param_scaler.range))
109
+ self.load_params(model)
110
+
111
+ def load_params(self, model: Model):
112
+ self.num_outputs = len(model.outputs)
113
+ self.dists = [output.dist.base for output in model.outputs.values()]
114
+ masks = torch.zeros((self.num_outputs, len(model.params)), dtype = bool)
115
+ for i, output in enumerate(model.outputs.values()):
116
+ loop = (param in output.dist.params for param in model.params)
117
+ mask = np.fromiter(loop, dtype = bool, count = len(model.params))
118
+ masks[i] = torch.tensor(mask, dtype = bool)
119
+ self.register_buffer("outputs", masks)
120
+
121
+ def forward(self, pred: torch.Tensor, true: torch.Tensor):
122
+ pred_scaled, true_scaled = pred, true
123
+ pred_unscaled = pred * self.range + self.min
124
+ if self.loss_func != "NLL":
125
+ true_unscaled = true * self.range + self.min
126
+ match self.loss_func:
127
+ case "MSE": return self.mean_squared_error(pred_scaled, true_scaled).mean()
128
+ case "KLD": return self.kl_divergence(pred_unscaled, true_unscaled).mean()
129
+ case "NLL": return self.neg_log_likelihood(pred_unscaled, true_scaled).mean()
130
+
131
+ def evaluate(self, pred: Dataset, true: Dataset, dependent: bool) -> list[torch.Tensor, torch.Tensor, torch.Tensor]:
132
+ pred_stats_unscaled = torch.tensor(pred.stats.outputs.values)
133
+ true_stats_unscaled = torch.tensor(true.stats.outputs.values)
134
+ pred_stats_scaled = (pred_stats_unscaled - self.min) / self.range
135
+ true_stats_scaled = (true_stats_unscaled - self.min) / self.range
136
+ mse = self.mean_squared_error(pred_stats_scaled, true_stats_scaled)
137
+ kld = self.kl_divergence(pred_stats_unscaled, true_stats_unscaled)
138
+ if dependent: return mse, kld, None
139
+ sample_groups = true.samples.outputs.groupby(true.samples.group)
140
+ nll = torch.zeros((sample_groups.ngroups, pred.samples.outputs.shape[1]), dtype = float)
141
+ for i, sample_group in sample_groups:
142
+ pred_stats = pred_stats_unscaled[i].unsqueeze(0).repeat(len(sample_group), 1)
143
+ true_samples = torch.tensor(sample_group.values)
144
+ nll[i] = self.neg_log_likelihood(pred_stats, true_samples)
145
+ return mse, kld, nll
146
+
147
+ def neg_log_likelihood(self, pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
148
+ neg_log_likelihoods = []
149
+ for i in range(self.num_outputs):
150
+ pred_params = pred[:, self.outputs[i]]
151
+ pred_fit = self.dists[i](*pred_params.unbind(dim = 1))
152
+ loss = -1 * pred_fit.log_prob(true[:,i])
153
+ neg_log_likelihoods.append(loss.mean())
154
+ return torch.column_stack(neg_log_likelihoods)
155
+
156
+ def kl_divergence(self, pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
157
+ kl_divergences = []
158
+ for i in range(self.num_outputs):
159
+ pred_params = pred[:, self.outputs[i]]
160
+ true_params = true[:, self.outputs[i]]
161
+ pred_fit = self.dists[i](*pred_params.unbind(dim = 1))
162
+ true_fit = self.dists[i](*true_params.unbind(dim = 1))
163
+ loss = torch.distributions.kl_divergence(true_fit, pred_fit)
164
+ kl_divergences.append(loss)
165
+ return torch.column_stack(kl_divergences)