metacontroller-pytorch 0.0.34__py3-none-any.whl → 0.0.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

@@ -329,8 +329,15 @@ class MetaController(Module):
329
329
  sampled_latent_action[:, -1:]
330
330
  )
331
331
 
332
+ # squeeze out the last dimension of switch_beta if single gate for all latent dimensions
333
+
334
+ if not self.switch_per_latent_dim:
335
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
336
+
332
337
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
333
338
 
339
+ MetaController.policy_loss = policy_loss
340
+
334
341
  # main transformer, which is subsumed into the environment after behavioral cloning
335
342
 
336
343
  Hiddens = namedtuple('Hiddens', (
@@ -409,17 +416,21 @@ class Transformer(Module):
409
416
  ):
410
417
  device = state.device
411
418
 
419
+ # meta controller is either given or already given at init
420
+
412
421
  meta_controller = default(meta_controller, self.meta_controller)
413
422
 
414
- meta_controlling = exists(meta_controller)
423
+ has_meta_controller = exists(meta_controller)
424
+
425
+ assert not (discovery_phase and not has_meta_controller), 'meta controller must be made available during discovery phase'
415
426
 
416
- behavioral_cloning = not meta_controlling and not return_raw_action_dist
427
+ behavioral_cloning = not has_meta_controller and not return_raw_action_dist
417
428
 
418
429
  # by default, if meta controller is passed in, transformer is no grad
419
430
 
420
- lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
421
- meta_controller_context = nullcontext if meta_controlling else torch.no_grad
422
- upper_transformer_context = nullcontext if (not meta_controlling or discovery_phase) else torch.no_grad
431
+ lower_transformer_context = nullcontext if not has_meta_controller else torch.no_grad
432
+ meta_controller_context = nullcontext if has_meta_controller else torch.no_grad
433
+ upper_transformer_context = nullcontext if (not has_meta_controller or discovery_phase) else torch.no_grad
423
434
 
424
435
  # handle cache
425
436
 
@@ -427,7 +438,8 @@ class Transformer(Module):
427
438
 
428
439
  # handle maybe behavioral cloning
429
440
 
430
- if behavioral_cloning or (meta_controlling and discovery_phase):
441
+ if behavioral_cloning or discovery_phase: # during behavior cloning and discovery phase, the network is predicting / reconstructing the next token
442
+
431
443
  assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
432
444
 
433
445
  state, target_state = state[:, :-1], state[:, 1:]
@@ -490,7 +502,7 @@ class Transformer(Module):
490
502
 
491
503
  return state_clone_loss, action_clone_loss
492
504
 
493
- elif meta_controlling and discovery_phase:
505
+ elif discovery_phase:
494
506
 
495
507
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
496
508
 
@@ -28,7 +28,7 @@ from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
30
30
 
31
- from metacontroller.metacontroller import MetaControllerOutput
31
+ from metacontroller.metacontroller import MetaControllerOutput, policy_loss
32
32
 
33
33
  # constants
34
34
 
@@ -170,7 +170,7 @@ class MetaControllerWithBinaryMapper(Module):
170
170
  action_log_probs = log_probs.gather(-1, codes)
171
171
  action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
172
172
 
173
- return action_log_probs.sum(dim = -1)
173
+ return action_log_probs
174
174
 
175
175
  def forward(
176
176
  self,
@@ -296,4 +296,11 @@ class MetaControllerWithBinaryMapper(Module):
296
296
  sampled_codes[:, -1:]
297
297
  )
298
298
 
299
+ # squeeze out the last dimension of switch_beta if single gate for all codes
300
+
301
+ if not self.switch_per_code:
302
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
303
+
299
304
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
305
+
306
+ MetaControllerWithBinaryMapper.policy_loss = policy_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.34
3
+ Version: 0.0.36
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -39,7 +39,7 @@ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: loguru
42
- Requires-Dist: memmap-replay-buffer>=0.0.23
42
+ Requires-Dist: memmap-replay-buffer>=0.0.25
43
43
  Requires-Dist: torch-einops-utils>=0.0.19
44
44
  Requires-Dist: torch>=2.5
45
45
  Requires-Dist: vector-quantize-pytorch>=1.27.20
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=rVYzBJ8jQx9tfkZ3B9NdxKTI7dyBxXtTl4kwfizYuis,16728
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=odZs49ZWY7_FfEweYkD0moX7Vn0jGd91FjFTxzjLyr8,9480
4
+ metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
+ metacontroller_pytorch-0.0.36.dist-info/METADATA,sha256=eLKG8B0gSZyIMkaLvjYE8SvWVN387BuuNFoOC_6lmT4,4747
6
+ metacontroller_pytorch-0.0.36.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.36.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=Ii4Z2MuVMXJeWxAnZnRfSKrGv1t6_Y4R5BgvmtDrSW8,16203
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=PgXK7uk--gvZPk1h4WLdCBEA7m9Ji12xixa2wqLmsLY,9234
4
- metacontroller/transformer_with_resnet.py,sha256=R49ycusbq3kEX97WHZ41WY2ONc2mYPOuRUCmaFcBOEo,5546
5
- metacontroller_pytorch-0.0.34.dist-info/METADATA,sha256=xBFG34yRWTkcfJLFvC22jARLkR_W_6kcSsAV8r5UFWY,4747
6
- metacontroller_pytorch-0.0.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
- metacontroller_pytorch-0.0.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- metacontroller_pytorch-0.0.34.dist-info/RECORD,,