torchrl 0.10.0__cp39-cp39-win_amd64.whl → 0.10.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/test_llm.py +5 -0
- torchrl/__init__.py +4 -1
- torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
- torchrl/_utils.py +3 -1
- torchrl/collectors/collectors.py +11 -10
- torchrl/collectors/distributed/generic.py +3 -3
- torchrl/collectors/distributed/ray.py +10 -4
- torchrl/collectors/distributed/rpc.py +3 -3
- torchrl/collectors/distributed/sync.py +3 -3
- torchrl/data/map/tree.py +2 -2
- torchrl/data/tensor_specs.py +191 -8
- torchrl/envs/batched_envs.py +1 -1
- torchrl/envs/common.py +1 -1
- torchrl/envs/custom/llm.py +3 -3
- torchrl/envs/llm/envs.py +3 -3
- torchrl/envs/transforms/transforms.py +2 -2
- torchrl/modules/distributions/discrete.py +1 -1
- torchrl/modules/llm/backends/vllm/vllm_async.py +1 -1
- torchrl/modules/llm/policies/transformers_wrapper.py +2 -1
- torchrl/modules/llm/policies/vllm_wrapper.py +1 -0
- torchrl/objectives/a2c.py +3 -3
- torchrl/objectives/cql.py +2 -2
- torchrl/objectives/crossq.py +2 -2
- torchrl/objectives/ddpg.py +1 -1
- torchrl/objectives/decision_transformer.py +2 -2
- torchrl/objectives/deprecated.py +2 -2
- torchrl/objectives/dqn.py +4 -4
- torchrl/objectives/gail.py +1 -1
- torchrl/objectives/iql.py +4 -4
- torchrl/objectives/multiagent/qmixer.py +1 -1
- torchrl/objectives/redq.py +2 -2
- torchrl/objectives/reinforce.py +3 -3
- torchrl/objectives/sac.py +5 -5
- torchrl/objectives/td3.py +2 -2
- torchrl/objectives/td3_bc.py +2 -2
- torchrl/record/loggers/wandb.py +3 -3
- {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/METADATA +1 -1
- {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/RECORD +42 -41
- torchrl-0.10.1.dist-info/entry_points.txt +2 -0
- {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/LICENSE +0 -0
- {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/WHEEL +0 -0
- {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/top_level.txt +0 -0
benchmarks/test_llm.py
CHANGED
|
@@ -16,6 +16,11 @@ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrappe
|
|
|
16
16
|
|
|
17
17
|
_has_transformers = importlib.import_module("transformers") is not None
|
|
18
18
|
|
|
19
|
+
# Skip all these tests if gpu is not available
|
|
20
|
+
pytestmark = pytest.mark.skipif(
|
|
21
|
+
not torch.cuda.is_available(), reason="GPU not available"
|
|
22
|
+
)
|
|
23
|
+
|
|
19
24
|
|
|
20
25
|
@pytest.fixture(scope="module")
|
|
21
26
|
def transformers_wrapper():
|
torchrl/__init__.py
CHANGED
|
@@ -27,7 +27,10 @@ from ._extension import _init_extension
|
|
|
27
27
|
try:
|
|
28
28
|
from .version import __version__
|
|
29
29
|
except ImportError:
|
|
30
|
-
|
|
30
|
+
try:
|
|
31
|
+
from ._version import __version__
|
|
32
|
+
except ImportError:
|
|
33
|
+
__version__ = "0.0.0+unknown"
|
|
31
34
|
|
|
32
35
|
try:
|
|
33
36
|
from torch.compiler import is_dynamo_compiling
|
|
Binary file
|
torchrl/_utils.py
CHANGED
|
@@ -410,7 +410,9 @@ def accept_remote_rref_udf_invocation(decorated_class):
|
|
|
410
410
|
"""Class decorator that applies `accept_remote_rref_invocation` to all public methods."""
|
|
411
411
|
# ignores private methods
|
|
412
412
|
for name in dir(decorated_class):
|
|
413
|
-
method = getattr(decorated_class, name)
|
|
413
|
+
method = getattr(decorated_class, name, None)
|
|
414
|
+
if method is None:
|
|
415
|
+
continue
|
|
414
416
|
if callable(method) and not name.startswith("_"):
|
|
415
417
|
setattr(decorated_class, name, accept_remote_rref_invocation(method))
|
|
416
418
|
return decorated_class
|
torchrl/collectors/collectors.py
CHANGED
|
@@ -283,12 +283,13 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
|
|
|
283
283
|
) -> None:
|
|
284
284
|
"""Shuts down the collector when started asynchronously with the `start` method.
|
|
285
285
|
|
|
286
|
-
|
|
286
|
+
Args:
|
|
287
287
|
timeout (float, optional): The maximum time to wait for the collector to shutdown.
|
|
288
288
|
close_env (bool, optional): If True, the collector will close the contained environment.
|
|
289
289
|
Defaults to `True`.
|
|
290
290
|
|
|
291
291
|
.. seealso:: :meth:`~.start`
|
|
292
|
+
|
|
292
293
|
"""
|
|
293
294
|
return self.shutdown(timeout=timeout, close_env=close_env)
|
|
294
295
|
|
|
@@ -440,7 +441,7 @@ class SyncDataCollector(DataCollectorBase):
|
|
|
440
441
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
441
442
|
|
|
442
443
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
443
|
-
pickled directly), the
|
|
444
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
444
445
|
|
|
445
446
|
Keyword Args:
|
|
446
447
|
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
@@ -1784,7 +1785,7 @@ class _MultiDataCollector(DataCollectorBase):
|
|
|
1784
1785
|
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
1785
1786
|
|
|
1786
1787
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
1787
|
-
pickled directly), the
|
|
1788
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
1788
1789
|
|
|
1789
1790
|
Keyword Args:
|
|
1790
1791
|
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
|
|
@@ -2749,8 +2750,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
|
|
|
2749
2750
|
... if i == 2:
|
|
2750
2751
|
... print(data)
|
|
2751
2752
|
... break
|
|
2752
|
-
|
|
2753
|
-
|
|
2753
|
+
... collector.shutdown()
|
|
2754
|
+
... del collector
|
|
2754
2755
|
TensorDict(
|
|
2755
2756
|
fields={
|
|
2756
2757
|
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
@@ -3130,8 +3131,8 @@ class MultiaSyncDataCollector(_MultiDataCollector):
|
|
|
3130
3131
|
... if i == 2:
|
|
3131
3132
|
... print(data)
|
|
3132
3133
|
... break
|
|
3133
|
-
...
|
|
3134
|
-
...
|
|
3134
|
+
... collector.shutdown()
|
|
3135
|
+
... del collector
|
|
3135
3136
|
TensorDict(
|
|
3136
3137
|
fields={
|
|
3137
3138
|
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
|
|
@@ -3366,7 +3367,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
|
|
|
3366
3367
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
3367
3368
|
|
|
3368
3369
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
3369
|
-
pickled directly), the
|
|
3370
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
3370
3371
|
|
|
3371
3372
|
Keyword Args:
|
|
3372
3373
|
policy_factory (Callable[[], Callable], optional): a callable that returns
|
|
@@ -3380,8 +3381,8 @@ class aSyncDataCollector(MultiaSyncDataCollector):
|
|
|
3380
3381
|
total number of frames returned by the collector
|
|
3381
3382
|
during its lifespan. If the ``total_frames`` is not divisible by
|
|
3382
3383
|
``frames_per_batch``, an exception is raised.
|
|
3383
|
-
|
|
3384
|
-
|
|
3384
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
3385
|
+
Defaults to ``-1`` (never ending collector).
|
|
3385
3386
|
device (int, str or torch.device, optional): The generic device of the
|
|
3386
3387
|
collector. The ``device`` args fills any non-specified device: if
|
|
3387
3388
|
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
|
@@ -282,7 +282,7 @@ class DistributedDataCollector(DataCollectorBase):
|
|
|
282
282
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
283
283
|
|
|
284
284
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
285
|
-
pickled directly), the
|
|
285
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
286
286
|
|
|
287
287
|
Keyword Args:
|
|
288
288
|
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
|
|
@@ -296,8 +296,8 @@ class DistributedDataCollector(DataCollectorBase):
|
|
|
296
296
|
number of frames returned by the collector
|
|
297
297
|
during its lifespan. If the ``total_frames`` is not divisible by
|
|
298
298
|
``frames_per_batch``, an exception is raised.
|
|
299
|
-
|
|
300
|
-
|
|
299
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
300
|
+
Defaults to ``-1`` (endless collector).
|
|
301
301
|
device (int, str or torch.device, optional): The generic device of the
|
|
302
302
|
collector. The ``device`` args fills any non-specified device: if
|
|
303
303
|
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
|
@@ -131,7 +131,7 @@ class RayCollector(DataCollectorBase):
|
|
|
131
131
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
132
132
|
|
|
133
133
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
134
|
-
pickled directly), the
|
|
134
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
135
135
|
|
|
136
136
|
Keyword Args:
|
|
137
137
|
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
|
|
@@ -263,6 +263,10 @@ class RayCollector(DataCollectorBase):
|
|
|
263
263
|
If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging
|
|
264
264
|
Ray's distributed capabilities.
|
|
265
265
|
Consider using a constructor if the updater needs to be serialized.
|
|
266
|
+
use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped
|
|
267
|
+
in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory
|
|
268
|
+
needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed.
|
|
269
|
+
Defaults to ``False``.
|
|
266
270
|
|
|
267
271
|
Examples:
|
|
268
272
|
>>> from torch import nn
|
|
@@ -326,6 +330,7 @@ class RayCollector(DataCollectorBase):
|
|
|
326
330
|
weight_updater: WeightUpdaterBase
|
|
327
331
|
| Callable[[], WeightUpdaterBase]
|
|
328
332
|
| None = None,
|
|
333
|
+
use_env_creator: bool = False,
|
|
329
334
|
):
|
|
330
335
|
self.frames_per_batch = frames_per_batch
|
|
331
336
|
if remote_configs is None:
|
|
@@ -400,9 +405,10 @@ class RayCollector(DataCollectorBase):
|
|
|
400
405
|
create_env_fn, collector_kwargs, remote_configs = out_lists
|
|
401
406
|
num_collectors = len(create_env_fn)
|
|
402
407
|
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
create_env_fn[i]
|
|
408
|
+
if use_env_creator:
|
|
409
|
+
for i in range(len(create_env_fn)):
|
|
410
|
+
if not isinstance(create_env_fn[i], (EnvBase, EnvCreator)):
|
|
411
|
+
create_env_fn[i] = EnvCreator(create_env_fn[i])
|
|
406
412
|
|
|
407
413
|
# If ray available, try to connect to an existing Ray cluster or start one and connect to it.
|
|
408
414
|
if not _has_ray:
|
|
@@ -121,7 +121,7 @@ class RPCDataCollector(DataCollectorBase):
|
|
|
121
121
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
122
122
|
|
|
123
123
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
124
|
-
pickled directly), the
|
|
124
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
125
125
|
|
|
126
126
|
Keyword Args:
|
|
127
127
|
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
|
|
@@ -135,8 +135,8 @@ class RPCDataCollector(DataCollectorBase):
|
|
|
135
135
|
number of frames returned by the collector
|
|
136
136
|
during its lifespan. If the ``total_frames`` is not divisible by
|
|
137
137
|
``frames_per_batch``, an exception is raised.
|
|
138
|
-
|
|
139
|
-
|
|
138
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
139
|
+
Defaults to ``-1`` (endless collector).
|
|
140
140
|
device (int, str or torch.device, optional): The generic device of the
|
|
141
141
|
collector. The ``device`` args fills any non-specified device: if
|
|
142
142
|
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
|
@@ -158,7 +158,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
|
|
|
158
158
|
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
|
|
159
159
|
|
|
160
160
|
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
|
|
161
|
-
pickled directly), the
|
|
161
|
+
pickled directly), the ``policy_factory`` should be used instead.
|
|
162
162
|
|
|
163
163
|
Keyword Args:
|
|
164
164
|
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
|
|
@@ -172,8 +172,8 @@ class DistributedSyncDataCollector(DataCollectorBase):
|
|
|
172
172
|
number of frames returned by the collector
|
|
173
173
|
during its lifespan. If the ``total_frames`` is not divisible by
|
|
174
174
|
``frames_per_batch``, an exception is raised.
|
|
175
|
-
|
|
176
|
-
|
|
175
|
+
Endless collectors can be created by passing ``total_frames=-1``.
|
|
176
|
+
Defaults to ``-1`` (endless collector).
|
|
177
177
|
device (int, str or torch.device, optional): The generic device of the
|
|
178
178
|
collector. The ``device`` args fills any non-specified device: if
|
|
179
179
|
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
|
torchrl/data/map/tree.py
CHANGED
|
@@ -610,7 +610,7 @@ class Tree(TensorClass["nocast"]):
|
|
|
610
610
|
This function can pull out information from each of the nodes in a tree,
|
|
611
611
|
so it can be useful for debugging. The nodes are listed line-by-line.
|
|
612
612
|
Each line contains the path to the node, followed by the string
|
|
613
|
-
representation of that node generated with
|
|
613
|
+
representation of that node generated with ``node_format_fn``. Each
|
|
614
614
|
line is indented according to number of steps in the path required to
|
|
615
615
|
get to the corresponding node.
|
|
616
616
|
|
|
@@ -1370,7 +1370,7 @@ class MCTSForest:
|
|
|
1370
1370
|
This function can pull out information from each of the nodes in a tree,
|
|
1371
1371
|
so it can be useful for debugging. The nodes are listed line-by-line.
|
|
1372
1372
|
Each line contains the path to the node, followed by the string
|
|
1373
|
-
representation of that node generated with
|
|
1373
|
+
representation of that node generated with ``node_format_fn``. Each
|
|
1374
1374
|
line is indented according to number of steps in the path required to
|
|
1375
1375
|
get to the corresponding node.
|
|
1376
1376
|
|
torchrl/data/tensor_specs.py
CHANGED
|
@@ -13,7 +13,7 @@ import math
|
|
|
13
13
|
import warnings
|
|
14
14
|
import weakref
|
|
15
15
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
|
16
|
-
from copy import deepcopy
|
|
16
|
+
from copy import copy, deepcopy
|
|
17
17
|
from dataclasses import dataclass, field
|
|
18
18
|
from functools import wraps
|
|
19
19
|
from textwrap import indent
|
|
@@ -5095,6 +5095,7 @@ class Composite(TensorSpec):
|
|
|
5095
5095
|
|
|
5096
5096
|
shape: torch.Size
|
|
5097
5097
|
domain: str = "composite"
|
|
5098
|
+
_td_dim_names: list[str] | None = None
|
|
5098
5099
|
|
|
5099
5100
|
SPEC_HANDLED_FUNCTIONS = {}
|
|
5100
5101
|
|
|
@@ -5111,6 +5112,7 @@ class Composite(TensorSpec):
|
|
|
5111
5112
|
device: torch.device | None = None,
|
|
5112
5113
|
data_cls: type | None = None,
|
|
5113
5114
|
step_mdp_static: bool = False,
|
|
5115
|
+
names: Sequence[str] | None = None,
|
|
5114
5116
|
**kwargs,
|
|
5115
5117
|
):
|
|
5116
5118
|
# For compatibility with TensorDict
|
|
@@ -5126,6 +5128,12 @@ class Composite(TensorSpec):
|
|
|
5126
5128
|
self._specs = {}
|
|
5127
5129
|
self.step_mdp_static = step_mdp_static
|
|
5128
5130
|
|
|
5131
|
+
# Initialize names
|
|
5132
|
+
if names is not None:
|
|
5133
|
+
self._td_dim_names = list(names)
|
|
5134
|
+
else:
|
|
5135
|
+
self._td_dim_names = None
|
|
5136
|
+
|
|
5129
5137
|
_device = (
|
|
5130
5138
|
_make_ordinal_device(torch.device(device)) if device is not None else device
|
|
5131
5139
|
)
|
|
@@ -5142,6 +5150,8 @@ class Composite(TensorSpec):
|
|
|
5142
5150
|
)
|
|
5143
5151
|
for k, item in argdict.items():
|
|
5144
5152
|
if isinstance(item, dict):
|
|
5153
|
+
# Create nested Composite with appropriate names
|
|
5154
|
+
# Note: nested specs will get their names propagated later in the names setter
|
|
5145
5155
|
item = Composite(item, shape=shape, device=_device)
|
|
5146
5156
|
self[k] = item
|
|
5147
5157
|
for k, item in kwargs.items():
|
|
@@ -5150,6 +5160,10 @@ class Composite(TensorSpec):
|
|
|
5150
5160
|
self.encode = self._encode_eager
|
|
5151
5161
|
self._encode_memo_dict = {}
|
|
5152
5162
|
|
|
5163
|
+
# Propagate names to nested specs if names were provided
|
|
5164
|
+
if names is not None:
|
|
5165
|
+
self._propagate_names_to_nested()
|
|
5166
|
+
|
|
5153
5167
|
def memoize_encode(self, mode: bool = True) -> None:
|
|
5154
5168
|
super().memoize_encode(mode=mode)
|
|
5155
5169
|
for spec in self._specs.values():
|
|
@@ -5354,6 +5368,127 @@ class Composite(TensorSpec):
|
|
|
5354
5368
|
spec.clear_device_()
|
|
5355
5369
|
return self
|
|
5356
5370
|
|
|
5371
|
+
def _has_names(self):
|
|
5372
|
+
"""Returns True if names are set for this Composite."""
|
|
5373
|
+
return self._td_dim_names is not None
|
|
5374
|
+
|
|
5375
|
+
def _erase_names(self):
|
|
5376
|
+
"""Erases the names of this Composite."""
|
|
5377
|
+
self._td_dim_names = None
|
|
5378
|
+
|
|
5379
|
+
def _propagate_names_to_nested(self):
|
|
5380
|
+
"""Propagates names to nested Composite specs."""
|
|
5381
|
+
if not self._has_names():
|
|
5382
|
+
return
|
|
5383
|
+
for spec in self._specs.values():
|
|
5384
|
+
if isinstance(spec, Composite):
|
|
5385
|
+
# For nested specs, we need to propagate the names
|
|
5386
|
+
# The nested spec should have the same leading dimensions
|
|
5387
|
+
if spec.ndim >= self.ndim:
|
|
5388
|
+
nested_names = list(self.names) + [None] * (spec.ndim - self.ndim)
|
|
5389
|
+
spec.names = nested_names
|
|
5390
|
+
|
|
5391
|
+
@property
|
|
5392
|
+
def names(self):
|
|
5393
|
+
"""Returns the names of the dimensions of this Composite."""
|
|
5394
|
+
names = self._td_dim_names
|
|
5395
|
+
if names is None:
|
|
5396
|
+
return [None for _ in range(self.ndim)]
|
|
5397
|
+
# Return a copy but don't use copy to make dynamo happy
|
|
5398
|
+
return list(names)
|
|
5399
|
+
|
|
5400
|
+
@names.setter
|
|
5401
|
+
def names(self, value):
|
|
5402
|
+
"""Sets the names of the dimensions of this Composite."""
|
|
5403
|
+
if value is None:
|
|
5404
|
+
self._td_dim_names = None
|
|
5405
|
+
return
|
|
5406
|
+
if len(value) != self.ndim:
|
|
5407
|
+
raise ValueError(
|
|
5408
|
+
f"Expected {self.ndim} names, but got {len(value)} names: {value}"
|
|
5409
|
+
)
|
|
5410
|
+
self._td_dim_names = list(value)
|
|
5411
|
+
# Propagate names to nested Composite specs
|
|
5412
|
+
for spec in self._specs.values():
|
|
5413
|
+
if isinstance(spec, Composite):
|
|
5414
|
+
# For nested specs, we need to propagate the names
|
|
5415
|
+
# The nested spec should have the same leading dimensions
|
|
5416
|
+
if spec.ndim >= self.ndim:
|
|
5417
|
+
nested_names = list(value) + [None] * (spec.ndim - self.ndim)
|
|
5418
|
+
spec.names = nested_names
|
|
5419
|
+
|
|
5420
|
+
def refine_names(self, *names):
|
|
5421
|
+
"""Refines the dimension names of self according to names.
|
|
5422
|
+
|
|
5423
|
+
Refining is a special case of renaming that "lifts" unnamed dimensions.
|
|
5424
|
+
A None dim can be refined to have any name; a named dim can only be
|
|
5425
|
+
refined to have the same name.
|
|
5426
|
+
|
|
5427
|
+
Because named specs can coexist with unnamed specs, refining names
|
|
5428
|
+
gives a nice way to write named-spec-aware code that works with both
|
|
5429
|
+
named and unnamed specs.
|
|
5430
|
+
|
|
5431
|
+
names may contain up to one Ellipsis (...). The Ellipsis is expanded
|
|
5432
|
+
greedily; it is expanded in-place to fill names to the same length as
|
|
5433
|
+
self.ndim using names from the corresponding indices of self.names.
|
|
5434
|
+
|
|
5435
|
+
Returns: the same composite spec with dimensions named according to the input.
|
|
5436
|
+
|
|
5437
|
+
Examples:
|
|
5438
|
+
>>> spec = Composite({}, shape=[3, 4, 5, 6])
|
|
5439
|
+
>>> spec_refined = spec.refine_names(None, None, None, "d")
|
|
5440
|
+
>>> assert spec_refined.names == [None, None, None, "d"]
|
|
5441
|
+
>>> spec_refined = spec.refine_names("a", None, None, "d")
|
|
5442
|
+
>>> assert spec_refined.names == ["a", None, None, "d"]
|
|
5443
|
+
|
|
5444
|
+
"""
|
|
5445
|
+
# replace ellipsis if any
|
|
5446
|
+
names_copy = copy(names)
|
|
5447
|
+
if any(name is Ellipsis for name in names):
|
|
5448
|
+
ellipsis_name = [None for _ in range(self.ndim - len(names) + 1)]
|
|
5449
|
+
names = []
|
|
5450
|
+
for name in names_copy:
|
|
5451
|
+
if name is Ellipsis:
|
|
5452
|
+
names += ellipsis_name
|
|
5453
|
+
else:
|
|
5454
|
+
names.append(name)
|
|
5455
|
+
|
|
5456
|
+
# check that the names that are set are either None or identical
|
|
5457
|
+
curr_names = self.names
|
|
5458
|
+
for i, name in enumerate(names):
|
|
5459
|
+
if curr_names[i] is None:
|
|
5460
|
+
continue
|
|
5461
|
+
if curr_names[i] == name:
|
|
5462
|
+
continue
|
|
5463
|
+
else:
|
|
5464
|
+
raise RuntimeError(
|
|
5465
|
+
f"refine_names: cannot coerce Composite names {self.names} with {names_copy}."
|
|
5466
|
+
)
|
|
5467
|
+
self.names = names
|
|
5468
|
+
return self
|
|
5469
|
+
|
|
5470
|
+
def _get_names_idx(self, idx):
|
|
5471
|
+
"""Helper method to get names after indexing."""
|
|
5472
|
+
if not self._has_names():
|
|
5473
|
+
return None
|
|
5474
|
+
|
|
5475
|
+
names = copy(self.names)
|
|
5476
|
+
if isinstance(idx, (int, slice)):
|
|
5477
|
+
# Single dimension indexing
|
|
5478
|
+
if isinstance(idx, int):
|
|
5479
|
+
names.pop(idx)
|
|
5480
|
+
else:
|
|
5481
|
+
# For slice, we keep the names but adjust for the slice
|
|
5482
|
+
pass
|
|
5483
|
+
elif isinstance(idx, tuple):
|
|
5484
|
+
# Multi-dimensional indexing
|
|
5485
|
+
for i, sub_idx in enumerate(idx):
|
|
5486
|
+
if isinstance(sub_idx, int):
|
|
5487
|
+
# Remove the dimension
|
|
5488
|
+
names.pop(i)
|
|
5489
|
+
# For slices, we keep the name
|
|
5490
|
+
return names
|
|
5491
|
+
|
|
5357
5492
|
def __getitem__(self, idx):
|
|
5358
5493
|
"""Indexes the current Composite based on the provided index."""
|
|
5359
5494
|
if isinstance(idx, (str, tuple)):
|
|
@@ -5393,10 +5528,15 @@ class Composite(TensorSpec):
|
|
|
5393
5528
|
except RuntimeError:
|
|
5394
5529
|
device = self._device
|
|
5395
5530
|
|
|
5531
|
+
names = None
|
|
5532
|
+
if self._has_names():
|
|
5533
|
+
names = self._get_names_idx(idx)
|
|
5534
|
+
|
|
5396
5535
|
return self.__class__(
|
|
5397
5536
|
indexed_specs,
|
|
5398
5537
|
shape=indexed_shape,
|
|
5399
5538
|
device=device,
|
|
5539
|
+
names=names,
|
|
5400
5540
|
)
|
|
5401
5541
|
|
|
5402
5542
|
def get(self, item, default=NO_DEFAULT):
|
|
@@ -5600,16 +5740,22 @@ class Composite(TensorSpec):
|
|
|
5600
5740
|
for key, item in self.items():
|
|
5601
5741
|
if item is not None:
|
|
5602
5742
|
_dict[key] = item.rand(shape)
|
|
5603
|
-
|
|
5604
|
-
|
|
5743
|
+
|
|
5744
|
+
cls = self.data_cls if self.data_cls is not None else TensorDict
|
|
5745
|
+
if cls is not TensorDict:
|
|
5746
|
+
kwargs = {}
|
|
5747
|
+
if self._td_dim_names is not None:
|
|
5748
|
+
warnings.warn(f"names for cls {cls} is not supported for rand.")
|
|
5605
5749
|
else:
|
|
5606
|
-
|
|
5750
|
+
kwargs = {"names": self._td_dim_names}
|
|
5751
|
+
|
|
5607
5752
|
# No need to run checks since we know Composite is compliant with
|
|
5608
5753
|
# TensorDict requirements
|
|
5609
5754
|
return cls.from_dict(
|
|
5610
5755
|
_dict,
|
|
5611
5756
|
batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
|
|
5612
5757
|
device=self.device,
|
|
5758
|
+
**kwargs,
|
|
5613
5759
|
)
|
|
5614
5760
|
|
|
5615
5761
|
def keys(
|
|
@@ -5760,6 +5906,7 @@ class Composite(TensorSpec):
|
|
|
5760
5906
|
shape=self.shape,
|
|
5761
5907
|
data_cls=self.data_cls,
|
|
5762
5908
|
step_mdp_static=self.step_mdp_static,
|
|
5909
|
+
names=self.names if self._has_names() else None,
|
|
5763
5910
|
)
|
|
5764
5911
|
if not isinstance(dest, (str, int, torch.device)):
|
|
5765
5912
|
raise ValueError(
|
|
@@ -5782,6 +5929,7 @@ class Composite(TensorSpec):
|
|
|
5782
5929
|
shape=self.shape,
|
|
5783
5930
|
data_cls=self.data_cls,
|
|
5784
5931
|
step_mdp_static=self.step_mdp_static,
|
|
5932
|
+
names=self.names if self._has_names() else None,
|
|
5785
5933
|
)
|
|
5786
5934
|
|
|
5787
5935
|
def clone(self) -> Composite:
|
|
@@ -5802,6 +5950,7 @@ class Composite(TensorSpec):
|
|
|
5802
5950
|
shape=self.shape,
|
|
5803
5951
|
data_cls=self.data_cls,
|
|
5804
5952
|
step_mdp_static=self.step_mdp_static,
|
|
5953
|
+
names=self.names if self._has_names() else None,
|
|
5805
5954
|
)
|
|
5806
5955
|
|
|
5807
5956
|
def cardinality(self) -> int:
|
|
@@ -5874,10 +6023,13 @@ class Composite(TensorSpec):
|
|
|
5874
6023
|
except RuntimeError:
|
|
5875
6024
|
device = self._device
|
|
5876
6025
|
|
|
5877
|
-
if self.data_cls is not None
|
|
5878
|
-
|
|
6026
|
+
cls = self.data_cls if self.data_cls is not None else TensorDict
|
|
6027
|
+
if cls is not TensorDict:
|
|
6028
|
+
kwargs = {}
|
|
6029
|
+
if self._td_dim_names is not None:
|
|
6030
|
+
warnings.warn(f"names for cls {cls} is not supported for zero.")
|
|
5879
6031
|
else:
|
|
5880
|
-
|
|
6032
|
+
kwargs = {"names": self._td_dim_names}
|
|
5881
6033
|
|
|
5882
6034
|
return cls.from_dict(
|
|
5883
6035
|
{
|
|
@@ -5887,6 +6039,7 @@ class Composite(TensorSpec):
|
|
|
5887
6039
|
},
|
|
5888
6040
|
batch_size=_size([*shape, *self._safe_shape]),
|
|
5889
6041
|
device=device,
|
|
6042
|
+
**kwargs,
|
|
5890
6043
|
)
|
|
5891
6044
|
|
|
5892
6045
|
def __eq__(self, other: object) -> bool:
|
|
@@ -5942,12 +6095,17 @@ class Composite(TensorSpec):
|
|
|
5942
6095
|
else None
|
|
5943
6096
|
for key, value in tuple(self.items())
|
|
5944
6097
|
}
|
|
6098
|
+
names = None
|
|
6099
|
+
if self._has_names():
|
|
6100
|
+
names = [None] * (len(shape) - self.ndim) + self.names
|
|
6101
|
+
|
|
5945
6102
|
out = Composite(
|
|
5946
6103
|
specs,
|
|
5947
6104
|
shape=shape,
|
|
5948
6105
|
device=device,
|
|
5949
6106
|
data_cls=self.data_cls,
|
|
5950
6107
|
step_mdp_static=self.step_mdp_static,
|
|
6108
|
+
names=names,
|
|
5951
6109
|
)
|
|
5952
6110
|
return out
|
|
5953
6111
|
|
|
@@ -5965,12 +6123,21 @@ class Composite(TensorSpec):
|
|
|
5965
6123
|
except RuntimeError:
|
|
5966
6124
|
device = self._device
|
|
5967
6125
|
|
|
6126
|
+
names = None
|
|
6127
|
+
if self._has_names():
|
|
6128
|
+
names = copy(self.names)
|
|
6129
|
+
names.pop(dim)
|
|
6130
|
+
# If all names are None after popping, set to None
|
|
6131
|
+
if all(name is None for name in names):
|
|
6132
|
+
names = None
|
|
6133
|
+
|
|
5968
6134
|
return self.__class__(
|
|
5969
6135
|
{key: value.squeeze(dim) for key, value in self.items()},
|
|
5970
6136
|
shape=shape,
|
|
5971
6137
|
device=device,
|
|
5972
6138
|
data_cls=self.data_cls,
|
|
5973
6139
|
step_mdp_static=self.step_mdp_static,
|
|
6140
|
+
names=names,
|
|
5974
6141
|
)
|
|
5975
6142
|
|
|
5976
6143
|
if self.shape.count(1) == 0:
|
|
@@ -5993,6 +6160,11 @@ class Composite(TensorSpec):
|
|
|
5993
6160
|
except RuntimeError:
|
|
5994
6161
|
device = self._device
|
|
5995
6162
|
|
|
6163
|
+
names = None
|
|
6164
|
+
if self._has_names():
|
|
6165
|
+
names = copy(self.names)
|
|
6166
|
+
names.insert(dim, None)
|
|
6167
|
+
|
|
5996
6168
|
return self.__class__(
|
|
5997
6169
|
{
|
|
5998
6170
|
key: value.unsqueeze(dim) if value is not None else None
|
|
@@ -6002,6 +6174,7 @@ class Composite(TensorSpec):
|
|
|
6002
6174
|
device=device,
|
|
6003
6175
|
data_cls=self.data_cls,
|
|
6004
6176
|
step_mdp_static=self.step_mdp_static,
|
|
6177
|
+
names=names,
|
|
6005
6178
|
)
|
|
6006
6179
|
|
|
6007
6180
|
def unbind(self, dim: int = 0) -> tuple[Composite, ...]:
|
|
@@ -6012,8 +6185,17 @@ class Composite(TensorSpec):
|
|
|
6012
6185
|
raise ValueError(
|
|
6013
6186
|
f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
|
|
6014
6187
|
)
|
|
6015
|
-
shape = (s for i, s in enumerate(self.shape) if i != dim)
|
|
6188
|
+
shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
|
|
6016
6189
|
unbound_vals = {key: val.unbind(dim) for key, val in self.items()}
|
|
6190
|
+
|
|
6191
|
+
names = None
|
|
6192
|
+
if self._has_names():
|
|
6193
|
+
names = copy(self.names)
|
|
6194
|
+
names.pop(dim)
|
|
6195
|
+
# If all names are None after popping, set to None
|
|
6196
|
+
if all(name is None for name in names):
|
|
6197
|
+
names = None
|
|
6198
|
+
|
|
6017
6199
|
return tuple(
|
|
6018
6200
|
self.__class__(
|
|
6019
6201
|
{key: val[i] for key, val in unbound_vals.items()},
|
|
@@ -6021,6 +6203,7 @@ class Composite(TensorSpec):
|
|
|
6021
6203
|
device=self.device,
|
|
6022
6204
|
data_cls=self.data_cls,
|
|
6023
6205
|
step_mdp_static=self.step_mdp_static,
|
|
6206
|
+
names=names,
|
|
6024
6207
|
)
|
|
6025
6208
|
for i in range(self.shape[dim])
|
|
6026
6209
|
)
|
torchrl/envs/batched_envs.py
CHANGED
|
@@ -308,7 +308,7 @@ class BatchedEnvBase(EnvBase):
|
|
|
308
308
|
num_sub_threads: int = 1,
|
|
309
309
|
serial_for_single: bool = False,
|
|
310
310
|
non_blocking: bool = False,
|
|
311
|
-
mp_start_method: str = None,
|
|
311
|
+
mp_start_method: str | None = None,
|
|
312
312
|
use_buffers: bool | None = None,
|
|
313
313
|
consolidate: bool = True,
|
|
314
314
|
):
|
torchrl/envs/common.py
CHANGED
|
@@ -2267,7 +2267,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
|
|
|
2267
2267
|
entry_point: Callable | None = None,
|
|
2268
2268
|
transform: Transform | None = None, # noqa: F821
|
|
2269
2269
|
info_keys: list[NestedKey] | None = None,
|
|
2270
|
-
backend: str = None,
|
|
2270
|
+
backend: str | None = None,
|
|
2271
2271
|
to_numpy: bool = False,
|
|
2272
2272
|
reward_threshold: float | None = None,
|
|
2273
2273
|
nondeterministic: bool = False,
|
torchrl/envs/custom/llm.py
CHANGED
|
@@ -28,10 +28,10 @@ class LLMHashingEnv(EnvBase):
|
|
|
28
28
|
The primary goal of this environment is to identify token chains using a hashing function.
|
|
29
29
|
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
|
|
30
30
|
identifiers, or easily prune repeated token chains in a data structure.
|
|
31
|
-
The following figure gives an overview of this workflow:
|
|
32
31
|
|
|
33
|
-
.. figure
|
|
34
|
-
|
|
32
|
+
.. The following figure gives an overview of this workflow:
|
|
33
|
+
.. .. figure:: /_static/img/rollout-llm.png
|
|
34
|
+
.. :alt: Data collection loop with our LLM environment.
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
37
|
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
|
torchrl/envs/llm/envs.py
CHANGED
|
@@ -601,10 +601,10 @@ class LLMHashingEnv(EnvBase):
|
|
|
601
601
|
The primary goal of this environment is to identify token chains using a hashing function.
|
|
602
602
|
This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
|
|
603
603
|
identifiers, or easily prune repeated token chains in a data structure.
|
|
604
|
-
The following figure gives an overview of this workflow:
|
|
605
604
|
|
|
606
|
-
.. figure
|
|
607
|
-
|
|
605
|
+
.. The following figure gives an overview of this workflow:
|
|
606
|
+
.. .. figure:: /_static/img/rollout-llm.png
|
|
607
|
+
.. :alt: Data collection loop with our LLM environment.
|
|
608
608
|
|
|
609
609
|
Args:
|
|
610
610
|
vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
|
|
@@ -5423,8 +5423,8 @@ class Hash(UnaryTransform):
|
|
|
5423
5423
|
"""Look up the input that was given for a particular hash output.
|
|
5424
5424
|
|
|
5425
5425
|
This feature is only available if, during initialization, either the
|
|
5426
|
-
|
|
5427
|
-
|
|
5426
|
+
``repertoire`` argument was given or both the ``in_keys_inv`` and
|
|
5427
|
+
``out_keys_inv`` arguments were given.
|
|
5428
5428
|
|
|
5429
5429
|
Args:
|
|
5430
5430
|
hash_tensor (Tensor): The hash output.
|
|
@@ -622,7 +622,7 @@ class Ordinal(D.Categorical):
|
|
|
622
622
|
not impose any notion of proximity or ordering over its support's atoms.
|
|
623
623
|
The `Ordinal` distribution explicitly encodes those concepts, which is
|
|
624
624
|
useful for learning discrete sampling from continuous sets. See §5 of
|
|
625
|
-
`Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
|
|
625
|
+
`Tang & Agrawal, 2020 <https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
|
|
626
626
|
|
|
627
627
|
.. note::
|
|
628
628
|
This class is mostly useful when you want to learn a distribution over
|
|
@@ -526,7 +526,7 @@ class AsyncVLLM(RLvLLMEngine):
|
|
|
526
526
|
See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
|
|
527
527
|
|
|
528
528
|
Example:
|
|
529
|
-
>>> from torchrl.modules.llm
|
|
529
|
+
>>> from torchrl.modules.llm import AsyncVLLM
|
|
530
530
|
>>> from vllm import SamplingParams
|
|
531
531
|
>>>
|
|
532
532
|
>>> # Simple usage - single GPU, single replica
|
|
@@ -172,6 +172,7 @@ class TransformersWrapper(LLMWrapperBase):
|
|
|
172
172
|
|
|
173
173
|
Input Keys:
|
|
174
174
|
The input key depends on both `input_mode` and `generate`:
|
|
175
|
+
|
|
175
176
|
- If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
|
|
176
177
|
- If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
|
|
177
178
|
- If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
|
|
@@ -2460,7 +2461,7 @@ class RemoteTransformersWrapper:
|
|
|
2460
2461
|
model,
|
|
2461
2462
|
max_concurrency: int = 16,
|
|
2462
2463
|
validate_model: bool = True,
|
|
2463
|
-
actor_name: str = None,
|
|
2464
|
+
actor_name: str | None = None,
|
|
2464
2465
|
num_gpus: int = 1,
|
|
2465
2466
|
num_cpus: int = 1,
|
|
2466
2467
|
**kwargs,
|
|
@@ -194,6 +194,7 @@ class vLLMWrapper(LLMWrapperBase):
|
|
|
194
194
|
|
|
195
195
|
Input Keys:
|
|
196
196
|
The input key depends on both `input_mode` and `generate`:
|
|
197
|
+
|
|
197
198
|
- If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
|
|
198
199
|
- If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
|
|
199
200
|
- If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
|
torchrl/objectives/a2c.py
CHANGED
|
@@ -282,12 +282,12 @@ class A2CLoss(LossModule):
|
|
|
282
282
|
loss_critic_type: str = "smooth_l1",
|
|
283
283
|
gamma: float | None = None,
|
|
284
284
|
separate_losses: bool = False,
|
|
285
|
-
advantage_key: str = None,
|
|
286
|
-
value_target_key: str = None,
|
|
285
|
+
advantage_key: str | None = None,
|
|
286
|
+
value_target_key: str | None = None,
|
|
287
287
|
functional: bool = True,
|
|
288
288
|
actor: ProbabilisticTensorDictSequential = None,
|
|
289
289
|
critic: ProbabilisticTensorDictSequential = None,
|
|
290
|
-
reduction: str = None,
|
|
290
|
+
reduction: str | None = None,
|
|
291
291
|
clip_value: float | None = None,
|
|
292
292
|
**kwargs,
|
|
293
293
|
):
|
torchrl/objectives/cql.py
CHANGED
|
@@ -291,7 +291,7 @@ class CQLLoss(LossModule):
|
|
|
291
291
|
num_random: int = 10,
|
|
292
292
|
with_lagrange: bool = False,
|
|
293
293
|
lagrange_thresh: float = 0.0,
|
|
294
|
-
reduction: str = None,
|
|
294
|
+
reduction: str | None = None,
|
|
295
295
|
deactivate_vmap: bool = False,
|
|
296
296
|
) -> None:
|
|
297
297
|
self._out_keys = None
|
|
@@ -1100,7 +1100,7 @@ class DiscreteCQLLoss(LossModule):
|
|
|
1100
1100
|
delay_value: bool = True,
|
|
1101
1101
|
gamma: float | None = None,
|
|
1102
1102
|
action_space=None,
|
|
1103
|
-
reduction: str = None,
|
|
1103
|
+
reduction: str | None = None,
|
|
1104
1104
|
) -> None:
|
|
1105
1105
|
self._in_keys = None
|
|
1106
1106
|
if reduction is None:
|
torchrl/objectives/crossq.py
CHANGED
|
@@ -266,9 +266,9 @@ class CrossQLoss(LossModule):
|
|
|
266
266
|
action_spec=None,
|
|
267
267
|
fixed_alpha: bool = False,
|
|
268
268
|
target_entropy: str | float = "auto",
|
|
269
|
-
priority_key: str = None,
|
|
269
|
+
priority_key: str | None = None,
|
|
270
270
|
separate_losses: bool = False,
|
|
271
|
-
reduction: str = None,
|
|
271
|
+
reduction: str | None = None,
|
|
272
272
|
deactivate_vmap: bool = False,
|
|
273
273
|
) -> None:
|
|
274
274
|
self._in_keys = None
|
torchrl/objectives/ddpg.py
CHANGED
|
@@ -85,7 +85,7 @@ class OnlineDTLoss(LossModule):
|
|
|
85
85
|
fixed_alpha: bool = False,
|
|
86
86
|
target_entropy: str | float = "auto",
|
|
87
87
|
samples_mc_entropy: int = 1,
|
|
88
|
-
reduction: str = None,
|
|
88
|
+
reduction: str | None = None,
|
|
89
89
|
) -> None:
|
|
90
90
|
self._in_keys = None
|
|
91
91
|
self._out_keys = None
|
|
@@ -296,7 +296,7 @@ class DTLoss(LossModule):
|
|
|
296
296
|
actor_network: ProbabilisticActor,
|
|
297
297
|
*,
|
|
298
298
|
loss_function: str = "l2",
|
|
299
|
-
reduction: str = None,
|
|
299
|
+
reduction: str | None = None,
|
|
300
300
|
device: torch.device | None = None,
|
|
301
301
|
) -> None:
|
|
302
302
|
self._in_keys = None
|
torchrl/objectives/deprecated.py
CHANGED
|
@@ -163,9 +163,9 @@ class REDQLoss_deprecated(LossModule):
|
|
|
163
163
|
delay_qvalue: bool = True,
|
|
164
164
|
gSDE: bool = False,
|
|
165
165
|
gamma: float | None = None,
|
|
166
|
-
priority_key: str = None,
|
|
166
|
+
priority_key: str | None = None,
|
|
167
167
|
separate_losses: bool = False,
|
|
168
|
-
reduction: str = None,
|
|
168
|
+
reduction: str | None = None,
|
|
169
169
|
deactivate_vmap: bool = False,
|
|
170
170
|
):
|
|
171
171
|
self._in_keys = None
|
torchrl/objectives/dqn.py
CHANGED
|
@@ -179,8 +179,8 @@ class DQNLoss(LossModule):
|
|
|
179
179
|
double_dqn: bool = False,
|
|
180
180
|
gamma: float | None = None,
|
|
181
181
|
action_space: str | TensorSpec = None,
|
|
182
|
-
priority_key: str = None,
|
|
183
|
-
reduction: str = None,
|
|
182
|
+
priority_key: str | None = None,
|
|
183
|
+
reduction: str | None = None,
|
|
184
184
|
) -> None:
|
|
185
185
|
if reduction is None:
|
|
186
186
|
reduction = "mean"
|
|
@@ -455,8 +455,8 @@ class DistributionalDQNLoss(LossModule):
|
|
|
455
455
|
*,
|
|
456
456
|
gamma: float,
|
|
457
457
|
delay_value: bool = True,
|
|
458
|
-
priority_key: str = None,
|
|
459
|
-
reduction: str = None,
|
|
458
|
+
priority_key: str | None = None,
|
|
459
|
+
reduction: str | None = None,
|
|
460
460
|
):
|
|
461
461
|
if reduction is None:
|
|
462
462
|
reduction = "mean"
|
torchrl/objectives/gail.py
CHANGED
torchrl/objectives/iql.py
CHANGED
|
@@ -266,9 +266,9 @@ class IQLLoss(LossModule):
|
|
|
266
266
|
temperature: float = 1.0,
|
|
267
267
|
expectile: float = 0.5,
|
|
268
268
|
gamma: float | None = None,
|
|
269
|
-
priority_key: str = None,
|
|
269
|
+
priority_key: str | None = None,
|
|
270
270
|
separate_losses: bool = False,
|
|
271
|
-
reduction: str = None,
|
|
271
|
+
reduction: str | None = None,
|
|
272
272
|
deactivate_vmap: bool = False,
|
|
273
273
|
) -> None:
|
|
274
274
|
self._in_keys = None
|
|
@@ -785,9 +785,9 @@ class DiscreteIQLLoss(IQLLoss):
|
|
|
785
785
|
temperature: float = 1.0,
|
|
786
786
|
expectile: float = 0.5,
|
|
787
787
|
gamma: float | None = None,
|
|
788
|
-
priority_key: str = None,
|
|
788
|
+
priority_key: str | None = None,
|
|
789
789
|
separate_losses: bool = False,
|
|
790
|
-
reduction: str = None,
|
|
790
|
+
reduction: str | None = None,
|
|
791
791
|
) -> None:
|
|
792
792
|
self._in_keys = None
|
|
793
793
|
self._out_keys = None
|
|
@@ -195,7 +195,7 @@ class QMixerLoss(LossModule):
|
|
|
195
195
|
delay_value: bool = True,
|
|
196
196
|
gamma: float | None = None,
|
|
197
197
|
action_space: str | TensorSpec = None,
|
|
198
|
-
priority_key: str = None,
|
|
198
|
+
priority_key: str | None = None,
|
|
199
199
|
) -> None:
|
|
200
200
|
super().__init__()
|
|
201
201
|
self._in_keys = None
|
torchrl/objectives/redq.py
CHANGED
|
@@ -279,9 +279,9 @@ class REDQLoss(LossModule):
|
|
|
279
279
|
delay_qvalue: bool = True,
|
|
280
280
|
gSDE: bool = False,
|
|
281
281
|
gamma: float | None = None,
|
|
282
|
-
priority_key: str = None,
|
|
282
|
+
priority_key: str | None = None,
|
|
283
283
|
separate_losses: bool = False,
|
|
284
|
-
reduction: str = None,
|
|
284
|
+
reduction: str | None = None,
|
|
285
285
|
deactivate_vmap: bool = False,
|
|
286
286
|
):
|
|
287
287
|
if reduction is None:
|
torchrl/objectives/reinforce.py
CHANGED
|
@@ -249,13 +249,13 @@ class ReinforceLoss(LossModule):
|
|
|
249
249
|
delay_value: bool = False,
|
|
250
250
|
loss_critic_type: str = "smooth_l1",
|
|
251
251
|
gamma: float | None = None,
|
|
252
|
-
advantage_key: str = None,
|
|
253
|
-
value_target_key: str = None,
|
|
252
|
+
advantage_key: str | None = None,
|
|
253
|
+
value_target_key: str | None = None,
|
|
254
254
|
separate_losses: bool = False,
|
|
255
255
|
functional: bool = True,
|
|
256
256
|
actor: ProbabilisticTensorDictSequential = None,
|
|
257
257
|
critic: ProbabilisticTensorDictSequential = None,
|
|
258
|
-
reduction: str = None,
|
|
258
|
+
reduction: str | None = None,
|
|
259
259
|
clip_value: float | None = None,
|
|
260
260
|
) -> None:
|
|
261
261
|
if actor is not None:
|
torchrl/objectives/sac.py
CHANGED
|
@@ -325,16 +325,16 @@ class SACLoss(LossModule):
|
|
|
325
325
|
alpha_init: float = 1.0,
|
|
326
326
|
min_alpha: float | None = None,
|
|
327
327
|
max_alpha: float | None = None,
|
|
328
|
-
action_spec=None,
|
|
328
|
+
action_spec: TensorSpec | None = None,
|
|
329
329
|
fixed_alpha: bool = False,
|
|
330
330
|
target_entropy: str | float = "auto",
|
|
331
331
|
delay_actor: bool = False,
|
|
332
332
|
delay_qvalue: bool = True,
|
|
333
333
|
delay_value: bool = True,
|
|
334
334
|
gamma: float | None = None,
|
|
335
|
-
priority_key: str = None,
|
|
335
|
+
priority_key: str | None = None,
|
|
336
336
|
separate_losses: bool = False,
|
|
337
|
-
reduction: str = None,
|
|
337
|
+
reduction: str | None = None,
|
|
338
338
|
skip_done_states: bool = False,
|
|
339
339
|
deactivate_vmap: bool = False,
|
|
340
340
|
) -> None:
|
|
@@ -1195,9 +1195,9 @@ class DiscreteSACLoss(LossModule):
|
|
|
1195
1195
|
target_entropy_weight: float = 0.98,
|
|
1196
1196
|
target_entropy: str | Number = "auto",
|
|
1197
1197
|
delay_qvalue: bool = True,
|
|
1198
|
-
priority_key: str = None,
|
|
1198
|
+
priority_key: str | None = None,
|
|
1199
1199
|
separate_losses: bool = False,
|
|
1200
|
-
reduction: str = None,
|
|
1200
|
+
reduction: str | None = None,
|
|
1201
1201
|
skip_done_states: bool = False,
|
|
1202
1202
|
deactivate_vmap: bool = False,
|
|
1203
1203
|
):
|
torchrl/objectives/td3.py
CHANGED
|
@@ -236,9 +236,9 @@ class TD3Loss(LossModule):
|
|
|
236
236
|
delay_actor: bool = True,
|
|
237
237
|
delay_qvalue: bool = True,
|
|
238
238
|
gamma: float | None = None,
|
|
239
|
-
priority_key: str = None,
|
|
239
|
+
priority_key: str | None = None,
|
|
240
240
|
separate_losses: bool = False,
|
|
241
|
-
reduction: str = None,
|
|
241
|
+
reduction: str | None = None,
|
|
242
242
|
deactivate_vmap: bool = False,
|
|
243
243
|
) -> None:
|
|
244
244
|
if reduction is None:
|
torchrl/objectives/td3_bc.py
CHANGED
|
@@ -251,9 +251,9 @@ class TD3BCLoss(LossModule):
|
|
|
251
251
|
loss_function: str = "smooth_l1",
|
|
252
252
|
delay_actor: bool = True,
|
|
253
253
|
delay_qvalue: bool = True,
|
|
254
|
-
priority_key: str = None,
|
|
254
|
+
priority_key: str | None = None,
|
|
255
255
|
separate_losses: bool = False,
|
|
256
|
-
reduction: str = None,
|
|
256
|
+
reduction: str | None = None,
|
|
257
257
|
deactivate_vmap: bool = False,
|
|
258
258
|
) -> None:
|
|
259
259
|
if reduction is None:
|
torchrl/record/loggers/wandb.py
CHANGED
|
@@ -52,9 +52,9 @@ class WandbLogger(Logger):
|
|
|
52
52
|
self,
|
|
53
53
|
exp_name: str,
|
|
54
54
|
offline: bool = False,
|
|
55
|
-
save_dir: str = None,
|
|
56
|
-
id: str = None,
|
|
57
|
-
project: str = None,
|
|
55
|
+
save_dir: str | None = None,
|
|
56
|
+
id: str | None = None,
|
|
57
|
+
project: str | None = None,
|
|
58
58
|
*,
|
|
59
59
|
video_fps: int = 32,
|
|
60
60
|
**kwargs,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: torchrl
|
|
3
|
-
Version: 0.10.
|
|
3
|
+
Version: 0.10.1
|
|
4
4
|
Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
|
|
5
5
|
Author-email: torchrl contributors <vmoens@fb.com>
|
|
6
6
|
Maintainer-email: torchrl contributors <vmoens@fb.com>
|
|
@@ -4,7 +4,7 @@ benchmarks\requirements.txt,sha256=zq-bWlShbTeSnj-Ud4NDgGngvWR_jdpnTAlMahTjr3k,7
|
|
|
4
4
|
benchmarks\test_collectors_benchmark.py,sha256=-8MQHzK7ItL4eQRuE8V40WP3EJVXktwDtgzGN23rJy8,6912
|
|
5
5
|
benchmarks\test_compressed_storage_benchmark.py,sha256=8NjVwRGKxi5CSuq9O9xLzM6MYG8R3a1afBjaXtGQjPY,6074
|
|
6
6
|
benchmarks\test_envs_benchmark.py,sha256=ntXAMgLSKtasErZso_lsM_F2ae6gnuBS-Ow0tI4OPoM,3653
|
|
7
|
-
benchmarks\test_llm.py,sha256=
|
|
7
|
+
benchmarks\test_llm.py,sha256=kHONY3smLA1JbyCBbSSeX8XSiapD6SiFxNKwikMf2-I,4038
|
|
8
8
|
benchmarks\test_objectives_benchmarks.py,sha256=EipTFZNKV_VcofGHTyIQiKVQv45k_u1QHrAMgncNB8I,34608
|
|
9
9
|
benchmarks\test_replaybuffer_benchmark.py,sha256=3VoDtjg0h3q76em1Ptq75jh_qfeHplYR4d3e1vTFsRc,8274
|
|
10
10
|
benchmarks\ecosystem\gym_env_throughput.py,sha256=_QDRDLoq0LfGHNWV_o0E9s_k-YUuvcyWR0v3-w-Sjig,14282
|
|
@@ -92,18 +92,18 @@ sota-implementations\td3\utils.py,sha256=AU18AXsfPXnsOOB7wKgo_5pe1wFobquJywnI0ay
|
|
|
92
92
|
sota-implementations\td3_bc\td3_bc.py,sha256=GUGzbGkaR7MDbmhvMCs69Ewd-iCVjaQCEt6dxcSE8yw,5927
|
|
93
93
|
sota-implementations\td3_bc\utils.py,sha256=ZMMYehJP6KZaiiwTRlJRFKC528M49QhfEb_RAxaBj-U,6951
|
|
94
94
|
torchrl\_extension.py,sha256=5DqUOUHZJPLqcSvGztUUBy4rJFAzznuJCAu5c84OJo8,2981
|
|
95
|
-
torchrl\_torchrl.cp39-win_amd64.pyd,sha256=
|
|
96
|
-
torchrl\_utils.py,sha256=
|
|
97
|
-
torchrl\__init__.py,sha256=
|
|
98
|
-
torchrl\collectors\collectors.py,sha256=
|
|
95
|
+
torchrl\_torchrl.cp39-win_amd64.pyd,sha256=yPDxj_5ErX7tV2rp4VD2LeYx0pY9DVj0zFEETqCDaFc,418816
|
|
96
|
+
torchrl\_utils.py,sha256=2CYX3KGAendGu5VCE1r_oRr_rGFCHd6jIUa1GMASLg4,32554
|
|
97
|
+
torchrl\__init__.py,sha256=KINiNwlkd5Qj3ts2XIGQpk8P10aaPERlR1iLegQgtpM,3190
|
|
98
|
+
torchrl\collectors\collectors.py,sha256=BdkinVV0S9rhd_ypGalsgm-5eOFBqurR2h0jozT8SQE,183779
|
|
99
99
|
torchrl\collectors\utils.py,sha256=DLHoLf9MvjAmJssFuG1GE9wDpjIuWgmbmMrnphg7mwc,11588
|
|
100
100
|
torchrl\collectors\weight_update.py,sha256=S9aSLYV6MtNMk23TM9ZY5nsk72s7bUJ2wnslh7k8sSI,24504
|
|
101
101
|
torchrl\collectors\__init__.py,sha256=jdZJgqPB15BBKHgd46Md1FYy6GaQjGJBIsQ2mAo8W_E,898
|
|
102
102
|
torchrl\collectors\distributed\default_configs.py,sha256=Kvbn84NRz9l7mBk_681P9I4AQ4hyeNPBfzLqydiFijQ,895
|
|
103
|
-
torchrl\collectors\distributed\generic.py,sha256=
|
|
104
|
-
torchrl\collectors\distributed\ray.py,sha256=
|
|
105
|
-
torchrl\collectors\distributed\rpc.py,sha256=
|
|
106
|
-
torchrl\collectors\distributed\sync.py,sha256=
|
|
103
|
+
torchrl\collectors\distributed\generic.py,sha256=tlxFJEYeaYJXML6kUgvykwsXj5nuZGsKWUKWAnOMpYM,46521
|
|
104
|
+
torchrl\collectors\distributed\ray.py,sha256=kFuamafzMlZL7uT5ALzgm8vVb5GXfcibG7CXN3vqxN8,37082
|
|
105
|
+
torchrl\collectors\distributed\rpc.py,sha256=kJZUX-5bFZGTvjRnYnsXumD9eESAkZIz7Yzn2IVpYvA,40492
|
|
106
|
+
torchrl\collectors\distributed\sync.py,sha256=Tg4dlYGpxrTgILXJAugH8c0ZzrIIdW3vvAxiQF1jnAA,27933
|
|
107
107
|
torchrl\collectors\distributed\utils.py,sha256=eY6M-vLCSzyACHRNBx5bHcieWsZfLg7DfNKGIv0IgHI,6625
|
|
108
108
|
torchrl\collectors\distributed\__init__.py,sha256=cKDWdNlwx2LoJkTwf-DKUXbq3Y-0Z1DctPYPcdgOSU0,730
|
|
109
109
|
torchrl\collectors\llm\base.py,sha256=ELhduZbaELYFpgitPoXB4qPhX2bsunqJgB2pOpwjG4U,22189
|
|
@@ -120,7 +120,7 @@ torchrl\csrc\torch_utils.h,sha256=k7gTjLle9wW_TG6GrqqOYIG1MKsWqjHPcoqPDxLys5Y,73
|
|
|
120
120
|
torchrl\csrc\utils.cpp,sha256=agOkJ0G4ytRsuGmkTXT1kBLGWOxrpkd6LHDVeZymEgA,1690
|
|
121
121
|
torchrl\csrc\utils.h,sha256=bXlPW94HH4UMRDXXbPgfC25SvI_txvAkueBmSex9g7M,1132
|
|
122
122
|
torchrl\data\rlhf.py,sha256=y4KwcYjtlons4czR72LGLqTjfl913EKkP_qXNXO_LC4,1040
|
|
123
|
-
torchrl\data\tensor_specs.py,sha256=
|
|
123
|
+
torchrl\data\tensor_specs.py,sha256=NpRGa6Kz_uvkn8ZjZAguFiIA7B9r2zPeMP91G9BlKoE,268154
|
|
124
124
|
torchrl\data\utils.py,sha256=krn1klWLdfdxy7afevsVecX0KGTuNUSz2D2yNCGW1Hw,12459
|
|
125
125
|
torchrl\data\__init__.py,sha256=SBu_aozTi8iEWZ_EJcWUMqERDXH9i-S3O6hjMb5xjno,5166
|
|
126
126
|
torchrl\data\datasets\atari_dqn.py,sha256=zet-dUhsxIbPNMqcbqksMCBSP9-ZvzIqxsviRaRTJtY,41626
|
|
@@ -145,7 +145,7 @@ torchrl\data\llm\__init__.py,sha256=-_UPiaQgHzFVLM_gfiOS0sCGVRwkxZ1_Z5B6C9ThN9o,
|
|
|
145
145
|
torchrl\data\map\hash.py,sha256=AilOzYQ0KYhCZpVZCm63AhRuXn2P_RLB4PcbIx6qnlA,7446
|
|
146
146
|
torchrl\data\map\query.py,sha256=CfAC9XJh7KpdCJNgqfJ5CUCi7BWqekgCMlNVSIF60To,8125
|
|
147
147
|
torchrl\data\map\tdstorage.py,sha256=HzdR7M8Fjt9543vqdEpy0QQ8tHgKLCqWYCA6P02_2dM,15143
|
|
148
|
-
torchrl\data\map\tree.py,sha256=
|
|
148
|
+
torchrl\data\map\tree.py,sha256=EhYin_p5qRfb1mrfD_1mkYelyNupq4upMRaRPIvPLUY,60782
|
|
149
149
|
torchrl\data\map\utils.py,sha256=fEjqCzaE4Vqjb8OzUvnClmLxVMooqeFOBMs7wroYvxs,3022
|
|
150
150
|
torchrl\data\map\__init__.py,sha256=bON0vqCksU7FPoWNqiNcdl60t7yWUh9SdLhNtglj7jI,576
|
|
151
151
|
torchrl\data\postprocs\postprocs.py,sha256=dpXOKWlhdKy4Um7HdzRKe42PJ_Q1jHC7AX5plR9AIiw,15509
|
|
@@ -160,15 +160,15 @@ torchrl\data\replay_buffers\utils.py,sha256=RPKS5C5U2GDPJQ1zSiawEi4LIwWuazJfcgQJ
|
|
|
160
160
|
torchrl\data\replay_buffers\writers.py,sha256=WVChK3QPVb2Ehoqn3U_c8W2FJ5nU57hqCb7XTthrUgU,29488
|
|
161
161
|
torchrl\data\replay_buffers\__init__.py,sha256=RcJSEXHz6zt1gQSFhEaKwQHUhZkYa2x-buNSDkNAslE,2595
|
|
162
162
|
torchrl\envs\async_envs.py,sha256=5ao4aEdETWTaTHoUru5GZOzH5uvNtGGmHxnniSk4Wdc,44125
|
|
163
|
-
torchrl\envs\batched_envs.py,sha256=
|
|
164
|
-
torchrl\envs\common.py,sha256=
|
|
163
|
+
torchrl\envs\batched_envs.py,sha256=hhlNbzE0h1euxAWLgbxClTRPXySIDKKec15Sqlk3bjs,121388
|
|
164
|
+
torchrl\envs\common.py,sha256=vZgFUaunYISFYoYCZ2OlgDWek4aWWq8g61FArECUTLs,174230
|
|
165
165
|
torchrl\envs\env_creator.py,sha256=AAuZNNgvm_jX2_014AWWvNI5EQCmerU8ICCIAvv3PiM,10329
|
|
166
166
|
torchrl\envs\gym_like.py,sha256=dky7JLsHAVnTdLimf4KAZGsPP104SFLD4fVzlmyAYh8,32381
|
|
167
167
|
torchrl\envs\utils.py,sha256=YTCNO6XyCfxm9njCFafrDh1nnZ7wTPYH8uPAeCtG4mI,74861
|
|
168
168
|
torchrl\envs\vec_envs.py,sha256=B3lrPCVk4jRIXy0V0berwktZInHpx_UBABTcPUlA1Lw,377
|
|
169
169
|
torchrl\envs\__init__.py,sha256=2eVr8StUSMiNd-IoD5BQAFFuV10pAtO926b6QzRzB_M,6082
|
|
170
170
|
torchrl\envs\custom\chess.py,sha256=IiudT29mY6Yvau6CIXB5678d8kvkCoBvLh9FoQeSiM0,25285
|
|
171
|
-
torchrl\envs\custom\llm.py,sha256=
|
|
171
|
+
torchrl\envs\custom\llm.py,sha256=FnZltPFKQUOH49qlvABk9dth2d2ZRkO_tiA0kiu6LKs,8869
|
|
172
172
|
torchrl\envs\custom\pendulum.py,sha256=8sBgT8DvHrg5-YSOduC8GAHM9v7SXXfd8BzHvQ3wHOU,18617
|
|
173
173
|
torchrl\envs\custom\san_moves.txt,sha256=AMStL2XCAnEbO6UZEYfDSCp5zRO2811gPEIzOWbmmRY,217492
|
|
174
174
|
torchrl\envs\custom\tictactoeenv.py,sha256=voszQ7rPl7PbCBB6BQ8OELCi5US5YFm5gb-CLOkAIRM,12547
|
|
@@ -194,7 +194,7 @@ torchrl\envs\libs\vmas.py,sha256=giTORg2AqYzyjrazdD94fD2dNYwX7qe5TFnr-E1mjIg,371
|
|
|
194
194
|
torchrl\envs\libs\_gym_utils.py,sha256=JYCNtWW4gAYwLq4k87ZdwLtDR_mRSyWQOi4seuXllOI,13150
|
|
195
195
|
torchrl\envs\libs\__init__.py,sha256=tNiqWDxI-PQ2sxer7atuaSmKN_35pUpDIVFor-GIPfg,1821
|
|
196
196
|
torchrl\envs\llm\chat.py,sha256=vCtRIhlw_8jQ55KL8x8-7N9dXxb83z68MirCD2jqt6E,32529
|
|
197
|
-
torchrl\envs\llm\envs.py,sha256=
|
|
197
|
+
torchrl\envs\llm\envs.py,sha256=z73Urni2ZCTPlQvMyHaxTzkvog7TwWqBI-c-GRd5TEk,35810
|
|
198
198
|
torchrl\envs\llm\__init__.py,sha256=mxvPV4WD_jViongRHTYScOHwwGPjUjKhch90jA11IuY,1504
|
|
199
199
|
torchrl\envs\llm\datasets\gsm8k.py,sha256=sja-7uzYSRd2CUYI6pe3SWq3YlKWWVmLqoNSFuw5tW0,17040
|
|
200
200
|
torchrl\envs\llm\datasets\ifeval.py,sha256=9uSTySm3PKDIsgX7axHI1b-Z2IZmcNzUIACAIoDXuQM,12373
|
|
@@ -231,7 +231,7 @@ torchrl\envs\transforms\llm.py,sha256=V2ZY8-QY27GCpGY5i0UrryohQclybyL7aZwU9glc7w
|
|
|
231
231
|
torchrl\envs\transforms\r3m.py,sha256=3B-JB3GHh3s1Af69WZ3wl3BU8SP0g_QmuH8IPztXRbQ,13850
|
|
232
232
|
torchrl\envs\transforms\rb_transforms.py,sha256=eoJVEOv2ckVHth7nBgRaULW4TICf7YoQHcbWn9n1Cns,7661
|
|
233
233
|
torchrl\envs\transforms\rlhf.py,sha256=DlAgMrLWVFkUQ3inpFgfHDkGjJKmFwhxiRGV3FWGZK8,691
|
|
234
|
-
torchrl\envs\transforms\transforms.py,sha256=
|
|
234
|
+
torchrl\envs\transforms\transforms.py,sha256=OqABOhEw36yzPc_oYEhMkMkj0PBbCIsGc9igrx3V-8Q,499488
|
|
235
235
|
torchrl\envs\transforms\utils.py,sha256=V7YAV2BcJWvhC6aUV9LcwOodZFPbKmsltyR747OdRTU,3358
|
|
236
236
|
torchrl\envs\transforms\vc1.py,sha256=snXdONyRKkyMiaW-bT7SwDJUQVb5GWr1mqY1W78Ohn0,10841
|
|
237
237
|
torchrl\envs\transforms\vecnorm.py,sha256=udY-bdOhm-Aqjpt-STQT_mT6Ee50j5XH7v70gmyoKKk,34915
|
|
@@ -239,7 +239,7 @@ torchrl\envs\transforms\vip.py,sha256=r8Ni0hAYY1gispLj0TXV2VIedrgC4eW3hAhJBv47Q7
|
|
|
239
239
|
torchrl\envs\transforms\__init__.py,sha256=d0p0afpcykAcBU6HENDJxtq91UEkooY69wbFgyOFIxE,3281
|
|
240
240
|
torchrl\modules\__init__.py,sha256=TuJj3WUlvilYY39nUH-ykXkyprTxjq9NLW0QQXADqJk,4343
|
|
241
241
|
torchrl\modules\distributions\continuous.py,sha256=Q9T8okHY625AGthEet5WOWfT0DTTSBeqWbIOB3lqPlE,26443
|
|
242
|
-
torchrl\modules\distributions\discrete.py,sha256=
|
|
242
|
+
torchrl\modules\distributions\discrete.py,sha256=A4-LMggbd0RBGsM3XmaQIyp7tzncdJpUHwMUzGQE9CM,36500
|
|
243
243
|
torchrl\modules\distributions\truncated_normal.py,sha256=l5G3TePasl7q12DjwisyQC_E0OfZZo2g_HzBhZREVxc,6122
|
|
244
244
|
torchrl\modules\distributions\utils.py,sha256=q4AFDKFpacRhrl4rjJ54UhxQzjOcj_SKlz0UIcZlUVc,7796
|
|
245
245
|
torchrl\modules\distributions\__init__.py,sha256=Evkiz96ZPs7VUZp2n03h9kd7rmUCEEvMVl2f7RhzMhQ,1670
|
|
@@ -247,13 +247,13 @@ torchrl\modules\llm\utils.py,sha256=b2s9ngHwXnNbLggygU3-ScNwk0MWICketq2pZBshGqM,
|
|
|
247
247
|
torchrl\modules\llm\__init__.py,sha256=_jPEt6oFb_R75zcVWlrfl4OMVxEs3Y9zCjVZp8dgg38,1098
|
|
248
248
|
torchrl\modules\llm\backends\__init__.py,sha256=O7RanoHTBR4xLQLUJ2JUG4-zlVZ7PJDx1d2_AezVmaU,1027
|
|
249
249
|
torchrl\modules\llm\backends\vllm\base.py,sha256=KZs36Q0sNveEkHJrub6xD_SzZafAdWz5ZK5ssHphMHM,2149
|
|
250
|
-
torchrl\modules\llm\backends\vllm\vllm_async.py,sha256=
|
|
250
|
+
torchrl\modules\llm\backends\vllm\vllm_async.py,sha256=rj9vyftLJVFjIQqLEFvHw3YhA_jzpYs0C9hQl1LImVs,80766
|
|
251
251
|
torchrl\modules\llm\backends\vllm\vllm_sync.py,sha256=K7P_da0XqAEO9lzniN3gso8IKWgP_S4ueyGncZtMEXU,15958
|
|
252
252
|
torchrl\modules\llm\backends\vllm\vllm_utils.py,sha256=qRwirmMfTvqwr34-AtwcpjMp-InhDj7Iz2ZnEbkTcxE,4320
|
|
253
253
|
torchrl\modules\llm\backends\vllm\__init__.py,sha256=LPqxC7ijHq6kjcA1a0kqGZU_JK_r8oT8BetsU4KKsWY,1816
|
|
254
254
|
torchrl\modules\llm\policies\common.py,sha256=dhw_Q64Nk3WrfLqjfLlLVY5BMSCp1pbjUJT8MHuoEjQ,58104
|
|
255
|
-
torchrl\modules\llm\policies\transformers_wrapper.py,sha256
|
|
256
|
-
torchrl\modules\llm\policies\vllm_wrapper.py,sha256=
|
|
255
|
+
torchrl\modules\llm\policies\transformers_wrapper.py,sha256=HN2qsnFXdCoLHkJsuMp2Kp3wwMN9hEpWMebDHkJ-8xg,114583
|
|
256
|
+
torchrl\modules\llm\policies\vllm_wrapper.py,sha256=uMbis1Coqj7oke2fTwtQZoPnbUgRT-CPrcM4SQWrwlA,95599
|
|
257
257
|
torchrl\modules\llm\policies\__init__.py,sha256=4rOIFNkYAZXMB5WIAomMV0tTZVP_J9mA4uatDpySbk8,628
|
|
258
258
|
torchrl\modules\models\batchrenorm.py,sha256=bR4ZhaJ5E1cSK5o8L2dNX5KVLIb-bgrYxcq6yhx0I1A,4869
|
|
259
259
|
torchrl\modules\models\decision_transformer.py,sha256=ANFTOm3k9_3Uv1vKGdXumRy3meBPnDdT8HqhVvJ2RCo,6783
|
|
@@ -281,30 +281,30 @@ torchrl\modules\tensordict_module\__init__.py,sha256=iTz8iCBmxt661GrGJRBfw4tBoTu
|
|
|
281
281
|
torchrl\modules\utils\mappings.py,sha256=HEPGNHhQrPNU85-Bq0cYm1TZIhSkdEBkLvgrmjFMa4Q,371
|
|
282
282
|
torchrl\modules\utils\utils.py,sha256=Ae2bl6GDxm9kU73WeLi-0ZEsrFt-XTaGqdxdXiX9LSU,3005
|
|
283
283
|
torchrl\modules\utils\__init__.py,sha256=NQ_ko0JAIPY_X5RgBJnZLZXnYSH2q_kuD0tvXGqqY3k,1165
|
|
284
|
-
torchrl\objectives\a2c.py,sha256=
|
|
284
|
+
torchrl\objectives\a2c.py,sha256=Pc-xmumKqLw0ovMzLZ5F0BkovUB-LsUEsGvF2WVwTI4,29470
|
|
285
285
|
torchrl\objectives\common.py,sha256=nslOhX1hi0nYvS8v7ylt5qB40fkQ7k-_XvO2p99ppTM,29274
|
|
286
|
-
torchrl\objectives\cql.py,sha256=
|
|
287
|
-
torchrl\objectives\crossq.py,sha256=
|
|
288
|
-
torchrl\objectives\ddpg.py,sha256=
|
|
289
|
-
torchrl\objectives\decision_transformer.py,sha256
|
|
290
|
-
torchrl\objectives\deprecated.py,sha256=
|
|
291
|
-
torchrl\objectives\dqn.py,sha256=
|
|
286
|
+
torchrl\objectives\cql.py,sha256=a6Fltgf05z_bP3YaXIydYtD4VZNSVhbHgMSxZsynHAg,56266
|
|
287
|
+
torchrl\objectives\crossq.py,sha256=KLG7DGZApOZWAktYYV-HLRh2C070zLAjcUh24xPqkbI,29571
|
|
288
|
+
torchrl\objectives\ddpg.py,sha256=KxZDlsNehydFh7-4oVVouGK6Dz8CO19uEpgHayd53mM,18221
|
|
289
|
+
torchrl\objectives\decision_transformer.py,sha256=vX-Gr8bXSQwq-gyPtFQWcrL_SO5ELuQfmBZm1uWzNzc,13419
|
|
290
|
+
torchrl\objectives\deprecated.py,sha256=jm77VQqK68nZziTlL14VZ3FHIAPVcUv1DUMGW33diBo,21267
|
|
291
|
+
torchrl\objectives\dqn.py,sha256=XhtPvhtNty1Gka9swHIJIF7HOM3huZqk0XTJ8RXVJcw,28831
|
|
292
292
|
torchrl\objectives\dreamer.py,sha256=65EntKqou3auLMYxD1uaKGNyucfktabqaATNT1bExQc,18497
|
|
293
293
|
torchrl\objectives\functional.py,sha256=0Pr_debAMM2bp06HPGVIpLTcyBue4DvcyUJVsaa6AjE,2154
|
|
294
|
-
torchrl\objectives\gail.py,sha256=
|
|
295
|
-
torchrl\objectives\iql.py,sha256=
|
|
294
|
+
torchrl\objectives\gail.py,sha256=MCJ-TE_asCp-NTfSgrqkUx9DWrR1GXth7VqrH46lndA,9855
|
|
295
|
+
torchrl\objectives\iql.py,sha256=CwymtcSV3RDktRVCbsQihST5PFik382-5J5qnRA3U8E,44004
|
|
296
296
|
torchrl\objectives\ppo.py,sha256=l2fGbQ45Zd0mwSijLtHYk3rvaoOR451LPXwcVq1L7Zw,82038
|
|
297
|
-
torchrl\objectives\redq.py,sha256=
|
|
298
|
-
torchrl\objectives\reinforce.py,sha256=
|
|
299
|
-
torchrl\objectives\sac.py,sha256=
|
|
300
|
-
torchrl\objectives\td3.py,sha256=
|
|
301
|
-
torchrl\objectives\td3_bc.py,sha256=
|
|
297
|
+
torchrl\objectives\redq.py,sha256=tLGi5wh8gErf0Ds725n_mwd3iSvpCXSh2w0WFJoqQvY,29191
|
|
298
|
+
torchrl\objectives\reinforce.py,sha256=Fzu-7VBiepxZT_bXX17yc4fyPl7BFjqQerCokVR29E8,22882
|
|
299
|
+
torchrl\objectives\sac.py,sha256=Djj2yGhqu1cglIVoOvq232Q9fkoNq5D1P7z9Flis9GI,70910
|
|
300
|
+
torchrl\objectives\td3.py,sha256=vixzm2sITvLVrojCnIYGopyimvcL8o_dML4nc93-WJY,23757
|
|
301
|
+
torchrl\objectives\td3_bc.py,sha256=oKZm3BRHsg-DK5fzFw1DC6fX8hrzbh8tmCOGvw5a62Q,26572
|
|
302
302
|
torchrl\objectives\utils.py,sha256=CfEk41IWgVpzgZ7jq7rWS7FZbh4ymYsv92Td2rCWFRE,25990
|
|
303
303
|
torchrl\objectives\__init__.py,sha256=Ug1FX1kFbTSz_i51uaDw7pOBIXSUIbH7BE5_8PZNbHM,2245
|
|
304
304
|
torchrl\objectives\llm\grpo.py,sha256=kXcHSA_uPAEqvqRjjLTrzyupqMG9yuiRh_G0AxMAzaQ,25307
|
|
305
305
|
torchrl\objectives\llm\sft.py,sha256=U1jtwZfDYLaU1YzA6edGhRo4t00dHpWbVKQLyyF2f1g,21352
|
|
306
306
|
torchrl\objectives\llm\__init__.py,sha256=tZmIz3rkeclw3MzJoOWEs2gkewjx2USKrKJbWdyiiaQ,406
|
|
307
|
-
torchrl\objectives\multiagent\qmixer.py,sha256=
|
|
307
|
+
torchrl\objectives\multiagent\qmixer.py,sha256=MQST8UvktQLG9Z1b12Fj0RdxloWKj87paxL0DWwSucc,17369
|
|
308
308
|
torchrl\objectives\multiagent\__init__.py,sha256=5uebDe5KrvlzeYV_BSd5vdmfruJQYMeDVVbU4iHErEg,245
|
|
309
309
|
torchrl\objectives\value\advantages.py,sha256=Nz0IANqvV7uAMzghxA6Ta1DdkEfLo2w2Z21wd2w1duo,86865
|
|
310
310
|
torchrl\objectives\value\functional.py,sha256=bgZiXJKuOmqlKdtTzWXvbgab4yT01xik-cTW2YQkTNU,51091
|
|
@@ -317,7 +317,7 @@ torchrl\record\loggers\csv.py,sha256=uNFjiPLq7mMr5z2WPyjyr9HGexu4ZkUwbX09FsV1mJ4
|
|
|
317
317
|
torchrl\record\loggers\mlflow.py,sha256=9N-a5OUJJGwYej0WvTxQPkrazsahhmgof8seMDCnjM0,5098
|
|
318
318
|
torchrl\record\loggers\tensorboard.py,sha256=x1Mo7KE4-iGG5NVToAP-1XceG_F6Vipr26B_1CK9Tg0,5005
|
|
319
319
|
torchrl\record\loggers\utils.py,sha256=rZqyZi-ebLozHh8pbV-7m23W27zXsbnR04Rh-yWMPV8,2432
|
|
320
|
-
torchrl\record\loggers\wandb.py,sha256=
|
|
320
|
+
torchrl\record\loggers\wandb.py,sha256=tkedfdd4RqTPFfe6xnyIhPm8T5a4Lzi9wW0ww1PY4iE,7298
|
|
321
321
|
torchrl\record\loggers\__init__.py,sha256=pa6ttxj0FORHS6MgiYg05iFoABwJ8vqBHn45wkqshT4,568
|
|
322
322
|
torchrl\trainers\trainers.py,sha256=_1SvHfNGwUd8w_YeRJBpJPm4Bvrv-5u7HsyKk5V3FCA,66905
|
|
323
323
|
torchrl\trainers\__init__.py,sha256=LEUdW1zV5jydpMjZqnJ7XZW77MBIVv8dZ1nGv-m89RQ,853
|
|
@@ -344,8 +344,9 @@ torchrl\trainers\helpers\models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsM
|
|
|
344
344
|
torchrl\trainers\helpers\replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
|
|
345
345
|
torchrl\trainers\helpers\trainers.py,sha256=VVhAXHcutHyVa7kJEo_RtaI9U5h0Hk2qLEnONXFpPQ8,12350
|
|
346
346
|
torchrl\trainers\helpers\__init__.py,sha256=sCBIXQqFQKRrbcNojgPxIh82HpXnXKgA_kMa3uZESSk,1137
|
|
347
|
-
torchrl-0.10.
|
|
348
|
-
torchrl-0.10.
|
|
349
|
-
torchrl-0.10.
|
|
350
|
-
torchrl-0.10.
|
|
351
|
-
torchrl-0.10.
|
|
347
|
+
torchrl-0.10.1.dist-info\entry_points.txt,sha256=kjqZUboF3jzU21uy15NPn2WDfbwGE21Ls0fmfhqhmy4,110
|
|
348
|
+
torchrl-0.10.1.dist-info\LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
|
|
349
|
+
torchrl-0.10.1.dist-info\METADATA,sha256=S5rAbED_7qvdSpMbLin0amLt_NkQ1aFqx93zdIPNz2g,50106
|
|
350
|
+
torchrl-0.10.1.dist-info\RECORD,,
|
|
351
|
+
torchrl-0.10.1.dist-info\top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
|
|
352
|
+
torchrl-0.10.1.dist-info\WHEEL,sha256=DmZ7B4aiAganfWOCUyjXG_z2uvUu-tkD3rVXMivgyOM,99
|
|
File without changes
|
|
File without changes
|
|
File without changes
|