genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,451 @@
1
+ import logging
2
+ from typing import List, Union
3
+
4
+ import h5pickle
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+
9
+ from genhpf.datasets.dataset import BaseDataset
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class GenHPFDataset(BaseDataset):
15
+ def __init__(
16
+ self,
17
+ manifest_paths: List[str],
18
+ structure: str,
19
+ vocab_size: int = 28996,
20
+ pad_token_id: int = 0,
21
+ sep_token_id: int = 102,
22
+ ignore_index: int = -100,
23
+ apply_mask: bool = False,
24
+ mask_token_id: int = 103,
25
+ mask_prob: float = 0,
26
+ mask_unit: str = "individual",
27
+ simclr: bool = False,
28
+ **kwargs,
29
+ ):
30
+ super().__init__()
31
+
32
+ if structure == "hierarchical":
33
+ structure = "hi"
34
+ elif structure == "flattened":
35
+ structure = "fl"
36
+ self.structure = structure
37
+
38
+ self.pad_token_id = pad_token_id
39
+
40
+ self.ignore_index = ignore_index
41
+
42
+ self.apply_mask = apply_mask
43
+ self.mask_prob = mask_prob
44
+ self.vocab_size = vocab_size
45
+ self.mask_token_id = mask_token_id
46
+ self.mask_unit = mask_unit
47
+ self.sep_token_id = sep_token_id
48
+
49
+ self.simclr = simclr
50
+
51
+ for k, v in kwargs.items():
52
+ self.__setattr__(k, v)
53
+
54
+ self.data = []
55
+ self.subjects = []
56
+ self.labels = {}
57
+ for i, manifest_path in enumerate(manifest_paths):
58
+ with open(manifest_path, "r") as f:
59
+ data_root = f.readline().strip()
60
+ label_root = f.readline().strip()
61
+ self.data.append(h5pickle.File(data_root, "r")["ehr"])
62
+ labels = pd.read_csv(label_root)
63
+ labels.index = [(i, str(x)) for x in labels["stay_id"]]
64
+ labels = labels.drop(columns=["stay_id"])
65
+ self.labels.update(labels.to_dict(orient="index"))
66
+ for line in f:
67
+ items = line.strip().split("\t")
68
+ assert len(items) == 1, line
69
+ self.subjects.append((i, items[0]))
70
+ logger.info(f"loaded {len(self.subjects)} samples from {len(manifest_paths)} dataset(s)")
71
+
72
+ def __len__(self):
73
+ return len(self.subjects)
74
+
75
+ def __getitem__(self, index):
76
+ raise NotImplementedError
77
+
78
+
79
+ class HierarchicalGenHPFDataset(GenHPFDataset):
80
+ def __init__(
81
+ self,
82
+ manifest_paths: List[str],
83
+ label: bool = False,
84
+ tasks: List[str] = None,
85
+ num_labels: List[int] = None,
86
+ dummy_token_id: int = 101,
87
+ **kwargs,
88
+ ):
89
+ kwargs.pop("structure", None)
90
+ super().__init__(manifest_paths=manifest_paths, structure="hierarchical", **kwargs)
91
+
92
+ self.label = label
93
+ self.tasks = tasks
94
+ self.num_labels = num_labels
95
+ self.dummy_token_id = dummy_token_id
96
+
97
+ def mask(self, tokens: Union[np.ndarray, torch.Tensor], **kwargs):
98
+ for i, event in enumerate(tokens):
99
+ tokens[i], _ = super().mask(event, **kwargs)
100
+ return tokens
101
+
102
+ def collator(self, samples):
103
+ samples = [s for s in samples if s["input_ids"] is not None]
104
+ if len(samples) == 0:
105
+ return {}
106
+
107
+ if self.simclr:
108
+ input_ids = sum(
109
+ [
110
+ [s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
111
+ for s in samples
112
+ ],
113
+ [],
114
+ )
115
+ type_ids = sum(
116
+ [
117
+ [s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
118
+ for s in samples
119
+ ],
120
+ [],
121
+ )
122
+ dpe_ids = sum(
123
+ [
124
+ [s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
125
+ for s in samples
126
+ ],
127
+ [],
128
+ )
129
+ else:
130
+ input_ids = [s["input_ids"] for s in samples]
131
+ type_ids = [s["type_ids"] for s in samples]
132
+ dpe_ids = [s["dpe_ids"] for s in samples]
133
+
134
+ sizes = [s.size(0) for s in input_ids]
135
+ target_size = max(sizes)
136
+
137
+ collated_input_ids = (
138
+ input_ids[0].new_zeros((len(input_ids), target_size, len(input_ids[0][0]))).long()
139
+ )
140
+ collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size, len(type_ids[0][0]))).long()
141
+ collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size, len(dpe_ids[0][0]))).long()
142
+ for i, size in enumerate(sizes):
143
+ diff = size - target_size
144
+ if diff == 0:
145
+ collated_input_ids[i] = input_ids[i]
146
+ collated_type_ids[i] = type_ids[i]
147
+ collated_dpe_ids[i] = dpe_ids[i]
148
+ elif diff < 0:
149
+ collated_input_ids[i] = torch.cat(
150
+ [
151
+ input_ids[i],
152
+ input_ids[i].new_zeros(-diff, len(input_ids[i][0])),
153
+ ],
154
+ dim=0,
155
+ )
156
+ # add dummy token to the start of each padded event as the event encoder can be
157
+ # crushed when all the input tokens are pad tokens
158
+ collated_input_ids[i][diff:, 0] = self.dummy_token_id
159
+ collated_type_ids[i] = torch.cat(
160
+ [
161
+ type_ids[i],
162
+ type_ids[i].new_zeros(-diff, len(type_ids[i][0])),
163
+ ],
164
+ dim=0,
165
+ )
166
+ collated_dpe_ids[i] = torch.cat(
167
+ [
168
+ dpe_ids[i],
169
+ dpe_ids[i].new_zeros(-diff, len(dpe_ids[i][0])),
170
+ ],
171
+ dim=0,
172
+ )
173
+ else:
174
+ raise ValueError(f"size mismatch, expected <={target_size}, got {size}")
175
+
176
+ out = {"id": [s["id"] for s in samples]}
177
+ out["net_input"] = {
178
+ "input_ids": collated_input_ids,
179
+ "type_ids": collated_type_ids,
180
+ "dpe_ids": collated_dpe_ids,
181
+ }
182
+
183
+ if self.label:
184
+ label = {}
185
+ for task in self.tasks:
186
+ label[task] = torch.stack([s[task] for s in samples])
187
+ out["label"] = label
188
+
189
+ return out
190
+
191
+ def __getitem__(self, index):
192
+ data_index, subject = self.subjects[index]
193
+ data = self.data[data_index][subject][self.structure][:]
194
+
195
+ if self.apply_mask:
196
+ data = self.mask(
197
+ data,
198
+ mask_prob=self.mask_prob,
199
+ vocab_size=self.vocab_size,
200
+ mask_token_id=self.mask_token_id,
201
+ mask_unit=self.mask_unit,
202
+ sep_token_id=self.sep_token_id,
203
+ )
204
+
205
+ ret = {
206
+ "id": subject,
207
+ "input_ids": torch.LongTensor(data[:, 0, :]),
208
+ "type_ids": torch.LongTensor(data[:, 1, :]),
209
+ "dpe_ids": torch.LongTensor(data[:, 2, :]),
210
+ }
211
+
212
+ if self.label:
213
+ for i, task in enumerate(self.tasks):
214
+ ret[task] = self.labels[self.subjects[index]][task]
215
+ if isinstance(ret[task], str):
216
+ ret[task] = eval(ret[task])
217
+ # for multi-label classification, where the label is given by a list of class indices
218
+ if isinstance(ret[task], list):
219
+ ret[task] = list(map(int, ret[task]))
220
+ num_label = self.num_labels[i]
221
+ label = np.zeros(num_label, dtype=np.int16)
222
+ label[ret[task]] = 1
223
+ ret[task] = torch.tensor(label)
224
+ else:
225
+ if np.isnan(ret[task]) or ret[task] < 0:
226
+ ret[task] = self.ignore_index
227
+ ret[task] = torch.tensor(ret[task])
228
+ return ret
229
+
230
+
231
+ class FlattenedGenHPFDataset(GenHPFDataset):
232
+ def __init__(
233
+ self,
234
+ manifest_paths: List[str],
235
+ label: bool = False,
236
+ tasks: List[str] = None,
237
+ num_labels: List[int] = None,
238
+ **kwargs,
239
+ ):
240
+ kwargs.pop("structure", None)
241
+ super().__init__(manifest_paths=manifest_paths, structure="flattened", **kwargs)
242
+
243
+ self.label = label
244
+ self.tasks = tasks
245
+ self.num_labels = num_labels
246
+
247
+ def sample_crop_indices(self, size, diff):
248
+ if self.mask:
249
+ start = np.random.randint(0, diff + 1)
250
+ end = size - diff + start
251
+ else:
252
+ start = 0
253
+ end = size - diff
254
+ return start, end
255
+
256
+ def pad_to_max_size(self, sample, max_len):
257
+ if len(sample) < max_len:
258
+ sample = np.concatenate([sample, np.zeros(max_len - len(sample), dtype=np.int16)])
259
+ else:
260
+ sample = sample[:max_len]
261
+ return sample
262
+
263
+ def collator(self, samples):
264
+ samples = [s for s in samples if s["input_ids"] is not None]
265
+ if len(samples) == 0:
266
+ return {}
267
+
268
+ if self.simclr:
269
+ input_ids = sum(
270
+ [
271
+ [s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
272
+ for s in samples
273
+ ],
274
+ [],
275
+ )
276
+ type_ids = sum(
277
+ [
278
+ [s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
279
+ for s in samples
280
+ ],
281
+ [],
282
+ )
283
+ dpe_ids = sum(
284
+ [
285
+ [s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
286
+ for s in samples
287
+ ],
288
+ [],
289
+ )
290
+ if self.apply_mask:
291
+ input_label = sum(
292
+ [
293
+ [
294
+ s["input_label"][: len(s["input_ids"]) // 2],
295
+ s["input_label"][len(s["input_ids"]) // 2 :],
296
+ ]
297
+ for s in samples
298
+ ],
299
+ [],
300
+ )
301
+ else:
302
+ input_ids = [s["input_ids"] for s in samples]
303
+ type_ids = [s["type_ids"] for s in samples]
304
+ dpe_ids = [s["dpe_ids"] for s in samples]
305
+ if self.apply_mask:
306
+ input_label = [s["input_label"] for s in samples]
307
+ type_label = [s["type_label"] for s in samples]
308
+ dpe_label = [s["dpe_label"] for s in samples]
309
+
310
+ sizes = [s.size(0) for s in input_ids]
311
+ target_size = max(sizes)
312
+
313
+ collated_input_ids = input_ids[0].new_zeros((len(input_ids), target_size)).long()
314
+ collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size)).long()
315
+ collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size)).long()
316
+ if self.apply_mask:
317
+ collated_input_label = input_label[0].new_zeros((len(input_label), target_size)).long()
318
+ collated_type_label = type_label[0].new_zeros((len(type_label), target_size)).long()
319
+ collated_dpe_label = dpe_label[0].new_zeros((len(dpe_label), target_size)).long()
320
+ for i, size in enumerate(sizes):
321
+ diff = size - target_size
322
+ if diff == 0:
323
+ collated_input_ids[i] = input_ids[i]
324
+ collated_type_ids[i] = type_ids[i]
325
+ collated_dpe_ids[i] = dpe_ids[i]
326
+ if self.apply_mask:
327
+ collated_input_label[i] = input_label[i]
328
+ collated_type_label[i] = type_label[i]
329
+ collated_dpe_label[i] = dpe_label[i]
330
+ elif diff < 0:
331
+ collated_input_ids[i] = torch.cat(
332
+ [
333
+ input_ids[i],
334
+ input_ids[i].new_zeros(
335
+ -diff,
336
+ ),
337
+ ],
338
+ dim=0,
339
+ )
340
+ collated_type_ids[i] = torch.cat(
341
+ [
342
+ type_ids[i],
343
+ type_ids[i].new_zeros(
344
+ -diff,
345
+ ),
346
+ ],
347
+ dim=0,
348
+ )
349
+ collated_dpe_ids[i] = torch.cat(
350
+ [
351
+ dpe_ids[i],
352
+ dpe_ids[i].new_zeros(
353
+ -diff,
354
+ ),
355
+ ],
356
+ dim=0,
357
+ )
358
+ if self.apply_mask:
359
+ collated_input_label[i] = torch.cat(
360
+ [
361
+ input_label[i],
362
+ input_label[i].new_zeros(
363
+ -diff,
364
+ ),
365
+ ],
366
+ dim=0,
367
+ )
368
+ collated_type_label[i] = torch.cat(
369
+ [
370
+ type_label[i],
371
+ type_label[i].new_zeros(
372
+ -diff,
373
+ ),
374
+ ],
375
+ dim=0,
376
+ )
377
+ collated_dpe_label[i] = torch.cat(
378
+ [
379
+ dpe_label[i],
380
+ dpe_label[i].new_zeros(
381
+ -diff,
382
+ ),
383
+ ],
384
+ dim=0,
385
+ )
386
+ else:
387
+ raise ValueError(f"size mismatch, expected <={target_size}, got {size}")
388
+
389
+ out = {"id": [s["id"] for s in samples]}
390
+ out["net_input"] = {
391
+ "input_ids": collated_input_ids,
392
+ "type_ids": collated_type_ids,
393
+ "dpe_ids": collated_dpe_ids,
394
+ }
395
+
396
+ if self.apply_mask:
397
+ out["input_label"] = collated_input_label
398
+ out["type_label"] = collated_type_label
399
+ out["dpe_label"] = collated_dpe_label
400
+
401
+ if self.label:
402
+ label = {}
403
+ for task in self.tasks:
404
+ label[task] = torch.stack([s[task] for s in samples])
405
+ out["label"] = label
406
+
407
+ return out
408
+
409
+ def __getitem__(self, index):
410
+ data_index, subject = self.subjects[index]
411
+ data = self.data[data_index][subject][self.structure][:]
412
+
413
+ if self.apply_mask:
414
+ data, mlm_labels = self.mask(
415
+ data,
416
+ mask_prob=self.mask_prob,
417
+ vocab_size=self.vocab_size,
418
+ mask_token_id=self.mask_token_id,
419
+ mask_unit=self.mask_unit,
420
+ sep_token_id=self.sep_token_id,
421
+ )
422
+
423
+ ret = {
424
+ "id": self.subjects[index],
425
+ "input_ids": torch.LongTensor(data[0, :]),
426
+ "type_ids": torch.LongTensor(data[1, :]),
427
+ "dpe_ids": torch.LongTensor(data[2, :]),
428
+ }
429
+ if self.apply_mask:
430
+ ret["input_label"] = torch.LongTensor(mlm_labels[0, :])
431
+ ret["type_label"] = torch.LongTensor(mlm_labels[1, :])
432
+ ret["dpe_label"] = torch.LongTensor(mlm_labels[2, :])
433
+
434
+ if self.label:
435
+ for i, task in enumerate(self.tasks):
436
+ ret[task] = self.labels[self.subjects[index]][task]
437
+ if isinstance(ret[task], str):
438
+ ret[task] = eval(ret[task])
439
+ # for multi-label classification, where the label is given by a list of class indices
440
+ if isinstance(ret[task], list):
441
+ ret[task] = list(map(int, ret[task]))
442
+ num_label = self.num_labels[i]
443
+ label = np.zeros(num_label, dtype=np.int16)
444
+ label[ret[task]] = 1
445
+ ret[task] = torch.tensor(label)
446
+ else:
447
+ if np.isnan(ret[task]) or ret[task] < 0:
448
+ ret[task] = self.ignore_index
449
+ ret[task] = torch.tensor(ret[task])
450
+
451
+ return ret
@@ -0,0 +1,232 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Union
4
+
5
+ import h5pickle
6
+ import numpy as np
7
+ import torch
8
+
9
+ from genhpf.datasets.dataset import BaseDataset
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class MEDSDataset(BaseDataset):
15
+ def __init__(
16
+ self,
17
+ manifest_paths: List[str],
18
+ structure: str = "hierarchical",
19
+ vocab_size: int = 28996,
20
+ pad_token_id: int = 0,
21
+ sep_token_id: int = 102,
22
+ ignore_index: int = -100,
23
+ apply_mask: bool = False,
24
+ mask_token_id: int = 103,
25
+ mask_prob: float = 0,
26
+ mask_unit: str = "individual",
27
+ simclr: bool = False,
28
+ debug: bool = False,
29
+ **kwargs,
30
+ ):
31
+ super().__init__()
32
+
33
+ if structure == "hierarchical":
34
+ structure = "hi"
35
+ elif structure == "flattened":
36
+ raise NotImplementedError("Flattened structure is not supported yet.")
37
+ self.structure = structure
38
+
39
+ self.pad_token_id = pad_token_id
40
+
41
+ self.ignore_index = ignore_index
42
+
43
+ self.apply_mask = apply_mask
44
+ self.mask_prob = mask_prob
45
+ self.vocab_size = vocab_size
46
+ self.mask_token_id = mask_token_id
47
+ self.mask_unit = mask_unit
48
+ self.sep_token_id = sep_token_id
49
+
50
+ self.simclr = simclr
51
+
52
+ for k, v in kwargs.items():
53
+ self.__setattr__(k, v)
54
+
55
+ self.data = []
56
+ self.subjects = []
57
+ self.shard_ids = []
58
+ self.sizes = []
59
+ for i, manifest_path in enumerate(manifest_paths):
60
+ with open(manifest_path, "r") as f:
61
+ data_i_root = f.readline().strip()
62
+ shard_ids = []
63
+ for j, line in enumerate(f):
64
+ if debug and j >= 300:
65
+ break
66
+ items = line.strip().split("\t")
67
+ assert len(items) == 3, line
68
+ subject_id, num_events, shard_id = items
69
+ self.subjects.append((i, subject_id))
70
+ self.sizes.append(int(num_events))
71
+ shard_ids.append(int(shard_id))
72
+
73
+ data_i = {}
74
+ unique_shard_ids = np.unique(shard_ids)
75
+ for shard_id in unique_shard_ids:
76
+ data_i[shard_id] = h5pickle.File(os.path.join(data_i_root, f"{shard_id}.h5"))["ehr"]
77
+ self.data.append(data_i)
78
+ self.shard_ids.extend(shard_ids)
79
+ logger.info(f"loaded {len(self.subjects)} samples from {len(manifest_paths)} dataset(s)")
80
+
81
+ def __len__(self):
82
+ return len(self.subjects)
83
+
84
+
85
+ class HierarchicalMEDSDataset(MEDSDataset):
86
+ def __init__(
87
+ self,
88
+ manifest_paths: List[str],
89
+ max_events: int = 256,
90
+ label: bool = False,
91
+ tasks: List[str] = None,
92
+ num_labels: List[int] = None,
93
+ dummy_token_id: int = 101,
94
+ **kwargs,
95
+ ):
96
+ kwargs.pop("structure", None)
97
+ super().__init__(manifest_paths=manifest_paths, structure="hierarchical", **kwargs)
98
+
99
+ self.max_events = max_events
100
+ self.label = label
101
+ self.tasks = tasks
102
+ self.num_labels = num_labels
103
+ self.dummy_token_id = dummy_token_id
104
+
105
+ def mask(self, tokens: Union[np.ndarray, torch.Tensor], **kwargs):
106
+ for i, event in enumerate(tokens):
107
+ tokens[i], _ = super().mask(event, **kwargs)
108
+ return tokens
109
+
110
+ def collator(self, samples):
111
+ samples = [s for s in samples if s["input_ids"] is not None]
112
+ if len(samples) == 0:
113
+ return {}
114
+
115
+ if self.simclr:
116
+ input_ids = sum(
117
+ [
118
+ [s["input_ids"][: len(s["input_ids"]) // 2], s["input_ids"][len(s["input_ids"]) // 2 :]]
119
+ for s in samples
120
+ ],
121
+ [],
122
+ )
123
+ type_ids = sum(
124
+ [
125
+ [s["type_ids"][: len(s["input_ids"]) // 2], s["type_ids"][len(s["input_ids"]) // 2 :]]
126
+ for s in samples
127
+ ],
128
+ [],
129
+ )
130
+ dpe_ids = sum(
131
+ [
132
+ [s["dpe_ids"][: len(s["input_ids"]) // 2], s["dpe_ids"][len(s["input_ids"]) // 2 :]]
133
+ for s in samples
134
+ ],
135
+ [],
136
+ )
137
+ else:
138
+ input_ids = [s["input_ids"] for s in samples]
139
+ type_ids = [s["type_ids"] for s in samples]
140
+ dpe_ids = [s["dpe_ids"] for s in samples]
141
+
142
+ sizes = [s.size(0) for s in input_ids]
143
+ target_size = self.max_events
144
+
145
+ collated_input_ids = (
146
+ input_ids[0].new_zeros((len(input_ids), target_size, len(input_ids[0][0]))).long()
147
+ )
148
+ collated_type_ids = type_ids[0].new_zeros((len(type_ids), target_size, len(type_ids[0][0]))).long()
149
+ collated_dpe_ids = dpe_ids[0].new_zeros((len(dpe_ids), target_size, len(dpe_ids[0][0]))).long()
150
+ for i, size in enumerate(sizes):
151
+ diff = size - target_size
152
+ if diff == 0:
153
+ collated_input_ids[i] = input_ids[i]
154
+ collated_type_ids[i] = type_ids[i]
155
+ collated_dpe_ids[i] = dpe_ids[i]
156
+ elif diff < 0:
157
+ collated_input_ids[i] = torch.cat(
158
+ [
159
+ input_ids[i],
160
+ input_ids[i].new_zeros(-diff, len(input_ids[i][0])),
161
+ ],
162
+ dim=0,
163
+ )
164
+ # add dummy token to the start of each padded event as the event encoder can be
165
+ # crushed when all the input tokens are pad tokens
166
+ collated_input_ids[i][diff:, 0] = self.dummy_token_id
167
+ collated_type_ids[i] = torch.cat(
168
+ [
169
+ type_ids[i],
170
+ type_ids[i].new_zeros(-diff, len(type_ids[i][0])),
171
+ ],
172
+ dim=0,
173
+ )
174
+ collated_dpe_ids[i] = torch.cat(
175
+ [
176
+ dpe_ids[i],
177
+ dpe_ids[i].new_zeros(-diff, len(dpe_ids[i][0])),
178
+ ],
179
+ dim=0,
180
+ )
181
+ else:
182
+ collated_input_ids[i] = input_ids[i][-target_size:]
183
+ collated_type_ids[i] = type_ids[i][-target_size:]
184
+ collated_dpe_ids[i] = dpe_ids[i][-target_size:]
185
+
186
+ out = {"id": [s["id"] for s in samples]}
187
+ out["net_input"] = {
188
+ "input_ids": collated_input_ids,
189
+ "type_ids": collated_type_ids,
190
+ "dpe_ids": collated_dpe_ids,
191
+ }
192
+
193
+ if self.label:
194
+ label = {}
195
+ for task in self.tasks:
196
+ if len(samples[0][task]) == 1:
197
+ label[task] = torch.cat([s[task] for s in samples])
198
+ else:
199
+ label[task] = torch.stack([s[task] for s in samples])
200
+ out["label"] = label
201
+
202
+ return out
203
+
204
+ def __getitem__(self, idx):
205
+ data_idx, subject = self.subjects[idx]
206
+ data = self.data[data_idx][self.shard_ids[idx]][subject]
207
+
208
+ tokens = data[self.structure][:]
209
+ if self.apply_mask:
210
+ tokens = self.mask(
211
+ tokens,
212
+ mask_prob=self.mask_prob,
213
+ vocab_size=self.vocab_size,
214
+ mask_token_id=self.mask_token_id,
215
+ mask_unit=self.mask_unit,
216
+ sep_token_id=self.sep_token_id,
217
+ )
218
+
219
+ ret = {
220
+ "id": subject,
221
+ "input_ids": torch.LongTensor(tokens[:, 0, :]),
222
+ "type_ids": torch.LongTensor(tokens[:, 1, :]),
223
+ "dpe_ids": torch.LongTensor(tokens[:, 2, :]),
224
+ }
225
+
226
+ if self.label:
227
+ for i, task in enumerate(self.tasks):
228
+ try:
229
+ ret[task] = torch.LongTensor(data["label"][i])
230
+ except ValueError:
231
+ ret[task] = torch.LongTensor([data["label"][()]])
232
+ return ret
File without changes