xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from indextts.s2mel.modules.diffusion_transformer import DiT
7
+ from indextts.s2mel.modules.commons import sequence_mask
8
+
9
+ from tqdm import tqdm
10
+
11
+ class BASECFM(torch.nn.Module, ABC):
12
+ def __init__(
13
+ self,
14
+ args,
15
+ ):
16
+ super().__init__()
17
+ self.sigma_min = 1e-6
18
+
19
+ self.estimator = None
20
+
21
+ self.in_channels = args.DiT.in_channels
22
+
23
+ self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
24
+
25
+ if hasattr(args.DiT, 'zero_prompt_speech_token'):
26
+ self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
27
+ else:
28
+ self.zero_prompt_speech_token = False
29
+
30
+ @torch.inference_mode()
31
+ def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
32
+ """Forward diffusion
33
+
34
+ Args:
35
+ mu (torch.Tensor): semantic info of reference audio and altered audio
36
+ shape: (batch_size, mel_timesteps(795+1069), 512)
37
+ x_lens (torch.Tensor): mel frames output
38
+ shape: (batch_size, mel_timesteps)
39
+ prompt (torch.Tensor): reference mel
40
+ shape: (batch_size, 80, 795)
41
+ style (torch.Tensor): reference global style
42
+ shape: (batch_size, 192)
43
+ f0: None
44
+ n_timesteps (int): number of diffusion steps
45
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
46
+
47
+ Returns:
48
+ sample: generated mel-spectrogram
49
+ shape: (batch_size, 80, mel_timesteps)
50
+ """
51
+ B, T = mu.size(0), mu.size(1)
52
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
53
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
54
+ # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
55
+ return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
56
+
57
+ def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
58
+ """
59
+ Fixed euler solver for ODEs.
60
+ Args:
61
+ x (torch.Tensor): random noise
62
+ t_span (torch.Tensor): n_timesteps interpolated
63
+ shape: (n_timesteps + 1,)
64
+ mu (torch.Tensor): semantic info of reference audio and altered audio
65
+ shape: (batch_size, mel_timesteps(795+1069), 512)
66
+ x_lens (torch.Tensor): mel frames output
67
+ shape: (batch_size, mel_timesteps)
68
+ prompt (torch.Tensor): reference mel
69
+ shape: (batch_size, 80, 795)
70
+ style (torch.Tensor): reference global style
71
+ shape: (batch_size, 192)
72
+ """
73
+ t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
74
+
75
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
76
+ # Or in future might add like a return_all_steps flag
77
+ sol = []
78
+ # apply prompt
79
+ prompt_len = prompt.size(-1)
80
+ prompt_x = torch.zeros_like(x)
81
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
82
+ x[..., :prompt_len] = 0
83
+ if self.zero_prompt_speech_token:
84
+ mu[..., :prompt_len] = 0
85
+ for step in tqdm(range(1, len(t_span))):
86
+ dt = t_span[step] - t_span[step - 1]
87
+ if inference_cfg_rate > 0:
88
+ # Stack original and CFG (null) inputs for batched processing
89
+ stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
90
+ stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
91
+ stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
92
+ stacked_x = torch.cat([x, x], dim=0)
93
+ stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
94
+
95
+ # Perform a single forward pass for both original and CFG inputs
96
+ stacked_dphi_dt = self.estimator(
97
+ stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
98
+ )
99
+
100
+ # Split the output back into the original and CFG components
101
+ dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
102
+
103
+ # Apply CFG formula
104
+ dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
105
+ else:
106
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
107
+
108
+ x = x + dt * dphi_dt
109
+ t = t + dt
110
+ sol.append(x)
111
+ if step < len(t_span) - 1:
112
+ dt = t_span[step + 1] - t
113
+ x[:, :, :prompt_len] = 0
114
+
115
+ return sol[-1]
116
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
117
+ """Computes diffusion loss
118
+
119
+ Args:
120
+ mu (torch.Tensor): semantic info of reference audio and altered audio
121
+ shape: (batch_size, mel_timesteps(795+1069), 512)
122
+ x1: mel
123
+ x_lens (torch.Tensor): mel frames output
124
+ shape: (batch_size, mel_timesteps)
125
+ prompt (torch.Tensor): reference mel
126
+ shape: (batch_size, 80, 795)
127
+ style (torch.Tensor): reference global style
128
+ shape: (batch_size, 192)
129
+
130
+ Returns:
131
+ loss: conditional flow matching loss
132
+ y: conditional flow
133
+ shape: (batch_size, n_feats, mel_timesteps)
134
+ """
135
+ b, _, t = x1.shape
136
+
137
+ # random timestep
138
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
139
+ # sample noise p(x_0)
140
+ z = torch.randn_like(x1)
141
+
142
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
143
+ u = x1 - (1 - self.sigma_min) * z
144
+
145
+ prompt = torch.zeros_like(x1)
146
+ for bib in range(b):
147
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
148
+ # range covered by prompt are set to 0
149
+ y[bib, :, :prompt_lens[bib]] = 0
150
+ if self.zero_prompt_speech_token:
151
+ mu[bib, :, :prompt_lens[bib]] = 0
152
+
153
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
154
+ loss = 0
155
+ for bib in range(b):
156
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
157
+ loss /= b
158
+
159
+ return loss, estimator_out + (1 - self.sigma_min) * z
160
+
161
+
162
+
163
+ class CFM(BASECFM):
164
+ def __init__(self, args):
165
+ super().__init__(
166
+ args
167
+ )
168
+ if args.dit_type == "DiT":
169
+ self.estimator = DiT(args)
170
+ else:
171
+ raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
@@ -0,0 +1,436 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import itertools
7
+ import sys
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ import torch._dynamo.config
14
+ import torch._inductor.config
15
+
16
+ def device_sync(device):
17
+ if "cuda" in device:
18
+ torch.cuda.synchronize(device)
19
+ elif ("cpu" in device) or ("mps" in device):
20
+ pass
21
+ else:
22
+ print(f"device={device} is not yet suppported")
23
+
24
+
25
+ torch._inductor.config.coordinate_descent_tuning = True
26
+ torch._inductor.config.triton.unique_kernel_names = True
27
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28
+
29
+ default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ # support running without installing as a package
32
+ wd = Path(__file__).parent.parent.resolve()
33
+ sys.path.append(str(wd))
34
+
35
+ from model import Transformer
36
+ from tokenizer import get_tokenizer
37
+
38
+ def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
39
+ q = torch.empty_like(probs_sort).exponential_(1)
40
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
41
+
42
+ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
43
+ logits = logits / max(temperature, 1e-5)
44
+
45
+ if top_k is not None:
46
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
47
+ pivot = v.select(-1, -1).unsqueeze(-1)
48
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
49
+ probs = torch.nn.functional.softmax(logits, dim=-1)
50
+ return probs
51
+
52
+ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
53
+ probs = logits_to_probs(logits[0, -1], temperature, top_k)
54
+ idx_next = multinomial_sample_one_no_sync(probs)
55
+ return idx_next, probs
56
+
57
+ def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
58
+ # input_pos: [B, S]
59
+ logits = model(x, input_pos)
60
+ return sample(logits, **sampling_kwargs)[0]
61
+
62
+ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ # input_pos: [B, 1]
64
+ assert input_pos.shape[-1] == 1
65
+ logits = model(x, input_pos)
66
+ return sample(logits, **sampling_kwargs)
67
+
68
+ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
69
+ new_tokens, new_probs = [], []
70
+ for i in range(num_new_tokens):
71
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
72
+ next_token, next_prob = decode_one_token(
73
+ model, cur_token, input_pos, **sampling_kwargs
74
+ )
75
+ input_pos += 1
76
+ new_tokens.append(next_token.clone())
77
+ callback(new_tokens[-1])
78
+ new_probs.append(next_prob.clone())
79
+ cur_token = next_token.view(1, -1)
80
+
81
+ return new_tokens, new_probs
82
+
83
+
84
+ def model_forward(model, x, input_pos):
85
+ return model(x, input_pos)
86
+
87
+ def speculative_decode(
88
+ model: Transformer,
89
+ draft_model: Transformer,
90
+ cur_token: torch.Tensor,
91
+ input_pos: int,
92
+ speculate_k: int,
93
+ **sampling_kwargs
94
+ ) -> torch.Tensor:
95
+ # draft model inference sequentially
96
+ device = cur_token.device
97
+ orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
98
+ draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
99
+
100
+ draft_tokens = torch.cat(draft_tokens)
101
+ # parallel inference on target model using draft tokens
102
+ target_logits = model_forward(
103
+ model,
104
+ torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
105
+ torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
106
+ )
107
+ target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
108
+ draft_probs = torch.stack(draft_probs)
109
+ # q: target prob, p: draft prob
110
+ # q >= p: always accept draft token
111
+ # q < p: q/p prob to accept draft token
112
+ p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
113
+ q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
114
+ accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
115
+ rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
116
+
117
+ if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
118
+ accept_length = speculate_k + 1
119
+ last_token = multinomial_sample_one_no_sync(target_probs[-1])
120
+ # fill last token into draft model
121
+ model_forward(
122
+ draft_model,
123
+ draft_tokens[-1].view(1, -1),
124
+ orig_input_pos + speculate_k,
125
+ )
126
+ return torch.cat([draft_tokens, last_token])
127
+ else:
128
+ accept_length = rejected_locations[0].item()
129
+ p = draft_probs[accept_length]
130
+ q = target_probs[accept_length]
131
+ new = q - p
132
+ new = torch.where(new > 0, new, 0.0)
133
+ new = new / new.sum()
134
+ next_token = multinomial_sample_one_no_sync(new)
135
+ return torch.cat([draft_tokens[:accept_length], next_token])
136
+
137
+ @torch.no_grad()
138
+ def generate(
139
+ model: Transformer,
140
+ prompt: torch.Tensor,
141
+ max_new_tokens: int,
142
+ *,
143
+ interactive: bool,
144
+ draft_model: Transformer,
145
+ speculate_k: Optional[int] = 8,
146
+ callback = lambda x: x,
147
+ **sampling_kwargs
148
+ ) -> torch.Tensor:
149
+ """
150
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
151
+ """
152
+
153
+ is_speculative = draft_model is not None
154
+ # create an empty tensor of the expected final shape and fill in the current tokens
155
+ T = prompt.size(0)
156
+ T_new = T + max_new_tokens
157
+ if interactive:
158
+ max_seq_length = 350
159
+ else:
160
+ max_seq_length = min(T_new, model.config.block_size)
161
+
162
+ device, dtype = prompt.device, prompt.dtype
163
+ max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164
+ with torch.device(device):
165
+ model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
166
+ if is_speculative and draft_model is not model:
167
+ draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
168
+
169
+ # create an empty tensor of the expected final shape and fill in the current tokens
170
+ empty = torch.empty(T_new, dtype=dtype, device=device)
171
+ empty[:T] = prompt
172
+ seq = empty
173
+ input_pos = torch.arange(0, T, device=device)
174
+
175
+ next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
176
+ if is_speculative:
177
+ prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
178
+ seq[T] = next_token
179
+
180
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
181
+ accept_counts = [0] * (speculate_k + 1)
182
+
183
+ if is_speculative:
184
+ input_pos = input_pos.item() # for speculative decoding easier to keep on host
185
+ while input_pos < T_new - 1:
186
+ cur_token = next_token.view(())
187
+
188
+ next_tokens = speculative_decode(
189
+ model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
190
+ )
191
+
192
+ accept_counts[len(next_tokens) - 1] += 1
193
+ num_added = min(T_new - input_pos - 1, len(next_tokens))
194
+ seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
195
+ for i in next_tokens[: num_added,]:
196
+ callback(i)
197
+ input_pos = input_pos + num_added
198
+ next_token = next_tokens[-1]
199
+ else:
200
+ generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
201
+ seq[T + 1:] = torch.cat(generated_tokens)
202
+
203
+ generate_stats = {
204
+ 'accept_counts': accept_counts
205
+ }
206
+ return seq, generate_stats
207
+
208
+ def encode_tokens(tokenizer, string, bos=True, device=default_device):
209
+ tokens = tokenizer.encode(string)
210
+ if bos:
211
+ tokens = [tokenizer.bos_id()] + tokens
212
+ return torch.tensor(tokens, dtype=torch.int, device=device)
213
+
214
+ def _load_model(checkpoint_path, device, precision, use_tp):
215
+ use_cuda = 'cuda' in device
216
+ with torch.device('meta'):
217
+ model = Transformer.from_name(checkpoint_path.parent.name)
218
+
219
+ if "int8" in str(checkpoint_path):
220
+ print("Using int8 weight-only quantization!")
221
+ from quantize import WeightOnlyInt8QuantHandler
222
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
223
+ model = simple_quantizer.convert_for_runtime()
224
+
225
+ if "int4" in str(checkpoint_path):
226
+ print("Using int4 weight-only quantization!")
227
+ path_comps = checkpoint_path.name.split(".")
228
+ groupsize = int(path_comps[-2][1:])
229
+ from quantize import WeightOnlyInt4QuantHandler
230
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
231
+ model = simple_quantizer.convert_for_runtime()
232
+
233
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
234
+ if "model" in checkpoint and "stories" in str(checkpoint_path):
235
+ checkpoint = checkpoint["model"]
236
+ model.load_state_dict(checkpoint, assign=True)
237
+
238
+ if use_tp:
239
+ from tp import apply_tp
240
+ print("Applying tensor parallel to model ...")
241
+ apply_tp(model)
242
+
243
+ model = model.to(device=device, dtype=precision)
244
+ return model.eval()
245
+
246
+ def _get_model_size(model):
247
+ model_size = 0
248
+ for name, child in model.named_children():
249
+ if not isinstance(child, torch.nn.Embedding):
250
+ model_size += sum(
251
+ [
252
+ p.numel() * p.dtype.itemsize
253
+ for p in itertools.chain(child.parameters(), child.buffers())
254
+ ]
255
+ )
256
+ return model_size
257
+
258
+ B_INST, E_INST = "[INST]", "[/INST]"
259
+
260
+ def main(
261
+ prompt: str = "Hello, my name is",
262
+ interactive: bool = False,
263
+ num_samples: int = 5,
264
+ max_new_tokens: int = 100,
265
+ top_k: int = 200,
266
+ temperature: float = 0.8,
267
+ checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
268
+ compile: bool = True,
269
+ compile_prefill: bool = False,
270
+ profile: Optional[Path] = None,
271
+ draft_checkpoint_path: Optional[Path] = None,
272
+ speculate_k: int = 5,
273
+ device=default_device,
274
+ ) -> None:
275
+ """Generates text samples based on a pre-trained Transformer model and tokenizer.
276
+ """
277
+ assert checkpoint_path.is_file(), checkpoint_path
278
+
279
+ tokenizer_path = checkpoint_path.parent / "tokenizer.model"
280
+ assert tokenizer_path.is_file(), str(tokenizer_path)
281
+
282
+ global print
283
+ from tp import maybe_init_dist
284
+ rank = maybe_init_dist()
285
+ use_tp = rank is not None
286
+ if use_tp:
287
+ if rank != 0:
288
+ # only print on rank 0
289
+ print = lambda *args, **kwargs: None
290
+
291
+ print(f"Using device={device}")
292
+ precision = torch.bfloat16
293
+ is_speculative = draft_checkpoint_path is not None
294
+ is_chat = "chat" in str(checkpoint_path)
295
+
296
+ print("Loading model ...")
297
+ t0 = time.time()
298
+ model = _load_model(checkpoint_path, device, precision, use_tp)
299
+
300
+ if is_speculative:
301
+ draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
302
+ else:
303
+ draft_model = None
304
+
305
+ device_sync(device=device) # MKG
306
+ print(f"Time to load model: {time.time() - t0:.02f} seconds")
307
+
308
+ tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
309
+
310
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
311
+ prompt_length = encoded.size(0)
312
+
313
+ torch.manual_seed(1234)
314
+ model_size = _get_model_size(model)
315
+ if compile:
316
+ if is_speculative and use_tp: # and ("cuda" in device):
317
+ torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
318
+
319
+ if is_speculative:
320
+ global model_forward, logits_to_prob
321
+ model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
322
+
323
+ global decode_one_token, prefill
324
+ decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
325
+
326
+ # Uncomment to squeeze more perf out of prefill
327
+ if compile_prefill:
328
+ prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
329
+
330
+
331
+ aggregate_metrics = {
332
+ 'tokens_per_sec': [],
333
+ 'accept_counts': [],
334
+ }
335
+ start = -1 if compile else 0
336
+
337
+ for i in range(start, num_samples):
338
+ device_sync(device=device) # MKG
339
+ if i >= 0 and interactive:
340
+ prompt = input("What is your prompt? ")
341
+ if is_chat:
342
+ prompt = f"{B_INST} {prompt.strip()} {E_INST}"
343
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
344
+
345
+ if interactive and i >= 0:
346
+ buffer = []
347
+ period_id = tokenizer.encode('.')[0]
348
+ done_generating = False
349
+ def callback(x):
350
+ nonlocal done_generating
351
+ if done_generating:
352
+ return
353
+ buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
354
+ if x.item() == tokenizer.eos_id():
355
+ done_generating = True
356
+ if len(buffer) == 4 or done_generating:
357
+ print(''.join(buffer), end='', flush=True)
358
+ buffer.clear()
359
+ # print(, end='', flush=True)
360
+ else:
361
+ callback = lambda x : x
362
+ t0 = time.perf_counter()
363
+ import contextlib
364
+ if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
365
+ prof = contextlib.nullcontext()
366
+ else:
367
+ torch.profiler._utils._init_for_cuda_graphs()
368
+ prof = torch.profiler.profile()
369
+ with prof:
370
+ y, metrics = generate(
371
+ model,
372
+ encoded,
373
+ max_new_tokens,
374
+ draft_model=draft_model,
375
+ speculate_k=speculate_k,
376
+ interactive=interactive,
377
+ callback=callback,
378
+ temperature=temperature,
379
+ top_k=top_k,
380
+ )
381
+ aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
382
+ if i == -1:
383
+ print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
384
+ continue
385
+ if hasattr(prof, "export_chrome_trace"):
386
+ if use_tp:
387
+ prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
388
+ else:
389
+ prof.export_chrome_trace(f"{profile}.json")
390
+ device_sync(device=device) # MKG
391
+ t = time.perf_counter() - t0
392
+
393
+ if not interactive:
394
+ print(tokenizer.decode(y.tolist()))
395
+ else:
396
+ print()
397
+ tokens_generated = y.size(0) - prompt_length
398
+ tokens_sec = tokens_generated / t
399
+ aggregate_metrics['tokens_per_sec'].append(tokens_sec)
400
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
401
+ print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
402
+ print("==========")
403
+ if is_speculative:
404
+ counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
405
+ acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
406
+ print(f"Acceptance probs: {acceptance_probs}")
407
+ print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
408
+
409
+ print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
410
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
411
+
412
+
413
+ if __name__ == '__main__':
414
+ import argparse
415
+ parser = argparse.ArgumentParser(description='Your CLI description.')
416
+
417
+ parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
418
+ parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
419
+ parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
420
+ parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
421
+ parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
422
+ parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
423
+ parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
424
+ parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
425
+ parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
426
+ parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
427
+ parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
428
+ parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
429
+ parser.add_argument('--device', type=str, default=default_device, help='Device to use')
430
+
431
+ args = parser.parse_args()
432
+ main(
433
+ args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
434
+ args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
435
+ args.speculate_k, args.device
436
+ )