evolutionary-policy-optimization 0.0.11__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ from evolutionary_policy_optimization.epo import (
2
2
  MLP,
3
3
  Actor,
4
4
  Critic,
5
+ create_agent,
5
6
  Agent,
6
7
  LatentGenePool
7
8
  )
@@ -502,6 +502,38 @@ class LatentGenePool(Module):
502
502
 
503
503
  # agent contains the actor, critic, and the latent genetic pool
504
504
 
505
+ def create_agent(
506
+ dim_state,
507
+ num_latents,
508
+ dim_latent,
509
+ actor_num_actions,
510
+ actor_dim_hiddens: int | tuple[int, ...],
511
+ critic_dim_hiddens: int | tuple[int, ...],
512
+ num_latent_sets = 1
513
+ ) -> Agent:
514
+
515
+ actor = Actor(
516
+ num_actions = actor_num_actions,
517
+ dim_state = dim_state,
518
+ dim_latent = dim_latent,
519
+ dim_hiddens = actor_dim_hiddens
520
+ )
521
+
522
+ critic = Critic(
523
+ dim_state = dim_state,
524
+ dim_latent = dim_latent,
525
+ dim_hiddens = critic_dim_hiddens
526
+ )
527
+
528
+ latent_gene_pool = LatentGenePool(
529
+ dim_state = dim_state,
530
+ num_latents = num_latents,
531
+ dim_latent = dim_latent,
532
+ num_latent_sets = num_latent_sets
533
+ )
534
+
535
+ return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
536
+
505
537
  class Agent(Module):
506
538
  def __init__(
507
539
  self,
@@ -516,6 +548,28 @@ class Agent(Module):
516
548
 
517
549
  self.latent_gene_pool = latent_gene_pool
518
550
 
551
+ def get_actor_actions(
552
+ self,
553
+ state,
554
+ latent_id
555
+ ):
556
+ latent = self.latent_gene_pool(latent_id = latent_id, state = state)
557
+ return self.actor(state, latent)
558
+
559
+ def get_critic_values(
560
+ self,
561
+ state,
562
+ latent_id
563
+ ):
564
+ latent = self.latent_gene_pool(latent_id = latent_id, state = state)
565
+ return self.critic(state, latent)
566
+
567
+ def update_latent_gene_pool_(
568
+ self,
569
+ fitnesses
570
+ ):
571
+ return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
572
+
519
573
  def forward(
520
574
  self,
521
575
  memories: list[Memory]
@@ -539,13 +593,10 @@ class EPO(Module):
539
593
 
540
594
  def __init__(
541
595
  self,
542
- agent: Agent,
543
- latent_gene_pool: LatentGenePool
596
+ agent: Agent
544
597
  ):
545
598
  super().__init__()
546
-
547
599
  self.agent = agent
548
- self.latent_gene_pool = latent_gene_pool
549
600
 
550
601
  def forward(
551
602
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.11
3
+ Version: 0.0.14
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=vGSq_KCHxyJClzU23kWRFjwBv511JZN43bBd7KOpWGo,16842
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.14.dist-info/METADATA,sha256=KV5cnqxeUqEAwNYj03sHivKo-cpd12RL_I2Wfu3KLaM,4357
5
+ evolutionary_policy_optimization-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.14.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
2
- evolutionary_policy_optimization/epo.py,sha256=JGow9ofx7IgFy7QNL0dL0K_SCL_bVkBUznMG8aSGM9Q,15591
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.11.dist-info/METADATA,sha256=fkouRBZU5nrPgHt0eT5izSHdOiYGAg67N5Gn3t039mQ,4357
5
- evolutionary_policy_optimization-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.11.dist-info/RECORD,,