mintstate 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. mintstate-0.1.0/.github/workflows/lint.yaml +30 -0
  2. mintstate-0.1.0/.github/workflows/release.yaml +33 -0
  3. mintstate-0.1.0/.github/workflows/test.yaml +28 -0
  4. mintstate-0.1.0/.gitignore +135 -0
  5. mintstate-0.1.0/.python-version +1 -0
  6. mintstate-0.1.0/.vscode/launch.json +46 -0
  7. mintstate-0.1.0/CONTRIBUTING.md +27 -0
  8. mintstate-0.1.0/PKG-INFO +371 -0
  9. mintstate-0.1.0/README.md +330 -0
  10. mintstate-0.1.0/pyproject.toml +65 -0
  11. mintstate-0.1.0/scenarios_prevalence.pdf +0 -0
  12. mintstate-0.1.0/stateMINT/__init__.py +0 -0
  13. mintstate-0.1.0/stateMINT/common/__init__.py +0 -0
  14. mintstate-0.1.0/stateMINT/common/dataclasses.py +11 -0
  15. mintstate-0.1.0/stateMINT/common/utils.py +74 -0
  16. mintstate-0.1.0/stateMINT/conf/export_config.yaml +25 -0
  17. mintstate-0.1.0/stateMINT/conf/sweeps/cases.yaml +46 -0
  18. mintstate-0.1.0/stateMINT/conf/sweeps/prevalence.yaml +46 -0
  19. mintstate-0.1.0/stateMINT/conf/target/cases.yaml +18 -0
  20. mintstate-0.1.0/stateMINT/conf/target/prevalence.yaml +18 -0
  21. mintstate-0.1.0/stateMINT/conf/train_config.yaml +46 -0
  22. mintstate-0.1.0/stateMINT/conf/viz_config.yaml +31 -0
  23. mintstate-0.1.0/stateMINT/data/__init__.py +21 -0
  24. mintstate-0.1.0/stateMINT/data/dataset.py +77 -0
  25. mintstate-0.1.0/stateMINT/data/features.py +104 -0
  26. mintstate-0.1.0/stateMINT/data/fetch.py +194 -0
  27. mintstate-0.1.0/stateMINT/data/preprocessing.py +504 -0
  28. mintstate-0.1.0/stateMINT/eval/__init__.py +0 -0
  29. mintstate-0.1.0/stateMINT/eval/metrics.py +180 -0
  30. mintstate-0.1.0/stateMINT/eval/viz_preds_truth.py +151 -0
  31. mintstate-0.1.0/stateMINT/filter_raw_data.py +45 -0
  32. mintstate-0.1.0/stateMINT/model/__init__.py +3 -0
  33. mintstate-0.1.0/stateMINT/model/hub.py +144 -0
  34. mintstate-0.1.0/stateMINT/model/mamba2.py +171 -0
  35. mintstate-0.1.0/stateMINT/model_export.py +90 -0
  36. mintstate-0.1.0/stateMINT/train.py +158 -0
  37. mintstate-0.1.0/stateMINT/training/__init__.py +0 -0
  38. mintstate-0.1.0/stateMINT/training/checkpoint.py +147 -0
  39. mintstate-0.1.0/stateMINT/training/loss.py +35 -0
  40. mintstate-0.1.0/stateMINT/training/train_step.py +167 -0
  41. mintstate-0.1.0/stateMINT/visualise_predictions.py +55 -0
  42. mintstate-0.1.0/tests/__init__.py +0 -0
  43. mintstate-0.1.0/tests/common/__init__.py +0 -0
  44. mintstate-0.1.0/tests/common/test_utils.py +49 -0
  45. mintstate-0.1.0/tests/conftest.py +230 -0
  46. mintstate-0.1.0/tests/data/__init__.py +0 -0
  47. mintstate-0.1.0/tests/data/test_dataset.py +21 -0
  48. mintstate-0.1.0/tests/data/test_features.py +10 -0
  49. mintstate-0.1.0/tests/data/test_fetch.py +154 -0
  50. mintstate-0.1.0/tests/data/test_preprocessing.py +310 -0
  51. mintstate-0.1.0/tests/eval/__init__.py +0 -0
  52. mintstate-0.1.0/tests/eval/test_metrics.py +34 -0
  53. mintstate-0.1.0/tests/eval/test_viz_preds_truth.py +166 -0
  54. mintstate-0.1.0/tests/model/__init__.py +0 -0
  55. mintstate-0.1.0/tests/model/test_hub.py +196 -0
  56. mintstate-0.1.0/tests/model/test_mamba2.py +109 -0
  57. mintstate-0.1.0/tests/smoke_test.py +3 -0
  58. mintstate-0.1.0/tests/training/__init__.py +0 -0
  59. mintstate-0.1.0/tests/training/test_checkpoint.py +40 -0
  60. mintstate-0.1.0/tests/training/test_loss.py +24 -0
  61. mintstate-0.1.0/tests/training/test_train_step.py +50 -0
  62. mintstate-0.1.0/uv.lock +2084 -0
  63. mintstate-0.1.0/viz_outputs/cases/preds-vs-targets.pdf +0 -0
  64. mintstate-0.1.0/viz_outputs/cases/static_scaler.pkl +0 -0
  65. mintstate-0.1.0/viz_outputs/prevalence/preds-vs-targets.pdf +0 -0
  66. mintstate-0.1.0/viz_outputs/prevalence/static_scaler.pkl +0 -0
@@ -0,0 +1,30 @@
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - "*"
10
+
11
+ jobs:
12
+ lint-and-format:
13
+ name: Run lint and format checks
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - name: Checkout repository
17
+ uses: actions/checkout@v6
18
+ - name: Install uv
19
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
20
+ with:
21
+ enable-cache: true
22
+ version: "0.11.18"
23
+ - name: Set up Python
24
+ run: uv python install
25
+ - name: Install the project
26
+ run: uv sync --locked --all-extras --dev
27
+ - name: Run lint checks
28
+ run: uv run ruff check --output-format=github
29
+ - name: Run format checks
30
+ run: uv run ruff format --check --diff
@@ -0,0 +1,33 @@
1
+ name: Publish Release to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v*.*.*'
7
+
8
+ jobs:
9
+ run:
10
+ runs-on: ubuntu-latest
11
+ environment:
12
+ name: pypi
13
+ permissions:
14
+ id-token: write
15
+ contents: read
16
+ steps:
17
+ - name: Checkout code
18
+ uses: actions/checkout@v6
19
+ - name: Install uv
20
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
21
+ with:
22
+ enable-cache: true
23
+ version: "0.11.18"
24
+ - name: Set up Python
25
+ run: uv python install
26
+ - name: Build
27
+ run: uv build
28
+ - name: Smoke test (wheel)
29
+ run: uv run --isolated --no-project --with dist/*.whl tests/smoke_test.py
30
+ - name: Smoke test (source distribution)
31
+ run: uv run --isolated --no-project --with dist/*.tar.gz tests/smoke_test.py
32
+ - name: Publish
33
+ run: uv publish
@@ -0,0 +1,28 @@
1
+ name: Test
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - "*"
10
+
11
+ jobs:
12
+ test:
13
+ name: Run tests
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - name: Checkout repository
17
+ uses: actions/checkout@v6
18
+ - name: Install uv
19
+ uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b
20
+ with:
21
+ enable-cache: true
22
+ version: "0.11.18"
23
+ - name: Set up Python
24
+ run: uv python install
25
+ - name: Install the project
26
+ run: uv sync --locked --all-extras --dev
27
+ - name: Run tests
28
+ run: uv run pytest -m "not local and not slow"
@@ -0,0 +1,135 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ artifacts/
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py.cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # Jupyter Notebook
76
+ .ipynb_checkpoints
77
+
78
+ # IPython
79
+ profile_default/
80
+ ipython_config.py
81
+
82
+
83
+ # Environments
84
+ .env
85
+ .envrc
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+ .dmypy.json
106
+ dmypy.json
107
+
108
+ # Pyre type checker
109
+ .pyre/
110
+
111
+ # pytype static type analyzer
112
+ .pytype/
113
+
114
+ # Cython debug symbols
115
+ cython_debug/
116
+
117
+ # Ruff stuff:
118
+ .ruff_cache/
119
+
120
+ # PyPI configuration file
121
+ .pypirc
122
+
123
+ # testing scripts
124
+ *.ipynb
125
+ x.py
126
+
127
+ # hydra logs
128
+ outputs
129
+
130
+ # misc
131
+ inputs
132
+ /data/
133
+ ref.py
134
+ train_outputs/
135
+ wandb/
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,46 @@
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Debug run train",
9
+ "type": "debugpy",
10
+ "request": "launch",
11
+ "python": "/home/anmol/code/stateMINT/.venv/bin/python",
12
+ "module": "stateMINT.viz",
13
+ // "args": [
14
+ // "num_epochs=10",
15
+ // ],
16
+ "env": {
17
+ // "JAX_DISABLE_JIT": "true",
18
+ // "UV_MANAGED_PYTHON": "true"
19
+ "CUDA_VISIBLE_DEVICES": "1"
20
+ },
21
+ "console": "integratedTerminal",
22
+ "cwd": "${workspaceFolder}",
23
+ "justMyCode": false
24
+ },
25
+ {
26
+ "name": "debug test function",
27
+ "type": "debugpy",
28
+ "request": "launch",
29
+ "python": "/home/anmol/code/stateMINT/.venv/bin/python",
30
+ "module": "pytest",
31
+ "args": [
32
+ "tests/model/test_mamba2.py",
33
+ "-k",
34
+ "test_prevalence_from_pretrained_from_hub"
35
+ ],
36
+ "env": {
37
+ // "JAX_DISABLE_JIT": "true",
38
+ // "UV_MANAGED_PYTHON": "true"
39
+ // "CUDA_VISIBLE_DEVICES": "1"
40
+ },
41
+ "console": "integratedTerminal",
42
+ "cwd": "${workspaceFolder}",
43
+ "justMyCode": false
44
+ },
45
+ ]
46
+ }
@@ -0,0 +1,27 @@
1
+ # Contributing to stateMINT
2
+
3
+ Thank you for your interest in contributing to stateMINT! This document
4
+ outlines the development workflow and contribution standards.
5
+
6
+ ## Contribution Standards
7
+
8
+ All contributions must meet the following requirements:
9
+
10
+ - **Tests**: All tests must pass (`uv run pytest test/`)
11
+ - **Formatting and Linting**: Code must pass formatting and linting checks (`uv run ruff format && uv run ruff check`)
12
+
13
+ ## Standard Contribution Process
14
+
15
+ 1. Create a feature branch
16
+ 1. Make your changes
17
+ 1. Ensure tests pass and formatting/linting succeeds
18
+ 1. Submit a pull request
19
+
20
+ ## Development Setup
21
+
22
+ See README.md for installation and development environment setup
23
+ instructions.
24
+
25
+ ## Questions?
26
+
27
+ Open an issue for questions about contributing or development workflow.
@@ -0,0 +1,371 @@
1
+ Metadata-Version: 2.4
2
+ Name: mintstate
3
+ Version: 0.1.0
4
+ Summary: StateMINT is a state space based neural network emulator for malariasimulation.
5
+ Project-URL: Homepage, https://github.com/mrc-ide/stateMINT
6
+ Project-URL: Repository, https://github.com/mrc-ide/stateMINT
7
+ Project-URL: Weights, https://huggingface.co/dide-ic/stateMINT
8
+ Author-email: Anmol Thapar <mr.anmolthapar@gmail.com>
9
+ Requires-Python: >=3.12
10
+ Requires-Dist: etils>=1.0
11
+ Requires-Dist: flax>=0.12.7
12
+ Requires-Dist: huggingface-hub>=1.19.0
13
+ Requires-Dist: jax>=0.10.1
14
+ Requires-Dist: jaxtyping>=0.3.10
15
+ Requires-Dist: mamba2-jax>=1.0.1
16
+ Requires-Dist: numpy>=2.4.6
17
+ Requires-Dist: omegaconf>=2.3
18
+ Requires-Dist: orbax-checkpoint>=0.12.0
19
+ Requires-Dist: pandas>=3.0.3
20
+ Provides-Extra: all
21
+ Requires-Dist: duckdb>=1.5.3; extra == 'all'
22
+ Requires-Dist: grain>=0.2.16; extra == 'all'
23
+ Requires-Dist: hydra-core>=1.3.2; extra == 'all'
24
+ Requires-Dist: jax[cuda12]>=0.10.1; extra == 'all'
25
+ Requires-Dist: matplotlib>=3.10.9; extra == 'all'
26
+ Requires-Dist: optax>=0.2.8; extra == 'all'
27
+ Requires-Dist: tqdm>=4.67.3; extra == 'all'
28
+ Requires-Dist: wandb>=0.27.0; extra == 'all'
29
+ Provides-Extra: gpu
30
+ Requires-Dist: jax[cuda12]>=0.10.1; extra == 'gpu'
31
+ Provides-Extra: plot
32
+ Requires-Dist: matplotlib>=3.10.9; extra == 'plot'
33
+ Provides-Extra: train
34
+ Requires-Dist: duckdb>=1.5.3; extra == 'train'
35
+ Requires-Dist: grain>=0.2.16; extra == 'train'
36
+ Requires-Dist: hydra-core>=1.3.2; extra == 'train'
37
+ Requires-Dist: optax>=0.2.8; extra == 'train'
38
+ Requires-Dist: tqdm>=4.67.3; extra == 'train'
39
+ Requires-Dist: wandb>=0.27.0; extra == 'train'
40
+ Description-Content-Type: text/markdown
41
+
42
+ # StateMINT
43
+
44
+ StateMINT is a JAX/Flax neural emulator for
45
+ [`malariasimulation`](https://github.com/mrc-ide/malariasimulation) outputs. It
46
+ uses a Mamba2 state-space sequence model to predict malaria trajectories from
47
+ static scenario covariates and intervention timing features. It supersedes the
48
+ earlier
49
+ [`MINTelligence`](https://github.com/CosmoNaught/MINTelligence) RNN emulator.
50
+
51
+ ## What StateMINT Provides
52
+
53
+ - Mamba2-based sequence regressors for malaria prevalence and case-count
54
+ trajectories.
55
+ - Data extraction utilities for aggregating raw `malariasimulation` DuckDB
56
+ outputs into model-ready parquet files.
57
+ - Preprocessing with target transforms, covariate scaling, and
58
+ intervention-aware feature construction.
59
+ - Training, evaluation, visualization, checkpointing, and export workflows.
60
+ - Hugging Face Hub loading utilities for exported inference artifacts.
61
+
62
+ ## Installation
63
+
64
+ StateMINT requires Python 3.12 or newer and uses
65
+ [`uv`](https://github.com/astral-sh/uv).
66
+
67
+ ```bash
68
+ git clone https://github.com/mrc-ide/stateMINT.git
69
+ cd stateMINT
70
+ uv sync
71
+ ```
72
+
73
+ For development dependencies and optional extras:
74
+
75
+ ```bash
76
+ uv sync --all-extras --dev
77
+ ```
78
+
79
+ Or install extras individually:
80
+
81
+ ```bash
82
+ uv sync --extra plot
83
+ uv sync --extra gpu
84
+ ```
85
+
86
+ ## Quick Start: Inference
87
+
88
+ Load an exported artifact from the Hugging Face Hub or a local directory with
89
+ `Mamba2Regressor.from_pretrained`.
90
+
91
+ ```python
92
+ from stateMINT.model import Mamba2Regressor
93
+
94
+ artifact = Mamba2Regressor.from_pretrained(
95
+ "dide-ic/stateMINT",
96
+ predictor="prevalence",
97
+ revision="v1.0.0",
98
+ )
99
+
100
+ static_covars = [{
101
+ "eir": 50.0,
102
+ "dn0_use": 0.3,
103
+ "dn0_future": 0.4,
104
+ "Q0": 0.8,
105
+ "phi_bednets": 0.7,
106
+ "seasonal": 1.0,
107
+ "routine": 0.5,
108
+ "itn_use": 0.2,
109
+ "irs_use": 0.1,
110
+ "itn_future": 0.3,
111
+ "irs_future": 0.2,
112
+ "lsm": 0.0,
113
+ }]
114
+
115
+ predicted_prevalence = artifact.predict(static_covars)
116
+
117
+ print(predicted_prevalence.shape) # (batch, timesteps)
118
+ print(predicted_prevalence[0]) # first trajectory
119
+ ```
120
+
121
+ For cases, load the cases artifact and use the same input format:
122
+
123
+ ```python
124
+ artifact = Mamba2Regressor.from_pretrained(
125
+ "dide-ic/stateMINT",
126
+ predictor="cases",
127
+ revision="v1.0.0",
128
+ )
129
+
130
+ predicted_cases = artifact.predict(static_covars)
131
+ ```
132
+
133
+ By default, predictions are returned on the original target scale: prevalence as
134
+ probabilities and cases on the scale used by the training data. Pass
135
+ `transformed=True` to return model-space outputs.
136
+
137
+ ```python
138
+ raw_model_space = artifact.predict(static_covars, transformed=True)
139
+ ```
140
+
141
+ For local artifacts, pass the target artifact directory:
142
+
143
+ ```python
144
+ artifact = Mamba2Regressor.from_pretrained(
145
+ "artifacts/prevalence",
146
+ predictor="prevalence",
147
+ )
148
+ ```
149
+
150
+ ## Static Covariates
151
+
152
+ Inference inputs need one dictionary per scenario with these static covariates:
153
+
154
+ ```text
155
+ eir
156
+ dn0_use
157
+ dn0_future
158
+ Q0
159
+ phi_bednets
160
+ seasonal
161
+ routine
162
+ itn_use
163
+ irs_use
164
+ itn_future
165
+ irs_future
166
+ lsm
167
+ ```
168
+
169
+ Artifacts include the fitted static scaler, timestep grid, intervention day,
170
+ target transform, and other preprocessing metadata needed for inference.
171
+
172
+ ## Training Workflow
173
+
174
+ Typical workflow:
175
+
176
+ 1. Fetch and aggregate simulation data from DuckDB.
177
+ 2. Train a target-specific model.
178
+ 3. Evaluate or visualize test-set predictions.
179
+ 4. Export the checkpoint into a portable inference artifact.
180
+ 5. Upload the artifact to the Hugging Face Hub, if needed.
181
+
182
+ ### 1. Fetch Filtered Data
183
+
184
+ `stateMINT.filter_raw_data` reads raw DuckDB simulation rows, filters burn-in,
185
+ aggregates fixed windows, and writes `filtered_data_<predictor>.parquet`.
186
+
187
+ ```bash
188
+ uv run python -m stateMINT.filter_raw_data \
189
+ --db-path /path/to/simulations.duckdb \
190
+ --table-name simulation_results \
191
+ --predictor prevalence \
192
+ --window-size 14 \
193
+ --output-folder data
194
+ ```
195
+
196
+ Useful options:
197
+
198
+ - `--predictor prevalence` or `--predictor cases`
199
+ - `--param-limit N` to keep only the first `N` parameter indices
200
+ - `--sim-limit N` to sample up to `N` simulations per parameter
201
+
202
+ The raw table should include identifiers (`parameter_index`, `simulation_index`,
203
+ `global_index`), daily timesteps, the static covariates above, and output
204
+ columns for prevalence or cases.
205
+
206
+ ### 2. Train a Model
207
+
208
+ Training uses Hydra; the default target is prevalence.
209
+
210
+ ```bash
211
+ uv run python -m stateMINT.train
212
+ ```
213
+
214
+ Train the cases model:
215
+
216
+ ```bash
217
+ uv run python -m stateMINT.train target=cases
218
+ ```
219
+
220
+ Common overrides:
221
+
222
+ ```bash
223
+ uv run python -m stateMINT.train \
224
+ target=prevalence \
225
+ data_file=data/filtered_data_prevalence.parquet \
226
+ output_dir=train_outputs/prevalence \
227
+ use_wandb=false
228
+ ```
229
+
230
+ Training writes checkpoints under `checkpoint_dir`, saves
231
+ `static_scaler.pkl` in `output_dir`, and reuses a split assignment file for
232
+ consistent train/validation/test splits.
233
+
234
+ ### 3. W&B Sweeps
235
+
236
+ Sweep definitions live in `stateMINT/conf/sweeps`. Create a sweep, then run one
237
+ or more agents with the sweep ID returned by W&B:
238
+
239
+ ```bash
240
+ uv run wandb sweep stateMINT/conf/sweeps/prevalence.yaml
241
+ uv run wandb agent <entity>/stateMINT-sweep/<sweep-id>
242
+ ```
243
+
244
+ Use `stateMINT/conf/sweeps/cases.yaml` for the cases target. Sweep commands set
245
+ `use_wandb=true` and pass Hydra overrides through `${args_no_hyphens}`.
246
+
247
+ ### 4. Visualize Predictions
248
+
249
+ Compare predictions with test-set targets. `checkpoint_dir` is required.
250
+
251
+ ```bash
252
+ uv run python -m stateMINT.visualise_predictions \
253
+ target=prevalence \
254
+ checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
255
+ data_file=data/filtered_data_prevalence.parquet
256
+ ```
257
+
258
+ The default output path is `viz_outputs/<predictor>/preds-vs-targets.pdf`.
259
+
260
+ ### 5. Export an Artifact
261
+
262
+ Export converts a trained Orbax checkpoint and preprocessing metadata into a
263
+ self-contained artifact.
264
+
265
+ ```bash
266
+ uv run python -m stateMINT.model_export \
267
+ predictor=prevalence \
268
+ checkpoint_dir=train_outputs/prevalence/ckpts-YYYY-MM-DDTHH:MM:SS \
269
+ scaler_file=train_outputs/prevalence/static_scaler.pkl \
270
+ artifact_dir=artifacts/prevalence
271
+ ```
272
+
273
+ Export config architecture values must match the checkpoint, including
274
+ `d_model`, `d_state`, `n_layers`, `dropout`, and related Mamba2 settings.
275
+
276
+ An exported artifact contains:
277
+
278
+ ```text
279
+ artifact_dir/
280
+ |-- checkpoint/
281
+ |-- model_config.json
282
+ `-- preprocessing_config.json
283
+ ```
284
+
285
+ `model_config.json` stores architecture settings; `preprocessing_config.json`
286
+ stores feature order, target transform, intervention timing, timestep
287
+ construction, and static scaler parameters.
288
+
289
+ ### 6. Upload To Hugging Face
290
+
291
+ Authenticate first:
292
+
293
+ ```bash
294
+ hf auth login
295
+ ```
296
+
297
+ Upload an artifact:
298
+
299
+ ```bash
300
+ hf upload dide-ic/stateMINT artifacts/prevalence prevalence/ \
301
+ --commit-message "Add prevalence model artifact"
302
+ ```
303
+
304
+ Create a release tag:
305
+
306
+ ```bash
307
+ hf repos tag create dide-ic/stateMINT v1.0.0 \
308
+ --revision main \
309
+ --message "Release v1.0.0"
310
+ ```
311
+
312
+ ## Configuration
313
+
314
+ Main Hydra configs in `stateMINT/conf`:
315
+
316
+ - `train_config.yaml` for training.
317
+ - `viz_config.yaml` for prediction visualizations.
318
+ - `export_config.yaml` for artifact export.
319
+ - `target/prevalence.yaml` and `target/cases.yaml` for target-specific defaults.
320
+ - `sweeps/*.yaml` for Weights & Biases sweep definitions.
321
+
322
+ Select a target with `target=prevalence` or `target=cases`; override config
323
+ values from the command line with Hydra syntax.
324
+
325
+ ## Development
326
+
327
+ Run the test suite:
328
+
329
+ ```bash
330
+ uv run pytest tests/
331
+ ```
332
+
333
+ Skip slow or local-only tests:
334
+
335
+ ```bash
336
+ uv run pytest tests/ -m "not slow"
337
+ uv run pytest tests/ -m "not local"
338
+ uv run pytest tests/ -m "not slow and not local" # skip both
339
+ ```
340
+
341
+ Run linting and formatting:
342
+
343
+ ```bash
344
+ uv run ruff check
345
+ uv run ruff format
346
+ ```
347
+
348
+ ## Repository Layout
349
+
350
+ ```text
351
+ stateMINT/
352
+ |-- common/ # shared dataclasses, transforms, and model helpers
353
+ |-- conf/ # Hydra configs for training, export, viz, and sweeps
354
+ |-- data/ # DuckDB fetch, preprocessing, features, and loaders
355
+ |-- eval/ # metrics and prediction/target visualization helpers
356
+ |-- model/ # Mamba2 regressor and artifact loading utilities
357
+ |-- training/ # optimizer, train/eval steps, loss, and checkpointing
358
+ |-- filter_raw_data.py # CLI for building filtered parquet datasets
359
+ |-- train.py # Hydra training entry point
360
+ |-- visualise_predictions.py
361
+ `-- model_export.py # Hydra artifact export entry point
362
+
363
+ tests/ # unit tests
364
+ artifacts/ # exported model artifact examples/metadata
365
+ viz_outputs/ # generated prediction visualization outputs
366
+ ```
367
+
368
+ ## Contributing
369
+
370
+ See [CONTRIBUTING.md](CONTRIBUTING.md) for contribution guidelines and the
371
+ development workflow.