cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,495 @@
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ from datetime import datetime
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import polars as pl
13
+ import torch
14
+ import torch.distributed as dist
15
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
16
+ from cehrbert.runners.runner_util import generate_prepared_ds_path
17
+ from datasets import concatenate_datasets, load_from_disk
18
+ from torch.utils.data import DataLoader
19
+ from tqdm import tqdm
20
+ from transformers.trainer_utils import is_main_process
21
+ from transformers.utils import is_flash_attn_2_available, logging
22
+
23
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
24
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
25
+ CehrGptDataCollator,
26
+ SamplePackingCehrGptDataCollator,
27
+ )
28
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
29
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
30
+ from cehrgpt.models.hf_cehrgpt import (
31
+ CEHRGPT2Model,
32
+ extract_features_from_packed_sequence,
33
+ )
34
+ from cehrgpt.models.special_tokens import LINEAR_PROB_TOKEN
35
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
36
+ from cehrgpt.runners.data_utils import (
37
+ extract_cohort_sequences,
38
+ prepare_finetune_dataset,
39
+ )
40
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
41
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
42
+
43
+ LOG = logging.get_logger("transformers")
44
+
45
+
46
+ def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
47
+ if torch_dtype and hasattr(torch, torch_dtype):
48
+ return getattr(torch, torch_dtype)
49
+ return torch.float32
50
+
51
+
52
+ def extract_averaged_embeddings_from_packed_sequence(
53
+ hidden_states: torch.Tensor,
54
+ attention_mask: torch.Tensor,
55
+ ve_token_indicators: torch.BoolTensor,
56
+ ) -> torch.Tensor:
57
+ """
58
+ Args:
59
+
60
+ hidden_states: (batch_size=1, seq_len, hidden_dim) tensor
61
+ attention_mask: (batch_size=1, seq_len) tensor, where 0 indicates padding
62
+ ve_token_indicators: (batch_size=1, seq_len) bool tensor, True if token is VE token
63
+ Returns:
64
+ (num_samples, hidden_dim) tensor: averaged embeddings over VE tokens for each sample
65
+ """
66
+ # Step 1: Create segment IDs
67
+ mask = attention_mask[0] # (seq_len,)
68
+ segment_ids = (mask == 0).cumsum(dim=0) + 1 # start segment IDs from 1
69
+ segment_ids = (segment_ids * mask).to(torch.int32) # set PAD positions back to 0
70
+
71
+ # Step 2: Only keep tokens that are both valid and VE tokens
72
+ valid = (segment_ids > 0) & (ve_token_indicators[0])
73
+ valid_embeddings = hidden_states[0, valid].to(
74
+ torch.float32
75
+ ) # (num_valid_ve_tokens, hidden_dim)
76
+ valid_segments = segment_ids[valid] # (num_valid_ve_tokens,)
77
+
78
+ # Step 3: Group by segment id and average
79
+ num_segments = int(segment_ids.max().item())
80
+
81
+ sample_embeddings = torch.zeros(
82
+ num_segments, hidden_states.size(-1), device=hidden_states.device
83
+ )
84
+ counts = torch.zeros(num_segments, device=hidden_states.device)
85
+
86
+ sample_embeddings.index_add_(0, valid_segments - 1, valid_embeddings)
87
+ counts.index_add_(
88
+ 0, valid_segments - 1, torch.ones_like(valid_segments, dtype=counts.dtype)
89
+ )
90
+
91
+ # Avoid divide-by-zero (if some segments have no VE tokens, set their embeddings to zero)
92
+ counts = counts.masked_fill(counts == 0, 1.0)
93
+
94
+ sample_embeddings = sample_embeddings / counts.unsqueeze(-1)
95
+
96
+ return sample_embeddings
97
+
98
+
99
+ def main():
100
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
101
+ if torch.cuda.is_available():
102
+ device = torch.device("cuda")
103
+ else:
104
+ device = torch.device("cpu")
105
+
106
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
107
+ model_args.tokenizer_name_or_path
108
+ )
109
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
110
+ cehrgpt_model = (
111
+ CEHRGPT2Model.from_pretrained(
112
+ model_args.model_name_or_path,
113
+ attn_implementation=(
114
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
115
+ ),
116
+ torch_dtype=torch_dtype,
117
+ )
118
+ .eval()
119
+ .to(device)
120
+ )
121
+
122
+ if LINEAR_PROB_TOKEN not in cehrgpt_tokenizer.get_vocab():
123
+ cehrgpt_tokenizer.add_tokens(LINEAR_PROB_TOKEN)
124
+ cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
125
+
126
+ prepared_ds_path = generate_prepared_ds_path(
127
+ data_args, model_args, data_folder=data_args.cohort_folder
128
+ )
129
+ cache_file_collector = CacheFileCollector()
130
+ processed_dataset = None
131
+ if any(prepared_ds_path.glob("*")):
132
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
133
+ processed_dataset = load_from_disk(str(prepared_ds_path))
134
+ LOG.info("Prepared dataset loaded from disk...")
135
+ if cehrgpt_args.expand_tokenizer:
136
+ if tokenizer_exists(training_args.output_dir):
137
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
138
+ training_args.output_dir
139
+ )
140
+ else:
141
+ LOG.warning(
142
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
143
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
144
+ f"Please delete the processed dataset at {prepared_ds_path}."
145
+ )
146
+ processed_dataset = None
147
+ shutil.rmtree(prepared_ds_path)
148
+
149
+ if processed_dataset is None:
150
+ if is_main_process(training_args.local_rank):
151
+ # If the full dataset has been tokenized, we don't want to tokenize the cohort containing
152
+ # the subset of the data. We should slice out the portion of the tokenized sequences for each sample
153
+ if cehrgpt_args.tokenized_full_dataset_path is not None:
154
+ processed_dataset = extract_cohort_sequences(
155
+ data_args, cehrgpt_args, cache_file_collector
156
+ )
157
+ else:
158
+ # Organize them into a single DatasetDict
159
+ final_splits = prepare_finetune_dataset(
160
+ data_args, training_args, cehrgpt_args, cache_file_collector
161
+ )
162
+ if cehrgpt_args.expand_tokenizer:
163
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
164
+ if tokenizer_exists(new_tokenizer_path):
165
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
166
+ new_tokenizer_path
167
+ )
168
+ else:
169
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
170
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
171
+ dataset=final_splits["train"],
172
+ data_args=data_args,
173
+ concept_name_mapping={},
174
+ )
175
+ cehrgpt_tokenizer.save_pretrained(
176
+ os.path.expanduser(training_args.output_dir)
177
+ )
178
+
179
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
180
+ if not data_args.streaming:
181
+ all_columns = final_splits["train"].column_names
182
+ if "visit_concept_ids" in all_columns:
183
+ final_splits = final_splits.remove_columns(
184
+ ["visit_concept_ids"]
185
+ )
186
+
187
+ processed_dataset = create_cehrgpt_finetuning_dataset(
188
+ dataset=final_splits,
189
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
190
+ data_args=data_args,
191
+ cache_file_collector=cache_file_collector,
192
+ )
193
+ if not data_args.streaming:
194
+ processed_dataset.save_to_disk(prepared_ds_path)
195
+ processed_dataset.cleanup_cache_files()
196
+
197
+ # Remove all the cached files if processed_dataset.cleanup_cache_files() did not remove them already
198
+ cache_file_collector.remove_cache_files()
199
+
200
+ # After main-process-only operations, synchronize all processes to ensure consistency
201
+ if dist.is_available() and dist.is_initialized():
202
+ dist.barrier()
203
+
204
+ # Load the dataset from disk again to in torch distributed training
205
+ processed_dataset = load_from_disk(str(prepared_ds_path))
206
+
207
+ # Getting the existing features
208
+ feature_folders = glob.glob(
209
+ os.path.join(training_args.output_dir, "*", "features", "*.parquet")
210
+ )
211
+ if feature_folders:
212
+ existing_features = pd.concat(
213
+ [
214
+ pd.read_parquet(f, columns=["subject_id", "prediction_time_posix"])
215
+ for f in feature_folders
216
+ ],
217
+ ignore_index=True,
218
+ )
219
+ subject_prediction_tuples = set(
220
+ existing_features.apply(
221
+ lambda row: f"{int(row['subject_id'])}-{int(row['prediction_time_posix'])}",
222
+ axis=1,
223
+ ).tolist()
224
+ )
225
+ processed_dataset = processed_dataset.filter(
226
+ lambda _batch: [
227
+ f"{int(subject)}-{int(time)}" not in subject_prediction_tuples
228
+ for subject, time in zip(_batch["person_id"], _batch["index_date"])
229
+ ],
230
+ num_proc=data_args.preprocessing_num_workers,
231
+ batch_size=data_args.preprocessing_batch_size,
232
+ batched=True,
233
+ )
234
+ LOG.info(
235
+ "The datasets after filtering (train: %s, validation: %s, test: %s)",
236
+ len(processed_dataset["train"]),
237
+ len(processed_dataset["validation"]),
238
+ len(processed_dataset["test"]),
239
+ )
240
+
241
+ LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
242
+ LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
243
+ if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
244
+ cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
245
+ if (
246
+ cehrgpt_model.config.max_position_embeddings
247
+ < model_args.max_position_embeddings
248
+ ):
249
+ LOG.info(
250
+ f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
251
+ )
252
+ cehrgpt_model.config.max_position_embeddings = (
253
+ model_args.max_position_embeddings
254
+ )
255
+ cehrgpt_model.resize_position_embeddings(model_args.max_position_embeddings)
256
+
257
+ train_set = concatenate_datasets(
258
+ [processed_dataset["train"], processed_dataset["validation"]]
259
+ )
260
+
261
+ if cehrgpt_args.sample_packing:
262
+ per_device_eval_batch_size = 1
263
+ data_collator_fn = partial(
264
+ SamplePackingCehrGptDataCollator,
265
+ cehrgpt_args.max_tokens_per_batch,
266
+ cehrgpt_model.config.max_position_embeddings,
267
+ add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
268
+ )
269
+ train_batch_sampler = SamplePackingBatchSampler(
270
+ lengths=train_set["num_of_concepts"],
271
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
272
+ max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
273
+ drop_last=training_args.dataloader_drop_last,
274
+ seed=training_args.seed,
275
+ )
276
+ test_batch_sampler = SamplePackingBatchSampler(
277
+ lengths=processed_dataset["test"]["num_of_concepts"],
278
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
279
+ max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
280
+ drop_last=training_args.dataloader_drop_last,
281
+ seed=training_args.seed,
282
+ )
283
+ else:
284
+ data_collator_fn = CehrGptDataCollator
285
+ train_batch_sampler = None
286
+ test_batch_sampler = None
287
+ per_device_eval_batch_size = training_args.per_device_eval_batch_size
288
+
289
+ # We suppress the additional learning objectives in fine-tuning
290
+ data_collator = data_collator_fn(
291
+ tokenizer=cehrgpt_tokenizer,
292
+ max_length=(
293
+ cehrgpt_args.max_tokens_per_batch
294
+ if cehrgpt_args.sample_packing
295
+ else model_args.max_position_embeddings
296
+ ),
297
+ include_values=cehrgpt_model.config.include_values,
298
+ pretraining=False,
299
+ include_ttv_prediction=False,
300
+ use_sub_time_tokenization=False,
301
+ include_demographics=cehrgpt_args.include_demographics,
302
+ add_linear_prob_token=True,
303
+ )
304
+
305
+ train_loader = DataLoader(
306
+ dataset=train_set,
307
+ batch_size=per_device_eval_batch_size,
308
+ num_workers=training_args.dataloader_num_workers,
309
+ collate_fn=data_collator,
310
+ pin_memory=training_args.dataloader_pin_memory,
311
+ batch_sampler=train_batch_sampler,
312
+ )
313
+
314
+ test_dataloader = DataLoader(
315
+ dataset=processed_dataset["test"],
316
+ batch_size=per_device_eval_batch_size,
317
+ num_workers=training_args.dataloader_num_workers,
318
+ collate_fn=data_collator,
319
+ pin_memory=training_args.dataloader_pin_memory,
320
+ batch_sampler=test_batch_sampler,
321
+ )
322
+
323
+ if data_args.is_data_in_meds:
324
+ demographics_dict = dict()
325
+ else:
326
+ # Loading demographics
327
+ print("Loading demographics as a dictionary")
328
+ demographics_df = pd.concat(
329
+ [
330
+ pd.read_parquet(
331
+ data_dir,
332
+ columns=[
333
+ "person_id",
334
+ "index_date",
335
+ "gender_concept_id",
336
+ "race_concept_id",
337
+ ],
338
+ )
339
+ for data_dir in [data_args.data_folder, data_args.test_data_folder]
340
+ ]
341
+ )
342
+ # This is a pre-caution in case the index_date is not a datetime type
343
+ demographics_df["index_date"] = pd.to_datetime(
344
+ demographics_df["index_date"]
345
+ ).dt.date
346
+ demographics_dict = {
347
+ (row["person_id"], row["index_date"]): {
348
+ "gender_concept_id": row["gender_concept_id"],
349
+ "race_concept_id": row["race_concept_id"],
350
+ }
351
+ for _, row in demographics_df.iterrows()
352
+ }
353
+
354
+ data_loaders = [("train", train_loader), ("test", test_dataloader)]
355
+
356
+ ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
357
+ for split, data_loader in data_loaders:
358
+ # Ensure prediction folder exists
359
+ feature_output_folder = (
360
+ Path(training_args.output_dir) / "features_with_label" / f"{split}_features"
361
+ )
362
+ feature_output_folder.mkdir(parents=True, exist_ok=True)
363
+
364
+ LOG.info("Generating features for %s set at %s", split, feature_output_folder)
365
+
366
+ with torch.no_grad():
367
+ for index, batch in enumerate(
368
+ tqdm(data_loader, desc="Generating features")
369
+ ):
370
+ prediction_time_ages = (
371
+ batch.pop("age_at_index").numpy().astype(float).squeeze()
372
+ )
373
+ if prediction_time_ages.ndim == 0:
374
+ prediction_time_ages = np.asarray([prediction_time_ages])
375
+
376
+ person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
377
+ if person_ids.ndim == 0:
378
+ person_ids = np.asarray([person_ids])
379
+ prediction_time_posix = batch.pop("index_date").numpy().squeeze()
380
+ if prediction_time_posix.ndim == 0:
381
+ prediction_time_posix = np.asarray([prediction_time_posix])
382
+ prediction_time = list(
383
+ map(datetime.fromtimestamp, prediction_time_posix)
384
+ )
385
+ labels = (
386
+ batch.pop("classifier_label")
387
+ .float()
388
+ .cpu()
389
+ .numpy()
390
+ .astype(bool)
391
+ .squeeze()
392
+ )
393
+ if labels.ndim == 0:
394
+ labels = np.asarray([labels])
395
+
396
+ batch = {k: v.to(device) for k, v in batch.items()}
397
+ # Forward pass
398
+ cehrgpt_output = cehrgpt_model(
399
+ **batch, output_attentions=False, output_hidden_states=False
400
+ )
401
+ if cehrgpt_args.sample_packing:
402
+ if cehrgpt_args.average_over_sequence:
403
+ ve_token_indicators: torch.BoolTensor = (
404
+ batch["input_ids"] == ve_token_id
405
+ )
406
+ features = (
407
+ extract_averaged_embeddings_from_packed_sequence(
408
+ cehrgpt_output.last_hidden_state,
409
+ batch["attention_mask"],
410
+ ve_token_indicators,
411
+ )
412
+ .cpu()
413
+ .float()
414
+ .detach()
415
+ .numpy()
416
+ )
417
+ else:
418
+ features = (
419
+ extract_features_from_packed_sequence(
420
+ cehrgpt_output.last_hidden_state,
421
+ batch["attention_mask"],
422
+ )
423
+ .cpu()
424
+ .float()
425
+ .detach()
426
+ .numpy()
427
+ .squeeze(axis=0)
428
+ )
429
+ else:
430
+ if cehrgpt_args.average_over_sequence:
431
+ features = torch.where(
432
+ batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool),
433
+ cehrgpt_output.last_hidden_state,
434
+ 0,
435
+ )
436
+ # Average across the sequence
437
+ features = features.mean(dim=1)
438
+ else:
439
+ last_end_token = any(
440
+ [
441
+ cehrgpt_tokenizer.end_token_id == input_id
442
+ for input_id in batch.pop("input_ids")
443
+ .cpu()
444
+ .numpy()
445
+ .squeeze()
446
+ .tolist()
447
+ ]
448
+ )
449
+ last_token_index = -2 if last_end_token else -1
450
+ LOG.debug(
451
+ "The last token is [END], we need to use the token index before that: %s",
452
+ last_token_index,
453
+ )
454
+ features = (
455
+ cehrgpt_output.last_hidden_state[..., last_token_index, :]
456
+ .cpu()
457
+ .float()
458
+ .detach()
459
+ .numpy()
460
+ )
461
+
462
+ # Flatten features or handle them as a list of arrays (one array per row)
463
+ features_list = [feature for feature in features]
464
+ race_concept_ids = []
465
+ gender_concept_ids = []
466
+ for person_id, index_date in zip(person_ids, prediction_time):
467
+ key = (person_id, index_date.date())
468
+ if key in demographics_dict:
469
+ demographics = demographics_dict[key]
470
+ gender_concept_ids.append(demographics["gender_concept_id"])
471
+ race_concept_ids.append(demographics["race_concept_id"])
472
+ else:
473
+ gender_concept_ids.append(0)
474
+ race_concept_ids.append(0)
475
+
476
+ features_pd = pd.DataFrame(
477
+ {
478
+ "subject_id": person_ids,
479
+ "prediction_time": prediction_time,
480
+ "prediction_time_posix": prediction_time_posix,
481
+ "boolean_value": labels,
482
+ "age_at_index": prediction_time_ages,
483
+ }
484
+ )
485
+ # Adding features as a separate column where each row contains a feature array
486
+ features_pd["features"] = features_list
487
+ features_pd["race_concept_id"] = race_concept_ids
488
+ features_pd["gender_concept_id"] = gender_concept_ids
489
+ features_pd.to_parquet(
490
+ feature_output_folder / f"{uuid.uuid4()}.parquet"
491
+ )
492
+
493
+
494
+ if __name__ == "__main__":
495
+ main()
@@ -0,0 +1,152 @@
1
+ import argparse
2
+ import json
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Union
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import polars as pl
10
+ from sklearn.linear_model import LogisticRegressionCV
11
+ from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
12
+ from sklearn.preprocessing import OneHotEncoder, StandardScaler
13
+
14
+
15
+ def prepare_dataset(
16
+ df: pd.DataFrame, feature_processor: Dict[str, Union[StandardScaler, OneHotEncoder]]
17
+ ) -> Dict[str, Any]:
18
+ age_scaler = feature_processor["age_scaler"]
19
+ gender_encoder = feature_processor["gender_encoder"]
20
+ race_encoder = feature_processor["race_encoder"]
21
+ age_scaler.transform(df[["age_at_index"]].to_numpy())
22
+
23
+ one_hot_gender = gender_encoder.transform(
24
+ np.expand_dims(df.gender_concept_id.to_numpy(), axis=1)
25
+ )
26
+ one_hot_race = race_encoder.transform(
27
+ np.expand_dims(df.race_concept_id.to_numpy(), axis=1)
28
+ )
29
+
30
+ features = np.stack(df["features"].apply(lambda x: np.array(x).flatten()))
31
+ # features = np.hstack(
32
+ # [scaled_age, one_hot_gender.toarray(), one_hot_race.toarray(), features]
33
+ # )
34
+ return {
35
+ "subject_id": df["subject_id"].to_numpy(),
36
+ "prediction_time": df["prediction_time"].tolist(),
37
+ "features": features,
38
+ "boolean_value": df["boolean_value"].to_numpy(),
39
+ }
40
+
41
+
42
+ def main(args):
43
+ features_data_dir = Path(args.features_data_dir)
44
+ output_dir = Path(args.output_dir)
45
+ feature_processor_path = output_dir / "feature_processor.pickle"
46
+ logistic_dir = output_dir / "logistic"
47
+ logistic_dir.mkdir(exist_ok=True, parents=True)
48
+ logistic_test_result_file = logistic_dir / "metrics.json"
49
+ if logistic_test_result_file.exists():
50
+ print("The models have been trained, and skip ...")
51
+ exit(0)
52
+
53
+ feature_train = pd.read_parquet(
54
+ features_data_dir / "features_with_label" / "train_features"
55
+ )
56
+ feature_test = pd.read_parquet(
57
+ features_data_dir / "features_with_label" / "test_features"
58
+ )
59
+
60
+ feature_train = feature_train.sort_values(["subject_id", "prediction_time"]).sample(
61
+ frac=1.0,
62
+ random_state=42,
63
+ replace=False,
64
+ )
65
+
66
+ if feature_processor_path.exists():
67
+ with open(feature_processor_path, "rb") as f:
68
+ feature_processor = pickle.load(f)
69
+ else:
70
+ age_scaler, gender_encoder, race_encoder = (
71
+ StandardScaler(),
72
+ OneHotEncoder(handle_unknown="ignore"),
73
+ OneHotEncoder(handle_unknown="ignore"),
74
+ )
75
+ age_scaler = age_scaler.fit(feature_train[["age_at_index"]].to_numpy())
76
+ gender_encoder = gender_encoder.fit(
77
+ feature_train[["gender_concept_id"]].to_numpy()
78
+ )
79
+ race_encoder = race_encoder.fit(feature_train[["race_concept_id"]].to_numpy())
80
+ feature_processor = {
81
+ "age_scaler": age_scaler,
82
+ "gender_encoder": gender_encoder,
83
+ "race_encoder": race_encoder,
84
+ }
85
+ with open(feature_processor_path, "wb") as f:
86
+ pickle.dump(feature_processor, f)
87
+
88
+ if logistic_test_result_file.exists():
89
+ print(
90
+ f"The results for logistic regression already exist at {logistic_test_result_file}"
91
+ )
92
+ else:
93
+ logistic_model_file = logistic_dir / "model.pickle"
94
+ if logistic_model_file.exists():
95
+ print(
96
+ f"The logistic regression model already exist, loading it from {logistic_model_file}"
97
+ )
98
+ with open(logistic_model_file, "rb") as f:
99
+ model = pickle.load(f)
100
+ else:
101
+ train_dataset = prepare_dataset(feature_train, feature_processor)
102
+ # Train logistic regression
103
+ model = LogisticRegressionCV(scoring="roc_auc", random_state=42)
104
+ model.fit(train_dataset["features"], train_dataset["boolean_value"])
105
+ with open(logistic_model_file, "wb") as f:
106
+ pickle.dump(model, f)
107
+
108
+ test_dataset = prepare_dataset(feature_test, feature_processor)
109
+ y_pred = model.predict_proba(test_dataset["features"])[:, 1]
110
+ logistic_predictions = pl.DataFrame(
111
+ {
112
+ "subject_id": test_dataset["subject_id"].tolist(),
113
+ "prediction_time": test_dataset["prediction_time"],
114
+ "predicted_boolean_probability": y_pred.tolist(),
115
+ "predicted_boolean_value": None,
116
+ "boolean_value": test_dataset["boolean_value"].astype(bool).tolist(),
117
+ }
118
+ )
119
+ logistic_predictions = logistic_predictions.with_columns(
120
+ pl.col("predicted_boolean_value").cast(pl.Boolean())
121
+ )
122
+ logistic_test_predictions = logistic_dir / "test_predictions"
123
+ logistic_test_predictions.mkdir(exist_ok=True, parents=True)
124
+ logistic_predictions.write_parquet(
125
+ logistic_test_predictions / "predictions.parquet"
126
+ )
127
+
128
+ roc_auc = roc_auc_score(test_dataset["boolean_value"], y_pred)
129
+ precision, recall, _ = precision_recall_curve(
130
+ test_dataset["boolean_value"], y_pred
131
+ )
132
+ pr_auc = auc(recall, precision)
133
+
134
+ metrics = {"roc_auc": roc_auc, "pr_auc": pr_auc}
135
+ print("Logistic:", features_data_dir.name, metrics)
136
+ with open(logistic_test_result_file, "w") as f:
137
+ json.dump(metrics, f, indent=4)
138
+
139
+
140
+ if __name__ == "__main__":
141
+ parser = argparse.ArgumentParser(
142
+ description="Train logistic regression model with cehrgpt features"
143
+ )
144
+ parser.add_argument(
145
+ "--features_data_dir",
146
+ required=True,
147
+ help="Directory containing training and test feature files",
148
+ )
149
+ parser.add_argument(
150
+ "--output_dir", required=True, help="Directory to save the output results"
151
+ )
152
+ main(parser.parse_args())
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.0.2
3
+ Version: 0.1.1
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,13 +12,15 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.8
15
+ Requires-Dist: cehrbert==1.4.5
16
+ Requires-Dist: cehrbert_data==0.0.11
16
17
  Requires-Dist: openai==1.54.3
17
18
  Requires-Dist: optuna==4.0.0
18
- Requires-Dist: transformers==4.40.0
19
+ Requires-Dist: transformers==4.44.1
19
20
  Requires-Dist: tokenizers==0.19.0
20
21
  Requires-Dist: peft==0.10.0
21
- Requires-Dist: trl==0.11.4
22
+ Requires-Dist: lightgbm
23
+ Requires-Dist: polars
22
24
  Provides-Extra: dev
23
25
  Requires-Dist: pre-commit; extra == "dev"
24
26
  Requires-Dist: pytest; extra == "dev"
@@ -29,14 +31,15 @@ Requires-Dist: hypothesis; extra == "dev"
29
31
  Requires-Dist: black; extra == "dev"
30
32
  Provides-Extra: flash-attn
31
33
  Requires-Dist: flash_attn; extra == "flash-attn"
34
+ Dynamic: license-file
32
35
 
33
36
  # CEHRGPT
34
37
 
35
38
  [![PyPI - Version](https://img.shields.io/pypi/v/cehrgpt)](https://pypi.org/project/cehrgpt/)
36
39
  ![Python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)
37
- [![tests](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt-public/actions/workflows/tests.yml)
38
- [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt-public/blob/main/LICENSE)
39
- [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt-public.svg)](https://github.com/knatarajan-lab/cehrgpt-public/graphs/contributors)
40
+ [![tests](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml/badge.svg)](https://github.com/knatarajan-lab/cehrgpt/actions/workflows/tests.yaml)
41
+ [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/knatarajan-lab/cehrgpt/blob/main/LICENSE)
42
+ [![contributors](https://img.shields.io/github/contributors/knatarajan-lab/cehrgpt.svg)](https://github.com/knatarajan-lab/cehrgpt/graphs/contributors)
40
43
 
41
44
  ## Description
42
45
  CEHRGPT is a synthetic data generation model developed to handle structured electronic health records (EHR) with enhanced privacy and reliability. It leverages state-of-the-art natural language processing techniques to create realistic, anonymized patient data that can be used for research and development without compromising patient privacy.