metacontroller-pytorch 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +19 -12
- {metacontroller_pytorch-0.0.2.dist-info → metacontroller_pytorch-0.0.4.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.4.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.2.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.2.dist-info → metacontroller_pytorch-0.0.4.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.2.dist-info → metacontroller_pytorch-0.0.4.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from functools import partial
|
|
3
|
+
from collections import namedtuple
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
from torch import nn, cat, stack, tensor
|
|
@@ -62,12 +63,12 @@ class MetaController(Module):
|
|
|
62
63
|
self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
|
|
63
64
|
|
|
64
65
|
self.emitter = GRU(dim_latent * 2, dim_latent * 2)
|
|
65
|
-
self.emitter_to_action_mean_log_var =
|
|
66
|
+
self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
|
|
66
67
|
|
|
67
68
|
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
68
69
|
|
|
69
70
|
self.action_proposer = GRU(dim_latent, dim_latent)
|
|
70
|
-
self.action_proposer_mean_log_var =
|
|
71
|
+
self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
|
|
71
72
|
|
|
72
73
|
# switching unit
|
|
73
74
|
|
|
@@ -123,24 +124,25 @@ class MetaController(Module):
|
|
|
123
124
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
124
125
|
|
|
125
126
|
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
|
|
126
|
-
|
|
127
|
+
readout = self.emitter_to_action_mean_log_var
|
|
127
128
|
|
|
128
129
|
else: # else internal rl phase
|
|
129
130
|
proposed_action_hidden, _ = self.action_proposer(residual_stream)
|
|
130
|
-
|
|
131
|
+
readout = self.action_proposer_mean_log_var
|
|
131
132
|
|
|
132
133
|
# sample from the gaussian as the action from the meta controller
|
|
133
134
|
|
|
134
|
-
|
|
135
|
+
action_dist = readout(proposed_action_hidden)
|
|
135
136
|
|
|
136
|
-
|
|
137
|
-
sampled_action_intents = mean + torch.randn_like(mean) * std
|
|
137
|
+
sampled_action = readout.sample(action_dist)
|
|
138
138
|
|
|
139
139
|
# need to encourage normal distribution
|
|
140
140
|
|
|
141
141
|
vae_kl_loss = self.zero
|
|
142
142
|
|
|
143
143
|
if discovery_phase:
|
|
144
|
+
mean, log_var = action_dist.unbind(dim = -1)
|
|
145
|
+
|
|
144
146
|
vae_kl_loss = (0.5 * (
|
|
145
147
|
log_var.exp()
|
|
146
148
|
+ mean.square()
|
|
@@ -150,13 +152,13 @@ class MetaController(Module):
|
|
|
150
152
|
|
|
151
153
|
# switching unit timer
|
|
152
154
|
|
|
153
|
-
batch, _, dim =
|
|
155
|
+
batch, _, dim = sampled_action.shape
|
|
154
156
|
|
|
155
157
|
switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
|
|
156
158
|
|
|
157
159
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
158
160
|
|
|
159
|
-
action_intent_for_gating = rearrange(
|
|
161
|
+
action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
|
|
160
162
|
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
161
163
|
|
|
162
164
|
forget = 1. - switch_beta
|
|
@@ -177,7 +179,7 @@ class MetaController(Module):
|
|
|
177
179
|
|
|
178
180
|
modified_residual_stream = residual_stream + control_signal
|
|
179
181
|
|
|
180
|
-
return modified_residual_stream, vae_kl_loss
|
|
182
|
+
return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
|
|
181
183
|
|
|
182
184
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
183
185
|
|
|
@@ -215,6 +217,8 @@ class Transformer(Module):
|
|
|
215
217
|
|
|
216
218
|
self.meta_controller = meta_controller
|
|
217
219
|
|
|
220
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
221
|
+
|
|
218
222
|
def evolve(
|
|
219
223
|
self,
|
|
220
224
|
environment,
|
|
@@ -238,7 +242,7 @@ class Transformer(Module):
|
|
|
238
242
|
discovery_phase = False,
|
|
239
243
|
return_latents = False
|
|
240
244
|
):
|
|
241
|
-
meta_controller = default(meta_controller, self.meta_controller
|
|
245
|
+
meta_controller = default(meta_controller, self.meta_controller)
|
|
242
246
|
|
|
243
247
|
embed = self.embed(ids)
|
|
244
248
|
|
|
@@ -246,7 +250,10 @@ class Transformer(Module):
|
|
|
246
250
|
|
|
247
251
|
# meta controller acts on residual stream here
|
|
248
252
|
|
|
249
|
-
|
|
253
|
+
if exists(meta_controller):
|
|
254
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
255
|
+
else:
|
|
256
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
|
|
250
257
|
|
|
251
258
|
# modified residual stream sent back
|
|
252
259
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=fGbNUdxTYGgHBdINJziUINDYmANkNLy0yiTIt4dycKM,8162
|
|
3
|
+
metacontroller_pytorch-0.0.4.dist-info/METADATA,sha256=wfeiKctuqzj_NlWq2Xg5hbgjs6bzMmgL-VdTCzgceS8,3706
|
|
4
|
+
metacontroller_pytorch-0.0.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.4.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=VkTl7xHwAW1RSkWNdaMpcJ9kSqcIME-5vK9otEzGocc,7899
|
|
3
|
-
metacontroller_pytorch-0.0.2.dist-info/METADATA,sha256=76qLimsKlE06NLg6XdyE_eKsmBzqxwjL-F0zW0XEoGk,3706
|
|
4
|
-
metacontroller_pytorch-0.0.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.2.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.2.dist-info → metacontroller_pytorch-0.0.4.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|