cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,6 @@ from cehrbert_data.const.artificial_tokens import (
21
21
  DISCHARGE_UNKNOWN_TOKEN,
22
22
  GENDER_UNKNOWN_TOKEN,
23
23
  RACE_UNKNOWN_TOKEN,
24
- VISIT_UNKNOWN_TOKEN,
25
24
  )
26
25
  from cehrbert_data.const.common import NA
27
26
  from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
@@ -29,6 +28,12 @@ from datasets.formatting.formatting import LazyBatch
29
28
  from dateutil.relativedelta import relativedelta
30
29
  from pandas import Series
31
30
 
31
+ from cehrgpt.gpt_utils import (
32
+ construct_age_sequence,
33
+ construct_time_sequence,
34
+ encode_demographics,
35
+ multiple_of_10,
36
+ )
32
37
  from cehrgpt.models.tokenization_hf_cehrgpt import (
33
38
  NONE_BIN,
34
39
  UNKNOWN_BIN,
@@ -44,13 +49,20 @@ CEHRGPT_COLUMNS = [
44
49
  "concept_values",
45
50
  "units",
46
51
  "epoch_times",
52
+ "ages",
47
53
  ]
48
54
 
49
55
 
50
- def convert_date_to_posix_time(index_date: datetime.date) -> float:
51
- return datetime.datetime.combine(
52
- index_date, datetime.datetime.min.time()
53
- ).timestamp()
56
+ def convert_date_to_posix_time(index_date: Union[datetime.date, int, float]) -> float:
57
+ if isinstance(index_date, datetime.date):
58
+ return (
59
+ datetime.datetime.combine(index_date, datetime.datetime.min.time())
60
+ .replace(tzinfo=datetime.timezone.utc)
61
+ .timestamp()
62
+ )
63
+ elif isinstance(index_date, datetime.datetime):
64
+ return index_date.replace(tzinfo=datetime.timezone.utc).timestamp()
65
+ return index_date
54
66
 
55
67
 
56
68
  class DatasetMappingDecorator(DatasetMapping):
@@ -116,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
116
128
  cehrgpt_record: Dict[str, Any],
117
129
  code: str,
118
130
  time: datetime.datetime,
131
+ age: int,
119
132
  concept_value_mask: int = 0,
120
133
  number_as_value: float = 0.0,
121
134
  concept_as_value: str = "0",
@@ -123,17 +136,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
123
136
  unit: str = NA,
124
137
  ) -> None:
125
138
  cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
139
+ cehrgpt_record["ages"].append(age)
126
140
  cehrgpt_record["concept_value_masks"].append(concept_value_mask)
127
141
  cehrgpt_record["number_as_values"].append(number_as_value)
128
142
  cehrgpt_record["concept_as_values"].append(concept_as_value)
129
143
  cehrgpt_record["units"].append(unit)
130
144
  cehrgpt_record["is_numeric_types"].append(is_numeric_type)
131
- cehrgpt_record["epoch_times"].append(time.timestamp())
145
+ cehrgpt_record["epoch_times"].append(
146
+ time.replace(tzinfo=datetime.timezone.utc).timestamp()
147
+ )
132
148
 
133
149
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
134
150
  cehrgpt_record = {
135
151
  "person_id": record["patient_id"],
136
152
  "concept_ids": [],
153
+ "ages": [],
137
154
  "concept_value_masks": [],
138
155
  "number_as_values": [],
139
156
  "concept_as_values": [],
@@ -161,14 +178,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
161
178
  first_visit_start_datetime: datetime.datetime = get_value(
162
179
  first_visit, "visit_start_datetime"
163
180
  )
181
+ starting_age = relativedelta(first_visit_start_datetime, birth_datetime).years
164
182
  year_str = f"year:{str(first_visit_start_datetime.year)}"
165
- age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
183
+ age_str = f"age:{starting_age}"
184
+ self._update_cehrgpt_record(
185
+ cehrgpt_record, year_str, first_visit_start_datetime, starting_age
186
+ )
187
+ self._update_cehrgpt_record(
188
+ cehrgpt_record, age_str, first_visit_start_datetime, starting_age
189
+ )
190
+ self._update_cehrgpt_record(
191
+ cehrgpt_record, gender, first_visit_start_datetime, starting_age
192
+ )
166
193
  self._update_cehrgpt_record(
167
- cehrgpt_record, year_str, first_visit_start_datetime
194
+ cehrgpt_record, race, first_visit_start_datetime, starting_age
168
195
  )
169
- self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
170
- self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
171
- self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
172
196
 
173
197
  # Use a data cursor to keep track of time
174
198
  datetime_cursor: Optional[datetime.datetime] = None
@@ -204,6 +228,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
204
228
  cehrgpt_record,
205
229
  code=self._time_token_function(time_delta),
206
230
  time=visit_start_datetime,
231
+ age=relativedelta(datetime_cursor, birth_datetime).years,
207
232
  )
208
233
 
209
234
  datetime_cursor = visit_start_datetime
@@ -212,12 +237,14 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
212
237
  cehrgpt_record,
213
238
  code="[VS]",
214
239
  time=datetime_cursor,
240
+ age=relativedelta(datetime_cursor, birth_datetime).years,
215
241
  )
216
242
  # Add a visit type token
217
243
  self._update_cehrgpt_record(
218
244
  cehrgpt_record,
219
245
  code=visit_type,
220
246
  time=datetime_cursor,
247
+ age=relativedelta(datetime_cursor, birth_datetime).years,
221
248
  )
222
249
  # We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
223
250
  # with respect to the midnight of the day
@@ -228,6 +255,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
228
255
  cehrgpt_record,
229
256
  code=f"i-H{datetime_cursor.hour}",
230
257
  time=datetime_cursor,
258
+ age=relativedelta(datetime_cursor, birth_datetime).years,
231
259
  )
232
260
 
233
261
  # Keep track of the existing outpatient events, we don't want to add them again
@@ -274,6 +302,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
274
302
  cehrgpt_record,
275
303
  code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
276
304
  time=event_time,
305
+ age=relativedelta(event_time, birth_datetime).years,
277
306
  )
278
307
 
279
308
  if self._include_inpatient_hour_token:
@@ -293,6 +322,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
293
322
  cehrgpt_record,
294
323
  code=f"i-H{time_diff_hours}",
295
324
  time=event_time,
325
+ age=relativedelta(event_time, birth_datetime).years,
296
326
  )
297
327
 
298
328
  if event_identity in existing_duplicate_events:
@@ -302,6 +332,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
302
332
  cehrgpt_record,
303
333
  code=code,
304
334
  time=event_time,
335
+ age=relativedelta(event_time, birth_datetime).years,
305
336
  concept_value_mask=concept_value_mask,
306
337
  unit=unit,
307
338
  number_as_value=numeric_value if numeric_value else 0.0,
@@ -341,6 +372,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
341
372
  cehrgpt_record,
342
373
  code=discharge_facility,
343
374
  time=datetime_cursor,
375
+ age=relativedelta(datetime_cursor, birth_datetime).years,
344
376
  )
345
377
 
346
378
  # Reuse the age and date calculated for the last event in the patient timeline
@@ -348,6 +380,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
348
380
  cehrgpt_record,
349
381
  code="[VE]",
350
382
  time=datetime_cursor,
383
+ age=relativedelta(datetime_cursor, birth_datetime).years,
351
384
  )
352
385
 
353
386
  # Generate the orders of the concepts that the cehrbert dataset mapping function expects
@@ -360,7 +393,9 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
360
393
  cehrgpt_record["num_of_visits"] = len(visits)
361
394
 
362
395
  if record.get("index_date", None) is not None:
363
- cehrgpt_record["index_date"] = record["index_date"]
396
+ cehrgpt_record["index_date"] = (
397
+ record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
398
+ )
364
399
  if record.get("label", None) is not None:
365
400
  cehrgpt_record["label"] = record["label"]
366
401
  if record.get("age_at_index", None) is not None:
@@ -419,6 +454,13 @@ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
419
454
  return record
420
455
 
421
456
  def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
457
+ # Reconstruct the ages input before the filter is applied
458
+ record["ages"] = construct_age_sequence(
459
+ record["concept_ids"], record.get("ages", None)
460
+ )
461
+ record["epoch_times"] = construct_time_sequence(
462
+ record["concept_ids"], record.get("epoch_times", None)
463
+ )
422
464
  # Remove the tokens from patient sequences that do not exist in the tokenizer
423
465
  record = self.filter_out_invalid_tokens(record)
424
466
  # If any concept has a value associated with it, we normalize the value
@@ -529,9 +571,13 @@ class ExtractTokenizedSequenceDataMapping:
529
571
  prediction_start_end_times = [
530
572
  (
531
573
  self._calculate_prediction_start_time(
532
- prediction_time_label_map["index_date"].timestamp()
574
+ prediction_time_label_map["index_date"]
575
+ .replace(tzinfo=datetime.timezone.utc)
576
+ .timestamp()
533
577
  ),
534
- prediction_time_label_map["index_date"].timestamp(),
578
+ prediction_time_label_map["index_date"]
579
+ .replace(tzinfo=datetime.timezone.utc)
580
+ .timestamp(),
535
581
  prediction_time_label_map["label"],
536
582
  )
537
583
  for prediction_time_label_map in prediction_times
@@ -0,0 +1,316 @@
1
+ import datetime
2
+ import os
3
+ import random
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import Any, Dict
7
+
8
+ import numpy as np
9
+ import polars as pl
10
+ import torch
11
+ import torch.distributed as dist
12
+ from cehrbert.runners.runner_util import generate_prepared_ds_path
13
+ from datasets import load_from_disk
14
+ from meds import held_out_split, train_split, tuning_split
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+ from transformers.trainer_utils import is_main_process
18
+ from transformers.utils import is_flash_attn_2_available, logging
19
+
20
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
21
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
22
+ from cehrgpt.generation.generate_batch_hf_gpt_sequence import (
23
+ generate_single_batch,
24
+ normalize_value,
25
+ )
26
+ from cehrgpt.gpt_utils import (
27
+ extract_time_interval_in_days,
28
+ extract_time_interval_in_hours,
29
+ is_att_token,
30
+ is_inpatient_hour_token,
31
+ is_visit_end,
32
+ is_visit_start,
33
+ )
34
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
35
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
36
+ from cehrgpt.runners.data_utils import (
37
+ extract_cohort_sequences,
38
+ prepare_finetune_dataset,
39
+ )
40
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
41
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
42
+
43
+ LOG = logging.get_logger("transformers")
44
+
45
+
46
+ def map_data_split_name(split: str) -> str:
47
+ if split == "train":
48
+ return train_split
49
+ elif split == "validation":
50
+ return tuning_split
51
+ elif split == "test":
52
+ return held_out_split
53
+ raise ValueError(f"Unknown split: {split}")
54
+
55
+
56
+ def seed_all(seed: int = 42):
57
+ """Set seed for Python, NumPy, and PyTorch (CPU & CUDA)."""
58
+ random.seed(seed) # Python random
59
+ np.random.seed(seed) # NumPy
60
+ torch.manual_seed(seed) # PyTorch CPU
61
+ torch.cuda.manual_seed(seed) # Current GPU
62
+ torch.cuda.manual_seed_all(seed) # All GPUs
63
+
64
+ # For reproducibility in dataloader workers
65
+ os.environ["PYTHONHASHSEED"] = str(seed)
66
+
67
+
68
+ def generate_trajectories_per_batch(
69
+ batch: Dict[str, Any],
70
+ cehrgpt_tokenizer: CehrGptTokenizer,
71
+ cehrgpt_model: CEHRGPT2LMHeadModel,
72
+ device,
73
+ data_output_path: Path,
74
+ max_length: int,
75
+ ):
76
+ subject_ids = batch["person_id"].squeeze().detach().cpu().tolist()
77
+ prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
78
+ batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
79
+ batched_input_ids = batch["input_ids"]
80
+ batched_ages = batch["ages"]
81
+ batched_value_indicators = batch["value_indicators"]
82
+ batched_values = batch["values"]
83
+ # Make sure the batch does not exceed batch_size
84
+ batch_sequences = generate_single_batch(
85
+ cehrgpt_model,
86
+ cehrgpt_tokenizer,
87
+ batched_input_ids,
88
+ ages=batched_ages,
89
+ values=batched_values,
90
+ value_indicators=batched_value_indicators,
91
+ max_length=max_length,
92
+ top_p=1.0,
93
+ top_k=cehrgpt_tokenizer.vocab_size,
94
+ device=device,
95
+ )
96
+ # Clear the cache
97
+ torch.cuda.empty_cache()
98
+
99
+ trajectories = []
100
+ for sample_i, (concept_ids, value_indicators, values) in enumerate(
101
+ zip(
102
+ batch_sequences["sequences"],
103
+ batch_sequences["value_indicators"],
104
+ batch_sequences["values"],
105
+ )
106
+ ):
107
+ (
108
+ concept_ids,
109
+ is_numeric_types,
110
+ number_as_values,
111
+ concept_as_values,
112
+ units,
113
+ ) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
114
+
115
+ epoch_times = batched_epoch_times[sample_i]
116
+ input_length = len(epoch_times)
117
+ # Getting the last observed event time from the token before the prediction time
118
+ window_last_observed = epoch_times[input_length - 1]
119
+ current_cursor = epoch_times[-1]
120
+ generated_epoch_times = []
121
+ valid_indices = []
122
+
123
+ for i in range(input_length, len(concept_ids)):
124
+ concept_id = concept_ids[i]
125
+ # We use the left padding strategy in the data collator
126
+ if concept_id in [cehrgpt_tokenizer.pad_token, cehrgpt_tokenizer.end_token]:
127
+ continue
128
+ # We need to construct the time stamp
129
+ if is_att_token(concept_id):
130
+ current_cursor += extract_time_interval_in_days(concept_id) * 24 * 3600
131
+ elif is_inpatient_hour_token(concept_id):
132
+ current_cursor += extract_time_interval_in_hours(concept_id) * 3600
133
+ elif is_visit_start(concept_id) or is_visit_end(concept_id):
134
+ continue
135
+ else:
136
+ valid_indices.append(i)
137
+ generated_epoch_times.append(
138
+ datetime.datetime.utcfromtimestamp(current_cursor).replace(
139
+ tzinfo=None
140
+ )
141
+ )
142
+
143
+ trajectories.append(
144
+ {
145
+ "subject_id": subject_ids[sample_i],
146
+ "prediction_time": datetime.datetime.utcfromtimestamp(
147
+ prediction_times[sample_i]
148
+ ).replace(tzinfo=None),
149
+ "window_last_observed_time": datetime.datetime.utcfromtimestamp(
150
+ window_last_observed
151
+ ).replace(tzinfo=None),
152
+ "times": generated_epoch_times,
153
+ "concept_ids": np.asarray(concept_ids)[valid_indices].tolist(),
154
+ "numeric_values": np.asarray(number_as_values)[valid_indices].tolist(),
155
+ "text_value": np.asarray(concept_as_values)[valid_indices].tolist(),
156
+ "units": np.asarray(units)[valid_indices].tolist(),
157
+ }
158
+ )
159
+
160
+ trajectories = (
161
+ pl.DataFrame(trajectories)
162
+ .explode(["times", "concept_ids", "numeric_values", "text_value", "units"])
163
+ .rename(
164
+ {
165
+ "times": "time",
166
+ "concept_ids": "code",
167
+ "numeric_values": "numeric_value",
168
+ "units": "unit",
169
+ }
170
+ )
171
+ .select(
172
+ "subject_id",
173
+ "prediction_time",
174
+ "window_last_observed_time",
175
+ "time",
176
+ "code",
177
+ "numeric_value",
178
+ "text_value",
179
+ "unit",
180
+ )
181
+ )
182
+ trajectories.write_parquet(data_output_path)
183
+
184
+
185
+ def main():
186
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
187
+ if torch.cuda.is_available():
188
+ device = torch.device("cuda")
189
+ else:
190
+ device = torch.device("cpu")
191
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
192
+ model_args.tokenizer_name_or_path
193
+ )
194
+ cehrgpt_model = (
195
+ CEHRGPT2LMHeadModel.from_pretrained(
196
+ model_args.model_name_or_path,
197
+ attn_implementation=(
198
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
199
+ ),
200
+ )
201
+ .eval()
202
+ .to(device)
203
+ )
204
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
205
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
206
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
207
+
208
+ if not os.path.exists(training_args.output_dir):
209
+ os.makedirs(training_args.output_dir)
210
+
211
+ prepared_ds_path = generate_prepared_ds_path(
212
+ data_args, model_args, data_folder=data_args.cohort_folder
213
+ )
214
+
215
+ processed_dataset = None
216
+ if any(prepared_ds_path.glob("*")):
217
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
218
+ processed_dataset = load_from_disk(str(prepared_ds_path))
219
+ LOG.info("Prepared dataset loaded from disk...")
220
+ if cehrgpt_args.expand_tokenizer:
221
+ if tokenizer_exists(training_args.output_dir):
222
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
223
+ training_args.output_dir
224
+ )
225
+ else:
226
+ LOG.warning(
227
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
228
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
229
+ f"Please delete the processed dataset at {prepared_ds_path}."
230
+ )
231
+ processed_dataset = None
232
+ shutil.rmtree(prepared_ds_path)
233
+
234
+ if processed_dataset is None and is_main_process(training_args.local_rank):
235
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
236
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
237
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
238
+ processed_dataset = extract_cohort_sequences(data_args, cehrgpt_args)
239
+ else:
240
+ # Organize them into a single DatasetDict
241
+ final_splits = prepare_finetune_dataset(
242
+ data_args, training_args, cehrgpt_args
243
+ )
244
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
245
+ if not data_args.streaming:
246
+ all_columns = final_splits["train"].column_names
247
+ if "visit_concept_ids" in all_columns:
248
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
249
+
250
+ processed_dataset = create_cehrgpt_finetuning_dataset(
251
+ dataset=final_splits,
252
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
253
+ data_args=data_args,
254
+ )
255
+ if not data_args.streaming:
256
+ processed_dataset.save_to_disk(prepared_ds_path)
257
+ processed_dataset.cleanup_cache_files()
258
+
259
+ # After main-process-only operations, synchronize all processes to ensure consistency
260
+ if dist.is_available() and dist.is_initialized():
261
+ dist.barrier()
262
+
263
+ # We suppress the additional learning objectives in fine-tuning
264
+ data_collator = CehrGptDataCollator(
265
+ tokenizer=cehrgpt_tokenizer,
266
+ max_length=cehrgpt_args.generation_input_length,
267
+ include_values=cehrgpt_model.config.include_values,
268
+ pretraining=False,
269
+ include_ttv_prediction=False,
270
+ use_sub_time_tokenization=False,
271
+ include_demographics=False,
272
+ add_linear_prob_token=False,
273
+ )
274
+
275
+ LOG.info(
276
+ "Generating %s trajectories per sample",
277
+ cehrgpt_args.num_of_trajectories_per_sample,
278
+ )
279
+ for sample_i in range(cehrgpt_args.num_of_trajectories_per_sample):
280
+ for split, dataset in processed_dataset.items():
281
+ meds_split = map_data_split_name(split)
282
+ dataloader = DataLoader(
283
+ dataset=dataset,
284
+ batch_size=training_args.per_device_eval_batch_size,
285
+ num_workers=training_args.dataloader_num_workers,
286
+ collate_fn=data_collator,
287
+ pin_memory=training_args.dataloader_pin_memory,
288
+ )
289
+ sample_output_dir = (
290
+ Path(training_args.output_dir) / meds_split / f"{sample_i}"
291
+ )
292
+ sample_output_dir.mkdir(exist_ok=True, parents=True)
293
+ for batch_i, batch in tqdm(
294
+ enumerate(dataloader),
295
+ desc=f"Generating Trajectories for split {meds_split} with trajectory {sample_i + 1}",
296
+ ):
297
+ output_parquet_file = sample_output_dir / f"{batch_i}.parquet"
298
+ if output_parquet_file.exists():
299
+ LOG.info("%s already exists, skip...", output_parquet_file)
300
+ continue
301
+
302
+ generate_trajectories_per_batch(
303
+ batch,
304
+ cehrgpt_tokenizer,
305
+ cehrgpt_model,
306
+ device,
307
+ sample_output_dir / f"{batch_i}.parquet",
308
+ cehrgpt_args.generation_max_new_tokens
309
+ + cehrgpt_args.generation_input_length,
310
+ )
311
+
312
+
313
+ if __name__ == "__main__":
314
+ # ✅ Call first thing inside main()
315
+ seed_all(42)
316
+ main()
@@ -2,7 +2,7 @@ import datetime
2
2
  import os
3
3
  import random
4
4
  import uuid
5
- from typing import Any, Dict, List, Optional, Sequence, Tuple
5
+ from typing import Any, Dict, Optional, Sequence, Tuple
6
6
 
7
7
  import numpy as np
8
8
  import pandas as pd
@@ -13,7 +13,7 @@ from transformers.utils import is_flash_attn_2_available, logging
13
13
 
14
14
  from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
15
15
  from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
16
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
16
+ from cehrgpt.gpt_utils import construct_age_sequence, get_cehrgpt_output_folder
17
17
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
18
  from cehrgpt.models.special_tokens import END_TOKEN
19
19
  from cehrgpt.models.tokenization_hf_cehrgpt import (
@@ -72,9 +72,13 @@ def normalize_value(
72
72
 
73
73
  def generate_single_batch(
74
74
  model: CEHRGPT2LMHeadModel,
75
- tokenizer: CehrGptTokenizer,
76
- prompts: List[List[int]],
77
- max_new_tokens=512,
75
+ cehrgpt_tokenizer: CehrGptTokenizer,
76
+ prompts: torch.Tensor,
77
+ max_length: int,
78
+ ages: Optional[torch.Tensor] = None,
79
+ values: Optional[torch.Tensor] = None,
80
+ value_indicators: Optional[torch.Tensor] = None,
81
+ max_new_tokens: Optional[int] = None,
78
82
  mini_num_of_concepts=1,
79
83
  top_p=0.95,
80
84
  top_k=50,
@@ -88,7 +92,8 @@ def generate_single_batch(
88
92
  with torch.no_grad():
89
93
  generation_config = GenerationConfig(
90
94
  repetition_penalty=repetition_penalty,
91
- max_length=max_new_tokens,
95
+ max_new_tokens=max_new_tokens,
96
+ max_length=max_length,
92
97
  min_length=mini_num_of_concepts,
93
98
  temperature=temperature,
94
99
  top_p=top_p,
@@ -107,20 +112,33 @@ def generate_single_batch(
107
112
  num_beam_groups=num_beam_groups,
108
113
  epsilon_cutoff=epsilon_cutoff,
109
114
  )
110
- batched_prompts = torch.tensor(prompts).to(device)
115
+
116
+ batched_prompts = prompts.to(device)
117
+ if ages is not None:
118
+ ages = ages.to(device)
119
+ if values is not None:
120
+ values = values.to(device)
121
+ if value_indicators is not None:
122
+ value_indicators = value_indicators.to(device)
123
+
111
124
  results = model.generate(
112
125
  inputs=batched_prompts,
126
+ ages=ages,
127
+ values=values,
128
+ value_indicators=value_indicators,
113
129
  generation_config=generation_config,
114
- lab_token_ids=tokenizer.lab_token_ids,
130
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
115
131
  )
116
132
 
117
133
  sequences = [
118
- tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
134
+ cehrgpt_tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
119
135
  for seq in results.sequences
120
136
  ]
121
137
  if results.sequence_vals is not None:
122
138
  values = [
123
- tokenizer.decode_value(values.cpu().numpy(), skip_special_tokens=False)
139
+ cehrgpt_tokenizer.decode_value(
140
+ values.cpu().numpy(), skip_special_tokens=False
141
+ )
124
142
  for values in results.sequence_vals
125
143
  ]
126
144
  else:
@@ -202,6 +220,7 @@ def main(args):
202
220
 
203
221
  # Randomly pick demographics from the existing population
204
222
  random_prompts = []
223
+ random_prompt_ages = []
205
224
  iter = 0
206
225
  while len(random_prompts) < args.batch_size:
207
226
  for row in dataset.select(
@@ -212,9 +231,9 @@ def main(args):
212
231
  <= len(row["concept_ids"])
213
232
  <= max_seq_allowed
214
233
  ):
215
- random_prompts.append(
216
- cehrgpt_tokenizer.encode(row["concept_ids"][:prompt_size])
217
- )
234
+ prompt = row["concept_ids"][:prompt_size]
235
+ random_prompts.append(cehrgpt_tokenizer.encode(prompt))
236
+ random_prompt_ages.append(construct_age_sequence(prompt))
218
237
  iter += 1
219
238
  if not random_prompts and iter > 10:
220
239
  raise RuntimeError(
@@ -225,8 +244,9 @@ def main(args):
225
244
  batch_sequences = generate_single_batch(
226
245
  cehrgpt_model,
227
246
  cehrgpt_tokenizer,
228
- random_prompts[: args.batch_size],
229
- max_new_tokens=args.context_window,
247
+ torch.tensor(random_prompts[: args.batch_size]),
248
+ ages=torch.tensor(random_prompt_ages[: args.batch_size]),
249
+ max_length=args.context_window,
230
250
  mini_num_of_concepts=args.min_num_of_concepts,
231
251
  top_p=args.top_p,
232
252
  top_k=args.top_k,
@@ -270,20 +270,24 @@ def gpt_to_omop_converter_batch(
270
270
 
271
271
  is_numeric_types = (
272
272
  is_numeric_types[START_TOKEN_SIZE:]
273
- if is_numeric_types is not None
273
+ if is_numeric_types is not None and not np.all(pd.isna(is_numeric_types))
274
274
  else None
275
275
  )
276
276
  number_as_values = (
277
277
  number_as_values[START_TOKEN_SIZE:]
278
- if number_as_values is not None
278
+ if number_as_values is not None and not np.all(pd.isna(number_as_values))
279
279
  else None
280
280
  )
281
281
  concept_as_values = (
282
282
  concept_as_values[START_TOKEN_SIZE:]
283
- if concept_as_values is not None
283
+ if concept_as_values is not None and not np.all(pd.isna(concept_as_values))
284
+ else None
285
+ )
286
+ units = (
287
+ units[START_TOKEN_SIZE:]
288
+ if units is not None and not np.all(pd.isna(units))
284
289
  else None
285
290
  )
286
- units = units[START_TOKEN_SIZE:] if units is not None else None
287
291
 
288
292
  # TODO:Need to decode if the input is tokenized
289
293
  [start_year, start_age, start_gender, start_race] = concept_ids[
@@ -441,6 +445,9 @@ def gpt_to_omop_converter_batch(
441
445
  ]:
442
446
  # If it's a start token, skip it
443
447
  pass
448
+ elif event.endswith("/0"):
449
+ # This should capture the concept such as Visit/0, Discharge/0
450
+ pass
444
451
  else:
445
452
  try:
446
453
  concept_id = int(event)