metacontroller-pytorch 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,6 +52,13 @@ def straight_through(src, tgt):
52
52
 
53
53
  # meta controller
54
54
 
55
+ MetaControllerOutput = namedtuple('MetaControllerOutput', (
56
+ 'prev_hiddens',
57
+ 'action_dist',
58
+ 'actions',
59
+ 'kl_loss'
60
+ ))
61
+
55
62
  class MetaController(Module):
56
63
  def __init__(
57
64
  self,
@@ -107,9 +114,9 @@ class MetaController(Module):
107
114
  return [
108
115
  *self.bidirectional_temporal_compressor.parameters(),
109
116
  *self.emitter.parameters(),
110
- *self.emitter_to_action_mean_log_var.parameters()
117
+ *self.emitter_to_action_mean_log_var.parameters(),
111
118
  *self.decoder.parameters(),
112
- *self.switch_gating
119
+ *self.switch_gating.parameters()
113
120
  ]
114
121
 
115
122
  def internal_rl_parameters(self):
@@ -121,10 +128,19 @@ class MetaController(Module):
121
128
  def forward(
122
129
  self,
123
130
  residual_stream,
131
+ cache: MetaControllerOutput | None = None,
124
132
  discovery_phase = False,
125
133
  hard_switch = False
126
134
  ):
127
135
 
136
+ # destruct prev cache
137
+
138
+ prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
139
+
140
+ # getting proposed action for the two phases
141
+
142
+ next_action_proposer_hidden = None
143
+
128
144
  if discovery_phase:
129
145
  temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
130
146
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
@@ -133,7 +149,8 @@ class MetaController(Module):
133
149
  readout = self.emitter_to_action_mean_log_var
134
150
 
135
151
  else: # else internal rl phase
136
- proposed_action_hidden, _ = self.action_proposer(residual_stream)
152
+
153
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
137
154
  readout = self.action_proposer_mean_log_var
138
155
 
139
156
  # sample from the gaussian as the action from the meta controller
@@ -146,35 +163,37 @@ class MetaController(Module):
146
163
 
147
164
  batch, _, dim = sampled_action.shape
148
165
 
149
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
166
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
150
167
 
151
168
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
169
 
153
170
  # need to encourage normal distribution
154
171
 
155
- vae_kl_loss = self.zero
172
+ kl_loss = self.zero
156
173
 
157
174
  if discovery_phase:
158
175
  mean, log_var = action_dist.unbind(dim = -1)
159
176
 
160
- vae_kl_loss = (0.5 * (
177
+ kl_loss = (0.5 * (
161
178
  log_var.exp()
162
179
  + mean.square()
163
180
  - log_var
164
181
  - 1.
165
182
  ))
166
183
 
167
- vae_kl_loss = vae_kl_loss * switch_beta
168
- vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
184
+ kl_loss = kl_loss * switch_beta
185
+ kl_loss = kl_loss.sum(dim = -1).mean()
169
186
 
170
187
  # maybe hard switch, then use associative scan
171
188
 
172
189
  if hard_switch:
173
- hard_switch = (switch_beta > 0.5).float()
174
- switch_beta = straight_through(switch_beta, hard_switch)
190
+ hard_switch_beta = (switch_beta > 0.5).float()
191
+ switch_beta = straight_through(switch_beta, hard_switch_beta)
175
192
 
176
193
  forget = 1. - switch_beta
177
- gated_action = self.switch_gating(switch_beta, sampled_action * forget)
194
+ gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
195
+
196
+ next_switch_gated_action = gated_action[:, -1]
178
197
 
179
198
  # decoder
180
199
 
@@ -189,10 +208,23 @@ class MetaController(Module):
189
208
 
190
209
  modified_residual_stream = residual_stream + control_signal
191
210
 
192
- return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
211
+ # returning
212
+
213
+ next_hiddens = (
214
+ next_action_proposer_hidden,
215
+ next_switching_unit_gru_hidden,
216
+ next_switch_gated_action
217
+ )
218
+
219
+ return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
193
220
 
194
221
  # main transformer, which is subsumed into the environment after behavioral cloning
195
222
 
223
+ TransformerOutput = namedtuple('TransformerOutput', (
224
+ 'residual_stream_latent',
225
+ 'prev_hiddens'
226
+ ))
227
+
196
228
  class Transformer(Module):
197
229
  def __init__(
198
230
  self,
@@ -251,10 +283,12 @@ class Transformer(Module):
251
283
  self,
252
284
  ids,
253
285
  meta_controller: Module | None = None,
286
+ cache: TransformerOutput | None = None,
254
287
  discovery_phase = False,
255
- return_latents = False,
256
288
  no_grad_transformer = None,
257
- no_grad_meta_controller = None
289
+ no_grad_meta_controller = None,
290
+ return_latents = False,
291
+ return_cache = False
258
292
  ):
259
293
  meta_controller = default(meta_controller, self.meta_controller)
260
294
 
@@ -268,28 +302,32 @@ class Transformer(Module):
268
302
  transformer_context = torch.no_grad if no_grad_transformer else nullcontext
269
303
  meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
270
304
 
305
+ # handle cache
306
+
307
+ lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
308
+
271
309
  # transformer lower body
272
310
 
273
311
  with transformer_context():
274
312
 
275
313
  embed = self.embed(ids)
276
314
 
277
- residual_stream = self.lower_body(embed)
315
+ residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
278
316
 
279
317
  # meta controller acts on residual stream here
280
318
 
281
319
  with meta_controller_context():
282
320
 
283
321
  if exists(meta_controller):
284
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
322
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
285
323
  else:
286
- modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
324
+ modified_residual_stream, next_meta_hiddens = residual_stream, None
287
325
 
288
326
  # modified residual stream sent back to transformer upper body
289
327
 
290
328
  with transformer_context():
291
329
 
292
- attended = self.upper_body(modified_residual_stream)
330
+ attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
293
331
 
294
332
  # head readout
295
333
 
@@ -297,7 +335,9 @@ class Transformer(Module):
297
335
 
298
336
  # returning
299
337
 
300
- if not return_latents:
338
+ return_one = not (return_latents or return_cache)
339
+
340
+ if return_one:
301
341
  return dist_params
302
342
 
303
- return dist_params, latents
343
+ return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=H-bZi70445-4JlhUFL8x_fgePY7bTxkDO4CCdItKao4,10642
3
+ metacontroller_pytorch-0.0.10.dist-info/METADATA,sha256=AFk9SUK6TGSG1APtt51yiASCEWIOTIvzAhtJJnS-Dsc,3714
4
+ metacontroller_pytorch-0.0.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.10.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=V2Nb7ByGj310CalTzho-grwNsoHMp55oN5spkedJihc,9189
3
- metacontroller_pytorch-0.0.9.dist-info/METADATA,sha256=BA4AHlFW8DsD_NPXNv8N8rmRPISZNTkcjvGautB7xJA,3713
4
- metacontroller_pytorch-0.0.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.9.dist-info/RECORD,,