surgicalplan 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surgicalplan-0.1.0/PKG-INFO +405 -0
- surgicalplan-0.1.0/README.md +348 -0
- surgicalplan-0.1.0/license.txt +7 -0
- surgicalplan-0.1.0/setup.cfg +4 -0
- surgicalplan-0.1.0/setup.py +70 -0
- surgicalplan-0.1.0/surgicalplan/DirectInference/__init__.py +9 -0
- surgicalplan-0.1.0/surgicalplan/DirectInference/direct_inference.py +82 -0
- surgicalplan-0.1.0/surgicalplan/JointFinetuning/__init__.py +13 -0
- surgicalplan-0.1.0/surgicalplan/JointFinetuning/joint_finetuning.py +317 -0
- surgicalplan-0.1.0/surgicalplan/JointFinetuning/model.py +163 -0
- surgicalplan-0.1.0/surgicalplan/JointFinetuning/trainer.py +76 -0
- surgicalplan-0.1.0/surgicalplan/MultiTaskFinetuning/MultiTaskLearningPrediction.py +503 -0
- surgicalplan-0.1.0/surgicalplan/MultiTaskFinetuning/__init__.py +9 -0
- surgicalplan-0.1.0/surgicalplan/MultiTaskFinetuning/model.py +156 -0
- surgicalplan-0.1.0/surgicalplan/MultiTaskFinetuning/trainer.py +74 -0
- surgicalplan-0.1.0/surgicalplan/__init__.py +22 -0
- surgicalplan-0.1.0/surgicalplan.egg-info/PKG-INFO +405 -0
- surgicalplan-0.1.0/surgicalplan.egg-info/SOURCES.txt +19 -0
- surgicalplan-0.1.0/surgicalplan.egg-info/dependency_links.txt +1 -0
- surgicalplan-0.1.0/surgicalplan.egg-info/requires.txt +21 -0
- surgicalplan-0.1.0/surgicalplan.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: surgicalplan
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: SurgicalPLAN is a Python package for predicting postoperative risks from clinical notes using language models. It provides training and inference workflows for fine-tuned models, semi-supervised methods, and multi-task prediction of multiple surgical outcomes. The package is intended for clinical research and educational use, notably for the American College of Surgeons.
|
|
5
|
+
Home-page: https://github.com/cja5553/ACS_demo_postoperative_risk_prediction_with_clinical_notes
|
|
6
|
+
Author: Charles Alba
|
|
7
|
+
Author-email: alba@wustl.edu
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: multi-task learning,NLP,clinical,perioperative,BERT,Bio_ClinicalBERT
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Intended Audience :: Healthcare Industry
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Operating System :: OS Independent
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Medical Science Apps.
|
|
22
|
+
Requires-Python: >=3.9,<3.13
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
License-File: license.txt
|
|
25
|
+
Requires-Dist: transformers<5,>=4.36
|
|
26
|
+
Requires-Dist: tokenizers>=0.15
|
|
27
|
+
Requires-Dist: huggingface_hub>=0.20
|
|
28
|
+
Requires-Dist: accelerate>=0.25
|
|
29
|
+
Requires-Dist: datasets>=2.14
|
|
30
|
+
Requires-Dist: safetensors>=0.4
|
|
31
|
+
Requires-Dist: torch>=2.0
|
|
32
|
+
Requires-Dist: numpy>=1.23
|
|
33
|
+
Requires-Dist: pandas>=2.0
|
|
34
|
+
Requires-Dist: pyarrow>=14
|
|
35
|
+
Requires-Dist: tqdm>=4.66
|
|
36
|
+
Provides-Extra: classifiers
|
|
37
|
+
Requires-Dist: scikit-learn>=1.3; extra == "classifiers"
|
|
38
|
+
Requires-Dist: xgboost>=1.7; extra == "classifiers"
|
|
39
|
+
Provides-Extra: dev
|
|
40
|
+
Requires-Dist: pytest>=7; extra == "dev"
|
|
41
|
+
Requires-Dist: jupyter; extra == "dev"
|
|
42
|
+
Requires-Dist: ipykernel; extra == "dev"
|
|
43
|
+
Requires-Dist: ipywidgets>=8; extra == "dev"
|
|
44
|
+
Dynamic: author
|
|
45
|
+
Dynamic: author-email
|
|
46
|
+
Dynamic: classifier
|
|
47
|
+
Dynamic: description
|
|
48
|
+
Dynamic: description-content-type
|
|
49
|
+
Dynamic: home-page
|
|
50
|
+
Dynamic: keywords
|
|
51
|
+
Dynamic: license
|
|
52
|
+
Dynamic: license-file
|
|
53
|
+
Dynamic: provides-extra
|
|
54
|
+
Dynamic: requires-dist
|
|
55
|
+
Dynamic: requires-python
|
|
56
|
+
Dynamic: summary
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
## Overview
|
|
61
|
+
|
|
62
|
+
SurgicalPLAN (*Surgical **P**ostoperative Risk Prediction with **L**anguage Models **A**dapting to Clinical **N**otes*) is a Python package for predicting postoperative risks from clinical notes using language models. It provides flexible and clinically oriented workflows that support a range of perioperative use cases, enabling clinicians, researchers, and healthcare institutions to train and fine-tune models using preoperative or intraoperative clinical text.
|
|
63
|
+
|
|
64
|
+
The package is designed to be accessible to a broad range of users, including clinicians, surgeons, and researchers with limited programming experience. It minimizes the need to interact with lower-level machine learning frameworks such as PyTorch. With just a few lines of high-level functions, users can begin training and fine-tuning their own models.
|
|
65
|
+
|
|
66
|
+
SurgicalPLAN supports multiple modeling strategies, including:
|
|
67
|
+
|
|
68
|
+
1. **Direct inference** with fine-tuned language models
|
|
69
|
+
2. **(Joint) Semi-supervised learning** approaches for leveraging partially labeled data
|
|
70
|
+
3. A **multi-task learning framework** that enables simultaneous prediction of multiple postoperative outcomes
|
|
71
|
+
|
|
72
|
+
The package was developed for the American College of Surgeons (ACS) workshop, *AI for Clinicians and Surgeons: A Hands-On Introduction Across the Care Continuum*.
|
|
73
|
+
|
|
74
|
+
The accompanying work is:
|
|
75
|
+
[*The foundational capabilities of large language models in predicting postoperative risks using clinical notes*](https://www.nature.com/articles/s41746-025-01489-2)
|
|
76
|
+
Alba, Xue, Abraham, Kannampallil, and Lu (2025), *npj Digital Medicine*
|
|
77
|
+
|
|
78
|
+
---
|
|
79
|
+
|
|
80
|
+
## Installation
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
pip install surgicalplan
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Because `torch` CUDA wheels aren't hosted on PyPI, install PyTorch first matching your GPU's CUDA version, then install this package. For example, on a machine with CUDA 11.8 drivers:
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118
|
|
90
|
+
pip install surgicalplan
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
**Python version**: 3.9–3.12 (tested on 3.12).
|
|
94
|
+
|
|
95
|
+
---
|
|
96
|
+
|
|
97
|
+
## Quick example
|
|
98
|
+
|
|
99
|
+
```python
|
|
100
|
+
import pandas as pd
|
|
101
|
+
from MultiTaskLearningPrediction import mtl_finetune, get_postoperative_outcome_scores
|
|
102
|
+
|
|
103
|
+
df = pd.read_csv("my_clinical_data.csv")
|
|
104
|
+
# df columns: "text", "Outcome_1", "Outcome_2", "Outcome_3", "Outcome_4"
|
|
105
|
+
|
|
106
|
+
# 1. Fine-tune
|
|
107
|
+
mtl_finetune(
|
|
108
|
+
df,
|
|
109
|
+
text_col="text",
|
|
110
|
+
outcome_cols=["Outcome_1", "Outcome_2", "Outcome_3", "Outcome_4"],
|
|
111
|
+
output_dir="my_finetuned_model",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# 2. Score a new scenario
|
|
118
|
+
|
|
119
|
+
note_1 = (
|
|
120
|
+
"83-year-old male, ASA 4, scheduled for coronary artery bypass graft (emergent three-vessel). "
|
|
121
|
+
"Indication: severe CAD with LAD stenosis, presenting with unstable angina. "
|
|
122
|
+
"PMH: COPD, type 2 diabetes mellitus, coronary artery disease, prior MI, chronic kidney disease stage 3. "
|
|
123
|
+
"Social: current smoker, 1 pack per day. "
|
|
124
|
+
"BMI 34 (obese). "
|
|
125
|
+
"Home medications: metoprolol, aspirin 81 mg, atorvastatin, insulin glargine, furosemide. "
|
|
126
|
+
"Allergies: NKDA. "
|
|
127
|
+
"Preop labs within acceptable limits. Consent obtained, plan to proceed."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
scores = get_postoperative_outcome_scores(
|
|
132
|
+
"my_finetuned_model",
|
|
133
|
+
note_1
|
|
134
|
+
)
|
|
135
|
+
# {'Outcome_1': 0.12, 'Outcome_2': 0.28, 'Outcome_3': 0.04, 'Outcome_4': 0.39}
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
---
|
|
139
|
+
|
|
140
|
+
## API reference
|
|
141
|
+
|
|
142
|
+
### Direct inference
|
|
143
|
+
|
|
144
|
+
Allows users to use out-of-the-box models that have already been trained on clinical data and its associated post-operative outcomes. Unlike the later ones, this is a direct inference function that loads a pre-trained, ready-to-use model from HuggingFace Hub and therefore requires no model training.
|
|
145
|
+
|
|
146
|
+
The default model is `cja5553/BJH-perioperative-notes-bioClinicalBERT`, which is our a Bio+ClinicalBERT model variant that was multi-task fine-tuned across 6 postoperative outcomes: (1) death in 30, (2) DVT, (3) PE, (4) AKI, (5) delirium and (6) Pneumonoia. This model was used in our [accompanying *npj Digital Medicine* paper](https://www.nature.com/articles/s41746-025-01489-2).
|
|
147
|
+
|
|
148
|
+
#### `direct_inference_from_trained_model`
|
|
149
|
+
|
|
150
|
+
Score clinical text against a pre-trained multi-task model without any fine-tuning step. The model is downloaded from HuggingFace Hub on first use and cached locally thereafter.
|
|
151
|
+
|
|
152
|
+
**Example**
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
from surgicalplan import direct_inference_from_trained_model
|
|
156
|
+
|
|
157
|
+
note = (
|
|
158
|
+
"Redo coronary artery bypass graft with aortic valve replacement "
|
|
159
|
+
"bioprosthetic. Indication: severe ischemic cardiomyopathy, "
|
|
160
|
+
"ejection fraction 25 percent, prior MI, ventricular arrhythmia "
|
|
161
|
+
"status post AICD placement, stage 3 chronic kidney disease, COPD."
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
scores = direct_inference_from_trained_model(text=note)
|
|
165
|
+
# {'DVT': 0.17, 'PE': 0.06, 'PNA': 0.28, 'postop_del': 0.81,
|
|
166
|
+
# 'death_in_30': 0.46, 'post_aki_status': 0.93}
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
**Parameters**
|
|
170
|
+
|
|
171
|
+
- `text` (`str | list[str]`, *required*): One clinical scenario, or a list of them. Determines the shape of the return value.
|
|
172
|
+
- `outcomes` (`list[str] | None`, default: `None`): Which outcomes to score. Defaults to all outcomes the default model was trained on (`DVT`, `PE`, `PNA`, `postop_del`, `death_in_30`, `post_aki_status`), recovered from the model's `mtl_metadata.json`. Pass a subset to score only some.
|
|
173
|
+
- `model_name` (`str`, default: `"cja5553/BJH-perioperative-notes-bioClinicalBERT"`): HuggingFace repo ID or local path. Override to use your own fine-tuned model.
|
|
174
|
+
- `max_length` (`int | None`, default: `None`): Token sequence length. Defaults to the value used during fine-tuning, recovered from metadata.
|
|
175
|
+
- `device` (`str | None`, default: `None`): `"cuda"`, `"cpu"`, or `None` to auto-detect.
|
|
176
|
+
- `hf_token` (`str | None`, default: `None`): Optional HuggingFace token, required only if the model repo is gated/private.
|
|
177
|
+
|
|
178
|
+
**Returns**
|
|
179
|
+
|
|
180
|
+
- `dict[str, float]` when `text` is a string — maps each outcome name to a probability in `[0, 1]`.
|
|
181
|
+
- `list[dict[str, float]]` when `text` is a list — one dict per input, in the same order.
|
|
182
|
+
|
|
183
|
+
**Notes**
|
|
184
|
+
|
|
185
|
+
- First call downloads the model (~440 MB) from HuggingFace and caches it locally; subsequent calls use the cache.
|
|
186
|
+
- Inference runs on CPU in ~5 seconds per note, or ~0.5 seconds with a GPU.
|
|
187
|
+
- For users who want to fine-tune their own model, see `mtl_finetune` (multi-outcome) or `joint_finetune` (single-outcome).
|
|
188
|
+
|
|
189
|
+
---
|
|
190
|
+
|
|
191
|
+
### Joint or semi-supervised finetuning
|
|
192
|
+
|
|
193
|
+
Joint Single-Outcome Finetuning trains a separate model for each postoperative outcome of interest. The jointly learns the structure of your clinical notes whilst learns to predict the outcome, ensuring the model captures both the linguistic patterns of your institution's documentation style and the clinical features that drive your specific outcomes. Unlike the below `MultiTaskLearningPrediction`, this is catered to a single specific outcome as opposed to multiple outcomes.
|
|
194
|
+
|
|
195
|
+

|
|
196
|
+
|
|
197
|
+
#### `JointFinetuning`
|
|
198
|
+
|
|
199
|
+
Perform Joint (or semi-supervised) finetuning.
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
**Example**
|
|
204
|
+
|
|
205
|
+
```python
|
|
206
|
+
joint_finetune(
|
|
207
|
+
df,
|
|
208
|
+
text_col="clinical_notes",
|
|
209
|
+
outcome_col="DVT",
|
|
210
|
+
output_dir="DVT_model",
|
|
211
|
+
training_configs={
|
|
212
|
+
"num_train_epochs": 3,
|
|
213
|
+
"per_device_train_batch_size": 16,
|
|
214
|
+
"evaluation_strategy": "steps",
|
|
215
|
+
"eval_steps": 100,
|
|
216
|
+
"logging_steps": 100,
|
|
217
|
+
"learning_rate": 2e-5,
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
Fine-tune Bio+ClinicalBERT on MLM jointly with a single binary classification head for one outcome.
|
|
223
|
+
|
|
224
|
+
**Parameters**
|
|
225
|
+
|
|
226
|
+
- `df` (`pandas.DataFrame`, *required*): Must contain `text_col` and `outcome_col`.
|
|
227
|
+
- `text_col` (`str`, *required*): Name of the free-text column.
|
|
228
|
+
- `outcome_col` (`str`, *required*): Name of a single binary (0/1) outcome column. Rows with NaN in this column are dropped before training.
|
|
229
|
+
- `output_dir` (`str`, default `"joint_finetuned"`): Directory to save the fine-tuned model, tokenizer, and metadata. Also used as the HuggingFace Trainer `output_dir` for checkpoints and logs.
|
|
230
|
+
- `base_model` (`str`, default `"emilyalsentzer/Bio_ClinicalBERT"`): HuggingFace model id to start from. Any BERT-architecture model should work.
|
|
231
|
+
- `hf_token` (`str | None`, default `None`): Optional HuggingFace token for gated/private base models. If `None`, uses the cached CLI login when present.
|
|
232
|
+
- `max_length` (`int`, default `512`): Token sequence length for tokenization.
|
|
233
|
+
- `lambda_constant` (`float`, default `2`): Weight on the auxiliary (BCE) loss relative to MLM loss. Total loss = MLM + λ · BCE.
|
|
234
|
+
- `mlm_probability` (`float`, default `0.15`): Token masking probability for MLM.
|
|
235
|
+
- `val_fraction` (`float`, default `1/8`): Fraction of `df` held out for validation during training.
|
|
236
|
+
- `weight` (`torch.Tensor | None`, default `None`): Optional `pos_weight` for `BCEWithLogitsLoss` to handle class imbalance. Useful for rare outcomes (e.g., `torch.tensor([20.0])` for ~5% positive prevalence).
|
|
237
|
+
- `training_configs` (`dict | None`, default `None`): Any keyword arguments accepted by `transformers.TrainingArguments`. User-provided values override the defaults below. Default `training_configs` is `{"num_train_epochs": 5, "per_device_train_batch_size": 24, "per_device_eval_batch_size": 24, "learning_rate": 1e-5, "warmup_ratio": 0.06, "weight_decay": 1e-3, "logging_steps": 1000, "save_strategy": "epoch", "seed": 42, "report_to": "none"}`.
|
|
238
|
+
|
|
239
|
+
**Returns**
|
|
240
|
+
|
|
241
|
+
`str` — the `output_dir` path. After training, this directory contains:
|
|
242
|
+
|
|
243
|
+
- `pytorch_model.bin` (or `model.safetensors`) — model weights
|
|
244
|
+
- `config.json` — model architecture config
|
|
245
|
+
- `tokenizer.json`, `vocab.txt`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
|
|
246
|
+
- `joint_metadata.json` — records `outcome_col`, `text_col`, `max_length`, `base_model`, `lambda_constant`, `num_tasks` (always 1), and `workflow` so inference can recover them automatically
|
|
247
|
+
- `checkpoint-*` — per-epoch training checkpoints (can be deleted after training)
|
|
248
|
+
- `logs/` — TensorBoard-compatible training logs
|
|
249
|
+
|
|
250
|
+
---
|
|
251
|
+
|
|
252
|
+
#### `get_outcome_score`
|
|
253
|
+
|
|
254
|
+
Score a text scenario (or list of scenarios) against the single auxiliary head of a joint-finetuned model.
|
|
255
|
+
|
|
256
|
+
**Example**
|
|
257
|
+
|
|
258
|
+
```python
|
|
259
|
+
get_outcome_score(
|
|
260
|
+
model_name="DVT_model",
|
|
261
|
+
text="83-year-old male, ASA 4, scheduled for CABG. PMH: COPD, diabetes.",
|
|
262
|
+
)
|
|
263
|
+
```
|
|
264
|
+
|
|
265
|
+
**Parameters**
|
|
266
|
+
|
|
267
|
+
- `model_name` (`str`, *required*): Path to a directory saved by `joint_finetune`.
|
|
268
|
+
- `text` (`str | list[str]`, *required*): One scenario string, or a list of them. Determines the shape of the return value.
|
|
269
|
+
- `max_length` (`int | None`, default: `None`): Token sequence length. Defaults to the value used during fine-tuning, recovered from `joint_metadata.json`, otherwise `512`.
|
|
270
|
+
- `device` (`str | None`, default: `None`): `"cuda"`, `"cpu"`, or `None` to auto-detect.
|
|
271
|
+
- `hf_token` (`str | None`, default: `None`): Optional HuggingFace token for gated/private models.
|
|
272
|
+
|
|
273
|
+
**Returns**
|
|
274
|
+
|
|
275
|
+
- `float` when `text` is a string — the predicted probability for the trained outcome, in `[0, 1]`.
|
|
276
|
+
- `list[float]` when `text` is a list — one probability per input, in the same order.
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
---
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
### Multi-task finetuning
|
|
283
|
+
|
|
284
|
+
Multi-Task Learning (MTL) allows you to train a single versatile model capable of predicting multiple postoperative outcomes from the same clinical notes. Unlike traditional finetuning strategies — where you'd need to train a single model for each outcome — MTL allows you to create a model capable of simultaneously predicting multiple risks — analogous to foundation models.
|
|
285
|
+
|
|
286
|
+

|
|
287
|
+
|
|
288
|
+
#### `MultiTaskLearningPrediction`
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
Performs MTL finetuning.
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
**Example**
|
|
296
|
+
|
|
297
|
+
```python
|
|
298
|
+
mtl_finetune(
|
|
299
|
+
df,
|
|
300
|
+
text_col="clincal_notes",
|
|
301
|
+
outcome_cols=["death_30d", "dvt", "pneumonia", "aki", "AUR", "PE"],
|
|
302
|
+
output_dir="my_run",
|
|
303
|
+
training_configs={
|
|
304
|
+
"num_train_epochs": 3,
|
|
305
|
+
"per_device_train_batch_size": 16,
|
|
306
|
+
"evaluation_strategy": "steps",
|
|
307
|
+
"eval_steps": 100,
|
|
308
|
+
"logging_steps": 100,
|
|
309
|
+
"learning_rate": 2e-5
|
|
310
|
+
}
|
|
311
|
+
)
|
|
312
|
+
```
|
|
313
|
+
|
|
314
|
+
Fine-tune Bio+ClinicalBERT on MLM jointly with one binary classification head per outcome.
|
|
315
|
+
|
|
316
|
+
**Parameters**
|
|
317
|
+
|
|
318
|
+
- `df` (`pandas.DataFrame`, *required*): Must contain `text_col` and all `outcome_cols`.
|
|
319
|
+
- `text_col` (`str`, *required*): Name of the free-text column.
|
|
320
|
+
- `outcome_cols` (`list[str]`, *required*): Names of binary (0/1) outcome columns. One auxiliary head is trained per outcome. Rows with NaN in a given outcome are dropped for that outcome's task but used for the others.
|
|
321
|
+
- `output_dir` (`str`, default `"mtl_finetuned"`): Directory to save the fine-tuned model, tokenizer, and metadata. Also used as the HuggingFace Trainer `output_dir` for checkpoints and logs.
|
|
322
|
+
- `base_model` (`str`, default `"emilyalsentzer/Bio_ClinicalBERT"`): HuggingFace model id to start from. Any BERT-architecture model should work.
|
|
323
|
+
- `max_length` (`int`, default `512`): Token sequence length for tokenization.
|
|
324
|
+
- `lambda_constant` (`float`, default `2`): Weight on the auxiliary (per-outcome BCE) loss relative to MLM loss. Total loss = MLM + λ · mean(per-task BCE).
|
|
325
|
+
- `val_fraction` (`float`, default `1/8`): Fraction of `df` held out for validation during training.
|
|
326
|
+
- `training_configs` (`dict | None`, default `None`): Any keyword arguments accepted by `transformers.TrainingArguments`. User-provided values override the defaults below. Default `training_configs` is `{"num_train_epochs": 5, "per_device_train_batch_size": 24, "per_device_eval_batch_size": 24, "learning_rate": 1e-5, "warmup_ratio": 0.06, "weight_decay": 1e-3, "logging_steps": 1000 "save_strategy": "epoch", "seed": 42,}`
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
**Returns**
|
|
332
|
+
|
|
333
|
+
`str` — the `output_dir` path. After training, this directory contains:
|
|
334
|
+
|
|
335
|
+
- `pytorch_model.bin` (or `model.safetensors`) — model weights
|
|
336
|
+
- `config.json` — model architecture config
|
|
337
|
+
- `tokenizer.json`, `vocab.txt`, `tokenizer_config.json`, `special_tokens_map.json` — tokenizer
|
|
338
|
+
- `mtl_metadata.json` — records `outcome_cols`, `text_col`, `max_length`, `base_model`, `lambda_constant`, `num_tasks` so inference can recover them automatically
|
|
339
|
+
- `checkpoint-*` — per-epoch training checkpoints (can be deleted after training)
|
|
340
|
+
- `logs/` — TensorBoard-compatible training logs
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
---
|
|
345
|
+
|
|
346
|
+
#### `get_postoperative_outcome_scores`
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
Score a text scenario (or list of scenarios) against each auxiliary head of a fine-tuned MTL model.
|
|
351
|
+
|
|
352
|
+
**Example**
|
|
353
|
+
|
|
354
|
+
```python
|
|
355
|
+
get_postoperative_outcome_scores(
|
|
356
|
+
model_name,
|
|
357
|
+
text,
|
|
358
|
+
outcomes=["death_30d", "dvt", "pneumonia", "aki", "AUR", "PE"],
|
|
359
|
+
)
|
|
360
|
+
```
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
**Parameters**
|
|
364
|
+
|
|
365
|
+
- `model_name` (`str`, required): Path to a directory saved by `mtl_finetune`.
|
|
366
|
+
- `text` (`str | list[str]`, required): One scenario string, or a list of them. Determines the shape of the return value.
|
|
367
|
+
- `outcomes` (`list[str] | None`, default: `None`): Which outcomes to score. Defaults to all outcomes the model was trained on, recovered from `mtl_metadata.json`. Pass a subset to score only some. Names must match those used in `mtl_finetune`.
|
|
368
|
+
- `max_length` (`int | None`, default: `None`): Token sequence length. Defaults to the value used during fine-tuning, recovered from metadata, otherwise `512`.
|
|
369
|
+
- `device` (`str | None`, default: `None`): `"cuda"`, `"cpu"`, or `None` to auto-detect.
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
**Returns**
|
|
373
|
+
|
|
374
|
+
- `dict[str, float]` when `text` is a string — maps each outcome name to a probability in `[0, 1]`.
|
|
375
|
+
- `list[dict[str, float]]` when `text` is a list — one dict per input, in the same order.
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
---
|
|
379
|
+
|
|
380
|
+
#### `get_pseudo_data`
|
|
381
|
+
|
|
382
|
+
Generate a small synthetic dataset of preoperative clinical notes with binary outcomes for testing and demonstration. Outcomes are not random — each is driven by realistic feature combinations in the note (procedure type, age, ASA class, comorbidities), so a fine-tuned model is expected to learn meaningful associations.
|
|
383
|
+
|
|
384
|
+
**Example**
|
|
385
|
+
|
|
386
|
+
```python
|
|
387
|
+
df = get_pseudo_data()
|
|
388
|
+
print(df.shape) # (1000, 5)
|
|
389
|
+
print(df.columns.tolist()) # ['text', 'Outcome_1', 'Outcome_2', 'Outcome_3', 'Outcome_4']
|
|
390
|
+
```
|
|
391
|
+
|
|
392
|
+
**Parameters**
|
|
393
|
+
|
|
394
|
+
None.
|
|
395
|
+
|
|
396
|
+
**Returns**
|
|
397
|
+
|
|
398
|
+
`pandas.DataFrame` with 1000 rows and 5 columns:
|
|
399
|
+
|
|
400
|
+
- `text` (`str`) — synthetic preoperative note.
|
|
401
|
+
- `Outcome_1` to `Outcome_4` (`int`, 0/1) — binary outcomes driven by clinical features in the note.
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
---
|
|
405
|
+
|