locoformer 0.0.6__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.6 → locoformer-0.0.11}/.gitignore +2 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/PKG-INFO +3 -2
- {locoformer-0.0.6 → locoformer-0.0.11}/README.md +1 -1
- {locoformer-0.0.6 → locoformer-0.0.11}/locoformer/locoformer.py +270 -10
- {locoformer-0.0.6 → locoformer-0.0.11}/pyproject.toml +2 -1
- locoformer-0.0.11/tests/test_locoformer.py +86 -0
- locoformer-0.0.11/train_gym.py +193 -0
- locoformer-0.0.6/tests/test_locoformer.py +0 -38
- {locoformer-0.0.6 → locoformer-0.0.11}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/LICENSE +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/data/README.md +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/data/enwik8.gz +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/fig3.png +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/locoformer/__init__.py +0 -0
- {locoformer-0.0.6 → locoformer-0.0.11}/train.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.11
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.0
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
|
|
|
53
54
|
|
|
54
55
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
55
56
|
|
|
56
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
58
|
|
|
58
59
|
## Sponsors
|
|
59
60
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
6
6
|
|
|
7
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
7
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
8
8
|
|
|
9
9
|
## Sponsors
|
|
10
10
|
|
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
from numpy.lib.format import open_memmap
|
|
10
|
+
|
|
11
|
+
from beartype import beartype
|
|
12
|
+
from beartype.door import is_bearable
|
|
13
|
+
|
|
4
14
|
import torch
|
|
5
|
-
from torch import nn, cat, stack, arange, is_tensor
|
|
15
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
6
16
|
import torch.nn.functional as F
|
|
7
17
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
8
18
|
from torch.utils._pytree import tree_map
|
|
19
|
+
from torch.utils.data import Dataset, DataLoader
|
|
9
20
|
|
|
10
21
|
import einx
|
|
11
22
|
from einops import rearrange, einsum
|
|
@@ -37,14 +48,18 @@ def tree_map_tensor(x, fn):
|
|
|
37
48
|
def detach_all(x):
|
|
38
49
|
return tree_map_tensor(x, lambda t: t.detach())
|
|
39
50
|
|
|
40
|
-
def
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
51
|
+
def pad_at_dim(
|
|
52
|
+
t,
|
|
53
|
+
pad: tuple[int, int],
|
|
54
|
+
dim = -1,
|
|
55
|
+
value = 0.
|
|
56
|
+
):
|
|
57
|
+
if pad == (0, 0):
|
|
58
|
+
return t
|
|
46
59
|
|
|
47
|
-
|
|
60
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
61
|
+
zeros = ((0, 0) * dims_from_right)
|
|
62
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
48
63
|
|
|
49
64
|
# generalized advantage estimate
|
|
50
65
|
|
|
@@ -146,6 +161,208 @@ def create_sliding_mask(
|
|
|
146
161
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
147
162
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
148
163
|
|
|
164
|
+
# data
|
|
165
|
+
|
|
166
|
+
def collate_var_time(data):
|
|
167
|
+
|
|
168
|
+
datum = first(data)
|
|
169
|
+
keys = datum.keys()
|
|
170
|
+
|
|
171
|
+
all_tensors = zip(*[datum.values() for datum in data])
|
|
172
|
+
|
|
173
|
+
collated_values = []
|
|
174
|
+
|
|
175
|
+
for key, tensors in zip(keys, all_tensors):
|
|
176
|
+
|
|
177
|
+
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
178
|
+
|
|
179
|
+
if key != '_lens':
|
|
180
|
+
|
|
181
|
+
times = [t.shape[0] for t in tensors]
|
|
182
|
+
max_time = max(times)
|
|
183
|
+
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
184
|
+
|
|
185
|
+
collated_values.append(stack(tensors))
|
|
186
|
+
|
|
187
|
+
return dict(zip(keys, collated_values))
|
|
188
|
+
|
|
189
|
+
class ReplayDataset(Dataset):
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
folder: str | Path,
|
|
193
|
+
fields: tuple[str, ...] | None = None
|
|
194
|
+
):
|
|
195
|
+
if isinstance(folder, str):
|
|
196
|
+
folder = Path(folder)
|
|
197
|
+
|
|
198
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
199
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
200
|
+
|
|
201
|
+
# get indices of non-zero lengthed episodes
|
|
202
|
+
|
|
203
|
+
nonzero_episodes = self.episode_lens > 0
|
|
204
|
+
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
205
|
+
|
|
206
|
+
# get all data files
|
|
207
|
+
|
|
208
|
+
filepaths = [*folder.glob('*.data.npy')]
|
|
209
|
+
assert len(filepaths) > 0
|
|
210
|
+
|
|
211
|
+
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
212
|
+
|
|
213
|
+
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
214
|
+
|
|
215
|
+
fields = default(fields, fieldnames_from_files)
|
|
216
|
+
|
|
217
|
+
self.memmaps = dict()
|
|
218
|
+
|
|
219
|
+
for field in fields:
|
|
220
|
+
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
221
|
+
|
|
222
|
+
path = fieldname_to_filepath[field]
|
|
223
|
+
|
|
224
|
+
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
225
|
+
|
|
226
|
+
def __len__(self):
|
|
227
|
+
return len(self.indices)
|
|
228
|
+
|
|
229
|
+
def __getitem__(self, idx):
|
|
230
|
+
episode_index = self.indices[idx]
|
|
231
|
+
|
|
232
|
+
episode_len = self.episode_lens[episode_index]
|
|
233
|
+
|
|
234
|
+
data = {field: torch.from_numpy(memmap[episode_index, :episode_len]) for field, memmap in self.memmaps.items()}
|
|
235
|
+
|
|
236
|
+
data['_lens'] = tensor(episode_len)
|
|
237
|
+
|
|
238
|
+
return data
|
|
239
|
+
|
|
240
|
+
class ReplayBuffer:
|
|
241
|
+
|
|
242
|
+
@beartype
|
|
243
|
+
def __init__(
|
|
244
|
+
self,
|
|
245
|
+
folder: str | Path,
|
|
246
|
+
max_episodes: int,
|
|
247
|
+
max_timesteps: int,
|
|
248
|
+
fields: dict[
|
|
249
|
+
str,
|
|
250
|
+
str | tuple[str, int | tuple[int, ...]]
|
|
251
|
+
]
|
|
252
|
+
):
|
|
253
|
+
|
|
254
|
+
# folder for data
|
|
255
|
+
|
|
256
|
+
if not isinstance(folder, Path):
|
|
257
|
+
folder = Path(folder)
|
|
258
|
+
folder.mkdir(exist_ok = True)
|
|
259
|
+
|
|
260
|
+
self.folder = folder
|
|
261
|
+
assert folder.is_dir()
|
|
262
|
+
|
|
263
|
+
# keeping track of episode length
|
|
264
|
+
|
|
265
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
266
|
+
|
|
267
|
+
self.episode_index = 0
|
|
268
|
+
self.timestep_index = 0
|
|
269
|
+
|
|
270
|
+
self.max_episodes = max_episodes
|
|
271
|
+
self.max_timesteps= max_timesteps
|
|
272
|
+
|
|
273
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
|
|
274
|
+
|
|
275
|
+
# create the memmap for individual data tracks
|
|
276
|
+
|
|
277
|
+
self.shapes = dict()
|
|
278
|
+
self.dtypes = dict()
|
|
279
|
+
self.memmaps = dict()
|
|
280
|
+
self.fieldnames = set(fields.keys())
|
|
281
|
+
|
|
282
|
+
for field_name, field_info in fields.items():
|
|
283
|
+
|
|
284
|
+
# some flexibility
|
|
285
|
+
|
|
286
|
+
field_info = (field_info, ()) if isinstance(field_info, str) else field_info
|
|
287
|
+
|
|
288
|
+
dtype_str, shape = field_info
|
|
289
|
+
assert dtype_str in {'int', 'float', 'bool'}
|
|
290
|
+
|
|
291
|
+
dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
|
|
292
|
+
|
|
293
|
+
# memmap file
|
|
294
|
+
|
|
295
|
+
filepath = folder / f'{field_name}.data.npy'
|
|
296
|
+
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
297
|
+
|
|
298
|
+
self.memmaps[field_name] = memmap
|
|
299
|
+
self.shapes[field_name] = shape
|
|
300
|
+
self.dtypes[field_name] = dtype
|
|
301
|
+
|
|
302
|
+
def advance_episode(self):
|
|
303
|
+
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
304
|
+
self.timestep_index = 0
|
|
305
|
+
|
|
306
|
+
def flush(self):
|
|
307
|
+
self.episode_lens[self.episode_index] = self.timestep_index
|
|
308
|
+
|
|
309
|
+
for memmap in self.memmaps.values():
|
|
310
|
+
memmap.flush()
|
|
311
|
+
|
|
312
|
+
self.episode_lens.flush()
|
|
313
|
+
|
|
314
|
+
@contextmanager
|
|
315
|
+
def one_episode(self):
|
|
316
|
+
|
|
317
|
+
yield
|
|
318
|
+
|
|
319
|
+
self.flush()
|
|
320
|
+
self.advance_episode()
|
|
321
|
+
|
|
322
|
+
@beartype
|
|
323
|
+
def store_datapoint(
|
|
324
|
+
self,
|
|
325
|
+
episode_index: int,
|
|
326
|
+
timestep_index: int,
|
|
327
|
+
name: str,
|
|
328
|
+
datapoint: Tensor | ndarray
|
|
329
|
+
):
|
|
330
|
+
assert 0 <= episode_index < self.max_episodes
|
|
331
|
+
assert 0 <= timestep_index < self.max_timesteps
|
|
332
|
+
|
|
333
|
+
if is_tensor(datapoint):
|
|
334
|
+
datapoint = datapoint.detach().cpu().numpy()
|
|
335
|
+
|
|
336
|
+
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
337
|
+
|
|
338
|
+
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
339
|
+
|
|
340
|
+
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
341
|
+
|
|
342
|
+
def store(
|
|
343
|
+
self,
|
|
344
|
+
**data
|
|
345
|
+
):
|
|
346
|
+
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
347
|
+
|
|
348
|
+
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
349
|
+
|
|
350
|
+
for name, datapoint in data.items():
|
|
351
|
+
|
|
352
|
+
self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
|
|
353
|
+
|
|
354
|
+
self.timestep_index += 1
|
|
355
|
+
|
|
356
|
+
def dataset(self) -> Dataset:
|
|
357
|
+
self.flush()
|
|
358
|
+
|
|
359
|
+
return ReplayDataset(self.folder)
|
|
360
|
+
|
|
361
|
+
def dataloader(self, **kwargs) -> DataLoader:
|
|
362
|
+
self.flush()
|
|
363
|
+
|
|
364
|
+
return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
|
|
365
|
+
|
|
149
366
|
# transformer-xl with ppo
|
|
150
367
|
|
|
151
368
|
class Attention(Module):
|
|
@@ -386,11 +603,46 @@ class Locoformer(Module):
|
|
|
386
603
|
def device(self):
|
|
387
604
|
return next(self.parameters()).device
|
|
388
605
|
|
|
606
|
+
def actor_parameters(self):
|
|
607
|
+
return self.unembedder.parameters()
|
|
608
|
+
|
|
609
|
+
def critic_parameters(self):
|
|
610
|
+
if not exists(self.value_network):
|
|
611
|
+
return []
|
|
612
|
+
|
|
613
|
+
return self.value_network.parameters()
|
|
614
|
+
|
|
615
|
+
def wrap_env_functions(self, env):
|
|
616
|
+
|
|
617
|
+
def wrapped_reset(*args, **kwargs):
|
|
618
|
+
state, _ = env.reset(*args, **kwargs)
|
|
619
|
+
|
|
620
|
+
if isinstance(state, ndarray):
|
|
621
|
+
state = from_numpy(state)
|
|
622
|
+
|
|
623
|
+
return state, _
|
|
624
|
+
|
|
625
|
+
def wrapped_step(action, *args, **kwargs):
|
|
626
|
+
out = env.step(action.item(), *args, **kwargs)
|
|
627
|
+
|
|
628
|
+
def transform_output(el):
|
|
629
|
+
if isinstance(el, ndarray):
|
|
630
|
+
return from_numpy(el)
|
|
631
|
+
elif isinstance(el, (int, bool, float)):
|
|
632
|
+
return tensor(el)
|
|
633
|
+
else:
|
|
634
|
+
return el
|
|
635
|
+
|
|
636
|
+
return tree_map(transform_output, out)
|
|
637
|
+
|
|
638
|
+
return wrapped_reset, wrapped_step
|
|
639
|
+
|
|
389
640
|
def get_stateful_forward(
|
|
390
641
|
self,
|
|
391
642
|
initial_states: Tensor | None = None,
|
|
392
643
|
inference_mode = False,
|
|
393
644
|
has_batch_dim = False,
|
|
645
|
+
has_time_dim = False,
|
|
394
646
|
**kwargs
|
|
395
647
|
):
|
|
396
648
|
window_size = self.window_size
|
|
@@ -400,11 +652,14 @@ class Locoformer(Module):
|
|
|
400
652
|
def stateful_forward(state: Tensor, **override_kwargs):
|
|
401
653
|
nonlocal cache
|
|
402
654
|
|
|
403
|
-
# handle no batch, for easier time rolling out against envs
|
|
655
|
+
# handle no batch or time, for easier time rolling out against envs
|
|
404
656
|
|
|
405
657
|
if not has_batch_dim:
|
|
406
658
|
state = rearrange(state, '... -> 1 ...')
|
|
407
659
|
|
|
660
|
+
if not has_time_dim:
|
|
661
|
+
state = rearrange(state, '... d -> ... 1 d')
|
|
662
|
+
|
|
408
663
|
# forwards
|
|
409
664
|
|
|
410
665
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
@@ -416,7 +671,10 @@ class Locoformer(Module):
|
|
|
416
671
|
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
417
672
|
cache = cache[..., -window_size:, :]
|
|
418
673
|
|
|
419
|
-
# maybe remove batch
|
|
674
|
+
# maybe remove batch or time
|
|
675
|
+
|
|
676
|
+
if not has_time_dim:
|
|
677
|
+
out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
|
|
420
678
|
|
|
421
679
|
if not has_batch_dim:
|
|
422
680
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -450,6 +708,8 @@ class Locoformer(Module):
|
|
|
450
708
|
return_values = False
|
|
451
709
|
):
|
|
452
710
|
|
|
711
|
+
state = state.to(self.device)
|
|
712
|
+
|
|
453
713
|
tokens = self.embedder(state)
|
|
454
714
|
|
|
455
715
|
embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "locoformer"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.11"
|
|
4
4
|
description = "LocoFormer"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -27,6 +27,7 @@ classifiers=[
|
|
|
27
27
|
|
|
28
28
|
dependencies = [
|
|
29
29
|
"assoc-scan",
|
|
30
|
+
"beartype",
|
|
30
31
|
"einx>=0.3.0",
|
|
31
32
|
"einops>=0.8.0",
|
|
32
33
|
"rotary-embedding-torch",
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
param = pytest.mark.parametrize
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from x_mlps_pytorch import MLP
|
|
6
|
+
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
|
|
9
|
+
def test_locoformer():
|
|
10
|
+
from locoformer.locoformer import Locoformer
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
model = Locoformer(
|
|
14
|
+
embedder = nn.Embedding(256, 128),
|
|
15
|
+
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
+
value_network = MLP(128, 32, 1),
|
|
17
|
+
transformer = dict(
|
|
18
|
+
dim = 128,
|
|
19
|
+
depth = 1,
|
|
20
|
+
window_size = 512
|
|
21
|
+
)
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
seq = torch.randint(0, 256, (3, 512))
|
|
25
|
+
|
|
26
|
+
(logits, values), cache = model(seq, return_values = True)
|
|
27
|
+
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
28
|
+
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
29
|
+
|
|
30
|
+
assert logits.shape == (3, 512, 256)
|
|
31
|
+
|
|
32
|
+
stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
|
|
33
|
+
|
|
34
|
+
for state in seq.unbind(dim = -1):
|
|
35
|
+
state = rearrange(state, 'b -> b 1')
|
|
36
|
+
|
|
37
|
+
logits, values = stateful_forward(state)
|
|
38
|
+
assert logits.shape == (3, 1, 256)
|
|
39
|
+
|
|
40
|
+
def test_replay():
|
|
41
|
+
from locoformer.locoformer import ReplayBuffer
|
|
42
|
+
|
|
43
|
+
replay_buffer = ReplayBuffer(
|
|
44
|
+
'./replay_data',
|
|
45
|
+
max_episodes = 10_000,
|
|
46
|
+
max_timesteps = 501,
|
|
47
|
+
fields = dict(
|
|
48
|
+
state = ('float', (8,)),
|
|
49
|
+
action = 'int',
|
|
50
|
+
action_log_prob = 'float',
|
|
51
|
+
reward = 'float',
|
|
52
|
+
value = 'float',
|
|
53
|
+
done = 'bool'
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
lens = [3, 5, 4]
|
|
58
|
+
|
|
59
|
+
for episode_len in lens:
|
|
60
|
+
with replay_buffer.one_episode():
|
|
61
|
+
for _ in range(episode_len):
|
|
62
|
+
state = torch.randn((8,))
|
|
63
|
+
action = torch.randint(0, 4, ())
|
|
64
|
+
log_prob = torch.randn(())
|
|
65
|
+
reward = torch.randn(())
|
|
66
|
+
value = torch.randn(())
|
|
67
|
+
done = torch.randint(0, 2, ()).bool()
|
|
68
|
+
|
|
69
|
+
replay_buffer.store(
|
|
70
|
+
state = state,
|
|
71
|
+
action = action,
|
|
72
|
+
action_log_prob = log_prob,
|
|
73
|
+
reward = reward,
|
|
74
|
+
value = value,
|
|
75
|
+
done = done
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
dataset = replay_buffer.dataset()
|
|
79
|
+
|
|
80
|
+
assert len(dataset) == 3
|
|
81
|
+
|
|
82
|
+
assert torch.is_tensor(dataset[0]['state'])
|
|
83
|
+
|
|
84
|
+
dataloader = replay_buffer.dataloader(batch_size = 3)
|
|
85
|
+
|
|
86
|
+
assert next(iter(dataloader))['state'].shape[0] == 3
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# dependencies = [
|
|
3
|
+
# "accelerate",
|
|
4
|
+
# "fire",
|
|
5
|
+
# "gymnasium[box2d]>=1.0.0",
|
|
6
|
+
# "locoformer",
|
|
7
|
+
# "moviepy",
|
|
8
|
+
# "tqdm"
|
|
9
|
+
# ]
|
|
10
|
+
# ///
|
|
11
|
+
|
|
12
|
+
from fire import Fire
|
|
13
|
+
from shutil import rmtree
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from collections import deque
|
|
16
|
+
|
|
17
|
+
from accelerate import Accelerator
|
|
18
|
+
|
|
19
|
+
import gymnasium as gym
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch import from_numpy, randint, tensor, stack
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
from torch.utils.data import TensorDataset, DataLoader
|
|
25
|
+
from torch.optim import Adam
|
|
26
|
+
|
|
27
|
+
from einops import rearrange
|
|
28
|
+
|
|
29
|
+
from locoformer.locoformer import Locoformer, ReplayBuffer
|
|
30
|
+
from x_mlps_pytorch import MLP
|
|
31
|
+
|
|
32
|
+
# helper functions
|
|
33
|
+
|
|
34
|
+
def exists(v):
|
|
35
|
+
return v is not None
|
|
36
|
+
|
|
37
|
+
def divisible_by(num, den):
|
|
38
|
+
return (num % den) == 0
|
|
39
|
+
|
|
40
|
+
def log(t, eps = 1e-20):
|
|
41
|
+
return t.clamp(min = eps).log()
|
|
42
|
+
|
|
43
|
+
def gumbel_noise(t):
|
|
44
|
+
return -log(-log(torch.rand_like(t)))
|
|
45
|
+
|
|
46
|
+
def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
47
|
+
noise = gumbel_noise(logits)
|
|
48
|
+
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
49
|
+
|
|
50
|
+
# main function
|
|
51
|
+
|
|
52
|
+
def main(
|
|
53
|
+
env_name = 'LunarLander-v3',
|
|
54
|
+
num_episodes = 50_000,
|
|
55
|
+
max_timesteps = 500,
|
|
56
|
+
num_timestep_before_learn = 5000,
|
|
57
|
+
clear_video = True,
|
|
58
|
+
video_folder = 'recordings',
|
|
59
|
+
record_every_episode = 250,
|
|
60
|
+
discount_factor = 0.99,
|
|
61
|
+
learning_rate = 1e-4,
|
|
62
|
+
batch_size = 16,
|
|
63
|
+
epochs = 2
|
|
64
|
+
):
|
|
65
|
+
|
|
66
|
+
# accelerate
|
|
67
|
+
|
|
68
|
+
accelerate = Accelerator()
|
|
69
|
+
device = accelerate.device
|
|
70
|
+
|
|
71
|
+
# environment
|
|
72
|
+
|
|
73
|
+
env = gym.make(env_name, render_mode = 'rgb_array')
|
|
74
|
+
|
|
75
|
+
if clear_video:
|
|
76
|
+
rmtree(video_folder, ignore_errors = True)
|
|
77
|
+
|
|
78
|
+
env = gym.wrappers.RecordVideo(
|
|
79
|
+
env = env,
|
|
80
|
+
video_folder = video_folder,
|
|
81
|
+
name_prefix = 'lunar-video',
|
|
82
|
+
episode_trigger = lambda eps: divisible_by(eps, record_every_episode),
|
|
83
|
+
disable_logger = True
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
dim_state = env.observation_space.shape[0]
|
|
87
|
+
num_actions = env.action_space.n
|
|
88
|
+
|
|
89
|
+
# memory
|
|
90
|
+
|
|
91
|
+
replay = ReplayBuffer(
|
|
92
|
+
'replay',
|
|
93
|
+
num_episodes,
|
|
94
|
+
max_timesteps,
|
|
95
|
+
fields = dict(
|
|
96
|
+
state = ('float', (dim_state,)),
|
|
97
|
+
action = 'int',
|
|
98
|
+
action_log_prob = 'float',
|
|
99
|
+
reward = 'float',
|
|
100
|
+
value = 'float',
|
|
101
|
+
done = 'bool'
|
|
102
|
+
)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# networks
|
|
106
|
+
|
|
107
|
+
locoformer = Locoformer(
|
|
108
|
+
embedder = MLP(dim_state, 64, bias = False),
|
|
109
|
+
unembedder = MLP(64, num_actions, bias = False),
|
|
110
|
+
value_network = MLP(64, 1, bias = False),
|
|
111
|
+
transformer = dict(
|
|
112
|
+
dim = 64,
|
|
113
|
+
dim_head = 32,
|
|
114
|
+
heads = 4,
|
|
115
|
+
depth = 4,
|
|
116
|
+
window_size = 16
|
|
117
|
+
)
|
|
118
|
+
).to(device)
|
|
119
|
+
|
|
120
|
+
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
|
|
121
|
+
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
|
|
122
|
+
|
|
123
|
+
timesteps_learn = 0
|
|
124
|
+
|
|
125
|
+
# able to wrap the env for all values to torch tensors and back
|
|
126
|
+
# all environments should follow usual MDP interface, domain randomization should be given at instantiation
|
|
127
|
+
|
|
128
|
+
env_reset, env_step = locoformer.wrap_env_functions(env)
|
|
129
|
+
|
|
130
|
+
# loop
|
|
131
|
+
|
|
132
|
+
for _ in tqdm(range(num_episodes)):
|
|
133
|
+
state, *_ = env_reset()
|
|
134
|
+
|
|
135
|
+
timestep = 0
|
|
136
|
+
|
|
137
|
+
stateful_forward = locoformer.get_stateful_forward(has_batch_dim = False, has_time_dim = False, inference_mode = True)
|
|
138
|
+
|
|
139
|
+
with replay.one_episode():
|
|
140
|
+
while True:
|
|
141
|
+
|
|
142
|
+
# predict next action
|
|
143
|
+
|
|
144
|
+
action_logits, value = stateful_forward(state, return_values = True)
|
|
145
|
+
|
|
146
|
+
action = gumbel_sample(action_logits)
|
|
147
|
+
|
|
148
|
+
# pass to environment
|
|
149
|
+
|
|
150
|
+
next_state, reward, truncated, terminated, *_ = env_step(action)
|
|
151
|
+
|
|
152
|
+
# append to memory
|
|
153
|
+
|
|
154
|
+
done = truncated or terminated
|
|
155
|
+
|
|
156
|
+
# get log prob of action
|
|
157
|
+
|
|
158
|
+
action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
|
|
159
|
+
action_log_prob = rearrange(action_log_prob, '1 ->')
|
|
160
|
+
|
|
161
|
+
replay.store(
|
|
162
|
+
state = state,
|
|
163
|
+
action = action,
|
|
164
|
+
action_log_prob = action_log_prob,
|
|
165
|
+
reward = reward,
|
|
166
|
+
value = value,
|
|
167
|
+
done = done
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# increment counters
|
|
171
|
+
|
|
172
|
+
timestep += 1
|
|
173
|
+
timesteps_learn += 1
|
|
174
|
+
|
|
175
|
+
# learn if hit the number of learn timesteps
|
|
176
|
+
|
|
177
|
+
if timesteps_learn >= num_timestep_before_learn:
|
|
178
|
+
# todo - carry out learning
|
|
179
|
+
|
|
180
|
+
timesteps_learn = 0
|
|
181
|
+
memories.clear()
|
|
182
|
+
|
|
183
|
+
# break if done or exceed max timestep
|
|
184
|
+
|
|
185
|
+
if done or timestep >= max_timesteps:
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
state = next_state
|
|
189
|
+
|
|
190
|
+
# main
|
|
191
|
+
|
|
192
|
+
if __name__ == '__main__':
|
|
193
|
+
Fire(main)
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
param = pytest.mark.parametrize
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from x_mlps_pytorch import MLP
|
|
6
|
-
|
|
7
|
-
from einops import rearrange
|
|
8
|
-
|
|
9
|
-
def test_locoformer():
|
|
10
|
-
from locoformer.locoformer import Locoformer
|
|
11
|
-
from torch import nn
|
|
12
|
-
|
|
13
|
-
model = Locoformer(
|
|
14
|
-
embedder = nn.Embedding(256, 128),
|
|
15
|
-
unembedder = nn.Linear(128, 256, bias = False),
|
|
16
|
-
value_network = MLP(128, 32, 1),
|
|
17
|
-
transformer = dict(
|
|
18
|
-
dim = 128,
|
|
19
|
-
depth = 1,
|
|
20
|
-
window_size = 256
|
|
21
|
-
)
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
seq = torch.randint(0, 256, (3, 512))
|
|
25
|
-
|
|
26
|
-
(logits, values), cache = model(seq, return_values = True)
|
|
27
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
28
|
-
(logits, values), cache = model(seq, return_values = True, cache = cache)
|
|
29
|
-
|
|
30
|
-
assert logits.shape == (3, 512, 256)
|
|
31
|
-
|
|
32
|
-
stateful_forward = model.get_stateful_forward(256, has_batch_dim = True, return_values = True, inference_mode = True)
|
|
33
|
-
|
|
34
|
-
for state in seq.unbind(dim = -1):
|
|
35
|
-
state = rearrange(state, 'b -> b 1')
|
|
36
|
-
|
|
37
|
-
logits, values = stateful_forward(state)
|
|
38
|
-
assert logits.shape == (3, 1, 256)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|