cehrgpt 0.0.2__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cehrgpt/analysis/irregularity.py +36 -0
- cehrgpt/data/hf_cehrgpt_dataset.py +25 -4
- cehrgpt/data/hf_cehrgpt_dataset_collator.py +635 -97
- cehrgpt/data/hf_cehrgpt_dataset_mapping.py +308 -95
- cehrgpt/data/sample_packing_sampler.py +181 -0
- cehrgpt/generation/generate_batch_hf_gpt_sequence.py +12 -9
- cehrgpt/generation/omop_converter_batch.py +32 -2
- cehrgpt/gpt_utils.py +20 -2
- cehrgpt/models/config.py +35 -0
- cehrgpt/models/hf_cehrgpt.py +470 -106
- cehrgpt/models/hf_modeling_outputs.py +1 -0
- cehrgpt/models/special_tokens.py +1 -0
- cehrgpt/models/tokenization_hf_cehrgpt.py +358 -71
- cehrgpt/runners/data_utils.py +358 -0
- cehrgpt/runners/gpt_runner_util.py +0 -10
- cehrgpt/runners/hf_cehrgpt_finetune_runner.py +181 -283
- cehrgpt/runners/hf_cehrgpt_pretrain_runner.py +288 -112
- cehrgpt/runners/hf_gpt_runner_argument_dataclass.py +90 -0
- cehrgpt/runners/hyperparameter_search_util.py +10 -8
- cehrgpt/runners/sample_packing_trainer.py +185 -0
- cehrgpt/simulations/generate_plots.py +95 -0
- cehrgpt/simulations/run_simulation.sh +24 -0
- cehrgpt/simulations/time_embedding_simulation.py +250 -0
- cehrgpt/simulations/time_token_simulation.py +177 -0
- cehrgpt/time_to_event/config/1_year_cabg.yaml +23 -0
- cehrgpt/time_to_event/time_to_event_model.py +2 -13
- cehrgpt/time_to_event/time_to_event_prediction.py +27 -13
- cehrgpt/tools/linear_prob/__init__.py +0 -0
- cehrgpt/tools/linear_prob/compute_cehrgpt_features.py +495 -0
- cehrgpt/tools/linear_prob/train_with_cehrgpt_features.py +152 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/METADATA +11 -8
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/RECORD +36 -32
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/WHEEL +1 -1
- cehrgpt/data/hf_cehrgpt_dpo_collator.py +0 -71
- cehrgpt/data/hf_cehrgpt_dpo_dataset_mapping.py +0 -61
- cehrgpt/generation/generate_paired_cehrgpt_sequence.py +0 -224
- cehrgpt/rl_finetune/cehrgpt_dpo_trainer.py +0 -586
- cehrgpt/rl_finetune/cehrgpt_ppo_trainer.py +0 -464
- cehrgpt/rl_finetune/ppo_finetune.py +0 -394
- cehrgpt/rl_finetune/ppo_finetune_v2.py +0 -373
- cehrgpt/runners/hf_cehrgpt_dpo_runner.py +0 -119
- /cehrgpt/{rl_finetune → simulations}/__init__.py +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info/licenses}/LICENSE +0 -0
- {cehrgpt-0.0.2.dist-info → cehrgpt-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,586 +0,0 @@
|
|
1
|
-
import warnings
|
2
|
-
from collections import defaultdict
|
3
|
-
from copy import deepcopy
|
4
|
-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
5
|
-
|
6
|
-
import torch
|
7
|
-
import torch.nn as nn
|
8
|
-
import torch.nn.functional as F
|
9
|
-
from accelerate.utils import is_deepspeed_available
|
10
|
-
from datasets import Dataset
|
11
|
-
from transformers import PreTrainedModel, Trainer
|
12
|
-
from transformers.trainer_callback import TrainerCallback
|
13
|
-
from transformers.trainer_utils import EvalLoopOutput
|
14
|
-
from trl.trainer.callbacks import SyncRefModelCallback
|
15
|
-
from trl.trainer.dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
|
16
|
-
from trl.trainer.utils import RunningMoments, cap_exp, disable_dropout_in_model
|
17
|
-
|
18
|
-
from cehrgpt.data.hf_cehrgpt_dpo_collator import CehrGptDPODataCollator
|
19
|
-
from cehrgpt.models.hf_cehrgpt import CEHRGPT2LMHeadModel
|
20
|
-
from cehrgpt.models.tokenization_hf_cehrgpt import CehrGptTokenizer
|
21
|
-
|
22
|
-
if is_deepspeed_available():
|
23
|
-
import deepspeed
|
24
|
-
|
25
|
-
|
26
|
-
class CehrGptDPOTrainer(Trainer):
|
27
|
-
def __init__(
|
28
|
-
self,
|
29
|
-
model: CEHRGPT2LMHeadModel,
|
30
|
-
ref_model: CEHRGPT2LMHeadModel,
|
31
|
-
tokenizer: CehrGptTokenizer,
|
32
|
-
args: DPOConfig,
|
33
|
-
data_collator: CehrGptDPODataCollator,
|
34
|
-
train_dataset: Optional[Dataset] = None,
|
35
|
-
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
36
|
-
callbacks: Optional[List[TrainerCallback]] = None,
|
37
|
-
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
38
|
-
None,
|
39
|
-
None,
|
40
|
-
),
|
41
|
-
preprocess_logits_for_metrics: Optional[
|
42
|
-
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
43
|
-
] = None,
|
44
|
-
disable_dropout: bool = True,
|
45
|
-
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
46
|
-
):
|
47
|
-
if ref_model is model:
|
48
|
-
raise ValueError(
|
49
|
-
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
50
|
-
"same as `model`, you must mass a copy of it, or `None` if you use peft."
|
51
|
-
)
|
52
|
-
|
53
|
-
if getattr(args, "gradient_checkpointing", False):
|
54
|
-
# For backward compatibility with older versions of transformers
|
55
|
-
if hasattr(model, "enable_input_require_grads"):
|
56
|
-
model.enable_input_require_grads()
|
57
|
-
else:
|
58
|
-
|
59
|
-
def make_inputs_require_grad(module, input, output):
|
60
|
-
output.requires_grad_(True)
|
61
|
-
|
62
|
-
model.get_input_embeddings().register_forward_hook(
|
63
|
-
make_inputs_require_grad
|
64
|
-
)
|
65
|
-
|
66
|
-
if model is not None:
|
67
|
-
self.is_encoder_decoder = model.config.is_encoder_decoder
|
68
|
-
elif args.is_encoder_decoder is None:
|
69
|
-
raise ValueError(
|
70
|
-
"When no model is provided, you need to pass the parameter is_encoder_decoder to the DPOTrainer/DPOConfig."
|
71
|
-
)
|
72
|
-
else:
|
73
|
-
self.is_encoder_decoder = args.is_encoder_decoder
|
74
|
-
|
75
|
-
self.tokenizer = tokenizer
|
76
|
-
self.ref_model = ref_model
|
77
|
-
|
78
|
-
if not disable_dropout:
|
79
|
-
warnings.warn(
|
80
|
-
"You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
|
81
|
-
)
|
82
|
-
args.disable_dropout = disable_dropout
|
83
|
-
|
84
|
-
if args.disable_dropout:
|
85
|
-
disable_dropout_in_model(model)
|
86
|
-
if self.ref_model is not None:
|
87
|
-
disable_dropout_in_model(self.ref_model)
|
88
|
-
|
89
|
-
self.beta = args.beta
|
90
|
-
self.label_smoothing = args.label_smoothing
|
91
|
-
self.loss_type = args.loss_type
|
92
|
-
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
93
|
-
self.f_divergence_type = args.f_divergence_type
|
94
|
-
self.f_divergence_params = {
|
95
|
-
FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef
|
96
|
-
}
|
97
|
-
self.reference_free = args.reference_free
|
98
|
-
# α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
|
99
|
-
# weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
|
100
|
-
# DPO loss. The paper recommends `rpo_alpha=1.0`.
|
101
|
-
self.rpo_alpha = args.rpo_alpha
|
102
|
-
self.vs_token_id = tokenizer._convert_token_to_id("VS")
|
103
|
-
if self.vs_token_id == tokenizer._oov_token_id:
|
104
|
-
self.vs_token_id = tokenizer._convert_token_to_id("[VS]")
|
105
|
-
self.ve_token_id = tokenizer._convert_token_to_id("VE")
|
106
|
-
if self.ve_token_id == tokenizer._oov_token_id:
|
107
|
-
self.ve_token_id = tokenizer._convert_token_to_id("[VE]")
|
108
|
-
self.non_active_token_ids = [
|
109
|
-
tokenizer.pad_token_id,
|
110
|
-
self.vs_token_id,
|
111
|
-
self.ve_token_id,
|
112
|
-
]
|
113
|
-
|
114
|
-
super().__init__(
|
115
|
-
model=model,
|
116
|
-
args=args,
|
117
|
-
data_collator=data_collator,
|
118
|
-
train_dataset=train_dataset,
|
119
|
-
eval_dataset=eval_dataset,
|
120
|
-
tokenizer=tokenizer,
|
121
|
-
compute_metrics=compute_metrics,
|
122
|
-
callbacks=callbacks,
|
123
|
-
optimizers=optimizers,
|
124
|
-
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
125
|
-
)
|
126
|
-
|
127
|
-
if not hasattr(self, "accelerator"):
|
128
|
-
raise AttributeError(
|
129
|
-
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
130
|
-
)
|
131
|
-
|
132
|
-
if self.is_deepspeed_enabled:
|
133
|
-
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
134
|
-
else:
|
135
|
-
self.ref_model = self.accelerator.prepare_model(
|
136
|
-
self.ref_model, evaluation_mode=True
|
137
|
-
)
|
138
|
-
self.add_callback(
|
139
|
-
SyncRefModelCallback(
|
140
|
-
ref_model=self.ref_model, accelerator=self.accelerator
|
141
|
-
)
|
142
|
-
)
|
143
|
-
if self.loss_type == "bco_pair":
|
144
|
-
self.running = RunningMoments(self.accelerator)
|
145
|
-
|
146
|
-
def _prepare_deepspeed(self, model: CEHRGPT2LMHeadModel):
|
147
|
-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
148
|
-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
149
|
-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
150
|
-
|
151
|
-
if model is not None:
|
152
|
-
if hasattr(model, "config"):
|
153
|
-
hidden_size = (
|
154
|
-
max(model.config.hidden_sizes)
|
155
|
-
if getattr(model.config, "hidden_sizes", None)
|
156
|
-
else getattr(model.config, "hidden_size", None)
|
157
|
-
)
|
158
|
-
if (
|
159
|
-
hidden_size is not None
|
160
|
-
and config_kwargs["zero_optimization"]["stage"] == 3
|
161
|
-
):
|
162
|
-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
163
|
-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
164
|
-
config_kwargs.update(
|
165
|
-
{
|
166
|
-
"zero_optimization.reduce_bucket_size": hidden_size
|
167
|
-
* hidden_size,
|
168
|
-
"zero_optimization.stage3_param_persistence_threshold": 10
|
169
|
-
* hidden_size,
|
170
|
-
"zero_optimization.stage3_prefetch_bucket_size": 0.9
|
171
|
-
* hidden_size
|
172
|
-
* hidden_size,
|
173
|
-
}
|
174
|
-
)
|
175
|
-
|
176
|
-
# If ZeRO-3 is used, we shard both the active and reference model.
|
177
|
-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
178
|
-
if config_kwargs["zero_optimization"]["stage"] != 3:
|
179
|
-
config_kwargs["zero_optimization"]["stage"] = 0
|
180
|
-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
181
|
-
model.eval()
|
182
|
-
return model
|
183
|
-
|
184
|
-
def forward(
|
185
|
-
self,
|
186
|
-
model: CEHRGPT2LMHeadModel,
|
187
|
-
batch: Dict[str, torch.Tensor],
|
188
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
189
|
-
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
190
|
-
|
191
|
-
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
192
|
-
"""
|
193
|
-
chosen_outputs = model(
|
194
|
-
input_ids=batch["chosen_input_ids"],
|
195
|
-
attention_mask=batch["chosen_attention_mask"],
|
196
|
-
value_indicators=(
|
197
|
-
batch["chosen_value_indicators"]
|
198
|
-
if "chosen_value_indicators" in batch
|
199
|
-
else None
|
200
|
-
),
|
201
|
-
values=batch["chosen_values"] if "chosen_values" in batch else None,
|
202
|
-
)
|
203
|
-
|
204
|
-
chosen_logps, chosen_logits = self.get_batch_logps(
|
205
|
-
chosen_outputs.logits,
|
206
|
-
batch["chosen_input_ids"],
|
207
|
-
pad_token_ids=[self.tokenizer.pad_token_id],
|
208
|
-
)
|
209
|
-
|
210
|
-
# move labels to correct device to enable model parallelism
|
211
|
-
labels = batch["chosen_input_ids"].to(chosen_outputs.logits.device)
|
212
|
-
# Shift so that tokens < n predict n
|
213
|
-
shift_logits = chosen_outputs.logits[..., :-1, :].contiguous()
|
214
|
-
shift_labels = labels[..., 1:].contiguous()
|
215
|
-
# Flatten the tokens
|
216
|
-
loss_fct = nn.CrossEntropyLoss()
|
217
|
-
nll_loss = loss_fct(
|
218
|
-
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
219
|
-
)
|
220
|
-
|
221
|
-
rejected_outputs = model(
|
222
|
-
input_ids=batch["rejected_input_ids"],
|
223
|
-
attention_mask=batch["rejected_attention_mask"],
|
224
|
-
value_indicators=(
|
225
|
-
batch["rejected_value_indicators"]
|
226
|
-
if "rejected_value_indicators" in batch
|
227
|
-
else None
|
228
|
-
),
|
229
|
-
values=batch["rejected_values"] if "rejected_values" in batch else None,
|
230
|
-
)
|
231
|
-
|
232
|
-
rejected_logps, rejected_logits = self.get_batch_logps(
|
233
|
-
rejected_outputs.logits,
|
234
|
-
batch["rejected_input_ids"],
|
235
|
-
pad_token_ids=[self.tokenizer.pad_token_id],
|
236
|
-
)
|
237
|
-
|
238
|
-
return (
|
239
|
-
chosen_logps,
|
240
|
-
rejected_logps,
|
241
|
-
chosen_logits,
|
242
|
-
rejected_logits,
|
243
|
-
nll_loss,
|
244
|
-
)
|
245
|
-
|
246
|
-
def dpo_loss(
|
247
|
-
self,
|
248
|
-
policy_chosen_logps: torch.Tensor,
|
249
|
-
policy_rejected_logps: torch.Tensor,
|
250
|
-
reference_chosen_logps: torch.Tensor,
|
251
|
-
reference_rejected_logps: torch.Tensor,
|
252
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
253
|
-
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
254
|
-
|
255
|
-
Args:
|
256
|
-
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
257
|
-
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
258
|
-
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
259
|
-
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
260
|
-
|
261
|
-
Returns:
|
262
|
-
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
263
|
-
The losses tensor contains the DPO loss for each example in the batch.
|
264
|
-
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
265
|
-
"""
|
266
|
-
chosen_logratios = policy_chosen_logps.to(self.accelerator.device) - (
|
267
|
-
not self.reference_free
|
268
|
-
) * reference_chosen_logps.to(self.accelerator.device)
|
269
|
-
rejected_logratios = policy_rejected_logps.to(self.accelerator.device) - (
|
270
|
-
not self.reference_free
|
271
|
-
) * reference_rejected_logps.to(self.accelerator.device)
|
272
|
-
|
273
|
-
if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value:
|
274
|
-
# The alpha-divergence formula: (1 - u^-alpha) / alpha
|
275
|
-
# The divergence difference between the chosen and rejected sample is:
|
276
|
-
# (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha
|
277
|
-
# = (u[l]^-alpha - u[w]^-alpha) / alpha
|
278
|
-
# where u[w] and u[l] are the policy/reference probability ratios
|
279
|
-
# for the chosen and rejected samples, respectively.
|
280
|
-
alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT
|
281
|
-
if (
|
282
|
-
self.f_divergence_params
|
283
|
-
and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
|
284
|
-
in self.f_divergence_params
|
285
|
-
):
|
286
|
-
alpha_coef = float(
|
287
|
-
self.f_divergence_params[
|
288
|
-
FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY
|
289
|
-
]
|
290
|
-
)
|
291
|
-
logits = (
|
292
|
-
cap_exp(rejected_logratios * -alpha_coef)
|
293
|
-
- cap_exp(chosen_logratios * -alpha_coef)
|
294
|
-
) / alpha_coef
|
295
|
-
else:
|
296
|
-
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
297
|
-
if self.reference_free:
|
298
|
-
ref_logratios = torch.tensor(
|
299
|
-
[0], dtype=pi_logratios.dtype, device=pi_logratios.device
|
300
|
-
)
|
301
|
-
else:
|
302
|
-
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
303
|
-
|
304
|
-
pi_logratios = pi_logratios.to(self.accelerator.device)
|
305
|
-
ref_logratios = ref_logratios.to(self.accelerator.device)
|
306
|
-
logits = pi_logratios - ref_logratios
|
307
|
-
|
308
|
-
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
|
309
|
-
# The js-divergence formula: log(2 * u / (1 + u))
|
310
|
-
# The divergence difference between the chosen and rejected sample is:
|
311
|
-
# log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
|
312
|
-
# = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
|
313
|
-
# where u[w] and u[l] are the policy/reference probability ratios
|
314
|
-
# for the chosen and rejected samples, respectively.
|
315
|
-
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
|
316
|
-
|
317
|
-
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
|
318
|
-
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
319
|
-
# calculates a conservative DPO loss.
|
320
|
-
if self.loss_type == "sigmoid":
|
321
|
-
losses = (
|
322
|
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
323
|
-
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
324
|
-
)
|
325
|
-
elif self.loss_type == "robust":
|
326
|
-
losses = (
|
327
|
-
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
328
|
-
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
329
|
-
) / (1 - 2 * self.label_smoothing)
|
330
|
-
elif self.loss_type == "exo_pair":
|
331
|
-
# eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856
|
332
|
-
import math
|
333
|
-
|
334
|
-
if self.label_smoothing == 0:
|
335
|
-
self.label_smoothing = 1e-3
|
336
|
-
losses = (self.beta * logits).sigmoid() * (
|
337
|
-
F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing)
|
338
|
-
) + (-self.beta * logits).sigmoid() * (
|
339
|
-
F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)
|
340
|
-
)
|
341
|
-
elif self.loss_type == "hinge":
|
342
|
-
losses = torch.relu(1 - self.beta * logits)
|
343
|
-
elif self.loss_type == "ipo":
|
344
|
-
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
345
|
-
losses = (logits - 1 / (2 * self.beta)) ** 2
|
346
|
-
elif self.loss_type == "bco_pair":
|
347
|
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
348
|
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
349
|
-
|
350
|
-
chosen_rewards = self.beta * chosen_logratios
|
351
|
-
rejected_rewards = self.beta * rejected_logratios
|
352
|
-
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
|
353
|
-
self.running.update(rewards)
|
354
|
-
delta = self.running.mean
|
355
|
-
|
356
|
-
losses = -F.logsigmoid(
|
357
|
-
(self.beta * chosen_logratios) - delta
|
358
|
-
) - F.logsigmoid(-(self.beta * rejected_logratios - delta))
|
359
|
-
elif self.loss_type == "sppo_hard":
|
360
|
-
# In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is set to 1 for the winner and 0 for the loser.
|
361
|
-
a = policy_chosen_logps - reference_chosen_logps
|
362
|
-
b = policy_rejected_logps - reference_rejected_logps
|
363
|
-
|
364
|
-
losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2
|
365
|
-
elif self.loss_type == "nca_pair":
|
366
|
-
chosen_rewards = (policy_chosen_logps - reference_chosen_logps) * self.beta
|
367
|
-
rejected_rewards = (
|
368
|
-
policy_rejected_logps - reference_rejected_logps
|
369
|
-
) * self.beta
|
370
|
-
losses = (
|
371
|
-
-F.logsigmoid(chosen_rewards)
|
372
|
-
- 0.5 * F.logsigmoid(-chosen_rewards)
|
373
|
-
- 0.5 * F.logsigmoid(-rejected_rewards)
|
374
|
-
)
|
375
|
-
elif self.loss_type == "aot_pair":
|
376
|
-
chosen_logratios = policy_chosen_logps - reference_chosen_logps
|
377
|
-
rejected_logratios = policy_rejected_logps - reference_rejected_logps
|
378
|
-
|
379
|
-
chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0)
|
380
|
-
rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0)
|
381
|
-
|
382
|
-
delta = chosen_logratios_sorted - rejected_logratios_sorted
|
383
|
-
|
384
|
-
losses = (
|
385
|
-
-F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
|
386
|
-
- F.logsigmoid(-self.beta * delta) * self.label_smoothing
|
387
|
-
)
|
388
|
-
|
389
|
-
elif self.loss_type == "aot":
|
390
|
-
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
391
|
-
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
392
|
-
|
393
|
-
pi_logratios_sorted, _ = torch.sort(pi_logratios, dim=0)
|
394
|
-
ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0)
|
395
|
-
|
396
|
-
delta = pi_logratios_sorted - ref_logratios_sorted
|
397
|
-
|
398
|
-
losses = (
|
399
|
-
-F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing)
|
400
|
-
- F.logsigmoid(-self.beta * delta) * self.label_smoothing
|
401
|
-
)
|
402
|
-
|
403
|
-
elif self.loss_type == "apo_zero":
|
404
|
-
# Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266)
|
405
|
-
# Use this loss when you believe the chosen outputs are better than your model's default output
|
406
|
-
|
407
|
-
losses_chosen = 1 - F.sigmoid(
|
408
|
-
self.beta * chosen_logratios
|
409
|
-
) # Increase chosen likelihood
|
410
|
-
losses_rejected = F.sigmoid(
|
411
|
-
self.beta * rejected_logratios
|
412
|
-
) # Decrease rejected likelihood
|
413
|
-
|
414
|
-
losses = losses_chosen + losses_rejected
|
415
|
-
|
416
|
-
elif self.loss_type == "apo_down":
|
417
|
-
# Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266)
|
418
|
-
# Use this loss when you believe the chosen outputs are worse than your model's default output
|
419
|
-
|
420
|
-
losses_chosen = F.sigmoid(
|
421
|
-
self.beta * chosen_logratios
|
422
|
-
) # Decrease chosen likelihood
|
423
|
-
losses_rejected = 1 - F.sigmoid(
|
424
|
-
self.beta * (chosen_logratios - rejected_logratios)
|
425
|
-
) # Decrease rejected likelihood more
|
426
|
-
|
427
|
-
losses = losses_chosen + losses_rejected
|
428
|
-
|
429
|
-
else:
|
430
|
-
raise ValueError(
|
431
|
-
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down']"
|
432
|
-
)
|
433
|
-
|
434
|
-
chosen_rewards = (
|
435
|
-
self.beta
|
436
|
-
* (
|
437
|
-
policy_chosen_logps.to(self.accelerator.device)
|
438
|
-
- reference_chosen_logps.to(self.accelerator.device)
|
439
|
-
).detach()
|
440
|
-
)
|
441
|
-
rejected_rewards = (
|
442
|
-
self.beta
|
443
|
-
* (
|
444
|
-
policy_rejected_logps.to(self.accelerator.device)
|
445
|
-
- reference_rejected_logps.to(self.accelerator.device)
|
446
|
-
).detach()
|
447
|
-
)
|
448
|
-
|
449
|
-
return losses, chosen_rewards, rejected_rewards
|
450
|
-
|
451
|
-
def get_batch_loss_metrics(
|
452
|
-
self,
|
453
|
-
model,
|
454
|
-
batch: Dict[str, Union[List, torch.LongTensor]],
|
455
|
-
train_eval: Literal["train", "eval"] = "train",
|
456
|
-
):
|
457
|
-
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
458
|
-
metrics = {}
|
459
|
-
forward_output = self.forward(model, batch)
|
460
|
-
(
|
461
|
-
policy_chosen_logps,
|
462
|
-
policy_rejected_logps,
|
463
|
-
policy_chosen_logits,
|
464
|
-
policy_rejected_logits,
|
465
|
-
policy_nll_loss,
|
466
|
-
) = forward_output[:5]
|
467
|
-
|
468
|
-
with torch.no_grad():
|
469
|
-
reference_chosen_logps, reference_rejected_logps = self.forward(
|
470
|
-
self.ref_model, batch
|
471
|
-
)[:2]
|
472
|
-
|
473
|
-
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
474
|
-
policy_chosen_logps,
|
475
|
-
policy_rejected_logps,
|
476
|
-
reference_chosen_logps,
|
477
|
-
reference_rejected_logps,
|
478
|
-
)
|
479
|
-
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
480
|
-
|
481
|
-
if self.rpo_alpha is not None:
|
482
|
-
# RPO loss from V3 of the paper:
|
483
|
-
losses = losses + policy_nll_loss * self.rpo_alpha
|
484
|
-
|
485
|
-
prefix = "eval_" if train_eval == "eval" else ""
|
486
|
-
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
487
|
-
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
488
|
-
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
489
|
-
metrics[f"{prefix}rewards/margins"] = (
|
490
|
-
(chosen_rewards - rejected_rewards).mean().cpu()
|
491
|
-
)
|
492
|
-
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
493
|
-
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
494
|
-
metrics[f"{prefix}logits/rejected"] = (
|
495
|
-
policy_rejected_logits.detach().mean().cpu()
|
496
|
-
)
|
497
|
-
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
498
|
-
if self.rpo_alpha is not None:
|
499
|
-
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
500
|
-
return losses.mean(), metrics
|
501
|
-
|
502
|
-
@staticmethod
|
503
|
-
def get_batch_logps(
|
504
|
-
all_logits: torch.FloatTensor, labels: torch.Tensor, pad_token_ids: List[int]
|
505
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
506
|
-
# remove the demographics
|
507
|
-
all_logits = all_logits[:, 3:-1]
|
508
|
-
labels = labels.clone()[:, 4:]
|
509
|
-
# Calculate the sequence log probability log p(x)
|
510
|
-
per_token_logps = torch.gather(
|
511
|
-
all_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
512
|
-
).squeeze(2)
|
513
|
-
non_pad_mask = ~torch.isin(
|
514
|
-
labels, torch.tensor(pad_token_ids, device=labels.device)
|
515
|
-
)
|
516
|
-
per_token_logps *= non_pad_mask
|
517
|
-
all_logps = per_token_logps.sum(-1)
|
518
|
-
return all_logps, all_logits
|
519
|
-
|
520
|
-
def compute_loss(
|
521
|
-
self,
|
522
|
-
model: Union[PreTrainedModel, nn.Module],
|
523
|
-
inputs: Dict[str, Union[torch.Tensor, Any]],
|
524
|
-
return_outputs=False,
|
525
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
526
|
-
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
527
|
-
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
528
|
-
loss = loss.to(self.args.device)
|
529
|
-
# force log the metrics
|
530
|
-
self.store_metrics(metrics, train_eval="train")
|
531
|
-
if return_outputs:
|
532
|
-
return (loss, metrics)
|
533
|
-
return loss
|
534
|
-
|
535
|
-
def store_metrics(
|
536
|
-
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
537
|
-
) -> None:
|
538
|
-
for key, value in metrics.items():
|
539
|
-
self._stored_metrics[train_eval][key].append(value)
|
540
|
-
|
541
|
-
def log(self, logs: Dict[str, float]) -> None:
|
542
|
-
"""
|
543
|
-
Log `logs` on the various objects watching training, including stored metrics.
|
544
|
-
|
545
|
-
Args:
|
546
|
-
logs (`Dict[str, float]`):
|
547
|
-
The values to log.
|
548
|
-
"""
|
549
|
-
# logs either has 'loss' or 'eval_loss'
|
550
|
-
train_eval = "train" if "loss" in logs else "eval"
|
551
|
-
# Add averaged stored metrics to logs
|
552
|
-
for key, metrics in self._stored_metrics[train_eval].items():
|
553
|
-
logs[key] = torch.tensor(metrics).mean().item()
|
554
|
-
del self._stored_metrics[train_eval]
|
555
|
-
return super().log(logs)
|
556
|
-
|
557
|
-
def prediction_step(
|
558
|
-
self,
|
559
|
-
model: Union[PreTrainedModel, nn.Module],
|
560
|
-
inputs: Dict[str, Union[torch.Tensor, Any]],
|
561
|
-
prediction_loss_only: bool,
|
562
|
-
ignore_keys: Optional[List[str]] = None,
|
563
|
-
):
|
564
|
-
with torch.no_grad():
|
565
|
-
loss, metrics = self.get_batch_loss_metrics(
|
566
|
-
model, inputs, train_eval="eval"
|
567
|
-
)
|
568
|
-
|
569
|
-
# force log the metrics
|
570
|
-
self.store_metrics(metrics, train_eval="eval")
|
571
|
-
|
572
|
-
if prediction_loss_only:
|
573
|
-
return (loss.detach(), None, None)
|
574
|
-
|
575
|
-
# logits for the chosen and rejected samples from model
|
576
|
-
logits_dict = {
|
577
|
-
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
578
|
-
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
579
|
-
}
|
580
|
-
logits = tuple(
|
581
|
-
v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys
|
582
|
-
)
|
583
|
-
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
584
|
-
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
585
|
-
|
586
|
-
return (loss.detach(), logits, labels)
|