evolutionary-policy-optimization 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +40 -8
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/METADATA +6 -2
- evolutionary_policy_optimization-0.0.6.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
|
12
|
-
from einops import rearrange, repeat
|
12
|
+
from einops import rearrange, repeat, einsum
|
13
13
|
|
14
14
|
from assoc_scan import AssocScan
|
15
15
|
|
@@ -162,6 +162,7 @@ class MLP(Module):
|
|
162
162
|
self,
|
163
163
|
dims: tuple[int, ...],
|
164
164
|
dim_latent = 0,
|
165
|
+
num_latent_sets = 1
|
165
166
|
):
|
166
167
|
super().__init__()
|
167
168
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
@@ -169,17 +170,26 @@ class MLP(Module):
|
|
169
170
|
# add the latent to the first dim
|
170
171
|
|
171
172
|
first_dim, *rest_dims = dims
|
172
|
-
first_dim
|
173
|
-
|
173
|
+
dims = (first_dim + dim_latent, *rest_dims)
|
174
|
+
|
175
|
+
assert num_latent_sets >= 1
|
174
176
|
|
175
177
|
self.dim_latent = dim_latent
|
178
|
+
self.num_latent_sets = num_latent_sets
|
179
|
+
|
176
180
|
self.needs_latent = dim_latent > 0
|
181
|
+
self.needs_latent_gate = num_latent_sets > 1
|
177
182
|
|
178
183
|
self.encode_latent = nn.Sequential(
|
179
184
|
Linear(dim_latent, dim_latent),
|
180
185
|
nn.SiLU()
|
181
186
|
) if self.needs_latent else None
|
182
187
|
|
188
|
+
self.to_latent_gate = nn.Sequential(
|
189
|
+
Linear(first_dim, num_latent_sets),
|
190
|
+
nn.Softmax(dim = -1)
|
191
|
+
) if self.needs_latent_gate else None
|
192
|
+
|
183
193
|
# pairs of dimension
|
184
194
|
|
185
195
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
@@ -195,16 +205,27 @@ class MLP(Module):
|
|
195
205
|
x,
|
196
206
|
latent = None
|
197
207
|
):
|
208
|
+
batch = x.shape[0]
|
209
|
+
|
198
210
|
assert xnor(self.needs_latent, exists(latent))
|
199
211
|
|
212
|
+
if exists(latent) and self.needs_latent_gate:
|
213
|
+
# an improvisation where set of genes with controlled expression by environment
|
214
|
+
|
215
|
+
gates = self.to_latent_gate(x)
|
216
|
+
latent = einsum(latent, gates, 'n g, b n -> b g')
|
217
|
+
else:
|
218
|
+
assert latent.shape[0] == 1
|
219
|
+
latent = latent[0]
|
220
|
+
|
200
221
|
if exists(latent):
|
201
222
|
# start with naive concatenative conditioning
|
202
223
|
# but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
|
203
224
|
|
204
|
-
batch = x.shape[0]
|
205
|
-
|
206
225
|
latent = self.encode_latent(latent)
|
207
|
-
|
226
|
+
|
227
|
+
if latent.ndim == 1:
|
228
|
+
latent = repeat(latent, 'd -> b d', b = batch)
|
208
229
|
|
209
230
|
x = cat((x, latent), dim = -1)
|
210
231
|
|
@@ -321,6 +342,7 @@ class LatentGenePool(Module):
|
|
321
342
|
self,
|
322
343
|
num_latents, # same as gene pool size
|
323
344
|
dim_latent, # gene dimension
|
345
|
+
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
324
346
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
325
347
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
326
348
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -333,12 +355,13 @@ class LatentGenePool(Module):
|
|
333
355
|
|
334
356
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
335
357
|
|
336
|
-
latents = torch.randn(num_latents, dim_latent)
|
358
|
+
latents = torch.randn(num_latents, num_latent_sets, dim_latent)
|
337
359
|
|
338
360
|
if l2norm_latent:
|
339
361
|
latents = maybe_l2norm(latents, dim = -1)
|
340
362
|
|
341
363
|
self.num_latents = num_latents
|
364
|
+
self.num_latent_sets = num_latent_sets
|
342
365
|
self.latents = nn.Parameter(latents, requires_grad = False)
|
343
366
|
|
344
367
|
self.maybe_l2norm = maybe_l2norm
|
@@ -364,9 +387,17 @@ class LatentGenePool(Module):
|
|
364
387
|
# network for the latent / gene
|
365
388
|
|
366
389
|
if isinstance(net, dict):
|
390
|
+
assert 'dim_latent' not in net
|
391
|
+
assert 'num_latent_sets' not in net
|
392
|
+
|
393
|
+
net.update(dim_latent = dim_latent)
|
394
|
+
net.update(num_latent_sets = num_latent_sets)
|
395
|
+
|
367
396
|
net = MLP(**net)
|
368
397
|
|
369
398
|
assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
|
399
|
+
assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
|
400
|
+
|
370
401
|
self.net = net
|
371
402
|
|
372
403
|
@torch.no_grad()
|
@@ -379,6 +410,7 @@ class LatentGenePool(Module):
|
|
379
410
|
"""
|
380
411
|
p - population
|
381
412
|
g - gene dimension
|
413
|
+
n - number of genes per individual
|
382
414
|
"""
|
383
415
|
assert self.num_latents > 1
|
384
416
|
|
@@ -403,7 +435,7 @@ class LatentGenePool(Module):
|
|
403
435
|
|
404
436
|
tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
|
405
437
|
|
406
|
-
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
|
438
|
+
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
|
407
439
|
|
408
440
|
parents = participants.gather(-2, tournament_winner_indices)
|
409
441
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
|
|
46
46
|
Requires-Dist: box2d-py; extra == 'examples-gym'
|
47
47
|
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
|
48
48
|
Requires-Dist: tqdm; extra == 'examples-gym'
|
49
|
+
Provides-Extra: test
|
50
|
+
Requires-Dist: pytest; extra == 'test'
|
49
51
|
Description-Content-Type: text/markdown
|
50
52
|
|
51
53
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
|
|
56
58
|
|
57
59
|
This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
|
58
60
|
|
59
|
-
Besides their latent variable
|
61
|
+
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
|
+
|
63
|
+
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
|
60
64
|
|
61
65
|
## Usage
|
62
66
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=vXkwsQE0CNEUPpguZP-XXsuDyIBN-bS3xDJDXpYlTHM,14772
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.6.dist-info/METADATA,sha256=M_0SbTqdifHQ_R9LWIe7ZfHMXgCiFDJ0sDpD29ctiNk,4460
|
5
|
+
evolutionary_policy_optimization-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.6.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
|
5
|
-
evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,
|
File without changes
|