evolutionary-policy-optimization 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,8 @@ from torch.nn import Linear, Module, ModuleList
9
9
 
10
10
  from einops import rearrange, repeat
11
11
 
12
+ from assoc_scan import AssocScan
13
+
12
14
  # helpers
13
15
 
14
16
  def exists(v):
@@ -56,18 +58,18 @@ def actor_loss(
56
58
  ):
57
59
  log_probs = gather_log_prob(logits, actions)
58
60
 
59
- entropy = calc_entropy(logits)
60
-
61
61
  ratio = (log_probs - old_log_probs).exp()
62
62
 
63
- clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
64
-
65
63
  # classic clipped surrogate loss from ppo
66
64
 
65
+ clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
66
+
67
67
  actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
68
68
 
69
69
  # add entropy loss for exploration
70
70
 
71
+ entropy = calc_entropy(logits)
72
+
71
73
  entropy_aux_loss = -entropy_weight * entropy
72
74
 
73
75
  return actor_loss + entropy_aux_loss
@@ -80,6 +82,39 @@ def critic_loss(
80
82
  discounted_values = advantages + old_values
81
83
  return F.mse_loss(pred_values, discounted_values)
82
84
 
85
+ # generalized advantage estimate
86
+
87
+ def calc_generalized_advantage_estimate(
88
+ rewards: Float['g n'],
89
+ values: Float['g n+1'],
90
+ masks: Bool['n'],
91
+ gamma = 0.99,
92
+ lam = 0.95,
93
+ use_accelerated = None
94
+
95
+ ):
96
+ assert values.shape[-1] == (rewards.shape[-1] + 1)
97
+
98
+ use_accelerated = default(use_accelerated, rewards.is_cuda)
99
+ device = rewards.device
100
+
101
+ masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
102
+
103
+ values, values_next = values[:, :-1], values[:, 1:]
104
+
105
+ delta = rewards + gamma * values_next * masks - values
106
+ gates = gamma * lam * masks
107
+
108
+ gates, delta = gates[..., :, None], delta[..., :, None]
109
+
110
+ scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
111
+
112
+ gae = scan(gates, delta)
113
+
114
+ gae = gae[..., :, 0]
115
+
116
+ return gae
117
+
83
118
  # evolution related functions
84
119
 
85
120
  def crossover_latents(
@@ -319,7 +354,7 @@ class LatentGenePool(Module):
319
354
 
320
355
  # if only 1 latent, assume doing ablation and get lone gene
321
356
 
322
- if self.num_latents == 1:
357
+ if not exists(latent_id) and self.num_latents == 1:
323
358
  latent_id = 0
324
359
 
325
360
  assert 0 <= latent_id < self.num_latents
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.2
3
+ Version: 0.0.4
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
34
34
  Classifier: Programming Language :: Python :: 3.8
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.8
37
+ Requires-Dist: assoc-scan
37
38
  Requires-Dist: einops>=0.8.0
38
39
  Requires-Dist: torch>=2.2
39
40
  Requires-Dist: tqdm
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
+ evolutionary_policy_optimization/epo.py,sha256=jW6wZ_IbTdO05agc9AghDHawLb0rStfOzHKpSh-vEe0,10783
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.4.dist-info/METADATA,sha256=ZmVUGRQkqOYs1fAyPXjyvIeyc_mShKVTfRVZsIE_Z1Q,4098
5
+ evolutionary_policy_optimization-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.4.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=6IcrAooUY8csFMv4ho5bf6TVAsk1cYyXB5hG3NA-jbA,9941
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.2.dist-info/METADATA,sha256=XqR4EZUWWXRjskU3_eHlwNzCeEddjm20un4E04xmNLk,4072
5
- evolutionary_policy_optimization-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.2.dist-info/RECORD,,