metacontroller-pytorch 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,9 +46,6 @@ def default(*args):
46
46
  return arg
47
47
  return None
48
48
 
49
- def is_empty(t):
50
- return t.numel() == 0
51
-
52
49
  def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
53
50
  if pad == (0, 0):
54
51
  return t
@@ -333,7 +330,7 @@ class Transformer(Module):
333
330
  def forward(
334
331
  self,
335
332
  state,
336
- action_ids: Tensor | None = None,
333
+ actions: Tensor | None = None,
337
334
  meta_controller: Module | None = None,
338
335
  cache: TransformerOutput | None = None,
339
336
  discovery_phase = False,
@@ -363,10 +360,10 @@ class Transformer(Module):
363
360
  # handle maybe behavioral cloning
364
361
 
365
362
  if behavioral_cloning or (meta_controlling and discovery_phase):
366
- assert not is_empty(action_ids), f'`action_ids` cannot be empty when doing discovery or behavioral cloning'
363
+ assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
367
364
 
368
365
  state, target_state = state[:, :-1], state[:, 1:]
369
- action_ids, target_action_ids = action_ids[:, :-1], action_ids[:, 1:]
366
+ actions, target_actions = actions[:, :-1], actions[:, 1:]
370
367
 
371
368
  # transformer lower body
372
369
 
@@ -376,8 +373,8 @@ class Transformer(Module):
376
373
 
377
374
  # handle no past action for first timestep
378
375
 
379
- if exists(action_ids):
380
- action_embed = self.action_embed(action_ids)
376
+ if exists(actions):
377
+ action_embed = self.action_embed(actions)
381
378
  else:
382
379
  action_embed = state_embed[:, 0:0] # empty action embed
383
380
 
@@ -415,13 +412,13 @@ class Transformer(Module):
415
412
  state_dist_params = self.state_readout(attended)
416
413
  state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
417
414
 
418
- action_clone_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
415
+ action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
419
416
 
420
417
  return state_clone_loss, action_clone_loss
421
418
 
422
419
  elif meta_controlling and discovery_phase:
423
420
 
424
- action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
421
+ action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
425
422
 
426
423
  return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
427
424
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.21
3
+ Version: 0.0.22
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=kmfq4XSeh_tT30n3ctG1ggTgA0hVtV2eNKMqKr_jChc,14084
3
+ metacontroller_pytorch-0.0.22.dist-info/METADATA,sha256=zxsjz3enrR60enivkdcCYLd2Xjpg8d5EUydakRjYOAQ,4320
4
+ metacontroller_pytorch-0.0.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.22.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
3
- metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
4
- metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.21.dist-info/RECORD,,