cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,146 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ # Define race mapping
5
+ race_mapping = {
6
+ "38003613": "8557",
7
+ "38003610": "8557",
8
+ "38003579": "8515",
9
+ "44814653": "0",
10
+ }
11
+
12
+ # Invalid age groups
13
+ invalid_age_groups = [
14
+ "age:100-110",
15
+ "age:110-120",
16
+ "age:120-130",
17
+ "age:130-140",
18
+ "age:140-150",
19
+ "age:150-160",
20
+ "age:160-170",
21
+ "age:170-180",
22
+ "age:180-190",
23
+ "age:190-200",
24
+ "age:640-650",
25
+ "age:680-690",
26
+ "age:730-740",
27
+ "age:740-750",
28
+ "age:890-900",
29
+ "age:900-910",
30
+ "age:-10-0",
31
+ ]
32
+
33
+
34
+ def age_group_func(age_str):
35
+ """
36
+ Categorize an age into a 10-year age group.
37
+
38
+ Args:
39
+ age_str (str): A string containing the age in the format "age:XX".
40
+
41
+ Returns:
42
+ str: A string representing the 10-year age group "age:XX-XX".
43
+ """
44
+ age = int(age_str.split(":")[1])
45
+ group_number = age // 10
46
+ return f"age:{group_number * 10}-{(group_number + 1) * 10}"
47
+
48
+
49
+ def map_race(race):
50
+ return race_mapping.get(race, race)
51
+
52
+
53
+ def main(args):
54
+ # Load data
55
+ patient_sequence = pd.read_parquet(args.patient_sequence)
56
+ # Extract and preprocess demographics
57
+ demographics = patient_sequence.concept_ids.apply(
58
+ lambda concept_ids: concept_ids[:4]
59
+ )
60
+ patient_sequence["demographics"] = demographics
61
+ year = demographics.apply(lambda concepts: concepts[0])
62
+ age = demographics.apply(lambda concepts: concepts[1]).apply(age_group_func)
63
+ gender = demographics.apply(lambda concepts: concepts[2])
64
+ race = demographics.apply(lambda concepts: concepts[3])
65
+ death = patient_sequence.concept_ids.apply(
66
+ lambda concept_ids: int(concept_ids[-2] == "[DEATH]")
67
+ )
68
+
69
+ patient_sequence["year"] = year
70
+ patient_sequence["age"] = age
71
+ patient_sequence["gender"] = gender
72
+ patient_sequence["race"] = race
73
+ patient_sequence["death"] = death
74
+
75
+ demographics = patient_sequence[
76
+ ["person_id", "death", "year", "age", "gender", "race", "split"]
77
+ ]
78
+ demographics["race"] = demographics.race.apply(map_race)
79
+
80
+ demographics_clean = demographics[
81
+ (demographics.gender != "0") & (~demographics.age.isin(invalid_age_groups))
82
+ ]
83
+ patient_sequence_clean = patient_sequence[
84
+ patient_sequence.person_id.isin(demographics_clean.person_id)
85
+ ]
86
+
87
+ # Calculate probabilities
88
+ probs = (
89
+ demographics_clean.groupby(["age"])["person_id"].count()
90
+ / len(demographics_clean)
91
+ ).reset_index()
92
+ probs.rename(columns={"person_id": "prob"}, inplace=True)
93
+
94
+ # Adjust probabilities
95
+ np.random.seed(42)
96
+ x = np.asarray(list(reversed(range(1, 11))))
97
+ adjusted_probs = probs.prob * pd.Series(x)
98
+ adjusted_probs = adjusted_probs / adjusted_probs.sum()
99
+ probs["adjusted_prob"] = adjusted_probs
100
+
101
+ demographics_for_sampling = patient_sequence_clean[
102
+ ["year", "age", "race", "gender", "person_id"]
103
+ ].merge(probs, on="age")
104
+ demographics_for_sampling["adjusted_prob"] = (
105
+ demographics_for_sampling.adjusted_prob
106
+ / demographics_for_sampling.adjusted_prob.sum()
107
+ )
108
+
109
+ # Train/Validation Split
110
+ causal_train_split = demographics_for_sampling.sample(
111
+ args.num_patients, replace=False, weights="adjusted_prob", random_state=1
112
+ )
113
+ causal_train_split["split"] = "train"
114
+ causal_val_split = demographics_for_sampling[
115
+ ~demographics_for_sampling.person_id.isin(causal_train_split.person_id)
116
+ ]
117
+ causal_val_split["split"] = "validation"
118
+
119
+ causal_train_val_split = pd.concat([causal_train_split, causal_val_split])
120
+
121
+ # Save outputs
122
+ causal_train_val_split.to_parquet(args.output_folder, index=False)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ import argparse
127
+
128
+ parser = argparse.ArgumentParser(
129
+ description="Arguments for a causal patient split by age groups"
130
+ )
131
+ parser.add_argument(
132
+ "--patient_sequence",
133
+ required=True,
134
+ )
135
+ parser.add_argument(
136
+ "--num_patients",
137
+ default=1_000_000,
138
+ type=int,
139
+ required=False,
140
+ )
141
+ parser.add_argument(
142
+ "--output_folder",
143
+ required=True,
144
+ )
145
+ # Call the main function with parsed arguments
146
+ main(parser.parse_args())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: cehrgpt
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,11 +12,12 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.3
15
+ Requires-Dist: cehrbert==1.3.8
16
16
  Requires-Dist: openai==1.54.3
17
17
  Requires-Dist: optuna==4.0.0
18
18
  Requires-Dist: transformers==4.40.0
19
- Requires-Dist: tokenizers==0.19
19
+ Requires-Dist: tokenizers==0.19.0
20
+ Requires-Dist: peft==0.10.0
20
21
  Requires-Dist: trl==0.11.4
21
22
  Provides-Extra: dev
22
23
  Requires-Dist: pre-commit; extra == "dev"
@@ -50,11 +51,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
50
51
  To install CEHRGPT, clone this repository and install the required dependencies.
51
52
 
52
53
  ```bash
53
- git clone https://github.com/knatarajan-lab/cehrgpt-public.git
54
- cd cehrgpt-public
54
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
55
+ cd cehrgpt
55
56
  pip install .
56
57
  ```
57
58
 
59
+ ## Pretrain
60
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
61
+ ```bash
62
+ mkdir test_results
63
+ # This is NOT required when streaming is set to true
64
+ mkdir test_dataset_prepared
65
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
66
+ ```
67
+
68
+ ## Generate synthetic sequences
69
+ Generate synthetic sequences using the trained model
70
+ ```bash
71
+ export TRANSFORMERS_VERBOSITY=info
72
+ export CUDA_VISIBLE_DEVICES="0"
73
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
74
+ --model_folder test_results \
75
+ --tokenizer_folder test_results \
76
+ --output_folder test_results \
77
+ --num_of_patients 128 \
78
+ --batch_size 32 \
79
+ --buffer_size 128 \
80
+ --context_window 1024 \
81
+ --sampling_strategy TopPStrategy \
82
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
83
+ --epsilon_cutoff 0.00 \
84
+ --demographic_data_path sample_data/pretrain
85
+ ```
86
+
87
+ ## Convert synthetic sequences to OMOP
88
+ ```bash
89
+ # omop converter requires the OHDSI vocabulary
90
+ export OMOP_VOCAB_DIR = ""
91
+ # the omop derived tables need to be built using pyspark
92
+ export SPARK_WORKER_INSTANCES="1"
93
+ export SPARK_WORKER_CORES="8"
94
+ export SPARK_EXECUTOR_CORES="2"
95
+ export SPARK_DRIVER_MEMORY="2g"
96
+ export SPARK_EXECUTOR_MEMORY="2g"
97
+
98
+ # Convert the sequences, create the omop derived tables
99
+ sh scripts/omop_pipeline.sh \
100
+ test_results/top_p10000/generated_sequences/ \
101
+ test_results/top_p10000/restored_omop/ \
102
+ $OMOP_VOCAB_DIR
103
+ ```
104
+
58
105
  ## Citation
59
106
  ```
60
107
  @article{cehrgpt2024,
@@ -63,4 +110,3 @@ pip install .
63
110
  journal={arXiv preprint arXiv:2402.04400},
64
111
  year={2024}
65
112
  }
66
- ```
@@ -13,19 +13,22 @@ cehrgpt/analysis/privacy/utils.py,sha256=CRA4H9mPLBjMQGKzZ_x_3ro3tMap-NjsMDVqSOj
13
13
  cehrgpt/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  cehrgpt/data/hf_cehrgpt_dataset.py,sha256=7hvjjqE8WInVuRvAtNkFI_J-xluFBv1Ij4TPTdUxPM4,2570
15
15
  cehrgpt/data/hf_cehrgpt_dataset_collator.py,sha256=RYw5Isrwa4sdyQQ3Nf3cu7xPDA3m-c5ecCFf_y1TJKY,20497
16
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=aQ0gsThOFhrh9ExpJhRmuiwN9ShIKheLgCIci-N7HOM,4305
16
+ cehrgpt/data/hf_cehrgpt_dataset_mapping.py,sha256=IjGwLKbEfNPxH3hsNmb8p48_imHnMWtslDK6f7R_1pc,16053
17
+ cehrgpt/data/hf_cehrgpt_dpo_collator.py,sha256=cqDK0SUOt3yAqUHWKGuLVi3WmmUMZ6eyxTv9fC9idZA,2787
18
+ cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py,sha256=uCLF5VEsyZAG1aNwqEM6Jy5Lx7bI5ALku52Z6Anine0,2574
17
19
  cehrgpt/generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
20
  cehrgpt/generation/chatgpt_generation.py,sha256=SrnLwHLdNtnAOEg36gNjqfoT9yd12iyPgpZffL2AFJo,4428
19
21
  cehrgpt/generation/generate_batch_hf_gpt_sequence.py,sha256=-WLpKlulVVDJSdA2jXyp87gfLW4Q3aAtwULK8fDtn_E,11408
20
- cehrgpt/generation/omop_converter_batch.py,sha256=SDpWjqzi8dsgVzbbFes42GMdZEvrJ3sm4RbP5UpmIlk,25280
22
+ cehrgpt/generation/generate_paired_cehrgpt_sequence.py,sha256=fLu3SHhRe_ZQfS09ebOktq2dekStgYfxmbrRawZQAO4,8280
23
+ cehrgpt/generation/omop_converter_batch.py,sha256=-c0AlDVy5pJ5Afhr8ERiCHhoRrEk8ozJi3g0yFdWaMI,25348
21
24
  cehrgpt/generation/omop_entity.py,sha256=Q5Sr0AlyuPAm1FRPfnJO13q-u1fqRgYVHXruZ9g4xNE,19400
22
25
  cehrgpt/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
26
  cehrgpt/models/config.py,sha256=xek4W_siO7WtMAKE7zDsENotsIE70F8dcW-PTC0kBKk,9700
24
- cehrgpt/models/hf_cehrgpt.py,sha256=YrHhT8c92xcOVTb6FjFQokyHrDOcXgEDMBs0BksSBpA,75739
27
+ cehrgpt/models/hf_cehrgpt.py,sha256=CKseTvGkBFwXK40Z_uKD1_d84oSYCFqKmHI0qtdk72g,75757
25
28
  cehrgpt/models/hf_modeling_outputs.py,sha256=LaWa1jI6BRIKMEjWOy1QUeOfTur5y_p2c-JyuGVTdtw,10301
26
29
  cehrgpt/models/pretrained_embeddings.py,sha256=vLLVs17TLpXRqCVEWQxGGwPHkUJUO7laNTeBuyBK_yk,3238
27
30
  cehrgpt/models/special_tokens.py,sha256=-a7HPJBbdIH0qQ6B3CcRKqvpG6FZlm4nbVPTswGSJ4U,485
28
- cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=jQR5aHjdHhS14nC1qnqDmybS1gpB27WK2-qVNz9cxW0,42156
31
+ cehrgpt/models/tokenization_hf_cehrgpt.py,sha256=JAZjnmQq-JjUxZK7XIsqdZB07ZB7BC2WraCjpO_6AOM,42161
29
32
  cehrgpt/omop/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
33
  cehrgpt/omop/condition_era.py,sha256=hPZALz2XaWnro_1bwIYNkI48foOJjueyg3CZ1BliCno,626
31
34
  cehrgpt/omop/observation_period.py,sha256=TRMgv5Ya2RaS2im7oQ6BLC_5JL9EJYNYR62ApxIuHvg,1211
@@ -35,11 +38,17 @@ cehrgpt/omop/sample_omop_tables.py,sha256=2JZ8BNSvssceinwFanvuCRh-YlKrKn25U9w1pL
35
38
  cehrgpt/omop/queries/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
39
  cehrgpt/omop/queries/condition_era.py,sha256=LFB6vBAvshHJxtYIRkl7cfrF0kf7ay0piBKpmHBwrpE,2578
37
40
  cehrgpt/omop/queries/observation_period.py,sha256=fpzr5DMNw-QLoSwp2Iatfch88E3hyhZ75usiIdG3A0U,6410
41
+ cehrgpt/rl_finetune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py,sha256=VQHf5vy5i8K1imcqYakhitfAW-d2mnaEzkSoAYSW5kg,26062
43
+ cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py,sha256=nYWYPCaWNjDGEwlo6UHOK1rvOZUx1vuJ8kYuAszI8Zg,17925
44
+ cehrgpt/rl_finetune/ppo_finetune.py,sha256=tSy-C0Kzgj5ffclBIDj-RTj78ZfrLmTESxVxd0n9yuE,13971
45
+ cehrgpt/rl_finetune/ppo_finetune_v2.py,sha256=7dChwKpq4zKmpkcxP4hryqBoIkcwmTJ44_BF8R2RghQ,13285
38
46
  cehrgpt/runners/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
47
  cehrgpt/runners/gpt_runner_util.py,sha256=88HKSVj-ADGBCMo7C3znKSMPnAAALa1iU_6P6i9sD0M,3867
40
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=aGw87ZJuUIH196ryaZzt9D4hCAHVcDyKnvvdVPdipwc,31568
41
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=6xulvnjwy6LDRPIL_zgsYH7sJMiXJ9AvFg3p2o35S6c,16510
42
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=2l1X5bp1zckoFp0rQkxGptXyG8u3PgNw0dqYVDWLYjg,5155
48
+ cehrgpt/runners/hf_cehrgpt_dpo_runner.py,sha256=Z4qNl9CZFC5YvUBc9ZzdOV5wsBFvMTdxfTn4jjtJQ-Y,4583
49
+ cehrgpt/runners/hf_cehrgpt_finetune_runner.py,sha256=reflNRb6YB6f_3jAfzFAdwKtTl6hvdIp9Jc7DC-Sv-U,32580
50
+ cehrgpt/runners/hf_cehrgpt_pretrain_runner.py,sha256=L3UpjtzxuS8a_tshlqpZN_sXnJSs3yzry0GZNT__05A,18200
51
+ cehrgpt/runners/hf_gpt_runner_argument_dataclass.py,sha256=gKVf4BLzNCFiJR7nZVkf-QRcj8fAEVvIUTV-AVH0g_U,5312
43
52
  cehrgpt/runners/hyperparameter_search_util.py,sha256=i4qAb_22JO78l40MSyBPwDgAGuGc96efXmg_833cSSo,9044
44
53
  cehrgpt/time_to_event/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
45
54
  cehrgpt/time_to_event/time_to_event_model.py,sha256=tfXa24l_0q1TBZ68BPRrHRC_3KRWYxrWGIv4myJlIb8,8497
@@ -50,11 +59,12 @@ cehrgpt/time_to_event/config/next_visit_type_prediction.yaml,sha256=WMj2ZutEvHKI
50
59
  cehrgpt/time_to_event/config/t2dm_hf.yaml,sha256=_oMQzh2eJTYzEaMOpmhAzbX-qmdsKlkORELL6HxOxHo,202
51
60
  cehrgpt/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
61
  cehrgpt/tools/ehrshot_benchmark.py,sha256=E-m_5srlYEw7Y7i9twIJWDvrkwNlop-6yZB-80FZid0,2667
62
+ cehrgpt/tools/generate_causal_patient_split_by_age.py,sha256=dmHiPAL_kR1WrhRteIiHH9dwMtMi3PVl8jXm2O06_gI,4177
53
63
  cehrgpt/tools/generate_pretrained_embeddings.py,sha256=lhFSacGv8bMld6qigKZN8Op8eXpFi0DsJuQbWKOWXqI,4160
54
64
  cehrgpt/tools/merge_synthetic_real_dataasets.py,sha256=O1dbQ32Le0t15fwymwAh9mfNVLEWuFwW53DNvESrWbY,7589
55
65
  cehrgpt/tools/upload_omop_tables.py,sha256=vdBAbkeAsGPA4NsyhNjelPVj3gS8yzmS1sKNM1Qk96g,3791
56
- cehrgpt-0.0.1.dist-info/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
57
- cehrgpt-0.0.1.dist-info/METADATA,sha256=BZrsoZe0Smn4JoA3cCI63fC4nBvOVrC9sgZ0Ct1NJsA,3388
58
- cehrgpt-0.0.1.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
59
- cehrgpt-0.0.1.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
60
- cehrgpt-0.0.1.dist-info/RECORD,,
66
+ cehrgpt-0.0.2.dist-info/LICENSE,sha256=LOfC32zkfUIdGm8e_098jPbt8OHKtNWymDzxn2pA9Zk,1093
67
+ cehrgpt-0.0.2.dist-info/METADATA,sha256=joUmDJWMEBvYphrkwYiK273FwSL9okY74D93ncrbvMU,4878
68
+ cehrgpt-0.0.2.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
69
+ cehrgpt-0.0.2.dist-info/top_level.txt,sha256=akNCJBbMSLV8nkOzdVzdy13hMJ5CIQURnAS_YYEDVwA,17
70
+ cehrgpt-0.0.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.1)
2
+ Generator: setuptools (76.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5