cehrgpt 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt-0.0.2/.gitignore +25 -0
- {cehrgpt-0.0.1/src/cehrgpt.egg-info → cehrgpt-0.0.2}/PKG-INFO +52 -6
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/README.md +49 -4
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/pyproject.toml +3 -2
- cehrgpt-0.0.2/sample_configs/cehrgpt_pretrain_sample_config.yaml +51 -0
- cehrgpt-0.0.2/scripts/omop_pipeline.sh +55 -0
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +382 -0
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_collator.py +71 -0
- cehrgpt-0.0.2/src/cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +61 -0
- cehrgpt-0.0.2/src/cehrgpt/generation/generate_paired_cehrgpt_sequence.py +224 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/omop_converter_batch.py +3 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/hf_cehrgpt.py +1 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/tokenization_hf_cehrgpt.py +2 -2
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +586 -0
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +464 -0
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune.py +394 -0
- cehrgpt-0.0.2/src/cehrgpt/rl_finetune/ppo_finetune_v2.py +373 -0
- cehrgpt-0.0.2/src/cehrgpt/runners/hf_cehrgpt_dpo_runner.py +119 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -3
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +44 -8
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +4 -0
- cehrgpt-0.0.2/src/cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2/src/cehrgpt.egg-info}/PKG-INFO +52 -6
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/SOURCES.txt +11 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/requires.txt +3 -2
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_runner_test.py +4 -0
- cehrgpt-0.0.2/tests/unit_tests/tools/__init__.py +0 -0
- cehrgpt-0.0.1/.gitignore +0 -38
- cehrgpt-0.0.1/scripts/omop_pipeline.sh +0 -73
- cehrgpt-0.0.1/src/cehrgpt/data/hf_cehrgpt_dataset_mapping.py +0 -116
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.github/workflows/build-python.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.github/workflows/tests.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/.pre-commit-config.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/LICENSE +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrain/patient_sequence.parquet +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrained_embeddings/pretrained_embedding_concepts.pkl +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/sample_data/pretrained_embeddings/pretrained_embedding_vectors.npy +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/scripts/level_three_evaluation.sh +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/scripts/pool_generated_sequences.sh +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/setup.cfg +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/attribute_inference.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/attribute_inference_config.yml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/member_inference.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/nearest_neighbor_inference.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/reid_inference.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/analysis/privacy/utils.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/cehrgpt_args.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/hf_cehrgpt_dataset.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/data/hf_cehrgpt_dataset_collator.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/chatgpt_generation.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/generate_batch_hf_gpt_sequence.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/generation/omop_entity.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/gpt_utils.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/config.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/hf_modeling_outputs.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/pretrained_embeddings.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/models/special_tokens.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/condition_era.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/observation_period.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/omop_argparse.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/omop_table_builder.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/condition_era.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/queries/observation_period.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/omop/sample_omop_tables.py +0 -0
- {cehrgpt-0.0.1/src/cehrgpt/runners → cehrgpt-0.0.2/src/cehrgpt/rl_finetune}/__init__.py +0 -0
- {cehrgpt-0.0.1/src/cehrgpt/time_to_event → cehrgpt-0.0.2/src/cehrgpt/runners}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/gpt_runner_util.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/runners/hyperparameter_search_util.py +0 -0
- {cehrgpt-0.0.1/src/cehrgpt/tools → cehrgpt-0.0.2/src/cehrgpt/time_to_event}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/30_day_readmission.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/config/t2dm_hf.yaml +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_model.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_prediction.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/time_to_event/time_to_event_utils.py +0 -0
- {cehrgpt-0.0.1/tests → cehrgpt-0.0.2/src/cehrgpt/tools}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/ehrshot_benchmark.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/generate_pretrained_embeddings.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/merge_synthetic_real_dataasets.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt/tools/upload_omop_tables.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/dependency_links.txt +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/src/cehrgpt.egg-info/top_level.txt +0 -0
- {cehrgpt-0.0.1/tests/integration_tests → cehrgpt-0.0.2/tests}/__init__.py +0 -0
- {cehrgpt-0.0.1/tests/integration_tests/runners → cehrgpt-0.0.2/tests/integration_tests}/__init__.py +0 -0
- {cehrgpt-0.0.1/tests/unit_tests → cehrgpt-0.0.2/tests/integration_tests/runners}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/integration_tests/runners/hf_cehrgpt_pretrain_sfm_runner_test.py +0 -0
- {cehrgpt-0.0.1/tests/unit_tests/models → cehrgpt-0.0.2/tests/unit_tests}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/gpt_utils_test.py +0 -0
- {cehrgpt-0.0.1/tests/unit_tests/models/tokenization → cehrgpt-0.0.2/tests/unit_tests/models}/__init__.py +0 -0
- {cehrgpt-0.0.1/tests/unit_tests/runners → cehrgpt-0.0.2/tests/unit_tests/models/tokenization}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/models/tokenization/create_bins_with_spline_test.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/models/tokenization/create_sample_from_bins_test.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/numeric_concept_statistics_test.py +0 -0
- {cehrgpt-0.0.1/tests/unit_tests/tools → cehrgpt-0.0.2/tests/unit_tests/runners}/__init__.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/runners/hf_cehrgpt_finetune_runner_test.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/tokenization_test.py +0 -0
- {cehrgpt-0.0.1 → cehrgpt-0.0.2}/tests/unit_tests/tools/upload_omop_tables_test.py +0 -0
cehrgpt-0.0.2/.gitignore
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
.DS_Store
|
2
|
+
.idea/
|
3
|
+
.vscode/
|
4
|
+
venv*
|
5
|
+
dist/*
|
6
|
+
|
7
|
+
*ipynb_checkpoints/
|
8
|
+
*h5
|
9
|
+
*logs
|
10
|
+
*nohup.out
|
11
|
+
*ipynb
|
12
|
+
|
13
|
+
*__pycache__/
|
14
|
+
.eggs/
|
15
|
+
*.dat
|
16
|
+
.metastore_db/
|
17
|
+
|
18
|
+
build/
|
19
|
+
|
20
|
+
*.out
|
21
|
+
*.egg-info/
|
22
|
+
|
23
|
+
test_data
|
24
|
+
test_dataset_prepared
|
25
|
+
test*results
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: cehrgpt
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.2
|
4
4
|
Summary: CEHR-GPT: Generating Electronic Health Records with Chronological Patient Timelines
|
5
5
|
Author-email: Chao Pang <chaopang229@gmail.com>, Xinzhuo Jiang <xj2193@cumc.columbia.edu>, Krishna Kalluri <kk3326@cumc.columbia.edu>, Elise Minto <em3697@cumc.columbia.edu>, Jason Patterson <jp3477@cumc.columbia.edu>, Nishanth Parameshwar Pavinkurve <np2689@cumc.columbia.edu>, Karthik Natarajan <kn2174@cumc.columbia.edu>
|
6
6
|
License: MIT License
|
@@ -12,11 +12,12 @@ Classifier: Programming Language :: Python :: 3
|
|
12
12
|
Requires-Python: >=3.10.0
|
13
13
|
Description-Content-Type: text/markdown
|
14
14
|
License-File: LICENSE
|
15
|
-
Requires-Dist: cehrbert==1.3.
|
15
|
+
Requires-Dist: cehrbert==1.3.8
|
16
16
|
Requires-Dist: openai==1.54.3
|
17
17
|
Requires-Dist: optuna==4.0.0
|
18
18
|
Requires-Dist: transformers==4.40.0
|
19
|
-
Requires-Dist: tokenizers==0.19
|
19
|
+
Requires-Dist: tokenizers==0.19.0
|
20
|
+
Requires-Dist: peft==0.10.0
|
20
21
|
Requires-Dist: trl==0.11.4
|
21
22
|
Provides-Extra: dev
|
22
23
|
Requires-Dist: pre-commit; extra == "dev"
|
@@ -50,11 +51,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
|
|
50
51
|
To install CEHRGPT, clone this repository and install the required dependencies.
|
51
52
|
|
52
53
|
```bash
|
53
|
-
git clone https://github.com/knatarajan-lab/cehrgpt
|
54
|
-
cd cehrgpt
|
54
|
+
git clone https://github.com/knatarajan-lab/cehrgpt.git
|
55
|
+
cd cehrgpt
|
55
56
|
pip install .
|
56
57
|
```
|
57
58
|
|
59
|
+
## Pretrain
|
60
|
+
Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
|
61
|
+
```bash
|
62
|
+
mkdir test_results
|
63
|
+
# This is NOT required when streaming is set to true
|
64
|
+
mkdir test_dataset_prepared
|
65
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
|
66
|
+
```
|
67
|
+
|
68
|
+
## Generate synthetic sequences
|
69
|
+
Generate synthetic sequences using the trained model
|
70
|
+
```bash
|
71
|
+
export TRANSFORMERS_VERBOSITY=info
|
72
|
+
export CUDA_VISIBLE_DEVICES="0"
|
73
|
+
python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
|
74
|
+
--model_folder test_results \
|
75
|
+
--tokenizer_folder test_results \
|
76
|
+
--output_folder test_results \
|
77
|
+
--num_of_patients 128 \
|
78
|
+
--batch_size 32 \
|
79
|
+
--buffer_size 128 \
|
80
|
+
--context_window 1024 \
|
81
|
+
--sampling_strategy TopPStrategy \
|
82
|
+
--top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
|
83
|
+
--epsilon_cutoff 0.00 \
|
84
|
+
--demographic_data_path sample_data/pretrain
|
85
|
+
```
|
86
|
+
|
87
|
+
## Convert synthetic sequences to OMOP
|
88
|
+
```bash
|
89
|
+
# omop converter requires the OHDSI vocabulary
|
90
|
+
export OMOP_VOCAB_DIR = ""
|
91
|
+
# the omop derived tables need to be built using pyspark
|
92
|
+
export SPARK_WORKER_INSTANCES="1"
|
93
|
+
export SPARK_WORKER_CORES="8"
|
94
|
+
export SPARK_EXECUTOR_CORES="2"
|
95
|
+
export SPARK_DRIVER_MEMORY="2g"
|
96
|
+
export SPARK_EXECUTOR_MEMORY="2g"
|
97
|
+
|
98
|
+
# Convert the sequences, create the omop derived tables
|
99
|
+
sh scripts/omop_pipeline.sh \
|
100
|
+
test_results/top_p10000/generated_sequences/ \
|
101
|
+
test_results/top_p10000/restored_omop/ \
|
102
|
+
$OMOP_VOCAB_DIR
|
103
|
+
```
|
104
|
+
|
58
105
|
## Citation
|
59
106
|
```
|
60
107
|
@article{cehrgpt2024,
|
@@ -63,4 +110,3 @@ pip install .
|
|
63
110
|
journal={arXiv preprint arXiv:2402.04400},
|
64
111
|
year={2024}
|
65
112
|
}
|
66
|
-
```
|
@@ -19,11 +19,57 @@ CEHRGPT is a synthetic data generation model developed to handle structured elec
|
|
19
19
|
To install CEHRGPT, clone this repository and install the required dependencies.
|
20
20
|
|
21
21
|
```bash
|
22
|
-
git clone https://github.com/knatarajan-lab/cehrgpt
|
23
|
-
cd cehrgpt
|
22
|
+
git clone https://github.com/knatarajan-lab/cehrgpt.git
|
23
|
+
cd cehrgpt
|
24
24
|
pip install .
|
25
25
|
```
|
26
26
|
|
27
|
+
## Pretrain
|
28
|
+
Pretrain cehrgpt using the Hugging Face trainer, the parameters can be found in the sample configuration yaml
|
29
|
+
```bash
|
30
|
+
mkdir test_results
|
31
|
+
# This is NOT required when streaming is set to true
|
32
|
+
mkdir test_dataset_prepared
|
33
|
+
python -u -m cehrgpt.runners.hf_cehrgpt_pretrain_runner sample_configs/cehrgpt_pretrain_sample_config.yaml
|
34
|
+
```
|
35
|
+
|
36
|
+
## Generate synthetic sequences
|
37
|
+
Generate synthetic sequences using the trained model
|
38
|
+
```bash
|
39
|
+
export TRANSFORMERS_VERBOSITY=info
|
40
|
+
export CUDA_VISIBLE_DEVICES="0"
|
41
|
+
python -u -m cehrgpt.generation.generate_batch_hf_gpt_sequence \
|
42
|
+
--model_folder test_results \
|
43
|
+
--tokenizer_folder test_results \
|
44
|
+
--output_folder test_results \
|
45
|
+
--num_of_patients 128 \
|
46
|
+
--batch_size 32 \
|
47
|
+
--buffer_size 128 \
|
48
|
+
--context_window 1024 \
|
49
|
+
--sampling_strategy TopPStrategy \
|
50
|
+
--top_p 1.0 --temperature 1.0 --repetition_penalty 1.0 \
|
51
|
+
--epsilon_cutoff 0.00 \
|
52
|
+
--demographic_data_path sample_data/pretrain
|
53
|
+
```
|
54
|
+
|
55
|
+
## Convert synthetic sequences to OMOP
|
56
|
+
```bash
|
57
|
+
# omop converter requires the OHDSI vocabulary
|
58
|
+
export OMOP_VOCAB_DIR = ""
|
59
|
+
# the omop derived tables need to be built using pyspark
|
60
|
+
export SPARK_WORKER_INSTANCES="1"
|
61
|
+
export SPARK_WORKER_CORES="8"
|
62
|
+
export SPARK_EXECUTOR_CORES="2"
|
63
|
+
export SPARK_DRIVER_MEMORY="2g"
|
64
|
+
export SPARK_EXECUTOR_MEMORY="2g"
|
65
|
+
|
66
|
+
# Convert the sequences, create the omop derived tables
|
67
|
+
sh scripts/omop_pipeline.sh \
|
68
|
+
test_results/top_p10000/generated_sequences/ \
|
69
|
+
test_results/top_p10000/restored_omop/ \
|
70
|
+
$OMOP_VOCAB_DIR
|
71
|
+
```
|
72
|
+
|
27
73
|
## Citation
|
28
74
|
```
|
29
75
|
@article{cehrgpt2024,
|
@@ -31,5 +77,4 @@ pip install .
|
|
31
77
|
author={Natarajan, K and others},
|
32
78
|
journal={arXiv preprint arXiv:2402.04400},
|
33
79
|
year={2024}
|
34
|
-
}
|
35
|
-
```
|
80
|
+
}
|
@@ -28,11 +28,12 @@ classifiers = [
|
|
28
28
|
]
|
29
29
|
|
30
30
|
dependencies = [
|
31
|
-
"cehrbert==1.3.
|
31
|
+
"cehrbert==1.3.8",
|
32
32
|
"openai==1.54.3",
|
33
33
|
"optuna==4.0.0",
|
34
34
|
"transformers==4.40.0",
|
35
|
-
"tokenizers==0.19",
|
35
|
+
"tokenizers==0.19.0",
|
36
|
+
"peft==0.10.0",
|
36
37
|
"trl==0.11.4",
|
37
38
|
]
|
38
39
|
|
@@ -0,0 +1,51 @@
|
|
1
|
+
model_name_or_path: "test_results"
|
2
|
+
tokenizer_name_or_path: "test_results"
|
3
|
+
|
4
|
+
data_folder: "sample_data/pretrain"
|
5
|
+
dataset_prepared_path: "test_dataset_prepared"
|
6
|
+
validation_split_percentage: 0.05
|
7
|
+
validation_split_num: 10
|
8
|
+
preprocessing_num_workers: 4
|
9
|
+
preprocessing_batch_size: 1000
|
10
|
+
streaming: true
|
11
|
+
|
12
|
+
#Tokenizer
|
13
|
+
vocab_size: 50000
|
14
|
+
min_frequency: 0
|
15
|
+
|
16
|
+
do_train: true
|
17
|
+
overwrite_output_dir: false
|
18
|
+
resume_from_checkpoint: # path to the checkpoint folder
|
19
|
+
seed: 42
|
20
|
+
|
21
|
+
num_hidden_layers: 6
|
22
|
+
hidden_size: 768
|
23
|
+
n_head: 12
|
24
|
+
max_position_embeddings: 1024
|
25
|
+
|
26
|
+
# torch dataloader configs
|
27
|
+
dataloader_num_workers: 4
|
28
|
+
dataloader_prefetch_factor: 2
|
29
|
+
|
30
|
+
output_dir: "test_results"
|
31
|
+
save_strategy: "steps"
|
32
|
+
evaluation_strategy: "no"
|
33
|
+
learning_rate: 0.00005
|
34
|
+
per_device_train_batch_size: 4
|
35
|
+
per_device_eval_batch_size: 4
|
36
|
+
gradient_accumulation_steps: 1
|
37
|
+
num_train_epochs: 1
|
38
|
+
# When streaming is set to True, max_steps needs to be provided
|
39
|
+
max_steps: 1000
|
40
|
+
save_steps: 500
|
41
|
+
|
42
|
+
warmup_steps: 100
|
43
|
+
weight_decay: 0.01
|
44
|
+
logging_dir: "./logs"
|
45
|
+
logging_steps: 100
|
46
|
+
save_total_limit: 5
|
47
|
+
load_best_model_at_end: false
|
48
|
+
metric_for_best_model: "eval_loss"
|
49
|
+
greater_is_better: false
|
50
|
+
|
51
|
+
report_to: "none"
|
@@ -0,0 +1,55 @@
|
|
1
|
+
#!/bin/bash
|
2
|
+
|
3
|
+
# Exporting input arguments as environment variables
|
4
|
+
export PATIENT_SEQUENCE_FOLDER="$1"
|
5
|
+
export OMOP_FOLDER="$2"
|
6
|
+
export SOURCE_OMOP_FOLDER="$3"
|
7
|
+
export PATIENT_SPLITS_FOLDER="$SOURCE_OMOP_FOLDER/patient_splits"
|
8
|
+
|
9
|
+
# Echoing the values of the environment variables
|
10
|
+
echo "PATIENT_SEQUENCE_FOLDER=$PATIENT_SEQUENCE_FOLDER"
|
11
|
+
echo "OMOP_FOLDER=$OMOP_FOLDER"
|
12
|
+
echo "SOURCE_OMOP_FOLDER=$SOURCE_OMOP_FOLDER"
|
13
|
+
|
14
|
+
# Ensure OMOP_FOLDER exists
|
15
|
+
if [ ! -d "$OMOP_FOLDER" ]; then
|
16
|
+
echo "Creating $OMOP_FOLDER"
|
17
|
+
mkdir -p "$OMOP_FOLDER"
|
18
|
+
fi
|
19
|
+
|
20
|
+
# Removing existing OMOP tables
|
21
|
+
rm -rf $OMOP_FOLDER/{person,visit_occurrence,condition_occurrence,procedure_occurrence,drug_exposure,death,measurement,observation_period,condition_era}
|
22
|
+
|
23
|
+
# Removing existing OMOP concept tables
|
24
|
+
rm -rf $OMOP_FOLDER/{concept,concept_ancestor,concept_relationship}
|
25
|
+
|
26
|
+
# Copying OMOP concept tables if they don't already exist
|
27
|
+
for table in concept concept_relationship concept_ancestor; do
|
28
|
+
if [ ! -d "$OMOP_FOLDER/$table" ]; then
|
29
|
+
echo "Creating $OMOP_FOLDER/$table"
|
30
|
+
cp -r "$SOURCE_OMOP_FOLDER/$table" "$OMOP_FOLDER/$table"
|
31
|
+
fi
|
32
|
+
done
|
33
|
+
|
34
|
+
# Reconstructing the OMOP instance from patient sequences
|
35
|
+
echo "Reconstructing the OMOP instance from patient sequences in $OMOP_FOLDER"
|
36
|
+
python -m cehrgpt.generation.omop_converter_batch \
|
37
|
+
--patient_sequence_path "$PATIENT_SEQUENCE_FOLDER" \
|
38
|
+
--output_folder "$OMOP_FOLDER" \
|
39
|
+
--concept_path "$OMOP_FOLDER/concept" \
|
40
|
+
--buffer_size 1280 \
|
41
|
+
--cpu_cores 10
|
42
|
+
|
43
|
+
# Create observation_period
|
44
|
+
echo "Reconstructing observation_period in $OMOP_FOLDER"
|
45
|
+
python -u -m cehrgpt.omop.observation_period \
|
46
|
+
--input_folder "$OMOP_FOLDER" \
|
47
|
+
--output_folder "$OMOP_FOLDER" \
|
48
|
+
--domain_table_list "condition_occurrence drug_exposure procedure_occurrence measurement"
|
49
|
+
|
50
|
+
# Create condition_era
|
51
|
+
echo "Reconstructing condition_era in $OMOP_FOLDER"
|
52
|
+
python -u -m cehrgpt.omop.condition_era \
|
53
|
+
--input_folder "$OMOP_FOLDER" \
|
54
|
+
--output_folder "$OMOP_FOLDER" \
|
55
|
+
--domain_table_list "condition_occurrence"
|
@@ -0,0 +1,382 @@
|
|
1
|
+
import datetime
|
2
|
+
from typing import Any, Dict
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
|
+
ED_VISIT_TYPE_CODES,
|
8
|
+
INPATIENT_VISIT_TYPE_CODES,
|
9
|
+
INPATIENT_VISIT_TYPES,
|
10
|
+
DatasetMapping,
|
11
|
+
replace_escape_chars,
|
12
|
+
)
|
13
|
+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
14
|
+
from cehrbert_data.const.common import NA
|
15
|
+
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
16
|
+
from dateutil.relativedelta import relativedelta
|
17
|
+
|
18
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
19
|
+
NONE_BIN,
|
20
|
+
UNKNOWN_BIN,
|
21
|
+
CehrGptTokenizer,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
26
|
+
return datetime.datetime.combine(
|
27
|
+
index_date, datetime.datetime.min.time()
|
28
|
+
).timestamp()
|
29
|
+
|
30
|
+
|
31
|
+
class MedToCehrGPTDatasetMapping(DatasetMapping):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
data_args: DataTrainingArguments,
|
35
|
+
is_pretraining: bool = True,
|
36
|
+
include_inpatient_hour_token: bool = True,
|
37
|
+
):
|
38
|
+
self._time_token_function = get_att_function(data_args.att_function_type)
|
39
|
+
self._include_auxiliary_token = data_args.include_auxiliary_token
|
40
|
+
self._inpatient_time_token_function = get_att_function(
|
41
|
+
data_args.inpatient_att_function_type
|
42
|
+
)
|
43
|
+
self._include_demographic_prompt = data_args.include_demographic_prompt
|
44
|
+
self._is_pretraining = is_pretraining
|
45
|
+
self._include_inpatient_hour_token = include_inpatient_hour_token
|
46
|
+
|
47
|
+
"""
|
48
|
+
This mapping function converts the MED (https://github.com/Medical-Event-Data-Standard/meds/tree/main) extension
|
49
|
+
to the CehrGPT format. We make several assumptions
|
50
|
+
- The first event contains the demographic information
|
51
|
+
- From the second event onward
|
52
|
+
- the time of the event is visit_start_datetime.
|
53
|
+
- the first measurement contains the code indicating a standard OMOP Visit concept_id (e.g. 9201, 9202)
|
54
|
+
- in case of inpatient visits, the last measurement is assumed to
|
55
|
+
contain the standard OMOP concept id for discharge facilities (e.g 8536)
|
56
|
+
- in case of inpatient visits, datetime_value of the last measurement stores visit_end_datetime
|
57
|
+
"""
|
58
|
+
|
59
|
+
def remove_columns(self):
|
60
|
+
if self._is_pretraining:
|
61
|
+
return ["visits", "birth_datetime", "index_date"]
|
62
|
+
else:
|
63
|
+
return [
|
64
|
+
"visits",
|
65
|
+
"birth_datetime",
|
66
|
+
"visit_concept_ids",
|
67
|
+
]
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def _update_cehrgpt_record(
|
71
|
+
cehrgpt_record: Dict[str, Any],
|
72
|
+
code: str,
|
73
|
+
concept_value_mask: int = 0,
|
74
|
+
number_as_value: float = 0.0,
|
75
|
+
concept_as_value: str = "0",
|
76
|
+
is_numeric_type: int = 0,
|
77
|
+
unit: str = NA,
|
78
|
+
) -> None:
|
79
|
+
cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
|
80
|
+
cehrgpt_record["concept_value_masks"].append(concept_value_mask)
|
81
|
+
cehrgpt_record["number_as_values"].append(number_as_value)
|
82
|
+
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
83
|
+
cehrgpt_record["units"].append(unit)
|
84
|
+
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
85
|
+
|
86
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
87
|
+
cehrgpt_record = {
|
88
|
+
"person_id": record["patient_id"],
|
89
|
+
"concept_ids": [],
|
90
|
+
"concept_value_masks": [],
|
91
|
+
"number_as_values": [],
|
92
|
+
"concept_as_values": [],
|
93
|
+
"units": [],
|
94
|
+
"is_numeric_types": [],
|
95
|
+
}
|
96
|
+
# Extract the demographic information
|
97
|
+
birth_datetime = record["birth_datetime"]
|
98
|
+
if isinstance(birth_datetime, pd.Timestamp):
|
99
|
+
birth_datetime = birth_datetime.to_pydatetime()
|
100
|
+
gender = record["gender"]
|
101
|
+
race = record["race"]
|
102
|
+
|
103
|
+
# Add the demographic tokens
|
104
|
+
first_visit = record["visits"][0]
|
105
|
+
year_str = f'year:{str(first_visit["visit_start_datetime"].year)}'
|
106
|
+
age_str = f'age:{str(relativedelta(first_visit["visit_start_datetime"], birth_datetime).years)}'
|
107
|
+
self._update_cehrgpt_record(cehrgpt_record, year_str)
|
108
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str)
|
109
|
+
self._update_cehrgpt_record(cehrgpt_record, gender)
|
110
|
+
self._update_cehrgpt_record(cehrgpt_record, race)
|
111
|
+
|
112
|
+
# Use a data cursor to keep track of time
|
113
|
+
date_cursor = None
|
114
|
+
|
115
|
+
# Loop through all the visits excluding the first event containing the demographics
|
116
|
+
for i, visit in enumerate(
|
117
|
+
sorted(record["visits"], key=lambda e: e["visit_start_datetime"])
|
118
|
+
):
|
119
|
+
|
120
|
+
events = visit["events"]
|
121
|
+
|
122
|
+
# Skip this visit if the number measurements in the event is zero
|
123
|
+
if events is None or len(events) == 0:
|
124
|
+
continue
|
125
|
+
|
126
|
+
visit_start_datetime = visit["visit_start_datetime"]
|
127
|
+
time_delta = (
|
128
|
+
(visit_start_datetime - date_cursor).days if date_cursor else None
|
129
|
+
)
|
130
|
+
date_cursor = visit_start_datetime
|
131
|
+
|
132
|
+
# We assume the first measurement to be the visit type of the current visit
|
133
|
+
visit_type = visit["visit_type"]
|
134
|
+
is_er_or_inpatient = (
|
135
|
+
visit_type in INPATIENT_VISIT_TYPES
|
136
|
+
or visit_type in INPATIENT_VISIT_TYPE_CODES
|
137
|
+
or visit_type in ED_VISIT_TYPE_CODES
|
138
|
+
)
|
139
|
+
|
140
|
+
# Add artificial time tokens to the patient timeline if timedelta exists
|
141
|
+
if time_delta is not None:
|
142
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
143
|
+
self._update_cehrgpt_record(
|
144
|
+
cehrgpt_record,
|
145
|
+
code=self._time_token_function(time_delta),
|
146
|
+
)
|
147
|
+
|
148
|
+
# Add the VS token to the patient timeline to mark the start of a visit
|
149
|
+
relativedelta(visit["visit_start_datetime"], birth_datetime).years
|
150
|
+
# Calculate the week number since the epoch time
|
151
|
+
date = (
|
152
|
+
visit["visit_start_datetime"]
|
153
|
+
- datetime.datetime(year=1970, month=1, day=1)
|
154
|
+
).days // 7
|
155
|
+
|
156
|
+
# Add a [VS] token
|
157
|
+
self._update_cehrgpt_record(
|
158
|
+
cehrgpt_record,
|
159
|
+
code="[VS]",
|
160
|
+
)
|
161
|
+
# Add a visit type token
|
162
|
+
self._update_cehrgpt_record(
|
163
|
+
cehrgpt_record,
|
164
|
+
code=visit_type,
|
165
|
+
)
|
166
|
+
# Keep track of the existing outpatient events, we don't want to add them again
|
167
|
+
existing_outpatient_events = list()
|
168
|
+
for e in events:
|
169
|
+
# If the event doesn't have a time stamp, we skip it
|
170
|
+
if not e["time"]:
|
171
|
+
continue
|
172
|
+
|
173
|
+
# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
|
174
|
+
numeric_value = e.get("numeric_value", None)
|
175
|
+
text_value = e.get("text_value", None)
|
176
|
+
# The unit might be populated with a None value
|
177
|
+
unit = e.get("unit", NA) if e.get("unit", NA) else NA
|
178
|
+
concept_value_mask = int(
|
179
|
+
numeric_value is not None or text_value is not None
|
180
|
+
)
|
181
|
+
is_numeric_type = int(numeric_value is not None)
|
182
|
+
code = replace_escape_chars(e["code"])
|
183
|
+
|
184
|
+
# Add a medical token to the patient timeline
|
185
|
+
# If this is an inpatient visit, we use the event time stamps to calculate age and date
|
186
|
+
# because the patient can stay in the hospital for a period of time.
|
187
|
+
if is_er_or_inpatient:
|
188
|
+
# Calculate the week number since the epoch time
|
189
|
+
date = (
|
190
|
+
e["time"] - datetime.datetime(year=1970, month=1, day=1)
|
191
|
+
).days // 7
|
192
|
+
# Calculate the time diff in days w.r.t the previous measurement
|
193
|
+
meas_time_diff = (e["time"] - date_cursor).days
|
194
|
+
# Update the date_cursor if the time diff between two neighboring measurements is greater than and
|
195
|
+
# equal to 1 day
|
196
|
+
if meas_time_diff > 0:
|
197
|
+
date_cursor = e["time"]
|
198
|
+
if self._inpatient_time_token_function:
|
199
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
200
|
+
self._update_cehrgpt_record(
|
201
|
+
cehrgpt_record,
|
202
|
+
code=f"i-{self._inpatient_time_token_function(meas_time_diff)}",
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
# For outpatient visits, we use the visit time stamp to calculate age and time because we assume
|
206
|
+
# the outpatient visits start and end on the same day.
|
207
|
+
# We check whether the date/code/value combination already exists in the existing events
|
208
|
+
# If they exist, we do not add them to the patient timeline for outpatient visits.
|
209
|
+
if (
|
210
|
+
date,
|
211
|
+
code,
|
212
|
+
numeric_value,
|
213
|
+
text_value,
|
214
|
+
concept_value_mask,
|
215
|
+
numeric_value,
|
216
|
+
) in existing_outpatient_events:
|
217
|
+
continue
|
218
|
+
|
219
|
+
self._update_cehrgpt_record(
|
220
|
+
cehrgpt_record,
|
221
|
+
code=code,
|
222
|
+
concept_value_mask=concept_value_mask,
|
223
|
+
unit=unit,
|
224
|
+
number_as_value=numeric_value if numeric_value else 0.0,
|
225
|
+
concept_as_value=(
|
226
|
+
replace_escape_chars(text_value) if text_value else "0"
|
227
|
+
),
|
228
|
+
is_numeric_type=is_numeric_type,
|
229
|
+
)
|
230
|
+
existing_outpatient_events.append(
|
231
|
+
(
|
232
|
+
date,
|
233
|
+
code,
|
234
|
+
numeric_value,
|
235
|
+
text_value,
|
236
|
+
concept_value_mask,
|
237
|
+
numeric_value,
|
238
|
+
)
|
239
|
+
)
|
240
|
+
|
241
|
+
# For inpatient or ER visits, we want to discharge_facility to the end of the visit
|
242
|
+
if is_er_or_inpatient:
|
243
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the date_cursor
|
244
|
+
visit_end_datetime = visit.get("visit_end_datetime", None)
|
245
|
+
if visit_end_datetime:
|
246
|
+
date_cursor = visit_end_datetime
|
247
|
+
|
248
|
+
if self._include_auxiliary_token:
|
249
|
+
# Reuse the age and date calculated for the last event in the patient timeline for the discharge
|
250
|
+
# facility event
|
251
|
+
discharge_facility = (
|
252
|
+
visit["discharge_facility"]
|
253
|
+
if ("discharge_facility" in visit)
|
254
|
+
and visit["discharge_facility"]
|
255
|
+
else "0"
|
256
|
+
)
|
257
|
+
|
258
|
+
self._update_cehrgpt_record(
|
259
|
+
cehrgpt_record,
|
260
|
+
code=discharge_facility,
|
261
|
+
)
|
262
|
+
|
263
|
+
# Reuse the age and date calculated for the last event in the patient timeline
|
264
|
+
self._update_cehrgpt_record(
|
265
|
+
cehrgpt_record,
|
266
|
+
code="[VE]",
|
267
|
+
)
|
268
|
+
|
269
|
+
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
270
|
+
cehrgpt_record["orders"] = list(
|
271
|
+
range(1, len(cehrgpt_record["concept_ids"]) + 1)
|
272
|
+
)
|
273
|
+
|
274
|
+
# Add some count information for this sequence
|
275
|
+
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
276
|
+
cehrgpt_record["num_of_visits"] = len(record["visits"])
|
277
|
+
|
278
|
+
if "label" in record:
|
279
|
+
cehrgpt_record["label"] = record["label"]
|
280
|
+
if "age_at_index" in record:
|
281
|
+
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
282
|
+
|
283
|
+
return cehrgpt_record
|
284
|
+
|
285
|
+
|
286
|
+
class HFCehrGptTokenizationMapping(DatasetMapping):
|
287
|
+
def __init__(
|
288
|
+
self,
|
289
|
+
concept_tokenizer: CehrGptTokenizer,
|
290
|
+
):
|
291
|
+
self._concept_tokenizer = concept_tokenizer
|
292
|
+
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
|
293
|
+
|
294
|
+
def remove_columns(self):
|
295
|
+
return [
|
296
|
+
"concept_value_masks",
|
297
|
+
"is_numeric_types",
|
298
|
+
]
|
299
|
+
|
300
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
301
|
+
# If any concept has a value associated with it, we normalize the value
|
302
|
+
record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
|
303
|
+
record["value_indicators"] = record["concept_value_masks"]
|
304
|
+
if "number_as_values" not in record or "concept_as_values" not in record:
|
305
|
+
record["number_as_values"] = [
|
306
|
+
float(value) if isinstance(value, float) else None
|
307
|
+
for value in record["concept_values"]
|
308
|
+
]
|
309
|
+
record["is_numeric_types"] = [
|
310
|
+
int(isinstance(value, float)) for value in record["concept_values"]
|
311
|
+
]
|
312
|
+
record["concept_as_values"] = [
|
313
|
+
value if isinstance(value, str) else None
|
314
|
+
for value in record["concept_values"]
|
315
|
+
]
|
316
|
+
if np.any(np.asarray(record["concept_value_masks"]) > 0):
|
317
|
+
values = []
|
318
|
+
for i, (
|
319
|
+
concept_id,
|
320
|
+
unit,
|
321
|
+
concept_value_mask,
|
322
|
+
number_as_value,
|
323
|
+
concept_as_value,
|
324
|
+
is_numeric_type,
|
325
|
+
) in enumerate(
|
326
|
+
zip(
|
327
|
+
record["concept_ids"],
|
328
|
+
record["units"],
|
329
|
+
record["concept_value_masks"],
|
330
|
+
record["number_as_values"],
|
331
|
+
record["concept_as_values"],
|
332
|
+
record["is_numeric_types"],
|
333
|
+
)
|
334
|
+
):
|
335
|
+
if concept_value_mask == 1:
|
336
|
+
value = UNKNOWN_BIN
|
337
|
+
if is_numeric_type == 1:
|
338
|
+
if concept_id in self._concept_tokenizer.numeric_concept_ids:
|
339
|
+
value = self._concept_tokenizer.normalize(
|
340
|
+
concept_id, unit, number_as_value
|
341
|
+
)
|
342
|
+
elif isinstance(concept_as_value, str):
|
343
|
+
value = concept_as_value
|
344
|
+
values.append(value)
|
345
|
+
else:
|
346
|
+
values.append(NONE_BIN)
|
347
|
+
assert len(values) == len(record["input_ids"])
|
348
|
+
record["values"] = self._concept_tokenizer.encode_value(values)
|
349
|
+
else:
|
350
|
+
record["values"] = self._concept_tokenizer.encode_value(
|
351
|
+
[NONE_BIN for _ in range(len(record["concept_value_masks"]))]
|
352
|
+
)
|
353
|
+
# Delete these features because they contain null values and pyarrow cannot concatenate multiple records
|
354
|
+
del record["number_as_values"]
|
355
|
+
del record["concept_as_values"]
|
356
|
+
return record
|
357
|
+
|
358
|
+
|
359
|
+
class HFFineTuningMapping(HFCehrGptTokenizationMapping):
|
360
|
+
"""Consider removing this transformation in the future."""
|
361
|
+
|
362
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
363
|
+
record = super().transform(record)
|
364
|
+
record.update(
|
365
|
+
{
|
366
|
+
"age_at_index": (
|
367
|
+
record["age"] if "age" in record else record["age_at_index"]
|
368
|
+
),
|
369
|
+
"classifier_label": int(record["label"] > 0),
|
370
|
+
"index_date": (
|
371
|
+
convert_date_to_posix_time(record["index_date"])
|
372
|
+
if "index_date" in record
|
373
|
+
else None
|
374
|
+
),
|
375
|
+
}
|
376
|
+
)
|
377
|
+
return record
|
378
|
+
|
379
|
+
def remove_columns(self):
|
380
|
+
columns = super().remove_columns()
|
381
|
+
columns.append("label")
|
382
|
+
return columns
|