evolutionary-policy-optimization 0.0.48__tar.gz → 0.0.50__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/PKG-INFO +1 -1
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/epo.py +32 -7
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/tests/test_epo.py +2 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/requirements.txt +0 -0
{evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.50
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -6,7 +6,7 @@ from collections import namedtuple
|
|
6
6
|
from random import randrange
|
7
7
|
|
8
8
|
import torch
|
9
|
-
from torch import nn, cat, stack, is_tensor, tensor
|
9
|
+
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
10
10
|
import torch.nn.functional as F
|
11
11
|
from torch.nn import Linear, Module, ModuleList
|
12
12
|
from torch.utils.data import TensorDataset, DataLoader
|
@@ -412,7 +412,16 @@ class LatentGenePool(Module):
|
|
412
412
|
|
413
413
|
self.can_migrate = num_islands > 1
|
414
414
|
self.migrate_every = migrate_every
|
415
|
-
self.register_buffer('step', tensor(
|
415
|
+
self.register_buffer('step', tensor(1))
|
416
|
+
|
417
|
+
def get_distance(self):
|
418
|
+
# returns latent euclidean distance as proxy for diversity
|
419
|
+
|
420
|
+
latents = rearrange(self.latents, '(i p) g -> i p g', i = self.num_islands)
|
421
|
+
|
422
|
+
distance = torch.cdist(latents, latents)
|
423
|
+
|
424
|
+
return distance
|
416
425
|
|
417
426
|
def advance_step_(self):
|
418
427
|
self.step.add_(1)
|
@@ -643,6 +652,7 @@ class Agent(Module):
|
|
643
652
|
actor_lr = 1e-4,
|
644
653
|
critic_lr = 1e-4,
|
645
654
|
latent_lr = 1e-5,
|
655
|
+
diversity_aux_loss_weight = 0.,
|
646
656
|
use_critic_ema = True,
|
647
657
|
critic_ema_beta = 0.99,
|
648
658
|
max_grad_norm = 0.5,
|
@@ -698,6 +708,11 @@ class Agent(Module):
|
|
698
708
|
|
699
709
|
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
700
710
|
|
711
|
+
# promotes latents to be farther apart for diversity maintenance
|
712
|
+
|
713
|
+
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
714
|
+
self.diversity_aux_loss_weight = diversity_aux_loss_weight
|
715
|
+
|
701
716
|
def save(self, path, overwrite = False):
|
702
717
|
path = Path(path)
|
703
718
|
|
@@ -879,11 +894,19 @@ class Agent(Module):
|
|
879
894
|
|
880
895
|
# maybe update latents, if not frozen
|
881
896
|
|
882
|
-
if
|
883
|
-
|
897
|
+
if self.latent_gene_pool.frozen_latents:
|
898
|
+
continue
|
899
|
+
|
900
|
+
orig_latents.backward(latents.grad)
|
901
|
+
|
902
|
+
if self.has_diversity_loss:
|
903
|
+
diversity = self.latent_gene_pool.get_distance()
|
904
|
+
diversity_loss = diversity.mul(-1).exp().mean()
|
905
|
+
|
906
|
+
(diversity_loss * self.diversity_aux_loss_weight).backward()
|
884
907
|
|
885
|
-
|
886
|
-
|
908
|
+
self.latent_optim.step()
|
909
|
+
self.latent_optim.zero_grad()
|
887
910
|
|
888
911
|
# apply evolution
|
889
912
|
|
@@ -930,6 +953,7 @@ def create_agent(
|
|
930
953
|
latent_gene_pool_kwargs: dict = dict(),
|
931
954
|
actor_kwargs: dict = dict(),
|
932
955
|
critic_kwargs: dict = dict(),
|
956
|
+
**kwargs
|
933
957
|
) -> Agent:
|
934
958
|
|
935
959
|
latent_gene_pool = LatentGenePool(
|
@@ -957,7 +981,8 @@ def create_agent(
|
|
957
981
|
actor = actor,
|
958
982
|
critic = critic,
|
959
983
|
latent_gene_pool = latent_gene_pool,
|
960
|
-
use_critic_ema = use_critic_ema
|
984
|
+
use_critic_ema = use_critic_ema,
|
985
|
+
**kwargs
|
961
986
|
)
|
962
987
|
|
963
988
|
return agent
|
@@ -75,6 +75,7 @@ def test_create_agent(
|
|
75
75
|
|
76
76
|
@pytest.mark.parametrize('frozen_latents', (False, True))
|
77
77
|
@pytest.mark.parametrize('use_critic_ema', (False, True))
|
78
|
+
@pytest.mark.parametrize('diversity_aux_loss_weight', (0., 1e-3))
|
78
79
|
def test_e2e_with_mock_env(
|
79
80
|
frozen_latents,
|
80
81
|
use_critic_ema
|
@@ -89,6 +90,7 @@ def test_e2e_with_mock_env(
|
|
89
90
|
actor_dim_hiddens = (256, 128),
|
90
91
|
critic_dim_hiddens = (256, 128, 64),
|
91
92
|
use_critic_ema = use_critic_ema,
|
93
|
+
diversity_aux_loss_weight = diversity_aux_loss_weight,
|
92
94
|
latent_gene_pool_kwargs = dict(
|
93
95
|
frozen_latents = frozen_latents,
|
94
96
|
)
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.48 → evolutionary_policy_optimization-0.0.50}/requirements.txt
RENAMED
File without changes
|