metacontroller-pytorch 0.0.24__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import pad_at_dim
29
+ from torch_einops_utils import pad_at_dim, lens_to_mask
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -335,6 +335,7 @@ class Transformer(Module):
335
335
  return_raw_action_dist = False,
336
336
  return_latents = False,
337
337
  return_cache = False,
338
+ episode_lens: Tensor | None = None
338
339
  ):
339
340
  device = state.device
340
341
 
@@ -362,6 +363,9 @@ class Transformer(Module):
362
363
  state, target_state = state[:, :-1], state[:, 1:]
363
364
  actions, target_actions = actions[:, :-1], actions[:, 1:]
364
365
 
366
+ if exists(episode_lens):
367
+ episode_lens = (episode_lens - 1).clamp(min = 0)
368
+
365
369
  # transformer lower body
366
370
 
367
371
  with lower_transformer_context():
@@ -406,10 +410,14 @@ class Transformer(Module):
406
410
  # maybe return behavior cloning loss
407
411
 
408
412
  if behavioral_cloning:
413
+ loss_mask = None
414
+ if exists(episode_lens):
415
+ loss_mask = lens_to_mask(episode_lens, state.shape[1])
416
+
409
417
  state_dist_params = self.state_readout(attended)
410
- state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
418
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
411
419
 
412
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
420
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
413
421
 
414
422
  return state_clone_loss, action_clone_loss
415
423
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.24
3
+ Version: 0.0.25
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=icToDxXPknHG5C5hTzVaVOCibYbJ3aDLmZlaMc3Xge0,14275
3
+ metacontroller_pytorch-0.0.25.dist-info/METADATA,sha256=HItPrlXUrJhZ1ZmpVU8JNftpyazBvJ3GVlOJPWL8NKE,4363
4
+ metacontroller_pytorch-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.25.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=Ocm_2hCBvV2coYg4tQ4kYd0LQHgHWiz1l-c9lR7Z_fM,13941
3
- metacontroller_pytorch-0.0.24.dist-info/METADATA,sha256=5xKHBecV3iRSK-JbNwOQ0iv6KSz_sIN--ar_M05-EWQ,4363
4
- metacontroller_pytorch-0.0.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.24.dist-info/RECORD,,