lt-tensor 0.0.1a35__py3-none-any.whl → 0.0.1a37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -726,3 +726,235 @@ class MultiResolutionDiscriminator(_MultiDiscriminatorT):
726
726
  y_d_gs.append(y_d_g)
727
727
  fmap_gs.append(fmap_g)
728
728
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
729
+
730
+
731
+ class DiscriminatorCQT(ConvNets):
732
+ """Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license."""
733
+
734
+ def __init__(
735
+ self,
736
+ hop_length: int,
737
+ n_octaves: int,
738
+ bins_per_octave: int,
739
+ sampling_rate: int,
740
+ cqtd_filters: int = 128,
741
+ cqtd_max_filters: int = 1024,
742
+ cqtd_filters_scale: int = 1,
743
+ cqtd_dilations: list = [1, 2, 4],
744
+ cqtd_in_channels: int = 1,
745
+ cqtd_out_channels: int = 1,
746
+ cqtd_normalize_volume: bool = False,
747
+ ):
748
+ super().__init__()
749
+ self.filters = cqtd_filters
750
+ self.max_filters = cqtd_max_filters
751
+ self.filters_scale = cqtd_filters_scale
752
+ self.kernel_size = (3, 9)
753
+ self.dilations = cqtd_dilations
754
+ self.stride = (1, 2)
755
+
756
+ self.fs = sampling_rate
757
+ self.in_channels = cqtd_in_channels
758
+ self.out_channels = cqtd_out_channels
759
+ self.hop_length = hop_length
760
+ self.n_octaves = n_octaves
761
+ self.bins_per_octave = bins_per_octave
762
+
763
+ # Lazy-load
764
+ from lt_tensor.model_zoo.losses.CQT.transforms import CQT2010v2
765
+
766
+ self.cqt_transform = CQT2010v2(
767
+ sr=self.fs * 2,
768
+ hop_length=self.hop_length,
769
+ n_bins=self.bins_per_octave * self.n_octaves,
770
+ bins_per_octave=self.bins_per_octave,
771
+ output_format="Complex",
772
+ pad_mode="constant",
773
+ )
774
+
775
+ self.conv_pres = nn.ModuleList()
776
+ for _ in range(self.n_octaves):
777
+ self.conv_pres.append(
778
+ nn.Conv2d(
779
+ self.in_channels * 2,
780
+ self.in_channels * 2,
781
+ kernel_size=self.kernel_size,
782
+ padding=self.get_2d_padding(self.kernel_size),
783
+ )
784
+ )
785
+
786
+ self.convs = nn.ModuleList()
787
+
788
+ self.convs.append(
789
+ nn.Conv2d(
790
+ self.in_channels * 2,
791
+ self.filters,
792
+ kernel_size=self.kernel_size,
793
+ padding=self.get_2d_padding(self.kernel_size),
794
+ )
795
+ )
796
+
797
+ in_chs = min(self.filters_scale * self.filters, self.max_filters)
798
+ for i, dilation in enumerate(self.dilations):
799
+ out_chs = min(
800
+ (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
801
+ )
802
+ self.convs.append(
803
+ weight_norm(
804
+ nn.Conv2d(
805
+ in_chs,
806
+ out_chs,
807
+ kernel_size=self.kernel_size,
808
+ stride=self.stride,
809
+ dilation=(dilation, 1),
810
+ padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
811
+ )
812
+ )
813
+ )
814
+ in_chs = out_chs
815
+ out_chs = min(
816
+ (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
817
+ self.max_filters,
818
+ )
819
+ self.convs.append(
820
+ weight_norm(
821
+ nn.Conv2d(
822
+ in_chs,
823
+ out_chs,
824
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
825
+ padding=self.get_2d_padding(
826
+ (self.kernel_size[0], self.kernel_size[0])
827
+ ),
828
+ )
829
+ )
830
+ )
831
+
832
+ self.conv_post = weight_norm(
833
+ nn.Conv2d(
834
+ out_chs,
835
+ self.out_channels,
836
+ kernel_size=(self.kernel_size[0], self.kernel_size[0]),
837
+ padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
838
+ )
839
+ )
840
+
841
+ self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
842
+ self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
843
+
844
+ self.cqtd_normalize_volume = cqtd_normalize_volume
845
+ if self.cqtd_normalize_volume:
846
+ print(
847
+ f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
848
+ )
849
+
850
+ def get_2d_padding(
851
+ self,
852
+ kernel_size: Tuple[int, int],
853
+ dilation: Tuple[int, int] = (1, 1),
854
+ ):
855
+ return (
856
+ ((kernel_size[0] - 1) * dilation[0]) // 2,
857
+ ((kernel_size[1] - 1) * dilation[1]) // 2,
858
+ )
859
+
860
+ def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
861
+ fmap = []
862
+
863
+ if self.cqtd_normalize_volume:
864
+ # Remove DC offset
865
+ x = x - x.mean(dim=-1, keepdims=True)
866
+ # Peak normalize the volume of input audio
867
+ x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
868
+
869
+ x = self.resample(x)
870
+
871
+ z = self.cqt_transform(x)
872
+
873
+ z_amplitude = z[:, :, :, 0].unsqueeze(1)
874
+ z_phase = z[:, :, :, 1].unsqueeze(1)
875
+
876
+ z = torch.cat([z_amplitude, z_phase], dim=1)
877
+ z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
878
+
879
+ latent_z = []
880
+ for i in range(self.n_octaves):
881
+ latent_z.append(
882
+ self.conv_pres[i](
883
+ z[
884
+ :,
885
+ :,
886
+ :,
887
+ i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
888
+ ]
889
+ )
890
+ )
891
+ latent_z = torch.cat(latent_z, dim=-1)
892
+
893
+ for i, l in enumerate(self.convs):
894
+ latent_z = l(latent_z)
895
+
896
+ latent_z = self.activation(latent_z)
897
+ fmap.append(latent_z)
898
+
899
+ latent_z = self.conv_post(latent_z)
900
+
901
+ return latent_z, fmap
902
+
903
+
904
+ class MultiScaleSubbandCQTDiscriminator(_MultiDiscriminatorT):
905
+ def __init__(
906
+ self,
907
+ sampling_rate: int,
908
+ cqtd_filters: int = 128,
909
+ cqtd_max_filters: int = 1024,
910
+ cqtd_filters_scale: Number = 1,
911
+ cqtd_dilations: list = [1, 2, 4],
912
+ cqtd_hop_lengths: list = [512, 256, 256],
913
+ cqtd_n_octaves: list = [9, 9, 9],
914
+ cqtd_bins_per_octaves: list = [24, 36, 48],
915
+ cqtd_in_channels: int = 1,
916
+ cqtd_out_channels: int = 1,
917
+ cqtd_normalize_volume: bool = False,
918
+ ):
919
+ super().__init__()
920
+
921
+ self.discriminators = nn.ModuleList(
922
+ [
923
+ DiscriminatorCQT(
924
+ hop_length=cqtd_hop_lengths[i],
925
+ n_octaves=cqtd_n_octaves[i],
926
+ bins_per_octave=cqtd_bins_per_octaves[i],
927
+ sampling_rate=sampling_rate,
928
+ cqtd_filters=cqtd_filters,
929
+ cqtd_max_filters=cqtd_max_filters,
930
+ cqtd_filters_scale=cqtd_filters_scale,
931
+ cqtd_dilations=cqtd_dilations,
932
+ cqtd_in_channels=cqtd_in_channels,
933
+ cqtd_out_channels=cqtd_out_channels,
934
+ cqtd_normalize_volume=cqtd_normalize_volume,
935
+ )
936
+ for i in range(len(cqtd_hop_lengths))
937
+ ]
938
+ )
939
+
940
+ def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
941
+ List[torch.Tensor],
942
+ List[torch.Tensor],
943
+ List[List[torch.Tensor]],
944
+ List[List[torch.Tensor]],
945
+ ]:
946
+
947
+ y_d_rs = []
948
+ y_d_gs = []
949
+ fmap_rs = []
950
+ fmap_gs = []
951
+
952
+ for disc in self.discriminators:
953
+ y_d_r, fmap_r = disc(y)
954
+ y_d_g, fmap_g = disc(y_hat)
955
+ y_d_rs.append(y_d_r)
956
+ fmap_rs.append(fmap_r)
957
+ y_d_gs.append(y_d_g)
958
+ fmap_gs.append(fmap_g)
959
+
960
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs