lt-tensor 0.0.1a4__py3-none-any.whl → 0.0.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lt_tensor/model_base.py CHANGED
@@ -4,6 +4,7 @@ __all__ = ["Model"]
4
4
  import warnings
5
5
  from .torch_commons import *
6
6
  from lt_utils.common import *
7
+ from lt_utils.misc_utils import log_traceback
7
8
 
8
9
  T = TypeVar("T")
9
10
 
@@ -40,20 +41,113 @@ class Model(nn.Module, ABC):
40
41
  def device(self, device: Union[torch.device, str]):
41
42
  assert isinstance(device, (str, torch.device))
42
43
  self._device = torch.device(device) if isinstance(device, str) else device
43
- self.tp_apply_device_to()
44
+ self._apply_device_to()
44
45
 
45
- def tp_apply_device_to(self):
46
+ def _apply_device_to(self):
46
47
  """Add here components that are needed to have device applied to them,
47
- that usualy the '.to()' function fails to apply
48
+ that usually the '.to()' function fails to apply
48
49
 
49
50
  example:
50
51
  ```
51
- def tp_apply_device_to(self):
52
+ def _apply_device_to(self):
52
53
  self.my_tensor = self.my_tensor.to(device=self.device)
53
54
  ```
54
55
  """
55
56
  pass
56
57
 
58
+ def freeze_weight(self, weight: Union[str, nn.Module], freeze: bool):
59
+ assert isinstance(weight, (str, nn.Module))
60
+ if isinstance(weight, str):
61
+ if hasattr(self, weight):
62
+ w = getattr(self, weight)
63
+ if isinstance(w, nn.Module):
64
+ w.requires_grad_(not freeze)
65
+ else:
66
+ weight.requires_grad_(not freeze)
67
+
68
+ def _freeze_unfreeze(
69
+ self,
70
+ weight: Union[str, nn.Module],
71
+ task: Literal["freeze", "unfreeze"] = "freeze",
72
+ _skip_except: bool = False,
73
+ ):
74
+ try:
75
+ assert isinstance(weight, (str, nn.Module))
76
+ if isinstance(weight, str):
77
+ w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a valid attribute of {self._get_name()}"
78
+ if hasattr(self, weight):
79
+ w_txt = f"Failed to {task} the module '{weight}'. Reason: is not a Module type."
80
+ w = getattr(self, weight)
81
+ if isinstance(w, nn.Module):
82
+ w_txt = f"Successfully {task} the module '{weight}'."
83
+ w.requires_grad_(task == "unfreeze")
84
+
85
+ else:
86
+ w.requires_grad_(task == "unfreeze")
87
+ w_txt = f"Successfully '{task}' the module '{weight}'."
88
+ return w_txt
89
+ except Exception as e:
90
+ if not _skip_except:
91
+ raise e
92
+ return str(e)
93
+
94
+ def freeze_weight(
95
+ self,
96
+ weight: Union[str, nn.Module],
97
+ _skip_except: bool = False,
98
+ ):
99
+ return self._freeze_unfreeze(weight, "freeze", _skip_except)
100
+
101
+ def unfreeze_weight(
102
+ self,
103
+ weight: Union[str, nn.Module],
104
+ _skip_except: bool = False,
105
+ ):
106
+ return self._freeze_unfreeze(weight, "freeze", _skip_except)
107
+
108
+ def freeze_all(self, exclude: Optional[List[str]] = None):
109
+ no_exclusions = not exclude
110
+ frozen = []
111
+ not_frozen = []
112
+ for name, param in self.named_parameters():
113
+ if no_exclusions:
114
+ try:
115
+ param.requires_grad_(False)
116
+ frozen.append(name)
117
+ except Exception as e:
118
+ not_frozen.append((name, str(e)))
119
+ elif any(layer in name for layer in exclude):
120
+ try:
121
+ param.requires_grad_(False)
122
+ frozen.append(name)
123
+ except Exception as e:
124
+ not_frozen.append((name, str(e)))
125
+ else:
126
+ not_frozen.append((name, "Excluded"))
127
+ return dict(frozen=frozen, not_frozen=not_frozen)
128
+
129
+ def unfreeze_all_except(self, exclude: Optional[list[str]] = None):
130
+ """Unfreezes all model parameters except specified layers."""
131
+ no_exclusions = not exclude
132
+ unfrozen = []
133
+ not_unfrozen = []
134
+ for name, param in self.named_parameters():
135
+ if no_exclusions:
136
+ try:
137
+ param.requires_grad_(True)
138
+ unfrozen.append(name)
139
+ except Exception as e:
140
+ not_unfrozen.append((name, str(e)))
141
+ elif any(layer in name for layer in exclude):
142
+ try:
143
+ param.requires_grad_(True)
144
+ unfrozen.append(name)
145
+ except Exception as e:
146
+ not_unfrozen.append((name, str(e)))
147
+ else:
148
+ not_unfrozen.append((name, "Excluded"))
149
+ return dict(unfrozen=unfrozen, not_unfrozen=not_unfrozen)
150
+
57
151
  def to(self, *args, **kwargs):
58
152
  device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
59
153
  *args, **kwargs
@@ -186,11 +280,16 @@ class Model(nn.Module, ABC):
186
280
  )
187
281
 
188
282
  def get_weights(self, module_name: Optional[str] = None) -> List[Tensor]:
189
- """Returns the weights of the model entrie model or from a specified module"""
283
+ """Returns the weights of the model entry model or from a specified module"""
190
284
  if module_name is not None:
191
285
  assert hasattr(self, module_name), f"Module {module_name} does not exits"
192
286
  module = getattr(self, module_name)
193
- return [x.data.detach() for x in module.parameters()]
287
+ params = []
288
+ if isinstance(module, nn.Module):
289
+ return [x.data.detach() for x in module.parameters()]
290
+ elif isinstance(module, (Tensor, nn.Parameter)):
291
+ return [module.data.detach()]
292
+ raise (f"{module_name} is has no weights")
194
293
  return [x.data.detach() for x in self.parameters()]
195
294
 
196
295
  def print_trainable_parameters(
@@ -11,37 +11,36 @@ class PeriodDiscriminator(Model):
11
11
  use_spectral_norm=False,
12
12
  kernel_size: int = 5,
13
13
  stride: int = 3,
14
- initial_s: int = 32,
15
14
  ):
16
15
  super().__init__()
17
16
  self.period = period
17
+ self.stride = stride
18
+ self.kernel_size = kernel_size
18
19
  self.norm_f = weight_norm if use_spectral_norm == False else spectral_norm
20
+
21
+ self.channels = [32, 128, 512, 1024, 1024]
19
22
  self.first_pass = nn.Sequential(
20
23
  self.norm_f(
21
24
  nn.Conv2d(
22
- 1, initial_s * 4, (kernel_size, 1), (stride, 1), padding=(2, 0)
25
+ 1, self.channels[0], (kernel_size, 1), (stride, 1), padding=(2, 0)
23
26
  )
24
27
  ),
25
28
  nn.LeakyReLU(0.1),
26
29
  )
27
- self._last_sz = initial_s * 4
28
30
 
29
- self.convs = nn.ModuleList([self._get_next(i == 3) for i in range(4)])
31
+
32
+ self.convs = nn.ModuleList([self._get_next(self.channels[i+1], self.channels[i], i == 3) for i in range(4)])
30
33
 
31
34
  self.post_conv = nn.Conv2d(1024, 1, (stride, 1), 1, padding=(1, 0))
32
- self.kernel_size = kernel_size
33
- self.stride = stride
34
35
 
35
- def _get_next(self, is_last: bool = False):
36
- in_dim = self._last_sz
37
- self._last_sz *= 4
38
- print(self._last_sz, "-----------------------")
36
+ def _get_next(self, out_dim:int, last_in:int, is_last: bool = False):
39
37
  stride = (self.stride, 1) if not is_last else 1
38
+
40
39
  return nn.Sequential(
41
40
  self.norm_f(
42
41
  nn.Conv2d(
43
- in_dim,
44
- self._last_sz,
42
+ last_in,
43
+ out_dim,
45
44
  (self.kernel_size, 1),
46
45
  stride,
47
46
  padding=(2, 0),
@@ -91,6 +90,7 @@ class ScaleDiscriminator(nn.Module):
91
90
  def __init__(self, use_spectral_norm=False):
92
91
  super().__init__()
93
92
  norm_f = weight_norm if use_spectral_norm == False else spectral_norm
93
+ self.activation = nn.LeakyReLU(0.1)
94
94
  self.convs = nn.ModuleList(
95
95
  [
96
96
  norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
@@ -103,7 +103,6 @@ class ScaleDiscriminator(nn.Module):
103
103
  ]
104
104
  )
105
105
  self.post_conv = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
106
- self.activation = nn.LeakyReLU(0.1)
107
106
 
108
107
  def forward(self, x: torch.Tensor):
109
108
  """
@@ -147,9 +146,10 @@ class GeneralLossDescriminator(Model):
147
146
  super().__init__()
148
147
  self.mpd = MultiPeriodDiscriminator()
149
148
  self.msd = MultiScaleDiscriminator()
149
+ self.print_trainable_parameters()
150
150
 
151
151
  def _get_group_(self):
152
152
  pass
153
153
 
154
154
  def forward(self, x: Tensor, y_hat: Tensor):
155
- return
155
+ return
@@ -106,3 +106,44 @@ class Generator(Model):
106
106
  classname = m.__class__.__name__
107
107
  if "Conv" in classname:
108
108
  m.weight.data.normal_(mean, std)
109
+
110
+
111
+ # Below are items found in the Rishikesh's repo that might work for this generator.
112
+ # https://github.com/rishikksh20/iSTFTNet-pytorch/blob/781480e9563d4dff5a8cc9ef1af6c6e0cab025c8/models.py
113
+
114
+
115
+ def feature_loss(fmap_r, fmap_g, weight=2.0):
116
+ """Feature matching loss between real and generated feature maps."""
117
+ loss = 0.0
118
+ for dr, dg in zip(fmap_r, fmap_g):
119
+ for rl, gl in zip(dr, dg):
120
+ loss += torch.mean(torch.abs(rl - gl))
121
+ return loss * weight
122
+
123
+
124
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
125
+ """LSGAN-style loss for real and fake predictions."""
126
+ loss = 0.0
127
+ r_losses, g_losses = [], []
128
+
129
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
130
+ r_loss = torch.mean((1.0 - dr) ** 2)
131
+ g_loss = torch.mean(dg**2)
132
+ loss += r_loss + g_loss
133
+ r_losses.append(r_loss.item())
134
+ g_losses.append(g_loss.item())
135
+
136
+ return loss, r_losses, g_losses
137
+
138
+
139
+ def generator_loss(disc_generated_outputs):
140
+ """LSGAN generator loss encouraging fake to look like real (close to 1)."""
141
+ loss = 0.0
142
+ gen_losses = []
143
+
144
+ for dg in disc_generated_outputs:
145
+ l = torch.mean((1.0 - dg) ** 2)
146
+ gen_losses.append(l.item())
147
+ loss += l
148
+
149
+ return loss, gen_losses
@@ -0,0 +1,368 @@
1
+ __all__ = [
2
+ "NoiseSchedulerA",
3
+ "NoiseSchedulerB",
4
+ "NoiseSchedulerC",
5
+ "add_gaussian_noise",
6
+ "add_uniform_noise",
7
+ "add_linear_noise",
8
+ "add_impulse_noise",
9
+ "add_pink_noise",
10
+ "add_clipped_gaussian_noise",
11
+ "add_multiplicative_noise",
12
+ "apply_noise",
13
+ ]
14
+
15
+ from lt_utils.common import *
16
+ import torch.nn.functional as F
17
+ from .torch_commons import *
18
+ import math
19
+ import random
20
+ from .misc_utils import set_seed
21
+
22
+
23
+ def add_gaussian_noise(x: Tensor, noise_level=0.025):
24
+ noise = torch.randn_like(x) * noise_level
25
+ return x + noise
26
+
27
+
28
+ def add_uniform_noise(x: Tensor, noise_level=0.025):
29
+ noise = (torch.rand_like(x) - 0.5) * 2 * noise_level
30
+ return x + noise
31
+
32
+
33
+ def add_linear_noise(x, noise_level=0.05):
34
+ T = x.shape[-1]
35
+ ramp = torch.linspace(0, noise_level, T, device=x.device)
36
+ for _ in range(x.dim() - 1):
37
+ ramp = ramp.unsqueeze(0)
38
+ return x + ramp.expand_as(x)
39
+
40
+
41
+ def add_impulse_noise(x: Tensor, noise_level=0.025):
42
+ # For image inputs
43
+ probs = torch.rand_like(x)
44
+ x_clone = x.detach().clone()
45
+ x_clone[probs < (noise_level / 2)] = 0.0 # salt
46
+ x_clone[probs > (1 - noise_level / 2)] = 1.0 # pepper
47
+ return x_clone
48
+
49
+
50
+ def add_pink_noise(x: Tensor, noise_level=0.05):
51
+ # pink noise: divide freq spectrum by sqrt(f)
52
+ if x.ndim == 3:
53
+ x = x.view(-1, x.shape[-1]) # flatten to 2D [B*M, T]
54
+ pink_noised = []
55
+
56
+ for row in x:
57
+ white = torch.randn_like(row)
58
+ f = torch.fft.rfft(white)
59
+ freqs = torch.fft.rfftfreq(row.numel(), d=1.0)
60
+ freqs[0] = 1.0 # prevent div by 0
61
+ f /= freqs.sqrt()
62
+ pink = torch.fft.irfft(f, n=row.numel())
63
+ pink_noised.append(pink)
64
+
65
+ pink_noised = torch.stack(pink_noised, dim=0).view_as(x)
66
+ return x + pink_noised * noise_level
67
+
68
+
69
+ def add_clipped_gaussian_noise(x, noise_level=0.025):
70
+ noise = torch.randn_like(x) * noise_level
71
+ return torch.clamp(x + noise, 0.0, 1.0)
72
+
73
+
74
+ def add_multiplicative_noise(x, noise_level=0.025):
75
+ noise = 1 + torch.randn_like(x) * noise_level
76
+ return x * noise
77
+
78
+
79
+ _VALID_NOISES = [
80
+ "gaussian",
81
+ "uniform",
82
+ "linear",
83
+ "impulse",
84
+ "pink",
85
+ "clipped_gaussian",
86
+ "multiplicative",
87
+ ]
88
+
89
+ _NOISE_MAP = {
90
+ "gaussian": add_gaussian_noise,
91
+ "uniform": add_uniform_noise,
92
+ "linear": add_linear_noise,
93
+ "impulse": add_impulse_noise,
94
+ "pink": add_pink_noise,
95
+ "clipped_gaussian": add_clipped_gaussian_noise,
96
+ "multiplicative": add_multiplicative_noise,
97
+ }
98
+
99
+ _NOISE_DIM_SUPPORT = {
100
+ "gaussian": (1, 2),
101
+ "uniform": (1, 2),
102
+ "multiplicative": (1, 2, 3),
103
+ "clipped_gaussian": (1, 2, 3),
104
+ "linear": (2, 3),
105
+ "impulse": (2, 3),
106
+ "pink": (2, 3),
107
+ }
108
+
109
+
110
+ def apply_noise(
111
+ x: Tensor,
112
+ noise_type: str = "gaussian",
113
+ noise_level: float = 0.01,
114
+ seed: Optional[int] = None,
115
+ on_error: Literal["raise", "try_others", "return_unchanged"] = "raise",
116
+ _last_tries: list[str] = [],
117
+ ):
118
+ noise_type = noise_type.lower().strip()
119
+ last_tries = _last_tries
120
+
121
+ if noise_type not in _NOISE_MAP:
122
+ raise ValueError(f"Noise type '{noise_type}' not supported.")
123
+
124
+ # Check dimension compatibility
125
+ allowed_dims = _NOISE_DIM_SUPPORT.get(noise_type, (1, 2))
126
+ if x.ndim not in allowed_dims:
127
+ assert (
128
+ on_error != "raise"
129
+ ), f"Noise '{noise_type}' is not supported for {x.ndim}D input."
130
+ if on_error == "return_unchanged":
131
+ return x, None
132
+ elif on_error == "try_others":
133
+ remaining = [
134
+ n
135
+ for n in _VALID_NOISES
136
+ if n not in last_tries and x.ndim in _NOISE_DIM_SUPPORT[n]
137
+ ]
138
+ if not remaining:
139
+ return x, None
140
+ new_type = random.choice(remaining)
141
+ last_tries.append(new_type)
142
+ return (
143
+ apply_noise(
144
+ x, new_type, noise_level, seed, on_error, last_tries.copy()
145
+ ),
146
+ noise_type,
147
+ )
148
+ try:
149
+ if isinstance(seed, int):
150
+ set_seed(seed)
151
+ return _NOISE_MAP[noise_type](x, noise_level), noise_type
152
+ except Exception as e:
153
+ if on_error == "raise":
154
+ raise e
155
+ elif on_error == "return_unchanged":
156
+ return x, None
157
+ if len(last_tries) == len(_VALID_NOISES):
158
+ return x, None
159
+ remaining = [n for n in _VALID_NOISES if n not in last_tries]
160
+ new_type = random.choice(remaining)
161
+ last_tries.append(new_type)
162
+ return (
163
+ apply_noise(x, new_type, noise_level, seed, on_error, last_tries.copy()),
164
+ noise_type,
165
+ )
166
+
167
+
168
+ class NoiseSchedulerA(nn.Module):
169
+ def __init__(self, samples: int = 64):
170
+ super().__init__()
171
+ self.base_steps = samples
172
+
173
+ def plot_noise_progression(noise_seq: list[Tensor], titles: list[str] = None):
174
+ import matplotlib.pyplot as plt
175
+
176
+ steps = len(noise_seq)
177
+ plt.figure(figsize=(15, 3))
178
+ for i, tensor in enumerate(noise_seq):
179
+ plt.subplot(1, steps, i + 1)
180
+ plt.imshow(tensor.squeeze().cpu().numpy(), aspect="auto", origin="lower")
181
+ if titles:
182
+ plt.title(titles[i])
183
+ plt.axis("off")
184
+ plt.tight_layout()
185
+ plt.show()
186
+
187
+ def forward(
188
+ self,
189
+ source_item: torch.Tensor,
190
+ steps: Optional[int] = None,
191
+ noise_type: Literal[
192
+ "gaussian",
193
+ "uniform",
194
+ "linear",
195
+ "impulse",
196
+ "pink",
197
+ "clipped_gaussian",
198
+ "multiplicative",
199
+ ] = "gaussian",
200
+ seed: Optional[int] = None,
201
+ noise_level: float = 0.01,
202
+ shuffle_noise_types: bool = False,
203
+ return_dict: bool = True,
204
+ ):
205
+ if steps is None:
206
+ steps = self.base_steps
207
+ collected = [source_item.detach().clone()]
208
+ noise_history = []
209
+ for i in range(steps):
210
+ if i > 0 and shuffle_noise_types:
211
+ noise_type = random.choice(_VALID_NOISES)
212
+ current, noise_name = apply_noise(
213
+ collected[-1],
214
+ noise_type,
215
+ noise_level,
216
+ seed=seed,
217
+ on_error="try_others",
218
+ )
219
+ noise_history.append(noise_name)
220
+ collected.append(current)
221
+
222
+ if return_dict:
223
+ return {
224
+ "steps": collected,
225
+ "history": noise_history,
226
+ "final": collected[-1],
227
+ "init": collected[0],
228
+ }
229
+ return collected, noise_history
230
+
231
+
232
+ class NoiseSchedulerB(nn.Module):
233
+ def __init__(self, timesteps: int = 512):
234
+ super().__init__()
235
+
236
+ betas = torch.linspace(1e-4, 0.02, timesteps)
237
+ alphas = 1.0 - betas
238
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
239
+
240
+ self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
241
+ self.register_buffer(
242
+ "sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
243
+ )
244
+
245
+ self.timesteps = timesteps
246
+ self.default_noise = math.sqrt(1.25)
247
+
248
+ def _get_random_noise(
249
+ self,
250
+ min_max: Tuple[float, float] = (-3, 3),
251
+ seed: Optional[int] = None,
252
+ ) -> float:
253
+ if isinstance(seed, int):
254
+ random.seed(seed)
255
+ return random.uniform(*min_max)
256
+
257
+ def set_noise(
258
+ self,
259
+ noise: Optional[Union[Tensor, Number]] = None,
260
+ seed: Optional[int] = None,
261
+ min_max: Tuple[float, float] = (-3, 3),
262
+ default: bool = False,
263
+ ):
264
+ if noise is not None:
265
+ self.default_noise = noise
266
+ else:
267
+ self.default_noise = (
268
+ math.sqrt(1.25) if default else self._get_random_noise(min_max, seed)
269
+ )
270
+
271
+ def forward(
272
+ self, x_0: Tensor, t: int, noise: Optional[Union[Tensor, float]] = None
273
+ ) -> Tensor:
274
+ apply_noise()
275
+ assert (
276
+ 0 >= t < self.timesteps
277
+ ), f"Time step t={t} is out of bounds for scheduler with {self.timesteps} steps."
278
+
279
+ if noise is None:
280
+ noise = torch.randn_like(x_0) * self.default_noise
281
+
282
+ elif isinstance(noise, (float, int)):
283
+ noise = torch.randn_like(x_0) * noise
284
+
285
+ alpha_term = self.sqrt_alpha_cumprod[t] * x_0
286
+ noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
287
+ return alpha_term + noise_term
288
+
289
+
290
+ class NoiseSchedulerC(nn.Module):
291
+ def __init__(self, timesteps: int = 512):
292
+ super().__init__()
293
+
294
+ betas = torch.linspace(1e-4, 0.02, timesteps)
295
+ alphas = 1.0 - betas
296
+ alpha_cumprod = torch.cumprod(alphas, dim=0)
297
+
298
+ self.register_buffer("sqrt_alpha_cumprod", torch.sqrt(alpha_cumprod))
299
+ self.register_buffer(
300
+ "sqrt_one_minus_alpha_cumprod", torch.sqrt(1.0 - alpha_cumprod)
301
+ )
302
+
303
+ self.timesteps = timesteps
304
+ self.default_noise_strength = math.sqrt(1.25)
305
+ self.default_noise_type = "gaussian"
306
+ self.noise_seed = None
307
+
308
+ def _get_random_uniform(self, shape, min_val=-1.0, max_val=1.0):
309
+ return torch.empty(shape).uniform_(min_val, max_val)
310
+
311
+ def _get_noise(self, x: Tensor, noise_type: str, noise_level: float) -> Tensor:
312
+ # Basic noise types
313
+ if noise_type == "gaussian":
314
+ return torch.randn_like(x) * noise_level
315
+ elif noise_type == "uniform":
316
+ return self._get_random_uniform(x.shape) * noise_level
317
+ elif noise_type == "multiplicative":
318
+ return x * (1 + (torch.randn_like(x) * noise_level))
319
+ elif noise_type == "clipped_gaussian":
320
+ noise = torch.randn_like(x) * noise_level
321
+ return noise.clamp(-1.0, 1.0)
322
+ elif noise_type == "impulse":
323
+ mask = torch.rand_like(x) < noise_level
324
+ impulses = torch.randn_like(x) * noise_level
325
+ return x + impulses * mask
326
+ else:
327
+ raise ValueError(f"Unsupported noise type: '{noise_type}'")
328
+
329
+ def set_noise(
330
+ self,
331
+ noise_strength: Optional[Union[Tensor, float]] = None,
332
+ noise_type: Optional[str] = None,
333
+ seed: Optional[int] = None,
334
+ default: bool = False,
335
+ ):
336
+ if noise_strength is not None:
337
+ self.default_noise_strength = noise_strength
338
+ elif default:
339
+ self.default_noise_strength = math.sqrt(1.25)
340
+
341
+ if noise_type is not None:
342
+ self.default_noise_type = noise_type.lower().strip()
343
+
344
+ if isinstance(seed, int):
345
+ self.noise_seed = seed
346
+ torch.manual_seed(seed)
347
+ random.seed(seed)
348
+
349
+ def forward(
350
+ self,
351
+ x_0: Tensor,
352
+ t: int,
353
+ noise: Optional[Union[Tensor, float]] = None,
354
+ noise_type: Optional[str] = None,
355
+ ) -> Tensor:
356
+ assert 0 <= t < self.timesteps, f"t={t} is out of bounds [0, {self.timesteps})"
357
+
358
+ noise_type = noise_type or self.default_noise_type
359
+ noise_level = self.default_noise_strength
360
+
361
+ if noise is None:
362
+ noise = self._get_noise(x_0, noise_type, noise_level)
363
+ elif isinstance(noise, (float, int)):
364
+ noise = self._get_noise(x_0, noise_type, noise)
365
+
366
+ alpha_term = self.sqrt_alpha_cumprod[t] * x_0
367
+ noise_term = self.sqrt_one_minus_alpha_cumprod[t] * noise
368
+ return alpha_term + noise_term
@@ -0,0 +1,3 @@
1
+ from .audio import AudioProcessor
2
+
3
+ __all__ = ["AudioProcessor"]