genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,130 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import genhpf.utils.utils as utils
9
+ from genhpf.criterions import BaseCriterion, register_criterion
10
+ from genhpf.criterions.criterion import CriterionConfig
11
+ from genhpf.loggings import metrics
12
+ from genhpf.loggings.meters import safe_round
13
+
14
+
15
+ @dataclass
16
+ class Wav2Vec2CriterionConfig(CriterionConfig):
17
+ loss_weights: Optional[List[float]] = field(
18
+ default=None, metadata={"help": "weights for additional loss terms (not first one)"}
19
+ )
20
+
21
+
22
+ @register_criterion("wav2vec2_criterion", dataclass=Wav2Vec2CriterionConfig)
23
+ class Wav2Vec2Criterion(BaseCriterion):
24
+ def __init__(self, cfg: Wav2Vec2CriterionConfig):
25
+ super().__init__(cfg)
26
+
27
+ self.loss_weights = cfg.loss_weights
28
+
29
+ def compute_loss(
30
+ self, logits: torch.Tensor, targets: torch.Tensor, sample=None, net_output=None, model=None
31
+ ) -> Tuple[torch.Tensor, List[float]]:
32
+ """Compute the loss given the logits and targets from the model."""
33
+
34
+ losses = []
35
+
36
+ loss = F.cross_entropy(logits, targets, reduction="sum")
37
+ losses.append(loss.detach().item())
38
+
39
+ sample_size = self.get_sample_size(sample, targets)
40
+
41
+ if self.loss_weights is not None:
42
+ assert hasattr(model, "get_extra_losses")
43
+ extra_losses = model.get_extra_losses(net_output)
44
+ if torch.is_tensor(extra_losses):
45
+ extra_losses = [extra_losses]
46
+ if len(self.loss_weights) == 1 and len(extra_losses) != 1:
47
+ self.loss_weights = [self.weights[0]] * len(extra_losses)
48
+ assert len(extra_losses) == len(
49
+ self.loss_weights
50
+ ), f"{len(extra_losses)}, {len(self.loss_weights)}"
51
+ for p, coef in zip(extra_losses, self.loss_weights):
52
+ if coef != 0 and p is not None:
53
+ p = coef * p.float() * sample_size
54
+ loss += p
55
+ losses.append(p.detach().item())
56
+
57
+ return loss, losses
58
+
59
+ def get_sample_size(self, sample, targets):
60
+ """
61
+ Get the sample size, which is used as the denominator for the gradient
62
+ """
63
+ if "sample_size" in sample:
64
+ sample_size = sample["sample_size"]
65
+ else:
66
+ sample_size = targets.numel()
67
+ return sample_size
68
+
69
+ def get_logging_outputs(
70
+ self, logging_output, logits: torch.Tensor, targets: torch.Tensor, sample=None
71
+ ) -> List[Dict[str, Any]]:
72
+ """
73
+ Get the logging output to display while training
74
+ """
75
+ with torch.no_grad():
76
+ if logits.numel() == 0:
77
+ corr = 0
78
+ count = 0
79
+ else:
80
+ assert logits.dim() > 1, logits.shape
81
+ max = logits.argmax(-1) == 0
82
+ min = logits.argmin(-1) == 0
83
+
84
+ both = max & min
85
+ corr = max.long().sum().item() - both.long().sum().item()
86
+ count = float(max.numel())
87
+
88
+ logging_output["correct"] = corr
89
+ logging_output["count"] = count
90
+ return logging_output
91
+
92
+ @staticmethod
93
+ def reduce_metrics(logging_outputs: List[Dict[str, Any]], prefix: str = None) -> None:
94
+ """Aggregate logging outputs from data parallel training."""
95
+ if prefix is None:
96
+ prefix = ""
97
+ elif prefix is not None and not prefix.endswith("_"):
98
+ prefix = prefix + "_"
99
+
100
+ loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
101
+
102
+ sample_size = utils.item(sum(log.get("sample_size", 0) for log in logging_outputs))
103
+
104
+ metrics.log_scalar(f"{prefix}loss", loss_sum / (sample_size or 1) / math.log(2), sample_size, round=3)
105
+
106
+ correct = sum(log.get("correct", 0) for log in logging_outputs)
107
+ metrics.log_scalar(f"_{prefix}correct", correct)
108
+
109
+ total = sum(log.get("count", 0) for log in logging_outputs)
110
+ metrics.log_scalar(f"_{prefix}total", total)
111
+
112
+ if total > 0:
113
+ metrics.log_derived(
114
+ f"{prefix}accuracy",
115
+ lambda meters: safe_round(meters[f"_{prefix}correct"].sum / meters[f"_{prefix}total"].sum, 5)
116
+ if meters[f"_{prefix}total"].sum > 0
117
+ else float("nan"),
118
+ )
119
+
120
+ builtin_keys = {"loss", "sample_size", "correct", "count"}
121
+
122
+ for k in logging_outputs[0]:
123
+ if k not in builtin_keys:
124
+ val = sum(log.get(k, 0) for log in logging_outputs)
125
+ if k.startswith("loss"):
126
+ metrics.log_scalar(
127
+ prefix + k, val / (sample_size or 1) / math.log(2), sample_size, round=3
128
+ )
129
+ else:
130
+ metrics.log_scalar(prefix + k, val / len(logging_outputs), round=3)
@@ -0,0 +1,84 @@
1
+ from typing import List
2
+
3
+ from genhpf.configs import Config
4
+
5
+ from .dataset import BaseDataset
6
+ from .genhpf_dataset import FlattenedGenHPFDataset, HierarchicalGenHPFDataset
7
+ from .meds_dataset import HierarchicalMEDSDataset
8
+
9
+ __all__ = ["BaseDataset", "HierarchicalGenHPFDataset", "FlattenedGenHPFDataset", "HierarchicalMEDSDataset"]
10
+
11
+
12
+ def load_dataset(
13
+ data_path: str,
14
+ subsets: List[str],
15
+ cfg: Config,
16
+ ):
17
+ dataset_cfg = cfg.dataset
18
+ model_cfg = cfg.model
19
+ criterion_cfg = cfg.criterion
20
+
21
+ manifest_paths = [f"{data_path}/{subset.strip()}.tsv" for subset in subsets]
22
+
23
+ if dataset_cfg.data_format == "genhpf":
24
+ if model_cfg.structure == "hierarchical":
25
+ dataset = HierarchicalGenHPFDataset(
26
+ manifest_paths=manifest_paths,
27
+ label=dataset_cfg.label,
28
+ tasks=getattr(criterion_cfg, "task_names", None),
29
+ num_labels=getattr(criterion_cfg, "num_labels", None),
30
+ vocab_size=dataset_cfg.vocab_size,
31
+ pad_token_id=dataset_cfg.pad_token_id,
32
+ sep_token_id=dataset_cfg.sep_token_id,
33
+ ignore_index=dataset_cfg.ignore_index,
34
+ apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
35
+ mask_token_id=dataset_cfg.mask_token_id,
36
+ mask_prob=dataset_cfg.mask_prob,
37
+ mask_unit=dataset_cfg.mask_unit,
38
+ simclr="simclr" in model_cfg._name,
39
+ dummy_token_id=dataset_cfg.dummy_token_id,
40
+ )
41
+ else:
42
+ dataset = FlattenedGenHPFDataset(
43
+ manifest_paths=manifest_paths,
44
+ label=dataset_cfg.label,
45
+ tasks=getattr(criterion_cfg, "task_names", None),
46
+ num_labels=getattr(criterion_cfg, "num_labels", None),
47
+ vocab_size=dataset_cfg.vocab_size,
48
+ pad_token_id=dataset_cfg.pad_token_id,
49
+ sep_token_id=dataset_cfg.sep_token_id,
50
+ ignore_index=dataset_cfg.ignore_index,
51
+ apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
52
+ mask_token_id=dataset_cfg.mask_token_id,
53
+ mask_prob=dataset_cfg.mask_prob,
54
+ mask_unit=dataset_cfg.mask_unit,
55
+ simclr="simclr" in model_cfg._name,
56
+ )
57
+ elif dataset_cfg.data_format == "meds":
58
+ assert model_cfg.structure == "hierarchical", (
59
+ "we currently only support hierarchical structure for MEDS dataset."
60
+ " please set model.structure to 'hierarchical'"
61
+ )
62
+ dataset = HierarchicalMEDSDataset(
63
+ manifest_paths=manifest_paths,
64
+ max_events=model_cfg.agg_max_seq_len,
65
+ label=dataset_cfg.label,
66
+ tasks=getattr(criterion_cfg, "task_names", None),
67
+ num_labels=getattr(criterion_cfg, "num_labels", None),
68
+ structure=model_cfg.structure,
69
+ vocab_size=dataset_cfg.vocab_size,
70
+ pad_token_id=dataset_cfg.pad_token_id,
71
+ sep_token_id=dataset_cfg.sep_token_id,
72
+ ignore_index=dataset_cfg.ignore_index,
73
+ apply_mask=dataset_cfg.apply_mask or "mlm" in model_cfg._name,
74
+ mask_token_id=dataset_cfg.mask_token_id,
75
+ mask_prob=dataset_cfg.mask_prob,
76
+ mask_unit=dataset_cfg.mask_unit,
77
+ simclr="simclr" in model_cfg._name,
78
+ dummy_token_id=dataset_cfg.dummy_token_id,
79
+ debug=cfg.common.debug,
80
+ )
81
+ else:
82
+ raise NotImplementedError(f"unsupported data format: {dataset_cfg.data_format}")
83
+
84
+ return dataset
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch.utils.data
6
+
7
+ from genhpf.configs import ChoiceEnum
8
+
9
+ MASK_UNIT_CHOICES = ChoiceEnum(["token", "event"])
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class BaseDataset(torch.utils.data.Dataset):
15
+ def __len__(self):
16
+ raise NotImplementedError
17
+
18
+ def __getitem__(self, index):
19
+ raise NotImplementedError
20
+
21
+ def collator(self, samples):
22
+ raise NotImplementedError
23
+
24
+ def mask(
25
+ self,
26
+ tokens: Union[np.ndarray, torch.Tensor],
27
+ mask_prob: float,
28
+ vocab_size: int,
29
+ mask_token_id: int,
30
+ mask_unit: MASK_UNIT_CHOICES = "token",
31
+ sep_token_id: int = None,
32
+ ):
33
+ """
34
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
35
+ """
36
+ assert 0.0 < mask_prob < 1.0, "mask_prob must be in the range (0.0, 1.0)"
37
+
38
+ if isinstance(tokens, np.ndarray):
39
+ tokens = torch.LongTensor(tokens)
40
+ tokens = tokens.long()
41
+ labels = tokens.clone()
42
+
43
+ assert tokens.dim() == 2, (
44
+ "input tokens must be 2D tensor, where the first dimension is the number of embeddings "
45
+ "(i.e., input_ids, type_ids, dpe_ids) and the second dimension is the length of token "
46
+ "sequence."
47
+ )
48
+
49
+ if mask_unit == "token":
50
+ probability_matrix = torch.full((labels.size(-1),), mask_prob)
51
+ # do not mask special tokens
52
+ if not hasattr(self, "tokenizer"):
53
+ from transformers import AutoTokenizer
54
+
55
+ self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
56
+
57
+ special_tokens_mask = torch.tensor(
58
+ self.tokenizer.get_special_tokens_mask(labels[0], already_has_special_tokens=True),
59
+ dtype=torch.bool,
60
+ )
61
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
62
+ probability_matrix[torch.where(labels[0] == self.pad_token_id)] = 0.0
63
+
64
+ mask_indices = torch.bernoulli(probability_matrix).bool()
65
+ while mask_indices.sum() == 0:
66
+ mask_indices = torch.bernoulli(probability_matrix).bool()
67
+ elif mask_unit == "event":
68
+ if sep_token_id is None:
69
+ logger.warning(
70
+ "sep_token_id is not provided. Using the default [SEP] token id (102) as "
71
+ "the event delimiter."
72
+ )
73
+ sep_token_id = 102 # token id for [SEP]
74
+ event_indices = torch.where(tokens[0] == sep_token_id)[0]
75
+ assert len(event_indices) > 1, (
76
+ "there must be at least two events in the input sequence to apply span masking. "
77
+ "check if you are using the hierarchical structure which is not supporting the "
78
+ "span masking method."
79
+ )
80
+ mask_indices = torch.zeros_like(tokens[0]).bool()
81
+ masked_event_indices = torch.randperm(len(event_indices) - 1)[
82
+ : round((len(event_indices) - 1) * mask_prob)
83
+ ]
84
+ for i in masked_event_indices:
85
+ mask_indices[event_indices[i] : event_indices[i + 1]] = True
86
+ else:
87
+ raise ValueError(f"mask_unit must be one of {MASK_UNIT_CHOICES}")
88
+
89
+ labels[:, ~mask_indices] = self.ignore_index # we only compute loss on masked tokens
90
+
91
+ # 80% of the time, we replace the masked input tokens with tokenizer.mask_token ([MASK])
92
+ mask_positions = torch.bernoulli(torch.full(mask_indices.shape, 0.8)).bool() & mask_indices
93
+ tokens[0, mask_positions] = mask_token_id # for input_ids
94
+ tokens[1, mask_positions] = 4 # for type_ids
95
+ tokens[2, mask_positions] = 15 # for dpe_ids
96
+
97
+ # 10% of the time, we replace the masked input tokens with random word
98
+ random_positions = (
99
+ torch.bernoulli(torch.full(mask_indices.shape, 0.5)).bool() & mask_indices & ~mask_positions
100
+ )
101
+ random_words = torch.randint(vocab_size, mask_indices.shape, dtype=torch.long)
102
+ tokens[0, random_positions] = random_words[random_positions]
103
+ random_types = torch.randint(7, mask_indices.shape, dtype=torch.long)
104
+ tokens[1, random_positions] = random_types[random_positions]
105
+ random_dpes = torch.randint(16, mask_indices.shape, dtype=torch.long)
106
+ tokens[2, random_positions] = random_dpes[random_positions]
107
+
108
+ # the rest of the time, we keep the masked input tokens unchanged
109
+ return tokens, labels