cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,106 @@
1
+ import os
2
+ from textwrap import dedent
3
+
4
+ import numpy as np
5
+ from jinja2 import BaseLoader, Environment
6
+ from openai import OpenAI
7
+ from pydantic import BaseModel
8
+
9
+ MODEL = "gpt-4o-2024-08-06"
10
+ TEMPLATE = """
11
+ You are a medical professional tasked with generating a synthetic patient sequence using the CEHR-GPT format, outlined as follows:
12
+
13
+ [year]: Represents the start year of the patient sequence.
14
+ [age]: Represents the start age of the patient sequence.
15
+ [gender]: Patient's gender, allowed values are "Male," "Female," and "Unknown."
16
+ [race]: Patient's race, allowed values are "White," "Black," "Asian," and "Unknown."
17
+ [VS]: Marks the start of a visit.
18
+ [VE]: Marks the end of a visit.
19
+ [VT]: Type of visit, with allowed values "9202" (outpatient), "9201" (inpatient), and "9203" (emergency room).
20
+ [C_i]: Clinical concept represented by an OMOP concept ID (could be a drug, condition, or procedure).
21
+ [ATT]: Artificial time tokens, representing time intervals in days (e.g., "D1," "D10").
22
+ [i-ATT]: Inpatient-specific artificial time tokens, representing intervals within inpatient stays (e.g., "i-D1"), these tokens should only appear in inpatient visits.
23
+ Each sequence can encompass multiple concepts within each visit and vary from one to ten visits, reflective of real-world clinical scenarios. All clinical concepts must correspond to valid OMOP IDs. The sequence must end on [VE]
24
+
25
+ Example of a sequence:
26
+
27
+ {
28
+ "seq": ['year:2008', 'age:28', '8532', '8527', '[VS]', '9202', '4301351',
29
+ '19078924', '35603428', '35603429', '40221381', '40223365',
30
+ '4155151', '4239130', '42536500', '4294382', '2108974', '433736',
31
+ '[VE]', 'D7', '[VS]', '9201', '43011850', '35603429', '35603600',
32
+ '35605482', '40163870', '40169706', '40221381', '35603428',
33
+ '19078921', '40244026', '948080', '1154615', '1593063', '4056973',
34
+ '4155151', '4194550', '3047860', '35604843', '43011962', '4160730',
35
+ 'i-D1', '35604843', '40162587', '43011962', '433736', '948080',
36
+ '0', '[VE]', 'D14', '[VS]', '9202', '4019497', '[VE]', 'D26', '[VS]',
37
+ '1', '4019497', '[VE]', 'D198', '[VS]', '581477', '433736',
38
+ '[VE]', 'D19', '[VS]', '581477', '194152', '320128', '40483287', '433736', '[VE]']
39
+ }
40
+
41
+ When creating the sequence, please use the demographic tokens {{ demographic_prompt }} to construct a realistic and medically plausible patient trajectory.
42
+ """
43
+
44
+
45
+ class PatientSequence(BaseModel):
46
+ seq: list[str]
47
+
48
+
49
+ if __name__ == "__main__":
50
+ import argparse
51
+ import uuid
52
+
53
+ import pandas as pd
54
+ from tqdm import tqdm
55
+
56
+ parser = argparse.ArgumentParser("ChatGPT patient generation")
57
+ parser.add_argument(
58
+ "--demographic_data",
59
+ dest="demographic_data",
60
+ action="store",
61
+ help="The path for your demographic_data",
62
+ required=True,
63
+ )
64
+ parser.add_argument(
65
+ "--output_folder",
66
+ dest="output_folder",
67
+ action="store",
68
+ help="The path for your output_folder",
69
+ required=True,
70
+ )
71
+ parser.add_argument(
72
+ "--num_sequences",
73
+ dest="num_sequences",
74
+ action="store",
75
+ type=int,
76
+ help="The path for your output_folder",
77
+ required=True,
78
+ )
79
+ args = parser.parse_args()
80
+ # Create a Jinja2 environment and render the template
81
+ env = Environment(loader=BaseLoader())
82
+ template = env.from_string(TEMPLATE)
83
+ demographics = pd.read_parquet(args.demographic_data)
84
+
85
+ for _ in tqdm(range(args.num_sequences)):
86
+ demographic_tokens = str(demographics.sample(1).concept_ids.iloc[0].tolist())
87
+ prompt = template.render(demographic_prompt=demographic_tokens)
88
+ client = OpenAI(api_key=os.environ.get("OPEN_AI_KEY"))
89
+ completion = client.beta.chat.completions.parse(
90
+ model=MODEL,
91
+ messages=[
92
+ {"role": "system", "content": "You are a medical professional."},
93
+ {"role": "user", "content": dedent(prompt)},
94
+ ],
95
+ response_format=PatientSequence,
96
+ )
97
+ patient_sequence = completion.choices[0].message.parsed.seq
98
+ pd.DataFrame(
99
+ [
100
+ {
101
+ "concept_ids": patient_sequence,
102
+ "concept_values": np.zeros_like(patient_sequence),
103
+ }
104
+ ],
105
+ columns=["concept_ids", "concept_values"],
106
+ ).to_parquet(os.path.join(args.output_folder, f"{uuid.uuid4()}.parquet"))
@@ -0,0 +1,333 @@
1
+ import datetime
2
+ import os
3
+ import random
4
+ import uuid
5
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from cehrbert.runners.runner_util import load_parquet_as_dataset
11
+ from transformers import GenerationConfig
12
+ from transformers.utils import is_flash_attn_2_available, logging
13
+
14
+ from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
15
+ from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
16
+ from cehrgpt.gpt_utils import get_cehrgpt_output_folder
17
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
18
+ from cehrgpt.models.special_tokens import END_TOKEN
19
+ from cehrgpt.models.tokenization_hf_cehrgpt import (
20
+ NA,
21
+ CehrGptTokenizer,
22
+ is_valid_valid_bin,
23
+ )
24
+
25
+ LOG = logging.get_logger("transformers")
26
+
27
+
28
+ def normalize_value(
29
+ seq: Sequence[str],
30
+ values: Sequence[str],
31
+ tokenizer: CehrGptTokenizer,
32
+ ) -> Tuple[
33
+ Sequence[str],
34
+ Optional[Sequence[Optional[int]]],
35
+ Optional[Sequence[Optional[float]]],
36
+ Optional[Sequence[Optional[str]]],
37
+ Optional[Sequence[str]],
38
+ ]:
39
+ concepts = []
40
+ number_as_values = []
41
+ concept_as_values = []
42
+ is_numeric_types = []
43
+ units = []
44
+ for concept, value in zip(seq, values):
45
+ if concept == END_TOKEN:
46
+ break
47
+ number_as_value = None
48
+ concept_as_value = value if value and value.isnumeric() else None
49
+ is_numeric_type = 0
50
+ unit = NA
51
+ # If concept is numeric, we expect the next token to be a value bin
52
+ if is_valid_valid_bin(value):
53
+ converted_value, unit = tokenizer.denormalize(concept, value)
54
+ if isinstance(converted_value, float):
55
+ number_as_value = converted_value
56
+ is_numeric_type = 1
57
+
58
+ concepts.append(concept)
59
+ number_as_values.append(number_as_value)
60
+ concept_as_values.append(concept_as_value)
61
+ is_numeric_types.append(is_numeric_type)
62
+ units.append(unit)
63
+
64
+ return (
65
+ concepts,
66
+ is_numeric_types,
67
+ number_as_values,
68
+ concept_as_values,
69
+ units,
70
+ )
71
+
72
+
73
+ def generate_single_batch(
74
+ model: CEHRGPT2LMHeadModel,
75
+ tokenizer: CehrGptTokenizer,
76
+ prompts: List[List[int]],
77
+ max_new_tokens=512,
78
+ mini_num_of_concepts=1,
79
+ top_p=0.95,
80
+ top_k=50,
81
+ temperature=1.0,
82
+ repetition_penalty=1.0,
83
+ num_beams=1,
84
+ num_beam_groups=1,
85
+ epsilon_cutoff=0.0,
86
+ device: Any = "cpu",
87
+ ) -> Dict[str, Any]:
88
+ with torch.no_grad():
89
+ generation_config = GenerationConfig(
90
+ repetition_penalty=repetition_penalty,
91
+ max_length=max_new_tokens,
92
+ min_length=mini_num_of_concepts,
93
+ temperature=temperature,
94
+ top_p=top_p,
95
+ top_k=top_k,
96
+ bos_token_id=tokenizer.end_token_id,
97
+ eos_token_id=tokenizer.end_token_id,
98
+ pad_token_id=tokenizer.pad_token_id,
99
+ do_sample=True,
100
+ use_cache=True,
101
+ return_dict_in_generate=True,
102
+ output_attentions=False,
103
+ output_hidden_states=False,
104
+ output_scores=False,
105
+ renormalize_logits=True,
106
+ num_beams=num_beams,
107
+ num_beam_groups=num_beam_groups,
108
+ epsilon_cutoff=epsilon_cutoff,
109
+ )
110
+ batched_prompts = torch.tensor(prompts).to(device)
111
+ results = model.generate(
112
+ inputs=batched_prompts,
113
+ generation_config=generation_config,
114
+ lab_token_ids=tokenizer.lab_token_ids,
115
+ )
116
+
117
+ sequences = [
118
+ tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
119
+ for seq in results.sequences
120
+ ]
121
+ if results.sequence_vals is not None:
122
+ values = [
123
+ tokenizer.decode_value(values.cpu().numpy(), skip_special_tokens=False)
124
+ for values in results.sequence_vals
125
+ ]
126
+ else:
127
+ values = np.zeros_like(sequences)
128
+ values.fill(NA)
129
+ if results.sequence_val_masks is not None:
130
+ value_indicators = results.sequence_val_masks.cpu().numpy()
131
+ else:
132
+ value_indicators = np.zeros_like(sequences, dtype=np.int32).astype(bool)
133
+ return {
134
+ "sequences": sequences,
135
+ "values": values,
136
+ "value_indicators": value_indicators,
137
+ }
138
+
139
+
140
+ def main(args):
141
+ if torch.cuda.is_available():
142
+ device = torch.device("cuda")
143
+ else:
144
+ device = torch.device("cpu")
145
+
146
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
147
+ cehrgpt_model = (
148
+ CEHRGPT2LMHeadModel.from_pretrained(
149
+ args.model_folder,
150
+ attn_implementation=(
151
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
152
+ ),
153
+ torch_dtype=(
154
+ torch.bfloat16
155
+ if is_flash_attn_2_available() and args.use_bfloat16
156
+ else torch.float32
157
+ ),
158
+ )
159
+ .eval()
160
+ .to(device)
161
+ )
162
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
163
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
164
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
165
+
166
+ folder_name = get_cehrgpt_output_folder(args, cehrgpt_tokenizer)
167
+ output_folder_name = os.path.join(
168
+ args.output_folder, folder_name, "generated_sequences"
169
+ )
170
+
171
+ if not os.path.exists(output_folder_name):
172
+ os.makedirs(output_folder_name)
173
+
174
+ # Determine whether we will use the demographics with the long sequences
175
+ max_seq_allowed = (
176
+ cehrgpt_model.config.n_positions
177
+ if args.drop_long_sequences
178
+ else np.iinfo(np.int32).max
179
+ )
180
+
181
+ LOG.info(f"Loading tokenizer at {args.model_folder}")
182
+ LOG.info(f"Loading model at {args.model_folder}")
183
+ LOG.info(f"Write sequences to {output_folder_name}")
184
+ LOG.info(f"Context window {args.context_window}")
185
+ LOG.info(f"Max sequence allowed {max_seq_allowed}")
186
+ LOG.info(f"Temperature {args.temperature}")
187
+ LOG.info(f"Repetition Penalty {args.repetition_penalty}")
188
+ LOG.info(f"Sampling Strategy {args.sampling_strategy}")
189
+ LOG.info(f"Num beam {args.num_beams}")
190
+ LOG.info(f"Num beam groups {args.num_beam_groups}")
191
+ LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
192
+ LOG.info(f"Top P {args.top_p}")
193
+ LOG.info(f"Top K {args.top_k}")
194
+ LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
195
+
196
+ dataset = load_parquet_as_dataset(args.demographic_data_path)
197
+ total_rows = len(dataset)
198
+
199
+ num_of_batches = args.num_of_patients // args.batch_size + 1
200
+ sequence_to_flush = []
201
+ current_person_id = 1
202
+ for i in range(num_of_batches):
203
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
204
+
205
+ # Randomly pick demographics from the existing population
206
+ random_prompts = []
207
+ iter = 0
208
+ while len(random_prompts) < args.batch_size:
209
+ for row in dataset.select(
210
+ random.sample(range(total_rows), k=args.batch_size)
211
+ ):
212
+ if (
213
+ args.min_num_of_concepts
214
+ <= len(row["concept_ids"])
215
+ <= max_seq_allowed
216
+ ):
217
+ random_prompts.append(
218
+ cehrgpt_tokenizer.encode(row["concept_ids"][:START_TOKEN_SIZE])
219
+ )
220
+ iter += 1
221
+ if not random_prompts and iter > 10:
222
+ raise RuntimeError(
223
+ f"The length of concept_ids in {args.demographic_data_path} does not qualify!"
224
+ )
225
+
226
+ # Make sure the batch does not exceed batch_size
227
+ batch_sequences = generate_single_batch(
228
+ cehrgpt_model,
229
+ cehrgpt_tokenizer,
230
+ random_prompts[: args.batch_size],
231
+ max_new_tokens=args.context_window,
232
+ mini_num_of_concepts=args.min_num_of_concepts,
233
+ top_p=args.top_p,
234
+ top_k=args.top_k,
235
+ temperature=args.temperature,
236
+ repetition_penalty=args.repetition_penalty,
237
+ num_beams=args.num_beams,
238
+ num_beam_groups=args.num_beam_groups,
239
+ epsilon_cutoff=args.epsilon_cutoff,
240
+ device=device,
241
+ )
242
+
243
+ # Clear the cache
244
+ torch.cuda.empty_cache()
245
+
246
+ for concept_ids, value_indicators, values in zip(
247
+ batch_sequences["sequences"],
248
+ batch_sequences["value_indicators"],
249
+ batch_sequences["values"],
250
+ ):
251
+ (
252
+ concept_ids,
253
+ is_numeric_types,
254
+ number_as_values,
255
+ concept_as_values,
256
+ units,
257
+ ) = normalize_value(concept_ids, values, cehrgpt_tokenizer)
258
+ output = {"concept_ids": concept_ids, "person_id": current_person_id}
259
+ if is_numeric_types is not None:
260
+ output["is_numeric_types"] = is_numeric_types
261
+ if number_as_values is not None:
262
+ output["number_as_values"] = number_as_values
263
+ if concept_as_values is not None:
264
+ output["concept_as_values"] = concept_as_values
265
+ if value_indicators is not None:
266
+ output["concept_value_masks"] = value_indicators
267
+ if units is not None:
268
+ output["units"] = units
269
+
270
+ sequence_to_flush.append(output)
271
+ current_person_id += 1
272
+
273
+ if len(sequence_to_flush) >= args.buffer_size:
274
+ LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Batch {i}")
275
+ pd.DataFrame(
276
+ sequence_to_flush,
277
+ columns=[
278
+ "concept_ids",
279
+ "person_id",
280
+ "is_numeric_types",
281
+ "number_as_values",
282
+ "concept_as_values",
283
+ "concept_value_masks",
284
+ "units",
285
+ ],
286
+ ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}.parquet"))
287
+ sequence_to_flush.clear()
288
+
289
+ if len(sequence_to_flush) > 0:
290
+ LOG.info(f"{datetime.datetime.now()}: Flushing to the Disk at Final Batch")
291
+ pd.DataFrame(
292
+ sequence_to_flush,
293
+ columns=[
294
+ "concept_ids",
295
+ "person_id",
296
+ "is_numeric_types",
297
+ "number_as_values",
298
+ "concept_as_values",
299
+ "concept_value_masks",
300
+ "units",
301
+ ],
302
+ ).to_parquet(os.path.join(output_folder_name, f"{uuid.uuid4()}-last.parquet"))
303
+
304
+
305
+ def create_arg_parser():
306
+ base_arg_parser = create_inference_base_arg_parser(
307
+ description="Arguments for generating patient sequences"
308
+ )
309
+ base_arg_parser.add_argument(
310
+ "--num_of_patients",
311
+ dest="num_of_patients",
312
+ action="store",
313
+ type=int,
314
+ help="The number of patients that will be generated",
315
+ required=True,
316
+ )
317
+ base_arg_parser.add_argument(
318
+ "--demographic_data_path",
319
+ dest="demographic_data_path",
320
+ action="store",
321
+ help="The path for your concept_path",
322
+ required=True,
323
+ )
324
+ base_arg_parser.add_argument(
325
+ "--drop_long_sequences",
326
+ dest="drop_long_sequences",
327
+ action="store_true",
328
+ )
329
+ return base_arg_parser
330
+
331
+
332
+ if __name__ == "__main__":
333
+ main(create_arg_parser().parse_args())