evolutionary-policy-optimization 0.0.18__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from collections import namedtuple
4
4
 
5
5
  import torch
6
- from torch import nn, cat
6
+ from torch import nn, cat, is_tensor, tensor
7
7
  import torch.nn.functional as F
8
8
  from torch.nn import Linear, Module, ModuleList
9
9
  from torch.utils.data import TensorDataset, DataLoader
@@ -176,6 +176,8 @@ class MLP(Module):
176
176
  if latent.ndim == 1:
177
177
  latent = repeat(latent, 'd -> b d', b = batch)
178
178
 
179
+ assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
180
+
179
181
  x = cat((x, latent), dim = -1)
180
182
 
181
183
  # layers
@@ -299,6 +301,7 @@ class LatentGenePool(Module):
299
301
  dim_latent, # gene dimension
300
302
  num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
301
303
  dim_state = None,
304
+ frozen_latents = True,
302
305
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
303
306
  l2norm_latent = False, # whether to enforce latents on hypersphere,
304
307
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
@@ -319,7 +322,7 @@ class LatentGenePool(Module):
319
322
 
320
323
  self.num_latents = num_latents
321
324
  self.needs_latent_gate = num_latent_sets > 1
322
- self.latents = nn.Parameter(latents, requires_grad = False)
325
+ self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
323
326
 
324
327
  self.maybe_l2norm = maybe_l2norm
325
328
 
@@ -441,27 +444,42 @@ class LatentGenePool(Module):
441
444
  net: Module | None = None,
442
445
  **kwargs,
443
446
  ):
447
+ device = self.latents.device
444
448
 
445
449
  # if only 1 latent, assume doing ablation and get lone gene
446
450
 
447
451
  if not exists(latent_id) and self.num_latents == 1:
448
452
  latent_id = 0
449
453
 
450
- assert 0 <= latent_id < self.num_latents
454
+ if not is_tensor(latent_id):
455
+ latent_id = tensor(latent_id, device = device)
456
+
457
+ assert (0 <= latent_id).all() and (latent_id < self.num_latents).all()
451
458
 
452
459
  # fetch latent
453
460
 
461
+ fetching_multiple_latents = latent_id.numel() > 1
462
+
454
463
  latent = self.latents[latent_id]
455
464
 
456
465
  if self.needs_latent_gate:
457
466
  assert exists(state), 'state must be passed in if greater than number of 1 latent set'
458
467
 
468
+ if not fetching_multiple_latents:
469
+ latent = repeat(latent, '... -> b ...', b = state.shape[0])
470
+
471
+ assert latent.shape[0] == state.shape[0]
472
+
459
473
  gates = self.to_latent_gate(state)
460
- latent = einsum(latent, gates, 'n g, b n -> b g')
474
+ latent = einsum(latent, gates, 'b n g, b n -> b g')
475
+
476
+ elif fetching_multiple_latents:
477
+ latent = latent[:, 0]
461
478
  else:
462
- assert latent.shape[0] == 1
463
479
  latent = latent[0]
464
480
 
481
+ latent = self.maybe_l2norm(latent)
482
+
465
483
  if not exists(net):
466
484
  return latent
467
485
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.18
3
+ Version: 0.0.22
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,10 +60,15 @@ Besides their latent variable strategy, I'll also throw in some attempts with cr
60
60
 
61
61
  Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
62
62
 
63
+ ## Install
64
+
65
+ ```bash
66
+ $ pip install evolutionary-policy-optimization
67
+ ```
68
+
63
69
  ## Usage
64
70
 
65
71
  ```python
66
-
67
72
  import torch
68
73
 
69
74
  from evolutionary_policy_optimization import (
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=TbUX2L-Wa2zIZ2b7iHmBtaym-qDSLAFrC7iU7xReX_k,18449
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.22.dist-info/METADATA,sha256=L3G-tesSEyhrc_SbTN6HuJQlXfogEUvr3W9SXPcnRVw,4931
5
+ evolutionary_policy_optimization-0.0.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.22.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=3pvYPwAEZdrxwwV95Ea1qG4CQjLnyaxAr40opk07LDw,17747
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.18.dist-info/METADATA,sha256=BIyCXw2IbMs-x2hDbFs9NR5s2dYEbfbeK_LadUeUc8Q,4860
5
- evolutionary_policy_optimization-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.18.dist-info/RECORD,,