torchrl-nightly 2025.8.8__cp39-cp39-macosx_10_9_universal2.whl → 2025.8.10__cp39-cp39-macosx_10_9_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +1 -2
  2. sota-implementations/grpo/grpo_utils.py +2 -1
  3. sota-implementations/redq/utils.py +2 -1
  4. torchrl/_torchrl.cpython-39-darwin.so +0 -0
  5. torchrl/_utils.py +2 -1
  6. torchrl/collectors/collectors.py +2 -1
  7. torchrl/collectors/distributed/generic.py +3 -1
  8. torchrl/collectors/distributed/ray.py +3 -1
  9. torchrl/collectors/distributed/rpc.py +3 -1
  10. torchrl/collectors/distributed/sync.py +3 -1
  11. torchrl/collectors/llm/base.py +2 -1
  12. torchrl/collectors/llm/ray_collector.py +2 -1
  13. torchrl/collectors/utils.py +1 -1
  14. torchrl/collectors/weight_update.py +2 -1
  15. torchrl/data/datasets/atari_dqn.py +1 -1
  16. torchrl/data/datasets/common.py +1 -1
  17. torchrl/data/datasets/d4rl.py +1 -1
  18. torchrl/data/datasets/minari_data.py +1 -1
  19. torchrl/data/datasets/openml.py +1 -1
  20. torchrl/data/datasets/openx.py +2 -1
  21. torchrl/data/datasets/roboset.py +1 -1
  22. torchrl/data/datasets/vd4rl.py +1 -1
  23. torchrl/data/llm/dataset.py +1 -1
  24. torchrl/data/map/hash.py +1 -1
  25. torchrl/data/map/query.py +4 -2
  26. torchrl/data/map/tdstorage.py +2 -1
  27. torchrl/data/map/tree.py +2 -1
  28. torchrl/data/map/utils.py +1 -1
  29. torchrl/data/replay_buffers/ray_buffer.py +2 -1
  30. torchrl/data/replay_buffers/replay_buffers.py +2 -1
  31. torchrl/data/replay_buffers/scheduler.py +2 -1
  32. torchrl/data/replay_buffers/storages.py +2 -1
  33. torchrl/data/replay_buffers/utils.py +2 -1
  34. torchrl/data/replay_buffers/writers.py +2 -1
  35. torchrl/data/tensor_specs.py +8 -19
  36. torchrl/data/utils.py +3 -2
  37. torchrl/envs/async_envs.py +2 -1
  38. torchrl/envs/batched_envs.py +2 -1
  39. torchrl/envs/common.py +2 -1
  40. torchrl/envs/custom/llm.py +1 -1
  41. torchrl/envs/env_creator.py +1 -1
  42. torchrl/envs/gym_like.py +2 -1
  43. torchrl/envs/libs/dm_control.py +2 -2
  44. torchrl/envs/libs/gym.py +2 -3
  45. torchrl/envs/libs/meltingpot.py +1 -1
  46. torchrl/envs/libs/pettingzoo.py +2 -3
  47. torchrl/envs/libs/smacv2.py +8 -10
  48. torchrl/envs/llm/chat.py +3 -1
  49. torchrl/envs/llm/datasets/gsm8k.py +2 -1
  50. torchrl/envs/llm/datasets/ifeval.py +3 -1
  51. torchrl/envs/llm/envs.py +2 -1
  52. torchrl/envs/llm/reward/ifeval/_instructions.py +3 -2
  53. torchrl/envs/llm/reward/ifeval/_instructions_util.py +1 -1
  54. torchrl/envs/llm/reward/ifeval/_scorer.py +1 -1
  55. torchrl/envs/llm/transforms/dataloading.py +2 -2
  56. torchrl/envs/llm/transforms/reason.py +2 -1
  57. torchrl/envs/llm/transforms/tokenizer.py +1 -1
  58. torchrl/envs/transforms/transforms.py +3 -10
  59. torchrl/envs/transforms/vecnorm.py +3 -1
  60. torchrl/modules/distributions/continuous.py +1 -1
  61. torchrl/modules/distributions/discrete.py +2 -1
  62. torchrl/modules/models/exploration.py +1 -1
  63. torchrl/modules/models/models.py +1 -1
  64. torchrl/modules/models/multiagent.py +1 -1
  65. torchrl/modules/models/utils.py +1 -1
  66. torchrl/modules/tensordict_module/actors.py +1 -1
  67. torchrl/modules/tensordict_module/common.py +1 -1
  68. torchrl/objectives/common.py +1 -1
  69. torchrl/objectives/ppo.py +1 -1
  70. torchrl/objectives/utils.py +2 -1
  71. torchrl/objectives/value/advantages.py +1 -1
  72. torchrl/record/loggers/common.py +1 -1
  73. torchrl/record/loggers/csv.py +1 -1
  74. torchrl/record/loggers/mlflow.py +2 -1
  75. torchrl/record/loggers/tensorboard.py +1 -1
  76. torchrl/record/loggers/wandb.py +1 -1
  77. torchrl/record/recorder.py +1 -1
  78. torchrl/trainers/helpers/collectors.py +3 -1
  79. torchrl/trainers/helpers/envs.py +14 -13
  80. torchrl/trainers/trainers.py +5 -4
  81. {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/METADATA +1 -1
  82. {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/RECORD +85 -85
  83. {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/LICENSE +0 -0
  84. {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/WHEEL +0 -0
  85. {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ import pickle
9
9
 
10
10
  import time
11
11
  from pathlib import Path
12
- from typing import Dict
13
12
 
14
13
  import numpy as np
15
14
 
@@ -93,7 +92,7 @@ def run_vmas_rllib(
93
92
  - result["timers"]["learn_time_ms"]
94
93
  )
95
94
 
96
- def env_creator(config: Dict):
95
+ def env_creator(config: dict):
97
96
  env = vmas.make_env(
98
97
  scenario=config["scenario_name"],
99
98
  num_envs=config["num_envs"],
@@ -5,7 +5,8 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import time
8
- from typing import Any, Callable, Literal
8
+ from collections.abc import Callable
9
+ from typing import Any, Literal
9
10
 
10
11
  import torch
11
12
  from omegaconf import DictConfig
@@ -4,8 +4,9 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
+ from collections.abc import Callable, Sequence
8
+
7
9
  from copy import copy
8
- from typing import Callable, Sequence
9
10
 
10
11
  import torch
11
12
  from omegaconf import OmegaConf
Binary file
torchrl/_utils.py CHANGED
@@ -16,12 +16,13 @@ import threading
16
16
  import time
17
17
  import traceback
18
18
  import warnings
19
+ from collections.abc import Callable
19
20
  from contextlib import nullcontext
20
21
  from copy import copy
21
22
  from functools import wraps
22
23
  from importlib import import_module
23
24
  from textwrap import indent
24
- from typing import Any, Callable, cast, TypeVar
25
+ from typing import Any, cast, TypeVar
25
26
 
26
27
  import numpy as np
27
28
  import torch
@@ -17,12 +17,13 @@ import time
17
17
  import typing
18
18
  import warnings
19
19
  from collections import defaultdict, OrderedDict
20
+ from collections.abc import Callable, Iterator, Mapping, Sequence
20
21
  from copy import deepcopy
21
22
  from multiprocessing import connection, queues
22
23
  from multiprocessing.managers import SyncManager
23
24
  from queue import Empty
24
25
  from textwrap import indent
25
- from typing import Any, Callable, Iterator, Mapping, Sequence, TypeVar
26
+ from typing import Any, TypeVar
26
27
 
27
28
  import numpy as np
28
29
  import torch
@@ -9,9 +9,11 @@ from __future__ import annotations
9
9
  import os
10
10
  import socket
11
11
  import warnings
12
+ from collections import OrderedDict
13
+ from collections.abc import Callable, Sequence
12
14
  from copy import copy, deepcopy
13
15
  from datetime import timedelta
14
- from typing import Any, Callable, OrderedDict, Sequence
16
+ from typing import Any
15
17
 
16
18
  import torch.cuda
17
19
  from tensordict import TensorDict, TensorDictBase
@@ -7,7 +7,9 @@ from __future__ import annotations
7
7
 
8
8
  import asyncio
9
9
  import warnings
10
- from typing import Any, Callable, Iterator, OrderedDict, Sequence
10
+ from collections import OrderedDict
11
+ from collections.abc import Callable, Iterator, Sequence
12
+ from typing import Any
11
13
 
12
14
  import torch
13
15
  import torch.nn as nn
@@ -11,8 +11,10 @@ import os
11
11
  import socket
12
12
  import time
13
13
  import warnings
14
+ from collections import OrderedDict
15
+ from collections.abc import Callable, Sequence
14
16
  from copy import copy, deepcopy
15
- from typing import Any, Callable, OrderedDict, Sequence
17
+ from typing import Any
16
18
 
17
19
  import torch.cuda
18
20
 
@@ -9,9 +9,11 @@ from __future__ import annotations
9
9
  import os
10
10
  import socket
11
11
  import warnings
12
+ from collections import OrderedDict
13
+ from collections.abc import Callable, Sequence
12
14
  from copy import copy, deepcopy
13
15
  from datetime import timedelta
14
- from typing import Any, Callable, Literal, OrderedDict, Sequence
16
+ from typing import Any, Literal
15
17
 
16
18
  import torch.cuda
17
19
  from tensordict import TensorDict, TensorDictBase
@@ -5,7 +5,8 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  from collections import deque
8
- from typing import Any, Callable
8
+ from collections.abc import Callable
9
+ from typing import Any
9
10
 
10
11
  import torch
11
12
 
@@ -7,7 +7,8 @@ from __future__ import annotations
7
7
  import copy
8
8
 
9
9
  import warnings
10
- from typing import Any, Callable, Iterator
10
+ from collections.abc import Callable, Iterator
11
+ from typing import Any
11
12
 
12
13
  import torch
13
14
  from tensordict import TensorDictBase
@@ -4,7 +4,7 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable
7
+ from collections.abc import Callable
8
8
 
9
9
  import torch
10
10
 
@@ -6,7 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  import abc
8
8
  import weakref
9
- from typing import Any, Callable, TypeVar
9
+ from collections.abc import Callable
10
+ from typing import Any, TypeVar
10
11
 
11
12
  import torch
12
13
  from tensordict import TensorDict, TensorDictBase
@@ -13,8 +13,8 @@ import shutil
13
13
  import subprocess
14
14
  import tempfile
15
15
  from collections import defaultdict
16
+ from collections.abc import Callable
16
17
  from pathlib import Path
17
- from typing import Callable
18
18
 
19
19
  import numpy as np
20
20
  import torch
@@ -6,8 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  import abc
8
8
  import shutil
9
+ from collections.abc import Callable
9
10
  from pathlib import Path
10
- from typing import Callable
11
11
 
12
12
  import torch
13
13
  from tensordict import TensorDictBase
@@ -10,8 +10,8 @@ import shutil
10
10
  import tempfile
11
11
  import urllib
12
12
  import warnings
13
+ from collections.abc import Callable
13
14
  from pathlib import Path
14
- from typing import Callable
15
15
 
16
16
  import numpy as np
17
17
  import torch
@@ -10,10 +10,10 @@ import os.path
10
10
  import shutil
11
11
  import tempfile
12
12
  from collections import defaultdict
13
+ from collections.abc import Callable
13
14
  from contextlib import nullcontext
14
15
  from dataclasses import asdict
15
16
  from pathlib import Path
16
- from typing import Callable
17
17
 
18
18
  import torch
19
19
  from tensordict import is_non_tensor, PersistentTensorDict, TensorDict
@@ -5,8 +5,8 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import os
8
+ from collections.abc import Callable
8
9
  from pathlib import Path
9
- from typing import Callable
10
10
 
11
11
  import numpy as np
12
12
  from tensordict import TensorDict
@@ -10,8 +10,9 @@ import json
10
10
  import os
11
11
  import shutil
12
12
  import tempfile
13
+ from collections.abc import Callable
13
14
  from pathlib import Path
14
- from typing import Any, Callable
15
+ from typing import Any
15
16
 
16
17
  import torch
17
18
  from tensordict import make_tensordict, NonTensorData, pad, TensorDict
@@ -8,9 +8,9 @@ import importlib.util
8
8
  import os.path
9
9
  import shutil
10
10
  import tempfile
11
+ from collections.abc import Callable
11
12
  from contextlib import nullcontext
12
13
  from pathlib import Path
13
- from typing import Callable
14
14
 
15
15
  import torch
16
16
  from tensordict import PersistentTensorDict, TensorDict
@@ -12,8 +12,8 @@ import pathlib
12
12
  import shutil
13
13
  import tempfile
14
14
  from collections import defaultdict
15
+ from collections.abc import Callable
15
16
  from pathlib import Path
16
- from typing import Callable
17
17
 
18
18
  import numpy as np
19
19
  import torch
@@ -6,8 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  import importlib.util
8
8
  import os
9
+ from collections.abc import Sequence
9
10
  from pathlib import Path
10
- from typing import Sequence
11
11
 
12
12
  import torch
13
13
  from tensordict import TensorDict, TensorDictBase
torchrl/data/map/hash.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable
7
+ from collections.abc import Callable
8
8
 
9
9
  import torch
10
10
  from torch.nn import Module
torchrl/data/map/query.py CHANGED
@@ -4,8 +4,10 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
+ from collections.abc import Callable, Mapping
8
+
7
9
  from copy import deepcopy
8
- from typing import Any, Callable, List, Mapping, TypeVar
10
+ from typing import Any, TypeVar
9
11
 
10
12
  import torch
11
13
  import torch.nn as nn
@@ -122,7 +124,7 @@ class QueryModule(TensorDictModuleBase):
122
124
  ):
123
125
  if len(in_keys) == 0:
124
126
  raise ValueError("`in_keys` cannot be empty.")
125
- in_keys = in_keys if isinstance(in_keys, List) else [in_keys]
127
+ in_keys = in_keys if isinstance(in_keys, list) else [in_keys]
126
128
 
127
129
  super().__init__()
128
130
  in_keys = self.in_keys = in_keys
@@ -7,7 +7,8 @@ from __future__ import annotations
7
7
  import abc
8
8
  import functools
9
9
  from abc import abstractmethod
10
- from typing import Any, Callable, Generic, TypeVar
10
+ from collections.abc import Callable
11
+ from typing import Any, Generic, TypeVar
11
12
 
12
13
  import torch
13
14
  from tensordict import is_tensor_collection, NestedKey, TensorDictBase
torchrl/data/map/tree.py CHANGED
@@ -6,7 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  import weakref
8
8
  from collections import deque
9
- from typing import Any, Callable, Literal
9
+ from collections.abc import Callable
10
+ from typing import Any, Literal
10
11
 
11
12
  import torch
12
13
  from tensordict import (
torchrl/data/map/utils.py CHANGED
@@ -4,7 +4,7 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable
7
+ from collections.abc import Callable
8
8
 
9
9
  from tensordict import NestedKey
10
10
 
@@ -6,8 +6,9 @@ from __future__ import annotations
6
6
 
7
7
  import contextlib
8
8
  import importlib
9
+ from collections.abc import Callable, Iterator
9
10
 
10
- from typing import Any, Callable, Iterator
11
+ from typing import Any
11
12
 
12
13
  import torch
13
14
  from torchrl._utils import logger as torchrl_logger
@@ -11,9 +11,10 @@ import multiprocessing
11
11
  import textwrap
12
12
  import threading
13
13
  import warnings
14
+ from collections.abc import Callable, Sequence
14
15
  from concurrent.futures import ThreadPoolExecutor
15
16
  from pathlib import Path
16
- from typing import Any, Callable, Sequence
17
+ from typing import Any
17
18
 
18
19
  import numpy as np
19
20
  import torch
@@ -5,7 +5,8 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  from abc import ABC, abstractmethod
8
- from typing import Any, Callable
8
+ from collections.abc import Callable
9
+ from typing import Any
9
10
 
10
11
  import numpy as np
11
12
  import torch
@@ -11,9 +11,10 @@ import sys
11
11
  import textwrap
12
12
  import warnings
13
13
  from collections import OrderedDict
14
+ from collections.abc import Callable, Mapping, Sequence
14
15
  from copy import copy
15
16
  from multiprocessing.context import get_spawning_popen
16
- from typing import Any, Callable, Mapping, Sequence
17
+ from typing import Any
17
18
 
18
19
  import numpy as np
19
20
  import tensordict
@@ -11,8 +11,9 @@ import math
11
11
  import operator
12
12
  import os
13
13
  import typing
14
+ from collections.abc import Callable
14
15
  from pathlib import Path
15
- from typing import Any, Callable, Union
16
+ from typing import Any, Union
16
17
 
17
18
  import numpy as np
18
19
  import torch
@@ -8,10 +8,11 @@ import heapq
8
8
  import json
9
9
  import textwrap
10
10
  from abc import ABC, abstractmethod
11
+ from collections.abc import Sequence
11
12
  from copy import copy
12
13
  from multiprocessing.context import get_spawning_popen
13
14
  from pathlib import Path
14
- from typing import Any, Sequence
15
+ from typing import Any
15
16
 
16
17
  import numpy as np
17
18
  import torch
@@ -12,23 +12,12 @@ import gc
12
12
  import math
13
13
  import warnings
14
14
  import weakref
15
- from collections.abc import Iterable
15
+ from collections.abc import Callable, Iterable, Sequence
16
16
  from copy import deepcopy
17
17
  from dataclasses import dataclass, field
18
18
  from functools import wraps
19
19
  from textwrap import indent
20
- from typing import (
21
- Any,
22
- Callable,
23
- Dict,
24
- Generic,
25
- List,
26
- overload,
27
- Sequence,
28
- Tuple,
29
- TypeVar,
30
- Union,
31
- )
20
+ from typing import Any, Generic, overload, TypeVar, Union
32
21
 
33
22
  import numpy as np
34
23
 
@@ -61,27 +50,27 @@ except ImportError:
61
50
 
62
51
  DEVICE_TYPING = Union[torch.device, str, int]
63
52
 
64
- INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List]
53
+ INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, list]
65
54
 
66
55
  SHAPE_INDEX_TYPING = Union[
67
56
  int,
68
57
  range,
69
- List[int],
58
+ list[int],
70
59
  np.ndarray,
71
60
  slice,
72
61
  None,
73
62
  torch.Tensor,
74
63
  type(...),
75
- Tuple[
64
+ tuple[
76
65
  int,
77
66
  range,
78
- List[int],
67
+ list[int],
79
68
  np.ndarray,
80
69
  slice,
81
70
  None,
82
71
  torch.Tensor,
83
72
  type(...),
84
- Tuple[Any],
73
+ tuple[Any],
85
74
  ],
86
75
  ]
87
76
 
@@ -6273,7 +6262,7 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
6273
6262
  def update(self, dict) -> None:
6274
6263
  for key, item in dict.items():
6275
6264
  if key in self.keys() and isinstance(
6276
- item, (Dict, Composite, StackedComposite)
6265
+ item, (dict, Composite, StackedComposite)
6277
6266
  ):
6278
6267
  for spec, sub_item in zip(self._specs, item.unbind(self.dim)):
6279
6268
  spec[key].update(sub_item)
torchrl/data/utils.py CHANGED
@@ -6,7 +6,8 @@ from __future__ import annotations
6
6
 
7
7
  import functools
8
8
  import typing
9
- from typing import Any, Callable, List, Tuple, Union
9
+ from collections.abc import Callable
10
+ from typing import Any, Union
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -45,7 +46,7 @@ if hasattr(typing, "get_args"):
45
46
  else:
46
47
  DEVICE_TYPING_ARGS = (torch.device, str, int)
47
48
 
48
- INDEX_TYPING = Union[None, int, slice, str, Tensor, List[Any], Tuple[Any, ...]]
49
+ INDEX_TYPING = Union[None, int, slice, str, Tensor, list[Any], tuple[Any, ...]]
49
50
 
50
51
 
51
52
  ACTION_SPACE_MAP = {
@@ -7,12 +7,13 @@ from __future__ import annotations
7
7
  import abc
8
8
 
9
9
  import multiprocessing
10
+ from collections.abc import Callable, Sequence
10
11
  from concurrent.futures import as_completed, ThreadPoolExecutor
11
12
 
12
13
  # import queue
13
14
  from multiprocessing import Queue
14
15
  from queue import Empty
15
- from typing import Callable, Literal, Sequence
16
+ from typing import Literal
16
17
 
17
18
  import torch
18
19
  from tensordict import (
@@ -11,11 +11,12 @@ import os
11
11
  import time
12
12
  import weakref
13
13
  from collections import OrderedDict
14
+ from collections.abc import Callable, Sequence
14
15
  from copy import deepcopy
15
16
  from functools import wraps
16
17
  from multiprocessing import connection
17
18
  from multiprocessing.synchronize import Lock as MpLock
18
- from typing import Any, Callable, Sequence
19
+ from typing import Any
19
20
  from warnings import warn
20
21
 
21
22
  import torch
torchrl/envs/common.py CHANGED
@@ -9,9 +9,10 @@ import abc
9
9
  import re
10
10
  import warnings
11
11
  import weakref
12
+ from collections.abc import Callable, Iterator
12
13
  from copy import deepcopy
13
14
  from functools import partial, wraps
14
- from typing import Any, Callable, Iterator
15
+ from typing import Any
15
16
 
16
17
  import numpy as np
17
18
  import torch
@@ -4,7 +4,7 @@
4
4
  # LICENSE file in the root directory of this source tree.
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Callable
7
+ from collections.abc import Callable
8
8
 
9
9
  import torch
10
10
 
@@ -6,8 +6,8 @@
6
6
  from __future__ import annotations
7
7
 
8
8
  from collections import OrderedDict
9
+ from collections.abc import Callable
9
10
  from multiprocessing.sharedctypes import Synchronized
10
- from typing import Callable
11
11
 
12
12
  import torch
13
13
  from tensordict import TensorDictBase
torchrl/envs/gym_like.py CHANGED
@@ -9,7 +9,8 @@ import abc
9
9
  import functools
10
10
  import re
11
11
  import warnings
12
- from typing import Any, Callable, Mapping, Sequence, TypeVar
12
+ from collections.abc import Callable, Mapping, Sequence
13
+ from typing import Any, TypeVar
13
14
 
14
15
  import numpy as np
15
16
  import torch
@@ -7,7 +7,7 @@ from __future__ import annotations
7
7
  import collections
8
8
  import importlib
9
9
  import os
10
- from typing import Any, Dict
10
+ from typing import Any
11
11
 
12
12
  import numpy as np
13
13
  import torch
@@ -44,7 +44,7 @@ def _dmcontrol_to_torchrl_spec_transform(
44
44
  ) -> TensorSpec:
45
45
  import dm_env
46
46
 
47
- if isinstance(spec, collections.OrderedDict) or isinstance(spec, Dict):
47
+ if isinstance(spec, collections.OrderedDict) or isinstance(spec, dict):
48
48
  spec = {
49
49
  k: _dmcontrol_to_torchrl_spec_transform(
50
50
  item,
torchrl/envs/libs/gym.py CHANGED
@@ -10,7 +10,6 @@ import importlib
10
10
  import warnings
11
11
  from copy import copy
12
12
  from types import ModuleType
13
- from typing import Dict
14
13
  from warnings import warn
15
14
 
16
15
  import numpy as np
@@ -510,7 +509,7 @@ def convert_sequence_spec(
510
509
  return out
511
510
 
512
511
 
513
- @register_gym_spec_conversion(Dict)
512
+ @register_gym_spec_conversion(dict)
514
513
  def convert_dict_spec(
515
514
  spec,
516
515
  dtype=None,
@@ -765,7 +764,7 @@ def _is_from_pixels(env):
765
764
  gDict = gym_backend("spaces").dict.Dict
766
765
  Box = gym_backend("spaces").Box
767
766
 
768
- if isinstance(observation_spec, (Dict,)):
767
+ if isinstance(observation_spec, (dict,)):
769
768
  if "pixels" in set(observation_spec.keys()):
770
769
  return True
771
770
  if isinstance(observation_spec, (gDict,)):
@@ -5,7 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import importlib
8
- from typing import Mapping, Sequence
8
+ from collections.abc import Mapping, Sequence
9
9
 
10
10
  import torch
11
11
  from tensordict import TensorDict, TensorDictBase
@@ -7,7 +7,6 @@ from __future__ import annotations
7
7
  import copy
8
8
  import importlib
9
9
  import warnings
10
- from typing import Dict
11
10
 
12
11
  import numpy as np
13
12
  import packaging
@@ -807,7 +806,7 @@ class PettingZooWrapper(_EnvWrapper):
807
806
  for index, agent in enumerate(agents):
808
807
  agent_obs = observation_dict[agent]
809
808
  agent_info = info_dict[agent]
810
- if isinstance(agent_obs, Dict) and "action_mask" in agent_obs:
809
+ if isinstance(agent_obs, dict) and "action_mask" in agent_obs:
811
810
  if agent in agents_acting:
812
811
  group_mask[index] = torch.tensor(
813
812
  agent_obs["action_mask"],
@@ -815,7 +814,7 @@ class PettingZooWrapper(_EnvWrapper):
815
814
  dtype=torch.bool,
816
815
  )
817
816
  del agent_obs["action_mask"]
818
- elif isinstance(agent_info, Dict) and "action_mask" in agent_info:
817
+ elif isinstance(agent_info, dict) and "action_mask" in agent_info:
819
818
  if agent in agents_acting:
820
819
  group_mask[index] = torch.tensor(
821
820
  agent_info["action_mask"],