evolutionary-policy-optimization 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +111 -55
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.15.dist-info}/METADATA +10 -1
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.15.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.15.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -76,6 +76,14 @@ def maybe(fn):
|
|
76
76
|
def interface_torch_numpy(fn, device):
|
77
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
78
78
|
|
79
|
+
def to_torch_tensor(t):
|
80
|
+
if isinstance(t, (np.ndarray, np.float64)):
|
81
|
+
t = from_numpy(np.array(t))
|
82
|
+
elif isinstance(t, (float, int, bool)):
|
83
|
+
t = tensor(t)
|
84
|
+
|
85
|
+
return t.to(device)
|
86
|
+
|
79
87
|
@wraps(fn)
|
80
88
|
def decorated_fn(*args, **kwargs):
|
81
89
|
|
@@ -83,7 +91,7 @@ def interface_torch_numpy(fn, device):
|
|
83
91
|
|
84
92
|
out = fn(*args, **kwargs)
|
85
93
|
|
86
|
-
out = tree_map(
|
94
|
+
out = tree_map(to_torch_tensor, out)
|
87
95
|
return out
|
88
96
|
|
89
97
|
return decorated_fn
|
@@ -279,43 +287,87 @@ class PowerLawDist(Module):
|
|
279
287
|
|
280
288
|
return self.values[sampled]
|
281
289
|
|
290
|
+
# layer integrated memory
|
291
|
+
|
292
|
+
class DynamicLIMe(Module):
|
293
|
+
def __init__(
|
294
|
+
self,
|
295
|
+
dim,
|
296
|
+
num_layers
|
297
|
+
):
|
298
|
+
super().__init__()
|
299
|
+
self.num_layers = num_layers
|
300
|
+
|
301
|
+
self.to_weights = nn.Sequential(
|
302
|
+
nn.RMSNorm(dim),
|
303
|
+
nn.Linear(dim, num_layers),
|
304
|
+
nn.ReLU()
|
305
|
+
)
|
306
|
+
|
307
|
+
def forward(
|
308
|
+
self,
|
309
|
+
x,
|
310
|
+
hiddens
|
311
|
+
):
|
312
|
+
|
313
|
+
if not is_tensor(hiddens):
|
314
|
+
hiddens = stack(hiddens)
|
315
|
+
|
316
|
+
assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
|
317
|
+
|
318
|
+
weights = self.to_weights(x)
|
319
|
+
|
320
|
+
return einsum(hiddens, weights, 'l b d, b l -> b d')
|
321
|
+
|
282
322
|
# simple MLP networks, but with latent variables
|
283
323
|
# the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
|
284
324
|
|
285
325
|
class MLP(Module):
|
286
326
|
def __init__(
|
287
327
|
self,
|
288
|
-
|
328
|
+
dim,
|
329
|
+
depth,
|
289
330
|
dim_latent = 0,
|
331
|
+
expansion_factor = 2.
|
290
332
|
):
|
291
333
|
super().__init__()
|
292
334
|
dim_latent = default(dim_latent, 0)
|
293
335
|
|
294
|
-
assert len(dims) >= 2, 'must have at least two dimensions'
|
295
|
-
|
296
|
-
# add the latent to the first dim
|
297
|
-
|
298
|
-
first_dim, *rest_dims = dims
|
299
|
-
dims = (first_dim + dim_latent, *rest_dims)
|
300
|
-
|
301
336
|
self.dim_latent = dim_latent
|
302
337
|
|
303
338
|
self.needs_latent = dim_latent > 0
|
304
339
|
|
305
340
|
self.encode_latent = nn.Sequential(
|
306
|
-
Linear(dim_latent,
|
341
|
+
Linear(dim_latent, dim),
|
307
342
|
nn.SiLU()
|
308
343
|
) if self.needs_latent else None
|
309
344
|
|
310
|
-
|
345
|
+
dim_hidden = int(dim * expansion_factor)
|
311
346
|
|
312
|
-
|
347
|
+
# layers
|
313
348
|
|
314
|
-
|
349
|
+
layers = []
|
350
|
+
|
351
|
+
for ind in range(depth):
|
352
|
+
is_first = ind == 0
|
315
353
|
|
316
|
-
|
354
|
+
lime = DynamicLIMe(dim, num_layers = ind + 1) if not is_first else None
|
355
|
+
|
356
|
+
layer = nn.Sequential(
|
357
|
+
nn.RMSNorm(dim),
|
358
|
+
nn.Linear(dim, dim_hidden),
|
359
|
+
nn.SiLU(),
|
360
|
+
nn.Linear(dim_hidden, dim),
|
361
|
+
)
|
317
362
|
|
318
|
-
|
363
|
+
layers.append(ModuleList([
|
364
|
+
lime,
|
365
|
+
layer
|
366
|
+
]))
|
367
|
+
|
368
|
+
# modules across layers
|
369
|
+
|
370
|
+
self.layers = ModuleList(layers)
|
319
371
|
|
320
372
|
def forward(
|
321
373
|
self,
|
@@ -337,17 +389,22 @@ class MLP(Module):
|
|
337
389
|
|
338
390
|
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
339
391
|
|
340
|
-
x =
|
392
|
+
x = x * latent
|
341
393
|
|
342
394
|
# layers
|
343
395
|
|
344
|
-
|
345
|
-
|
396
|
+
prev_layer_inputs = [x]
|
397
|
+
|
398
|
+
for lime, layer in self.layers:
|
346
399
|
|
347
|
-
|
400
|
+
layer_inp = x
|
348
401
|
|
349
|
-
if
|
350
|
-
|
402
|
+
if exists(lime):
|
403
|
+
layer_inp = lime(x, prev_layer_inputs)
|
404
|
+
|
405
|
+
x = layer(layer_inp) + x
|
406
|
+
|
407
|
+
prev_layer_inputs.append(x)
|
351
408
|
|
352
409
|
return x
|
353
410
|
|
@@ -359,26 +416,24 @@ class Actor(Module):
|
|
359
416
|
self,
|
360
417
|
dim_state,
|
361
418
|
num_actions,
|
362
|
-
|
419
|
+
dim,
|
420
|
+
mlp_depth,
|
363
421
|
dim_latent = 0,
|
364
422
|
):
|
365
423
|
super().__init__()
|
366
424
|
|
367
|
-
assert len(dim_hiddens) >= 2
|
368
|
-
dim_first, *_, dim_last = dim_hiddens
|
369
|
-
|
370
425
|
self.dim_latent = dim_latent
|
371
426
|
|
372
427
|
self.init_layer = nn.Sequential(
|
373
|
-
nn.Linear(dim_state,
|
428
|
+
nn.Linear(dim_state, dim),
|
374
429
|
nn.SiLU()
|
375
430
|
)
|
376
431
|
|
377
|
-
self.mlp = MLP(
|
432
|
+
self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
|
378
433
|
|
379
434
|
self.to_out = nn.Sequential(
|
380
|
-
nn.
|
381
|
-
nn.Linear(
|
435
|
+
nn.RMSNorm(dim),
|
436
|
+
nn.Linear(dim, num_actions, bias = False),
|
382
437
|
)
|
383
438
|
|
384
439
|
def forward(
|
@@ -397,34 +452,31 @@ class Critic(Module):
|
|
397
452
|
def __init__(
|
398
453
|
self,
|
399
454
|
dim_state,
|
400
|
-
|
455
|
+
dim,
|
456
|
+
mlp_depth,
|
401
457
|
dim_latent = 0,
|
402
458
|
use_regression = False,
|
403
459
|
hl_gauss_loss_kwargs: dict = dict(
|
404
|
-
min_value = -
|
405
|
-
max_value =
|
406
|
-
num_bins =
|
407
|
-
sigma = 0.5
|
460
|
+
min_value = -100.,
|
461
|
+
max_value = 100.,
|
462
|
+
num_bins = 200
|
408
463
|
)
|
409
464
|
):
|
410
465
|
super().__init__()
|
411
466
|
|
412
|
-
assert len(dim_hiddens) >= 2
|
413
|
-
dim_first, *_, dim_last = dim_hiddens
|
414
|
-
|
415
467
|
self.dim_latent = dim_latent
|
416
468
|
|
417
469
|
self.init_layer = nn.Sequential(
|
418
|
-
nn.Linear(dim_state,
|
470
|
+
nn.Linear(dim_state, dim),
|
419
471
|
nn.SiLU()
|
420
472
|
)
|
421
473
|
|
422
|
-
self.mlp = MLP(
|
474
|
+
self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
|
423
475
|
|
424
|
-
self.
|
476
|
+
self.final_norm = nn.RMSNorm(dim)
|
425
477
|
|
426
478
|
self.to_pred = HLGaussLayer(
|
427
|
-
dim =
|
479
|
+
dim = dim,
|
428
480
|
use_regression = use_regression,
|
429
481
|
hl_gauss_loss = hl_gauss_loss_kwargs
|
430
482
|
)
|
@@ -488,7 +540,7 @@ class Critic(Module):
|
|
488
540
|
|
489
541
|
hidden = self.mlp(hidden, latent)
|
490
542
|
|
491
|
-
hidden = self.
|
543
|
+
hidden = self.final_norm(hidden)
|
492
544
|
|
493
545
|
pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
|
494
546
|
return self.to_pred(hidden, **pred_kwargs)
|
@@ -843,16 +895,16 @@ class Agent(Module):
|
|
843
895
|
critic: Critic,
|
844
896
|
latent_gene_pool: LatentGenePool | None,
|
845
897
|
optim_klass = AdoptAtan2,
|
846
|
-
actor_lr =
|
847
|
-
critic_lr =
|
898
|
+
actor_lr = 8e-4,
|
899
|
+
critic_lr = 8e-4,
|
848
900
|
latent_lr = 1e-5,
|
849
|
-
actor_weight_decay =
|
850
|
-
critic_weight_decay =
|
901
|
+
actor_weight_decay = 5e-4,
|
902
|
+
critic_weight_decay = 5e-4,
|
851
903
|
diversity_aux_loss_weight = 0.,
|
852
904
|
use_critic_ema = True,
|
853
|
-
critic_ema_beta = 0.
|
854
|
-
max_grad_norm = 0
|
855
|
-
batch_size =
|
905
|
+
critic_ema_beta = 0.95,
|
906
|
+
max_grad_norm = 1.0,
|
907
|
+
batch_size = 32,
|
856
908
|
calc_gae_kwargs: dict = dict(
|
857
909
|
use_accelerated = False,
|
858
910
|
gamma = 0.99,
|
@@ -1269,8 +1321,10 @@ def create_agent(
|
|
1269
1321
|
num_latents,
|
1270
1322
|
dim_latent,
|
1271
1323
|
actor_num_actions,
|
1272
|
-
|
1273
|
-
|
1324
|
+
actor_dim,
|
1325
|
+
actor_mlp_depth,
|
1326
|
+
critic_dim,
|
1327
|
+
critic_mlp_depth,
|
1274
1328
|
use_critic_ema = True,
|
1275
1329
|
latent_gene_pool_kwargs: dict = dict(),
|
1276
1330
|
actor_kwargs: dict = dict(),
|
@@ -1293,14 +1347,16 @@ def create_agent(
|
|
1293
1347
|
num_actions = actor_num_actions,
|
1294
1348
|
dim_state = dim_state,
|
1295
1349
|
dim_latent = dim_latent,
|
1296
|
-
|
1350
|
+
dim = actor_dim,
|
1351
|
+
mlp_depth = actor_mlp_depth,
|
1297
1352
|
**actor_kwargs
|
1298
1353
|
)
|
1299
1354
|
|
1300
1355
|
critic = Critic(
|
1301
1356
|
dim_state = dim_state,
|
1302
1357
|
dim_latent = dim_latent,
|
1303
|
-
|
1358
|
+
dim = critic_dim,
|
1359
|
+
mlp_depth = critic_mlp_depth,
|
1304
1360
|
**critic_kwargs
|
1305
1361
|
)
|
1306
1362
|
|
@@ -1475,7 +1531,7 @@ class EPO(Module):
|
|
1475
1531
|
log_prob,
|
1476
1532
|
reward,
|
1477
1533
|
value,
|
1478
|
-
|
1534
|
+
terminated
|
1479
1535
|
)
|
1480
1536
|
|
1481
1537
|
memory = Memory(*tuple(t.cpu() for t in memory))
|
@@ -1487,7 +1543,7 @@ class EPO(Module):
|
|
1487
1543
|
if not terminated:
|
1488
1544
|
# add bootstrap value if truncated
|
1489
1545
|
|
1490
|
-
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1546
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
|
1491
1547
|
|
1492
1548
|
memory_for_gae = memory._replace(
|
1493
1549
|
episode_id = invalid_episode,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.15
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -226,4 +226,13 @@ agent.load('./agent.pt')
|
|
226
226
|
}
|
227
227
|
```
|
228
228
|
|
229
|
+
```bibtex
|
230
|
+
@inproceedings{Gerasimov2025YouDN,
|
231
|
+
title = {You Do Not Fully Utilize Transformer's Representation Capacity},
|
232
|
+
author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
|
233
|
+
year = {2025},
|
234
|
+
url = {https://api.semanticscholar.org/CorpusID:276317819}
|
235
|
+
}
|
236
|
+
```
|
237
|
+
|
229
238
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=iUxd7gbT1GPGwso4utTaxgtjcxvvGNA8AGGUpSOImTM,47108
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.15.dist-info/METADATA,sha256=e8ofJe5rpGIyEiMd3mJBU-2VjOfFJ8TpGGv7adSKjRM,7962
|
8
|
+
evolutionary_policy_optimization-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.15.dist-info/RECORD,,
|
File without changes
|