evolutionary-policy-optimization 0.2.3__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/experimental.py +143 -6
- {evolutionary_policy_optimization-0.2.3.dist-info → evolutionary_policy_optimization-0.2.6.dist-info}/METADATA +5 -2
- {evolutionary_policy_optimization-0.2.3.dist-info → evolutionary_policy_optimization-0.2.6.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.2.3.dist-info → evolutionary_policy_optimization-0.2.6.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.2.3.dist-info → evolutionary_policy_optimization-0.2.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,20 @@
|
|
1
|
+
from random import uniform
|
2
|
+
from copy import deepcopy
|
3
|
+
|
1
4
|
import torch
|
2
5
|
import torch.nn.functional as F
|
3
|
-
from
|
6
|
+
from torch.func import vmap, functional_call
|
7
|
+
from torch.nn import Module, ParameterList
|
8
|
+
|
9
|
+
from einops import rearrange, reduce, repeat
|
10
|
+
|
11
|
+
def exists(v):
|
12
|
+
return v is not None
|
4
13
|
|
5
14
|
def l2norm(t, dim = -1):
|
6
15
|
return F.normalize(t, dim = dim)
|
7
16
|
|
8
|
-
def crossover_weights(w1, w2
|
17
|
+
def crossover_weights(w1, w2):
|
9
18
|
assert w2.shape == w2.shape
|
10
19
|
|
11
20
|
no_batch = w1.ndim == 2
|
@@ -15,6 +24,9 @@ def crossover_weights(w1, w2, transpose = False):
|
|
15
24
|
|
16
25
|
assert w1.ndim == 3
|
17
26
|
|
27
|
+
i, j = w1.shape[-2:]
|
28
|
+
transpose = i < j
|
29
|
+
|
18
30
|
if transpose:
|
19
31
|
w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
|
20
32
|
|
@@ -45,14 +57,16 @@ def crossover_weights(w1, w2, transpose = False):
|
|
45
57
|
|
46
58
|
def mutate_weight(
|
47
59
|
w,
|
48
|
-
transpose = False,
|
49
60
|
mutation_strength = 1.
|
50
61
|
):
|
51
62
|
|
63
|
+
i, j = w.shape[-2:]
|
64
|
+
transpose = i < j
|
65
|
+
|
52
66
|
if transpose:
|
53
67
|
w = w.transpose(-1, -2)
|
54
68
|
|
55
|
-
rank = min(
|
69
|
+
rank = min(w.shape[1:])
|
56
70
|
assert rank >= 2
|
57
71
|
|
58
72
|
u, s, v = torch.svd(w)
|
@@ -70,9 +84,132 @@ def mutate_weight(
|
|
70
84
|
|
71
85
|
return out
|
72
86
|
|
87
|
+
# wrapper that manages network to population
|
88
|
+
# able to receive fitness and employ selection + crossover
|
89
|
+
|
90
|
+
class PopulationWrapper(Module):
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
net: Module,
|
94
|
+
pop_size,
|
95
|
+
num_selected,
|
96
|
+
tournament_size,
|
97
|
+
learning_rate = 1e-3,
|
98
|
+
init_std_dev = 1e-1
|
99
|
+
):
|
100
|
+
super().__init__()
|
101
|
+
assert num_selected < pop_size
|
102
|
+
assert tournament_size < num_selected
|
103
|
+
|
104
|
+
self.pop_size = pop_size
|
105
|
+
self.num_selected = num_selected
|
106
|
+
self.tournament_size = tournament_size
|
107
|
+
self.num_offsprings = pop_size - num_selected
|
108
|
+
|
109
|
+
self.net = net
|
110
|
+
|
111
|
+
params = dict(net.named_parameters())
|
112
|
+
device = next(iter(params.values())).device
|
113
|
+
|
114
|
+
pop_params = {name: (torch.randn((pop_size, *param.shape), device = device) * init_std_dev).requires_grad_() for name, param in params.items()}
|
115
|
+
|
116
|
+
self.param_names = pop_params.keys()
|
117
|
+
self.param_values = ParameterList(list(pop_params.values()))
|
118
|
+
|
119
|
+
def _forward(params, data):
|
120
|
+
return functional_call(net, params, data)
|
121
|
+
|
122
|
+
self.forward_pop_nets = vmap(_forward, in_dims = (0, None))
|
123
|
+
|
124
|
+
@property
|
125
|
+
def pop_params(self):
|
126
|
+
return dict(zip(self.param_names, self.param_values))
|
127
|
+
|
128
|
+
def individual(self, id) -> Module:
|
129
|
+
assert 0 <= id < self.pop_size
|
130
|
+
state_dict = {key: param[id] for key, param in self.pop_params.items()}
|
131
|
+
|
132
|
+
net = deepcopy(self.net)
|
133
|
+
net.load_state_dict(state_dict)
|
134
|
+
return net
|
135
|
+
|
136
|
+
def parameters(self):
|
137
|
+
return self.pop_params.values()
|
138
|
+
|
139
|
+
def genetic_algorithm_step_(
|
140
|
+
self,
|
141
|
+
fitnesses
|
142
|
+
):
|
143
|
+
fitnesses = reduce(fitnesses, 'b p -> p', 'mean') # average across samples
|
144
|
+
|
145
|
+
num_selected = self.num_selected
|
146
|
+
|
147
|
+
# selection
|
148
|
+
|
149
|
+
sel_fitnesses, sel_indices = fitnesses.topk(num_selected, dim = -1)
|
150
|
+
|
151
|
+
# tournaments
|
152
|
+
|
153
|
+
tourn_ids = torch.randn((self.num_offsprings, self.tournament_size)).argsort(dim = -1)
|
154
|
+
tourn_scores = sel_fitnesses[tourn_ids]
|
155
|
+
|
156
|
+
winner_ids = tourn_scores.topk(2, dim = -1).indices
|
157
|
+
winner_ids = rearrange(winner_ids, 'offsprings couple -> couple offsprings')
|
158
|
+
parent_ids = sel_indices[winner_ids]
|
159
|
+
|
160
|
+
# crossover
|
161
|
+
|
162
|
+
for param in self.param_values:
|
163
|
+
parents = param[sel_indices]
|
164
|
+
parent1, parent2 = param[parent_ids]
|
165
|
+
|
166
|
+
children = parent1.lerp_(parent2, uniform(0.25, 0.75))
|
167
|
+
|
168
|
+
pop = torch.cat((parents, children))
|
169
|
+
|
170
|
+
param.data.copy_(pop)
|
171
|
+
|
172
|
+
def forward(
|
173
|
+
self,
|
174
|
+
data,
|
175
|
+
*,
|
176
|
+
individual_id = None,
|
177
|
+
labels = None,
|
178
|
+
return_logits_with_loss = False
|
179
|
+
):
|
180
|
+
# if `individual_id` passed in, will forward for only that one network
|
181
|
+
|
182
|
+
if exists(individual_id):
|
183
|
+
assert 0 <= individual_id < self.pop_size
|
184
|
+
params = {key: param[individual_id] for key, param in self.pop_params.items()}
|
185
|
+
return functional_call(self.net, params, data)
|
186
|
+
|
187
|
+
out = self.forward_pop_nets(dict(self.pop_params), data)
|
188
|
+
|
189
|
+
if not exists(labels):
|
190
|
+
return out
|
191
|
+
|
192
|
+
logits = out
|
193
|
+
pop_size = logits.shape[0]
|
194
|
+
|
195
|
+
losses = F.cross_entropy(
|
196
|
+
rearrange(logits, 'p b ... l -> (p b) l ...'),
|
197
|
+
repeat(labels, 'b ... -> (p b) ...', p = pop_size),
|
198
|
+
reduction = 'none'
|
199
|
+
)
|
200
|
+
|
201
|
+
losses = rearrange(losses, '(p b) ... -> p b ...', p = pop_size)
|
202
|
+
|
203
|
+
if not return_logits_with_loss:
|
204
|
+
return losses
|
205
|
+
|
206
|
+
return losses, logits
|
207
|
+
|
208
|
+
# test
|
209
|
+
|
73
210
|
if __name__ == '__main__':
|
74
|
-
w1 = torch.randn(32, 16)
|
75
|
-
w2 = torch.randn(32, 16)
|
211
|
+
w1 = torch.randn(2, 32, 16)
|
212
|
+
w2 = torch.randn(2, 32, 16)
|
76
213
|
|
77
214
|
child = crossover_weights(w1, w2)
|
78
215
|
mutated_w1 = mutate_weight(w1)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -47,6 +47,9 @@ Provides-Extra: examples
|
|
47
47
|
Requires-Dist: numpy; extra == 'examples'
|
48
48
|
Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
|
49
49
|
Requires-Dist: tqdm; extra == 'examples'
|
50
|
+
Provides-Extra: experimental
|
51
|
+
Requires-Dist: tensordict; extra == 'experimental'
|
52
|
+
Requires-Dist: torchvision; extra == 'experimental'
|
50
53
|
Provides-Extra: test
|
51
54
|
Requires-Dist: pytest; extra == 'test'
|
52
55
|
Requires-Dist: ruff>=0.4.2; extra == 'test'
|
@@ -56,7 +59,7 @@ Description-Content-Type: text/markdown
|
|
56
59
|
|
57
60
|
## Evolutionary Policy Optimization
|
58
61
|
|
59
|
-
Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
|
62
|
+
Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
|
60
63
|
|
61
64
|
This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
|
62
65
|
|
@@ -2,9 +2,9 @@ evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX
|
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
4
|
evolutionary_policy_optimization/epo.py,sha256=81rf249ykaPrAEMGk9KsF98qDkCUhW8xL3-2UXIvI2E,51838
|
5
|
-
evolutionary_policy_optimization/experimental.py,sha256=
|
5
|
+
evolutionary_policy_optimization/experimental.py,sha256=7LOrMIaU4fr2Vme1ZpHNIvlvFEIdWj0-uemhQoNJcPQ,5549
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.2.
|
8
|
-
evolutionary_policy_optimization-0.2.
|
9
|
-
evolutionary_policy_optimization-0.2.
|
10
|
-
evolutionary_policy_optimization-0.2.
|
7
|
+
evolutionary_policy_optimization-0.2.6.dist-info/METADATA,sha256=fD68WuKAl76bc0O8w9ThQupw9OpEep0brkPgBlhGBhk,8858
|
8
|
+
evolutionary_policy_optimization-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.2.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.2.6.dist-info/RECORD,,
|
File without changes
|