cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,373 @@
1
+ import datetime
2
+ import os
3
+ import pickle
4
+ from collections import Counter, defaultdict
5
+ from functools import partial
6
+ from typing import Any, Dict, List
7
+
8
+ import numpy as np
9
+ import torch
10
+ from cehrbert.models.hf_models.tokenization_utils import agg_helper
11
+ from cehrbert.runners.runner_util import load_parquet_as_dataset
12
+ from tqdm import tqdm
13
+ from transformers.utils import is_flash_attn_2_available, logging
14
+ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, create_reference_model
15
+
16
+ from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
17
+ from cehrgpt.generation.generate_batch_hf_gpt_sequence import generate_single_batch
18
+ from cehrgpt.gpt_utils import get_cehrgpt_output_folder
19
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
20
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
21
+ from cehrgpt.rl_finetune.cehrgpt_ppo_trainer import (
22
+ CehrGptPPODataCollator,
23
+ CehrGptPPOTrainer,
24
+ )
25
+
26
+ LOG = logging.get_logger("transformers")
27
+
28
+
29
+ def extract_concept_frequency(records: Dict[str, Any]) -> Dict[str, int]:
30
+ batched_concept_ids = records["concept_ids"]
31
+ outputs = defaultdict(int)
32
+ for concept_ids in batched_concept_ids:
33
+ for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
34
+ outputs[concept_id] += cnt
35
+ return outputs
36
+
37
+
38
+ def main(args):
39
+ if torch.cuda.is_available():
40
+ device = torch.device("cuda")
41
+ else:
42
+ device = torch.device("cpu")
43
+
44
+ cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
45
+ model_folder_name = os.path.join(
46
+ args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
47
+ )
48
+
49
+ if not os.path.exists(model_folder_name):
50
+ os.makedirs(model_folder_name)
51
+
52
+ if args.restore_from_checkpoint:
53
+ try:
54
+ cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
55
+ model_folder_name,
56
+ attn_implementation=(
57
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
58
+ ),
59
+ torch_dtype=(
60
+ torch.bfloat16 if is_flash_attn_2_available() else torch.float32
61
+ ),
62
+ )
63
+ except Exception:
64
+ LOG.warning(
65
+ "Checkpoint does not exist in %s, loading from the %s",
66
+ model_folder_name,
67
+ args.model_folder,
68
+ )
69
+ cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
70
+ args.model_folder,
71
+ attn_implementation=(
72
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
73
+ ),
74
+ torch_dtype=(
75
+ torch.bfloat16 if is_flash_attn_2_available() else torch.float32
76
+ ),
77
+ )
78
+ else:
79
+ cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
80
+ args.model_folder,
81
+ attn_implementation=(
82
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
83
+ ),
84
+ torch_dtype=(
85
+ torch.bfloat16 if is_flash_attn_2_available() else torch.float32
86
+ ),
87
+ )
88
+
89
+ cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
90
+ cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
91
+ cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
92
+ model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
93
+ model.is_peft_model = False
94
+ ref_model = create_reference_model(model).to(device)
95
+
96
+ # create a ppo trainer
97
+ ppo_trainer = CehrGptPPOTrainer(
98
+ config=PPOConfig(
99
+ batch_size=args.batch_size,
100
+ mini_batch_size=args.mini_batch_size,
101
+ init_kl_coef=args.init_kl_coef,
102
+ vf_coef=args.vf_coef,
103
+ kl_penalty=args.kl_penalty,
104
+ gamma=args.gamma,
105
+ use_score_scaling=args.use_score_scaling,
106
+ ),
107
+ model=model,
108
+ ref_model=ref_model,
109
+ tokenizer=cehrgpt_tokenizer,
110
+ training_data_collator=CehrGptPPODataCollator(
111
+ cehrgpt_tokenizer, max_length=args.context_window
112
+ ),
113
+ )
114
+
115
+ LOG.info(f"Loading tokenizer at {args.model_folder}")
116
+ LOG.info(f"Loading model at {args.model_folder}")
117
+ LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
118
+ LOG.info(f"Context window {args.context_window}")
119
+ LOG.info(f"Temperature {args.temperature}")
120
+ LOG.info(f"Repetition Penalty {args.repetition_penalty}")
121
+ LOG.info(f"Sampling Strategy {args.sampling_strategy}")
122
+ LOG.info(f"Num beam {args.num_beams}")
123
+ LOG.info(f"Num beam groups {args.num_beam_groups}")
124
+ LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
125
+ LOG.info(f"Top P {args.top_p}")
126
+ LOG.info(f"Top K {args.top_k}")
127
+ LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
128
+
129
+ dataset = load_parquet_as_dataset(args.demographic_data_path).filter(
130
+ lambda batched: [
131
+ model.config.n_positions >= num_of_concepts > args.min_num_tokens
132
+ for num_of_concepts in batched["num_of_concepts"]
133
+ ],
134
+ batched=True,
135
+ )
136
+ parts = dataset.map(
137
+ partial(agg_helper, map_func=extract_concept_frequency),
138
+ batched=True,
139
+ batch_size=1000,
140
+ num_proc=args.num_proc,
141
+ remove_columns=dataset.column_names,
142
+ )
143
+
144
+ concept_stats = defaultdict(float)
145
+ for stat in tqdm(parts, desc="Aggregating the concept counts"):
146
+ fixed_stat = pickle.loads(stat["data"])
147
+ for concept_id, count in fixed_stat.items():
148
+ concept_stats[concept_id] += count
149
+ total_sum = sum(concept_stats.values())
150
+ for concept_id, count in concept_stats.items():
151
+ concept_stats[concept_id] = count / total_sum
152
+
153
+ logs = []
154
+ device = ppo_trainer.current_device
155
+ total_rows = len(dataset)
156
+ num_of_micro_batches = args.batch_size // args.mini_batch_size
157
+ for i in tqdm(range(args.num_of_steps)):
158
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
159
+ random_prompts = []
160
+ batched_sequences = []
161
+ batched_values = []
162
+ batched_value_indicators = []
163
+ for _ in range(num_of_micro_batches):
164
+ random_indices = np.random.randint(0, total_rows, args.mini_batch_size)
165
+ random_prompts_micro_batch = [
166
+ record["concept_ids"][:4] for record in dataset.select(random_indices)
167
+ ]
168
+ random_prompts.extend(random_prompts_micro_batch)
169
+ micro_batched_prompts = [
170
+ cehrgpt_tokenizer.encode(random_prompt)
171
+ for random_prompt in random_prompts_micro_batch
172
+ ]
173
+
174
+ micro_batched_sequences = generate_single_batch(
175
+ cehrgpt_model,
176
+ cehrgpt_tokenizer,
177
+ micro_batched_prompts,
178
+ max_new_tokens=args.context_window,
179
+ mini_num_of_concepts=args.min_num_of_concepts,
180
+ top_p=args.top_p,
181
+ top_k=args.top_k,
182
+ temperature=args.temperature,
183
+ repetition_penalty=args.repetition_penalty,
184
+ num_beams=args.num_beams,
185
+ num_beam_groups=args.num_beam_groups,
186
+ epsilon_cutoff=args.epsilon_cutoff,
187
+ device=device,
188
+ )
189
+ # Clear the cache
190
+ torch.cuda.empty_cache()
191
+ batched_sequences.extend(micro_batched_sequences["sequences"])
192
+ batched_values.extend(micro_batched_sequences["values"])
193
+ batched_value_indicators.extend(micro_batched_sequences["value_indicators"])
194
+
195
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
196
+ reward = compute_marginal_dist_reward(
197
+ batched_sequences, concept_stats, cehrgpt_tokenizer
198
+ )
199
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
200
+ query_tensors = []
201
+ response_tensors = []
202
+ value_tensors = []
203
+ value_indicator_tensors = []
204
+ rewards = []
205
+ for sequence, values, value_indicators in zip(
206
+ batched_sequences, batched_values, batched_value_indicators
207
+ ):
208
+ # Convert sequence to a NumPy array if it's not already one
209
+ sequence_array = np.asarray(sequence)
210
+ # Find the end token
211
+ condition_array = sequence_array == cehrgpt_tokenizer.end_token
212
+ end_index = (
213
+ np.argmax(condition_array)
214
+ if condition_array.any()
215
+ else len(sequence_array) - 1
216
+ )
217
+
218
+ sequence = sequence[: end_index + 1]
219
+ values = values[: end_index + 1]
220
+ value_indicators = value_indicators[: end_index + 1]
221
+
222
+ query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
223
+ response_tensors.append(
224
+ torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
225
+ )
226
+ value_tensors.append(torch.tensor(cehrgpt_tokenizer.encode_value(values)))
227
+ value_indicator_tensors.append(torch.tensor(value_indicators))
228
+ rewards.append(reward)
229
+
230
+ train_stats = ppo_trainer.step(
231
+ query_tensors,
232
+ response_tensors,
233
+ rewards,
234
+ value_tensors,
235
+ value_indicator_tensors,
236
+ )
237
+ LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
238
+ logs.append(reward)
239
+ ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
240
+ ppo_trainer.save_pretrained(model_folder_name)
241
+ with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
242
+ pickle.dump(logs, f)
243
+
244
+
245
+ def compute_marginal_dist_reward(
246
+ batched_sequences: List[List[str]],
247
+ expected_concept_dist: Dict[str, float],
248
+ tokenizer: CehrGptTokenizer,
249
+ ) -> torch.Tensor:
250
+ actual_concept_dist = dict(
251
+ Counter(
252
+ [
253
+ concept_id
254
+ for sequence in batched_sequences
255
+ for concept_id in sequence[4:]
256
+ ]
257
+ )
258
+ )
259
+ total_count = sum(actual_concept_dist.values())
260
+ for concept_id in actual_concept_dist.keys():
261
+ actual_concept_dist[concept_id] /= total_count
262
+ # Translate the concept ids to token ids
263
+ actual_dist = np.zeros(tokenizer.vocab_size)
264
+ actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
265
+ actual_concept_dist.values()
266
+ )
267
+ # Add a small epsilon to avoid log(0)
268
+ epsilon = 1e-10
269
+ logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
270
+ # Translate the concept ids to token ids
271
+ ref_dist = np.zeros(tokenizer.vocab_size)
272
+ ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
273
+ expected_concept_dist.values()
274
+ )
275
+ ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
276
+
277
+ # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
278
+ return torch.exp(
279
+ -torch.nn.functional.kl_div(
280
+ ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
281
+ ).sum(-1)
282
+ )
283
+
284
+
285
+ def create_arg_parser():
286
+ base_arg_parser = create_inference_base_arg_parser(
287
+ description="Arguments for finetuning cehr-gpt using PPO"
288
+ )
289
+ base_arg_parser.add_argument(
290
+ "--mini_batch_size",
291
+ dest="mini_batch_size",
292
+ action="store",
293
+ type=int,
294
+ required=True,
295
+ )
296
+ base_arg_parser.add_argument(
297
+ "--init_kl_coef",
298
+ dest="init_kl_coef",
299
+ action="store",
300
+ type=float,
301
+ required=False,
302
+ default=0.1,
303
+ )
304
+ base_arg_parser.add_argument(
305
+ "--vf_coef",
306
+ dest="vf_coef",
307
+ action="store",
308
+ type=float,
309
+ required=False,
310
+ default=0.1,
311
+ )
312
+ base_arg_parser.add_argument(
313
+ "--kl_penalty",
314
+ dest="kl_penalty",
315
+ action="store",
316
+ choices=["kl", "abs", "mse", "full"],
317
+ required=False,
318
+ default="kl",
319
+ )
320
+ base_arg_parser.add_argument(
321
+ "--gamma",
322
+ dest="gamma",
323
+ action="store",
324
+ type=float,
325
+ required=False,
326
+ default=0.99,
327
+ )
328
+ base_arg_parser.add_argument(
329
+ "--num_proc",
330
+ dest="num_proc",
331
+ action="store",
332
+ type=int,
333
+ default=4,
334
+ required=False,
335
+ )
336
+ base_arg_parser.add_argument(
337
+ "--num_of_steps",
338
+ dest="num_of_steps",
339
+ action="store",
340
+ type=int,
341
+ default=1028,
342
+ required=False,
343
+ )
344
+ base_arg_parser.add_argument(
345
+ "--min_num_tokens",
346
+ dest="min_num_tokens",
347
+ action="store",
348
+ type=int,
349
+ default=4,
350
+ required=False,
351
+ )
352
+ base_arg_parser.add_argument(
353
+ "--demographic_data_path",
354
+ dest="demographic_data_path",
355
+ action="store",
356
+ help="The path for your concept_path",
357
+ required=True,
358
+ )
359
+ base_arg_parser.add_argument(
360
+ "--restore_from_checkpoint",
361
+ dest="restore_from_checkpoint",
362
+ action="store_true",
363
+ )
364
+ base_arg_parser.add_argument(
365
+ "--use_score_scaling",
366
+ dest="use_score_scaling",
367
+ action="store_true",
368
+ )
369
+ return base_arg_parser
370
+
371
+
372
+ if __name__ == "__main__":
373
+ main(create_arg_parser().parse_args())
@@ -0,0 +1,119 @@
1
+ from cehrbert.data_generators.hf_data_generator.hf_dataset import (
2
+ apply_cehrbert_dataset_mapping,
3
+ )
4
+ from cehrbert.runners.runner_util import (
5
+ generate_prepared_ds_path,
6
+ get_last_hf_checkpoint,
7
+ load_parquet_as_dataset,
8
+ )
9
+ from datasets import DatasetDict, load_from_disk
10
+ from transformers import set_seed
11
+ from transformers.utils import is_flash_attn_2_available, logging
12
+
13
+ from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
14
+ from cehrgpt.data.hf_cehrgpt_dpo_dataset_mapping import HFCehrGptDPOTokenizationMapping
15
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
16
+ from cehrgpt.rl_finetune.cehrgpt_dpo_trainer import CehrGptDPOTrainer
17
+ from cehrgpt.runners.gpt_runner_util import parse_dpo_runner_args
18
+ from cehrgpt.runners.hf_cehrgpt_finetune_runner import load_pretrained_tokenizer
19
+
20
+ LOG = logging.get_logger("transformers")
21
+
22
+
23
+ def main():
24
+ cehrgpt_args, data_args, model_args, dpo_config = parse_dpo_runner_args()
25
+ tokenizer = load_pretrained_tokenizer(model_args)
26
+ prepared_ds_path = generate_prepared_ds_path(
27
+ data_args, model_args, data_folder=data_args.cohort_folder
28
+ )
29
+ if any(prepared_ds_path.glob("*")):
30
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
31
+ processed_dataset = load_from_disk(str(prepared_ds_path))
32
+ LOG.info("Prepared dataset loaded from disk...")
33
+ else:
34
+ dataset = load_parquet_as_dataset(data_args.data_folder)
35
+ # Random split
36
+ dataset = dataset.train_test_split(
37
+ test_size=data_args.validation_split_percentage, seed=dpo_config.seed
38
+ )
39
+ processed_dataset = apply_cehrbert_dataset_mapping(
40
+ dataset,
41
+ mapping_function=HFCehrGptDPOTokenizationMapping(tokenizer),
42
+ batch_size=data_args.preprocessing_batch_size,
43
+ num_proc=data_args.preprocessing_num_workers,
44
+ streaming=data_args.streaming,
45
+ )
46
+
47
+ processed_dataset = processed_dataset.filter(
48
+ lambda batch: [
49
+ len(chosen_concept_ids) < model_args.max_position_embeddings
50
+ for chosen_concept_ids in batch["chosen_concept_ids"]
51
+ ],
52
+ batched=True,
53
+ batch_size=data_args.preprocessing_batch_size,
54
+ num_proc=data_args.preprocessing_num_workers,
55
+ ).filter(
56
+ lambda batch: [
57
+ len(rejected_concept_ids) < model_args.max_position_embeddings
58
+ for rejected_concept_ids in batch["rejected_concept_ids"]
59
+ ],
60
+ batched=True,
61
+ batch_size=data_args.preprocessing_batch_size,
62
+ num_proc=data_args.preprocessing_num_workers,
63
+ )
64
+ processed_dataset.save_to_disk(prepared_ds_path)
65
+
66
+ # Set seed before initializing model.
67
+ set_seed(dpo_config.seed)
68
+ processed_dataset.set_format("pt")
69
+
70
+ # A hacky way to prevent the training from removing unmatched inputs
71
+ dpo_config.label_names = [
72
+ "chosen_input_ids",
73
+ "rejected_input_ids",
74
+ "chosen_concept_values",
75
+ "rejected_concept_values",
76
+ "chosen_concept_value_masks",
77
+ "rejected_concept_value_masks",
78
+ ]
79
+
80
+ attn_implementation = (
81
+ "flash_attention_2" if is_flash_attn_2_available() else "eager"
82
+ )
83
+ model = CEHRGPT2LMHeadModel.from_pretrained(
84
+ model_args.model_name_or_path,
85
+ attn_implementation=attn_implementation,
86
+ )
87
+ ref_model = CEHRGPT2LMHeadModel.from_pretrained(
88
+ model_args.model_name_or_path,
89
+ attn_implementation=attn_implementation,
90
+ )
91
+
92
+ # Initialize Trainer for final training on the combined train+val set
93
+ trainer = CehrGptDPOTrainer(
94
+ model=model,
95
+ ref_model=ref_model,
96
+ args=dpo_config,
97
+ tokenizer=tokenizer,
98
+ train_dataset=processed_dataset["train"],
99
+ eval_dataset=processed_dataset["test"],
100
+ data_collator=CehrGptDPODataCollator(
101
+ tokenizer=tokenizer,
102
+ max_length=model_args.max_position_embeddings,
103
+ pretraining=False,
104
+ include_ttv_prediction=False,
105
+ use_sub_time_tokenization=False,
106
+ ),
107
+ )
108
+ # Train the model on the combined train + val set
109
+ checkpoint = get_last_hf_checkpoint(dpo_config)
110
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
111
+ trainer.save_model() # Saves the tokenizer too for easy upload
112
+ metrics = train_result.metrics
113
+ trainer.log_metrics("train", metrics)
114
+ trainer.save_metrics("train", metrics)
115
+ trainer.save_state()
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
@@ -43,6 +43,7 @@ from transformers.utils import is_flash_attn_2_available, logging
43
43
 
44
44
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_finetuning_dataset
45
45
  from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
46
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
46
47
  from cehrgpt.models.hf_cehrgpt import (
47
48
  CEHRGPTConfig,
48
49
  CehrGptForClassification,
@@ -408,10 +409,24 @@ def main():
408
409
  except Exception as e:
409
410
  LOG.exception(e)
410
411
  dataset = create_dataset_from_meds_reader(
411
- data_args, is_pretraining=False
412
+ data_args=data_args,
413
+ dataset_mappings=[
414
+ MedToCehrGPTDatasetMapping(
415
+ data_args=data_args,
416
+ is_pretraining=False,
417
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
418
+ )
419
+ ],
412
420
  )
413
421
  if not data_args.streaming:
414
- dataset.save_to_disk(meds_extension_path)
422
+ dataset.save_to_disk(str(meds_extension_path))
423
+ stats = dataset.cleanup_cache_files()
424
+ LOG.info(
425
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
426
+ stats,
427
+ )
428
+ dataset = load_from_disk(str(meds_extension_path))
429
+
415
430
  train_set = dataset["train"]
416
431
  validation_set = dataset["validation"]
417
432
  test_set = dataset["test"]
@@ -451,7 +466,13 @@ def main():
451
466
  dataset=final_splits, cehrgpt_tokenizer=tokenizer, data_args=data_args
452
467
  )
453
468
  if not data_args.streaming:
454
- processed_dataset.save_to_disk(prepared_ds_path)
469
+ processed_dataset.save_to_disk(str(prepared_ds_path))
470
+ stats = processed_dataset.cleanup_cache_files()
471
+ LOG.info(
472
+ "Clean up the cached files for the cehrgpt finetuning dataset : %s",
473
+ stats,
474
+ )
475
+ processed_dataset = load_from_disk(str(prepared_ds_path))
455
476
 
456
477
  # Set seed before initializing model.
457
478
  set_seed(training_args.seed)
@@ -21,12 +21,13 @@ from transformers.utils import is_flash_attn_2_available, logging
21
21
 
22
22
  from cehrgpt.data.hf_cehrgpt_dataset import create_cehrgpt_pretraining_dataset
23
23
  from cehrgpt.data.hf_cehrgpt_dataset_collator import CehrGptDataCollator
24
+ from cehrgpt.data.hf_cehrgpt_dataset_mapping import MedToCehrGPTDatasetMapping
24
25
  from cehrgpt.models.config import CEHRGPTConfig
25
26
  from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
26
27
  from cehrgpt.models.pretrained_embeddings import PretrainedEmbeddings
27
28
  from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
28
29
  from cehrgpt.runners.gpt_runner_util import parse_runner_args
29
- from src.cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
30
+ from cehrgpt.runners.hf_gpt_runner_argument_dataclass import CehrGPTArguments
30
31
 
31
32
  LOG = logging.get_logger("transformers")
32
33
 
@@ -82,11 +83,25 @@ def load_and_create_model(
82
83
  model_abspath = os.path.expanduser(model_args.model_name_or_path)
83
84
  if cehrgpt_args.continue_pretrain:
84
85
  try:
85
- return CEHRGPT2LMHeadModel.from_pretrained(
86
+ pretrained_model = CEHRGPT2LMHeadModel.from_pretrained(
86
87
  model_abspath,
87
88
  attn_implementation=attn_implementation,
88
89
  torch_dtype=torch_dtype,
89
90
  )
91
+ if (
92
+ pretrained_model.config.max_position_embeddings
93
+ < model_args.max_position_embeddings
94
+ ):
95
+ LOG.info(
96
+ f"Increase model.config.max_position_embeddings to {model_args.max_position_embeddings}"
97
+ )
98
+ pretrained_model.config.max_position_embeddings = (
99
+ model_args.max_position_embeddings
100
+ )
101
+ pretrained_model.resize_position_embeddings(
102
+ model_args.max_position_embeddings
103
+ )
104
+ return pretrained_model
90
105
  except Exception as e:
91
106
  LOG.error(
92
107
  f"When continue_pretrain is set to True, it assumes that CEHR-GPT has been trained "
@@ -94,7 +109,7 @@ def load_and_create_model(
94
109
  )
95
110
  raise e
96
111
  try:
97
- model_config = AutoConfig.from_pretrained(
112
+ model_config = CEHRGPTConfig.from_pretrained(
98
113
  model_abspath, attn_implementation=attn_implementation
99
114
  )
100
115
  except Exception as e:
@@ -148,7 +163,7 @@ def main():
148
163
  # The iterable dataset doesn't have sharding implemented, so the number of works has to be set to 0
149
164
  # Otherwise the trainer will throw an error
150
165
  training_args.dataloader_num_workers = 0
151
- training_args.dataloader_prefetch_factor = 0
166
+ training_args.dataloader_prefetch_factor = None
152
167
 
153
168
  prepared_ds_path = generate_prepared_ds_path(data_args, model_args)
154
169
  if os.path.exists(os.path.join(data_args.data_folder, "dataset_dict.json")):
@@ -212,14 +227,29 @@ def main():
212
227
  except FileNotFoundError as e:
213
228
  LOG.exception(e)
214
229
  dataset = create_dataset_from_meds_reader(
215
- data_args, is_pretraining=True
230
+ data_args=data_args,
231
+ dataset_mappings=[
232
+ MedToCehrGPTDatasetMapping(
233
+ data_args=data_args,
234
+ is_pretraining=True,
235
+ include_inpatient_hour_token=cehrgpt_args.include_inpatient_hour_token,
236
+ )
237
+ ],
216
238
  )
217
239
  if not data_args.streaming:
218
- dataset.save_to_disk(meds_extension_path)
240
+ dataset.save_to_disk(str(meds_extension_path))
241
+ stats = dataset.cleanup_cache_files()
242
+ LOG.info(
243
+ "Clean up the cached files for the cehrgpt dataset transformed from the MEDS: %s",
244
+ stats,
245
+ )
246
+ dataset = load_from_disk(str(meds_extension_path))
219
247
  else:
220
248
  # Load the dataset from the parquet files
221
249
  dataset = load_parquet_as_dataset(
222
- data_args.data_folder, split="train", streaming=data_args.streaming
250
+ os.path.expanduser(data_args.data_folder),
251
+ split="train",
252
+ streaming=data_args.streaming,
223
253
  )
224
254
  # If streaming is enabled, we need to manually split the data into train/val
225
255
  if data_args.streaming and data_args.validation_split_num:
@@ -274,7 +304,13 @@ def main():
274
304
  )
275
305
  # only save the data to the disk if it is not streaming
276
306
  if not data_args.streaming:
277
- processed_dataset.save_to_disk(prepared_ds_path)
307
+ processed_dataset.save_to_disk(str(prepared_ds_path))
308
+ stats = processed_dataset.cleanup_cache_files()
309
+ LOG.info(
310
+ "Clean up the cached files for the cehrgpt pretraining dataset: %s",
311
+ stats,
312
+ )
313
+ processed_dataset = load_from_disk(str(prepared_ds_path))
278
314
 
279
315
  def filter_func(examples):
280
316
  if cehrgpt_args.drop_long_sequences:
@@ -6,6 +6,10 @@ from typing import List, Optional
6
6
  class CehrGPTArguments:
7
7
  """Arguments pertaining to what data we are going to input our model for training and eval."""
8
8
 
9
+ include_inpatient_hour_token: Optional[bool] = dataclasses.field(
10
+ default=True,
11
+ metadata={"help": "Include inpatient hour token"},
12
+ )
9
13
  include_demographics: Optional[bool] = dataclasses.field(
10
14
  default=False,
11
15
  metadata={