evolutionary-policy-optimization 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  from evolutionary_policy_optimization.epo import (
2
2
  MLP,
3
+ Actor,
4
+ Critic,
5
+ Agent,
3
6
  LatentGenePool
4
7
  )
@@ -10,6 +10,7 @@ import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
 
12
12
  from einops import rearrange, repeat, einsum
13
+ from einops.layers.torch import Rearrange
13
14
 
14
15
  from assoc_scan import AssocScan
15
16
 
@@ -277,7 +278,7 @@ class Actor(Module):
277
278
 
278
279
  hidden = self.init_layer(state)
279
280
 
280
- hidden = self.mlp(state, latent)
281
+ hidden = self.mlp(hidden, latent)
281
282
 
282
283
  return self.to_out(hidden)
283
284
 
@@ -314,27 +315,10 @@ class Critic(Module):
314
315
 
315
316
  hidden = self.init_layer(state)
316
317
 
317
- hidden = self.mlp(state, latent)
318
+ hidden = self.mlp(hidden, latent)
318
319
 
319
320
  return self.to_out(hidden)
320
321
 
321
- class Agent(Module):
322
- def __init__(
323
- self,
324
- actor: Actor,
325
- critic: Critic,
326
- ):
327
- super().__init__()
328
-
329
- self.actor = actor
330
- self.critic = critic
331
-
332
- def forward(
333
- self,
334
- memories: list[Memory]
335
- ):
336
- raise NotImplementedError
337
-
338
322
  # criteria for running genetic algorithm
339
323
 
340
324
  class ShouldRunGeneticAlgorithm(Module):
@@ -368,7 +352,6 @@ class LatentGenePool(Module):
368
352
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
369
353
  frac_elitism = 0.1, # frac of population to preserve from being noised
370
354
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
371
- net: MLP | Module | dict | None = None,
372
355
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
373
356
  default_should_run_ga_gamma = 1.5
374
357
  ):
@@ -405,22 +388,6 @@ class LatentGenePool(Module):
405
388
  self.num_elites = int(frac_elitism * num_latents)
406
389
  self.has_elites = self.num_elites > 0
407
390
 
408
- # network for the latent / gene
409
-
410
- if isinstance(net, dict):
411
- assert 'dim_latent' not in net
412
- assert 'num_latent_sets' not in net
413
-
414
- net.update(dim_latent = dim_latent)
415
- net.update(num_latent_sets = num_latent_sets)
416
-
417
- net = MLP(**net)
418
-
419
- assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
420
- assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
421
-
422
- self.net = net
423
-
424
391
  if not exists(should_run_genetic_algorithm):
425
392
  should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
426
393
 
@@ -467,7 +434,7 @@ class LatentGenePool(Module):
467
434
 
468
435
  tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
469
436
 
470
- parents = participants.gather(-2, tournament_winner_indices)
437
+ parents = participants.gather(-3, tournament_winner_indices)
471
438
 
472
439
  # 3. do a crossover of the parents - in their case they went for a simple averaging, but since we are doing tournament style and the same pair of parents may be re-selected, lets make it random interpolation
473
440
 
@@ -505,11 +472,10 @@ class LatentGenePool(Module):
505
472
  self,
506
473
  *args,
507
474
  latent_id: int | None = None,
475
+ net: Module | None = None,
508
476
  **kwargs,
509
477
  ):
510
478
 
511
- assert exists(self.net)
512
-
513
479
  # if only 1 latent, assume doing ablation and get lone gene
514
480
 
515
481
  if not exists(latent_id) and self.num_latents == 1:
@@ -521,12 +487,37 @@ class LatentGenePool(Module):
521
487
 
522
488
  latent = self.latents[latent_id]
523
489
 
524
- return self.net(
490
+ if not exists(net):
491
+ return latent
492
+
493
+ return net(
525
494
  *args,
526
495
  latent = latent,
527
496
  **kwargs
528
497
  )
529
498
 
499
+ # agent contains the actor, critic, and the latent genetic pool
500
+
501
+ class Agent(Module):
502
+ def __init__(
503
+ self,
504
+ actor: Actor,
505
+ critic: Critic,
506
+ latent_gene_pool: LatentGenePool
507
+ ):
508
+ super().__init__()
509
+
510
+ self.actor = actor
511
+ self.critic = critic
512
+
513
+ self.latent_gene_pool = latent_gene_pool
514
+
515
+ def forward(
516
+ self,
517
+ memories: list[Memory]
518
+ ):
519
+ raise NotImplementedError
520
+
530
521
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
531
522
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
532
523
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -60,37 +60,40 @@ This paper stands out, as I have witnessed the positive effects first hand in an
60
60
 
61
61
  Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
62
 
63
- Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
63
+ Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm)
64
64
 
65
65
  ## Usage
66
66
 
67
67
  ```python
68
+
68
69
  import torch
69
70
 
70
71
  from evolutionary_policy_optimization import (
71
72
  LatentGenePool,
72
- MLP
73
+ Actor,
74
+ Critic
73
75
  )
74
76
 
75
77
  latent_pool = LatentGenePool(
76
- num_latents = 32,
78
+ num_latents = 128,
77
79
  dim_latent = 32,
78
- net = MLP(
79
- dims = (512, 256),
80
- dim_latent = 32,
81
- )
82
80
  )
83
81
 
84
82
  state = torch.randn(1, 512)
85
- action = latent_pool(state, latent_id = 3) # use latent / gene 4
83
+
84
+ actor = Actor(512, dim_hiddens = (256, 128), num_actions = 4, dim_latent = 32)
85
+ critic = Critic(512, dim_hiddens = (256, 128, 64), dim_latent = 32)
86
+
87
+ latent = latent_pool(latent_id = 2)
88
+
89
+ actions = actor(state, latent)
90
+ value = critic(state, latent)
86
91
 
87
92
  # interact with environment and receive rewards, termination etc
88
93
 
89
94
  # derive a fitness score for each gene / latent
90
95
 
91
- fitness = torch.randn(32)
92
-
93
- latent_pool.genetic_algorithm_step(fitness) # update latents using one generation of genetic algorithm
96
+ fitness = torch.randn(128)
94
97
 
95
98
  ```
96
99
 
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
2
+ evolutionary_policy_optimization/epo.py,sha256=66GOQq8_s5kmQI7G-2Z0J_0g4E5QarjQPJfWEP7mmKg,15442
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.10.dist-info/METADATA,sha256=bD3fw2Zw1IxhfkCvzjsRhODyL_XIC5ZsvNQqFbZXNc4,4357
5
+ evolutionary_policy_optimization-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.10.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=eiOJg0J14miB3ZWpcTD1dMC7M6abxtVaMD_Oxza0cYI,15880
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.9.dist-info/METADATA,sha256=MT4_JXsUQCrcBWB-0m9uJZHYtGnSFMbQzclZ32HZKnQ,4460
5
- evolutionary_policy_optimization-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.9.dist-info/RECORD,,