torchrl-nightly 2025.6.20__cp312-cp312-win_amd64.whl → 2025.6.21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
7
+ import contextlib
8
+ import warnings
9
+
10
+ from dataclasses import dataclass
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from tensordict import NestedKey, TensorClass, TensorDictBase
15
+ from tensordict.nn import TensorDictModule
16
+ from tensordict.utils import _zip_strict
17
+ from torchrl.data import History
18
+ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
19
+ from torchrl.objectives.common import LossModule
20
+
21
+
22
+ def sft_loss(summed_log_probs: torch.Tensor, reduction: str) -> torch.Tensor:
23
+ """Compute the SFT loss."""
24
+ if reduction == "mean":
25
+ loss = -summed_log_probs.mean()
26
+ elif reduction == "sum":
27
+ loss = -summed_log_probs.sum()
28
+ elif reduction == "none":
29
+ loss = -summed_log_probs
30
+ else:
31
+ raise ValueError(f"Invalid reduction: {reduction}.")
32
+ return loss
33
+
34
+
35
+ def minor_sft_loss(
36
+ log_probs: torch.Tensor,
37
+ ref_log_probs: torch.Tensor,
38
+ beta: float,
39
+ reduction: str,
40
+ ) -> torch.Tensor:
41
+ """Compute the MinorSFT loss.
42
+
43
+ This loss is inspired by DPO and is designed to be less aggressive than standard SFT.
44
+ It computes ``-log_sigmoid(beta * (log_probs - ref_log_probs))``.
45
+
46
+ Args:
47
+ log_probs (torch.Tensor): The log probabilities from the model being trained.
48
+ ref_log_probs (torch.Tensor): The log probabilities from the reference model.
49
+ beta (float): The beta parameter from DPO.
50
+ reduction (str): The reduction to apply to the loss.
51
+
52
+ Returns:
53
+ The MinorSFT loss.
54
+
55
+ References:
56
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
57
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
58
+ """
59
+ if log_probs.shape != ref_log_probs.shape:
60
+ raise ValueError(
61
+ f"Current log probabilities and reference log probabilities have different shapes: {log_probs.shape=} vs {ref_log_probs.shape=}."
62
+ )
63
+ loss = -torch.nn.functional.logsigmoid(beta * (log_probs - ref_log_probs))
64
+ if reduction == "mean":
65
+ return loss.mean()
66
+ if reduction == "sum":
67
+ return loss.sum()
68
+ if reduction == "none":
69
+ return loss
70
+ raise ValueError(f"Invalid reduction: {reduction}")
71
+
72
+
73
+ class SFTLossOutput(TensorClass["nocast"]):
74
+ """SFT Loss Output.
75
+
76
+ Attributes:
77
+ loss_sft (torch.Tensor): The loss for the SFT objective.
78
+ loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model.
79
+ kl_to_ref (torch.Tensor | None): The KL divergence to the reference model.
80
+
81
+ .. note::
82
+ The loss components are kept separate to allow for logging and visualization.
83
+ Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable
84
+ when the loss is constructed via :class:`~torchrl.objectives.llm.sft.SFTLoss`, summing
85
+ the :class:`~torchrl.objectives.llm.sft.SFTLossOutput` directly is a proper way of obtaining the total loss.
86
+
87
+ >>> loss_fn = SFTLoss(...)
88
+ >>> loss_output = loss_fn(td)
89
+ >>> loss = loss_output.loss_sft + loss_output.loss_kl_to_ref
90
+ >>> loss.backward()
91
+ >>> # or equivalently
92
+ >>> loss = loss_fn(td)
93
+ >>> loss.sum(reduce=True).backward()
94
+ """
95
+
96
+ loss_sft: torch.Tensor
97
+ loss_kl_to_ref: torch.Tensor | None = None
98
+ kl_to_ref: torch.Tensor | None = None
99
+
100
+
101
+ class SFTLoss(LossModule):
102
+ r"""Supervised fine-tuning loss.
103
+
104
+ Args:
105
+ actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance,
106
+ with `return_log_prob=True` and `from_text=True`.
107
+ tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor_network`.
108
+ tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`.
109
+ This can be used to override arguments such as the `chat_template` or `chat_template_name`.
110
+ reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`.
111
+ normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`.
112
+ kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`.
113
+ loss_function (Literal["sft", "minor_sft"], optional): The loss function to use. Defaults to `"sft"`.
114
+ beta (float, optional): The beta parameter for MinorSFT loss. This is only used when `loss_function` is `"minor_sft"`.
115
+ Higher values of beta make the loss more aggressive (pushes the model to generate responses further from the reference model):
116
+
117
+ .. math::
118
+ \text{loss} = -\log\sigma(\beta \cdot (\text{log_probs} - \text{ref_log_probs}))
119
+
120
+ Defaults to `0.1`.
121
+ device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`.
122
+
123
+ .. note::
124
+ The input tensordict is expected to contain the following keys by default:
125
+ - ``("next", "history")``: The chat history
126
+ - ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set
127
+
128
+ These keys can be customized using the ``set_keys()`` method.
129
+
130
+ .. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation.
131
+
132
+ References:
133
+ - Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
134
+ `"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
135
+
136
+ Examples:
137
+ >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
138
+ >>> from torchrl.modules.llm import TransformersWrapper
139
+ >>> from torchrl.objectives.llm.sft import SFTLoss
140
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
141
+ >>> from tensordict import TensorDict, lazy_stack
142
+ >>> import torch
143
+ >>>
144
+ >>> # Create chat data
145
+ >>> chats = [
146
+ ... [
147
+ ... {"role": "system", "content": "You are a helpful assistant."},
148
+ ... {"role": "user", "content": "Hello, how are you?"},
149
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
150
+ ... ],
151
+ ... [
152
+ ... {"role": "system", "content": "You are a helpful assistant."},
153
+ ... {"role": "user", "content": "What's the weather like?"},
154
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
155
+ ... ],
156
+ ... ]
157
+ >>> history = History.from_chats(chats)
158
+ >>>
159
+ >>> # Setup tokenizer and model
160
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
161
+ >>> tokenizer.pad_token = tokenizer.eos_token
162
+ >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
163
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
164
+ >>>
165
+ >>> # Create training and reference policies
166
+ >>> policy_train = TransformersWrapper(
167
+ ... model,
168
+ ... tokenizer=tokenizer,
169
+ ... generate=False,
170
+ ... from_text=True,
171
+ ... chat_template_name="qwen",
172
+ ... )
173
+ >>> policy_ref = TransformersWrapper(
174
+ ... model,
175
+ ... tokenizer=tokenizer,
176
+ ... generate=False,
177
+ ... from_text=True,
178
+ ... return_log_probs=True,
179
+ ... chat_template_name="qwen",
180
+ ... )
181
+ >>>
182
+ >>> # Create the RetrieveLogProb transform
183
+ >>> transform = RetrieveLogProb(
184
+ ... policy_ref,
185
+ ... assistant_only=True,
186
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
187
+ ... tokenizer=tokenizer,
188
+ ... )
189
+ >>>
190
+ >>> # Prepare data
191
+ >>> text = history[:, :-1].apply_chat_template(
192
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
193
+ ... )
194
+ >>> text_response = history.apply_chat_template(
195
+ ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
196
+ ... )
197
+ >>> text_response = [
198
+ ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
199
+ ... ]
200
+ >>> td = TensorDict(
201
+ ... text=text,
202
+ ... text_response=text_response,
203
+ ... history=history,
204
+ ... next=TensorDict(
205
+ ... reward=torch.randn(2, 1),
206
+ ... done=torch.zeros(2, dtype=torch.bool),
207
+ ... history=history,
208
+ ... ),
209
+ ... batch_size=(2,),
210
+ ... )
211
+ >>> data = lazy_stack(list(td.unbind(0)))
212
+ >>>
213
+ >>> # Apply the transform to get reference log probabilities
214
+ >>> data = transform(data)
215
+ >>> assert "ref_log_prob" in data["next"].keys()
216
+ >>>
217
+ >>> # Use with SFTLoss for KL regularization
218
+ >>> loss = SFTLoss(
219
+ ... actor_network=policy_train,
220
+ ... tokenizer=tokenizer,
221
+ ... reduction="mean",
222
+ ... normalize_by_seq_length=True,
223
+ ... kl_to_ref_coeff=0.1,
224
+ ... tokenizer_kwargs={"chat_template_name": "qwen"},
225
+ ... loss_function="sft",
226
+ ... )
227
+ >>> loss_vals = loss(data)
228
+ >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
229
+ >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
230
+
231
+ """
232
+
233
+ @dataclass
234
+ class _AcceptedKeys:
235
+ """Maintains default values for all configurable tensordict keys.
236
+
237
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
238
+ default values.
239
+
240
+ Attributes:
241
+ history (NestedKey): The input tensordict key where the chat history is expected.
242
+ Defaults to ``("next", "history")``.
243
+ ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected.
244
+ Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``.
245
+ log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written.
246
+ Defaults to ``"log_probs"``.
247
+ """
248
+
249
+ history: NestedKey = ("next", "history")
250
+ ref_log_prob: NestedKey = ("next", "ref_log_prob")
251
+ log_probs: NestedKey = "log_probs"
252
+
253
+ default_keys = _AcceptedKeys
254
+ tensor_keys: _AcceptedKeys
255
+
256
+ def __init__(
257
+ self,
258
+ actor_network: TensorDictModule | TransformersWrapper,
259
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
260
+ tokenizer_kwargs: dict | None = None,
261
+ reduction: Literal["mean", "sum", "none"] = "mean",
262
+ normalize_by_seq_length: bool = True,
263
+ kl_to_ref_coeff: float | None = None,
264
+ loss_function: Literal["sft", "minor_sft"] = "sft",
265
+ beta: float = 0.1,
266
+ device: torch.device | None = None,
267
+ ):
268
+ super().__init__()
269
+ self.in_keys = []
270
+ self.actor_network = actor_network
271
+ if tokenizer is None:
272
+ tokenizer = actor_network.tokenizer
273
+ self.tokenizer = tokenizer
274
+ if tokenizer_kwargs is None:
275
+ tokenizer_kwargs = {}
276
+ if tokenizer is None:
277
+ raise ValueError("Tokenizer must be provided.")
278
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
279
+ tokenizer_kwargs.setdefault("tokenize", True)
280
+ tokenizer_kwargs.setdefault("return_tensors", "pt")
281
+ tokenizer_kwargs.setdefault("padding", False)
282
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
283
+ self.tokenizer_kwargs = tokenizer_kwargs
284
+ self.reduction = reduction
285
+ self.normalize_by_seq_length = normalize_by_seq_length
286
+ self.kl_to_ref_coeff = kl_to_ref_coeff
287
+ self.loss_function = loss_function
288
+ if self.loss_function == "minor_sft" and kl_to_ref_coeff:
289
+ warnings.warn(
290
+ "kl_to_ref_coeff should not be set when using minor_sft loss, as KL regularization is implicit. Setting kl_to_ref_coeff to 0.0."
291
+ )
292
+ self.kl_to_ref_coeff = 0.0
293
+ self.beta = beta
294
+ self._set_in_keys()
295
+ self.device = device
296
+
297
+ def _set_in_keys(self) -> None:
298
+ """Sets the input keys for the loss module."""
299
+ in_keys = [self.tensor_keys.history]
300
+ if self.kl_to_ref_coeff is not None or self.loss_function == "minor_sft":
301
+ in_keys.append(self.tensor_keys.ref_log_prob)
302
+ self.in_keys = in_keys
303
+ self.out_keys = [] # Loss modules typically don't have out_keys
304
+
305
+ def _kl_to_ref(
306
+ self,
307
+ cur_log_prob: list[torch.Tensor],
308
+ ref_log_prob: list[torch.Tensor],
309
+ ) -> tuple[torch.Tensor, torch.Tensor]:
310
+ """Compute KL divergence to reference model.
311
+
312
+ Args:
313
+ cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response.
314
+ ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response.
315
+
316
+ Returns:
317
+ tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging)
318
+ """
319
+ # Apply mask
320
+ ref_log_prob = torch.cat(ref_log_prob)
321
+ cur_log_prob = torch.cat(cur_log_prob)
322
+ # ref_log_prob = ref_log_prob[mask]
323
+ # cur_log_prob = cur_log_prob[mask].squeeze()
324
+ if cur_log_prob.shape != ref_log_prob.shape:
325
+ raise ValueError(
326
+ f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}."
327
+ )
328
+ # Compute KL using same approximation as GRPO
329
+ diff = ref_log_prob - cur_log_prob
330
+
331
+ kl_penalty = (diff.expm1() - diff).mean()
332
+ return self.kl_to_ref_coeff * kl_penalty, kl_penalty
333
+
334
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
335
+ # Gather history
336
+ history: History = tensordict[self.tensor_keys.history]
337
+
338
+ # Apply tokenizer to history and gather mask
339
+ with torch.device(
340
+ self.device
341
+ ) if self.device is not None else contextlib.nullcontext():
342
+ token_struct = history.apply_chat_template(
343
+ tokenizer=self.tokenizer, **self.tokenizer_kwargs
344
+ )
345
+ if "assistant_masks" not in token_struct:
346
+ raise ValueError(
347
+ f"Assistant masks are not present in the token structure: {token_struct=}."
348
+ )
349
+ assistant_masks = token_struct.get(
350
+ "assistant_masks",
351
+ as_list=True,
352
+ )
353
+ assistant_masks = [mask.bool() for mask in assistant_masks]
354
+ attention_mask = token_struct.get("attention_mask", as_list=True)
355
+ attention_mask = [mask.bool() for mask in attention_mask]
356
+ assistant_masks = [
357
+ mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
358
+ ]
359
+
360
+ if not any(mask.any(-1).all() for mask in assistant_masks):
361
+ raise ValueError("Some inputs have no valid assistant masks.")
362
+ input_loss = tensordict.select(self.tensor_keys.history)
363
+ if (
364
+ isinstance(self.tensor_keys.history, tuple)
365
+ and self.tensor_keys.history[0] == "next"
366
+ ):
367
+ input_loss = input_loss["next"]
368
+
369
+ with torch.device(
370
+ self.device
371
+ ) if self.device is not None else contextlib.nullcontext():
372
+ output_loss = self.actor_network(input_loss)
373
+
374
+ # get log-probs
375
+ log_probs = output_loss.get(
376
+ self.tensor_keys.log_probs,
377
+ as_list=True,
378
+ )
379
+ # apply mask
380
+ if not all(
381
+ mask.shape == lp.shape
382
+ for mask, lp in _zip_strict(assistant_masks, log_probs)
383
+ ):
384
+ raise ValueError(
385
+ f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
386
+ )
387
+
388
+ log_probs_masked = [
389
+ lp.masked_fill(~mask, 0.0)
390
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
391
+ ]
392
+
393
+ # Sum log probs, optionally normalize by sequence length
394
+ summed_log_probs = torch.stack(
395
+ [lp.sum(tensordict.ndim - 1) for lp in log_probs_masked]
396
+ )
397
+ seq_lengths = torch.stack(
398
+ [mask.sum(tensordict.ndim - 1) for mask in assistant_masks]
399
+ )
400
+ if self.normalize_by_seq_length:
401
+ # Compute sequence lengths for normalization (number of assistant tokens)
402
+ summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1)
403
+
404
+ # Compute main loss
405
+ if self.loss_function == "sft":
406
+ loss = sft_loss(summed_log_probs, self.reduction)
407
+ # Add KL divergence loss if reference model is provided
408
+ if self.kl_to_ref_coeff is not None:
409
+ ref_log_probs = tensordict.get(
410
+ self.tensor_keys.ref_log_prob,
411
+ default=None,
412
+ as_list=True,
413
+ )
414
+ if ref_log_probs is None:
415
+ raise ValueError(
416
+ "Reference log probs not found in tensordict but kl_to_ref_coeff was set"
417
+ )
418
+
419
+ loss_kl, kl_penalty = self._kl_to_ref(
420
+ [lp[mask] for lp, mask in _zip_strict(log_probs, assistant_masks)],
421
+ ref_log_probs,
422
+ )
423
+ output = SFTLossOutput(
424
+ loss_sft=loss,
425
+ loss_kl_to_ref=loss_kl,
426
+ kl_to_ref=kl_penalty.detach(),
427
+ )
428
+ else:
429
+ output = SFTLossOutput(loss_sft=loss)
430
+ elif self.loss_function == "minor_sft":
431
+ ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
432
+ if ref_log_probs is None:
433
+ raise ValueError(
434
+ f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict but loss_function is 'minor_sft'"
435
+ )
436
+
437
+ # we need to re-sum ref_log_probs as they are not summed per-sequence
438
+ summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to(
439
+ summed_log_probs.device
440
+ )
441
+ if self.normalize_by_seq_length:
442
+ summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1)
443
+ loss = minor_sft_loss(
444
+ summed_log_probs, summed_ref_log_probs, self.beta, self.reduction
445
+ )
446
+ if self.kl_to_ref_coeff is not None:
447
+ with torch.no_grad():
448
+ loss_kl, kl_penalty = self._kl_to_ref(
449
+ [
450
+ lp[mask]
451
+ for lp, mask in _zip_strict(log_probs, assistant_masks)
452
+ ],
453
+ ref_log_probs,
454
+ )
455
+ output = SFTLossOutput(
456
+ loss_sft=loss,
457
+ loss_kl_to_ref=loss_kl,
458
+ kl_to_ref=kl_penalty.detach(),
459
+ )
460
+ else:
461
+ output = SFTLossOutput(loss_sft=loss)
462
+ else:
463
+ raise ValueError(f"Invalid loss function: {self.loss_function}")
464
+
465
+ return output
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.6.20'
2
- git_version = 'e1e15d692a6df69bfdc80e85f45d37a4b967e625'
1
+ __version__ = '2025.6.21'
2
+ git_version = '77dbc6c9ffbce3d2ce3f26b659355cd46d8132c3'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl-nightly
3
- Version: 2025.6.20
3
+ Version: 2025.6.21
4
4
  Home-page: https://github.com/pytorch/rl
5
5
  Author: torchrl contributors
6
6
  Author-email: vmoens@fb.com
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFk
3
3
  build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
4
4
  torchrl/__init__.py,sha256=76lKYwYKmAKORhyVt2tURmYAIRTifxxO3gWsskrHAXU,3054
5
5
  torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
6
- torchrl/_torchrl.cp312-win_amd64.pyd,sha256=us0SQpPTkMVaK0lkwlpmvnN2kGx_rMYBsxmhFXE_Mv8,429568
6
+ torchrl/_torchrl.cp312-win_amd64.pyd,sha256=vo4yMM0Omnpg4cHCwuWbHKHclWaa2pwnu_uX2wRMld4,429568
7
7
  torchrl/_utils.py,sha256=2N35rdD65U1khMi5gVIz8-nMjlZsoVq0kCiQftVRSxw,42297
8
- torchrl/version.py,sha256=HnpF56pzVs2b-GMFYIzT-B7z9-9fXZZkDTrkv-8dp-M,85
8
+ torchrl/version.py,sha256=qrSt-wV9IZ8YXCgIg42ODdVhUj3c0MW9OPYRWulbkGY,85
9
9
  torchrl/collectors/__init__.py,sha256=LzTyfxmkNGPSa5-3rS5unQK7HfT5ZEdr2NV291rAOlU,832
10
- torchrl/collectors/collectors.py,sha256=nWwPcWWdeEyj3Fe0bdFqN69Yw_jsjPhD49Iwc8H61D0,181298
10
+ torchrl/collectors/collectors.py,sha256=Pz6VYYrekjBiVBQiyzp6zIyZrBdjSv4-FqlfrGYQz3E,181469
11
11
  torchrl/collectors/utils.py,sha256=aBmBLpphhfplqQjRCyn1jtWWJ-Wtc7TWvM0rOBN8SsE,11579
12
12
  torchrl/collectors/weight_update.py,sha256=Ydq5nJSTV3Q1uqLtJ_1Nj1JB5rwHwrG5StaLxymWFV4,21572
13
13
  torchrl/collectors/distributed/__init__.py,sha256=cKDWdNlwx2LoJkTwf-DKUXbq3Y-0Z1DctPYPcdgOSU0,730
@@ -18,12 +18,12 @@ torchrl/collectors/distributed/rpc.py,sha256=xta5tptC0mLlIY_AecLrARvFBYh7nMbclrr
18
18
  torchrl/collectors/distributed/sync.py,sha256=zjp0HEEcSMaDzq8xndoBWvyqYCdf7hp8urYRRiJP2GI,27912
19
19
  torchrl/collectors/distributed/utils.py,sha256=eY6M-vLCSzyACHRNBx5bHcieWsZfLg7DfNKGIv0IgHI,6625
20
20
  torchrl/collectors/llm/__init__.py,sha256=u03aQ97C3sb5-C0-s2tMBAGGs3kJTfZUSse29fHDkIk,365
21
- torchrl/collectors/llm/base.py,sha256=fs7YcEsGF8WtuS1my_4dyzAk94Yw4J25NnZxwnB7AzE,20842
22
- torchrl/collectors/llm/ray_collector.py,sha256=lxtHFkgmnMCNIUa9E3iDmpQfTs8tCuuZmydF6SJcnnM,11208
21
+ torchrl/collectors/llm/base.py,sha256=Wxdo4drsMk_i5u5DzDlikY4j5-TM9f6Ac4xjB6wJgPw,21132
22
+ torchrl/collectors/llm/ray_collector.py,sha256=nk-i61ZAsYkDNZW2Y7vcDhddjkUyyIuiDYjU9iWYklE,11364
23
23
  torchrl/collectors/llm/utils.py,sha256=GnDY2cTu4XEdwqqhFCP4QWfS2tsgaLTy8nwpIaTEsQI,1184
24
24
  torchrl/collectors/llm/weight_update/__init__.py,sha256=ngbL_sPfXh8FMM3r_j0B9QEP_jQIVSOa8pZVouHg9ec,281
25
25
  torchrl/collectors/llm/weight_update/vllm.py,sha256=4kRlEBHb6093d9lkKVIqU8ZwiPoCFtmVVaADuhxKLL4,11571
26
- torchrl/data/__init__.py,sha256=LEjVJd1OcHNz60lnoDMwoFtp-ejbqhWhhkydH_bgCfE,4974
26
+ torchrl/data/__init__.py,sha256=h6ZHGWzvWDfyw6tgo69gld5C_Kgg-T4DP-EVv1hy0Xk,5026
27
27
  torchrl/data/rlhf.py,sha256=_ENSvNe84snnFQG0jlTtOI419nIYtbBHvAw-pdFMiSs,1002
28
28
  torchrl/data/tensor_specs.py,sha256=PF-sta3dHy0UaDu26FS5-ZVlpAsagTUUMLmYHs5PQhk,254762
29
29
  torchrl/data/utils.py,sha256=tXBPxl5VHqPUfJF1VLqURmb066zDd9lipRDER4R1FY8,12444
@@ -39,12 +39,13 @@ torchrl/data/datasets/openx.py,sha256=0p2H3phnvsUgrFFfQulTyGEDcUrvXiFp-p18uFk8Xk
39
39
  torchrl/data/datasets/roboset.py,sha256=sLdDknyPj7f2NF5z5EbzEQ9QhD03VBKHEXXo_wxSRc4,17011
40
40
  torchrl/data/datasets/utils.py,sha256=tRZkarWl-BX_fnGEVNRt6Fjo0wmyKCNAuVmw5Le-0C4,350
41
41
  torchrl/data/datasets/vd4rl.py,sha256=YFjXvP-QGNzF7UWNKkGMKPFthcB0I8v6sJc7oESegYA,18694
42
- torchrl/data/llm/__init__.py,sha256=ZWoisCgT7JbAAHgcyrlp8o_q2mhgrgU4k-0kikLsTh8,941
43
- torchrl/data/llm/chat.py,sha256=z-PujCw0tVZXUktQqf8NUevKCSjL56_qKiouyse6nVo,31554
42
+ torchrl/data/llm/__init__.py,sha256=X86bYW_uNAwKJcxK2AVQpJo56tDejtXSDFtOyZMNhHQ,1006
43
+ torchrl/data/llm/chat.py,sha256=akIL6wj2QtpdWlFvOcRfPcrk7HE5NoSpgZxB8gtdOos,33962
44
44
  torchrl/data/llm/common.py,sha256=3Gb8sMojtNss6wi6hKGSIUDAwK1PC_8ve9W-bHCPjAk,2181
45
45
  torchrl/data/llm/dataset.py,sha256=GDuzflBq2ThgYn_V4bOr_1MHOhESnQ5jX1Wlcw69lfM,21194
46
46
  torchrl/data/llm/prompt.py,sha256=ikHWafhTIoCONCpuMHwIuGfpnPSpg5drZQTxTCmygQE,8578
47
47
  torchrl/data/llm/reward.py,sha256=QW1HWpNRORd3InwWLg-hAhjTlPqX4ffzAkYHEz0jQxo,8629
48
+ torchrl/data/llm/topk.py,sha256=4MTxYTTdfSBM5vxDHnleY7FatanzHgneXs8rjgKwwqQ,8539
48
49
  torchrl/data/llm/utils.py,sha256=K2NQoEhBC6VWowsMeDHu2Q8vbg3ZPEWBBN6z4qifiNM,24143
49
50
  torchrl/data/map/__init__.py,sha256=bON0vqCksU7FPoWNqiNcdl60t7yWUh9SdLhNtglj7jI,576
50
51
  torchrl/data/map/hash.py,sha256=XRYdaFHQUm87fL9pWjhvi2LeZVaqJsASkCU-G_Gus8s,7437
@@ -56,11 +57,11 @@ torchrl/data/postprocs/__init__.py,sha256=fOyX5OMaDb5HGrQbn9W72_QwncNdh6l3DkVSqR
56
57
  torchrl/data/postprocs/postprocs.py,sha256=dpXOKWlhdKy4Um7HdzRKe42PJ_Q1jHC7AX5plR9AIiw,15509
57
58
  torchrl/data/replay_buffers/__init__.py,sha256=oINoSWKO3Ku6YIBF-0KnbVLZwelZbANN4nLU4q6Mir0,2455
58
59
  torchrl/data/replay_buffers/checkpointers.py,sha256=eizAw4W0tQ2EWgfx6-EUV_3EuZMcZVdoahLoux15Nr4,15186
59
- torchrl/data/replay_buffers/ray_buffer.py,sha256=_8ZFBCJKgEcPWSjMoMGoJzUfN5eF9UD79dXK6Lh7dhc,9310
60
- torchrl/data/replay_buffers/replay_buffers.py,sha256=fO9SyUjNl-VUNch3FkwVKIRPkprAcypIcpbQpvCleUI,90708
61
- torchrl/data/replay_buffers/samplers.py,sha256=tBDi6GN2oe9WrJpnbqHmAAkNT0Du0Ce7Ldw9wWZ_7Jk,107888
60
+ torchrl/data/replay_buffers/ray_buffer.py,sha256=p8EkiXOP4EVMkkpjOyje7wiBfgbWOB0xhEJzDdelAxc,10135
61
+ torchrl/data/replay_buffers/replay_buffers.py,sha256=YzNV543zDpvENbUnqjjghkHHq6IwyRms1I3DXl-ayq4,92567
62
+ torchrl/data/replay_buffers/samplers.py,sha256=bFP8j3BahHULASAjeIGtxJX36GRsg3yBCNFu7vR6Zdo,112834
62
63
  torchrl/data/replay_buffers/scheduler.py,sha256=cGm4LZcZ2lo8azDMWKGTdhWApxjZFh0KfynApxAkVK4,10416
63
- torchrl/data/replay_buffers/storages.py,sha256=bP_pak6fi8X57OkXlQnbMX-ze7-VilRYKOTPr8KGkuA,61146
64
+ torchrl/data/replay_buffers/storages.py,sha256=VdEYOQ29FWGwDeHfLafZuobMlmuagxHVHnntKnU-yX4,62298
64
65
  torchrl/data/replay_buffers/utils.py,sha256=vlGfyHVKUAMKBR0l7fJM9NI47ZinS18Qzf8lpwoo6pI,39644
65
66
  torchrl/data/replay_buffers/writers.py,sha256=-aI6Y28oisuFDutMVlPp4e8wTe6x0wlY0MY1OUKHl4Q,28466
66
67
  torchrl/envs/__init__.py,sha256=2eVr8StUSMiNd-IoD5BQAFFuV10pAtO926b6QzRzB_M,6082
@@ -96,8 +97,8 @@ torchrl/envs/libs/smacv2.py,sha256=pr03oGHE2G_fc86qHeSQjSz3S6IH_l2hX0J2umb020M,2
96
97
  torchrl/envs/libs/unity_mlagents.py,sha256=vszCYjEX0S9AmIwLvGsoqc0Jr7jvlBAqZ1HQ1uqesjM,50558
97
98
  torchrl/envs/libs/utils.py,sha256=Ce8nAYc2MQOBTYCV17Yswk98pg3PStnaGPFVW2jqARQ,5354
98
99
  torchrl/envs/libs/vmas.py,sha256=giTORg2AqYzyjrazdD94fD2dNYwX7qe5TFnr-E1mjIg,37140
99
- torchrl/envs/llm/__init__.py,sha256=_srpJ1x42TqTbsMvJpTObVeyAlRhvG6XHpbRqk8qYi4,1274
100
- torchrl/envs/llm/chat.py,sha256=2ZN5fef0cYUGwLPHbDK3YNb9bI2pewIApQKzong-LKA,17993
100
+ torchrl/envs/llm/__init__.py,sha256=Iz5HtLoVy8O4u1mrPmyql4G8SU9S-MCinP_Gh8sbUWo,1320
101
+ torchrl/envs/llm/chat.py,sha256=YvADxo11RKkjD06rBvbbljch3Jb_H4snaBcgkU2Q-7w,18171
101
102
  torchrl/envs/llm/envs.py,sha256=wphbzLwDKYO_OTV63WYW4iTK5Ek4vmb1zNv5gehzodY,35450
102
103
  torchrl/envs/llm/datasets/__init__.py,sha256=6-x0WlKD7lpMVLKA4W1AktvgUs6adMuaGAqYYhgQ_hk,490
103
104
  torchrl/envs/llm/datasets/gsm8k.py,sha256=MfCFu0U7uetDtLdzUdvqX4rENXPsL8msnkMT98Q29jE,15624
@@ -105,18 +106,18 @@ torchrl/envs/llm/datasets/ifeval.py,sha256=sJ4bvXEWBzzNnDDbKkj6yz_1zemDStVFNfxqo
105
106
  torchrl/envs/llm/libs/__init__.py,sha256=zvUe6oe3pjZwGefV-_x4MAC6K89TMqxh3TZs5s3ADkI,274
106
107
  torchrl/envs/llm/libs/mlgym.py,sha256=TMaoV9P5w5EGgBSmLiw42_DOyKEh7ZGf3mDf-LaZ9W0,32237
107
108
  torchrl/envs/llm/reward/__init__.py,sha256=KYNJxyDOe2mZkjyH4CSuQ8qM0_Zu3EAaIGocYhLduPQ,380
108
- torchrl/envs/llm/reward/gsm8k.py,sha256=gkDQu0Xa8r4PEyuhH146M3dnlJbI7XAVzThIt3Opm1M,7899
109
+ torchrl/envs/llm/reward/gsm8k.py,sha256=TW2lACMLXHRlcTTRfTcFTLl7NIJA0TNh6qfSJiC52QI,8066
109
110
  torchrl/envs/llm/reward/ifeval/__init__.py,sha256=vvh7JSUQaEiMjNeMeJvWlcFb2-6_J1LfM6l4mENn4Zg,324
110
111
  torchrl/envs/llm/reward/ifeval/_instructions.py,sha256=jlNvIO3dykk8fBFXC35PQSDJ9vLF3knS-ywq2ILfF00,63362
111
112
  torchrl/envs/llm/reward/ifeval/_instructions_main.py,sha256=DEc7QqfujGxYvqcm2y_zPasBqB7FgSfXRt3QQ4HQUz0,4244
112
113
  torchrl/envs/llm/reward/ifeval/_instructions_registry.py,sha256=bY8R51RgjKJYiim67j5IXSfYhtWtvZrRF61yuqi0Tzs,3914
113
114
  torchrl/envs/llm/reward/ifeval/_instructions_util.py,sha256=63ZJbqUKaqMA_SDhnYT7VppULbible8udHGihiamxKc,27719
114
115
  torchrl/envs/llm/reward/ifeval/_scorer.py,sha256=iv-316dBYlz4fz6WUtzP7151y4xEwuwOq-Wf7Qazgmc,14928
115
- torchrl/envs/llm/transforms/__init__.py,sha256=P6-KVptxmsweLvwmnTl5dkfLS9JNAizwFPWOl8Hr3So,773
116
+ torchrl/envs/llm/transforms/__init__.py,sha256=BnVW7WVCYPlaNPd4cEyXIUI1Qfd7YQkZVsXvj5q5UVw,814
116
117
  torchrl/envs/llm/transforms/browser.py,sha256=d0JIUZ3TfgmqBcci5ihzzTqZA9KeTrs1iCProRWQQK8,10715
117
118
  torchrl/envs/llm/transforms/dataloading.py,sha256=Zl--I6bT2AqWDZmM6RMQq4ds3b1PFGqilAaIbXBJuNc,25054
118
119
  torchrl/envs/llm/transforms/format.py,sha256=tME390wkG0h2V5DAWHZa7EhJ5Or-6cga6AIjxPuy1l8,2592
119
- torchrl/envs/llm/transforms/kl.py,sha256=T-ab6P2PCpue0N4Pkz3ExttcEkTuPZylLlK5MGKAXuY,11992
120
+ torchrl/envs/llm/transforms/kl.py,sha256=GrANICxnF-FC_yVkdO4EU66bj7aTfkB8OXG7wA2uxDo,23251
120
121
  torchrl/envs/llm/transforms/policy_version.py,sha256=fko23hsQrAMmUqFwKjV_CQVavDhixXFUeVE0lJBASOA,7080
121
122
  torchrl/envs/llm/transforms/tokenizer.py,sha256=Nest15FD1iPLNZuw0rAobyb7n3ce6KFX00qfN3dUE2M,14274
122
123
  torchrl/envs/llm/transforms/tools.py,sha256=WoNgUN1Me4mhbqH5ef9XNc5qKXE0H-g3ZobjEKDM_kw,30308
@@ -147,8 +148,8 @@ torchrl/modules/llm/backends/__init__.py,sha256=ABKK4mJeRtoLXEqfnMvIuiovs7VJoCxn
147
148
  torchrl/modules/llm/backends/vllm.py,sha256=5P78jEtAIytgYHzEkOrg-wwqh1ryhiMVy4M_AxNQ9JQ,9649
148
149
  torchrl/modules/llm/policies/__init__.py,sha256=CK7VEdfShjkeNu_-TmYOobrCEjKTIb2aw2hE6s5RBNs,439
149
150
  torchrl/modules/llm/policies/common.py,sha256=GXzmVRa0SJvQ8iPMeuNjwV7EaZDOPrVy5k_LlJ10QXY,3111
150
- torchrl/modules/llm/policies/transformers_wrapper.py,sha256=yylsatdc63hodGcjhfO79CAA2-kQmuvRqiFT38mbPCQ,23177
151
- torchrl/modules/llm/policies/vllm_wrapper.py,sha256=iDG0593FfZfzf8-JiQQPANQ6O-WnF8ifQJw1-O4vDHs,30488
151
+ torchrl/modules/llm/policies/transformers_wrapper.py,sha256=0wDYGpC1T5T8ZVyZOi5S2Qoa2wtg5Oix2W9Y_bKMKs8,25988
152
+ torchrl/modules/llm/policies/vllm_wrapper.py,sha256=LBoFbrTyGiEiilYVmU7Ze-WBpwYinvAVIEtbU1QKajw,32226
152
153
  torchrl/modules/models/__init__.py,sha256=Y1XTkBOB5EMj6IaMru6V3CDwFLnkUtxzsHcqzeqq_4Y,1829
153
154
  torchrl/modules/models/batchrenorm.py,sha256=bR4ZhaJ5E1cSK5o8L2dNX5KVLIb-bgrYxcq6yhx0I1A,4869
154
155
  torchrl/modules/models/decision_transformer.py,sha256=ANFTOm3k9_3Uv1vKGdXumRy3meBPnDdT8HqhVvJ2RCo,6783
@@ -194,8 +195,9 @@ torchrl/objectives/sac.py,sha256=gKOgCU399miKgpgu7Bmzs1bkIF8JTm_lybHn8V4wDuk,654
194
195
  torchrl/objectives/td3.py,sha256=Rq2q5gXo3AMuHm2OjRZvpfvKsAl1lIK5ALh2_sZM1ZE,23743
195
196
  torchrl/objectives/td3_bc.py,sha256=1pjB8mjCT2CLvQzjnqwAfZoc7yhjMB9UQjuJ5wZfTUY,26558
196
197
  torchrl/objectives/utils.py,sha256=Vrjj07SjMYANfFyn3n1xS7izBIs5Mq9mCvyITMzifZs,24705
197
- torchrl/objectives/llm/__init__.py,sha256=hlj2mKgz0BJR1ob1uObF8IkjyrDBnF24fXQf2zKEqaw,337
198
+ torchrl/objectives/llm/__init__.py,sha256=tZmIz3rkeclw3MzJoOWEs2gkewjx2USKrKJbWdyiiaQ,406
198
199
  torchrl/objectives/llm/grpo.py,sha256=nT3Ukjaz7nZZnkS5tnb-pDnRzvZ3L1edpcNCzi5WZRs,17164
200
+ torchrl/objectives/llm/sft.py,sha256=9fzX9Qo0Goyjxuwca6eLN1PUQ24F0LZGRpjzTDLFfs4,20572
199
201
  torchrl/objectives/multiagent/__init__.py,sha256=5uebDe5KrvlzeYV_BSd5vdmfruJQYMeDVVbU4iHErEg,245
200
202
  torchrl/objectives/multiagent/qmixer.py,sha256=yttOxc5FNylKw4iMnYSG1qO8EbHvx8imAhxNxW9_iLw,17362
201
203
  torchrl/objectives/value/__init__.py,sha256=QkSnenYVqe_3FVtwGr_D86N52unnpBvRXfcC5JFTBOw,589
@@ -221,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=rWKure02dl8hLBzLUs-jhNJV8L3QHWtFbl3HbX
221
223
  torchrl/trainers/helpers/models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsMonX6s,22465
222
224
  torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
223
225
  torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
224
- torchrl_nightly-2025.6.20.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
225
- torchrl_nightly-2025.6.20.dist-info/METADATA,sha256=jj9lUe-poxEhkdNE2J7ExqB1sKG47hV8vCKY1Bn2b_k,40113
226
- torchrl_nightly-2025.6.20.dist-info/WHEEL,sha256=VjOakRrFjQDaJ3SL0TIqKBZGtb43B2QJnIB-eW3qItk,101
227
- torchrl_nightly-2025.6.20.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
228
- torchrl_nightly-2025.6.20.dist-info/RECORD,,
226
+ torchrl_nightly-2025.6.21.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
227
+ torchrl_nightly-2025.6.21.dist-info/METADATA,sha256=aopN2qXYW22ie8bRVw4tzRvtX_rS99CWnQruvvzukkY,40113
228
+ torchrl_nightly-2025.6.21.dist-info/WHEEL,sha256=VjOakRrFjQDaJ3SL0TIqKBZGtb43B2QJnIB-eW3qItk,101
229
+ torchrl_nightly-2025.6.21.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
+ torchrl_nightly-2025.6.21.dist-info/RECORD,,