lt-tensor 0.0.1.dev0__tar.gz → 0.0.1.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/PKG-INFO +1 -1
  2. lt_tensor-0.0.1.dev1/lt_tensor/__init__.py +1 -0
  3. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/math_ops.py +0 -60
  4. lt_tensor-0.0.1.dev1/lt_tensor/transform.py +332 -0
  5. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor.egg-info/PKG-INFO +1 -1
  6. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/setup.py +1 -1
  7. lt_tensor-0.0.1.dev0/lt_tensor/__init__.py +0 -0
  8. lt_tensor-0.0.1.dev0/lt_tensor/transform.py +0 -113
  9. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/LICENSE +0 -0
  10. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/README.md +0 -0
  11. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/_basics.py +0 -0
  12. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/_torch_commons.py +0 -0
  13. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/lr_schedulers.py +0 -0
  14. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/misc_utils.py +0 -0
  15. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/__init__.py +0 -0
  16. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/basic.py +0 -0
  17. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/diffusion/__init__.py +0 -0
  18. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/diffusion/models.py +0 -0
  19. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/residual.py +0 -0
  20. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/transformer_models/__init__.py +0 -0
  21. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/transformer_models/models.py +0 -0
  22. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/model_zoo/transformer_models/positional_encoders.py +0 -0
  23. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor/monotonic_align.py +0 -0
  24. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor.egg-info/SOURCES.txt +0 -0
  25. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor.egg-info/dependency_links.txt +0 -0
  26. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor.egg-info/requires.txt +0 -0
  27. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/lt_tensor.egg-info/top_level.txt +0 -0
  28. {lt_tensor-0.0.1.dev0 → lt_tensor-0.0.1.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1.dev0
3
+ Version: 0.0.1.dev1
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -0,0 +1 @@
1
+ __version__ = "0.0.1dev1"
@@ -47,66 +47,6 @@ def normalize_tensor(x: Tensor, eps: float = 1e-8) -> Tensor:
47
47
  return x / (torch.norm(x, dim=-1, keepdim=True) + eps)
48
48
 
49
49
 
50
- def stft(
51
- waveform: Tensor,
52
- n_fft: int = 512,
53
- hop_length: Optional[int] = None,
54
- win_length: Optional[int] = None,
55
- window_fn: str = "hann",
56
- center: bool = True,
57
- return_complex: bool = True,
58
- ) -> Tensor:
59
- """Performs short-time Fourier transform using PyTorch."""
60
- window = (
61
- torch.hann_window(win_length or n_fft).to(waveform.device)
62
- if window_fn == "hann"
63
- else None
64
- )
65
- return torch.stft(
66
- input=waveform,
67
- n_fft=n_fft,
68
- hop_length=hop_length,
69
- win_length=win_length,
70
- window=window,
71
- center=center,
72
- return_complex=return_complex,
73
- )
74
-
75
-
76
- def istft(
77
- stft_matrix: Tensor,
78
- n_fft: int = 512,
79
- hop_length: Optional[int] = None,
80
- win_length: Optional[int] = None,
81
- window_fn: str = "hann",
82
- center: bool = True,
83
- length: Optional[int] = None,
84
- ) -> Tensor:
85
- """Performs inverse short-time Fourier transform using PyTorch."""
86
- window = (
87
- torch.hann_window(win_length or n_fft).to(stft_matrix.device)
88
- if window_fn == "hann"
89
- else None
90
- )
91
- return torch.istft(
92
- input=stft_matrix,
93
- n_fft=n_fft,
94
- hop_length=hop_length,
95
- win_length=win_length,
96
- window=window,
97
- center=center,
98
- length=length,
99
- )
100
-
101
-
102
- def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
103
- """Returns the FFT of a real tensor."""
104
- return torch.fft.fft(x, norm=norm)
105
-
106
-
107
- def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
108
- """Returns the inverse FFT of a complex tensor."""
109
- return torch.fft.ifft(x, norm=norm)
110
50
 
111
51
 
112
52
  def log_magnitude(stft_complex: Tensor, eps: float = 1e-5) -> Tensor:
@@ -0,0 +1,332 @@
1
+ from ._torch_commons import *
2
+ import torchaudio
3
+ import math
4
+ import random
5
+ from .misc_utils import log_tensor
6
+
7
+
8
+ def to_mel_spectrogram(
9
+ waveform: torch.Tensor,
10
+ sample_rate: int = 22050,
11
+ n_fft: int = 1024,
12
+ hop_length: Optional[int] = None,
13
+ win_length: Optional[int] = None,
14
+ n_mels: int = 80,
15
+ f_min: float = 0.0,
16
+ f_max: Optional[float] = None,
17
+ ) -> torch.Tensor:
18
+ """Converts waveform to mel spectrogram."""
19
+ mel_spectrogram = torchaudio.transforms.MelSpectrogram(
20
+ sample_rate=sample_rate,
21
+ n_fft=n_fft,
22
+ hop_length=hop_length,
23
+ win_length=win_length,
24
+ n_mels=n_mels,
25
+ f_min=f_min,
26
+ f_max=f_max,
27
+ )
28
+ return mel_spectrogram(waveform)
29
+
30
+
31
+ def stft(
32
+ waveform: Tensor,
33
+ n_fft: int = 512,
34
+ hop_length: Optional[int] = None,
35
+ win_length: Optional[int] = None,
36
+ window_fn: str = "hann",
37
+ center: bool = True,
38
+ return_complex: bool = True,
39
+ ) -> Tensor:
40
+ """Performs short-time Fourier transform using PyTorch."""
41
+ window = (
42
+ torch.hann_window(win_length or n_fft).to(waveform.device)
43
+ if window_fn == "hann"
44
+ else None
45
+ )
46
+ return torch.stft(
47
+ input=waveform,
48
+ n_fft=n_fft,
49
+ hop_length=hop_length,
50
+ win_length=win_length,
51
+ window=window,
52
+ center=center,
53
+ return_complex=return_complex,
54
+ )
55
+
56
+
57
+ def istft(
58
+ stft_matrix: Tensor,
59
+ n_fft: int = 512,
60
+ hop_length: Optional[int] = None,
61
+ win_length: Optional[int] = None,
62
+ window_fn: str = "hann",
63
+ center: bool = True,
64
+ length: Optional[int] = None,
65
+ ) -> Tensor:
66
+ """Performs inverse short-time Fourier transform using PyTorch."""
67
+ window = (
68
+ torch.hann_window(win_length or n_fft).to(stft_matrix.device)
69
+ if window_fn == "hann"
70
+ else None
71
+ )
72
+ return torch.istft(
73
+ input=stft_matrix,
74
+ n_fft=n_fft,
75
+ hop_length=hop_length,
76
+ win_length=win_length,
77
+ window=window,
78
+ center=center,
79
+ length=length,
80
+ )
81
+
82
+
83
+ def fft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
84
+ """Returns the FFT of a real tensor."""
85
+ return torch.fft.fft(x, norm=norm)
86
+
87
+
88
+ def ifft(x: Tensor, norm: Optional[str] = "backward") -> Tensor:
89
+ """Returns the inverse FFT of a complex tensor."""
90
+ return torch.fft.ifft(x, norm=norm)
91
+
92
+
93
+ def to_log_mel_spectrogram(
94
+ waveform: torch.Tensor, sample_rate: int = 22050, eps: float = 1e-9, **kwargs
95
+ ) -> torch.Tensor:
96
+ """Converts waveform to log-mel spectrogram."""
97
+ mel = to_mel_spectrogram(waveform, sample_rate, **kwargs)
98
+ return torch.log(mel + eps)
99
+
100
+
101
+ def normalize(
102
+ x: torch.Tensor,
103
+ mean: Optional[float] = None,
104
+ std: Optional[float] = None,
105
+ eps: float = 1e-9,
106
+ ) -> torch.Tensor:
107
+ """Normalizes tensor by mean and std."""
108
+ if mean is None:
109
+ mean = x.mean()
110
+ if std is None:
111
+ std = x.std()
112
+ return (x - mean) / (std + eps)
113
+
114
+
115
+ def min_max_scale(
116
+ x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0
117
+ ) -> torch.Tensor:
118
+ """Scales tensor to [min_val, max_val] range."""
119
+ x_min, x_max = x.min(), x.max()
120
+ return (x - x_min) / (x_max - x_min + 1e-8) * (max_val - min_val) + min_val
121
+
122
+
123
+ def mel_to_linear(
124
+ mel_spec: torch.Tensor, mel_fb: torch.Tensor, eps: float = 1e-8
125
+ ) -> torch.Tensor:
126
+ """Approximate inversion of mel to linear spectrogram using pseudo-inverse."""
127
+ mel_fb_inv = torch.pinverse(mel_fb)
128
+ return torch.matmul(mel_fb_inv, mel_spec + eps)
129
+
130
+
131
+ def add_noise(x: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
132
+ """Adds Gaussian noise to tensor."""
133
+ return x + noise_level * torch.randn_like(x)
134
+
135
+
136
+ def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
137
+ """Shifts tensor along time axis (last dim)."""
138
+ return torch.roll(x, shifts=shift, dims=-1)
139
+
140
+
141
+ def stretch_tensor(x: torch.Tensor, rate: float, mode: str = "linear") -> torch.Tensor:
142
+ """Time-stretch tensor using interpolation."""
143
+ B, C, T = x.shape if x.ndim == 3 else (1, 1, x.shape[0])
144
+ new_T = int(T * rate)
145
+ x_reshaped = x.view(B * C, T).unsqueeze(1)
146
+ stretched = torch.nn.functional.interpolate(x_reshaped, size=new_T, mode=mode)
147
+ return (
148
+ stretched.squeeze(1).view(B, C, new_T) if x.ndim == 3 else stretched.squeeze()
149
+ )
150
+
151
+
152
+ def pad_tensor(
153
+ x: torch.Tensor, target_len: int, pad_value: float = 0.0
154
+ ) -> torch.Tensor:
155
+ """Pads tensor to target length along last dimension."""
156
+ current_len = x.shape[-1]
157
+ if current_len >= target_len:
158
+ return x[..., :target_len]
159
+ padding = [0] * (2 * (x.ndim - 1)) + [0, target_len - current_len]
160
+ return F.pad(x, padding, value=pad_value)
161
+
162
+
163
+ def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
164
+ # Expect shape [B] or [B, 1]
165
+ if timesteps.dim() > 1:
166
+ timesteps = timesteps.view(-1) # flatten to [B]
167
+
168
+ device = timesteps.device
169
+ half_dim = dim // 2
170
+ emb = torch.exp(
171
+ torch.arange(half_dim, device=device) * -(math.log(10000.0) / half_dim)
172
+ )
173
+ emb = timesteps[:, None].float() * emb[None, :] # [B, half_dim]
174
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # [B, dim]
175
+ return emb
176
+
177
+
178
+ def generate_window(
179
+ M: int, alpha: float = 0.5, device: Optional[DeviceType] = None
180
+ ) -> Tensor:
181
+ if M < 1:
182
+ raise ValueError("Window length M must be >= 1.")
183
+ if M == 1:
184
+ return torch.ones(1, device=device)
185
+
186
+ n = torch.arange(M, dtype=torch.float32, device=device)
187
+ window = alpha - (1.0 - alpha) * torch.cos(2.0 * math.pi * n / (M - 1))
188
+ return window
189
+
190
+
191
+ def pad_center(tensor: torch.Tensor, size: int, axis: int = -1) -> torch.Tensor:
192
+ n = tensor.shape[axis]
193
+ if size < n:
194
+ raise ValueError(f"Target size ({size}) must be at least input size ({n})")
195
+
196
+ lpad = (size - n) // 2
197
+ rpad = size - n - lpad
198
+
199
+ pad = [0] * (2 * tensor.ndim)
200
+ pad[2 * axis + 1] = rpad
201
+ pad[2 * axis] = lpad
202
+
203
+ return F.pad(tensor, pad, mode="constant", value=0)
204
+
205
+
206
+ def normalize(
207
+ S: torch.Tensor,
208
+ norm: float = float("inf"),
209
+ axis: int = 0,
210
+ threshold: float = 1e-10,
211
+ fill: bool = False,
212
+ ) -> torch.Tensor:
213
+ mag = S.abs().float()
214
+
215
+ if norm is None:
216
+ return S
217
+
218
+ elif norm == float("inf"):
219
+ length = mag.max(dim=axis, keepdim=True).values
220
+
221
+ elif norm == float("-inf"):
222
+ length = mag.min(dim=axis, keepdim=True).values
223
+
224
+ elif norm == 0:
225
+ length = (mag > 0).sum(dim=axis, keepdim=True).float()
226
+
227
+ elif norm > 0:
228
+ length = (mag**norm).sum(dim=axis, keepdim=True) ** (1.0 / norm)
229
+
230
+ else:
231
+ raise ValueError(f"Unsupported norm: {norm}")
232
+
233
+ small_idx = length < threshold
234
+ length = length.clone()
235
+ if fill:
236
+ length[small_idx] = float("nan")
237
+ Snorm = S / length
238
+ Snorm[Snorm != Snorm] = 1.0 # replace nan with fill_norm (default 1.0)
239
+ else:
240
+ length[small_idx] = float("inf")
241
+ Snorm = S / length
242
+
243
+ return Snorm
244
+
245
+
246
+ def window_sumsquare(
247
+ window_spec: Union[str, int, float, Callable, List[Any], Tuple[Any, ...]],
248
+ n_frames: int,
249
+ hop_length: int = 300,
250
+ win_length: int = 1200,
251
+ n_fft: int = 2048,
252
+ dtype: torch.dtype = torch.float32,
253
+ norm: Optional[Union[int, float]] = None,
254
+ device: Optional[torch.device] = "cpu",
255
+ ):
256
+ if win_length is None:
257
+ win_length = n_fft
258
+
259
+ total_length = n_fft + hop_length * (n_frames - 1)
260
+ x = torch.zeros(total_length, dtype=dtype, device=device)
261
+
262
+ # Get the window (from scipy for now)
263
+ win = generate_window(window_spec, win_length, fftbins=True)
264
+ win = torch.tensor(win, dtype=dtype, device=device)
265
+
266
+ # Normalize and square
267
+ win_sq = normalize(win, norm=norm, axis=0) ** 2
268
+ win_sq = pad_center(win_sq, size=n_fft, axis=0)
269
+
270
+ # Accumulate squared windows
271
+ for i in range(n_frames):
272
+ sample = i * hop_length
273
+ end = min(total_length, sample + n_fft)
274
+ length = end - sample
275
+ x[sample:end] += win_sq[:length]
276
+
277
+ return x
278
+
279
+
280
+ def get_window(win_length: int = 1200):
281
+ return generate_window(win_length)
282
+
283
+
284
+ def inverse_transform(
285
+ spec: Tensor,
286
+ phase: Tensor,
287
+ window: Optional[Tensor] = None,
288
+ n_fft: int = 2048,
289
+ hop_length: int = 300,
290
+ win_length: int = 1200,
291
+ length: Optional[torch.shape] = None,
292
+ ):
293
+ if window is None:
294
+ window = generate_window(win_length)
295
+ return torch.istft(
296
+ spec * torch.exp(phase * 1j),
297
+ n_fft,
298
+ hop_length,
299
+ win_length,
300
+ window=window,
301
+ length=length,
302
+ )
303
+
304
+
305
+ def stft_istft_rebuild(
306
+ input_data: Tensor,
307
+ window: Optional[Tensor] = None,
308
+ n_fft: int = 2048,
309
+ hop_length: int = 300,
310
+ win_length: int = 1200,
311
+ ):
312
+ """
313
+ Perform STFT followed by ISTFT reconstruction using magnitude and phase.
314
+ """
315
+ if window is None:
316
+ window = generate_window(win_length)
317
+ st = torch.stft(
318
+ input_data,
319
+ n_fft,
320
+ hop_length,
321
+ win_length,
322
+ window=window,
323
+ return_complex=True,
324
+ )
325
+ return torch.istft(
326
+ torch.abs(st) * torch.exp(1j * torch.angle(st)),
327
+ n_fft,
328
+ hop_length,
329
+ win_length,
330
+ window=window,
331
+ length=input_data.shape[-1],
332
+ ).squeeze(0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1.dev0
3
+ Version: 0.0.1.dev1
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -4,7 +4,7 @@ with open("README.md", "r", encoding="utf-8") as f:
4
4
  long_description = f.read()
5
5
 
6
6
  setup(
7
- version="0.0.1dev",
7
+ version="0.0.1dev1",
8
8
  name="lt-tensor",
9
9
  description="General utilities for PyTorch and others. Built for general use.",
10
10
  long_description=long_description,
File without changes
@@ -1,113 +0,0 @@
1
- from ._torch_commons import *
2
- import torchaudio
3
- import math
4
- import random
5
- from .misc_utils import log_tensor
6
-
7
-
8
- def to_mel_spectrogram(
9
- waveform: torch.Tensor,
10
- sample_rate: int = 22050,
11
- n_fft: int = 1024,
12
- hop_length: Optional[int] = None,
13
- win_length: Optional[int] = None,
14
- n_mels: int = 80,
15
- f_min: float = 0.0,
16
- f_max: Optional[float] = None,
17
- ) -> torch.Tensor:
18
- """Converts waveform to mel spectrogram."""
19
- mel_spectrogram = torchaudio.transforms.MelSpectrogram(
20
- sample_rate=sample_rate,
21
- n_fft=n_fft,
22
- hop_length=hop_length,
23
- win_length=win_length,
24
- n_mels=n_mels,
25
- f_min=f_min,
26
- f_max=f_max,
27
- )
28
- return mel_spectrogram(waveform)
29
-
30
-
31
- def to_log_mel_spectrogram(
32
- waveform: torch.Tensor, sample_rate: int = 22050, eps: float = 1e-9, **kwargs
33
- ) -> torch.Tensor:
34
- """Converts waveform to log-mel spectrogram."""
35
- mel = to_mel_spectrogram(waveform, sample_rate, **kwargs)
36
- return torch.log(mel + eps)
37
-
38
-
39
- def normalize(
40
- x: torch.Tensor,
41
- mean: Optional[float] = None,
42
- std: Optional[float] = None,
43
- eps: float = 1e-9,
44
- ) -> torch.Tensor:
45
- """Normalizes tensor by mean and std."""
46
- if mean is None:
47
- mean = x.mean()
48
- if std is None:
49
- std = x.std()
50
- return (x - mean) / (std + eps)
51
-
52
-
53
- def min_max_scale(
54
- x: torch.Tensor, min_val: float = 0.0, max_val: float = 1.0
55
- ) -> torch.Tensor:
56
- """Scales tensor to [min_val, max_val] range."""
57
- x_min, x_max = x.min(), x.max()
58
- return (x - x_min) / (x_max - x_min + 1e-8) * (max_val - min_val) + min_val
59
-
60
-
61
- def mel_to_linear(
62
- mel_spec: torch.Tensor, mel_fb: torch.Tensor, eps: float = 1e-8
63
- ) -> torch.Tensor:
64
- """Approximate inversion of mel to linear spectrogram using pseudo-inverse."""
65
- mel_fb_inv = torch.pinverse(mel_fb)
66
- return torch.matmul(mel_fb_inv, mel_spec + eps)
67
-
68
-
69
- def add_noise(x: torch.Tensor, noise_level: float = 0.01) -> torch.Tensor:
70
- """Adds Gaussian noise to tensor."""
71
- return x + noise_level * torch.randn_like(x)
72
-
73
-
74
- def shift_time(x: torch.Tensor, shift: int) -> torch.Tensor:
75
- """Shifts tensor along time axis (last dim)."""
76
- return torch.roll(x, shifts=shift, dims=-1)
77
-
78
-
79
- def stretch_tensor(x: torch.Tensor, rate: float, mode: str = "linear") -> torch.Tensor:
80
- """Time-stretch tensor using interpolation."""
81
- B, C, T = x.shape if x.ndim == 3 else (1, 1, x.shape[0])
82
- new_T = int(T * rate)
83
- x_reshaped = x.view(B * C, T).unsqueeze(1)
84
- stretched = torch.nn.functional.interpolate(x_reshaped, size=new_T, mode=mode)
85
- return (
86
- stretched.squeeze(1).view(B, C, new_T) if x.ndim == 3 else stretched.squeeze()
87
- )
88
-
89
-
90
- def pad_tensor(
91
- x: torch.Tensor, target_len: int, pad_value: float = 0.0
92
- ) -> torch.Tensor:
93
- """Pads tensor to target length along last dimension."""
94
- current_len = x.shape[-1]
95
- if current_len >= target_len:
96
- return x[..., :target_len]
97
- padding = [0] * (2 * (x.ndim - 1)) + [0, target_len - current_len]
98
- return F.pad(x, padding, value=pad_value)
99
-
100
-
101
- def get_sinusoidal_embedding(timesteps: torch.Tensor, dim: int) -> torch.Tensor:
102
- # Expect shape [B] or [B, 1]
103
- if timesteps.dim() > 1:
104
- timesteps = timesteps.view(-1) # flatten to [B]
105
-
106
- device = timesteps.device
107
- half_dim = dim // 2
108
- emb = torch.exp(
109
- torch.arange(half_dim, device=device) * -(math.log(10000.0) / half_dim)
110
- )
111
- emb = timesteps[:, None].float() * emb[None, :] # [B, half_dim]
112
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # [B, dim]
113
- return emb
File without changes
File without changes
File without changes