cehrgpt 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. __init__.py +0 -0
  2. cehrgpt/__init__.py +0 -0
  3. cehrgpt/analysis/__init__.py +0 -0
  4. cehrgpt/analysis/privacy/__init__.py +0 -0
  5. cehrgpt/analysis/privacy/attribute_inference.py +275 -0
  6. cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
  7. cehrgpt/analysis/privacy/member_inference.py +172 -0
  8. cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
  9. cehrgpt/analysis/privacy/reid_inference.py +407 -0
  10. cehrgpt/analysis/privacy/utils.py +255 -0
  11. cehrgpt/cehrgpt_args.py +142 -0
  12. cehrgpt/data/__init__.py +0 -0
  13. cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
  14. cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
  15. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
  16. cehrgpt/generation/__init__.py +0 -0
  17. cehrgpt/generation/chatgpt_generation.py +106 -0
  18. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
  19. cehrgpt/generation/omop_converter_batch.py +644 -0
  20. cehrgpt/generation/omop_entity.py +515 -0
  21. cehrgpt/gpt_utils.py +331 -0
  22. cehrgpt/models/__init__.py +0 -0
  23. cehrgpt/models/config.py +205 -0
  24. cehrgpt/models/hf_cehrgpt.py +1817 -0
  25. cehrgpt/models/hf_modeling_outputs.py +158 -0
  26. cehrgpt/models/pretrained_embeddings.py +82 -0
  27. cehrgpt/models/special_tokens.py +30 -0
  28. cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
  29. cehrgpt/omop/__init__.py +0 -0
  30. cehrgpt/omop/condition_era.py +20 -0
  31. cehrgpt/omop/observation_period.py +43 -0
  32. cehrgpt/omop/omop_argparse.py +38 -0
  33. cehrgpt/omop/omop_table_builder.py +86 -0
  34. cehrgpt/omop/queries/__init__.py +0 -0
  35. cehrgpt/omop/queries/condition_era.py +86 -0
  36. cehrgpt/omop/queries/observation_period.py +135 -0
  37. cehrgpt/omop/sample_omop_tables.py +71 -0
  38. cehrgpt/runners/__init__.py +0 -0
  39. cehrgpt/runners/gpt_runner_util.py +99 -0
  40. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
  41. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
  42. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
  43. cehrgpt/runners/hyperparameter_search_util.py +223 -0
  44. cehrgpt/time_to_event/__init__.py +0 -0
  45. cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
  46. cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
  47. cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
  48. cehrgpt/time_to_event/time_to_event_model.py +226 -0
  49. cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
  50. cehrgpt/time_to_event/time_to_event_utils.py +55 -0
  51. cehrgpt/tools/__init__.py +0 -0
  52. cehrgpt/tools/ehrshot_benchmark.py +74 -0
  53. cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
  54. cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
  55. cehrgpt/tools/upload_omop_tables.py +108 -0
  56. cehrgpt-0.0.1.dist-info/LICENSE +21 -0
  57. cehrgpt-0.0.1.dist-info/METADATA +66 -0
  58. cehrgpt-0.0.1.dist-info/RECORD +60 -0
  59. cehrgpt-0.0.1.dist-info/WHEEL +5 -0
  60. cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
cehrgpt/gpt_utils.py ADDED
@@ -0,0 +1,331 @@
1
+ import random
2
+ import re
3
+ from datetime import date, timedelta
4
+ from typing import List, Sequence, Tuple
5
+
6
+ from cehrgpt.cehrgpt_args import SamplingStrategy
7
+ from cehrgpt.models.special_tokens import (
8
+ DISCHARGE_CONCEPT_IDS,
9
+ END_TOKEN,
10
+ VISIT_CONCEPT_IDS,
11
+ )
12
+
13
+ # Regular expression pattern to match inpatient attendance tokens
14
+ INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
15
+ DEMOGRAPHIC_PROMPT_SIZE = 4
16
+
17
+
18
+ class RandomSampleCache:
19
+ def __init__(
20
+ self,
21
+ data_indices: Sequence[int],
22
+ cache_size: int,
23
+ sample_weights: Sequence[float] = None,
24
+ ):
25
+ """
26
+ Initialize the RandomSampleCache.
27
+
28
+ :param data_indices: Sequence of data indices to sample from.
29
+ :param cache_size: Size of the cache.
30
+ :param sample_weights: Optional sequence of weights for sampling.
31
+ """
32
+ self._data_indices = data_indices
33
+ self._sample_weights = sample_weights
34
+ self._cache_size = cache_size
35
+ self._cache = []
36
+
37
+ if self._sample_weights is not None:
38
+ assert sum(self._sample_weights) - 1 < 1e-8
39
+
40
+ def next(self):
41
+ """
42
+ Get the next sample from the cache.
43
+
44
+ If the cache is empty, refill it.
45
+
46
+ :return: A sampled data index.
47
+ """
48
+ if not self._cache:
49
+ if self._sample_weights is not None:
50
+ self._cache.extend(
51
+ random.choices(
52
+ self._data_indices,
53
+ k=self._cache_size,
54
+ weights=self._sample_weights,
55
+ )
56
+ )
57
+ else:
58
+ self._cache.extend(
59
+ random.choices(self._data_indices, k=self._cache_size)
60
+ )
61
+ return self._cache.pop()
62
+
63
+
64
+ def collect_demographic_prompts_at_visits(patient_history: List[str]):
65
+ demographic_prompts_at_visits = []
66
+ start_year, start_age, start_gender, start_race = patient_history[
67
+ :DEMOGRAPHIC_PROMPT_SIZE
68
+ ]
69
+ try:
70
+ start_year = int(start_year.split(":")[1])
71
+ start_age = int(start_age.split(":")[1])
72
+ valid_prompt = True
73
+ except IndexError | ValueError:
74
+ start_year = 1900
75
+ start_age = 0
76
+ valid_prompt = False
77
+ data_cursor = date(int(start_year), 1, 1)
78
+ birth_date = date(start_year - start_age, 1, 1)
79
+ for i, current_token in enumerate(patient_history):
80
+ if is_visit_start(current_token):
81
+ reconstructed_year = (
82
+ f"year:{data_cursor.year}" if valid_prompt else "year:unknown"
83
+ )
84
+ reconstructed_age = (
85
+ f"age:{data_cursor.year - birth_date.year}"
86
+ if valid_prompt
87
+ else "age:unknown"
88
+ )
89
+ demographic_prompts_at_visits.append(
90
+ (
91
+ i,
92
+ (
93
+ reconstructed_year,
94
+ reconstructed_age,
95
+ start_gender,
96
+ start_race,
97
+ ),
98
+ )
99
+ )
100
+ elif is_att_token(current_token):
101
+ att_date_delta = extract_time_interval_in_days(current_token)
102
+ data_cursor = data_cursor + timedelta(days=att_date_delta)
103
+ return demographic_prompts_at_visits
104
+
105
+
106
+ def random_slice_gpt_sequence(concept_ids, max_seq_len):
107
+ """
108
+ Randomly slice a GPT sequence.
109
+
110
+ :param concept_ids: List of concept IDs.
111
+ :param max_seq_len: Maximum sequence length.
112
+ :return: Tuple containing start index, end index, and demographic tokens.
113
+ """
114
+ seq_length = len(concept_ids)
115
+ starting_points = []
116
+ start_year, start_age, start_gender, start_race = [
117
+ _ for _ in concept_ids[:DEMOGRAPHIC_PROMPT_SIZE]
118
+ ]
119
+ try:
120
+ start_year = int(start_year.split(":")[1])
121
+ start_age = int(start_age.split(":")[1])
122
+ data_cursor = date(int(start_year), 1, 1)
123
+ birth_date = date(start_year - start_age, 1, 1)
124
+ for i in range(
125
+ DEMOGRAPHIC_PROMPT_SIZE,
126
+ min(seq_length, seq_length - max_seq_len + DEMOGRAPHIC_PROMPT_SIZE),
127
+ ):
128
+ current_token = concept_ids[i]
129
+ if is_visit_start(current_token):
130
+ starting_points.append(
131
+ (i, data_cursor.year, data_cursor.year - birth_date.year)
132
+ )
133
+ elif is_att_token(current_token):
134
+ att_date_delta = extract_time_interval_in_days(current_token)
135
+ data_cursor = data_cursor + timedelta(days=att_date_delta)
136
+
137
+ if len(starting_points) == 0:
138
+ return 0, 0, concept_ids[:DEMOGRAPHIC_PROMPT_SIZE]
139
+
140
+ random_starting_index, random_starting_year, random_starting_age = (
141
+ random.choice(starting_points)
142
+ )
143
+ demographic_tokens = [
144
+ f"year:{random_starting_year}",
145
+ f"age:{random_starting_age}",
146
+ start_gender,
147
+ start_race,
148
+ ]
149
+ # Remove the number of demographic tokens
150
+ random_end_index = random_starting_index
151
+ for i in reversed(
152
+ range(
153
+ random_starting_index,
154
+ random_starting_index + max_seq_len - DEMOGRAPHIC_PROMPT_SIZE,
155
+ )
156
+ ):
157
+ current_token = concept_ids[i]
158
+ if current_token == "VE":
159
+ random_end_index = i
160
+ break
161
+ return random_starting_index, random_end_index, demographic_tokens
162
+
163
+ except Exception:
164
+ return 0, max_seq_len - 1, []
165
+
166
+
167
+ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
168
+ if args.sampling_strategy == SamplingStrategy.TopKStrategy.value:
169
+ folder_name = f"top_k{args.top_k}"
170
+ args.top_p = 1.0
171
+ elif args.sampling_strategy == SamplingStrategy.TopPStrategy.value:
172
+ folder_name = f"top_p{int(args.top_p * 10000)}"
173
+ args.top_k = cehrgpt_tokenizer.vocab_size
174
+ elif args.sampling_strategy == SamplingStrategy.TopMixStrategy.value:
175
+ folder_name = f"top_mix_p{int(args.top_p * 10000)}_k{args.top_k}"
176
+ else:
177
+ raise RuntimeError(
178
+ "sampling_strategy has to be one of the following three options [TopKStrategy, TopPStrategy, TopMixStrategy]"
179
+ )
180
+ if args.temperature != 1.0:
181
+ folder_name = f"{folder_name}_temp_{int(args.temperature * 10000)}"
182
+ if args.repetition_penalty != 1.0:
183
+ folder_name = (
184
+ f"{folder_name}_repetition_penalty_{int(args.repetition_penalty * 10000)}"
185
+ )
186
+ if args.num_beams > 1:
187
+ folder_name = f"{folder_name}_num_beams_{int(args.num_beams)}"
188
+ if args.num_beam_groups > 1:
189
+ folder_name = f"{folder_name}_num_beam_groups_{int(args.num_beam_groups)}"
190
+ if args.epsilon_cutoff > 0.0:
191
+ folder_name = (
192
+ f"{folder_name}_epsilon_cutoff_{int(args.epsilon_cutoff * 100000)}"
193
+ )
194
+ return folder_name
195
+
196
+
197
+ def is_clinical_event(token: str) -> bool:
198
+ return token.isnumeric()
199
+
200
+
201
+ def is_visit_start(token: str):
202
+ """
203
+ Check if the token indicates the start of a visit.
204
+
205
+ :param token: Token to check.
206
+ :return: True if the token is a visit start token, False otherwise.
207
+ """
208
+ return token in ["VS", "[VS]"]
209
+
210
+
211
+ def is_visit_end(token: str) -> bool:
212
+ return token in ["VE", "[VE]"]
213
+
214
+
215
+ def is_att_token(token: str):
216
+ """
217
+ Check if the token is an attention token.
218
+
219
+ :param token: Token to check.
220
+ :return: True if the token is an attention token, False otherwise.
221
+ """
222
+ if bool(re.match(r"^D\d+", token)): # day tokens
223
+ return True
224
+ elif bool(re.match(r"^W\d+", token)): # week tokens
225
+ return True
226
+ elif bool(re.match(r"^M\d+", token)): # month tokens
227
+ return True
228
+ elif bool(re.match(r"^Y\d+", token)): # year tokens
229
+ return True
230
+ elif token == "LT":
231
+ return True
232
+ elif token[:3] == "VS-": # VS-D7-VE
233
+ return True
234
+ elif token[:2] == "i-" and not token.startswith(
235
+ "i-H"
236
+ ): # i-D7 and exclude hour tokens
237
+ return True
238
+ return False
239
+
240
+
241
+ def is_artificial_token(token: str) -> bool:
242
+ if token in VISIT_CONCEPT_IDS:
243
+ return True
244
+ if token in DISCHARGE_CONCEPT_IDS:
245
+ return True
246
+ if is_visit_start(token):
247
+ return True
248
+ if is_visit_end(token):
249
+ return True
250
+ if is_att_token(token):
251
+ return True
252
+ if token == END_TOKEN:
253
+ return True
254
+ return False
255
+
256
+
257
+ def is_inpatient_att_token(token: str):
258
+ """
259
+ Check if the token is an inpatient ATT token.
260
+
261
+ :param token: Token to check.
262
+ :return: True if the token is an inpatient ATT token, False otherwise.
263
+ """
264
+ return INPATIENT_ATT_PATTERN.match(token)
265
+
266
+
267
+ def extract_time_interval_in_days(token: str):
268
+ """
269
+ Extract the time interval in days from a token.
270
+
271
+ :param token: Token to extract from.
272
+ :return: Time interval in days.
273
+ :raises ValueError: If the token is invalid.
274
+ """
275
+ try:
276
+ if token[0] == "D": # day tokens
277
+ return int(token[1:])
278
+ elif token[0] == "W": # week tokens
279
+ return int(token[1:]) * 7
280
+ elif token[0] == "M": # month tokens
281
+ return int(token[1:]) * 30
282
+ elif token[0] == "Y": # year tokens
283
+ return int(token[1:]) * 365
284
+ elif token == "LT":
285
+ return 365 * 3
286
+ elif token[:3] == "VS-": # VS-D7-VE
287
+ part = token.split("-")[1]
288
+ if part.startswith("LT"):
289
+ return 365 * 3
290
+ return int(part[1:])
291
+ elif token[:2] == "i-": # i-D7
292
+ part = token.split("-")[1]
293
+ if part.startswith("LT"):
294
+ return 365 * 3
295
+ return int(token.split("-")[1][1:])
296
+ except Exception:
297
+ raise ValueError(f"Invalid time token: {token}")
298
+ raise ValueError(f"Invalid time token: {token}")
299
+
300
+
301
+ def convert_time_interval_to_time_tuple(
302
+ time_interval: int, is_inpatient: bool
303
+ ) -> Tuple[str, str, str]:
304
+ """
305
+ Convert a time interval to a tuple of time tokens.
306
+
307
+ :param time_interval: Time interval in days.
308
+ :param is_inpatient: Whether the interval is for an inpatient.
309
+ :return: Tuple of year, month, and day tokens.
310
+ """
311
+ assert time_interval >= 0, "the time interval must equal and greater than zero"
312
+ year = time_interval // 365
313
+ month = time_interval % 365 // 30
314
+ day = time_interval % 365 % 30
315
+ year_token = f"year:{year}"
316
+ month_token = f"month:{month}"
317
+ day_token = f"i-day:{day}" if is_inpatient else f"day:{day}"
318
+ return year_token, month_token, day_token
319
+
320
+
321
+ def generate_artificial_time_tokens():
322
+ """
323
+ Generate all the time tokens used in training.
324
+
325
+ :return: List of time tokens.
326
+ """
327
+ day_tokens = [f"D{i}" for i in range(2000)]
328
+ week_tokens = [f"W{i}" for i in range(4)]
329
+ month_tokens = [f"M{i}" for i in range(12)]
330
+ long_term_tokens = ["LT"]
331
+ return day_tokens + week_tokens + month_tokens + long_term_tokens
File without changes
@@ -0,0 +1,205 @@
1
+ from typing import Dict, List
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class CEHRGPTConfig(PretrainedConfig):
7
+ """
8
+ Args:
9
+
10
+ vocab_size (`int`, *optional*, defaults to 50257):
11
+ Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
12
+ `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
13
+ n_positions (`int`, *optional*, defaults to 1024):
14
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
15
+ just in case (e.g., 512 or 1024 or 2048).
16
+ n_embd (`int`, *optional*, defaults to 768):
17
+ Dimensionality of the embeddings and hidden states.
18
+ n_layer (`int`, *optional*, defaults to 12):
19
+ Number of hidden layers in the Transformer encoder.
20
+ n_head (`int`, *optional*, defaults to 12):
21
+ Number of attention heads for each attention layer in the Transformer encoder.
22
+ n_inner (`int`, *optional*):
23
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
24
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
25
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
26
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
27
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
28
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
29
+ The dropout ratio for the embeddings.
30
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
31
+ The dropout ratio for the attention.
32
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
33
+ The epsilon to use in the layer normalization layers.
34
+ initializer_range (`float`, *optional*, defaults to 0.02):
35
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
36
+ summary_type (`string`, *optional*, defaults to `"cls_index"`):
37
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
38
+ [`TFGPT2DoubleHeadsModel`].
39
+
40
+ Has to be one of the following options:
41
+
42
+ - `"last"`: Take the last token hidden state (like XLNet).
43
+ - `"first"`: Take the first token hidden state (like BERT).
44
+ - `"mean"`: Take the mean of all tokens hidden states.
45
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
46
+ - `"attn"`: Not implemented now, use multi-head attention.
47
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
48
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
49
+ [`TFGPT2DoubleHeadsModel`].
50
+
51
+ Whether or not to add a projection after the vector extraction.
52
+ summary_activation (`str`, *optional*):
53
+ Argument used when doing sequence summary. Used in for the multiple choice head in
54
+ [`GPT2DoubleHeadsModel`].
55
+
56
+ Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
57
+ summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
58
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
59
+ [`TFGPT2DoubleHeadsModel`].
60
+
61
+ Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
62
+ summary_first_dropout (`float`, *optional*, defaults to 0.1):
63
+ Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
64
+ [`TFGPT2DoubleHeadsModel`].
65
+
66
+ The dropout ratio to be used after the projection and activation.
67
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
68
+ Scale attention weights by dividing by sqrt(hidden_size)..
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models).
71
+ bos_token_id (`int`, *optional*, defaults to 50256):
72
+ Id of the beginning of sentence token in the vocabulary.
73
+ eos_token_id (`int`, *optional*, defaults to 50256):
74
+ Id of the end of sentence token in the vocabulary.
75
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
76
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
77
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
78
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
79
+ dot-product/softmax to float() when training with mixed precision.
80
+ """
81
+
82
+ model_type = "cehrgpt"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "hidden_size": "n_embd",
86
+ "max_position_embeddings": "n_positions",
87
+ "num_attention_heads": "n_head",
88
+ "num_hidden_layers": "n_layer",
89
+ }
90
+
91
+ @property
92
+ def token_to_time_token_mapping(self) -> Dict[int, List[int]]:
93
+ # The saved _token_to_time_token_mapping converts the key to string, so we need to convert it back to int
94
+ return {
95
+ int(token): list(map(int, sub_tokens))
96
+ for token, sub_tokens in self._token_to_time_token_mapping.items()
97
+ }
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=50257,
102
+ time_token_vocab_size=50257,
103
+ n_positions=1024,
104
+ n_embd=768,
105
+ n_layer=12,
106
+ n_head=12,
107
+ n_inner=None,
108
+ activation_function="gelu_new",
109
+ resid_pdrop=0.1,
110
+ embd_pdrop=0.1,
111
+ attn_pdrop=0.1,
112
+ layer_norm_epsilon=1e-5,
113
+ initializer_range=0.02,
114
+ summary_type="cls_index",
115
+ summary_use_proj=True,
116
+ summary_activation=None,
117
+ summary_proj_to_labels=True,
118
+ summary_first_dropout=0.1,
119
+ scale_attn_weights=True,
120
+ use_cache=True,
121
+ bos_token_id=50256,
122
+ eos_token_id=50256,
123
+ lab_token_ids=None,
124
+ scale_attn_by_inverse_layer_idx=False,
125
+ reorder_and_upcast_attn=False,
126
+ exclude_position_ids=False,
127
+ include_values=False,
128
+ value_vocab_size=None,
129
+ include_ttv_prediction=False,
130
+ use_sub_time_tokenization=True,
131
+ token_to_time_token_mapping: Dict[int, List] = None,
132
+ use_pretrained_embeddings=False,
133
+ n_pretrained_embeddings_layers=2,
134
+ pretrained_embedding_dim=768,
135
+ pretrained_token_ids: List[int] = None,
136
+ time_token_loss_weight=1.0,
137
+ time_to_visit_loss_weight=1.0,
138
+ causal_sfm=False,
139
+ demographics_size=4,
140
+ lab_token_penalty=False,
141
+ lab_token_loss_weight=0.9,
142
+ entropy_penalty=False,
143
+ entropy_penalty_alpha=0.01,
144
+ **kwargs,
145
+ ):
146
+ if token_to_time_token_mapping is None:
147
+ token_to_time_token_mapping = {}
148
+ if pretrained_token_ids is None:
149
+ pretrained_token_ids = list()
150
+ self.vocab_size = vocab_size
151
+ self.time_token_vocab_size = time_token_vocab_size
152
+ self.n_positions = n_positions
153
+ self.n_embd = n_embd
154
+ self.n_layer = n_layer
155
+ self.n_head = n_head
156
+ self.n_inner = n_inner
157
+ self.activation_function = activation_function
158
+ self.resid_pdrop = resid_pdrop
159
+ self.embd_pdrop = embd_pdrop
160
+ self.attn_pdrop = attn_pdrop
161
+ self.layer_norm_epsilon = layer_norm_epsilon
162
+ self.initializer_range = initializer_range
163
+ self.summary_type = summary_type
164
+ self.summary_use_proj = summary_use_proj
165
+ self.summary_activation = summary_activation
166
+ self.summary_first_dropout = summary_first_dropout
167
+ self.summary_proj_to_labels = summary_proj_to_labels
168
+ self.scale_attn_weights = scale_attn_weights
169
+ self.use_cache = use_cache
170
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
171
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
172
+
173
+ self.bos_token_id = bos_token_id
174
+ self.eos_token_id = eos_token_id
175
+ self.lab_token_ids = lab_token_ids
176
+
177
+ self.exclude_position_ids = exclude_position_ids
178
+ self.include_values = include_values
179
+ self.value_vocab_size = value_vocab_size
180
+
181
+ self.include_ttv_prediction = include_ttv_prediction
182
+ self.use_sub_time_tokenization = use_sub_time_tokenization
183
+ self._token_to_time_token_mapping = token_to_time_token_mapping
184
+ self.time_token_loss_weight = time_token_loss_weight
185
+ self.time_to_visit_loss_weight = time_to_visit_loss_weight
186
+ self.causal_sfm = causal_sfm
187
+ self.demographics_size = demographics_size
188
+ self.use_pretrained_embeddings = use_pretrained_embeddings
189
+ self.pretrained_embedding_dim = pretrained_embedding_dim
190
+ self.pretrained_token_ids = pretrained_token_ids
191
+ self.n_pretrained_embeddings_layers = n_pretrained_embeddings_layers
192
+ # self.tie_word_embeddings = not use_pretrained_embeddings
193
+
194
+ self.lab_token_penalty = lab_token_penalty
195
+ self.lab_token_loss_weight = lab_token_loss_weight
196
+ self.entropy_penalty = entropy_penalty
197
+ self.entropy_penalty_alpha = entropy_penalty_alpha
198
+
199
+ kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
200
+
201
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
202
+
203
+ @property
204
+ def lab_token_exists(self) -> bool:
205
+ return self.lab_token_ids is not None and len(self.lab_token_ids) > 0