torchrl-nightly 2025.6.21__cp312-cp312-macosx_10_13_universal2.whl → 2025.6.23__cp312-cp312-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
@@ -15,6 +15,7 @@ from .envs import LLMEnv, LLMHashingEnv
15
15
  from .libs import make_mlgym, MLGymWrapper
16
16
  from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
17
17
  from .transforms import (
18
+ AddThinkingPrompt,
18
19
  as_nested_tensor,
19
20
  as_padded_tensor,
20
21
  BrowserTransform,
@@ -33,6 +34,7 @@ __all__ = [
33
34
  "ChatEnv",
34
35
  "DataLoadingPrimer",
35
36
  "DatasetChatEnv",
37
+ "AddThinkingPrompt",
36
38
  "GSM8KEnv",
37
39
  "GSM8KPrepareQuestion",
38
40
  "GSM8KRewardParser",
torchrl/envs/llm/chat.py CHANGED
@@ -284,6 +284,7 @@ class DatasetChatEnv(TransformedEnv):
284
284
 
285
285
  Keyword Args:
286
286
  dataset (str): The name of the dataset.
287
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
287
288
  name (str, optional): name of the dataset configuration.
288
289
  split (str, optional): the split to use (usually from `"train"`, `"val"` or `"test"`). Defaults to `None` (no split).
289
290
  num_envs (int, optional): The number of environments to create. Defaults to `1`.
@@ -317,6 +318,7 @@ class DatasetChatEnv(TransformedEnv):
317
318
  self,
318
319
  *,
319
320
  dataset: str,
321
+ shuffle: bool = True,
320
322
  name: str | None = None,
321
323
  split: Literal["train", "val", "test"] | None = None,
322
324
  num_envs: int = 1,
@@ -355,7 +357,7 @@ class DatasetChatEnv(TransformedEnv):
355
357
  dataloader = DataLoader( # noqa: TOR401
356
358
  dataset,
357
359
  batch_size=batch_size_dl,
358
- shuffle=True,
360
+ shuffle=shuffle,
359
361
  collate_fn=collate_fn,
360
362
  generator=generator,
361
363
  )
@@ -375,3 +377,14 @@ class DatasetChatEnv(TransformedEnv):
375
377
  apply_template=apply_template,
376
378
  )
377
379
  return super().__init__(env_base, primer)
380
+
381
+ def reset_dataloader(self):
382
+ """Reset the dataloader.
383
+
384
+ This is useful when the dataloader is not infinite and we want to reset it.
385
+
386
+ Returns:
387
+ self: The environment itself.
388
+ """
389
+ self.transform[0].reset_dataloader()
390
+ return self
@@ -135,6 +135,7 @@ class GSM8KEnv(DatasetChatEnv):
135
135
 
136
136
  Keyword Args:
137
137
  dataset (str, optional): The name of the dataset. Defaults to `"gsm8k"`.
138
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
138
139
  num_envs (int, optional): The number of environments to create. Defaults to `1`.
139
140
  repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
140
141
  based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
@@ -284,12 +285,13 @@ class GSM8KEnv(DatasetChatEnv):
284
285
  SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
285
286
  The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
286
287
  The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively,
287
- i.e., <think>reasoning process here</think> <answer>answer here</answer>."""
288
+ i.e., <think>reasoning process here</think> <answer>answer here</answer>. The answer should be a number."""
288
289
 
289
290
  def __init__(
290
291
  self,
291
292
  *,
292
293
  dataset: str = "gsm8k",
294
+ shuffle: bool = True,
293
295
  num_envs: int = 1,
294
296
  repeats: int | None = None,
295
297
  batch_size_dl: int = 1,
@@ -307,6 +309,7 @@ i.e., <think>reasoning process here</think> <answer>answer here</answer>."""
307
309
  collate_fn = _collate_fn
308
310
  super().__init__(
309
311
  dataset=dataset,
312
+ shuffle=shuffle,
310
313
  name="main",
311
314
  num_envs=num_envs,
312
315
  repeats=repeats,
@@ -41,6 +41,7 @@ class IFEvalEnv(DatasetChatEnv):
41
41
 
42
42
  Keyword Args:
43
43
  dataset (str, optional): The name of the dataset. Defaults to `"google/IFeval"`.
44
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
44
45
  num_envs (int, optional): The number of environments to create. Defaults to `1`.
45
46
  repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
46
47
  based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
@@ -146,6 +147,7 @@ You will be assessed by the content of the answer block only, so make sure it co
146
147
  self,
147
148
  *,
148
149
  dataset: str = "google/IFeval",
150
+ shuffle: bool = True,
149
151
  num_envs: int = 1,
150
152
  repeats: int | None = None,
151
153
  batch_size_dl: int = 1,
@@ -163,6 +165,7 @@ You will be assessed by the content of the answer block only, so make sure it co
163
165
  collate_fn = _collate_fn
164
166
  super().__init__(
165
167
  dataset=dataset,
168
+ shuffle=shuffle,
166
169
  num_envs=num_envs,
167
170
  repeats=repeats,
168
171
  batch_size_dl=batch_size_dl,
@@ -20,6 +20,7 @@ class GSM8KRewardParser(Transform):
20
20
  in_keys (list of NestedKey): the input keys. Defaults to `["text_response", "answer"]`.
21
21
  out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
22
22
  eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
23
+ set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
23
24
 
24
25
  """
25
26
 
@@ -29,10 +30,18 @@ class GSM8KRewardParser(Transform):
29
30
  in_keys: list[NestedKey] | None = None,
30
31
  out_keys: list[NestedKey] | None = None,
31
32
  eos_token: str | None = None,
33
+ set_done_if_answer: bool = True,
32
34
  ):
33
35
  super().__init__()
34
36
  self.tokenizer = tokenizer
35
- self.eos_token = eos_token if eos_token is not None else tokenizer.eos_token
37
+ self.eos_token = (
38
+ eos_token
39
+ if eos_token is not None
40
+ else tokenizer.eos_token
41
+ if tokenizer is not None
42
+ else None
43
+ )
44
+ self.set_done_if_answer = set_done_if_answer
36
45
  if in_keys is None:
37
46
  in_keys = ["text_response", "answer"]
38
47
  if not isinstance(in_keys, list) or len(in_keys) != 2:
@@ -118,7 +127,20 @@ class GSM8KRewardParser(Transform):
118
127
  tds = tds.add(
119
128
  next_td_exist, default=torch.zeros((), device=next_tensordict.device)
120
129
  )
121
- return next_tensordict.update(tds)
130
+ next_tensordict = next_tensordict.update(tds)
131
+ if (
132
+ self.set_done_if_answer
133
+ and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
134
+ ):
135
+ done = next_tensordict.get("done")
136
+ if done is not None:
137
+ next_tensordict.set("done", reward_answer.view_as(done) | done)
138
+ terminated = next_tensordict.get("terminated")
139
+ if terminated is not None:
140
+ next_tensordict.set(
141
+ "terminated", reward_answer.view_as(terminated) | terminated
142
+ )
143
+ return next_tensordict
122
144
 
123
145
  def transform_reward_spec(self, reward_spec: Composite) -> Composite:
124
146
  shape = reward_spec.shape + (1, 1)
@@ -8,6 +8,7 @@ from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
8
8
  from .format import TemplateTransform
9
9
  from .kl import KLRewardTransform, RetrieveLogProb
10
10
  from .policy_version import PolicyVersion
11
+ from .reason import AddThinkingPrompt
11
12
  from .tokenizer import Tokenizer
12
13
  from .tools import MCPToolTransform, PythonInterpreter
13
14
 
@@ -19,6 +20,7 @@ __all__ = [
19
20
  "MCPToolTransform",
20
21
  "PolicyVersion",
21
22
  "PythonInterpreter",
23
+ "AddThinkingPrompt",
22
24
  "TemplateTransform",
23
25
  "Tokenizer",
24
26
  "as_nested_tensor",
@@ -447,6 +447,18 @@ class DataLoadingPrimer(TensorDictPrimer):
447
447
  )
448
448
  self._reset_key = "_reset"
449
449
 
450
+ def reset_dataloader(self):
451
+ """Reset the dataloader.
452
+
453
+ This is useful when the dataloader is not infinite and we want to reset it.
454
+
455
+ Returns:
456
+ self: The transform itself.
457
+ """
458
+ self._queue.clear()
459
+ self.endless_dataloader = self._endless_iter(self.dataloader)
460
+ return self
461
+
450
462
  @classmethod
451
463
  def _endless_iter(self, obj):
452
464
  while True:
@@ -0,0 +1,260 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+ from typing import Callable, Literal
10
+
11
+ from tensordict import lazy_stack, TensorDictBase
12
+
13
+ from torchrl.data.llm.chat import History
14
+ from torchrl.envs import Transform
15
+ from torchrl.envs.common import EnvBase
16
+
17
+
18
+ class AddThinkingPrompt(Transform):
19
+ """A transform that adds thinking prompts to encourage the LLM to reconsider its response.
20
+
21
+ This transform can either add a new thinking prompt as a separate message or edit the last
22
+ assistant response to include a thinking prompt before the final answer. This is useful for
23
+ training LLMs to self-correct and think more carefully when their initial responses are
24
+ incorrect or incomplete.
25
+
26
+ Args:
27
+ cond (Callable[[TensorDictBase], bool], optional): Condition function that determines
28
+ when to add the thinking prompt. Takes a tensordict and returns `True` if the prompt
29
+ should be added.
30
+ prompt (str, optional): The thinking prompt to add. If None, a default prompt is used.
31
+ Defaults to `"But wait, let me think about this more carefully..."`.
32
+ random_prompt (bool, optional): Whether to randomly select from predefined prompts.
33
+ Defaults to `False`.
34
+ role (Literal["user", "assistant"], optional): The role for the thinking prompt.
35
+ If `"assistant"`, the prompt is added to the assistant's response. If `"user"`, it's
36
+ added as a separate user message. Defaults to `"assistant"`.
37
+ edit_last_turn (bool, optional): Whether to edit the last assistant response instead
38
+ of adding a new message. Only works with `role="assistant"`. Defaults to `True`.
39
+ zero_reward (bool, optional): Whether to zero out the reward when the thinking prompt
40
+ is added. If `None`, defaults to the value of `edit_last_turn`. Defaults to the same value as `edit_last_turn`.
41
+ undo_done (bool, optional): Whether to undo the done flag when the thinking prompt
42
+ is added. Defaults to `True`.
43
+
44
+ Examples:
45
+ >>> from torchrl.envs.llm.transforms import AddThinkingPrompt
46
+ >>> from torchrl.envs.llm import GSM8KEnv
47
+ >>> from transformers import AutoTokenizer
48
+ >>> import torch
49
+ >>>
50
+ >>> # Create environment with thinking prompt transform
51
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
52
+ >>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
53
+ >>> env = env.append_transform(
54
+ ... AddThinkingPrompt(
55
+ ... cond=lambda td: td["reward"] < 50,
56
+ ... role="assistant",
57
+ ... edit_last_turn=True,
58
+ ... zero_reward=True,
59
+ ... undo_done=True
60
+ ... )
61
+ ... )
62
+ >>>
63
+ >>> # Test with wrong answer (low reward)
64
+ >>> reset = env.reset()
65
+ >>> wrong_answer = (
66
+ ... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
67
+ ... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
68
+ ... "To find the total, I need to add April and May: 48 + 24 = 72. "
69
+ ... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
70
+ ... "<answer>322 clips</answer><|im_end|>"
71
+ ... )
72
+ >>> reset["text_response"] = [wrong_answer]
73
+ >>> s = env.step(reset)
74
+ >>> assert (s["next", "reward"] == 0).all() # Reward zeroed
75
+ >>> assert (s["next", "done"] == 0).all() # Done undone
76
+ >>> assert s["next", "history"].shape == (1, 3) # History modified
77
+ >>>
78
+ >>> # Test with correct answer (high reward)
79
+ >>> reset = env.reset()
80
+ >>> correct_answer = (
81
+ ... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
82
+ ... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
83
+ ... "To find the total, I need to add April and May: 48 + 24 = 72. "
84
+ ... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
85
+ ... "<answer>72</answer><|im_end|>"
86
+ ... )
87
+ >>> reset["text_response"] = [correct_answer]
88
+ >>> s = env.step(reset)
89
+ >>> assert (s["next", "reward"] != 0).all() # Reward not zeroed
90
+ >>> assert s["next", "done"].all() # Done remains True
91
+ >>> assert s["next", "history"].shape == (1, 3) # History unchanged
92
+ """
93
+
94
+ # Predefined thinking prompts
95
+ DEFAULT_PROMPTS = [
96
+ "But wait, let me think about this more carefully...",
97
+ "Actually, let me reconsider this...",
98
+ "Let me think about it step by step...",
99
+ "Wait, I need to double-check my reasoning...",
100
+ "Actually, let me think about it more carefully...",
101
+ ]
102
+
103
+ def __init__(
104
+ self,
105
+ cond: Callable[[TensorDictBase], bool],
106
+ prompt: str | None = None,
107
+ random_prompt: bool = False,
108
+ role: Literal["user", "assistant"] = "assistant",
109
+ edit_last_turn: bool = True,
110
+ zero_reward: bool | None = None,
111
+ undo_done: bool = True,
112
+ ) -> None:
113
+ super().__init__()
114
+
115
+ # Set the prompt
116
+ if prompt is None:
117
+ prompt = self.DEFAULT_PROMPTS[0]
118
+ self._prompt = prompt
119
+ self.random_prompt = random_prompt
120
+
121
+ # Set condition and role
122
+ self.cond = cond
123
+ self.role = role
124
+
125
+ # Validate edit_last_turn constraint
126
+ if edit_last_turn and role != "assistant":
127
+ raise ValueError("edit_last_turn can only be used with role='assistant'")
128
+ self.edit_last_turn = edit_last_turn
129
+
130
+ # Set zero_reward behavior
131
+ if zero_reward is None:
132
+ zero_reward = edit_last_turn
133
+ self.zero_reward = zero_reward
134
+ self.undo_done = undo_done
135
+
136
+ @property
137
+ def prompt(self) -> str:
138
+ if self.random_prompt:
139
+ import random
140
+
141
+ return random.choice(self.DEFAULT_PROMPTS)
142
+ return self._prompt
143
+
144
+ def _step(
145
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
146
+ ) -> TensorDictBase:
147
+ """Process the tensordict and add thinking prompts based on the condition.
148
+
149
+ Args:
150
+ tensordict: The current tensordict
151
+ next_tensordict: The next tensordict containing the most recent history and reward
152
+
153
+ Returns:
154
+ The modified next_tensordict
155
+ """
156
+ print("Reward", next_tensordict["reward"])
157
+ # Handle batch dimensions
158
+ if next_tensordict.batch_dims >= 1:
159
+ ntds = []
160
+ for td, next_td in zip(tensordict.unbind(0), next_tensordict.unbind(0)):
161
+ ntds.append(self._step(td, next_td))
162
+ next_tensordict.update(lazy_stack(ntds))
163
+ return next_tensordict
164
+
165
+ # Check if we should add the thinking prompt
166
+ if self.cond(next_tensordict):
167
+ history: History = next_tensordict["history"]
168
+ last_turn = history[..., -1]
169
+
170
+ if self.edit_last_turn:
171
+ # Edit the last assistant response
172
+ content = last_turn.content
173
+ modified_content = self._replace_answer_with_prompt(content)
174
+
175
+ # Create new history entry with modified content
176
+ new_turn = History(
177
+ role="assistant",
178
+ content=modified_content,
179
+ batch_size=last_turn.batch_size,
180
+ device=last_turn.device,
181
+ )
182
+
183
+ # Replace the last turn in history
184
+ history = history[..., :-1].append(new_turn)
185
+ next_tensordict["history"] = history
186
+
187
+ else:
188
+ # Add a new message
189
+ prompt = self.prompt
190
+
191
+ history = history.append(History(role=self.role, content=prompt))
192
+ next_tensordict["history"] = history
193
+
194
+ if self.undo_done:
195
+ parent: EnvBase = self.parent
196
+ if parent is not None:
197
+ done_keys = parent.done_keys
198
+ for key in done_keys:
199
+ done = next_tensordict.get(key)
200
+ if done is not None:
201
+ next_tensordict.set(key, done.zero_())
202
+
203
+ # Zero out reward if requested
204
+ if self.zero_reward:
205
+ parent: EnvBase = self.parent
206
+ if parent is not None:
207
+ reward_keys = parent.reward_keys
208
+ for key in reward_keys:
209
+ reward = next_tensordict.get(key)
210
+ if reward is not None:
211
+ next_tensordict.set(key, reward.zero_())
212
+ return next_tensordict
213
+
214
+ def _replace_answer_with_prompt(self, content: str) -> str:
215
+ """Replace the answer section with a thinking prompt.
216
+
217
+ This method uses regex to find and replace the <answer>...</answer> section
218
+ with the thinking prompt, preserving any content before the answer tag.
219
+
220
+ Args:
221
+ content: The original content string
222
+
223
+ Returns:
224
+ The modified content with the answer replaced by the thinking prompt
225
+ """
226
+ # Pattern to match <answer>...</answer> with optional EOS token
227
+ answer_pattern = r"<answer>.*?</answer>(?:\s*<\|im_end\|>)?"
228
+
229
+ # Check if there's an answer tag
230
+ if "<answer>" in content:
231
+ # Replace the answer section with the thinking prompt
232
+ prompt = self.prompt
233
+
234
+ # Replace the answer section
235
+ modified_content = re.sub(answer_pattern, prompt, content, flags=re.DOTALL)
236
+
237
+ # Clean up any trailing whitespace
238
+ modified_content = modified_content.rstrip()
239
+
240
+ else:
241
+ # No answer tag found, just append the prompt
242
+ prompt = self.prompt
243
+
244
+ modified_content = content.rstrip() + "\n\n" + prompt
245
+
246
+ return modified_content
247
+
248
+ def _reset(
249
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
250
+ ) -> TensorDictBase:
251
+ """Reset the transform state.
252
+
253
+ Args:
254
+ tensordict: The current tensordict
255
+ tensordict_reset: The reset tensordict
256
+
257
+ Returns:
258
+ The reset tensordict
259
+ """
260
+ return tensordict_reset
@@ -738,7 +738,7 @@ class Transform(nn.Module):
738
738
  self.__dict__.update(state)
739
739
 
740
740
  @property
741
- def parent(self) -> EnvBase | None:
741
+ def parent(self) -> TransformedEnv | None:
742
742
  """Returns the parent env of the transform.
743
743
 
744
744
  The parent env is the env that contains all the transforms up until the current one.
@@ -1249,6 +1249,7 @@ but got an object of type {type(transform)}."""
1249
1249
  def empty_cache(self):
1250
1250
  self.__dict__["_output_spec"] = None
1251
1251
  self.__dict__["_input_spec"] = None
1252
+ self.transform.empty_cache()
1252
1253
  super().empty_cache()
1253
1254
 
1254
1255
  def append_transform(
@@ -1429,6 +1430,50 @@ class Compose(Transform):
1429
1430
  for t in transforms:
1430
1431
  t.set_container(self)
1431
1432
 
1433
+ def pop(self, index: int | None = None) -> Transform:
1434
+ """Pop a transform from the chain.
1435
+
1436
+ Args:
1437
+ index (int, optional): The index of the transform to pop. If None, the last transform is popped.
1438
+
1439
+ Returns:
1440
+ The popped transform.
1441
+ """
1442
+ if index is None:
1443
+ index = len(self.transforms) - 1
1444
+ result = self.transforms.pop(index)
1445
+ parent = self.parent
1446
+ self.empty_cache()
1447
+ if parent is not None:
1448
+ parent.empty_cache()
1449
+ return result
1450
+
1451
+ def __delitem__(self, index: int | slice | list):
1452
+ """Delete a transform in the chain.
1453
+
1454
+ :class:`~torchrl.envs.transforms.Transform` or callable are accepted.
1455
+ """
1456
+ del self.transforms[index]
1457
+ parent = self.parent
1458
+ self.empty_cache()
1459
+ if parent is not None:
1460
+ parent.empty_cache()
1461
+
1462
+ def __setitem__(
1463
+ self,
1464
+ index: int | slice | list,
1465
+ value: Transform | Callable[[TensorDictBase], TensorDictBase],
1466
+ ):
1467
+ """Set a transform in the chain.
1468
+
1469
+ :class:`~torchrl.envs.transforms.Transform` or callable are accepted.
1470
+ """
1471
+ self.transforms[index] = value
1472
+ parent = self.parent
1473
+ self.empty_cache()
1474
+ if parent is not None:
1475
+ parent.empty_cache()
1476
+
1432
1477
  def close(self):
1433
1478
  """Close the transform."""
1434
1479
  for t in self.transforms:
@@ -1594,6 +1639,9 @@ class Compose(Transform):
1594
1639
  else:
1595
1640
  self.transforms.append(transform)
1596
1641
  transform.set_container(self)
1642
+ parent = self.parent
1643
+ if parent is not None:
1644
+ parent.empty_cache()
1597
1645
 
1598
1646
  def set_container(self, container: Transform | EnvBase) -> None:
1599
1647
  self.reset_parent()
@@ -1626,6 +1674,9 @@ class Compose(Transform):
1626
1674
 
1627
1675
  # empty cache of all transforms to reset parents and specs
1628
1676
  self.empty_cache()
1677
+ parent = self.parent
1678
+ if parent is not None:
1679
+ parent.empty_cache()
1629
1680
  if index < 0:
1630
1681
  index = index + len(self.transforms)
1631
1682
  transform.eval()
torchrl/objectives/ppo.py CHANGED
@@ -752,10 +752,10 @@ class PPOLoss(LossModule):
752
752
 
753
753
  explained_variance = None
754
754
  if self.log_explained_variance:
755
- with torch.no_grad(): # <‑‑ break grad‐flow
756
- tgt = target_return.detach()
757
- pred = state_value.detach()
758
- eps = torch.finfo(tgt.dtype).eps
755
+ with torch.no_grad(): # <‑‑ break grad‐flow
756
+ tgt = target_return.detach()
757
+ pred = state_value.detach()
758
+ eps = torch.finfo(tgt.dtype).eps
759
759
  resid = torch.var(tgt - pred, unbiased=False, dim=0)
760
760
  total = torch.var(tgt, unbiased=False, dim=0)
761
761
  explained_variance = 1.0 - resid / (total + eps)
@@ -819,7 +819,9 @@ class PPOLoss(LossModule):
819
819
  td_out.set("entropy", entropy.detach().mean()) # for logging
820
820
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
821
821
  if self._has_critic:
822
- loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
822
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
823
+ tensordict
824
+ )
823
825
  td_out.set("loss_critic", loss_critic)
824
826
  if value_clip_fraction is not None:
825
827
  td_out.set("value_clip_fraction", value_clip_fraction)
@@ -1189,7 +1191,9 @@ class ClipPPOLoss(PPOLoss):
1189
1191
  td_out.set("entropy", entropy.detach().mean()) # for logging
1190
1192
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
1191
1193
  if self._has_critic:
1192
- loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
1194
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
1195
+ tensordict
1196
+ )
1193
1197
  td_out.set("loss_critic", loss_critic)
1194
1198
  if value_clip_fraction is not None:
1195
1199
  td_out.set("value_clip_fraction", value_clip_fraction)
@@ -1537,7 +1541,9 @@ class KLPENPPOLoss(PPOLoss):
1537
1541
  td_out.set("entropy", entropy.detach().mean()) # for logging
1538
1542
  td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
1539
1543
  if self._has_critic:
1540
- loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
1544
+ loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
1545
+ tensordict_copy
1546
+ )
1541
1547
  td_out.set("loss_critic", loss_critic)
1542
1548
  if value_clip_fraction is not None:
1543
1549
  td_out.set("value_clip_fraction", value_clip_fraction)
torchrl/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = '2025.6.21'
2
- git_version = '77dbc6c9ffbce3d2ce3f26b659355cd46d8132c3'
1
+ __version__ = '2025.6.23'
2
+ git_version = 'ed051bc3e5b33d00f64f2a785023bca9a6c72c9b'
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: torchrl-nightly
3
- Version: 2025.6.21
3
+ Version: 2025.6.23
4
4
  Home-page: https://github.com/pytorch/rl
5
5
  Author: torchrl contributors
6
6
  Author-email: vmoens@fb.com
@@ -18,119 +18,129 @@ Classifier: License :: OSI Approved :: BSD License
18
18
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
19
  Description-Content-Type: text/markdown
20
20
  License-File: LICENSE
21
- Requires-Dist: torch >=2.1.0
21
+ Requires-Dist: torch>=2.1.0
22
22
  Requires-Dist: numpy
23
23
  Requires-Dist: packaging
24
24
  Requires-Dist: cloudpickle
25
25
  Requires-Dist: tensordict-nightly
26
- Provides-Extra: all
27
- Requires-Dist: accelerate ; extra == 'all'
28
- Requires-Dist: datasets ; extra == 'all'
29
- Requires-Dist: dm-meltingpot ; extra == 'all'
30
- Requires-Dist: dm-control ; extra == 'all'
31
- Requires-Dist: einops ; extra == 'all'
32
- Requires-Dist: git ; extra == 'all'
33
- Requires-Dist: gymnasium <1.0 ; extra == 'all'
34
- Requires-Dist: gymnasium[atari] ; extra == 'all'
35
- Requires-Dist: h5py ; extra == 'all'
36
- Requires-Dist: huggingface-hub ; extra == 'all'
37
- Requires-Dist: hydra-core >=1.1 ; extra == 'all'
38
- Requires-Dist: hydra-submitit-launcher ; extra == 'all'
39
- Requires-Dist: immutabledict ; extra == 'all'
40
- Requires-Dist: langdetect ; extra == 'all'
41
- Requires-Dist: minari ; extra == 'all'
42
- Requires-Dist: moviepy <2.0.0 ; extra == 'all'
43
- Requires-Dist: mujoco ; extra == 'all'
44
- Requires-Dist: nltk ; extra == 'all'
45
- Requires-Dist: open-spiel >=1.5 ; extra == 'all'
46
- Requires-Dist: pandas ; extra == 'all'
47
- Requires-Dist: pettingzoo >=1.24.1 ; extra == 'all'
48
- Requires-Dist: pillow ; extra == 'all'
49
- Requires-Dist: playwright ; extra == 'all'
50
- Requires-Dist: protobuf ; extra == 'all'
51
- Requires-Dist: pytest ; extra == 'all'
52
- Requires-Dist: pytest-asyncio ; extra == 'all'
53
- Requires-Dist: pytest-benchmark ; extra == 'all'
54
- Requires-Dist: pytest-cov ; extra == 'all'
55
- Requires-Dist: pytest-error-for-skips ; extra == 'all'
56
- Requires-Dist: pytest-instafail ; extra == 'all'
57
- Requires-Dist: pytest-mock ; extra == 'all'
58
- Requires-Dist: pytest-rerunfailures ; extra == 'all'
59
- Requires-Dist: pyyaml ; extra == 'all'
60
- Requires-Dist: requests ; extra == 'all'
61
- Requires-Dist: safetensors ; extra == 'all'
62
- Requires-Dist: scikit-learn ; extra == 'all'
63
- Requires-Dist: scipy ; extra == 'all'
64
- Requires-Dist: sentencepiece ; extra == 'all'
65
- Requires-Dist: tensorboard ; extra == 'all'
66
- Requires-Dist: torch >=2.7.0 ; extra == 'all'
67
- Requires-Dist: torchsnapshot ; extra == 'all'
68
- Requires-Dist: torchvision ; extra == 'all'
69
- Requires-Dist: tqdm ; extra == 'all'
70
- Requires-Dist: transformers ; extra == 'all'
71
- Requires-Dist: vllm ; extra == 'all'
72
- Requires-Dist: vmas >=1.2.10 ; extra == 'all'
73
- Requires-Dist: wandb ; extra == 'all'
74
26
  Provides-Extra: atari
75
- Requires-Dist: gymnasium[atari] ; extra == 'atari'
76
- Provides-Extra: checkpointing
77
- Requires-Dist: torchsnapshot ; extra == 'checkpointing'
78
- Provides-Extra: dm_control
79
- Requires-Dist: dm-control ; extra == 'dm_control'
80
- Provides-Extra: gym_continuous
81
- Requires-Dist: gymnasium <1.0 ; extra == 'gym_continuous'
82
- Requires-Dist: mujoco ; extra == 'gym_continuous'
83
- Provides-Extra: llm
84
- Requires-Dist: transformers ; extra == 'llm'
85
- Requires-Dist: vllm ; extra == 'llm'
86
- Requires-Dist: playwright ; extra == 'llm'
87
- Requires-Dist: datasets ; extra == 'llm'
88
- Requires-Dist: langdetect ; extra == 'llm'
89
- Requires-Dist: nltk ; extra == 'llm'
90
- Requires-Dist: immutabledict ; extra == 'llm'
91
- Requires-Dist: accelerate ; extra == 'llm'
92
- Requires-Dist: sentencepiece ; extra == 'llm'
93
- Requires-Dist: protobuf ; extra == 'llm'
94
- Requires-Dist: einops ; extra == 'llm'
95
- Requires-Dist: safetensors ; extra == 'llm'
96
- Provides-Extra: marl
97
- Requires-Dist: vmas >=1.2.10 ; extra == 'marl'
98
- Requires-Dist: pettingzoo >=1.24.1 ; extra == 'marl'
99
- Requires-Dist: dm-meltingpot ; extra == 'marl'
100
- Provides-Extra: offline-data
101
- Requires-Dist: huggingface-hub ; extra == 'offline-data'
102
- Requires-Dist: minari ; extra == 'offline-data'
103
- Requires-Dist: requests ; extra == 'offline-data'
104
- Requires-Dist: tqdm ; extra == 'offline-data'
105
- Requires-Dist: torchvision ; extra == 'offline-data'
106
- Requires-Dist: scikit-learn ; extra == 'offline-data'
107
- Requires-Dist: pandas ; extra == 'offline-data'
108
- Requires-Dist: h5py ; extra == 'offline-data'
109
- Requires-Dist: pillow ; extra == 'offline-data'
110
- Provides-Extra: open_spiel
111
- Requires-Dist: open-spiel >=1.5 ; extra == 'open_spiel'
27
+ Requires-Dist: gymnasium[atari]; extra == "atari"
28
+ Provides-Extra: dm-control
29
+ Requires-Dist: dm_control; extra == "dm-control"
30
+ Provides-Extra: replay-buffer
31
+ Requires-Dist: torch>=2.7.0; extra == "replay-buffer"
32
+ Provides-Extra: gym-continuous
33
+ Requires-Dist: gymnasium<1.0; extra == "gym-continuous"
34
+ Requires-Dist: mujoco; extra == "gym-continuous"
112
35
  Provides-Extra: rendering
113
- Requires-Dist: moviepy <2.0.0 ; extra == 'rendering'
114
- Provides-Extra: replay_buffer
115
- Requires-Dist: torch >=2.7.0 ; extra == 'replay_buffer'
36
+ Requires-Dist: moviepy<2.0.0; extra == "rendering"
116
37
  Provides-Extra: tests
117
- Requires-Dist: pytest ; extra == 'tests'
118
- Requires-Dist: pyyaml ; extra == 'tests'
119
- Requires-Dist: pytest-instafail ; extra == 'tests'
120
- Requires-Dist: scipy ; extra == 'tests'
121
- Requires-Dist: pytest-mock ; extra == 'tests'
122
- Requires-Dist: pytest-cov ; extra == 'tests'
123
- Requires-Dist: pytest-asyncio ; extra == 'tests'
124
- Requires-Dist: pytest-benchmark ; extra == 'tests'
125
- Requires-Dist: pytest-rerunfailures ; extra == 'tests'
126
- Requires-Dist: pytest-error-for-skips ; extra == 'tests'
38
+ Requires-Dist: pytest; extra == "tests"
39
+ Requires-Dist: pyyaml; extra == "tests"
40
+ Requires-Dist: pytest-instafail; extra == "tests"
41
+ Requires-Dist: scipy; extra == "tests"
42
+ Requires-Dist: pytest-mock; extra == "tests"
43
+ Requires-Dist: pytest-cov; extra == "tests"
44
+ Requires-Dist: pytest-asyncio; extra == "tests"
45
+ Requires-Dist: pytest-benchmark; extra == "tests"
46
+ Requires-Dist: pytest-rerunfailures; extra == "tests"
47
+ Requires-Dist: pytest-error-for-skips; extra == "tests"
127
48
  Provides-Extra: utils
128
- Requires-Dist: tensorboard ; extra == 'utils'
129
- Requires-Dist: wandb ; extra == 'utils'
130
- Requires-Dist: tqdm ; extra == 'utils'
131
- Requires-Dist: hydra-core >=1.1 ; extra == 'utils'
132
- Requires-Dist: hydra-submitit-launcher ; extra == 'utils'
133
- Requires-Dist: git ; extra == 'utils'
49
+ Requires-Dist: tensorboard; extra == "utils"
50
+ Requires-Dist: wandb; extra == "utils"
51
+ Requires-Dist: tqdm; extra == "utils"
52
+ Requires-Dist: hydra-core>=1.1; extra == "utils"
53
+ Requires-Dist: hydra-submitit-launcher; extra == "utils"
54
+ Requires-Dist: git; extra == "utils"
55
+ Provides-Extra: checkpointing
56
+ Requires-Dist: torchsnapshot; extra == "checkpointing"
57
+ Provides-Extra: offline-data
58
+ Requires-Dist: huggingface_hub; extra == "offline-data"
59
+ Requires-Dist: minari; extra == "offline-data"
60
+ Requires-Dist: requests; extra == "offline-data"
61
+ Requires-Dist: tqdm; extra == "offline-data"
62
+ Requires-Dist: torchvision; extra == "offline-data"
63
+ Requires-Dist: scikit-learn; extra == "offline-data"
64
+ Requires-Dist: pandas; extra == "offline-data"
65
+ Requires-Dist: h5py; extra == "offline-data"
66
+ Requires-Dist: pillow; extra == "offline-data"
67
+ Provides-Extra: marl
68
+ Requires-Dist: vmas>=1.2.10; extra == "marl"
69
+ Requires-Dist: pettingzoo>=1.24.1; extra == "marl"
70
+ Requires-Dist: dm-meltingpot; extra == "marl"
71
+ Provides-Extra: open-spiel
72
+ Requires-Dist: open_spiel>=1.5; extra == "open-spiel"
73
+ Provides-Extra: llm
74
+ Requires-Dist: transformers; extra == "llm"
75
+ Requires-Dist: vllm; extra == "llm"
76
+ Requires-Dist: playwright; extra == "llm"
77
+ Requires-Dist: datasets; extra == "llm"
78
+ Requires-Dist: langdetect; extra == "llm"
79
+ Requires-Dist: nltk; extra == "llm"
80
+ Requires-Dist: immutabledict; extra == "llm"
81
+ Requires-Dist: accelerate; extra == "llm"
82
+ Requires-Dist: sentencepiece; extra == "llm"
83
+ Requires-Dist: protobuf; extra == "llm"
84
+ Requires-Dist: einops; extra == "llm"
85
+ Requires-Dist: safetensors; extra == "llm"
86
+ Provides-Extra: all
87
+ Requires-Dist: accelerate; extra == "all"
88
+ Requires-Dist: datasets; extra == "all"
89
+ Requires-Dist: dm-meltingpot; extra == "all"
90
+ Requires-Dist: dm_control; extra == "all"
91
+ Requires-Dist: einops; extra == "all"
92
+ Requires-Dist: git; extra == "all"
93
+ Requires-Dist: gymnasium<1.0; extra == "all"
94
+ Requires-Dist: gymnasium[atari]; extra == "all"
95
+ Requires-Dist: h5py; extra == "all"
96
+ Requires-Dist: huggingface_hub; extra == "all"
97
+ Requires-Dist: hydra-core>=1.1; extra == "all"
98
+ Requires-Dist: hydra-submitit-launcher; extra == "all"
99
+ Requires-Dist: immutabledict; extra == "all"
100
+ Requires-Dist: langdetect; extra == "all"
101
+ Requires-Dist: minari; extra == "all"
102
+ Requires-Dist: moviepy<2.0.0; extra == "all"
103
+ Requires-Dist: mujoco; extra == "all"
104
+ Requires-Dist: nltk; extra == "all"
105
+ Requires-Dist: open_spiel>=1.5; extra == "all"
106
+ Requires-Dist: pandas; extra == "all"
107
+ Requires-Dist: pettingzoo>=1.24.1; extra == "all"
108
+ Requires-Dist: pillow; extra == "all"
109
+ Requires-Dist: playwright; extra == "all"
110
+ Requires-Dist: protobuf; extra == "all"
111
+ Requires-Dist: pytest; extra == "all"
112
+ Requires-Dist: pytest-asyncio; extra == "all"
113
+ Requires-Dist: pytest-benchmark; extra == "all"
114
+ Requires-Dist: pytest-cov; extra == "all"
115
+ Requires-Dist: pytest-error-for-skips; extra == "all"
116
+ Requires-Dist: pytest-instafail; extra == "all"
117
+ Requires-Dist: pytest-mock; extra == "all"
118
+ Requires-Dist: pytest-rerunfailures; extra == "all"
119
+ Requires-Dist: pyyaml; extra == "all"
120
+ Requires-Dist: requests; extra == "all"
121
+ Requires-Dist: safetensors; extra == "all"
122
+ Requires-Dist: scikit-learn; extra == "all"
123
+ Requires-Dist: scipy; extra == "all"
124
+ Requires-Dist: sentencepiece; extra == "all"
125
+ Requires-Dist: tensorboard; extra == "all"
126
+ Requires-Dist: torch>=2.7.0; extra == "all"
127
+ Requires-Dist: torchsnapshot; extra == "all"
128
+ Requires-Dist: torchvision; extra == "all"
129
+ Requires-Dist: tqdm; extra == "all"
130
+ Requires-Dist: transformers; extra == "all"
131
+ Requires-Dist: vllm; extra == "all"
132
+ Requires-Dist: vmas>=1.2.10; extra == "all"
133
+ Requires-Dist: wandb; extra == "all"
134
+ Dynamic: author
135
+ Dynamic: author-email
136
+ Dynamic: classifier
137
+ Dynamic: description
138
+ Dynamic: description-content-type
139
+ Dynamic: home-page
140
+ Dynamic: license
141
+ Dynamic: license-file
142
+ Dynamic: provides-extra
143
+ Dynamic: requires-dist
134
144
 
135
145
  [![Unit-tests](https://github.com/pytorch/rl/actions/workflows/test-linux.yml/badge.svg)](https://github.com/pytorch/rl/actions/workflows/test-linux.yml)
136
146
  [![Documentation](https://img.shields.io/badge/Documentation-blue.svg)](https://pytorch.org/rl/)
@@ -3,9 +3,9 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
3
3
  build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
4
4
  torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
5
5
  torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
6
- torchrl/_torchrl.cpython-312-darwin.so,sha256=TU6MyMJjef82BMjWzfoA-jerlMeKCo4x02TNL8y-IrU,1691072
6
+ torchrl/_torchrl.cpython-312-darwin.so,sha256=2G08KUB1lgllQaOiK_aiC-u0bm9sPqvrm5fF8LPT_bc,1691072
7
7
  torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
8
- torchrl/version.py,sha256=9TpIavFD2hzZlpmXpm_tjHh5avX5AXfLXuNk_r1S5wc,83
8
+ torchrl/version.py,sha256=4YeSUDGHrB3YeBHYcVaU3pmlvqaCzwjkSwnDvVAiGUQ,83
9
9
  torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
10
10
  torchrl/collectors/collectors.py,sha256=CdTerIwhCTr6n5OoJLNad0bNQ5OLliPZFWkU18QBKSA,177625
11
11
  torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
@@ -97,28 +97,29 @@ torchrl/envs/libs/smacv2.py,sha256=i0TRHuZ9S9v0NfufPgQAcTlvAjf6JKv8hHvOzjSgsaw,2
97
97
  torchrl/envs/libs/unity_mlagents.py,sha256=Z3qSU0H3o2NXbS2lNvQ7OmYxkr3AWAMyRHfxeCtNZrk,49667
98
98
  torchrl/envs/libs/utils.py,sha256=RgiR16KJWFEtQim44-AIcHByGTq_NrtpjWoYIC13aYA,5207
99
99
  torchrl/envs/libs/vmas.py,sha256=a71_jU4r627hFXcMsT5wNSb4TMpyd3punLdOF3Cc8O0,36297
100
- torchrl/envs/llm/__init__.py,sha256=o8uAVGHYngy_k6xM5qIkqgHaz__S1HyG7QjLd78gtaA,1265
101
- torchrl/envs/llm/chat.py,sha256=mVLjmBTwd6IWdlKJMRcynDJNVVbiHjCop5EVUXpaaAA,17794
100
+ torchrl/envs/llm/__init__.py,sha256=HGpJZYZHR3tJVZ0EKq-Zh2r715JSH_H82PNa1z8F9V0,1313
101
+ torchrl/envs/llm/chat.py,sha256=2j1S1-_EC52_RpIdN4Gy6_mkxBG8aXL8Yo93SQ2YRIM,18201
102
102
  torchrl/envs/llm/envs.py,sha256=Er-ahjgvtYG4LB7_EWOMbdobiUV5DOHPBQYkVTu80r4,34677
103
103
  torchrl/envs/llm/datasets/__init__.py,sha256=FFethtv8unJWzphGLPQVC5QD9NMdaygEjx25O1DHHZk,473
104
- torchrl/envs/llm/datasets/gsm8k.py,sha256=wTntpV-bi0gbyvJ-JnuHQmPXjXgV4hEssGFed8GRGGc,15299
105
- torchrl/envs/llm/datasets/ifeval.py,sha256=fVbMSVjpnlZR36B0yDUgDcM1Ye-EP6ui7g9nPRHX_vc,8327
104
+ torchrl/envs/llm/datasets/gsm8k.py,sha256=pAOWJh8ArCvTdOKWmr7bQb4o6Hqpoq6PjS0h9HAaRDE,15475
105
+ torchrl/envs/llm/datasets/ifeval.py,sha256=dzvSgOgqVxogFq0rC8O1SMqPfjQD8NEAGCGg5LnmXiU,8472
106
106
  torchrl/envs/llm/libs/__init__.py,sha256=vhEm5Fhz1sLWt107zfZLy5pzGmfQi0fNBGazTq1m7dU,266
107
107
  torchrl/envs/llm/libs/mlgym.py,sha256=ECnkrNoPV73L1fIO05SlTTXuTSNOM2pdX6aJcEYJVlo,31372
108
108
  torchrl/envs/llm/reward/__init__.py,sha256=a-Xsye29z2LugO1cOCFM2FNsqNwEp-5XwQk4saVQlu8,370
109
- torchrl/envs/llm/reward/gsm8k.py,sha256=2pUXYkCw6_arM6HCZJcrEYwRZMDntsFAzdpf3QXNthI,7862
109
+ torchrl/envs/llm/reward/gsm8k.py,sha256=GYd0l_YRaIiivZBLRGjhJeQiFj6jm-BUh9T3pEze3a8,8760
110
110
  torchrl/envs/llm/reward/ifeval/__init__.py,sha256=g5NtrwfwqK22hRcoIdz8-KWBh5Ogre9J-Bf3uGWE9Pg,314
111
111
  torchrl/envs/llm/reward/ifeval/_instructions.py,sha256=rAoTdwG42smCLJgwW7kAwJrNonjIS6OwdohDE70oMOA,61696
112
112
  torchrl/envs/llm/reward/ifeval/_instructions_main.py,sha256=CofKXvG0J2H-1ZXP1fL6UZI8ArNCIO2w5R_37drRIW8,4117
113
113
  torchrl/envs/llm/reward/ifeval/_instructions_registry.py,sha256=3_guc8LZ0mWQc-n6E4cQgYMgZRYa6xfgvXgrze9aO_w,3814
114
114
  torchrl/envs/llm/reward/ifeval/_instructions_util.py,sha256=aA3fupO8MvqBCqD7Y_Qk6y32toWF1lZGAflWON1ruXM,26042
115
115
  torchrl/envs/llm/reward/ifeval/_scorer.py,sha256=zJHBgaGlluEv6czsI6ZtLqArV_J_W9zY7UPAJhT5YIo,14563
116
- torchrl/envs/llm/transforms/__init__.py,sha256=roEOZVFOs1PhC1cGF-LIXQt5DlXZx6mgIJ-1k0JDTfI,788
116
+ torchrl/envs/llm/transforms/__init__.py,sha256=PNwdol9ItWXPfzyKcf4Id7Yu6oKFFAtta2J78ksSrf0,851
117
117
  torchrl/envs/llm/transforms/browser.py,sha256=zF7jHHHrdpxUCjFFtiYK-vhw-p1YqsqwP8_b4SiK0Rs,10423
118
- torchrl/envs/llm/transforms/dataloading.py,sha256=dv4IV3OWEa6-evxBk3WAZjkBi1_yKUs2NQ2gGmL2lKQ,24533
118
+ torchrl/envs/llm/transforms/dataloading.py,sha256=4P-e5yjUdxRtfaOmMxtNRisJLLtCqurAhWAqV7GiXHI,24872
119
119
  torchrl/envs/llm/transforms/format.py,sha256=ESn0S9k5G4FQPBICq9h6ZsLKXZqiU71tYW8UnW4rgLI,2519
120
120
  torchrl/envs/llm/transforms/kl.py,sha256=N68378chSx54X5a7YLJzIV6d870H5xrBb5-qWqzpX1U,22744
121
121
  torchrl/envs/llm/transforms/policy_version.py,sha256=by2TjsZLwVjQbq7ggBoAco2Iq_2aEYgyxh9asTXL1vk,6893
122
+ torchrl/envs/llm/transforms/reason.py,sha256=Q3LRbl7QmatRfAt7bOjOw_aLuZJgqRZvmKwT67cWX7s,10561
122
123
  torchrl/envs/llm/transforms/tokenizer.py,sha256=CcuKRu33YnyDgLtQtyxTGDFC6iI3b3fUA6Nb1Lnh7h8,13953
123
124
  torchrl/envs/llm/transforms/tools.py,sha256=I-HR0zjH4tFMp9xPH556H5Q5JqmqXdsAXwElAR93e5U,29498
124
125
  torchrl/envs/model_based/__init__.py,sha256=AkgZvTP5AerIg6ZwXfCfk3bnSr01hlwZWDiRd3UjBE4,331
@@ -131,7 +132,7 @@ torchrl/envs/transforms/llm.py,sha256=rQDzuut807wvFpSPCm5tynt8-cMKTgVKVjSVu9D99P
131
132
  torchrl/envs/transforms/r3m.py,sha256=sdTVLpnxHfzFVo5rO8WnXf2uUg9cr4LBOLBsWaFgGT8,13478
132
133
  torchrl/envs/transforms/rb_transforms.py,sha256=pxtL1VHvzEq6djuWsccLu4P-tnbAKsavemLGyt80I6c,7448
133
134
  torchrl/envs/transforms/rlhf.py,sha256=lOVXYqQaoDfm4_n77Dxw_wjicBpMtDvavKmBIK2N3lU,628
134
- torchrl/envs/transforms/transforms.py,sha256=59WHIbGryXTSvswHxvQSxHAza1k5-qtxwfWRzd6MQ6M,479710
135
+ torchrl/envs/transforms/transforms.py,sha256=8aXDl-NfugfqlBK-FcPBKYuDU-oIXeabW3uIXZ6QMik,481272
135
136
  torchrl/envs/transforms/utils.py,sha256=VXGH69Jxdmnw5eP9L3uM8ronQA5aIbT-Ktpjn5Frds0,2058
136
137
  torchrl/envs/transforms/vc1.py,sha256=mho5BvdAK-f9hD9t-iah52wT2B06qPmaJO7chrfIOWY,10534
137
138
  torchrl/envs/transforms/vecnorm.py,sha256=XahMcWvK3zjOB6EACSZtJ6UMP3yQ2zD9xf87UEB37Eg,34047
@@ -188,7 +189,7 @@ torchrl/objectives/dreamer.py,sha256=vIJQN91oPXYnPubDFQpaF5d3fR_WwIYuIVYtoCvw0TY
188
189
  torchrl/objectives/functional.py,sha256=ZaglBjEGuOTNGeFA-Ox-ugZVcNegQMUj--KWHDRBmaU,2106
189
190
  torchrl/objectives/gail.py,sha256=0m34XmcN-EDk5OfNIo5bKYbKKZfATsYRv4zQe3v2UwA,9576
190
191
  torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,42991
191
- torchrl/objectives/ppo.py,sha256=qoG7YiHHz6M5jn3XgtE32AmMERianoZqs-lSHQA35Rg,75284
192
+ torchrl/objectives/ppo.py,sha256=4fzV-DSFSGv0VHrI0YCk0EBUB35gkuWyo3j_4KhSoqE,75340
192
193
  torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
193
194
  torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
194
195
  torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,63993
@@ -223,8 +224,8 @@ torchrl/trainers/helpers/losses.py,sha256=qH-2YJwMtDAYAPXTTYy3cOPiq4ILC6xTjfnGUU
223
224
  torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
224
225
  torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
225
226
  torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
226
- torchrl_nightly-2025.6.21.dist-info/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
227
- torchrl_nightly-2025.6.21.dist-info/METADATA,sha256=FbFAW_HINLwiA_5Vi6WG31aQU6K9088TRaz-QcHO5nA,39023
228
- torchrl_nightly-2025.6.21.dist-info/WHEEL,sha256=3K-ZUOK4xUOAXNNICzKF-g_5h4y1OCqLtypLLrsO4lc,115
229
- torchrl_nightly-2025.6.21.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
230
- torchrl_nightly-2025.6.21.dist-info/RECORD,,
227
+ torchrl_nightly-2025.6.23.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
228
+ torchrl_nightly-2025.6.23.dist-info/METADATA,sha256=Akc3RKlo_nIxX-wV8Dm44Cgrmskwm90hWXd6FV5xJe0,39131
229
+ torchrl_nightly-2025.6.23.dist-info/WHEEL,sha256=9_3tTSxMJq-dgdzMiScNvtT5eTBVd3l6RgHS7HwTzpA,115
230
+ torchrl_nightly-2025.6.23.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
231
+ torchrl_nightly-2025.6.23.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.2.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp312-cp312-macosx_10_13_universal2
5
5