evolutionary-policy-optimization 0.0.12__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +90 -76
- {evolutionary_policy_optimization-0.0.12.dist-info → evolutionary_policy_optimization-0.0.15.dist-info}/METADATA +2 -1
- evolutionary_policy_optimization-0.0.15.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.12.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.12.dist-info → evolutionary_policy_optimization-0.0.15.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.12.dist-info → evolutionary_policy_optimization-0.0.15.dist-info}/licenses/LICENSE +0 -0
@@ -4,16 +4,17 @@ from collections import namedtuple
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from torch import nn, cat
|
7
|
-
import torch.nn.functional as F
|
8
|
-
|
9
7
|
import torch.nn.functional as F
|
10
8
|
from torch.nn import Linear, Module, ModuleList
|
9
|
+
from torch.utils.data import TensorDataset, DataLoader
|
11
10
|
|
12
11
|
from einops import rearrange, repeat, einsum
|
13
12
|
from einops.layers.torch import Rearrange
|
14
13
|
|
15
14
|
from assoc_scan import AssocScan
|
16
15
|
|
16
|
+
from adam_atan2_pytorch import AdoptAtan2
|
17
|
+
|
17
18
|
# helpers
|
18
19
|
|
19
20
|
def exists(v):
|
@@ -49,42 +50,6 @@ def gather_log_prob(
|
|
49
50
|
log_prob = log_probs.gather(-1, indices)
|
50
51
|
return rearrange(log_prob, '... 1 -> ...')
|
51
52
|
|
52
|
-
# reinforcement learning related - ppo
|
53
|
-
|
54
|
-
def actor_loss(
|
55
|
-
logits, # Float[b l]
|
56
|
-
old_log_probs, # Float[b]
|
57
|
-
actions, # Int[b]
|
58
|
-
advantages, # Float[b]
|
59
|
-
eps_clip = 0.2,
|
60
|
-
entropy_weight = .01,
|
61
|
-
):
|
62
|
-
log_probs = gather_log_prob(logits, actions)
|
63
|
-
|
64
|
-
ratio = (log_probs - old_log_probs).exp()
|
65
|
-
|
66
|
-
# classic clipped surrogate loss from ppo
|
67
|
-
|
68
|
-
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
69
|
-
|
70
|
-
actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
|
71
|
-
|
72
|
-
# add entropy loss for exploration
|
73
|
-
|
74
|
-
entropy = calc_entropy(logits)
|
75
|
-
|
76
|
-
entropy_aux_loss = -entropy_weight * entropy
|
77
|
-
|
78
|
-
return actor_loss + entropy_aux_loss
|
79
|
-
|
80
|
-
def critic_loss(
|
81
|
-
pred_values, # Float[b]
|
82
|
-
advantages, # Float[b]
|
83
|
-
old_values # Float[b]
|
84
|
-
):
|
85
|
-
discounted_values = advantages + old_values
|
86
|
-
return F.mse_loss(pred_values, discounted_values)
|
87
|
-
|
88
53
|
# generalized advantage estimate
|
89
54
|
|
90
55
|
def calc_generalized_advantage_estimate(
|
@@ -500,46 +465,21 @@ class LatentGenePool(Module):
|
|
500
465
|
**kwargs
|
501
466
|
)
|
502
467
|
|
503
|
-
# agent
|
504
|
-
|
505
|
-
def create_agent(
|
506
|
-
dim_state,
|
507
|
-
num_latents,
|
508
|
-
dim_latent,
|
509
|
-
actor_num_actions,
|
510
|
-
actor_dim_hiddens: int | tuple[int, ...],
|
511
|
-
critic_dim_hiddens: int | tuple[int, ...],
|
512
|
-
num_latent_sets = 1
|
513
|
-
) -> Agent:
|
514
|
-
|
515
|
-
actor = Actor(
|
516
|
-
num_actions = actor_num_actions,
|
517
|
-
dim_state = dim_state,
|
518
|
-
dim_latent = dim_latent,
|
519
|
-
dim_hiddens = actor_dim_hiddens
|
520
|
-
)
|
521
|
-
|
522
|
-
critic = Critic(
|
523
|
-
dim_state = dim_state,
|
524
|
-
dim_latent = dim_latent,
|
525
|
-
dim_hiddens = critic_dim_hiddens
|
526
|
-
)
|
527
|
-
|
528
|
-
latent_gene_pool = LatentGenePool(
|
529
|
-
dim_state = dim_state,
|
530
|
-
num_latents = num_latents,
|
531
|
-
dim_latent = dim_latent,
|
532
|
-
num_latent_sets = num_latent_sets
|
533
|
-
)
|
534
|
-
|
535
|
-
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
468
|
+
# agent class
|
536
469
|
|
537
470
|
class Agent(Module):
|
538
471
|
def __init__(
|
539
472
|
self,
|
540
473
|
actor: Actor,
|
541
474
|
critic: Critic,
|
542
|
-
latent_gene_pool: LatentGenePool
|
475
|
+
latent_gene_pool: LatentGenePool,
|
476
|
+
optim_klass = AdoptAtan2,
|
477
|
+
actor_lr = 1e-4,
|
478
|
+
critic_lr = 1e-4,
|
479
|
+
latent_lr = 1e-5,
|
480
|
+
actor_optim_kwargs: dict = dict(),
|
481
|
+
critic_optim_kwargs: dict = dict(),
|
482
|
+
latent_optim_kwargs: dict = dict(),
|
543
483
|
):
|
544
484
|
super().__init__()
|
545
485
|
|
@@ -548,6 +488,13 @@ class Agent(Module):
|
|
548
488
|
|
549
489
|
self.latent_gene_pool = latent_gene_pool
|
550
490
|
|
491
|
+
# optimizers
|
492
|
+
|
493
|
+
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
494
|
+
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
495
|
+
|
496
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.needs_latent_gate else None
|
497
|
+
|
551
498
|
def get_actor_actions(
|
552
499
|
self,
|
553
500
|
state,
|
@@ -576,6 +523,76 @@ class Agent(Module):
|
|
576
523
|
):
|
577
524
|
raise NotImplementedError
|
578
525
|
|
526
|
+
# reinforcement learning related - ppo
|
527
|
+
|
528
|
+
def actor_loss(
|
529
|
+
logits, # Float[b l]
|
530
|
+
old_log_probs, # Float[b]
|
531
|
+
actions, # Int[b]
|
532
|
+
advantages, # Float[b]
|
533
|
+
eps_clip = 0.2,
|
534
|
+
entropy_weight = .01,
|
535
|
+
):
|
536
|
+
log_probs = gather_log_prob(logits, actions)
|
537
|
+
|
538
|
+
ratio = (log_probs - old_log_probs).exp()
|
539
|
+
|
540
|
+
# classic clipped surrogate loss from ppo
|
541
|
+
|
542
|
+
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
543
|
+
|
544
|
+
actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
|
545
|
+
|
546
|
+
# add entropy loss for exploration
|
547
|
+
|
548
|
+
entropy = calc_entropy(logits)
|
549
|
+
|
550
|
+
entropy_aux_loss = -entropy_weight * entropy
|
551
|
+
|
552
|
+
return actor_loss + entropy_aux_loss
|
553
|
+
|
554
|
+
def critic_loss(
|
555
|
+
pred_values, # Float[b]
|
556
|
+
advantages, # Float[b]
|
557
|
+
old_values # Float[b]
|
558
|
+
):
|
559
|
+
discounted_values = advantages + old_values
|
560
|
+
return F.mse_loss(pred_values, discounted_values)
|
561
|
+
|
562
|
+
# agent contains the actor, critic, and the latent genetic pool
|
563
|
+
|
564
|
+
def create_agent(
|
565
|
+
dim_state,
|
566
|
+
num_latents,
|
567
|
+
dim_latent,
|
568
|
+
actor_num_actions,
|
569
|
+
actor_dim_hiddens: int | tuple[int, ...],
|
570
|
+
critic_dim_hiddens: int | tuple[int, ...],
|
571
|
+
num_latent_sets = 1
|
572
|
+
) -> Agent:
|
573
|
+
|
574
|
+
actor = Actor(
|
575
|
+
num_actions = actor_num_actions,
|
576
|
+
dim_state = dim_state,
|
577
|
+
dim_latent = dim_latent,
|
578
|
+
dim_hiddens = actor_dim_hiddens
|
579
|
+
)
|
580
|
+
|
581
|
+
critic = Critic(
|
582
|
+
dim_state = dim_state,
|
583
|
+
dim_latent = dim_latent,
|
584
|
+
dim_hiddens = critic_dim_hiddens
|
585
|
+
)
|
586
|
+
|
587
|
+
latent_gene_pool = LatentGenePool(
|
588
|
+
dim_state = dim_state,
|
589
|
+
num_latents = num_latents,
|
590
|
+
dim_latent = dim_latent,
|
591
|
+
num_latent_sets = num_latent_sets
|
592
|
+
)
|
593
|
+
|
594
|
+
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
595
|
+
|
579
596
|
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
580
597
|
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
581
598
|
|
@@ -593,13 +610,10 @@ class EPO(Module):
|
|
593
610
|
|
594
611
|
def __init__(
|
595
612
|
self,
|
596
|
-
agent: Agent
|
597
|
-
latent_gene_pool: LatentGenePool
|
613
|
+
agent: Agent
|
598
614
|
):
|
599
615
|
super().__init__()
|
600
|
-
|
601
616
|
self.agent = agent
|
602
|
-
self.latent_gene_pool = latent_gene_pool
|
603
617
|
|
604
618
|
def forward(
|
605
619
|
self,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.15
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -95,6 +95,7 @@ value = critic(state, latent)
|
|
95
95
|
|
96
96
|
fitness = torch.randn(128)
|
97
97
|
|
98
|
+
latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
|
98
99
|
```
|
99
100
|
|
100
101
|
## Citations
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=f_e-TkJRFF1VHG3psJDgLGNIzlEvDSjX0nOsbLaOBrw,17543
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.15.dist-info/METADATA,sha256=dGjFvEBt10Ac9PhsYSnluSbw_BixLTA_pEekW7TPF3U,4446
|
5
|
+
evolutionary_policy_optimization-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.15.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=dX7UANZk4RI_rwoSPi8GnZDoF9H1EdVT6Z7WA30cZ3Q,16934
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.12.dist-info/METADATA,sha256=Uyee4UOgNs04nrJlU6m7drQFTUuka3s8nLhblHxfiEg,4357
|
5
|
-
evolutionary_policy_optimization-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.12.dist-info/RECORD,,
|
File without changes
|