evolutionary-policy-optimization 0.0.5__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ import torch.nn.functional as F
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
 
12
- from einops import rearrange, repeat
12
+ from einops import rearrange, repeat, einsum
13
13
 
14
14
  from assoc_scan import AssocScan
15
15
 
@@ -162,6 +162,7 @@ class MLP(Module):
162
162
  self,
163
163
  dims: tuple[int, ...],
164
164
  dim_latent = 0,
165
+ num_latent_sets = 1
165
166
  ):
166
167
  super().__init__()
167
168
  assert len(dims) >= 2, 'must have at least two dimensions'
@@ -169,17 +170,26 @@ class MLP(Module):
169
170
  # add the latent to the first dim
170
171
 
171
172
  first_dim, *rest_dims = dims
172
- first_dim += dim_latent
173
- dims = (first_dim, *rest_dims)
173
+ dims = (first_dim + dim_latent, *rest_dims)
174
+
175
+ assert num_latent_sets >= 1
174
176
 
175
177
  self.dim_latent = dim_latent
178
+ self.num_latent_sets = num_latent_sets
179
+
176
180
  self.needs_latent = dim_latent > 0
181
+ self.needs_latent_gate = num_latent_sets > 1
177
182
 
178
183
  self.encode_latent = nn.Sequential(
179
184
  Linear(dim_latent, dim_latent),
180
185
  nn.SiLU()
181
186
  ) if self.needs_latent else None
182
187
 
188
+ self.to_latent_gate = nn.Sequential(
189
+ Linear(first_dim, num_latent_sets),
190
+ nn.Softmax(dim = -1)
191
+ ) if self.needs_latent_gate else None
192
+
183
193
  # pairs of dimension
184
194
 
185
195
  dim_pairs = tuple(zip(dims[:-1], dims[1:]))
@@ -195,16 +205,27 @@ class MLP(Module):
195
205
  x,
196
206
  latent = None
197
207
  ):
208
+ batch = x.shape[0]
209
+
198
210
  assert xnor(self.needs_latent, exists(latent))
199
211
 
212
+ if exists(latent) and self.needs_latent_gate:
213
+ # an improvisation where set of genes with controlled expression by environment
214
+
215
+ gates = self.to_latent_gate(x)
216
+ latent = einsum(latent, gates, 'n g, b n -> b g')
217
+ else:
218
+ assert latent.shape[0] == 1
219
+ latent = latent[0]
220
+
200
221
  if exists(latent):
201
222
  # start with naive concatenative conditioning
202
223
  # but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
203
224
 
204
- batch = x.shape[0]
205
-
206
225
  latent = self.encode_latent(latent)
207
- latent = repeat(latent, 'd -> b d', b = batch)
226
+
227
+ if latent.ndim == 1:
228
+ latent = repeat(latent, 'd -> b d', b = batch)
208
229
 
209
230
  x = cat((x, latent), dim = -1)
210
231
 
@@ -314,6 +335,25 @@ class Agent(Module):
314
335
  ):
315
336
  raise NotImplementedError
316
337
 
338
+ # criteria for running genetic algorithm
339
+
340
+ class ShouldRunGeneticAlgorithm(Module):
341
+ def __init__(
342
+ self,
343
+ gamma = 2. # not sure what the value is
344
+ ):
345
+ super().__init__()
346
+ self.gamma = gamma
347
+
348
+ def forward(self, fitnesses):
349
+ # equation (3)
350
+
351
+ # max(fitness) - min(fitness) > gamma * median(fitness)
352
+ # however, this equation does not make much sense to me if fitness increases unbounded
353
+ # just let it be customizable, and offer a variant where mean and variance is over some threshold (could account for skew too)
354
+
355
+ return (fitnesses.amax() - fitnesses.amin()) > (self.gamma * torch.median(fitnesses))
356
+
317
357
  # classes
318
358
 
319
359
  class LatentGenePool(Module):
@@ -321,6 +361,7 @@ class LatentGenePool(Module):
321
361
  self,
322
362
  num_latents, # same as gene pool size
323
363
  dim_latent, # gene dimension
364
+ num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
324
365
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
325
366
  l2norm_latent = False, # whether to enforce latents on hypersphere,
326
367
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
@@ -328,17 +369,19 @@ class LatentGenePool(Module):
328
369
  frac_elitism = 0.1, # frac of population to preserve from being noised
329
370
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
330
371
  net: MLP | Module | dict | None = None,
372
+ should_run_genetic_algorithm: Module = ShouldRunGeneticAlgorithm() # eq (3) in paper
331
373
  ):
332
374
  super().__init__()
333
375
 
334
376
  maybe_l2norm = l2norm if l2norm_latent else identity
335
377
 
336
- latents = torch.randn(num_latents, dim_latent)
378
+ latents = torch.randn(num_latents, num_latent_sets, dim_latent)
337
379
 
338
380
  if l2norm_latent:
339
381
  latents = maybe_l2norm(latents, dim = -1)
340
382
 
341
383
  self.num_latents = num_latents
384
+ self.num_latent_sets = num_latent_sets
342
385
  self.latents = nn.Parameter(latents, requires_grad = False)
343
386
 
344
387
  self.maybe_l2norm = maybe_l2norm
@@ -364,11 +407,21 @@ class LatentGenePool(Module):
364
407
  # network for the latent / gene
365
408
 
366
409
  if isinstance(net, dict):
410
+ assert 'dim_latent' not in net
411
+ assert 'num_latent_sets' not in net
412
+
413
+ net.update(dim_latent = dim_latent)
414
+ net.update(num_latent_sets = num_latent_sets)
415
+
367
416
  net = MLP(**net)
368
417
 
369
418
  assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
419
+ assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
420
+
370
421
  self.net = net
371
422
 
423
+ self.should_run_genetic_algorithm = should_run_genetic_algorithm
424
+
372
425
  @torch.no_grad()
373
426
  # non-gradient optimization, at least, not on the individual level (taken care of by rl component)
374
427
  def genetic_algorithm_step(
@@ -379,7 +432,12 @@ class LatentGenePool(Module):
379
432
  """
380
433
  p - population
381
434
  g - gene dimension
435
+ n - number of genes per individual
382
436
  """
437
+
438
+ if not self.should_run_genetic_algorithm(fitness):
439
+ return
440
+
383
441
  assert self.num_latents > 1
384
442
 
385
443
  genes = self.latents # the latents are the genes
@@ -403,7 +461,7 @@ class LatentGenePool(Module):
403
461
 
404
462
  tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
405
463
 
406
- tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
464
+ tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
407
465
 
408
466
  parents = participants.gather(-2, tournament_winner_indices)
409
467
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.5
3
+ Version: 0.0.8
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
46
46
  Requires-Dist: box2d-py; extra == 'examples-gym'
47
47
  Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
48
48
  Requires-Dist: tqdm; extra == 'examples-gym'
49
+ Provides-Extra: test
50
+ Requires-Dist: pytest; extra == 'test'
49
51
  Description-Content-Type: text/markdown
50
52
 
51
53
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
56
58
 
57
59
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
58
60
 
59
- Besides their latent variable method, I'll also throw in some attempts with crossover in weight space
61
+ Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
+
63
+ Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
60
64
 
61
65
  ## Usage
62
66
 
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
+ evolutionary_policy_optimization/epo.py,sha256=LA81Yi6o3EFbJZHkxx1vyBFZWvNqpZ9mGhEauLZu9Ig,15692
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.8.dist-info/METADATA,sha256=42kS9DROtA90mUCJhim940ysydx3apEerwNbNs1wj_A,4460
5
+ evolutionary_policy_optimization-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.8.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
5
- evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,