locoformer 0.0.6__tar.gz → 0.0.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.6 → locoformer-0.0.17}/.gitignore +2 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/PKG-INFO +3 -2
- {locoformer-0.0.6 → locoformer-0.0.17}/README.md +1 -1
- {locoformer-0.0.6 → locoformer-0.0.17}/locoformer/locoformer.py +417 -19
- {locoformer-0.0.6 → locoformer-0.0.17}/pyproject.toml +2 -1
- locoformer-0.0.17/tests/test_locoformer.py +86 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/train.py +1 -1
- locoformer-0.0.17/train_gym.py +262 -0
- locoformer-0.0.6/tests/test_locoformer.py +0 -38
- {locoformer-0.0.6 → locoformer-0.0.17}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/LICENSE +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/data/README.md +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/data/enwik8.gz +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/fig3.png +0 -0
- {locoformer-0.0.6 → locoformer-0.0.17}/locoformer/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.17
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.0
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
|
|
|
53
54
|
|
|
54
55
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
55
56
|
|
|
56
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
58
|
|
|
58
59
|
## Sponsors
|
|
59
60
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
6
6
|
|
|
7
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
7
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
8
8
|
|
|
9
9
|
## Sponsors
|
|
10
10
|
|
|
@@ -1,11 +1,24 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from collections import namedtuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy import ndarray
|
|
10
|
+
from numpy.lib.format import open_memmap
|
|
11
|
+
|
|
12
|
+
from beartype import beartype
|
|
13
|
+
from beartype.door import is_bearable
|
|
14
|
+
|
|
4
15
|
import torch
|
|
5
|
-
from torch import nn, cat, stack, arange, is_tensor
|
|
16
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
6
17
|
import torch.nn.functional as F
|
|
7
18
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
8
19
|
from torch.utils._pytree import tree_map
|
|
20
|
+
from torch.utils.data import Dataset, DataLoader
|
|
21
|
+
from torch.optim import Optimizer
|
|
9
22
|
|
|
10
23
|
import einx
|
|
11
24
|
from einops import rearrange, einsum
|
|
@@ -15,6 +28,8 @@ from rotary_embedding_torch import RotaryEmbedding
|
|
|
15
28
|
|
|
16
29
|
from assoc_scan import AssocScan
|
|
17
30
|
|
|
31
|
+
# constants
|
|
32
|
+
|
|
18
33
|
LinearNoBias = partial(Linear, bias = False)
|
|
19
34
|
|
|
20
35
|
# helper functions
|
|
@@ -31,20 +46,30 @@ def first(arr):
|
|
|
31
46
|
def divisible_by(num, den):
|
|
32
47
|
return (num % den) == 0
|
|
33
48
|
|
|
49
|
+
# tensor helpers
|
|
50
|
+
|
|
51
|
+
def log(t, eps = 1e-20):
|
|
52
|
+
return t.clamp_min(eps).log()
|
|
53
|
+
|
|
34
54
|
def tree_map_tensor(x, fn):
|
|
35
55
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
36
56
|
|
|
37
|
-
def
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
57
|
+
def pad_at_dim(
|
|
58
|
+
t,
|
|
59
|
+
pad: tuple[int, int],
|
|
60
|
+
dim = -1,
|
|
61
|
+
value = 0.
|
|
62
|
+
):
|
|
63
|
+
if pad == (0, 0):
|
|
64
|
+
return t
|
|
42
65
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
66
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
67
|
+
zeros = ((0, 0) * dims_from_right)
|
|
68
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
46
69
|
|
|
47
|
-
|
|
70
|
+
def calc_entropy(logits):
|
|
71
|
+
prob = logits.softmax(dim = -1)
|
|
72
|
+
return -(prob * log(prob)).sum(dim = -1)
|
|
48
73
|
|
|
49
74
|
# generalized advantage estimate
|
|
50
75
|
|
|
@@ -52,7 +77,7 @@ def combine_kv_cache(cache1, cache2):
|
|
|
52
77
|
def calc_gae(
|
|
53
78
|
rewards,
|
|
54
79
|
values,
|
|
55
|
-
masks,
|
|
80
|
+
masks = None,
|
|
56
81
|
gamma = 0.99,
|
|
57
82
|
lam = 0.95,
|
|
58
83
|
use_accelerated = None
|
|
@@ -63,6 +88,9 @@ def calc_gae(
|
|
|
63
88
|
values = F.pad(values, (0, 1), value = 0.)
|
|
64
89
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
65
90
|
|
|
91
|
+
if not exists(masks):
|
|
92
|
+
masks = torch.ones_like(values)
|
|
93
|
+
|
|
66
94
|
delta = rewards + gamma * values_next * masks - values
|
|
67
95
|
gates = gamma * lam * masks
|
|
68
96
|
|
|
@@ -72,7 +100,7 @@ def calc_gae(
|
|
|
72
100
|
|
|
73
101
|
returns = gae + values
|
|
74
102
|
|
|
75
|
-
return returns
|
|
103
|
+
return gae, returns
|
|
76
104
|
|
|
77
105
|
# transformer-xl mask w/ flex attn
|
|
78
106
|
|
|
@@ -114,8 +142,8 @@ def create_xl_mask(
|
|
|
114
142
|
# handle intra-episodic attention if needed
|
|
115
143
|
|
|
116
144
|
if exists(episode_ids):
|
|
117
|
-
q_episode =
|
|
118
|
-
k_episode =
|
|
145
|
+
q_episode = episode_ids[b, q + offset]
|
|
146
|
+
k_episode = episode_ids[b, k]
|
|
119
147
|
|
|
120
148
|
intra_episode_mask = q_episode == k_episode
|
|
121
149
|
mask = mask & intra_episode_mask
|
|
@@ -146,6 +174,217 @@ def create_sliding_mask(
|
|
|
146
174
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
147
175
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
148
176
|
|
|
177
|
+
# data
|
|
178
|
+
|
|
179
|
+
def collate_var_time(data):
|
|
180
|
+
|
|
181
|
+
datum = first(data)
|
|
182
|
+
keys = datum.keys()
|
|
183
|
+
|
|
184
|
+
all_tensors = zip(*[datum.values() for datum in data])
|
|
185
|
+
|
|
186
|
+
collated_values = []
|
|
187
|
+
|
|
188
|
+
for key, tensors in zip(keys, all_tensors):
|
|
189
|
+
|
|
190
|
+
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
191
|
+
|
|
192
|
+
if key != '_lens':
|
|
193
|
+
|
|
194
|
+
times = [t.shape[0] for t in tensors]
|
|
195
|
+
max_time = max(times)
|
|
196
|
+
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
197
|
+
|
|
198
|
+
collated_values.append(stack(tensors))
|
|
199
|
+
|
|
200
|
+
return dict(zip(keys, collated_values))
|
|
201
|
+
|
|
202
|
+
class ReplayDataset(Dataset):
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
folder: str | Path,
|
|
206
|
+
fields: tuple[str, ...] | None = None
|
|
207
|
+
):
|
|
208
|
+
if isinstance(folder, str):
|
|
209
|
+
folder = Path(folder)
|
|
210
|
+
|
|
211
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
212
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
213
|
+
|
|
214
|
+
# get indices of non-zero lengthed episodes
|
|
215
|
+
|
|
216
|
+
nonzero_episodes = self.episode_lens > 0
|
|
217
|
+
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
218
|
+
|
|
219
|
+
# get all data files
|
|
220
|
+
|
|
221
|
+
filepaths = [*folder.glob('*.data.npy')]
|
|
222
|
+
assert len(filepaths) > 0
|
|
223
|
+
|
|
224
|
+
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
225
|
+
|
|
226
|
+
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
227
|
+
|
|
228
|
+
fields = default(fields, fieldnames_from_files)
|
|
229
|
+
|
|
230
|
+
self.memmaps = dict()
|
|
231
|
+
|
|
232
|
+
for field in fields:
|
|
233
|
+
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
234
|
+
|
|
235
|
+
path = fieldname_to_filepath[field]
|
|
236
|
+
|
|
237
|
+
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
238
|
+
|
|
239
|
+
def __len__(self):
|
|
240
|
+
return len(self.indices)
|
|
241
|
+
|
|
242
|
+
def __getitem__(self, idx):
|
|
243
|
+
episode_index = self.indices[idx]
|
|
244
|
+
|
|
245
|
+
episode_len = self.episode_lens[episode_index]
|
|
246
|
+
|
|
247
|
+
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
248
|
+
|
|
249
|
+
data['_lens'] = tensor(episode_len)
|
|
250
|
+
|
|
251
|
+
return data
|
|
252
|
+
|
|
253
|
+
class ReplayBuffer:
|
|
254
|
+
|
|
255
|
+
@beartype
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
folder: str | Path,
|
|
259
|
+
max_episodes: int,
|
|
260
|
+
max_timesteps: int,
|
|
261
|
+
fields: dict[
|
|
262
|
+
str,
|
|
263
|
+
str | tuple[str, int | tuple[int, ...]]
|
|
264
|
+
]
|
|
265
|
+
):
|
|
266
|
+
|
|
267
|
+
# folder for data
|
|
268
|
+
|
|
269
|
+
if not isinstance(folder, Path):
|
|
270
|
+
folder = Path(folder)
|
|
271
|
+
folder.mkdir(exist_ok = True)
|
|
272
|
+
|
|
273
|
+
self.folder = folder
|
|
274
|
+
assert folder.is_dir()
|
|
275
|
+
|
|
276
|
+
# keeping track of episode length
|
|
277
|
+
|
|
278
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
279
|
+
|
|
280
|
+
self.episode_index = 0
|
|
281
|
+
self.timestep_index = 0
|
|
282
|
+
|
|
283
|
+
self.max_episodes = max_episodes
|
|
284
|
+
self.max_timesteps= max_timesteps
|
|
285
|
+
|
|
286
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
|
|
287
|
+
|
|
288
|
+
# create the memmap for individual data tracks
|
|
289
|
+
|
|
290
|
+
self.shapes = dict()
|
|
291
|
+
self.dtypes = dict()
|
|
292
|
+
self.memmaps = dict()
|
|
293
|
+
self.fieldnames = set(fields.keys())
|
|
294
|
+
|
|
295
|
+
for field_name, field_info in fields.items():
|
|
296
|
+
|
|
297
|
+
# some flexibility
|
|
298
|
+
|
|
299
|
+
field_info = (field_info, ()) if isinstance(field_info, str) else field_info
|
|
300
|
+
|
|
301
|
+
dtype_str, shape = field_info
|
|
302
|
+
assert dtype_str in {'int', 'float', 'bool'}
|
|
303
|
+
|
|
304
|
+
dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
|
|
305
|
+
|
|
306
|
+
# memmap file
|
|
307
|
+
|
|
308
|
+
filepath = folder / f'{field_name}.data.npy'
|
|
309
|
+
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
310
|
+
|
|
311
|
+
self.memmaps[field_name] = memmap
|
|
312
|
+
self.shapes[field_name] = shape
|
|
313
|
+
self.dtypes[field_name] = dtype
|
|
314
|
+
|
|
315
|
+
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
|
+
|
|
317
|
+
def reset_(self):
|
|
318
|
+
self.episode_lens[:] = 0
|
|
319
|
+
self.episode_index = 0
|
|
320
|
+
self.timestep_index = 0
|
|
321
|
+
|
|
322
|
+
def advance_episode(self):
|
|
323
|
+
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
324
|
+
self.timestep_index = 0
|
|
325
|
+
|
|
326
|
+
def flush(self):
|
|
327
|
+
self.episode_lens[self.episode_index] = self.timestep_index
|
|
328
|
+
|
|
329
|
+
for memmap in self.memmaps.values():
|
|
330
|
+
memmap.flush()
|
|
331
|
+
|
|
332
|
+
self.episode_lens.flush()
|
|
333
|
+
|
|
334
|
+
@contextmanager
|
|
335
|
+
def one_episode(self):
|
|
336
|
+
|
|
337
|
+
yield
|
|
338
|
+
|
|
339
|
+
self.flush()
|
|
340
|
+
self.advance_episode()
|
|
341
|
+
|
|
342
|
+
@beartype
|
|
343
|
+
def store_datapoint(
|
|
344
|
+
self,
|
|
345
|
+
episode_index: int,
|
|
346
|
+
timestep_index: int,
|
|
347
|
+
name: str,
|
|
348
|
+
datapoint: Tensor | ndarray
|
|
349
|
+
):
|
|
350
|
+
assert 0 <= episode_index < self.max_episodes
|
|
351
|
+
assert 0 <= timestep_index < self.max_timesteps
|
|
352
|
+
|
|
353
|
+
if is_tensor(datapoint):
|
|
354
|
+
datapoint = datapoint.detach().cpu().numpy()
|
|
355
|
+
|
|
356
|
+
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
357
|
+
|
|
358
|
+
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
359
|
+
|
|
360
|
+
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
361
|
+
|
|
362
|
+
def store(
|
|
363
|
+
self,
|
|
364
|
+
**data
|
|
365
|
+
):
|
|
366
|
+
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
367
|
+
|
|
368
|
+
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
369
|
+
|
|
370
|
+
for name, datapoint in data.items():
|
|
371
|
+
|
|
372
|
+
self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
|
|
373
|
+
|
|
374
|
+
self.timestep_index += 1
|
|
375
|
+
|
|
376
|
+
return self.memory_namedtuple(**data)
|
|
377
|
+
|
|
378
|
+
def dataset(self) -> Dataset:
|
|
379
|
+
self.flush()
|
|
380
|
+
|
|
381
|
+
return ReplayDataset(self.folder)
|
|
382
|
+
|
|
383
|
+
def dataloader(self, batch_size, **kwargs) -> DataLoader:
|
|
384
|
+
self.flush()
|
|
385
|
+
|
|
386
|
+
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
387
|
+
|
|
149
388
|
# transformer-xl with ppo
|
|
150
389
|
|
|
151
390
|
class Attention(Module):
|
|
@@ -204,7 +443,6 @@ class Attention(Module):
|
|
|
204
443
|
return_kv_cache = False,
|
|
205
444
|
):
|
|
206
445
|
seq_len = tokens.shape[-2]
|
|
207
|
-
assert seq_len <= self.window_size
|
|
208
446
|
|
|
209
447
|
device = tokens.device
|
|
210
448
|
|
|
@@ -365,7 +603,14 @@ class Locoformer(Module):
|
|
|
365
603
|
embedder: Module,
|
|
366
604
|
unembedder: Module,
|
|
367
605
|
transformer: dict | TransformerXL,
|
|
368
|
-
value_network: Module | None = None
|
|
606
|
+
value_network: Module | None = None,
|
|
607
|
+
discount_factor = 0.999,
|
|
608
|
+
gae_lam = 0.95,
|
|
609
|
+
ppo_eps_clip = 0.2,
|
|
610
|
+
ppo_entropy_weight = 0.01,
|
|
611
|
+
ppo_value_clip = 0.4,
|
|
612
|
+
value_loss_weight = 0.5,
|
|
613
|
+
calc_gae_kwargs: dict = dict()
|
|
369
614
|
):
|
|
370
615
|
super().__init__()
|
|
371
616
|
|
|
@@ -382,15 +627,160 @@ class Locoformer(Module):
|
|
|
382
627
|
self.fixed_window_size = transformer.fixed_window_size
|
|
383
628
|
self.window_size = transformer.window_size
|
|
384
629
|
|
|
630
|
+
# ppo related
|
|
631
|
+
|
|
632
|
+
self.discount_factor = discount_factor
|
|
633
|
+
self.gae_lam = gae_lam
|
|
634
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
635
|
+
self.ppo_entropy_weight = ppo_entropy_weight
|
|
636
|
+
self.ppo_value_clip = ppo_value_clip
|
|
637
|
+
self.value_loss_weight = value_loss_weight
|
|
638
|
+
|
|
639
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
640
|
+
|
|
641
|
+
# loss related
|
|
642
|
+
|
|
643
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
644
|
+
|
|
385
645
|
@property
|
|
386
646
|
def device(self):
|
|
387
647
|
return next(self.parameters()).device
|
|
388
648
|
|
|
649
|
+
def actor_parameters(self):
|
|
650
|
+
return self.unembedder.parameters()
|
|
651
|
+
|
|
652
|
+
def critic_parameters(self):
|
|
653
|
+
if not exists(self.value_network):
|
|
654
|
+
return []
|
|
655
|
+
|
|
656
|
+
return self.value_network.parameters()
|
|
657
|
+
|
|
658
|
+
def ppo(
|
|
659
|
+
self,
|
|
660
|
+
state,
|
|
661
|
+
action,
|
|
662
|
+
old_action_log_prob,
|
|
663
|
+
reward,
|
|
664
|
+
old_value,
|
|
665
|
+
mask,
|
|
666
|
+
actor_optim: Optimizer | None = None,
|
|
667
|
+
critic_optim: Optimizer | None = None
|
|
668
|
+
):
|
|
669
|
+
window_size = self.window_size
|
|
670
|
+
total_learnable_tokens = mask.sum().item()
|
|
671
|
+
|
|
672
|
+
windowed_tensors = [
|
|
673
|
+
t.split(window_size, dim = 1) for t in
|
|
674
|
+
(
|
|
675
|
+
state,
|
|
676
|
+
action,
|
|
677
|
+
old_action_log_prob,
|
|
678
|
+
reward,
|
|
679
|
+
old_value,
|
|
680
|
+
mask
|
|
681
|
+
)
|
|
682
|
+
]
|
|
683
|
+
|
|
684
|
+
mean_actor_loss = self.zero.clone()
|
|
685
|
+
mean_critic_loss = self.zero.clone()
|
|
686
|
+
|
|
687
|
+
# learn across windows
|
|
688
|
+
|
|
689
|
+
cache = None
|
|
690
|
+
|
|
691
|
+
for (
|
|
692
|
+
state,
|
|
693
|
+
action,
|
|
694
|
+
old_action_log_prob,
|
|
695
|
+
reward,
|
|
696
|
+
old_value,
|
|
697
|
+
mask
|
|
698
|
+
) in zip(*windowed_tensors):
|
|
699
|
+
|
|
700
|
+
(action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
|
|
701
|
+
entropy = calc_entropy(action_logits)
|
|
702
|
+
|
|
703
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
704
|
+
log_prob = action_logits.gather(-1, action)
|
|
705
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
706
|
+
|
|
707
|
+
# update actor, classic clipped surrogate loss
|
|
708
|
+
|
|
709
|
+
eps_clip = self.ppo_eps_clip
|
|
710
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
711
|
+
|
|
712
|
+
advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
713
|
+
|
|
714
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
715
|
+
|
|
716
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
717
|
+
|
|
718
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
719
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
720
|
+
|
|
721
|
+
# update critic
|
|
722
|
+
|
|
723
|
+
value_loss = F.mse_loss(returns, value, reduction = 'none')
|
|
724
|
+
|
|
725
|
+
value_clip = self.ppo_value_clip
|
|
726
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
727
|
+
clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
|
|
728
|
+
|
|
729
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
730
|
+
|
|
731
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
732
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
733
|
+
|
|
734
|
+
# accumulate
|
|
735
|
+
|
|
736
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
737
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
738
|
+
|
|
739
|
+
# optimizer update
|
|
740
|
+
|
|
741
|
+
if exists(actor_optim):
|
|
742
|
+
actor_optim.step()
|
|
743
|
+
actor_optim.zero_grad()
|
|
744
|
+
|
|
745
|
+
if exists(critic_optim):
|
|
746
|
+
critic_optim.step()
|
|
747
|
+
critic_optim.zero_grad()
|
|
748
|
+
|
|
749
|
+
# return losses for logging
|
|
750
|
+
|
|
751
|
+
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
752
|
+
|
|
753
|
+
def wrap_env_functions(self, env):
|
|
754
|
+
|
|
755
|
+
def wrapped_reset(*args, **kwargs):
|
|
756
|
+
state, _ = env.reset(*args, **kwargs)
|
|
757
|
+
|
|
758
|
+
if isinstance(state, ndarray):
|
|
759
|
+
state = from_numpy(state)
|
|
760
|
+
|
|
761
|
+
return state, _
|
|
762
|
+
|
|
763
|
+
def wrapped_step(action, *args, **kwargs):
|
|
764
|
+
out = env.step(action.item(), *args, **kwargs)
|
|
765
|
+
|
|
766
|
+
def transform_output(el):
|
|
767
|
+
if isinstance(el, ndarray):
|
|
768
|
+
return from_numpy(el)
|
|
769
|
+
elif isinstance(el, (int, bool, float)):
|
|
770
|
+
return tensor(el)
|
|
771
|
+
else:
|
|
772
|
+
return el
|
|
773
|
+
|
|
774
|
+
return tree_map(transform_output, out)
|
|
775
|
+
|
|
776
|
+
return wrapped_reset, wrapped_step
|
|
777
|
+
|
|
389
778
|
def get_stateful_forward(
|
|
390
779
|
self,
|
|
391
780
|
initial_states: Tensor | None = None,
|
|
392
781
|
inference_mode = False,
|
|
393
782
|
has_batch_dim = False,
|
|
783
|
+
has_time_dim = False,
|
|
394
784
|
**kwargs
|
|
395
785
|
):
|
|
396
786
|
window_size = self.window_size
|
|
@@ -400,11 +790,14 @@ class Locoformer(Module):
|
|
|
400
790
|
def stateful_forward(state: Tensor, **override_kwargs):
|
|
401
791
|
nonlocal cache
|
|
402
792
|
|
|
403
|
-
# handle no batch, for easier time rolling out against envs
|
|
793
|
+
# handle no batch or time, for easier time rolling out against envs
|
|
404
794
|
|
|
405
795
|
if not has_batch_dim:
|
|
406
796
|
state = rearrange(state, '... -> 1 ...')
|
|
407
797
|
|
|
798
|
+
if not has_time_dim:
|
|
799
|
+
state = rearrange(state, '... d -> ... 1 d')
|
|
800
|
+
|
|
408
801
|
# forwards
|
|
409
802
|
|
|
410
803
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
@@ -416,7 +809,10 @@ class Locoformer(Module):
|
|
|
416
809
|
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
417
810
|
cache = cache[..., -window_size:, :]
|
|
418
811
|
|
|
419
|
-
# maybe remove batch
|
|
812
|
+
# maybe remove batch or time
|
|
813
|
+
|
|
814
|
+
if not has_time_dim:
|
|
815
|
+
out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
|
|
420
816
|
|
|
421
817
|
if not has_batch_dim:
|
|
422
818
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -450,6 +846,8 @@ class Locoformer(Module):
|
|
|
450
846
|
return_values = False
|
|
451
847
|
):
|
|
452
848
|
|
|
849
|
+
state = state.to(self.device)
|
|
850
|
+
|
|
453
851
|
tokens = self.embedder(state)
|
|
454
852
|
|
|
455
853
|
embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
|
|
@@ -463,7 +861,7 @@ class Locoformer(Module):
|
|
|
463
861
|
# maybe detach cache
|
|
464
862
|
|
|
465
863
|
if detach_cache:
|
|
466
|
-
kv_cache =
|
|
864
|
+
kv_cache = kv_cache.detach()
|
|
467
865
|
|
|
468
866
|
# handle returning of values
|
|
469
867
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "locoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.17"
|
|
4
4
|
description = "LocoFormer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -27,6 +27,7 @@ classifiers=[
|
|
|
27
27
|
|
|
28
28
|
dependencies = [
|
|
29
29
|
"assoc-scan",
|
|
30
|
+
"beartype",
|
|
30
31
|
"einx>=0.3.0",
|
|
31
32
|
"einops>=0.8.0",
|
|
32
33
|
"rotary-embedding-torch",
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from x_mlps_pytorch import MLP
|
|
6
|
+
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
9
|
+
def test_locoformer():
|
|
10
|
+
from locoformer.locoformer import Locoformer
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
model = Locoformer(
|
|
14
|
+
embedder = nn.Embedding(256, 128),
|
|
15
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
+
value_network = MLP(128, 32, 1),
|
|
17
|
+
transformer = dict(
|
|
18
|
+
dim = 128,
|
|
19
|
+
depth = 1,
|
|
20
|
+
window_size = 512
|
|
21
|
+
)
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
seq = torch.randint(0, 256, (3, 512))
|
|
25
|
+
|
|
26
|
+
(logits, values), cache = model(seq, return_values = True)
|
|
27
|
+
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
28
|
+
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
29
|
+
|
|
30
|
+
assert logits.shape == (3, 512, 256)
|
|
31
|
+
|
|
32
|
+
stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
|
|
33
|
+
|
|
34
|
+
for state in seq.unbind(dim = -1):
|
|
35
|
+
state = rearrange(state, 'b -> b 1')
|
|
36
|
+
|
|
37
|
+
logits, values = stateful_forward(state)
|
|
38
|
+
assert logits.shape == (3, 1, 256)
|
|
39
|
+
|
|
40
|
+
def test_replay():
|
|
41
|
+
from locoformer.locoformer import ReplayBuffer
|
|
42
|
+
|
|
43
|
+
replay_buffer = ReplayBuffer(
|
|
44
|
+
'./replay_data',
|
|
45
|
+
max_episodes = 10_000,
|
|
46
|
+
max_timesteps = 501,
|
|
47
|
+
fields = dict(
|
|
48
|
+
state = ('float', (8,)),
|
|
49
|
+
action = 'int',
|
|
50
|
+
action_log_prob = 'float',
|
|
51
|
+
reward = 'float',
|
|
52
|
+
value = 'float',
|
|
53
|
+
done = 'bool'
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
lens = [3, 5, 4]
|
|
58
|
+
|
|
59
|
+
for episode_len in lens:
|
|
60
|
+
with replay_buffer.one_episode():
|
|
61
|
+
for _ in range(episode_len):
|
|
62
|
+
state = torch.randn((8,))
|
|
63
|
+
action = torch.randint(0, 4, ())
|
|
64
|
+
log_prob = torch.randn(())
|
|
65
|
+
reward = torch.randn(())
|
|
66
|
+
value = torch.randn(())
|
|
67
|
+
done = torch.randint(0, 2, ()).bool()
|
|
68
|
+
|
|
69
|
+
replay_buffer.store(
|
|
70
|
+
state = state,
|
|
71
|
+
action = action,
|
|
72
|
+
action_log_prob = log_prob,
|
|
73
|
+
reward = reward,
|
|
74
|
+
value = value,
|
|
75
|
+
done = done
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
dataset = replay_buffer.dataset()
|
|
79
|
+
|
|
80
|
+
assert len(dataset) == 3
|
|
81
|
+
|
|
82
|
+
assert torch.is_tensor(dataset[0]['state'])
|
|
83
|
+
|
|
84
|
+
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
85
|
+
|
|
86
|
+
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
|
|
|
169
169
|
prime = prime.to(model.device)
|
|
170
170
|
out = prime
|
|
171
171
|
|
|
172
|
-
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
|
|
172
|
+
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
|
|
173
173
|
|
|
174
174
|
# sample
|
|
175
175
|
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "accelerate",
|
|
4
|
+
# "fire",
|
|
5
|
+
# "gymnasium[box2d]>=1.0.0",
|
|
6
|
+
# "locoformer>=0.0.12",
|
|
7
|
+
# "moviepy",
|
|
8
|
+
# "tqdm"
|
|
9
|
+
# ]
|
|
10
|
+
# ///
|
|
11
|
+
|
|
12
|
+
from fire import Fire
|
|
13
|
+
from shutil import rmtree
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from collections import deque
|
|
16
|
+
from types import SimpleNamespace
|
|
17
|
+
|
|
18
|
+
from accelerate import Accelerator
|
|
19
|
+
|
|
20
|
+
import gymnasium as gym
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
from torch import from_numpy, randint, tensor, stack, arange
|
|
24
|
+
import torch.nn.functional as F
|
|
25
|
+
from torch.utils.data import TensorDataset, DataLoader
|
|
26
|
+
from torch.optim import Adam
|
|
27
|
+
|
|
28
|
+
import einx
|
|
29
|
+
from einops import rearrange
|
|
30
|
+
|
|
31
|
+
from locoformer.locoformer import Locoformer, ReplayBuffer
|
|
32
|
+
from x_mlps_pytorch import MLP
|
|
33
|
+
|
|
34
|
+
# helper functions
|
|
35
|
+
|
|
36
|
+
def exists(v):
|
|
37
|
+
return v is not None
|
|
38
|
+
|
|
39
|
+
def divisible_by(num, den):
|
|
40
|
+
return (num % den) == 0
|
|
41
|
+
|
|
42
|
+
def log(t, eps = 1e-20):
|
|
43
|
+
return t.clamp(min = eps).log()
|
|
44
|
+
|
|
45
|
+
def gumbel_noise(t):
|
|
46
|
+
return -log(-log(torch.rand_like(t)))
|
|
47
|
+
|
|
48
|
+
def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
49
|
+
noise = gumbel_noise(logits)
|
|
50
|
+
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
51
|
+
|
|
52
|
+
# learn
|
|
53
|
+
|
|
54
|
+
def learn(
|
|
55
|
+
model,
|
|
56
|
+
actor_optim,
|
|
57
|
+
critic_optim,
|
|
58
|
+
accelerator,
|
|
59
|
+
replay,
|
|
60
|
+
batch_size = 16,
|
|
61
|
+
epochs = 2,
|
|
62
|
+
):
|
|
63
|
+
device = accelerator.device
|
|
64
|
+
|
|
65
|
+
dl = replay.dataloader(batch_size = batch_size, shuffle = True)
|
|
66
|
+
model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
|
|
67
|
+
|
|
68
|
+
for _ in range(epochs):
|
|
69
|
+
for data in dl:
|
|
70
|
+
|
|
71
|
+
data = SimpleNamespace(**data)
|
|
72
|
+
|
|
73
|
+
seq_len = data.state.shape[1]
|
|
74
|
+
|
|
75
|
+
value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
|
|
76
|
+
value = torch.where(value_mask, data.value, 0.)
|
|
77
|
+
|
|
78
|
+
actor_loss, critic_loss = model.ppo(
|
|
79
|
+
state = data.state,
|
|
80
|
+
action = data.action,
|
|
81
|
+
old_action_log_prob = data.action_log_prob,
|
|
82
|
+
reward = data.reward,
|
|
83
|
+
old_value = value,
|
|
84
|
+
mask = data.learnable,
|
|
85
|
+
actor_optim = actor_optim,
|
|
86
|
+
critic_optim = critic_optim
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
|
|
90
|
+
|
|
91
|
+
# main function
|
|
92
|
+
|
|
93
|
+
def main(
|
|
94
|
+
env_name = 'LunarLander-v3',
|
|
95
|
+
num_episodes = 50_000,
|
|
96
|
+
max_timesteps = 500,
|
|
97
|
+
num_episodes_before_learn = 32,
|
|
98
|
+
clear_video = True,
|
|
99
|
+
video_folder = 'recordings',
|
|
100
|
+
record_every_episode = 250,
|
|
101
|
+
learning_rate = 8e-4,
|
|
102
|
+
discount_factor = 0.99,
|
|
103
|
+
betas = (0.9, 0.99),
|
|
104
|
+
gae_lam = 0.95,
|
|
105
|
+
ppo_eps_clip = 0.2,
|
|
106
|
+
ppo_entropy_weight = .01,
|
|
107
|
+
batch_size = 16,
|
|
108
|
+
epochs = 2
|
|
109
|
+
):
|
|
110
|
+
|
|
111
|
+
# accelerate
|
|
112
|
+
|
|
113
|
+
accelerator = Accelerator()
|
|
114
|
+
device = accelerator.device
|
|
115
|
+
|
|
116
|
+
# environment
|
|
117
|
+
|
|
118
|
+
env = gym.make(env_name, render_mode = 'rgb_array')
|
|
119
|
+
|
|
120
|
+
if clear_video:
|
|
121
|
+
rmtree(video_folder, ignore_errors = True)
|
|
122
|
+
|
|
123
|
+
env = gym.wrappers.RecordVideo(
|
|
124
|
+
env = env,
|
|
125
|
+
video_folder = video_folder,
|
|
126
|
+
name_prefix = 'lunar-video',
|
|
127
|
+
episode_trigger = lambda eps: divisible_by(eps, record_every_episode),
|
|
128
|
+
disable_logger = True
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
dim_state = env.observation_space.shape[0]
|
|
132
|
+
num_actions = env.action_space.n
|
|
133
|
+
|
|
134
|
+
# memory
|
|
135
|
+
|
|
136
|
+
replay = ReplayBuffer(
|
|
137
|
+
'replay',
|
|
138
|
+
num_episodes,
|
|
139
|
+
max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
|
|
140
|
+
fields = dict(
|
|
141
|
+
state = ('float', (dim_state,)),
|
|
142
|
+
action = 'int',
|
|
143
|
+
action_log_prob = 'float',
|
|
144
|
+
reward = 'float',
|
|
145
|
+
value = 'float',
|
|
146
|
+
done = 'bool',
|
|
147
|
+
learnable = 'bool'
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# networks
|
|
152
|
+
|
|
153
|
+
locoformer = Locoformer(
|
|
154
|
+
embedder = MLP(dim_state, 64, bias = False),
|
|
155
|
+
unembedder = MLP(64, num_actions, bias = False),
|
|
156
|
+
value_network = MLP(64, 1, bias = False),
|
|
157
|
+
transformer = dict(
|
|
158
|
+
dim = 64,
|
|
159
|
+
dim_head = 32,
|
|
160
|
+
heads = 4,
|
|
161
|
+
depth = 4,
|
|
162
|
+
window_size = 16
|
|
163
|
+
),
|
|
164
|
+
discount_factor = discount_factor,
|
|
165
|
+
gae_lam = gae_lam,
|
|
166
|
+
ppo_eps_clip = ppo_eps_clip,
|
|
167
|
+
ppo_entropy_weight = ppo_entropy_weight,
|
|
168
|
+
calc_gae_kwargs = dict(
|
|
169
|
+
use_accelerated = False
|
|
170
|
+
)
|
|
171
|
+
).to(device)
|
|
172
|
+
|
|
173
|
+
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
|
|
174
|
+
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
|
|
175
|
+
|
|
176
|
+
timesteps_learn = 0
|
|
177
|
+
|
|
178
|
+
# able to wrap the env for all values to torch tensors and back
|
|
179
|
+
# all environments should follow usual MDP interface, domain randomization should be given at instantiation
|
|
180
|
+
|
|
181
|
+
env_reset, env_step = locoformer.wrap_env_functions(env)
|
|
182
|
+
|
|
183
|
+
# loop
|
|
184
|
+
|
|
185
|
+
for episodes_index in tqdm(range(num_episodes)):
|
|
186
|
+
|
|
187
|
+
state, *_ = env_reset()
|
|
188
|
+
|
|
189
|
+
timestep = 0
|
|
190
|
+
|
|
191
|
+
stateful_forward = locoformer.get_stateful_forward(has_batch_dim = False, has_time_dim = False, inference_mode = True)
|
|
192
|
+
|
|
193
|
+
with replay.one_episode():
|
|
194
|
+
while True:
|
|
195
|
+
|
|
196
|
+
# predict next action
|
|
197
|
+
|
|
198
|
+
action_logits, value = stateful_forward(state, return_values = True)
|
|
199
|
+
|
|
200
|
+
action = gumbel_sample(action_logits)
|
|
201
|
+
|
|
202
|
+
# pass to environment
|
|
203
|
+
|
|
204
|
+
next_state, reward, truncated, terminated, *_ = env_step(action)
|
|
205
|
+
|
|
206
|
+
# append to memory
|
|
207
|
+
|
|
208
|
+
done = truncated or terminated
|
|
209
|
+
|
|
210
|
+
# get log prob of action
|
|
211
|
+
|
|
212
|
+
action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
|
|
213
|
+
action_log_prob = rearrange(action_log_prob, '1 ->')
|
|
214
|
+
|
|
215
|
+
memory = replay.store(
|
|
216
|
+
state = state,
|
|
217
|
+
action = action,
|
|
218
|
+
action_log_prob = action_log_prob,
|
|
219
|
+
reward = reward,
|
|
220
|
+
value = value,
|
|
221
|
+
done = done,
|
|
222
|
+
learnable = tensor(True)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
226
|
+
# only if terminated signal not detected
|
|
227
|
+
|
|
228
|
+
if not terminated:
|
|
229
|
+
_, next_value = stateful_forward(next_state, return_values = True)
|
|
230
|
+
|
|
231
|
+
memory._replace(value = next_value, learnable = False)
|
|
232
|
+
|
|
233
|
+
replay.store(**memory._asdict())
|
|
234
|
+
|
|
235
|
+
# increment counters
|
|
236
|
+
|
|
237
|
+
timestep += 1
|
|
238
|
+
|
|
239
|
+
# break if done or exceed max timestep
|
|
240
|
+
|
|
241
|
+
if done or timestep >= max_timesteps:
|
|
242
|
+
break
|
|
243
|
+
|
|
244
|
+
state = next_state
|
|
245
|
+
|
|
246
|
+
# learn if hit the number of learn timesteps
|
|
247
|
+
|
|
248
|
+
if divisible_by(episodes_index + 1, num_episodes_before_learn):
|
|
249
|
+
|
|
250
|
+
learn(
|
|
251
|
+
locoformer,
|
|
252
|
+
optim_actor,
|
|
253
|
+
optim_critic,
|
|
254
|
+
accelerator,
|
|
255
|
+
replay,
|
|
256
|
+
batch_size,
|
|
257
|
+
epochs,
|
|
258
|
+
)
|
|
259
|
+
# main
|
|
260
|
+
|
|
261
|
+
if __name__ == '__main__':
|
|
262
|
+
Fire(main)
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
param = pytest.mark.parametrize
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from x_mlps_pytorch import MLP
|
|
6
|
-
|
|
7
|
-
from einops import rearrange
|
|
8
|
-
|
|
9
|
-
def test_locoformer():
|
|
10
|
-
from locoformer.locoformer import Locoformer
|
|
11
|
-
from torch import nn
|
|
12
|
-
|
|
13
|
-
model = Locoformer(
|
|
14
|
-
embedder = nn.Embedding(256, 128),
|
|
15
|
-
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
-
value_network = MLP(128, 32, 1),
|
|
17
|
-
transformer = dict(
|
|
18
|
-
dim = 128,
|
|
19
|
-
depth = 1,
|
|
20
|
-
window_size = 256
|
|
21
|
-
)
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
seq = torch.randint(0, 256, (3, 512))
|
|
25
|
-
|
|
26
|
-
(logits, values), cache = model(seq, return_values = True)
|
|
27
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
28
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
29
|
-
|
|
30
|
-
assert logits.shape == (3, 512, 256)
|
|
31
|
-
|
|
32
|
-
stateful_forward = model.get_stateful_forward(256, has_batch_dim = True, return_values = True, inference_mode = True)
|
|
33
|
-
|
|
34
|
-
for state in seq.unbind(dim = -1):
|
|
35
|
-
state = rearrange(state, 'b -> b 1')
|
|
36
|
-
|
|
37
|
-
logits, values = stateful_forward(state)
|
|
38
|
-
assert logits.shape == (3, 1, 256)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|