cadence-core 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,40 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ build-and-publish:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ contents: read
13
+
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v5
19
+ with:
20
+ python-version: "3.11"
21
+
22
+ - name: Install uv
23
+ uses: astral-sh/setup-uv@v4
24
+ with:
25
+ version: "latest"
26
+
27
+ - name: Install build dependencies
28
+ run: uv pip install --system hatchling build twine
29
+
30
+ - name: Build package
31
+ run: python -m build
32
+
33
+ - name: Check distribution
34
+ run: twine check dist/*
35
+
36
+ - name: Publish to PyPI
37
+ env:
38
+ TWINE_USERNAME: __token__
39
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
40
+ run: twine upload dist/*
@@ -0,0 +1,25 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .venv/
6
+ venv/
7
+ .uv/
8
+ build/
9
+ dist/
10
+ *.egg-info/
11
+ .eggs/
12
+ .pytest_cache/
13
+ .ruff_cache/
14
+ .mypy_cache/
15
+ *.pt
16
+ *.bin
17
+ *.pth
18
+ *.ckpt
19
+ *.pkl
20
+ local.env
21
+ .env
22
+ *.log
23
+ .DS_Store
24
+ Thumbs.db
25
+ uv.lock
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Amir Rouhollahi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,108 @@
1
+ Metadata-Version: 2.4
2
+ Name: cadence-core
3
+ Version: 0.1.0
4
+ Summary: Flat-MLP with PubMedBERT-enriched self-distillation for clinical next-event prediction
5
+ Project-URL: Homepage, https://github.com/amirrouh/cadence
6
+ Project-URL: Repository, https://github.com/amirrouh/cadence
7
+ Project-URL: Issues, https://github.com/amirrouh/cadence/issues
8
+ Author-email: Amir Rouhollahi <arouhollahi@bwh.harvard.edu>
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Keywords: clinical,ehr,healthcare-ml,next-event-prediction,pubmedbert
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
20
+ Requires-Python: >=3.10
21
+ Requires-Dist: huggingface-hub>=0.23
22
+ Requires-Dist: numpy>=1.24
23
+ Requires-Dist: pandas>=2.0
24
+ Requires-Dist: scikit-learn>=1.3
25
+ Requires-Dist: sentence-transformers>=2.7
26
+ Requires-Dist: torch>=2.1
27
+ Requires-Dist: tqdm>=4.66
28
+ Requires-Dist: transformers>=4.40
29
+ Provides-Extra: dev
30
+ Requires-Dist: build; extra == 'dev'
31
+ Requires-Dist: pytest>=7; extra == 'dev'
32
+ Requires-Dist: ruff>=0.5; extra == 'dev'
33
+ Requires-Dist: twine; extra == 'dev'
34
+ Description-Content-Type: text/markdown
35
+
36
+ # Cadence
37
+
38
+ Clinical next-event prediction: a flat-MLP with PubMedBERT-enriched features and self-knowledge distillation, trained on EHR event sequences.
39
+
40
+ ## Install
41
+
42
+ ```bash
43
+ pip install cadence-core
44
+ ```
45
+
46
+ ## Quickstart
47
+
48
+ ### Inference with a pretrained model
49
+
50
+ ```python
51
+ from cadence import Cadence
52
+
53
+ model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
54
+ next_event, days_until = model.predict(patient_events)
55
+ ```
56
+
57
+ ### Training on your own data
58
+
59
+ ```python
60
+ from cadence import Cadence
61
+
62
+ model = Cadence()
63
+ model.fit(events_df)
64
+ model.save("my-model/")
65
+ ```
66
+
67
+ ## Input data format
68
+
69
+ `events_df` is a pandas DataFrame with the following columns:
70
+
71
+ - `patient_id` — patient identifier (any hashable type)
72
+ - `timestamp` — event time (datetime or ISO string; coerced via `pd.to_datetime`)
73
+ - `event_text` — free-text event description (e.g. "Patient admitted with chest pain")
74
+ - `cluster_id` — integer event cluster (optional; auto-assigned via sentence-transformers + KMeans if omitted)
75
+
76
+ Example:
77
+
78
+ | patient_id | timestamp | event_text | cluster_id |
79
+ |------------|---------------------|-------------------------------------|------------|
80
+ | P001 | 2024-01-15 09:30 | Patient admitted with chest pain | 3 |
81
+ | P001 | 2024-01-15 11:45 | ECG performed, ST elevation | 7 |
82
+ | P002 | 2024-02-03 14:20 | Routine check-up, vitals normal | 1 |
83
+
84
+ `.predict(patient_events)` returns `(next_event_label, days_until)` for `top_k=1`, or a dict of top-k predictions with confidences when `top_k > 1`.
85
+
86
+ ## Architecture
87
+
88
+ Cadence implements the NVC-Clean v14 champion model:
89
+
90
+ - **Feature engineering**: 884-d handcrafted features (population anomaly scores, narrative velocity, temporal-gap statistics, cluster bag-of-words)
91
+ - **Optional**: PubMedBERT embeddings (mean + last token, 1536-d) appended → 2420-d total input
92
+ - **Backbone**: flat-MLP with BatchNorm (Linear 884→1024→1024→512 with residual skip)
93
+ - **Classification head**: Asymmetric Loss (ASL, Ridnik et al. 2021)
94
+ - **Regression head**: quantile-bin softmax expectation for time-to-next-event
95
+ - **Training**: Phase 1 (frozen) + Phase 2 (full), MixUp augmentation, Stochastic Weight Averaging, self-knowledge distillation
96
+
97
+ ## Citation
98
+
99
+ Manuscript in preparation; citation forthcoming.
100
+
101
+ ## License
102
+
103
+ MIT. Copyright 2026 Amir Rouhollahi.
104
+
105
+ ## Links
106
+
107
+ - GitHub: https://github.com/amirrouh/cadence
108
+ - Issues: https://github.com/amirrouh/cadence/issues
@@ -0,0 +1,73 @@
1
+ # Cadence
2
+
3
+ Clinical next-event prediction: a flat-MLP with PubMedBERT-enriched features and self-knowledge distillation, trained on EHR event sequences.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install cadence-core
9
+ ```
10
+
11
+ ## Quickstart
12
+
13
+ ### Inference with a pretrained model
14
+
15
+ ```python
16
+ from cadence import Cadence
17
+
18
+ model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
19
+ next_event, days_until = model.predict(patient_events)
20
+ ```
21
+
22
+ ### Training on your own data
23
+
24
+ ```python
25
+ from cadence import Cadence
26
+
27
+ model = Cadence()
28
+ model.fit(events_df)
29
+ model.save("my-model/")
30
+ ```
31
+
32
+ ## Input data format
33
+
34
+ `events_df` is a pandas DataFrame with the following columns:
35
+
36
+ - `patient_id` — patient identifier (any hashable type)
37
+ - `timestamp` — event time (datetime or ISO string; coerced via `pd.to_datetime`)
38
+ - `event_text` — free-text event description (e.g. "Patient admitted with chest pain")
39
+ - `cluster_id` — integer event cluster (optional; auto-assigned via sentence-transformers + KMeans if omitted)
40
+
41
+ Example:
42
+
43
+ | patient_id | timestamp | event_text | cluster_id |
44
+ |------------|---------------------|-------------------------------------|------------|
45
+ | P001 | 2024-01-15 09:30 | Patient admitted with chest pain | 3 |
46
+ | P001 | 2024-01-15 11:45 | ECG performed, ST elevation | 7 |
47
+ | P002 | 2024-02-03 14:20 | Routine check-up, vitals normal | 1 |
48
+
49
+ `.predict(patient_events)` returns `(next_event_label, days_until)` for `top_k=1`, or a dict of top-k predictions with confidences when `top_k > 1`.
50
+
51
+ ## Architecture
52
+
53
+ Cadence implements the NVC-Clean v14 champion model:
54
+
55
+ - **Feature engineering**: 884-d handcrafted features (population anomaly scores, narrative velocity, temporal-gap statistics, cluster bag-of-words)
56
+ - **Optional**: PubMedBERT embeddings (mean + last token, 1536-d) appended → 2420-d total input
57
+ - **Backbone**: flat-MLP with BatchNorm (Linear 884→1024→1024→512 with residual skip)
58
+ - **Classification head**: Asymmetric Loss (ASL, Ridnik et al. 2021)
59
+ - **Regression head**: quantile-bin softmax expectation for time-to-next-event
60
+ - **Training**: Phase 1 (frozen) + Phase 2 (full), MixUp augmentation, Stochastic Weight Averaging, self-knowledge distillation
61
+
62
+ ## Citation
63
+
64
+ Manuscript in preparation; citation forthcoming.
65
+
66
+ ## License
67
+
68
+ MIT. Copyright 2026 Amir Rouhollahi.
69
+
70
+ ## Links
71
+
72
+ - GitHub: https://github.com/amirrouh/cadence
73
+ - Issues: https://github.com/amirrouh/cadence/issues
@@ -0,0 +1,430 @@
1
+ """Cadence: flat-MLP with PubMedBERT-enriched self-distillation for
2
+ clinical next-event prediction.
3
+
4
+ Quick start
5
+ -----------
6
+ Inference with a pretrained model::
7
+
8
+ from cadence import Cadence
9
+ model = Cadence.from_pretrained("amirrouh/cadence-mimic-100k")
10
+ next_event, days_until = model.predict(patient_events)
11
+
12
+ Training on your own data::
13
+
14
+ from cadence import Cadence
15
+ model = Cadence()
16
+ model.fit(events_df)
17
+ model.save("my-model/")
18
+
19
+ See README.md and examples/quickstart.py for a complete walkthrough.
20
+ """
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import logging
25
+ from pathlib import Path
26
+ from typing import Dict, List, Optional, Tuple, Union
27
+
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+
32
+ from .config import CadenceConfig
33
+ from .model import NVCFlatMLP
34
+ from .features import (
35
+ build_population_prior,
36
+ build_feature_matrix,
37
+ extract_features,
38
+ LOG_DAYS_CLIP,
39
+ )
40
+ from .data import events_df_to_records, CadenceDataset, validate_events_df
41
+ from .trainer import CadenceTrainer, compute_quantile_bins
42
+ from .pretrained import save_checkpoint, load_checkpoint, download_from_hub
43
+
44
+ __version__ = "0.1.0"
45
+ __all__ = ["Cadence", "CadenceConfig", "__version__"]
46
+
47
+ log = logging.getLogger(__name__)
48
+
49
+
50
+ class Cadence:
51
+ """High-level API for training, inference, and checkpoint management.
52
+
53
+ Parameters
54
+ ----------
55
+ config : CadenceConfig or None
56
+ Hyperparameter configuration. Defaults to ``CadenceConfig()`` (50
57
+ clusters, 884-d features, NVC-Clean v14 champion settings).
58
+
59
+ Examples
60
+ --------
61
+ >>> model = Cadence()
62
+ >>> model.fit(events_df) # trains on your data
63
+ >>> next_event, days = model.predict(patient_df) # single-patient inference
64
+ >>> model.save("my-model/")
65
+ >>> model2 = Cadence.from_pretrained("my-model/")
66
+ """
67
+
68
+ def __init__(self, config: Optional[CadenceConfig] = None) -> None:
69
+ self.config = config or CadenceConfig()
70
+ self._model: Optional[NVCFlatMLP] = None
71
+ self._clusterer = None # CadenceClusterer | None
72
+ self._prior: Optional[dict] = None
73
+ self._bin_centers: Optional[np.ndarray] = None
74
+ self._bin_edges: Optional[np.ndarray] = None
75
+ self._cluster_labels: Optional[dict] = None
76
+ self._device = torch.device(
77
+ "cuda" if torch.cuda.is_available() else "cpu"
78
+ )
79
+
80
+ # ------------------------------------------------------------------
81
+ # Fit
82
+ # ------------------------------------------------------------------
83
+
84
+ def fit(
85
+ self,
86
+ events_df, # pd.DataFrame
87
+ epochs: Optional[int] = None,
88
+ val_df=None, # pd.DataFrame | None — if None, 10 % split used
89
+ verbose: bool = True,
90
+ ) -> "Cadence":
91
+ """Train Cadence on ``events_df``.
92
+
93
+ Parameters
94
+ ----------
95
+ events_df : pd.DataFrame
96
+ Columns: ``patient_id``, ``timestamp``, ``event_text``.
97
+ Optional column ``cluster_id`` (skips auto-clustering when present).
98
+ epochs : int or None
99
+ Total training epochs. Defaults to
100
+ ``config.phase1_epochs + config.phase2_epochs``.
101
+ val_df : pd.DataFrame or None
102
+ Validation dataframe. When None, 10 % of patients are held out.
103
+ verbose : bool
104
+ Whether to log training progress.
105
+
106
+ Returns
107
+ -------
108
+ self
109
+ """
110
+ if verbose:
111
+ logging.basicConfig(
112
+ level=logging.INFO,
113
+ format="%(asctime)s %(levelname)-8s %(message)s",
114
+ datefmt="%H:%M:%S",
115
+ )
116
+
117
+ validate_events_df(events_df)
118
+ cfg = self.config
119
+
120
+ # ── Fit clusters if needed ────────────────────────────────────────────
121
+ if "cluster_id" not in events_df.columns:
122
+ self._fit_clusters_from_df(events_df)
123
+
124
+ # ── Train / val split ─────────────────────────────────────────────────
125
+ if val_df is None:
126
+ events_df, val_df = self._split_patients(events_df, val_frac=0.1)
127
+
128
+ # ── Build records ─────────────────────────────────────────────────────
129
+ train_records = events_df_to_records(
130
+ events_df, clusterer=self._clusterer,
131
+ n_clusters=cfg.n_clusters, max_history=cfg.max_history,
132
+ )
133
+ val_records = events_df_to_records(
134
+ val_df, clusterer=self._clusterer,
135
+ n_clusters=cfg.n_clusters, max_history=cfg.max_history,
136
+ )
137
+ log.info(
138
+ "Records: train=%d, val=%d", len(train_records), len(val_records)
139
+ )
140
+
141
+ # ── Population prior ──────────────────────────────────────────────────
142
+ self._prior = build_population_prior(train_records, cfg.n_clusters)
143
+
144
+ # ── Feature matrices ──────────────────────────────────────────────────
145
+ X_tr, y_cls_tr, y_reg_tr = build_feature_matrix(
146
+ train_records, self._prior, cfg.n_clusters, cfg.max_history
147
+ )
148
+ X_val, y_cls_val, y_reg_val = build_feature_matrix(
149
+ val_records, self._prior, cfg.n_clusters, cfg.max_history
150
+ )
151
+ log.info("Feature matrix: train=%s, val=%s", X_tr.shape, X_val.shape)
152
+
153
+ # Actual feature dim may differ from config default (user data)
154
+ n_features = X_tr.shape[1]
155
+ cfg.n_features = n_features
156
+
157
+ # ── Quantile bins ─────────────────────────────────────────────────────
158
+ bin_edges, bin_centers = compute_quantile_bins(y_reg_tr, cfg.n_reg_bins)
159
+ self._bin_edges = bin_edges
160
+ self._bin_centers = bin_centers
161
+
162
+ # ── DataLoaders ───────────────────────────────────────────────────────
163
+ from torch.utils.data import DataLoader
164
+
165
+ train_ds = CadenceDataset(X_tr, y_cls_tr, y_reg_tr)
166
+ val_ds = CadenceDataset(X_val, y_cls_val, y_reg_val)
167
+ train_loader = DataLoader(
168
+ train_ds, batch_size=cfg.batch_size, shuffle=True,
169
+ num_workers=cfg.num_workers, pin_memory=self._device.type == "cuda",
170
+ )
171
+ val_loader = DataLoader(
172
+ val_ds, batch_size=cfg.batch_size * 2, shuffle=False,
173
+ num_workers=cfg.num_workers,
174
+ )
175
+
176
+ # ── Build model ───────────────────────────────────────────────────────
177
+ bin_centers_t = torch.tensor(bin_centers, dtype=torch.float32)
178
+ self._model = NVCFlatMLP(
179
+ n_features=n_features,
180
+ n_classes=cfg.n_clusters,
181
+ bin_centers=bin_centers_t,
182
+ config=cfg,
183
+ ).to(self._device)
184
+ log.info(
185
+ "NVCFlatMLP: n_features=%d, n_classes=%d, params=%d",
186
+ n_features, cfg.n_clusters, self._model.n_params,
187
+ )
188
+
189
+ # ── Train ─────────────────────────────────────────────────────────────
190
+ trainer = CadenceTrainer(
191
+ model=self._model,
192
+ config=cfg,
193
+ device=self._device,
194
+ bin_edges=bin_edges,
195
+ bin_centers=bin_centers,
196
+ )
197
+ self._model = trainer.fit(train_loader, val_loader, epochs=epochs)
198
+ return self
199
+
200
+ # ------------------------------------------------------------------
201
+ # Predict
202
+ # ------------------------------------------------------------------
203
+
204
+ def predict(
205
+ self,
206
+ patient_events, # pd.DataFrame — single patient, sorted by timestamp
207
+ top_k: int = 1,
208
+ ) -> Union[Tuple[str, float], dict]:
209
+ """Predict the next event and days-until for one patient.
210
+
211
+ Parameters
212
+ ----------
213
+ patient_events : pd.DataFrame
214
+ History for a single patient. Same schema as ``events_df``
215
+ (columns: ``patient_id``, ``timestamp``, ``event_text``).
216
+ Must have at least 1 row.
217
+ top_k : int
218
+ When 1, returns ``(event_label, days)``.
219
+ When > 1, returns a dict with ``predictions`` (list of
220
+ ``{label, cluster_id, confidence, days}``).
221
+
222
+ Returns
223
+ -------
224
+ (next_event_label, days_until) when top_k=1, else dict.
225
+ """
226
+ if self._model is None:
227
+ raise RuntimeError(
228
+ "Model is not trained. Call .fit() or .from_pretrained() first."
229
+ )
230
+ if self._prior is None:
231
+ raise RuntimeError(
232
+ "Population prior is missing. The model may not have been "
233
+ "trained with .fit()."
234
+ )
235
+
236
+ validate_events_df(patient_events)
237
+
238
+ # Build record
239
+ records = events_df_to_records(
240
+ patient_events,
241
+ clusterer=self._clusterer,
242
+ n_clusters=self.config.n_clusters,
243
+ max_history=self.config.max_history,
244
+ )
245
+
246
+ if not records:
247
+ raise ValueError(
248
+ "patient_events must have at least 2 rows to form one "
249
+ "prediction example (history + target)."
250
+ )
251
+
252
+ # Use the last record (most recent history window)
253
+ record = records[-1]
254
+ feat = extract_features(
255
+ record, self._prior,
256
+ n_clusters=self.config.n_clusters,
257
+ max_history=self.config.max_history,
258
+ )
259
+
260
+ X = torch.tensor(feat, dtype=torch.float32).unsqueeze(0).to(self._device)
261
+
262
+ self._model.eval()
263
+ with torch.no_grad():
264
+ logits, reg_logits = self._model(X)
265
+ days = self._model.predict_days(reg_logits).item()
266
+
267
+ probs = F.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
268
+
269
+ if top_k == 1:
270
+ best_cid = int(probs.argmax())
271
+ label = self._cluster_label(best_cid)
272
+ return label, days
273
+
274
+ # top_k > 1
275
+ top_ids = np.argsort(-probs)[:top_k]
276
+ preds = [
277
+ {
278
+ "label": self._cluster_label(int(cid)),
279
+ "cluster_id": int(cid),
280
+ "confidence": float(probs[cid]),
281
+ "days": days,
282
+ }
283
+ for cid in top_ids
284
+ ]
285
+ return {"predictions": preds}
286
+
287
+ def _cluster_label(self, cluster_id: int) -> str:
288
+ if self._cluster_labels and str(cluster_id) in self._cluster_labels:
289
+ return self._cluster_labels[str(cluster_id)]
290
+ if self._cluster_labels and cluster_id in self._cluster_labels:
291
+ return self._cluster_labels[cluster_id]
292
+ return f"cluster_{cluster_id}"
293
+
294
+ # ------------------------------------------------------------------
295
+ # Save / load
296
+ # ------------------------------------------------------------------
297
+
298
+ def save(self, directory: Union[str, Path]) -> None:
299
+ """Save the model, config, and clusterer to ``directory``.
300
+
301
+ Parameters
302
+ ----------
303
+ directory : str | Path
304
+ """
305
+ if self._model is None:
306
+ raise RuntimeError("No model to save. Call .fit() first.")
307
+
308
+ save_checkpoint(
309
+ model=self._model,
310
+ config=self.config,
311
+ bin_centers=self._bin_centers,
312
+ save_dir=directory,
313
+ clusterer=self._clusterer,
314
+ cluster_labels=self._cluster_labels,
315
+ extra={"prior": self._prior},
316
+ )
317
+
318
+ @classmethod
319
+ def from_pretrained(
320
+ cls,
321
+ path_or_repo: Union[str, Path],
322
+ device: Optional[Union[str, torch.device]] = None,
323
+ revision: Optional[str] = None,
324
+ ) -> "Cadence":
325
+ """Load a Cadence model from a local directory or HuggingFace Hub.
326
+
327
+ Parameters
328
+ ----------
329
+ path_or_repo : str | Path
330
+ Local directory path OR HuggingFace repo ID (e.g.
331
+ ``"amirrouh/cadence-mimic-100k"``).
332
+ device : str | torch.device | None
333
+ revision : str | None
334
+ HuggingFace revision / tag (ignored for local paths).
335
+
336
+ Returns
337
+ -------
338
+ Cadence instance, ready for inference.
339
+ """
340
+ local_path = Path(path_or_repo)
341
+
342
+ if not local_path.exists():
343
+ # Try HuggingFace Hub
344
+ local_path = download_from_hub(
345
+ str(path_or_repo), revision=revision
346
+ )
347
+
348
+ model_obj, config, bin_centers, clusterer, cluster_labels = load_checkpoint(
349
+ local_path, device=device
350
+ )
351
+
352
+ # Restore population prior if saved in config.json
353
+ cfg_dict = json.loads((local_path / "config.json").read_text())
354
+ prior = cfg_dict.get("prior", None)
355
+
356
+ instance = cls(config=config)
357
+ instance._model = model_obj
358
+ instance._clusterer = clusterer
359
+ instance._bin_centers = bin_centers
360
+ instance._cluster_labels = cluster_labels
361
+ instance._prior = prior
362
+ if device is not None:
363
+ instance._device = torch.device(device)
364
+ else:
365
+ instance._device = next(model_obj.parameters()).device
366
+ return instance
367
+
368
+ # ------------------------------------------------------------------
369
+ # Cluster helpers
370
+ # ------------------------------------------------------------------
371
+
372
+ def fit_clusters(
373
+ self,
374
+ texts: List[str],
375
+ n_clusters: int = 50,
376
+ encoder_model: str = "all-MiniLM-L6-v2",
377
+ ) -> "Cadence":
378
+ """Fit event-text clusters from a list of raw event strings.
379
+
380
+ Call this before ``fit()`` if you want to control the cluster
381
+ fitting step explicitly.
382
+
383
+ Parameters
384
+ ----------
385
+ texts : list of str
386
+ n_clusters : int
387
+ encoder_model : str
388
+
389
+ Returns
390
+ -------
391
+ self
392
+ """
393
+ from .clustering import CadenceClusterer
394
+
395
+ self._clusterer = CadenceClusterer(
396
+ n_clusters=n_clusters, encoder_model=encoder_model
397
+ ).fit(texts)
398
+ self.config.n_clusters = n_clusters
399
+ return self
400
+
401
+ # ------------------------------------------------------------------
402
+ # Internal helpers
403
+ # ------------------------------------------------------------------
404
+
405
+ def _fit_clusters_from_df(self, events_df) -> None:
406
+ """Auto-fit clusters from unique event texts in events_df."""
407
+ from .clustering import CadenceClusterer
408
+
409
+ texts = events_df["event_text"].dropna().unique().tolist()
410
+ log.info(
411
+ "Auto-fitting clusters: %d unique event texts → %d clusters",
412
+ len(texts), self.config.n_clusters,
413
+ )
414
+ self._clusterer = CadenceClusterer(
415
+ n_clusters=self.config.n_clusters,
416
+ encoder_model=self.config.cluster_encoder,
417
+ ).fit(texts)
418
+
419
+ @staticmethod
420
+ def _split_patients(df, val_frac: float = 0.1):
421
+ """Hold out val_frac of patients as the validation set."""
422
+ import pandas as pd
423
+
424
+ patients = np.array(df["patient_id"].unique())
425
+ np.random.shuffle(patients)
426
+ n_val = max(1, int(len(patients) * val_frac))
427
+ val_patients = set(patients[:n_val])
428
+ train_df = df[~df["patient_id"].isin(val_patients)].copy()
429
+ val_df = df[df["patient_id"].isin(val_patients)].copy()
430
+ return train_df, val_df