torchrl 0.10.0__cp39-cp39-win_amd64.whl → 0.10.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. benchmarks/test_llm.py +5 -0
  2. torchrl/__init__.py +4 -1
  3. torchrl/_torchrl.cp39-win_amd64.pyd +0 -0
  4. torchrl/_utils.py +3 -1
  5. torchrl/collectors/collectors.py +11 -10
  6. torchrl/collectors/distributed/generic.py +3 -3
  7. torchrl/collectors/distributed/ray.py +10 -4
  8. torchrl/collectors/distributed/rpc.py +3 -3
  9. torchrl/collectors/distributed/sync.py +3 -3
  10. torchrl/data/map/tree.py +2 -2
  11. torchrl/data/tensor_specs.py +191 -8
  12. torchrl/envs/batched_envs.py +1 -1
  13. torchrl/envs/common.py +1 -1
  14. torchrl/envs/custom/llm.py +3 -3
  15. torchrl/envs/llm/envs.py +3 -3
  16. torchrl/envs/transforms/transforms.py +2 -2
  17. torchrl/modules/distributions/discrete.py +1 -1
  18. torchrl/modules/llm/backends/vllm/vllm_async.py +1 -1
  19. torchrl/modules/llm/policies/transformers_wrapper.py +2 -1
  20. torchrl/modules/llm/policies/vllm_wrapper.py +1 -0
  21. torchrl/objectives/a2c.py +3 -3
  22. torchrl/objectives/cql.py +2 -2
  23. torchrl/objectives/crossq.py +2 -2
  24. torchrl/objectives/ddpg.py +1 -1
  25. torchrl/objectives/decision_transformer.py +2 -2
  26. torchrl/objectives/deprecated.py +2 -2
  27. torchrl/objectives/dqn.py +4 -4
  28. torchrl/objectives/gail.py +1 -1
  29. torchrl/objectives/iql.py +4 -4
  30. torchrl/objectives/multiagent/qmixer.py +1 -1
  31. torchrl/objectives/redq.py +2 -2
  32. torchrl/objectives/reinforce.py +3 -3
  33. torchrl/objectives/sac.py +5 -5
  34. torchrl/objectives/td3.py +2 -2
  35. torchrl/objectives/td3_bc.py +2 -2
  36. torchrl/record/loggers/wandb.py +3 -3
  37. {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/METADATA +1 -1
  38. {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/RECORD +42 -41
  39. torchrl-0.10.1.dist-info/entry_points.txt +2 -0
  40. {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/LICENSE +0 -0
  41. {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/WHEEL +0 -0
  42. {torchrl-0.10.0.dist-info → torchrl-0.10.1.dist-info}/top_level.txt +0 -0
benchmarks/test_llm.py CHANGED
@@ -16,6 +16,11 @@ from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrappe
16
16
 
17
17
  _has_transformers = importlib.import_module("transformers") is not None
18
18
 
19
+ # Skip all these tests if gpu is not available
20
+ pytestmark = pytest.mark.skipif(
21
+ not torch.cuda.is_available(), reason="GPU not available"
22
+ )
23
+
19
24
 
20
25
  @pytest.fixture(scope="module")
21
26
  def transformers_wrapper():
torchrl/__init__.py CHANGED
@@ -27,7 +27,10 @@ from ._extension import _init_extension
27
27
  try:
28
28
  from .version import __version__
29
29
  except ImportError:
30
- __version__ = "0.0.0+unknown"
30
+ try:
31
+ from ._version import __version__
32
+ except ImportError:
33
+ __version__ = "0.0.0+unknown"
31
34
 
32
35
  try:
33
36
  from torch.compiler import is_dynamo_compiling
Binary file
torchrl/_utils.py CHANGED
@@ -410,7 +410,9 @@ def accept_remote_rref_udf_invocation(decorated_class):
410
410
  """Class decorator that applies `accept_remote_rref_invocation` to all public methods."""
411
411
  # ignores private methods
412
412
  for name in dir(decorated_class):
413
- method = getattr(decorated_class, name)
413
+ method = getattr(decorated_class, name, None)
414
+ if method is None:
415
+ continue
414
416
  if callable(method) and not name.startswith("_"):
415
417
  setattr(decorated_class, name, accept_remote_rref_invocation(method))
416
418
  return decorated_class
@@ -283,12 +283,13 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
283
283
  ) -> None:
284
284
  """Shuts down the collector when started asynchronously with the `start` method.
285
285
 
286
- Arg:
286
+ Args:
287
287
  timeout (float, optional): The maximum time to wait for the collector to shutdown.
288
288
  close_env (bool, optional): If True, the collector will close the contained environment.
289
289
  Defaults to `True`.
290
290
 
291
291
  .. seealso:: :meth:`~.start`
292
+
292
293
  """
293
294
  return self.shutdown(timeout=timeout, close_env=close_env)
294
295
 
@@ -440,7 +441,7 @@ class SyncDataCollector(DataCollectorBase):
440
441
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
441
442
 
442
443
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
443
- pickled directly), the :arg:`policy_factory` should be used instead.
444
+ pickled directly), the ``policy_factory`` should be used instead.
444
445
 
445
446
  Keyword Args:
446
447
  policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -1784,7 +1785,7 @@ class _MultiDataCollector(DataCollectorBase):
1784
1785
  ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
1785
1786
 
1786
1787
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
1787
- pickled directly), the :arg:`policy_factory` should be used instead.
1788
+ pickled directly), the ``policy_factory`` should be used instead.
1788
1789
 
1789
1790
  Keyword Args:
1790
1791
  policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -2749,8 +2750,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
2749
2750
  ... if i == 2:
2750
2751
  ... print(data)
2751
2752
  ... break
2752
- >>> collector.shutdown()
2753
- >>> del collector
2753
+ ... collector.shutdown()
2754
+ ... del collector
2754
2755
  TensorDict(
2755
2756
  fields={
2756
2757
  action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3130,8 +3131,8 @@ class MultiaSyncDataCollector(_MultiDataCollector):
3130
3131
  ... if i == 2:
3131
3132
  ... print(data)
3132
3133
  ... break
3133
- ... collector.shutdown()
3134
- ... del collector
3134
+ ... collector.shutdown()
3135
+ ... del collector
3135
3136
  TensorDict(
3136
3137
  fields={
3137
3138
  action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3366,7 +3367,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
3366
3367
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
3367
3368
 
3368
3369
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
3369
- pickled directly), the :arg:`policy_factory` should be used instead.
3370
+ pickled directly), the ``policy_factory`` should be used instead.
3370
3371
 
3371
3372
  Keyword Args:
3372
3373
  policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -3380,8 +3381,8 @@ class aSyncDataCollector(MultiaSyncDataCollector):
3380
3381
  total number of frames returned by the collector
3381
3382
  during its lifespan. If the ``total_frames`` is not divisible by
3382
3383
  ``frames_per_batch``, an exception is raised.
3383
- Endless collectors can be created by passing ``total_frames=-1``.
3384
- Defaults to ``-1`` (never ending collector).
3384
+ Endless collectors can be created by passing ``total_frames=-1``.
3385
+ Defaults to ``-1`` (never ending collector).
3385
3386
  device (int, str or torch.device, optional): The generic device of the
3386
3387
  collector. The ``device`` args fills any non-specified device: if
3387
3388
  ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
@@ -282,7 +282,7 @@ class DistributedDataCollector(DataCollectorBase):
282
282
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
283
283
 
284
284
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
285
- pickled directly), the :arg:`policy_factory` should be used instead.
285
+ pickled directly), the ``policy_factory`` should be used instead.
286
286
 
287
287
  Keyword Args:
288
288
  policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -296,8 +296,8 @@ class DistributedDataCollector(DataCollectorBase):
296
296
  number of frames returned by the collector
297
297
  during its lifespan. If the ``total_frames`` is not divisible by
298
298
  ``frames_per_batch``, an exception is raised.
299
- Endless collectors can be created by passing ``total_frames=-1``.
300
- Defaults to ``-1`` (endless collector).
299
+ Endless collectors can be created by passing ``total_frames=-1``.
300
+ Defaults to ``-1`` (endless collector).
301
301
  device (int, str or torch.device, optional): The generic device of the
302
302
  collector. The ``device`` args fills any non-specified device: if
303
303
  ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
@@ -131,7 +131,7 @@ class RayCollector(DataCollectorBase):
131
131
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
132
132
 
133
133
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
134
- pickled directly), the :arg:`policy_factory` should be used instead.
134
+ pickled directly), the ``policy_factory`` should be used instead.
135
135
 
136
136
  Keyword Args:
137
137
  policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -263,6 +263,10 @@ class RayCollector(DataCollectorBase):
263
263
  If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging
264
264
  Ray's distributed capabilities.
265
265
  Consider using a constructor if the updater needs to be serialized.
266
+ use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped
267
+ in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory
268
+ needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed.
269
+ Defaults to ``False``.
266
270
 
267
271
  Examples:
268
272
  >>> from torch import nn
@@ -326,6 +330,7 @@ class RayCollector(DataCollectorBase):
326
330
  weight_updater: WeightUpdaterBase
327
331
  | Callable[[], WeightUpdaterBase]
328
332
  | None = None,
333
+ use_env_creator: bool = False,
329
334
  ):
330
335
  self.frames_per_batch = frames_per_batch
331
336
  if remote_configs is None:
@@ -400,9 +405,10 @@ class RayCollector(DataCollectorBase):
400
405
  create_env_fn, collector_kwargs, remote_configs = out_lists
401
406
  num_collectors = len(create_env_fn)
402
407
 
403
- for i in range(len(create_env_fn)):
404
- if not isinstance(create_env_fn[i], (EnvBase, EnvCreator)):
405
- create_env_fn[i] = EnvCreator(create_env_fn[i])
408
+ if use_env_creator:
409
+ for i in range(len(create_env_fn)):
410
+ if not isinstance(create_env_fn[i], (EnvBase, EnvCreator)):
411
+ create_env_fn[i] = EnvCreator(create_env_fn[i])
406
412
 
407
413
  # If ray available, try to connect to an existing Ray cluster or start one and connect to it.
408
414
  if not _has_ray:
@@ -121,7 +121,7 @@ class RPCDataCollector(DataCollectorBase):
121
121
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
122
122
 
123
123
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
124
- pickled directly), the :arg:`policy_factory` should be used instead.
124
+ pickled directly), the ``policy_factory`` should be used instead.
125
125
 
126
126
  Keyword Args:
127
127
  policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -135,8 +135,8 @@ class RPCDataCollector(DataCollectorBase):
135
135
  number of frames returned by the collector
136
136
  during its lifespan. If the ``total_frames`` is not divisible by
137
137
  ``frames_per_batch``, an exception is raised.
138
- Endless collectors can be created by passing ``total_frames=-1``.
139
- Defaults to ``-1`` (endless collector).
138
+ Endless collectors can be created by passing ``total_frames=-1``.
139
+ Defaults to ``-1`` (endless collector).
140
140
  device (int, str or torch.device, optional): The generic device of the
141
141
  collector. The ``device`` args fills any non-specified device: if
142
142
  ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
@@ -158,7 +158,7 @@ class DistributedSyncDataCollector(DataCollectorBase):
158
158
  - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
159
159
 
160
160
  .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
161
- pickled directly), the :arg:`policy_factory` should be used instead.
161
+ pickled directly), the ``policy_factory`` should be used instead.
162
162
 
163
163
  Keyword Args:
164
164
  policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -172,8 +172,8 @@ class DistributedSyncDataCollector(DataCollectorBase):
172
172
  number of frames returned by the collector
173
173
  during its lifespan. If the ``total_frames`` is not divisible by
174
174
  ``frames_per_batch``, an exception is raised.
175
- Endless collectors can be created by passing ``total_frames=-1``.
176
- Defaults to ``-1`` (endless collector).
175
+ Endless collectors can be created by passing ``total_frames=-1``.
176
+ Defaults to ``-1`` (endless collector).
177
177
  device (int, str or torch.device, optional): The generic device of the
178
178
  collector. The ``device`` args fills any non-specified device: if
179
179
  ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or
torchrl/data/map/tree.py CHANGED
@@ -610,7 +610,7 @@ class Tree(TensorClass["nocast"]):
610
610
  This function can pull out information from each of the nodes in a tree,
611
611
  so it can be useful for debugging. The nodes are listed line-by-line.
612
612
  Each line contains the path to the node, followed by the string
613
- representation of that node generated with :arg:`node_format_fn`. Each
613
+ representation of that node generated with ``node_format_fn``. Each
614
614
  line is indented according to number of steps in the path required to
615
615
  get to the corresponding node.
616
616
 
@@ -1370,7 +1370,7 @@ class MCTSForest:
1370
1370
  This function can pull out information from each of the nodes in a tree,
1371
1371
  so it can be useful for debugging. The nodes are listed line-by-line.
1372
1372
  Each line contains the path to the node, followed by the string
1373
- representation of that node generated with :arg:`node_format_fn`. Each
1373
+ representation of that node generated with ``node_format_fn``. Each
1374
1374
  line is indented according to number of steps in the path required to
1375
1375
  get to the corresponding node.
1376
1376
 
@@ -13,7 +13,7 @@ import math
13
13
  import warnings
14
14
  import weakref
15
15
  from collections.abc import Callable, Iterable, Mapping, Sequence
16
- from copy import deepcopy
16
+ from copy import copy, deepcopy
17
17
  from dataclasses import dataclass, field
18
18
  from functools import wraps
19
19
  from textwrap import indent
@@ -5095,6 +5095,7 @@ class Composite(TensorSpec):
5095
5095
 
5096
5096
  shape: torch.Size
5097
5097
  domain: str = "composite"
5098
+ _td_dim_names: list[str] | None = None
5098
5099
 
5099
5100
  SPEC_HANDLED_FUNCTIONS = {}
5100
5101
 
@@ -5111,6 +5112,7 @@ class Composite(TensorSpec):
5111
5112
  device: torch.device | None = None,
5112
5113
  data_cls: type | None = None,
5113
5114
  step_mdp_static: bool = False,
5115
+ names: Sequence[str] | None = None,
5114
5116
  **kwargs,
5115
5117
  ):
5116
5118
  # For compatibility with TensorDict
@@ -5126,6 +5128,12 @@ class Composite(TensorSpec):
5126
5128
  self._specs = {}
5127
5129
  self.step_mdp_static = step_mdp_static
5128
5130
 
5131
+ # Initialize names
5132
+ if names is not None:
5133
+ self._td_dim_names = list(names)
5134
+ else:
5135
+ self._td_dim_names = None
5136
+
5129
5137
  _device = (
5130
5138
  _make_ordinal_device(torch.device(device)) if device is not None else device
5131
5139
  )
@@ -5142,6 +5150,8 @@ class Composite(TensorSpec):
5142
5150
  )
5143
5151
  for k, item in argdict.items():
5144
5152
  if isinstance(item, dict):
5153
+ # Create nested Composite with appropriate names
5154
+ # Note: nested specs will get their names propagated later in the names setter
5145
5155
  item = Composite(item, shape=shape, device=_device)
5146
5156
  self[k] = item
5147
5157
  for k, item in kwargs.items():
@@ -5150,6 +5160,10 @@ class Composite(TensorSpec):
5150
5160
  self.encode = self._encode_eager
5151
5161
  self._encode_memo_dict = {}
5152
5162
 
5163
+ # Propagate names to nested specs if names were provided
5164
+ if names is not None:
5165
+ self._propagate_names_to_nested()
5166
+
5153
5167
  def memoize_encode(self, mode: bool = True) -> None:
5154
5168
  super().memoize_encode(mode=mode)
5155
5169
  for spec in self._specs.values():
@@ -5354,6 +5368,127 @@ class Composite(TensorSpec):
5354
5368
  spec.clear_device_()
5355
5369
  return self
5356
5370
 
5371
+ def _has_names(self):
5372
+ """Returns True if names are set for this Composite."""
5373
+ return self._td_dim_names is not None
5374
+
5375
+ def _erase_names(self):
5376
+ """Erases the names of this Composite."""
5377
+ self._td_dim_names = None
5378
+
5379
+ def _propagate_names_to_nested(self):
5380
+ """Propagates names to nested Composite specs."""
5381
+ if not self._has_names():
5382
+ return
5383
+ for spec in self._specs.values():
5384
+ if isinstance(spec, Composite):
5385
+ # For nested specs, we need to propagate the names
5386
+ # The nested spec should have the same leading dimensions
5387
+ if spec.ndim >= self.ndim:
5388
+ nested_names = list(self.names) + [None] * (spec.ndim - self.ndim)
5389
+ spec.names = nested_names
5390
+
5391
+ @property
5392
+ def names(self):
5393
+ """Returns the names of the dimensions of this Composite."""
5394
+ names = self._td_dim_names
5395
+ if names is None:
5396
+ return [None for _ in range(self.ndim)]
5397
+ # Return a copy but don't use copy to make dynamo happy
5398
+ return list(names)
5399
+
5400
+ @names.setter
5401
+ def names(self, value):
5402
+ """Sets the names of the dimensions of this Composite."""
5403
+ if value is None:
5404
+ self._td_dim_names = None
5405
+ return
5406
+ if len(value) != self.ndim:
5407
+ raise ValueError(
5408
+ f"Expected {self.ndim} names, but got {len(value)} names: {value}"
5409
+ )
5410
+ self._td_dim_names = list(value)
5411
+ # Propagate names to nested Composite specs
5412
+ for spec in self._specs.values():
5413
+ if isinstance(spec, Composite):
5414
+ # For nested specs, we need to propagate the names
5415
+ # The nested spec should have the same leading dimensions
5416
+ if spec.ndim >= self.ndim:
5417
+ nested_names = list(value) + [None] * (spec.ndim - self.ndim)
5418
+ spec.names = nested_names
5419
+
5420
+ def refine_names(self, *names):
5421
+ """Refines the dimension names of self according to names.
5422
+
5423
+ Refining is a special case of renaming that "lifts" unnamed dimensions.
5424
+ A None dim can be refined to have any name; a named dim can only be
5425
+ refined to have the same name.
5426
+
5427
+ Because named specs can coexist with unnamed specs, refining names
5428
+ gives a nice way to write named-spec-aware code that works with both
5429
+ named and unnamed specs.
5430
+
5431
+ names may contain up to one Ellipsis (...). The Ellipsis is expanded
5432
+ greedily; it is expanded in-place to fill names to the same length as
5433
+ self.ndim using names from the corresponding indices of self.names.
5434
+
5435
+ Returns: the same composite spec with dimensions named according to the input.
5436
+
5437
+ Examples:
5438
+ >>> spec = Composite({}, shape=[3, 4, 5, 6])
5439
+ >>> spec_refined = spec.refine_names(None, None, None, "d")
5440
+ >>> assert spec_refined.names == [None, None, None, "d"]
5441
+ >>> spec_refined = spec.refine_names("a", None, None, "d")
5442
+ >>> assert spec_refined.names == ["a", None, None, "d"]
5443
+
5444
+ """
5445
+ # replace ellipsis if any
5446
+ names_copy = copy(names)
5447
+ if any(name is Ellipsis for name in names):
5448
+ ellipsis_name = [None for _ in range(self.ndim - len(names) + 1)]
5449
+ names = []
5450
+ for name in names_copy:
5451
+ if name is Ellipsis:
5452
+ names += ellipsis_name
5453
+ else:
5454
+ names.append(name)
5455
+
5456
+ # check that the names that are set are either None or identical
5457
+ curr_names = self.names
5458
+ for i, name in enumerate(names):
5459
+ if curr_names[i] is None:
5460
+ continue
5461
+ if curr_names[i] == name:
5462
+ continue
5463
+ else:
5464
+ raise RuntimeError(
5465
+ f"refine_names: cannot coerce Composite names {self.names} with {names_copy}."
5466
+ )
5467
+ self.names = names
5468
+ return self
5469
+
5470
+ def _get_names_idx(self, idx):
5471
+ """Helper method to get names after indexing."""
5472
+ if not self._has_names():
5473
+ return None
5474
+
5475
+ names = copy(self.names)
5476
+ if isinstance(idx, (int, slice)):
5477
+ # Single dimension indexing
5478
+ if isinstance(idx, int):
5479
+ names.pop(idx)
5480
+ else:
5481
+ # For slice, we keep the names but adjust for the slice
5482
+ pass
5483
+ elif isinstance(idx, tuple):
5484
+ # Multi-dimensional indexing
5485
+ for i, sub_idx in enumerate(idx):
5486
+ if isinstance(sub_idx, int):
5487
+ # Remove the dimension
5488
+ names.pop(i)
5489
+ # For slices, we keep the name
5490
+ return names
5491
+
5357
5492
  def __getitem__(self, idx):
5358
5493
  """Indexes the current Composite based on the provided index."""
5359
5494
  if isinstance(idx, (str, tuple)):
@@ -5393,10 +5528,15 @@ class Composite(TensorSpec):
5393
5528
  except RuntimeError:
5394
5529
  device = self._device
5395
5530
 
5531
+ names = None
5532
+ if self._has_names():
5533
+ names = self._get_names_idx(idx)
5534
+
5396
5535
  return self.__class__(
5397
5536
  indexed_specs,
5398
5537
  shape=indexed_shape,
5399
5538
  device=device,
5539
+ names=names,
5400
5540
  )
5401
5541
 
5402
5542
  def get(self, item, default=NO_DEFAULT):
@@ -5600,16 +5740,22 @@ class Composite(TensorSpec):
5600
5740
  for key, item in self.items():
5601
5741
  if item is not None:
5602
5742
  _dict[key] = item.rand(shape)
5603
- if self.data_cls is None:
5604
- cls = TensorDict
5743
+
5744
+ cls = self.data_cls if self.data_cls is not None else TensorDict
5745
+ if cls is not TensorDict:
5746
+ kwargs = {}
5747
+ if self._td_dim_names is not None:
5748
+ warnings.warn(f"names for cls {cls} is not supported for rand.")
5605
5749
  else:
5606
- cls = self.data_cls
5750
+ kwargs = {"names": self._td_dim_names}
5751
+
5607
5752
  # No need to run checks since we know Composite is compliant with
5608
5753
  # TensorDict requirements
5609
5754
  return cls.from_dict(
5610
5755
  _dict,
5611
5756
  batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]),
5612
5757
  device=self.device,
5758
+ **kwargs,
5613
5759
  )
5614
5760
 
5615
5761
  def keys(
@@ -5760,6 +5906,7 @@ class Composite(TensorSpec):
5760
5906
  shape=self.shape,
5761
5907
  data_cls=self.data_cls,
5762
5908
  step_mdp_static=self.step_mdp_static,
5909
+ names=self.names if self._has_names() else None,
5763
5910
  )
5764
5911
  if not isinstance(dest, (str, int, torch.device)):
5765
5912
  raise ValueError(
@@ -5782,6 +5929,7 @@ class Composite(TensorSpec):
5782
5929
  shape=self.shape,
5783
5930
  data_cls=self.data_cls,
5784
5931
  step_mdp_static=self.step_mdp_static,
5932
+ names=self.names if self._has_names() else None,
5785
5933
  )
5786
5934
 
5787
5935
  def clone(self) -> Composite:
@@ -5802,6 +5950,7 @@ class Composite(TensorSpec):
5802
5950
  shape=self.shape,
5803
5951
  data_cls=self.data_cls,
5804
5952
  step_mdp_static=self.step_mdp_static,
5953
+ names=self.names if self._has_names() else None,
5805
5954
  )
5806
5955
 
5807
5956
  def cardinality(self) -> int:
@@ -5874,10 +6023,13 @@ class Composite(TensorSpec):
5874
6023
  except RuntimeError:
5875
6024
  device = self._device
5876
6025
 
5877
- if self.data_cls is not None:
5878
- cls = self.data_cls
6026
+ cls = self.data_cls if self.data_cls is not None else TensorDict
6027
+ if cls is not TensorDict:
6028
+ kwargs = {}
6029
+ if self._td_dim_names is not None:
6030
+ warnings.warn(f"names for cls {cls} is not supported for zero.")
5879
6031
  else:
5880
- cls = TensorDict
6032
+ kwargs = {"names": self._td_dim_names}
5881
6033
 
5882
6034
  return cls.from_dict(
5883
6035
  {
@@ -5887,6 +6039,7 @@ class Composite(TensorSpec):
5887
6039
  },
5888
6040
  batch_size=_size([*shape, *self._safe_shape]),
5889
6041
  device=device,
6042
+ **kwargs,
5890
6043
  )
5891
6044
 
5892
6045
  def __eq__(self, other: object) -> bool:
@@ -5942,12 +6095,17 @@ class Composite(TensorSpec):
5942
6095
  else None
5943
6096
  for key, value in tuple(self.items())
5944
6097
  }
6098
+ names = None
6099
+ if self._has_names():
6100
+ names = [None] * (len(shape) - self.ndim) + self.names
6101
+
5945
6102
  out = Composite(
5946
6103
  specs,
5947
6104
  shape=shape,
5948
6105
  device=device,
5949
6106
  data_cls=self.data_cls,
5950
6107
  step_mdp_static=self.step_mdp_static,
6108
+ names=names,
5951
6109
  )
5952
6110
  return out
5953
6111
 
@@ -5965,12 +6123,21 @@ class Composite(TensorSpec):
5965
6123
  except RuntimeError:
5966
6124
  device = self._device
5967
6125
 
6126
+ names = None
6127
+ if self._has_names():
6128
+ names = copy(self.names)
6129
+ names.pop(dim)
6130
+ # If all names are None after popping, set to None
6131
+ if all(name is None for name in names):
6132
+ names = None
6133
+
5968
6134
  return self.__class__(
5969
6135
  {key: value.squeeze(dim) for key, value in self.items()},
5970
6136
  shape=shape,
5971
6137
  device=device,
5972
6138
  data_cls=self.data_cls,
5973
6139
  step_mdp_static=self.step_mdp_static,
6140
+ names=names,
5974
6141
  )
5975
6142
 
5976
6143
  if self.shape.count(1) == 0:
@@ -5993,6 +6160,11 @@ class Composite(TensorSpec):
5993
6160
  except RuntimeError:
5994
6161
  device = self._device
5995
6162
 
6163
+ names = None
6164
+ if self._has_names():
6165
+ names = copy(self.names)
6166
+ names.insert(dim, None)
6167
+
5996
6168
  return self.__class__(
5997
6169
  {
5998
6170
  key: value.unsqueeze(dim) if value is not None else None
@@ -6002,6 +6174,7 @@ class Composite(TensorSpec):
6002
6174
  device=device,
6003
6175
  data_cls=self.data_cls,
6004
6176
  step_mdp_static=self.step_mdp_static,
6177
+ names=names,
6005
6178
  )
6006
6179
 
6007
6180
  def unbind(self, dim: int = 0) -> tuple[Composite, ...]:
@@ -6012,8 +6185,17 @@ class Composite(TensorSpec):
6012
6185
  raise ValueError(
6013
6186
  f"Cannot unbind along dim {orig_dim} with shape {self.shape}."
6014
6187
  )
6015
- shape = (s for i, s in enumerate(self.shape) if i != dim)
6188
+ shape = tuple(s for i, s in enumerate(self.shape) if i != dim)
6016
6189
  unbound_vals = {key: val.unbind(dim) for key, val in self.items()}
6190
+
6191
+ names = None
6192
+ if self._has_names():
6193
+ names = copy(self.names)
6194
+ names.pop(dim)
6195
+ # If all names are None after popping, set to None
6196
+ if all(name is None for name in names):
6197
+ names = None
6198
+
6017
6199
  return tuple(
6018
6200
  self.__class__(
6019
6201
  {key: val[i] for key, val in unbound_vals.items()},
@@ -6021,6 +6203,7 @@ class Composite(TensorSpec):
6021
6203
  device=self.device,
6022
6204
  data_cls=self.data_cls,
6023
6205
  step_mdp_static=self.step_mdp_static,
6206
+ names=names,
6024
6207
  )
6025
6208
  for i in range(self.shape[dim])
6026
6209
  )
@@ -308,7 +308,7 @@ class BatchedEnvBase(EnvBase):
308
308
  num_sub_threads: int = 1,
309
309
  serial_for_single: bool = False,
310
310
  non_blocking: bool = False,
311
- mp_start_method: str = None,
311
+ mp_start_method: str | None = None,
312
312
  use_buffers: bool | None = None,
313
313
  consolidate: bool = True,
314
314
  ):
torchrl/envs/common.py CHANGED
@@ -2267,7 +2267,7 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit):
2267
2267
  entry_point: Callable | None = None,
2268
2268
  transform: Transform | None = None, # noqa: F821
2269
2269
  info_keys: list[NestedKey] | None = None,
2270
- backend: str = None,
2270
+ backend: str | None = None,
2271
2271
  to_numpy: bool = False,
2272
2272
  reward_threshold: float | None = None,
2273
2273
  nondeterministic: bool = False,
@@ -28,10 +28,10 @@ class LLMHashingEnv(EnvBase):
28
28
  The primary goal of this environment is to identify token chains using a hashing function.
29
29
  This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
30
30
  identifiers, or easily prune repeated token chains in a data structure.
31
- The following figure gives an overview of this workflow:
32
31
 
33
- .. figure:: /_static/img/rollout-llm.png
34
- :alt: Data collection loop with our LLM environment.
32
+ .. The following figure gives an overview of this workflow:
33
+ .. .. figure:: /_static/img/rollout-llm.png
34
+ .. :alt: Data collection loop with our LLM environment.
35
35
 
36
36
  Args:
37
37
  vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
torchrl/envs/llm/envs.py CHANGED
@@ -601,10 +601,10 @@ class LLMHashingEnv(EnvBase):
601
601
  The primary goal of this environment is to identify token chains using a hashing function.
602
602
  This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node
603
603
  identifiers, or easily prune repeated token chains in a data structure.
604
- The following figure gives an overview of this workflow:
605
604
 
606
- .. figure:: /_static/img/rollout-llm.png
607
- :alt: Data collection loop with our LLM environment.
605
+ .. The following figure gives an overview of this workflow:
606
+ .. .. figure:: /_static/img/rollout-llm.png
607
+ .. :alt: Data collection loop with our LLM environment.
608
608
 
609
609
  Args:
610
610
  vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed.
@@ -5423,8 +5423,8 @@ class Hash(UnaryTransform):
5423
5423
  """Look up the input that was given for a particular hash output.
5424
5424
 
5425
5425
  This feature is only available if, during initialization, either the
5426
- :arg:`repertoire` argument was given or both the :arg:`in_keys_inv` and
5427
- :arg:`out_keys_inv` arguments were given.
5426
+ ``repertoire`` argument was given or both the ``in_keys_inv`` and
5427
+ ``out_keys_inv`` arguments were given.
5428
5428
 
5429
5429
  Args:
5430
5430
  hash_tensor (Tensor): The hash output.
@@ -622,7 +622,7 @@ class Ordinal(D.Categorical):
622
622
  not impose any notion of proximity or ordering over its support's atoms.
623
623
  The `Ordinal` distribution explicitly encodes those concepts, which is
624
624
  useful for learning discrete sampling from continuous sets. See §5 of
625
- `Tang & Agrawal, 2020<https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
625
+ `Tang & Agrawal, 2020 <https://arxiv.org/pdf/1901.10500.pdf>`_ for details.
626
626
 
627
627
  .. note::
628
628
  This class is mostly useful when you want to learn a distribution over
@@ -526,7 +526,7 @@ class AsyncVLLM(RLvLLMEngine):
526
526
  See `this issue <https://github.com/vllm-project/vllm/issues/8268>`_ for more details.
527
527
 
528
528
  Example:
529
- >>> from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
529
+ >>> from torchrl.modules.llm import AsyncVLLM
530
530
  >>> from vllm import SamplingParams
531
531
  >>>
532
532
  >>> # Simple usage - single GPU, single replica
@@ -172,6 +172,7 @@ class TransformersWrapper(LLMWrapperBase):
172
172
 
173
173
  Input Keys:
174
174
  The input key depends on both `input_mode` and `generate`:
175
+
175
176
  - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
176
177
  - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
177
178
  - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
@@ -2460,7 +2461,7 @@ class RemoteTransformersWrapper:
2460
2461
  model,
2461
2462
  max_concurrency: int = 16,
2462
2463
  validate_model: bool = True,
2463
- actor_name: str = None,
2464
+ actor_name: str | None = None,
2464
2465
  num_gpus: int = 1,
2465
2466
  num_cpus: int = 1,
2466
2467
  **kwargs,
@@ -194,6 +194,7 @@ class vLLMWrapper(LLMWrapperBase):
194
194
 
195
195
  Input Keys:
196
196
  The input key depends on both `input_mode` and `generate`:
197
+
197
198
  - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
198
199
  - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
199
200
  - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
torchrl/objectives/a2c.py CHANGED
@@ -282,12 +282,12 @@ class A2CLoss(LossModule):
282
282
  loss_critic_type: str = "smooth_l1",
283
283
  gamma: float | None = None,
284
284
  separate_losses: bool = False,
285
- advantage_key: str = None,
286
- value_target_key: str = None,
285
+ advantage_key: str | None = None,
286
+ value_target_key: str | None = None,
287
287
  functional: bool = True,
288
288
  actor: ProbabilisticTensorDictSequential = None,
289
289
  critic: ProbabilisticTensorDictSequential = None,
290
- reduction: str = None,
290
+ reduction: str | None = None,
291
291
  clip_value: float | None = None,
292
292
  **kwargs,
293
293
  ):
torchrl/objectives/cql.py CHANGED
@@ -291,7 +291,7 @@ class CQLLoss(LossModule):
291
291
  num_random: int = 10,
292
292
  with_lagrange: bool = False,
293
293
  lagrange_thresh: float = 0.0,
294
- reduction: str = None,
294
+ reduction: str | None = None,
295
295
  deactivate_vmap: bool = False,
296
296
  ) -> None:
297
297
  self._out_keys = None
@@ -1100,7 +1100,7 @@ class DiscreteCQLLoss(LossModule):
1100
1100
  delay_value: bool = True,
1101
1101
  gamma: float | None = None,
1102
1102
  action_space=None,
1103
- reduction: str = None,
1103
+ reduction: str | None = None,
1104
1104
  ) -> None:
1105
1105
  self._in_keys = None
1106
1106
  if reduction is None:
@@ -266,9 +266,9 @@ class CrossQLoss(LossModule):
266
266
  action_spec=None,
267
267
  fixed_alpha: bool = False,
268
268
  target_entropy: str | float = "auto",
269
- priority_key: str = None,
269
+ priority_key: str | None = None,
270
270
  separate_losses: bool = False,
271
- reduction: str = None,
271
+ reduction: str | None = None,
272
272
  deactivate_vmap: bool = False,
273
273
  ) -> None:
274
274
  self._in_keys = None
@@ -201,7 +201,7 @@ class DDPGLoss(LossModule):
201
201
  delay_value: bool = True,
202
202
  gamma: float | None = None,
203
203
  separate_losses: bool = False,
204
- reduction: str = None,
204
+ reduction: str | None = None,
205
205
  ) -> None:
206
206
  self._in_keys = None
207
207
  if reduction is None:
@@ -85,7 +85,7 @@ class OnlineDTLoss(LossModule):
85
85
  fixed_alpha: bool = False,
86
86
  target_entropy: str | float = "auto",
87
87
  samples_mc_entropy: int = 1,
88
- reduction: str = None,
88
+ reduction: str | None = None,
89
89
  ) -> None:
90
90
  self._in_keys = None
91
91
  self._out_keys = None
@@ -296,7 +296,7 @@ class DTLoss(LossModule):
296
296
  actor_network: ProbabilisticActor,
297
297
  *,
298
298
  loss_function: str = "l2",
299
- reduction: str = None,
299
+ reduction: str | None = None,
300
300
  device: torch.device | None = None,
301
301
  ) -> None:
302
302
  self._in_keys = None
@@ -163,9 +163,9 @@ class REDQLoss_deprecated(LossModule):
163
163
  delay_qvalue: bool = True,
164
164
  gSDE: bool = False,
165
165
  gamma: float | None = None,
166
- priority_key: str = None,
166
+ priority_key: str | None = None,
167
167
  separate_losses: bool = False,
168
- reduction: str = None,
168
+ reduction: str | None = None,
169
169
  deactivate_vmap: bool = False,
170
170
  ):
171
171
  self._in_keys = None
torchrl/objectives/dqn.py CHANGED
@@ -179,8 +179,8 @@ class DQNLoss(LossModule):
179
179
  double_dqn: bool = False,
180
180
  gamma: float | None = None,
181
181
  action_space: str | TensorSpec = None,
182
- priority_key: str = None,
183
- reduction: str = None,
182
+ priority_key: str | None = None,
183
+ reduction: str | None = None,
184
184
  ) -> None:
185
185
  if reduction is None:
186
186
  reduction = "mean"
@@ -455,8 +455,8 @@ class DistributionalDQNLoss(LossModule):
455
455
  *,
456
456
  gamma: float,
457
457
  delay_value: bool = True,
458
- priority_key: str = None,
459
- reduction: str = None,
458
+ priority_key: str | None = None,
459
+ reduction: str | None = None,
460
460
  ):
461
461
  if reduction is None:
462
462
  reduction = "mean"
@@ -78,7 +78,7 @@ class GAILLoss(LossModule):
78
78
  *,
79
79
  use_grad_penalty: bool = False,
80
80
  gp_lambda: float = 10,
81
- reduction: str = None,
81
+ reduction: str | None = None,
82
82
  ) -> None:
83
83
  self._in_keys = None
84
84
  self._out_keys = None
torchrl/objectives/iql.py CHANGED
@@ -266,9 +266,9 @@ class IQLLoss(LossModule):
266
266
  temperature: float = 1.0,
267
267
  expectile: float = 0.5,
268
268
  gamma: float | None = None,
269
- priority_key: str = None,
269
+ priority_key: str | None = None,
270
270
  separate_losses: bool = False,
271
- reduction: str = None,
271
+ reduction: str | None = None,
272
272
  deactivate_vmap: bool = False,
273
273
  ) -> None:
274
274
  self._in_keys = None
@@ -785,9 +785,9 @@ class DiscreteIQLLoss(IQLLoss):
785
785
  temperature: float = 1.0,
786
786
  expectile: float = 0.5,
787
787
  gamma: float | None = None,
788
- priority_key: str = None,
788
+ priority_key: str | None = None,
789
789
  separate_losses: bool = False,
790
- reduction: str = None,
790
+ reduction: str | None = None,
791
791
  ) -> None:
792
792
  self._in_keys = None
793
793
  self._out_keys = None
@@ -195,7 +195,7 @@ class QMixerLoss(LossModule):
195
195
  delay_value: bool = True,
196
196
  gamma: float | None = None,
197
197
  action_space: str | TensorSpec = None,
198
- priority_key: str = None,
198
+ priority_key: str | None = None,
199
199
  ) -> None:
200
200
  super().__init__()
201
201
  self._in_keys = None
@@ -279,9 +279,9 @@ class REDQLoss(LossModule):
279
279
  delay_qvalue: bool = True,
280
280
  gSDE: bool = False,
281
281
  gamma: float | None = None,
282
- priority_key: str = None,
282
+ priority_key: str | None = None,
283
283
  separate_losses: bool = False,
284
- reduction: str = None,
284
+ reduction: str | None = None,
285
285
  deactivate_vmap: bool = False,
286
286
  ):
287
287
  if reduction is None:
@@ -249,13 +249,13 @@ class ReinforceLoss(LossModule):
249
249
  delay_value: bool = False,
250
250
  loss_critic_type: str = "smooth_l1",
251
251
  gamma: float | None = None,
252
- advantage_key: str = None,
253
- value_target_key: str = None,
252
+ advantage_key: str | None = None,
253
+ value_target_key: str | None = None,
254
254
  separate_losses: bool = False,
255
255
  functional: bool = True,
256
256
  actor: ProbabilisticTensorDictSequential = None,
257
257
  critic: ProbabilisticTensorDictSequential = None,
258
- reduction: str = None,
258
+ reduction: str | None = None,
259
259
  clip_value: float | None = None,
260
260
  ) -> None:
261
261
  if actor is not None:
torchrl/objectives/sac.py CHANGED
@@ -325,16 +325,16 @@ class SACLoss(LossModule):
325
325
  alpha_init: float = 1.0,
326
326
  min_alpha: float | None = None,
327
327
  max_alpha: float | None = None,
328
- action_spec=None,
328
+ action_spec: TensorSpec | None = None,
329
329
  fixed_alpha: bool = False,
330
330
  target_entropy: str | float = "auto",
331
331
  delay_actor: bool = False,
332
332
  delay_qvalue: bool = True,
333
333
  delay_value: bool = True,
334
334
  gamma: float | None = None,
335
- priority_key: str = None,
335
+ priority_key: str | None = None,
336
336
  separate_losses: bool = False,
337
- reduction: str = None,
337
+ reduction: str | None = None,
338
338
  skip_done_states: bool = False,
339
339
  deactivate_vmap: bool = False,
340
340
  ) -> None:
@@ -1195,9 +1195,9 @@ class DiscreteSACLoss(LossModule):
1195
1195
  target_entropy_weight: float = 0.98,
1196
1196
  target_entropy: str | Number = "auto",
1197
1197
  delay_qvalue: bool = True,
1198
- priority_key: str = None,
1198
+ priority_key: str | None = None,
1199
1199
  separate_losses: bool = False,
1200
- reduction: str = None,
1200
+ reduction: str | None = None,
1201
1201
  skip_done_states: bool = False,
1202
1202
  deactivate_vmap: bool = False,
1203
1203
  ):
torchrl/objectives/td3.py CHANGED
@@ -236,9 +236,9 @@ class TD3Loss(LossModule):
236
236
  delay_actor: bool = True,
237
237
  delay_qvalue: bool = True,
238
238
  gamma: float | None = None,
239
- priority_key: str = None,
239
+ priority_key: str | None = None,
240
240
  separate_losses: bool = False,
241
- reduction: str = None,
241
+ reduction: str | None = None,
242
242
  deactivate_vmap: bool = False,
243
243
  ) -> None:
244
244
  if reduction is None:
@@ -251,9 +251,9 @@ class TD3BCLoss(LossModule):
251
251
  loss_function: str = "smooth_l1",
252
252
  delay_actor: bool = True,
253
253
  delay_qvalue: bool = True,
254
- priority_key: str = None,
254
+ priority_key: str | None = None,
255
255
  separate_losses: bool = False,
256
- reduction: str = None,
256
+ reduction: str | None = None,
257
257
  deactivate_vmap: bool = False,
258
258
  ) -> None:
259
259
  if reduction is None:
@@ -52,9 +52,9 @@ class WandbLogger(Logger):
52
52
  self,
53
53
  exp_name: str,
54
54
  offline: bool = False,
55
- save_dir: str = None,
56
- id: str = None,
57
- project: str = None,
55
+ save_dir: str | None = None,
56
+ id: str | None = None,
57
+ project: str | None = None,
58
58
  *,
59
59
  video_fps: int = 32,
60
60
  **kwargs,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchrl
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
5
5
  Author-email: torchrl contributors <vmoens@fb.com>
6
6
  Maintainer-email: torchrl contributors <vmoens@fb.com>
@@ -4,7 +4,7 @@ benchmarks\requirements.txt,sha256=zq-bWlShbTeSnj-Ud4NDgGngvWR_jdpnTAlMahTjr3k,7
4
4
  benchmarks\test_collectors_benchmark.py,sha256=-8MQHzK7ItL4eQRuE8V40WP3EJVXktwDtgzGN23rJy8,6912
5
5
  benchmarks\test_compressed_storage_benchmark.py,sha256=8NjVwRGKxi5CSuq9O9xLzM6MYG8R3a1afBjaXtGQjPY,6074
6
6
  benchmarks\test_envs_benchmark.py,sha256=ntXAMgLSKtasErZso_lsM_F2ae6gnuBS-Ow0tI4OPoM,3653
7
- benchmarks\test_llm.py,sha256=PL476UXGywrjzSbEgwnCP_35B8Y3UalOVsHgDFz2vuk,3888
7
+ benchmarks\test_llm.py,sha256=kHONY3smLA1JbyCBbSSeX8XSiapD6SiFxNKwikMf2-I,4038
8
8
  benchmarks\test_objectives_benchmarks.py,sha256=EipTFZNKV_VcofGHTyIQiKVQv45k_u1QHrAMgncNB8I,34608
9
9
  benchmarks\test_replaybuffer_benchmark.py,sha256=3VoDtjg0h3q76em1Ptq75jh_qfeHplYR4d3e1vTFsRc,8274
10
10
  benchmarks\ecosystem\gym_env_throughput.py,sha256=_QDRDLoq0LfGHNWV_o0E9s_k-YUuvcyWR0v3-w-Sjig,14282
@@ -92,18 +92,18 @@ sota-implementations\td3\utils.py,sha256=AU18AXsfPXnsOOB7wKgo_5pe1wFobquJywnI0ay
92
92
  sota-implementations\td3_bc\td3_bc.py,sha256=GUGzbGkaR7MDbmhvMCs69Ewd-iCVjaQCEt6dxcSE8yw,5927
93
93
  sota-implementations\td3_bc\utils.py,sha256=ZMMYehJP6KZaiiwTRlJRFKC528M49QhfEb_RAxaBj-U,6951
94
94
  torchrl\_extension.py,sha256=5DqUOUHZJPLqcSvGztUUBy4rJFAzznuJCAu5c84OJo8,2981
95
- torchrl\_torchrl.cp39-win_amd64.pyd,sha256=XUHgcTQC6nS1UUWxbL74fvxiokiqQCExWxPH2vgFpJU,418816
96
- torchrl\_utils.py,sha256=4tX8pOODv1pxgs4SQtauAhxdKgm3b_eh1fmj7JffELQ,32498
97
- torchrl\__init__.py,sha256=S0E9MQNriM-1IpeifdlYkejum3EQ263p10ojQOII5Rk,3108
98
- torchrl\collectors\collectors.py,sha256=Un4UDD7XxX3K-9AhkfsqOXKV6mBWf7Rd8AE_0blPgi8,183771
95
+ torchrl\_torchrl.cp39-win_amd64.pyd,sha256=yPDxj_5ErX7tV2rp4VD2LeYx0pY9DVj0zFEETqCDaFc,418816
96
+ torchrl\_utils.py,sha256=2CYX3KGAendGu5VCE1r_oRr_rGFCHd6jIUa1GMASLg4,32554
97
+ torchrl\__init__.py,sha256=KINiNwlkd5Qj3ts2XIGQpk8P10aaPERlR1iLegQgtpM,3190
98
+ torchrl\collectors\collectors.py,sha256=BdkinVV0S9rhd_ypGalsgm-5eOFBqurR2h0jozT8SQE,183779
99
99
  torchrl\collectors\utils.py,sha256=DLHoLf9MvjAmJssFuG1GE9wDpjIuWgmbmMrnphg7mwc,11588
100
100
  torchrl\collectors\weight_update.py,sha256=S9aSLYV6MtNMk23TM9ZY5nsk72s7bUJ2wnslh7k8sSI,24504
101
101
  torchrl\collectors\__init__.py,sha256=jdZJgqPB15BBKHgd46Md1FYy6GaQjGJBIsQ2mAo8W_E,898
102
102
  torchrl\collectors\distributed\default_configs.py,sha256=Kvbn84NRz9l7mBk_681P9I4AQ4hyeNPBfzLqydiFijQ,895
103
- torchrl\collectors\distributed\generic.py,sha256=cRVIGv4vfMB9VudJpar-5sU9gE_39mYzwcRYGZe5ez0,46526
104
- torchrl\collectors\distributed\ray.py,sha256=UOgRU0KWZT8e7GbE643Iw1Kv9ZuRxd7r6TFbgn_qAOs,36631
105
- torchrl\collectors\distributed\rpc.py,sha256=E9S_AFzP7hY3l6vHyNCk_A4Clzqhe8u-jXyIwbNJRnM,40497
106
- torchrl\collectors\distributed\sync.py,sha256=jJkrpVlYAo8JAE6ec8vA1wvaI7U6NwulPX3EbcW9vOc,27938
103
+ torchrl\collectors\distributed\generic.py,sha256=tlxFJEYeaYJXML6kUgvykwsXj5nuZGsKWUKWAnOMpYM,46521
104
+ torchrl\collectors\distributed\ray.py,sha256=kFuamafzMlZL7uT5ALzgm8vVb5GXfcibG7CXN3vqxN8,37082
105
+ torchrl\collectors\distributed\rpc.py,sha256=kJZUX-5bFZGTvjRnYnsXumD9eESAkZIz7Yzn2IVpYvA,40492
106
+ torchrl\collectors\distributed\sync.py,sha256=Tg4dlYGpxrTgILXJAugH8c0ZzrIIdW3vvAxiQF1jnAA,27933
107
107
  torchrl\collectors\distributed\utils.py,sha256=eY6M-vLCSzyACHRNBx5bHcieWsZfLg7DfNKGIv0IgHI,6625
108
108
  torchrl\collectors\distributed\__init__.py,sha256=cKDWdNlwx2LoJkTwf-DKUXbq3Y-0Z1DctPYPcdgOSU0,730
109
109
  torchrl\collectors\llm\base.py,sha256=ELhduZbaELYFpgitPoXB4qPhX2bsunqJgB2pOpwjG4U,22189
@@ -120,7 +120,7 @@ torchrl\csrc\torch_utils.h,sha256=k7gTjLle9wW_TG6GrqqOYIG1MKsWqjHPcoqPDxLys5Y,73
120
120
  torchrl\csrc\utils.cpp,sha256=agOkJ0G4ytRsuGmkTXT1kBLGWOxrpkd6LHDVeZymEgA,1690
121
121
  torchrl\csrc\utils.h,sha256=bXlPW94HH4UMRDXXbPgfC25SvI_txvAkueBmSex9g7M,1132
122
122
  torchrl\data\rlhf.py,sha256=y4KwcYjtlons4czR72LGLqTjfl913EKkP_qXNXO_LC4,1040
123
- torchrl\data\tensor_specs.py,sha256=dFEESV_bBmCiwW1J1Fz9-8tqHIb9dMsB4uWhQHc3s8Q,260938
123
+ torchrl\data\tensor_specs.py,sha256=NpRGa6Kz_uvkn8ZjZAguFiIA7B9r2zPeMP91G9BlKoE,268154
124
124
  torchrl\data\utils.py,sha256=krn1klWLdfdxy7afevsVecX0KGTuNUSz2D2yNCGW1Hw,12459
125
125
  torchrl\data\__init__.py,sha256=SBu_aozTi8iEWZ_EJcWUMqERDXH9i-S3O6hjMb5xjno,5166
126
126
  torchrl\data\datasets\atari_dqn.py,sha256=zet-dUhsxIbPNMqcbqksMCBSP9-ZvzIqxsviRaRTJtY,41626
@@ -145,7 +145,7 @@ torchrl\data\llm\__init__.py,sha256=-_UPiaQgHzFVLM_gfiOS0sCGVRwkxZ1_Z5B6C9ThN9o,
145
145
  torchrl\data\map\hash.py,sha256=AilOzYQ0KYhCZpVZCm63AhRuXn2P_RLB4PcbIx6qnlA,7446
146
146
  torchrl\data\map\query.py,sha256=CfAC9XJh7KpdCJNgqfJ5CUCi7BWqekgCMlNVSIF60To,8125
147
147
  torchrl\data\map\tdstorage.py,sha256=HzdR7M8Fjt9543vqdEpy0QQ8tHgKLCqWYCA6P02_2dM,15143
148
- torchrl\data\map\tree.py,sha256=MqktMw54JjGNJTdo_sjTR_UVQ3M4XPjhphijLcYwjKE,60788
148
+ torchrl\data\map\tree.py,sha256=EhYin_p5qRfb1mrfD_1mkYelyNupq4upMRaRPIvPLUY,60782
149
149
  torchrl\data\map\utils.py,sha256=fEjqCzaE4Vqjb8OzUvnClmLxVMooqeFOBMs7wroYvxs,3022
150
150
  torchrl\data\map\__init__.py,sha256=bON0vqCksU7FPoWNqiNcdl60t7yWUh9SdLhNtglj7jI,576
151
151
  torchrl\data\postprocs\postprocs.py,sha256=dpXOKWlhdKy4Um7HdzRKe42PJ_Q1jHC7AX5plR9AIiw,15509
@@ -160,15 +160,15 @@ torchrl\data\replay_buffers\utils.py,sha256=RPKS5C5U2GDPJQ1zSiawEi4LIwWuazJfcgQJ
160
160
  torchrl\data\replay_buffers\writers.py,sha256=WVChK3QPVb2Ehoqn3U_c8W2FJ5nU57hqCb7XTthrUgU,29488
161
161
  torchrl\data\replay_buffers\__init__.py,sha256=RcJSEXHz6zt1gQSFhEaKwQHUhZkYa2x-buNSDkNAslE,2595
162
162
  torchrl\envs\async_envs.py,sha256=5ao4aEdETWTaTHoUru5GZOzH5uvNtGGmHxnniSk4Wdc,44125
163
- torchrl\envs\batched_envs.py,sha256=jiUDll2tW2jJ57FlCFzV844DfC5OzVXCDX6idr8gLDc,121381
164
- torchrl\envs\common.py,sha256=GxKbl5CRO7fMDDrlJW0TmEpTifr3SYNSAUgi2iGQApA,174223
163
+ torchrl\envs\batched_envs.py,sha256=hhlNbzE0h1euxAWLgbxClTRPXySIDKKec15Sqlk3bjs,121388
164
+ torchrl\envs\common.py,sha256=vZgFUaunYISFYoYCZ2OlgDWek4aWWq8g61FArECUTLs,174230
165
165
  torchrl\envs\env_creator.py,sha256=AAuZNNgvm_jX2_014AWWvNI5EQCmerU8ICCIAvv3PiM,10329
166
166
  torchrl\envs\gym_like.py,sha256=dky7JLsHAVnTdLimf4KAZGsPP104SFLD4fVzlmyAYh8,32381
167
167
  torchrl\envs\utils.py,sha256=YTCNO6XyCfxm9njCFafrDh1nnZ7wTPYH8uPAeCtG4mI,74861
168
168
  torchrl\envs\vec_envs.py,sha256=B3lrPCVk4jRIXy0V0berwktZInHpx_UBABTcPUlA1Lw,377
169
169
  torchrl\envs\__init__.py,sha256=2eVr8StUSMiNd-IoD5BQAFFuV10pAtO926b6QzRzB_M,6082
170
170
  torchrl\envs\custom\chess.py,sha256=IiudT29mY6Yvau6CIXB5678d8kvkCoBvLh9FoQeSiM0,25285
171
- torchrl\envs\custom\llm.py,sha256=i_xxGNbdTUXCYJPY3tj6JjbaARFzdP9hqB3zzJjhywU,8860
171
+ torchrl\envs\custom\llm.py,sha256=FnZltPFKQUOH49qlvABk9dth2d2ZRkO_tiA0kiu6LKs,8869
172
172
  torchrl\envs\custom\pendulum.py,sha256=8sBgT8DvHrg5-YSOduC8GAHM9v7SXXfd8BzHvQ3wHOU,18617
173
173
  torchrl\envs\custom\san_moves.txt,sha256=AMStL2XCAnEbO6UZEYfDSCp5zRO2811gPEIzOWbmmRY,217492
174
174
  torchrl\envs\custom\tictactoeenv.py,sha256=voszQ7rPl7PbCBB6BQ8OELCi5US5YFm5gb-CLOkAIRM,12547
@@ -194,7 +194,7 @@ torchrl\envs\libs\vmas.py,sha256=giTORg2AqYzyjrazdD94fD2dNYwX7qe5TFnr-E1mjIg,371
194
194
  torchrl\envs\libs\_gym_utils.py,sha256=JYCNtWW4gAYwLq4k87ZdwLtDR_mRSyWQOi4seuXllOI,13150
195
195
  torchrl\envs\libs\__init__.py,sha256=tNiqWDxI-PQ2sxer7atuaSmKN_35pUpDIVFor-GIPfg,1821
196
196
  torchrl\envs\llm\chat.py,sha256=vCtRIhlw_8jQ55KL8x8-7N9dXxb83z68MirCD2jqt6E,32529
197
- torchrl\envs\llm\envs.py,sha256=ercyfYxVmIiKJq-U-Dn2ZgWxhEmDBYTDOVST1af8_sE,35801
197
+ torchrl\envs\llm\envs.py,sha256=z73Urni2ZCTPlQvMyHaxTzkvog7TwWqBI-c-GRd5TEk,35810
198
198
  torchrl\envs\llm\__init__.py,sha256=mxvPV4WD_jViongRHTYScOHwwGPjUjKhch90jA11IuY,1504
199
199
  torchrl\envs\llm\datasets\gsm8k.py,sha256=sja-7uzYSRd2CUYI6pe3SWq3YlKWWVmLqoNSFuw5tW0,17040
200
200
  torchrl\envs\llm\datasets\ifeval.py,sha256=9uSTySm3PKDIsgX7axHI1b-Z2IZmcNzUIACAIoDXuQM,12373
@@ -231,7 +231,7 @@ torchrl\envs\transforms\llm.py,sha256=V2ZY8-QY27GCpGY5i0UrryohQclybyL7aZwU9glc7w
231
231
  torchrl\envs\transforms\r3m.py,sha256=3B-JB3GHh3s1Af69WZ3wl3BU8SP0g_QmuH8IPztXRbQ,13850
232
232
  torchrl\envs\transforms\rb_transforms.py,sha256=eoJVEOv2ckVHth7nBgRaULW4TICf7YoQHcbWn9n1Cns,7661
233
233
  torchrl\envs\transforms\rlhf.py,sha256=DlAgMrLWVFkUQ3inpFgfHDkGjJKmFwhxiRGV3FWGZK8,691
234
- torchrl\envs\transforms\transforms.py,sha256=zJPpUPfpyzcrcg2X0AwkF2yDtMrzc99JtzqzQmcsyYw,499497
234
+ torchrl\envs\transforms\transforms.py,sha256=OqABOhEw36yzPc_oYEhMkMkj0PBbCIsGc9igrx3V-8Q,499488
235
235
  torchrl\envs\transforms\utils.py,sha256=V7YAV2BcJWvhC6aUV9LcwOodZFPbKmsltyR747OdRTU,3358
236
236
  torchrl\envs\transforms\vc1.py,sha256=snXdONyRKkyMiaW-bT7SwDJUQVb5GWr1mqY1W78Ohn0,10841
237
237
  torchrl\envs\transforms\vecnorm.py,sha256=udY-bdOhm-Aqjpt-STQT_mT6Ee50j5XH7v70gmyoKKk,34915
@@ -239,7 +239,7 @@ torchrl\envs\transforms\vip.py,sha256=r8Ni0hAYY1gispLj0TXV2VIedrgC4eW3hAhJBv47Q7
239
239
  torchrl\envs\transforms\__init__.py,sha256=d0p0afpcykAcBU6HENDJxtq91UEkooY69wbFgyOFIxE,3281
240
240
  torchrl\modules\__init__.py,sha256=TuJj3WUlvilYY39nUH-ykXkyprTxjq9NLW0QQXADqJk,4343
241
241
  torchrl\modules\distributions\continuous.py,sha256=Q9T8okHY625AGthEet5WOWfT0DTTSBeqWbIOB3lqPlE,26443
242
- torchrl\modules\distributions\discrete.py,sha256=KfZMRqX0NjTG82EiBgB-GYnf1kaIcAUqn_-D5zYUrGY,36499
242
+ torchrl\modules\distributions\discrete.py,sha256=A4-LMggbd0RBGsM3XmaQIyp7tzncdJpUHwMUzGQE9CM,36500
243
243
  torchrl\modules\distributions\truncated_normal.py,sha256=l5G3TePasl7q12DjwisyQC_E0OfZZo2g_HzBhZREVxc,6122
244
244
  torchrl\modules\distributions\utils.py,sha256=q4AFDKFpacRhrl4rjJ54UhxQzjOcj_SKlz0UIcZlUVc,7796
245
245
  torchrl\modules\distributions\__init__.py,sha256=Evkiz96ZPs7VUZp2n03h9kd7rmUCEEvMVl2f7RhzMhQ,1670
@@ -247,13 +247,13 @@ torchrl\modules\llm\utils.py,sha256=b2s9ngHwXnNbLggygU3-ScNwk0MWICketq2pZBshGqM,
247
247
  torchrl\modules\llm\__init__.py,sha256=_jPEt6oFb_R75zcVWlrfl4OMVxEs3Y9zCjVZp8dgg38,1098
248
248
  torchrl\modules\llm\backends\__init__.py,sha256=O7RanoHTBR4xLQLUJ2JUG4-zlVZ7PJDx1d2_AezVmaU,1027
249
249
  torchrl\modules\llm\backends\vllm\base.py,sha256=KZs36Q0sNveEkHJrub6xD_SzZafAdWz5ZK5ssHphMHM,2149
250
- torchrl\modules\llm\backends\vllm\vllm_async.py,sha256=VS_j-WmGJvW5G-nivGkYskEjzdm0V4j9RDqjKIMFM2E,80786
250
+ torchrl\modules\llm\backends\vllm\vllm_async.py,sha256=rj9vyftLJVFjIQqLEFvHw3YhA_jzpYs0C9hQl1LImVs,80766
251
251
  torchrl\modules\llm\backends\vllm\vllm_sync.py,sha256=K7P_da0XqAEO9lzniN3gso8IKWgP_S4ueyGncZtMEXU,15958
252
252
  torchrl\modules\llm\backends\vllm\vllm_utils.py,sha256=qRwirmMfTvqwr34-AtwcpjMp-InhDj7Iz2ZnEbkTcxE,4320
253
253
  torchrl\modules\llm\backends\vllm\__init__.py,sha256=LPqxC7ijHq6kjcA1a0kqGZU_JK_r8oT8BetsU4KKsWY,1816
254
254
  torchrl\modules\llm\policies\common.py,sha256=dhw_Q64Nk3WrfLqjfLlLVY5BMSCp1pbjUJT8MHuoEjQ,58104
255
- torchrl\modules\llm\policies\transformers_wrapper.py,sha256=-GcgzV_WGaPv3ZC8BNSSJOGFtXXjLn2FCbq45oG7dPc,114574
256
- torchrl\modules\llm\policies\vllm_wrapper.py,sha256=5mIzwW-tb_L-oqLz8Coqb4KbTRo1Rws3pbcugrGEpW8,95597
255
+ torchrl\modules\llm\policies\transformers_wrapper.py,sha256=HN2qsnFXdCoLHkJsuMp2Kp3wwMN9hEpWMebDHkJ-8xg,114583
256
+ torchrl\modules\llm\policies\vllm_wrapper.py,sha256=uMbis1Coqj7oke2fTwtQZoPnbUgRT-CPrcM4SQWrwlA,95599
257
257
  torchrl\modules\llm\policies\__init__.py,sha256=4rOIFNkYAZXMB5WIAomMV0tTZVP_J9mA4uatDpySbk8,628
258
258
  torchrl\modules\models\batchrenorm.py,sha256=bR4ZhaJ5E1cSK5o8L2dNX5KVLIb-bgrYxcq6yhx0I1A,4869
259
259
  torchrl\modules\models\decision_transformer.py,sha256=ANFTOm3k9_3Uv1vKGdXumRy3meBPnDdT8HqhVvJ2RCo,6783
@@ -281,30 +281,30 @@ torchrl\modules\tensordict_module\__init__.py,sha256=iTz8iCBmxt661GrGJRBfw4tBoTu
281
281
  torchrl\modules\utils\mappings.py,sha256=HEPGNHhQrPNU85-Bq0cYm1TZIhSkdEBkLvgrmjFMa4Q,371
282
282
  torchrl\modules\utils\utils.py,sha256=Ae2bl6GDxm9kU73WeLi-0ZEsrFt-XTaGqdxdXiX9LSU,3005
283
283
  torchrl\modules\utils\__init__.py,sha256=NQ_ko0JAIPY_X5RgBJnZLZXnYSH2q_kuD0tvXGqqY3k,1165
284
- torchrl\objectives\a2c.py,sha256=1rtRVZZx_HZFdfF6xLM1Q0kYknQg06rW5tC54CNb7Qk,29449
284
+ torchrl\objectives\a2c.py,sha256=Pc-xmumKqLw0ovMzLZ5F0BkovUB-LsUEsGvF2WVwTI4,29470
285
285
  torchrl\objectives\common.py,sha256=nslOhX1hi0nYvS8v7ylt5qB40fkQ7k-_XvO2p99ppTM,29274
286
- torchrl\objectives\cql.py,sha256=iu_wUH0MXPKFGYYV9AncQ81xzt42U2TR54f2WPl3fgE,56252
287
- torchrl\objectives\crossq.py,sha256=fqRlrFjRgSrb5JR4IP9fKb2-3snWQ7NtMprt21IqxTo,29557
288
- torchrl\objectives\ddpg.py,sha256=ceQW0cg6_5xxv1JVX_AS60E2seV7SHn6nUgYrqoRh9w,18214
289
- torchrl\objectives\decision_transformer.py,sha256=-U--UUC3UfbZi9MfOCugmjD1_xdohNoyACJmDkKAPgo,13405
290
- torchrl\objectives\deprecated.py,sha256=REfAaUeWcZ_b8NpVHBGa1U9NTO5zRZBc5vNQDdS_fGA,21253
291
- torchrl\objectives\dqn.py,sha256=qviNEbxiNQRKVDj-V4ibzeu0JbZvKs3MkEipNbc9UDM,28803
286
+ torchrl\objectives\cql.py,sha256=a6Fltgf05z_bP3YaXIydYtD4VZNSVhbHgMSxZsynHAg,56266
287
+ torchrl\objectives\crossq.py,sha256=KLG7DGZApOZWAktYYV-HLRh2C070zLAjcUh24xPqkbI,29571
288
+ torchrl\objectives\ddpg.py,sha256=KxZDlsNehydFh7-4oVVouGK6Dz8CO19uEpgHayd53mM,18221
289
+ torchrl\objectives\decision_transformer.py,sha256=vX-Gr8bXSQwq-gyPtFQWcrL_SO5ELuQfmBZm1uWzNzc,13419
290
+ torchrl\objectives\deprecated.py,sha256=jm77VQqK68nZziTlL14VZ3FHIAPVcUv1DUMGW33diBo,21267
291
+ torchrl\objectives\dqn.py,sha256=XhtPvhtNty1Gka9swHIJIF7HOM3huZqk0XTJ8RXVJcw,28831
292
292
  torchrl\objectives\dreamer.py,sha256=65EntKqou3auLMYxD1uaKGNyucfktabqaATNT1bExQc,18497
293
293
  torchrl\objectives\functional.py,sha256=0Pr_debAMM2bp06HPGVIpLTcyBue4DvcyUJVsaa6AjE,2154
294
- torchrl\objectives\gail.py,sha256=6UQHluezDA3fT7clTDwCT41xAcZMyHrxK4XR0juSxOc,9848
295
- torchrl\objectives\iql.py,sha256=VgCjfjKu91WCMCOK78vuc4k4kg1G3hPgOmjiKpozRM0,43976
294
+ torchrl\objectives\gail.py,sha256=MCJ-TE_asCp-NTfSgrqkUx9DWrR1GXth7VqrH46lndA,9855
295
+ torchrl\objectives\iql.py,sha256=CwymtcSV3RDktRVCbsQihST5PFik382-5J5qnRA3U8E,44004
296
296
  torchrl\objectives\ppo.py,sha256=l2fGbQ45Zd0mwSijLtHYk3rvaoOR451LPXwcVq1L7Zw,82038
297
- torchrl\objectives\redq.py,sha256=qRN5WyA6YHh7GcKX9n5GinXyETssAXJkiH0HuOx6Uss,29177
298
- torchrl\objectives\reinforce.py,sha256=EnUjqDSiTla3CuHg9rspQlvecd-VXZrPZxg4rGECZ8w,22861
299
- torchrl\objectives\sac.py,sha256=wKpfdm2y8Udp100PVp4bC0ljkdPwPQQNlYg5ZqvVO1M,70861
300
- torchrl\objectives\td3.py,sha256=Rq2q5gXo3AMuHm2OjRZvpfvKsAl1lIK5ALh2_sZM1ZE,23743
301
- torchrl\objectives\td3_bc.py,sha256=1pjB8mjCT2CLvQzjnqwAfZoc7yhjMB9UQjuJ5wZfTUY,26558
297
+ torchrl\objectives\redq.py,sha256=tLGi5wh8gErf0Ds725n_mwd3iSvpCXSh2w0WFJoqQvY,29191
298
+ torchrl\objectives\reinforce.py,sha256=Fzu-7VBiepxZT_bXX17yc4fyPl7BFjqQerCokVR29E8,22882
299
+ torchrl\objectives\sac.py,sha256=Djj2yGhqu1cglIVoOvq232Q9fkoNq5D1P7z9Flis9GI,70910
300
+ torchrl\objectives\td3.py,sha256=vixzm2sITvLVrojCnIYGopyimvcL8o_dML4nc93-WJY,23757
301
+ torchrl\objectives\td3_bc.py,sha256=oKZm3BRHsg-DK5fzFw1DC6fX8hrzbh8tmCOGvw5a62Q,26572
302
302
  torchrl\objectives\utils.py,sha256=CfEk41IWgVpzgZ7jq7rWS7FZbh4ymYsv92Td2rCWFRE,25990
303
303
  torchrl\objectives\__init__.py,sha256=Ug1FX1kFbTSz_i51uaDw7pOBIXSUIbH7BE5_8PZNbHM,2245
304
304
  torchrl\objectives\llm\grpo.py,sha256=kXcHSA_uPAEqvqRjjLTrzyupqMG9yuiRh_G0AxMAzaQ,25307
305
305
  torchrl\objectives\llm\sft.py,sha256=U1jtwZfDYLaU1YzA6edGhRo4t00dHpWbVKQLyyF2f1g,21352
306
306
  torchrl\objectives\llm\__init__.py,sha256=tZmIz3rkeclw3MzJoOWEs2gkewjx2USKrKJbWdyiiaQ,406
307
- torchrl\objectives\multiagent\qmixer.py,sha256=yttOxc5FNylKw4iMnYSG1qO8EbHvx8imAhxNxW9_iLw,17362
307
+ torchrl\objectives\multiagent\qmixer.py,sha256=MQST8UvktQLG9Z1b12Fj0RdxloWKj87paxL0DWwSucc,17369
308
308
  torchrl\objectives\multiagent\__init__.py,sha256=5uebDe5KrvlzeYV_BSd5vdmfruJQYMeDVVbU4iHErEg,245
309
309
  torchrl\objectives\value\advantages.py,sha256=Nz0IANqvV7uAMzghxA6Ta1DdkEfLo2w2Z21wd2w1duo,86865
310
310
  torchrl\objectives\value\functional.py,sha256=bgZiXJKuOmqlKdtTzWXvbgab4yT01xik-cTW2YQkTNU,51091
@@ -317,7 +317,7 @@ torchrl\record\loggers\csv.py,sha256=uNFjiPLq7mMr5z2WPyjyr9HGexu4ZkUwbX09FsV1mJ4
317
317
  torchrl\record\loggers\mlflow.py,sha256=9N-a5OUJJGwYej0WvTxQPkrazsahhmgof8seMDCnjM0,5098
318
318
  torchrl\record\loggers\tensorboard.py,sha256=x1Mo7KE4-iGG5NVToAP-1XceG_F6Vipr26B_1CK9Tg0,5005
319
319
  torchrl\record\loggers\utils.py,sha256=rZqyZi-ebLozHh8pbV-7m23W27zXsbnR04Rh-yWMPV8,2432
320
- torchrl\record\loggers\wandb.py,sha256=OJSDhMuT0PjfPHJdpthPZt9mJwqcf0lLQQDWdmjvENc,7277
320
+ torchrl\record\loggers\wandb.py,sha256=tkedfdd4RqTPFfe6xnyIhPm8T5a4Lzi9wW0ww1PY4iE,7298
321
321
  torchrl\record\loggers\__init__.py,sha256=pa6ttxj0FORHS6MgiYg05iFoABwJ8vqBHn45wkqshT4,568
322
322
  torchrl\trainers\trainers.py,sha256=_1SvHfNGwUd8w_YeRJBpJPm4Bvrv-5u7HsyKk5V3FCA,66905
323
323
  torchrl\trainers\__init__.py,sha256=LEUdW1zV5jydpMjZqnJ7XZW77MBIVv8dZ1nGv-m89RQ,853
@@ -344,8 +344,9 @@ torchrl\trainers\helpers\models.py,sha256=VujBq9H92sEzpCtU1iTrJQNlwvyOO-Rho4bzsM
344
344
  torchrl\trainers\helpers\replay_buffer.py,sha256=RaZqXnHimmadiibvDBcLbtIhpPaVMTPhYMOBvX4v3CA,2060
345
345
  torchrl\trainers\helpers\trainers.py,sha256=VVhAXHcutHyVa7kJEo_RtaI9U5h0Hk2qLEnONXFpPQ8,12350
346
346
  torchrl\trainers\helpers\__init__.py,sha256=sCBIXQqFQKRrbcNojgPxIh82HpXnXKgA_kMa3uZESSk,1137
347
- torchrl-0.10.0.dist-info\LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
348
- torchrl-0.10.0.dist-info\METADATA,sha256=VTYNsSJuSqtVhuiw1P4bZbaSm1UKjh9NR7OgMabJvYM,50106
349
- torchrl-0.10.0.dist-info\RECORD,,
350
- torchrl-0.10.0.dist-info\top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
351
- torchrl-0.10.0.dist-info\WHEEL,sha256=DmZ7B4aiAganfWOCUyjXG_z2uvUu-tkD3rVXMivgyOM,99
347
+ torchrl-0.10.1.dist-info\entry_points.txt,sha256=kjqZUboF3jzU21uy15NPn2WDfbwGE21Ls0fmfhqhmy4,110
348
+ torchrl-0.10.1.dist-info\LICENSE,sha256=PGO-oZsq4EzhE1-WQS2xGiEF3UCVb9YawfQ09cIMV_8,1119
349
+ torchrl-0.10.1.dist-info\METADATA,sha256=S5rAbED_7qvdSpMbLin0amLt_NkQ1aFqx93zdIPNz2g,50106
350
+ torchrl-0.10.1.dist-info\RECORD,,
351
+ torchrl-0.10.1.dist-info\top_level.txt,sha256=-5FcSdmJ9DwdHF8aOIaofsPbz4Gm8G1eo7r7Sc2CHgE,59
352
+ torchrl-0.10.1.dist-info\WHEEL,sha256=DmZ7B4aiAganfWOCUyjXG_z2uvUu-tkD3rVXMivgyOM,99
@@ -0,0 +1,2 @@
1
+ [vllm.general_plugins]
2
+ fp32_overrides = torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides