evolutionary-policy-optimization 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +33 -10
- {evolutionary_policy_optimization-0.0.44.dist-info → evolutionary_policy_optimization-0.0.46.dist-info}/METADATA +1 -1
- {evolutionary_policy_optimization-0.0.44.dist-info → evolutionary_policy_optimization-0.0.46.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.0.44.dist-info → evolutionary_policy_optimization-0.0.46.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.44.dist-info → evolutionary_policy_optimization-0.0.46.dist-info}/licenses/LICENSE +0 -0
@@ -87,6 +87,11 @@ def temp_batch_dim(fn):
|
|
87
87
|
|
88
88
|
return inner
|
89
89
|
|
90
|
+
# fitness related
|
91
|
+
|
92
|
+
def get_fitness_scores(cum_rewards, memories):
|
93
|
+
return cum_rewards
|
94
|
+
|
90
95
|
# generalized advantage estimate
|
91
96
|
|
92
97
|
def calc_generalized_advantage_estimate(
|
@@ -628,6 +633,7 @@ class Agent(Module):
|
|
628
633
|
latent_lr = 1e-5,
|
629
634
|
use_critic_ema = True,
|
630
635
|
critic_ema_beta = 0.99,
|
636
|
+
max_grad_norm = 0.5,
|
631
637
|
batch_size = 16,
|
632
638
|
calc_gae_kwargs: dict = dict(
|
633
639
|
use_accelerated = False,
|
@@ -642,6 +648,7 @@ class Agent(Module):
|
|
642
648
|
actor_optim_kwargs: dict = dict(),
|
643
649
|
critic_optim_kwargs: dict = dict(),
|
644
650
|
latent_optim_kwargs: dict = dict(),
|
651
|
+
get_fitness_scores: Callable[..., Tensor] = get_fitness_scores
|
645
652
|
):
|
646
653
|
super().__init__()
|
647
654
|
|
@@ -662,7 +669,15 @@ class Agent(Module):
|
|
662
669
|
self.actor_loss = partial(actor_loss, **actor_loss_kwargs)
|
663
670
|
self.calc_gae = partial(calc_generalized_advantage_estimate, **calc_gae_kwargs)
|
664
671
|
|
672
|
+
# fitness score related
|
673
|
+
|
674
|
+
self.get_fitness_scores = get_fitness_scores
|
675
|
+
|
676
|
+
# learning hparams
|
677
|
+
|
665
678
|
self.batch_size = batch_size
|
679
|
+
self.max_grad_norm = max_grad_norm
|
680
|
+
self.has_grad_clip = exists(max_grad_norm)
|
666
681
|
|
667
682
|
# optimizers
|
668
683
|
|
@@ -761,10 +776,12 @@ class Agent(Module):
|
|
761
776
|
|
762
777
|
def forward(
|
763
778
|
self,
|
764
|
-
|
779
|
+
memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
|
765
780
|
epochs = 2
|
766
781
|
):
|
767
|
-
memories,
|
782
|
+
memories, cumulative_rewards = memories_and_cumulative_rewards
|
783
|
+
|
784
|
+
fitness_scores = self.get_fitness_scores(cumulative_rewards, memories)
|
768
785
|
|
769
786
|
(
|
770
787
|
episode_ids,
|
@@ -821,6 +838,9 @@ class Agent(Module):
|
|
821
838
|
|
822
839
|
actor_loss.backward()
|
823
840
|
|
841
|
+
if exists(self.has_grad_clip):
|
842
|
+
nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
|
843
|
+
|
824
844
|
self.actor_optim.step()
|
825
845
|
self.actor_optim.zero_grad()
|
826
846
|
|
@@ -834,6 +854,9 @@ class Agent(Module):
|
|
834
854
|
|
835
855
|
critic_loss.backward()
|
836
856
|
|
857
|
+
if exists(self.has_grad_clip):
|
858
|
+
nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
|
859
|
+
|
837
860
|
self.critic_optim.step()
|
838
861
|
self.critic_optim.zero_grad()
|
839
862
|
|
@@ -941,9 +964,9 @@ Memory = namedtuple('Memory', [
|
|
941
964
|
'done'
|
942
965
|
])
|
943
966
|
|
944
|
-
|
967
|
+
MemoriesAndCumulativeRewards = namedtuple('MemoriesAndCumulativeRewards', [
|
945
968
|
'memories',
|
946
|
-
'
|
969
|
+
'cumulative_rewards'
|
947
970
|
])
|
948
971
|
|
949
972
|
class EPO(Module):
|
@@ -967,7 +990,7 @@ class EPO(Module):
|
|
967
990
|
def forward(
|
968
991
|
self,
|
969
992
|
env
|
970
|
-
) ->
|
993
|
+
) -> MemoriesAndCumulativeRewards:
|
971
994
|
|
972
995
|
self.agent.eval()
|
973
996
|
|
@@ -975,7 +998,7 @@ class EPO(Module):
|
|
975
998
|
|
976
999
|
memories: list[Memory] = []
|
977
1000
|
|
978
|
-
|
1001
|
+
cumulative_rewards = torch.zeros((self.num_latents))
|
979
1002
|
|
980
1003
|
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
981
1004
|
|
@@ -1008,9 +1031,9 @@ class EPO(Module):
|
|
1008
1031
|
|
1009
1032
|
state, reward, done = env(action)
|
1010
1033
|
|
1011
|
-
# update
|
1034
|
+
# update cumulative rewards per latent, to be used as default fitness score
|
1012
1035
|
|
1013
|
-
|
1036
|
+
cumulative_rewards[latent_id] += reward
|
1014
1037
|
|
1015
1038
|
# store memories
|
1016
1039
|
|
@@ -1040,7 +1063,7 @@ class EPO(Module):
|
|
1040
1063
|
|
1041
1064
|
memories.append(memory_for_gae)
|
1042
1065
|
|
1043
|
-
return
|
1066
|
+
return MemoriesAndCumulativeRewards(
|
1044
1067
|
memories = memories,
|
1045
|
-
|
1068
|
+
cumulative_rewards = cumulative_rewards
|
1046
1069
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.46
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,8 +1,8 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=SAhWgRY8uPQEKFg1_nz1mvh8A6S_sHwnDykhd0F5xEI,31853
|
3
3
|
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
4
|
evolutionary_policy_optimization/mock_env.py,sha256=6AIc4mwL_C6JkAxwESJgCLxXHMzCAu2FcffVg3HkSm0,920
|
5
|
-
evolutionary_policy_optimization-0.0.
|
6
|
-
evolutionary_policy_optimization-0.0.
|
7
|
-
evolutionary_policy_optimization-0.0.
|
8
|
-
evolutionary_policy_optimization-0.0.
|
5
|
+
evolutionary_policy_optimization-0.0.46.dist-info/METADATA,sha256=xP2kdKo52-X4Z5XXTPpW0M_NFI0spuigeL7fvqFlsRM,6213
|
6
|
+
evolutionary_policy_optimization-0.0.46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.46.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.46.dist-info/RECORD,,
|
File without changes
|