xinference 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (328) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +473 -31
  3. xinference/client/restful/async_restful_client.py +178 -8
  4. xinference/client/restful/restful_client.py +151 -3
  5. xinference/core/supervisor.py +99 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +58 -21
  11. xinference/model/image/model_spec.json +159 -90
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +6 -2
  14. xinference/model/llm/llm_family.json +1299 -174
  15. xinference/model/llm/mlx/distributed_models/core.py +41 -0
  16. xinference/model/llm/mlx/distributed_models/qwen2.py +1 -2
  17. xinference/model/llm/sglang/core.py +44 -11
  18. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  19. xinference/model/llm/tool_parsers/qwen_tool_parser.py +29 -4
  20. xinference/model/llm/transformers/chatglm.py +3 -0
  21. xinference/model/llm/transformers/core.py +129 -36
  22. xinference/model/llm/transformers/multimodal/minicpmv45.py +340 -0
  23. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  24. xinference/model/llm/transformers/utils.py +23 -0
  25. xinference/model/llm/utils.py +48 -32
  26. xinference/model/llm/vllm/core.py +207 -72
  27. xinference/model/utils.py +74 -31
  28. xinference/thirdparty/audiotools/__init__.py +10 -0
  29. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  30. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  31. xinference/thirdparty/audiotools/core/display.py +194 -0
  32. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  33. xinference/thirdparty/audiotools/core/effects.py +647 -0
  34. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  35. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  36. xinference/thirdparty/audiotools/core/playback.py +252 -0
  37. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  38. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  39. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  40. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  41. xinference/thirdparty/audiotools/core/util.py +671 -0
  42. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  43. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  44. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  45. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  46. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  47. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  48. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  49. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  50. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  51. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  52. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  53. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  54. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  55. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  56. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  57. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  58. xinference/thirdparty/audiotools/post.py +140 -0
  59. xinference/thirdparty/audiotools/preference.py +600 -0
  60. xinference/thirdparty/fish_speech/fish_speech/text/chn_text_norm/text.py +1 -1
  61. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  62. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  63. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  81. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  82. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  83. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  84. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  85. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  86. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  87. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  88. xinference/thirdparty/indextts/__init__.py +0 -0
  89. xinference/thirdparty/indextts/cli.py +65 -0
  90. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  91. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  92. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  93. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  94. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  95. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  96. xinference/thirdparty/indextts/gpt/model.py +713 -0
  97. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  98. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  99. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  100. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  101. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  102. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  103. xinference/thirdparty/indextts/infer.py +690 -0
  104. xinference/thirdparty/indextts/infer_v2.py +739 -0
  105. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  106. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  107. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  108. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  109. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  110. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  111. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  112. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  113. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  114. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  115. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  116. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  117. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  118. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  119. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  120. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  121. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  123. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  124. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  142. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  143. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  144. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  145. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  146. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  147. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  148. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  149. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  150. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  151. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  152. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  153. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  154. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  155. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  159. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  160. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  161. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  162. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  168. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  169. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  170. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  171. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  172. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  173. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  174. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  175. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  176. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  178. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  179. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  180. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  181. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  182. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  183. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  184. xinference/thirdparty/indextts/utils/common.py +121 -0
  185. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  186. xinference/thirdparty/indextts/utils/front.py +536 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  249. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  250. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  251. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  252. xinference/thirdparty/indextts/utils/utils.py +93 -0
  253. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  254. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  255. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  256. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  257. xinference/thirdparty/melo/text/chinese_mix.py +2 -2
  258. xinference/types.py +9 -0
  259. xinference/ui/gradio/media_interface.py +66 -8
  260. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  261. xinference/ui/web/ui/build/index.html +1 -1
  262. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  263. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  264. xinference/ui/web/ui/build/static/js/main.45e78536.js +3 -0
  265. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.45e78536.js.LICENSE.txt} +0 -7
  266. xinference/ui/web/ui/build/static/js/main.45e78536.js.map +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ea2a26361204e70cf1018d6990fb6354bed82b3ac69690391e0f100385e7abb7.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  284. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  285. xinference/ui/web/ui/package-lock.json +0 -34
  286. xinference/ui/web/ui/package.json +0 -1
  287. xinference/ui/web/ui/src/locales/en.json +9 -3
  288. xinference/ui/web/ui/src/locales/ja.json +9 -3
  289. xinference/ui/web/ui/src/locales/ko.json +9 -3
  290. xinference/ui/web/ui/src/locales/zh.json +9 -3
  291. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/METADATA +24 -6
  292. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/RECORD +296 -77
  293. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  294. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  295. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  296. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  302. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  313. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  314. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  315. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  316. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  317. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  318. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  319. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  320. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  321. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  322. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  323. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  324. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  325. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/WHEEL +0 -0
  326. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/entry_points.txt +0 -0
  327. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/licenses/LICENSE +0 -0
  328. {xinference-1.10.0.dist-info → xinference-1.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1013 @@
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from collections import UserDict
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from transformers.utils import add_start_docstrings
24
+ from transformers.generation.beam_constraints import Constraint, ConstraintListState
25
+
26
+
27
+ PROCESS_INPUTS_DOCSTRING = r"""
28
+ Args:
29
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
30
+ Indices of input sequence tokens in the vocabulary.
31
+
32
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
33
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
34
+
35
+ [What are input IDs?](../glossary#input-ids)
36
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
37
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
38
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
39
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
40
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
41
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
42
+ pad_token_id (`int`, *optional*):
43
+ The id of the *padding* token.
44
+ eos_token_id (`Union[int, List[int]]`, *optional*):
45
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
46
+ beam_indices (`torch.LongTensor`, *optional*):
47
+ Beam indices indicating to which beam hypothesis each token correspond.
48
+ group_index (`int`, *optional*):
49
+ The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
50
+
51
+ Return:
52
+ `UserDict`: A dictionary composed of the fields as defined above:
53
+
54
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
55
+ non-finished beams.
56
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
57
+ to the non-finished beam_hypotheses.
58
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
59
+ indicating to which beam the next tokens shall be added.
60
+
61
+ """
62
+
63
+ FINALIZE_INPUTS_DOCSTRING = r"""
64
+ Args:
65
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
66
+ Indices of input sequence tokens in the vocabulary.
67
+
68
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
69
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
70
+
71
+ [What are input IDs?](../glossary#input-ids)
72
+ final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
73
+ The final scores of all non-finished beams.
74
+ final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
75
+ The last tokens to be added to the non-finished beam_hypotheses.
76
+ final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
77
+ The beam indices indicating to which beam the `final_beam_tokens` shall be added.
78
+ pad_token_id (`int`, *optional*):
79
+ The id of the *padding* token.
80
+ eos_token_id (`Union[int, List[int]]`, *optional*):
81
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
82
+
83
+ Return:
84
+ `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
85
+ The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
86
+ due to the `eos_token_id`.
87
+
88
+ """
89
+
90
+
91
+ class BeamScorer(ABC):
92
+ """
93
+ Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
94
+ [`~PreTrainedModel.beam_sample`].
95
+ """
96
+
97
+ @abstractmethod
98
+ @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
99
+ def process(
100
+ self,
101
+ input_ids: torch.LongTensor,
102
+ next_scores: torch.FloatTensor,
103
+ next_tokens: torch.LongTensor,
104
+ next_indices: torch.LongTensor,
105
+ **kwargs,
106
+ ) -> Tuple[torch.Tensor]:
107
+ raise NotImplementedError("This is an abstract method.")
108
+
109
+ @abstractmethod
110
+ @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
111
+ def finalize(
112
+ self,
113
+ input_ids: torch.LongTensor,
114
+ next_scores: torch.FloatTensor,
115
+ next_tokens: torch.LongTensor,
116
+ next_indices: torch.LongTensor,
117
+ max_length: int,
118
+ **kwargs,
119
+ ) -> torch.LongTensor:
120
+ raise NotImplementedError("This is an abstract method.")
121
+
122
+
123
+ class BeamSearchScorer(BeamScorer):
124
+ r"""
125
+ [`BeamScorer`] implementing standard beam search decoding.
126
+
127
+ Adapted in part from [Facebook's XLM beam search
128
+ code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
129
+
130
+ Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
131
+ implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
132
+
133
+ Args:
134
+ batch_size (`int`):
135
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
136
+ num_beams (`int`):
137
+ Number of beams for beam search.
138
+ device (`torch.device`):
139
+ Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
140
+ allocated.
141
+ length_penalty (`float`, *optional*, defaults to 1.0):
142
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
143
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
144
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
145
+ `length_penalty` < 0.0 encourages shorter sequences.
146
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
147
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
148
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
149
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
150
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
151
+ beam search algorithm).
152
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
153
+ The number of beam hypotheses that shall be returned upon calling
154
+ [`~transformers.BeamSearchScorer.finalize`].
155
+ num_beam_groups (`int`, *optional*, defaults to 1):
156
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
157
+ See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
158
+ max_length (`int`, *optional*):
159
+ The maximum length of the sequence to be generated.
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ batch_size: int,
165
+ num_beams: int,
166
+ device: torch.device,
167
+ length_penalty: Optional[float] = 1.0,
168
+ do_early_stopping: Optional[Union[bool, str]] = False,
169
+ num_beam_hyps_to_keep: Optional[int] = 1,
170
+ num_beam_groups: Optional[int] = 1,
171
+ max_length: Optional[int] = None,
172
+ ):
173
+ self.num_beams = num_beams
174
+ self.device = device
175
+ self.length_penalty = length_penalty
176
+ self.do_early_stopping = do_early_stopping
177
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
178
+ self.num_beam_groups = num_beam_groups
179
+ self.group_size = self.num_beams // self.num_beam_groups
180
+
181
+ self._is_init = False
182
+ # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
183
+ # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
184
+ self._beam_hyps = [
185
+ BeamHypotheses(
186
+ num_beams=self.group_size,
187
+ length_penalty=self.length_penalty,
188
+ early_stopping=self.do_early_stopping,
189
+ max_length=max_length,
190
+ )
191
+ for _ in range(batch_size * self.num_beam_groups)
192
+ ]
193
+ # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
194
+ # in the i-th mini-batch is complete.
195
+ self._done = torch.tensor(
196
+ [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
197
+ )
198
+
199
+ if not isinstance(num_beams, int) or num_beams <= 1:
200
+ raise ValueError(
201
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
202
+ " one should make use of `greedy_search` instead."
203
+ )
204
+
205
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
206
+ raise ValueError(
207
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
208
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
209
+ )
210
+
211
+ @property
212
+ def is_done(self) -> bool:
213
+ return self._done.all()
214
+
215
+ def process(
216
+ self,
217
+ input_ids: torch.LongTensor,
218
+ next_scores: torch.FloatTensor,
219
+ next_tokens: torch.LongTensor,
220
+ next_indices: torch.LongTensor,
221
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
222
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
223
+ beam_indices: Optional[torch.LongTensor] = None,
224
+ group_index: Optional[int] = 0,
225
+ decoder_prompt_len: Optional[int] = 0,
226
+ ) -> Dict[str, torch.Tensor]:
227
+ # add up to the length which the next_scores is calculated on (including decoder prompt)
228
+ cur_len = input_ids.shape[-1] + 1
229
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
230
+
231
+ if not (batch_size == (input_ids.shape[0] // self.group_size)):
232
+ if self.num_beam_groups > 1:
233
+ raise ValueError(
234
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
235
+ f"size of {self.group_size} is expected by the beam scorer."
236
+ )
237
+ else:
238
+ raise ValueError(
239
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
240
+ f"{self.group_size} is expected by the beam scorer."
241
+ )
242
+
243
+ device = input_ids.device
244
+ next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
245
+ next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
246
+ next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
247
+
248
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
249
+ if isinstance(eos_token_id, int):
250
+ eos_token_id = [eos_token_id]
251
+ eos_token_id = torch.tensor(eos_token_id)
252
+
253
+ for batch_idx in range(batch_size):
254
+ batch_group_idx = batch_idx * self.num_beam_groups + group_index
255
+ if self._done[batch_group_idx]:
256
+ if self.num_beams < len(self._beam_hyps[batch_group_idx]):
257
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
258
+ if eos_token_id is None or pad_token_id is None:
259
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
260
+ # pad the batch
261
+ next_beam_scores[batch_idx, :] = 0
262
+ next_beam_tokens[batch_idx, :] = pad_token_id
263
+ next_beam_indices[batch_idx, :] = 0
264
+ continue
265
+
266
+ # next tokens for this sentence
267
+ beam_idx = 0
268
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
269
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
270
+ ):
271
+ batch_beam_idx = batch_idx * self.group_size + next_index
272
+ # add to generated hypotheses if end of sentence
273
+ if (eos_token_id is not None) and (next_token.item() in eos_token_id):
274
+ # if beam_token does not belong to top num_beams tokens, it should not be added
275
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
276
+ if is_beam_token_worse_than_top_num_beams:
277
+ continue
278
+ if beam_indices is not None:
279
+ beam_index = beam_indices[batch_beam_idx]
280
+ beam_index = beam_index + (batch_beam_idx,)
281
+ else:
282
+ beam_index = None
283
+
284
+ self._beam_hyps[batch_group_idx].add(
285
+ input_ids[batch_beam_idx].clone(),
286
+ next_score.item(),
287
+ beam_indices=beam_index,
288
+ generated_len=cur_len - decoder_prompt_len,
289
+ )
290
+ else:
291
+ # add next predicted token since it is not eos_token
292
+ next_beam_scores[batch_idx, beam_idx] = next_score
293
+ next_beam_tokens[batch_idx, beam_idx] = next_token
294
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
295
+ beam_idx += 1
296
+
297
+ # once the beam for next step is full, don't add more tokens to it.
298
+ if beam_idx == self.group_size:
299
+ break
300
+
301
+ if beam_idx < self.group_size:
302
+ raise ValueError(
303
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
304
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
305
+ )
306
+
307
+ # Check if we are done so that we can save a pad step if all(done)
308
+ self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
309
+ next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
310
+ )
311
+
312
+ return UserDict(
313
+ {
314
+ "next_beam_scores": next_beam_scores.view(-1),
315
+ "next_beam_tokens": next_beam_tokens.view(-1),
316
+ "next_beam_indices": next_beam_indices.view(-1),
317
+ }
318
+ )
319
+
320
+ def finalize(
321
+ self,
322
+ input_ids: torch.LongTensor,
323
+ final_beam_scores: torch.FloatTensor,
324
+ final_beam_tokens: torch.LongTensor,
325
+ final_beam_indices: torch.LongTensor,
326
+ max_length: int,
327
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
328
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
329
+ beam_indices: Optional[torch.LongTensor] = None,
330
+ decoder_prompt_len: Optional[int] = 0,
331
+ ) -> Tuple[torch.LongTensor]:
332
+ batch_size = len(self._beam_hyps) // self.num_beam_groups
333
+
334
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
335
+ if isinstance(eos_token_id, int):
336
+ eos_token_id = [eos_token_id]
337
+ eos_token_id = torch.tensor(eos_token_id)
338
+
339
+ # finalize all open beam hypotheses and add to generated hypotheses
340
+ for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
341
+ if self._done[batch_group_idx]:
342
+ continue
343
+
344
+ # all open beam hypotheses are added to the beam hypothesis
345
+ # beam hypothesis class automatically keeps the best beams
346
+ for index_per_group in range(self.group_size):
347
+ batch_beam_idx = batch_group_idx * self.group_size + index_per_group
348
+ final_score = final_beam_scores[batch_beam_idx].item()
349
+ final_tokens = input_ids[batch_beam_idx]
350
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
351
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
352
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
353
+
354
+ # select the best hypotheses
355
+ sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
356
+ best = []
357
+ best_indices = []
358
+ best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
359
+
360
+ # retrieve best hypotheses
361
+ for i in range(batch_size):
362
+ beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
363
+ candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
364
+ sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
365
+ for j in range(self.num_beam_hyps_to_keep):
366
+ best_hyp_tuple = sorted_hyps.pop()
367
+ best_score = best_hyp_tuple[0]
368
+ best_hyp = best_hyp_tuple[1]
369
+ best_index = best_hyp_tuple[2]
370
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
371
+
372
+ # append hyp to lists
373
+ best.append(best_hyp)
374
+
375
+ # append indices to list
376
+ best_indices.append(best_index)
377
+
378
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
379
+
380
+ # prepare for adding eos
381
+ sent_lengths_max = sent_lengths.max().item() + 1
382
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
383
+ decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
384
+
385
+ if len(best_indices) > 0 and best_indices[0] is not None:
386
+ indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
387
+ else:
388
+ indices = None
389
+
390
+ # shorter batches are padded if needed
391
+ if sent_lengths.min().item() != sent_lengths.max().item():
392
+ if pad_token_id is None:
393
+ raise ValueError("`pad_token_id` has to be defined")
394
+ decoded.fill_(pad_token_id)
395
+
396
+ if indices is not None:
397
+ indices.fill_(-1)
398
+
399
+ # fill with hypotheses and eos_token_id if the latter fits in
400
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
401
+ decoded[i, : sent_lengths[i]] = hypo
402
+
403
+ if indices is not None:
404
+ indices[i, : len(best_idx)] = torch.tensor(best_idx)
405
+
406
+ if sent_lengths[i] < sent_max_len:
407
+ # inserting only the first eos_token_id
408
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
409
+
410
+ return UserDict(
411
+ {
412
+ "sequences": decoded,
413
+ "sequence_scores": best_scores,
414
+ "beam_indices": indices,
415
+ }
416
+ )
417
+
418
+
419
+ class ConstrainedBeamSearchScorer(BeamScorer):
420
+ r"""
421
+ [`BeamScorer`] implementing constrained beam search decoding.
422
+
423
+
424
+ Args:
425
+ batch_size (`int`):
426
+ Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
427
+ num_beams (`int`):
428
+ Number of beams for beam search.
429
+ constraints (`List[Constraint]`):
430
+ A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
431
+ output. For more information, the documentation of [`Constraint`] should be read.
432
+ device (`torch.device`):
433
+ Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
434
+ allocated.
435
+ length_penalty (`float`, *optional*, defaults to 1.0):
436
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
437
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
438
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
439
+ `length_penalty` < 0.0 encourages shorter sequences.
440
+ do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
441
+ Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
442
+ `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
443
+ heuristic is applied and the generation stops when is it very unlikely to find better candidates;
444
+ `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
445
+ beam search algorithm).
446
+ num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
447
+ The number of beam hypotheses that shall be returned upon calling
448
+ [`~transformers.BeamSearchScorer.finalize`].
449
+ num_beam_groups (`int`, *optional*, defaults to 1):
450
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
451
+ See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
452
+ max_length (`int`, *optional*):
453
+ The maximum length of the sequence to be generated.
454
+ """
455
+
456
+ def __init__(
457
+ self,
458
+ batch_size: int,
459
+ num_beams: int,
460
+ constraints: List[Constraint],
461
+ device: torch.device,
462
+ length_penalty: Optional[float] = 1.0,
463
+ do_early_stopping: Optional[Union[bool, str]] = False,
464
+ num_beam_hyps_to_keep: Optional[int] = 1,
465
+ num_beam_groups: Optional[int] = 1,
466
+ max_length: Optional[int] = None,
467
+ ):
468
+ self.num_beams = num_beams
469
+ self.device = device
470
+ self.length_penalty = length_penalty
471
+ self.do_early_stopping = do_early_stopping
472
+ self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
473
+ self.num_beam_groups = num_beam_groups
474
+ self.group_size = self.num_beams // self.num_beam_groups
475
+ self.constraints = constraints
476
+
477
+ self._is_init = False
478
+ self._beam_hyps = [
479
+ BeamHypotheses(
480
+ num_beams=self.num_beams,
481
+ length_penalty=self.length_penalty,
482
+ early_stopping=self.do_early_stopping,
483
+ max_length=max_length,
484
+ )
485
+ for _ in range(batch_size)
486
+ ]
487
+ self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
488
+
489
+ if not isinstance(num_beams, int) or num_beams <= 1:
490
+ raise ValueError(
491
+ f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
492
+ " one should make use of `greedy_search` instead."
493
+ )
494
+
495
+ if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
496
+ raise ValueError(
497
+ "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
498
+ f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
499
+ )
500
+
501
+ @property
502
+ def is_done(self) -> bool:
503
+ return self._done.all()
504
+
505
+ def make_constraint_states(self, n):
506
+ return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
507
+
508
+ def check_completes_constraints(self, sequence):
509
+ new_state = self.make_constraint_states(1)[0]
510
+ new_state.reset(sequence)
511
+ return new_state.completed
512
+
513
+ def process(
514
+ self,
515
+ input_ids: torch.LongTensor,
516
+ next_scores: torch.FloatTensor,
517
+ next_tokens: torch.LongTensor,
518
+ next_indices: torch.LongTensor,
519
+ scores_for_all_vocab: torch.FloatTensor,
520
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
521
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
522
+ beam_indices: Optional[torch.LongTensor] = None,
523
+ decoder_prompt_len: Optional[int] = 0,
524
+ ) -> Tuple[torch.Tensor]:
525
+ r"""
526
+ Args:
527
+ input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
528
+ Indices of input sequence tokens in the vocabulary.
529
+
530
+ Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
531
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
532
+
533
+ [What are input IDs?](../glossary#input-ids)
534
+ next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
535
+ Current scores of the top `2 * num_beams` non-finished beam hypotheses.
536
+ next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
537
+ `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
538
+ next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
539
+ Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
540
+ scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
541
+ The scores of all tokens in the vocabulary for each of the beam hypotheses.
542
+ pad_token_id (`int`, *optional*):
543
+ The id of the *padding* token.
544
+ eos_token_id (`Union[int, List[int]]`, *optional*):
545
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
546
+ beam_indices (`torch.LongTensor`, *optional*):
547
+ Beam indices indicating to which beam hypothesis each token correspond.
548
+ decoder_prompt_len (`int`, *optional*):
549
+ The length of prompt that is included in the input to decoder.
550
+ Return:
551
+ `UserDict`: A dictionary composed of the fields as defined above:
552
+
553
+ - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
554
+ all
555
+ non-finished beams.
556
+
557
+ - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
558
+ added
559
+ to the non-finished beam_hypotheses.
560
+ - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
561
+ indicating to which beam the next tokens shall be added.
562
+ """
563
+
564
+ # add up to the length which the next_scores is calculated on (including decoder prompt)
565
+ cur_len = input_ids.shape[-1] + 1
566
+ batch_size = len(self._beam_hyps)
567
+ if not (batch_size == (input_ids.shape[0] // self.group_size)):
568
+ if self.num_beam_groups > 1:
569
+ raise ValueError(
570
+ f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
571
+ f"size of {self.group_size} is expected by the beam scorer."
572
+ )
573
+ else:
574
+ raise ValueError(
575
+ f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
576
+ f"{self.group_size} is expected by the beam scorer."
577
+ )
578
+
579
+ device = input_ids.device
580
+
581
+ next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
582
+ next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
583
+ next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
584
+
585
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
586
+ if isinstance(eos_token_id, int):
587
+ eos_token_id = [eos_token_id]
588
+ eos_token_id = torch.tensor(eos_token_id)
589
+
590
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
591
+ if self._done[batch_idx]:
592
+ if self.num_beams < len(beam_hyp):
593
+ raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
594
+ if eos_token_id is None or pad_token_id is None:
595
+ raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
596
+ # pad the batch
597
+ next_beam_scores[batch_idx, :] = 0
598
+ next_beam_tokens[batch_idx, :] = pad_token_id
599
+ next_beam_indices[batch_idx, :] = 0
600
+ continue
601
+
602
+ # next tokens for this sentence.
603
+ beam_idx = 0
604
+ for beam_token_rank, (next_token, next_score, next_index) in enumerate(
605
+ zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
606
+ ):
607
+ batch_beam_idx = batch_idx * self.group_size + next_index
608
+ # add to generated hypotheses if end of sentence
609
+ if (eos_token_id is not None) and (next_token.item() in eos_token_id):
610
+ # if beam_token does not belong to top num_beams tokens, it should not be added
611
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
612
+ if is_beam_token_worse_than_top_num_beams:
613
+ continue
614
+
615
+ completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].cpu().tolist())
616
+ if completes_constraint:
617
+ if beam_indices is not None:
618
+ beam_index = beam_indices[batch_beam_idx]
619
+ beam_index = beam_index + (batch_beam_idx,)
620
+ else:
621
+ beam_index = None
622
+
623
+ beam_hyp.add(
624
+ input_ids[batch_beam_idx].clone(),
625
+ next_score.item(),
626
+ beam_indices=beam_index,
627
+ generated_len=cur_len - decoder_prompt_len,
628
+ )
629
+ else:
630
+ # add next predicted token since it is not eos_token
631
+ next_beam_scores[batch_idx, beam_idx] = next_score
632
+ next_beam_tokens[batch_idx, beam_idx] = next_token
633
+ next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
634
+ beam_idx += 1
635
+
636
+ # once the beam for next step is full, don't add more tokens to it.
637
+ if beam_idx == self.group_size:
638
+ break
639
+
640
+ new_scores, new_tokens, new_indices = self.step_sentence_constraint(
641
+ batch_idx,
642
+ input_ids,
643
+ scores_for_all_vocab,
644
+ next_beam_scores[batch_idx],
645
+ next_beam_tokens[batch_idx],
646
+ next_beam_indices[batch_idx],
647
+ )
648
+
649
+ next_beam_scores[batch_idx] = new_scores
650
+ next_beam_tokens[batch_idx] = new_tokens
651
+ next_beam_indices[batch_idx] = new_indices
652
+
653
+ if beam_idx < self.group_size:
654
+ raise ValueError(
655
+ f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
656
+ f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
657
+ )
658
+
659
+ # Check if we are done so that we can save a pad step if all(done)
660
+ self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
661
+ next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
662
+ )
663
+
664
+ return UserDict(
665
+ {
666
+ "next_beam_scores": next_beam_scores.view(-1),
667
+ "next_beam_tokens": next_beam_tokens.view(-1),
668
+ "next_beam_indices": next_beam_indices.view(-1),
669
+ }
670
+ )
671
+
672
+ def step_sentence_constraint(
673
+ self,
674
+ batch_idx: int,
675
+ input_ids: torch.LongTensor,
676
+ vocab_scores: torch.FloatTensor,
677
+ sent_beam_scores: torch.FloatTensor,
678
+ sent_beam_tokens: torch.LongTensor,
679
+ sent_beam_indices: torch.LongTensor,
680
+ push_progress: bool = False,
681
+ ):
682
+ # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
683
+ # (candidate next tokens)
684
+
685
+ # 1. Adding "advance_tokens"
686
+ # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
687
+ # advance us in fulfilling the constraints.
688
+
689
+ # 2. Selecting best candidates such that we end up with highest probable candidates
690
+ # that fulfill our constraints.
691
+
692
+ orig_len = sent_beam_indices.size(0)
693
+ device = sent_beam_indices.device
694
+
695
+ # initialize states
696
+ topk_contraint_states = self.make_constraint_states(orig_len)
697
+ advance_constraint_states = self.make_constraint_states(orig_len)
698
+
699
+ sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
700
+ this_batch_input_ids = input_ids[sidx:eidx]
701
+ this_batch_token_scores = vocab_scores[sidx:eidx]
702
+ full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
703
+
704
+ # need to make new hypothesis that advance the constraints
705
+ track_new = {
706
+ "new_seqs": full_hypotheses.tolist(),
707
+ "new_states": [],
708
+ "new_indices": [],
709
+ "new_tokens": [],
710
+ "new_scores": [],
711
+ }
712
+ for seq_idx, pre_seq in enumerate(this_batch_input_ids):
713
+ # pre_seq = ith sequence generated before this step.
714
+
715
+ # input_ids -> (topk) generic beam search best model next tokens
716
+ # -> (advance) constraints forcing the next token
717
+ # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
718
+ # hypotheses.
719
+
720
+ topk_state = topk_contraint_states[seq_idx]
721
+ topk_state.reset(full_hypotheses[seq_idx].cpu().tolist())
722
+
723
+ advance_state = advance_constraint_states[seq_idx]
724
+ advance_state.reset(pre_seq.cpu().tolist())
725
+
726
+ if not advance_state.completed:
727
+ advance_tokens = torch.LongTensor(advance_state.advance()).to(device)
728
+ for advance_token in advance_tokens:
729
+ # since adding each `advance_token` leads to a different hypothesis, create new state instance.
730
+ new_state = advance_state.copy(stateful=True)
731
+ new_state.add(advance_token.cpu().tolist())
732
+
733
+ advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).cpu().tolist()
734
+ if advance_seq not in track_new["new_seqs"]:
735
+ # prevent duplicates, which are basically bound to happen in this process.
736
+ track_new["new_seqs"].append(advance_seq)
737
+ track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
738
+ track_new["new_tokens"].append(advance_token)
739
+ track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
740
+ track_new["new_states"].append(new_state)
741
+ elif push_progress:
742
+ # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
743
+ # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
744
+
745
+ # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
746
+
747
+ # Without this step, if `sent_beam_indices` is something like [1,1], then
748
+ # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
749
+ # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
750
+ # the else part of `if constraints_completed[seq_idx]`)
751
+ # 3. it ends up simply getting removed from consideration.
752
+
753
+ # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
754
+ # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
755
+ # search times, since completed sequences keep getting removed after all this effort for constrained
756
+ # generation.
757
+
758
+ # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
759
+ # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
760
+
761
+ new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
762
+ advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
763
+
764
+ advance_state = advance_constraint_states[seq_idx]
765
+
766
+ advance_seq = advance_seq.cpu().tolist()
767
+
768
+ advance_state.reset(advance_seq)
769
+ if advance_seq not in track_new["new_seqs"]:
770
+ # but still don't want to have duplicates
771
+ track_new["new_seqs"].append(advance_seq)
772
+ track_new["new_indices"].append(seq_idx)
773
+ track_new["new_tokens"].append(new_token)
774
+ track_new["new_scores"].append(new_score)
775
+ track_new["new_states"].append(advance_state)
776
+
777
+ if len(track_new["new_indices"]) > 0:
778
+ new_indices = torch.tensor(track_new["new_indices"]).to(device)
779
+ new_tokens = torch.stack(track_new["new_tokens"]).to(device)
780
+ new_scores = torch.stack(track_new["new_scores"]).to(device)
781
+
782
+ all_states = topk_contraint_states + track_new["new_states"]
783
+ all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
784
+ all_scores = torch.cat((sent_beam_scores, new_scores), -1)
785
+ all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device)
786
+
787
+ zipped = all_banks * 100 + all_scores
788
+ indices = zipped.sort(descending=True).indices
789
+ sorted_banks = all_banks[indices]
790
+
791
+ # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
792
+
793
+ counter = -1
794
+ cur_bank = sorted_banks[0]
795
+ increments = []
796
+ for bank in sorted_banks:
797
+ if bank == cur_bank:
798
+ counter += 1
799
+ else:
800
+ counter = 0
801
+ cur_bank = bank
802
+ increments.append(counter)
803
+ rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
804
+
805
+ indices = indices[rearrangers][:orig_len]
806
+
807
+ sent_beam_scores = all_scores[indices]
808
+ sent_beam_tokens = all_tokens[indices]
809
+ sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
810
+
811
+ return sent_beam_scores, sent_beam_tokens, sent_beam_indices
812
+
813
+ def finalize(
814
+ self,
815
+ input_ids: torch.LongTensor,
816
+ final_beam_scores: torch.FloatTensor,
817
+ final_beam_tokens: torch.LongTensor,
818
+ final_beam_indices: torch.LongTensor,
819
+ max_length: int,
820
+ pad_token_id: Optional[Union[int, torch.Tensor]] = None,
821
+ eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
822
+ beam_indices: Optional[torch.LongTensor] = None,
823
+ decoder_prompt_len: Optional[int] = 0,
824
+ ) -> Tuple[torch.LongTensor]:
825
+ batch_size = len(self._beam_hyps)
826
+
827
+ if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
828
+ if isinstance(eos_token_id, int):
829
+ eos_token_id = [eos_token_id]
830
+ eos_token_id = torch.tensor(eos_token_id)
831
+
832
+ # finalize all open beam hypotheses and add to generated hypotheses
833
+ for batch_idx, beam_hyp in enumerate(self._beam_hyps):
834
+ if self._done[batch_idx]:
835
+ continue
836
+
837
+ # all open beam hypotheses are added to the beam hypothesis
838
+ # beam hypothesis class automatically keeps the best beams
839
+
840
+ ids_collect = []
841
+ for beam_id in range(self.num_beams):
842
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
843
+ final_score = final_beam_scores[batch_beam_idx].item()
844
+ final_tokens = input_ids[batch_beam_idx]
845
+
846
+ completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
847
+ if completes_constraint:
848
+ beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
849
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
850
+ beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
851
+ ids_collect.append(beam_id)
852
+
853
+ # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
854
+ # generation. In these cases we simply return the highest scoring outputs.
855
+ if len(ids_collect) < self.num_beam_hyps_to_keep:
856
+ for beam_id in range(self.num_beams):
857
+ if beam_id not in ids_collect:
858
+ batch_beam_idx = batch_idx * self.num_beams + beam_id
859
+ final_score = final_beam_scores[batch_beam_idx].item()
860
+ final_tokens = input_ids[batch_beam_idx]
861
+ generated_len = final_tokens.shape[-1] - decoder_prompt_len
862
+ beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
863
+ if len(ids_collect) >= self.num_beam_hyps_to_keep:
864
+ break
865
+
866
+ # select the best hypotheses
867
+ sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
868
+ best = []
869
+ best_indices = []
870
+ best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
871
+
872
+ # retrieve best hypotheses
873
+ for i, beam_hyp in enumerate(self._beam_hyps):
874
+ sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
875
+ for j in range(self.num_beam_hyps_to_keep):
876
+ best_hyp_tuple = sorted_hyps.pop()
877
+ best_score = best_hyp_tuple[0]
878
+ best_hyp = best_hyp_tuple[1]
879
+ best_index = best_hyp_tuple[2]
880
+ sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
881
+
882
+ # append to lists
883
+ best.append(best_hyp)
884
+
885
+ # append indices to list
886
+ best_indices.append(best_index)
887
+
888
+ best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
889
+
890
+ # prepare for adding eos
891
+ sent_lengths_max = sent_lengths.max().item() + 1
892
+
893
+ sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
894
+ decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
895
+
896
+ if len(best_indices) > 0 and best_indices[0] is not None:
897
+ indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
898
+ else:
899
+ indices = None
900
+
901
+ # shorter batches are padded if needed
902
+ if sent_lengths.min().item() != sent_lengths.max().item():
903
+ if pad_token_id is None:
904
+ raise ValueError("`pad_token_id` has to be defined")
905
+ decoded.fill_(pad_token_id)
906
+
907
+ if indices is not None:
908
+ indices.fill_(-1)
909
+
910
+ # fill with hypotheses and eos_token_id if the latter fits in
911
+ for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
912
+ decoded[i, : sent_lengths[i]] = hypo
913
+
914
+ if indices is not None:
915
+ indices[i, : len(best_idx)] = torch.tensor(best_idx)
916
+
917
+ if sent_lengths[i] < sent_max_len:
918
+ # inserting only the first eos_token_id
919
+ decoded[i, sent_lengths[i]] = eos_token_id[0]
920
+
921
+ return UserDict(
922
+ {
923
+ "sequences": decoded,
924
+ "sequence_scores": best_scores,
925
+ "beam_indices": indices,
926
+ }
927
+ )
928
+
929
+
930
+ class BeamHypotheses:
931
+ def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
932
+ """
933
+ Initialize n-best list of hypotheses.
934
+ """
935
+ self.length_penalty = length_penalty
936
+ self.early_stopping = early_stopping
937
+ self.max_length = max_length
938
+ self.num_beams = num_beams
939
+ self.beams = []
940
+ self.worst_score = 1e9
941
+
942
+ if not isinstance(self.early_stopping, bool) and self.max_length is None:
943
+ raise ValueError(
944
+ "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
945
+ " BeamScorer class instance at initialization time."
946
+ )
947
+
948
+ def __len__(self):
949
+ """
950
+ Number of hypotheses in the list.
951
+ """
952
+ return len(self.beams)
953
+
954
+ def add(
955
+ self,
956
+ hyp: torch.LongTensor,
957
+ sum_logprobs: float,
958
+ beam_indices: Optional[torch.LongTensor] = None,
959
+ generated_len: Optional[int] = None,
960
+ ):
961
+ """
962
+ Add a new hypothesis to the list.
963
+ """
964
+ if generated_len is not None:
965
+ score = sum_logprobs / (generated_len**self.length_penalty)
966
+ # This 'else' case exists for retrocompatibility
967
+ else:
968
+ score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
969
+
970
+ if len(self) < self.num_beams or score > self.worst_score:
971
+ self.beams.append((score, hyp, beam_indices))
972
+ if len(self) > self.num_beams:
973
+ sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
974
+ del self.beams[sorted_next_scores[0][1]]
975
+ self.worst_score = sorted_next_scores[1][0]
976
+ else:
977
+ self.worst_score = min(score, self.worst_score)
978
+
979
+ def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
980
+ """
981
+ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
982
+ one in the heap, then we are done with this sentence.
983
+ """
984
+
985
+ if len(self) < self.num_beams:
986
+ return False
987
+
988
+ # `True`: stop as soon as at least `num_beams` hypotheses are finished
989
+ if self.early_stopping is True:
990
+ return True
991
+ # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
992
+ # when `length_penalty` is positive. See the discussion below for more details.
993
+ # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
994
+ elif self.early_stopping is False:
995
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
996
+ ret = self.worst_score >= highest_attainable_score
997
+ return ret
998
+ # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
999
+ else:
1000
+ # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
1001
+ # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
1002
+ # its max this way
1003
+ if self.length_penalty > 0.0:
1004
+ if self.max_length <= decoder_prompt_len:
1005
+ raise ValueError("max_length is not larger than decoder prompt length")
1006
+ highest_attainable_score = (
1007
+ best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
1008
+ )
1009
+ # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
1010
+ else:
1011
+ highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
1012
+ ret = self.worst_score >= highest_attainable_score
1013
+ return ret