xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of xinference might be problematic. Click here for more details.
- xinference/_version.py +3 -3
- xinference/api/restful_api.py +473 -31
- xinference/client/restful/async_restful_client.py +178 -8
- xinference/client/restful/restful_client.py +151 -3
- xinference/core/supervisor.py +99 -53
- xinference/core/worker.py +10 -0
- xinference/deploy/cmdline.py +15 -0
- xinference/model/audio/core.py +21 -6
- xinference/model/audio/indextts2.py +166 -0
- xinference/model/audio/model_spec.json +58 -21
- xinference/model/image/model_spec.json +159 -90
- xinference/model/image/stable_diffusion/core.py +13 -4
- xinference/model/llm/__init__.py +6 -2
- xinference/model/llm/llm_family.json +1299 -174
- xinference/model/llm/mlx/distributed_models/core.py +41 -0
- xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
- xinference/model/llm/sglang/core.py +44 -11
- xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
- xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
- xinference/model/llm/transformers/chatglm.py +3 -0
- xinference/model/llm/transformers/core.py +129 -36
- xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
- xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
- xinference/model/llm/transformers/utils.py +23 -0
- xinference/model/llm/utils.py +48 -32
- xinference/model/llm/vllm/core.py +207 -72
- xinference/model/utils.py +74 -31
- xinference/thirdparty/audiotools/__init__.py +10 -0
- xinference/thirdparty/audiotools/core/__init__.py +4 -0
- xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
- xinference/thirdparty/audiotools/core/display.py +194 -0
- xinference/thirdparty/audiotools/core/dsp.py +390 -0
- xinference/thirdparty/audiotools/core/effects.py +647 -0
- xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
- xinference/thirdparty/audiotools/core/loudness.py +320 -0
- xinference/thirdparty/audiotools/core/playback.py +252 -0
- xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
- xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
- xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
- xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
- xinference/thirdparty/audiotools/core/util.py +671 -0
- xinference/thirdparty/audiotools/core/whisper.py +97 -0
- xinference/thirdparty/audiotools/data/__init__.py +3 -0
- xinference/thirdparty/audiotools/data/datasets.py +517 -0
- xinference/thirdparty/audiotools/data/preprocess.py +81 -0
- xinference/thirdparty/audiotools/data/transforms.py +1592 -0
- xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
- xinference/thirdparty/audiotools/metrics/distance.py +131 -0
- xinference/thirdparty/audiotools/metrics/quality.py +159 -0
- xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
- xinference/thirdparty/audiotools/ml/__init__.py +5 -0
- xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
- xinference/thirdparty/audiotools/ml/decorators.py +440 -0
- xinference/thirdparty/audiotools/ml/experiment.py +90 -0
- xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
- xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
- xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
- xinference/thirdparty/audiotools/post.py +140 -0
- xinference/thirdparty/audiotools/preference.py +600 -0
- xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
- xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
- xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
- xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
- xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
- xinference/thirdparty/indextts/__init__.py +0 -0
- xinference/thirdparty/indextts/cli.py +65 -0
- xinference/thirdparty/indextts/gpt/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
- xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
- xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
- xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
- xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
- xinference/thirdparty/indextts/gpt/model.py +713 -0
- xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
- xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
- xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
- xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
- xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
- xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
- xinference/thirdparty/indextts/infer.py +690 -0
- xinference/thirdparty/indextts/infer_v2.py +739 -0
- xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
- xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
- xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
- xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
- xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
- xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
- xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
- xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
- xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
- xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
- xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
- xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
- xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
- xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
- xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
- xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
- xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
- xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
- xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
- xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
- xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
- xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
- xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
- xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
- xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
- xinference/thirdparty/indextts/utils/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/arch_util.py +120 -0
- xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
- xinference/thirdparty/indextts/utils/common.py +121 -0
- xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
- xinference/thirdparty/indextts/utils/front.py +536 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
- xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
- xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
- xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
- xinference/thirdparty/indextts/utils/text_utils.py +41 -0
- xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
- xinference/thirdparty/indextts/utils/utils.py +93 -0
- xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
- xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
- xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
- xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
- xinference/thirdparty/melo/text/chinese_mix.py +2 -2
- xinference/types.py +9 -0
- xinference/ui/gradio/media_interface.py +66 -8
- xinference/ui/web/ui/build/asset-manifest.json +6 -6
- xinference/ui/web/ui/build/index.html +1 -1
- xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
- xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
- xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
- xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
- xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
- xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
- xinference/ui/web/ui/package-lock.json +0 -34
- xinference/ui/web/ui/package.json +0 -1
- xinference/ui/web/ui/src/locales/en.json +9 -3
- xinference/ui/web/ui/src/locales/ja.json +9 -3
- xinference/ui/web/ui/src/locales/ko.json +9 -3
- xinference/ui/web/ui/src/locales/zh.json +9 -3
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
- xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
- xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
- xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
- xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
- xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
- xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
- xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
- xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
- xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
- xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
- xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
- xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
- xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
- xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
- xinference/ui/web/ui/node_modules/select/bower.json +0 -13
- xinference/ui/web/ui/node_modules/select/package.json +0 -29
- xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
- {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,713 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
import transformers
|
|
8
|
+
from transformers import GPT2Config, LogitsProcessorList
|
|
9
|
+
from indextts.gpt.transformers_gpt2 import GPT2PreTrainedModel, GPT2Model
|
|
10
|
+
|
|
11
|
+
# from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList
|
|
12
|
+
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|
13
|
+
from transformers.utils.model_parallel_utils import (assert_device_map,
|
|
14
|
+
get_device_map)
|
|
15
|
+
|
|
16
|
+
from indextts.gpt.conformer_encoder import ConformerEncoder
|
|
17
|
+
from indextts.gpt.perceiver import PerceiverResampler
|
|
18
|
+
from indextts.utils.arch_util import AttentionBlock
|
|
19
|
+
from indextts.utils.typical_sampling import TypicalLogitsWarper
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def null_position_embeddings(range, dim):
|
|
23
|
+
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ResBlock(nn.Module):
|
|
27
|
+
"""
|
|
28
|
+
Basic residual convolutional block that uses GroupNorm.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, chan):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.net = nn.Sequential(
|
|
34
|
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
|
35
|
+
nn.GroupNorm(chan // 8, chan),
|
|
36
|
+
nn.ReLU(),
|
|
37
|
+
nn.Conv1d(chan, chan, kernel_size=3, padding=1),
|
|
38
|
+
nn.GroupNorm(chan // 8, chan)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def forward(self, x):
|
|
42
|
+
return F.relu(self.net(x) + x)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GPT2InferenceModel(GPT2PreTrainedModel):
|
|
46
|
+
def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
|
|
47
|
+
super().__init__(config)
|
|
48
|
+
# Note: the argument named `text_pos_emb` here actually represents the mel position embedding
|
|
49
|
+
self.transformer = gpt
|
|
50
|
+
self.text_pos_embedding = text_pos_emb
|
|
51
|
+
self.embeddings = embeddings
|
|
52
|
+
self.final_norm = norm
|
|
53
|
+
self.lm_head = nn.Sequential(norm, linear)
|
|
54
|
+
self.kv_cache = kv_cache
|
|
55
|
+
|
|
56
|
+
# Model parallel
|
|
57
|
+
self.model_parallel = False
|
|
58
|
+
self.device_map = None
|
|
59
|
+
self.cached_mel_emb = None
|
|
60
|
+
|
|
61
|
+
def parallelize(self, device_map=None):
|
|
62
|
+
self.device_map = (
|
|
63
|
+
get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
|
|
64
|
+
if device_map is None
|
|
65
|
+
else device_map
|
|
66
|
+
)
|
|
67
|
+
assert_device_map(self.device_map, len(self.transformer.h))
|
|
68
|
+
self.transformer.parallelize(self.device_map)
|
|
69
|
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
|
70
|
+
self.model_parallel = True
|
|
71
|
+
|
|
72
|
+
def deparallelize(self):
|
|
73
|
+
self.transformer.deparallelize()
|
|
74
|
+
self.transformer = self.transformer.to("cpu")
|
|
75
|
+
self.lm_head = self.lm_head.to("cpu")
|
|
76
|
+
self.model_parallel = False
|
|
77
|
+
torch.cuda.empty_cache()
|
|
78
|
+
if torch.backends.mps.is_available():
|
|
79
|
+
torch.mps.empty_cache()
|
|
80
|
+
|
|
81
|
+
def get_output_embeddings(self):
|
|
82
|
+
return self.lm_head
|
|
83
|
+
|
|
84
|
+
def set_output_embeddings(self, new_embeddings):
|
|
85
|
+
self.lm_head = new_embeddings
|
|
86
|
+
|
|
87
|
+
def store_mel_emb(self, mel_emb):
|
|
88
|
+
self.cached_mel_emb = mel_emb
|
|
89
|
+
|
|
90
|
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
|
91
|
+
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
|
92
|
+
if not self.kv_cache:
|
|
93
|
+
past_key_values = None
|
|
94
|
+
# only last token for inputs_ids if past is defined in kwargs
|
|
95
|
+
if past_key_values:
|
|
96
|
+
input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
97
|
+
if token_type_ids is not None:
|
|
98
|
+
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
|
99
|
+
|
|
100
|
+
attention_mask = kwargs.get("attention_mask", None)
|
|
101
|
+
position_ids = kwargs.get("position_ids", None)
|
|
102
|
+
|
|
103
|
+
if attention_mask is not None and position_ids is None:
|
|
104
|
+
# create position_ids on the fly for batch generation
|
|
105
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
106
|
+
position_ids.masked_fill_(attention_mask == 0, 0)
|
|
107
|
+
if past_key_values:
|
|
108
|
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
109
|
+
else:
|
|
110
|
+
position_ids = None
|
|
111
|
+
return {
|
|
112
|
+
"input_ids": input_ids,
|
|
113
|
+
"past_key_values": past_key_values,
|
|
114
|
+
"use_cache": kwargs.get("use_cache"),
|
|
115
|
+
"position_ids": position_ids,
|
|
116
|
+
"attention_mask": attention_mask,
|
|
117
|
+
"token_type_ids": token_type_ids,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
def forward(
|
|
121
|
+
self,
|
|
122
|
+
input_ids=None,
|
|
123
|
+
past_key_values=None,
|
|
124
|
+
attention_mask=None,
|
|
125
|
+
token_type_ids=None,
|
|
126
|
+
position_ids=None,
|
|
127
|
+
head_mask=None,
|
|
128
|
+
inputs_embeds=None,
|
|
129
|
+
encoder_hidden_states=None,
|
|
130
|
+
encoder_attention_mask=None,
|
|
131
|
+
labels=None,
|
|
132
|
+
use_cache=None,
|
|
133
|
+
output_attentions=None,
|
|
134
|
+
output_hidden_states=None,
|
|
135
|
+
return_dict=None,
|
|
136
|
+
):
|
|
137
|
+
assert self.cached_mel_emb is not None
|
|
138
|
+
assert inputs_embeds is None # Not supported by this inference model.
|
|
139
|
+
assert labels is None # Training not supported by this inference model.
|
|
140
|
+
return_dict = (
|
|
141
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
142
|
+
)
|
|
143
|
+
# Create embedding
|
|
144
|
+
mel_len = self.cached_mel_emb.shape[1]
|
|
145
|
+
if input_ids.shape[1] != 1:
|
|
146
|
+
text_inputs = input_ids[:, mel_len:]
|
|
147
|
+
text_emb = self.embeddings(text_inputs)
|
|
148
|
+
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
|
149
|
+
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
|
150
|
+
mel_emb = self.cached_mel_emb.repeat_interleave(
|
|
151
|
+
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
|
|
152
|
+
)
|
|
153
|
+
else: # this outcome only occurs once per loop in most cases
|
|
154
|
+
mel_emb = self.cached_mel_emb
|
|
155
|
+
emb = torch.cat([mel_emb, text_emb], dim=1)
|
|
156
|
+
else:
|
|
157
|
+
emb = self.embeddings(input_ids)
|
|
158
|
+
emb = emb + self.text_pos_embedding.get_fixed_embedding(
|
|
159
|
+
attention_mask.shape[1] - mel_len, attention_mask.device
|
|
160
|
+
)
|
|
161
|
+
transformer_outputs = self.transformer(
|
|
162
|
+
inputs_embeds=emb,
|
|
163
|
+
past_key_values=past_key_values,
|
|
164
|
+
attention_mask=attention_mask,
|
|
165
|
+
token_type_ids=token_type_ids,
|
|
166
|
+
position_ids=position_ids,
|
|
167
|
+
head_mask=head_mask,
|
|
168
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
169
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
170
|
+
use_cache=use_cache,
|
|
171
|
+
output_attentions=output_attentions,
|
|
172
|
+
output_hidden_states=output_hidden_states,
|
|
173
|
+
return_dict=return_dict,
|
|
174
|
+
)
|
|
175
|
+
hidden_states = transformer_outputs[0]
|
|
176
|
+
|
|
177
|
+
# Set device for model parallelism
|
|
178
|
+
if self.model_parallel:
|
|
179
|
+
if torch.backends.mps.is_available():
|
|
180
|
+
self.to(self.transformer.first_device)
|
|
181
|
+
else:
|
|
182
|
+
torch.cuda.set_device(self.transformer.first_device)
|
|
183
|
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
|
184
|
+
|
|
185
|
+
lm_logits = self.lm_head(hidden_states)
|
|
186
|
+
|
|
187
|
+
if not return_dict:
|
|
188
|
+
return (lm_logits,) + transformer_outputs[1:]
|
|
189
|
+
|
|
190
|
+
return CausalLMOutputWithCrossAttentions(
|
|
191
|
+
loss=None,
|
|
192
|
+
logits=lm_logits,
|
|
193
|
+
past_key_values=transformer_outputs.past_key_values,
|
|
194
|
+
hidden_states=transformer_outputs.hidden_states,
|
|
195
|
+
attentions=transformer_outputs.attentions,
|
|
196
|
+
cross_attentions=transformer_outputs.cross_attentions,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def _reorder_cache(past, beam_idx):
|
|
201
|
+
"""
|
|
202
|
+
This function is used to re-order the :obj:`past_key_values` cache if
|
|
203
|
+
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
|
204
|
+
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
|
205
|
+
"""
|
|
206
|
+
return tuple(
|
|
207
|
+
tuple(
|
|
208
|
+
past_state.index_select(0, beam_idx.to(past_state.device))
|
|
209
|
+
for past_state in layer_past
|
|
210
|
+
)
|
|
211
|
+
for layer_past in past
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ConditioningEncoder(nn.Module):
|
|
216
|
+
def __init__(self,
|
|
217
|
+
spec_dim,
|
|
218
|
+
embedding_dim,
|
|
219
|
+
attn_blocks=6,
|
|
220
|
+
num_attn_heads=4,
|
|
221
|
+
do_checkpointing=False,
|
|
222
|
+
mean=False):
|
|
223
|
+
super().__init__()
|
|
224
|
+
attn = []
|
|
225
|
+
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
|
226
|
+
for a in range(attn_blocks):
|
|
227
|
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
|
228
|
+
self.attn = nn.Sequential(*attn)
|
|
229
|
+
self.dim = embedding_dim
|
|
230
|
+
self.do_checkpointing = do_checkpointing
|
|
231
|
+
self.mean = mean
|
|
232
|
+
|
|
233
|
+
def forward(self, x):
|
|
234
|
+
h = self.init(x)
|
|
235
|
+
h = self.attn(h)
|
|
236
|
+
if self.mean:
|
|
237
|
+
return h.mean(dim=2)
|
|
238
|
+
else:
|
|
239
|
+
return h
|
|
240
|
+
# return h[:, :, 0]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class LearnedPositionEmbeddings(nn.Module):
|
|
244
|
+
def __init__(self, seq_len, model_dim, init=.02):
|
|
245
|
+
super().__init__()
|
|
246
|
+
self.emb = nn.Embedding(seq_len, model_dim)
|
|
247
|
+
# Initializing this way is standard for GPT-2
|
|
248
|
+
self.emb.weight.data.normal_(mean=0.0, std=init)
|
|
249
|
+
|
|
250
|
+
def forward(self, x):
|
|
251
|
+
sl = x.shape[1]
|
|
252
|
+
return self.emb(torch.arange(0, sl, device=x.device))
|
|
253
|
+
|
|
254
|
+
def get_fixed_embedding(self, ind, dev):
|
|
255
|
+
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
|
|
259
|
+
"""
|
|
260
|
+
GPT-2 implemented by the HuggingFace library.
|
|
261
|
+
"""
|
|
262
|
+
from transformers import GPT2Config, GPT2Model
|
|
263
|
+
gpt_config = GPT2Config(vocab_size=256, # Unused.
|
|
264
|
+
n_positions=max_mel_seq_len + max_text_seq_len,
|
|
265
|
+
n_ctx=max_mel_seq_len + max_text_seq_len,
|
|
266
|
+
n_embd=model_dim,
|
|
267
|
+
n_layer=layers,
|
|
268
|
+
n_head=heads,
|
|
269
|
+
activation_function=activation_function or "gelu_new",
|
|
270
|
+
gradient_checkpointing=checkpointing,
|
|
271
|
+
use_cache=not checkpointing)
|
|
272
|
+
gpt = GPT2Model(gpt_config)
|
|
273
|
+
# Override the built in positional embeddings
|
|
274
|
+
del gpt.wpe
|
|
275
|
+
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
|
276
|
+
# Built-in token embeddings are unused.
|
|
277
|
+
del gpt.wte
|
|
278
|
+
return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
|
|
279
|
+
None, None
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class MelEncoder(nn.Module):
|
|
283
|
+
def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
|
|
284
|
+
super().__init__()
|
|
285
|
+
self.channels = channels
|
|
286
|
+
self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
|
287
|
+
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
|
288
|
+
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
|
289
|
+
nn.GroupNorm(channels // 16, channels // 2),
|
|
290
|
+
nn.ReLU(),
|
|
291
|
+
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
|
292
|
+
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
|
293
|
+
nn.GroupNorm(channels // 8, channels),
|
|
294
|
+
nn.ReLU(),
|
|
295
|
+
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
|
296
|
+
)
|
|
297
|
+
self.reduction = 4
|
|
298
|
+
|
|
299
|
+
def forward(self, x):
|
|
300
|
+
for e in self.encoder:
|
|
301
|
+
x = e(x)
|
|
302
|
+
return x.permute(0, 2, 1)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class UnifiedVoice(nn.Module):
|
|
306
|
+
def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
|
|
307
|
+
mel_length_compression=1024, number_text_tokens=256,
|
|
308
|
+
start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
|
|
309
|
+
train_solo_embeddings=False, use_mel_codes_as_input=True,
|
|
310
|
+
checkpointing=True, types=1, activation_function=None,
|
|
311
|
+
condition_num_latent=32, condition_type="perceiver", condition_module=None):
|
|
312
|
+
"""
|
|
313
|
+
Args:
|
|
314
|
+
layers: Number of layers in transformer stack.
|
|
315
|
+
model_dim: Operating dimensions of the transformer
|
|
316
|
+
heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
|
|
317
|
+
max_text_tokens: Maximum number of text tokens that will be encountered by model.
|
|
318
|
+
max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
|
|
319
|
+
max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
|
|
320
|
+
mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
|
|
321
|
+
number_text_tokens:
|
|
322
|
+
start_text_token:
|
|
323
|
+
stop_text_token:
|
|
324
|
+
number_mel_codes:
|
|
325
|
+
start_mel_token:
|
|
326
|
+
stop_mel_token:
|
|
327
|
+
train_solo_embeddings:
|
|
328
|
+
use_mel_codes_as_input:
|
|
329
|
+
checkpointing:
|
|
330
|
+
condition_type: perceiver, gst or default encoder
|
|
331
|
+
"""
|
|
332
|
+
super().__init__()
|
|
333
|
+
self.number_text_tokens = number_text_tokens
|
|
334
|
+
self.start_text_token = start_text_token
|
|
335
|
+
self.stop_text_token = stop_text_token
|
|
336
|
+
self.number_mel_codes = number_mel_codes
|
|
337
|
+
self.start_mel_token = start_mel_token
|
|
338
|
+
self.stop_mel_token = stop_mel_token
|
|
339
|
+
self.layers = layers
|
|
340
|
+
self.heads = heads
|
|
341
|
+
self.max_mel_tokens = max_mel_tokens
|
|
342
|
+
self.max_text_tokens = max_text_tokens
|
|
343
|
+
self.model_dim = model_dim
|
|
344
|
+
self.max_conditioning_inputs = max_conditioning_inputs
|
|
345
|
+
self.mel_length_compression = mel_length_compression
|
|
346
|
+
self.condition_type = condition_type
|
|
347
|
+
self.cond_num = condition_num_latent
|
|
348
|
+
self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
|
|
349
|
+
if condition_type == "perceiver":
|
|
350
|
+
self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads)
|
|
351
|
+
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
|
|
352
|
+
elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
|
|
353
|
+
self.conditioning_encoder = ConformerEncoder(input_size=100,
|
|
354
|
+
output_size=condition_module['output_size'],
|
|
355
|
+
linear_units=condition_module['linear_units'],
|
|
356
|
+
attention_heads=condition_module['attention_heads'],
|
|
357
|
+
num_blocks=condition_module['num_blocks'],
|
|
358
|
+
input_layer=condition_module['input_layer'])
|
|
359
|
+
if condition_type == "conformer_perceiver":
|
|
360
|
+
self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
|
|
361
|
+
ff_mult=condition_module['perceiver_mult'],
|
|
362
|
+
heads=condition_module['attention_heads'],
|
|
363
|
+
num_latents=self.cond_num)
|
|
364
|
+
else:
|
|
365
|
+
self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads, mean=True)
|
|
366
|
+
|
|
367
|
+
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
|
|
368
|
+
if use_mel_codes_as_input:
|
|
369
|
+
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
|
370
|
+
else:
|
|
371
|
+
self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
|
|
372
|
+
self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
|
|
373
|
+
build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
|
|
374
|
+
self.max_text_tokens + 2, checkpointing, activation_function)
|
|
375
|
+
if train_solo_embeddings:
|
|
376
|
+
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
|
377
|
+
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
|
|
378
|
+
else:
|
|
379
|
+
self.mel_solo_embedding = 0
|
|
380
|
+
self.text_solo_embedding = 0
|
|
381
|
+
|
|
382
|
+
self.final_norm = nn.LayerNorm(model_dim)
|
|
383
|
+
self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
|
|
384
|
+
self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
|
|
385
|
+
|
|
386
|
+
# Initialize the embeddings per the GPT-2 scheme
|
|
387
|
+
embeddings = [self.text_embedding]
|
|
388
|
+
if use_mel_codes_as_input:
|
|
389
|
+
embeddings.append(self.mel_embedding)
|
|
390
|
+
for module in embeddings:
|
|
391
|
+
module.weight.data.normal_(mean=0.0, std=.02)
|
|
392
|
+
|
|
393
|
+
def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
|
|
394
|
+
seq_length = self.max_mel_tokens + self.max_text_tokens + 2
|
|
395
|
+
gpt_config = GPT2Config(
|
|
396
|
+
vocab_size=self.number_mel_codes,
|
|
397
|
+
n_positions=seq_length,
|
|
398
|
+
n_ctx=seq_length,
|
|
399
|
+
n_embd=self.model_dim,
|
|
400
|
+
n_layer=self.layers,
|
|
401
|
+
n_head=self.heads,
|
|
402
|
+
gradient_checkpointing=False,
|
|
403
|
+
use_cache=True,
|
|
404
|
+
)
|
|
405
|
+
self.inference_model = GPT2InferenceModel(
|
|
406
|
+
gpt_config,
|
|
407
|
+
self.gpt,
|
|
408
|
+
self.mel_pos_embedding,
|
|
409
|
+
self.mel_embedding,
|
|
410
|
+
self.final_norm,
|
|
411
|
+
self.mel_head,
|
|
412
|
+
kv_cache=kv_cache,
|
|
413
|
+
)
|
|
414
|
+
if use_deepspeed and half and torch.cuda.is_available():
|
|
415
|
+
import deepspeed
|
|
416
|
+
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
|
417
|
+
mp_size=1,
|
|
418
|
+
replace_with_kernel_inject=False,
|
|
419
|
+
dtype=torch.float16)
|
|
420
|
+
self.inference_model = self.ds_engine.module.eval()
|
|
421
|
+
elif use_deepspeed and torch.cuda.is_available():
|
|
422
|
+
import deepspeed
|
|
423
|
+
self.ds_engine = deepspeed.init_inference(model=self.inference_model,
|
|
424
|
+
mp_size=1,
|
|
425
|
+
replace_with_kernel_inject=False,
|
|
426
|
+
dtype=torch.float32)
|
|
427
|
+
self.inference_model = self.ds_engine.module.eval()
|
|
428
|
+
else:
|
|
429
|
+
self.inference_model = self.inference_model.eval()
|
|
430
|
+
|
|
431
|
+
# self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
|
|
432
|
+
self.gpt.wte = self.mel_embedding
|
|
433
|
+
|
|
434
|
+
def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
|
|
435
|
+
inp = F.pad(input, (1, 0), value=start_token)
|
|
436
|
+
tar = F.pad(input, (0, 1), value=stop_token)
|
|
437
|
+
return inp, tar
|
|
438
|
+
|
|
439
|
+
def set_mel_padding(self, mel_input_tokens, mel_lengths):
|
|
440
|
+
"""
|
|
441
|
+
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
|
442
|
+
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
|
|
443
|
+
preformatting to create a working TTS model.
|
|
444
|
+
"""
|
|
445
|
+
for b in range(len(mel_lengths)):
|
|
446
|
+
# Due to the convolutional nature of how these tokens are generated,
|
|
447
|
+
# it would be best if the model predicts a token past the actual last token.
|
|
448
|
+
actual_end = mel_lengths[b]
|
|
449
|
+
if actual_end < mel_input_tokens.shape[-1]:
|
|
450
|
+
mel_input_tokens[b, actual_end:] = self.stop_mel_token
|
|
451
|
+
return mel_input_tokens
|
|
452
|
+
|
|
453
|
+
def set_text_padding(self, text_input_tokens, text_lengths):
|
|
454
|
+
"""
|
|
455
|
+
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
|
456
|
+
that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
|
|
457
|
+
preformatting to create a working TTS model.
|
|
458
|
+
"""
|
|
459
|
+
for b in range(len(text_lengths)):
|
|
460
|
+
# Due to the convolutional nature of how these tokens are generated,
|
|
461
|
+
# it would be best if the model predicts a token past the actual last token.
|
|
462
|
+
actual_end = text_lengths[b]
|
|
463
|
+
if actual_end < text_input_tokens.shape[-1]:
|
|
464
|
+
text_input_tokens[b, actual_end:] = self.stop_text_token
|
|
465
|
+
return text_input_tokens
|
|
466
|
+
|
|
467
|
+
def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
|
|
468
|
+
if second_inputs is not None:
|
|
469
|
+
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
|
470
|
+
else:
|
|
471
|
+
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
|
472
|
+
|
|
473
|
+
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
|
474
|
+
if get_attns:
|
|
475
|
+
return gpt_out.attentions
|
|
476
|
+
|
|
477
|
+
offset = speech_conditioning_inputs.shape[1]
|
|
478
|
+
enc = gpt_out.last_hidden_state[:, offset:]
|
|
479
|
+
enc = self.final_norm(enc)
|
|
480
|
+
|
|
481
|
+
if return_latent:
|
|
482
|
+
return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
|
|
483
|
+
|
|
484
|
+
first_logits = enc[:, :first_inputs.shape[1]]
|
|
485
|
+
first_logits = first_head(first_logits)
|
|
486
|
+
first_logits = first_logits.permute(0, 2, 1)
|
|
487
|
+
if second_inputs is not None:
|
|
488
|
+
second_logits = enc[:, -second_inputs.shape[1]:]
|
|
489
|
+
second_logits = second_head(second_logits)
|
|
490
|
+
second_logits = second_logits.permute(0, 2, 1)
|
|
491
|
+
return first_logits, second_logits
|
|
492
|
+
else:
|
|
493
|
+
return first_logits
|
|
494
|
+
|
|
495
|
+
def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
|
|
496
|
+
if self.condition_type == "perceiver":
|
|
497
|
+
if speech_conditioning_input.ndim == 4:
|
|
498
|
+
speech_conditioning_input = speech_conditioning_input.squeeze(1)
|
|
499
|
+
speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
|
|
500
|
+
conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
|
|
501
|
+
elif self.condition_type == "conformer_perceiver":
|
|
502
|
+
speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
|
|
503
|
+
cond_mel_lengths) # (b, s, d), (b, 1, s)
|
|
504
|
+
if self.condition_type == "conformer_perceiver":
|
|
505
|
+
# conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
|
|
506
|
+
conds_mask = self.cond_mask_pad(mask.squeeze(1))
|
|
507
|
+
conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
|
|
508
|
+
elif self.condition_type == "gst":
|
|
509
|
+
if speech_conditioning_input.ndim == 4:
|
|
510
|
+
speech_conditioning_input = speech_conditioning_input.squeeze(1)
|
|
511
|
+
conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
|
|
512
|
+
else:
|
|
513
|
+
speech_conditioning_input = (
|
|
514
|
+
speech_conditioning_input.unsqueeze(1)
|
|
515
|
+
if len(speech_conditioning_input.shape) == 3
|
|
516
|
+
else speech_conditioning_input
|
|
517
|
+
)
|
|
518
|
+
conds = []
|
|
519
|
+
for j in range(speech_conditioning_input.shape[1]):
|
|
520
|
+
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
|
521
|
+
conds = torch.stack(conds, dim=1)
|
|
522
|
+
conds = conds.mean(dim=1)
|
|
523
|
+
conds = conds.unsqueeze(1)
|
|
524
|
+
return conds
|
|
525
|
+
|
|
526
|
+
def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths,
|
|
527
|
+
cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False,
|
|
528
|
+
return_latent=False, clip_inputs=False):
|
|
529
|
+
"""
|
|
530
|
+
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
|
531
|
+
(actuated by `text_first`).
|
|
532
|
+
|
|
533
|
+
speech_conditioning_input: MEL float tensor, (b,1024)
|
|
534
|
+
text_inputs: long tensor, (b,t)
|
|
535
|
+
text_lengths: long tensor, (b,)
|
|
536
|
+
mel_inputs: long tensor, (b,m)
|
|
537
|
+
wav_lengths: long tensor, (b,)
|
|
538
|
+
raw_mels: MEL float tensor (b,80,s)
|
|
539
|
+
|
|
540
|
+
If return_attentions is specified, only logits are returned.
|
|
541
|
+
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
|
542
|
+
If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
|
|
543
|
+
"""
|
|
544
|
+
|
|
545
|
+
speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
|
|
546
|
+
# Types are expressed by expanding the text embedding space.
|
|
547
|
+
if types is not None:
|
|
548
|
+
text_inputs = text_inputs * (1 + types).unsqueeze(-1)
|
|
549
|
+
|
|
550
|
+
if clip_inputs:
|
|
551
|
+
# This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
|
|
552
|
+
# chopping the inputs by the maximum actual length.
|
|
553
|
+
max_text_len = text_lengths.max()
|
|
554
|
+
text_inputs = text_inputs[:, :max_text_len]
|
|
555
|
+
max_mel_len = wav_lengths.max() // self.mel_length_compression
|
|
556
|
+
mel_codes = mel_codes[:, :max_mel_len]
|
|
557
|
+
if raw_mels is not None:
|
|
558
|
+
raw_mels = raw_mels[:, :, :max_mel_len * 4]
|
|
559
|
+
|
|
560
|
+
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
|
561
|
+
# mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
|
|
562
|
+
mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
|
|
563
|
+
mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
|
|
564
|
+
text_inputs = self.set_text_padding(text_inputs, text_lengths)
|
|
565
|
+
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
|
566
|
+
mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
|
|
567
|
+
|
|
568
|
+
conds = speech_conditioning_latent
|
|
569
|
+
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
|
570
|
+
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
|
571
|
+
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
|
|
572
|
+
if raw_mels is not None:
|
|
573
|
+
mel_inp = F.pad(raw_mels, (0, 8))
|
|
574
|
+
else:
|
|
575
|
+
mel_inp = mel_codes
|
|
576
|
+
mel_emb = self.mel_embedding(mel_inp)
|
|
577
|
+
mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
|
|
578
|
+
|
|
579
|
+
if text_first:
|
|
580
|
+
# print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
|
|
581
|
+
text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
|
|
582
|
+
if return_latent:
|
|
583
|
+
return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
|
584
|
+
else:
|
|
585
|
+
mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
|
|
586
|
+
if return_latent:
|
|
587
|
+
return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
|
|
588
|
+
|
|
589
|
+
if return_attentions:
|
|
590
|
+
return mel_logits
|
|
591
|
+
|
|
592
|
+
loss_text = F.cross_entropy(text_logits, text_targets.long())
|
|
593
|
+
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
|
594
|
+
return loss_text.mean(), loss_mel.mean(), mel_logits
|
|
595
|
+
|
|
596
|
+
def prepare_gpt_inputs(
|
|
597
|
+
self,
|
|
598
|
+
conditional_latents: torch.Tensor,
|
|
599
|
+
text_inputs: torch.Tensor,
|
|
600
|
+
):
|
|
601
|
+
|
|
602
|
+
"""
|
|
603
|
+
Prepare the inputs for the GPT2InferenceModel to generate.
|
|
604
|
+
Args:
|
|
605
|
+
conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
|
|
606
|
+
text_inputs: (b, L)
|
|
607
|
+
Returns:
|
|
608
|
+
input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
|
|
609
|
+
inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
|
|
610
|
+
attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
|
|
611
|
+
"""
|
|
612
|
+
b, L = text_inputs.shape[:2]
|
|
613
|
+
device = text_inputs.device
|
|
614
|
+
single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
|
|
615
|
+
if not single_cond:
|
|
616
|
+
assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
|
|
617
|
+
batched_mel_emb = []
|
|
618
|
+
attention_masks = []
|
|
619
|
+
target_len = conditional_latents.shape[1] + L + 2
|
|
620
|
+
for i in range(b):
|
|
621
|
+
valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
|
|
622
|
+
text_input = text_inputs[i][valid_mask]
|
|
623
|
+
text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
|
|
624
|
+
text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
|
|
625
|
+
text_input_pos = torch.arange(0, text_input.size(-1), device=device)
|
|
626
|
+
text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
|
|
627
|
+
# concatenate [conditional latents][text embeddings]
|
|
628
|
+
conds_text_emb = [
|
|
629
|
+
conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
|
|
630
|
+
text_emb,
|
|
631
|
+
]
|
|
632
|
+
# +1 for the start_mel_token
|
|
633
|
+
attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
|
|
634
|
+
# check this text input is padded
|
|
635
|
+
padding: int = L + 2 - text_input.size(-1)
|
|
636
|
+
# pad left of [cond][text] -> [pad][cond][text]
|
|
637
|
+
if padding > 0:
|
|
638
|
+
pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
|
|
639
|
+
conds_text_emb.insert(0, pad)
|
|
640
|
+
attention_mask[:padding] = 0
|
|
641
|
+
mel_emb = torch.cat(conds_text_emb) #[s, dim]
|
|
642
|
+
assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
|
|
643
|
+
batched_mel_emb.append(mel_emb)
|
|
644
|
+
attention_masks.append(attention_mask)
|
|
645
|
+
# [b, s, dim]
|
|
646
|
+
batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
|
|
647
|
+
# [b, s+1]
|
|
648
|
+
attention_mask = torch.stack(attention_masks, dim=0)
|
|
649
|
+
# [b, s+1]
|
|
650
|
+
fake_inputs = torch.ones(
|
|
651
|
+
(
|
|
652
|
+
batched_mel_emb.shape[0],
|
|
653
|
+
batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
|
|
654
|
+
),
|
|
655
|
+
dtype=torch.long,
|
|
656
|
+
device=device,
|
|
657
|
+
)
|
|
658
|
+
fake_inputs[:, -1] = self.start_mel_token
|
|
659
|
+
return fake_inputs, batched_mel_emb, attention_mask
|
|
660
|
+
def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
|
|
661
|
+
max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
|
|
662
|
+
"""
|
|
663
|
+
Args:
|
|
664
|
+
speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
|
|
665
|
+
text_inputs: (b, L)
|
|
666
|
+
cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
|
|
667
|
+
input_tokens: additional tokens for generation in shape (b, s) or (s,)
|
|
668
|
+
max_generate_length: limit the number of generated tokens
|
|
669
|
+
hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
|
|
670
|
+
"""
|
|
671
|
+
if speech_conditioning_mel.ndim == 2:
|
|
672
|
+
speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
|
|
673
|
+
if cond_mel_lengths is None:
|
|
674
|
+
cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
|
|
675
|
+
conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
|
|
676
|
+
input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
|
|
677
|
+
self.inference_model.store_mel_emb(inputs_embeds)
|
|
678
|
+
if input_tokens is None:
|
|
679
|
+
inputs = input_ids
|
|
680
|
+
else:
|
|
681
|
+
if input_tokens.ndim == 1:
|
|
682
|
+
input_tokens = input_tokens.unsqueeze(0)
|
|
683
|
+
assert num_return_sequences % input_tokens.shape[0] == 0, \
|
|
684
|
+
"The num_return_sequences must be divisible by the batch number of input_tokens"
|
|
685
|
+
assert num_return_sequences % text_inputs.shape[0] == 0, \
|
|
686
|
+
"The num_return_sequences must be divisible by the batch number of text_inputs"
|
|
687
|
+
b = num_return_sequences // input_ids.shape[0]
|
|
688
|
+
if b > 1:
|
|
689
|
+
input_ids = input_ids.repeat(b, 1)
|
|
690
|
+
attention_mask = attention_mask.repeat(b, 1)
|
|
691
|
+
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
|
692
|
+
inputs = torch.cat([input_ids, input_tokens], dim=1)
|
|
693
|
+
attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
|
|
694
|
+
trunc_index = inputs.shape[1]
|
|
695
|
+
logits_processor = LogitsProcessorList()
|
|
696
|
+
if typical_sampling:
|
|
697
|
+
# employ custom typical sampling
|
|
698
|
+
if not (typical_mass > 0.0 and typical_mass < 1.0):
|
|
699
|
+
raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
|
|
700
|
+
min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
|
|
701
|
+
logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
|
|
702
|
+
max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
|
|
703
|
+
output = self.inference_model.generate(inputs,
|
|
704
|
+
bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
|
|
705
|
+
eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
|
|
706
|
+
max_length=max_length, logits_processor=logits_processor,
|
|
707
|
+
num_return_sequences=num_return_sequences,
|
|
708
|
+
**hf_generate_kwargs)
|
|
709
|
+
if isinstance(output, torch.Tensor):
|
|
710
|
+
return output[:, trunc_index:]
|
|
711
|
+
# GenerateOutput
|
|
712
|
+
output.sequences = output.sequences[:, trunc_index:]
|
|
713
|
+
return output
|