torchrl-nightly 2025.6.20__cp313-cp313-win_amd64.whl → 2025.6.22__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cp313-win_amd64.pyd +0 -0
- torchrl/collectors/collectors.py +8 -5
- torchrl/collectors/llm/base.py +13 -6
- torchrl/collectors/llm/ray_collector.py +3 -0
- torchrl/data/__init__.py +2 -0
- torchrl/data/llm/__init__.py +2 -0
- torchrl/data/llm/chat.py +59 -8
- torchrl/data/llm/topk.py +186 -0
- torchrl/data/replay_buffers/ray_buffer.py +15 -1
- torchrl/data/replay_buffers/replay_buffers.py +50 -11
- torchrl/data/replay_buffers/samplers.py +98 -21
- torchrl/data/replay_buffers/storages.py +29 -2
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +4 -1
- torchrl/envs/llm/reward/gsm8k.py +15 -8
- torchrl/envs/llm/transforms/__init__.py +2 -1
- torchrl/envs/llm/transforms/kl.py +240 -4
- torchrl/modules/llm/policies/transformers_wrapper.py +71 -15
- torchrl/modules/llm/policies/vllm_wrapper.py +38 -5
- torchrl/objectives/llm/__init__.py +2 -1
- torchrl/objectives/llm/sft.py +465 -0
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/RECORD +27 -25
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/LICENSE +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.6.20.dist-info → torchrl_nightly-2025.6.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
#
|
3
|
+
# This source code is licensed under the MIT license found in the
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import contextlib
|
8
|
+
import warnings
|
9
|
+
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from typing import Literal
|
12
|
+
|
13
|
+
import torch
|
14
|
+
from tensordict import NestedKey, TensorClass, TensorDictBase
|
15
|
+
from tensordict.nn import TensorDictModule
|
16
|
+
from tensordict.utils import _zip_strict
|
17
|
+
from torchrl.data import History
|
18
|
+
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
|
19
|
+
from torchrl.objectives.common import LossModule
|
20
|
+
|
21
|
+
|
22
|
+
def sft_loss(summed_log_probs: torch.Tensor, reduction: str) -> torch.Tensor:
|
23
|
+
"""Compute the SFT loss."""
|
24
|
+
if reduction == "mean":
|
25
|
+
loss = -summed_log_probs.mean()
|
26
|
+
elif reduction == "sum":
|
27
|
+
loss = -summed_log_probs.sum()
|
28
|
+
elif reduction == "none":
|
29
|
+
loss = -summed_log_probs
|
30
|
+
else:
|
31
|
+
raise ValueError(f"Invalid reduction: {reduction}.")
|
32
|
+
return loss
|
33
|
+
|
34
|
+
|
35
|
+
def minor_sft_loss(
|
36
|
+
log_probs: torch.Tensor,
|
37
|
+
ref_log_probs: torch.Tensor,
|
38
|
+
beta: float,
|
39
|
+
reduction: str,
|
40
|
+
) -> torch.Tensor:
|
41
|
+
"""Compute the MinorSFT loss.
|
42
|
+
|
43
|
+
This loss is inspired by DPO and is designed to be less aggressive than standard SFT.
|
44
|
+
It computes ``-log_sigmoid(beta * (log_probs - ref_log_probs))``.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
log_probs (torch.Tensor): The log probabilities from the model being trained.
|
48
|
+
ref_log_probs (torch.Tensor): The log probabilities from the reference model.
|
49
|
+
beta (float): The beta parameter from DPO.
|
50
|
+
reduction (str): The reduction to apply to the loss.
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
The MinorSFT loss.
|
54
|
+
|
55
|
+
References:
|
56
|
+
- Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
|
57
|
+
`"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
|
58
|
+
"""
|
59
|
+
if log_probs.shape != ref_log_probs.shape:
|
60
|
+
raise ValueError(
|
61
|
+
f"Current log probabilities and reference log probabilities have different shapes: {log_probs.shape=} vs {ref_log_probs.shape=}."
|
62
|
+
)
|
63
|
+
loss = -torch.nn.functional.logsigmoid(beta * (log_probs - ref_log_probs))
|
64
|
+
if reduction == "mean":
|
65
|
+
return loss.mean()
|
66
|
+
if reduction == "sum":
|
67
|
+
return loss.sum()
|
68
|
+
if reduction == "none":
|
69
|
+
return loss
|
70
|
+
raise ValueError(f"Invalid reduction: {reduction}")
|
71
|
+
|
72
|
+
|
73
|
+
class SFTLossOutput(TensorClass["nocast"]):
|
74
|
+
"""SFT Loss Output.
|
75
|
+
|
76
|
+
Attributes:
|
77
|
+
loss_sft (torch.Tensor): The loss for the SFT objective.
|
78
|
+
loss_kl_to_ref (torch.Tensor | None): The loss for the KL divergence to the reference model.
|
79
|
+
kl_to_ref (torch.Tensor | None): The KL divergence to the reference model.
|
80
|
+
|
81
|
+
.. note::
|
82
|
+
The loss components are kept separate to allow for logging and visualization.
|
83
|
+
Before backpropagation, the loss components are to be summed together. Since non-loss components are not differentiable
|
84
|
+
when the loss is constructed via :class:`~torchrl.objectives.llm.sft.SFTLoss`, summing
|
85
|
+
the :class:`~torchrl.objectives.llm.sft.SFTLossOutput` directly is a proper way of obtaining the total loss.
|
86
|
+
|
87
|
+
>>> loss_fn = SFTLoss(...)
|
88
|
+
>>> loss_output = loss_fn(td)
|
89
|
+
>>> loss = loss_output.loss_sft + loss_output.loss_kl_to_ref
|
90
|
+
>>> loss.backward()
|
91
|
+
>>> # or equivalently
|
92
|
+
>>> loss = loss_fn(td)
|
93
|
+
>>> loss.sum(reduce=True).backward()
|
94
|
+
"""
|
95
|
+
|
96
|
+
loss_sft: torch.Tensor
|
97
|
+
loss_kl_to_ref: torch.Tensor | None = None
|
98
|
+
kl_to_ref: torch.Tensor | None = None
|
99
|
+
|
100
|
+
|
101
|
+
class SFTLoss(LossModule):
|
102
|
+
r"""Supervised fine-tuning loss.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
actor_network (TensorDictModule): the actor network. Usually a :class:`~torchrl.modules.llm.TransformersWrapper` instance,
|
106
|
+
with `return_log_prob=True` and `from_text=True`.
|
107
|
+
tokenizer (`Tokenizer`): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor_network`.
|
108
|
+
tokenizer_kwargs (dict, optional): keyword arguments to pass to the tokenizer during :meth:`~torchrl.data.llm.chat.History.apply_chat_template`.
|
109
|
+
This can be used to override arguments such as the `chat_template` or `chat_template_name`.
|
110
|
+
reduction (Literal["mean", "sum", "none"], optional): the reduction to apply to the loss. Defaults to `"mean"`.
|
111
|
+
normalize_by_seq_length (bool, optional): whether to normalize the loss by the sequence length. Defaults to `True`.
|
112
|
+
kl_to_ref_coeff (float | None, optional): coefficient for KL divergence to reference model. Defaults to `None`.
|
113
|
+
loss_function (Literal["sft", "minor_sft"], optional): The loss function to use. Defaults to `"sft"`.
|
114
|
+
beta (float, optional): The beta parameter for MinorSFT loss. This is only used when `loss_function` is `"minor_sft"`.
|
115
|
+
Higher values of beta make the loss more aggressive (pushes the model to generate responses further from the reference model):
|
116
|
+
|
117
|
+
.. math::
|
118
|
+
\text{loss} = -\log\sigma(\beta \cdot (\text{log_probs} - \text{ref_log_probs}))
|
119
|
+
|
120
|
+
Defaults to `0.1`.
|
121
|
+
device (torch.device | None, optional): the device to use for the loss, when tokenizing the input. Defaults to `None`.
|
122
|
+
|
123
|
+
.. note::
|
124
|
+
The input tensordict is expected to contain the following keys by default:
|
125
|
+
- ``("next", "history")``: The chat history
|
126
|
+
- ``("next", "ref_log_prob")`` (optional): Reference model log probabilities, required if kl_to_ref_coeff is set
|
127
|
+
|
128
|
+
These keys can be customized using the ``set_keys()`` method.
|
129
|
+
|
130
|
+
.. seealso:: :class:`~torchrl.envs.llm.transforms.RetrieveLogProb` for the KL divergence computation.
|
131
|
+
|
132
|
+
References:
|
133
|
+
- Shiming Xie, Hong Chen, Fred Yu, Zeye Sun, Xiuyu Wu, 2024.
|
134
|
+
`"Minor SFT loss for LLM fine-tune to increase performance and reduce model deviation" <https://arxiv.org/abs/2408.10642>`_
|
135
|
+
|
136
|
+
Examples:
|
137
|
+
>>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
|
138
|
+
>>> from torchrl.modules.llm import TransformersWrapper
|
139
|
+
>>> from torchrl.objectives.llm.sft import SFTLoss
|
140
|
+
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
|
141
|
+
>>> from tensordict import TensorDict, lazy_stack
|
142
|
+
>>> import torch
|
143
|
+
>>>
|
144
|
+
>>> # Create chat data
|
145
|
+
>>> chats = [
|
146
|
+
... [
|
147
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
148
|
+
... {"role": "user", "content": "Hello, how are you?"},
|
149
|
+
... {"role": "assistant", "content": "I'm doing well, thank you!"},
|
150
|
+
... ],
|
151
|
+
... [
|
152
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
153
|
+
... {"role": "user", "content": "What's the weather like?"},
|
154
|
+
... {"role": "assistant", "content": "I can't check the weather for you."},
|
155
|
+
... ],
|
156
|
+
... ]
|
157
|
+
>>> history = History.from_chats(chats)
|
158
|
+
>>>
|
159
|
+
>>> # Setup tokenizer and model
|
160
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
161
|
+
>>> tokenizer.pad_token = tokenizer.eos_token
|
162
|
+
>>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
|
163
|
+
>>> model = OPTForCausalLM(OPTConfig()).eval()
|
164
|
+
>>>
|
165
|
+
>>> # Create training and reference policies
|
166
|
+
>>> policy_train = TransformersWrapper(
|
167
|
+
... model,
|
168
|
+
... tokenizer=tokenizer,
|
169
|
+
... generate=False,
|
170
|
+
... from_text=True,
|
171
|
+
... chat_template_name="qwen",
|
172
|
+
... )
|
173
|
+
>>> policy_ref = TransformersWrapper(
|
174
|
+
... model,
|
175
|
+
... tokenizer=tokenizer,
|
176
|
+
... generate=False,
|
177
|
+
... from_text=True,
|
178
|
+
... return_log_probs=True,
|
179
|
+
... chat_template_name="qwen",
|
180
|
+
... )
|
181
|
+
>>>
|
182
|
+
>>> # Create the RetrieveLogProb transform
|
183
|
+
>>> transform = RetrieveLogProb(
|
184
|
+
... policy_ref,
|
185
|
+
... assistant_only=True,
|
186
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
187
|
+
... tokenizer=tokenizer,
|
188
|
+
... )
|
189
|
+
>>>
|
190
|
+
>>> # Prepare data
|
191
|
+
>>> text = history[:, :-1].apply_chat_template(
|
192
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
|
193
|
+
... )
|
194
|
+
>>> text_response = history.apply_chat_template(
|
195
|
+
... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
|
196
|
+
... )
|
197
|
+
>>> text_response = [
|
198
|
+
... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
|
199
|
+
... ]
|
200
|
+
>>> td = TensorDict(
|
201
|
+
... text=text,
|
202
|
+
... text_response=text_response,
|
203
|
+
... history=history,
|
204
|
+
... next=TensorDict(
|
205
|
+
... reward=torch.randn(2, 1),
|
206
|
+
... done=torch.zeros(2, dtype=torch.bool),
|
207
|
+
... history=history,
|
208
|
+
... ),
|
209
|
+
... batch_size=(2,),
|
210
|
+
... )
|
211
|
+
>>> data = lazy_stack(list(td.unbind(0)))
|
212
|
+
>>>
|
213
|
+
>>> # Apply the transform to get reference log probabilities
|
214
|
+
>>> data = transform(data)
|
215
|
+
>>> assert "ref_log_prob" in data["next"].keys()
|
216
|
+
>>>
|
217
|
+
>>> # Use with SFTLoss for KL regularization
|
218
|
+
>>> loss = SFTLoss(
|
219
|
+
... actor_network=policy_train,
|
220
|
+
... tokenizer=tokenizer,
|
221
|
+
... reduction="mean",
|
222
|
+
... normalize_by_seq_length=True,
|
223
|
+
... kl_to_ref_coeff=0.1,
|
224
|
+
... tokenizer_kwargs={"chat_template_name": "qwen"},
|
225
|
+
... loss_function="sft",
|
226
|
+
... )
|
227
|
+
>>> loss_vals = loss(data)
|
228
|
+
>>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
|
229
|
+
>>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
|
230
|
+
|
231
|
+
"""
|
232
|
+
|
233
|
+
@dataclass
|
234
|
+
class _AcceptedKeys:
|
235
|
+
"""Maintains default values for all configurable tensordict keys.
|
236
|
+
|
237
|
+
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
|
238
|
+
default values.
|
239
|
+
|
240
|
+
Attributes:
|
241
|
+
history (NestedKey): The input tensordict key where the chat history is expected.
|
242
|
+
Defaults to ``("next", "history")``.
|
243
|
+
ref_log_prob (NestedKey): The input tensordict key where the reference model log probabilities are expected.
|
244
|
+
Only used when kl_to_ref_coeff is set. Defaults to ``("next", "ref_log_prob")``.
|
245
|
+
log_probs (NestedKey): The output tensordict key where the model's log probabilities will be written.
|
246
|
+
Defaults to ``"log_probs"``.
|
247
|
+
"""
|
248
|
+
|
249
|
+
history: NestedKey = ("next", "history")
|
250
|
+
ref_log_prob: NestedKey = ("next", "ref_log_prob")
|
251
|
+
log_probs: NestedKey = "log_probs"
|
252
|
+
|
253
|
+
default_keys = _AcceptedKeys
|
254
|
+
tensor_keys: _AcceptedKeys
|
255
|
+
|
256
|
+
def __init__(
|
257
|
+
self,
|
258
|
+
actor_network: TensorDictModule | TransformersWrapper,
|
259
|
+
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
|
260
|
+
tokenizer_kwargs: dict | None = None,
|
261
|
+
reduction: Literal["mean", "sum", "none"] = "mean",
|
262
|
+
normalize_by_seq_length: bool = True,
|
263
|
+
kl_to_ref_coeff: float | None = None,
|
264
|
+
loss_function: Literal["sft", "minor_sft"] = "sft",
|
265
|
+
beta: float = 0.1,
|
266
|
+
device: torch.device | None = None,
|
267
|
+
):
|
268
|
+
super().__init__()
|
269
|
+
self.in_keys = []
|
270
|
+
self.actor_network = actor_network
|
271
|
+
if tokenizer is None:
|
272
|
+
tokenizer = actor_network.tokenizer
|
273
|
+
self.tokenizer = tokenizer
|
274
|
+
if tokenizer_kwargs is None:
|
275
|
+
tokenizer_kwargs = {}
|
276
|
+
if tokenizer is None:
|
277
|
+
raise ValueError("Tokenizer must be provided.")
|
278
|
+
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
|
279
|
+
tokenizer_kwargs.setdefault("tokenize", True)
|
280
|
+
tokenizer_kwargs.setdefault("return_tensors", "pt")
|
281
|
+
tokenizer_kwargs.setdefault("padding", False)
|
282
|
+
tokenizer_kwargs.setdefault("add_generation_prompt", False)
|
283
|
+
self.tokenizer_kwargs = tokenizer_kwargs
|
284
|
+
self.reduction = reduction
|
285
|
+
self.normalize_by_seq_length = normalize_by_seq_length
|
286
|
+
self.kl_to_ref_coeff = kl_to_ref_coeff
|
287
|
+
self.loss_function = loss_function
|
288
|
+
if self.loss_function == "minor_sft" and kl_to_ref_coeff:
|
289
|
+
warnings.warn(
|
290
|
+
"kl_to_ref_coeff should not be set when using minor_sft loss, as KL regularization is implicit. Setting kl_to_ref_coeff to 0.0."
|
291
|
+
)
|
292
|
+
self.kl_to_ref_coeff = 0.0
|
293
|
+
self.beta = beta
|
294
|
+
self._set_in_keys()
|
295
|
+
self.device = device
|
296
|
+
|
297
|
+
def _set_in_keys(self) -> None:
|
298
|
+
"""Sets the input keys for the loss module."""
|
299
|
+
in_keys = [self.tensor_keys.history]
|
300
|
+
if self.kl_to_ref_coeff is not None or self.loss_function == "minor_sft":
|
301
|
+
in_keys.append(self.tensor_keys.ref_log_prob)
|
302
|
+
self.in_keys = in_keys
|
303
|
+
self.out_keys = [] # Loss modules typically don't have out_keys
|
304
|
+
|
305
|
+
def _kl_to_ref(
|
306
|
+
self,
|
307
|
+
cur_log_prob: list[torch.Tensor],
|
308
|
+
ref_log_prob: list[torch.Tensor],
|
309
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
310
|
+
"""Compute KL divergence to reference model.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
cur_log_prob (List[torch.Tensor]): Log probabilities from current model. Must have shape [T] where T is the number of tokens in the assistant response.
|
314
|
+
ref_log_prob (List[torch.Tensor]): Log probabilities from reference model. Must have shape [T] where T is the number of tokens in the assistant response.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
tuple[torch.Tensor, torch.Tensor]: (KL loss term, KL penalty for logging)
|
318
|
+
"""
|
319
|
+
# Apply mask
|
320
|
+
ref_log_prob = torch.cat(ref_log_prob)
|
321
|
+
cur_log_prob = torch.cat(cur_log_prob)
|
322
|
+
# ref_log_prob = ref_log_prob[mask]
|
323
|
+
# cur_log_prob = cur_log_prob[mask].squeeze()
|
324
|
+
if cur_log_prob.shape != ref_log_prob.shape:
|
325
|
+
raise ValueError(
|
326
|
+
f"Current log probabilities and reference log probabilities have different shapes: {cur_log_prob.shape=} vs {ref_log_prob.shape=}."
|
327
|
+
)
|
328
|
+
# Compute KL using same approximation as GRPO
|
329
|
+
diff = ref_log_prob - cur_log_prob
|
330
|
+
|
331
|
+
kl_penalty = (diff.expm1() - diff).mean()
|
332
|
+
return self.kl_to_ref_coeff * kl_penalty, kl_penalty
|
333
|
+
|
334
|
+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
335
|
+
# Gather history
|
336
|
+
history: History = tensordict[self.tensor_keys.history]
|
337
|
+
|
338
|
+
# Apply tokenizer to history and gather mask
|
339
|
+
with torch.device(
|
340
|
+
self.device
|
341
|
+
) if self.device is not None else contextlib.nullcontext():
|
342
|
+
token_struct = history.apply_chat_template(
|
343
|
+
tokenizer=self.tokenizer, **self.tokenizer_kwargs
|
344
|
+
)
|
345
|
+
if "assistant_masks" not in token_struct:
|
346
|
+
raise ValueError(
|
347
|
+
f"Assistant masks are not present in the token structure: {token_struct=}."
|
348
|
+
)
|
349
|
+
assistant_masks = token_struct.get(
|
350
|
+
"assistant_masks",
|
351
|
+
as_list=True,
|
352
|
+
)
|
353
|
+
assistant_masks = [mask.bool() for mask in assistant_masks]
|
354
|
+
attention_mask = token_struct.get("attention_mask", as_list=True)
|
355
|
+
attention_mask = [mask.bool() for mask in attention_mask]
|
356
|
+
assistant_masks = [
|
357
|
+
mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
|
358
|
+
]
|
359
|
+
|
360
|
+
if not any(mask.any(-1).all() for mask in assistant_masks):
|
361
|
+
raise ValueError("Some inputs have no valid assistant masks.")
|
362
|
+
input_loss = tensordict.select(self.tensor_keys.history)
|
363
|
+
if (
|
364
|
+
isinstance(self.tensor_keys.history, tuple)
|
365
|
+
and self.tensor_keys.history[0] == "next"
|
366
|
+
):
|
367
|
+
input_loss = input_loss["next"]
|
368
|
+
|
369
|
+
with torch.device(
|
370
|
+
self.device
|
371
|
+
) if self.device is not None else contextlib.nullcontext():
|
372
|
+
output_loss = self.actor_network(input_loss)
|
373
|
+
|
374
|
+
# get log-probs
|
375
|
+
log_probs = output_loss.get(
|
376
|
+
self.tensor_keys.log_probs,
|
377
|
+
as_list=True,
|
378
|
+
)
|
379
|
+
# apply mask
|
380
|
+
if not all(
|
381
|
+
mask.shape == lp.shape
|
382
|
+
for mask, lp in _zip_strict(assistant_masks, log_probs)
|
383
|
+
):
|
384
|
+
raise ValueError(
|
385
|
+
f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
|
386
|
+
)
|
387
|
+
|
388
|
+
log_probs_masked = [
|
389
|
+
lp.masked_fill(~mask, 0.0)
|
390
|
+
for lp, mask in _zip_strict(log_probs, assistant_masks)
|
391
|
+
]
|
392
|
+
|
393
|
+
# Sum log probs, optionally normalize by sequence length
|
394
|
+
summed_log_probs = torch.stack(
|
395
|
+
[lp.sum(tensordict.ndim - 1) for lp in log_probs_masked]
|
396
|
+
)
|
397
|
+
seq_lengths = torch.stack(
|
398
|
+
[mask.sum(tensordict.ndim - 1) for mask in assistant_masks]
|
399
|
+
)
|
400
|
+
if self.normalize_by_seq_length:
|
401
|
+
# Compute sequence lengths for normalization (number of assistant tokens)
|
402
|
+
summed_log_probs = summed_log_probs / seq_lengths.clamp(min=1)
|
403
|
+
|
404
|
+
# Compute main loss
|
405
|
+
if self.loss_function == "sft":
|
406
|
+
loss = sft_loss(summed_log_probs, self.reduction)
|
407
|
+
# Add KL divergence loss if reference model is provided
|
408
|
+
if self.kl_to_ref_coeff is not None:
|
409
|
+
ref_log_probs = tensordict.get(
|
410
|
+
self.tensor_keys.ref_log_prob,
|
411
|
+
default=None,
|
412
|
+
as_list=True,
|
413
|
+
)
|
414
|
+
if ref_log_probs is None:
|
415
|
+
raise ValueError(
|
416
|
+
"Reference log probs not found in tensordict but kl_to_ref_coeff was set"
|
417
|
+
)
|
418
|
+
|
419
|
+
loss_kl, kl_penalty = self._kl_to_ref(
|
420
|
+
[lp[mask] for lp, mask in _zip_strict(log_probs, assistant_masks)],
|
421
|
+
ref_log_probs,
|
422
|
+
)
|
423
|
+
output = SFTLossOutput(
|
424
|
+
loss_sft=loss,
|
425
|
+
loss_kl_to_ref=loss_kl,
|
426
|
+
kl_to_ref=kl_penalty.detach(),
|
427
|
+
)
|
428
|
+
else:
|
429
|
+
output = SFTLossOutput(loss_sft=loss)
|
430
|
+
elif self.loss_function == "minor_sft":
|
431
|
+
ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
|
432
|
+
if ref_log_probs is None:
|
433
|
+
raise ValueError(
|
434
|
+
f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict but loss_function is 'minor_sft'"
|
435
|
+
)
|
436
|
+
|
437
|
+
# we need to re-sum ref_log_probs as they are not summed per-sequence
|
438
|
+
summed_ref_log_probs = torch.stack([lp.sum() for lp in ref_log_probs]).to(
|
439
|
+
summed_log_probs.device
|
440
|
+
)
|
441
|
+
if self.normalize_by_seq_length:
|
442
|
+
summed_ref_log_probs = summed_ref_log_probs / seq_lengths.clamp(min=1)
|
443
|
+
loss = minor_sft_loss(
|
444
|
+
summed_log_probs, summed_ref_log_probs, self.beta, self.reduction
|
445
|
+
)
|
446
|
+
if self.kl_to_ref_coeff is not None:
|
447
|
+
with torch.no_grad():
|
448
|
+
loss_kl, kl_penalty = self._kl_to_ref(
|
449
|
+
[
|
450
|
+
lp[mask]
|
451
|
+
for lp, mask in _zip_strict(log_probs, assistant_masks)
|
452
|
+
],
|
453
|
+
ref_log_probs,
|
454
|
+
)
|
455
|
+
output = SFTLossOutput(
|
456
|
+
loss_sft=loss,
|
457
|
+
loss_kl_to_ref=loss_kl,
|
458
|
+
kl_to_ref=kl_penalty.detach(),
|
459
|
+
)
|
460
|
+
else:
|
461
|
+
output = SFTLossOutput(loss_sft=loss)
|
462
|
+
else:
|
463
|
+
raise ValueError(f"Invalid loss function: {self.loss_function}")
|
464
|
+
|
465
|
+
return output
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.6.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.6.22'
|
2
|
+
git_version = '77dbc6c9ffbce3d2ce3f26b659355cd46d8132c3'
|
@@ -3,11 +3,11 @@ build_tools/setup_helpers/__init__.py,sha256=l9zlK7Nm5bT7P_onQx-hZeIGzKKyCFm1PFk
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=ihV8jz8kqOvpqzuD006XqF1oNX5ukKGlwIOJRb1Vd-o,6075
|
4
4
|
torchrl/__init__.py,sha256=76lKYwYKmAKORhyVt2tURmYAIRTifxxO3gWsskrHAXU,3054
|
5
5
|
torchrl/_extension.py,sha256=x6Nqj2brF3VhlEwxmNA2fYbmpxq1HHGrHMnP0YnQwdc,2412
|
6
|
-
torchrl/_torchrl.cp313-win_amd64.pyd,sha256=
|
6
|
+
torchrl/_torchrl.cp313-win_amd64.pyd,sha256=CP0G1g9_FfwAaXAXkaznbN0USlRSkpOLIZeC6_V-QZE,429056
|
7
7
|
torchrl/_utils.py,sha256=2N35rdD65U1khMi5gVIz8-nMjlZsoVq0kCiQftVRSxw,42297
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=sFWjKZnDuj1wpE9ARKNe8cTHlKoOAExwF3ZTthG-Y-U,85
|
9
9
|
torchrl/collectors/__init__.py,sha256=LzTyfxmkNGPSa5-3rS5unQK7HfT5ZEdr2NV291rAOlU,832
|
10
|
-
torchrl/collectors/collectors.py,sha256=
|
10
|
+
torchrl/collectors/collectors.py,sha256=Pz6VYYrekjBiVBQiyzp6zIyZrBdjSv4-FqlfrGYQz3E,181469
|
11
11
|
torchrl/collectors/utils.py,sha256=aBmBLpphhfplqQjRCyn1jtWWJ-Wtc7TWvM0rOBN8SsE,11579
|
12
12
|
torchrl/collectors/weight_update.py,sha256=Ydq5nJSTV3Q1uqLtJ_1Nj1JB5rwHwrG5StaLxymWFV4,21572
|
13
13
|
torchrl/collectors/distributed/__init__.py,sha256=cKDWdNlwx2LoJkTwf-DKUXbq3Y-0Z1DctPYPcdgOSU0,730
|
@@ -18,12 +18,12 @@ torchrl/collectors/distributed/rpc.py,sha256=xta5tptC0mLlIY_AecLrARvFBYh7nMbclrr
|
|
18
18
|
torchrl/collectors/distributed/sync.py,sha256=zjp0HEEcSMaDzq8xndoBWvyqYCdf7hp8urYRRiJP2GI,27912
|
19
19
|
torchrl/collectors/distributed/utils.py,sha256=eY6M-vLCSzyACHRNBx5bHcieWsZfLg7DfNKGIv0IgHI,6625
|
20
20
|
torchrl/collectors/llm/__init__.py,sha256=u03aQ97C3sb5-C0-s2tMBAGGs3kJTfZUSse29fHDkIk,365
|
21
|
-
torchrl/collectors/llm/base.py,sha256=
|
22
|
-
torchrl/collectors/llm/ray_collector.py,sha256=
|
21
|
+
torchrl/collectors/llm/base.py,sha256=Wxdo4drsMk_i5u5DzDlikY4j5-TM9f6Ac4xjB6wJgPw,21132
|
22
|
+
torchrl/collectors/llm/ray_collector.py,sha256=nk-i61ZAsYkDNZW2Y7vcDhddjkUyyIuiDYjU9iWYklE,11364
|
23
23
|
torchrl/collectors/llm/utils.py,sha256=GnDY2cTu4XEdwqqhFCP4QWfS2tsgaLTy8nwpIaTEsQI,1184
|
24
24
|
torchrl/collectors/llm/weight_update/__init__.py,sha256=ngbL_sPfXh8FMM3r_j0B9QEP_jQIVSOa8pZVouHg9ec,281
|
25
25
|
torchrl/collectors/llm/weight_update/vllm.py,sha256=4kRlEBHb6093d9lkKVIqU8ZwiPoCFtmVVaADuhxKLL4,11571
|
26
|
-
torchrl/data/__init__.py,sha256=
|
26
|
+
torchrl/data/__init__.py,sha256=h6ZHGWzvWDfyw6tgo69gld5C_Kgg-T4DP-EVv1hy0Xk,5026
|
27
27
|
torchrl/data/rlhf.py,sha256=_ENSvNe84snnFQG0jlTtOI419nIYtbBHvAw-pdFMiSs,1002
|
28
28
|
torchrl/data/tensor_specs.py,sha256=PF-sta3dHy0UaDu26FS5-ZVlpAsagTUUMLmYHs5PQhk,254762
|
29
29
|
torchrl/data/utils.py,sha256=tXBPxl5VHqPUfJF1VLqURmb066zDd9lipRDER4R1FY8,12444
|
@@ -39,12 +39,13 @@ torchrl/data/datasets/openx.py,sha256=0p2H3phnvsUgrFFfQulTyGEDcUrvXiFp-p18uFk8Xk
|
|
39
39
|
torchrl/data/datasets/roboset.py,sha256=sLdDknyPj7f2NF5z5EbzEQ9QhD03VBKHEXXo_wxSRc4,17011
|
40
40
|
torchrl/data/datasets/utils.py,sha256=tRZkarWl-BX_fnGEVNRt6Fjo0wmyKCNAuVmw5Le-0C4,350
|
41
41
|
torchrl/data/datasets/vd4rl.py,sha256=YFjXvP-QGNzF7UWNKkGMKPFthcB0I8v6sJc7oESegYA,18694
|
42
|
-
torchrl/data/llm/__init__.py,sha256=
|
43
|
-
torchrl/data/llm/chat.py,sha256=
|
42
|
+
torchrl/data/llm/__init__.py,sha256=X86bYW_uNAwKJcxK2AVQpJo56tDejtXSDFtOyZMNhHQ,1006
|
43
|
+
torchrl/data/llm/chat.py,sha256=akIL6wj2QtpdWlFvOcRfPcrk7HE5NoSpgZxB8gtdOos,33962
|
44
44
|
torchrl/data/llm/common.py,sha256=3Gb8sMojtNss6wi6hKGSIUDAwK1PC_8ve9W-bHCPjAk,2181
|
45
45
|
torchrl/data/llm/dataset.py,sha256=GDuzflBq2ThgYn_V4bOr_1MHOhESnQ5jX1Wlcw69lfM,21194
|
46
46
|
torchrl/data/llm/prompt.py,sha256=ikHWafhTIoCONCpuMHwIuGfpnPSpg5drZQTxTCmygQE,8578
|
47
47
|
torchrl/data/llm/reward.py,sha256=QW1HWpNRORd3InwWLg-hAhjTlPqX4ffzAkYHEz0jQxo,8629
|
48
|
+
torchrl/data/llm/topk.py,sha256=4MTxYTTdfSBM5vxDHnleY7FatanzHgneXs8rjgKwwqQ,8539
|
48
49
|
torchrl/data/llm/utils.py,sha256=K2NQoEhBC6VWowsMeDHu2Q8vbg3ZPEWBBN6z4qifiNM,24143
|
49
50
|
torchrl/data/map/__init__.py,sha256=bON0vqCksU7FPoWNqiNcdl60t7yWUh9SdLhNtglj7jI,576
|
50
51
|
torchrl/data/map/hash.py,sha256=XRYdaFHQUm87fL9pWjhvi2LeZVaqJsASkCU-G_Gus8s,7437
|
@@ -56,11 +57,11 @@ torchrl/data/postprocs/__init__.py,sha256=fOyX5OMaDb5HGrQbn9W72_QwncNdh6l3DkVSqR
|
|
56
57
|
torchrl/data/postprocs/postprocs.py,sha256=dpXOKWlhdKy4Um7HdzRKe42PJ_Q1jHC7AX5plR9AIiw,15509
|
57
58
|
torchrl/data/replay_buffers/__init__.py,sha256=oINoSWKO3Ku6YIBF-0KnbVLZwelZbANN4nLU4q6Mir0,2455
|
58
59
|
torchrl/data/replay_buffers/checkpointers.py,sha256=eizAw4W0tQ2EWgfx6-EUV_3EuZMcZVdoahLoux15Nr4,15186
|
59
|
-
torchrl/data/replay_buffers/ray_buffer.py,sha256=
|
60
|
-
torchrl/data/replay_buffers/replay_buffers.py,sha256=
|
61
|
-
torchrl/data/replay_buffers/samplers.py,sha256=
|
60
|
+
torchrl/data/replay_buffers/ray_buffer.py,sha256=p8EkiXOP4EVMkkpjOyje7wiBfgbWOB0xhEJzDdelAxc,10135
|
61
|
+
torchrl/data/replay_buffers/replay_buffers.py,sha256=YzNV543zDpvENbUnqjjghkHHq6IwyRms1I3DXl-ayq4,92567
|
62
|
+
torchrl/data/replay_buffers/samplers.py,sha256=bFP8j3BahHULASAjeIGtxJX36GRsg3yBCNFu7vR6Zdo,112834
|
62
63
|
torchrl/data/replay_buffers/scheduler.py,sha256=cGm4LZcZ2lo8azDMWKGTdhWApxjZFh0KfynApxAkVK4,10416
|
63
|
-
torchrl/data/replay_buffers/storages.py,sha256=
|
64
|
+
torchrl/data/replay_buffers/storages.py,sha256=VdEYOQ29FWGwDeHfLafZuobMlmuagxHVHnntKnU-yX4,62298
|
64
65
|
torchrl/data/replay_buffers/utils.py,sha256=vlGfyHVKUAMKBR0l7fJM9NI47ZinS18Qzf8lpwoo6pI,39644
|
65
66
|
torchrl/data/replay_buffers/writers.py,sha256=-aI6Y28oisuFDutMVlPp4e8wTe6x0wlY0MY1OUKHl4Q,28466
|
66
67
|
torchrl/envs/__init__.py,sha256=2eVr8StUSMiNd-IoD5BQAFFuV10pAtO926b6QzRzB_M,6082
|
@@ -96,8 +97,8 @@ torchrl/envs/libs/smacv2.py,sha256=pr03oGHE2G_fc86qHeSQjSz3S6IH_l2hX0J2umb020M,2
|
|
96
97
|
torchrl/envs/libs/unity_mlagents.py,sha256=vszCYjEX0S9AmIwLvGsoqc0Jr7jvlBAqZ1HQ1uqesjM,50558
|
97
98
|
torchrl/envs/libs/utils.py,sha256=Ce8nAYc2MQOBTYCV17Yswk98pg3PStnaGPFVW2jqARQ,5354
|
98
99
|
torchrl/envs/libs/vmas.py,sha256=giTORg2AqYzyjrazdD94fD2dNYwX7qe5TFnr-E1mjIg,37140
|
99
|
-
torchrl/envs/llm/__init__.py,sha256=
|
100
|
-
torchrl/envs/llm/chat.py,sha256=
|
100
|
+
torchrl/envs/llm/__init__.py,sha256=Iz5HtLoVy8O4u1mrPmyql4G8SU9S-MCinP_Gh8sbUWo,1320
|
101
|
+
torchrl/envs/llm/chat.py,sha256=YvADxo11RKkjD06rBvbbljch3Jb_H4snaBcgkU2Q-7w,18171
|
101
102
|
torchrl/envs/llm/envs.py,sha256=wphbzLwDKYO_OTV63WYW4iTK5Ek4vmb1zNv5gehzodY,35450
|
102
103
|
torchrl/envs/llm/datasets/__init__.py,sha256=6-x0WlKD7lpMVLKA4W1AktvgUs6adMuaGAqYYhgQ_hk,490
|
103
104
|
torchrl/envs/llm/datasets/gsm8k.py,sha256=MfCFu0U7uetDtLdzUdvqX4rENXPsL8msnkMT98Q29jE,15624
|
@@ -105,18 +106,18 @@ torchrl/envs/llm/datasets/ifeval.py,sha256=sJ4bvXEWBzzNnDDbKkj6yz_1zemDStVFNfxqo
|
|
105
106
|
torchrl/envs/llm/libs/__init__.py,sha256=zvUe6oe3pjZwGefV-_x4MAC6K89TMqxh3TZs5s3ADkI,274
|
106
107
|
torchrl/envs/llm/libs/mlgym.py,sha256=TMaoV9P5w5EGgBSmLiw42_DOyKEh7ZGf3mDf-LaZ9W0,32237
|
107
108
|
torchrl/envs/llm/reward/__init__.py,sha256=KYNJxyDOe2mZkjyH4CSuQ8qM0_Zu3EAaIGocYhLduPQ,380
|
108
|
-
torchrl/envs/llm/reward/gsm8k.py,sha256=
|
109
|
+
torchrl/envs/llm/reward/gsm8k.py,sha256=TW2lACMLXHRlcTTRfTcFTLl7NIJA0TNh6qfSJiC52QI,8066
|
109
110
|
torchrl/envs/llm/reward/ifeval/__init__.py,sha256=vvh7JSUQaEiMjNeMeJvWlcFb2-6_J1LfM6l4mENn4Zg,324
|
110
111
|
torchrl/envs/llm/reward/ifeval/_instructions.py,sha256=jlNvIO3dykk8fBFXC35PQSDJ9vLF3knS-ywq2ILfF00,63362
|
111
112
|
torchrl/envs/llm/reward/ifeval/_instructions_main.py,sha256=DEc7QqfujGxYvqcm2y_zPasBqB7FgSfXRt3QQ4HQUz0,4244
|
112
113
|
torchrl/envs/llm/reward/ifeval/_instructions_registry.py,sha256=bY8R51RgjKJYiim67j5IXSfYhtWtvZrRF61yuqi0Tzs,3914
|
113
114
|
torchrl/envs/llm/reward/ifeval/_instructions_util.py,sha256=63ZJbqUKaqMA_SDhnYT7VppULbible8udHGihiamxKc,27719
|
114
115
|
torchrl/envs/llm/reward/ifeval/_scorer.py,sha256=iv-316dBYlz4fz6WUtzP7151y4xEwuwOq-Wf7Qazgmc,14928
|
115
|
-
torchrl/envs/llm/transforms/__init__.py,sha256=
|
116
|
+
torchrl/envs/llm/transforms/__init__.py,sha256=BnVW7WVCYPlaNPd4cEyXIUI1Qfd7YQkZVsXvj5q5UVw,814
|
116
117
|
torchrl/envs/llm/transforms/browser.py,sha256=d0JIUZ3TfgmqBcci5ihzzTqZA9KeTrs1iCProRWQQK8,10715
|
117
118
|
torchrl/envs/llm/transforms/dataloading.py,sha256=Zl--I6bT2AqWDZmM6RMQq4ds3b1PFGqilAaIbXBJuNc,25054
|
118
119
|
torchrl/envs/llm/transforms/format.py,sha256=tME390wkG0h2V5DAWHZa7EhJ5Or-6cga6AIjxPuy1l8,2592
|
119
|
-
torchrl/envs/llm/transforms/kl.py,sha256=
|
120
|
+
torchrl/envs/llm/transforms/kl.py,sha256=GrANICxnF-FC_yVkdO4EU66bj7aTfkB8OXG7wA2uxDo,23251
|
120
121
|
torchrl/envs/llm/transforms/policy_version.py,sha256=fko23hsQrAMmUqFwKjV_CQVavDhixXFUeVE0lJBASOA,7080
|
121
122
|
torchrl/envs/llm/transforms/tokenizer.py,sha256=Nest15FD1iPLNZuw0rAobyb7n3ce6KFX00qfN3dUE2M,14274
|
122
123
|
torchrl/envs/llm/transforms/tools.py,sha256=WoNgUN1Me4mhbqH5ef9XNc5qKXE0H-g3ZobjEKDM_kw,30308
|
@@ -147,8 +148,8 @@ torchrl/modules/llm/backends/__init__.py,sha256=ABKK4mJeRtoLXEqfnMvIuiovs7VJoCxn
|
|
147
148
|
torchrl/modules/llm/backends/vllm.py,sha256=5P78jEtAIytgYHzEkOrg-wwqh1ryhiMVy4M_AxNQ9JQ,9649
|
148
149
|
torchrl/modules/llm/policies/__init__.py,sha256=CK7VEdfShjkeNu_-TmYOobrCEjKTIb2aw2hE6s5RBNs,439
|
149
150
|
torchrl/modules/llm/policies/common.py,sha256=GXzmVRa0SJvQ8iPMeuNjwV7EaZDOPrVy5k_LlJ10QXY,3111
|
150
|
-
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=
|
151
|
-
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=
|
151
|
+
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=0wDYGpC1T5T8ZVyZOi5S2Qoa2wtg5Oix2W9Y_bKMKs8,25988
|
152
|
+
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=LBoFbrTyGiEiilYVmU7Ze-WBpwYinvAVIEtbU1QKajw,32226
|
152
153
|
torchrl/modules/models/__init__.py,sha256=Y1XTkBOB5EMj6IaMru6V3CDwFLnkUtxzsHcqzeqq_4Y,1829
|
153
154
|
torchrl/modules/models/batchrenorm.py,sha256=bR4ZhaJ5E1cSK5o8L2dNX5KVLIb-bgrYxcq6yhx0I1A,4869
|
154
155
|
torchrl/modules/models/decision_transformer.py,sha256=ANFTOm3k9_3Uv1vKGdXumRy3meBPnDdT8HqhVvJ2RCo,6783
|
@@ -194,8 +195,9 @@ torchrl/objectives/sac.py,sha256=gKOgCU399miKgpgu7Bmzs1bkIF8JTm_lybHn8V4wDuk,654
|
|
194
195
|
torchrl/objectives/td3.py,sha256=Rq2q5gXo3AMuHm2OjRZvpfvKsAl1lIK5ALh2_sZM1ZE,23743
|
195
196
|
torchrl/objectives/td3_bc.py,sha256=1pjB8mjCT2CLvQzjnqwAfZoc7yhjMB9UQjuJ5wZfTUY,26558
|
196
197
|
torchrl/objectives/utils.py,sha256=Vrjj07SjMYANfFyn3n1xS7izBIs5Mq9mCvyITMzifZs,24705
|
197
|
-
torchrl/objectives/llm/__init__.py,sha256=
|
198
|
+
torchrl/objectives/llm/__init__.py,sha256=tZmIz3rkeclw3MzJoOWEs2gkewjx2USKrKJbWdyiiaQ,406
|
198
199
|
torchrl/objectives/llm/grpo.py,sha256=nT3Ukjaz7nZZnkS5tnb-pDnRzvZ3L1edpcNCzi5WZRs,17164
|
200
|
+
torchrl/objectives/llm/sft.py,sha256=9fzX9Qo0Goyjxuwca6eLN1PUQ24F0LZGRpjzTDLFfs4,20572
|
199
201
|
torchrl/objectives/multiagent/__init__.py,sha256=5uebDe5KrvlzeYV_BSd5vdmfruJQYMeDVVbU4iHErEg,245
|
200
202
|
torchrl/objectives/multiagent/qmixer.py,sha256=yttOxc5FNylKw4iMnYSG1qO8EbHvx8imAhxNxW9_iLw,17362
|
201
203
|
torchrl/objectives/value/__init__.py,sha256=QkSnenYVqe_3FVtwGr_D86N52unnpBvRXfcC5JFTBOw,589
|
@@ -221,8 +223,8 @@ torchrl/trainers/helpers/losses.py,sha256=rWKure02dl8hLBzLUs-jhNJV8L3QHWtFbl3HbX
|
|
221
223
|
torchrl/trainers/helpers/models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsMonX6s,22465
|
222
224
|
torchrl/trainers/helpers/replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
|
223
225
|
torchrl/trainers/helpers/trainers.py,sha256=hB1FtHtP-S0PBQ4LF6WPy37caaLpacyaLThj1BNl5Ho,12372
|
224
|
-
torchrl_nightly-2025.6.
|
225
|
-
torchrl_nightly-2025.6.
|
226
|
-
torchrl_nightly-2025.6.
|
227
|
-
torchrl_nightly-2025.6.
|
228
|
-
torchrl_nightly-2025.6.
|
226
|
+
torchrl_nightly-2025.6.22.dist-info/LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
|
227
|
+
torchrl_nightly-2025.6.22.dist-info/METADATA,sha256=hHTi-fhdpfc3I2ipKsmcjKtiCzkK_mWcn_7nZwnH-0U,40113
|
228
|
+
torchrl_nightly-2025.6.22.dist-info/WHEEL,sha256=34WDYMHIz_-H9l4AW0T3KmyzuxEGEmufYB-iQzXMn2g,101
|
229
|
+
torchrl_nightly-2025.6.22.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
230
|
+
torchrl_nightly-2025.6.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|