cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,33 @@
|
|
1
1
|
import datetime
|
2
|
-
from
|
2
|
+
from collections import defaultdict
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Union
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import pandas as pd
|
7
|
+
from cehrbert.data_generators.hf_data_generator import UNKNOWN_VALUE
|
6
8
|
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
|
7
9
|
ED_VISIT_TYPE_CODES,
|
8
10
|
INPATIENT_VISIT_TYPE_CODES,
|
9
11
|
INPATIENT_VISIT_TYPES,
|
10
12
|
DatasetMapping,
|
13
|
+
VisitObject,
|
14
|
+
get_value,
|
15
|
+
has_events_and_get_events,
|
11
16
|
replace_escape_chars,
|
12
17
|
)
|
18
|
+
from cehrbert.med_extension.schema_extension import Event
|
13
19
|
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
|
20
|
+
from cehrbert_data.const.artificial_tokens import (
|
21
|
+
DISCHARGE_UNKNOWN_TOKEN,
|
22
|
+
GENDER_UNKNOWN_TOKEN,
|
23
|
+
RACE_UNKNOWN_TOKEN,
|
24
|
+
VISIT_UNKNOWN_TOKEN,
|
25
|
+
)
|
14
26
|
from cehrbert_data.const.common import NA
|
15
27
|
from cehrbert_data.decorators.patient_event_decorator_base import get_att_function
|
28
|
+
from datasets.formatting.formatting import LazyBatch
|
16
29
|
from dateutil.relativedelta import relativedelta
|
30
|
+
from pandas import Series
|
17
31
|
|
18
32
|
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
19
33
|
NONE_BIN,
|
@@ -21,6 +35,17 @@ from cehrgpt.models.tokenization_hf_cehrgpt import (
|
|
21
35
|
CehrGptTokenizer,
|
22
36
|
)
|
23
37
|
|
38
|
+
CEHRGPT_COLUMNS = [
|
39
|
+
"concept_ids",
|
40
|
+
"concept_value_masks",
|
41
|
+
"number_as_values",
|
42
|
+
"concept_as_values",
|
43
|
+
"is_numeric_types",
|
44
|
+
"concept_values",
|
45
|
+
"units",
|
46
|
+
"epoch_times",
|
47
|
+
]
|
48
|
+
|
24
49
|
|
25
50
|
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
26
51
|
return datetime.datetime.combine(
|
@@ -28,11 +53,39 @@ def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
|
28
53
|
).timestamp()
|
29
54
|
|
30
55
|
|
31
|
-
class
|
56
|
+
class DatasetMappingDecorator(DatasetMapping):
|
57
|
+
|
58
|
+
def batch_transform(
|
59
|
+
self, records: Union[LazyBatch, Dict[str, Any]]
|
60
|
+
) -> List[Dict[str, Any]]:
|
61
|
+
"""
|
62
|
+
Drop index_date if it contains None.
|
63
|
+
|
64
|
+
:param records:
|
65
|
+
:return:
|
66
|
+
"""
|
67
|
+
if isinstance(records, LazyBatch):
|
68
|
+
table = records.pa_table
|
69
|
+
|
70
|
+
if "index_date" in table.column_names:
|
71
|
+
index_col = table.column("index_date")
|
72
|
+
if index_col.null_count > 0:
|
73
|
+
table = table.drop(["index_date"])
|
74
|
+
records = LazyBatch(pa_table=table, formatter=records.formatter)
|
75
|
+
else:
|
76
|
+
if "index_date" in records:
|
77
|
+
if pd.isna(records["index_date"][0]):
|
78
|
+
del records["index_date"]
|
79
|
+
return super().batch_transform(records=records)
|
80
|
+
|
81
|
+
def transform(self, record: Dict[str, Any]) -> Union[Dict[str, Any], Series]:
|
82
|
+
raise NotImplemented("Must be implemented")
|
83
|
+
|
84
|
+
|
85
|
+
class MedToCehrGPTDatasetMapping(DatasetMappingDecorator):
|
32
86
|
def __init__(
|
33
87
|
self,
|
34
88
|
data_args: DataTrainingArguments,
|
35
|
-
is_pretraining: bool = True,
|
36
89
|
include_inpatient_hour_token: bool = True,
|
37
90
|
):
|
38
91
|
self._time_token_function = get_att_function(data_args.att_function_type)
|
@@ -41,7 +94,6 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
41
94
|
data_args.inpatient_att_function_type
|
42
95
|
)
|
43
96
|
self._include_demographic_prompt = data_args.include_demographic_prompt
|
44
|
-
self._is_pretraining = is_pretraining
|
45
97
|
self._include_inpatient_hour_token = include_inpatient_hour_token
|
46
98
|
|
47
99
|
"""
|
@@ -57,19 +109,13 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
57
109
|
"""
|
58
110
|
|
59
111
|
def remove_columns(self):
|
60
|
-
|
61
|
-
return ["visits", "birth_datetime", "index_date"]
|
62
|
-
else:
|
63
|
-
return [
|
64
|
-
"visits",
|
65
|
-
"birth_datetime",
|
66
|
-
"visit_concept_ids",
|
67
|
-
]
|
112
|
+
return ["patient_id", "visits", "birth_datetime"]
|
68
113
|
|
69
114
|
@staticmethod
|
70
115
|
def _update_cehrgpt_record(
|
71
116
|
cehrgpt_record: Dict[str, Any],
|
72
117
|
code: str,
|
118
|
+
time: datetime.datetime,
|
73
119
|
concept_value_mask: int = 0,
|
74
120
|
number_as_value: float = 0.0,
|
75
121
|
concept_as_value: str = "0",
|
@@ -82,6 +128,7 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
82
128
|
cehrgpt_record["concept_as_values"].append(concept_as_value)
|
83
129
|
cehrgpt_record["units"].append(unit)
|
84
130
|
cehrgpt_record["is_numeric_types"].append(is_numeric_type)
|
131
|
+
cehrgpt_record["epoch_times"].append(time.timestamp())
|
85
132
|
|
86
133
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
87
134
|
cehrgpt_record = {
|
@@ -92,45 +139,57 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
92
139
|
"concept_as_values": [],
|
93
140
|
"units": [],
|
94
141
|
"is_numeric_types": [],
|
142
|
+
"epoch_times": [],
|
95
143
|
}
|
96
144
|
# Extract the demographic information
|
97
145
|
birth_datetime = record["birth_datetime"]
|
98
146
|
if isinstance(birth_datetime, pd.Timestamp):
|
99
147
|
birth_datetime = birth_datetime.to_pydatetime()
|
100
148
|
gender = record["gender"]
|
149
|
+
gender = GENDER_UNKNOWN_TOKEN if gender == UNKNOWN_VALUE else gender
|
101
150
|
race = record["race"]
|
151
|
+
race = RACE_UNKNOWN_TOKEN if race == UNKNOWN_VALUE else race
|
152
|
+
visits = record["visits"]
|
153
|
+
# This indicates this is columnar format
|
154
|
+
if isinstance(visits, dict):
|
155
|
+
visits = sorted(self.convert_visit_columnar_to_python(visits))
|
156
|
+
else:
|
157
|
+
visits = sorted(visits, key=lambda _: get_value(_, "visit_start_datetime"))
|
102
158
|
|
103
159
|
# Add the demographic tokens
|
104
|
-
first_visit =
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
self._update_cehrgpt_record(
|
160
|
+
first_visit = visits[0]
|
161
|
+
first_visit_start_datetime: datetime.datetime = get_value(
|
162
|
+
first_visit, "visit_start_datetime"
|
163
|
+
)
|
164
|
+
year_str = f"year:{str(first_visit_start_datetime.year)}"
|
165
|
+
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
|
166
|
+
self._update_cehrgpt_record(
|
167
|
+
cehrgpt_record, year_str, first_visit_start_datetime
|
168
|
+
)
|
169
|
+
self._update_cehrgpt_record(cehrgpt_record, age_str, first_visit_start_datetime)
|
170
|
+
self._update_cehrgpt_record(cehrgpt_record, gender, first_visit_start_datetime)
|
171
|
+
self._update_cehrgpt_record(cehrgpt_record, race, first_visit_start_datetime)
|
111
172
|
|
112
173
|
# Use a data cursor to keep track of time
|
113
|
-
|
114
|
-
|
115
|
-
# Loop through all the visits
|
116
|
-
for i, visit in enumerate(
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
events = visit["events"]
|
121
|
-
|
122
|
-
# Skip this visit if the number measurements in the event is zero
|
123
|
-
if events is None or len(events) == 0:
|
174
|
+
datetime_cursor: Optional[datetime.datetime] = None
|
175
|
+
visit: VisitObject
|
176
|
+
# Loop through all the visits
|
177
|
+
for i, visit in enumerate(visits):
|
178
|
+
events: Generator[Event, None, None] = get_value(visit, "events")
|
179
|
+
has_events, events = has_events_and_get_events(events)
|
180
|
+
if not has_events:
|
124
181
|
continue
|
125
182
|
|
126
|
-
visit_start_datetime =
|
127
|
-
|
128
|
-
|
183
|
+
visit_start_datetime: datetime.datetime = get_value(
|
184
|
+
visit, "visit_start_datetime"
|
185
|
+
)
|
186
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
187
|
+
visit_end_datetime: Optional[datetime.datetime] = get_value(
|
188
|
+
visit, "visit_end_datetime"
|
129
189
|
)
|
130
|
-
date_cursor = visit_start_datetime
|
131
190
|
|
132
191
|
# We assume the first measurement to be the visit type of the current visit
|
133
|
-
visit_type = visit
|
192
|
+
visit_type = get_value(visit, "visit_type")
|
134
193
|
is_er_or_inpatient = (
|
135
194
|
visit_type in INPATIENT_VISIT_TYPES
|
136
195
|
or visit_type in INPATIENT_VISIT_TYPE_CODES
|
@@ -138,36 +197,45 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
138
197
|
)
|
139
198
|
|
140
199
|
# Add artificial time tokens to the patient timeline if timedelta exists
|
141
|
-
if
|
200
|
+
if datetime_cursor is not None:
|
201
|
+
time_delta = max((visit_start_datetime - datetime_cursor).days, 0)
|
142
202
|
# This generates an artificial time token depending on the choice of the time token functions
|
143
203
|
self._update_cehrgpt_record(
|
144
204
|
cehrgpt_record,
|
145
205
|
code=self._time_token_function(time_delta),
|
206
|
+
time=visit_start_datetime,
|
146
207
|
)
|
147
208
|
|
148
|
-
|
149
|
-
relativedelta(visit["visit_start_datetime"], birth_datetime).years
|
150
|
-
# Calculate the week number since the epoch time
|
151
|
-
date = (
|
152
|
-
visit["visit_start_datetime"]
|
153
|
-
- datetime.datetime(year=1970, month=1, day=1)
|
154
|
-
).days // 7
|
155
|
-
|
209
|
+
datetime_cursor = visit_start_datetime
|
156
210
|
# Add a [VS] token
|
157
211
|
self._update_cehrgpt_record(
|
158
212
|
cehrgpt_record,
|
159
213
|
code="[VS]",
|
214
|
+
time=datetime_cursor,
|
160
215
|
)
|
161
216
|
# Add a visit type token
|
162
217
|
self._update_cehrgpt_record(
|
163
218
|
cehrgpt_record,
|
164
219
|
code=visit_type,
|
220
|
+
time=datetime_cursor,
|
165
221
|
)
|
222
|
+
# We need to insert an inpatient hour token right after the visit type, we calculate the hour interval
|
223
|
+
# with respect to the midnight of the day
|
224
|
+
if is_er_or_inpatient and self._include_inpatient_hour_token:
|
225
|
+
if datetime_cursor.hour > 0:
|
226
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
227
|
+
self._update_cehrgpt_record(
|
228
|
+
cehrgpt_record,
|
229
|
+
code=f"i-H{datetime_cursor.hour}",
|
230
|
+
time=datetime_cursor,
|
231
|
+
)
|
232
|
+
|
166
233
|
# Keep track of the existing outpatient events, we don't want to add them again
|
167
|
-
|
234
|
+
existing_duplicate_events = list()
|
168
235
|
for e in events:
|
169
236
|
# If the event doesn't have a time stamp, we skip it
|
170
|
-
|
237
|
+
event_time: datetime.datetime = e["time"]
|
238
|
+
if not event_time:
|
171
239
|
continue
|
172
240
|
|
173
241
|
# If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask
|
@@ -178,47 +246,62 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
178
246
|
concept_value_mask = int(
|
179
247
|
numeric_value is not None or text_value is not None
|
180
248
|
)
|
249
|
+
if numeric_value is None and text_value is not None:
|
250
|
+
if text_value.isnumeric():
|
251
|
+
numeric_value = float(text_value)
|
252
|
+
|
181
253
|
is_numeric_type = int(numeric_value is not None)
|
182
254
|
code = replace_escape_chars(e["code"])
|
183
255
|
|
256
|
+
# Create the event identity
|
257
|
+
event_identity = (
|
258
|
+
(event_time, code, text_value, unit)
|
259
|
+
if is_er_or_inpatient
|
260
|
+
else (event_time.date(), code, text_value, unit)
|
261
|
+
)
|
262
|
+
|
184
263
|
# Add a medical token to the patient timeline
|
185
264
|
# If this is an inpatient visit, we use the event time stamps to calculate age and date
|
186
265
|
# because the patient can stay in the hospital for a period of time.
|
187
266
|
if is_er_or_inpatient:
|
188
|
-
# Calculate the week number since the epoch time
|
189
|
-
date = (
|
190
|
-
e["time"] - datetime.datetime(year=1970, month=1, day=1)
|
191
|
-
).days // 7
|
192
267
|
# Calculate the time diff in days w.r.t the previous measurement
|
193
|
-
|
194
|
-
# Update the
|
268
|
+
time_diff_days = (event_time - datetime_cursor).days
|
269
|
+
# Update the datetime_cursor if the time diff between two neighboring measurements is greater than and
|
195
270
|
# equal to 1 day
|
196
|
-
if
|
197
|
-
|
198
|
-
|
271
|
+
if self._inpatient_time_token_function and time_diff_days > 0:
|
272
|
+
# This generates an artificial time token depending on the choice of the time token functions
|
273
|
+
self._update_cehrgpt_record(
|
274
|
+
cehrgpt_record,
|
275
|
+
code=f"i-{self._inpatient_time_token_function(time_diff_days)}",
|
276
|
+
time=event_time,
|
277
|
+
)
|
278
|
+
|
279
|
+
if self._include_inpatient_hour_token:
|
280
|
+
# if the time difference in days is greater than 0, we calculate the hour interval
|
281
|
+
# with respect to the midnight of the day
|
282
|
+
time_diff_hours = (
|
283
|
+
event_time.hour
|
284
|
+
if time_diff_days > 0
|
285
|
+
else int(
|
286
|
+
(event_time - datetime_cursor).total_seconds() // 3600
|
287
|
+
)
|
288
|
+
)
|
289
|
+
|
290
|
+
if time_diff_hours > 0:
|
199
291
|
# This generates an artificial time token depending on the choice of the time token functions
|
200
292
|
self._update_cehrgpt_record(
|
201
293
|
cehrgpt_record,
|
202
|
-
code=f"i-{
|
294
|
+
code=f"i-H{time_diff_hours}",
|
295
|
+
time=event_time,
|
203
296
|
)
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
# We check whether the date/code/value combination already exists in the existing events
|
208
|
-
# If they exist, we do not add them to the patient timeline for outpatient visits.
|
209
|
-
if (
|
210
|
-
date,
|
211
|
-
code,
|
212
|
-
numeric_value,
|
213
|
-
text_value,
|
214
|
-
concept_value_mask,
|
215
|
-
numeric_value,
|
216
|
-
) in existing_outpatient_events:
|
217
|
-
continue
|
297
|
+
|
298
|
+
if event_identity in existing_duplicate_events:
|
299
|
+
continue
|
218
300
|
|
219
301
|
self._update_cehrgpt_record(
|
220
302
|
cehrgpt_record,
|
221
303
|
code=code,
|
304
|
+
time=event_time,
|
222
305
|
concept_value_mask=concept_value_mask,
|
223
306
|
unit=unit,
|
224
307
|
number_as_value=numeric_value if numeric_value else 0.0,
|
@@ -227,43 +310,44 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
227
310
|
),
|
228
311
|
is_numeric_type=is_numeric_type,
|
229
312
|
)
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
)
|
239
|
-
)
|
313
|
+
existing_duplicate_events.append(event_identity)
|
314
|
+
# we only want to update the time stamp when data_cursor is less than the event time
|
315
|
+
if datetime_cursor < event_time or datetime_cursor is None:
|
316
|
+
datetime_cursor = event_time
|
317
|
+
# We need to bound the datetime_cursor if the current visit is an admission type of visit
|
318
|
+
# as the associated events could be generated after the visits are complete
|
319
|
+
if is_er_or_inpatient and visit_end_datetime is not None:
|
320
|
+
datetime_cursor = min(datetime_cursor, visit_end_datetime)
|
240
321
|
|
241
322
|
# For inpatient or ER visits, we want to discharge_facility to the end of the visit
|
242
323
|
if is_er_or_inpatient:
|
243
|
-
# If visit_end_datetime is populated for the inpatient visit, we update the
|
244
|
-
visit_end_datetime
|
245
|
-
|
246
|
-
date_cursor = visit_end_datetime
|
324
|
+
# If visit_end_datetime is populated for the inpatient visit, we update the datetime_cursor
|
325
|
+
if visit_end_datetime is not None:
|
326
|
+
datetime_cursor = visit_end_datetime
|
247
327
|
|
248
328
|
if self._include_auxiliary_token:
|
249
329
|
# Reuse the age and date calculated for the last event in the patient timeline for the discharge
|
250
330
|
# facility event
|
251
|
-
discharge_facility = (
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
331
|
+
discharge_facility = get_value(visit, "discharge_facility")
|
332
|
+
if not discharge_facility:
|
333
|
+
discharge_facility = DISCHARGE_UNKNOWN_TOKEN
|
334
|
+
else:
|
335
|
+
discharge_facility = (
|
336
|
+
DISCHARGE_UNKNOWN_TOKEN
|
337
|
+
if discharge_facility == UNKNOWN_VALUE
|
338
|
+
else discharge_facility
|
339
|
+
)
|
258
340
|
self._update_cehrgpt_record(
|
259
341
|
cehrgpt_record,
|
260
342
|
code=discharge_facility,
|
343
|
+
time=datetime_cursor,
|
261
344
|
)
|
262
345
|
|
263
346
|
# Reuse the age and date calculated for the last event in the patient timeline
|
264
347
|
self._update_cehrgpt_record(
|
265
348
|
cehrgpt_record,
|
266
349
|
code="[VE]",
|
350
|
+
time=datetime_cursor,
|
267
351
|
)
|
268
352
|
|
269
353
|
# Generate the orders of the concepts that the cehrbert dataset mapping function expects
|
@@ -273,17 +357,23 @@ class MedToCehrGPTDatasetMapping(DatasetMapping):
|
|
273
357
|
|
274
358
|
# Add some count information for this sequence
|
275
359
|
cehrgpt_record["num_of_concepts"] = len(cehrgpt_record["concept_ids"])
|
276
|
-
cehrgpt_record["num_of_visits"] = len(
|
360
|
+
cehrgpt_record["num_of_visits"] = len(visits)
|
277
361
|
|
278
|
-
if "
|
362
|
+
if record.get("index_date", None) is not None:
|
363
|
+
cehrgpt_record["index_date"] = record["index_date"]
|
364
|
+
if record.get("label", None) is not None:
|
279
365
|
cehrgpt_record["label"] = record["label"]
|
280
|
-
if "age_at_index"
|
366
|
+
if record.get("age_at_index", None) is not None:
|
281
367
|
cehrgpt_record["age_at_index"] = record["age_at_index"]
|
282
368
|
|
369
|
+
assert len(cehrgpt_record["epoch_times"]) == len(
|
370
|
+
cehrgpt_record["concept_ids"]
|
371
|
+
), "The number of time stamps must match with the number of concepts in the sequence"
|
372
|
+
|
283
373
|
return cehrgpt_record
|
284
374
|
|
285
375
|
|
286
|
-
class HFCehrGptTokenizationMapping(
|
376
|
+
class HFCehrGptTokenizationMapping(DatasetMappingDecorator):
|
287
377
|
def __init__(
|
288
378
|
self,
|
289
379
|
concept_tokenizer: CehrGptTokenizer,
|
@@ -297,9 +387,46 @@ class HFCehrGptTokenizationMapping(DatasetMapping):
|
|
297
387
|
"is_numeric_types",
|
298
388
|
]
|
299
389
|
|
390
|
+
def filter_out_invalid_tokens(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
391
|
+
column_names = []
|
392
|
+
seq_length = len(record["concept_ids"])
|
393
|
+
|
394
|
+
# We can't have "0" as a token in the tokenizer because it would break tokenization for "Race/0", "Visit/0"
|
395
|
+
# This is a pre-caution
|
396
|
+
if "0" in record["concept_ids"]:
|
397
|
+
if isinstance(record["concept_ids"], np.ndarray):
|
398
|
+
record["concept_ids"][record["concept_ids"] == "0"] = "Unknown"
|
399
|
+
else:
|
400
|
+
record["concept_ids"] = [
|
401
|
+
"Unknown" if x == "0" else x for x in record["concept_ids"]
|
402
|
+
]
|
403
|
+
|
404
|
+
for k, v in record.items():
|
405
|
+
if k not in CEHRGPT_COLUMNS:
|
406
|
+
continue
|
407
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
408
|
+
column_names.append(k)
|
409
|
+
valid_concept_ids = self._concept_tokenizer.get_vocab().keys()
|
410
|
+
valid_indices = [
|
411
|
+
idx
|
412
|
+
for idx, concept_id in enumerate(record["concept_ids"])
|
413
|
+
if concept_id in valid_concept_ids
|
414
|
+
]
|
415
|
+
if len(valid_indices) != len(record["concept_ids"]):
|
416
|
+
for column in column_names:
|
417
|
+
values = record[column]
|
418
|
+
record[column] = [values[idx] for idx in valid_indices]
|
419
|
+
return record
|
420
|
+
|
300
421
|
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
422
|
+
# Remove the tokens from patient sequences that do not exist in the tokenizer
|
423
|
+
record = self.filter_out_invalid_tokens(record)
|
301
424
|
# If any concept has a value associated with it, we normalize the value
|
302
425
|
record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
|
426
|
+
assert len(record["input_ids"]) == len(record["concept_ids"]), (
|
427
|
+
"The number of tokens must equal to the number of concepts\n"
|
428
|
+
f"decoded concept_ids: {self._concept_tokenizer.decode(record['input_ids'], skip_special_tokens=False)}"
|
429
|
+
)
|
303
430
|
record["value_indicators"] = record["concept_value_masks"]
|
304
431
|
if "number_as_values" not in record or "concept_as_values" not in record:
|
305
432
|
record["number_as_values"] = [
|
@@ -380,3 +507,89 @@ class HFFineTuningMapping(HFCehrGptTokenizationMapping):
|
|
380
507
|
columns = super().remove_columns()
|
381
508
|
columns.append("label")
|
382
509
|
return columns
|
510
|
+
|
511
|
+
|
512
|
+
class ExtractTokenizedSequenceDataMapping:
|
513
|
+
def __init__(
|
514
|
+
self,
|
515
|
+
person_index_date_map: Dict[int, List[Dict[str, Any]]],
|
516
|
+
observation_window: int = 0,
|
517
|
+
):
|
518
|
+
self.person_index_date_map = person_index_date_map
|
519
|
+
self.observation_window = observation_window
|
520
|
+
|
521
|
+
def _calculate_prediction_start_time(self, prediction_time: float):
|
522
|
+
if self.observation_window and self.observation_window > 0:
|
523
|
+
return max(prediction_time - self.observation_window * 24 * 3600, 0)
|
524
|
+
return 0
|
525
|
+
|
526
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
527
|
+
person_id = record["person_id"]
|
528
|
+
prediction_times = self.person_index_date_map[person_id]
|
529
|
+
prediction_start_end_times = [
|
530
|
+
(
|
531
|
+
self._calculate_prediction_start_time(
|
532
|
+
prediction_time_label_map["index_date"].timestamp()
|
533
|
+
),
|
534
|
+
prediction_time_label_map["index_date"].timestamp(),
|
535
|
+
prediction_time_label_map["label"],
|
536
|
+
)
|
537
|
+
for prediction_time_label_map in prediction_times
|
538
|
+
]
|
539
|
+
observation_window_indices = np.zeros(
|
540
|
+
(len(prediction_times), len(record["epoch_times"])), dtype=bool
|
541
|
+
)
|
542
|
+
for i, epoch_time in enumerate(record["epoch_times"]):
|
543
|
+
for sample_n, (
|
544
|
+
feature_extraction_time_start,
|
545
|
+
feature_extraction_end_end,
|
546
|
+
_,
|
547
|
+
) in enumerate(prediction_start_end_times):
|
548
|
+
if (
|
549
|
+
feature_extraction_time_start
|
550
|
+
<= epoch_time
|
551
|
+
<= feature_extraction_end_end
|
552
|
+
):
|
553
|
+
observation_window_indices[sample_n][i] = True
|
554
|
+
|
555
|
+
seq_length = len(record["epoch_times"])
|
556
|
+
time_series_columns = ["concept_ids", "input_ids"]
|
557
|
+
static_inputs = dict()
|
558
|
+
for k, v in record.items():
|
559
|
+
if k in ["concept_ids", "input_ids"]:
|
560
|
+
continue
|
561
|
+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
|
562
|
+
time_series_columns.append(k)
|
563
|
+
else:
|
564
|
+
static_inputs[k] = v
|
565
|
+
|
566
|
+
batched_samples = defaultdict(list)
|
567
|
+
for (_, index_date, label), observation_window_index in zip(
|
568
|
+
prediction_start_end_times, observation_window_indices
|
569
|
+
):
|
570
|
+
for k, v in static_inputs.items():
|
571
|
+
batched_samples[k].append(v)
|
572
|
+
batched_samples["classifier_label"].append(label)
|
573
|
+
batched_samples["index_date"].append(index_date)
|
574
|
+
try:
|
575
|
+
start_age = int(record["concept_ids"][1].split(":")[1])
|
576
|
+
except Exception:
|
577
|
+
start_age = -1
|
578
|
+
batched_samples["age_at_index"].append(start_age)
|
579
|
+
for time_series_column in time_series_columns:
|
580
|
+
batched_samples[time_series_column].append(
|
581
|
+
np.asarray(record[time_series_column])[observation_window_index]
|
582
|
+
)
|
583
|
+
return batched_samples
|
584
|
+
|
585
|
+
def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
586
|
+
all_batched_record = defaultdict(list)
|
587
|
+
all_columns = record.keys()
|
588
|
+
for i in range(len(record["concept_ids"])):
|
589
|
+
one_record = {}
|
590
|
+
for column in all_columns:
|
591
|
+
one_record[column] = record[column][i]
|
592
|
+
new_batched_record = self.transform(one_record)
|
593
|
+
for k, v in new_batched_record.items():
|
594
|
+
all_batched_record[k].extend(v)
|
595
|
+
return all_batched_record
|