metacontroller-pytorch 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +30 -20
- {metacontroller_pytorch-0.0.17.dist-info → metacontroller_pytorch-0.0.19.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.19.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.17.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.17.dist-info → metacontroller_pytorch-0.0.19.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.17.dist-info → metacontroller_pytorch-0.0.19.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -64,8 +64,10 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
64
64
|
class MetaController(Module):
|
|
65
65
|
def __init__(
|
|
66
66
|
self,
|
|
67
|
-
|
|
67
|
+
dim_model,
|
|
68
68
|
*,
|
|
69
|
+
dim_meta_controller = 256,
|
|
70
|
+
dim_latent = 128,
|
|
69
71
|
switch_per_latent_dim = True,
|
|
70
72
|
decoder_expansion_factor = 2.,
|
|
71
73
|
decoder_depth = 1,
|
|
@@ -73,25 +75,30 @@ class MetaController(Module):
|
|
|
73
75
|
assoc_scan_kwargs: dict = dict()
|
|
74
76
|
):
|
|
75
77
|
super().__init__()
|
|
78
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
79
|
+
|
|
80
|
+
# the linear that brings from model dimension
|
|
81
|
+
|
|
82
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
76
83
|
|
|
77
84
|
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
|
|
78
85
|
|
|
79
|
-
self.bidirectional_temporal_compressor = GRU(
|
|
86
|
+
self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
|
|
80
87
|
|
|
81
|
-
self.emitter = GRU(
|
|
82
|
-
self.emitter_to_action_mean_log_var = Readout(
|
|
88
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
89
|
+
self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
|
|
83
90
|
|
|
84
91
|
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
85
92
|
|
|
86
|
-
self.action_proposer = GRU(
|
|
87
|
-
self.action_proposer_mean_log_var = Readout(
|
|
93
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
94
|
+
self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
|
|
88
95
|
|
|
89
96
|
# switching unit
|
|
90
97
|
|
|
91
98
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
92
99
|
|
|
93
|
-
self.switching_unit = GRU(
|
|
94
|
-
self.to_switching_unit_beta = nn.Linear(
|
|
100
|
+
self.switching_unit = GRU(dim_meta, dim_meta)
|
|
101
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
95
102
|
|
|
96
103
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
97
104
|
|
|
@@ -105,7 +112,7 @@ class MetaController(Module):
|
|
|
105
112
|
dim_in = dim_latent,
|
|
106
113
|
dim = dim_decoder_hidden,
|
|
107
114
|
depth = decoder_depth,
|
|
108
|
-
dim_out = 2 * hypernetwork_low_rank *
|
|
115
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
109
116
|
)
|
|
110
117
|
|
|
111
118
|
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
@@ -114,6 +121,7 @@ class MetaController(Module):
|
|
|
114
121
|
|
|
115
122
|
def discovery_parameters(self):
|
|
116
123
|
return [
|
|
124
|
+
*self.model_to_meta.parameters(),
|
|
117
125
|
*self.bidirectional_temporal_compressor.parameters(),
|
|
118
126
|
*self.emitter.parameters(),
|
|
119
127
|
*self.emitter_to_action_mean_log_var.parameters(),
|
|
@@ -144,18 +152,20 @@ class MetaController(Module):
|
|
|
144
152
|
|
|
145
153
|
next_action_proposer_hidden = None
|
|
146
154
|
|
|
155
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
156
|
+
|
|
147
157
|
if discovery_phase:
|
|
148
158
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
149
159
|
|
|
150
|
-
temporal_compressed, _ = self.bidirectional_temporal_compressor(
|
|
160
|
+
temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
|
|
151
161
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
152
162
|
|
|
153
|
-
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed,
|
|
163
|
+
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
|
|
154
164
|
readout = self.emitter_to_action_mean_log_var
|
|
155
165
|
|
|
156
166
|
else: # else internal rl phase
|
|
157
167
|
|
|
158
|
-
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(
|
|
168
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
159
169
|
readout = self.action_proposer_mean_log_var
|
|
160
170
|
|
|
161
171
|
# sample from the gaussian as the action from the meta controller
|
|
@@ -168,7 +178,7 @@ class MetaController(Module):
|
|
|
168
178
|
|
|
169
179
|
batch, _, dim = sampled_action.shape
|
|
170
180
|
|
|
171
|
-
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
181
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
|
|
172
182
|
|
|
173
183
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
174
184
|
|
|
@@ -213,9 +223,7 @@ class MetaController(Module):
|
|
|
213
223
|
|
|
214
224
|
# generating the residual stream controlling signal
|
|
215
225
|
|
|
216
|
-
control_signal = einsum(
|
|
217
|
-
|
|
218
|
-
modified_residual_stream = residual_stream + control_signal
|
|
226
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
219
227
|
|
|
220
228
|
# returning
|
|
221
229
|
|
|
@@ -225,7 +233,7 @@ class MetaController(Module):
|
|
|
225
233
|
next_switch_gated_action
|
|
226
234
|
)
|
|
227
235
|
|
|
228
|
-
return
|
|
236
|
+
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
|
|
229
237
|
|
|
230
238
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
231
239
|
|
|
@@ -248,7 +256,7 @@ class Transformer(Module):
|
|
|
248
256
|
super().__init__()
|
|
249
257
|
|
|
250
258
|
if isinstance(lower_body, dict):
|
|
251
|
-
lower_body = Decoder(dim = dim, **lower_body)
|
|
259
|
+
lower_body = Decoder(dim = dim, pre_norm_has_final_norm = False, **lower_body)
|
|
252
260
|
|
|
253
261
|
if isinstance(upper_body, dict):
|
|
254
262
|
upper_body = Decoder(dim = dim, **upper_body)
|
|
@@ -334,9 +342,11 @@ class Transformer(Module):
|
|
|
334
342
|
with meta_controller_context():
|
|
335
343
|
|
|
336
344
|
if exists(meta_controller):
|
|
337
|
-
|
|
345
|
+
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
338
346
|
else:
|
|
339
|
-
|
|
347
|
+
control_signal, next_meta_hiddens = self.zero, None
|
|
348
|
+
|
|
349
|
+
modified_residual_stream = residual_stream + control_signal
|
|
340
350
|
|
|
341
351
|
# modified residual stream sent back to transformer upper body
|
|
342
352
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=GTErzikqVd8XDY8pmDnY8t4uIjbGCUd1GZBJX13peo8,12339
|
|
3
|
+
metacontroller_pytorch-0.0.19.dist-info/METADATA,sha256=lX3L7J3CKoSyxvJniLdSJsCu0UMEbJTxQLEw6zzT7dY,3741
|
|
4
|
+
metacontroller_pytorch-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.19.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=blxDztbtXyP3cNbjnM3fEw_KZdLFJp_l1Sub6-7zIKg,12041
|
|
3
|
-
metacontroller_pytorch-0.0.17.dist-info/METADATA,sha256=_8hYYTO_ME23kgZXqSfhA1XXAA8W877F-AL8amA7LKM,3741
|
|
4
|
-
metacontroller_pytorch-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.17.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.17.dist-info → metacontroller_pytorch-0.0.19.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|