torchrl-nightly 2025.6.20__cp313-cp313-manylinux1_x86_64.whl → 2025.6.22__cp313-cp313-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import warnings
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from tensordict import NestedKey, TensorClass, TensorDictBase
15
+ from tensordict.nn import TensorDictModule
16
+ from tensordict.utils import _zip_strict
17
+ from torchrl.data import History
18
+ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
19
+ from torchrl.objectives.common import LossModule
20
+
21
+
22
+ def sft_loss(summed_log_probs: torch.Tensor, reduction: str) -> torch.Tensor:
23
+ """Compute the SFT loss."""
24
+ if reduction == "mean":
25
+ loss = -summed_log_probs.mean()
26
+ elif reduction == "sum":
27
+ loss = -summed_log_probs.sum()
28
+ elif reduction == "none":
29
+ loss = -summed_log_probs
30
+ else:
31
+ raise ValueError(f"Invalid reduction: {reduction}.")
32
+ return loss
33
+
34
+
35
+ def minor_sft_loss(
36
+ log_probs: torch.Tensor,
37
+ ref_log_probs: torch.Tensor,
38
+ beta: float,
39
+ reduction: str,
40
+ ) -> torch.Tensor:
41
+ """Compute the MinorSFT loss.
42
+
43
+ This loss is inspired by DPO and is designed to be less aggressive than standard SFT.
44
+ It computes ``-log_sigmoid(beta * (log_probs - ref_log_probs))``.
45
+
46
+ Args:
47
+ log_probs (torch.Tensor): The log probabilities from the model being trained.
48
+ ref_log_probs (torch.Tensor): The log probabilities from the reference model.
49
+ beta (float): The beta parameter from DPO.
50
+ reduction (str): The reduction to apply to the loss.
51
+
52
+ Returns:
53
+ The MinorSFT loss.
54
+
55
+ References:
56
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
57
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
58
+ """
59
+ if log_probs.shape != ref_log_probs.shape:
60
+ raise ValueError(
61
+ f"Current log probabilities and reference log probabilities have different shapes: {log_probs.shape=} vs {ref_log_probs.shape=}."
62
+ )
63
+ loss = -torch.nn.functional.logsigmoid(beta * (log_probs - ref_log_probs))
64
+ if reduction == "mean":
65
+ return loss.mean()
66
+ if reduction == "sum":
67
+ return loss.sum()
68
+ if reduction == "none":
69
+ return loss
70
+ raise ValueError(f"Invalid reduction: {reduction}")
71
+
72
+
73
+ class SFTLossOutput(TensorClass["nocast"]):
74
+ """SFT Loss Output.
75
+
76
+ Attributes:
77
+ loss_sft (torch.Tensor): The loss for the SFT objective.
78
+ loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model.
79
+ kl_to_ref (torch.Tensor | None): The KL divergence to the reference model.
80
+
81
+ .. note::
82
+ The loss components are kept separate to allow for logging and visualization.
83
+ Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable
84
+ when the loss is constructed via :class:`~torchrl.objectives.llm.sft.SFTLoss`, summing
85
+ the :class:`~torchrl.objectives.llm.sft.SFTLossOutput` directly is a proper way of obtaining the total loss.
86
+
87
+ >>> loss_fn = SFTLoss(...)
88
+ >>> loss_output = loss_fn(td)
89
+ >>> loss = loss_output.loss_sft + loss_output.loss_kl_to_ref
90
+ >>> loss.backward()
91
+ >>> # or equivalently
92
+ >>> loss = loss_fn(td)
93
+ >>> loss.sum(reduce=True).backward()
94
+ """
95
+
96
+ loss_sft: torch.Tensor
97
+ loss_kl_to_ref: torch.Tensor | None = None
98
+ kl_to_ref: torch.Tensor | None = None
99
+
100
+
101
+ class SFTLoss(LossModule):
102
+ r"""Supervised fine-tuning loss.
103
+
104
+ Args:
105
+ actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance,
106
+ with `return_log_prob=True` and `from_text=True`.
107
+ tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor_network`.
108
+ tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`.
109
+ This can be used to override arguments such as the `chat_template` or `chat_template_name`.
110
+ reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`.
111
+ normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`.
112
+ kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`.
113
+ loss_function (Literal["sft", "minor_sft"], optional): The loss function to use. Defaults to `"sft"`.
114
+ beta (float, optional): The beta parameter for MinorSFT loss. This is only used when `loss_function` is `"minor_sft"`.
115
+ Higher values of beta make the loss more aggressive (pushes the model to generate responses further from the reference model):
116
+
117
+ .. math::
118
+ \text{loss} = -\log\sigma(\beta \cdot (\text{log_probs} - \text{ref_log_probs}))
119
+
120
+ Defaults to `0.1`.
121
+ device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`.
122
+
123
+ .. note::
124
+ The input tensordict is expected to contain the following keys by default:
125
+ - ``("next", "history")``: The chat history
126
+ - ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set
127
+
128
+ These keys can be customized using the ``set_keys()`` method.
129
+
130
+ .. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation.
131
+
132
+ References:
133
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
134
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
135
+
136
+ Examples:
137
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
138
+ >>> from torchrl.modules.llm import TransformersWrapper
139
+ >>> from torchrl.objectives.llm.sft import SFTLoss
140
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
141
+ >>> from tensordict import TensorDict, lazy_stack
142
+ >>> import torch
143
+ >>>
144
+ >>> # Create chat data
145
+ >>> chats = [
146
+ ... [
147
+ ... {"role": "system", "content": "You are a helpful assistant."},
148
+ ... {"role": "user", "content": "Hello, how are you?"},
149
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
150
+ ... ],
151
+ ... [
152
+ ... {"role": "system", "content": "You are a helpful assistant."},
153
+ ... {"role": "user", "content": "What's the weather like?"},
154
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
155
+ ... ],
156
+ ... ]
157
+ >>> history = History.from_chats(chats)
158
+ >>>
159
+ >>> # Setup tokenizer and model
160
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
161
+ >>> tokenizer.pad_token = tokenizer.eos_token
162
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
163
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
164
+ >>>
165
+ >>> # Create training and reference policies
166
+ >>> policy_train = TransformersWrapper(
167
+ ... model,
168
+ ... tokenizer=tokenizer,
169
+ ... generate=False,
170
+ ... from_text=True,
171
+ ... chat_template_name="qwen",
172
+ ... )
173
+ >>> policy_ref = TransformersWrapper(
174
+ ... model,
175
+ ... tokenizer=tokenizer,
176
+ ... generate=False,
177
+ ... from_text=True,
178
+ ... return_log_probs=True,
179
+ ... chat_template_name="qwen",
180
+ ... )
181
+ >>>
182
+ >>> # Create the RetrieveLogProb transform
183
+ >>> transform = RetrieveLogProb(
184
+ ... policy_ref,
185
+ ... assistant_only=True,
186
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
187
+ ... tokenizer=tokenizer,
188
+ ... )
189
+ >>>
190
+ >>> # Prepare data
191
+ >>> text = history[:, :-1].apply_chat_template(
192
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
193
+ ... )
194
+ >>> text_response = history.apply_chat_template(
195
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
196
+ ... )
197
+ >>> text_response = [
198
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
199
+ ... ]
200
+ >>> td = TensorDict(
201
+ ... text=text,
202
+ ... text_response=text_response,
203
+ ... history=history,
204
+ ... next=TensorDict(
205
+ ... reward=torch.randn(2, 1),
206
+ ... done=torch.zeros(2, dtype=torch.bool),
207
+ ... history=history,
208
+ ... ),
209
+ ... batch_size=(2,),
210
+ ... )
211
+ >>> data = lazy_stack(list(td.unbind(0)))
212
+ >>>
213
+ >>> # Apply the transform to get reference log probabilities
214
+ >>> data = transform(data)
215
+ >>> assert "ref_log_prob" in data["next"].keys()
216
+ >>>
217
+ >>> # Use with SFTLoss for KL regularization
218
+ >>> loss = SFTLoss(
219
+ ... actor_network=policy_train,
220
+ ... tokenizer=tokenizer,
221
+ ... reduction="mean",
222
+ ... normalize_by_seq_length=True,
223
+ ... kl_to_ref_coeff=0.1,
224
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
225
+ ... loss_function="sft",
226
+ ... )
227
+ >>> loss_vals = loss(data)
228
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
229
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
230
+
231
+ """
232
+
233
+ @dataclass
234
+ class _AcceptedKeys:
235
+ """Maintains default values for all configurable tensordict keys.
236
+
237
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
238
+ default values.
239
+
240
+ Attributes:
241
+ history (NestedKey): The input tensordict key where the chat history is expected.
242
+ Defaults to ``("next", "history")``.
243
+ ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected.
244
+ Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``.
245
+ log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written.
246
+ Defaults to ``"log_probs"``.
247
+ """
248
+
249
+ history: NestedKey = ("next", "history")
250
+ ref_log_prob: NestedKey = ("next", "ref_log_prob")
251
+ log_probs: NestedKey = "log_probs"
252
+
253
+ default_keys = _AcceptedKeys
254
+ tensor_keys: _AcceptedKeys
255
+
256
+ def __init__(
257
+ self,
258
+ actor_network: TensorDictModule | TransformersWrapper,
259
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
260
+ tokenizer_kwargs: dict | None = None,
261
+ reduction: Literal["mean", "sum", "none"] = "mean",
262
+ normalize_by_seq_length: bool = True,
263
+ kl_to_ref_coeff: float | None = None,
264
+ loss_function: Literal["sft", "minor_sft"] = "sft",
265
+ beta: float = 0.1,
266
+ device: torch.device | None = None,
267
+ ):
268
+ super().__init__()
269
+ self.in_keys = []
270
+ self.actor_network = actor_network
271
+ if tokenizer is None:
272
+ tokenizer = actor_network.tokenizer
273
+ self.tokenizer = tokenizer
274
+ if tokenizer_kwargs is None:
275
+ tokenizer_kwargs = {}
276
+ if tokenizer is None:
277
+ raise ValueError("Tokenizer must be provided.")
278
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
279
+ tokenizer_kwargs.setdefault("tokenize", True)
280
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
281
+ tokenizer_kwargs.setdefault("padding", False)
282
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
283
+ self.tokenizer_kwargs = tokenizer_kwargs
284
+ self.reduction = reduction
285
+ self.normalize_by_seq_length = normalize_by_seq_length
286
+ self.kl_to_ref_coeff = kl_to_ref_coeff
287
+ self.loss_function = loss_function
288
+ if self.loss_function == "minor_sft" and kl_to_ref_coeff:
289
+ warnings.warn(
290
+ "kl_to_ref_coeff should not be set when using minor_sft loss, as KL regularization is implicit. Setting kl_to_ref_coeff to 0.0."
291
+ )
292
+ self.kl_to_ref_coeff = 0.0
293
+ self.beta = beta
294
+ self._set_in_keys()
295
+ self.device = device
296
+
297
+ def _set_in_keys(self) -> None:
298
+ """Sets the input keys for the loss module."""
299
+ in_keys = [self.tensor_keys.history]
300
+ if self.kl_to_ref_coeff is not None or self.loss_function == "minor_sft":
301
+ in_keys.append(self.tensor_keys.ref_log_prob)
302
+ self.in_keys = in_keys
303
+ self.out_keys = [] # Loss modules typically don't have out_keys
304
+
305
+ def _kl_to_ref(
306
+ self,
307
+ cur_log_prob: list[torch.Tensor],
308
+ ref_log_prob: list[torch.Tensor],
309
+ ) -> tuple[torch.Tensor, torch.Tensor]:
310
+ """Compute KL divergence to reference model.
311
+
312
+ Args:
313
+ cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response.
314
+ ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response.
315
+
316
+ Returns:
317
+ tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging)
318
+ """
319
+ # Apply mask
320
+ ref_log_prob = torch.cat(ref_log_prob)
321
+ cur_log_prob = torch.cat(cur_log_prob)
322
+ # ref_log_prob = ref_log_prob[mask]
323
+ # cur_log_prob = cur_log_prob[mask].squeeze()
324
+ if cur_log_prob.shape != ref_log_prob.shape:
325
+ raise ValueError(
326
+ f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}."
327
+ )
328
+ # Compute KL using same approximation as GRPO
329
+ diff = ref_log_prob - cur_log_prob
330
+
331
+ kl_penalty = (diff.expm1() - diff).mean()
332
+ return self.kl_to_ref_coeff * kl_penalty, kl_penalty
333
+
334
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
335
+ # Gather history
336
+ history: History = tensordict[self.tensor_keys.history]
337
+
338
+ # Apply tokenizer to history and gather mask
339
+ with torch.device(
340
+ self.device
341
+ ) if self.device is not None else contextlib.nullcontext():
342
+ token_struct = history.apply_chat_template(
343
+ tokenizer=self.tokenizer, **self.tokenizer_kwargs
344
+ )
345
+ if "assistant_masks" not in token_struct:
346
+ raise ValueError(
347
+ f"Assistant masks are not present in the token structure: {token_struct=}."
348
+ )
349
+ assistant_masks = token_struct.get(
350
+ "assistant_masks",
351
+ as_list=True,
352
+ )
353
+ assistant_masks = [mask.bool() for mask in assistant_masks]
354
+ attention_mask = token_struct.get("attention_mask", as_list=True)
355
+ attention_mask = [mask.bool() for mask in attention_mask]
356
+ assistant_masks = [
357
+ mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
358
+ ]
359
+
360
+ if not any(mask.any(-1).all() for mask in assistant_masks):
361
+ raise ValueError("Some inputs have no valid assistant masks.")
362
+ input_loss = tensordict.select(self.tensor_keys.history)
363
+ if (
364
+ isinstance(self.tensor_keys.history, tuple)
365
+ and self.tensor_keys.history[0] == "next"
366
+ ):
367
+ input_loss = input_loss["next"]
368
+
369
+ with torch.device(
370
+ self.device
371
+ ) if self.device is not None else contextlib.nullcontext():
372
+ output_loss = self.actor_network(input_loss)
373
+
374
+ # get log-probs
375
+ log_probs = output_loss.get(
376
+ self.tensor_keys.log_probs,
377
+ as_list=True,
378
+ )
379
+ # apply mask
380
+ if not all(
381
+ mask.shape == lp.shape
382
+ for mask, lp in _zip_strict(assistant_masks, log_probs)
383
+ ):
384
+ raise ValueError(
385
+ f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
386
+ )
387
+
388
+ log_probs_masked = [
389
+ lp.masked_fill(~mask, 0.0)
390
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
391
+ ]
392
+
393
+ # Sum log probs, optionally normalize by sequence length
394
+ summed_log_probs = torch.stack(
395
+ [lp.sum(tensordict.ndim - 1) for lp in log_probs_masked]
396
+ )
397
+ seq_lengths = torch.stack(
398
+ [mask.sum(tensordict.ndim - 1) for mask in assistant_masks]
399
+ )
400
+ if self.normalize_by_seq_length:
401
+ # Compute sequence lengths for normalization (number of assistant tokens)
402
+ summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1)
403
+
404
+ # Compute main loss
405
+ if self.loss_function == "sft":
406
+ loss = sft_loss(summed_log_probs, self.reduction)
407
+ # Add KL divergence loss if reference model is provided
408
+ if self.kl_to_ref_coeff is not None:
409
+ ref_log_probs = tensordict.get(
410
+ self.tensor_keys.ref_log_prob,
411
+ default=None,
412
+ as_list=True,
413
+ )
414
+ if ref_log_probs is None:
415
+ raise ValueError(
416
+ "Reference log probs not found in tensordict but kl_to_ref_coeff was set"
417
+ )
418
+
419
+ loss_kl, kl_penalty = self._kl_to_ref(
420
+ [lp[mask] for lp, mask in _zip_strict(log_probs, assistant_masks)],
421
+ ref_log_probs,
422
+ )
423
+ output = SFTLossOutput(
424
+ loss_sft=loss,
425
+ loss_kl_to_ref=loss_kl,
426
+ kl_to_ref=kl_penalty.detach(),
427
+ )
428
+ else:
429
+ output = SFTLossOutput(loss_sft=loss)
430
+ elif self.loss_function == "minor_sft":
431
+ ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
432
+ if ref_log_probs is None:
433
+ raise ValueError(
434
+ f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict but loss_function is 'minor_sft'"
435
+ )
436
+
437
+ # we need to re-sum ref_log_probs as they are not summed per-sequence
438
+ summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to(
439
+ summed_log_probs.device
440
+ )
441
+ if self.normalize_by_seq_length:
442
+ summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1)
443
+ loss = minor_sft_loss(
444
+ summed_log_probs, summed_ref_log_probs, self.beta, self.reduction
445
+ )
446
+ if self.kl_to_ref_coeff is not None:
447
+ with torch.no_grad():
448
+ loss_kl, kl_penalty = self._kl_to_ref(
449
+ [
450
+ lp[mask]
451
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
452
+ ],
453
+ ref_log_probs,
454
+ )
455
+ output = SFTLossOutput(
456
+ loss_sft=loss,
457
+ loss_kl_to_ref=loss_kl,
458
+ kl_to_ref=kl_penalty.detach(),
459
+ )
460
+ else:
461
+ output = SFTLossOutput(loss_sft=loss)
462
+ else:
463
+ raise ValueError(f"Invalid loss function: {self.loss_function}")
464
+
465
+ return output
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.6.20'
2
- git_version = 'e1e15d692a6df69bfdc80e85f45d37a4b967e625'
1
+ __version__ = '2025.6.22'
2
+ git_version = '77dbc6c9ffbce3d2ce3f26b659355cd46d8132c3'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.6.20
3
+ Version: 2025.6.22
4
4
  Home-page: https://github.com/pytorch/rl
5
5
  Author: torchrl contributors
6
6
  Author-email: vmoens@fb.com
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
3
3
  build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
4
4
  torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
5
5
  torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
6
- torchrl/_torchrl.cpython-313-x86_64-linux-gnu.so,sha256=K2mDnBxDxSj0qflvFxCBKYjue5HjA5oOtzfr9s26sVk,21451208
6
+ torchrl/_torchrl.cpython-313-x86_64-linux-gnu.so,sha256=BCB4sN_0LejjRpv5bAtaI_d135o_BAHu_XLlQdHGNQg,21451176
7
7
  torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
8
- torchrl/version.py,sha256=tueh-ZkmvytwgeWC2kBvr7G0Iu6dTR_gZ03kkRW7v8Q,83
8
+ torchrl/version.py,sha256=-iU3qRfg2kpxfWLGxmIBA7z2dNl2hOX09zklpdgeA4Y,83
9
9
  torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
10
- torchrl/collectors/collectors.py,sha256=5yY-WMEAzj2vKBA0RxijFjBq7y5faQGY2BwkPfI5qSU,177457
10
+ torchrl/collectors/collectors.py,sha256=CdTerIwhCTr6n5OoJLNad0bNQ5OLliPZFWkU18QBKSA,177625
11
11
  torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
12
12
  torchrl/collectors/weight_update.py,sha256=nSIfs8ALsfggLoC2ylg1oOAqdGku1tt4e-50JCZJBww,21073
13
13
  torchrl/collectors/distributed/__init__.py,sha256=_24P0ALFunLhL-ls7EsssGUhJkZ_m3nw7krfMTwPqS0,705
@@ -18,12 +18,12 @@ torchrl/collectors/distributed/rpc.py,sha256=0xQDqKlvLmCb_2wL9oZojt4rONaSq09abPL
18
18
  torchrl/collectors/distributed/sync.py,sha256=oZW3nUYrUK52N6pMYX9M0WWhMeTzLl25maxM7X2G8Ec,27272
19
19
  torchrl/collectors/distributed/utils.py,sha256=MuxSeb4TkiyWJYyMyXWLgyCDgtgbGU6g8nNVf59xqCE,6464
20
20
  torchrl/collectors/llm/__init__.py,sha256=rx9DktowQ-gvFleb07US9d9WFc4aNG6zKpiOPSW4A7U,355
21
- torchrl/collectors/llm/base.py,sha256=wyZmNIZ_92lUkfZKgbCh8OXDUoHNhVn_7s4qizmrH58,20388
22
- torchrl/collectors/llm/ray_collector.py,sha256=cc1oZ1zh322lJL21bgJd4b6w9-QUQiCOx2i63zzgMLo,10948
21
+ torchrl/collectors/llm/base.py,sha256=G6n2_U7CIr0BBUMbrOSv-AIRSxFFMqft_Ia3Ir3Ggks,20671
22
+ torchrl/collectors/llm/ray_collector.py,sha256=1o9rbQtoJ48Ovo_YP76KQ-dLlKJt-bGdH2VMaK_-olg,11101
23
23
  torchrl/collectors/llm/utils.py,sha256=-KRSlOmjj34M0c3msP7yS_0DlLmqCijEbf_bADLjzuM,1148
24
24
  torchrl/collectors/llm/weight_update/__init__.py,sha256=bKjvD7yZG5VnHgvYc4EmKI1seK4FyMBKTqeLzkqR_3s,272
25
25
  torchrl/collectors/llm/weight_update/vllm.py,sha256=81ShmKzNjVIg7hxlPvLHhF-YqeXv98cIk0l6ByD-MDU,11276
26
- torchrl/data/__init__.py,sha256=RuBnwrzJqJZxU1drtdzUHdWTrZpL6z4SPLYBYM2AMqc,4769
26
+ torchrl/data/__init__.py,sha256=oowsio6ZUOZnJV8JV43xgs17B37XO1yKAYIQPdk8yt0,4819
27
27
  torchrl/data/rlhf.py,sha256=JUmdYBWgkN229DwpXuDrhy9ddjduNvU2kyHzHR6MoA0,963
28
28
  torchrl/data/tensor_specs.py,sha256=rfuYM9WLUnF4vHwM4opvypShZ3RN7954WhiPMyG3CSU,247841
29
29
  torchrl/data/utils.py,sha256=attuNwzfgjszyp0lJSrV06f2peX3r0qTjRZWEwfl6Yg,12108
@@ -39,12 +39,13 @@ torchrl/data/datasets/openx.py,sha256=QXjJPZHoRhefVux00iAL-g4spynrWjmI_M2IuaQ8TA
39
39
  torchrl/data/datasets/roboset.py,sha256=rLPdyEQI9yEibXU6SZFA0YD79EGFKY8o5oyUlvcn4aM,16648
40
40
  torchrl/data/datasets/utils.py,sha256=nAFDTlBIPyEoPoJC-Hc_fcOhzE7UZQE4BwKxq15Vhvk,339
41
41
  torchrl/data/datasets/vd4rl.py,sha256=z90MqrxKzod8TPGK0uzkC6vw5wQIE4cgrDAC4e72jyk,18262
42
- torchrl/data/llm/__init__.py,sha256=FWApDyEPlyE7jA7CRMmGpxnLYQ4ZKKAdCbbpmOMf-OU,908
43
- torchrl/data/llm/chat.py,sha256=aGnrefzwxvEhuPlTHmLPHyzqG2xWbMnzwjWBfoldNIM,30817
42
+ torchrl/data/llm/__init__.py,sha256=By2FWnjqADPmHnNXh6DVLQ9CYPj51gn3HxPW_DYPMyc,971
43
+ torchrl/data/llm/chat.py,sha256=K5Cuw4GHSJWGg5vXwGyV9oqS7X0ddcx1FA1sUNEvjKY,33174
44
44
  torchrl/data/llm/common.py,sha256=CYBaAop8QETotOCBGTw_pfKjxFYlsoSGElki6wBx5jo,2135
45
45
  torchrl/data/llm/dataset.py,sha256=t-41hAzQcjrdoKwpHIMbcrT7pRcQ7DHl2a1-lr6E7W4,20703
46
46
  torchrl/data/llm/prompt.py,sha256=bg5LzJfwOq5Ns72KQMciIprMWAmDDinzdopwdopU04c,8380
47
47
  torchrl/data/llm/reward.py,sha256=FbPchNXG3smJV9NCbB5Yk4grsCa2Se4KZ_tojVLKWQM,8404
48
+ torchrl/data/llm/topk.py,sha256=SZq89yeFr8rNbpVR-S5vC7AVoeb6JKYZPeSS-n4FwKE,8353
48
49
  torchrl/data/llm/utils.py,sha256=axe3wSovfWBm5YmR_uJYpfAmYtd__2i9SCKgUSezkBk,23600
49
50
  torchrl/data/map/__init__.py,sha256=1IB8lWApscQOOscsCEhQrUDy_AE1wWV51Tcl1Segsqk,555
50
51
  torchrl/data/map/hash.py,sha256=29cKgYjd5vVeR2bu2kI5BwtOq9FeZD41RA7Q3UxP9vo,7252
@@ -56,11 +57,11 @@ torchrl/data/postprocs/__init__.py,sha256=Z9JpRKMGsuFGpB3ro4R9Y_hYTBqkkzbkWZR79T
56
57
  torchrl/data/postprocs/postprocs.py,sha256=h8LO8zBosRm7iLmUOxdtPxZ84yavkv9usYtLSBq9tC4,15118
57
58
  torchrl/data/replay_buffers/__init__.py,sha256=v_oKflSohims6uw40XhLkjDX7vZM9UwXrWAeZfftogw,2360
58
59
  torchrl/data/replay_buffers/checkpointers.py,sha256=VF18DlRiy361gecbT2HL5VLTQU4Faxq7mULsownjYiQ,14790
59
- torchrl/data/replay_buffers/ray_buffer.py,sha256=joYh_ypj4Zk2CNUEdfgNiAybvA8vNJTmrAlwC1bhejg,9043
60
- torchrl/data/replay_buffers/replay_buffers.py,sha256=c0yoPBzQjfWfTLOXbPj7VqA-ZnTP_7oS0OVpwo5gaPk,88639
61
- torchrl/data/replay_buffers/samplers.py,sha256=HZguztdX2tvDwIPKT0STAKuncJi8l005FCuuNKnAVoE,105427
60
+ torchrl/data/replay_buffers/ray_buffer.py,sha256=at8rYXxtlctoPCnL5oJRNoEkjEASHoXjPIt6UH16OCA,9854
61
+ torchrl/data/replay_buffers/replay_buffers.py,sha256=lKTcEQOooT_MY4cuVuAdaYPKN9Ob9v3o46FGdnCyOS8,90459
62
+ torchrl/data/replay_buffers/samplers.py,sha256=Kp48OPzvEWeTbPS8LNMRiGaYwUdrMgVVc3OaRIkNIR4,110296
62
63
  torchrl/data/replay_buffers/scheduler.py,sha256=SRZf_FJLUEIBz684W9RlLt3In158s9N5h4xb_MWnBgY,10152
63
- torchrl/data/replay_buffers/storages.py,sha256=WnNopbDT3DbjiN7QRdb-Iet-sLBmt7hSfEnKDx1-VEI,59579
64
+ torchrl/data/replay_buffers/storages.py,sha256=9h2iyLv9jnKG7kB1925SRlcxly-IABqGjPhoMGov-6Y,60704
64
65
  torchrl/data/replay_buffers/utils.py,sha256=tU98Nc_j9bMrWBs96gFUTDXLmWEZCvHRYjSXjPMc_lY,38603
65
66
  torchrl/data/replay_buffers/writers.py,sha256=p9b8k89u-JrqoObT4aCLa0qCkKWdM__l7lGUQDKSdsU,27727
66
67
  torchrl/envs/__init__.py,sha256=c-_VtMuAcRdg0hBmltn6AbTU7B1X-ARBEfqOQoPFEZk,5817
@@ -96,8 +97,8 @@ torchrl/envs/libs/smacv2.py,sha256=i0TRHuZ9S9v0NfufPgQAcTlvAjf6JKv8hHvOzjSgsaw,2
96
97
  torchrl/envs/libs/unity_mlagents.py,sha256=Z3qSU0H3o2NXbS2lNvQ7OmYxkr3AWAMyRHfxeCtNZrk,49667
97
98
  torchrl/envs/libs/utils.py,sha256=RgiR16KJWFEtQim44-AIcHByGTq_NrtpjWoYIC13aYA,5207
98
99
  torchrl/envs/libs/vmas.py,sha256=a71_jU4r627hFXcMsT5wNSb4TMpyd3punLdOF3Cc8O0,36297
99
- torchrl/envs/llm/__init__.py,sha256=DiYt8YjoxmwoM62XPtNUPMYaqZyf1UXY6dAD_vcBIfE,1221
100
- torchrl/envs/llm/chat.py,sha256=DT_kcsfpM0W3bayRVk3rdtNKyv3pjoOsicw56LG6fp8,17619
100
+ torchrl/envs/llm/__init__.py,sha256=o8uAVGHYngy_k6xM5qIkqgHaz__S1HyG7QjLd78gtaA,1265
101
+ torchrl/envs/llm/chat.py,sha256=mVLjmBTwd6IWdlKJMRcynDJNVVbiHjCop5EVUXpaaAA,17794
101
102
  torchrl/envs/llm/envs.py,sha256=Er-ahjgvtYG4LB7_EWOMbdobiUV5DOHPBQYkVTu80r4,34677
102
103
  torchrl/envs/llm/datasets/__init__.py,sha256=FFethtv8unJWzphGLPQVC5QD9NMdaygEjx25O1DHHZk,473
103
104
  torchrl/envs/llm/datasets/gsm8k.py,sha256=wTntpV-bi0gbyvJ-JnuHQmPXjXgV4hEssGFed8GRGGc,15299
@@ -105,18 +106,18 @@ torchrl/envs/llm/datasets/ifeval.py,sha256=fVbMSVjpnlZR36B0yDUgDcM1Ye-EP6ui7g9nP
105
106
  torchrl/envs/llm/libs/__init__.py,sha256=vhEm5Fhz1sLWt107zfZLy5pzGmfQi0fNBGazTq1m7dU,266
106
107
  torchrl/envs/llm/libs/mlgym.py,sha256=ECnkrNoPV73L1fIO05SlTTXuTSNOM2pdX6aJcEYJVlo,31372
107
108
  torchrl/envs/llm/reward/__init__.py,sha256=a-Xsye29z2LugO1cOCFM2FNsqNwEp-5XwQk4saVQlu8,370
108
- torchrl/envs/llm/reward/gsm8k.py,sha256=6y6I8UdPanS6g7skWFStNm_nXP0nS5ctcAHFWEkFup0,7702
109
+ torchrl/envs/llm/reward/gsm8k.py,sha256=2pUXYkCw6_arM6HCZJcrEYwRZMDntsFAzdpf3QXNthI,7862
109
110
  torchrl/envs/llm/reward/ifeval/__init__.py,sha256=g5NtrwfwqK22hRcoIdz8-KWBh5Ogre9J-Bf3uGWE9Pg,314
110
111
  torchrl/envs/llm/reward/ifeval/_instructions.py,sha256=rAoTdwG42smCLJgwW7kAwJrNonjIS6OwdohDE70oMOA,61696
111
112
  torchrl/envs/llm/reward/ifeval/_instructions_main.py,sha256=CofKXvG0J2H-1ZXP1fL6UZI8ArNCIO2w5R_37drRIW8,4117
112
113
  torchrl/envs/llm/reward/ifeval/_instructions_registry.py,sha256=3_guc8LZ0mWQc-n6E4cQgYMgZRYa6xfgvXgrze9aO_w,3814
113
114
  torchrl/envs/llm/reward/ifeval/_instructions_util.py,sha256=aA3fupO8MvqBCqD7Y_Qk6y32toWF1lZGAflWON1ruXM,26042
114
115
  torchrl/envs/llm/reward/ifeval/_scorer.py,sha256=zJHBgaGlluEv6czsI6ZtLqArV_J_W9zY7UPAJhT5YIo,14563
115
- torchrl/envs/llm/transforms/__init__.py,sha256=fpcS83ud3OC2NWnkFeTdE8r4Mtlbcp_OiITzDM03aes,748
116
+ torchrl/envs/llm/transforms/__init__.py,sha256=roEOZVFOs1PhC1cGF-LIXQt5DlXZx6mgIJ-1k0JDTfI,788
116
117
  torchrl/envs/llm/transforms/browser.py,sha256=zF7jHHHrdpxUCjFFtiYK-vhw-p1YqsqwP8_b4SiK0Rs,10423
117
118
  torchrl/envs/llm/transforms/dataloading.py,sha256=dv4IV3OWEa6-evxBk3WAZjkBi1_yKUs2NQ2gGmL2lKQ,24533
118
119
  torchrl/envs/llm/transforms/format.py,sha256=ESn0S9k5G4FQPBICq9h6ZsLKXZqiU71tYW8UnW4rgLI,2519
119
- torchrl/envs/llm/transforms/kl.py,sha256=lRWW1Gf8bu71jMatAlk91Eeuh50mmPedjKrnXKUm5D0,11721
120
+ torchrl/envs/llm/transforms/kl.py,sha256=N68378chSx54X5a7YLJzIV6d870H5xrBb5-qWqzpX1U,22744
120
121
  torchrl/envs/llm/transforms/policy_version.py,sha256=by2TjsZLwVjQbq7ggBoAco2Iq_2aEYgyxh9asTXL1vk,6893
121
122
  torchrl/envs/llm/transforms/tokenizer.py,sha256=CcuKRu33YnyDgLtQtyxTGDFC6iI3b3fUA6Nb1Lnh7h8,13953
122
123
  torchrl/envs/llm/transforms/tools.py,sha256=I-HR0zjH4tFMp9xPH556H5Q5JqmqXdsAXwElAR93e5U,29498
@@ -147,8 +148,8 @@ torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un
147
148
  torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
148
149
  torchrl/modules/llm/policies/__init__.py,sha256=rVQwVhSTS1hLcSynvPXKq9_9gGC6gC1SyOz5DNg1qcc,426
149
150
  torchrl/modules/llm/policies/common.py,sha256=m76rSjgYbf-ZMEUFZNbjBbyXNHbR8BXt1z5o9honJOM,3019
150
- torchrl/modules/llm/policies/transformers_wrapper.py,sha256=yn_qVpFqjr41HrkkxGhLDfIjtS9PCgklnbkAecu4Evc,22615
151
- torchrl/modules/llm/policies/vllm_wrapper.py,sha256=g3eaQSNti6NQBpKcokeLL9b0K3Kt38ltaPv8qlIIqDo,29782
151
+ torchrl/modules/llm/policies/transformers_wrapper.py,sha256=M0Drk7MFY596Ek8_duNTXFpc4c2Ar94Jy3viXnhRS2M,25370
152
+ torchrl/modules/llm/policies/vllm_wrapper.py,sha256=1vwfoIYxOL2IwBMVZUFrwOexIwS7x1xbhBVdru6gYxY,31487
152
153
  torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
153
154
  torchrl/modules/models/batchrenorm.py,sha256=TojpTUluIcFdTSemIVRLGtB2O5q54mRHy3vJP6DuI5I,4750
154
155
  torchrl/modules/models/decision_transformer.py,sha256=Lttf_wZMNqXbB_vpxMYgEp18gEzOvm3NvMnxQkHkH4M,6604
@@ -194,8 +195,9 @@ torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,639
194
195
  torchrl/objectives/td3.py,sha256=RnlkGzBBTY0KrfRKytsFbNyoVUy2HLfwSL4_9YQRep8,23190
195
196
  torchrl/objectives/td3_bc.py,sha256=jHGwCzPuCbN37zAxsiDQIe92yR1UE7rjcnJoy8b_NjE,25950
196
197
  torchrl/objectives/utils.py,sha256=nhB7a2gLoZMLgYSWTpSgQqZWEGRBkvVoa8yszTlecm4,24001
197
- torchrl/objectives/llm/__init__.py,sha256=LnYwAuaG-ylQQcu2BRQWavaDhjMPikXNT6YaH_3QoEU,328
198
+ torchrl/objectives/llm/__init__.py,sha256=SXYwry5YoDp5m0QRFmOYzz60siJQmofcTvCOmC1DlXw,396
198
199
  torchrl/objectives/llm/grpo.py,sha256=rsPVvfE_2Bbl8K1aq_LIry1ViDnibfGYWexfSIbJx80,16788
200
+ torchrl/objectives/llm/sft.py,sha256=zAdVT1CmXJJPjEwPt4SPJNzFUC2m-flcfOsejIuAFkg,20107
199
201
  torchrl/objectives/multiagent/__init__.py,sha256=CHxWmq5_3kveLcAdyB7cgSVYVIald7EZo81RRgozxo0,237
200
202
  torchrl/objectives/multiagent/qmixer.py,sha256=JyDcZeV2zv2MqKsyJ-ql9ISYHJ58e3pzb5-0BThswhI,16973
201
203
  torchrl/objectives/value/__init__.py,sha256=AdluF370wYzOAcP_yglUAFnNByKVZzivBYJafkDQbJA,561
@@ -221,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=qH-2YJwMtDAYAPXTTYy3cOPiq4ILC6xTjfnGUU
221
223
  torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
222
224
  torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
223
225
  torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
224
- torchrl_nightly-2025.6.20.dist-info/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
225
- torchrl_nightly-2025.6.20.dist-info/METADATA,sha256=fR7h0Sq2FvApealDc_cnX5Aj5QtIL4acVWv-Cz60FTk,39023
226
- torchrl_nightly-2025.6.20.dist-info/WHEEL,sha256=HRqO1yy0EkQFVSOPjhgaTzf773tbWecKJXRlZH64XT8,104
227
- torchrl_nightly-2025.6.20.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
228
- torchrl_nightly-2025.6.20.dist-info/RECORD,,
226
+ torchrl_nightly-2025.6.22.dist-info/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
227
+ torchrl_nightly-2025.6.22.dist-info/METADATA,sha256=p6Lp-DEGEipD6Ak0XOcCM_dRoqAan8l4PSSAtgyr7K4,39023
228
+ torchrl_nightly-2025.6.22.dist-info/WHEEL,sha256=HRqO1yy0EkQFVSOPjhgaTzf773tbWecKJXRlZH64XT8,104
229
+ torchrl_nightly-2025.6.22.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
+ torchrl_nightly-2025.6.22.dist-info/RECORD,,