evolutionary-policy-optimization 0.0.37__py3-none-any.whl → 0.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/__init__.py +4 -1
- evolutionary_policy_optimization/epo.py +19 -11
- evolutionary_policy_optimization/experimental.py +27 -7
- evolutionary_policy_optimization/mock_env.py +8 -3
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.38.dist-info}/METADATA +1 -3
- evolutionary_policy_optimization-0.0.38.dist-info/RECORD +8 -0
- evolutionary_policy_optimization-0.0.37.dist-info/RECORD +0 -8
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.38.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.37.dist-info → evolutionary_policy_optimization-0.0.38.dist-info}/licenses/LICENSE +0 -0
@@ -5,13 +5,13 @@ from pathlib import Path
|
|
5
5
|
from collections import namedtuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from torch import nn, cat, is_tensor, tensor
|
8
|
+
from torch import nn, cat, stack, is_tensor, tensor
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
from torch.utils.data import TensorDataset, DataLoader
|
12
12
|
|
13
13
|
import einx
|
14
|
-
from einops import rearrange, repeat, einsum
|
14
|
+
from einops import rearrange, repeat, einsum, pack
|
15
15
|
from einops.layers.torch import Rearrange
|
16
16
|
|
17
17
|
from assoc_scan import AssocScan
|
@@ -319,7 +319,6 @@ class LatentGenePool(Module):
|
|
319
319
|
num_latents, # same as gene pool size
|
320
320
|
dim_latent, # gene dimension
|
321
321
|
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
322
|
-
dim_state = None,
|
323
322
|
frozen_latents = True,
|
324
323
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
325
324
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
@@ -384,7 +383,6 @@ class LatentGenePool(Module):
|
|
384
383
|
fitness,
|
385
384
|
beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
|
386
385
|
gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
|
387
|
-
alpha = 0.1, # exploration factor
|
388
386
|
inplace = True,
|
389
387
|
):
|
390
388
|
islands = self.num_islands
|
@@ -555,7 +553,6 @@ class LatentGenePool(Module):
|
|
555
553
|
def forward(
|
556
554
|
self,
|
557
555
|
*args,
|
558
|
-
state: Tensor | None = None,
|
559
556
|
latent_id: int | None = None,
|
560
557
|
net: Module | None = None,
|
561
558
|
net_latent_kwarg_name = 'latent',
|
@@ -575,8 +572,6 @@ class LatentGenePool(Module):
|
|
575
572
|
|
576
573
|
# fetch latent
|
577
574
|
|
578
|
-
fetching_multiple_latents = latent_id.numel() > 1
|
579
|
-
|
580
575
|
latent = self.latents[latent_id]
|
581
576
|
|
582
577
|
latent = self.maybe_l2norm(latent)
|
@@ -713,6 +708,7 @@ class Agent(Module):
|
|
713
708
|
memories, next_value = memories_and_next_value
|
714
709
|
|
715
710
|
(
|
711
|
+
_,
|
716
712
|
states,
|
717
713
|
latent_gene_ids,
|
718
714
|
actions,
|
@@ -785,7 +781,7 @@ def actor_loss(
|
|
785
781
|
|
786
782
|
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
787
783
|
|
788
|
-
actor_loss = -torch.min(clipped_ratio *
|
784
|
+
actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
|
789
785
|
|
790
786
|
# add entropy loss for exploration
|
791
787
|
|
@@ -828,7 +824,6 @@ def create_agent(
|
|
828
824
|
)
|
829
825
|
|
830
826
|
latent_gene_pool = LatentGenePool(
|
831
|
-
dim_state = dim_state,
|
832
827
|
num_latents = num_latents,
|
833
828
|
dim_latent = dim_latent,
|
834
829
|
)
|
@@ -839,6 +834,7 @@ def create_agent(
|
|
839
834
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
840
835
|
|
841
836
|
Memory = namedtuple('Memory', [
|
837
|
+
'episode_id',
|
842
838
|
'state',
|
843
839
|
'latent_gene_id',
|
844
840
|
'action',
|
@@ -850,21 +846,33 @@ Memory = namedtuple('Memory', [
|
|
850
846
|
|
851
847
|
MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
|
852
848
|
'memories',
|
853
|
-
'next_value'
|
849
|
+
'next_value',
|
850
|
+
'cumulative_rewards'
|
854
851
|
])
|
855
852
|
|
856
853
|
class EPO(Module):
|
857
854
|
|
858
855
|
def __init__(
|
859
856
|
self,
|
860
|
-
agent: Agent
|
857
|
+
agent: Agent,
|
858
|
+
episodes_per_latent,
|
859
|
+
max_episode_length
|
861
860
|
):
|
862
861
|
super().__init__()
|
863
862
|
self.agent = agent
|
864
863
|
|
864
|
+
self.num_latents = agent.latent_gene_pool.num_latents
|
865
|
+
self.episodes_per_latent = episodes_per_latent
|
866
|
+
self.max_episode_length = max_episode_length
|
867
|
+
|
868
|
+
@torch.no_grad()
|
865
869
|
def forward(
|
866
870
|
self,
|
867
871
|
env
|
868
872
|
) -> MemoriesAndNextValue:
|
869
873
|
|
874
|
+
self.agent.eval()
|
875
|
+
|
876
|
+
memories: list[Memory] = []
|
877
|
+
|
870
878
|
raise NotImplementedError
|
@@ -1,27 +1,47 @@
|
|
1
1
|
import torch
|
2
|
+
from einops import rearrange
|
2
3
|
|
3
4
|
def crossover_weights(w1, w2, transpose = False):
|
4
5
|
assert w2.shape == w2.shape
|
5
|
-
|
6
|
+
|
7
|
+
no_batch = w1.ndim == 2
|
8
|
+
|
9
|
+
if no_batch:
|
10
|
+
w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
|
11
|
+
|
12
|
+
assert w1.ndim == 3
|
6
13
|
|
7
14
|
if transpose:
|
8
|
-
w1, w2 =
|
15
|
+
w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
|
9
16
|
|
10
|
-
rank = min(w2.shape)
|
17
|
+
rank = min(w2.shape[1:])
|
11
18
|
assert rank >= 2
|
12
19
|
|
20
|
+
batch = w1.shape[0]
|
21
|
+
|
13
22
|
u1, s1, v1 = torch.svd(w1)
|
14
23
|
u2, s2, v2 = torch.svd(w2)
|
15
24
|
|
16
|
-
|
25
|
+
batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
|
26
|
+
mask = batch_randperm < (rank // 2)
|
17
27
|
|
18
|
-
u = torch.where(mask[None, :], u1, u2)
|
28
|
+
u = torch.where(mask[:, None, :], u1, u2)
|
19
29
|
s = torch.where(mask, s1, s2)
|
20
|
-
v = torch.where(mask[None, :], v1, v2)
|
30
|
+
v = torch.where(mask[:, None, :], v1, v2)
|
21
31
|
|
22
32
|
out = u @ torch.diag_embed(s) @ v.mT
|
23
33
|
|
24
34
|
if transpose:
|
25
|
-
out = out
|
35
|
+
out = rearrange(out, 'b j i -> b i j')
|
36
|
+
|
37
|
+
if no_batch:
|
38
|
+
out = rearrange(out, '1 ... -> ...')
|
26
39
|
|
27
40
|
return out
|
41
|
+
|
42
|
+
if __name__ == '__main__':
|
43
|
+
w1 = torch.randn(32, 16)
|
44
|
+
w2 = torch.randn(32, 16)
|
45
|
+
child = crossover_weights(w2, w2)
|
46
|
+
|
47
|
+
assert child.shape == w2.shape
|
@@ -4,15 +4,20 @@ import torch
|
|
4
4
|
from torch import tensor, randn, randint
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
|
+
# functions
|
8
|
+
|
9
|
+
def cast_tuple(v):
|
10
|
+
return v if isinstance(v, tuple) else v\
|
11
|
+
|
7
12
|
# mock env
|
8
13
|
|
9
14
|
class Env(Module):
|
10
15
|
def __init__(
|
11
16
|
self,
|
12
|
-
state_shape: tuple[int, ...]
|
17
|
+
state_shape: int | tuple[int, ...]
|
13
18
|
):
|
14
19
|
super().__init__()
|
15
|
-
self.state_shape = state_shape
|
20
|
+
self.state_shape = cast_tuple(state_shape)
|
16
21
|
self.register_buffer('dummy', tensor(0))
|
17
22
|
|
18
23
|
@property
|
@@ -31,6 +36,6 @@ class Env(Module):
|
|
31
36
|
):
|
32
37
|
state = randn(self.state_shape, device = self.device)
|
33
38
|
reward = randint(0, 5, (), device = self.device).float()
|
34
|
-
done = zeros((), device = self.device, dtype = torch.bool)
|
39
|
+
done = torch.zeros((), device = self.device, dtype = torch.bool)
|
35
40
|
|
36
41
|
return state, reward, done
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.38
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -60,8 +60,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
|
|
60
60
|
|
61
61
|
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
62
|
|
63
|
-
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
64
|
-
|
65
63
|
## Install
|
66
64
|
|
67
65
|
```bash
|
@@ -0,0 +1,8 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=z41p5LmvOHULq6o5aIj9Q6lpyka5DvkqsJ493-WL-EQ,26175
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
|
4
|
+
evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
|
5
|
+
evolutionary_policy_optimization-0.0.38.dist-info/METADATA,sha256=Lofrc6waEB8qBD19pjBKQjbKBYMNUyjYZZrJCO1fji8,4818
|
6
|
+
evolutionary_policy_optimization-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.38.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=onIGNWHg1EGQwJ9TfkkJ8Yz8_S-BPoaqrxJwq54BXp0,25992
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
|
5
|
-
evolutionary_policy_optimization-0.0.37.dist-info/METADATA,sha256=nPWBCvx02MHWdKu5cEoPmHFMFKhwepOfStkXIXR2NHc,4992
|
6
|
-
evolutionary_policy_optimization-0.0.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
evolutionary_policy_optimization-0.0.37.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
evolutionary_policy_optimization-0.0.37.dist-info/RECORD,,
|
File without changes
|