cehrgpt 0.0.2__tar.gz → 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/.gitignore +3 -0
  2. {cehrgpt-0.0.2/src/cehrgpt.egg-info → cehrgpt-0.1.0}/PKG-INFO +7 -5
  3. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/pyproject.toml +4 -3
  4. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/scripts/level_three_evaluation.sh +10 -6
  5. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  6. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  7. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  8. cehrgpt-0.1.0/src/cehrgpt/data/sample_packing_sampler.py +151 -0
  9. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  10. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/config.py +10 -0
  11. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/hf_cehrgpt.py +243 -73
  12. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  13. cehrgpt-0.1.0/src/cehrgpt/runners/data_utils.py +243 -0
  14. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/runners/gpt_runner_util.py +0 -10
  15. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  16. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  17. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  18. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/runners/hyperparameter_search_util.py +4 -1
  19. cehrgpt-0.1.0/src/cehrgpt/runners/sample_packing_trainer.py +168 -0
  20. cehrgpt-0.1.0/src/cehrgpt/simulations/generate_plots.py +95 -0
  21. cehrgpt-0.1.0/src/cehrgpt/simulations/run_simulation.sh +24 -0
  22. cehrgpt-0.1.0/src/cehrgpt/simulations/time_embedding_simulation.py +250 -0
  23. cehrgpt-0.1.0/src/cehrgpt/simulations/time_token_simulation.py +177 -0
  24. cehrgpt-0.1.0/src/cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  25. cehrgpt-0.1.0/src/cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  26. {cehrgpt-0.0.2 → cehrgpt-0.1.0/src/cehrgpt.egg-info}/PKG-INFO +7 -5
  27. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt.egg-info/SOURCES.txt +13 -9
  28. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt.egg-info/requires.txt +4 -3
  29. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +11 -5
  30. cehrgpt-0.1.0/tests/integration_tests/runners/hf_cehrgpt_pretrain_sample_packing_runner_test.py +115 -0
  31. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +9 -3
  32. cehrgpt-0.1.0/tests/unit_tests/models/model_utils_test.py +131 -0
  33. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +4 -4
  34. cehrgpt-0.1.0/tests/unit_tests/tools/__init__.py +0 -0
  35. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  36. cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  37. cehrgpt-0.0.2/src/cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  38. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  39. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  40. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  41. cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  42. cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  43. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/.github/workflows/build-python.yaml +0 -0
  44. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/.github/workflows/tests.yaml +0 -0
  45. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/.pre-commit-config.yaml +0 -0
  46. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/LICENSE +0 -0
  47. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/README.md +0 -0
  48. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/sample_configs/cehrgpt_pretrain_sample_config.yaml +0 -0
  49. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/sample_data/pretrain/patient_sequence.parquet +0 -0
  50. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
  51. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
  52. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/scripts/omop_pipeline.sh +0 -0
  53. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/scripts/pool_generated_sequences.sh +0 -0
  54. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/setup.cfg +0 -0
  55. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/__init__.py +0 -0
  56. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/__init__.py +0 -0
  57. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/__init__.py +0 -0
  58. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
  59. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
  60. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
  61. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
  62. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
  63. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
  64. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/analysis/privacy/utils.py +0 -0
  65. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/cehrgpt_args.py +0 -0
  66. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/data/__init__.py +0 -0
  67. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/generation/__init__.py +0 -0
  68. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
  69. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/generation/omop_converter_batch.py +0 -0
  70. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/generation/omop_entity.py +0 -0
  71. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/gpt_utils.py +0 -0
  72. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/__init__.py +0 -0
  73. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/hf_modeling_outputs.py +0 -0
  74. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
  75. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/models/special_tokens.py +0 -0
  76. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/__init__.py +0 -0
  77. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/condition_era.py +0 -0
  78. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/observation_period.py +0 -0
  79. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/omop_argparse.py +0 -0
  80. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/omop_table_builder.py +0 -0
  81. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/queries/__init__.py +0 -0
  82. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/queries/condition_era.py +0 -0
  83. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/queries/observation_period.py +0 -0
  84. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
  85. {cehrgpt-0.0.2/src/cehrgpt/rl_finetune → cehrgpt-0.1.0/src/cehrgpt/runners}/__init__.py +0 -0
  86. {cehrgpt-0.0.2/src/cehrgpt/runners → cehrgpt-0.1.0/src/cehrgpt/simulations}/__init__.py +0 -0
  87. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/__init__.py +0 -0
  88. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
  89. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
  90. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
  91. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/time_to_event_model.py +0 -0
  92. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/time_to_event_prediction.py +0 -0
  93. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
  94. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/__init__.py +0 -0
  95. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
  96. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +0 -0
  97. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
  98. {cehrgpt-0.0.2/tests → cehrgpt-0.1.0/src/cehrgpt/tools/linear_prob}/__init__.py +0 -0
  99. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
  100. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
  101. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
  102. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/src/cehrgpt.egg-info/top_level.txt +0 -0
  103. {cehrgpt-0.0.2/tests/integration_tests → cehrgpt-0.1.0/tests}/__init__.py +0 -0
  104. {cehrgpt-0.0.2/tests/integration_tests/runners → cehrgpt-0.1.0/tests/integration_tests}/__init__.py +0 -0
  105. {cehrgpt-0.0.2/tests/unit_tests → cehrgpt-0.1.0/tests/integration_tests/runners}/__init__.py +0 -0
  106. {cehrgpt-0.0.2/tests/unit_tests/models → cehrgpt-0.1.0/tests/unit_tests}/__init__.py +0 -0
  107. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/gpt_utils_test.py +0 -0
  108. {cehrgpt-0.0.2/tests/unit_tests/models/tokenization → cehrgpt-0.1.0/tests/unit_tests/models}/__init__.py +0 -0
  109. {cehrgpt-0.0.2/tests/unit_tests/runners → cehrgpt-0.1.0/tests/unit_tests/models/tokenization}/__init__.py +0 -0
  110. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
  111. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
  112. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
  113. {cehrgpt-0.0.2/tests/unit_tests/tools → cehrgpt-0.1.0/tests/unit_tests/runners}/__init__.py +0 -0
  114. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/tokenization_test.py +0 -0
  115. {cehrgpt-0.0.2 → cehrgpt-0.1.0}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
@@ -4,6 +4,9 @@
4
4
  venv*
5
5
  dist/*
6
6
 
7
+ *png
8
+ *json
9
+
7
10
  *ipynb_checkpoints/
8
11
  *h5
9
12
  *logs
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: cehrgpt
3
- Version: 0.0.2
3
+ Version: 0.1.0
4
4
  Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
5
5
  Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
6
6
  License: MIT License
@@ -12,13 +12,14 @@ Classifier: Programming Language :: Python :: 3
12
12
  Requires-Python: >=3.10.0
13
13
  Description-Content-Type: text/markdown
14
14
  License-File: LICENSE
15
- Requires-Dist: cehrbert==1.3.8
15
+ Requires-Dist: cehrbert==1.4.1
16
+ Requires-Dist: cehrbert_data==0.0.7
16
17
  Requires-Dist: openai==1.54.3
17
18
  Requires-Dist: optuna==4.0.0
18
- Requires-Dist: transformers==4.40.0
19
+ Requires-Dist: transformers==4.44.0
19
20
  Requires-Dist: tokenizers==0.19.0
20
21
  Requires-Dist: peft==0.10.0
21
- Requires-Dist: trl==0.11.4
22
+ Requires-Dist: lightgbm
22
23
  Provides-Extra: dev
23
24
  Requires-Dist: pre-commit; extra == "dev"
24
25
  Requires-Dist: pytest; extra == "dev"
@@ -29,6 +30,7 @@ Requires-Dist: hypothesis; extra == "dev"
29
30
  Requires-Dist: black; extra == "dev"
30
31
  Provides-Extra: flash-attn
31
32
  Requires-Dist: flash_attn; extra == "flash-attn"
33
+ Dynamic: license-file
32
34
 
33
35
  # CEHRGPT
34
36
 
@@ -28,13 +28,14 @@ classifiers = [
28
28
  ]
29
29
 
30
30
  dependencies = [
31
- "cehrbert==1.3.8",
31
+ "cehrbert==1.4.1",
32
+ "cehrbert_data==0.0.7",
32
33
  "openai==1.54.3",
33
34
  "optuna==4.0.0",
34
- "transformers==4.40.0",
35
+ "transformers==4.44.0",
35
36
  "tokenizers==0.19.0",
36
37
  "peft==0.10.0",
37
- "trl==0.11.4",
38
+ "lightgbm",
38
39
  ]
39
40
 
40
41
  [tool.setuptools_scm]
@@ -29,7 +29,8 @@ python -u -m cehrbert_data.prediction_cohorts.cad_cabg_cohort \
29
29
  -dl 1985-01-01 -du 2023-12-31 \
30
30
  -l 18 -u 100 -ow 360 -ps 0 -pw 360 -f \
31
31
  --att_type cehr_bert \
32
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
32
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
33
+ --is_remove_index_prediction_starts
33
34
 
34
35
  # Run Predictions on CAD CABG
35
36
  echo "Run predictions on cad_cabg"
@@ -56,9 +57,10 @@ python -u -m cehrbert_data.prediction_cohorts.hf_readmission \
56
57
  -c hf_readmission_bow \
57
58
  -i "$OMOP_FOLDER" \
58
59
  -o "$OMOP_FOLDER/cohorts/hf_readmission" \
59
- -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 0 -pw 30 -f \
60
+ -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
60
61
  --att_type cehr_bert \
61
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
62
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
63
+ --is_remove_index_prediction_starts
62
64
 
63
65
  # Run predictions on HF Readmission
64
66
  echo "Run predictions on hf_readmission"
@@ -85,9 +87,10 @@ python -u -m cehrbert_data.prediction_cohorts.copd_readmission \
85
87
  -c copd_readmission_bow \
86
88
  -i "$OMOP_FOLDER" \
87
89
  -o "$OMOP_FOLDER/cohorts/copd_readmission" \
88
- -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 720 -ps 0 -pw 360 -f \
90
+ -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 360 -ps 1 -pw 30 -f \
89
91
  --att_type cehr_bert \
90
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
92
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
93
+ --is_remove_index_prediction_starts
91
94
 
92
95
  # Run predictions on COPD Readmission
93
96
  echo "Run predictions on copd_readmission"
@@ -145,7 +148,8 @@ python -u -m cehrbert_data.prediction_cohorts.afib_ischemic_stroke \
145
148
  -o "$OMOP_FOLDER/cohorts/afib_ischemic_stroke" \
146
149
  -dl 1985-01-01 -du 2023-12-31 -l 18 -u 100 -ow 720 -ps 0 -pw 360 -f \
147
150
  --att_type cehr_bert \
148
- --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv
151
+ --ehr_table_list condition_occurrence procedure_occurrence drug_exposure -iv \
152
+ --is_remove_index_prediction_starts
149
153
 
150
154
  # Run predictions on AFIB Ischemic Stroke
151
155
  echo "Run predictions on afib_ischemic_stroke"
@@ -1,9 +1,10 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from cehrbert.data_generators.hf_data_generator.hf_dataset import (
4
4
  FINETUNING_COLUMNS,
5
5
  apply_cehrbert_dataset_mapping,
6
6
  )
7
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
7
8
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
8
9
  from datasets import Dataset, DatasetDict
9
10
 
@@ -31,16 +32,25 @@ def create_cehrgpt_pretraining_dataset(
31
32
  dataset: Union[Dataset, DatasetDict],
32
33
  cehrgpt_tokenizer: CehrGptTokenizer,
33
34
  data_args: DataTrainingArguments,
34
- ) -> Dataset:
35
+ cache_file_collector: Optional[CacheFileCollector] = None,
36
+ ) -> Union[Dataset, DatasetDict]:
35
37
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
38
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
39
+ if not data_args.streaming:
40
+ if isinstance(dataset, DatasetDict):
41
+ all_columns = dataset["train"].column_names
42
+ else:
43
+ all_columns = dataset.column_names
44
+ if "visit_concept_ids" in all_columns:
45
+ dataset.remove_columns(["visit_concept_ids"])
36
46
  dataset = apply_cehrbert_dataset_mapping(
37
47
  dataset,
38
48
  HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
39
49
  num_proc=data_args.preprocessing_num_workers,
40
50
  batch_size=data_args.preprocessing_batch_size,
41
51
  streaming=data_args.streaming,
52
+ cache_file_collector=cache_file_collector,
42
53
  )
43
-
44
54
  if not data_args.streaming:
45
55
  if isinstance(dataset, DatasetDict):
46
56
  all_columns = dataset["train"].column_names
@@ -56,8 +66,17 @@ def create_cehrgpt_finetuning_dataset(
56
66
  dataset: Union[Dataset, DatasetDict],
57
67
  cehrgpt_tokenizer: CehrGptTokenizer,
58
68
  data_args: DataTrainingArguments,
59
- ) -> Dataset:
69
+ cache_file_collector: Optional[CacheFileCollector] = None,
70
+ ) -> Union[Dataset, DatasetDict]:
60
71
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
72
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
73
+ if not data_args.streaming:
74
+ if isinstance(dataset, DatasetDict):
75
+ all_columns = dataset["train"].column_names
76
+ else:
77
+ all_columns = dataset.column_names
78
+ if "visit_concept_ids" in all_columns:
79
+ dataset.remove_columns(["visit_concept_ids"])
61
80
  mapping_functions = [
62
81
  HFFineTuningMapping(cehrgpt_tokenizer),
63
82
  ]
@@ -68,6 +87,7 @@ def create_cehrgpt_finetuning_dataset(
68
87
  num_proc=data_args.preprocessing_num_workers,
69
88
  batch_size=data_args.preprocessing_batch_size,
70
89
  streaming=data_args.streaming,
90
+ cache_file_collector=cache_file_collector,
71
91
  )
72
92
 
73
93
  if not data_args.streaming:
@@ -1,5 +1,5 @@
1
1
  import random
2
- from typing import Any, Dict
2
+ from typing import Any, Dict, Optional
3
3
 
4
4
  import numpy as np
5
5
  import torch
@@ -105,9 +105,12 @@ class CehrGptDataCollator:
105
105
  self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
106
106
  for example in examples
107
107
  ]
108
+
108
109
  batch_attention_mask = [
109
110
  self._try_reverse_tensor(
110
- torch.ones_like(
111
+ self._convert_to_tensor(example["attention_mask"]).to(torch.float)
112
+ if "attention_mask" in example
113
+ else torch.ones_like(
111
114
  self._convert_to_tensor(example["input_ids"]), dtype=torch.float
112
115
  )
113
116
  )
@@ -128,16 +131,40 @@ class CehrGptDataCollator:
128
131
  )
129
132
  assert batch["input_ids"].shape[1] <= self.max_length
130
133
  assert batch["attention_mask"].shape[1] <= self.max_length
134
+ assert batch["attention_mask"].shape[1] == batch["input_ids"].shape[1], (
135
+ f'batch["attention_mask"].shape[1]: {batch["attention_mask"].shape[1]}, '
136
+ f'batch["input_ids"].shape[1]: {batch["input_ids"].shape[1]}'
137
+ )
138
+ assert batch["input_ids"].max() < self.tokenizer.vocab_size, (
139
+ f"batch['input_ids'].max(): {batch['input_ids'].max()} must be smaller than "
140
+ f"self.tokenizer.vocab_size: {self.tokenizer.vocab_size}. "
141
+ f"batch['input_ids']: {batch['input_ids']} "
142
+ )
131
143
 
132
- if self.pretraining:
133
- batch["labels"] = self._try_reverse_tensor(
144
+ if "position_ids" in examples[0]:
145
+ batch_position_ids = [
146
+ self._try_reverse_tensor(
147
+ self._convert_to_tensor(example["position_ids"])
148
+ )
149
+ for example in examples
150
+ ]
151
+ # Pad sequences to the max length in the batch
152
+ batch["position_ids"] = self._try_reverse_tensor(
134
153
  pad_sequence(
135
- batch_input_ids,
154
+ batch_position_ids,
136
155
  batch_first=True,
137
- padding_value=-100,
156
+ padding_value=self.max_length,
138
157
  ).to(torch.int64)
139
158
  )
140
159
 
160
+ if self.pretraining:
161
+ batch["labels"] = torch.where(
162
+ (batch["input_ids"] != self.tokenizer.pad_token_id)
163
+ & batch["attention_mask"].to(torch.bool),
164
+ batch["input_ids"],
165
+ -100,
166
+ )
167
+
141
168
  if self.use_sub_time_tokenization:
142
169
  time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
143
170
  masked_tokens = batch["input_ids"].clone()
@@ -170,7 +197,7 @@ class CehrGptDataCollator:
170
197
  if self.include_values:
171
198
  batch_value_indicators = [
172
199
  self._try_reverse_tensor(
173
- self._convert_to_tensor(example["value_indicators"])
200
+ self._convert_to_tensor(example["value_indicators"]).to(torch.bool)
174
201
  )
175
202
  for example in examples
176
203
  ]
@@ -178,7 +205,6 @@ class CehrGptDataCollator:
178
205
  self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
179
206
  for example in examples
180
207
  ]
181
-
182
208
  batch["value_indicators"] = self._try_reverse_tensor(
183
209
  pad_sequence(
184
210
  batch_value_indicators, batch_first=True, padding_value=False
@@ -200,41 +226,58 @@ class CehrGptDataCollator:
200
226
  batch["value_indicators"], batch["values"].clone(), -100
201
227
  )
202
228
 
229
+ bz = len(examples)
203
230
  if "person_id" in examples[0]:
204
- batch["person_id"] = torch.cat(
205
- [
206
- self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
207
- for example in examples
208
- ],
209
- dim=0,
210
- ).to(torch.int32)
231
+ batch["person_id"] = (
232
+ torch.cat(
233
+ [
234
+ self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
235
+ for example in examples
236
+ ],
237
+ dim=0,
238
+ )
239
+ .to(torch.int32)
240
+ .reshape(bz, -1)
241
+ )
211
242
 
212
243
  if "index_date" in examples[0]:
213
244
  batch["index_date"] = torch.cat(
214
245
  [
215
- self._convert_to_tensor(example["index_date"]).reshape(-1, 1)
246
+ torch.tensor(example["index_date"], dtype=torch.float64).reshape(
247
+ -1, 1
248
+ )
216
249
  for example in examples
217
250
  ],
218
251
  dim=0,
219
- ).to(torch.float32)
252
+ ).reshape(bz, -1)
220
253
 
221
254
  if "age_at_index" in examples[0]:
222
- batch["age_at_index"] = torch.cat(
223
- [
224
- self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
225
- for example in examples
226
- ],
227
- dim=0,
228
- ).to(torch.float32)
255
+ batch["age_at_index"] = (
256
+ torch.cat(
257
+ [
258
+ self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
259
+ for example in examples
260
+ ],
261
+ dim=0,
262
+ )
263
+ .to(torch.float32)
264
+ .reshape(bz, -1)
265
+ )
229
266
 
230
267
  if "classifier_label" in examples[0]:
231
- batch["classifier_label"] = torch.cat(
232
- [
233
- self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1)
234
- for example in examples
235
- ],
236
- dim=0,
237
- ).to(torch.float32)
268
+ batch["classifier_label"] = (
269
+ torch.cat(
270
+ [
271
+ self._convert_to_tensor(example["classifier_label"]).reshape(
272
+ -1, 1
273
+ )
274
+ for example in examples
275
+ ],
276
+ dim=0,
277
+ )
278
+ .to(torch.float32)
279
+ .reshape(bz, -1)
280
+ )
238
281
 
239
282
  return batch
240
283
 
@@ -273,53 +316,69 @@ class CehrGptDataCollator:
273
316
  record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
274
317
  return record
275
318
 
276
- def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
319
+ def generate_start_end_index(
320
+ self, record: Dict[str, Any], max_length_allowed: Optional[int] = None
321
+ ) -> Dict[str, Any]:
277
322
  """Adding the start and end indices to extract a portion of the patient sequence."""
278
323
  # concept_ids will be used to for time to event predictions and identifying the visit starts
324
+ max_length_allowed = (
325
+ self.max_length if max_length_allowed is None else max_length_allowed
326
+ )
327
+ sample_packing = getattr(self, "sample_packing", False)
279
328
  input_ids = record["input_ids"]
280
329
  if isinstance(input_ids, torch.Tensor):
281
330
  input_ids = input_ids.detach().tolist()
282
331
  concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
283
332
  seq_length = len(record["input_ids"])
284
- new_max_length = self.max_length - 1 # Subtract one for the [END] token
333
+
334
+ # Subtract one for the [END] token when sample_packing is not enabled
335
+ new_max_length = (
336
+ max_length_allowed if sample_packing else max_length_allowed - 1
337
+ )
338
+
339
+ if self.include_ttv_prediction:
340
+ record["time_to_visits"] = torch.concat(
341
+ [self._convert_to_tensor(self._convert_time_to_event(concept_ids))]
342
+ )
285
343
 
286
344
  # Return the record directly if the actual sequence length is less than the max sequence
287
345
  if seq_length <= new_max_length:
288
- record["input_ids"] = torch.concat(
289
- [
290
- self._convert_to_tensor(record["input_ids"]),
291
- self._convert_to_tensor([self.tokenizer.end_token_id]),
292
- ]
293
- )
294
- if self.include_values:
295
- record["value_indicators"] = torch.concat(
296
- [
297
- self._convert_to_tensor(record["value_indicators"]),
298
- self._convert_to_tensor([False]),
299
- ]
300
- ).to(torch.bool)
301
- record["values"] = torch.concat(
302
- [
303
- self._convert_to_tensor(record["values"]),
304
- self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
305
- ]
306
- )
307
- if self.include_ttv_prediction:
308
- record["time_to_visits"] = torch.concat(
346
+ if not sample_packing:
347
+ record["input_ids"] = torch.concat(
309
348
  [
310
- self._convert_to_tensor(
311
- self._convert_time_to_event(concept_ids)
312
- ),
313
- self._convert_to_tensor([-100.0]),
349
+ self._convert_to_tensor(record["input_ids"]),
350
+ self._convert_to_tensor([self.tokenizer.end_token_id]),
314
351
  ]
315
352
  )
353
+ if self.include_values:
354
+ record["value_indicators"] = torch.concat(
355
+ [
356
+ self._convert_to_tensor(record["value_indicators"]),
357
+ self._convert_to_tensor([False]),
358
+ ]
359
+ ).to(torch.bool)
360
+ record["values"] = torch.concat(
361
+ [
362
+ self._convert_to_tensor(record["values"]),
363
+ self._convert_to_tensor(
364
+ [self.tokenizer.pad_value_token_id]
365
+ ),
366
+ ]
367
+ )
368
+ if self.include_ttv_prediction:
369
+ record["time_to_visits"] = torch.concat(
370
+ [
371
+ record["time_to_visits"],
372
+ self._convert_to_tensor([-100.0]),
373
+ ]
374
+ )
316
375
 
317
376
  return record
318
377
 
319
378
  if self.pretraining:
320
379
  # There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
321
380
  # prompt depending on the new starting point
322
- if random.random() < 0.5:
381
+ if random.random() < 0.5 and not sample_packing:
323
382
  start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
324
383
  concept_ids, new_max_length
325
384
  )
@@ -351,6 +410,11 @@ class CehrGptDataCollator:
351
410
  break
352
411
 
353
412
  record["input_ids"] = record["input_ids"][0:end_index]
413
+
414
+ # We want to make sure we take the subset of attention_mask in sample packing if this field is available
415
+ if sample_packing and "attention_mask" in record:
416
+ record["attention_mask"] = record["attention_mask"][0:end_index]
417
+
354
418
  if self.include_values:
355
419
  record["value_indicators"] = self._convert_to_tensor(
356
420
  record["value_indicators"][0:end_index]
@@ -364,7 +428,7 @@ class CehrGptDataCollator:
364
428
  )
365
429
  return record
366
430
  else:
367
- if self.include_demographics:
431
+ if self.include_demographics and not sample_packing:
368
432
  # We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
369
433
  demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
370
434
  concept_ids
@@ -427,6 +491,10 @@ class CehrGptDataCollator:
427
491
  current_token = record["input_ids"][i]
428
492
  if current_token == self.vs_token_id:
429
493
  record["input_ids"] = record["input_ids"][i:end_index]
494
+ if sample_packing and "attention_mask" in record:
495
+ record["attention_mask"] = record["attention_mask"][
496
+ i:end_index
497
+ ]
430
498
  if self.include_values:
431
499
  record["value_indicators"] = record["value_indicators"][
432
500
  i:end_index
@@ -442,6 +510,10 @@ class CehrGptDataCollator:
442
510
  # We simply take the last new_max_length number of tokens from the patient sequence
443
511
  if len(record["input_ids"]) > new_max_length:
444
512
  record["input_ids"] = record["input_ids"][-new_max_length:]
513
+ if sample_packing and "attention_mask" in record:
514
+ record["attention_mask"] = record["attention_mask"][
515
+ -new_max_length:
516
+ ]
445
517
  if self.include_values:
446
518
  record["value_indicators"] = record["value_indicators"][
447
519
  -new_max_length:
@@ -452,31 +524,135 @@ class CehrGptDataCollator:
452
524
  -new_max_length:
453
525
  ]
454
526
 
455
- # Finally we add the end token to the end of the sequence
456
- record["input_ids"] = torch.concat(
457
- [
458
- self._convert_to_tensor(record["input_ids"]),
459
- self._convert_to_tensor([self.tokenizer.end_token_id]),
460
- ]
461
- )
462
- if self.include_values:
463
- record["value_indicators"] = torch.concat(
464
- [
465
- self._convert_to_tensor(record["value_indicators"]),
466
- self._convert_to_tensor([False]),
467
- ]
468
- ).to(torch.bool)
469
- record["values"] = torch.concat(
470
- [
471
- self._convert_to_tensor(record["values"]),
472
- self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
473
- ]
474
- )
475
- if self.include_ttv_prediction:
476
- record["time_to_visits"] = torch.concat(
527
+ if not sample_packing:
528
+ # Finally we add the end token to the end of the sequence
529
+ record["input_ids"] = torch.concat(
477
530
  [
478
- record["time_to_visits"],
479
- self._convert_to_tensor([-100.0]),
531
+ self._convert_to_tensor(record["input_ids"]),
532
+ self._convert_to_tensor([self.tokenizer.end_token_id]),
480
533
  ]
481
534
  )
535
+ if self.include_values:
536
+ record["value_indicators"] = torch.concat(
537
+ [
538
+ self._convert_to_tensor(record["value_indicators"]),
539
+ self._convert_to_tensor([False]),
540
+ ]
541
+ ).to(torch.bool)
542
+ record["values"] = torch.concat(
543
+ [
544
+ self._convert_to_tensor(record["values"]),
545
+ self._convert_to_tensor(
546
+ [self.tokenizer.pad_value_token_id]
547
+ ),
548
+ ]
549
+ )
550
+ if self.include_ttv_prediction:
551
+ record["time_to_visits"] = torch.concat(
552
+ [
553
+ record["time_to_visits"],
554
+ self._convert_to_tensor([-100.0]),
555
+ ]
556
+ )
482
557
  return record
558
+
559
+
560
+ class SamplePackingCehrGptDataCollator(CehrGptDataCollator):
561
+ def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
562
+ self.max_tokens_per_batch = max_tokens
563
+ self.max_position_embeddings = max_position_embeddings
564
+ self.sample_packing = True
565
+ self.add_end_token_in_sample_packing = kwargs.pop(
566
+ "add_end_token_in_sample_packing", False
567
+ )
568
+ super(SamplePackingCehrGptDataCollator, self).__init__(*args, **kwargs)
569
+
570
+ def __call__(self, examples):
571
+ current_input_ids = []
572
+ current_attention_mask = []
573
+ current_position_ids = []
574
+ current_value_indicators = []
575
+ current_values = []
576
+
577
+ # Demographics
578
+ current_person_ids = []
579
+ current_index_dates = []
580
+
581
+ # Binary classification inputs
582
+ current_ages = []
583
+ current_labels = []
584
+
585
+ for idx, example in enumerate(examples):
586
+
587
+ # If the sample length exceeds the model's capacity, truncate this example
588
+ add_end_token = (
589
+ len(example["input_ids"]) <= self.max_position_embeddings
590
+ and self.add_end_token_in_sample_packing
591
+ )
592
+
593
+ if len(example["input_ids"]) > self.max_position_embeddings:
594
+ example = self.generate_start_end_index(
595
+ example, self.max_position_embeddings
596
+ )
597
+
598
+ input_ids = example["input_ids"]
599
+ # We add [END] [PAD], we want to attend to [END], adding [END] is important for sequence generation.
600
+ # If the sequence length of the sequence is less than the context window, we add both [END][PAD], otherwise
601
+ # we only add [PAD] token to the end of the sequence because it's not finished
602
+ current_input_ids.extend(
603
+ list(input_ids)
604
+ + (
605
+ [self.tokenizer.end_token_id, self.tokenizer.pad_token_id]
606
+ if add_end_token
607
+ else [self.tokenizer.pad_token_id]
608
+ )
609
+ )
610
+ current_attention_mask.extend(
611
+ np.ones_like(input_ids).tolist() + ([1, 0] if add_end_token else [0])
612
+ )
613
+ num_tokens_to_pad = 1 + int(add_end_token)
614
+ current_position_ids.extend(list(range(len(input_ids) + num_tokens_to_pad)))
615
+ if self.include_values:
616
+ current_value_indicators.extend(
617
+ list(example["value_indicators"]) + [False] * num_tokens_to_pad
618
+ )
619
+ current_values.extend(
620
+ list(example["values"])
621
+ + [self.tokenizer.pad_value_token_id] * num_tokens_to_pad
622
+ )
623
+
624
+ if "person_id" in example:
625
+ current_person_ids.append(example["person_id"])
626
+
627
+ if "index_date" in example:
628
+ current_index_dates.append(example["index_date"])
629
+
630
+ if "age_at_index" in example:
631
+ current_ages.append(example["age_at_index"])
632
+
633
+ if "classifier_label" in example:
634
+ current_labels.append(example["classifier_label"])
635
+
636
+ assert (
637
+ len(current_input_ids) <= self.max_tokens_per_batch
638
+ ), f"the total number of tokens in the packed sequence should be less than { self.max_tokens_per_batch}"
639
+ packed_example = {
640
+ "input_ids": current_input_ids,
641
+ "attention_mask": current_attention_mask,
642
+ "position_ids": current_position_ids,
643
+ }
644
+ if self.include_values:
645
+ packed_example.update({"value_indicators": current_value_indicators})
646
+ packed_example.update({"values": current_values})
647
+
648
+ if current_labels:
649
+ packed_example.update(
650
+ {
651
+ "person_id": current_person_ids,
652
+ "index_date": current_index_dates,
653
+ "age_at_index": current_ages,
654
+ "classifier_label": current_labels,
655
+ }
656
+ )
657
+
658
+ return super().__call__([packed_example])