evolutionary-policy-optimization 0.2.1__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1044,7 +1044,14 @@ class Agent(Module):
1044
1044
  actor.state_norm = critic.state_norm = state_norm
1045
1045
 
1046
1046
  self.use_critic_ema = use_critic_ema
1047
- self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
1047
+
1048
+ self.critic_ema = EMA(
1049
+ critic,
1050
+ beta = critic_ema_beta,
1051
+ include_online_model = False,
1052
+ ignore_startswith_names = {'state_norm'},
1053
+ **ema_kwargs
1054
+ ) if use_critic_ema else None
1048
1055
 
1049
1056
  self.latent_gene_pool = latent_gene_pool
1050
1057
  self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
@@ -1,11 +1,19 @@
1
+ from random import uniform
2
+
1
3
  import torch
2
4
  import torch.nn.functional as F
3
- from einops import rearrange
5
+ from torch.func import vmap, functional_call
6
+ from torch.nn import Module, ParameterList
7
+
8
+ from einops import rearrange, reduce, repeat
9
+
10
+ def exists(v):
11
+ return v is not None
4
12
 
5
13
  def l2norm(t, dim = -1):
6
14
  return F.normalize(t, dim = dim)
7
15
 
8
- def crossover_weights(w1, w2, transpose = False):
16
+ def crossover_weights(w1, w2):
9
17
  assert w2.shape == w2.shape
10
18
 
11
19
  no_batch = w1.ndim == 2
@@ -15,6 +23,9 @@ def crossover_weights(w1, w2, transpose = False):
15
23
 
16
24
  assert w1.ndim == 3
17
25
 
26
+ i, j = w1.shape[-2:]
27
+ transpose = i < j
28
+
18
29
  if transpose:
19
30
  w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
20
31
 
@@ -45,14 +56,16 @@ def crossover_weights(w1, w2, transpose = False):
45
56
 
46
57
  def mutate_weight(
47
58
  w,
48
- transpose = False,
49
59
  mutation_strength = 1.
50
60
  ):
51
61
 
62
+ i, j = w.shape[-2:]
63
+ transpose = i < j
64
+
52
65
  if transpose:
53
66
  w = w.transpose(-1, -2)
54
67
 
55
- rank = min(w2.shape[1:])
68
+ rank = min(w.shape[1:])
56
69
  assert rank >= 2
57
70
 
58
71
  u, s, v = torch.svd(w)
@@ -70,9 +83,114 @@ def mutate_weight(
70
83
 
71
84
  return out
72
85
 
86
+ # wrapper that manages network to population
87
+ # able to receive fitness and employ selection + crossover
88
+
89
+ class PopulationWrapper(Module):
90
+ def __init__(
91
+ self,
92
+ net: Module,
93
+ pop_size,
94
+ num_selected,
95
+ tournament_size,
96
+ learning_rate = 1e-3,
97
+ init_std_dev = 1e-1
98
+ ):
99
+ super().__init__()
100
+ assert num_selected < pop_size
101
+ assert tournament_size < num_selected
102
+
103
+ self.num_selected = num_selected
104
+ self.tournament_size = tournament_size
105
+ self.num_offsprings = pop_size - num_selected
106
+
107
+ self.net = net
108
+
109
+ params = dict(net.named_parameters())
110
+ device = next(iter(params.values())).device
111
+
112
+ pop_params = {name: (torch.randn((pop_size, *param.shape), device = device) * init_std_dev).requires_grad_() for name, param in params.items()}
113
+
114
+ self.param_names = pop_params.keys()
115
+ self.param_values = ParameterList(list(pop_params.values()))
116
+
117
+ def _forward(params, data):
118
+ return functional_call(net, params, data)
119
+
120
+ self.forward_pop_nets = vmap(_forward, in_dims = (0, None))
121
+
122
+ @property
123
+ def pop_params(self):
124
+ return dict(zip(self.param_names, self.param_values))
125
+
126
+ def parameters(self):
127
+ return self.pop_params.values()
128
+
129
+ def genetic_algorithm_step_(
130
+ self,
131
+ fitnesses
132
+ ):
133
+ fitnesses = reduce(fitnesses, 'b p -> p', 'mean') # average across samples
134
+
135
+ num_selected = self.num_selected
136
+
137
+ # selection
138
+
139
+ sel_fitnesses, sel_indices = fitnesses.topk(num_selected, dim = -1)
140
+
141
+ # tournaments
142
+
143
+ tourn_ids = torch.randn((self.num_offsprings, self.tournament_size)).argsort(dim = -1)
144
+ tourn_scores = sel_fitnesses[tourn_ids]
145
+
146
+ winner_ids = tourn_scores.topk(2, dim = -1).indices
147
+ winner_ids = rearrange(winner_ids, 'offsprings couple -> couple offsprings')
148
+ parent_ids = sel_indices[winner_ids]
149
+
150
+ # crossover
151
+
152
+ for param in self.param_values:
153
+ parents = param[sel_indices]
154
+ parent1, parent2 = param[parent_ids]
155
+
156
+ children = parent1.lerp_(parent2, uniform(0.25, 0.75))
157
+
158
+ pop = torch.cat((parents, children))
159
+
160
+ param.data.copy_(pop)
161
+
162
+ def forward(
163
+ self,
164
+ data,
165
+ labels = None,
166
+ return_logits_with_loss = False
167
+ ):
168
+ out = self.forward_pop_nets(dict(self.pop_params), data)
169
+
170
+ if not exists(labels):
171
+ return out
172
+
173
+ logits = out
174
+ pop_size = logits.shape[0]
175
+
176
+ losses = F.cross_entropy(
177
+ rearrange(logits, 'p b ... l -> (p b) l ...'),
178
+ repeat(labels, 'b ... -> (p b) ...', p = pop_size),
179
+ reduction = 'none'
180
+ )
181
+
182
+ losses = rearrange(losses, '(p b) ... -> p b ...', p = pop_size)
183
+
184
+ if not return_logits_with_loss:
185
+ return losses
186
+
187
+ return losses, logits
188
+
189
+ # test
190
+
73
191
  if __name__ == '__main__':
74
- w1 = torch.randn(32, 16)
75
- w2 = torch.randn(32, 16)
192
+ w1 = torch.randn(2, 32, 16)
193
+ w2 = torch.randn(2, 32, 16)
76
194
 
77
195
  child = crossover_weights(w1, w2)
78
196
  mutated_w1 = mutate_weight(w1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -47,6 +47,9 @@ Provides-Extra: examples
47
47
  Requires-Dist: numpy; extra == 'examples'
48
48
  Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
+ Provides-Extra: experimental
51
+ Requires-Dist: tensordict; extra == 'experimental'
52
+ Requires-Dist: torchvision; extra == 'experimental'
50
53
  Provides-Extra: test
51
54
  Requires-Dist: pytest; extra == 'test'
52
55
  Requires-Dist: ruff>=0.4.2; extra == 'test'
@@ -56,7 +59,7 @@ Description-Content-Type: text/markdown
56
59
 
57
60
  ## Evolutionary Policy Optimization
58
61
 
59
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
62
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
60
63
 
61
64
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
62
65
 
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=3uKuNrDgciZ2uxccOviPhXkuf6YTvXt-aQt-h7Uu8i0,51725
5
- evolutionary_policy_optimization/experimental.py,sha256=ZyOGHbE4dXmt4zCljSzcUklua4vlOwQtslhFEm0JN94,1716
4
+ evolutionary_policy_optimization/epo.py,sha256=81rf249ykaPrAEMGk9KsF98qDkCUhW8xL3-2UXIvI2E,51838
5
+ evolutionary_policy_optimization/experimental.py,sha256=NKtAeOS8AVoX5HHwwxAl1ngUi7ZE19uV1-NILFK6Tu8,4877
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.2.1.dist-info/METADATA,sha256=FUZiK2msH6WQip1mjfb-lWnqPKbf50ciDqKlP-8aqsQ,8697
8
- evolutionary_policy_optimization-0.2.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.2.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.2.1.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.2.5.dist-info/METADATA,sha256=WzU43HLqNw6kjQ-s7IMW9bFv9nojiQsLEd2pHTvKbfw,8858
8
+ evolutionary_policy_optimization-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.2.5.dist-info/RECORD,,