evolutionary-policy-optimization 0.0.37__py3-none-any.whl → 0.0.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/__init__.py +4 -1
- evolutionary_policy_optimization/epo.py +180 -51
- evolutionary_policy_optimization/experimental.py +27 -7
- evolutionary_policy_optimization/mock_env.py +8 -3
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.39.dist-info}/METADATA +40 -3
- evolutionary_policy_optimization-0.0.39.dist-info/RECORD +8 -0
- evolutionary_policy_optimization-0.0.37.dist-info/RECORD +0 -8
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.39.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.39.dist-info}/licenses/LICENSE +0 -0
@@ -5,13 +5,13 @@ from pathlib import Path
|
|
5
5
|
from collections import namedtuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from torch import nn, cat, is_tensor, tensor
|
8
|
+
from torch import nn, cat, stack, is_tensor, tensor
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
from torch.utils.data import TensorDataset, DataLoader
|
12
12
|
|
13
13
|
import einx
|
14
|
-
from einops import rearrange, repeat, einsum
|
14
|
+
from einops import rearrange, repeat, einsum, pack
|
15
15
|
from einops.layers.torch import Rearrange
|
16
16
|
|
17
17
|
from assoc_scan import AssocScan
|
@@ -22,6 +22,8 @@ from hl_gauss_pytorch import HLGaussLayer
|
|
22
22
|
|
23
23
|
from ema_pytorch import EMA
|
24
24
|
|
25
|
+
from tqdm import tqdm
|
26
|
+
|
25
27
|
# helpers
|
26
28
|
|
27
29
|
def exists(v):
|
@@ -47,9 +49,20 @@ def l2norm(t):
|
|
47
49
|
def log(t, eps = 1e-20):
|
48
50
|
return t.clamp(min = eps).log()
|
49
51
|
|
52
|
+
def gumbel_noise(t):
|
53
|
+
return -log(-log(torch.rand_like(t)))
|
54
|
+
|
55
|
+
def gumbel_sample(t, temperature = 1.):
|
56
|
+
is_greedy = temperature <= 0.
|
57
|
+
|
58
|
+
if not is_greedy:
|
59
|
+
t = (t / temperature) + gumbel_noise(t)
|
60
|
+
|
61
|
+
return t.argmax(dim = -1)
|
62
|
+
|
50
63
|
def calc_entropy(logits):
|
51
64
|
prob = logits.softmax(dim = -1)
|
52
|
-
return -prob * log(prob)
|
65
|
+
return -(prob * log(prob)).sum(dim = -1)
|
53
66
|
|
54
67
|
def gather_log_prob(
|
55
68
|
logits, # Float[b l]
|
@@ -63,8 +76,8 @@ def gather_log_prob(
|
|
63
76
|
# generalized advantage estimate
|
64
77
|
|
65
78
|
def calc_generalized_advantage_estimate(
|
66
|
-
rewards, # Float[
|
67
|
-
values, # Float[
|
79
|
+
rewards, # Float[n]
|
80
|
+
values, # Float[n+1]
|
68
81
|
masks, # Bool[n]
|
69
82
|
gamma = 0.99,
|
70
83
|
lam = 0.95,
|
@@ -75,9 +88,7 @@ def calc_generalized_advantage_estimate(
|
|
75
88
|
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
76
89
|
device = rewards.device
|
77
90
|
|
78
|
-
|
79
|
-
|
80
|
-
values, values_next = values[:, :-1], values[:, 1:]
|
91
|
+
values, values_next = values[:-1], values[1:]
|
81
92
|
|
82
93
|
delta = rewards + gamma * values_next * masks - values
|
83
94
|
gates = gamma * lam * masks
|
@@ -319,7 +330,6 @@ class LatentGenePool(Module):
|
|
319
330
|
num_latents, # same as gene pool size
|
320
331
|
dim_latent, # gene dimension
|
321
332
|
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
322
|
-
dim_state = None,
|
323
333
|
frozen_latents = True,
|
324
334
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
325
335
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
@@ -384,7 +394,6 @@ class LatentGenePool(Module):
|
|
384
394
|
fitness,
|
385
395
|
beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
|
386
396
|
gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
|
387
|
-
alpha = 0.1, # exploration factor
|
388
397
|
inplace = True,
|
389
398
|
):
|
390
399
|
islands = self.num_islands
|
@@ -555,7 +564,6 @@ class LatentGenePool(Module):
|
|
555
564
|
def forward(
|
556
565
|
self,
|
557
566
|
*args,
|
558
|
-
state: Tensor | None = None,
|
559
567
|
latent_id: int | None = None,
|
560
568
|
net: Module | None = None,
|
561
569
|
net_latent_kwarg_name = 'latent',
|
@@ -568,6 +576,8 @@ class LatentGenePool(Module):
|
|
568
576
|
if not exists(latent_id) and self.num_latents == 1:
|
569
577
|
latent_id = 0
|
570
578
|
|
579
|
+
assert exists(latent_id)
|
580
|
+
|
571
581
|
if not is_tensor(latent_id):
|
572
582
|
latent_id = tensor(latent_id, device = device)
|
573
583
|
|
@@ -575,8 +585,6 @@ class LatentGenePool(Module):
|
|
575
585
|
|
576
586
|
# fetch latent
|
577
587
|
|
578
|
-
fetching_multiple_latents = latent_id.numel() > 1
|
579
|
-
|
580
588
|
latent = self.latents[latent_id]
|
581
589
|
|
582
590
|
latent = self.maybe_l2norm(latent)
|
@@ -686,17 +694,38 @@ class Agent(Module):
|
|
686
694
|
def get_actor_actions(
|
687
695
|
self,
|
688
696
|
state,
|
689
|
-
latent_id
|
697
|
+
latent_id = None,
|
698
|
+
latent = None,
|
699
|
+
sample = False,
|
700
|
+
temperature = 1.
|
690
701
|
):
|
691
|
-
|
692
|
-
|
702
|
+
assert exists(latent_id) or exists(latent)
|
703
|
+
|
704
|
+
if not exists(latent):
|
705
|
+
latent = self.latent_gene_pool(latent_id = latent_id)
|
706
|
+
|
707
|
+
logits = self.actor(state, latent)
|
708
|
+
|
709
|
+
if not sample:
|
710
|
+
return logits
|
711
|
+
|
712
|
+
actions = gumbel_sample(logits, temperature = temperature)
|
713
|
+
|
714
|
+
log_probs = gather_log_prob(logits, actions)
|
715
|
+
|
716
|
+
return actions, log_probs
|
693
717
|
|
694
718
|
def get_critic_values(
|
695
719
|
self,
|
696
720
|
state,
|
697
|
-
latent_id
|
721
|
+
latent_id = None,
|
722
|
+
latent = None
|
698
723
|
):
|
699
|
-
|
724
|
+
assert exists(latent_id) or exists(latent)
|
725
|
+
|
726
|
+
if not exists(latent):
|
727
|
+
latent = self.latent_gene_pool(latent_id = latent_id)
|
728
|
+
|
700
729
|
return self.critic(state, latent)
|
701
730
|
|
702
731
|
def update_latent_gene_pool_(
|
@@ -707,12 +736,13 @@ class Agent(Module):
|
|
707
736
|
|
708
737
|
def forward(
|
709
738
|
self,
|
710
|
-
|
739
|
+
memories_and_fitness_scores: MemoriesAndFitnessScores,
|
711
740
|
epochs = 2
|
712
741
|
):
|
713
|
-
memories,
|
742
|
+
memories, fitness_scores = memories_and_fitness_scores
|
714
743
|
|
715
744
|
(
|
745
|
+
episode_ids,
|
716
746
|
states,
|
717
747
|
latent_gene_ids,
|
718
748
|
actions,
|
@@ -722,35 +752,46 @@ class Agent(Module):
|
|
722
752
|
dones
|
723
753
|
) = map(stack, zip(*memories))
|
724
754
|
|
725
|
-
|
755
|
+
advantages = self.calc_gae(
|
756
|
+
rewards[:-1],
|
757
|
+
values,
|
758
|
+
dones[:-1],
|
759
|
+
)
|
726
760
|
|
727
|
-
|
761
|
+
valid_episode = episode_ids >= 0
|
728
762
|
|
729
|
-
dataset = TensorDataset(
|
763
|
+
dataset = TensorDataset(
|
764
|
+
*[
|
765
|
+
advantages[valid_episode[:-1]],
|
766
|
+
*[t[valid_episode] for t in (states, latent_gene_ids, actions, log_probs, values)]
|
767
|
+
]
|
768
|
+
)
|
730
769
|
|
731
770
|
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
732
771
|
|
733
772
|
self.actor.train()
|
734
773
|
self.critic.train()
|
735
774
|
|
736
|
-
for _ in range(epochs):
|
775
|
+
for _ in tqdm(range(epochs), desc = 'learning actor/critic epoch'):
|
737
776
|
for (
|
777
|
+
advantages,
|
738
778
|
states,
|
739
779
|
latent_gene_ids,
|
740
780
|
actions,
|
741
781
|
log_probs,
|
742
|
-
advantages,
|
743
782
|
old_values
|
744
783
|
) in dataloader:
|
745
784
|
|
746
|
-
latents = self.latent_gene_pool(latent_gene_ids)
|
785
|
+
latents = self.latent_gene_pool(latent_id = latent_gene_ids)
|
747
786
|
|
748
787
|
# learn actor
|
749
788
|
|
750
789
|
logits = self.actor(states, latents)
|
790
|
+
|
751
791
|
actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
|
752
792
|
|
753
793
|
actor_loss.backward()
|
794
|
+
|
754
795
|
self.actor_optim.step()
|
755
796
|
self.actor_optim.zero_grad()
|
756
797
|
|
@@ -759,7 +800,7 @@ class Agent(Module):
|
|
759
800
|
critic_loss = self.critic(
|
760
801
|
states,
|
761
802
|
latents,
|
762
|
-
|
803
|
+
target = advantages + old_values
|
763
804
|
)
|
764
805
|
|
765
806
|
critic_loss.backward()
|
@@ -767,6 +808,10 @@ class Agent(Module):
|
|
767
808
|
self.critic_optim.step()
|
768
809
|
self.critic_optim.zero_grad()
|
769
810
|
|
811
|
+
# apply evolution
|
812
|
+
|
813
|
+
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
814
|
+
|
770
815
|
# reinforcement learning related - ppo
|
771
816
|
|
772
817
|
def actor_loss(
|
@@ -785,7 +830,7 @@ def actor_loss(
|
|
785
830
|
|
786
831
|
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
787
832
|
|
788
|
-
actor_loss = -torch.min(clipped_ratio *
|
833
|
+
actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
|
789
834
|
|
790
835
|
# add entropy loss for exploration
|
791
836
|
|
@@ -793,15 +838,7 @@ def actor_loss(
|
|
793
838
|
|
794
839
|
entropy_aux_loss = -entropy_weight * entropy
|
795
840
|
|
796
|
-
return actor_loss + entropy_aux_loss
|
797
|
-
|
798
|
-
def critic_loss(
|
799
|
-
pred_values, # Float[b]
|
800
|
-
advantages, # Float[b]
|
801
|
-
old_values # Float[b]
|
802
|
-
):
|
803
|
-
discounted_values = advantages + old_values
|
804
|
-
return F.mse_loss(pred_values, discounted_values)
|
841
|
+
return (actor_loss + entropy_aux_loss).mean()
|
805
842
|
|
806
843
|
# agent contains the actor, critic, and the latent genetic pool
|
807
844
|
|
@@ -814,6 +851,11 @@ def create_agent(
|
|
814
851
|
critic_dim_hiddens: int | tuple[int, ...],
|
815
852
|
) -> Agent:
|
816
853
|
|
854
|
+
latent_gene_pool = LatentGenePool(
|
855
|
+
num_latents = num_latents,
|
856
|
+
dim_latent = dim_latent
|
857
|
+
)
|
858
|
+
|
817
859
|
actor = Actor(
|
818
860
|
num_actions = actor_num_actions,
|
819
861
|
dim_state = dim_state,
|
@@ -825,46 +867,133 @@ def create_agent(
|
|
825
867
|
dim_state = dim_state,
|
826
868
|
dim_latent = dim_latent,
|
827
869
|
dim_hiddens = critic_dim_hiddens
|
828
|
-
)
|
829
|
-
|
830
|
-
latent_gene_pool = LatentGenePool(
|
831
|
-
dim_state = dim_state,
|
832
|
-
num_latents = num_latents,
|
833
|
-
dim_latent = dim_latent,
|
834
|
-
)
|
835
|
-
|
870
|
+
)
|
836
871
|
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
837
872
|
|
838
873
|
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
839
874
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
840
875
|
|
841
876
|
Memory = namedtuple('Memory', [
|
877
|
+
'episode_id',
|
842
878
|
'state',
|
843
879
|
'latent_gene_id',
|
844
880
|
'action',
|
845
881
|
'log_prob',
|
846
882
|
'reward',
|
847
|
-
'
|
883
|
+
'value',
|
848
884
|
'done'
|
849
885
|
])
|
850
886
|
|
851
|
-
|
887
|
+
MemoriesAndFitnessScores = namedtuple('MemoriesAndFitnessScores', [
|
852
888
|
'memories',
|
853
|
-
'
|
889
|
+
'fitness_scores'
|
854
890
|
])
|
855
891
|
|
856
892
|
class EPO(Module):
|
857
893
|
|
858
894
|
def __init__(
|
859
895
|
self,
|
860
|
-
agent: Agent
|
896
|
+
agent: Agent,
|
897
|
+
episodes_per_latent,
|
898
|
+
max_episode_length,
|
899
|
+
action_sample_temperature = 1.
|
861
900
|
):
|
862
901
|
super().__init__()
|
863
902
|
self.agent = agent
|
903
|
+
self.action_sample_temperature = action_sample_temperature
|
904
|
+
|
905
|
+
self.num_latents = agent.latent_gene_pool.num_latents
|
906
|
+
self.episodes_per_latent = episodes_per_latent
|
907
|
+
self.max_episode_length = max_episode_length
|
864
908
|
|
909
|
+
@torch.no_grad()
|
865
910
|
def forward(
|
866
911
|
self,
|
867
912
|
env
|
868
|
-
) ->
|
913
|
+
) -> MemoriesAndFitnessScores:
|
914
|
+
|
915
|
+
self.agent.eval()
|
916
|
+
|
917
|
+
invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
|
918
|
+
|
919
|
+
memories: list[Memory] = []
|
920
|
+
|
921
|
+
fitness_scores = torch.zeros((self.num_latents))
|
922
|
+
|
923
|
+
for episode_id in tqdm(range(self.episodes_per_latent), desc = 'episode'):
|
924
|
+
|
925
|
+
for latent_id in tqdm(range(self.num_latents), desc = 'latent'):
|
926
|
+
time = 0
|
927
|
+
|
928
|
+
# initial state
|
929
|
+
|
930
|
+
state = env.reset()
|
931
|
+
|
932
|
+
# get latent from pool
|
933
|
+
|
934
|
+
latent = self.agent.latent_gene_pool(latent_id = latent_id)
|
935
|
+
|
936
|
+
# until maximum episode length
|
937
|
+
|
938
|
+
done = tensor(False)
|
939
|
+
|
940
|
+
while time < self.max_episode_length:
|
941
|
+
|
942
|
+
batched_state = rearrange(state, '... -> 1 ...')
|
943
|
+
|
944
|
+
# sample action
|
869
945
|
|
870
|
-
|
946
|
+
action, log_prob = self.agent.get_actor_actions(batched_state, latent = latent, sample = True, temperature = self.action_sample_temperature)
|
947
|
+
|
948
|
+
action = rearrange(action, '1 ... -> ...')
|
949
|
+
log_prob = rearrange(log_prob, '1 ... -> ...')
|
950
|
+
|
951
|
+
# values
|
952
|
+
|
953
|
+
value = self.agent.get_critic_values(batched_state, latent = latent)
|
954
|
+
|
955
|
+
value = rearrange(value, '1 ... -> ...')
|
956
|
+
|
957
|
+
# get the next state, action, and reward
|
958
|
+
|
959
|
+
state, reward, done = env(action)
|
960
|
+
|
961
|
+
# update fitness for each gene as cumulative reward received, but make this customizable at some point
|
962
|
+
|
963
|
+
fitness_scores[latent_id] += reward
|
964
|
+
|
965
|
+
# store memories
|
966
|
+
|
967
|
+
memory = Memory(
|
968
|
+
tensor(episode_id),
|
969
|
+
state,
|
970
|
+
tensor(latent_id),
|
971
|
+
action,
|
972
|
+
log_prob,
|
973
|
+
reward,
|
974
|
+
value,
|
975
|
+
done
|
976
|
+
)
|
977
|
+
|
978
|
+
memories.append(memory)
|
979
|
+
|
980
|
+
time += 1
|
981
|
+
|
982
|
+
# need the final next value for GAE, iiuc
|
983
|
+
|
984
|
+
batched_state = rearrange(state, '... -> 1 ...')
|
985
|
+
|
986
|
+
next_value = self.agent.get_critic_values(batched_state, latent = latent)
|
987
|
+
next_value = rearrange(next_value, '1 ... -> ...')
|
988
|
+
|
989
|
+
memory_for_gae = memory._replace(
|
990
|
+
episode_id = invalid_episode,
|
991
|
+
value = next_value
|
992
|
+
)
|
993
|
+
|
994
|
+
memories.append(memory_for_gae)
|
995
|
+
|
996
|
+
return MemoriesAndFitnessScores(
|
997
|
+
memories = memories,
|
998
|
+
fitness_scores = fitness_scores
|
999
|
+
)
|
@@ -1,27 +1,47 @@
|
|
1
1
|
import torch
|
2
|
+
from einops import rearrange
|
2
3
|
|
3
4
|
def crossover_weights(w1, w2, transpose = False):
|
4
5
|
assert w2.shape == w2.shape
|
5
|
-
|
6
|
+
|
7
|
+
no_batch = w1.ndim == 2
|
8
|
+
|
9
|
+
if no_batch:
|
10
|
+
w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
|
11
|
+
|
12
|
+
assert w1.ndim == 3
|
6
13
|
|
7
14
|
if transpose:
|
8
|
-
w1, w2 =
|
15
|
+
w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
|
9
16
|
|
10
|
-
rank = min(w2.shape)
|
17
|
+
rank = min(w2.shape[1:])
|
11
18
|
assert rank >= 2
|
12
19
|
|
20
|
+
batch = w1.shape[0]
|
21
|
+
|
13
22
|
u1, s1, v1 = torch.svd(w1)
|
14
23
|
u2, s2, v2 = torch.svd(w2)
|
15
24
|
|
16
|
-
|
25
|
+
batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
|
26
|
+
mask = batch_randperm < (rank // 2)
|
17
27
|
|
18
|
-
u = torch.where(mask[None, :], u1, u2)
|
28
|
+
u = torch.where(mask[:, None, :], u1, u2)
|
19
29
|
s = torch.where(mask, s1, s2)
|
20
|
-
v = torch.where(mask[None, :], v1, v2)
|
30
|
+
v = torch.where(mask[:, None, :], v1, v2)
|
21
31
|
|
22
32
|
out = u @ torch.diag_embed(s) @ v.mT
|
23
33
|
|
24
34
|
if transpose:
|
25
|
-
out = out
|
35
|
+
out = rearrange(out, 'b j i -> b i j')
|
36
|
+
|
37
|
+
if no_batch:
|
38
|
+
out = rearrange(out, '1 ... -> ...')
|
26
39
|
|
27
40
|
return out
|
41
|
+
|
42
|
+
if __name__ == '__main__':
|
43
|
+
w1 = torch.randn(32, 16)
|
44
|
+
w2 = torch.randn(32, 16)
|
45
|
+
child = crossover_weights(w2, w2)
|
46
|
+
|
47
|
+
assert child.shape == w2.shape
|
@@ -4,15 +4,20 @@ import torch
|
|
4
4
|
from torch import tensor, randn, randint
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
|
+
# functions
|
8
|
+
|
9
|
+
def cast_tuple(v):
|
10
|
+
return v if isinstance(v, tuple) else v\
|
11
|
+
|
7
12
|
# mock env
|
8
13
|
|
9
14
|
class Env(Module):
|
10
15
|
def __init__(
|
11
16
|
self,
|
12
|
-
state_shape: tuple[int, ...]
|
17
|
+
state_shape: int | tuple[int, ...]
|
13
18
|
):
|
14
19
|
super().__init__()
|
15
|
-
self.state_shape = state_shape
|
20
|
+
self.state_shape = cast_tuple(state_shape)
|
16
21
|
self.register_buffer('dummy', tensor(0))
|
17
22
|
|
18
23
|
@property
|
@@ -31,6 +36,6 @@ class Env(Module):
|
|
31
36
|
):
|
32
37
|
state = randn(self.state_shape, device = self.device)
|
33
38
|
reward = randint(0, 5, (), device = self.device).float()
|
34
|
-
done = zeros((), device = self.device, dtype = torch.bool)
|
39
|
+
done = torch.zeros((), device = self.device, dtype = torch.bool)
|
35
40
|
|
36
41
|
return state, reward, done
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.39
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -60,8 +60,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
|
|
60
60
|
|
61
61
|
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
62
|
|
63
|
-
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
64
|
-
|
65
63
|
## Install
|
66
64
|
|
67
65
|
```bash
|
@@ -103,6 +101,45 @@ fitness = torch.randn(128)
|
|
103
101
|
latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
|
104
102
|
```
|
105
103
|
|
104
|
+
End to end learning
|
105
|
+
|
106
|
+
```python
|
107
|
+
import torch
|
108
|
+
|
109
|
+
from evolutionary_policy_optimization import (
|
110
|
+
create_agent,
|
111
|
+
EPO,
|
112
|
+
Env
|
113
|
+
)
|
114
|
+
|
115
|
+
agent = create_agent(
|
116
|
+
dim_state = 512,
|
117
|
+
num_latents = 8,
|
118
|
+
dim_latent = 32,
|
119
|
+
actor_num_actions = 5,
|
120
|
+
actor_dim_hiddens = (256, 128),
|
121
|
+
critic_dim_hiddens = (256, 128, 64)
|
122
|
+
)
|
123
|
+
|
124
|
+
epo = EPO(
|
125
|
+
agent,
|
126
|
+
episodes_per_latent = 1,
|
127
|
+
max_episode_length = 10,
|
128
|
+
action_sample_temperature = 1.
|
129
|
+
)
|
130
|
+
|
131
|
+
env = Env((512,))
|
132
|
+
|
133
|
+
memories = epo(env)
|
134
|
+
|
135
|
+
agent(memories)
|
136
|
+
|
137
|
+
# saving and loading
|
138
|
+
|
139
|
+
agent.save('./agent.pt', overwrite = True)
|
140
|
+
agent.load('./agent.pt')
|
141
|
+
```
|
142
|
+
|
106
143
|
## Citations
|
107
144
|
|
108
145
|
```bibtex
|
@@ -0,0 +1,8 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=lzxPamJahE5KqBwzyYlGOwNeUoB2vONLwtRcWqCI_Jw,29800
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
+
evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
|
5
|
+
evolutionary_policy_optimization-0.0.39.dist-info/METADATA,sha256=TTNQD7sTWIgpVwnrQrFFBD-cyySkvwJr_J3ABxTpor8,5409
|
6
|
+
evolutionary_policy_optimization-0.0.39.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.39.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.39.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=onIGNWHg1EGQwJ9TfkkJ8Yz8_S-BPoaqrxJwq54BXp0,25992
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
|
5
|
-
evolutionary_policy_optimization-0.0.37.dist-info/METADATA,sha256=nPWBCvx02MHWdKu5cEoPmHFMFKhwepOfStkXIXR2NHc,4992
|
6
|
-
evolutionary_policy_optimization-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
evolutionary_policy_optimization-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
evolutionary_policy_optimization-0.0.37.dist-info/RECORD,,
|
File without changes
|