evolutionary-policy-optimization 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from collections import namedtuple
4
4
 
5
5
  import torch
6
- from torch import nn, cat
6
+ from torch import nn, cat, is_tensor, tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
@@ -176,6 +176,8 @@ class MLP(Module):
176
176
  if latent.ndim == 1:
177
177
  latent = repeat(latent, 'd -> b d', b = batch)
178
178
 
179
+ assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
180
+
179
181
  x = cat((x, latent), dim = -1)
180
182
 
181
183
  # layers
@@ -441,25 +443,38 @@ class LatentGenePool(Module):
441
443
  net: Module | None = None,
442
444
  **kwargs,
443
445
  ):
446
+ device = self.latents.device
444
447
 
445
448
  # if only 1 latent, assume doing ablation and get lone gene
446
449
 
447
450
  if not exists(latent_id) and self.num_latents == 1:
448
451
  latent_id = 0
449
452
 
450
- assert 0 <= latent_id < self.num_latents
453
+ if not is_tensor(latent_id):
454
+ latent_id = tensor(latent_id, device = device)
455
+
456
+ assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
451
457
 
452
458
  # fetch latent
453
459
 
460
+ fetching_multiple_latents = latent_id.numel() > 1
461
+
454
462
  latent = self.latents[latent_id]
455
463
 
456
464
  if self.needs_latent_gate:
457
465
  assert exists(state), 'state must be passed in if greater than number of 1 latent set'
458
466
 
467
+ if not fetching_multiple_latents:
468
+ latent = repeat(latent, '... -> b ...', b = state.shape[0])
469
+
470
+ assert latent.shape[0] == state.shape[0]
471
+
459
472
  gates = self.to_latent_gate(state)
460
- latent = einsum(latent, gates, 'n g, b n -> b g')
473
+ latent = einsum(latent, gates, 'b n g, b n -> b g')
474
+
475
+ elif fetching_multiple_latents:
476
+ latent = latent[:, 0]
461
477
  else:
462
- assert latent.shape[0] == 1
463
478
  latent = latent[0]
464
479
 
465
480
  if not exists(net):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
60
60
 
61
61
  Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
62
62
 
63
+ ## Install
64
+
65
+ ```bash
66
+ $ pip install evolutionary-policy-optimization
67
+ ```
68
+
63
69
  ## Usage
64
70
 
65
71
  ```python
66
-
67
72
  import torch
68
73
 
69
74
  from evolutionary_policy_optimization import (
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=BTBqkgDq-x4dUMlKdSojvV2Yjzf9pDUZGMik32WjdHQ,18361
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.20.dist-info/METADATA,sha256=0QNTGATtchVuxVplbrfXAtupcrMKEQD-uisM7CFm7qE,4931
5
+ evolutionary_policy_optimization-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.20.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=3pvYPwAEZdrxwwV95Ea1qG4CQjLnyaxAr40opk07LDw,17747
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.18.dist-info/METADATA,sha256=BIyCXw2IbMs-x2hDbFs9NR5s2dYEbfbeK_LadUeUc8Q,4860
5
- evolutionary_policy_optimization-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.18.dist-info/RECORD,,