flamo 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flamo/optimize/dataset.py +5 -1
- flamo/processor/dsp.py +415 -0
- {flamo-0.2.3.dist-info → flamo-0.2.5.dist-info}/METADATA +1 -1
- {flamo-0.2.3.dist-info → flamo-0.2.5.dist-info}/RECORD +6 -6
- {flamo-0.2.3.dist-info → flamo-0.2.5.dist-info}/WHEEL +0 -0
- {flamo-0.2.3.dist-info → flamo-0.2.5.dist-info}/licenses/LICENSE +0 -0
flamo/optimize/dataset.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Optional
|
|
1
2
|
import torch
|
|
2
3
|
import torch.utils.data as data
|
|
3
4
|
from flamo.utils import get_device
|
|
@@ -21,6 +22,7 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
21
22
|
- **target** (torch.Tensor, optional): The target data tensor. Default: torch.randn(100, 100).
|
|
22
23
|
- **expand** (int): The first shape dimention of the input and target tensor after expansion. Default: 1. This coincides with the length on the dataset.
|
|
23
24
|
- **device** (torch.device, optional): The device to store the tensors on. Defaults to torch default device.
|
|
25
|
+
- **dtype** (torch.dtype, optional): The data type of the tensors. If None, the data type of the input tensor is used. Default: None.
|
|
24
26
|
"""
|
|
25
27
|
|
|
26
28
|
def __init__(
|
|
@@ -29,8 +31,10 @@ class Dataset(torch.utils.data.Dataset):
|
|
|
29
31
|
target: torch.Tensor = torch.randn(1, 1),
|
|
30
32
|
expand: int = 1,
|
|
31
33
|
device: torch.device = torch.get_default_device(),
|
|
32
|
-
dtype: torch.dtype =
|
|
34
|
+
dtype: Optional[torch.dtype] = None,
|
|
33
35
|
):
|
|
36
|
+
if dtype is None:
|
|
37
|
+
dtype = input.dtype
|
|
34
38
|
self.input = input.to(device).to(dtype)
|
|
35
39
|
self.target = target.to(device).to(dtype)
|
|
36
40
|
self.expand = expand
|
flamo/processor/dsp.py
CHANGED
|
@@ -1663,6 +1663,270 @@ class parallelBiquad(Biquad):
|
|
|
1663
1663
|
self.output_channels = self.size[-1]
|
|
1664
1664
|
|
|
1665
1665
|
|
|
1666
|
+
class SOSFilter(Filter):
|
|
1667
|
+
r"""
|
|
1668
|
+
A class representing cascaded second-order sections (SOS) specified directly by
|
|
1669
|
+
numerator/denominator coefficients (b/a).
|
|
1670
|
+
|
|
1671
|
+
Each section k has coefficients [b0_k, b1_k, b2_k, a0_k, a1_k, a2_k]. Sections are
|
|
1672
|
+
applied in series. The frequency response is computed from the time-domain
|
|
1673
|
+
polynomial coefficients with an anti time-aliasing envelope applied to the 3 taps.
|
|
1674
|
+
|
|
1675
|
+
Shape:
|
|
1676
|
+
- input: (B, M, N_in, ...)
|
|
1677
|
+
- param: (K, 6, N_out, N_in) with ordering [b0, b1, b2, a0, a1, a2]
|
|
1678
|
+
- freq_response: (M, N_out, N_in)
|
|
1679
|
+
- output: (B, M, N_out, ...)
|
|
1680
|
+
|
|
1681
|
+
where B is the batch size, M is the number of frequency bins, N_in is the number of
|
|
1682
|
+
input channels, N_out is the number of output channels, and K is the number of SOS
|
|
1683
|
+
sections (cascaded in series).
|
|
1684
|
+
|
|
1685
|
+
**Arguments / Attributes**:
|
|
1686
|
+
- size (tuple, optional): (N_out, N_in). Default: (1, 1).
|
|
1687
|
+
- n_sections (int, optional): Number of SOS sections (K). Default: 1.
|
|
1688
|
+
- nfft (int, optional): Number of FFT points. Default: 2**11.
|
|
1689
|
+
- fs (int, optional): Sampling frequency. Default: 48000.
|
|
1690
|
+
- alias_decay_db (float, optional): Anti time-aliasing envelope decay in dB after nfft samples. Default: 0.0.
|
|
1691
|
+
- device (str, optional): Device for constructed tensors. Default: None.
|
|
1692
|
+
- dtype (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
1693
|
+
- normalize_a0 (bool, optional): Normalize each section by a0 so a0=1. Default: True.
|
|
1694
|
+
|
|
1695
|
+
**Attributes**:
|
|
1696
|
+
- param (nn.Parameter): Raw SOS coefficients with shape (K, 6, N_out, N_in).
|
|
1697
|
+
- alias_envelope_dcy (torch.Tensor): Length-3 envelope for time anti-aliasing.
|
|
1698
|
+
- freq_response (callable): Maps parameters to frequency response.
|
|
1699
|
+
- input_channels (int): Number of input channels.
|
|
1700
|
+
- output_channels (int): Number of output channels.
|
|
1701
|
+
"""
|
|
1702
|
+
|
|
1703
|
+
def __init__(
|
|
1704
|
+
self,
|
|
1705
|
+
size: tuple = (1, 1),
|
|
1706
|
+
n_sections: int = 1,
|
|
1707
|
+
nfft: int = 2**11,
|
|
1708
|
+
fs: int = 48000,
|
|
1709
|
+
alias_decay_db: float = 0.0,
|
|
1710
|
+
device: Optional[str] = None,
|
|
1711
|
+
dtype: torch.dtype = torch.float32,
|
|
1712
|
+
normalize_a0: bool = True,
|
|
1713
|
+
):
|
|
1714
|
+
self.n_sections = n_sections
|
|
1715
|
+
self.fs = fs
|
|
1716
|
+
self.device = device
|
|
1717
|
+
self.dtype = dtype
|
|
1718
|
+
self.normalize_a0 = normalize_a0
|
|
1719
|
+
# 3-tap envelope for [b0, b1, b2] and [a0, a1, a2]
|
|
1720
|
+
gamma = 10 ** (
|
|
1721
|
+
-torch.abs(torch.tensor(alias_decay_db, device=self.device, dtype=self.dtype)) / (nfft) / 20
|
|
1722
|
+
)
|
|
1723
|
+
self.alias_envelope_dcy = gamma ** torch.arange(0, 3, 1, device=self.device)
|
|
1724
|
+
self.get_map()
|
|
1725
|
+
super().__init__(
|
|
1726
|
+
size=(n_sections, *self.get_size(), *size),
|
|
1727
|
+
nfft=nfft,
|
|
1728
|
+
map=self.map,
|
|
1729
|
+
requires_grad=False,
|
|
1730
|
+
alias_decay_db=alias_decay_db,
|
|
1731
|
+
device=device,
|
|
1732
|
+
dtype=dtype,
|
|
1733
|
+
)
|
|
1734
|
+
|
|
1735
|
+
def get_size(self):
|
|
1736
|
+
r"""
|
|
1737
|
+
Leading dimensions for SOS parameters.
|
|
1738
|
+
|
|
1739
|
+
- 6 for [b0, b1, b2, a0, a1, a2]
|
|
1740
|
+
"""
|
|
1741
|
+
return (6,)
|
|
1742
|
+
|
|
1743
|
+
def get_map(self):
|
|
1744
|
+
r"""
|
|
1745
|
+
Mapping for raw SOS coefficients. Optionally normalizes each section so a0=1.
|
|
1746
|
+
"""
|
|
1747
|
+
|
|
1748
|
+
def _map(x: torch.Tensor) -> torch.Tensor:
|
|
1749
|
+
if not self.normalize_a0:
|
|
1750
|
+
return x
|
|
1751
|
+
# x: (K, 6, N_out, N_in)
|
|
1752
|
+
a0 = x[:, 3, ...]
|
|
1753
|
+
eps = torch.finfo(x.dtype).eps
|
|
1754
|
+
a0_safe = torch.where(torch.abs(a0) > eps, a0, eps * torch.ones_like(a0))
|
|
1755
|
+
y = x.clone()
|
|
1756
|
+
# divide all coeffs by a0; set a0 to 1
|
|
1757
|
+
for idx in [0, 1, 2, 4, 5]:
|
|
1758
|
+
y[:, idx, ...] = y[:, idx, ...] / a0_safe
|
|
1759
|
+
y[:, 3, ...] = torch.ones_like(a0)
|
|
1760
|
+
return y
|
|
1761
|
+
|
|
1762
|
+
self.map = _map
|
|
1763
|
+
|
|
1764
|
+
def init_param(self):
|
|
1765
|
+
r"""
|
|
1766
|
+
Initialize parameters to identity sections: b=[1,0,0], a=[1,0,0].
|
|
1767
|
+
"""
|
|
1768
|
+
with torch.no_grad():
|
|
1769
|
+
self.param.zero_()
|
|
1770
|
+
# b0 = 1, a0 = 1
|
|
1771
|
+
self.param[:, 0, ...] = 1.0
|
|
1772
|
+
self.param[:, 3, ...] = 1.0
|
|
1773
|
+
|
|
1774
|
+
def check_param_shape(self):
|
|
1775
|
+
r"""
|
|
1776
|
+
Checks if the shape of the SOS parameters is valid.
|
|
1777
|
+
"""
|
|
1778
|
+
assert (
|
|
1779
|
+
len(self.size) == 4
|
|
1780
|
+
), "Parameter size must be 4D, expected (K, 6, N_out, N_in)."
|
|
1781
|
+
assert (
|
|
1782
|
+
self.size[1] == 6
|
|
1783
|
+
), "Second dimension must be 6: [b0,b1,b2,a0,a1,a2]."
|
|
1784
|
+
|
|
1785
|
+
def initialize_class(self):
|
|
1786
|
+
r"""
|
|
1787
|
+
Initialize the SosFilter class.
|
|
1788
|
+
"""
|
|
1789
|
+
self.check_param_shape()
|
|
1790
|
+
self.get_io()
|
|
1791
|
+
self.get_freq_response()
|
|
1792
|
+
self.get_freq_convolve()
|
|
1793
|
+
|
|
1794
|
+
def get_freq_response(self):
|
|
1795
|
+
r"""
|
|
1796
|
+
Compute the frequency response of the cascaded SOS.
|
|
1797
|
+
"""
|
|
1798
|
+
self.freq_response = lambda param: self.get_poly_coeff(self.map(param))[0]
|
|
1799
|
+
|
|
1800
|
+
def get_poly_coeff(self, param: torch.Tensor):
|
|
1801
|
+
r"""
|
|
1802
|
+
Split mapped parameters into b and a polynomials, apply anti-aliasing envelope,
|
|
1803
|
+
and compute frequency response in double precision.
|
|
1804
|
+
|
|
1805
|
+
**Arguments**:
|
|
1806
|
+
- param (torch.Tensor): (K, 6, N_out, N_in)
|
|
1807
|
+
|
|
1808
|
+
**Returns**:
|
|
1809
|
+
- H (torch.Tensor): (M, N_out, N_in)
|
|
1810
|
+
- B (torch.Tensor): (M, K, N_out, N_in)
|
|
1811
|
+
- A (torch.Tensor): (M, K, N_out, N_in)
|
|
1812
|
+
"""
|
|
1813
|
+
# Arrange to (3, K, N_out, N_in)
|
|
1814
|
+
b = torch.stack((param[:, 0, ...], param[:, 1, ...], param[:, 2, ...]), dim=0)
|
|
1815
|
+
a = torch.stack((param[:, 3, ...], param[:, 4, ...], param[:, 5, ...]), dim=0)
|
|
1816
|
+
|
|
1817
|
+
b_aa = torch.einsum(
|
|
1818
|
+
"p, pomn -> pomn", self.alias_envelope_dcy, b)
|
|
1819
|
+
a_aa = torch.einsum(
|
|
1820
|
+
"p, pomn -> pomn", self.alias_envelope_dcy, a)
|
|
1821
|
+
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1822
|
+
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1823
|
+
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1824
|
+
denom = torch.abs(torch.prod(A, dim=1))
|
|
1825
|
+
H = torch.where(
|
|
1826
|
+
denom != 0, H_temp, torch.finfo(H_temp.dtype).eps * torch.ones_like(H_temp)
|
|
1827
|
+
)
|
|
1828
|
+
return H, B, A
|
|
1829
|
+
|
|
1830
|
+
def get_freq_convolve(self):
|
|
1831
|
+
r"""
|
|
1832
|
+
Frequency-domain matrix product with the input.
|
|
1833
|
+
"""
|
|
1834
|
+
self.freq_convolve = lambda x, param: torch.einsum(
|
|
1835
|
+
"fmn,bfn...->bfm...", self.freq_response(param), x
|
|
1836
|
+
)
|
|
1837
|
+
|
|
1838
|
+
def get_io(self):
|
|
1839
|
+
r"""
|
|
1840
|
+
Computes the number of input and output channels based on the size parameter.
|
|
1841
|
+
"""
|
|
1842
|
+
self.input_channels = self.size[-1]
|
|
1843
|
+
self.output_channels = self.size[-2]
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
class parallelSOSFilter(SOSFilter):
|
|
1847
|
+
r"""
|
|
1848
|
+
Parallel counterpart of the SOSFilter class.
|
|
1849
|
+
|
|
1850
|
+
Accepts direct SOS coefficients for N parallel channels. Parameter shape is
|
|
1851
|
+
(K, 6, N), with ordering [b0, b1, b2, a0, a1, a2] per section.
|
|
1852
|
+
|
|
1853
|
+
Shape:
|
|
1854
|
+
- input: (B, M, N, ...)
|
|
1855
|
+
- param: (K, 6, N)
|
|
1856
|
+
- freq_response: (M, N)
|
|
1857
|
+
- output: (B, M, N, ...)
|
|
1858
|
+
"""
|
|
1859
|
+
|
|
1860
|
+
def __init__(
|
|
1861
|
+
self,
|
|
1862
|
+
size: tuple = (1,),
|
|
1863
|
+
n_sections: int = 1,
|
|
1864
|
+
nfft: int = 2**11,
|
|
1865
|
+
fs: int = 48000,
|
|
1866
|
+
alias_decay_db: float = 0.0,
|
|
1867
|
+
device: Optional[str] = None,
|
|
1868
|
+
dtype: torch.dtype = torch.float32,
|
|
1869
|
+
normalize_a0: bool = True,
|
|
1870
|
+
):
|
|
1871
|
+
super().__init__(
|
|
1872
|
+
size=size,
|
|
1873
|
+
n_sections=n_sections,
|
|
1874
|
+
nfft=nfft,
|
|
1875
|
+
fs=fs,
|
|
1876
|
+
alias_decay_db=alias_decay_db,
|
|
1877
|
+
device=device,
|
|
1878
|
+
dtype=dtype,
|
|
1879
|
+
normalize_a0=normalize_a0,
|
|
1880
|
+
)
|
|
1881
|
+
|
|
1882
|
+
def check_param_shape(self):
|
|
1883
|
+
r"""
|
|
1884
|
+
Checks if the shape of the SOS parameters is valid.
|
|
1885
|
+
"""
|
|
1886
|
+
assert (
|
|
1887
|
+
len(self.size) == 3
|
|
1888
|
+
), "Parameter size must be 3D, expected (K, 6, N)."
|
|
1889
|
+
assert self.size[1] == 6, "Second dimension must be 6: [b0,b1,b2,a0,a1,a2]."
|
|
1890
|
+
|
|
1891
|
+
def get_freq_response(self):
|
|
1892
|
+
r"""Compute the frequency response of the cascaded SOS."""
|
|
1893
|
+
self.freq_response = lambda param: self.get_poly_coeff(self.map(param))[0]
|
|
1894
|
+
|
|
1895
|
+
def get_poly_coeff(self, param: torch.Tensor):
|
|
1896
|
+
r"""
|
|
1897
|
+
Split mapped parameters into b and a polynomials (parallel case), apply
|
|
1898
|
+
anti-aliasing envelope, and compute frequency response in double precision.
|
|
1899
|
+
|
|
1900
|
+
**Arguments**:
|
|
1901
|
+
- param (torch.Tensor): (K, 6, N)
|
|
1902
|
+
|
|
1903
|
+
**Returns**:
|
|
1904
|
+
- H (torch.Tensor): (M, N)
|
|
1905
|
+
- B (torch.Tensor): (M, K, N)
|
|
1906
|
+
- A (torch.Tensor): (M, K, N)
|
|
1907
|
+
"""
|
|
1908
|
+
b = torch.stack((param[:, 0, :], param[:, 1, :], param[:, 2, :]), dim=0)
|
|
1909
|
+
a = torch.stack((param[:, 3, :], param[:, 4, :], param[:, 5, :]), dim=0)
|
|
1910
|
+
|
|
1911
|
+
b_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, b)
|
|
1912
|
+
a_aa = torch.einsum("p, pon -> pon", self.alias_envelope_dcy, a)
|
|
1913
|
+
B = torch.fft.rfft(b_aa, self.nfft, dim=0)
|
|
1914
|
+
A = torch.fft.rfft(a_aa, self.nfft, dim=0)
|
|
1915
|
+
H_temp = torch.prod(B, dim=1) / (torch.prod(A, dim=1))
|
|
1916
|
+
H = torch.where(torch.abs(torch.prod(A, dim=1)) != 0, H_temp, torch.finfo(H_temp.dtype).eps * torch.ones_like(H_temp))
|
|
1917
|
+
return H, B, A
|
|
1918
|
+
|
|
1919
|
+
def get_freq_convolve(self):
|
|
1920
|
+
self.freq_convolve = lambda x, param: torch.einsum(
|
|
1921
|
+
"fn,bfn...->bfn...", self.freq_response(param), x
|
|
1922
|
+
)
|
|
1923
|
+
|
|
1924
|
+
def get_io(self):
|
|
1925
|
+
r"""Computes the number of input and output channels based on the size parameter."""
|
|
1926
|
+
self.input_channels = self.size[-1]
|
|
1927
|
+
self.output_channels = self.size[-1]
|
|
1928
|
+
|
|
1929
|
+
|
|
1666
1930
|
class SVF(Filter):
|
|
1667
1931
|
r"""
|
|
1668
1932
|
A class for IIR filters as a serially cascaded state variable filters (SVFs).
|
|
@@ -3105,3 +3369,154 @@ class parallelDelay(Delay):
|
|
|
3105
3369
|
"""
|
|
3106
3370
|
self.input_channels = self.size[-1]
|
|
3107
3371
|
self.output_channels = self.size[-1]
|
|
3372
|
+
|
|
3373
|
+
|
|
3374
|
+
class GainDelay(DSP):
|
|
3375
|
+
r"""
|
|
3376
|
+
A class implementing a combined MIMO gain and delay stage operating in the frequency domain.
|
|
3377
|
+
|
|
3378
|
+
This class computes the frequency response of a gain matrix followed by per-channel delays
|
|
3379
|
+
without constructing intermediate expanded tensors of size :math:`N_{out} \\times N_{in}`.
|
|
3380
|
+
|
|
3381
|
+
Shape:
|
|
3382
|
+
- input: :math:`(B, M, N_{in}, ...)`
|
|
3383
|
+
- param: :math:`(2, N_{out}, N_{in})`
|
|
3384
|
+
- output: :math:`(B, M, N_{out}, ...)`
|
|
3385
|
+
|
|
3386
|
+
where :math:`B` is the batch size, :math:`M` is the number of frequency bins,
|
|
3387
|
+
:math:`N_{in}` is the number of input channels, and :math:`N_{out}` is the number of output channels.
|
|
3388
|
+
Ellipsis :math:`(...)` represents additional dimensions.
|
|
3389
|
+
|
|
3390
|
+
**Arguments / Attributes**:
|
|
3391
|
+
- **size** (tuple, optional): Size of the gain-delay stage as ``(N_{out}, N_{in})``. Default: (1, 1).
|
|
3392
|
+
- **max_len** (int, optional): Maximum delay length expressed in samples. Default: 2000.
|
|
3393
|
+
- **isint** (bool, optional): If ``True``, delays are rounded to the nearest integer sample. Default: False.
|
|
3394
|
+
- **unit** (int, optional): Unit scaling factor for converting seconds to samples. Default: 100.
|
|
3395
|
+
- **nfft** (int, optional): Number of FFT points. Default: 2 ** 11.
|
|
3396
|
+
- **fs** (int, optional): Sampling rate. Default: 48000.
|
|
3397
|
+
- **map_gain** (callable, optional): Mapping applied to raw gain parameters. Default: ``lambda x: x``.
|
|
3398
|
+
- **map_delay** (callable, optional): Mapping applied to raw delay parameters (in seconds). Default: ``lambda x: x``.
|
|
3399
|
+
- **requires_grad** (bool, optional): Whether parameters require gradients. Default: False.
|
|
3400
|
+
- **alias_decay_db** (float, optional): Decay in dB applied by the anti aliasing envelope. Default: 0.0.
|
|
3401
|
+
- **device** (str, optional): Device of the constructed tensors. Default: None.
|
|
3402
|
+
- **dtype** (torch.dtype, optional): Data type for tensors. Default: torch.float32.
|
|
3403
|
+
"""
|
|
3404
|
+
|
|
3405
|
+
def __init__(
|
|
3406
|
+
self,
|
|
3407
|
+
size: tuple = (1, 1),
|
|
3408
|
+
max_len: int = 2000,
|
|
3409
|
+
isint: bool = False,
|
|
3410
|
+
unit: int = 100,
|
|
3411
|
+
nfft: int = 2**11,
|
|
3412
|
+
fs: int = 48000,
|
|
3413
|
+
map_gain: Optional[callable] = None,
|
|
3414
|
+
map_delay: Optional[callable] = None,
|
|
3415
|
+
requires_grad: bool = False,
|
|
3416
|
+
alias_decay_db: float = 0.0,
|
|
3417
|
+
device: Optional[str] = None,
|
|
3418
|
+
dtype: torch.dtype = torch.float32,
|
|
3419
|
+
):
|
|
3420
|
+
self.fs = fs
|
|
3421
|
+
self.max_len = max_len
|
|
3422
|
+
self.unit = unit
|
|
3423
|
+
self.isint = isint
|
|
3424
|
+
self._custom_gain_map = map_gain is not None
|
|
3425
|
+
self._custom_delay_map = map_delay is not None
|
|
3426
|
+
self.map_gain = map_gain if map_gain is not None else (lambda x: x)
|
|
3427
|
+
self.map_delay = map_delay if map_delay is not None else (lambda x: x)
|
|
3428
|
+
super().__init__(
|
|
3429
|
+
size=(2, *size),
|
|
3430
|
+
nfft=nfft,
|
|
3431
|
+
requires_grad=requires_grad,
|
|
3432
|
+
alias_decay_db=alias_decay_db,
|
|
3433
|
+
device=device,
|
|
3434
|
+
dtype=dtype,
|
|
3435
|
+
)
|
|
3436
|
+
self.initialize_class()
|
|
3437
|
+
|
|
3438
|
+
def forward(self, x, ext_param=None):
|
|
3439
|
+
self.check_input_shape(x)
|
|
3440
|
+
if ext_param is None:
|
|
3441
|
+
return self.freq_convolve(x, self.param)
|
|
3442
|
+
with torch.no_grad():
|
|
3443
|
+
self.assign_value(ext_param)
|
|
3444
|
+
return self.freq_convolve(x, ext_param)
|
|
3445
|
+
|
|
3446
|
+
def init_param(self):
|
|
3447
|
+
gain_shape = self.size[1:]
|
|
3448
|
+
with torch.no_grad():
|
|
3449
|
+
nn.init.ones_(self.param[0])
|
|
3450
|
+
if self.isint:
|
|
3451
|
+
delay_samples = torch.randint(
|
|
3452
|
+
1, self.max_len, gain_shape, device=self.device, dtype=torch.int64
|
|
3453
|
+
).to(self.param.dtype)
|
|
3454
|
+
else:
|
|
3455
|
+
delay_samples = torch.rand(gain_shape, device=self.device, dtype=self.dtype) * self.max_len
|
|
3456
|
+
delay_seconds = self.sample2s(delay_samples)
|
|
3457
|
+
self.param[1].copy_(delay_seconds)
|
|
3458
|
+
max_delay = torch.ceil(delay_samples).max().item()
|
|
3459
|
+
self.order = int(max_delay) + 1
|
|
3460
|
+
|
|
3461
|
+
def s2sample(self, delay: torch.Tensor):
|
|
3462
|
+
return delay * self.fs / self.unit
|
|
3463
|
+
|
|
3464
|
+
def sample2s(self, delay: torch.Tensor):
|
|
3465
|
+
return delay / self.fs * self.unit
|
|
3466
|
+
|
|
3467
|
+
def check_input_shape(self, x):
|
|
3468
|
+
if (int(self.nfft / 2 + 1), self.input_channels) != (x.shape[1], x.shape[2]):
|
|
3469
|
+
raise ValueError(
|
|
3470
|
+
f"parameter shape = {self.param.shape} not compatible with input signal of shape = ({x.shape})."
|
|
3471
|
+
)
|
|
3472
|
+
|
|
3473
|
+
def check_param_shape(self):
|
|
3474
|
+
assert (
|
|
3475
|
+
len(self.size) == 3 and self.size[0] == 2
|
|
3476
|
+
), "GainDelay parameters must have shape (2, N_out, N_in)."
|
|
3477
|
+
|
|
3478
|
+
def get_gains(self):
|
|
3479
|
+
return lambda param: to_complex(self.map_gain(param[0]))
|
|
3480
|
+
|
|
3481
|
+
def get_delays(self):
|
|
3482
|
+
return lambda param: self.s2sample(self.map_delay(param[1]))
|
|
3483
|
+
|
|
3484
|
+
def get_freq_response(self):
|
|
3485
|
+
gains = self.get_gains()
|
|
3486
|
+
delays = self.get_delays()
|
|
3487
|
+
if self.isint:
|
|
3488
|
+
self.freq_response = lambda param: self._combine_gain_delay(
|
|
3489
|
+
gains(param), delays(param).round()
|
|
3490
|
+
)
|
|
3491
|
+
else:
|
|
3492
|
+
self.freq_response = lambda param: self._combine_gain_delay(
|
|
3493
|
+
gains(param), delays(param)
|
|
3494
|
+
)
|
|
3495
|
+
|
|
3496
|
+
def get_freq_convolve(self):
|
|
3497
|
+
self.freq_convolve = lambda x, param: torch.einsum(
|
|
3498
|
+
"fmn,bfn...->bfm...", self.freq_response(param), x
|
|
3499
|
+
)
|
|
3500
|
+
|
|
3501
|
+
def initialize_class(self):
|
|
3502
|
+
self.check_param_shape()
|
|
3503
|
+
self.get_io()
|
|
3504
|
+
if self.requires_grad and not self._custom_delay_map:
|
|
3505
|
+
self.map_delay = lambda x: F.softplus(x)
|
|
3506
|
+
self.omega = (
|
|
3507
|
+
2
|
|
3508
|
+
* torch.pi
|
|
3509
|
+
* torch.arange(0, self.nfft // 2 + 1, device=self.device, dtype=self.dtype)
|
|
3510
|
+
/ self.nfft
|
|
3511
|
+
).unsqueeze(1)
|
|
3512
|
+
self.get_freq_response()
|
|
3513
|
+
self.get_freq_convolve()
|
|
3514
|
+
|
|
3515
|
+
def get_io(self):
|
|
3516
|
+
self.input_channels = self.size[-1]
|
|
3517
|
+
self.output_channels = self.size[-2]
|
|
3518
|
+
|
|
3519
|
+
def _combine_gain_delay(self, gain: torch.Tensor, delay_samples: torch.Tensor):
|
|
3520
|
+
delay_samples = delay_samples.to(gain.real.dtype)
|
|
3521
|
+
phase = torch.einsum("fo, omn -> fmn", self.omega, delay_samples.unsqueeze(0))
|
|
3522
|
+
return gain.unsqueeze(0) * (self.gamma ** delay_samples) * torch.exp(-1j * phase)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flamo
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: An Open-Source Library for Frequency-Domain Differentiable Audio Processing
|
|
5
5
|
Project-URL: Homepage, https://github.com/gdalsanto/flamo
|
|
6
6
|
Project-URL: Issues, https://github.com/gdalsanto/flamo/issues
|
|
@@ -10,15 +10,15 @@ flamo/auxiliary/scattering.py,sha256=IV9ZgkMfgEWtQo3IeLJviAXAdjrYI_Tob-H_8Zv8oBA
|
|
|
10
10
|
flamo/auxiliary/velvet.py,sha256=B4pYEnhaQPkh02pxqiGdAhLRX2g-eWtHezphi0_h4Qs,4201
|
|
11
11
|
flamo/auxiliary/config/config.py,sha256=sZ3XvqwV6KiIc2n8HRtg7YJE3zhc7Vqblbqs-Z0bsKg,2978
|
|
12
12
|
flamo/optimize/__init__.py,sha256=grgxLmQ7m-c9MvRdIejmEAaaajfBwgeaZAv2qjHIvPw,65
|
|
13
|
-
flamo/optimize/dataset.py,sha256=
|
|
13
|
+
flamo/optimize/dataset.py,sha256=WPvWDhT-U-gFkPaP1UzvFfB2bxlxdDDQ64zQ2-OcbYY,6789
|
|
14
14
|
flamo/optimize/loss.py,sha256=h6EeqjdX5P1SqDBKBavSxV25VBgnYK8tuX91wk6lw_g,33466
|
|
15
15
|
flamo/optimize/surface.py,sha256=sWy1ImwxUh_QLoY6S68LXBa82_HdWJGplFg2ObtpNGc,26655
|
|
16
16
|
flamo/optimize/trainer.py,sha256=he4nUjLC-3RTlxxBIw33r5k8mQfgAGvN1wpPBAWCjVo,12045
|
|
17
17
|
flamo/optimize/utils.py,sha256=R5-KoZagRho3eykY88pC3UB2mc5SsE4Yv9X-ogskXdA,1610
|
|
18
18
|
flamo/processor/__init__.py,sha256=paGdxGVZgA2VAs0tBwRd0bobzGxeyK79DS7ZGO8drkI,41
|
|
19
|
-
flamo/processor/dsp.py,sha256=
|
|
19
|
+
flamo/processor/dsp.py,sha256=Znp_9qjHRb7V0DBaqPWxa9oOlU84CVTjy6h3DcAh-TU,144811
|
|
20
20
|
flamo/processor/system.py,sha256=Hct-o6IgF5NQ2xYbX-1j3st94hMoM8dOgAzle2gjDqU,43145
|
|
21
|
-
flamo-0.2.
|
|
22
|
-
flamo-0.2.
|
|
23
|
-
flamo-0.2.
|
|
24
|
-
flamo-0.2.
|
|
21
|
+
flamo-0.2.5.dist-info/METADATA,sha256=ilg6hr1DaZgWg2qq783vj0vd8VNE1DQVsMb3-kKDt24,7825
|
|
22
|
+
flamo-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
23
|
+
flamo-0.2.5.dist-info/licenses/LICENSE,sha256=smMocRH7xdPT5RvFNqSLtbSNzohXJM5G_rX1Qaej6vg,1120
|
|
24
|
+
flamo-0.2.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|