cehrgpt 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/htn_treatment_pathway.py +546 -0
- cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
- cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
- cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
- cehrgpt/data/cehrgpt_data_processor.py +549 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
- cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
- cehrgpt/generation/omop_converter_batch.py +11 -4
- cehrgpt/gpt_utils.py +73 -3
- cehrgpt/models/activations.py +27 -0
- cehrgpt/models/config.py +6 -2
- cehrgpt/models/gpt2.py +560 -0
- cehrgpt/models/hf_cehrgpt.py +183 -460
- cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
- cehrgpt/omop/ontology.py +154 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
- cehrgpt/runners/hyperparameter_search_util.py +180 -69
- cehrgpt/runners/sample_packing_trainer.py +11 -2
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
- cehrgpt-0.1.3.dist-info/METADATA +238 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +32 -22
- cehrgpt-0.1.2.dist-info/METADATA +0 -209
- /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -28,6 +28,12 @@ from datasets.formatting.formatting import LazyBatch
|
|
28
28
|
from dateutil.relativedelta import relativedelta
|
29
29
|
from pandas import Series
|
30
30
|
|
31
|
+
from cehrgpt.gpt_utils import (
|
32
|
+
construct_age_sequence,
|
33
|
+
construct_time_sequence,
|
34
|
+
encode_demographics,
|
35
|
+
multiple_of_10,
|
36
|
+
)
|
31
37
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
32
38
|
NONE_BIN,
|
33
39
|
UNKNOWN_BIN,
|
@@ -43,6 +49,7 @@ CEHRGPT_COLUMNS = [
|
|
43
49
|
"concept_values",
|
44
50
|
"units",
|
45
51
|
"epoch_times",
|
52
|
+
"ages",
|
46
53
|
]
|
47
54
|
|
48
55
|
|
@@ -121,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
121
128
|
cehrgpt_record: Dict[str, Any],
|
122
129
|
code: str,
|
123
130
|
time: datetime.datetime,
|
131
|
+
age: int,
|
124
132
|
concept_value_mask: int = 0,
|
125
133
|
number_as_value: float = 0.0,
|
126
134
|
concept_as_value: str = "0",
|
@@ -128,6 +136,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
128
136
|
unit: str = NA,
|
129
137
|
) -> None:
|
130
138
|
cehrgpt_record["concept_ids"].append(replace_escape_chars(code))
|
139
|
+
cehrgpt_record["ages"].append(age)
|
131
140
|
cehrgpt_record["concept_value_masks"].append(concept_value_mask)
|
132
141
|
cehrgpt_record["number_as_values"].append(number_as_value)
|
133
142
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
@@ -141,6 +150,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
141
150
|
cehrgpt_record = {
|
142
151
|
"person_id": record["patient_id"],
|
143
152
|
"concept_ids": [],
|
153
|
+
"ages": [],
|
144
154
|
"concept_value_masks": [],
|
145
155
|
"number_as_values": [],
|
146
156
|
"concept_as_values": [],
|
@@ -168,14 +178,21 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
168
178
|
first_visit_start_datetime: datetime.datetime = get_value(
|
169
179
|
first_visit, "visit_start_datetime"
|
170
180
|
)
|
181
|
+
starting_age = relativedelta(first_visit_start_datetime, birth_datetime).years
|
171
182
|
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
172
|
-
age_str = f"age:{
|
183
|
+
age_str = f"age:{starting_age}"
|
184
|
+
self._update_cehrgpt_record(
|
185
|
+
cehrgpt_record, year_str, first_visit_start_datetime, starting_age
|
186
|
+
)
|
187
|
+
self._update_cehrgpt_record(
|
188
|
+
cehrgpt_record, age_str, first_visit_start_datetime, starting_age
|
189
|
+
)
|
173
190
|
self._update_cehrgpt_record(
|
174
|
-
cehrgpt_record,
|
191
|
+
cehrgpt_record, gender, first_visit_start_datetime, starting_age
|
192
|
+
)
|
193
|
+
self._update_cehrgpt_record(
|
194
|
+
cehrgpt_record, race, first_visit_start_datetime, starting_age
|
175
195
|
)
|
176
|
-
self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
|
177
|
-
self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
|
178
|
-
self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
|
179
196
|
|
180
197
|
# Use a data cursor to keep track of time
|
181
198
|
datetime_cursor: Optional[datetime.datetime] = None
|
@@ -211,6 +228,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
211
228
|
cehrgpt_record,
|
212
229
|
code=self._time_token_function(time_delta),
|
213
230
|
time=visit_start_datetime,
|
231
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
214
232
|
)
|
215
233
|
|
216
234
|
datetime_cursor = visit_start_datetime
|
@@ -219,12 +237,14 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
219
237
|
cehrgpt_record,
|
220
238
|
code="[VS]",
|
221
239
|
time=datetime_cursor,
|
240
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
222
241
|
)
|
223
242
|
# Add a visit type token
|
224
243
|
self._update_cehrgpt_record(
|
225
244
|
cehrgpt_record,
|
226
245
|
code=visit_type,
|
227
246
|
time=datetime_cursor,
|
247
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
228
248
|
)
|
229
249
|
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
230
250
|
# with respect to the midnight of the day
|
@@ -235,6 +255,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
235
255
|
cehrgpt_record,
|
236
256
|
code=f"i-H{datetime_cursor.hour}",
|
237
257
|
time=datetime_cursor,
|
258
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
238
259
|
)
|
239
260
|
|
240
261
|
# Keep track of the existing outpatient events, we don't want to add them again
|
@@ -281,6 +302,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
281
302
|
cehrgpt_record,
|
282
303
|
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
283
304
|
time=event_time,
|
305
|
+
age=relativedelta(event_time, birth_datetime).years,
|
284
306
|
)
|
285
307
|
|
286
308
|
if self._include_inpatient_hour_token:
|
@@ -300,6 +322,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
300
322
|
cehrgpt_record,
|
301
323
|
code=f"i-H{time_diff_hours}",
|
302
324
|
time=event_time,
|
325
|
+
age=relativedelta(event_time, birth_datetime).years,
|
303
326
|
)
|
304
327
|
|
305
328
|
if event_identity in existing_duplicate_events:
|
@@ -309,6 +332,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
309
332
|
cehrgpt_record,
|
310
333
|
code=code,
|
311
334
|
time=event_time,
|
335
|
+
age=relativedelta(event_time, birth_datetime).years,
|
312
336
|
concept_value_mask=concept_value_mask,
|
313
337
|
unit=unit,
|
314
338
|
number_as_value=numeric_value if numeric_value else 0.0,
|
@@ -348,6 +372,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
348
372
|
cehrgpt_record,
|
349
373
|
code=discharge_facility,
|
350
374
|
time=datetime_cursor,
|
375
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
351
376
|
)
|
352
377
|
|
353
378
|
# Reuse the age and date calculated for the last event in the patient timeline
|
@@ -355,6 +380,7 @@ class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
|
355
380
|
cehrgpt_record,
|
356
381
|
code="[VE]",
|
357
382
|
time=datetime_cursor,
|
383
|
+
age=relativedelta(datetime_cursor, birth_datetime).years,
|
358
384
|
)
|
359
385
|
|
360
386
|
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
@@ -428,6 +454,13 @@ class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
|
|
428
454
|
return record
|
429
455
|
|
430
456
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
457
|
+
# Reconstruct the ages input before the filter is applied
|
458
|
+
record["ages"] = construct_age_sequence(
|
459
|
+
record["concept_ids"], record.get("ages", None)
|
460
|
+
)
|
461
|
+
record["epoch_times"] = construct_time_sequence(
|
462
|
+
record["concept_ids"], record.get("epoch_times", None)
|
463
|
+
)
|
431
464
|
# Remove the tokens from patient sequences that do not exist in the tokenizer
|
432
465
|
record = self.filter_out_invalid_tokens(record)
|
433
466
|
# If any concept has a value associated with it, we normalize the value
|
@@ -77,6 +77,7 @@ def generate_trajectories_per_batch(
|
|
77
77
|
prediction_times = batch["index_date"].squeeze().detach().cpu().tolist()
|
78
78
|
batched_epoch_times = batch["epoch_times"].detach().cpu().tolist()
|
79
79
|
batched_input_ids = batch["input_ids"]
|
80
|
+
batched_ages = batch["ages"]
|
80
81
|
batched_value_indicators = batch["value_indicators"]
|
81
82
|
batched_values = batch["values"]
|
82
83
|
# Make sure the batch does not exceed batch_size
|
@@ -84,6 +85,7 @@ def generate_trajectories_per_batch(
|
|
84
85
|
cehrgpt_model,
|
85
86
|
cehrgpt_tokenizer,
|
86
87
|
batched_input_ids,
|
88
|
+
ages=batched_ages,
|
87
89
|
values=batched_values,
|
88
90
|
value_indicators=batched_value_indicators,
|
89
91
|
max_length=max_length,
|
@@ -2,7 +2,7 @@ import datetime
|
|
2
2
|
import os
|
3
3
|
import random
|
4
4
|
import uuid
|
5
|
-
from typing import Any, Dict,
|
5
|
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import pandas as pd
|
@@ -13,7 +13,7 @@ from transformers.utils import is_flash_attn_2_available, logging
|
|
13
13
|
|
14
14
|
from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
|
15
15
|
from cehrgpt.generation.omop_converter_batch import START_TOKEN_SIZE
|
16
|
-
from cehrgpt.gpt_utils import get_cehrgpt_output_folder
|
16
|
+
from cehrgpt.gpt_utils import construct_age_sequence, get_cehrgpt_output_folder
|
17
17
|
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
18
18
|
from cehrgpt.models.special_tokens import END_TOKEN
|
19
19
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
@@ -72,9 +72,10 @@ def normalize_value(
|
|
72
72
|
|
73
73
|
def generate_single_batch(
|
74
74
|
model: CEHRGPT2LMHeadModel,
|
75
|
-
|
76
|
-
prompts:
|
75
|
+
cehrgpt_tokenizer: CehrGptTokenizer,
|
76
|
+
prompts: torch.Tensor,
|
77
77
|
max_length: int,
|
78
|
+
ages: Optional[torch.Tensor] = None,
|
78
79
|
values: Optional[torch.Tensor] = None,
|
79
80
|
value_indicators: Optional[torch.Tensor] = None,
|
80
81
|
max_new_tokens: Optional[int] = None,
|
@@ -112,7 +113,9 @@ def generate_single_batch(
|
|
112
113
|
epsilon_cutoff=epsilon_cutoff,
|
113
114
|
)
|
114
115
|
|
115
|
-
batched_prompts =
|
116
|
+
batched_prompts = prompts.to(device)
|
117
|
+
if ages is not None:
|
118
|
+
ages = ages.to(device)
|
116
119
|
if values is not None:
|
117
120
|
values = values.to(device)
|
118
121
|
if value_indicators is not None:
|
@@ -120,19 +123,22 @@ def generate_single_batch(
|
|
120
123
|
|
121
124
|
results = model.generate(
|
122
125
|
inputs=batched_prompts,
|
126
|
+
ages=ages,
|
123
127
|
values=values,
|
124
128
|
value_indicators=value_indicators,
|
125
129
|
generation_config=generation_config,
|
126
|
-
|
130
|
+
cehrgpt_tokenizer=cehrgpt_tokenizer,
|
127
131
|
)
|
128
132
|
|
129
133
|
sequences = [
|
130
|
-
|
134
|
+
cehrgpt_tokenizer.decode(seq.cpu().numpy(), skip_special_tokens=False)
|
131
135
|
for seq in results.sequences
|
132
136
|
]
|
133
137
|
if results.sequence_vals is not None:
|
134
138
|
values = [
|
135
|
-
|
139
|
+
cehrgpt_tokenizer.decode_value(
|
140
|
+
values.cpu().numpy(), skip_special_tokens=False
|
141
|
+
)
|
136
142
|
for values in results.sequence_vals
|
137
143
|
]
|
138
144
|
else:
|
@@ -214,6 +220,7 @@ def main(args):
|
|
214
220
|
|
215
221
|
# Randomly pick demographics from the existing population
|
216
222
|
random_prompts = []
|
223
|
+
random_prompt_ages = []
|
217
224
|
iter = 0
|
218
225
|
while len(random_prompts) < args.batch_size:
|
219
226
|
for row in dataset.select(
|
@@ -224,9 +231,9 @@ def main(args):
|
|
224
231
|
<= len(row["concept_ids"])
|
225
232
|
<= max_seq_allowed
|
226
233
|
):
|
227
|
-
|
228
|
-
|
229
|
-
)
|
234
|
+
prompt = row["concept_ids"][:prompt_size]
|
235
|
+
random_prompts.append(cehrgpt_tokenizer.encode(prompt))
|
236
|
+
random_prompt_ages.append(construct_age_sequence(prompt))
|
230
237
|
iter += 1
|
231
238
|
if not random_prompts and iter > 10:
|
232
239
|
raise RuntimeError(
|
@@ -237,7 +244,8 @@ def main(args):
|
|
237
244
|
batch_sequences = generate_single_batch(
|
238
245
|
cehrgpt_model,
|
239
246
|
cehrgpt_tokenizer,
|
240
|
-
random_prompts[: args.batch_size],
|
247
|
+
torch.tensor(random_prompts[: args.batch_size]),
|
248
|
+
ages=torch.tensor(random_prompt_ages[: args.batch_size]),
|
241
249
|
max_length=args.context_window,
|
242
250
|
mini_num_of_concepts=args.min_num_of_concepts,
|
243
251
|
top_p=args.top_p,
|
@@ -270,20 +270,24 @@ def gpt_to_omop_converter_batch(
|
|
270
270
|
|
271
271
|
is_numeric_types = (
|
272
272
|
is_numeric_types[START_TOKEN_SIZE:]
|
273
|
-
if is_numeric_types is not None
|
273
|
+
if is_numeric_types is not None and not np.all(pd.isna(is_numeric_types))
|
274
274
|
else None
|
275
275
|
)
|
276
276
|
number_as_values = (
|
277
277
|
number_as_values[START_TOKEN_SIZE:]
|
278
|
-
if number_as_values is not None
|
278
|
+
if number_as_values is not None and not np.all(pd.isna(number_as_values))
|
279
279
|
else None
|
280
280
|
)
|
281
281
|
concept_as_values = (
|
282
282
|
concept_as_values[START_TOKEN_SIZE:]
|
283
|
-
if concept_as_values is not None
|
283
|
+
if concept_as_values is not None and not np.all(pd.isna(concept_as_values))
|
284
|
+
else None
|
285
|
+
)
|
286
|
+
units = (
|
287
|
+
units[START_TOKEN_SIZE:]
|
288
|
+
if units is not None and not np.all(pd.isna(units))
|
284
289
|
else None
|
285
290
|
)
|
286
|
-
units = units[START_TOKEN_SIZE:] if units is not None else None
|
287
291
|
|
288
292
|
# TODO:Need to decode if the input is tokenized
|
289
293
|
[start_year, start_age, start_gender, start_race] = concept_ids[
|
@@ -441,6 +445,9 @@ def gpt_to_omop_converter_batch(
|
|
441
445
|
]:
|
442
446
|
# If it's a start token, skip it
|
443
447
|
pass
|
448
|
+
elif event.endswith("/0"):
|
449
|
+
# This should capture the concept such as Visit/0, Discharge/0
|
450
|
+
pass
|
444
451
|
else:
|
445
452
|
try:
|
446
453
|
concept_id = int(event)
|
cehrgpt/gpt_utils.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
1
|
import random
|
2
2
|
import re
|
3
|
-
from datetime import date, timedelta
|
4
|
-
from typing import List, Sequence, Tuple
|
3
|
+
from datetime import date, datetime, timedelta, timezone
|
4
|
+
from typing import List, Optional, Sequence, Tuple, Union
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
from cehrbert_data.const.artificial_tokens import DEATH_TOKEN
|
8
|
+
from meds import death_code
|
9
|
+
from transformers.utils import logging
|
5
10
|
|
6
11
|
from cehrgpt.cehrgpt_args import SamplingStrategy
|
7
12
|
from cehrgpt.models.special_tokens import (
|
@@ -14,6 +19,7 @@ from cehrgpt.models.special_tokens import (
|
|
14
19
|
MEDS_CODE_PATTERN = re.compile(r".*/.*")
|
15
20
|
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
16
21
|
DEMOGRAPHIC_PROMPT_SIZE = 4
|
22
|
+
logger = logging.get_logger("transformers")
|
17
23
|
|
18
24
|
|
19
25
|
class RandomSampleCache:
|
@@ -62,6 +68,68 @@ class RandomSampleCache:
|
|
62
68
|
return self._cache.pop()
|
63
69
|
|
64
70
|
|
71
|
+
def construct_time_sequence(
|
72
|
+
concept_ids: List[str], epoch_times: Optional[List[Union[int, float]]] = None
|
73
|
+
) -> List[float]:
|
74
|
+
if epoch_times is not None:
|
75
|
+
return epoch_times
|
76
|
+
|
77
|
+
if concept_ids[0].lower().startswith("year"):
|
78
|
+
year_str = concept_ids[0].split(":")[1]
|
79
|
+
else:
|
80
|
+
year_str = "1985"
|
81
|
+
|
82
|
+
datetime_cursor = datetime(
|
83
|
+
int(year_str), month=1, day=1, hour=0, minute=0, second=0
|
84
|
+
).replace(tzinfo=timezone.utc)
|
85
|
+
epoch_times = []
|
86
|
+
for concept_id in concept_ids:
|
87
|
+
if is_att_token(concept_id):
|
88
|
+
att_days = extract_time_interval_in_days(concept_id)
|
89
|
+
datetime_cursor += timedelta(days=att_days)
|
90
|
+
epoch_times.append(datetime_cursor.timestamp())
|
91
|
+
return epoch_times
|
92
|
+
|
93
|
+
|
94
|
+
def construct_age_sequence(
|
95
|
+
concept_ids: List[str], ages: Optional[List[int]] = None
|
96
|
+
) -> List[int]:
|
97
|
+
if ages is not None:
|
98
|
+
return ages
|
99
|
+
elif concept_ids[1].lower().startswith("age"):
|
100
|
+
age_str = concept_ids[1].split(":")[1]
|
101
|
+
assert age_str.isnumeric(), f"age_str: {age_str}"
|
102
|
+
ages = []
|
103
|
+
time_delta = 0
|
104
|
+
for concept_id in concept_ids:
|
105
|
+
if is_att_token(concept_id):
|
106
|
+
time_delta += extract_time_interval_in_days(concept_id)
|
107
|
+
ages.append(int(age_str) + time_delta // 365)
|
108
|
+
return ages
|
109
|
+
else:
|
110
|
+
logger.warning(
|
111
|
+
"The second token is not a valid age token. The first 4 tokens are: %s. "
|
112
|
+
"Trying to fall back to ages, but it is not valid either %s. "
|
113
|
+
"Fall back to a zero vector [0, 0, 0, ...., 0]",
|
114
|
+
concept_ids[:4],
|
115
|
+
ages,
|
116
|
+
)
|
117
|
+
return np.zeros_like(concept_ids, dtype=int).tolist()
|
118
|
+
|
119
|
+
|
120
|
+
def multiple_of_10(n: int) -> int:
|
121
|
+
return ((n // 10) + 1) * 10
|
122
|
+
|
123
|
+
|
124
|
+
def encode_demographics(
|
125
|
+
age: int, gender: int, race: int, max_age=200, max_gender=10, max_race=10
|
126
|
+
) -> int:
|
127
|
+
assert 0 <= age < max_age, f"age: {age}"
|
128
|
+
assert 0 <= gender < max_gender, f"gender: {gender}"
|
129
|
+
assert 0 <= race < max_race, f"race: {race}"
|
130
|
+
return age + max_age * gender + max_age * max_gender * race
|
131
|
+
|
132
|
+
|
65
133
|
def collect_demographic_prompts_at_visits(patient_history: List[str]):
|
66
134
|
demographic_prompts_at_visits = []
|
67
135
|
start_year, start_age, start_gender, start_race = patient_history[
|
@@ -156,7 +224,7 @@ def random_slice_gpt_sequence(concept_ids, max_seq_len):
|
|
156
224
|
)
|
157
225
|
):
|
158
226
|
current_token = concept_ids[i]
|
159
|
-
if current_token
|
227
|
+
if is_visit_end(current_token):
|
160
228
|
random_end_index = i
|
161
229
|
break
|
162
230
|
return random_starting_index, random_end_index, demographic_tokens
|
@@ -198,6 +266,8 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
|
198
266
|
def is_clinical_event(token: str, meds: bool = False) -> bool:
|
199
267
|
if token.isnumeric():
|
200
268
|
return True
|
269
|
+
if token in [DEATH_TOKEN, death_code]:
|
270
|
+
return True
|
201
271
|
if meds:
|
202
272
|
return bool(MEDS_CODE_PATTERN.match(token))
|
203
273
|
return False
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# From https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py
|
2
|
+
# coding=utf-8
|
3
|
+
|
4
|
+
from __future__ import absolute_import, division, print_function
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
import transformers.pytorch_utils
|
9
|
+
|
10
|
+
|
11
|
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
|
12
|
+
class RMSNorm(nn.Module):
|
13
|
+
def __init__(self, hidden_size, eps=1e-6):
|
14
|
+
"""MistralRMSNorm is equivalent to T5LayerNorm."""
|
15
|
+
super().__init__()
|
16
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
17
|
+
self.variance_epsilon = eps
|
18
|
+
|
19
|
+
def forward(self, hidden_states):
|
20
|
+
input_dtype = hidden_states.dtype
|
21
|
+
hidden_states = hidden_states.to(torch.float32)
|
22
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
23
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
24
|
+
return self.weight * hidden_states.to(input_dtype)
|
25
|
+
|
26
|
+
|
27
|
+
transformers.pytorch_utils.ALL_LAYERNORM_LAYERS.extend([RMSNorm])
|
cehrgpt/models/config.py
CHANGED
@@ -106,6 +106,8 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
106
106
|
n_head=12,
|
107
107
|
n_inner=None,
|
108
108
|
activation_function="gelu_new",
|
109
|
+
decoder_mlp="GPT2MLP",
|
110
|
+
mlp_bias=False,
|
109
111
|
resid_pdrop=0.1,
|
110
112
|
embd_pdrop=0.1,
|
111
113
|
attn_pdrop=0.1,
|
@@ -124,7 +126,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
124
126
|
ve_token_id=None,
|
125
127
|
scale_attn_by_inverse_layer_idx=False,
|
126
128
|
reorder_and_upcast_attn=False,
|
127
|
-
|
129
|
+
apply_rotary=False,
|
128
130
|
include_values=False,
|
129
131
|
value_vocab_size=None,
|
130
132
|
include_ttv_prediction=False,
|
@@ -169,6 +171,8 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
169
171
|
self.n_head = n_head
|
170
172
|
self.n_inner = n_inner
|
171
173
|
self.activation_function = activation_function
|
174
|
+
self.decoder_mlp = decoder_mlp
|
175
|
+
self.mlp_bias = mlp_bias
|
172
176
|
self.resid_pdrop = resid_pdrop
|
173
177
|
self.embd_pdrop = embd_pdrop
|
174
178
|
self.attn_pdrop = attn_pdrop
|
@@ -188,7 +192,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
188
192
|
self.eos_token_id = eos_token_id
|
189
193
|
self.lab_token_ids = lab_token_ids
|
190
194
|
|
191
|
-
self.
|
195
|
+
self.apply_rotary = apply_rotary
|
192
196
|
self.include_values = include_values
|
193
197
|
self.value_vocab_size = value_vocab_size
|
194
198
|
|