torchrl-nightly 2025.6.19__cp39-cp39-win_amd64.whl → 2025.6.21__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
  2. torchrl/collectors/collectors.py +49 -24
  3. torchrl/collectors/llm/base.py +13 -6
  4. torchrl/collectors/llm/ray_collector.py +3 -0
  5. torchrl/data/__init__.py +2 -0
  6. torchrl/data/datasets/minari_data.py +1 -1
  7. torchrl/data/llm/__init__.py +2 -0
  8. torchrl/data/llm/chat.py +59 -9
  9. torchrl/data/llm/topk.py +186 -0
  10. torchrl/data/replay_buffers/ray_buffer.py +15 -1
  11. torchrl/data/replay_buffers/replay_buffers.py +50 -11
  12. torchrl/data/replay_buffers/samplers.py +98 -21
  13. torchrl/data/replay_buffers/storages.py +29 -2
  14. torchrl/envs/llm/__init__.py +2 -0
  15. torchrl/envs/llm/chat.py +4 -1
  16. torchrl/envs/llm/reward/gsm8k.py +15 -8
  17. torchrl/envs/llm/transforms/__init__.py +2 -1
  18. torchrl/envs/llm/transforms/kl.py +240 -4
  19. torchrl/envs/transforms/transforms.py +11 -27
  20. torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
  21. torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
  22. torchrl/objectives/llm/__init__.py +2 -1
  23. torchrl/objectives/llm/sft.py +465 -0
  24. torchrl/objectives/ppo.py +35 -12
  25. torchrl/version.py +2 -2
  26. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/METADATA +1 -1
  27. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/RECORD +30 -28
  28. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/LICENSE +0 -0
  29. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/WHEEL +0 -0
  30. {torchrl_nightly-2025.6.19.dist-info → torchrl_nightly-2025.6.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import warnings
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from tensordict import NestedKey, TensorClass, TensorDictBase
15
+ from tensordict.nn import TensorDictModule
16
+ from tensordict.utils import _zip_strict
17
+ from torchrl.data import History
18
+ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
19
+ from torchrl.objectives.common import LossModule
20
+
21
+
22
+ def sft_loss(summed_log_probs: torch.Tensor, reduction: str) -> torch.Tensor:
23
+ """Compute the SFT loss."""
24
+ if reduction == "mean":
25
+ loss = -summed_log_probs.mean()
26
+ elif reduction == "sum":
27
+ loss = -summed_log_probs.sum()
28
+ elif reduction == "none":
29
+ loss = -summed_log_probs
30
+ else:
31
+ raise ValueError(f"Invalid reduction: {reduction}.")
32
+ return loss
33
+
34
+
35
+ def minor_sft_loss(
36
+ log_probs: torch.Tensor,
37
+ ref_log_probs: torch.Tensor,
38
+ beta: float,
39
+ reduction: str,
40
+ ) -> torch.Tensor:
41
+ """Compute the MinorSFT loss.
42
+
43
+ This loss is inspired by DPO and is designed to be less aggressive than standard SFT.
44
+ It computes ``-log_sigmoid(beta * (log_probs - ref_log_probs))``.
45
+
46
+ Args:
47
+ log_probs (torch.Tensor): The log probabilities from the model being trained.
48
+ ref_log_probs (torch.Tensor): The log probabilities from the reference model.
49
+ beta (float): The beta parameter from DPO.
50
+ reduction (str): The reduction to apply to the loss.
51
+
52
+ Returns:
53
+ The MinorSFT loss.
54
+
55
+ References:
56
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
57
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
58
+ """
59
+ if log_probs.shape != ref_log_probs.shape:
60
+ raise ValueError(
61
+ f"Current log probabilities and reference log probabilities have different shapes: {log_probs.shape=} vs {ref_log_probs.shape=}."
62
+ )
63
+ loss = -torch.nn.functional.logsigmoid(beta * (log_probs - ref_log_probs))
64
+ if reduction == "mean":
65
+ return loss.mean()
66
+ if reduction == "sum":
67
+ return loss.sum()
68
+ if reduction == "none":
69
+ return loss
70
+ raise ValueError(f"Invalid reduction: {reduction}")
71
+
72
+
73
+ class SFTLossOutput(TensorClass["nocast"]):
74
+ """SFT Loss Output.
75
+
76
+ Attributes:
77
+ loss_sft (torch.Tensor): The loss for the SFT objective.
78
+ loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model.
79
+ kl_to_ref (torch.Tensor | None): The KL divergence to the reference model.
80
+
81
+ .. note::
82
+ The loss components are kept separate to allow for logging and visualization.
83
+ Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable
84
+ when the loss is constructed via :class:`~torchrl.objectives.llm.sft.SFTLoss`, summing
85
+ the :class:`~torchrl.objectives.llm.sft.SFTLossOutput` directly is a proper way of obtaining the total loss.
86
+
87
+ >>> loss_fn = SFTLoss(...)
88
+ >>> loss_output = loss_fn(td)
89
+ >>> loss = loss_output.loss_sft + loss_output.loss_kl_to_ref
90
+ >>> loss.backward()
91
+ >>> # or equivalently
92
+ >>> loss = loss_fn(td)
93
+ >>> loss.sum(reduce=True).backward()
94
+ """
95
+
96
+ loss_sft: torch.Tensor
97
+ loss_kl_to_ref: torch.Tensor | None = None
98
+ kl_to_ref: torch.Tensor | None = None
99
+
100
+
101
+ class SFTLoss(LossModule):
102
+ r"""Supervised fine-tuning loss.
103
+
104
+ Args:
105
+ actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance,
106
+ with `return_log_prob=True` and `from_text=True`.
107
+ tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor_network`.
108
+ tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`.
109
+ This can be used to override arguments such as the `chat_template` or `chat_template_name`.
110
+ reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`.
111
+ normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`.
112
+ kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`.
113
+ loss_function (Literal["sft", "minor_sft"], optional): The loss function to use. Defaults to `"sft"`.
114
+ beta (float, optional): The beta parameter for MinorSFT loss. This is only used when `loss_function` is `"minor_sft"`.
115
+ Higher values of beta make the loss more aggressive (pushes the model to generate responses further from the reference model):
116
+
117
+ .. math::
118
+ \text{loss} = -\log\sigma(\beta \cdot (\text{log_probs} - \text{ref_log_probs}))
119
+
120
+ Defaults to `0.1`.
121
+ device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`.
122
+
123
+ .. note::
124
+ The input tensordict is expected to contain the following keys by default:
125
+ - ``("next", "history")``: The chat history
126
+ - ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set
127
+
128
+ These keys can be customized using the ``set_keys()`` method.
129
+
130
+ .. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation.
131
+
132
+ References:
133
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
134
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
135
+
136
+ Examples:
137
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
138
+ >>> from torchrl.modules.llm import TransformersWrapper
139
+ >>> from torchrl.objectives.llm.sft import SFTLoss
140
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
141
+ >>> from tensordict import TensorDict, lazy_stack
142
+ >>> import torch
143
+ >>>
144
+ >>> # Create chat data
145
+ >>> chats = [
146
+ ... [
147
+ ... {"role": "system", "content": "You are a helpful assistant."},
148
+ ... {"role": "user", "content": "Hello, how are you?"},
149
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
150
+ ... ],
151
+ ... [
152
+ ... {"role": "system", "content": "You are a helpful assistant."},
153
+ ... {"role": "user", "content": "What's the weather like?"},
154
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
155
+ ... ],
156
+ ... ]
157
+ >>> history = History.from_chats(chats)
158
+ >>>
159
+ >>> # Setup tokenizer and model
160
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
161
+ >>> tokenizer.pad_token = tokenizer.eos_token
162
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
163
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
164
+ >>>
165
+ >>> # Create training and reference policies
166
+ >>> policy_train = TransformersWrapper(
167
+ ... model,
168
+ ... tokenizer=tokenizer,
169
+ ... generate=False,
170
+ ... from_text=True,
171
+ ... chat_template_name="qwen",
172
+ ... )
173
+ >>> policy_ref = TransformersWrapper(
174
+ ... model,
175
+ ... tokenizer=tokenizer,
176
+ ... generate=False,
177
+ ... from_text=True,
178
+ ... return_log_probs=True,
179
+ ... chat_template_name="qwen",
180
+ ... )
181
+ >>>
182
+ >>> # Create the RetrieveLogProb transform
183
+ >>> transform = RetrieveLogProb(
184
+ ... policy_ref,
185
+ ... assistant_only=True,
186
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
187
+ ... tokenizer=tokenizer,
188
+ ... )
189
+ >>>
190
+ >>> # Prepare data
191
+ >>> text = history[:, :-1].apply_chat_template(
192
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
193
+ ... )
194
+ >>> text_response = history.apply_chat_template(
195
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
196
+ ... )
197
+ >>> text_response = [
198
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
199
+ ... ]
200
+ >>> td = TensorDict(
201
+ ... text=text,
202
+ ... text_response=text_response,
203
+ ... history=history,
204
+ ... next=TensorDict(
205
+ ... reward=torch.randn(2, 1),
206
+ ... done=torch.zeros(2, dtype=torch.bool),
207
+ ... history=history,
208
+ ... ),
209
+ ... batch_size=(2,),
210
+ ... )
211
+ >>> data = lazy_stack(list(td.unbind(0)))
212
+ >>>
213
+ >>> # Apply the transform to get reference log probabilities
214
+ >>> data = transform(data)
215
+ >>> assert "ref_log_prob" in data["next"].keys()
216
+ >>>
217
+ >>> # Use with SFTLoss for KL regularization
218
+ >>> loss = SFTLoss(
219
+ ... actor_network=policy_train,
220
+ ... tokenizer=tokenizer,
221
+ ... reduction="mean",
222
+ ... normalize_by_seq_length=True,
223
+ ... kl_to_ref_coeff=0.1,
224
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
225
+ ... loss_function="sft",
226
+ ... )
227
+ >>> loss_vals = loss(data)
228
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
229
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
230
+
231
+ """
232
+
233
+ @dataclass
234
+ class _AcceptedKeys:
235
+ """Maintains default values for all configurable tensordict keys.
236
+
237
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
238
+ default values.
239
+
240
+ Attributes:
241
+ history (NestedKey): The input tensordict key where the chat history is expected.
242
+ Defaults to ``("next", "history")``.
243
+ ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected.
244
+ Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``.
245
+ log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written.
246
+ Defaults to ``"log_probs"``.
247
+ """
248
+
249
+ history: NestedKey = ("next", "history")
250
+ ref_log_prob: NestedKey = ("next", "ref_log_prob")
251
+ log_probs: NestedKey = "log_probs"
252
+
253
+ default_keys = _AcceptedKeys
254
+ tensor_keys: _AcceptedKeys
255
+
256
+ def __init__(
257
+ self,
258
+ actor_network: TensorDictModule | TransformersWrapper,
259
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
260
+ tokenizer_kwargs: dict | None = None,
261
+ reduction: Literal["mean", "sum", "none"] = "mean",
262
+ normalize_by_seq_length: bool = True,
263
+ kl_to_ref_coeff: float | None = None,
264
+ loss_function: Literal["sft", "minor_sft"] = "sft",
265
+ beta: float = 0.1,
266
+ device: torch.device | None = None,
267
+ ):
268
+ super().__init__()
269
+ self.in_keys = []
270
+ self.actor_network = actor_network
271
+ if tokenizer is None:
272
+ tokenizer = actor_network.tokenizer
273
+ self.tokenizer = tokenizer
274
+ if tokenizer_kwargs is None:
275
+ tokenizer_kwargs = {}
276
+ if tokenizer is None:
277
+ raise ValueError("Tokenizer must be provided.")
278
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
279
+ tokenizer_kwargs.setdefault("tokenize", True)
280
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
281
+ tokenizer_kwargs.setdefault("padding", False)
282
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
283
+ self.tokenizer_kwargs = tokenizer_kwargs
284
+ self.reduction = reduction
285
+ self.normalize_by_seq_length = normalize_by_seq_length
286
+ self.kl_to_ref_coeff = kl_to_ref_coeff
287
+ self.loss_function = loss_function
288
+ if self.loss_function == "minor_sft" and kl_to_ref_coeff:
289
+ warnings.warn(
290
+ "kl_to_ref_coeff should not be set when using minor_sft loss, as KL regularization is implicit. Setting kl_to_ref_coeff to 0.0."
291
+ )
292
+ self.kl_to_ref_coeff = 0.0
293
+ self.beta = beta
294
+ self._set_in_keys()
295
+ self.device = device
296
+
297
+ def _set_in_keys(self) -> None:
298
+ """Sets the input keys for the loss module."""
299
+ in_keys = [self.tensor_keys.history]
300
+ if self.kl_to_ref_coeff is not None or self.loss_function == "minor_sft":
301
+ in_keys.append(self.tensor_keys.ref_log_prob)
302
+ self.in_keys = in_keys
303
+ self.out_keys = [] # Loss modules typically don't have out_keys
304
+
305
+ def _kl_to_ref(
306
+ self,
307
+ cur_log_prob: list[torch.Tensor],
308
+ ref_log_prob: list[torch.Tensor],
309
+ ) -> tuple[torch.Tensor, torch.Tensor]:
310
+ """Compute KL divergence to reference model.
311
+
312
+ Args:
313
+ cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response.
314
+ ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response.
315
+
316
+ Returns:
317
+ tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging)
318
+ """
319
+ # Apply mask
320
+ ref_log_prob = torch.cat(ref_log_prob)
321
+ cur_log_prob = torch.cat(cur_log_prob)
322
+ # ref_log_prob = ref_log_prob[mask]
323
+ # cur_log_prob = cur_log_prob[mask].squeeze()
324
+ if cur_log_prob.shape != ref_log_prob.shape:
325
+ raise ValueError(
326
+ f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}."
327
+ )
328
+ # Compute KL using same approximation as GRPO
329
+ diff = ref_log_prob - cur_log_prob
330
+
331
+ kl_penalty = (diff.expm1() - diff).mean()
332
+ return self.kl_to_ref_coeff * kl_penalty, kl_penalty
333
+
334
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
335
+ # Gather history
336
+ history: History = tensordict[self.tensor_keys.history]
337
+
338
+ # Apply tokenizer to history and gather mask
339
+ with torch.device(
340
+ self.device
341
+ ) if self.device is not None else contextlib.nullcontext():
342
+ token_struct = history.apply_chat_template(
343
+ tokenizer=self.tokenizer, **self.tokenizer_kwargs
344
+ )
345
+ if "assistant_masks" not in token_struct:
346
+ raise ValueError(
347
+ f"Assistant masks are not present in the token structure: {token_struct=}."
348
+ )
349
+ assistant_masks = token_struct.get(
350
+ "assistant_masks",
351
+ as_list=True,
352
+ )
353
+ assistant_masks = [mask.bool() for mask in assistant_masks]
354
+ attention_mask = token_struct.get("attention_mask", as_list=True)
355
+ attention_mask = [mask.bool() for mask in attention_mask]
356
+ assistant_masks = [
357
+ mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
358
+ ]
359
+
360
+ if not any(mask.any(-1).all() for mask in assistant_masks):
361
+ raise ValueError("Some inputs have no valid assistant masks.")
362
+ input_loss = tensordict.select(self.tensor_keys.history)
363
+ if (
364
+ isinstance(self.tensor_keys.history, tuple)
365
+ and self.tensor_keys.history[0] == "next"
366
+ ):
367
+ input_loss = input_loss["next"]
368
+
369
+ with torch.device(
370
+ self.device
371
+ ) if self.device is not None else contextlib.nullcontext():
372
+ output_loss = self.actor_network(input_loss)
373
+
374
+ # get log-probs
375
+ log_probs = output_loss.get(
376
+ self.tensor_keys.log_probs,
377
+ as_list=True,
378
+ )
379
+ # apply mask
380
+ if not all(
381
+ mask.shape == lp.shape
382
+ for mask, lp in _zip_strict(assistant_masks, log_probs)
383
+ ):
384
+ raise ValueError(
385
+ f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
386
+ )
387
+
388
+ log_probs_masked = [
389
+ lp.masked_fill(~mask, 0.0)
390
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
391
+ ]
392
+
393
+ # Sum log probs, optionally normalize by sequence length
394
+ summed_log_probs = torch.stack(
395
+ [lp.sum(tensordict.ndim - 1) for lp in log_probs_masked]
396
+ )
397
+ seq_lengths = torch.stack(
398
+ [mask.sum(tensordict.ndim - 1) for mask in assistant_masks]
399
+ )
400
+ if self.normalize_by_seq_length:
401
+ # Compute sequence lengths for normalization (number of assistant tokens)
402
+ summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1)
403
+
404
+ # Compute main loss
405
+ if self.loss_function == "sft":
406
+ loss = sft_loss(summed_log_probs, self.reduction)
407
+ # Add KL divergence loss if reference model is provided
408
+ if self.kl_to_ref_coeff is not None:
409
+ ref_log_probs = tensordict.get(
410
+ self.tensor_keys.ref_log_prob,
411
+ default=None,
412
+ as_list=True,
413
+ )
414
+ if ref_log_probs is None:
415
+ raise ValueError(
416
+ "Reference log probs not found in tensordict but kl_to_ref_coeff was set"
417
+ )
418
+
419
+ loss_kl, kl_penalty = self._kl_to_ref(
420
+ [lp[mask] for lp, mask in _zip_strict(log_probs, assistant_masks)],
421
+ ref_log_probs,
422
+ )
423
+ output = SFTLossOutput(
424
+ loss_sft=loss,
425
+ loss_kl_to_ref=loss_kl,
426
+ kl_to_ref=kl_penalty.detach(),
427
+ )
428
+ else:
429
+ output = SFTLossOutput(loss_sft=loss)
430
+ elif self.loss_function == "minor_sft":
431
+ ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
432
+ if ref_log_probs is None:
433
+ raise ValueError(
434
+ f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict but loss_function is 'minor_sft'"
435
+ )
436
+
437
+ # we need to re-sum ref_log_probs as they are not summed per-sequence
438
+ summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to(
439
+ summed_log_probs.device
440
+ )
441
+ if self.normalize_by_seq_length:
442
+ summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1)
443
+ loss = minor_sft_loss(
444
+ summed_log_probs, summed_ref_log_probs, self.beta, self.reduction
445
+ )
446
+ if self.kl_to_ref_coeff is not None:
447
+ with torch.no_grad():
448
+ loss_kl, kl_penalty = self._kl_to_ref(
449
+ [
450
+ lp[mask]
451
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
452
+ ],
453
+ ref_log_probs,
454
+ )
455
+ output = SFTLossOutput(
456
+ loss_sft=loss,
457
+ loss_kl_to_ref=loss_kl,
458
+ kl_to_ref=kl_penalty.detach(),
459
+ )
460
+ else:
461
+ output = SFTLossOutput(loss_sft=loss)
462
+ else:
463
+ raise ValueError(f"Invalid loss function: {self.loss_function}")
464
+
465
+ return output
torchrl/objectives/ppo.py CHANGED
@@ -29,7 +29,7 @@ from tensordict.nn import (
29
29
  from tensordict.utils import NestedKey
30
30
  from torch import distributions as d
31
31
 
32
- from torchrl._utils import _standardize, logger as torchrl_logger
32
+ from torchrl._utils import _standardize, logger as torchrl_logger, VERBOSE
33
33
  from torchrl.objectives.common import LossModule
34
34
  from torchrl.objectives.utils import (
35
35
  _cache_values,
@@ -104,6 +104,9 @@ class PPOLoss(LossModule):
104
104
  * **Scalar**: one value applied to the summed entropy of every action head.
105
105
  * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
106
106
  Defaults to ``0.01``.
107
+ log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108
+ predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109
+ This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
107
110
  critic_coef (scalar, optional): critic loss multiplier when computing the total
108
111
  loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
109
112
  loss from the forward outputs.
@@ -349,6 +352,7 @@ class PPOLoss(LossModule):
349
352
  entropy_bonus: bool = True,
350
353
  samples_mc_entropy: int = 1,
351
354
  entropy_coeff: float | Mapping[str, float] = 0.01,
355
+ log_explained_variance: bool = True,
352
356
  critic_coef: float | None = None,
353
357
  loss_critic_type: str = "smooth_l1",
354
358
  normalize_advantage: bool = False,
@@ -413,6 +417,7 @@ class PPOLoss(LossModule):
413
417
  self.critic_network_params = None
414
418
  self.target_critic_network_params = None
415
419
 
420
+ self.log_explained_variance = log_explained_variance
416
421
  self.samples_mc_entropy = samples_mc_entropy
417
422
  self.entropy_bonus = entropy_bonus
418
423
  self.separate_losses = separate_losses
@@ -564,14 +569,16 @@ class PPOLoss(LossModule):
564
569
  entropy = dist.entropy()
565
570
  if not entropy.isfinite().all():
566
571
  del entropy
567
- torchrl_logger.info(
568
- "Entropy is not finite. Using Monte Carlo sampling."
569
- )
572
+ if VERBOSE:
573
+ torchrl_logger.info(
574
+ "Entropy is not finite. Using Monte Carlo sampling."
575
+ )
570
576
  raise NotImplementedError
571
577
  except NotImplementedError:
572
- torchrl_logger.warn(
573
- f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling."
574
- )
578
+ if VERBOSE:
579
+ torchrl_logger.warning(
580
+ f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling."
581
+ )
575
582
  if getattr(dist, "has_rsample", False):
576
583
  x = dist.rsample((self.samples_mc_entropy,))
577
584
  else:
@@ -743,6 +750,16 @@ class PPOLoss(LossModule):
743
750
  self.loss_critic_type,
744
751
  )
745
752
 
753
+ explained_variance = None
754
+ if self.log_explained_variance:
755
+ with torch.no_grad(): # <‑‑ break grad‐flow
756
+ tgt = target_return.detach()
757
+ pred = state_value.detach()
758
+ eps = torch.finfo(tgt.dtype).eps
759
+ resid = torch.var(tgt - pred, unbiased=False, dim=0)
760
+ total = torch.var(tgt, unbiased=False, dim=0)
761
+ explained_variance = 1.0 - resid / (total + eps)
762
+
746
763
  self._clear_weakrefs(
747
764
  tensordict,
748
765
  "actor_network_params",
@@ -751,8 +768,8 @@ class PPOLoss(LossModule):
751
768
  "target_critic_network_params",
752
769
  )
753
770
  if self._has_critic:
754
- return self.critic_coef * loss_value, clip_fraction
755
- return loss_value, clip_fraction
771
+ return self.critic_coef * loss_value, clip_fraction, explained_variance
772
+ return loss_value, clip_fraction, explained_variance
756
773
 
757
774
  @property
758
775
  @_cache_values
@@ -802,10 +819,12 @@ class PPOLoss(LossModule):
802
819
  td_out.set("entropy", entropy.detach().mean()) # for logging
803
820
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
804
821
  if self._has_critic:
805
- loss_critic, value_clip_fraction = self.loss_critic(tensordict)
822
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
806
823
  td_out.set("loss_critic", loss_critic)
807
824
  if value_clip_fraction is not None:
808
825
  td_out.set("value_clip_fraction", value_clip_fraction)
826
+ if explained_variance is not None:
827
+ td_out.set("explained_variance", explained_variance)
809
828
  td_out = td_out.named_apply(
810
829
  lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
811
830
  if name.startswith("loss_")
@@ -1170,10 +1189,12 @@ class ClipPPOLoss(PPOLoss):
1170
1189
  td_out.set("entropy", entropy.detach().mean()) # for logging
1171
1190
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
1172
1191
  if self._has_critic:
1173
- loss_critic, value_clip_fraction = self.loss_critic(tensordict)
1192
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
1174
1193
  td_out.set("loss_critic", loss_critic)
1175
1194
  if value_clip_fraction is not None:
1176
1195
  td_out.set("value_clip_fraction", value_clip_fraction)
1196
+ if explained_variance is not None:
1197
+ td_out.set("explained_variance", explained_variance)
1177
1198
 
1178
1199
  td_out.set("ESS", _reduce(ess, self.reduction) / batch)
1179
1200
  td_out = td_out.named_apply(
@@ -1516,10 +1537,12 @@ class KLPENPPOLoss(PPOLoss):
1516
1537
  td_out.set("entropy", entropy.detach().mean()) # for logging
1517
1538
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
1518
1539
  if self._has_critic:
1519
- loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
1540
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
1520
1541
  td_out.set("loss_critic", loss_critic)
1521
1542
  if value_clip_fraction is not None:
1522
1543
  td_out.set("value_clip_fraction", value_clip_fraction)
1544
+ if explained_variance is not None:
1545
+ td_out.set("explained_variance", explained_variance)
1523
1546
  td_out = td_out.named_apply(
1524
1547
  lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
1525
1548
  if name.startswith("loss_")
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.6.19'
2
- git_version = '350fa1d790df24f6d05da6f7a36cf02c4daec97f'
1
+ __version__ = '2025.6.21'
2
+ git_version = '77dbc6c9ffbce3d2ce3f26b659355cd46d8132c3'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.6.19
3
+ Version: 2025.6.21
4
4
  Summary: UNKNOWN
5
5
  Home-page: https://github.com/pytorch/rl
6
6
  Author: torchrl contributors