torchrl-nightly 2025.8.8__cp312-cp312-manylinux1_x86_64.whl → 2025.8.10__cp312-cp312-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +1 -2
- sota-implementations/grpo/grpo_utils.py +2 -1
- sota-implementations/redq/utils.py +2 -1
- torchrl/_torchrl.cpython-312-x86_64-linux-gnu.so +0 -0
- torchrl/_utils.py +2 -1
- torchrl/collectors/collectors.py +2 -1
- torchrl/collectors/distributed/generic.py +3 -1
- torchrl/collectors/distributed/ray.py +3 -1
- torchrl/collectors/distributed/rpc.py +3 -1
- torchrl/collectors/distributed/sync.py +3 -1
- torchrl/collectors/llm/base.py +2 -1
- torchrl/collectors/llm/ray_collector.py +2 -1
- torchrl/collectors/utils.py +1 -1
- torchrl/collectors/weight_update.py +2 -1
- torchrl/data/datasets/atari_dqn.py +1 -1
- torchrl/data/datasets/common.py +1 -1
- torchrl/data/datasets/d4rl.py +1 -1
- torchrl/data/datasets/minari_data.py +1 -1
- torchrl/data/datasets/openml.py +1 -1
- torchrl/data/datasets/openx.py +2 -1
- torchrl/data/datasets/roboset.py +1 -1
- torchrl/data/datasets/vd4rl.py +1 -1
- torchrl/data/llm/dataset.py +1 -1
- torchrl/data/map/hash.py +1 -1
- torchrl/data/map/query.py +4 -2
- torchrl/data/map/tdstorage.py +2 -1
- torchrl/data/map/tree.py +2 -1
- torchrl/data/map/utils.py +1 -1
- torchrl/data/replay_buffers/ray_buffer.py +2 -1
- torchrl/data/replay_buffers/replay_buffers.py +2 -1
- torchrl/data/replay_buffers/scheduler.py +2 -1
- torchrl/data/replay_buffers/storages.py +2 -1
- torchrl/data/replay_buffers/utils.py +2 -1
- torchrl/data/replay_buffers/writers.py +2 -1
- torchrl/data/tensor_specs.py +8 -19
- torchrl/data/utils.py +3 -2
- torchrl/envs/async_envs.py +2 -1
- torchrl/envs/batched_envs.py +2 -1
- torchrl/envs/common.py +2 -1
- torchrl/envs/custom/llm.py +1 -1
- torchrl/envs/env_creator.py +1 -1
- torchrl/envs/gym_like.py +2 -1
- torchrl/envs/libs/dm_control.py +2 -2
- torchrl/envs/libs/gym.py +2 -3
- torchrl/envs/libs/meltingpot.py +1 -1
- torchrl/envs/libs/pettingzoo.py +2 -3
- torchrl/envs/libs/smacv2.py +8 -10
- torchrl/envs/llm/chat.py +3 -1
- torchrl/envs/llm/datasets/gsm8k.py +2 -1
- torchrl/envs/llm/datasets/ifeval.py +3 -1
- torchrl/envs/llm/envs.py +2 -1
- torchrl/envs/llm/reward/ifeval/_instructions.py +3 -2
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1 -1
- torchrl/envs/llm/reward/ifeval/_scorer.py +1 -1
- torchrl/envs/llm/transforms/dataloading.py +2 -2
- torchrl/envs/llm/transforms/reason.py +2 -1
- torchrl/envs/llm/transforms/tokenizer.py +1 -1
- torchrl/envs/transforms/transforms.py +3 -10
- torchrl/envs/transforms/vecnorm.py +3 -1
- torchrl/modules/distributions/continuous.py +1 -1
- torchrl/modules/distributions/discrete.py +2 -1
- torchrl/modules/models/exploration.py +1 -1
- torchrl/modules/models/models.py +1 -1
- torchrl/modules/models/multiagent.py +1 -1
- torchrl/modules/models/utils.py +1 -1
- torchrl/modules/tensordict_module/actors.py +1 -1
- torchrl/modules/tensordict_module/common.py +1 -1
- torchrl/objectives/common.py +1 -1
- torchrl/objectives/ppo.py +1 -1
- torchrl/objectives/utils.py +2 -1
- torchrl/objectives/value/advantages.py +1 -1
- torchrl/record/loggers/common.py +1 -1
- torchrl/record/loggers/csv.py +1 -1
- torchrl/record/loggers/mlflow.py +2 -1
- torchrl/record/loggers/tensorboard.py +1 -1
- torchrl/record/loggers/wandb.py +1 -1
- torchrl/record/recorder.py +1 -1
- torchrl/trainers/helpers/collectors.py +3 -1
- torchrl/trainers/helpers/envs.py +14 -13
- torchrl/trainers/trainers.py +5 -4
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/RECORD +85 -85
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/top_level.txt +0 -0
torchrl/envs/libs/smacv2.py
CHANGED
@@ -5,8 +5,6 @@
|
|
5
5
|
import importlib
|
6
6
|
import re
|
7
7
|
|
8
|
-
from typing import Dict, Optional
|
9
|
-
|
10
8
|
import torch
|
11
9
|
from tensordict import TensorDict, TensorDictBase
|
12
10
|
|
@@ -196,7 +194,7 @@ class SMACv2Wrapper(_EnvWrapper):
|
|
196
194
|
|
197
195
|
return smacv2
|
198
196
|
|
199
|
-
def _check_kwargs(self, kwargs:
|
197
|
+
def _check_kwargs(self, kwargs: dict):
|
200
198
|
import smacv2
|
201
199
|
|
202
200
|
if "env" not in kwargs:
|
@@ -311,7 +309,7 @@ class SMACv2Wrapper(_EnvWrapper):
|
|
311
309
|
)
|
312
310
|
return spec
|
313
311
|
|
314
|
-
def _set_seed(self, seed:
|
312
|
+
def _set_seed(self, seed: int | None) -> None:
|
315
313
|
if seed is not None:
|
316
314
|
raise NotImplementedError(
|
317
315
|
"Seed cannot be changed once environment was created."
|
@@ -329,7 +327,7 @@ class SMACv2Wrapper(_EnvWrapper):
|
|
329
327
|
return torch.tensor(value, device=self.device, dtype=torch.float32)
|
330
328
|
|
331
329
|
def _reset(
|
332
|
-
self, tensordict:
|
330
|
+
self, tensordict: TensorDictBase | None = None, **kwargs
|
333
331
|
) -> TensorDictBase:
|
334
332
|
|
335
333
|
obs, state = self._env.reset()
|
@@ -602,8 +600,8 @@ class SMACv2Env(SMACv2Wrapper):
|
|
602
600
|
def __init__(
|
603
601
|
self,
|
604
602
|
map_name: str,
|
605
|
-
capability_config:
|
606
|
-
seed:
|
603
|
+
capability_config: dict | None = None,
|
604
|
+
seed: int | None = None,
|
607
605
|
categorical_actions: bool = True,
|
608
606
|
**kwargs,
|
609
607
|
):
|
@@ -619,15 +617,15 @@ class SMACv2Env(SMACv2Wrapper):
|
|
619
617
|
|
620
618
|
super().__init__(**kwargs)
|
621
619
|
|
622
|
-
def _check_kwargs(self, kwargs:
|
620
|
+
def _check_kwargs(self, kwargs: dict):
|
623
621
|
if "map_name" not in kwargs:
|
624
622
|
raise TypeError("Expected 'map_name' to be part of kwargs")
|
625
623
|
|
626
624
|
def _build_env(
|
627
625
|
self,
|
628
626
|
map_name: str,
|
629
|
-
capability_config:
|
630
|
-
seed:
|
627
|
+
capability_config: dict | None = None,
|
628
|
+
seed: int | None = None,
|
631
629
|
**kwargs,
|
632
630
|
) -> "smacv2.env.StarCraft2Env": # noqa: F821
|
633
631
|
import smacv2.env
|
torchrl/envs/llm/chat.py
CHANGED
@@ -4,7 +4,9 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
from
|
7
|
+
from collections.abc import Callable
|
8
|
+
|
9
|
+
from typing import Any, Literal, TYPE_CHECKING
|
8
10
|
|
9
11
|
import torch
|
10
12
|
from tensordict import lazy_stack, TensorDictBase
|
@@ -5,7 +5,8 @@
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
7
|
import warnings
|
8
|
-
from
|
8
|
+
from collections.abc import Callable
|
9
|
+
from typing import Any, Literal, TYPE_CHECKING
|
9
10
|
|
10
11
|
import torch
|
11
12
|
from tensordict import NestedKey, TensorDict, TensorDictBase
|
@@ -4,7 +4,9 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
from
|
7
|
+
from collections.abc import Callable
|
8
|
+
|
9
|
+
from typing import Any, Literal, TYPE_CHECKING
|
8
10
|
|
9
11
|
import torch
|
10
12
|
from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
|
torchrl/envs/llm/envs.py
CHANGED
@@ -36,7 +36,8 @@ import json
|
|
36
36
|
import random
|
37
37
|
import re
|
38
38
|
import string
|
39
|
-
from
|
39
|
+
from collections.abc import Sequence
|
40
|
+
from typing import Any, Literal, Optional, Union
|
40
41
|
|
41
42
|
from torchrl._utils import logger as torchrl_logger
|
42
43
|
|
@@ -50,7 +51,7 @@ from ._instructions_util import (
|
|
50
51
|
)
|
51
52
|
|
52
53
|
|
53
|
-
_InstructionArgsDtype = Optional[
|
54
|
+
_InstructionArgsDtype = Optional[dict[str, Union[int, str, Sequence[str]]]]
|
54
55
|
|
55
56
|
_LANGUAGES = LANGUAGE_CODES
|
56
57
|
|
@@ -6,8 +6,8 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
import warnings
|
8
8
|
from collections import deque
|
9
|
-
from collections.abc import Mapping
|
10
|
-
from typing import Any,
|
9
|
+
from collections.abc import Callable, Iterable, Mapping
|
10
|
+
from typing import Any, Literal
|
11
11
|
|
12
12
|
import torch
|
13
13
|
from tensordict import is_tensor_collection, lazy_stack, TensorDict, TensorDictBase
|
@@ -6,7 +6,8 @@
|
|
6
6
|
from __future__ import annotations
|
7
7
|
|
8
8
|
import re
|
9
|
-
from
|
9
|
+
from collections.abc import Callable
|
10
|
+
from typing import Literal
|
10
11
|
|
11
12
|
from tensordict import lazy_stack, TensorDictBase
|
12
13
|
from torchrl._utils import logger as torchrl_logger
|
@@ -14,20 +14,13 @@ import multiprocessing as mp
|
|
14
14
|
import time
|
15
15
|
import warnings
|
16
16
|
import weakref
|
17
|
+
from collections import OrderedDict
|
18
|
+
from collections.abc import Callable, Mapping, Sequence
|
17
19
|
from copy import copy
|
18
20
|
from enum import IntEnum
|
19
21
|
from functools import wraps
|
20
22
|
from textwrap import indent
|
21
|
-
from typing import
|
22
|
-
Any,
|
23
|
-
Callable,
|
24
|
-
Mapping,
|
25
|
-
OrderedDict,
|
26
|
-
Sequence,
|
27
|
-
TYPE_CHECKING,
|
28
|
-
TypeVar,
|
29
|
-
Union,
|
30
|
-
)
|
23
|
+
from typing import Any, TYPE_CHECKING, TypeVar, Union
|
31
24
|
|
32
25
|
import numpy as np
|
33
26
|
|
@@ -7,9 +7,11 @@ from __future__ import annotations
|
|
7
7
|
import math
|
8
8
|
import uuid
|
9
9
|
import warnings
|
10
|
+
from collections import OrderedDict
|
11
|
+
from collections.abc import Sequence
|
10
12
|
from copy import copy
|
11
13
|
|
12
|
-
from typing import Any
|
14
|
+
from typing import Any
|
13
15
|
|
14
16
|
import torch
|
15
17
|
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
|
@@ -4,9 +4,10 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
from collections.abc import Sequence
|
8
|
+
|
7
9
|
from enum import Enum
|
8
10
|
from functools import wraps
|
9
|
-
from typing import Sequence
|
10
11
|
|
11
12
|
import torch
|
12
13
|
import torch.distributions as D
|
torchrl/modules/models/models.py
CHANGED
torchrl/modules/models/utils.py
CHANGED
torchrl/objectives/common.py
CHANGED
@@ -8,9 +8,9 @@ from __future__ import annotations
|
|
8
8
|
import abc
|
9
9
|
import functools
|
10
10
|
import warnings
|
11
|
+
from collections.abc import Iterator
|
11
12
|
from copy import deepcopy
|
12
13
|
from dataclasses import dataclass
|
13
|
-
from typing import Iterator
|
14
14
|
|
15
15
|
import torch
|
16
16
|
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
|
torchrl/objectives/ppo.py
CHANGED
torchrl/objectives/utils.py
CHANGED
@@ -7,9 +7,10 @@ from __future__ import annotations
|
|
7
7
|
import functools
|
8
8
|
import re
|
9
9
|
import warnings
|
10
|
+
from collections.abc import Callable, Iterable
|
10
11
|
from copy import copy
|
11
12
|
from enum import Enum
|
12
|
-
from typing import Any,
|
13
|
+
from typing import Any, TypeVar
|
13
14
|
|
14
15
|
import torch
|
15
16
|
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
|
@@ -7,10 +7,10 @@ from __future__ import annotations
|
|
7
7
|
import abc
|
8
8
|
import functools
|
9
9
|
import warnings
|
10
|
+
from collections.abc import Callable
|
10
11
|
from contextlib import nullcontext
|
11
12
|
from dataclasses import asdict, dataclass
|
12
13
|
from functools import wraps
|
13
|
-
from typing import Callable
|
14
14
|
|
15
15
|
import torch
|
16
16
|
from tensordict import is_tensor_collection, TensorDictBase
|
torchrl/record/loggers/common.py
CHANGED
torchrl/record/loggers/csv.py
CHANGED
torchrl/record/loggers/mlflow.py
CHANGED
torchrl/record/loggers/wandb.py
CHANGED
torchrl/record/recorder.py
CHANGED
@@ -4,8 +4,10 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
from collections.abc import Callable
|
8
|
+
|
7
9
|
from dataclasses import dataclass, field
|
8
|
-
from typing import Any
|
10
|
+
from typing import Any
|
9
11
|
|
10
12
|
from tensordict import TensorDictBase
|
11
13
|
|
torchrl/trainers/helpers/envs.py
CHANGED
@@ -8,9 +8,10 @@
|
|
8
8
|
# from __future__ import annotations
|
9
9
|
|
10
10
|
import importlib.util
|
11
|
+
from collections.abc import Callable, Sequence
|
11
12
|
from copy import copy
|
12
13
|
from dataclasses import dataclass, field as dataclass_field
|
13
|
-
from typing import Any
|
14
|
+
from typing import Any
|
14
15
|
|
15
16
|
import torch
|
16
17
|
from torchrl._utils import logger as torchrl_logger, VERBOSE
|
@@ -223,18 +224,18 @@ def get_norm_state_dict(env):
|
|
223
224
|
def transformed_env_constructor(
|
224
225
|
cfg: DictConfig, # noqa: F821
|
225
226
|
video_tag: str = "",
|
226
|
-
logger:
|
227
|
-
stats:
|
227
|
+
logger: Logger | None = None, # noqa
|
228
|
+
stats: dict | None = None,
|
228
229
|
norm_obs_only: bool = False,
|
229
230
|
use_env_creator: bool = False,
|
230
|
-
custom_env_maker:
|
231
|
-
custom_env:
|
231
|
+
custom_env_maker: Callable | None = None,
|
232
|
+
custom_env: EnvBase | None = None,
|
232
233
|
return_transformed_envs: bool = True,
|
233
|
-
action_dim_gsde:
|
234
|
-
state_dim_gsde:
|
235
|
-
batch_dims:
|
236
|
-
obs_norm_state_dict:
|
237
|
-
) ->
|
234
|
+
action_dim_gsde: int | None = None,
|
235
|
+
state_dim_gsde: int | None = None,
|
236
|
+
batch_dims: int | None = 0,
|
237
|
+
obs_norm_state_dict: dict | None = None,
|
238
|
+
) -> Callable | EnvCreator:
|
238
239
|
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
|
239
240
|
|
240
241
|
Args:
|
@@ -340,7 +341,7 @@ def transformed_env_constructor(
|
|
340
341
|
|
341
342
|
def parallel_env_constructor(
|
342
343
|
cfg: DictConfig, **kwargs # noqa: F821
|
343
|
-
) ->
|
344
|
+
) -> ParallelEnv | EnvCreator:
|
344
345
|
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
|
345
346
|
|
346
347
|
Args:
|
@@ -385,7 +386,7 @@ def parallel_env_constructor(
|
|
385
386
|
def get_stats_random_rollout(
|
386
387
|
cfg: DictConfig, # noqa: F821
|
387
388
|
proof_environment: EnvBase = None,
|
388
|
-
key:
|
389
|
+
key: str | None = None,
|
389
390
|
):
|
390
391
|
"""Gathers stas (loc and scale) from an environment using random rollouts.
|
391
392
|
|
@@ -463,7 +464,7 @@ def get_stats_random_rollout(
|
|
463
464
|
def initialize_observation_norm_transforms(
|
464
465
|
proof_environment: EnvBase,
|
465
466
|
num_iter: int = 1000,
|
466
|
-
key:
|
467
|
+
key: str | tuple[str, ...] = None,
|
467
468
|
):
|
468
469
|
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
|
469
470
|
|
torchrl/trainers/trainers.py
CHANGED
@@ -9,9 +9,10 @@ import abc
|
|
9
9
|
import pathlib
|
10
10
|
import warnings
|
11
11
|
from collections import defaultdict, OrderedDict
|
12
|
+
from collections.abc import Callable, Sequence
|
12
13
|
from copy import deepcopy
|
13
14
|
from textwrap import indent
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import numpy as np
|
17
18
|
import torch.nn
|
@@ -362,19 +363,19 @@ class Trainer:
|
|
362
363
|
|
363
364
|
elif dest == "pre_steps_log":
|
364
365
|
_check_input_output_typehint(
|
365
|
-
op, input=TensorDictBase, output=
|
366
|
+
op, input=TensorDictBase, output=tuple[str, float]
|
366
367
|
)
|
367
368
|
self._pre_steps_log_ops.append((op, kwargs))
|
368
369
|
|
369
370
|
elif dest == "post_steps_log":
|
370
371
|
_check_input_output_typehint(
|
371
|
-
op, input=TensorDictBase, output=
|
372
|
+
op, input=TensorDictBase, output=tuple[str, float]
|
372
373
|
)
|
373
374
|
self._post_steps_log_ops.append((op, kwargs))
|
374
375
|
|
375
376
|
elif dest == "post_optim_log":
|
376
377
|
_check_input_output_typehint(
|
377
|
-
op, input=TensorDictBase, output=
|
378
|
+
op, input=TensorDictBase, output=tuple[str, float]
|
378
379
|
)
|
379
380
|
self._post_optim_log_ops.append((op, kwargs))
|
380
381
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: torchrl-nightly
|
3
|
-
Version: 2025.8.
|
3
|
+
Version: 2025.8.10
|
4
4
|
Summary: A modular, primitive-first, python-first PyTorch library for Reinforcement Learning
|
5
5
|
Author-email: torchrl contributors <vmoens@fb.com>
|
6
6
|
Maintainer-email: torchrl contributors <vmoens@fb.com>
|