xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,625 @@
1
+ # A inference only version of the FireflyGAN model
2
+
3
+ import math
4
+ from functools import partial
5
+ from math import prod
6
+ from typing import Callable
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn import Conv1d
13
+ from torch.nn.utils.parametrizations import weight_norm
14
+ from torch.nn.utils.parametrize import remove_parametrizations
15
+ from torch.utils.checkpoint import checkpoint
16
+
17
+ from fish_speech.models.vqgan.utils import sequence_mask
18
+
19
+
20
+ def init_weights(m, mean=0.0, std=0.01):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ m.weight.data.normal_(mean, std)
24
+
25
+
26
+ def get_padding(kernel_size, dilation=1):
27
+ return (kernel_size * dilation - dilation) // 2
28
+
29
+
30
+ class ResBlock1(torch.nn.Module):
31
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
32
+ super().__init__()
33
+
34
+ self.convs1 = nn.ModuleList(
35
+ [
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[0],
43
+ padding=get_padding(kernel_size, dilation[0]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[1],
53
+ padding=get_padding(kernel_size, dilation[1]),
54
+ )
55
+ ),
56
+ weight_norm(
57
+ Conv1d(
58
+ channels,
59
+ channels,
60
+ kernel_size,
61
+ 1,
62
+ dilation=dilation[2],
63
+ padding=get_padding(kernel_size, dilation[2]),
64
+ )
65
+ ),
66
+ ]
67
+ )
68
+ self.convs1.apply(init_weights)
69
+
70
+ self.convs2 = nn.ModuleList(
71
+ [
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ weight_norm(
93
+ Conv1d(
94
+ channels,
95
+ channels,
96
+ kernel_size,
97
+ 1,
98
+ dilation=1,
99
+ padding=get_padding(kernel_size, 1),
100
+ )
101
+ ),
102
+ ]
103
+ )
104
+ self.convs2.apply(init_weights)
105
+
106
+ def forward(self, x):
107
+ for c1, c2 in zip(self.convs1, self.convs2):
108
+ xt = F.silu(x)
109
+ xt = c1(xt)
110
+ xt = F.silu(xt)
111
+ xt = c2(xt)
112
+ x = xt + x
113
+ return x
114
+
115
+ def remove_parametrizations(self):
116
+ for conv in self.convs1:
117
+ remove_parametrizations(conv, tensor_name="weight")
118
+ for conv in self.convs2:
119
+ remove_parametrizations(conv, tensor_name="weight")
120
+
121
+
122
+ class ParralelBlock(nn.Module):
123
+ def __init__(
124
+ self,
125
+ channels: int,
126
+ kernel_sizes: tuple[int] = (3, 7, 11),
127
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
128
+ ):
129
+ super().__init__()
130
+
131
+ assert len(kernel_sizes) == len(dilation_sizes)
132
+
133
+ self.blocks = nn.ModuleList()
134
+ for k, d in zip(kernel_sizes, dilation_sizes):
135
+ self.blocks.append(ResBlock1(channels, k, d))
136
+
137
+ def forward(self, x):
138
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
139
+
140
+ def remove_parametrizations(self):
141
+ for block in self.blocks:
142
+ block.remove_parametrizations()
143
+
144
+
145
+ class HiFiGANGenerator(nn.Module):
146
+ def __init__(
147
+ self,
148
+ *,
149
+ hop_length: int = 512,
150
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
151
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
152
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
153
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
154
+ num_mels: int = 128,
155
+ upsample_initial_channel: int = 512,
156
+ use_template: bool = True,
157
+ pre_conv_kernel_size: int = 7,
158
+ post_conv_kernel_size: int = 7,
159
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
160
+ ):
161
+ super().__init__()
162
+
163
+ assert (
164
+ prod(upsample_rates) == hop_length
165
+ ), f"hop_length must be {prod(upsample_rates)}"
166
+
167
+ self.conv_pre = weight_norm(
168
+ nn.Conv1d(
169
+ num_mels,
170
+ upsample_initial_channel,
171
+ pre_conv_kernel_size,
172
+ 1,
173
+ padding=get_padding(pre_conv_kernel_size),
174
+ )
175
+ )
176
+
177
+ self.num_upsamples = len(upsample_rates)
178
+ self.num_kernels = len(resblock_kernel_sizes)
179
+
180
+ self.noise_convs = nn.ModuleList()
181
+ self.use_template = use_template
182
+ self.ups = nn.ModuleList()
183
+
184
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
185
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
186
+ self.ups.append(
187
+ weight_norm(
188
+ nn.ConvTranspose1d(
189
+ upsample_initial_channel // (2**i),
190
+ upsample_initial_channel // (2 ** (i + 1)),
191
+ k,
192
+ u,
193
+ padding=(k - u) // 2,
194
+ )
195
+ )
196
+ )
197
+
198
+ if not use_template:
199
+ continue
200
+
201
+ if i + 1 < len(upsample_rates):
202
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
203
+ self.noise_convs.append(
204
+ Conv1d(
205
+ 1,
206
+ c_cur,
207
+ kernel_size=stride_f0 * 2,
208
+ stride=stride_f0,
209
+ padding=stride_f0 // 2,
210
+ )
211
+ )
212
+ else:
213
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
214
+
215
+ self.resblocks = nn.ModuleList()
216
+ for i in range(len(self.ups)):
217
+ ch = upsample_initial_channel // (2 ** (i + 1))
218
+ self.resblocks.append(
219
+ ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
220
+ )
221
+
222
+ self.activation_post = post_activation()
223
+ self.conv_post = weight_norm(
224
+ nn.Conv1d(
225
+ ch,
226
+ 1,
227
+ post_conv_kernel_size,
228
+ 1,
229
+ padding=get_padding(post_conv_kernel_size),
230
+ )
231
+ )
232
+ self.ups.apply(init_weights)
233
+ self.conv_post.apply(init_weights)
234
+
235
+ def forward(self, x, template=None):
236
+ x = self.conv_pre(x)
237
+
238
+ for i in range(self.num_upsamples):
239
+ x = F.silu(x, inplace=True)
240
+ x = self.ups[i](x)
241
+
242
+ if self.use_template:
243
+ x = x + self.noise_convs[i](template)
244
+
245
+ if self.training:
246
+ x = checkpoint(
247
+ self.resblocks[i],
248
+ x,
249
+ use_reentrant=False,
250
+ )
251
+ else:
252
+ x = self.resblocks[i](x)
253
+
254
+ x = self.activation_post(x)
255
+ x = self.conv_post(x)
256
+ x = torch.tanh(x)
257
+
258
+ return x
259
+
260
+ def remove_parametrizations(self):
261
+ for up in self.ups:
262
+ remove_parametrizations(up, tensor_name="weight")
263
+ for block in self.resblocks:
264
+ block.remove_parametrizations()
265
+ remove_parametrizations(self.conv_pre, tensor_name="weight")
266
+ remove_parametrizations(self.conv_post, tensor_name="weight")
267
+
268
+
269
+ # DropPath copied from timm library
270
+ def drop_path(
271
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
272
+ ):
273
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
274
+
275
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
276
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
277
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
278
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
279
+ 'survival rate' as the argument.
280
+
281
+ """ # noqa: E501
282
+
283
+ if drop_prob == 0.0 or not training:
284
+ return x
285
+ keep_prob = 1 - drop_prob
286
+ shape = (x.shape[0],) + (1,) * (
287
+ x.ndim - 1
288
+ ) # work with diff dim tensors, not just 2D ConvNets
289
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
290
+ if keep_prob > 0.0 and scale_by_keep:
291
+ random_tensor.div_(keep_prob)
292
+ return x * random_tensor
293
+
294
+
295
+ class DropPath(nn.Module):
296
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
297
+
298
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
299
+ super(DropPath, self).__init__()
300
+ self.drop_prob = drop_prob
301
+ self.scale_by_keep = scale_by_keep
302
+
303
+ def forward(self, x):
304
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
305
+
306
+ def extra_repr(self):
307
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
308
+
309
+
310
+ class LayerNorm(nn.Module):
311
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
312
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
313
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
314
+ with shape (batch_size, channels, height, width).
315
+ """ # noqa: E501
316
+
317
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
318
+ super().__init__()
319
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
320
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
321
+ self.eps = eps
322
+ self.data_format = data_format
323
+ if self.data_format not in ["channels_last", "channels_first"]:
324
+ raise NotImplementedError
325
+ self.normalized_shape = (normalized_shape,)
326
+
327
+ def forward(self, x):
328
+ if self.data_format == "channels_last":
329
+ return F.layer_norm(
330
+ x, self.normalized_shape, self.weight, self.bias, self.eps
331
+ )
332
+ elif self.data_format == "channels_first":
333
+ u = x.mean(1, keepdim=True)
334
+ s = (x - u).pow(2).mean(1, keepdim=True)
335
+ x = (x - u) / torch.sqrt(s + self.eps)
336
+ x = self.weight[:, None] * x + self.bias[:, None]
337
+ return x
338
+
339
+
340
+ # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
341
+ class ConvNeXtBlock(nn.Module):
342
+ r"""ConvNeXt Block. There are two equivalent implementations:
343
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
344
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
345
+ We use (2) as we find it slightly faster in PyTorch
346
+
347
+ Args:
348
+ dim (int): Number of input channels.
349
+ drop_path (float): Stochastic depth rate. Default: 0.0
350
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
351
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
352
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
353
+ dilation (int): Dilation for depthwise conv. Default: 1.
354
+ """ # noqa: E501
355
+
356
+ def __init__(
357
+ self,
358
+ dim: int,
359
+ drop_path: float = 0.0,
360
+ layer_scale_init_value: float = 1e-6,
361
+ mlp_ratio: float = 4.0,
362
+ kernel_size: int = 7,
363
+ dilation: int = 1,
364
+ ):
365
+ super().__init__()
366
+
367
+ self.dwconv = nn.Conv1d(
368
+ dim,
369
+ dim,
370
+ kernel_size=kernel_size,
371
+ padding=int(dilation * (kernel_size - 1) / 2),
372
+ groups=dim,
373
+ ) # depthwise conv
374
+ self.norm = LayerNorm(dim, eps=1e-6)
375
+ self.pwconv1 = nn.Linear(
376
+ dim, int(mlp_ratio * dim)
377
+ ) # pointwise/1x1 convs, implemented with linear layers
378
+ self.act = nn.GELU()
379
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
380
+ self.gamma = (
381
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
382
+ if layer_scale_init_value > 0
383
+ else None
384
+ )
385
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
386
+
387
+ def forward(self, x, apply_residual: bool = True):
388
+ input = x
389
+
390
+ x = self.dwconv(x)
391
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
392
+ x = self.norm(x)
393
+ x = self.pwconv1(x)
394
+ x = self.act(x)
395
+ x = self.pwconv2(x)
396
+
397
+ if self.gamma is not None:
398
+ x = self.gamma * x
399
+
400
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
401
+ x = self.drop_path(x)
402
+
403
+ if apply_residual:
404
+ x = input + x
405
+
406
+ return x
407
+
408
+
409
+ class ConvNeXtEncoder(nn.Module):
410
+ def __init__(
411
+ self,
412
+ input_channels: int = 3,
413
+ depths: list[int] = [3, 3, 9, 3],
414
+ dims: list[int] = [96, 192, 384, 768],
415
+ drop_path_rate: float = 0.0,
416
+ layer_scale_init_value: float = 1e-6,
417
+ kernel_size: int = 7,
418
+ ):
419
+ super().__init__()
420
+ assert len(depths) == len(dims)
421
+
422
+ self.downsample_layers = nn.ModuleList()
423
+ stem = nn.Sequential(
424
+ nn.Conv1d(
425
+ input_channels,
426
+ dims[0],
427
+ kernel_size=kernel_size,
428
+ padding=kernel_size // 2,
429
+ padding_mode="zeros",
430
+ ),
431
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
432
+ )
433
+ self.downsample_layers.append(stem)
434
+
435
+ for i in range(len(depths) - 1):
436
+ mid_layer = nn.Sequential(
437
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
438
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
439
+ )
440
+ self.downsample_layers.append(mid_layer)
441
+
442
+ self.stages = nn.ModuleList()
443
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
444
+
445
+ cur = 0
446
+ for i in range(len(depths)):
447
+ stage = nn.Sequential(
448
+ *[
449
+ ConvNeXtBlock(
450
+ dim=dims[i],
451
+ drop_path=dp_rates[cur + j],
452
+ layer_scale_init_value=layer_scale_init_value,
453
+ kernel_size=kernel_size,
454
+ )
455
+ for j in range(depths[i])
456
+ ]
457
+ )
458
+ self.stages.append(stage)
459
+ cur += depths[i]
460
+
461
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
462
+ self.apply(self._init_weights)
463
+
464
+ def _init_weights(self, m):
465
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
466
+ nn.init.trunc_normal_(m.weight, std=0.02)
467
+ nn.init.constant_(m.bias, 0)
468
+
469
+ def forward(
470
+ self,
471
+ x: torch.Tensor,
472
+ ) -> torch.Tensor:
473
+ for i in range(len(self.downsample_layers)):
474
+ x = self.downsample_layers[i](x)
475
+ x = self.stages[i](x)
476
+
477
+ return self.norm(x)
478
+
479
+
480
+ class FireflyArchitecture(nn.Module):
481
+ def __init__(
482
+ self,
483
+ backbone: nn.Module,
484
+ head: nn.Module,
485
+ quantizer: nn.Module,
486
+ spec_transform: nn.Module,
487
+ ):
488
+ super().__init__()
489
+
490
+ self.backbone = backbone
491
+ self.head = head
492
+ self.quantizer = quantizer
493
+ self.spec_transform = spec_transform
494
+
495
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
496
+ if self.spec_transform is not None:
497
+ x = self.spec_transform(x)
498
+
499
+ x = self.backbone(x)
500
+ if mask is not None:
501
+ x = x * mask
502
+
503
+ if self.quantizer is not None:
504
+ vq_result = self.quantizer(x)
505
+ x = vq_result.z
506
+
507
+ if mask is not None:
508
+ x = x * mask
509
+
510
+ x = self.head(x, template=template)
511
+
512
+ if x.ndim == 2:
513
+ x = x[:, None, :]
514
+
515
+ if self.quantizer is not None:
516
+ return x, vq_result
517
+
518
+ return x
519
+
520
+ def encode(self, audios, audio_lengths):
521
+ audios = audios.float()
522
+
523
+ mels = self.spec_transform(audios)
524
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
525
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
526
+ mel_masks_float_conv = mel_masks[:, None, :].float()
527
+ mels = mels * mel_masks_float_conv
528
+
529
+ # Encode
530
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
531
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
532
+
533
+ return self.quantizer.encode(encoded_features), feature_lengths
534
+
535
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
536
+ factor = math.prod(self.quantizer.downsample_factor)
537
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
538
+ mel_masks_float_conv = mel_masks[:, None, :].float()
539
+
540
+ audio_masks = sequence_mask(
541
+ feature_lengths * factor * self.spec_transform.hop_length,
542
+ indices.shape[2] * factor * self.spec_transform.hop_length,
543
+ )
544
+ audio_masks_float_conv = audio_masks[:, None, :].float()
545
+
546
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
547
+ x = self.head(z) * audio_masks_float_conv
548
+
549
+ return x
550
+
551
+ def remove_parametrizations(self):
552
+ if hasattr(self.backbone, "remove_parametrizations"):
553
+ self.backbone.remove_parametrizations()
554
+
555
+ if hasattr(self.head, "remove_parametrizations"):
556
+ self.head.remove_parametrizations()
557
+
558
+ @property
559
+ def device(self):
560
+ return next(self.parameters()).device
561
+
562
+
563
+ class FireflyBase(nn.Module):
564
+ def __init__(self, ckpt_path: str = None, pretrained: bool = True):
565
+ super().__init__()
566
+
567
+ self.backbone = ConvNeXtEncoder(
568
+ input_channels=128,
569
+ depths=[3, 3, 9, 3],
570
+ dims=[128, 256, 384, 512],
571
+ drop_path_rate=0.2,
572
+ kernel_size=7,
573
+ )
574
+
575
+ self.head = HiFiGANGenerator(
576
+ hop_length=512,
577
+ upsample_rates=[8, 8, 2, 2, 2],
578
+ upsample_kernel_sizes=[16, 16, 4, 4, 4],
579
+ resblock_kernel_sizes=[3, 7, 11],
580
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
581
+ num_mels=512,
582
+ upsample_initial_channel=512,
583
+ use_template=False,
584
+ pre_conv_kernel_size=13,
585
+ post_conv_kernel_size=13,
586
+ )
587
+
588
+ if ckpt_path is not None:
589
+ state_dict = torch.load(ckpt_path, map_location="cpu")
590
+ elif pretrained:
591
+ state_dict = torch.hub.load_state_dict_from_url(
592
+ "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
593
+ map_location="cpu",
594
+ model_dir="checkpoints",
595
+ )
596
+
597
+ if "state_dict" in state_dict:
598
+ state_dict = state_dict["state_dict"]
599
+
600
+ if any("generator." in k for k in state_dict):
601
+ state_dict = {
602
+ k.replace("generator.", ""): v
603
+ for k, v in state_dict.items()
604
+ if "generator." in k
605
+ }
606
+
607
+ self.load_state_dict(state_dict, strict=True)
608
+ self.head.remove_parametrizations()
609
+
610
+ @torch.no_grad()
611
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
612
+ x = self.backbone(x)
613
+ x = self.head(x)
614
+ if x.ndim == 2:
615
+ x = x[:, None, :]
616
+ return x
617
+
618
+
619
+ if __name__ == "__main__":
620
+ model = FireflyBase()
621
+ model.eval()
622
+ x = torch.randn(1, 128, 128)
623
+ with torch.no_grad():
624
+ y = model(x)
625
+ print(y.shape)