metacontroller-pytorch 0.0.40__tar.gz → 0.0.41__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/PKG-INFO +86 -1
  2. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/README.md +85 -0
  3. metacontroller_pytorch-0.0.41/metacontroller/__init__.py +1 -0
  4. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/pyproject.toml +1 -1
  5. metacontroller_pytorch-0.0.40/metacontroller/__init__.py +0 -1
  6. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/.github/workflows/python-publish.yml +0 -0
  7. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/.github/workflows/test.yml +0 -0
  8. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/.gitignore +0 -0
  9. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/LICENSE +0 -0
  10. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/babyai_env.py +0 -0
  11. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/fig1.png +0 -0
  12. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/gather_babyai_trajs.py +0 -0
  13. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller.py +0 -0
  14. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/metacontroller/metacontroller_with_binary_mapper.py +0 -0
  15. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/metacontroller/transformer_with_resnet.py +0 -0
  16. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/test_babyai_e2e.sh +0 -0
  17. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/tests/test_metacontroller.py +0 -0
  18. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/train_babyai.py +0 -0
  19. {metacontroller_pytorch-0.0.40 → metacontroller_pytorch-0.0.41}/train_behavior_clone_babyai.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.40
3
+ Version: 0.0.41
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -69,6 +69,91 @@ $ pip install metacontroller-pytorch
69
69
 
70
70
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
71
 
72
+ ## Usage
73
+
74
+ ```python
75
+ import torch
76
+ from metacontroller import Transformer, MetaController
77
+
78
+ # 1. initialize model
79
+
80
+ model = Transformer(
81
+ dim = 512,
82
+ action_embed_readout = dict(num_discrete = 4),
83
+ state_embed_readout = dict(num_continuous = 384),
84
+ lower_body = dict(depth = 2),
85
+ upper_body = dict(depth = 2)
86
+ )
87
+
88
+ state = torch.randn(2, 128, 384)
89
+ actions = torch.randint(0, 4, (2, 128))
90
+
91
+ # 2. behavioral cloning (BC)
92
+
93
+ state_loss, action_loss = model(state, actions)
94
+ (state_loss + action_loss).backward()
95
+
96
+ # 3. discovery phase
97
+
98
+ meta_controller = MetaController(
99
+ dim_model = 512,
100
+ dim_meta_controller = 256,
101
+ dim_latent = 128
102
+ )
103
+
104
+ action_recon_loss, kl_loss, switch_loss = model(
105
+ state,
106
+ actions,
107
+ meta_controller = meta_controller,
108
+ discovery_phase = True
109
+ )
110
+
111
+ (action_recon_loss + kl_loss + switch_loss).backward()
112
+
113
+ # 4. internal rl phase (GRPO)
114
+
115
+ # ... collect trajectories ...
116
+
117
+ logits, cache = model(
118
+ one_state,
119
+ past_action_id,
120
+ meta_controller = meta_controller,
121
+ return_cache = True
122
+ )
123
+
124
+ meta_output = cache.prev_hiddens.meta_controller
125
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
126
+
127
+ # ... calculate advantages ...
128
+
129
+ loss = meta_controller.policy_loss(
130
+ group_states,
131
+ group_old_log_probs,
132
+ group_latent_actions,
133
+ group_advantages,
134
+ group_switch_betas
135
+ )
136
+
137
+ loss.backward()
138
+ ```
139
+
140
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
141
+
142
+ ```python
143
+ # 5. evolve (ES over GRPO)
144
+
145
+ model.meta_controller = meta_controller
146
+
147
+ def environment_callable(model):
148
+ # return a fitness score
149
+ return 1.0
150
+
151
+ model.evolve(
152
+ num_generations = 10,
153
+ environment = environment_callable
154
+ )
155
+ ```
156
+
72
157
  ## Citations
73
158
 
74
159
  ```bibtex
@@ -16,6 +16,91 @@ $ pip install metacontroller-pytorch
16
16
 
17
17
  - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
18
18
 
19
+ ## Usage
20
+
21
+ ```python
22
+ import torch
23
+ from metacontroller import Transformer, MetaController
24
+
25
+ # 1. initialize model
26
+
27
+ model = Transformer(
28
+ dim = 512,
29
+ action_embed_readout = dict(num_discrete = 4),
30
+ state_embed_readout = dict(num_continuous = 384),
31
+ lower_body = dict(depth = 2),
32
+ upper_body = dict(depth = 2)
33
+ )
34
+
35
+ state = torch.randn(2, 128, 384)
36
+ actions = torch.randint(0, 4, (2, 128))
37
+
38
+ # 2. behavioral cloning (BC)
39
+
40
+ state_loss, action_loss = model(state, actions)
41
+ (state_loss + action_loss).backward()
42
+
43
+ # 3. discovery phase
44
+
45
+ meta_controller = MetaController(
46
+ dim_model = 512,
47
+ dim_meta_controller = 256,
48
+ dim_latent = 128
49
+ )
50
+
51
+ action_recon_loss, kl_loss, switch_loss = model(
52
+ state,
53
+ actions,
54
+ meta_controller = meta_controller,
55
+ discovery_phase = True
56
+ )
57
+
58
+ (action_recon_loss + kl_loss + switch_loss).backward()
59
+
60
+ # 4. internal rl phase (GRPO)
61
+
62
+ # ... collect trajectories ...
63
+
64
+ logits, cache = model(
65
+ one_state,
66
+ past_action_id,
67
+ meta_controller = meta_controller,
68
+ return_cache = True
69
+ )
70
+
71
+ meta_output = cache.prev_hiddens.meta_controller
72
+ old_log_probs = meta_controller.log_prob(meta_output.action_dist, meta_output.actions)
73
+
74
+ # ... calculate advantages ...
75
+
76
+ loss = meta_controller.policy_loss(
77
+ group_states,
78
+ group_old_log_probs,
79
+ group_latent_actions,
80
+ group_advantages,
81
+ group_switch_betas
82
+ )
83
+
84
+ loss.backward()
85
+ ```
86
+
87
+ Or using [evolutionary strategies](https://arxiv.org/abs/2511.16652) for the last portion
88
+
89
+ ```python
90
+ # 5. evolve (ES over GRPO)
91
+
92
+ model.meta_controller = meta_controller
93
+
94
+ def environment_callable(model):
95
+ # return a fitness score
96
+ return 1.0
97
+
98
+ model.evolve(
99
+ num_generations = 10,
100
+ environment = environment_callable
101
+ )
102
+ ```
103
+
19
104
  ## Citations
20
105
 
21
106
  ```bibtex
@@ -0,0 +1 @@
1
+ from metacontroller.metacontroller import MetaController, Transformer
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.40"
3
+ version = "0.0.41"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -1 +0,0 @@
1
- from metacontroller.metacontroller import MetaController