torchrl 0.11.0__cp313-cp313-win_amd64.whl → 0.11.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -140,17 +140,21 @@ def main(cfg: DictConfig): # noqa: F821
140
140
  buffer_size=buffer_size,
141
141
  buffer_scratch_dir=scratch_dir,
142
142
  device=device,
143
- prefetch=prefetch if not profiling_enabled else None,
143
+ prefetch=prefetch, # Always use prefetch for better throughput
144
144
  pixel_obs=cfg.env.from_pixels,
145
145
  grayscale=cfg.env.grayscale,
146
146
  image_size=cfg.env.image_size,
147
147
  )
148
148
 
149
149
  # Create storage transform for extend-time processing (applied once per frame)
150
+ # When GPU is available, GPUImageTransform handles image processing in the env,
151
+ # so we skip the heavy CPU transforms in storage_transform
152
+ gpu_transforms = device.type == "cuda"
150
153
  storage_transform = make_storage_transform(
151
154
  pixel_obs=cfg.env.from_pixels,
152
155
  grayscale=cfg.env.grayscale,
153
156
  image_size=cfg.env.image_size,
157
+ gpu_transforms=gpu_transforms,
154
158
  )
155
159
 
156
160
  # Create policy version tracker for async collection
@@ -247,7 +251,11 @@ def main(cfg: DictConfig): # noqa: F821
247
251
  compile_warmup = 3
248
252
  torchrl_logger.info(f"Compiling loss modules with warmup={compile_warmup}")
249
253
  backend = compile_cfg.backend
250
- mode = compile_cfg.mode
254
+ cudagraphs = compile_cfg.cudagraphs
255
+
256
+ # Build compile options - disable CUDA graphs if configured (default)
257
+ # CUDA graphs conflict with dynamic RSSM rollout loop
258
+ compile_options = {"triton.cudagraphs": cudagraphs}
251
259
 
252
260
  # Note: We do NOT compile rssm_prior/rssm_posterior here because they are
253
261
  # shared with the policy used in the collector. Compiling them would cause
@@ -260,17 +268,25 @@ def main(cfg: DictConfig): # noqa: F821
260
268
  world_model_loss = compile_with_warmup(
261
269
  world_model_loss,
262
270
  backend=backend,
263
- mode=mode,
264
271
  fullgraph=False,
265
272
  warmup=compile_warmup,
273
+ options=compile_options,
266
274
  )
267
275
  if "actor" in compile_losses:
268
276
  actor_loss = compile_with_warmup(
269
- actor_loss, backend=backend, mode=mode, warmup=compile_warmup
277
+ actor_loss,
278
+ backend=backend,
279
+ fullgraph=False,
280
+ warmup=compile_warmup,
281
+ options=compile_options,
270
282
  )
271
283
  if "value" in compile_losses:
272
284
  value_loss = compile_with_warmup(
273
- value_loss, backend=backend, mode=mode, warmup=compile_warmup
285
+ value_loss,
286
+ backend=backend,
287
+ fullgraph=False,
288
+ warmup=compile_warmup,
289
+ options=compile_options,
274
290
  )
275
291
  else:
276
292
  compile_warmup = 0
@@ -10,7 +10,7 @@ from contextlib import nullcontext
10
10
 
11
11
  import torch
12
12
  import torch.nn as nn
13
- from tensordict import NestedKey
13
+ from tensordict import NestedKey, TensorDictBase
14
14
  from tensordict.nn import (
15
15
  InteractionType,
16
16
  ProbabilisticTensorDictModule,
@@ -38,7 +38,6 @@ from torchrl.envs import (
38
38
  DreamerEnv,
39
39
  EnvCreator,
40
40
  ExcludeTransform,
41
- # ExcludeTransform,
42
41
  FrameSkipTransform,
43
42
  GrayScale,
44
43
  GymEnv,
@@ -50,6 +49,7 @@ from torchrl.envs import (
50
49
  StepCounter,
51
50
  TensorDictPrimer,
52
51
  ToTensorImage,
52
+ Transform,
53
53
  TransformedEnv,
54
54
  )
55
55
  from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
@@ -260,6 +260,89 @@ class DreamerProfiler:
260
260
  return self.total_optim_steps >= target_steps
261
261
 
262
262
 
263
+ class GPUImageTransform(Transform):
264
+ """Composite transform that processes images on GPU for faster execution.
265
+
266
+ This transform:
267
+ 1. Moves pixels_int to GPU
268
+ 2. Runs ToTensorImage (permute + divide by 255)
269
+ 3. Optionally runs GrayScale
270
+ 4. Runs Resize
271
+ 5. Keeps output on GPU for fast policy inference
272
+
273
+ This avoids device mismatch issues by not using DeviceCastTransform on the
274
+ full tensordict - only the pixel processing happens on GPU.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ device: torch.device,
280
+ image_size: int,
281
+ grayscale: bool = False,
282
+ in_key: str = "pixels_int",
283
+ out_key: str = "pixels",
284
+ ):
285
+ super().__init__(in_keys=[in_key], out_keys=[out_key])
286
+ self.device = device
287
+ self.image_size = image_size
288
+ self.grayscale = grayscale
289
+ self.in_key = in_key
290
+ self.out_key = out_key
291
+
292
+ def _apply_transform(self, pixels_int: torch.Tensor) -> torch.Tensor:
293
+ # Move to GPU
294
+ pixels = pixels_int.to(self.device)
295
+ # ToTensorImage: permute W x H x C -> C x W x H and normalize
296
+ pixels = pixels.permute(*list(range(pixels.ndimension() - 3)), -1, -3, -2)
297
+ pixels = pixels.float().div(255)
298
+ # GrayScale
299
+ if self.grayscale:
300
+ pixels = pixels.mean(dim=-3, keepdim=True)
301
+ # Resize using interpolate
302
+ if pixels.shape[-2:] != (self.image_size, self.image_size):
303
+ # Add batch dim if needed for interpolate
304
+ needs_squeeze = pixels.ndim == 3
305
+ if needs_squeeze:
306
+ pixels = pixels.unsqueeze(0)
307
+ pixels = torch.nn.functional.interpolate(
308
+ pixels,
309
+ size=(self.image_size, self.image_size),
310
+ mode="bilinear",
311
+ align_corners=False,
312
+ antialias=True,
313
+ )
314
+ if needs_squeeze:
315
+ pixels = pixels.squeeze(0)
316
+ return pixels
317
+
318
+ def _reset(
319
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
320
+ ) -> TensorDictBase:
321
+ return self._call(tensordict_reset)
322
+
323
+ def transform_observation_spec(self, observation_spec):
324
+ # Update the spec for the output key
325
+ # Note: Keep spec on CPU to match other specs in Composite
326
+ # The actual transform will put data on GPU, but spec device must be uniform
327
+ from torchrl.data import Unbounded
328
+
329
+ in_spec = observation_spec[self.in_key]
330
+ # Output shape: (C, H, W) where C=1 if grayscale else 3
331
+ out_channels = 1 if self.grayscale else 3
332
+ out_shape = (
333
+ *in_spec.shape[:-3],
334
+ out_channels,
335
+ self.image_size,
336
+ self.image_size,
337
+ )
338
+ # Use in_spec.device to maintain device consistency in Composite
339
+ out_spec = Unbounded(
340
+ shape=out_shape, dtype=torch.float32, device=in_spec.device
341
+ )
342
+ observation_spec[self.out_key] = out_spec
343
+ return observation_spec
344
+
345
+
263
346
  def _make_env(cfg, device, from_pixels=False):
264
347
  lib = cfg.env.backend
265
348
  if lib in ("gym", "gymnasium"):
@@ -294,22 +377,44 @@ def _make_env(cfg, device, from_pixels=False):
294
377
  return env
295
378
 
296
379
 
297
- def transform_env(cfg, env):
380
+ def transform_env(cfg, env, device=None):
381
+ """Apply transforms to environment.
382
+
383
+ Args:
384
+ cfg: Config object
385
+ env: The environment to transform
386
+ device: If specified and is a CUDA device, use GPU-accelerated image
387
+ processing which is ~50-100x faster than CPU.
388
+ """
298
389
  if not isinstance(env, TransformedEnv):
299
390
  env = TransformedEnv(env)
300
391
  if cfg.env.from_pixels:
301
- # transforms pixel from 0-255 to 0-1 (uint8 to float32)
392
+ # Rename original pixels for processing
302
393
  env.append_transform(
303
394
  RenameTransform(in_keys=["pixels"], out_keys=["pixels_int"])
304
395
  )
305
- env.append_transform(
306
- ToTensorImage(from_int=True, in_keys=["pixels_int"], out_keys=["pixels"])
307
- )
308
- if cfg.env.grayscale:
309
- env.append_transform(GrayScale())
310
396
 
311
- image_size = cfg.env.image_size
312
- env.append_transform(Resize(image_size, image_size))
397
+ # Use GPU-accelerated image processing if device is CUDA
398
+ if device is not None and str(device).startswith("cuda"):
399
+ env.append_transform(
400
+ GPUImageTransform(
401
+ device=device,
402
+ image_size=cfg.env.image_size,
403
+ grayscale=cfg.env.grayscale,
404
+ in_key="pixels_int",
405
+ out_key="pixels",
406
+ )
407
+ )
408
+ else:
409
+ # CPU fallback: use standard transforms
410
+ env.append_transform(
411
+ ToTensorImage(
412
+ from_int=True, in_keys=["pixels_int"], out_keys=["pixels"]
413
+ )
414
+ )
415
+ if cfg.env.grayscale:
416
+ env.append_transform(GrayScale())
417
+ env.append_transform(Resize(cfg.env.image_size, cfg.env.image_size))
313
418
 
314
419
  env.append_transform(DoubleToFloat())
315
420
  env.append_transform(RewardSum())
@@ -329,24 +434,38 @@ def make_environments(cfg, parallel_envs=1, logger=None):
329
434
  """
330
435
 
331
436
  def train_env_factory():
332
- """Factory function for creating training environments."""
333
- func = functools.partial(
334
- _make_env, cfg=cfg, device=_default_device(cfg.env.device)
335
- )
437
+ """Factory function for creating training environments.
438
+
439
+ Note: This factory runs inside collector worker processes. We use
440
+ CUDA if available for GPU-accelerated image transforms (ToTensorImage,
441
+ Resize) which are ~50-100x faster than CPU. The cfg.env.device setting
442
+ is ignored in favor of auto-detecting CUDA availability.
443
+ """
444
+ # Use CUDA for transforms if available, regardless of cfg.env.device
445
+ # This is critical: image transforms (Resize, ToTensorImage) are ~50-100x
446
+ # faster on GPU. DMControl/Gym render on CPU, but we move to GPU for transforms.
447
+ transform_device = _default_device(None) # Returns CUDA if available
448
+ # Base env still uses cfg.env.device for compatibility
449
+ env_device = _default_device(cfg.env.device)
450
+ func = functools.partial(_make_env, cfg=cfg, device=env_device)
336
451
  train_env = ParallelEnv(
337
452
  parallel_envs,
338
453
  EnvCreator(func),
339
454
  serial_for_single=True,
340
455
  )
341
- train_env = transform_env(cfg, train_env)
456
+ # Pass transform_device to enable GPU-accelerated image transforms
457
+ train_env = transform_env(cfg, train_env, device=transform_device)
342
458
  train_env.set_seed(cfg.env.seed)
343
459
  return train_env
344
460
 
345
461
  # Create eval env directly (not a factory)
462
+ # Use CUDA for transforms if available, regardless of cfg.env.device
463
+ transform_device = _default_device(None) # Returns CUDA if available
464
+ env_device = _default_device(cfg.env.device)
346
465
  func = functools.partial(
347
466
  _make_env,
348
467
  cfg=cfg,
349
- device=_default_device(cfg.env.device),
468
+ device=env_device,
350
469
  from_pixels=cfg.logger.video,
351
470
  )
352
471
  eval_env = ParallelEnv(
@@ -354,7 +473,8 @@ def make_environments(cfg, parallel_envs=1, logger=None):
354
473
  EnvCreator(func),
355
474
  serial_for_single=True,
356
475
  )
357
- eval_env = transform_env(cfg, eval_env)
476
+ # Pass transform_device to enable GPU-accelerated image transforms
477
+ eval_env = transform_env(cfg, eval_env, device=transform_device)
358
478
  eval_env.set_seed(cfg.env.seed + 1)
359
479
  if cfg.logger.video:
360
480
  eval_env.insert_transform(
@@ -681,15 +801,32 @@ def make_storage_transform(
681
801
  pixel_obs=True,
682
802
  grayscale=True,
683
803
  image_size,
804
+ gpu_transforms=False,
684
805
  ):
685
806
  """Create transforms to be applied at extend-time (once per frame).
686
807
 
687
- These heavy transforms (ToTensorImage, GrayScale, Resize) are applied once
688
- when data is added to the buffer, rather than on every sample.
808
+ Args:
809
+ pixel_obs: Whether observations are pixel-based.
810
+ grayscale: Whether to convert to grayscale.
811
+ image_size: Target image size.
812
+ gpu_transforms: If True, skip heavy image transforms (ToTensorImage,
813
+ GrayScale, Resize) since they're already applied by GPUImageTransform
814
+ in the environment. Only ExcludeTransform is applied to filter keys.
689
815
  """
690
816
  if not pixel_obs:
691
817
  return None
692
818
 
819
+ # When GPU transforms are enabled, GPUImageTransform already processes
820
+ # pixels_int -> pixels with normalization, grayscale, and resize.
821
+ # We only need to filter out the intermediate pixels_int key.
822
+ if gpu_transforms:
823
+ storage_transforms = Compose(
824
+ # Just exclude pixels_int, keep everything else including processed pixels
825
+ ExcludeTransform("pixels_int", ("next", "pixels_int")),
826
+ )
827
+ return storage_transforms
828
+
829
+ # CPU fallback: apply heavy transforms at storage time
693
830
  storage_transforms = Compose(
694
831
  ExcludeTransform("pixels", ("next", "pixels"), inverse=True),
695
832
  ToTensorImage(
@@ -741,7 +878,6 @@ def make_replay_buffer(
741
878
  )
742
879
 
743
880
  replay_buffer = TensorDictReplayBuffer(
744
- pin_memory=False,
745
881
  prefetch=prefetch,
746
882
  storage=LazyMemmapStorage(
747
883
  buffer_size,
@@ -755,7 +891,9 @@ def make_replay_buffer(
755
891
  strict_length=False,
756
892
  traj_key=("collector", "traj_ids"),
757
893
  cache_values=False, # Disabled for async collection (cache not synced across processes)
758
- # Don't compile the sampler - inductor has C++ codegen bugs for int64 ops
894
+ use_gpu=device.type == "cuda"
895
+ if device is not None
896
+ else False, # Speed up trajectory computation on GPU
759
897
  ),
760
898
  transform=sample_transforms,
761
899
  batch_size=batch_size,
Binary file
@@ -585,6 +585,14 @@ class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
585
585
  ... "actor": actor_weights,
586
586
  ... "critic": critic_weights,
587
587
  ... })
588
+ >>>
589
+ >>> # Per-worker weight updates (for distinct policy factories)
590
+ >>> # Each worker can have independently updated weights
591
+ >>> collector.update_policy_weights_({
592
+ ... 0: worker_0_weights,
593
+ ... 1: worker_1_weights,
594
+ ... 2: worker_2_weights,
595
+ ... })
588
596
 
589
597
  Args:
590
598
  policy_or_weights: The weights to update with. Can be:
@@ -593,6 +601,8 @@ class BaseCollector(IterableDataset, metaclass=abc.ABCMeta):
593
601
  - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted
594
602
  - ``TensorDictBase``: A TensorDict containing weights
595
603
  - ``dict``: A regular dict containing weights
604
+ - ``dict[int, TensorDictBase]``: Per-worker weights where keys are worker indices.
605
+ This is used with distinct policy factories where each worker has independent weights.
596
606
  - ``None``: Will try to get weights from server using ``_get_server_weights()``
597
607
 
598
608
  Keyword Args:
@@ -429,16 +429,30 @@ class MultiCollector(BaseCollector, metaclass=_MultiCollectorMeta):
429
429
  raise TypeError(
430
430
  "Cannot specify both weight_sync_schemes and weight_updater."
431
431
  )
432
+ # Check if policy_factory entries are all the same (replicated from single factory)
433
+ # vs different factories per worker.
434
+ has_uniform_policy_factory = any(policy_factory) and all(
435
+ f is policy_factory[0] for f in policy_factory
436
+ )
437
+ has_distinct_policy_factory = (
438
+ any(policy_factory) and not has_uniform_policy_factory
439
+ )
432
440
  if (
433
441
  weight_sync_schemes is not None
434
442
  and not weight_sync_schemes
435
443
  and weight_updater is None
436
- and (isinstance(policy, nn.Module) or any(policy_factory))
437
444
  ):
438
- # Set up a default local shared-memory sync scheme for the policy.
439
- # This is used to propagate weights from the orchestrator policy
440
- # (possibly combined with a policy_factory) down to worker policies.
441
- weight_sync_schemes["policy"] = SharedMemWeightSyncScheme()
445
+ if isinstance(policy, nn.Module) or has_uniform_policy_factory:
446
+ # Set up a default local shared-memory sync scheme for the policy.
447
+ # This is used to propagate weights from the orchestrator policy
448
+ # (possibly combined with a policy_factory) down to worker policies.
449
+ weight_sync_schemes["policy"] = SharedMemWeightSyncScheme()
450
+ elif has_distinct_policy_factory:
451
+ # Distinct factories: set up per-worker weight sync scheme.
452
+ # Each worker maintains independent weights that can be updated individually.
453
+ weight_sync_schemes["policy"] = SharedMemWeightSyncScheme(
454
+ per_worker_weights=True
455
+ )
442
456
 
443
457
  self._setup_multi_weight_sync(weight_updater, weight_sync_schemes)
444
458
 
@@ -417,7 +417,8 @@ def _make_policy_factory(
417
417
  raise ValueError("policy cannot be used with policy_factory")
418
418
  elif has_policy_factory:
419
419
  if isinstance(policy_factory, Sequence):
420
- return policy_factory
420
+ # Use worker_idx to get the correct factory for this worker
421
+ policy = policy_factory[worker_idx]()
421
422
  else:
422
423
  policy = policy_factory()
423
424
 
@@ -93,7 +93,8 @@ class Writer(ABC):
93
93
  )
94
94
  mesh = torch.stack(
95
95
  torch.meshgrid(
96
- *(torch.arange(dim, device=device) for dim in self._storage.shape[1:])
96
+ *(torch.arange(dim, device=device) for dim in self._storage.shape[1:]),
97
+ indexing="ij",
97
98
  ),
98
99
  -1,
99
100
  ).flatten(0, -2)
@@ -5601,7 +5601,7 @@ class Composite(TensorSpec):
5601
5601
  elif self.data_cls is not None:
5602
5602
  out = {}
5603
5603
  else:
5604
- out = TensorDict._new_unsafe({}, _size([]))
5604
+ out = TensorDict._new_unsafe({}, self.shape)
5605
5605
  for key, item in vals.items():
5606
5606
  if item is None:
5607
5607
  raise RuntimeError(
@@ -5644,7 +5644,7 @@ class Composite(TensorSpec):
5644
5644
  else:
5645
5645
 
5646
5646
  def empty(vals):
5647
- out = TensorDict._new_unsafe({}, _size([]))
5647
+ out = TensorDict._new_unsafe({}, self.shape)
5648
5648
  return vals, out
5649
5649
 
5650
5650
  funcs.append(empty)
@@ -96,6 +96,22 @@ class _dispatch_caller_parallel:
96
96
  # if the object returned is not a callable
97
97
  return iter(self.__call__())
98
98
 
99
+ def __getattr__(self, name):
100
+ """Support chained attribute access: env_parallel.a.b -> sends ('a','b') to workers."""
101
+ # Don't chain special/dunder methods - these are often called by
102
+ # display systems (e.g., Jupyter's _repr_html_) and shouldn't be
103
+ # dispatched to workers
104
+ if name.startswith("_"):
105
+ raise AttributeError(
106
+ f"Accessing private/special attribute {name!r} is not supported "
107
+ f"on dispatched parallel env attributes."
108
+ )
109
+ if isinstance(self.attr, tuple):
110
+ new_attr = self.attr + (name,)
111
+ else:
112
+ new_attr = (self.attr, name)
113
+ return _dispatch_caller_parallel(new_attr, self.parallel_env)
114
+
99
115
 
100
116
  class _dispatch_caller_serial:
101
117
  def __init__(self, list_callable: list[Callable, Any]):
@@ -7673,6 +7673,26 @@ class StepCounter(Transform):
7673
7673
  self._truncated_keys = truncated_keys
7674
7674
  return truncated_keys
7675
7675
 
7676
+ @property
7677
+ def all_truncated_keys(self) -> list[NestedKey]:
7678
+ """Returns truncated keys for ALL reset keys (including nested ones).
7679
+
7680
+ Used for propagating truncated to nested agent-level keys in MARL envs.
7681
+ """
7682
+ all_truncated_keys = self.__dict__.get("_all_truncated_keys", None)
7683
+ if all_truncated_keys is None:
7684
+ all_truncated_keys = []
7685
+ if self.parent is None:
7686
+ return self.truncated_keys
7687
+ for reset_key in self.parent.reset_keys:
7688
+ if isinstance(reset_key, str):
7689
+ key = self.truncated_key
7690
+ else:
7691
+ key = (*reset_key[:-1], self.truncated_key)
7692
+ all_truncated_keys.append(key)
7693
+ self.__dict__["_all_truncated_keys"] = all_truncated_keys
7694
+ return all_truncated_keys
7695
+
7676
7696
  @property
7677
7697
  def done_keys(self) -> list[NestedKey]:
7678
7698
  done_keys = self.__dict__.get("_done_keys", None)
@@ -7688,6 +7708,26 @@ class StepCounter(Transform):
7688
7708
  self.__dict__["_done_keys"] = done_keys
7689
7709
  return done_keys
7690
7710
 
7711
+ @property
7712
+ def all_done_keys(self) -> list[NestedKey]:
7713
+ """Returns done keys for ALL reset keys (including nested ones).
7714
+
7715
+ Used for propagating done to nested agent-level keys in MARL envs.
7716
+ """
7717
+ all_done_keys = self.__dict__.get("_all_done_keys", None)
7718
+ if all_done_keys is None:
7719
+ all_done_keys = []
7720
+ if self.parent is None:
7721
+ return self.done_keys
7722
+ for reset_key in self.parent.reset_keys:
7723
+ if isinstance(reset_key, str):
7724
+ key = "done"
7725
+ else:
7726
+ key = (*reset_key[:-1], "done")
7727
+ all_done_keys.append(key)
7728
+ self.__dict__["_all_done_keys"] = all_done_keys
7729
+ return all_done_keys
7730
+
7691
7731
  @property
7692
7732
  def terminated_keys(self) -> list[NestedKey]:
7693
7733
  terminated_keys = self.__dict__.get("_terminated_keys", None)
@@ -7803,8 +7843,59 @@ class StepCounter(Transform):
7803
7843
  done = truncated | done # we assume no done after reset
7804
7844
  next_tensordict.set(done_key, done)
7805
7845
  next_tensordict.set(truncated_key, truncated)
7846
+
7847
+ # Propagate truncated/done to nested agent-level keys in MARL envs
7848
+ # This ensures that when max_steps is reached, all agent truncated/done keys are updated
7849
+ if self.max_steps is not None:
7850
+ self._propagate_to_nested_keys(next_tensordict)
7851
+
7806
7852
  return next_tensordict
7807
7853
 
7854
+ def _propagate_to_nested_keys(self, next_tensordict: TensorDictBase) -> None:
7855
+ """Propagate truncated and done values to nested agent-level keys.
7856
+
7857
+ In MARL envs, there may be nested agent-level truncated/done keys that
7858
+ are children of the root truncated/done. When StepCounter sets truncated
7859
+ at the root level, we need to propagate this to nested keys.
7860
+ """
7861
+ # Get the set of keys we already updated (filtered keys)
7862
+ updated_truncated = set(self.truncated_keys)
7863
+ updated_done = set(self.done_keys)
7864
+
7865
+ # Propagate truncated to nested keys
7866
+ for nested_key in self.all_truncated_keys:
7867
+ if nested_key in updated_truncated:
7868
+ continue
7869
+ # Find the parent truncated key that should be propagated
7870
+ nested_truncated = next_tensordict.get(nested_key, None)
7871
+ if nested_truncated is None:
7872
+ continue
7873
+ # Find a parent truncated key to propagate from
7874
+ for parent_key in self.truncated_keys:
7875
+ parent_truncated = next_tensordict.get(parent_key, None)
7876
+ if parent_truncated is not None:
7877
+ # Expand parent truncated to match nested shape and apply OR
7878
+ expanded = parent_truncated.expand_as(nested_truncated)
7879
+ next_tensordict.set(nested_key, nested_truncated | expanded)
7880
+ break
7881
+
7882
+ # Propagate done to nested keys if update_done is True
7883
+ if self.update_done:
7884
+ for nested_key in self.all_done_keys:
7885
+ if nested_key in updated_done:
7886
+ continue
7887
+ nested_done = next_tensordict.get(nested_key, None)
7888
+ if nested_done is None:
7889
+ continue
7890
+ # Find a parent done key to propagate from
7891
+ for parent_key in self.done_keys:
7892
+ parent_done = next_tensordict.get(parent_key, None)
7893
+ if parent_done is not None:
7894
+ # Expand parent done to match nested shape and apply OR
7895
+ expanded = parent_done.expand_as(nested_done)
7896
+ next_tensordict.set(nested_key, nested_done | expanded)
7897
+ break
7898
+
7808
7899
  def transform_observation_spec(self, observation_spec: Composite) -> Composite:
7809
7900
  if not isinstance(observation_spec, Composite):
7810
7901
  raise ValueError(
torchrl/objectives/cql.py CHANGED
@@ -298,6 +298,7 @@ class CQLLoss(LossModule):
298
298
  lagrange_thresh: float = 0.0,
299
299
  reduction: str | None = None,
300
300
  deactivate_vmap: bool = False,
301
+ scalar_output_mode: str | None = None,
301
302
  ) -> None:
302
303
  self._out_keys = None
303
304
  if reduction is None:
@@ -381,6 +382,23 @@ class CQLLoss(LossModule):
381
382
  )
382
383
  self._make_vmap()
383
384
  self.reduction = reduction
385
+
386
+ # Handle scalar_output_mode for reduction="none"
387
+ if reduction == "none" and scalar_output_mode is None:
388
+ warnings.warn(
389
+ "CQLLoss with reduction='none' cannot include scalar values (alpha, entropy) "
390
+ "in the output TensorDict without changing their shape. These values will be "
391
+ "excluded from the output. You can access them via `loss_module._alpha` and "
392
+ "compute entropy from the log_prob in the actor loss metadata. "
393
+ "To suppress this warning, pass `scalar_output_mode='exclude'` to the constructor. "
394
+ "Alternatively, pass `scalar_output_mode='non_tensor'` to include them as non-tensor data. "
395
+ "This is a known limitation we're working on improving.",
396
+ category=UserWarning,
397
+ stacklevel=2,
398
+ )
399
+ scalar_output_mode = "exclude"
400
+ self.scalar_output_mode = scalar_output_mode
401
+
384
402
  _ = self.target_entropy
385
403
 
386
404
  def _make_vmap(self):
@@ -548,18 +566,28 @@ class CQLLoss(LossModule):
548
566
  tensordict.set(
549
567
  self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
550
568
  )
569
+ entropy = -actor_metadata.get(self.tensor_keys.log_prob)
551
570
  out = {
552
571
  "loss_actor": loss_actor,
553
572
  "loss_actor_bc": loss_actor_bc,
554
573
  "loss_qvalue": q_loss,
555
574
  "loss_cql": cql_loss,
556
575
  "loss_alpha": loss_alpha,
557
- "alpha": self._alpha,
558
- "entropy": -actor_metadata.get(self.tensor_keys.log_prob).mean().detach(),
559
576
  }
560
577
  if self.with_lagrange:
561
578
  out["loss_alpha_prime"] = alpha_prime_loss.mean()
562
- td_loss = TensorDict(out)
579
+
580
+ # Handle batch_size and scalar values (alpha, entropy) based on reduction mode
581
+ if self.reduction == "none":
582
+ batch_size = tensordict.batch_size
583
+ td_loss = TensorDict(out, batch_size=batch_size)
584
+ if self.scalar_output_mode == "non_tensor":
585
+ td_loss.set_non_tensor("alpha", self._alpha)
586
+ td_loss.set_non_tensor("entropy", entropy.detach().mean())
587
+ else:
588
+ out["alpha"] = self._alpha
589
+ out["entropy"] = entropy.detach().mean()
590
+ td_loss = TensorDict(out)
563
591
  self._clear_weakrefs(
564
592
  tensordict,
565
593
  td_loss,