cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,394 +0,0 @@
1
- import datetime
2
- import os
3
- import pickle
4
- import random
5
- from collections import Counter, defaultdict
6
- from functools import partial
7
- from typing import Any, Dict, List, Tuple
8
-
9
- import numpy as np
10
- import torch
11
- from cehrbert.models.hf_models.tokenization_utils import agg_helper
12
- from cehrbert.runners.runner_util import load_parquet_as_dataset
13
- from tqdm import tqdm
14
- from transformers import GenerationConfig
15
- from transformers.utils import is_flash_attn_2_available, logging
16
- from trl import (
17
- AutoModelForCausalLMWithValueHead,
18
- PPOConfig,
19
- PPOTrainer,
20
- create_reference_model,
21
- )
22
-
23
- from cehrgpt.cehrgpt_args import create_inference_base_arg_parser
24
- from cehrgpt.gpt_utils import get_cehrgpt_output_folder
25
- from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
26
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
27
-
28
- LOG = logging.get_logger("transformers")
29
-
30
-
31
- def extract_demographics_info(
32
- records: Dict[str, Any]
33
- ) -> Dict[Tuple[str, str, str, str], Dict[str, int]]:
34
- batched_concept_ids = records["concept_ids"]
35
- outputs = defaultdict(dict)
36
- for concept_ids in batched_concept_ids:
37
- start_year, start_age, gender, race = concept_ids[:4]
38
- existing_stats = outputs[(start_year, start_age, gender, race)]
39
- for concept_id, cnt in dict(Counter(concept_ids[4:])).items():
40
- if concept_id in existing_stats:
41
- existing_stats[concept_id] += cnt
42
- else:
43
- existing_stats[concept_id] = cnt
44
- if "total" in existing_stats:
45
- existing_stats["total"] += 1
46
- else:
47
- existing_stats["total"] = 1
48
- return outputs
49
-
50
-
51
- def generate_single_batch(
52
- model,
53
- tokenizer,
54
- batched_prompts,
55
- max_new_tokens=512,
56
- mini_num_of_concepts=1,
57
- top_p=0.95,
58
- top_k=50,
59
- temperature=1.0,
60
- repetition_penalty=1.0,
61
- num_beams=1,
62
- num_beam_groups=1,
63
- epsilon_cutoff=0.0,
64
- ) -> List[List[str]]:
65
- with torch.no_grad():
66
- generation_config = GenerationConfig(
67
- repetition_penalty=repetition_penalty,
68
- max_length=max_new_tokens,
69
- min_length=mini_num_of_concepts,
70
- temperature=temperature,
71
- top_p=top_p,
72
- top_k=top_k,
73
- bos_token_id=tokenizer.end_token_id,
74
- eos_token_id=tokenizer.end_token_id,
75
- pad_token_id=tokenizer.pad_token_id,
76
- do_sample=True,
77
- use_cache=True,
78
- return_dict_in_generate=True,
79
- output_attentions=False,
80
- output_hidden_states=False,
81
- output_scores=False,
82
- renormalize_logits=True,
83
- num_beams=num_beams,
84
- num_beam_groups=num_beam_groups,
85
- epsilon_cutoff=epsilon_cutoff,
86
- )
87
- results = model.generate(
88
- inputs=batched_prompts, generation_config=generation_config
89
- )
90
-
91
- return [tokenizer.decode(seq.cpu().numpy()) for seq in results.sequences]
92
-
93
-
94
- def main(args):
95
- if torch.cuda.is_available():
96
- device = torch.device("cuda")
97
- else:
98
- device = torch.device("cpu")
99
-
100
- cehrgpt_tokenizer = CehrGptTokenizer.from_pretrained(args.tokenizer_folder)
101
- model_folder_name = os.path.join(
102
- args.output_folder, get_cehrgpt_output_folder(args, cehrgpt_tokenizer), "model"
103
- )
104
-
105
- if not os.path.exists(model_folder_name):
106
- os.makedirs(model_folder_name)
107
-
108
- if args.restore_from_checkpoint:
109
- try:
110
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
111
- model_folder_name,
112
- attn_implementation=(
113
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
114
- ),
115
- torch_dtype=(
116
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
117
- ),
118
- )
119
- except Exception:
120
- LOG.warning(
121
- "Checkpoint does not exist in %s, loading from the %s",
122
- model_folder_name,
123
- args.model_folder,
124
- )
125
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
126
- args.model_folder,
127
- attn_implementation=(
128
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
129
- ),
130
- torch_dtype=(
131
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
132
- ),
133
- )
134
- else:
135
- cehrgpt_model = CEHRGPT2LMHeadModel.from_pretrained(
136
- args.model_folder,
137
- attn_implementation=(
138
- "flash_attention_2" if is_flash_attn_2_available() else "eager"
139
- ),
140
- torch_dtype=(
141
- torch.bfloat16 if is_flash_attn_2_available() else torch.float32
142
- ),
143
- )
144
-
145
- cehrgpt_model.generation_config.pad_token_id = cehrgpt_tokenizer.pad_token_id
146
- cehrgpt_model.generation_config.eos_token_id = cehrgpt_tokenizer.end_token_id
147
- cehrgpt_model.generation_config.bos_token_id = cehrgpt_tokenizer.end_token_id
148
- model = AutoModelForCausalLMWithValueHead(cehrgpt_model).to(device)
149
- model.is_peft_model = False
150
- ref_model = create_reference_model(model).to(device)
151
-
152
- # create a ppo trainer
153
- ppo_trainer = PPOTrainer(
154
- config=PPOConfig(
155
- batch_size=args.batch_size,
156
- mini_batch_size=args.mini_batch_size,
157
- init_kl_coef=args.init_kl_coef,
158
- vf_coef=args.vf_coef,
159
- kl_penalty=args.kl_penalty,
160
- gamma=args.gamma,
161
- ),
162
- model=model,
163
- ref_model=ref_model,
164
- tokenizer=cehrgpt_tokenizer,
165
- )
166
-
167
- LOG.info(f"Loading tokenizer at {args.model_folder}")
168
- LOG.info(f"Loading model at {args.model_folder}")
169
- LOG.info(f"Will save the fine-tuned model at {model_folder_name}")
170
- LOG.info(f"Context window {args.context_window}")
171
- LOG.info(f"Temperature {args.temperature}")
172
- LOG.info(f"Repetition Penalty {args.repetition_penalty}")
173
- LOG.info(f"Sampling Strategy {args.sampling_strategy}")
174
- LOG.info(f"Num beam {args.num_beams}")
175
- LOG.info(f"Num beam groups {args.num_beam_groups}")
176
- LOG.info(f"Epsilon cutoff {args.epsilon_cutoff}")
177
- LOG.info(f"Top P {args.top_p}")
178
- LOG.info(f"Top K {args.top_k}")
179
- LOG.info(f"Loading demographic_info at {args.demographic_data_path}")
180
-
181
- dataset = load_parquet_as_dataset(args.demographic_data_path)
182
- parts = dataset.filter(
183
- lambda batched: [
184
- num_of_concepts > 4 for num_of_concepts in batched["num_of_concepts"]
185
- ],
186
- batched=True,
187
- ).map(
188
- partial(agg_helper, map_func=extract_demographics_info),
189
- batched=True,
190
- batch_size=1000,
191
- num_proc=args.num_proc,
192
- remove_columns=dataset.column_names,
193
- )
194
- prompts_and_concept_stats = defaultdict(dict)
195
- for stat in tqdm(parts, desc="Aggregating the concept counts"):
196
- fixed_stat = pickle.loads(stat["data"])
197
- for prompt, concept_stats in fixed_stat.items():
198
- for concept_id, count in concept_stats.items():
199
- if concept_id not in prompts_and_concept_stats[prompt]:
200
- prompts_and_concept_stats[prompt][concept_id] = count
201
- else:
202
- prompts_and_concept_stats[prompt][concept_id] += count
203
-
204
- prompt_weights = defaultdict(int)
205
- for prompt, concept_stats in prompts_and_concept_stats.items():
206
- prompt_weight = concept_stats.pop("total")
207
- prompt_weights[prompt] = prompt_weight
208
- total_count = sum(concept_stats.values())
209
- for concept_id in concept_stats.keys():
210
- concept_stats[concept_id] = concept_stats[concept_id] / total_count
211
-
212
- logs = []
213
- prompts = list(prompt_weights.keys())
214
- weight_sum = sum(prompt_weights.values())
215
- prompt_weights = np.asarray(list(prompt_weights.values())) / weight_sum
216
- device = ppo_trainer.current_device
217
- num_of_micro_batches = args.batch_size // args.mini_batch_size
218
- for i in tqdm(range(args.num_of_steps)):
219
- LOG.info(f"{datetime.datetime.now()}: Batch {i} started")
220
- random_prompt = random.choices(prompts, weights=prompt_weights, k=1)[0]
221
- prompt_weight = prompt_weights[prompts.index(random_prompt)]
222
- LOG.info(
223
- f"%s: Batch %s random_prompt: %s with weight %.2f%% (%d / %s)",
224
- datetime.datetime.now(),
225
- i,
226
- random_prompt,
227
- prompt_weight * 100,
228
- int(prompt_weights[prompts.index(random_prompt)] * weight_sum),
229
- weight_sum,
230
- )
231
- expected_concept_dist = prompts_and_concept_stats[random_prompt]
232
- batched_sequences = []
233
- for _ in range(num_of_micro_batches):
234
- batched_prompts = torch.tensor(
235
- [
236
- cehrgpt_tokenizer.encode(random_prompt)
237
- for _ in range(args.mini_batch_size)
238
- ]
239
- ).to(device)
240
- mini_batched_sequences = generate_single_batch(
241
- cehrgpt_model,
242
- cehrgpt_tokenizer,
243
- batched_prompts,
244
- max_new_tokens=args.context_window,
245
- mini_num_of_concepts=args.min_num_of_concepts,
246
- top_p=args.top_p,
247
- top_k=args.top_k,
248
- temperature=args.temperature,
249
- repetition_penalty=args.repetition_penalty,
250
- num_beams=args.num_beams,
251
- num_beam_groups=args.num_beam_groups,
252
- epsilon_cutoff=args.epsilon_cutoff,
253
- )
254
- # Clear the cache
255
- torch.cuda.empty_cache()
256
- batched_sequences.extend(mini_batched_sequences)
257
-
258
- LOG.info(f"{datetime.datetime.now()}: Batch {i} sequence generated")
259
- reward = compute_marginal_dist_reward(
260
- batched_sequences, expected_concept_dist, cehrgpt_tokenizer
261
- )
262
- LOG.info(f"{datetime.datetime.now()}: Batch {i} KL divergence reward: {reward}")
263
- query_tensors = []
264
- response_tensors = []
265
- rewards = []
266
- for sequence in batched_sequences:
267
- query_tensors.append(torch.tensor(cehrgpt_tokenizer.encode(sequence[:4])))
268
- response_tensors.append(
269
- torch.tensor(cehrgpt_tokenizer.encode(sequence[4:]))
270
- )
271
- rewards.append(reward)
272
- train_stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
273
- LOG.info(f"{datetime.datetime.now()}: Batch {i} stats: {train_stats}")
274
- logs.append(reward)
275
- ppo_trainer.log_stats(stats=train_stats, batch={}, rewards=rewards)
276
- ppo_trainer.save_pretrained(model_folder_name)
277
- with open(os.path.join(model_folder_name, "ppo_finetune_stats.pkl"), "wb") as f:
278
- pickle.dump(logs, f)
279
-
280
-
281
- def compute_marginal_dist_reward(
282
- batched_sequences: List[List[str]],
283
- expected_concept_dist: Dict[str, float],
284
- tokenizer: CehrGptTokenizer,
285
- ) -> torch.Tensor:
286
- actual_concept_dist = dict(
287
- Counter(
288
- [
289
- concept_id
290
- for sequence in batched_sequences
291
- for concept_id in sequence[4:]
292
- ]
293
- )
294
- )
295
- total_count = sum(actual_concept_dist.values())
296
- for concept_id in actual_concept_dist.keys():
297
- actual_concept_dist[concept_id] /= total_count
298
- # Translate the concept ids to token ids
299
- actual_dist = np.zeros(tokenizer.vocab_size)
300
- actual_dist[tokenizer.encode(list(actual_concept_dist.keys()))] = list(
301
- actual_concept_dist.values()
302
- )
303
- # Add a small epsilon to avoid log(0)
304
- epsilon = 1e-10
305
- logprob_dist = torch.tensor(np.log(actual_dist + epsilon))
306
- # Translate the concept ids to token ids
307
- ref_dist = np.zeros(tokenizer.vocab_size)
308
- ref_dist[tokenizer.encode(list(expected_concept_dist.keys()))] = list(
309
- expected_concept_dist.values()
310
- )
311
- ref_logprob_dist = torch.tensor(np.log(ref_dist + epsilon))
312
-
313
- # Flip is required due to this issue? :https://github.com/pytorch/pytorch/issues/57459
314
- return -torch.nn.functional.kl_div(
315
- ref_logprob_dist, logprob_dist, log_target=True, reduction="none"
316
- ).sum(-1)
317
-
318
-
319
- def create_arg_parser():
320
- base_arg_parser = create_inference_base_arg_parser(
321
- description="Arguments for finetuning cehr-gpt using PPO"
322
- )
323
- base_arg_parser.add_argument(
324
- "--mini_batch_size",
325
- dest="mini_batch_size",
326
- action="store",
327
- type=int,
328
- required=True,
329
- )
330
- base_arg_parser.add_argument(
331
- "--init_kl_coef",
332
- dest="init_kl_coef",
333
- action="store",
334
- type=float,
335
- required=False,
336
- default=0.1,
337
- )
338
- base_arg_parser.add_argument(
339
- "--vf_coef",
340
- dest="vf_coef",
341
- action="store",
342
- type=float,
343
- required=False,
344
- default=0.1,
345
- )
346
- base_arg_parser.add_argument(
347
- "--kl_penalty",
348
- dest="kl_penalty",
349
- action="store",
350
- choices=["kl", "abs", "mse", "full"],
351
- required=False,
352
- default="kl",
353
- )
354
- base_arg_parser.add_argument(
355
- "--gamma",
356
- dest="gamma",
357
- action="store",
358
- type=float,
359
- required=False,
360
- default=0.99,
361
- )
362
- base_arg_parser.add_argument(
363
- "--num_proc",
364
- dest="num_proc",
365
- action="store",
366
- type=int,
367
- default=4,
368
- required=False,
369
- )
370
- base_arg_parser.add_argument(
371
- "--num_of_steps",
372
- dest="num_of_steps",
373
- action="store",
374
- type=int,
375
- default=1028,
376
- required=False,
377
- )
378
- base_arg_parser.add_argument(
379
- "--demographic_data_path",
380
- dest="demographic_data_path",
381
- action="store",
382
- help="The path for your concept_path",
383
- required=True,
384
- )
385
- base_arg_parser.add_argument(
386
- "--restore_from_checkpoint",
387
- dest="restore_from_checkpoint",
388
- action="store_true",
389
- )
390
- return base_arg_parser
391
-
392
-
393
- if __name__ == "__main__":
394
- main(create_arg_parser().parse_args())