evolutionary-policy-optimization 0.0.10__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -163,7 +163,6 @@ class MLP(Module):
163
163
  self,
164
164
  dims: tuple[int, ...],
165
165
  dim_latent = 0,
166
- num_latent_sets = 1
167
166
  ):
168
167
  super().__init__()
169
168
  assert len(dims) >= 2, 'must have at least two dimensions'
@@ -173,24 +172,15 @@ class MLP(Module):
173
172
  first_dim, *rest_dims = dims
174
173
  dims = (first_dim + dim_latent, *rest_dims)
175
174
 
176
- assert num_latent_sets >= 1
177
-
178
175
  self.dim_latent = dim_latent
179
- self.num_latent_sets = num_latent_sets
180
176
 
181
177
  self.needs_latent = dim_latent > 0
182
- self.needs_latent_gate = num_latent_sets > 1
183
178
 
184
179
  self.encode_latent = nn.Sequential(
185
180
  Linear(dim_latent, dim_latent),
186
181
  nn.SiLU()
187
182
  ) if self.needs_latent else None
188
183
 
189
- self.to_latent_gate = nn.Sequential(
190
- Linear(first_dim, num_latent_sets),
191
- nn.Softmax(dim = -1)
192
- ) if self.needs_latent_gate else None
193
-
194
184
  # pairs of dimension
195
185
 
196
186
  dim_pairs = tuple(zip(dims[:-1], dims[1:]))
@@ -210,15 +200,6 @@ class MLP(Module):
210
200
 
211
201
  assert xnor(self.needs_latent, exists(latent))
212
202
 
213
- if exists(latent) and self.needs_latent_gate:
214
- # an improvisation where set of genes with controlled expression by environment
215
-
216
- gates = self.to_latent_gate(x)
217
- latent = einsum(latent, gates, 'n g, b n -> b g')
218
- else:
219
- assert latent.shape[0] == 1
220
- latent = latent[0]
221
-
222
203
  if exists(latent):
223
204
  # start with naive concatenative conditioning
224
205
  # but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
@@ -248,7 +229,7 @@ class MLP(Module):
248
229
  class Actor(Module):
249
230
  def __init__(
250
231
  self,
251
- dim_in,
232
+ dim_state,
252
233
  num_actions,
253
234
  dim_hiddens: tuple[int, ...],
254
235
  dim_latent = 0,
@@ -259,7 +240,7 @@ class Actor(Module):
259
240
  dim_first, *_, dim_last = dim_hiddens
260
241
 
261
242
  self.init_layer = nn.Sequential(
262
- nn.Linear(dim_in, dim_first),
243
+ nn.Linear(dim_state, dim_first),
263
244
  nn.SiLU()
264
245
  )
265
246
 
@@ -285,7 +266,7 @@ class Actor(Module):
285
266
  class Critic(Module):
286
267
  def __init__(
287
268
  self,
288
- dim_in,
269
+ dim_state,
289
270
  dim_hiddens: tuple[int, ...],
290
271
  dim_latent = 0,
291
272
  ):
@@ -295,7 +276,7 @@ class Critic(Module):
295
276
  dim_first, *_, dim_last = dim_hiddens
296
277
 
297
278
  self.init_layer = nn.Sequential(
298
- nn.Linear(dim_in, dim_first),
279
+ nn.Linear(dim_state, dim_first),
299
280
  nn.SiLU()
300
281
  )
301
282
 
@@ -346,6 +327,7 @@ class LatentGenePool(Module):
346
327
  num_latents, # same as gene pool size
347
328
  dim_latent, # gene dimension
348
329
  num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
330
+ dim_state = None,
349
331
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
350
332
  l2norm_latent = False, # whether to enforce latents on hypersphere,
351
333
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
@@ -365,11 +347,23 @@ class LatentGenePool(Module):
365
347
  latents = maybe_l2norm(latents, dim = -1)
366
348
 
367
349
  self.num_latents = num_latents
368
- self.num_latent_sets = num_latent_sets
350
+ self.needs_latent_gate = num_latent_sets > 1
369
351
  self.latents = nn.Parameter(latents, requires_grad = False)
370
352
 
371
353
  self.maybe_l2norm = maybe_l2norm
372
354
 
355
+ # gene expression as a function of environment
356
+
357
+ self.num_latent_sets = num_latent_sets
358
+
359
+ if self.needs_latent_gate:
360
+ assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
361
+
362
+ self.to_latent_gate = nn.Sequential(
363
+ Linear(dim_state, num_latent_sets),
364
+ nn.Softmax(dim = -1)
365
+ ) if self.needs_latent_gate else None
366
+
373
367
  # some derived values
374
368
 
375
369
  assert 0. < frac_tournaments < 1.
@@ -471,6 +465,7 @@ class LatentGenePool(Module):
471
465
  def forward(
472
466
  self,
473
467
  *args,
468
+ state: Tensor | None = None,
474
469
  latent_id: int | None = None,
475
470
  net: Module | None = None,
476
471
  **kwargs,
@@ -487,6 +482,15 @@ class LatentGenePool(Module):
487
482
 
488
483
  latent = self.latents[latent_id]
489
484
 
485
+ if self.needs_latent_gate:
486
+ assert exists(state), 'state must be passed in if greater than number of 1 latent set'
487
+
488
+ gates = self.to_latent_gate(state)
489
+ latent = einsum(latent, gates, 'n g, b n -> b g')
490
+ else:
491
+ assert latent.shape[0] == 1
492
+ latent = latent[0]
493
+
490
494
  if not exists(net):
491
495
  return latent
492
496
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
2
+ evolutionary_policy_optimization/epo.py,sha256=JGow9ofx7IgFy7QNL0dL0K_SCL_bVkBUznMG8aSGM9Q,15591
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.11.dist-info/METADATA,sha256=fkouRBZU5nrPgHt0eT5izSHdOiYGAg67N5Gn3t039mQ,4357
5
+ evolutionary_policy_optimization-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.11.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=A07bhbBI_p-GlSTkI15pioQ1XgtJ0V4tBN6v3vs2nuU,115
2
- evolutionary_policy_optimization/epo.py,sha256=66GOQq8_s5kmQI7G-2Z0J_0g4E5QarjQPJfWEP7mmKg,15442
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.10.dist-info/METADATA,sha256=bD3fw2Zw1IxhfkCvzjsRhODyL_XIC5ZsvNQqFbZXNc4,4357
5
- evolutionary_policy_optimization-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.10.dist-info/RECORD,,