evolutionary-policy-optimization 0.0.12__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,16 +4,17 @@ from collections import namedtuple
4
4
 
5
5
  import torch
6
6
  from torch import nn, cat
7
- import torch.nn.functional as F
8
-
9
7
  import torch.nn.functional as F
10
8
  from torch.nn import Linear, Module, ModuleList
9
+ from torch.utils.data import TensorDataset, DataLoader
11
10
 
12
11
  from einops import rearrange, repeat, einsum
13
12
  from einops.layers.torch import Rearrange
14
13
 
15
14
  from assoc_scan import AssocScan
16
15
 
16
+ from adam_atan2_pytorch import AdoptAtan2
17
+
17
18
  # helpers
18
19
 
19
20
  def exists(v):
@@ -49,42 +50,6 @@ def gather_log_prob(
49
50
  log_prob = log_probs.gather(-1, indices)
50
51
  return rearrange(log_prob, '... 1 -> ...')
51
52
 
52
- # reinforcement learning related - ppo
53
-
54
- def actor_loss(
55
- logits, # Float[b l]
56
- old_log_probs, # Float[b]
57
- actions, # Int[b]
58
- advantages, # Float[b]
59
- eps_clip = 0.2,
60
- entropy_weight = .01,
61
- ):
62
- log_probs = gather_log_prob(logits, actions)
63
-
64
- ratio = (log_probs - old_log_probs).exp()
65
-
66
- # classic clipped surrogate loss from ppo
67
-
68
- clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
69
-
70
- actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
71
-
72
- # add entropy loss for exploration
73
-
74
- entropy = calc_entropy(logits)
75
-
76
- entropy_aux_loss = -entropy_weight * entropy
77
-
78
- return actor_loss + entropy_aux_loss
79
-
80
- def critic_loss(
81
- pred_values, # Float[b]
82
- advantages, # Float[b]
83
- old_values # Float[b]
84
- ):
85
- discounted_values = advantages + old_values
86
- return F.mse_loss(pred_values, discounted_values)
87
-
88
53
  # generalized advantage estimate
89
54
 
90
55
  def calc_generalized_advantage_estimate(
@@ -500,46 +465,21 @@ class LatentGenePool(Module):
500
465
  **kwargs
501
466
  )
502
467
 
503
- # agent contains the actor, critic, and the latent genetic pool
504
-
505
- def create_agent(
506
- dim_state,
507
- num_latents,
508
- dim_latent,
509
- actor_num_actions,
510
- actor_dim_hiddens: int | tuple[int, ...],
511
- critic_dim_hiddens: int | tuple[int, ...],
512
- num_latent_sets = 1
513
- ) -> Agent:
514
-
515
- actor = Actor(
516
- num_actions = actor_num_actions,
517
- dim_state = dim_state,
518
- dim_latent = dim_latent,
519
- dim_hiddens = actor_dim_hiddens
520
- )
521
-
522
- critic = Critic(
523
- dim_state = dim_state,
524
- dim_latent = dim_latent,
525
- dim_hiddens = critic_dim_hiddens
526
- )
527
-
528
- latent_gene_pool = LatentGenePool(
529
- dim_state = dim_state,
530
- num_latents = num_latents,
531
- dim_latent = dim_latent,
532
- num_latent_sets = num_latent_sets
533
- )
534
-
535
- return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
468
+ # agent class
536
469
 
537
470
  class Agent(Module):
538
471
  def __init__(
539
472
  self,
540
473
  actor: Actor,
541
474
  critic: Critic,
542
- latent_gene_pool: LatentGenePool
475
+ latent_gene_pool: LatentGenePool,
476
+ optim_klass = AdoptAtan2,
477
+ actor_lr = 1e-4,
478
+ critic_lr = 1e-4,
479
+ latent_lr = 1e-5,
480
+ actor_optim_kwargs: dict = dict(),
481
+ critic_optim_kwargs: dict = dict(),
482
+ latent_optim_kwargs: dict = dict(),
543
483
  ):
544
484
  super().__init__()
545
485
 
@@ -548,6 +488,13 @@ class Agent(Module):
548
488
 
549
489
  self.latent_gene_pool = latent_gene_pool
550
490
 
491
+ # optimizers
492
+
493
+ self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
494
+ self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
495
+
496
+ self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.needs_latent_gate else None
497
+
551
498
  def get_actor_actions(
552
499
  self,
553
500
  state,
@@ -576,6 +523,76 @@ class Agent(Module):
576
523
  ):
577
524
  raise NotImplementedError
578
525
 
526
+ # reinforcement learning related - ppo
527
+
528
+ def actor_loss(
529
+ logits, # Float[b l]
530
+ old_log_probs, # Float[b]
531
+ actions, # Int[b]
532
+ advantages, # Float[b]
533
+ eps_clip = 0.2,
534
+ entropy_weight = .01,
535
+ ):
536
+ log_probs = gather_log_prob(logits, actions)
537
+
538
+ ratio = (log_probs - old_log_probs).exp()
539
+
540
+ # classic clipped surrogate loss from ppo
541
+
542
+ clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
543
+
544
+ actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
545
+
546
+ # add entropy loss for exploration
547
+
548
+ entropy = calc_entropy(logits)
549
+
550
+ entropy_aux_loss = -entropy_weight * entropy
551
+
552
+ return actor_loss + entropy_aux_loss
553
+
554
+ def critic_loss(
555
+ pred_values, # Float[b]
556
+ advantages, # Float[b]
557
+ old_values # Float[b]
558
+ ):
559
+ discounted_values = advantages + old_values
560
+ return F.mse_loss(pred_values, discounted_values)
561
+
562
+ # agent contains the actor, critic, and the latent genetic pool
563
+
564
+ def create_agent(
565
+ dim_state,
566
+ num_latents,
567
+ dim_latent,
568
+ actor_num_actions,
569
+ actor_dim_hiddens: int | tuple[int, ...],
570
+ critic_dim_hiddens: int | tuple[int, ...],
571
+ num_latent_sets = 1
572
+ ) -> Agent:
573
+
574
+ actor = Actor(
575
+ num_actions = actor_num_actions,
576
+ dim_state = dim_state,
577
+ dim_latent = dim_latent,
578
+ dim_hiddens = actor_dim_hiddens
579
+ )
580
+
581
+ critic = Critic(
582
+ dim_state = dim_state,
583
+ dim_latent = dim_latent,
584
+ dim_hiddens = critic_dim_hiddens
585
+ )
586
+
587
+ latent_gene_pool = LatentGenePool(
588
+ dim_state = dim_state,
589
+ num_latents = num_latents,
590
+ dim_latent = dim_latent,
591
+ num_latent_sets = num_latent_sets
592
+ )
593
+
594
+ return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
595
+
579
596
  # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
580
597
  # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
581
598
 
@@ -593,13 +610,10 @@ class EPO(Module):
593
610
 
594
611
  def __init__(
595
612
  self,
596
- agent: Agent,
597
- latent_gene_pool: LatentGenePool
613
+ agent: Agent
598
614
  ):
599
615
  super().__init__()
600
-
601
616
  self.agent = agent
602
- self.latent_gene_pool = latent_gene_pool
603
617
 
604
618
  def forward(
605
619
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.12
3
+ Version: 0.0.15
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -95,6 +95,7 @@ value = critic(state, latent)
95
95
 
96
96
  fitness = torch.randn(128)
97
97
 
98
+ latent_pool.genetic_algorithm_step(fitness) # update latent genes with genetic algorithm
98
99
  ```
99
100
 
100
101
  ## Citations
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=f_e-TkJRFF1VHG3psJDgLGNIzlEvDSjX0nOsbLaOBrw,17543
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.15.dist-info/METADATA,sha256=dGjFvEBt10Ac9PhsYSnluSbw_BixLTA_pEekW7TPF3U,4446
5
+ evolutionary_policy_optimization-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.15.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=dX7UANZk4RI_rwoSPi8GnZDoF9H1EdVT6Z7WA30cZ3Q,16934
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.12.dist-info/METADATA,sha256=Uyee4UOgNs04nrJlU6m7drQFTUuka3s8nLhblHxfiEg,4357
5
- evolutionary_policy_optimization-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.12.dist-info/RECORD,,