metacontroller-pytorch 0.0.34__tar.gz → 0.0.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (17) hide show
  1. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/PKG-INFO +1 -1
  2. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/metacontroller/metacontroller.py +5 -0
  3. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/metacontroller/metacontroller_with_binary_mapper.py +5 -0
  4. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/pyproject.toml +1 -1
  5. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/tests/test_metacontroller.py +17 -20
  6. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/.github/workflows/python-publish.yml +0 -0
  7. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/.github/workflows/test.yml +0 -0
  8. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/.gitignore +0 -0
  9. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/LICENSE +0 -0
  10. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/README.md +0 -0
  11. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/fig1.png +0 -0
  12. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/gather_babyai_trajs.py +0 -0
  13. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/metacontroller/__init__.py +0 -0
  14. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/metacontroller/transformer_with_resnet.py +0 -0
  15. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/test_babyai_e2e.sh +0 -0
  16. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/train_babyai.py +0 -0
  17. {metacontroller_pytorch-0.0.34 → metacontroller_pytorch-0.0.35}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.34
3
+ Version: 0.0.35
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -329,6 +329,11 @@ class MetaController(Module):
329
329
  sampled_latent_action[:, -1:]
330
330
  )
331
331
 
332
+ # squeeze out the last dimension of switch_beta if single gate for all latent dimensions
333
+
334
+ if not self.switch_per_latent_dim:
335
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
336
+
332
337
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
333
338
 
334
339
  # main transformer, which is subsumed into the environment after behavioral cloning
@@ -296,4 +296,9 @@ class MetaControllerWithBinaryMapper(Module):
296
296
  sampled_codes[:, -1:]
297
297
  )
298
298
 
299
+ # squeeze out the last dimension of switch_beta if single gate for all codes
300
+
301
+ if not self.switch_per_code:
302
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
303
+
299
304
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.34"
3
+ version = "0.0.35"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,6 +2,7 @@ import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
4
  from pathlib import Path
5
+ from functools import partial
5
6
 
6
7
  import torch
7
8
  from torch import cat
@@ -116,11 +117,11 @@ def test_metacontroller(
116
117
 
117
118
  # accumulate across time for the episode data
118
119
 
119
- all_episodes.append(dict(
120
- states = cat(states, dim = 1),
121
- log_probs = cat(log_probs, dim = 1),
122
- switch_betas = cat(switch_betas, dim = 1),
123
- latent_actions = cat(latent_actions, dim = 1)
120
+ all_episodes.append((
121
+ cat(states, dim = 1),
122
+ cat(log_probs, dim = 1),
123
+ cat(switch_betas, dim = 1),
124
+ cat(latent_actions, dim = 1)
124
125
  ))
125
126
 
126
127
  all_rewards.append(torch.randn(1))
@@ -134,23 +135,19 @@ def test_metacontroller(
134
135
 
135
136
  # simulate a policy loss update over the entire group
136
137
 
137
- group_states = cat([e['states'] for e in all_episodes], dim = 0)
138
- group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
139
- group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
140
- group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
138
+ group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
141
139
 
142
- if not use_binary_mapper_variant:
143
- loss = policy_loss(
144
- meta_controller,
145
- group_states,
146
- group_log_probs,
147
- group_latent_actions,
148
- advantages,
149
- group_switch_betas == 1.,
150
- episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
151
- )
140
+ loss = policy_loss(
141
+ meta_controller,
142
+ group_states,
143
+ group_log_probs,
144
+ group_latent_actions,
145
+ advantages,
146
+ group_switch_betas == 1.,
147
+ episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
148
+ )
152
149
 
153
- loss.backward()
150
+ loss.backward()
154
151
 
155
152
  # evolutionary strategies over grpo
156
153