evolutionary-policy-optimization 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +131 -3
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.4.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from collections import namedtuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
from torch import nn, cat
|
5
7
|
import torch.nn.functional as F
|
@@ -85,9 +87,9 @@ def critic_loss(
|
|
85
87
|
# generalized advantage estimate
|
86
88
|
|
87
89
|
def calc_generalized_advantage_estimate(
|
88
|
-
rewards
|
89
|
-
values
|
90
|
-
masks
|
90
|
+
rewards, # Float[g n]
|
91
|
+
values, # Float[g n+1]
|
92
|
+
masks, # Bool[n]
|
91
93
|
gamma = 0.99,
|
92
94
|
lam = 0.95,
|
93
95
|
use_accelerated = None
|
@@ -218,6 +220,100 @@ class MLP(Module):
|
|
218
220
|
|
219
221
|
return x
|
220
222
|
|
223
|
+
# actor, critic, and agent (actor + critic)
|
224
|
+
# eventually, should just create a separate repo and aggregate all the MLP related architectures
|
225
|
+
|
226
|
+
class Actor(Module):
|
227
|
+
def __init__(
|
228
|
+
self,
|
229
|
+
dim_in,
|
230
|
+
num_actions,
|
231
|
+
dim_hiddens: tuple[int, ...],
|
232
|
+
dim_latent = 0,
|
233
|
+
):
|
234
|
+
super().__init__()
|
235
|
+
|
236
|
+
assert len(dim_hiddens) >= 2
|
237
|
+
dim_first, *_, dim_last = dim_hiddens
|
238
|
+
|
239
|
+
self.init_layer = nn.Sequential(
|
240
|
+
nn.Linear(dim_in, dim_first),
|
241
|
+
nn.SiLU()
|
242
|
+
)
|
243
|
+
|
244
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
245
|
+
|
246
|
+
self.to_out = nn.Sequential(
|
247
|
+
nn.SiLU(),
|
248
|
+
nn.Linear(dim_last, num_actions),
|
249
|
+
)
|
250
|
+
|
251
|
+
def forward(
|
252
|
+
self,
|
253
|
+
state,
|
254
|
+
latent
|
255
|
+
):
|
256
|
+
|
257
|
+
hidden = self.init_layer(state)
|
258
|
+
|
259
|
+
hidden = self.mlp(state, latent)
|
260
|
+
|
261
|
+
return self.to_out(hidden)
|
262
|
+
|
263
|
+
class Critic(Module):
|
264
|
+
def __init__(
|
265
|
+
self,
|
266
|
+
dim_in,
|
267
|
+
dim_hiddens: tuple[int, ...],
|
268
|
+
dim_latent = 0,
|
269
|
+
):
|
270
|
+
super().__init__()
|
271
|
+
|
272
|
+
assert len(dim_hiddens) >= 2
|
273
|
+
dim_first, *_, dim_last = dim_hiddens
|
274
|
+
|
275
|
+
self.init_layer = nn.Sequential(
|
276
|
+
nn.Linear(dim_in, dim_first),
|
277
|
+
nn.SiLU()
|
278
|
+
)
|
279
|
+
|
280
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
281
|
+
|
282
|
+
self.to_out = nn.Sequential(
|
283
|
+
nn.SiLU(),
|
284
|
+
nn.Linear(dim_last, 1),
|
285
|
+
Rearrange('... 1 -> ...')
|
286
|
+
)
|
287
|
+
|
288
|
+
def forward(
|
289
|
+
self,
|
290
|
+
state,
|
291
|
+
latent
|
292
|
+
):
|
293
|
+
|
294
|
+
hidden = self.init_layer(state)
|
295
|
+
|
296
|
+
hidden = self.mlp(state, latent)
|
297
|
+
|
298
|
+
return self.to_out(hidden)
|
299
|
+
|
300
|
+
class Agent(Module):
|
301
|
+
def __init__(
|
302
|
+
self,
|
303
|
+
actor: Actor,
|
304
|
+
critic: Critic,
|
305
|
+
):
|
306
|
+
super().__init__()
|
307
|
+
|
308
|
+
self.actor = actor
|
309
|
+
self.critic = critic
|
310
|
+
|
311
|
+
def forward(
|
312
|
+
self,
|
313
|
+
memories: list[Memory]
|
314
|
+
):
|
315
|
+
raise NotImplementedError
|
316
|
+
|
221
317
|
# classes
|
222
318
|
|
223
319
|
class LatentGenePool(Module):
|
@@ -368,3 +464,35 @@ class LatentGenePool(Module):
|
|
368
464
|
latent = latent,
|
369
465
|
**kwargs
|
370
466
|
)
|
467
|
+
|
468
|
+
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
469
|
+
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
470
|
+
|
471
|
+
Memory = namedtuple('Memory', [
|
472
|
+
'state',
|
473
|
+
'latent_gene_id',
|
474
|
+
'action',
|
475
|
+
'log_prob',
|
476
|
+
'reward',
|
477
|
+
'values',
|
478
|
+
'done'
|
479
|
+
])
|
480
|
+
|
481
|
+
class EPO(Module):
|
482
|
+
|
483
|
+
def __init__(
|
484
|
+
self,
|
485
|
+
agent: Agent,
|
486
|
+
latent_gene_pool: LatentGenePool
|
487
|
+
):
|
488
|
+
super().__init__()
|
489
|
+
|
490
|
+
self.agent = agent
|
491
|
+
self.latent_gene_pool = latent_gene_pool
|
492
|
+
|
493
|
+
def forward(
|
494
|
+
self,
|
495
|
+
env
|
496
|
+
) -> list[Memory]:
|
497
|
+
|
498
|
+
raise NotImplementedError
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
|
5
|
+
evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=jW6wZ_IbTdO05agc9AghDHawLb0rStfOzHKpSh-vEe0,10783
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.4.dist-info/METADATA,sha256=ZmVUGRQkqOYs1fAyPXjyvIeyc_mShKVTfRVZsIE_Z1Q,4098
|
5
|
-
evolutionary_policy_optimization-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.4.dist-info/RECORD,,
|
File without changes
|