cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
cehrgpt/gpt_utils.py
ADDED
@@ -0,0 +1,331 @@
|
|
1
|
+
import random
|
2
|
+
import re
|
3
|
+
from datetime import date, timedelta
|
4
|
+
from typing import List, Sequence, Tuple
|
5
|
+
|
6
|
+
from cehrgpt.cehrgpt_args import SamplingStrategy
|
7
|
+
from cehrgpt.models.special_tokens import (
|
8
|
+
DISCHARGE_CONCEPT_IDS,
|
9
|
+
END_TOKEN,
|
10
|
+
VISIT_CONCEPT_IDS,
|
11
|
+
)
|
12
|
+
|
13
|
+
# Regular expression pattern to match inpatient attendance tokens
|
14
|
+
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
15
|
+
DEMOGRAPHIC_PROMPT_SIZE = 4
|
16
|
+
|
17
|
+
|
18
|
+
class RandomSampleCache:
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
data_indices: Sequence[int],
|
22
|
+
cache_size: int,
|
23
|
+
sample_weights: Sequence[float] = None,
|
24
|
+
):
|
25
|
+
"""
|
26
|
+
Initialize the RandomSampleCache.
|
27
|
+
|
28
|
+
:param data_indices: Sequence of data indices to sample from.
|
29
|
+
:param cache_size: Size of the cache.
|
30
|
+
:param sample_weights: Optional sequence of weights for sampling.
|
31
|
+
"""
|
32
|
+
self._data_indices = data_indices
|
33
|
+
self._sample_weights = sample_weights
|
34
|
+
self._cache_size = cache_size
|
35
|
+
self._cache = []
|
36
|
+
|
37
|
+
if self._sample_weights is not None:
|
38
|
+
assert sum(self._sample_weights) - 1 < 1e-8
|
39
|
+
|
40
|
+
def next(self):
|
41
|
+
"""
|
42
|
+
Get the next sample from the cache.
|
43
|
+
|
44
|
+
If the cache is empty, refill it.
|
45
|
+
|
46
|
+
:return: A sampled data index.
|
47
|
+
"""
|
48
|
+
if not self._cache:
|
49
|
+
if self._sample_weights is not None:
|
50
|
+
self._cache.extend(
|
51
|
+
random.choices(
|
52
|
+
self._data_indices,
|
53
|
+
k=self._cache_size,
|
54
|
+
weights=self._sample_weights,
|
55
|
+
)
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
self._cache.extend(
|
59
|
+
random.choices(self._data_indices, k=self._cache_size)
|
60
|
+
)
|
61
|
+
return self._cache.pop()
|
62
|
+
|
63
|
+
|
64
|
+
def collect_demographic_prompts_at_visits(patient_history: List[str]):
|
65
|
+
demographic_prompts_at_visits = []
|
66
|
+
start_year, start_age, start_gender, start_race = patient_history[
|
67
|
+
:DEMOGRAPHIC_PROMPT_SIZE
|
68
|
+
]
|
69
|
+
try:
|
70
|
+
start_year = int(start_year.split(":")[1])
|
71
|
+
start_age = int(start_age.split(":")[1])
|
72
|
+
valid_prompt = True
|
73
|
+
except IndexError | ValueError:
|
74
|
+
start_year = 1900
|
75
|
+
start_age = 0
|
76
|
+
valid_prompt = False
|
77
|
+
data_cursor = date(int(start_year), 1, 1)
|
78
|
+
birth_date = date(start_year - start_age, 1, 1)
|
79
|
+
for i, current_token in enumerate(patient_history):
|
80
|
+
if is_visit_start(current_token):
|
81
|
+
reconstructed_year = (
|
82
|
+
f"year:{data_cursor.year}" if valid_prompt else "year:unknown"
|
83
|
+
)
|
84
|
+
reconstructed_age = (
|
85
|
+
f"age:{data_cursor.year - birth_date.year}"
|
86
|
+
if valid_prompt
|
87
|
+
else "age:unknown"
|
88
|
+
)
|
89
|
+
demographic_prompts_at_visits.append(
|
90
|
+
(
|
91
|
+
i,
|
92
|
+
(
|
93
|
+
reconstructed_year,
|
94
|
+
reconstructed_age,
|
95
|
+
start_gender,
|
96
|
+
start_race,
|
97
|
+
),
|
98
|
+
)
|
99
|
+
)
|
100
|
+
elif is_att_token(current_token):
|
101
|
+
att_date_delta = extract_time_interval_in_days(current_token)
|
102
|
+
data_cursor = data_cursor + timedelta(days=att_date_delta)
|
103
|
+
return demographic_prompts_at_visits
|
104
|
+
|
105
|
+
|
106
|
+
def random_slice_gpt_sequence(concept_ids, max_seq_len):
|
107
|
+
"""
|
108
|
+
Randomly slice a GPT sequence.
|
109
|
+
|
110
|
+
:param concept_ids: List of concept IDs.
|
111
|
+
:param max_seq_len: Maximum sequence length.
|
112
|
+
:return: Tuple containing start index, end index, and demographic tokens.
|
113
|
+
"""
|
114
|
+
seq_length = len(concept_ids)
|
115
|
+
starting_points = []
|
116
|
+
start_year, start_age, start_gender, start_race = [
|
117
|
+
_ for _ in concept_ids[:DEMOGRAPHIC_PROMPT_SIZE]
|
118
|
+
]
|
119
|
+
try:
|
120
|
+
start_year = int(start_year.split(":")[1])
|
121
|
+
start_age = int(start_age.split(":")[1])
|
122
|
+
data_cursor = date(int(start_year), 1, 1)
|
123
|
+
birth_date = date(start_year - start_age, 1, 1)
|
124
|
+
for i in range(
|
125
|
+
DEMOGRAPHIC_PROMPT_SIZE,
|
126
|
+
min(seq_length, seq_length - max_seq_len + DEMOGRAPHIC_PROMPT_SIZE),
|
127
|
+
):
|
128
|
+
current_token = concept_ids[i]
|
129
|
+
if is_visit_start(current_token):
|
130
|
+
starting_points.append(
|
131
|
+
(i, data_cursor.year, data_cursor.year - birth_date.year)
|
132
|
+
)
|
133
|
+
elif is_att_token(current_token):
|
134
|
+
att_date_delta = extract_time_interval_in_days(current_token)
|
135
|
+
data_cursor = data_cursor + timedelta(days=att_date_delta)
|
136
|
+
|
137
|
+
if len(starting_points) == 0:
|
138
|
+
return 0, 0, concept_ids[:DEMOGRAPHIC_PROMPT_SIZE]
|
139
|
+
|
140
|
+
random_starting_index, random_starting_year, random_starting_age = (
|
141
|
+
random.choice(starting_points)
|
142
|
+
)
|
143
|
+
demographic_tokens = [
|
144
|
+
f"year:{random_starting_year}",
|
145
|
+
f"age:{random_starting_age}",
|
146
|
+
start_gender,
|
147
|
+
start_race,
|
148
|
+
]
|
149
|
+
# Remove the number of demographic tokens
|
150
|
+
random_end_index = random_starting_index
|
151
|
+
for i in reversed(
|
152
|
+
range(
|
153
|
+
random_starting_index,
|
154
|
+
random_starting_index + max_seq_len - DEMOGRAPHIC_PROMPT_SIZE,
|
155
|
+
)
|
156
|
+
):
|
157
|
+
current_token = concept_ids[i]
|
158
|
+
if current_token == "VE":
|
159
|
+
random_end_index = i
|
160
|
+
break
|
161
|
+
return random_starting_index, random_end_index, demographic_tokens
|
162
|
+
|
163
|
+
except Exception:
|
164
|
+
return 0, max_seq_len - 1, []
|
165
|
+
|
166
|
+
|
167
|
+
def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
168
|
+
if args.sampling_strategy == SamplingStrategy.TopKStrategy.value:
|
169
|
+
folder_name = f"top_k{args.top_k}"
|
170
|
+
args.top_p = 1.0
|
171
|
+
elif args.sampling_strategy == SamplingStrategy.TopPStrategy.value:
|
172
|
+
folder_name = f"top_p{int(args.top_p * 10000)}"
|
173
|
+
args.top_k = cehrgpt_tokenizer.vocab_size
|
174
|
+
elif args.sampling_strategy == SamplingStrategy.TopMixStrategy.value:
|
175
|
+
folder_name = f"top_mix_p{int(args.top_p * 10000)}_k{args.top_k}"
|
176
|
+
else:
|
177
|
+
raise RuntimeError(
|
178
|
+
"sampling_strategy has to be one of the following three options [TopKStrategy, TopPStrategy, TopMixStrategy]"
|
179
|
+
)
|
180
|
+
if args.temperature != 1.0:
|
181
|
+
folder_name = f"{folder_name}_temp_{int(args.temperature * 10000)}"
|
182
|
+
if args.repetition_penalty != 1.0:
|
183
|
+
folder_name = (
|
184
|
+
f"{folder_name}_repetition_penalty_{int(args.repetition_penalty * 10000)}"
|
185
|
+
)
|
186
|
+
if args.num_beams > 1:
|
187
|
+
folder_name = f"{folder_name}_num_beams_{int(args.num_beams)}"
|
188
|
+
if args.num_beam_groups > 1:
|
189
|
+
folder_name = f"{folder_name}_num_beam_groups_{int(args.num_beam_groups)}"
|
190
|
+
if args.epsilon_cutoff > 0.0:
|
191
|
+
folder_name = (
|
192
|
+
f"{folder_name}_epsilon_cutoff_{int(args.epsilon_cutoff * 100000)}"
|
193
|
+
)
|
194
|
+
return folder_name
|
195
|
+
|
196
|
+
|
197
|
+
def is_clinical_event(token: str) -> bool:
|
198
|
+
return token.isnumeric()
|
199
|
+
|
200
|
+
|
201
|
+
def is_visit_start(token: str):
|
202
|
+
"""
|
203
|
+
Check if the token indicates the start of a visit.
|
204
|
+
|
205
|
+
:param token: Token to check.
|
206
|
+
:return: True if the token is a visit start token, False otherwise.
|
207
|
+
"""
|
208
|
+
return token in ["VS", "[VS]"]
|
209
|
+
|
210
|
+
|
211
|
+
def is_visit_end(token: str) -> bool:
|
212
|
+
return token in ["VE", "[VE]"]
|
213
|
+
|
214
|
+
|
215
|
+
def is_att_token(token: str):
|
216
|
+
"""
|
217
|
+
Check if the token is an attention token.
|
218
|
+
|
219
|
+
:param token: Token to check.
|
220
|
+
:return: True if the token is an attention token, False otherwise.
|
221
|
+
"""
|
222
|
+
if bool(re.match(r"^D\d+", token)): # day tokens
|
223
|
+
return True
|
224
|
+
elif bool(re.match(r"^W\d+", token)): # week tokens
|
225
|
+
return True
|
226
|
+
elif bool(re.match(r"^M\d+", token)): # month tokens
|
227
|
+
return True
|
228
|
+
elif bool(re.match(r"^Y\d+", token)): # year tokens
|
229
|
+
return True
|
230
|
+
elif token == "LT":
|
231
|
+
return True
|
232
|
+
elif token[:3] == "VS-": # VS-D7-VE
|
233
|
+
return True
|
234
|
+
elif token[:2] == "i-" and not token.startswith(
|
235
|
+
"i-H"
|
236
|
+
): # i-D7 and exclude hour tokens
|
237
|
+
return True
|
238
|
+
return False
|
239
|
+
|
240
|
+
|
241
|
+
def is_artificial_token(token: str) -> bool:
|
242
|
+
if token in VISIT_CONCEPT_IDS:
|
243
|
+
return True
|
244
|
+
if token in DISCHARGE_CONCEPT_IDS:
|
245
|
+
return True
|
246
|
+
if is_visit_start(token):
|
247
|
+
return True
|
248
|
+
if is_visit_end(token):
|
249
|
+
return True
|
250
|
+
if is_att_token(token):
|
251
|
+
return True
|
252
|
+
if token == END_TOKEN:
|
253
|
+
return True
|
254
|
+
return False
|
255
|
+
|
256
|
+
|
257
|
+
def is_inpatient_att_token(token: str):
|
258
|
+
"""
|
259
|
+
Check if the token is an inpatient ATT token.
|
260
|
+
|
261
|
+
:param token: Token to check.
|
262
|
+
:return: True if the token is an inpatient ATT token, False otherwise.
|
263
|
+
"""
|
264
|
+
return INPATIENT_ATT_PATTERN.match(token)
|
265
|
+
|
266
|
+
|
267
|
+
def extract_time_interval_in_days(token: str):
|
268
|
+
"""
|
269
|
+
Extract the time interval in days from a token.
|
270
|
+
|
271
|
+
:param token: Token to extract from.
|
272
|
+
:return: Time interval in days.
|
273
|
+
:raises ValueError: If the token is invalid.
|
274
|
+
"""
|
275
|
+
try:
|
276
|
+
if token[0] == "D": # day tokens
|
277
|
+
return int(token[1:])
|
278
|
+
elif token[0] == "W": # week tokens
|
279
|
+
return int(token[1:]) * 7
|
280
|
+
elif token[0] == "M": # month tokens
|
281
|
+
return int(token[1:]) * 30
|
282
|
+
elif token[0] == "Y": # year tokens
|
283
|
+
return int(token[1:]) * 365
|
284
|
+
elif token == "LT":
|
285
|
+
return 365 * 3
|
286
|
+
elif token[:3] == "VS-": # VS-D7-VE
|
287
|
+
part = token.split("-")[1]
|
288
|
+
if part.startswith("LT"):
|
289
|
+
return 365 * 3
|
290
|
+
return int(part[1:])
|
291
|
+
elif token[:2] == "i-": # i-D7
|
292
|
+
part = token.split("-")[1]
|
293
|
+
if part.startswith("LT"):
|
294
|
+
return 365 * 3
|
295
|
+
return int(token.split("-")[1][1:])
|
296
|
+
except Exception:
|
297
|
+
raise ValueError(f"Invalid time token: {token}")
|
298
|
+
raise ValueError(f"Invalid time token: {token}")
|
299
|
+
|
300
|
+
|
301
|
+
def convert_time_interval_to_time_tuple(
|
302
|
+
time_interval: int, is_inpatient: bool
|
303
|
+
) -> Tuple[str, str, str]:
|
304
|
+
"""
|
305
|
+
Convert a time interval to a tuple of time tokens.
|
306
|
+
|
307
|
+
:param time_interval: Time interval in days.
|
308
|
+
:param is_inpatient: Whether the interval is for an inpatient.
|
309
|
+
:return: Tuple of year, month, and day tokens.
|
310
|
+
"""
|
311
|
+
assert time_interval >= 0, "the time interval must equal and greater than zero"
|
312
|
+
year = time_interval // 365
|
313
|
+
month = time_interval % 365 // 30
|
314
|
+
day = time_interval % 365 % 30
|
315
|
+
year_token = f"year:{year}"
|
316
|
+
month_token = f"month:{month}"
|
317
|
+
day_token = f"i-day:{day}" if is_inpatient else f"day:{day}"
|
318
|
+
return year_token, month_token, day_token
|
319
|
+
|
320
|
+
|
321
|
+
def generate_artificial_time_tokens():
|
322
|
+
"""
|
323
|
+
Generate all the time tokens used in training.
|
324
|
+
|
325
|
+
:return: List of time tokens.
|
326
|
+
"""
|
327
|
+
day_tokens = [f"D{i}" for i in range(2000)]
|
328
|
+
week_tokens = [f"W{i}" for i in range(4)]
|
329
|
+
month_tokens = [f"M{i}" for i in range(12)]
|
330
|
+
long_term_tokens = ["LT"]
|
331
|
+
return day_tokens + week_tokens + month_tokens + long_term_tokens
|
File without changes
|
cehrgpt/models/config.py
ADDED
@@ -0,0 +1,205 @@
|
|
1
|
+
from typing import Dict, List
|
2
|
+
|
3
|
+
from transformers import PretrainedConfig
|
4
|
+
|
5
|
+
|
6
|
+
class CEHRGPTConfig(PretrainedConfig):
|
7
|
+
"""
|
8
|
+
Args:
|
9
|
+
|
10
|
+
vocab_size (`int`, *optional*, defaults to 50257):
|
11
|
+
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
|
12
|
+
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
|
13
|
+
n_positions (`int`, *optional*, defaults to 1024):
|
14
|
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
15
|
+
just in case (e.g., 512 or 1024 or 2048).
|
16
|
+
n_embd (`int`, *optional*, defaults to 768):
|
17
|
+
Dimensionality of the embeddings and hidden states.
|
18
|
+
n_layer (`int`, *optional*, defaults to 12):
|
19
|
+
Number of hidden layers in the Transformer encoder.
|
20
|
+
n_head (`int`, *optional*, defaults to 12):
|
21
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
22
|
+
n_inner (`int`, *optional*):
|
23
|
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
24
|
+
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
25
|
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
26
|
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
27
|
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
28
|
+
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
29
|
+
The dropout ratio for the embeddings.
|
30
|
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
31
|
+
The dropout ratio for the attention.
|
32
|
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
33
|
+
The epsilon to use in the layer normalization layers.
|
34
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
35
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
36
|
+
summary_type (`string`, *optional*, defaults to `"cls_index"`):
|
37
|
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
38
|
+
[`TFGPT2DoubleHeadsModel`].
|
39
|
+
|
40
|
+
Has to be one of the following options:
|
41
|
+
|
42
|
+
- `"last"`: Take the last token hidden state (like XLNet).
|
43
|
+
- `"first"`: Take the first token hidden state (like BERT).
|
44
|
+
- `"mean"`: Take the mean of all tokens hidden states.
|
45
|
+
- `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
|
46
|
+
- `"attn"`: Not implemented now, use multi-head attention.
|
47
|
+
summary_use_proj (`bool`, *optional*, defaults to `True`):
|
48
|
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
49
|
+
[`TFGPT2DoubleHeadsModel`].
|
50
|
+
|
51
|
+
Whether or not to add a projection after the vector extraction.
|
52
|
+
summary_activation (`str`, *optional*):
|
53
|
+
Argument used when doing sequence summary. Used in for the multiple choice head in
|
54
|
+
[`GPT2DoubleHeadsModel`].
|
55
|
+
|
56
|
+
Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
|
57
|
+
summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
|
58
|
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
59
|
+
[`TFGPT2DoubleHeadsModel`].
|
60
|
+
|
61
|
+
Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
|
62
|
+
summary_first_dropout (`float`, *optional*, defaults to 0.1):
|
63
|
+
Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
|
64
|
+
[`TFGPT2DoubleHeadsModel`].
|
65
|
+
|
66
|
+
The dropout ratio to be used after the projection and activation.
|
67
|
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
68
|
+
Scale attention weights by dividing by sqrt(hidden_size)..
|
69
|
+
use_cache (`bool`, *optional*, defaults to `True`):
|
70
|
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
71
|
+
bos_token_id (`int`, *optional*, defaults to 50256):
|
72
|
+
Id of the beginning of sentence token in the vocabulary.
|
73
|
+
eos_token_id (`int`, *optional*, defaults to 50256):
|
74
|
+
Id of the end of sentence token in the vocabulary.
|
75
|
+
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
76
|
+
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
77
|
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
78
|
+
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
79
|
+
dot-product/softmax to float() when training with mixed precision.
|
80
|
+
"""
|
81
|
+
|
82
|
+
model_type = "cehrgpt"
|
83
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
84
|
+
attribute_map = {
|
85
|
+
"hidden_size": "n_embd",
|
86
|
+
"max_position_embeddings": "n_positions",
|
87
|
+
"num_attention_heads": "n_head",
|
88
|
+
"num_hidden_layers": "n_layer",
|
89
|
+
}
|
90
|
+
|
91
|
+
@property
|
92
|
+
def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
|
93
|
+
# The saved _token_to_time_token_mapping converts the key to string, so we need to convert it back to int
|
94
|
+
return {
|
95
|
+
int(token): list(map(int, sub_tokens))
|
96
|
+
for token, sub_tokens in self._token_to_time_token_mapping.items()
|
97
|
+
}
|
98
|
+
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
vocab_size=50257,
|
102
|
+
time_token_vocab_size=50257,
|
103
|
+
n_positions=1024,
|
104
|
+
n_embd=768,
|
105
|
+
n_layer=12,
|
106
|
+
n_head=12,
|
107
|
+
n_inner=None,
|
108
|
+
activation_function="gelu_new",
|
109
|
+
resid_pdrop=0.1,
|
110
|
+
embd_pdrop=0.1,
|
111
|
+
attn_pdrop=0.1,
|
112
|
+
layer_norm_epsilon=1e-5,
|
113
|
+
initializer_range=0.02,
|
114
|
+
summary_type="cls_index",
|
115
|
+
summary_use_proj=True,
|
116
|
+
summary_activation=None,
|
117
|
+
summary_proj_to_labels=True,
|
118
|
+
summary_first_dropout=0.1,
|
119
|
+
scale_attn_weights=True,
|
120
|
+
use_cache=True,
|
121
|
+
bos_token_id=50256,
|
122
|
+
eos_token_id=50256,
|
123
|
+
lab_token_ids=None,
|
124
|
+
scale_attn_by_inverse_layer_idx=False,
|
125
|
+
reorder_and_upcast_attn=False,
|
126
|
+
exclude_position_ids=False,
|
127
|
+
include_values=False,
|
128
|
+
value_vocab_size=None,
|
129
|
+
include_ttv_prediction=False,
|
130
|
+
use_sub_time_tokenization=True,
|
131
|
+
token_to_time_token_mapping: Dict[int, List] = None,
|
132
|
+
use_pretrained_embeddings=False,
|
133
|
+
n_pretrained_embeddings_layers=2,
|
134
|
+
pretrained_embedding_dim=768,
|
135
|
+
pretrained_token_ids: List[int] = None,
|
136
|
+
time_token_loss_weight=1.0,
|
137
|
+
time_to_visit_loss_weight=1.0,
|
138
|
+
causal_sfm=False,
|
139
|
+
demographics_size=4,
|
140
|
+
lab_token_penalty=False,
|
141
|
+
lab_token_loss_weight=0.9,
|
142
|
+
entropy_penalty=False,
|
143
|
+
entropy_penalty_alpha=0.01,
|
144
|
+
**kwargs,
|
145
|
+
):
|
146
|
+
if token_to_time_token_mapping is None:
|
147
|
+
token_to_time_token_mapping = {}
|
148
|
+
if pretrained_token_ids is None:
|
149
|
+
pretrained_token_ids = list()
|
150
|
+
self.vocab_size = vocab_size
|
151
|
+
self.time_token_vocab_size = time_token_vocab_size
|
152
|
+
self.n_positions = n_positions
|
153
|
+
self.n_embd = n_embd
|
154
|
+
self.n_layer = n_layer
|
155
|
+
self.n_head = n_head
|
156
|
+
self.n_inner = n_inner
|
157
|
+
self.activation_function = activation_function
|
158
|
+
self.resid_pdrop = resid_pdrop
|
159
|
+
self.embd_pdrop = embd_pdrop
|
160
|
+
self.attn_pdrop = attn_pdrop
|
161
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
162
|
+
self.initializer_range = initializer_range
|
163
|
+
self.summary_type = summary_type
|
164
|
+
self.summary_use_proj = summary_use_proj
|
165
|
+
self.summary_activation = summary_activation
|
166
|
+
self.summary_first_dropout = summary_first_dropout
|
167
|
+
self.summary_proj_to_labels = summary_proj_to_labels
|
168
|
+
self.scale_attn_weights = scale_attn_weights
|
169
|
+
self.use_cache = use_cache
|
170
|
+
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
171
|
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
172
|
+
|
173
|
+
self.bos_token_id = bos_token_id
|
174
|
+
self.eos_token_id = eos_token_id
|
175
|
+
self.lab_token_ids = lab_token_ids
|
176
|
+
|
177
|
+
self.exclude_position_ids = exclude_position_ids
|
178
|
+
self.include_values = include_values
|
179
|
+
self.value_vocab_size = value_vocab_size
|
180
|
+
|
181
|
+
self.include_ttv_prediction = include_ttv_prediction
|
182
|
+
self.use_sub_time_tokenization = use_sub_time_tokenization
|
183
|
+
self._token_to_time_token_mapping = token_to_time_token_mapping
|
184
|
+
self.time_token_loss_weight = time_token_loss_weight
|
185
|
+
self.time_to_visit_loss_weight = time_to_visit_loss_weight
|
186
|
+
self.causal_sfm = causal_sfm
|
187
|
+
self.demographics_size = demographics_size
|
188
|
+
self.use_pretrained_embeddings = use_pretrained_embeddings
|
189
|
+
self.pretrained_embedding_dim = pretrained_embedding_dim
|
190
|
+
self.pretrained_token_ids = pretrained_token_ids
|
191
|
+
self.n_pretrained_embeddings_layers = n_pretrained_embeddings_layers
|
192
|
+
# self.tie_word_embeddings = not use_pretrained_embeddings
|
193
|
+
|
194
|
+
self.lab_token_penalty = lab_token_penalty
|
195
|
+
self.lab_token_loss_weight = lab_token_loss_weight
|
196
|
+
self.entropy_penalty = entropy_penalty
|
197
|
+
self.entropy_penalty_alpha = entropy_penalty_alpha
|
198
|
+
|
199
|
+
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
200
|
+
|
201
|
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
202
|
+
|
203
|
+
@property
|
204
|
+
def lab_token_exists(self) -> bool:
|
205
|
+
return self.lab_token_ids is not None and len(self.lab_token_ids) > 0
|