lt-tensor 0.0.1a10__tar.gz → 0.0.1a12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/PKG-INFO +3 -2
  2. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/__init__.py +2 -0
  3. lt_tensor-0.0.1a12/lt_tensor/config_templates.py +97 -0
  4. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/datasets/audio.py +21 -7
  5. lt_tensor-0.0.1a12/lt_tensor/losses.py +159 -0
  6. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/math_ops.py +1 -1
  7. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/misc_utils.py +94 -7
  8. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_base.py +298 -128
  9. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/__init__.py +2 -2
  10. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/bsc.py +25 -3
  11. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/disc.py +55 -51
  12. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/fsn.py +2 -2
  13. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/gns.py +4 -4
  14. lt_tensor-0.0.1a12/lt_tensor/model_zoo/istft/__init__.py +5 -0
  15. lt_tensor-0.0.1a10/lt_tensor/model_zoo/istft.py → lt_tensor-0.0.1a12/lt_tensor/model_zoo/istft/generator.py +67 -66
  16. lt_tensor-0.0.1a12/lt_tensor/model_zoo/istft/trainer.py +450 -0
  17. lt_tensor-0.0.1a12/lt_tensor/model_zoo/istft.py +591 -0
  18. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/pos.py +2 -2
  19. lt_tensor-0.0.1a12/lt_tensor/model_zoo/rsd.py +107 -0
  20. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/model_zoo/tfrms.py +4 -4
  21. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/noise_tools.py +3 -4
  22. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/processors/audio.py +87 -16
  23. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/transform.py +30 -61
  24. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor.egg-info/PKG-INFO +3 -2
  25. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor.egg-info/SOURCES.txt +4 -0
  26. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor.egg-info/requires.txt +2 -1
  27. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/setup.py +3 -2
  28. lt_tensor-0.0.1a10/lt_tensor/losses.py +0 -145
  29. lt_tensor-0.0.1a10/lt_tensor/model_zoo/rsd.py +0 -237
  30. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/LICENSE +0 -0
  31. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/README.md +0 -0
  32. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/datasets/__init__.py +0 -0
  33. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/lr_schedulers.py +0 -0
  34. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/monotonic_align.py +0 -0
  35. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/processors/__init__.py +0 -0
  36. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor/torch_commons.py +0 -0
  37. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor.egg-info/dependency_links.txt +0 -0
  38. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/lt_tensor.egg-info/top_level.txt +0 -0
  39. {lt_tensor-0.0.1a10 → lt_tensor-0.0.1a12}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a10
3
+ Version: 0.0.1a12
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -17,8 +17,9 @@ Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
- Requires-Dist: lt-utils==0.0.1
20
+ Requires-Dist: lt-utils>=0.0.2a1
21
21
  Requires-Dist: librosa>=0.11.0
22
+ Requires-Dist: plotly
22
23
  Dynamic: author
23
24
  Dynamic: classifier
24
25
  Dynamic: description
@@ -12,6 +12,7 @@ from . import (
12
12
  losses,
13
13
  processors,
14
14
  datasets,
15
+ torch_commons,
15
16
  )
16
17
 
17
18
  __all__ = [
@@ -26,4 +27,5 @@ __all__ = [
26
27
  "losses",
27
28
  "processors",
28
29
  "datasets",
30
+ "torch_commons",
29
31
  ]
@@ -0,0 +1,97 @@
1
+ from lt_utils.common import *
2
+ from lt_utils.file_ops import load_json, save_json, FileScan
3
+ from lt_utils.misc_utils import log_traceback, get_current_time
4
+ from lt_utils.type_utils import is_pathlike, is_file, is_dir, is_dict, is_str
5
+ from lt_tensor.misc_utils import updateDict
6
+
7
+
8
+ class ModelConfig(ABC, OrderedDict):
9
+ _default_settings: Dict[str, Any] = {}
10
+ _forbidden_list: List[str] = [
11
+ "_settings",
12
+ ]
13
+
14
+ def __init__(
15
+ self,
16
+ settings: Dict[str, Any] = None,
17
+ path_name: Optional[Union[str, PathLike]] = None,
18
+ ):
19
+ assert is_dict(settings)
20
+ self._default_settings = settings
21
+ if path_name is not None and is_pathlike(path_name):
22
+ if not str(path_name).endswith(".json"):
23
+ self.path_name = str(Path(path_name, "config.json")).replace("\\", "/")
24
+ else:
25
+ self.path_name = str(path_name).replace("\\", "/")
26
+ else:
27
+ self.path_name = "config.json"
28
+ self.reset_settings()
29
+
30
+ def _setup_path_name(self, path_name: Union[str, PathLike]):
31
+ if is_file(path_name):
32
+ self.from_path(path_name)
33
+ self.path_name = str(path_name).replace("\\", "/")
34
+ elif is_str(path_name):
35
+ self.path_name = str(path_name).replace("\\", "/")
36
+ if not self.path_name.endswith((".json")):
37
+ self.path_name += ".json"
38
+
39
+ def reset_settings(self):
40
+ for s_name, setting in self._default_settings.items():
41
+ if s_name in self._forbidden_list:
42
+ continue
43
+ updateDict(self, {s_name: setting})
44
+
45
+ def save_config(
46
+ self,
47
+ path_name: Union[PathLike, str],
48
+ ):
49
+ assert is_pathlike(
50
+ path_name, True
51
+ ), f"path_name should be a non-empty string or pathlike object! received instead: {path_name}"
52
+ self._setup_path_name(path_name)
53
+ base = {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
54
+ save_json(self.path_name, base, indent=2)
55
+
56
+ def to_dict(self):
57
+ return {k: y for k, y in self.__dict__.items() if k not in self._forbidden_list}
58
+
59
+ def set_value(self, var_name: str, value: str) -> None:
60
+ updateDict(self, {var_name: value})
61
+
62
+ def get_value(self, var_name: str) -> Any:
63
+ return self.__dict__.get(var_name)
64
+
65
+ @classmethod
66
+ def from_dict(
67
+ cls, dictionary: Dict[str, Any], path: Optional[Union[str, PathLike]] = None
68
+ ) -> "ModelConfig":
69
+ assert is_dict(dictionary)
70
+ return ModelConfig(dictionary, path)
71
+
72
+ @classmethod
73
+ def from_path(cls, path_name: PathLike) -> "ModelConfig":
74
+ assert is_file(path_name) or is_dir(path_name)
75
+ settings = {}
76
+
77
+ if is_file(path_name):
78
+ settings.update(load_json(path_name, {}, errors="ignore"))
79
+ else:
80
+ files = FileScan.files(
81
+ path_name,
82
+ [
83
+ "*_config.json",
84
+ "config_*.json",
85
+ "*_config.json",
86
+ "cfg_*.json",
87
+ "*_cfg.json",
88
+ "cfg.json",
89
+ "config.json",
90
+ "settings.json",
91
+ "settings_*.json",
92
+ "*_settings.json",
93
+ ],
94
+ )
95
+ assert files, "No config file found in the provided directory!"
96
+ settings.update(load_json(files[-1], {}, errors="ignore"))
97
+ return ModelConfig(settings, path_name)
@@ -6,11 +6,10 @@ from torch.utils.data import Dataset, DataLoader, Sampler
6
6
  from lt_tensor.processors import AudioProcessor
7
7
  import torch.nn.functional as FT
8
8
  from lt_tensor.misc_utils import log_tensor
9
+ from tqdm import tqdm
9
10
 
10
11
 
11
12
  class WaveMelDataset(Dataset):
12
- """Untested!"""
13
-
14
13
  data: Union[list[dict[str, Tensor]], Tuple[Tensor, Tensor]] = []
15
14
 
16
15
  def __init__(
@@ -19,12 +18,16 @@ class WaveMelDataset(Dataset):
19
18
  path: PathLike,
20
19
  limit_files: Optional[int] = None,
21
20
  max_frame_length: Optional[int] = None,
21
+ randomize_ranges: bool = False,
22
22
  ):
23
23
  super().__init__()
24
24
  assert max_frame_length is None or max_frame_length >= (
25
25
  (audio_processor.n_fft // 2) + 1
26
26
  )
27
+
27
28
  self.post_n_fft = (audio_processor.n_fft // 2) + 1
29
+ if max_frame_length is not None:
30
+ self.r_range = max(self.post_n_fft + 1, max_frame_length // 3)
28
31
  self.ap = audio_processor
29
32
  self.files = self.ap.find_audios(path)
30
33
  if limit_files:
@@ -32,21 +35,32 @@ class WaveMelDataset(Dataset):
32
35
  self.files = self.files[:limit_files]
33
36
  self.data = []
34
37
 
35
- for file in self.files:
36
- results = self.load_data(file, max_frame_length)
38
+ for file in tqdm(self.files, "Loading files"):
39
+ results = self.load_data(file, max_frame_length, randomize_ranges)
37
40
  self.data.extend(results)
38
41
 
39
42
  def _add_dict(self, audio_raw: Tensor, audio_mel: Tensor, file: PathLike):
40
43
  return {"mel": audio_mel, "raw": audio_raw, "file": file}
41
44
 
42
- def load_data(self, file: PathLike, audio_frames_limit: Optional[int] = None):
43
- initial_audio = self.ap.load_audio(file)
45
+ def load_data(
46
+ self,
47
+ file: PathLike,
48
+ audio_frames_limit: Optional[int] = None,
49
+ randomize_ranges: bool = False,
50
+ ):
51
+ initial_audio = self.ap.rebuild_spectrogram(self.ap.load_audio(file))
44
52
  if not audio_frames_limit or initial_audio.shape[-1] <= audio_frames_limit:
53
+ if initial_audio.shape[-1] < self.post_n_fft:
54
+ return []
45
55
  audio_mel = self.ap.compute_mel(initial_audio, add_base=True)
46
56
  return [self._add_dict(initial_audio, audio_mel, file)]
47
57
  results = []
58
+ if randomize_ranges:
59
+ frame_limit = random.randint(self.r_range, audio_frames_limit)
60
+ else:
61
+ frame_limit = audio_frames_limit
48
62
  for fragment in torch.split(
49
- initial_audio, split_size_or_sections=audio_frames_limit, dim=-1
63
+ initial_audio, split_size_or_sections=frame_limit, dim=-1
50
64
  ):
51
65
  if fragment.shape[-1] < self.post_n_fft:
52
66
  # sometimes the tensor will be too small to be able to pass on mel
@@ -0,0 +1,159 @@
1
+ __all__ = [
2
+ "masked_cross_entropy",
3
+ "adaptive_l1_loss",
4
+ "contrastive_loss",
5
+ "smooth_l1_loss",
6
+ "hybrid_loss",
7
+ "diff_loss",
8
+ "cosine_loss",
9
+ "gan_loss",
10
+ "ft_n_loss",
11
+ ]
12
+ import math
13
+ import random
14
+ from lt_tensor.torch_commons import *
15
+ from lt_utils.common import *
16
+ import torch.nn.functional as F
17
+
18
+ def ft_n_loss(output: Tensor, target: Tensor, weight: Optional[Tensor] = None):
19
+ if weight is not None:
20
+ return torch.mean((torch.abs(output - target) + weight) **0.5)
21
+ return torch.mean(torch.abs(output - target)**0.5)
22
+
23
+ def adaptive_l1_loss(
24
+ inp: Tensor,
25
+ tgt: Tensor,
26
+ weight: Optional[Tensor] = None,
27
+ scale: float = 1.0,
28
+ inverted: bool = False,
29
+ ):
30
+
31
+ if weight is not None:
32
+ loss = torch.mean(torch.abs((inp - tgt) + weight.mean()))
33
+ else:
34
+ loss = torch.mean(torch.abs(inp - tgt))
35
+ loss *= scale
36
+ if inverted:
37
+ return -loss
38
+ return loss
39
+
40
+
41
+ def smooth_l1_loss(inp: Tensor, tgt: Tensor, beta=1.0, weight=None):
42
+ diff = torch.abs(inp - tgt)
43
+ loss = torch.where(diff < beta, 0.5 * diff**2 / beta, diff - 0.5 * beta)
44
+ if weight is not None:
45
+ loss *= weight
46
+ return loss.mean()
47
+
48
+
49
+ def contrastive_loss(x1: Tensor, x2: Tensor, label: Tensor, margin: float = 1.0):
50
+ # label == 1: similar, label == 0: dissimilar
51
+ dist = torch.nn.functional.pairwise_distance(x1, x2)
52
+ loss = label * dist**2 + (1 - label) * torch.clamp(margin - dist, min=0.0) ** 2
53
+ return loss.mean()
54
+
55
+
56
+ def cosine_loss(inp, tgt):
57
+ cos = torch.nn.functional.cosine_similarity(inp, tgt, dim=-1)
58
+ return 1 - cos.mean() # Lower is better
59
+
60
+
61
+ class GanLosses:
62
+ @staticmethod
63
+ def get_loss(
64
+ pred: Tensor,
65
+ target_is_real: bool,
66
+ loss_type: Literal["bce", "mse", "hinge", "wasserstein"] = "bce",
67
+ ) -> Tensor:
68
+ if loss_type == "bce": # Standard GAN
69
+ target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
70
+ return F.binary_cross_entropy_with_logits(pred, target)
71
+
72
+ elif loss_type == "mse": # LSGAN
73
+ target = torch.ones_like(pred) if target_is_real else torch.zeros_like(pred)
74
+ return F.mse_loss(torch.sigmoid(pred), target)
75
+
76
+ elif loss_type == "hinge":
77
+ if target_is_real:
78
+ return torch.mean(F.relu(1.0 - pred))
79
+ else:
80
+ return torch.mean(F.relu(1.0 + pred))
81
+
82
+ elif loss_type == "wasserstein":
83
+ return -pred.mean() if target_is_real else pred.mean()
84
+
85
+ else:
86
+ raise ValueError(f"Unknown loss_type: {loss_type}")
87
+
88
+ @staticmethod
89
+ def generator_loss(fake_pred: Tensor, loss_type: str = "bce") -> Tensor:
90
+ return GanLosses.get_loss(fake_pred, target_is_real=True, loss_type=loss_type)
91
+
92
+ @staticmethod
93
+ def discriminator_loss(
94
+ real_pred: Tensor, fake_pred: Tensor, loss_type: str = "bce"
95
+ ) -> Tensor:
96
+ real_loss = GanLosses.get_loss(
97
+ real_pred, target_is_real=True, loss_type=loss_type
98
+ )
99
+ fake_loss = GanLosses.get_loss(
100
+ fake_pred.detach(), target_is_real=False, loss_type=loss_type
101
+ )
102
+ return (real_loss + fake_loss) * 0.5
103
+
104
+
105
+ def masked_cross_entropy(
106
+ logits: torch.Tensor, # [B, T, V]
107
+ targets: torch.Tensor, # [B, T]
108
+ lengths: torch.Tensor, # [B]
109
+ reduction: str = "mean",
110
+ ) -> torch.Tensor:
111
+ """
112
+ CrossEntropyLoss with masking for variable-length sequences.
113
+ - logits: unnormalized scores [B, T, V]
114
+ - targets: ground truth indices [B, T]
115
+ - lengths: actual sequence lengths [B]
116
+ """
117
+ B, T, V = logits.size()
118
+ logits = logits.view(-1, V)
119
+ targets = targets.view(-1)
120
+
121
+ # Create mask
122
+ mask = torch.arange(T, device=lengths.device).expand(B, T) < lengths.unsqueeze(1)
123
+ mask = mask.reshape(-1)
124
+
125
+ # Apply CE only where mask == True
126
+ loss = F.cross_entropy(
127
+ logits[mask], targets[mask], reduction="mean" if reduction == "mean" else "none"
128
+ )
129
+ if reduction == "none":
130
+ return loss
131
+ return loss
132
+
133
+
134
+ def diff_loss(pred_noise, true_noise, mask=None):
135
+ """Standard diffusion noise-prediction loss (e.g., DDPM)"""
136
+ if mask is not None:
137
+ return F.mse_loss(pred_noise * mask, true_noise * mask)
138
+ return F.mse_loss(pred_noise, true_noise)
139
+
140
+
141
+ def hybrid_diff_loss(pred_noise, true_noise, alpha=0.5):
142
+ """Combines L1 and L2"""
143
+ l1 = F.l1_loss(pred_noise, true_noise)
144
+ l2 = F.mse_loss(pred_noise, true_noise)
145
+ return alpha * l1 + (1 - alpha) * l2
146
+
147
+
148
+ def gan_d_loss(real_preds, fake_preds, use_lsgan=True):
149
+ loss = 0
150
+ for real, fake in zip(real_preds, fake_preds):
151
+ if use_lsgan:
152
+ loss += F.mse_loss(real, torch.ones_like(real)) + F.mse_loss(
153
+ fake, torch.zeros_like(fake)
154
+ )
155
+ else:
156
+ loss += -torch.mean(torch.log(real + 1e-7)) - torch.mean(
157
+ torch.log(1 - fake + 1e-7)
158
+ )
159
+ return loss
@@ -12,7 +12,7 @@ __all__ = [
12
12
  "phase",
13
13
  ]
14
14
 
15
- from .torch_commons import *
15
+ from lt_tensor.torch_commons import *
16
16
 
17
17
 
18
18
  def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
@@ -15,11 +15,14 @@ __all__ = [
15
15
  "TorchCacheUtils",
16
16
  "clear_cache",
17
17
  "default_device",
18
+ "soft_restore",
18
19
  "Packing",
19
20
  "Padding",
20
21
  "Masking",
21
22
  "LogTensor",
22
23
  "get_losses",
24
+ "plot_view",
25
+ "get_weights",
23
26
  ]
24
27
 
25
28
  import re
@@ -27,22 +30,104 @@ import gc
27
30
  import sys
28
31
  import random
29
32
  import numpy as np
30
- from lt_utils.type_utils import is_str
31
- from .torch_commons import *
33
+ import warnings
34
+ from lt_utils.type_utils import is_str, is_dir, is_file, is_pathlike, is_path_valid
35
+ from lt_utils.file_ops import FileScan, find_files, path_to_str, load_json, load_yaml
36
+ from lt_tensor.torch_commons import *
32
37
  from lt_utils.misc_utils import cache_wrapper
33
38
  from lt_utils.common import *
34
39
  from lt_utils.misc_utils import ff_list
35
40
  import torch.nn.functional as F
36
41
 
37
42
 
43
+ def plot_view(
44
+ data: Dict[str, List[Any]],
45
+ title: str = "Loss",
46
+ max_amount: int = 0,
47
+ xaxis_title="Step/Epoch",
48
+ yaxis_title="Loss",
49
+ template="plotly_dark",
50
+ ):
51
+ try:
52
+ import plotly.graph_objs as go
53
+ except ModuleNotFoundError:
54
+ warnings.warn(
55
+ "No installation of plotly was found. To use it use 'pip install plotly' and restart this application!"
56
+ )
57
+ return
58
+ fig = go.Figure()
59
+ for mode, values in data.items():
60
+ if values:
61
+ items = values if not max_amount > 0 else values[-max_amount:]
62
+ fig.add_trace(go.Scatter(y=items, name=mode.capitalize()))
63
+ fig.update_layout(
64
+ title=title,
65
+ xaxis_title=xaxis_title,
66
+ yaxis_title=yaxis_title,
67
+ template=template,
68
+ )
69
+ return fig
70
+
71
+
72
+ def get_weights(directory: Union[str, PathLike]):
73
+ is_path_valid(directory, validate=True) # raises validation if its invalid path
74
+ directory = Path(directory)
75
+ if is_file(directory):
76
+ if directory.name.endswith((".pt", ".ckpt", ".pth")):
77
+ return directory
78
+ directory = directory.parent
79
+ res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
80
+ return res[-1] if res else None
81
+
82
+
83
+ def get_config(directory: Union[str, PathLike], default: Optional[Any] = None):
84
+ # raises validation if its invalid path only when default is None otherwise it returns the defaults.
85
+ if not is_path_valid(directory, validate=default is None):
86
+ return default
87
+ directory = Path(directory)
88
+ if is_file(directory):
89
+ if directory.name.endswith((".json", ".yaml", ".yml")):
90
+ if directory.name.endswith(".json"):
91
+ return load_json(directory, default)
92
+ return load_yaml(directory, default)
93
+ directory = directory.parent
94
+ res = sorted(find_files(directory, ["*.pt", "*.ckpt", "*.pth"]))
95
+ if res:
96
+ res = res[-1]
97
+ if Path(res).name.endswith(".json"):
98
+ return load_json(directory, default)
99
+ return load_yaml(directory, default)
100
+ return default
101
+
102
+
103
+ def updateDict(self, dct: dict[str, Any]):
104
+ for k, v in dct.items():
105
+ setattr(self, k, v)
106
+
107
+
108
+ def soft_restore(tensor, epsilon=1e-6):
109
+ return torch.where(tensor == 0, torch.full_like(tensor, epsilon), tensor)
110
+
111
+
38
112
  def try_torch(fn: str, *args, **kwargs):
113
+ tryed_torch = False
114
+ not_present_message = (
115
+ f"Both `torch` and `torch.nn.functional` does not contain the module `{fn}`"
116
+ )
39
117
  try:
40
- return getattr(F, fn)(*args, **kwargs)
41
- except Exception as e:
42
- try:
118
+ if hasattr(F, fn):
119
+ return getattr(F, fn)(*args, **kwargs)
120
+ elif hasattr(torch, fn):
121
+ tryed_torch = True
43
122
  return getattr(torch, fn)(*args, **kwargs)
123
+ return not_present_message
124
+ except Exception as a:
125
+ try:
126
+ if not tryed_torch and hasattr(torch, fn):
127
+ return getattr(torch, fn)(*args, **kwargs)
128
+ return str(a)
44
129
  except Exception as e:
45
- return str(e)
130
+ return str(e) + " | " + str(a)
46
131
 
47
132
 
48
133
  def log_tensor(
@@ -192,6 +277,7 @@ class LogTensor:
192
277
  ("max", dict(dim=-1)),
193
278
  ],
194
279
  validate_item_type: bool = False,
280
+ exclude_invalid_losses: bool = True,
195
281
  **kwargs,
196
282
  ):
197
283
  invalid_type = not isinstance(inputs, (Tensor, np.ndarray, list, tuple))
@@ -243,12 +329,13 @@ class LogTensor:
243
329
  value = try_torch(log_fn, inputs)
244
330
  else:
245
331
  value = try_torch(log_fn, inputs, log_args)
332
+
246
333
  results = self._process(log_fn, value)
247
334
  current_register[log_fn] = results
248
335
  self.do_print = old_print
249
336
 
250
337
  if target is not None:
251
- losses = get_losses(inputs, target, False)
338
+ losses = get_losses(inputs, target, exclude_invalid_losses)
252
339
  started_ls = False
253
340
  if self.do_print:
254
341
  for loss, res in losses.items():