genhpf 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. genhpf/__init__.py +9 -0
  2. genhpf/configs/__init__.py +23 -0
  3. genhpf/configs/config.yaml +8 -0
  4. genhpf/configs/configs.py +240 -0
  5. genhpf/configs/constants.py +29 -0
  6. genhpf/configs/initialize.py +58 -0
  7. genhpf/configs/utils.py +29 -0
  8. genhpf/criterions/__init__.py +74 -0
  9. genhpf/criterions/binary_cross_entropy.py +114 -0
  10. genhpf/criterions/binary_cross_entropy_with_logits.py +115 -0
  11. genhpf/criterions/criterion.py +87 -0
  12. genhpf/criterions/cross_entropy.py +202 -0
  13. genhpf/criterions/multi_task_criterion.py +177 -0
  14. genhpf/criterions/simclr_criterion.py +84 -0
  15. genhpf/criterions/wav2vec2_criterion.py +130 -0
  16. genhpf/datasets/__init__.py +84 -0
  17. genhpf/datasets/dataset.py +109 -0
  18. genhpf/datasets/genhpf_dataset.py +451 -0
  19. genhpf/datasets/meds_dataset.py +232 -0
  20. genhpf/loggings/__init__.py +0 -0
  21. genhpf/loggings/meters.py +374 -0
  22. genhpf/loggings/metrics.py +155 -0
  23. genhpf/loggings/progress_bar.py +445 -0
  24. genhpf/models/__init__.py +73 -0
  25. genhpf/models/genhpf.py +244 -0
  26. genhpf/models/genhpf_mlm.py +64 -0
  27. genhpf/models/genhpf_predictor.py +73 -0
  28. genhpf/models/genhpf_simclr.py +58 -0
  29. genhpf/models/genhpf_wav2vec2.py +304 -0
  30. genhpf/modules/__init__.py +15 -0
  31. genhpf/modules/gather_layer.py +23 -0
  32. genhpf/modules/grad_multiply.py +12 -0
  33. genhpf/modules/gumbel_vector_quantizer.py +204 -0
  34. genhpf/modules/identity_layer.py +8 -0
  35. genhpf/modules/layer_norm.py +27 -0
  36. genhpf/modules/positional_encoding.py +24 -0
  37. genhpf/scripts/__init__.py +0 -0
  38. genhpf/scripts/preprocess/__init__.py +0 -0
  39. genhpf/scripts/preprocess/genhpf/README.md +75 -0
  40. genhpf/scripts/preprocess/genhpf/__init__.py +0 -0
  41. genhpf/scripts/preprocess/genhpf/ehrs/__init__.py +36 -0
  42. genhpf/scripts/preprocess/genhpf/ehrs/ehr.py +919 -0
  43. genhpf/scripts/preprocess/genhpf/ehrs/eicu.py +550 -0
  44. genhpf/scripts/preprocess/genhpf/ehrs/mimiciii.py +839 -0
  45. genhpf/scripts/preprocess/genhpf/ehrs/mimiciv.py +619 -0
  46. genhpf/scripts/preprocess/genhpf/main.py +175 -0
  47. genhpf/scripts/preprocess/genhpf/manifest.py +79 -0
  48. genhpf/scripts/preprocess/genhpf/sample_dataset.py +177 -0
  49. genhpf/scripts/preprocess/genhpf/utils/__init__.py +3 -0
  50. genhpf/scripts/preprocess/genhpf/utils/utils.py +16 -0
  51. genhpf/scripts/preprocess/manifest.py +83 -0
  52. genhpf/scripts/preprocess/preprocess_meds.py +674 -0
  53. genhpf/scripts/test.py +264 -0
  54. genhpf/scripts/train.py +365 -0
  55. genhpf/trainer.py +370 -0
  56. genhpf/utils/checkpoint_utils.py +171 -0
  57. genhpf/utils/data_utils.py +130 -0
  58. genhpf/utils/distributed_utils.py +497 -0
  59. genhpf/utils/file_io.py +170 -0
  60. genhpf/utils/pdb.py +38 -0
  61. genhpf/utils/utils.py +204 -0
  62. genhpf-1.0.11.dist-info/LICENSE +21 -0
  63. genhpf-1.0.11.dist-info/METADATA +202 -0
  64. genhpf-1.0.11.dist-info/RECORD +67 -0
  65. genhpf-1.0.11.dist-info/WHEEL +5 -0
  66. genhpf-1.0.11.dist-info/entry_points.txt +6 -0
  67. genhpf-1.0.11.dist-info/top_level.txt +1 -0
@@ -0,0 +1,244 @@
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from omegaconf import II
8
+
9
+ from genhpf.configs import BaseConfig, ChoiceEnum
10
+ from genhpf.modules import PositionalEncoding
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ GENHPF_MODEL_ARCH_CHOICES = ChoiceEnum(["hierarchical", "flattened"])
15
+ GENHPF_AGGREGATOR_ARCH_CHOICES = ChoiceEnum(["transformer", "performer"])
16
+ GENHPF_EMBEDDING_METHOD_CHOICES = ChoiceEnum(["code", "text"])
17
+
18
+
19
+ @dataclass
20
+ class GenHPFConfig(BaseConfig):
21
+ structure: GENHPF_MODEL_ARCH_CHOICES = field(
22
+ default="hierarchical",
23
+ metadata={"help": "Architecture choice for GenHPF. Choose from hierarchical or flattened"},
24
+ )
25
+ embedding_method: GENHPF_EMBEDDING_METHOD_CHOICES = field(
26
+ default="text", metadata={"help": "Embedding method choice for GenHPF. Choose from code or text"}
27
+ )
28
+
29
+ encoder_max_seq_len: int = field(
30
+ default=128,
31
+ metadata={
32
+ "help": "max sequence length for the event encoder, only used when structure is "
33
+ "hierarchical. this is the max number of tokens in an event."
34
+ },
35
+ )
36
+ agg_max_seq_len: int = field(
37
+ default=256,
38
+ metadata={
39
+ "help": "max sequence length for the event aggregator. In the hierarchical structure, "
40
+ "this is the max number of events in a sample. In the flattened structure, this is the"
41
+ "max sequence length of the flattened input."
42
+ },
43
+ )
44
+
45
+ # configs for event encoder in hierarchical structure
46
+ encoder_layers: int = field(default=2, metadata={"help": "num encoder layers in the transformer"})
47
+ encoder_embed_dim: int = field(default=128, metadata={"help": "encoder embedding dimension"})
48
+ encoder_ffn_embed_dim: int = field(default=512, metadata={"help": "encoder embedding dimension for FFN"})
49
+ encoder_attention_heads: int = field(
50
+ default=4, metadata={"help": "num attention heads in the transformer"}
51
+ )
52
+
53
+ # configs for event aggregator
54
+ agg_arch: GENHPF_AGGREGATOR_ARCH_CHOICES = field(
55
+ default="transformer",
56
+ metadata={
57
+ "help": "Architecture choice for the event aggregator. Choose from transformer or " "performer"
58
+ },
59
+ )
60
+ agg_layers: int = field(default=2, metadata={"help": "num layers in the transformer"})
61
+ agg_embed_dim: int = field(default=128, metadata={"help": "hidden dimension for the event aggregator"})
62
+ agg_ffn_embed_dim: int = field(
63
+ default=512, metadata={"help": "hidden dimension for the FFN in the event aggregator"}
64
+ )
65
+ agg_attention_heads: int = field(
66
+ default=4, metadata={"help": "num attention heads for the event aggregator"}
67
+ )
68
+
69
+ dropout: float = field(default=0.2, metadata={"help": "dropout probability"})
70
+
71
+ from_pretrained: Optional[str] = field(
72
+ default=None, metadata={"help": "path to the pretrained model if available"}
73
+ )
74
+
75
+ vocab_size: int = II("dataset.vocab_size")
76
+ debug: bool = II("common.debug")
77
+
78
+
79
+ class GenHPF(nn.Module):
80
+ def __init__(self, cfg: GenHPFConfig):
81
+ super().__init__()
82
+ self.cfg = cfg
83
+
84
+ if cfg.debug:
85
+ cfg.encoder_layers = 1
86
+ cfg.encoder_embed_dim = 32
87
+ cfg.encoder_ffn_embed_dim = 128
88
+ cfg.encoder_attention_heads = 2
89
+ cfg.agg_layers = 1
90
+ cfg.agg_embed_dim = 32
91
+ cfg.agg_ffn_embed_dim = 128
92
+ cfg.agg_attention_heads = 2
93
+
94
+ self.structure = cfg.structure
95
+ assert self.structure in GENHPF_MODEL_ARCH_CHOICES
96
+
97
+ self.embedding_method = cfg.embedding_method
98
+ assert self.embedding_method in GENHPF_EMBEDDING_METHOD_CHOICES
99
+
100
+ self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.encoder_embed_dim, padding_idx=0)
101
+ if self.embedding_method == "text":
102
+ # we currently use 7 token types and 16 digit places in this version
103
+ self.token_type_vocab_size = 7
104
+ self.digit_place_vocab_size = 16
105
+ self.token_type_embeddings = nn.Embedding(
106
+ self.token_type_vocab_size, cfg.encoder_embed_dim, padding_idx=0
107
+ )
108
+ self.digit_place_embeddings = nn.Embedding(
109
+ self.digit_place_vocab_size, cfg.encoder_embed_dim, padding_idx=0
110
+ )
111
+
112
+ max_length = cfg.encoder_max_seq_len if self.structure == "hierarchical" else cfg.agg_max_seq_len
113
+ self.positional_encodings = PositionalEncoding(cfg.encoder_embed_dim, cfg.dropout, max_length)
114
+
115
+ self.event_encoder = None
116
+ if self.structure == "hierarchical" and self.embedding_method == "text":
117
+ self.encoder_layer_norm = nn.LayerNorm(cfg.encoder_embed_dim, eps=1e-12)
118
+ encoder_layer = nn.TransformerEncoderLayer(
119
+ cfg.encoder_embed_dim,
120
+ cfg.encoder_attention_heads,
121
+ cfg.encoder_ffn_embed_dim,
122
+ cfg.dropout,
123
+ batch_first=True,
124
+ )
125
+ self.event_encoder = nn.TransformerEncoder(encoder_layer, self.cfg.encoder_layers)
126
+ self.post_encode_proj = nn.Linear(cfg.encoder_embed_dim, cfg.agg_embed_dim)
127
+ self.event_positional_encodings = PositionalEncoding(
128
+ cfg.agg_embed_dim, cfg.dropout, cfg.agg_max_seq_len
129
+ )
130
+
131
+ self.agg_layer_norm = nn.LayerNorm(cfg.agg_embed_dim, eps=1e-12)
132
+ if cfg.agg_arch == "transformer":
133
+ agg_layer = nn.TransformerEncoderLayer(
134
+ cfg.agg_embed_dim,
135
+ cfg.agg_attention_heads,
136
+ cfg.agg_ffn_embed_dim,
137
+ cfg.dropout,
138
+ batch_first=True,
139
+ )
140
+ self.event_aggregator = nn.TransformerEncoder(agg_layer, cfg.agg_layers)
141
+ elif cfg.agg_arch == "performer":
142
+ from performer_pytorch import Performer
143
+
144
+ self.event_aggregator = Performer(
145
+ dim=cfg.agg_embed_dim,
146
+ depth=cfg.agg_layers,
147
+ heads=cfg.agg_attention_heads,
148
+ dim_head=64,
149
+ nb_features=64,
150
+ reversible=True,
151
+ generalized_attention=True,
152
+ ff_dropout=cfg.dropout,
153
+ attn_dropout=cfg.dropout,
154
+ shift_tokens=True,
155
+ )
156
+ else:
157
+ raise NotImplementedError(f"Unsupported event aggregator architecture: {cfg.agg_arch}")
158
+
159
+ @classmethod
160
+ def build_model(cls, cfg):
161
+ """Build a new model instance."""
162
+ if cfg.from_pretrained:
163
+ return cls.from_pretrained(cfg, cfg.from_pretrained)
164
+ else:
165
+ return cls(cfg)
166
+
167
+ @classmethod
168
+ def from_pretrained(cls, cfg, checkpoint_path):
169
+ model = cls(cfg)
170
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
171
+
172
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
173
+ assert len(unexpected_keys) == 0, f"unexpected keys: {unexpected_keys}"
174
+
175
+ finetuning_parameter_names = []
176
+ if hasattr(model, "get_finetuning_parameter_names"):
177
+ finetuning_parameter_names = model.get_finetuning_parameter_names()
178
+ missing_keys = list(set(missing_keys) - set(finetuning_parameter_names))
179
+ assert len(missing_keys) == 0, f"missing keys: {missing_keys}"
180
+
181
+ logger.info(f"loaded pre-trained model from {checkpoint_path}")
182
+
183
+ return model
184
+
185
+ def get_logits(self, sample, net_output):
186
+ """Get logits from the model output."""
187
+ raise NotImplementedError
188
+
189
+ def get_targets(self, sample, net_output):
190
+ """Get targets from the sample or model output."""
191
+ raise NotImplementedError
192
+
193
+ def forward(
194
+ self,
195
+ input_ids: torch.Tensor,
196
+ type_ids: torch.Tensor = None,
197
+ dpe_ids: torch.Tensor = None,
198
+ padding_mask: torch.Tensor = None,
199
+ encoder_only: bool = False,
200
+ **kwargs,
201
+ ):
202
+ x = self.word_embeddings(input_ids)
203
+ if self.embedding_method == "text":
204
+ if type_ids is not None:
205
+ x += self.token_type_embeddings(type_ids)
206
+ if dpe_ids is not None:
207
+ x += self.digit_place_embeddings(dpe_ids)
208
+
209
+ if self.structure == "hierarchical":
210
+ assert input_ids.ndim == 3 # (batch, num_events, num_words)
211
+ batch_size, num_events = input_ids.shape[0], input_ids.shape[1]
212
+
213
+ x = x.view(batch_size * num_events, -1, self.cfg.encoder_embed_dim)
214
+ x = self.positional_encodings(x)
215
+ x = self.encoder_layer_norm(x)
216
+ if padding_mask is None:
217
+ padding_mask = input_ids.view(batch_size * num_events, -1).eq(0).to(x.device)
218
+ # x: (batch * num_events, num_words, embed_dim)
219
+ x = self.event_encoder(x, src_key_padding_mask=padding_mask)
220
+
221
+ if padding_mask.any():
222
+ x[padding_mask] = 0
223
+ x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
224
+ x = self.post_encode_proj(x).view(batch_size, num_events, -1) # (batch, num_events, embed_dim)
225
+
226
+ padding_mask = input_ids[:, :, 1].eq(0).to(x.device)
227
+
228
+ x = self.event_positional_encodings(x)
229
+ else:
230
+ assert input_ids.ndim == 2 # (batch, seq_len)
231
+ x = self.positional_encodings(x) # (batch, seq_len, embed_dim)
232
+ if padding_mask is None:
233
+ padding_mask = input_ids.eq(0).to(x.device)
234
+
235
+ if encoder_only:
236
+ return {
237
+ "x": x,
238
+ "padding_mask": padding_mask,
239
+ }
240
+
241
+ # breakpoint()
242
+ x = self.event_aggregator(x, src_key_padding_mask=padding_mask)
243
+
244
+ return x, padding_mask
@@ -0,0 +1,64 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from omegaconf import II
7
+
8
+ from genhpf.models import register_model
9
+ from genhpf.models.genhpf import GenHPF, GenHPFConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class GenHPFMLMConfig(GenHPFConfig):
16
+ ignore_index: int = II("dataset.ignore_index")
17
+
18
+
19
+ @register_model("genhpf_mlm", dataclass=GenHPFMLMConfig)
20
+ class GenHPFMLM(GenHPF):
21
+ def __init__(self, cfg: GenHPFMLMConfig):
22
+ super().__init__(cfg)
23
+
24
+ self.ignore_index = cfg.ignore_index
25
+
26
+ self.input_ids_proj = nn.Linear(cfg.agg_embed_dim, cfg.vocab_size)
27
+
28
+ @classmethod
29
+ def build_model(cls, cfg):
30
+ """Build a new model instance."""
31
+ return cls(cfg)
32
+
33
+ def get_logits(self, sample, net_output):
34
+ masked_indices = torch.where(
35
+ (sample["input_label"] > 0) & (sample["input_label"] != self.ignore_index)
36
+ )
37
+ return net_output["input_ids"][masked_indices]
38
+
39
+ def get_targets(self, sample, net_output):
40
+ masked_indices = torch.where(
41
+ (sample["input_label"] > 0) & (sample["input_label"] != self.ignore_index)
42
+ )
43
+ return sample["input_label"][masked_indices]
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: torch.Tensor,
48
+ type_ids: torch.Tensor = None,
49
+ dpe_ids: torch.Tensor = None,
50
+ padding_mask: torch.Tensor = None,
51
+ **kwargs,
52
+ ):
53
+ x, padding_mask = super().forward(
54
+ input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
55
+ )
56
+
57
+ input_ids = self.input_ids_proj(x)
58
+
59
+ return {"input_ids": input_ids, "padding_mask": padding_mask}
60
+
61
+ def get_pretraining_parameter_names(self):
62
+ ret = []
63
+ ret.extend(["input_ids_proj" + "." + x[0] for x in self.input_ids_proj.named_parameters()])
64
+ return ret
@@ -0,0 +1,73 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import List
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from omegaconf import II
8
+
9
+ from genhpf.models import register_model
10
+ from genhpf.models.genhpf import GenHPF, GenHPFConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class GenHPFPredictorConfig(GenHPFConfig):
17
+ tasks: List[str] = II("criterion.task_names")
18
+ num_labels: List[int] = II("criterion.num_labels")
19
+
20
+
21
+ @register_model("genhpf_predictor", dataclass=GenHPFPredictorConfig)
22
+ class GenHPFPredictor(GenHPF):
23
+ def __init__(self, cfg: GenHPFPredictorConfig):
24
+ super().__init__(cfg)
25
+
26
+ self.num_labels = cfg.num_labels
27
+ self.tasks = cfg.tasks
28
+ assert len(self.num_labels) == len(
29
+ cfg.tasks
30
+ ), "The number of num_labels must be equal to the number of tasks"
31
+
32
+ self.final_proj = nn.ModuleDict()
33
+ for i, task in enumerate(cfg.tasks):
34
+ self.final_proj[task] = nn.Linear(cfg.agg_embed_dim, self.num_labels[i])
35
+
36
+ def get_logits(self, sample, net_output):
37
+ if len(self.tasks) == 1:
38
+ return net_output[self.tasks[0]]
39
+ else:
40
+ return net_output
41
+
42
+ def get_targets(self, sample, net_output):
43
+ if len(self.tasks) == 1:
44
+ return sample["label"][self.tasks[0]]
45
+ else:
46
+ return sample["label"]
47
+
48
+ def forward(
49
+ self,
50
+ input_ids: torch.Tensor,
51
+ type_ids: torch.Tensor = None,
52
+ dpe_ids: torch.Tensor = None,
53
+ padding_mask: torch.Tensor = None,
54
+ **kwargs,
55
+ ):
56
+ x, padding_mask = super().forward(
57
+ input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
58
+ )
59
+
60
+ if padding_mask is not None and padding_mask.any():
61
+ x[padding_mask] = 0
62
+ x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
63
+
64
+ ret = {}
65
+ for task, proj in self.final_proj.items():
66
+ ret[task] = proj(x)
67
+
68
+ return ret
69
+
70
+ def get_finetuning_parameter_names(self):
71
+ ret = []
72
+ ret.extend(["final_proj" + "." + x[0] for x in self.final_proj.named_parameters()])
73
+ return ret
@@ -0,0 +1,58 @@
1
+ import logging
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+
6
+ from genhpf.models import register_model
7
+ from genhpf.models.genhpf import GenHPF, GenHPFConfig
8
+ from genhpf.modules import GatherLayer
9
+ from genhpf.utils import distributed_utils as dist_utils
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class GenHPFSimCLRConfig(GenHPFConfig):
16
+ all_gather: bool = field(
17
+ default=True, metadata={"help": "whether or not to apply all gather across different gpus"}
18
+ )
19
+
20
+
21
+ @register_model("genhpf_simclr", dataclass=GenHPFSimCLRConfig)
22
+ class GenHPFSimCLR(GenHPF):
23
+ def __init__(self, cfg: GenHPFSimCLRConfig):
24
+ super().__init__(cfg)
25
+
26
+ self.all_gather = cfg.all_gather
27
+
28
+ @classmethod
29
+ def build_model(cls, cfg):
30
+ """Build a new model instance."""
31
+ return cls(cfg)
32
+
33
+ def get_logits(self, sample, net_output):
34
+ return net_output
35
+
36
+ def get_targets(self, sample, net_output):
37
+ return None
38
+
39
+ def forward(
40
+ self,
41
+ input_ids: torch.Tensor,
42
+ type_ids: torch.Tensor = None,
43
+ dpe_ids: torch.Tensor = None,
44
+ padding_mask: torch.Tensor = None,
45
+ **kwargs,
46
+ ):
47
+ x, padding_mask = super().forward(
48
+ input_ids=input_ids, type_ids=type_ids, dpe_ids=dpe_ids, padding_mask=padding_mask, **kwargs
49
+ )
50
+
51
+ if padding_mask is not None and padding_mask.any():
52
+ x[padding_mask] = 0
53
+ x = torch.div(x.sum(dim=1), (x != 0).sum(dim=1))
54
+
55
+ if self.all_gather and dist_utils.get_data_parallel_world_size() > 1:
56
+ x = torch.cat(GatherLayer.apply(x), dim=0)
57
+
58
+ return x