evolutionary-policy-optimization 0.2.3__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,20 @@
1
+ from random import uniform
2
+ from copy import deepcopy
3
+
1
4
  import torch
2
5
  import torch.nn.functional as F
3
- from einops import rearrange
6
+ from torch.func import vmap, functional_call
7
+ from torch.nn import Module, ParameterList
8
+
9
+ from einops import rearrange, reduce, repeat
10
+
11
+ def exists(v):
12
+ return v is not None
4
13
 
5
14
  def l2norm(t, dim = -1):
6
15
  return F.normalize(t, dim = dim)
7
16
 
8
- def crossover_weights(w1, w2, transpose = False):
17
+ def crossover_weights(w1, w2):
9
18
  assert w2.shape == w2.shape
10
19
 
11
20
  no_batch = w1.ndim == 2
@@ -15,6 +24,9 @@ def crossover_weights(w1, w2, transpose = False):
15
24
 
16
25
  assert w1.ndim == 3
17
26
 
27
+ i, j = w1.shape[-2:]
28
+ transpose = i < j
29
+
18
30
  if transpose:
19
31
  w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
20
32
 
@@ -45,14 +57,16 @@ def crossover_weights(w1, w2, transpose = False):
45
57
 
46
58
  def mutate_weight(
47
59
  w,
48
- transpose = False,
49
60
  mutation_strength = 1.
50
61
  ):
51
62
 
63
+ i, j = w.shape[-2:]
64
+ transpose = i < j
65
+
52
66
  if transpose:
53
67
  w = w.transpose(-1, -2)
54
68
 
55
- rank = min(w2.shape[1:])
69
+ rank = min(w.shape[1:])
56
70
  assert rank >= 2
57
71
 
58
72
  u, s, v = torch.svd(w)
@@ -70,9 +84,132 @@ def mutate_weight(
70
84
 
71
85
  return out
72
86
 
87
+ # wrapper that manages network to population
88
+ # able to receive fitness and employ selection + crossover
89
+
90
+ class PopulationWrapper(Module):
91
+ def __init__(
92
+ self,
93
+ net: Module,
94
+ pop_size,
95
+ num_selected,
96
+ tournament_size,
97
+ learning_rate = 1e-3,
98
+ init_std_dev = 1e-1
99
+ ):
100
+ super().__init__()
101
+ assert num_selected < pop_size
102
+ assert tournament_size < num_selected
103
+
104
+ self.pop_size = pop_size
105
+ self.num_selected = num_selected
106
+ self.tournament_size = tournament_size
107
+ self.num_offsprings = pop_size - num_selected
108
+
109
+ self.net = net
110
+
111
+ params = dict(net.named_parameters())
112
+ device = next(iter(params.values())).device
113
+
114
+ pop_params = {name: (torch.randn((pop_size, *param.shape), device = device) * init_std_dev).requires_grad_() for name, param in params.items()}
115
+
116
+ self.param_names = pop_params.keys()
117
+ self.param_values = ParameterList(list(pop_params.values()))
118
+
119
+ def _forward(params, data):
120
+ return functional_call(net, params, data)
121
+
122
+ self.forward_pop_nets = vmap(_forward, in_dims = (0, None))
123
+
124
+ @property
125
+ def pop_params(self):
126
+ return dict(zip(self.param_names, self.param_values))
127
+
128
+ def individual(self, id) -> Module:
129
+ assert 0 <= id < self.pop_size
130
+ state_dict = {key: param[id] for key, param in self.pop_params.items()}
131
+
132
+ net = deepcopy(self.net)
133
+ net.load_state_dict(state_dict)
134
+ return net
135
+
136
+ def parameters(self):
137
+ return self.pop_params.values()
138
+
139
+ def genetic_algorithm_step_(
140
+ self,
141
+ fitnesses
142
+ ):
143
+ fitnesses = reduce(fitnesses, 'b p -> p', 'mean') # average across samples
144
+
145
+ num_selected = self.num_selected
146
+
147
+ # selection
148
+
149
+ sel_fitnesses, sel_indices = fitnesses.topk(num_selected, dim = -1)
150
+
151
+ # tournaments
152
+
153
+ tourn_ids = torch.randn((self.num_offsprings, self.tournament_size)).argsort(dim = -1)
154
+ tourn_scores = sel_fitnesses[tourn_ids]
155
+
156
+ winner_ids = tourn_scores.topk(2, dim = -1).indices
157
+ winner_ids = rearrange(winner_ids, 'offsprings couple -> couple offsprings')
158
+ parent_ids = sel_indices[winner_ids]
159
+
160
+ # crossover
161
+
162
+ for param in self.param_values:
163
+ parents = param[sel_indices]
164
+ parent1, parent2 = param[parent_ids]
165
+
166
+ children = parent1.lerp_(parent2, uniform(0.25, 0.75))
167
+
168
+ pop = torch.cat((parents, children))
169
+
170
+ param.data.copy_(pop)
171
+
172
+ def forward(
173
+ self,
174
+ data,
175
+ *,
176
+ individual_id = None,
177
+ labels = None,
178
+ return_logits_with_loss = False
179
+ ):
180
+ # if `individual_id` passed in, will forward for only that one network
181
+
182
+ if exists(individual_id):
183
+ assert 0 <= individual_id < self.pop_size
184
+ params = {key: param[individual_id] for key, param in self.pop_params.items()}
185
+ return functional_call(self.net, params, data)
186
+
187
+ out = self.forward_pop_nets(dict(self.pop_params), data)
188
+
189
+ if not exists(labels):
190
+ return out
191
+
192
+ logits = out
193
+ pop_size = logits.shape[0]
194
+
195
+ losses = F.cross_entropy(
196
+ rearrange(logits, 'p b ... l -> (p b) l ...'),
197
+ repeat(labels, 'b ... -> (p b) ...', p = pop_size),
198
+ reduction = 'none'
199
+ )
200
+
201
+ losses = rearrange(losses, '(p b) ... -> p b ...', p = pop_size)
202
+
203
+ if not return_logits_with_loss:
204
+ return losses
205
+
206
+ return losses, logits
207
+
208
+ # test
209
+
73
210
  if __name__ == '__main__':
74
- w1 = torch.randn(32, 16)
75
- w2 = torch.randn(32, 16)
211
+ w1 = torch.randn(2, 32, 16)
212
+ w2 = torch.randn(2, 32, 16)
76
213
 
77
214
  child = crossover_weights(w1, w2)
78
215
  mutated_w1 = mutate_weight(w1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.3
3
+ Version: 0.2.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -47,6 +47,9 @@ Provides-Extra: examples
47
47
  Requires-Dist: numpy; extra == 'examples'
48
48
  Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
+ Provides-Extra: experimental
51
+ Requires-Dist: tensordict; extra == 'experimental'
52
+ Requires-Dist: torchvision; extra == 'experimental'
50
53
  Provides-Extra: test
51
54
  Requires-Dist: pytest; extra == 'test'
52
55
  Requires-Dist: ruff>=0.4.2; extra == 'test'
@@ -56,7 +59,7 @@ Description-Content-Type: text/markdown
56
59
 
57
60
  ## Evolutionary Policy Optimization
58
61
 
59
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
62
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
60
63
 
61
64
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
62
65
 
@@ -2,9 +2,9 @@ evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX
2
2
  evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
4
  evolutionary_policy_optimization/epo.py,sha256=81rf249ykaPrAEMGk9KsF98qDkCUhW8xL3-2UXIvI2E,51838
5
- evolutionary_policy_optimization/experimental.py,sha256=ZyOGHbE4dXmt4zCljSzcUklua4vlOwQtslhFEm0JN94,1716
5
+ evolutionary_policy_optimization/experimental.py,sha256=7LOrMIaU4fr2Vme1ZpHNIvlvFEIdWj0-uemhQoNJcPQ,5549
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.2.3.dist-info/METADATA,sha256=syjis1-9dDCEwfGt7CeMfhC5k7OegIl4BBsKaTnYssQ,8697
8
- evolutionary_policy_optimization-0.2.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.2.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.2.3.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.2.6.dist-info/METADATA,sha256=fD68WuKAl76bc0O8w9ThQupw9OpEep0brkPgBlhGBhk,8858
8
+ evolutionary_policy_optimization-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.2.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.2.6.dist-info/RECORD,,