FASTEN-cli 1.0.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- FASTEN/__init__.py +7 -0
- FASTEN/cli.py +95 -0
- FASTEN/common.py +5 -0
- FASTEN/config.py +110 -0
- FASTEN/data.py +118 -0
- FASTEN/estimate.py +138 -0
- FASTEN/learn.py +165 -0
- FASTEN/model.py +152 -0
- FASTEN/param.py +120 -0
- FASTEN/plot.py +215 -0
- FASTEN/predict.py +66 -0
- FASTEN/train.py +87 -0
- FASTEN/tune.py +92 -0
- FASTEN/utils.py +67 -0
- fasten_cli-1.0.0.dist-info/METADATA +89 -0
- fasten_cli-1.0.0.dist-info/RECORD +19 -0
- fasten_cli-1.0.0.dist-info/WHEEL +5 -0
- fasten_cli-1.0.0.dist-info/entry_points.txt +2 -0
- fasten_cli-1.0.0.dist-info/licenses/LICENSE.md +21 -0
FASTEN/tune.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from .common import os, json
|
|
2
|
+
from .plot import plot_train, plot_predict
|
|
3
|
+
from .train import Trainer
|
|
4
|
+
from .predict import Predictor
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
import optuna, warnings
|
|
7
|
+
|
|
8
|
+
warnings.filterwarnings("ignore", category = optuna.exceptions.ExperimentalWarning)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Tuner:
|
|
12
|
+
def __init__(self, trainer: Trainer):
|
|
13
|
+
self.trainer = trainer
|
|
14
|
+
self.model = trainer.model
|
|
15
|
+
self.config = trainer.model.config
|
|
16
|
+
self.study: optuna.Study = None
|
|
17
|
+
|
|
18
|
+
def dump_trials(self, output_dir: str):
|
|
19
|
+
trials_file = f"{output_dir}/trials.tsv"
|
|
20
|
+
trial_data = self.study.trials_dataframe().sort_values(by = "value")
|
|
21
|
+
trial_data.to_csv(trials_file, sep = "\t", index = False)
|
|
22
|
+
|
|
23
|
+
best_params = self.study.best_trial.params
|
|
24
|
+
for arg, value in self.config["model"].items():
|
|
25
|
+
if isinstance(value, list): self.config["model"][arg] = best_params[arg]
|
|
26
|
+
for arg, value in self.config["train"].items():
|
|
27
|
+
if isinstance(value, list): self.config["train"][arg] = best_params[arg]
|
|
28
|
+
with open(f"{output_dir}/config.json", "w") as file:
|
|
29
|
+
json.dump(self.config, file, indent = 4)
|
|
30
|
+
|
|
31
|
+
def load_study(self, output_dir: str, study_name: str = "tune_data"):
|
|
32
|
+
self.output_dir = output_dir
|
|
33
|
+
if not os.path.exists(output_dir):
|
|
34
|
+
os.mkdir(output_dir)
|
|
35
|
+
os.mkdir(f"{output_dir}/plots")
|
|
36
|
+
os.mkdir(f"{output_dir}/plots/trials")
|
|
37
|
+
storage_name = f"sqlite:///{output_dir}/{study_name}.db"
|
|
38
|
+
sampler = optuna.samplers.TPESampler(n_startup_trials = 20, multivariate = True)
|
|
39
|
+
self.study = optuna.create_study(study_name = study_name, storage = storage_name,
|
|
40
|
+
load_if_exists = True, direction = "minimize", sampler = sampler)
|
|
41
|
+
|
|
42
|
+
def execute(self, n_trials: int, duplicates: bool):
|
|
43
|
+
while len(self.study.get_trials()) < n_trials:
|
|
44
|
+
objective = Objective(self, duplicates)
|
|
45
|
+
self.study.optimize(objective, n_trials = 1)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Objective:
|
|
49
|
+
def __init__(self, tuner: Tuner, unique: bool):
|
|
50
|
+
self.tuner, self.trainer = tuner, tuner.trainer
|
|
51
|
+
self.unique = unique
|
|
52
|
+
self.trial_dir = f"{self.tuner.output_dir}/plots/trials"
|
|
53
|
+
|
|
54
|
+
def sample(self, trial) -> tuple[dict, dict]:
|
|
55
|
+
trial_model, trial_train = dict(), dict()
|
|
56
|
+
for arg, value in self.tuner.config["model"].items():
|
|
57
|
+
if not isinstance(value, list): trial_model[arg] = value
|
|
58
|
+
else: trial_model[arg] = trial.suggest_categorical(arg, value)
|
|
59
|
+
for arg, value in self.tuner.config["train"].items():
|
|
60
|
+
if not isinstance(value, list): trial_train[arg] = value
|
|
61
|
+
else: trial_train[arg] = trial.suggest_categorical(arg, value)
|
|
62
|
+
if "rand_seed" not in trial_train or trial_train["rand_seed"] is None:
|
|
63
|
+
trial_train["rand_seed"] = trial.number
|
|
64
|
+
return trial_model, trial_train
|
|
65
|
+
|
|
66
|
+
def get_duplicate(self, trial) -> tuple[bool, int, float]:
|
|
67
|
+
if not self.unique: return False, None, None
|
|
68
|
+
for prev in trial.study.trials:
|
|
69
|
+
if prev.number != trial.number and prev.params == trial.params:
|
|
70
|
+
return True, prev.number, prev.value
|
|
71
|
+
return False, None, None
|
|
72
|
+
|
|
73
|
+
def __call__(self, trial) -> float:
|
|
74
|
+
trial_model, trial_train = self.sample(trial)
|
|
75
|
+
duplicate, number, value = self.get_duplicate(trial)
|
|
76
|
+
if duplicate:
|
|
77
|
+
print(f"Trial {trial.number} is a duplicate of trial {number} with value {value}.")
|
|
78
|
+
return value
|
|
79
|
+
|
|
80
|
+
self.tuner.model.validate_args(trial_model, trial_train)
|
|
81
|
+
try: self.trainer.execute()
|
|
82
|
+
except ValueError:
|
|
83
|
+
message = "Training diverged: loss is NaN (possible exploding gradients)"
|
|
84
|
+
raise optuna.exceptions.TrialPruned(message)
|
|
85
|
+
|
|
86
|
+
plot_train(self.trainer, f"{self.trial_dir}/trial_{trial.number}")
|
|
87
|
+
predictor = Predictor(self.tuner.model, deepcopy(self.trainer.test.dataset))
|
|
88
|
+
mse, kld, nll = plot_predict(predictor, f"{self.trial_dir}/trial_{trial.number}")
|
|
89
|
+
match self.tuner.model.args.loss_func:
|
|
90
|
+
case "MSE": return mse.mean().mean()
|
|
91
|
+
case "KLD": return kld.mean().mean()
|
|
92
|
+
case _: return nll.mean().mean()
|
FASTEN/utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from .common import np, pd, torch
|
|
3
|
+
from .config import ModelInput
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Scaler():
|
|
7
|
+
def __init__(self):
|
|
8
|
+
self.min = None
|
|
9
|
+
self.max = None
|
|
10
|
+
self.range = None
|
|
11
|
+
self.fitted = False
|
|
12
|
+
|
|
13
|
+
def fit(self, data):
|
|
14
|
+
if isinstance(data, pd.DataFrame):
|
|
15
|
+
self.min = np.min(data.values, axis = 0)
|
|
16
|
+
self.max = np.max(data.values, axis = 0)
|
|
17
|
+
if isinstance(data, torch.Tensor):
|
|
18
|
+
self.min = data.min(dim = 0).values.numpy()
|
|
19
|
+
self.max = data.max(dim = 0).values.numpy()
|
|
20
|
+
self.range = self.max - self.min
|
|
21
|
+
self.range[self.range == 0] = 1.0
|
|
22
|
+
self.fitted = True
|
|
23
|
+
|
|
24
|
+
def transform(self, data):
|
|
25
|
+
if not self.fitted: self.fit(data)
|
|
26
|
+
data -= self.min
|
|
27
|
+
data /= self.range
|
|
28
|
+
if isinstance(data, pd.DataFrame): data.clip(0, 1)
|
|
29
|
+
if isinstance(data, torch.Tensor): torch.clamp(data, 0, 1)
|
|
30
|
+
|
|
31
|
+
def inverse_transform(self, data):
|
|
32
|
+
data *= self.range
|
|
33
|
+
data += self.min
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Encoder():
|
|
37
|
+
def __init__(self, inputs: dict[str, ModelInput]):
|
|
38
|
+
self.inputs = inputs
|
|
39
|
+
|
|
40
|
+
def fit(self, data: pd.Series) -> bool:
|
|
41
|
+
self.labels, self.names = [], []
|
|
42
|
+
self.origins, self.strings = [], []
|
|
43
|
+
for i, origin in enumerate(self.inputs.values()):
|
|
44
|
+
if origin.type != "string": continue
|
|
45
|
+
for string in data[origin.label].unique():
|
|
46
|
+
self.labels.append(f"{origin.label}_{string}")
|
|
47
|
+
self.names.append(f"{origin.name}: {string.capitalize()}")
|
|
48
|
+
self.origins.append(i)
|
|
49
|
+
self.strings.append(string)
|
|
50
|
+
if len(self.strings) != len(set(self.strings)):
|
|
51
|
+
raise AssertionError("Set of strings must be disjoint for all categorical inputs.")
|
|
52
|
+
return len(self.strings) > 0
|
|
53
|
+
|
|
54
|
+
def transform(self, data: pd.Series):
|
|
55
|
+
if not self.fit(data): return
|
|
56
|
+
col = {string: i for i, string in enumerate(self.strings)}
|
|
57
|
+
values = np.zeros((data.shape[0], len(self.labels)))
|
|
58
|
+
for origin in self.inputs.values():
|
|
59
|
+
if origin.type != "string": continue
|
|
60
|
+
for row, string in enumerate(data[origin.label]):
|
|
61
|
+
values[row, col[string]] = 1
|
|
62
|
+
label = self.labels[col[string]]
|
|
63
|
+
name = self.names[col[string]]
|
|
64
|
+
self.inputs[label] = ModelInput(label, name, "integer")
|
|
65
|
+
data.drop(columns = origin.label, inplace = True)
|
|
66
|
+
self.inputs.pop(origin.label)
|
|
67
|
+
data[self.labels] = pd.DataFrame(values, dtype = float, index = data.index)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: FASTEN-cli
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Project-URL: Homepage, https://github.com/k1jackson/FASTEN
|
|
5
|
+
Project-URL: Bug Tracker, https://github.com/k1jackson/FASTEN/issues
|
|
6
|
+
Project-URL: Config Designer, https://k1jackson.github.io/FASTEN/
|
|
7
|
+
License-File: LICENSE.md
|
|
8
|
+
Requires-Dist: k-means-constrained
|
|
9
|
+
Requires-Dist: matplotlib
|
|
10
|
+
Requires-Dist: numpy>=2.1.1
|
|
11
|
+
Requires-Dist: optuna
|
|
12
|
+
Requires-Dist: pandas
|
|
13
|
+
Requires-Dist: pydantic>=2.0
|
|
14
|
+
Requires-Dist: rich
|
|
15
|
+
Requires-Dist: scikit-learn
|
|
16
|
+
Requires-Dist: scipy>=1.14.0
|
|
17
|
+
Requires-Dist: torch
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
|
|
20
|
+
# FASTEN
|
|
21
|
+
|
|
22
|
+
[](LICENSE.md)
|
|
23
|
+
[](CHANGELOG.md)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
<!-- Explain the *what* and *why* of your project in 2–4 sentences.
|
|
27
|
+
Answer: What problem does it solve? Who is it for? Why does it exist? -->
|
|
28
|
+
|
|
29
|
+
FASTEN is a flexible and user-friendly framework for building PyTorch-based deep learning emulators for epidemic simulations with stochastic outputs. FASTEN provides three intuitive modules to (1) train deep neural networks on simulation data, (2) select optimal hyperparameters, and (3) generate predictions from unseen inputs.
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
FASTEN can be installed via ```pip```:
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
pip install fasten-cli
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
## Usage
|
|
42
|
+
|
|
43
|
+
There are two primary components to FASTEN: the [configuration file designer](https://k1jackson.github.io/FASTEN/) and the command line tool.
|
|
44
|
+
|
|
45
|
+
### Configuration File Designer
|
|
46
|
+
Before executing a FASTEN, the user must construct a workflow configuration file that outlines the simulation data format and training hyperparameters. The [configuration file designer](https://k1jackson.github.io/FASTEN/) guides users through this process with contextual instructions and validation checks.
|
|
47
|
+
|
|
48
|
+
### Command Line Tool
|
|
49
|
+
The FASTEN workflow decomposes the model emulation process into three phases: (1)training, (2) hyperparameter tuning, and (3) output prediction. Each phase is invoked through a dedicated command line module, with a shared configuration file governing the underlying behavior. The command line tool can used as follows:
|
|
50
|
+
|
|
51
|
+
**Training:**
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
usage: FASTEN train [-h] -c CONFIG -i INPUT [-o OUTPUT] [-m MODEL]
|
|
55
|
+
|
|
56
|
+
options:
|
|
57
|
+
-h, --help show this help message and exit
|
|
58
|
+
-c, --config CONFIG JSON file defining configuration parameters
|
|
59
|
+
-i, --input INPUT TSV file with simulation data
|
|
60
|
+
-o, --output OUTPUT Folder to output model and figures (default: outputs)
|
|
61
|
+
-m, --model MODEL ZIP file containing initial model (default: None)
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
**Hyperparameter Tuning:**
|
|
65
|
+
|
|
66
|
+
```bash
|
|
67
|
+
usage: FASTEN tune [-h] -c CONFIG -i INPUT [-o OUTPUT] [-n TRIALS] [--unique]
|
|
68
|
+
|
|
69
|
+
options:
|
|
70
|
+
-h, --help show this help message and exit
|
|
71
|
+
-c, --config CONFIG JSON file defining configuration parameters
|
|
72
|
+
-i, --input INPUT TSV file with simulation data
|
|
73
|
+
-o, --output OUTPUT Folder to output optimal configs and figures (default: outputs)
|
|
74
|
+
-n, --trials TRIALS Total number of optimation trials (default: 100)
|
|
75
|
+
--unique Prevents re-training with duplicate hyperparameter sets (default: False)
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
**Output Prediction:**
|
|
79
|
+
|
|
80
|
+
```bash
|
|
81
|
+
usage: FASTEN predict [-h] -m MODEL -i INPUT [-o OUTPUT] [-n RUNS]
|
|
82
|
+
|
|
83
|
+
options:
|
|
84
|
+
-h, --help show this help message and exit
|
|
85
|
+
-m, --model MODEL ZIP file containing model
|
|
86
|
+
-i, --input INPUT TSV file with simulation inputs
|
|
87
|
+
-o, --output OUTPUT TSV file to output predicted simulation data (default: outputs.tsv)
|
|
88
|
+
-n, --runs RUNS Number of simulation runs per input (default: 0)
|
|
89
|
+
```
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
FASTEN/__init__.py,sha256=e-T0h8j5pOz4g3c9sp-sQVnMGzuWSm6Dv4h_UAXigLE,188
|
|
2
|
+
FASTEN/cli.py,sha256=ncH8WWFEoEYiUlmGKSuxEJxfuFinQeATeJuLRaoNYYs,5139
|
|
3
|
+
FASTEN/common.py,sha256=HZ4iB68x2oq7yKz_gwpKNACizCKdEKkvXnPk2CQBtnU,123
|
|
4
|
+
FASTEN/config.py,sha256=hh6GpUTavZlqzipaXykidjcb0PgIjTL9bOYwbCezHD4,4858
|
|
5
|
+
FASTEN/data.py,sha256=r4kSVvkpO1abNObzZaef2AaGktc--jYYcZjZW2AOd4Q,6141
|
|
6
|
+
FASTEN/estimate.py,sha256=svQVENGB0M_vxX_jh4WdusJXibEs9_QE28KchzfIMSY,6711
|
|
7
|
+
FASTEN/learn.py,sha256=aTazbtU4-W4pEBy8qCMRDYkOKmENgu16yiA4LQrDSf0,7933
|
|
8
|
+
FASTEN/model.py,sha256=-aS3qBci-LOIn1MjKLQxr-R4rcUxbN4_hqCI8RKE5iM,7864
|
|
9
|
+
FASTEN/param.py,sha256=AtF0zuSsA_Dn4H89cN6BRZqv0NdNSXWaupv3qCJVahQ,5951
|
|
10
|
+
FASTEN/plot.py,sha256=-blUaho6X_4gY5AmZK4FTeMAzVr-Veds4ABug-VbQFE,10261
|
|
11
|
+
FASTEN/predict.py,sha256=b0LJDiVnfA4ga3BVOljAkyj5Xs8oe1_EMX6BTRTq6MY,3084
|
|
12
|
+
FASTEN/train.py,sha256=K6hsDpSGWYZi-jsISSZuTr-keaE0gjrjN7b80_Oc-uc,3940
|
|
13
|
+
FASTEN/tune.py,sha256=ltZB9CweGP3AkzFbn4SEQAVYsXTZ1wCfW9wcyKgk7UM,4345
|
|
14
|
+
FASTEN/utils.py,sha256=MMq7Bs9VsuEcYjpDvkZUTh2KFl9-_7OmSOik95Wtj-w,2615
|
|
15
|
+
fasten_cli-1.0.0.dist-info/METADATA,sha256=XTMnePdGxEuIxIlxviuNg0sKEjIHLzyAtm3UMMDGGUw,3586
|
|
16
|
+
fasten_cli-1.0.0.dist-info/WHEEL,sha256=VX-VJ7c6dw9Ge3EqJIbA6W3pOUbz24SnnGGFNr55jY4,105
|
|
17
|
+
fasten_cli-1.0.0.dist-info/entry_points.txt,sha256=brVh-WK-yszMifsiSc3uCLQs3srzXjlvObGR0Xq1IqE,43
|
|
18
|
+
fasten_cli-1.0.0.dist-info/licenses/LICENSE.md,sha256=u5ofZFWS6UIMv9Qaf1Yjvuouqd4xatYjLqA4Ejtxv7Q,1069
|
|
19
|
+
fasten_cli-1.0.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Kate Jackson
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|