cehrgpt 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +285 -652
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +38 -5
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +2 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +20 -12
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +183 -460
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -78
  22. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  23. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -34
  24. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  25. cehrgpt/runners/sample_packing_trainer.py +11 -2
  26. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +8 -2
  27. cehrgpt-0.1.4.dist-info/METADATA +238 -0
  28. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/RECORD +32 -22
  29. cehrgpt-0.1.2.dist-info/METADATA +0 -209
  30. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  31. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/WHEEL +0 -0
  32. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/licenses/LICENSE +0 -0
  33. {cehrgpt-0.1.2.dist-info → cehrgpt-0.1.4.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,8 @@ import numpy as np
6
6
  import torch
7
7
  import torch.nn.functional as f
8
8
  from torch import nn
9
- from torch.distributions import Exponential, Gamma
9
+ from torch.distributions import Gamma
10
10
  from torch.nn import CrossEntropyLoss
11
- from torch.nn import functional as F
12
11
  from transformers import PreTrainedModel
13
12
  from transformers.activations import gelu_new
14
13
  from transformers.generation.logits_process import LogitsProcessorList
@@ -18,16 +17,20 @@ from transformers.generation.stopping_criteria import (
18
17
  )
19
18
  from transformers.generation.streamers import BaseStreamer
20
19
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
21
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
22
20
  from transformers.pytorch_utils import Conv1D
23
- from transformers.utils import (
24
- is_accelerate_available,
25
- is_flash_attn_2_available,
26
- logging,
27
- )
21
+ from transformers.utils import is_flash_attn_2_available, logging
28
22
  from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
29
23
 
24
+ from cehrgpt.gpt_utils import (
25
+ construct_age_sequence,
26
+ encode_demographics,
27
+ extract_time_interval_in_days,
28
+ is_att_token,
29
+ multiple_of_10,
30
+ )
31
+ from cehrgpt.models.activations import RMSNorm
30
32
  from cehrgpt.models.config import CEHRGPTConfig
33
+ from cehrgpt.models.gpt2 import GPT2Block, is_sample_pack
31
34
  from cehrgpt.models.hf_modeling_outputs import (
32
35
  CehrGptCausalLMOutput,
33
36
  CehrGptGenerateDecoderOnlyOutput,
@@ -35,13 +38,6 @@ from cehrgpt.models.hf_modeling_outputs import (
35
38
  CehrGptSequenceClassifierOutput,
36
39
  )
37
40
 
38
- if is_flash_attn_2_available():
39
- from flash_attn import flash_attn_func, flash_attn_varlen_func
40
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
41
-
42
- if is_accelerate_available():
43
- from accelerate.hooks import add_hook_to_module
44
-
45
41
  logger = logging.get_logger(__name__)
46
42
 
47
43
 
@@ -50,7 +46,7 @@ def extract_features_from_packed_sequence(
50
46
  attention_mask: torch.Tensor,
51
47
  ) -> torch.Tensor:
52
48
  max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
53
- padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
49
+ padded_attention_mask = f.pad(attention_mask[:, : max_index + 1], (0, 1))
54
50
  feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
55
51
  return hidden_state[:, feature_indices]
56
52
 
@@ -81,315 +77,41 @@ def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.
81
77
  return attn_matrix
82
78
 
83
79
 
84
- def is_sample_pack(attention_mask: torch.Tensor) -> bool:
85
- """
86
- Determines whether any sequence in the batch is likely sample-packed.
87
-
88
- A sample-packed sequence is one where there are non-padding (1) tokens
89
- after a padding (0) token, indicating multiple sequences packed together
90
- with padding as a separator.
91
-
92
- Args:
93
- attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
94
- where 1 indicates a real token and 0 indicates padding.
95
-
96
- Returns:
97
- bool: True if any sample in the batch is sample-packed, False otherwise.
98
- """
99
-
100
- # If the attention_maks is left padded, we will flip it so we can use the same logic below
101
- if (attention_mask[:, 0] == 0).any():
102
- attention_mask = attention_mask.flip(dims=[1])
103
-
104
- nonzero_counts = attention_mask.sum(dim=1)
105
- max_token_positions = torch.argmax(
106
- attention_mask.to(torch.int32).flip(dims=[1]), dim=1
107
- )
108
- max_indices = attention_mask.shape[1] - 1 - max_token_positions
109
- return torch.any(nonzero_counts < (max_indices + 1)).item()
110
-
111
-
112
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
113
- def _get_unpad_data(attention_mask):
114
- # This infers sample packing
115
- if is_sample_pack(attention_mask):
116
- # Assume input: attention_mask shape = (batch, seq_len)
117
- attention_mask = attention_mask.flatten() # shape: (seq_len,)
118
-
119
- # Compute max_index of the last non-zero element
120
- nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
121
- max_index = nonzero[-1].item()
122
-
123
- # Pad the truncated attention mask
124
- padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
125
-
126
- # Indices of all tokens
127
- indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
128
-
129
- # Find where 0s occur (segment boundaries)
130
- cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
131
- padded_attention_mask == 0
132
- ]
133
-
134
- # Compute seqlens per segment
135
- seqlens_in_batch = (
136
- cumsum_seqlens_in_batch
137
- - F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
138
- ).to(torch.int)
139
-
140
- max_seqlen_in_batch = (
141
- seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
142
- )
143
- cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
144
- else:
145
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
146
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
147
- max_seqlen_in_batch = seqlens_in_batch.max().item()
148
- cu_seqlens = F.pad(
149
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
150
- )
151
-
152
- return (
153
- indices,
154
- cu_seqlens,
155
- max_seqlen_in_batch,
156
- )
157
-
158
-
159
- class GPT2FlashAttention(GPT2Attention):
160
- """
161
- GPT2FlashAttention inherits from `GPT2Attention`.
162
-
163
- The primary change is in the forward pass, where it correctly
164
- calls the public API of flash attention and handles padding tokens.
165
- """
166
-
167
- def forward(
168
- self,
169
- hidden_states: Optional[Tuple[torch.FloatTensor]],
170
- layer_past: Optional[Tuple[torch.Tensor]] = None,
171
- attention_mask: Optional[torch.FloatTensor] = None,
172
- head_mask: Optional[torch.FloatTensor] = None,
173
- encoder_hidden_states: Optional[torch.Tensor] = None,
174
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
175
- use_cache: Optional[bool] = False,
176
- output_attentions: Optional[bool] = False,
177
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
178
- # Prepare query, key, and value
179
- if encoder_hidden_states is not None:
180
- if not hasattr(self, "q_attn"):
181
- raise ValueError(
182
- "If class is used as cross attention, the weights `q_attn` have to be defined. "
183
- "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
184
- )
185
-
186
- query = self.q_attn(hidden_states)
187
- key, value = self.c_attn(encoder_hidden_states).split(
188
- self.split_size, dim=2
189
- )
190
- attention_mask = encoder_attention_mask
191
- else:
192
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
193
-
194
- query = self._split_heads(query, self.num_heads, self.head_dim)
195
- key = self._split_heads(key, self.num_heads, self.head_dim)
196
- value = self._split_heads(value, self.num_heads, self.head_dim)
197
-
198
- if layer_past is not None:
199
- past_key, past_value = layer_past
200
- key = torch.cat((past_key, key), dim=-2)
201
- value = torch.cat((past_value, value), dim=-2)
202
-
203
- if use_cache is True:
204
- present = (key, value)
205
- else:
206
- present = None
207
-
208
- # Apply Flash Attention Forward
209
- if self.reorder_and_upcast_attn:
210
- attn_output, attn_weights = self._upcast_and_reordered_attn(
211
- query, key, value, attention_mask, head_mask
212
- )
213
- else:
214
- # Flash Attention forward pass
215
- attn_output = self._flash_attention_forward(
216
- query,
217
- key,
218
- value,
219
- attention_mask,
220
- query.size(-2),
221
- self.attn_dropout.p,
222
- softmax_scale=None,
223
- )
224
-
225
- # Merge heads and project back to hidden size
226
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
227
- attn_output = self.c_proj(attn_output)
228
- attn_output = self.resid_dropout(attn_output)
229
-
230
- outputs = (attn_output, present)
231
- if output_attentions:
232
- outputs += (attn_weights,)
233
-
234
- return outputs
235
-
236
- def _flash_attention_forward(
80
+ class MotorTaskHead(nn.Module):
81
+ def __init__(
237
82
  self,
238
- query_states,
239
- key_states,
240
- value_states,
241
- attention_mask,
242
- query_length,
243
- dropout=0.0,
244
- softmax_scale=None,
83
+ input_dim,
84
+ motor_tte_vocab_size,
85
+ motor_num_time_pieces,
86
+ eps=1e-6,
245
87
  ):
246
- """
247
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token.
248
-
249
- first unpad the input, then computes the attention scores and pad the final attention scores.
250
- Args:
251
- query_states (`torch.Tensor`):
252
- Input query states to be passed to Flash Attention API
253
- key_states (`torch.Tensor`):
254
- Input key states to be passed to Flash Attention API
255
- value_states (`torch.Tensor`):
256
- Input value states to be passed to Flash Attention API
257
- attention_mask (`torch.Tensor`):
258
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
259
- position of padding tokens and 1 for the position of non-padding tokens.
260
- dropout (`int`, *optional*):
261
- Attention dropout
262
- softmax_scale (`float`, *optional*):
263
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
264
- """
265
-
266
- # Flash attention requires the input to have the shape
267
- # batch_size x seq_length x head_dim x hidden_dim
268
- # therefore we just need to keep the original shape
269
- dtype = query_states.dtype
270
- query_states = query_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
271
- key_states = key_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
272
- value_states = value_states.permute(0, 2, 1, 3).contiguous().to(torch.bfloat16)
273
-
274
- # Contains at least one padding token in the sequence
275
- if attention_mask is not None:
276
- batch_size = query_states.shape[0]
277
-
278
- (
279
- query_states,
280
- key_states,
281
- value_states,
282
- indices_q,
283
- cu_seq_lens,
284
- max_seq_lens,
285
- ) = self._upad_input(
286
- query_states, key_states, value_states, attention_mask, query_length
287
- )
288
-
289
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
290
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
291
-
292
- attn_output_unpad = flash_attn_varlen_func(
293
- query_states,
294
- key_states,
295
- value_states,
296
- cu_seqlens_q=cu_seqlens_q,
297
- cu_seqlens_k=cu_seqlens_k,
298
- max_seqlen_q=max_seqlen_in_batch_q,
299
- max_seqlen_k=max_seqlen_in_batch_k,
300
- dropout_p=dropout,
301
- softmax_scale=softmax_scale,
302
- causal=True,
303
- )
304
- # (batch, seq_length, n_heads, head_dim)
305
- attn_output = pad_input(
306
- attn_output_unpad, indices_q, batch_size, query_length
307
- )
308
- else:
309
- attn_output = flash_attn_func(
310
- query_states,
311
- key_states,
312
- value_states,
313
- dropout,
314
- softmax_scale=softmax_scale,
315
- causal=self.is_causal,
316
- )
317
- # re-order the tensor back to (batch, n_heads, seq_length, head_dim)
318
- return attn_output.permute(0, 2, 1, 3).contiguous().to(dtype)
319
-
320
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
321
- def _upad_input(
322
- self, query_layer, key_layer, value_layer, attention_mask, query_length
323
- ):
324
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
325
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
326
-
327
- key_layer = index_first_axis(
328
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
329
- indices_k,
330
- )
331
- value_layer = index_first_axis(
332
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
333
- indices_k,
334
- )
335
- if query_length == kv_seq_len:
336
- query_layer = index_first_axis(
337
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
338
- indices_k,
339
- )
340
- cu_seqlens_q = cu_seqlens_k
341
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
342
- indices_q = indices_k
343
- elif query_length == 1:
344
- max_seqlen_in_batch_q = 1
345
- cu_seqlens_q = torch.arange(
346
- batch_size + 1, dtype=torch.int32, device=query_layer.device
347
- ) # There is a memcpy here, that is very bad.
348
- indices_q = cu_seqlens_q[:-1]
349
- query_layer = query_layer.squeeze(1)
350
- else:
351
- # The -q_len: slice assumes left padding.
352
- attention_mask = attention_mask[:, -query_length:]
353
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
354
- query_layer, attention_mask
355
- )
356
-
357
- return (
358
- query_layer,
359
- key_layer,
360
- value_layer,
361
- indices_q,
362
- (cu_seqlens_q, cu_seqlens_k),
363
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
364
- )
365
-
366
-
367
- class MotorTaskHead(nn.Module):
368
- def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
369
88
  super(MotorTaskHead, self).__init__()
370
89
  self.input_dim = input_dim
371
90
  self.motor_tte_vocab_size = motor_tte_vocab_size
372
91
  self.motor_num_time_pieces = motor_num_time_pieces
373
- self.linear = nn.Sequential(
374
- nn.Linear(input_dim, input_dim // 2),
375
- gelu_new,
376
- nn.Linear(
377
- input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
378
- ),
92
+ self.final_layer = nn.Linear(input_dim, input_dim * motor_num_time_pieces)
93
+ self.norm = RMSNorm(input_dim, eps)
94
+ self.task_layer = nn.Linear(input_dim, motor_tte_vocab_size)
95
+ self.task_time_bias = nn.Parameter(
96
+ torch.zeros(1, self.motor_num_time_pieces, motor_tte_vocab_size)
379
97
  )
380
98
 
381
99
  def forward(self, x):
382
100
  # Ensure scale is positive
383
101
  length = x.shape[0]
384
102
  # (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
385
- lambda_p = f.softplus(self.linear(x))
386
- # Check for NaN values
387
- if torch.isnan(lambda_p).any():
388
- logger.warning(f"NaN values found in scale_param. x: {x}")
389
- # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
390
- return lambda_p.view(
391
- length, self.motor_num_time_pieces, self.motor_tte_vocab_size
103
+ x = self.final_layer(x).reshape(
104
+ length, self.motor_num_time_pieces, self.input_dim
392
105
  )
106
+ x = self.norm(x)
107
+ x = self.task_layer(x) + self.task_time_bias
108
+ # lambda_p = f.softplus(x)
109
+
110
+ # # Check for NaN values
111
+ # if torch.isnan(lambda_p).any():
112
+ # logger.warning(f"NaN values found in scale_param. x: {x}")
113
+ # # (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
114
+ return x
393
115
 
394
116
 
395
117
  class VisitTimeToEventHead(nn.Module):
@@ -725,7 +447,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
725
447
  def __init__(self, config: CEHRGPTConfig):
726
448
  super().__init__(config)
727
449
 
728
- self.exclude_position_ids = config.exclude_position_ids
729
450
  self.include_values = config.include_values
730
451
  self.include_ttv_prediction = config.include_ttv_prediction
731
452
  self.embed_dim = config.hidden_size
@@ -736,8 +457,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
736
457
  self.pretrained_wte = None
737
458
 
738
459
  self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
739
- if not self.exclude_position_ids:
740
- self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
741
460
  if self.include_values:
742
461
  self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
743
462
  self.concept_value_transformation_layer = ConceptValueTransformationLayer(
@@ -748,9 +467,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
748
467
  gpt_blocks = []
749
468
  for i in range(config.num_hidden_layers):
750
469
  gpt_block = GPT2Block(config, layer_idx=i)
751
- if getattr(config, "_attn_implementation", "eager") == "flash_attention_2":
752
- gpt_block.attn = GPT2FlashAttention(config, layer_idx=i)
753
- gpt_block.is_causal = True
470
+ gpt_block.is_causal = True
754
471
  gpt_blocks.append(gpt_block)
755
472
  self.h = nn.ModuleList(gpt_blocks)
756
473
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -771,10 +488,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
771
488
  )
772
489
  self.update_attn_bias(self.config.sample_packing_max_positions)
773
490
 
774
- def enable_position_embeddings(self):
775
- self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
776
- self.config.exclude_position_ids = False
777
-
778
491
  def initialize_pretrained_embeddings(self):
779
492
  layers = [
780
493
  nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
@@ -817,8 +530,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
817
530
  self.wte = self.wte.to(self.first_device)
818
531
  if self.config.use_pretrained_embeddings:
819
532
  self.pretrained_wte = self.pretrained_wte.to(self.first_device)
820
- if not self.exclude_position_ids:
821
- self.wpe = self.wpe.to(self.first_device)
822
533
  if self.include_values:
823
534
  self.vte = self.vte.to(self.first_device)
824
535
  self.concept_value_transformation_layer = (
@@ -844,8 +555,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
844
555
  self.wte = self.wte.to("cpu")
845
556
  if self.config.use_pretrained_embeddings:
846
557
  self.pretrained_wte = self.pretrained_wte.to("cpu")
847
- if not self.exclude_position_ids:
848
- self.wpe = self.wpe.to("cpu")
849
558
  self.vte = self.vte.to("cpu")
850
559
  self.concept_value_transformation_layer = (
851
560
  self.concept_value_transformation_layer.to("cpu")
@@ -873,8 +582,6 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
873
582
  def get_position_embeddings(
874
583
  self,
875
584
  ) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
876
- if not self.exclude_position_ids:
877
- return self.wpe
878
585
  return None
879
586
 
880
587
  def set_position_embeddings(self, new_embeddings: nn.Embedding):
@@ -948,24 +655,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
948
655
  # Convert list back to torch.Size if needed
949
656
  input_shape = torch.Size(shape_list)
950
657
 
951
- device = input_ids.device
658
+ input_ids.device
952
659
 
953
660
  if past_key_values is None:
954
- past_length = 0
955
661
  past_key_values = tuple([None] * len(self.h))
956
662
  else:
957
- past_length = past_key_values[0][0].size(-2)
958
-
959
- # This is normally called during training or fine-tuning.
960
- # While the generation logic will handle position_ids in the sampling logic
961
- if position_ids is None and not self.exclude_position_ids:
962
- position_ids = torch.arange(
963
- past_length,
964
- input_shape[-1] + past_length,
965
- dtype=torch.long,
966
- device=device,
967
- )
968
- position_ids = position_ids.unsqueeze(0)
663
+ past_key_values[0][0].size(-2)
969
664
 
970
665
  # GPT2Attention mask.
971
666
  if attention_mask is not None:
@@ -1048,10 +743,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1048
743
  ]
1049
744
  if random_vectors is None:
1050
745
  random_vectors = torch.rand_like(input_embeddings[:, :1])
746
+
1051
747
  input_embeddings = torch.concat(
1052
748
  [demographic_embeddings, random_vectors, medical_event_embeddings],
1053
749
  dim=1,
1054
750
  )
751
+ position_ids = torch.concat(
752
+ [
753
+ position_ids[:, : self.config.demographics_size],
754
+ position_ids[:, :1],
755
+ position_ids[:, self.config.demographics_size :],
756
+ ],
757
+ dim=1,
758
+ )
1055
759
 
1056
760
  if self.include_values:
1057
761
  if (
@@ -1077,13 +781,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1077
781
  value_embeddings=value_embeddings,
1078
782
  )
1079
783
 
1080
- if not self.exclude_position_ids:
1081
- position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
1082
- hidden_states = input_embeddings + position_embeds
1083
- else:
1084
- hidden_states = input_embeddings
1085
-
1086
- hidden_states = self.drop(hidden_states)
784
+ hidden_states = self.drop(input_embeddings)
1087
785
 
1088
786
  output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1089
787
 
@@ -1111,6 +809,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1111
809
  attention_mask = attention_mask.to(hidden_states.device)
1112
810
  if isinstance(head_mask, torch.Tensor):
1113
811
  head_mask = head_mask.to(hidden_states.device)
812
+
1114
813
  if output_hidden_states:
1115
814
  all_hidden_states = all_hidden_states + (hidden_states,)
1116
815
 
@@ -1118,6 +817,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1118
817
  outputs = self._gradient_checkpointing_func(
1119
818
  block.__call__,
1120
819
  hidden_states,
820
+ position_ids,
1121
821
  None,
1122
822
  attention_mask,
1123
823
  head_mask[i],
@@ -1129,6 +829,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
1129
829
  else:
1130
830
  outputs = block(
1131
831
  hidden_states,
832
+ position_ids=position_ids,
1132
833
  layer_past=layer_past,
1133
834
  attention_mask=attention_mask,
1134
835
  head_mask=head_mask[i],
@@ -1202,7 +903,9 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1202
903
 
1203
904
  if self.config.include_motor_time_to_event:
1204
905
  self.motor_tte = MotorTaskHead(
1205
- config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
906
+ input_dim=config.n_embd,
907
+ motor_tte_vocab_size=config.motor_tte_vocab_size,
908
+ motor_num_time_pieces=config.motor_num_time_pieces,
1206
909
  )
1207
910
 
1208
911
  # Model parallel
@@ -1302,12 +1005,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1302
1005
  def prepare_inputs_for_generation(
1303
1006
  self,
1304
1007
  input_ids,
1008
+ cehrgpt_tokenizer,
1305
1009
  past_key_values=None,
1306
1010
  inputs_embeds=None,
1307
- lab_token_ids=None,
1308
1011
  **kwargs,
1309
1012
  ):
1310
-
1013
+ ages = kwargs.get("ages")
1311
1014
  # Omit tokens covered by past_key_values
1312
1015
  if past_key_values:
1313
1016
  past_length = past_key_values[0][0].shape[2]
@@ -1322,33 +1025,10 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1322
1025
  remove_prefix_length = input_ids.shape[1] - 1
1323
1026
 
1324
1027
  input_ids = input_ids[:, remove_prefix_length:]
1028
+ ages = ages[:, remove_prefix_length:]
1325
1029
 
1326
1030
  attention_mask = kwargs.get("attention_mask", None)
1327
- position_ids = kwargs.get("position_ids", None)
1328
1031
  random_vectors = kwargs.get("random_vectors", None)
1329
-
1330
- if attention_mask is not None and position_ids is None:
1331
- # create position_ids on the fly for batch generation
1332
- position_ids = attention_mask.long().cumsum(-1) - 1
1333
- position_ids.masked_fill_(attention_mask == 0, 1)
1334
- if past_key_values:
1335
- position_ids = position_ids[:, -input_ids.shape[1] :]
1336
-
1337
- # Add one more position for the random vectors
1338
- if (
1339
- self.cehrgpt.config.causal_sfm
1340
- and position_ids.shape[-1] >= self.cehrgpt.config.demographics_size
1341
- ):
1342
- position_ids = torch.concat(
1343
- [
1344
- position_ids,
1345
- torch.max(position_ids, dim=-1, keepdim=True)[0] + 1,
1346
- ],
1347
- dim=-1,
1348
- )
1349
- else:
1350
- position_ids = None
1351
-
1352
1032
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1353
1033
  if inputs_embeds is not None and past_key_values is None:
1354
1034
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -1386,7 +1066,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1386
1066
  {
1387
1067
  "past_key_values": past_key_values,
1388
1068
  "use_cache": kwargs.get("use_cache"),
1389
- "position_ids": position_ids,
1069
+ "ages": ages,
1390
1070
  "attention_mask": attention_mask,
1391
1071
  "random_vectors": random_vectors,
1392
1072
  }
@@ -1396,12 +1076,12 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1396
1076
 
1397
1077
  def motor_nll_loss(
1398
1078
  self,
1399
- ve_token_features,
1400
- motor_time_to_event_vectors,
1401
- motor_event_indicators,
1402
- motor_time_to_event_to_include,
1403
- motor_time_indicators,
1404
- batch_motor_end_index,
1079
+ hidden_states,
1080
+ motor_tte_times,
1081
+ motor_tte_event_indicators,
1082
+ motor_tte_task_indicators,
1083
+ motor_tte_masks,
1084
+ motor_end_index,
1405
1085
  ):
1406
1086
  """
1407
1087
  Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
@@ -1409,58 +1089,62 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1409
1089
  for modeling time-to-event data at each visit.
1410
1090
 
1411
1091
  Args:
1412
- ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
1413
- motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1414
- motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
1415
- motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1416
- motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
1417
- time bucket (1 if censored, 0 if event occurred).
1418
- batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1092
+ hidden_states (Tensor): Hidden representations for sequence tokens [num_of_concepts, hidden_dim].
1093
+ motor_tte_times (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
1094
+ motor_tte_task_indicators: (Tensor): Bool indicators (True if included, False if not included).
1095
+ motor_tte_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
1096
+ motor_tte_masks (Tensor): Binary indicators whether the prediction should be masked
1097
+ (1 if not masked, 0 if masked).
1098
+ motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
1419
1099
 
1420
1100
  Returns:
1421
1101
  Tensor: Scalar loss value (mean negative log-likelihood).
1422
1102
  """
1423
- batch_motor_end_index = batch_motor_end_index.sum().item()
1424
- motor_time_to_event_vectors = motor_time_to_event_vectors.view(
1103
+ motor_end_index = motor_end_index.sum().item()
1104
+ motor_tte_times = motor_tte_times.view(
1425
1105
  (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1426
- )[:batch_motor_end_index].clamp(min=1e-3)
1427
- motor_event_indicators = motor_event_indicators.reshape(
1428
- (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1429
- )[:batch_motor_end_index]
1430
- motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
1431
- :batch_motor_end_index
1432
- ]
1433
- motor_time_indicators = motor_time_indicators.view(
1106
+ )[:motor_end_index].clamp(min=1e-3)
1107
+ motor_tte_event_indicators = motor_tte_event_indicators.reshape(
1434
1108
  (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1435
- )[:batch_motor_end_index]
1436
- assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
1109
+ )[:motor_end_index]
1110
+ # motor_tte_masks = motor_tte_masks.view(
1111
+ # (-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
1112
+ # )[:motor_end_index]
1113
+
1114
+ tte_features = hidden_states[motor_tte_task_indicators].view(
1115
+ (-1, self.config.n_embd)
1116
+ )
1117
+
1118
+ assert tte_features.shape[0] == motor_tte_times.shape[0], (
1437
1119
  "The number of VE tokens in the labels needs to match up "
1438
1120
  "with the first dimension of motor_time_to_event_vectors. "
1439
- f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
1440
- f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
1121
+ f"Received ve_token_features.shape[0]: {tte_features.shape[0]}, "
1122
+ f"motor_time_to_event_vectors.shape[0]: {motor_tte_times.shape[0]}"
1441
1123
  )
1442
- motor_time_to_event_vectors = motor_time_to_event_vectors[
1443
- motor_time_to_event_to_include
1444
- ]
1445
- motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
1446
- motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
1447
- ve_token_features = ve_token_features[motor_time_to_event_to_include]
1448
1124
 
1449
1125
  # Get Exponential parameters from model
1450
- lambda_p = self.motor_tte(ve_token_features)
1451
- # (num_visits_in_batch, num_of_pieces, motor_vocab_size)
1452
- dist = Exponential(lambda_p.clamp(min=1e-3))
1126
+ time_dependent_logits = self.motor_tte(tte_features)
1453
1127
 
1454
1128
  # Compute event loss
1455
- tte_loss = torch.where(
1456
- motor_event_indicators,
1457
- -dist.log_prob(motor_time_to_event_vectors),
1458
- -torch.log(
1459
- 1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
1460
- ),
1129
+ # Calculate the accumulative hazard
1130
+ # exp(-sum_{j} lambda_j)
1131
+ survival_loss = torch.exp2(time_dependent_logits + motor_tte_times).mean()
1132
+ event_loss = (
1133
+ -math.log(2)
1134
+ * torch.where(motor_tte_event_indicators, time_dependent_logits, 0).mean()
1461
1135
  )
1462
- tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
1463
- return torch.mean(tte_loss)
1136
+
1137
+ # survival_loss = (
1138
+ # torch.where(motor_tte_masks, lambda_p * motor_tte_times, 0)
1139
+ # .sum(dim=1)
1140
+ # .mean()
1141
+ # )
1142
+ # event_loss = (
1143
+ # -torch.where(motor_tte_event_indicators, torch.log(lambda_p), 0)
1144
+ # .sum(dim=1)
1145
+ # .mean()
1146
+ # )
1147
+ return survival_loss + event_loss
1464
1148
 
1465
1149
  def forward(
1466
1150
  self,
@@ -1469,7 +1153,6 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1469
1153
  values: Optional[torch.LongTensor] = None,
1470
1154
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1471
1155
  attention_mask: Optional[torch.FloatTensor] = None,
1472
- position_ids: Optional[torch.LongTensor] = None,
1473
1156
  head_mask: Optional[torch.FloatTensor] = None,
1474
1157
  random_vectors: Optional[torch.FloatTensor] = None,
1475
1158
  labels: Optional[torch.LongTensor] = None,
@@ -1478,21 +1161,23 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1478
1161
  time_to_visits: Optional[torch.FloatTensor] = None,
1479
1162
  time_token_indicators: Optional[torch.BoolTensor] = None,
1480
1163
  sub_time_tokens: Optional[torch.LongTensor] = None,
1481
- motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
1482
- motor_event_indicators: Optional[torch.BoolTensor] = None,
1483
- motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
1484
- motor_time_indicators: Optional[torch.BoolTensor] = None,
1164
+ motor_tte_times: Optional[torch.FloatTensor] = None,
1165
+ motor_tte_event_indicators: Optional[torch.BoolTensor] = None,
1166
+ motor_tte_task_indicators: Optional[torch.BoolTensor] = None,
1167
+ motor_tte_masks: Optional[torch.BoolTensor] = None,
1485
1168
  motor_end_index: Optional[torch.LongTensor] = None,
1486
1169
  use_cache: Optional[bool] = None,
1487
1170
  output_attentions: Optional[bool] = None,
1488
1171
  output_hidden_states: Optional[bool] = None,
1489
1172
  return_dict: Optional[bool] = None,
1173
+ ages: Optional[torch.FloatTensor] = None,
1174
+ epoch_times: Optional[torch.FloatTensor] = None,
1490
1175
  ) -> Union[Tuple, CehrGptCausalLMOutput]:
1491
1176
  r"""
1492
1177
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1493
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1494
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1495
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1178
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1179
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1180
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1496
1181
  """
1497
1182
  return_dict = (
1498
1183
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1504,7 +1189,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1504
1189
  values=values,
1505
1190
  past_key_values=past_key_values,
1506
1191
  attention_mask=attention_mask,
1507
- position_ids=position_ids,
1192
+ position_ids=ages,
1508
1193
  random_vectors=random_vectors,
1509
1194
  head_mask=head_mask,
1510
1195
  use_cache=use_cache,
@@ -1613,23 +1298,19 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1613
1298
 
1614
1299
  if (
1615
1300
  self.config.include_motor_time_to_event
1616
- and motor_time_to_event_vectors is not None
1617
- and motor_event_indicators is not None
1618
- and motor_time_to_event_to_include is not None
1619
- and motor_time_indicators is not None
1301
+ and motor_tte_times is not None
1302
+ and motor_tte_event_indicators is not None
1303
+ and motor_tte_task_indicators is not None
1304
+ and motor_tte_masks is not None
1620
1305
  and motor_end_index is not None
1621
1306
  ):
1622
- ve_token_id_indices = labels == self.config.ve_token_id
1623
- ve_token_features = hidden_states[ve_token_id_indices]
1624
- # Get rid of the last VE features because it's already reached the end of the patient sequence and
1625
- # there is nothing to predict.
1626
1307
  motor_tte_loss = self.motor_nll_loss(
1627
- ve_token_features=ve_token_features,
1628
- motor_time_to_event_vectors=motor_time_to_event_vectors,
1629
- motor_event_indicators=motor_event_indicators,
1630
- motor_time_to_event_to_include=motor_time_to_event_to_include,
1631
- motor_time_indicators=motor_time_indicators,
1632
- batch_motor_end_index=motor_end_index,
1308
+ hidden_states=hidden_states,
1309
+ motor_tte_times=motor_tte_times,
1310
+ motor_tte_event_indicators=motor_tte_event_indicators,
1311
+ motor_tte_task_indicators=motor_tte_task_indicators,
1312
+ motor_tte_masks=motor_tte_masks,
1313
+ motor_end_index=motor_end_index,
1633
1314
  )
1634
1315
  loss += motor_tte_loss * self.config.motor_time_to_event_weight
1635
1316
 
@@ -1835,6 +1516,15 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1835
1516
  else self.generation_config.return_dict_in_generate
1836
1517
  )
1837
1518
 
1519
+ if "cehrgpt_tokenizer" not in model_kwargs:
1520
+ raise RuntimeError(
1521
+ "The cehr-gpt tokenizer must be provided to the "
1522
+ "model.generate(..., cehrgpt_tokenizer=cehrgpt_tokenizer)"
1523
+ )
1524
+
1525
+ # Remove this from the model_kwargs and will pass it to other functions explicitly
1526
+ cehrgpt_tokenizer = model_kwargs.pop("cehrgpt_tokenizer")
1527
+
1838
1528
  # init attention / hidden states / scores tuples
1839
1529
  scores = () if (return_dict_in_generate and output_scores) else None
1840
1530
  raw_logits = () if (return_dict_in_generate and output_logits) else None
@@ -1858,18 +1548,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1858
1548
  batch_size, dtype=torch.long, device=input_ids.device
1859
1549
  )
1860
1550
  model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
1861
- # Use the lab_token_ids in the argument, otherwise default to the configuration token_ids
1862
- if "lab_token_ids" in model_kwargs:
1863
- lab_token_ids = torch.tensor(
1864
- model_kwargs["lab_token_ids"],
1865
- dtype=torch.int32,
1866
- )
1867
- else:
1868
- lab_token_ids = torch.tensor(
1869
- [] if self.config.lab_token_ids is None else self.config.lab_token_ids,
1870
- dtype=torch.int32,
1871
- )
1872
-
1551
+ # Getting the lab token ids
1552
+ lab_token_ids = torch.tensor(
1553
+ cehrgpt_tokenizer.lab_token_ids,
1554
+ dtype=torch.int32,
1555
+ )
1873
1556
  if model_kwargs.get("value_indicators", None) is not None:
1874
1557
  value_indicators = model_kwargs.get("value_indicators")
1875
1558
  else:
@@ -1895,11 +1578,33 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1895
1578
  model_kwargs["random_vectors"] = None
1896
1579
  model_kwargs["value_indicators"] = value_indicators
1897
1580
  model_kwargs["values"] = values
1581
+
1582
+ # A variable to keep track of time and initialize it to zero
1583
+ batched_time_delta = np.zeros((batch_size,), dtype=np.float32)
1584
+ batched_ages = model_kwargs.get("ages", None)
1585
+ if batched_ages is None:
1586
+ batched_ages = []
1587
+ for token_ids in input_ids.detach().cpu():
1588
+ concept_ids = cehrgpt_tokenizer.decode(
1589
+ token_ids.numpy(), skip_special_tokens=False
1590
+ )
1591
+ batched_ages.append(construct_age_sequence(concept_ids))
1592
+ # Turn this to a numpy array for easy manipulation
1593
+ batched_ages = np.asarray(batched_ages)
1594
+ else:
1595
+ batched_ages = batched_ages.cpu().numpy()
1596
+ # This is the base to which we will add the time delta
1597
+ base_ages = np.asarray([ages[-1] for ages in batched_ages])
1598
+ # Update the keyword arguments for the prepare_inputs_for_generation
1599
+ model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
1600
+
1898
1601
  while self._has_unfinished_sequences(
1899
1602
  this_peer_finished, synced_gpus, device=input_ids.device
1900
1603
  ):
1901
1604
  # prepare model inputs
1902
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1605
+ model_inputs = self.prepare_inputs_for_generation(
1606
+ input_ids, cehrgpt_tokenizer, **model_kwargs
1607
+ )
1903
1608
 
1904
1609
  # forward pass to get next token
1905
1610
  outputs = self(
@@ -1944,6 +1649,22 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1944
1649
  probs = nn.functional.softmax(next_token_scores, dim=-1)
1945
1650
  next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1946
1651
 
1652
+ # TODO: decode to get time tokens and recalculate the age at this time step
1653
+ # Look for a potential time token
1654
+ for batch_i, next_concept_id in enumerate(
1655
+ cehrgpt_tokenizer.decode(
1656
+ next_tokens.detach().cpu().numpy(), skip_special_tokens=False
1657
+ )
1658
+ ):
1659
+ if is_att_token(next_concept_id):
1660
+ batched_time_delta[batch_i] += extract_time_interval_in_days(
1661
+ next_concept_id
1662
+ )
1663
+
1664
+ next_age = (base_ages + batched_time_delta // 365).astype(int)[..., None]
1665
+ batched_ages = np.concatenate([batched_ages, next_age], axis=-1)
1666
+ model_kwargs["ages"] = torch.tensor(batched_ages).to(input_ids.device)
1667
+
1947
1668
  # finished sentences should have their next token be a padding token
1948
1669
  if eos_token_id is not None:
1949
1670
  if pad_token_id is None:
@@ -1979,6 +1700,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
1979
1700
 
1980
1701
  if streamer is not None:
1981
1702
  streamer.put(next_tokens.cpu())
1703
+
1982
1704
  model_kwargs = self._update_model_kwargs_for_generation(
1983
1705
  outputs,
1984
1706
  model_kwargs,
@@ -2023,7 +1745,7 @@ class FocalLoss(nn.Module):
2023
1745
  self.reduction = reduction
2024
1746
 
2025
1747
  def forward(self, logits, targets):
2026
- bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
1748
+ bce_loss = f.binary_cross_entropy_with_logits(logits, targets, reduction="none")
2027
1749
  probs = torch.sigmoid(logits)
2028
1750
  pt = torch.where(targets == 1, probs, 1 - probs)
2029
1751
  focal_term = (1 - pt) ** self.gamma
@@ -2105,12 +1827,13 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
2105
1827
  values: Optional[torch.LongTensor] = None,
2106
1828
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
2107
1829
  attention_mask: Optional[torch.FloatTensor] = None,
2108
- position_ids: Optional[torch.LongTensor] = None,
2109
1830
  head_mask: Optional[torch.FloatTensor] = None,
2110
1831
  use_cache: Optional[bool] = None,
2111
1832
  output_attentions: Optional[bool] = None,
2112
1833
  output_hidden_states: Optional[bool] = None,
2113
1834
  return_dict: Optional[bool] = None,
1835
+ ages: Optional[torch.FloatTensor] = None,
1836
+ epoch_times: Optional[torch.FloatTensor] = None,
2114
1837
  **kwargs,
2115
1838
  ) -> CehrGptSequenceClassifierOutput:
2116
1839
 
@@ -2120,12 +1843,12 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
2120
1843
  values=values,
2121
1844
  past_key_values=past_key_values,
2122
1845
  attention_mask=attention_mask,
2123
- position_ids=position_ids,
2124
1846
  head_mask=head_mask,
2125
1847
  use_cache=use_cache,
2126
1848
  output_attentions=output_attentions,
2127
1849
  output_hidden_states=output_hidden_states,
2128
1850
  return_dict=return_dict,
1851
+ position_ids=ages,
2129
1852
  )
2130
1853
 
2131
1854
  if is_sample_pack(attention_mask):