xinference 1.10.0__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (317) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +11 -28
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/core/supervisor.py +87 -53
  6. xinference/core/worker.py +10 -0
  7. xinference/deploy/cmdline.py +15 -0
  8. xinference/model/audio/core.py +21 -6
  9. xinference/model/audio/indextts2.py +166 -0
  10. xinference/model/audio/model_spec.json +38 -1
  11. xinference/model/image/model_spec.json +69 -0
  12. xinference/model/image/stable_diffusion/core.py +13 -4
  13. xinference/model/llm/__init__.py +4 -0
  14. xinference/model/llm/llm_family.json +464 -2
  15. xinference/model/llm/sglang/core.py +30 -11
  16. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +94 -32
  17. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  18. xinference/model/llm/utils.py +12 -9
  19. xinference/model/llm/vllm/core.py +93 -17
  20. xinference/thirdparty/audiotools/__init__.py +10 -0
  21. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  22. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  23. xinference/thirdparty/audiotools/core/display.py +194 -0
  24. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  25. xinference/thirdparty/audiotools/core/effects.py +647 -0
  26. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  27. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  28. xinference/thirdparty/audiotools/core/playback.py +252 -0
  29. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  30. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  31. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  32. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  33. xinference/thirdparty/audiotools/core/util.py +671 -0
  34. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  35. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  36. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  37. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  38. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  39. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  40. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  41. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  42. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  43. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  44. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  45. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  46. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  47. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  48. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  49. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  50. xinference/thirdparty/audiotools/post.py +140 -0
  51. xinference/thirdparty/audiotools/preference.py +600 -0
  52. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  53. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  54. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  55. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  56. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  57. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  58. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  59. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  60. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  61. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  62. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  63. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  64. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  65. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  66. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  67. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  68. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  69. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  70. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  72. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  73. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  74. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  75. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  76. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  77. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  78. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  79. xinference/thirdparty/indextts/__init__.py +0 -0
  80. xinference/thirdparty/indextts/cli.py +65 -0
  81. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  82. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  83. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  84. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  85. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  86. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  87. xinference/thirdparty/indextts/gpt/model.py +713 -0
  88. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  89. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  90. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  91. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  92. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  93. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  94. xinference/thirdparty/indextts/infer.py +690 -0
  95. xinference/thirdparty/indextts/infer_v2.py +739 -0
  96. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  97. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  98. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  99. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  100. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  101. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  102. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  103. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  104. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  105. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  106. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  107. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  108. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  109. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  110. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  111. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  112. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  113. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  114. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  115. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  116. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  117. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  118. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  119. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  120. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  121. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  122. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  123. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  124. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  125. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  126. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  127. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  128. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  129. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  130. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  131. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  133. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  134. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  135. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  136. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  137. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  138. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  139. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  140. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  141. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  142. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  143. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  144. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  145. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  146. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  147. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  148. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  149. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  150. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  151. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  152. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  153. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  154. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  155. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  156. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  157. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  158. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  159. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  160. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  161. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  162. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  163. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  164. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  165. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  166. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  167. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  168. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  169. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  170. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  171. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  172. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  173. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  174. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  175. xinference/thirdparty/indextts/utils/common.py +121 -0
  176. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  177. xinference/thirdparty/indextts/utils/front.py +536 -0
  178. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  179. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  180. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  181. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  182. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  183. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  184. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  185. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  186. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  187. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  188. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  189. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  190. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  191. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  192. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  193. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  240. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  241. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  242. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  243. xinference/thirdparty/indextts/utils/utils.py +93 -0
  244. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  245. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  246. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  247. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  248. xinference/ui/gradio/media_interface.py +66 -8
  249. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  250. xinference/ui/web/ui/build/index.html +1 -1
  251. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  252. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  253. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  254. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  255. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  256. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  257. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  258. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  259. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  260. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  261. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  262. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  263. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  264. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  265. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  266. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  267. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  268. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  269. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  270. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  271. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  272. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  273. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  274. xinference/ui/web/ui/package-lock.json +0 -34
  275. xinference/ui/web/ui/package.json +0 -1
  276. xinference/ui/web/ui/src/locales/en.json +9 -3
  277. xinference/ui/web/ui/src/locales/ja.json +9 -3
  278. xinference/ui/web/ui/src/locales/ko.json +9 -3
  279. xinference/ui/web/ui/src/locales/zh.json +9 -3
  280. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/METADATA +18 -2
  281. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/RECORD +285 -67
  282. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  283. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  284. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  285. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  290. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  291. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  292. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  293. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  294. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  295. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  296. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  297. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  298. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  299. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  300. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  301. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  302. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  303. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  304. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  305. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  306. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  307. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  308. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  309. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  310. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  311. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  312. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  313. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  314. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  315. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  316. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  317. {xinference-1.10.0.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1247 @@
1
+ import math
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from inspect import isfunction
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange, repeat
9
+ from torch import einsum, nn
10
+
11
+ DEFAULT_DIM_HEAD = 64
12
+
13
+ Intermediates = namedtuple('Intermediates', [
14
+ 'pre_softmax_attn',
15
+ 'post_softmax_attn'
16
+ ])
17
+
18
+ LayerIntermediates = namedtuple('Intermediates', [
19
+ 'hiddens',
20
+ 'attn_intermediates',
21
+ 'past_key_values',
22
+ ])
23
+
24
+
25
+ # helpers
26
+
27
+ def exists(val):
28
+ return val is not None
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def cast_tuple(val, depth):
38
+ return val if isinstance(val, tuple) else (val,) * depth
39
+
40
+
41
+ class always():
42
+ def __init__(self, val):
43
+ self.val = val
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.val
47
+
48
+
49
+ class not_equals():
50
+ def __init__(self, val):
51
+ self.val = val
52
+
53
+ def __call__(self, x, *args, **kwargs):
54
+ return x != self.val
55
+
56
+
57
+ class equals():
58
+ def __init__(self, val):
59
+ self.val = val
60
+
61
+ def __call__(self, x, *args, **kwargs):
62
+ return x == self.val
63
+
64
+
65
+ def max_neg_value(tensor):
66
+ return -torch.finfo(tensor.dtype).max
67
+
68
+
69
+ def l2norm(t):
70
+ return F.normalize(t, p=2, dim=-1)
71
+
72
+
73
+ # init helpers
74
+
75
+ def init_zero_(layer):
76
+ nn.init.constant_(layer.weight, 0.)
77
+ if exists(layer.bias):
78
+ nn.init.constant_(layer.bias, 0.)
79
+
80
+
81
+ # keyword argument helpers
82
+
83
+ def pick_and_pop(keys, d):
84
+ values = list(map(lambda key: d.pop(key), keys))
85
+ return dict(zip(keys, values))
86
+
87
+
88
+ def group_dict_by_key(cond, d):
89
+ return_val = [dict(), dict()]
90
+ for key in d.keys():
91
+ match = bool(cond(key))
92
+ ind = int(not match)
93
+ return_val[ind][key] = d[key]
94
+ return (*return_val,)
95
+
96
+
97
+ def string_begins_with(prefix, str):
98
+ return str.startswith(prefix)
99
+
100
+
101
+ def group_by_key_prefix(prefix, d):
102
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
103
+
104
+
105
+ def groupby_prefix_and_trim(prefix, d):
106
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
107
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
108
+ return kwargs_without_prefix, kwargs
109
+
110
+
111
+ # activations
112
+
113
+ class ReluSquared(nn.Module):
114
+ def forward(self, x):
115
+ return F.relu(x) ** 2
116
+
117
+
118
+ # positional embeddings
119
+
120
+ class AbsolutePositionalEmbedding(nn.Module):
121
+ def __init__(self, dim, max_seq_len):
122
+ super().__init__()
123
+ self.scale = dim ** -0.5
124
+ self.emb = nn.Embedding(max_seq_len, dim)
125
+
126
+ def forward(self, x):
127
+ n = torch.arange(x.shape[1], device=x.device)
128
+ pos_emb = self.emb(n)
129
+ pos_emb = rearrange(pos_emb, 'n d -> () n d')
130
+ return pos_emb * self.scale
131
+
132
+
133
+ class FixedPositionalEmbedding(nn.Module):
134
+ def __init__(self, dim):
135
+ super().__init__()
136
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
137
+ self.register_buffer('inv_freq', inv_freq)
138
+
139
+ def forward(self, x, seq_dim=1, offset=0):
140
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
141
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
142
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
143
+ return rearrange(emb, 'n d -> () n d')
144
+
145
+
146
+ class RelativePositionBias(nn.Module):
147
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
148
+ super().__init__()
149
+ self.scale = scale
150
+ self.causal = causal
151
+ self.num_buckets = num_buckets
152
+ self.max_distance = max_distance
153
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
154
+
155
+ @staticmethod
156
+ def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128):
157
+ ret = 0
158
+ n = -relative_position
159
+ if not causal:
160
+ num_buckets //= 2
161
+ ret += (n < 0).long() * num_buckets
162
+ n = torch.abs(n)
163
+ else:
164
+ n = torch.max(n, torch.zeros_like(n))
165
+
166
+ max_exact = num_buckets // 2
167
+ is_small = n < max_exact
168
+
169
+ val_if_large = max_exact + (
170
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
171
+ ).long()
172
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
173
+
174
+ ret += torch.where(is_small, n, val_if_large)
175
+ return ret
176
+
177
+ def forward(self, qk_dots):
178
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
179
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
180
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
181
+ rel_pos = k_pos[None, :] - q_pos[:, None]
182
+ rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
183
+ max_distance=self.max_distance)
184
+ values = self.relative_attention_bias(rp_bucket)
185
+ bias = rearrange(values, 'i j h -> () h i j')
186
+ return qk_dots + (bias * self.scale)
187
+
188
+
189
+ class AlibiPositionalBias(nn.Module):
190
+ def __init__(self, heads, **kwargs):
191
+ super().__init__()
192
+ self.heads = heads
193
+ slopes = torch.Tensor(self._get_slopes(heads))
194
+ slopes = rearrange(slopes, 'h -> () h () ()')
195
+ self.register_buffer('slopes', slopes, persistent=False)
196
+ self.register_buffer('bias', None, persistent=False)
197
+
198
+ @staticmethod
199
+ def _get_slopes(heads):
200
+ def get_slopes_power_of_2(n):
201
+ start = (2 ** (-2 ** -(math.log2(n) - 3)))
202
+ ratio = start
203
+ return [start * ratio ** i for i in range(n)]
204
+
205
+ if math.log2(heads).is_integer():
206
+ return get_slopes_power_of_2(heads)
207
+
208
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
209
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
210
+ :heads - closest_power_of_2]
211
+
212
+ def forward(self, qk_dots):
213
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
214
+
215
+ if exists(self.bias) and self.bias.shape[-1] >= j:
216
+ return qk_dots + self.bias[..., :j]
217
+
218
+ bias = torch.arange(j, device=device)
219
+ bias = rearrange(bias, 'j -> () () () j')
220
+ bias = bias * self.slopes
221
+
222
+ num_heads_unalibied = h - bias.shape[1]
223
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
224
+
225
+ self.register_buffer('bias', bias, persistent=False)
226
+ return qk_dots + self.bias
227
+
228
+
229
+ class LearnedAlibiPositionalBias(AlibiPositionalBias):
230
+ def __init__(self, heads, bidirectional=False):
231
+ super().__init__(heads)
232
+ los_slopes = torch.log(self.slopes)
233
+ self.learned_logslopes = nn.Parameter(los_slopes)
234
+
235
+ self.bidirectional = bidirectional
236
+ if self.bidirectional:
237
+ self.learned_logslopes_future = nn.Parameter(los_slopes)
238
+
239
+ def forward(self, qk_dots):
240
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
241
+
242
+ def get_slopes(param):
243
+ return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1]))
244
+
245
+ if exists(self.bias) and self.bias.shape[-1] >= j:
246
+ bias = self.bias[..., :i, :j]
247
+ else:
248
+ i_arange = torch.arange(i, device=device)
249
+ j_arange = torch.arange(j, device=device)
250
+ bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
251
+ self.register_buffer('bias', bias, persistent=False)
252
+
253
+ if self.bidirectional:
254
+ past_slopes = get_slopes(self.learned_logslopes)
255
+ future_slopes = get_slopes(self.learned_logslopes_future)
256
+ bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes)
257
+ else:
258
+ slopes = get_slopes(self.learned_logslopes)
259
+ bias = bias * slopes
260
+
261
+ return qk_dots + bias
262
+
263
+
264
+ class RotaryEmbedding(nn.Module):
265
+ def __init__(self, dim):
266
+ super().__init__()
267
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
268
+ self.register_buffer('inv_freq', inv_freq)
269
+
270
+ def forward(self, max_seq_len, device):
271
+ t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
272
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
273
+ emb = torch.cat((freqs, freqs), dim=-1)
274
+ return rearrange(emb, 'n d -> () () n d')
275
+
276
+
277
+ def rotate_half(x):
278
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
279
+ x1, x2 = x.unbind(dim=-2)
280
+ return torch.cat((-x2, x1), dim=-1)
281
+
282
+
283
+ def apply_rotary_pos_emb(t, freqs):
284
+ seq_len = t.shape[-2]
285
+ freqs = freqs[:, :, -seq_len:]
286
+ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
287
+
288
+
289
+ # norms
290
+
291
+ class Scale(nn.Module):
292
+ def __init__(self, value, fn):
293
+ super().__init__()
294
+ self.value = value
295
+ self.fn = fn
296
+
297
+ def forward(self, x, **kwargs):
298
+ out = self.fn(x, **kwargs)
299
+ scale_fn = lambda t: t * self.value
300
+
301
+ if not isinstance(out, tuple):
302
+ return scale_fn(out)
303
+
304
+ return (scale_fn(out[0]), *out[1:])
305
+
306
+
307
+ class Rezero(nn.Module):
308
+ def __init__(self, fn):
309
+ super().__init__()
310
+ self.fn = fn
311
+ self.g = nn.Parameter(torch.zeros(1))
312
+
313
+ def forward(self, x, **kwargs):
314
+ out = self.fn(x, **kwargs)
315
+ rezero_fn = lambda t: t * self.g
316
+
317
+ if not isinstance(out, tuple):
318
+ return rezero_fn(out)
319
+
320
+ return (rezero_fn(out[0]), *out[1:])
321
+
322
+
323
+ class ScaleNorm(nn.Module):
324
+ def __init__(self, dim, eps=1e-5):
325
+ super().__init__()
326
+ self.scale = dim ** -0.5
327
+ self.eps = eps
328
+ self.g = nn.Parameter(torch.ones(1))
329
+
330
+ def forward(self, x):
331
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
332
+ return x / norm.clamp(min=self.eps) * self.g
333
+
334
+
335
+ class RMSNorm(nn.Module):
336
+ def __init__(self, dim, eps=1e-8):
337
+ super().__init__()
338
+ self.scale = dim ** -0.5
339
+ self.eps = eps
340
+ self.g = nn.Parameter(torch.ones(dim))
341
+
342
+ def forward(self, x):
343
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
344
+ return x / norm.clamp(min=self.eps) * self.g
345
+
346
+
347
+ class RMSScaleShiftNorm(nn.Module):
348
+ def __init__(self, dim, eps=1e-8):
349
+ super().__init__()
350
+ self.scale = dim ** -0.5
351
+ self.eps = eps
352
+ self.g = nn.Parameter(torch.ones(dim))
353
+ self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
354
+
355
+ def forward(self, x, norm_scale_shift_inp):
356
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
357
+ norm = x / norm.clamp(min=self.eps) * self.g
358
+
359
+ ss_emb = self.scale_shift_process(norm_scale_shift_inp)
360
+ scale, shift = torch.chunk(ss_emb, 2, dim=1)
361
+ h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
362
+ return h
363
+
364
+
365
+ # residual and residual gates
366
+
367
+ class Residual(nn.Module):
368
+ def __init__(self, dim, scale_residual=False):
369
+ super().__init__()
370
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
371
+
372
+ def forward(self, x, residual):
373
+ if exists(self.residual_scale):
374
+ residual = residual * self.residual_scale
375
+
376
+ return x + residual
377
+
378
+
379
+ class GRUGating(nn.Module):
380
+ def __init__(self, dim, scale_residual=False):
381
+ super().__init__()
382
+ self.gru = nn.GRUCell(dim, dim)
383
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
384
+
385
+ def forward(self, x, residual):
386
+ if exists(self.residual_scale):
387
+ residual = residual * self.residual_scale
388
+
389
+ gated_output = self.gru(
390
+ rearrange(x, 'b n d -> (b n) d'),
391
+ rearrange(residual, 'b n d -> (b n) d')
392
+ )
393
+
394
+ return gated_output.reshape_as(x)
395
+
396
+
397
+ # token shifting
398
+
399
+ def shift(t, amount, mask=None):
400
+ if amount == 0:
401
+ return t
402
+
403
+ if exists(mask):
404
+ t = t.masked_fill(~mask[..., None], 0.)
405
+
406
+ return F.pad(t, (0, 0, amount, -amount), value=0.)
407
+
408
+
409
+ class ShiftTokens(nn.Module):
410
+ def __init__(self, shifts, fn):
411
+ super().__init__()
412
+ self.fn = fn
413
+ self.shifts = tuple(shifts)
414
+
415
+ def forward(self, x, **kwargs):
416
+ mask = kwargs.get('mask', None)
417
+ shifts = self.shifts
418
+ segments = len(shifts)
419
+ feats_per_shift = x.shape[-1] // segments
420
+ splitted = x.split(feats_per_shift, dim=-1)
421
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
422
+ segments_to_shift = list(map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)))
423
+ x = torch.cat((*segments_to_shift, *rest), dim=-1)
424
+ return self.fn(x, **kwargs)
425
+
426
+
427
+ # feedforward
428
+
429
+ class GLU(nn.Module):
430
+ def __init__(self, dim_in, dim_out, activation):
431
+ super().__init__()
432
+ self.act = activation
433
+ self.proj = nn.Linear(dim_in, dim_out * 2)
434
+
435
+ def forward(self, x):
436
+ x, gate = self.proj(x).chunk(2, dim=-1)
437
+ return x * self.act(gate)
438
+
439
+
440
+ class FeedForward(nn.Module):
441
+ def __init__(
442
+ self,
443
+ dim,
444
+ dim_out=None,
445
+ mult=4,
446
+ glu=False,
447
+ relu_squared=False,
448
+ post_act_ln=False,
449
+ dropout=0.,
450
+ zero_init_output=False
451
+ ):
452
+ super().__init__()
453
+ inner_dim = int(dim * mult)
454
+ dim_out = default(dim_out, dim)
455
+ activation = ReluSquared() if relu_squared else nn.GELU()
456
+
457
+ project_in = nn.Sequential(
458
+ nn.Linear(dim, inner_dim),
459
+ activation
460
+ ) if not glu else GLU(dim, inner_dim, activation)
461
+
462
+ self.net = nn.Sequential(
463
+ project_in,
464
+ nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
465
+ nn.Dropout(dropout),
466
+ nn.Linear(inner_dim, dim_out)
467
+ )
468
+
469
+ # init last linear layer to 0
470
+ if zero_init_output:
471
+ init_zero_(self.net[-1])
472
+
473
+ def forward(self, x):
474
+ return self.net(x)
475
+
476
+
477
+ # attention.
478
+
479
+ class Attention(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dim,
483
+ dim_head=DEFAULT_DIM_HEAD,
484
+ heads=8,
485
+ causal=False,
486
+ talking_heads=False,
487
+ head_scale=False,
488
+ collab_heads=False,
489
+ collab_compression=.3,
490
+ sparse_topk=None,
491
+ use_entmax15=False,
492
+ num_mem_kv=0,
493
+ dropout=0.,
494
+ on_attn=False,
495
+ gate_values=False,
496
+ zero_init_output=False,
497
+ max_attend_past=None,
498
+ qk_norm=False,
499
+ scale_init_value=None,
500
+ rel_pos_bias=False,
501
+ rel_pos_num_buckets=32,
502
+ rel_pos_max_distance=128,
503
+ ):
504
+ super().__init__()
505
+ self.scale = dim_head ** -0.5
506
+
507
+ self.heads = heads
508
+ self.causal = causal
509
+ self.max_attend_past = max_attend_past
510
+
511
+ qk_dim = v_dim = dim_head * heads
512
+
513
+ # collaborative heads
514
+ self.collab_heads = collab_heads
515
+ if self.collab_heads:
516
+ qk_dim = int(collab_compression * qk_dim)
517
+ self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim))
518
+
519
+ self.to_q = nn.Linear(dim, qk_dim, bias=False)
520
+ self.to_k = nn.Linear(dim, qk_dim, bias=False)
521
+ self.to_v = nn.Linear(dim, v_dim, bias=False)
522
+
523
+ self.dropout = nn.Dropout(dropout)
524
+
525
+ # add GLU gating for aggregated values, from alphafold2
526
+ self.to_v_gate = None
527
+ if gate_values:
528
+ self.to_v_gate = nn.Linear(dim, v_dim)
529
+ nn.init.constant_(self.to_v_gate.weight, 0)
530
+ nn.init.constant_(self.to_v_gate.bias, 1)
531
+
532
+ # cosine sim attention
533
+ self.qk_norm = qk_norm
534
+ if qk_norm:
535
+ scale_init_value = default(scale_init_value,
536
+ -3) # if not provided, initialize as though it were sequence length of 1024
537
+ self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
538
+
539
+ # talking heads
540
+ self.talking_heads = talking_heads
541
+ if talking_heads:
542
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
543
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
544
+
545
+ # head scaling
546
+ self.head_scale = head_scale
547
+ if head_scale:
548
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
549
+
550
+ # explicit topk sparse attention
551
+ self.sparse_topk = sparse_topk
552
+
553
+ # entmax
554
+ self.attn_fn = F.softmax
555
+
556
+ # add memory key / values
557
+ self.num_mem_kv = num_mem_kv
558
+ if num_mem_kv > 0:
559
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
560
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
561
+
562
+ # attention on attention
563
+ self.attn_on_attn = on_attn
564
+ self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim)
565
+
566
+ self.rel_pos_bias = rel_pos_bias
567
+ if rel_pos_bias:
568
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
569
+ self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
570
+ num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
571
+
572
+ # init output projection 0
573
+ if zero_init_output:
574
+ init_zero_(self.to_out)
575
+
576
+ def forward(
577
+ self,
578
+ x,
579
+ context=None,
580
+ mask=None,
581
+ context_mask=None,
582
+ attn_mask=None,
583
+ sinusoidal_emb=None,
584
+ rotary_pos_emb=None,
585
+ prev_attn=None,
586
+ mem=None,
587
+ layer_past=None,
588
+ ):
589
+ b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
590
+ context)
591
+ kv_input = default(context, x)
592
+
593
+ q_input = x
594
+ k_input = kv_input
595
+ v_input = kv_input
596
+
597
+ if exists(mem):
598
+ k_input = torch.cat((mem, k_input), dim=-2)
599
+ v_input = torch.cat((mem, v_input), dim=-2)
600
+
601
+ if exists(sinusoidal_emb):
602
+ # in shortformer, the query would start at a position offset depending on the past cached memory
603
+ offset = k_input.shape[-2] - q_input.shape[-2]
604
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
605
+ k_input = k_input + sinusoidal_emb(k_input)
606
+
607
+ q = self.to_q(q_input)
608
+ k = self.to_k(k_input)
609
+ v = self.to_v(v_input)
610
+
611
+ if not collab_heads:
612
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
613
+ else:
614
+ q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
615
+ k = rearrange(k, 'b n d -> b () n d')
616
+ v = rearrange(v, 'b n (h d) -> b h n d', h=h)
617
+
618
+ if layer_past is not None:
619
+ past_key, past_value = layer_past
620
+ k = torch.cat([past_key, k], dim=-2)
621
+ v = torch.cat([past_value, v], dim=-2)
622
+ k_cache = k
623
+ v_cache = v
624
+
625
+ if exists(rotary_pos_emb) and not has_context:
626
+ l = rotary_pos_emb.shape[-1]
627
+ (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
628
+ ql, kl, vl = map(lambda t: apply_rotary_pos_emb(t, rotary_pos_emb), (ql, kl, vl))
629
+ q, k, v = map(lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)))
630
+
631
+ input_mask = None
632
+ if any(map(exists, (mask, context_mask))):
633
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
634
+ k_mask = q_mask if not exists(context) else context_mask
635
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
636
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
637
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
638
+ input_mask = q_mask * k_mask
639
+
640
+ if self.num_mem_kv > 0:
641
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
642
+ k = torch.cat((mem_k, k), dim=-2)
643
+ v = torch.cat((mem_v, v), dim=-2)
644
+ if exists(input_mask):
645
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
646
+
647
+ if collab_heads:
648
+ k = k.expand(-1, h, -1, -1)
649
+
650
+ if self.qk_norm:
651
+ q, k = map(l2norm, (q, k))
652
+ scale = 1 / (self.scale.exp().clamp(min=1e-2))
653
+
654
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
655
+ mask_value = max_neg_value(dots)
656
+
657
+ if exists(prev_attn):
658
+ dots = dots + prev_attn
659
+
660
+ pre_softmax_attn = dots.clone()
661
+
662
+ if talking_heads:
663
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
664
+
665
+ if self.rel_pos_bias:
666
+ dots = self.rel_pos(dots)
667
+
668
+ if exists(input_mask):
669
+ dots.masked_fill_(~input_mask, mask_value)
670
+ del input_mask
671
+
672
+ if exists(attn_mask):
673
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
674
+ if attn_mask.ndim == 2:
675
+ attn_mask = rearrange(attn_mask, 'i j -> () () i j')
676
+ elif attn_mask.ndim == 3:
677
+ attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
678
+ dots.masked_fill_(~attn_mask, mask_value)
679
+
680
+ if exists(self.max_attend_past):
681
+ i, j = dots.shape[-2:]
682
+ range_q = torch.arange(j - i, j, device=device)
683
+ range_k = torch.arange(j, device=device)
684
+ dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
685
+ mask = dist > self.max_attend_past
686
+ dots.masked_fill_(mask, mask_value)
687
+ del mask
688
+
689
+ if self.causal:
690
+ i, j = dots.shape[-2:]
691
+ r = torch.arange(i, device=device)
692
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
693
+ mask = F.pad(mask, (j - i, 0), value=False)
694
+ dots.masked_fill_(mask, mask_value)
695
+ del mask
696
+
697
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
698
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
699
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
700
+ mask = dots < vk
701
+ dots.masked_fill_(mask, mask_value)
702
+ del mask
703
+
704
+ attn = self.attn_fn(dots, dim=-1)
705
+ post_softmax_attn = attn.clone()
706
+
707
+ attn = self.dropout(attn)
708
+
709
+ if talking_heads:
710
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
711
+
712
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
713
+
714
+ if head_scale:
715
+ out = out * self.head_scale_params
716
+
717
+ out = rearrange(out, 'b h n d -> b n (h d)')
718
+
719
+ if exists(self.to_v_gate):
720
+ gates = self.to_v_gate(x)
721
+ out = out * gates.sigmoid()
722
+
723
+ intermediates = Intermediates(
724
+ pre_softmax_attn=pre_softmax_attn,
725
+ post_softmax_attn=post_softmax_attn
726
+ )
727
+
728
+ return self.to_out(out), intermediates, k_cache, v_cache
729
+
730
+
731
+ class AttentionLayers(nn.Module):
732
+ def __init__(
733
+ self,
734
+ dim,
735
+ depth,
736
+ heads=8,
737
+ causal=False,
738
+ cross_attend=False,
739
+ only_cross=False,
740
+ use_scalenorm=False,
741
+ use_rms_scaleshift_norm=False,
742
+ use_rmsnorm=False,
743
+ use_rezero=False,
744
+ alibi_pos_bias=False,
745
+ alibi_num_heads=None,
746
+ alibi_learned=False,
747
+ position_infused_attn=False,
748
+ rotary_pos_emb=False,
749
+ rotary_emb_dim=None,
750
+ custom_layers=None,
751
+ sandwich_coef=None,
752
+ par_ratio=None,
753
+ residual_attn=False,
754
+ cross_residual_attn=False,
755
+ macaron=False,
756
+ pre_norm=True,
757
+ gate_residual=False,
758
+ scale_residual=False,
759
+ shift_tokens=0,
760
+ sandwich_norm=False,
761
+ use_qk_norm_attn=False,
762
+ qk_norm_attn_seq_len=None,
763
+ zero_init_branch_output=False,
764
+ **kwargs
765
+ ):
766
+ super().__init__()
767
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
768
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
769
+
770
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
771
+
772
+ self.dim = dim
773
+ self.depth = depth
774
+ self.layers = nn.ModuleList([])
775
+ self.causal = causal
776
+
777
+ rel_pos_bias = 'rel_pos_bias' in attn_kwargs
778
+ self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
779
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
780
+
781
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
782
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
783
+
784
+ assert not (
785
+ alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
786
+
787
+ if alibi_pos_bias:
788
+ alibi_num_heads = default(alibi_num_heads, heads)
789
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
790
+ alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
791
+ self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
792
+ else:
793
+ self.rel_pos = None
794
+
795
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
796
+ self.pre_norm = pre_norm
797
+ self.sandwich_norm = sandwich_norm
798
+
799
+ self.residual_attn = residual_attn
800
+ self.cross_residual_attn = cross_residual_attn
801
+ self.cross_attend = cross_attend
802
+
803
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
804
+ norm_class = RMSNorm if use_rmsnorm else norm_class
805
+ norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class
806
+ norm_fn = partial(norm_class, dim)
807
+
808
+ norm_fn = nn.Identity if use_rezero else norm_fn
809
+ branch_fn = Rezero if use_rezero else None
810
+
811
+ if cross_attend and not only_cross:
812
+ default_block = ('a', 'c', 'f')
813
+ elif cross_attend and only_cross:
814
+ default_block = ('c', 'f')
815
+ else:
816
+ default_block = ('a', 'f')
817
+
818
+ if macaron:
819
+ default_block = ('f',) + default_block
820
+
821
+ # qk normalization
822
+
823
+ if use_qk_norm_attn:
824
+ attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
825
+ qk_norm_attn_seq_len) else None
826
+ attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
827
+
828
+ # zero init
829
+
830
+ if zero_init_branch_output:
831
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
832
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
833
+
834
+ # calculate layer block order
835
+
836
+ if exists(custom_layers):
837
+ layer_types = custom_layers
838
+ elif exists(par_ratio):
839
+ par_depth = depth * len(default_block)
840
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
841
+ default_block = tuple(filter(not_equals('f'), default_block))
842
+ par_attn = par_depth // par_ratio
843
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
844
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
845
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
846
+ par_block = default_block + ('f',) * (par_width - len(default_block))
847
+ par_head = par_block * par_attn
848
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
849
+ elif exists(sandwich_coef):
850
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
851
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
852
+ else:
853
+ layer_types = default_block * depth
854
+
855
+ self.layer_types = layer_types
856
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
857
+
858
+ # calculate token shifting
859
+
860
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
861
+
862
+ # iterate and construct layers
863
+
864
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
865
+ is_last_layer = ind == (len(self.layer_types) - 1)
866
+
867
+ if layer_type == 'a':
868
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
869
+ elif layer_type == 'c':
870
+ layer = Attention(dim, heads=heads, **attn_kwargs)
871
+ elif layer_type == 'f':
872
+ layer = FeedForward(dim, **ff_kwargs)
873
+ layer = layer if not macaron else Scale(0.5, layer)
874
+ else:
875
+ raise Exception(f'invalid layer type {layer_type}')
876
+
877
+ if layer_shift_tokens > 0:
878
+ shift_range_upper = layer_shift_tokens + 1
879
+ shift_range_lower = -layer_shift_tokens if not causal else 0
880
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
881
+
882
+ if exists(branch_fn):
883
+ layer = branch_fn(layer)
884
+
885
+ residual_fn = GRUGating if gate_residual else Residual
886
+ residual = residual_fn(dim, scale_residual=scale_residual)
887
+
888
+ layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
889
+
890
+ pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
891
+ post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
892
+ post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
893
+
894
+ norms = nn.ModuleList([
895
+ pre_branch_norm,
896
+ post_branch_norm,
897
+ post_main_norm
898
+ ])
899
+
900
+ self.layers.append(nn.ModuleList([
901
+ norms,
902
+ layer,
903
+ residual
904
+ ]))
905
+
906
+ def forward(
907
+ self,
908
+ x,
909
+ context=None,
910
+ full_context=None, # for passing a list of hidden states from an encoder
911
+ mask=None,
912
+ context_mask=None,
913
+ attn_mask=None,
914
+ mems=None,
915
+ return_hiddens=False,
916
+ norm_scale_shift_inp=None,
917
+ past_key_values=None,
918
+ expected_seq_len=None,
919
+ ):
920
+
921
+ assert not (self.cross_attend ^ (exists(context) or exists(
922
+ full_context))), 'context must be passed in if cross_attend is set to True'
923
+ assert context is None or full_context is None, 'only one of full_context or context can be provided'
924
+
925
+ hiddens = []
926
+ intermediates = []
927
+ prev_attn = None
928
+ prev_cross_attn = None
929
+
930
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
931
+ norm_args = {}
932
+ if exists(norm_scale_shift_inp):
933
+ norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
934
+
935
+ rotary_pos_emb = None
936
+ if exists(self.rotary_pos_emb):
937
+ if not self.training and self.causal:
938
+ assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
939
+ elif expected_seq_len is None:
940
+ expected_seq_len = 0
941
+ seq_len = x.shape[1]
942
+ if past_key_values is not None:
943
+ seq_len += past_key_values[0][0].shape[-2]
944
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
945
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
946
+
947
+ present_key_values = []
948
+ cross_attn_count = 0
949
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
950
+ if layer_type == 'a':
951
+ layer_mem = mems.pop(0) if mems else None
952
+
953
+ residual = x
954
+
955
+ pre_branch_norm, post_branch_norm, post_main_norm = norm
956
+
957
+ if exists(pre_branch_norm):
958
+ x = pre_branch_norm(x, **norm_args)
959
+
960
+ if layer_type == 'a' or layer_type == 'c':
961
+ if past_key_values is not None:
962
+ layer_kv = past_key_values.pop(0)
963
+ layer_past = tuple(s.to(x.device) for s in layer_kv)
964
+ else:
965
+ layer_past = None
966
+
967
+ if layer_type == 'a':
968
+ out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
969
+ prev_attn, layer_mem, layer_past)
970
+ elif layer_type == 'c':
971
+ if exists(full_context):
972
+ out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
973
+ None, prev_attn, None, layer_past)
974
+ else:
975
+ out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
976
+ elif layer_type == 'f':
977
+ out = block(x)
978
+
979
+ if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
980
+ present_key_values.append((k.detach(), v.detach()))
981
+
982
+ if exists(post_branch_norm):
983
+ out = post_branch_norm(out, **norm_args)
984
+
985
+ x = residual_fn(out, residual)
986
+
987
+ if layer_type in ('a', 'c'):
988
+ intermediates.append(inter)
989
+
990
+ if layer_type == 'a' and self.residual_attn:
991
+ prev_attn = inter.pre_softmax_attn
992
+ elif layer_type == 'c' and self.cross_residual_attn:
993
+ prev_cross_attn = inter.pre_softmax_attn
994
+
995
+ if exists(post_main_norm):
996
+ x = post_main_norm(x, **norm_args)
997
+
998
+ if layer_type == 'c':
999
+ cross_attn_count += 1
1000
+
1001
+ if layer_type == 'f':
1002
+ hiddens.append(x)
1003
+
1004
+ if return_hiddens:
1005
+ intermediates = LayerIntermediates(
1006
+ hiddens=hiddens,
1007
+ attn_intermediates=intermediates,
1008
+ past_key_values=present_key_values
1009
+ )
1010
+
1011
+ return x, intermediates
1012
+
1013
+ return x
1014
+
1015
+
1016
+ class Encoder(AttentionLayers):
1017
+ def __init__(self, **kwargs):
1018
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
1019
+ super().__init__(causal=False, **kwargs)
1020
+
1021
+
1022
+ class Decoder(AttentionLayers):
1023
+ def __init__(self, **kwargs):
1024
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
1025
+ super().__init__(causal=True, **kwargs)
1026
+
1027
+
1028
+ class CrossAttender(AttentionLayers):
1029
+ def __init__(self, **kwargs):
1030
+ super().__init__(cross_attend=True, only_cross=True, **kwargs)
1031
+
1032
+
1033
+ class ViTransformerWrapper(nn.Module):
1034
+ def __init__(
1035
+ self,
1036
+ *,
1037
+ image_size,
1038
+ patch_size,
1039
+ attn_layers,
1040
+ num_classes=None,
1041
+ dropout=0.,
1042
+ emb_dropout=0.
1043
+ ):
1044
+ super().__init__()
1045
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1046
+ assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1047
+ dim = attn_layers.dim
1048
+ num_patches = (image_size // patch_size) ** 2
1049
+ patch_dim = 3 * patch_size ** 2
1050
+
1051
+ self.patch_size = patch_size
1052
+
1053
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1054
+ self.patch_to_embedding = nn.Linear(patch_dim, dim)
1055
+ self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
1056
+ self.dropout = nn.Dropout(emb_dropout)
1057
+
1058
+ self.attn_layers = attn_layers
1059
+ self.norm = nn.LayerNorm(dim)
1060
+ self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
1061
+
1062
+ def forward(
1063
+ self,
1064
+ img,
1065
+ return_embeddings=False
1066
+ ):
1067
+ p = self.patch_size
1068
+
1069
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
1070
+ x = self.patch_to_embedding(x)
1071
+ b, n, _ = x.shape
1072
+
1073
+ cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
1074
+ x = torch.cat((cls_tokens, x), dim=1)
1075
+ x = x + self.pos_embedding[:, :(n + 1)]
1076
+ x = self.dropout(x)
1077
+
1078
+ x = self.attn_layers(x)
1079
+ x = self.norm(x)
1080
+
1081
+ if not exists(self.mlp_head) or return_embeddings:
1082
+ return x
1083
+
1084
+ return self.mlp_head(x[:, 0])
1085
+
1086
+
1087
+ class TransformerWrapper(nn.Module):
1088
+ def __init__(
1089
+ self,
1090
+ *,
1091
+ num_tokens,
1092
+ max_seq_len,
1093
+ attn_layers,
1094
+ emb_dim=None,
1095
+ max_mem_len=0.,
1096
+ shift_mem_down=0,
1097
+ emb_dropout=0.,
1098
+ num_memory_tokens=None,
1099
+ tie_embedding=False,
1100
+ use_pos_emb=True
1101
+ ):
1102
+ super().__init__()
1103
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1104
+
1105
+ dim = attn_layers.dim
1106
+ emb_dim = default(emb_dim, dim)
1107
+
1108
+ self.max_seq_len = max_seq_len
1109
+ self.max_mem_len = max_mem_len
1110
+ self.shift_mem_down = shift_mem_down
1111
+
1112
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
1113
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
1114
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1115
+ self.emb_dropout = nn.Dropout(emb_dropout)
1116
+
1117
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1118
+ self.attn_layers = attn_layers
1119
+ self.norm = nn.LayerNorm(dim)
1120
+
1121
+ self.init_()
1122
+
1123
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1124
+
1125
+ # memory tokens (like [cls]) from Memory Transformers paper
1126
+ num_memory_tokens = default(num_memory_tokens, 0)
1127
+ self.num_memory_tokens = num_memory_tokens
1128
+ if num_memory_tokens > 0:
1129
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1130
+
1131
+ def init_(self):
1132
+ nn.init.kaiming_normal_(self.token_emb.weight)
1133
+
1134
+ def forward(
1135
+ self,
1136
+ x,
1137
+ return_embeddings=False,
1138
+ mask=None,
1139
+ return_hiddens=False,
1140
+ return_attn=False,
1141
+ mems=None,
1142
+ use_cache=False,
1143
+ **kwargs
1144
+ ):
1145
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
1146
+ x = self.token_emb(x)
1147
+ x = x + self.pos_emb(x)
1148
+ x = self.emb_dropout(x)
1149
+
1150
+ x = self.project_emb(x)
1151
+
1152
+ if num_mem > 0:
1153
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
1154
+ x = torch.cat((mem, x), dim=1)
1155
+
1156
+ # auto-handle masking after appending memory tokens
1157
+ if exists(mask):
1158
+ mask = F.pad(mask, (num_mem, 0), value=True)
1159
+
1160
+ if self.shift_mem_down and exists(mems):
1161
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1162
+ mems = [*mems_r, *mems_l]
1163
+
1164
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1165
+ x = self.norm(x)
1166
+
1167
+ mem, x = x[:, :num_mem], x[:, num_mem:]
1168
+
1169
+ out = self.to_logits(x) if not return_embeddings else x
1170
+
1171
+ if return_hiddens:
1172
+ hiddens = intermediates.hiddens
1173
+ return out, hiddens
1174
+
1175
+ res = [out]
1176
+ if return_attn:
1177
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1178
+ res.append(attn_maps)
1179
+ if use_cache:
1180
+ res.append(intermediates.past_key_values)
1181
+
1182
+ if len(res) > 1:
1183
+ return tuple(res)
1184
+ return res[0]
1185
+
1186
+
1187
+ class ContinuousTransformerWrapper(nn.Module):
1188
+ def __init__(
1189
+ self,
1190
+ *,
1191
+ max_seq_len,
1192
+ attn_layers,
1193
+ dim_in=None,
1194
+ dim_out=None,
1195
+ emb_dim=None,
1196
+ emb_dropout=0.,
1197
+ use_pos_emb=True
1198
+ ):
1199
+ super().__init__()
1200
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1201
+
1202
+ dim = attn_layers.dim
1203
+
1204
+ self.max_seq_len = max_seq_len
1205
+
1206
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
1207
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
1208
+ self.emb_dropout = nn.Dropout(emb_dropout)
1209
+
1210
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1211
+
1212
+ self.attn_layers = attn_layers
1213
+ self.norm = nn.LayerNorm(dim)
1214
+
1215
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ x,
1220
+ return_embeddings=False,
1221
+ mask=None,
1222
+ return_attn=False,
1223
+ mems=None,
1224
+ use_cache=False,
1225
+ **kwargs
1226
+ ):
1227
+ b, n, _, device = *x.shape, x.device
1228
+
1229
+ x = self.project_in(x)
1230
+ x = x + self.pos_emb(x)
1231
+ x = self.emb_dropout(x)
1232
+
1233
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
1234
+ x = self.norm(x)
1235
+
1236
+ out = self.project_out(x) if not return_embeddings else x
1237
+
1238
+ res = [out]
1239
+ if return_attn:
1240
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1241
+ res.append(attn_maps)
1242
+ if use_cache:
1243
+ res.append(intermediates.past_key_values)
1244
+
1245
+ if len(res) > 1:
1246
+ return tuple(res)
1247
+ return res[0]