evolutionary-policy-optimization 0.0.37__tar.gz → 0.0.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (14) hide show
  1. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/PKG-INFO +1 -3
  2. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/README.md +0 -2
  3. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/evolutionary_policy_optimization/__init__.py +4 -1
  4. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/evolutionary_policy_optimization/epo.py +19 -11
  5. evolutionary_policy_optimization-0.0.38/evolutionary_policy_optimization/experimental.py +47 -0
  6. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/evolutionary_policy_optimization/mock_env.py +8 -3
  7. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/pyproject.toml +1 -1
  8. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/tests/test_epo.py +1 -2
  9. evolutionary_policy_optimization-0.0.37/evolutionary_policy_optimization/experimental.py +0 -27
  10. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/.github/workflows/python-publish.yml +0 -0
  11. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/.github/workflows/test.yml +0 -0
  12. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/.gitignore +0 -0
  13. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/LICENSE +0 -0
  14. {evolutionary_policy_optimization-0.0.37 → evolutionary_policy_optimization-0.0.38}/requirements.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.37
3
+ Version: 0.0.38
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,8 +60,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
60
60
 
61
61
  Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
62
 
63
- Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
64
-
65
63
  ## Install
66
64
 
67
65
  ```bash
@@ -8,8 +8,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
8
8
 
9
9
  Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
10
10
 
11
- Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
12
-
13
11
  ## Install
14
12
 
15
13
  ```bash
@@ -4,5 +4,8 @@ from evolutionary_policy_optimization.epo import (
4
4
  Critic,
5
5
  create_agent,
6
6
  Agent,
7
- LatentGenePool
7
+ LatentGenePool,
8
+ EPO
8
9
  )
10
+
11
+ from evolutionary_policy_optimization.mock_env import Env
@@ -5,13 +5,13 @@ from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
7
7
  import torch
8
- from torch import nn, cat, is_tensor, tensor
8
+ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
12
 
13
13
  import einx
14
- from einops import rearrange, repeat, einsum
14
+ from einops import rearrange, repeat, einsum, pack
15
15
  from einops.layers.torch import Rearrange
16
16
 
17
17
  from assoc_scan import AssocScan
@@ -319,7 +319,6 @@ class LatentGenePool(Module):
319
319
  num_latents, # same as gene pool size
320
320
  dim_latent, # gene dimension
321
321
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
322
- dim_state = None,
323
322
  frozen_latents = True,
324
323
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
325
324
  l2norm_latent = False, # whether to enforce latents on hypersphere,
@@ -384,7 +383,6 @@ class LatentGenePool(Module):
384
383
  fitness,
385
384
  beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
386
385
  gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
387
- alpha = 0.1, # exploration factor
388
386
  inplace = True,
389
387
  ):
390
388
  islands = self.num_islands
@@ -555,7 +553,6 @@ class LatentGenePool(Module):
555
553
  def forward(
556
554
  self,
557
555
  *args,
558
- state: Tensor | None = None,
559
556
  latent_id: int | None = None,
560
557
  net: Module | None = None,
561
558
  net_latent_kwarg_name = 'latent',
@@ -575,8 +572,6 @@ class LatentGenePool(Module):
575
572
 
576
573
  # fetch latent
577
574
 
578
- fetching_multiple_latents = latent_id.numel() > 1
579
-
580
575
  latent = self.latents[latent_id]
581
576
 
582
577
  latent = self.maybe_l2norm(latent)
@@ -713,6 +708,7 @@ class Agent(Module):
713
708
  memories, next_value = memories_and_next_value
714
709
 
715
710
  (
711
+ _,
716
712
  states,
717
713
  latent_gene_ids,
718
714
  actions,
@@ -785,7 +781,7 @@ def actor_loss(
785
781
 
786
782
  clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
787
783
 
788
- actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
784
+ actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
789
785
 
790
786
  # add entropy loss for exploration
791
787
 
@@ -828,7 +824,6 @@ def create_agent(
828
824
  )
829
825
 
830
826
  latent_gene_pool = LatentGenePool(
831
- dim_state = dim_state,
832
827
  num_latents = num_latents,
833
828
  dim_latent = dim_latent,
834
829
  )
@@ -839,6 +834,7 @@ def create_agent(
839
834
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
840
835
 
841
836
  Memory = namedtuple('Memory', [
837
+ 'episode_id',
842
838
  'state',
843
839
  'latent_gene_id',
844
840
  'action',
@@ -850,21 +846,33 @@ Memory = namedtuple('Memory', [
850
846
 
851
847
  MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
852
848
  'memories',
853
- 'next_value'
849
+ 'next_value',
850
+ 'cumulative_rewards'
854
851
  ])
855
852
 
856
853
  class EPO(Module):
857
854
 
858
855
  def __init__(
859
856
  self,
860
- agent: Agent
857
+ agent: Agent,
858
+ episodes_per_latent,
859
+ max_episode_length
861
860
  ):
862
861
  super().__init__()
863
862
  self.agent = agent
864
863
 
864
+ self.num_latents = agent.latent_gene_pool.num_latents
865
+ self.episodes_per_latent = episodes_per_latent
866
+ self.max_episode_length = max_episode_length
867
+
868
+ @torch.no_grad()
865
869
  def forward(
866
870
  self,
867
871
  env
868
872
  ) -> MemoriesAndNextValue:
869
873
 
874
+ self.agent.eval()
875
+
876
+ memories: list[Memory] = []
877
+
870
878
  raise NotImplementedError
@@ -0,0 +1,47 @@
1
+ import torch
2
+ from einops import rearrange
3
+
4
+ def crossover_weights(w1, w2, transpose = False):
5
+ assert w2.shape == w2.shape
6
+
7
+ no_batch = w1.ndim == 2
8
+
9
+ if no_batch:
10
+ w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
11
+
12
+ assert w1.ndim == 3
13
+
14
+ if transpose:
15
+ w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
16
+
17
+ rank = min(w2.shape[1:])
18
+ assert rank >= 2
19
+
20
+ batch = w1.shape[0]
21
+
22
+ u1, s1, v1 = torch.svd(w1)
23
+ u2, s2, v2 = torch.svd(w2)
24
+
25
+ batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
26
+ mask = batch_randperm < (rank // 2)
27
+
28
+ u = torch.where(mask[:, None, :], u1, u2)
29
+ s = torch.where(mask, s1, s2)
30
+ v = torch.where(mask[:, None, :], v1, v2)
31
+
32
+ out = u @ torch.diag_embed(s) @ v.mT
33
+
34
+ if transpose:
35
+ out = rearrange(out, 'b j i -> b i j')
36
+
37
+ if no_batch:
38
+ out = rearrange(out, '1 ... -> ...')
39
+
40
+ return out
41
+
42
+ if __name__ == '__main__':
43
+ w1 = torch.randn(32, 16)
44
+ w2 = torch.randn(32, 16)
45
+ child = crossover_weights(w2, w2)
46
+
47
+ assert child.shape == w2.shape
@@ -4,15 +4,20 @@ import torch
4
4
  from torch import tensor, randn, randint
5
5
  from torch.nn import Module
6
6
 
7
+ # functions
8
+
9
+ def cast_tuple(v):
10
+ return v if isinstance(v, tuple) else v\
11
+
7
12
  # mock env
8
13
 
9
14
  class Env(Module):
10
15
  def __init__(
11
16
  self,
12
- state_shape: tuple[int, ...]
17
+ state_shape: int | tuple[int, ...]
13
18
  ):
14
19
  super().__init__()
15
- self.state_shape = state_shape
20
+ self.state_shape = cast_tuple(state_shape)
16
21
  self.register_buffer('dummy', tensor(0))
17
22
 
18
23
  @property
@@ -31,6 +36,6 @@ class Env(Module):
31
36
  ):
32
37
  state = randn(self.state_shape, device = self.device)
33
38
  reward = randint(0, 5, (), device = self.device).float()
34
- done = zeros((), device = self.device, dtype = torch.bool)
39
+ done = torch.zeros((), device = self.device, dtype = torch.bool)
35
40
 
36
41
  return state, reward, done
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.37"
3
+ version = "0.0.38"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -16,8 +16,7 @@ def test_readme(
16
16
 
17
17
  latent_pool = LatentGenePool(
18
18
  num_latents = 128,
19
- dim_latent = 32,
20
- dim_state = 512,
19
+ dim_latent = 32,
21
20
  num_islands = num_islands,
22
21
  )
23
22
 
@@ -1,27 +0,0 @@
1
- import torch
2
-
3
- def crossover_weights(w1, w2, transpose = False):
4
- assert w2.shape == w2.shape
5
- assert w1.ndim == 2
6
-
7
- if transpose:
8
- w1, w2 = w1.t(), w2.t()
9
-
10
- rank = min(w2.shape)
11
- assert rank >= 2
12
-
13
- u1, s1, v1 = torch.svd(w1)
14
- u2, s2, v2 = torch.svd(w2)
15
-
16
- mask = torch.randperm(rank) < (rank // 2)
17
-
18
- u = torch.where(mask[None, :], u1, u2)
19
- s = torch.where(mask, s1, s2)
20
- v = torch.where(mask[None, :], v1, v2)
21
-
22
- out = u @ torch.diag_embed(s) @ v.mT
23
-
24
- if transpose:
25
- out = out.t()
26
-
27
- return out