cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +267 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +71 -0
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +61 -0
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +224 -0
- cehrgpt/generation/omop_converter_batch.py +3 -0
- cehrgpt/models/hf_cehrgpt.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +2 -2
- cehrgpt/rl_finetune/__init__.py +0 -0
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +586 -0
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +464 -0
- cehrgpt/rl_finetune/ppo_finetune.py +394 -0
- cehrgpt/rl_finetune/ppo_finetune_v2.py +373 -0
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +119 -0
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +24 -3
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +44 -8
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +4 -0
- cehrgpt/tools/generate_causal_patient_split_by_age.py +146 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/METADATA +52 -6
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/RECORD +22 -12
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/WHEEL +1 -1
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/LICENSE +0 -0
- {cehrgpt-0.0.1.dist-info → cehrgpt-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,464 @@
|
|
1
|
+
import time
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
from torch.nn.utils.rnn import pad_sequence
|
7
|
+
from trl.core import (
|
8
|
+
WANDB_PADDING,
|
9
|
+
PPODecorators,
|
10
|
+
convert_to_scalar,
|
11
|
+
logprobs_from_logits,
|
12
|
+
stack_dicts,
|
13
|
+
stats_to_np,
|
14
|
+
)
|
15
|
+
from trl.trainer import PPOTrainer
|
16
|
+
|
17
|
+
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
18
|
+
|
19
|
+
|
20
|
+
class CehrGptPPODataCollator:
|
21
|
+
def __init__(self, tokenizer: CehrGptTokenizer, max_length: int):
|
22
|
+
self.tokenizer = tokenizer
|
23
|
+
self.max_length = max_length
|
24
|
+
|
25
|
+
def __call__(self, examples):
|
26
|
+
|
27
|
+
batch = {}
|
28
|
+
|
29
|
+
# Pad sequences to the max length in the batch
|
30
|
+
batch["input_ids"] = pad_sequence(
|
31
|
+
[example["input_ids"] for example in examples],
|
32
|
+
batch_first=True,
|
33
|
+
padding_value=self.tokenizer.pad_token_id,
|
34
|
+
).to(torch.int64)
|
35
|
+
|
36
|
+
batch["attention_mask"] = pad_sequence(
|
37
|
+
[example["attention_mask"] for example in examples],
|
38
|
+
batch_first=True,
|
39
|
+
padding_value=0.0,
|
40
|
+
)
|
41
|
+
|
42
|
+
assert (
|
43
|
+
batch["input_ids"].shape[1] <= self.max_length
|
44
|
+
), f"Invalid input_ids length: {batch['input_ids'].shape[1]}"
|
45
|
+
|
46
|
+
if "value_indicators" in examples[0]:
|
47
|
+
batch["value_indicators"] = pad_sequence(
|
48
|
+
[example["value_indicators"] for example in examples],
|
49
|
+
batch_first=True,
|
50
|
+
padding_value=False,
|
51
|
+
)
|
52
|
+
|
53
|
+
if "values" in examples[0]:
|
54
|
+
batch["values"] = pad_sequence(
|
55
|
+
[example["values"] for example in examples],
|
56
|
+
batch_first=True,
|
57
|
+
padding_value=self.tokenizer.pad_value_token_id,
|
58
|
+
)
|
59
|
+
assert batch["value_indicators"].shape[1] <= self.max_length
|
60
|
+
assert batch["values"].shape[1] <= self.max_length
|
61
|
+
|
62
|
+
return batch
|
63
|
+
|
64
|
+
|
65
|
+
class CehrGptPPOTrainer(PPOTrainer):
|
66
|
+
def _step_safety_checker(
|
67
|
+
self,
|
68
|
+
batch_size: int,
|
69
|
+
queries: List[torch.LongTensor],
|
70
|
+
responses: List[torch.LongTensor],
|
71
|
+
scores: List[torch.FloatTensor],
|
72
|
+
values: List[torch.Tensor] = None,
|
73
|
+
value_indicators: List[torch.BoolTensor] = None,
|
74
|
+
masks: Optional[List[torch.LongTensor]] = None,
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Check if the input data is valid for training.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
batch_size (int):
|
81
|
+
Batch size from the config file.
|
82
|
+
queries (List[`torch.LongTensor`]):
|
83
|
+
List of tensors containing the encoded queries of shape (`query_length`)
|
84
|
+
responses (List[`torch.LongTensor`]):
|
85
|
+
List of tensors containing the encoded responses of shape (`response_length`)
|
86
|
+
scores (List[`torch.FloatTensor`]):
|
87
|
+
List of tensors containing the scores.
|
88
|
+
masks (List[`torch.LongTensor`], *optional*):
|
89
|
+
list of optional tensors containing the masks of shape (`response_length`)
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
`tuple`: The input processed data.
|
93
|
+
"""
|
94
|
+
for name, tensor_list in zip(
|
95
|
+
["queries", "responses", "scores", "values", "value_indicators"],
|
96
|
+
[queries, responses, scores, values, value_indicators],
|
97
|
+
):
|
98
|
+
if not isinstance(tensor_list, list):
|
99
|
+
raise ValueError(
|
100
|
+
f"{name} must be a list of tensors - got {type(tensor_list)}"
|
101
|
+
)
|
102
|
+
if not isinstance(tensor_list[0], torch.Tensor):
|
103
|
+
raise ValueError(
|
104
|
+
f"Elements in {name} must be tensors - got {type(tensor_list[0])}"
|
105
|
+
)
|
106
|
+
if batch_size is not None and len(tensor_list) != batch_size:
|
107
|
+
raise ValueError(
|
108
|
+
f"Batch size ({batch_size}) does not match number of examples - but got {len(tensor_list)} for: {name}"
|
109
|
+
)
|
110
|
+
|
111
|
+
# add queries, scores and responses on the correct device
|
112
|
+
queries = [tensor.to(self.current_device) for tensor in queries]
|
113
|
+
responses = [tensor.to(self.current_device) for tensor in responses]
|
114
|
+
scores = [tensor.to(self.current_device) for tensor in scores]
|
115
|
+
masks = (
|
116
|
+
[tensor.to(self.current_device) for tensor in masks]
|
117
|
+
if masks is not None
|
118
|
+
else None
|
119
|
+
)
|
120
|
+
values = (
|
121
|
+
[tensor.to(self.current_device) for tensor in values]
|
122
|
+
if values is not None
|
123
|
+
else None
|
124
|
+
)
|
125
|
+
value_indicators = (
|
126
|
+
[tensor.to(self.current_device) for tensor in value_indicators]
|
127
|
+
if value_indicators is not None
|
128
|
+
else None
|
129
|
+
)
|
130
|
+
|
131
|
+
# squeeze scores if needed
|
132
|
+
for i, score in enumerate(scores):
|
133
|
+
if score.dim() > 1:
|
134
|
+
raise ValueError(
|
135
|
+
f"Scores must be 1-dimensional - got {score.dim()} for {score}"
|
136
|
+
)
|
137
|
+
elif score.dim() == 1:
|
138
|
+
scores[i] = score.squeeze()
|
139
|
+
|
140
|
+
return queries, responses, scores, values, value_indicators, masks
|
141
|
+
|
142
|
+
@PPODecorators.empty_device_cache()
|
143
|
+
def step(
|
144
|
+
self,
|
145
|
+
queries: List[torch.LongTensor],
|
146
|
+
responses: List[torch.LongTensor],
|
147
|
+
scores: List[torch.FloatTensor],
|
148
|
+
values: List[torch.Tensor] = None,
|
149
|
+
value_indicators: List[torch.BoolTensor] = None,
|
150
|
+
response_masks: Optional[List[torch.LongTensor]] = None,
|
151
|
+
):
|
152
|
+
|
153
|
+
bs = self.config.batch_size
|
154
|
+
|
155
|
+
queries, responses, scores, values, value_indicators, response_masks = (
|
156
|
+
self._step_safety_checker(
|
157
|
+
bs, queries, responses, scores, values, value_indicators, response_masks
|
158
|
+
)
|
159
|
+
)
|
160
|
+
scores = torch.tensor(scores, device=self.current_device)
|
161
|
+
if self.config.use_score_scaling:
|
162
|
+
# Score scaling
|
163
|
+
scores_mean, scores_std = self.running.update(scores)
|
164
|
+
score_scaling_factor = scores_std + torch.finfo(scores.dtype).eps
|
165
|
+
if self.config.use_score_norm:
|
166
|
+
scores = (scores - scores_mean) / score_scaling_factor
|
167
|
+
else:
|
168
|
+
scores /= score_scaling_factor
|
169
|
+
|
170
|
+
if self.config.score_clip is not None:
|
171
|
+
# Score clipping
|
172
|
+
scores_dtype = scores.dtype
|
173
|
+
scores = torch.clip(
|
174
|
+
scores.float(), -self.config.score_clip, self.config.score_clip
|
175
|
+
).to(dtype=scores_dtype)
|
176
|
+
|
177
|
+
# if we want to push best model to the hub
|
178
|
+
if hasattr(self, "highest_reward"):
|
179
|
+
if self.compare_step % self.config.compare_steps == 0:
|
180
|
+
curr_mean_reward = scores.mean()
|
181
|
+
# if the best reward ever seen
|
182
|
+
if curr_mean_reward > self.highest_reward:
|
183
|
+
self.highest_reward = curr_mean_reward
|
184
|
+
# push model to hub
|
185
|
+
self.push_to_hub(**self.push_to_hub_kwargs)
|
186
|
+
self.compare_step += 1
|
187
|
+
|
188
|
+
timing = dict()
|
189
|
+
t0 = time.time()
|
190
|
+
|
191
|
+
t = time.time()
|
192
|
+
|
193
|
+
model_inputs = self.prepare_model_inputs(
|
194
|
+
queries, responses, values, value_indicators
|
195
|
+
)
|
196
|
+
|
197
|
+
if self.is_distributed:
|
198
|
+
pad_first = self.tokenizer.padding_side == "left"
|
199
|
+
|
200
|
+
model_inputs["input_ids"] = self.accelerator.pad_across_processes(
|
201
|
+
model_inputs["input_ids"],
|
202
|
+
dim=1,
|
203
|
+
pad_index=self.tokenizer.pad_token_id,
|
204
|
+
pad_first=pad_first,
|
205
|
+
)
|
206
|
+
model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
|
207
|
+
model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
|
208
|
+
)
|
209
|
+
if values is not None:
|
210
|
+
model_inputs["values"] = self.accelerator.pad_across_processes(
|
211
|
+
model_inputs["values"],
|
212
|
+
dim=1,
|
213
|
+
pad_index=self.tokenizer.pad_value_token_id,
|
214
|
+
pad_first=pad_first,
|
215
|
+
)
|
216
|
+
if value_indicators is not None:
|
217
|
+
model_inputs["value_indicators"] = (
|
218
|
+
self.accelerator.pad_across_processes(
|
219
|
+
model_inputs["value_indicators"],
|
220
|
+
dim=1,
|
221
|
+
pad_index=False,
|
222
|
+
pad_first=pad_first,
|
223
|
+
)
|
224
|
+
)
|
225
|
+
if self.is_encoder_decoder:
|
226
|
+
model_inputs["decoder_input_ids"] = (
|
227
|
+
self.accelerator.pad_across_processes(
|
228
|
+
model_inputs["decoder_input_ids"],
|
229
|
+
dim=1,
|
230
|
+
pad_index=self.tokenizer.pad_token_id,
|
231
|
+
pad_first=pad_first,
|
232
|
+
)
|
233
|
+
)
|
234
|
+
model_inputs["decoder_attention_mask"] = (
|
235
|
+
self.accelerator.pad_across_processes(
|
236
|
+
model_inputs["decoder_attention_mask"],
|
237
|
+
dim=1,
|
238
|
+
pad_index=0,
|
239
|
+
pad_first=pad_first,
|
240
|
+
)
|
241
|
+
)
|
242
|
+
|
243
|
+
model_inputs_names = list(model_inputs.keys())
|
244
|
+
|
245
|
+
full_kl_penalty = self.config.kl_penalty == "full"
|
246
|
+
|
247
|
+
with torch.no_grad():
|
248
|
+
all_logprobs, logits_or_none, states_values, masks = (
|
249
|
+
self.batched_forward_pass(
|
250
|
+
self.model,
|
251
|
+
queries,
|
252
|
+
responses,
|
253
|
+
model_inputs,
|
254
|
+
response_masks=response_masks,
|
255
|
+
return_logits=full_kl_penalty,
|
256
|
+
)
|
257
|
+
)
|
258
|
+
with self.optional_peft_ctx():
|
259
|
+
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
|
260
|
+
self.model if self.is_peft_model else self.ref_model,
|
261
|
+
queries,
|
262
|
+
responses,
|
263
|
+
model_inputs,
|
264
|
+
return_logits=full_kl_penalty,
|
265
|
+
)
|
266
|
+
|
267
|
+
timing["time/ppo/forward_pass"] = time.time() - t
|
268
|
+
|
269
|
+
with torch.no_grad():
|
270
|
+
t = time.time()
|
271
|
+
if full_kl_penalty:
|
272
|
+
active_full_logprobs = logprobs_from_logits(
|
273
|
+
logits_or_none, None, gather=False
|
274
|
+
)
|
275
|
+
ref_full_logprobs = logprobs_from_logits(
|
276
|
+
ref_logits_or_none, None, gather=False
|
277
|
+
)
|
278
|
+
|
279
|
+
rewards, non_score_reward, kls = self.compute_rewards(
|
280
|
+
scores, active_full_logprobs, ref_full_logprobs, masks
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
rewards, non_score_reward, kls = self.compute_rewards(
|
284
|
+
scores, all_logprobs, ref_logprobs, masks
|
285
|
+
)
|
286
|
+
timing["time/ppo/compute_rewards"] = time.time() - t
|
287
|
+
|
288
|
+
t = time.time()
|
289
|
+
states_values, advantages, returns = self.compute_advantages(
|
290
|
+
states_values, rewards, masks
|
291
|
+
)
|
292
|
+
timing["time/ppo/compute_advantages"] = time.time() - t
|
293
|
+
|
294
|
+
# upcast to float32 to avoid dataset issues
|
295
|
+
batch_dict = {
|
296
|
+
"queries": queries,
|
297
|
+
"responses": responses,
|
298
|
+
"logprobs": all_logprobs.to(torch.float32),
|
299
|
+
"states_values": states_values.to(torch.float32),
|
300
|
+
"masks": masks,
|
301
|
+
"advantages": advantages,
|
302
|
+
"returns": returns,
|
303
|
+
}
|
304
|
+
batch_dict.update(model_inputs)
|
305
|
+
|
306
|
+
t = time.time()
|
307
|
+
all_stats = []
|
308
|
+
early_stop = False
|
309
|
+
for _ in range(self.config.ppo_epochs):
|
310
|
+
if early_stop:
|
311
|
+
break
|
312
|
+
b_inds = np.random.permutation(bs)
|
313
|
+
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
|
314
|
+
backward_batch_end = (
|
315
|
+
backward_batch_start + self.config.backward_batch_size
|
316
|
+
)
|
317
|
+
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
|
318
|
+
|
319
|
+
for mini_batch_start in range(
|
320
|
+
0, self.config.backward_batch_size, self.config.mini_batch_size
|
321
|
+
):
|
322
|
+
mini_batch_end = mini_batch_start + self.config.mini_batch_size
|
323
|
+
mini_batch_inds = backward_batch_inds[
|
324
|
+
mini_batch_start:mini_batch_end
|
325
|
+
]
|
326
|
+
mini_batch_dict = {
|
327
|
+
"logprobs": batch_dict["logprobs"][mini_batch_inds],
|
328
|
+
"states_values": batch_dict["states_values"][mini_batch_inds],
|
329
|
+
"masks": batch_dict["masks"][mini_batch_inds],
|
330
|
+
# hacks: the queries and responses are ragged.
|
331
|
+
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
|
332
|
+
"responses": [
|
333
|
+
batch_dict["responses"][i] for i in mini_batch_inds
|
334
|
+
],
|
335
|
+
"advantages": batch_dict["advantages"][mini_batch_inds],
|
336
|
+
"returns": batch_dict["returns"][mini_batch_inds],
|
337
|
+
}
|
338
|
+
for k in model_inputs_names:
|
339
|
+
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
|
340
|
+
with self.accelerator.accumulate(self.model):
|
341
|
+
model_inputs = {
|
342
|
+
k: mini_batch_dict[k] for k in model_inputs_names
|
343
|
+
}
|
344
|
+
|
345
|
+
logprobs, logits, vpreds, _ = self.batched_forward_pass(
|
346
|
+
self.model,
|
347
|
+
mini_batch_dict["queries"],
|
348
|
+
mini_batch_dict["responses"],
|
349
|
+
model_inputs,
|
350
|
+
return_logits=True,
|
351
|
+
)
|
352
|
+
train_stats = self.train_minibatch(
|
353
|
+
mini_batch_dict["logprobs"],
|
354
|
+
mini_batch_dict["states_values"],
|
355
|
+
logprobs,
|
356
|
+
logits,
|
357
|
+
vpreds,
|
358
|
+
mini_batch_dict["masks"],
|
359
|
+
mini_batch_dict["advantages"],
|
360
|
+
mini_batch_dict["returns"],
|
361
|
+
)
|
362
|
+
all_stats.append(train_stats)
|
363
|
+
|
364
|
+
# typically, early stopping is done at the epoch level
|
365
|
+
if self.config.early_stopping:
|
366
|
+
policykl = train_stats["policy/policykl"]
|
367
|
+
early_stop = self._early_stop(policykl)
|
368
|
+
if early_stop:
|
369
|
+
break
|
370
|
+
|
371
|
+
timing["time/ppo/optimize_step"] = time.time() - t
|
372
|
+
|
373
|
+
t = time.time()
|
374
|
+
train_stats = stack_dicts(all_stats)
|
375
|
+
|
376
|
+
# reshape advantages/ratios such that they are not averaged.
|
377
|
+
train_stats["policy/advantages"] = torch.flatten(
|
378
|
+
train_stats["policy/advantages"]
|
379
|
+
).unsqueeze(0)
|
380
|
+
train_stats["policy/advantages"] = torch.nan_to_num(
|
381
|
+
train_stats["policy/advantages"], WANDB_PADDING
|
382
|
+
)
|
383
|
+
train_stats["policy/ratio"] = torch.flatten(
|
384
|
+
train_stats["policy/ratio"]
|
385
|
+
).unsqueeze(0)
|
386
|
+
|
387
|
+
stats = self.record_step_stats(
|
388
|
+
scores=scores,
|
389
|
+
logprobs=all_logprobs,
|
390
|
+
ref_logprobs=ref_logprobs,
|
391
|
+
non_score_reward=non_score_reward,
|
392
|
+
train_stats=train_stats,
|
393
|
+
kl_coef=self.kl_ctl.value,
|
394
|
+
masks=masks,
|
395
|
+
queries=queries,
|
396
|
+
responses=responses,
|
397
|
+
kls=kls,
|
398
|
+
)
|
399
|
+
# Gather/Reduce stats from all processes
|
400
|
+
if self.is_distributed:
|
401
|
+
stats = self.gather_stats(stats)
|
402
|
+
stats = stats_to_np(stats)
|
403
|
+
timing["time/ppo/calc_stats"] = time.time() - t
|
404
|
+
stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
|
405
|
+
|
406
|
+
# Update the KL control - multiply the batch_size by the number of processes
|
407
|
+
self.kl_ctl.update(
|
408
|
+
stats["objective/kl"],
|
409
|
+
self.config.batch_size * self.accelerator.num_processes,
|
410
|
+
)
|
411
|
+
|
412
|
+
# Log the total ppo time
|
413
|
+
timing["time/ppo/total"] = time.time() - t0
|
414
|
+
stats.update(timing)
|
415
|
+
|
416
|
+
# post-process stats for tensorboard and other loggers
|
417
|
+
if self.config.log_with != "wandb":
|
418
|
+
stats = convert_to_scalar(stats)
|
419
|
+
|
420
|
+
if self.lr_scheduler is not None:
|
421
|
+
self.lr_scheduler.step()
|
422
|
+
|
423
|
+
return stats
|
424
|
+
|
425
|
+
def prepare_model_inputs(
|
426
|
+
self,
|
427
|
+
queries: torch.Tensor,
|
428
|
+
responses: torch.Tensor,
|
429
|
+
values: torch.Tensor,
|
430
|
+
value_indicators: torch.Tensor,
|
431
|
+
):
|
432
|
+
if self.is_encoder_decoder:
|
433
|
+
input_data = self.data_collator(
|
434
|
+
[
|
435
|
+
{"input_ids": q, "attention_mask": torch.ones_like(q)}
|
436
|
+
for q in queries
|
437
|
+
]
|
438
|
+
).to(self.current_device)
|
439
|
+
|
440
|
+
decoder_inputs = self.data_collator(
|
441
|
+
[
|
442
|
+
{"input_ids": r, "attention_mask": torch.ones_like(r)}
|
443
|
+
for r in responses
|
444
|
+
]
|
445
|
+
).to(self.current_device)
|
446
|
+
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
|
447
|
+
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
|
448
|
+
else:
|
449
|
+
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
|
450
|
+
input_data = self.data_collator(
|
451
|
+
[
|
452
|
+
{
|
453
|
+
"input_ids": ids,
|
454
|
+
"attention_mask": torch.ones_like(ids),
|
455
|
+
"values": v_s,
|
456
|
+
"value_indicators": v_indicators,
|
457
|
+
}
|
458
|
+
for ids, v_s, v_indicators in zip(
|
459
|
+
input_ids, values, value_indicators
|
460
|
+
)
|
461
|
+
]
|
462
|
+
)
|
463
|
+
input_data.pop("labels", None) # we don't want to compute LM losses
|
464
|
+
return input_data
|