evolutionary-policy-optimization 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +40 -5
- {evolutionary_policy_optimization-0.0.2.dist-info → evolutionary_policy_optimization-0.0.4.dist-info}/METADATA +2 -1
- evolutionary_policy_optimization-0.0.4.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.2.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.2.dist-info → evolutionary_policy_optimization-0.0.4.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.2.dist-info → evolutionary_policy_optimization-0.0.4.dist-info}/licenses/LICENSE +0 -0
@@ -9,6 +9,8 @@ from torch.nn import Linear, Module, ModuleList
|
|
9
9
|
|
10
10
|
from einops import rearrange, repeat
|
11
11
|
|
12
|
+
from assoc_scan import AssocScan
|
13
|
+
|
12
14
|
# helpers
|
13
15
|
|
14
16
|
def exists(v):
|
@@ -56,18 +58,18 @@ def actor_loss(
|
|
56
58
|
):
|
57
59
|
log_probs = gather_log_prob(logits, actions)
|
58
60
|
|
59
|
-
entropy = calc_entropy(logits)
|
60
|
-
|
61
61
|
ratio = (log_probs - old_log_probs).exp()
|
62
62
|
|
63
|
-
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
64
|
-
|
65
63
|
# classic clipped surrogate loss from ppo
|
66
64
|
|
65
|
+
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
66
|
+
|
67
67
|
actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
|
68
68
|
|
69
69
|
# add entropy loss for exploration
|
70
70
|
|
71
|
+
entropy = calc_entropy(logits)
|
72
|
+
|
71
73
|
entropy_aux_loss = -entropy_weight * entropy
|
72
74
|
|
73
75
|
return actor_loss + entropy_aux_loss
|
@@ -80,6 +82,39 @@ def critic_loss(
|
|
80
82
|
discounted_values = advantages + old_values
|
81
83
|
return F.mse_loss(pred_values, discounted_values)
|
82
84
|
|
85
|
+
# generalized advantage estimate
|
86
|
+
|
87
|
+
def calc_generalized_advantage_estimate(
|
88
|
+
rewards: Float['g n'],
|
89
|
+
values: Float['g n+1'],
|
90
|
+
masks: Bool['n'],
|
91
|
+
gamma = 0.99,
|
92
|
+
lam = 0.95,
|
93
|
+
use_accelerated = None
|
94
|
+
|
95
|
+
):
|
96
|
+
assert values.shape[-1] == (rewards.shape[-1] + 1)
|
97
|
+
|
98
|
+
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
99
|
+
device = rewards.device
|
100
|
+
|
101
|
+
masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
|
102
|
+
|
103
|
+
values, values_next = values[:, :-1], values[:, 1:]
|
104
|
+
|
105
|
+
delta = rewards + gamma * values_next * masks - values
|
106
|
+
gates = gamma * lam * masks
|
107
|
+
|
108
|
+
gates, delta = gates[..., :, None], delta[..., :, None]
|
109
|
+
|
110
|
+
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
111
|
+
|
112
|
+
gae = scan(gates, delta)
|
113
|
+
|
114
|
+
gae = gae[..., :, 0]
|
115
|
+
|
116
|
+
return gae
|
117
|
+
|
83
118
|
# evolution related functions
|
84
119
|
|
85
120
|
def crossover_latents(
|
@@ -319,7 +354,7 @@ class LatentGenePool(Module):
|
|
319
354
|
|
320
355
|
# if only 1 latent, assume doing ablation and get lone gene
|
321
356
|
|
322
|
-
if self.num_latents == 1:
|
357
|
+
if not exists(latent_id) and self.num_latents == 1:
|
323
358
|
latent_id = 0
|
324
359
|
|
325
360
|
assert 0 <= latent_id < self.num_latents
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.4
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.8
|
37
|
+
Requires-Dist: assoc-scan
|
37
38
|
Requires-Dist: einops>=0.8.0
|
38
39
|
Requires-Dist: torch>=2.2
|
39
40
|
Requires-Dist: tqdm
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=jW6wZ_IbTdO05agc9AghDHawLb0rStfOzHKpSh-vEe0,10783
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.4.dist-info/METADATA,sha256=ZmVUGRQkqOYs1fAyPXjyvIeyc_mShKVTfRVZsIE_Z1Q,4098
|
5
|
+
evolutionary_policy_optimization-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.4.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=6IcrAooUY8csFMv4ho5bf6TVAsk1cYyXB5hG3NA-jbA,9941
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.2.dist-info/METADATA,sha256=XqR4EZUWWXRjskU3_eHlwNzCeEddjm20un4E04xmNLk,4072
|
5
|
-
evolutionary_policy_optimization-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.2.dist-info/RECORD,,
|
File without changes
|