metacontroller-pytorch 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,8 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import pad_at_dim
29
+ from torch_einops_utils import pad_at_dim, lens_to_mask
30
+ from torch_einops_utils.save_load import save_load
30
31
 
31
32
  # constants
32
33
 
@@ -63,6 +64,7 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
63
64
  'switch_loss'
64
65
  ))
65
66
 
67
+ @save_load()
66
68
  class MetaController(Module):
67
69
  def __init__(
68
70
  self,
@@ -272,6 +274,7 @@ TransformerOutput = namedtuple('TransformerOutput', (
272
274
  'prev_hiddens'
273
275
  ))
274
276
 
277
+ @save_load()
275
278
  class Transformer(Module):
276
279
  def __init__(
277
280
  self,
@@ -332,6 +335,7 @@ class Transformer(Module):
332
335
  return_raw_action_dist = False,
333
336
  return_latents = False,
334
337
  return_cache = False,
338
+ episode_lens: Tensor | None = None
335
339
  ):
336
340
  device = state.device
337
341
 
@@ -359,6 +363,9 @@ class Transformer(Module):
359
363
  state, target_state = state[:, :-1], state[:, 1:]
360
364
  actions, target_actions = actions[:, :-1], actions[:, 1:]
361
365
 
366
+ if exists(episode_lens):
367
+ episode_lens = (episode_lens - 1).clamp(min = 0)
368
+
362
369
  # transformer lower body
363
370
 
364
371
  with lower_transformer_context():
@@ -403,10 +410,14 @@ class Transformer(Module):
403
410
  # maybe return behavior cloning loss
404
411
 
405
412
  if behavioral_cloning:
413
+ loss_mask = None
414
+ if exists(episode_lens):
415
+ loss_mask = lens_to_mask(episode_lens, state.shape[1])
416
+
406
417
  state_dist_params = self.state_readout(attended)
407
- state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
418
+ state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state, mask = loss_mask)
408
419
 
409
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
420
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions, mask = loss_mask)
410
421
 
411
422
  return state_clone_loss, action_clone_loss
412
423
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.23
3
+ Version: 0.0.25
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,8 +39,8 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
- Requires-Dist: memmap-replay-buffer>=0.0.1
43
- Requires-Dist: torch-einops-utils>=0.0.7
42
+ Requires-Dist: memmap-replay-buffer>=0.0.23
43
+ Requires-Dist: torch-einops-utils>=0.0.16
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: x-evolution>=0.1.23
46
46
  Requires-Dist: x-mlps-pytorch
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=icToDxXPknHG5C5hTzVaVOCibYbJ3aDLmZlaMc3Xge0,14275
3
+ metacontroller_pytorch-0.0.25.dist-info/METADATA,sha256=HItPrlXUrJhZ1ZmpVU8JNftpyazBvJ3GVlOJPWL8NKE,4363
4
+ metacontroller_pytorch-0.0.25.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.25.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=mxXLeUTJ2AeYYUzpJQeE3TOrtZG_D45vfkb3jTNeEb8,13864
3
- metacontroller_pytorch-0.0.23.dist-info/METADATA,sha256=md2Ew7uonopeMK6YFkgfcXrs_3AoIQ_5gzu9Om3Fhuc,4361
4
- metacontroller_pytorch-0.0.23.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.23.dist-info/RECORD,,