FASTEN-cli 1.0.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
FASTEN/tune.py ADDED
@@ -0,0 +1,92 @@
1
+ from .common import os, json
2
+ from .plot import plot_train, plot_predict
3
+ from .train import Trainer
4
+ from .predict import Predictor
5
+ from copy import deepcopy
6
+ import optuna, warnings
7
+
8
+ warnings.filterwarnings("ignore", category = optuna.exceptions.ExperimentalWarning)
9
+
10
+
11
+ class Tuner:
12
+ def __init__(self, trainer: Trainer):
13
+ self.trainer = trainer
14
+ self.model = trainer.model
15
+ self.config = trainer.model.config
16
+ self.study: optuna.Study = None
17
+
18
+ def dump_trials(self, output_dir: str):
19
+ trials_file = f"{output_dir}/trials.tsv"
20
+ trial_data = self.study.trials_dataframe().sort_values(by = "value")
21
+ trial_data.to_csv(trials_file, sep = "\t", index = False)
22
+
23
+ best_params = self.study.best_trial.params
24
+ for arg, value in self.config["model"].items():
25
+ if isinstance(value, list): self.config["model"][arg] = best_params[arg]
26
+ for arg, value in self.config["train"].items():
27
+ if isinstance(value, list): self.config["train"][arg] = best_params[arg]
28
+ with open(f"{output_dir}/config.json", "w") as file:
29
+ json.dump(self.config, file, indent = 4)
30
+
31
+ def load_study(self, output_dir: str, study_name: str = "tune_data"):
32
+ self.output_dir = output_dir
33
+ if not os.path.exists(output_dir):
34
+ os.mkdir(output_dir)
35
+ os.mkdir(f"{output_dir}/plots")
36
+ os.mkdir(f"{output_dir}/plots/trials")
37
+ storage_name = f"sqlite:///{output_dir}/{study_name}.db"
38
+ sampler = optuna.samplers.TPESampler(n_startup_trials = 20, multivariate = True)
39
+ self.study = optuna.create_study(study_name = study_name, storage = storage_name,
40
+ load_if_exists = True, direction = "minimize", sampler = sampler)
41
+
42
+ def execute(self, n_trials: int, duplicates: bool):
43
+ while len(self.study.get_trials()) < n_trials:
44
+ objective = Objective(self, duplicates)
45
+ self.study.optimize(objective, n_trials = 1)
46
+
47
+
48
+ class Objective:
49
+ def __init__(self, tuner: Tuner, unique: bool):
50
+ self.tuner, self.trainer = tuner, tuner.trainer
51
+ self.unique = unique
52
+ self.trial_dir = f"{self.tuner.output_dir}/plots/trials"
53
+
54
+ def sample(self, trial) -> tuple[dict, dict]:
55
+ trial_model, trial_train = dict(), dict()
56
+ for arg, value in self.tuner.config["model"].items():
57
+ if not isinstance(value, list): trial_model[arg] = value
58
+ else: trial_model[arg] = trial.suggest_categorical(arg, value)
59
+ for arg, value in self.tuner.config["train"].items():
60
+ if not isinstance(value, list): trial_train[arg] = value
61
+ else: trial_train[arg] = trial.suggest_categorical(arg, value)
62
+ if "rand_seed" not in trial_train or trial_train["rand_seed"] is None:
63
+ trial_train["rand_seed"] = trial.number
64
+ return trial_model, trial_train
65
+
66
+ def get_duplicate(self, trial) -> tuple[bool, int, float]:
67
+ if not self.unique: return False, None, None
68
+ for prev in trial.study.trials:
69
+ if prev.number != trial.number and prev.params == trial.params:
70
+ return True, prev.number, prev.value
71
+ return False, None, None
72
+
73
+ def __call__(self, trial) -> float:
74
+ trial_model, trial_train = self.sample(trial)
75
+ duplicate, number, value = self.get_duplicate(trial)
76
+ if duplicate:
77
+ print(f"Trial {trial.number} is a duplicate of trial {number} with value {value}.")
78
+ return value
79
+
80
+ self.tuner.model.validate_args(trial_model, trial_train)
81
+ try: self.trainer.execute()
82
+ except ValueError:
83
+ message = "Training diverged: loss is NaN (possible exploding gradients)"
84
+ raise optuna.exceptions.TrialPruned(message)
85
+
86
+ plot_train(self.trainer, f"{self.trial_dir}/trial_{trial.number}")
87
+ predictor = Predictor(self.tuner.model, deepcopy(self.trainer.test.dataset))
88
+ mse, kld, nll = plot_predict(predictor, f"{self.trial_dir}/trial_{trial.number}")
89
+ match self.tuner.model.args.loss_func:
90
+ case "MSE": return mse.mean().mean()
91
+ case "KLD": return kld.mean().mean()
92
+ case _: return nll.mean().mean()
FASTEN/utils.py ADDED
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+ from .common import np, pd, torch
3
+ from .config import ModelInput
4
+
5
+
6
+ class Scaler():
7
+ def __init__(self):
8
+ self.min = None
9
+ self.max = None
10
+ self.range = None
11
+ self.fitted = False
12
+
13
+ def fit(self, data):
14
+ if isinstance(data, pd.DataFrame):
15
+ self.min = np.min(data.values, axis = 0)
16
+ self.max = np.max(data.values, axis = 0)
17
+ if isinstance(data, torch.Tensor):
18
+ self.min = data.min(dim = 0).values.numpy()
19
+ self.max = data.max(dim = 0).values.numpy()
20
+ self.range = self.max - self.min
21
+ self.range[self.range == 0] = 1.0
22
+ self.fitted = True
23
+
24
+ def transform(self, data):
25
+ if not self.fitted: self.fit(data)
26
+ data -= self.min
27
+ data /= self.range
28
+ if isinstance(data, pd.DataFrame): data.clip(0, 1)
29
+ if isinstance(data, torch.Tensor): torch.clamp(data, 0, 1)
30
+
31
+ def inverse_transform(self, data):
32
+ data *= self.range
33
+ data += self.min
34
+
35
+
36
+ class Encoder():
37
+ def __init__(self, inputs: dict[str, ModelInput]):
38
+ self.inputs = inputs
39
+
40
+ def fit(self, data: pd.Series) -> bool:
41
+ self.labels, self.names = [], []
42
+ self.origins, self.strings = [], []
43
+ for i, origin in enumerate(self.inputs.values()):
44
+ if origin.type != "string": continue
45
+ for string in data[origin.label].unique():
46
+ self.labels.append(f"{origin.label}_{string}")
47
+ self.names.append(f"{origin.name}: {string.capitalize()}")
48
+ self.origins.append(i)
49
+ self.strings.append(string)
50
+ if len(self.strings) != len(set(self.strings)):
51
+ raise AssertionError("Set of strings must be disjoint for all categorical inputs.")
52
+ return len(self.strings) > 0
53
+
54
+ def transform(self, data: pd.Series):
55
+ if not self.fit(data): return
56
+ col = {string: i for i, string in enumerate(self.strings)}
57
+ values = np.zeros((data.shape[0], len(self.labels)))
58
+ for origin in self.inputs.values():
59
+ if origin.type != "string": continue
60
+ for row, string in enumerate(data[origin.label]):
61
+ values[row, col[string]] = 1
62
+ label = self.labels[col[string]]
63
+ name = self.names[col[string]]
64
+ self.inputs[label] = ModelInput(label, name, "integer")
65
+ data.drop(columns = origin.label, inplace = True)
66
+ self.inputs.pop(origin.label)
67
+ data[self.labels] = pd.DataFrame(values, dtype = float, index = data.index)
@@ -0,0 +1,89 @@
1
+ Metadata-Version: 2.4
2
+ Name: FASTEN-cli
3
+ Version: 1.0.0
4
+ Project-URL: Homepage, https://github.com/k1jackson/FASTEN
5
+ Project-URL: Bug Tracker, https://github.com/k1jackson/FASTEN/issues
6
+ Project-URL: Config Designer, https://k1jackson.github.io/FASTEN/
7
+ License-File: LICENSE.md
8
+ Requires-Dist: k-means-constrained
9
+ Requires-Dist: matplotlib
10
+ Requires-Dist: numpy>=2.1.1
11
+ Requires-Dist: optuna
12
+ Requires-Dist: pandas
13
+ Requires-Dist: pydantic>=2.0
14
+ Requires-Dist: rich
15
+ Requires-Dist: scikit-learn
16
+ Requires-Dist: scipy>=1.14.0
17
+ Requires-Dist: torch
18
+ Description-Content-Type: text/markdown
19
+
20
+ # FASTEN
21
+
22
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md)
23
+ [![Version](https://img.shields.io/badge/version-1.0.0-brightgreen.svg)](CHANGELOG.md)
24
+
25
+
26
+ <!-- Explain the *what* and *why* of your project in 2–4 sentences.
27
+ Answer: What problem does it solve? Who is it for? Why does it exist? -->
28
+
29
+ FASTEN is a flexible and user-friendly framework for building PyTorch-based deep learning emulators for epidemic simulations with stochastic outputs. FASTEN provides three intuitive modules to (1) train deep neural networks on simulation data, (2) select optimal hyperparameters, and (3) generate predictions from unseen inputs.
30
+
31
+
32
+ ## Installation
33
+
34
+
35
+ FASTEN can be installed via ```pip```:
36
+
37
+ ```bash
38
+ pip install fasten-cli
39
+ ```
40
+
41
+ ## Usage
42
+
43
+ There are two primary components to FASTEN: the [configuration file designer](https://k1jackson.github.io/FASTEN/) and the command line tool.
44
+
45
+ ### Configuration File Designer
46
+ Before executing a FASTEN, the user must construct a workflow configuration file that outlines the simulation data format and training hyperparameters. The [configuration file designer](https://k1jackson.github.io/FASTEN/) guides users through this process with contextual instructions and validation checks.
47
+
48
+ ### Command Line Tool
49
+ The FASTEN workflow decomposes the model emulation process into three phases: (1)training, (2) hyperparameter tuning, and (3) output prediction. Each phase is invoked through a dedicated command line module, with a shared configuration file governing the underlying behavior. The command line tool can used as follows:
50
+
51
+ **Training:**
52
+
53
+ ```bash
54
+ usage: FASTEN train [-h] -c CONFIG -i INPUT [-o OUTPUT] [-m MODEL]
55
+
56
+ options:
57
+ -h, --help show this help message and exit
58
+ -c, --config CONFIG JSON file defining configuration parameters
59
+ -i, --input INPUT TSV file with simulation data
60
+ -o, --output OUTPUT Folder to output model and figures (default: outputs)
61
+ -m, --model MODEL ZIP file containing initial model (default: None)
62
+ ```
63
+
64
+ **Hyperparameter Tuning:**
65
+
66
+ ```bash
67
+ usage: FASTEN tune [-h] -c CONFIG -i INPUT [-o OUTPUT] [-n TRIALS] [--unique]
68
+
69
+ options:
70
+ -h, --help show this help message and exit
71
+ -c, --config CONFIG JSON file defining configuration parameters
72
+ -i, --input INPUT TSV file with simulation data
73
+ -o, --output OUTPUT Folder to output optimal configs and figures (default: outputs)
74
+ -n, --trials TRIALS Total number of optimation trials (default: 100)
75
+ --unique Prevents re-training with duplicate hyperparameter sets (default: False)
76
+ ```
77
+
78
+ **Output Prediction:**
79
+
80
+ ```bash
81
+ usage: FASTEN predict [-h] -m MODEL -i INPUT [-o OUTPUT] [-n RUNS]
82
+
83
+ options:
84
+ -h, --help show this help message and exit
85
+ -m, --model MODEL ZIP file containing model
86
+ -i, --input INPUT TSV file with simulation inputs
87
+ -o, --output OUTPUT TSV file to output predicted simulation data (default: outputs.tsv)
88
+ -n, --runs RUNS Number of simulation runs per input (default: 0)
89
+ ```
@@ -0,0 +1,19 @@
1
+ FASTEN/__init__.py,sha256=e-T0h8j5pOz4g3c9sp-sQVnMGzuWSm6Dv4h_UAXigLE,188
2
+ FASTEN/cli.py,sha256=ncH8WWFEoEYiUlmGKSuxEJxfuFinQeATeJuLRaoNYYs,5139
3
+ FASTEN/common.py,sha256=HZ4iB68x2oq7yKz_gwpKNACizCKdEKkvXnPk2CQBtnU,123
4
+ FASTEN/config.py,sha256=hh6GpUTavZlqzipaXykidjcb0PgIjTL9bOYwbCezHD4,4858
5
+ FASTEN/data.py,sha256=r4kSVvkpO1abNObzZaef2AaGktc--jYYcZjZW2AOd4Q,6141
6
+ FASTEN/estimate.py,sha256=svQVENGB0M_vxX_jh4WdusJXibEs9_QE28KchzfIMSY,6711
7
+ FASTEN/learn.py,sha256=aTazbtU4-W4pEBy8qCMRDYkOKmENgu16yiA4LQrDSf0,7933
8
+ FASTEN/model.py,sha256=-aS3qBci-LOIn1MjKLQxr-R4rcUxbN4_hqCI8RKE5iM,7864
9
+ FASTEN/param.py,sha256=AtF0zuSsA_Dn4H89cN6BRZqv0NdNSXWaupv3qCJVahQ,5951
10
+ FASTEN/plot.py,sha256=-blUaho6X_4gY5AmZK4FTeMAzVr-Veds4ABug-VbQFE,10261
11
+ FASTEN/predict.py,sha256=b0LJDiVnfA4ga3BVOljAkyj5Xs8oe1_EMX6BTRTq6MY,3084
12
+ FASTEN/train.py,sha256=K6hsDpSGWYZi-jsISSZuTr-keaE0gjrjN7b80_Oc-uc,3940
13
+ FASTEN/tune.py,sha256=ltZB9CweGP3AkzFbn4SEQAVYsXTZ1wCfW9wcyKgk7UM,4345
14
+ FASTEN/utils.py,sha256=MMq7Bs9VsuEcYjpDvkZUTh2KFl9-_7OmSOik95Wtj-w,2615
15
+ fasten_cli-1.0.0.dist-info/METADATA,sha256=XTMnePdGxEuIxIlxviuNg0sKEjIHLzyAtm3UMMDGGUw,3586
16
+ fasten_cli-1.0.0.dist-info/WHEEL,sha256=VX-VJ7c6dw9Ge3EqJIbA6W3pOUbz24SnnGGFNr55jY4,105
17
+ fasten_cli-1.0.0.dist-info/entry_points.txt,sha256=brVh-WK-yszMifsiSc3uCLQs3srzXjlvObGR0Xq1IqE,43
18
+ fasten_cli-1.0.0.dist-info/licenses/LICENSE.md,sha256=u5ofZFWS6UIMv9Qaf1Yjvuouqd4xatYjLqA4Ejtxv7Q,1069
19
+ fasten_cli-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py2-none-any
5
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ FASTEN = FASTEN.cli:main
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Kate Jackson
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.