metacontroller-pytorch 0.0.2__tar.gz → 0.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
+ from collections import namedtuple
3
4
 
4
5
  import torch
5
6
  from torch import nn, cat, stack, tensor
@@ -62,12 +63,12 @@ class MetaController(Module):
62
63
  self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
63
64
 
64
65
  self.emitter = GRU(dim_latent * 2, dim_latent * 2)
65
- self.emitter_to_action_mean_log_var = LinearNoBias(dim_latent * 2, dim_latent * 2)
66
+ self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
66
67
 
67
68
  # internal rl phase substitutes the acausal + emitter with a causal ssm
68
69
 
69
70
  self.action_proposer = GRU(dim_latent, dim_latent)
70
- self.action_proposer_mean_log_var = LinearNoBias(dim_latent, dim_latent * 2)
71
+ self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
71
72
 
72
73
  # switching unit
73
74
 
@@ -123,24 +124,25 @@ class MetaController(Module):
123
124
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
124
125
 
125
126
  proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
126
- proposed_action = self.emitter_to_action_mean_log_var(proposed_action_hidden)
127
+ readout = self.emitter_to_action_mean_log_var
127
128
 
128
129
  else: # else internal rl phase
129
130
  proposed_action_hidden, _ = self.action_proposer(residual_stream)
130
- proposed_action = self.action_proposer_mean_log_var(proposed_action_hidden)
131
+ readout = self.action_proposer_mean_log_var
131
132
 
132
133
  # sample from the gaussian as the action from the meta controller
133
134
 
134
- mean, log_var = proposed_action.chunk(2, dim = -1)
135
+ action_dist = readout(proposed_action_hidden)
135
136
 
136
- std = (0.5 * log_var).exp()
137
- sampled_action_intents = mean + torch.randn_like(mean) * std
137
+ sampled_action = readout.sample(action_dist)
138
138
 
139
139
  # need to encourage normal distribution
140
140
 
141
141
  vae_kl_loss = self.zero
142
142
 
143
143
  if discovery_phase:
144
+ mean, log_var = action_dist.unbind(dim = -1)
145
+
144
146
  vae_kl_loss = (0.5 * (
145
147
  log_var.exp()
146
148
  + mean.square()
@@ -150,13 +152,13 @@ class MetaController(Module):
150
152
 
151
153
  # switching unit timer
152
154
 
153
- batch, _, dim = sampled_action_intents.shape
155
+ batch, _, dim = sampled_action.shape
154
156
 
155
157
  switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
156
158
 
157
159
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
158
160
 
159
- action_intent_for_gating = rearrange(sampled_action_intents, 'b n d -> (b d) n')
161
+ action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
160
162
  switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
161
163
 
162
164
  forget = 1. - switch_beta
@@ -177,7 +179,7 @@ class MetaController(Module):
177
179
 
178
180
  modified_residual_stream = residual_stream + control_signal
179
181
 
180
- return modified_residual_stream, vae_kl_loss
182
+ return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
181
183
 
182
184
  # main transformer, which is subsumed into the environment after behavioral cloning
183
185
 
@@ -215,6 +217,8 @@ class Transformer(Module):
215
217
 
216
218
  self.meta_controller = meta_controller
217
219
 
220
+ self.register_buffer('zero', tensor(0.), persistent = False)
221
+
218
222
  def evolve(
219
223
  self,
220
224
  environment,
@@ -238,7 +242,7 @@ class Transformer(Module):
238
242
  discovery_phase = False,
239
243
  return_latents = False
240
244
  ):
241
- meta_controller = default(meta_controller, self.meta_controller, Identity())
245
+ meta_controller = default(meta_controller, self.meta_controller)
242
246
 
243
247
  embed = self.embed(ids)
244
248
 
@@ -246,7 +250,10 @@ class Transformer(Module):
246
250
 
247
251
  # meta controller acts on residual stream here
248
252
 
249
- modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
253
+ if exists(meta_controller):
254
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
255
+ else:
256
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
250
257
 
251
258
  # modified residual stream sent back
252
259
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.2"
3
+ version = "0.0.4"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }