locoformer 0.0.5__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- locoformer/locoformer.py +454 -30
- {locoformer-0.0.5.dist-info → locoformer-0.0.15.dist-info}/METADATA +3 -2
- locoformer-0.0.15.dist-info/RECORD +6 -0
- locoformer-0.0.5.dist-info/RECORD +0 -6
- {locoformer-0.0.5.dist-info → locoformer-0.0.15.dist-info}/WHEEL +0 -0
- {locoformer-0.0.5.dist-info → locoformer-0.0.15.dist-info}/licenses/LICENSE +0 -0
locoformer/locoformer.py
CHANGED
|
@@ -1,12 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from functools import partial
|
|
3
3
|
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from collections import namedtuple
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from numpy import ndarray
|
|
10
|
+
from numpy.lib.format import open_memmap
|
|
11
|
+
|
|
12
|
+
from beartype import beartype
|
|
13
|
+
from beartype.door import is_bearable
|
|
14
|
+
|
|
4
15
|
import torch
|
|
5
|
-
from torch import cat, stack, is_tensor
|
|
16
|
+
from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
|
|
6
17
|
import torch.nn.functional as F
|
|
7
|
-
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
|
|
18
|
+
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
8
19
|
from torch.utils._pytree import tree_map
|
|
20
|
+
from torch.utils.data import Dataset, DataLoader
|
|
9
21
|
|
|
22
|
+
import einx
|
|
10
23
|
from einops import rearrange, einsum
|
|
11
24
|
from einops.layers.torch import Rearrange
|
|
12
25
|
|
|
@@ -24,23 +37,39 @@ def exists(v):
|
|
|
24
37
|
def default(v, d):
|
|
25
38
|
return v if exists(v) else d
|
|
26
39
|
|
|
40
|
+
def first(arr):
|
|
41
|
+
return arr[0]
|
|
42
|
+
|
|
27
43
|
def divisible_by(num, den):
|
|
28
44
|
return (num % den) == 0
|
|
29
45
|
|
|
46
|
+
# tensor helpers
|
|
47
|
+
|
|
48
|
+
def log(t, eps = 1e-20):
|
|
49
|
+
return t.clamp_min(eps).log()
|
|
50
|
+
|
|
30
51
|
def tree_map_tensor(x, fn):
|
|
31
52
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
32
53
|
|
|
33
54
|
def detach_all(x):
|
|
34
55
|
return tree_map_tensor(x, lambda t: t.detach())
|
|
35
56
|
|
|
36
|
-
def
|
|
37
|
-
|
|
57
|
+
def pad_at_dim(
|
|
58
|
+
t,
|
|
59
|
+
pad: tuple[int, int],
|
|
60
|
+
dim = -1,
|
|
61
|
+
value = 0.
|
|
62
|
+
):
|
|
63
|
+
if pad == (0, 0):
|
|
64
|
+
return t
|
|
38
65
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
66
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
67
|
+
zeros = ((0, 0) * dims_from_right)
|
|
68
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
42
69
|
|
|
43
|
-
|
|
70
|
+
def calc_entropy(logits):
|
|
71
|
+
prob = logits.softmax(dim = -1)
|
|
72
|
+
return -(prob * log(prob)).sum(dim = -1)
|
|
44
73
|
|
|
45
74
|
# generalized advantage estimate
|
|
46
75
|
|
|
@@ -48,7 +77,7 @@ def combine_kv_cache(cache1, cache2):
|
|
|
48
77
|
def calc_gae(
|
|
49
78
|
rewards,
|
|
50
79
|
values,
|
|
51
|
-
masks,
|
|
80
|
+
masks = None,
|
|
52
81
|
gamma = 0.99,
|
|
53
82
|
lam = 0.95,
|
|
54
83
|
use_accelerated = None
|
|
@@ -59,6 +88,9 @@ def calc_gae(
|
|
|
59
88
|
values = F.pad(values, (0, 1), value = 0.)
|
|
60
89
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
61
90
|
|
|
91
|
+
if not exists(masks):
|
|
92
|
+
masks = torch.ones_like(values)
|
|
93
|
+
|
|
62
94
|
delta = rewards + gamma * values_next * masks - values
|
|
63
95
|
gates = gamma * lam * masks
|
|
64
96
|
|
|
@@ -110,8 +142,8 @@ def create_xl_mask(
|
|
|
110
142
|
# handle intra-episodic attention if needed
|
|
111
143
|
|
|
112
144
|
if exists(episode_ids):
|
|
113
|
-
q_episode =
|
|
114
|
-
k_episode =
|
|
145
|
+
q_episode = episode_ids[b, q + offset]
|
|
146
|
+
k_episode = episode_ids[b, k]
|
|
115
147
|
|
|
116
148
|
intra_episode_mask = q_episode == k_episode
|
|
117
149
|
mask = mask & intra_episode_mask
|
|
@@ -142,15 +174,229 @@ def create_sliding_mask(
|
|
|
142
174
|
create_kwargs = dict(device = device) if exists(device) else dict()
|
|
143
175
|
return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
|
|
144
176
|
|
|
177
|
+
# data
|
|
178
|
+
|
|
179
|
+
def collate_var_time(data):
|
|
180
|
+
|
|
181
|
+
datum = first(data)
|
|
182
|
+
keys = datum.keys()
|
|
183
|
+
|
|
184
|
+
all_tensors = zip(*[datum.values() for datum in data])
|
|
185
|
+
|
|
186
|
+
collated_values = []
|
|
187
|
+
|
|
188
|
+
for key, tensors in zip(keys, all_tensors):
|
|
189
|
+
|
|
190
|
+
# the episode lens have zero dimension - think of a cleaner way to handle this later
|
|
191
|
+
|
|
192
|
+
if key != '_lens':
|
|
193
|
+
|
|
194
|
+
times = [t.shape[0] for t in tensors]
|
|
195
|
+
max_time = max(times)
|
|
196
|
+
tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
|
|
197
|
+
|
|
198
|
+
collated_values.append(stack(tensors))
|
|
199
|
+
|
|
200
|
+
return dict(zip(keys, collated_values))
|
|
201
|
+
|
|
202
|
+
class ReplayDataset(Dataset):
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
folder: str | Path,
|
|
206
|
+
fields: tuple[str, ...] | None = None
|
|
207
|
+
):
|
|
208
|
+
if isinstance(folder, str):
|
|
209
|
+
folder = Path(folder)
|
|
210
|
+
|
|
211
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
212
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
|
|
213
|
+
|
|
214
|
+
# get indices of non-zero lengthed episodes
|
|
215
|
+
|
|
216
|
+
nonzero_episodes = self.episode_lens > 0
|
|
217
|
+
self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
|
|
218
|
+
|
|
219
|
+
# get all data files
|
|
220
|
+
|
|
221
|
+
filepaths = [*folder.glob('*.data.npy')]
|
|
222
|
+
assert len(filepaths) > 0
|
|
223
|
+
|
|
224
|
+
fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
|
|
225
|
+
|
|
226
|
+
fieldnames_from_files = set(fieldname_to_filepath.keys())
|
|
227
|
+
|
|
228
|
+
fields = default(fields, fieldnames_from_files)
|
|
229
|
+
|
|
230
|
+
self.memmaps = dict()
|
|
231
|
+
|
|
232
|
+
for field in fields:
|
|
233
|
+
assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
|
|
234
|
+
|
|
235
|
+
path = fieldname_to_filepath[field]
|
|
236
|
+
|
|
237
|
+
self.memmaps[field] = open_memmap(str(path), mode = 'r')
|
|
238
|
+
|
|
239
|
+
def __len__(self):
|
|
240
|
+
return len(self.indices)
|
|
241
|
+
|
|
242
|
+
def __getitem__(self, idx):
|
|
243
|
+
episode_index = self.indices[idx]
|
|
244
|
+
|
|
245
|
+
episode_len = self.episode_lens[episode_index]
|
|
246
|
+
|
|
247
|
+
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
248
|
+
|
|
249
|
+
data['_lens'] = tensor(episode_len)
|
|
250
|
+
|
|
251
|
+
return data
|
|
252
|
+
|
|
253
|
+
class ReplayBuffer:
|
|
254
|
+
|
|
255
|
+
@beartype
|
|
256
|
+
def __init__(
|
|
257
|
+
self,
|
|
258
|
+
folder: str | Path,
|
|
259
|
+
max_episodes: int,
|
|
260
|
+
max_timesteps: int,
|
|
261
|
+
fields: dict[
|
|
262
|
+
str,
|
|
263
|
+
str | tuple[str, int | tuple[int, ...]]
|
|
264
|
+
]
|
|
265
|
+
):
|
|
266
|
+
|
|
267
|
+
# folder for data
|
|
268
|
+
|
|
269
|
+
if not isinstance(folder, Path):
|
|
270
|
+
folder = Path(folder)
|
|
271
|
+
folder.mkdir(exist_ok = True)
|
|
272
|
+
|
|
273
|
+
self.folder = folder
|
|
274
|
+
assert folder.is_dir()
|
|
275
|
+
|
|
276
|
+
# keeping track of episode length
|
|
277
|
+
|
|
278
|
+
episode_lens = folder / 'episode_lens.npy'
|
|
279
|
+
|
|
280
|
+
self.episode_index = 0
|
|
281
|
+
self.timestep_index = 0
|
|
282
|
+
|
|
283
|
+
self.max_episodes = max_episodes
|
|
284
|
+
self.max_timesteps= max_timesteps
|
|
285
|
+
|
|
286
|
+
self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
|
|
287
|
+
|
|
288
|
+
# create the memmap for individual data tracks
|
|
289
|
+
|
|
290
|
+
self.shapes = dict()
|
|
291
|
+
self.dtypes = dict()
|
|
292
|
+
self.memmaps = dict()
|
|
293
|
+
self.fieldnames = set(fields.keys())
|
|
294
|
+
|
|
295
|
+
for field_name, field_info in fields.items():
|
|
296
|
+
|
|
297
|
+
# some flexibility
|
|
298
|
+
|
|
299
|
+
field_info = (field_info, ()) if isinstance(field_info, str) else field_info
|
|
300
|
+
|
|
301
|
+
dtype_str, shape = field_info
|
|
302
|
+
assert dtype_str in {'int', 'float', 'bool'}
|
|
303
|
+
|
|
304
|
+
dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
|
|
305
|
+
|
|
306
|
+
# memmap file
|
|
307
|
+
|
|
308
|
+
filepath = folder / f'{field_name}.data.npy'
|
|
309
|
+
memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
|
|
310
|
+
|
|
311
|
+
self.memmaps[field_name] = memmap
|
|
312
|
+
self.shapes[field_name] = shape
|
|
313
|
+
self.dtypes[field_name] = dtype
|
|
314
|
+
|
|
315
|
+
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
|
+
|
|
317
|
+
def reset_(self):
|
|
318
|
+
self.episode_lens[:] = 0
|
|
319
|
+
self.episode_index = 0
|
|
320
|
+
self.timestep_index = 0
|
|
321
|
+
|
|
322
|
+
def advance_episode(self):
|
|
323
|
+
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
324
|
+
self.timestep_index = 0
|
|
325
|
+
|
|
326
|
+
def flush(self):
|
|
327
|
+
self.episode_lens[self.episode_index] = self.timestep_index
|
|
328
|
+
|
|
329
|
+
for memmap in self.memmaps.values():
|
|
330
|
+
memmap.flush()
|
|
331
|
+
|
|
332
|
+
self.episode_lens.flush()
|
|
333
|
+
|
|
334
|
+
@contextmanager
|
|
335
|
+
def one_episode(self):
|
|
336
|
+
|
|
337
|
+
yield
|
|
338
|
+
|
|
339
|
+
self.flush()
|
|
340
|
+
self.advance_episode()
|
|
341
|
+
|
|
342
|
+
@beartype
|
|
343
|
+
def store_datapoint(
|
|
344
|
+
self,
|
|
345
|
+
episode_index: int,
|
|
346
|
+
timestep_index: int,
|
|
347
|
+
name: str,
|
|
348
|
+
datapoint: Tensor | ndarray
|
|
349
|
+
):
|
|
350
|
+
assert 0 <= episode_index < self.max_episodes
|
|
351
|
+
assert 0 <= timestep_index < self.max_timesteps
|
|
352
|
+
|
|
353
|
+
if is_tensor(datapoint):
|
|
354
|
+
datapoint = datapoint.detach().cpu().numpy()
|
|
355
|
+
|
|
356
|
+
assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
|
|
357
|
+
|
|
358
|
+
assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
|
|
359
|
+
|
|
360
|
+
self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
|
|
361
|
+
|
|
362
|
+
def store(
|
|
363
|
+
self,
|
|
364
|
+
**data
|
|
365
|
+
):
|
|
366
|
+
assert is_bearable(data, dict[str, Tensor | ndarray])
|
|
367
|
+
|
|
368
|
+
assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
|
|
369
|
+
|
|
370
|
+
for name, datapoint in data.items():
|
|
371
|
+
|
|
372
|
+
self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
|
|
373
|
+
|
|
374
|
+
self.timestep_index += 1
|
|
375
|
+
|
|
376
|
+
return self.memory_namedtuple(**data)
|
|
377
|
+
|
|
378
|
+
def dataset(self) -> Dataset:
|
|
379
|
+
self.flush()
|
|
380
|
+
|
|
381
|
+
return ReplayDataset(self.folder)
|
|
382
|
+
|
|
383
|
+
def dataloader(self, batch_size, **kwargs) -> DataLoader:
|
|
384
|
+
self.flush()
|
|
385
|
+
|
|
386
|
+
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
387
|
+
|
|
145
388
|
# transformer-xl with ppo
|
|
146
389
|
|
|
147
390
|
class Attention(Module):
|
|
148
391
|
def __init__(
|
|
149
392
|
self,
|
|
150
393
|
dim,
|
|
394
|
+
window_size,
|
|
151
395
|
dim_head = 64,
|
|
152
396
|
heads = 8,
|
|
153
|
-
pre_rmsnorm = True
|
|
397
|
+
pre_rmsnorm = True,
|
|
398
|
+
fixed_window_size = False,
|
|
399
|
+
accept_value_residual = False
|
|
154
400
|
):
|
|
155
401
|
super().__init__()
|
|
156
402
|
self.scale = dim_head ** -0.5
|
|
@@ -167,20 +413,54 @@ class Attention(Module):
|
|
|
167
413
|
self.to_kv = LinearNoBias(dim, dim_inner * 2)
|
|
168
414
|
self.to_out = LinearNoBias(dim_inner, dim)
|
|
169
415
|
|
|
416
|
+
self.to_v_gates = Sequential(
|
|
417
|
+
LinearNoBias(dim, heads),
|
|
418
|
+
Rearrange('b n h -> b h n 1'),
|
|
419
|
+
nn.Sigmoid()
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# value residual
|
|
423
|
+
|
|
424
|
+
self.accept_value_residual = accept_value_residual
|
|
425
|
+
|
|
426
|
+
if accept_value_residual:
|
|
427
|
+
self.to_value_residual_mix = Sequential(
|
|
428
|
+
LinearNoBias(dim, heads),
|
|
429
|
+
Rearrange('b n h -> b h n 1'),
|
|
430
|
+
nn.Sigmoid()
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# fixed window size
|
|
434
|
+
|
|
435
|
+
self.fixed_window_size = fixed_window_size
|
|
436
|
+
self.window_size = window_size
|
|
437
|
+
|
|
170
438
|
def forward(
|
|
171
439
|
self,
|
|
172
440
|
tokens,
|
|
441
|
+
value_residual = None,
|
|
173
442
|
kv_cache = None,
|
|
174
|
-
return_kv_cache = False
|
|
443
|
+
return_kv_cache = False,
|
|
175
444
|
):
|
|
445
|
+
seq_len = tokens.shape[-2]
|
|
446
|
+
|
|
447
|
+
device = tokens.device
|
|
448
|
+
|
|
176
449
|
tokens = self.norm(tokens)
|
|
177
450
|
|
|
178
451
|
q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
|
|
179
452
|
|
|
180
453
|
q, k, v = map(self.split_heads, (q, k, v))
|
|
181
454
|
|
|
455
|
+
orig_v = v
|
|
456
|
+
|
|
182
457
|
q = q * self.scale
|
|
183
458
|
|
|
459
|
+
if exists(value_residual):
|
|
460
|
+
assert self.accept_value_residual
|
|
461
|
+
mix = self.to_value_residual_mix(tokens)
|
|
462
|
+
v = v.lerp(value_residual, mix)
|
|
463
|
+
|
|
184
464
|
if exists(kv_cache):
|
|
185
465
|
ck, cv = kv_cache
|
|
186
466
|
k = cat((ck, k), dim = -2)
|
|
@@ -195,7 +475,13 @@ class Attention(Module):
|
|
|
195
475
|
|
|
196
476
|
i, j = sim.shape[-2:]
|
|
197
477
|
|
|
198
|
-
|
|
478
|
+
if self.fixed_window_size:
|
|
479
|
+
i_seq = arange(i, device = device)
|
|
480
|
+
j_seq = arange(j, device = device) - (j - i)
|
|
481
|
+
dist = einx.subtract('i, j -> i j', i_seq, j_seq)
|
|
482
|
+
causal_mask = (dist < 0) | (dist > self.window_size)
|
|
483
|
+
else:
|
|
484
|
+
causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
|
|
199
485
|
|
|
200
486
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
201
487
|
|
|
@@ -203,6 +489,8 @@ class Attention(Module):
|
|
|
203
489
|
|
|
204
490
|
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
|
|
205
491
|
|
|
492
|
+
out = out * self.to_v_gates(tokens)
|
|
493
|
+
|
|
206
494
|
out = self.merge_heads(out)
|
|
207
495
|
|
|
208
496
|
out = self.to_out(out)
|
|
@@ -210,7 +498,7 @@ class Attention(Module):
|
|
|
210
498
|
if not return_kv_cache:
|
|
211
499
|
return out
|
|
212
500
|
|
|
213
|
-
return out, next_kv_cache
|
|
501
|
+
return out, (next_kv_cache, orig_v)
|
|
214
502
|
|
|
215
503
|
class FeedForward(Module):
|
|
216
504
|
def __init__(
|
|
@@ -244,17 +532,21 @@ class TransformerXL(Module):
|
|
|
244
532
|
self,
|
|
245
533
|
dim,
|
|
246
534
|
depth,
|
|
535
|
+
window_size,
|
|
247
536
|
dim_head = 64,
|
|
248
537
|
heads = 8,
|
|
249
538
|
expansion_factor = 4.,
|
|
250
|
-
final_norm = True
|
|
539
|
+
final_norm = True,
|
|
540
|
+
fixed_window_size = False,
|
|
251
541
|
):
|
|
252
542
|
super().__init__()
|
|
253
543
|
|
|
254
544
|
layers = ModuleList([])
|
|
255
545
|
|
|
256
|
-
for
|
|
257
|
-
|
|
546
|
+
for i in range(depth):
|
|
547
|
+
is_first = i == 0
|
|
548
|
+
|
|
549
|
+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
|
|
258
550
|
|
|
259
551
|
ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
|
|
260
552
|
|
|
@@ -265,6 +557,11 @@ class TransformerXL(Module):
|
|
|
265
557
|
self.layers = layers
|
|
266
558
|
self.norm = RMSNorm(dim) if final_norm else Identity()
|
|
267
559
|
|
|
560
|
+
# fixed window size
|
|
561
|
+
|
|
562
|
+
self.fixed_window_size = fixed_window_size
|
|
563
|
+
self.window_size = window_size
|
|
564
|
+
|
|
268
565
|
def forward(
|
|
269
566
|
self,
|
|
270
567
|
x,
|
|
@@ -275,22 +572,28 @@ class TransformerXL(Module):
|
|
|
275
572
|
cache = default(cache, (None,) * len(self.layers))
|
|
276
573
|
|
|
277
574
|
next_kv_caches = []
|
|
575
|
+
value_residual = None
|
|
278
576
|
|
|
279
577
|
for (attn, ff), kv_cache in zip(self.layers, cache):
|
|
280
578
|
|
|
281
|
-
attn_out, next_kv_cache = attn(x, kv_cache = kv_cache, return_kv_cache = True)
|
|
282
|
-
|
|
283
|
-
next_kv_caches.append(next_kv_cache)
|
|
579
|
+
attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
|
|
284
580
|
|
|
285
581
|
x = attn_out + x
|
|
286
582
|
x = ff(x) + x
|
|
287
583
|
|
|
584
|
+
next_kv_caches.append(next_kv_cache)
|
|
585
|
+
value_residual = default(value_residual, values)
|
|
586
|
+
|
|
288
587
|
embed = self.norm(x)
|
|
289
588
|
|
|
290
589
|
if not return_kv_cache:
|
|
291
590
|
return embed
|
|
292
591
|
|
|
293
|
-
|
|
592
|
+
next_kv_cache = stack(next_kv_caches)
|
|
593
|
+
|
|
594
|
+
next_kv_cache = next_kv_cache[..., -self.window_size:, :]
|
|
595
|
+
|
|
596
|
+
return embed, next_kv_cache
|
|
294
597
|
|
|
295
598
|
# class
|
|
296
599
|
|
|
@@ -300,7 +603,13 @@ class Locoformer(Module):
|
|
|
300
603
|
embedder: Module,
|
|
301
604
|
unembedder: Module,
|
|
302
605
|
transformer: dict | TransformerXL,
|
|
303
|
-
value_network: Module | None = None
|
|
606
|
+
value_network: Module | None = None,
|
|
607
|
+
discount_factor = 0.999,
|
|
608
|
+
gae_lam = 0.95,
|
|
609
|
+
ppo_eps_clip = 0.2,
|
|
610
|
+
ppo_entropy_weight = 0.01,
|
|
611
|
+
ppo_value_clip = 0.4,
|
|
612
|
+
value_loss_weight = 0.5
|
|
304
613
|
):
|
|
305
614
|
super().__init__()
|
|
306
615
|
|
|
@@ -314,28 +623,138 @@ class Locoformer(Module):
|
|
|
314
623
|
|
|
315
624
|
self.value_network = value_network
|
|
316
625
|
|
|
626
|
+
self.fixed_window_size = transformer.fixed_window_size
|
|
627
|
+
self.window_size = transformer.window_size
|
|
628
|
+
|
|
629
|
+
# ppo related
|
|
630
|
+
|
|
631
|
+
self.discount_factor = discount_factor
|
|
632
|
+
self.gae_lam = gae_lam
|
|
633
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
634
|
+
self.ppo_entropy_weight = ppo_entropy_weight
|
|
635
|
+
self.ppo_value_clip = ppo_value_clip
|
|
636
|
+
self.value_loss_weight = value_loss_weight
|
|
637
|
+
|
|
317
638
|
@property
|
|
318
639
|
def device(self):
|
|
319
640
|
return next(self.parameters()).device
|
|
320
641
|
|
|
642
|
+
def actor_parameters(self):
|
|
643
|
+
return self.unembedder.parameters()
|
|
644
|
+
|
|
645
|
+
def critic_parameters(self):
|
|
646
|
+
if not exists(self.value_network):
|
|
647
|
+
return []
|
|
648
|
+
|
|
649
|
+
return self.value_network.parameters()
|
|
650
|
+
|
|
651
|
+
def ppo(
|
|
652
|
+
self,
|
|
653
|
+
state,
|
|
654
|
+
action,
|
|
655
|
+
old_action_log_prob,
|
|
656
|
+
reward,
|
|
657
|
+
old_value,
|
|
658
|
+
mask,
|
|
659
|
+
actor_optim,
|
|
660
|
+
critic_optim
|
|
661
|
+
):
|
|
662
|
+
|
|
663
|
+
(action_logits, value), _ = self.forward(state, return_values = True)
|
|
664
|
+
entropy = calc_entropy(action_logits)
|
|
665
|
+
|
|
666
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
667
|
+
log_prob = action_logits.gather(-1, action)
|
|
668
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
669
|
+
|
|
670
|
+
# update actor, classic clipped surrogate loss
|
|
671
|
+
|
|
672
|
+
eps_clip = self.ppo_eps_clip
|
|
673
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
674
|
+
|
|
675
|
+
returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor)
|
|
676
|
+
advantage = returns - old_value
|
|
677
|
+
|
|
678
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
679
|
+
|
|
680
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
681
|
+
|
|
682
|
+
mean_actor_loss = actor_loss[mask].mean()
|
|
683
|
+
mean_actor_loss.backward(retain_graph = True)
|
|
684
|
+
|
|
685
|
+
# update critic
|
|
686
|
+
|
|
687
|
+
value_loss = F.mse_loss(returns, value, reduction = 'none')
|
|
688
|
+
|
|
689
|
+
value_clip = self.ppo_value_clip
|
|
690
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
691
|
+
clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
|
|
692
|
+
|
|
693
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
694
|
+
|
|
695
|
+
mean_critic_loss = critic_loss[mask].mean()
|
|
696
|
+
mean_critic_loss.backward()
|
|
697
|
+
|
|
698
|
+
# optimizer update
|
|
699
|
+
|
|
700
|
+
actor_optim.step()
|
|
701
|
+
actor_optim.zero_grad()
|
|
702
|
+
|
|
703
|
+
critic_optim.step()
|
|
704
|
+
critic_optim.zero_grad()
|
|
705
|
+
|
|
706
|
+
# return losses for logging
|
|
707
|
+
|
|
708
|
+
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
709
|
+
|
|
710
|
+
def wrap_env_functions(self, env):
|
|
711
|
+
|
|
712
|
+
def wrapped_reset(*args, **kwargs):
|
|
713
|
+
state, _ = env.reset(*args, **kwargs)
|
|
714
|
+
|
|
715
|
+
if isinstance(state, ndarray):
|
|
716
|
+
state = from_numpy(state)
|
|
717
|
+
|
|
718
|
+
return state, _
|
|
719
|
+
|
|
720
|
+
def wrapped_step(action, *args, **kwargs):
|
|
721
|
+
out = env.step(action.item(), *args, **kwargs)
|
|
722
|
+
|
|
723
|
+
def transform_output(el):
|
|
724
|
+
if isinstance(el, ndarray):
|
|
725
|
+
return from_numpy(el)
|
|
726
|
+
elif isinstance(el, (int, bool, float)):
|
|
727
|
+
return tensor(el)
|
|
728
|
+
else:
|
|
729
|
+
return el
|
|
730
|
+
|
|
731
|
+
return tree_map(transform_output, out)
|
|
732
|
+
|
|
733
|
+
return wrapped_reset, wrapped_step
|
|
734
|
+
|
|
321
735
|
def get_stateful_forward(
|
|
322
736
|
self,
|
|
323
|
-
segment_size,
|
|
324
737
|
initial_states: Tensor | None = None,
|
|
325
738
|
inference_mode = False,
|
|
326
739
|
has_batch_dim = False,
|
|
740
|
+
has_time_dim = False,
|
|
327
741
|
**kwargs
|
|
328
742
|
):
|
|
743
|
+
window_size = self.window_size
|
|
744
|
+
|
|
329
745
|
cache = None
|
|
330
746
|
|
|
331
|
-
def stateful_forward(state: Tensor, override_kwargs
|
|
747
|
+
def stateful_forward(state: Tensor, **override_kwargs):
|
|
332
748
|
nonlocal cache
|
|
333
749
|
|
|
334
|
-
# handle no batch, for easier time rolling out against envs
|
|
750
|
+
# handle no batch or time, for easier time rolling out against envs
|
|
335
751
|
|
|
336
752
|
if not has_batch_dim:
|
|
337
753
|
state = rearrange(state, '... -> 1 ...')
|
|
338
754
|
|
|
755
|
+
if not has_time_dim:
|
|
756
|
+
state = rearrange(state, '... d -> ... 1 d')
|
|
757
|
+
|
|
339
758
|
# forwards
|
|
340
759
|
|
|
341
760
|
out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
|
|
@@ -344,10 +763,13 @@ class Locoformer(Module):
|
|
|
344
763
|
|
|
345
764
|
cache_len = cache.shape[-2]
|
|
346
765
|
|
|
347
|
-
if divisible_by(cache_len,
|
|
348
|
-
cache = cache[..., -
|
|
766
|
+
if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
|
|
767
|
+
cache = cache[..., -window_size:, :]
|
|
349
768
|
|
|
350
|
-
# maybe remove batch
|
|
769
|
+
# maybe remove batch or time
|
|
770
|
+
|
|
771
|
+
if not has_time_dim:
|
|
772
|
+
out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
|
|
351
773
|
|
|
352
774
|
if not has_batch_dim:
|
|
353
775
|
out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
|
|
@@ -364,7 +786,7 @@ class Locoformer(Module):
|
|
|
364
786
|
|
|
365
787
|
initial_logits = []
|
|
366
788
|
|
|
367
|
-
for state_segments in initial_states.split(
|
|
789
|
+
for state_segments in initial_states.split(self.window_size, dim = -1):
|
|
368
790
|
|
|
369
791
|
logits = stateful_forward(state_segments, return_values = False)
|
|
370
792
|
initial_logits.append(logits)
|
|
@@ -381,6 +803,8 @@ class Locoformer(Module):
|
|
|
381
803
|
return_values = False
|
|
382
804
|
):
|
|
383
805
|
|
|
806
|
+
state = state.to(self.device)
|
|
807
|
+
|
|
384
808
|
tokens = self.embedder(state)
|
|
385
809
|
|
|
386
810
|
embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: locoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.15
|
|
4
4
|
Summary: LocoFormer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/locoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/locoformer
|
|
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan
|
|
38
|
+
Requires-Dist: beartype
|
|
38
39
|
Requires-Dist: einops>=0.8.0
|
|
39
40
|
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
|
|
|
53
54
|
|
|
54
55
|
[LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
|
|
55
56
|
|
|
56
|
-
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
|
+
The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
|
|
57
58
|
|
|
58
59
|
## Sponsors
|
|
59
60
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
+
locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
|
|
3
|
+
locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
|
|
4
|
+
locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
locoformer-0.0.15.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
|
|
2
|
-
locoformer/locoformer.py,sha256=Yoh3hrj2E_91YLoYRa73wGzjdIiMdcd5ofNjkiVlogI,10570
|
|
3
|
-
locoformer-0.0.5.dist-info/METADATA,sha256=oe6HfOwWKQvusiJl1ukmNFcrGRhdDZ6NcKZi3upv-SY,3159
|
|
4
|
-
locoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
locoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
locoformer-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|