cehrgpt 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,586 @@
1
+ import warnings
2
+ from collections import defaultdict
3
+ from copy import deepcopy
4
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from accelerate.utils import is_deepspeed_available
10
+ from datasets import Dataset
11
+ from transformers import PreTrainedModel, Trainer
12
+ from transformers.trainer_callback import TrainerCallback
13
+ from transformers.trainer_utils import EvalLoopOutput
14
+ from trl.trainer.callbacks import SyncRefModelCallback
15
+ from trl.trainer.dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
16
+ from trl.trainer.utils import RunningMoments, cap_exp, disable_dropout_in_model
17
+
18
+ from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
19
+ from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
20
+ from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
21
+
22
+ if is_deepspeed_available():
23
+ import deepspeed
24
+
25
+
26
+ class CehrGptDPOTrainer(Trainer):
27
+ def __init__(
28
+ self,
29
+ model: CEHRGPT2LMHeadModel,
30
+ ref_model: CEHRGPT2LMHeadModel,
31
+ tokenizer: CehrGptTokenizer,
32
+ args: DPOConfig,
33
+ data_collator: CehrGptDPODataCollator,
34
+ train_dataset: Optional[Dataset] = None,
35
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
36
+ callbacks: Optional[List[TrainerCallback]] = None,
37
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
38
+ None,
39
+ None,
40
+ ),
41
+ preprocess_logits_for_metrics: Optional[
42
+ Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
43
+ ] = None,
44
+ disable_dropout: bool = True,
45
+ compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
46
+ ):
47
+ if ref_model is model:
48
+ raise ValueError(
49
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
50
+ "same as `model`, you must mass a copy of it, or `None` if you use peft."
51
+ )
52
+
53
+ if getattr(args, "gradient_checkpointing", False):
54
+ # For backward compatibility with older versions of transformers
55
+ if hasattr(model, "enable_input_require_grads"):
56
+ model.enable_input_require_grads()
57
+ else:
58
+
59
+ def make_inputs_require_grad(module, input, output):
60
+ output.requires_grad_(True)
61
+
62
+ model.get_input_embeddings().register_forward_hook(
63
+ make_inputs_require_grad
64
+ )
65
+
66
+ if model is not None:
67
+ self.is_encoder_decoder = model.config.is_encoder_decoder
68
+ elif args.is_encoder_decoder is None:
69
+ raise ValueError(
70
+ "When no model is provided, you need to pass the parameter is_encoder_decoder to the DPOTrainer/DPOConfig."
71
+ )
72
+ else:
73
+ self.is_encoder_decoder = args.is_encoder_decoder
74
+
75
+ self.tokenizer = tokenizer
76
+ self.ref_model = ref_model
77
+
78
+ if not disable_dropout:
79
+ warnings.warn(
80
+ "You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
81
+ )
82
+ args.disable_dropout = disable_dropout
83
+
84
+ if args.disable_dropout:
85
+ disable_dropout_in_model(model)
86
+ if self.ref_model is not None:
87
+ disable_dropout_in_model(self.ref_model)
88
+
89
+ self.beta = args.beta
90
+ self.label_smoothing = args.label_smoothing
91
+ self.loss_type = args.loss_type
92
+ self._stored_metrics = defaultdict(lambda: defaultdict(list))
93
+ self.f_divergence_type = args.f_divergence_type
94
+ self.f_divergence_params = {
95
+ FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef
96
+ }
97
+ self.reference_free = args.reference_free
98
+ # α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
99
+ # weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
100
+ # DPO loss. The paper recommends `rpo_alpha=1.0`.
101
+ self.rpo_alpha = args.rpo_alpha
102
+ self.vs_token_id = tokenizer._convert_token_to_id("VS")
103
+ if self.vs_token_id == tokenizer._oov_token_id:
104
+ self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
105
+ self.ve_token_id = tokenizer._convert_token_to_id("VE")
106
+ if self.ve_token_id == tokenizer._oov_token_id:
107
+ self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
108
+ self.non_active_token_ids = [
109
+ tokenizer.pad_token_id,
110
+ self.vs_token_id,
111
+ self.ve_token_id,
112
+ ]
113
+
114
+ super().__init__(
115
+ model=model,
116
+ args=args,
117
+ data_collator=data_collator,
118
+ train_dataset=train_dataset,
119
+ eval_dataset=eval_dataset,
120
+ tokenizer=tokenizer,
121
+ compute_metrics=compute_metrics,
122
+ callbacks=callbacks,
123
+ optimizers=optimizers,
124
+ preprocess_logits_for_metrics=preprocess_logits_for_metrics,
125
+ )
126
+
127
+ if not hasattr(self, "accelerator"):
128
+ raise AttributeError(
129
+ "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
130
+ )
131
+
132
+ if self.is_deepspeed_enabled:
133
+ self.ref_model = self._prepare_deepspeed(self.ref_model)
134
+ else:
135
+ self.ref_model = self.accelerator.prepare_model(
136
+ self.ref_model, evaluation_mode=True
137
+ )
138
+ self.add_callback(
139
+ SyncRefModelCallback(
140
+ ref_model=self.ref_model, accelerator=self.accelerator
141
+ )
142
+ )
143
+ if self.loss_type == "bco_pair":
144
+ self.running = RunningMoments(self.accelerator)
145
+
146
+ def _prepare_deepspeed(self, model: CEHRGPT2LMHeadModel):
147
+ # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
148
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
149
+ config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
150
+
151
+ if model is not None:
152
+ if hasattr(model, "config"):
153
+ hidden_size = (
154
+ max(model.config.hidden_sizes)
155
+ if getattr(model.config, "hidden_sizes", None)
156
+ else getattr(model.config, "hidden_size", None)
157
+ )
158
+ if (
159
+ hidden_size is not None
160
+ and config_kwargs["zero_optimization"]["stage"] == 3
161
+ ):
162
+ # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
163
+ # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
164
+ config_kwargs.update(
165
+ {
166
+ "zero_optimization.reduce_bucket_size": hidden_size
167
+ * hidden_size,
168
+ "zero_optimization.stage3_param_persistence_threshold": 10
169
+ * hidden_size,
170
+ "zero_optimization.stage3_prefetch_bucket_size": 0.9
171
+ * hidden_size
172
+ * hidden_size,
173
+ }
174
+ )
175
+
176
+ # If ZeRO-3 is used, we shard both the active and reference model.
177
+ # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
178
+ if config_kwargs["zero_optimization"]["stage"] != 3:
179
+ config_kwargs["zero_optimization"]["stage"] = 0
180
+ model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
181
+ model.eval()
182
+ return model
183
+
184
+ def forward(
185
+ self,
186
+ model: CEHRGPT2LMHeadModel,
187
+ batch: Dict[str, torch.Tensor],
188
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
189
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
190
+
191
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
192
+ """
193
+ chosen_outputs = model(
194
+ input_ids=batch["chosen_input_ids"],
195
+ attention_mask=batch["chosen_attention_mask"],
196
+ value_indicators=(
197
+ batch["chosen_value_indicators"]
198
+ if "chosen_value_indicators" in batch
199
+ else None
200
+ ),
201
+ values=batch["chosen_values"] if "chosen_values" in batch else None,
202
+ )
203
+
204
+ chosen_logps, chosen_logits = self.get_batch_logps(
205
+ chosen_outputs.logits,
206
+ batch["chosen_input_ids"],
207
+ pad_token_ids=[self.tokenizer.pad_token_id],
208
+ )
209
+
210
+ # move labels to correct device to enable model parallelism
211
+ labels = batch["chosen_input_ids"].to(chosen_outputs.logits.device)
212
+ # Shift so that tokens < n predict n
213
+ shift_logits = chosen_outputs.logits[..., :-1, :].contiguous()
214
+ shift_labels = labels[..., 1:].contiguous()
215
+ # Flatten the tokens
216
+ loss_fct = nn.CrossEntropyLoss()
217
+ nll_loss = loss_fct(
218
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
219
+ )
220
+
221
+ rejected_outputs = model(
222
+ input_ids=batch["rejected_input_ids"],
223
+ attention_mask=batch["rejected_attention_mask"],
224
+ value_indicators=(
225
+ batch["rejected_value_indicators"]
226
+ if "rejected_value_indicators" in batch
227
+ else None
228
+ ),
229
+ values=batch["rejected_values"] if "rejected_values" in batch else None,
230
+ )
231
+
232
+ rejected_logps, rejected_logits = self.get_batch_logps(
233
+ rejected_outputs.logits,
234
+ batch["rejected_input_ids"],
235
+ pad_token_ids=[self.tokenizer.pad_token_id],
236
+ )
237
+
238
+ return (
239
+ chosen_logps,
240
+ rejected_logps,
241
+ chosen_logits,
242
+ rejected_logits,
243
+ nll_loss,
244
+ )
245
+
246
+ def dpo_loss(
247
+ self,
248
+ policy_chosen_logps: torch.Tensor,
249
+ policy_rejected_logps: torch.Tensor,
250
+ reference_chosen_logps: torch.Tensor,
251
+ reference_rejected_logps: torch.Tensor,
252
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
253
+ """Compute the DPO loss for a batch of policy and reference model log probabilities.
254
+
255
+ Args:
256
+ policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
257
+ policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
258
+ reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
259
+ reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
260
+
261
+ Returns:
262
+ A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
263
+ The losses tensor contains the DPO loss for each example in the batch.
264
+ The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
265
+ """
266
+ chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
267
+ not self.reference_free
268
+ ) * reference_chosen_logps.to(self.accelerator.device)
269
+ rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - (
270
+ not self.reference_free
271
+ ) * reference_rejected_logps.to(self.accelerator.device)
272
+
273
+ if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
274
+ # The alpha-divergence formula: (1 - u^-alpha) / alpha
275
+ # The divergence difference between the chosen and rejected sample is:
276
+ # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
277
+ # = (u[l]^-alpha - u[w]^-alpha) / alpha
278
+ # where u[w] and u[l] are the policy/reference probability ratios
279
+ # for the chosen and rejected samples, respectively.
280
+ alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
281
+ if (
282
+ self.f_divergence_params
283
+ and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
284
+ in self.f_divergence_params
285
+ ):
286
+ alpha_coef = float(
287
+ self.f_divergence_params[
288
+ FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
289
+ ]
290
+ )
291
+ logits = (
292
+ cap_exp(rejected_logratios * -alpha_coef)
293
+ - cap_exp(chosen_logratios * -alpha_coef)
294
+ ) / alpha_coef
295
+ else:
296
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
297
+ if self.reference_free:
298
+ ref_logratios = torch.tensor(
299
+ [0], dtype=pi_logratios.dtype, device=pi_logratios.device
300
+ )
301
+ else:
302
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
303
+
304
+ pi_logratios = pi_logratios.to(self.accelerator.device)
305
+ ref_logratios = ref_logratios.to(self.accelerator.device)
306
+ logits = pi_logratios - ref_logratios
307
+
308
+ if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
309
+ # The js-divergence formula: log(2 * u / (1 + u))
310
+ # The divergence difference between the chosen and rejected sample is:
311
+ # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
312
+ # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
313
+ # where u[w] and u[l] are the policy/reference probability ratios
314
+ # for the chosen and rejected samples, respectively.
315
+ logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
316
+
317
+ # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
318
+ # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
319
+ # calculates a conservative DPO loss.
320
+ if self.loss_type == "sigmoid":
321
+ losses = (
322
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
323
+ - F.logsigmoid(-self.beta * logits) * self.label_smoothing
324
+ )
325
+ elif self.loss_type == "robust":
326
+ losses = (
327
+ -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
328
+ + F.logsigmoid(-self.beta * logits) * self.label_smoothing
329
+ ) / (1 - 2 * self.label_smoothing)
330
+ elif self.loss_type == "exo_pair":
331
+ # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856
332
+ import math
333
+
334
+ if self.label_smoothing == 0:
335
+ self.label_smoothing = 1e-3
336
+ losses = (self.beta * logits).sigmoid() * (
337
+ F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing)
338
+ ) + (-self.beta * logits).sigmoid() * (
339
+ F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)
340
+ )
341
+ elif self.loss_type == "hinge":
342
+ losses = torch.relu(1 - self.beta * logits)
343
+ elif self.loss_type == "ipo":
344
+ # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
345
+ losses = (logits - 1 / (2 * self.beta)) ** 2
346
+ elif self.loss_type == "bco_pair":
347
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
348
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
349
+
350
+ chosen_rewards = self.beta * chosen_logratios
351
+ rejected_rewards = self.beta * rejected_logratios
352
+ rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
353
+ self.running.update(rewards)
354
+ delta = self.running.mean
355
+
356
+ losses = -F.logsigmoid(
357
+ (self.beta * chosen_logratios) - delta
358
+ ) - F.logsigmoid(-(self.beta * rejected_logratios - delta))
359
+ elif self.loss_type == "sppo_hard":
360
+ # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser.
361
+ a = policy_chosen_logps - reference_chosen_logps
362
+ b = policy_rejected_logps - reference_rejected_logps
363
+
364
+ losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
365
+ elif self.loss_type == "nca_pair":
366
+ chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta
367
+ rejected_rewards = (
368
+ policy_rejected_logps - reference_rejected_logps
369
+ ) * self.beta
370
+ losses = (
371
+ -F.logsigmoid(chosen_rewards)
372
+ - 0.5 * F.logsigmoid(-chosen_rewards)
373
+ - 0.5 * F.logsigmoid(-rejected_rewards)
374
+ )
375
+ elif self.loss_type == "aot_pair":
376
+ chosen_logratios = policy_chosen_logps - reference_chosen_logps
377
+ rejected_logratios = policy_rejected_logps - reference_rejected_logps
378
+
379
+ chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0)
380
+ rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0)
381
+
382
+ delta = chosen_logratios_sorted - rejected_logratios_sorted
383
+
384
+ losses = (
385
+ -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
386
+ - F.logsigmoid(-self.beta * delta) * self.label_smoothing
387
+ )
388
+
389
+ elif self.loss_type == "aot":
390
+ pi_logratios = policy_chosen_logps - policy_rejected_logps
391
+ ref_logratios = reference_chosen_logps - reference_rejected_logps
392
+
393
+ pi_logratios_sorted, _ = torch.sort(pi_logratios, dim=0)
394
+ ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0)
395
+
396
+ delta = pi_logratios_sorted - ref_logratios_sorted
397
+
398
+ losses = (
399
+ -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
400
+ - F.logsigmoid(-self.beta * delta) * self.label_smoothing
401
+ )
402
+
403
+ elif self.loss_type == "apo_zero":
404
+ # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
405
+ # Use this loss when you believe the chosen outputs are better than your model's default output
406
+
407
+ losses_chosen = 1 - F.sigmoid(
408
+ self.beta * chosen_logratios
409
+ ) # Increase chosen likelihood
410
+ losses_rejected = F.sigmoid(
411
+ self.beta * rejected_logratios
412
+ ) # Decrease rejected likelihood
413
+
414
+ losses = losses_chosen + losses_rejected
415
+
416
+ elif self.loss_type == "apo_down":
417
+ # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
418
+ # Use this loss when you believe the chosen outputs are worse than your model's default output
419
+
420
+ losses_chosen = F.sigmoid(
421
+ self.beta * chosen_logratios
422
+ ) # Decrease chosen likelihood
423
+ losses_rejected = 1 - F.sigmoid(
424
+ self.beta * (chosen_logratios - rejected_logratios)
425
+ ) # Decrease rejected likelihood more
426
+
427
+ losses = losses_chosen + losses_rejected
428
+
429
+ else:
430
+ raise ValueError(
431
+ f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']"
432
+ )
433
+
434
+ chosen_rewards = (
435
+ self.beta
436
+ * (
437
+ policy_chosen_logps.to(self.accelerator.device)
438
+ - reference_chosen_logps.to(self.accelerator.device)
439
+ ).detach()
440
+ )
441
+ rejected_rewards = (
442
+ self.beta
443
+ * (
444
+ policy_rejected_logps.to(self.accelerator.device)
445
+ - reference_rejected_logps.to(self.accelerator.device)
446
+ ).detach()
447
+ )
448
+
449
+ return losses, chosen_rewards, rejected_rewards
450
+
451
+ def get_batch_loss_metrics(
452
+ self,
453
+ model,
454
+ batch: Dict[str, Union[List, torch.LongTensor]],
455
+ train_eval: Literal["train", "eval"] = "train",
456
+ ):
457
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
458
+ metrics = {}
459
+ forward_output = self.forward(model, batch)
460
+ (
461
+ policy_chosen_logps,
462
+ policy_rejected_logps,
463
+ policy_chosen_logits,
464
+ policy_rejected_logits,
465
+ policy_nll_loss,
466
+ ) = forward_output[:5]
467
+
468
+ with torch.no_grad():
469
+ reference_chosen_logps, reference_rejected_logps = self.forward(
470
+ self.ref_model, batch
471
+ )[:2]
472
+
473
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
474
+ policy_chosen_logps,
475
+ policy_rejected_logps,
476
+ reference_chosen_logps,
477
+ reference_rejected_logps,
478
+ )
479
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
480
+
481
+ if self.rpo_alpha is not None:
482
+ # RPO loss from V3 of the paper:
483
+ losses = losses + policy_nll_loss * self.rpo_alpha
484
+
485
+ prefix = "eval_" if train_eval == "eval" else ""
486
+ metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
487
+ metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
488
+ metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
489
+ metrics[f"{prefix}rewards/margins"] = (
490
+ (chosen_rewards - rejected_rewards).mean().cpu()
491
+ )
492
+ metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
493
+ metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
494
+ metrics[f"{prefix}logits/rejected"] = (
495
+ policy_rejected_logits.detach().mean().cpu()
496
+ )
497
+ metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
498
+ if self.rpo_alpha is not None:
499
+ metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
500
+ return losses.mean(), metrics
501
+
502
+ @staticmethod
503
+ def get_batch_logps(
504
+ all_logits: torch.FloatTensor, labels: torch.Tensor, pad_token_ids: List[int]
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ # remove the demographics
507
+ all_logits = all_logits[:, 3:-1]
508
+ labels = labels.clone()[:, 4:]
509
+ # Calculate the sequence log probability log p(x)
510
+ per_token_logps = torch.gather(
511
+ all_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
512
+ ).squeeze(2)
513
+ non_pad_mask = ~torch.isin(
514
+ labels, torch.tensor(pad_token_ids, device=labels.device)
515
+ )
516
+ per_token_logps *= non_pad_mask
517
+ all_logps = per_token_logps.sum(-1)
518
+ return all_logps, all_logits
519
+
520
+ def compute_loss(
521
+ self,
522
+ model: Union[PreTrainedModel, nn.Module],
523
+ inputs: Dict[str, Union[torch.Tensor, Any]],
524
+ return_outputs=False,
525
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
526
+ loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
527
+ # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
528
+ loss = loss.to(self.args.device)
529
+ # force log the metrics
530
+ self.store_metrics(metrics, train_eval="train")
531
+ if return_outputs:
532
+ return (loss, metrics)
533
+ return loss
534
+
535
+ def store_metrics(
536
+ self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
537
+ ) -> None:
538
+ for key, value in metrics.items():
539
+ self._stored_metrics[train_eval][key].append(value)
540
+
541
+ def log(self, logs: Dict[str, float]) -> None:
542
+ """
543
+ Log `logs` on the various objects watching training, including stored metrics.
544
+
545
+ Args:
546
+ logs (`Dict[str, float]`):
547
+ The values to log.
548
+ """
549
+ # logs either has 'loss' or 'eval_loss'
550
+ train_eval = "train" if "loss" in logs else "eval"
551
+ # Add averaged stored metrics to logs
552
+ for key, metrics in self._stored_metrics[train_eval].items():
553
+ logs[key] = torch.tensor(metrics).mean().item()
554
+ del self._stored_metrics[train_eval]
555
+ return super().log(logs)
556
+
557
+ def prediction_step(
558
+ self,
559
+ model: Union[PreTrainedModel, nn.Module],
560
+ inputs: Dict[str, Union[torch.Tensor, Any]],
561
+ prediction_loss_only: bool,
562
+ ignore_keys: Optional[List[str]] = None,
563
+ ):
564
+ with torch.no_grad():
565
+ loss, metrics = self.get_batch_loss_metrics(
566
+ model, inputs, train_eval="eval"
567
+ )
568
+
569
+ # force log the metrics
570
+ self.store_metrics(metrics, train_eval="eval")
571
+
572
+ if prediction_loss_only:
573
+ return (loss.detach(), None, None)
574
+
575
+ # logits for the chosen and rejected samples from model
576
+ logits_dict = {
577
+ "eval_logits/chosen": metrics["eval_logits/chosen"],
578
+ "eval_logits/rejected": metrics["eval_logits/rejected"],
579
+ }
580
+ logits = tuple(
581
+ v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys
582
+ )
583
+ logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
584
+ labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
585
+
586
+ return (loss.detach(), logits, labels)