cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import polars as pl
|
4
|
+
|
5
|
+
from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
|
6
|
+
|
7
|
+
|
8
|
+
def main(args):
|
9
|
+
dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
|
10
|
+
time_token_frequency_df = (
|
11
|
+
dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
|
12
|
+
.filter(pl.col("concept_id").map_elements(is_att_token))
|
13
|
+
.with_columns(
|
14
|
+
pl.col("concept_id")
|
15
|
+
.map_elements(extract_time_interval_in_days)
|
16
|
+
.alias("time_interval")
|
17
|
+
)
|
18
|
+
)
|
19
|
+
results = time_token_frequency_df.select(
|
20
|
+
pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
|
21
|
+
).to_dicts()[0]
|
22
|
+
print(results)
|
23
|
+
|
24
|
+
|
25
|
+
if __name__ == "__main__":
|
26
|
+
import argparse
|
27
|
+
|
28
|
+
parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
|
29
|
+
parser.add_argument(
|
30
|
+
"--input_dir",
|
31
|
+
dest="input_dir",
|
32
|
+
action="store",
|
33
|
+
help="The path for where the input data folder",
|
34
|
+
required=True,
|
35
|
+
)
|
36
|
+
main(parser.parse_args())
|
@@ -1,9 +1,10 @@
|
|
1
|
-
from typing import Union
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
from cehrbert.data_generators.hf_data_generator.hf_dataset import (
|
4
4
|
FINETUNING_COLUMNS,
|
5
5
|
apply_cehrbert_dataset_mapping,
|
6
6
|
)
|
7
|
+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
|
7
8
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
8
9
|
from datasets import Dataset, DatasetDict
|
9
10
|
|
@@ -22,6 +23,7 @@ CEHRGPT_COLUMNS = [
|
|
22
23
|
"num_of_visits",
|
23
24
|
"values",
|
24
25
|
"value_indicators",
|
26
|
+
"epoch_times",
|
25
27
|
]
|
26
28
|
|
27
29
|
TRANSFORMER_COLUMNS = ["input_ids"]
|
@@ -31,16 +33,25 @@ def create_cehrgpt_pretraining_dataset(
|
|
31
33
|
dataset: Union[Dataset, DatasetDict],
|
32
34
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
33
35
|
data_args: DataTrainingArguments,
|
34
|
-
|
36
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
37
|
+
) -> Union[Dataset, DatasetDict]:
|
35
38
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
|
39
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
40
|
+
if not data_args.streaming:
|
41
|
+
if isinstance(dataset, DatasetDict):
|
42
|
+
all_columns = dataset["train"].column_names
|
43
|
+
else:
|
44
|
+
all_columns = dataset.column_names
|
45
|
+
if "visit_concept_ids" in all_columns:
|
46
|
+
dataset.remove_columns(["visit_concept_ids"])
|
36
47
|
dataset = apply_cehrbert_dataset_mapping(
|
37
48
|
dataset,
|
38
49
|
HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
|
39
50
|
num_proc=data_args.preprocessing_num_workers,
|
40
51
|
batch_size=data_args.preprocessing_batch_size,
|
41
52
|
streaming=data_args.streaming,
|
53
|
+
cache_file_collector=cache_file_collector,
|
42
54
|
)
|
43
|
-
|
44
55
|
if not data_args.streaming:
|
45
56
|
if isinstance(dataset, DatasetDict):
|
46
57
|
all_columns = dataset["train"].column_names
|
@@ -56,8 +67,17 @@ def create_cehrgpt_finetuning_dataset(
|
|
56
67
|
dataset: Union[Dataset, DatasetDict],
|
57
68
|
cehrgpt_tokenizer: CehrGptTokenizer,
|
58
69
|
data_args: DataTrainingArguments,
|
59
|
-
|
70
|
+
cache_file_collector: Optional[CacheFileCollector] = None,
|
71
|
+
) -> Union[Dataset, DatasetDict]:
|
60
72
|
required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
|
73
|
+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
|
74
|
+
if not data_args.streaming:
|
75
|
+
if isinstance(dataset, DatasetDict):
|
76
|
+
all_columns = dataset["train"].column_names
|
77
|
+
else:
|
78
|
+
all_columns = dataset.column_names
|
79
|
+
if "visit_concept_ids" in all_columns:
|
80
|
+
dataset.remove_columns(["visit_concept_ids"])
|
61
81
|
mapping_functions = [
|
62
82
|
HFFineTuningMapping(cehrgpt_tokenizer),
|
63
83
|
]
|
@@ -68,6 +88,7 @@ def create_cehrgpt_finetuning_dataset(
|
|
68
88
|
num_proc=data_args.preprocessing_num_workers,
|
69
89
|
batch_size=data_args.preprocessing_batch_size,
|
70
90
|
streaming=data_args.streaming,
|
91
|
+
cache_file_collector=cache_file_collector,
|
71
92
|
)
|
72
93
|
|
73
94
|
if not data_args.streaming:
|