cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,8 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Exponential, Gamma
9
+ from torch.distributions import Gamma
10
10
  from torch.nn import CrossEntropyLoss
11
- from torch.nn import functional as F
12
11
  from transformers import PreTrainedModel
13
12
  from transformers.activations import gelu_new
14
13
  from transformers.generation.logits_process import LogitsProcessorList
@@ -18,16 +17,20 @@ from transformers.generation.stopping_criteria import (
18
17
  )
19
18
  from transformers.generation.streamers import BaseStreamer
20
19
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
21
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
22
20
  from transformers.pytorch_utils import Conv1D
23
- from transformers.utils import (
24
- is_accelerate_available,
25
- is_flash_attn_2_available,
26
- logging,
27
- )
21
+ from transformers.utils import is_flash_attn_2_available, logging
28
22
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
29
23
 
24
+ from cehrgpt.gpt_utils import (
25
+ construct_age_sequence,
26
+ encode_demographics,
27
+ extract_time_interval_in_days,
28
+ is_att_token,
29
+ multiple_of_10,
30
+ )
31
+ from cehrgpt.models.activations import RMSNorm
30
32
  from cehrgpt.models.config import CEHRGPTConfig
33
+ from cehrgpt.models.gpt2 import GPT2Block, is_sample_pack
31
34
  from cehrgpt.models.hf_modeling_outputs import (
32
35
  CehrGptCausalLMOutput,
33
36
  CehrGptGenerateDecoderOnlyOutput,
@@ -35,13 +38,6 @@ from cehrgpt.models.hf_modeling_outputs import (
35
38
  CehrGptSequenceClassifierOutput,
36
39
  )
37
40
 
38
- if is_flash_attn_2_available():
39
- from flash_attn import flash_attn_func, flash_attn_varlen_func
40
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
41
-
42
- if is_accelerate_available():
43
- from accelerate.hooks import add_hook_to_module
44
-
45
41
  logger = logging.get_logger(__name__)
46
42
 
47
43
 
@@ -50,7 +46,7 @@ def extract_features_from_packed_sequence(
50
46
  attention_mask: torch.Tensor,
51
47
  ) -> torch.Tensor:
52
48
  max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
53
- padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
49
+ padded_attention_mask = f.pad(attention_mask[:, : max_index + 1], (0, 1))
54
50
  feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
55
51
  return hidden_state[:, feature_indices]
56
52
 
@@ -81,313 +77,41 @@ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.
81
77
  return attn_matrix
82
78
 
83
79
 
84
- def is_sample_pack(attention_mask: torch.Tensor) -> bool:
85
- """
86
- Determines whether any sequence in the batch is likely sample-packed.
87
-
88
- A sample-packed sequence is one where there are non-padding (1) tokens
89
- after a padding (0) token, indicating multiple sequences packed together
90
- with padding as a separator.
91
-
92
- Args:
93
- attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
94
- where 1 indicates a real token and 0 indicates padding.
95
-
96
- Returns:
97
- bool: True if any sample in the batch is sample-packed, False otherwise.
98
- """
99
-
100
- # If the attention_maks is left padded, we will flip it so we can use the same logic below
101
- if (attention_mask[:, 0] == 0).any():
102
- attention_mask = attention_mask.flip(dims=[1])
103
-
104
- nonzero_counts = attention_mask.sum(dim=1)
105
- max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
106
- max_indices = attention_mask.shape[1] - 1 - max_token_positions
107
- return torch.any(nonzero_counts < (max_indices + 1)).item()
108
-
109
-
110
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
111
- def _get_unpad_data(attention_mask):
112
- # This infers sample packing
113
- if is_sample_pack(attention_mask):
114
- # Assume input: attention_mask shape = (batch, seq_len)
115
- attention_mask = attention_mask.flatten() # shape: (seq_len,)
116
-
117
- # Compute max_index of the last non-zero element
118
- nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
119
- max_index = nonzero[-1].item()
120
-
121
- # Pad the truncated attention mask
122
- padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
123
-
124
- # Indices of all tokens
125
- indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
126
-
127
- # Find where 0s occur (segment boundaries)
128
- cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
129
- padded_attention_mask == 0
130
- ]
131
-
132
- # Compute seqlens per segment
133
- seqlens_in_batch = (
134
- cumsum_seqlens_in_batch
135
- - F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
136
- ).to(torch.int)
137
-
138
- max_seqlen_in_batch = (
139
- seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
140
- )
141
- cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
142
- else:
143
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
144
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
145
- max_seqlen_in_batch = seqlens_in_batch.max().item()
146
- cu_seqlens = F.pad(
147
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
148
- )
149
-
150
- return (
151
- indices,
152
- cu_seqlens,
153
- max_seqlen_in_batch,
154
- )
155
-
156
-
157
- class GPT2FlashAttention(GPT2Attention):
158
- """
159
- GPT2FlashAttention inherits from `GPT2Attention`.
160
-
161
- The primary change is in the forward pass, where it correctly
162
- calls the public API of flash attention and handles padding tokens.
163
- """
164
-
165
- def forward(
166
- self,
167
- hidden_states: Optional[Tuple[torch.FloatTensor]],
168
- layer_past: Optional[Tuple[torch.Tensor]] = None,
169
- attention_mask: Optional[torch.FloatTensor] = None,
170
- head_mask: Optional[torch.FloatTensor] = None,
171
- encoder_hidden_states: Optional[torch.Tensor] = None,
172
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
173
- use_cache: Optional[bool] = False,
174
- output_attentions: Optional[bool] = False,
175
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
176
- # Prepare query, key, and value
177
- if encoder_hidden_states is not None:
178
- if not hasattr(self, "q_attn"):
179
- raise ValueError(
180
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
181
- "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
182
- )
183
-
184
- query = self.q_attn(hidden_states)
185
- key, value = self.c_attn(encoder_hidden_states).split(
186
- self.split_size, dim=2
187
- )
188
- attention_mask = encoder_attention_mask
189
- else:
190
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
191
-
192
- query = self._split_heads(query, self.num_heads, self.head_dim)
193
- key = self._split_heads(key, self.num_heads, self.head_dim)
194
- value = self._split_heads(value, self.num_heads, self.head_dim)
195
-
196
- if layer_past is not None:
197
- past_key, past_value = layer_past
198
- key = torch.cat((past_key, key), dim=-2)
199
- value = torch.cat((past_value, value), dim=-2)
200
-
201
- if use_cache is True:
202
- present = (key, value)
203
- else:
204
- present = None
205
-
206
- # Apply Flash Attention Forward
207
- if self.reorder_and_upcast_attn:
208
- attn_output, attn_weights = self._upcast_and_reordered_attn(
209
- query, key, value, attention_mask, head_mask
210
- )
211
- else:
212
- # Flash Attention forward pass
213
- attn_output = self._flash_attention_forward(
214
- query,
215
- key,
216
- value,
217
- attention_mask,
218
- query.size(-2),
219
- self.attn_dropout.p,
220
- softmax_scale=None,
221
- )
222
-
223
- # Merge heads and project back to hidden size
224
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
225
- attn_output = self.c_proj(attn_output)
226
- attn_output = self.resid_dropout(attn_output)
227
-
228
- outputs = (attn_output, present)
229
- if output_attentions:
230
- outputs += (attn_weights,)
231
-
232
- return outputs
233
-
234
- def _flash_attention_forward(
80
+ class MotorTaskHead(nn.Module):
81
+ def __init__(
235
82
  self,
236
- query_states,
237
- key_states,
238
- value_states,
239
- attention_mask,
240
- query_length,
241
- dropout=0.0,
242
- softmax_scale=None,
83
+ input_dim,
84
+ motor_tte_vocab_size,
85
+ motor_num_time_pieces,
86
+ eps=1e-6,
243
87
  ):
244
- """
245
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
246
-
247
- first unpad the input, then computes the attention scores and pad the final attention scores.
248
- Args:
249
- query_states (`torch.Tensor`):
250
- Input query states to be passed to Flash Attention API
251
- key_states (`torch.Tensor`):
252
- Input key states to be passed to Flash Attention API
253
- value_states (`torch.Tensor`):
254
- Input value states to be passed to Flash Attention API
255
- attention_mask (`torch.Tensor`):
256
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
257
- position of padding tokens and 1 for the position of non-padding tokens.
258
- dropout (`int`, *optional*):
259
- Attention dropout
260
- softmax_scale (`float`, *optional*):
261
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
262
- """
263
-
264
- # Flash attention requires the input to have the shape
265
- # batch_size x seq_length x head_dim x hidden_dim
266
- # therefore we just need to keep the original shape
267
- dtype = query_states.dtype
268
- query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
269
- key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
270
- value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
271
-
272
- # Contains at least one padding token in the sequence
273
- if attention_mask is not None:
274
- batch_size = query_states.shape[0]
275
-
276
- (
277
- query_states,
278
- key_states,
279
- value_states,
280
- indices_q,
281
- cu_seq_lens,
282
- max_seq_lens,
283
- ) = self._upad_input(
284
- query_states, key_states, value_states, attention_mask, query_length
285
- )
286
-
287
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
288
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
289
-
290
- attn_output_unpad = flash_attn_varlen_func(
291
- query_states,
292
- key_states,
293
- value_states,
294
- cu_seqlens_q=cu_seqlens_q,
295
- cu_seqlens_k=cu_seqlens_k,
296
- max_seqlen_q=max_seqlen_in_batch_q,
297
- max_seqlen_k=max_seqlen_in_batch_k,
298
- dropout_p=dropout,
299
- softmax_scale=softmax_scale,
300
- causal=True,
301
- )
302
- # (batch, seq_length, n_heads, head_dim)
303
- attn_output = pad_input(
304
- attn_output_unpad, indices_q, batch_size, query_length
305
- )
306
- else:
307
- attn_output = flash_attn_func(
308
- query_states,
309
- key_states,
310
- value_states,
311
- dropout,
312
- softmax_scale=softmax_scale,
313
- causal=self.is_causal,
314
- )
315
- # re-order the tensor back to (batch, n_heads, seq_length, head_dim)
316
- return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
317
-
318
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
319
- def _upad_input(
320
- self, query_layer, key_layer, value_layer, attention_mask, query_length
321
- ):
322
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
323
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
324
-
325
- key_layer = index_first_axis(
326
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
327
- indices_k,
328
- )
329
- value_layer = index_first_axis(
330
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
331
- indices_k,
332
- )
333
- if query_length == kv_seq_len:
334
- query_layer = index_first_axis(
335
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
336
- indices_k,
337
- )
338
- cu_seqlens_q = cu_seqlens_k
339
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
340
- indices_q = indices_k
341
- elif query_length == 1:
342
- max_seqlen_in_batch_q = 1
343
- cu_seqlens_q = torch.arange(
344
- batch_size + 1, dtype=torch.int32, device=query_layer.device
345
- ) # There is a memcpy here, that is very bad.
346
- indices_q = cu_seqlens_q[:-1]
347
- query_layer = query_layer.squeeze(1)
348
- else:
349
- # The -q_len: slice assumes left padding.
350
- attention_mask = attention_mask[:, -query_length:]
351
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
352
- query_layer, attention_mask
353
- )
354
-
355
- return (
356
- query_layer,
357
- key_layer,
358
- value_layer,
359
- indices_q,
360
- (cu_seqlens_q, cu_seqlens_k),
361
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
362
- )
363
-
364
-
365
- class MotorTaskHead(nn.Module):
366
- def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
367
88
  super(MotorTaskHead, self).__init__()
368
89
  self.input_dim = input_dim
369
90
  self.motor_tte_vocab_size = motor_tte_vocab_size
370
91
  self.motor_num_time_pieces = motor_num_time_pieces
371
- self.linear = nn.Sequential(
372
- nn.Linear(input_dim, input_dim // 2),
373
- gelu_new,
374
- nn.Linear(
375
- input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
376
- ),
92
+ self.final_layer = nn.Linear(input_dim, input_dim * motor_num_time_pieces)
93
+ self.norm = RMSNorm(input_dim, eps)
94
+ self.task_layer = nn.Linear(input_dim, motor_tte_vocab_size)
95
+ self.task_time_bias = nn.Parameter(
96
+ torch.zeros(1, self.motor_num_time_pieces, motor_tte_vocab_size)
377
97
  )
378
98
 
379
99
  def forward(self, x):
380
100
  # Ensure scale is positive
381
101
  length = x.shape[0]
382
102
  # (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
383
- lambda_p = f.softplus(self.linear(x))
384
- # Check for NaN values
385
- if torch.isnan(lambda_p).any():
386
- logger.warning(f"NaN values found in scale_param. x: {x}")
387
- # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
388
- return lambda_p.view(
389
- length, self.motor_num_time_pieces, self.motor_tte_vocab_size
103
+ x = self.final_layer(x).reshape(
104
+ length, self.motor_num_time_pieces, self.input_dim
390
105
  )
106
+ x = self.norm(x)
107
+ x = self.task_layer(x) + self.task_time_bias
108
+ # lambda_p = f.softplus(x)
109
+
110
+ # # Check for NaN values
111
+ # if torch.isnan(lambda_p).any():
112
+ # logger.warning(f"NaN values found in scale_param. x: {x}")
113
+ # # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
114
+ return x
391
115
 
392
116
 
393
117
  class VisitTimeToEventHead(nn.Module):
@@ -723,7 +447,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
723
447
  def __init__(self, config: CEHRGPTConfig):
724
448
  super().__init__(config)
725
449
 
726
- self.exclude_position_ids = config.exclude_position_ids
727
450
  self.include_values = config.include_values
728
451
  self.include_ttv_prediction = config.include_ttv_prediction
729
452
  self.embed_dim = config.hidden_size
@@ -734,8 +457,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
734
457
  self.pretrained_wte = None
735
458
 
736
459
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
737
- if not self.exclude_position_ids:
738
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
739
460
  if self.include_values:
740
461
  self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
741
462
  self.concept_value_transformation_layer = ConceptValueTransformationLayer(
@@ -746,9 +467,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
746
467
  gpt_blocks = []
747
468
  for i in range(config.num_hidden_layers):
748
469
  gpt_block = GPT2Block(config, layer_idx=i)
749
- if getattr(config, "_attn_implementation", "eager") == "flash_attention_2":
750
- gpt_block.attn = GPT2FlashAttention(config, layer_idx=i)
751
- gpt_block.is_causal = True
470
+ gpt_block.is_causal = True
752
471
  gpt_blocks.append(gpt_block)
753
472
  self.h = nn.ModuleList(gpt_blocks)
754
473
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -769,10 +488,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
769
488
  )
770
489
  self.update_attn_bias(self.config.sample_packing_max_positions)
771
490
 
772
- def enable_position_embeddings(self):
773
- self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
774
- self.config.exclude_position_ids = False
775
-
776
491
  def initialize_pretrained_embeddings(self):
777
492
  layers = [
778
493
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -815,8 +530,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
815
530
  self.wte = self.wte.to(self.first_device)
816
531
  if self.config.use_pretrained_embeddings:
817
532
  self.pretrained_wte = self.pretrained_wte.to(self.first_device)
818
- if not self.exclude_position_ids:
819
- self.wpe = self.wpe.to(self.first_device)
820
533
  if self.include_values:
821
534
  self.vte = self.vte.to(self.first_device)
822
535
  self.concept_value_transformation_layer = (
@@ -842,8 +555,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
842
555
  self.wte = self.wte.to("cpu")
843
556
  if self.config.use_pretrained_embeddings:
844
557
  self.pretrained_wte = self.pretrained_wte.to("cpu")
845
- if not self.exclude_position_ids:
846
- self.wpe = self.wpe.to("cpu")
847
558
  self.vte = self.vte.to("cpu")
848
559
  self.concept_value_transformation_layer = (
849
560
  self.concept_value_transformation_layer.to("cpu")
@@ -871,8 +582,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
871
582
  def get_position_embeddings(
872
583
  self,
873
584
  ) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
874
- if not self.exclude_position_ids:
875
- return self.wpe
876
585
  return None
877
586
 
878
587
  def set_position_embeddings(self, new_embeddings: nn.Embedding):
@@ -946,24 +655,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
946
655
  # Convert list back to torch.Size if needed
947
656
  input_shape = torch.Size(shape_list)
948
657
 
949
- device = input_ids.device
658
+ input_ids.device
950
659
 
951
660
  if past_key_values is None:
952
- past_length = 0
953
661
  past_key_values = tuple([None] * len(self.h))
954
662
  else:
955
- past_length = past_key_values[0][0].size(-2)
956
-
957
- # This is normally called during training or fine-tuning.
958
- # While the generation logic will handle position_ids in the sampling logic
959
- if position_ids is None and not self.exclude_position_ids:
960
- position_ids = torch.arange(
961
- past_length,
962
- input_shape[-1] + past_length,
963
- dtype=torch.long,
964
- device=device,
965
- )
966
- position_ids = position_ids.unsqueeze(0)
663
+ past_key_values[0][0].size(-2)
967
664
 
968
665
  # GPT2Attention mask.
969
666
  if attention_mask is not None:
@@ -1046,10 +743,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1046
743
  ]
1047
744
  if random_vectors is None:
1048
745
  random_vectors = torch.rand_like(input_embeddings[:, :1])
746
+
1049
747
  input_embeddings = torch.concat(
1050
748
  [demographic_embeddings, random_vectors, medical_event_embeddings],
1051
749
  dim=1,
1052
750
  )
751
+ position_ids = torch.concat(
752
+ [
753
+ position_ids[:, : self.config.demographics_size],
754
+ position_ids[:, :1],
755
+ position_ids[:, self.config.demographics_size :],
756
+ ],
757
+ dim=1,
758
+ )
1053
759
 
1054
760
  if self.include_values:
1055
761
  if (
@@ -1075,13 +781,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1075
781
  value_embeddings=value_embeddings,
1076
782
  )
1077
783
 
1078
- if not self.exclude_position_ids:
1079
- position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
1080
- hidden_states = input_embeddings + position_embeds
1081
- else:
1082
- hidden_states = input_embeddings
1083
-
1084
- hidden_states = self.drop(hidden_states)
784
+ hidden_states = self.drop(input_embeddings)
1085
785
 
1086
786
  output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1087
787
 
@@ -1109,6 +809,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1109
809
  attention_mask = attention_mask.to(hidden_states.device)
1110
810
  if isinstance(head_mask, torch.Tensor):
1111
811
  head_mask = head_mask.to(hidden_states.device)
812
+
1112
813
  if output_hidden_states:
1113
814
  all_hidden_states = all_hidden_states + (hidden_states,)
1114
815
 
@@ -1116,6 +817,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1116
817
  outputs = self._gradient_checkpointing_func(
1117
818
  block.__call__,
1118
819
  hidden_states,
820
+ position_ids,
1119
821
  None,
1120
822
  attention_mask,
1121
823
  head_mask[i],
@@ -1127,6 +829,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1127
829
  else:
1128
830
  outputs = block(
1129
831
  hidden_states,
832
+ position_ids=position_ids,
1130
833
  layer_past=layer_past,
1131
834
  attention_mask=attention_mask,
1132
835
  head_mask=head_mask[i],
@@ -1200,7 +903,9 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1200
903
 
1201
904
  if self.config.include_motor_time_to_event:
1202
905
  self.motor_tte = MotorTaskHead(
1203
- config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
906
+ input_dim=config.n_embd,
907
+ motor_tte_vocab_size=config.motor_tte_vocab_size,
908
+ motor_num_time_pieces=config.motor_num_time_pieces,
1204
909
  )
1205
910
 
1206
911
  # Model parallel
@@ -1300,12 +1005,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1300
1005
  def prepare_inputs_for_generation(
1301
1006
  self,
1302
1007
  input_ids,
1008
+ cehrgpt_tokenizer,
1303
1009
  past_key_values=None,
1304
1010
  inputs_embeds=None,
1305
- lab_token_ids=None,
1306
1011
  **kwargs,
1307
1012
  ):
1308
-
1013
+ ages = kwargs.get("ages")
1309
1014
  # Omit tokens covered by past_key_values
1310
1015
  if past_key_values:
1311
1016
  past_length = past_key_values[0][0].shape[2]
@@ -1320,33 +1025,10 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1320
1025
  remove_prefix_length = input_ids.shape[1] - 1
1321
1026
 
1322
1027
  input_ids = input_ids[:, remove_prefix_length:]
1028
+ ages = ages[:, remove_prefix_length:]
1323
1029
 
1324
1030
  attention_mask = kwargs.get("attention_mask", None)
1325
- position_ids = kwargs.get("position_ids", None)
1326
1031
  random_vectors = kwargs.get("random_vectors", None)
1327
-
1328
- if attention_mask is not None and position_ids is None:
1329
- # create position_ids on the fly for batch generation
1330
- position_ids = attention_mask.long().cumsum(-1) - 1
1331
- position_ids.masked_fill_(attention_mask == 0, 1)
1332
- if past_key_values:
1333
- position_ids = position_ids[:, -input_ids.shape[1] :]
1334
-
1335
- # Add one more position for the random vectors
1336
- if (
1337
- self.cehrgpt.config.causal_sfm
1338
- and position_ids.shape[-1] >= self.cehrgpt.config.demographics_size
1339
- ):
1340
- position_ids = torch.concat(
1341
- [
1342
- position_ids,
1343
- torch.max(position_ids, dim=-1, keepdim=True)[0] + 1,
1344
- ],
1345
- dim=-1,
1346
- )
1347
- else:
1348
- position_ids = None
1349
-
1350
1032
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1351
1033
  if inputs_embeds is not None and past_key_values is None:
1352
1034
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -1384,7 +1066,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1384
1066
  {
1385
1067
  "past_key_values": past_key_values,
1386
1068
  "use_cache": kwargs.get("use_cache"),
1387
- "position_ids": position_ids,
1069
+ "ages": ages,
1388
1070
  "attention_mask": attention_mask,
1389
1071
  "random_vectors": random_vectors,
1390
1072
  }
@@ -1394,12 +1076,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1394
1076
 
1395
1077
  def motor_nll_loss(
1396
1078
  self,
1397
- ve_token_features,
1398
- motor_time_to_event_vectors,
1399
- motor_event_indicators,
1400
- motor_time_to_event_to_include,
1401
- motor_time_indicators,
1402
- batch_motor_end_index,
1079
+ hidden_states,
1080
+ motor_tte_times,
1081
+ motor_tte_event_indicators,
1082
+ motor_tte_task_indicators,
1083
+ motor_tte_masks,
1084
+ motor_end_index,
1403
1085
  ):
1404
1086
  """
1405
1087
  Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
@@ -1407,58 +1089,62 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1407
1089
  for modeling time-to-event data at each visit.
1408
1090
 
1409
1091
  Args:
1410
- ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
1411
- motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1412
- motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
1413
- motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1414
- motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
1415
- time bucket (1 if censored, 0 if event occurred).
1416
- batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1092
+ hidden_states (Tensor): Hidden representations for sequence tokens [num_of_concepts, hidden_dim].
1093
+ motor_tte_times (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1094
+ motor_tte_task_indicators: (Tensor): Bool indicators (True if included, False if not included).
1095
+ motor_tte_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1096
+ motor_tte_masks (Tensor): Binary indicators whether the prediction should be masked
1097
+ (1 if not masked, 0 if masked).
1098
+ motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1417
1099
 
1418
1100
  Returns:
1419
1101
  Tensor: Scalar loss value (mean negative log-likelihood).
1420
1102
  """
1421
- batch_motor_end_index = batch_motor_end_index.sum().item()
1422
- motor_time_to_event_vectors = motor_time_to_event_vectors.view(
1423
- (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1424
- )[:batch_motor_end_index].clamp(min=1e-3)
1425
- motor_event_indicators = motor_event_indicators.reshape(
1103
+ motor_end_index = motor_end_index.sum().item()
1104
+ motor_tte_times = motor_tte_times.view(
1426
1105
  (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1427
- )[:batch_motor_end_index]
1428
- motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
1429
- :batch_motor_end_index
1430
- ]
1431
- motor_time_indicators = motor_time_indicators.view(
1106
+ )[:motor_end_index].clamp(min=1e-3)
1107
+ motor_tte_event_indicators = motor_tte_event_indicators.reshape(
1432
1108
  (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1433
- )[:batch_motor_end_index]
1434
- assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
1109
+ )[:motor_end_index]
1110
+ # motor_tte_masks = motor_tte_masks.view(
1111
+ # (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1112
+ # )[:motor_end_index]
1113
+
1114
+ tte_features = hidden_states[motor_tte_task_indicators].view(
1115
+ (-1, self.config.n_embd)
1116
+ )
1117
+
1118
+ assert tte_features.shape[0] == motor_tte_times.shape[0], (
1435
1119
  "The number of VE tokens in the labels needs to match up "
1436
1120
  "with the first dimension of motor_time_to_event_vectors. "
1437
- f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
1438
- f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
1121
+ f"Received ve_token_features.shape[0]: {tte_features.shape[0]}, "
1122
+ f"motor_time_to_event_vectors.shape[0]: {motor_tte_times.shape[0]}"
1439
1123
  )
1440
- motor_time_to_event_vectors = motor_time_to_event_vectors[
1441
- motor_time_to_event_to_include
1442
- ]
1443
- motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
1444
- motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
1445
- ve_token_features = ve_token_features[motor_time_to_event_to_include]
1446
1124
 
1447
1125
  # Get Exponential parameters from model
1448
- lambda_p = self.motor_tte(ve_token_features)
1449
- # (num_visits_in_batch, num_of_pieces, motor_vocab_size)
1450
- dist = Exponential(lambda_p.clamp(min=1e-3))
1126
+ time_dependent_logits = self.motor_tte(tte_features)
1451
1127
 
1452
1128
  # Compute event loss
1453
- tte_loss = torch.where(
1454
- motor_event_indicators,
1455
- -dist.log_prob(motor_time_to_event_vectors),
1456
- -torch.log(
1457
- 1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
1458
- ),
1129
+ # Calculate the accumulative hazard
1130
+ # exp(-sum_{j} lambda_j)
1131
+ survival_loss = torch.exp2(time_dependent_logits + motor_tte_times).mean()
1132
+ event_loss = (
1133
+ -math.log(2)
1134
+ * torch.where(motor_tte_event_indicators, time_dependent_logits, 0).mean()
1459
1135
  )
1460
- tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
1461
- return torch.mean(tte_loss)
1136
+
1137
+ # survival_loss = (
1138
+ # torch.where(motor_tte_masks, lambda_p * motor_tte_times, 0)
1139
+ # .sum(dim=1)
1140
+ # .mean()
1141
+ # )
1142
+ # event_loss = (
1143
+ # -torch.where(motor_tte_event_indicators, torch.log(lambda_p), 0)
1144
+ # .sum(dim=1)
1145
+ # .mean()
1146
+ # )
1147
+ return survival_loss + event_loss
1462
1148
 
1463
1149
  def forward(
1464
1150
  self,
@@ -1467,7 +1153,6 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1467
1153
  values: Optional[torch.LongTensor] = None,
1468
1154
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1469
1155
  attention_mask: Optional[torch.FloatTensor] = None,
1470
- position_ids: Optional[torch.LongTensor] = None,
1471
1156
  head_mask: Optional[torch.FloatTensor] = None,
1472
1157
  random_vectors: Optional[torch.FloatTensor] = None,
1473
1158
  labels: Optional[torch.LongTensor] = None,
@@ -1476,21 +1161,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1476
1161
  time_to_visits: Optional[torch.FloatTensor] = None,
1477
1162
  time_token_indicators: Optional[torch.BoolTensor] = None,
1478
1163
  sub_time_tokens: Optional[torch.LongTensor] = None,
1479
- motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
1480
- motor_event_indicators: Optional[torch.BoolTensor] = None,
1481
- motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
1482
- motor_time_indicators: Optional[torch.BoolTensor] = None,
1164
+ motor_tte_times: Optional[torch.FloatTensor] = None,
1165
+ motor_tte_event_indicators: Optional[torch.BoolTensor] = None,
1166
+ motor_tte_task_indicators: Optional[torch.BoolTensor] = None,
1167
+ motor_tte_masks: Optional[torch.BoolTensor] = None,
1483
1168
  motor_end_index: Optional[torch.LongTensor] = None,
1484
1169
  use_cache: Optional[bool] = None,
1485
1170
  output_attentions: Optional[bool] = None,
1486
1171
  output_hidden_states: Optional[bool] = None,
1487
1172
  return_dict: Optional[bool] = None,
1173
+ ages: Optional[torch.FloatTensor] = None,
1174
+ epoch_times: Optional[torch.FloatTensor] = None,
1488
1175
  ) -> Union[Tuple, CehrGptCausalLMOutput]:
1489
1176
  r"""
1490
1177
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1491
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1492
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1493
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1178
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1179
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1180
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1494
1181
  """
1495
1182
  return_dict = (
1496
1183
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1502,7 +1189,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1502
1189
  values=values,
1503
1190
  past_key_values=past_key_values,
1504
1191
  attention_mask=attention_mask,
1505
- position_ids=position_ids,
1192
+ position_ids=ages,
1506
1193
  random_vectors=random_vectors,
1507
1194
  head_mask=head_mask,
1508
1195
  use_cache=use_cache,
@@ -1611,23 +1298,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1611
1298
 
1612
1299
  if (
1613
1300
  self.config.include_motor_time_to_event
1614
- and motor_time_to_event_vectors is not None
1615
- and motor_event_indicators is not None
1616
- and motor_time_to_event_to_include is not None
1617
- and motor_time_indicators is not None
1301
+ and motor_tte_times is not None
1302
+ and motor_tte_event_indicators is not None
1303
+ and motor_tte_task_indicators is not None
1304
+ and motor_tte_masks is not None
1618
1305
  and motor_end_index is not None
1619
1306
  ):
1620
- ve_token_id_indices = labels == self.config.ve_token_id
1621
- ve_token_features = hidden_states[ve_token_id_indices]
1622
- # Get rid of the last VE features because it's already reached the end of the patient sequence and
1623
- # there is nothing to predict.
1624
1307
  motor_tte_loss = self.motor_nll_loss(
1625
- ve_token_features=ve_token_features,
1626
- motor_time_to_event_vectors=motor_time_to_event_vectors,
1627
- motor_event_indicators=motor_event_indicators,
1628
- motor_time_to_event_to_include=motor_time_to_event_to_include,
1629
- motor_time_indicators=motor_time_indicators,
1630
- batch_motor_end_index=motor_end_index,
1308
+ hidden_states=hidden_states,
1309
+ motor_tte_times=motor_tte_times,
1310
+ motor_tte_event_indicators=motor_tte_event_indicators,
1311
+ motor_tte_task_indicators=motor_tte_task_indicators,
1312
+ motor_tte_masks=motor_tte_masks,
1313
+ motor_end_index=motor_end_index,
1631
1314
  )
1632
1315
  loss += motor_tte_loss * self.config.motor_time_to_event_weight
1633
1316
 
@@ -1833,6 +1516,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1833
1516
  else self.generation_config.return_dict_in_generate
1834
1517
  )
1835
1518
 
1519
+ if "cehrgpt_tokenizer" not in model_kwargs:
1520
+ raise RuntimeError(
1521
+ "The cehr-gpt tokenizer must be provided to the "
1522
+ "model.generate(..., cehrgpt_tokenizer=cehrgpt_tokenizer)"
1523
+ )
1524
+
1525
+ # Remove this from the model_kwargs and will pass it to other functions explicitly
1526
+ cehrgpt_tokenizer = model_kwargs.pop("cehrgpt_tokenizer")
1527
+
1836
1528
  # init attention / hidden states / scores tuples
1837
1529
  scores = () if (return_dict_in_generate and output_scores) else None
1838
1530
  raw_logits = () if (return_dict_in_generate and output_logits) else None
@@ -1848,6 +1540,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1848
1540
 
1849
1541
  # keep track of which sequences are already finished
1850
1542
  batch_size, cur_len = input_ids.shape
1543
+ model_kwargs["attention_mask"] = input_ids != pad_token_id
1851
1544
  if "inputs_embeds" in model_kwargs:
1852
1545
  cur_len = model_kwargs["inputs_embeds"].shape[1]
1853
1546
  this_peer_finished = False
@@ -1855,22 +1548,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1855
1548
  batch_size, dtype=torch.long, device=input_ids.device
1856
1549
  )
1857
1550
  model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
1858
- # Use the lab_token_ids in the argument, otherwise default to the configuration token_ids
1859
- if "lab_token_ids" in model_kwargs:
1860
- lab_token_ids = torch.tensor(
1861
- model_kwargs["lab_token_ids"],
1862
- dtype=torch.int32,
1863
- )
1551
+ # Getting the lab token ids
1552
+ lab_token_ids = torch.tensor(
1553
+ cehrgpt_tokenizer.lab_token_ids,
1554
+ dtype=torch.int32,
1555
+ )
1556
+ if model_kwargs.get("value_indicators", None) is not None:
1557
+ value_indicators = model_kwargs.get("value_indicators")
1864
1558
  else:
1865
- lab_token_ids = torch.tensor(
1866
- [] if self.config.lab_token_ids is None else self.config.lab_token_ids,
1559
+ value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1560
+
1561
+ if model_kwargs.get("values", None) is not None:
1562
+ values = model_kwargs.get("values")
1563
+ else:
1564
+ values = torch.zeros_like(
1565
+ input_ids,
1867
1566
  dtype=torch.int32,
1868
1567
  )
1869
- value_indicators = torch.zeros_like(input_ids).to(torch.bool)
1870
- values = torch.zeros_like(
1871
- input_ids,
1872
- dtype=torch.int32,
1873
- )
1874
1568
  # Generate initial random_vectors
1875
1569
  if self.cehrgpt.config.causal_sfm:
1876
1570
  model_kwargs["random_vectors"] = torch.rand(
@@ -1884,11 +1578,33 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1884
1578
  model_kwargs["random_vectors"] = None
1885
1579
  model_kwargs["value_indicators"] = value_indicators
1886
1580
  model_kwargs["values"] = values
1581
+
1582
+ # A variable to keep track of time and initialize it to zero
1583
+ batched_time_delta = np.zeros((batch_size,), dtype=np.float32)
1584
+ batched_ages = model_kwargs.get("ages", None)
1585
+ if batched_ages is None:
1586
+ batched_ages = []
1587
+ for token_ids in input_ids.detach().cpu():
1588
+ concept_ids = cehrgpt_tokenizer.decode(
1589
+ token_ids.numpy(), skip_special_tokens=False
1590
+ )
1591
+ batched_ages.append(construct_age_sequence(concept_ids))
1592
+ # Turn this to a numpy array for easy manipulation
1593
+ batched_ages = np.asarray(batched_ages)
1594
+ else:
1595
+ batched_ages = batched_ages.cpu().numpy()
1596
+ # This is the base to which we will add the time delta
1597
+ base_ages = np.asarray([ages[-1] for ages in batched_ages])
1598
+ # Update the keyword arguments for the prepare_inputs_for_generation
1599
+ model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
1600
+
1887
1601
  while self._has_unfinished_sequences(
1888
1602
  this_peer_finished, synced_gpus, device=input_ids.device
1889
1603
  ):
1890
1604
  # prepare model inputs
1891
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1605
+ model_inputs = self.prepare_inputs_for_generation(
1606
+ input_ids, cehrgpt_tokenizer, **model_kwargs
1607
+ )
1892
1608
 
1893
1609
  # forward pass to get next token
1894
1610
  outputs = self(
@@ -1933,6 +1649,22 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1933
1649
  probs = nn.functional.softmax(next_token_scores, dim=-1)
1934
1650
  next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1935
1651
 
1652
+ # TODO: decode to get time tokens and recalculate the age at this time step
1653
+ # Look for a potential time token
1654
+ for batch_i, next_concept_id in enumerate(
1655
+ cehrgpt_tokenizer.decode(
1656
+ next_tokens.detach().cpu().numpy(), skip_special_tokens=False
1657
+ )
1658
+ ):
1659
+ if is_att_token(next_concept_id):
1660
+ batched_time_delta[batch_i] += extract_time_interval_in_days(
1661
+ next_concept_id
1662
+ )
1663
+
1664
+ next_age = (base_ages + batched_time_delta // 365).astype(int)[..., None]
1665
+ batched_ages = np.concatenate([batched_ages, next_age], axis=-1)
1666
+ model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
1667
+
1936
1668
  # finished sentences should have their next token be a padding token
1937
1669
  if eos_token_id is not None:
1938
1670
  if pad_token_id is None:
@@ -1968,6 +1700,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1968
1700
 
1969
1701
  if streamer is not None:
1970
1702
  streamer.put(next_tokens.cpu())
1703
+
1971
1704
  model_kwargs = self._update_model_kwargs_for_generation(
1972
1705
  outputs,
1973
1706
  model_kwargs,
@@ -2012,7 +1745,7 @@ class FocalLoss(nn.Module):
2012
1745
  self.reduction = reduction
2013
1746
 
2014
1747
  def forward(self, logits, targets):
2015
- bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
1748
+ bce_loss = f.binary_cross_entropy_with_logits(logits, targets, reduction="none")
2016
1749
  probs = torch.sigmoid(logits)
2017
1750
  pt = torch.where(targets == 1, probs, 1 - probs)
2018
1751
  focal_term = (1 - pt) ** self.gamma
@@ -2094,12 +1827,13 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
2094
1827
  values: Optional[torch.LongTensor] = None,
2095
1828
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
2096
1829
  attention_mask: Optional[torch.FloatTensor] = None,
2097
- position_ids: Optional[torch.LongTensor] = None,
2098
1830
  head_mask: Optional[torch.FloatTensor] = None,
2099
1831
  use_cache: Optional[bool] = None,
2100
1832
  output_attentions: Optional[bool] = None,
2101
1833
  output_hidden_states: Optional[bool] = None,
2102
1834
  return_dict: Optional[bool] = None,
1835
+ ages: Optional[torch.FloatTensor] = None,
1836
+ epoch_times: Optional[torch.FloatTensor] = None,
2103
1837
  **kwargs,
2104
1838
  ) -> CehrGptSequenceClassifierOutput:
2105
1839
 
@@ -2109,12 +1843,12 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
2109
1843
  values=values,
2110
1844
  past_key_values=past_key_values,
2111
1845
  attention_mask=attention_mask,
2112
- position_ids=position_ids,
2113
1846
  head_mask=head_mask,
2114
1847
  use_cache=use_cache,
2115
1848
  output_attentions=output_attentions,
2116
1849
  output_hidden_states=output_hidden_states,
2117
1850
  return_dict=return_dict,
1851
+ position_ids=ages,
2118
1852
  )
2119
1853
 
2120
1854
  if is_sample_pack(attention_mask):