metacontroller-pytorch 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller_with_binary_mapper.py +266 -0
- {metacontroller_pytorch-0.0.26.dist-info → metacontroller_pytorch-0.0.27.dist-info}/METADATA +2 -1
- metacontroller_pytorch-0.0.27.dist-info/RECORD +7 -0
- metacontroller_pytorch-0.0.26.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.26.dist-info → metacontroller_pytorch-0.0.27.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.26.dist-info → metacontroller_pytorch-0.0.27.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn, cat, stack, tensor, Tensor
|
|
10
|
+
from torch.nn import Module, GRU, Linear, Identity
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
|
|
13
|
+
# einops
|
|
14
|
+
|
|
15
|
+
import einx
|
|
16
|
+
from einops import einsum, rearrange, repeat, reduce
|
|
17
|
+
from einops.layers.torch import Rearrange
|
|
18
|
+
|
|
19
|
+
# external modules
|
|
20
|
+
|
|
21
|
+
from x_transformers import Encoder, Decoder
|
|
22
|
+
from x_mlps_pytorch import Feedforwards
|
|
23
|
+
|
|
24
|
+
from assoc_scan import AssocScan
|
|
25
|
+
|
|
26
|
+
from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
|
|
27
|
+
from torch_einops_utils.save_load import save_load
|
|
28
|
+
|
|
29
|
+
from vector_quantize_pytorch import BinaryMapper
|
|
30
|
+
|
|
31
|
+
# constants
|
|
32
|
+
|
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
|
34
|
+
|
|
35
|
+
GRU = partial(GRU, batch_first = True)
|
|
36
|
+
|
|
37
|
+
# helper functions
|
|
38
|
+
|
|
39
|
+
def exists(v):
|
|
40
|
+
return v is not None
|
|
41
|
+
|
|
42
|
+
def default(*args):
|
|
43
|
+
for arg in args:
|
|
44
|
+
if exists(arg):
|
|
45
|
+
return arg
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
def straight_through(src, tgt):
|
|
49
|
+
return tgt + src - src.detach()
|
|
50
|
+
|
|
51
|
+
# meta controller
|
|
52
|
+
|
|
53
|
+
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
54
|
+
'prev_hiddens',
|
|
55
|
+
'action_dist',
|
|
56
|
+
'codes',
|
|
57
|
+
'kl_loss',
|
|
58
|
+
'switch_loss'
|
|
59
|
+
))
|
|
60
|
+
|
|
61
|
+
@save_load()
|
|
62
|
+
class MetaControllerWithBinaryMapper(Module):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
dim_model,
|
|
66
|
+
*,
|
|
67
|
+
dim_meta_controller = 256,
|
|
68
|
+
dim_code_bits = 4,
|
|
69
|
+
switch_per_code = False,
|
|
70
|
+
decoder_expansion_factor = 2.,
|
|
71
|
+
decoder_depth = 1,
|
|
72
|
+
hypernetwork_low_rank = 16,
|
|
73
|
+
assoc_scan_kwargs: dict = dict(),
|
|
74
|
+
bidirectional_temporal_encoder_kwargs: dict = dict(
|
|
75
|
+
attn_dim_head = 32, heads = 8
|
|
76
|
+
),
|
|
77
|
+
kl_loss_threshold = 0.
|
|
78
|
+
):
|
|
79
|
+
super().__init__()
|
|
80
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
81
|
+
|
|
82
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
83
|
+
|
|
84
|
+
self.bidirectional_temporal_encoder = Encoder(dim = dim_meta, depth = 1, **bidirectional_temporal_encoder_kwargs)
|
|
85
|
+
|
|
86
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
87
|
+
self.emitter_to_binary_logits = Linear(dim_meta * 2, dim_code_bits)
|
|
88
|
+
|
|
89
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
90
|
+
self.proposer_to_binary_logits = Linear(dim_meta, dim_code_bits)
|
|
91
|
+
|
|
92
|
+
# binary mapper
|
|
93
|
+
# proposed in https://arxiv.org/abs/2510.17558 as a more stable alternative to VAE by François Fleuret
|
|
94
|
+
|
|
95
|
+
self.binary_mapper = BinaryMapper(
|
|
96
|
+
bits = dim_code_bits,
|
|
97
|
+
kl_loss_threshold = kl_loss_threshold
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self.dim_code_bits = dim_code_bits
|
|
101
|
+
self.num_codes = self.binary_mapper.num_codes
|
|
102
|
+
|
|
103
|
+
# switching unit
|
|
104
|
+
|
|
105
|
+
self.switch_per_code = switch_per_code
|
|
106
|
+
|
|
107
|
+
self.switching_unit = GRU(dim_meta + self.num_codes, dim_meta)
|
|
108
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, self.num_codes if switch_per_code else 1, bias = False)
|
|
109
|
+
|
|
110
|
+
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
111
|
+
|
|
112
|
+
# decoder
|
|
113
|
+
|
|
114
|
+
assert hypernetwork_low_rank < self.num_codes
|
|
115
|
+
|
|
116
|
+
dim_decoder_hidden = int(self.num_codes * decoder_expansion_factor)
|
|
117
|
+
|
|
118
|
+
self.decoder = Feedforwards(
|
|
119
|
+
dim_in = self.num_codes,
|
|
120
|
+
dim = dim_decoder_hidden,
|
|
121
|
+
depth = decoder_depth,
|
|
122
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
126
|
+
|
|
127
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
128
|
+
|
|
129
|
+
def discovery_parameters(self):
|
|
130
|
+
return [
|
|
131
|
+
*self.model_to_meta.parameters(),
|
|
132
|
+
*self.bidirectional_temporal_encoder.parameters(),
|
|
133
|
+
*self.emitter.parameters(),
|
|
134
|
+
*self.emitter_to_binary_logits.parameters(),
|
|
135
|
+
*self.binary_mapper.parameters(),
|
|
136
|
+
*self.decoder.parameters(),
|
|
137
|
+
*self.switch_gating.parameters()
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
def internal_rl_parameters(self):
|
|
141
|
+
return [
|
|
142
|
+
*self.action_proposer.parameters(),
|
|
143
|
+
*self.proposer_to_binary_logits.parameters()
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
def forward(
|
|
147
|
+
self,
|
|
148
|
+
residual_stream,
|
|
149
|
+
cache: MetaControllerOutput | None = None,
|
|
150
|
+
discovery_phase = False,
|
|
151
|
+
hard_switch = False,
|
|
152
|
+
temperature = 1.,
|
|
153
|
+
episode_lens: Tensor | None = None
|
|
154
|
+
):
|
|
155
|
+
device = residual_stream.device
|
|
156
|
+
|
|
157
|
+
# destruct prev cache
|
|
158
|
+
|
|
159
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens, prev_sampled_code = cache.prev_hiddens if exists(cache) else ((None,) * 4)
|
|
160
|
+
|
|
161
|
+
# getting proposed action for the two phases
|
|
162
|
+
|
|
163
|
+
next_action_proposer_hidden = None
|
|
164
|
+
|
|
165
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
166
|
+
|
|
167
|
+
if discovery_phase:
|
|
168
|
+
mask = maybe(lens_to_mask)(episode_lens, meta_embed.shape[1])
|
|
169
|
+
|
|
170
|
+
encoded_temporal = self.bidirectional_temporal_encoder(meta_embed, mask = mask)
|
|
171
|
+
|
|
172
|
+
proposed_action_hidden, _ = self.emitter(cat((encoded_temporal, meta_embed), dim = -1))
|
|
173
|
+
to_logits = self.emitter_to_binary_logits
|
|
174
|
+
|
|
175
|
+
else: # else internal rl phase
|
|
176
|
+
|
|
177
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
178
|
+
to_logits = self.proposer_to_binary_logits
|
|
179
|
+
|
|
180
|
+
# sample from the binary mapper
|
|
181
|
+
|
|
182
|
+
binary_logits = to_logits(proposed_action_hidden)
|
|
183
|
+
|
|
184
|
+
one_hot, kl_loss = self.binary_mapper(
|
|
185
|
+
binary_logits,
|
|
186
|
+
temperature = temperature,
|
|
187
|
+
reduce_aux_kl_loss = False
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# bottled action is now the one-hot sparse codes (with straight-through)
|
|
191
|
+
|
|
192
|
+
sampled_codes = one_hot
|
|
193
|
+
|
|
194
|
+
# switching unit timer
|
|
195
|
+
|
|
196
|
+
batch, seq_len, dim = sampled_codes.shape
|
|
197
|
+
|
|
198
|
+
if not exists(prev_sampled_code):
|
|
199
|
+
prev_sampled_code = torch.zeros(batch, 1, self.num_codes, device = device)
|
|
200
|
+
|
|
201
|
+
if discovery_phase:
|
|
202
|
+
z_prev = cat((prev_sampled_code, sampled_codes[:, :-1]), dim = 1)
|
|
203
|
+
else:
|
|
204
|
+
assert seq_len == 1, f'inference RL phase must be done one token at a time'
|
|
205
|
+
z_prev = prev_sampled_code
|
|
206
|
+
|
|
207
|
+
switch_input = torch.cat((meta_embed, z_prev), dim=-1)
|
|
208
|
+
|
|
209
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
210
|
+
switch_input,
|
|
211
|
+
prev_switching_unit_gru_hidden
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
215
|
+
|
|
216
|
+
# losses
|
|
217
|
+
|
|
218
|
+
switch_loss = self.zero
|
|
219
|
+
|
|
220
|
+
if discovery_phase:
|
|
221
|
+
# weight unreduced kl loss by switch gates
|
|
222
|
+
|
|
223
|
+
weighted_kl_loss = kl_loss * switch_beta
|
|
224
|
+
kl_loss = weighted_kl_loss.sum(dim = -1).mean()
|
|
225
|
+
|
|
226
|
+
# encourage less switching
|
|
227
|
+
|
|
228
|
+
switch_loss = switch_beta.mean()
|
|
229
|
+
else:
|
|
230
|
+
kl_loss = self.zero
|
|
231
|
+
|
|
232
|
+
# maybe hard switch, then use associative scan
|
|
233
|
+
|
|
234
|
+
if hard_switch:
|
|
235
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
236
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
237
|
+
|
|
238
|
+
forget = 1. - switch_beta
|
|
239
|
+
|
|
240
|
+
# gated codes (or soft distribution)
|
|
241
|
+
|
|
242
|
+
gated_codes = self.switch_gating(switch_beta, sampled_codes * forget, prev = prev_switch_gated_hiddens)
|
|
243
|
+
|
|
244
|
+
next_switch_gated_codes = gated_codes[:, -1]
|
|
245
|
+
|
|
246
|
+
# decoder
|
|
247
|
+
|
|
248
|
+
decoder_out = self.decoder(gated_codes)
|
|
249
|
+
|
|
250
|
+
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
251
|
+
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
252
|
+
|
|
253
|
+
# generating the residual stream controlling signal
|
|
254
|
+
|
|
255
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
256
|
+
|
|
257
|
+
# returning
|
|
258
|
+
|
|
259
|
+
next_hiddens = (
|
|
260
|
+
next_action_proposer_hidden,
|
|
261
|
+
next_switching_unit_gru_hidden,
|
|
262
|
+
next_switch_gated_codes,
|
|
263
|
+
sampled_codes[:, -1:]
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return control_signal, MetaControllerOutput(next_hiddens, binary_logits, sampled_codes, kl_loss, switch_loss)
|
{metacontroller_pytorch-0.0.26.dist-info → metacontroller_pytorch-0.0.27.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.27
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -42,6 +42,7 @@ Requires-Dist: loguru
|
|
|
42
42
|
Requires-Dist: memmap-replay-buffer>=0.0.23
|
|
43
43
|
Requires-Dist: torch-einops-utils>=0.0.16
|
|
44
44
|
Requires-Dist: torch>=2.5
|
|
45
|
+
Requires-Dist: vector-quantize-pytorch>=1.27.20
|
|
45
46
|
Requires-Dist: x-evolution>=0.1.23
|
|
46
47
|
Requires-Dist: x-mlps-pytorch
|
|
47
48
|
Requires-Dist: x-transformers
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
|
|
3
|
+
metacontroller/metacontroller_with_binary_mapper.py,sha256=GCvyF-5XILiexKQKu26h8NroTyeS7ksS1Q02mN5EGVw,8014
|
|
4
|
+
metacontroller_pytorch-0.0.27.dist-info/METADATA,sha256=X_cwahEVbf7nS7c7QEi1t-kBqezkjycP4fSFKL1D-rk,4411
|
|
5
|
+
metacontroller_pytorch-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
6
|
+
metacontroller_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
metacontroller_pytorch-0.0.27.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
|
|
3
|
-
metacontroller_pytorch-0.0.26.dist-info/METADATA,sha256=E00jJkfHS_wsEuh-a4iIo42fQQ1NhX7r-HuSWtyimUQ,4363
|
|
4
|
-
metacontroller_pytorch-0.0.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.26.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.26.dist-info → metacontroller_pytorch-0.0.27.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|