evolutionary-policy-optimization 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
303
303
  self,
304
304
  num_latents, # same as gene pool size
305
305
  dim_latent, # gene dimension
306
- num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
307
306
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
308
307
  dim_state = None,
309
308
  frozen_latents = True,
@@ -312,6 +311,7 @@ class LatentGenePool(Module):
312
311
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
313
312
  frac_natural_selected = 0.25, # number of least fit genes to remove from the pool
314
313
  frac_elitism = 0.1, # frac of population to preserve from being noised
314
+ frac_migrate = 0.1, # frac of population, excluding elites, that migrate between islands randomly. will use a designated set migration pattern (since for some reason using random it seems to be worse for me)
315
315
  mutation_strength = 1., # factor to multiply to gaussian noise as mutation to latents
316
316
  should_run_genetic_algorithm: Module | None = None, # eq (3) in paper
317
317
  default_should_run_ga_gamma = 1.5
@@ -320,29 +320,17 @@ class LatentGenePool(Module):
320
320
 
321
321
  maybe_l2norm = l2norm if l2norm_latent else identity
322
322
 
323
- latents = torch.randn(num_latents, num_latent_sets, dim_latent)
323
+ latents = torch.randn(num_latents, dim_latent)
324
324
 
325
325
  if l2norm_latent:
326
326
  latents = maybe_l2norm(latents, dim = -1)
327
327
 
328
328
  self.num_latents = num_latents
329
- self.needs_latent_gate = num_latent_sets > 1
329
+ self.frozen_latents = frozen_latents
330
330
  self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
331
331
 
332
332
  self.maybe_l2norm = maybe_l2norm
333
333
 
334
- # gene expression as a function of environment
335
-
336
- self.num_latent_sets = num_latent_sets
337
-
338
- if self.needs_latent_gate:
339
- assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
340
-
341
- self.to_latent_gate = nn.Sequential(
342
- Linear(dim_state, num_latent_sets),
343
- nn.Softmax(dim = -1)
344
- ) if self.needs_latent_gate else None
345
-
346
334
  # some derived values
347
335
 
348
336
  assert num_islands >= 1
@@ -361,12 +349,16 @@ class LatentGenePool(Module):
361
349
  self.num_natural_selected = int(frac_natural_selected * latents_per_island)
362
350
 
363
351
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
352
+
364
353
  self.crossover_random = crossover_random
365
354
 
366
355
  self.mutation_strength = mutation_strength
367
356
  self.num_elites = int(frac_elitism * latents_per_island)
368
357
  self.has_elites = self.num_elites > 0
369
358
 
359
+ latents_without_elites = num_latents - self.num_elites
360
+ self.num_migrate = int(frac_migrate * latents_without_elites)
361
+
370
362
  if not exists(should_run_genetic_algorithm):
371
363
  should_run_genetic_algorithm = ShouldRunGeneticAlgorithm(gamma = default_should_run_ga_gamma)
372
364
 
@@ -378,7 +370,6 @@ class LatentGenePool(Module):
378
370
  beta0 = 2., # exploitation factor, moving fireflies of low light intensity to high
379
371
  gamma = 1., # controls light intensity decay over distance - setting this to zero will make firefly equivalent to vanilla PSO
380
372
  alpha = 0.1, # exploration factor
381
- alpha_decay = 0.995, # exploration decay each step
382
373
  inplace = True,
383
374
  ):
384
375
  islands = self.num_islands
@@ -424,8 +415,11 @@ class LatentGenePool(Module):
424
415
  def genetic_algorithm_step(
425
416
  self,
426
417
  fitness, # Float['p'],
427
- inplace = True
418
+ inplace = True,
419
+ migrate = False # trigger a migration in the setting of multiple islands, the loop outside will need to have some `migrate_every` hyperparameter
428
420
  ):
421
+ device = self.latents.device
422
+
429
423
  """
430
424
  i - islands
431
425
  p - population
@@ -460,7 +454,7 @@ class LatentGenePool(Module):
460
454
 
461
455
  return genes
462
456
 
463
- genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
457
+ genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
464
458
 
465
459
  orig_genes = genes
466
460
 
@@ -469,13 +463,13 @@ class LatentGenePool(Module):
469
463
 
470
464
  sorted_indices = fitness.sort(dim = -1).indices
471
465
  natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
472
- natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
466
+ natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
473
467
 
474
468
  genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
475
469
 
476
470
  # 2. for finding pairs of parents to replete gene pool, we will go with the popular tournament strategy
477
471
 
478
- rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants)).argsort(dim = -1)
472
+ rand_tournament_gene_ids = torch.randn((islands, pop_size_per_island - self.num_natural_selected, tournament_participants), device = device).argsort(dim = -1)
479
473
  rand_tournament_gene_ids_for_gather = rearrange(rand_tournament_gene_ids, 'i p t -> i (p t)')
480
474
 
481
475
  participant_fitness = fitness.gather(1, rand_tournament_gene_ids_for_gather)
@@ -484,7 +478,7 @@ class LatentGenePool(Module):
484
478
  parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
485
479
  parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
486
480
 
487
- parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
481
+ parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
488
482
 
489
483
  parents = genes.gather(1, parent_gene_ids_for_gather)
490
484
  parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
@@ -507,6 +501,16 @@ class LatentGenePool(Module):
507
501
 
508
502
  genes = mutation(genes, mutation_strength = self.mutation_strength)
509
503
 
504
+ # 6. maybe migration
505
+
506
+ if migrate:
507
+ assert self.num_islands > 1
508
+ randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
509
+
510
+ migrate_mask = randperm < self.num_migrate
511
+ maybe_migrated_genes = torch.roll(genes, 1, dims = 0)
512
+ genes = einx.where('i p, i p g, i p g', migrate_mask, maybe_migrated_genes, genes)
513
+
510
514
  # add back the elites
511
515
 
512
516
  if self.has_elites:
@@ -555,22 +559,6 @@ class LatentGenePool(Module):
555
559
 
556
560
  latent = self.latents[latent_id]
557
561
 
558
- if self.needs_latent_gate:
559
- assert exists(state), 'state must be passed in if greater than number of 1 latent set'
560
-
561
- if not fetching_multiple_latents:
562
- latent = repeat(latent, '... -> b ...', b = state.shape[0])
563
-
564
- assert latent.shape[0] == state.shape[0]
565
-
566
- gates = self.to_latent_gate(state)
567
- latent = einsum(latent, gates, 'b n g, b n -> b g')
568
-
569
- elif fetching_multiple_latents:
570
- latent = latent[:, 0]
571
- else:
572
- latent = latent[0]
573
-
574
562
  latent = self.maybe_l2norm(latent)
575
563
 
576
564
  if not exists(net):
@@ -612,7 +600,7 @@ class Agent(Module):
612
600
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
613
601
  self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
614
602
 
615
- self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.needs_latent_gate else None
603
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
616
604
 
617
605
  def get_actor_actions(
618
606
  self,
@@ -687,7 +675,6 @@ def create_agent(
687
675
  actor_num_actions,
688
676
  actor_dim_hiddens: int | tuple[int, ...],
689
677
  critic_dim_hiddens: int | tuple[int, ...],
690
- num_latent_sets = 1
691
678
  ) -> Agent:
692
679
 
693
680
  actor = Actor(
@@ -707,7 +694,6 @@ def create_agent(
707
694
  dim_state = dim_state,
708
695
  num_latents = num_latents,
709
696
  dim_latent = dim_latent,
710
- num_latent_sets = num_latent_sets
711
697
  )
712
698
 
713
699
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.26
3
+ Version: 0.0.28
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=GckXFGdRoZT149cOlMqLUVe9oXr1QXP-gPZTv4H_HFU,21692
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.28.dist-info/METADATA,sha256=Fn846Lxaxo_OrXFD-_8IECOJ9fZL2JosriGUKMO0CfQ,4958
5
+ evolutionary_policy_optimization-0.0.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.28.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=zYKRKUkvFdxgHkc2yduN76Hph3asWX33mnpDF3isDfo,22019
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.26.dist-info/METADATA,sha256=l24aFXZu4kp1oxZeIdFTUw1mwkyzln9C64S3HNqebF4,4958
5
- evolutionary_policy_optimization-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.26.dist-info/RECORD,,