FASTEN-cli 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FASTEN/__init__.py +7 -0
- FASTEN/cli.py +95 -0
- FASTEN/common.py +5 -0
- FASTEN/config.py +110 -0
- FASTEN/data.py +118 -0
- FASTEN/estimate.py +138 -0
- FASTEN/learn.py +165 -0
- FASTEN/model.py +152 -0
- FASTEN/param.py +120 -0
- FASTEN/plot.py +215 -0
- FASTEN/predict.py +66 -0
- FASTEN/train.py +87 -0
- FASTEN/tune.py +92 -0
- FASTEN/utils.py +67 -0
- fasten_cli-1.0.0.dist-info/METADATA +89 -0
- fasten_cli-1.0.0.dist-info/RECORD +19 -0
- fasten_cli-1.0.0.dist-info/WHEEL +5 -0
- fasten_cli-1.0.0.dist-info/entry_points.txt +2 -0
- fasten_cli-1.0.0.dist-info/licenses/LICENSE.md +21 -0
FASTEN/__init__.py
ADDED
FASTEN/cli.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from .model import Model
|
|
2
|
+
from .train import Trainer
|
|
3
|
+
from .predict import Predictor
|
|
4
|
+
from .tune import Tuner
|
|
5
|
+
from .plot import plot_train, plot_predict, plot_tune
|
|
6
|
+
from .common import pd
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
import argparse, time
|
|
9
|
+
|
|
10
|
+
def parse_args():
|
|
11
|
+
superparser = argparse.ArgumentParser(description = "A flexible software framework to approximate computationally \
|
|
12
|
+
intensive simulations using neural-network-based emulators",
|
|
13
|
+
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
|
|
14
|
+
subparsers = superparser.add_subparsers(dest = "command")
|
|
15
|
+
train = subparsers.add_parser("train", description = "Trains emulator on simulation data",
|
|
16
|
+
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
|
|
17
|
+
train.add_argument("-c", "--config", required = True, help = "JSON file defining configuration parameters")
|
|
18
|
+
train.add_argument("-i", "--input", required = True, help = "TSV file with simulation data")
|
|
19
|
+
train.add_argument("-o", "--output", default = "outputs", help = "Folder to output model and figures")
|
|
20
|
+
train.add_argument("-m", "--model", default = None, help = "ZIP file containing initial model")
|
|
21
|
+
tune = subparsers.add_parser("tune", description = "Tunes hyperparameters for emulator on simulation data",
|
|
22
|
+
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
|
|
23
|
+
tune.add_argument("-c", "--config", required = True, help = "JSON file defining configuration parameters")
|
|
24
|
+
tune.add_argument("-i", "--input", required = True, help = "TSV file with simulation data")
|
|
25
|
+
tune.add_argument("-o", "--output", default = "outputs", help = "Folder to output optimal configs and figures")
|
|
26
|
+
tune.add_argument("-n", "--trials", type = int, default = 100, help = "Total number of optimation trials")
|
|
27
|
+
tune.add_argument("--unique", action = "store_true", help = "Prevents re-training with duplicate hyperparameter sets")
|
|
28
|
+
predict = subparsers.add_parser("predict", description = "Predicts simulation data from inputs with emulator",
|
|
29
|
+
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
|
|
30
|
+
predict.add_argument("-m", "--model", required = True, help = "ZIP file containing model")
|
|
31
|
+
predict.add_argument("-i", "--input", required = True, help = "TSV file with simulation inputs")
|
|
32
|
+
predict.add_argument("-o", "--output", default = "outputs.tsv", help = "TSV file to output predicted simulation data")
|
|
33
|
+
predict.add_argument("-n", "--runs", default = 0, type = int, help = "Number of simulation runs per input")
|
|
34
|
+
return superparser.parse_args()
|
|
35
|
+
|
|
36
|
+
def train(args, console):
|
|
37
|
+
model = Model(config_file = args.config,
|
|
38
|
+
model_file = args.model)
|
|
39
|
+
trainer = Trainer(model)
|
|
40
|
+
console.log("Estimating distribution parameters...")
|
|
41
|
+
trainer.load_data(args.input)
|
|
42
|
+
console.log("Training neural network...")
|
|
43
|
+
trainer.execute()
|
|
44
|
+
console.log("Writing outputs...")
|
|
45
|
+
trainer.dump_model(args.output)
|
|
46
|
+
|
|
47
|
+
plot_train(trainer, f"{args.output}/plots/training")
|
|
48
|
+
predictor = Predictor(model, trainer.train.dataset)
|
|
49
|
+
mse, kld, _ = plot_predict(predictor, f"{args.output}/plots/training")
|
|
50
|
+
console.print(f"Average Training MSE = {mse.mean().mean():.3g}\nAverage Training KL Divergence = {kld.mean().mean():.3g}")
|
|
51
|
+
if not trainer.test: return
|
|
52
|
+
predictor = Predictor(model, trainer.test.dataset)
|
|
53
|
+
mse, kld, _ = plot_predict(predictor, f"{args.output}/plots/testing")
|
|
54
|
+
if isinstance(mse, pd.DataFrame):
|
|
55
|
+
console.print(f"Average Testing MSE = {mse.mean().mean():.3g}\nAverage Testing KL Divergence = {kld.mean().mean():.3g}")
|
|
56
|
+
|
|
57
|
+
def predict(args, console):
|
|
58
|
+
model = Model(model_file = args.model)
|
|
59
|
+
console.log("Predicting outputs...")
|
|
60
|
+
predictor = Predictor(model)
|
|
61
|
+
predictor.load_inputs(args.input, args.runs)
|
|
62
|
+
predictor.execute()
|
|
63
|
+
|
|
64
|
+
if args.runs: predictor.dump_samples(args.output)
|
|
65
|
+
else: predictor.dump_statistics(args.output)
|
|
66
|
+
|
|
67
|
+
def tune(args, console):
|
|
68
|
+
model = Model(config_file = args.config)
|
|
69
|
+
trainer = Trainer(model)
|
|
70
|
+
if args.trials:
|
|
71
|
+
console.log("Estimating distribution parameters...")
|
|
72
|
+
trainer.load_data(args.input)
|
|
73
|
+
|
|
74
|
+
console.log("Tuning hyperparameters...")
|
|
75
|
+
tuner = Tuner(trainer)
|
|
76
|
+
tuner.load_study(args.output)
|
|
77
|
+
tuner.execute(args.trials, args.unique)
|
|
78
|
+
|
|
79
|
+
console.log("Writing outputs...")
|
|
80
|
+
tuner.dump_trials(args.output)
|
|
81
|
+
plot_tune(tuner, f"{args.output}/plots")
|
|
82
|
+
|
|
83
|
+
def main():
|
|
84
|
+
args = parse_args()
|
|
85
|
+
start = time.perf_counter()
|
|
86
|
+
console = Console()
|
|
87
|
+
match args.command:
|
|
88
|
+
case "train": train(args, console)
|
|
89
|
+
case "predict": predict(args, console)
|
|
90
|
+
case "tune": tune(args, console)
|
|
91
|
+
end = time.perf_counter()
|
|
92
|
+
if end - start < 60: runtime = f"{(end - start):.2f} s"
|
|
93
|
+
elif end - start < 60 * 60: runtime = f"{(end - start) / 60:.2f} m"
|
|
94
|
+
else: runtime = f"{(end - start) / (60 * 60):.2f} h"
|
|
95
|
+
console.log(f"Done in {runtime}")
|
FASTEN/common.py
ADDED
FASTEN/config.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from .common import pd, torch
|
|
2
|
+
from .param import ModelDist
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from warnings import warn
|
|
5
|
+
import pydantic as pdc
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ModelArgs(pdc.BaseModel):
|
|
9
|
+
test_split: float = pdc.Field(default = 0.1, ge = 0.0, lt = 1.0)
|
|
10
|
+
valid_split: float = pdc.Field(default = 0.1, ge = 0.0, lt = 1.0)
|
|
11
|
+
estimator: Literal["MoM", "MLE"] = "MLE"
|
|
12
|
+
rand_seed: int | None = None
|
|
13
|
+
|
|
14
|
+
architecture: Literal["rectangular", "pyramidal"] = "pyramidal"
|
|
15
|
+
hidden_layers: int = pdc.Field(default = 2, ge = 0)
|
|
16
|
+
hidden_size: int = pdc.Field(default = 64, gt = 0)
|
|
17
|
+
|
|
18
|
+
device: Literal["cpu", "cuda"] = "cpu"
|
|
19
|
+
batch_size: int = pdc.Field(default = 32, gt = 0)
|
|
20
|
+
num_epochs: int = pdc.Field(default = 1e5, gt = 0)
|
|
21
|
+
|
|
22
|
+
early_stop: bool = True
|
|
23
|
+
patience: int = pdc.Field(default = 20, ge = 0)
|
|
24
|
+
min_delta: float = pdc.Field(default = 0.0)
|
|
25
|
+
|
|
26
|
+
optimizer: Literal["SGD", "Adam", "AdamW"] = "AdamW"
|
|
27
|
+
loss_func: Literal["MSE", "KLD", "NLL"] = "NLL"
|
|
28
|
+
learn_rate: float = pdc.Field(default = 1e-3, gt = 0.0)
|
|
29
|
+
weight_decay: float = pdc.Field(default = 0.0, ge = 0.0)
|
|
30
|
+
momentum: float = pdc.Field(default = 0.0, ge = 0.0)
|
|
31
|
+
|
|
32
|
+
@pdc.field_validator("device", mode = "after")
|
|
33
|
+
@classmethod
|
|
34
|
+
def validate_device(cls, value):
|
|
35
|
+
if value == "cuda" and not torch.cuda.is_available():
|
|
36
|
+
warn("PyTorch cannot find a compatible GPU. Defaulting to CPU.")
|
|
37
|
+
return torch.device("cpu")
|
|
38
|
+
return torch.device(value)
|
|
39
|
+
|
|
40
|
+
@pdc.field_validator("optimizer", mode = "after")
|
|
41
|
+
@classmethod
|
|
42
|
+
def validate_optimizer(cls, value: str):
|
|
43
|
+
return getattr(torch.optim, value)
|
|
44
|
+
|
|
45
|
+
@pdc.model_validator(mode = "after")
|
|
46
|
+
def validate_early_stop(self):
|
|
47
|
+
if not self.valid_split and self.early_stop:
|
|
48
|
+
raise ValueError("Non-empty validation set required for early stopping.")
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
@pdc.model_validator(mode = "after")
|
|
52
|
+
def validate_splits(self):
|
|
53
|
+
if self.valid_split + self.test_split >= 1:
|
|
54
|
+
raise ValueError("Non-empty training set required. Decrease size of validation or testing set.")
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ModelInput(pdc.BaseModel):
|
|
59
|
+
label: str
|
|
60
|
+
name: str = pdc.Field(default_factory = lambda data: data['label'])
|
|
61
|
+
type: Literal["float", "integer", "string"] = "float"
|
|
62
|
+
|
|
63
|
+
def validate_data(self, data: pd.DataFrame, label: str):
|
|
64
|
+
if data[label].isna().any():
|
|
65
|
+
raise ValueError(f"Training data contains missing or undefined values: {self.name}.")
|
|
66
|
+
if self.type == "string" and not pd.api.types.is_string_dtype(data[label]):
|
|
67
|
+
raise ValueError(f"Training data has invalid values: {self.name}.")
|
|
68
|
+
if self.type in ["integer", "float"]:
|
|
69
|
+
if not pd.api.types.is_numeric_dtype(data[label]):
|
|
70
|
+
raise ValueError(f"Training data has invalid values: {self.name}.")
|
|
71
|
+
else: data[label] = data[label].astype(float)
|
|
72
|
+
if self.type == "integer" and (data[label] % 1 != 0).any():
|
|
73
|
+
warn(f"Integer type specified for non-integer training data: {self.name}. Rounding to nearest integer.")
|
|
74
|
+
data[label] = data[label].round()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ModelOutput(pdc.BaseModel): # validate priors
|
|
78
|
+
model_config = pdc.ConfigDict(arbitrary_types_allowed = True, extra = "allow")
|
|
79
|
+
|
|
80
|
+
label: str
|
|
81
|
+
dist: str | ModelDist
|
|
82
|
+
name: str = pdc.Field(default_factory = lambda data: data['label'])
|
|
83
|
+
type: Literal["float", "integer"] = "float"
|
|
84
|
+
min_thresh: float | None = None
|
|
85
|
+
max_thresh: float | None = None
|
|
86
|
+
|
|
87
|
+
@pdc.field_validator("dist", mode = "after")
|
|
88
|
+
@classmethod
|
|
89
|
+
def validate_distribution(cls, value: str) -> ModelDist:
|
|
90
|
+
try: dist = ModelDist(value)
|
|
91
|
+
except AttributeError: raise ValueError(f"Invalid distribution specified.")
|
|
92
|
+
return dist
|
|
93
|
+
|
|
94
|
+
@pdc.model_validator(mode = "after")
|
|
95
|
+
def validate_discrete(self):
|
|
96
|
+
if self.dist.support.discrete and self.type != "integer":
|
|
97
|
+
raise ValueError(f"Discrete distribution specified for non-integer training data: {self.name}.")
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def validate_data(self, data: pd.DataFrame, label: str):
|
|
101
|
+
if data[label].isna().any():
|
|
102
|
+
raise ValueError(f"Training data contains missing or undefined values: {self.name}.")
|
|
103
|
+
if not pd.api.types.is_numeric_dtype(data[label]):
|
|
104
|
+
raise ValueError(f"Training data has invalid values: {self.name}.")
|
|
105
|
+
else: data[label] = data[label].astype(float)
|
|
106
|
+
if self.type == "integer" and (data[label] % 1 != 0).any():
|
|
107
|
+
warn(f"Integer type specified for non-integer training data: {self.name}. Rounding to nearest integer.")
|
|
108
|
+
data[label] = data[label].round()
|
|
109
|
+
if not self.dist.support.validate(data[label]):
|
|
110
|
+
raise AssertionError(f"Training data contains values outside domain of distribution: {self.name}")
|
FASTEN/data.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .common import pd, os, np
|
|
3
|
+
from .config import ModelOutput
|
|
4
|
+
from .utils import Scaler, Encoder
|
|
5
|
+
from .estimate import Estimator
|
|
6
|
+
from .model import Model
|
|
7
|
+
from sklearn.model_selection import train_test_split
|
|
8
|
+
from k_means_constrained import KMeansConstrained
|
|
9
|
+
from typing import Any
|
|
10
|
+
from abc import ABC
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Dataset():
|
|
14
|
+
def __init__(self, samples: Samples = None, stats: Statistics = None):
|
|
15
|
+
self.samples = samples if samples else Samples()
|
|
16
|
+
self.stats = stats if stats else Statistics()
|
|
17
|
+
|
|
18
|
+
def load_samples(self, data_file: str, model: Model):
|
|
19
|
+
self.samples.load_data(data_file, model.inputs, model.outputs)
|
|
20
|
+
self.samples.filter_data(model.outputs)
|
|
21
|
+
for label, config in model.outputs.items():
|
|
22
|
+
config.validate_data(self.samples.outputs, label)
|
|
23
|
+
for label, config in model.inputs.items():
|
|
24
|
+
config.validate_data(self.samples.inputs, label)
|
|
25
|
+
self.samples.encode_inputs(model.encoder)
|
|
26
|
+
self.samples.scale_inputs(model.input_scaler)
|
|
27
|
+
self.samples.group_data()
|
|
28
|
+
|
|
29
|
+
def load_stats(self, model: Model, estimator: str):
|
|
30
|
+
self.estimate_stats(Estimator(estimator, model.outputs))
|
|
31
|
+
self.stats.scale_outputs(model.param_scaler)
|
|
32
|
+
model.network.load_scaler(model.param_scaler)
|
|
33
|
+
|
|
34
|
+
def estimate_stats(self, estimator: Estimator):
|
|
35
|
+
input_data = self.samples.inputs.groupby(self.samples.group)
|
|
36
|
+
output_data = self.samples.outputs.groupby(self.samples.group)
|
|
37
|
+
self.stats.inputs = input_data.first().reset_index(drop = True)
|
|
38
|
+
self.stats.outputs = estimator.execute(output_data)
|
|
39
|
+
|
|
40
|
+
def split(self, split_prop: float, rand_seed: int):
|
|
41
|
+
if not split_prop: return None
|
|
42
|
+
groups, index = self.stats.cluster_data(split_prop), self.stats.inputs.index
|
|
43
|
+
train_index, test_index = train_test_split(index, test_size = split_prop, stratify = groups, random_state = rand_seed)
|
|
44
|
+
train_index, test_index = sorted(train_index), sorted(test_index)
|
|
45
|
+
train_samples, test_samples = self.samples.group.isin(train_index), self.samples.group.isin(test_index)
|
|
46
|
+
test_stats = Statistics(self.stats.inputs.loc[test_index].reset_index(drop = True),
|
|
47
|
+
None if self.stats.outputs is None else self.stats.outputs.loc[test_index].reset_index(drop = True))
|
|
48
|
+
self.stats = Statistics(self.stats.inputs.loc[train_index].reset_index(drop = True),
|
|
49
|
+
None if self.stats.outputs is None else self.stats.outputs.loc[train_index].reset_index(drop = True))
|
|
50
|
+
test_samples = Samples(self.samples.inputs.loc[test_samples].reset_index(drop = True),
|
|
51
|
+
self.samples.outputs.loc[test_samples].reset_index(drop = True))
|
|
52
|
+
self.samples = Samples(self.samples.inputs.loc[train_samples].reset_index(drop = True),
|
|
53
|
+
self.samples.outputs.loc[train_samples].reset_index(drop = True))
|
|
54
|
+
return Dataset(samples = test_samples, stats = test_stats)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Data(ABC):
|
|
58
|
+
def __init__(self, inputs: pd.DataFrame = None, outputs: pd.DataFrame = None):
|
|
59
|
+
self.inputs, self.outputs = inputs, outputs
|
|
60
|
+
|
|
61
|
+
def _repr_html_(self):
|
|
62
|
+
return pd.concat([self.inputs, self.outputs], axis = 1)._repr_html_()
|
|
63
|
+
|
|
64
|
+
def __str__(self):
|
|
65
|
+
return pd.concat([self.inputs, self.outputs], axis = 1).__str__()
|
|
66
|
+
|
|
67
|
+
def dump_data(self, data_file: str):
|
|
68
|
+
data = pd.concat([self.inputs, self.outputs], axis = 1)
|
|
69
|
+
data.to_csv(data_file, sep = '\t', index = False)
|
|
70
|
+
|
|
71
|
+
def load_data(self, data_file: str, inputs: dict[str, Any], outputs: dict[str, Any]):
|
|
72
|
+
inputs, outputs = list(inputs), list(outputs) if outputs else None
|
|
73
|
+
self.inputs = pd.read_csv(data_file, sep = "\t", usecols = inputs)[inputs]
|
|
74
|
+
self.inputs = self.inputs.sort_values(by = inputs)
|
|
75
|
+
if outputs:
|
|
76
|
+
self.outputs = pd.read_csv(data_file, sep = "\t", usecols = outputs)[outputs]
|
|
77
|
+
self.outputs = self.outputs.loc[self.inputs.index].reset_index(drop = True)
|
|
78
|
+
self.inputs = self.inputs.reset_index(drop = True)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class Samples(Data):
|
|
82
|
+
def __init__(self, inputs: pd.DataFrame = None, outputs: pd.DataFrame = None):
|
|
83
|
+
super().__init__(inputs, outputs)
|
|
84
|
+
if isinstance(inputs, pd.DataFrame): self.group_data()
|
|
85
|
+
else: self.group: pd.Series = None
|
|
86
|
+
|
|
87
|
+
def filter_data(self, outputs: dict[str, ModelOutput]):
|
|
88
|
+
for label, output in outputs.items():
|
|
89
|
+
if output.min_thresh is not None:
|
|
90
|
+
mask = (self.outputs[label] > output.min_thresh)
|
|
91
|
+
self.outputs = self.outputs[mask].reset_index(drop = True)
|
|
92
|
+
self.inputs = self.inputs[mask].reset_index(drop = True)
|
|
93
|
+
if output.max_thresh is not None:
|
|
94
|
+
mask = (self.outputs[label] < output.max_thresh)
|
|
95
|
+
self.outputs = self.outputs[mask].reset_index(drop = True)
|
|
96
|
+
self.inputs = self.inputs[mask].reset_index(drop = True)
|
|
97
|
+
|
|
98
|
+
def group_data(self):
|
|
99
|
+
matches = (self.inputs != self.inputs.shift())
|
|
100
|
+
self.group = matches.any(axis = 1).cumsum() - 1
|
|
101
|
+
|
|
102
|
+
def scale_inputs(self, scaler: Scaler): scaler.transform(self.inputs)
|
|
103
|
+
def unscale_inputs(self, scaler: Scaler): scaler.inverse_transform(self.inputs)
|
|
104
|
+
def encode_inputs(self, encoder: Encoder): encoder.transform(self.inputs)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Statistics(Data):
|
|
108
|
+
def cluster_data(self, split_prop: float = 1, n_clusters: int | None = None) -> np.ndarray:
|
|
109
|
+
if n_clusters is None:
|
|
110
|
+
n_clusters = int(self.inputs.shape[0] * split_prop / 5)
|
|
111
|
+
if n_clusters < 1: return None
|
|
112
|
+
kmeans = KMeansConstrained(n_clusters = n_clusters, size_min = 5)
|
|
113
|
+
return kmeans.fit_predict(self.inputs)
|
|
114
|
+
|
|
115
|
+
def scale_outputs(self, scaler: Scaler): scaler.transform(self.outputs)
|
|
116
|
+
def unscale_outputs(self, scaler: Scaler): scaler.inverse_transform(self.outputs)
|
|
117
|
+
def scale_inputs(self, scaler: Scaler): scaler.transform(self.inputs)
|
|
118
|
+
def unscale_inputs(self, scaler: Scaler): scaler.inverse_transform(self.inputs)
|
FASTEN/estimate.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from .common import np, pd, torch, F, os
|
|
2
|
+
from .param import ModelDist, Constraint
|
|
3
|
+
from .config import ModelOutput
|
|
4
|
+
from pandas.api.typing import DataFrameGroupBy
|
|
5
|
+
from rich import progress
|
|
6
|
+
|
|
7
|
+
PROGRESS = progress.Progress(
|
|
8
|
+
progress.TextColumn("{task.description}"),
|
|
9
|
+
progress.BarColumn(),
|
|
10
|
+
progress.MofNCompleteColumn(),
|
|
11
|
+
progress.TimeRemainingColumn(),
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
class Estimator():
|
|
15
|
+
|
|
16
|
+
class Moments():
|
|
17
|
+
@staticmethod
|
|
18
|
+
def Exponential(data: pd.Series) -> torch.Tensor:
|
|
19
|
+
mean = data.mean()
|
|
20
|
+
if not mean: raise AssertionError()
|
|
21
|
+
return torch.tensor([1 / mean])
|
|
22
|
+
@staticmethod
|
|
23
|
+
def Normal(data: pd.Series) -> torch.Tensor:
|
|
24
|
+
mean, var = data.mean(), data.var()
|
|
25
|
+
if np.isnan(var) or not var: raise AssertionError()
|
|
26
|
+
return torch.tensor([mean, np.sqrt(var)])
|
|
27
|
+
@staticmethod
|
|
28
|
+
def HalfNormal(data: pd.Series) -> torch.Tensor:
|
|
29
|
+
mean = np.sqrt(data.pow(2).mean())
|
|
30
|
+
if not mean: raise AssertionError()
|
|
31
|
+
return torch.tensor([mean])
|
|
32
|
+
@staticmethod
|
|
33
|
+
def LogNormal(data: pd.Series) -> torch.Tensor:
|
|
34
|
+
mean, var = np.log(data).mean(), np.sqrt(np.log(data).var())
|
|
35
|
+
if np.isnan(var) or not var: raise AssertionError()
|
|
36
|
+
return torch.tensor([mean, var])
|
|
37
|
+
@staticmethod
|
|
38
|
+
def Uniform(data: pd.Series) -> torch.Tensor:
|
|
39
|
+
if data.min() == data.max(): # raise AssertionError()
|
|
40
|
+
return torch.tensor([data.min(), data.max()])
|
|
41
|
+
return torch.tensor([data.min(), data.max()])
|
|
42
|
+
@staticmethod
|
|
43
|
+
def Geometric(data: pd.Series) -> torch.Tensor:
|
|
44
|
+
return torch.tensor([1 / (1 + data.mean())])
|
|
45
|
+
@staticmethod
|
|
46
|
+
def Poisson(data: pd.Series) -> torch.Tensor:
|
|
47
|
+
mean = data.mean()
|
|
48
|
+
if not mean: raise AssertionError()
|
|
49
|
+
return torch.tensor([mean])
|
|
50
|
+
@staticmethod
|
|
51
|
+
def Bernoulli(data: pd.Series) -> torch.Tensor:
|
|
52
|
+
return torch.tensor([data.mean()])
|
|
53
|
+
@staticmethod
|
|
54
|
+
def Laplace(data: pd.Series) -> torch.Tensor:
|
|
55
|
+
mad = (data - data.median()).abs().mean()
|
|
56
|
+
if mad <= 0: raise AssertionError()
|
|
57
|
+
return torch.tensor([data.median(), mad])
|
|
58
|
+
@staticmethod
|
|
59
|
+
def Pareto(data: pd.Series) -> torch.Tensor:
|
|
60
|
+
log_sum = np.log(data / data.min()).sum()
|
|
61
|
+
if not log_sum: raise AssertionError()
|
|
62
|
+
return torch.tensor([data.min(), data.shape[0] / log_sum])
|
|
63
|
+
@staticmethod
|
|
64
|
+
def Binomial(data: pd.Series) -> torch.Tensor:
|
|
65
|
+
mean, var = data.mean(), data.var()
|
|
66
|
+
if not var or np.isnan(var): raise AssertionError()
|
|
67
|
+
if not mean or var >= mean: raise AssertionError()
|
|
68
|
+
probs = 1 - var / mean
|
|
69
|
+
total_counts = max(mean / probs, data.max())
|
|
70
|
+
return torch.tensor([total_counts, np.log(probs / (1 - probs))])
|
|
71
|
+
@staticmethod
|
|
72
|
+
def NegativeBinomial(data: pd.Series) -> torch.Tensor:
|
|
73
|
+
mean, var = data.mean().item(), data.var().item()
|
|
74
|
+
if not var or np.isnan(var): raise AssertionError()
|
|
75
|
+
if not mean or var <= mean: raise AssertionError()
|
|
76
|
+
total_counts, probs = mean**2 / (var - mean), 1 - mean / var
|
|
77
|
+
return torch.tensor([total_counts, np.log(probs / (1 - probs))])
|
|
78
|
+
|
|
79
|
+
def __init__(self, estimator: str, outputs: dict[str, ModelOutput]):
|
|
80
|
+
self.estimator: str = estimator
|
|
81
|
+
self.outputs: dict[str, ModelOutput] = outputs
|
|
82
|
+
|
|
83
|
+
def execute(self, groups: DataFrameGroupBy) -> pd.DataFrame:
|
|
84
|
+
torch.set_num_threads(os.cpu_count())
|
|
85
|
+
with PROGRESS as progress:
|
|
86
|
+
params = [self.iterate(groups, label, progress) for label in self.outputs]
|
|
87
|
+
return pd.concat(params, axis = 1)
|
|
88
|
+
|
|
89
|
+
def iterate(self, total_groups: DataFrameGroupBy, label: str, progress: progress.Progress) -> pd.DataFrame:
|
|
90
|
+
output, groups, params = self.outputs[label], total_groups[label], dict()
|
|
91
|
+
task = progress.add_task(output.name, total = groups.ngroups)
|
|
92
|
+
for group, data in groups:
|
|
93
|
+
params[group] = self.estimate(output, data)
|
|
94
|
+
progress.update(task, advance = 1)
|
|
95
|
+
return pd.DataFrame.from_dict(params, "index", None, output.dist.params).sort_index()
|
|
96
|
+
|
|
97
|
+
def estimate(self, output: ModelOutput, data: pd.Series) -> np.ndarray:
|
|
98
|
+
if self.estimator == "MoM" and hasattr(self.Moments, output.dist.name):
|
|
99
|
+
method = getattr(self.Moments, output.dist.name)
|
|
100
|
+
try: return method(data).numpy()
|
|
101
|
+
except AssertionError: pass
|
|
102
|
+
return self.max_likelihood(output.dist, torch.tensor(data.values))
|
|
103
|
+
|
|
104
|
+
def max_likelihood(self, dist: ModelDist, data: torch.Tensor) -> np.ndarray:
|
|
105
|
+
self.load_constraints(dist, data)
|
|
106
|
+
if dist.support.discrete: data = data.to(int)
|
|
107
|
+
weights = torch.randn(len(dist.params), requires_grad = True, dtype = float)
|
|
108
|
+
optimizer = torch.optim.LBFGS([weights], max_iter = 200, line_search_fn = "strong_wolfe",
|
|
109
|
+
tolerance_grad = 1e-12, tolerance_change = 1e-12)
|
|
110
|
+
|
|
111
|
+
def closure():
|
|
112
|
+
optimizer.zero_grad()
|
|
113
|
+
params = self.apply_constraints(weights)
|
|
114
|
+
params.nan_to_num_(nan = 1e-16)
|
|
115
|
+
fit = dist.base(*params)
|
|
116
|
+
loss = -1 * fit.log_prob(data).mean()
|
|
117
|
+
loss.backward()
|
|
118
|
+
return loss
|
|
119
|
+
|
|
120
|
+
optimizer.step(closure)
|
|
121
|
+
with torch.no_grad():
|
|
122
|
+
return self.apply_constraints(weights).numpy()
|
|
123
|
+
|
|
124
|
+
def load_constraints(self, dist: ModelDist, data: pd.Series):
|
|
125
|
+
for rule in Constraint.RULES: setattr(self, rule, torch.zeros(len(dist.params), dtype = bool))
|
|
126
|
+
for value in Constraint.VALUES: setattr(self, value, torch.zeros(len(dist.params), dtype = float))
|
|
127
|
+
for i, param in enumerate(dist.params.values()):
|
|
128
|
+
min_val, max_val = data.min().item(), data.max().item()
|
|
129
|
+
param.load_constraints(dist.support, min_val, max_val)
|
|
130
|
+
for rule in Constraint.RULES: getattr(self, rule)[i] = param.constraints.get_rule(rule)
|
|
131
|
+
for value in Constraint.VALUES: getattr(self, value)[i] = param.constraints.get_value(value)
|
|
132
|
+
|
|
133
|
+
def apply_constraints(self, weights: torch.Tensor) -> torch.Tensor:
|
|
134
|
+
params = weights.clone()
|
|
135
|
+
params[self.greater_than] = F.softplus(params[self.greater_than]) + self.lower[self.greater_than]
|
|
136
|
+
params[self.less_than] = self.upper[self.less_than] - F.softplus(params[self.less_than])
|
|
137
|
+
params[self.between] = F.sigmoid(params[self.between]) * self.interval[self.between] + self.lower[self.between]
|
|
138
|
+
return params
|
FASTEN/learn.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from .common import torch, nn, F, np
|
|
2
|
+
from .data import Dataset
|
|
3
|
+
from .model import Model
|
|
4
|
+
from torch.distributions import NegativeBinomial, Binomial
|
|
5
|
+
from torch.distributions.kl import register_kl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
MAX_STEPS = 100000
|
|
9
|
+
MAX_SIZE = 4 * 1024**2
|
|
10
|
+
|
|
11
|
+
@register_kl(NegativeBinomial, NegativeBinomial)
|
|
12
|
+
def KL_negative_binomial(p, q):
|
|
13
|
+
log_p, log_q = F.logsigmoid(p.logits), F.logsigmoid(q.logits)
|
|
14
|
+
log_neg_p, log_neg_q = F.logsigmoid(-p.logits), F.logsigmoid(-q.logits)
|
|
15
|
+
d_log, d_log_neg = log_p - log_q, log_neg_p - log_neg_q
|
|
16
|
+
mean_p = p.total_count * torch.exp(p.logits)
|
|
17
|
+
kld_exact = mean_p * d_log + p.total_count * d_log_neg
|
|
18
|
+
log_mean_p = torch.log(p.total_count) + p.logits
|
|
19
|
+
std_p = torch.exp((log_mean_p - log_neg_p) / 2)
|
|
20
|
+
return approximate_KL(p, q, mean_p, std_p, kld_exact)
|
|
21
|
+
|
|
22
|
+
@register_kl(Binomial, Binomial)
|
|
23
|
+
def KL_binomial(p, q):
|
|
24
|
+
log_p, log_q = F.logsigmoid(p.logits), F.logsigmoid(q.logits)
|
|
25
|
+
log_neg_p, log_neg_q = F.logsigmoid(-p.logits), F.logsigmoid(-q.logits)
|
|
26
|
+
d_log, d_log_neg = log_p - log_q, log_neg_p - log_neg_q
|
|
27
|
+
mean_p = p.total_count * torch.sigmoid(p.logits)
|
|
28
|
+
kld_exact = mean_p * d_log + (p.total_count - mean_p) * d_log_neg
|
|
29
|
+
std_p = torch.sqrt(mean_p * (1 - torch.sigmoid(p.logits)))
|
|
30
|
+
return approximate_KL(p, q, mean_p, std_p, kld_exact)
|
|
31
|
+
|
|
32
|
+
def approximate_KL(p, q, mean_p, std_p, kld_exact, n_stds = 6):
|
|
33
|
+
comparable = torch.isclose(p.total_count, q.total_count)
|
|
34
|
+
if comparable.all(): return kld_exact
|
|
35
|
+
|
|
36
|
+
max_k = torch.ceil(mean_p + n_stds * std_p)
|
|
37
|
+
min_k = torch.clamp(torch.floor(mean_p - n_stds * std_p), min = 0.0)
|
|
38
|
+
zeros, ones = torch.zeros_like(std_p), torch.ones_like(std_p)
|
|
39
|
+
total_size = torch.where(comparable, zeros, max_k - min_k)
|
|
40
|
+
integer_size = torch.floor(total_size) + 1
|
|
41
|
+
n_steps = int(max(1, min(integer_size.max().item(), MAX_STEPS)))
|
|
42
|
+
|
|
43
|
+
max_width = total_size / max(n_steps - 1, 1)
|
|
44
|
+
width = torch.where(total_size > n_steps, max_width, ones)
|
|
45
|
+
full = torch.full_like(total_size, n_steps)
|
|
46
|
+
valid = torch.where(total_size > n_steps, full, integer_size)
|
|
47
|
+
size = int(max(1, min(MAX_SIZE, n_steps)))
|
|
48
|
+
|
|
49
|
+
kld_approx = torch.zeros_like(mean_p)
|
|
50
|
+
min_k, valid = min_k.unsqueeze(0), valid.unsqueeze(0)
|
|
51
|
+
width = width.unsqueeze(0)
|
|
52
|
+
for i in range(0, n_steps, size):
|
|
53
|
+
j = min(i + size, n_steps)
|
|
54
|
+
shape = [-1] + [1] * len(mean_p.shape)
|
|
55
|
+
delta = torch.arange(i, j, device = p.logits.device).view(shape)
|
|
56
|
+
k = torch.where(delta < valid, min_k + delta * width, min_k)
|
|
57
|
+
log_prob_p, log_prob_q = p.log_prob(k), q.log_prob(k)
|
|
58
|
+
kld = torch.exp(log_prob_p) * (log_prob_p - log_prob_q) * width
|
|
59
|
+
kld = torch.where(delta < valid, kld, torch.zeros_like(kld))
|
|
60
|
+
kld_approx += torch.sum(kld, dim = 0)
|
|
61
|
+
return torch.where(comparable, kld_exact, kld_approx)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Partition():
|
|
65
|
+
def __init__(self, dataset: Dataset, loss_func: str):
|
|
66
|
+
self.dataset: Dataset = dataset
|
|
67
|
+
self.dataloader: torch.utils.data.DataLoader = None
|
|
68
|
+
self.by_sample: bool = (loss_func == "NLL")
|
|
69
|
+
self.loss: list = []
|
|
70
|
+
|
|
71
|
+
def load(self, batch_size: int, device: torch.device):
|
|
72
|
+
data = self.dataset.samples if self.by_sample else self.dataset.stats
|
|
73
|
+
input_tensor = torch.tensor(data.inputs.values).to(device)
|
|
74
|
+
output_tensor = torch.tensor(data.outputs.values).to(device)
|
|
75
|
+
data_tensor = torch.utils.data.TensorDataset(input_tensor, output_tensor)
|
|
76
|
+
self.dataloader = torch.utils.data.DataLoader(data_tensor, batch_size, shuffle = True)
|
|
77
|
+
self.loss = []
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class EarlyStop:
|
|
81
|
+
def __init__(self, patience = 20, min_delta = 0, multiplier = 0.1):
|
|
82
|
+
self.multiplier: float = multiplier
|
|
83
|
+
self.patience: float = patience
|
|
84
|
+
self.min_delta: float = min_delta
|
|
85
|
+
self.counter: int = 0
|
|
86
|
+
self.best_loss: float = float('inf')
|
|
87
|
+
self.avg_loss: float = None
|
|
88
|
+
|
|
89
|
+
def __call__(self, loss):
|
|
90
|
+
if self.avg_loss is not None:
|
|
91
|
+
self.avg_loss *= 1 - self.multiplier
|
|
92
|
+
self.avg_loss += self.multiplier * loss
|
|
93
|
+
else: self.avg_loss = loss
|
|
94
|
+
if self.avg_loss < self.best_loss - self.min_delta:
|
|
95
|
+
if self.avg_loss < self.best_loss:
|
|
96
|
+
self.best_loss = self.avg_loss
|
|
97
|
+
self.counter = 0
|
|
98
|
+
else: self.counter += 1
|
|
99
|
+
return (self.counter >= self.patience)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class Loss(nn.Module):
|
|
103
|
+
def __init__(self, model: Model):
|
|
104
|
+
super().__init__()
|
|
105
|
+
self.loss_func = model.args.loss_func
|
|
106
|
+
self.mean_squared_error = nn.MSELoss(reduction = "none")
|
|
107
|
+
self.register_buffer("min", torch.from_numpy(model.param_scaler.min))
|
|
108
|
+
self.register_buffer("range", torch.from_numpy(model.param_scaler.range))
|
|
109
|
+
self.load_params(model)
|
|
110
|
+
|
|
111
|
+
def load_params(self, model: Model):
|
|
112
|
+
self.num_outputs = len(model.outputs)
|
|
113
|
+
self.dists = [output.dist.base for output in model.outputs.values()]
|
|
114
|
+
masks = torch.zeros((self.num_outputs, len(model.params)), dtype = bool)
|
|
115
|
+
for i, output in enumerate(model.outputs.values()):
|
|
116
|
+
loop = (param in output.dist.params for param in model.params)
|
|
117
|
+
mask = np.fromiter(loop, dtype = bool, count = len(model.params))
|
|
118
|
+
masks[i] = torch.tensor(mask, dtype = bool)
|
|
119
|
+
self.register_buffer("outputs", masks)
|
|
120
|
+
|
|
121
|
+
def forward(self, pred: torch.Tensor, true: torch.Tensor):
|
|
122
|
+
pred_scaled, true_scaled = pred, true
|
|
123
|
+
pred_unscaled = pred * self.range + self.min
|
|
124
|
+
if self.loss_func != "NLL":
|
|
125
|
+
true_unscaled = true * self.range + self.min
|
|
126
|
+
match self.loss_func:
|
|
127
|
+
case "MSE": return self.mean_squared_error(pred_scaled, true_scaled).mean()
|
|
128
|
+
case "KLD": return self.kl_divergence(pred_unscaled, true_unscaled).mean()
|
|
129
|
+
case "NLL": return self.neg_log_likelihood(pred_unscaled, true_scaled).mean()
|
|
130
|
+
|
|
131
|
+
def evaluate(self, pred: Dataset, true: Dataset, dependent: bool) -> list[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
132
|
+
pred_stats_unscaled = torch.tensor(pred.stats.outputs.values)
|
|
133
|
+
true_stats_unscaled = torch.tensor(true.stats.outputs.values)
|
|
134
|
+
pred_stats_scaled = (pred_stats_unscaled - self.min) / self.range
|
|
135
|
+
true_stats_scaled = (true_stats_unscaled - self.min) / self.range
|
|
136
|
+
mse = self.mean_squared_error(pred_stats_scaled, true_stats_scaled)
|
|
137
|
+
kld = self.kl_divergence(pred_stats_unscaled, true_stats_unscaled)
|
|
138
|
+
if dependent: return mse, kld, None
|
|
139
|
+
sample_groups = true.samples.outputs.groupby(true.samples.group)
|
|
140
|
+
nll = torch.zeros((sample_groups.ngroups, pred.samples.outputs.shape[1]), dtype = float)
|
|
141
|
+
for i, sample_group in sample_groups:
|
|
142
|
+
pred_stats = pred_stats_unscaled[i].unsqueeze(0).repeat(len(sample_group), 1)
|
|
143
|
+
true_samples = torch.tensor(sample_group.values)
|
|
144
|
+
nll[i] = self.neg_log_likelihood(pred_stats, true_samples)
|
|
145
|
+
return mse, kld, nll
|
|
146
|
+
|
|
147
|
+
def neg_log_likelihood(self, pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
|
148
|
+
neg_log_likelihoods = []
|
|
149
|
+
for i in range(self.num_outputs):
|
|
150
|
+
pred_params = pred[:, self.outputs[i]]
|
|
151
|
+
pred_fit = self.dists[i](*pred_params.unbind(dim = 1))
|
|
152
|
+
loss = -1 * pred_fit.log_prob(true[:,i])
|
|
153
|
+
neg_log_likelihoods.append(loss.mean())
|
|
154
|
+
return torch.column_stack(neg_log_likelihoods)
|
|
155
|
+
|
|
156
|
+
def kl_divergence(self, pred: torch.Tensor, true: torch.Tensor) -> torch.Tensor:
|
|
157
|
+
kl_divergences = []
|
|
158
|
+
for i in range(self.num_outputs):
|
|
159
|
+
pred_params = pred[:, self.outputs[i]]
|
|
160
|
+
true_params = true[:, self.outputs[i]]
|
|
161
|
+
pred_fit = self.dists[i](*pred_params.unbind(dim = 1))
|
|
162
|
+
true_fit = self.dists[i](*true_params.unbind(dim = 1))
|
|
163
|
+
loss = torch.distributions.kl_divergence(true_fit, pred_fit)
|
|
164
|
+
kl_divergences.append(loss)
|
|
165
|
+
return torch.column_stack(kl_divergences)
|