locoformer 0.0.7__py3-none-any.whl → 0.0.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +600 -37
- {locoformer-0.0.7.dist-info → locoformer-0.0.29.dist-info}/METADATA +4 -2
- locoformer-0.0.29.dist-info/RECORD +6 -0
- locoformer-0.0.7.dist-info/RECORD +0 -6
- {locoformer-0.0.7.dist-info → locoformer-0.0.29.dist-info}/WHEEL +0 -0
- {locoformer-0.0.7.dist-info → locoformer-0.0.29.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,11 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from collections import namedtuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from numpy import ndarray
|
|
11
|
+
from numpy.lib.format import open_memmap
|
|
12
|
+
|
|
13
|
+
from beartype import beartype
|
|
14
|
+
from beartype.door import is_bearable
|
|
15
|
+
|
|
4
16
|
import torch
|
|
5
|
-
from torch import nn, cat, stack, arange, Tensor, is_tensor
|
|
17
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
6
18
|
import torch.nn.functional as F
|
|
7
19
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
8
20
|
from torch.utils._pytree import tree_map
|
|
21
|
+
from torch.utils.data import Dataset, DataLoader
|
|
22
|
+
from torch.optim import Optimizer
|
|
9
23
|
|
|
10
24
|
import einx
|
|
11
25
|
from einops import rearrange, einsum
|
|
@@ -13,10 +27,16 @@ from einops.layers.torch import Rearrange
|
|
|
13
27
|
|
|
14
28
|
from rotary_embedding_torch import RotaryEmbedding
|
|
15
29
|
|
|
30
|
+
from hl_gauss_pytorch import HLGaussLoss
|
|
31
|
+
|
|
16
32
|
from assoc_scan import AssocScan
|
|
17
33
|
|
|
34
|
+
# constants
|
|
35
|
+
|
|
18
36
|
LinearNoBias = partial(Linear, bias = False)
|
|
19
37
|
|
|
38
|
+
Cache = namedtuple('Cache', ('curr_timestep', 'kv_cache')) # (int, Tensor)
|
|
39
|
+
|
|
20
40
|
# helper functions
|
|
21
41
|
|
|
22
42
|
def exists(v):
|
|
@@ -31,20 +51,36 @@ def first(arr):
|
|
|
31
51
|
def divisible_by(num, den):
|
|
32
52
|
return (num % den) == 0
|
|
33
53
|
|
|
54
|
+
# tensor helpers
|
|
55
|
+
|
|
56
|
+
def log(t, eps = 1e-20):
|
|
57
|
+
return t.clamp_min(eps).log()
|
|
58
|
+
|
|
59
|
+
def is_empty(t):
|
|
60
|
+
return t.numel() == 0
|
|
61
|
+
|
|
34
62
|
def tree_map_tensor(x, fn):
|
|
35
63
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
36
64
|
|
|
37
|
-
def
|
|
38
|
-
|
|
65
|
+
def pad_at_dim(
|
|
66
|
+
t,
|
|
67
|
+
pad: tuple[int, int],
|
|
68
|
+
dim = -1,
|
|
69
|
+
value = 0.
|
|
70
|
+
):
|
|
71
|
+
if pad == (0, 0):
|
|
72
|
+
return t
|
|
39
73
|
|
|
40
|
-
|
|
41
|
-
|
|
74
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
75
|
+
zeros = ((0, 0) * dims_from_right)
|
|
76
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
42
77
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
combined_cache.append(next_cache)
|
|
78
|
+
def normalize(t, eps = 1e-5):
|
|
79
|
+
return (t - t.mean()) / t.std().clamp_min(eps)
|
|
46
80
|
|
|
47
|
-
|
|
81
|
+
def calc_entropy(logits):
|
|
82
|
+
prob = logits.softmax(dim = -1)
|
|
83
|
+
return -(prob * log(prob)).sum(dim = -1)
|
|
48
84
|
|
|
49
85
|
# generalized advantage estimate
|
|
50
86
|
|
|
@@ -52,7 +88,7 @@ def combine_kv_cache(cache1, cache2):
|
|
|
52
88
|
def calc_gae(
|
|
53
89
|
rewards,
|
|
54
90
|
values,
|
|
55
|
-
masks,
|
|
91
|
+
masks = None,
|
|
56
92
|
gamma = 0.99,
|
|
57
93
|
lam = 0.95,
|
|
58
94
|
use_accelerated = None
|
|
@@ -63,6 +99,9 @@ def calc_gae(
|
|
|
63
99
|
values = F.pad(values, (0, 1), value = 0.)
|
|
64
100
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
65
101
|
|
|
102
|
+
if not exists(masks):
|
|
103
|
+
masks = torch.ones_like(values)
|
|
104
|
+
|
|
66
105
|
delta = rewards + gamma * values_next * masks - values
|
|
67
106
|
gates = gamma * lam * masks
|
|
68
107
|
|
|
@@ -72,7 +111,7 @@ def calc_gae(
|
|
|
72
111
|
|
|
73
112
|
returns = gae + values
|
|
74
113
|
|
|
75
|
-
return returns
|
|
114
|
+
return gae, returns
|
|
76
115
|
|
|
77
116
|
# transformer-xl mask w/ flex attn
|
|
78
117
|
|
|
@@ -114,8 +153,8 @@ def create_xl_mask(
|
|
|
114
153
|
# handle intra-episodic attention if needed
|
|
115
154
|
|
|
116
155
|
if exists(episode_ids):
|
|
117
|
-
q_episode =
|
|
118
|
-
k_episode =
|
|
156
|
+
q_episode = episode_ids[b, q + offset]
|
|
157
|
+
k_episode = episode_ids[b, k]
|
|
119
158
|
|
|
120
159
|
intra_episode_mask = q_episode == k_episode
|
|
121
160
|
mask = mask & intra_episode_mask
|
|
@@ -146,6 +185,284 @@ def create_sliding_mask(
|
|
|
146
185
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
147
186
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
148
187
|
|
|
188
|
+
# data
|
|
189
|
+
|
|
190
|
+
def collate_var_time(data):
|
|
191
|
+
|
|
192
|
+
datum = first(data)
|
|
193
|
+
keys = datum.keys()
|
|
194
|
+
|
|
195
|
+
all_tensors = zip(*[datum.values() for datum in data])
|
|
196
|
+
|
|
197
|
+
collated_values = []
|
|
198
|
+
|
|
199
|
+
for key, tensors in zip(keys, all_tensors):
|
|
200
|
+
|
|
201
|
+
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
202
|
+
|
|
203
|
+
if key != '_lens':
|
|
204
|
+
|
|
205
|
+
times = [t.shape[0] for t in tensors]
|
|
206
|
+
max_time = max(times)
|
|
207
|
+
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
208
|
+
|
|
209
|
+
collated_values.append(stack(tensors))
|
|
210
|
+
|
|
211
|
+
return dict(zip(keys, collated_values))
|
|
212
|
+
|
|
213
|
+
class ReplayDataset(Dataset):
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
folder: str | Path,
|
|
217
|
+
fields: tuple[str, ...] | None = None
|
|
218
|
+
):
|
|
219
|
+
if isinstance(folder, str):
|
|
220
|
+
folder = Path(folder)
|
|
221
|
+
|
|
222
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
223
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
224
|
+
|
|
225
|
+
# get indices of non-zero lengthed episodes
|
|
226
|
+
|
|
227
|
+
nonzero_episodes = self.episode_lens > 0
|
|
228
|
+
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
229
|
+
|
|
230
|
+
# get all data files
|
|
231
|
+
|
|
232
|
+
filepaths = [*folder.glob('*.data.npy')]
|
|
233
|
+
assert len(filepaths) > 0
|
|
234
|
+
|
|
235
|
+
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
236
|
+
|
|
237
|
+
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
238
|
+
|
|
239
|
+
fields = default(fields, fieldnames_from_files)
|
|
240
|
+
|
|
241
|
+
self.memmaps = dict()
|
|
242
|
+
|
|
243
|
+
for field in fields:
|
|
244
|
+
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
245
|
+
|
|
246
|
+
path = fieldname_to_filepath[field]
|
|
247
|
+
|
|
248
|
+
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
249
|
+
|
|
250
|
+
def __len__(self):
|
|
251
|
+
return len(self.indices)
|
|
252
|
+
|
|
253
|
+
def __getitem__(self, idx):
|
|
254
|
+
episode_index = self.indices[idx]
|
|
255
|
+
|
|
256
|
+
episode_len = self.episode_lens[episode_index]
|
|
257
|
+
|
|
258
|
+
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
259
|
+
|
|
260
|
+
data['_lens'] = tensor(episode_len)
|
|
261
|
+
|
|
262
|
+
return data
|
|
263
|
+
|
|
264
|
+
class RemappedReplayDataset(Dataset):
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
dataset: ReplayDataset,
|
|
268
|
+
episode_mapping: Tensor | list[list[int]],
|
|
269
|
+
shuffle_episodes = False
|
|
270
|
+
):
|
|
271
|
+
assert len(dataset) > 0
|
|
272
|
+
self.dataset = dataset
|
|
273
|
+
|
|
274
|
+
if is_tensor(episode_mapping):
|
|
275
|
+
assert episode_mapping.dtype in (torch.int, torch.long) and episode_mapping.ndim == 2
|
|
276
|
+
episode_mapping = episode_mapping.tolist()
|
|
277
|
+
|
|
278
|
+
self.episode_mapping = episode_mapping
|
|
279
|
+
self.shuffle_episodes = shuffle_episodes
|
|
280
|
+
|
|
281
|
+
def __len__(self):
|
|
282
|
+
return len(self.episode_mapping)
|
|
283
|
+
|
|
284
|
+
def __getitem__(self, idx):
|
|
285
|
+
|
|
286
|
+
episode_indices = self.episode_mapping[idx]
|
|
287
|
+
|
|
288
|
+
episode_indices = tensor(episode_indices)
|
|
289
|
+
episode_indices = episode_indices[(episode_indices >= 0) & (episode_indices < len(self.dataset))]
|
|
290
|
+
|
|
291
|
+
assert not is_empty(episode_indices)
|
|
292
|
+
|
|
293
|
+
if self.shuffle_episodes and episode_indices.numel() > 1:
|
|
294
|
+
num_episodes = len(episode_indices)
|
|
295
|
+
episode_indices = episode_indices[torch.randperm(num_episodes)]
|
|
296
|
+
|
|
297
|
+
episode_data = [self.dataset[i] for i in episode_indices.tolist()]
|
|
298
|
+
|
|
299
|
+
episode_lens = stack([data.pop('_lens') for data in episode_data])
|
|
300
|
+
|
|
301
|
+
keys = first(episode_data).keys()
|
|
302
|
+
|
|
303
|
+
values = [list(data.values()) for data in episode_data]
|
|
304
|
+
|
|
305
|
+
values = [cat(field_values) for field_values in zip(*values)] # concat across time
|
|
306
|
+
|
|
307
|
+
multi_episode_data = dict(zip(keys, values))
|
|
308
|
+
|
|
309
|
+
multi_episode_data['_lens'] = episode_lens.sum()
|
|
310
|
+
|
|
311
|
+
multi_episode_data['_episode_indices'] = cat([torch.full((episode_len,), episode_index) for episode_len, episode_index in zip(episode_lens, episode_indices)])
|
|
312
|
+
|
|
313
|
+
return multi_episode_data
|
|
314
|
+
|
|
315
|
+
class ReplayBuffer:
|
|
316
|
+
|
|
317
|
+
@beartype
|
|
318
|
+
def __init__(
|
|
319
|
+
self,
|
|
320
|
+
folder: str | Path,
|
|
321
|
+
max_episodes: int,
|
|
322
|
+
max_timesteps: int,
|
|
323
|
+
fields: dict[
|
|
324
|
+
str,
|
|
325
|
+
str | tuple[str, int | tuple[int, ...]]
|
|
326
|
+
]
|
|
327
|
+
):
|
|
328
|
+
|
|
329
|
+
# folder for data
|
|
330
|
+
|
|
331
|
+
if not isinstance(folder, Path):
|
|
332
|
+
folder = Path(folder)
|
|
333
|
+
folder.mkdir(exist_ok = True)
|
|
334
|
+
|
|
335
|
+
self.folder = folder
|
|
336
|
+
assert folder.is_dir()
|
|
337
|
+
|
|
338
|
+
# keeping track of episode length
|
|
339
|
+
|
|
340
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
341
|
+
|
|
342
|
+
self.episode_index = 0
|
|
343
|
+
self.timestep_index = 0
|
|
344
|
+
|
|
345
|
+
self.max_episodes = max_episodes
|
|
346
|
+
self.max_timesteps= max_timesteps
|
|
347
|
+
|
|
348
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
|
|
349
|
+
|
|
350
|
+
# create the memmap for individual data tracks
|
|
351
|
+
|
|
352
|
+
self.shapes = dict()
|
|
353
|
+
self.dtypes = dict()
|
|
354
|
+
self.memmaps = dict()
|
|
355
|
+
self.fieldnames = set(fields.keys())
|
|
356
|
+
|
|
357
|
+
for field_name, field_info in fields.items():
|
|
358
|
+
|
|
359
|
+
# some flexibility
|
|
360
|
+
|
|
361
|
+
field_info = (field_info, ()) if isinstance(field_info, str) else field_info
|
|
362
|
+
|
|
363
|
+
dtype_str, shape = field_info
|
|
364
|
+
assert dtype_str in {'int', 'float', 'bool'}
|
|
365
|
+
|
|
366
|
+
dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
|
|
367
|
+
|
|
368
|
+
# memmap file
|
|
369
|
+
|
|
370
|
+
filepath = folder / f'{field_name}.data.npy'
|
|
371
|
+
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
372
|
+
|
|
373
|
+
self.memmaps[field_name] = memmap
|
|
374
|
+
self.shapes[field_name] = shape
|
|
375
|
+
self.dtypes[field_name] = dtype
|
|
376
|
+
|
|
377
|
+
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
378
|
+
|
|
379
|
+
def __len__(self):
|
|
380
|
+
return (self.episode_lens > 0).sum().item()
|
|
381
|
+
|
|
382
|
+
def reset_(self):
|
|
383
|
+
self.episode_lens[:] = 0
|
|
384
|
+
self.episode_index = 0
|
|
385
|
+
self.timestep_index = 0
|
|
386
|
+
|
|
387
|
+
def advance_episode(self):
|
|
388
|
+
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
389
|
+
self.timestep_index = 0
|
|
390
|
+
|
|
391
|
+
def flush(self):
|
|
392
|
+
self.episode_lens[self.episode_index] = self.timestep_index
|
|
393
|
+
|
|
394
|
+
for memmap in self.memmaps.values():
|
|
395
|
+
memmap.flush()
|
|
396
|
+
|
|
397
|
+
self.episode_lens.flush()
|
|
398
|
+
|
|
399
|
+
@contextmanager
|
|
400
|
+
def one_episode(self):
|
|
401
|
+
|
|
402
|
+
yield
|
|
403
|
+
|
|
404
|
+
self.flush()
|
|
405
|
+
self.advance_episode()
|
|
406
|
+
|
|
407
|
+
@beartype
|
|
408
|
+
def store_datapoint(
|
|
409
|
+
self,
|
|
410
|
+
episode_index: int,
|
|
411
|
+
timestep_index: int,
|
|
412
|
+
name: str,
|
|
413
|
+
datapoint: Tensor | ndarray
|
|
414
|
+
):
|
|
415
|
+
assert 0 <= episode_index < self.max_episodes
|
|
416
|
+
assert 0 <= timestep_index < self.max_timesteps
|
|
417
|
+
|
|
418
|
+
if is_tensor(datapoint):
|
|
419
|
+
datapoint = datapoint.detach().cpu().numpy()
|
|
420
|
+
|
|
421
|
+
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
422
|
+
|
|
423
|
+
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
424
|
+
|
|
425
|
+
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
426
|
+
|
|
427
|
+
def store(
|
|
428
|
+
self,
|
|
429
|
+
**data
|
|
430
|
+
):
|
|
431
|
+
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
432
|
+
|
|
433
|
+
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
434
|
+
|
|
435
|
+
for name, datapoint in data.items():
|
|
436
|
+
|
|
437
|
+
self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
|
|
438
|
+
|
|
439
|
+
self.timestep_index += 1
|
|
440
|
+
|
|
441
|
+
return self.memory_namedtuple(**data)
|
|
442
|
+
|
|
443
|
+
def dataset(
|
|
444
|
+
self,
|
|
445
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
446
|
+
) -> Dataset:
|
|
447
|
+
self.flush()
|
|
448
|
+
|
|
449
|
+
dataset = ReplayDataset(self.folder)
|
|
450
|
+
|
|
451
|
+
if not exists(episode_mapping):
|
|
452
|
+
return dataset
|
|
453
|
+
|
|
454
|
+
return RemappedReplayDataset(dataset, episode_mapping)
|
|
455
|
+
|
|
456
|
+
def dataloader(
|
|
457
|
+
self,
|
|
458
|
+
batch_size,
|
|
459
|
+
episode_mapping: Tensor | list[list[int]] | None = None,
|
|
460
|
+
**kwargs
|
|
461
|
+
) -> DataLoader:
|
|
462
|
+
self.flush()
|
|
463
|
+
|
|
464
|
+
return DataLoader(self.dataset(episode_mapping), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
465
|
+
|
|
149
466
|
# transformer-xl with ppo
|
|
150
467
|
|
|
151
468
|
class Attention(Module):
|
|
@@ -204,7 +521,6 @@ class Attention(Module):
|
|
|
204
521
|
return_kv_cache = False,
|
|
205
522
|
):
|
|
206
523
|
seq_len = tokens.shape[-2]
|
|
207
|
-
assert seq_len <= self.window_size
|
|
208
524
|
|
|
209
525
|
device = tokens.device
|
|
210
526
|
|
|
@@ -365,7 +681,21 @@ class Locoformer(Module):
|
|
|
365
681
|
embedder: Module,
|
|
366
682
|
unembedder: Module,
|
|
367
683
|
transformer: dict | TransformerXL,
|
|
368
|
-
|
|
684
|
+
discount_factor = 0.999,
|
|
685
|
+
gae_lam = 0.95,
|
|
686
|
+
ppo_eps_clip = 0.2,
|
|
687
|
+
ppo_entropy_weight = 0.01,
|
|
688
|
+
ppo_value_clip = 0.4,
|
|
689
|
+
dim_value_input = None, # needs to be set for value network to be available
|
|
690
|
+
value_network: Module = nn.Identity(),
|
|
691
|
+
reward_range: tuple[float, float] | None = None,
|
|
692
|
+
reward_shaping_fns: list[Callable[[Tensor], float | Tensor]] | None = None,
|
|
693
|
+
num_reward_bins = 32,
|
|
694
|
+
hl_gauss_loss_kwargs = dict(),
|
|
695
|
+
value_loss_weight = 0.5,
|
|
696
|
+
calc_gae_kwargs: dict = dict(),
|
|
697
|
+
recurrent_kv_cache = True,
|
|
698
|
+
use_spo = False # simple policy optimization https://arxiv.org/abs/2401.16025 - Levine's group (PI) verified it is more stable than PPO
|
|
369
699
|
):
|
|
370
700
|
super().__init__()
|
|
371
701
|
|
|
@@ -377,11 +707,58 @@ class Locoformer(Module):
|
|
|
377
707
|
self.embedder = embedder
|
|
378
708
|
self.unembedder = unembedder
|
|
379
709
|
|
|
380
|
-
self.value_network = value_network
|
|
381
|
-
|
|
382
710
|
self.fixed_window_size = transformer.fixed_window_size
|
|
383
711
|
self.window_size = transformer.window_size
|
|
384
712
|
|
|
713
|
+
# determine value network, using HL Gauss Layer
|
|
714
|
+
|
|
715
|
+
self.to_value_pred = None
|
|
716
|
+
|
|
717
|
+
if exists(dim_value_input):
|
|
718
|
+
assert exists(reward_range)
|
|
719
|
+
|
|
720
|
+
self.to_value_pred = nn.Sequential(
|
|
721
|
+
value_network,
|
|
722
|
+
LinearNoBias(dim_value_input, num_reward_bins)
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
reward_min, reward_max = reward_range
|
|
726
|
+
|
|
727
|
+
self.hl_gauss_loss = HLGaussLoss(
|
|
728
|
+
min_value = reward_min,
|
|
729
|
+
max_value = reward_max,
|
|
730
|
+
num_bins = num_reward_bins,
|
|
731
|
+
**hl_gauss_loss_kwargs
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# ppo related
|
|
735
|
+
|
|
736
|
+
self.discount_factor = discount_factor
|
|
737
|
+
self.gae_lam = gae_lam
|
|
738
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
739
|
+
self.ppo_entropy_weight = ppo_entropy_weight
|
|
740
|
+
self.ppo_value_clip = ppo_value_clip
|
|
741
|
+
self.value_loss_weight = value_loss_weight
|
|
742
|
+
|
|
743
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
744
|
+
|
|
745
|
+
# maybe use spo
|
|
746
|
+
|
|
747
|
+
self.use_spo = use_spo
|
|
748
|
+
|
|
749
|
+
# maybe recurrent kv cache (todo: find and cite this paper from ages ago)
|
|
750
|
+
|
|
751
|
+
self.recurrent_kv_cache = recurrent_kv_cache
|
|
752
|
+
|
|
753
|
+
# reward shaping function
|
|
754
|
+
|
|
755
|
+
self.has_reward_shaping = exists(reward_shaping_fns)
|
|
756
|
+
self.reward_shaping_fns = reward_shaping_fns
|
|
757
|
+
|
|
758
|
+
# loss related
|
|
759
|
+
|
|
760
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
761
|
+
|
|
385
762
|
@property
|
|
386
763
|
def device(self):
|
|
387
764
|
return next(self.parameters()).device
|
|
@@ -390,7 +767,165 @@ class Locoformer(Module):
|
|
|
390
767
|
return self.unembedder.parameters()
|
|
391
768
|
|
|
392
769
|
def critic_parameters(self):
|
|
393
|
-
|
|
770
|
+
if not exists(self.to_value_pred):
|
|
771
|
+
return []
|
|
772
|
+
|
|
773
|
+
return self.to_value_pred.parameters()
|
|
774
|
+
|
|
775
|
+
def ppo(
|
|
776
|
+
self,
|
|
777
|
+
state,
|
|
778
|
+
action,
|
|
779
|
+
old_action_log_prob,
|
|
780
|
+
reward,
|
|
781
|
+
old_value,
|
|
782
|
+
mask,
|
|
783
|
+
episode_lens,
|
|
784
|
+
actor_optim: Optimizer | None = None,
|
|
785
|
+
critic_optim: Optimizer | None = None
|
|
786
|
+
):
|
|
787
|
+
window_size = self.window_size
|
|
788
|
+
total_learnable_tokens = mask.sum().item()
|
|
789
|
+
|
|
790
|
+
seq_len = state.shape[1]
|
|
791
|
+
gae_mask = einx.less('j, i -> i j', arange(seq_len, device = self.device), episode_lens)
|
|
792
|
+
|
|
793
|
+
advantage, returns = calc_gae(reward, old_value, masks = gae_mask, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
794
|
+
|
|
795
|
+
advantage = normalize(advantage)
|
|
796
|
+
|
|
797
|
+
windowed_tensors = [
|
|
798
|
+
t.split(window_size, dim = 1) for t in
|
|
799
|
+
(
|
|
800
|
+
state,
|
|
801
|
+
action,
|
|
802
|
+
old_action_log_prob,
|
|
803
|
+
reward,
|
|
804
|
+
old_value,
|
|
805
|
+
mask,
|
|
806
|
+
advantage,
|
|
807
|
+
returns
|
|
808
|
+
)
|
|
809
|
+
]
|
|
810
|
+
|
|
811
|
+
mean_actor_loss = self.zero.clone()
|
|
812
|
+
mean_critic_loss = self.zero.clone()
|
|
813
|
+
|
|
814
|
+
# learn across windows
|
|
815
|
+
|
|
816
|
+
cache = None
|
|
817
|
+
|
|
818
|
+
for (
|
|
819
|
+
state,
|
|
820
|
+
action,
|
|
821
|
+
old_action_log_prob,
|
|
822
|
+
reward,
|
|
823
|
+
old_value,
|
|
824
|
+
mask,
|
|
825
|
+
advantage,
|
|
826
|
+
returns
|
|
827
|
+
) in zip(*windowed_tensors):
|
|
828
|
+
|
|
829
|
+
(action_logits, value_logits), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True, return_raw_value_logits = True)
|
|
830
|
+
entropy = calc_entropy(action_logits)
|
|
831
|
+
|
|
832
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
833
|
+
log_prob = action_logits.gather(-1, action)
|
|
834
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
835
|
+
|
|
836
|
+
# update actor, classic clipped surrogate loss
|
|
837
|
+
|
|
838
|
+
eps_clip = self.ppo_eps_clip
|
|
839
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
840
|
+
|
|
841
|
+
if self.use_spo:
|
|
842
|
+
actor_loss = -(ratio * advantage - (advantage.abs() * (ratio - 1.).square()) / (2 * eps_clip))
|
|
843
|
+
else:
|
|
844
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
845
|
+
|
|
846
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
847
|
+
|
|
848
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
849
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
850
|
+
|
|
851
|
+
# update critic
|
|
852
|
+
|
|
853
|
+
value_loss = self.hl_gauss_loss(value_logits, returns, reduction = 'none')
|
|
854
|
+
|
|
855
|
+
value_clip = self.ppo_value_clip
|
|
856
|
+
value = self.hl_gauss_loss(value_logits)
|
|
857
|
+
|
|
858
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
859
|
+
clipped_value_loss = self.hl_gauss_loss(clipped_value, returns, reduction = 'none')
|
|
860
|
+
|
|
861
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
862
|
+
|
|
863
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
864
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
865
|
+
|
|
866
|
+
# accumulate
|
|
867
|
+
|
|
868
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
869
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
870
|
+
|
|
871
|
+
# optimizer update
|
|
872
|
+
|
|
873
|
+
if exists(actor_optim):
|
|
874
|
+
actor_optim.step()
|
|
875
|
+
actor_optim.zero_grad()
|
|
876
|
+
|
|
877
|
+
if exists(critic_optim):
|
|
878
|
+
critic_optim.step()
|
|
879
|
+
critic_optim.zero_grad()
|
|
880
|
+
|
|
881
|
+
# return losses for logging
|
|
882
|
+
|
|
883
|
+
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
884
|
+
|
|
885
|
+
def state_to_rewards(
|
|
886
|
+
self,
|
|
887
|
+
state
|
|
888
|
+
) -> Tensor:
|
|
889
|
+
|
|
890
|
+
assert self.has_reward_shaping
|
|
891
|
+
|
|
892
|
+
rewards = [fn(state) for fn in self.reward_shaping_fns]
|
|
893
|
+
|
|
894
|
+
rewards = [tensor(reward) if not is_tensor(reward) else reward for reward in rewards]
|
|
895
|
+
return stack(rewards)
|
|
896
|
+
|
|
897
|
+
def wrap_env_functions(self, env):
|
|
898
|
+
|
|
899
|
+
def transform_output(el):
|
|
900
|
+
if isinstance(el, ndarray):
|
|
901
|
+
return from_numpy(el)
|
|
902
|
+
elif isinstance(el, (int, bool, float)):
|
|
903
|
+
return tensor(el)
|
|
904
|
+
else:
|
|
905
|
+
return el
|
|
906
|
+
|
|
907
|
+
def wrapped_reset(*args, **kwargs):
|
|
908
|
+
env_reset_out = env.reset(*args, **kwargs)
|
|
909
|
+
|
|
910
|
+
return tree_map(transform_output, env_reset_out)
|
|
911
|
+
|
|
912
|
+
def wrapped_step(action, *args, **kwargs):
|
|
913
|
+
|
|
914
|
+
if is_tensor(action):
|
|
915
|
+
action = action.item()
|
|
916
|
+
|
|
917
|
+
env_step_out = env.step(action, *args, **kwargs)
|
|
918
|
+
|
|
919
|
+
env_step_out_torch = tree_map(transform_output, env_step_out)
|
|
920
|
+
|
|
921
|
+
if not self.has_reward_shaping:
|
|
922
|
+
return env_step_out_torch
|
|
923
|
+
|
|
924
|
+
shaped_rewards = self.state_to_rewards(env_step_out_torch)
|
|
925
|
+
|
|
926
|
+
return env_step_out_torch, shaped_rewards
|
|
927
|
+
|
|
928
|
+
return wrapped_reset, wrapped_step
|
|
394
929
|
|
|
395
930
|
def get_stateful_forward(
|
|
396
931
|
self,
|
|
@@ -398,6 +933,7 @@ class Locoformer(Module):
|
|
|
398
933
|
inference_mode = False,
|
|
399
934
|
has_batch_dim = False,
|
|
400
935
|
has_time_dim = False,
|
|
936
|
+
state_time_dim = 1,
|
|
401
937
|
**kwargs
|
|
402
938
|
):
|
|
403
939
|
window_size = self.window_size
|
|
@@ -413,23 +949,16 @@ class Locoformer(Module):
|
|
|
413
949
|
state = rearrange(state, '... -> 1 ...')
|
|
414
950
|
|
|
415
951
|
if not has_time_dim:
|
|
416
|
-
state =
|
|
952
|
+
state = state.unsqueeze(state_time_dim)
|
|
417
953
|
|
|
418
954
|
# forwards
|
|
419
955
|
|
|
420
956
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
421
957
|
|
|
422
|
-
# handle cache
|
|
423
|
-
|
|
424
|
-
cache_len = cache.shape[-2]
|
|
425
|
-
|
|
426
|
-
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
427
|
-
cache = cache[..., -window_size:, :]
|
|
428
|
-
|
|
429
958
|
# maybe remove batch or time
|
|
430
959
|
|
|
431
960
|
if not has_time_dim:
|
|
432
|
-
out = tree_map_tensor(out, lambda t:
|
|
961
|
+
out = tree_map_tensor(out, lambda t: t.squeeze(state_time_dim))
|
|
433
962
|
|
|
434
963
|
if not has_batch_dim:
|
|
435
964
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -458,14 +987,35 @@ class Locoformer(Module):
|
|
|
458
987
|
def forward(
|
|
459
988
|
self,
|
|
460
989
|
state: Tensor,
|
|
461
|
-
cache:
|
|
990
|
+
cache: Cache | None = None,
|
|
462
991
|
detach_cache = False,
|
|
463
|
-
return_values = False
|
|
992
|
+
return_values = False,
|
|
993
|
+
return_raw_value_logits = False
|
|
464
994
|
):
|
|
465
995
|
|
|
996
|
+
state = state.to(self.device)
|
|
997
|
+
|
|
466
998
|
tokens = self.embedder(state)
|
|
467
999
|
|
|
468
|
-
|
|
1000
|
+
# time
|
|
1001
|
+
|
|
1002
|
+
time = tokens.shape[-2]
|
|
1003
|
+
|
|
1004
|
+
# destruct the cache for the current timestep and the cache
|
|
1005
|
+
|
|
1006
|
+
prev_kv_cache = None
|
|
1007
|
+
timestep_start = 0
|
|
1008
|
+
|
|
1009
|
+
if exists(cache):
|
|
1010
|
+
timestep_start, prev_kv_cache = cache
|
|
1011
|
+
|
|
1012
|
+
# an assert - make sure during training or inference, forward never gets anything that crosses the window segment boundary, to open up some possibilities with extending memory
|
|
1013
|
+
|
|
1014
|
+
assert ((timestep_start % self.window_size) + time) <= self.window_size
|
|
1015
|
+
|
|
1016
|
+
# attention
|
|
1017
|
+
|
|
1018
|
+
embed, kv_cache = self.transformer(tokens, cache = prev_kv_cache, return_kv_cache = True)
|
|
469
1019
|
|
|
470
1020
|
# unembed to actions - in language models this would be the next state
|
|
471
1021
|
|
|
@@ -476,21 +1026,34 @@ class Locoformer(Module):
|
|
|
476
1026
|
# maybe detach cache
|
|
477
1027
|
|
|
478
1028
|
if detach_cache:
|
|
479
|
-
kv_cache =
|
|
1029
|
+
kv_cache = kv_cache.detach()
|
|
480
1030
|
|
|
481
1031
|
# handle returning of values
|
|
482
1032
|
|
|
483
1033
|
if return_values:
|
|
484
|
-
assert exists(self.
|
|
1034
|
+
assert exists(self.to_value_pred)
|
|
485
1035
|
|
|
486
|
-
values = self.
|
|
1036
|
+
values = self.to_value_pred(embed)
|
|
487
1037
|
|
|
488
|
-
if
|
|
489
|
-
|
|
490
|
-
values = rearrange(values, '... 1 -> ...')
|
|
1038
|
+
if not return_raw_value_logits:
|
|
1039
|
+
values = self.hl_gauss_loss(values) # converts the value logits to scalar values
|
|
491
1040
|
|
|
492
1041
|
out = (out, values)
|
|
493
1042
|
|
|
494
1043
|
# output and cache
|
|
495
1044
|
|
|
496
|
-
|
|
1045
|
+
next_timestep = time + timestep_start
|
|
1046
|
+
|
|
1047
|
+
# handle curtailing kv cache at the right intervals
|
|
1048
|
+
|
|
1049
|
+
window_size = self.window_size
|
|
1050
|
+
|
|
1051
|
+
if self.fixed_window_size or divisible_by(next_timestep, window_size * 2):
|
|
1052
|
+
kv_cache = kv_cache[..., -window_size:, :]
|
|
1053
|
+
|
|
1054
|
+
# maybe recurrent cache - shift the kv cache from one layer above to the one below, for extending on receptive field of past
|
|
1055
|
+
|
|
1056
|
+
if self.recurrent_kv_cache and divisible_by(next_timestep, window_size):
|
|
1057
|
+
kv_cache = torch.roll(kv_cache, shifts = -1, dims = 0)
|
|
1058
|
+
|
|
1059
|
+
return out, (next_timestep, kv_cache)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.29
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -35,8 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.0
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: hl-gauss-pytorch>=0.2.0
|
|
40
42
|
Requires-Dist: rotary-embedding-torch
|
|
41
43
|
Requires-Dist: torch>=2.4
|
|
42
44
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -53,7 +55,7 @@ Description-Content-Type: text/markdown
|
|
|
53
55
|
|
|
54
56
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
55
57
|
|
|
56
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment)
|
|
58
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) and extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
59
|
|
|
58
60
|
## Sponsors
|
|
59
61
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=Tr_1btuoTZ0huXeDcAeuHxTPaVeCUEGc5iLvMYGDLck,29982
|
|
3
|
+
locoformer-0.0.29.dist-info/METADATA,sha256=5Fi3EOsgpBvpzAFVZQyrlink-HcHE8EgFl10Y5l8mqM,3256
|
|
4
|
+
locoformer-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
locoformer-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.29.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=lJQs0CKr9iztF8tie1FRUVEItCt-IZbIILQqKcgK2sI,13142
|
|
3
|
-
locoformer-0.0.7.dist-info/METADATA,sha256=PZ_phKV3t4Bha0GnUB5HPmE9w8A5fvNevsuN532Ls3s,3193
|
|
4
|
-
locoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|