xinference 0.14.1.post1__py3-none-any.whl → 0.14.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (194) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +15 -34
  3. xinference/client/restful/restful_client.py +2 -2
  4. xinference/core/chat_interface.py +45 -10
  5. xinference/core/image_interface.py +9 -0
  6. xinference/core/model.py +8 -5
  7. xinference/core/scheduler.py +1 -2
  8. xinference/core/worker.py +49 -42
  9. xinference/deploy/cmdline.py +2 -2
  10. xinference/deploy/test/test_cmdline.py +7 -7
  11. xinference/model/audio/chattts.py +24 -9
  12. xinference/model/audio/core.py +8 -2
  13. xinference/model/audio/fish_speech.py +228 -0
  14. xinference/model/audio/model_spec.json +8 -0
  15. xinference/model/embedding/core.py +23 -1
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +49 -1
  19. xinference/model/llm/__init__.py +26 -27
  20. xinference/model/llm/{ggml/llamacpp.py → llama_cpp/core.py} +2 -35
  21. xinference/model/llm/llm_family.json +606 -1266
  22. xinference/model/llm/llm_family.py +16 -139
  23. xinference/model/llm/llm_family_modelscope.json +276 -313
  24. xinference/model/llm/lmdeploy/__init__.py +0 -0
  25. xinference/model/llm/lmdeploy/core.py +557 -0
  26. xinference/model/llm/memory.py +9 -9
  27. xinference/model/llm/sglang/core.py +2 -2
  28. xinference/model/llm/{pytorch → transformers}/chatglm.py +6 -13
  29. xinference/model/llm/{pytorch → transformers}/cogvlm2.py +4 -45
  30. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  31. xinference/model/llm/{pytorch → transformers}/core.py +3 -10
  32. xinference/model/llm/{pytorch → transformers}/glm4v.py +2 -23
  33. xinference/model/llm/transformers/intern_vl.py +540 -0
  34. xinference/model/llm/{pytorch → transformers}/internlm2.py +4 -8
  35. xinference/model/llm/{pytorch → transformers}/minicpmv25.py +2 -23
  36. xinference/model/llm/{pytorch → transformers}/minicpmv26.py +66 -41
  37. xinference/model/llm/{pytorch → transformers}/utils.py +1 -2
  38. xinference/model/llm/{pytorch → transformers}/yi_vl.py +2 -24
  39. xinference/model/llm/utils.py +85 -70
  40. xinference/model/llm/vllm/core.py +110 -11
  41. xinference/model/utils.py +1 -95
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/internvl/__init__.py +0 -0
  137. xinference/thirdparty/internvl/conversation.py +393 -0
  138. xinference/thirdparty/omnilmm/model/utils.py +16 -1
  139. xinference/web/ui/build/asset-manifest.json +3 -3
  140. xinference/web/ui/build/index.html +1 -1
  141. xinference/web/ui/build/static/js/main.661c7b0a.js +3 -0
  142. xinference/web/ui/build/static/js/{main.17ca0398.js.map → main.661c7b0a.js.map} +1 -1
  143. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  144. xinference/web/ui/node_modules/.cache/babel-loader/213b5913e164773c2b0567455377765715f5f07225fbac77ad8e1e9dc9648a47.json +1 -0
  145. xinference/web/ui/node_modules/.cache/babel-loader/4de9a6942c5f1749d6cbfdd54279699975f16016b182848bc253886f52ec2ec3.json +1 -0
  146. xinference/web/ui/node_modules/.cache/babel-loader/5391543180fead1eeef5364300301498d58a7d91d62de3841a32768b67f4552f.json +1 -0
  147. xinference/web/ui/node_modules/.cache/babel-loader/5c26a23b5eacf5b752a08531577ae3840bb247745ef9a39583dc2d05ba93a82a.json +1 -0
  148. xinference/web/ui/node_modules/.cache/babel-loader/714c37ce0ec5b5c591033f02be2f3f491fdd70da3ef568ee4a4f94689a3d5ca2.json +1 -0
  149. xinference/web/ui/node_modules/.cache/babel-loader/822586ed1077201b64b954f12f25e3f9b45678c1acbabe53d8af3ca82ca71f33.json +1 -0
  150. xinference/web/ui/node_modules/.cache/babel-loader/978b57d1a04a701bc3fcfebc511f5f274eed6ed7eade67f6fb76c27d5fd9ecc8.json +1 -0
  151. xinference/web/ui/node_modules/.cache/babel-loader/a797831de0dc74897f4b50b3426555d748f328b4c2cc391de709eadaf6a5f3e3.json +1 -0
  152. xinference/web/ui/node_modules/.cache/babel-loader/bd6ad8159341315a1764c397621a560809f7eb7219ab5174c801fca7e969d943.json +1 -0
  153. xinference/web/ui/node_modules/.cache/babel-loader/e64b7e8cedcf43d4c95deba60ec1341855c887705805bb62431693118b870c69.json +1 -0
  154. xinference/web/ui/node_modules/.cache/babel-loader/e91938976f229ce986b2907e51e1f00540b584ced0a315d498c172d13220739d.json +1 -0
  155. xinference/web/ui/node_modules/.cache/babel-loader/f72f011744c4649fabddca6f7a9327861ac0a315a89b1a2e62a39774e7863845.json +1 -0
  156. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/METADATA +22 -13
  157. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/RECORD +170 -79
  158. xinference/locale/utils.py +0 -39
  159. xinference/locale/zh_CN.json +0 -26
  160. xinference/model/llm/ggml/tools/__init__.py +0 -15
  161. xinference/model/llm/ggml/tools/convert_ggml_to_gguf.py +0 -498
  162. xinference/model/llm/ggml/tools/gguf.py +0 -884
  163. xinference/model/llm/pytorch/__init__.py +0 -13
  164. xinference/model/llm/pytorch/baichuan.py +0 -81
  165. xinference/model/llm/pytorch/falcon.py +0 -138
  166. xinference/model/llm/pytorch/intern_vl.py +0 -352
  167. xinference/model/llm/pytorch/vicuna.py +0 -69
  168. xinference/web/ui/build/static/js/main.17ca0398.js +0 -3
  169. xinference/web/ui/node_modules/.cache/babel-loader/1444c41a4d04494f1cbc2d8c1537df107b451cb569cb2c1fbf5159f3a4841a5f.json +0 -1
  170. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  171. xinference/web/ui/node_modules/.cache/babel-loader/44774c783428f952d8e2e4ad0998a9c5bc16a57cd9c68b7c5ff18aaa5a41d65c.json +0 -1
  172. xinference/web/ui/node_modules/.cache/babel-loader/5262556baf9207738bf6a8ba141ec6599d0a636345c245d61fdf88d3171998cb.json +0 -1
  173. xinference/web/ui/node_modules/.cache/babel-loader/6450605fac003812485f6251b9f0caafbf2e5bfc3bbe2f000050d9e2fdb8dcd3.json +0 -1
  174. xinference/web/ui/node_modules/.cache/babel-loader/71684495d995c7e266eecc6a0ad8ea0284cc785f80abddf863789c57a6134969.json +0 -1
  175. xinference/web/ui/node_modules/.cache/babel-loader/80acd1edf31542ab1dcccfad02cb4b38f3325cff847a781fcce97500cfd6f878.json +0 -1
  176. xinference/web/ui/node_modules/.cache/babel-loader/8a9742ddd8ba8546ef42dc14caca443f2b4524fabed7bf269e0eff3b7b64ee7d.json +0 -1
  177. xinference/web/ui/node_modules/.cache/babel-loader/d06a96a3c9c32e42689094aa3aaad41c8125894e956b8f84a70fadce6e3f65b3.json +0 -1
  178. xinference/web/ui/node_modules/.cache/babel-loader/d93730e2b5d7e8c957b4d0965d2ed1dac9045a649adbd47c220d11f255d4b1e0.json +0 -1
  179. xinference/web/ui/node_modules/.cache/babel-loader/e656dc00b4d8b387f0a81ba8fc558767df1601c66369e2eb86a5ef27cf080572.json +0 -1
  180. xinference/web/ui/node_modules/.cache/babel-loader/f28b83886159d83b84f099b05d607a822dca4dd7f2d8aa6d56fe08bab0b5b086.json +0 -1
  181. xinference/web/ui/node_modules/.cache/babel-loader/f3e02274cb1964e99b1fe69cbb6db233d3d8d7dd05d50ebcdb8e66d50b224b7b.json +0 -1
  182. /xinference/{locale → model/llm/llama_cpp}/__init__.py +0 -0
  183. /xinference/model/llm/{ggml → transformers}/__init__.py +0 -0
  184. /xinference/model/llm/{pytorch → transformers}/compression.py +0 -0
  185. /xinference/model/llm/{pytorch → transformers}/deepseek_vl.py +0 -0
  186. /xinference/model/llm/{pytorch → transformers}/llama_2.py +0 -0
  187. /xinference/model/llm/{pytorch → transformers}/omnilmm.py +0 -0
  188. /xinference/model/llm/{pytorch → transformers}/qwen_vl.py +0 -0
  189. /xinference/model/llm/{pytorch → transformers}/tensorizer_utils.py +0 -0
  190. /xinference/web/ui/build/static/js/{main.17ca0398.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  191. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/LICENSE +0 -0
  192. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/WHEEL +0 -0
  193. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/entry_points.txt +0 -0
  194. {xinference-0.14.1.post1.dist-info → xinference-0.14.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
8
+
9
+ from .firefly import ConvNeXtBlock
10
+
11
+
12
+ @dataclass
13
+ class FSQResult:
14
+ z: torch.Tensor
15
+ codes: torch.Tensor
16
+ latents: torch.Tensor
17
+
18
+
19
+ class DownsampleFiniteScalarQuantize(nn.Module):
20
+ def __init__(
21
+ self,
22
+ input_dim: int = 512,
23
+ n_codebooks: int = 1,
24
+ n_groups: int = 1,
25
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
+ downsample_factor: tuple[int] = (2, 2),
27
+ downsample_dims: tuple[int] | None = None,
28
+ ):
29
+ super().__init__()
30
+
31
+ if downsample_dims is None:
32
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
+
34
+ all_dims = (input_dim,) + tuple(downsample_dims)
35
+
36
+ self.residual_fsq = GroupedResidualFSQ(
37
+ dim=all_dims[-1],
38
+ levels=levels,
39
+ num_quantizers=n_codebooks,
40
+ groups=n_groups,
41
+ )
42
+
43
+ self.downsample_factor = downsample_factor
44
+ self.downsample_dims = downsample_dims
45
+
46
+ self.downsample = nn.Sequential(
47
+ *[
48
+ nn.Sequential(
49
+ nn.Conv1d(
50
+ all_dims[idx],
51
+ all_dims[idx + 1],
52
+ kernel_size=factor,
53
+ stride=factor,
54
+ ),
55
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
56
+ )
57
+ for idx, factor in enumerate(downsample_factor)
58
+ ]
59
+ )
60
+
61
+ self.upsample = nn.Sequential(
62
+ *[
63
+ nn.Sequential(
64
+ nn.ConvTranspose1d(
65
+ all_dims[idx + 1],
66
+ all_dims[idx],
67
+ kernel_size=factor,
68
+ stride=factor,
69
+ ),
70
+ ConvNeXtBlock(dim=all_dims[idx]),
71
+ )
72
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
73
+ ]
74
+ )
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
80
+ nn.init.trunc_normal_(m.weight, std=0.02)
81
+ nn.init.constant_(m.bias, 0)
82
+
83
+ def forward(self, z) -> FSQResult:
84
+ original_shape = z.shape
85
+ z = self.downsample(z)
86
+ quantized, indices = self.residual_fsq(z.mT)
87
+ result = FSQResult(
88
+ z=quantized.mT,
89
+ codes=indices.mT,
90
+ latents=z,
91
+ )
92
+ result.z = self.upsample(result.z)
93
+
94
+ # Pad or crop z to match original shape
95
+ diff = original_shape[-1] - result.z.shape[-1]
96
+ left = diff // 2
97
+ right = diff - left
98
+
99
+ if diff > 0:
100
+ result.z = F.pad(result.z, (left, right))
101
+ elif diff < 0:
102
+ result.z = result.z[..., left:-right]
103
+
104
+ return result
105
+
106
+ def encode(self, z):
107
+ z = self.downsample(z)
108
+ _, indices = self.residual_fsq(z.mT)
109
+ indices = rearrange(indices, "g b l r -> b (g r) l")
110
+ return indices
111
+
112
+ def decode(self, indices: torch.Tensor):
113
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
+ z_q = self.residual_fsq.get_output_from_indices(indices)
115
+ z_q = self.upsample(z_q.mT)
116
+ return z_q
117
+
118
+ # def from_latents(self, latents: torch.Tensor):
119
+ # z_q, z_p, codes = super().from_latents(latents)
120
+ # z_q = self.upsample(z_q)
121
+ # return z_q, z_p, codes
122
+
123
+
124
+ if __name__ == "__main__":
125
+ rvq = DownsampleFiniteScalarQuantize(
126
+ n_codebooks=1,
127
+ downsample_factor=(2, 2),
128
+ )
129
+ x = torch.randn(16, 512, 80)
130
+
131
+ result = rvq(x)
132
+ print(rvq)
133
+ print(result.latents.shape, result.codes.shape, result.z.shape)
134
+
135
+ # y = rvq.from_codes(result.codes)
136
+ # print(y[0].shape)
137
+
138
+ # y = rvq.from_latents(result.latents)
139
+ # print(y[0].shape)
@@ -0,0 +1,115 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from fish_speech.utils import autocast_exclude_mps
8
+
9
+ from .wavenet import WaveNet
10
+
11
+
12
+ class ReferenceEncoder(WaveNet):
13
+ def __init__(
14
+ self,
15
+ input_channels: Optional[int] = None,
16
+ output_channels: Optional[int] = None,
17
+ residual_channels: int = 512,
18
+ residual_layers: int = 20,
19
+ dilation_cycle: Optional[int] = 4,
20
+ num_heads: int = 8,
21
+ latent_len: int = 4,
22
+ ):
23
+ super().__init__(
24
+ input_channels=input_channels,
25
+ residual_channels=residual_channels,
26
+ residual_layers=residual_layers,
27
+ dilation_cycle=dilation_cycle,
28
+ )
29
+
30
+ self.head_dim = residual_channels // num_heads
31
+ self.num_heads = num_heads
32
+
33
+ self.latent_len = latent_len
34
+ self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
35
+
36
+ self.q = nn.Linear(residual_channels, residual_channels, bias=True)
37
+ self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
38
+ self.q_norm = nn.LayerNorm(self.head_dim)
39
+ self.k_norm = nn.LayerNorm(self.head_dim)
40
+ self.proj = nn.Linear(residual_channels, residual_channels)
41
+ self.proj_drop = nn.Dropout(0.1)
42
+
43
+ self.norm = nn.LayerNorm(residual_channels)
44
+ self.mlp = nn.Sequential(
45
+ nn.Linear(residual_channels, residual_channels * 4),
46
+ nn.SiLU(),
47
+ nn.Linear(residual_channels * 4, residual_channels),
48
+ )
49
+ self.output_projection_attn = nn.Linear(residual_channels, output_channels)
50
+
51
+ torch.nn.init.trunc_normal_(self.latent, std=0.02)
52
+ self.apply(self.init_weights)
53
+
54
+ def init_weights(self, m):
55
+ if isinstance(m, nn.Linear):
56
+ torch.nn.init.trunc_normal_(m.weight, std=0.02)
57
+ if m.bias is not None:
58
+ torch.nn.init.constant_(m.bias, 0)
59
+
60
+ def forward(self, x, attn_mask=None):
61
+ x = super().forward(x).mT
62
+ B, N, C = x.shape
63
+
64
+ # Calculate mask
65
+ if attn_mask is not None:
66
+ assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
67
+
68
+ attn_mask = attn_mask[:, None, None, :].expand(
69
+ B, self.num_heads, self.latent_len, N
70
+ )
71
+
72
+ q_latent = self.latent.expand(B, -1, -1)
73
+ q = (
74
+ self.q(q_latent)
75
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
76
+ .transpose(1, 2)
77
+ )
78
+
79
+ kv = (
80
+ self.kv(x)
81
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
82
+ .permute(2, 0, 3, 1, 4)
83
+ )
84
+ k, v = kv.unbind(0)
85
+
86
+ q, k = self.q_norm(q), self.k_norm(k)
87
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
88
+
89
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
90
+ x = self.proj(x)
91
+ x = self.proj_drop(x)
92
+
93
+ x = x + self.mlp(self.norm(x))
94
+ x = self.output_projection_attn(x)
95
+ x = x.mean(1)
96
+
97
+ return x
98
+
99
+
100
+ if __name__ == "__main__":
101
+ with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
102
+ model = ReferenceEncoder(
103
+ input_channels=128,
104
+ output_channels=64,
105
+ residual_channels=384,
106
+ residual_layers=20,
107
+ dilation_cycle=4,
108
+ num_heads=8,
109
+ )
110
+ x = torch.randn(4, 128, 64)
111
+ mask = torch.ones(4, 64, dtype=torch.bool)
112
+ y = model(x, mask)
113
+ print(y.shape)
114
+ loss = F.mse_loss(y, torch.randn(4, 64))
115
+ loss.backward()
@@ -0,0 +1,225 @@
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class Mish(nn.Module):
10
+ def forward(self, x):
11
+ return x * torch.tanh(F.softplus(x))
12
+
13
+
14
+ class DiffusionEmbedding(nn.Module):
15
+ """Diffusion Step Embedding"""
16
+
17
+ def __init__(self, d_denoiser):
18
+ super(DiffusionEmbedding, self).__init__()
19
+ self.dim = d_denoiser
20
+
21
+ def forward(self, x):
22
+ device = x.device
23
+ half_dim = self.dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26
+ emb = x[:, None] * emb[None, :]
27
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28
+ return emb
29
+
30
+
31
+ class LinearNorm(nn.Module):
32
+ """LinearNorm Projection"""
33
+
34
+ def __init__(self, in_features, out_features, bias=False):
35
+ super(LinearNorm, self).__init__()
36
+ self.linear = nn.Linear(in_features, out_features, bias)
37
+
38
+ nn.init.xavier_uniform_(self.linear.weight)
39
+ if bias:
40
+ nn.init.constant_(self.linear.bias, 0.0)
41
+
42
+ def forward(self, x):
43
+ x = self.linear(x)
44
+ return x
45
+
46
+
47
+ class ConvNorm(nn.Module):
48
+ """1D Convolution"""
49
+
50
+ def __init__(
51
+ self,
52
+ in_channels,
53
+ out_channels,
54
+ kernel_size=1,
55
+ stride=1,
56
+ padding=None,
57
+ dilation=1,
58
+ bias=True,
59
+ w_init_gain="linear",
60
+ ):
61
+ super(ConvNorm, self).__init__()
62
+
63
+ if padding is None:
64
+ assert kernel_size % 2 == 1
65
+ padding = int(dilation * (kernel_size - 1) / 2)
66
+
67
+ self.conv = nn.Conv1d(
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size=kernel_size,
71
+ stride=stride,
72
+ padding=padding,
73
+ dilation=dilation,
74
+ bias=bias,
75
+ )
76
+ nn.init.kaiming_normal_(self.conv.weight)
77
+
78
+ def forward(self, signal):
79
+ conv_signal = self.conv(signal)
80
+
81
+ return conv_signal
82
+
83
+
84
+ class ResidualBlock(nn.Module):
85
+ """Residual Block"""
86
+
87
+ def __init__(
88
+ self,
89
+ residual_channels,
90
+ use_linear_bias=False,
91
+ dilation=1,
92
+ condition_channels=None,
93
+ ):
94
+ super(ResidualBlock, self).__init__()
95
+ self.conv_layer = ConvNorm(
96
+ residual_channels,
97
+ 2 * residual_channels,
98
+ kernel_size=3,
99
+ stride=1,
100
+ padding=dilation,
101
+ dilation=dilation,
102
+ )
103
+
104
+ if condition_channels is not None:
105
+ self.diffusion_projection = LinearNorm(
106
+ residual_channels, residual_channels, use_linear_bias
107
+ )
108
+ self.condition_projection = ConvNorm(
109
+ condition_channels, 2 * residual_channels, kernel_size=1
110
+ )
111
+
112
+ self.output_projection = ConvNorm(
113
+ residual_channels, 2 * residual_channels, kernel_size=1
114
+ )
115
+
116
+ def forward(self, x, condition=None, diffusion_step=None):
117
+ y = x
118
+
119
+ if diffusion_step is not None:
120
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
121
+ y = y + diffusion_step
122
+
123
+ y = self.conv_layer(y)
124
+
125
+ if condition is not None:
126
+ condition = self.condition_projection(condition)
127
+ y = y + condition
128
+
129
+ gate, filter = torch.chunk(y, 2, dim=1)
130
+ y = torch.sigmoid(gate) * torch.tanh(filter)
131
+
132
+ y = self.output_projection(y)
133
+ residual, skip = torch.chunk(y, 2, dim=1)
134
+
135
+ return (x + residual) / math.sqrt(2.0), skip
136
+
137
+
138
+ class WaveNet(nn.Module):
139
+ def __init__(
140
+ self,
141
+ input_channels: Optional[int] = None,
142
+ output_channels: Optional[int] = None,
143
+ residual_channels: int = 512,
144
+ residual_layers: int = 20,
145
+ dilation_cycle: Optional[int] = 4,
146
+ is_diffusion: bool = False,
147
+ condition_channels: Optional[int] = None,
148
+ ):
149
+ super().__init__()
150
+
151
+ # Input projection
152
+ self.input_projection = None
153
+ if input_channels is not None and input_channels != residual_channels:
154
+ self.input_projection = ConvNorm(
155
+ input_channels, residual_channels, kernel_size=1
156
+ )
157
+
158
+ if input_channels is None:
159
+ input_channels = residual_channels
160
+
161
+ self.input_channels = input_channels
162
+
163
+ # Residual layers
164
+ self.residual_layers = nn.ModuleList(
165
+ [
166
+ ResidualBlock(
167
+ residual_channels=residual_channels,
168
+ use_linear_bias=False,
169
+ dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
170
+ condition_channels=condition_channels,
171
+ )
172
+ for i in range(residual_layers)
173
+ ]
174
+ )
175
+
176
+ # Skip projection
177
+ self.skip_projection = ConvNorm(
178
+ residual_channels, residual_channels, kernel_size=1
179
+ )
180
+
181
+ # Output projection
182
+ self.output_projection = None
183
+ if output_channels is not None and output_channels != residual_channels:
184
+ self.output_projection = ConvNorm(
185
+ residual_channels, output_channels, kernel_size=1
186
+ )
187
+
188
+ if is_diffusion:
189
+ self.diffusion_embedding = DiffusionEmbedding(residual_channels)
190
+ self.mlp = nn.Sequential(
191
+ LinearNorm(residual_channels, residual_channels * 4, False),
192
+ Mish(),
193
+ LinearNorm(residual_channels * 4, residual_channels, False),
194
+ )
195
+
196
+ self.apply(self._init_weights)
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
200
+ nn.init.trunc_normal_(m.weight, std=0.02)
201
+ if getattr(m, "bias", None) is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+
204
+ def forward(self, x, t=None, condition=None):
205
+ if self.input_projection is not None:
206
+ x = self.input_projection(x)
207
+ x = F.silu(x)
208
+
209
+ if t is not None:
210
+ t = self.diffusion_embedding(t)
211
+ t = self.mlp(t)
212
+
213
+ skip = []
214
+ for layer in self.residual_layers:
215
+ x, skip_connection = layer(x, condition, t)
216
+ skip.append(skip_connection)
217
+
218
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
219
+ x = self.skip_projection(x)
220
+
221
+ if self.output_projection is not None:
222
+ x = F.silu(x)
223
+ x = self.output_projection(x)
224
+
225
+ return x
@@ -0,0 +1,94 @@
1
+ import matplotlib
2
+ import torch
3
+ from matplotlib import pyplot as plt
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ def convert_pad_shape(pad_shape):
9
+ l = pad_shape[::-1]
10
+ pad_shape = [item for sublist in l for item in sublist]
11
+ return pad_shape
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size * dilation - dilation) / 2)
29
+
30
+
31
+ def plot_mel(data, titles=None):
32
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
+
34
+ if titles is None:
35
+ titles = [None for i in range(len(data))]
36
+
37
+ plt.tight_layout()
38
+
39
+ for i in range(len(data)):
40
+ mel = data[i]
41
+
42
+ if isinstance(mel, torch.Tensor):
43
+ mel = mel.float().detach().cpu().numpy()
44
+
45
+ axes[i][0].imshow(mel, origin="lower")
46
+ axes[i][0].set_aspect(2.5, adjustable="box")
47
+ axes[i][0].set_ylim(0, mel.shape[0])
48
+ axes[i][0].set_title(titles[i], fontsize="medium")
49
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
+ axes[i][0].set_anchor("W")
51
+
52
+ return fig
53
+
54
+
55
+ def slice_segments(x, ids_str, segment_size=4):
56
+ ret = torch.zeros_like(x[:, :, :segment_size])
57
+ for i in range(x.size(0)):
58
+ idx_str = ids_str[i]
59
+ idx_end = idx_str + segment_size
60
+ ret[i] = x[i, :, idx_str:idx_end]
61
+
62
+ return ret
63
+
64
+
65
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
+ b, d, t = x.size()
67
+ if x_lengths is None:
68
+ x_lengths = t
69
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
+ ret = slice_segments(x, ids_str, segment_size)
72
+ return ret, ids_str
73
+
74
+
75
+ @torch.jit.script
76
+ def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
+ n_channels_int = n_channels[0]
78
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
+ acts = t_act * s_act
81
+
82
+ return acts
83
+
84
+
85
+ def avg_with_mask(x, mask):
86
+ assert mask.dtype == torch.float, "Mask should be float"
87
+
88
+ if mask.ndim == 2:
89
+ mask = mask.unsqueeze(1)
90
+
91
+ if mask.shape[1] == 1:
92
+ mask = mask.expand_as(x)
93
+
94
+ return (x * mask).sum() / mask.sum()
@@ -0,0 +1,40 @@
1
+ import math
2
+
3
+
4
+ def get_cosine_schedule_with_warmup_lr_lambda(
5
+ current_step: int,
6
+ *,
7
+ num_warmup_steps: int | float,
8
+ num_training_steps: int,
9
+ num_cycles: float = 0.5,
10
+ final_lr_ratio: float = 0.0,
11
+ ):
12
+ if 0 < num_warmup_steps < 1: # float mode
13
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
+
15
+ if current_step < num_warmup_steps:
16
+ return float(current_step) / float(max(1, num_warmup_steps))
17
+
18
+ progress = float(current_step - num_warmup_steps) / float(
19
+ max(1, num_training_steps - num_warmup_steps)
20
+ )
21
+
22
+ return max(
23
+ final_lr_ratio,
24
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
+ )
26
+
27
+
28
+ def get_constant_schedule_with_warmup_lr_lambda(
29
+ current_step: int,
30
+ *,
31
+ num_warmup_steps: int | float,
32
+ num_training_steps: int | None = None,
33
+ ):
34
+ if 0 < num_warmup_steps < 1: # float mode
35
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
+
37
+ if current_step < num_warmup_steps:
38
+ return float(current_step) / float(max(1, num_warmup_steps))
39
+
40
+ return 1.0
@@ -0,0 +1,4 @@
1
+ from .clean import clean_text
2
+ from .spliter import split_text
3
+
4
+ __all__ = ["clean_text", "split_text"]