cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,373 +0,0 @@
1
- import datetime
2
- import os
3
- import pickle
4
- from collections import Counter, defaultdict
5
- from functools import partial
6
- from typing import Any, Dict, List
7
-
8
- import numpy as np
9
- import torch
10
- from cehrbert.models.hf_models.tokenization_utils import agg_helper
11
- from cehrbert.runners.runner_util import load_parquet_as_dataset
12
- from tqdm import tqdm
13
- from transformers.utils import is_flash_attn_2_available, logging
14
- from trl import AutoModelForCausalLMWithValueHead, PPOConfig, create_reference_model
15
-
16
- from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
17
- from cehrgpt.generation.generate_batch_hf_gpt_sequence import generate_single_batch
18
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
19
- from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
20
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
21
- from cehrgpt.rl_finetune.cehrgpt_ppo_trainer import (
22
- CehrGptPPODataCollator,
23
- CehrGptPPOTrainer,
24
- )
25
-
26
- LOG = logging.get_logger("transformers")
27
-
28
-
29
- def extract_concept_frequency(records: Dict[str, Any]) -> Dict[str, int]:
30
- batched_concept_ids = records["concept_ids"]
31
- outputs = defaultdict(int)
32
- for concept_ids in batched_concept_ids:
33
- for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
34
- outputs[concept_id] += cnt
35
- return outputs
36
-
37
-
38
- def main(args):
39
- if torch.cuda.is_available():
40
- device = torch.device("cuda")
41
- else:
42
- device = torch.device("cpu")
43
-
44
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
45
- model_folder_name = os.path.join(
46
- args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
47
- )
48
-
49
- if not os.path.exists(model_folder_name):
50
- os.makedirs(model_folder_name)
51
-
52
- if args.restore_from_checkpoint:
53
- try:
54
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
55
- model_folder_name,
56
- attn_implementation=(
57
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
58
- ),
59
- torch_dtype=(
60
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
61
- ),
62
- )
63
- except Exception:
64
- LOG.warning(
65
- "Checkpoint does not exist in %s, loading from the %s",
66
- model_folder_name,
67
- args.model_folder,
68
- )
69
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
70
- args.model_folder,
71
- attn_implementation=(
72
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
73
- ),
74
- torch_dtype=(
75
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
76
- ),
77
- )
78
- else:
79
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
80
- args.model_folder,
81
- attn_implementation=(
82
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
83
- ),
84
- torch_dtype=(
85
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
86
- ),
87
- )
88
-
89
- cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
90
- cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
91
- cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
92
- model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
93
- model.is_peft_model = False
94
- ref_model = create_reference_model(model).to(device)
95
-
96
- # create a ppo trainer
97
- ppo_trainer = CehrGptPPOTrainer(
98
- config=PPOConfig(
99
- batch_size=args.batch_size,
100
- mini_batch_size=args.mini_batch_size,
101
- init_kl_coef=args.init_kl_coef,
102
- vf_coef=args.vf_coef,
103
- kl_penalty=args.kl_penalty,
104
- gamma=args.gamma,
105
- use_score_scaling=args.use_score_scaling,
106
- ),
107
- model=model,
108
- ref_model=ref_model,
109
- tokenizer=cehrgpt_tokenizer,
110
- training_data_collator=CehrGptPPODataCollator(
111
- cehrgpt_tokenizer, max_length=args.context_window
112
- ),
113
- )
114
-
115
- LOG.info(f"Loading tokenizer at {args.model_folder}")
116
- LOG.info(f"Loading model at {args.model_folder}")
117
- LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
118
- LOG.info(f"Context window {args.context_window}")
119
- LOG.info(f"Temperature {args.temperature}")
120
- LOG.info(f"Repetition Penalty {args.repetition_penalty}")
121
- LOG.info(f"Sampling Strategy {args.sampling_strategy}")
122
- LOG.info(f"Num beam {args.num_beams}")
123
- LOG.info(f"Num beam groups {args.num_beam_groups}")
124
- LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
125
- LOG.info(f"Top P {args.top_p}")
126
- LOG.info(f"Top K {args.top_k}")
127
- LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
128
-
129
- dataset = load_parquet_as_dataset(args.demographic_data_path).filter(
130
- lambda batched: [
131
- model.config.n_positions >= num_of_concepts > args.min_num_tokens
132
- for num_of_concepts in batched["num_of_concepts"]
133
- ],
134
- batched=True,
135
- )
136
- parts = dataset.map(
137
- partial(agg_helper, map_func=extract_concept_frequency),
138
- batched=True,
139
- batch_size=1000,
140
- num_proc=args.num_proc,
141
- remove_columns=dataset.column_names,
142
- )
143
-
144
- concept_stats = defaultdict(float)
145
- for stat in tqdm(parts, desc="Aggregating the concept counts"):
146
- fixed_stat = pickle.loads(stat["data"])
147
- for concept_id, count in fixed_stat.items():
148
- concept_stats[concept_id] += count
149
- total_sum = sum(concept_stats.values())
150
- for concept_id, count in concept_stats.items():
151
- concept_stats[concept_id] = count / total_sum
152
-
153
- logs = []
154
- device = ppo_trainer.current_device
155
- total_rows = len(dataset)
156
- num_of_micro_batches = args.batch_size // args.mini_batch_size
157
- for i in tqdm(range(args.num_of_steps)):
158
- LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
159
- random_prompts = []
160
- batched_sequences = []
161
- batched_values = []
162
- batched_value_indicators = []
163
- for _ in range(num_of_micro_batches):
164
- random_indices = np.random.randint(0, total_rows, args.mini_batch_size)
165
- random_prompts_micro_batch = [
166
- record["concept_ids"][:4] for record in dataset.select(random_indices)
167
- ]
168
- random_prompts.extend(random_prompts_micro_batch)
169
- micro_batched_prompts = [
170
- cehrgpt_tokenizer.encode(random_prompt)
171
- for random_prompt in random_prompts_micro_batch
172
- ]
173
-
174
- micro_batched_sequences = generate_single_batch(
175
- cehrgpt_model,
176
- cehrgpt_tokenizer,
177
- micro_batched_prompts,
178
- max_new_tokens=args.context_window,
179
- mini_num_of_concepts=args.min_num_of_concepts,
180
- top_p=args.top_p,
181
- top_k=args.top_k,
182
- temperature=args.temperature,
183
- repetition_penalty=args.repetition_penalty,
184
- num_beams=args.num_beams,
185
- num_beam_groups=args.num_beam_groups,
186
- epsilon_cutoff=args.epsilon_cutoff,
187
- device=device,
188
- )
189
- # Clear the cache
190
- torch.cuda.empty_cache()
191
- batched_sequences.extend(micro_batched_sequences["sequences"])
192
- batched_values.extend(micro_batched_sequences["values"])
193
- batched_value_indicators.extend(micro_batched_sequences["value_indicators"])
194
-
195
- LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
196
- reward = compute_marginal_dist_reward(
197
- batched_sequences, concept_stats, cehrgpt_tokenizer
198
- )
199
- LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
200
- query_tensors = []
201
- response_tensors = []
202
- value_tensors = []
203
- value_indicator_tensors = []
204
- rewards = []
205
- for sequence, values, value_indicators in zip(
206
- batched_sequences, batched_values, batched_value_indicators
207
- ):
208
- # Convert sequence to a NumPy array if it's not already one
209
- sequence_array = np.asarray(sequence)
210
- # Find the end token
211
- condition_array = sequence_array == cehrgpt_tokenizer.end_token
212
- end_index = (
213
- np.argmax(condition_array)
214
- if condition_array.any()
215
- else len(sequence_array) - 1
216
- )
217
-
218
- sequence = sequence[: end_index + 1]
219
- values = values[: end_index + 1]
220
- value_indicators = value_indicators[: end_index + 1]
221
-
222
- query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
223
- response_tensors.append(
224
- torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
225
- )
226
- value_tensors.append(torch.tensor(cehrgpt_tokenizer.encode_value(values)))
227
- value_indicator_tensors.append(torch.tensor(value_indicators))
228
- rewards.append(reward)
229
-
230
- train_stats = ppo_trainer.step(
231
- query_tensors,
232
- response_tensors,
233
- rewards,
234
- value_tensors,
235
- value_indicator_tensors,
236
- )
237
- LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
238
- logs.append(reward)
239
- ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
240
- ppo_trainer.save_pretrained(model_folder_name)
241
- with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
242
- pickle.dump(logs, f)
243
-
244
-
245
- def compute_marginal_dist_reward(
246
- batched_sequences: List[List[str]],
247
- expected_concept_dist: Dict[str, float],
248
- tokenizer: CehrGptTokenizer,
249
- ) -> torch.Tensor:
250
- actual_concept_dist = dict(
251
- Counter(
252
- [
253
- concept_id
254
- for sequence in batched_sequences
255
- for concept_id in sequence[4:]
256
- ]
257
- )
258
- )
259
- total_count = sum(actual_concept_dist.values())
260
- for concept_id in actual_concept_dist.keys():
261
- actual_concept_dist[concept_id] /= total_count
262
- # Translate the concept ids to token ids
263
- actual_dist = np.zeros(tokenizer.vocab_size)
264
- actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
265
- actual_concept_dist.values()
266
- )
267
- # Add a small epsilon to avoid log(0)
268
- epsilon = 1e-10
269
- logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
270
- # Translate the concept ids to token ids
271
- ref_dist = np.zeros(tokenizer.vocab_size)
272
- ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
273
- expected_concept_dist.values()
274
- )
275
- ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
276
-
277
- # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
278
- return torch.exp(
279
- -torch.nn.functional.kl_div(
280
- ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
281
- ).sum(-1)
282
- )
283
-
284
-
285
- def create_arg_parser():
286
- base_arg_parser = create_inference_base_arg_parser(
287
- description="Arguments for finetuning cehr-gpt using PPO"
288
- )
289
- base_arg_parser.add_argument(
290
- "--mini_batch_size",
291
- dest="mini_batch_size",
292
- action="store",
293
- type=int,
294
- required=True,
295
- )
296
- base_arg_parser.add_argument(
297
- "--init_kl_coef",
298
- dest="init_kl_coef",
299
- action="store",
300
- type=float,
301
- required=False,
302
- default=0.1,
303
- )
304
- base_arg_parser.add_argument(
305
- "--vf_coef",
306
- dest="vf_coef",
307
- action="store",
308
- type=float,
309
- required=False,
310
- default=0.1,
311
- )
312
- base_arg_parser.add_argument(
313
- "--kl_penalty",
314
- dest="kl_penalty",
315
- action="store",
316
- choices=["kl", "abs", "mse", "full"],
317
- required=False,
318
- default="kl",
319
- )
320
- base_arg_parser.add_argument(
321
- "--gamma",
322
- dest="gamma",
323
- action="store",
324
- type=float,
325
- required=False,
326
- default=0.99,
327
- )
328
- base_arg_parser.add_argument(
329
- "--num_proc",
330
- dest="num_proc",
331
- action="store",
332
- type=int,
333
- default=4,
334
- required=False,
335
- )
336
- base_arg_parser.add_argument(
337
- "--num_of_steps",
338
- dest="num_of_steps",
339
- action="store",
340
- type=int,
341
- default=1028,
342
- required=False,
343
- )
344
- base_arg_parser.add_argument(
345
- "--min_num_tokens",
346
- dest="min_num_tokens",
347
- action="store",
348
- type=int,
349
- default=4,
350
- required=False,
351
- )
352
- base_arg_parser.add_argument(
353
- "--demographic_data_path",
354
- dest="demographic_data_path",
355
- action="store",
356
- help="The path for your concept_path",
357
- required=True,
358
- )
359
- base_arg_parser.add_argument(
360
- "--restore_from_checkpoint",
361
- dest="restore_from_checkpoint",
362
- action="store_true",
363
- )
364
- base_arg_parser.add_argument(
365
- "--use_score_scaling",
366
- dest="use_score_scaling",
367
- action="store_true",
368
- )
369
- return base_arg_parser
370
-
371
-
372
- if __name__ == "__main__":
373
- main(create_arg_parser().parse_args())
@@ -1,119 +0,0 @@
1
- from cehrbert.data_generators.hf_data_generator.hf_dataset import (
2
- apply_cehrbert_dataset_mapping,
3
- )
4
- from cehrbert.runners.runner_util import (
5
- generate_prepared_ds_path,
6
- get_last_hf_checkpoint,
7
- load_parquet_as_dataset,
8
- )
9
- from datasets import DatasetDict, load_from_disk
10
- from transformers import set_seed
11
- from transformers.utils import is_flash_attn_2_available, logging
12
-
13
- from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
14
- from cehrgpt.data.hf_cehrgpt_dpo_dataset_mapping import HFCehrGptDPOTokenizationMapping
15
- from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
16
- from cehrgpt.rl_finetune.cehrgpt_dpo_trainer import CehrGptDPOTrainer
17
- from cehrgpt.runners.gpt_runner_util import parse_dpo_runner_args
18
- from cehrgpt.runners.hf_cehrgpt_finetune_runner import load_pretrained_tokenizer
19
-
20
- LOG = logging.get_logger("transformers")
21
-
22
-
23
- def main():
24
- cehrgpt_args, data_args, model_args, dpo_config = parse_dpo_runner_args()
25
- tokenizer = load_pretrained_tokenizer(model_args)
26
- prepared_ds_path = generate_prepared_ds_path(
27
- data_args, model_args, data_folder=data_args.cohort_folder
28
- )
29
- if any(prepared_ds_path.glob("*")):
30
- LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
31
- processed_dataset = load_from_disk(str(prepared_ds_path))
32
- LOG.info("Prepared dataset loaded from disk...")
33
- else:
34
- dataset = load_parquet_as_dataset(data_args.data_folder)
35
- # Random split
36
- dataset = dataset.train_test_split(
37
- test_size=data_args.validation_split_percentage, seed=dpo_config.seed
38
- )
39
- processed_dataset = apply_cehrbert_dataset_mapping(
40
- dataset,
41
- mapping_function=HFCehrGptDPOTokenizationMapping(tokenizer),
42
- batch_size=data_args.preprocessing_batch_size,
43
- num_proc=data_args.preprocessing_num_workers,
44
- streaming=data_args.streaming,
45
- )
46
-
47
- processed_dataset = processed_dataset.filter(
48
- lambda batch: [
49
- len(chosen_concept_ids) < model_args.max_position_embeddings
50
- for chosen_concept_ids in batch["chosen_concept_ids"]
51
- ],
52
- batched=True,
53
- batch_size=data_args.preprocessing_batch_size,
54
- num_proc=data_args.preprocessing_num_workers,
55
- ).filter(
56
- lambda batch: [
57
- len(rejected_concept_ids) < model_args.max_position_embeddings
58
- for rejected_concept_ids in batch["rejected_concept_ids"]
59
- ],
60
- batched=True,
61
- batch_size=data_args.preprocessing_batch_size,
62
- num_proc=data_args.preprocessing_num_workers,
63
- )
64
- processed_dataset.save_to_disk(prepared_ds_path)
65
-
66
- # Set seed before initializing model.
67
- set_seed(dpo_config.seed)
68
- processed_dataset.set_format("pt")
69
-
70
- # A hacky way to prevent the training from removing unmatched inputs
71
- dpo_config.label_names = [
72
- "chosen_input_ids",
73
- "rejected_input_ids",
74
- "chosen_concept_values",
75
- "rejected_concept_values",
76
- "chosen_concept_value_masks",
77
- "rejected_concept_value_masks",
78
- ]
79
-
80
- attn_implementation = (
81
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
82
- )
83
- model = CEHRGPT2LMHeadModel.from_pretrained(
84
- model_args.model_name_or_path,
85
- attn_implementation=attn_implementation,
86
- )
87
- ref_model = CEHRGPT2LMHeadModel.from_pretrained(
88
- model_args.model_name_or_path,
89
- attn_implementation=attn_implementation,
90
- )
91
-
92
- # Initialize Trainer for final training on the combined train+val set
93
- trainer = CehrGptDPOTrainer(
94
- model=model,
95
- ref_model=ref_model,
96
- args=dpo_config,
97
- tokenizer=tokenizer,
98
- train_dataset=processed_dataset["train"],
99
- eval_dataset=processed_dataset["test"],
100
- data_collator=CehrGptDPODataCollator(
101
- tokenizer=tokenizer,
102
- max_length=model_args.max_position_embeddings,
103
- pretraining=False,
104
- include_ttv_prediction=False,
105
- use_sub_time_tokenization=False,
106
- ),
107
- )
108
- # Train the model on the combined train + val set
109
- checkpoint = get_last_hf_checkpoint(dpo_config)
110
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
111
- trainer.save_model() # Saves the tokenizer too for easy upload
112
- metrics = train_result.metrics
113
- trainer.log_metrics("train", metrics)
114
- trainer.save_metrics("train", metrics)
115
- trainer.save_state()
116
-
117
-
118
- if __name__ == "__main__":
119
- main()
File without changes