lt-tensor 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +0 -0
- lt_tensor/_basics.py +244 -0
- lt_tensor/_torch_commons.py +12 -0
- lt_tensor/lr_schedulers.py +108 -0
- lt_tensor/math_ops.py +120 -0
- lt_tensor/misc_utils.py +596 -0
- lt_tensor/model_zoo/__init__.py +2 -0
- lt_tensor/model_zoo/basic.py +65 -0
- lt_tensor/model_zoo/diffusion/__init__.py +6 -0
- lt_tensor/model_zoo/diffusion/models.py +114 -0
- lt_tensor/model_zoo/residual.py +236 -0
- lt_tensor/model_zoo/transformer_models/__init__.py +6 -0
- lt_tensor/model_zoo/transformer_models/models.py +132 -0
- lt_tensor/model_zoo/transformer_models/positional_encoders.py +95 -0
- lt_tensor/monotonic_align.py +70 -0
- lt_tensor/transform.py +113 -0
- lt_tensor-0.0.1.dev0.dist-info/METADATA +33 -0
- lt_tensor-0.0.1.dev0.dist-info/RECORD +21 -0
- lt_tensor-0.0.1.dev0.dist-info/WHEEL +5 -0
- lt_tensor-0.0.1.dev0.dist-info/licenses/LICENSE +201 -0
- lt_tensor-0.0.1.dev0.dist-info/top_level.txt +1 -0
lt_tensor/__init__.py
ADDED
File without changes
|
lt_tensor/_basics.py
ADDED
@@ -0,0 +1,244 @@
|
|
1
|
+
__all__ = ["Model"]
|
2
|
+
|
3
|
+
|
4
|
+
import warnings
|
5
|
+
from ._torch_commons import *
|
6
|
+
|
7
|
+
ROOT_DEVICE = torch.device(torch.zeros(1).device)
|
8
|
+
|
9
|
+
class _ModelDevice(nn.Module):
|
10
|
+
"""
|
11
|
+
This makes it easier to assign a device and retrieves it later
|
12
|
+
"""
|
13
|
+
|
14
|
+
_device: torch.device = ROOT_DEVICE
|
15
|
+
|
16
|
+
@property
|
17
|
+
def device(self):
|
18
|
+
return self._device
|
19
|
+
|
20
|
+
@device.setter
|
21
|
+
def device(self, device: Union[torch.device, str]):
|
22
|
+
assert isinstance(device, (str, torch.device))
|
23
|
+
self._device = torch.device(device) if isinstance(device, str) else device
|
24
|
+
self.tp_apply_device_to()
|
25
|
+
|
26
|
+
def __init__(self, *args, **kwargs):
|
27
|
+
super().__init__(*args, **kwargs)
|
28
|
+
|
29
|
+
def tp_apply_device_to(self):
|
30
|
+
"""Add here components that are needed to have device applied to them,
|
31
|
+
that usualy the '.to()' function fails to apply
|
32
|
+
|
33
|
+
example:
|
34
|
+
```
|
35
|
+
def tp_apply_device_to(self):
|
36
|
+
self.my_tensor = self.my_tensor.to(device=self.device)
|
37
|
+
```
|
38
|
+
"""
|
39
|
+
pass
|
40
|
+
|
41
|
+
def to(self, *args, **kwargs):
|
42
|
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
43
|
+
*args, **kwargs
|
44
|
+
)
|
45
|
+
|
46
|
+
if dtype is not None:
|
47
|
+
if not (dtype.is_floating_point or dtype.is_complex):
|
48
|
+
raise TypeError(
|
49
|
+
"nn.Module.to only accepts floating point or complex "
|
50
|
+
f"dtypes, but got desired dtype={dtype}"
|
51
|
+
)
|
52
|
+
if dtype.is_complex:
|
53
|
+
warnings.warn(
|
54
|
+
"Complex modules are a new feature under active development whose design may change, "
|
55
|
+
"and some modules might not work as expected when using complex tensors as parameters or buffers. "
|
56
|
+
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
|
57
|
+
"if a complex module does not work as expected."
|
58
|
+
)
|
59
|
+
|
60
|
+
def convert(t: Tensor):
|
61
|
+
try:
|
62
|
+
if convert_to_format is not None and t.dim() in (4, 5):
|
63
|
+
return t.to(
|
64
|
+
device,
|
65
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
66
|
+
non_blocking,
|
67
|
+
memory_format=convert_to_format,
|
68
|
+
)
|
69
|
+
return t.to(
|
70
|
+
device,
|
71
|
+
dtype if t.is_floating_point() or t.is_complex() else None,
|
72
|
+
non_blocking,
|
73
|
+
)
|
74
|
+
except NotImplementedError as e:
|
75
|
+
if str(e) == "Cannot copy out of meta tensor; no data!":
|
76
|
+
raise NotImplementedError(
|
77
|
+
f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
|
78
|
+
f"when moving module from meta to a different device."
|
79
|
+
) from None
|
80
|
+
else:
|
81
|
+
raise
|
82
|
+
|
83
|
+
self._apply(convert)
|
84
|
+
self.device = device
|
85
|
+
return self
|
86
|
+
|
87
|
+
def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
88
|
+
super().ipu(device)
|
89
|
+
dvc = "ipu"
|
90
|
+
if device is not None:
|
91
|
+
dvc += (
|
92
|
+
":" + str(device) if isinstance(device, (int, float)) else device.index
|
93
|
+
)
|
94
|
+
self.device = dvc
|
95
|
+
return self
|
96
|
+
|
97
|
+
def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
98
|
+
super().xpu(device)
|
99
|
+
dvc = "xpu"
|
100
|
+
if device is not None:
|
101
|
+
dvc += (
|
102
|
+
":" + str(device) if isinstance(device, (int, float)) else device.index
|
103
|
+
)
|
104
|
+
self.device = dvc
|
105
|
+
return self
|
106
|
+
|
107
|
+
def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
108
|
+
super().cuda(device)
|
109
|
+
dvc = "cuda"
|
110
|
+
if device is not None:
|
111
|
+
dvc += (
|
112
|
+
":" + str(device) if isinstance(device, (int, float)) else device.index
|
113
|
+
)
|
114
|
+
self.device = dvc
|
115
|
+
return self
|
116
|
+
|
117
|
+
def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
|
118
|
+
super().mtia(device)
|
119
|
+
dvc = "mtia"
|
120
|
+
if device is not None:
|
121
|
+
dvc += (
|
122
|
+
":" + str(device) if isinstance(device, (int, float)) else device.index
|
123
|
+
)
|
124
|
+
self.device = dvc
|
125
|
+
return self
|
126
|
+
|
127
|
+
def cpu(self) -> T:
|
128
|
+
super().cpu()
|
129
|
+
self.device = "cpu"
|
130
|
+
return self
|
131
|
+
|
132
|
+
|
133
|
+
class Model(_ModelDevice, ABC):
|
134
|
+
"""
|
135
|
+
For easier access of items in a model
|
136
|
+
This applies to all except to "forward", "save_weights", "load_weights",
|
137
|
+
"train_step", and "inference".
|
138
|
+
"""
|
139
|
+
|
140
|
+
_autocast: bool = False
|
141
|
+
|
142
|
+
@property
|
143
|
+
def autocast(self):
|
144
|
+
return self._autocast
|
145
|
+
|
146
|
+
@autocast.setter
|
147
|
+
def autocast(self, value: bool):
|
148
|
+
self._autocast = value
|
149
|
+
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
*args,
|
153
|
+
**kwargs,
|
154
|
+
):
|
155
|
+
super().__init__(*args, **kwargs)
|
156
|
+
device = kwargs.get("device", None)
|
157
|
+
if isinstance(device, (torch.device, str)):
|
158
|
+
self.to(device)
|
159
|
+
self.autocast = kwargs.get("autocast", self.autocast)
|
160
|
+
|
161
|
+
def tp_count_parameters(self, module_name: Optional[str] = None):
|
162
|
+
"""Gets the numer of parameters from either the entire model or from a specific module"""
|
163
|
+
if module_name is not None:
|
164
|
+
assert hasattr(self, module_name), f"Module {module_name} does not exits"
|
165
|
+
module = getattr(self, module_name)
|
166
|
+
return sum([x.numel() for x in module.parameters()])
|
167
|
+
else:
|
168
|
+
return sum([x.numel() for x in self.parameters()])
|
169
|
+
|
170
|
+
def tp_get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
|
171
|
+
"""Returns the weights of the model entrie model or from a specified module"""
|
172
|
+
if module_name is not None:
|
173
|
+
assert hasattr(self, module_name), f"Module {module_name} does not exits"
|
174
|
+
module = getattr(self, module_name)
|
175
|
+
return [x.data.detach() for x in module.parameters()]
|
176
|
+
return [x.data.detach() for x in self.parameters()]
|
177
|
+
|
178
|
+
def tp_print_num_params(self, module_name: Optional[str] = None) -> List[Tensor]:
|
179
|
+
params = format(self.tp_count_parameters(module_name), ",").replace(",", ".")
|
180
|
+
|
181
|
+
if module_name:
|
182
|
+
print(f'Parameters from "{module_name}": {params}')
|
183
|
+
else:
|
184
|
+
print(f"Parameters: {params}")
|
185
|
+
|
186
|
+
def save_weights(self, path: Union[Path, str]):
|
187
|
+
path = Path(path)
|
188
|
+
if path.exists():
|
189
|
+
assert (
|
190
|
+
path.is_file()
|
191
|
+
), "The provided path exists but its a directory not a file!"
|
192
|
+
path.rmdir()
|
193
|
+
path.parent.mkdir(exist_ok=True, parents=True)
|
194
|
+
torch.save(obj=self.state_dict(), f=str(path))
|
195
|
+
|
196
|
+
def load_weights(
|
197
|
+
self,
|
198
|
+
path: Union[Path, str],
|
199
|
+
raise_if_not_exists: bool = False,
|
200
|
+
strict: bool = True,
|
201
|
+
assign: bool = False,
|
202
|
+
weights_only: bool = False,
|
203
|
+
mmap: Optional[bool] = None,
|
204
|
+
**torch_loader_kwargs,
|
205
|
+
):
|
206
|
+
path = Path(path)
|
207
|
+
if not path.exists():
|
208
|
+
assert not raise_if_not_exists, "Path does not exists!"
|
209
|
+
return None
|
210
|
+
assert path.is_file(), "The provided path is not a valid file!"
|
211
|
+
state_dict = torch.load(
|
212
|
+
str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
|
213
|
+
)
|
214
|
+
incompatible_keys = self.load_state_dict(
|
215
|
+
state_dict,
|
216
|
+
strict=strict,
|
217
|
+
assign=assign,
|
218
|
+
)
|
219
|
+
return incompatible_keys
|
220
|
+
|
221
|
+
@torch.no_grad()
|
222
|
+
def inference(self, *args, **kwargs):
|
223
|
+
if self.training:
|
224
|
+
self.eval()
|
225
|
+
if self.autocast:
|
226
|
+
with torch.autocast(device_type=self.device.type):
|
227
|
+
return self(*args, **kwargs)
|
228
|
+
return self(*args, **kwargs)
|
229
|
+
|
230
|
+
def train_step(
|
231
|
+
self,
|
232
|
+
*args,
|
233
|
+
**kwargs,
|
234
|
+
):
|
235
|
+
"""Train Step"""
|
236
|
+
if not self.training:
|
237
|
+
self.train()
|
238
|
+
return self(*args, **kwargs)
|
239
|
+
|
240
|
+
@abstractmethod
|
241
|
+
def forward(
|
242
|
+
self, *args, **kwargs
|
243
|
+
) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
|
244
|
+
pass
|
@@ -0,0 +1,12 @@
|
|
1
|
+
import torch
|
2
|
+
from torch import nn, optim
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.optim import Optimizer
|
5
|
+
from torch.nn import Module, L1Loss, MSELoss
|
6
|
+
from torch.nn.utils import remove_weight_norm
|
7
|
+
from torch import Tensor, FloatTensor, device, LongTensor
|
8
|
+
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
9
|
+
|
10
|
+
from lt_utils.common import *
|
11
|
+
|
12
|
+
DeviceType: TypeAlias = Union[device, str]
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import math
|
2
|
+
from torch.optim import Optimizer
|
3
|
+
from torch.optim.lr_scheduler import _LRScheduler
|
4
|
+
|
5
|
+
|
6
|
+
class WarmupDecayScheduler(_LRScheduler):
|
7
|
+
def __init__(
|
8
|
+
self,
|
9
|
+
optimizer: Optimizer,
|
10
|
+
warmup_steps: int,
|
11
|
+
total_steps: int,
|
12
|
+
decay_type: str = "linear", # or "cosine"
|
13
|
+
min_lr: float = 0.0,
|
14
|
+
last_epoch: int = -1,
|
15
|
+
):
|
16
|
+
self.warmup_steps = warmup_steps
|
17
|
+
self.total_steps = total_steps
|
18
|
+
self.decay_type = decay_type
|
19
|
+
self.min_lr = min_lr
|
20
|
+
super().__init__(optimizer, last_epoch)
|
21
|
+
|
22
|
+
def get_lr(self):
|
23
|
+
step = self.last_epoch + 1
|
24
|
+
warmup = self.warmup_steps
|
25
|
+
total = self.total_steps
|
26
|
+
lrs = []
|
27
|
+
|
28
|
+
for base_lr in self.base_lrs:
|
29
|
+
if step < warmup:
|
30
|
+
lr = base_lr * step / warmup
|
31
|
+
else:
|
32
|
+
progress = (step - warmup) / max(1, total - warmup)
|
33
|
+
if self.decay_type == "linear":
|
34
|
+
lr = base_lr * (1.0 - progress)
|
35
|
+
elif self.decay_type == "cosine":
|
36
|
+
lr = base_lr * 0.5 * (1 + math.cos(math.pi * progress))
|
37
|
+
else:
|
38
|
+
raise ValueError(f"Unknown decay type: {self.decay_type}")
|
39
|
+
|
40
|
+
lr = max(self.min_lr, lr)
|
41
|
+
lrs.append(lr)
|
42
|
+
|
43
|
+
return lrs
|
44
|
+
|
45
|
+
|
46
|
+
class AdaptiveDropScheduler(_LRScheduler):
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
optimizer,
|
50
|
+
drop_factor=0.5,
|
51
|
+
patience=10,
|
52
|
+
min_lr=1e-6,
|
53
|
+
cooldown=5,
|
54
|
+
last_epoch=-1,
|
55
|
+
):
|
56
|
+
self.drop_factor = drop_factor
|
57
|
+
self.patience = patience
|
58
|
+
self.min_lr = min_lr
|
59
|
+
self.cooldown = cooldown
|
60
|
+
self.cooldown_counter = 0
|
61
|
+
self.best_loss = float("inf")
|
62
|
+
self.bad_steps = 0
|
63
|
+
super().__init__(optimizer, last_epoch)
|
64
|
+
|
65
|
+
def step(self, val_loss=None):
|
66
|
+
if val_loss is not None:
|
67
|
+
if val_loss < self.best_loss:
|
68
|
+
self.best_loss = val_loss
|
69
|
+
self.bad_steps = 0
|
70
|
+
self.cooldown_counter = 0
|
71
|
+
else:
|
72
|
+
self.bad_steps += 1
|
73
|
+
if self.bad_steps >= self.patience and self.cooldown_counter == 0:
|
74
|
+
for i, group in enumerate(self.optimizer.param_groups):
|
75
|
+
new_lr = max(group["lr"] * self.drop_factor, self.min_lr)
|
76
|
+
group["lr"] = new_lr
|
77
|
+
self.cooldown_counter = self.cooldown
|
78
|
+
self.bad_steps = 0
|
79
|
+
if self.cooldown_counter > 0:
|
80
|
+
self.cooldown_counter -= 1
|
81
|
+
|
82
|
+
def get_lr(self):
|
83
|
+
return [group["lr"] for group in self.optimizer.param_groups]
|
84
|
+
|
85
|
+
|
86
|
+
class WaveringLRScheduler(_LRScheduler):
|
87
|
+
def __init__(
|
88
|
+
self, optimizer, base_lr, max_lr, period=1000, decay=0.999, last_epoch=-1
|
89
|
+
):
|
90
|
+
"""
|
91
|
+
Sinusoidal-like oscillating LR. Can escape shallow local minima.
|
92
|
+
- base_lr: minimum LR
|
93
|
+
- max_lr: maximum LR
|
94
|
+
- period: full sine cycle in steps
|
95
|
+
- decay: multiplies max_lr each cycle
|
96
|
+
"""
|
97
|
+
self.base_lr = base_lr
|
98
|
+
self.max_lr = max_lr
|
99
|
+
self.period = period
|
100
|
+
self.decay = decay
|
101
|
+
super().__init__(optimizer, last_epoch)
|
102
|
+
|
103
|
+
def get_lr(self):
|
104
|
+
cycle = self.last_epoch // self.period
|
105
|
+
step_in_cycle = self.last_epoch % self.period
|
106
|
+
factor = math.sin(math.pi * step_in_cycle / self.period)
|
107
|
+
amplitude = (self.max_lr - self.base_lr) * (self.decay**cycle)
|
108
|
+
return [self.base_lr + amplitude * factor for _ in self.optimizer.param_groups]
|
lt_tensor/math_ops.py
ADDED
@@ -0,0 +1,120 @@
|
|
1
|
+
from ._torch_commons import *
|
2
|
+
|
3
|
+
|
4
|
+
def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
|
5
|
+
"""Applies sine function element-wise."""
|
6
|
+
return torch.sin(x * freq)
|
7
|
+
|
8
|
+
|
9
|
+
def cos_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
|
10
|
+
"""Applies cosine function element-wise."""
|
11
|
+
return torch.cos(x * freq)
|
12
|
+
|
13
|
+
|
14
|
+
def sin_plus_cos(x: Tensor, freq: float = 1.0) -> Tensor:
|
15
|
+
"""Returns sin(x) + cos(x)."""
|
16
|
+
return torch.sin(x * freq) + torch.cos(x * freq)
|
17
|
+
|
18
|
+
|
19
|
+
def sin_times_cos(x: Tensor, freq: float = 1.0) -> Tensor:
|
20
|
+
"""Returns sin(x) * cos(x)."""
|
21
|
+
return torch.sin(x * freq) * torch.cos(x * freq)
|
22
|
+
|
23
|
+
|
24
|
+
def apply_window(x: Tensor, window_type: str = "hann") -> Tensor:
|
25
|
+
"""Applies a window function to a 1D tensor."""
|
26
|
+
if window_type == "hann":
|
27
|
+
window = torch.hann_window(x.shape[-1], device=x.device)
|
28
|
+
elif window_type == "hamming":
|
29
|
+
window = torch.hamming_window(x.shape[-1], device=x.device)
|
30
|
+
else:
|
31
|
+
raise ValueError(f"Unsupported window type: {window_type}")
|
32
|
+
return x * window
|
33
|
+
|
34
|
+
|
35
|
+
def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
|
36
|
+
"""Circularly shifts tensor values: last becomes first (along given dim)."""
|
37
|
+
return torch.roll(x, shifts=1, dims=dim)
|
38
|
+
|
39
|
+
|
40
|
+
def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
|
41
|
+
"""Computes dot product along the specified dimension."""
|
42
|
+
return torch.sum(x * y, dim=dim)
|
43
|
+
|
44
|
+
|
45
|
+
def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
|
46
|
+
"""Normalizes a tensor to unit norm (L2)."""
|
47
|
+
return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
|
48
|
+
|
49
|
+
|
50
|
+
def stft(
|
51
|
+
waveform: Tensor,
|
52
|
+
n_fft: int = 512,
|
53
|
+
hop_length: Optional[int] = None,
|
54
|
+
win_length: Optional[int] = None,
|
55
|
+
window_fn: str = "hann",
|
56
|
+
center: bool = True,
|
57
|
+
return_complex: bool = True,
|
58
|
+
) -> Tensor:
|
59
|
+
"""Performs short-time Fourier transform using PyTorch."""
|
60
|
+
window = (
|
61
|
+
torch.hann_window(win_length or n_fft).to(waveform.device)
|
62
|
+
if window_fn == "hann"
|
63
|
+
else None
|
64
|
+
)
|
65
|
+
return torch.stft(
|
66
|
+
input=waveform,
|
67
|
+
n_fft=n_fft,
|
68
|
+
hop_length=hop_length,
|
69
|
+
win_length=win_length,
|
70
|
+
window=window,
|
71
|
+
center=center,
|
72
|
+
return_complex=return_complex,
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
def istft(
|
77
|
+
stft_matrix: Tensor,
|
78
|
+
n_fft: int = 512,
|
79
|
+
hop_length: Optional[int] = None,
|
80
|
+
win_length: Optional[int] = None,
|
81
|
+
window_fn: str = "hann",
|
82
|
+
center: bool = True,
|
83
|
+
length: Optional[int] = None,
|
84
|
+
) -> Tensor:
|
85
|
+
"""Performs inverse short-time Fourier transform using PyTorch."""
|
86
|
+
window = (
|
87
|
+
torch.hann_window(win_length or n_fft).to(stft_matrix.device)
|
88
|
+
if window_fn == "hann"
|
89
|
+
else None
|
90
|
+
)
|
91
|
+
return torch.istft(
|
92
|
+
input=stft_matrix,
|
93
|
+
n_fft=n_fft,
|
94
|
+
hop_length=hop_length,
|
95
|
+
win_length=win_length,
|
96
|
+
window=window,
|
97
|
+
center=center,
|
98
|
+
length=length,
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
|
103
|
+
"""Returns the FFT of a real tensor."""
|
104
|
+
return torch.fft.fft(x, norm=norm)
|
105
|
+
|
106
|
+
|
107
|
+
def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
|
108
|
+
"""Returns the inverse FFT of a complex tensor."""
|
109
|
+
return torch.fft.ifft(x, norm=norm)
|
110
|
+
|
111
|
+
|
112
|
+
def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
|
113
|
+
"""Returns log magnitude from complex STFT."""
|
114
|
+
magnitude = torch.abs(stft_complex)
|
115
|
+
return torch.log(magnitude + eps)
|
116
|
+
|
117
|
+
|
118
|
+
def phase(stft_complex: Tensor) -> Tensor:
|
119
|
+
"""Returns phase from complex STFT."""
|
120
|
+
return torch.angle(stft_complex)
|