dreamer4 0.0.91__tar.gz → 0.0.92__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.91
3
+ Version: 0.0.92
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -11,7 +11,7 @@ from dataclasses import dataclass, asdict
11
11
  import torch
12
12
  import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
- from torch.distributions import Normal
14
+ from torch.distributions import Normal, kl
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
16
  from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
@@ -198,6 +198,14 @@ def masked_mean(t, mask = None):
198
198
  def log(t, eps = 1e-20):
199
199
  return t.clamp(min = eps).log()
200
200
 
201
+ def mean_log_var_to_distr(
202
+ mean_log_var: Tensor
203
+ ) -> Normal:
204
+
205
+ mean, log_var = mean_log_var.unbind(dim = -1)
206
+ std = (0.5 * log_var).exp()
207
+ return Normal(mean, std)
208
+
201
209
  def safe_cat(tensors, dim):
202
210
  tensors = [*filter(exists, tensors)]
203
211
 
@@ -824,10 +832,7 @@ class ActionEmbedder(Module):
824
832
  continuous_entropies = None
825
833
 
826
834
  if exists(continuous_targets):
827
- mean, log_var = continuous_action_mean_log_var.unbind(dim = -1)
828
- std = (0.5 * log_var).exp()
829
-
830
- distr = Normal(mean, std)
835
+ distr = mean_log_var_to_distr(continuous_action_mean_log_var)
831
836
  continuous_log_probs = distr.log_prob(continuous_targets)
832
837
 
833
838
  if return_entropies:
@@ -842,6 +847,54 @@ class ActionEmbedder(Module):
842
847
 
843
848
  return log_probs, entropies
844
849
 
850
+ def kl_div(
851
+ self,
852
+ src: tuple[Tensor | None, Tensor | None],
853
+ tgt: tuple[Tensor | None, Tensor | None]
854
+ ) -> tuple[Tensor | None, Tensor | None]:
855
+
856
+ src_discrete, src_continuous = src
857
+ tgt_discrete, tgt_continuous = tgt
858
+
859
+ discrete_kl_div = None
860
+
861
+ # split discrete if it is not already (multiple discrete actions)
862
+
863
+ if exists(src_discrete):
864
+
865
+ discrete_split = self.num_discrete_actions.tolist()
866
+
867
+ if is_tensor(src_discrete):
868
+ src_discrete = src_discrete.split(discrete_split, dim = -1)
869
+
870
+ if is_tensor(tgt_discrete):
871
+ tgt_discrete = tgt_discrete.split(discrete_split, dim = -1)
872
+
873
+ discrete_kl_divs = []
874
+
875
+ for src_logit, tgt_logit in zip(src_discrete, tgt_discrete):
876
+
877
+ src_log_probs = src_logit.log_softmax(dim = -1)
878
+ tgt_prob = tgt_logit.softmax(dim = -1)
879
+
880
+ one_discrete_kl_div = F.kl_div(src_log_probs, tgt_prob, reduction = 'none')
881
+
882
+ discrete_kl_divs.append(one_discrete_kl_div.sum(dim = -1))
883
+
884
+ discrete_kl_div = stack(discrete_kl_divs, dim = -1)
885
+
886
+ # calculate kl divergence for continuous
887
+
888
+ continuous_kl_div = None
889
+
890
+ if exists(src_continuous):
891
+ src_normal = mean_log_var_to_distr(src_continuous)
892
+ tgt_normal = mean_log_var_to_distr(tgt_continuous)
893
+
894
+ continuous_kl_div = kl.kl_divergence(src_normal, tgt_normal)
895
+
896
+ return discrete_kl_div, continuous_kl_div
897
+
845
898
  def forward(
846
899
  self,
847
900
  *,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.91"
3
+ version = "0.0.92"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -346,6 +346,15 @@ def test_action_embedder():
346
346
  assert discrete_logits.shape == (2, 3, 8)
347
347
  assert continuous_mean_log_var.shape == (2, 3, 2, 2)
348
348
 
349
+ # test kl div
350
+
351
+ discrete_logits_tgt, continuous_mean_log_var_tgt = embedder.unembed(action_embed)
352
+
353
+ discrete_kl_div, continuous_kl_div = embedder.kl_div((discrete_logits, continuous_mean_log_var), (discrete_logits_tgt, continuous_mean_log_var_tgt))
354
+
355
+ assert discrete_kl_div.shape == (2, 3, 2)
356
+ assert continuous_kl_div.shape == (2, 3, 2)
357
+
349
358
  # return discrete split by number of actions
350
359
 
351
360
  discrete_logits, continuous_mean_log_var = embedder.unembed(action_embed, return_split_discrete = True)
File without changes
File without changes
File without changes
File without changes
File without changes