lt-tensor 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py ADDED
File without changes
lt_tensor/_basics.py ADDED
@@ -0,0 +1,244 @@
1
+ __all__ = ["Model"]
2
+
3
+
4
+ import warnings
5
+ from ._torch_commons import *
6
+
7
+ ROOT_DEVICE = torch.device(torch.zeros(1).device)
8
+
9
+ class _ModelDevice(nn.Module):
10
+ """
11
+ This makes it easier to assign a device and retrieves it later
12
+ """
13
+
14
+ _device: torch.device = ROOT_DEVICE
15
+
16
+ @property
17
+ def device(self):
18
+ return self._device
19
+
20
+ @device.setter
21
+ def device(self, device: Union[torch.device, str]):
22
+ assert isinstance(device, (str, torch.device))
23
+ self._device = torch.device(device) if isinstance(device, str) else device
24
+ self.tp_apply_device_to()
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+
29
+ def tp_apply_device_to(self):
30
+ """Add here components that are needed to have device applied to them,
31
+ that usualy the '.to()' function fails to apply
32
+
33
+ example:
34
+ ```
35
+ def tp_apply_device_to(self):
36
+ self.my_tensor = self.my_tensor.to(device=self.device)
37
+ ```
38
+ """
39
+ pass
40
+
41
+ def to(self, *args, **kwargs):
42
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
43
+ *args, **kwargs
44
+ )
45
+
46
+ if dtype is not None:
47
+ if not (dtype.is_floating_point or dtype.is_complex):
48
+ raise TypeError(
49
+ "nn.Module.to only accepts floating point or complex "
50
+ f"dtypes, but got desired dtype={dtype}"
51
+ )
52
+ if dtype.is_complex:
53
+ warnings.warn(
54
+ "Complex modules are a new feature under active development whose design may change, "
55
+ "and some modules might not work as expected when using complex tensors as parameters or buffers. "
56
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
57
+ "if a complex module does not work as expected."
58
+ )
59
+
60
+ def convert(t: Tensor):
61
+ try:
62
+ if convert_to_format is not None and t.dim() in (4, 5):
63
+ return t.to(
64
+ device,
65
+ dtype if t.is_floating_point() or t.is_complex() else None,
66
+ non_blocking,
67
+ memory_format=convert_to_format,
68
+ )
69
+ return t.to(
70
+ device,
71
+ dtype if t.is_floating_point() or t.is_complex() else None,
72
+ non_blocking,
73
+ )
74
+ except NotImplementedError as e:
75
+ if str(e) == "Cannot copy out of meta tensor; no data!":
76
+ raise NotImplementedError(
77
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
78
+ f"when moving module from meta to a different device."
79
+ ) from None
80
+ else:
81
+ raise
82
+
83
+ self._apply(convert)
84
+ self.device = device
85
+ return self
86
+
87
+ def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
88
+ super().ipu(device)
89
+ dvc = "ipu"
90
+ if device is not None:
91
+ dvc += (
92
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
93
+ )
94
+ self.device = dvc
95
+ return self
96
+
97
+ def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
98
+ super().xpu(device)
99
+ dvc = "xpu"
100
+ if device is not None:
101
+ dvc += (
102
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
103
+ )
104
+ self.device = dvc
105
+ return self
106
+
107
+ def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
108
+ super().cuda(device)
109
+ dvc = "cuda"
110
+ if device is not None:
111
+ dvc += (
112
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
113
+ )
114
+ self.device = dvc
115
+ return self
116
+
117
+ def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
118
+ super().mtia(device)
119
+ dvc = "mtia"
120
+ if device is not None:
121
+ dvc += (
122
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
123
+ )
124
+ self.device = dvc
125
+ return self
126
+
127
+ def cpu(self) -> T:
128
+ super().cpu()
129
+ self.device = "cpu"
130
+ return self
131
+
132
+
133
+ class Model(_ModelDevice, ABC):
134
+ """
135
+ For easier access of items in a model
136
+ This applies to all except to "forward", "save_weights", "load_weights",
137
+ "train_step", and "inference".
138
+ """
139
+
140
+ _autocast: bool = False
141
+
142
+ @property
143
+ def autocast(self):
144
+ return self._autocast
145
+
146
+ @autocast.setter
147
+ def autocast(self, value: bool):
148
+ self._autocast = value
149
+
150
+ def __init__(
151
+ self,
152
+ *args,
153
+ **kwargs,
154
+ ):
155
+ super().__init__(*args, **kwargs)
156
+ device = kwargs.get("device", None)
157
+ if isinstance(device, (torch.device, str)):
158
+ self.to(device)
159
+ self.autocast = kwargs.get("autocast", self.autocast)
160
+
161
+ def tp_count_parameters(self, module_name: Optional[str] = None):
162
+ """Gets the numer of parameters from either the entire model or from a specific module"""
163
+ if module_name is not None:
164
+ assert hasattr(self, module_name), f"Module {module_name} does not exits"
165
+ module = getattr(self, module_name)
166
+ return sum([x.numel() for x in module.parameters()])
167
+ else:
168
+ return sum([x.numel() for x in self.parameters()])
169
+
170
+ def tp_get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
171
+ """Returns the weights of the model entrie model or from a specified module"""
172
+ if module_name is not None:
173
+ assert hasattr(self, module_name), f"Module {module_name} does not exits"
174
+ module = getattr(self, module_name)
175
+ return [x.data.detach() for x in module.parameters()]
176
+ return [x.data.detach() for x in self.parameters()]
177
+
178
+ def tp_print_num_params(self, module_name: Optional[str] = None) -> List[Tensor]:
179
+ params = format(self.tp_count_parameters(module_name), ",").replace(",", ".")
180
+
181
+ if module_name:
182
+ print(f'Parameters from "{module_name}": {params}')
183
+ else:
184
+ print(f"Parameters: {params}")
185
+
186
+ def save_weights(self, path: Union[Path, str]):
187
+ path = Path(path)
188
+ if path.exists():
189
+ assert (
190
+ path.is_file()
191
+ ), "The provided path exists but its a directory not a file!"
192
+ path.rmdir()
193
+ path.parent.mkdir(exist_ok=True, parents=True)
194
+ torch.save(obj=self.state_dict(), f=str(path))
195
+
196
+ def load_weights(
197
+ self,
198
+ path: Union[Path, str],
199
+ raise_if_not_exists: bool = False,
200
+ strict: bool = True,
201
+ assign: bool = False,
202
+ weights_only: bool = False,
203
+ mmap: Optional[bool] = None,
204
+ **torch_loader_kwargs,
205
+ ):
206
+ path = Path(path)
207
+ if not path.exists():
208
+ assert not raise_if_not_exists, "Path does not exists!"
209
+ return None
210
+ assert path.is_file(), "The provided path is not a valid file!"
211
+ state_dict = torch.load(
212
+ str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
213
+ )
214
+ incompatible_keys = self.load_state_dict(
215
+ state_dict,
216
+ strict=strict,
217
+ assign=assign,
218
+ )
219
+ return incompatible_keys
220
+
221
+ @torch.no_grad()
222
+ def inference(self, *args, **kwargs):
223
+ if self.training:
224
+ self.eval()
225
+ if self.autocast:
226
+ with torch.autocast(device_type=self.device.type):
227
+ return self(*args, **kwargs)
228
+ return self(*args, **kwargs)
229
+
230
+ def train_step(
231
+ self,
232
+ *args,
233
+ **kwargs,
234
+ ):
235
+ """Train Step"""
236
+ if not self.training:
237
+ self.train()
238
+ return self(*args, **kwargs)
239
+
240
+ @abstractmethod
241
+ def forward(
242
+ self, *args, **kwargs
243
+ ) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
244
+ pass
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from torch import nn, optim
3
+ import torch.nn.functional as F
4
+ from torch.optim import Optimizer
5
+ from torch.nn import Module, L1Loss, MSELoss
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch import Tensor, FloatTensor, device, LongTensor
8
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
9
+
10
+ from lt_utils.common import *
11
+
12
+ DeviceType: TypeAlias = Union[device, str]
@@ -0,0 +1,108 @@
1
+ import math
2
+ from torch.optim import Optimizer
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class WarmupDecayScheduler(_LRScheduler):
7
+ def __init__(
8
+ self,
9
+ optimizer: Optimizer,
10
+ warmup_steps: int,
11
+ total_steps: int,
12
+ decay_type: str = "linear", # or "cosine"
13
+ min_lr: float = 0.0,
14
+ last_epoch: int = -1,
15
+ ):
16
+ self.warmup_steps = warmup_steps
17
+ self.total_steps = total_steps
18
+ self.decay_type = decay_type
19
+ self.min_lr = min_lr
20
+ super().__init__(optimizer, last_epoch)
21
+
22
+ def get_lr(self):
23
+ step = self.last_epoch + 1
24
+ warmup = self.warmup_steps
25
+ total = self.total_steps
26
+ lrs = []
27
+
28
+ for base_lr in self.base_lrs:
29
+ if step < warmup:
30
+ lr = base_lr * step / warmup
31
+ else:
32
+ progress = (step - warmup) / max(1, total - warmup)
33
+ if self.decay_type == "linear":
34
+ lr = base_lr * (1.0 - progress)
35
+ elif self.decay_type == "cosine":
36
+ lr = base_lr * 0.5 * (1 + math.cos(math.pi * progress))
37
+ else:
38
+ raise ValueError(f"Unknown decay type: {self.decay_type}")
39
+
40
+ lr = max(self.min_lr, lr)
41
+ lrs.append(lr)
42
+
43
+ return lrs
44
+
45
+
46
+ class AdaptiveDropScheduler(_LRScheduler):
47
+ def __init__(
48
+ self,
49
+ optimizer,
50
+ drop_factor=0.5,
51
+ patience=10,
52
+ min_lr=1e-6,
53
+ cooldown=5,
54
+ last_epoch=-1,
55
+ ):
56
+ self.drop_factor = drop_factor
57
+ self.patience = patience
58
+ self.min_lr = min_lr
59
+ self.cooldown = cooldown
60
+ self.cooldown_counter = 0
61
+ self.best_loss = float("inf")
62
+ self.bad_steps = 0
63
+ super().__init__(optimizer, last_epoch)
64
+
65
+ def step(self, val_loss=None):
66
+ if val_loss is not None:
67
+ if val_loss < self.best_loss:
68
+ self.best_loss = val_loss
69
+ self.bad_steps = 0
70
+ self.cooldown_counter = 0
71
+ else:
72
+ self.bad_steps += 1
73
+ if self.bad_steps >= self.patience and self.cooldown_counter == 0:
74
+ for i, group in enumerate(self.optimizer.param_groups):
75
+ new_lr = max(group["lr"] * self.drop_factor, self.min_lr)
76
+ group["lr"] = new_lr
77
+ self.cooldown_counter = self.cooldown
78
+ self.bad_steps = 0
79
+ if self.cooldown_counter > 0:
80
+ self.cooldown_counter -= 1
81
+
82
+ def get_lr(self):
83
+ return [group["lr"] for group in self.optimizer.param_groups]
84
+
85
+
86
+ class WaveringLRScheduler(_LRScheduler):
87
+ def __init__(
88
+ self, optimizer, base_lr, max_lr, period=1000, decay=0.999, last_epoch=-1
89
+ ):
90
+ """
91
+ Sinusoidal-like oscillating LR. Can escape shallow local minima.
92
+ - base_lr: minimum LR
93
+ - max_lr: maximum LR
94
+ - period: full sine cycle in steps
95
+ - decay: multiplies max_lr each cycle
96
+ """
97
+ self.base_lr = base_lr
98
+ self.max_lr = max_lr
99
+ self.period = period
100
+ self.decay = decay
101
+ super().__init__(optimizer, last_epoch)
102
+
103
+ def get_lr(self):
104
+ cycle = self.last_epoch // self.period
105
+ step_in_cycle = self.last_epoch % self.period
106
+ factor = math.sin(math.pi * step_in_cycle / self.period)
107
+ amplitude = (self.max_lr - self.base_lr) * (self.decay**cycle)
108
+ return [self.base_lr + amplitude * factor for _ in self.optimizer.param_groups]
lt_tensor/math_ops.py ADDED
@@ -0,0 +1,120 @@
1
+ from ._torch_commons import *
2
+
3
+
4
+ def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
5
+ """Applies sine function element-wise."""
6
+ return torch.sin(x * freq)
7
+
8
+
9
+ def cos_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
10
+ """Applies cosine function element-wise."""
11
+ return torch.cos(x * freq)
12
+
13
+
14
+ def sin_plus_cos(x: Tensor, freq: float = 1.0) -> Tensor:
15
+ """Returns sin(x) + cos(x)."""
16
+ return torch.sin(x * freq) + torch.cos(x * freq)
17
+
18
+
19
+ def sin_times_cos(x: Tensor, freq: float = 1.0) -> Tensor:
20
+ """Returns sin(x) * cos(x)."""
21
+ return torch.sin(x * freq) * torch.cos(x * freq)
22
+
23
+
24
+ def apply_window(x: Tensor, window_type: str = "hann") -> Tensor:
25
+ """Applies a window function to a 1D tensor."""
26
+ if window_type == "hann":
27
+ window = torch.hann_window(x.shape[-1], device=x.device)
28
+ elif window_type == "hamming":
29
+ window = torch.hamming_window(x.shape[-1], device=x.device)
30
+ else:
31
+ raise ValueError(f"Unsupported window type: {window_type}")
32
+ return x * window
33
+
34
+
35
+ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
36
+ """Circularly shifts tensor values: last becomes first (along given dim)."""
37
+ return torch.roll(x, shifts=1, dims=dim)
38
+
39
+
40
+ def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
41
+ """Computes dot product along the specified dimension."""
42
+ return torch.sum(x * y, dim=dim)
43
+
44
+
45
+ def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
46
+ """Normalizes a tensor to unit norm (L2)."""
47
+ return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
48
+
49
+
50
+ def stft(
51
+ waveform: Tensor,
52
+ n_fft: int = 512,
53
+ hop_length: Optional[int] = None,
54
+ win_length: Optional[int] = None,
55
+ window_fn: str = "hann",
56
+ center: bool = True,
57
+ return_complex: bool = True,
58
+ ) -> Tensor:
59
+ """Performs short-time Fourier transform using PyTorch."""
60
+ window = (
61
+ torch.hann_window(win_length or n_fft).to(waveform.device)
62
+ if window_fn == "hann"
63
+ else None
64
+ )
65
+ return torch.stft(
66
+ input=waveform,
67
+ n_fft=n_fft,
68
+ hop_length=hop_length,
69
+ win_length=win_length,
70
+ window=window,
71
+ center=center,
72
+ return_complex=return_complex,
73
+ )
74
+
75
+
76
+ def istft(
77
+ stft_matrix: Tensor,
78
+ n_fft: int = 512,
79
+ hop_length: Optional[int] = None,
80
+ win_length: Optional[int] = None,
81
+ window_fn: str = "hann",
82
+ center: bool = True,
83
+ length: Optional[int] = None,
84
+ ) -> Tensor:
85
+ """Performs inverse short-time Fourier transform using PyTorch."""
86
+ window = (
87
+ torch.hann_window(win_length or n_fft).to(stft_matrix.device)
88
+ if window_fn == "hann"
89
+ else None
90
+ )
91
+ return torch.istft(
92
+ input=stft_matrix,
93
+ n_fft=n_fft,
94
+ hop_length=hop_length,
95
+ win_length=win_length,
96
+ window=window,
97
+ center=center,
98
+ length=length,
99
+ )
100
+
101
+
102
+ def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
103
+ """Returns the FFT of a real tensor."""
104
+ return torch.fft.fft(x, norm=norm)
105
+
106
+
107
+ def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
108
+ """Returns the inverse FFT of a complex tensor."""
109
+ return torch.fft.ifft(x, norm=norm)
110
+
111
+
112
+ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
113
+ """Returns log magnitude from complex STFT."""
114
+ magnitude = torch.abs(stft_complex)
115
+ return torch.log(magnitude + eps)
116
+
117
+
118
+ def phase(stft_complex: Tensor) -> Tensor:
119
+ """Returns phase from complex STFT."""
120
+ return torch.angle(stft_complex)