evolutionary-policy-optimization 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
303
303
  self,
304
304
  num_latents, # same as gene pool size
305
305
  dim_latent, # gene dimension
306
- num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
307
306
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
308
307
  dim_state = None,
309
308
  frozen_latents = True,
@@ -320,29 +319,17 @@ class LatentGenePool(Module):
320
319
 
321
320
  maybe_l2norm = l2norm if l2norm_latent else identity
322
321
 
323
- latents = torch.randn(num_latents, num_latent_sets, dim_latent)
322
+ latents = torch.randn(num_latents, dim_latent)
324
323
 
325
324
  if l2norm_latent:
326
325
  latents = maybe_l2norm(latents, dim = -1)
327
326
 
328
327
  self.num_latents = num_latents
329
- self.needs_latent_gate = num_latent_sets > 1
328
+ self.frozen_latents = frozen_latents
330
329
  self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
331
330
 
332
331
  self.maybe_l2norm = maybe_l2norm
333
332
 
334
- # gene expression as a function of environment
335
-
336
- self.num_latent_sets = num_latent_sets
337
-
338
- if self.needs_latent_gate:
339
- assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
340
-
341
- self.to_latent_gate = nn.Sequential(
342
- Linear(dim_state, num_latent_sets),
343
- nn.Softmax(dim = -1)
344
- ) if self.needs_latent_gate else None
345
-
346
333
  # some derived values
347
334
 
348
335
  assert num_islands >= 1
@@ -410,6 +397,10 @@ class LatentGenePool(Module):
410
397
 
411
398
  fireflies = rearrange(fireflies, 'i p ... -> (i p) ...')
412
399
 
400
+ # maybe fireflies on hypersphere
401
+
402
+ fireflies = self.maybe_l2norm(fireflies)
403
+
413
404
  if not inplace:
414
405
  return fireflies
415
406
 
@@ -456,7 +447,7 @@ class LatentGenePool(Module):
456
447
 
457
448
  return genes
458
449
 
459
- genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
450
+ genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
460
451
 
461
452
  orig_genes = genes
462
453
 
@@ -465,7 +456,7 @@ class LatentGenePool(Module):
465
456
 
466
457
  sorted_indices = fitness.sort(dim = -1).indices
467
458
  natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
468
- natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
459
+ natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
469
460
 
470
461
  genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
471
462
 
@@ -480,7 +471,7 @@ class LatentGenePool(Module):
480
471
  parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
481
472
  parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
482
473
 
483
- parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
474
+ parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
484
475
 
485
476
  parents = genes.gather(1, parent_gene_ids_for_gather)
486
477
  parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
@@ -551,22 +542,6 @@ class LatentGenePool(Module):
551
542
 
552
543
  latent = self.latents[latent_id]
553
544
 
554
- if self.needs_latent_gate:
555
- assert exists(state), 'state must be passed in if greater than number of 1 latent set'
556
-
557
- if not fetching_multiple_latents:
558
- latent = repeat(latent, '... -> b ...', b = state.shape[0])
559
-
560
- assert latent.shape[0] == state.shape[0]
561
-
562
- gates = self.to_latent_gate(state)
563
- latent = einsum(latent, gates, 'b n g, b n -> b g')
564
-
565
- elif fetching_multiple_latents:
566
- latent = latent[:, 0]
567
- else:
568
- latent = latent[0]
569
-
570
545
  latent = self.maybe_l2norm(latent)
571
546
 
572
547
  if not exists(net):
@@ -608,7 +583,7 @@ class Agent(Module):
608
583
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
609
584
  self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
610
585
 
611
- self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.needs_latent_gate else None
586
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
612
587
 
613
588
  def get_actor_actions(
614
589
  self,
@@ -683,7 +658,6 @@ def create_agent(
683
658
  actor_num_actions,
684
659
  actor_dim_hiddens: int | tuple[int, ...],
685
660
  critic_dim_hiddens: int | tuple[int, ...],
686
- num_latent_sets = 1
687
661
  ) -> Agent:
688
662
 
689
663
  actor = Actor(
@@ -703,7 +677,6 @@ def create_agent(
703
677
  dim_state = dim_state,
704
678
  num_latents = num_latents,
705
679
  dim_latent = dim_latent,
706
- num_latent_sets = num_latent_sets
707
680
  )
708
681
 
709
682
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.25
3
+ Version: 0.0.27
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=UCCwYK-b20X-5Cq-pah1NTeHFc_35b4xZ3y0aSR8aaI,20783
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.27.dist-info/METADATA,sha256=pJ2kQD5YtKDSUp1TCO_hsrRMh6FCMm8dyu6WrpVHiQk,4958
5
+ evolutionary_policy_optimization-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.27.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=BLwy7PBZOjw6H7MFvMq9CC7Mdm3K8fpzBNH6HbNu6LY,21927
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.25.dist-info/METADATA,sha256=p3-_SuLvKs8E0z1l567qA0Pbsv2dOLlrJPX4WYoZaB4,4958
5
- evolutionary_policy_optimization-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.25.dist-info/RECORD,,