genhpf 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- genhpf/__init__.py +9 -0
- genhpf/configs/__init__.py +23 -0
- genhpf/configs/config.yaml +8 -0
- genhpf/configs/configs.py +240 -0
- genhpf/configs/constants.py +29 -0
- genhpf/configs/initialize.py +58 -0
- genhpf/configs/utils.py +29 -0
- genhpf/criterions/__init__.py +74 -0
- genhpf/criterions/binary_cross_entropy.py +114 -0
- genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
- genhpf/criterions/criterion.py +87 -0
- genhpf/criterions/cross_entropy.py +202 -0
- genhpf/criterions/multi_task_criterion.py +177 -0
- genhpf/criterions/simclr_criterion.py +84 -0
- genhpf/criterions/wav2vec2_criterion.py +130 -0
- genhpf/datasets/__init__.py +84 -0
- genhpf/datasets/dataset.py +109 -0
- genhpf/datasets/genhpf_dataset.py +451 -0
- genhpf/datasets/meds_dataset.py +232 -0
- genhpf/loggings/__init__.py +0 -0
- genhpf/loggings/meters.py +374 -0
- genhpf/loggings/metrics.py +155 -0
- genhpf/loggings/progress_bar.py +445 -0
- genhpf/models/__init__.py +73 -0
- genhpf/models/genhpf.py +244 -0
- genhpf/models/genhpf_mlm.py +64 -0
- genhpf/models/genhpf_predictor.py +73 -0
- genhpf/models/genhpf_simclr.py +58 -0
- genhpf/models/genhpf_wav2vec2.py +304 -0
- genhpf/modules/__init__.py +15 -0
- genhpf/modules/gather_layer.py +23 -0
- genhpf/modules/grad_multiply.py +12 -0
- genhpf/modules/gumbel_vector_quantizer.py +204 -0
- genhpf/modules/identity_layer.py +8 -0
- genhpf/modules/layer_norm.py +27 -0
- genhpf/modules/positional_encoding.py +24 -0
- genhpf/scripts/__init__.py +0 -0
- genhpf/scripts/preprocess/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/README.md +75 -0
- genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
- genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
- genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
- genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
- genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
- genhpf/scripts/preprocess/genhpf/main.py +175 -0
- genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
- genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
- genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
- genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
- genhpf/scripts/preprocess/manifest.py +83 -0
- genhpf/scripts/preprocess/preprocess_meds.py +674 -0
- genhpf/scripts/test.py +264 -0
- genhpf/scripts/train.py +365 -0
- genhpf/trainer.py +370 -0
- genhpf/utils/checkpoint_utils.py +171 -0
- genhpf/utils/data_utils.py +130 -0
- genhpf/utils/distributed_utils.py +497 -0
- genhpf/utils/file_io.py +170 -0
- genhpf/utils/pdb.py +38 -0
- genhpf/utils/utils.py +204 -0
- genhpf-1.0.11.dist-info/LICENSE +21 -0
- genhpf-1.0.11.dist-info/METADATA +202 -0
- genhpf-1.0.11.dist-info/RECORD +67 -0
- genhpf-1.0.11.dist-info/WHEEL +5 -0
- genhpf-1.0.11.dist-info/entry_points.txt +6 -0
- genhpf-1.0.11.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
import genhpf.utils.utils as utils
|
|
9
|
+
from genhpf.criterions import BaseCriterion, register_criterion
|
|
10
|
+
from genhpf.criterions.criterion import CriterionConfig
|
|
11
|
+
from genhpf.loggings import metrics
|
|
12
|
+
from genhpf.loggings.meters import safe_round
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Wav2Vec2CriterionConfig(CriterionConfig):
|
|
17
|
+
loss_weights: Optional[List[float]] = field(
|
|
18
|
+
default=None, metadata={"help": "weights for additional loss terms (not first one)"}
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@register_criterion("wav2vec2_criterion", dataclass=Wav2Vec2CriterionConfig)
|
|
23
|
+
class Wav2Vec2Criterion(BaseCriterion):
|
|
24
|
+
def __init__(self, cfg: Wav2Vec2CriterionConfig):
|
|
25
|
+
super().__init__(cfg)
|
|
26
|
+
|
|
27
|
+
self.loss_weights = cfg.loss_weights
|
|
28
|
+
|
|
29
|
+
def compute_loss(
|
|
30
|
+
self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
|
|
31
|
+
) -> Tuple[torch.Tensor, List[float]]:
|
|
32
|
+
"""Compute the loss given the logits and targets from the model."""
|
|
33
|
+
|
|
34
|
+
losses = []
|
|
35
|
+
|
|
36
|
+
loss = F.cross_entropy(logits, targets, reduction="sum")
|
|
37
|
+
losses.append(loss.detach().item())
|
|
38
|
+
|
|
39
|
+
sample_size = self.get_sample_size(sample, targets)
|
|
40
|
+
|
|
41
|
+
if self.loss_weights is not None:
|
|
42
|
+
assert hasattr(model, "get_extra_losses")
|
|
43
|
+
extra_losses = model.get_extra_losses(net_output)
|
|
44
|
+
if torch.is_tensor(extra_losses):
|
|
45
|
+
extra_losses = [extra_losses]
|
|
46
|
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
|
47
|
+
self.loss_weights = [self.weights[0]] * len(extra_losses)
|
|
48
|
+
assert len(extra_losses) == len(
|
|
49
|
+
self.loss_weights
|
|
50
|
+
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
|
51
|
+
for p, coef in zip(extra_losses, self.loss_weights):
|
|
52
|
+
if coef != 0 and p is not None:
|
|
53
|
+
p = coef * p.float() * sample_size
|
|
54
|
+
loss += p
|
|
55
|
+
losses.append(p.detach().item())
|
|
56
|
+
|
|
57
|
+
return loss, losses
|
|
58
|
+
|
|
59
|
+
def get_sample_size(self, sample, targets):
|
|
60
|
+
"""
|
|
61
|
+
Get the sample size, which is used as the denominator for the gradient
|
|
62
|
+
"""
|
|
63
|
+
if "sample_size" in sample:
|
|
64
|
+
sample_size = sample["sample_size"]
|
|
65
|
+
else:
|
|
66
|
+
sample_size = targets.numel()
|
|
67
|
+
return sample_size
|
|
68
|
+
|
|
69
|
+
def get_logging_outputs(
|
|
70
|
+
self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
|
|
71
|
+
) -> List[Dict[str, Any]]:
|
|
72
|
+
"""
|
|
73
|
+
Get the logging output to display while training
|
|
74
|
+
"""
|
|
75
|
+
with torch.no_grad():
|
|
76
|
+
if logits.numel() == 0:
|
|
77
|
+
corr = 0
|
|
78
|
+
count = 0
|
|
79
|
+
else:
|
|
80
|
+
assert logits.dim() > 1, logits.shape
|
|
81
|
+
max = logits.argmax(-1) == 0
|
|
82
|
+
min = logits.argmin(-1) == 0
|
|
83
|
+
|
|
84
|
+
both = max & min
|
|
85
|
+
corr = max.long().sum().item() - both.long().sum().item()
|
|
86
|
+
count = float(max.numel())
|
|
87
|
+
|
|
88
|
+
logging_output["correct"] = corr
|
|
89
|
+
logging_output["count"] = count
|
|
90
|
+
return logging_output
|
|
91
|
+
|
|
92
|
+
@staticmethod
|
|
93
|
+
def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
|
|
94
|
+
"""Aggregate logging outputs from data parallel training."""
|
|
95
|
+
if prefix is None:
|
|
96
|
+
prefix = ""
|
|
97
|
+
elif prefix is not None and not prefix.endswith("_"):
|
|
98
|
+
prefix = prefix + "_"
|
|
99
|
+
|
|
100
|
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
|
101
|
+
|
|
102
|
+
sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
|
|
103
|
+
|
|
104
|
+
metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
|
|
105
|
+
|
|
106
|
+
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
|
107
|
+
metrics.log_scalar(f"_{prefix}correct", correct)
|
|
108
|
+
|
|
109
|
+
total = sum(log.get("count", 0) for log in logging_outputs)
|
|
110
|
+
metrics.log_scalar(f"_{prefix}total", total)
|
|
111
|
+
|
|
112
|
+
if total > 0:
|
|
113
|
+
metrics.log_derived(
|
|
114
|
+
f"{prefix}accuracy",
|
|
115
|
+
lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
|
|
116
|
+
if meters[f"_{prefix}total"].sum > 0
|
|
117
|
+
else float("nan"),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
builtin_keys = {"loss", "sample_size", "correct", "count"}
|
|
121
|
+
|
|
122
|
+
for k in logging_outputs[0]:
|
|
123
|
+
if k not in builtin_keys:
|
|
124
|
+
val = sum(log.get(k, 0) for log in logging_outputs)
|
|
125
|
+
if k.startswith("loss"):
|
|
126
|
+
metrics.log_scalar(
|
|
127
|
+
prefix + k, val / (sample_size or 1) / math.log(2), sample_size, round=3
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
metrics.log_scalar(prefix + k, val / len(logging_outputs), round=3)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from genhpf.configs import Config
|
|
4
|
+
|
|
5
|
+
from .dataset import BaseDataset
|
|
6
|
+
from .genhpf_dataset import FlattenedGenHPFDataset, HierarchicalGenHPFDataset
|
|
7
|
+
from .meds_dataset import HierarchicalMEDSDataset
|
|
8
|
+
|
|
9
|
+
__all__ = ["BaseDataset", "HierarchicalGenHPFDataset", "FlattenedGenHPFDataset", "HierarchicalMEDSDataset"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_dataset(
|
|
13
|
+
data_path: str,
|
|
14
|
+
subsets: List[str],
|
|
15
|
+
cfg: Config,
|
|
16
|
+
):
|
|
17
|
+
dataset_cfg = cfg.dataset
|
|
18
|
+
model_cfg = cfg.model
|
|
19
|
+
criterion_cfg = cfg.criterion
|
|
20
|
+
|
|
21
|
+
manifest_paths = [f"{data_path}/{subset.strip()}.tsv" for subset in subsets]
|
|
22
|
+
|
|
23
|
+
if dataset_cfg.data_format == "genhpf":
|
|
24
|
+
if model_cfg.structure == "hierarchical":
|
|
25
|
+
dataset = HierarchicalGenHPFDataset(
|
|
26
|
+
manifest_paths=manifest_paths,
|
|
27
|
+
label=dataset_cfg.label,
|
|
28
|
+
tasks=getattr(criterion_cfg, "task_names", None),
|
|
29
|
+
num_labels=getattr(criterion_cfg, "num_labels", None),
|
|
30
|
+
vocab_size=dataset_cfg.vocab_size,
|
|
31
|
+
pad_token_id=dataset_cfg.pad_token_id,
|
|
32
|
+
sep_token_id=dataset_cfg.sep_token_id,
|
|
33
|
+
ignore_index=dataset_cfg.ignore_index,
|
|
34
|
+
apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
|
|
35
|
+
mask_token_id=dataset_cfg.mask_token_id,
|
|
36
|
+
mask_prob=dataset_cfg.mask_prob,
|
|
37
|
+
mask_unit=dataset_cfg.mask_unit,
|
|
38
|
+
simclr="simclr" in model_cfg._name,
|
|
39
|
+
dummy_token_id=dataset_cfg.dummy_token_id,
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
dataset = FlattenedGenHPFDataset(
|
|
43
|
+
manifest_paths=manifest_paths,
|
|
44
|
+
label=dataset_cfg.label,
|
|
45
|
+
tasks=getattr(criterion_cfg, "task_names", None),
|
|
46
|
+
num_labels=getattr(criterion_cfg, "num_labels", None),
|
|
47
|
+
vocab_size=dataset_cfg.vocab_size,
|
|
48
|
+
pad_token_id=dataset_cfg.pad_token_id,
|
|
49
|
+
sep_token_id=dataset_cfg.sep_token_id,
|
|
50
|
+
ignore_index=dataset_cfg.ignore_index,
|
|
51
|
+
apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
|
|
52
|
+
mask_token_id=dataset_cfg.mask_token_id,
|
|
53
|
+
mask_prob=dataset_cfg.mask_prob,
|
|
54
|
+
mask_unit=dataset_cfg.mask_unit,
|
|
55
|
+
simclr="simclr" in model_cfg._name,
|
|
56
|
+
)
|
|
57
|
+
elif dataset_cfg.data_format == "meds":
|
|
58
|
+
assert model_cfg.structure == "hierarchical", (
|
|
59
|
+
"we currently only support hierarchical structure for MEDS dataset."
|
|
60
|
+
" please set model.structure to 'hierarchical'"
|
|
61
|
+
)
|
|
62
|
+
dataset = HierarchicalMEDSDataset(
|
|
63
|
+
manifest_paths=manifest_paths,
|
|
64
|
+
max_events=model_cfg.agg_max_seq_len,
|
|
65
|
+
label=dataset_cfg.label,
|
|
66
|
+
tasks=getattr(criterion_cfg, "task_names", None),
|
|
67
|
+
num_labels=getattr(criterion_cfg, "num_labels", None),
|
|
68
|
+
structure=model_cfg.structure,
|
|
69
|
+
vocab_size=dataset_cfg.vocab_size,
|
|
70
|
+
pad_token_id=dataset_cfg.pad_token_id,
|
|
71
|
+
sep_token_id=dataset_cfg.sep_token_id,
|
|
72
|
+
ignore_index=dataset_cfg.ignore_index,
|
|
73
|
+
apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
|
|
74
|
+
mask_token_id=dataset_cfg.mask_token_id,
|
|
75
|
+
mask_prob=dataset_cfg.mask_prob,
|
|
76
|
+
mask_unit=dataset_cfg.mask_unit,
|
|
77
|
+
simclr="simclr" in model_cfg._name,
|
|
78
|
+
dummy_token_id=dataset_cfg.dummy_token_id,
|
|
79
|
+
debug=cfg.common.debug,
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
raise NotImplementedError(f"unsupported data format: {dataset_cfg.data_format}")
|
|
83
|
+
|
|
84
|
+
return dataset
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch.utils.data
|
|
6
|
+
|
|
7
|
+
from genhpf.configs import ChoiceEnum
|
|
8
|
+
|
|
9
|
+
MASK_UNIT_CHOICES = ChoiceEnum(["token", "event"])
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseDataset(torch.utils.data.Dataset):
|
|
15
|
+
def __len__(self):
|
|
16
|
+
raise NotImplementedError
|
|
17
|
+
|
|
18
|
+
def __getitem__(self, index):
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
def collator(self, samples):
|
|
22
|
+
raise NotImplementedError
|
|
23
|
+
|
|
24
|
+
def mask(
|
|
25
|
+
self,
|
|
26
|
+
tokens: Union[np.ndarray, torch.Tensor],
|
|
27
|
+
mask_prob: float,
|
|
28
|
+
vocab_size: int,
|
|
29
|
+
mask_token_id: int,
|
|
30
|
+
mask_unit: MASK_UNIT_CHOICES = "token",
|
|
31
|
+
sep_token_id: int = None,
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
|
35
|
+
"""
|
|
36
|
+
assert 0.0 < mask_prob < 1.0, "mask_prob must be in the range (0.0, 1.0)"
|
|
37
|
+
|
|
38
|
+
if isinstance(tokens, np.ndarray):
|
|
39
|
+
tokens = torch.LongTensor(tokens)
|
|
40
|
+
tokens = tokens.long()
|
|
41
|
+
labels = tokens.clone()
|
|
42
|
+
|
|
43
|
+
assert tokens.dim() == 2, (
|
|
44
|
+
"input tokens must be 2D tensor, where the first dimension is the number of embeddings "
|
|
45
|
+
"(i.e., input_ids, type_ids, dpe_ids) and the second dimension is the length of token "
|
|
46
|
+
"sequence."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if mask_unit == "token":
|
|
50
|
+
probability_matrix = torch.full((labels.size(-1),), mask_prob)
|
|
51
|
+
# do not mask special tokens
|
|
52
|
+
if not hasattr(self, "tokenizer"):
|
|
53
|
+
from transformers import AutoTokenizer
|
|
54
|
+
|
|
55
|
+
self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
|
56
|
+
|
|
57
|
+
special_tokens_mask = torch.tensor(
|
|
58
|
+
self.tokenizer.get_special_tokens_mask(labels[0], already_has_special_tokens=True),
|
|
59
|
+
dtype=torch.bool,
|
|
60
|
+
)
|
|
61
|
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
|
62
|
+
probability_matrix[torch.where(labels[0] == self.pad_token_id)] = 0.0
|
|
63
|
+
|
|
64
|
+
mask_indices = torch.bernoulli(probability_matrix).bool()
|
|
65
|
+
while mask_indices.sum() == 0:
|
|
66
|
+
mask_indices = torch.bernoulli(probability_matrix).bool()
|
|
67
|
+
elif mask_unit == "event":
|
|
68
|
+
if sep_token_id is None:
|
|
69
|
+
logger.warning(
|
|
70
|
+
"sep_token_id is not provided. Using the default [SEP] token id (102) as "
|
|
71
|
+
"the event delimiter."
|
|
72
|
+
)
|
|
73
|
+
sep_token_id = 102 # token id for [SEP]
|
|
74
|
+
event_indices = torch.where(tokens[0] == sep_token_id)[0]
|
|
75
|
+
assert len(event_indices) > 1, (
|
|
76
|
+
"there must be at least two events in the input sequence to apply span masking. "
|
|
77
|
+
"check if you are using the hierarchical structure which is not supporting the "
|
|
78
|
+
"span masking method."
|
|
79
|
+
)
|
|
80
|
+
mask_indices = torch.zeros_like(tokens[0]).bool()
|
|
81
|
+
masked_event_indices = torch.randperm(len(event_indices) - 1)[
|
|
82
|
+
: round((len(event_indices) - 1) * mask_prob)
|
|
83
|
+
]
|
|
84
|
+
for i in masked_event_indices:
|
|
85
|
+
mask_indices[event_indices[i] : event_indices[i + 1]] = True
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"mask_unit must be one of {MASK_UNIT_CHOICES}")
|
|
88
|
+
|
|
89
|
+
labels[:, ~mask_indices] = self.ignore_index # we only compute loss on masked tokens
|
|
90
|
+
|
|
91
|
+
# 80% of the time, we replace the masked input tokens with tokenizer.mask_token ([MASK])
|
|
92
|
+
mask_positions = torch.bernoulli(torch.full(mask_indices.shape, 0.8)).bool() & mask_indices
|
|
93
|
+
tokens[0, mask_positions] = mask_token_id # for input_ids
|
|
94
|
+
tokens[1, mask_positions] = 4 # for type_ids
|
|
95
|
+
tokens[2, mask_positions] = 15 # for dpe_ids
|
|
96
|
+
|
|
97
|
+
# 10% of the time, we replace the masked input tokens with random word
|
|
98
|
+
random_positions = (
|
|
99
|
+
torch.bernoulli(torch.full(mask_indices.shape, 0.5)).bool() & mask_indices & ~mask_positions
|
|
100
|
+
)
|
|
101
|
+
random_words = torch.randint(vocab_size, mask_indices.shape, dtype=torch.long)
|
|
102
|
+
tokens[0, random_positions] = random_words[random_positions]
|
|
103
|
+
random_types = torch.randint(7, mask_indices.shape, dtype=torch.long)
|
|
104
|
+
tokens[1, random_positions] = random_types[random_positions]
|
|
105
|
+
random_dpes = torch.randint(16, mask_indices.shape, dtype=torch.long)
|
|
106
|
+
tokens[2, random_positions] = random_dpes[random_positions]
|
|
107
|
+
|
|
108
|
+
# the rest of the time, we keep the masked input tokens unchanged
|
|
109
|
+
return tokens, labels
|