cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
- cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +193 -459
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/data_utils.py +17 -6
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
- cehrgpt-0.1.1.dist-info/METADATA +0 -115
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
cehrgpt/gpt_utils.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
1
|
import random
|
2
2
|
import re
|
3
|
-
from datetime import date, timedelta
|
4
|
-
from typing import List, Sequence, Tuple
|
3
|
+
from datetime import date, datetime, timedelta, timezone
|
4
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
8
|
+
from meds import death_code
|
9
|
+
from transformers.utils import logging
|
5
10
|
|
6
11
|
from cehrgpt.cehrgpt_args import SamplingStrategy
|
7
12
|
from cehrgpt.models.special_tokens import (
|
@@ -14,6 +19,7 @@ from cehrgpt.models.special_tokens import (
|
|
14
19
|
MEDS_CODE_PATTERN = re.compile(r".*/.*")
|
15
20
|
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
16
21
|
DEMOGRAPHIC_PROMPT_SIZE = 4
|
22
|
+
logger = logging.get_logger("transformers")
|
17
23
|
|
18
24
|
|
19
25
|
class RandomSampleCache:
|
@@ -62,6 +68,68 @@ class RandomSampleCache:
|
|
62
68
|
return self._cache.pop()
|
63
69
|
|
64
70
|
|
71
|
+
def construct_time_sequence(
|
72
|
+
concept_ids: List[str], epoch_times: Optional[List[Union[int, float]]] = None
|
73
|
+
) -> List[float]:
|
74
|
+
if epoch_times is not None:
|
75
|
+
return epoch_times
|
76
|
+
|
77
|
+
if concept_ids[0].lower().startswith("year"):
|
78
|
+
year_str = concept_ids[0].split(":")[1]
|
79
|
+
else:
|
80
|
+
year_str = "1985"
|
81
|
+
|
82
|
+
datetime_cursor = datetime(
|
83
|
+
int(year_str), month=1, day=1, hour=0, minute=0, second=0
|
84
|
+
).replace(tzinfo=timezone.utc)
|
85
|
+
epoch_times = []
|
86
|
+
for concept_id in concept_ids:
|
87
|
+
if is_att_token(concept_id):
|
88
|
+
att_days = extract_time_interval_in_days(concept_id)
|
89
|
+
datetime_cursor += timedelta(days=att_days)
|
90
|
+
epoch_times.append(datetime_cursor.timestamp())
|
91
|
+
return epoch_times
|
92
|
+
|
93
|
+
|
94
|
+
def construct_age_sequence(
|
95
|
+
concept_ids: List[str], ages: Optional[List[int]] = None
|
96
|
+
) -> List[int]:
|
97
|
+
if ages is not None:
|
98
|
+
return ages
|
99
|
+
elif concept_ids[1].lower().startswith("age"):
|
100
|
+
age_str = concept_ids[1].split(":")[1]
|
101
|
+
assert age_str.isnumeric(), f"age_str: {age_str}"
|
102
|
+
ages = []
|
103
|
+
time_delta = 0
|
104
|
+
for concept_id in concept_ids:
|
105
|
+
if is_att_token(concept_id):
|
106
|
+
time_delta += extract_time_interval_in_days(concept_id)
|
107
|
+
ages.append(int(age_str) + time_delta // 365)
|
108
|
+
return ages
|
109
|
+
else:
|
110
|
+
logger.warning(
|
111
|
+
"The second token is not a valid age token. The first 4 tokens are: %s. "
|
112
|
+
"Trying to fall back to ages, but it is not valid either %s. "
|
113
|
+
"Fall back to a zero vector [0, 0, 0, ...., 0]",
|
114
|
+
concept_ids[:4],
|
115
|
+
ages,
|
116
|
+
)
|
117
|
+
return np.zeros_like(concept_ids, dtype=int).tolist()
|
118
|
+
|
119
|
+
|
120
|
+
def multiple_of_10(n: int) -> int:
|
121
|
+
return ((n // 10) + 1) * 10
|
122
|
+
|
123
|
+
|
124
|
+
def encode_demographics(
|
125
|
+
age: int, gender: int, race: int, max_age=200, max_gender=10, max_race=10
|
126
|
+
) -> int:
|
127
|
+
assert 0 <= age < max_age, f"age: {age}"
|
128
|
+
assert 0 <= gender < max_gender, f"gender: {gender}"
|
129
|
+
assert 0 <= race < max_race, f"race: {race}"
|
130
|
+
return age + max_age * gender + max_age * max_gender * race
|
131
|
+
|
132
|
+
|
65
133
|
def collect_demographic_prompts_at_visits(patient_history: List[str]):
|
66
134
|
demographic_prompts_at_visits = []
|
67
135
|
start_year, start_age, start_gender, start_race = patient_history[
|
@@ -156,7 +224,7 @@ def random_slice_gpt_sequence(concept_ids, max_seq_len):
|
|
156
224
|
)
|
157
225
|
):
|
158
226
|
current_token = concept_ids[i]
|
159
|
-
if current_token
|
227
|
+
if is_visit_end(current_token):
|
160
228
|
random_end_index = i
|
161
229
|
break
|
162
230
|
return random_starting_index, random_end_index, demographic_tokens
|
@@ -198,6 +266,8 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
|
198
266
|
def is_clinical_event(token: str, meds: bool = False) -> bool:
|
199
267
|
if token.isnumeric():
|
200
268
|
return True
|
269
|
+
if token in [DEATH_TOKEN, death_code]:
|
270
|
+
return True
|
201
271
|
if meds:
|
202
272
|
return bool(MEDS_CODE_PATTERN.match(token))
|
203
273
|
return False
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# From https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
|
2
|
+
# coding=utf-8
|
3
|
+
|
4
|
+
from __future__ import absolute_import, division, print_function
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import transformers.pytorch_utils
|
9
|
+
|
10
|
+
|
11
|
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
|
12
|
+
class RMSNorm(nn.Module):
|
13
|
+
def __init__(self, hidden_size, eps=1e-6):
|
14
|
+
"""MistralRMSNorm is equivalent to T5LayerNorm."""
|
15
|
+
super().__init__()
|
16
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
17
|
+
self.variance_epsilon = eps
|
18
|
+
|
19
|
+
def forward(self, hidden_states):
|
20
|
+
input_dtype = hidden_states.dtype
|
21
|
+
hidden_states = hidden_states.to(torch.float32)
|
22
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
23
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
24
|
+
return self.weight * hidden_states.to(input_dtype)
|
25
|
+
|
26
|
+
|
27
|
+
transformers.pytorch_utils.ALL_LAYERNORM_LAYERS.extend([RMSNorm])
|
cehrgpt/models/config.py
CHANGED
@@ -106,6 +106,8 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
106
106
|
n_head=12,
|
107
107
|
n_inner=None,
|
108
108
|
activation_function="gelu_new",
|
109
|
+
decoder_mlp="GPT2MLP",
|
110
|
+
mlp_bias=False,
|
109
111
|
resid_pdrop=0.1,
|
110
112
|
embd_pdrop=0.1,
|
111
113
|
attn_pdrop=0.1,
|
@@ -124,7 +126,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
124
126
|
ve_token_id=None,
|
125
127
|
scale_attn_by_inverse_layer_idx=False,
|
126
128
|
reorder_and_upcast_attn=False,
|
127
|
-
|
129
|
+
apply_rotary=False,
|
128
130
|
include_values=False,
|
129
131
|
value_vocab_size=None,
|
130
132
|
include_ttv_prediction=False,
|
@@ -169,6 +171,8 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
169
171
|
self.n_head = n_head
|
170
172
|
self.n_inner = n_inner
|
171
173
|
self.activation_function = activation_function
|
174
|
+
self.decoder_mlp = decoder_mlp
|
175
|
+
self.mlp_bias = mlp_bias
|
172
176
|
self.resid_pdrop = resid_pdrop
|
173
177
|
self.embd_pdrop = embd_pdrop
|
174
178
|
self.attn_pdrop = attn_pdrop
|
@@ -188,7 +192,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
188
192
|
self.eos_token_id = eos_token_id
|
189
193
|
self.lab_token_ids = lab_token_ids
|
190
194
|
|
191
|
-
self.
|
195
|
+
self.apply_rotary = apply_rotary
|
192
196
|
self.include_values = include_values
|
193
197
|
self.value_vocab_size = value_vocab_size
|
194
198
|
|