evolutionary-policy-optimization 0.0.17__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from collections import namedtuple
4
4
 
5
5
  import torch
6
- from torch import nn, cat
6
+ from torch import nn, cat, is_tensor, tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
@@ -176,6 +176,8 @@ class MLP(Module):
176
176
  if latent.ndim == 1:
177
177
  latent = repeat(latent, 'd -> b d', b = batch)
178
178
 
179
+ assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
180
+
179
181
  x = cat((x, latent), dim = -1)
180
182
 
181
183
  # layers
@@ -206,6 +208,8 @@ class Actor(Module):
206
208
  assert len(dim_hiddens) >= 2
207
209
  dim_first, *_, dim_last = dim_hiddens
208
210
 
211
+ self.dim_latent = dim_latent
212
+
209
213
  self.init_layer = nn.Sequential(
210
214
  nn.Linear(dim_state, dim_first),
211
215
  nn.SiLU()
@@ -242,6 +246,8 @@ class Critic(Module):
242
246
  assert len(dim_hiddens) >= 2
243
247
  dim_first, *_, dim_last = dim_hiddens
244
248
 
249
+ self.dim_latent = dim_latent
250
+
245
251
  self.init_layer = nn.Sequential(
246
252
  nn.Linear(dim_state, dim_first),
247
253
  nn.SiLU()
@@ -437,25 +443,38 @@ class LatentGenePool(Module):
437
443
  net: Module | None = None,
438
444
  **kwargs,
439
445
  ):
446
+ device = self.latents.device
440
447
 
441
448
  # if only 1 latent, assume doing ablation and get lone gene
442
449
 
443
450
  if not exists(latent_id) and self.num_latents == 1:
444
451
  latent_id = 0
445
452
 
446
- assert 0 <= latent_id < self.num_latents
453
+ if not is_tensor(latent_id):
454
+ latent_id = tensor(latent_id, device = device)
455
+
456
+ assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
447
457
 
448
458
  # fetch latent
449
459
 
460
+ fetching_multiple_latents = latent_id.numel() > 1
461
+
450
462
  latent = self.latents[latent_id]
451
463
 
452
464
  if self.needs_latent_gate:
453
465
  assert exists(state), 'state must be passed in if greater than number of 1 latent set'
454
466
 
467
+ if not fetching_multiple_latents:
468
+ latent = repeat(latent, '... -> b ...', b = state.shape[0])
469
+
470
+ assert latent.shape[0] == state.shape[0]
471
+
455
472
  gates = self.to_latent_gate(state)
456
- latent = einsum(latent, gates, 'n g, b n -> b g')
473
+ latent = einsum(latent, gates, 'b n g, b n -> b g')
474
+
475
+ elif fetching_multiple_latents:
476
+ latent = latent[:, 0]
457
477
  else:
458
- assert latent.shape[0] == 1
459
478
  latent = latent[0]
460
479
 
461
480
  if not exists(net):
@@ -490,6 +509,8 @@ class Agent(Module):
490
509
 
491
510
  self.latent_gene_pool = latent_gene_pool
492
511
 
512
+ assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
513
+
493
514
  # optimizers
494
515
 
495
516
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.17
3
+ Version: 0.0.20
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -33,7 +33,7 @@ Classifier: Intended Audience :: Developers
33
33
  Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
- Requires-Python: >=3.8
36
+ Requires-Python: >=3.9
37
37
  Requires-Dist: adam-atan2-pytorch
38
38
  Requires-Dist: assoc-scan
39
39
  Requires-Dist: einops>=0.8.0
@@ -44,10 +44,6 @@ Provides-Extra: examples
44
44
  Requires-Dist: numpy; extra == 'examples'
45
45
  Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
46
46
  Requires-Dist: tqdm; extra == 'examples'
47
- Provides-Extra: examples-gym
48
- Requires-Dist: box2d-py; extra == 'examples-gym'
49
- Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
50
- Requires-Dist: tqdm; extra == 'examples-gym'
51
47
  Provides-Extra: test
52
48
  Requires-Dist: pytest; extra == 'test'
53
49
  Description-Content-Type: text/markdown
@@ -64,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
64
60
 
65
61
  Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
66
62
 
63
+ ## Install
64
+
65
+ ```bash
66
+ $ pip install evolutionary-policy-optimization
67
+ ```
68
+
67
69
  ## Usage
68
70
 
69
71
  ```python
70
-
71
72
  import torch
72
73
 
73
74
  from evolutionary_policy_optimization import (
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=BTBqkgDq-x4dUMlKdSojvV2Yjzf9pDUZGMik32WjdHQ,18361
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.20.dist-info/METADATA,sha256=0QNTGATtchVuxVplbrfXAtupcrMKEQD-uisM7CFm7qE,4931
5
+ evolutionary_policy_optimization-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.20.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=U1iROmPdJjU_tqd50XtBUibfOHtYUE7MzfPu-6bU2Pw,17586
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.17.dist-info/METADATA,sha256=okvM0b28MQBex5XUXVWwflYcf7hqG3I5dAh8PxWGhrM,5047
5
- evolutionary_policy_optimization-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.17.dist-info/RECORD,,