cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/gpt_utils.py CHANGED
@@ -1,7 +1,12 @@
1
1
  import random
2
2
  import re
3
- from datetime import date, timedelta
4
- from typing import List, Sequence, Tuple
3
+ from datetime import date, datetime, timedelta, timezone
4
+ from typing import List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
8
+ from meds import death_code
9
+ from transformers.utils import logging
5
10
 
6
11
  from cehrgpt.cehrgpt_args import SamplingStrategy
7
12
  from cehrgpt.models.special_tokens import (
@@ -14,6 +19,7 @@ from cehrgpt.models.special_tokens import (
14
19
  MEDS_CODE_PATTERN = re.compile(r".*/.*")
15
20
  INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
16
21
  DEMOGRAPHIC_PROMPT_SIZE = 4
22
+ logger = logging.get_logger("transformers")
17
23
 
18
24
 
19
25
  class RandomSampleCache:
@@ -62,6 +68,68 @@ class RandomSampleCache:
62
68
  return self._cache.pop()
63
69
 
64
70
 
71
+ def construct_time_sequence(
72
+ concept_ids: List[str], epoch_times: Optional[List[Union[int, float]]] = None
73
+ ) -> List[float]:
74
+ if epoch_times is not None:
75
+ return epoch_times
76
+
77
+ if concept_ids[0].lower().startswith("year"):
78
+ year_str = concept_ids[0].split(":")[1]
79
+ else:
80
+ year_str = "1985"
81
+
82
+ datetime_cursor = datetime(
83
+ int(year_str), month=1, day=1, hour=0, minute=0, second=0
84
+ ).replace(tzinfo=timezone.utc)
85
+ epoch_times = []
86
+ for concept_id in concept_ids:
87
+ if is_att_token(concept_id):
88
+ att_days = extract_time_interval_in_days(concept_id)
89
+ datetime_cursor += timedelta(days=att_days)
90
+ epoch_times.append(datetime_cursor.timestamp())
91
+ return epoch_times
92
+
93
+
94
+ def construct_age_sequence(
95
+ concept_ids: List[str], ages: Optional[List[int]] = None
96
+ ) -> List[int]:
97
+ if ages is not None:
98
+ return ages
99
+ elif concept_ids[1].lower().startswith("age"):
100
+ age_str = concept_ids[1].split(":")[1]
101
+ assert age_str.isnumeric(), f"age_str: {age_str}"
102
+ ages = []
103
+ time_delta = 0
104
+ for concept_id in concept_ids:
105
+ if is_att_token(concept_id):
106
+ time_delta += extract_time_interval_in_days(concept_id)
107
+ ages.append(int(age_str) + time_delta // 365)
108
+ return ages
109
+ else:
110
+ logger.warning(
111
+ "The second token is not a valid age token. The first 4 tokens are: %s. "
112
+ "Trying to fall back to ages, but it is not valid either %s. "
113
+ "Fall back to a zero vector [0, 0, 0, ...., 0]",
114
+ concept_ids[:4],
115
+ ages,
116
+ )
117
+ return np.zeros_like(concept_ids, dtype=int).tolist()
118
+
119
+
120
+ def multiple_of_10(n: int) -> int:
121
+ return ((n // 10) + 1) * 10
122
+
123
+
124
+ def encode_demographics(
125
+ age: int, gender: int, race: int, max_age=200, max_gender=10, max_race=10
126
+ ) -> int:
127
+ assert 0 <= age < max_age, f"age: {age}"
128
+ assert 0 <= gender < max_gender, f"gender: {gender}"
129
+ assert 0 <= race < max_race, f"race: {race}"
130
+ return age + max_age * gender + max_age * max_gender * race
131
+
132
+
65
133
  def collect_demographic_prompts_at_visits(patient_history: List[str]):
66
134
  demographic_prompts_at_visits = []
67
135
  start_year, start_age, start_gender, start_race = patient_history[
@@ -156,7 +224,7 @@ def random_slice_gpt_sequence(concept_ids, max_seq_len):
156
224
  )
157
225
  ):
158
226
  current_token = concept_ids[i]
159
- if current_token == "VE":
227
+ if is_visit_end(current_token):
160
228
  random_end_index = i
161
229
  break
162
230
  return random_starting_index, random_end_index, demographic_tokens
@@ -198,6 +266,8 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
198
266
  def is_clinical_event(token: str, meds: bool = False) -> bool:
199
267
  if token.isnumeric():
200
268
  return True
269
+ if token in [DEATH_TOKEN, death_code]:
270
+ return True
201
271
  if meds:
202
272
  return bool(MEDS_CODE_PATTERN.match(token))
203
273
  return False
@@ -0,0 +1,27 @@
1
+ # From https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
2
+ # coding=utf-8
3
+
4
+ from __future__ import absolute_import, division, print_function
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import transformers.pytorch_utils
9
+
10
+
11
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
12
+ class RMSNorm(nn.Module):
13
+ def __init__(self, hidden_size, eps=1e-6):
14
+ """MistralRMSNorm is equivalent to T5LayerNorm."""
15
+ super().__init__()
16
+ self.weight = nn.Parameter(torch.ones(hidden_size))
17
+ self.variance_epsilon = eps
18
+
19
+ def forward(self, hidden_states):
20
+ input_dtype = hidden_states.dtype
21
+ hidden_states = hidden_states.to(torch.float32)
22
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
23
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
24
+ return self.weight * hidden_states.to(input_dtype)
25
+
26
+
27
+ transformers.pytorch_utils.ALL_LAYERNORM_LAYERS.extend([RMSNorm])
cehrgpt/models/config.py CHANGED
@@ -106,6 +106,8 @@ class CEHRGPTConfig(PretrainedConfig):
106
106
  n_head=12,
107
107
  n_inner=None,
108
108
  activation_function="gelu_new",
109
+ decoder_mlp="GPT2MLP",
110
+ mlp_bias=False,
109
111
  resid_pdrop=0.1,
110
112
  embd_pdrop=0.1,
111
113
  attn_pdrop=0.1,
@@ -124,7 +126,7 @@ class CEHRGPTConfig(PretrainedConfig):
124
126
  ve_token_id=None,
125
127
  scale_attn_by_inverse_layer_idx=False,
126
128
  reorder_and_upcast_attn=False,
127
- exclude_position_ids=False,
129
+ apply_rotary=False,
128
130
  include_values=False,
129
131
  value_vocab_size=None,
130
132
  include_ttv_prediction=False,
@@ -169,6 +171,8 @@ class CEHRGPTConfig(PretrainedConfig):
169
171
  self.n_head = n_head
170
172
  self.n_inner = n_inner
171
173
  self.activation_function = activation_function
174
+ self.decoder_mlp = decoder_mlp
175
+ self.mlp_bias = mlp_bias
172
176
  self.resid_pdrop = resid_pdrop
173
177
  self.embd_pdrop = embd_pdrop
174
178
  self.attn_pdrop = attn_pdrop
@@ -188,7 +192,7 @@ class CEHRGPTConfig(PretrainedConfig):
188
192
  self.eos_token_id = eos_token_id
189
193
  self.lab_token_ids = lab_token_ids
190
194
 
191
- self.exclude_position_ids = exclude_position_ids
195
+ self.apply_rotary = apply_rotary
192
196
  self.include_values = include_values
193
197
  self.value_vocab_size = value_vocab_size
194
198