evolutionary-policy-optimization 0.0.18__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +23 -5
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.22.dist-info}/METADATA +7 -2
- evolutionary_policy_optimization-0.0.22.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.18.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.22.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.22.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from collections import namedtuple
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn, cat
|
6
|
+
from torch import nn, cat, is_tensor, tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module, ModuleList
|
9
9
|
from torch.utils.data import TensorDataset, DataLoader
|
@@ -176,6 +176,8 @@ class MLP(Module):
|
|
176
176
|
if latent.ndim == 1:
|
177
177
|
latent = repeat(latent, 'd -> b d', b = batch)
|
178
178
|
|
179
|
+
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
180
|
+
|
179
181
|
x = cat((x, latent), dim = -1)
|
180
182
|
|
181
183
|
# layers
|
@@ -299,6 +301,7 @@ class LatentGenePool(Module):
|
|
299
301
|
dim_latent, # gene dimension
|
300
302
|
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
301
303
|
dim_state = None,
|
304
|
+
frozen_latents = True,
|
302
305
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
303
306
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
304
307
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -319,7 +322,7 @@ class LatentGenePool(Module):
|
|
319
322
|
|
320
323
|
self.num_latents = num_latents
|
321
324
|
self.needs_latent_gate = num_latent_sets > 1
|
322
|
-
self.latents = nn.Parameter(latents, requires_grad =
|
325
|
+
self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
|
323
326
|
|
324
327
|
self.maybe_l2norm = maybe_l2norm
|
325
328
|
|
@@ -441,27 +444,42 @@ class LatentGenePool(Module):
|
|
441
444
|
net: Module | None = None,
|
442
445
|
**kwargs,
|
443
446
|
):
|
447
|
+
device = self.latents.device
|
444
448
|
|
445
449
|
# if only 1 latent, assume doing ablation and get lone gene
|
446
450
|
|
447
451
|
if not exists(latent_id) and self.num_latents == 1:
|
448
452
|
latent_id = 0
|
449
453
|
|
450
|
-
|
454
|
+
if not is_tensor(latent_id):
|
455
|
+
latent_id = tensor(latent_id, device = device)
|
456
|
+
|
457
|
+
assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
|
451
458
|
|
452
459
|
# fetch latent
|
453
460
|
|
461
|
+
fetching_multiple_latents = latent_id.numel() > 1
|
462
|
+
|
454
463
|
latent = self.latents[latent_id]
|
455
464
|
|
456
465
|
if self.needs_latent_gate:
|
457
466
|
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
458
467
|
|
468
|
+
if not fetching_multiple_latents:
|
469
|
+
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
470
|
+
|
471
|
+
assert latent.shape[0] == state.shape[0]
|
472
|
+
|
459
473
|
gates = self.to_latent_gate(state)
|
460
|
-
latent = einsum(latent, gates, 'n g, b n -> b g')
|
474
|
+
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
475
|
+
|
476
|
+
elif fetching_multiple_latents:
|
477
|
+
latent = latent[:, 0]
|
461
478
|
else:
|
462
|
-
assert latent.shape[0] == 1
|
463
479
|
latent = latent[0]
|
464
480
|
|
481
|
+
latent = self.maybe_l2norm(latent)
|
482
|
+
|
465
483
|
if not exists(net):
|
466
484
|
return latent
|
467
485
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.22
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -60,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
|
|
60
60
|
|
61
61
|
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
62
62
|
|
63
|
+
## Install
|
64
|
+
|
65
|
+
```bash
|
66
|
+
$ pip install evolutionary-policy-optimization
|
67
|
+
```
|
68
|
+
|
63
69
|
## Usage
|
64
70
|
|
65
71
|
```python
|
66
|
-
|
67
72
|
import torch
|
68
73
|
|
69
74
|
from evolutionary_policy_optimization import (
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=TbUX2L-Wa2zIZ2b7iHmBtaym-qDSLAFrC7iU7xReX_k,18449
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.22.dist-info/METADATA,sha256=L3G-tesSEyhrc_SbTN6HuJQlXfogEUvr3W9SXPcnRVw,4931
|
5
|
+
evolutionary_policy_optimization-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.22.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=3pvYPwAEZdrxwwV95Ea1qG4CQjLnyaxAr40opk07LDw,17747
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.18.dist-info/METADATA,sha256=BIyCXw2IbMs-x2hDbFs9NR5s2dYEbfbeK_LadUeUc8Q,4860
|
5
|
-
evolutionary_policy_optimization-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.18.dist-info/RECORD,,
|
File without changes
|