evolutionary-policy-optimization 0.0.17__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +25 -4
- {evolutionary_policy_optimization-0.0.17.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/METADATA +8 -7
- evolutionary_policy_optimization-0.0.20.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.17.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.17.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.17.dist-info → evolutionary_policy_optimization-0.0.20.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from collections import namedtuple
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn, cat
|
6
|
+
from torch import nn, cat, is_tensor, tensor
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Linear, Module, ModuleList
|
9
9
|
from torch.utils.data import TensorDataset, DataLoader
|
@@ -176,6 +176,8 @@ class MLP(Module):
|
|
176
176
|
if latent.ndim == 1:
|
177
177
|
latent = repeat(latent, 'd -> b d', b = batch)
|
178
178
|
|
179
|
+
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
180
|
+
|
179
181
|
x = cat((x, latent), dim = -1)
|
180
182
|
|
181
183
|
# layers
|
@@ -206,6 +208,8 @@ class Actor(Module):
|
|
206
208
|
assert len(dim_hiddens) >= 2
|
207
209
|
dim_first, *_, dim_last = dim_hiddens
|
208
210
|
|
211
|
+
self.dim_latent = dim_latent
|
212
|
+
|
209
213
|
self.init_layer = nn.Sequential(
|
210
214
|
nn.Linear(dim_state, dim_first),
|
211
215
|
nn.SiLU()
|
@@ -242,6 +246,8 @@ class Critic(Module):
|
|
242
246
|
assert len(dim_hiddens) >= 2
|
243
247
|
dim_first, *_, dim_last = dim_hiddens
|
244
248
|
|
249
|
+
self.dim_latent = dim_latent
|
250
|
+
|
245
251
|
self.init_layer = nn.Sequential(
|
246
252
|
nn.Linear(dim_state, dim_first),
|
247
253
|
nn.SiLU()
|
@@ -437,25 +443,38 @@ class LatentGenePool(Module):
|
|
437
443
|
net: Module | None = None,
|
438
444
|
**kwargs,
|
439
445
|
):
|
446
|
+
device = self.latents.device
|
440
447
|
|
441
448
|
# if only 1 latent, assume doing ablation and get lone gene
|
442
449
|
|
443
450
|
if not exists(latent_id) and self.num_latents == 1:
|
444
451
|
latent_id = 0
|
445
452
|
|
446
|
-
|
453
|
+
if not is_tensor(latent_id):
|
454
|
+
latent_id = tensor(latent_id, device = device)
|
455
|
+
|
456
|
+
assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
|
447
457
|
|
448
458
|
# fetch latent
|
449
459
|
|
460
|
+
fetching_multiple_latents = latent_id.numel() > 1
|
461
|
+
|
450
462
|
latent = self.latents[latent_id]
|
451
463
|
|
452
464
|
if self.needs_latent_gate:
|
453
465
|
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
454
466
|
|
467
|
+
if not fetching_multiple_latents:
|
468
|
+
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
469
|
+
|
470
|
+
assert latent.shape[0] == state.shape[0]
|
471
|
+
|
455
472
|
gates = self.to_latent_gate(state)
|
456
|
-
latent = einsum(latent, gates, 'n g, b n -> b g')
|
473
|
+
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
474
|
+
|
475
|
+
elif fetching_multiple_latents:
|
476
|
+
latent = latent[:, 0]
|
457
477
|
else:
|
458
|
-
assert latent.shape[0] == 1
|
459
478
|
latent = latent[0]
|
460
479
|
|
461
480
|
if not exists(net):
|
@@ -490,6 +509,8 @@ class Agent(Module):
|
|
490
509
|
|
491
510
|
self.latent_gene_pool = latent_gene_pool
|
492
511
|
|
512
|
+
assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
|
513
|
+
|
493
514
|
# optimizers
|
494
515
|
|
495
516
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.20
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -33,7 +33,7 @@ Classifier: Intended Audience :: Developers
|
|
33
33
|
Classifier: License :: OSI Approved :: MIT License
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
|
-
Requires-Python: >=3.
|
36
|
+
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: adam-atan2-pytorch
|
38
38
|
Requires-Dist: assoc-scan
|
39
39
|
Requires-Dist: einops>=0.8.0
|
@@ -44,10 +44,6 @@ Provides-Extra: examples
|
|
44
44
|
Requires-Dist: numpy; extra == 'examples'
|
45
45
|
Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
|
46
46
|
Requires-Dist: tqdm; extra == 'examples'
|
47
|
-
Provides-Extra: examples-gym
|
48
|
-
Requires-Dist: box2d-py; extra == 'examples-gym'
|
49
|
-
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
|
50
|
-
Requires-Dist: tqdm; extra == 'examples-gym'
|
51
47
|
Provides-Extra: test
|
52
48
|
Requires-Dist: pytest; extra == 'test'
|
53
49
|
Description-Content-Type: text/markdown
|
@@ -64,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
|
|
64
60
|
|
65
61
|
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
|
66
62
|
|
63
|
+
## Install
|
64
|
+
|
65
|
+
```bash
|
66
|
+
$ pip install evolutionary-policy-optimization
|
67
|
+
```
|
68
|
+
|
67
69
|
## Usage
|
68
70
|
|
69
71
|
```python
|
70
|
-
|
71
72
|
import torch
|
72
73
|
|
73
74
|
from evolutionary_policy_optimization import (
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=BTBqkgDq-x4dUMlKdSojvV2Yjzf9pDUZGMik32WjdHQ,18361
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.20.dist-info/METADATA,sha256=0QNTGATtchVuxVplbrfXAtupcrMKEQD-uisM7CFm7qE,4931
|
5
|
+
evolutionary_policy_optimization-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.20.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=U1iROmPdJjU_tqd50XtBUibfOHtYUE7MzfPu-6bU2Pw,17586
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.17.dist-info/METADATA,sha256=okvM0b28MQBex5XUXVWwflYcf7hqG3I5dAh8PxWGhrM,5047
|
5
|
-
evolutionary_policy_optimization-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.17.dist-info/RECORD,,
|
File without changes
|