metacontroller-pytorch 0.0.33__tar.gz → 0.0.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of metacontroller-pytorch might be problematic. Click here for more details.

Files changed (17) hide show
  1. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/PKG-INFO +1 -1
  2. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/metacontroller/metacontroller.py +6 -1
  3. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/metacontroller/metacontroller_with_binary_mapper.py +22 -5
  4. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/pyproject.toml +1 -1
  5. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/tests/test_metacontroller.py +17 -20
  6. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/.github/workflows/python-publish.yml +0 -0
  7. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/.github/workflows/test.yml +0 -0
  8. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/.gitignore +0 -0
  9. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/LICENSE +0 -0
  10. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/README.md +0 -0
  11. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/fig1.png +0 -0
  12. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/gather_babyai_trajs.py +0 -0
  13. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/metacontroller/__init__.py +0 -0
  14. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/metacontroller/transformer_with_resnet.py +0 -0
  15. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/test_babyai_e2e.sh +0 -0
  16. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/train_babyai.py +0 -0
  17. {metacontroller_pytorch-0.0.33 → metacontroller_pytorch-0.0.35}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.33
3
+ Version: 0.0.35
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -26,7 +26,7 @@ from discrete_continuous_embed_readout import Embed, Readout, EmbedAndReadout
26
26
 
27
27
  from assoc_scan import AssocScan
28
28
 
29
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
29
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left
30
30
  from torch_einops_utils.save_load import save_load
31
31
 
32
32
  # constants
@@ -329,6 +329,11 @@ class MetaController(Module):
329
329
  sampled_latent_action[:, -1:]
330
330
  )
331
331
 
332
+ # squeeze out the last dimension of switch_beta if single gate for all latent dimensions
333
+
334
+ if not self.switch_per_latent_dim:
335
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
336
+
332
337
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, action_dist, sampled_latent_action, switch_beta, kl_loss, switch_loss)
333
338
 
334
339
  # main transformer, which is subsumed into the environment after behavioral cloning
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, masked_mean, align_dims_left, pad_right_ndim_to
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -143,22 +143,34 @@ class MetaControllerWithBinaryMapper(Module):
143
143
  *self.proposer_to_binary_logits.parameters()
144
144
  ]
145
145
 
146
+ def get_action_dist_for_internal_rl(
147
+ self,
148
+ residual_stream
149
+ ):
150
+ meta_embed = self.model_to_meta(residual_stream)
151
+
152
+ proposed_action_hidden, _ = self.action_proposer(meta_embed)
153
+
154
+ return self.proposer_to_binary_logits(proposed_action_hidden)
155
+
146
156
  def log_prob(
147
157
  self,
148
158
  action_dist,
149
159
  sampled_latent_action
150
160
  ):
151
- action_prob = action_dist.sigmoid()
152
- probs = stack((action_prob, 1. - action_prob), dim = -1)
153
- log_probs = log(probs)
161
+ log_probs = stack((
162
+ F.logsigmoid(action_dist),
163
+ F.logsigmoid(-action_dist)
164
+ ), dim = -1)
154
165
 
155
166
  indices = sampled_latent_action.argmax(dim = -1)
156
167
  codes = self.binary_mapper.codes[indices].long()
157
168
 
158
169
  codes = rearrange(codes, '... -> ... 1')
159
170
  action_log_probs = log_probs.gather(-1, codes)
171
+ action_log_probs = rearrange(action_log_probs, '... 1 -> ...')
160
172
 
161
- return rearrange(action_log_probs, '... 1 -> ...')
173
+ return action_log_probs.sum(dim = -1)
162
174
 
163
175
  def forward(
164
176
  self,
@@ -284,4 +296,9 @@ class MetaControllerWithBinaryMapper(Module):
284
296
  sampled_codes[:, -1:]
285
297
  )
286
298
 
299
+ # squeeze out the last dimension of switch_beta if single gate for all codes
300
+
301
+ if not self.switch_per_code:
302
+ switch_beta = rearrange(switch_beta, '... 1 -> ...')
303
+
287
304
  return control_signal, MetaControllerOutput(next_hiddens, residual_stream, binary_logits, sampled_codes, switch_beta, kl_loss, switch_loss)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.33"
3
+ version = "0.0.35"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -2,6 +2,7 @@ import pytest
2
2
  param = pytest.mark.parametrize
3
3
 
4
4
  from pathlib import Path
5
+ from functools import partial
5
6
 
6
7
  import torch
7
8
  from torch import cat
@@ -116,11 +117,11 @@ def test_metacontroller(
116
117
 
117
118
  # accumulate across time for the episode data
118
119
 
119
- all_episodes.append(dict(
120
- states = cat(states, dim = 1),
121
- log_probs = cat(log_probs, dim = 1),
122
- switch_betas = cat(switch_betas, dim = 1),
123
- latent_actions = cat(latent_actions, dim = 1)
120
+ all_episodes.append((
121
+ cat(states, dim = 1),
122
+ cat(log_probs, dim = 1),
123
+ cat(switch_betas, dim = 1),
124
+ cat(latent_actions, dim = 1)
124
125
  ))
125
126
 
126
127
  all_rewards.append(torch.randn(1))
@@ -134,23 +135,19 @@ def test_metacontroller(
134
135
 
135
136
  # simulate a policy loss update over the entire group
136
137
 
137
- group_states = cat([e['states'] for e in all_episodes], dim = 0)
138
- group_log_probs = cat([e['log_probs'] for e in all_episodes], dim = 0)
139
- group_latent_actions = cat([e['latent_actions'] for e in all_episodes], dim = 0)
140
- group_switch_betas = cat([e['switch_betas'] for e in all_episodes], dim = 0)
138
+ group_states, group_log_probs, group_switch_betas, group_latent_actions = map(partial(cat, dim = 0), zip(*all_episodes))
141
139
 
142
- if not use_binary_mapper_variant:
143
- loss = policy_loss(
144
- meta_controller,
145
- group_states,
146
- group_log_probs,
147
- group_latent_actions,
148
- advantages,
149
- group_switch_betas == 1.,
150
- episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
151
- )
140
+ loss = policy_loss(
141
+ meta_controller,
142
+ group_states,
143
+ group_log_probs,
144
+ group_latent_actions,
145
+ advantages,
146
+ group_switch_betas == 1.,
147
+ episode_lens = episode_lens[:1].repeat(3) if exists(episode_lens) else None
148
+ )
152
149
 
153
- loss.backward()
150
+ loss.backward()
154
151
 
155
152
  # evolutionary strategies over grpo
156
153