cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
cehrgpt/models/hf_cehrgpt.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as f
|
8
8
|
from torch import nn
|
9
|
-
from torch.distributions import Gamma
|
9
|
+
from torch.distributions import Exponential, Gamma
|
10
10
|
from torch.nn import CrossEntropyLoss
|
11
11
|
from torch.nn import functional as F
|
12
12
|
from transformers import PreTrainedModel
|
@@ -45,12 +45,108 @@ if is_accelerate_available():
|
|
45
45
|
logger = logging.get_logger(__name__)
|
46
46
|
|
47
47
|
|
48
|
+
def extract_features_from_packed_sequence(
|
49
|
+
hidden_state: torch.Tensor,
|
50
|
+
attention_mask: torch.Tensor,
|
51
|
+
) -> torch.Tensor:
|
52
|
+
max_index = attention_mask.nonzero(as_tuple=False).flatten()[-1]
|
53
|
+
padded_attention_mask = F.pad(attention_mask[:, : max_index + 1], (0, 1))
|
54
|
+
feature_indices = torch.nonzero(padded_attention_mask == 0)[:, 1] - 1
|
55
|
+
return hidden_state[:, feature_indices]
|
56
|
+
|
57
|
+
|
58
|
+
def create_sample_packing_attention_mask(attention_mask: torch.Tensor) -> torch.Tensor:
|
59
|
+
"""
|
60
|
+
Create a block-diagonal attention mask for packed sequences within a batch.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
attention_mask (torch.Tensor): (batch_size, seq_len) binary mask where 1 = token, 0 = padding
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
torch.Tensor: (batch_size, seq_len, seq_len) attention mask where entries are 1 if tokens
|
67
|
+
can attend to each other (within same packed segment), 0 otherwise.
|
68
|
+
"""
|
69
|
+
# Step 1: Identify segments within each sample
|
70
|
+
cumsum_mask = (attention_mask == 0).cumsum(dim=-1)
|
71
|
+
segment_ids = cumsum_mask * attention_mask # zeros remain zero
|
72
|
+
|
73
|
+
# Step 2: Compare segment IDs pairwise per batch element
|
74
|
+
# Shape: (batch_size, seq_len, seq_len)
|
75
|
+
attn_matrix = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).int()
|
76
|
+
|
77
|
+
# Step 3: Mask out padding tokens
|
78
|
+
mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
|
79
|
+
attn_matrix = attn_matrix * mask
|
80
|
+
|
81
|
+
return attn_matrix
|
82
|
+
|
83
|
+
|
84
|
+
def is_sample_pack(attention_mask: torch.Tensor) -> bool:
|
85
|
+
"""
|
86
|
+
Determines whether any sequence in the batch is likely sample-packed.
|
87
|
+
|
88
|
+
A sample-packed sequence is one where there are non-padding (1) tokens
|
89
|
+
after a padding (0) token, indicating multiple sequences packed together
|
90
|
+
with padding as a separator.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
attention_mask (torch.Tensor): A tensor of shape (batch_size, seq_len)
|
94
|
+
where 1 indicates a real token and 0 indicates padding.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
bool: True if any sample in the batch is sample-packed, False otherwise.
|
98
|
+
"""
|
99
|
+
|
100
|
+
# If the attention_maks is left padded, we will flip it so we can use the same logic below
|
101
|
+
if (attention_mask[:, 0] == 0).any():
|
102
|
+
attention_mask = attention_mask.flip(dims=[1])
|
103
|
+
|
104
|
+
nonzero_counts = attention_mask.sum(dim=1)
|
105
|
+
max_token_positions = torch.argmax(attention_mask.flip(dims=[1]), dim=1)
|
106
|
+
max_indices = attention_mask.shape[1] - 1 - max_token_positions
|
107
|
+
return torch.any(nonzero_counts < (max_indices + 1)).item()
|
108
|
+
|
109
|
+
|
48
110
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
49
111
|
def _get_unpad_data(attention_mask):
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
112
|
+
# This infers sample packing
|
113
|
+
if is_sample_pack(attention_mask):
|
114
|
+
# Assume input: attention_mask shape = (batch, seq_len)
|
115
|
+
attention_mask = attention_mask.flatten() # shape: (seq_len,)
|
116
|
+
|
117
|
+
# Compute max_index of the last non-zero element
|
118
|
+
nonzero = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
119
|
+
max_index = nonzero[-1].item()
|
120
|
+
|
121
|
+
# Pad the truncated attention mask
|
122
|
+
padded_attention_mask = F.pad(attention_mask[: max_index + 1], (0, 1), value=0)
|
123
|
+
|
124
|
+
# Indices of all tokens
|
125
|
+
indices = torch.nonzero(attention_mask, as_tuple=False).flatten()
|
126
|
+
|
127
|
+
# Find where 0s occur (segment boundaries)
|
128
|
+
cumsum_seqlens_in_batch = torch.cumsum(padded_attention_mask, dim=0)[
|
129
|
+
padded_attention_mask == 0
|
130
|
+
]
|
131
|
+
|
132
|
+
# Compute seqlens per segment
|
133
|
+
seqlens_in_batch = (
|
134
|
+
cumsum_seqlens_in_batch
|
135
|
+
- F.pad(cumsum_seqlens_in_batch, (1, 0), value=0)[:-1]
|
136
|
+
).to(torch.int)
|
137
|
+
|
138
|
+
max_seqlen_in_batch = (
|
139
|
+
seqlens_in_batch.max().item() if seqlens_in_batch.numel() > 0 else 0
|
140
|
+
)
|
141
|
+
cu_seqlens = F.pad(cumsum_seqlens_in_batch, (1, 0)).to(torch.int)
|
142
|
+
else:
|
143
|
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
144
|
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
145
|
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
146
|
+
cu_seqlens = F.pad(
|
147
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
148
|
+
)
|
149
|
+
|
54
150
|
return (
|
55
151
|
indices,
|
56
152
|
cu_seqlens,
|
@@ -266,9 +362,37 @@ class GPT2FlashAttention(GPT2Attention):
|
|
266
362
|
)
|
267
363
|
|
268
364
|
|
269
|
-
class
|
365
|
+
class MotorTaskHead(nn.Module):
|
366
|
+
def __init__(self, input_dim, motor_tte_vocab_size, motor_num_time_pieces):
|
367
|
+
super(MotorTaskHead, self).__init__()
|
368
|
+
self.input_dim = input_dim
|
369
|
+
self.motor_tte_vocab_size = motor_tte_vocab_size
|
370
|
+
self.motor_num_time_pieces = motor_num_time_pieces
|
371
|
+
self.linear = nn.Sequential(
|
372
|
+
nn.Linear(input_dim, input_dim // 2),
|
373
|
+
gelu_new,
|
374
|
+
nn.Linear(
|
375
|
+
input_dim // 2, motor_tte_vocab_size * self.motor_num_time_pieces
|
376
|
+
),
|
377
|
+
)
|
378
|
+
|
379
|
+
def forward(self, x):
|
380
|
+
# Ensure scale is positive
|
381
|
+
length = x.shape[0]
|
382
|
+
# (num_visits_in_batch, motor_tte_vocab_size * motor_num_time_pieces)
|
383
|
+
lambda_p = f.softplus(self.linear(x))
|
384
|
+
# Check for NaN values
|
385
|
+
if torch.isnan(lambda_p).any():
|
386
|
+
logger.warning(f"NaN values found in scale_param. x: {x}")
|
387
|
+
# (num_visits_in_batch, motor_num_time_pieces, motor_tte_vocab_size,)
|
388
|
+
return lambda_p.view(
|
389
|
+
length, self.motor_num_time_pieces, self.motor_tte_vocab_size
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
class VisitTimeToEventHead(nn.Module):
|
270
394
|
def __init__(self, input_dim):
|
271
|
-
super(
|
395
|
+
super(VisitTimeToEventHead, self).__init__()
|
272
396
|
self.linear1 = nn.Sequential(
|
273
397
|
nn.Linear(input_dim, input_dim // 2), gelu_new, nn.Linear(input_dim // 2, 1)
|
274
398
|
)
|
@@ -565,32 +689,33 @@ class CEHRGPTPreTrainedModel(PreTrainedModel):
|
|
565
689
|
hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
|
566
690
|
)
|
567
691
|
wpe = self.get_position_embeddings()
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
# initialize all new embeddings (in particular added tokens)
|
578
|
-
self._init_weights(new_embeddings)
|
579
|
-
if is_deepspeed_zero3_enabled() and not is_quantized:
|
580
|
-
import deepspeed
|
692
|
+
if wpe is not None:
|
693
|
+
max_position, embed_dim = wpe.weight.shape
|
694
|
+
if new_num_position_embeddings > max_position:
|
695
|
+
new_embeddings = nn.Embedding(
|
696
|
+
new_num_position_embeddings,
|
697
|
+
embed_dim,
|
698
|
+
device=wpe.weight.device,
|
699
|
+
dtype=wpe.weight.dtype,
|
700
|
+
)
|
581
701
|
|
582
|
-
|
583
|
-
|
702
|
+
# initialize all new embeddings (in particular added tokens)
|
703
|
+
self._init_weights(new_embeddings)
|
704
|
+
if is_deepspeed_zero3_enabled() and not is_quantized:
|
705
|
+
import deepspeed
|
706
|
+
|
707
|
+
params = [wpe.weight, new_embeddings.weight]
|
708
|
+
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
|
709
|
+
new_embeddings.weight.data[:max_position, :] = (
|
710
|
+
wpe.weight.data[:max_position, :]
|
711
|
+
)
|
712
|
+
else:
|
584
713
|
new_embeddings.weight.data[:max_position, :] = wpe.weight.data[
|
585
714
|
:max_position, :
|
586
715
|
]
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
]
|
591
|
-
self.set_position_embeddings(new_embeddings)
|
592
|
-
self.config.max_position_embeddings = new_num_position_embeddings
|
593
|
-
self.update_attn_bias(new_num_position_embeddings)
|
716
|
+
self.set_position_embeddings(new_embeddings)
|
717
|
+
self.config.max_position_embeddings = new_num_position_embeddings
|
718
|
+
self.update_attn_bias(new_num_position_embeddings)
|
594
719
|
|
595
720
|
|
596
721
|
class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
@@ -609,7 +734,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
609
734
|
self.pretrained_wte = None
|
610
735
|
|
611
736
|
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
612
|
-
|
737
|
+
if not self.exclude_position_ids:
|
738
|
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
613
739
|
if self.include_values:
|
614
740
|
self.vte = nn.Embedding(config.value_vocab_size, self.embed_dim)
|
615
741
|
self.concept_value_transformation_layer = ConceptValueTransformationLayer(
|
@@ -635,6 +761,18 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
635
761
|
# Initialize weights and apply final processing
|
636
762
|
self.post_init()
|
637
763
|
|
764
|
+
# We do need to update the pre-computed attention bias matrix if sample packing requires a larger context window
|
765
|
+
if self.config.sample_packing_max_positions > self.config.n_positions:
|
766
|
+
logger.info(
|
767
|
+
"Updated attn_bias to %s according to sample_packing_max_positions",
|
768
|
+
config.sample_packing_max_positions,
|
769
|
+
)
|
770
|
+
self.update_attn_bias(self.config.sample_packing_max_positions)
|
771
|
+
|
772
|
+
def enable_position_embeddings(self):
|
773
|
+
self.wpe = nn.Embedding(self.config.max_position_embeddings, self.embed_dim)
|
774
|
+
self.config.exclude_position_ids = False
|
775
|
+
|
638
776
|
def initialize_pretrained_embeddings(self):
|
639
777
|
layers = [
|
640
778
|
nn.Embedding(self.config.vocab_size, self.config.pretrained_embedding_dim),
|
@@ -677,7 +815,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
677
815
|
self.wte = self.wte.to(self.first_device)
|
678
816
|
if self.config.use_pretrained_embeddings:
|
679
817
|
self.pretrained_wte = self.pretrained_wte.to(self.first_device)
|
680
|
-
|
818
|
+
if not self.exclude_position_ids:
|
819
|
+
self.wpe = self.wpe.to(self.first_device)
|
681
820
|
if self.include_values:
|
682
821
|
self.vte = self.vte.to(self.first_device)
|
683
822
|
self.concept_value_transformation_layer = (
|
@@ -703,7 +842,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
703
842
|
self.wte = self.wte.to("cpu")
|
704
843
|
if self.config.use_pretrained_embeddings:
|
705
844
|
self.pretrained_wte = self.pretrained_wte.to("cpu")
|
706
|
-
|
845
|
+
if not self.exclude_position_ids:
|
846
|
+
self.wpe = self.wpe.to("cpu")
|
707
847
|
self.vte = self.vte.to("cpu")
|
708
848
|
self.concept_value_transformation_layer = (
|
709
849
|
self.concept_value_transformation_layer.to("cpu")
|
@@ -728,8 +868,12 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
728
868
|
persistent=False,
|
729
869
|
)
|
730
870
|
|
731
|
-
def get_position_embeddings(
|
732
|
-
|
871
|
+
def get_position_embeddings(
|
872
|
+
self,
|
873
|
+
) -> Optional[Union[nn.Embedding, Tuple[nn.Embedding]]]:
|
874
|
+
if not self.exclude_position_ids:
|
875
|
+
return self.wpe
|
876
|
+
return None
|
733
877
|
|
734
878
|
def set_position_embeddings(self, new_embeddings: nn.Embedding):
|
735
879
|
self.wpe = new_embeddings
|
@@ -758,8 +902,8 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
758
902
|
def forward(
|
759
903
|
self,
|
760
904
|
input_ids: Optional[torch.LongTensor],
|
761
|
-
value_indicators: Optional[torch.BoolTensor],
|
762
|
-
values: Optional[torch.LongTensor],
|
905
|
+
value_indicators: Optional[torch.BoolTensor] = None,
|
906
|
+
values: Optional[torch.LongTensor] = None,
|
763
907
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
764
908
|
attention_mask: Optional[torch.FloatTensor] = None,
|
765
909
|
position_ids: Optional[torch.LongTensor] = None,
|
@@ -850,12 +994,19 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
850
994
|
== "flash_attention_2"
|
851
995
|
):
|
852
996
|
attention_mask = attention_mask.view(batch_size, -1)
|
853
|
-
|
854
|
-
#
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
997
|
+
|
998
|
+
# If this is sample packing, we need to great the
|
999
|
+
if is_sample_pack(attention_mask):
|
1000
|
+
attention_mask = create_sample_packing_attention_mask(
|
1001
|
+
attention_mask
|
1002
|
+
)[:, None, :, :]
|
1003
|
+
else:
|
1004
|
+
# We create a 3D attention mask from a 2D tensor mask.
|
1005
|
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
1006
|
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
1007
|
+
# this attention mask is more simple than the triangular masking of causal attention
|
1008
|
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
1009
|
+
attention_mask = attention_mask[:, None, None, :]
|
859
1010
|
|
860
1011
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
861
1012
|
# masked positions, this operation will create a tensor which is 0.0 for
|
@@ -925,7 +1076,7 @@ class CEHRGPT2Model(CEHRGPTPreTrainedModel):
|
|
925
1076
|
)
|
926
1077
|
|
927
1078
|
if not self.exclude_position_ids:
|
928
|
-
position_embeds = self.wpe(position_ids)
|
1079
|
+
position_embeds = self.wpe(position_ids).to(input_embeddings.dtype)
|
929
1080
|
hidden_states = input_embeddings + position_embeds
|
930
1081
|
else:
|
931
1082
|
hidden_states = input_embeddings
|
@@ -1034,7 +1185,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1034
1185
|
super().__init__(config)
|
1035
1186
|
self.cehrgpt = CEHRGPT2Model(config)
|
1036
1187
|
if self.config.include_ttv_prediction:
|
1037
|
-
self.tte_head =
|
1188
|
+
self.tte_head = VisitTimeToEventHead(config.n_embd)
|
1038
1189
|
|
1039
1190
|
if self.config.use_sub_time_tokenization:
|
1040
1191
|
self.time_token_lm_head = nn.Linear(
|
@@ -1047,6 +1198,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1047
1198
|
config.n_embd, config.value_vocab_size, bias=False
|
1048
1199
|
)
|
1049
1200
|
|
1201
|
+
if self.config.include_motor_time_to_event:
|
1202
|
+
self.motor_tte = MotorTaskHead(
|
1203
|
+
config.n_embd, config.motor_tte_vocab_size, config.motor_num_time_pieces
|
1204
|
+
)
|
1205
|
+
|
1050
1206
|
# Model parallel
|
1051
1207
|
self.model_parallel = False
|
1052
1208
|
self.device_map = None
|
@@ -1074,6 +1230,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1074
1230
|
self.value_head = self.value_head.to(self.cehrgpt.first_device)
|
1075
1231
|
if self.config.include_ttv_prediction:
|
1076
1232
|
self.tte_head = self.tte_head.to(self.cehrgpt.first_device)
|
1233
|
+
if self.config.include_motor_time_to_event:
|
1234
|
+
self.motor_tte = self.motor_tte.to(self.cehrgpt.first_device)
|
1077
1235
|
self.model_parallel = True
|
1078
1236
|
|
1079
1237
|
def deparallelize(self):
|
@@ -1088,6 +1246,8 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1088
1246
|
self.value_head = self.value_head.to("cpu")
|
1089
1247
|
if self.config.include_ttv_prediction:
|
1090
1248
|
self.tte_head = self.tte_head.to("cpu")
|
1249
|
+
if self.config.include_motor_time_to_event:
|
1250
|
+
self.motor_tte = self.motor_tte.to("cpu")
|
1091
1251
|
self.model_parallel = False
|
1092
1252
|
torch.cuda.empty_cache()
|
1093
1253
|
|
@@ -1115,6 +1275,28 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1115
1275
|
def update_attn_bias(self, max_position_embeddings: int):
|
1116
1276
|
self.cehrgpt.update_attn_bias(max_position_embeddings)
|
1117
1277
|
|
1278
|
+
def update_motor_tte_vocab_size(
|
1279
|
+
self, motor_tte_vocab_size: Optional[int] = None
|
1280
|
+
) -> None:
|
1281
|
+
update_motor_tte_layer = False
|
1282
|
+
if motor_tte_vocab_size and motor_tte_vocab_size > 0:
|
1283
|
+
if self.config.include_motor_time_to_event:
|
1284
|
+
if self.config.motor_tte_vocab_size != motor_tte_vocab_size:
|
1285
|
+
self.config.include_motor_time_to_event = True
|
1286
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1287
|
+
update_motor_tte_layer = True
|
1288
|
+
else:
|
1289
|
+
self.config.include_motor_time_to_event = True
|
1290
|
+
self.config.motor_tte_vocab_size = motor_tte_vocab_size
|
1291
|
+
update_motor_tte_layer = True
|
1292
|
+
|
1293
|
+
if update_motor_tte_layer:
|
1294
|
+
self.motor_tte = MotorTaskHead(
|
1295
|
+
self.config.n_embd,
|
1296
|
+
self.config.motor_tte_vocab_size,
|
1297
|
+
self.config.motor_num_time_pieces,
|
1298
|
+
)
|
1299
|
+
|
1118
1300
|
def prepare_inputs_for_generation(
|
1119
1301
|
self,
|
1120
1302
|
input_ids,
|
@@ -1210,6 +1392,74 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1210
1392
|
|
1211
1393
|
return model_inputs
|
1212
1394
|
|
1395
|
+
def motor_nll_loss(
|
1396
|
+
self,
|
1397
|
+
ve_token_features,
|
1398
|
+
motor_time_to_event_vectors,
|
1399
|
+
motor_event_indicators,
|
1400
|
+
motor_time_to_event_to_include,
|
1401
|
+
motor_time_indicators,
|
1402
|
+
batch_motor_end_index,
|
1403
|
+
):
|
1404
|
+
"""
|
1405
|
+
Computes the negative log-likelihood (NLL) loss using the LogNormal distribution.
|
1406
|
+
|
1407
|
+
for modeling time-to-event data at each visit.
|
1408
|
+
|
1409
|
+
Args:
|
1410
|
+
ve_token_features (Tensor): Hidden representations for the [VE] tokens [num_visits, hidden_dim].
|
1411
|
+
motor_time_to_event_vectors (Tensor): Raw time-to-event durations [B, T, motor_vocab_size] (flattened).
|
1412
|
+
motor_time_to_event_to_include: (Tensor): Bool indicators (True if included, False if not included).
|
1413
|
+
motor_event_indicators (Tensor): Binary indicators (1 if censored, 0 if event occurred).
|
1414
|
+
motor_time_indicators (Tensor): Binary indicators whether the time occurs in the current
|
1415
|
+
time bucket (1 if censored, 0 if event occurred).
|
1416
|
+
batch_motor_end_index (Tensor): Tensor indicating the number of valid [VE] tokens in the batch.
|
1417
|
+
|
1418
|
+
Returns:
|
1419
|
+
Tensor: Scalar loss value (mean negative log-likelihood).
|
1420
|
+
"""
|
1421
|
+
batch_motor_end_index = batch_motor_end_index.sum().item()
|
1422
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors.view(
|
1423
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1424
|
+
)[:batch_motor_end_index].clamp(min=1e-3)
|
1425
|
+
motor_event_indicators = motor_event_indicators.reshape(
|
1426
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1427
|
+
)[:batch_motor_end_index]
|
1428
|
+
motor_time_to_event_to_include = motor_time_to_event_to_include.flatten()[
|
1429
|
+
:batch_motor_end_index
|
1430
|
+
]
|
1431
|
+
motor_time_indicators = motor_time_indicators.view(
|
1432
|
+
(-1, self.config.motor_num_time_pieces, self.config.motor_tte_vocab_size)
|
1433
|
+
)[:batch_motor_end_index]
|
1434
|
+
assert ve_token_features.shape[0] == motor_time_to_event_vectors.shape[0], (
|
1435
|
+
"The number of VE tokens in the labels needs to match up "
|
1436
|
+
"with the first dimension of motor_time_to_event_vectors. "
|
1437
|
+
f"Received ve_token_features.shape[0]: {ve_token_features.shape[0]}, "
|
1438
|
+
f"motor_time_to_event_vectors.shape[0]: {motor_time_to_event_vectors.shape[0]}"
|
1439
|
+
)
|
1440
|
+
motor_time_to_event_vectors = motor_time_to_event_vectors[
|
1441
|
+
motor_time_to_event_to_include
|
1442
|
+
]
|
1443
|
+
motor_event_indicators = motor_event_indicators[motor_time_to_event_to_include]
|
1444
|
+
motor_time_indicators = motor_time_indicators[motor_time_to_event_to_include]
|
1445
|
+
ve_token_features = ve_token_features[motor_time_to_event_to_include]
|
1446
|
+
|
1447
|
+
# Get Exponential parameters from model
|
1448
|
+
lambda_p = self.motor_tte(ve_token_features)
|
1449
|
+
# (num_visits_in_batch, num_of_pieces, motor_vocab_size)
|
1450
|
+
dist = Exponential(lambda_p.clamp(min=1e-3))
|
1451
|
+
|
1452
|
+
# Compute event loss
|
1453
|
+
tte_loss = torch.where(
|
1454
|
+
motor_event_indicators,
|
1455
|
+
-dist.log_prob(motor_time_to_event_vectors),
|
1456
|
+
-torch.log(
|
1457
|
+
1 - dist.cdf(motor_time_to_event_vectors).clamp(max=1 - 1e-6) + 1e-6
|
1458
|
+
),
|
1459
|
+
)
|
1460
|
+
tte_loss = torch.where(motor_time_indicators, tte_loss, 0.0)
|
1461
|
+
return torch.mean(tte_loss)
|
1462
|
+
|
1213
1463
|
def forward(
|
1214
1464
|
self,
|
1215
1465
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -1226,6 +1476,11 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1226
1476
|
time_to_visits: Optional[torch.FloatTensor] = None,
|
1227
1477
|
time_token_indicators: Optional[torch.BoolTensor] = None,
|
1228
1478
|
sub_time_tokens: Optional[torch.LongTensor] = None,
|
1479
|
+
motor_time_to_event_vectors: Optional[torch.FloatTensor] = None,
|
1480
|
+
motor_event_indicators: Optional[torch.BoolTensor] = None,
|
1481
|
+
motor_time_to_event_to_include: Optional[torch.BoolTensor] = None,
|
1482
|
+
motor_time_indicators: Optional[torch.BoolTensor] = None,
|
1483
|
+
motor_end_index: Optional[torch.LongTensor] = None,
|
1229
1484
|
use_cache: Optional[bool] = None,
|
1230
1485
|
output_attentions: Optional[bool] = None,
|
1231
1486
|
output_hidden_states: Optional[bool] = None,
|
@@ -1285,12 +1540,31 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1285
1540
|
time_token_loss = None
|
1286
1541
|
time_to_visit_loss = None
|
1287
1542
|
token_value_loss = None
|
1543
|
+
motor_tte_loss = None
|
1544
|
+
|
1288
1545
|
if labels is not None:
|
1289
1546
|
# move labels to correct device to enable model parallelism
|
1290
1547
|
labels = labels.to(lm_logits.device)
|
1548
|
+
|
1549
|
+
if self.config.causal_sfm:
|
1550
|
+
# Ensure demographic_labels matches the dtype of original labels
|
1551
|
+
demographic_labels = torch.full(
|
1552
|
+
(labels.shape[0], self.config.demographics_size),
|
1553
|
+
-100,
|
1554
|
+
dtype=labels.dtype, # Match the original labels' dtype
|
1555
|
+
device=labels.device, # Ensure on the same device
|
1556
|
+
)
|
1557
|
+
# Concatenate the demographic labels with the rest of the original labels
|
1558
|
+
labels = torch.cat(
|
1559
|
+
(demographic_labels, labels[:, self.config.demographics_size :]),
|
1560
|
+
dim=1,
|
1561
|
+
)
|
1562
|
+
|
1291
1563
|
# Shift so that tokens < n predict n
|
1292
1564
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1293
1565
|
shift_labels = labels[..., 1:].contiguous()
|
1566
|
+
valid_tokens: torch.BoolTensor = shift_labels != 100
|
1567
|
+
total_num_tokens = valid_tokens.sum()
|
1294
1568
|
if (
|
1295
1569
|
self.cehrgpt.config.lab_token_penalty
|
1296
1570
|
and self.cehrgpt.config.lab_token_exists
|
@@ -1310,28 +1584,60 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1310
1584
|
lab_index,
|
1311
1585
|
token_loss * self.cehrgpt.config.lab_token_loss_weight,
|
1312
1586
|
token_loss,
|
1313
|
-
)
|
1587
|
+
)
|
1588
|
+
|
1589
|
+
token_loss = token_loss.sum() / total_num_tokens
|
1314
1590
|
else:
|
1315
1591
|
# Flatten the tokens
|
1316
|
-
loss_fct = CrossEntropyLoss()
|
1592
|
+
loss_fct = CrossEntropyLoss(reduction="none")
|
1317
1593
|
token_loss = loss_fct(
|
1318
1594
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
1319
1595
|
)
|
1320
|
-
|
1596
|
+
token_loss = token_loss.sum() / total_num_tokens
|
1597
|
+
|
1598
|
+
loss = token_loss * self.cehrgpt.config.next_token_prediction_loss_weight
|
1321
1599
|
|
1322
1600
|
if self.cehrgpt.config.entropy_penalty:
|
1323
1601
|
# Compute probabilities using softmax
|
1324
|
-
probs = torch.softmax(
|
1602
|
+
probs = torch.softmax(shift_logits, dim=-1)
|
1325
1603
|
# Compute negative entropy: sum(p * log(p))
|
1326
1604
|
entropy = torch.sum(
|
1327
1605
|
probs * torch.log(probs + 1e-9), dim=-1
|
1328
1606
|
) # Add epsilon for numerical stability
|
1607
|
+
entropy = torch.where(valid_tokens, entropy, 0)
|
1329
1608
|
# Regularization term: mean entropy scaled by alpha
|
1330
|
-
|
1609
|
+
entropy_penalty = entropy.sum() / total_num_tokens
|
1610
|
+
loss += entropy_penalty * self.cehrgpt.config.entropy_penalty_alpha
|
1611
|
+
|
1612
|
+
if (
|
1613
|
+
self.config.include_motor_time_to_event
|
1614
|
+
and motor_time_to_event_vectors is not None
|
1615
|
+
and motor_event_indicators is not None
|
1616
|
+
and motor_time_to_event_to_include is not None
|
1617
|
+
and motor_time_indicators is not None
|
1618
|
+
and motor_end_index is not None
|
1619
|
+
):
|
1620
|
+
ve_token_id_indices = labels == self.config.ve_token_id
|
1621
|
+
ve_token_features = hidden_states[ve_token_id_indices]
|
1622
|
+
# Get rid of the last VE features because it's already reached the end of the patient sequence and
|
1623
|
+
# there is nothing to predict.
|
1624
|
+
motor_tte_loss = self.motor_nll_loss(
|
1625
|
+
ve_token_features=ve_token_features,
|
1626
|
+
motor_time_to_event_vectors=motor_time_to_event_vectors,
|
1627
|
+
motor_event_indicators=motor_event_indicators,
|
1628
|
+
motor_time_to_event_to_include=motor_time_to_event_to_include,
|
1629
|
+
motor_time_indicators=motor_time_indicators,
|
1630
|
+
batch_motor_end_index=motor_end_index,
|
1631
|
+
)
|
1632
|
+
loss += motor_tte_loss * self.config.motor_time_to_event_weight
|
1331
1633
|
|
1332
1634
|
# We add another loss term when use_sub_time_tokenization is enabled, we need to recover the sub time token
|
1333
1635
|
# predictions for year/month/token
|
1334
|
-
if
|
1636
|
+
if (
|
1637
|
+
self.config.use_sub_time_tokenization
|
1638
|
+
and sub_time_tokens is not None
|
1639
|
+
and time_token_indicators is not None
|
1640
|
+
):
|
1335
1641
|
# Split the last dimensions into three parts
|
1336
1642
|
time_loss_fct = CrossEntropyLoss(reduction="none")
|
1337
1643
|
time_token_logits = self.time_token_lm_head(
|
@@ -1352,54 +1658,61 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1352
1658
|
),
|
1353
1659
|
shifted_time_token_labels.view(-1),
|
1354
1660
|
)
|
1355
|
-
|
1356
|
-
|
1357
|
-
-1, 3
|
1358
|
-
|
1359
|
-
time_token_loss = time_token_loss.sum(-1)
|
1360
|
-
time_token_loss = (
|
1361
|
-
torch.mean(time_token_loss) * self.config.time_token_loss_weight
|
1661
|
+
time_token_loss = torch.where(
|
1662
|
+
shifted_time_token_indicators.view(-1, 1).to(torch.bool),
|
1663
|
+
time_token_loss.view(-1, 3),
|
1664
|
+
0,
|
1362
1665
|
)
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
|
1379
|
-
|
1380
|
-
|
1381
|
-
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1666
|
+
time_token_loss = time_token_loss.sum() / total_num_tokens
|
1667
|
+
loss += time_token_loss * self.config.time_token_loss_weight
|
1668
|
+
|
1669
|
+
if time_to_visits is not None and time_to_visits is not None:
|
1670
|
+
# Get lambda and k parameters
|
1671
|
+
lambda_param, k_param = self.tte_head(hidden_states)
|
1672
|
+
|
1673
|
+
# Perform slicing before tensors are split across GPUs
|
1674
|
+
shifted_lambda_param = lambda_param[..., :-1, :].contiguous()
|
1675
|
+
shifted_k_param = k_param[..., :-1, :].contiguous()
|
1676
|
+
shift_time_to_visits = time_to_visits[..., 1:].contiguous()
|
1677
|
+
|
1678
|
+
# Move to the same device as lambda_param
|
1679
|
+
shift_time_to_visits = shift_time_to_visits.to(lambda_param.device)
|
1680
|
+
time_to_visit_indicator = shift_time_to_visits >= 0
|
1681
|
+
# Define the Gamma distribution
|
1682
|
+
dist = Gamma(
|
1683
|
+
shifted_k_param.squeeze(-1), shifted_lambda_param.squeeze(-1)
|
1684
|
+
)
|
1685
|
+
# Compute log-probs and apply the time_to_visit_indicator
|
1686
|
+
log_probs = dist.log_prob(
|
1687
|
+
torch.clamp(shift_time_to_visits, min=1e-3) + 1e-6
|
1688
|
+
)
|
1689
|
+
log_probs = torch.where(time_to_visit_indicator, log_probs, 0)
|
1690
|
+
time_to_visit_loss = -log_probs.sum() / total_num_tokens
|
1691
|
+
# Compute the loss
|
1692
|
+
loss += time_to_visit_loss * self.config.time_to_visit_loss_weight
|
1693
|
+
|
1694
|
+
if true_values is not None and true_value_indicators is not None:
|
1695
|
+
true_values = true_values.to(value_logits.device)
|
1696
|
+
shift_value_logits = value_logits[..., :-1, :].contiguous()
|
1697
|
+
shift_value_indicators = true_value_indicators[..., :-1].contiguous()
|
1698
|
+
shift_next_values = true_values[..., 1:].contiguous()
|
1699
|
+
value_loss_fct = CrossEntropyLoss(reduction="none")
|
1700
|
+
token_value_loss = value_loss_fct(
|
1701
|
+
shift_value_logits.view(-1, shift_value_logits.size(-1)),
|
1702
|
+
shift_next_values.view(-1),
|
1703
|
+
)
|
1704
|
+
token_value_loss = torch.where(
|
1705
|
+
shift_value_indicators.view(-1), token_value_loss, 0
|
1706
|
+
)
|
1707
|
+
token_value_loss = token_value_loss.sum() / total_num_tokens
|
1708
|
+
if (
|
1709
|
+
self.cehrgpt.config.lab_token_penalty
|
1710
|
+
and self.cehrgpt.config.lab_token_exists
|
1711
|
+
):
|
1712
|
+
token_value_loss = (
|
1713
|
+
token_value_loss * self.config.lab_token_loss_weight
|
1714
|
+
)
|
1715
|
+
loss += token_value_loss * self.config.value_prediction_loss_weight
|
1403
1716
|
|
1404
1717
|
if not return_dict:
|
1405
1718
|
output = (lm_logits,) + transformer_outputs[1:]
|
@@ -1417,6 +1730,7 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1417
1730
|
time_token_loss=time_token_loss,
|
1418
1731
|
time_to_visit_loss=time_to_visit_loss,
|
1419
1732
|
token_value_loss=token_value_loss,
|
1733
|
+
motor_tte_loss=motor_tte_loss,
|
1420
1734
|
)
|
1421
1735
|
|
1422
1736
|
@staticmethod
|
@@ -1690,6 +2004,27 @@ class CEHRGPT2LMHeadModel(CEHRGPTPreTrainedModel):
|
|
1690
2004
|
)
|
1691
2005
|
|
1692
2006
|
|
2007
|
+
class FocalLoss(nn.Module):
|
2008
|
+
def __init__(self, alpha=0.25, gamma=2.0, reduction="mean"):
|
2009
|
+
super().__init__()
|
2010
|
+
self.alpha = alpha
|
2011
|
+
self.gamma = gamma
|
2012
|
+
self.reduction = reduction
|
2013
|
+
|
2014
|
+
def forward(self, logits, targets):
|
2015
|
+
bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
|
2016
|
+
probs = torch.sigmoid(logits)
|
2017
|
+
pt = torch.where(targets == 1, probs, 1 - probs)
|
2018
|
+
focal_term = (1 - pt) ** self.gamma
|
2019
|
+
loss = self.alpha * focal_term * bce_loss
|
2020
|
+
|
2021
|
+
if self.reduction == "mean":
|
2022
|
+
return loss.mean()
|
2023
|
+
elif self.reduction == "sum":
|
2024
|
+
return loss.sum()
|
2025
|
+
return loss
|
2026
|
+
|
2027
|
+
|
1693
2028
|
class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
1694
2029
|
_keep_in_fp32_modules = ["age_batch_norm", "dense_layer", "classifier"]
|
1695
2030
|
|
@@ -1712,7 +2047,6 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1712
2047
|
self.model_parallel = False
|
1713
2048
|
self.device_map = None
|
1714
2049
|
self.gradient_checkpointing = False
|
1715
|
-
|
1716
2050
|
# Initialize weights and apply final processing
|
1717
2051
|
self.post_init()
|
1718
2052
|
|
@@ -1768,6 +2102,7 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1768
2102
|
return_dict: Optional[bool] = None,
|
1769
2103
|
**kwargs,
|
1770
2104
|
) -> CehrGptSequenceClassifierOutput:
|
2105
|
+
|
1771
2106
|
cehrgpt_output = self.cehrgpt(
|
1772
2107
|
input_ids=input_ids,
|
1773
2108
|
value_indicators=value_indicators,
|
@@ -1782,17 +2117,39 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1782
2117
|
return_dict=return_dict,
|
1783
2118
|
)
|
1784
2119
|
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
2120
|
+
if is_sample_pack(attention_mask):
|
2121
|
+
features = extract_features_from_packed_sequence(
|
2122
|
+
cehrgpt_output.last_hidden_state, attention_mask
|
2123
|
+
)
|
2124
|
+
assert features.shape[1] == classifier_label.shape[1], (
|
2125
|
+
"the length of the features need to be the same as the length of classifier_label. "
|
2126
|
+
f"features.shape[1]: {features.shape[1]}, "
|
2127
|
+
f"classifier_label.shape[1]: {classifier_label.shape[1]}"
|
2128
|
+
)
|
2129
|
+
assert features.shape[1] == age_at_index.shape[1], (
|
2130
|
+
"the length of the features need to be the same as the length of age_at_index. "
|
2131
|
+
f"features.shape[1]: {features.shape[1]}, "
|
2132
|
+
f"age_at_index.shape[1]: {age_at_index.shape[1]}"
|
2133
|
+
)
|
2134
|
+
num_samples = age_at_index.shape[1]
|
2135
|
+
features = features.view((num_samples, -1))
|
2136
|
+
classifier_label = classifier_label.view((num_samples, -1))
|
2137
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
2138
|
+
normalized_age = self._apply_age_norm(
|
2139
|
+
age_at_index.view((num_samples, 1))
|
2140
|
+
)
|
2141
|
+
else:
|
2142
|
+
features = cehrgpt_output.last_hidden_state[..., -1, :]
|
2143
|
+
# Disable autocasting for precision-sensitive operations
|
2144
|
+
with torch.autocast(device_type="cuda", enabled=False):
|
2145
|
+
normalized_age = self._apply_age_norm(age_at_index)
|
1788
2146
|
|
1789
2147
|
# In case the model is in bfloat16
|
1790
|
-
if
|
1791
|
-
normalized_age = normalized_age.to(
|
2148
|
+
if features.dtype != normalized_age.dtype:
|
2149
|
+
normalized_age = normalized_age.to(features.dtype)
|
1792
2150
|
|
1793
2151
|
# In fine-tuning, the sequences are left-padded, so we use the last element as the pooler
|
1794
|
-
|
1795
|
-
next_input = self.dropout(output_pooler)
|
2152
|
+
next_input = self.dropout(features)
|
1796
2153
|
next_input = torch.cat([next_input, normalized_age], dim=1)
|
1797
2154
|
next_input = self.dense_layer(next_input)
|
1798
2155
|
next_input = nn.functional.relu(next_input)
|
@@ -1801,7 +2158,14 @@ class CehrGptForClassification(CEHRGPTPreTrainedModel):
|
|
1801
2158
|
|
1802
2159
|
loss = None
|
1803
2160
|
if classifier_label is not None:
|
1804
|
-
|
2161
|
+
if self.config.class_weights:
|
2162
|
+
class_weights = torch.tensor(
|
2163
|
+
[self.config.class_weights[1] / self.config.class_weights[0]],
|
2164
|
+
dtype=torch.float32,
|
2165
|
+
).to(logits.device)
|
2166
|
+
else:
|
2167
|
+
class_weights = None
|
2168
|
+
loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
|
1805
2169
|
loss = loss_fct(logits, classifier_label)
|
1806
2170
|
|
1807
2171
|
return CehrGptSequenceClassifierOutput(
|