cehrgpt 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. cehrgpt-0.0.2/.gitignore +25 -0
  2. {cehrgpt-0.0.1/src/cehrgpt.egg-info → cehrgpt-0.0.2}/PKG-INFO +52 -6
  3. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/README.md +49 -4
  4. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/pyproject.toml +3 -2
  5. cehrgpt-0.0.2/sample_configs/cehrgpt_pretrain_sample_config.yaml +51 -0
  6. cehrgpt-0.0.2/scripts/omop_pipeline.sh +55 -0
  7. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +382 -0
  8. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_collator.py +71 -0
  9. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +61 -0
  10. cehrgpt-0.0.2/src/cehrgpt/generation/generate_paired_cehrgpt_sequence.py +224 -0
  11. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/omop_converter_batch.py +3 -0
  12. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/hf_cehrgpt.py +1 -0
  13. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +2 -2
  14. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +586 -0
  15. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +464 -0
  16. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune.py +394 -0
  17. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune_v2.py +373 -0
  18. cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_dpo_runner.py +119 -0
  19. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -3
  20. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +44 -8
  21. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +4 -0
  22. cehrgpt-0.0.2/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
  23. {cehrgpt-0.0.1 → cehrgpt-0.0.2/src/cehrgpt.egg-info}/PKG-INFO +52 -6
  24. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/SOURCES.txt +11 -0
  25. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/requires.txt +3 -2
  26. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +4 -0
  27. cehrgpt-0.0.2/tests/unit_tests/tools/__init__.py +0 -0
  28. cehrgpt-0.0.1/.gitignore +0 -38
  29. cehrgpt-0.0.1/scripts/omop_pipeline.sh +0 -73
  30. cehrgpt-0.0.1/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +0 -116
  31. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.github/workflows/build-python.yaml +0 -0
  32. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.github/workflows/tests.yaml +0 -0
  33. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.pre-commit-config.yaml +0 -0
  34. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/LICENSE +0 -0
  35. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrain/patient_sequence.parquet +0 -0
  36. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  37. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  38. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/scripts/level_three_evaluation.sh +0 -0
  39. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/scripts/pool_generated_sequences.sh +0 -0
  40. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/setup.cfg +0 -0
  41. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/__init__.py +0 -0
  42. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/__init__.py +0 -0
  43. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/__init__.py +0 -0
  44. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  45. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  46. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  47. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  48. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  49. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  50. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  51. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/cehrgpt_args.py +0 -0
  52. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/__init__.py +0 -0
  53. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/hf_cehrgpt_dataset.py +0 -0
  54. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +0 -0
  55. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/__init__.py +0 -0
  56. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  57. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +0 -0
  58. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/omop_entity.py +0 -0
  59. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/gpt_utils.py +0 -0
  60. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/__init__.py +0 -0
  61. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/config.py +0 -0
  62. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/hf_modeling_outputs.py +0 -0
  63. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  64. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/special_tokens.py +0 -0
  65. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/__init__.py +0 -0
  66. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/condition_era.py +0 -0
  67. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/observation_period.py +0 -0
  68. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/omop_argparse.py +0 -0
  69. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  70. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/__init__.py +0 -0
  71. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  72. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  73. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  74. {cehrgpt-0.0.1/src/cehrgpt/runners → cehrgpt-0.0.2/src/cehrgpt/rl_finetune}/__init__.py +0 -0
  75. {cehrgpt-0.0.1/src/cehrgpt/time_to_event → cehrgpt-0.0.2/src/cehrgpt/runners}/__init__.py +0 -0
  76. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
  77. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hyperparameter_search_util.py +0 -0
  78. {cehrgpt-0.0.1/src/cehrgpt/tools → cehrgpt-0.0.2/src/cehrgpt/time_to_event}/__init__.py +0 -0
  79. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  80. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  81. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  82. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_model.py +0 -0
  83. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_prediction.py +0 -0
  84. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  85. {cehrgpt-0.0.1/tests → cehrgpt-0.0.2/src/cehrgpt/tools}/__init__.py +0 -0
  86. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  87. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  88. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  89. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  90. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  91. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/top_level.txt +0 -0
  92. {cehrgpt-0.0.1/tests/integration_tests → cehrgpt-0.0.2/tests}/__init__.py +0 -0
  93. {cehrgpt-0.0.1/tests/integration_tests/runners → cehrgpt-0.0.2/tests/integration_tests}/__init__.py +0 -0
  94. {cehrgpt-0.0.1/tests/unit_tests → cehrgpt-0.0.2/tests/integration_tests/runners}/__init__.py +0 -0
  95. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
  96. {cehrgpt-0.0.1/tests/unit_tests/models → cehrgpt-0.0.2/tests/unit_tests}/__init__.py +0 -0
  97. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/gpt_utils_test.py +0 -0
  98. {cehrgpt-0.0.1/tests/unit_tests/models/tokenization → cehrgpt-0.0.2/tests/unit_tests/models}/__init__.py +0 -0
  99. {cehrgpt-0.0.1/tests/unit_tests/runners → cehrgpt-0.0.2/tests/unit_tests/models/tokenization}/__init__.py +0 -0
  100. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  101. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  102. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  103. {cehrgpt-0.0.1/tests/unit_tests/tools → cehrgpt-0.0.2/tests/unit_tests/runners}/__init__.py +0 -0
  104. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
  105. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/tokenization_test.py +0 -0
  106. {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -0,0 +1,25 @@
1
+ .DS_Store
2
+ .idea/
3
+ .vscode/
4
+ venv*
5
+ dist/*
6
+
7
+ *ipynb_checkpoints/
8
+ *h5
9
+ *logs
10
+ *nohup.out
11
+ *ipynb
12
+
13
+ *__pycache__/
14
+ .eggs/
15
+ *.dat
16
+ .metastore_db/
17
+
18
+ build/
19
+
20
+ *.out
21
+ *.egg-info/
22
+
23
+ test_data
24
+ test_dataset_prepared
25
+ test*results
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: cehrgpt
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,11 +12,12 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.3
15
+ Requires-Dist: cehrbert==1.3.8
16
16
  Requires-Dist: openai==1.54.3
17
17
  Requires-Dist: optuna==4.0.0
18
18
  Requires-Dist: transformers==4.40.0
19
- Requires-Dist: tokenizers==0.19
19
+ Requires-Dist: tokenizers==0.19.0
20
+ Requires-Dist: peft==0.10.0
20
21
  Requires-Dist: trl==0.11.4
21
22
  Provides-Extra: dev
22
23
  Requires-Dist: pre-commit; extra == "dev"
@@ -50,11 +51,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
50
51
  To install CEHRGPT, clone this repository and install the required dependencies.
51
52
 
52
53
  ```bash
53
- git clone https://github.com/knatarajan-lab/cehrgpt-public.git
54
- cd cehrgpt-public
54
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
55
+ cd cehrgpt
55
56
  pip install .
56
57
  ```
57
58
 
59
+ ## Pretrain
60
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
61
+ ```bash
62
+ mkdir test_results
63
+ # This is NOT required when streaming is set to true
64
+ mkdir test_dataset_prepared
65
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
66
+ ```
67
+
68
+ ## Generate synthetic sequences
69
+ Generate synthetic sequences using the trained model
70
+ ```bash
71
+ export TRANSFORMERS_VERBOSITY=info
72
+ export CUDA_VISIBLE_DEVICES="0"
73
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
74
+ --model_folder test_results \
75
+ --tokenizer_folder test_results \
76
+ --output_folder test_results \
77
+ --num_of_patients 128 \
78
+ --batch_size 32 \
79
+ --buffer_size 128 \
80
+ --context_window 1024 \
81
+ --sampling_strategy TopPStrategy \
82
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
83
+ --epsilon_cutoff 0.00 \
84
+ --demographic_data_path sample_data/pretrain
85
+ ```
86
+
87
+ ## Convert synthetic sequences to OMOP
88
+ ```bash
89
+ # omop converter requires the OHDSI vocabulary
90
+ export OMOP_VOCAB_DIR = ""
91
+ # the omop derived tables need to be built using pyspark
92
+ export SPARK_WORKER_INSTANCES="1"
93
+ export SPARK_WORKER_CORES="8"
94
+ export SPARK_EXECUTOR_CORES="2"
95
+ export SPARK_DRIVER_MEMORY="2g"
96
+ export SPARK_EXECUTOR_MEMORY="2g"
97
+
98
+ # Convert the sequences, create the omop derived tables
99
+ sh scripts/omop_pipeline.sh \
100
+ test_results/top_p10000/generated_sequences/ \
101
+ test_results/top_p10000/restored_omop/ \
102
+ $OMOP_VOCAB_DIR
103
+ ```
104
+
58
105
  ## Citation
59
106
  ```
60
107
  @article{cehrgpt2024,
@@ -63,4 +110,3 @@ pip install .
63
110
  journal={arXiv preprint arXiv:2402.04400},
64
111
  year={2024}
65
112
  }
66
- ```
@@ -19,11 +19,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
19
19
  To install CEHRGPT, clone this repository and install the required dependencies.
20
20
 
21
21
  ```bash
22
- git clone https://github.com/knatarajan-lab/cehrgpt-public.git
23
- cd cehrgpt-public
22
+ git clone https://github.com/knatarajan-lab/cehrgpt.git
23
+ cd cehrgpt
24
24
  pip install .
25
25
  ```
26
26
 
27
+ ## Pretrain
28
+ Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
29
+ ```bash
30
+ mkdir test_results
31
+ # This is NOT required when streaming is set to true
32
+ mkdir test_dataset_prepared
33
+ python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
34
+ ```
35
+
36
+ ## Generate synthetic sequences
37
+ Generate synthetic sequences using the trained model
38
+ ```bash
39
+ export TRANSFORMERS_VERBOSITY=info
40
+ export CUDA_VISIBLE_DEVICES="0"
41
+ python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
42
+ --model_folder test_results \
43
+ --tokenizer_folder test_results \
44
+ --output_folder test_results \
45
+ --num_of_patients 128 \
46
+ --batch_size 32 \
47
+ --buffer_size 128 \
48
+ --context_window 1024 \
49
+ --sampling_strategy TopPStrategy \
50
+ --top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
51
+ --epsilon_cutoff 0.00 \
52
+ --demographic_data_path sample_data/pretrain
53
+ ```
54
+
55
+ ## Convert synthetic sequences to OMOP
56
+ ```bash
57
+ # omop converter requires the OHDSI vocabulary
58
+ export OMOP_VOCAB_DIR = ""
59
+ # the omop derived tables need to be built using pyspark
60
+ export SPARK_WORKER_INSTANCES="1"
61
+ export SPARK_WORKER_CORES="8"
62
+ export SPARK_EXECUTOR_CORES="2"
63
+ export SPARK_DRIVER_MEMORY="2g"
64
+ export SPARK_EXECUTOR_MEMORY="2g"
65
+
66
+ # Convert the sequences, create the omop derived tables
67
+ sh scripts/omop_pipeline.sh \
68
+ test_results/top_p10000/generated_sequences/ \
69
+ test_results/top_p10000/restored_omop/ \
70
+ $OMOP_VOCAB_DIR
71
+ ```
72
+
27
73
  ## Citation
28
74
  ```
29
75
  @article{cehrgpt2024,
@@ -31,5 +77,4 @@ pip install .
31
77
  author={Natarajan, K and others},
32
78
  journal={arXiv preprint arXiv:2402.04400},
33
79
  year={2024}
34
- }
35
- ```
80
+ }
@@ -28,11 +28,12 @@ classifiers = [
28
28
  ]
29
29
 
30
30
  dependencies = [
31
- "cehrbert==1.3.3",
31
+ "cehrbert==1.3.8",
32
32
  "openai==1.54.3",
33
33
  "optuna==4.0.0",
34
34
  "transformers==4.40.0",
35
- "tokenizers==0.19",
35
+ "tokenizers==0.19.0",
36
+ "peft==0.10.0",
36
37
  "trl==0.11.4",
37
38
  ]
38
39
 
@@ -0,0 +1,51 @@
1
+ model_name_or_path: "test_results"
2
+ tokenizer_name_or_path: "test_results"
3
+
4
+ data_folder: "sample_data/pretrain"
5
+ dataset_prepared_path: "test_dataset_prepared"
6
+ validation_split_percentage: 0.05
7
+ validation_split_num: 10
8
+ preprocessing_num_workers: 4
9
+ preprocessing_batch_size: 1000
10
+ streaming: true
11
+
12
+ #Tokenizer
13
+ vocab_size: 50000
14
+ min_frequency: 0
15
+
16
+ do_train: true
17
+ overwrite_output_dir: false
18
+ resume_from_checkpoint: # path to the checkpoint folder
19
+ seed: 42
20
+
21
+ num_hidden_layers: 6
22
+ hidden_size: 768
23
+ n_head: 12
24
+ max_position_embeddings: 1024
25
+
26
+ # torch dataloader configs
27
+ dataloader_num_workers: 4
28
+ dataloader_prefetch_factor: 2
29
+
30
+ output_dir: "test_results"
31
+ save_strategy: "steps"
32
+ evaluation_strategy: "no"
33
+ learning_rate: 0.00005
34
+ per_device_train_batch_size: 4
35
+ per_device_eval_batch_size: 4
36
+ gradient_accumulation_steps: 1
37
+ num_train_epochs: 1
38
+ # When streaming is set to True, max_steps needs to be provided
39
+ max_steps: 1000
40
+ save_steps: 500
41
+
42
+ warmup_steps: 100
43
+ weight_decay: 0.01
44
+ logging_dir: "./logs"
45
+ logging_steps: 100
46
+ save_total_limit: 5
47
+ load_best_model_at_end: false
48
+ metric_for_best_model: "eval_loss"
49
+ greater_is_better: false
50
+
51
+ report_to: "none"
@@ -0,0 +1,55 @@
1
+ #!/bin/bash
2
+
3
+ # Exporting input arguments as environment variables
4
+ export PATIENT_SEQUENCE_FOLDER="$1"
5
+ export OMOP_FOLDER="$2"
6
+ export SOURCE_OMOP_FOLDER="$3"
7
+ export PATIENT_SPLITS_FOLDER="$SOURCE_OMOP_FOLDER/patient_splits"
8
+
9
+ # Echoing the values of the environment variables
10
+ echo "PATIENT_SEQUENCE_FOLDER=$PATIENT_SEQUENCE_FOLDER"
11
+ echo "OMOP_FOLDER=$OMOP_FOLDER"
12
+ echo "SOURCE_OMOP_FOLDER=$SOURCE_OMOP_FOLDER"
13
+
14
+ # Ensure OMOP_FOLDER exists
15
+ if [ ! -d "$OMOP_FOLDER" ]; then
16
+ echo "Creating $OMOP_FOLDER"
17
+ mkdir -p "$OMOP_FOLDER"
18
+ fi
19
+
20
+ # Removing existing OMOP tables
21
+ rm -rf $OMOP_FOLDER/{person,visit_occurrence,condition_occurrence,procedure_occurrence,drug_exposure,death,measurement,observation_period,condition_era}
22
+
23
+ # Removing existing OMOP concept tables
24
+ rm -rf $OMOP_FOLDER/{concept,concept_ancestor,concept_relationship}
25
+
26
+ # Copying OMOP concept tables if they don't already exist
27
+ for table in concept concept_relationship concept_ancestor; do
28
+ if [ ! -d "$OMOP_FOLDER/$table" ]; then
29
+ echo "Creating $OMOP_FOLDER/$table"
30
+ cp -r "$SOURCE_OMOP_FOLDER/$table" "$OMOP_FOLDER/$table"
31
+ fi
32
+ done
33
+
34
+ # Reconstructing the OMOP instance from patient sequences
35
+ echo "Reconstructing the OMOP instance from patient sequences in $OMOP_FOLDER"
36
+ python -m cehrgpt.generation.omop_converter_batch \
37
+ --patient_sequence_path "$PATIENT_SEQUENCE_FOLDER" \
38
+ --output_folder "$OMOP_FOLDER" \
39
+ --concept_path "$OMOP_FOLDER/concept" \
40
+ --buffer_size 1280 \
41
+ --cpu_cores 10
42
+
43
+ # Create observation_period
44
+ echo "Reconstructing observation_period in $OMOP_FOLDER"
45
+ python -u -m cehrgpt.omop.observation_period \
46
+ --input_folder "$OMOP_FOLDER" \
47
+ --output_folder "$OMOP_FOLDER" \
48
+ --domain_table_list "condition_occurrence drug_exposure procedure_occurrence measurement"
49
+
50
+ # Create condition_era
51
+ echo "Reconstructing condition_era in $OMOP_FOLDER"
52
+ python -u -m cehrgpt.omop.condition_era \
53
+ --input_folder "$OMOP_FOLDER" \
54
+ --output_folder "$OMOP_FOLDER" \
55
+ --domain_table_list "condition_occurrence"
@@ -0,0 +1,382 @@
1
+ import datetime
2
+ from typing import Any, Dict
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
7
+ ED_VISIT_TYPE_CODES,
8
+ INPATIENT_VISIT_TYPE_CODES,
9
+ INPATIENT_VISIT_TYPES,
10
+ DatasetMapping,
11
+ replace_escape_chars,
12
+ )
13
+ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
14
+ from cehrbert_data.const.common import NA
15
+ from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
16
+ from dateutil.relativedelta import relativedelta
17
+
18
+ from cehrgpt.models.tokenization_hf_cehrgpt import (
19
+ NONE_BIN,
20
+ UNKNOWN_BIN,
21
+ CehrGptTokenizer,
22
+ )
23
+
24
+
25
+ def convert_date_to_posix_time(index_date: datetime.date) -> float:
26
+ return datetime.datetime.combine(
27
+ index_date, datetime.datetime.min.time()
28
+ ).timestamp()
29
+
30
+
31
+ class MedToCehrGPTDatasetMapping(DatasetMapping):
32
+ def __init__(
33
+ self,
34
+ data_args: DataTrainingArguments,
35
+ is_pretraining: bool = True,
36
+ include_inpatient_hour_token: bool = True,
37
+ ):
38
+ self._time_token_function = get_att_function(data_args.att_function_type)
39
+ self._include_auxiliary_token = data_args.include_auxiliary_token
40
+ self._inpatient_time_token_function = get_att_function(
41
+ data_args.inpatient_att_function_type
42
+ )
43
+ self._include_demographic_prompt = data_args.include_demographic_prompt
44
+ self._is_pretraining = is_pretraining
45
+ self._include_inpatient_hour_token = include_inpatient_hour_token
46
+
47
+ """
48
+ This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
49
+ to the CehrGPT format. We make several assumptions
50
+ - The first event contains the demographic information
51
+ - From the second event onward
52
+ - the time of the event is visit_start_datetime.
53
+ - the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
54
+ - in case of inpatient visits, the last measurement is assumed to
55
+ contain the standard OMOP concept id for discharge facilities (e.g 8536)
56
+ - in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
57
+ """
58
+
59
+ def remove_columns(self):
60
+ if self._is_pretraining:
61
+ return ["visits", "birth_datetime", "index_date"]
62
+ else:
63
+ return [
64
+ "visits",
65
+ "birth_datetime",
66
+ "visit_concept_ids",
67
+ ]
68
+
69
+ @staticmethod
70
+ def _update_cehrgpt_record(
71
+ cehrgpt_record: Dict[str, Any],
72
+ code: str,
73
+ concept_value_mask: int = 0,
74
+ number_as_value: float = 0.0,
75
+ concept_as_value: str = "0",
76
+ is_numeric_type: int = 0,
77
+ unit: str = NA,
78
+ ) -> None:
79
+ cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
80
+ cehrgpt_record["concept_value_masks"].append(concept_value_mask)
81
+ cehrgpt_record["number_as_values"].append(number_as_value)
82
+ cehrgpt_record["concept_as_values"].append(concept_as_value)
83
+ cehrgpt_record["units"].append(unit)
84
+ cehrgpt_record["is_numeric_types"].append(is_numeric_type)
85
+
86
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
87
+ cehrgpt_record = {
88
+ "person_id": record["patient_id"],
89
+ "concept_ids": [],
90
+ "concept_value_masks": [],
91
+ "number_as_values": [],
92
+ "concept_as_values": [],
93
+ "units": [],
94
+ "is_numeric_types": [],
95
+ }
96
+ # Extract the demographic information
97
+ birth_datetime = record["birth_datetime"]
98
+ if isinstance(birth_datetime, pd.Timestamp):
99
+ birth_datetime = birth_datetime.to_pydatetime()
100
+ gender = record["gender"]
101
+ race = record["race"]
102
+
103
+ # Add the demographic tokens
104
+ first_visit = record["visits"][0]
105
+ year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
106
+ age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
107
+ self._update_cehrgpt_record(cehrgpt_record, year_str)
108
+ self._update_cehrgpt_record(cehrgpt_record, age_str)
109
+ self._update_cehrgpt_record(cehrgpt_record, gender)
110
+ self._update_cehrgpt_record(cehrgpt_record, race)
111
+
112
+ # Use a data cursor to keep track of time
113
+ date_cursor = None
114
+
115
+ # Loop through all the visits excluding the first event containing the demographics
116
+ for i, visit in enumerate(
117
+ sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
118
+ ):
119
+
120
+ events = visit["events"]
121
+
122
+ # Skip this visit if the number measurements in the event is zero
123
+ if events is None or len(events) == 0:
124
+ continue
125
+
126
+ visit_start_datetime = visit["visit_start_datetime"]
127
+ time_delta = (
128
+ (visit_start_datetime - date_cursor).days if date_cursor else None
129
+ )
130
+ date_cursor = visit_start_datetime
131
+
132
+ # We assume the first measurement to be the visit type of the current visit
133
+ visit_type = visit["visit_type"]
134
+ is_er_or_inpatient = (
135
+ visit_type in INPATIENT_VISIT_TYPES
136
+ or visit_type in INPATIENT_VISIT_TYPE_CODES
137
+ or visit_type in ED_VISIT_TYPE_CODES
138
+ )
139
+
140
+ # Add artificial time tokens to the patient timeline if timedelta exists
141
+ if time_delta is not None:
142
+ # This generates an artificial time token depending on the choice of the time token functions
143
+ self._update_cehrgpt_record(
144
+ cehrgpt_record,
145
+ code=self._time_token_function(time_delta),
146
+ )
147
+
148
+ # Add the VS token to the patient timeline to mark the start of a visit
149
+ relativedelta(visit["visit_start_datetime"], birth_datetime).years
150
+ # Calculate the week number since the epoch time
151
+ date = (
152
+ visit["visit_start_datetime"]
153
+ - datetime.datetime(year=1970, month=1, day=1)
154
+ ).days // 7
155
+
156
+ # Add a [VS] token
157
+ self._update_cehrgpt_record(
158
+ cehrgpt_record,
159
+ code="[VS]",
160
+ )
161
+ # Add a visit type token
162
+ self._update_cehrgpt_record(
163
+ cehrgpt_record,
164
+ code=visit_type,
165
+ )
166
+ # Keep track of the existing outpatient events, we don't want to add them again
167
+ existing_outpatient_events = list()
168
+ for e in events:
169
+ # If the event doesn't have a time stamp, we skip it
170
+ if not e["time"]:
171
+ continue
172
+
173
+ # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
174
+ numeric_value = e.get("numeric_value", None)
175
+ text_value = e.get("text_value", None)
176
+ # The unit might be populated with a None value
177
+ unit = e.get("unit", NA) if e.get("unit", NA) else NA
178
+ concept_value_mask = int(
179
+ numeric_value is not None or text_value is not None
180
+ )
181
+ is_numeric_type = int(numeric_value is not None)
182
+ code = replace_escape_chars(e["code"])
183
+
184
+ # Add a medical token to the patient timeline
185
+ # If this is an inpatient visit, we use the event time stamps to calculate age and date
186
+ # because the patient can stay in the hospital for a period of time.
187
+ if is_er_or_inpatient:
188
+ # Calculate the week number since the epoch time
189
+ date = (
190
+ e["time"] - datetime.datetime(year=1970, month=1, day=1)
191
+ ).days // 7
192
+ # Calculate the time diff in days w.r.t the previous measurement
193
+ meas_time_diff = (e["time"] - date_cursor).days
194
+ # Update the date_cursor if the time diff between two neighboring measurements is greater than and
195
+ # equal to 1 day
196
+ if meas_time_diff > 0:
197
+ date_cursor = e["time"]
198
+ if self._inpatient_time_token_function:
199
+ # This generates an artificial time token depending on the choice of the time token functions
200
+ self._update_cehrgpt_record(
201
+ cehrgpt_record,
202
+ code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
203
+ )
204
+ else:
205
+ # For outpatient visits, we use the visit time stamp to calculate age and time because we assume
206
+ # the outpatient visits start and end on the same day.
207
+ # We check whether the date/code/value combination already exists in the existing events
208
+ # If they exist, we do not add them to the patient timeline for outpatient visits.
209
+ if (
210
+ date,
211
+ code,
212
+ numeric_value,
213
+ text_value,
214
+ concept_value_mask,
215
+ numeric_value,
216
+ ) in existing_outpatient_events:
217
+ continue
218
+
219
+ self._update_cehrgpt_record(
220
+ cehrgpt_record,
221
+ code=code,
222
+ concept_value_mask=concept_value_mask,
223
+ unit=unit,
224
+ number_as_value=numeric_value if numeric_value else 0.0,
225
+ concept_as_value=(
226
+ replace_escape_chars(text_value) if text_value else "0"
227
+ ),
228
+ is_numeric_type=is_numeric_type,
229
+ )
230
+ existing_outpatient_events.append(
231
+ (
232
+ date,
233
+ code,
234
+ numeric_value,
235
+ text_value,
236
+ concept_value_mask,
237
+ numeric_value,
238
+ )
239
+ )
240
+
241
+ # For inpatient or ER visits, we want to discharge_facility to the end of the visit
242
+ if is_er_or_inpatient:
243
+ # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
244
+ visit_end_datetime = visit.get("visit_end_datetime", None)
245
+ if visit_end_datetime:
246
+ date_cursor = visit_end_datetime
247
+
248
+ if self._include_auxiliary_token:
249
+ # Reuse the age and date calculated for the last event in the patient timeline for the discharge
250
+ # facility event
251
+ discharge_facility = (
252
+ visit["discharge_facility"]
253
+ if ("discharge_facility" in visit)
254
+ and visit["discharge_facility"]
255
+ else "0"
256
+ )
257
+
258
+ self._update_cehrgpt_record(
259
+ cehrgpt_record,
260
+ code=discharge_facility,
261
+ )
262
+
263
+ # Reuse the age and date calculated for the last event in the patient timeline
264
+ self._update_cehrgpt_record(
265
+ cehrgpt_record,
266
+ code="[VE]",
267
+ )
268
+
269
+ # Generate the orders of the concepts that the cehrbert dataset mapping function expects
270
+ cehrgpt_record["orders"] = list(
271
+ range(1, len(cehrgpt_record["concept_ids"]) + 1)
272
+ )
273
+
274
+ # Add some count information for this sequence
275
+ cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
276
+ cehrgpt_record["num_of_visits"] = len(record["visits"])
277
+
278
+ if "label" in record:
279
+ cehrgpt_record["label"] = record["label"]
280
+ if "age_at_index" in record:
281
+ cehrgpt_record["age_at_index"] = record["age_at_index"]
282
+
283
+ return cehrgpt_record
284
+
285
+
286
+ class HFCehrGptTokenizationMapping(DatasetMapping):
287
+ def __init__(
288
+ self,
289
+ concept_tokenizer: CehrGptTokenizer,
290
+ ):
291
+ self._concept_tokenizer = concept_tokenizer
292
+ self._lab_token_ids = self._concept_tokenizer.lab_token_ids
293
+
294
+ def remove_columns(self):
295
+ return [
296
+ "concept_value_masks",
297
+ "is_numeric_types",
298
+ ]
299
+
300
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
301
+ # If any concept has a value associated with it, we normalize the value
302
+ record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
303
+ record["value_indicators"] = record["concept_value_masks"]
304
+ if "number_as_values" not in record or "concept_as_values" not in record:
305
+ record["number_as_values"] = [
306
+ float(value) if isinstance(value, float) else None
307
+ for value in record["concept_values"]
308
+ ]
309
+ record["is_numeric_types"] = [
310
+ int(isinstance(value, float)) for value in record["concept_values"]
311
+ ]
312
+ record["concept_as_values"] = [
313
+ value if isinstance(value, str) else None
314
+ for value in record["concept_values"]
315
+ ]
316
+ if np.any(np.asarray(record["concept_value_masks"]) > 0):
317
+ values = []
318
+ for i, (
319
+ concept_id,
320
+ unit,
321
+ concept_value_mask,
322
+ number_as_value,
323
+ concept_as_value,
324
+ is_numeric_type,
325
+ ) in enumerate(
326
+ zip(
327
+ record["concept_ids"],
328
+ record["units"],
329
+ record["concept_value_masks"],
330
+ record["number_as_values"],
331
+ record["concept_as_values"],
332
+ record["is_numeric_types"],
333
+ )
334
+ ):
335
+ if concept_value_mask == 1:
336
+ value = UNKNOWN_BIN
337
+ if is_numeric_type == 1:
338
+ if concept_id in self._concept_tokenizer.numeric_concept_ids:
339
+ value = self._concept_tokenizer.normalize(
340
+ concept_id, unit, number_as_value
341
+ )
342
+ elif isinstance(concept_as_value, str):
343
+ value = concept_as_value
344
+ values.append(value)
345
+ else:
346
+ values.append(NONE_BIN)
347
+ assert len(values) == len(record["input_ids"])
348
+ record["values"] = self._concept_tokenizer.encode_value(values)
349
+ else:
350
+ record["values"] = self._concept_tokenizer.encode_value(
351
+ [NONE_BIN for _ in range(len(record["concept_value_masks"]))]
352
+ )
353
+ # Delete these features because they contain null values and pyarrow cannot concatenate multiple records
354
+ del record["number_as_values"]
355
+ del record["concept_as_values"]
356
+ return record
357
+
358
+
359
+ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
360
+ """Consider removing this transformation in the future."""
361
+
362
+ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
363
+ record = super().transform(record)
364
+ record.update(
365
+ {
366
+ "age_at_index": (
367
+ record["age"] if "age" in record else record["age_at_index"]
368
+ ),
369
+ "classifier_label": int(record["label"] > 0),
370
+ "index_date": (
371
+ convert_date_to_posix_time(record["index_date"])
372
+ if "index_date" in record
373
+ else None
374
+ ),
375
+ }
376
+ )
377
+ return record
378
+
379
+ def remove_columns(self):
380
+ columns = super().remove_columns()
381
+ columns.append("label")
382
+ return columns