xinference 0.14.4.post1__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_compat.py +51 -0
  2. xinference/_version.py +3 -3
  3. xinference/api/restful_api.py +209 -40
  4. xinference/client/restful/restful_client.py +7 -26
  5. xinference/conftest.py +1 -1
  6. xinference/constants.py +5 -0
  7. xinference/core/cache_tracker.py +1 -1
  8. xinference/core/chat_interface.py +8 -14
  9. xinference/core/event.py +1 -1
  10. xinference/core/image_interface.py +28 -0
  11. xinference/core/model.py +110 -31
  12. xinference/core/scheduler.py +37 -37
  13. xinference/core/status_guard.py +1 -1
  14. xinference/core/supervisor.py +17 -10
  15. xinference/core/utils.py +80 -22
  16. xinference/core/worker.py +17 -16
  17. xinference/deploy/cmdline.py +8 -16
  18. xinference/deploy/local.py +1 -1
  19. xinference/deploy/supervisor.py +1 -1
  20. xinference/deploy/utils.py +1 -1
  21. xinference/deploy/worker.py +1 -1
  22. xinference/model/audio/cosyvoice.py +86 -41
  23. xinference/model/audio/fish_speech.py +9 -9
  24. xinference/model/audio/model_spec.json +9 -9
  25. xinference/model/audio/whisper.py +4 -1
  26. xinference/model/embedding/core.py +52 -31
  27. xinference/model/image/core.py +2 -1
  28. xinference/model/image/model_spec.json +16 -4
  29. xinference/model/image/model_spec_modelscope.json +16 -4
  30. xinference/model/image/sdapi.py +136 -0
  31. xinference/model/image/stable_diffusion/core.py +164 -19
  32. xinference/model/llm/__init__.py +29 -11
  33. xinference/model/llm/llama_cpp/core.py +16 -33
  34. xinference/model/llm/llm_family.json +1011 -1296
  35. xinference/model/llm/llm_family.py +34 -53
  36. xinference/model/llm/llm_family_csghub.json +18 -35
  37. xinference/model/llm/llm_family_modelscope.json +981 -1122
  38. xinference/model/llm/lmdeploy/core.py +56 -88
  39. xinference/model/llm/mlx/core.py +46 -69
  40. xinference/model/llm/sglang/core.py +36 -18
  41. xinference/model/llm/transformers/chatglm.py +168 -306
  42. xinference/model/llm/transformers/cogvlm2.py +36 -63
  43. xinference/model/llm/transformers/cogvlm2_video.py +33 -223
  44. xinference/model/llm/transformers/core.py +55 -50
  45. xinference/model/llm/transformers/deepseek_v2.py +340 -0
  46. xinference/model/llm/transformers/deepseek_vl.py +53 -96
  47. xinference/model/llm/transformers/glm4v.py +55 -111
  48. xinference/model/llm/transformers/intern_vl.py +39 -70
  49. xinference/model/llm/transformers/internlm2.py +32 -54
  50. xinference/model/llm/transformers/minicpmv25.py +22 -55
  51. xinference/model/llm/transformers/minicpmv26.py +158 -68
  52. xinference/model/llm/transformers/omnilmm.py +5 -28
  53. xinference/model/llm/transformers/qwen2_audio.py +168 -0
  54. xinference/model/llm/transformers/qwen2_vl.py +234 -0
  55. xinference/model/llm/transformers/qwen_vl.py +34 -86
  56. xinference/model/llm/transformers/utils.py +32 -38
  57. xinference/model/llm/transformers/yi_vl.py +32 -72
  58. xinference/model/llm/utils.py +280 -554
  59. xinference/model/llm/vllm/core.py +161 -100
  60. xinference/model/rerank/core.py +41 -8
  61. xinference/model/rerank/model_spec.json +7 -0
  62. xinference/model/rerank/model_spec_modelscope.json +7 -1
  63. xinference/model/utils.py +1 -31
  64. xinference/thirdparty/cosyvoice/bin/export_jit.py +64 -0
  65. xinference/thirdparty/cosyvoice/bin/export_trt.py +8 -0
  66. xinference/thirdparty/cosyvoice/bin/inference.py +5 -2
  67. xinference/thirdparty/cosyvoice/cli/cosyvoice.py +38 -22
  68. xinference/thirdparty/cosyvoice/cli/model.py +139 -26
  69. xinference/thirdparty/cosyvoice/flow/flow.py +15 -9
  70. xinference/thirdparty/cosyvoice/flow/length_regulator.py +20 -1
  71. xinference/thirdparty/cosyvoice/hifigan/generator.py +8 -4
  72. xinference/thirdparty/cosyvoice/llm/llm.py +14 -13
  73. xinference/thirdparty/cosyvoice/transformer/attention.py +7 -3
  74. xinference/thirdparty/cosyvoice/transformer/decoder.py +1 -1
  75. xinference/thirdparty/cosyvoice/transformer/embedding.py +4 -3
  76. xinference/thirdparty/cosyvoice/transformer/encoder.py +4 -2
  77. xinference/thirdparty/cosyvoice/utils/common.py +36 -0
  78. xinference/thirdparty/cosyvoice/utils/file_utils.py +16 -0
  79. xinference/thirdparty/deepseek_vl/serve/assets/Kelpy-Codos.js +100 -0
  80. xinference/thirdparty/deepseek_vl/serve/assets/avatar.png +0 -0
  81. xinference/thirdparty/deepseek_vl/serve/assets/custom.css +355 -0
  82. xinference/thirdparty/deepseek_vl/serve/assets/custom.js +22 -0
  83. xinference/thirdparty/deepseek_vl/serve/assets/favicon.ico +0 -0
  84. xinference/thirdparty/deepseek_vl/serve/examples/app.png +0 -0
  85. xinference/thirdparty/deepseek_vl/serve/examples/chart.png +0 -0
  86. xinference/thirdparty/deepseek_vl/serve/examples/mirror.png +0 -0
  87. xinference/thirdparty/deepseek_vl/serve/examples/pipeline.png +0 -0
  88. xinference/thirdparty/deepseek_vl/serve/examples/puzzle.png +0 -0
  89. xinference/thirdparty/deepseek_vl/serve/examples/rap.jpeg +0 -0
  90. xinference/thirdparty/fish_speech/fish_speech/configs/base.yaml +87 -0
  91. xinference/thirdparty/fish_speech/fish_speech/configs/firefly_gan_vq.yaml +33 -0
  92. xinference/thirdparty/fish_speech/fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  93. xinference/thirdparty/fish_speech/fish_speech/configs/text2semantic_finetune.yaml +83 -0
  94. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text-data.proto +24 -0
  95. xinference/thirdparty/fish_speech/fish_speech/i18n/README.md +27 -0
  96. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +1 -1
  97. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +1 -1
  98. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +1 -1
  99. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +1 -1
  100. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +1 -1
  101. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +2 -2
  102. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +0 -3
  103. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +169 -198
  104. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +4 -27
  105. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/.gitignore +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/README.md +36 -0
  107. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +9 -47
  108. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +2 -2
  109. xinference/thirdparty/fish_speech/fish_speech/train.py +2 -0
  110. xinference/thirdparty/fish_speech/fish_speech/webui/css/style.css +161 -0
  111. xinference/thirdparty/fish_speech/fish_speech/webui/html/footer.html +11 -0
  112. xinference/thirdparty/fish_speech/fish_speech/webui/js/animate.js +69 -0
  113. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +12 -10
  114. xinference/thirdparty/fish_speech/tools/api.py +79 -134
  115. xinference/thirdparty/fish_speech/tools/commons.py +35 -0
  116. xinference/thirdparty/fish_speech/tools/download_models.py +3 -3
  117. xinference/thirdparty/fish_speech/tools/file.py +17 -0
  118. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +1 -1
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +29 -24
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +1 -1
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +2 -2
  122. xinference/thirdparty/fish_speech/tools/msgpack_api.py +34 -0
  123. xinference/thirdparty/fish_speech/tools/post_api.py +85 -44
  124. xinference/thirdparty/fish_speech/tools/sensevoice/README.md +59 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +1 -1
  126. xinference/thirdparty/fish_speech/tools/smart_pad.py +16 -3
  127. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +2 -2
  128. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +4 -2
  129. xinference/thirdparty/fish_speech/tools/webui.py +12 -146
  130. xinference/thirdparty/matcha/VERSION +1 -0
  131. xinference/thirdparty/matcha/hifigan/LICENSE +21 -0
  132. xinference/thirdparty/matcha/hifigan/README.md +101 -0
  133. xinference/thirdparty/omnilmm/LICENSE +201 -0
  134. xinference/thirdparty/whisper/__init__.py +156 -0
  135. xinference/thirdparty/whisper/__main__.py +3 -0
  136. xinference/thirdparty/whisper/assets/gpt2.tiktoken +50256 -0
  137. xinference/thirdparty/whisper/assets/mel_filters.npz +0 -0
  138. xinference/thirdparty/whisper/assets/multilingual.tiktoken +50257 -0
  139. xinference/thirdparty/whisper/audio.py +157 -0
  140. xinference/thirdparty/whisper/decoding.py +826 -0
  141. xinference/thirdparty/whisper/model.py +314 -0
  142. xinference/thirdparty/whisper/normalizers/__init__.py +2 -0
  143. xinference/thirdparty/whisper/normalizers/basic.py +76 -0
  144. xinference/thirdparty/whisper/normalizers/english.json +1741 -0
  145. xinference/thirdparty/whisper/normalizers/english.py +550 -0
  146. xinference/thirdparty/whisper/timing.py +386 -0
  147. xinference/thirdparty/whisper/tokenizer.py +395 -0
  148. xinference/thirdparty/whisper/transcribe.py +605 -0
  149. xinference/thirdparty/whisper/triton_ops.py +109 -0
  150. xinference/thirdparty/whisper/utils.py +316 -0
  151. xinference/thirdparty/whisper/version.py +1 -0
  152. xinference/types.py +14 -53
  153. xinference/web/ui/build/asset-manifest.json +6 -6
  154. xinference/web/ui/build/index.html +1 -1
  155. xinference/web/ui/build/static/css/{main.4bafd904.css → main.5061c4c3.css} +2 -2
  156. xinference/web/ui/build/static/css/main.5061c4c3.css.map +1 -0
  157. xinference/web/ui/build/static/js/main.754740c0.js +3 -0
  158. xinference/web/ui/build/static/js/{main.eb13fe95.js.LICENSE.txt → main.754740c0.js.LICENSE.txt} +2 -0
  159. xinference/web/ui/build/static/js/main.754740c0.js.map +1 -0
  160. xinference/web/ui/node_modules/.cache/babel-loader/10c69dc7a296779fcffedeff9393d832dfcb0013c36824adf623d3c518b801ff.json +1 -0
  161. xinference/web/ui/node_modules/.cache/babel-loader/68bede6d95bb5ef0b35bbb3ec5b8c937eaf6862c6cdbddb5ef222a7776aaf336.json +1 -0
  162. xinference/web/ui/node_modules/.cache/babel-loader/77d50223f3e734d4485cca538cb098a8c3a7a0a1a9f01f58cdda3af42fe1adf5.json +1 -0
  163. xinference/web/ui/node_modules/.cache/babel-loader/a56d5a642409a84988891089c98ca28ad0546432dfbae8aaa51bc5a280e1cdd2.json +1 -0
  164. xinference/web/ui/node_modules/.cache/babel-loader/cd90b08d177025dfe84209596fc51878f8a86bcaa6a240848a3d2e5fd4c7ff24.json +1 -0
  165. xinference/web/ui/node_modules/.cache/babel-loader/d9ff696a3e3471f01b46c63d18af32e491eb5dc0e43cb30202c96871466df57f.json +1 -0
  166. xinference/web/ui/node_modules/.cache/babel-loader/e42b72d4cc1ea412ebecbb8d040dc6c6bfee462c33903c2f1f3facb602ad742e.json +1 -0
  167. xinference/web/ui/node_modules/.cache/babel-loader/f5039ddbeb815c51491a1989532006b96fc3ae49c6c60e3c097f875b4ae915ae.json +1 -0
  168. xinference/web/ui/node_modules/.package-lock.json +37 -0
  169. xinference/web/ui/node_modules/a-sync-waterfall/package.json +21 -0
  170. xinference/web/ui/node_modules/nunjucks/node_modules/commander/package.json +48 -0
  171. xinference/web/ui/node_modules/nunjucks/package.json +112 -0
  172. xinference/web/ui/package-lock.json +38 -0
  173. xinference/web/ui/package.json +1 -0
  174. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/METADATA +16 -10
  175. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/RECORD +179 -127
  176. xinference/model/llm/transformers/llama_2.py +0 -108
  177. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +0 -442
  178. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +0 -44
  179. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +0 -115
  180. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +0 -225
  181. xinference/thirdparty/fish_speech/tools/auto_rerank.py +0 -159
  182. xinference/thirdparty/fish_speech/tools/gen_ref.py +0 -36
  183. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +0 -55
  184. xinference/web/ui/build/static/css/main.4bafd904.css.map +0 -1
  185. xinference/web/ui/build/static/js/main.eb13fe95.js +0 -3
  186. xinference/web/ui/build/static/js/main.eb13fe95.js.map +0 -1
  187. xinference/web/ui/node_modules/.cache/babel-loader/0b11a5339468c13b2d31ac085e7effe4303259b2071abd46a0a8eb8529233a5e.json +0 -1
  188. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +0 -1
  189. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +0 -1
  190. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +0 -1
  191. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.4.post1.dist-info → xinference-0.15.1.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,26 @@
1
- # A inference only version of the FireflyGAN model
2
-
3
1
  import math
4
2
  from functools import partial
5
3
  from math import prod
6
4
  from typing import Callable
7
5
 
8
- import numpy as np
9
6
  import torch
10
7
  import torch.nn.functional as F
11
8
  from torch import nn
12
- from torch.nn import Conv1d
13
9
  from torch.nn.utils.parametrizations import weight_norm
14
10
  from torch.nn.utils.parametrize import remove_parametrizations
15
11
  from torch.utils.checkpoint import checkpoint
16
12
 
17
- from fish_speech.models.vqgan.utils import sequence_mask
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
18
19
 
19
20
 
20
21
  def init_weights(m, mean=0.0, std=0.01):
21
22
  classname = m.__class__.__name__
22
- if classname.find("Conv") != -1:
23
+ if classname.find("Conv1D") != -1:
23
24
  m.weight.data.normal_(mean, std)
24
25
 
25
26
 
@@ -27,78 +28,141 @@ def get_padding(kernel_size, dilation=1):
27
28
  return (kernel_size * dilation - dilation) // 2
28
29
 
29
30
 
31
+ def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
+ padding_left, padding_right = paddings
34
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
+ assert (padding_left + padding_right) <= x.shape[-1]
36
+ end = x.shape[-1] - padding_right
37
+ return x[..., padding_left:end]
38
+
39
+
40
+ def get_extra_padding_for_conv1d(
41
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
+ ) -> int:
43
+ """See `pad_for_conv1d`."""
44
+ length = x.shape[-1]
45
+ n_frames = (length - kernel_size + padding_total) / stride + 1
46
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
+ return ideal_length - length
48
+
49
+
50
+ def pad1d(
51
+ x: torch.Tensor,
52
+ paddings: tuple[int, int],
53
+ mode: str = "zeros",
54
+ value: float = 0.0,
55
+ ):
56
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
+ If this is the case, we insert extra 0 padding to the right
58
+ before the reflection happen.
59
+ """
60
+ length = x.shape[-1]
61
+ padding_left, padding_right = paddings
62
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
+ if mode == "reflect":
64
+ max_pad = max(padding_left, padding_right)
65
+ extra_pad = 0
66
+ if length <= max_pad:
67
+ extra_pad = max_pad - length + 1
68
+ x = F.pad(x, (0, extra_pad))
69
+ padded = F.pad(x, paddings, mode, value)
70
+ end = padded.shape[-1] - extra_pad
71
+ return padded[..., :end]
72
+ else:
73
+ return F.pad(x, paddings, mode, value)
74
+
75
+
76
+ class FishConvNet(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
+ ):
80
+ super(FishConvNet, self).__init__()
81
+ self.conv = nn.Conv1d(
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size,
85
+ stride=stride,
86
+ dilation=dilation,
87
+ groups=groups,
88
+ )
89
+ self.stride = stride
90
+ self.kernel_size = (kernel_size - 1) * dilation + 1
91
+ self.dilation = dilation
92
+
93
+ def forward(self, x):
94
+ pad = self.kernel_size - self.stride
95
+ extra_padding = get_extra_padding_for_conv1d(
96
+ x, self.kernel_size, self.stride, pad
97
+ )
98
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
+ return self.conv(x).contiguous()
100
+
101
+ def weight_norm(self, name="weight", dim=0):
102
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
103
+ return self
104
+
105
+ def remove_weight_norm(self):
106
+ self.conv = remove_parametrizations(self.conv)
107
+ return self
108
+
109
+
110
+ class FishTransConvNet(nn.Module):
111
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
+ super(FishTransConvNet, self).__init__()
113
+ self.conv = nn.ConvTranspose1d(
114
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
+ )
116
+ self.stride = stride
117
+ self.kernel_size = kernel_size
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ pad = self.kernel_size - self.stride
122
+ padding_right = math.ceil(pad)
123
+ padding_left = pad - padding_right
124
+ x = unpad1d(x, (padding_left, padding_right))
125
+ return x.contiguous()
126
+
127
+ def weight_norm(self, name="weight", dim=0):
128
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
129
+ return self
130
+
131
+ def remove_weight_norm(self):
132
+ self.conv = remove_parametrizations(self.conv)
133
+ return self
134
+
135
+
30
136
  class ResBlock1(torch.nn.Module):
31
137
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
32
138
  super().__init__()
33
139
 
34
140
  self.convs1 = nn.ModuleList(
35
141
  [
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[0],
43
- padding=get_padding(kernel_size, dilation[0]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[1],
53
- padding=get_padding(kernel_size, dilation[1]),
54
- )
55
- ),
56
- weight_norm(
57
- Conv1d(
58
- channels,
59
- channels,
60
- kernel_size,
61
- 1,
62
- dilation=dilation[2],
63
- padding=get_padding(kernel_size, dilation[2]),
64
- )
65
- ),
142
+ FishConvNet(
143
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
+ ).weight_norm(),
145
+ FishConvNet(
146
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
+ ).weight_norm(),
148
+ FishConvNet(
149
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
+ ).weight_norm(),
66
151
  ]
67
152
  )
68
153
  self.convs1.apply(init_weights)
69
154
 
70
155
  self.convs2 = nn.ModuleList(
71
156
  [
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- weight_norm(
93
- Conv1d(
94
- channels,
95
- channels,
96
- kernel_size,
97
- 1,
98
- dilation=1,
99
- padding=get_padding(kernel_size, 1),
100
- )
101
- ),
157
+ FishConvNet(
158
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
+ ).weight_norm(),
160
+ FishConvNet(
161
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
+ ).weight_norm(),
163
+ FishConvNet(
164
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
+ ).weight_norm(),
102
166
  ]
103
167
  )
104
168
  self.convs2.apply(init_weights)
@@ -119,7 +183,7 @@ class ResBlock1(torch.nn.Module):
119
183
  remove_parametrizations(conv, tensor_name="weight")
120
184
 
121
185
 
122
- class ParralelBlock(nn.Module):
186
+ class ParallelBlock(nn.Module):
123
187
  def __init__(
124
188
  self,
125
189
  channels: int,
@@ -153,7 +217,6 @@ class HiFiGANGenerator(nn.Module):
153
217
  resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
154
218
  num_mels: int = 128,
155
219
  upsample_initial_channel: int = 512,
156
- use_template: bool = True,
157
220
  pre_conv_kernel_size: int = 7,
158
221
  post_conv_kernel_size: int = 7,
159
222
  post_activation: Callable = partial(nn.SiLU, inplace=True),
@@ -164,85 +227,51 @@ class HiFiGANGenerator(nn.Module):
164
227
  prod(upsample_rates) == hop_length
165
228
  ), f"hop_length must be {prod(upsample_rates)}"
166
229
 
167
- self.conv_pre = weight_norm(
168
- nn.Conv1d(
169
- num_mels,
170
- upsample_initial_channel,
171
- pre_conv_kernel_size,
172
- 1,
173
- padding=get_padding(pre_conv_kernel_size),
174
- )
175
- )
230
+ self.conv_pre = FishConvNet(
231
+ num_mels,
232
+ upsample_initial_channel,
233
+ pre_conv_kernel_size,
234
+ stride=1,
235
+ ).weight_norm()
176
236
 
177
237
  self.num_upsamples = len(upsample_rates)
178
238
  self.num_kernels = len(resblock_kernel_sizes)
179
239
 
180
240
  self.noise_convs = nn.ModuleList()
181
- self.use_template = use_template
182
241
  self.ups = nn.ModuleList()
183
242
 
184
243
  for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
185
- c_cur = upsample_initial_channel // (2 ** (i + 1))
186
244
  self.ups.append(
187
- weight_norm(
188
- nn.ConvTranspose1d(
189
- upsample_initial_channel // (2**i),
190
- upsample_initial_channel // (2 ** (i + 1)),
191
- k,
192
- u,
193
- padding=(k - u) // 2,
194
- )
195
- )
245
+ FishTransConvNet(
246
+ upsample_initial_channel // (2**i),
247
+ upsample_initial_channel // (2 ** (i + 1)),
248
+ k,
249
+ stride=u,
250
+ ).weight_norm()
196
251
  )
197
252
 
198
- if not use_template:
199
- continue
200
-
201
- if i + 1 < len(upsample_rates):
202
- stride_f0 = np.prod(upsample_rates[i + 1 :])
203
- self.noise_convs.append(
204
- Conv1d(
205
- 1,
206
- c_cur,
207
- kernel_size=stride_f0 * 2,
208
- stride=stride_f0,
209
- padding=stride_f0 // 2,
210
- )
211
- )
212
- else:
213
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
214
-
215
253
  self.resblocks = nn.ModuleList()
216
254
  for i in range(len(self.ups)):
217
255
  ch = upsample_initial_channel // (2 ** (i + 1))
218
256
  self.resblocks.append(
219
- ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
257
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
220
258
  )
221
259
 
222
260
  self.activation_post = post_activation()
223
- self.conv_post = weight_norm(
224
- nn.Conv1d(
225
- ch,
226
- 1,
227
- post_conv_kernel_size,
228
- 1,
229
- padding=get_padding(post_conv_kernel_size),
230
- )
231
- )
261
+ self.conv_post = FishConvNet(
262
+ ch, 1, post_conv_kernel_size, stride=1
263
+ ).weight_norm()
232
264
  self.ups.apply(init_weights)
233
265
  self.conv_post.apply(init_weights)
234
266
 
235
- def forward(self, x, template=None):
267
+ def forward(self, x):
236
268
  x = self.conv_pre(x)
237
269
 
238
270
  for i in range(self.num_upsamples):
239
271
  x = F.silu(x, inplace=True)
240
272
  x = self.ups[i](x)
241
273
 
242
- if self.use_template:
243
- x = x + self.noise_convs[i](template)
244
-
245
- if self.training:
274
+ if self.training and self.checkpointing:
246
275
  x = checkpoint(
247
276
  self.resblocks[i],
248
277
  x,
@@ -364,11 +393,11 @@ class ConvNeXtBlock(nn.Module):
364
393
  ):
365
394
  super().__init__()
366
395
 
367
- self.dwconv = nn.Conv1d(
396
+ self.dwconv = FishConvNet(
368
397
  dim,
369
398
  dim,
370
399
  kernel_size=kernel_size,
371
- padding=int(dilation * (kernel_size - 1) / 2),
400
+ # padding=int(dilation * (kernel_size - 1) / 2),
372
401
  groups=dim,
373
402
  ) # depthwise conv
374
403
  self.norm = LayerNorm(dim, eps=1e-6)
@@ -421,12 +450,13 @@ class ConvNeXtEncoder(nn.Module):
421
450
 
422
451
  self.downsample_layers = nn.ModuleList()
423
452
  stem = nn.Sequential(
424
- nn.Conv1d(
453
+ FishConvNet(
425
454
  input_channels,
426
455
  dims[0],
427
- kernel_size=kernel_size,
428
- padding=kernel_size // 2,
429
- padding_mode="zeros",
456
+ kernel_size=7,
457
+ # padding=3,
458
+ # padding_mode="replicate",
459
+ # padding_mode="zeros",
430
460
  ),
431
461
  LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
432
462
  )
@@ -491,6 +521,7 @@ class FireflyArchitecture(nn.Module):
491
521
  self.head = head
492
522
  self.quantizer = quantizer
493
523
  self.spec_transform = spec_transform
524
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
494
525
 
495
526
  def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
496
527
  if self.spec_transform is not None:
@@ -512,7 +543,7 @@ class FireflyArchitecture(nn.Module):
512
543
  if x.ndim == 2:
513
544
  x = x[:, None, :]
514
545
 
515
- if self.quantizer is not None:
546
+ if self.vq is not None:
516
547
  return x, vq_result
517
548
 
518
549
  return x
@@ -528,25 +559,30 @@ class FireflyArchitecture(nn.Module):
528
559
 
529
560
  # Encode
530
561
  encoded_features = self.backbone(mels) * mel_masks_float_conv
531
- feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
562
+ feature_lengths = mel_lengths // self.downsample_factor
532
563
 
533
564
  return self.quantizer.encode(encoded_features), feature_lengths
534
565
 
535
566
  def decode(self, indices, feature_lengths) -> torch.Tensor:
536
- factor = math.prod(self.quantizer.downsample_factor)
537
- mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
567
+ mel_masks = sequence_mask(
568
+ feature_lengths * self.downsample_factor,
569
+ indices.shape[2] * self.downsample_factor,
570
+ )
538
571
  mel_masks_float_conv = mel_masks[:, None, :].float()
572
+ audio_lengths = (
573
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
+ )
539
575
 
540
576
  audio_masks = sequence_mask(
541
- feature_lengths * factor * self.spec_transform.hop_length,
542
- indices.shape[2] * factor * self.spec_transform.hop_length,
577
+ audio_lengths,
578
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
543
579
  )
544
580
  audio_masks_float_conv = audio_masks[:, None, :].float()
545
581
 
546
582
  z = self.quantizer.decode(indices) * mel_masks_float_conv
547
583
  x = self.head(z) * audio_masks_float_conv
548
584
 
549
- return x
585
+ return x, audio_lengths
550
586
 
551
587
  def remove_parametrizations(self):
552
588
  if hasattr(self.backbone, "remove_parametrizations"):
@@ -558,68 +594,3 @@ class FireflyArchitecture(nn.Module):
558
594
  @property
559
595
  def device(self):
560
596
  return next(self.parameters()).device
561
-
562
-
563
- class FireflyBase(nn.Module):
564
- def __init__(self, ckpt_path: str = None, pretrained: bool = True):
565
- super().__init__()
566
-
567
- self.backbone = ConvNeXtEncoder(
568
- input_channels=128,
569
- depths=[3, 3, 9, 3],
570
- dims=[128, 256, 384, 512],
571
- drop_path_rate=0.2,
572
- kernel_size=7,
573
- )
574
-
575
- self.head = HiFiGANGenerator(
576
- hop_length=512,
577
- upsample_rates=[8, 8, 2, 2, 2],
578
- upsample_kernel_sizes=[16, 16, 4, 4, 4],
579
- resblock_kernel_sizes=[3, 7, 11],
580
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
581
- num_mels=512,
582
- upsample_initial_channel=512,
583
- use_template=False,
584
- pre_conv_kernel_size=13,
585
- post_conv_kernel_size=13,
586
- )
587
-
588
- if ckpt_path is not None:
589
- state_dict = torch.load(ckpt_path, map_location="cpu")
590
- elif pretrained:
591
- state_dict = torch.hub.load_state_dict_from_url(
592
- "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
593
- map_location="cpu",
594
- model_dir="checkpoints",
595
- )
596
-
597
- if "state_dict" in state_dict:
598
- state_dict = state_dict["state_dict"]
599
-
600
- if any("generator." in k for k in state_dict):
601
- state_dict = {
602
- k.replace("generator.", ""): v
603
- for k, v in state_dict.items()
604
- if "generator." in k
605
- }
606
-
607
- self.load_state_dict(state_dict, strict=True)
608
- self.head.remove_parametrizations()
609
-
610
- @torch.no_grad()
611
- def forward(self, x: torch.Tensor) -> torch.Tensor:
612
- x = self.backbone(x)
613
- x = self.head(x)
614
- if x.ndim == 2:
615
- x = x[:, None, :]
616
- return x
617
-
618
-
619
- if __name__ == "__main__":
620
- model = FireflyBase()
621
- model.eval()
622
- x = torch.randn(1, 128, 128)
623
- with torch.no_grad():
624
- y = model(x)
625
- print(y.shape)
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
6
  from einops import rearrange
7
7
  from vector_quantize_pytorch import GroupedResidualFSQ
8
8
 
9
- from .firefly import ConvNeXtBlock
9
+ from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
10
 
11
11
 
12
12
  @dataclass
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
20
20
  def __init__(
21
21
  self,
22
22
  input_dim: int = 512,
23
- n_codebooks: int = 1,
23
+ n_codebooks: int = 9,
24
24
  n_groups: int = 1,
25
25
  levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
26
  downsample_factor: tuple[int] = (2, 2),
@@ -46,7 +46,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
46
46
  self.downsample = nn.Sequential(
47
47
  *[
48
48
  nn.Sequential(
49
- nn.Conv1d(
49
+ FishConvNet(
50
50
  all_dims[idx],
51
51
  all_dims[idx + 1],
52
52
  kernel_size=factor,
@@ -61,7 +61,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
61
61
  self.upsample = nn.Sequential(
62
62
  *[
63
63
  nn.Sequential(
64
- nn.ConvTranspose1d(
64
+ FishTransConvNet(
65
65
  all_dims[idx + 1],
66
66
  all_dims[idx],
67
67
  kernel_size=factor,
@@ -114,26 +114,3 @@ class DownsampleFiniteScalarQuantize(nn.Module):
114
114
  z_q = self.residual_fsq.get_output_from_indices(indices)
115
115
  z_q = self.upsample(z_q.mT)
116
116
  return z_q
117
-
118
- # def from_latents(self, latents: torch.Tensor):
119
- # z_q, z_p, codes = super().from_latents(latents)
120
- # z_q = self.upsample(z_q)
121
- # return z_q, z_p, codes
122
-
123
-
124
- if __name__ == "__main__":
125
- rvq = DownsampleFiniteScalarQuantize(
126
- n_codebooks=1,
127
- downsample_factor=(2, 2),
128
- )
129
- x = torch.randn(16, 512, 80)
130
-
131
- result = rvq(x)
132
- print(rvq)
133
- print(result.latents.shape, result.codes.shape, result.z.shape)
134
-
135
- # y = rvq.from_codes(result.codes)
136
- # print(y[0].shape)
137
-
138
- # y = rvq.from_latents(result.latents)
139
- # print(y[0].shape)
@@ -0,0 +1,114 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # JetBrains PyCharm
107
+ .idea
108
+
109
+ # Customize
110
+ references
111
+ url.txt
112
+
113
+ # Git
114
+ .git