evolutionary-policy-optimization 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
303
303
  self,
304
304
  num_latents, # same as gene pool size
305
305
  dim_latent, # gene dimension
306
- num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
307
306
  num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
308
307
  dim_state = None,
309
308
  frozen_latents = True,
@@ -320,29 +319,17 @@ class LatentGenePool(Module):
320
319
 
321
320
  maybe_l2norm = l2norm if l2norm_latent else identity
322
321
 
323
- latents = torch.randn(num_latents, num_latent_sets, dim_latent)
322
+ latents = torch.randn(num_latents, dim_latent)
324
323
 
325
324
  if l2norm_latent:
326
325
  latents = maybe_l2norm(latents, dim = -1)
327
326
 
328
327
  self.num_latents = num_latents
329
- self.needs_latent_gate = num_latent_sets > 1
328
+ self.frozen_latents = frozen_latents
330
329
  self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
331
330
 
332
331
  self.maybe_l2norm = maybe_l2norm
333
332
 
334
- # gene expression as a function of environment
335
-
336
- self.num_latent_sets = num_latent_sets
337
-
338
- if self.needs_latent_gate:
339
- assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
340
-
341
- self.to_latent_gate = nn.Sequential(
342
- Linear(dim_state, num_latent_sets),
343
- nn.Softmax(dim = -1)
344
- ) if self.needs_latent_gate else None
345
-
346
333
  # some derived values
347
334
 
348
335
  assert num_islands >= 1
@@ -460,7 +447,7 @@ class LatentGenePool(Module):
460
447
 
461
448
  return genes
462
449
 
463
- genes = rearrange(genes, '(i p) n g -> i p n g', i = islands)
450
+ genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
464
451
 
465
452
  orig_genes = genes
466
453
 
@@ -469,7 +456,7 @@ class LatentGenePool(Module):
469
456
 
470
457
  sorted_indices = fitness.sort(dim = -1).indices
471
458
  natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
472
- natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... n g', n = genes.shape[-2], g = genes.shape[-1])
459
+ natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
473
460
 
474
461
  genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
475
462
 
@@ -484,7 +471,7 @@ class LatentGenePool(Module):
484
471
  parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
485
472
  parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
486
473
 
487
- parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) n g', n = genes.shape[-2], g = genes.shape[-1])
474
+ parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
488
475
 
489
476
  parents = genes.gather(1, parent_gene_ids_for_gather)
490
477
  parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
@@ -555,22 +542,6 @@ class LatentGenePool(Module):
555
542
 
556
543
  latent = self.latents[latent_id]
557
544
 
558
- if self.needs_latent_gate:
559
- assert exists(state), 'state must be passed in if greater than number of 1 latent set'
560
-
561
- if not fetching_multiple_latents:
562
- latent = repeat(latent, '... -> b ...', b = state.shape[0])
563
-
564
- assert latent.shape[0] == state.shape[0]
565
-
566
- gates = self.to_latent_gate(state)
567
- latent = einsum(latent, gates, 'b n g, b n -> b g')
568
-
569
- elif fetching_multiple_latents:
570
- latent = latent[:, 0]
571
- else:
572
- latent = latent[0]
573
-
574
545
  latent = self.maybe_l2norm(latent)
575
546
 
576
547
  if not exists(net):
@@ -612,7 +583,7 @@ class Agent(Module):
612
583
  self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
613
584
  self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
614
585
 
615
- self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.needs_latent_gate else None
586
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
616
587
 
617
588
  def get_actor_actions(
618
589
  self,
@@ -687,7 +658,6 @@ def create_agent(
687
658
  actor_num_actions,
688
659
  actor_dim_hiddens: int | tuple[int, ...],
689
660
  critic_dim_hiddens: int | tuple[int, ...],
690
- num_latent_sets = 1
691
661
  ) -> Agent:
692
662
 
693
663
  actor = Actor(
@@ -707,7 +677,6 @@ def create_agent(
707
677
  dim_state = dim_state,
708
678
  num_latents = num_latents,
709
679
  dim_latent = dim_latent,
710
- num_latent_sets = num_latent_sets
711
680
  )
712
681
 
713
682
  return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.26
3
+ Version: 0.0.27
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=UCCwYK-b20X-5Cq-pah1NTeHFc_35b4xZ3y0aSR8aaI,20783
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.27.dist-info/METADATA,sha256=pJ2kQD5YtKDSUp1TCO_hsrRMh6FCMm8dyu6WrpVHiQk,4958
5
+ evolutionary_policy_optimization-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.27.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=zYKRKUkvFdxgHkc2yduN76Hph3asWX33mnpDF3isDfo,22019
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.26.dist-info/METADATA,sha256=l24aFXZu4kp1oxZeIdFTUw1mwkyzln9C64S3HNqebF4,4958
5
- evolutionary_policy_optimization-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.26.dist-info/RECORD,,