xinference 0.14.2__py3-none-any.whl → 0.14.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (191) hide show
  1. xinference/_version.py +3 -3
  2. xinference/core/chat_interface.py +1 -1
  3. xinference/core/image_interface.py +9 -0
  4. xinference/core/model.py +4 -1
  5. xinference/core/worker.py +60 -44
  6. xinference/model/audio/chattts.py +25 -9
  7. xinference/model/audio/core.py +8 -2
  8. xinference/model/audio/cosyvoice.py +4 -3
  9. xinference/model/audio/custom.py +4 -5
  10. xinference/model/audio/fish_speech.py +228 -0
  11. xinference/model/audio/model_spec.json +8 -0
  12. xinference/model/embedding/core.py +25 -1
  13. xinference/model/embedding/custom.py +4 -5
  14. xinference/model/flexible/core.py +5 -1
  15. xinference/model/image/custom.py +4 -5
  16. xinference/model/image/model_spec.json +2 -1
  17. xinference/model/image/model_spec_modelscope.json +2 -1
  18. xinference/model/image/stable_diffusion/core.py +66 -3
  19. xinference/model/llm/__init__.py +6 -0
  20. xinference/model/llm/llm_family.json +54 -9
  21. xinference/model/llm/llm_family.py +7 -6
  22. xinference/model/llm/llm_family_modelscope.json +56 -10
  23. xinference/model/llm/lmdeploy/__init__.py +0 -0
  24. xinference/model/llm/lmdeploy/core.py +557 -0
  25. xinference/model/llm/sglang/core.py +7 -1
  26. xinference/model/llm/transformers/cogvlm2.py +4 -45
  27. xinference/model/llm/transformers/cogvlm2_video.py +524 -0
  28. xinference/model/llm/transformers/core.py +3 -0
  29. xinference/model/llm/transformers/glm4v.py +2 -23
  30. xinference/model/llm/transformers/intern_vl.py +94 -11
  31. xinference/model/llm/transformers/minicpmv25.py +2 -23
  32. xinference/model/llm/transformers/minicpmv26.py +2 -22
  33. xinference/model/llm/transformers/yi_vl.py +2 -24
  34. xinference/model/llm/utils.py +13 -1
  35. xinference/model/llm/vllm/core.py +1 -34
  36. xinference/model/rerank/custom.py +4 -5
  37. xinference/model/utils.py +41 -1
  38. xinference/model/video/core.py +3 -1
  39. xinference/model/video/diffusers.py +41 -38
  40. xinference/model/video/model_spec.json +24 -1
  41. xinference/model/video/model_spec_modelscope.json +25 -1
  42. xinference/thirdparty/fish_speech/__init__.py +0 -0
  43. xinference/thirdparty/fish_speech/fish_speech/__init__.py +0 -0
  44. xinference/thirdparty/fish_speech/fish_speech/callbacks/__init__.py +3 -0
  45. xinference/thirdparty/fish_speech/fish_speech/callbacks/grad_norm.py +113 -0
  46. xinference/thirdparty/fish_speech/fish_speech/configs/__init__.py +0 -0
  47. xinference/thirdparty/fish_speech/fish_speech/configs/lora/__init__.py +0 -0
  48. xinference/thirdparty/fish_speech/fish_speech/conversation.py +2 -0
  49. xinference/thirdparty/fish_speech/fish_speech/datasets/__init__.py +0 -0
  50. xinference/thirdparty/fish_speech/fish_speech/datasets/concat_repeat.py +53 -0
  51. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/__init__.py +0 -0
  52. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_pb2.py +33 -0
  53. xinference/thirdparty/fish_speech/fish_speech/datasets/protos/text_data_stream.py +36 -0
  54. xinference/thirdparty/fish_speech/fish_speech/datasets/semantic.py +496 -0
  55. xinference/thirdparty/fish_speech/fish_speech/datasets/vqgan.py +147 -0
  56. xinference/thirdparty/fish_speech/fish_speech/i18n/__init__.py +3 -0
  57. xinference/thirdparty/fish_speech/fish_speech/i18n/core.py +40 -0
  58. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/__init__.py +0 -0
  59. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/en_US.json +122 -0
  60. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/es_ES.json +122 -0
  61. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/ja_JP.json +123 -0
  62. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/pt_BR.json +133 -0
  63. xinference/thirdparty/fish_speech/fish_speech/i18n/locale/zh_CN.json +122 -0
  64. xinference/thirdparty/fish_speech/fish_speech/i18n/scan.py +122 -0
  65. xinference/thirdparty/fish_speech/fish_speech/models/__init__.py +0 -0
  66. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/__init__.py +0 -0
  67. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lit_module.py +202 -0
  68. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/llama.py +779 -0
  69. xinference/thirdparty/fish_speech/fish_speech/models/text2semantic/lora.py +92 -0
  70. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/__init__.py +3 -0
  71. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/lit_module.py +442 -0
  72. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/__init__.py +0 -0
  73. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/discriminator.py +44 -0
  74. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/firefly.py +625 -0
  75. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/fsq.py +139 -0
  76. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/reference.py +115 -0
  77. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/modules/wavenet.py +225 -0
  78. xinference/thirdparty/fish_speech/fish_speech/models/vqgan/utils.py +94 -0
  79. xinference/thirdparty/fish_speech/fish_speech/scheduler.py +40 -0
  80. xinference/thirdparty/fish_speech/fish_speech/text/__init__.py +4 -0
  81. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/__init__.py +0 -0
  82. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_class.py +172 -0
  83. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  84. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/basic_util.py +342 -0
  85. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/cardinal.py +32 -0
  86. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/date.py +75 -0
  87. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/digit.py +32 -0
  88. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/fraction.py +35 -0
  89. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/money.py +43 -0
  90. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/percentage.py +33 -0
  91. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/telephone.py +51 -0
  92. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +177 -0
  93. xinference/thirdparty/fish_speech/fish_speech/text/clean.py +69 -0
  94. xinference/thirdparty/fish_speech/fish_speech/text/spliter.py +130 -0
  95. xinference/thirdparty/fish_speech/fish_speech/train.py +139 -0
  96. xinference/thirdparty/fish_speech/fish_speech/utils/__init__.py +23 -0
  97. xinference/thirdparty/fish_speech/fish_speech/utils/braceexpand.py +217 -0
  98. xinference/thirdparty/fish_speech/fish_speech/utils/context.py +13 -0
  99. xinference/thirdparty/fish_speech/fish_speech/utils/file.py +16 -0
  100. xinference/thirdparty/fish_speech/fish_speech/utils/instantiators.py +50 -0
  101. xinference/thirdparty/fish_speech/fish_speech/utils/logger.py +55 -0
  102. xinference/thirdparty/fish_speech/fish_speech/utils/logging_utils.py +48 -0
  103. xinference/thirdparty/fish_speech/fish_speech/utils/rich_utils.py +100 -0
  104. xinference/thirdparty/fish_speech/fish_speech/utils/spectrogram.py +122 -0
  105. xinference/thirdparty/fish_speech/fish_speech/utils/utils.py +114 -0
  106. xinference/thirdparty/fish_speech/fish_speech/webui/__init__.py +0 -0
  107. xinference/thirdparty/fish_speech/fish_speech/webui/launch_utils.py +120 -0
  108. xinference/thirdparty/fish_speech/fish_speech/webui/manage.py +1237 -0
  109. xinference/thirdparty/fish_speech/tools/__init__.py +0 -0
  110. xinference/thirdparty/fish_speech/tools/api.py +495 -0
  111. xinference/thirdparty/fish_speech/tools/auto_rerank.py +159 -0
  112. xinference/thirdparty/fish_speech/tools/download_models.py +55 -0
  113. xinference/thirdparty/fish_speech/tools/extract_model.py +21 -0
  114. xinference/thirdparty/fish_speech/tools/file.py +108 -0
  115. xinference/thirdparty/fish_speech/tools/gen_ref.py +36 -0
  116. xinference/thirdparty/fish_speech/tools/llama/__init__.py +0 -0
  117. xinference/thirdparty/fish_speech/tools/llama/build_dataset.py +169 -0
  118. xinference/thirdparty/fish_speech/tools/llama/eval_in_context.py +171 -0
  119. xinference/thirdparty/fish_speech/tools/llama/generate.py +698 -0
  120. xinference/thirdparty/fish_speech/tools/llama/merge_lora.py +95 -0
  121. xinference/thirdparty/fish_speech/tools/llama/quantize.py +497 -0
  122. xinference/thirdparty/fish_speech/tools/llama/rebuild_tokenizer.py +57 -0
  123. xinference/thirdparty/fish_speech/tools/merge_asr_files.py +55 -0
  124. xinference/thirdparty/fish_speech/tools/post_api.py +164 -0
  125. xinference/thirdparty/fish_speech/tools/sensevoice/__init__.py +0 -0
  126. xinference/thirdparty/fish_speech/tools/sensevoice/auto_model.py +573 -0
  127. xinference/thirdparty/fish_speech/tools/sensevoice/fun_asr.py +332 -0
  128. xinference/thirdparty/fish_speech/tools/sensevoice/vad_utils.py +61 -0
  129. xinference/thirdparty/fish_speech/tools/smart_pad.py +47 -0
  130. xinference/thirdparty/fish_speech/tools/vqgan/__init__.py +0 -0
  131. xinference/thirdparty/fish_speech/tools/vqgan/create_train_split.py +83 -0
  132. xinference/thirdparty/fish_speech/tools/vqgan/extract_vq.py +227 -0
  133. xinference/thirdparty/fish_speech/tools/vqgan/inference.py +120 -0
  134. xinference/thirdparty/fish_speech/tools/webui.py +619 -0
  135. xinference/thirdparty/fish_speech/tools/whisper_asr.py +176 -0
  136. xinference/thirdparty/matcha/__init__.py +0 -0
  137. xinference/thirdparty/matcha/app.py +357 -0
  138. xinference/thirdparty/matcha/cli.py +419 -0
  139. xinference/thirdparty/matcha/data/__init__.py +0 -0
  140. xinference/thirdparty/matcha/data/components/__init__.py +0 -0
  141. xinference/thirdparty/matcha/data/text_mel_datamodule.py +274 -0
  142. xinference/thirdparty/matcha/hifigan/__init__.py +0 -0
  143. xinference/thirdparty/matcha/hifigan/config.py +28 -0
  144. xinference/thirdparty/matcha/hifigan/denoiser.py +64 -0
  145. xinference/thirdparty/matcha/hifigan/env.py +17 -0
  146. xinference/thirdparty/matcha/hifigan/meldataset.py +217 -0
  147. xinference/thirdparty/matcha/hifigan/models.py +368 -0
  148. xinference/thirdparty/matcha/hifigan/xutils.py +60 -0
  149. xinference/thirdparty/matcha/models/__init__.py +0 -0
  150. xinference/thirdparty/matcha/models/baselightningmodule.py +210 -0
  151. xinference/thirdparty/matcha/models/components/__init__.py +0 -0
  152. xinference/thirdparty/matcha/models/components/decoder.py +443 -0
  153. xinference/thirdparty/matcha/models/components/flow_matching.py +132 -0
  154. xinference/thirdparty/matcha/models/components/text_encoder.py +410 -0
  155. xinference/thirdparty/matcha/models/components/transformer.py +316 -0
  156. xinference/thirdparty/matcha/models/matcha_tts.py +244 -0
  157. xinference/thirdparty/matcha/onnx/__init__.py +0 -0
  158. xinference/thirdparty/matcha/onnx/export.py +181 -0
  159. xinference/thirdparty/matcha/onnx/infer.py +168 -0
  160. xinference/thirdparty/matcha/text/__init__.py +53 -0
  161. xinference/thirdparty/matcha/text/cleaners.py +121 -0
  162. xinference/thirdparty/matcha/text/numbers.py +71 -0
  163. xinference/thirdparty/matcha/text/symbols.py +17 -0
  164. xinference/thirdparty/matcha/train.py +122 -0
  165. xinference/thirdparty/matcha/utils/__init__.py +5 -0
  166. xinference/thirdparty/matcha/utils/audio.py +82 -0
  167. xinference/thirdparty/matcha/utils/generate_data_statistics.py +112 -0
  168. xinference/thirdparty/matcha/utils/get_durations_from_trained_model.py +195 -0
  169. xinference/thirdparty/matcha/utils/instantiators.py +56 -0
  170. xinference/thirdparty/matcha/utils/logging_utils.py +53 -0
  171. xinference/thirdparty/matcha/utils/model.py +90 -0
  172. xinference/thirdparty/matcha/utils/monotonic_align/__init__.py +22 -0
  173. xinference/thirdparty/matcha/utils/monotonic_align/core.pyx +47 -0
  174. xinference/thirdparty/matcha/utils/monotonic_align/setup.py +7 -0
  175. xinference/thirdparty/matcha/utils/pylogger.py +21 -0
  176. xinference/thirdparty/matcha/utils/rich_utils.py +101 -0
  177. xinference/thirdparty/matcha/utils/utils.py +259 -0
  178. xinference/web/ui/build/asset-manifest.json +3 -3
  179. xinference/web/ui/build/index.html +1 -1
  180. xinference/web/ui/build/static/js/{main.ffc26121.js → main.661c7b0a.js} +3 -3
  181. xinference/web/ui/build/static/js/main.661c7b0a.js.map +1 -0
  182. xinference/web/ui/node_modules/.cache/babel-loader/070d8c6b3b0f3485c6d3885f0b6bbfdf9643e088a468acbd5d596f2396071c16.json +1 -0
  183. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/METADATA +31 -11
  184. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/RECORD +189 -49
  185. xinference/web/ui/build/static/js/main.ffc26121.js.map +0 -1
  186. xinference/web/ui/node_modules/.cache/babel-loader/2f40209b32e7e46a2eab6b8c8a355eb42c3caa8bc3228dd929f32fd2b3940294.json +0 -1
  187. /xinference/web/ui/build/static/js/{main.ffc26121.js.LICENSE.txt → main.661c7b0a.js.LICENSE.txt} +0 -0
  188. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/LICENSE +0 -0
  189. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/WHEEL +0 -0
  190. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/entry_points.txt +0 -0
  191. {xinference-0.14.2.dist-info → xinference-0.14.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,316 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
+ """
37
+ Initialization.
38
+ INPUT:
39
+ - in_features: shape of the input
40
+ - alpha - trainable parameter that controls frequency
41
+ - beta - trainable parameter that controls magnitude
42
+ alpha is initialized to 1 by default, higher values = higher-frequency.
43
+ beta is initialized to 1 by default, higher values = higher-magnitude.
44
+ alpha will be trained along with the rest of your model.
45
+ """
46
+ super().__init__()
47
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
+ self.proj = LoRACompatibleLinear(in_features, out_features)
49
+
50
+ # initialize alpha
51
+ self.alpha_logscale = alpha_logscale
52
+ if self.alpha_logscale: # log scale alphas initialized to zeros
53
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
+ else: # linear scale alphas initialized to ones
56
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
+
59
+ self.alpha.requires_grad = alpha_trainable
60
+ self.beta.requires_grad = alpha_trainable
61
+
62
+ self.no_div_by_zero = 0.000000001
63
+
64
+ def forward(self, x):
65
+ """
66
+ Forward pass of the function.
67
+ Applies the function to the input elementwise.
68
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
+ """
70
+ x = self.proj(x)
71
+ if self.alpha_logscale:
72
+ alpha = torch.exp(self.alpha)
73
+ beta = torch.exp(self.beta)
74
+ else:
75
+ alpha = self.alpha
76
+ beta = self.beta
77
+
78
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
+
80
+ return x
81
+
82
+
83
+ class FeedForward(nn.Module):
84
+ r"""
85
+ A feed-forward layer.
86
+
87
+ Parameters:
88
+ dim (`int`): The number of channels in the input.
89
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ dim_out: Optional[int] = None,
100
+ mult: int = 4,
101
+ dropout: float = 0.0,
102
+ activation_fn: str = "geglu",
103
+ final_dropout: bool = False,
104
+ ):
105
+ super().__init__()
106
+ inner_dim = int(dim * mult)
107
+ dim_out = dim_out if dim_out is not None else dim
108
+
109
+ if activation_fn == "gelu":
110
+ act_fn = GELU(dim, inner_dim)
111
+ if activation_fn == "gelu-approximate":
112
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
113
+ elif activation_fn == "geglu":
114
+ act_fn = GEGLU(dim, inner_dim)
115
+ elif activation_fn == "geglu-approximate":
116
+ act_fn = ApproximateGELU(dim, inner_dim)
117
+ elif activation_fn == "snakebeta":
118
+ act_fn = SnakeBeta(dim, inner_dim)
119
+
120
+ self.net = nn.ModuleList([])
121
+ # project in
122
+ self.net.append(act_fn)
123
+ # project dropout
124
+ self.net.append(nn.Dropout(dropout))
125
+ # project out
126
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
+ if final_dropout:
129
+ self.net.append(nn.Dropout(dropout))
130
+
131
+ def forward(self, hidden_states):
132
+ for module in self.net:
133
+ hidden_states = module(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class BasicTransformerBlock(nn.Module):
139
+ r"""
140
+ A basic Transformer block.
141
+
142
+ Parameters:
143
+ dim (`int`): The number of channels in the input and output.
144
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
+ attention_head_dim (`int`): The number of channels in each head.
146
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
+ only_cross_attention (`bool`, *optional*):
149
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
+ double_self_attention (`bool`, *optional*):
151
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
152
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
+ num_embeds_ada_norm (:
154
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
+ attention_bias (:
156
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_attention_heads: int,
163
+ attention_head_dim: int,
164
+ dropout=0.0,
165
+ cross_attention_dim: Optional[int] = None,
166
+ activation_fn: str = "geglu",
167
+ num_embeds_ada_norm: Optional[int] = None,
168
+ attention_bias: bool = False,
169
+ only_cross_attention: bool = False,
170
+ double_self_attention: bool = False,
171
+ upcast_attention: bool = False,
172
+ norm_elementwise_affine: bool = True,
173
+ norm_type: str = "layer_norm",
174
+ final_dropout: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.only_cross_attention = only_cross_attention
178
+
179
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
+
182
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
+ raise ValueError(
184
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
+ )
187
+
188
+ # Define 3 blocks. Each block has its own normalization layer.
189
+ # 1. Self-Attn
190
+ if self.use_ada_layer_norm:
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
+ elif self.use_ada_layer_norm_zero:
193
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
+ else:
195
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
+ self.attn1 = Attention(
197
+ query_dim=dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+
206
+ # 2. Cross-Attn
207
+ if cross_attention_dim is not None or double_self_attention:
208
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
+ # the second cross attention block.
211
+ self.norm2 = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn2 = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ # scale_qk=False, # uncomment this to not to use flash attention
225
+ ) # is self-attn if encoder_hidden_states is none
226
+ else:
227
+ self.norm2 = None
228
+ self.attn2 = None
229
+
230
+ # 3. Feed-forward
231
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
+
234
+ # let chunk size default to None
235
+ self._chunk_size = None
236
+ self._chunk_dim = 0
237
+
238
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
+ # Sets chunk feed-forward
240
+ self._chunk_size = chunk_size
241
+ self._chunk_dim = dim
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.FloatTensor,
246
+ attention_mask: Optional[torch.FloatTensor] = None,
247
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ cross_attention_kwargs: Dict[str, Any] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ ):
253
+ # Notice that normalization is always applied before the real computation in the following blocks.
254
+ # 1. Self-Attention
255
+ if self.use_ada_layer_norm:
256
+ norm_hidden_states = self.norm1(hidden_states, timestep)
257
+ elif self.use_ada_layer_norm_zero:
258
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
+ )
261
+ else:
262
+ norm_hidden_states = self.norm1(hidden_states)
263
+
264
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ hidden_states = attn_output + hidden_states
275
+
276
+ # 2. Cross-Attention
277
+ if self.attn2 is not None:
278
+ norm_hidden_states = (
279
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
+ )
281
+
282
+ attn_output = self.attn2(
283
+ norm_hidden_states,
284
+ encoder_hidden_states=encoder_hidden_states,
285
+ attention_mask=encoder_attention_mask,
286
+ **cross_attention_kwargs,
287
+ )
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ # 3. Feed-forward
291
+ norm_hidden_states = self.norm3(hidden_states)
292
+
293
+ if self.use_ada_layer_norm_zero:
294
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
+
296
+ if self._chunk_size is not None:
297
+ # "feed_forward_chunk_size" can be used to save memory
298
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
+ raise ValueError(
300
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
+ )
302
+
303
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
+ ff_output = torch.cat(
305
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
+ dim=self._chunk_dim,
307
+ )
308
+ else:
309
+ ff_output = self.ff(norm_hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
313
+
314
+ hidden_states = ff_output + hidden_states
315
+
316
+ return hidden_states
@@ -0,0 +1,244 @@
1
+ import datetime as dt
2
+ import math
3
+ import random
4
+
5
+ import torch
6
+
7
+ import matcha.utils.monotonic_align as monotonic_align
8
+ from matcha import utils
9
+ from matcha.models.baselightningmodule import BaseLightningClass
10
+ from matcha.models.components.flow_matching import CFM
11
+ from matcha.models.components.text_encoder import TextEncoder
12
+ from matcha.utils.model import (
13
+ denormalize,
14
+ duration_loss,
15
+ fix_len_compatibility,
16
+ generate_path,
17
+ sequence_mask,
18
+ )
19
+
20
+ log = utils.get_pylogger(__name__)
21
+
22
+
23
+ class MatchaTTS(BaseLightningClass): # 🍵
24
+ def __init__(
25
+ self,
26
+ n_vocab,
27
+ n_spks,
28
+ spk_emb_dim,
29
+ n_feats,
30
+ encoder,
31
+ decoder,
32
+ cfm,
33
+ data_statistics,
34
+ out_size,
35
+ optimizer=None,
36
+ scheduler=None,
37
+ prior_loss=True,
38
+ use_precomputed_durations=False,
39
+ ):
40
+ super().__init__()
41
+
42
+ self.save_hyperparameters(logger=False)
43
+
44
+ self.n_vocab = n_vocab
45
+ self.n_spks = n_spks
46
+ self.spk_emb_dim = spk_emb_dim
47
+ self.n_feats = n_feats
48
+ self.out_size = out_size
49
+ self.prior_loss = prior_loss
50
+ self.use_precomputed_durations = use_precomputed_durations
51
+
52
+ if n_spks > 1:
53
+ self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
54
+
55
+ self.encoder = TextEncoder(
56
+ encoder.encoder_type,
57
+ encoder.encoder_params,
58
+ encoder.duration_predictor_params,
59
+ n_vocab,
60
+ n_spks,
61
+ spk_emb_dim,
62
+ )
63
+
64
+ self.decoder = CFM(
65
+ in_channels=2 * encoder.encoder_params.n_feats,
66
+ out_channel=encoder.encoder_params.n_feats,
67
+ cfm_params=cfm,
68
+ decoder_params=decoder,
69
+ n_spks=n_spks,
70
+ spk_emb_dim=spk_emb_dim,
71
+ )
72
+
73
+ self.update_data_statistics(data_statistics)
74
+
75
+ @torch.inference_mode()
76
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0):
77
+ """
78
+ Generates mel-spectrogram from text. Returns:
79
+ 1. encoder outputs
80
+ 2. decoder outputs
81
+ 3. generated alignment
82
+
83
+ Args:
84
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
85
+ shape: (batch_size, max_text_length)
86
+ x_lengths (torch.Tensor): lengths of texts in batch.
87
+ shape: (batch_size,)
88
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
89
+ temperature (float, optional): controls variance of terminal distribution.
90
+ spks (bool, optional): speaker ids.
91
+ shape: (batch_size,)
92
+ length_scale (float, optional): controls speech pace.
93
+ Increase value to slow down generated speech and vice versa.
94
+
95
+ Returns:
96
+ dict: {
97
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
98
+ # Average mel spectrogram generated by the encoder
99
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
100
+ # Refined mel spectrogram improved by the CFM
101
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
102
+ # Alignment map between text and mel spectrogram
103
+ "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
104
+ # Denormalized mel spectrogram
105
+ "mel_lengths": torch.Tensor, shape: (batch_size,),
106
+ # Lengths of mel spectrograms
107
+ "rtf": float,
108
+ # Real-time factor
109
+ """
110
+ # For RTF computation
111
+ t = dt.datetime.now()
112
+
113
+ if self.n_spks > 1:
114
+ # Get speaker embedding
115
+ spks = self.spk_emb(spks.long())
116
+
117
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
118
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
119
+
120
+ w = torch.exp(logw) * x_mask
121
+ w_ceil = torch.ceil(w) * length_scale
122
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
123
+ y_max_length = y_lengths.max()
124
+ y_max_length_ = fix_len_compatibility(y_max_length)
125
+
126
+ # Using obtained durations `w` construct alignment map `attn`
127
+ y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
128
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
129
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
130
+
131
+ # Align encoded text and get mu_y
132
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
133
+ mu_y = mu_y.transpose(1, 2)
134
+ encoder_outputs = mu_y[:, :, :y_max_length]
135
+
136
+ # Generate sample tracing the probability flow
137
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks)
138
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
139
+
140
+ t = (dt.datetime.now() - t).total_seconds()
141
+ rtf = t * 22050 / (decoder_outputs.shape[-1] * 256)
142
+
143
+ return {
144
+ "encoder_outputs": encoder_outputs,
145
+ "decoder_outputs": decoder_outputs,
146
+ "attn": attn[:, :, :y_max_length],
147
+ "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std),
148
+ "mel_lengths": y_lengths,
149
+ "rtf": rtf,
150
+ }
151
+
152
+ def forward(self, x, x_lengths, y, y_lengths, spks=None, out_size=None, cond=None, durations=None):
153
+ """
154
+ Computes 3 losses:
155
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
156
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
157
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
158
+
159
+ Args:
160
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
161
+ shape: (batch_size, max_text_length)
162
+ x_lengths (torch.Tensor): lengths of texts in batch.
163
+ shape: (batch_size,)
164
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
165
+ shape: (batch_size, n_feats, max_mel_length)
166
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
167
+ shape: (batch_size,)
168
+ out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained.
169
+ Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size.
170
+ spks (torch.Tensor, optional): speaker ids.
171
+ shape: (batch_size,)
172
+ """
173
+ if self.n_spks > 1:
174
+ # Get speaker embedding
175
+ spks = self.spk_emb(spks)
176
+
177
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
178
+ mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
179
+ y_max_length = y.shape[-1]
180
+
181
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
182
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
183
+
184
+ if self.use_precomputed_durations:
185
+ attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1))
186
+ else:
187
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
188
+ with torch.no_grad():
189
+ const = -0.5 * math.log(2 * math.pi) * self.n_feats
190
+ factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
191
+ y_square = torch.matmul(factor.transpose(1, 2), y**2)
192
+ y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
193
+ mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
194
+ log_prior = y_square - y_mu_double + mu_square + const
195
+
196
+ attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
197
+ attn = attn.detach() # b, t_text, T_mel
198
+
199
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
200
+ # refered to as prior loss in the paper
201
+ logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
202
+ dur_loss = duration_loss(logw, logw_, x_lengths)
203
+
204
+ # Cut a small segment of mel-spectrogram in order to increase batch size
205
+ # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it
206
+ # - Do not need this hack for Matcha-TTS, but it works with it as well
207
+ if not isinstance(out_size, type(None)):
208
+ max_offset = (y_lengths - out_size).clamp(0)
209
+ offset_ranges = list(zip([0] * max_offset.shape[0], max_offset.cpu().numpy()))
210
+ out_offset = torch.LongTensor(
211
+ [torch.tensor(random.choice(range(start, end)) if end > start else 0) for start, end in offset_ranges]
212
+ ).to(y_lengths)
213
+ attn_cut = torch.zeros(attn.shape[0], attn.shape[1], out_size, dtype=attn.dtype, device=attn.device)
214
+ y_cut = torch.zeros(y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device)
215
+
216
+ y_cut_lengths = []
217
+ for i, (y_, out_offset_) in enumerate(zip(y, out_offset)):
218
+ y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0)
219
+ y_cut_lengths.append(y_cut_length)
220
+ cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length
221
+ y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper]
222
+ attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper]
223
+
224
+ y_cut_lengths = torch.LongTensor(y_cut_lengths)
225
+ y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask)
226
+
227
+ attn = attn_cut
228
+ y = y_cut
229
+ y_mask = y_cut_mask
230
+
231
+ # Align encoded text with mel-spectrogram and get mu_y segment
232
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
233
+ mu_y = mu_y.transpose(1, 2)
234
+
235
+ # Compute loss of the decoder
236
+ diff_loss, _ = self.decoder.compute_loss(x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond)
237
+
238
+ if self.prior_loss:
239
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
240
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
241
+ else:
242
+ prior_loss = 0
243
+
244
+ return dur_loss, prior_loss, diff_loss, attn
File without changes