evolutionary-policy-optimization 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +19 -4
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/METADATA +7 -2
- evolutionary_policy_optimization-0.0.20.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.18.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.18.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from collections import namedtuple
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn, cat
|
6
|
+
from torch import nn, cat, is_tensor, tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module, ModuleList
|
9
9
|
from torch.utils.data import TensorDataset, DataLoader
|
@@ -176,6 +176,8 @@ class MLP(Module):
|
|
176
176
|
if latent.ndim == 1:
|
177
177
|
latent = repeat(latent, 'd -> b d', b = batch)
|
178
178
|
|
179
|
+
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
180
|
+
|
179
181
|
x = cat((x, latent), dim = -1)
|
180
182
|
|
181
183
|
# layers
|
@@ -441,25 +443,38 @@ class LatentGenePool(Module):
|
|
441
443
|
net: Module | None = None,
|
442
444
|
**kwargs,
|
443
445
|
):
|
446
|
+
device = self.latents.device
|
444
447
|
|
445
448
|
# if only 1 latent, assume doing ablation and get lone gene
|
446
449
|
|
447
450
|
if not exists(latent_id) and self.num_latents == 1:
|
448
451
|
latent_id = 0
|
449
452
|
|
450
|
-
|
453
|
+
if not is_tensor(latent_id):
|
454
|
+
latent_id = tensor(latent_id, device = device)
|
455
|
+
|
456
|
+
assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
|
451
457
|
|
452
458
|
# fetch latent
|
453
459
|
|
460
|
+
fetching_multiple_latents = latent_id.numel() > 1
|
461
|
+
|
454
462
|
latent = self.latents[latent_id]
|
455
463
|
|
456
464
|
if self.needs_latent_gate:
|
457
465
|
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
458
466
|
|
467
|
+
if not fetching_multiple_latents:
|
468
|
+
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
469
|
+
|
470
|
+
assert latent.shape[0] == state.shape[0]
|
471
|
+
|
459
472
|
gates = self.to_latent_gate(state)
|
460
|
-
latent = einsum(latent, gates, 'n g, b n -> b g')
|
473
|
+
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
474
|
+
|
475
|
+
elif fetching_multiple_latents:
|
476
|
+
latent = latent[:, 0]
|
461
477
|
else:
|
462
|
-
assert latent.shape[0] == 1
|
463
478
|
latent = latent[0]
|
464
479
|
|
465
480
|
if not exists(net):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.20
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -60,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
|
|
60
60
|
|
61
61
|
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
62
62
|
|
63
|
+
## Install
|
64
|
+
|
65
|
+
```bash
|
66
|
+
$ pip install evolutionary-policy-optimization
|
67
|
+
```
|
68
|
+
|
63
69
|
## Usage
|
64
70
|
|
65
71
|
```python
|
66
|
-
|
67
72
|
import torch
|
68
73
|
|
69
74
|
from evolutionary_policy_optimization import (
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=BTBqkgDq-x4dUMlKdSojvV2Yjzf9pDUZGMik32WjdHQ,18361
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.20.dist-info/METADATA,sha256=0QNTGATtchVuxVplbrfXAtupcrMKEQD-uisM7CFm7qE,4931
|
5
|
+
evolutionary_policy_optimization-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.20.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=3pvYPwAEZdrxwwV95Ea1qG4CQjLnyaxAr40opk07LDw,17747
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.18.dist-info/METADATA,sha256=BIyCXw2IbMs-x2hDbFs9NR5s2dYEbfbeK_LadUeUc8Q,4860
|
5
|
-
evolutionary_policy_optimization-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.18.dist-info/RECORD,,
|
File without changes
|