lt-tensor 0.0.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1dev3"
lt_tensor/_basics.py ADDED
@@ -0,0 +1,270 @@
1
+ __all__ = ["Model"]
2
+
3
+
4
+ import warnings
5
+ from ._torch_commons import *
6
+
7
+ ROOT_DEVICE = torch.device(torch.zeros(1).device)
8
+
9
+
10
+ class _ModelDevice(nn.Module):
11
+ """
12
+ This makes it easier to assign a device and retrieves it later
13
+ """
14
+
15
+ _device: torch.device = ROOT_DEVICE
16
+
17
+ @property
18
+ def device(self):
19
+ return self._device
20
+
21
+ @device.setter
22
+ def device(self, device: Union[torch.device, str]):
23
+ assert isinstance(device, (str, torch.device))
24
+ self._device = torch.device(device) if isinstance(device, str) else device
25
+ self.tp_apply_device_to()
26
+
27
+ def tp_apply_device_to(self):
28
+ """Add here components that are needed to have device applied to them,
29
+ that usualy the '.to()' function fails to apply
30
+
31
+ example:
32
+ ```
33
+ def tp_apply_device_to(self):
34
+ self.my_tensor = self.my_tensor.to(device=self.device)
35
+ ```
36
+ """
37
+ pass
38
+
39
+ def to(self, *args, **kwargs):
40
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
41
+ *args, **kwargs
42
+ )
43
+
44
+ if dtype is not None:
45
+ if not (dtype.is_floating_point or dtype.is_complex):
46
+ raise TypeError(
47
+ "nn.Module.to only accepts floating point or complex "
48
+ f"dtypes, but got desired dtype={dtype}"
49
+ )
50
+ if dtype.is_complex:
51
+ warnings.warn(
52
+ "Complex modules are a new feature under active development whose design may change, "
53
+ "and some modules might not work as expected when using complex tensors as parameters or buffers. "
54
+ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
55
+ "if a complex module does not work as expected."
56
+ )
57
+
58
+ def convert(t: Tensor):
59
+ try:
60
+ if convert_to_format is not None and t.dim() in (4, 5):
61
+ return t.to(
62
+ device,
63
+ dtype if t.is_floating_point() or t.is_complex() else None,
64
+ non_blocking,
65
+ memory_format=convert_to_format,
66
+ )
67
+ return t.to(
68
+ device,
69
+ dtype if t.is_floating_point() or t.is_complex() else None,
70
+ non_blocking,
71
+ )
72
+ except NotImplementedError as e:
73
+ if str(e) == "Cannot copy out of meta tensor; no data!":
74
+ raise NotImplementedError(
75
+ f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
76
+ f"when moving module from meta to a different device."
77
+ ) from None
78
+ else:
79
+ raise
80
+
81
+ self._apply(convert)
82
+ self.device = device
83
+ return self
84
+
85
+ def ipu(self, device: Optional[Union[int, torch.device]] = None) -> T:
86
+ super().ipu(device)
87
+ dvc = "ipu"
88
+ if device is not None:
89
+ dvc += (
90
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
91
+ )
92
+ self.device = dvc
93
+ return self
94
+
95
+ def xpu(self, device: Optional[Union[int, torch.device]] = None) -> T:
96
+ super().xpu(device)
97
+ dvc = "xpu"
98
+ if device is not None:
99
+ dvc += (
100
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
101
+ )
102
+ self.device = dvc
103
+ return self
104
+
105
+ def cuda(self, device: Optional[Union[int, torch.device]] = None) -> T:
106
+ super().cuda(device)
107
+ dvc = "cuda"
108
+ if device is not None:
109
+ dvc += (
110
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
111
+ )
112
+ self.device = dvc
113
+ return self
114
+
115
+ def mtia(self, device: Optional[Union[int, torch.device]] = None) -> T:
116
+ super().mtia(device)
117
+ dvc = "mtia"
118
+ if device is not None:
119
+ dvc += (
120
+ ":" + str(device) if isinstance(device, (int, float)) else device.index
121
+ )
122
+ self.device = dvc
123
+ return self
124
+
125
+ def cpu(self) -> T:
126
+ super().cpu()
127
+ self.device = "cpu"
128
+ return self
129
+
130
+
131
+ class Model(_ModelDevice, ABC):
132
+ _autocast: bool = False
133
+
134
+ @property
135
+ def autocast(self):
136
+ return self._autocast
137
+
138
+ @autocast.setter
139
+ def autocast(self, value: bool):
140
+ self._autocast = value
141
+
142
+ def count_trainable_parameters(self, module_name: Optional[str] = None):
143
+ """Gets the number of trainable parameters from either the entire model or from a specific module."""
144
+ if module_name is not None:
145
+ assert hasattr(self, module_name), f"Module {module_name} does not exits"
146
+ module = getattr(self, module_name)
147
+ return sum(
148
+ [
149
+ x.numel()
150
+ for x in module.parameters()
151
+ if hasattr(x, "requires_grad") and x.requires_grad
152
+ ]
153
+ )
154
+ return sum(
155
+ [
156
+ x.numel()
157
+ for x in self.parameters()
158
+ if hasattr(x, "requires_grad") and x.requires_grad
159
+ ]
160
+ )
161
+
162
+ def count_non_trainable_parameters(self, module_name: Optional[str] = None):
163
+ """Gets the number of non-trainable parameters from either the entire model or from a specific module."""
164
+ if module_name is not None:
165
+ assert hasattr(self, module_name), f"Module {module_name} does not exits"
166
+ module = getattr(self, module_name)
167
+ return sum(
168
+ [
169
+ x.numel()
170
+ for x in module.parameters()
171
+ if not hasattr(x, "requires_grad") or not x.requires_grad
172
+ ]
173
+ )
174
+ return sum(
175
+ [
176
+ x.numel()
177
+ for x in self.parameters()
178
+ if not hasattr(x, "requires_grad") or not x.requires_grad
179
+ ]
180
+ )
181
+
182
+ def get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
183
+ """Returns the weights of the model entrie model or from a specified module"""
184
+ if module_name is not None:
185
+ assert hasattr(self, module_name), f"Module {module_name} does not exits"
186
+ module = getattr(self, module_name)
187
+ return [x.data.detach() for x in module.parameters()]
188
+ return [x.data.detach() for x in self.parameters()]
189
+
190
+ def print_trainable_parameters(
191
+ self, module_name: Optional[str] = None
192
+ ) -> List[Tensor]:
193
+ params = format(self.count_trainable_parameters(module_name), ",").replace(
194
+ ",", "."
195
+ )
196
+ if module_name:
197
+ print(f'Trainable Parameters from "{module_name}": {params}')
198
+ else:
199
+ print(f"Trainable Parameters: {params}")
200
+
201
+ def print_non_trainable_parameters(
202
+ self, module_name: Optional[str] = None
203
+ ) -> List[Tensor]:
204
+ params = format(self.count_non_trainable_parameters(module_name), ",").replace(
205
+ ",", "."
206
+ )
207
+ if module_name:
208
+ print(f'Non-Trainable Parameters from "{module_name}": {params}')
209
+ else:
210
+ print(f"Non-Trainable Parameters: {params}")
211
+
212
+ def save_weights(self, path: Union[Path, str]):
213
+ path = Path(path)
214
+ if path.exists():
215
+ assert (
216
+ path.is_file()
217
+ ), "The provided path exists but its a directory not a file!"
218
+ path.rmdir()
219
+ path.parent.mkdir(exist_ok=True, parents=True)
220
+ torch.save(obj=self.state_dict(), f=str(path))
221
+
222
+ def load_weights(
223
+ self,
224
+ path: Union[Path, str],
225
+ raise_if_not_exists: bool = False,
226
+ strict: bool = True,
227
+ assign: bool = False,
228
+ weights_only: bool = False,
229
+ mmap: Optional[bool] = None,
230
+ **torch_loader_kwargs,
231
+ ):
232
+ path = Path(path)
233
+ if not path.exists():
234
+ assert not raise_if_not_exists, "Path does not exists!"
235
+ return None
236
+ assert path.is_file(), "The provided path is not a valid file!"
237
+ state_dict = torch.load(
238
+ str(path), weights_only=weights_only, mmap=mmap, **torch_loader_kwargs
239
+ )
240
+ incompatible_keys = self.load_state_dict(
241
+ state_dict,
242
+ strict=strict,
243
+ assign=assign,
244
+ )
245
+ return incompatible_keys
246
+
247
+ @torch.no_grad()
248
+ def inference(self, *args, **kwargs):
249
+ if self.training:
250
+ self.eval()
251
+ if self.autocast:
252
+ with torch.autocast(device_type=self.device.type):
253
+ return self(*args, **kwargs)
254
+ return self(*args, **kwargs)
255
+
256
+ def train_step(
257
+ self,
258
+ *args,
259
+ **kwargs,
260
+ ):
261
+ """Train Step"""
262
+ if not self.training:
263
+ self.train()
264
+ return self(*args, **kwargs)
265
+
266
+ @abstractmethod
267
+ def forward(
268
+ self, *args, **kwargs
269
+ ) -> Union[Tensor, Sequence[Tensor], Dict[Any, Union[Any, Tensor]]]:
270
+ pass
@@ -0,0 +1,12 @@
1
+ import torch
2
+ from torch import nn, optim
3
+ import torch.nn.functional as F
4
+ from torch.optim import Optimizer
5
+ from torch.nn import Module, L1Loss, MSELoss
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch import Tensor, FloatTensor, device, LongTensor
8
+ from torch.nn.utils.parametrizations import weight_norm, spectral_norm
9
+
10
+ from lt_utils.common import *
11
+
12
+ DeviceType: TypeAlias = Union[device, str]
@@ -0,0 +1,114 @@
1
+ __all__ = [
2
+ "WarmupDecayScheduler",
3
+ "AdaptiveDropScheduler",
4
+ "WaveringLRScheduler",
5
+ ]
6
+
7
+ import math
8
+ from torch.optim import Optimizer
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+
11
+
12
+ class WarmupDecayScheduler(_LRScheduler):
13
+ def __init__(
14
+ self,
15
+ optimizer: Optimizer,
16
+ warmup_steps: int,
17
+ total_steps: int,
18
+ decay_type: str = "linear", # or "cosine"
19
+ min_lr: float = 0.0,
20
+ last_epoch: int = -1,
21
+ ):
22
+ self.warmup_steps = warmup_steps
23
+ self.total_steps = total_steps
24
+ self.decay_type = decay_type
25
+ self.min_lr = min_lr
26
+ super().__init__(optimizer, last_epoch)
27
+
28
+ def get_lr(self):
29
+ step = self.last_epoch + 1
30
+ warmup = self.warmup_steps
31
+ total = self.total_steps
32
+ lrs = []
33
+
34
+ for base_lr in self.base_lrs:
35
+ if step < warmup:
36
+ lr = base_lr * step / warmup
37
+ else:
38
+ progress = (step - warmup) / max(1, total - warmup)
39
+ if self.decay_type == "linear":
40
+ lr = base_lr * (1.0 - progress)
41
+ elif self.decay_type == "cosine":
42
+ lr = base_lr * 0.5 * (1 + math.cos(math.pi * progress))
43
+ else:
44
+ raise ValueError(f"Unknown decay type: {self.decay_type}")
45
+
46
+ lr = max(self.min_lr, lr)
47
+ lrs.append(lr)
48
+
49
+ return lrs
50
+
51
+
52
+ class AdaptiveDropScheduler(_LRScheduler):
53
+ def __init__(
54
+ self,
55
+ optimizer,
56
+ drop_factor=0.5,
57
+ patience=10,
58
+ min_lr=1e-6,
59
+ cooldown=5,
60
+ last_epoch=-1,
61
+ ):
62
+ self.drop_factor = drop_factor
63
+ self.patience = patience
64
+ self.min_lr = min_lr
65
+ self.cooldown = cooldown
66
+ self.cooldown_counter = 0
67
+ self.best_loss = float("inf")
68
+ self.bad_steps = 0
69
+ super().__init__(optimizer, last_epoch)
70
+
71
+ def step(self, val_loss=None):
72
+ if val_loss is not None:
73
+ if val_loss < self.best_loss:
74
+ self.best_loss = val_loss
75
+ self.bad_steps = 0
76
+ self.cooldown_counter = 0
77
+ else:
78
+ self.bad_steps += 1
79
+ if self.bad_steps >= self.patience and self.cooldown_counter == 0:
80
+ for i, group in enumerate(self.optimizer.param_groups):
81
+ new_lr = max(group["lr"] * self.drop_factor, self.min_lr)
82
+ group["lr"] = new_lr
83
+ self.cooldown_counter = self.cooldown
84
+ self.bad_steps = 0
85
+ if self.cooldown_counter > 0:
86
+ self.cooldown_counter -= 1
87
+
88
+ def get_lr(self):
89
+ return [group["lr"] for group in self.optimizer.param_groups]
90
+
91
+
92
+ class WaveringLRScheduler(_LRScheduler):
93
+ def __init__(
94
+ self, optimizer, base_lr, max_lr, period=1000, decay=0.999, last_epoch=-1
95
+ ):
96
+ """
97
+ Sinusoidal-like oscillating LR. Can escape shallow local minima.
98
+ - base_lr: minimum LR
99
+ - max_lr: maximum LR
100
+ - period: full sine cycle in steps
101
+ - decay: multiplies max_lr each cycle
102
+ """
103
+ self.base_lr = base_lr
104
+ self.max_lr = max_lr
105
+ self.period = period
106
+ self.decay = decay
107
+ super().__init__(optimizer, last_epoch)
108
+
109
+ def get_lr(self):
110
+ cycle = self.last_epoch // self.period
111
+ step_in_cycle = self.last_epoch % self.period
112
+ factor = math.sin(math.pi * step_in_cycle / self.period)
113
+ amplitude = (self.max_lr - self.base_lr) * (self.decay**cycle)
114
+ return [self.base_lr + amplitude * factor for _ in self.optimizer.param_groups]
lt_tensor/math_ops.py ADDED
@@ -0,0 +1,71 @@
1
+ __all__ = [
2
+ "sin_tensor",
3
+ "cos_tensor",
4
+ "sin_plus_cos",
5
+ "sin_times_cos",
6
+ "apply_window",
7
+ "shift_ring",
8
+ "dot_product",
9
+ "normalize_tensor",
10
+ "log_magnitude",
11
+ "phase",
12
+ ]
13
+
14
+ from ._torch_commons import *
15
+
16
+
17
+ def sin_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
18
+ """Applies sine function element-wise."""
19
+ return torch.sin(x * freq)
20
+
21
+
22
+ def cos_tensor(x: Tensor, freq: float = 1.0) -> Tensor:
23
+ """Applies cosine function element-wise."""
24
+ return torch.cos(x * freq)
25
+
26
+
27
+ def sin_plus_cos(x: Tensor, freq: float = 1.0) -> Tensor:
28
+ """Returns sin(x) + cos(x)."""
29
+ return torch.sin(x * freq) + torch.cos(x * freq)
30
+
31
+
32
+ def sin_times_cos(x: Tensor, freq: float = 1.0) -> Tensor:
33
+ """Returns sin(x) * cos(x)."""
34
+ return torch.sin(x * freq) * torch.cos(x * freq)
35
+
36
+
37
+ def apply_window(x: Tensor, window_type: str = "hann") -> Tensor:
38
+ """Applies a window function to a 1D tensor."""
39
+ if window_type == "hann":
40
+ window = torch.hann_window(x.shape[-1], device=x.device)
41
+ elif window_type == "hamming":
42
+ window = torch.hamming_window(x.shape[-1], device=x.device)
43
+ else:
44
+ raise ValueError(f"Unsupported window type: {window_type}")
45
+ return x * window
46
+
47
+
48
+ def shift_ring(x: Tensor, dim: int = -1) -> Tensor:
49
+ """Circularly shifts tensor values: last becomes first (along given dim)."""
50
+ return torch.roll(x, shifts=1, dims=dim)
51
+
52
+
53
+ def dot_product(x: Tensor, y: Tensor, dim: int = -1) -> Tensor:
54
+ """Computes dot product along the specified dimension."""
55
+ return torch.sum(x * y, dim=dim)
56
+
57
+
58
+ def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
59
+ """Normalizes a tensor to unit norm (L2)."""
60
+ return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
61
+
62
+
63
+ def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
64
+ """Returns log magnitude from complex STFT."""
65
+ magnitude = torch.abs(stft_complex)
66
+ return torch.log(magnitude + eps)
67
+
68
+
69
+ def phase(stft_complex: Tensor) -> Tensor:
70
+ """Returns phase from complex STFT."""
71
+ return torch.angle(stft_complex)