evolutionary-policy-optimization 0.0.10__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +28 -24
- {evolutionary_policy_optimization-0.0.10.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.11.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.10.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.10.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.10.dist-info → evolutionary_policy_optimization-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -163,7 +163,6 @@ class MLP(Module):
|
|
163
163
|
self,
|
164
164
|
dims: tuple[int, ...],
|
165
165
|
dim_latent = 0,
|
166
|
-
num_latent_sets = 1
|
167
166
|
):
|
168
167
|
super().__init__()
|
169
168
|
assert len(dims) >= 2, 'must have at least two dimensions'
|
@@ -173,24 +172,15 @@ class MLP(Module):
|
|
173
172
|
first_dim, *rest_dims = dims
|
174
173
|
dims = (first_dim + dim_latent, *rest_dims)
|
175
174
|
|
176
|
-
assert num_latent_sets >= 1
|
177
|
-
|
178
175
|
self.dim_latent = dim_latent
|
179
|
-
self.num_latent_sets = num_latent_sets
|
180
176
|
|
181
177
|
self.needs_latent = dim_latent > 0
|
182
|
-
self.needs_latent_gate = num_latent_sets > 1
|
183
178
|
|
184
179
|
self.encode_latent = nn.Sequential(
|
185
180
|
Linear(dim_latent, dim_latent),
|
186
181
|
nn.SiLU()
|
187
182
|
) if self.needs_latent else None
|
188
183
|
|
189
|
-
self.to_latent_gate = nn.Sequential(
|
190
|
-
Linear(first_dim, num_latent_sets),
|
191
|
-
nn.Softmax(dim = -1)
|
192
|
-
) if self.needs_latent_gate else None
|
193
|
-
|
194
184
|
# pairs of dimension
|
195
185
|
|
196
186
|
dim_pairs = tuple(zip(dims[:-1], dims[1:]))
|
@@ -210,15 +200,6 @@ class MLP(Module):
|
|
210
200
|
|
211
201
|
assert xnor(self.needs_latent, exists(latent))
|
212
202
|
|
213
|
-
if exists(latent) and self.needs_latent_gate:
|
214
|
-
# an improvisation where set of genes with controlled expression by environment
|
215
|
-
|
216
|
-
gates = self.to_latent_gate(x)
|
217
|
-
latent = einsum(latent, gates, 'n g, b n -> b g')
|
218
|
-
else:
|
219
|
-
assert latent.shape[0] == 1
|
220
|
-
latent = latent[0]
|
221
|
-
|
222
203
|
if exists(latent):
|
223
204
|
# start with naive concatenative conditioning
|
224
205
|
# but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
|
@@ -248,7 +229,7 @@ class MLP(Module):
|
|
248
229
|
class Actor(Module):
|
249
230
|
def __init__(
|
250
231
|
self,
|
251
|
-
|
232
|
+
dim_state,
|
252
233
|
num_actions,
|
253
234
|
dim_hiddens: tuple[int, ...],
|
254
235
|
dim_latent = 0,
|
@@ -259,7 +240,7 @@ class Actor(Module):
|
|
259
240
|
dim_first, *_, dim_last = dim_hiddens
|
260
241
|
|
261
242
|
self.init_layer = nn.Sequential(
|
262
|
-
nn.Linear(
|
243
|
+
nn.Linear(dim_state, dim_first),
|
263
244
|
nn.SiLU()
|
264
245
|
)
|
265
246
|
|
@@ -285,7 +266,7 @@ class Actor(Module):
|
|
285
266
|
class Critic(Module):
|
286
267
|
def __init__(
|
287
268
|
self,
|
288
|
-
|
269
|
+
dim_state,
|
289
270
|
dim_hiddens: tuple[int, ...],
|
290
271
|
dim_latent = 0,
|
291
272
|
):
|
@@ -295,7 +276,7 @@ class Critic(Module):
|
|
295
276
|
dim_first, *_, dim_last = dim_hiddens
|
296
277
|
|
297
278
|
self.init_layer = nn.Sequential(
|
298
|
-
nn.Linear(
|
279
|
+
nn.Linear(dim_state, dim_first),
|
299
280
|
nn.SiLU()
|
300
281
|
)
|
301
282
|
|
@@ -346,6 +327,7 @@ class LatentGenePool(Module):
|
|
346
327
|
num_latents, # same as gene pool size
|
347
328
|
dim_latent, # gene dimension
|
348
329
|
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
330
|
+
dim_state = None,
|
349
331
|
crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
|
350
332
|
l2norm_latent = False, # whether to enforce latents on hypersphere,
|
351
333
|
frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
|
@@ -365,11 +347,23 @@ class LatentGenePool(Module):
|
|
365
347
|
latents = maybe_l2norm(latents, dim = -1)
|
366
348
|
|
367
349
|
self.num_latents = num_latents
|
368
|
-
self.
|
350
|
+
self.needs_latent_gate = num_latent_sets > 1
|
369
351
|
self.latents = nn.Parameter(latents, requires_grad = False)
|
370
352
|
|
371
353
|
self.maybe_l2norm = maybe_l2norm
|
372
354
|
|
355
|
+
# gene expression as a function of environment
|
356
|
+
|
357
|
+
self.num_latent_sets = num_latent_sets
|
358
|
+
|
359
|
+
if self.needs_latent_gate:
|
360
|
+
assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
|
361
|
+
|
362
|
+
self.to_latent_gate = nn.Sequential(
|
363
|
+
Linear(dim_state, num_latent_sets),
|
364
|
+
nn.Softmax(dim = -1)
|
365
|
+
) if self.needs_latent_gate else None
|
366
|
+
|
373
367
|
# some derived values
|
374
368
|
|
375
369
|
assert 0. < frac_tournaments < 1.
|
@@ -471,6 +465,7 @@ class LatentGenePool(Module):
|
|
471
465
|
def forward(
|
472
466
|
self,
|
473
467
|
*args,
|
468
|
+
state: Tensor | None = None,
|
474
469
|
latent_id: int | None = None,
|
475
470
|
net: Module | None = None,
|
476
471
|
**kwargs,
|
@@ -487,6 +482,15 @@ class LatentGenePool(Module):
|
|
487
482
|
|
488
483
|
latent = self.latents[latent_id]
|
489
484
|
|
485
|
+
if self.needs_latent_gate:
|
486
|
+
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
487
|
+
|
488
|
+
gates = self.to_latent_gate(state)
|
489
|
+
latent = einsum(latent, gates, 'n g, b n -> b g')
|
490
|
+
else:
|
491
|
+
assert latent.shape[0] == 1
|
492
|
+
latent = latent[0]
|
493
|
+
|
490
494
|
if not exists(net):
|
491
495
|
return latent
|
492
496
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.11
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=JGow9ofx7IgFy7QNL0dL0K_SCL_bVkBUznMG8aSGM9Q,15591
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.11.dist-info/METADATA,sha256=fkouRBZU5nrPgHt0eT5izSHdOiYGAg67N5Gn3t039mQ,4357
|
5
|
+
evolutionary_policy_optimization-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.11.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=66GOQq8_s5kmQI7G-2Z0J_0g4E5QarjQPJfWEP7mmKg,15442
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.10.dist-info/METADATA,sha256=bD3fw2Zw1IxhfkCvzjsRhODyL_XIC5ZsvNQqFbZXNc4,4357
|
5
|
-
evolutionary_policy_optimization-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.10.dist-info/RECORD,,
|
File without changes
|