mintstate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mintstate-0.1.0.dist-info/METADATA +371 -0
- mintstate-0.1.0.dist-info/RECORD +33 -0
- mintstate-0.1.0.dist-info/WHEEL +4 -0
- stateMINT/__init__.py +0 -0
- stateMINT/common/__init__.py +0 -0
- stateMINT/common/dataclasses.py +11 -0
- stateMINT/common/utils.py +74 -0
- stateMINT/conf/export_config.yaml +25 -0
- stateMINT/conf/sweeps/cases.yaml +46 -0
- stateMINT/conf/sweeps/prevalence.yaml +46 -0
- stateMINT/conf/target/cases.yaml +18 -0
- stateMINT/conf/target/prevalence.yaml +18 -0
- stateMINT/conf/train_config.yaml +46 -0
- stateMINT/conf/viz_config.yaml +31 -0
- stateMINT/data/__init__.py +21 -0
- stateMINT/data/dataset.py +77 -0
- stateMINT/data/features.py +104 -0
- stateMINT/data/fetch.py +194 -0
- stateMINT/data/preprocessing.py +504 -0
- stateMINT/eval/__init__.py +0 -0
- stateMINT/eval/metrics.py +180 -0
- stateMINT/eval/viz_preds_truth.py +151 -0
- stateMINT/filter_raw_data.py +45 -0
- stateMINT/model/__init__.py +3 -0
- stateMINT/model/hub.py +144 -0
- stateMINT/model/mamba2.py +171 -0
- stateMINT/model_export.py +90 -0
- stateMINT/train.py +158 -0
- stateMINT/training/__init__.py +0 -0
- stateMINT/training/checkpoint.py +147 -0
- stateMINT/training/loss.py +35 -0
- stateMINT/training/train_step.py +167 -0
- stateMINT/visualise_predictions.py +55 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mintstate
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: StateMINT is a state space based neural network emulator for malariasimulation.
|
|
5
|
+
Project-URL: Homepage, https://github.com/mrc-ide/stateMINT
|
|
6
|
+
Project-URL: Repository, https://github.com/mrc-ide/stateMINT
|
|
7
|
+
Project-URL: Weights, https://huggingface.co/dide-ic/stateMINT
|
|
8
|
+
Author-email: Anmol Thapar <mr.anmolthapar@gmail.com>
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Requires-Dist: etils>=1.0
|
|
11
|
+
Requires-Dist: flax>=0.12.7
|
|
12
|
+
Requires-Dist: huggingface-hub>=1.19.0
|
|
13
|
+
Requires-Dist: jax>=0.10.1
|
|
14
|
+
Requires-Dist: jaxtyping>=0.3.10
|
|
15
|
+
Requires-Dist: mamba2-jax>=1.0.1
|
|
16
|
+
Requires-Dist: numpy>=2.4.6
|
|
17
|
+
Requires-Dist: omegaconf>=2.3
|
|
18
|
+
Requires-Dist: orbax-checkpoint>=0.12.0
|
|
19
|
+
Requires-Dist: pandas>=3.0.3
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: duckdb>=1.5.3; extra == 'all'
|
|
22
|
+
Requires-Dist: grain>=0.2.16; extra == 'all'
|
|
23
|
+
Requires-Dist: hydra-core>=1.3.2; extra == 'all'
|
|
24
|
+
Requires-Dist: jax[cuda12]>=0.10.1; extra == 'all'
|
|
25
|
+
Requires-Dist: matplotlib>=3.10.9; extra == 'all'
|
|
26
|
+
Requires-Dist: optax>=0.2.8; extra == 'all'
|
|
27
|
+
Requires-Dist: tqdm>=4.67.3; extra == 'all'
|
|
28
|
+
Requires-Dist: wandb>=0.27.0; extra == 'all'
|
|
29
|
+
Provides-Extra: gpu
|
|
30
|
+
Requires-Dist: jax[cuda12]>=0.10.1; extra == 'gpu'
|
|
31
|
+
Provides-Extra: plot
|
|
32
|
+
Requires-Dist: matplotlib>=3.10.9; extra == 'plot'
|
|
33
|
+
Provides-Extra: train
|
|
34
|
+
Requires-Dist: duckdb>=1.5.3; extra == 'train'
|
|
35
|
+
Requires-Dist: grain>=0.2.16; extra == 'train'
|
|
36
|
+
Requires-Dist: hydra-core>=1.3.2; extra == 'train'
|
|
37
|
+
Requires-Dist: optax>=0.2.8; extra == 'train'
|
|
38
|
+
Requires-Dist: tqdm>=4.67.3; extra == 'train'
|
|
39
|
+
Requires-Dist: wandb>=0.27.0; extra == 'train'
|
|
40
|
+
Description-Content-Type: text/markdown
|
|
41
|
+
|
|
42
|
+
# StateMINT
|
|
43
|
+
|
|
44
|
+
StateMINT is a JAX/Flax neural emulator for
|
|
45
|
+
[`malariasimulation`](https://github.com/mrc-ide/malariasimulation) outputs. It
|
|
46
|
+
uses a Mamba2 state-space sequence model to predict malaria trajectories from
|
|
47
|
+
static scenario covariates and intervention timing features. It supersedes the
|
|
48
|
+
earlier
|
|
49
|
+
[`MINTelligence`](https://github.com/CosmoNaught/MINTelligence) RNN emulator.
|
|
50
|
+
|
|
51
|
+
## What StateMINT Provides
|
|
52
|
+
|
|
53
|
+
- Mamba2-based sequence regressors for malaria prevalence and case-count
|
|
54
|
+
trajectories.
|
|
55
|
+
- Data extraction utilities for aggregating raw `malariasimulation` DuckDB
|
|
56
|
+
outputs into model-ready parquet files.
|
|
57
|
+
- Preprocessing with target transforms, covariate scaling, and
|
|
58
|
+
intervention-aware feature construction.
|
|
59
|
+
- Training, evaluation, visualization, checkpointing, and export workflows.
|
|
60
|
+
- Hugging Face Hub loading utilities for exported inference artifacts.
|
|
61
|
+
|
|
62
|
+
## Installation
|
|
63
|
+
|
|
64
|
+
StateMINT requires Python 3.12 or newer and uses
|
|
65
|
+
[`uv`](https://github.com/astral-sh/uv).
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
git clone https://github.com/mrc-ide/stateMINT.git
|
|
69
|
+
cd stateMINT
|
|
70
|
+
uv sync
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
For development dependencies and optional extras:
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
uv sync --all-extras --dev
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
Or install extras individually:
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
uv sync --extra plot
|
|
83
|
+
uv sync --extra gpu
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Quick Start: Inference
|
|
87
|
+
|
|
88
|
+
Load an exported artifact from the Hugging Face Hub or a local directory with
|
|
89
|
+
`Mamba2Regressor.from_pretrained`.
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from stateMINT.model import Mamba2Regressor
|
|
93
|
+
|
|
94
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
95
|
+
"dide-ic/stateMINT",
|
|
96
|
+
predictor="prevalence",
|
|
97
|
+
revision="v1.0.0",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
static_covars = [{
|
|
101
|
+
"eir": 50.0,
|
|
102
|
+
"dn0_use": 0.3,
|
|
103
|
+
"dn0_future": 0.4,
|
|
104
|
+
"Q0": 0.8,
|
|
105
|
+
"phi_bednets": 0.7,
|
|
106
|
+
"seasonal": 1.0,
|
|
107
|
+
"routine": 0.5,
|
|
108
|
+
"itn_use": 0.2,
|
|
109
|
+
"irs_use": 0.1,
|
|
110
|
+
"itn_future": 0.3,
|
|
111
|
+
"irs_future": 0.2,
|
|
112
|
+
"lsm": 0.0,
|
|
113
|
+
}]
|
|
114
|
+
|
|
115
|
+
predicted_prevalence = artifact.predict(static_covars)
|
|
116
|
+
|
|
117
|
+
print(predicted_prevalence.shape) # (batch, timesteps)
|
|
118
|
+
print(predicted_prevalence[0]) # first trajectory
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
For cases, load the cases artifact and use the same input format:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
125
|
+
"dide-ic/stateMINT",
|
|
126
|
+
predictor="cases",
|
|
127
|
+
revision="v1.0.0",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
predicted_cases = artifact.predict(static_covars)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
By default, predictions are returned on the original target scale: prevalence as
|
|
134
|
+
probabilities and cases on the scale used by the training data. Pass
|
|
135
|
+
`transformed=True` to return model-space outputs.
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
raw_model_space = artifact.predict(static_covars, transformed=True)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
For local artifacts, pass the target artifact directory:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
145
|
+
"artifacts/prevalence",
|
|
146
|
+
predictor="prevalence",
|
|
147
|
+
)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## Static Covariates
|
|
151
|
+
|
|
152
|
+
Inference inputs need one dictionary per scenario with these static covariates:
|
|
153
|
+
|
|
154
|
+
```text
|
|
155
|
+
eir
|
|
156
|
+
dn0_use
|
|
157
|
+
dn0_future
|
|
158
|
+
Q0
|
|
159
|
+
phi_bednets
|
|
160
|
+
seasonal
|
|
161
|
+
routine
|
|
162
|
+
itn_use
|
|
163
|
+
irs_use
|
|
164
|
+
itn_future
|
|
165
|
+
irs_future
|
|
166
|
+
lsm
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
Artifacts include the fitted static scaler, timestep grid, intervention day,
|
|
170
|
+
target transform, and other preprocessing metadata needed for inference.
|
|
171
|
+
|
|
172
|
+
## Training Workflow
|
|
173
|
+
|
|
174
|
+
Typical workflow:
|
|
175
|
+
|
|
176
|
+
1. Fetch and aggregate simulation data from DuckDB.
|
|
177
|
+
2. Train a target-specific model.
|
|
178
|
+
3. Evaluate or visualize test-set predictions.
|
|
179
|
+
4. Export the checkpoint into a portable inference artifact.
|
|
180
|
+
5. Upload the artifact to the Hugging Face Hub, if needed.
|
|
181
|
+
|
|
182
|
+
### 1. Fetch Filtered Data
|
|
183
|
+
|
|
184
|
+
`stateMINT.filter_raw_data` reads raw DuckDB simulation rows, filters burn-in,
|
|
185
|
+
aggregates fixed windows, and writes `filtered_data_<predictor>.parquet`.
|
|
186
|
+
|
|
187
|
+
```bash
|
|
188
|
+
uv run python -m stateMINT.filter_raw_data \
|
|
189
|
+
--db-path /path/to/simulations.duckdb \
|
|
190
|
+
--table-name simulation_results \
|
|
191
|
+
--predictor prevalence \
|
|
192
|
+
--window-size 14 \
|
|
193
|
+
--output-folder data
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
Useful options:
|
|
197
|
+
|
|
198
|
+
- `--predictor prevalence` or `--predictor cases`
|
|
199
|
+
- `--param-limit N` to keep only the first `N` parameter indices
|
|
200
|
+
- `--sim-limit N` to sample up to `N` simulations per parameter
|
|
201
|
+
|
|
202
|
+
The raw table should include identifiers (`parameter_index`, `simulation_index`,
|
|
203
|
+
`global_index`), daily timesteps, the static covariates above, and output
|
|
204
|
+
columns for prevalence or cases.
|
|
205
|
+
|
|
206
|
+
### 2. Train a Model
|
|
207
|
+
|
|
208
|
+
Training uses Hydra; the default target is prevalence.
|
|
209
|
+
|
|
210
|
+
```bash
|
|
211
|
+
uv run python -m stateMINT.train
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
Train the cases model:
|
|
215
|
+
|
|
216
|
+
```bash
|
|
217
|
+
uv run python -m stateMINT.train target=cases
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
Common overrides:
|
|
221
|
+
|
|
222
|
+
```bash
|
|
223
|
+
uv run python -m stateMINT.train \
|
|
224
|
+
target=prevalence \
|
|
225
|
+
data_file=data/filtered_data_prevalence.parquet \
|
|
226
|
+
output_dir=train_outputs/prevalence \
|
|
227
|
+
use_wandb=false
|
|
228
|
+
```
|
|
229
|
+
|
|
230
|
+
Training writes checkpoints under `checkpoint_dir`, saves
|
|
231
|
+
`static_scaler.pkl` in `output_dir`, and reuses a split assignment file for
|
|
232
|
+
consistent train/validation/test splits.
|
|
233
|
+
|
|
234
|
+
### 3. W&B Sweeps
|
|
235
|
+
|
|
236
|
+
Sweep definitions live in `stateMINT/conf/sweeps`. Create a sweep, then run one
|
|
237
|
+
or more agents with the sweep ID returned by W&B:
|
|
238
|
+
|
|
239
|
+
```bash
|
|
240
|
+
uv run wandb sweep stateMINT/conf/sweeps/prevalence.yaml
|
|
241
|
+
uv run wandb agent <entity>/stateMINT-sweep/<sweep-id>
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
Use `stateMINT/conf/sweeps/cases.yaml` for the cases target. Sweep commands set
|
|
245
|
+
`use_wandb=true` and pass Hydra overrides through `${args_no_hyphens}`.
|
|
246
|
+
|
|
247
|
+
### 4. Visualize Predictions
|
|
248
|
+
|
|
249
|
+
Compare predictions with test-set targets. `checkpoint_dir` is required.
|
|
250
|
+
|
|
251
|
+
```bash
|
|
252
|
+
uv run python -m stateMINT.visualise_predictions \
|
|
253
|
+
target=prevalence \
|
|
254
|
+
checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
|
|
255
|
+
data_file=data/filtered_data_prevalence.parquet
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
The default output path is `viz_outputs/<predictor>/preds-vs-targets.pdf`.
|
|
259
|
+
|
|
260
|
+
### 5. Export an Artifact
|
|
261
|
+
|
|
262
|
+
Export converts a trained Orbax checkpoint and preprocessing metadata into a
|
|
263
|
+
self-contained artifact.
|
|
264
|
+
|
|
265
|
+
```bash
|
|
266
|
+
uv run python -m stateMINT.model_export \
|
|
267
|
+
predictor=prevalence \
|
|
268
|
+
checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
|
|
269
|
+
scaler_file=train_outputs/prevalence/static_scaler.pkl \
|
|
270
|
+
artifact_dir=artifacts/prevalence
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
Export config architecture values must match the checkpoint, including
|
|
274
|
+
`d_model`, `d_state`, `n_layers`, `dropout`, and related Mamba2 settings.
|
|
275
|
+
|
|
276
|
+
An exported artifact contains:
|
|
277
|
+
|
|
278
|
+
```text
|
|
279
|
+
artifact_dir/
|
|
280
|
+
|-- checkpoint/
|
|
281
|
+
|-- model_config.json
|
|
282
|
+
`-- preprocessing_config.json
|
|
283
|
+
```
|
|
284
|
+
|
|
285
|
+
`model_config.json` stores architecture settings; `preprocessing_config.json`
|
|
286
|
+
stores feature order, target transform, intervention timing, timestep
|
|
287
|
+
construction, and static scaler parameters.
|
|
288
|
+
|
|
289
|
+
### 6. Upload To Hugging Face
|
|
290
|
+
|
|
291
|
+
Authenticate first:
|
|
292
|
+
|
|
293
|
+
```bash
|
|
294
|
+
hf auth login
|
|
295
|
+
```
|
|
296
|
+
|
|
297
|
+
Upload an artifact:
|
|
298
|
+
|
|
299
|
+
```bash
|
|
300
|
+
hf upload dide-ic/stateMINT artifacts/prevalence prevalence/ \
|
|
301
|
+
--commit-message "Add prevalence model artifact"
|
|
302
|
+
```
|
|
303
|
+
|
|
304
|
+
Create a release tag:
|
|
305
|
+
|
|
306
|
+
```bash
|
|
307
|
+
hf repos tag create dide-ic/stateMINT v1.0.0 \
|
|
308
|
+
--revision main \
|
|
309
|
+
--message "Release v1.0.0"
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
## Configuration
|
|
313
|
+
|
|
314
|
+
Main Hydra configs in `stateMINT/conf`:
|
|
315
|
+
|
|
316
|
+
- `train_config.yaml` for training.
|
|
317
|
+
- `viz_config.yaml` for prediction visualizations.
|
|
318
|
+
- `export_config.yaml` for artifact export.
|
|
319
|
+
- `target/prevalence.yaml` and `target/cases.yaml` for target-specific defaults.
|
|
320
|
+
- `sweeps/*.yaml` for Weights & Biases sweep definitions.
|
|
321
|
+
|
|
322
|
+
Select a target with `target=prevalence` or `target=cases`; override config
|
|
323
|
+
values from the command line with Hydra syntax.
|
|
324
|
+
|
|
325
|
+
## Development
|
|
326
|
+
|
|
327
|
+
Run the test suite:
|
|
328
|
+
|
|
329
|
+
```bash
|
|
330
|
+
uv run pytest tests/
|
|
331
|
+
```
|
|
332
|
+
|
|
333
|
+
Skip slow or local-only tests:
|
|
334
|
+
|
|
335
|
+
```bash
|
|
336
|
+
uv run pytest tests/ -m "not slow"
|
|
337
|
+
uv run pytest tests/ -m "not local"
|
|
338
|
+
uv run pytest tests/ -m "not slow and not local" # skip both
|
|
339
|
+
```
|
|
340
|
+
|
|
341
|
+
Run linting and formatting:
|
|
342
|
+
|
|
343
|
+
```bash
|
|
344
|
+
uv run ruff check
|
|
345
|
+
uv run ruff format
|
|
346
|
+
```
|
|
347
|
+
|
|
348
|
+
## Repository Layout
|
|
349
|
+
|
|
350
|
+
```text
|
|
351
|
+
stateMINT/
|
|
352
|
+
|-- common/ # shared dataclasses, transforms, and model helpers
|
|
353
|
+
|-- conf/ # Hydra configs for training, export, viz, and sweeps
|
|
354
|
+
|-- data/ # DuckDB fetch, preprocessing, features, and loaders
|
|
355
|
+
|-- eval/ # metrics and prediction/target visualization helpers
|
|
356
|
+
|-- model/ # Mamba2 regressor and artifact loading utilities
|
|
357
|
+
|-- training/ # optimizer, train/eval steps, loss, and checkpointing
|
|
358
|
+
|-- filter_raw_data.py # CLI for building filtered parquet datasets
|
|
359
|
+
|-- train.py # Hydra training entry point
|
|
360
|
+
|-- visualise_predictions.py
|
|
361
|
+
`-- model_export.py # Hydra artifact export entry point
|
|
362
|
+
|
|
363
|
+
tests/ # unit tests
|
|
364
|
+
artifacts/ # exported model artifact examples/metadata
|
|
365
|
+
viz_outputs/ # generated prediction visualization outputs
|
|
366
|
+
```
|
|
367
|
+
|
|
368
|
+
## Contributing
|
|
369
|
+
|
|
370
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines and the
|
|
371
|
+
development workflow.
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
stateMINT/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
stateMINT/filter_raw_data.py,sha256=KjOC7GeZ19lOl4i4PTqkC17cCCYdMPrrizfQftlwKLM,1623
|
|
3
|
+
stateMINT/model_export.py,sha256=JynbP6gTHgV7ZrSDpWWf4l_lZbHl89zd5R_ZTaTupL4,2903
|
|
4
|
+
stateMINT/train.py,sha256=eI_f2gjCXe-PYXZd1Gm_cv404GezX793Cbb5g0n859E,5696
|
|
5
|
+
stateMINT/visualise_predictions.py,sha256=4MrwoutGevqHilSEAWToCchuxHcykZuPkFrX5d-bBlo,1853
|
|
6
|
+
stateMINT/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
stateMINT/common/dataclasses.py,sha256=YYVby4nOZQehS5qiYngVkdomDL8BUkZQ6xoTa5gAr98,260
|
|
8
|
+
stateMINT/common/utils.py,sha256=oLmz5HWAb5QwXP26_4WmeJ_AefBCW6wAT5pzQ7n0X7w,1748
|
|
9
|
+
stateMINT/conf/export_config.yaml,sha256=aPmSsxJMJvZiigfjrpwavVNidl2tR9yJVJTcdMtoJmM,426
|
|
10
|
+
stateMINT/conf/train_config.yaml,sha256=lszJtBir7GEvVKjlfFUeBzgtHnKnAUiJ5V7DFyuZgQQ,911
|
|
11
|
+
stateMINT/conf/viz_config.yaml,sha256=q74YLMK7QWoBm_uXJTWZW2GcAx5AaMTjI_tLMD_YyrI,517
|
|
12
|
+
stateMINT/conf/sweeps/cases.yaml,sha256=P2w1a3VDwP_6h7bBVRkt7DG9rUe4EDl--cfSeAbkhe0,699
|
|
13
|
+
stateMINT/conf/sweeps/prevalence.yaml,sha256=cQNg8YqldLuzYwwUERjjQvZaK3EMfiOpoWDGY92k-Ig,711
|
|
14
|
+
stateMINT/conf/target/cases.yaml,sha256=9LkY6jjAoC2tD1QObJlghTO0t1xl95Q3x9ypyTfgVh4,199
|
|
15
|
+
stateMINT/conf/target/prevalence.yaml,sha256=nUKOpVnqutYpCuEchz5rhbVr_bZ6tpFZK_pPbBqBIao,198
|
|
16
|
+
stateMINT/data/__init__.py,sha256=FXunf0wXzF1jDJUHo2AogDqvzmbryc7bLo9Mg6VjLmQ,403
|
|
17
|
+
stateMINT/data/dataset.py,sha256=idTvwz0i3ualujXp-Bh8F67ROASNT1vr5ygHqf_RbDQ,1689
|
|
18
|
+
stateMINT/data/features.py,sha256=W64e3PE4ZRM7ufLyPB4sJGWDnSpmpJZxkMNgWqfuQxc,2714
|
|
19
|
+
stateMINT/data/fetch.py,sha256=MZelwGuOkGwx_5b0rKWryIoUJPNF7GgxoissA2Ebyvk,6110
|
|
20
|
+
stateMINT/data/preprocessing.py,sha256=KBbH-lIMUgixac3YjfENBr5YagHvRy-uA7pzFJb3OFo,17428
|
|
21
|
+
stateMINT/eval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
+
stateMINT/eval/metrics.py,sha256=J79ANLzGGb3s_ARFExF4AsR0YRqPULQsKwVL0MmZKRk,4335
|
|
23
|
+
stateMINT/eval/viz_preds_truth.py,sha256=CTDT58MkRZgxVRaZDUCTofCW6VOWZQtVPN07Od-bSzo,6104
|
|
24
|
+
stateMINT/model/__init__.py,sha256=xL_8A4TghG_bYEYuU8qNl_ov5yFtB8WcvYzkPp239vY,67
|
|
25
|
+
stateMINT/model/hub.py,sha256=OHveRm0qvXpuAd6f-4QOg7MTkO-7tJ3BGfC6npWSmEg,4802
|
|
26
|
+
stateMINT/model/mamba2.py,sha256=IIo3VDuSq1UbcBnDWpuOCB-s6wCJMRhL7FOe_pMIAek,5581
|
|
27
|
+
stateMINT/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
|
+
stateMINT/training/checkpoint.py,sha256=jKQmeP28aCY_C7koHL3P8OmQHsmDhVA0QjmL77ryO6A,4372
|
|
29
|
+
stateMINT/training/loss.py,sha256=ZmP62C8i-xRB_pNMTGl4-LkT5esrtgwOmpJozbbL64Q,746
|
|
30
|
+
stateMINT/training/train_step.py,sha256=CMYt2kcmpKXI5qgsk2oN52AKgr-emyPwjTj9OU70sPA,4469
|
|
31
|
+
mintstate-0.1.0.dist-info/METADATA,sha256=Z_ZP7Okvm8OFnrYKB2iBlb1SAQVyNzm9mlQN_YQVmEE,10137
|
|
32
|
+
mintstate-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
|
|
33
|
+
mintstate-0.1.0.dist-info/RECORD,,
|
stateMINT/__init__.py
ADDED
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import Literal, Protocol
|
|
2
|
+
|
|
3
|
+
from flax import nnx
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
|
|
6
|
+
Predictor = Literal["prevalence", "cases"]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelFactory(Protocol):
|
|
10
|
+
@classmethod
|
|
11
|
+
def from_cfg(cls, cfg: DictConfig, input_size: int) -> nnx.Module: ...
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import jax
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
from .dataclasses import Predictor
|
|
5
|
+
from jaxtyping import Array
|
|
6
|
+
from flax import nnx
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def transform_targets_np(y: np.ndarray, predictor: Predictor, eps: float = 1e-5) -> np.ndarray:
|
|
10
|
+
"""
|
|
11
|
+
Apply train-time transform to targets.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
y: Target values.
|
|
15
|
+
predictor: Target type.
|
|
16
|
+
eps: Prevalence clipping epsilon.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Transformed target values.
|
|
20
|
+
"""
|
|
21
|
+
if predictor == "prevalence":
|
|
22
|
+
y = np.clip(y, eps, 1.0 - eps)
|
|
23
|
+
return np.log(y / (1.0 - y)) # logit transform
|
|
24
|
+
else:
|
|
25
|
+
return np.log1p(np.maximum(y, 0.0)) # log1p transform for counts/rates
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def inverse_transform_np(y: np.ndarray, predictor: Predictor) -> np.ndarray:
|
|
29
|
+
"""
|
|
30
|
+
Invert transform for metrics/plots.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
y: Transformed target values.
|
|
34
|
+
predictor: Target type.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Values in the original target scale.
|
|
38
|
+
"""
|
|
39
|
+
if predictor == "prevalence":
|
|
40
|
+
return 1.0 / (1.0 + np.exp(-y)) # sigmoid
|
|
41
|
+
else:
|
|
42
|
+
return np.expm1(y)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def inverse_transform_jax(y: jax.Array, predictor: Predictor) -> jax.Array:
|
|
46
|
+
"""
|
|
47
|
+
Invert transformed targets with JAX.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
y: Transformed target values.
|
|
51
|
+
predictor: Target type.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Values in the original target scale.
|
|
55
|
+
"""
|
|
56
|
+
if predictor == "prevalence":
|
|
57
|
+
return jax.nn.sigmoid(y)
|
|
58
|
+
else:
|
|
59
|
+
return jnp.expm1(y)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@nnx.jit
|
|
63
|
+
def forward(model: nnx.Module, x: Array) -> Array:
|
|
64
|
+
"""
|
|
65
|
+
Forward pass through the model.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model: Model to evaluate.
|
|
69
|
+
x: Input batch.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Model predictions with shape (B, T).
|
|
73
|
+
"""
|
|
74
|
+
return model(x).squeeze(-1) # (B, T, 1) -> (B, T)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
predictor: prevalence
|
|
2
|
+
|
|
3
|
+
# Data
|
|
4
|
+
eps_prevalence: 1.0e-5
|
|
5
|
+
scaler_file: train_outputs/${predictor}/static_scaler.pkl
|
|
6
|
+
window_size: 14
|
|
7
|
+
use_cyclical_time: true
|
|
8
|
+
|
|
9
|
+
# Model - ensure these match the checkpointed model's parameters
|
|
10
|
+
n_layers: 4
|
|
11
|
+
d_conv: 4
|
|
12
|
+
expand: 2
|
|
13
|
+
head_dim: 64
|
|
14
|
+
chunk_size: 256
|
|
15
|
+
output_dim: 1
|
|
16
|
+
d_model: 256
|
|
17
|
+
d_state: 128
|
|
18
|
+
dropout: 0.3
|
|
19
|
+
|
|
20
|
+
# Checkpointing
|
|
21
|
+
checkpoint_dir: ???
|
|
22
|
+
|
|
23
|
+
# General
|
|
24
|
+
seed: 42
|
|
25
|
+
artifact_dir: "artifacts/${predictor}"
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
program: stateMINT/train.py
|
|
2
|
+
project: stateMINT-sweep
|
|
3
|
+
name: sweep-cases
|
|
4
|
+
method: bayes
|
|
5
|
+
metric:
|
|
6
|
+
goal: minimize
|
|
7
|
+
name: val/loss
|
|
8
|
+
|
|
9
|
+
# early_terminate:
|
|
10
|
+
# type: hyperband
|
|
11
|
+
# min_iter: 10
|
|
12
|
+
# eta: 3
|
|
13
|
+
|
|
14
|
+
parameters:
|
|
15
|
+
lr:
|
|
16
|
+
distribution: log_uniform_values
|
|
17
|
+
min: 0.001
|
|
18
|
+
max: 0.01
|
|
19
|
+
dropout:
|
|
20
|
+
distribution: uniform
|
|
21
|
+
min: 0.25
|
|
22
|
+
max: 0.6
|
|
23
|
+
batch_size:
|
|
24
|
+
values: [128]
|
|
25
|
+
|
|
26
|
+
# model-specific
|
|
27
|
+
d_state:
|
|
28
|
+
values: [128]
|
|
29
|
+
d_model:
|
|
30
|
+
values: [256]
|
|
31
|
+
n_layers:
|
|
32
|
+
values: [1, 2, 3, 4]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
command:
|
|
36
|
+
- ${env}
|
|
37
|
+
- python
|
|
38
|
+
- -m
|
|
39
|
+
- stateMINT.train
|
|
40
|
+
- "hydra.output_subdir=null"
|
|
41
|
+
- "hydra.run.dir=."
|
|
42
|
+
- ${args_no_hyphens}
|
|
43
|
+
- target=cases
|
|
44
|
+
- use_wandb=true
|
|
45
|
+
- num_epochs=300
|
|
46
|
+
- patience=100
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
program: stateMINT/train.py
|
|
2
|
+
project: stateMINT-sweep
|
|
3
|
+
name: sweep-prevalence
|
|
4
|
+
method: bayes
|
|
5
|
+
metric:
|
|
6
|
+
goal: minimize
|
|
7
|
+
name: val/loss
|
|
8
|
+
|
|
9
|
+
# early_terminate:
|
|
10
|
+
# type: hyperband
|
|
11
|
+
# min_iter: 40
|
|
12
|
+
# eta: 3
|
|
13
|
+
|
|
14
|
+
parameters:
|
|
15
|
+
lr:
|
|
16
|
+
distribution: log_uniform_values
|
|
17
|
+
min: 0.0001
|
|
18
|
+
max: 0.001
|
|
19
|
+
dropout:
|
|
20
|
+
distribution: uniform
|
|
21
|
+
min: 0.25
|
|
22
|
+
max: 0.6
|
|
23
|
+
batch_size:
|
|
24
|
+
values: [128]
|
|
25
|
+
|
|
26
|
+
# model-specific
|
|
27
|
+
d_state:
|
|
28
|
+
values: [128]
|
|
29
|
+
d_model:
|
|
30
|
+
values: [256]
|
|
31
|
+
n_layers:
|
|
32
|
+
values: [1, 2, 3, 4]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
command:
|
|
36
|
+
- ${env}
|
|
37
|
+
- python
|
|
38
|
+
- -m
|
|
39
|
+
- stateMINT.train
|
|
40
|
+
- "hydra.output_subdir=null"
|
|
41
|
+
- "hydra.run.dir=."
|
|
42
|
+
- ${args_no_hyphens}
|
|
43
|
+
- target=prevalence
|
|
44
|
+
- use_wandb=true
|
|
45
|
+
- num_epochs=300
|
|
46
|
+
- patience=100
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- _self_ # Load current and override with the following defaults
|
|
3
|
+
- target: prevalence
|
|
4
|
+
|
|
5
|
+
# Data
|
|
6
|
+
data_file: "data/filtered_data_${predictor}.parquet"
|
|
7
|
+
split_file: "data/split_${predictor}.json"
|
|
8
|
+
num_workers: 0
|
|
9
|
+
use_existing_split: false
|
|
10
|
+
eps_prevalence: 1.0e-5
|
|
11
|
+
use_cyclical_time: true
|
|
12
|
+
|
|
13
|
+
# Model
|
|
14
|
+
loss_method: stateMINT.training.loss.weighted_mse
|
|
15
|
+
n_layers: 4
|
|
16
|
+
d_conv: 4
|
|
17
|
+
expand: 2
|
|
18
|
+
head_dim: 64
|
|
19
|
+
chunk_size: 256
|
|
20
|
+
output_dim: 1
|
|
21
|
+
d_model: 256
|
|
22
|
+
d_state: 128
|
|
23
|
+
|
|
24
|
+
# Hyperparameters
|
|
25
|
+
num_epochs: 200
|
|
26
|
+
min_epochs: 100
|
|
27
|
+
patience: 50
|
|
28
|
+
diff_loss_alpha: 0.05
|
|
29
|
+
lr: 1e-3
|
|
30
|
+
batch_size: 128
|
|
31
|
+
dropout: 0.3
|
|
32
|
+
weight_decay: 1e-4
|
|
33
|
+
|
|
34
|
+
# Checkpointing
|
|
35
|
+
checkpoint_dir: "${output_dir}/ckpts-${cur_time}"
|
|
36
|
+
restore_checkpoint: false
|
|
37
|
+
max_checkpoints_to_keep: 1
|
|
38
|
+
|
|
39
|
+
# General
|
|
40
|
+
cur_time: ${now:%Y-%m-%dT%H:%M:%S}
|
|
41
|
+
seed: 42
|
|
42
|
+
use_wandb: false
|
|
43
|
+
output_dir: "train_outputs/${predictor}"
|
|
44
|
+
wandb:
|
|
45
|
+
project: "stateMINT-${predictor}"
|
|
46
|
+
name: "train-${predictor}-${cur_time}"
|