torchrl-nightly 2025.8.8__cp312-cp312-manylinux1_x86_64.whl → 2025.8.10__cp312-cp312-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/ecosystem/vmas_rllib_vs_torchrl_sampling_performance.py +1 -2
- sota-implementations/grpo/grpo_utils.py +2 -1
- sota-implementations/redq/utils.py +2 -1
- torchrl/_torchrl.cpython-312-x86_64-linux-gnu.so +0 -0
- torchrl/_utils.py +2 -1
- torchrl/collectors/collectors.py +2 -1
- torchrl/collectors/distributed/generic.py +3 -1
- torchrl/collectors/distributed/ray.py +3 -1
- torchrl/collectors/distributed/rpc.py +3 -1
- torchrl/collectors/distributed/sync.py +3 -1
- torchrl/collectors/llm/base.py +2 -1
- torchrl/collectors/llm/ray_collector.py +2 -1
- torchrl/collectors/utils.py +1 -1
- torchrl/collectors/weight_update.py +2 -1
- torchrl/data/datasets/atari_dqn.py +1 -1
- torchrl/data/datasets/common.py +1 -1
- torchrl/data/datasets/d4rl.py +1 -1
- torchrl/data/datasets/minari_data.py +1 -1
- torchrl/data/datasets/openml.py +1 -1
- torchrl/data/datasets/openx.py +2 -1
- torchrl/data/datasets/roboset.py +1 -1
- torchrl/data/datasets/vd4rl.py +1 -1
- torchrl/data/llm/dataset.py +1 -1
- torchrl/data/map/hash.py +1 -1
- torchrl/data/map/query.py +4 -2
- torchrl/data/map/tdstorage.py +2 -1
- torchrl/data/map/tree.py +2 -1
- torchrl/data/map/utils.py +1 -1
- torchrl/data/replay_buffers/ray_buffer.py +2 -1
- torchrl/data/replay_buffers/replay_buffers.py +2 -1
- torchrl/data/replay_buffers/scheduler.py +2 -1
- torchrl/data/replay_buffers/storages.py +2 -1
- torchrl/data/replay_buffers/utils.py +2 -1
- torchrl/data/replay_buffers/writers.py +2 -1
- torchrl/data/tensor_specs.py +8 -19
- torchrl/data/utils.py +3 -2
- torchrl/envs/async_envs.py +2 -1
- torchrl/envs/batched_envs.py +2 -1
- torchrl/envs/common.py +2 -1
- torchrl/envs/custom/llm.py +1 -1
- torchrl/envs/env_creator.py +1 -1
- torchrl/envs/gym_like.py +2 -1
- torchrl/envs/libs/dm_control.py +2 -2
- torchrl/envs/libs/gym.py +2 -3
- torchrl/envs/libs/meltingpot.py +1 -1
- torchrl/envs/libs/pettingzoo.py +2 -3
- torchrl/envs/libs/smacv2.py +8 -10
- torchrl/envs/llm/chat.py +3 -1
- torchrl/envs/llm/datasets/gsm8k.py +2 -1
- torchrl/envs/llm/datasets/ifeval.py +3 -1
- torchrl/envs/llm/envs.py +2 -1
- torchrl/envs/llm/reward/ifeval/_instructions.py +3 -2
- torchrl/envs/llm/reward/ifeval/_instructions_util.py +1 -1
- torchrl/envs/llm/reward/ifeval/_scorer.py +1 -1
- torchrl/envs/llm/transforms/dataloading.py +2 -2
- torchrl/envs/llm/transforms/reason.py +2 -1
- torchrl/envs/llm/transforms/tokenizer.py +1 -1
- torchrl/envs/transforms/transforms.py +3 -10
- torchrl/envs/transforms/vecnorm.py +3 -1
- torchrl/modules/distributions/continuous.py +1 -1
- torchrl/modules/distributions/discrete.py +2 -1
- torchrl/modules/models/exploration.py +1 -1
- torchrl/modules/models/models.py +1 -1
- torchrl/modules/models/multiagent.py +1 -1
- torchrl/modules/models/utils.py +1 -1
- torchrl/modules/tensordict_module/actors.py +1 -1
- torchrl/modules/tensordict_module/common.py +1 -1
- torchrl/objectives/common.py +1 -1
- torchrl/objectives/ppo.py +1 -1
- torchrl/objectives/utils.py +2 -1
- torchrl/objectives/value/advantages.py +1 -1
- torchrl/record/loggers/common.py +1 -1
- torchrl/record/loggers/csv.py +1 -1
- torchrl/record/loggers/mlflow.py +2 -1
- torchrl/record/loggers/tensorboard.py +1 -1
- torchrl/record/loggers/wandb.py +1 -1
- torchrl/record/recorder.py +1 -1
- torchrl/trainers/helpers/collectors.py +3 -1
- torchrl/trainers/helpers/envs.py +14 -13
- torchrl/trainers/trainers.py +5 -4
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/METADATA +1 -1
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/RECORD +85 -85
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/WHEEL +0 -0
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/licenses/LICENSE +0 -0
- {torchrl_nightly-2025.8.8.dist-info → torchrl_nightly-2025.8.10.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,6 @@ import pickle
|
|
9
9
|
|
10
10
|
import time
|
11
11
|
from pathlib import Path
|
12
|
-
from typing import Dict
|
13
12
|
|
14
13
|
import numpy as np
|
15
14
|
|
@@ -93,7 +92,7 @@ def run_vmas_rllib(
|
|
93
92
|
- result["timers"]["learn_time_ms"]
|
94
93
|
)
|
95
94
|
|
96
|
-
def env_creator(config:
|
95
|
+
def env_creator(config: dict):
|
97
96
|
env = vmas.make_env(
|
98
97
|
scenario=config["scenario_name"],
|
99
98
|
num_envs=config["num_envs"],
|
@@ -4,8 +4,9 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
from collections.abc import Callable, Sequence
|
8
|
+
|
7
9
|
from copy import copy
|
8
|
-
from typing import Callable, Sequence
|
9
10
|
|
10
11
|
import torch
|
11
12
|
from omegaconf import OmegaConf
|
Binary file
|
torchrl/_utils.py
CHANGED
@@ -16,12 +16,13 @@ import threading
|
|
16
16
|
import time
|
17
17
|
import traceback
|
18
18
|
import warnings
|
19
|
+
from collections.abc import Callable
|
19
20
|
from contextlib import nullcontext
|
20
21
|
from copy import copy
|
21
22
|
from functools import wraps
|
22
23
|
from importlib import import_module
|
23
24
|
from textwrap import indent
|
24
|
-
from typing import Any,
|
25
|
+
from typing import Any, cast, TypeVar
|
25
26
|
|
26
27
|
import numpy as np
|
27
28
|
import torch
|
torchrl/collectors/collectors.py
CHANGED
@@ -17,12 +17,13 @@ import time
|
|
17
17
|
import typing
|
18
18
|
import warnings
|
19
19
|
from collections import defaultdict, OrderedDict
|
20
|
+
from collections.abc import Callable, Iterator, Mapping, Sequence
|
20
21
|
from copy import deepcopy
|
21
22
|
from multiprocessing import connection, queues
|
22
23
|
from multiprocessing.managers import SyncManager
|
23
24
|
from queue import Empty
|
24
25
|
from textwrap import indent
|
25
|
-
from typing import Any,
|
26
|
+
from typing import Any, TypeVar
|
26
27
|
|
27
28
|
import numpy as np
|
28
29
|
import torch
|
@@ -9,9 +9,11 @@ from __future__ import annotations
|
|
9
9
|
import os
|
10
10
|
import socket
|
11
11
|
import warnings
|
12
|
+
from collections import OrderedDict
|
13
|
+
from collections.abc import Callable, Sequence
|
12
14
|
from copy import copy, deepcopy
|
13
15
|
from datetime import timedelta
|
14
|
-
from typing import Any
|
16
|
+
from typing import Any
|
15
17
|
|
16
18
|
import torch.cuda
|
17
19
|
from tensordict import TensorDict, TensorDictBase
|
@@ -7,7 +7,9 @@ from __future__ import annotations
|
|
7
7
|
|
8
8
|
import asyncio
|
9
9
|
import warnings
|
10
|
-
from
|
10
|
+
from collections import OrderedDict
|
11
|
+
from collections.abc import Callable, Iterator, Sequence
|
12
|
+
from typing import Any
|
11
13
|
|
12
14
|
import torch
|
13
15
|
import torch.nn as nn
|
@@ -11,8 +11,10 @@ import os
|
|
11
11
|
import socket
|
12
12
|
import time
|
13
13
|
import warnings
|
14
|
+
from collections import OrderedDict
|
15
|
+
from collections.abc import Callable, Sequence
|
14
16
|
from copy import copy, deepcopy
|
15
|
-
from typing import Any
|
17
|
+
from typing import Any
|
16
18
|
|
17
19
|
import torch.cuda
|
18
20
|
|
@@ -9,9 +9,11 @@ from __future__ import annotations
|
|
9
9
|
import os
|
10
10
|
import socket
|
11
11
|
import warnings
|
12
|
+
from collections import OrderedDict
|
13
|
+
from collections.abc import Callable, Sequence
|
12
14
|
from copy import copy, deepcopy
|
13
15
|
from datetime import timedelta
|
14
|
-
from typing import Any,
|
16
|
+
from typing import Any, Literal
|
15
17
|
|
16
18
|
import torch.cuda
|
17
19
|
from tensordict import TensorDict, TensorDictBase
|
torchrl/collectors/llm/base.py
CHANGED
torchrl/collectors/utils.py
CHANGED
torchrl/data/datasets/common.py
CHANGED
torchrl/data/datasets/d4rl.py
CHANGED
@@ -10,10 +10,10 @@ import os.path
|
|
10
10
|
import shutil
|
11
11
|
import tempfile
|
12
12
|
from collections import defaultdict
|
13
|
+
from collections.abc import Callable
|
13
14
|
from contextlib import nullcontext
|
14
15
|
from dataclasses import asdict
|
15
16
|
from pathlib import Path
|
16
|
-
from typing import Callable
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from tensordict import is_non_tensor, PersistentTensorDict, TensorDict
|
torchrl/data/datasets/openml.py
CHANGED
torchrl/data/datasets/openx.py
CHANGED
@@ -10,8 +10,9 @@ import json
|
|
10
10
|
import os
|
11
11
|
import shutil
|
12
12
|
import tempfile
|
13
|
+
from collections.abc import Callable
|
13
14
|
from pathlib import Path
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import torch
|
17
18
|
from tensordict import make_tensordict, NonTensorData, pad, TensorDict
|
torchrl/data/datasets/roboset.py
CHANGED
@@ -8,9 +8,9 @@ import importlib.util
|
|
8
8
|
import os.path
|
9
9
|
import shutil
|
10
10
|
import tempfile
|
11
|
+
from collections.abc import Callable
|
11
12
|
from contextlib import nullcontext
|
12
13
|
from pathlib import Path
|
13
|
-
from typing import Callable
|
14
14
|
|
15
15
|
import torch
|
16
16
|
from tensordict import PersistentTensorDict, TensorDict
|
torchrl/data/datasets/vd4rl.py
CHANGED
torchrl/data/llm/dataset.py
CHANGED
torchrl/data/map/hash.py
CHANGED
torchrl/data/map/query.py
CHANGED
@@ -4,8 +4,10 @@
|
|
4
4
|
# LICENSE file in the root directory of this source tree.
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
+
from collections.abc import Callable, Mapping
|
8
|
+
|
7
9
|
from copy import deepcopy
|
8
|
-
from typing import Any,
|
10
|
+
from typing import Any, TypeVar
|
9
11
|
|
10
12
|
import torch
|
11
13
|
import torch.nn as nn
|
@@ -122,7 +124,7 @@ class QueryModule(TensorDictModuleBase):
|
|
122
124
|
):
|
123
125
|
if len(in_keys) == 0:
|
124
126
|
raise ValueError("`in_keys` cannot be empty.")
|
125
|
-
in_keys = in_keys if isinstance(in_keys,
|
127
|
+
in_keys = in_keys if isinstance(in_keys, list) else [in_keys]
|
126
128
|
|
127
129
|
super().__init__()
|
128
130
|
in_keys = self.in_keys = in_keys
|
torchrl/data/map/tdstorage.py
CHANGED
@@ -7,7 +7,8 @@ from __future__ import annotations
|
|
7
7
|
import abc
|
8
8
|
import functools
|
9
9
|
from abc import abstractmethod
|
10
|
-
from
|
10
|
+
from collections.abc import Callable
|
11
|
+
from typing import Any, Generic, TypeVar
|
11
12
|
|
12
13
|
import torch
|
13
14
|
from tensordict import is_tensor_collection, NestedKey, TensorDictBase
|
torchrl/data/map/tree.py
CHANGED
torchrl/data/map/utils.py
CHANGED
@@ -6,8 +6,9 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
import contextlib
|
8
8
|
import importlib
|
9
|
+
from collections.abc import Callable, Iterator
|
9
10
|
|
10
|
-
from typing import Any
|
11
|
+
from typing import Any
|
11
12
|
|
12
13
|
import torch
|
13
14
|
from torchrl._utils import logger as torchrl_logger
|
@@ -11,9 +11,10 @@ import multiprocessing
|
|
11
11
|
import textwrap
|
12
12
|
import threading
|
13
13
|
import warnings
|
14
|
+
from collections.abc import Callable, Sequence
|
14
15
|
from concurrent.futures import ThreadPoolExecutor
|
15
16
|
from pathlib import Path
|
16
|
-
from typing import Any
|
17
|
+
from typing import Any
|
17
18
|
|
18
19
|
import numpy as np
|
19
20
|
import torch
|
@@ -11,9 +11,10 @@ import sys
|
|
11
11
|
import textwrap
|
12
12
|
import warnings
|
13
13
|
from collections import OrderedDict
|
14
|
+
from collections.abc import Callable, Mapping, Sequence
|
14
15
|
from copy import copy
|
15
16
|
from multiprocessing.context import get_spawning_popen
|
16
|
-
from typing import Any
|
17
|
+
from typing import Any
|
17
18
|
|
18
19
|
import numpy as np
|
19
20
|
import tensordict
|
@@ -8,10 +8,11 @@ import heapq
|
|
8
8
|
import json
|
9
9
|
import textwrap
|
10
10
|
from abc import ABC, abstractmethod
|
11
|
+
from collections.abc import Sequence
|
11
12
|
from copy import copy
|
12
13
|
from multiprocessing.context import get_spawning_popen
|
13
14
|
from pathlib import Path
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import numpy as np
|
17
18
|
import torch
|
torchrl/data/tensor_specs.py
CHANGED
@@ -12,23 +12,12 @@ import gc
|
|
12
12
|
import math
|
13
13
|
import warnings
|
14
14
|
import weakref
|
15
|
-
from collections.abc import Iterable
|
15
|
+
from collections.abc import Callable, Iterable, Sequence
|
16
16
|
from copy import deepcopy
|
17
17
|
from dataclasses import dataclass, field
|
18
18
|
from functools import wraps
|
19
19
|
from textwrap import indent
|
20
|
-
from typing import
|
21
|
-
Any,
|
22
|
-
Callable,
|
23
|
-
Dict,
|
24
|
-
Generic,
|
25
|
-
List,
|
26
|
-
overload,
|
27
|
-
Sequence,
|
28
|
-
Tuple,
|
29
|
-
TypeVar,
|
30
|
-
Union,
|
31
|
-
)
|
20
|
+
from typing import Any, Generic, overload, TypeVar, Union
|
32
21
|
|
33
22
|
import numpy as np
|
34
23
|
|
@@ -61,27 +50,27 @@ except ImportError:
|
|
61
50
|
|
62
51
|
DEVICE_TYPING = Union[torch.device, str, int]
|
63
52
|
|
64
|
-
INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice,
|
53
|
+
INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, list]
|
65
54
|
|
66
55
|
SHAPE_INDEX_TYPING = Union[
|
67
56
|
int,
|
68
57
|
range,
|
69
|
-
|
58
|
+
list[int],
|
70
59
|
np.ndarray,
|
71
60
|
slice,
|
72
61
|
None,
|
73
62
|
torch.Tensor,
|
74
63
|
type(...),
|
75
|
-
|
64
|
+
tuple[
|
76
65
|
int,
|
77
66
|
range,
|
78
|
-
|
67
|
+
list[int],
|
79
68
|
np.ndarray,
|
80
69
|
slice,
|
81
70
|
None,
|
82
71
|
torch.Tensor,
|
83
72
|
type(...),
|
84
|
-
|
73
|
+
tuple[Any],
|
85
74
|
],
|
86
75
|
]
|
87
76
|
|
@@ -6273,7 +6262,7 @@ class StackedComposite(_LazyStackedMixin[Composite], Composite):
|
|
6273
6262
|
def update(self, dict) -> None:
|
6274
6263
|
for key, item in dict.items():
|
6275
6264
|
if key in self.keys() and isinstance(
|
6276
|
-
item, (
|
6265
|
+
item, (dict, Composite, StackedComposite)
|
6277
6266
|
):
|
6278
6267
|
for spec, sub_item in zip(self._specs, item.unbind(self.dim)):
|
6279
6268
|
spec[key].update(sub_item)
|
torchrl/data/utils.py
CHANGED
@@ -6,7 +6,8 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
import functools
|
8
8
|
import typing
|
9
|
-
from
|
9
|
+
from collections.abc import Callable
|
10
|
+
from typing import Any, Union
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import torch
|
@@ -45,7 +46,7 @@ if hasattr(typing, "get_args"):
|
|
45
46
|
else:
|
46
47
|
DEVICE_TYPING_ARGS = (torch.device, str, int)
|
47
48
|
|
48
|
-
INDEX_TYPING = Union[None, int, slice, str, Tensor,
|
49
|
+
INDEX_TYPING = Union[None, int, slice, str, Tensor, list[Any], tuple[Any, ...]]
|
49
50
|
|
50
51
|
|
51
52
|
ACTION_SPACE_MAP = {
|
torchrl/envs/async_envs.py
CHANGED
@@ -7,12 +7,13 @@ from __future__ import annotations
|
|
7
7
|
import abc
|
8
8
|
|
9
9
|
import multiprocessing
|
10
|
+
from collections.abc import Callable, Sequence
|
10
11
|
from concurrent.futures import as_completed, ThreadPoolExecutor
|
11
12
|
|
12
13
|
# import queue
|
13
14
|
from multiprocessing import Queue
|
14
15
|
from queue import Empty
|
15
|
-
from typing import
|
16
|
+
from typing import Literal
|
16
17
|
|
17
18
|
import torch
|
18
19
|
from tensordict import (
|
torchrl/envs/batched_envs.py
CHANGED
@@ -11,11 +11,12 @@ import os
|
|
11
11
|
import time
|
12
12
|
import weakref
|
13
13
|
from collections import OrderedDict
|
14
|
+
from collections.abc import Callable, Sequence
|
14
15
|
from copy import deepcopy
|
15
16
|
from functools import wraps
|
16
17
|
from multiprocessing import connection
|
17
18
|
from multiprocessing.synchronize import Lock as MpLock
|
18
|
-
from typing import Any
|
19
|
+
from typing import Any
|
19
20
|
from warnings import warn
|
20
21
|
|
21
22
|
import torch
|
torchrl/envs/common.py
CHANGED
@@ -9,9 +9,10 @@ import abc
|
|
9
9
|
import re
|
10
10
|
import warnings
|
11
11
|
import weakref
|
12
|
+
from collections.abc import Callable, Iterator
|
12
13
|
from copy import deepcopy
|
13
14
|
from functools import partial, wraps
|
14
|
-
from typing import Any
|
15
|
+
from typing import Any
|
15
16
|
|
16
17
|
import numpy as np
|
17
18
|
import torch
|
torchrl/envs/custom/llm.py
CHANGED
torchrl/envs/env_creator.py
CHANGED
@@ -6,8 +6,8 @@
|
|
6
6
|
from __future__ import annotations
|
7
7
|
|
8
8
|
from collections import OrderedDict
|
9
|
+
from collections.abc import Callable
|
9
10
|
from multiprocessing.sharedctypes import Synchronized
|
10
|
-
from typing import Callable
|
11
11
|
|
12
12
|
import torch
|
13
13
|
from tensordict import TensorDictBase
|
torchrl/envs/gym_like.py
CHANGED
torchrl/envs/libs/dm_control.py
CHANGED
@@ -7,7 +7,7 @@ from __future__ import annotations
|
|
7
7
|
import collections
|
8
8
|
import importlib
|
9
9
|
import os
|
10
|
-
from typing import Any
|
10
|
+
from typing import Any
|
11
11
|
|
12
12
|
import numpy as np
|
13
13
|
import torch
|
@@ -44,7 +44,7 @@ def _dmcontrol_to_torchrl_spec_transform(
|
|
44
44
|
) -> TensorSpec:
|
45
45
|
import dm_env
|
46
46
|
|
47
|
-
if isinstance(spec, collections.OrderedDict) or isinstance(spec,
|
47
|
+
if isinstance(spec, collections.OrderedDict) or isinstance(spec, dict):
|
48
48
|
spec = {
|
49
49
|
k: _dmcontrol_to_torchrl_spec_transform(
|
50
50
|
item,
|
torchrl/envs/libs/gym.py
CHANGED
@@ -10,7 +10,6 @@ import importlib
|
|
10
10
|
import warnings
|
11
11
|
from copy import copy
|
12
12
|
from types import ModuleType
|
13
|
-
from typing import Dict
|
14
13
|
from warnings import warn
|
15
14
|
|
16
15
|
import numpy as np
|
@@ -510,7 +509,7 @@ def convert_sequence_spec(
|
|
510
509
|
return out
|
511
510
|
|
512
511
|
|
513
|
-
@register_gym_spec_conversion(
|
512
|
+
@register_gym_spec_conversion(dict)
|
514
513
|
def convert_dict_spec(
|
515
514
|
spec,
|
516
515
|
dtype=None,
|
@@ -765,7 +764,7 @@ def _is_from_pixels(env):
|
|
765
764
|
gDict = gym_backend("spaces").dict.Dict
|
766
765
|
Box = gym_backend("spaces").Box
|
767
766
|
|
768
|
-
if isinstance(observation_spec, (
|
767
|
+
if isinstance(observation_spec, (dict,)):
|
769
768
|
if "pixels" in set(observation_spec.keys()):
|
770
769
|
return True
|
771
770
|
if isinstance(observation_spec, (gDict,)):
|
torchrl/envs/libs/meltingpot.py
CHANGED
torchrl/envs/libs/pettingzoo.py
CHANGED
@@ -7,7 +7,6 @@ from __future__ import annotations
|
|
7
7
|
import copy
|
8
8
|
import importlib
|
9
9
|
import warnings
|
10
|
-
from typing import Dict
|
11
10
|
|
12
11
|
import numpy as np
|
13
12
|
import packaging
|
@@ -807,7 +806,7 @@ class PettingZooWrapper(_EnvWrapper):
|
|
807
806
|
for index, agent in enumerate(agents):
|
808
807
|
agent_obs = observation_dict[agent]
|
809
808
|
agent_info = info_dict[agent]
|
810
|
-
if isinstance(agent_obs,
|
809
|
+
if isinstance(agent_obs, dict) and "action_mask" in agent_obs:
|
811
810
|
if agent in agents_acting:
|
812
811
|
group_mask[index] = torch.tensor(
|
813
812
|
agent_obs["action_mask"],
|
@@ -815,7 +814,7 @@ class PettingZooWrapper(_EnvWrapper):
|
|
815
814
|
dtype=torch.bool,
|
816
815
|
)
|
817
816
|
del agent_obs["action_mask"]
|
818
|
-
elif isinstance(agent_info,
|
817
|
+
elif isinstance(agent_info, dict) and "action_mask" in agent_info:
|
819
818
|
if agent in agents_acting:
|
820
819
|
group_mask[index] = torch.tensor(
|
821
820
|
agent_info["action_mask"],
|