xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,650 @@
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import os
11
+ import torch.nn as nn
12
+ from typing import List, Optional, Tuple, Union
13
+ import math
14
+
15
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
16
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
17
+
18
+
19
+ # sinusoidal positional encoding
20
+ class SinusoidalPosEmb(nn.Module):
21
+ def __init__(self, dim):
22
+ super().__init__()
23
+ self.dim = dim
24
+
25
+ def forward(self, x):
26
+ device = x.device
27
+ half_dim = self.dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
30
+ emb = x[:, None] * emb[None, :] * 1.0
31
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
32
+ return emb
33
+
34
+
35
+ class LlamaAdaptiveRMSNorm(nn.Module):
36
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
37
+ super().__init__()
38
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
39
+ nn.init.zeros_(self.to_weight.weight)
40
+ nn.init.ones_(self.to_weight.bias)
41
+ self.variance_epsilon = eps
42
+ self._is_hf_initialized = True # disable automatic init
43
+
44
+ def forward(self, hidden_states, cond_embedding):
45
+ input_dtype = hidden_states.dtype
46
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+
49
+ weight = self.to_weight(cond_embedding)
50
+ if len(weight.shape) == 2:
51
+ weight = weight.unsqueeze(1)
52
+
53
+ return (weight * hidden_states).to(input_dtype)
54
+
55
+
56
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
57
+ def __init__(self, config: LlamaConfig, layer_idx: int):
58
+ """Override to adaptive layer norm"""
59
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
60
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
61
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
62
+ )
63
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
64
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
65
+ )
66
+
67
+ # add `cond` in forward function
68
+ def forward(
69
+ self,
70
+ hidden_states: torch.Tensor,
71
+ cond_embedding: torch.Tensor,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ position_ids: Optional[torch.LongTensor] = None,
74
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
75
+ output_attentions: Optional[bool] = False,
76
+ use_cache: Optional[bool] = False,
77
+ ) -> Tuple[
78
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
79
+ ]:
80
+ """
81
+ Args:
82
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
83
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
84
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
85
+ output_attentions (`bool`, *optional*):
86
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
87
+ returned tensors for more detail.
88
+ use_cache (`bool`, *optional*):
89
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
90
+ (see `past_key_values`).
91
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
92
+ """
93
+
94
+ residual = hidden_states
95
+
96
+ hidden_states = self.input_layernorm(
97
+ hidden_states, cond_embedding=cond_embedding
98
+ )
99
+
100
+ # Self Attention
101
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
102
+ hidden_states=hidden_states,
103
+ attention_mask=attention_mask,
104
+ position_ids=position_ids,
105
+ past_key_value=past_key_value,
106
+ output_attentions=output_attentions,
107
+ use_cache=use_cache,
108
+ )
109
+ hidden_states = residual + hidden_states
110
+
111
+ # Fully Connected
112
+ residual = hidden_states
113
+ hidden_states = self.post_attention_layernorm(
114
+ hidden_states, cond_embedding=cond_embedding
115
+ )
116
+ hidden_states = self.mlp(hidden_states)
117
+ hidden_states = residual + hidden_states
118
+
119
+ outputs = (hidden_states,)
120
+
121
+ if output_attentions:
122
+ outputs += (self_attn_weights,)
123
+
124
+ if use_cache:
125
+ outputs += (present_key_value,)
126
+
127
+ return outputs
128
+
129
+ def __init__(self, config: LlamaConfig, layer_idx: int):
130
+ """Override to adaptive layer norm"""
131
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
132
+ self.layer_idx = layer_idx
133
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
134
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
135
+ )
136
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
137
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
138
+ )
139
+
140
+ def forward(
141
+ self,
142
+ hidden_states: torch.Tensor,
143
+ cond_embedding: torch.Tensor,
144
+ attention_mask: Optional[torch.Tensor] = None,
145
+ position_ids: Optional[torch.LongTensor] = None,
146
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
147
+ output_attentions: Optional[bool] = False,
148
+ use_cache: Optional[bool] = False,
149
+ ) -> Tuple[
150
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
151
+ ]:
152
+ """
153
+ Args:
154
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
155
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
156
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
157
+ output_attentions (`bool`, *optional*):
158
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
159
+ returned tensors for more detail.
160
+ use_cache (`bool`, *optional*):
161
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
162
+ (see `past_key_values`).
163
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
164
+ """
165
+
166
+ residual = hidden_states
167
+
168
+ hidden_states = self.input_layernorm(
169
+ hidden_states, cond_embedding=cond_embedding
170
+ )
171
+
172
+ # Self Attention
173
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
174
+ hidden_states=hidden_states,
175
+ attention_mask=attention_mask,
176
+ position_ids=position_ids,
177
+ past_key_value=past_key_value,
178
+ output_attentions=output_attentions,
179
+ use_cache=use_cache,
180
+ )
181
+ hidden_states = residual + hidden_states
182
+
183
+ # Fully Connected
184
+ residual = hidden_states
185
+ hidden_states = self.post_attention_layernorm(
186
+ hidden_states, cond_embedding=cond_embedding
187
+ )
188
+ hidden_states = self.mlp(hidden_states)
189
+ hidden_states = residual + hidden_states
190
+
191
+ outputs = (hidden_states,)
192
+
193
+ if output_attentions:
194
+ outputs += (self_attn_weights,)
195
+
196
+ if use_cache:
197
+ outputs += (present_key_value,)
198
+
199
+ return outputs
200
+
201
+
202
+ class DiffLlama(LlamaModel):
203
+ def __init__(
204
+ self,
205
+ hidden_size=1024,
206
+ num_heads=16,
207
+ num_layers=16,
208
+ config=LlamaConfig(0, 256, 1024, 1, 1),
209
+ ):
210
+ super().__init__(config)
211
+
212
+ self.layers = nn.ModuleList(
213
+ [
214
+ LlamaNARDecoderLayer(
215
+ LlamaConfig(
216
+ hidden_size=hidden_size,
217
+ num_attention_heads=num_heads,
218
+ max_position_embeddings=4096,
219
+ intermediate_size=hidden_size * 4,
220
+ ),
221
+ layer_idx=i,
222
+ )
223
+ for i in range(num_layers)
224
+ ]
225
+ )
226
+
227
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
228
+
229
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
230
+ self.diff_step_mlp = nn.Sequential(
231
+ nn.Linear(hidden_size, hidden_size * 4),
232
+ nn.SiLU(),
233
+ nn.Linear(hidden_size * 4, hidden_size),
234
+ )
235
+
236
+ # self.position_embedding = PositionalEncoding(hidden_size, dropout=0.0)
237
+
238
+ self.cond_mlp = nn.Sequential(
239
+ nn.Linear(hidden_size, hidden_size * 4),
240
+ nn.SiLU(),
241
+ nn.Linear(hidden_size * 4, hidden_size),
242
+ )
243
+
244
+ for layer in self.layers:
245
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
246
+ hidden_size, dim_cond=hidden_size
247
+ )
248
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
249
+ hidden_size, dim_cond=hidden_size
250
+ )
251
+
252
+ self.post_init()
253
+
254
+ # self.reset_parameters()
255
+
256
+ def _prepare_decoder_attention_mask(
257
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
258
+ ):
259
+ # create noncausal mask
260
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
261
+ combined_attention_mask = None
262
+
263
+ def _expand_mask(
264
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
265
+ ):
266
+ """
267
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
268
+ """
269
+ bsz, src_len = mask.size()
270
+ tgt_len = tgt_len if tgt_len is not None else src_len
271
+
272
+ expanded_mask = (
273
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
274
+ )
275
+
276
+ inverted_mask = 1.0 - expanded_mask
277
+
278
+ return inverted_mask.masked_fill(
279
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
280
+ )
281
+
282
+ if attention_mask is not None:
283
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
284
+ expanded_attn_mask = _expand_mask(
285
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
286
+ ).to(inputs_embeds.device)
287
+ combined_attention_mask = (
288
+ expanded_attn_mask
289
+ if combined_attention_mask is None
290
+ else expanded_attn_mask + combined_attention_mask
291
+ )
292
+
293
+ return combined_attention_mask
294
+
295
+ def forward(
296
+ self,
297
+ x,
298
+ diffusion_step,
299
+ cond,
300
+ x_mask,
301
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
302
+ attention_mask: Optional[torch.Tensor] = None,
303
+ position_ids: Optional[torch.LongTensor] = None,
304
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
305
+ inputs_embeds: Optional[torch.FloatTensor] = None,
306
+ use_cache: Optional[bool] = None,
307
+ output_attentions: Optional[bool] = None,
308
+ output_hidden_states: Optional[bool] = None,
309
+ return_dict: Optional[bool] = None,
310
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
311
+
312
+ # retrieve some shape info
313
+ batch_size, seq_length, _ = x.shape
314
+
315
+ # condtion mlp
316
+ cond_embedding = self.cond_mlp(cond) # (B, T, C)
317
+
318
+ # diffusion step embedding
319
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
320
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
321
+ x = x + cond_embedding
322
+
323
+ inputs_embeds = x
324
+ attention_mask = x_mask
325
+
326
+ output_attentions = (
327
+ output_attentions
328
+ if output_attentions is not None
329
+ else self.config.output_attentions
330
+ )
331
+ output_hidden_states = (
332
+ output_hidden_states
333
+ if output_hidden_states is not None
334
+ else self.config.output_hidden_states
335
+ )
336
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
337
+
338
+ return_dict = (
339
+ return_dict if return_dict is not None else self.config.use_return_dict
340
+ )
341
+
342
+ seq_length_with_past = seq_length
343
+ past_key_values_length = 0
344
+
345
+ if past_key_values is not None:
346
+ past_key_values_length = past_key_values[0][0].shape[2]
347
+ seq_length_with_past = seq_length_with_past + past_key_values_length
348
+
349
+ if position_ids is None:
350
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
351
+ position_ids = torch.arange(
352
+ past_key_values_length,
353
+ seq_length + past_key_values_length,
354
+ dtype=torch.long,
355
+ device=device,
356
+ )
357
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
358
+ else:
359
+ position_ids = position_ids.view(-1, seq_length).long()
360
+
361
+ # embed positions
362
+ if attention_mask is None:
363
+ attention_mask = torch.ones(
364
+ (batch_size, seq_length_with_past),
365
+ dtype=torch.bool,
366
+ device=inputs_embeds.device,
367
+ )
368
+ attention_mask = self._prepare_decoder_attention_mask(
369
+ attention_mask,
370
+ (batch_size, seq_length),
371
+ inputs_embeds,
372
+ past_key_values_length,
373
+ )
374
+
375
+ hidden_states = inputs_embeds
376
+
377
+ if self.gradient_checkpointing and self.training:
378
+ if use_cache:
379
+ use_cache = False
380
+
381
+ # decoder layers
382
+ all_hidden_states = () if output_hidden_states else None
383
+ all_self_attns = () if output_attentions else None
384
+ next_decoder_cache = () if use_cache else None
385
+
386
+ for idx, decoder_layer in enumerate(self.layers):
387
+ if output_hidden_states:
388
+ all_hidden_states += (hidden_states,)
389
+
390
+ past_key_value = (
391
+ past_key_values[idx] if past_key_values is not None else None
392
+ )
393
+
394
+ if self.gradient_checkpointing and self.training:
395
+ raise NotImplementedError
396
+
397
+ else:
398
+ layer_outputs = decoder_layer(
399
+ hidden_states,
400
+ attention_mask=attention_mask,
401
+ position_ids=position_ids,
402
+ past_key_value=past_key_value,
403
+ output_attentions=output_attentions,
404
+ use_cache=use_cache,
405
+ cond_embedding=diffusion_step,
406
+ )
407
+
408
+ hidden_states = layer_outputs[0]
409
+
410
+ if use_cache:
411
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
412
+
413
+ if output_attentions:
414
+ all_self_attns += (layer_outputs[1],)
415
+
416
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
417
+
418
+ # add hidden states from the last decoder layer
419
+ if output_hidden_states:
420
+ all_hidden_states += (hidden_states,)
421
+
422
+ next_cache = next_decoder_cache if use_cache else None
423
+
424
+ return hidden_states
425
+
426
+
427
+ class DiffLlamaPrefix(LlamaModel):
428
+ def __init__(
429
+ self,
430
+ hidden_size=1024,
431
+ num_heads=16,
432
+ num_layers=16,
433
+ config=LlamaConfig(0, 256, 1024, 1, 1),
434
+ ):
435
+ super().__init__(config)
436
+
437
+ self.layers = nn.ModuleList(
438
+ [
439
+ LlamaNARDecoderLayer(
440
+ LlamaConfig(
441
+ hidden_size=hidden_size,
442
+ num_attention_heads=num_heads,
443
+ max_position_embeddings=4096,
444
+ intermediate_size=hidden_size * 4,
445
+ ),
446
+ layer_idx=i,
447
+ )
448
+ for i in range(num_layers)
449
+ ]
450
+ )
451
+
452
+ self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size)
453
+
454
+ self.diff_step_embedding = SinusoidalPosEmb(hidden_size)
455
+ self.diff_step_mlp = nn.Sequential(
456
+ nn.Linear(hidden_size, hidden_size * 4),
457
+ nn.SiLU(),
458
+ nn.Linear(hidden_size * 4, hidden_size),
459
+ )
460
+
461
+ self.cond_mlp = nn.Sequential(
462
+ nn.Linear(hidden_size, hidden_size * 4),
463
+ nn.SiLU(),
464
+ nn.Linear(hidden_size * 4, hidden_size),
465
+ )
466
+
467
+ for layer in self.layers:
468
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
469
+ hidden_size, dim_cond=hidden_size
470
+ )
471
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
472
+ hidden_size, dim_cond=hidden_size
473
+ )
474
+
475
+ self.embed_tokens = None
476
+
477
+ self.post_init()
478
+
479
+ def _prepare_decoder_attention_mask(
480
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
481
+ ):
482
+ # create noncausal mask
483
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
484
+ combined_attention_mask = None
485
+
486
+ def _expand_mask(
487
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
488
+ ):
489
+ """
490
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
491
+ """
492
+ bsz, src_len = mask.size()
493
+ tgt_len = tgt_len if tgt_len is not None else src_len
494
+
495
+ expanded_mask = (
496
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
497
+ )
498
+
499
+ inverted_mask = 1.0 - expanded_mask
500
+
501
+ return inverted_mask.masked_fill(
502
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
503
+ )
504
+
505
+ if attention_mask is not None:
506
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
507
+ expanded_attn_mask = _expand_mask(
508
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
509
+ ).to(inputs_embeds.device)
510
+ combined_attention_mask = (
511
+ expanded_attn_mask
512
+ if combined_attention_mask is None
513
+ else expanded_attn_mask + combined_attention_mask
514
+ )
515
+
516
+ return combined_attention_mask
517
+
518
+ def forward(
519
+ self,
520
+ x,
521
+ diffusion_step,
522
+ x_mask,
523
+ phone_embedding: Optional[torch.LongTensor] = None,
524
+ phone_mask: Optional[torch.FloatTensor] = None,
525
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
526
+ attention_mask: Optional[torch.LongTensor] = None,
527
+ position_ids: Optional[torch.LongTensor] = None,
528
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
529
+ inputs_embeds: Optional[torch.FloatTensor] = None,
530
+ use_cache: Optional[bool] = None,
531
+ output_attentions: Optional[bool] = None,
532
+ output_hidden_states: Optional[bool] = None,
533
+ return_dict: Optional[bool] = None,
534
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
535
+
536
+ # retrieve some shape info
537
+
538
+ phone_embedding = self.cond_mlp(phone_embedding) # (B, T, C)
539
+ phone_length = phone_embedding.shape[1]
540
+ inputs_embeds = torch.cat([phone_embedding, x], dim=1)
541
+ attention_mask = torch.cat([phone_mask, x_mask], dim=1)
542
+
543
+ # diffusion step embedding
544
+ diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device)
545
+ diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C)
546
+
547
+ batch_size, seq_length, _ = inputs_embeds.shape
548
+
549
+ output_attentions = (
550
+ output_attentions
551
+ if output_attentions is not None
552
+ else self.config.output_attentions
553
+ )
554
+ output_hidden_states = (
555
+ output_hidden_states
556
+ if output_hidden_states is not None
557
+ else self.config.output_hidden_states
558
+ )
559
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
560
+
561
+ return_dict = (
562
+ return_dict if return_dict is not None else self.config.use_return_dict
563
+ )
564
+
565
+ seq_length_with_past = seq_length
566
+ past_key_values_length = 0
567
+
568
+ if past_key_values is not None:
569
+ past_key_values_length = past_key_values[0][0].shape[2]
570
+ seq_length_with_past = seq_length_with_past + past_key_values_length
571
+
572
+ if position_ids is None:
573
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
574
+ position_ids = torch.arange(
575
+ past_key_values_length,
576
+ seq_length + past_key_values_length,
577
+ dtype=torch.long,
578
+ device=device,
579
+ )
580
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
581
+ else:
582
+ position_ids = position_ids.view(-1, seq_length).long()
583
+
584
+ # embed positions
585
+ if attention_mask is None:
586
+ attention_mask = torch.ones(
587
+ (batch_size, seq_length_with_past),
588
+ dtype=torch.bool,
589
+ device=inputs_embeds.device,
590
+ )
591
+ attention_mask = self._prepare_decoder_attention_mask(
592
+ attention_mask,
593
+ (batch_size, seq_length),
594
+ inputs_embeds,
595
+ past_key_values_length,
596
+ )
597
+
598
+ hidden_states = inputs_embeds
599
+
600
+ if self.gradient_checkpointing and self.training:
601
+ if use_cache:
602
+ use_cache = False
603
+
604
+ # decoder layers
605
+ all_hidden_states = () if output_hidden_states else None
606
+ all_self_attns = () if output_attentions else None
607
+ next_decoder_cache = () if use_cache else None
608
+
609
+ for idx, decoder_layer in enumerate(self.layers):
610
+ if output_hidden_states:
611
+ all_hidden_states += (hidden_states,)
612
+
613
+ past_key_value = (
614
+ past_key_values[idx] if past_key_values is not None else None
615
+ )
616
+
617
+ if self.gradient_checkpointing and self.training:
618
+ raise NotImplementedError
619
+
620
+ else:
621
+ layer_outputs = decoder_layer(
622
+ hidden_states,
623
+ attention_mask=attention_mask,
624
+ position_ids=position_ids,
625
+ past_key_value=past_key_value,
626
+ output_attentions=output_attentions,
627
+ use_cache=use_cache,
628
+ cond_embedding=diffusion_step,
629
+ )
630
+
631
+ hidden_states = layer_outputs[0]
632
+
633
+ if use_cache:
634
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
635
+
636
+ if output_attentions:
637
+ all_self_attns += (layer_outputs[1],)
638
+
639
+ hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step)
640
+
641
+ # add hidden states from the last decoder layer
642
+ if output_hidden_states:
643
+ all_hidden_states += (hidden_states,)
644
+
645
+ next_cache = next_decoder_cache if use_cache else None
646
+
647
+ return hidden_states[
648
+ :,
649
+ phone_length:,
650
+ ]