evolutionary-policy-optimization 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +63 -54
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.14.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.14.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.14.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.12.dist-info → evolutionary_policy_optimization-0.1.14.dist-info}/licenses/LICENSE +0 -0
@@ -76,6 +76,14 @@ def maybe(fn):
|
|
76
76
|
def interface_torch_numpy(fn, device):
|
77
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
78
78
|
|
79
|
+
def to_torch_tensor(t):
|
80
|
+
if isinstance(t, (np.ndarray, np.float64)):
|
81
|
+
t = from_numpy(np.array(t))
|
82
|
+
elif isinstance(t, (float, int, bool)):
|
83
|
+
t = tensor(t)
|
84
|
+
|
85
|
+
return t.to(device)
|
86
|
+
|
79
87
|
@wraps(fn)
|
80
88
|
def decorated_fn(*args, **kwargs):
|
81
89
|
|
@@ -83,7 +91,7 @@ def interface_torch_numpy(fn, device):
|
|
83
91
|
|
84
92
|
out = fn(*args, **kwargs)
|
85
93
|
|
86
|
-
out = tree_map(
|
94
|
+
out = tree_map(to_torch_tensor, out)
|
87
95
|
return out
|
88
96
|
|
89
97
|
return decorated_fn
|
@@ -285,37 +293,42 @@ class PowerLawDist(Module):
|
|
285
293
|
class MLP(Module):
|
286
294
|
def __init__(
|
287
295
|
self,
|
288
|
-
|
296
|
+
dim,
|
297
|
+
depth,
|
289
298
|
dim_latent = 0,
|
299
|
+
expansion_factor = 2.
|
290
300
|
):
|
291
301
|
super().__init__()
|
292
302
|
dim_latent = default(dim_latent, 0)
|
293
303
|
|
294
|
-
assert len(dims) >= 2, 'must have at least two dimensions'
|
295
|
-
|
296
|
-
# add the latent to the first dim
|
297
|
-
|
298
|
-
first_dim, *rest_dims = dims
|
299
|
-
dims = (first_dim + dim_latent, *rest_dims)
|
300
|
-
|
301
304
|
self.dim_latent = dim_latent
|
302
305
|
|
303
306
|
self.needs_latent = dim_latent > 0
|
304
307
|
|
305
308
|
self.encode_latent = nn.Sequential(
|
306
|
-
Linear(dim_latent,
|
309
|
+
Linear(dim_latent, dim),
|
307
310
|
nn.SiLU()
|
308
311
|
) if self.needs_latent else None
|
309
312
|
|
310
|
-
|
313
|
+
dim_hidden = int(dim * expansion_factor)
|
311
314
|
|
312
|
-
|
315
|
+
# layers
|
313
316
|
|
314
|
-
|
317
|
+
layers = []
|
315
318
|
|
316
|
-
|
319
|
+
for _ in range(depth):
|
320
|
+
layer = nn.Sequential(
|
321
|
+
nn.LayerNorm(dim, bias = False),
|
322
|
+
nn.Linear(dim, dim_hidden),
|
323
|
+
nn.SiLU(),
|
324
|
+
nn.Linear(dim_hidden, dim),
|
325
|
+
)
|
326
|
+
|
327
|
+
layers.append(layer)
|
328
|
+
|
329
|
+
# modules across layers
|
317
330
|
|
318
|
-
self.layers = layers
|
331
|
+
self.layers = ModuleList(layers)
|
319
332
|
|
320
333
|
def forward(
|
321
334
|
self,
|
@@ -337,17 +350,14 @@ class MLP(Module):
|
|
337
350
|
|
338
351
|
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
339
352
|
|
340
|
-
x =
|
353
|
+
x = x * latent
|
341
354
|
|
342
355
|
# layers
|
343
356
|
|
344
357
|
for ind, layer in enumerate(self.layers, start = 1):
|
345
358
|
is_last = ind == len(self.layers)
|
346
359
|
|
347
|
-
x = layer(x)
|
348
|
-
|
349
|
-
if not is_last:
|
350
|
-
x = F.silu(x)
|
360
|
+
x = layer(x) + x
|
351
361
|
|
352
362
|
return x
|
353
363
|
|
@@ -359,26 +369,24 @@ class Actor(Module):
|
|
359
369
|
self,
|
360
370
|
dim_state,
|
361
371
|
num_actions,
|
362
|
-
|
372
|
+
dim,
|
373
|
+
mlp_depth,
|
363
374
|
dim_latent = 0,
|
364
375
|
):
|
365
376
|
super().__init__()
|
366
377
|
|
367
|
-
assert len(dim_hiddens) >= 2
|
368
|
-
dim_first, *_, dim_last = dim_hiddens
|
369
|
-
|
370
378
|
self.dim_latent = dim_latent
|
371
379
|
|
372
380
|
self.init_layer = nn.Sequential(
|
373
|
-
nn.Linear(dim_state,
|
381
|
+
nn.Linear(dim_state, dim),
|
374
382
|
nn.SiLU()
|
375
383
|
)
|
376
384
|
|
377
|
-
self.mlp = MLP(
|
385
|
+
self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
|
378
386
|
|
379
387
|
self.to_out = nn.Sequential(
|
380
|
-
nn.
|
381
|
-
nn.Linear(
|
388
|
+
nn.LayerNorm(dim, bias = False),
|
389
|
+
nn.Linear(dim, num_actions, bias = False),
|
382
390
|
)
|
383
391
|
|
384
392
|
def forward(
|
@@ -397,34 +405,31 @@ class Critic(Module):
|
|
397
405
|
def __init__(
|
398
406
|
self,
|
399
407
|
dim_state,
|
400
|
-
|
408
|
+
dim,
|
409
|
+
mlp_depth,
|
401
410
|
dim_latent = 0,
|
402
411
|
use_regression = False,
|
403
412
|
hl_gauss_loss_kwargs: dict = dict(
|
404
|
-
min_value = -
|
405
|
-
max_value =
|
406
|
-
num_bins =
|
407
|
-
sigma = 0.5
|
413
|
+
min_value = -100.,
|
414
|
+
max_value = 100.,
|
415
|
+
num_bins = 200
|
408
416
|
)
|
409
417
|
):
|
410
418
|
super().__init__()
|
411
419
|
|
412
|
-
assert len(dim_hiddens) >= 2
|
413
|
-
dim_first, *_, dim_last = dim_hiddens
|
414
|
-
|
415
420
|
self.dim_latent = dim_latent
|
416
421
|
|
417
422
|
self.init_layer = nn.Sequential(
|
418
|
-
nn.Linear(dim_state,
|
423
|
+
nn.Linear(dim_state, dim),
|
419
424
|
nn.SiLU()
|
420
425
|
)
|
421
426
|
|
422
|
-
self.mlp = MLP(
|
427
|
+
self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
|
423
428
|
|
424
|
-
self.
|
429
|
+
self.final_norm = nn.LayerNorm(dim, bias = False)
|
425
430
|
|
426
431
|
self.to_pred = HLGaussLayer(
|
427
|
-
dim =
|
432
|
+
dim = dim,
|
428
433
|
use_regression = use_regression,
|
429
434
|
hl_gauss_loss = hl_gauss_loss_kwargs
|
430
435
|
)
|
@@ -488,7 +493,7 @@ class Critic(Module):
|
|
488
493
|
|
489
494
|
hidden = self.mlp(hidden, latent)
|
490
495
|
|
491
|
-
hidden = self.
|
496
|
+
hidden = self.final_norm(hidden)
|
492
497
|
|
493
498
|
pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
|
494
499
|
return self.to_pred(hidden, **pred_kwargs)
|
@@ -843,16 +848,16 @@ class Agent(Module):
|
|
843
848
|
critic: Critic,
|
844
849
|
latent_gene_pool: LatentGenePool | None,
|
845
850
|
optim_klass = AdoptAtan2,
|
846
|
-
actor_lr =
|
847
|
-
critic_lr =
|
851
|
+
actor_lr = 8e-4,
|
852
|
+
critic_lr = 8e-4,
|
848
853
|
latent_lr = 1e-5,
|
849
|
-
actor_weight_decay =
|
850
|
-
critic_weight_decay =
|
854
|
+
actor_weight_decay = 5e-4,
|
855
|
+
critic_weight_decay = 5e-4,
|
851
856
|
diversity_aux_loss_weight = 0.,
|
852
857
|
use_critic_ema = True,
|
853
|
-
critic_ema_beta = 0.
|
854
|
-
max_grad_norm = 0
|
855
|
-
batch_size =
|
858
|
+
critic_ema_beta = 0.95,
|
859
|
+
max_grad_norm = 1.0,
|
860
|
+
batch_size = 32,
|
856
861
|
calc_gae_kwargs: dict = dict(
|
857
862
|
use_accelerated = False,
|
858
863
|
gamma = 0.99,
|
@@ -1269,8 +1274,10 @@ def create_agent(
|
|
1269
1274
|
num_latents,
|
1270
1275
|
dim_latent,
|
1271
1276
|
actor_num_actions,
|
1272
|
-
|
1273
|
-
|
1277
|
+
actor_dim,
|
1278
|
+
actor_mlp_depth,
|
1279
|
+
critic_dim,
|
1280
|
+
critic_mlp_depth,
|
1274
1281
|
use_critic_ema = True,
|
1275
1282
|
latent_gene_pool_kwargs: dict = dict(),
|
1276
1283
|
actor_kwargs: dict = dict(),
|
@@ -1293,14 +1300,16 @@ def create_agent(
|
|
1293
1300
|
num_actions = actor_num_actions,
|
1294
1301
|
dim_state = dim_state,
|
1295
1302
|
dim_latent = dim_latent,
|
1296
|
-
|
1303
|
+
dim = actor_dim,
|
1304
|
+
mlp_depth = actor_mlp_depth,
|
1297
1305
|
**actor_kwargs
|
1298
1306
|
)
|
1299
1307
|
|
1300
1308
|
critic = Critic(
|
1301
1309
|
dim_state = dim_state,
|
1302
1310
|
dim_latent = dim_latent,
|
1303
|
-
|
1311
|
+
dim = critic_dim,
|
1312
|
+
mlp_depth = critic_mlp_depth,
|
1304
1313
|
**critic_kwargs
|
1305
1314
|
)
|
1306
1315
|
|
@@ -1475,7 +1484,7 @@ class EPO(Module):
|
|
1475
1484
|
log_prob,
|
1476
1485
|
reward,
|
1477
1486
|
value,
|
1478
|
-
|
1487
|
+
terminated
|
1479
1488
|
)
|
1480
1489
|
|
1481
1490
|
memory = Memory(*tuple(t.cpu() for t in memory))
|
@@ -1487,7 +1496,7 @@ class EPO(Module):
|
|
1487
1496
|
if not terminated:
|
1488
1497
|
# add bootstrap value if truncated
|
1489
1498
|
|
1490
|
-
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
|
1499
|
+
next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
|
1491
1500
|
|
1492
1501
|
memory_for_gae = memory._replace(
|
1493
1502
|
episode_id = invalid_episode,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.14
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=RTMVCo1joMEKIkqTQLsLgTeOuJVbvkNbX9hcOOL0oCw,46088
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.14.dist-info/METADATA,sha256=TCj1gzhViuNYP-TnVHCcvWHkTIyqiz-AAO-xhoVahBo,7625
|
8
|
+
evolutionary_policy_optimization-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.14.dist-info/RECORD,,
|
File without changes
|