torchrl 0.11.0__cp313-cp313-win_amd64.whl → 0.11.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sota-implementations/dreamer/dreamer.py +21 -5
- sota-implementations/dreamer/dreamer_utils.py +160 -22
- torchrl/_torchrl.cp313-win_amd64.pyd +0 -0
- torchrl/collectors/_base.py +10 -0
- torchrl/collectors/_multi_base.py +19 -5
- torchrl/collectors/utils.py +2 -1
- torchrl/data/replay_buffers/writers.py +2 -1
- torchrl/data/tensor_specs.py +2 -2
- torchrl/envs/batched_envs.py +16 -0
- torchrl/envs/transforms/transforms.py +91 -0
- torchrl/objectives/cql.py +31 -3
- torchrl/objectives/crossq.py +31 -3
- torchrl/objectives/decision_transformer.py +35 -8
- torchrl/objectives/iql.py +45 -2
- torchrl/objectives/redq.py +46 -10
- torchrl/objectives/sac.py +69 -6
- torchrl/weight_update/_shared.py +189 -44
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/METADATA +2 -1
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/RECORD +23 -23
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/LICENSE +0 -0
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/WHEEL +0 -0
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/entry_points.txt +0 -0
- {torchrl-0.11.0.dist-info → torchrl-0.11.1.dist-info}/top_level.txt +0 -0
|
@@ -140,17 +140,21 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
140
140
|
buffer_size=buffer_size,
|
|
141
141
|
buffer_scratch_dir=scratch_dir,
|
|
142
142
|
device=device,
|
|
143
|
-
prefetch=prefetch
|
|
143
|
+
prefetch=prefetch, # Always use prefetch for better throughput
|
|
144
144
|
pixel_obs=cfg.env.from_pixels,
|
|
145
145
|
grayscale=cfg.env.grayscale,
|
|
146
146
|
image_size=cfg.env.image_size,
|
|
147
147
|
)
|
|
148
148
|
|
|
149
149
|
# Create storage transform for extend-time processing (applied once per frame)
|
|
150
|
+
# When GPU is available, GPUImageTransform handles image processing in the env,
|
|
151
|
+
# so we skip the heavy CPU transforms in storage_transform
|
|
152
|
+
gpu_transforms = device.type == "cuda"
|
|
150
153
|
storage_transform = make_storage_transform(
|
|
151
154
|
pixel_obs=cfg.env.from_pixels,
|
|
152
155
|
grayscale=cfg.env.grayscale,
|
|
153
156
|
image_size=cfg.env.image_size,
|
|
157
|
+
gpu_transforms=gpu_transforms,
|
|
154
158
|
)
|
|
155
159
|
|
|
156
160
|
# Create policy version tracker for async collection
|
|
@@ -247,7 +251,11 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
247
251
|
compile_warmup = 3
|
|
248
252
|
torchrl_logger.info(f"Compiling loss modules with warmup={compile_warmup}")
|
|
249
253
|
backend = compile_cfg.backend
|
|
250
|
-
|
|
254
|
+
cudagraphs = compile_cfg.cudagraphs
|
|
255
|
+
|
|
256
|
+
# Build compile options - disable CUDA graphs if configured (default)
|
|
257
|
+
# CUDA graphs conflict with dynamic RSSM rollout loop
|
|
258
|
+
compile_options = {"triton.cudagraphs": cudagraphs}
|
|
251
259
|
|
|
252
260
|
# Note: We do NOT compile rssm_prior/rssm_posterior here because they are
|
|
253
261
|
# shared with the policy used in the collector. Compiling them would cause
|
|
@@ -260,17 +268,25 @@ def main(cfg: DictConfig): # noqa: F821
|
|
|
260
268
|
world_model_loss = compile_with_warmup(
|
|
261
269
|
world_model_loss,
|
|
262
270
|
backend=backend,
|
|
263
|
-
mode=mode,
|
|
264
271
|
fullgraph=False,
|
|
265
272
|
warmup=compile_warmup,
|
|
273
|
+
options=compile_options,
|
|
266
274
|
)
|
|
267
275
|
if "actor" in compile_losses:
|
|
268
276
|
actor_loss = compile_with_warmup(
|
|
269
|
-
actor_loss,
|
|
277
|
+
actor_loss,
|
|
278
|
+
backend=backend,
|
|
279
|
+
fullgraph=False,
|
|
280
|
+
warmup=compile_warmup,
|
|
281
|
+
options=compile_options,
|
|
270
282
|
)
|
|
271
283
|
if "value" in compile_losses:
|
|
272
284
|
value_loss = compile_with_warmup(
|
|
273
|
-
value_loss,
|
|
285
|
+
value_loss,
|
|
286
|
+
backend=backend,
|
|
287
|
+
fullgraph=False,
|
|
288
|
+
warmup=compile_warmup,
|
|
289
|
+
options=compile_options,
|
|
274
290
|
)
|
|
275
291
|
else:
|
|
276
292
|
compile_warmup = 0
|
|
@@ -10,7 +10,7 @@ from contextlib import nullcontext
|
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.nn as nn
|
|
13
|
-
from tensordict import NestedKey
|
|
13
|
+
from tensordict import NestedKey, TensorDictBase
|
|
14
14
|
from tensordict.nn import (
|
|
15
15
|
InteractionType,
|
|
16
16
|
ProbabilisticTensorDictModule,
|
|
@@ -38,7 +38,6 @@ from torchrl.envs import (
|
|
|
38
38
|
DreamerEnv,
|
|
39
39
|
EnvCreator,
|
|
40
40
|
ExcludeTransform,
|
|
41
|
-
# ExcludeTransform,
|
|
42
41
|
FrameSkipTransform,
|
|
43
42
|
GrayScale,
|
|
44
43
|
GymEnv,
|
|
@@ -50,6 +49,7 @@ from torchrl.envs import (
|
|
|
50
49
|
StepCounter,
|
|
51
50
|
TensorDictPrimer,
|
|
52
51
|
ToTensorImage,
|
|
52
|
+
Transform,
|
|
53
53
|
TransformedEnv,
|
|
54
54
|
)
|
|
55
55
|
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
|
|
@@ -260,6 +260,89 @@ class DreamerProfiler:
|
|
|
260
260
|
return self.total_optim_steps >= target_steps
|
|
261
261
|
|
|
262
262
|
|
|
263
|
+
class GPUImageTransform(Transform):
|
|
264
|
+
"""Composite transform that processes images on GPU for faster execution.
|
|
265
|
+
|
|
266
|
+
This transform:
|
|
267
|
+
1. Moves pixels_int to GPU
|
|
268
|
+
2. Runs ToTensorImage (permute + divide by 255)
|
|
269
|
+
3. Optionally runs GrayScale
|
|
270
|
+
4. Runs Resize
|
|
271
|
+
5. Keeps output on GPU for fast policy inference
|
|
272
|
+
|
|
273
|
+
This avoids device mismatch issues by not using DeviceCastTransform on the
|
|
274
|
+
full tensordict - only the pixel processing happens on GPU.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(
|
|
278
|
+
self,
|
|
279
|
+
device: torch.device,
|
|
280
|
+
image_size: int,
|
|
281
|
+
grayscale: bool = False,
|
|
282
|
+
in_key: str = "pixels_int",
|
|
283
|
+
out_key: str = "pixels",
|
|
284
|
+
):
|
|
285
|
+
super().__init__(in_keys=[in_key], out_keys=[out_key])
|
|
286
|
+
self.device = device
|
|
287
|
+
self.image_size = image_size
|
|
288
|
+
self.grayscale = grayscale
|
|
289
|
+
self.in_key = in_key
|
|
290
|
+
self.out_key = out_key
|
|
291
|
+
|
|
292
|
+
def _apply_transform(self, pixels_int: torch.Tensor) -> torch.Tensor:
|
|
293
|
+
# Move to GPU
|
|
294
|
+
pixels = pixels_int.to(self.device)
|
|
295
|
+
# ToTensorImage: permute W x H x C -> C x W x H and normalize
|
|
296
|
+
pixels = pixels.permute(*list(range(pixels.ndimension() - 3)), -1, -3, -2)
|
|
297
|
+
pixels = pixels.float().div(255)
|
|
298
|
+
# GrayScale
|
|
299
|
+
if self.grayscale:
|
|
300
|
+
pixels = pixels.mean(dim=-3, keepdim=True)
|
|
301
|
+
# Resize using interpolate
|
|
302
|
+
if pixels.shape[-2:] != (self.image_size, self.image_size):
|
|
303
|
+
# Add batch dim if needed for interpolate
|
|
304
|
+
needs_squeeze = pixels.ndim == 3
|
|
305
|
+
if needs_squeeze:
|
|
306
|
+
pixels = pixels.unsqueeze(0)
|
|
307
|
+
pixels = torch.nn.functional.interpolate(
|
|
308
|
+
pixels,
|
|
309
|
+
size=(self.image_size, self.image_size),
|
|
310
|
+
mode="bilinear",
|
|
311
|
+
align_corners=False,
|
|
312
|
+
antialias=True,
|
|
313
|
+
)
|
|
314
|
+
if needs_squeeze:
|
|
315
|
+
pixels = pixels.squeeze(0)
|
|
316
|
+
return pixels
|
|
317
|
+
|
|
318
|
+
def _reset(
|
|
319
|
+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
|
|
320
|
+
) -> TensorDictBase:
|
|
321
|
+
return self._call(tensordict_reset)
|
|
322
|
+
|
|
323
|
+
def transform_observation_spec(self, observation_spec):
|
|
324
|
+
# Update the spec for the output key
|
|
325
|
+
# Note: Keep spec on CPU to match other specs in Composite
|
|
326
|
+
# The actual transform will put data on GPU, but spec device must be uniform
|
|
327
|
+
from torchrl.data import Unbounded
|
|
328
|
+
|
|
329
|
+
in_spec = observation_spec[self.in_key]
|
|
330
|
+
# Output shape: (C, H, W) where C=1 if grayscale else 3
|
|
331
|
+
out_channels = 1 if self.grayscale else 3
|
|
332
|
+
out_shape = (
|
|
333
|
+
*in_spec.shape[:-3],
|
|
334
|
+
out_channels,
|
|
335
|
+
self.image_size,
|
|
336
|
+
self.image_size,
|
|
337
|
+
)
|
|
338
|
+
# Use in_spec.device to maintain device consistency in Composite
|
|
339
|
+
out_spec = Unbounded(
|
|
340
|
+
shape=out_shape, dtype=torch.float32, device=in_spec.device
|
|
341
|
+
)
|
|
342
|
+
observation_spec[self.out_key] = out_spec
|
|
343
|
+
return observation_spec
|
|
344
|
+
|
|
345
|
+
|
|
263
346
|
def _make_env(cfg, device, from_pixels=False):
|
|
264
347
|
lib = cfg.env.backend
|
|
265
348
|
if lib in ("gym", "gymnasium"):
|
|
@@ -294,22 +377,44 @@ def _make_env(cfg, device, from_pixels=False):
|
|
|
294
377
|
return env
|
|
295
378
|
|
|
296
379
|
|
|
297
|
-
def transform_env(cfg, env):
|
|
380
|
+
def transform_env(cfg, env, device=None):
|
|
381
|
+
"""Apply transforms to environment.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
cfg: Config object
|
|
385
|
+
env: The environment to transform
|
|
386
|
+
device: If specified and is a CUDA device, use GPU-accelerated image
|
|
387
|
+
processing which is ~50-100x faster than CPU.
|
|
388
|
+
"""
|
|
298
389
|
if not isinstance(env, TransformedEnv):
|
|
299
390
|
env = TransformedEnv(env)
|
|
300
391
|
if cfg.env.from_pixels:
|
|
301
|
-
#
|
|
392
|
+
# Rename original pixels for processing
|
|
302
393
|
env.append_transform(
|
|
303
394
|
RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])
|
|
304
395
|
)
|
|
305
|
-
env.append_transform(
|
|
306
|
-
ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"])
|
|
307
|
-
)
|
|
308
|
-
if cfg.env.grayscale:
|
|
309
|
-
env.append_transform(GrayScale())
|
|
310
396
|
|
|
311
|
-
|
|
312
|
-
|
|
397
|
+
# Use GPU-accelerated image processing if device is CUDA
|
|
398
|
+
if device is not None and str(device).startswith("cuda"):
|
|
399
|
+
env.append_transform(
|
|
400
|
+
GPUImageTransform(
|
|
401
|
+
device=device,
|
|
402
|
+
image_size=cfg.env.image_size,
|
|
403
|
+
grayscale=cfg.env.grayscale,
|
|
404
|
+
in_key="pixels_int",
|
|
405
|
+
out_key="pixels",
|
|
406
|
+
)
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
# CPU fallback: use standard transforms
|
|
410
|
+
env.append_transform(
|
|
411
|
+
ToTensorImage(
|
|
412
|
+
from_int=True, in_keys=["pixels_int"], out_keys=["pixels"]
|
|
413
|
+
)
|
|
414
|
+
)
|
|
415
|
+
if cfg.env.grayscale:
|
|
416
|
+
env.append_transform(GrayScale())
|
|
417
|
+
env.append_transform(Resize(cfg.env.image_size, cfg.env.image_size))
|
|
313
418
|
|
|
314
419
|
env.append_transform(DoubleToFloat())
|
|
315
420
|
env.append_transform(RewardSum())
|
|
@@ -329,24 +434,38 @@ def make_environments(cfg, parallel_envs=1, logger=None):
|
|
|
329
434
|
"""
|
|
330
435
|
|
|
331
436
|
def train_env_factory():
|
|
332
|
-
"""Factory function for creating training environments.
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
437
|
+
"""Factory function for creating training environments.
|
|
438
|
+
|
|
439
|
+
Note: This factory runs inside collector worker processes. We use
|
|
440
|
+
CUDA if available for GPU-accelerated image transforms (ToTensorImage,
|
|
441
|
+
Resize) which are ~50-100x faster than CPU. The cfg.env.device setting
|
|
442
|
+
is ignored in favor of auto-detecting CUDA availability.
|
|
443
|
+
"""
|
|
444
|
+
# Use CUDA for transforms if available, regardless of cfg.env.device
|
|
445
|
+
# This is critical: image transforms (Resize, ToTensorImage) are ~50-100x
|
|
446
|
+
# faster on GPU. DMControl/Gym render on CPU, but we move to GPU for transforms.
|
|
447
|
+
transform_device = _default_device(None) # Returns CUDA if available
|
|
448
|
+
# Base env still uses cfg.env.device for compatibility
|
|
449
|
+
env_device = _default_device(cfg.env.device)
|
|
450
|
+
func = functools.partial(_make_env, cfg=cfg, device=env_device)
|
|
336
451
|
train_env = ParallelEnv(
|
|
337
452
|
parallel_envs,
|
|
338
453
|
EnvCreator(func),
|
|
339
454
|
serial_for_single=True,
|
|
340
455
|
)
|
|
341
|
-
|
|
456
|
+
# Pass transform_device to enable GPU-accelerated image transforms
|
|
457
|
+
train_env = transform_env(cfg, train_env, device=transform_device)
|
|
342
458
|
train_env.set_seed(cfg.env.seed)
|
|
343
459
|
return train_env
|
|
344
460
|
|
|
345
461
|
# Create eval env directly (not a factory)
|
|
462
|
+
# Use CUDA for transforms if available, regardless of cfg.env.device
|
|
463
|
+
transform_device = _default_device(None) # Returns CUDA if available
|
|
464
|
+
env_device = _default_device(cfg.env.device)
|
|
346
465
|
func = functools.partial(
|
|
347
466
|
_make_env,
|
|
348
467
|
cfg=cfg,
|
|
349
|
-
device=
|
|
468
|
+
device=env_device,
|
|
350
469
|
from_pixels=cfg.logger.video,
|
|
351
470
|
)
|
|
352
471
|
eval_env = ParallelEnv(
|
|
@@ -354,7 +473,8 @@ def make_environments(cfg, parallel_envs=1, logger=None):
|
|
|
354
473
|
EnvCreator(func),
|
|
355
474
|
serial_for_single=True,
|
|
356
475
|
)
|
|
357
|
-
|
|
476
|
+
# Pass transform_device to enable GPU-accelerated image transforms
|
|
477
|
+
eval_env = transform_env(cfg, eval_env, device=transform_device)
|
|
358
478
|
eval_env.set_seed(cfg.env.seed + 1)
|
|
359
479
|
if cfg.logger.video:
|
|
360
480
|
eval_env.insert_transform(
|
|
@@ -681,15 +801,32 @@ def make_storage_transform(
|
|
|
681
801
|
pixel_obs=True,
|
|
682
802
|
grayscale=True,
|
|
683
803
|
image_size,
|
|
804
|
+
gpu_transforms=False,
|
|
684
805
|
):
|
|
685
806
|
"""Create transforms to be applied at extend-time (once per frame).
|
|
686
807
|
|
|
687
|
-
|
|
688
|
-
|
|
808
|
+
Args:
|
|
809
|
+
pixel_obs: Whether observations are pixel-based.
|
|
810
|
+
grayscale: Whether to convert to grayscale.
|
|
811
|
+
image_size: Target image size.
|
|
812
|
+
gpu_transforms: If True, skip heavy image transforms (ToTensorImage,
|
|
813
|
+
GrayScale, Resize) since they're already applied by GPUImageTransform
|
|
814
|
+
in the environment. Only ExcludeTransform is applied to filter keys.
|
|
689
815
|
"""
|
|
690
816
|
if not pixel_obs:
|
|
691
817
|
return None
|
|
692
818
|
|
|
819
|
+
# When GPU transforms are enabled, GPUImageTransform already processes
|
|
820
|
+
# pixels_int -> pixels with normalization, grayscale, and resize.
|
|
821
|
+
# We only need to filter out the intermediate pixels_int key.
|
|
822
|
+
if gpu_transforms:
|
|
823
|
+
storage_transforms = Compose(
|
|
824
|
+
# Just exclude pixels_int, keep everything else including processed pixels
|
|
825
|
+
ExcludeTransform("pixels_int", ("next", "pixels_int")),
|
|
826
|
+
)
|
|
827
|
+
return storage_transforms
|
|
828
|
+
|
|
829
|
+
# CPU fallback: apply heavy transforms at storage time
|
|
693
830
|
storage_transforms = Compose(
|
|
694
831
|
ExcludeTransform("pixels", ("next", "pixels"), inverse=True),
|
|
695
832
|
ToTensorImage(
|
|
@@ -741,7 +878,6 @@ def make_replay_buffer(
|
|
|
741
878
|
)
|
|
742
879
|
|
|
743
880
|
replay_buffer = TensorDictReplayBuffer(
|
|
744
|
-
pin_memory=False,
|
|
745
881
|
prefetch=prefetch,
|
|
746
882
|
storage=LazyMemmapStorage(
|
|
747
883
|
buffer_size,
|
|
@@ -755,7 +891,9 @@ def make_replay_buffer(
|
|
|
755
891
|
strict_length=False,
|
|
756
892
|
traj_key=("collector", "traj_ids"),
|
|
757
893
|
cache_values=False, # Disabled for async collection (cache not synced across processes)
|
|
758
|
-
|
|
894
|
+
use_gpu=device.type == "cuda"
|
|
895
|
+
if device is not None
|
|
896
|
+
else False, # Speed up trajectory computation on GPU
|
|
759
897
|
),
|
|
760
898
|
transform=sample_transforms,
|
|
761
899
|
batch_size=batch_size,
|
|
Binary file
|
torchrl/collectors/_base.py
CHANGED
|
@@ -585,6 +585,14 @@ class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
|
|
|
585
585
|
... "actor": actor_weights,
|
|
586
586
|
... "critic": critic_weights,
|
|
587
587
|
... })
|
|
588
|
+
>>>
|
|
589
|
+
>>> # Per-worker weight updates (for distinct policy factories)
|
|
590
|
+
>>> # Each worker can have independently updated weights
|
|
591
|
+
>>> collector.update_policy_weights_({
|
|
592
|
+
... 0: worker_0_weights,
|
|
593
|
+
... 1: worker_1_weights,
|
|
594
|
+
... 2: worker_2_weights,
|
|
595
|
+
... })
|
|
588
596
|
|
|
589
597
|
Args:
|
|
590
598
|
policy_or_weights: The weights to update with. Can be:
|
|
@@ -593,6 +601,8 @@ class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
|
|
|
593
601
|
- ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
|
|
594
602
|
- ``TensorDictBase``: A TensorDict containing weights
|
|
595
603
|
- ``dict``: A regular dict containing weights
|
|
604
|
+
- ``dict[int, TensorDictBase]``: Per-worker weights where keys are worker indices.
|
|
605
|
+
This is used with distinct policy factories where each worker has independent weights.
|
|
596
606
|
- ``None``: Will try to get weights from server using ``_get_server_weights()``
|
|
597
607
|
|
|
598
608
|
Keyword Args:
|
|
@@ -429,16 +429,30 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
|
|
|
429
429
|
raise TypeError(
|
|
430
430
|
"Cannot specify both weight_sync_schemes and weight_updater."
|
|
431
431
|
)
|
|
432
|
+
# Check if policy_factory entries are all the same (replicated from single factory)
|
|
433
|
+
# vs different factories per worker.
|
|
434
|
+
has_uniform_policy_factory = any(policy_factory) and all(
|
|
435
|
+
f is policy_factory[0] for f in policy_factory
|
|
436
|
+
)
|
|
437
|
+
has_distinct_policy_factory = (
|
|
438
|
+
any(policy_factory) and not has_uniform_policy_factory
|
|
439
|
+
)
|
|
432
440
|
if (
|
|
433
441
|
weight_sync_schemes is not None
|
|
434
442
|
and not weight_sync_schemes
|
|
435
443
|
and weight_updater is None
|
|
436
|
-
and (isinstance(policy, nn.Module) or any(policy_factory))
|
|
437
444
|
):
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
445
|
+
if isinstance(policy, nn.Module) or has_uniform_policy_factory:
|
|
446
|
+
# Set up a default local shared-memory sync scheme for the policy.
|
|
447
|
+
# This is used to propagate weights from the orchestrator policy
|
|
448
|
+
# (possibly combined with a policy_factory) down to worker policies.
|
|
449
|
+
weight_sync_schemes["policy"] = SharedMemWeightSyncScheme()
|
|
450
|
+
elif has_distinct_policy_factory:
|
|
451
|
+
# Distinct factories: set up per-worker weight sync scheme.
|
|
452
|
+
# Each worker maintains independent weights that can be updated individually.
|
|
453
|
+
weight_sync_schemes["policy"] = SharedMemWeightSyncScheme(
|
|
454
|
+
per_worker_weights=True
|
|
455
|
+
)
|
|
442
456
|
|
|
443
457
|
self._setup_multi_weight_sync(weight_updater, weight_sync_schemes)
|
|
444
458
|
|
torchrl/collectors/utils.py
CHANGED
|
@@ -417,7 +417,8 @@ def _make_policy_factory(
|
|
|
417
417
|
raise ValueError("policy cannot be used with policy_factory")
|
|
418
418
|
elif has_policy_factory:
|
|
419
419
|
if isinstance(policy_factory, Sequence):
|
|
420
|
-
|
|
420
|
+
# Use worker_idx to get the correct factory for this worker
|
|
421
|
+
policy = policy_factory[worker_idx]()
|
|
421
422
|
else:
|
|
422
423
|
policy = policy_factory()
|
|
423
424
|
|
|
@@ -93,7 +93,8 @@ class Writer(ABC):
|
|
|
93
93
|
)
|
|
94
94
|
mesh = torch.stack(
|
|
95
95
|
torch.meshgrid(
|
|
96
|
-
*(torch.arange(dim, device=device) for dim in self._storage.shape[1:])
|
|
96
|
+
*(torch.arange(dim, device=device) for dim in self._storage.shape[1:]),
|
|
97
|
+
indexing="ij",
|
|
97
98
|
),
|
|
98
99
|
-1,
|
|
99
100
|
).flatten(0, -2)
|
torchrl/data/tensor_specs.py
CHANGED
|
@@ -5601,7 +5601,7 @@ class Composite(TensorSpec):
|
|
|
5601
5601
|
elif self.data_cls is not None:
|
|
5602
5602
|
out = {}
|
|
5603
5603
|
else:
|
|
5604
|
-
out = TensorDict._new_unsafe({},
|
|
5604
|
+
out = TensorDict._new_unsafe({}, self.shape)
|
|
5605
5605
|
for key, item in vals.items():
|
|
5606
5606
|
if item is None:
|
|
5607
5607
|
raise RuntimeError(
|
|
@@ -5644,7 +5644,7 @@ class Composite(TensorSpec):
|
|
|
5644
5644
|
else:
|
|
5645
5645
|
|
|
5646
5646
|
def empty(vals):
|
|
5647
|
-
out = TensorDict._new_unsafe({},
|
|
5647
|
+
out = TensorDict._new_unsafe({}, self.shape)
|
|
5648
5648
|
return vals, out
|
|
5649
5649
|
|
|
5650
5650
|
funcs.append(empty)
|
torchrl/envs/batched_envs.py
CHANGED
|
@@ -96,6 +96,22 @@ class _dispatch_caller_parallel:
|
|
|
96
96
|
# if the object returned is not a callable
|
|
97
97
|
return iter(self.__call__())
|
|
98
98
|
|
|
99
|
+
def __getattr__(self, name):
|
|
100
|
+
"""Support chained attribute access: env_parallel.a.b -> sends ('a','b') to workers."""
|
|
101
|
+
# Don't chain special/dunder methods - these are often called by
|
|
102
|
+
# display systems (e.g., Jupyter's _repr_html_) and shouldn't be
|
|
103
|
+
# dispatched to workers
|
|
104
|
+
if name.startswith("_"):
|
|
105
|
+
raise AttributeError(
|
|
106
|
+
f"Accessing private/special attribute {name!r} is not supported "
|
|
107
|
+
f"on dispatched parallel env attributes."
|
|
108
|
+
)
|
|
109
|
+
if isinstance(self.attr, tuple):
|
|
110
|
+
new_attr = self.attr + (name,)
|
|
111
|
+
else:
|
|
112
|
+
new_attr = (self.attr, name)
|
|
113
|
+
return _dispatch_caller_parallel(new_attr, self.parallel_env)
|
|
114
|
+
|
|
99
115
|
|
|
100
116
|
class _dispatch_caller_serial:
|
|
101
117
|
def __init__(self, list_callable: list[Callable, Any]):
|
|
@@ -7673,6 +7673,26 @@ class StepCounter(Transform):
|
|
|
7673
7673
|
self._truncated_keys = truncated_keys
|
|
7674
7674
|
return truncated_keys
|
|
7675
7675
|
|
|
7676
|
+
@property
|
|
7677
|
+
def all_truncated_keys(self) -> list[NestedKey]:
|
|
7678
|
+
"""Returns truncated keys for ALL reset keys (including nested ones).
|
|
7679
|
+
|
|
7680
|
+
Used for propagating truncated to nested agent-level keys in MARL envs.
|
|
7681
|
+
"""
|
|
7682
|
+
all_truncated_keys = self.__dict__.get("_all_truncated_keys", None)
|
|
7683
|
+
if all_truncated_keys is None:
|
|
7684
|
+
all_truncated_keys = []
|
|
7685
|
+
if self.parent is None:
|
|
7686
|
+
return self.truncated_keys
|
|
7687
|
+
for reset_key in self.parent.reset_keys:
|
|
7688
|
+
if isinstance(reset_key, str):
|
|
7689
|
+
key = self.truncated_key
|
|
7690
|
+
else:
|
|
7691
|
+
key = (*reset_key[:-1], self.truncated_key)
|
|
7692
|
+
all_truncated_keys.append(key)
|
|
7693
|
+
self.__dict__["_all_truncated_keys"] = all_truncated_keys
|
|
7694
|
+
return all_truncated_keys
|
|
7695
|
+
|
|
7676
7696
|
@property
|
|
7677
7697
|
def done_keys(self) -> list[NestedKey]:
|
|
7678
7698
|
done_keys = self.__dict__.get("_done_keys", None)
|
|
@@ -7688,6 +7708,26 @@ class StepCounter(Transform):
|
|
|
7688
7708
|
self.__dict__["_done_keys"] = done_keys
|
|
7689
7709
|
return done_keys
|
|
7690
7710
|
|
|
7711
|
+
@property
|
|
7712
|
+
def all_done_keys(self) -> list[NestedKey]:
|
|
7713
|
+
"""Returns done keys for ALL reset keys (including nested ones).
|
|
7714
|
+
|
|
7715
|
+
Used for propagating done to nested agent-level keys in MARL envs.
|
|
7716
|
+
"""
|
|
7717
|
+
all_done_keys = self.__dict__.get("_all_done_keys", None)
|
|
7718
|
+
if all_done_keys is None:
|
|
7719
|
+
all_done_keys = []
|
|
7720
|
+
if self.parent is None:
|
|
7721
|
+
return self.done_keys
|
|
7722
|
+
for reset_key in self.parent.reset_keys:
|
|
7723
|
+
if isinstance(reset_key, str):
|
|
7724
|
+
key = "done"
|
|
7725
|
+
else:
|
|
7726
|
+
key = (*reset_key[:-1], "done")
|
|
7727
|
+
all_done_keys.append(key)
|
|
7728
|
+
self.__dict__["_all_done_keys"] = all_done_keys
|
|
7729
|
+
return all_done_keys
|
|
7730
|
+
|
|
7691
7731
|
@property
|
|
7692
7732
|
def terminated_keys(self) -> list[NestedKey]:
|
|
7693
7733
|
terminated_keys = self.__dict__.get("_terminated_keys", None)
|
|
@@ -7803,8 +7843,59 @@ class StepCounter(Transform):
|
|
|
7803
7843
|
done = truncated | done # we assume no done after reset
|
|
7804
7844
|
next_tensordict.set(done_key, done)
|
|
7805
7845
|
next_tensordict.set(truncated_key, truncated)
|
|
7846
|
+
|
|
7847
|
+
# Propagate truncated/done to nested agent-level keys in MARL envs
|
|
7848
|
+
# This ensures that when max_steps is reached, all agent truncated/done keys are updated
|
|
7849
|
+
if self.max_steps is not None:
|
|
7850
|
+
self._propagate_to_nested_keys(next_tensordict)
|
|
7851
|
+
|
|
7806
7852
|
return next_tensordict
|
|
7807
7853
|
|
|
7854
|
+
def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
|
|
7855
|
+
"""Propagate truncated and done values to nested agent-level keys.
|
|
7856
|
+
|
|
7857
|
+
In MARL envs, there may be nested agent-level truncated/done keys that
|
|
7858
|
+
are children of the root truncated/done. When StepCounter sets truncated
|
|
7859
|
+
at the root level, we need to propagate this to nested keys.
|
|
7860
|
+
"""
|
|
7861
|
+
# Get the set of keys we already updated (filtered keys)
|
|
7862
|
+
updated_truncated = set(self.truncated_keys)
|
|
7863
|
+
updated_done = set(self.done_keys)
|
|
7864
|
+
|
|
7865
|
+
# Propagate truncated to nested keys
|
|
7866
|
+
for nested_key in self.all_truncated_keys:
|
|
7867
|
+
if nested_key in updated_truncated:
|
|
7868
|
+
continue
|
|
7869
|
+
# Find the parent truncated key that should be propagated
|
|
7870
|
+
nested_truncated = next_tensordict.get(nested_key, None)
|
|
7871
|
+
if nested_truncated is None:
|
|
7872
|
+
continue
|
|
7873
|
+
# Find a parent truncated key to propagate from
|
|
7874
|
+
for parent_key in self.truncated_keys:
|
|
7875
|
+
parent_truncated = next_tensordict.get(parent_key, None)
|
|
7876
|
+
if parent_truncated is not None:
|
|
7877
|
+
# Expand parent truncated to match nested shape and apply OR
|
|
7878
|
+
expanded = parent_truncated.expand_as(nested_truncated)
|
|
7879
|
+
next_tensordict.set(nested_key, nested_truncated | expanded)
|
|
7880
|
+
break
|
|
7881
|
+
|
|
7882
|
+
# Propagate done to nested keys if update_done is True
|
|
7883
|
+
if self.update_done:
|
|
7884
|
+
for nested_key in self.all_done_keys:
|
|
7885
|
+
if nested_key in updated_done:
|
|
7886
|
+
continue
|
|
7887
|
+
nested_done = next_tensordict.get(nested_key, None)
|
|
7888
|
+
if nested_done is None:
|
|
7889
|
+
continue
|
|
7890
|
+
# Find a parent done key to propagate from
|
|
7891
|
+
for parent_key in self.done_keys:
|
|
7892
|
+
parent_done = next_tensordict.get(parent_key, None)
|
|
7893
|
+
if parent_done is not None:
|
|
7894
|
+
# Expand parent done to match nested shape and apply OR
|
|
7895
|
+
expanded = parent_done.expand_as(nested_done)
|
|
7896
|
+
next_tensordict.set(nested_key, nested_done | expanded)
|
|
7897
|
+
break
|
|
7898
|
+
|
|
7808
7899
|
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
|
|
7809
7900
|
if not isinstance(observation_spec, Composite):
|
|
7810
7901
|
raise ValueError(
|
torchrl/objectives/cql.py
CHANGED
|
@@ -298,6 +298,7 @@ class CQLLoss(LossModule):
|
|
|
298
298
|
lagrange_thresh: float = 0.0,
|
|
299
299
|
reduction: str | None = None,
|
|
300
300
|
deactivate_vmap: bool = False,
|
|
301
|
+
scalar_output_mode: str | None = None,
|
|
301
302
|
) -> None:
|
|
302
303
|
self._out_keys = None
|
|
303
304
|
if reduction is None:
|
|
@@ -381,6 +382,23 @@ class CQLLoss(LossModule):
|
|
|
381
382
|
)
|
|
382
383
|
self._make_vmap()
|
|
383
384
|
self.reduction = reduction
|
|
385
|
+
|
|
386
|
+
# Handle scalar_output_mode for reduction="none"
|
|
387
|
+
if reduction == "none" and scalar_output_mode is None:
|
|
388
|
+
warnings.warn(
|
|
389
|
+
"CQLLoss with reduction='none' cannot include scalar values (alpha, entropy) "
|
|
390
|
+
"in the output TensorDict without changing their shape. These values will be "
|
|
391
|
+
"excluded from the output. You can access them via `loss_module._alpha` and "
|
|
392
|
+
"compute entropy from the log_prob in the actor loss metadata. "
|
|
393
|
+
"To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. "
|
|
394
|
+
"Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. "
|
|
395
|
+
"This is a known limitation we're working on improving.",
|
|
396
|
+
category=UserWarning,
|
|
397
|
+
stacklevel=2,
|
|
398
|
+
)
|
|
399
|
+
scalar_output_mode = "exclude"
|
|
400
|
+
self.scalar_output_mode = scalar_output_mode
|
|
401
|
+
|
|
384
402
|
_ = self.target_entropy
|
|
385
403
|
|
|
386
404
|
def _make_vmap(self):
|
|
@@ -548,18 +566,28 @@ class CQLLoss(LossModule):
|
|
|
548
566
|
tensordict.set(
|
|
549
567
|
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
|
|
550
568
|
)
|
|
569
|
+
entropy = -actor_metadata.get(self.tensor_keys.log_prob)
|
|
551
570
|
out = {
|
|
552
571
|
"loss_actor": loss_actor,
|
|
553
572
|
"loss_actor_bc": loss_actor_bc,
|
|
554
573
|
"loss_qvalue": q_loss,
|
|
555
574
|
"loss_cql": cql_loss,
|
|
556
575
|
"loss_alpha": loss_alpha,
|
|
557
|
-
"alpha": self._alpha,
|
|
558
|
-
"entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
|
|
559
576
|
}
|
|
560
577
|
if self.with_lagrange:
|
|
561
578
|
out["loss_alpha_prime"] = alpha_prime_loss.mean()
|
|
562
|
-
|
|
579
|
+
|
|
580
|
+
# Handle batch_size and scalar values (alpha, entropy) based on reduction mode
|
|
581
|
+
if self.reduction == "none":
|
|
582
|
+
batch_size = tensordict.batch_size
|
|
583
|
+
td_loss = TensorDict(out, batch_size=batch_size)
|
|
584
|
+
if self.scalar_output_mode == "non_tensor":
|
|
585
|
+
td_loss.set_non_tensor("alpha", self._alpha)
|
|
586
|
+
td_loss.set_non_tensor("entropy", entropy.detach().mean())
|
|
587
|
+
else:
|
|
588
|
+
out["alpha"] = self._alpha
|
|
589
|
+
out["entropy"] = entropy.detach().mean()
|
|
590
|
+
td_loss = TensorDict(out)
|
|
563
591
|
self._clear_weakrefs(
|
|
564
592
|
tensordict,
|
|
565
593
|
td_loss,
|