cehrgpt 0.0.2__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. cehrgpt/data/hf_cehrgpt_dataset.py +24 -4
  2. cehrgpt/data/hf_cehrgpt_dataset_collator.py +260 -84
  3. cehrgpt/data/hf_cehrgpt_dataset_mapping.py +99 -88
  4. cehrgpt/data/sample_packing_sampler.py +151 -0
  5. cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
  6. cehrgpt/models/config.py +10 -0
  7. cehrgpt/models/hf_cehrgpt.py +243 -73
  8. cehrgpt/models/tokenization_hf_cehrgpt.py +4 -0
  9. cehrgpt/runners/data_utils.py +243 -0
  10. cehrgpt/runners/gpt_runner_util.py +0 -10
  11. cehrgpt/runners/hf_cehrgpt_finetune_runner.py +152 -279
  12. cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +229 -105
  13. cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +42 -0
  14. cehrgpt/runners/hyperparameter_search_util.py +4 -1
  15. cehrgpt/runners/sample_packing_trainer.py +168 -0
  16. cehrgpt/simulations/generate_plots.py +95 -0
  17. cehrgpt/simulations/run_simulation.sh +24 -0
  18. cehrgpt/simulations/time_embedding_simulation.py +250 -0
  19. cehrgpt/simulations/time_token_simulation.py +177 -0
  20. cehrgpt/tools/linear_prob/__init__.py +0 -0
  21. cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +467 -0
  22. cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
  23. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/METADATA +7 -5
  24. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/RECORD +28 -26
  25. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/WHEEL +1 -1
  26. cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
  27. cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
  28. cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
  29. cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
  30. cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
  31. cehrgpt/rl_finetune/ppo_finetune.py +0 -394
  32. cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
  33. cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
  34. /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
  35. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info/licenses}/LICENSE +0 -0
  36. {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,464 +0,0 @@
1
- import time
2
- from typing import List, Optional
3
-
4
- import numpy as np
5
- import torch
6
- from torch.nn.utils.rnn import pad_sequence
7
- from trl.core import (
8
- WANDB_PADDING,
9
- PPODecorators,
10
- convert_to_scalar,
11
- logprobs_from_logits,
12
- stack_dicts,
13
- stats_to_np,
14
- )
15
- from trl.trainer import PPOTrainer
16
-
17
- from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
18
-
19
-
20
- class CehrGptPPODataCollator:
21
- def __init__(self, tokenizer: CehrGptTokenizer, max_length: int):
22
- self.tokenizer = tokenizer
23
- self.max_length = max_length
24
-
25
- def __call__(self, examples):
26
-
27
- batch = {}
28
-
29
- # Pad sequences to the max length in the batch
30
- batch["input_ids"] = pad_sequence(
31
- [example["input_ids"] for example in examples],
32
- batch_first=True,
33
- padding_value=self.tokenizer.pad_token_id,
34
- ).to(torch.int64)
35
-
36
- batch["attention_mask"] = pad_sequence(
37
- [example["attention_mask"] for example in examples],
38
- batch_first=True,
39
- padding_value=0.0,
40
- )
41
-
42
- assert (
43
- batch["input_ids"].shape[1] <= self.max_length
44
- ), f"Invalid input_ids length: {batch['input_ids'].shape[1]}"
45
-
46
- if "value_indicators" in examples[0]:
47
- batch["value_indicators"] = pad_sequence(
48
- [example["value_indicators"] for example in examples],
49
- batch_first=True,
50
- padding_value=False,
51
- )
52
-
53
- if "values" in examples[0]:
54
- batch["values"] = pad_sequence(
55
- [example["values"] for example in examples],
56
- batch_first=True,
57
- padding_value=self.tokenizer.pad_value_token_id,
58
- )
59
- assert batch["value_indicators"].shape[1] <= self.max_length
60
- assert batch["values"].shape[1] <= self.max_length
61
-
62
- return batch
63
-
64
-
65
- class CehrGptPPOTrainer(PPOTrainer):
66
- def _step_safety_checker(
67
- self,
68
- batch_size: int,
69
- queries: List[torch.LongTensor],
70
- responses: List[torch.LongTensor],
71
- scores: List[torch.FloatTensor],
72
- values: List[torch.Tensor] = None,
73
- value_indicators: List[torch.BoolTensor] = None,
74
- masks: Optional[List[torch.LongTensor]] = None,
75
- ):
76
- """
77
- Check if the input data is valid for training.
78
-
79
- Args:
80
- batch_size (int):
81
- Batch size from the config file.
82
- queries (List[`torch.LongTensor`]):
83
- List of tensors containing the encoded queries of shape (`query_length`)
84
- responses (List[`torch.LongTensor`]):
85
- List of tensors containing the encoded responses of shape (`response_length`)
86
- scores (List[`torch.FloatTensor`]):
87
- List of tensors containing the scores.
88
- masks (List[`torch.LongTensor`], *optional*):
89
- list of optional tensors containing the masks of shape (`response_length`)
90
-
91
- Returns:
92
- `tuple`: The input processed data.
93
- """
94
- for name, tensor_list in zip(
95
- ["queries", "responses", "scores", "values", "value_indicators"],
96
- [queries, responses, scores, values, value_indicators],
97
- ):
98
- if not isinstance(tensor_list, list):
99
- raise ValueError(
100
- f"{name} must be a list of tensors - got {type(tensor_list)}"
101
- )
102
- if not isinstance(tensor_list[0], torch.Tensor):
103
- raise ValueError(
104
- f"Elements in {name} must be tensors - got {type(tensor_list[0])}"
105
- )
106
- if batch_size is not None and len(tensor_list) != batch_size:
107
- raise ValueError(
108
- f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
109
- )
110
-
111
- # add queries, scores and responses on the correct device
112
- queries = [tensor.to(self.current_device) for tensor in queries]
113
- responses = [tensor.to(self.current_device) for tensor in responses]
114
- scores = [tensor.to(self.current_device) for tensor in scores]
115
- masks = (
116
- [tensor.to(self.current_device) for tensor in masks]
117
- if masks is not None
118
- else None
119
- )
120
- values = (
121
- [tensor.to(self.current_device) for tensor in values]
122
- if values is not None
123
- else None
124
- )
125
- value_indicators = (
126
- [tensor.to(self.current_device) for tensor in value_indicators]
127
- if value_indicators is not None
128
- else None
129
- )
130
-
131
- # squeeze scores if needed
132
- for i, score in enumerate(scores):
133
- if score.dim() > 1:
134
- raise ValueError(
135
- f"Scores must be 1-dimensional - got {score.dim()} for {score}"
136
- )
137
- elif score.dim() == 1:
138
- scores[i] = score.squeeze()
139
-
140
- return queries, responses, scores, values, value_indicators, masks
141
-
142
- @PPODecorators.empty_device_cache()
143
- def step(
144
- self,
145
- queries: List[torch.LongTensor],
146
- responses: List[torch.LongTensor],
147
- scores: List[torch.FloatTensor],
148
- values: List[torch.Tensor] = None,
149
- value_indicators: List[torch.BoolTensor] = None,
150
- response_masks: Optional[List[torch.LongTensor]] = None,
151
- ):
152
-
153
- bs = self.config.batch_size
154
-
155
- queries, responses, scores, values, value_indicators, response_masks = (
156
- self._step_safety_checker(
157
- bs, queries, responses, scores, values, value_indicators, response_masks
158
- )
159
- )
160
- scores = torch.tensor(scores, device=self.current_device)
161
- if self.config.use_score_scaling:
162
- # Score scaling
163
- scores_mean, scores_std = self.running.update(scores)
164
- score_scaling_factor = scores_std + torch.finfo(scores.dtype).eps
165
- if self.config.use_score_norm:
166
- scores = (scores - scores_mean) / score_scaling_factor
167
- else:
168
- scores /= score_scaling_factor
169
-
170
- if self.config.score_clip is not None:
171
- # Score clipping
172
- scores_dtype = scores.dtype
173
- scores = torch.clip(
174
- scores.float(), -self.config.score_clip, self.config.score_clip
175
- ).to(dtype=scores_dtype)
176
-
177
- # if we want to push best model to the hub
178
- if hasattr(self, "highest_reward"):
179
- if self.compare_step % self.config.compare_steps == 0:
180
- curr_mean_reward = scores.mean()
181
- # if the best reward ever seen
182
- if curr_mean_reward > self.highest_reward:
183
- self.highest_reward = curr_mean_reward
184
- # push model to hub
185
- self.push_to_hub(**self.push_to_hub_kwargs)
186
- self.compare_step += 1
187
-
188
- timing = dict()
189
- t0 = time.time()
190
-
191
- t = time.time()
192
-
193
- model_inputs = self.prepare_model_inputs(
194
- queries, responses, values, value_indicators
195
- )
196
-
197
- if self.is_distributed:
198
- pad_first = self.tokenizer.padding_side == "left"
199
-
200
- model_inputs["input_ids"] = self.accelerator.pad_across_processes(
201
- model_inputs["input_ids"],
202
- dim=1,
203
- pad_index=self.tokenizer.pad_token_id,
204
- pad_first=pad_first,
205
- )
206
- model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
207
- model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
208
- )
209
- if values is not None:
210
- model_inputs["values"] = self.accelerator.pad_across_processes(
211
- model_inputs["values"],
212
- dim=1,
213
- pad_index=self.tokenizer.pad_value_token_id,
214
- pad_first=pad_first,
215
- )
216
- if value_indicators is not None:
217
- model_inputs["value_indicators"] = (
218
- self.accelerator.pad_across_processes(
219
- model_inputs["value_indicators"],
220
- dim=1,
221
- pad_index=False,
222
- pad_first=pad_first,
223
- )
224
- )
225
- if self.is_encoder_decoder:
226
- model_inputs["decoder_input_ids"] = (
227
- self.accelerator.pad_across_processes(
228
- model_inputs["decoder_input_ids"],
229
- dim=1,
230
- pad_index=self.tokenizer.pad_token_id,
231
- pad_first=pad_first,
232
- )
233
- )
234
- model_inputs["decoder_attention_mask"] = (
235
- self.accelerator.pad_across_processes(
236
- model_inputs["decoder_attention_mask"],
237
- dim=1,
238
- pad_index=0,
239
- pad_first=pad_first,
240
- )
241
- )
242
-
243
- model_inputs_names = list(model_inputs.keys())
244
-
245
- full_kl_penalty = self.config.kl_penalty == "full"
246
-
247
- with torch.no_grad():
248
- all_logprobs, logits_or_none, states_values, masks = (
249
- self.batched_forward_pass(
250
- self.model,
251
- queries,
252
- responses,
253
- model_inputs,
254
- response_masks=response_masks,
255
- return_logits=full_kl_penalty,
256
- )
257
- )
258
- with self.optional_peft_ctx():
259
- ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
260
- self.model if self.is_peft_model else self.ref_model,
261
- queries,
262
- responses,
263
- model_inputs,
264
- return_logits=full_kl_penalty,
265
- )
266
-
267
- timing["time/ppo/forward_pass"] = time.time() - t
268
-
269
- with torch.no_grad():
270
- t = time.time()
271
- if full_kl_penalty:
272
- active_full_logprobs = logprobs_from_logits(
273
- logits_or_none, None, gather=False
274
- )
275
- ref_full_logprobs = logprobs_from_logits(
276
- ref_logits_or_none, None, gather=False
277
- )
278
-
279
- rewards, non_score_reward, kls = self.compute_rewards(
280
- scores, active_full_logprobs, ref_full_logprobs, masks
281
- )
282
- else:
283
- rewards, non_score_reward, kls = self.compute_rewards(
284
- scores, all_logprobs, ref_logprobs, masks
285
- )
286
- timing["time/ppo/compute_rewards"] = time.time() - t
287
-
288
- t = time.time()
289
- states_values, advantages, returns = self.compute_advantages(
290
- states_values, rewards, masks
291
- )
292
- timing["time/ppo/compute_advantages"] = time.time() - t
293
-
294
- # upcast to float32 to avoid dataset issues
295
- batch_dict = {
296
- "queries": queries,
297
- "responses": responses,
298
- "logprobs": all_logprobs.to(torch.float32),
299
- "states_values": states_values.to(torch.float32),
300
- "masks": masks,
301
- "advantages": advantages,
302
- "returns": returns,
303
- }
304
- batch_dict.update(model_inputs)
305
-
306
- t = time.time()
307
- all_stats = []
308
- early_stop = False
309
- for _ in range(self.config.ppo_epochs):
310
- if early_stop:
311
- break
312
- b_inds = np.random.permutation(bs)
313
- for backward_batch_start in range(0, bs, self.config.backward_batch_size):
314
- backward_batch_end = (
315
- backward_batch_start + self.config.backward_batch_size
316
- )
317
- backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
318
-
319
- for mini_batch_start in range(
320
- 0, self.config.backward_batch_size, self.config.mini_batch_size
321
- ):
322
- mini_batch_end = mini_batch_start + self.config.mini_batch_size
323
- mini_batch_inds = backward_batch_inds[
324
- mini_batch_start:mini_batch_end
325
- ]
326
- mini_batch_dict = {
327
- "logprobs": batch_dict["logprobs"][mini_batch_inds],
328
- "states_values": batch_dict["states_values"][mini_batch_inds],
329
- "masks": batch_dict["masks"][mini_batch_inds],
330
- # hacks: the queries and responses are ragged.
331
- "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
332
- "responses": [
333
- batch_dict["responses"][i] for i in mini_batch_inds
334
- ],
335
- "advantages": batch_dict["advantages"][mini_batch_inds],
336
- "returns": batch_dict["returns"][mini_batch_inds],
337
- }
338
- for k in model_inputs_names:
339
- mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
340
- with self.accelerator.accumulate(self.model):
341
- model_inputs = {
342
- k: mini_batch_dict[k] for k in model_inputs_names
343
- }
344
-
345
- logprobs, logits, vpreds, _ = self.batched_forward_pass(
346
- self.model,
347
- mini_batch_dict["queries"],
348
- mini_batch_dict["responses"],
349
- model_inputs,
350
- return_logits=True,
351
- )
352
- train_stats = self.train_minibatch(
353
- mini_batch_dict["logprobs"],
354
- mini_batch_dict["states_values"],
355
- logprobs,
356
- logits,
357
- vpreds,
358
- mini_batch_dict["masks"],
359
- mini_batch_dict["advantages"],
360
- mini_batch_dict["returns"],
361
- )
362
- all_stats.append(train_stats)
363
-
364
- # typically, early stopping is done at the epoch level
365
- if self.config.early_stopping:
366
- policykl = train_stats["policy/policykl"]
367
- early_stop = self._early_stop(policykl)
368
- if early_stop:
369
- break
370
-
371
- timing["time/ppo/optimize_step"] = time.time() - t
372
-
373
- t = time.time()
374
- train_stats = stack_dicts(all_stats)
375
-
376
- # reshape advantages/ratios such that they are not averaged.
377
- train_stats["policy/advantages"] = torch.flatten(
378
- train_stats["policy/advantages"]
379
- ).unsqueeze(0)
380
- train_stats["policy/advantages"] = torch.nan_to_num(
381
- train_stats["policy/advantages"], WANDB_PADDING
382
- )
383
- train_stats["policy/ratio"] = torch.flatten(
384
- train_stats["policy/ratio"]
385
- ).unsqueeze(0)
386
-
387
- stats = self.record_step_stats(
388
- scores=scores,
389
- logprobs=all_logprobs,
390
- ref_logprobs=ref_logprobs,
391
- non_score_reward=non_score_reward,
392
- train_stats=train_stats,
393
- kl_coef=self.kl_ctl.value,
394
- masks=masks,
395
- queries=queries,
396
- responses=responses,
397
- kls=kls,
398
- )
399
- # Gather/Reduce stats from all processes
400
- if self.is_distributed:
401
- stats = self.gather_stats(stats)
402
- stats = stats_to_np(stats)
403
- timing["time/ppo/calc_stats"] = time.time() - t
404
- stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
405
-
406
- # Update the KL control - multiply the batch_size by the number of processes
407
- self.kl_ctl.update(
408
- stats["objective/kl"],
409
- self.config.batch_size * self.accelerator.num_processes,
410
- )
411
-
412
- # Log the total ppo time
413
- timing["time/ppo/total"] = time.time() - t0
414
- stats.update(timing)
415
-
416
- # post-process stats for tensorboard and other loggers
417
- if self.config.log_with != "wandb":
418
- stats = convert_to_scalar(stats)
419
-
420
- if self.lr_scheduler is not None:
421
- self.lr_scheduler.step()
422
-
423
- return stats
424
-
425
- def prepare_model_inputs(
426
- self,
427
- queries: torch.Tensor,
428
- responses: torch.Tensor,
429
- values: torch.Tensor,
430
- value_indicators: torch.Tensor,
431
- ):
432
- if self.is_encoder_decoder:
433
- input_data = self.data_collator(
434
- [
435
- {"input_ids": q, "attention_mask": torch.ones_like(q)}
436
- for q in queries
437
- ]
438
- ).to(self.current_device)
439
-
440
- decoder_inputs = self.data_collator(
441
- [
442
- {"input_ids": r, "attention_mask": torch.ones_like(r)}
443
- for r in responses
444
- ]
445
- ).to(self.current_device)
446
- input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
447
- input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
448
- else:
449
- input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
450
- input_data = self.data_collator(
451
- [
452
- {
453
- "input_ids": ids,
454
- "attention_mask": torch.ones_like(ids),
455
- "values": v_s,
456
- "value_indicators": v_indicators,
457
- }
458
- for ids, v_s, v_indicators in zip(
459
- input_ids, values, value_indicators
460
- )
461
- ]
462
- )
463
- input_data.pop("labels", None) # we don't want to compute LM losses
464
- return input_data