metacontroller-pytorch 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +7 -10
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.22.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.22.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.21.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.22.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.22.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -46,9 +46,6 @@ def default(*args):
|
|
|
46
46
|
return arg
|
|
47
47
|
return None
|
|
48
48
|
|
|
49
|
-
def is_empty(t):
|
|
50
|
-
return t.numel() == 0
|
|
51
|
-
|
|
52
49
|
def pad_at_dim(t, pad: tuple[int, int], dim = -1, value = 0.):
|
|
53
50
|
if pad == (0, 0):
|
|
54
51
|
return t
|
|
@@ -333,7 +330,7 @@ class Transformer(Module):
|
|
|
333
330
|
def forward(
|
|
334
331
|
self,
|
|
335
332
|
state,
|
|
336
|
-
|
|
333
|
+
actions: Tensor | None = None,
|
|
337
334
|
meta_controller: Module | None = None,
|
|
338
335
|
cache: TransformerOutput | None = None,
|
|
339
336
|
discovery_phase = False,
|
|
@@ -363,10 +360,10 @@ class Transformer(Module):
|
|
|
363
360
|
# handle maybe behavioral cloning
|
|
364
361
|
|
|
365
362
|
if behavioral_cloning or (meta_controlling and discovery_phase):
|
|
366
|
-
assert
|
|
363
|
+
assert exists(actions), f'`actions` cannot be empty when doing discovery or behavioral cloning'
|
|
367
364
|
|
|
368
365
|
state, target_state = state[:, :-1], state[:, 1:]
|
|
369
|
-
|
|
366
|
+
actions, target_actions = actions[:, :-1], actions[:, 1:]
|
|
370
367
|
|
|
371
368
|
# transformer lower body
|
|
372
369
|
|
|
@@ -376,8 +373,8 @@ class Transformer(Module):
|
|
|
376
373
|
|
|
377
374
|
# handle no past action for first timestep
|
|
378
375
|
|
|
379
|
-
if exists(
|
|
380
|
-
action_embed = self.action_embed(
|
|
376
|
+
if exists(actions):
|
|
377
|
+
action_embed = self.action_embed(actions)
|
|
381
378
|
else:
|
|
382
379
|
action_embed = state_embed[:, 0:0] # empty action embed
|
|
383
380
|
|
|
@@ -415,13 +412,13 @@ class Transformer(Module):
|
|
|
415
412
|
state_dist_params = self.state_readout(attended)
|
|
416
413
|
state_clone_loss = self.state_readout.calculate_loss(state_dist_params, target_state)
|
|
417
414
|
|
|
418
|
-
action_clone_loss = self.action_readout.calculate_loss(dist_params,
|
|
415
|
+
action_clone_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
419
416
|
|
|
420
417
|
return state_clone_loss, action_clone_loss
|
|
421
418
|
|
|
422
419
|
elif meta_controlling and discovery_phase:
|
|
423
420
|
|
|
424
|
-
action_recon_loss = self.action_readout.calculate_loss(dist_params,
|
|
421
|
+
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_actions)
|
|
425
422
|
|
|
426
423
|
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
427
424
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=kmfq4XSeh_tT30n3ctG1ggTgA0hVtV2eNKMqKr_jChc,14084
|
|
3
|
+
metacontroller_pytorch-0.0.22.dist-info/METADATA,sha256=zxsjz3enrR60enivkdcCYLd2Xjpg8d5EUydakRjYOAQ,4320
|
|
4
|
+
metacontroller_pytorch-0.0.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.22.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=EP2N1Qtw4WTNthQrMz6bBT9rxTtMFikdOyYtcwSPdHM,14167
|
|
3
|
-
metacontroller_pytorch-0.0.21.dist-info/METADATA,sha256=scUJVoSZ6Tl3RYNiNjK_wIeWVrpVLbQhya-XkCqdieQ,4320
|
|
4
|
-
metacontroller_pytorch-0.0.21.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.21.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.21.dist-info → metacontroller_pytorch-0.0.22.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|