lt-tensor 0.0.1a12__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,11 +5,14 @@ from lt_utils.type_utils import is_file, is_array
5
5
  from lt_tensor.misc_utils import log_tensor
6
6
  import librosa
7
7
  import torchaudio
8
-
8
+ import numpy as np
9
9
  from lt_tensor.transform import InverseTransformConfig, InverseTransform
10
10
  from lt_utils.file_ops import FileScan, get_file_name, path_to_str
11
-
11
+ from torchaudio.functional import detect_pitch_frequency
12
12
  from lt_tensor.model_base import Model
13
+ import torch.nn.functional as F
14
+
15
+ DEFAULT_DEVICE = torch.tensor([0]).device
13
16
 
14
17
 
15
18
  class AudioProcessor(Model):
@@ -21,14 +24,14 @@ class AudioProcessor(Model):
21
24
  win_length: Optional[int] = None,
22
25
  hop_length: Optional[int] = None,
23
26
  f_min: float = 0,
24
- f_max: float | None = None,
27
+ f_max: float = 12000.0,
25
28
  center: bool = True,
26
29
  mel_scale: Literal["htk", "slaney"] = "htk",
27
30
  std: int = 4,
28
31
  mean: int = -4,
29
32
  n_iter: int = 32,
30
33
  window: Optional[Tensor] = None,
31
- normalized: bool =False,
34
+ normalized: bool = False,
32
35
  onesided: Optional[bool] = None,
33
36
  ):
34
37
  super().__init__()
@@ -38,7 +41,7 @@ class AudioProcessor(Model):
38
41
  self.n_fft = n_fft
39
42
  self.n_stft = n_fft // 2 + 1
40
43
  self.f_min = f_min
41
- self.f_max = f_max
44
+ self.f_max = max(min(f_max, 12000), self.f_min + 1)
42
45
  self.n_iter = n_iter
43
46
  self.hop_length = hop_length or n_fft // 4
44
47
  self.win_length = win_length or n_fft
@@ -76,7 +79,176 @@ class AudioProcessor(Model):
76
79
  "window",
77
80
  (torch.hann_window(self.win_length) if window is None else window),
78
81
  )
79
- # self._inv_transform = InverseTransform(**inverse_transform_config.to_dict())
82
+
83
+ def from_numpy(
84
+ self,
85
+ array: np.ndarray,
86
+ device: Optional[torch.device] = None,
87
+ dtype: Optional[torch.dtype] = None,
88
+ ):
89
+ converted = torch.from_numpy(array)
90
+ if not any([device is not None, dtype is not None]):
91
+ return converted
92
+ return converted.to(device=device, dtype=dtype)
93
+
94
+ def from_numpy_batch(
95
+ self,
96
+ arrays: List[np.ndarray],
97
+ device: Optional[torch.device] = None,
98
+ dtype: Optional[torch.dtype] = None,
99
+ ):
100
+ stacked = torch.stack([torch.from_numpy(x) for x in arrays])
101
+ if not any([device is not None, dtype is not None]):
102
+ return stacked
103
+ return stacked.to(device=device, dtype=dtype)
104
+
105
+ def to_numpy_safe(self, tensor: Tensor):
106
+ return tensor.detach().to(DEFAULT_DEVICE).numpy(force=True)
107
+
108
+ def compute_rms(
109
+ self, audio: Union[Tensor, np.ndarray], mel: Optional[Tensor] = None
110
+ ):
111
+ default_dtype = audio.dtype
112
+ default_device = audio.device
113
+ assert audio.ndim in [1, 2], (
114
+ f"Audio should have 1D for unbatched and 2D for batched"
115
+ ", received instead a: {audio.ndim}D"
116
+ )
117
+ if mel is not None:
118
+ assert mel.ndim in [2, 3], (
119
+ "Mel spectogram should have 2D dim for non-batched or 3D dim for both non-batched or batched"
120
+ f". Received instead {mel.ndim}D."
121
+ )
122
+ if audio.ndim == 2:
123
+ B = audio.shape[0]
124
+ else:
125
+ B = 1
126
+ audio = audio.unsqueeze(0)
127
+
128
+ if mel is not None:
129
+ if mel.ndim == 2:
130
+ assert B == 1, "Batch from mel and audio must be the same!"
131
+ mel = mel.unsqueeze(0)
132
+ else:
133
+ assert B == mel.shape[0], "Batch from mel and audio must be the same!"
134
+ mel = self.to_numpy_safe(mel)
135
+ gt_mel = lambda idx: mel[idx, :, :]
136
+
137
+ else:
138
+ gt_mel = lambda idx: None
139
+ audio = self.to_numpy_safe(audio)
140
+ if B == 1:
141
+ _r = librosa.feature.rms(
142
+ y=audio, frame_length=self.n_fft, hop_length=self.hop_length
143
+ )[0]
144
+ rms = self.from_numpy(_r, default_device, default_dtype)
145
+
146
+ else:
147
+ rms_ = []
148
+ for i in range(B):
149
+ _r = librosa.feature.rms(
150
+ y=audio[i, :],
151
+ S=gt_mel(i),
152
+ frame_length=self.n_fft,
153
+ hop_length=self.hop_length,
154
+ )[0]
155
+ rms_.append(_r)
156
+ rms = self.from_numpy_batch(rms_, default_device, default_dtype)
157
+
158
+ return rms
159
+
160
+ def compute_pitch(
161
+ self,
162
+ audio: Tensor,
163
+ ):
164
+ default_dtype = audio.dtype
165
+ default_device = audio.device
166
+ assert audio.ndim in [1, 2], (
167
+ f"Audio should have 1D for unbatched and 2D for batched"
168
+ ", received instead a: {audio.ndim}D"
169
+ )
170
+ if audio.ndim == 2:
171
+ B = audio.shape[0]
172
+ else:
173
+ B = 1
174
+ fmin = max(self.f_min, 80)
175
+ if B == 1:
176
+ f0 = self.from_numpy(
177
+ librosa.yin(
178
+ self.to_numpy_safe(audio),
179
+ fmin=fmin,
180
+ fmax=self.f_max,
181
+ frame_length=self.n_fft,
182
+ sr=self.sample_rate,
183
+ hop_length=self.hop_length,
184
+ center=self.center,
185
+ ),
186
+ default_device,
187
+ default_dtype,
188
+ )
189
+
190
+ else:
191
+ f0_ = []
192
+ for i in range(B):
193
+ r = librosa.yin(
194
+ self.to_numpy_safe(audio[i, :]),
195
+ fmin=fmin,
196
+ fmax=self.f_max,
197
+ frame_length=self.n_fft,
198
+ sr=self.sample_rate,
199
+ hop_length=self.hop_length,
200
+ center=self.center,
201
+ )
202
+ f0_.append(r)
203
+ f0 = self.from_numpy_batch(f0_, default_device, default_dtype)
204
+
205
+ # librosa.pyin(self.f_min, self.f_max)
206
+ return f0 # dict(f0=f0, attention_mask=f0 != f_max)
207
+
208
+ def compute_pitch_torch(self, audio: Tensor):
209
+ return detect_pitch_frequency(
210
+ audio,
211
+ sample_rate=self.sample_rate,
212
+ frame_time=self.n_fft,
213
+ win_length=self.win_length,
214
+ freq_low=max(self.f_min, 35),
215
+ freq_high=self.f_max,
216
+ )
217
+
218
+ def interpolate_tensor(
219
+ self,
220
+ tensor: Tensor,
221
+ target_len: int,
222
+ mode: Literal[
223
+ "nearest",
224
+ "linear",
225
+ "bilinear",
226
+ "bicubic",
227
+ "trilinear",
228
+ "area",
229
+ "nearest-exact",
230
+ ] = "nearest",
231
+ align_corners: Optional[bool] = None,
232
+ scale_factor: Optional[list[float]] = None,
233
+ recompute_scale_factor: Optional[bool] = None,
234
+ antialias: bool = False,
235
+ ):
236
+ """
237
+ The modes available for upsampling are: `nearest`, `linear` (3D-only),
238
+ `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only)
239
+ """
240
+
241
+ if tensor.ndim == 2: # [1, T]
242
+ tensor = tensor.unsqueeze(1) # [1, 1, T]
243
+ return F.interpolate(
244
+ tensor,
245
+ size=target_len,
246
+ mode=mode,
247
+ align_corners=align_corners,
248
+ scale_factor=scale_factor,
249
+ recompute_scale_factor=recompute_scale_factor,
250
+ antialias=antialias,
251
+ )
80
252
 
81
253
  def inverse_transform(
82
254
  self,
@@ -95,7 +267,9 @@ class AudioProcessor(Model):
95
267
  n_fft=n_fft or self.n_fft,
96
268
  hop_length=hop_length or self.hop_length,
97
269
  win_length=win_length or self.win_length,
98
- window=torch.hann_window(win_length or self.win_length, device=spec.device),
270
+ window=torch.hann_window(
271
+ win_length or self.win_length, device=spec.device
272
+ ),
99
273
  center=self.center,
100
274
  normalized=self.normalized,
101
275
  onesided=self.onesided,
@@ -105,10 +279,12 @@ class AudioProcessor(Model):
105
279
  except RuntimeError as e:
106
280
  if not _recall and spec.device != self.window.device:
107
281
  self.window = self.window.to(spec.device)
108
- return self.inverse_transform(spec, phase, n_fft, hop_length, win_length, length, _recall=True)
282
+ return self.inverse_transform(
283
+ spec, phase, n_fft, hop_length, win_length, length, _recall=True
284
+ )
109
285
  raise e
110
286
 
111
- def rebuild_spectrogram(
287
+ def normalize_audio(
112
288
  self,
113
289
  wave: Tensor,
114
290
  length: Optional[int] = None,
@@ -148,7 +324,7 @@ class AudioProcessor(Model):
148
324
  except RuntimeError as e:
149
325
  if not _recall and wave.device != self.window.device:
150
326
  self.window = self.window.to(wave.device)
151
- return self.rebuild_spectrogram(wave, length, _recall=True)
327
+ return self.normalize_audio(wave, length, _recall=True)
152
328
  raise e
153
329
 
154
330
  def compute_mel(
@@ -167,12 +343,7 @@ class AudioProcessor(Model):
167
343
 
168
344
  def inverse_mel_spectogram(self, mel: Tensor, n_iter: Optional[int] = None):
169
345
  if isinstance(n_iter, int) and n_iter != self.n_iter:
170
- self.giffin_lim = torchaudio.transforms.GriffinLim(
171
- n_fft=self.n_fft,
172
- n_iter=n_iter,
173
- win_length=self.win_length,
174
- hop_length=self.hop_length,
175
- )
346
+ self.giffin_lim.n_iter = n_iter
176
347
  self.n_iter = n_iter
177
348
  return self.giffin_lim.forward(
178
349
  self.mel_rscale(mel),
@@ -182,21 +353,26 @@ class AudioProcessor(Model):
182
353
  self,
183
354
  path: PathLike,
184
355
  top_db: float = 30,
356
+ normalize: bool = False,
357
+ alpha: float = 1.0,
185
358
  ) -> Tensor:
186
359
  is_file(path, True)
187
360
  wave, sr = librosa.load(str(path), sr=self.sample_rate)
188
361
  wave, _ = librosa.effects.trim(wave, top_db=top_db)
189
- return (
190
- torch.from_numpy(
191
- librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
192
- if sr != self.sample_rate
193
- else wave
194
- )
195
- .float()
196
- .unsqueeze(0)
197
- )
362
+ if sr != self.sample_rate:
363
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=self.sample_rate)
364
+ if normalize:
365
+ wave = librosa.util.normalize(wave)
366
+ if alpha not in [0.0, 1.0]:
367
+ wave = wave * alpha
368
+ return torch.from_numpy(wave).float().unsqueeze(0)
198
369
 
199
- def find_audios(self, path: PathLike, additional_extensions: List[str] = []):
370
+ def find_audios(
371
+ self,
372
+ path: PathLike,
373
+ additional_extensions: List[str] = [],
374
+ maximum: int | None = None,
375
+ ):
200
376
  extensions = [
201
377
  "*.wav",
202
378
  "*.aac",
@@ -212,6 +388,7 @@ class AudioProcessor(Model):
212
388
  return FileScan.files(
213
389
  path,
214
390
  extensions,
391
+ maximum,
215
392
  )
216
393
 
217
394
  def find_audio_text_pairs(
@@ -240,57 +417,6 @@ class AudioProcessor(Model):
240
417
  break
241
418
  return results
242
419
 
243
- def slice_mismatch_outputs(
244
- self,
245
- tensor_1: Tensor,
246
- tensor_2: Tensor,
247
- smallest_size: Optional[int] = None,
248
- left_to_right: bool = True,
249
- ):
250
- assert tensor_1.ndim == tensor_2.ndim, (
251
- "Tensors must have the same dimentions to be sliced! \n"
252
- f"Received instead a tensor_1 with {tensor_1.ndim}D and tensor_2 with {tensor_1.ndim}D."
253
- )
254
- dim = tensor_1.ndim
255
- assert dim < 5, (
256
- "Excpected to receive tensors with from 1D up to 4D. "
257
- f"Received instead a {dim}D tensor."
258
- )
259
-
260
- if tensor_1.shape[-1] == tensor_2.shape[-1]:
261
- return tensor_1, tensor_2
262
-
263
- if smallest_size is None:
264
- smallest_size = min(tensor_1.shape[-1], tensor_2.shape[-1])
265
- if dim == 0:
266
- tensor_1 = tensor_1.unsqueeze(0)
267
- tensor_2 = tensor_2.unsqueeze(0)
268
- dim = 1
269
-
270
- if dim == 1:
271
- if left_to_right:
272
- return tensor_1[:smallest_size], tensor_2[:smallest_size]
273
- return tensor_1[-smallest_size:], tensor_2[-smallest_size:]
274
- elif dim == 2:
275
- if left_to_right:
276
- return tensor_1[:, :smallest_size], tensor_2[:, :smallest_size]
277
- return tensor_1[:, -smallest_size:], tensor_2[:, -smallest_size:]
278
- elif dim == 3:
279
- if left_to_right:
280
- return tensor_1[:, :, :smallest_size], tensor_2[:, :, :smallest_size]
281
- return tensor_1[:, :, -smallest_size:], tensor_2[:, :, -smallest_size:]
282
-
283
- # else:
284
- if left_to_right:
285
- return (
286
- tensor_1[:, :, :, :smallest_size],
287
- tensor_2[:, :, :, :smallest_size],
288
- )
289
- return (
290
- tensor_1[:, :, :, -smallest_size:],
291
- tensor_2[:, :, :, -smallest_size:],
292
- )
293
-
294
420
  def stft_loss(
295
421
  self,
296
422
  signal: Tensor,
@@ -302,11 +428,23 @@ class AudioProcessor(Model):
302
428
  smallest = min(signal.shape[-1], ground.shape[-1])
303
429
  signal = signal[:, -smallest:]
304
430
  ground = ground[:, -smallest:]
305
- sig_mel = self.compute_mel(signal, base, True).detach().cpu()
306
- gnd_mel = self.compute_mel(ground, base, True).detach().cpu()
431
+ sig_mel = self.compute_mel(signal, base, True)
432
+ gnd_mel = self.compute_mel(ground, base, True)
307
433
  return torch.norm(gnd_mel - sig_mel, p=1) / torch.norm(gnd_mel, p=1)
308
434
 
309
- # def forward(self, wave: Tensor, base: Optional[float] = None):
435
+ @staticmethod
436
+ def plot_spectrogram(spectrogram, ax_):
437
+ import matplotlib.pylab as plt
438
+
439
+ fig, ax = plt.subplots(figsize=(10, 2))
440
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
441
+ plt.colorbar(im, ax=ax)
442
+
443
+ fig.canvas.draw()
444
+ plt.close()
445
+
446
+ return fig
447
+
310
448
  def forward(
311
449
  self,
312
450
  *inputs: Union[Tensor, float],
lt_tensor/transform.py CHANGED
@@ -316,18 +316,6 @@ def inverse_transform(
316
316
  )
317
317
 
318
318
 
319
- def is_nand(a: bool, b: bool):
320
- """[a -> b = result]
321
- ```
322
- False -> False = True
323
- False -> True = True
324
- True -> False = True
325
- True -> True = False
326
- ```
327
- """
328
- return not (a and b)
329
-
330
-
331
319
  class InverseTransformConfig:
332
320
  def __init__(
333
321
  self,
@@ -413,9 +401,11 @@ class InverseTransform(Model):
413
401
  self.return_complex = return_complex
414
402
  self.onesided = onesided
415
403
  self.normalized = normalized
416
- self.register_buffer('window', torch.hann_window(self.win_length) if window is None else window)
404
+ self.register_buffer(
405
+ "window", torch.hann_window(self.win_length) if window is None else window
406
+ )
417
407
 
418
- def forward(self, spec: Tensor, phase: Tensor, *, _recall:bool = False):
408
+ def forward(self, spec: Tensor, phase: Tensor, *, _recall: bool = False):
419
409
  """
420
410
  Perform the inverse short-time Fourier transform.
421
411
 
@@ -434,7 +424,7 @@ class InverseTransform(Model):
434
424
  try:
435
425
  return torch.istft(
436
426
  spec * torch.exp(phase * 1j),
437
- n_fft = self.n_fft,
427
+ n_fft=self.n_fft,
438
428
  hop_length=self.hop_length,
439
429
  win_length=self.win_length,
440
430
  window=self.window,
@@ -448,4 +438,5 @@ class InverseTransform(Model):
448
438
  if not _recall and spec.device != self.window.device:
449
439
  self.window = self.window.to(spec.device)
450
440
  return self.forward(spec, phase, _recall=True)
451
- raise e
441
+ raise e
442
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lt-tensor
3
- Version: 0.0.1a12
3
+ Version: 0.0.1a13
4
4
  Summary: General utilities for PyTorch and others. Built for general use.
5
5
  Home-page: https://github.com/gr1336/lt-tensor/
6
6
  Author: gr1336
@@ -11,15 +11,17 @@ Classifier: Topic :: Software Development :: Libraries
11
11
  Classifier: Topic :: Utilities
12
12
  Description-Content-Type: text/markdown
13
13
  License-File: LICENSE
14
- Requires-Dist: torch>=2.2.0
15
- Requires-Dist: torchaudio>=2.2.0
14
+ Requires-Dist: torch>=2.7.0
15
+ Requires-Dist: torchaudio>=2.7.0
16
16
  Requires-Dist: numpy>=1.26.4
17
17
  Requires-Dist: tokenizers
18
18
  Requires-Dist: pyyaml>=6.0.0
19
19
  Requires-Dist: numba>0.60.0
20
20
  Requires-Dist: lt-utils>=0.0.2a1
21
- Requires-Dist: librosa>=0.11.0
21
+ Requires-Dist: librosa==0.11.*
22
+ Requires-Dist: einops
22
23
  Requires-Dist: plotly
24
+ Requires-Dist: scipy
23
25
  Dynamic: author
24
26
  Dynamic: classifier
25
27
  Dynamic: description
@@ -0,0 +1,32 @@
1
+ lt_tensor/__init__.py,sha256=XxNCGcVL-haJyMpifr-GRaamo32R6jmqe3iOuS4ecfs,469
2
+ lt_tensor/config_templates.py,sha256=FRN4-i1amoqMh_wyp4gNsw61ABWTIhGC62Uc3l3SNss,3515
3
+ lt_tensor/losses.py,sha256=zvkCOnE5XpF3v6ymivRIdqPTsMM5zc94ZMom7YDi3zM,4946
4
+ lt_tensor/lr_schedulers.py,sha256=LSZzqrOOLzSthD8k-W4cYPJt0vCjmHkiJkLr5e3yRTE,3659
5
+ lt_tensor/math_ops.py,sha256=TkD4WQG42KsQ9Fg7FXOjf8f-ixtW0apf2XjaooecVx4,2257
6
+ lt_tensor/misc_utils.py,sha256=UNba6UEsAv1oZ60IAaKBNGbhXK2WPxRI9E4QcjP-_w0,28755
7
+ lt_tensor/model_base.py,sha256=lxzRXfPlR_t_6LfgRw2dct55evrtmwTiDqZGAe3jLro,20026
8
+ lt_tensor/monotonic_align.py,sha256=LhBd8p1xdBzg6jQrQX1j7b4PNeYGwIqM24zcU-pHOLE,2239
9
+ lt_tensor/noise_tools.py,sha256=wFeAsHhLhSlEc5XU5LbFKaXoHeVxrWjiMeljjGdIKyM,11363
10
+ lt_tensor/torch_commons.py,sha256=fntsEU8lhBQo0ebonI1iXBkMbWMN3HpBsG13EWlP5s8,718
11
+ lt_tensor/transform.py,sha256=dZm8T_ov0blHMQu6nGiehsdG1VSB7bZBUVmTkT-PBdc,13257
12
+ lt_tensor/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ lt_tensor/datasets/audio.py,sha256=j73oRyXt-AK4tWWYWjH-3c5RYouQBgDSCTuWHmyG8kQ,7450
14
+ lt_tensor/model_zoo/__init__.py,sha256=RzG7fltZLyiIU_Za4pgfBPli5uPITiJkq4sTCd4uA_0,319
15
+ lt_tensor/model_zoo/basic.py,sha256=_26H_jJk5Ld3DZiNpIhGosGfMxoFDZrI8bpDAYUOYno,10660
16
+ lt_tensor/model_zoo/discriminator.py,sha256=dS5UmJZV5MxIFiaBlIXfgGLDdUT3y0Vuv9lDGHsjJE8,5849
17
+ lt_tensor/model_zoo/features.py,sha256=CTFMidzza31pqQjwPfp_g0BNVfuQ8Dlo5JnxpYpKgag,13144
18
+ lt_tensor/model_zoo/fusion.py,sha256=usC1bcjQRNivDc8xzkIS5T1glm78OLcs2V_tPqfp-eI,5422
19
+ lt_tensor/model_zoo/pos_encoder.py,sha256=3d1EYLinCU9UAy-WuEWeYMGhMqaGknCiQ5qEmhw_UYM,4487
20
+ lt_tensor/model_zoo/residual.py,sha256=knVLxzrLUjNQ6vdBESTZOk3r86ldi5PHetoBuJmymcw,6388
21
+ lt_tensor/model_zoo/transformer.py,sha256=HUFoFFh7EQJErxdd9XIxhssdjvNVx2tNGDJOTUfwG2A,4301
22
+ lt_tensor/model_zoo/istft/__init__.py,sha256=SV96w9WUWfHMee8Vjgn2MP0igKft7_mLTju9rFVYGHY,102
23
+ lt_tensor/model_zoo/istft/generator.py,sha256=lotGkMu67fctzwa5FSwX_xtHILOuV95uP-djCz2N3C8,5261
24
+ lt_tensor/model_zoo/istft/sg.py,sha256=EaEi3otw_uY5QfqDBNIWBWTJSg3KnwzzR4FBr0u09C0,4838
25
+ lt_tensor/model_zoo/istft/trainer.py,sha256=EPuGtvfgR8vCrVc72p5OwVy73nNVlx510VxnH3NeErY,16080
26
+ lt_tensor/processors/__init__.py,sha256=4b9MxAJolXiJfSm20ZEspQTDm1tgLazwlPWA_jB1yLM,63
27
+ lt_tensor/processors/audio.py,sha256=uBvMls4u_B1M-pk3xAiOIRnwM2l_3LcdfESNkE0Ch30,15314
28
+ lt_tensor-0.0.1a13.dist-info/licenses/LICENSE,sha256=HUnu_iSPpnDfZS_PINhO3AoVizJD1A2vee8WX7D7uXo,11358
29
+ lt_tensor-0.0.1a13.dist-info/METADATA,sha256=yzNtg91vOGZCoXi6XWpn1kWk7LgVD2mIWQXL-7tw_Uc,1033
30
+ lt_tensor-0.0.1a13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ lt_tensor-0.0.1a13.dist-info/top_level.txt,sha256=35FuhFeXnUyvHWdbVHGPh0hS8euofafnJ_GJAVSF4Kk,10
32
+ lt_tensor-0.0.1a13.dist-info/RECORD,,
@@ -1,67 +0,0 @@
1
- __all__ = [
2
- "ConcatFusion",
3
- "FiLMFusion",
4
- "BilinearFusion",
5
- "CrossAttentionFusion",
6
- "GatedFusion",
7
- ]
8
-
9
- from lt_tensor.torch_commons import *
10
- from lt_tensor.model_base import Model
11
-
12
-
13
- class ConcatFusion(Model):
14
- def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
15
- super().__init__()
16
- self.proj = nn.Linear(in_dim_a + in_dim_b, out_dim)
17
-
18
- def forward(self, a: Tensor, b: Tensor) -> Tensor:
19
- x = torch.cat([a, b], dim=-1)
20
- return self.proj(x)
21
-
22
-
23
- class FiLMFusion(Model):
24
- def __init__(self, cond_dim: int, feature_dim: int):
25
- super().__init__()
26
- self.modulator = nn.Linear(cond_dim, 2 * feature_dim)
27
-
28
- def forward(self, x: Tensor, cond: Tensor) -> Tensor:
29
- scale, shift = self.modulator(cond).chunk(2, dim=-1)
30
- return x * scale + shift
31
-
32
-
33
- class BilinearFusion(Model):
34
- def __init__(self, in_dim_a: int, in_dim_b: int, out_dim: int):
35
- super().__init__()
36
- self.bilinear = nn.Bilinear(in_dim_a, in_dim_b, out_dim)
37
-
38
- def forward(self, a: Tensor, b: Tensor) -> Tensor:
39
- return self.bilinear(a, b)
40
-
41
-
42
- class CrossAttentionFusion(Model):
43
- def __init__(self, q_dim: int, kv_dim: int, n_heads: int = 4, d_model: int = 256):
44
- super().__init__()
45
- self.q_proj = nn.Linear(q_dim, d_model)
46
- self.k_proj = nn.Linear(kv_dim, d_model)
47
- self.v_proj = nn.Linear(kv_dim, d_model)
48
- self.attn = nn.MultiheadAttention(
49
- embed_dim=d_model, num_heads=n_heads, batch_first=True
50
- )
51
-
52
- def forward(self, query: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
53
- Q = self.q_proj(query)
54
- K = self.k_proj(context)
55
- V = self.v_proj(context)
56
- output, _ = self.attn(Q, K, V, key_padding_mask=mask)
57
- return output
58
-
59
-
60
- class GatedFusion(Model):
61
- def __init__(self, in_dim: int):
62
- super().__init__()
63
- self.gate = nn.Sequential(nn.Linear(in_dim * 2, in_dim), nn.Sigmoid())
64
-
65
- def forward(self, a: Tensor, b: Tensor) -> Tensor:
66
- gate = self.gate(torch.cat([a, b], dim=-1))
67
- return gate * a + (1 - gate) * b