evolutionary-policy-optimization 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +167 -4
- {evolutionary_policy_optimization-0.0.3.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/METADATA +2 -1
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.3.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.3.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.3.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from collections import namedtuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
from torch import nn, cat
|
5
7
|
import torch.nn.functional as F
|
@@ -9,6 +11,8 @@ from torch.nn import Linear, Module, ModuleList
|
|
9
11
|
|
10
12
|
from einops import rearrange, repeat
|
11
13
|
|
14
|
+
from assoc_scan import AssocScan
|
15
|
+
|
12
16
|
# helpers
|
13
17
|
|
14
18
|
def exists(v):
|
@@ -56,18 +60,18 @@ def actor_loss(
|
|
56
60
|
):
|
57
61
|
log_probs = gather_log_prob(logits, actions)
|
58
62
|
|
59
|
-
entropy = calc_entropy(logits)
|
60
|
-
|
61
63
|
ratio = (log_probs - old_log_probs).exp()
|
62
64
|
|
63
|
-
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
64
|
-
|
65
65
|
# classic clipped surrogate loss from ppo
|
66
66
|
|
67
|
+
clipped_ratio = ratio.clamp(min = 1. - eps_clip, max = 1. + eps_clip)
|
68
|
+
|
67
69
|
actor_loss = -torch.min(clipped_ratio * advantage, ratio * advantage)
|
68
70
|
|
69
71
|
# add entropy loss for exploration
|
70
72
|
|
73
|
+
entropy = calc_entropy(logits)
|
74
|
+
|
71
75
|
entropy_aux_loss = -entropy_weight * entropy
|
72
76
|
|
73
77
|
return actor_loss + entropy_aux_loss
|
@@ -80,6 +84,39 @@ def critic_loss(
|
|
80
84
|
discounted_values = advantages + old_values
|
81
85
|
return F.mse_loss(pred_values, discounted_values)
|
82
86
|
|
87
|
+
# generalized advantage estimate
|
88
|
+
|
89
|
+
def calc_generalized_advantage_estimate(
|
90
|
+
rewards, # Float[g n]
|
91
|
+
values, # Float[g n+1]
|
92
|
+
masks, # Bool[n]
|
93
|
+
gamma = 0.99,
|
94
|
+
lam = 0.95,
|
95
|
+
use_accelerated = None
|
96
|
+
|
97
|
+
):
|
98
|
+
assert values.shape[-1] == (rewards.shape[-1] + 1)
|
99
|
+
|
100
|
+
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
101
|
+
device = rewards.device
|
102
|
+
|
103
|
+
masks = repeat(masks, 'n -> g n', g = rewards.shape[0])
|
104
|
+
|
105
|
+
values, values_next = values[:, :-1], values[:, 1:]
|
106
|
+
|
107
|
+
delta = rewards + gamma * values_next * masks - values
|
108
|
+
gates = gamma * lam * masks
|
109
|
+
|
110
|
+
gates, delta = gates[..., :, None], delta[..., :, None]
|
111
|
+
|
112
|
+
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
113
|
+
|
114
|
+
gae = scan(gates, delta)
|
115
|
+
|
116
|
+
gae = gae[..., :, 0]
|
117
|
+
|
118
|
+
return gae
|
119
|
+
|
83
120
|
# evolution related functions
|
84
121
|
|
85
122
|
def crossover_latents(
|
@@ -183,6 +220,100 @@ class MLP(Module):
|
|
183
220
|
|
184
221
|
return x
|
185
222
|
|
223
|
+
# actor, critic, and agent (actor + critic)
|
224
|
+
# eventually, should just create a separate repo and aggregate all the MLP related architectures
|
225
|
+
|
226
|
+
class Actor(Module):
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
dim_in,
|
230
|
+
num_actions,
|
231
|
+
dim_hiddens: tuple[int, ...],
|
232
|
+
dim_latent = 0,
|
233
|
+
):
|
234
|
+
super().__init__()
|
235
|
+
|
236
|
+
assert len(dim_hiddens) >= 2
|
237
|
+
dim_first, *_, dim_last = dim_hiddens
|
238
|
+
|
239
|
+
self.init_layer = nn.Sequential(
|
240
|
+
nn.Linear(dim_in, dim_first),
|
241
|
+
nn.SiLU()
|
242
|
+
)
|
243
|
+
|
244
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
245
|
+
|
246
|
+
self.to_out = nn.Sequential(
|
247
|
+
nn.SiLU(),
|
248
|
+
nn.Linear(dim_last, num_actions),
|
249
|
+
)
|
250
|
+
|
251
|
+
def forward(
|
252
|
+
self,
|
253
|
+
state,
|
254
|
+
latent
|
255
|
+
):
|
256
|
+
|
257
|
+
hidden = self.init_layer(state)
|
258
|
+
|
259
|
+
hidden = self.mlp(state, latent)
|
260
|
+
|
261
|
+
return self.to_out(hidden)
|
262
|
+
|
263
|
+
class Critic(Module):
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
dim_in,
|
267
|
+
dim_hiddens: tuple[int, ...],
|
268
|
+
dim_latent = 0,
|
269
|
+
):
|
270
|
+
super().__init__()
|
271
|
+
|
272
|
+
assert len(dim_hiddens) >= 2
|
273
|
+
dim_first, *_, dim_last = dim_hiddens
|
274
|
+
|
275
|
+
self.init_layer = nn.Sequential(
|
276
|
+
nn.Linear(dim_in, dim_first),
|
277
|
+
nn.SiLU()
|
278
|
+
)
|
279
|
+
|
280
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
281
|
+
|
282
|
+
self.to_out = nn.Sequential(
|
283
|
+
nn.SiLU(),
|
284
|
+
nn.Linear(dim_last, 1),
|
285
|
+
Rearrange('... 1 -> ...')
|
286
|
+
)
|
287
|
+
|
288
|
+
def forward(
|
289
|
+
self,
|
290
|
+
state,
|
291
|
+
latent
|
292
|
+
):
|
293
|
+
|
294
|
+
hidden = self.init_layer(state)
|
295
|
+
|
296
|
+
hidden = self.mlp(state, latent)
|
297
|
+
|
298
|
+
return self.to_out(hidden)
|
299
|
+
|
300
|
+
class Agent(Module):
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
actor: Actor,
|
304
|
+
critic: Critic,
|
305
|
+
):
|
306
|
+
super().__init__()
|
307
|
+
|
308
|
+
self.actor = actor
|
309
|
+
self.critic = critic
|
310
|
+
|
311
|
+
def forward(
|
312
|
+
self,
|
313
|
+
memories: list[Memory]
|
314
|
+
):
|
315
|
+
raise NotImplementedError
|
316
|
+
|
186
317
|
# classes
|
187
318
|
|
188
319
|
class LatentGenePool(Module):
|
@@ -333,3 +464,35 @@ class LatentGenePool(Module):
|
|
333
464
|
latent = latent,
|
334
465
|
**kwargs
|
335
466
|
)
|
467
|
+
|
468
|
+
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
469
|
+
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
470
|
+
|
471
|
+
Memory = namedtuple('Memory', [
|
472
|
+
'state',
|
473
|
+
'latent_gene_id',
|
474
|
+
'action',
|
475
|
+
'log_prob',
|
476
|
+
'reward',
|
477
|
+
'values',
|
478
|
+
'done'
|
479
|
+
])
|
480
|
+
|
481
|
+
class EPO(Module):
|
482
|
+
|
483
|
+
def __init__(
|
484
|
+
self,
|
485
|
+
agent: Agent,
|
486
|
+
latent_gene_pool: LatentGenePool
|
487
|
+
):
|
488
|
+
super().__init__()
|
489
|
+
|
490
|
+
self.agent = agent
|
491
|
+
self.latent_gene_pool = latent_gene_pool
|
492
|
+
|
493
|
+
def forward(
|
494
|
+
self,
|
495
|
+
env
|
496
|
+
) -> list[Memory]:
|
497
|
+
|
498
|
+
raise NotImplementedError
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -34,6 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.8
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.8
|
37
|
+
Requires-Dist: assoc-scan
|
37
38
|
Requires-Dist: einops>=0.8.0
|
38
39
|
Requires-Dist: torch>=2.2
|
39
40
|
Requires-Dist: tqdm
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
|
5
|
+
evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=gB9_ZPtVjjFilZEeXX1ZWh67ZSd-OBt6Xs6WyfjRSxI,9967
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.3.dist-info/METADATA,sha256=qlQ1NxiA-AVWiNlpDcMI9hVqBg1g5DY9AvISpnrKkV0,4072
|
5
|
-
evolutionary_policy_optimization-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.3.dist-info/RECORD,,
|
File without changes
|