cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,36 @@
1
+ import os
2
+
3
+ import polars as pl
4
+
5
+ from cehrgpt.gpt_utils import extract_time_interval_in_days, is_att_token
6
+
7
+
8
+ def main(args):
9
+ dataset = pl.read_parquet(os.path.join(args.input_dir, "*.parquet"))
10
+ time_token_frequency_df = (
11
+ dataset.select(pl.col("concept_ids").explode().alias("concept_id"))
12
+ .filter(pl.col("concept_id").map_elements(is_att_token))
13
+ .with_columns(
14
+ pl.col("concept_id")
15
+ .map_elements(extract_time_interval_in_days)
16
+ .alias("time_interval")
17
+ )
18
+ )
19
+ results = time_token_frequency_df.select(
20
+ pl.mean("time_interval").alias("mean"), pl.std("time_interval").alias("std")
21
+ ).to_dicts()[0]
22
+ print(results)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+
28
+ parser = argparse.ArgumentParser(description="EHR Irregularity analysis")
29
+ parser.add_argument(
30
+ "--input_dir",
31
+ dest="input_dir",
32
+ action="store",
33
+ help="The path for where the input data folder",
34
+ required=True,
35
+ )
36
+ main(parser.parse_args())
@@ -1,9 +1,10 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from cehrbert.data_generators.hf_data_generator.hf_dataset import (
4
4
  FINETUNING_COLUMNS,
5
5
  apply_cehrbert_dataset_mapping,
6
6
  )
7
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
7
8
  from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
8
9
  from datasets import Dataset, DatasetDict
9
10
 
@@ -22,6 +23,7 @@ CEHRGPT_COLUMNS = [
22
23
  "num_of_visits",
23
24
  "values",
24
25
  "value_indicators",
26
+ "epoch_times",
25
27
  ]
26
28
 
27
29
  TRANSFORMER_COLUMNS = ["input_ids"]
@@ -31,16 +33,25 @@ def create_cehrgpt_pretraining_dataset(
31
33
  dataset: Union[Dataset, DatasetDict],
32
34
  cehrgpt_tokenizer: CehrGptTokenizer,
33
35
  data_args: DataTrainingArguments,
34
- ) -> Dataset:
36
+ cache_file_collector: Optional[CacheFileCollector] = None,
37
+ ) -> Union[Dataset, DatasetDict]:
35
38
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS
39
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
40
+ if not data_args.streaming:
41
+ if isinstance(dataset, DatasetDict):
42
+ all_columns = dataset["train"].column_names
43
+ else:
44
+ all_columns = dataset.column_names
45
+ if "visit_concept_ids" in all_columns:
46
+ dataset.remove_columns(["visit_concept_ids"])
36
47
  dataset = apply_cehrbert_dataset_mapping(
37
48
  dataset,
38
49
  HFCehrGptTokenizationMapping(cehrgpt_tokenizer),
39
50
  num_proc=data_args.preprocessing_num_workers,
40
51
  batch_size=data_args.preprocessing_batch_size,
41
52
  streaming=data_args.streaming,
53
+ cache_file_collector=cache_file_collector,
42
54
  )
43
-
44
55
  if not data_args.streaming:
45
56
  if isinstance(dataset, DatasetDict):
46
57
  all_columns = dataset["train"].column_names
@@ -56,8 +67,17 @@ def create_cehrgpt_finetuning_dataset(
56
67
  dataset: Union[Dataset, DatasetDict],
57
68
  cehrgpt_tokenizer: CehrGptTokenizer,
58
69
  data_args: DataTrainingArguments,
59
- ) -> Dataset:
70
+ cache_file_collector: Optional[CacheFileCollector] = None,
71
+ ) -> Union[Dataset, DatasetDict]:
60
72
  required_columns = TRANSFORMER_COLUMNS + CEHRGPT_COLUMNS + FINETUNING_COLUMNS
73
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
74
+ if not data_args.streaming:
75
+ if isinstance(dataset, DatasetDict):
76
+ all_columns = dataset["train"].column_names
77
+ else:
78
+ all_columns = dataset.column_names
79
+ if "visit_concept_ids" in all_columns:
80
+ dataset.remove_columns(["visit_concept_ids"])
61
81
  mapping_functions = [
62
82
  HFFineTuningMapping(cehrgpt_tokenizer),
63
83
  ]
@@ -68,6 +88,7 @@ def create_cehrgpt_finetuning_dataset(
68
88
  num_proc=data_args.preprocessing_num_workers,
69
89
  batch_size=data_args.preprocessing_batch_size,
70
90
  streaming=data_args.streaming,
91
+ cache_file_collector=cache_file_collector,
71
92
  )
72
93
 
73
94
  if not data_args.streaming: