cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.optim as optim
6
+ from sklearn.metrics import accuracy_score, roc_auc_score
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import BertConfig, BertModel
9
+
10
+ from cehrgpt.simulations.time_embedding_simulation import generate_simulation_data
11
+
12
+
13
+ class ModelTimeToken(torch.nn.Module):
14
+ def __init__(self, vocab_size: int):
15
+ super(ModelTimeToken, self).__init__()
16
+ self.embedding = torch.nn.Embedding(vocab_size, 16)
17
+ self.bert = BertModel(
18
+ BertConfig(
19
+ vocab_size=vocab_size,
20
+ hidden_size=16,
21
+ num_attention_heads=2,
22
+ num_hidden_layers=2,
23
+ intermediate_size=32,
24
+ hidden_dropout_prob=0.0,
25
+ attention_probs_dropout_prob=0.0,
26
+ max_position_embeddings=3,
27
+ ),
28
+ add_pooling_layer=False,
29
+ )
30
+ self.linear = torch.nn.Linear(48, 2)
31
+
32
+ def forward(
33
+ self,
34
+ input_ids: torch.LongTensor,
35
+ labels: Optional[torch.LongTensor] = None,
36
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
37
+ bz = input_ids.shape[0]
38
+ x = self.embedding(input_ids)
39
+ bert_output = self.bert.forward(inputs_embeds=x, return_dict=True)
40
+ output = bert_output.last_hidden_state.reshape((bz, 48))
41
+ y = self.linear(output)
42
+ loss = None
43
+ if labels is not None:
44
+ loss_fct = CrossEntropyLoss()
45
+ loss = loss_fct(y, labels)
46
+ return loss, y
47
+
48
+
49
+ def create_time_token_tokenizer(simulated_data):
50
+ vocab = []
51
+ for row in simulated_data:
52
+ x1, x2, t1, t2, y = row
53
+ x1 = f"c-{x1}"
54
+ x2 = f"c-{x2}"
55
+ t = f"t-{t2 - t1}"
56
+ if x1 not in vocab:
57
+ vocab.append(x1)
58
+ if x2 not in vocab:
59
+ vocab.append(x2)
60
+ if t not in vocab:
61
+ vocab.append(t)
62
+ return {c: i + 1 for i, c in enumerate(vocab)}
63
+
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
+
68
+ def eval_step(simulated_data, time_token_tokenizer, time_embedding_model):
69
+ time_embedding_model.eval()
70
+ eval_input_ids = []
71
+ eval_y = []
72
+ for row in simulated_data:
73
+ x1, x2, t1, t2, y = row
74
+ x1 = f"c-{x1}"
75
+ x2 = f"c-{x2}"
76
+ t = f"t-{t2 - t1}"
77
+ eval_input_ids.append(
78
+ [
79
+ time_token_tokenizer[x1],
80
+ time_token_tokenizer[t],
81
+ time_token_tokenizer[x2],
82
+ ]
83
+ )
84
+ eval_y.append(y)
85
+ with torch.no_grad():
86
+ batched_input_ids = torch.tensor(eval_input_ids, dtype=torch.long).to(device)
87
+ batched_y = np.asarray(eval_y)
88
+ # Compute loss and forward pass
89
+ _, y_pred = time_embedding_model(batched_input_ids)
90
+ y_probs = torch.nn.functional.softmax(y_pred, dim=1)
91
+ y_probs = y_probs.detach().cpu().numpy()
92
+ roc_auc = roc_auc_score(batched_y, y_probs[:, 1])
93
+ accuracy = accuracy_score(batched_y, y_probs[:, 1] > y_probs[:, 0])
94
+ print(f"ROC AUC: {roc_auc}")
95
+ print(f"Accuracy: {accuracy}")
96
+ return accuracy, roc_auc
97
+
98
+
99
+ def train_step(
100
+ simulated_data, time_token_tokenizer, time_embedding_model, time_embedding_optimizer
101
+ ):
102
+ batched_input_ids = []
103
+ batched_y = []
104
+ indices = np.random.choice(simulated_data.shape[0], size=8, replace=False)
105
+ for row in simulated_data[indices, :]:
106
+ x1, x2, t1, t2, y = row
107
+ x1 = f"c-{x1}"
108
+ x2 = f"c-{x2}"
109
+ t = f"t-{t2 - t1}"
110
+ batched_input_ids.append(
111
+ [
112
+ time_token_tokenizer[x1],
113
+ time_token_tokenizer[t],
114
+ time_token_tokenizer[x2],
115
+ ]
116
+ )
117
+ batched_y.append(y)
118
+ batched_input_ids = torch.tensor(batched_input_ids, dtype=torch.long).to(device)
119
+ batched_y = torch.tensor(batched_y, dtype=torch.long).to(device)
120
+ # Zero the gradients
121
+ time_embedding_optimizer.zero_grad()
122
+ # Compute loss and forward pass
123
+ loss, _ = time_embedding_model(batched_input_ids, batched_y)
124
+ # Backward pass (compute gradients)
125
+ loss.backward()
126
+ # Update model parameters
127
+ time_embedding_optimizer.step()
128
+ return loss
129
+
130
+
131
+ def main(args):
132
+ simulated_data = generate_simulation_data(args.n_samples)
133
+ time_token_tokenizer = create_time_token_tokenizer(simulated_data)
134
+ time_embedding_model = ModelTimeToken(len(time_token_tokenizer) + 1).to(device)
135
+ time_embedding_optimizer = optim.Adam(time_embedding_model.parameters(), lr=0.001)
136
+ steps = []
137
+ roc_aucs = []
138
+ accuracies = []
139
+ for step in range(args.n_steps):
140
+ loss = train_step(
141
+ simulated_data,
142
+ time_token_tokenizer,
143
+ time_embedding_model,
144
+ time_embedding_optimizer,
145
+ )
146
+ print(f"Step {step}: Loss = {loss.item()}")
147
+ # Evaluation
148
+ if (
149
+ args.n_steps % args.eval_frequency == 0
150
+ and args.n_steps > args.eval_frequency
151
+ ):
152
+ accuracy, roc_auc = eval_step(
153
+ simulated_data, time_token_tokenizer, time_embedding_model
154
+ )
155
+ steps.append(step)
156
+ roc_aucs.append(roc_auc)
157
+ accuracies.append(accuracy)
158
+ return {"steps": steps, "roc_auc": roc_aucs, "accuracy": accuracies}
159
+
160
+
161
+ if __name__ == "__main__":
162
+ import argparse
163
+ import json
164
+ from pathlib import Path
165
+
166
+ parser = argparse.ArgumentParser("Model with time token simulation")
167
+ parser.add_argument("--output_dir", type=str, required=True)
168
+ parser.add_argument("--n_steps", type=int, default=10000)
169
+ parser.add_argument("--n_samples", type=int, default=1000)
170
+ parser.add_argument("--batch_size", type=int, default=128)
171
+ parser.add_argument("--eval_frequency", type=int, default=100)
172
+ args = parser.parse_args()
173
+ output_dir = Path(args.output_dir)
174
+ output_dir.mkdir(exist_ok=True, parents=True)
175
+ metrics = main(args)
176
+ with open(output_dir / "time_token_metrics.json", "w") as f:
177
+ json.dump(metrics, f)
File without changes
@@ -0,0 +1,467 @@
1
+ import glob
2
+ import os
3
+ import shutil
4
+ import uuid
5
+ from datetime import datetime
6
+ from functools import partial
7
+ from pathlib import Path
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ import torch.distributed as dist
14
+ from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
15
+ from cehrbert.runners.runner_util import generate_prepared_ds_path
16
+ from datasets import concatenate_datasets, load_from_disk
17
+ from torch.utils.data import DataLoader
18
+ from tqdm import tqdm
19
+ from transformers.trainer_utils import is_main_process
20
+ from transformers.utils import is_flash_attn_2_available, logging
21
+
22
+ from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
23
+ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
24
+ CehrGptDataCollator,
25
+ SamplePackingCehrGptDataCollator,
26
+ )
27
+ from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
28
+ from cehrgpt.models.hf_cehrgpt import (
29
+ CEHRGPT2Model,
30
+ extract_features_from_packed_sequence,
31
+ )
32
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
33
+ from cehrgpt.runners.data_utils import prepare_finetune_dataset
34
+ from cehrgpt.runners.gpt_runner_util import parse_runner_args
35
+ from cehrgpt.runners.hf_cehrgpt_pretrain_runner import tokenizer_exists
36
+
37
+ LOG = logging.get_logger("transformers")
38
+
39
+
40
+ def get_torch_dtype(torch_dtype: Optional[str] = None) -> Union[torch.dtype, str]:
41
+ if torch_dtype and hasattr(torch, torch_dtype):
42
+ return getattr(torch, torch_dtype)
43
+ return torch.float32
44
+
45
+
46
+ def extract_averaged_embeddings_from_packed_sequence(
47
+ hidden_states: torch.Tensor,
48
+ attention_mask: torch.Tensor,
49
+ ve_token_indicators: torch.BoolTensor,
50
+ ) -> torch.Tensor:
51
+ """
52
+ Args:
53
+
54
+ hidden_states: (batch_size=1, seq_len, hidden_dim) tensor
55
+ attention_mask: (batch_size=1, seq_len) tensor, where 0 indicates padding
56
+ ve_token_indicators: (batch_size=1, seq_len) bool tensor, True if token is VE token
57
+ Returns:
58
+ (num_samples, hidden_dim) tensor: averaged embeddings over VE tokens for each sample
59
+ """
60
+ # Step 1: Create segment IDs
61
+ mask = attention_mask[0] # (seq_len,)
62
+ segment_ids = (mask == 0).cumsum(dim=0) + 1 # start segment IDs from 1
63
+ segment_ids = (segment_ids * mask).to(torch.int32) # set PAD positions back to 0
64
+
65
+ # Step 2: Only keep tokens that are both valid and VE tokens
66
+ valid = (segment_ids > 0) & (ve_token_indicators[0])
67
+ valid_embeddings = hidden_states[0, valid].to(
68
+ torch.float32
69
+ ) # (num_valid_ve_tokens, hidden_dim)
70
+ valid_segments = segment_ids[valid] # (num_valid_ve_tokens,)
71
+
72
+ # Step 3: Group by segment id and average
73
+ num_segments = int(segment_ids.max().item())
74
+
75
+ sample_embeddings = torch.zeros(
76
+ num_segments, hidden_states.size(-1), device=hidden_states.device
77
+ )
78
+ counts = torch.zeros(num_segments, device=hidden_states.device)
79
+
80
+ sample_embeddings.index_add_(0, valid_segments - 1, valid_embeddings)
81
+ counts.index_add_(
82
+ 0, valid_segments - 1, torch.ones_like(valid_segments, dtype=counts.dtype)
83
+ )
84
+
85
+ # Avoid divide-by-zero (if some segments have no VE tokens, set their embeddings to zero)
86
+ counts = counts.masked_fill(counts == 0, 1.0)
87
+
88
+ sample_embeddings = sample_embeddings / counts.unsqueeze(-1)
89
+
90
+ return sample_embeddings
91
+
92
+
93
+ def main():
94
+ cehrgpt_args, data_args, model_args, training_args = parse_runner_args()
95
+ if torch.cuda.is_available():
96
+ device = torch.device("cuda")
97
+ else:
98
+ device = torch.device("cpu")
99
+
100
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
101
+ model_args.tokenizer_name_or_path
102
+ )
103
+ torch_dtype = get_torch_dtype(model_args.torch_dtype)
104
+ cehrgpt_model = (
105
+ CEHRGPT2Model.from_pretrained(
106
+ model_args.model_name_or_path,
107
+ attn_implementation=(
108
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
109
+ ),
110
+ torch_dtype=torch_dtype,
111
+ )
112
+ .eval()
113
+ .to(device)
114
+ )
115
+ prepared_ds_path = generate_prepared_ds_path(
116
+ data_args, model_args, data_folder=data_args.cohort_folder
117
+ )
118
+ cache_file_collector = CacheFileCollector()
119
+ processed_dataset = None
120
+ if any(prepared_ds_path.glob("*")):
121
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
122
+ processed_dataset = load_from_disk(str(prepared_ds_path))
123
+ LOG.info("Prepared dataset loaded from disk...")
124
+ if cehrgpt_args.expand_tokenizer:
125
+ if tokenizer_exists(training_args.output_dir):
126
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
127
+ training_args.output_dir
128
+ )
129
+ else:
130
+ LOG.warning(
131
+ f"CehrGptTokenizer must exist in {training_args.output_dir} "
132
+ f"when the dataset has been processed and expand_tokenizer is set to True. "
133
+ f"Please delete the processed dataset at {prepared_ds_path}."
134
+ )
135
+ processed_dataset = None
136
+ shutil.rmtree(prepared_ds_path)
137
+
138
+ if processed_dataset is None:
139
+ if is_main_process(training_args.local_rank):
140
+ # Organize them into a single DatasetDict
141
+ final_splits = prepare_finetune_dataset(
142
+ data_args, training_args, cehrgpt_args, cache_file_collector
143
+ )
144
+ if cehrgpt_args.expand_tokenizer:
145
+ new_tokenizer_path = os.path.expanduser(training_args.output_dir)
146
+ if tokenizer_exists(new_tokenizer_path):
147
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
148
+ new_tokenizer_path
149
+ )
150
+ else:
151
+ cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
152
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
153
+ dataset=final_splits["train"],
154
+ data_args=data_args,
155
+ concept_name_mapping={},
156
+ )
157
+ cehrgpt_tokenizer.save_pretrained(
158
+ os.path.expanduser(training_args.output_dir)
159
+ )
160
+
161
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
162
+ if not data_args.streaming:
163
+ all_columns = final_splits["train"].column_names
164
+ if "visit_concept_ids" in all_columns:
165
+ final_splits = final_splits.remove_columns(["visit_concept_ids"])
166
+
167
+ processed_dataset = create_cehrgpt_finetuning_dataset(
168
+ dataset=final_splits,
169
+ cehrgpt_tokenizer=cehrgpt_tokenizer,
170
+ data_args=data_args,
171
+ cache_file_collector=cache_file_collector,
172
+ )
173
+ if not data_args.streaming:
174
+ processed_dataset.save_to_disk(prepared_ds_path)
175
+ processed_dataset.cleanup_cache_files()
176
+
177
+ # Remove all the cached files if processed_dataset.cleanup_cache_files() did not remove them already
178
+ cache_file_collector.remove_cache_files()
179
+
180
+ # After main-process-only operations, synchronize all processes to ensure consistency
181
+ if dist.is_available() and dist.is_initialized():
182
+ dist.barrier()
183
+
184
+ # Load the dataset from disk again to in torch distributed training
185
+ processed_dataset = load_from_disk(str(prepared_ds_path))
186
+
187
+ # Getting the existing features
188
+ feature_folders = glob.glob(
189
+ os.path.join(training_args.output_dir, "*", "features", "*.parquet")
190
+ )
191
+ if feature_folders:
192
+ existing_features = pd.concat(
193
+ [
194
+ pd.read_parquet(f, columns=["subject_id", "prediction_time_posix"])
195
+ for f in feature_folders
196
+ ],
197
+ ignore_index=True,
198
+ )
199
+ subject_prediction_tuples = set(
200
+ existing_features.apply(
201
+ lambda row: f"{int(row['subject_id'])}-{int(row['prediction_time_posix'])}",
202
+ axis=1,
203
+ ).tolist()
204
+ )
205
+ processed_dataset = processed_dataset.filter(
206
+ lambda _batch: [
207
+ f"{int(subject)}-{int(time)}" not in subject_prediction_tuples
208
+ for subject, time in zip(_batch["person_id"], _batch["index_date"])
209
+ ],
210
+ num_proc=data_args.preprocessing_num_workers,
211
+ batch_size=data_args.preprocessing_batch_size,
212
+ batched=True,
213
+ )
214
+ LOG.info(
215
+ "The datasets after filtering (train: %s, validation: %s, test: %s)",
216
+ len(processed_dataset["train"]),
217
+ len(processed_dataset["validation"]),
218
+ len(processed_dataset["test"]),
219
+ )
220
+
221
+ LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
222
+ LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
223
+ if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
224
+ cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
225
+ if (
226
+ cehrgpt_model.config.max_position_embeddings
227
+ < model_args.max_position_embeddings
228
+ ):
229
+ LOG.info(
230
+ f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
231
+ )
232
+ cehrgpt_model.config.max_position_embeddings = (
233
+ model_args.max_position_embeddings
234
+ )
235
+ cehrgpt_model.resize_position_embeddings(model_args.max_position_embeddings)
236
+
237
+ train_set = concatenate_datasets(
238
+ [processed_dataset["train"], processed_dataset["validation"]]
239
+ )
240
+
241
+ if cehrgpt_args.sample_packing:
242
+ per_device_eval_batch_size = 1
243
+ data_collator_fn = partial(
244
+ SamplePackingCehrGptDataCollator,
245
+ cehrgpt_args.max_tokens_per_batch,
246
+ cehrgpt_model.config.max_position_embeddings,
247
+ )
248
+ train_batch_sampler = SamplePackingBatchSampler(
249
+ lengths=train_set["num_of_concepts"],
250
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
251
+ max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
252
+ drop_last=training_args.dataloader_drop_last,
253
+ seed=training_args.seed,
254
+ )
255
+ test_batch_sampler = SamplePackingBatchSampler(
256
+ lengths=processed_dataset["test"]["num_of_concepts"],
257
+ max_tokens_per_batch=cehrgpt_args.max_tokens_per_batch,
258
+ max_position_embeddings=cehrgpt_model.config.max_position_embeddings,
259
+ drop_last=training_args.dataloader_drop_last,
260
+ seed=training_args.seed,
261
+ )
262
+ else:
263
+ data_collator_fn = CehrGptDataCollator
264
+ train_batch_sampler = None
265
+ test_batch_sampler = None
266
+ per_device_eval_batch_size = training_args.per_device_eval_batch_size
267
+
268
+ # We suppress the additional learning objectives in fine-tuning
269
+ data_collator = data_collator_fn(
270
+ tokenizer=cehrgpt_tokenizer,
271
+ max_length=(
272
+ cehrgpt_args.max_tokens_per_batch
273
+ if cehrgpt_args.sample_packing
274
+ else model_args.max_position_embeddings
275
+ ),
276
+ include_values=cehrgpt_model.config.include_values,
277
+ pretraining=False,
278
+ include_ttv_prediction=False,
279
+ use_sub_time_tokenization=False,
280
+ include_demographics=cehrgpt_args.include_demographics,
281
+ )
282
+
283
+ train_loader = DataLoader(
284
+ dataset=train_set,
285
+ batch_size=per_device_eval_batch_size,
286
+ num_workers=training_args.dataloader_num_workers,
287
+ collate_fn=data_collator,
288
+ pin_memory=training_args.dataloader_pin_memory,
289
+ batch_sampler=train_batch_sampler,
290
+ )
291
+
292
+ test_dataloader = DataLoader(
293
+ dataset=processed_dataset["test"],
294
+ batch_size=per_device_eval_batch_size,
295
+ num_workers=training_args.dataloader_num_workers,
296
+ collate_fn=data_collator,
297
+ pin_memory=training_args.dataloader_pin_memory,
298
+ batch_sampler=test_batch_sampler,
299
+ )
300
+
301
+ # Loading demographics
302
+ print("Loading demographics as a dictionary")
303
+ demographics_df = pd.concat(
304
+ [
305
+ pd.read_parquet(
306
+ data_dir,
307
+ columns=[
308
+ "person_id",
309
+ "index_date",
310
+ "gender_concept_id",
311
+ "race_concept_id",
312
+ ],
313
+ )
314
+ for data_dir in [data_args.data_folder, data_args.test_data_folder]
315
+ ]
316
+ )
317
+ demographics_df["index_date"] = demographics_df.index_date.dt.date
318
+ demographics_dict = {
319
+ (row["person_id"], row["index_date"]): {
320
+ "gender_concept_id": row["gender_concept_id"],
321
+ "race_concept_id": row["race_concept_id"],
322
+ }
323
+ for _, row in demographics_df.iterrows()
324
+ }
325
+
326
+ data_loaders = [("train", train_loader), ("test", test_dataloader)]
327
+
328
+ ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
329
+ for split, data_loader in data_loaders:
330
+ # Ensure prediction folder exists
331
+ feature_output_folder = (
332
+ Path(training_args.output_dir) / "features_with_label" / f"{split}_features"
333
+ )
334
+ feature_output_folder.mkdir(parents=True, exist_ok=True)
335
+
336
+ LOG.info("Generating features for %s set at %s", split, feature_output_folder)
337
+
338
+ with torch.no_grad():
339
+ for index, batch in enumerate(
340
+ tqdm(data_loader, desc="Generating features")
341
+ ):
342
+ prediction_time_ages = (
343
+ batch.pop("age_at_index").numpy().astype(float).squeeze()
344
+ )
345
+ if prediction_time_ages.ndim == 0:
346
+ prediction_time_ages = np.asarray([prediction_time_ages])
347
+
348
+ person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
349
+ if person_ids.ndim == 0:
350
+ person_ids = np.asarray([person_ids])
351
+ prediction_time_posix = batch.pop("index_date").numpy().squeeze()
352
+ if prediction_time_posix.ndim == 0:
353
+ prediction_time_posix = np.asarray([prediction_time_posix])
354
+ prediction_time = list(
355
+ map(datetime.fromtimestamp, prediction_time_posix)
356
+ )
357
+ labels = (
358
+ batch.pop("classifier_label")
359
+ .float()
360
+ .cpu()
361
+ .numpy()
362
+ .astype(bool)
363
+ .squeeze()
364
+ )
365
+ if labels.ndim == 0:
366
+ labels = np.asarray([labels])
367
+
368
+ batch = {k: v.to(device) for k, v in batch.items()}
369
+ # Forward pass
370
+ cehrgpt_output = cehrgpt_model(
371
+ **batch, output_attentions=False, output_hidden_states=False
372
+ )
373
+ if cehrgpt_args.sample_packing:
374
+ if cehrgpt_args.average_over_sequence:
375
+ ve_token_indicators: torch.BoolTensor = (
376
+ batch["input_ids"] == ve_token_id
377
+ )
378
+ features = (
379
+ extract_averaged_embeddings_from_packed_sequence(
380
+ cehrgpt_output.last_hidden_state,
381
+ batch["attention_mask"],
382
+ ve_token_indicators,
383
+ )
384
+ .cpu()
385
+ .float()
386
+ .detach()
387
+ .numpy()
388
+ )
389
+ else:
390
+ features = (
391
+ extract_features_from_packed_sequence(
392
+ cehrgpt_output.last_hidden_state,
393
+ batch["attention_mask"],
394
+ )
395
+ .cpu()
396
+ .float()
397
+ .detach()
398
+ .numpy()
399
+ .squeeze(axis=0)
400
+ )
401
+ else:
402
+ if cehrgpt_args.average_over_sequence:
403
+ features = torch.where(
404
+ batch["attention_mask"].unsqueeze(dim=-1).to(torch.bool),
405
+ cehrgpt_output.last_hidden_state,
406
+ 0,
407
+ )
408
+ # Average across the sequence
409
+ features = features.mean(dim=1)
410
+ else:
411
+ last_end_token = any(
412
+ [
413
+ cehrgpt_tokenizer.end_token_id == input_id
414
+ for input_id in batch.pop("input_ids")
415
+ .cpu()
416
+ .numpy()
417
+ .squeeze()
418
+ .tolist()
419
+ ]
420
+ )
421
+ last_token_index = -2 if last_end_token else -1
422
+ LOG.debug(
423
+ "The last token is [END], we need to use the token index before that: %s",
424
+ last_token_index,
425
+ )
426
+ features = (
427
+ cehrgpt_output.last_hidden_state[..., last_token_index, :]
428
+ .cpu()
429
+ .float()
430
+ .detach()
431
+ .numpy()
432
+ )
433
+
434
+ # Flatten features or handle them as a list of arrays (one array per row)
435
+ features_list = [feature for feature in features]
436
+ race_concept_ids = []
437
+ gender_concept_ids = []
438
+ for person_id, index_date in zip(person_ids, prediction_time):
439
+ key = (person_id, index_date.date())
440
+ if key in demographics_dict:
441
+ demographics = demographics_dict[key]
442
+ gender_concept_ids.append(demographics["gender_concept_id"])
443
+ race_concept_ids.append(demographics["race_concept_id"])
444
+ else:
445
+ gender_concept_ids.append(0)
446
+ race_concept_ids.append(0)
447
+
448
+ features_pd = pd.DataFrame(
449
+ {
450
+ "subject_id": person_ids,
451
+ "prediction_time": prediction_time,
452
+ "prediction_time_posix": prediction_time_posix,
453
+ "boolean_value": labels,
454
+ "age_at_index": prediction_time_ages,
455
+ }
456
+ )
457
+ # Adding features as a separate column where each row contains a feature array
458
+ features_pd["features"] = features_list
459
+ features_pd["race_concept_id"] = race_concept_ids
460
+ features_pd["gender_concept_id"] = gender_concept_ids
461
+ features_pd.to_parquet(
462
+ feature_output_folder / f"{uuid.uuid4()}.parquet"
463
+ )
464
+
465
+
466
+ if __name__ == "__main__":
467
+ main()