cehrgpt 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +279 -2
- cehrgpt/data/sample_packing_sampler.py +151 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/config.py +10 -0
- cehrgpt/models/hf_cehrgpt.py +244 -73
- cehrgpt/models/tokenization_hf_cehrgpt.py +6 -2
- cehrgpt/runners/data_utils.py +243 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +154 -260
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +250 -90
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +46 -0
- cehrgpt/runners/hyperparameter_search_util.py +4 -1
- cehrgpt/runners/sample_packing_trainer.py +168 -0
- cehrgpt/simulations/__init__.py +0 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +57 -9
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +30 -18
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
|
|
1
|
+
from typing import Optional, Union
|
2
|
+
|
3
|
+
from datasets import Dataset
|
4
|
+
from torch.utils.data import DataLoader
|
5
|
+
from transformers import Trainer
|
6
|
+
from transformers.trainer_utils import has_length
|
7
|
+
from transformers.utils import import_utils, logging
|
8
|
+
|
9
|
+
from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
|
10
|
+
|
11
|
+
DEFAULT_MAX_TOKENS_PER_BATCH = 16384
|
12
|
+
|
13
|
+
LOG = logging.get_logger("transformers")
|
14
|
+
|
15
|
+
|
16
|
+
class SamplePackingTrainer(Trainer):
|
17
|
+
def __init__(self, *args, **kwargs):
|
18
|
+
if "max_tokens_per_batch" in kwargs:
|
19
|
+
self.max_tokens_per_batch = kwargs.pop("max_tokens_per_batch")
|
20
|
+
LOG.info("max_tokens_per_batch: %s", self.max_tokens_per_batch)
|
21
|
+
else:
|
22
|
+
self.max_tokens_per_batch = DEFAULT_MAX_TOKENS_PER_BATCH
|
23
|
+
LOG.info(
|
24
|
+
"max_tokens_per_batch is not provided to SamplePackingTrainer and will default to %s",
|
25
|
+
DEFAULT_MAX_TOKENS_PER_BATCH,
|
26
|
+
)
|
27
|
+
|
28
|
+
if "max_position_embeddings" in kwargs:
|
29
|
+
self.max_position_embeddings = kwargs.pop("max_position_embeddings")
|
30
|
+
LOG.info("max_position_embeddings: %s", self.max_position_embeddings)
|
31
|
+
else:
|
32
|
+
self.max_position_embeddings = self.max_tokens_per_batch
|
33
|
+
LOG.info(
|
34
|
+
"max_position_embeddings is not provided to SamplePackingTrainer and will default to %s",
|
35
|
+
self.max_tokens_per_batch,
|
36
|
+
)
|
37
|
+
|
38
|
+
self.train_lengths = kwargs.pop("train_lengths", None)
|
39
|
+
self.validation_lengths = kwargs.pop("validation_lengths", None)
|
40
|
+
super().__init__(*args, **kwargs)
|
41
|
+
self.accelerator.even_batches = False
|
42
|
+
|
43
|
+
def num_examples(self, dataloader: DataLoader) -> int:
|
44
|
+
if has_length(dataloader):
|
45
|
+
return len(dataloader)
|
46
|
+
raise RuntimeError("DataLoader in SamplePackingTrainer must have length")
|
47
|
+
|
48
|
+
def get_train_dataloader(self) -> DataLoader:
|
49
|
+
"""Returns the training dataloader with our custom batch sampler."""
|
50
|
+
train_dataset = self.train_dataset
|
51
|
+
|
52
|
+
if self.train_lengths is None:
|
53
|
+
LOG.info("Started computing lengths for the train dataset")
|
54
|
+
# Calculate lengths of all sequences in dataset
|
55
|
+
if "num_of_concepts" in train_dataset.column_names:
|
56
|
+
lengths = train_dataset["num_of_concepts"]
|
57
|
+
else:
|
58
|
+
lengths = [len(sample["input_ids"]) for sample in train_dataset]
|
59
|
+
|
60
|
+
LOG.info("Finished computing lengths for the train dataset")
|
61
|
+
else:
|
62
|
+
lengths = self.train_lengths
|
63
|
+
|
64
|
+
data_collator = self.data_collator
|
65
|
+
if import_utils.is_datasets_available() and isinstance(train_dataset, Dataset):
|
66
|
+
train_dataset = self._remove_unused_columns(
|
67
|
+
train_dataset, description="training"
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
data_collator = self._get_collator_with_removed_columns(
|
71
|
+
data_collator, description="training"
|
72
|
+
)
|
73
|
+
# Create our custom batch sampler
|
74
|
+
batch_sampler = SamplePackingBatchSampler(
|
75
|
+
lengths=lengths,
|
76
|
+
max_tokens_per_batch=self.max_tokens_per_batch,
|
77
|
+
max_position_embeddings=self.max_position_embeddings,
|
78
|
+
drop_last=self.args.dataloader_drop_last,
|
79
|
+
seed=self.args.seed,
|
80
|
+
)
|
81
|
+
dataloader_params = {
|
82
|
+
"collate_fn": data_collator,
|
83
|
+
"num_workers": self.args.dataloader_num_workers,
|
84
|
+
"pin_memory": self.args.dataloader_pin_memory,
|
85
|
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
86
|
+
"batch_sampler": batch_sampler,
|
87
|
+
}
|
88
|
+
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
89
|
+
|
90
|
+
def get_eval_dataloader(
|
91
|
+
self, eval_dataset: Optional[Union[str, Dataset]] = None
|
92
|
+
) -> DataLoader:
|
93
|
+
"""
|
94
|
+
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
95
|
+
|
96
|
+
Subclass and override this method if you want to inject some custom behavior.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
|
100
|
+
If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
|
101
|
+
"""
|
102
|
+
if eval_dataset is None and self.eval_dataset is None:
|
103
|
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
104
|
+
|
105
|
+
# If we have persistent workers, don't do a fork bomb especially as eval datasets
|
106
|
+
# don't change during training
|
107
|
+
dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
|
108
|
+
if (
|
109
|
+
hasattr(self, "_eval_dataloaders")
|
110
|
+
and dataloader_key in self._eval_dataloaders
|
111
|
+
and self.args.dataloader_persistent_workers
|
112
|
+
):
|
113
|
+
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
|
114
|
+
|
115
|
+
eval_dataset = (
|
116
|
+
self.eval_dataset[eval_dataset]
|
117
|
+
if isinstance(eval_dataset, str)
|
118
|
+
else eval_dataset if eval_dataset is not None else self.eval_dataset
|
119
|
+
)
|
120
|
+
|
121
|
+
if self.validation_lengths is None:
|
122
|
+
LOG.info("Started computing lengths for the train dataset")
|
123
|
+
# Calculate lengths of all sequences in dataset
|
124
|
+
if "num_of_concepts" in eval_dataset.column_names:
|
125
|
+
lengths = eval_dataset["num_of_concepts"]
|
126
|
+
else:
|
127
|
+
lengths = [len(sample["input_ids"]) for sample in eval_dataset]
|
128
|
+
|
129
|
+
LOG.info("Finished computing lengths for the train dataset")
|
130
|
+
else:
|
131
|
+
lengths = self.validation_lengths
|
132
|
+
|
133
|
+
data_collator = self.data_collator
|
134
|
+
|
135
|
+
if import_utils.is_datasets_available() and isinstance(eval_dataset, Dataset):
|
136
|
+
eval_dataset = self._remove_unused_columns(
|
137
|
+
eval_dataset, description="evaluation"
|
138
|
+
)
|
139
|
+
else:
|
140
|
+
data_collator = self._get_collator_with_removed_columns(
|
141
|
+
data_collator, description="evaluation"
|
142
|
+
)
|
143
|
+
|
144
|
+
# Create our custom batch sampler
|
145
|
+
batch_sampler = SamplePackingBatchSampler(
|
146
|
+
lengths=lengths,
|
147
|
+
max_tokens_per_batch=self.max_tokens_per_batch,
|
148
|
+
max_position_embeddings=self.max_position_embeddings,
|
149
|
+
drop_last=self.args.dataloader_drop_last,
|
150
|
+
seed=self.args.seed,
|
151
|
+
)
|
152
|
+
dataloader_params = {
|
153
|
+
"collate_fn": data_collator,
|
154
|
+
"num_workers": self.args.dataloader_num_workers,
|
155
|
+
"pin_memory": self.args.dataloader_pin_memory,
|
156
|
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
157
|
+
"batch_sampler": batch_sampler,
|
158
|
+
}
|
159
|
+
# accelerator.free_memory() will destroy the references, so
|
160
|
+
# we need to store the non-prepared version
|
161
|
+
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
|
162
|
+
if self.args.dataloader_persistent_workers:
|
163
|
+
if hasattr(self, "_eval_dataloaders"):
|
164
|
+
self._eval_dataloaders[dataloader_key] = eval_dataloader
|
165
|
+
else:
|
166
|
+
self._eval_dataloaders = {dataloader_key: eval_dataloader}
|
167
|
+
|
168
|
+
return self.accelerator.prepare(eval_dataloader)
|
File without changes
|
@@ -0,0 +1,95 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
|
9
|
+
def main(output_dir: str):
|
10
|
+
with open(os.path.join(output_dir, "time_embedding_metrics.json"), "r") as f:
|
11
|
+
time_embedding_metrics = json.load(f)
|
12
|
+
with open(os.path.join(output_dir, "time_token_metrics.json"), "r") as f:
|
13
|
+
time_token_metrics = json.load(f)
|
14
|
+
|
15
|
+
common_steps = list(
|
16
|
+
set(time_embedding_metrics["steps"]) & set(time_token_metrics["steps"])
|
17
|
+
)
|
18
|
+
|
19
|
+
time_embedding_aucs = []
|
20
|
+
time_embedding_accuracies = []
|
21
|
+
for step, roc_auc, accuracy in zip(
|
22
|
+
time_embedding_metrics["steps"],
|
23
|
+
time_embedding_metrics["roc_auc"],
|
24
|
+
time_embedding_metrics["accuracy"],
|
25
|
+
):
|
26
|
+
if step in common_steps:
|
27
|
+
time_embedding_aucs.append(roc_auc)
|
28
|
+
time_embedding_accuracies.append(accuracy)
|
29
|
+
|
30
|
+
time_token_aucs = []
|
31
|
+
time_token_accuracies = []
|
32
|
+
for step, roc_auc, accuracy in zip(
|
33
|
+
time_token_metrics["steps"],
|
34
|
+
time_token_metrics["roc_auc"],
|
35
|
+
time_token_metrics["accuracy"],
|
36
|
+
):
|
37
|
+
if step in common_steps:
|
38
|
+
time_token_aucs.append(roc_auc)
|
39
|
+
time_token_accuracies.append(accuracy)
|
40
|
+
|
41
|
+
# Create the accuracy plot
|
42
|
+
plt.figure(figsize=(8, 5)) # Define figure size
|
43
|
+
plt.plot(
|
44
|
+
common_steps,
|
45
|
+
time_embedding_accuracies,
|
46
|
+
linestyle="-",
|
47
|
+
color="b",
|
48
|
+
label="Time Embedding",
|
49
|
+
lw=1,
|
50
|
+
)
|
51
|
+
plt.plot(
|
52
|
+
common_steps,
|
53
|
+
time_token_accuracies,
|
54
|
+
linestyle="--",
|
55
|
+
color="r",
|
56
|
+
label="Time Token",
|
57
|
+
lw=1,
|
58
|
+
)
|
59
|
+
plt.title("Accuracy Comparison Over Time")
|
60
|
+
plt.xlabel("Training Steps")
|
61
|
+
plt.ylabel("Accuracy")
|
62
|
+
plt.legend()
|
63
|
+
plt.grid(False)
|
64
|
+
plt.savefig(os.path.join(output_dir, "accuracy_comparison.png"))
|
65
|
+
|
66
|
+
# Create the ROC AUC plot
|
67
|
+
plt.figure(figsize=(8, 5)) # Define figure size
|
68
|
+
plt.plot(
|
69
|
+
common_steps,
|
70
|
+
time_embedding_aucs,
|
71
|
+
linestyle="-",
|
72
|
+
color="b",
|
73
|
+
label="Time Embedding",
|
74
|
+
lw=1,
|
75
|
+
)
|
76
|
+
plt.plot(
|
77
|
+
common_steps,
|
78
|
+
time_token_aucs,
|
79
|
+
linestyle="--",
|
80
|
+
color="r",
|
81
|
+
label="Time Token",
|
82
|
+
lw=1,
|
83
|
+
)
|
84
|
+
plt.title("ROC AUC Comparison Over Time")
|
85
|
+
plt.xlabel("Training Steps")
|
86
|
+
plt.ylabel("ROC AUC")
|
87
|
+
plt.legend()
|
88
|
+
plt.grid(False)
|
89
|
+
plt.savefig(
|
90
|
+
os.path.join(output_dir, "roc_auc_comparison.png")
|
91
|
+
) # Save the plot as a PNG file
|
92
|
+
|
93
|
+
|
94
|
+
if __name__ == "__main__":
|
95
|
+
main(sys.argv[1])
|
@@ -0,0 +1,24 @@
|
|
1
|
+
#!/bin/bash
|
2
|
+
|
3
|
+
# This script runs various Python simulations and generates plots
|
4
|
+
# It accepts three parameters: output directory, number of steps, and number of samples
|
5
|
+
|
6
|
+
# Check if all arguments are provided
|
7
|
+
if [ "$#" -ne 3 ]; then
|
8
|
+
echo "Usage: $0 <output_dir> <n_steps> <n_samples>"
|
9
|
+
exit 1
|
10
|
+
fi
|
11
|
+
|
12
|
+
# Assigning command line arguments to variables
|
13
|
+
OUTPUT_DIR="$1"
|
14
|
+
N_STEPS="$2"
|
15
|
+
N_SAMPLES="$3"
|
16
|
+
|
17
|
+
# Run time token simulation
|
18
|
+
python -u -m cehrgpt.simulations.time_token_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
|
19
|
+
|
20
|
+
# Run time embedding simulation
|
21
|
+
python -u -m cehrgpt.simulations.time_embedding_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
|
22
|
+
|
23
|
+
# Generate plots
|
24
|
+
python -u -m cehrgpt.simulations.generate_plots "$OUTPUT_DIR"
|
@@ -0,0 +1,250 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import torch.optim as optim
|
6
|
+
from sklearn.metrics import accuracy_score, roc_auc_score
|
7
|
+
from torch.nn import CrossEntropyLoss
|
8
|
+
from transformers import BertConfig, BertModel
|
9
|
+
|
10
|
+
|
11
|
+
class ModelTimeEmbedding(torch.nn.Module):
|
12
|
+
def __init__(self, vocab_size: int):
|
13
|
+
super(ModelTimeEmbedding, self).__init__()
|
14
|
+
self.embedding = torch.nn.Embedding(vocab_size, 16)
|
15
|
+
self.bert = BertModel(
|
16
|
+
BertConfig(
|
17
|
+
vocab_size=vocab_size,
|
18
|
+
hidden_size=16,
|
19
|
+
num_attention_heads=2,
|
20
|
+
num_hidden_layers=2,
|
21
|
+
intermediate_size=32,
|
22
|
+
hidden_dropout_prob=0.0,
|
23
|
+
attention_probs_dropout_prob=0.0,
|
24
|
+
max_position_embeddings=2,
|
25
|
+
),
|
26
|
+
add_pooling_layer=False,
|
27
|
+
)
|
28
|
+
self.linear = torch.nn.Linear(32, 2)
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
self,
|
32
|
+
input_ids: torch.LongTensor,
|
33
|
+
time_stamps: torch.LongTensor,
|
34
|
+
labels: Optional[torch.LongTensor] = None,
|
35
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
36
|
+
bz = input_ids.shape[0]
|
37
|
+
x = self.embedding(input_ids)
|
38
|
+
t = self.embedding(time_stamps)
|
39
|
+
x = x + t
|
40
|
+
bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
|
41
|
+
output = bert_output.last_hidden_state.reshape((bz, 32))
|
42
|
+
y = self.linear(output)
|
43
|
+
loss = None
|
44
|
+
if labels is not None:
|
45
|
+
loss_fct = CrossEntropyLoss()
|
46
|
+
loss = loss_fct(y, labels)
|
47
|
+
return loss, y
|
48
|
+
|
49
|
+
|
50
|
+
def generate_simulation_data(sample_size: int = 1000, seed: int = 42) -> np.ndarray:
|
51
|
+
np.random.seed(seed) # Set the seed for reproducibility
|
52
|
+
|
53
|
+
# Define input values and time stamps
|
54
|
+
x_values = [0, 1]
|
55
|
+
time_stamp_values = list(range(0, 21))
|
56
|
+
|
57
|
+
# Generate random choices for features and time stamps
|
58
|
+
x1 = np.random.choice(x_values, size=sample_size)
|
59
|
+
x2 = np.random.choice(x_values, size=sample_size)
|
60
|
+
t1 = np.random.choice(time_stamp_values, size=sample_size)
|
61
|
+
t2 = t1 + np.random.choice(time_stamp_values, size=sample_size)
|
62
|
+
|
63
|
+
# Define conditions based on time differences
|
64
|
+
time_diff = t2 - t1
|
65
|
+
# Complex condition involving modulo operation
|
66
|
+
is_custom_func_1 = (x1 == 1) & (time_diff % 4 == 0)
|
67
|
+
is_custom_func_2 = (x1 == 0) & (time_diff % 3 == 0)
|
68
|
+
is_xor = time_diff <= 7
|
69
|
+
is_and = (time_diff > 7) & (time_diff <= 14)
|
70
|
+
is_or = (time_diff > 14) & (time_diff <= 21)
|
71
|
+
|
72
|
+
# Logical operations based on x1 and x2
|
73
|
+
xor = (x2 != x1).astype(int)
|
74
|
+
logical_and = (x2 & x1).astype(int)
|
75
|
+
logical_or = (x2 | x1).astype(int)
|
76
|
+
# Additional complexity: introduce a new rule based on a more complex condition
|
77
|
+
custom_func_1_result = (x2 == 0).astype(int) # For example, use a different rule
|
78
|
+
custom_func_2_result = (x2 == 1).astype(int) # For example, use a different rule
|
79
|
+
|
80
|
+
# Determine output based on multiple conditions
|
81
|
+
y = np.where(
|
82
|
+
is_custom_func_1,
|
83
|
+
custom_func_1_result,
|
84
|
+
np.where(
|
85
|
+
is_custom_func_2,
|
86
|
+
custom_func_2_result,
|
87
|
+
np.where(
|
88
|
+
is_xor,
|
89
|
+
xor,
|
90
|
+
np.where(is_and, logical_and, np.where(is_or, logical_or, 0)),
|
91
|
+
),
|
92
|
+
),
|
93
|
+
)
|
94
|
+
|
95
|
+
# Return the data as a single numpy array with features and output
|
96
|
+
return np.column_stack((x1, x2, t1, t2, y))
|
97
|
+
|
98
|
+
|
99
|
+
def create_time_embedding_tokenizer(simulated_data):
|
100
|
+
vocab = []
|
101
|
+
for row in simulated_data:
|
102
|
+
x1, x2, t1, t2, y = row
|
103
|
+
x1 = f"c-{x1}"
|
104
|
+
x2 = f"c-{x2}"
|
105
|
+
t1 = f"t-{t1}"
|
106
|
+
t2 = f"t-{t2}"
|
107
|
+
if x1 not in vocab:
|
108
|
+
vocab.append(x1)
|
109
|
+
if x2 not in vocab:
|
110
|
+
vocab.append(x2)
|
111
|
+
if t1 not in vocab:
|
112
|
+
vocab.append(t1)
|
113
|
+
if t2 not in vocab:
|
114
|
+
vocab.append(t2)
|
115
|
+
return {c: i + 1 for i, c in enumerate(vocab)}
|
116
|
+
|
117
|
+
|
118
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
119
|
+
|
120
|
+
|
121
|
+
def eval_step(
|
122
|
+
simulated_data,
|
123
|
+
time_embedding_tokenizer,
|
124
|
+
time_embedding_model,
|
125
|
+
time_embedding_optimizer,
|
126
|
+
):
|
127
|
+
time_embedding_optimizer.zero_grad()
|
128
|
+
time_embedding_model.eval()
|
129
|
+
eval_input_ids = []
|
130
|
+
eval_time_stamps = []
|
131
|
+
eval_y = []
|
132
|
+
for row in simulated_data:
|
133
|
+
x1, x2, t1, t2, y = row
|
134
|
+
x1 = f"c-{x1}"
|
135
|
+
x2 = f"c-{x2}"
|
136
|
+
t1 = f"t-{t1}"
|
137
|
+
t2 = f"t-{t2}"
|
138
|
+
eval_input_ids.append(
|
139
|
+
[time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
|
140
|
+
)
|
141
|
+
eval_time_stamps.append(
|
142
|
+
[time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
|
143
|
+
)
|
144
|
+
eval_y.append(y)
|
145
|
+
eval_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
|
146
|
+
eval_time_stamps = torch.tensor(eval_time_stamps, dtype=torch.long).to(device)
|
147
|
+
eval_y = np.asarray(eval_y)
|
148
|
+
with torch.no_grad():
|
149
|
+
# Compute loss and forward pass
|
150
|
+
_, y_pred = time_embedding_model(eval_input_ids, eval_time_stamps)
|
151
|
+
y_probs = torch.nn.functional.softmax(y_pred, dim=1)
|
152
|
+
y_probs = y_probs.detach().cpu().numpy()
|
153
|
+
# print(np.concatenate((y_probs, batched_y[:, None]), axis=1))
|
154
|
+
roc_auc = roc_auc_score(eval_y, y_probs[:, 1])
|
155
|
+
accuracy = accuracy_score(eval_y, y_probs[:, 1] > y_probs[:, 0])
|
156
|
+
print(f"ROC AUC: {roc_auc}")
|
157
|
+
print(f"Accuracy: {accuracy}")
|
158
|
+
return roc_auc, accuracy
|
159
|
+
|
160
|
+
|
161
|
+
def train_step(
|
162
|
+
simulated_data,
|
163
|
+
time_embedding_tokenizer,
|
164
|
+
time_embedding_model,
|
165
|
+
time_embedding_optimizer,
|
166
|
+
):
|
167
|
+
batched_input_ids = []
|
168
|
+
batched_time_stamps = []
|
169
|
+
batched_y = []
|
170
|
+
indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
|
171
|
+
for row in simulated_data[indices, :]:
|
172
|
+
x1, x2, t1, t2, y = row
|
173
|
+
x1 = f"c-{x1}"
|
174
|
+
x2 = f"c-{x2}"
|
175
|
+
t1 = f"t-{t1}"
|
176
|
+
t2 = f"t-{t2}"
|
177
|
+
batched_input_ids.append(
|
178
|
+
[time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
|
179
|
+
)
|
180
|
+
batched_time_stamps.append(
|
181
|
+
[time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
|
182
|
+
)
|
183
|
+
batched_y.append(y)
|
184
|
+
batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
|
185
|
+
batched_time_stamps = torch.tensor(batched_time_stamps, dtype=torch.long).to(device)
|
186
|
+
batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
|
187
|
+
# Zero the gradients
|
188
|
+
time_embedding_optimizer.zero_grad()
|
189
|
+
# Compute loss and forward pass
|
190
|
+
loss, _ = time_embedding_model(batched_input_ids, batched_time_stamps, batched_y)
|
191
|
+
# Backward pass (compute gradients)
|
192
|
+
loss.backward()
|
193
|
+
# Update model parameters
|
194
|
+
time_embedding_optimizer.step()
|
195
|
+
return loss
|
196
|
+
|
197
|
+
|
198
|
+
def main(args):
|
199
|
+
simulated_data = generate_simulation_data(args.n_samples)
|
200
|
+
time_embedding_tokenizer = create_time_embedding_tokenizer(simulated_data)
|
201
|
+
time_embedding_model = ModelTimeEmbedding(len(time_embedding_tokenizer) + 1).to(
|
202
|
+
device
|
203
|
+
)
|
204
|
+
time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
|
205
|
+
steps = []
|
206
|
+
roc_aucs = []
|
207
|
+
accuracies = []
|
208
|
+
for step in range(args.n_steps):
|
209
|
+
loss = train_step(
|
210
|
+
simulated_data,
|
211
|
+
time_embedding_tokenizer,
|
212
|
+
time_embedding_model,
|
213
|
+
time_embedding_optimizer,
|
214
|
+
)
|
215
|
+
print(f"Step {step}: Loss = {loss.item()}")
|
216
|
+
# Evaluation
|
217
|
+
if (
|
218
|
+
args.n_steps % args.eval_frequency == 0
|
219
|
+
and args.n_steps > args.eval_frequency
|
220
|
+
):
|
221
|
+
# Zero the gradients
|
222
|
+
roc_auc, accuracy = eval_step(
|
223
|
+
simulated_data,
|
224
|
+
time_embedding_tokenizer,
|
225
|
+
time_embedding_model,
|
226
|
+
time_embedding_optimizer,
|
227
|
+
)
|
228
|
+
steps.append(step)
|
229
|
+
roc_aucs.append(roc_auc)
|
230
|
+
accuracies.append(accuracy)
|
231
|
+
return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
|
232
|
+
|
233
|
+
|
234
|
+
if __name__ == "__main__":
|
235
|
+
import argparse
|
236
|
+
import json
|
237
|
+
from pathlib import Path
|
238
|
+
|
239
|
+
parser = argparse.ArgumentParser("Model with time embedding simulation")
|
240
|
+
parser.add_argument("--output_dir", type=str, required=True)
|
241
|
+
parser.add_argument("--n_steps", type=int, default=10000)
|
242
|
+
parser.add_argument("--n_samples", type=int, default=1000)
|
243
|
+
parser.add_argument("--batch_size", type=int, default=128)
|
244
|
+
parser.add_argument("--eval_frequency", type=int, default=100)
|
245
|
+
args = parser.parse_args()
|
246
|
+
output_dir = Path(args.output_dir)
|
247
|
+
output_dir.mkdir(exist_ok=True, parents=True)
|
248
|
+
metrics = main(args)
|
249
|
+
with open(output_dir / "time_embedding_metrics.json", "w") as f:
|
250
|
+
json.dump(metrics, f)
|