evolutionary-policy-optimization 0.0.28__tar.gz → 0.0.31__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/PKG-INFO +1 -1
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/epo.py +68 -10
- evolutionary_policy_optimization-0.0.31/evolutionary_policy_optimization/mock_env.py +36 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/pyproject.toml +1 -1
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/tests/test_epo.py +5 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/README.md +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/evolutionary_policy_optimization/experimental.py +0 -0
- {evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/requirements.txt +0 -0
{evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.31
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from pathlib import Path
|
3
4
|
from collections import namedtuple
|
4
5
|
|
5
6
|
import torch
|
@@ -244,6 +245,13 @@ class Critic(Module):
|
|
244
245
|
dim_state,
|
245
246
|
dim_hiddens: tuple[int, ...],
|
246
247
|
dim_latent = 0,
|
248
|
+
use_regression = False,
|
249
|
+
hl_gauss_loss_kwargs: dict = dict(
|
250
|
+
min_value = -10.,
|
251
|
+
max_value = 10.,
|
252
|
+
num_bins = 25,
|
253
|
+
sigma = 0.5
|
254
|
+
)
|
247
255
|
):
|
248
256
|
super().__init__()
|
249
257
|
|
@@ -259,23 +267,28 @@ class Critic(Module):
|
|
259
267
|
|
260
268
|
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
261
269
|
|
262
|
-
self.
|
263
|
-
|
264
|
-
|
265
|
-
|
270
|
+
self.final_act = nn.SiLU()
|
271
|
+
|
272
|
+
self.to_pred = HLGaussLayer(
|
273
|
+
dim = dim_last,
|
274
|
+
use_regression = False,
|
275
|
+
hl_gauss_loss = hl_gauss_loss_kwargs
|
266
276
|
)
|
267
277
|
|
268
278
|
def forward(
|
269
279
|
self,
|
270
280
|
state,
|
271
|
-
latent
|
281
|
+
latent,
|
282
|
+
target = None
|
272
283
|
):
|
273
284
|
|
274
285
|
hidden = self.init_layer(state)
|
275
286
|
|
276
287
|
hidden = self.mlp(hidden, latent)
|
277
288
|
|
278
|
-
|
289
|
+
hidden = self.final_act(hidden)
|
290
|
+
|
291
|
+
return self.to_pred(hidden, target = target)
|
279
292
|
|
280
293
|
# criteria for running genetic algorithm
|
281
294
|
|
@@ -508,8 +521,12 @@ class LatentGenePool(Module):
|
|
508
521
|
randperm = torch.randn(genes.shape[:-1], device = device).argsort(dim = -1)
|
509
522
|
|
510
523
|
migrate_mask = randperm < self.num_migrate
|
511
|
-
|
512
|
-
|
524
|
+
|
525
|
+
nonmigrants = rearrange(genes[~migrate_mask], '(i p) g -> i p g', i = islands)
|
526
|
+
migrants = rearrange(genes[migrate_mask], '(i p) g -> i p g', i = islands)
|
527
|
+
migrants = torch.roll(migrants, 1, dims = 0)
|
528
|
+
|
529
|
+
genes = cat((nonmigrants, migrants), dim = 1)
|
513
530
|
|
514
531
|
# add back the elites
|
515
532
|
|
@@ -591,6 +608,7 @@ class Agent(Module):
|
|
591
608
|
self.actor = actor
|
592
609
|
self.critic = critic
|
593
610
|
|
611
|
+
self.num_latents = latent_gene_pool.num_latents
|
594
612
|
self.latent_gene_pool = latent_gene_pool
|
595
613
|
|
596
614
|
assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
|
@@ -602,6 +620,39 @@ class Agent(Module):
|
|
602
620
|
|
603
621
|
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
604
622
|
|
623
|
+
def save(self, path, overwrite = False):
|
624
|
+
path = Path(path)
|
625
|
+
|
626
|
+
assert not path.exists() or overwrite
|
627
|
+
|
628
|
+
pkg = dict(
|
629
|
+
actor = self.actor.state_dict(),
|
630
|
+
critic = self.critic.state_dict(),
|
631
|
+
latents = self.latent_gene_pool.state_dict(),
|
632
|
+
actor_optim = self.actor_optim.state_dict(),
|
633
|
+
critic_optim = self.critic_optim.state_dict(),
|
634
|
+
latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
|
635
|
+
)
|
636
|
+
|
637
|
+
torch.save(pkg, str(path))
|
638
|
+
|
639
|
+
def load(self, path):
|
640
|
+
path = Path(path)
|
641
|
+
|
642
|
+
assert path.exists()
|
643
|
+
|
644
|
+
pkg = torch.load(str(path), weights_only = True)
|
645
|
+
|
646
|
+
self.actor.load_state_dict(pkg['actor'])
|
647
|
+
self.critic.load_state_dict(pkg['critic'])
|
648
|
+
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
649
|
+
|
650
|
+
self.actor_optim.load_state_dict(pkg['actor_optim'])
|
651
|
+
self.critic_optim.load_state_dict(pkg['critic_optim'])
|
652
|
+
|
653
|
+
if exists(pkg.get('latent_optim', None)):
|
654
|
+
self.latent_optim.load_state_dict(pkg['latent_optim'])
|
655
|
+
|
605
656
|
def get_actor_actions(
|
606
657
|
self,
|
607
658
|
state,
|
@@ -626,8 +677,10 @@ class Agent(Module):
|
|
626
677
|
|
627
678
|
def forward(
|
628
679
|
self,
|
629
|
-
|
680
|
+
memories_and_next_value: MemoriesAndNextValue
|
630
681
|
):
|
682
|
+
memories, next_value = memories_and_next_value
|
683
|
+
|
631
684
|
raise NotImplementedError
|
632
685
|
|
633
686
|
# reinforcement learning related - ppo
|
@@ -711,6 +764,11 @@ Memory = namedtuple('Memory', [
|
|
711
764
|
'done'
|
712
765
|
])
|
713
766
|
|
767
|
+
MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
|
768
|
+
'memories',
|
769
|
+
'next_value'
|
770
|
+
])
|
771
|
+
|
714
772
|
class EPO(Module):
|
715
773
|
|
716
774
|
def __init__(
|
@@ -723,6 +781,6 @@ class EPO(Module):
|
|
723
781
|
def forward(
|
724
782
|
self,
|
725
783
|
env
|
726
|
-
) ->
|
784
|
+
) -> MemoriesAndNextValue:
|
727
785
|
|
728
786
|
raise NotImplementedError
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import tensor, randn, randint
|
5
|
+
from torch.nn import Module
|
6
|
+
|
7
|
+
# mock env
|
8
|
+
|
9
|
+
class Env(Module):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
state_shape: tuple[int, ...]
|
13
|
+
):
|
14
|
+
super().__init__()
|
15
|
+
self.state_shape = state_shape
|
16
|
+
self.register_buffer('dummy', tensor(0))
|
17
|
+
|
18
|
+
@property
|
19
|
+
def device(self):
|
20
|
+
return self.dummy.device
|
21
|
+
|
22
|
+
def reset(
|
23
|
+
self
|
24
|
+
):
|
25
|
+
state = randn(self.state_shape, device = self.device)
|
26
|
+
return state
|
27
|
+
|
28
|
+
def forward(
|
29
|
+
self,
|
30
|
+
actions,
|
31
|
+
):
|
32
|
+
state = randn(self.state_shape, device = self.device)
|
33
|
+
reward = randint(0, 5, (), device = self.device).float()
|
34
|
+
done = zeros((), device = self.device, dtype = torch.bool)
|
35
|
+
|
36
|
+
return state, reward, done
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/.gitignore
RENAMED
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/README.md
RENAMED
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.0.28 → evolutionary_policy_optimization-0.0.31}/requirements.txt
RENAMED
File without changes
|