metacontroller-pytorch 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,8 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
+ from torch_einops_utils import pad_at_dim
30
+
29
31
  # constants
30
32
 
31
33
  LinearNoBias = partial(Linear, bias = False)
@@ -46,17 +48,6 @@ def default(*args):
46
48
  return arg
47
49
  return None
48
50
 
49
- def is_empty(t):
50
- return t.numel() == 0
51
-
52
- def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
53
- if pad == (0, 0):
54
- return t
55
-
56
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
57
- zeros = ((0, 0) * dims_from_right)
58
- return F.pad(t, (*zeros, *pad), value = value)
59
-
60
51
  # tensor helpers
61
52
 
62
53
  def straight_through(src, tgt):
@@ -333,7 +324,7 @@ class Transformer(Module):
333
324
  def forward(
334
325
  self,
335
326
  state,
336
- action_ids: Tensor | None = None,
327
+ actions: Tensor | None = None,
337
328
  meta_controller: Module | None = None,
338
329
  cache: TransformerOutput | None = None,
339
330
  discovery_phase = False,
@@ -363,10 +354,10 @@ class Transformer(Module):
363
354
  # handle maybe behavioral cloning
364
355
 
365
356
  if behavioral_cloning or (meta_controlling and discovery_phase):
366
- assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
357
+ assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
367
358
 
368
359
  state, target_state = state[:, :-1], state[:, 1:]
369
- action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
360
+ actions, target_actions = actions[:, :-1], actions[:, 1:]
370
361
 
371
362
  # transformer lower body
372
363
 
@@ -376,8 +367,8 @@ class Transformer(Module):
376
367
 
377
368
  # handle no past action for first timestep
378
369
 
379
- if exists(action_ids):
380
- action_embed = self.action_embed(action_ids)
370
+ if exists(actions):
371
+ action_embed = self.action_embed(actions)
381
372
  else:
382
373
  action_embed = state_embed[:, 0:0] # empty action embed
383
374
 
@@ -415,13 +406,13 @@ class Transformer(Module):
415
406
  state_dist_params = self.state_readout(attended)
416
407
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
417
408
 
418
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
409
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
419
410
 
420
411
  return state_clone_loss, action_clone_loss
421
412
 
422
413
  elif meta_controlling and discovery_phase:
423
414
 
424
- action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
415
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
425
416
 
426
417
  return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
427
418
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.21
3
+ Version: 0.0.23
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -40,6 +40,7 @@ Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
42
  Requires-Dist: memmap-replay-buffer>=0.0.1
43
+ Requires-Dist: torch-einops-utils>=0.0.7
43
44
  Requires-Dist: torch>=2.5
44
45
  Requires-Dist: x-evolution>=0.1.23
45
46
  Requires-Dist: x-mlps-pytorch
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=mxXLeUTJ2AeYYUzpJQeE3TOrtZG_D45vfkb3jTNeEb8,13864
3
+ metacontroller_pytorch-0.0.23.dist-info/METADATA,sha256=md2Ew7uonopeMK6YFkgfcXrs_3AoIQ_5gzu9Om3Fhuc,4361
4
+ metacontroller_pytorch-0.0.23.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.23.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
3
- metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
4
- metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.21.dist-info/RECORD,,