evolutionary-policy-optimization 0.0.28__tar.gz → 0.0.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (13) hide show
  1. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/epo.py +68 -10
  3. evolutionary_policy_optimization-0.0.31/evolutionary_policy_optimization/mock_env.py +36 -0
  4. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/pyproject.toml +1 -1
  5. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/tests/test_epo.py +5 -0
  6. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/README.md +0 -0
  11. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/__init__.py +0 -0
  12. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/experimental.py +0 -0
  13. {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/requirements.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.28
3
+ Version: 0.0.31
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from pathlib import Path
3
4
  from collections import namedtuple
4
5
 
5
6
  import torch
@@ -244,6 +245,13 @@ class Critic(Module):
244
245
  dim_state,
245
246
  dim_hiddens: tuple[int, ...],
246
247
  dim_latent = 0,
248
+ use_regression = False,
249
+ hl_gauss_loss_kwargs: dict = dict(
250
+ min_value = -10.,
251
+ max_value = 10.,
252
+ num_bins = 25,
253
+ sigma = 0.5
254
+ )
247
255
  ):
248
256
  super().__init__()
249
257
 
@@ -259,23 +267,28 @@ class Critic(Module):
259
267
 
260
268
  self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
261
269
 
262
- self.to_out = nn.Sequential(
263
- nn.SiLU(),
264
- nn.Linear(dim_last, 1),
265
- Rearrange('... 1 -> ...')
270
+ self.final_act = nn.SiLU()
271
+
272
+ self.to_pred = HLGaussLayer(
273
+ dim = dim_last,
274
+ use_regression = False,
275
+ hl_gauss_loss = hl_gauss_loss_kwargs
266
276
  )
267
277
 
268
278
  def forward(
269
279
  self,
270
280
  state,
271
- latent
281
+ latent,
282
+ target = None
272
283
  ):
273
284
 
274
285
  hidden = self.init_layer(state)
275
286
 
276
287
  hidden = self.mlp(hidden, latent)
277
288
 
278
- return self.to_out(hidden)
289
+ hidden = self.final_act(hidden)
290
+
291
+ return self.to_pred(hidden, target = target)
279
292
 
280
293
  # criteria for running genetic algorithm
281
294
 
@@ -508,8 +521,12 @@ class LatentGenePool(Module):
508
521
  randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
509
522
 
510
523
  migrate_mask = randperm < self.num_migrate
511
- maybe_migrated_genes = torch.roll(genes, 1, dims = 0)
512
- genes = einx.where('i p, i p g, i p g', migrate_mask, maybe_migrated_genes, genes)
524
+
525
+ nonmigrants = rearrange(genes[~migrate_mask], '(i p) g -> i p g', i = islands)
526
+ migrants = rearrange(genes[migrate_mask], '(i p) g -> i p g', i = islands)
527
+ migrants = torch.roll(migrants, 1, dims = 0)
528
+
529
+ genes = cat((nonmigrants, migrants), dim = 1)
513
530
 
514
531
  # add back the elites
515
532
 
@@ -591,6 +608,7 @@ class Agent(Module):
591
608
  self.actor = actor
592
609
  self.critic = critic
593
610
 
611
+ self.num_latents = latent_gene_pool.num_latents
594
612
  self.latent_gene_pool = latent_gene_pool
595
613
 
596
614
  assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
@@ -602,6 +620,39 @@ class Agent(Module):
602
620
 
603
621
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
604
622
 
623
+ def save(self, path, overwrite = False):
624
+ path = Path(path)
625
+
626
+ assert not path.exists() or overwrite
627
+
628
+ pkg = dict(
629
+ actor = self.actor.state_dict(),
630
+ critic = self.critic.state_dict(),
631
+ latents = self.latent_gene_pool.state_dict(),
632
+ actor_optim = self.actor_optim.state_dict(),
633
+ critic_optim = self.critic_optim.state_dict(),
634
+ latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
635
+ )
636
+
637
+ torch.save(pkg, str(path))
638
+
639
+ def load(self, path):
640
+ path = Path(path)
641
+
642
+ assert path.exists()
643
+
644
+ pkg = torch.load(str(path), weights_only = True)
645
+
646
+ self.actor.load_state_dict(pkg['actor'])
647
+ self.critic.load_state_dict(pkg['critic'])
648
+ self.latent_gene_pool.load_state_dict(pkg['latents'])
649
+
650
+ self.actor_optim.load_state_dict(pkg['actor_optim'])
651
+ self.critic_optim.load_state_dict(pkg['critic_optim'])
652
+
653
+ if exists(pkg.get('latent_optim', None)):
654
+ self.latent_optim.load_state_dict(pkg['latent_optim'])
655
+
605
656
  def get_actor_actions(
606
657
  self,
607
658
  state,
@@ -626,8 +677,10 @@ class Agent(Module):
626
677
 
627
678
  def forward(
628
679
  self,
629
- memories: list[Memory]
680
+ memories_and_next_value: MemoriesAndNextValue
630
681
  ):
682
+ memories, next_value = memories_and_next_value
683
+
631
684
  raise NotImplementedError
632
685
 
633
686
  # reinforcement learning related - ppo
@@ -711,6 +764,11 @@ Memory = namedtuple('Memory', [
711
764
  'done'
712
765
  ])
713
766
 
767
+ MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
768
+ 'memories',
769
+ 'next_value'
770
+ ])
771
+
714
772
  class EPO(Module):
715
773
 
716
774
  def __init__(
@@ -723,6 +781,6 @@ class EPO(Module):
723
781
  def forward(
724
782
  self,
725
783
  env
726
- ) -> list[Memory]:
784
+ ) -> MemoriesAndNextValue:
727
785
 
728
786
  raise NotImplementedError
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import tensor, randn, randint
5
+ from torch.nn import Module
6
+
7
+ # mock env
8
+
9
+ class Env(Module):
10
+ def __init__(
11
+ self,
12
+ state_shape: tuple[int, ...]
13
+ ):
14
+ super().__init__()
15
+ self.state_shape = state_shape
16
+ self.register_buffer('dummy', tensor(0))
17
+
18
+ @property
19
+ def device(self):
20
+ return self.dummy.device
21
+
22
+ def reset(
23
+ self
24
+ ):
25
+ state = randn(self.state_shape, device = self.device)
26
+ return state
27
+
28
+ def forward(
29
+ self,
30
+ actions,
31
+ ):
32
+ state = randn(self.state_shape, device = self.device)
33
+ reward = randint(0, 5, (), device = self.device).float()
34
+ done = zeros((), device = self.device, dtype = torch.bool)
35
+
36
+ return state, reward, done
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.0.28"
3
+ version = "0.0.31"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -68,3 +68,8 @@ def test_create_agent(
68
68
  fitness = torch.randn(128)
69
69
 
70
70
  agent.update_latent_gene_pool_(fitness) # update once
71
+
72
+ # saving and loading
73
+
74
+ agent.save('./agent.pt', overwrite = True)
75
+ agent.load('./agent.pt')