evolutionary-policy-optimization 0.0.64__py3-none-any.whl → 0.0.65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -114,18 +114,17 @@ def get_fitness_scores(
114
114
  # generalized advantage estimate
115
115
 
116
116
  def calc_generalized_advantage_estimate(
117
- rewards, # Float[n]
118
- values, # Float[n+1]
119
- masks, # Bool[n]
117
+ rewards,
118
+ values,
119
+ masks,
120
120
  gamma = 0.99,
121
121
  lam = 0.95,
122
122
  use_accelerated = None
123
123
  ):
124
- assert values.shape[-1] == (rewards.shape[-1] + 1)
125
-
126
124
  use_accelerated = default(use_accelerated, rewards.is_cuda)
127
125
  device = rewards.device
128
126
 
127
+ values = F.pad(values, (0, 1), value = 0.)
129
128
  values, values_next = values[:-1], values[1:]
130
129
 
131
130
  delta = rewards + gamma * values_next * masks - values
@@ -866,21 +865,16 @@ class Agent(Module):
866
865
  # generalized advantage estimate
867
866
 
868
867
  advantages = self.calc_gae(
869
- rewards[:-1],
868
+ rewards,
870
869
  values,
871
- masks[:-1],
870
+ masks,
872
871
  )
873
872
 
874
873
  # dataset and dataloader
875
874
 
876
875
  valid_episode = episode_ids >= 0
877
876
 
878
- dataset = TensorDataset(
879
- *[
880
- advantages[valid_episode[:-1]],
881
- *[t[valid_episode] for t in (states, latent_gene_ids, actions, log_probs, values)]
882
- ]
883
- )
877
+ dataset = TensorDataset(*[t[valid_episode] for t in (advantages, states, latent_gene_ids, actions, log_probs, values)])
884
878
 
885
879
  dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
886
880
 
@@ -1183,12 +1177,14 @@ class EPO(Module):
1183
1177
 
1184
1178
  # get the next state, action, and reward
1185
1179
 
1186
- state, reward, done = env(action)
1180
+ state, reward, truncated, terminated = env(action)
1181
+
1182
+ done = truncated or terminated
1187
1183
 
1188
1184
  # update cumulative rewards per latent, to be used as default fitness score
1189
1185
 
1190
1186
  rewards_per_latent_episode[latent_id, episode_id] += reward
1191
-
1187
+
1192
1188
  # store memories
1193
1189
 
1194
1190
  memory = Memory(
@@ -1199,14 +1195,14 @@ class EPO(Module):
1199
1195
  log_prob,
1200
1196
  reward,
1201
1197
  value,
1202
- done
1198
+ terminated
1203
1199
  )
1204
1200
 
1205
1201
  memories.append(memory)
1206
1202
 
1207
1203
  time += 1
1208
1204
 
1209
- if not done:
1205
+ if not terminated:
1210
1206
  # add bootstrap value if truncated
1211
1207
 
1212
1208
  next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+ from random import choice
2
3
 
3
4
  import torch
4
5
  from torch import tensor, randn, randint
@@ -14,21 +15,25 @@ def cast_tuple(v):
14
15
  class Env(Module):
15
16
  def __init__(
16
17
  self,
17
- state_shape: int | tuple[int, ...]
18
+ state_shape: int | tuple[int, ...],
19
+ can_terminate_after = 2
18
20
  ):
19
21
  super().__init__()
20
22
  self.state_shape = cast_tuple(state_shape)
21
- self.register_buffer('dummy', tensor(0))
23
+
24
+ self.can_terminate_after = can_terminate_after
25
+ self.register_buffer('step', tensor(0))
22
26
 
23
27
  @property
24
28
  def device(self):
25
- return self.dummy.device
29
+ return self.step.device
26
30
 
27
31
  def reset(
28
32
  self,
29
33
  seed = None
30
34
  ):
31
35
  state = randn(self.state_shape, device = self.device)
36
+ self.step.zero_()
32
37
  return state
33
38
 
34
39
  def forward(
@@ -37,6 +42,13 @@ class Env(Module):
37
42
  ):
38
43
  state = randn(self.state_shape, device = self.device)
39
44
  reward = randint(0, 5, (), device = self.device).float()
40
- done = torch.zeros((), device = self.device, dtype = torch.bool)
41
45
 
42
- return state, reward, done
46
+ if self.step > self.can_terminate_after:
47
+ truncated = tensor(choice((True, False)), device =self.device)
48
+ terminated = tensor(choice((True, False)), device =self.device)
49
+ else:
50
+ truncated = terminated = tensor(False, device = self.device)
51
+
52
+ self.step.add_(1)
53
+
54
+ return state, reward, truncated, terminated
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.64
3
+ Version: 0.0.65
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=IVyzLUH2h83_T2h8bloUM00q5GKuAHnzZu2QzVzTSDk,36912
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
+ evolutionary_policy_optimization-0.0.65.dist-info/METADATA,sha256=2e4tKSTSxwYRCynvtkowFpIFyUkyh4oLXUyX8PhMZWg,6220
7
+ evolutionary_policy_optimization-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.65.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=0_jC9Tbl6FiscLHklvTKtuQTwZL8egqFKW-4JUxxwvw,37001
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=gvATGA51Ym5sf3jiR2VmlpjiCcT7KCDDY_SrR-MEwsU,941
6
- evolutionary_policy_optimization-0.0.64.dist-info/METADATA,sha256=vWdnTe2a86wTenEh29TNJlYEjD8A5CPtsyylxh4XsE0,6220
7
- evolutionary_policy_optimization-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.64.dist-info/RECORD,,