mintstate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,371 @@
1
+ Metadata-Version: 2.4
2
+ Name: mintstate
3
+ Version: 0.1.0
4
+ Summary: StateMINT is a state space based neural network emulator for malariasimulation.
5
+ Project-URL: Homepage, https://github.com/mrc-ide/stateMINT
6
+ Project-URL: Repository, https://github.com/mrc-ide/stateMINT
7
+ Project-URL: Weights, https://huggingface.co/dide-ic/stateMINT
8
+ Author-email: Anmol Thapar <mr.anmolthapar@gmail.com>
9
+ Requires-Python: >=3.12
10
+ Requires-Dist: etils>=1.0
11
+ Requires-Dist: flax>=0.12.7
12
+ Requires-Dist: huggingface-hub>=1.19.0
13
+ Requires-Dist: jax>=0.10.1
14
+ Requires-Dist: jaxtyping>=0.3.10
15
+ Requires-Dist: mamba2-jax>=1.0.1
16
+ Requires-Dist: numpy>=2.4.6
17
+ Requires-Dist: omegaconf>=2.3
18
+ Requires-Dist: orbax-checkpoint>=0.12.0
19
+ Requires-Dist: pandas>=3.0.3
20
+ Provides-Extra: all
21
+ Requires-Dist: duckdb>=1.5.3; extra == 'all'
22
+ Requires-Dist: grain>=0.2.16; extra == 'all'
23
+ Requires-Dist: hydra-core>=1.3.2; extra == 'all'
24
+ Requires-Dist: jax[cuda12]>=0.10.1; extra == 'all'
25
+ Requires-Dist: matplotlib>=3.10.9; extra == 'all'
26
+ Requires-Dist: optax>=0.2.8; extra == 'all'
27
+ Requires-Dist: tqdm>=4.67.3; extra == 'all'
28
+ Requires-Dist: wandb>=0.27.0; extra == 'all'
29
+ Provides-Extra: gpu
30
+ Requires-Dist: jax[cuda12]>=0.10.1; extra == 'gpu'
31
+ Provides-Extra: plot
32
+ Requires-Dist: matplotlib>=3.10.9; extra == 'plot'
33
+ Provides-Extra: train
34
+ Requires-Dist: duckdb>=1.5.3; extra == 'train'
35
+ Requires-Dist: grain>=0.2.16; extra == 'train'
36
+ Requires-Dist: hydra-core>=1.3.2; extra == 'train'
37
+ Requires-Dist: optax>=0.2.8; extra == 'train'
38
+ Requires-Dist: tqdm>=4.67.3; extra == 'train'
39
+ Requires-Dist: wandb>=0.27.0; extra == 'train'
40
+ Description-Content-Type: text/markdown
41
+
42
+ # StateMINT
43
+
44
+ StateMINT is a JAX/Flax neural emulator for
45
+ [`malariasimulation`](https://github.com/mrc-ide/malariasimulation) outputs. It
46
+ uses a Mamba2 state-space sequence model to predict malaria trajectories from
47
+ static scenario covariates and intervention timing features. It supersedes the
48
+ earlier
49
+ [`MINTelligence`](https://github.com/CosmoNaught/MINTelligence) RNN emulator.
50
+
51
+ ## What StateMINT Provides
52
+
53
+ - Mamba2-based sequence regressors for malaria prevalence and case-count
54
+ trajectories.
55
+ - Data extraction utilities for aggregating raw `malariasimulation` DuckDB
56
+ outputs into model-ready parquet files.
57
+ - Preprocessing with target transforms, covariate scaling, and
58
+ intervention-aware feature construction.
59
+ - Training, evaluation, visualization, checkpointing, and export workflows.
60
+ - Hugging Face Hub loading utilities for exported inference artifacts.
61
+
62
+ ## Installation
63
+
64
+ StateMINT requires Python 3.12 or newer and uses
65
+ [`uv`](https://github.com/astral-sh/uv).
66
+
67
+ ```bash
68
+ git clone https://github.com/mrc-ide/stateMINT.git
69
+ cd stateMINT
70
+ uv sync
71
+ ```
72
+
73
+ For development dependencies and optional extras:
74
+
75
+ ```bash
76
+ uv sync --all-extras --dev
77
+ ```
78
+
79
+ Or install extras individually:
80
+
81
+ ```bash
82
+ uv sync --extra plot
83
+ uv sync --extra gpu
84
+ ```
85
+
86
+ ## Quick Start: Inference
87
+
88
+ Load an exported artifact from the Hugging Face Hub or a local directory with
89
+ `Mamba2Regressor.from_pretrained`.
90
+
91
+ ```python
92
+ from stateMINT.model import Mamba2Regressor
93
+
94
+ artifact = Mamba2Regressor.from_pretrained(
95
+ "dide-ic/stateMINT",
96
+ predictor="prevalence",
97
+ revision="v1.0.0",
98
+ )
99
+
100
+ static_covars = [{
101
+ "eir": 50.0,
102
+ "dn0_use": 0.3,
103
+ "dn0_future": 0.4,
104
+ "Q0": 0.8,
105
+ "phi_bednets": 0.7,
106
+ "seasonal": 1.0,
107
+ "routine": 0.5,
108
+ "itn_use": 0.2,
109
+ "irs_use": 0.1,
110
+ "itn_future": 0.3,
111
+ "irs_future": 0.2,
112
+ "lsm": 0.0,
113
+ }]
114
+
115
+ predicted_prevalence = artifact.predict(static_covars)
116
+
117
+ print(predicted_prevalence.shape) # (batch, timesteps)
118
+ print(predicted_prevalence[0]) # first trajectory
119
+ ```
120
+
121
+ For cases, load the cases artifact and use the same input format:
122
+
123
+ ```python
124
+ artifact = Mamba2Regressor.from_pretrained(
125
+ "dide-ic/stateMINT",
126
+ predictor="cases",
127
+ revision="v1.0.0",
128
+ )
129
+
130
+ predicted_cases = artifact.predict(static_covars)
131
+ ```
132
+
133
+ By default, predictions are returned on the original target scale: prevalence as
134
+ probabilities and cases on the scale used by the training data. Pass
135
+ `transformed=True` to return model-space outputs.
136
+
137
+ ```python
138
+ raw_model_space = artifact.predict(static_covars, transformed=True)
139
+ ```
140
+
141
+ For local artifacts, pass the target artifact directory:
142
+
143
+ ```python
144
+ artifact = Mamba2Regressor.from_pretrained(
145
+ "artifacts/prevalence",
146
+ predictor="prevalence",
147
+ )
148
+ ```
149
+
150
+ ## Static Covariates
151
+
152
+ Inference inputs need one dictionary per scenario with these static covariates:
153
+
154
+ ```text
155
+ eir
156
+ dn0_use
157
+ dn0_future
158
+ Q0
159
+ phi_bednets
160
+ seasonal
161
+ routine
162
+ itn_use
163
+ irs_use
164
+ itn_future
165
+ irs_future
166
+ lsm
167
+ ```
168
+
169
+ Artifacts include the fitted static scaler, timestep grid, intervention day,
170
+ target transform, and other preprocessing metadata needed for inference.
171
+
172
+ ## Training Workflow
173
+
174
+ Typical workflow:
175
+
176
+ 1. Fetch and aggregate simulation data from DuckDB.
177
+ 2. Train a target-specific model.
178
+ 3. Evaluate or visualize test-set predictions.
179
+ 4. Export the checkpoint into a portable inference artifact.
180
+ 5. Upload the artifact to the Hugging Face Hub, if needed.
181
+
182
+ ### 1. Fetch Filtered Data
183
+
184
+ `stateMINT.filter_raw_data` reads raw DuckDB simulation rows, filters burn-in,
185
+ aggregates fixed windows, and writes `filtered_data_<predictor>.parquet`.
186
+
187
+ ```bash
188
+ uv run python -m stateMINT.filter_raw_data \
189
+ --db-path /path/to/simulations.duckdb \
190
+ --table-name simulation_results \
191
+ --predictor prevalence \
192
+ --window-size 14 \
193
+ --output-folder data
194
+ ```
195
+
196
+ Useful options:
197
+
198
+ - `--predictor prevalence` or `--predictor cases`
199
+ - `--param-limit N` to keep only the first `N` parameter indices
200
+ - `--sim-limit N` to sample up to `N` simulations per parameter
201
+
202
+ The raw table should include identifiers (`parameter_index`, `simulation_index`,
203
+ `global_index`), daily timesteps, the static covariates above, and output
204
+ columns for prevalence or cases.
205
+
206
+ ### 2. Train a Model
207
+
208
+ Training uses Hydra; the default target is prevalence.
209
+
210
+ ```bash
211
+ uv run python -m stateMINT.train
212
+ ```
213
+
214
+ Train the cases model:
215
+
216
+ ```bash
217
+ uv run python -m stateMINT.train target=cases
218
+ ```
219
+
220
+ Common overrides:
221
+
222
+ ```bash
223
+ uv run python -m stateMINT.train \
224
+ target=prevalence \
225
+ data_file=data/filtered_data_prevalence.parquet \
226
+ output_dir=train_outputs/prevalence \
227
+ use_wandb=false
228
+ ```
229
+
230
+ Training writes checkpoints under `checkpoint_dir`, saves
231
+ `static_scaler.pkl` in `output_dir`, and reuses a split assignment file for
232
+ consistent train/validation/test splits.
233
+
234
+ ### 3. W&B Sweeps
235
+
236
+ Sweep definitions live in `stateMINT/conf/sweeps`. Create a sweep, then run one
237
+ or more agents with the sweep ID returned by W&B:
238
+
239
+ ```bash
240
+ uv run wandb sweep stateMINT/conf/sweeps/prevalence.yaml
241
+ uv run wandb agent <entity>/stateMINT-sweep/<sweep-id>
242
+ ```
243
+
244
+ Use `stateMINT/conf/sweeps/cases.yaml` for the cases target. Sweep commands set
245
+ `use_wandb=true` and pass Hydra overrides through `${args_no_hyphens}`.
246
+
247
+ ### 4. Visualize Predictions
248
+
249
+ Compare predictions with test-set targets. `checkpoint_dir` is required.
250
+
251
+ ```bash
252
+ uv run python -m stateMINT.visualise_predictions \
253
+ target=prevalence \
254
+ checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
255
+ data_file=data/filtered_data_prevalence.parquet
256
+ ```
257
+
258
+ The default output path is `viz_outputs/<predictor>/preds-vs-targets.pdf`.
259
+
260
+ ### 5. Export an Artifact
261
+
262
+ Export converts a trained Orbax checkpoint and preprocessing metadata into a
263
+ self-contained artifact.
264
+
265
+ ```bash
266
+ uv run python -m stateMINT.model_export \
267
+ predictor=prevalence \
268
+ checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
269
+ scaler_file=train_outputs/prevalence/static_scaler.pkl \
270
+ artifact_dir=artifacts/prevalence
271
+ ```
272
+
273
+ Export config architecture values must match the checkpoint, including
274
+ `d_model`, `d_state`, `n_layers`, `dropout`, and related Mamba2 settings.
275
+
276
+ An exported artifact contains:
277
+
278
+ ```text
279
+ artifact_dir/
280
+ |-- checkpoint/
281
+ |-- model_config.json
282
+ `-- preprocessing_config.json
283
+ ```
284
+
285
+ `model_config.json` stores architecture settings; `preprocessing_config.json`
286
+ stores feature order, target transform, intervention timing, timestep
287
+ construction, and static scaler parameters.
288
+
289
+ ### 6. Upload To Hugging Face
290
+
291
+ Authenticate first:
292
+
293
+ ```bash
294
+ hf auth login
295
+ ```
296
+
297
+ Upload an artifact:
298
+
299
+ ```bash
300
+ hf upload dide-ic/stateMINT artifacts/prevalence prevalence/ \
301
+ --commit-message "Add prevalence model artifact"
302
+ ```
303
+
304
+ Create a release tag:
305
+
306
+ ```bash
307
+ hf repos tag create dide-ic/stateMINT v1.0.0 \
308
+ --revision main \
309
+ --message "Release v1.0.0"
310
+ ```
311
+
312
+ ## Configuration
313
+
314
+ Main Hydra configs in `stateMINT/conf`:
315
+
316
+ - `train_config.yaml` for training.
317
+ - `viz_config.yaml` for prediction visualizations.
318
+ - `export_config.yaml` for artifact export.
319
+ - `target/prevalence.yaml` and `target/cases.yaml` for target-specific defaults.
320
+ - `sweeps/*.yaml` for Weights & Biases sweep definitions.
321
+
322
+ Select a target with `target=prevalence` or `target=cases`; override config
323
+ values from the command line with Hydra syntax.
324
+
325
+ ## Development
326
+
327
+ Run the test suite:
328
+
329
+ ```bash
330
+ uv run pytest tests/
331
+ ```
332
+
333
+ Skip slow or local-only tests:
334
+
335
+ ```bash
336
+ uv run pytest tests/ -m "not slow"
337
+ uv run pytest tests/ -m "not local"
338
+ uv run pytest tests/ -m "not slow and not local" # skip both
339
+ ```
340
+
341
+ Run linting and formatting:
342
+
343
+ ```bash
344
+ uv run ruff check
345
+ uv run ruff format
346
+ ```
347
+
348
+ ## Repository Layout
349
+
350
+ ```text
351
+ stateMINT/
352
+ |-- common/ # shared dataclasses, transforms, and model helpers
353
+ |-- conf/ # Hydra configs for training, export, viz, and sweeps
354
+ |-- data/ # DuckDB fetch, preprocessing, features, and loaders
355
+ |-- eval/ # metrics and prediction/target visualization helpers
356
+ |-- model/ # Mamba2 regressor and artifact loading utilities
357
+ |-- training/ # optimizer, train/eval steps, loss, and checkpointing
358
+ |-- filter_raw_data.py # CLI for building filtered parquet datasets
359
+ |-- train.py # Hydra training entry point
360
+ |-- visualise_predictions.py
361
+ `-- model_export.py # Hydra artifact export entry point
362
+
363
+ tests/ # unit tests
364
+ artifacts/ # exported model artifact examples/metadata
365
+ viz_outputs/ # generated prediction visualization outputs
366
+ ```
367
+
368
+ ## Contributing
369
+
370
+ See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines and the
371
+ development workflow.
@@ -0,0 +1,33 @@
1
+ stateMINT/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ stateMINT/filter_raw_data.py,sha256=KjOC7GeZ19lOl4i4PTqkC17cCCYdMPrrizfQftlwKLM,1623
3
+ stateMINT/model_export.py,sha256=JynbP6gTHgV7ZrSDpWWf4l_lZbHl89zd5R_ZTaTupL4,2903
4
+ stateMINT/train.py,sha256=eI_f2gjCXe-PYXZd1Gm_cv404GezX793Cbb5g0n859E,5696
5
+ stateMINT/visualise_predictions.py,sha256=4MrwoutGevqHilSEAWToCchuxHcykZuPkFrX5d-bBlo,1853
6
+ stateMINT/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ stateMINT/common/dataclasses.py,sha256=YYVby4nOZQehS5qiYngVkdomDL8BUkZQ6xoTa5gAr98,260
8
+ stateMINT/common/utils.py,sha256=oLmz5HWAb5QwXP26_4WmeJ_AefBCW6wAT5pzQ7n0X7w,1748
9
+ stateMINT/conf/export_config.yaml,sha256=aPmSsxJMJvZiigfjrpwavVNidl2tR9yJVJTcdMtoJmM,426
10
+ stateMINT/conf/train_config.yaml,sha256=lszJtBir7GEvVKjlfFUeBzgtHnKnAUiJ5V7DFyuZgQQ,911
11
+ stateMINT/conf/viz_config.yaml,sha256=q74YLMK7QWoBm_uXJTWZW2GcAx5AaMTjI_tLMD_YyrI,517
12
+ stateMINT/conf/sweeps/cases.yaml,sha256=P2w1a3VDwP_6h7bBVRkt7DG9rUe4EDl--cfSeAbkhe0,699
13
+ stateMINT/conf/sweeps/prevalence.yaml,sha256=cQNg8YqldLuzYwwUERjjQvZaK3EMfiOpoWDGY92k-Ig,711
14
+ stateMINT/conf/target/cases.yaml,sha256=9LkY6jjAoC2tD1QObJlghTO0t1xl95Q3x9ypyTfgVh4,199
15
+ stateMINT/conf/target/prevalence.yaml,sha256=nUKOpVnqutYpCuEchz5rhbVr_bZ6tpFZK_pPbBqBIao,198
16
+ stateMINT/data/__init__.py,sha256=FXunf0wXzF1jDJUHo2AogDqvzmbryc7bLo9Mg6VjLmQ,403
17
+ stateMINT/data/dataset.py,sha256=idTvwz0i3ualujXp-Bh8F67ROASNT1vr5ygHqf_RbDQ,1689
18
+ stateMINT/data/features.py,sha256=W64e3PE4ZRM7ufLyPB4sJGWDnSpmpJZxkMNgWqfuQxc,2714
19
+ stateMINT/data/fetch.py,sha256=MZelwGuOkGwx_5b0rKWryIoUJPNF7GgxoissA2Ebyvk,6110
20
+ stateMINT/data/preprocessing.py,sha256=KBbH-lIMUgixac3YjfENBr5YagHvRy-uA7pzFJb3OFo,17428
21
+ stateMINT/eval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ stateMINT/eval/metrics.py,sha256=J79ANLzGGb3s_ARFExF4AsR0YRqPULQsKwVL0MmZKRk,4335
23
+ stateMINT/eval/viz_preds_truth.py,sha256=CTDT58MkRZgxVRaZDUCTofCW6VOWZQtVPN07Od-bSzo,6104
24
+ stateMINT/model/__init__.py,sha256=xL_8A4TghG_bYEYuU8qNl_ov5yFtB8WcvYzkPp239vY,67
25
+ stateMINT/model/hub.py,sha256=OHveRm0qvXpuAd6f-4QOg7MTkO-7tJ3BGfC6npWSmEg,4802
26
+ stateMINT/model/mamba2.py,sha256=IIo3VDuSq1UbcBnDWpuOCB-s6wCJMRhL7FOe_pMIAek,5581
27
+ stateMINT/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ stateMINT/training/checkpoint.py,sha256=jKQmeP28aCY_C7koHL3P8OmQHsmDhVA0QjmL77ryO6A,4372
29
+ stateMINT/training/loss.py,sha256=ZmP62C8i-xRB_pNMTGl4-LkT5esrtgwOmpJozbbL64Q,746
30
+ stateMINT/training/train_step.py,sha256=CMYt2kcmpKXI5qgsk2oN52AKgr-emyPwjTj9OU70sPA,4469
31
+ mintstate-0.1.0.dist-info/METADATA,sha256=Z_ZP7Okvm8OFnrYKB2iBlb1SAQVyNzm9mlQN_YQVmEE,10137
32
+ mintstate-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
33
+ mintstate-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
stateMINT/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,11 @@
1
+ from typing import Literal, Protocol
2
+
3
+ from flax import nnx
4
+ from omegaconf import DictConfig
5
+
6
+ Predictor = Literal["prevalence", "cases"]
7
+
8
+
9
+ class ModelFactory(Protocol):
10
+ @classmethod
11
+ def from_cfg(cls, cfg: DictConfig, input_size: int) -> nnx.Module: ...
@@ -0,0 +1,74 @@
1
+ import numpy as np
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from .dataclasses import Predictor
5
+ from jaxtyping import Array
6
+ from flax import nnx
7
+
8
+
9
+ def transform_targets_np(y: np.ndarray, predictor: Predictor, eps: float = 1e-5) -> np.ndarray:
10
+ """
11
+ Apply train-time transform to targets.
12
+
13
+ Args:
14
+ y: Target values.
15
+ predictor: Target type.
16
+ eps: Prevalence clipping epsilon.
17
+
18
+ Returns:
19
+ Transformed target values.
20
+ """
21
+ if predictor == "prevalence":
22
+ y = np.clip(y, eps, 1.0 - eps)
23
+ return np.log(y / (1.0 - y)) # logit transform
24
+ else:
25
+ return np.log1p(np.maximum(y, 0.0)) # log1p transform for counts/rates
26
+
27
+
28
+ def inverse_transform_np(y: np.ndarray, predictor: Predictor) -> np.ndarray:
29
+ """
30
+ Invert transform for metrics/plots.
31
+
32
+ Args:
33
+ y: Transformed target values.
34
+ predictor: Target type.
35
+
36
+ Returns:
37
+ Values in the original target scale.
38
+ """
39
+ if predictor == "prevalence":
40
+ return 1.0 / (1.0 + np.exp(-y)) # sigmoid
41
+ else:
42
+ return np.expm1(y)
43
+
44
+
45
+ def inverse_transform_jax(y: jax.Array, predictor: Predictor) -> jax.Array:
46
+ """
47
+ Invert transformed targets with JAX.
48
+
49
+ Args:
50
+ y: Transformed target values.
51
+ predictor: Target type.
52
+
53
+ Returns:
54
+ Values in the original target scale.
55
+ """
56
+ if predictor == "prevalence":
57
+ return jax.nn.sigmoid(y)
58
+ else:
59
+ return jnp.expm1(y)
60
+
61
+
62
+ @nnx.jit
63
+ def forward(model: nnx.Module, x: Array) -> Array:
64
+ """
65
+ Forward pass through the model.
66
+
67
+ Args:
68
+ model: Model to evaluate.
69
+ x: Input batch.
70
+
71
+ Returns:
72
+ Model predictions with shape (B, T).
73
+ """
74
+ return model(x).squeeze(-1) # (B, T, 1) -> (B, T)
@@ -0,0 +1,25 @@
1
+ predictor: prevalence
2
+
3
+ # Data
4
+ eps_prevalence: 1.0e-5
5
+ scaler_file: train_outputs/${predictor}/static_scaler.pkl
6
+ window_size: 14
7
+ use_cyclical_time: true
8
+
9
+ # Model - ensure these match the checkpointed model's parameters
10
+ n_layers: 4
11
+ d_conv: 4
12
+ expand: 2
13
+ head_dim: 64
14
+ chunk_size: 256
15
+ output_dim: 1
16
+ d_model: 256
17
+ d_state: 128
18
+ dropout: 0.3
19
+
20
+ # Checkpointing
21
+ checkpoint_dir: ???
22
+
23
+ # General
24
+ seed: 42
25
+ artifact_dir: "artifacts/${predictor}"
@@ -0,0 +1,46 @@
1
+ program: stateMINT/train.py
2
+ project: stateMINT-sweep
3
+ name: sweep-cases
4
+ method: bayes
5
+ metric:
6
+ goal: minimize
7
+ name: val/loss
8
+
9
+ # early_terminate:
10
+ # type: hyperband
11
+ # min_iter: 10
12
+ # eta: 3
13
+
14
+ parameters:
15
+ lr:
16
+ distribution: log_uniform_values
17
+ min: 0.001
18
+ max: 0.01
19
+ dropout:
20
+ distribution: uniform
21
+ min: 0.25
22
+ max: 0.6
23
+ batch_size:
24
+ values: [128]
25
+
26
+ # model-specific
27
+ d_state:
28
+ values: [128]
29
+ d_model:
30
+ values: [256]
31
+ n_layers:
32
+ values: [1, 2, 3, 4]
33
+
34
+
35
+ command:
36
+ - ${env}
37
+ - python
38
+ - -m
39
+ - stateMINT.train
40
+ - "hydra.output_subdir=null"
41
+ - "hydra.run.dir=."
42
+ - ${args_no_hyphens}
43
+ - target=cases
44
+ - use_wandb=true
45
+ - num_epochs=300
46
+ - patience=100
@@ -0,0 +1,46 @@
1
+ program: stateMINT/train.py
2
+ project: stateMINT-sweep
3
+ name: sweep-prevalence
4
+ method: bayes
5
+ metric:
6
+ goal: minimize
7
+ name: val/loss
8
+
9
+ # early_terminate:
10
+ # type: hyperband
11
+ # min_iter: 40
12
+ # eta: 3
13
+
14
+ parameters:
15
+ lr:
16
+ distribution: log_uniform_values
17
+ min: 0.0001
18
+ max: 0.001
19
+ dropout:
20
+ distribution: uniform
21
+ min: 0.25
22
+ max: 0.6
23
+ batch_size:
24
+ values: [128]
25
+
26
+ # model-specific
27
+ d_state:
28
+ values: [128]
29
+ d_model:
30
+ values: [256]
31
+ n_layers:
32
+ values: [1, 2, 3, 4]
33
+
34
+
35
+ command:
36
+ - ${env}
37
+ - python
38
+ - -m
39
+ - stateMINT.train
40
+ - "hydra.output_subdir=null"
41
+ - "hydra.run.dir=."
42
+ - ${args_no_hyphens}
43
+ - target=prevalence
44
+ - use_wandb=true
45
+ - num_epochs=300
46
+ - patience=100
@@ -0,0 +1,18 @@
1
+ # @package _global_
2
+
3
+ predictor: cases
4
+
5
+ # Data
6
+ min_cases: 0.1
7
+ ylabel: "Cases per 1000 per day"
8
+
9
+ # Model
10
+ n_layers: 4
11
+ d_model: 256
12
+ d_state: 128
13
+ dropout: 0.3
14
+
15
+ # Hyperparameters
16
+ lr: 8e-3
17
+ batch_size: 128
18
+
@@ -0,0 +1,18 @@
1
+ # @package _global_
2
+
3
+ predictor: prevalence
4
+
5
+ # Data
6
+ min_prevalence: 0.01
7
+ ylabel: "Prevalence"
8
+
9
+ # Model
10
+ n_layers: 4
11
+ d_model: 256
12
+ d_state: 128
13
+ dropout: 0.4
14
+
15
+ # Hyperparameters
16
+ lr: 4e-4
17
+ batch_size: 128
18
+
@@ -0,0 +1,46 @@
1
+ defaults:
2
+ - _self_ # Load current and override with the following defaults
3
+ - target: prevalence
4
+
5
+ # Data
6
+ data_file: "data/filtered_data_${predictor}.parquet"
7
+ split_file: "data/split_${predictor}.json"
8
+ num_workers: 0
9
+ use_existing_split: false
10
+ eps_prevalence: 1.0e-5
11
+ use_cyclical_time: true
12
+
13
+ # Model
14
+ loss_method: stateMINT.training.loss.weighted_mse
15
+ n_layers: 4
16
+ d_conv: 4
17
+ expand: 2
18
+ head_dim: 64
19
+ chunk_size: 256
20
+ output_dim: 1
21
+ d_model: 256
22
+ d_state: 128
23
+
24
+ # Hyperparameters
25
+ num_epochs: 200
26
+ min_epochs: 100
27
+ patience: 50
28
+ diff_loss_alpha: 0.05
29
+ lr: 1e-3
30
+ batch_size: 128
31
+ dropout: 0.3
32
+ weight_decay: 1e-4
33
+
34
+ # Checkpointing
35
+ checkpoint_dir: "${output_dir}/ckpts-${cur_time}"
36
+ restore_checkpoint: false
37
+ max_checkpoints_to_keep: 1
38
+
39
+ # General
40
+ cur_time: ${now:%Y-%m-%dT%H:%M:%S}
41
+ seed: 42
42
+ use_wandb: false
43
+ output_dir: "train_outputs/${predictor}"
44
+ wandb:
45
+ project: "stateMINT-${predictor}"
46
+ name: "train-${predictor}-${cur_time}"