evolutionary-policy-optimization 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/__init__.py +3 -0
- evolutionary_policy_optimization/epo.py +59 -64
- {evolutionary_policy_optimization-0.0.9.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/METADATA +15 -12
- evolutionary_policy_optimization-0.0.11.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.9.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.9.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.9.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|
10
10
|
from torch.nn import Linear, Module, ModuleList
|
11
11
|
|
12
12
|
from einops import rearrange, repeat, einsum
|
13
|
+
from einops.layers.torch import Rearrange
|
13
14
|
|
14
15
|
from assoc_scan import AssocScan
|
15
16
|
|
@@ -162,7 +163,6 @@ class MLP(Module):
|
|
162
163
|
self,
|
163
164
|
dims: tuple[int, ...],
|
164
165
|
dim_latent = 0,
|
165
|
-
num_latent_sets = 1
|
166
166
|
):
|
167
167
|
super().__init__()
|
168
168
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
@@ -172,24 +172,15 @@ class MLP(Module):
|
|
172
172
|
first_dim, *rest_dims = dims
|
173
173
|
dims = (first_dim + dim_latent, *rest_dims)
|
174
174
|
|
175
|
-
assert num_latent_sets >= 1
|
176
|
-
|
177
175
|
self.dim_latent = dim_latent
|
178
|
-
self.num_latent_sets = num_latent_sets
|
179
176
|
|
180
177
|
self.needs_latent = dim_latent > 0
|
181
|
-
self.needs_latent_gate = num_latent_sets > 1
|
182
178
|
|
183
179
|
self.encode_latent = nn.Sequential(
|
184
180
|
Linear(dim_latent, dim_latent),
|
185
181
|
nn.SiLU()
|
186
182
|
) if self.needs_latent else None
|
187
183
|
|
188
|
-
self.to_latent_gate = nn.Sequential(
|
189
|
-
Linear(first_dim, num_latent_sets),
|
190
|
-
nn.Softmax(dim = -1)
|
191
|
-
) if self.needs_latent_gate else None
|
192
|
-
|
193
184
|
# pairs of dimension
|
194
185
|
|
195
186
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
@@ -209,15 +200,6 @@ class MLP(Module):
|
|
209
200
|
|
210
201
|
assert xnor(self.needs_latent, exists(latent))
|
211
202
|
|
212
|
-
if exists(latent) and self.needs_latent_gate:
|
213
|
-
# an improvisation where set of genes with controlled expression by environment
|
214
|
-
|
215
|
-
gates = self.to_latent_gate(x)
|
216
|
-
latent = einsum(latent, gates, 'n g, b n -> b g')
|
217
|
-
else:
|
218
|
-
assert latent.shape[0] == 1
|
219
|
-
latent = latent[0]
|
220
|
-
|
221
203
|
if exists(latent):
|
222
204
|
# start with naive concatenative conditioning
|
223
205
|
# but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
|
@@ -247,7 +229,7 @@ class MLP(Module):
|
|
247
229
|
class Actor(Module):
|
248
230
|
def __init__(
|
249
231
|
self,
|
250
|
-
|
232
|
+
dim_state,
|
251
233
|
num_actions,
|
252
234
|
dim_hiddens: tuple[int, ...],
|
253
235
|
dim_latent = 0,
|
@@ -258,7 +240,7 @@ class Actor(Module):
|
|
258
240
|
dim_first, *_, dim_last = dim_hiddens
|
259
241
|
|
260
242
|
self.init_layer = nn.Sequential(
|
261
|
-
nn.Linear(
|
243
|
+
nn.Linear(dim_state, dim_first),
|
262
244
|
nn.SiLU()
|
263
245
|
)
|
264
246
|
|
@@ -277,14 +259,14 @@ class Actor(Module):
|
|
277
259
|
|
278
260
|
hidden = self.init_layer(state)
|
279
261
|
|
280
|
-
hidden = self.mlp(
|
262
|
+
hidden = self.mlp(hidden, latent)
|
281
263
|
|
282
264
|
return self.to_out(hidden)
|
283
265
|
|
284
266
|
class Critic(Module):
|
285
267
|
def __init__(
|
286
268
|
self,
|
287
|
-
|
269
|
+
dim_state,
|
288
270
|
dim_hiddens: tuple[int, ...],
|
289
271
|
dim_latent = 0,
|
290
272
|
):
|
@@ -294,7 +276,7 @@ class Critic(Module):
|
|
294
276
|
dim_first, *_, dim_last = dim_hiddens
|
295
277
|
|
296
278
|
self.init_layer = nn.Sequential(
|
297
|
-
nn.Linear(
|
279
|
+
nn.Linear(dim_state, dim_first),
|
298
280
|
nn.SiLU()
|
299
281
|
)
|
300
282
|
|
@@ -314,27 +296,10 @@ class Critic(Module):
|
|
314
296
|
|
315
297
|
hidden = self.init_layer(state)
|
316
298
|
|
317
|
-
hidden = self.mlp(
|
299
|
+
hidden = self.mlp(hidden, latent)
|
318
300
|
|
319
301
|
return self.to_out(hidden)
|
320
302
|
|
321
|
-
class Agent(Module):
|
322
|
-
def __init__(
|
323
|
-
self,
|
324
|
-
actor: Actor,
|
325
|
-
critic: Critic,
|
326
|
-
):
|
327
|
-
super().__init__()
|
328
|
-
|
329
|
-
self.actor = actor
|
330
|
-
self.critic = critic
|
331
|
-
|
332
|
-
def forward(
|
333
|
-
self,
|
334
|
-
memories: list[Memory]
|
335
|
-
):
|
336
|
-
raise NotImplementedError
|
337
|
-
|
338
303
|
# criteria for running genetic algorithm
|
339
304
|
|
340
305
|
class ShouldRunGeneticAlgorithm(Module):
|
@@ -362,13 +327,13 @@ class LatentGenePool(Module):
|
|
362
327
|
num_latents, # same as gene pool size
|
363
328
|
dim_latent, # gene dimension
|
364
329
|
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
330
|
+
dim_state = None,
|
365
331
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
366
332
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
367
333
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
368
334
|
frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
|
369
335
|
frac_elitism = 0.1, # frac of population to preserve from being noised
|
370
336
|
mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
|
371
|
-
net: MLP | Module | dict | None = None,
|
372
337
|
should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
|
373
338
|
default_should_run_ga_gamma = 1.5
|
374
339
|
):
|
@@ -382,11 +347,23 @@ class LatentGenePool(Module):
|
|
382
347
|
latents = maybe_l2norm(latents, dim = -1)
|
383
348
|
|
384
349
|
self.num_latents = num_latents
|
385
|
-
self.
|
350
|
+
self.needs_latent_gate = num_latent_sets > 1
|
386
351
|
self.latents = nn.Parameter(latents, requires_grad = False)
|
387
352
|
|
388
353
|
self.maybe_l2norm = maybe_l2norm
|
389
354
|
|
355
|
+
# gene expression as a function of environment
|
356
|
+
|
357
|
+
self.num_latent_sets = num_latent_sets
|
358
|
+
|
359
|
+
if self.needs_latent_gate:
|
360
|
+
assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
|
361
|
+
|
362
|
+
self.to_latent_gate = nn.Sequential(
|
363
|
+
Linear(dim_state, num_latent_sets),
|
364
|
+
nn.Softmax(dim = -1)
|
365
|
+
) if self.needs_latent_gate else None
|
366
|
+
|
390
367
|
# some derived values
|
391
368
|
|
392
369
|
assert 0. < frac_tournaments < 1.
|
@@ -405,22 +382,6 @@ class LatentGenePool(Module):
|
|
405
382
|
self.num_elites = int(frac_elitism * num_latents)
|
406
383
|
self.has_elites = self.num_elites > 0
|
407
384
|
|
408
|
-
# network for the latent / gene
|
409
|
-
|
410
|
-
if isinstance(net, dict):
|
411
|
-
assert 'dim_latent' not in net
|
412
|
-
assert 'num_latent_sets' not in net
|
413
|
-
|
414
|
-
net.update(dim_latent = dim_latent)
|
415
|
-
net.update(num_latent_sets = num_latent_sets)
|
416
|
-
|
417
|
-
net = MLP(**net)
|
418
|
-
|
419
|
-
assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
|
420
|
-
assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
|
421
|
-
|
422
|
-
self.net = net
|
423
|
-
|
424
385
|
if not exists(should_run_genetic_algorithm):
|
425
386
|
should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
|
426
387
|
|
@@ -467,7 +428,7 @@ class LatentGenePool(Module):
|
|
467
428
|
|
468
429
|
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
|
469
430
|
|
470
|
-
parents = participants.gather(-
|
431
|
+
parents = participants.gather(-3, tournament_winner_indices)
|
471
432
|
|
472
433
|
# 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
|
473
434
|
|
@@ -504,12 +465,12 @@ class LatentGenePool(Module):
|
|
504
465
|
def forward(
|
505
466
|
self,
|
506
467
|
*args,
|
468
|
+
state: Tensor | None = None,
|
507
469
|
latent_id: int | None = None,
|
470
|
+
net: Module | None = None,
|
508
471
|
**kwargs,
|
509
472
|
):
|
510
473
|
|
511
|
-
assert exists(self.net)
|
512
|
-
|
513
474
|
# if only 1 latent, assume doing ablation and get lone gene
|
514
475
|
|
515
476
|
if not exists(latent_id) and self.num_latents == 1:
|
@@ -521,12 +482,46 @@ class LatentGenePool(Module):
|
|
521
482
|
|
522
483
|
latent = self.latents[latent_id]
|
523
484
|
|
524
|
-
|
485
|
+
if self.needs_latent_gate:
|
486
|
+
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
487
|
+
|
488
|
+
gates = self.to_latent_gate(state)
|
489
|
+
latent = einsum(latent, gates, 'n g, b n -> b g')
|
490
|
+
else:
|
491
|
+
assert latent.shape[0] == 1
|
492
|
+
latent = latent[0]
|
493
|
+
|
494
|
+
if not exists(net):
|
495
|
+
return latent
|
496
|
+
|
497
|
+
return net(
|
525
498
|
*args,
|
526
499
|
latent = latent,
|
527
500
|
**kwargs
|
528
501
|
)
|
529
502
|
|
503
|
+
# agent contains the actor, critic, and the latent genetic pool
|
504
|
+
|
505
|
+
class Agent(Module):
|
506
|
+
def __init__(
|
507
|
+
self,
|
508
|
+
actor: Actor,
|
509
|
+
critic: Critic,
|
510
|
+
latent_gene_pool: LatentGenePool
|
511
|
+
):
|
512
|
+
super().__init__()
|
513
|
+
|
514
|
+
self.actor = actor
|
515
|
+
self.critic = critic
|
516
|
+
|
517
|
+
self.latent_gene_pool = latent_gene_pool
|
518
|
+
|
519
|
+
def forward(
|
520
|
+
self,
|
521
|
+
memories: list[Memory]
|
522
|
+
):
|
523
|
+
raise NotImplementedError
|
524
|
+
|
530
525
|
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
531
526
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
532
527
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.11
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -60,37 +60,40 @@ This paper stands out, as I have witnessed the positive effects first hand in an
|
|
60
60
|
|
61
61
|
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
62
|
|
63
|
-
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
63
|
+
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
64
64
|
|
65
65
|
## Usage
|
66
66
|
|
67
67
|
```python
|
68
|
+
|
68
69
|
import torch
|
69
70
|
|
70
71
|
from evolutionary_policy_optimization import (
|
71
72
|
LatentGenePool,
|
72
|
-
|
73
|
+
Actor,
|
74
|
+
Critic
|
73
75
|
)
|
74
76
|
|
75
77
|
latent_pool = LatentGenePool(
|
76
|
-
num_latents =
|
78
|
+
num_latents = 128,
|
77
79
|
dim_latent = 32,
|
78
|
-
net = MLP(
|
79
|
-
dims = (512, 256),
|
80
|
-
dim_latent = 32,
|
81
|
-
)
|
82
80
|
)
|
83
81
|
|
84
82
|
state = torch.randn(1, 512)
|
85
|
-
|
83
|
+
|
84
|
+
actor = Actor(512, dim_hiddens = (256, 128), num_actions = 4, dim_latent = 32)
|
85
|
+
critic = Critic(512, dim_hiddens = (256, 128, 64), dim_latent = 32)
|
86
|
+
|
87
|
+
latent = latent_pool(latent_id = 2)
|
88
|
+
|
89
|
+
actions = actor(state, latent)
|
90
|
+
value = critic(state, latent)
|
86
91
|
|
87
92
|
# interact with environment and receive rewards, termination etc
|
88
93
|
|
89
94
|
# derive a fitness score for each gene / latent
|
90
95
|
|
91
|
-
fitness = torch.randn(
|
92
|
-
|
93
|
-
latent_pool.genetic_algorithm_step(fitness) # update latents using one generation of genetic algorithm
|
96
|
+
fitness = torch.randn(128)
|
94
97
|
|
95
98
|
```
|
96
99
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=JGow9ofx7IgFy7QNL0dL0K_SCL_bVkBUznMG8aSGM9Q,15591
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.11.dist-info/METADATA,sha256=fkouRBZU5nrPgHt0eT5izSHdOiYGAg67N5Gn3t039mQ,4357
|
5
|
+
evolutionary_policy_optimization-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.11.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=eiOJg0J14miB3ZWpcTD1dMC7M6abxtVaMD_Oxza0cYI,15880
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.9.dist-info/METADATA,sha256=MT4_JXsUQCrcBWB-0m9uJZHYtGnSFMbQzclZ32HZKnQ,4460
|
5
|
-
evolutionary_policy_optimization-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.9.dist-info/RECORD,,
|
File without changes
|