cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
1
+ from typing import Optional, Union
2
+
3
+ from datasets import Dataset
4
+ from torch.utils.data import DataLoader
5
+ from transformers import Trainer
6
+ from transformers.trainer_utils import has_length
7
+ from transformers.utils import import_utils, logging
8
+
9
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
10
+
11
+ DEFAULT_MAX_TOKENS_PER_BATCH = 16384
12
+
13
+ LOG = logging.get_logger("transformers")
14
+
15
+
16
+ class SamplePackingTrainer(Trainer):
17
+ def __init__(self, *args, **kwargs):
18
+ if "max_tokens_per_batch" in kwargs:
19
+ self.max_tokens_per_batch = kwargs.pop("max_tokens_per_batch")
20
+ LOG.info("max_tokens_per_batch: %s", self.max_tokens_per_batch)
21
+ else:
22
+ self.max_tokens_per_batch = DEFAULT_MAX_TOKENS_PER_BATCH
23
+ LOG.info(
24
+ "max_tokens_per_batch is not provided to SamplePackingTrainer and will default to %s",
25
+ DEFAULT_MAX_TOKENS_PER_BATCH,
26
+ )
27
+
28
+ if "max_position_embeddings" in kwargs:
29
+ self.max_position_embeddings = kwargs.pop("max_position_embeddings")
30
+ LOG.info("max_position_embeddings: %s", self.max_position_embeddings)
31
+ else:
32
+ self.max_position_embeddings = self.max_tokens_per_batch
33
+ LOG.info(
34
+ "max_position_embeddings is not provided to SamplePackingTrainer and will default to %s",
35
+ self.max_tokens_per_batch,
36
+ )
37
+
38
+ self.train_lengths = kwargs.pop("train_lengths", None)
39
+ self.validation_lengths = kwargs.pop("validation_lengths", None)
40
+ super().__init__(*args, **kwargs)
41
+ self.accelerator.even_batches = False
42
+
43
+ def num_examples(self, dataloader: DataLoader) -> int:
44
+ if has_length(dataloader):
45
+ return len(dataloader)
46
+ raise RuntimeError("DataLoader in SamplePackingTrainer must have length")
47
+
48
+ def get_train_dataloader(self) -> DataLoader:
49
+ """Returns the training dataloader with our custom batch sampler."""
50
+ train_dataset = self.train_dataset
51
+
52
+ if self.train_lengths is None:
53
+ LOG.info("Started computing lengths for the train dataset")
54
+ # Calculate lengths of all sequences in dataset
55
+ if "num_of_concepts" in train_dataset.column_names:
56
+ lengths = train_dataset["num_of_concepts"]
57
+ else:
58
+ lengths = [len(sample["input_ids"]) for sample in train_dataset]
59
+
60
+ LOG.info("Finished computing lengths for the train dataset")
61
+ else:
62
+ lengths = self.train_lengths
63
+
64
+ data_collator = self.data_collator
65
+ if import_utils.is_datasets_available() and isinstance(train_dataset, Dataset):
66
+ train_dataset = self._remove_unused_columns(
67
+ train_dataset, description="training"
68
+ )
69
+ else:
70
+ data_collator = self._get_collator_with_removed_columns(
71
+ data_collator, description="training"
72
+ )
73
+ # Create our custom batch sampler
74
+ batch_sampler = SamplePackingBatchSampler(
75
+ lengths=lengths,
76
+ max_tokens_per_batch=self.max_tokens_per_batch,
77
+ max_position_embeddings=self.max_position_embeddings,
78
+ drop_last=self.args.dataloader_drop_last,
79
+ seed=self.args.seed,
80
+ )
81
+ dataloader_params = {
82
+ "collate_fn": data_collator,
83
+ "num_workers": self.args.dataloader_num_workers,
84
+ "pin_memory": self.args.dataloader_pin_memory,
85
+ "persistent_workers": self.args.dataloader_persistent_workers,
86
+ "batch_sampler": batch_sampler,
87
+ }
88
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
89
+
90
+ def get_eval_dataloader(
91
+ self, eval_dataset: Optional[Union[str, Dataset]] = None
92
+ ) -> DataLoader:
93
+ """
94
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
95
+
96
+ Subclass and override this method if you want to inject some custom behavior.
97
+
98
+ Args:
99
+ eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
100
+ If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
101
+ """
102
+ if eval_dataset is None and self.eval_dataset is None:
103
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
104
+
105
+ # If we have persistent workers, don't do a fork bomb especially as eval datasets
106
+ # don't change during training
107
+ dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
108
+ if (
109
+ hasattr(self, "_eval_dataloaders")
110
+ and dataloader_key in self._eval_dataloaders
111
+ and self.args.dataloader_persistent_workers
112
+ ):
113
+ return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
114
+
115
+ eval_dataset = (
116
+ self.eval_dataset[eval_dataset]
117
+ if isinstance(eval_dataset, str)
118
+ else eval_dataset if eval_dataset is not None else self.eval_dataset
119
+ )
120
+
121
+ if self.validation_lengths is None:
122
+ LOG.info("Started computing lengths for the train dataset")
123
+ # Calculate lengths of all sequences in dataset
124
+ if "num_of_concepts" in eval_dataset.column_names:
125
+ lengths = eval_dataset["num_of_concepts"]
126
+ else:
127
+ lengths = [len(sample["input_ids"]) for sample in eval_dataset]
128
+
129
+ LOG.info("Finished computing lengths for the train dataset")
130
+ else:
131
+ lengths = self.validation_lengths
132
+
133
+ data_collator = self.data_collator
134
+
135
+ if import_utils.is_datasets_available() and isinstance(eval_dataset, Dataset):
136
+ eval_dataset = self._remove_unused_columns(
137
+ eval_dataset, description="evaluation"
138
+ )
139
+ else:
140
+ data_collator = self._get_collator_with_removed_columns(
141
+ data_collator, description="evaluation"
142
+ )
143
+
144
+ # Create our custom batch sampler
145
+ batch_sampler = SamplePackingBatchSampler(
146
+ lengths=lengths,
147
+ max_tokens_per_batch=self.max_tokens_per_batch,
148
+ max_position_embeddings=self.max_position_embeddings,
149
+ drop_last=self.args.dataloader_drop_last,
150
+ seed=self.args.seed,
151
+ )
152
+ dataloader_params = {
153
+ "collate_fn": data_collator,
154
+ "num_workers": self.args.dataloader_num_workers,
155
+ "pin_memory": self.args.dataloader_pin_memory,
156
+ "persistent_workers": self.args.dataloader_persistent_workers,
157
+ "batch_sampler": batch_sampler,
158
+ }
159
+ # accelerator.free_memory() will destroy the references, so
160
+ # we need to store the non-prepared version
161
+ eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
162
+ if self.args.dataloader_persistent_workers:
163
+ if hasattr(self, "_eval_dataloaders"):
164
+ self._eval_dataloaders[dataloader_key] = eval_dataloader
165
+ else:
166
+ self._eval_dataloaders = {dataloader_key: eval_dataloader}
167
+
168
+ return self.accelerator.prepare(eval_dataloader)
@@ -0,0 +1,95 @@
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+
9
+ def main(output_dir: str):
10
+ with open(os.path.join(output_dir, "time_embedding_metrics.json"), "r") as f:
11
+ time_embedding_metrics = json.load(f)
12
+ with open(os.path.join(output_dir, "time_token_metrics.json"), "r") as f:
13
+ time_token_metrics = json.load(f)
14
+
15
+ common_steps = list(
16
+ set(time_embedding_metrics["steps"]) & set(time_token_metrics["steps"])
17
+ )
18
+
19
+ time_embedding_aucs = []
20
+ time_embedding_accuracies = []
21
+ for step, roc_auc, accuracy in zip(
22
+ time_embedding_metrics["steps"],
23
+ time_embedding_metrics["roc_auc"],
24
+ time_embedding_metrics["accuracy"],
25
+ ):
26
+ if step in common_steps:
27
+ time_embedding_aucs.append(roc_auc)
28
+ time_embedding_accuracies.append(accuracy)
29
+
30
+ time_token_aucs = []
31
+ time_token_accuracies = []
32
+ for step, roc_auc, accuracy in zip(
33
+ time_token_metrics["steps"],
34
+ time_token_metrics["roc_auc"],
35
+ time_token_metrics["accuracy"],
36
+ ):
37
+ if step in common_steps:
38
+ time_token_aucs.append(roc_auc)
39
+ time_token_accuracies.append(accuracy)
40
+
41
+ # Create the accuracy plot
42
+ plt.figure(figsize=(8, 5)) # Define figure size
43
+ plt.plot(
44
+ common_steps,
45
+ time_embedding_accuracies,
46
+ linestyle="-",
47
+ color="b",
48
+ label="Time Embedding",
49
+ lw=1,
50
+ )
51
+ plt.plot(
52
+ common_steps,
53
+ time_token_accuracies,
54
+ linestyle="--",
55
+ color="r",
56
+ label="Time Token",
57
+ lw=1,
58
+ )
59
+ plt.title("Accuracy Comparison Over Time")
60
+ plt.xlabel("Training Steps")
61
+ plt.ylabel("Accuracy")
62
+ plt.legend()
63
+ plt.grid(False)
64
+ plt.savefig(os.path.join(output_dir, "accuracy_comparison.png"))
65
+
66
+ # Create the ROC AUC plot
67
+ plt.figure(figsize=(8, 5)) # Define figure size
68
+ plt.plot(
69
+ common_steps,
70
+ time_embedding_aucs,
71
+ linestyle="-",
72
+ color="b",
73
+ label="Time Embedding",
74
+ lw=1,
75
+ )
76
+ plt.plot(
77
+ common_steps,
78
+ time_token_aucs,
79
+ linestyle="--",
80
+ color="r",
81
+ label="Time Token",
82
+ lw=1,
83
+ )
84
+ plt.title("ROC AUC Comparison Over Time")
85
+ plt.xlabel("Training Steps")
86
+ plt.ylabel("ROC AUC")
87
+ plt.legend()
88
+ plt.grid(False)
89
+ plt.savefig(
90
+ os.path.join(output_dir, "roc_auc_comparison.png")
91
+ ) # Save the plot as a PNG file
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main(sys.argv[1])
@@ -0,0 +1,24 @@
1
+ #!/bin/bash
2
+
3
+ # This script runs various Python simulations and generates plots
4
+ # It accepts three parameters: output directory, number of steps, and number of samples
5
+
6
+ # Check if all arguments are provided
7
+ if [ "$#" -ne 3 ]; then
8
+ echo "Usage: $0 <output_dir> <n_steps> <n_samples>"
9
+ exit 1
10
+ fi
11
+
12
+ # Assigning command line arguments to variables
13
+ OUTPUT_DIR="$1"
14
+ N_STEPS="$2"
15
+ N_SAMPLES="$3"
16
+
17
+ # Run time token simulation
18
+ python -u -m cehrgpt.simulations.time_token_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
19
+
20
+ # Run time embedding simulation
21
+ python -u -m cehrgpt.simulations.time_embedding_simulation --output_dir "$OUTPUT_DIR" --n_steps "$N_STEPS" --n_samples "$N_SAMPLES"
22
+
23
+ # Generate plots
24
+ python -u -m cehrgpt.simulations.generate_plots "$OUTPUT_DIR"
@@ -0,0 +1,250 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.optim as optim
6
+ from sklearn.metrics import accuracy_score, roc_auc_score
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import BertConfig, BertModel
9
+
10
+
11
+ class ModelTimeEmbedding(torch.nn.Module):
12
+ def __init__(self, vocab_size: int):
13
+ super(ModelTimeEmbedding, self).__init__()
14
+ self.embedding = torch.nn.Embedding(vocab_size, 16)
15
+ self.bert = BertModel(
16
+ BertConfig(
17
+ vocab_size=vocab_size,
18
+ hidden_size=16,
19
+ num_attention_heads=2,
20
+ num_hidden_layers=2,
21
+ intermediate_size=32,
22
+ hidden_dropout_prob=0.0,
23
+ attention_probs_dropout_prob=0.0,
24
+ max_position_embeddings=2,
25
+ ),
26
+ add_pooling_layer=False,
27
+ )
28
+ self.linear = torch.nn.Linear(32, 2)
29
+
30
+ def forward(
31
+ self,
32
+ input_ids: torch.LongTensor,
33
+ time_stamps: torch.LongTensor,
34
+ labels: Optional[torch.LongTensor] = None,
35
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
36
+ bz = input_ids.shape[0]
37
+ x = self.embedding(input_ids)
38
+ t = self.embedding(time_stamps)
39
+ x = x + t
40
+ bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
41
+ output = bert_output.last_hidden_state.reshape((bz, 32))
42
+ y = self.linear(output)
43
+ loss = None
44
+ if labels is not None:
45
+ loss_fct = CrossEntropyLoss()
46
+ loss = loss_fct(y, labels)
47
+ return loss, y
48
+
49
+
50
+ def generate_simulation_data(sample_size: int = 1000, seed: int = 42) -> np.ndarray:
51
+ np.random.seed(seed) # Set the seed for reproducibility
52
+
53
+ # Define input values and time stamps
54
+ x_values = [0, 1]
55
+ time_stamp_values = list(range(0, 21))
56
+
57
+ # Generate random choices for features and time stamps
58
+ x1 = np.random.choice(x_values, size=sample_size)
59
+ x2 = np.random.choice(x_values, size=sample_size)
60
+ t1 = np.random.choice(time_stamp_values, size=sample_size)
61
+ t2 = t1 + np.random.choice(time_stamp_values, size=sample_size)
62
+
63
+ # Define conditions based on time differences
64
+ time_diff = t2 - t1
65
+ # Complex condition involving modulo operation
66
+ is_custom_func_1 = (x1 == 1) & (time_diff % 4 == 0)
67
+ is_custom_func_2 = (x1 == 0) & (time_diff % 3 == 0)
68
+ is_xor = time_diff <= 7
69
+ is_and = (time_diff > 7) & (time_diff <= 14)
70
+ is_or = (time_diff > 14) & (time_diff <= 21)
71
+
72
+ # Logical operations based on x1 and x2
73
+ xor = (x2 != x1).astype(int)
74
+ logical_and = (x2 & x1).astype(int)
75
+ logical_or = (x2 | x1).astype(int)
76
+ # Additional complexity: introduce a new rule based on a more complex condition
77
+ custom_func_1_result = (x2 == 0).astype(int) # For example, use a different rule
78
+ custom_func_2_result = (x2 == 1).astype(int) # For example, use a different rule
79
+
80
+ # Determine output based on multiple conditions
81
+ y = np.where(
82
+ is_custom_func_1,
83
+ custom_func_1_result,
84
+ np.where(
85
+ is_custom_func_2,
86
+ custom_func_2_result,
87
+ np.where(
88
+ is_xor,
89
+ xor,
90
+ np.where(is_and, logical_and, np.where(is_or, logical_or, 0)),
91
+ ),
92
+ ),
93
+ )
94
+
95
+ # Return the data as a single numpy array with features and output
96
+ return np.column_stack((x1, x2, t1, t2, y))
97
+
98
+
99
+ def create_time_embedding_tokenizer(simulated_data):
100
+ vocab = []
101
+ for row in simulated_data:
102
+ x1, x2, t1, t2, y = row
103
+ x1 = f"c-{x1}"
104
+ x2 = f"c-{x2}"
105
+ t1 = f"t-{t1}"
106
+ t2 = f"t-{t2}"
107
+ if x1 not in vocab:
108
+ vocab.append(x1)
109
+ if x2 not in vocab:
110
+ vocab.append(x2)
111
+ if t1 not in vocab:
112
+ vocab.append(t1)
113
+ if t2 not in vocab:
114
+ vocab.append(t2)
115
+ return {c: i + 1 for i, c in enumerate(vocab)}
116
+
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+
120
+
121
+ def eval_step(
122
+ simulated_data,
123
+ time_embedding_tokenizer,
124
+ time_embedding_model,
125
+ time_embedding_optimizer,
126
+ ):
127
+ time_embedding_optimizer.zero_grad()
128
+ time_embedding_model.eval()
129
+ eval_input_ids = []
130
+ eval_time_stamps = []
131
+ eval_y = []
132
+ for row in simulated_data:
133
+ x1, x2, t1, t2, y = row
134
+ x1 = f"c-{x1}"
135
+ x2 = f"c-{x2}"
136
+ t1 = f"t-{t1}"
137
+ t2 = f"t-{t2}"
138
+ eval_input_ids.append(
139
+ [time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
140
+ )
141
+ eval_time_stamps.append(
142
+ [time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
143
+ )
144
+ eval_y.append(y)
145
+ eval_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
146
+ eval_time_stamps = torch.tensor(eval_time_stamps, dtype=torch.long).to(device)
147
+ eval_y = np.asarray(eval_y)
148
+ with torch.no_grad():
149
+ # Compute loss and forward pass
150
+ _, y_pred = time_embedding_model(eval_input_ids, eval_time_stamps)
151
+ y_probs = torch.nn.functional.softmax(y_pred, dim=1)
152
+ y_probs = y_probs.detach().cpu().numpy()
153
+ # print(np.concatenate((y_probs, batched_y[:, None]), axis=1))
154
+ roc_auc = roc_auc_score(eval_y, y_probs[:, 1])
155
+ accuracy = accuracy_score(eval_y, y_probs[:, 1] > y_probs[:, 0])
156
+ print(f"ROC AUC: {roc_auc}")
157
+ print(f"Accuracy: {accuracy}")
158
+ return roc_auc, accuracy
159
+
160
+
161
+ def train_step(
162
+ simulated_data,
163
+ time_embedding_tokenizer,
164
+ time_embedding_model,
165
+ time_embedding_optimizer,
166
+ ):
167
+ batched_input_ids = []
168
+ batched_time_stamps = []
169
+ batched_y = []
170
+ indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
171
+ for row in simulated_data[indices, :]:
172
+ x1, x2, t1, t2, y = row
173
+ x1 = f"c-{x1}"
174
+ x2 = f"c-{x2}"
175
+ t1 = f"t-{t1}"
176
+ t2 = f"t-{t2}"
177
+ batched_input_ids.append(
178
+ [time_embedding_tokenizer[x1], time_embedding_tokenizer[x2]]
179
+ )
180
+ batched_time_stamps.append(
181
+ [time_embedding_tokenizer[t1], time_embedding_tokenizer[t2]]
182
+ )
183
+ batched_y.append(y)
184
+ batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
185
+ batched_time_stamps = torch.tensor(batched_time_stamps, dtype=torch.long).to(device)
186
+ batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
187
+ # Zero the gradients
188
+ time_embedding_optimizer.zero_grad()
189
+ # Compute loss and forward pass
190
+ loss, _ = time_embedding_model(batched_input_ids, batched_time_stamps, batched_y)
191
+ # Backward pass (compute gradients)
192
+ loss.backward()
193
+ # Update model parameters
194
+ time_embedding_optimizer.step()
195
+ return loss
196
+
197
+
198
+ def main(args):
199
+ simulated_data = generate_simulation_data(args.n_samples)
200
+ time_embedding_tokenizer = create_time_embedding_tokenizer(simulated_data)
201
+ time_embedding_model = ModelTimeEmbedding(len(time_embedding_tokenizer) + 1).to(
202
+ device
203
+ )
204
+ time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
205
+ steps = []
206
+ roc_aucs = []
207
+ accuracies = []
208
+ for step in range(args.n_steps):
209
+ loss = train_step(
210
+ simulated_data,
211
+ time_embedding_tokenizer,
212
+ time_embedding_model,
213
+ time_embedding_optimizer,
214
+ )
215
+ print(f"Step {step}: Loss = {loss.item()}")
216
+ # Evaluation
217
+ if (
218
+ args.n_steps % args.eval_frequency == 0
219
+ and args.n_steps > args.eval_frequency
220
+ ):
221
+ # Zero the gradients
222
+ roc_auc, accuracy = eval_step(
223
+ simulated_data,
224
+ time_embedding_tokenizer,
225
+ time_embedding_model,
226
+ time_embedding_optimizer,
227
+ )
228
+ steps.append(step)
229
+ roc_aucs.append(roc_auc)
230
+ accuracies.append(accuracy)
231
+ return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
232
+
233
+
234
+ if __name__ == "__main__":
235
+ import argparse
236
+ import json
237
+ from pathlib import Path
238
+
239
+ parser = argparse.ArgumentParser("Model with time embedding simulation")
240
+ parser.add_argument("--output_dir", type=str, required=True)
241
+ parser.add_argument("--n_steps", type=int, default=10000)
242
+ parser.add_argument("--n_samples", type=int, default=1000)
243
+ parser.add_argument("--batch_size", type=int, default=128)
244
+ parser.add_argument("--eval_frequency", type=int, default=100)
245
+ args = parser.parse_args()
246
+ output_dir = Path(args.output_dir)
247
+ output_dir.mkdir(exist_ok=True, parents=True)
248
+ metrics = main(args)
249
+ with open(output_dir / "time_embedding_metrics.json", "w") as f:
250
+ json.dump(metrics, f)