genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
|
|
4
|
+
import h5pickle
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from genhpf.datasets.dataset import BaseDataset
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GenHPFDataset(BaseDataset):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
manifest_paths: List[str],
|
|
18
|
+
structure: str,
|
|
19
|
+
vocab_size: int = 28996,
|
|
20
|
+
pad_token_id: int = 0,
|
|
21
|
+
sep_token_id: int = 102,
|
|
22
|
+
ignore_index: int = -100,
|
|
23
|
+
apply_mask: bool = False,
|
|
24
|
+
mask_token_id: int = 103,
|
|
25
|
+
mask_prob: float = 0,
|
|
26
|
+
mask_unit: str = "individual",
|
|
27
|
+
simclr: bool = False,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
if structure == "hierarchical":
|
|
33
|
+
structure = "hi"
|
|
34
|
+
elif structure == "flattened":
|
|
35
|
+
structure = "fl"
|
|
36
|
+
self.structure = structure
|
|
37
|
+
|
|
38
|
+
self.pad_token_id = pad_token_id
|
|
39
|
+
|
|
40
|
+
self.ignore_index = ignore_index
|
|
41
|
+
|
|
42
|
+
self.apply_mask = apply_mask
|
|
43
|
+
self.mask_prob = mask_prob
|
|
44
|
+
self.vocab_size = vocab_size
|
|
45
|
+
self.mask_token_id = mask_token_id
|
|
46
|
+
self.mask_unit = mask_unit
|
|
47
|
+
self.sep_token_id = sep_token_id
|
|
48
|
+
|
|
49
|
+
self.simclr = simclr
|
|
50
|
+
|
|
51
|
+
for k, v in kwargs.items():
|
|
52
|
+
self.__setattr__(k, v)
|
|
53
|
+
|
|
54
|
+
self.data = []
|
|
55
|
+
self.subjects = []
|
|
56
|
+
self.labels = {}
|
|
57
|
+
for i, manifest_path in enumerate(manifest_paths):
|
|
58
|
+
with open(manifest_path, "r") as f:
|
|
59
|
+
data_root = f.readline().strip()
|
|
60
|
+
label_root = f.readline().strip()
|
|
61
|
+
self.data.append(h5pickle.File(data_root, "r")["ehr"])
|
|
62
|
+
labels = pd.read_csv(label_root)
|
|
63
|
+
labels.index = [(i, str(x)) for x in labels["stay_id"]]
|
|
64
|
+
labels = labels.drop(columns=["stay_id"])
|
|
65
|
+
self.labels.update(labels.to_dict(orient="index"))
|
|
66
|
+
for line in f:
|
|
67
|
+
items = line.strip().split("\t")
|
|
68
|
+
assert len(items) == 1, line
|
|
69
|
+
self.subjects.append((i, items[0]))
|
|
70
|
+
logger.info(f"loaded {len(self.subjects)} samples from {len(manifest_paths)} dataset(s)")
|
|
71
|
+
|
|
72
|
+
def __len__(self):
|
|
73
|
+
return len(self.subjects)
|
|
74
|
+
|
|
75
|
+
def __getitem__(self, index):
|
|
76
|
+
raise NotImplementedError
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class HierarchicalGenHPFDataset(GenHPFDataset):
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
manifest_paths: List[str],
|
|
83
|
+
label: bool = False,
|
|
84
|
+
tasks: List[str] = None,
|
|
85
|
+
num_labels: List[int] = None,
|
|
86
|
+
dummy_token_id: int = 101,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
kwargs.pop("structure", None)
|
|
90
|
+
super().__init__(manifest_paths=manifest_paths, structure="hierarchical", **kwargs)
|
|
91
|
+
|
|
92
|
+
self.label = label
|
|
93
|
+
self.tasks = tasks
|
|
94
|
+
self.num_labels = num_labels
|
|
95
|
+
self.dummy_token_id = dummy_token_id
|
|
96
|
+
|
|
97
|
+
def mask(self, tokens: Union[np.ndarray, torch.Tensor], **kwargs):
|
|
98
|
+
for i, event in enumerate(tokens):
|
|
99
|
+
tokens[i], _ = super().mask(event, **kwargs)
|
|
100
|
+
return tokens
|
|
101
|
+
|
|
102
|
+
def collator(self, samples):
|
|
103
|
+
samples = [s for s in samples if s["input_ids"] is not None]
|
|
104
|
+
if len(samples) == 0:
|
|
105
|
+
return {}
|
|
106
|
+
|
|
107
|
+
if self.simclr:
|
|
108
|
+
input_ids = sum(
|
|
109
|
+
[
|
|
110
|
+
[s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
|
|
111
|
+
for s in samples
|
|
112
|
+
],
|
|
113
|
+
[],
|
|
114
|
+
)
|
|
115
|
+
type_ids = sum(
|
|
116
|
+
[
|
|
117
|
+
[s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
|
|
118
|
+
for s in samples
|
|
119
|
+
],
|
|
120
|
+
[],
|
|
121
|
+
)
|
|
122
|
+
dpe_ids = sum(
|
|
123
|
+
[
|
|
124
|
+
[s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
|
|
125
|
+
for s in samples
|
|
126
|
+
],
|
|
127
|
+
[],
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
input_ids = [s["input_ids"] for s in samples]
|
|
131
|
+
type_ids = [s["type_ids"] for s in samples]
|
|
132
|
+
dpe_ids = [s["dpe_ids"] for s in samples]
|
|
133
|
+
|
|
134
|
+
sizes = [s.size(0) for s in input_ids]
|
|
135
|
+
target_size = max(sizes)
|
|
136
|
+
|
|
137
|
+
collated_input_ids = (
|
|
138
|
+
input_ids[0].new_zeros((len(input_ids), target_size, len(input_ids[0][0]))).long()
|
|
139
|
+
)
|
|
140
|
+
collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size, len(type_ids[0][0]))).long()
|
|
141
|
+
collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size, len(dpe_ids[0][0]))).long()
|
|
142
|
+
for i, size in enumerate(sizes):
|
|
143
|
+
diff = size - target_size
|
|
144
|
+
if diff == 0:
|
|
145
|
+
collated_input_ids[i] = input_ids[i]
|
|
146
|
+
collated_type_ids[i] = type_ids[i]
|
|
147
|
+
collated_dpe_ids[i] = dpe_ids[i]
|
|
148
|
+
elif diff < 0:
|
|
149
|
+
collated_input_ids[i] = torch.cat(
|
|
150
|
+
[
|
|
151
|
+
input_ids[i],
|
|
152
|
+
input_ids[i].new_zeros(-diff, len(input_ids[i][0])),
|
|
153
|
+
],
|
|
154
|
+
dim=0,
|
|
155
|
+
)
|
|
156
|
+
# add dummy token to the start of each padded event as the event encoder can be
|
|
157
|
+
# crushed when all the input tokens are pad tokens
|
|
158
|
+
collated_input_ids[i][diff:, 0] = self.dummy_token_id
|
|
159
|
+
collated_type_ids[i] = torch.cat(
|
|
160
|
+
[
|
|
161
|
+
type_ids[i],
|
|
162
|
+
type_ids[i].new_zeros(-diff, len(type_ids[i][0])),
|
|
163
|
+
],
|
|
164
|
+
dim=0,
|
|
165
|
+
)
|
|
166
|
+
collated_dpe_ids[i] = torch.cat(
|
|
167
|
+
[
|
|
168
|
+
dpe_ids[i],
|
|
169
|
+
dpe_ids[i].new_zeros(-diff, len(dpe_ids[i][0])),
|
|
170
|
+
],
|
|
171
|
+
dim=0,
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
raise ValueError(f"size mismatch, expected <={target_size}, got {size}")
|
|
175
|
+
|
|
176
|
+
out = {"id": [s["id"] for s in samples]}
|
|
177
|
+
out["net_input"] = {
|
|
178
|
+
"input_ids": collated_input_ids,
|
|
179
|
+
"type_ids": collated_type_ids,
|
|
180
|
+
"dpe_ids": collated_dpe_ids,
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if self.label:
|
|
184
|
+
label = {}
|
|
185
|
+
for task in self.tasks:
|
|
186
|
+
label[task] = torch.stack([s[task] for s in samples])
|
|
187
|
+
out["label"] = label
|
|
188
|
+
|
|
189
|
+
return out
|
|
190
|
+
|
|
191
|
+
def __getitem__(self, index):
|
|
192
|
+
data_index, subject = self.subjects[index]
|
|
193
|
+
data = self.data[data_index][subject][self.structure][:]
|
|
194
|
+
|
|
195
|
+
if self.apply_mask:
|
|
196
|
+
data = self.mask(
|
|
197
|
+
data,
|
|
198
|
+
mask_prob=self.mask_prob,
|
|
199
|
+
vocab_size=self.vocab_size,
|
|
200
|
+
mask_token_id=self.mask_token_id,
|
|
201
|
+
mask_unit=self.mask_unit,
|
|
202
|
+
sep_token_id=self.sep_token_id,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
ret = {
|
|
206
|
+
"id": subject,
|
|
207
|
+
"input_ids": torch.LongTensor(data[:, 0, :]),
|
|
208
|
+
"type_ids": torch.LongTensor(data[:, 1, :]),
|
|
209
|
+
"dpe_ids": torch.LongTensor(data[:, 2, :]),
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
if self.label:
|
|
213
|
+
for i, task in enumerate(self.tasks):
|
|
214
|
+
ret[task] = self.labels[self.subjects[index]][task]
|
|
215
|
+
if isinstance(ret[task], str):
|
|
216
|
+
ret[task] = eval(ret[task])
|
|
217
|
+
# for multi-label classification, where the label is given by a list of class indices
|
|
218
|
+
if isinstance(ret[task], list):
|
|
219
|
+
ret[task] = list(map(int, ret[task]))
|
|
220
|
+
num_label = self.num_labels[i]
|
|
221
|
+
label = np.zeros(num_label, dtype=np.int16)
|
|
222
|
+
label[ret[task]] = 1
|
|
223
|
+
ret[task] = torch.tensor(label)
|
|
224
|
+
else:
|
|
225
|
+
if np.isnan(ret[task]) or ret[task] < 0:
|
|
226
|
+
ret[task] = self.ignore_index
|
|
227
|
+
ret[task] = torch.tensor(ret[task])
|
|
228
|
+
return ret
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class FlattenedGenHPFDataset(GenHPFDataset):
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
manifest_paths: List[str],
|
|
235
|
+
label: bool = False,
|
|
236
|
+
tasks: List[str] = None,
|
|
237
|
+
num_labels: List[int] = None,
|
|
238
|
+
**kwargs,
|
|
239
|
+
):
|
|
240
|
+
kwargs.pop("structure", None)
|
|
241
|
+
super().__init__(manifest_paths=manifest_paths, structure="flattened", **kwargs)
|
|
242
|
+
|
|
243
|
+
self.label = label
|
|
244
|
+
self.tasks = tasks
|
|
245
|
+
self.num_labels = num_labels
|
|
246
|
+
|
|
247
|
+
def sample_crop_indices(self, size, diff):
|
|
248
|
+
if self.mask:
|
|
249
|
+
start = np.random.randint(0, diff + 1)
|
|
250
|
+
end = size - diff + start
|
|
251
|
+
else:
|
|
252
|
+
start = 0
|
|
253
|
+
end = size - diff
|
|
254
|
+
return start, end
|
|
255
|
+
|
|
256
|
+
def pad_to_max_size(self, sample, max_len):
|
|
257
|
+
if len(sample) < max_len:
|
|
258
|
+
sample = np.concatenate([sample, np.zeros(max_len - len(sample), dtype=np.int16)])
|
|
259
|
+
else:
|
|
260
|
+
sample = sample[:max_len]
|
|
261
|
+
return sample
|
|
262
|
+
|
|
263
|
+
def collator(self, samples):
|
|
264
|
+
samples = [s for s in samples if s["input_ids"] is not None]
|
|
265
|
+
if len(samples) == 0:
|
|
266
|
+
return {}
|
|
267
|
+
|
|
268
|
+
if self.simclr:
|
|
269
|
+
input_ids = sum(
|
|
270
|
+
[
|
|
271
|
+
[s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
|
|
272
|
+
for s in samples
|
|
273
|
+
],
|
|
274
|
+
[],
|
|
275
|
+
)
|
|
276
|
+
type_ids = sum(
|
|
277
|
+
[
|
|
278
|
+
[s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
|
|
279
|
+
for s in samples
|
|
280
|
+
],
|
|
281
|
+
[],
|
|
282
|
+
)
|
|
283
|
+
dpe_ids = sum(
|
|
284
|
+
[
|
|
285
|
+
[s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
|
|
286
|
+
for s in samples
|
|
287
|
+
],
|
|
288
|
+
[],
|
|
289
|
+
)
|
|
290
|
+
if self.apply_mask:
|
|
291
|
+
input_label = sum(
|
|
292
|
+
[
|
|
293
|
+
[
|
|
294
|
+
s["input_label"][: len(s["input_ids"]) // 2],
|
|
295
|
+
s["input_label"][len(s["input_ids"]) // 2 :],
|
|
296
|
+
]
|
|
297
|
+
for s in samples
|
|
298
|
+
],
|
|
299
|
+
[],
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
input_ids = [s["input_ids"] for s in samples]
|
|
303
|
+
type_ids = [s["type_ids"] for s in samples]
|
|
304
|
+
dpe_ids = [s["dpe_ids"] for s in samples]
|
|
305
|
+
if self.apply_mask:
|
|
306
|
+
input_label = [s["input_label"] for s in samples]
|
|
307
|
+
type_label = [s["type_label"] for s in samples]
|
|
308
|
+
dpe_label = [s["dpe_label"] for s in samples]
|
|
309
|
+
|
|
310
|
+
sizes = [s.size(0) for s in input_ids]
|
|
311
|
+
target_size = max(sizes)
|
|
312
|
+
|
|
313
|
+
collated_input_ids = input_ids[0].new_zeros((len(input_ids), target_size)).long()
|
|
314
|
+
collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size)).long()
|
|
315
|
+
collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size)).long()
|
|
316
|
+
if self.apply_mask:
|
|
317
|
+
collated_input_label = input_label[0].new_zeros((len(input_label), target_size)).long()
|
|
318
|
+
collated_type_label = type_label[0].new_zeros((len(type_label), target_size)).long()
|
|
319
|
+
collated_dpe_label = dpe_label[0].new_zeros((len(dpe_label), target_size)).long()
|
|
320
|
+
for i, size in enumerate(sizes):
|
|
321
|
+
diff = size - target_size
|
|
322
|
+
if diff == 0:
|
|
323
|
+
collated_input_ids[i] = input_ids[i]
|
|
324
|
+
collated_type_ids[i] = type_ids[i]
|
|
325
|
+
collated_dpe_ids[i] = dpe_ids[i]
|
|
326
|
+
if self.apply_mask:
|
|
327
|
+
collated_input_label[i] = input_label[i]
|
|
328
|
+
collated_type_label[i] = type_label[i]
|
|
329
|
+
collated_dpe_label[i] = dpe_label[i]
|
|
330
|
+
elif diff < 0:
|
|
331
|
+
collated_input_ids[i] = torch.cat(
|
|
332
|
+
[
|
|
333
|
+
input_ids[i],
|
|
334
|
+
input_ids[i].new_zeros(
|
|
335
|
+
-diff,
|
|
336
|
+
),
|
|
337
|
+
],
|
|
338
|
+
dim=0,
|
|
339
|
+
)
|
|
340
|
+
collated_type_ids[i] = torch.cat(
|
|
341
|
+
[
|
|
342
|
+
type_ids[i],
|
|
343
|
+
type_ids[i].new_zeros(
|
|
344
|
+
-diff,
|
|
345
|
+
),
|
|
346
|
+
],
|
|
347
|
+
dim=0,
|
|
348
|
+
)
|
|
349
|
+
collated_dpe_ids[i] = torch.cat(
|
|
350
|
+
[
|
|
351
|
+
dpe_ids[i],
|
|
352
|
+
dpe_ids[i].new_zeros(
|
|
353
|
+
-diff,
|
|
354
|
+
),
|
|
355
|
+
],
|
|
356
|
+
dim=0,
|
|
357
|
+
)
|
|
358
|
+
if self.apply_mask:
|
|
359
|
+
collated_input_label[i] = torch.cat(
|
|
360
|
+
[
|
|
361
|
+
input_label[i],
|
|
362
|
+
input_label[i].new_zeros(
|
|
363
|
+
-diff,
|
|
364
|
+
),
|
|
365
|
+
],
|
|
366
|
+
dim=0,
|
|
367
|
+
)
|
|
368
|
+
collated_type_label[i] = torch.cat(
|
|
369
|
+
[
|
|
370
|
+
type_label[i],
|
|
371
|
+
type_label[i].new_zeros(
|
|
372
|
+
-diff,
|
|
373
|
+
),
|
|
374
|
+
],
|
|
375
|
+
dim=0,
|
|
376
|
+
)
|
|
377
|
+
collated_dpe_label[i] = torch.cat(
|
|
378
|
+
[
|
|
379
|
+
dpe_label[i],
|
|
380
|
+
dpe_label[i].new_zeros(
|
|
381
|
+
-diff,
|
|
382
|
+
),
|
|
383
|
+
],
|
|
384
|
+
dim=0,
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(f"size mismatch, expected <={target_size}, got {size}")
|
|
388
|
+
|
|
389
|
+
out = {"id": [s["id"] for s in samples]}
|
|
390
|
+
out["net_input"] = {
|
|
391
|
+
"input_ids": collated_input_ids,
|
|
392
|
+
"type_ids": collated_type_ids,
|
|
393
|
+
"dpe_ids": collated_dpe_ids,
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
if self.apply_mask:
|
|
397
|
+
out["input_label"] = collated_input_label
|
|
398
|
+
out["type_label"] = collated_type_label
|
|
399
|
+
out["dpe_label"] = collated_dpe_label
|
|
400
|
+
|
|
401
|
+
if self.label:
|
|
402
|
+
label = {}
|
|
403
|
+
for task in self.tasks:
|
|
404
|
+
label[task] = torch.stack([s[task] for s in samples])
|
|
405
|
+
out["label"] = label
|
|
406
|
+
|
|
407
|
+
return out
|
|
408
|
+
|
|
409
|
+
def __getitem__(self, index):
|
|
410
|
+
data_index, subject = self.subjects[index]
|
|
411
|
+
data = self.data[data_index][subject][self.structure][:]
|
|
412
|
+
|
|
413
|
+
if self.apply_mask:
|
|
414
|
+
data, mlm_labels = self.mask(
|
|
415
|
+
data,
|
|
416
|
+
mask_prob=self.mask_prob,
|
|
417
|
+
vocab_size=self.vocab_size,
|
|
418
|
+
mask_token_id=self.mask_token_id,
|
|
419
|
+
mask_unit=self.mask_unit,
|
|
420
|
+
sep_token_id=self.sep_token_id,
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
ret = {
|
|
424
|
+
"id": self.subjects[index],
|
|
425
|
+
"input_ids": torch.LongTensor(data[0, :]),
|
|
426
|
+
"type_ids": torch.LongTensor(data[1, :]),
|
|
427
|
+
"dpe_ids": torch.LongTensor(data[2, :]),
|
|
428
|
+
}
|
|
429
|
+
if self.apply_mask:
|
|
430
|
+
ret["input_label"] = torch.LongTensor(mlm_labels[0, :])
|
|
431
|
+
ret["type_label"] = torch.LongTensor(mlm_labels[1, :])
|
|
432
|
+
ret["dpe_label"] = torch.LongTensor(mlm_labels[2, :])
|
|
433
|
+
|
|
434
|
+
if self.label:
|
|
435
|
+
for i, task in enumerate(self.tasks):
|
|
436
|
+
ret[task] = self.labels[self.subjects[index]][task]
|
|
437
|
+
if isinstance(ret[task], str):
|
|
438
|
+
ret[task] = eval(ret[task])
|
|
439
|
+
# for multi-label classification, where the label is given by a list of class indices
|
|
440
|
+
if isinstance(ret[task], list):
|
|
441
|
+
ret[task] = list(map(int, ret[task]))
|
|
442
|
+
num_label = self.num_labels[i]
|
|
443
|
+
label = np.zeros(num_label, dtype=np.int16)
|
|
444
|
+
label[ret[task]] = 1
|
|
445
|
+
ret[task] = torch.tensor(label)
|
|
446
|
+
else:
|
|
447
|
+
if np.isnan(ret[task]) or ret[task] < 0:
|
|
448
|
+
ret[task] = self.ignore_index
|
|
449
|
+
ret[task] = torch.tensor(ret[task])
|
|
450
|
+
|
|
451
|
+
return ret
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Union
|
|
4
|
+
|
|
5
|
+
import h5pickle
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from genhpf.datasets.dataset import BaseDataset
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MEDSDataset(BaseDataset):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
manifest_paths: List[str],
|
|
18
|
+
structure: str = "hierarchical",
|
|
19
|
+
vocab_size: int = 28996,
|
|
20
|
+
pad_token_id: int = 0,
|
|
21
|
+
sep_token_id: int = 102,
|
|
22
|
+
ignore_index: int = -100,
|
|
23
|
+
apply_mask: bool = False,
|
|
24
|
+
mask_token_id: int = 103,
|
|
25
|
+
mask_prob: float = 0,
|
|
26
|
+
mask_unit: str = "individual",
|
|
27
|
+
simclr: bool = False,
|
|
28
|
+
debug: bool = False,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
|
|
33
|
+
if structure == "hierarchical":
|
|
34
|
+
structure = "hi"
|
|
35
|
+
elif structure == "flattened":
|
|
36
|
+
raise NotImplementedError("Flattened structure is not supported yet.")
|
|
37
|
+
self.structure = structure
|
|
38
|
+
|
|
39
|
+
self.pad_token_id = pad_token_id
|
|
40
|
+
|
|
41
|
+
self.ignore_index = ignore_index
|
|
42
|
+
|
|
43
|
+
self.apply_mask = apply_mask
|
|
44
|
+
self.mask_prob = mask_prob
|
|
45
|
+
self.vocab_size = vocab_size
|
|
46
|
+
self.mask_token_id = mask_token_id
|
|
47
|
+
self.mask_unit = mask_unit
|
|
48
|
+
self.sep_token_id = sep_token_id
|
|
49
|
+
|
|
50
|
+
self.simclr = simclr
|
|
51
|
+
|
|
52
|
+
for k, v in kwargs.items():
|
|
53
|
+
self.__setattr__(k, v)
|
|
54
|
+
|
|
55
|
+
self.data = []
|
|
56
|
+
self.subjects = []
|
|
57
|
+
self.shard_ids = []
|
|
58
|
+
self.sizes = []
|
|
59
|
+
for i, manifest_path in enumerate(manifest_paths):
|
|
60
|
+
with open(manifest_path, "r") as f:
|
|
61
|
+
data_i_root = f.readline().strip()
|
|
62
|
+
shard_ids = []
|
|
63
|
+
for j, line in enumerate(f):
|
|
64
|
+
if debug and j >= 300:
|
|
65
|
+
break
|
|
66
|
+
items = line.strip().split("\t")
|
|
67
|
+
assert len(items) == 3, line
|
|
68
|
+
subject_id, num_events, shard_id = items
|
|
69
|
+
self.subjects.append((i, subject_id))
|
|
70
|
+
self.sizes.append(int(num_events))
|
|
71
|
+
shard_ids.append(int(shard_id))
|
|
72
|
+
|
|
73
|
+
data_i = {}
|
|
74
|
+
unique_shard_ids = np.unique(shard_ids)
|
|
75
|
+
for shard_id in unique_shard_ids:
|
|
76
|
+
data_i[shard_id] = h5pickle.File(os.path.join(data_i_root, f"{shard_id}.h5"))["ehr"]
|
|
77
|
+
self.data.append(data_i)
|
|
78
|
+
self.shard_ids.extend(shard_ids)
|
|
79
|
+
logger.info(f"loaded {len(self.subjects)} samples from {len(manifest_paths)} dataset(s)")
|
|
80
|
+
|
|
81
|
+
def __len__(self):
|
|
82
|
+
return len(self.subjects)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class HierarchicalMEDSDataset(MEDSDataset):
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
manifest_paths: List[str],
|
|
89
|
+
max_events: int = 256,
|
|
90
|
+
label: bool = False,
|
|
91
|
+
tasks: List[str] = None,
|
|
92
|
+
num_labels: List[int] = None,
|
|
93
|
+
dummy_token_id: int = 101,
|
|
94
|
+
**kwargs,
|
|
95
|
+
):
|
|
96
|
+
kwargs.pop("structure", None)
|
|
97
|
+
super().__init__(manifest_paths=manifest_paths, structure="hierarchical", **kwargs)
|
|
98
|
+
|
|
99
|
+
self.max_events = max_events
|
|
100
|
+
self.label = label
|
|
101
|
+
self.tasks = tasks
|
|
102
|
+
self.num_labels = num_labels
|
|
103
|
+
self.dummy_token_id = dummy_token_id
|
|
104
|
+
|
|
105
|
+
def mask(self, tokens: Union[np.ndarray, torch.Tensor], **kwargs):
|
|
106
|
+
for i, event in enumerate(tokens):
|
|
107
|
+
tokens[i], _ = super().mask(event, **kwargs)
|
|
108
|
+
return tokens
|
|
109
|
+
|
|
110
|
+
def collator(self, samples):
|
|
111
|
+
samples = [s for s in samples if s["input_ids"] is not None]
|
|
112
|
+
if len(samples) == 0:
|
|
113
|
+
return {}
|
|
114
|
+
|
|
115
|
+
if self.simclr:
|
|
116
|
+
input_ids = sum(
|
|
117
|
+
[
|
|
118
|
+
[s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
|
|
119
|
+
for s in samples
|
|
120
|
+
],
|
|
121
|
+
[],
|
|
122
|
+
)
|
|
123
|
+
type_ids = sum(
|
|
124
|
+
[
|
|
125
|
+
[s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
|
|
126
|
+
for s in samples
|
|
127
|
+
],
|
|
128
|
+
[],
|
|
129
|
+
)
|
|
130
|
+
dpe_ids = sum(
|
|
131
|
+
[
|
|
132
|
+
[s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
|
|
133
|
+
for s in samples
|
|
134
|
+
],
|
|
135
|
+
[],
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
input_ids = [s["input_ids"] for s in samples]
|
|
139
|
+
type_ids = [s["type_ids"] for s in samples]
|
|
140
|
+
dpe_ids = [s["dpe_ids"] for s in samples]
|
|
141
|
+
|
|
142
|
+
sizes = [s.size(0) for s in input_ids]
|
|
143
|
+
target_size = self.max_events
|
|
144
|
+
|
|
145
|
+
collated_input_ids = (
|
|
146
|
+
input_ids[0].new_zeros((len(input_ids), target_size, len(input_ids[0][0]))).long()
|
|
147
|
+
)
|
|
148
|
+
collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size, len(type_ids[0][0]))).long()
|
|
149
|
+
collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size, len(dpe_ids[0][0]))).long()
|
|
150
|
+
for i, size in enumerate(sizes):
|
|
151
|
+
diff = size - target_size
|
|
152
|
+
if diff == 0:
|
|
153
|
+
collated_input_ids[i] = input_ids[i]
|
|
154
|
+
collated_type_ids[i] = type_ids[i]
|
|
155
|
+
collated_dpe_ids[i] = dpe_ids[i]
|
|
156
|
+
elif diff < 0:
|
|
157
|
+
collated_input_ids[i] = torch.cat(
|
|
158
|
+
[
|
|
159
|
+
input_ids[i],
|
|
160
|
+
input_ids[i].new_zeros(-diff, len(input_ids[i][0])),
|
|
161
|
+
],
|
|
162
|
+
dim=0,
|
|
163
|
+
)
|
|
164
|
+
# add dummy token to the start of each padded event as the event encoder can be
|
|
165
|
+
# crushed when all the input tokens are pad tokens
|
|
166
|
+
collated_input_ids[i][diff:, 0] = self.dummy_token_id
|
|
167
|
+
collated_type_ids[i] = torch.cat(
|
|
168
|
+
[
|
|
169
|
+
type_ids[i],
|
|
170
|
+
type_ids[i].new_zeros(-diff, len(type_ids[i][0])),
|
|
171
|
+
],
|
|
172
|
+
dim=0,
|
|
173
|
+
)
|
|
174
|
+
collated_dpe_ids[i] = torch.cat(
|
|
175
|
+
[
|
|
176
|
+
dpe_ids[i],
|
|
177
|
+
dpe_ids[i].new_zeros(-diff, len(dpe_ids[i][0])),
|
|
178
|
+
],
|
|
179
|
+
dim=0,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
collated_input_ids[i] = input_ids[i][-target_size:]
|
|
183
|
+
collated_type_ids[i] = type_ids[i][-target_size:]
|
|
184
|
+
collated_dpe_ids[i] = dpe_ids[i][-target_size:]
|
|
185
|
+
|
|
186
|
+
out = {"id": [s["id"] for s in samples]}
|
|
187
|
+
out["net_input"] = {
|
|
188
|
+
"input_ids": collated_input_ids,
|
|
189
|
+
"type_ids": collated_type_ids,
|
|
190
|
+
"dpe_ids": collated_dpe_ids,
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
if self.label:
|
|
194
|
+
label = {}
|
|
195
|
+
for task in self.tasks:
|
|
196
|
+
if len(samples[0][task]) == 1:
|
|
197
|
+
label[task] = torch.cat([s[task] for s in samples])
|
|
198
|
+
else:
|
|
199
|
+
label[task] = torch.stack([s[task] for s in samples])
|
|
200
|
+
out["label"] = label
|
|
201
|
+
|
|
202
|
+
return out
|
|
203
|
+
|
|
204
|
+
def __getitem__(self, idx):
|
|
205
|
+
data_idx, subject = self.subjects[idx]
|
|
206
|
+
data = self.data[data_idx][self.shard_ids[idx]][subject]
|
|
207
|
+
|
|
208
|
+
tokens = data[self.structure][:]
|
|
209
|
+
if self.apply_mask:
|
|
210
|
+
tokens = self.mask(
|
|
211
|
+
tokens,
|
|
212
|
+
mask_prob=self.mask_prob,
|
|
213
|
+
vocab_size=self.vocab_size,
|
|
214
|
+
mask_token_id=self.mask_token_id,
|
|
215
|
+
mask_unit=self.mask_unit,
|
|
216
|
+
sep_token_id=self.sep_token_id,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
ret = {
|
|
220
|
+
"id": subject,
|
|
221
|
+
"input_ids": torch.LongTensor(tokens[:, 0, :]),
|
|
222
|
+
"type_ids": torch.LongTensor(tokens[:, 1, :]),
|
|
223
|
+
"dpe_ids": torch.LongTensor(tokens[:, 2, :]),
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
if self.label:
|
|
227
|
+
for i, task in enumerate(self.tasks):
|
|
228
|
+
try:
|
|
229
|
+
ret[task] = torch.LongTensor(data["label"][i])
|
|
230
|
+
except ValueError:
|
|
231
|
+
ret[task] = torch.LongTensor([data["label"][()]])
|
|
232
|
+
return ret
|
|
File without changes
|