metacontroller-pytorch 0.0.3__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/metacontroller/metacontroller.py +37 -20
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/README.md +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/tests/test_metacontroller.py +0 -0
{metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/metacontroller/metacontroller.py
RENAMED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from functools import partial
|
|
3
|
+
from collections import namedtuple
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
from torch import nn, cat, stack, tensor
|
|
@@ -42,6 +43,11 @@ def default(*args):
|
|
|
42
43
|
return arg
|
|
43
44
|
return None
|
|
44
45
|
|
|
46
|
+
# tensor helpers
|
|
47
|
+
|
|
48
|
+
def straight_through(src, tgt):
|
|
49
|
+
return tgt + src - src.detach()
|
|
50
|
+
|
|
45
51
|
# meta controller
|
|
46
52
|
|
|
47
53
|
class MetaController(Module):
|
|
@@ -62,12 +68,12 @@ class MetaController(Module):
|
|
|
62
68
|
self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
|
|
63
69
|
|
|
64
70
|
self.emitter = GRU(dim_latent * 2, dim_latent * 2)
|
|
65
|
-
self.emitter_to_action_mean_log_var =
|
|
71
|
+
self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
|
|
66
72
|
|
|
67
73
|
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
68
74
|
|
|
69
75
|
self.action_proposer = GRU(dim_latent, dim_latent)
|
|
70
|
-
self.action_proposer_mean_log_var =
|
|
76
|
+
self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
|
|
71
77
|
|
|
72
78
|
# switching unit
|
|
73
79
|
|
|
@@ -115,7 +121,8 @@ class MetaController(Module):
|
|
|
115
121
|
def forward(
|
|
116
122
|
self,
|
|
117
123
|
residual_stream,
|
|
118
|
-
discovery_phase = False
|
|
124
|
+
discovery_phase = False,
|
|
125
|
+
hard_switch = False
|
|
119
126
|
):
|
|
120
127
|
|
|
121
128
|
if discovery_phase:
|
|
@@ -123,41 +130,51 @@ class MetaController(Module):
|
|
|
123
130
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
124
131
|
|
|
125
132
|
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
|
|
126
|
-
|
|
133
|
+
readout = self.emitter_to_action_mean_log_var
|
|
127
134
|
|
|
128
135
|
else: # else internal rl phase
|
|
129
136
|
proposed_action_hidden, _ = self.action_proposer(residual_stream)
|
|
130
|
-
|
|
137
|
+
readout = self.action_proposer_mean_log_var
|
|
131
138
|
|
|
132
139
|
# sample from the gaussian as the action from the meta controller
|
|
133
140
|
|
|
134
|
-
|
|
141
|
+
action_dist = readout(proposed_action_hidden)
|
|
142
|
+
|
|
143
|
+
sampled_action = readout.sample(action_dist)
|
|
144
|
+
|
|
145
|
+
# switching unit timer
|
|
146
|
+
|
|
147
|
+
batch, _, dim = sampled_action.shape
|
|
135
148
|
|
|
136
|
-
|
|
137
|
-
|
|
149
|
+
switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
|
|
150
|
+
|
|
151
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
|
+
|
|
153
|
+
action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
|
|
154
|
+
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
138
155
|
|
|
139
156
|
# need to encourage normal distribution
|
|
140
157
|
|
|
141
158
|
vae_kl_loss = self.zero
|
|
142
159
|
|
|
143
160
|
if discovery_phase:
|
|
161
|
+
mean, log_var = action_dist.unbind(dim = -1)
|
|
162
|
+
|
|
144
163
|
vae_kl_loss = (0.5 * (
|
|
145
164
|
log_var.exp()
|
|
146
165
|
+ mean.square()
|
|
147
166
|
- log_var
|
|
148
167
|
- 1.
|
|
149
|
-
)).sum(dim = -1)
|
|
168
|
+
)).sum(dim = -1)
|
|
150
169
|
|
|
151
|
-
|
|
170
|
+
vae_kl_loss = vae_kl_loss * switch_beta
|
|
171
|
+
vae_kl_loss = vae_kl_loss.mean()
|
|
152
172
|
|
|
153
|
-
|
|
173
|
+
# maybe hard switch, then use associative scan
|
|
154
174
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
action_intent_for_gating = rearrange(sampled_action_intents, 'b n d -> (b d) n')
|
|
160
|
-
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
175
|
+
if hard_switch:
|
|
176
|
+
hard_switch = (switch_beta > 0.5).float()
|
|
177
|
+
switch_beta = straight_through(switch_beta, hard_switch)
|
|
161
178
|
|
|
162
179
|
forget = 1. - switch_beta
|
|
163
180
|
gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
|
|
@@ -177,7 +194,7 @@ class MetaController(Module):
|
|
|
177
194
|
|
|
178
195
|
modified_residual_stream = residual_stream + control_signal
|
|
179
196
|
|
|
180
|
-
return modified_residual_stream, vae_kl_loss
|
|
197
|
+
return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
|
|
181
198
|
|
|
182
199
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
183
200
|
|
|
@@ -249,9 +266,9 @@ class Transformer(Module):
|
|
|
249
266
|
# meta controller acts on residual stream here
|
|
250
267
|
|
|
251
268
|
if exists(meta_controller):
|
|
252
|
-
modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
269
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
253
270
|
else:
|
|
254
|
-
modified_residual_stream, vae_aux_loss = residual_stream, self.zero
|
|
271
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
|
|
255
272
|
|
|
256
273
|
# modified residual stream sent back
|
|
257
274
|
|
{metacontroller_pytorch-0.0.3 → metacontroller_pytorch-0.0.5}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|