cehrgpt 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cehrgpt/analysis/htn_treatment_pathway.py +546 -0
  2. cehrgpt/analysis/treatment_pathway/__init__.py +0 -0
  3. cehrgpt/analysis/treatment_pathway/depression_treatment_pathway.py +94 -0
  4. cehrgpt/analysis/treatment_pathway/diabetes_treatment_pathway.py +94 -0
  5. cehrgpt/analysis/treatment_pathway/htn_treatment_pathway.py +94 -0
  6. cehrgpt/analysis/treatment_pathway/treatment_pathway.py +631 -0
  7. cehrgpt/data/cehrgpt_data_processor.py +549 -0
  8. cehrgpt/data/hf_cehrgpt_dataset.py +4 -0
  9. cehrgpt/data/hf_cehrgpt_dataset_collator.py +286 -629
  10. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +60 -14
  11. cehrgpt/generation/cehrgpt_conditional_generation.py +316 -0
  12. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +35 -15
  13. cehrgpt/generation/omop_converter_batch.py +11 -4
  14. cehrgpt/gpt_utils.py +73 -3
  15. cehrgpt/models/activations.py +27 -0
  16. cehrgpt/models/config.py +6 -2
  17. cehrgpt/models/gpt2.py +560 -0
  18. cehrgpt/models/hf_cehrgpt.py +193 -459
  19. cehrgpt/models/tokenization_hf_cehrgpt.py +380 -50
  20. cehrgpt/omop/ontology.py +154 -0
  21. cehrgpt/runners/data_utils.py +17 -6
  22. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +33 -79
  23. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +48 -44
  24. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +58 -34
  25. cehrgpt/runners/hyperparameter_search_util.py +180 -69
  26. cehrgpt/runners/sample_packing_trainer.py +11 -2
  27. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +27 -31
  28. cehrgpt-0.1.3.dist-info/METADATA +238 -0
  29. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/RECORD +33 -22
  30. cehrgpt-0.1.1.dist-info/METADATA +0 -115
  31. /cehrgpt/tools/{merge_synthetic_real_dataasets.py → merge_synthetic_real_datasets.py} +0 -0
  32. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/WHEEL +0 -0
  33. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/licenses/LICENSE +0 -0
  34. {cehrgpt-0.1.1.dist-info → cehrgpt-0.1.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  import dataclasses
2
- from typing import List, Optional
2
+ from typing import List, Literal, Optional
3
+
4
+ from cehrgpt.models.gpt2 import ACT2FN
3
5
 
4
6
 
5
7
  @dataclasses.dataclass
@@ -12,6 +14,14 @@ class CehrGPTArguments:
12
14
  "help": "The path to the tokenized dataset created for the full population"
13
15
  },
14
16
  )
17
+ activation_function: Literal[tuple(ACT2FN.keys())] = dataclasses.field(
18
+ default="gelu_new",
19
+ metadata={"help": "The activation function to use"},
20
+ )
21
+ decoder_mlp: Literal["GPT2MLP", "LlamaMLP"] = dataclasses.field(
22
+ default="GPT2MLP",
23
+ metadata={"help": "The decoder MLP architecture"},
24
+ )
15
25
  include_inpatient_hour_token: Optional[bool] = dataclasses.field(
16
26
  default=True,
17
27
  metadata={"help": "Include inpatient hour token"},
@@ -54,6 +64,14 @@ class CehrGPTArguments:
54
64
  default=128,
55
65
  metadata={"help": "The number of examples from the training set."},
56
66
  )
67
+ hyperparameter_tuning: Optional[bool] = dataclasses.field(
68
+ default=False,
69
+ metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
70
+ )
71
+ hyperparameter_tuning_is_grid: Optional[bool] = dataclasses.field(
72
+ default=True,
73
+ metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
74
+ )
57
75
  hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
58
76
  default=0.1,
59
77
  metadata={
@@ -66,10 +84,6 @@ class CehrGPTArguments:
66
84
  "help": "The number of trails will be use for hyperparameter tuning."
67
85
  },
68
86
  )
69
- hyperparameter_tuning: Optional[bool] = dataclasses.field(
70
- default=False,
71
- metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
72
- )
73
87
  hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
74
88
  default_factory=lambda: [4, 8, 16],
75
89
  metadata={"help": "Hyperparameter search batch sizes"},
@@ -78,29 +92,13 @@ class CehrGPTArguments:
78
92
  default_factory=lambda: [10],
79
93
  metadata={"help": "Hyperparameter search num_train_epochs"},
80
94
  )
81
- lr_low: Optional[float] = dataclasses.field(
82
- default=1e-5,
83
- metadata={
84
- "help": "The lower bound of the learning rate range for hyperparameter tuning."
85
- },
95
+ hyperparameter_learning_rates: Optional[List[int]] = dataclasses.field(
96
+ default_factory=lambda: [1e-5],
97
+ metadata={"help": "Hyperparameter search learning rates"},
86
98
  )
87
- lr_high: Optional[float] = dataclasses.field(
88
- default=5e-5,
89
- metadata={
90
- "help": "The upper bound of the learning rate range for hyperparameter tuning."
91
- },
92
- )
93
- weight_decays_low: Optional[float] = dataclasses.field(
94
- default=1e-3,
95
- metadata={
96
- "help": "The lower bound of the weight decays range for hyperparameter tuning."
97
- },
98
- )
99
- weight_decays_high: Optional[float] = dataclasses.field(
100
- default=1e-2,
101
- metadata={
102
- "help": "The upper bound of the weight decays range for hyperparameter tuning."
103
- },
99
+ hyperparameter_weight_decays: Optional[List[int]] = dataclasses.field(
100
+ default_factory=lambda: [1e-2],
101
+ metadata={"help": "Hyperparameter search learning rates"},
104
102
  )
105
103
  causal_sfm: Optional[bool] = dataclasses.field(
106
104
  default=False,
@@ -168,6 +166,16 @@ class CehrGPTArguments:
168
166
  "help": "A threshold to denote how much the specified metric must improve to satisfy early stopping conditions."
169
167
  },
170
168
  )
169
+ inner_dim: Optional[int] = dataclasses.field(
170
+ default=None,
171
+ metadata={"help": "The dimensionality of the hidden layer"},
172
+ )
173
+ apply_rotary: Optional[bool] = dataclasses.field(
174
+ default=False,
175
+ metadata={
176
+ "help": "A flag to indicate whether we want to use rotary encoder layers"
177
+ },
178
+ )
171
179
  sample_packing: Optional[bool] = dataclasses.field(
172
180
  default=False,
173
181
  metadata={
@@ -177,12 +185,6 @@ class CehrGPTArguments:
177
185
  max_tokens_per_batch: int = dataclasses.field(
178
186
  default=16384, metadata={"help": "Maximum number of tokens in each batch"}
179
187
  )
180
- add_end_token_in_sample_packing: Optional[bool] = dataclasses.field(
181
- default=False,
182
- metadata={
183
- "help": "A flag to indicate whether we want to add end token in sample packing"
184
- },
185
- )
186
188
  include_motor_time_to_event: Optional[bool] = dataclasses.field(
187
189
  default=False,
188
190
  metadata={
@@ -203,7 +205,17 @@ class CehrGPTArguments:
203
205
  "help": "The number of times each motor_num_time_pieces piece has to be"
204
206
  },
205
207
  )
206
- concept_dir: Optional[str] = dataclasses.field(
208
+ motor_use_ontology: Optional[bool] = dataclasses.field(
209
+ default=False,
210
+ metadata={
211
+ "help": "A flag to indicate whether we want to use motor_use_ontology"
212
+ },
213
+ )
214
+ motor_sampling_probability: Optional[float] = dataclasses.field(
215
+ default=0.0,
216
+ metadata={"help": "A flag to indicate whether we want to use sample packing"},
217
+ )
218
+ vocab_dir: Optional[str] = dataclasses.field(
207
219
  default=None,
208
220
  metadata={"help": "The directory where the concept data is stored."},
209
221
  )
@@ -229,3 +241,15 @@ class CehrGPTArguments:
229
241
  "help": "The probability of negative samples will be included in the training data"
230
242
  },
231
243
  )
244
+ num_of_trajectories_per_sample: Optional[int] = dataclasses.field(
245
+ default=1,
246
+ metadata={"help": "The number of trajectories per sample"},
247
+ )
248
+ generation_input_length: Optional[int] = dataclasses.field(
249
+ default=1024,
250
+ metadata={"help": "The length of the input sequence"},
251
+ )
252
+ generation_max_new_tokens: Optional[int] = dataclasses.field(
253
+ default=1024,
254
+ metadata={"help": "The maximum number of tokens in the generation sequence"},
255
+ )
@@ -1,5 +1,5 @@
1
1
  from functools import partial
2
- from typing import Callable, Tuple
2
+ from typing import Callable, List, Optional, Tuple, Union
3
3
 
4
4
  import optuna
5
5
  from cehrbert.runners.hf_runner_argument_dataclass import ModelArguments
@@ -64,28 +64,99 @@ class OptunaMetricCallback(TrainerCallback):
64
64
  metrics.update({"optuna_best_metric": metrics["eval_loss"]})
65
65
 
66
66
 
67
- # Define the hyperparameter search space with parameters
68
- def hp_space(
69
- trial: optuna.Trial,
70
- lr_range: Tuple[float, float] = (1e-5, 5e-5),
71
- batch_sizes=None,
72
- weight_decays: Tuple[float, float] = (1e-4, 1e-2),
73
- num_train_epochs: Tuple[float, ...] = 10,
74
- ):
75
- if batch_sizes is None:
76
- batch_sizes = [4, 8]
67
+ def get_suggestion(
68
+ trial,
69
+ hyperparameter_name: str,
70
+ hyperparameters: List[Union[float, int]],
71
+ is_grid: bool = False,
72
+ ) -> Union[float, int]:
73
+ """
74
+ Get hyperparameter suggestion based on search mode.
75
+
76
+ Args:
77
+ trial: Optuna trial object
78
+ hyperparameter_name: Name of the hyperparameter
79
+ hyperparameters: List of hyperparameter values
80
+ is_grid: Whether to use grid search mode
81
+
82
+ Returns:
83
+ Suggested hyperparameter value
84
+
85
+ Raises:
86
+ RuntimeError: If Bayesian mode is used with incorrect number of bounds
87
+ """
88
+ if is_grid:
89
+ return trial.suggest_categorical(hyperparameter_name, hyperparameters)
90
+
91
+ # For Bayesian optimization, we need exactly 2 values (lower and upper bounds)
92
+ if len(hyperparameters) != 2:
93
+ raise RuntimeError(
94
+ f"{hyperparameter_name} must contain exactly two values (lower and upper bound) "
95
+ f"for Bayesian Optimization, but {len(hyperparameters)} values were provided: {hyperparameters}"
96
+ )
97
+
98
+ # Ensure bounds are sorted
99
+ lower, upper = sorted(hyperparameters)
100
+ return trial.suggest_float(hyperparameter_name, lower, upper, log=True)
101
+
102
+
103
+ def hp_space(trial: optuna.Trial, cehrgpt_args: CehrGPTArguments):
104
+ """
105
+ Define the hyperparameter search space.
106
+
107
+ Args:
108
+ trial: Optuna trial object
109
+ cehrgpt_args: CehrGPTArguments
110
+ Returns:
111
+ Dictionary of hyperparameter suggestions
112
+ """
113
+
114
+ is_grid = cehrgpt_args.hyperparameter_tuning_is_grid
115
+ learning_rates = cehrgpt_args.hyperparameter_learning_rates
116
+ weight_decays = cehrgpt_args.hyperparameter_weight_decays
117
+ batch_sizes = cehrgpt_args.hyperparameter_batch_sizes
118
+ num_train_epochs = cehrgpt_args.hyperparameter_num_train_epochs
119
+
77
120
  return {
78
- "learning_rate": trial.suggest_float("learning_rate", *lr_range, log=True),
121
+ "learning_rate": get_suggestion(
122
+ trial, "learning_rate", learning_rates, is_grid
123
+ ),
79
124
  "per_device_train_batch_size": trial.suggest_categorical(
80
125
  "per_device_train_batch_size", batch_sizes
81
126
  ),
82
- "weight_decay": trial.suggest_float("weight_decay", *weight_decays, log=True),
127
+ "weight_decay": get_suggestion(trial, "weight_decay", weight_decays, is_grid),
83
128
  "num_train_epochs": trial.suggest_categorical(
84
129
  "num_train_epochs", num_train_epochs
85
130
  ),
86
131
  }
87
132
 
88
133
 
134
+ def create_grid_search_space(cehrgpt_args: CehrGPTArguments):
135
+ """
136
+ Create the search space dictionary for GridSampler.
137
+
138
+ Args:
139
+ cehrgpt_args: CehrGPTArguments
140
+
141
+ Returns:
142
+ Dictionary defining the grid search space
143
+ """
144
+ return {
145
+ "learning_rate": cehrgpt_args.hyperparameter_learning_rates,
146
+ "weight_decay": cehrgpt_args.hyperparameter_weight_decays,
147
+ "per_device_train_batch_size": cehrgpt_args.hyperparameter_batch_sizes,
148
+ "num_train_epochs": cehrgpt_args.hyperparameter_num_train_epochs,
149
+ }
150
+
151
+
152
+ def calculate_total_combinations(search_space: dict) -> int:
153
+ """Calculate total number of combinations in grid search."""
154
+ total = 1
155
+ for values in search_space.values():
156
+ total *= len(values)
157
+ return total
158
+
159
+
89
160
  def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
90
161
  """
91
162
  Samples a subset of the given dataset based on a specified percentage.
@@ -113,7 +184,7 @@ def sample_dataset(data: Dataset, percentage: float, seed: int) -> Dataset:
113
184
  returns the "test" portion, which is the specified percentage of the dataset.
114
185
  - Ensure that `percentage` is between 0 and 1 to avoid errors.
115
186
  """
116
- if percentage == 1.0:
187
+ if percentage >= 1.0:
117
188
  return data
118
189
 
119
190
  return data.train_test_split(
@@ -130,14 +201,13 @@ def perform_hyperparameter_search(
130
201
  training_args: TrainingArguments,
131
202
  model_args: ModelArguments,
132
203
  cehrgpt_args: CehrGPTArguments,
133
- ) -> TrainingArguments:
204
+ ) -> Tuple[TrainingArguments, Optional[str]]:
134
205
  """
135
206
  Perform hyperparameter tuning for the CehrGPT model using Optuna with the Hugging Face Trainer.
136
207
 
137
- This function initializes a Trainer with sampled training and validation sets, and performs
138
- a hyperparameter search using Optuna. The search tunes learning rate, batch size, and weight decay
139
- to optimize model performance based on a specified objective metric (e.g., validation loss).
140
- After the search, it updates the provided `TrainingArguments` with the best hyperparameters found.
208
+ This function supports two modes:
209
+ 1. Bayesian Optimization (TPE): Intelligently explores hyperparameter space using bounds
210
+ 2. Grid Search: Exhaustively tests all combinations of discrete values
141
211
 
142
212
  Args:
143
213
  trainer_class: A Trainer or its subclass
@@ -147,15 +217,15 @@ def perform_hyperparameter_search(
147
217
  training_args (TrainingArguments): Configuration for training parameters (e.g., epochs, evaluation strategy).
148
218
  model_args (ModelArguments): Model configuration arguments, including early stopping parameters.
149
219
  cehrgpt_args (CehrGPTArguments): Additional arguments specific to CehrGPT, including hyperparameter
150
- tuning options such as learning rate range, batch sizes, and tuning percentage.
220
+ tuning options and search mode configuration.
151
221
 
152
222
  Returns:
153
- TrainingArguments: Updated `TrainingArguments` instance containing the best hyperparameters found
154
- from the search.
223
+ Tuple[TrainingArguments, Optional[str]]: Updated TrainingArguments with best hyperparameters
224
+ and optional run_id of the best trial.
155
225
 
156
226
  Example:
157
227
  ```
158
- best_training_args = perform_hyperparameter_search(
228
+ best_training_args, run_id = perform_hyperparameter_search(
159
229
  trainer_class=Trainer,
160
230
  model_init=my_model_init,
161
231
  dataset=my_dataset_dict,
@@ -176,50 +246,91 @@ def perform_hyperparameter_search(
176
246
  Logging:
177
247
  Logs the best hyperparameters found at the end of the search.
178
248
  """
179
- if cehrgpt_args.hyperparameter_tuning:
180
- sampled_train = sample_dataset(
181
- dataset["train"],
182
- cehrgpt_args.hyperparameter_tuning_percentage,
183
- training_args.seed,
184
- )
185
- sampled_val = sample_dataset(
186
- dataset["validation"],
187
- cehrgpt_args.hyperparameter_tuning_percentage,
188
- training_args.seed,
189
- )
190
- hyperparam_trainer = trainer_class(
191
- model_init=model_init,
192
- data_collator=data_collator,
193
- train_dataset=sampled_train,
194
- eval_dataset=sampled_val,
195
- callbacks=[
196
- EarlyStoppingCallback(model_args.early_stopping_patience),
197
- OptunaMetricCallback(),
198
- ],
199
- args=training_args,
200
- )
201
- # Perform hyperparameter search
202
- best_trial = hyperparam_trainer.hyperparameter_search(
203
- direction="minimize",
204
- hp_space=partial(
205
- hp_space,
206
- lr_range=(cehrgpt_args.lr_low, cehrgpt_args.lr_high),
207
- weight_decays=(
208
- cehrgpt_args.weight_decays_low,
209
- cehrgpt_args.weight_decays_high,
210
- ),
211
- batch_sizes=cehrgpt_args.hyperparameter_batch_sizes,
212
- num_train_epochs=cehrgpt_args.hyperparameter_num_train_epochs,
213
- ),
214
- backend="optuna",
215
- n_trials=cehrgpt_args.n_trials,
216
- compute_objective=lambda m: m["optuna_best_metric"],
217
- # Ensure reproducibility
218
- sampler=optuna.samplers.TPESampler(seed=training_args.seed),
219
- )
220
- LOG.info("Best hyperparameters: %s", best_trial.hyperparameters)
221
- # Update training arguments with best hyperparameters and set epochs based on adjusted effective epochs
222
- for k, v in best_trial.hyperparameters.items():
223
- setattr(training_args, k, v)
249
+ if not cehrgpt_args.hyperparameter_tuning:
250
+ return training_args, None
251
+
252
+ # Prepare hyperparameters based on mode
253
+ if (
254
+ cehrgpt_args.hyperparameter_tuning_is_grid
255
+ and cehrgpt_args.hyperparameter_tuning_is_grid
256
+ ):
257
+ search_space = create_grid_search_space(cehrgpt_args)
258
+ total_combinations = calculate_total_combinations(search_space)
259
+
260
+ LOG.info(f"Grid search mode: Testing {total_combinations} combinations")
261
+ LOG.info(f"Search space: {search_space}")
262
+
263
+ # Adjust n_trials for grid search if not set appropriately
264
+ if cehrgpt_args.n_trials < total_combinations:
265
+ LOG.warning(
266
+ f"n_trials ({cehrgpt_args.n_trials}) is less than total combinations ({total_combinations}). "
267
+ f"Setting n_trials to {total_combinations} to test all combinations."
268
+ )
269
+ cehrgpt_args.n_trials = total_combinations
270
+
271
+ # Configure sampler based on search mode
272
+ sampler = optuna.samplers.GridSampler(search_space, seed=training_args.seed)
273
+ else:
274
+ LOG.info("Bayesian optimization mode (TPE)")
275
+ LOG.info(f"Learning rate bounds: {cehrgpt_args.hyperparameter_learning_rates}")
276
+ LOG.info(f"Weight decay bounds: {cehrgpt_args.hyperparameter_weight_decays}")
277
+ LOG.info(f"Batch sizes: {cehrgpt_args.hyperparameter_batch_sizes}")
278
+ LOG.info(f"Epochs: {cehrgpt_args.hyperparameter_num_train_epochs}")
279
+ # Configure the TPE sampler
280
+ sampler = optuna.samplers.TPESampler(seed=training_args.seed)
281
+
282
+ # Prepare datasets
283
+ save_total_limit_original = training_args.save_total_limit
284
+ training_args.save_total_limit = 1
285
+
286
+ sampled_train = sample_dataset(
287
+ dataset["train"],
288
+ cehrgpt_args.hyperparameter_tuning_percentage,
289
+ training_args.seed,
290
+ )
291
+ sampled_val = sample_dataset(
292
+ dataset["validation"],
293
+ cehrgpt_args.hyperparameter_tuning_percentage,
294
+ training_args.seed,
295
+ )
296
+ # Create trainer
297
+ hyperparam_trainer = trainer_class(
298
+ model_init=model_init,
299
+ data_collator=data_collator,
300
+ train_dataset=sampled_train,
301
+ eval_dataset=sampled_val,
302
+ callbacks=[
303
+ EarlyStoppingCallback(model_args.early_stopping_patience),
304
+ OptunaMetricCallback(),
305
+ ],
306
+ args=training_args,
307
+ )
308
+
309
+ best_trial = hyperparam_trainer.hyperparameter_search(
310
+ direction="minimize",
311
+ hp_space=partial(
312
+ hp_space,
313
+ cehrgpt_args=cehrgpt_args,
314
+ ),
315
+ backend="optuna",
316
+ n_trials=cehrgpt_args.n_trials,
317
+ compute_objective=lambda m: m["optuna_best_metric"],
318
+ sampler=sampler,
319
+ )
320
+
321
+ # Log results
322
+ LOG.info("=" * 50)
323
+ LOG.info("HYPERPARAMETER SEARCH COMPLETED")
324
+ LOG.info("=" * 50)
325
+ LOG.info(f"Best hyperparameters: {best_trial.hyperparameters}")
326
+ LOG.info(f"Best metric (eval_loss): {best_trial.objective}")
327
+ LOG.info(f"Best run_id: {best_trial.run_id}")
328
+ LOG.info("=" * 50)
329
+
330
+ # Restore original settings and update with best hyperparameters
331
+ training_args.save_total_limit = save_total_limit_original
332
+ for k, v in best_trial.hyperparameters.items():
333
+ setattr(training_args, k, v)
334
+ LOG.info(f"Updated training_args.{k} = {v}")
224
335
 
225
- return training_args
336
+ return training_args, best_trial.run_id
@@ -1,9 +1,10 @@
1
1
  from typing import Optional, Union
2
2
 
3
+ import torch
3
4
  from datasets import Dataset
4
5
  from torch.utils.data import DataLoader
5
6
  from transformers import Trainer
6
- from transformers.trainer_utils import has_length
7
+ from transformers.trainer_utils import has_length, seed_worker
7
8
  from transformers.utils import import_utils, logging
8
9
 
9
10
  from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
@@ -62,7 +63,10 @@ class SamplePackingTrainer(Trainer):
62
63
  if "num_of_concepts" in train_dataset.column_names:
63
64
  lengths = train_dataset["num_of_concepts"]
64
65
  else:
65
- lengths = [len(sample["input_ids"]) for sample in train_dataset]
66
+ lengths = [
67
+ len(sample["input_ids"])
68
+ for sample in train_dataset.select_columns("input_ids")
69
+ ]
66
70
 
67
71
  LOG.info("Finished computing lengths for the train dataset")
68
72
  else:
@@ -102,6 +106,11 @@ class SamplePackingTrainer(Trainer):
102
106
  "persistent_workers": self.args.dataloader_persistent_workers,
103
107
  "batch_sampler": batch_sampler,
104
108
  }
109
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
110
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
111
+ dataloader_params["worker_init_fn"] = seed_worker
112
+ dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
113
+
105
114
  return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
106
115
 
107
116
  def get_eval_dataloader(
@@ -1,8 +1,8 @@
1
+ import datetime
1
2
  import glob
2
3
  import os
3
4
  import shutil
4
5
  import uuid
5
- from datetime import datetime
6
6
  from functools import partial
7
7
  from pathlib import Path
8
8
  from typing import Optional, Union
@@ -15,6 +15,9 @@ import torch.distributed as dist
15
15
  from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector
16
16
  from cehrbert.runners.runner_util import generate_prepared_ds_path
17
17
  from datasets import concatenate_datasets, load_from_disk
18
+ from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import (
19
+ batched_powerSGD_hook,
20
+ )
18
21
  from torch.utils.data import DataLoader
19
22
  from tqdm import tqdm
20
23
  from transformers.trainer_utils import is_main_process
@@ -25,7 +28,6 @@ from cehrgpt.data.hf_cehrgpt_dataset_collator import (
25
28
  CehrGptDataCollator,
26
29
  SamplePackingCehrGptDataCollator,
27
30
  )
28
- from cehrgpt.data.hf_cehrgpt_dataset_mapping import ExtractTokenizedSequenceDataMapping
29
31
  from cehrgpt.data.sample_packing_sampler import SamplePackingBatchSampler
30
32
  from cehrgpt.models.hf_cehrgpt import (
31
33
  CEHRGPT2Model,
@@ -159,24 +161,7 @@ def main():
159
161
  final_splits = prepare_finetune_dataset(
160
162
  data_args, training_args, cehrgpt_args, cache_file_collector
161
163
  )
162
- if cehrgpt_args.expand_tokenizer:
163
- new_tokenizer_path = os.path.expanduser(training_args.output_dir)
164
- if tokenizer_exists(new_tokenizer_path):
165
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(
166
- new_tokenizer_path
167
- )
168
- else:
169
- cehrgpt_tokenizer = CehrGptTokenizer.expand_trained_tokenizer(
170
- cehrgpt_tokenizer=cehrgpt_tokenizer,
171
- dataset=final_splits["train"],
172
- data_args=data_args,
173
- concept_name_mapping={},
174
- )
175
- cehrgpt_tokenizer.save_pretrained(
176
- os.path.expanduser(training_args.output_dir)
177
- )
178
-
179
- # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
164
+ # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
180
165
  if not data_args.streaming:
181
166
  all_columns = final_splits["train"].column_names
182
167
  if "visit_concept_ids" in all_columns:
@@ -238,10 +223,6 @@ def main():
238
223
  len(processed_dataset["test"]),
239
224
  )
240
225
 
241
- LOG.info(f"cehrgpt_model.config.vocab_size: {cehrgpt_model.config.vocab_size}")
242
- LOG.info(f"cehrgpt_tokenizer.vocab_size: {cehrgpt_tokenizer.vocab_size}")
243
- if cehrgpt_model.config.vocab_size < cehrgpt_tokenizer.vocab_size:
244
- cehrgpt_model.resize_token_embeddings(cehrgpt_tokenizer.vocab_size)
245
226
  if (
246
227
  cehrgpt_model.config.max_position_embeddings
247
228
  < model_args.max_position_embeddings
@@ -264,7 +245,6 @@ def main():
264
245
  SamplePackingCehrGptDataCollator,
265
246
  cehrgpt_args.max_tokens_per_batch,
266
247
  cehrgpt_model.config.max_position_embeddings,
267
- add_end_token_in_sample_packing=cehrgpt_args.add_end_token_in_sample_packing,
268
248
  )
269
249
  train_batch_sampler = SamplePackingBatchSampler(
270
250
  lengths=train_set["num_of_concepts"],
@@ -339,10 +319,12 @@ def main():
339
319
  for data_dir in [data_args.data_folder, data_args.test_data_folder]
340
320
  ]
341
321
  )
342
- # This is a pre-caution in case the index_date is not a datetime type
343
- demographics_df["index_date"] = pd.to_datetime(
344
- demographics_df["index_date"]
345
- ).dt.date
322
+
323
+ demographics_df["index_date"] = (
324
+ demographics_df["index_date"].dt.tz_localize("UTC")
325
+ - datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
326
+ ).dt.total_seconds()
327
+
346
328
  demographics_dict = {
347
329
  (row["person_id"], row["index_date"]): {
348
330
  "gender_concept_id": row["gender_concept_id"],
@@ -353,7 +335,7 @@ def main():
353
335
 
354
336
  data_loaders = [("train", train_loader), ("test", test_dataloader)]
355
337
 
356
- ve_token_id = cehrgpt_tokenizer._convert_token_to_id("[VE]")
338
+ ve_token_id = cehrgpt_tokenizer.ve_token_id
357
339
  for split, data_loader in data_loaders:
358
340
  # Ensure prediction folder exists
359
341
  feature_output_folder = (
@@ -379,9 +361,16 @@ def main():
379
361
  prediction_time_posix = batch.pop("index_date").numpy().squeeze()
380
362
  if prediction_time_posix.ndim == 0:
381
363
  prediction_time_posix = np.asarray([prediction_time_posix])
364
+
382
365
  prediction_time = list(
383
- map(datetime.fromtimestamp, prediction_time_posix)
366
+ map(
367
+ lambda posix_time: datetime.datetime.utcfromtimestamp(
368
+ posix_time
369
+ ).replace(tzinfo=None),
370
+ prediction_time_posix,
371
+ )
384
372
  )
373
+
385
374
  labels = (
386
375
  batch.pop("classifier_label")
387
376
  .float()
@@ -393,6 +382,13 @@ def main():
393
382
  if labels.ndim == 0:
394
383
  labels = np.asarray([labels])
395
384
 
385
+ # Right now the model does not support this column, we need to pop it
386
+ if "epoch_times" in batch:
387
+ batch.pop("epoch_times")
388
+
389
+ if "ages" in batch:
390
+ batch.pop("ages")
391
+
396
392
  batch = {k: v.to(device) for k, v in batch.items()}
397
393
  # Forward pass
398
394
  cehrgpt_output = cehrgpt_model(