cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
|
|
21
21
|
DISCHARGE_UNKNOWN_TOKEN,
|
22
22
|
GENDER_UNKNOWN_TOKEN,
|
23
23
|
RACE_UNKNOWN_TOKEN,
|
24
|
-
VISIT_UNKNOWN_TOKEN,
|
25
24
|
)
|
26
25
|
from cehrbert_data.const.common import NA
|
27
26
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
@@ -29,6 +28,12 @@ from datasets.formatting.formatting import LazyBatch
|
|
29
28
|
from dateutil.relativedelta import relativedelta
|
30
29
|
from pandas import Series
|
31
30
|
|
31
|
+
from cehrgpt.gpt_utils import (
|
32
|
+
construct_age_sequence,
|
33
|
+
construct_time_sequence,
|
34
|
+
encode_demographics,
|
35
|
+
multiple_of_10,
|
36
|
+
)
|
32
37
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
33
38
|
NONE_BIN,
|
34
39
|
UNKNOWN_BIN,
|
@@ -44,13 +49,20 @@ CEHRGPT_COLUMNS = [
|
|
44
49
|
"concept_values",
|
45
50
|
"units",
|
46
51
|
"epoch_times",
|
52
|
+
"ages",
|
47
53
|
]
|
48
54
|
|
49
55
|
|
50
|
-
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
51
|
-
|
52
|
-
|
53
|
-
|
56
|
+
def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
|
57
|
+
if isinstance(index_date, datetime.date):
|
58
|
+
return (
|
59
|
+
datetime.datetime.combine(index_date, datetime.datetime.min.time())
|
60
|
+
.replace(tzinfo=datetime.timezone.utc)
|
61
|
+
.timestamp()
|
62
|
+
)
|
63
|
+
elif isinstance(index_date, datetime.datetime):
|
64
|
+
return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
|
65
|
+
return index_date
|
54
66
|
|
55
67
|
|
56
68
|
class DatasetMappingDecorator(DatasetMapping):
|
@@ -116,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
116
128
|
cehrgpt_record: Dict[str, Any],
|
117
129
|
code: str,
|
118
130
|
time: datetime.datetime,
|
131
|
+
age: int,
|
119
132
|
concept_value_mask: int = 0,
|
120
133
|
number_as_value: float = 0.0,
|
121
134
|
concept_as_value: str = "0",
|
@@ -123,17 +136,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
123
136
|
unit: str = NA,
|
124
137
|
) -> None:
|
125
138
|
cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
|
139
|
+
cehrgpt_record["ages"].append(age)
|
126
140
|
cehrgpt_record["concept_value_masks"].append(concept_value_mask)
|
127
141
|
cehrgpt_record["number_as_values"].append(number_as_value)
|
128
142
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
129
143
|
cehrgpt_record["units"].append(unit)
|
130
144
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
131
|
-
cehrgpt_record["epoch_times"].append(
|
145
|
+
cehrgpt_record["epoch_times"].append(
|
146
|
+
time.replace(tzinfo=datetime.timezone.utc).timestamp()
|
147
|
+
)
|
132
148
|
|
133
149
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
134
150
|
cehrgpt_record = {
|
135
151
|
"person_id": record["patient_id"],
|
136
152
|
"concept_ids": [],
|
153
|
+
"ages": [],
|
137
154
|
"concept_value_masks": [],
|
138
155
|
"number_as_values": [],
|
139
156
|
"concept_as_values": [],
|
@@ -161,14 +178,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
161
178
|
first_visit_start_datetime: datetime.datetime = get_value(
|
162
179
|
first_visit, "visit_start_datetime"
|
163
180
|
)
|
181
|
+
starting_age = relativedelta(first_visit_start_datetime, birth_datetime).years
|
164
182
|
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
165
|
-
age_str = f"age:{
|
183
|
+
age_str = f"age:{starting_age}"
|
184
|
+
self._update_cehrgpt_record(
|
185
|
+
cehrgpt_record, year_str, first_visit_start_datetime, starting_age
|
186
|
+
)
|
187
|
+
self._update_cehrgpt_record(
|
188
|
+
cehrgpt_record, age_str, first_visit_start_datetime, starting_age
|
189
|
+
)
|
190
|
+
self._update_cehrgpt_record(
|
191
|
+
cehrgpt_record, gender, first_visit_start_datetime, starting_age
|
192
|
+
)
|
166
193
|
self._update_cehrgpt_record(
|
167
|
-
cehrgpt_record,
|
194
|
+
cehrgpt_record, race, first_visit_start_datetime, starting_age
|
168
195
|
)
|
169
|
-
self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
|
170
|
-
self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
|
171
|
-
self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
|
172
196
|
|
173
197
|
# Use a data cursor to keep track of time
|
174
198
|
datetime_cursor: Optional[datetime.datetime] = None
|
@@ -204,6 +228,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
204
228
|
cehrgpt_record,
|
205
229
|
code=self._time_token_function(time_delta),
|
206
230
|
time=visit_start_datetime,
|
231
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
207
232
|
)
|
208
233
|
|
209
234
|
datetime_cursor = visit_start_datetime
|
@@ -212,12 +237,14 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
212
237
|
cehrgpt_record,
|
213
238
|
code="[VS]",
|
214
239
|
time=datetime_cursor,
|
240
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
215
241
|
)
|
216
242
|
# Add a visit type token
|
217
243
|
self._update_cehrgpt_record(
|
218
244
|
cehrgpt_record,
|
219
245
|
code=visit_type,
|
220
246
|
time=datetime_cursor,
|
247
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
221
248
|
)
|
222
249
|
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
223
250
|
# with respect to the midnight of the day
|
@@ -228,6 +255,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
228
255
|
cehrgpt_record,
|
229
256
|
code=f"i-H{datetime_cursor.hour}",
|
230
257
|
time=datetime_cursor,
|
258
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
231
259
|
)
|
232
260
|
|
233
261
|
# Keep track of the existing outpatient events, we don't want to add them again
|
@@ -274,6 +302,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
274
302
|
cehrgpt_record,
|
275
303
|
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
276
304
|
time=event_time,
|
305
|
+
age=relativedelta(event_time, birth_datetime).years,
|
277
306
|
)
|
278
307
|
|
279
308
|
if self._include_inpatient_hour_token:
|
@@ -293,6 +322,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
293
322
|
cehrgpt_record,
|
294
323
|
code=f"i-H{time_diff_hours}",
|
295
324
|
time=event_time,
|
325
|
+
age=relativedelta(event_time, birth_datetime).years,
|
296
326
|
)
|
297
327
|
|
298
328
|
if event_identity in existing_duplicate_events:
|
@@ -302,6 +332,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
302
332
|
cehrgpt_record,
|
303
333
|
code=code,
|
304
334
|
time=event_time,
|
335
|
+
age=relativedelta(event_time, birth_datetime).years,
|
305
336
|
concept_value_mask=concept_value_mask,
|
306
337
|
unit=unit,
|
307
338
|
number_as_value=numeric_value if numeric_value else 0.0,
|
@@ -341,6 +372,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
341
372
|
cehrgpt_record,
|
342
373
|
code=discharge_facility,
|
343
374
|
time=datetime_cursor,
|
375
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
344
376
|
)
|
345
377
|
|
346
378
|
# Reuse the age and date calculated for the last event in the patient timeline
|
@@ -348,6 +380,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
348
380
|
cehrgpt_record,
|
349
381
|
code="[VE]",
|
350
382
|
time=datetime_cursor,
|
383
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
351
384
|
)
|
352
385
|
|
353
386
|
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
@@ -360,7 +393,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
360
393
|
cehrgpt_record["num_of_visits"] = len(visits)
|
361
394
|
|
362
395
|
if record.get("index_date", None) is not None:
|
363
|
-
cehrgpt_record["index_date"] =
|
396
|
+
cehrgpt_record["index_date"] = (
|
397
|
+
record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
|
398
|
+
)
|
364
399
|
if record.get("label", None) is not None:
|
365
400
|
cehrgpt_record["label"] = record["label"]
|
366
401
|
if record.get("age_at_index", None) is not None:
|
@@ -419,6 +454,13 @@ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
|
|
419
454
|
return record
|
420
455
|
|
421
456
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
457
|
+
# Reconstruct the ages input before the filter is applied
|
458
|
+
record["ages"] = construct_age_sequence(
|
459
|
+
record["concept_ids"], record.get("ages", None)
|
460
|
+
)
|
461
|
+
record["epoch_times"] = construct_time_sequence(
|
462
|
+
record["concept_ids"], record.get("epoch_times", None)
|
463
|
+
)
|
422
464
|
# Remove the tokens from patient sequences that do not exist in the tokenizer
|
423
465
|
record = self.filter_out_invalid_tokens(record)
|
424
466
|
# If any concept has a value associated with it, we normalize the value
|
@@ -529,9 +571,13 @@ class ExtractTokenizedSequenceDataMapping:
|
|
529
571
|
prediction_start_end_times = [
|
530
572
|
(
|
531
573
|
self._calculate_prediction_start_time(
|
532
|
-
prediction_time_label_map["index_date"]
|
574
|
+
prediction_time_label_map["index_date"]
|
575
|
+
.replace(tzinfo=datetime.timezone.utc)
|
576
|
+
.timestamp()
|
533
577
|
),
|
534
|
-
prediction_time_label_map["index_date"]
|
578
|
+
prediction_time_label_map["index_date"]
|
579
|
+
.replace(tzinfo=datetime.timezone.utc)
|
580
|
+
.timestamp(),
|
535
581
|
prediction_time_label_map["label"],
|
536
582
|
)
|
537
583
|
for prediction_time_label_map in prediction_times
|
@@ -0,0 +1,316 @@
|
|
1
|
+
import datetime
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
import shutil
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
import polars as pl
|
10
|
+
import torch
|
11
|
+
import torch.distributed as dist
|
12
|
+
from cehrbert.runners.runner_util import generate_prepared_ds_path
|
13
|
+
from datasets import load_from_disk
|
14
|
+
from meds import held_out_split, train_split, tuning_split
|
15
|
+
from torch.utils.data import DataLoader
|
16
|
+
from tqdm import tqdm
|
17
|
+
from transformers.trainer_utils import is_main_process
|
18
|
+
from transformers.utils import is_flash_attn_2_available, logging
|
19
|
+
|
20
|
+
from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
|
21
|
+
from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
|
22
|
+
from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
|
23
|
+
generate_single_batch,
|
24
|
+
normalize_value,
|
25
|
+
)
|
26
|
+
from cehrgpt.gpt_utils import (
|
27
|
+
extract_time_interval_in_days,
|
28
|
+
extract_time_interval_in_hours,
|
29
|
+
is_att_token,
|
30
|
+
is_inpatient_hour_token,
|
31
|
+
is_visit_end,
|
32
|
+
is_visit_start,
|
33
|
+
)
|
34
|
+
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
35
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
36
|
+
from cehrgpt.runners.data_utils import (
|
37
|
+
extract_cohort_sequences,
|
38
|
+
prepare_finetune_dataset,
|
39
|
+
)
|
40
|
+
from cehrgpt.runners.gpt_runner_util import parse_runner_args
|
41
|
+
from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
|
42
|
+
|
43
|
+
LOG = logging.get_logger("transformers")
|
44
|
+
|
45
|
+
|
46
|
+
def map_data_split_name(split: str) -> str:
|
47
|
+
if split == "train":
|
48
|
+
return train_split
|
49
|
+
elif split == "validation":
|
50
|
+
return tuning_split
|
51
|
+
elif split == "test":
|
52
|
+
return held_out_split
|
53
|
+
raise ValueError(f"Unknown split: {split}")
|
54
|
+
|
55
|
+
|
56
|
+
def seed_all(seed: int = 42):
|
57
|
+
"""Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
|
58
|
+
random.seed(seed) # Python random
|
59
|
+
np.random.seed(seed) # NumPy
|
60
|
+
torch.manual_seed(seed) # PyTorch CPU
|
61
|
+
torch.cuda.manual_seed(seed) # Current GPU
|
62
|
+
torch.cuda.manual_seed_all(seed) # All GPUs
|
63
|
+
|
64
|
+
# For reproducibility in dataloader workers
|
65
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
66
|
+
|
67
|
+
|
68
|
+
def generate_trajectories_per_batch(
|
69
|
+
batch: Dict[str, Any],
|
70
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
71
|
+
cehrgpt_model: CEHRGPT2LMHeadModel,
|
72
|
+
device,
|
73
|
+
data_output_path: Path,
|
74
|
+
max_length: int,
|
75
|
+
):
|
76
|
+
subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
|
77
|
+
prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
|
78
|
+
batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
|
79
|
+
batched_input_ids = batch["input_ids"]
|
80
|
+
batched_ages = batch["ages"]
|
81
|
+
batched_value_indicators = batch["value_indicators"]
|
82
|
+
batched_values = batch["values"]
|
83
|
+
# Make sure the batch does not exceed batch_size
|
84
|
+
batch_sequences = generate_single_batch(
|
85
|
+
cehrgpt_model,
|
86
|
+
cehrgpt_tokenizer,
|
87
|
+
batched_input_ids,
|
88
|
+
ages=batched_ages,
|
89
|
+
values=batched_values,
|
90
|
+
value_indicators=batched_value_indicators,
|
91
|
+
max_length=max_length,
|
92
|
+
top_p=1.0,
|
93
|
+
top_k=cehrgpt_tokenizer.vocab_size,
|
94
|
+
device=device,
|
95
|
+
)
|
96
|
+
# Clear the cache
|
97
|
+
torch.cuda.empty_cache()
|
98
|
+
|
99
|
+
trajectories = []
|
100
|
+
for sample_i, (concept_ids, value_indicators, values) in enumerate(
|
101
|
+
zip(
|
102
|
+
batch_sequences["sequences"],
|
103
|
+
batch_sequences["value_indicators"],
|
104
|
+
batch_sequences["values"],
|
105
|
+
)
|
106
|
+
):
|
107
|
+
(
|
108
|
+
concept_ids,
|
109
|
+
is_numeric_types,
|
110
|
+
number_as_values,
|
111
|
+
concept_as_values,
|
112
|
+
units,
|
113
|
+
) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
|
114
|
+
|
115
|
+
epoch_times = batched_epoch_times[sample_i]
|
116
|
+
input_length = len(epoch_times)
|
117
|
+
# Getting the last observed event time from the token before the prediction time
|
118
|
+
window_last_observed = epoch_times[input_length - 1]
|
119
|
+
current_cursor = epoch_times[-1]
|
120
|
+
generated_epoch_times = []
|
121
|
+
valid_indices = []
|
122
|
+
|
123
|
+
for i in range(input_length, len(concept_ids)):
|
124
|
+
concept_id = concept_ids[i]
|
125
|
+
# We use the left padding strategy in the data collator
|
126
|
+
if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
|
127
|
+
continue
|
128
|
+
# We need to construct the time stamp
|
129
|
+
if is_att_token(concept_id):
|
130
|
+
current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
|
131
|
+
elif is_inpatient_hour_token(concept_id):
|
132
|
+
current_cursor += extract_time_interval_in_hours(concept_id) * 3600
|
133
|
+
elif is_visit_start(concept_id) or is_visit_end(concept_id):
|
134
|
+
continue
|
135
|
+
else:
|
136
|
+
valid_indices.append(i)
|
137
|
+
generated_epoch_times.append(
|
138
|
+
datetime.datetime.utcfromtimestamp(current_cursor).replace(
|
139
|
+
tzinfo=None
|
140
|
+
)
|
141
|
+
)
|
142
|
+
|
143
|
+
trajectories.append(
|
144
|
+
{
|
145
|
+
"subject_id": subject_ids[sample_i],
|
146
|
+
"prediction_time": datetime.datetime.utcfromtimestamp(
|
147
|
+
prediction_times[sample_i]
|
148
|
+
).replace(tzinfo=None),
|
149
|
+
"window_last_observed_time": datetime.datetime.utcfromtimestamp(
|
150
|
+
window_last_observed
|
151
|
+
).replace(tzinfo=None),
|
152
|
+
"times": generated_epoch_times,
|
153
|
+
"concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
|
154
|
+
"numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
|
155
|
+
"text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
|
156
|
+
"units": np.asarray(units)[valid_indices].tolist(),
|
157
|
+
}
|
158
|
+
)
|
159
|
+
|
160
|
+
trajectories = (
|
161
|
+
pl.DataFrame(trajectories)
|
162
|
+
.explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
|
163
|
+
.rename(
|
164
|
+
{
|
165
|
+
"times": "time",
|
166
|
+
"concept_ids": "code",
|
167
|
+
"numeric_values": "numeric_value",
|
168
|
+
"units": "unit",
|
169
|
+
}
|
170
|
+
)
|
171
|
+
.select(
|
172
|
+
"subject_id",
|
173
|
+
"prediction_time",
|
174
|
+
"window_last_observed_time",
|
175
|
+
"time",
|
176
|
+
"code",
|
177
|
+
"numeric_value",
|
178
|
+
"text_value",
|
179
|
+
"unit",
|
180
|
+
)
|
181
|
+
)
|
182
|
+
trajectories.write_parquet(data_output_path)
|
183
|
+
|
184
|
+
|
185
|
+
def main():
|
186
|
+
cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
|
187
|
+
if torch.cuda.is_available():
|
188
|
+
device = torch.device("cuda")
|
189
|
+
else:
|
190
|
+
device = torch.device("cpu")
|
191
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
192
|
+
model_args.tokenizer_name_or_path
|
193
|
+
)
|
194
|
+
cehrgpt_model = (
|
195
|
+
CEHRGPT2LMHeadModel.from_pretrained(
|
196
|
+
model_args.model_name_or_path,
|
197
|
+
attn_implementation=(
|
198
|
+
"flash_attention_2" if is_flash_attn_2_available() else "eager"
|
199
|
+
),
|
200
|
+
)
|
201
|
+
.eval()
|
202
|
+
.to(device)
|
203
|
+
)
|
204
|
+
cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
|
205
|
+
cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
|
206
|
+
cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
|
207
|
+
|
208
|
+
if not os.path.exists(training_args.output_dir):
|
209
|
+
os.makedirs(training_args.output_dir)
|
210
|
+
|
211
|
+
prepared_ds_path = generate_prepared_ds_path(
|
212
|
+
data_args, model_args, data_folder=data_args.cohort_folder
|
213
|
+
)
|
214
|
+
|
215
|
+
processed_dataset = None
|
216
|
+
if any(prepared_ds_path.glob("*")):
|
217
|
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
218
|
+
processed_dataset = load_from_disk(str(prepared_ds_path))
|
219
|
+
LOG.info("Prepared dataset loaded from disk...")
|
220
|
+
if cehrgpt_args.expand_tokenizer:
|
221
|
+
if tokenizer_exists(training_args.output_dir):
|
222
|
+
cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
|
223
|
+
training_args.output_dir
|
224
|
+
)
|
225
|
+
else:
|
226
|
+
LOG.warning(
|
227
|
+
f"CehrGptTokenizer must exist in {training_args.output_dir} "
|
228
|
+
f"when the dataset has been processed and expand_tokenizer is set to True. "
|
229
|
+
f"Please delete the processed dataset at {prepared_ds_path}."
|
230
|
+
)
|
231
|
+
processed_dataset = None
|
232
|
+
shutil.rmtree(prepared_ds_path)
|
233
|
+
|
234
|
+
if processed_dataset is None and is_main_process(training_args.local_rank):
|
235
|
+
# If the full dataset has been tokenized, we don't want to tokenize the cohort containing
|
236
|
+
# the subset of the data. We should slice out the portion of the tokenized sequences for each sample
|
237
|
+
if cehrgpt_args.tokenized_full_dataset_path is not None:
|
238
|
+
processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
|
239
|
+
else:
|
240
|
+
# Organize them into a single DatasetDict
|
241
|
+
final_splits = prepare_finetune_dataset(
|
242
|
+
data_args, training_args, cehrgpt_args
|
243
|
+
)
|
244
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
245
|
+
if not data_args.streaming:
|
246
|
+
all_columns = final_splits["train"].column_names
|
247
|
+
if "visit_concept_ids" in all_columns:
|
248
|
+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
|
249
|
+
|
250
|
+
processed_dataset = create_cehrgpt_finetuning_dataset(
|
251
|
+
dataset=final_splits,
|
252
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
253
|
+
data_args=data_args,
|
254
|
+
)
|
255
|
+
if not data_args.streaming:
|
256
|
+
processed_dataset.save_to_disk(prepared_ds_path)
|
257
|
+
processed_dataset.cleanup_cache_files()
|
258
|
+
|
259
|
+
# After main-process-only operations, synchronize all processes to ensure consistency
|
260
|
+
if dist.is_available() and dist.is_initialized():
|
261
|
+
dist.barrier()
|
262
|
+
|
263
|
+
# We suppress the additional learning objectives in fine-tuning
|
264
|
+
data_collator = CehrGptDataCollator(
|
265
|
+
tokenizer=cehrgpt_tokenizer,
|
266
|
+
max_length=cehrgpt_args.generation_input_length,
|
267
|
+
include_values=cehrgpt_model.config.include_values,
|
268
|
+
pretraining=False,
|
269
|
+
include_ttv_prediction=False,
|
270
|
+
use_sub_time_tokenization=False,
|
271
|
+
include_demographics=False,
|
272
|
+
add_linear_prob_token=False,
|
273
|
+
)
|
274
|
+
|
275
|
+
LOG.info(
|
276
|
+
"Generating %s trajectories per sample",
|
277
|
+
cehrgpt_args.num_of_trajectories_per_sample,
|
278
|
+
)
|
279
|
+
for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
|
280
|
+
for split, dataset in processed_dataset.items():
|
281
|
+
meds_split = map_data_split_name(split)
|
282
|
+
dataloader = DataLoader(
|
283
|
+
dataset=dataset,
|
284
|
+
batch_size=training_args.per_device_eval_batch_size,
|
285
|
+
num_workers=training_args.dataloader_num_workers,
|
286
|
+
collate_fn=data_collator,
|
287
|
+
pin_memory=training_args.dataloader_pin_memory,
|
288
|
+
)
|
289
|
+
sample_output_dir = (
|
290
|
+
Path(training_args.output_dir) / meds_split / f"{sample_i}"
|
291
|
+
)
|
292
|
+
sample_output_dir.mkdir(exist_ok=True, parents=True)
|
293
|
+
for batch_i, batch in tqdm(
|
294
|
+
enumerate(dataloader),
|
295
|
+
desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
|
296
|
+
):
|
297
|
+
output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
|
298
|
+
if output_parquet_file.exists():
|
299
|
+
LOG.info("%s already exists, skip...", output_parquet_file)
|
300
|
+
continue
|
301
|
+
|
302
|
+
generate_trajectories_per_batch(
|
303
|
+
batch,
|
304
|
+
cehrgpt_tokenizer,
|
305
|
+
cehrgpt_model,
|
306
|
+
device,
|
307
|
+
sample_output_dir / f"{batch_i}.parquet",
|
308
|
+
cehrgpt_args.generation_max_new_tokens
|
309
|
+
+ cehrgpt_args.generation_input_length,
|
310
|
+
)
|
311
|
+
|
312
|
+
|
313
|
+
if __name__ == "__main__":
|
314
|
+
# ✅ Call first thing inside main()
|
315
|
+
seed_all(42)
|
316
|
+
main()
|
@@ -2,7 +2,7 @@ import datetime
|
|
2
2
|
import os
|
3
3
|
import random
|
4
4
|
import uuid
|
5
|
-
from typing import Any, Dict,
|
5
|
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import pandas as pd
|
@@ -13,7 +13,7 @@ from transformers.utils import is_flash_attn_2_available, logging
|
|
13
13
|
|
14
14
|
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
15
15
|
from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
|
16
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
16
|
+
from cehrgpt.gpt_utils import construct_age_sequence, get_cehrgpt_output_folder
|
17
17
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
18
|
from cehrgpt.models.special_tokens import END_TOKEN
|
19
19
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
@@ -72,9 +72,13 @@ def normalize_value(
|
|
72
72
|
|
73
73
|
def generate_single_batch(
|
74
74
|
model: CEHRGPT2LMHeadModel,
|
75
|
-
|
76
|
-
prompts:
|
77
|
-
|
75
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
76
|
+
prompts: torch.Tensor,
|
77
|
+
max_length: int,
|
78
|
+
ages: Optional[torch.Tensor] = None,
|
79
|
+
values: Optional[torch.Tensor] = None,
|
80
|
+
value_indicators: Optional[torch.Tensor] = None,
|
81
|
+
max_new_tokens: Optional[int] = None,
|
78
82
|
mini_num_of_concepts=1,
|
79
83
|
top_p=0.95,
|
80
84
|
top_k=50,
|
@@ -88,7 +92,8 @@ def generate_single_batch(
|
|
88
92
|
with torch.no_grad():
|
89
93
|
generation_config = GenerationConfig(
|
90
94
|
repetition_penalty=repetition_penalty,
|
91
|
-
|
95
|
+
max_new_tokens=max_new_tokens,
|
96
|
+
max_length=max_length,
|
92
97
|
min_length=mini_num_of_concepts,
|
93
98
|
temperature=temperature,
|
94
99
|
top_p=top_p,
|
@@ -107,20 +112,33 @@ def generate_single_batch(
|
|
107
112
|
num_beam_groups=num_beam_groups,
|
108
113
|
epsilon_cutoff=epsilon_cutoff,
|
109
114
|
)
|
110
|
-
|
115
|
+
|
116
|
+
batched_prompts = prompts.to(device)
|
117
|
+
if ages is not None:
|
118
|
+
ages = ages.to(device)
|
119
|
+
if values is not None:
|
120
|
+
values = values.to(device)
|
121
|
+
if value_indicators is not None:
|
122
|
+
value_indicators = value_indicators.to(device)
|
123
|
+
|
111
124
|
results = model.generate(
|
112
125
|
inputs=batched_prompts,
|
126
|
+
ages=ages,
|
127
|
+
values=values,
|
128
|
+
value_indicators=value_indicators,
|
113
129
|
generation_config=generation_config,
|
114
|
-
|
130
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
115
131
|
)
|
116
132
|
|
117
133
|
sequences = [
|
118
|
-
|
134
|
+
cehrgpt_tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
|
119
135
|
for seq in results.sequences
|
120
136
|
]
|
121
137
|
if results.sequence_vals is not None:
|
122
138
|
values = [
|
123
|
-
|
139
|
+
cehrgpt_tokenizer.decode_value(
|
140
|
+
values.cpu().numpy(), skip_special_tokens=False
|
141
|
+
)
|
124
142
|
for values in results.sequence_vals
|
125
143
|
]
|
126
144
|
else:
|
@@ -202,6 +220,7 @@ def main(args):
|
|
202
220
|
|
203
221
|
# Randomly pick demographics from the existing population
|
204
222
|
random_prompts = []
|
223
|
+
random_prompt_ages = []
|
205
224
|
iter = 0
|
206
225
|
while len(random_prompts) < args.batch_size:
|
207
226
|
for row in dataset.select(
|
@@ -212,9 +231,9 @@ def main(args):
|
|
212
231
|
<= len(row["concept_ids"])
|
213
232
|
<= max_seq_allowed
|
214
233
|
):
|
215
|
-
|
216
|
-
|
217
|
-
)
|
234
|
+
prompt = row["concept_ids"][:prompt_size]
|
235
|
+
random_prompts.append(cehrgpt_tokenizer.encode(prompt))
|
236
|
+
random_prompt_ages.append(construct_age_sequence(prompt))
|
218
237
|
iter += 1
|
219
238
|
if not random_prompts and iter > 10:
|
220
239
|
raise RuntimeError(
|
@@ -225,8 +244,9 @@ def main(args):
|
|
225
244
|
batch_sequences = generate_single_batch(
|
226
245
|
cehrgpt_model,
|
227
246
|
cehrgpt_tokenizer,
|
228
|
-
random_prompts[: args.batch_size],
|
229
|
-
|
247
|
+
torch.tensor(random_prompts[: args.batch_size]),
|
248
|
+
ages=torch.tensor(random_prompt_ages[: args.batch_size]),
|
249
|
+
max_length=args.context_window,
|
230
250
|
mini_num_of_concepts=args.min_num_of_concepts,
|
231
251
|
top_p=args.top_p,
|
232
252
|
top_k=args.top_k,
|
@@ -270,20 +270,24 @@ def gpt_to_omop_converter_batch(
|
|
270
270
|
|
271
271
|
is_numeric_types = (
|
272
272
|
is_numeric_types[START_TOKEN_SIZE:]
|
273
|
-
if is_numeric_types is not None
|
273
|
+
if is_numeric_types is not None and not np.all(pd.isna(is_numeric_types))
|
274
274
|
else None
|
275
275
|
)
|
276
276
|
number_as_values = (
|
277
277
|
number_as_values[START_TOKEN_SIZE:]
|
278
|
-
if number_as_values is not None
|
278
|
+
if number_as_values is not None and not np.all(pd.isna(number_as_values))
|
279
279
|
else None
|
280
280
|
)
|
281
281
|
concept_as_values = (
|
282
282
|
concept_as_values[START_TOKEN_SIZE:]
|
283
|
-
if concept_as_values is not None
|
283
|
+
if concept_as_values is not None and not np.all(pd.isna(concept_as_values))
|
284
|
+
else None
|
285
|
+
)
|
286
|
+
units = (
|
287
|
+
units[START_TOKEN_SIZE:]
|
288
|
+
if units is not None and not np.all(pd.isna(units))
|
284
289
|
else None
|
285
290
|
)
|
286
|
-
units = units[START_TOKEN_SIZE:] if units is not None else None
|
287
291
|
|
288
292
|
# TODO:Need to decode if the input is tokenized
|
289
293
|
[start_year, start_age, start_gender, start_race] = concept_ids[
|
@@ -441,6 +445,9 @@ def gpt_to_omop_converter_batch(
|
|
441
445
|
]:
|
442
446
|
# If it's a start token, skip it
|
443
447
|
pass
|
448
|
+
elif event.endswith("/0"):
|
449
|
+
# This should capture the concept such as Visit/0, Discharge/0
|
450
|
+
pass
|
444
451
|
else:
|
445
452
|
try:
|
446
453
|
concept_id = int(event)
|