cehrgpt 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +1 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +398 -36
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +214 -12
- cehrgpt/data/sample_packing_sampler.py +36 -6
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +25 -0
- cehrgpt/models/hf_cehrgpt.py +227 -33
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +354 -71
- cehrgpt/runners/data_utils.py +117 -2
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +75 -50
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +59 -7
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +48 -0
- cehrgpt/runners/hyperparameter_search_util.py +6 -7
- cehrgpt/runners/sample_packing_trainer.py +17 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +85 -57
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +8 -7
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +27 -25
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {cehrgpt-0.1.0.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
import datetime
|
2
|
-
from
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import pandas as pd
|
7
|
+
from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
|
6
8
|
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
9
|
ED_VISIT_TYPE_CODES,
|
8
10
|
INPATIENT_VISIT_TYPE_CODES,
|
@@ -15,9 +17,17 @@ from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
|
15
17
|
)
|
16
18
|
from cehrbert.med_extension.schema_extension import Event
|
17
19
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
20
|
+
from cehrbert_data.const.artificial_tokens import (
|
21
|
+
DISCHARGE_UNKNOWN_TOKEN,
|
22
|
+
GENDER_UNKNOWN_TOKEN,
|
23
|
+
RACE_UNKNOWN_TOKEN,
|
24
|
+
VISIT_UNKNOWN_TOKEN,
|
25
|
+
)
|
18
26
|
from cehrbert_data.const.common import NA
|
19
27
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
28
|
+
from datasets.formatting.formatting import LazyBatch
|
20
29
|
from dateutil.relativedelta import relativedelta
|
30
|
+
from pandas import Series
|
21
31
|
|
22
32
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
23
33
|
NONE_BIN,
|
@@ -25,6 +35,17 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
|
|
25
35
|
CehrGptTokenizer,
|
26
36
|
)
|
27
37
|
|
38
|
+
CEHRGPT_COLUMNS = [
|
39
|
+
"concept_ids",
|
40
|
+
"concept_value_masks",
|
41
|
+
"number_as_values",
|
42
|
+
"concept_as_values",
|
43
|
+
"is_numeric_types",
|
44
|
+
"concept_values",
|
45
|
+
"units",
|
46
|
+
"epoch_times",
|
47
|
+
]
|
48
|
+
|
28
49
|
|
29
50
|
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
30
51
|
return datetime.datetime.combine(
|
@@ -32,7 +53,36 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
|
32
53
|
).timestamp()
|
33
54
|
|
34
55
|
|
35
|
-
class
|
56
|
+
class DatasetMappingDecorator(DatasetMapping):
|
57
|
+
|
58
|
+
def batch_transform(
|
59
|
+
self, records: Union[LazyBatch, Dict[str, Any]]
|
60
|
+
) -> List[Dict[str, Any]]:
|
61
|
+
"""
|
62
|
+
Drop index_date if it contains None.
|
63
|
+
|
64
|
+
:param records:
|
65
|
+
:return:
|
66
|
+
"""
|
67
|
+
if isinstance(records, LazyBatch):
|
68
|
+
table = records.pa_table
|
69
|
+
|
70
|
+
if "index_date" in table.column_names:
|
71
|
+
index_col = table.column("index_date")
|
72
|
+
if index_col.null_count > 0:
|
73
|
+
table = table.drop(["index_date"])
|
74
|
+
records = LazyBatch(pa_table=table, formatter=records.formatter)
|
75
|
+
else:
|
76
|
+
if "index_date" in records:
|
77
|
+
if pd.isna(records["index_date"][0]):
|
78
|
+
del records["index_date"]
|
79
|
+
return super().batch_transform(records=records)
|
80
|
+
|
81
|
+
def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
|
82
|
+
raise NotImplemented("Must be implemented")
|
83
|
+
|
84
|
+
|
85
|
+
class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
36
86
|
def __init__(
|
37
87
|
self,
|
38
88
|
data_args: DataTrainingArguments,
|
@@ -65,6 +115,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
65
115
|
def _update_cehrgpt_record(
|
66
116
|
cehrgpt_record: Dict[str, Any],
|
67
117
|
code: str,
|
118
|
+
time: datetime.datetime,
|
68
119
|
concept_value_mask: int = 0,
|
69
120
|
number_as_value: float = 0.0,
|
70
121
|
concept_as_value: str = "0",
|
@@ -77,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
77
128
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
78
129
|
cehrgpt_record["units"].append(unit)
|
79
130
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
131
|
+
cehrgpt_record["epoch_times"].append(time.timestamp())
|
80
132
|
|
81
133
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
82
134
|
cehrgpt_record = {
|
@@ -87,13 +139,16 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
87
139
|
"concept_as_values": [],
|
88
140
|
"units": [],
|
89
141
|
"is_numeric_types": [],
|
142
|
+
"epoch_times": [],
|
90
143
|
}
|
91
144
|
# Extract the demographic information
|
92
145
|
birth_datetime = record["birth_datetime"]
|
93
146
|
if isinstance(birth_datetime, pd.Timestamp):
|
94
147
|
birth_datetime = birth_datetime.to_pydatetime()
|
95
148
|
gender = record["gender"]
|
149
|
+
gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
|
96
150
|
race = record["race"]
|
151
|
+
race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
|
97
152
|
visits = record["visits"]
|
98
153
|
# This indicates this is columnar format
|
99
154
|
if isinstance(visits, dict):
|
@@ -108,10 +163,12 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
108
163
|
)
|
109
164
|
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
110
165
|
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
|
111
|
-
self._update_cehrgpt_record(
|
112
|
-
|
113
|
-
|
114
|
-
self._update_cehrgpt_record(cehrgpt_record,
|
166
|
+
self._update_cehrgpt_record(
|
167
|
+
cehrgpt_record, year_str, first_visit_start_datetime
|
168
|
+
)
|
169
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
|
170
|
+
self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
|
171
|
+
self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
|
115
172
|
|
116
173
|
# Use a data cursor to keep track of time
|
117
174
|
datetime_cursor: Optional[datetime.datetime] = None
|
@@ -146,6 +203,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
146
203
|
self._update_cehrgpt_record(
|
147
204
|
cehrgpt_record,
|
148
205
|
code=self._time_token_function(time_delta),
|
206
|
+
time=visit_start_datetime,
|
149
207
|
)
|
150
208
|
|
151
209
|
datetime_cursor = visit_start_datetime
|
@@ -153,11 +211,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
153
211
|
self._update_cehrgpt_record(
|
154
212
|
cehrgpt_record,
|
155
213
|
code="[VS]",
|
214
|
+
time=datetime_cursor,
|
156
215
|
)
|
157
216
|
# Add a visit type token
|
158
217
|
self._update_cehrgpt_record(
|
159
218
|
cehrgpt_record,
|
160
219
|
code=visit_type,
|
220
|
+
time=datetime_cursor,
|
161
221
|
)
|
162
222
|
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
163
223
|
# with respect to the midnight of the day
|
@@ -167,6 +227,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
167
227
|
self._update_cehrgpt_record(
|
168
228
|
cehrgpt_record,
|
169
229
|
code=f"i-H{datetime_cursor.hour}",
|
230
|
+
time=datetime_cursor,
|
170
231
|
)
|
171
232
|
|
172
233
|
# Keep track of the existing outpatient events, we don't want to add them again
|
@@ -185,6 +246,10 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
185
246
|
concept_value_mask = int(
|
186
247
|
numeric_value is not None or text_value is not None
|
187
248
|
)
|
249
|
+
if numeric_value is None and text_value is not None:
|
250
|
+
if text_value.isnumeric():
|
251
|
+
numeric_value = float(text_value)
|
252
|
+
|
188
253
|
is_numeric_type = int(numeric_value is not None)
|
189
254
|
code = replace_escape_chars(e["code"])
|
190
255
|
|
@@ -208,6 +273,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
208
273
|
self._update_cehrgpt_record(
|
209
274
|
cehrgpt_record,
|
210
275
|
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
276
|
+
time=event_time,
|
211
277
|
)
|
212
278
|
|
213
279
|
if self._include_inpatient_hour_token:
|
@@ -226,6 +292,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
226
292
|
self._update_cehrgpt_record(
|
227
293
|
cehrgpt_record,
|
228
294
|
code=f"i-H{time_diff_hours}",
|
295
|
+
time=event_time,
|
229
296
|
)
|
230
297
|
|
231
298
|
if event_identity in existing_duplicate_events:
|
@@ -234,6 +301,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
234
301
|
self._update_cehrgpt_record(
|
235
302
|
cehrgpt_record,
|
236
303
|
code=code,
|
304
|
+
time=event_time,
|
237
305
|
concept_value_mask=concept_value_mask,
|
238
306
|
unit=unit,
|
239
307
|
number_as_value=numeric_value if numeric_value else 0.0,
|
@@ -262,17 +330,24 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
262
330
|
# facility event
|
263
331
|
discharge_facility = get_value(visit, "discharge_facility")
|
264
332
|
if not discharge_facility:
|
265
|
-
discharge_facility =
|
266
|
-
|
333
|
+
discharge_facility = DISCHARGE_UNKNOWN_TOKEN
|
334
|
+
else:
|
335
|
+
discharge_facility = (
|
336
|
+
DISCHARGE_UNKNOWN_TOKEN
|
337
|
+
if discharge_facility == UNKNOWN_VALUE
|
338
|
+
else discharge_facility
|
339
|
+
)
|
267
340
|
self._update_cehrgpt_record(
|
268
341
|
cehrgpt_record,
|
269
342
|
code=discharge_facility,
|
343
|
+
time=datetime_cursor,
|
270
344
|
)
|
271
345
|
|
272
346
|
# Reuse the age and date calculated for the last event in the patient timeline
|
273
347
|
self._update_cehrgpt_record(
|
274
348
|
cehrgpt_record,
|
275
349
|
code="[VE]",
|
350
|
+
time=datetime_cursor,
|
276
351
|
)
|
277
352
|
|
278
353
|
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
@@ -284,17 +359,21 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
284
359
|
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
285
360
|
cehrgpt_record["num_of_visits"] = len(visits)
|
286
361
|
|
287
|
-
if record.get("index_date", None):
|
362
|
+
if record.get("index_date", None) is not None:
|
288
363
|
cehrgpt_record["index_date"] = record["index_date"]
|
289
|
-
if record.get("label", None):
|
364
|
+
if record.get("label", None) is not None:
|
290
365
|
cehrgpt_record["label"] = record["label"]
|
291
|
-
if record.get("age_at_index", None):
|
366
|
+
if record.get("age_at_index", None) is not None:
|
292
367
|
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
293
368
|
|
369
|
+
assert len(cehrgpt_record["epoch_times"]) == len(
|
370
|
+
cehrgpt_record["concept_ids"]
|
371
|
+
), "The number of time stamps must match with the number of concepts in the sequence"
|
372
|
+
|
294
373
|
return cehrgpt_record
|
295
374
|
|
296
375
|
|
297
|
-
class HFCehrGptTokenizationMapping(
|
376
|
+
class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
|
298
377
|
def __init__(
|
299
378
|
self,
|
300
379
|
concept_tokenizer: CehrGptTokenizer,
|
@@ -308,9 +387,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
|
|
308
387
|
"is_numeric_types",
|
309
388
|
]
|
310
389
|
|
390
|
+
def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
391
|
+
column_names = []
|
392
|
+
seq_length = len(record["concept_ids"])
|
393
|
+
|
394
|
+
# We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
|
395
|
+
# This is a pre-caution
|
396
|
+
if "0" in record["concept_ids"]:
|
397
|
+
if isinstance(record["concept_ids"], np.ndarray):
|
398
|
+
record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
|
399
|
+
else:
|
400
|
+
record["concept_ids"] = [
|
401
|
+
"Unknown" if x == "0" else x for x in record["concept_ids"]
|
402
|
+
]
|
403
|
+
|
404
|
+
for k, v in record.items():
|
405
|
+
if k not in CEHRGPT_COLUMNS:
|
406
|
+
continue
|
407
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
408
|
+
column_names.append(k)
|
409
|
+
valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
|
410
|
+
valid_indices = [
|
411
|
+
idx
|
412
|
+
for idx, concept_id in enumerate(record["concept_ids"])
|
413
|
+
if concept_id in valid_concept_ids
|
414
|
+
]
|
415
|
+
if len(valid_indices) != len(record["concept_ids"]):
|
416
|
+
for column in column_names:
|
417
|
+
values = record[column]
|
418
|
+
record[column] = [values[idx] for idx in valid_indices]
|
419
|
+
return record
|
420
|
+
|
311
421
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
422
|
+
# Remove the tokens from patient sequences that do not exist in the tokenizer
|
423
|
+
record = self.filter_out_invalid_tokens(record)
|
312
424
|
# If any concept has a value associated with it, we normalize the value
|
313
425
|
record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
|
426
|
+
assert len(record["input_ids"]) == len(record["concept_ids"]), (
|
427
|
+
"The number of tokens must equal to the number of concepts\n"
|
428
|
+
f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
|
429
|
+
)
|
314
430
|
record["value_indicators"] = record["concept_value_masks"]
|
315
431
|
if "number_as_values" not in record or "concept_as_values" not in record:
|
316
432
|
record["number_as_values"] = [
|
@@ -391,3 +507,89 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
|
|
391
507
|
columns = super().remove_columns()
|
392
508
|
columns.append("label")
|
393
509
|
return columns
|
510
|
+
|
511
|
+
|
512
|
+
class ExtractTokenizedSequenceDataMapping:
|
513
|
+
def __init__(
|
514
|
+
self,
|
515
|
+
person_index_date_map: Dict[int, List[Dict[str, Any]]],
|
516
|
+
observation_window: int = 0,
|
517
|
+
):
|
518
|
+
self.person_index_date_map = person_index_date_map
|
519
|
+
self.observation_window = observation_window
|
520
|
+
|
521
|
+
def _calculate_prediction_start_time(self, prediction_time: float):
|
522
|
+
if self.observation_window and self.observation_window > 0:
|
523
|
+
return max(prediction_time - self.observation_window * 24 * 3600, 0)
|
524
|
+
return 0
|
525
|
+
|
526
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
527
|
+
person_id = record["person_id"]
|
528
|
+
prediction_times = self.person_index_date_map[person_id]
|
529
|
+
prediction_start_end_times = [
|
530
|
+
(
|
531
|
+
self._calculate_prediction_start_time(
|
532
|
+
prediction_time_label_map["index_date"].timestamp()
|
533
|
+
),
|
534
|
+
prediction_time_label_map["index_date"].timestamp(),
|
535
|
+
prediction_time_label_map["label"],
|
536
|
+
)
|
537
|
+
for prediction_time_label_map in prediction_times
|
538
|
+
]
|
539
|
+
observation_window_indices = np.zeros(
|
540
|
+
(len(prediction_times), len(record["epoch_times"])), dtype=bool
|
541
|
+
)
|
542
|
+
for i, epoch_time in enumerate(record["epoch_times"]):
|
543
|
+
for sample_n, (
|
544
|
+
feature_extraction_time_start,
|
545
|
+
feature_extraction_end_end,
|
546
|
+
_,
|
547
|
+
) in enumerate(prediction_start_end_times):
|
548
|
+
if (
|
549
|
+
feature_extraction_time_start
|
550
|
+
<= epoch_time
|
551
|
+
<= feature_extraction_end_end
|
552
|
+
):
|
553
|
+
observation_window_indices[sample_n][i] = True
|
554
|
+
|
555
|
+
seq_length = len(record["epoch_times"])
|
556
|
+
time_series_columns = ["concept_ids", "input_ids"]
|
557
|
+
static_inputs = dict()
|
558
|
+
for k, v in record.items():
|
559
|
+
if k in ["concept_ids", "input_ids"]:
|
560
|
+
continue
|
561
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
562
|
+
time_series_columns.append(k)
|
563
|
+
else:
|
564
|
+
static_inputs[k] = v
|
565
|
+
|
566
|
+
batched_samples = defaultdict(list)
|
567
|
+
for (_, index_date, label), observation_window_index in zip(
|
568
|
+
prediction_start_end_times, observation_window_indices
|
569
|
+
):
|
570
|
+
for k, v in static_inputs.items():
|
571
|
+
batched_samples[k].append(v)
|
572
|
+
batched_samples["classifier_label"].append(label)
|
573
|
+
batched_samples["index_date"].append(index_date)
|
574
|
+
try:
|
575
|
+
start_age = int(record["concept_ids"][1].split(":")[1])
|
576
|
+
except Exception:
|
577
|
+
start_age = -1
|
578
|
+
batched_samples["age_at_index"].append(start_age)
|
579
|
+
for time_series_column in time_series_columns:
|
580
|
+
batched_samples[time_series_column].append(
|
581
|
+
np.asarray(record[time_series_column])[observation_window_index]
|
582
|
+
)
|
583
|
+
return batched_samples
|
584
|
+
|
585
|
+
def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
586
|
+
all_batched_record = defaultdict(list)
|
587
|
+
all_columns = record.keys()
|
588
|
+
for i in range(len(record["concept_ids"])):
|
589
|
+
one_record = {}
|
590
|
+
for column in all_columns:
|
591
|
+
one_record[column] = record[column][i]
|
592
|
+
new_batched_record = self.transform(one_record)
|
593
|
+
for k, v in new_batched_record.items():
|
594
|
+
all_batched_record[k].extend(v)
|
595
|
+
return all_batched_record
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import Iterator, List, Optional
|
2
2
|
|
3
|
+
import numpy as np
|
3
4
|
import torch
|
4
5
|
import torch.distributed as dist
|
5
6
|
from torch.utils.data import Sampler
|
@@ -33,6 +34,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
33
34
|
rank: Optional[int] = None,
|
34
35
|
seed: int = 0,
|
35
36
|
drop_last: bool = False,
|
37
|
+
negative_sampling_probability: Optional[float] = None,
|
38
|
+
labels: Optional[List[int]] = None,
|
36
39
|
):
|
37
40
|
"""
|
38
41
|
Args:
|
@@ -73,6 +76,11 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
73
76
|
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
74
77
|
)
|
75
78
|
|
79
|
+
if negative_sampling_probability is not None and labels is None:
|
80
|
+
raise ValueError(
|
81
|
+
f"When the negative sampling probability is provide, the labels must be provided as well"
|
82
|
+
)
|
83
|
+
|
76
84
|
self.lengths = lengths
|
77
85
|
self.max_tokens_per_batch = max_tokens_per_batch
|
78
86
|
self.max_position_embeddings = max_position_embeddings
|
@@ -80,6 +88,8 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
80
88
|
self.rank = rank
|
81
89
|
self.seed = seed
|
82
90
|
self.drop_last = drop_last
|
91
|
+
self.negative_sampling_probability = negative_sampling_probability
|
92
|
+
self.labels = labels
|
83
93
|
# Trainer https://github.com/huggingface/transformers/blame/main/src/transformers/trainer.py#L2470
|
84
94
|
# http://github.com/huggingface/accelerate/blob/v0.31.0/src/accelerate/data_loader.py#L482
|
85
95
|
# the huggingface trainer will call the accelerate.data_loader.DataLoaderShard.set_epoch,
|
@@ -100,6 +110,14 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
100
110
|
current_batch_tokens = 0
|
101
111
|
|
102
112
|
for idx in indices:
|
113
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
114
|
+
# in the fine-tuning dataset
|
115
|
+
if self.negative_sampling_probability:
|
116
|
+
if (
|
117
|
+
np.random.random() > self.negative_sampling_probability
|
118
|
+
and self.labels[idx] == 0
|
119
|
+
):
|
120
|
+
continue
|
103
121
|
# We take the minimum of the two because each sequence will be truncated to fit
|
104
122
|
# the context window of the model
|
105
123
|
sample_length = min(self.lengths[idx], self.max_position_embeddings)
|
@@ -131,10 +149,22 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
131
149
|
if len(self.lengths) == 0:
|
132
150
|
return 0
|
133
151
|
|
134
|
-
#
|
135
|
-
|
136
|
-
|
137
|
-
|
152
|
+
# There is a chance to skip the negative samples to account for the class imbalance
|
153
|
+
# in the fine-tuning dataset
|
154
|
+
if self.negative_sampling_probability:
|
155
|
+
truncated_lengths = []
|
156
|
+
for length, label in zip(self.lengths, self.labels):
|
157
|
+
if (
|
158
|
+
np.random.random() > self.negative_sampling_probability
|
159
|
+
and label == 0
|
160
|
+
):
|
161
|
+
continue
|
162
|
+
truncated_lengths.append(length)
|
163
|
+
else:
|
164
|
+
# We need to truncate the lengths due to the context window limit imposed by the model
|
165
|
+
truncated_lengths = [
|
166
|
+
min(self.max_position_embeddings, length + 2) for length in self.lengths
|
167
|
+
]
|
138
168
|
|
139
169
|
# Calculate average sequence length
|
140
170
|
avg_seq_length = sum(truncated_lengths) // len(truncated_lengths)
|
@@ -145,7 +175,7 @@ class SamplePackingBatchSampler(Sampler[List[int]]):
|
|
145
175
|
# Estimate total number of batches
|
146
176
|
if self.drop_last:
|
147
177
|
# If dropping last incomplete batch
|
148
|
-
return len(truncated_lengths) // seqs_per_batch
|
178
|
+
return len(truncated_lengths) // seqs_per_batch
|
149
179
|
else:
|
150
180
|
# If keeping last incomplete batch, ensure at least 1 batch
|
151
|
-
return max(1, len(truncated_lengths) // seqs_per_batch)
|
181
|
+
return max(1, len(truncated_lengths) // seqs_per_batch)
|
@@ -60,6 +60,24 @@ OOV_CONCEPT_MAP = {
|
|
60
60
|
}
|
61
61
|
|
62
62
|
|
63
|
+
def extract_gender_concept_id(gender_token: str) -> int:
|
64
|
+
if gender_token.startswith("Gender/"):
|
65
|
+
return int(gender_token[len("Gender/") :])
|
66
|
+
elif gender_token.isnumeric():
|
67
|
+
return int(gender_token)
|
68
|
+
else:
|
69
|
+
return 0
|
70
|
+
|
71
|
+
|
72
|
+
def extract_race_concept_id(race_token: str) -> int:
|
73
|
+
if race_token.startswith("Race/"):
|
74
|
+
return int(race_token[len("Race/") :])
|
75
|
+
elif race_token.isnumeric():
|
76
|
+
return int(race_token)
|
77
|
+
else:
|
78
|
+
return 0
|
79
|
+
|
80
|
+
|
63
81
|
def create_folder_if_not_exists(output_folder, table_name):
|
64
82
|
if not os.path.isdir(Path(output_folder) / table_name):
|
65
83
|
os.mkdir(Path(output_folder) / table_name)
|
@@ -288,7 +306,13 @@ def gpt_to_omop_converter_batch(
|
|
288
306
|
if int(birth_year) < 1900 or int(birth_year) > datetime.date.today().year:
|
289
307
|
continue
|
290
308
|
|
291
|
-
p = Person(
|
309
|
+
p = Person(
|
310
|
+
person_id=person_id,
|
311
|
+
gender_concept_id=extract_gender_concept_id(start_gender),
|
312
|
+
year_of_birth=birth_year,
|
313
|
+
race_concept_id=extract_race_concept_id(start_race),
|
314
|
+
)
|
315
|
+
|
292
316
|
append_to_dict(omop_export_dict, p, person_id)
|
293
317
|
id_mappings_dict["person"][person_id] = person_id
|
294
318
|
pt_seq_dict[person_id] = " ".join(concept_ids)
|
@@ -316,7 +340,12 @@ def gpt_to_omop_converter_batch(
|
|
316
340
|
id_mappings_dict["death"][person_id] = person_id
|
317
341
|
else:
|
318
342
|
try:
|
319
|
-
|
343
|
+
if clinical_events[event_idx + 1].startswith("Visit/"):
|
344
|
+
visit_concept_id = int(
|
345
|
+
clinical_events[event_idx + 1][len("Visit/") :]
|
346
|
+
)
|
347
|
+
else:
|
348
|
+
visit_concept_id = int(clinical_events[event_idx + 1])
|
320
349
|
inpatient_visit_indicator = visit_concept_id in [
|
321
350
|
9201,
|
322
351
|
262,
|
@@ -349,6 +378,7 @@ def gpt_to_omop_converter_batch(
|
|
349
378
|
visit_occurrence_id
|
350
379
|
] = person_id
|
351
380
|
visit_occurrence_id += 1
|
381
|
+
|
352
382
|
elif event in ATT_TIME_TOKENS:
|
353
383
|
if event[0] == "D":
|
354
384
|
att_date_delta = int(event[1:])
|
cehrgpt/gpt_utils.py
CHANGED
@@ -11,6 +11,7 @@ from cehrgpt.models.special_tokens import (
|
|
11
11
|
)
|
12
12
|
|
13
13
|
# Regular expression pattern to match inpatient attendance tokens
|
14
|
+
MEDS_CODE_PATTERN = re.compile(r".*/.*")
|
14
15
|
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
|
15
16
|
DEMOGRAPHIC_PROMPT_SIZE = 4
|
16
17
|
|
@@ -194,8 +195,12 @@ def get_cehrgpt_output_folder(args, cehrgpt_tokenizer) -> str:
|
|
194
195
|
return folder_name
|
195
196
|
|
196
197
|
|
197
|
-
def is_clinical_event(token: str) -> bool:
|
198
|
-
|
198
|
+
def is_clinical_event(token: str, meds: bool = False) -> bool:
|
199
|
+
if token.isnumeric():
|
200
|
+
return True
|
201
|
+
if meds:
|
202
|
+
return bool(MEDS_CODE_PATTERN.match(token))
|
203
|
+
return False
|
199
204
|
|
200
205
|
|
201
206
|
def is_visit_start(token: str):
|
@@ -212,6 +217,18 @@ def is_visit_end(token: str) -> bool:
|
|
212
217
|
return token in ["VE", "[VE]"]
|
213
218
|
|
214
219
|
|
220
|
+
def is_inpatient_hour_token(token: str) -> bool:
|
221
|
+
return token.startswith("i-H")
|
222
|
+
|
223
|
+
|
224
|
+
def extract_time_interval_in_hours(token: str) -> int:
|
225
|
+
try:
|
226
|
+
hour = int(token[3:])
|
227
|
+
return hour
|
228
|
+
except ValueError:
|
229
|
+
return 0
|
230
|
+
|
231
|
+
|
215
232
|
def is_att_token(token: str):
|
216
233
|
"""
|
217
234
|
Check if the token is an attention token.
|
@@ -251,6 +268,7 @@ def is_artificial_token(token: str) -> bool:
|
|
251
268
|
return True
|
252
269
|
if token == END_TOKEN:
|
253
270
|
return True
|
271
|
+
|
254
272
|
return False
|
255
273
|
|
256
274
|
|
cehrgpt/models/config.py
CHANGED
@@ -121,6 +121,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
121
121
|
bos_token_id=50256,
|
122
122
|
eos_token_id=50256,
|
123
123
|
lab_token_ids=None,
|
124
|
+
ve_token_id=None,
|
124
125
|
scale_attn_by_inverse_layer_idx=False,
|
125
126
|
reorder_and_upcast_attn=False,
|
126
127
|
exclude_position_ids=False,
|
@@ -128,6 +129,10 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
128
129
|
value_vocab_size=None,
|
129
130
|
include_ttv_prediction=False,
|
130
131
|
use_sub_time_tokenization=True,
|
132
|
+
include_motor_time_to_event=True,
|
133
|
+
motor_tte_vocab_size=None,
|
134
|
+
motor_time_to_event_weight=1.0,
|
135
|
+
motor_num_time_pieces=16,
|
131
136
|
token_to_time_token_mapping: Dict[int, List] = None,
|
132
137
|
use_pretrained_embeddings=False,
|
133
138
|
n_pretrained_embeddings_layers=2,
|
@@ -144,6 +149,7 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
144
149
|
entropy_penalty=False,
|
145
150
|
entropy_penalty_alpha=0.01,
|
146
151
|
sample_packing_max_positions=None,
|
152
|
+
class_weights=None,
|
147
153
|
**kwargs,
|
148
154
|
):
|
149
155
|
if token_to_time_token_mapping is None:
|
@@ -192,6 +198,22 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
192
198
|
self._token_to_time_token_mapping = token_to_time_token_mapping
|
193
199
|
self.time_token_loss_weight = time_token_loss_weight
|
194
200
|
self.time_to_visit_loss_weight = time_to_visit_loss_weight
|
201
|
+
|
202
|
+
# MOTOR TTE configuration
|
203
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
204
|
+
self.include_motor_time_to_event = (
|
205
|
+
include_motor_time_to_event
|
206
|
+
and self.motor_tte_vocab_size
|
207
|
+
and self.motor_tte_vocab_size > 0
|
208
|
+
)
|
209
|
+
if self.include_motor_time_to_event and not ve_token_id:
|
210
|
+
raise RuntimeError(
|
211
|
+
f"ve_token_id must be provided when include_motor_time_to_event is True"
|
212
|
+
)
|
213
|
+
self.ve_token_id = ve_token_id
|
214
|
+
self.motor_time_to_event_weight = motor_time_to_event_weight
|
215
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
216
|
+
|
195
217
|
self.causal_sfm = causal_sfm
|
196
218
|
self.demographics_size = demographics_size
|
197
219
|
self.use_pretrained_embeddings = use_pretrained_embeddings
|
@@ -206,6 +228,9 @@ class CEHRGPTConfig(PretrainedConfig):
|
|
206
228
|
self.entropy_penalty_alpha = entropy_penalty_alpha
|
207
229
|
self.value_prediction_loss_weight = value_prediction_loss_weight
|
208
230
|
|
231
|
+
# Class weights for fine-tuning
|
232
|
+
self.class_weights = class_weights
|
233
|
+
|
209
234
|
kwargs["tie_word_embeddings"] = not use_pretrained_embeddings
|
210
235
|
|
211
236
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|