evolutionary-policy-optimization 0.0.29__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from pathlib import Path
3
4
  from collections import namedtuple
4
5
 
5
6
  import torch
@@ -244,6 +245,13 @@ class Critic(Module):
244
245
  dim_state,
245
246
  dim_hiddens: tuple[int, ...],
246
247
  dim_latent = 0,
248
+ use_regression = False,
249
+ hl_gauss_loss_kwargs: dict = dict(
250
+ min_value = -10.,
251
+ max_value = 10.,
252
+ num_bins = 25,
253
+ sigma = 0.5
254
+ )
247
255
  ):
248
256
  super().__init__()
249
257
 
@@ -259,23 +267,28 @@ class Critic(Module):
259
267
 
260
268
  self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
261
269
 
262
- self.to_out = nn.Sequential(
263
- nn.SiLU(),
264
- nn.Linear(dim_last, 1),
265
- Rearrange('... 1 -> ...')
270
+ self.final_act = nn.SiLU()
271
+
272
+ self.to_pred = HLGaussLayer(
273
+ dim = dim_last,
274
+ use_regression = use_regression,
275
+ hl_gauss_loss = hl_gauss_loss_kwargs
266
276
  )
267
277
 
268
278
  def forward(
269
279
  self,
270
280
  state,
271
- latent
281
+ latent,
282
+ target = None
272
283
  ):
273
284
 
274
285
  hidden = self.init_layer(state)
275
286
 
276
287
  hidden = self.mlp(hidden, latent)
277
288
 
278
- return self.to_out(hidden)
289
+ hidden = self.final_act(hidden)
290
+
291
+ return self.to_pred(hidden, target = target)
279
292
 
280
293
  # criteria for running genetic algorithm
281
294
 
@@ -595,6 +608,7 @@ class Agent(Module):
595
608
  self.actor = actor
596
609
  self.critic = critic
597
610
 
611
+ self.num_latents = latent_gene_pool.num_latents
598
612
  self.latent_gene_pool = latent_gene_pool
599
613
 
600
614
  assert actor.dim_latent == critic.dim_latent == latent_gene_pool.dim_latent
@@ -606,6 +620,39 @@ class Agent(Module):
606
620
 
607
621
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
608
622
 
623
+ def save(self, path, overwrite = False):
624
+ path = Path(path)
625
+
626
+ assert not path.exists() or overwrite
627
+
628
+ pkg = dict(
629
+ actor = self.actor.state_dict(),
630
+ critic = self.critic.state_dict(),
631
+ latents = self.latent_gene_pool.state_dict(),
632
+ actor_optim = self.actor_optim.state_dict(),
633
+ critic_optim = self.critic_optim.state_dict(),
634
+ latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
635
+ )
636
+
637
+ torch.save(pkg, str(path))
638
+
639
+ def load(self, path):
640
+ path = Path(path)
641
+
642
+ assert path.exists()
643
+
644
+ pkg = torch.load(str(path), weights_only = True)
645
+
646
+ self.actor.load_state_dict(pkg['actor'])
647
+ self.critic.load_state_dict(pkg['critic'])
648
+ self.latent_gene_pool.load_state_dict(pkg['latents'])
649
+
650
+ self.actor_optim.load_state_dict(pkg['actor_optim'])
651
+ self.critic_optim.load_state_dict(pkg['critic_optim'])
652
+
653
+ if exists(pkg.get('latent_optim', None)):
654
+ self.latent_optim.load_state_dict(pkg['latent_optim'])
655
+
609
656
  def get_actor_actions(
610
657
  self,
611
658
  state,
@@ -630,8 +677,10 @@ class Agent(Module):
630
677
 
631
678
  def forward(
632
679
  self,
633
- memories: list[Memory]
680
+ memories_and_next_value: MemoriesAndNextValue
634
681
  ):
682
+ memories, next_value = memories_and_next_value
683
+
635
684
  raise NotImplementedError
636
685
 
637
686
  # reinforcement learning related - ppo
@@ -715,6 +764,11 @@ Memory = namedtuple('Memory', [
715
764
  'done'
716
765
  ])
717
766
 
767
+ MemoriesAndNextValue = namedtuple('MemoriesAndNextValue', [
768
+ 'memories',
769
+ 'next_value'
770
+ ])
771
+
718
772
  class EPO(Module):
719
773
 
720
774
  def __init__(
@@ -727,6 +781,6 @@ class EPO(Module):
727
781
  def forward(
728
782
  self,
729
783
  env
730
- ) -> list[Memory]:
784
+ ) -> MemoriesAndNextValue:
731
785
 
732
786
  raise NotImplementedError
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import tensor, randn, randint
5
+ from torch.nn import Module
6
+
7
+ # mock env
8
+
9
+ class Env(Module):
10
+ def __init__(
11
+ self,
12
+ state_shape: tuple[int, ...]
13
+ ):
14
+ super().__init__()
15
+ self.state_shape = state_shape
16
+ self.register_buffer('dummy', tensor(0))
17
+
18
+ @property
19
+ def device(self):
20
+ return self.dummy.device
21
+
22
+ def reset(
23
+ self
24
+ ):
25
+ state = randn(self.state_shape, device = self.device)
26
+ return state
27
+
28
+ def forward(
29
+ self,
30
+ actions,
31
+ ):
32
+ state = randn(self.state_shape, device = self.device)
33
+ reward = randint(0, 5, (), device = self.device).float()
34
+ done = zeros((), device = self.device, dtype = torch.bool)
35
+
36
+ return state, reward, done
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.29
3
+ Version: 0.0.32
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,8 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
+ evolutionary_policy_optimization/epo.py,sha256=MUcCJLE9cNZS84m5Dhl9qD2ygptvJSuDe6ElwardtgA,23525
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization/mock_env.py,sha256=3xrd-gwjZeVd_sEvxIyX0lppnMWcfQGOapO-XjKmExI,816
5
+ evolutionary_policy_optimization-0.0.32.dist-info/METADATA,sha256=NfF4ogDZA7ea4vLWHO_rl1ixapXuKIBeuy7tKzEFCTY,4958
6
+ evolutionary_policy_optimization-0.0.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ evolutionary_policy_optimization-0.0.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ evolutionary_policy_optimization-0.0.32.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
2
- evolutionary_policy_optimization/epo.py,sha256=4iuro11yTpRNzFfSoRZARnOiTDIJYndWmVaUAqk3--E,21826
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.29.dist-info/METADATA,sha256=C4gxOaspzHqA7TN5iQ8cDIFkg8llS8kg4y5Xg_ke2Qc,4958
5
- evolutionary_policy_optimization-0.0.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.29.dist-info/RECORD,,