evolutionary-policy-optimization 0.0.5__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +66 -8
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.8.dist-info}/METADATA +6 -2
- evolutionary_policy_optimization-0.0.8.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.8.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.5.dist-info → evolutionary_policy_optimization-0.0.8.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ import torch.nn.functional as F
|
|
9
9
|
import torch.nn.functional as F
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
|
12
|
-
from einops import rearrange, repeat
|
12
|
+
from einops import rearrange, repeat, einsum
|
13
13
|
|
14
14
|
from assoc_scan import AssocScan
|
15
15
|
|
@@ -162,6 +162,7 @@ class MLP(Module):
|
|
162
162
|
self,
|
163
163
|
dims: tuple[int, ...],
|
164
164
|
dim_latent = 0,
|
165
|
+
num_latent_sets = 1
|
165
166
|
):
|
166
167
|
super().__init__()
|
167
168
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
@@ -169,17 +170,26 @@ class MLP(Module):
|
|
169
170
|
# add the latent to the first dim
|
170
171
|
|
171
172
|
first_dim, *rest_dims = dims
|
172
|
-
first_dim
|
173
|
-
|
173
|
+
dims = (first_dim + dim_latent, *rest_dims)
|
174
|
+
|
175
|
+
assert num_latent_sets >= 1
|
174
176
|
|
175
177
|
self.dim_latent = dim_latent
|
178
|
+
self.num_latent_sets = num_latent_sets
|
179
|
+
|
176
180
|
self.needs_latent = dim_latent > 0
|
181
|
+
self.needs_latent_gate = num_latent_sets > 1
|
177
182
|
|
178
183
|
self.encode_latent = nn.Sequential(
|
179
184
|
Linear(dim_latent, dim_latent),
|
180
185
|
nn.SiLU()
|
181
186
|
) if self.needs_latent else None
|
182
187
|
|
188
|
+
self.to_latent_gate = nn.Sequential(
|
189
|
+
Linear(first_dim, num_latent_sets),
|
190
|
+
nn.Softmax(dim = -1)
|
191
|
+
) if self.needs_latent_gate else None
|
192
|
+
|
183
193
|
# pairs of dimension
|
184
194
|
|
185
195
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
@@ -195,16 +205,27 @@ class MLP(Module):
|
|
195
205
|
x,
|
196
206
|
latent = None
|
197
207
|
):
|
208
|
+
batch = x.shape[0]
|
209
|
+
|
198
210
|
assert xnor(self.needs_latent, exists(latent))
|
199
211
|
|
212
|
+
if exists(latent) and self.needs_latent_gate:
|
213
|
+
# an improvisation where set of genes with controlled expression by environment
|
214
|
+
|
215
|
+
gates = self.to_latent_gate(x)
|
216
|
+
latent = einsum(latent, gates, 'n g, b n -> b g')
|
217
|
+
else:
|
218
|
+
assert latent.shape[0] == 1
|
219
|
+
latent = latent[0]
|
220
|
+
|
200
221
|
if exists(latent):
|
201
222
|
# start with naive concatenative conditioning
|
202
223
|
# but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
|
203
224
|
|
204
|
-
batch = x.shape[0]
|
205
|
-
|
206
225
|
latent = self.encode_latent(latent)
|
207
|
-
|
226
|
+
|
227
|
+
if latent.ndim == 1:
|
228
|
+
latent = repeat(latent, 'd -> b d', b = batch)
|
208
229
|
|
209
230
|
x = cat((x, latent), dim = -1)
|
210
231
|
|
@@ -314,6 +335,25 @@ class Agent(Module):
|
|
314
335
|
):
|
315
336
|
raise NotImplementedError
|
316
337
|
|
338
|
+
# criteria for running genetic algorithm
|
339
|
+
|
340
|
+
class ShouldRunGeneticAlgorithm(Module):
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
gamma = 2. # not sure what the value is
|
344
|
+
):
|
345
|
+
super().__init__()
|
346
|
+
self.gamma = gamma
|
347
|
+
|
348
|
+
def forward(self, fitnesses):
|
349
|
+
# equation (3)
|
350
|
+
|
351
|
+
# max(fitness) - min(fitness) > gamma * median(fitness)
|
352
|
+
# however, this equation does not make much sense to me if fitness increases unbounded
|
353
|
+
# just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
|
354
|
+
|
355
|
+
return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
|
356
|
+
|
317
357
|
# classes
|
318
358
|
|
319
359
|
class LatentGenePool(Module):
|
@@ -321,6 +361,7 @@ class LatentGenePool(Module):
|
|
321
361
|
self,
|
322
362
|
num_latents, # same as gene pool size
|
323
363
|
dim_latent, # gene dimension
|
364
|
+
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
324
365
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
325
366
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
326
367
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -328,17 +369,19 @@ class LatentGenePool(Module):
|
|
328
369
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
329
370
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
330
371
|
net: MLP | Module | dict | None = None,
|
372
|
+
should_run_genetic_algorithm: Module = ShouldRunGeneticAlgorithm() # eq (3) in paper
|
331
373
|
):
|
332
374
|
super().__init__()
|
333
375
|
|
334
376
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
335
377
|
|
336
|
-
latents = torch.randn(num_latents, dim_latent)
|
378
|
+
latents = torch.randn(num_latents, num_latent_sets, dim_latent)
|
337
379
|
|
338
380
|
if l2norm_latent:
|
339
381
|
latents = maybe_l2norm(latents, dim = -1)
|
340
382
|
|
341
383
|
self.num_latents = num_latents
|
384
|
+
self.num_latent_sets = num_latent_sets
|
342
385
|
self.latents = nn.Parameter(latents, requires_grad = False)
|
343
386
|
|
344
387
|
self.maybe_l2norm = maybe_l2norm
|
@@ -364,11 +407,21 @@ class LatentGenePool(Module):
|
|
364
407
|
# network for the latent / gene
|
365
408
|
|
366
409
|
if isinstance(net, dict):
|
410
|
+
assert 'dim_latent' not in net
|
411
|
+
assert 'num_latent_sets' not in net
|
412
|
+
|
413
|
+
net.update(dim_latent = dim_latent)
|
414
|
+
net.update(num_latent_sets = num_latent_sets)
|
415
|
+
|
367
416
|
net = MLP(**net)
|
368
417
|
|
369
418
|
assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
|
419
|
+
assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
|
420
|
+
|
370
421
|
self.net = net
|
371
422
|
|
423
|
+
self.should_run_genetic_algorithm = should_run_genetic_algorithm
|
424
|
+
|
372
425
|
@torch.no_grad()
|
373
426
|
# non-gradient optimization, at least, not on the individual level (taken care of by rl component)
|
374
427
|
def genetic_algorithm_step(
|
@@ -379,7 +432,12 @@ class LatentGenePool(Module):
|
|
379
432
|
"""
|
380
433
|
p - population
|
381
434
|
g - gene dimension
|
435
|
+
n - number of genes per individual
|
382
436
|
"""
|
437
|
+
|
438
|
+
if not self.should_run_genetic_algorithm(fitness):
|
439
|
+
return
|
440
|
+
|
383
441
|
assert self.num_latents > 1
|
384
442
|
|
385
443
|
genes = self.latents # the latents are the genes
|
@@ -403,7 +461,7 @@ class LatentGenePool(Module):
|
|
403
461
|
|
404
462
|
tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
|
405
463
|
|
406
|
-
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
|
464
|
+
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
|
407
465
|
|
408
466
|
parents = participants.gather(-2, tournament_winner_indices)
|
409
467
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.8
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
|
|
46
46
|
Requires-Dist: box2d-py; extra == 'examples-gym'
|
47
47
|
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
|
48
48
|
Requires-Dist: tqdm; extra == 'examples-gym'
|
49
|
+
Provides-Extra: test
|
50
|
+
Requires-Dist: pytest; extra == 'test'
|
49
51
|
Description-Content-Type: text/markdown
|
50
52
|
|
51
53
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
|
|
56
58
|
|
57
59
|
This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
|
58
60
|
|
59
|
-
Besides their latent variable
|
61
|
+
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
|
+
|
63
|
+
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
|
60
64
|
|
61
65
|
## Usage
|
62
66
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=LA81Yi6o3EFbJZHkxx1vyBFZWvNqpZ9mGhEauLZu9Ig,15692
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.8.dist-info/METADATA,sha256=42kS9DROtA90mUCJhim940ysydx3apEerwNbNs1wj_A,4460
|
5
|
+
evolutionary_policy_optimization-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.8.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
|
5
|
-
evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,
|
File without changes
|