evolutionary-policy-optimization 0.0.29__py3-none-any.whl → 0.0.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +62 -8
- evolutionary_policy_optimization/mock_env.py +36 -0
- {evolutionary_policy_optimization-0.0.29.dist-info → evolutionary_policy_optimization-0.0.31.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.31.dist-info/RECORD +8 -0
- evolutionary_policy_optimization-0.0.29.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.29.dist-info → evolutionary_policy_optimization-0.0.31.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.29.dist-info → evolutionary_policy_optimization-0.0.31.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from pathlib import Path
|
3
4
|
from collections import namedtuple
|
4
5
|
|
5
6
|
import torch
|
@@ -244,6 +245,13 @@ class Critic(Module):
|
|
244
245
|
dim_state,
|
245
246
|
dim_hiddens: tuple[int, ...],
|
246
247
|
dim_latent = 0,
|
248
|
+
use_regression = False,
|
249
|
+
hl_gauss_loss_kwargs: dict = dict(
|
250
|
+
min_value = -10.,
|
251
|
+
max_value = 10.,
|
252
|
+
num_bins = 25,
|
253
|
+
sigma = 0.5
|
254
|
+
)
|
247
255
|
):
|
248
256
|
super().__init__()
|
249
257
|
|
@@ -259,23 +267,28 @@ class Critic(Module):
|
|
259
267
|
|
260
268
|
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
261
269
|
|
262
|
-
self.
|
263
|
-
|
264
|
-
|
265
|
-
|
270
|
+
self.final_act = nn.SiLU()
|
271
|
+
|
272
|
+
self.to_pred = HLGaussLayer(
|
273
|
+
dim = dim_last,
|
274
|
+
use_regression = False,
|
275
|
+
hl_gauss_loss = hl_gauss_loss_kwargs
|
266
276
|
)
|
267
277
|
|
268
278
|
def forward(
|
269
279
|
self,
|
270
280
|
state,
|
271
|
-
latent
|
281
|
+
latent,
|
282
|
+
target = None
|
272
283
|
):
|
273
284
|
|
274
285
|
hidden = self.init_layer(state)
|
275
286
|
|
276
287
|
hidden = self.mlp(hidden, latent)
|
277
288
|
|
278
|
-
|
289
|
+
hidden = self.final_act(hidden)
|
290
|
+
|
291
|
+
return self.to_pred(hidden, target = target)
|
279
292
|
|
280
293
|
# criteria for running genetic algorithm
|
281
294
|
|
@@ -595,6 +608,7 @@ class Agent(Module):
|
|
595
608
|
self.actor = actor
|
596
609
|
self.critic = critic
|
597
610
|
|
611
|
+
self.num_latents = latent_gene_pool.num_latents
|
598
612
|
self.latent_gene_pool = latent_gene_pool
|
599
613
|
|
600
614
|
assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
|
@@ -606,6 +620,39 @@ class Agent(Module):
|
|
606
620
|
|
607
621
|
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
608
622
|
|
623
|
+
def save(self, path, overwrite = False):
|
624
|
+
path = Path(path)
|
625
|
+
|
626
|
+
assert not path.exists() or overwrite
|
627
|
+
|
628
|
+
pkg = dict(
|
629
|
+
actor = self.actor.state_dict(),
|
630
|
+
critic = self.critic.state_dict(),
|
631
|
+
latents = self.latent_gene_pool.state_dict(),
|
632
|
+
actor_optim = self.actor_optim.state_dict(),
|
633
|
+
critic_optim = self.critic_optim.state_dict(),
|
634
|
+
latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
|
635
|
+
)
|
636
|
+
|
637
|
+
torch.save(pkg, str(path))
|
638
|
+
|
639
|
+
def load(self, path):
|
640
|
+
path = Path(path)
|
641
|
+
|
642
|
+
assert path.exists()
|
643
|
+
|
644
|
+
pkg = torch.load(str(path), weights_only = True)
|
645
|
+
|
646
|
+
self.actor.load_state_dict(pkg['actor'])
|
647
|
+
self.critic.load_state_dict(pkg['critic'])
|
648
|
+
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
649
|
+
|
650
|
+
self.actor_optim.load_state_dict(pkg['actor_optim'])
|
651
|
+
self.critic_optim.load_state_dict(pkg['critic_optim'])
|
652
|
+
|
653
|
+
if exists(pkg.get('latent_optim', None)):
|
654
|
+
self.latent_optim.load_state_dict(pkg['latent_optim'])
|
655
|
+
|
609
656
|
def get_actor_actions(
|
610
657
|
self,
|
611
658
|
state,
|
@@ -630,8 +677,10 @@ class Agent(Module):
|
|
630
677
|
|
631
678
|
def forward(
|
632
679
|
self,
|
633
|
-
|
680
|
+
memories_and_next_value: MemoriesAndNextValue
|
634
681
|
):
|
682
|
+
memories, next_value = memories_and_next_value
|
683
|
+
|
635
684
|
raise NotImplementedError
|
636
685
|
|
637
686
|
# reinforcement learning related - ppo
|
@@ -715,6 +764,11 @@ Memory = namedtuple('Memory', [
|
|
715
764
|
'done'
|
716
765
|
])
|
717
766
|
|
767
|
+
MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
|
768
|
+
'memories',
|
769
|
+
'next_value'
|
770
|
+
])
|
771
|
+
|
718
772
|
class EPO(Module):
|
719
773
|
|
720
774
|
def __init__(
|
@@ -727,6 +781,6 @@ class EPO(Module):
|
|
727
781
|
def forward(
|
728
782
|
self,
|
729
783
|
env
|
730
|
-
) ->
|
784
|
+
) -> MemoriesAndNextValue:
|
731
785
|
|
732
786
|
raise NotImplementedError
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import tensor, randn, randint
|
5
|
+
from torch.nn import Module
|
6
|
+
|
7
|
+
# mock env
|
8
|
+
|
9
|
+
class Env(Module):
|
10
|
+
def __init__(
|
11
|
+
self,
|
12
|
+
state_shape: tuple[int, ...]
|
13
|
+
):
|
14
|
+
super().__init__()
|
15
|
+
self.state_shape = state_shape
|
16
|
+
self.register_buffer('dummy', tensor(0))
|
17
|
+
|
18
|
+
@property
|
19
|
+
def device(self):
|
20
|
+
return self.dummy.device
|
21
|
+
|
22
|
+
def reset(
|
23
|
+
self
|
24
|
+
):
|
25
|
+
state = randn(self.state_shape, device = self.device)
|
26
|
+
return state
|
27
|
+
|
28
|
+
def forward(
|
29
|
+
self,
|
30
|
+
actions,
|
31
|
+
):
|
32
|
+
state = randn(self.state_shape, device = self.device)
|
33
|
+
reward = randint(0, 5, (), device = self.device).float()
|
34
|
+
done = zeros((), device = self.device, dtype = torch.bool)
|
35
|
+
|
36
|
+
return state, reward, done
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.31
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,8 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=jjlCgB-oyowMtlRUWg32YGAh7-yN97yl0qCq-Lah4lE,23516
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
|
5
|
+
evolutionary_policy_optimization-0.0.31.dist-info/METADATA,sha256=a8VwzC6q__7tmPTb35bn86zaJfvpnzdHcyG_mioUEQs,4958
|
6
|
+
evolutionary_policy_optimization-0.0.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
evolutionary_policy_optimization-0.0.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
evolutionary_policy_optimization-0.0.31.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=4iuro11yTpRNzFfSoRZARnOiTDIJYndWmVaUAqk3--E,21826
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.29.dist-info/METADATA,sha256=C4gxOaspzHqA7TN5iQ8cDIFkg8llS8kg4y5Xg_ke2Qc,4958
|
5
|
-
evolutionary_policy_optimization-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.29.dist-info/RECORD,,
|
File without changes
|