evolutionary-policy-optimization 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +171 -11
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/METADATA +6 -2
- evolutionary_policy_optimization-0.0.6.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.4.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.4.dist-info → evolutionary_policy_optimization-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from collections import namedtuple
|
4
|
+
|
3
5
|
import torch
|
4
6
|
from torch import nn, cat
|
5
7
|
import torch.nn.functional as F
|
@@ -7,7 +9,7 @@ import torch.nn.functional as F
|
|
7
9
|
import torch.nn.functional as F
|
8
10
|
from torch.nn import Linear, Module, ModuleList
|
9
11
|
|
10
|
-
from einops import rearrange, repeat
|
12
|
+
from einops import rearrange, repeat, einsum
|
11
13
|
|
12
14
|
from assoc_scan import AssocScan
|
13
15
|
|
@@ -85,9 +87,9 @@ def critic_loss(
|
|
85
87
|
# generalized advantage estimate
|
86
88
|
|
87
89
|
def calc_generalized_advantage_estimate(
|
88
|
-
rewards
|
89
|
-
values
|
90
|
-
masks
|
90
|
+
rewards, # Float[g n]
|
91
|
+
values, # Float[g n+1]
|
92
|
+
masks, # Bool[n]
|
91
93
|
gamma = 0.99,
|
92
94
|
lam = 0.95,
|
93
95
|
use_accelerated = None
|
@@ -160,6 +162,7 @@ class MLP(Module):
|
|
160
162
|
self,
|
161
163
|
dims: tuple[int, ...],
|
162
164
|
dim_latent = 0,
|
165
|
+
num_latent_sets = 1
|
163
166
|
):
|
164
167
|
super().__init__()
|
165
168
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
@@ -167,17 +170,26 @@ class MLP(Module):
|
|
167
170
|
# add the latent to the first dim
|
168
171
|
|
169
172
|
first_dim, *rest_dims = dims
|
170
|
-
first_dim
|
171
|
-
|
173
|
+
dims = (first_dim + dim_latent, *rest_dims)
|
174
|
+
|
175
|
+
assert num_latent_sets >= 1
|
172
176
|
|
173
177
|
self.dim_latent = dim_latent
|
178
|
+
self.num_latent_sets = num_latent_sets
|
179
|
+
|
174
180
|
self.needs_latent = dim_latent > 0
|
181
|
+
self.needs_latent_gate = num_latent_sets > 1
|
175
182
|
|
176
183
|
self.encode_latent = nn.Sequential(
|
177
184
|
Linear(dim_latent, dim_latent),
|
178
185
|
nn.SiLU()
|
179
186
|
) if self.needs_latent else None
|
180
187
|
|
188
|
+
self.to_latent_gate = nn.Sequential(
|
189
|
+
Linear(first_dim, num_latent_sets),
|
190
|
+
nn.Softmax(dim = -1)
|
191
|
+
) if self.needs_latent_gate else None
|
192
|
+
|
181
193
|
# pairs of dimension
|
182
194
|
|
183
195
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
@@ -193,16 +205,27 @@ class MLP(Module):
|
|
193
205
|
x,
|
194
206
|
latent = None
|
195
207
|
):
|
208
|
+
batch = x.shape[0]
|
209
|
+
|
196
210
|
assert xnor(self.needs_latent, exists(latent))
|
197
211
|
|
212
|
+
if exists(latent) and self.needs_latent_gate:
|
213
|
+
# an improvisation where set of genes with controlled expression by environment
|
214
|
+
|
215
|
+
gates = self.to_latent_gate(x)
|
216
|
+
latent = einsum(latent, gates, 'n g, b n -> b g')
|
217
|
+
else:
|
218
|
+
assert latent.shape[0] == 1
|
219
|
+
latent = latent[0]
|
220
|
+
|
198
221
|
if exists(latent):
|
199
222
|
# start with naive concatenative conditioning
|
200
223
|
# but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
|
201
224
|
|
202
|
-
batch = x.shape[0]
|
203
|
-
|
204
225
|
latent = self.encode_latent(latent)
|
205
|
-
|
226
|
+
|
227
|
+
if latent.ndim == 1:
|
228
|
+
latent = repeat(latent, 'd -> b d', b = batch)
|
206
229
|
|
207
230
|
x = cat((x, latent), dim = -1)
|
208
231
|
|
@@ -218,6 +241,100 @@ class MLP(Module):
|
|
218
241
|
|
219
242
|
return x
|
220
243
|
|
244
|
+
# actor, critic, and agent (actor + critic)
|
245
|
+
# eventually, should just create a separate repo and aggregate all the MLP related architectures
|
246
|
+
|
247
|
+
class Actor(Module):
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
dim_in,
|
251
|
+
num_actions,
|
252
|
+
dim_hiddens: tuple[int, ...],
|
253
|
+
dim_latent = 0,
|
254
|
+
):
|
255
|
+
super().__init__()
|
256
|
+
|
257
|
+
assert len(dim_hiddens) >= 2
|
258
|
+
dim_first, *_, dim_last = dim_hiddens
|
259
|
+
|
260
|
+
self.init_layer = nn.Sequential(
|
261
|
+
nn.Linear(dim_in, dim_first),
|
262
|
+
nn.SiLU()
|
263
|
+
)
|
264
|
+
|
265
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
266
|
+
|
267
|
+
self.to_out = nn.Sequential(
|
268
|
+
nn.SiLU(),
|
269
|
+
nn.Linear(dim_last, num_actions),
|
270
|
+
)
|
271
|
+
|
272
|
+
def forward(
|
273
|
+
self,
|
274
|
+
state,
|
275
|
+
latent
|
276
|
+
):
|
277
|
+
|
278
|
+
hidden = self.init_layer(state)
|
279
|
+
|
280
|
+
hidden = self.mlp(state, latent)
|
281
|
+
|
282
|
+
return self.to_out(hidden)
|
283
|
+
|
284
|
+
class Critic(Module):
|
285
|
+
def __init__(
|
286
|
+
self,
|
287
|
+
dim_in,
|
288
|
+
dim_hiddens: tuple[int, ...],
|
289
|
+
dim_latent = 0,
|
290
|
+
):
|
291
|
+
super().__init__()
|
292
|
+
|
293
|
+
assert len(dim_hiddens) >= 2
|
294
|
+
dim_first, *_, dim_last = dim_hiddens
|
295
|
+
|
296
|
+
self.init_layer = nn.Sequential(
|
297
|
+
nn.Linear(dim_in, dim_first),
|
298
|
+
nn.SiLU()
|
299
|
+
)
|
300
|
+
|
301
|
+
self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
|
302
|
+
|
303
|
+
self.to_out = nn.Sequential(
|
304
|
+
nn.SiLU(),
|
305
|
+
nn.Linear(dim_last, 1),
|
306
|
+
Rearrange('... 1 -> ...')
|
307
|
+
)
|
308
|
+
|
309
|
+
def forward(
|
310
|
+
self,
|
311
|
+
state,
|
312
|
+
latent
|
313
|
+
):
|
314
|
+
|
315
|
+
hidden = self.init_layer(state)
|
316
|
+
|
317
|
+
hidden = self.mlp(state, latent)
|
318
|
+
|
319
|
+
return self.to_out(hidden)
|
320
|
+
|
321
|
+
class Agent(Module):
|
322
|
+
def __init__(
|
323
|
+
self,
|
324
|
+
actor: Actor,
|
325
|
+
critic: Critic,
|
326
|
+
):
|
327
|
+
super().__init__()
|
328
|
+
|
329
|
+
self.actor = actor
|
330
|
+
self.critic = critic
|
331
|
+
|
332
|
+
def forward(
|
333
|
+
self,
|
334
|
+
memories: list[Memory]
|
335
|
+
):
|
336
|
+
raise NotImplementedError
|
337
|
+
|
221
338
|
# classes
|
222
339
|
|
223
340
|
class LatentGenePool(Module):
|
@@ -225,6 +342,7 @@ class LatentGenePool(Module):
|
|
225
342
|
self,
|
226
343
|
num_latents, # same as gene pool size
|
227
344
|
dim_latent, # gene dimension
|
345
|
+
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
228
346
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
229
347
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
230
348
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -237,12 +355,13 @@ class LatentGenePool(Module):
|
|
237
355
|
|
238
356
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
239
357
|
|
240
|
-
latents = torch.randn(num_latents, dim_latent)
|
358
|
+
latents = torch.randn(num_latents, num_latent_sets, dim_latent)
|
241
359
|
|
242
360
|
if l2norm_latent:
|
243
361
|
latents = maybe_l2norm(latents, dim = -1)
|
244
362
|
|
245
363
|
self.num_latents = num_latents
|
364
|
+
self.num_latent_sets = num_latent_sets
|
246
365
|
self.latents = nn.Parameter(latents, requires_grad = False)
|
247
366
|
|
248
367
|
self.maybe_l2norm = maybe_l2norm
|
@@ -268,9 +387,17 @@ class LatentGenePool(Module):
|
|
268
387
|
# network for the latent / gene
|
269
388
|
|
270
389
|
if isinstance(net, dict):
|
390
|
+
assert 'dim_latent' not in net
|
391
|
+
assert 'num_latent_sets' not in net
|
392
|
+
|
393
|
+
net.update(dim_latent = dim_latent)
|
394
|
+
net.update(num_latent_sets = num_latent_sets)
|
395
|
+
|
271
396
|
net = MLP(**net)
|
272
397
|
|
273
398
|
assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
|
399
|
+
assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
|
400
|
+
|
274
401
|
self.net = net
|
275
402
|
|
276
403
|
@torch.no_grad()
|
@@ -283,6 +410,7 @@ class LatentGenePool(Module):
|
|
283
410
|
"""
|
284
411
|
p - population
|
285
412
|
g - gene dimension
|
413
|
+
n - number of genes per individual
|
286
414
|
"""
|
287
415
|
assert self.num_latents > 1
|
288
416
|
|
@@ -307,7 +435,7 @@ class LatentGenePool(Module):
|
|
307
435
|
|
308
436
|
tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
|
309
437
|
|
310
|
-
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
|
438
|
+
tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
|
311
439
|
|
312
440
|
parents = participants.gather(-2, tournament_winner_indices)
|
313
441
|
|
@@ -368,3 +496,35 @@ class LatentGenePool(Module):
|
|
368
496
|
latent = latent,
|
369
497
|
**kwargs
|
370
498
|
)
|
499
|
+
|
500
|
+
# EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
|
501
|
+
# the tricky part is that the latent ids for each episode / trajectory needs to be tracked
|
502
|
+
|
503
|
+
Memory = namedtuple('Memory', [
|
504
|
+
'state',
|
505
|
+
'latent_gene_id',
|
506
|
+
'action',
|
507
|
+
'log_prob',
|
508
|
+
'reward',
|
509
|
+
'values',
|
510
|
+
'done'
|
511
|
+
])
|
512
|
+
|
513
|
+
class EPO(Module):
|
514
|
+
|
515
|
+
def __init__(
|
516
|
+
self,
|
517
|
+
agent: Agent,
|
518
|
+
latent_gene_pool: LatentGenePool
|
519
|
+
):
|
520
|
+
super().__init__()
|
521
|
+
|
522
|
+
self.agent = agent
|
523
|
+
self.latent_gene_pool = latent_gene_pool
|
524
|
+
|
525
|
+
def forward(
|
526
|
+
self,
|
527
|
+
env
|
528
|
+
) -> list[Memory]:
|
529
|
+
|
530
|
+
raise NotImplementedError
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
|
|
46
46
|
Requires-Dist: box2d-py; extra == 'examples-gym'
|
47
47
|
Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
|
48
48
|
Requires-Dist: tqdm; extra == 'examples-gym'
|
49
|
+
Provides-Extra: test
|
50
|
+
Requires-Dist: pytest; extra == 'test'
|
49
51
|
Description-Content-Type: text/markdown
|
50
52
|
|
51
53
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
|
|
56
58
|
|
57
59
|
This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
|
58
60
|
|
59
|
-
Besides their latent variable
|
61
|
+
Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
|
62
|
+
|
63
|
+
Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
|
60
64
|
|
61
65
|
## Usage
|
62
66
|
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=vXkwsQE0CNEUPpguZP-XXsuDyIBN-bS3xDJDXpYlTHM,14772
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.6.dist-info/METADATA,sha256=M_0SbTqdifHQ_R9LWIe7ZfHMXgCiFDJ0sDpD29ctiNk,4460
|
5
|
+
evolutionary_policy_optimization-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.6.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=jW6wZ_IbTdO05agc9AghDHawLb0rStfOzHKpSh-vEe0,10783
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.4.dist-info/METADATA,sha256=ZmVUGRQkqOYs1fAyPXjyvIeyc_mShKVTfRVZsIE_Z1Q,4098
|
5
|
-
evolutionary_policy_optimization-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.4.dist-info/RECORD,,
|
File without changes
|