metacontroller-pytorch 0.0.27__py3-none-any.whl → 0.0.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,7 +23,7 @@ from x_mlps_pytorch import Feedforwards
23
23
 
24
24
  from assoc_scan import AssocScan
25
25
 
26
- from torch_einops_utils import maybe, pad_at_dim, lens_to_mask
26
+ from torch_einops_utils import maybe, pad_at_dim, lens_to_mask, align_dims_left
27
27
  from torch_einops_utils.save_load import save_load
28
28
 
29
29
  from vector_quantize_pytorch import BinaryMapper
@@ -220,6 +220,8 @@ class MetaControllerWithBinaryMapper(Module):
220
220
  if discovery_phase:
221
221
  # weight unreduced kl loss by switch gates
222
222
 
223
+ kl_loss, switch_beta = align_dims_left((kl_loss, switch_beta))
224
+
223
225
  weighted_kl_loss = kl_loss * switch_beta
224
226
  kl_loss = weighted_kl_loss.sum(dim = -1).mean()
225
227
 
@@ -0,0 +1,250 @@
1
+ from typing import Any, List, Type, Union, Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch import nn
6
+ from einops import rearrange
7
+ from metacontroller.metacontroller import Transformer
8
+
9
+ class TransformerWithResnetEncoder(Transformer):
10
+ def __init__(self, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.resnet_dim = kwargs["state_embed_readout"]["num_continuous"]
13
+ self.visual_encoder = resnet18(out_dim=self.resnet_dim)
14
+
15
+ def visual_encode(self, x: torch.Tensor) -> torch.Tensor:
16
+ b, t = x.shape[:2]
17
+ x = rearrange(x, 'b t h w c -> (b t) c h w')
18
+ h = self.visual_encoder(x)
19
+ h = rearrange(h, '(b t) d -> b t d', b=b, t=t, d = self.resnet_dim)
20
+ return h
21
+
22
+ # resnet components taken from https://github.com/Lornatang/ResNet-PyTorch
23
+
24
+ class _BasicBlock(nn.Module):
25
+ expansion: int = 1
26
+
27
+ def __init__(
28
+ self,
29
+ in_channels: int,
30
+ out_channels: int,
31
+ stride: int,
32
+ downsample: Optional[nn.Module] = None,
33
+ groups: int = 1,
34
+ base_channels: int = 64,
35
+ ) -> None:
36
+ super(_BasicBlock, self).__init__()
37
+ self.stride = stride
38
+ self.downsample = downsample
39
+ self.groups = groups
40
+ self.base_channels = base_channels
41
+
42
+ self.conv1 = nn.Conv2d(in_channels, out_channels, (3, 3), (stride, stride), (1, 1), bias=False)
43
+ self.bn1 = nn.BatchNorm2d(out_channels)
44
+ self.relu = nn.ReLU(True)
45
+ self.conv2 = nn.Conv2d(out_channels, out_channels, (3, 3), (1, 1), (1, 1), bias=False)
46
+ self.bn2 = nn.BatchNorm2d(out_channels)
47
+
48
+ def forward(self, x: Tensor) -> Tensor:
49
+ identity = x
50
+
51
+ out = self.conv1(x)
52
+ out = self.bn1(out)
53
+ out = self.relu(out)
54
+
55
+ out = self.conv2(out)
56
+ out = self.bn2(out)
57
+
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+
61
+ out = torch.add(out, identity)
62
+ out = self.relu(out)
63
+
64
+ return out
65
+
66
+
67
+ class _Bottleneck(nn.Module):
68
+ expansion: int = 4
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int,
73
+ out_channels: int,
74
+ stride: int,
75
+ downsample: Optional[nn.Module] = None,
76
+ groups: int = 1,
77
+ base_channels: int = 64,
78
+ ) -> None:
79
+ super(_Bottleneck, self).__init__()
80
+ self.stride = stride
81
+ self.downsample = downsample
82
+ self.groups = groups
83
+ self.base_channels = base_channels
84
+
85
+ channels = int(out_channels * (base_channels / 64.0)) * groups
86
+
87
+ self.conv1 = nn.Conv2d(in_channels, channels, (1, 1), (1, 1), (0, 0), bias=False)
88
+ self.bn1 = nn.BatchNorm2d(channels)
89
+ self.conv2 = nn.Conv2d(channels, channels, (3, 3), (stride, stride), (1, 1), groups=groups, bias=False)
90
+ self.bn2 = nn.BatchNorm2d(channels)
91
+ self.conv3 = nn.Conv2d(channels, int(out_channels * self.expansion), (1, 1), (1, 1), (0, 0), bias=False)
92
+ self.bn3 = nn.BatchNorm2d(int(out_channels * self.expansion))
93
+ self.relu = nn.ReLU(True)
94
+
95
+ def forward(self, x: Tensor) -> Tensor:
96
+ identity = x
97
+
98
+ out = self.conv1(x)
99
+ out = self.bn1(out)
100
+ out = self.relu(out)
101
+
102
+ out = self.conv2(out)
103
+ out = self.bn2(out)
104
+ out = self.relu(out)
105
+
106
+ out = self.conv3(out)
107
+ out = self.bn3(out)
108
+
109
+ if self.downsample is not None:
110
+ identity = self.downsample(x)
111
+
112
+ out = torch.add(out, identity)
113
+ out = self.relu(out)
114
+
115
+ return out
116
+
117
+
118
+ class ResNet(nn.Module):
119
+
120
+ def __init__(
121
+ self,
122
+ arch_cfg: List[int],
123
+ block: Type[Union[_BasicBlock, _Bottleneck]],
124
+ groups: int = 1,
125
+ channels_per_group: int = 64,
126
+ out_dim: int = 1000,
127
+ ) -> None:
128
+ super(ResNet, self).__init__()
129
+ self.in_channels = 64
130
+ self.dilation = 1
131
+ self.groups = groups
132
+ self.base_channels = channels_per_group
133
+
134
+ self.conv1 = nn.Conv2d(3, self.in_channels, (7, 7), (2, 2), (3, 3), bias=False)
135
+ self.bn1 = nn.BatchNorm2d(self.in_channels)
136
+ self.relu = nn.ReLU(True)
137
+ self.maxpool = nn.MaxPool2d((3, 3), (2, 2), (1, 1))
138
+
139
+ self.layer1 = self._make_layer(arch_cfg[0], block, 64, 1)
140
+ self.layer2 = self._make_layer(arch_cfg[1], block, 128, 2)
141
+ self.layer3 = self._make_layer(arch_cfg[2], block, 256, 2)
142
+ self.layer4 = self._make_layer(arch_cfg[3], block, 512, 2)
143
+
144
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
145
+
146
+ self.fc = nn.Linear(512 * block.expansion, out_dim)
147
+
148
+ # Initialize neural network weights
149
+ self._initialize_weights()
150
+
151
+ def _make_layer(
152
+ self,
153
+ repeat_times: int,
154
+ block: Type[Union[_BasicBlock, _Bottleneck]],
155
+ channels: int,
156
+ stride: int = 1,
157
+ ) -> nn.Sequential:
158
+ downsample = None
159
+
160
+ if stride != 1 or self.in_channels != channels * block.expansion:
161
+ downsample = nn.Sequential(
162
+ nn.Conv2d(self.in_channels, channels * block.expansion, (1, 1), (stride, stride), (0, 0), bias=False),
163
+ nn.BatchNorm2d(channels * block.expansion),
164
+ )
165
+
166
+ layers = [
167
+ block(
168
+ self.in_channels,
169
+ channels,
170
+ stride,
171
+ downsample,
172
+ self.groups,
173
+ self.base_channels
174
+ )
175
+ ]
176
+ self.in_channels = channels * block.expansion
177
+ for _ in range(1, repeat_times):
178
+ layers.append(
179
+ block(
180
+ self.in_channels,
181
+ channels,
182
+ 1,
183
+ None,
184
+ self.groups,
185
+ self.base_channels,
186
+ )
187
+ )
188
+
189
+ return nn.Sequential(*layers)
190
+
191
+ def forward(self, x: Tensor) -> Tensor:
192
+ out = self._forward_impl(x)
193
+
194
+ return out
195
+
196
+ # Support torch.script function
197
+ def _forward_impl(self, x: Tensor) -> Tensor:
198
+ out = self.conv1(x)
199
+ out = self.bn1(out)
200
+ out = self.relu(out)
201
+ out = self.maxpool(out)
202
+
203
+ out = self.layer1(out)
204
+ out = self.layer2(out)
205
+ out = self.layer3(out)
206
+ out = self.layer4(out)
207
+
208
+ out = self.avgpool(out)
209
+ out = torch.flatten(out, 1)
210
+ out = self.fc(out)
211
+
212
+ return out
213
+
214
+ def _initialize_weights(self) -> None:
215
+ for module in self.modules():
216
+ if isinstance(module, nn.Conv2d):
217
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
218
+ elif isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
219
+ nn.init.constant_(module.weight, 1)
220
+ nn.init.constant_(module.bias, 0)
221
+
222
+
223
+ def resnet18(**kwargs: Any) -> ResNet:
224
+ model = ResNet([2, 2, 2, 2], _BasicBlock, **kwargs)
225
+
226
+ return model
227
+
228
+
229
+ def resnet34(**kwargs: Any) -> ResNet:
230
+ model = ResNet([3, 4, 6, 3], _BasicBlock, **kwargs)
231
+
232
+ return model
233
+
234
+
235
+ def resnet50(**kwargs: Any) -> ResNet:
236
+ model = ResNet([3, 4, 6, 3], _Bottleneck, **kwargs)
237
+
238
+ return model
239
+
240
+
241
+ def resnet101(**kwargs: Any) -> ResNet:
242
+ model = ResNet([3, 4, 23, 3], _Bottleneck, **kwargs)
243
+
244
+ return model
245
+
246
+
247
+ def resnet152(**kwargs: Any) -> ResNet:
248
+ model = ResNet([3, 8, 36, 3], _Bottleneck, **kwargs)
249
+
250
+ return model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.27
3
+ Version: 0.0.29
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -67,6 +67,8 @@ $ pip install metacontroller-pytorch
67
67
 
68
68
  - [Pranoy](https://github.com/pranoyr) for submitting a pull request for fixing the previous latent action not being included in the inputs to the switching unit
69
69
 
70
+ - [Diego Calanzone](https://github.com/ddidacus) for proposing testing on BabyAI gridworld task, and submitting the [pull request](https://github.com/lucidrains/metacontroller/pull/3) for behavior cloning and discovery phase training for it!
71
+
70
72
  ## Citations
71
73
 
72
74
  ```bibtex
@@ -103,3 +105,5 @@ $ pip install metacontroller-pytorch
103
105
  url = {https://arxiv.org/abs/2510.17558},
104
106
  }
105
107
  ```
108
+
109
+ *Life can only be understood backwards; but it must be lived forwards* - Søren Kierkegaard
@@ -0,0 +1,8 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
3
+ metacontroller/metacontroller_with_binary_mapper.py,sha256=uUFCSIRq20TdctRd7O20A_I2SiB9AgYS6z5iQMFqf2Q,8107
4
+ metacontroller/metacontroller_with_resnet.py,sha256=YKHcazRZrrRParHRH-H_EPvT1-55LHKAs5pM6gwuT20,7394
5
+ metacontroller_pytorch-0.0.29.dist-info/METADATA,sha256=8zeOj2sUZ-5V_qGXvzXoBH3lpCJqHgPfZq0-YllSrTs,4747
6
+ metacontroller_pytorch-0.0.29.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
7
+ metacontroller_pytorch-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ metacontroller_pytorch-0.0.29.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=LWEq069EnBP3Sr6FTiDtz0cM5SFFT1zl35WkU6_kWGA,14451
3
- metacontroller/metacontroller_with_binary_mapper.py,sha256=GCvyF-5XILiexKQKu26h8NroTyeS7ksS1Q02mN5EGVw,8014
4
- metacontroller_pytorch-0.0.27.dist-info/METADATA,sha256=X_cwahEVbf7nS7c7QEi1t-kBqezkjycP4fSFKL1D-rk,4411
5
- metacontroller_pytorch-0.0.27.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
6
- metacontroller_pytorch-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- metacontroller_pytorch-0.0.27.dist-info/RECORD,,