cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cehrgpt/analysis/irregularity.py +36 -0
  2. cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
  3. cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
  4. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
  5. cehrgpt/data/sample_packing_sampler.py +181 -0
  6. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  7. cehrgpt/generation/omop_converter_batch.py +32 -2
  8. cehrgpt/gpt_utils.py +20 -2
  9. cehrgpt/models/config.py +35 -0
  10. cehrgpt/models/hf_cehrgpt.py +470 -106
  11. cehrgpt/models/hf_modeling_outputs.py +1 -0
  12. cehrgpt/models/special_tokens.py +1 -0
  13. cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
  14. cehrgpt/runners/data_utils.py +358 -0
  15. cehrgpt/runners/gpt_runner_util.py +0 -10
  16. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
  17. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
  18. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
  19. cehrgpt/runners/hyperparameter_search_util.py +10 -8
  20. cehrgpt/runners/sample_packing_trainer.py +185 -0
  21. cehrgpt/simulations/generate_plots.py +95 -0
  22. cehrgpt/simulations/run_simulation.sh +24 -0
  23. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  24. cehrgpt/simulations/time_token_simulation.py +177 -0
  25. cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
  26. cehrgpt/time_to_event/time_to_event_model.py +2 -13
  27. cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
  28. cehrgpt/tools/linear_prob/__init__.py +0 -0
  29. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
  30. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  31. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
  32. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
  33. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
  34. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  35. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  36. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  37. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  38. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  39. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  40. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  41. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  42. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  43. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
  44. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,7 @@ from typing import Callable, Tuple
4
4
  import optuna
5
5
  from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
6
6
  from datasets import Dataset, DatasetDict
7
- from transformers import (
8
- EarlyStoppingCallback,
9
- Trainer,
10
- TrainerCallback,
11
- TrainingArguments,
12
- )
7
+ from transformers import EarlyStoppingCallback, TrainerCallback, TrainingArguments
13
8
  from transformers.utils import logging
14
9
 
15
10
  from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
@@ -85,7 +80,9 @@ def hp_space(
85
80
  "per_device_train_batch_size", batch_sizes
86
81
  ),
87
82
  "weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
88
- "num_train_epochs": trial.suggest_int("num_train_epochs", *num_train_epochs),
83
+ "num_train_epochs": trial.suggest_categorical(
84
+ "num_train_epochs", num_train_epochs
85
+ ),
89
86
  }
90
87
 
91
88
 
@@ -126,6 +123,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
126
123
 
127
124
 
128
125
  def perform_hyperparameter_search(
126
+ trainer_class,
129
127
  model_init: Callable,
130
128
  dataset: DatasetDict,
131
129
  data_collator: CehrGptDataCollator,
@@ -142,6 +140,7 @@ def perform_hyperparameter_search(
142
140
  After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
143
141
 
144
142
  Args:
143
+ trainer_class: A Trainer or its subclass
145
144
  model_init (Callable): A function to initialize the model, used for each hyperparameter trial.
146
145
  dataset (DatasetDict): A Hugging Face DatasetDict containing "train" and "validation" datasets.
147
146
  data_collator (CehrGptDataCollator): A data collator for processing batches.
@@ -157,6 +156,7 @@ def perform_hyperparameter_search(
157
156
  Example:
158
157
  ```
159
158
  best_training_args = perform_hyperparameter_search(
159
+ trainer_class=Trainer,
160
160
  model_init=my_model_init,
161
161
  dataset=my_dataset_dict,
162
162
  data_collator=my_data_collator,
@@ -187,7 +187,7 @@ def perform_hyperparameter_search(
187
187
  cehrgpt_args.hyperparameter_tuning_percentage,
188
188
  training_args.seed,
189
189
  )
190
- hyperparam_trainer = Trainer(
190
+ hyperparam_trainer = trainer_class(
191
191
  model_init=model_init,
192
192
  data_collator=data_collator,
193
193
  train_dataset=sampled_train,
@@ -214,6 +214,8 @@ def perform_hyperparameter_search(
214
214
  backend="optuna",
215
215
  n_trials=cehrgpt_args.n_trials,
216
216
  compute_objective=lambda m: m["optuna_best_metric"],
217
+ # Ensure reproducibility
218
+ sampler=optuna.samplers.TPESampler(seed=training_args.seed),
217
219
  )
218
220
  LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
219
221
  # Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
@@ -0,0 +1,185 @@
1
+ from typing import Optional, Union
2
+
3
+ from datasets import Dataset
4
+ from torch.utils.data import DataLoader
5
+ from transformers import Trainer
6
+ from transformers.trainer_utils import has_length
7
+ from transformers.utils import import_utils, logging
8
+
9
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
10
+
11
+ DEFAULT_MAX_TOKENS_PER_BATCH = 16384
12
+
13
+ LOG = logging.get_logger("transformers")
14
+
15
+
16
+ class SamplePackingTrainer(Trainer):
17
+ def __init__(self, *args, **kwargs):
18
+ if "max_tokens_per_batch" in kwargs:
19
+ self.max_tokens_per_batch = kwargs.pop("max_tokens_per_batch")
20
+ LOG.info("max_tokens_per_batch: %s", self.max_tokens_per_batch)
21
+ else:
22
+ self.max_tokens_per_batch = DEFAULT_MAX_TOKENS_PER_BATCH
23
+ LOG.info(
24
+ "max_tokens_per_batch is not provided to SamplePackingTrainer and will default to %s",
25
+ DEFAULT_MAX_TOKENS_PER_BATCH,
26
+ )
27
+
28
+ if "max_position_embeddings" in kwargs:
29
+ self.max_position_embeddings = kwargs.pop("max_position_embeddings")
30
+ LOG.info("max_position_embeddings: %s", self.max_position_embeddings)
31
+ else:
32
+ self.max_position_embeddings = self.max_tokens_per_batch
33
+ LOG.info(
34
+ "max_position_embeddings is not provided to SamplePackingTrainer and will default to %s",
35
+ self.max_tokens_per_batch,
36
+ )
37
+
38
+ self.negative_sampling_probability = kwargs.pop(
39
+ "negative_sampling_probability", None
40
+ )
41
+ if self.negative_sampling_probability:
42
+ LOG.info(
43
+ "negative_sampling_probability: %s", self.negative_sampling_probability
44
+ )
45
+ self.train_lengths = kwargs.pop("train_lengths", None)
46
+ self.validation_lengths = kwargs.pop("validation_lengths", None)
47
+ super().__init__(*args, **kwargs)
48
+ self.accelerator.even_batches = False
49
+
50
+ def num_examples(self, dataloader: DataLoader) -> int:
51
+ if has_length(dataloader):
52
+ return len(dataloader)
53
+ raise RuntimeError("DataLoader in SamplePackingTrainer must have length")
54
+
55
+ def get_train_dataloader(self) -> DataLoader:
56
+ """Returns the training dataloader with our custom batch sampler."""
57
+ train_dataset = self.train_dataset
58
+
59
+ if self.train_lengths is None:
60
+ LOG.info("Started computing lengths for the train dataset")
61
+ # Calculate lengths of all sequences in dataset
62
+ if "num_of_concepts" in train_dataset.column_names:
63
+ lengths = train_dataset["num_of_concepts"]
64
+ else:
65
+ lengths = [len(sample["input_ids"]) for sample in train_dataset]
66
+
67
+ LOG.info("Finished computing lengths for the train dataset")
68
+ else:
69
+ lengths = self.train_lengths
70
+
71
+ data_collator = self.data_collator
72
+ if import_utils.is_datasets_available() and isinstance(train_dataset, Dataset):
73
+ train_dataset = self._remove_unused_columns(
74
+ train_dataset, description="training"
75
+ )
76
+ else:
77
+ data_collator = self._get_collator_with_removed_columns(
78
+ data_collator, description="training"
79
+ )
80
+
81
+ labels = None
82
+ if (
83
+ self.negative_sampling_probability is not None
84
+ and "classifier_label" in train_dataset.column_names
85
+ ):
86
+ labels = train_dataset["classifier_label"]
87
+
88
+ # Create our custom batch sampler
89
+ batch_sampler = SamplePackingBatchSampler(
90
+ lengths=lengths,
91
+ max_tokens_per_batch=self.max_tokens_per_batch,
92
+ max_position_embeddings=self.max_position_embeddings,
93
+ drop_last=self.args.dataloader_drop_last,
94
+ seed=self.args.seed,
95
+ negative_sampling_probability=self.negative_sampling_probability,
96
+ labels=labels,
97
+ )
98
+ dataloader_params = {
99
+ "collate_fn": data_collator,
100
+ "num_workers": self.args.dataloader_num_workers,
101
+ "pin_memory": self.args.dataloader_pin_memory,
102
+ "persistent_workers": self.args.dataloader_persistent_workers,
103
+ "batch_sampler": batch_sampler,
104
+ }
105
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
106
+
107
+ def get_eval_dataloader(
108
+ self, eval_dataset: Optional[Union[str, Dataset]] = None
109
+ ) -> DataLoader:
110
+ """
111
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
112
+
113
+ Subclass and override this method if you want to inject some custom behavior.
114
+
115
+ Args:
116
+ eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
117
+ If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
118
+ """
119
+ if eval_dataset is None and self.eval_dataset is None:
120
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
121
+
122
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
123
+ # don't change during training
124
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
125
+ if (
126
+ hasattr(self, "_eval_dataloaders")
127
+ and dataloader_key in self._eval_dataloaders
128
+ and self.args.dataloader_persistent_workers
129
+ ):
130
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
131
+
132
+ eval_dataset = (
133
+ self.eval_dataset[eval_dataset]
134
+ if isinstance(eval_dataset, str)
135
+ else eval_dataset if eval_dataset is not None else self.eval_dataset
136
+ )
137
+
138
+ if self.validation_lengths is None:
139
+ LOG.info("Started computing lengths for the train dataset")
140
+ # Calculate lengths of all sequences in dataset
141
+ if "num_of_concepts" in eval_dataset.column_names:
142
+ lengths = eval_dataset["num_of_concepts"]
143
+ else:
144
+ lengths = [len(sample["input_ids"]) for sample in eval_dataset]
145
+
146
+ LOG.info("Finished computing lengths for the train dataset")
147
+ else:
148
+ lengths = self.validation_lengths
149
+
150
+ data_collator = self.data_collator
151
+
152
+ if import_utils.is_datasets_available() and isinstance(eval_dataset, Dataset):
153
+ eval_dataset = self._remove_unused_columns(
154
+ eval_dataset, description="evaluation"
155
+ )
156
+ else:
157
+ data_collator = self._get_collator_with_removed_columns(
158
+ data_collator, description="evaluation"
159
+ )
160
+
161
+ # Create our custom batch sampler
162
+ batch_sampler = SamplePackingBatchSampler(
163
+ lengths=lengths,
164
+ max_tokens_per_batch=self.max_tokens_per_batch,
165
+ max_position_embeddings=self.max_position_embeddings,
166
+ drop_last=self.args.dataloader_drop_last,
167
+ seed=self.args.seed,
168
+ )
169
+ dataloader_params = {
170
+ "collate_fn": data_collator,
171
+ "num_workers": self.args.dataloader_num_workers,
172
+ "pin_memory": self.args.dataloader_pin_memory,
173
+ "persistent_workers": self.args.dataloader_persistent_workers,
174
+ "batch_sampler": batch_sampler,
175
+ }
176
+ # accelerator.free_memory() will destroy the references, so
177
+ # we need to store the non-prepared version
178
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
179
+ if self.args.dataloader_persistent_workers:
180
+ if hasattr(self, "_eval_dataloaders"):
181
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
182
+ else:
183
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
184
+
185
+ return self.accelerator.prepare(eval_dataloader)
@@ -0,0 +1,95 @@
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+
9
+ def main(output_dir: str):
10
+ with open(os.path.join(output_dir, "time_embedding_metrics.json"), "r") as f:
11
+ time_embedding_metrics = json.load(f)
12
+ with open(os.path.join(output_dir, "time_token_metrics.json"), "r") as f:
13
+ time_token_metrics = json.load(f)
14
+
15
+ common_steps = list(
16
+ set(time_embedding_metrics["steps"]) & set(time_token_metrics["steps"])
17
+ )
18
+
19
+ time_embedding_aucs = []
20
+ time_embedding_accuracies = []
21
+ for step, roc_auc, accuracy in zip(
22
+ time_embedding_metrics["steps"],
23
+ time_embedding_metrics["roc_auc"],
24
+ time_embedding_metrics["accuracy"],
25
+ ):
26
+ if step in common_steps:
27
+ time_embedding_aucs.append(roc_auc)
28
+ time_embedding_accuracies.append(accuracy)
29
+
30
+ time_token_aucs = []
31
+ time_token_accuracies = []
32
+ for step, roc_auc, accuracy in zip(
33
+ time_token_metrics["steps"],
34
+ time_token_metrics["roc_auc"],
35
+ time_token_metrics["accuracy"],
36
+ ):
37
+ if step in common_steps:
38
+ time_token_aucs.append(roc_auc)
39
+ time_token_accuracies.append(accuracy)
40
+
41
+ # Create the accuracy plot
42
+ plt.figure(figsize=(8, 5)) # Define figure size
43
+ plt.plot(
44
+ common_steps,
45
+ time_embedding_accuracies,
46
+ linestyle="-",
47
+ color="b",
48
+ label="Time Embedding",
49
+ lw=1,
50
+ )
51
+ plt.plot(
52
+ common_steps,
53
+ time_token_accuracies,
54
+ linestyle="--",
55
+ color="r",
56
+ label="Time Token",
57
+ lw=1,
58
+ )
59
+ plt.title("Accuracy Comparison Over Time")
60
+ plt.xlabel("Training Steps")
61
+ plt.ylabel("Accuracy")
62
+ plt.legend()
63
+ plt.grid(False)
64
+ plt.savefig(os.path.join(output_dir, "accuracy_comparison.png"))
65
+
66
+ # Create the ROC AUC plot
67
+ plt.figure(figsize=(8, 5)) # Define figure size
68
+ plt.plot(
69
+ common_steps,
70
+ time_embedding_aucs,
71
+ linestyle="-",
72
+ color="b",
73
+ label="Time Embedding",
74
+ lw=1,
75
+ )
76
+ plt.plot(
77
+ common_steps,
78
+ time_token_aucs,
79
+ linestyle="--",
80
+ color="r",
81
+ label="Time Token",
82
+ lw=1,
83
+ )
84
+ plt.title("ROC AUC Comparison Over Time")
85
+ plt.xlabel("Training Steps")
86
+ plt.ylabel("ROC AUC")
87
+ plt.legend()
88
+ plt.grid(False)
89
+ plt.savefig(
90
+ os.path.join(output_dir, "roc_auc_comparison.png")
91
+ ) # Save the plot as a PNG file
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main(sys.argv[1])
@@ -0,0 +1,24 @@
1
+ #!/bin/bash
2
+
3
+ # This script runs various Python simulations and generates plots
4
+ # It accepts three parameters: output directory, number of steps, and number of samples
5
+
6
+ # Check if all arguments are provided
7
+ if [ "$#" -ne 3 ]; then
8
+ echo "Usage: $0 <output_dir> <n_steps> <n_samples>"
9
+ exit 1
10
+ fi
11
+
12
+ # Assigning command line arguments to variables
13
+ OUTPUT_DIR="$1"
14
+ N_STEPS="$2"
15
+ N_SAMPLES="$3"
16
+
17
+ # Run time token simulation
18
+ python -u -m cehrgpt.simulations.time_token_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
19
+
20
+ # Run time embedding simulation
21
+ python -u -m cehrgpt.simulations.time_embedding_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
22
+
23
+ # Generate plots
24
+ python -u -m cehrgpt.simulations.generate_plots "$OUTPUT_DIR"
@@ -0,0 +1,250 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.optim as optim
6
+ from sklearn.metrics import accuracy_score, roc_auc_score
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import BertConfig, BertModel
9
+
10
+
11
+ class ModelTimeEmbedding(torch.nn.Module):
12
+ def __init__(self, vocab_size: int):
13
+ super(ModelTimeEmbedding, self).__init__()
14
+ self.embedding = torch.nn.Embedding(vocab_size, 16)
15
+ self.bert = BertModel(
16
+ BertConfig(
17
+ vocab_size=vocab_size,
18
+ hidden_size=16,
19
+ num_attention_heads=2,
20
+ num_hidden_layers=2,
21
+ intermediate_size=32,
22
+ hidden_dropout_prob=0.0,
23
+ attention_probs_dropout_prob=0.0,
24
+ max_position_embeddings=2,
25
+ ),
26
+ add_pooling_layer=False,
27
+ )
28
+ self.linear = torch.nn.Linear(32, 2)
29
+
30
+ def forward(
31
+ self,
32
+ input_ids: torch.LongTensor,
33
+ time_stamps: torch.LongTensor,
34
+ labels: Optional[torch.LongTensor] = None,
35
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
36
+ bz = input_ids.shape[0]
37
+ x = self.embedding(input_ids)
38
+ t = self.embedding(time_stamps)
39
+ x = x + t
40
+ bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
41
+ output = bert_output.last_hidden_state.reshape((bz, 32))
42
+ y = self.linear(output)
43
+ loss = None
44
+ if labels is not None:
45
+ loss_fct = CrossEntropyLoss()
46
+ loss = loss_fct(y, labels)
47
+ return loss, y
48
+
49
+
50
+ def generate_simulation_data(sample_size: int = 1000, seed: int = 42) -> np.ndarray:
51
+ np.random.seed(seed) # Set the seed for reproducibility
52
+
53
+ # Define input values and time stamps
54
+ x_values = [0, 1]
55
+ time_stamp_values = list(range(0, 21))
56
+
57
+ # Generate random choices for features and time stamps
58
+ x1 = np.random.choice(x_values, size=sample_size)
59
+ x2 = np.random.choice(x_values, size=sample_size)
60
+ t1 = np.random.choice(time_stamp_values, size=sample_size)
61
+ t2 = t1 + np.random.choice(time_stamp_values, size=sample_size)
62
+
63
+ # Define conditions based on time differences
64
+ time_diff = t2 - t1
65
+ # Complex condition involving modulo operation
66
+ is_custom_func_1 = (x1 == 1) & (time_diff % 4 == 0)
67
+ is_custom_func_2 = (x1 == 0) & (time_diff % 3 == 0)
68
+ is_xor = time_diff <= 7
69
+ is_and = (time_diff > 7) & (time_diff <= 14)
70
+ is_or = (time_diff > 14) & (time_diff <= 21)
71
+
72
+ # Logical operations based on x1 and x2
73
+ xor = (x2 != x1).astype(int)
74
+ logical_and = (x2 & x1).astype(int)
75
+ logical_or = (x2 | x1).astype(int)
76
+ # Additional complexity: introduce a new rule based on a more complex condition
77
+ custom_func_1_result = (x2 == 0).astype(int) # For example, use a different rule
78
+ custom_func_2_result = (x2 == 1).astype(int) # For example, use a different rule
79
+
80
+ # Determine output based on multiple conditions
81
+ y = np.where(
82
+ is_custom_func_1,
83
+ custom_func_1_result,
84
+ np.where(
85
+ is_custom_func_2,
86
+ custom_func_2_result,
87
+ np.where(
88
+ is_xor,
89
+ xor,
90
+ np.where(is_and, logical_and, np.where(is_or, logical_or, 0)),
91
+ ),
92
+ ),
93
+ )
94
+
95
+ # Return the data as a single numpy array with features and output
96
+ return np.column_stack((x1, x2, t1, t2, y))
97
+
98
+
99
+ def create_time_embedding_tokenizer(simulated_data):
100
+ vocab = []
101
+ for row in simulated_data:
102
+ x1, x2, t1, t2, y = row
103
+ x1 = f"c-{x1}"
104
+ x2 = f"c-{x2}"
105
+ t1 = f"t-{t1}"
106
+ t2 = f"t-{t2}"
107
+ if x1 not in vocab:
108
+ vocab.append(x1)
109
+ if x2 not in vocab:
110
+ vocab.append(x2)
111
+ if t1 not in vocab:
112
+ vocab.append(t1)
113
+ if t2 not in vocab:
114
+ vocab.append(t2)
115
+ return {c: i + 1 for i, c in enumerate(vocab)}
116
+
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+
120
+
121
+ def eval_step(
122
+ simulated_data,
123
+ time_embedding_tokenizer,
124
+ time_embedding_model,
125
+ time_embedding_optimizer,
126
+ ):
127
+ time_embedding_optimizer.zero_grad()
128
+ time_embedding_model.eval()
129
+ eval_input_ids = []
130
+ eval_time_stamps = []
131
+ eval_y = []
132
+ for row in simulated_data:
133
+ x1, x2, t1, t2, y = row
134
+ x1 = f"c-{x1}"
135
+ x2 = f"c-{x2}"
136
+ t1 = f"t-{t1}"
137
+ t2 = f"t-{t2}"
138
+ eval_input_ids.append(
139
+ [time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
140
+ )
141
+ eval_time_stamps.append(
142
+ [time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
143
+ )
144
+ eval_y.append(y)
145
+ eval_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
146
+ eval_time_stamps = torch.tensor(eval_time_stamps, dtype=torch.long).to(device)
147
+ eval_y = np.asarray(eval_y)
148
+ with torch.no_grad():
149
+ # Compute loss and forward pass
150
+ _, y_pred = time_embedding_model(eval_input_ids, eval_time_stamps)
151
+ y_probs = torch.nn.functional.softmax(y_pred, dim=1)
152
+ y_probs = y_probs.detach().cpu().numpy()
153
+ # print(np.concatenate((y_probs, batched_y[:, None]), axis=1))
154
+ roc_auc = roc_auc_score(eval_y, y_probs[:, 1])
155
+ accuracy = accuracy_score(eval_y, y_probs[:, 1] > y_probs[:, 0])
156
+ print(f"ROC AUC: {roc_auc}")
157
+ print(f"Accuracy: {accuracy}")
158
+ return roc_auc, accuracy
159
+
160
+
161
+ def train_step(
162
+ simulated_data,
163
+ time_embedding_tokenizer,
164
+ time_embedding_model,
165
+ time_embedding_optimizer,
166
+ ):
167
+ batched_input_ids = []
168
+ batched_time_stamps = []
169
+ batched_y = []
170
+ indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
171
+ for row in simulated_data[indices, :]:
172
+ x1, x2, t1, t2, y = row
173
+ x1 = f"c-{x1}"
174
+ x2 = f"c-{x2}"
175
+ t1 = f"t-{t1}"
176
+ t2 = f"t-{t2}"
177
+ batched_input_ids.append(
178
+ [time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
179
+ )
180
+ batched_time_stamps.append(
181
+ [time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
182
+ )
183
+ batched_y.append(y)
184
+ batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
185
+ batched_time_stamps = torch.tensor(batched_time_stamps, dtype=torch.long).to(device)
186
+ batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
187
+ # Zero the gradients
188
+ time_embedding_optimizer.zero_grad()
189
+ # Compute loss and forward pass
190
+ loss, _ = time_embedding_model(batched_input_ids, batched_time_stamps, batched_y)
191
+ # Backward pass (compute gradients)
192
+ loss.backward()
193
+ # Update model parameters
194
+ time_embedding_optimizer.step()
195
+ return loss
196
+
197
+
198
+ def main(args):
199
+ simulated_data = generate_simulation_data(args.n_samples)
200
+ time_embedding_tokenizer = create_time_embedding_tokenizer(simulated_data)
201
+ time_embedding_model = ModelTimeEmbedding(len(time_embedding_tokenizer) + 1).to(
202
+ device
203
+ )
204
+ time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
205
+ steps = []
206
+ roc_aucs = []
207
+ accuracies = []
208
+ for step in range(args.n_steps):
209
+ loss = train_step(
210
+ simulated_data,
211
+ time_embedding_tokenizer,
212
+ time_embedding_model,
213
+ time_embedding_optimizer,
214
+ )
215
+ print(f"Step {step}: Loss = {loss.item()}")
216
+ # Evaluation
217
+ if (
218
+ args.n_steps % args.eval_frequency == 0
219
+ and args.n_steps > args.eval_frequency
220
+ ):
221
+ # Zero the gradients
222
+ roc_auc, accuracy = eval_step(
223
+ simulated_data,
224
+ time_embedding_tokenizer,
225
+ time_embedding_model,
226
+ time_embedding_optimizer,
227
+ )
228
+ steps.append(step)
229
+ roc_aucs.append(roc_auc)
230
+ accuracies.append(accuracy)
231
+ return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
232
+
233
+
234
+ if __name__ == "__main__":
235
+ import argparse
236
+ import json
237
+ from pathlib import Path
238
+
239
+ parser = argparse.ArgumentParser("Model with time embedding simulation")
240
+ parser.add_argument("--output_dir", type=str, required=True)
241
+ parser.add_argument("--n_steps", type=int, default=10000)
242
+ parser.add_argument("--n_samples", type=int, default=1000)
243
+ parser.add_argument("--batch_size", type=int, default=128)
244
+ parser.add_argument("--eval_frequency", type=int, default=100)
245
+ args = parser.parse_args()
246
+ output_dir = Path(args.output_dir)
247
+ output_dir.mkdir(exist_ok=True, parents=True)
248
+ metrics = main(args)
249
+ with open(output_dir / "time_embedding_metrics.json", "w") as f:
250
+ json.dump(metrics, f)