torchrl-nightly 2025.7.15__cp313-cp313-macosx_10_13_universal2.whl → 2025.7.16__cp313-cp313-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchrl/_torchrl.cpython-313-darwin.so +0 -0
- torchrl/data/tensor_specs.py +34 -9
- torchrl/envs/transforms/transforms.py +0 -1
- torchrl/modules/llm/policies/common.py +28 -0
- torchrl/objectives/a2c.py +32 -13
- torchrl/objectives/ppo.py +50 -32
- torchrl/trainers/helpers/losses.py +2 -2
- torchrl/version.py +2 -2
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/RECORD +13 -13
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/top_level.txt +0 -0
Binary file
|
torchrl/data/tensor_specs.py
CHANGED
@@ -4449,12 +4449,18 @@ class Binary(Categorical):
|
|
4449
4449
|
f"shape of the {self.__class__.__name__} spec in expand()."
|
4450
4450
|
)
|
4451
4451
|
return self.__class__(
|
4452
|
-
n=self.shape[-1]
|
4452
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4453
|
+
shape=shape,
|
4454
|
+
device=self.device,
|
4455
|
+
dtype=self.dtype,
|
4453
4456
|
)
|
4454
4457
|
|
4455
4458
|
def _reshape(self, shape):
|
4456
4459
|
return self.__class__(
|
4457
|
-
n=self.shape[-1]
|
4460
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4461
|
+
shape=shape,
|
4462
|
+
device=self.device,
|
4463
|
+
dtype=self.dtype,
|
4458
4464
|
)
|
4459
4465
|
|
4460
4466
|
def _unflatten(self, dim, sizes):
|
@@ -4464,7 +4470,10 @@ class Binary(Categorical):
|
|
4464
4470
|
.shape
|
4465
4471
|
)
|
4466
4472
|
return self.__class__(
|
4467
|
-
n=self.shape[-1]
|
4473
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4474
|
+
shape=shape,
|
4475
|
+
device=self.device,
|
4476
|
+
dtype=self.dtype,
|
4468
4477
|
)
|
4469
4478
|
|
4470
4479
|
def squeeze(self, dim=None):
|
@@ -4472,13 +4481,19 @@ class Binary(Categorical):
|
|
4472
4481
|
if shape is None:
|
4473
4482
|
return self
|
4474
4483
|
return self.__class__(
|
4475
|
-
n=self.shape[-1]
|
4484
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4485
|
+
shape=shape,
|
4486
|
+
device=self.device,
|
4487
|
+
dtype=self.dtype,
|
4476
4488
|
)
|
4477
4489
|
|
4478
4490
|
def unsqueeze(self, dim: int):
|
4479
4491
|
shape = _unsqueezed_shape(self.shape, dim)
|
4480
4492
|
return self.__class__(
|
4481
|
-
n=self.shape[-1]
|
4493
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4494
|
+
shape=shape,
|
4495
|
+
device=self.device,
|
4496
|
+
dtype=self.dtype,
|
4482
4497
|
)
|
4483
4498
|
|
4484
4499
|
def unbind(self, dim: int = 0):
|
@@ -4495,7 +4510,10 @@ class Binary(Categorical):
|
|
4495
4510
|
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
|
4496
4511
|
return tuple(
|
4497
4512
|
self.__class__(
|
4498
|
-
n=self.shape[-1]
|
4513
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4514
|
+
shape=shape,
|
4515
|
+
device=self.device,
|
4516
|
+
dtype=self.dtype,
|
4499
4517
|
)
|
4500
4518
|
for i in range(self.shape[dim])
|
4501
4519
|
)
|
@@ -4512,12 +4530,15 @@ class Binary(Categorical):
|
|
4512
4530
|
if dest_device == self.device and dest_dtype == self.dtype:
|
4513
4531
|
return self
|
4514
4532
|
return self.__class__(
|
4515
|
-
n=self.shape[-1]
|
4533
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4534
|
+
shape=self.shape,
|
4535
|
+
device=dest_device,
|
4536
|
+
dtype=dest_dtype,
|
4516
4537
|
)
|
4517
4538
|
|
4518
4539
|
def clone(self) -> Binary:
|
4519
4540
|
return self.__class__(
|
4520
|
-
n=self.shape[-1],
|
4541
|
+
n=self.shape[-1] if len(self.shape) > 0 else None,
|
4521
4542
|
shape=self.shape,
|
4522
4543
|
device=self.device,
|
4523
4544
|
dtype=self.dtype,
|
@@ -4528,6 +4549,8 @@ class Binary(Categorical):
|
|
4528
4549
|
|
4529
4550
|
The last dimension of the spec (length n of the binary vector) cannot be indexed.
|
4530
4551
|
"""
|
4552
|
+
if not len(self.shape):
|
4553
|
+
raise ValueError("Cannot index a Binary spec with an empty shape")
|
4531
4554
|
indexed_shape = _shape_indexing(self.shape[:-1], idx)
|
4532
4555
|
return self.__class__(
|
4533
4556
|
n=self.shape[-1],
|
@@ -5533,8 +5556,10 @@ class Composite(TensorSpec):
|
|
5533
5556
|
sub_str = [
|
5534
5557
|
indent(f"{k}: {str(item)}", 4 * " ") for k, item in self._specs.items()
|
5535
5558
|
]
|
5559
|
+
if len(sub_str) == 0:
|
5560
|
+
return f"{self.__class__.__name__}(device={self._device}, shape={self.shape}, data_cls={self.data_cls})"
|
5536
5561
|
sub_str = ",\n".join(sub_str)
|
5537
|
-
return f"
|
5562
|
+
return f"{self.__class__.__name__}(\n{sub_str},\n device={self._device},\n shape={self.shape},\n data_cls={self.data_cls})"
|
5538
5563
|
|
5539
5564
|
def type_check(
|
5540
5565
|
self,
|
@@ -1211,7 +1211,6 @@ but got an object of type {type(transform)}."""
|
|
1211
1211
|
if tensordict is not None:
|
1212
1212
|
# We must avoid modifying the original tensordict so a shallow copy is necessary.
|
1213
1213
|
# We just select the input data and reset signal, which is all we need.
|
1214
|
-
self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
|
1215
1214
|
tensordict = tensordict.select(
|
1216
1215
|
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
|
1217
1216
|
)
|
@@ -4,6 +4,7 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
import warnings
|
7
8
|
import weakref
|
8
9
|
from typing import Any, Literal, overload
|
9
10
|
|
@@ -171,6 +172,33 @@ class ChatHistory(TensorClass["nocast"]):
|
|
171
172
|
step_mdp_static=True,
|
172
173
|
)
|
173
174
|
|
175
|
+
def __post_init__(self):
|
176
|
+
# Check that all history objects have one more batch dimension than the ChatHistory object
|
177
|
+
if self.prompt is not None:
|
178
|
+
if self.prompt.batch_dims != self.batch_dims + 1:
|
179
|
+
warnings.warn(
|
180
|
+
"Prompt history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
181
|
+
f"got {self.prompt.batch_dims} and {self.batch_dims}. "
|
182
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
183
|
+
)
|
184
|
+
self.prompt = self.prompt.unsqueeze(-1)
|
185
|
+
if self.response is not None:
|
186
|
+
if self.response.batch_dims != self.batch_dims + 1:
|
187
|
+
warnings.warn(
|
188
|
+
"Response history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
189
|
+
f"got {self.response.batch_dims} and {self.batch_dims}. "
|
190
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
191
|
+
)
|
192
|
+
self.response = self.response.unsqueeze(-1)
|
193
|
+
if self.full is not None:
|
194
|
+
if self.full.batch_dims != self.batch_dims + 1:
|
195
|
+
warnings.warn(
|
196
|
+
"Full history should have one more batch dimension than the ChatHistory object to handle multi-turn conversations, "
|
197
|
+
f"got {self.full.batch_dims} and {self.batch_dims}. "
|
198
|
+
"The batch dimension of the ChatHistory object will be unsqueezed along the last dimension."
|
199
|
+
)
|
200
|
+
self.full = self.full.unsqueeze(-1)
|
201
|
+
|
174
202
|
|
175
203
|
class LogProbs(TensorClass["nocast"]):
|
176
204
|
"""A log-probability container.
|
torchrl/objectives/a2c.py
CHANGED
@@ -70,7 +70,7 @@ class A2CLoss(LossModule):
|
|
70
70
|
samples will be used to compute this estimate.
|
71
71
|
Defaults to ``1``.
|
72
72
|
entropy_coeff (:obj:`float`): the weight of the entropy loss. Defaults to `0.01``.
|
73
|
-
|
73
|
+
critic_coeff (:obj:`float`): the weight of the critic loss. Defaults to ``1.0``. If ``None``, the critic
|
74
74
|
loss won't be included and the in-keys will miss the critic inputs.
|
75
75
|
loss_critic_type (str): loss function for the value discrepancy.
|
76
76
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -156,7 +156,7 @@ class A2CLoss(LossModule):
|
|
156
156
|
the expected keyword arguments are:
|
157
157
|
``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and critic.
|
158
158
|
The return value is a tuple of tensors in the following order:
|
159
|
-
``["loss_objective"]`` + ``["loss_critic"]`` if
|
159
|
+
``["loss_objective"]`` + ``["loss_critic"]`` if critic_coeff is not None + ``["entropy", "loss_entropy"]`` if entropy_bonus is True and critic_coeff is not None
|
160
160
|
|
161
161
|
Examples:
|
162
162
|
>>> import torch
|
@@ -277,8 +277,8 @@ class A2CLoss(LossModule):
|
|
277
277
|
*,
|
278
278
|
entropy_bonus: bool = True,
|
279
279
|
samples_mc_entropy: int = 1,
|
280
|
-
entropy_coeff: float =
|
281
|
-
|
280
|
+
entropy_coeff: float | None = None,
|
281
|
+
critic_coeff: float = 1.0,
|
282
282
|
loss_critic_type: str = "smooth_l1",
|
283
283
|
gamma: float | None = None,
|
284
284
|
separate_losses: bool = False,
|
@@ -291,13 +291,32 @@ class A2CLoss(LossModule):
|
|
291
291
|
clip_value: float | None = None,
|
292
292
|
**kwargs,
|
293
293
|
):
|
294
|
+
# Handle deprecated entropy_coef argument
|
294
295
|
if "entropy_coef" in kwargs:
|
296
|
+
if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
|
297
|
+
raise ValueError(
|
298
|
+
"Cannot specify both 'entropy_coef' and 'entropy_coeff'"
|
299
|
+
)
|
295
300
|
warnings.warn(
|
296
301
|
"'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
|
297
302
|
DeprecationWarning,
|
298
303
|
)
|
299
304
|
entropy_coeff = kwargs.pop("entropy_coef")
|
300
305
|
|
306
|
+
# Set default value if None
|
307
|
+
if entropy_coeff is None:
|
308
|
+
entropy_coeff = 0.01
|
309
|
+
|
310
|
+
# Handle deprecated critic_coef argument
|
311
|
+
if "critic_coef" in kwargs:
|
312
|
+
if critic_coeff != 1.0: # Check if critic_coeff was explicitly set
|
313
|
+
raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
|
314
|
+
warnings.warn(
|
315
|
+
"'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
|
316
|
+
DeprecationWarning,
|
317
|
+
)
|
318
|
+
critic_coeff = kwargs.pop("critic_coef")
|
319
|
+
|
301
320
|
if actor is not None:
|
302
321
|
actor_network = actor
|
303
322
|
del actor
|
@@ -349,12 +368,12 @@ class A2CLoss(LossModule):
|
|
349
368
|
self.register_buffer(
|
350
369
|
"entropy_coeff", torch.as_tensor(entropy_coeff, device=device)
|
351
370
|
)
|
352
|
-
if
|
371
|
+
if critic_coeff is not None:
|
353
372
|
self.register_buffer(
|
354
|
-
"
|
373
|
+
"critic_coeff", torch.as_tensor(critic_coeff, device=device)
|
355
374
|
)
|
356
375
|
else:
|
357
|
-
self.
|
376
|
+
self.critic_coeff = None
|
358
377
|
|
359
378
|
if gamma is not None:
|
360
379
|
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
|
@@ -399,7 +418,7 @@ class A2CLoss(LossModule):
|
|
399
418
|
*self.actor_network.in_keys,
|
400
419
|
*[("next", key) for key in self.actor_network.in_keys],
|
401
420
|
]
|
402
|
-
if self.
|
421
|
+
if self.critic_coeff is not None:
|
403
422
|
keys.extend(self.critic_network.in_keys)
|
404
423
|
return list(set(keys))
|
405
424
|
|
@@ -407,7 +426,7 @@ class A2CLoss(LossModule):
|
|
407
426
|
def out_keys(self):
|
408
427
|
if self._out_keys is None:
|
409
428
|
outs = ["loss_objective"]
|
410
|
-
if self.
|
429
|
+
if self.critic_coeff is not None:
|
411
430
|
outs.append("loss_critic")
|
412
431
|
if self.entropy_bonus:
|
413
432
|
outs.append("entropy")
|
@@ -478,7 +497,7 @@ class A2CLoss(LossModule):
|
|
478
497
|
return log_prob, dist
|
479
498
|
|
480
499
|
def loss_critic(self, tensordict: TensorDictBase) -> tuple[torch.Tensor, float]:
|
481
|
-
"""Returns the loss value of the critic, multiplied by ``
|
500
|
+
"""Returns the loss value of the critic, multiplied by ``critic_coeff`` if it is not ``None``.
|
482
501
|
|
483
502
|
Returns the loss and the clip-fraction.
|
484
503
|
|
@@ -539,8 +558,8 @@ class A2CLoss(LossModule):
|
|
539
558
|
"target_actor_network_params",
|
540
559
|
"target_critic_network_params",
|
541
560
|
)
|
542
|
-
if self.
|
543
|
-
return self.
|
561
|
+
if self.critic_coeff is not None:
|
562
|
+
return self.critic_coeff * loss_value, clip_fraction
|
544
563
|
return loss_value, clip_fraction
|
545
564
|
|
546
565
|
@property
|
@@ -568,7 +587,7 @@ class A2CLoss(LossModule):
|
|
568
587
|
entropy = self.get_entropy_bonus(dist)
|
569
588
|
td_out.set("entropy", entropy.detach().mean()) # for logging
|
570
589
|
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
|
571
|
-
if self.
|
590
|
+
if self.critic_coeff is not None:
|
572
591
|
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
|
573
592
|
td_out.set("loss_critic", loss_critic)
|
574
593
|
if value_clip_fraction is not None:
|
torchrl/objectives/ppo.py
CHANGED
@@ -102,13 +102,13 @@ class PPOLoss(LossModule):
|
|
102
102
|
Defaults to ``1``.
|
103
103
|
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
104
104
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
105
|
-
* **Mapping** ``{head_name:
|
105
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
106
106
|
Defaults to ``0.01``.
|
107
107
|
log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
|
108
108
|
predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
|
109
109
|
This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
|
110
|
-
|
111
|
-
loss. Defaults to ``1.0``. Set ``
|
110
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
111
|
+
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
|
112
112
|
loss from the forward outputs.
|
113
113
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
114
114
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -239,7 +239,7 @@ class PPOLoss(LossModule):
|
|
239
239
|
the expected keyword arguments are:
|
240
240
|
``["action", "sample_log_prob", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and value network.
|
241
241
|
The return value is a tuple of tensors in the following order:
|
242
|
-
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if
|
242
|
+
``["loss_objective"]`` + ``["entropy", "loss_entropy"]`` if entropy_bonus is set + ``"loss_critic"`` if critic_coeff is not ``None``.
|
243
243
|
The output keys can also be filtered using :meth:`PPOLoss.select_out_keys` method.
|
244
244
|
|
245
245
|
Examples:
|
@@ -351,9 +351,9 @@ class PPOLoss(LossModule):
|
|
351
351
|
*,
|
352
352
|
entropy_bonus: bool = True,
|
353
353
|
samples_mc_entropy: int = 1,
|
354
|
-
entropy_coeff: float | Mapping[str, float] =
|
354
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
355
355
|
log_explained_variance: bool = True,
|
356
|
-
|
356
|
+
critic_coeff: float | None = None,
|
357
357
|
loss_critic_type: str = "smooth_l1",
|
358
358
|
normalize_advantage: bool = False,
|
359
359
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -377,13 +377,23 @@ class PPOLoss(LossModule):
|
|
377
377
|
critic_network = critic
|
378
378
|
del critic
|
379
379
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
380
|
+
# Handle deprecated critic_coef argument
|
381
|
+
if "critic_coef" in kwargs:
|
382
|
+
if critic_coeff is not None:
|
383
|
+
raise ValueError("Cannot specify both 'critic_coef' and 'critic_coeff'")
|
384
|
+
warnings.warn(
|
385
|
+
"'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead.",
|
386
|
+
DeprecationWarning,
|
387
|
+
)
|
388
|
+
critic_coeff = kwargs.pop("critic_coef")
|
389
|
+
|
390
|
+
if critic_coeff is None and critic_network is not None:
|
391
|
+
critic_coeff = 1.0
|
392
|
+
elif critic_coeff in (None, 0) and critic_network is not None:
|
393
|
+
critic_coeff = None
|
384
394
|
|
385
395
|
if actor_network is None or (
|
386
|
-
critic_network is None and
|
396
|
+
critic_network is None and critic_coeff not in (None, 0.0)
|
387
397
|
):
|
388
398
|
raise TypeError(
|
389
399
|
"Missing positional arguments actor_network or critic_network."
|
@@ -431,13 +441,21 @@ class PPOLoss(LossModule):
|
|
431
441
|
torch, "get_default_device", lambda: torch.device("cpu")
|
432
442
|
)()
|
433
443
|
|
434
|
-
# Handle deprecated
|
435
|
-
if "
|
444
|
+
# Handle deprecated entropy_coef argument
|
445
|
+
if "entropy_coef" in kwargs:
|
446
|
+
if entropy_coeff is not None: # Check if entropy_coeff was explicitly set
|
447
|
+
raise ValueError(
|
448
|
+
"Cannot specify both 'entropy_coef' and 'entropy_coeff'"
|
449
|
+
)
|
436
450
|
warnings.warn(
|
437
|
-
"'
|
451
|
+
"'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.",
|
438
452
|
DeprecationWarning,
|
439
453
|
)
|
440
|
-
entropy_coeff = kwargs.pop("
|
454
|
+
entropy_coeff = kwargs.pop("entropy_coef")
|
455
|
+
|
456
|
+
# Set default value if None
|
457
|
+
if entropy_coeff is None:
|
458
|
+
entropy_coeff = 0.01
|
441
459
|
|
442
460
|
if isinstance(entropy_coeff, Mapping):
|
443
461
|
# Store the mapping for per-head coefficients
|
@@ -457,13 +475,13 @@ class PPOLoss(LossModule):
|
|
457
475
|
self._entropy_coeff_map = None
|
458
476
|
else:
|
459
477
|
raise TypeError("entropy_coeff must be a float or a Mapping[str, float]")
|
460
|
-
if
|
478
|
+
if critic_coeff is not None:
|
461
479
|
self.register_buffer(
|
462
|
-
"
|
480
|
+
"critic_coeff", torch.tensor(critic_coeff, device=device)
|
463
481
|
)
|
464
482
|
else:
|
465
|
-
self.
|
466
|
-
self._has_critic = bool(self.
|
483
|
+
self.critic_coeff = None
|
484
|
+
self._has_critic = bool(self.critic_coeff is not None and self.critic_coeff > 0)
|
467
485
|
self.loss_critic_type = loss_critic_type
|
468
486
|
self.normalize_advantage = normalize_advantage
|
469
487
|
self.normalize_advantage_exclude_dims = normalize_advantage_exclude_dims
|
@@ -692,7 +710,7 @@ class PPOLoss(LossModule):
|
|
692
710
|
def loss_critic(
|
693
711
|
self, tensordict: TensorDictBase
|
694
712
|
) -> tuple[torch.Tensor | TensorDict, ...]:
|
695
|
-
"""Returns the critic loss multiplied by ``
|
713
|
+
"""Returns the critic loss multiplied by ``critic_coeff``, if it is not ``None``."""
|
696
714
|
# TODO: if the advantage is gathered by forward, this introduces an
|
697
715
|
# overhead that we could easily reduce.
|
698
716
|
if self.separate_losses:
|
@@ -766,7 +784,7 @@ class PPOLoss(LossModule):
|
|
766
784
|
"target_critic_network_params",
|
767
785
|
)
|
768
786
|
if self._has_critic:
|
769
|
-
return self.
|
787
|
+
return self.critic_coeff * loss_value, clip_fraction, explained_variance
|
770
788
|
return loss_value, clip_fraction, explained_variance
|
771
789
|
|
772
790
|
@property
|
@@ -954,10 +972,10 @@ class ClipPPOLoss(PPOLoss):
|
|
954
972
|
Defaults to ``1``.
|
955
973
|
entropy_coeff: (scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
956
974
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
957
|
-
* **Mapping** ``{head_name:
|
975
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
958
976
|
Defaults to ``0.01``.
|
959
|
-
|
960
|
-
loss. Defaults to ``1.0``. Set ``
|
977
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
978
|
+
loss. Defaults to ``1.0``. Set ``critic_coeff`` to ``None`` to exclude the value
|
961
979
|
loss from the forward outputs.
|
962
980
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
963
981
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -1057,8 +1075,8 @@ class ClipPPOLoss(PPOLoss):
|
|
1057
1075
|
clip_epsilon: float = 0.2,
|
1058
1076
|
entropy_bonus: bool = True,
|
1059
1077
|
samples_mc_entropy: int = 1,
|
1060
|
-
entropy_coeff: float | Mapping[str, float] =
|
1061
|
-
|
1078
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
1079
|
+
critic_coeff: float | None = None,
|
1062
1080
|
loss_critic_type: str = "smooth_l1",
|
1063
1081
|
normalize_advantage: bool = False,
|
1064
1082
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -1079,7 +1097,7 @@ class ClipPPOLoss(PPOLoss):
|
|
1079
1097
|
entropy_bonus=entropy_bonus,
|
1080
1098
|
samples_mc_entropy=samples_mc_entropy,
|
1081
1099
|
entropy_coeff=entropy_coeff,
|
1082
|
-
|
1100
|
+
critic_coeff=critic_coeff,
|
1083
1101
|
loss_critic_type=loss_critic_type,
|
1084
1102
|
normalize_advantage=normalize_advantage,
|
1085
1103
|
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
|
@@ -1247,9 +1265,9 @@ class KLPENPPOLoss(PPOLoss):
|
|
1247
1265
|
Defaults to ``1``.
|
1248
1266
|
entropy_coeff: scalar | Mapping[str, scalar], optional): entropy multiplier when computing the total loss.
|
1249
1267
|
* **Scalar**: one value applied to the summed entropy of every action head.
|
1250
|
-
* **Mapping** ``{head_name:
|
1268
|
+
* **Mapping** ``{head_name: coeff}`` gives an individual coefficient for each action-head's entropy.
|
1251
1269
|
Defaults to ``0.01``.
|
1252
|
-
|
1270
|
+
critic_coeff (scalar, optional): critic loss multiplier when computing the total
|
1253
1271
|
loss. Defaults to ``1.0``.
|
1254
1272
|
loss_critic_type (str, optional): loss function for the value discrepancy.
|
1255
1273
|
Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``.
|
@@ -1351,8 +1369,8 @@ class KLPENPPOLoss(PPOLoss):
|
|
1351
1369
|
samples_mc_kl: int = 1,
|
1352
1370
|
entropy_bonus: bool = True,
|
1353
1371
|
samples_mc_entropy: int = 1,
|
1354
|
-
entropy_coeff: float | Mapping[str, float] =
|
1355
|
-
|
1372
|
+
entropy_coeff: float | Mapping[str, float] | None = None,
|
1373
|
+
critic_coeff: float | None = None,
|
1356
1374
|
loss_critic_type: str = "smooth_l1",
|
1357
1375
|
normalize_advantage: bool = False,
|
1358
1376
|
normalize_advantage_exclude_dims: tuple[int] = (),
|
@@ -1369,7 +1387,7 @@ class KLPENPPOLoss(PPOLoss):
|
|
1369
1387
|
entropy_bonus=entropy_bonus,
|
1370
1388
|
samples_mc_entropy=samples_mc_entropy,
|
1371
1389
|
entropy_coeff=entropy_coeff,
|
1372
|
-
|
1390
|
+
critic_coeff=critic_coeff,
|
1373
1391
|
loss_critic_type=loss_critic_type,
|
1374
1392
|
normalize_advantage=normalize_advantage,
|
1375
1393
|
normalize_advantage_exclude_dims=normalize_advantage_exclude_dims,
|
@@ -86,7 +86,7 @@ class A2CLossConfig:
|
|
86
86
|
# Decay factor for return computation. Default=0.99.
|
87
87
|
entropy_coeff: float = 1e-3
|
88
88
|
# Entropy factor for the A2C loss
|
89
|
-
|
89
|
+
critic_coeff: float = 1.0
|
90
90
|
# Critic factor for the A2C loss
|
91
91
|
critic_loss_function: str = "smooth_l1"
|
92
92
|
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
|
@@ -112,7 +112,7 @@ class PPOLossConfig:
|
|
112
112
|
# Number of samples to use for a Monte-Carlo estimate if the policy distribution has not closed formula.
|
113
113
|
loss_function: str = "smooth_l1"
|
114
114
|
# loss function for the value network. Either one of l1, l2 or smooth_l1 (default).
|
115
|
-
|
115
|
+
critic_coeff: float = 1.0
|
116
116
|
# Critic loss multiplier when computing the total loss.
|
117
117
|
|
118
118
|
# ClipPPOLoss parameters:
|
torchrl/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
|
-
__version__ = '2025.7.
|
2
|
-
git_version = '
|
1
|
+
__version__ = '2025.7.16'
|
2
|
+
git_version = '361a8da6edc77979e17409cf19396230d18c18a9'
|
@@ -3,9 +3,9 @@ build_tools/setup_helpers/__init__.py,sha256=7l8TvVqxKezgzKCLuRv20mvGLloprFVZYm8
|
|
3
3
|
build_tools/setup_helpers/extension.py,sha256=4-PDLr-pw40bJnd9SfxnTaSjUyuXU_Tg8yOg69Kl0o4,5914
|
4
4
|
torchrl/__init__.py,sha256=mhDBx2UIuBKc0gmi8dVNHokQ6tCbIovruZmyAxcSsy8,2938
|
5
5
|
torchrl/_extension.py,sha256=z7wQ8i1iYWYcnygq_j0nq9sT-koY13tfHhTLNbMk17Q,2353
|
6
|
-
torchrl/_torchrl.cpython-313-darwin.so,sha256=
|
6
|
+
torchrl/_torchrl.cpython-313-darwin.so,sha256=hKxRdhDxuVm-1ogc0Pnu7F8xo6djoCREE7psO-strBc,1692464
|
7
7
|
torchrl/_utils.py,sha256=Cw5EG6x5oSZF1iE3YCs1a32VUKp0rTXIs2u67q9zKUI,41078
|
8
|
-
torchrl/version.py,sha256=
|
8
|
+
torchrl/version.py,sha256=aShA8Hert0ILH-D5pqVnCE_n2SCgvZPjiPpxsO1x4eM,83
|
9
9
|
torchrl/collectors/__init__.py,sha256=hJ3JD6shRku0BL6SzJQq44FZ5Q1RGR8LealFyU3FRn4,799
|
10
10
|
torchrl/collectors/collectors.py,sha256=WoeR-MAfzcLiy8EHPWQ3uknm_jTWjA9Wi45CODG8NZI,177782
|
11
11
|
torchrl/collectors/utils.py,sha256=MlXrkYuDmV0Em-tVNQiLL32FWgPNDgceYYG_GgpiviA,11320
|
@@ -25,7 +25,7 @@ torchrl/collectors/llm/weight_update/__init__.py,sha256=bKjvD7yZG5VnHgvYc4EmKI1s
|
|
25
25
|
torchrl/collectors/llm/weight_update/vllm.py,sha256=slKUmrIo4eL6R4J1oEnmlP6Q7Zer09p92JU8zbIHFUM,11515
|
26
26
|
torchrl/data/__init__.py,sha256=oowsio6ZUOZnJV8JV43xgs17B37XO1yKAYIQPdk8yt0,4819
|
27
27
|
torchrl/data/rlhf.py,sha256=JUmdYBWgkN229DwpXuDrhy9ddjduNvU2kyHzHR6MoA0,963
|
28
|
-
torchrl/data/tensor_specs.py,sha256=
|
28
|
+
torchrl/data/tensor_specs.py,sha256=RlMckj6PJo9MQMzneHzbcVe9xUyMB_n7pnSz0jytB9s,253907
|
29
29
|
torchrl/data/utils.py,sha256=attuNwzfgjszyp0lJSrV06f2peX3r0qTjRZWEwfl6Yg,12108
|
30
30
|
torchrl/data/datasets/__init__.py,sha256=NQpXsHecbZmza8AocX9mkqQQNkdFzeUrMTZoi6hbbU4,733
|
31
31
|
torchrl/data/datasets/atari_dqn.py,sha256=3ij6-UGfKev-QJuUEhZEEmn_3yL210CqKJALaFvlc5M,40739
|
@@ -131,7 +131,7 @@ torchrl/envs/transforms/llm.py,sha256=rQDzuut807wvFpSPCm5tynt8-cMKTgVKVjSVu9D99P
|
|
131
131
|
torchrl/envs/transforms/r3m.py,sha256=sdTVLpnxHfzFVo5rO8WnXf2uUg9cr4LBOLBsWaFgGT8,13478
|
132
132
|
torchrl/envs/transforms/rb_transforms.py,sha256=6ohnKXHHAEh2Hz3Seaw6eDrcFMu-1IVQrT7RVywh3YQ,7447
|
133
133
|
torchrl/envs/transforms/rlhf.py,sha256=lOVXYqQaoDfm4_n77Dxw_wjicBpMtDvavKmBIK2N3lU,628
|
134
|
-
torchrl/envs/transforms/transforms.py,sha256=
|
134
|
+
torchrl/envs/transforms/transforms.py,sha256=cDv_NxElzTOW8qQO-2krvOBmlKVGPOKMfqM6XyuLckU,482882
|
135
135
|
torchrl/envs/transforms/utils.py,sha256=7ToVFnD4-DkOMtML91g4bqXeY0bZ-gmCaSLxC93oaKM,3264
|
136
136
|
torchrl/envs/transforms/vc1.py,sha256=mho5BvdAK-f9hD9t-iah52wT2B06qPmaJO7chrfIOWY,10534
|
137
137
|
torchrl/envs/transforms/vecnorm.py,sha256=XahMcWvK3zjOB6EACSZtJ6UMP3yQ2zD9xf87UEB37Eg,34047
|
@@ -147,7 +147,7 @@ torchrl/modules/llm/utils.py,sha256=gf_F-4bEMwkcI3jLQM7ifB7nsjRctGebB5E2c-AznO0,
|
|
147
147
|
torchrl/modules/llm/backends/__init__.py,sha256=WdVy9EdiAfk8i5zFa49TEkRvcUd0L4Un4v6wqWBy8l8,438
|
148
148
|
torchrl/modules/llm/backends/vllm.py,sha256=x57Xop1xd5ZShicsh47ZFmz4VpfZ3eCzVx7k0COvpqQ,9387
|
149
149
|
torchrl/modules/llm/policies/__init__.py,sha256=nfZ2mcVuucxnY3WCuzIQrTLIf1yEd36k8-AlvwnSa8Y,545
|
150
|
-
torchrl/modules/llm/policies/common.py,sha256=
|
150
|
+
torchrl/modules/llm/policies/common.py,sha256=jmWoaR6fgKQ5lKa9HqDVW73K3oySpb2zO3WJlNTz6iQ,38117
|
151
151
|
torchrl/modules/llm/policies/transformers_wrapper.py,sha256=HTkubIsbEui2hWqAZ3GwsATI2NGmA0kry1nW5RjnEJ0,74326
|
152
152
|
torchrl/modules/llm/policies/vllm_wrapper.py,sha256=u0ITRdVI8pNhpRRMy2yXEh9bK_TkYRUOUEzix2m2aR0,78231
|
153
153
|
torchrl/modules/models/__init__.py,sha256=DrOG-7hynjjUh_tc2EqysiUiNMRiDR0WLtZql9TPNcI,1743
|
@@ -176,7 +176,7 @@ torchrl/modules/utils/__init__.py,sha256=KXaF_xEghKSPsNg0JyfxChK6KWHFRy0lwkL2Rip
|
|
176
176
|
torchrl/modules/utils/mappings.py,sha256=VMYrPxDk1ywgl2l_f6HXZaRsVOKcYR7VF5DNkmi3lHk,362
|
177
177
|
torchrl/modules/utils/utils.py,sha256=WPfcE-AoemnrP7Ny4FxJ-_LoQsBnX-y77Zb7MnZjXV0,2916
|
178
178
|
torchrl/objectives/__init__.py,sha256=pnprzIXA6E9Ph7isYgNLh4SFTU0pxIQg4oUNcaQ6doc,2148
|
179
|
-
torchrl/objectives/a2c.py,sha256=
|
179
|
+
torchrl/objectives/a2c.py,sha256=_xdp8D2ErOPyHwpxqPHtUr-EvZw7MqcuhhK9Isnewgo,28791
|
180
180
|
torchrl/objectives/common.py,sha256=40inZ0z3bFdQUkXuup3PWP_KmCx1m13cKTksjOp_b6I,28571
|
181
181
|
torchrl/objectives/cql.py,sha256=8faIZmA9e65NQ39HAi6torMofr98bkngjtBXm0UbnVM,54925
|
182
182
|
torchrl/objectives/crossq.py,sha256=a_vAjET5GG-2U7zZDgMnA0QP1iPCtv2ho6q-XvvLsnc,28858
|
@@ -188,7 +188,7 @@ torchrl/objectives/dreamer.py,sha256=vIJQN91oPXYnPubDFQpaF5d3fR_WwIYuIVYtoCvw0TY
|
|
188
188
|
torchrl/objectives/functional.py,sha256=ZaglBjEGuOTNGeFA-Ox-ugZVcNegQMUj--KWHDRBmaU,2106
|
189
189
|
torchrl/objectives/gail.py,sha256=0m34XmcN-EDk5OfNIo5bKYbKKZfATsYRv4zQe3v2UwA,9576
|
190
190
|
torchrl/objectives/iql.py,sha256=1jvlSznWke6NZSwfuYyHVnVBE7Cz3q169GnCRC7iel4,42991
|
191
|
-
torchrl/objectives/ppo.py,sha256=
|
191
|
+
torchrl/objectives/ppo.py,sha256=0soC2aiCOFNM5hCL20-99LX_NZi6XIXDmG2IkGEHSek,76082
|
192
192
|
torchrl/objectives/redq.py,sha256=4usM-nG2UWujeL-VEqzf7-uOwRFx6itkKCeitKuJhtw,28507
|
193
193
|
torchrl/objectives/reinforce.py,sha256=ySXLp5C-OOUYayqjrf4taQmL8LgRvMgPCgHDsle8JDc,22339
|
194
194
|
torchrl/objectives/sac.py,sha256=Oq9Iq90s9KFbnM4KSRUd2onU1JfW6aW80LWGdtO0CY8,63993
|
@@ -219,12 +219,12 @@ torchrl/trainers/helpers/__init__.py,sha256=HhDB2Ubq2gZodV-hB6xw4ZgCgwaZKUoZgOfV
|
|
219
219
|
torchrl/trainers/helpers/collectors.py,sha256=NjMMvGWEe4TWkVXzx7AlJ_Qa_AxEzMl6EUmEgUzHkoE,18715
|
220
220
|
torchrl/trainers/helpers/envs.py,sha256=1yqJZgz7mc5wa58HmSDGpPQINeDHFZB0_KTgwdKm9QE,22084
|
221
221
|
torchrl/trainers/helpers/logger.py,sha256=FtuEiLnK4NmxVVNyEEWaoCu3nG7WbNpHP3UYGQRJmgo,1278
|
222
|
-
torchrl/trainers/helpers/losses.py,sha256=
|
222
|
+
torchrl/trainers/helpers/losses.py,sha256=sHlJqjh02t8cKN73X35Azd_OoWGurohLuviB8Yeo4JQ,5272
|
223
223
|
torchrl/trainers/helpers/models.py,sha256=ihTERG2c96E8cS3Tnul6a_ys6iDEEJmHh05p9blQTW8,21807
|
224
224
|
torchrl/trainers/helpers/replay_buffer.py,sha256=ZUZHOa0TILyeWJ3iahzTJ6UvMl_0FdxuZfJEja94Bn8,2001
|
225
225
|
torchrl/trainers/helpers/trainers.py,sha256=j6B5XA7_FFHMQeOIQwjNcO0CGE_4mZKUC9_jH_iqqh4,12071
|
226
|
-
torchrl_nightly-2025.7.
|
227
|
-
torchrl_nightly-2025.7.
|
228
|
-
torchrl_nightly-2025.7.
|
229
|
-
torchrl_nightly-2025.7.
|
230
|
-
torchrl_nightly-2025.7.
|
226
|
+
torchrl_nightly-2025.7.16.dist-info/licenses/LICENSE,sha256=xdjS4_xk-IwnLuIFCvTYTl9Y8aXRejqpmke3dGam_nI,1098
|
227
|
+
torchrl_nightly-2025.7.16.dist-info/METADATA,sha256=vNEXmcQPsZZGDFx6pCdysVMqIIBiKpH4RpRNO6AMuTs,42990
|
228
|
+
torchrl_nightly-2025.7.16.dist-info/WHEEL,sha256=A6iggJuFsuu67bHdjxJADhwSEJmqwgO3xFoNCIwjOxc,115
|
229
|
+
torchrl_nightly-2025.7.16.dist-info/top_level.txt,sha256=JeTJ1jV7QJwLcUS1nr21aPn_wb-XlAZ9c-z_EH472JA,20
|
230
|
+
torchrl_nightly-2025.7.16.dist-info/RECORD,,
|
File without changes
|
{torchrl_nightly-2025.7.15.dist-info → torchrl_nightly-2025.7.16.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
File without changes
|