cehrgpt 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- cehrgpt/__init__.py +0 -0
- cehrgpt/analysis/__init__.py +0 -0
- cehrgpt/analysis/privacy/__init__.py +0 -0
- cehrgpt/analysis/privacy/attribute_inference.py +275 -0
- cehrgpt/analysis/privacy/attribute_inference_config.yml +8975 -0
- cehrgpt/analysis/privacy/member_inference.py +172 -0
- cehrgpt/analysis/privacy/nearest_neighbor_inference.py +189 -0
- cehrgpt/analysis/privacy/reid_inference.py +407 -0
- cehrgpt/analysis/privacy/utils.py +255 -0
- cehrgpt/cehrgpt_args.py +142 -0
- cehrgpt/data/__init__.py +0 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +80 -0
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +482 -0
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +116 -0
- cehrgpt/generation/__init__.py +0 -0
- cehrgpt/generation/chatgpt_generation.py +106 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +333 -0
- cehrgpt/generation/omop_converter_batch.py +644 -0
- cehrgpt/generation/omop_entity.py +515 -0
- cehrgpt/gpt_utils.py +331 -0
- cehrgpt/models/__init__.py +0 -0
- cehrgpt/models/config.py +205 -0
- cehrgpt/models/hf_cehrgpt.py +1817 -0
- cehrgpt/models/hf_modeling_outputs.py +158 -0
- cehrgpt/models/pretrained_embeddings.py +82 -0
- cehrgpt/models/special_tokens.py +30 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +1077 -0
- cehrgpt/omop/__init__.py +0 -0
- cehrgpt/omop/condition_era.py +20 -0
- cehrgpt/omop/observation_period.py +43 -0
- cehrgpt/omop/omop_argparse.py +38 -0
- cehrgpt/omop/omop_table_builder.py +86 -0
- cehrgpt/omop/queries/__init__.py +0 -0
- cehrgpt/omop/queries/condition_era.py +86 -0
- cehrgpt/omop/queries/observation_period.py +135 -0
- cehrgpt/omop/sample_omop_tables.py +71 -0
- cehrgpt/runners/__init__.py +0 -0
- cehrgpt/runners/gpt_runner_util.py +99 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +746 -0
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +370 -0
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +137 -0
- cehrgpt/runners/hyperparameter_search_util.py +223 -0
- cehrgpt/time_to_event/__init__.py +0 -0
- cehrgpt/time_to_event/config/30_day_readmission.yaml +8 -0
- cehrgpt/time_to_event/config/next_visit_type_prediction.yaml +8 -0
- cehrgpt/time_to_event/config/t2dm_hf.yaml +8 -0
- cehrgpt/time_to_event/time_to_event_model.py +226 -0
- cehrgpt/time_to_event/time_to_event_prediction.py +347 -0
- cehrgpt/time_to_event/time_to_event_utils.py +55 -0
- cehrgpt/tools/__init__.py +0 -0
- cehrgpt/tools/ehrshot_benchmark.py +74 -0
- cehrgpt/tools/generate_pretrained_embeddings.py +130 -0
- cehrgpt/tools/merge_synthetic_real_dataasets.py +218 -0
- cehrgpt/tools/upload_omop_tables.py +108 -0
- cehrgpt-0.0.1.dist-info/LICENSE +21 -0
- cehrgpt-0.0.1.dist-info/METADATA +66 -0
- cehrgpt-0.0.1.dist-info/RECORD +60 -0
- cehrgpt-0.0.1.dist-info/WHEEL +5 -0
- cehrgpt-0.0.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,482 @@
|
|
1
|
+
import random
|
2
|
+
from typing import Any, Dict
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from torch.nn.utils.rnn import pad_sequence
|
7
|
+
|
8
|
+
from cehrgpt.gpt_utils import (
|
9
|
+
DEMOGRAPHIC_PROMPT_SIZE,
|
10
|
+
collect_demographic_prompts_at_visits,
|
11
|
+
extract_time_interval_in_days,
|
12
|
+
is_att_token,
|
13
|
+
is_inpatient_att_token,
|
14
|
+
random_slice_gpt_sequence,
|
15
|
+
)
|
16
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
17
|
+
|
18
|
+
INPATIENT_STAY_DURATION_LIMIT = 30
|
19
|
+
|
20
|
+
|
21
|
+
class CehrGptDataCollator:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
tokenizer: CehrGptTokenizer,
|
25
|
+
max_length: int,
|
26
|
+
shuffle_records: bool = False,
|
27
|
+
include_values: bool = False,
|
28
|
+
include_ttv_prediction: bool = False,
|
29
|
+
use_sub_time_tokenization: bool = False,
|
30
|
+
pretraining: bool = True,
|
31
|
+
include_demographics: bool = False,
|
32
|
+
):
|
33
|
+
self.tokenizer = tokenizer
|
34
|
+
self.max_length = max_length
|
35
|
+
# Pre-compute these so we can use them later on
|
36
|
+
# We used VS for the historical data, currently, we use the new [VS] for the newer data
|
37
|
+
# so we need to check both cases.
|
38
|
+
self.vs_token_id = tokenizer._convert_token_to_id("VS")
|
39
|
+
if self.vs_token_id == tokenizer._oov_token_id:
|
40
|
+
self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
|
41
|
+
self.ve_token_id = tokenizer._convert_token_to_id("VE")
|
42
|
+
if self.ve_token_id == tokenizer._oov_token_id:
|
43
|
+
self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
|
44
|
+
|
45
|
+
self.shuffle_records = shuffle_records
|
46
|
+
self.include_values = include_values
|
47
|
+
self.include_ttv_prediction = include_ttv_prediction
|
48
|
+
self.use_sub_time_tokenization = use_sub_time_tokenization
|
49
|
+
self.pretraining = pretraining
|
50
|
+
self.include_demographics = include_demographics
|
51
|
+
|
52
|
+
if self.use_sub_time_tokenization:
|
53
|
+
token_to_time_token_mapping = tokenizer.token_to_time_token_mapping
|
54
|
+
if not token_to_time_token_mapping:
|
55
|
+
raise ValueError(
|
56
|
+
"The token_to_time_token_mapping in CehrGptTokenizer cannot be None "
|
57
|
+
"when use_sub_time_tokenization is enabled"
|
58
|
+
)
|
59
|
+
# Create the tensors for converting time tokens to the sub time tokens
|
60
|
+
self.time_tokens = torch.tensor(
|
61
|
+
list(tokenizer.token_to_time_token_mapping.keys()), dtype=torch.int64
|
62
|
+
)
|
63
|
+
self.mapped_sub_time_tokens = torch.tensor(
|
64
|
+
list(token_to_time_token_mapping.values()), dtype=torch.int64
|
65
|
+
)
|
66
|
+
|
67
|
+
def _try_reverse_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
68
|
+
if not self.pretraining:
|
69
|
+
return torch.flip(tensor, dims=[-1])
|
70
|
+
return tensor
|
71
|
+
|
72
|
+
@staticmethod
|
73
|
+
def _convert_to_tensor(features: Any) -> torch.Tensor:
|
74
|
+
if isinstance(features, torch.Tensor):
|
75
|
+
return features
|
76
|
+
else:
|
77
|
+
return torch.tensor(features)
|
78
|
+
|
79
|
+
@staticmethod
|
80
|
+
def _convert_time_to_event(concept_ids):
|
81
|
+
def default_value(c):
|
82
|
+
try:
|
83
|
+
if is_att_token(c):
|
84
|
+
time_to_visit = extract_time_interval_in_days(c)
|
85
|
+
if (
|
86
|
+
is_inpatient_att_token(c)
|
87
|
+
and time_to_visit > INPATIENT_STAY_DURATION_LIMIT
|
88
|
+
):
|
89
|
+
return -100
|
90
|
+
return time_to_visit
|
91
|
+
return -100
|
92
|
+
except ValueError:
|
93
|
+
return -100
|
94
|
+
|
95
|
+
return [float(default_value(_)) for _ in concept_ids]
|
96
|
+
|
97
|
+
def __call__(self, examples):
|
98
|
+
|
99
|
+
examples = [self.generate_start_end_index(_) for _ in examples]
|
100
|
+
examples = [self.random_sort(_) for _ in examples]
|
101
|
+
batch = {}
|
102
|
+
|
103
|
+
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
|
104
|
+
batch_input_ids = [
|
105
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["input_ids"]))
|
106
|
+
for example in examples
|
107
|
+
]
|
108
|
+
batch_attention_mask = [
|
109
|
+
self._try_reverse_tensor(
|
110
|
+
torch.ones_like(
|
111
|
+
self._convert_to_tensor(example["input_ids"]), dtype=torch.float
|
112
|
+
)
|
113
|
+
)
|
114
|
+
for example in examples
|
115
|
+
]
|
116
|
+
|
117
|
+
# Pad sequences to the max length in the batch
|
118
|
+
batch["input_ids"] = self._try_reverse_tensor(
|
119
|
+
pad_sequence(
|
120
|
+
batch_input_ids,
|
121
|
+
batch_first=True,
|
122
|
+
padding_value=self.tokenizer.pad_token_id,
|
123
|
+
).to(torch.int64)
|
124
|
+
)
|
125
|
+
|
126
|
+
batch["attention_mask"] = self._try_reverse_tensor(
|
127
|
+
pad_sequence(batch_attention_mask, batch_first=True, padding_value=0.0)
|
128
|
+
)
|
129
|
+
assert batch["input_ids"].shape[1] <= self.max_length
|
130
|
+
assert batch["attention_mask"].shape[1] <= self.max_length
|
131
|
+
|
132
|
+
if self.pretraining:
|
133
|
+
batch["labels"] = self._try_reverse_tensor(
|
134
|
+
pad_sequence(
|
135
|
+
batch_input_ids,
|
136
|
+
batch_first=True,
|
137
|
+
padding_value=-100,
|
138
|
+
).to(torch.int64)
|
139
|
+
)
|
140
|
+
|
141
|
+
if self.use_sub_time_tokenization:
|
142
|
+
time_token_indicators = torch.isin(batch["input_ids"], self.time_tokens)
|
143
|
+
masked_tokens = batch["input_ids"].clone()
|
144
|
+
masked_tokens[~time_token_indicators] = -1
|
145
|
+
# Get the index of the sub_time_tokens from the time_tokens tensor
|
146
|
+
sub_time_token_indices = torch.argmax(
|
147
|
+
(
|
148
|
+
masked_tokens.unsqueeze(-1)
|
149
|
+
== self.time_tokens.unsqueeze(0).unsqueeze(0)
|
150
|
+
).to(torch.int32),
|
151
|
+
dim=-1,
|
152
|
+
)
|
153
|
+
sub_time_tokens = self.mapped_sub_time_tokens[sub_time_token_indices]
|
154
|
+
batch["time_token_indicators"] = time_token_indicators
|
155
|
+
batch["sub_time_tokens"] = sub_time_tokens
|
156
|
+
|
157
|
+
if self.include_ttv_prediction:
|
158
|
+
batch_time_to_visits = [
|
159
|
+
self._try_reverse_tensor(
|
160
|
+
self._convert_to_tensor(example["time_to_visits"])
|
161
|
+
)
|
162
|
+
for example in examples
|
163
|
+
]
|
164
|
+
batch["time_to_visits"] = self._try_reverse_tensor(
|
165
|
+
pad_sequence(
|
166
|
+
batch_time_to_visits, batch_first=True, padding_value=-100.0
|
167
|
+
)
|
168
|
+
)
|
169
|
+
|
170
|
+
if self.include_values:
|
171
|
+
batch_value_indicators = [
|
172
|
+
self._try_reverse_tensor(
|
173
|
+
self._convert_to_tensor(example["value_indicators"])
|
174
|
+
)
|
175
|
+
for example in examples
|
176
|
+
]
|
177
|
+
batch_values = [
|
178
|
+
self._try_reverse_tensor(self._convert_to_tensor(example["values"]))
|
179
|
+
for example in examples
|
180
|
+
]
|
181
|
+
|
182
|
+
batch["value_indicators"] = self._try_reverse_tensor(
|
183
|
+
pad_sequence(
|
184
|
+
batch_value_indicators, batch_first=True, padding_value=False
|
185
|
+
)
|
186
|
+
)
|
187
|
+
batch["values"] = self._try_reverse_tensor(
|
188
|
+
pad_sequence(
|
189
|
+
batch_values,
|
190
|
+
batch_first=True,
|
191
|
+
padding_value=self.tokenizer.pad_value_token_id,
|
192
|
+
).to(torch.int64)
|
193
|
+
)
|
194
|
+
assert batch["value_indicators"].shape[1] <= self.max_length
|
195
|
+
assert batch["values"].shape[1] <= self.max_length
|
196
|
+
|
197
|
+
if self.pretraining:
|
198
|
+
batch["true_value_indicators"] = batch["value_indicators"].clone()
|
199
|
+
batch["true_values"] = torch.where(
|
200
|
+
batch["value_indicators"], batch["values"].clone(), -100
|
201
|
+
)
|
202
|
+
|
203
|
+
if "person_id" in examples[0]:
|
204
|
+
batch["person_id"] = torch.cat(
|
205
|
+
[
|
206
|
+
self._convert_to_tensor(example["person_id"]).reshape(-1, 1)
|
207
|
+
for example in examples
|
208
|
+
],
|
209
|
+
dim=0,
|
210
|
+
).to(torch.int32)
|
211
|
+
|
212
|
+
if "index_date" in examples[0]:
|
213
|
+
batch["index_date"] = torch.cat(
|
214
|
+
[
|
215
|
+
self._convert_to_tensor(example["index_date"]).reshape(-1, 1)
|
216
|
+
for example in examples
|
217
|
+
],
|
218
|
+
dim=0,
|
219
|
+
).to(torch.float32)
|
220
|
+
|
221
|
+
if "age_at_index" in examples[0]:
|
222
|
+
batch["age_at_index"] = torch.cat(
|
223
|
+
[
|
224
|
+
self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1)
|
225
|
+
for example in examples
|
226
|
+
],
|
227
|
+
dim=0,
|
228
|
+
).to(torch.float32)
|
229
|
+
|
230
|
+
if "classifier_label" in examples[0]:
|
231
|
+
batch["classifier_label"] = torch.cat(
|
232
|
+
[
|
233
|
+
self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1)
|
234
|
+
for example in examples
|
235
|
+
],
|
236
|
+
dim=0,
|
237
|
+
).to(torch.float32)
|
238
|
+
|
239
|
+
return batch
|
240
|
+
|
241
|
+
def random_sort(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
242
|
+
|
243
|
+
if not self.shuffle_records:
|
244
|
+
return record
|
245
|
+
|
246
|
+
if "record_ranks" not in record:
|
247
|
+
return record
|
248
|
+
|
249
|
+
sorting_column = record["record_ranks"]
|
250
|
+
random_order = np.random.rand(len(sorting_column))
|
251
|
+
|
252
|
+
if self.include_values:
|
253
|
+
iterator = zip(
|
254
|
+
sorting_column,
|
255
|
+
random_order,
|
256
|
+
record["input_ids"],
|
257
|
+
record["value_indicators"],
|
258
|
+
record["values"],
|
259
|
+
)
|
260
|
+
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
261
|
+
_, _, sorted_input_ids, sorted_value_indicators, sorted_values = zip(
|
262
|
+
*list(sorted_list)
|
263
|
+
)
|
264
|
+
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
265
|
+
record["value_indicators"] = self._convert_to_tensor(
|
266
|
+
sorted_value_indicators
|
267
|
+
)
|
268
|
+
record["values"] = self._convert_to_tensor(sorted_values)
|
269
|
+
else:
|
270
|
+
iterator = zip(sorting_column, random_order, record["input_ids"])
|
271
|
+
sorted_list = sorted(iterator, key=lambda tup2: (tup2[0], tup2[1], tup2[2]))
|
272
|
+
_, _, sorted_input_ids = zip(*list(sorted_list))
|
273
|
+
record["input_ids"] = self._convert_to_tensor(sorted_input_ids)
|
274
|
+
return record
|
275
|
+
|
276
|
+
def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
277
|
+
"""Adding the start and end indices to extract a portion of the patient sequence."""
|
278
|
+
# concept_ids will be used to for time to event predictions and identifying the visit starts
|
279
|
+
input_ids = record["input_ids"]
|
280
|
+
if isinstance(input_ids, torch.Tensor):
|
281
|
+
input_ids = input_ids.detach().tolist()
|
282
|
+
concept_ids = self.tokenizer.decode(input_ids, skip_special_tokens=False)
|
283
|
+
seq_length = len(record["input_ids"])
|
284
|
+
new_max_length = self.max_length - 1 # Subtract one for the [END] token
|
285
|
+
|
286
|
+
# Return the record directly if the actual sequence length is less than the max sequence
|
287
|
+
if seq_length <= new_max_length:
|
288
|
+
record["input_ids"] = torch.concat(
|
289
|
+
[
|
290
|
+
self._convert_to_tensor(record["input_ids"]),
|
291
|
+
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
292
|
+
]
|
293
|
+
)
|
294
|
+
if self.include_values:
|
295
|
+
record["value_indicators"] = torch.concat(
|
296
|
+
[
|
297
|
+
self._convert_to_tensor(record["value_indicators"]),
|
298
|
+
self._convert_to_tensor([False]),
|
299
|
+
]
|
300
|
+
).to(torch.bool)
|
301
|
+
record["values"] = torch.concat(
|
302
|
+
[
|
303
|
+
self._convert_to_tensor(record["values"]),
|
304
|
+
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
305
|
+
]
|
306
|
+
)
|
307
|
+
if self.include_ttv_prediction:
|
308
|
+
record["time_to_visits"] = torch.concat(
|
309
|
+
[
|
310
|
+
self._convert_to_tensor(
|
311
|
+
self._convert_time_to_event(concept_ids)
|
312
|
+
),
|
313
|
+
self._convert_to_tensor([-100.0]),
|
314
|
+
]
|
315
|
+
)
|
316
|
+
|
317
|
+
return record
|
318
|
+
|
319
|
+
if self.pretraining:
|
320
|
+
# There is a 50% chance we randomly slice out a portion of the patient history and update the demographic
|
321
|
+
# prompt depending on the new starting point
|
322
|
+
if random.random() < 0.5:
|
323
|
+
start_index, end_index, demographic_tokens = random_slice_gpt_sequence(
|
324
|
+
concept_ids, new_max_length
|
325
|
+
)
|
326
|
+
if start_index != end_index:
|
327
|
+
record["input_ids"] = self._convert_to_tensor(
|
328
|
+
record["input_ids"][start_index : end_index + 1]
|
329
|
+
)
|
330
|
+
if self.include_values:
|
331
|
+
record["value_indicators"] = self._convert_to_tensor(
|
332
|
+
record["value_indicators"][start_index : end_index + 1]
|
333
|
+
).to(torch.bool)
|
334
|
+
record["values"] = self._convert_to_tensor(
|
335
|
+
record["values"][start_index : end_index + 1]
|
336
|
+
)
|
337
|
+
if self.include_ttv_prediction:
|
338
|
+
record["time_to_visits"] = self._convert_to_tensor(
|
339
|
+
self._convert_time_to_event(
|
340
|
+
concept_ids[start_index : end_index + 1]
|
341
|
+
)
|
342
|
+
)
|
343
|
+
return record
|
344
|
+
|
345
|
+
# The default employs a right truncation strategy, where the demographic prompt is reserved
|
346
|
+
end_index = new_max_length
|
347
|
+
for i in reversed(list(range(0, end_index))):
|
348
|
+
current_token = record["input_ids"][i]
|
349
|
+
if current_token == self.ve_token_id:
|
350
|
+
end_index = i
|
351
|
+
break
|
352
|
+
|
353
|
+
record["input_ids"] = record["input_ids"][0:end_index]
|
354
|
+
if self.include_values:
|
355
|
+
record["value_indicators"] = self._convert_to_tensor(
|
356
|
+
record["value_indicators"][0:end_index]
|
357
|
+
).to(torch.bool)
|
358
|
+
record["values"] = self._convert_to_tensor(
|
359
|
+
record["values"][0:end_index]
|
360
|
+
)
|
361
|
+
if self.include_ttv_prediction:
|
362
|
+
record["time_to_visits"] = self._convert_to_tensor(
|
363
|
+
self._convert_time_to_event(concept_ids[0:end_index])
|
364
|
+
)
|
365
|
+
return record
|
366
|
+
else:
|
367
|
+
if self.include_demographics:
|
368
|
+
# We employ a left truncation strategy, where the most recent patient history is reserved for fine-tuning
|
369
|
+
demographic_prompts_at_visits = collect_demographic_prompts_at_visits(
|
370
|
+
concept_ids
|
371
|
+
)
|
372
|
+
for token_index, demographic_prompt in demographic_prompts_at_visits:
|
373
|
+
if (
|
374
|
+
seq_length - token_index
|
375
|
+
<= new_max_length - DEMOGRAPHIC_PROMPT_SIZE
|
376
|
+
):
|
377
|
+
demographic_tokens = self.tokenizer.encode(demographic_prompt)
|
378
|
+
record["input_ids"] = torch.concat(
|
379
|
+
[
|
380
|
+
self._convert_to_tensor(demographic_tokens),
|
381
|
+
self._convert_to_tensor(
|
382
|
+
record["input_ids"][token_index:seq_length]
|
383
|
+
),
|
384
|
+
]
|
385
|
+
)
|
386
|
+
if self.include_values:
|
387
|
+
record["value_indicators"] = torch.concat(
|
388
|
+
[
|
389
|
+
torch.zeros(
|
390
|
+
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
391
|
+
).to(torch.bool),
|
392
|
+
self._convert_to_tensor(
|
393
|
+
record["value_indicators"][
|
394
|
+
token_index:seq_length
|
395
|
+
]
|
396
|
+
),
|
397
|
+
]
|
398
|
+
)
|
399
|
+
record["values"] = torch.concat(
|
400
|
+
[
|
401
|
+
torch.zeros(
|
402
|
+
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
403
|
+
)
|
404
|
+
.to(torch.int32)
|
405
|
+
.fill_(self.tokenizer.pad_value_token_id),
|
406
|
+
self._convert_to_tensor(
|
407
|
+
record["values"][token_index:seq_length]
|
408
|
+
),
|
409
|
+
]
|
410
|
+
)
|
411
|
+
if self.include_ttv_prediction:
|
412
|
+
record["time_to_visits"] = torch.concat(
|
413
|
+
[
|
414
|
+
torch.zeros(
|
415
|
+
[DEMOGRAPHIC_PROMPT_SIZE], dtype=torch.int32
|
416
|
+
)
|
417
|
+
.to(torch.float32)
|
418
|
+
.fill_(-100.0),
|
419
|
+
record["time_to_visits"][token_index:seq_length],
|
420
|
+
]
|
421
|
+
)
|
422
|
+
break
|
423
|
+
else:
|
424
|
+
start_index = seq_length - new_max_length
|
425
|
+
end_index = seq_length
|
426
|
+
for i in range(start_index, end_index):
|
427
|
+
current_token = record["input_ids"][i]
|
428
|
+
if current_token == self.vs_token_id:
|
429
|
+
record["input_ids"] = record["input_ids"][i:end_index]
|
430
|
+
if self.include_values:
|
431
|
+
record["value_indicators"] = record["value_indicators"][
|
432
|
+
i:end_index
|
433
|
+
]
|
434
|
+
record["values"] = record["values"][i:end_index]
|
435
|
+
if self.include_ttv_prediction:
|
436
|
+
record["time_to_visits"] = record["time_to_visits"][
|
437
|
+
i:end_index
|
438
|
+
]
|
439
|
+
break
|
440
|
+
|
441
|
+
# This could happen when the last visit contains more than new_max_length number of tokens
|
442
|
+
# We simply take the last new_max_length number of tokens from the patient sequence
|
443
|
+
if len(record["input_ids"]) > new_max_length:
|
444
|
+
record["input_ids"] = record["input_ids"][-new_max_length:]
|
445
|
+
if self.include_values:
|
446
|
+
record["value_indicators"] = record["value_indicators"][
|
447
|
+
-new_max_length:
|
448
|
+
]
|
449
|
+
record["values"] = record["values"][-new_max_length:]
|
450
|
+
if self.include_ttv_prediction:
|
451
|
+
record["time_to_visits"] = record["time_to_visits"][
|
452
|
+
-new_max_length:
|
453
|
+
]
|
454
|
+
|
455
|
+
# Finally we add the end token to the end of the sequence
|
456
|
+
record["input_ids"] = torch.concat(
|
457
|
+
[
|
458
|
+
self._convert_to_tensor(record["input_ids"]),
|
459
|
+
self._convert_to_tensor([self.tokenizer.end_token_id]),
|
460
|
+
]
|
461
|
+
)
|
462
|
+
if self.include_values:
|
463
|
+
record["value_indicators"] = torch.concat(
|
464
|
+
[
|
465
|
+
self._convert_to_tensor(record["value_indicators"]),
|
466
|
+
self._convert_to_tensor([False]),
|
467
|
+
]
|
468
|
+
).to(torch.bool)
|
469
|
+
record["values"] = torch.concat(
|
470
|
+
[
|
471
|
+
self._convert_to_tensor(record["values"]),
|
472
|
+
self._convert_to_tensor([self.tokenizer.pad_value_token_id]),
|
473
|
+
]
|
474
|
+
)
|
475
|
+
if self.include_ttv_prediction:
|
476
|
+
record["time_to_visits"] = torch.concat(
|
477
|
+
[
|
478
|
+
record["time_to_visits"],
|
479
|
+
self._convert_to_tensor([-100.0]),
|
480
|
+
]
|
481
|
+
)
|
482
|
+
return record
|
@@ -0,0 +1,116 @@
|
|
1
|
+
import datetime
|
2
|
+
from typing import Any, Dict
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import DatasetMapping
|
6
|
+
|
7
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import (
|
8
|
+
NONE_BIN,
|
9
|
+
UNKNOWN_BIN,
|
10
|
+
CehrGptTokenizer,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
def convert_date_to_posix_time(index_date: datetime.date) -> float:
|
15
|
+
return datetime.datetime.combine(
|
16
|
+
index_date, datetime.datetime.min.time()
|
17
|
+
).timestamp()
|
18
|
+
|
19
|
+
|
20
|
+
class HFCehrGptTokenizationMapping(DatasetMapping):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
concept_tokenizer: CehrGptTokenizer,
|
24
|
+
):
|
25
|
+
self._concept_tokenizer = concept_tokenizer
|
26
|
+
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
|
27
|
+
|
28
|
+
def remove_columns(self):
|
29
|
+
return [
|
30
|
+
"concept_value_masks",
|
31
|
+
"is_numeric_types",
|
32
|
+
]
|
33
|
+
|
34
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
35
|
+
# If any concept has a value associated with it, we normalize the value
|
36
|
+
record["input_ids"] = self._concept_tokenizer.encode(record["concept_ids"])
|
37
|
+
record["value_indicators"] = record["concept_value_masks"]
|
38
|
+
if "number_as_values" not in record or "concept_as_values" not in record:
|
39
|
+
record["number_as_values"] = [
|
40
|
+
float(value) if isinstance(value, float) else None
|
41
|
+
for value in record["concept_values"]
|
42
|
+
]
|
43
|
+
record["is_numeric_types"] = [
|
44
|
+
int(isinstance(value, float)) for value in record["concept_values"]
|
45
|
+
]
|
46
|
+
record["concept_as_values"] = [
|
47
|
+
value if isinstance(value, str) else None
|
48
|
+
for value in record["concept_values"]
|
49
|
+
]
|
50
|
+
if np.any(np.asarray(record["concept_value_masks"]) > 0):
|
51
|
+
values = []
|
52
|
+
for i, (
|
53
|
+
concept_id,
|
54
|
+
unit,
|
55
|
+
concept_value_mask,
|
56
|
+
number_as_value,
|
57
|
+
concept_as_value,
|
58
|
+
is_numeric_type,
|
59
|
+
) in enumerate(
|
60
|
+
zip(
|
61
|
+
record["concept_ids"],
|
62
|
+
record["units"],
|
63
|
+
record["concept_value_masks"],
|
64
|
+
record["number_as_values"],
|
65
|
+
record["concept_as_values"],
|
66
|
+
record["is_numeric_types"],
|
67
|
+
)
|
68
|
+
):
|
69
|
+
if concept_value_mask == 1:
|
70
|
+
value = UNKNOWN_BIN
|
71
|
+
if is_numeric_type == 1:
|
72
|
+
if concept_id in self._concept_tokenizer.numeric_concept_ids:
|
73
|
+
value = self._concept_tokenizer.normalize(
|
74
|
+
concept_id, unit, number_as_value
|
75
|
+
)
|
76
|
+
elif isinstance(concept_as_value, str):
|
77
|
+
value = concept_as_value
|
78
|
+
values.append(value)
|
79
|
+
else:
|
80
|
+
values.append(NONE_BIN)
|
81
|
+
assert len(values) == len(record["input_ids"])
|
82
|
+
record["values"] = self._concept_tokenizer.encode_value(values)
|
83
|
+
else:
|
84
|
+
record["values"] = self._concept_tokenizer.encode_value(
|
85
|
+
[NONE_BIN for _ in range(len(record["concept_value_masks"]))]
|
86
|
+
)
|
87
|
+
# Delete these features because they contain null values and pyarrow cannot concatenate multiple records
|
88
|
+
del record["number_as_values"]
|
89
|
+
del record["concept_as_values"]
|
90
|
+
return record
|
91
|
+
|
92
|
+
|
93
|
+
class HFFineTuningMapping(HFCehrGptTokenizationMapping):
|
94
|
+
"""Consider removing this transformation in the future."""
|
95
|
+
|
96
|
+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
|
97
|
+
record = super().transform(record)
|
98
|
+
record.update(
|
99
|
+
{
|
100
|
+
"age_at_index": (
|
101
|
+
record["age"] if "age" in record else record["age_at_index"]
|
102
|
+
),
|
103
|
+
"classifier_label": int(record["label"] > 0),
|
104
|
+
"index_date": (
|
105
|
+
convert_date_to_posix_time(record["index_date"])
|
106
|
+
if "index_date" in record
|
107
|
+
else None
|
108
|
+
),
|
109
|
+
}
|
110
|
+
)
|
111
|
+
return record
|
112
|
+
|
113
|
+
def remove_columns(self):
|
114
|
+
columns = super().remove_columns()
|
115
|
+
columns.append("label")
|
116
|
+
return columns
|
File without changes
|