evolutionary-policy-optimization 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,5 +4,8 @@ from evolutionary_policy_optimization.epo import (
4
4
  Critic,
5
5
  create_agent,
6
6
  Agent,
7
- LatentGenePool
7
+ LatentGenePool,
8
+ EPO
8
9
  )
10
+
11
+ from evolutionary_policy_optimization.mock_env import Env
@@ -5,13 +5,13 @@ from pathlib import Path
5
5
  from collections import namedtuple
6
6
 
7
7
  import torch
8
- from torch import nn, cat, is_tensor, tensor
8
+ from torch import nn, cat, stack, is_tensor, tensor
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
  from torch.utils.data import TensorDataset, DataLoader
12
12
 
13
13
  import einx
14
- from einops import rearrange, repeat, einsum
14
+ from einops import rearrange, repeat, einsum, pack
15
15
  from einops.layers.torch import Rearrange
16
16
 
17
17
  from assoc_scan import AssocScan
@@ -319,7 +319,6 @@ class LatentGenePool(Module):
319
319
  num_latents, # same as gene pool size
320
320
  dim_latent, # gene dimension
321
321
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
322
- dim_state = None,
323
322
  frozen_latents = True,
324
323
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
325
324
  l2norm_latent = False, # whether to enforce latents on hypersphere,
@@ -384,7 +383,6 @@ class LatentGenePool(Module):
384
383
  fitness,
385
384
  beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
386
385
  gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
387
- alpha = 0.1, # exploration factor
388
386
  inplace = True,
389
387
  ):
390
388
  islands = self.num_islands
@@ -555,7 +553,6 @@ class LatentGenePool(Module):
555
553
  def forward(
556
554
  self,
557
555
  *args,
558
- state: Tensor | None = None,
559
556
  latent_id: int | None = None,
560
557
  net: Module | None = None,
561
558
  net_latent_kwarg_name = 'latent',
@@ -575,8 +572,6 @@ class LatentGenePool(Module):
575
572
 
576
573
  # fetch latent
577
574
 
578
- fetching_multiple_latents = latent_id.numel() > 1
579
-
580
575
  latent = self.latents[latent_id]
581
576
 
582
577
  latent = self.maybe_l2norm(latent)
@@ -713,6 +708,7 @@ class Agent(Module):
713
708
  memories, next_value = memories_and_next_value
714
709
 
715
710
  (
711
+ _,
716
712
  states,
717
713
  latent_gene_ids,
718
714
  actions,
@@ -743,9 +739,11 @@ class Agent(Module):
743
739
  old_values
744
740
  ) in dataloader:
745
741
 
742
+ latents = self.latent_gene_pool(latent_gene_ids)
743
+
746
744
  # learn actor
747
745
 
748
- logits = self.actor(states)
746
+ logits = self.actor(states, latents)
749
747
  actor_loss = self.actor_loss(logits, log_probs, actions, advantages)
750
748
 
751
749
  actor_loss.backward()
@@ -754,7 +752,12 @@ class Agent(Module):
754
752
 
755
753
  # learn critic with maybe classification loss
756
754
 
757
- critic_loss = self.critic(states, advantages + old_values)
755
+ critic_loss = self.critic(
756
+ states,
757
+ latents,
758
+ targets = advantages + old_values
759
+ )
760
+
758
761
  critic_loss.backward()
759
762
 
760
763
  self.critic_optim.step()
@@ -778,7 +781,7 @@ def actor_loss(
778
781
 
779
782
  clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
780
783
 
781
- actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
784
+ actor_loss = -torch.min(clipped_ratio * advantages, ratio * advantages)
782
785
 
783
786
  # add entropy loss for exploration
784
787
 
@@ -821,7 +824,6 @@ def create_agent(
821
824
  )
822
825
 
823
826
  latent_gene_pool = LatentGenePool(
824
- dim_state = dim_state,
825
827
  num_latents = num_latents,
826
828
  dim_latent = dim_latent,
827
829
  )
@@ -832,6 +834,7 @@ def create_agent(
832
834
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
833
835
 
834
836
  Memory = namedtuple('Memory', [
837
+ 'episode_id',
835
838
  'state',
836
839
  'latent_gene_id',
837
840
  'action',
@@ -843,21 +846,33 @@ Memory = namedtuple('Memory', [
843
846
 
844
847
  MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
845
848
  'memories',
846
- 'next_value'
849
+ 'next_value',
850
+ 'cumulative_rewards'
847
851
  ])
848
852
 
849
853
  class EPO(Module):
850
854
 
851
855
  def __init__(
852
856
  self,
853
- agent: Agent
857
+ agent: Agent,
858
+ episodes_per_latent,
859
+ max_episode_length
854
860
  ):
855
861
  super().__init__()
856
862
  self.agent = agent
857
863
 
864
+ self.num_latents = agent.latent_gene_pool.num_latents
865
+ self.episodes_per_latent = episodes_per_latent
866
+ self.max_episode_length = max_episode_length
867
+
868
+ @torch.no_grad()
858
869
  def forward(
859
870
  self,
860
871
  env
861
872
  ) -> MemoriesAndNextValue:
862
873
 
874
+ self.agent.eval()
875
+
876
+ memories: list[Memory] = []
877
+
863
878
  raise NotImplementedError
@@ -1,27 +1,47 @@
1
1
  import torch
2
+ from einops import rearrange
2
3
 
3
4
  def crossover_weights(w1, w2, transpose = False):
4
5
  assert w2.shape == w2.shape
5
- assert w1.ndim == 2
6
+
7
+ no_batch = w1.ndim == 2
8
+
9
+ if no_batch:
10
+ w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
11
+
12
+ assert w1.ndim == 3
6
13
 
7
14
  if transpose:
8
- w1, w2 = w1.t(), w2.t()
15
+ w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
9
16
 
10
- rank = min(w2.shape)
17
+ rank = min(w2.shape[1:])
11
18
  assert rank >= 2
12
19
 
20
+ batch = w1.shape[0]
21
+
13
22
  u1, s1, v1 = torch.svd(w1)
14
23
  u2, s2, v2 = torch.svd(w2)
15
24
 
16
- mask = torch.randperm(rank) < (rank // 2)
25
+ batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
26
+ mask = batch_randperm < (rank // 2)
17
27
 
18
- u = torch.where(mask[None, :], u1, u2)
28
+ u = torch.where(mask[:, None, :], u1, u2)
19
29
  s = torch.where(mask, s1, s2)
20
- v = torch.where(mask[None, :], v1, v2)
30
+ v = torch.where(mask[:, None, :], v1, v2)
21
31
 
22
32
  out = u @ torch.diag_embed(s) @ v.mT
23
33
 
24
34
  if transpose:
25
- out = out.t()
35
+ out = rearrange(out, 'b j i -> b i j')
36
+
37
+ if no_batch:
38
+ out = rearrange(out, '1 ... -> ...')
26
39
 
27
40
  return out
41
+
42
+ if __name__ == '__main__':
43
+ w1 = torch.randn(32, 16)
44
+ w2 = torch.randn(32, 16)
45
+ child = crossover_weights(w2, w2)
46
+
47
+ assert child.shape == w2.shape
@@ -4,15 +4,20 @@ import torch
4
4
  from torch import tensor, randn, randint
5
5
  from torch.nn import Module
6
6
 
7
+ # functions
8
+
9
+ def cast_tuple(v):
10
+ return v if isinstance(v, tuple) else v\
11
+
7
12
  # mock env
8
13
 
9
14
  class Env(Module):
10
15
  def __init__(
11
16
  self,
12
- state_shape: tuple[int, ...]
17
+ state_shape: int | tuple[int, ...]
13
18
  ):
14
19
  super().__init__()
15
- self.state_shape = state_shape
20
+ self.state_shape = cast_tuple(state_shape)
16
21
  self.register_buffer('dummy', tensor(0))
17
22
 
18
23
  @property
@@ -31,6 +36,6 @@ class Env(Module):
31
36
  ):
32
37
  state = randn(self.state_shape, device = self.device)
33
38
  reward = randint(0, 5, (), device = self.device).float()
34
- done = zeros((), device = self.device, dtype = torch.bool)
39
+ done = torch.zeros((), device = self.device, dtype = torch.bool)
35
40
 
36
41
  return state, reward, done
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.36
3
+ Version: 0.0.38
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,8 +60,6 @@ This paper stands out, as I have witnessed the positive effects first hand in an
60
60
 
61
61
  Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
62
 
63
- Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
64
-
65
63
  ## Install
66
64
 
67
65
  ```bash
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/epo.py,sha256=z41p5LmvOHULq6o5aIj9Q6lpyka5DvkqsJ493-WL-EQ,26175
3
+ evolutionary_policy_optimization/experimental.py,sha256=9FrJGviLESlYysHI3i83efT9g2ZB9ha4u3K9HXN98_w,1100
4
+ evolutionary_policy_optimization/mock_env.py,sha256=QqVPZVJtrvQmSDcnYDTob_A5sDwiUzGj6_tmo6BII5c,918
5
+ evolutionary_policy_optimization-0.0.38.dist-info/METADATA,sha256=Lofrc6waEB8qBD19pjBKQjbKBYMNUyjYZZrJCO1fji8,4818
6
+ evolutionary_policy_optimization-0.0.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.38.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=wRqGjoiksWY33BQc9jypJbKWroHm3i_aEPNx1twVjWk,25819
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
5
- evolutionary_policy_optimization-0.0.36.dist-info/METADATA,sha256=WQpJa1PuiQx1qANilbJ0E7tZoKHDm2wAvjMccQoPH5Q,4992
6
- evolutionary_policy_optimization-0.0.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- evolutionary_policy_optimization-0.0.36.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- evolutionary_policy_optimization-0.0.36.dist-info/RECORD,,