mintstate 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mintstate-0.1.0/.github/workflows/lint.yaml +30 -0
- mintstate-0.1.0/.github/workflows/release.yaml +33 -0
- mintstate-0.1.0/.github/workflows/test.yaml +28 -0
- mintstate-0.1.0/.gitignore +135 -0
- mintstate-0.1.0/.python-version +1 -0
- mintstate-0.1.0/.vscode/launch.json +46 -0
- mintstate-0.1.0/CONTRIBUTING.md +27 -0
- mintstate-0.1.0/PKG-INFO +371 -0
- mintstate-0.1.0/README.md +330 -0
- mintstate-0.1.0/pyproject.toml +65 -0
- mintstate-0.1.0/scenarios_prevalence.pdf +0 -0
- mintstate-0.1.0/stateMINT/__init__.py +0 -0
- mintstate-0.1.0/stateMINT/common/__init__.py +0 -0
- mintstate-0.1.0/stateMINT/common/dataclasses.py +11 -0
- mintstate-0.1.0/stateMINT/common/utils.py +74 -0
- mintstate-0.1.0/stateMINT/conf/export_config.yaml +25 -0
- mintstate-0.1.0/stateMINT/conf/sweeps/cases.yaml +46 -0
- mintstate-0.1.0/stateMINT/conf/sweeps/prevalence.yaml +46 -0
- mintstate-0.1.0/stateMINT/conf/target/cases.yaml +18 -0
- mintstate-0.1.0/stateMINT/conf/target/prevalence.yaml +18 -0
- mintstate-0.1.0/stateMINT/conf/train_config.yaml +46 -0
- mintstate-0.1.0/stateMINT/conf/viz_config.yaml +31 -0
- mintstate-0.1.0/stateMINT/data/__init__.py +21 -0
- mintstate-0.1.0/stateMINT/data/dataset.py +77 -0
- mintstate-0.1.0/stateMINT/data/features.py +104 -0
- mintstate-0.1.0/stateMINT/data/fetch.py +194 -0
- mintstate-0.1.0/stateMINT/data/preprocessing.py +504 -0
- mintstate-0.1.0/stateMINT/eval/__init__.py +0 -0
- mintstate-0.1.0/stateMINT/eval/metrics.py +180 -0
- mintstate-0.1.0/stateMINT/eval/viz_preds_truth.py +151 -0
- mintstate-0.1.0/stateMINT/filter_raw_data.py +45 -0
- mintstate-0.1.0/stateMINT/model/__init__.py +3 -0
- mintstate-0.1.0/stateMINT/model/hub.py +144 -0
- mintstate-0.1.0/stateMINT/model/mamba2.py +171 -0
- mintstate-0.1.0/stateMINT/model_export.py +90 -0
- mintstate-0.1.0/stateMINT/train.py +158 -0
- mintstate-0.1.0/stateMINT/training/__init__.py +0 -0
- mintstate-0.1.0/stateMINT/training/checkpoint.py +147 -0
- mintstate-0.1.0/stateMINT/training/loss.py +35 -0
- mintstate-0.1.0/stateMINT/training/train_step.py +167 -0
- mintstate-0.1.0/stateMINT/visualise_predictions.py +55 -0
- mintstate-0.1.0/tests/__init__.py +0 -0
- mintstate-0.1.0/tests/common/__init__.py +0 -0
- mintstate-0.1.0/tests/common/test_utils.py +49 -0
- mintstate-0.1.0/tests/conftest.py +230 -0
- mintstate-0.1.0/tests/data/__init__.py +0 -0
- mintstate-0.1.0/tests/data/test_dataset.py +21 -0
- mintstate-0.1.0/tests/data/test_features.py +10 -0
- mintstate-0.1.0/tests/data/test_fetch.py +154 -0
- mintstate-0.1.0/tests/data/test_preprocessing.py +310 -0
- mintstate-0.1.0/tests/eval/__init__.py +0 -0
- mintstate-0.1.0/tests/eval/test_metrics.py +34 -0
- mintstate-0.1.0/tests/eval/test_viz_preds_truth.py +166 -0
- mintstate-0.1.0/tests/model/__init__.py +0 -0
- mintstate-0.1.0/tests/model/test_hub.py +196 -0
- mintstate-0.1.0/tests/model/test_mamba2.py +109 -0
- mintstate-0.1.0/tests/smoke_test.py +3 -0
- mintstate-0.1.0/tests/training/__init__.py +0 -0
- mintstate-0.1.0/tests/training/test_checkpoint.py +40 -0
- mintstate-0.1.0/tests/training/test_loss.py +24 -0
- mintstate-0.1.0/tests/training/test_train_step.py +50 -0
- mintstate-0.1.0/uv.lock +2084 -0
- mintstate-0.1.0/viz_outputs/cases/preds-vs-targets.pdf +0 -0
- mintstate-0.1.0/viz_outputs/cases/static_scaler.pkl +0 -0
- mintstate-0.1.0/viz_outputs/prevalence/preds-vs-targets.pdf +0 -0
- mintstate-0.1.0/viz_outputs/prevalence/static_scaler.pkl +0 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
name: Lint
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
pull_request:
|
|
8
|
+
branches:
|
|
9
|
+
- "*"
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
lint-and-format:
|
|
13
|
+
name: Run lint and format checks
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
steps:
|
|
16
|
+
- name: Checkout repository
|
|
17
|
+
uses: actions/checkout@v6
|
|
18
|
+
- name: Install uv
|
|
19
|
+
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
|
|
20
|
+
with:
|
|
21
|
+
enable-cache: true
|
|
22
|
+
version: "0.11.18"
|
|
23
|
+
- name: Set up Python
|
|
24
|
+
run: uv python install
|
|
25
|
+
- name: Install the project
|
|
26
|
+
run: uv sync --locked --all-extras --dev
|
|
27
|
+
- name: Run lint checks
|
|
28
|
+
run: uv run ruff check --output-format=github
|
|
29
|
+
- name: Run format checks
|
|
30
|
+
run: uv run ruff format --check --diff
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
name: Publish Release to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- 'v*.*.*'
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
run:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
environment:
|
|
12
|
+
name: pypi
|
|
13
|
+
permissions:
|
|
14
|
+
id-token: write
|
|
15
|
+
contents: read
|
|
16
|
+
steps:
|
|
17
|
+
- name: Checkout code
|
|
18
|
+
uses: actions/checkout@v6
|
|
19
|
+
- name: Install uv
|
|
20
|
+
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
|
|
21
|
+
with:
|
|
22
|
+
enable-cache: true
|
|
23
|
+
version: "0.11.18"
|
|
24
|
+
- name: Set up Python
|
|
25
|
+
run: uv python install
|
|
26
|
+
- name: Build
|
|
27
|
+
run: uv build
|
|
28
|
+
- name: Smoke test (wheel)
|
|
29
|
+
run: uv run --isolated --no-project --with dist/*.whl tests/smoke_test.py
|
|
30
|
+
- name: Smoke test (source distribution)
|
|
31
|
+
run: uv run --isolated --no-project --with dist/*.tar.gz tests/smoke_test.py
|
|
32
|
+
- name: Publish
|
|
33
|
+
run: uv publish
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
name: Test
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches:
|
|
6
|
+
- main
|
|
7
|
+
pull_request:
|
|
8
|
+
branches:
|
|
9
|
+
- "*"
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
test:
|
|
13
|
+
name: Run tests
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
steps:
|
|
16
|
+
- name: Checkout repository
|
|
17
|
+
uses: actions/checkout@v6
|
|
18
|
+
- name: Install uv
|
|
19
|
+
uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
|
|
20
|
+
with:
|
|
21
|
+
enable-cache: true
|
|
22
|
+
version: "0.11.18"
|
|
23
|
+
- name: Set up Python
|
|
24
|
+
run: uv python install
|
|
25
|
+
- name: Install the project
|
|
26
|
+
run: uv sync --locked --all-extras --dev
|
|
27
|
+
- name: Run tests
|
|
28
|
+
run: uv run pytest -m "not local and not slow"
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[codz]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
artifacts/
|
|
11
|
+
.Python
|
|
12
|
+
build/
|
|
13
|
+
develop-eggs/
|
|
14
|
+
dist/
|
|
15
|
+
downloads/
|
|
16
|
+
eggs/
|
|
17
|
+
.eggs/
|
|
18
|
+
lib/
|
|
19
|
+
lib64/
|
|
20
|
+
parts/
|
|
21
|
+
sdist/
|
|
22
|
+
var/
|
|
23
|
+
wheels/
|
|
24
|
+
share/python-wheels/
|
|
25
|
+
*.egg-info/
|
|
26
|
+
.installed.cfg
|
|
27
|
+
*.egg
|
|
28
|
+
MANIFEST
|
|
29
|
+
|
|
30
|
+
# PyInstaller
|
|
31
|
+
# Usually these files are written by a python script from a template
|
|
32
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
33
|
+
*.manifest
|
|
34
|
+
*.spec
|
|
35
|
+
|
|
36
|
+
# Installer logs
|
|
37
|
+
pip-log.txt
|
|
38
|
+
pip-delete-this-directory.txt
|
|
39
|
+
|
|
40
|
+
# Unit test / coverage reports
|
|
41
|
+
htmlcov/
|
|
42
|
+
.tox/
|
|
43
|
+
.nox/
|
|
44
|
+
.coverage
|
|
45
|
+
.coverage.*
|
|
46
|
+
.cache
|
|
47
|
+
nosetests.xml
|
|
48
|
+
coverage.xml
|
|
49
|
+
*.cover
|
|
50
|
+
*.py.cover
|
|
51
|
+
.hypothesis/
|
|
52
|
+
.pytest_cache/
|
|
53
|
+
cover/
|
|
54
|
+
|
|
55
|
+
# Translations
|
|
56
|
+
*.mo
|
|
57
|
+
*.pot
|
|
58
|
+
|
|
59
|
+
# Django stuff:
|
|
60
|
+
*.log
|
|
61
|
+
local_settings.py
|
|
62
|
+
db.sqlite3
|
|
63
|
+
db.sqlite3-journal
|
|
64
|
+
|
|
65
|
+
# Flask stuff:
|
|
66
|
+
instance/
|
|
67
|
+
.webassets-cache
|
|
68
|
+
|
|
69
|
+
# Scrapy stuff:
|
|
70
|
+
.scrapy
|
|
71
|
+
|
|
72
|
+
# Sphinx documentation
|
|
73
|
+
docs/_build/
|
|
74
|
+
|
|
75
|
+
# Jupyter Notebook
|
|
76
|
+
.ipynb_checkpoints
|
|
77
|
+
|
|
78
|
+
# IPython
|
|
79
|
+
profile_default/
|
|
80
|
+
ipython_config.py
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# Environments
|
|
84
|
+
.env
|
|
85
|
+
.envrc
|
|
86
|
+
.venv
|
|
87
|
+
env/
|
|
88
|
+
venv/
|
|
89
|
+
ENV/
|
|
90
|
+
env.bak/
|
|
91
|
+
venv.bak/
|
|
92
|
+
|
|
93
|
+
# Spyder project settings
|
|
94
|
+
.spyderproject
|
|
95
|
+
.spyproject
|
|
96
|
+
|
|
97
|
+
# Rope project settings
|
|
98
|
+
.ropeproject
|
|
99
|
+
|
|
100
|
+
# mkdocs documentation
|
|
101
|
+
/site
|
|
102
|
+
|
|
103
|
+
# mypy
|
|
104
|
+
.mypy_cache/
|
|
105
|
+
.dmypy.json
|
|
106
|
+
dmypy.json
|
|
107
|
+
|
|
108
|
+
# Pyre type checker
|
|
109
|
+
.pyre/
|
|
110
|
+
|
|
111
|
+
# pytype static type analyzer
|
|
112
|
+
.pytype/
|
|
113
|
+
|
|
114
|
+
# Cython debug symbols
|
|
115
|
+
cython_debug/
|
|
116
|
+
|
|
117
|
+
# Ruff stuff:
|
|
118
|
+
.ruff_cache/
|
|
119
|
+
|
|
120
|
+
# PyPI configuration file
|
|
121
|
+
.pypirc
|
|
122
|
+
|
|
123
|
+
# testing scripts
|
|
124
|
+
*.ipynb
|
|
125
|
+
x.py
|
|
126
|
+
|
|
127
|
+
# hydra logs
|
|
128
|
+
outputs
|
|
129
|
+
|
|
130
|
+
# misc
|
|
131
|
+
inputs
|
|
132
|
+
/data/
|
|
133
|
+
ref.py
|
|
134
|
+
train_outputs/
|
|
135
|
+
wandb/
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
{
|
|
2
|
+
// Use IntelliSense to learn about possible attributes.
|
|
3
|
+
// Hover to view descriptions of existing attributes.
|
|
4
|
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
|
5
|
+
"version": "0.2.0",
|
|
6
|
+
"configurations": [
|
|
7
|
+
{
|
|
8
|
+
"name": "Debug run train",
|
|
9
|
+
"type": "debugpy",
|
|
10
|
+
"request": "launch",
|
|
11
|
+
"python": "/home/anmol/code/stateMINT/.venv/bin/python",
|
|
12
|
+
"module": "stateMINT.viz",
|
|
13
|
+
// "args": [
|
|
14
|
+
// "num_epochs=10",
|
|
15
|
+
// ],
|
|
16
|
+
"env": {
|
|
17
|
+
// "JAX_DISABLE_JIT": "true",
|
|
18
|
+
// "UV_MANAGED_PYTHON": "true"
|
|
19
|
+
"CUDA_VISIBLE_DEVICES": "1"
|
|
20
|
+
},
|
|
21
|
+
"console": "integratedTerminal",
|
|
22
|
+
"cwd": "${workspaceFolder}",
|
|
23
|
+
"justMyCode": false
|
|
24
|
+
},
|
|
25
|
+
{
|
|
26
|
+
"name": "debug test function",
|
|
27
|
+
"type": "debugpy",
|
|
28
|
+
"request": "launch",
|
|
29
|
+
"python": "/home/anmol/code/stateMINT/.venv/bin/python",
|
|
30
|
+
"module": "pytest",
|
|
31
|
+
"args": [
|
|
32
|
+
"tests/model/test_mamba2.py",
|
|
33
|
+
"-k",
|
|
34
|
+
"test_prevalence_from_pretrained_from_hub"
|
|
35
|
+
],
|
|
36
|
+
"env": {
|
|
37
|
+
// "JAX_DISABLE_JIT": "true",
|
|
38
|
+
// "UV_MANAGED_PYTHON": "true"
|
|
39
|
+
// "CUDA_VISIBLE_DEVICES": "1"
|
|
40
|
+
},
|
|
41
|
+
"console": "integratedTerminal",
|
|
42
|
+
"cwd": "${workspaceFolder}",
|
|
43
|
+
"justMyCode": false
|
|
44
|
+
},
|
|
45
|
+
]
|
|
46
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Contributing to stateMINT
|
|
2
|
+
|
|
3
|
+
Thank you for your interest in contributing to stateMINT! This document
|
|
4
|
+
outlines the development workflow and contribution standards.
|
|
5
|
+
|
|
6
|
+
## Contribution Standards
|
|
7
|
+
|
|
8
|
+
All contributions must meet the following requirements:
|
|
9
|
+
|
|
10
|
+
- **Tests**: All tests must pass (`uv run pytest test/`)
|
|
11
|
+
- **Formatting and Linting**: Code must pass formatting and linting checks (`uv run ruff format && uv run ruff check`)
|
|
12
|
+
|
|
13
|
+
## Standard Contribution Process
|
|
14
|
+
|
|
15
|
+
1. Create a feature branch
|
|
16
|
+
1. Make your changes
|
|
17
|
+
1. Ensure tests pass and formatting/linting succeeds
|
|
18
|
+
1. Submit a pull request
|
|
19
|
+
|
|
20
|
+
## Development Setup
|
|
21
|
+
|
|
22
|
+
See README.md for installation and development environment setup
|
|
23
|
+
instructions.
|
|
24
|
+
|
|
25
|
+
## Questions?
|
|
26
|
+
|
|
27
|
+
Open an issue for questions about contributing or development workflow.
|
mintstate-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mintstate
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: StateMINT is a state space based neural network emulator for malariasimulation.
|
|
5
|
+
Project-URL: Homepage, https://github.com/mrc-ide/stateMINT
|
|
6
|
+
Project-URL: Repository, https://github.com/mrc-ide/stateMINT
|
|
7
|
+
Project-URL: Weights, https://huggingface.co/dide-ic/stateMINT
|
|
8
|
+
Author-email: Anmol Thapar <mr.anmolthapar@gmail.com>
|
|
9
|
+
Requires-Python: >=3.12
|
|
10
|
+
Requires-Dist: etils>=1.0
|
|
11
|
+
Requires-Dist: flax>=0.12.7
|
|
12
|
+
Requires-Dist: huggingface-hub>=1.19.0
|
|
13
|
+
Requires-Dist: jax>=0.10.1
|
|
14
|
+
Requires-Dist: jaxtyping>=0.3.10
|
|
15
|
+
Requires-Dist: mamba2-jax>=1.0.1
|
|
16
|
+
Requires-Dist: numpy>=2.4.6
|
|
17
|
+
Requires-Dist: omegaconf>=2.3
|
|
18
|
+
Requires-Dist: orbax-checkpoint>=0.12.0
|
|
19
|
+
Requires-Dist: pandas>=3.0.3
|
|
20
|
+
Provides-Extra: all
|
|
21
|
+
Requires-Dist: duckdb>=1.5.3; extra == 'all'
|
|
22
|
+
Requires-Dist: grain>=0.2.16; extra == 'all'
|
|
23
|
+
Requires-Dist: hydra-core>=1.3.2; extra == 'all'
|
|
24
|
+
Requires-Dist: jax[cuda12]>=0.10.1; extra == 'all'
|
|
25
|
+
Requires-Dist: matplotlib>=3.10.9; extra == 'all'
|
|
26
|
+
Requires-Dist: optax>=0.2.8; extra == 'all'
|
|
27
|
+
Requires-Dist: tqdm>=4.67.3; extra == 'all'
|
|
28
|
+
Requires-Dist: wandb>=0.27.0; extra == 'all'
|
|
29
|
+
Provides-Extra: gpu
|
|
30
|
+
Requires-Dist: jax[cuda12]>=0.10.1; extra == 'gpu'
|
|
31
|
+
Provides-Extra: plot
|
|
32
|
+
Requires-Dist: matplotlib>=3.10.9; extra == 'plot'
|
|
33
|
+
Provides-Extra: train
|
|
34
|
+
Requires-Dist: duckdb>=1.5.3; extra == 'train'
|
|
35
|
+
Requires-Dist: grain>=0.2.16; extra == 'train'
|
|
36
|
+
Requires-Dist: hydra-core>=1.3.2; extra == 'train'
|
|
37
|
+
Requires-Dist: optax>=0.2.8; extra == 'train'
|
|
38
|
+
Requires-Dist: tqdm>=4.67.3; extra == 'train'
|
|
39
|
+
Requires-Dist: wandb>=0.27.0; extra == 'train'
|
|
40
|
+
Description-Content-Type: text/markdown
|
|
41
|
+
|
|
42
|
+
# StateMINT
|
|
43
|
+
|
|
44
|
+
StateMINT is a JAX/Flax neural emulator for
|
|
45
|
+
[`malariasimulation`](https://github.com/mrc-ide/malariasimulation) outputs. It
|
|
46
|
+
uses a Mamba2 state-space sequence model to predict malaria trajectories from
|
|
47
|
+
static scenario covariates and intervention timing features. It supersedes the
|
|
48
|
+
earlier
|
|
49
|
+
[`MINTelligence`](https://github.com/CosmoNaught/MINTelligence) RNN emulator.
|
|
50
|
+
|
|
51
|
+
## What StateMINT Provides
|
|
52
|
+
|
|
53
|
+
- Mamba2-based sequence regressors for malaria prevalence and case-count
|
|
54
|
+
trajectories.
|
|
55
|
+
- Data extraction utilities for aggregating raw `malariasimulation` DuckDB
|
|
56
|
+
outputs into model-ready parquet files.
|
|
57
|
+
- Preprocessing with target transforms, covariate scaling, and
|
|
58
|
+
intervention-aware feature construction.
|
|
59
|
+
- Training, evaluation, visualization, checkpointing, and export workflows.
|
|
60
|
+
- Hugging Face Hub loading utilities for exported inference artifacts.
|
|
61
|
+
|
|
62
|
+
## Installation
|
|
63
|
+
|
|
64
|
+
StateMINT requires Python 3.12 or newer and uses
|
|
65
|
+
[`uv`](https://github.com/astral-sh/uv).
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
git clone https://github.com/mrc-ide/stateMINT.git
|
|
69
|
+
cd stateMINT
|
|
70
|
+
uv sync
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
For development dependencies and optional extras:
|
|
74
|
+
|
|
75
|
+
```bash
|
|
76
|
+
uv sync --all-extras --dev
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
Or install extras individually:
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
uv sync --extra plot
|
|
83
|
+
uv sync --extra gpu
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Quick Start: Inference
|
|
87
|
+
|
|
88
|
+
Load an exported artifact from the Hugging Face Hub or a local directory with
|
|
89
|
+
`Mamba2Regressor.from_pretrained`.
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
from stateMINT.model import Mamba2Regressor
|
|
93
|
+
|
|
94
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
95
|
+
"dide-ic/stateMINT",
|
|
96
|
+
predictor="prevalence",
|
|
97
|
+
revision="v1.0.0",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
static_covars = [{
|
|
101
|
+
"eir": 50.0,
|
|
102
|
+
"dn0_use": 0.3,
|
|
103
|
+
"dn0_future": 0.4,
|
|
104
|
+
"Q0": 0.8,
|
|
105
|
+
"phi_bednets": 0.7,
|
|
106
|
+
"seasonal": 1.0,
|
|
107
|
+
"routine": 0.5,
|
|
108
|
+
"itn_use": 0.2,
|
|
109
|
+
"irs_use": 0.1,
|
|
110
|
+
"itn_future": 0.3,
|
|
111
|
+
"irs_future": 0.2,
|
|
112
|
+
"lsm": 0.0,
|
|
113
|
+
}]
|
|
114
|
+
|
|
115
|
+
predicted_prevalence = artifact.predict(static_covars)
|
|
116
|
+
|
|
117
|
+
print(predicted_prevalence.shape) # (batch, timesteps)
|
|
118
|
+
print(predicted_prevalence[0]) # first trajectory
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
For cases, load the cases artifact and use the same input format:
|
|
122
|
+
|
|
123
|
+
```python
|
|
124
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
125
|
+
"dide-ic/stateMINT",
|
|
126
|
+
predictor="cases",
|
|
127
|
+
revision="v1.0.0",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
predicted_cases = artifact.predict(static_covars)
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
By default, predictions are returned on the original target scale: prevalence as
|
|
134
|
+
probabilities and cases on the scale used by the training data. Pass
|
|
135
|
+
`transformed=True` to return model-space outputs.
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
raw_model_space = artifact.predict(static_covars, transformed=True)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
For local artifacts, pass the target artifact directory:
|
|
142
|
+
|
|
143
|
+
```python
|
|
144
|
+
artifact = Mamba2Regressor.from_pretrained(
|
|
145
|
+
"artifacts/prevalence",
|
|
146
|
+
predictor="prevalence",
|
|
147
|
+
)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## Static Covariates
|
|
151
|
+
|
|
152
|
+
Inference inputs need one dictionary per scenario with these static covariates:
|
|
153
|
+
|
|
154
|
+
```text
|
|
155
|
+
eir
|
|
156
|
+
dn0_use
|
|
157
|
+
dn0_future
|
|
158
|
+
Q0
|
|
159
|
+
phi_bednets
|
|
160
|
+
seasonal
|
|
161
|
+
routine
|
|
162
|
+
itn_use
|
|
163
|
+
irs_use
|
|
164
|
+
itn_future
|
|
165
|
+
irs_future
|
|
166
|
+
lsm
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
Artifacts include the fitted static scaler, timestep grid, intervention day,
|
|
170
|
+
target transform, and other preprocessing metadata needed for inference.
|
|
171
|
+
|
|
172
|
+
## Training Workflow
|
|
173
|
+
|
|
174
|
+
Typical workflow:
|
|
175
|
+
|
|
176
|
+
1. Fetch and aggregate simulation data from DuckDB.
|
|
177
|
+
2. Train a target-specific model.
|
|
178
|
+
3. Evaluate or visualize test-set predictions.
|
|
179
|
+
4. Export the checkpoint into a portable inference artifact.
|
|
180
|
+
5. Upload the artifact to the Hugging Face Hub, if needed.
|
|
181
|
+
|
|
182
|
+
### 1. Fetch Filtered Data
|
|
183
|
+
|
|
184
|
+
`stateMINT.filter_raw_data` reads raw DuckDB simulation rows, filters burn-in,
|
|
185
|
+
aggregates fixed windows, and writes `filtered_data_<predictor>.parquet`.
|
|
186
|
+
|
|
187
|
+
```bash
|
|
188
|
+
uv run python -m stateMINT.filter_raw_data \
|
|
189
|
+
--db-path /path/to/simulations.duckdb \
|
|
190
|
+
--table-name simulation_results \
|
|
191
|
+
--predictor prevalence \
|
|
192
|
+
--window-size 14 \
|
|
193
|
+
--output-folder data
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
Useful options:
|
|
197
|
+
|
|
198
|
+
- `--predictor prevalence` or `--predictor cases`
|
|
199
|
+
- `--param-limit N` to keep only the first `N` parameter indices
|
|
200
|
+
- `--sim-limit N` to sample up to `N` simulations per parameter
|
|
201
|
+
|
|
202
|
+
The raw table should include identifiers (`parameter_index`, `simulation_index`,
|
|
203
|
+
`global_index`), daily timesteps, the static covariates above, and output
|
|
204
|
+
columns for prevalence or cases.
|
|
205
|
+
|
|
206
|
+
### 2. Train a Model
|
|
207
|
+
|
|
208
|
+
Training uses Hydra; the default target is prevalence.
|
|
209
|
+
|
|
210
|
+
```bash
|
|
211
|
+
uv run python -m stateMINT.train
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
Train the cases model:
|
|
215
|
+
|
|
216
|
+
```bash
|
|
217
|
+
uv run python -m stateMINT.train target=cases
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
Common overrides:
|
|
221
|
+
|
|
222
|
+
```bash
|
|
223
|
+
uv run python -m stateMINT.train \
|
|
224
|
+
target=prevalence \
|
|
225
|
+
data_file=data/filtered_data_prevalence.parquet \
|
|
226
|
+
output_dir=train_outputs/prevalence \
|
|
227
|
+
use_wandb=false
|
|
228
|
+
```
|
|
229
|
+
|
|
230
|
+
Training writes checkpoints under `checkpoint_dir`, saves
|
|
231
|
+
`static_scaler.pkl` in `output_dir`, and reuses a split assignment file for
|
|
232
|
+
consistent train/validation/test splits.
|
|
233
|
+
|
|
234
|
+
### 3. W&B Sweeps
|
|
235
|
+
|
|
236
|
+
Sweep definitions live in `stateMINT/conf/sweeps`. Create a sweep, then run one
|
|
237
|
+
or more agents with the sweep ID returned by W&B:
|
|
238
|
+
|
|
239
|
+
```bash
|
|
240
|
+
uv run wandb sweep stateMINT/conf/sweeps/prevalence.yaml
|
|
241
|
+
uv run wandb agent <entity>/stateMINT-sweep/<sweep-id>
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
Use `stateMINT/conf/sweeps/cases.yaml` for the cases target. Sweep commands set
|
|
245
|
+
`use_wandb=true` and pass Hydra overrides through `${args_no_hyphens}`.
|
|
246
|
+
|
|
247
|
+
### 4. Visualize Predictions
|
|
248
|
+
|
|
249
|
+
Compare predictions with test-set targets. `checkpoint_dir` is required.
|
|
250
|
+
|
|
251
|
+
```bash
|
|
252
|
+
uv run python -m stateMINT.visualise_predictions \
|
|
253
|
+
target=prevalence \
|
|
254
|
+
checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
|
|
255
|
+
data_file=data/filtered_data_prevalence.parquet
|
|
256
|
+
```
|
|
257
|
+
|
|
258
|
+
The default output path is `viz_outputs/<predictor>/preds-vs-targets.pdf`.
|
|
259
|
+
|
|
260
|
+
### 5. Export an Artifact
|
|
261
|
+
|
|
262
|
+
Export converts a trained Orbax checkpoint and preprocessing metadata into a
|
|
263
|
+
self-contained artifact.
|
|
264
|
+
|
|
265
|
+
```bash
|
|
266
|
+
uv run python -m stateMINT.model_export \
|
|
267
|
+
predictor=prevalence \
|
|
268
|
+
checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
|
|
269
|
+
scaler_file=train_outputs/prevalence/static_scaler.pkl \
|
|
270
|
+
artifact_dir=artifacts/prevalence
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
Export config architecture values must match the checkpoint, including
|
|
274
|
+
`d_model`, `d_state`, `n_layers`, `dropout`, and related Mamba2 settings.
|
|
275
|
+
|
|
276
|
+
An exported artifact contains:
|
|
277
|
+
|
|
278
|
+
```text
|
|
279
|
+
artifact_dir/
|
|
280
|
+
|-- checkpoint/
|
|
281
|
+
|-- model_config.json
|
|
282
|
+
`-- preprocessing_config.json
|
|
283
|
+
```
|
|
284
|
+
|
|
285
|
+
`model_config.json` stores architecture settings; `preprocessing_config.json`
|
|
286
|
+
stores feature order, target transform, intervention timing, timestep
|
|
287
|
+
construction, and static scaler parameters.
|
|
288
|
+
|
|
289
|
+
### 6. Upload To Hugging Face
|
|
290
|
+
|
|
291
|
+
Authenticate first:
|
|
292
|
+
|
|
293
|
+
```bash
|
|
294
|
+
hf auth login
|
|
295
|
+
```
|
|
296
|
+
|
|
297
|
+
Upload an artifact:
|
|
298
|
+
|
|
299
|
+
```bash
|
|
300
|
+
hf upload dide-ic/stateMINT artifacts/prevalence prevalence/ \
|
|
301
|
+
--commit-message "Add prevalence model artifact"
|
|
302
|
+
```
|
|
303
|
+
|
|
304
|
+
Create a release tag:
|
|
305
|
+
|
|
306
|
+
```bash
|
|
307
|
+
hf repos tag create dide-ic/stateMINT v1.0.0 \
|
|
308
|
+
--revision main \
|
|
309
|
+
--message "Release v1.0.0"
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
## Configuration
|
|
313
|
+
|
|
314
|
+
Main Hydra configs in `stateMINT/conf`:
|
|
315
|
+
|
|
316
|
+
- `train_config.yaml` for training.
|
|
317
|
+
- `viz_config.yaml` for prediction visualizations.
|
|
318
|
+
- `export_config.yaml` for artifact export.
|
|
319
|
+
- `target/prevalence.yaml` and `target/cases.yaml` for target-specific defaults.
|
|
320
|
+
- `sweeps/*.yaml` for Weights & Biases sweep definitions.
|
|
321
|
+
|
|
322
|
+
Select a target with `target=prevalence` or `target=cases`; override config
|
|
323
|
+
values from the command line with Hydra syntax.
|
|
324
|
+
|
|
325
|
+
## Development
|
|
326
|
+
|
|
327
|
+
Run the test suite:
|
|
328
|
+
|
|
329
|
+
```bash
|
|
330
|
+
uv run pytest tests/
|
|
331
|
+
```
|
|
332
|
+
|
|
333
|
+
Skip slow or local-only tests:
|
|
334
|
+
|
|
335
|
+
```bash
|
|
336
|
+
uv run pytest tests/ -m "not slow"
|
|
337
|
+
uv run pytest tests/ -m "not local"
|
|
338
|
+
uv run pytest tests/ -m "not slow and not local" # skip both
|
|
339
|
+
```
|
|
340
|
+
|
|
341
|
+
Run linting and formatting:
|
|
342
|
+
|
|
343
|
+
```bash
|
|
344
|
+
uv run ruff check
|
|
345
|
+
uv run ruff format
|
|
346
|
+
```
|
|
347
|
+
|
|
348
|
+
## Repository Layout
|
|
349
|
+
|
|
350
|
+
```text
|
|
351
|
+
stateMINT/
|
|
352
|
+
|-- common/ # shared dataclasses, transforms, and model helpers
|
|
353
|
+
|-- conf/ # Hydra configs for training, export, viz, and sweeps
|
|
354
|
+
|-- data/ # DuckDB fetch, preprocessing, features, and loaders
|
|
355
|
+
|-- eval/ # metrics and prediction/target visualization helpers
|
|
356
|
+
|-- model/ # Mamba2 regressor and artifact loading utilities
|
|
357
|
+
|-- training/ # optimizer, train/eval steps, loss, and checkpointing
|
|
358
|
+
|-- filter_raw_data.py # CLI for building filtered parquet datasets
|
|
359
|
+
|-- train.py # Hydra training entry point
|
|
360
|
+
|-- visualise_predictions.py
|
|
361
|
+
`-- model_export.py # Hydra artifact export entry point
|
|
362
|
+
|
|
363
|
+
tests/ # unit tests
|
|
364
|
+
artifacts/ # exported model artifact examples/metadata
|
|
365
|
+
viz_outputs/ # generated prediction visualization outputs
|
|
366
|
+
```
|
|
367
|
+
|
|
368
|
+
## Contributing
|
|
369
|
+
|
|
370
|
+
See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines and the
|
|
371
|
+
development workflow.
|