torchrl-nightly 2025.6.21__cp312-cp312-macosx_10_13_universal2.whl → 2025.6.23__cp312-cp312-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-312-darwin.so +0 -0
- torchrl/envs/llm/__init__.py +2 -0
- torchrl/envs/llm/chat.py +14 -1
- torchrl/envs/llm/datasets/gsm8k.py +4 -1
- torchrl/envs/llm/datasets/ifeval.py +3 -0
- torchrl/envs/llm/reward/gsm8k.py +24 -2
- torchrl/envs/llm/transforms/__init__.py +2 -0
- torchrl/envs/llm/transforms/dataloading.py +12 -0
- torchrl/envs/llm/transforms/reason.py +260 -0
- torchrl/envs/transforms/transforms.py +52 -1
- torchrl/objectives/ppo.py +13 -7
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info}/METADATA +117 -107
- {torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info}/RECORD +17 -16
- {torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info}/WHEEL +1 -1
- {torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info/licenses}/LICENSE +0 -0
- {torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/envs/llm/__init__.py
CHANGED
@@ -15,6 +15,7 @@ from .envs import LLMEnv, LLMHashingEnv
|
|
15
15
|
from .libs import make_mlgym, MLGymWrapper
|
16
16
|
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
|
17
17
|
from .transforms import (
|
18
|
+
AddThinkingPrompt,
|
18
19
|
as_nested_tensor,
|
19
20
|
as_padded_tensor,
|
20
21
|
BrowserTransform,
|
@@ -33,6 +34,7 @@ __all__ = [
|
|
33
34
|
"ChatEnv",
|
34
35
|
"DataLoadingPrimer",
|
35
36
|
"DatasetChatEnv",
|
37
|
+
"AddThinkingPrompt",
|
36
38
|
"GSM8KEnv",
|
37
39
|
"GSM8KPrepareQuestion",
|
38
40
|
"GSM8KRewardParser",
|
torchrl/envs/llm/chat.py
CHANGED
@@ -284,6 +284,7 @@ class DatasetChatEnv(TransformedEnv):
|
|
284
284
|
|
285
285
|
Keyword Args:
|
286
286
|
dataset (str): The name of the dataset.
|
287
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
|
287
288
|
name (str, optional): name of the dataset configuration.
|
288
289
|
split (str, optional): the split to use (usually from `"train"`, `"val"` or `"test"`). Defaults to `None` (no split).
|
289
290
|
num_envs (int, optional): The number of environments to create. Defaults to `1`.
|
@@ -317,6 +318,7 @@ class DatasetChatEnv(TransformedEnv):
|
|
317
318
|
self,
|
318
319
|
*,
|
319
320
|
dataset: str,
|
321
|
+
shuffle: bool = True,
|
320
322
|
name: str | None = None,
|
321
323
|
split: Literal["train", "val", "test"] | None = None,
|
322
324
|
num_envs: int = 1,
|
@@ -355,7 +357,7 @@ class DatasetChatEnv(TransformedEnv):
|
|
355
357
|
dataloader = DataLoader( # noqa: TOR401
|
356
358
|
dataset,
|
357
359
|
batch_size=batch_size_dl,
|
358
|
-
shuffle=
|
360
|
+
shuffle=shuffle,
|
359
361
|
collate_fn=collate_fn,
|
360
362
|
generator=generator,
|
361
363
|
)
|
@@ -375,3 +377,14 @@ class DatasetChatEnv(TransformedEnv):
|
|
375
377
|
apply_template=apply_template,
|
376
378
|
)
|
377
379
|
return super().__init__(env_base, primer)
|
380
|
+
|
381
|
+
def reset_dataloader(self):
|
382
|
+
"""Reset the dataloader.
|
383
|
+
|
384
|
+
This is useful when the dataloader is not infinite and we want to reset it.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
self: The environment itself.
|
388
|
+
"""
|
389
|
+
self.transform[0].reset_dataloader()
|
390
|
+
return self
|
@@ -135,6 +135,7 @@ class GSM8KEnv(DatasetChatEnv):
|
|
135
135
|
|
136
136
|
Keyword Args:
|
137
137
|
dataset (str, optional): The name of the dataset. Defaults to `"gsm8k"`.
|
138
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
|
138
139
|
num_envs (int, optional): The number of environments to create. Defaults to `1`.
|
139
140
|
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
|
140
141
|
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
|
@@ -284,12 +285,13 @@ class GSM8KEnv(DatasetChatEnv):
|
|
284
285
|
SYSTEM_PROMPT = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
|
285
286
|
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
|
286
287
|
The reasoning process and answer are enclosed within <think></think> and <answer></answer> tags, respectively,
|
287
|
-
i.e., <think>reasoning process here</think> <answer>answer here</answer>."""
|
288
|
+
i.e., <think>reasoning process here</think> <answer>answer here</answer>. The answer should be a number."""
|
288
289
|
|
289
290
|
def __init__(
|
290
291
|
self,
|
291
292
|
*,
|
292
293
|
dataset: str = "gsm8k",
|
294
|
+
shuffle: bool = True,
|
293
295
|
num_envs: int = 1,
|
294
296
|
repeats: int | None = None,
|
295
297
|
batch_size_dl: int = 1,
|
@@ -307,6 +309,7 @@ i.e., <think>reasoning process here</think> <answer>answer here</answer>."""
|
|
307
309
|
collate_fn = _collate_fn
|
308
310
|
super().__init__(
|
309
311
|
dataset=dataset,
|
312
|
+
shuffle=shuffle,
|
310
313
|
name="main",
|
311
314
|
num_envs=num_envs,
|
312
315
|
repeats=repeats,
|
@@ -41,6 +41,7 @@ class IFEvalEnv(DatasetChatEnv):
|
|
41
41
|
|
42
42
|
Keyword Args:
|
43
43
|
dataset (str, optional): The name of the dataset. Defaults to `"google/IFeval"`.
|
44
|
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to `True`.
|
44
45
|
num_envs (int, optional): The number of environments to create. Defaults to `1`.
|
45
46
|
repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
|
46
47
|
based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
|
@@ -146,6 +147,7 @@ You will be assessed by the content of the answer block only, so make sure it co
|
|
146
147
|
self,
|
147
148
|
*,
|
148
149
|
dataset: str = "google/IFeval",
|
150
|
+
shuffle: bool = True,
|
149
151
|
num_envs: int = 1,
|
150
152
|
repeats: int | None = None,
|
151
153
|
batch_size_dl: int = 1,
|
@@ -163,6 +165,7 @@ You will be assessed by the content of the answer block only, so make sure it co
|
|
163
165
|
collate_fn = _collate_fn
|
164
166
|
super().__init__(
|
165
167
|
dataset=dataset,
|
168
|
+
shuffle=shuffle,
|
166
169
|
num_envs=num_envs,
|
167
170
|
repeats=repeats,
|
168
171
|
batch_size_dl=batch_size_dl,
|
torchrl/envs/llm/reward/gsm8k.py
CHANGED
@@ -20,6 +20,7 @@ class GSM8KRewardParser(Transform):
|
|
20
20
|
in_keys (list of NestedKey): the input keys. Defaults to `["text_response", "answer"]`.
|
21
21
|
out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
|
22
22
|
eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
|
23
|
+
set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
|
23
24
|
|
24
25
|
"""
|
25
26
|
|
@@ -29,10 +30,18 @@ class GSM8KRewardParser(Transform):
|
|
29
30
|
in_keys: list[NestedKey] | None = None,
|
30
31
|
out_keys: list[NestedKey] | None = None,
|
31
32
|
eos_token: str | None = None,
|
33
|
+
set_done_if_answer: bool = True,
|
32
34
|
):
|
33
35
|
super().__init__()
|
34
36
|
self.tokenizer = tokenizer
|
35
|
-
self.eos_token =
|
37
|
+
self.eos_token = (
|
38
|
+
eos_token
|
39
|
+
if eos_token is not None
|
40
|
+
else tokenizer.eos_token
|
41
|
+
if tokenizer is not None
|
42
|
+
else None
|
43
|
+
)
|
44
|
+
self.set_done_if_answer = set_done_if_answer
|
36
45
|
if in_keys is None:
|
37
46
|
in_keys = ["text_response", "answer"]
|
38
47
|
if not isinstance(in_keys, list) or len(in_keys) != 2:
|
@@ -118,7 +127,20 @@ class GSM8KRewardParser(Transform):
|
|
118
127
|
tds = tds.add(
|
119
128
|
next_td_exist, default=torch.zeros((), device=next_tensordict.device)
|
120
129
|
)
|
121
|
-
|
130
|
+
next_tensordict = next_tensordict.update(tds)
|
131
|
+
if (
|
132
|
+
self.set_done_if_answer
|
133
|
+
and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
|
134
|
+
):
|
135
|
+
done = next_tensordict.get("done")
|
136
|
+
if done is not None:
|
137
|
+
next_tensordict.set("done", reward_answer.view_as(done) | done)
|
138
|
+
terminated = next_tensordict.get("terminated")
|
139
|
+
if terminated is not None:
|
140
|
+
next_tensordict.set(
|
141
|
+
"terminated", reward_answer.view_as(terminated) | terminated
|
142
|
+
)
|
143
|
+
return next_tensordict
|
122
144
|
|
123
145
|
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
|
124
146
|
shape = reward_spec.shape + (1, 1)
|
@@ -8,6 +8,7 @@ from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
|
|
8
8
|
from .format import TemplateTransform
|
9
9
|
from .kl import KLRewardTransform, RetrieveLogProb
|
10
10
|
from .policy_version import PolicyVersion
|
11
|
+
from .reason import AddThinkingPrompt
|
11
12
|
from .tokenizer import Tokenizer
|
12
13
|
from .tools import MCPToolTransform, PythonInterpreter
|
13
14
|
|
@@ -19,6 +20,7 @@ __all__ = [
|
|
19
20
|
"MCPToolTransform",
|
20
21
|
"PolicyVersion",
|
21
22
|
"PythonInterpreter",
|
23
|
+
"AddThinkingPrompt",
|
22
24
|
"TemplateTransform",
|
23
25
|
"Tokenizer",
|
24
26
|
"as_nested_tensor",
|
@@ -447,6 +447,18 @@ class DataLoadingPrimer(TensorDictPrimer):
|
|
447
447
|
)
|
448
448
|
self._reset_key = "_reset"
|
449
449
|
|
450
|
+
def reset_dataloader(self):
|
451
|
+
"""Reset the dataloader.
|
452
|
+
|
453
|
+
This is useful when the dataloader is not infinite and we want to reset it.
|
454
|
+
|
455
|
+
Returns:
|
456
|
+
self: The transform itself.
|
457
|
+
"""
|
458
|
+
self._queue.clear()
|
459
|
+
self.endless_dataloader = self._endless_iter(self.dataloader)
|
460
|
+
return self
|
461
|
+
|
450
462
|
@classmethod
|
451
463
|
def _endless_iter(self, obj):
|
452
464
|
while True:
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
#
|
3
|
+
# This source code is licensed under the MIT license found in the
|
4
|
+
# LICENSE file in the root directory of this source tree.
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import re
|
9
|
+
from typing import Callable, Literal
|
10
|
+
|
11
|
+
from tensordict import lazy_stack, TensorDictBase
|
12
|
+
|
13
|
+
from torchrl.data.llm.chat import History
|
14
|
+
from torchrl.envs import Transform
|
15
|
+
from torchrl.envs.common import EnvBase
|
16
|
+
|
17
|
+
|
18
|
+
class AddThinkingPrompt(Transform):
|
19
|
+
"""A transform that adds thinking prompts to encourage the LLM to reconsider its response.
|
20
|
+
|
21
|
+
This transform can either add a new thinking prompt as a separate message or edit the last
|
22
|
+
assistant response to include a thinking prompt before the final answer. This is useful for
|
23
|
+
training LLMs to self-correct and think more carefully when their initial responses are
|
24
|
+
incorrect or incomplete.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
cond (Callable[[TensorDictBase], bool], optional): Condition function that determines
|
28
|
+
when to add the thinking prompt. Takes a tensordict and returns `True` if the prompt
|
29
|
+
should be added.
|
30
|
+
prompt (str, optional): The thinking prompt to add. If None, a default prompt is used.
|
31
|
+
Defaults to `"But wait, let me think about this more carefully..."`.
|
32
|
+
random_prompt (bool, optional): Whether to randomly select from predefined prompts.
|
33
|
+
Defaults to `False`.
|
34
|
+
role (Literal["user", "assistant"], optional): The role for the thinking prompt.
|
35
|
+
If `"assistant"`, the prompt is added to the assistant's response. If `"user"`, it's
|
36
|
+
added as a separate user message. Defaults to `"assistant"`.
|
37
|
+
edit_last_turn (bool, optional): Whether to edit the last assistant response instead
|
38
|
+
of adding a new message. Only works with `role="assistant"`. Defaults to `True`.
|
39
|
+
zero_reward (bool, optional): Whether to zero out the reward when the thinking prompt
|
40
|
+
is added. If `None`, defaults to the value of `edit_last_turn`. Defaults to the same value as `edit_last_turn`.
|
41
|
+
undo_done (bool, optional): Whether to undo the done flag when the thinking prompt
|
42
|
+
is added. Defaults to `True`.
|
43
|
+
|
44
|
+
Examples:
|
45
|
+
>>> from torchrl.envs.llm.transforms import AddThinkingPrompt
|
46
|
+
>>> from torchrl.envs.llm import GSM8KEnv
|
47
|
+
>>> from transformers import AutoTokenizer
|
48
|
+
>>> import torch
|
49
|
+
>>>
|
50
|
+
>>> # Create environment with thinking prompt transform
|
51
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
|
52
|
+
>>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
|
53
|
+
>>> env = env.append_transform(
|
54
|
+
... AddThinkingPrompt(
|
55
|
+
... cond=lambda td: td["reward"] < 50,
|
56
|
+
... role="assistant",
|
57
|
+
... edit_last_turn=True,
|
58
|
+
... zero_reward=True,
|
59
|
+
... undo_done=True
|
60
|
+
... )
|
61
|
+
... )
|
62
|
+
>>>
|
63
|
+
>>> # Test with wrong answer (low reward)
|
64
|
+
>>> reset = env.reset()
|
65
|
+
>>> wrong_answer = (
|
66
|
+
... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
|
67
|
+
... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
|
68
|
+
... "To find the total, I need to add April and May: 48 + 24 = 72. "
|
69
|
+
... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
|
70
|
+
... "<answer>322 clips</answer><|im_end|>"
|
71
|
+
... )
|
72
|
+
>>> reset["text_response"] = [wrong_answer]
|
73
|
+
>>> s = env.step(reset)
|
74
|
+
>>> assert (s["next", "reward"] == 0).all() # Reward zeroed
|
75
|
+
>>> assert (s["next", "done"] == 0).all() # Done undone
|
76
|
+
>>> assert s["next", "history"].shape == (1, 3) # History modified
|
77
|
+
>>>
|
78
|
+
>>> # Test with correct answer (high reward)
|
79
|
+
>>> reset = env.reset()
|
80
|
+
>>> correct_answer = (
|
81
|
+
... "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
|
82
|
+
... "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
|
83
|
+
... "To find the total, I need to add April and May: 48 + 24 = 72. "
|
84
|
+
... "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
|
85
|
+
... "<answer>72</answer><|im_end|>"
|
86
|
+
... )
|
87
|
+
>>> reset["text_response"] = [correct_answer]
|
88
|
+
>>> s = env.step(reset)
|
89
|
+
>>> assert (s["next", "reward"] != 0).all() # Reward not zeroed
|
90
|
+
>>> assert s["next", "done"].all() # Done remains True
|
91
|
+
>>> assert s["next", "history"].shape == (1, 3) # History unchanged
|
92
|
+
"""
|
93
|
+
|
94
|
+
# Predefined thinking prompts
|
95
|
+
DEFAULT_PROMPTS = [
|
96
|
+
"But wait, let me think about this more carefully...",
|
97
|
+
"Actually, let me reconsider this...",
|
98
|
+
"Let me think about it step by step...",
|
99
|
+
"Wait, I need to double-check my reasoning...",
|
100
|
+
"Actually, let me think about it more carefully...",
|
101
|
+
]
|
102
|
+
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
cond: Callable[[TensorDictBase], bool],
|
106
|
+
prompt: str | None = None,
|
107
|
+
random_prompt: bool = False,
|
108
|
+
role: Literal["user", "assistant"] = "assistant",
|
109
|
+
edit_last_turn: bool = True,
|
110
|
+
zero_reward: bool | None = None,
|
111
|
+
undo_done: bool = True,
|
112
|
+
) -> None:
|
113
|
+
super().__init__()
|
114
|
+
|
115
|
+
# Set the prompt
|
116
|
+
if prompt is None:
|
117
|
+
prompt = self.DEFAULT_PROMPTS[0]
|
118
|
+
self._prompt = prompt
|
119
|
+
self.random_prompt = random_prompt
|
120
|
+
|
121
|
+
# Set condition and role
|
122
|
+
self.cond = cond
|
123
|
+
self.role = role
|
124
|
+
|
125
|
+
# Validate edit_last_turn constraint
|
126
|
+
if edit_last_turn and role != "assistant":
|
127
|
+
raise ValueError("edit_last_turn can only be used with role='assistant'")
|
128
|
+
self.edit_last_turn = edit_last_turn
|
129
|
+
|
130
|
+
# Set zero_reward behavior
|
131
|
+
if zero_reward is None:
|
132
|
+
zero_reward = edit_last_turn
|
133
|
+
self.zero_reward = zero_reward
|
134
|
+
self.undo_done = undo_done
|
135
|
+
|
136
|
+
@property
|
137
|
+
def prompt(self) -> str:
|
138
|
+
if self.random_prompt:
|
139
|
+
import random
|
140
|
+
|
141
|
+
return random.choice(self.DEFAULT_PROMPTS)
|
142
|
+
return self._prompt
|
143
|
+
|
144
|
+
def _step(
|
145
|
+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
|
146
|
+
) -> TensorDictBase:
|
147
|
+
"""Process the tensordict and add thinking prompts based on the condition.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
tensordict: The current tensordict
|
151
|
+
next_tensordict: The next tensordict containing the most recent history and reward
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
The modified next_tensordict
|
155
|
+
"""
|
156
|
+
print("Reward", next_tensordict["reward"])
|
157
|
+
# Handle batch dimensions
|
158
|
+
if next_tensordict.batch_dims >= 1:
|
159
|
+
ntds = []
|
160
|
+
for td, next_td in zip(tensordict.unbind(0), next_tensordict.unbind(0)):
|
161
|
+
ntds.append(self._step(td, next_td))
|
162
|
+
next_tensordict.update(lazy_stack(ntds))
|
163
|
+
return next_tensordict
|
164
|
+
|
165
|
+
# Check if we should add the thinking prompt
|
166
|
+
if self.cond(next_tensordict):
|
167
|
+
history: History = next_tensordict["history"]
|
168
|
+
last_turn = history[..., -1]
|
169
|
+
|
170
|
+
if self.edit_last_turn:
|
171
|
+
# Edit the last assistant response
|
172
|
+
content = last_turn.content
|
173
|
+
modified_content = self._replace_answer_with_prompt(content)
|
174
|
+
|
175
|
+
# Create new history entry with modified content
|
176
|
+
new_turn = History(
|
177
|
+
role="assistant",
|
178
|
+
content=modified_content,
|
179
|
+
batch_size=last_turn.batch_size,
|
180
|
+
device=last_turn.device,
|
181
|
+
)
|
182
|
+
|
183
|
+
# Replace the last turn in history
|
184
|
+
history = history[..., :-1].append(new_turn)
|
185
|
+
next_tensordict["history"] = history
|
186
|
+
|
187
|
+
else:
|
188
|
+
# Add a new message
|
189
|
+
prompt = self.prompt
|
190
|
+
|
191
|
+
history = history.append(History(role=self.role, content=prompt))
|
192
|
+
next_tensordict["history"] = history
|
193
|
+
|
194
|
+
if self.undo_done:
|
195
|
+
parent: EnvBase = self.parent
|
196
|
+
if parent is not None:
|
197
|
+
done_keys = parent.done_keys
|
198
|
+
for key in done_keys:
|
199
|
+
done = next_tensordict.get(key)
|
200
|
+
if done is not None:
|
201
|
+
next_tensordict.set(key, done.zero_())
|
202
|
+
|
203
|
+
# Zero out reward if requested
|
204
|
+
if self.zero_reward:
|
205
|
+
parent: EnvBase = self.parent
|
206
|
+
if parent is not None:
|
207
|
+
reward_keys = parent.reward_keys
|
208
|
+
for key in reward_keys:
|
209
|
+
reward = next_tensordict.get(key)
|
210
|
+
if reward is not None:
|
211
|
+
next_tensordict.set(key, reward.zero_())
|
212
|
+
return next_tensordict
|
213
|
+
|
214
|
+
def _replace_answer_with_prompt(self, content: str) -> str:
|
215
|
+
"""Replace the answer section with a thinking prompt.
|
216
|
+
|
217
|
+
This method uses regex to find and replace the <answer>...</answer> section
|
218
|
+
with the thinking prompt, preserving any content before the answer tag.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
content: The original content string
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
The modified content with the answer replaced by the thinking prompt
|
225
|
+
"""
|
226
|
+
# Pattern to match <answer>...</answer> with optional EOS token
|
227
|
+
answer_pattern = r"<answer>.*?</answer>(?:\s*<\|im_end\|>)?"
|
228
|
+
|
229
|
+
# Check if there's an answer tag
|
230
|
+
if "<answer>" in content:
|
231
|
+
# Replace the answer section with the thinking prompt
|
232
|
+
prompt = self.prompt
|
233
|
+
|
234
|
+
# Replace the answer section
|
235
|
+
modified_content = re.sub(answer_pattern, prompt, content, flags=re.DOTALL)
|
236
|
+
|
237
|
+
# Clean up any trailing whitespace
|
238
|
+
modified_content = modified_content.rstrip()
|
239
|
+
|
240
|
+
else:
|
241
|
+
# No answer tag found, just append the prompt
|
242
|
+
prompt = self.prompt
|
243
|
+
|
244
|
+
modified_content = content.rstrip() + "\n\n" + prompt
|
245
|
+
|
246
|
+
return modified_content
|
247
|
+
|
248
|
+
def _reset(
|
249
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
250
|
+
) -> TensorDictBase:
|
251
|
+
"""Reset the transform state.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
tensordict: The current tensordict
|
255
|
+
tensordict_reset: The reset tensordict
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
The reset tensordict
|
259
|
+
"""
|
260
|
+
return tensordict_reset
|
@@ -738,7 +738,7 @@ class Transform(nn.Module):
|
|
738
738
|
self.__dict__.update(state)
|
739
739
|
|
740
740
|
@property
|
741
|
-
def parent(self) ->
|
741
|
+
def parent(self) -> TransformedEnv | None:
|
742
742
|
"""Returns the parent env of the transform.
|
743
743
|
|
744
744
|
The parent env is the env that contains all the transforms up until the current one.
|
@@ -1249,6 +1249,7 @@ but got an object of type {type(transform)}."""
|
|
1249
1249
|
def empty_cache(self):
|
1250
1250
|
self.__dict__["_output_spec"] = None
|
1251
1251
|
self.__dict__["_input_spec"] = None
|
1252
|
+
self.transform.empty_cache()
|
1252
1253
|
super().empty_cache()
|
1253
1254
|
|
1254
1255
|
def append_transform(
|
@@ -1429,6 +1430,50 @@ class Compose(Transform):
|
|
1429
1430
|
for t in transforms:
|
1430
1431
|
t.set_container(self)
|
1431
1432
|
|
1433
|
+
def pop(self, index: int | None = None) -> Transform:
|
1434
|
+
"""Pop a transform from the chain.
|
1435
|
+
|
1436
|
+
Args:
|
1437
|
+
index (int, optional): The index of the transform to pop. If None, the last transform is popped.
|
1438
|
+
|
1439
|
+
Returns:
|
1440
|
+
The popped transform.
|
1441
|
+
"""
|
1442
|
+
if index is None:
|
1443
|
+
index = len(self.transforms) - 1
|
1444
|
+
result = self.transforms.pop(index)
|
1445
|
+
parent = self.parent
|
1446
|
+
self.empty_cache()
|
1447
|
+
if parent is not None:
|
1448
|
+
parent.empty_cache()
|
1449
|
+
return result
|
1450
|
+
|
1451
|
+
def __delitem__(self, index: int | slice | list):
|
1452
|
+
"""Delete a transform in the chain.
|
1453
|
+
|
1454
|
+
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
|
1455
|
+
"""
|
1456
|
+
del self.transforms[index]
|
1457
|
+
parent = self.parent
|
1458
|
+
self.empty_cache()
|
1459
|
+
if parent is not None:
|
1460
|
+
parent.empty_cache()
|
1461
|
+
|
1462
|
+
def __setitem__(
|
1463
|
+
self,
|
1464
|
+
index: int | slice | list,
|
1465
|
+
value: Transform | Callable[[TensorDictBase], TensorDictBase],
|
1466
|
+
):
|
1467
|
+
"""Set a transform in the chain.
|
1468
|
+
|
1469
|
+
:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
|
1470
|
+
"""
|
1471
|
+
self.transforms[index] = value
|
1472
|
+
parent = self.parent
|
1473
|
+
self.empty_cache()
|
1474
|
+
if parent is not None:
|
1475
|
+
parent.empty_cache()
|
1476
|
+
|
1432
1477
|
def close(self):
|
1433
1478
|
"""Close the transform."""
|
1434
1479
|
for t in self.transforms:
|
@@ -1594,6 +1639,9 @@ class Compose(Transform):
|
|
1594
1639
|
else:
|
1595
1640
|
self.transforms.append(transform)
|
1596
1641
|
transform.set_container(self)
|
1642
|
+
parent = self.parent
|
1643
|
+
if parent is not None:
|
1644
|
+
parent.empty_cache()
|
1597
1645
|
|
1598
1646
|
def set_container(self, container: Transform | EnvBase) -> None:
|
1599
1647
|
self.reset_parent()
|
@@ -1626,6 +1674,9 @@ class Compose(Transform):
|
|
1626
1674
|
|
1627
1675
|
# empty cache of all transforms to reset parents and specs
|
1628
1676
|
self.empty_cache()
|
1677
|
+
parent = self.parent
|
1678
|
+
if parent is not None:
|
1679
|
+
parent.empty_cache()
|
1629
1680
|
if index < 0:
|
1630
1681
|
index = index + len(self.transforms)
|
1631
1682
|
transform.eval()
|
torchrl/objectives/ppo.py
CHANGED
@@ -752,10 +752,10 @@ class PPOLoss(LossModule):
|
|
752
752
|
|
753
753
|
explained_variance = None
|
754
754
|
if self.log_explained_variance:
|
755
|
-
with torch.no_grad():
|
756
|
-
tgt
|
757
|
-
pred
|
758
|
-
eps
|
755
|
+
with torch.no_grad(): # <‑‑ break grad‐flow
|
756
|
+
tgt = target_return.detach()
|
757
|
+
pred = state_value.detach()
|
758
|
+
eps = torch.finfo(tgt.dtype).eps
|
759
759
|
resid = torch.var(tgt - pred, unbiased=False, dim=0)
|
760
760
|
total = torch.var(tgt, unbiased=False, dim=0)
|
761
761
|
explained_variance = 1.0 - resid / (total + eps)
|
@@ -819,7 +819,9 @@ class PPOLoss(LossModule):
|
|
819
819
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
820
820
|
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
|
821
821
|
if self._has_critic:
|
822
|
-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
822
|
+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
823
|
+
tensordict
|
824
|
+
)
|
823
825
|
td_out.set("loss_critic", loss_critic)
|
824
826
|
if value_clip_fraction is not None:
|
825
827
|
td_out.set("value_clip_fraction", value_clip_fraction)
|
@@ -1189,7 +1191,9 @@ class ClipPPOLoss(PPOLoss):
|
|
1189
1191
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
1190
1192
|
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
|
1191
1193
|
if self._has_critic:
|
1192
|
-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
1194
|
+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
1195
|
+
tensordict
|
1196
|
+
)
|
1193
1197
|
td_out.set("loss_critic", loss_critic)
|
1194
1198
|
if value_clip_fraction is not None:
|
1195
1199
|
td_out.set("value_clip_fraction", value_clip_fraction)
|
@@ -1537,7 +1541,9 @@ class KLPENPPOLoss(PPOLoss):
|
|
1537
1541
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
1538
1542
|
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
|
1539
1543
|
if self._has_critic:
|
1540
|
-
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
1544
|
+
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
|
1545
|
+
tensordict_copy
|
1546
|
+
)
|
1541
1547
|
td_out.set("loss_critic", loss_critic)
|
1542
1548
|
if value_clip_fraction is not None:
|
1543
1549
|
td_out.set("value_clip_fraction", value_clip_fraction)
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.6.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.6.23'
|
2
|
+
git_version = 'ed051bc3e5b33d00f64f2a785023bca9a6c72c9b'
|
@@ -1,6 +1,6 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: torchrl-nightly
|
3
|
-
Version: 2025.6.
|
3
|
+
Version: 2025.6.23
|
4
4
|
Home-page: https://github.com/pytorch/rl
|
5
5
|
Author: torchrl contributors
|
6
6
|
Author-email: vmoens@fb.com
|
@@ -18,119 +18,129 @@ Classifier: License :: OSI Approved :: BSD License
|
|
18
18
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
19
19
|
Description-Content-Type: text/markdown
|
20
20
|
License-File: LICENSE
|
21
|
-
Requires-Dist: torch
|
21
|
+
Requires-Dist: torch>=2.1.0
|
22
22
|
Requires-Dist: numpy
|
23
23
|
Requires-Dist: packaging
|
24
24
|
Requires-Dist: cloudpickle
|
25
25
|
Requires-Dist: tensordict-nightly
|
26
|
-
Provides-Extra: all
|
27
|
-
Requires-Dist: accelerate ; extra == 'all'
|
28
|
-
Requires-Dist: datasets ; extra == 'all'
|
29
|
-
Requires-Dist: dm-meltingpot ; extra == 'all'
|
30
|
-
Requires-Dist: dm-control ; extra == 'all'
|
31
|
-
Requires-Dist: einops ; extra == 'all'
|
32
|
-
Requires-Dist: git ; extra == 'all'
|
33
|
-
Requires-Dist: gymnasium <1.0 ; extra == 'all'
|
34
|
-
Requires-Dist: gymnasium[atari] ; extra == 'all'
|
35
|
-
Requires-Dist: h5py ; extra == 'all'
|
36
|
-
Requires-Dist: huggingface-hub ; extra == 'all'
|
37
|
-
Requires-Dist: hydra-core >=1.1 ; extra == 'all'
|
38
|
-
Requires-Dist: hydra-submitit-launcher ; extra == 'all'
|
39
|
-
Requires-Dist: immutabledict ; extra == 'all'
|
40
|
-
Requires-Dist: langdetect ; extra == 'all'
|
41
|
-
Requires-Dist: minari ; extra == 'all'
|
42
|
-
Requires-Dist: moviepy <2.0.0 ; extra == 'all'
|
43
|
-
Requires-Dist: mujoco ; extra == 'all'
|
44
|
-
Requires-Dist: nltk ; extra == 'all'
|
45
|
-
Requires-Dist: open-spiel >=1.5 ; extra == 'all'
|
46
|
-
Requires-Dist: pandas ; extra == 'all'
|
47
|
-
Requires-Dist: pettingzoo >=1.24.1 ; extra == 'all'
|
48
|
-
Requires-Dist: pillow ; extra == 'all'
|
49
|
-
Requires-Dist: playwright ; extra == 'all'
|
50
|
-
Requires-Dist: protobuf ; extra == 'all'
|
51
|
-
Requires-Dist: pytest ; extra == 'all'
|
52
|
-
Requires-Dist: pytest-asyncio ; extra == 'all'
|
53
|
-
Requires-Dist: pytest-benchmark ; extra == 'all'
|
54
|
-
Requires-Dist: pytest-cov ; extra == 'all'
|
55
|
-
Requires-Dist: pytest-error-for-skips ; extra == 'all'
|
56
|
-
Requires-Dist: pytest-instafail ; extra == 'all'
|
57
|
-
Requires-Dist: pytest-mock ; extra == 'all'
|
58
|
-
Requires-Dist: pytest-rerunfailures ; extra == 'all'
|
59
|
-
Requires-Dist: pyyaml ; extra == 'all'
|
60
|
-
Requires-Dist: requests ; extra == 'all'
|
61
|
-
Requires-Dist: safetensors ; extra == 'all'
|
62
|
-
Requires-Dist: scikit-learn ; extra == 'all'
|
63
|
-
Requires-Dist: scipy ; extra == 'all'
|
64
|
-
Requires-Dist: sentencepiece ; extra == 'all'
|
65
|
-
Requires-Dist: tensorboard ; extra == 'all'
|
66
|
-
Requires-Dist: torch >=2.7.0 ; extra == 'all'
|
67
|
-
Requires-Dist: torchsnapshot ; extra == 'all'
|
68
|
-
Requires-Dist: torchvision ; extra == 'all'
|
69
|
-
Requires-Dist: tqdm ; extra == 'all'
|
70
|
-
Requires-Dist: transformers ; extra == 'all'
|
71
|
-
Requires-Dist: vllm ; extra == 'all'
|
72
|
-
Requires-Dist: vmas >=1.2.10 ; extra == 'all'
|
73
|
-
Requires-Dist: wandb ; extra == 'all'
|
74
26
|
Provides-Extra: atari
|
75
|
-
Requires-Dist: gymnasium[atari]
|
76
|
-
Provides-Extra:
|
77
|
-
Requires-Dist:
|
78
|
-
Provides-Extra:
|
79
|
-
Requires-Dist:
|
80
|
-
Provides-Extra:
|
81
|
-
Requires-Dist: gymnasium
|
82
|
-
Requires-Dist: mujoco
|
83
|
-
Provides-Extra: llm
|
84
|
-
Requires-Dist: transformers ; extra == 'llm'
|
85
|
-
Requires-Dist: vllm ; extra == 'llm'
|
86
|
-
Requires-Dist: playwright ; extra == 'llm'
|
87
|
-
Requires-Dist: datasets ; extra == 'llm'
|
88
|
-
Requires-Dist: langdetect ; extra == 'llm'
|
89
|
-
Requires-Dist: nltk ; extra == 'llm'
|
90
|
-
Requires-Dist: immutabledict ; extra == 'llm'
|
91
|
-
Requires-Dist: accelerate ; extra == 'llm'
|
92
|
-
Requires-Dist: sentencepiece ; extra == 'llm'
|
93
|
-
Requires-Dist: protobuf ; extra == 'llm'
|
94
|
-
Requires-Dist: einops ; extra == 'llm'
|
95
|
-
Requires-Dist: safetensors ; extra == 'llm'
|
96
|
-
Provides-Extra: marl
|
97
|
-
Requires-Dist: vmas >=1.2.10 ; extra == 'marl'
|
98
|
-
Requires-Dist: pettingzoo >=1.24.1 ; extra == 'marl'
|
99
|
-
Requires-Dist: dm-meltingpot ; extra == 'marl'
|
100
|
-
Provides-Extra: offline-data
|
101
|
-
Requires-Dist: huggingface-hub ; extra == 'offline-data'
|
102
|
-
Requires-Dist: minari ; extra == 'offline-data'
|
103
|
-
Requires-Dist: requests ; extra == 'offline-data'
|
104
|
-
Requires-Dist: tqdm ; extra == 'offline-data'
|
105
|
-
Requires-Dist: torchvision ; extra == 'offline-data'
|
106
|
-
Requires-Dist: scikit-learn ; extra == 'offline-data'
|
107
|
-
Requires-Dist: pandas ; extra == 'offline-data'
|
108
|
-
Requires-Dist: h5py ; extra == 'offline-data'
|
109
|
-
Requires-Dist: pillow ; extra == 'offline-data'
|
110
|
-
Provides-Extra: open_spiel
|
111
|
-
Requires-Dist: open-spiel >=1.5 ; extra == 'open_spiel'
|
27
|
+
Requires-Dist: gymnasium[atari]; extra == "atari"
|
28
|
+
Provides-Extra: dm-control
|
29
|
+
Requires-Dist: dm_control; extra == "dm-control"
|
30
|
+
Provides-Extra: replay-buffer
|
31
|
+
Requires-Dist: torch>=2.7.0; extra == "replay-buffer"
|
32
|
+
Provides-Extra: gym-continuous
|
33
|
+
Requires-Dist: gymnasium<1.0; extra == "gym-continuous"
|
34
|
+
Requires-Dist: mujoco; extra == "gym-continuous"
|
112
35
|
Provides-Extra: rendering
|
113
|
-
Requires-Dist: moviepy
|
114
|
-
Provides-Extra: replay_buffer
|
115
|
-
Requires-Dist: torch >=2.7.0 ; extra == 'replay_buffer'
|
36
|
+
Requires-Dist: moviepy<2.0.0; extra == "rendering"
|
116
37
|
Provides-Extra: tests
|
117
|
-
Requires-Dist: pytest
|
118
|
-
Requires-Dist: pyyaml
|
119
|
-
Requires-Dist: pytest-instafail
|
120
|
-
Requires-Dist: scipy
|
121
|
-
Requires-Dist: pytest-mock
|
122
|
-
Requires-Dist: pytest-cov
|
123
|
-
Requires-Dist: pytest-asyncio
|
124
|
-
Requires-Dist: pytest-benchmark
|
125
|
-
Requires-Dist: pytest-rerunfailures
|
126
|
-
Requires-Dist: pytest-error-for-skips
|
38
|
+
Requires-Dist: pytest; extra == "tests"
|
39
|
+
Requires-Dist: pyyaml; extra == "tests"
|
40
|
+
Requires-Dist: pytest-instafail; extra == "tests"
|
41
|
+
Requires-Dist: scipy; extra == "tests"
|
42
|
+
Requires-Dist: pytest-mock; extra == "tests"
|
43
|
+
Requires-Dist: pytest-cov; extra == "tests"
|
44
|
+
Requires-Dist: pytest-asyncio; extra == "tests"
|
45
|
+
Requires-Dist: pytest-benchmark; extra == "tests"
|
46
|
+
Requires-Dist: pytest-rerunfailures; extra == "tests"
|
47
|
+
Requires-Dist: pytest-error-for-skips; extra == "tests"
|
127
48
|
Provides-Extra: utils
|
128
|
-
Requires-Dist: tensorboard
|
129
|
-
Requires-Dist: wandb
|
130
|
-
Requires-Dist: tqdm
|
131
|
-
Requires-Dist: hydra-core
|
132
|
-
Requires-Dist: hydra-submitit-launcher
|
133
|
-
Requires-Dist: git
|
49
|
+
Requires-Dist: tensorboard; extra == "utils"
|
50
|
+
Requires-Dist: wandb; extra == "utils"
|
51
|
+
Requires-Dist: tqdm; extra == "utils"
|
52
|
+
Requires-Dist: hydra-core>=1.1; extra == "utils"
|
53
|
+
Requires-Dist: hydra-submitit-launcher; extra == "utils"
|
54
|
+
Requires-Dist: git; extra == "utils"
|
55
|
+
Provides-Extra: checkpointing
|
56
|
+
Requires-Dist: torchsnapshot; extra == "checkpointing"
|
57
|
+
Provides-Extra: offline-data
|
58
|
+
Requires-Dist: huggingface_hub; extra == "offline-data"
|
59
|
+
Requires-Dist: minari; extra == "offline-data"
|
60
|
+
Requires-Dist: requests; extra == "offline-data"
|
61
|
+
Requires-Dist: tqdm; extra == "offline-data"
|
62
|
+
Requires-Dist: torchvision; extra == "offline-data"
|
63
|
+
Requires-Dist: scikit-learn; extra == "offline-data"
|
64
|
+
Requires-Dist: pandas; extra == "offline-data"
|
65
|
+
Requires-Dist: h5py; extra == "offline-data"
|
66
|
+
Requires-Dist: pillow; extra == "offline-data"
|
67
|
+
Provides-Extra: marl
|
68
|
+
Requires-Dist: vmas>=1.2.10; extra == "marl"
|
69
|
+
Requires-Dist: pettingzoo>=1.24.1; extra == "marl"
|
70
|
+
Requires-Dist: dm-meltingpot; extra == "marl"
|
71
|
+
Provides-Extra: open-spiel
|
72
|
+
Requires-Dist: open_spiel>=1.5; extra == "open-spiel"
|
73
|
+
Provides-Extra: llm
|
74
|
+
Requires-Dist: transformers; extra == "llm"
|
75
|
+
Requires-Dist: vllm; extra == "llm"
|
76
|
+
Requires-Dist: playwright; extra == "llm"
|
77
|
+
Requires-Dist: datasets; extra == "llm"
|
78
|
+
Requires-Dist: langdetect; extra == "llm"
|
79
|
+
Requires-Dist: nltk; extra == "llm"
|
80
|
+
Requires-Dist: immutabledict; extra == "llm"
|
81
|
+
Requires-Dist: accelerate; extra == "llm"
|
82
|
+
Requires-Dist: sentencepiece; extra == "llm"
|
83
|
+
Requires-Dist: protobuf; extra == "llm"
|
84
|
+
Requires-Dist: einops; extra == "llm"
|
85
|
+
Requires-Dist: safetensors; extra == "llm"
|
86
|
+
Provides-Extra: all
|
87
|
+
Requires-Dist: accelerate; extra == "all"
|
88
|
+
Requires-Dist: datasets; extra == "all"
|
89
|
+
Requires-Dist: dm-meltingpot; extra == "all"
|
90
|
+
Requires-Dist: dm_control; extra == "all"
|
91
|
+
Requires-Dist: einops; extra == "all"
|
92
|
+
Requires-Dist: git; extra == "all"
|
93
|
+
Requires-Dist: gymnasium<1.0; extra == "all"
|
94
|
+
Requires-Dist: gymnasium[atari]; extra == "all"
|
95
|
+
Requires-Dist: h5py; extra == "all"
|
96
|
+
Requires-Dist: huggingface_hub; extra == "all"
|
97
|
+
Requires-Dist: hydra-core>=1.1; extra == "all"
|
98
|
+
Requires-Dist: hydra-submitit-launcher; extra == "all"
|
99
|
+
Requires-Dist: immutabledict; extra == "all"
|
100
|
+
Requires-Dist: langdetect; extra == "all"
|
101
|
+
Requires-Dist: minari; extra == "all"
|
102
|
+
Requires-Dist: moviepy<2.0.0; extra == "all"
|
103
|
+
Requires-Dist: mujoco; extra == "all"
|
104
|
+
Requires-Dist: nltk; extra == "all"
|
105
|
+
Requires-Dist: open_spiel>=1.5; extra == "all"
|
106
|
+
Requires-Dist: pandas; extra == "all"
|
107
|
+
Requires-Dist: pettingzoo>=1.24.1; extra == "all"
|
108
|
+
Requires-Dist: pillow; extra == "all"
|
109
|
+
Requires-Dist: playwright; extra == "all"
|
110
|
+
Requires-Dist: protobuf; extra == "all"
|
111
|
+
Requires-Dist: pytest; extra == "all"
|
112
|
+
Requires-Dist: pytest-asyncio; extra == "all"
|
113
|
+
Requires-Dist: pytest-benchmark; extra == "all"
|
114
|
+
Requires-Dist: pytest-cov; extra == "all"
|
115
|
+
Requires-Dist: pytest-error-for-skips; extra == "all"
|
116
|
+
Requires-Dist: pytest-instafail; extra == "all"
|
117
|
+
Requires-Dist: pytest-mock; extra == "all"
|
118
|
+
Requires-Dist: pytest-rerunfailures; extra == "all"
|
119
|
+
Requires-Dist: pyyaml; extra == "all"
|
120
|
+
Requires-Dist: requests; extra == "all"
|
121
|
+
Requires-Dist: safetensors; extra == "all"
|
122
|
+
Requires-Dist: scikit-learn; extra == "all"
|
123
|
+
Requires-Dist: scipy; extra == "all"
|
124
|
+
Requires-Dist: sentencepiece; extra == "all"
|
125
|
+
Requires-Dist: tensorboard; extra == "all"
|
126
|
+
Requires-Dist: torch>=2.7.0; extra == "all"
|
127
|
+
Requires-Dist: torchsnapshot; extra == "all"
|
128
|
+
Requires-Dist: torchvision; extra == "all"
|
129
|
+
Requires-Dist: tqdm; extra == "all"
|
130
|
+
Requires-Dist: transformers; extra == "all"
|
131
|
+
Requires-Dist: vllm; extra == "all"
|
132
|
+
Requires-Dist: vmas>=1.2.10; extra == "all"
|
133
|
+
Requires-Dist: wandb; extra == "all"
|
134
|
+
Dynamic: author
|
135
|
+
Dynamic: author-email
|
136
|
+
Dynamic: classifier
|
137
|
+
Dynamic: description
|
138
|
+
Dynamic: description-content-type
|
139
|
+
Dynamic: home-page
|
140
|
+
Dynamic: license
|
141
|
+
Dynamic: license-file
|
142
|
+
Dynamic: provides-extra
|
143
|
+
Dynamic: requires-dist
|
134
144
|
|
135
145
|
[](https://github.com/pytorch/rl/actions/workflows/test-linux.yml)
|
136
146
|
[](https://pytorch.org/rl/)
|
@@ -3,9 +3,9 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
|
4
4
|
torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
|
5
5
|
torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
|
6
|
-
torchrl/_torchrl.cpython-312-darwin.so,sha256=
|
6
|
+
torchrl/_torchrl.cpython-312-darwin.so,sha256=2G08KUB1lgllQaOiK_aiC-u0bm9sPqvrm5fF8LPT_bc,1691072
|
7
7
|
torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=4YeSUDGHrB3YeBHYcVaU3pmlvqaCzwjkSwnDvVAiGUQ,83
|
9
9
|
torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
|
10
10
|
torchrl/collectors/collectors.py,sha256=CdTerIwhCTr6n5OoJLNad0bNQ5OLliPZFWkU18QBKSA,177625
|
11
11
|
torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
|
@@ -97,28 +97,29 @@ torchrl/envs/libs/smacv2.py,sha256=i0TRHuZ9S9v0NfufPgQAcTlvAjf6JKv8hHvOzjSgsaw,2
|
|
97
97
|
torchrl/envs/libs/unity_mlagents.py,sha256=Z3qSU0H3o2NXbS2lNvQ7OmYxkr3AWAMyRHfxeCtNZrk,49667
|
98
98
|
torchrl/envs/libs/utils.py,sha256=RgiR16KJWFEtQim44-AIcHByGTq_NrtpjWoYIC13aYA,5207
|
99
99
|
torchrl/envs/libs/vmas.py,sha256=a71_jU4r627hFXcMsT5wNSb4TMpyd3punLdOF3Cc8O0,36297
|
100
|
-
torchrl/envs/llm/__init__.py,sha256=
|
101
|
-
torchrl/envs/llm/chat.py,sha256=
|
100
|
+
torchrl/envs/llm/__init__.py,sha256=HGpJZYZHR3tJVZ0EKq-Zh2r715JSH_H82PNa1z8F9V0,1313
|
101
|
+
torchrl/envs/llm/chat.py,sha256=2j1S1-_EC52_RpIdN4Gy6_mkxBG8aXL8Yo93SQ2YRIM,18201
|
102
102
|
torchrl/envs/llm/envs.py,sha256=Er-ahjgvtYG4LB7_EWOMbdobiUV5DOHPBQYkVTu80r4,34677
|
103
103
|
torchrl/envs/llm/datasets/__init__.py,sha256=FFethtv8unJWzphGLPQVC5QD9NMdaygEjx25O1DHHZk,473
|
104
|
-
torchrl/envs/llm/datasets/gsm8k.py,sha256=
|
105
|
-
torchrl/envs/llm/datasets/ifeval.py,sha256=
|
104
|
+
torchrl/envs/llm/datasets/gsm8k.py,sha256=pAOWJh8ArCvTdOKWmr7bQb4o6Hqpoq6PjS0h9HAaRDE,15475
|
105
|
+
torchrl/envs/llm/datasets/ifeval.py,sha256=dzvSgOgqVxogFq0rC8O1SMqPfjQD8NEAGCGg5LnmXiU,8472
|
106
106
|
torchrl/envs/llm/libs/__init__.py,sha256=vhEm5Fhz1sLWt107zfZLy5pzGmfQi0fNBGazTq1m7dU,266
|
107
107
|
torchrl/envs/llm/libs/mlgym.py,sha256=ECnkrNoPV73L1fIO05SlTTXuTSNOM2pdX6aJcEYJVlo,31372
|
108
108
|
torchrl/envs/llm/reward/__init__.py,sha256=a-Xsye29z2LugO1cOCFM2FNsqNwEp-5XwQk4saVQlu8,370
|
109
|
-
torchrl/envs/llm/reward/gsm8k.py,sha256=
|
109
|
+
torchrl/envs/llm/reward/gsm8k.py,sha256=GYd0l_YRaIiivZBLRGjhJeQiFj6jm-BUh9T3pEze3a8,8760
|
110
110
|
torchrl/envs/llm/reward/ifeval/__init__.py,sha256=g5NtrwfwqK22hRcoIdz8-KWBh5Ogre9J-Bf3uGWE9Pg,314
|
111
111
|
torchrl/envs/llm/reward/ifeval/_instructions.py,sha256=rAoTdwG42smCLJgwW7kAwJrNonjIS6OwdohDE70oMOA,61696
|
112
112
|
torchrl/envs/llm/reward/ifeval/_instructions_main.py,sha256=CofKXvG0J2H-1ZXP1fL6UZI8ArNCIO2w5R_37drRIW8,4117
|
113
113
|
torchrl/envs/llm/reward/ifeval/_instructions_registry.py,sha256=3_guc8LZ0mWQc-n6E4cQgYMgZRYa6xfgvXgrze9aO_w,3814
|
114
114
|
torchrl/envs/llm/reward/ifeval/_instructions_util.py,sha256=aA3fupO8MvqBCqD7Y_Qk6y32toWF1lZGAflWON1ruXM,26042
|
115
115
|
torchrl/envs/llm/reward/ifeval/_scorer.py,sha256=zJHBgaGlluEv6czsI6ZtLqArV_J_W9zY7UPAJhT5YIo,14563
|
116
|
-
torchrl/envs/llm/transforms/__init__.py,sha256=
|
116
|
+
torchrl/envs/llm/transforms/__init__.py,sha256=PNwdol9ItWXPfzyKcf4Id7Yu6oKFFAtta2J78ksSrf0,851
|
117
117
|
torchrl/envs/llm/transforms/browser.py,sha256=zF7jHHHrdpxUCjFFtiYK-vhw-p1YqsqwP8_b4SiK0Rs,10423
|
118
|
-
torchrl/envs/llm/transforms/dataloading.py,sha256=
|
118
|
+
torchrl/envs/llm/transforms/dataloading.py,sha256=4P-e5yjUdxRtfaOmMxtNRisJLLtCqurAhWAqV7GiXHI,24872
|
119
119
|
torchrl/envs/llm/transforms/format.py,sha256=ESn0S9k5G4FQPBICq9h6ZsLKXZqiU71tYW8UnW4rgLI,2519
|
120
120
|
torchrl/envs/llm/transforms/kl.py,sha256=N68378chSx54X5a7YLJzIV6d870H5xrBb5-qWqzpX1U,22744
|
121
121
|
torchrl/envs/llm/transforms/policy_version.py,sha256=by2TjsZLwVjQbq7ggBoAco2Iq_2aEYgyxh9asTXL1vk,6893
|
122
|
+
torchrl/envs/llm/transforms/reason.py,sha256=Q3LRbl7QmatRfAt7bOjOw_aLuZJgqRZvmKwT67cWX7s,10561
|
122
123
|
torchrl/envs/llm/transforms/tokenizer.py,sha256=CcuKRu33YnyDgLtQtyxTGDFC6iI3b3fUA6Nb1Lnh7h8,13953
|
123
124
|
torchrl/envs/llm/transforms/tools.py,sha256=I-HR0zjH4tFMp9xPH556H5Q5JqmqXdsAXwElAR93e5U,29498
|
124
125
|
torchrl/envs/model_based/__init__.py,sha256=AkgZvTP5AerIg6ZwXfCfk3bnSr01hlwZWDiRd3UjBE4,331
|
@@ -131,7 +132,7 @@ torchrl/envs/transforms/llm.py,sha256=rQDzuut807wvFpSPCm5tynt8-cMKTgVKVjSVu9D99P
|
|
131
132
|
torchrl/envs/transforms/r3m.py,sha256=sdTVLpnxHfzFVo5rO8WnXf2uUg9cr4LBOLBsWaFgGT8,13478
|
132
133
|
torchrl/envs/transforms/rb_transforms.py,sha256=pxtL1VHvzEq6djuWsccLu4P-tnbAKsavemLGyt80I6c,7448
|
133
134
|
torchrl/envs/transforms/rlhf.py,sha256=lOVXYqQaoDfm4_n77Dxw_wjicBpMtDvavKmBIK2N3lU,628
|
134
|
-
torchrl/envs/transforms/transforms.py,sha256=
|
135
|
+
torchrl/envs/transforms/transforms.py,sha256=8aXDl-NfugfqlBK-FcPBKYuDU-oIXeabW3uIXZ6QMik,481272
|
135
136
|
torchrl/envs/transforms/utils.py,sha256=VXGH69Jxdmnw5eP9L3uM8ronQA5aIbT-Ktpjn5Frds0,2058
|
136
137
|
torchrl/envs/transforms/vc1.py,sha256=mho5BvdAK-f9hD9t-iah52wT2B06qPmaJO7chrfIOWY,10534
|
137
138
|
torchrl/envs/transforms/vecnorm.py,sha256=XahMcWvK3zjOB6EACSZtJ6UMP3yQ2zD9xf87UEB37Eg,34047
|
@@ -188,7 +189,7 @@ torchrl/objectives/dreamer.py,sha256=vIJQN91oPXYnPubDFQpaF5d3fR_WwIYuIVYtoCvw0TY
|
|
188
189
|
torchrl/objectives/functional.py,sha256=ZaglBjEGuOTNGeFA-Ox-ugZVcNegQMUj--KWHDRBmaU,2106
|
189
190
|
torchrl/objectives/gail.py,sha256=0m34XmcN-EDk5OfNIo5bKYbKKZfATsYRv4zQe3v2UwA,9576
|
190
191
|
torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,42991
|
191
|
-
torchrl/objectives/ppo.py,sha256=
|
192
|
+
torchrl/objectives/ppo.py,sha256=4fzV-DSFSGv0VHrI0YCk0EBUB35gkuWyo3j_4KhSoqE,75340
|
192
193
|
torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
|
193
194
|
torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
|
194
195
|
torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,63993
|
@@ -223,8 +224,8 @@ torchrl/trainers/helpers/losses.py,sha256=qH-2YJwMtDAYAPXTTYy3cOPiq4ILC6xTjfnGUU
|
|
223
224
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
224
225
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
225
226
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
226
|
-
torchrl_nightly-2025.6.
|
227
|
-
torchrl_nightly-2025.6.
|
228
|
-
torchrl_nightly-2025.6.
|
229
|
-
torchrl_nightly-2025.6.
|
230
|
-
torchrl_nightly-2025.6.
|
227
|
+
torchrl_nightly-2025.6.23.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
228
|
+
torchrl_nightly-2025.6.23.dist-info/METADATA,sha256=Akc3RKlo_nIxX-wV8Dm44Cgrmskwm90hWXd6FV5xJe0,39131
|
229
|
+
torchrl_nightly-2025.6.23.dist-info/WHEEL,sha256=9_3tTSxMJq-dgdzMiScNvtT5eTBVd3l6RgHS7HwTzpA,115
|
230
|
+
torchrl_nightly-2025.6.23.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
231
|
+
torchrl_nightly-2025.6.23.dist-info/RECORD,,
|
{torchrl_nightly-2025.6.21.dist-info → torchrl_nightly-2025.6.23.dist-info/licenses}/LICENSE
RENAMED
File without changes
|
File without changes
|