evolutionary-policy-optimization 0.0.11__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/__init__.py +1 -0
- evolutionary_policy_optimization/epo.py +55 -4
- {evolutionary_policy_optimization-0.0.11.dist-info → evolutionary_policy_optimization-0.0.14.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.14.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.11.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.11.dist-info → evolutionary_policy_optimization-0.0.14.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.11.dist-info → evolutionary_policy_optimization-0.0.14.dist-info}/licenses/LICENSE +0 -0
@@ -502,6 +502,38 @@ class LatentGenePool(Module):
|
|
502
502
|
|
503
503
|
# agent contains the actor, critic, and the latent genetic pool
|
504
504
|
|
505
|
+
def create_agent(
|
506
|
+
dim_state,
|
507
|
+
num_latents,
|
508
|
+
dim_latent,
|
509
|
+
actor_num_actions,
|
510
|
+
actor_dim_hiddens: int | tuple[int, ...],
|
511
|
+
critic_dim_hiddens: int | tuple[int, ...],
|
512
|
+
num_latent_sets = 1
|
513
|
+
) -> Agent:
|
514
|
+
|
515
|
+
actor = Actor(
|
516
|
+
num_actions = actor_num_actions,
|
517
|
+
dim_state = dim_state,
|
518
|
+
dim_latent = dim_latent,
|
519
|
+
dim_hiddens = actor_dim_hiddens
|
520
|
+
)
|
521
|
+
|
522
|
+
critic = Critic(
|
523
|
+
dim_state = dim_state,
|
524
|
+
dim_latent = dim_latent,
|
525
|
+
dim_hiddens = critic_dim_hiddens
|
526
|
+
)
|
527
|
+
|
528
|
+
latent_gene_pool = LatentGenePool(
|
529
|
+
dim_state = dim_state,
|
530
|
+
num_latents = num_latents,
|
531
|
+
dim_latent = dim_latent,
|
532
|
+
num_latent_sets = num_latent_sets
|
533
|
+
)
|
534
|
+
|
535
|
+
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
536
|
+
|
505
537
|
class Agent(Module):
|
506
538
|
def __init__(
|
507
539
|
self,
|
@@ -516,6 +548,28 @@ class Agent(Module):
|
|
516
548
|
|
517
549
|
self.latent_gene_pool = latent_gene_pool
|
518
550
|
|
551
|
+
def get_actor_actions(
|
552
|
+
self,
|
553
|
+
state,
|
554
|
+
latent_id
|
555
|
+
):
|
556
|
+
latent = self.latent_gene_pool(latent_id = latent_id, state = state)
|
557
|
+
return self.actor(state, latent)
|
558
|
+
|
559
|
+
def get_critic_values(
|
560
|
+
self,
|
561
|
+
state,
|
562
|
+
latent_id
|
563
|
+
):
|
564
|
+
latent = self.latent_gene_pool(latent_id = latent_id, state = state)
|
565
|
+
return self.critic(state, latent)
|
566
|
+
|
567
|
+
def update_latent_gene_pool_(
|
568
|
+
self,
|
569
|
+
fitnesses
|
570
|
+
):
|
571
|
+
return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
|
572
|
+
|
519
573
|
def forward(
|
520
574
|
self,
|
521
575
|
memories: list[Memory]
|
@@ -539,13 +593,10 @@ class EPO(Module):
|
|
539
593
|
|
540
594
|
def __init__(
|
541
595
|
self,
|
542
|
-
agent: Agent
|
543
|
-
latent_gene_pool: LatentGenePool
|
596
|
+
agent: Agent
|
544
597
|
):
|
545
598
|
super().__init__()
|
546
|
-
|
547
599
|
self.agent = agent
|
548
|
-
self.latent_gene_pool = latent_gene_pool
|
549
600
|
|
550
601
|
def forward(
|
551
602
|
self,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.14
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=vGSq_KCHxyJClzU23kWRFjwBv511JZN43bBd7KOpWGo,16842
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.14.dist-info/METADATA,sha256=KV5cnqxeUqEAwNYj03sHivKo-cpd12RL_I2Wfu3KLaM,4357
|
5
|
+
evolutionary_policy_optimization-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.14.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=JGow9ofx7IgFy7QNL0dL0K_SCL_bVkBUznMG8aSGM9Q,15591
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.11.dist-info/METADATA,sha256=fkouRBZU5nrPgHt0eT5izSHdOiYGAg67N5Gn3t039mQ,4357
|
5
|
-
evolutionary_policy_optimization-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.11.dist-info/RECORD,,
|
File without changes
|