cadence-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cadence_core-0.1.0/.github/workflows/publish.yml +40 -0
- cadence_core-0.1.0/.gitignore +25 -0
- cadence_core-0.1.0/LICENSE +21 -0
- cadence_core-0.1.0/PKG-INFO +108 -0
- cadence_core-0.1.0/README.md +73 -0
- cadence_core-0.1.0/cadence/__init__.py +430 -0
- cadence_core-0.1.0/cadence/clustering.py +168 -0
- cadence_core-0.1.0/cadence/config.py +83 -0
- cadence_core-0.1.0/cadence/data.py +171 -0
- cadence_core-0.1.0/cadence/embeddings.py +130 -0
- cadence_core-0.1.0/cadence/features.py +552 -0
- cadence_core-0.1.0/cadence/model.py +209 -0
- cadence_core-0.1.0/cadence/pretrained.py +198 -0
- cadence_core-0.1.0/cadence/trainer.py +387 -0
- cadence_core-0.1.0/examples/quickstart.py +110 -0
- cadence_core-0.1.0/pyproject.toml +44 -0
- cadence_core-0.1.0/tests/__init__.py +0 -0
- cadence_core-0.1.0/tests/test_smoke.py +238 -0
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*"
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
build-and-publish:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
permissions:
|
|
12
|
+
contents: read
|
|
13
|
+
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
|
|
17
|
+
- name: Set up Python
|
|
18
|
+
uses: actions/setup-python@v5
|
|
19
|
+
with:
|
|
20
|
+
python-version: "3.11"
|
|
21
|
+
|
|
22
|
+
- name: Install uv
|
|
23
|
+
uses: astral-sh/setup-uv@v4
|
|
24
|
+
with:
|
|
25
|
+
version: "latest"
|
|
26
|
+
|
|
27
|
+
- name: Install build dependencies
|
|
28
|
+
run: uv pip install --system hatchling build twine
|
|
29
|
+
|
|
30
|
+
- name: Build package
|
|
31
|
+
run: python -m build
|
|
32
|
+
|
|
33
|
+
- name: Check distribution
|
|
34
|
+
run: twine check dist/*
|
|
35
|
+
|
|
36
|
+
- name: Publish to PyPI
|
|
37
|
+
env:
|
|
38
|
+
TWINE_USERNAME: __token__
|
|
39
|
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
|
40
|
+
run: twine upload dist/*
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
__pycache__/
|
|
2
|
+
*.py[cod]
|
|
3
|
+
*$py.class
|
|
4
|
+
*.so
|
|
5
|
+
.venv/
|
|
6
|
+
venv/
|
|
7
|
+
.uv/
|
|
8
|
+
build/
|
|
9
|
+
dist/
|
|
10
|
+
*.egg-info/
|
|
11
|
+
.eggs/
|
|
12
|
+
.pytest_cache/
|
|
13
|
+
.ruff_cache/
|
|
14
|
+
.mypy_cache/
|
|
15
|
+
*.pt
|
|
16
|
+
*.bin
|
|
17
|
+
*.pth
|
|
18
|
+
*.ckpt
|
|
19
|
+
*.pkl
|
|
20
|
+
local.env
|
|
21
|
+
.env
|
|
22
|
+
*.log
|
|
23
|
+
.DS_Store
|
|
24
|
+
Thumbs.db
|
|
25
|
+
uv.lock
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Amir Rouhollahi
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: cadence-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Flat-MLP with PubMedBERT-enriched self-distillation for clinical next-event prediction
|
|
5
|
+
Project-URL: Homepage, https://github.com/amirrouh/cadence
|
|
6
|
+
Project-URL: Repository, https://github.com/amirrouh/cadence
|
|
7
|
+
Project-URL: Issues, https://github.com/amirrouh/cadence/issues
|
|
8
|
+
Author-email: Amir Rouhollahi <arouhollahi@bwh.harvard.edu>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: clinical,ehr,healthcare-ml,next-event-prediction,pubmedbert
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
|
+
Requires-Dist: huggingface-hub>=0.23
|
|
22
|
+
Requires-Dist: numpy>=1.24
|
|
23
|
+
Requires-Dist: pandas>=2.0
|
|
24
|
+
Requires-Dist: scikit-learn>=1.3
|
|
25
|
+
Requires-Dist: sentence-transformers>=2.7
|
|
26
|
+
Requires-Dist: torch>=2.1
|
|
27
|
+
Requires-Dist: tqdm>=4.66
|
|
28
|
+
Requires-Dist: transformers>=4.40
|
|
29
|
+
Provides-Extra: dev
|
|
30
|
+
Requires-Dist: build; extra == 'dev'
|
|
31
|
+
Requires-Dist: pytest>=7; extra == 'dev'
|
|
32
|
+
Requires-Dist: ruff>=0.5; extra == 'dev'
|
|
33
|
+
Requires-Dist: twine; extra == 'dev'
|
|
34
|
+
Description-Content-Type: text/markdown
|
|
35
|
+
|
|
36
|
+
# Cadence
|
|
37
|
+
|
|
38
|
+
Clinical next-event prediction: a flat-MLP with PubMedBERT-enriched features and self-knowledge distillation, trained on EHR event sequences.
|
|
39
|
+
|
|
40
|
+
## Install
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install cadence-core
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## Quickstart
|
|
47
|
+
|
|
48
|
+
### Inference with a pretrained model
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from cadence import Cadence
|
|
52
|
+
|
|
53
|
+
model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
|
|
54
|
+
next_event, days_until = model.predict(patient_events)
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
### Training on your own data
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
from cadence import Cadence
|
|
61
|
+
|
|
62
|
+
model = Cadence()
|
|
63
|
+
model.fit(events_df)
|
|
64
|
+
model.save("my-model/")
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Input data format
|
|
68
|
+
|
|
69
|
+
`events_df` is a pandas DataFrame with the following columns:
|
|
70
|
+
|
|
71
|
+
- `patient_id` — patient identifier (any hashable type)
|
|
72
|
+
- `timestamp` — event time (datetime or ISO string; coerced via `pd.to_datetime`)
|
|
73
|
+
- `event_text` — free-text event description (e.g. "Patient admitted with chest pain")
|
|
74
|
+
- `cluster_id` — integer event cluster (optional; auto-assigned via sentence-transformers + KMeans if omitted)
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
|
|
78
|
+
| patient_id | timestamp | event_text | cluster_id |
|
|
79
|
+
|------------|---------------------|-------------------------------------|------------|
|
|
80
|
+
| P001 | 2024-01-15 09:30 | Patient admitted with chest pain | 3 |
|
|
81
|
+
| P001 | 2024-01-15 11:45 | ECG performed, ST elevation | 7 |
|
|
82
|
+
| P002 | 2024-02-03 14:20 | Routine check-up, vitals normal | 1 |
|
|
83
|
+
|
|
84
|
+
`.predict(patient_events)` returns `(next_event_label, days_until)` for `top_k=1`, or a dict of top-k predictions with confidences when `top_k > 1`.
|
|
85
|
+
|
|
86
|
+
## Architecture
|
|
87
|
+
|
|
88
|
+
Cadence implements the NVC-Clean v14 champion model:
|
|
89
|
+
|
|
90
|
+
- **Feature engineering**: 884-d handcrafted features (population anomaly scores, narrative velocity, temporal-gap statistics, cluster bag-of-words)
|
|
91
|
+
- **Optional**: PubMedBERT embeddings (mean + last token, 1536-d) appended → 2420-d total input
|
|
92
|
+
- **Backbone**: flat-MLP with BatchNorm (Linear 884→1024→1024→512 with residual skip)
|
|
93
|
+
- **Classification head**: Asymmetric Loss (ASL, Ridnik et al. 2021)
|
|
94
|
+
- **Regression head**: quantile-bin softmax expectation for time-to-next-event
|
|
95
|
+
- **Training**: Phase 1 (frozen) + Phase 2 (full), MixUp augmentation, Stochastic Weight Averaging, self-knowledge distillation
|
|
96
|
+
|
|
97
|
+
## Citation
|
|
98
|
+
|
|
99
|
+
Manuscript in preparation; citation forthcoming.
|
|
100
|
+
|
|
101
|
+
## License
|
|
102
|
+
|
|
103
|
+
MIT. Copyright 2026 Amir Rouhollahi.
|
|
104
|
+
|
|
105
|
+
## Links
|
|
106
|
+
|
|
107
|
+
- GitHub: https://github.com/amirrouh/cadence
|
|
108
|
+
- Issues: https://github.com/amirrouh/cadence/issues
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# Cadence
|
|
2
|
+
|
|
3
|
+
Clinical next-event prediction: a flat-MLP with PubMedBERT-enriched features and self-knowledge distillation, trained on EHR event sequences.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install cadence-core
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Quickstart
|
|
12
|
+
|
|
13
|
+
### Inference with a pretrained model
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from cadence import Cadence
|
|
17
|
+
|
|
18
|
+
model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
|
|
19
|
+
next_event, days_until = model.predict(patient_events)
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
### Training on your own data
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
from cadence import Cadence
|
|
26
|
+
|
|
27
|
+
model = Cadence()
|
|
28
|
+
model.fit(events_df)
|
|
29
|
+
model.save("my-model/")
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Input data format
|
|
33
|
+
|
|
34
|
+
`events_df` is a pandas DataFrame with the following columns:
|
|
35
|
+
|
|
36
|
+
- `patient_id` — patient identifier (any hashable type)
|
|
37
|
+
- `timestamp` — event time (datetime or ISO string; coerced via `pd.to_datetime`)
|
|
38
|
+
- `event_text` — free-text event description (e.g. "Patient admitted with chest pain")
|
|
39
|
+
- `cluster_id` — integer event cluster (optional; auto-assigned via sentence-transformers + KMeans if omitted)
|
|
40
|
+
|
|
41
|
+
Example:
|
|
42
|
+
|
|
43
|
+
| patient_id | timestamp | event_text | cluster_id |
|
|
44
|
+
|------------|---------------------|-------------------------------------|------------|
|
|
45
|
+
| P001 | 2024-01-15 09:30 | Patient admitted with chest pain | 3 |
|
|
46
|
+
| P001 | 2024-01-15 11:45 | ECG performed, ST elevation | 7 |
|
|
47
|
+
| P002 | 2024-02-03 14:20 | Routine check-up, vitals normal | 1 |
|
|
48
|
+
|
|
49
|
+
`.predict(patient_events)` returns `(next_event_label, days_until)` for `top_k=1`, or a dict of top-k predictions with confidences when `top_k > 1`.
|
|
50
|
+
|
|
51
|
+
## Architecture
|
|
52
|
+
|
|
53
|
+
Cadence implements the NVC-Clean v14 champion model:
|
|
54
|
+
|
|
55
|
+
- **Feature engineering**: 884-d handcrafted features (population anomaly scores, narrative velocity, temporal-gap statistics, cluster bag-of-words)
|
|
56
|
+
- **Optional**: PubMedBERT embeddings (mean + last token, 1536-d) appended → 2420-d total input
|
|
57
|
+
- **Backbone**: flat-MLP with BatchNorm (Linear 884→1024→1024→512 with residual skip)
|
|
58
|
+
- **Classification head**: Asymmetric Loss (ASL, Ridnik et al. 2021)
|
|
59
|
+
- **Regression head**: quantile-bin softmax expectation for time-to-next-event
|
|
60
|
+
- **Training**: Phase 1 (frozen) + Phase 2 (full), MixUp augmentation, Stochastic Weight Averaging, self-knowledge distillation
|
|
61
|
+
|
|
62
|
+
## Citation
|
|
63
|
+
|
|
64
|
+
Manuscript in preparation; citation forthcoming.
|
|
65
|
+
|
|
66
|
+
## License
|
|
67
|
+
|
|
68
|
+
MIT. Copyright 2026 Amir Rouhollahi.
|
|
69
|
+
|
|
70
|
+
## Links
|
|
71
|
+
|
|
72
|
+
- GitHub: https://github.com/amirrouh/cadence
|
|
73
|
+
- Issues: https://github.com/amirrouh/cadence/issues
|
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
"""Cadence: flat-MLP with PubMedBERT-enriched self-distillation for
|
|
2
|
+
clinical next-event prediction.
|
|
3
|
+
|
|
4
|
+
Quick start
|
|
5
|
+
-----------
|
|
6
|
+
Inference with a pretrained model::
|
|
7
|
+
|
|
8
|
+
from cadence import Cadence
|
|
9
|
+
model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
|
|
10
|
+
next_event, days_until = model.predict(patient_events)
|
|
11
|
+
|
|
12
|
+
Training on your own data::
|
|
13
|
+
|
|
14
|
+
from cadence import Cadence
|
|
15
|
+
model = Cadence()
|
|
16
|
+
model.fit(events_df)
|
|
17
|
+
model.save("my-model/")
|
|
18
|
+
|
|
19
|
+
See README.md and examples/quickstart.py for a complete walkthrough.
|
|
20
|
+
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import torch
|
|
30
|
+
import torch.nn.functional as F
|
|
31
|
+
|
|
32
|
+
from .config import CadenceConfig
|
|
33
|
+
from .model import NVCFlatMLP
|
|
34
|
+
from .features import (
|
|
35
|
+
build_population_prior,
|
|
36
|
+
build_feature_matrix,
|
|
37
|
+
extract_features,
|
|
38
|
+
LOG_DAYS_CLIP,
|
|
39
|
+
)
|
|
40
|
+
from .data import events_df_to_records, CadenceDataset, validate_events_df
|
|
41
|
+
from .trainer import CadenceTrainer, compute_quantile_bins
|
|
42
|
+
from .pretrained import save_checkpoint, load_checkpoint, download_from_hub
|
|
43
|
+
|
|
44
|
+
__version__ = "0.1.0"
|
|
45
|
+
__all__ = ["Cadence", "CadenceConfig", "__version__"]
|
|
46
|
+
|
|
47
|
+
log = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Cadence:
|
|
51
|
+
"""High-level API for training, inference, and checkpoint management.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
config : CadenceConfig or None
|
|
56
|
+
Hyperparameter configuration. Defaults to ``CadenceConfig()`` (50
|
|
57
|
+
clusters, 884-d features, NVC-Clean v14 champion settings).
|
|
58
|
+
|
|
59
|
+
Examples
|
|
60
|
+
--------
|
|
61
|
+
>>> model = Cadence()
|
|
62
|
+
>>> model.fit(events_df) # trains on your data
|
|
63
|
+
>>> next_event, days = model.predict(patient_df) # single-patient inference
|
|
64
|
+
>>> model.save("my-model/")
|
|
65
|
+
>>> model2 = Cadence.from_pretrained("my-model/")
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: Optional[CadenceConfig] = None) -> None:
|
|
69
|
+
self.config = config or CadenceConfig()
|
|
70
|
+
self._model: Optional[NVCFlatMLP] = None
|
|
71
|
+
self._clusterer = None # CadenceClusterer | None
|
|
72
|
+
self._prior: Optional[dict] = None
|
|
73
|
+
self._bin_centers: Optional[np.ndarray] = None
|
|
74
|
+
self._bin_edges: Optional[np.ndarray] = None
|
|
75
|
+
self._cluster_labels: Optional[dict] = None
|
|
76
|
+
self._device = torch.device(
|
|
77
|
+
"cuda" if torch.cuda.is_available() else "cpu"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
# Fit
|
|
82
|
+
# ------------------------------------------------------------------
|
|
83
|
+
|
|
84
|
+
def fit(
|
|
85
|
+
self,
|
|
86
|
+
events_df, # pd.DataFrame
|
|
87
|
+
epochs: Optional[int] = None,
|
|
88
|
+
val_df=None, # pd.DataFrame | None — if None, 10 % split used
|
|
89
|
+
verbose: bool = True,
|
|
90
|
+
) -> "Cadence":
|
|
91
|
+
"""Train Cadence on ``events_df``.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
events_df : pd.DataFrame
|
|
96
|
+
Columns: ``patient_id``, ``timestamp``, ``event_text``.
|
|
97
|
+
Optional column ``cluster_id`` (skips auto-clustering when present).
|
|
98
|
+
epochs : int or None
|
|
99
|
+
Total training epochs. Defaults to
|
|
100
|
+
``config.phase1_epochs + config.phase2_epochs``.
|
|
101
|
+
val_df : pd.DataFrame or None
|
|
102
|
+
Validation dataframe. When None, 10 % of patients are held out.
|
|
103
|
+
verbose : bool
|
|
104
|
+
Whether to log training progress.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
self
|
|
109
|
+
"""
|
|
110
|
+
if verbose:
|
|
111
|
+
logging.basicConfig(
|
|
112
|
+
level=logging.INFO,
|
|
113
|
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
|
114
|
+
datefmt="%H:%M:%S",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
validate_events_df(events_df)
|
|
118
|
+
cfg = self.config
|
|
119
|
+
|
|
120
|
+
# ── Fit clusters if needed ────────────────────────────────────────────
|
|
121
|
+
if "cluster_id" not in events_df.columns:
|
|
122
|
+
self._fit_clusters_from_df(events_df)
|
|
123
|
+
|
|
124
|
+
# ── Train / val split ─────────────────────────────────────────────────
|
|
125
|
+
if val_df is None:
|
|
126
|
+
events_df, val_df = self._split_patients(events_df, val_frac=0.1)
|
|
127
|
+
|
|
128
|
+
# ── Build records ─────────────────────────────────────────────────────
|
|
129
|
+
train_records = events_df_to_records(
|
|
130
|
+
events_df, clusterer=self._clusterer,
|
|
131
|
+
n_clusters=cfg.n_clusters, max_history=cfg.max_history,
|
|
132
|
+
)
|
|
133
|
+
val_records = events_df_to_records(
|
|
134
|
+
val_df, clusterer=self._clusterer,
|
|
135
|
+
n_clusters=cfg.n_clusters, max_history=cfg.max_history,
|
|
136
|
+
)
|
|
137
|
+
log.info(
|
|
138
|
+
"Records: train=%d, val=%d", len(train_records), len(val_records)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# ── Population prior ──────────────────────────────────────────────────
|
|
142
|
+
self._prior = build_population_prior(train_records, cfg.n_clusters)
|
|
143
|
+
|
|
144
|
+
# ── Feature matrices ──────────────────────────────────────────────────
|
|
145
|
+
X_tr, y_cls_tr, y_reg_tr = build_feature_matrix(
|
|
146
|
+
train_records, self._prior, cfg.n_clusters, cfg.max_history
|
|
147
|
+
)
|
|
148
|
+
X_val, y_cls_val, y_reg_val = build_feature_matrix(
|
|
149
|
+
val_records, self._prior, cfg.n_clusters, cfg.max_history
|
|
150
|
+
)
|
|
151
|
+
log.info("Feature matrix: train=%s, val=%s", X_tr.shape, X_val.shape)
|
|
152
|
+
|
|
153
|
+
# Actual feature dim may differ from config default (user data)
|
|
154
|
+
n_features = X_tr.shape[1]
|
|
155
|
+
cfg.n_features = n_features
|
|
156
|
+
|
|
157
|
+
# ── Quantile bins ─────────────────────────────────────────────────────
|
|
158
|
+
bin_edges, bin_centers = compute_quantile_bins(y_reg_tr, cfg.n_reg_bins)
|
|
159
|
+
self._bin_edges = bin_edges
|
|
160
|
+
self._bin_centers = bin_centers
|
|
161
|
+
|
|
162
|
+
# ── DataLoaders ───────────────────────────────────────────────────────
|
|
163
|
+
from torch.utils.data import DataLoader
|
|
164
|
+
|
|
165
|
+
train_ds = CadenceDataset(X_tr, y_cls_tr, y_reg_tr)
|
|
166
|
+
val_ds = CadenceDataset(X_val, y_cls_val, y_reg_val)
|
|
167
|
+
train_loader = DataLoader(
|
|
168
|
+
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
|
169
|
+
num_workers=cfg.num_workers, pin_memory=self._device.type == "cuda",
|
|
170
|
+
)
|
|
171
|
+
val_loader = DataLoader(
|
|
172
|
+
val_ds, batch_size=cfg.batch_size * 2, shuffle=False,
|
|
173
|
+
num_workers=cfg.num_workers,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# ── Build model ───────────────────────────────────────────────────────
|
|
177
|
+
bin_centers_t = torch.tensor(bin_centers, dtype=torch.float32)
|
|
178
|
+
self._model = NVCFlatMLP(
|
|
179
|
+
n_features=n_features,
|
|
180
|
+
n_classes=cfg.n_clusters,
|
|
181
|
+
bin_centers=bin_centers_t,
|
|
182
|
+
config=cfg,
|
|
183
|
+
).to(self._device)
|
|
184
|
+
log.info(
|
|
185
|
+
"NVCFlatMLP: n_features=%d, n_classes=%d, params=%d",
|
|
186
|
+
n_features, cfg.n_clusters, self._model.n_params,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# ── Train ─────────────────────────────────────────────────────────────
|
|
190
|
+
trainer = CadenceTrainer(
|
|
191
|
+
model=self._model,
|
|
192
|
+
config=cfg,
|
|
193
|
+
device=self._device,
|
|
194
|
+
bin_edges=bin_edges,
|
|
195
|
+
bin_centers=bin_centers,
|
|
196
|
+
)
|
|
197
|
+
self._model = trainer.fit(train_loader, val_loader, epochs=epochs)
|
|
198
|
+
return self
|
|
199
|
+
|
|
200
|
+
# ------------------------------------------------------------------
|
|
201
|
+
# Predict
|
|
202
|
+
# ------------------------------------------------------------------
|
|
203
|
+
|
|
204
|
+
def predict(
|
|
205
|
+
self,
|
|
206
|
+
patient_events, # pd.DataFrame — single patient, sorted by timestamp
|
|
207
|
+
top_k: int = 1,
|
|
208
|
+
) -> Union[Tuple[str, float], dict]:
|
|
209
|
+
"""Predict the next event and days-until for one patient.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
patient_events : pd.DataFrame
|
|
214
|
+
History for a single patient. Same schema as ``events_df``
|
|
215
|
+
(columns: ``patient_id``, ``timestamp``, ``event_text``).
|
|
216
|
+
Must have at least 1 row.
|
|
217
|
+
top_k : int
|
|
218
|
+
When 1, returns ``(event_label, days)``.
|
|
219
|
+
When > 1, returns a dict with ``predictions`` (list of
|
|
220
|
+
``{label, cluster_id, confidence, days}``).
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
(next_event_label, days_until) when top_k=1, else dict.
|
|
225
|
+
"""
|
|
226
|
+
if self._model is None:
|
|
227
|
+
raise RuntimeError(
|
|
228
|
+
"Model is not trained. Call .fit() or .from_pretrained() first."
|
|
229
|
+
)
|
|
230
|
+
if self._prior is None:
|
|
231
|
+
raise RuntimeError(
|
|
232
|
+
"Population prior is missing. The model may not have been "
|
|
233
|
+
"trained with .fit()."
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
validate_events_df(patient_events)
|
|
237
|
+
|
|
238
|
+
# Build record
|
|
239
|
+
records = events_df_to_records(
|
|
240
|
+
patient_events,
|
|
241
|
+
clusterer=self._clusterer,
|
|
242
|
+
n_clusters=self.config.n_clusters,
|
|
243
|
+
max_history=self.config.max_history,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if not records:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
"patient_events must have at least 2 rows to form one "
|
|
249
|
+
"prediction example (history + target)."
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Use the last record (most recent history window)
|
|
253
|
+
record = records[-1]
|
|
254
|
+
feat = extract_features(
|
|
255
|
+
record, self._prior,
|
|
256
|
+
n_clusters=self.config.n_clusters,
|
|
257
|
+
max_history=self.config.max_history,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
X = torch.tensor(feat, dtype=torch.float32).unsqueeze(0).to(self._device)
|
|
261
|
+
|
|
262
|
+
self._model.eval()
|
|
263
|
+
with torch.no_grad():
|
|
264
|
+
logits, reg_logits = self._model(X)
|
|
265
|
+
days = self._model.predict_days(reg_logits).item()
|
|
266
|
+
|
|
267
|
+
probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
|
|
268
|
+
|
|
269
|
+
if top_k == 1:
|
|
270
|
+
best_cid = int(probs.argmax())
|
|
271
|
+
label = self._cluster_label(best_cid)
|
|
272
|
+
return label, days
|
|
273
|
+
|
|
274
|
+
# top_k > 1
|
|
275
|
+
top_ids = np.argsort(-probs)[:top_k]
|
|
276
|
+
preds = [
|
|
277
|
+
{
|
|
278
|
+
"label": self._cluster_label(int(cid)),
|
|
279
|
+
"cluster_id": int(cid),
|
|
280
|
+
"confidence": float(probs[cid]),
|
|
281
|
+
"days": days,
|
|
282
|
+
}
|
|
283
|
+
for cid in top_ids
|
|
284
|
+
]
|
|
285
|
+
return {"predictions": preds}
|
|
286
|
+
|
|
287
|
+
def _cluster_label(self, cluster_id: int) -> str:
|
|
288
|
+
if self._cluster_labels and str(cluster_id) in self._cluster_labels:
|
|
289
|
+
return self._cluster_labels[str(cluster_id)]
|
|
290
|
+
if self._cluster_labels and cluster_id in self._cluster_labels:
|
|
291
|
+
return self._cluster_labels[cluster_id]
|
|
292
|
+
return f"cluster_{cluster_id}"
|
|
293
|
+
|
|
294
|
+
# ------------------------------------------------------------------
|
|
295
|
+
# Save / load
|
|
296
|
+
# ------------------------------------------------------------------
|
|
297
|
+
|
|
298
|
+
def save(self, directory: Union[str, Path]) -> None:
|
|
299
|
+
"""Save the model, config, and clusterer to ``directory``.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
directory : str | Path
|
|
304
|
+
"""
|
|
305
|
+
if self._model is None:
|
|
306
|
+
raise RuntimeError("No model to save. Call .fit() first.")
|
|
307
|
+
|
|
308
|
+
save_checkpoint(
|
|
309
|
+
model=self._model,
|
|
310
|
+
config=self.config,
|
|
311
|
+
bin_centers=self._bin_centers,
|
|
312
|
+
save_dir=directory,
|
|
313
|
+
clusterer=self._clusterer,
|
|
314
|
+
cluster_labels=self._cluster_labels,
|
|
315
|
+
extra={"prior": self._prior},
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
@classmethod
|
|
319
|
+
def from_pretrained(
|
|
320
|
+
cls,
|
|
321
|
+
path_or_repo: Union[str, Path],
|
|
322
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
323
|
+
revision: Optional[str] = None,
|
|
324
|
+
) -> "Cadence":
|
|
325
|
+
"""Load a Cadence model from a local directory or HuggingFace Hub.
|
|
326
|
+
|
|
327
|
+
Parameters
|
|
328
|
+
----------
|
|
329
|
+
path_or_repo : str | Path
|
|
330
|
+
Local directory path OR HuggingFace repo ID (e.g.
|
|
331
|
+
``"amirrouh/cadence-mimic-100k"``).
|
|
332
|
+
device : str | torch.device | None
|
|
333
|
+
revision : str | None
|
|
334
|
+
HuggingFace revision / tag (ignored for local paths).
|
|
335
|
+
|
|
336
|
+
Returns
|
|
337
|
+
-------
|
|
338
|
+
Cadence instance, ready for inference.
|
|
339
|
+
"""
|
|
340
|
+
local_path = Path(path_or_repo)
|
|
341
|
+
|
|
342
|
+
if not local_path.exists():
|
|
343
|
+
# Try HuggingFace Hub
|
|
344
|
+
local_path = download_from_hub(
|
|
345
|
+
str(path_or_repo), revision=revision
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
model_obj, config, bin_centers, clusterer, cluster_labels = load_checkpoint(
|
|
349
|
+
local_path, device=device
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Restore population prior if saved in config.json
|
|
353
|
+
cfg_dict = json.loads((local_path / "config.json").read_text())
|
|
354
|
+
prior = cfg_dict.get("prior", None)
|
|
355
|
+
|
|
356
|
+
instance = cls(config=config)
|
|
357
|
+
instance._model = model_obj
|
|
358
|
+
instance._clusterer = clusterer
|
|
359
|
+
instance._bin_centers = bin_centers
|
|
360
|
+
instance._cluster_labels = cluster_labels
|
|
361
|
+
instance._prior = prior
|
|
362
|
+
if device is not None:
|
|
363
|
+
instance._device = torch.device(device)
|
|
364
|
+
else:
|
|
365
|
+
instance._device = next(model_obj.parameters()).device
|
|
366
|
+
return instance
|
|
367
|
+
|
|
368
|
+
# ------------------------------------------------------------------
|
|
369
|
+
# Cluster helpers
|
|
370
|
+
# ------------------------------------------------------------------
|
|
371
|
+
|
|
372
|
+
def fit_clusters(
|
|
373
|
+
self,
|
|
374
|
+
texts: List[str],
|
|
375
|
+
n_clusters: int = 50,
|
|
376
|
+
encoder_model: str = "all-MiniLM-L6-v2",
|
|
377
|
+
) -> "Cadence":
|
|
378
|
+
"""Fit event-text clusters from a list of raw event strings.
|
|
379
|
+
|
|
380
|
+
Call this before ``fit()`` if you want to control the cluster
|
|
381
|
+
fitting step explicitly.
|
|
382
|
+
|
|
383
|
+
Parameters
|
|
384
|
+
----------
|
|
385
|
+
texts : list of str
|
|
386
|
+
n_clusters : int
|
|
387
|
+
encoder_model : str
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
self
|
|
392
|
+
"""
|
|
393
|
+
from .clustering import CadenceClusterer
|
|
394
|
+
|
|
395
|
+
self._clusterer = CadenceClusterer(
|
|
396
|
+
n_clusters=n_clusters, encoder_model=encoder_model
|
|
397
|
+
).fit(texts)
|
|
398
|
+
self.config.n_clusters = n_clusters
|
|
399
|
+
return self
|
|
400
|
+
|
|
401
|
+
# ------------------------------------------------------------------
|
|
402
|
+
# Internal helpers
|
|
403
|
+
# ------------------------------------------------------------------
|
|
404
|
+
|
|
405
|
+
def _fit_clusters_from_df(self, events_df) -> None:
|
|
406
|
+
"""Auto-fit clusters from unique event texts in events_df."""
|
|
407
|
+
from .clustering import CadenceClusterer
|
|
408
|
+
|
|
409
|
+
texts = events_df["event_text"].dropna().unique().tolist()
|
|
410
|
+
log.info(
|
|
411
|
+
"Auto-fitting clusters: %d unique event texts → %d clusters",
|
|
412
|
+
len(texts), self.config.n_clusters,
|
|
413
|
+
)
|
|
414
|
+
self._clusterer = CadenceClusterer(
|
|
415
|
+
n_clusters=self.config.n_clusters,
|
|
416
|
+
encoder_model=self.config.cluster_encoder,
|
|
417
|
+
).fit(texts)
|
|
418
|
+
|
|
419
|
+
@staticmethod
|
|
420
|
+
def _split_patients(df, val_frac: float = 0.1):
|
|
421
|
+
"""Hold out val_frac of patients as the validation set."""
|
|
422
|
+
import pandas as pd
|
|
423
|
+
|
|
424
|
+
patients = np.array(df["patient_id"].unique())
|
|
425
|
+
np.random.shuffle(patients)
|
|
426
|
+
n_val = max(1, int(len(patients) * val_frac))
|
|
427
|
+
val_patients = set(patients[:n_val])
|
|
428
|
+
train_df = df[~df["patient_id"].isin(val_patients)].copy()
|
|
429
|
+
val_df = df[df["patient_id"].isin(val_patients)].copy()
|
|
430
|
+
return train_df, val_df
|