xinference 1.9.1__py3-none-any.whl → 1.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xinference might be problematic. Click here for more details.

Files changed (334) hide show
  1. xinference/_version.py +3 -3
  2. xinference/api/restful_api.py +400 -3
  3. xinference/client/restful/async_restful_client.py +20 -3
  4. xinference/client/restful/restful_client.py +20 -3
  5. xinference/constants.py +2 -0
  6. xinference/core/supervisor.py +111 -49
  7. xinference/core/worker.py +10 -0
  8. xinference/deploy/cmdline.py +15 -0
  9. xinference/model/audio/core.py +26 -6
  10. xinference/model/audio/indextts2.py +166 -0
  11. xinference/model/audio/kokoro.py +1 -1
  12. xinference/model/audio/kokoro_zh.py +124 -0
  13. xinference/model/audio/model_spec.json +58 -1
  14. xinference/model/embedding/sentence_transformers/core.py +4 -4
  15. xinference/model/embedding/vllm/core.py +7 -1
  16. xinference/model/image/model_spec.json +71 -3
  17. xinference/model/image/stable_diffusion/core.py +13 -4
  18. xinference/model/llm/__init__.py +4 -0
  19. xinference/model/llm/core.py +10 -0
  20. xinference/model/llm/llama_cpp/core.py +1 -0
  21. xinference/model/llm/llm_family.json +503 -21
  22. xinference/model/llm/llm_family.py +1 -0
  23. xinference/model/llm/mlx/core.py +52 -33
  24. xinference/model/llm/sglang/core.py +32 -55
  25. xinference/model/llm/tool_parsers/__init__.py +58 -0
  26. xinference/model/llm/tool_parsers/abstract_tool_parser.py +33 -0
  27. xinference/model/llm/tool_parsers/deepseek_r1_tool_parser.py +190 -0
  28. xinference/model/llm/tool_parsers/deepseek_v3_tool_parser.py +145 -0
  29. xinference/model/llm/tool_parsers/glm4_tool_parser.py +123 -0
  30. xinference/model/llm/tool_parsers/llama3_tool_parser.py +77 -0
  31. xinference/model/llm/tool_parsers/qwen_tool_parser.py +320 -0
  32. xinference/model/llm/transformers/core.py +1 -1
  33. xinference/model/llm/transformers/multimodal/qwen2_vl.py +34 -8
  34. xinference/model/llm/utils.py +138 -53
  35. xinference/model/llm/vllm/core.py +95 -78
  36. xinference/thirdparty/audiotools/__init__.py +10 -0
  37. xinference/thirdparty/audiotools/core/__init__.py +4 -0
  38. xinference/thirdparty/audiotools/core/audio_signal.py +1682 -0
  39. xinference/thirdparty/audiotools/core/display.py +194 -0
  40. xinference/thirdparty/audiotools/core/dsp.py +390 -0
  41. xinference/thirdparty/audiotools/core/effects.py +647 -0
  42. xinference/thirdparty/audiotools/core/ffmpeg.py +211 -0
  43. xinference/thirdparty/audiotools/core/loudness.py +320 -0
  44. xinference/thirdparty/audiotools/core/playback.py +252 -0
  45. xinference/thirdparty/audiotools/core/templates/__init__.py +0 -0
  46. xinference/thirdparty/audiotools/core/templates/headers.html +322 -0
  47. xinference/thirdparty/audiotools/core/templates/pandoc.css +407 -0
  48. xinference/thirdparty/audiotools/core/templates/widget.html +52 -0
  49. xinference/thirdparty/audiotools/core/util.py +671 -0
  50. xinference/thirdparty/audiotools/core/whisper.py +97 -0
  51. xinference/thirdparty/audiotools/data/__init__.py +3 -0
  52. xinference/thirdparty/audiotools/data/datasets.py +517 -0
  53. xinference/thirdparty/audiotools/data/preprocess.py +81 -0
  54. xinference/thirdparty/audiotools/data/transforms.py +1592 -0
  55. xinference/thirdparty/audiotools/metrics/__init__.py +6 -0
  56. xinference/thirdparty/audiotools/metrics/distance.py +131 -0
  57. xinference/thirdparty/audiotools/metrics/quality.py +159 -0
  58. xinference/thirdparty/audiotools/metrics/spectral.py +247 -0
  59. xinference/thirdparty/audiotools/ml/__init__.py +5 -0
  60. xinference/thirdparty/audiotools/ml/accelerator.py +184 -0
  61. xinference/thirdparty/audiotools/ml/decorators.py +440 -0
  62. xinference/thirdparty/audiotools/ml/experiment.py +90 -0
  63. xinference/thirdparty/audiotools/ml/layers/__init__.py +2 -0
  64. xinference/thirdparty/audiotools/ml/layers/base.py +328 -0
  65. xinference/thirdparty/audiotools/ml/layers/spectral_gate.py +127 -0
  66. xinference/thirdparty/audiotools/post.py +140 -0
  67. xinference/thirdparty/audiotools/preference.py +600 -0
  68. xinference/thirdparty/indextts/BigVGAN/ECAPA_TDNN.py +656 -0
  69. xinference/thirdparty/indextts/BigVGAN/__init__.py +0 -0
  70. xinference/thirdparty/indextts/BigVGAN/activations.py +122 -0
  71. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
  72. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
  73. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
  74. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
  75. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  76. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
  77. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
  78. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
  79. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
  80. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
  81. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
  82. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
  83. xinference/thirdparty/indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
  84. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
  85. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/act.py +29 -0
  86. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/filter.py +96 -0
  87. xinference/thirdparty/indextts/BigVGAN/alias_free_torch/resample.py +49 -0
  88. xinference/thirdparty/indextts/BigVGAN/bigvgan.py +534 -0
  89. xinference/thirdparty/indextts/BigVGAN/models.py +451 -0
  90. xinference/thirdparty/indextts/BigVGAN/nnet/CNN.py +546 -0
  91. xinference/thirdparty/indextts/BigVGAN/nnet/__init__.py +0 -0
  92. xinference/thirdparty/indextts/BigVGAN/nnet/linear.py +89 -0
  93. xinference/thirdparty/indextts/BigVGAN/nnet/normalization.py +670 -0
  94. xinference/thirdparty/indextts/BigVGAN/utils.py +101 -0
  95. xinference/thirdparty/indextts/__init__.py +0 -0
  96. xinference/thirdparty/indextts/cli.py +65 -0
  97. xinference/thirdparty/indextts/gpt/__init__.py +0 -0
  98. xinference/thirdparty/indextts/gpt/conformer/__init__.py +0 -0
  99. xinference/thirdparty/indextts/gpt/conformer/attention.py +312 -0
  100. xinference/thirdparty/indextts/gpt/conformer/embedding.py +163 -0
  101. xinference/thirdparty/indextts/gpt/conformer/subsampling.py +348 -0
  102. xinference/thirdparty/indextts/gpt/conformer_encoder.py +520 -0
  103. xinference/thirdparty/indextts/gpt/model.py +713 -0
  104. xinference/thirdparty/indextts/gpt/model_v2.py +747 -0
  105. xinference/thirdparty/indextts/gpt/perceiver.py +317 -0
  106. xinference/thirdparty/indextts/gpt/transformers_beam_search.py +1013 -0
  107. xinference/thirdparty/indextts/gpt/transformers_generation_utils.py +4747 -0
  108. xinference/thirdparty/indextts/gpt/transformers_gpt2.py +1878 -0
  109. xinference/thirdparty/indextts/gpt/transformers_modeling_utils.py +5525 -0
  110. xinference/thirdparty/indextts/infer.py +690 -0
  111. xinference/thirdparty/indextts/infer_v2.py +739 -0
  112. xinference/thirdparty/indextts/s2mel/dac/__init__.py +16 -0
  113. xinference/thirdparty/indextts/s2mel/dac/__main__.py +36 -0
  114. xinference/thirdparty/indextts/s2mel/dac/model/__init__.py +4 -0
  115. xinference/thirdparty/indextts/s2mel/dac/model/base.py +294 -0
  116. xinference/thirdparty/indextts/s2mel/dac/model/dac.py +400 -0
  117. xinference/thirdparty/indextts/s2mel/dac/model/discriminator.py +228 -0
  118. xinference/thirdparty/indextts/s2mel/dac/model/encodec.py +320 -0
  119. xinference/thirdparty/indextts/s2mel/dac/nn/__init__.py +3 -0
  120. xinference/thirdparty/indextts/s2mel/dac/nn/layers.py +33 -0
  121. xinference/thirdparty/indextts/s2mel/dac/nn/loss.py +368 -0
  122. xinference/thirdparty/indextts/s2mel/dac/nn/quantize.py +339 -0
  123. xinference/thirdparty/indextts/s2mel/dac/utils/__init__.py +123 -0
  124. xinference/thirdparty/indextts/s2mel/dac/utils/decode.py +95 -0
  125. xinference/thirdparty/indextts/s2mel/dac/utils/encode.py +94 -0
  126. xinference/thirdparty/indextts/s2mel/hf_utils.py +12 -0
  127. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/__init__.py +5 -0
  128. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/act.py +29 -0
  129. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/filter.py +96 -0
  130. xinference/thirdparty/indextts/s2mel/modules/alias_free_torch/resample.py +57 -0
  131. xinference/thirdparty/indextts/s2mel/modules/audio.py +82 -0
  132. xinference/thirdparty/indextts/s2mel/modules/bigvgan/activations.py +120 -0
  133. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  134. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  135. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  136. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  137. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/compat.h +29 -0
  138. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/load.py +86 -0
  139. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  140. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/__init__.py +6 -0
  141. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/act.py +30 -0
  142. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/filter.py +101 -0
  143. xinference/thirdparty/indextts/s2mel/modules/bigvgan/alias_free_activation/torch/resample.py +58 -0
  144. xinference/thirdparty/indextts/s2mel/modules/bigvgan/bigvgan.py +492 -0
  145. xinference/thirdparty/indextts/s2mel/modules/bigvgan/config.json +63 -0
  146. xinference/thirdparty/indextts/s2mel/modules/bigvgan/env.py +18 -0
  147. xinference/thirdparty/indextts/s2mel/modules/bigvgan/meldataset.py +354 -0
  148. xinference/thirdparty/indextts/s2mel/modules/bigvgan/utils.py +99 -0
  149. xinference/thirdparty/indextts/s2mel/modules/campplus/DTDNN.py +115 -0
  150. xinference/thirdparty/indextts/s2mel/modules/campplus/classifier.py +70 -0
  151. xinference/thirdparty/indextts/s2mel/modules/campplus/layers.py +253 -0
  152. xinference/thirdparty/indextts/s2mel/modules/commons.py +632 -0
  153. xinference/thirdparty/indextts/s2mel/modules/diffusion_transformer.py +257 -0
  154. xinference/thirdparty/indextts/s2mel/modules/encodec.py +292 -0
  155. xinference/thirdparty/indextts/s2mel/modules/flow_matching.py +171 -0
  156. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/generate.py +436 -0
  157. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/model.py +360 -0
  158. xinference/thirdparty/indextts/s2mel/modules/gpt_fast/quantize.py +622 -0
  159. xinference/thirdparty/indextts/s2mel/modules/hifigan/f0_predictor.py +55 -0
  160. xinference/thirdparty/indextts/s2mel/modules/hifigan/generator.py +454 -0
  161. xinference/thirdparty/indextts/s2mel/modules/layers.py +354 -0
  162. xinference/thirdparty/indextts/s2mel/modules/length_regulator.py +141 -0
  163. xinference/thirdparty/indextts/s2mel/modules/openvoice/__init__.py +0 -0
  164. xinference/thirdparty/indextts/s2mel/modules/openvoice/api.py +186 -0
  165. xinference/thirdparty/indextts/s2mel/modules/openvoice/attentions.py +465 -0
  166. xinference/thirdparty/indextts/s2mel/modules/openvoice/checkpoints_v2/converter/config.json +57 -0
  167. xinference/thirdparty/indextts/s2mel/modules/openvoice/commons.py +160 -0
  168. xinference/thirdparty/indextts/s2mel/modules/openvoice/mel_processing.py +183 -0
  169. xinference/thirdparty/indextts/s2mel/modules/openvoice/models.py +499 -0
  170. xinference/thirdparty/indextts/s2mel/modules/openvoice/modules.py +598 -0
  171. xinference/thirdparty/indextts/s2mel/modules/openvoice/openvoice_app.py +275 -0
  172. xinference/thirdparty/indextts/s2mel/modules/openvoice/se_extractor.py +153 -0
  173. xinference/thirdparty/indextts/s2mel/modules/openvoice/transforms.py +209 -0
  174. xinference/thirdparty/indextts/s2mel/modules/openvoice/utils.py +194 -0
  175. xinference/thirdparty/indextts/s2mel/modules/quantize.py +229 -0
  176. xinference/thirdparty/indextts/s2mel/modules/rmvpe.py +631 -0
  177. xinference/thirdparty/indextts/s2mel/modules/vocos/__init__.py +4 -0
  178. xinference/thirdparty/indextts/s2mel/modules/vocos/heads.py +164 -0
  179. xinference/thirdparty/indextts/s2mel/modules/vocos/helpers.py +71 -0
  180. xinference/thirdparty/indextts/s2mel/modules/vocos/loss.py +114 -0
  181. xinference/thirdparty/indextts/s2mel/modules/vocos/models.py +118 -0
  182. xinference/thirdparty/indextts/s2mel/modules/vocos/modules.py +213 -0
  183. xinference/thirdparty/indextts/s2mel/modules/vocos/pretrained.py +51 -0
  184. xinference/thirdparty/indextts/s2mel/modules/vocos/spectral_ops.py +192 -0
  185. xinference/thirdparty/indextts/s2mel/modules/wavenet.py +174 -0
  186. xinference/thirdparty/indextts/s2mel/optimizers.py +96 -0
  187. xinference/thirdparty/indextts/s2mel/wav2vecbert_extract.py +148 -0
  188. xinference/thirdparty/indextts/utils/__init__.py +0 -0
  189. xinference/thirdparty/indextts/utils/arch_util.py +120 -0
  190. xinference/thirdparty/indextts/utils/checkpoint.py +34 -0
  191. xinference/thirdparty/indextts/utils/common.py +121 -0
  192. xinference/thirdparty/indextts/utils/feature_extractors.py +50 -0
  193. xinference/thirdparty/indextts/utils/front.py +536 -0
  194. xinference/thirdparty/indextts/utils/maskgct/models/codec/__init__.py +0 -0
  195. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/codec.py +427 -0
  196. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/__init__.py +11 -0
  197. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/factorized_vector_quantize.py +150 -0
  198. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/lookup_free_quantize.py +77 -0
  199. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/residual_vq.py +177 -0
  200. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/quantize/vector_quantize.py +401 -0
  201. xinference/thirdparty/indextts/utils/maskgct/models/codec/amphion_codec/vocos.py +881 -0
  202. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_dataset.py +264 -0
  203. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_inference.py +515 -0
  204. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_sampler.py +126 -0
  205. xinference/thirdparty/indextts/utils/maskgct/models/codec/codec_trainer.py +166 -0
  206. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/__init__.py +0 -0
  207. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/__init__.py +5 -0
  208. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/act.py +29 -0
  209. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/filter.py +96 -0
  210. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/alias_free_torch/resample.py +57 -0
  211. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_dataset.py +98 -0
  212. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_inference.py +137 -0
  213. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/facodec_trainer.py +776 -0
  214. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/__init__.py +1 -0
  215. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/bst.t7 +0 -0
  216. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/JDC/model.py +219 -0
  217. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/attentions.py +437 -0
  218. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/commons.py +331 -0
  219. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/gradient_reversal.py +35 -0
  220. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/layers.py +460 -0
  221. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/quantize.py +741 -0
  222. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/style_encoder.py +110 -0
  223. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/modules/wavenet.py +224 -0
  224. xinference/thirdparty/indextts/utils/maskgct/models/codec/facodec/optimizer.py +104 -0
  225. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/repcodec_model.py +210 -0
  226. xinference/thirdparty/indextts/utils/maskgct/models/codec/kmeans/vocos.py +850 -0
  227. xinference/thirdparty/indextts/utils/maskgct/models/codec/melvqgan/melspec.py +108 -0
  228. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/README.md +216 -0
  229. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/__init__.py +6 -0
  230. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  231. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  232. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  233. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/alias_free_torch/resample.py +57 -0
  234. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/facodec.py +1222 -0
  235. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/gradient_reversal.py +35 -0
  236. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/melspec.py +102 -0
  237. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/__init__.py +7 -0
  238. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/fvq.py +116 -0
  239. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/quantize/rvq.py +87 -0
  240. xinference/thirdparty/indextts/utils/maskgct/models/codec/ns3_codec/transformer.py +234 -0
  241. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/model.py +184 -0
  242. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/__init__.py +27 -0
  243. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/conv.py +346 -0
  244. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/lstm.py +46 -0
  245. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/norm.py +37 -0
  246. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/__init__.py +14 -0
  247. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/ac.py +317 -0
  248. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/core_vq.py +388 -0
  249. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/distrib.py +135 -0
  250. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/quantization/vq.py +125 -0
  251. xinference/thirdparty/indextts/utils/maskgct/models/codec/speechtokenizer/modules/seanet.py +414 -0
  252. xinference/thirdparty/indextts/utils/maskgct/models/codec/vevo/vevo_repcodec.py +592 -0
  253. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/ckpt/wav2vec2bert_stats.pt +0 -0
  254. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/llama_nar.py +650 -0
  255. xinference/thirdparty/indextts/utils/maskgct/models/tts/maskgct/maskgct_s2a.py +503 -0
  256. xinference/thirdparty/indextts/utils/maskgct_utils.py +259 -0
  257. xinference/thirdparty/indextts/utils/text_utils.py +41 -0
  258. xinference/thirdparty/indextts/utils/typical_sampling.py +30 -0
  259. xinference/thirdparty/indextts/utils/utils.py +93 -0
  260. xinference/thirdparty/indextts/utils/webui_utils.py +42 -0
  261. xinference/thirdparty/indextts/utils/xtransformers.py +1247 -0
  262. xinference/thirdparty/indextts/vqvae/__init__.py +0 -0
  263. xinference/thirdparty/indextts/vqvae/xtts_dvae.py +395 -0
  264. xinference/types.py +105 -2
  265. xinference/ui/gradio/media_interface.py +66 -8
  266. xinference/ui/web/ui/build/asset-manifest.json +6 -6
  267. xinference/ui/web/ui/build/index.html +1 -1
  268. xinference/ui/web/ui/build/static/css/main.5ea97072.css +2 -0
  269. xinference/ui/web/ui/build/static/css/main.5ea97072.css.map +1 -0
  270. xinference/ui/web/ui/build/static/js/main.d192c4f3.js +3 -0
  271. xinference/ui/web/ui/build/static/js/{main.1086c759.js.LICENSE.txt → main.d192c4f3.js.LICENSE.txt} +0 -7
  272. xinference/ui/web/ui/build/static/js/main.d192c4f3.js.map +1 -0
  273. xinference/ui/web/ui/node_modules/.cache/babel-loader/089c38df5f52348d212ed868dda5c518a42e0c2762caed4175487c0405830c35.json +1 -0
  274. xinference/ui/web/ui/node_modules/.cache/babel-loader/2b6e3a5b6eb2c5c5f2d007e68cd46c372721cd52bf63508adcdb21ecf79241d8.json +1 -0
  275. xinference/ui/web/ui/node_modules/.cache/babel-loader/2d887825fd07a56f872eda4420da25fba0b5b62a23bdcc6c6da1a5281887f618.json +1 -0
  276. xinference/ui/web/ui/node_modules/.cache/babel-loader/4001f9c3e64e73a4f2158826650c174a59d5e3f89ddecddf17cbb6bb688cc4ca.json +1 -0
  277. xinference/ui/web/ui/node_modules/.cache/babel-loader/4a7018a69e6b7f90fc313248c2aa86f2a8f1eb1db120df586047a8023549b44b.json +1 -0
  278. xinference/ui/web/ui/node_modules/.cache/babel-loader/64b12aaa1c1d1bf53820ada8a63769067c0ccc5aab46b32348eb1917ae7f2a11.json +1 -0
  279. xinference/ui/web/ui/node_modules/.cache/babel-loader/7275b67c78ec76ce38a686bb8a576d8c9cecf54e1573614c84859d538efb9be5.json +1 -0
  280. xinference/ui/web/ui/node_modules/.cache/babel-loader/a68b6ee3b31eadc051fb95ce8f8ccb9c2e8b52c60f290dbab545a1917e065282.json +1 -0
  281. xinference/ui/web/ui/node_modules/.cache/babel-loader/ae8771cc37693feb160fa8727231312a0c54ef2d1d1ca893be568cd70016ca7e.json +1 -0
  282. xinference/ui/web/ui/node_modules/.cache/babel-loader/bb4e8722d2d41d87f1fce3661bc8937bffe9448e231fc5f0462630849e851592.json +1 -0
  283. xinference/ui/web/ui/node_modules/.cache/babel-loader/be6aada1ee4adc2bbf65dbe56d17db32bb3b5478be05d6b527805a8ba6cfb2b9.json +1 -0
  284. xinference/ui/web/ui/node_modules/.cache/babel-loader/de91c352653c233cf0cb6674e6e04049a44fd0e1156560de65d5c4620521391e.json +1 -0
  285. xinference/ui/web/ui/node_modules/.cache/babel-loader/e85f7002fc325c83b9c9cd8a1619e5b3ebc701d30e811afc284b88e6ae710cb5.json +1 -0
  286. xinference/ui/web/ui/node_modules/.cache/babel-loader/e8b603c78944bf3d213639078bfe155ff5c0dfa4048a93cbb967cad6a4eb4ff3.json +1 -0
  287. xinference/ui/web/ui/node_modules/.cache/babel-loader/f05535160a508b2a312de546a6de234776c613db276479ea4253c0b1bdeeb7d6.json +1 -0
  288. xinference/ui/web/ui/node_modules/.cache/babel-loader/f09ba9e11106bd59a0de10cc85c55084097729dcab575f43dfcf07375961ed87.json +1 -0
  289. xinference/ui/web/ui/node_modules/.cache/babel-loader/f995a2425dfb0822fd07127f66ffe9b026883bc156b402eb8bd0b83d52460a93.json +1 -0
  290. xinference/ui/web/ui/node_modules/.package-lock.json +0 -33
  291. xinference/ui/web/ui/package-lock.json +0 -34
  292. xinference/ui/web/ui/package.json +0 -1
  293. xinference/ui/web/ui/src/locales/en.json +9 -3
  294. xinference/ui/web/ui/src/locales/ja.json +9 -3
  295. xinference/ui/web/ui/src/locales/ko.json +9 -3
  296. xinference/ui/web/ui/src/locales/zh.json +9 -3
  297. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/METADATA +24 -4
  298. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/RECORD +302 -76
  299. xinference/ui/web/ui/build/static/css/main.013f296b.css +0 -2
  300. xinference/ui/web/ui/build/static/css/main.013f296b.css.map +0 -1
  301. xinference/ui/web/ui/build/static/js/main.1086c759.js +0 -3
  302. xinference/ui/web/ui/build/static/js/main.1086c759.js.map +0 -1
  303. xinference/ui/web/ui/node_modules/.cache/babel-loader/0b0f77000cc1b482ca091cfbcae511dfe02f08916971645fad21d0b1234d04a2.json +0 -1
  304. xinference/ui/web/ui/node_modules/.cache/babel-loader/1c5f8ff423a7c9202bea60b15680f04b1e9964b445b0da3f86c6ff70cf24e797.json +0 -1
  305. xinference/ui/web/ui/node_modules/.cache/babel-loader/44ce7993e344980e3ed4f13e8f69237d4a5dfc60e37ca6b54f51f8ee1357bd67.json +0 -1
  306. xinference/ui/web/ui/node_modules/.cache/babel-loader/4aec1cc414ac3ebb3481d3d915e4db597d9127de813291346eacb8554ab170d4.json +0 -1
  307. xinference/ui/web/ui/node_modules/.cache/babel-loader/644cfec52f3c57a6e222ce60f112237a1efefe9835efd9aad857a685f53d8eed.json +0 -1
  308. xinference/ui/web/ui/node_modules/.cache/babel-loader/663436f72af53fe0d72394f56d003fa4e0bba489e5bb4e483fd34b00f84637f7.json +0 -1
  309. xinference/ui/web/ui/node_modules/.cache/babel-loader/69db82ca9bfe27fe417cc6cf2b1716b09be9c6f0cd198530f12bfc60e801bbcf.json +0 -1
  310. xinference/ui/web/ui/node_modules/.cache/babel-loader/85087e27618d740c236bf159f30e0219db443ab55f0997388eed5fde6f9e90cc.json +0 -1
  311. xinference/ui/web/ui/node_modules/.cache/babel-loader/88b07838348864aa86c672be3bbca1e9f58f6f3a2881b32070ec27f4e7b449d1.json +0 -1
  312. xinference/ui/web/ui/node_modules/.cache/babel-loader/8b8cd408ccfbe115acef27ccfa5b233da8597131a2a5712add13e1e4d5d4504b.json +0 -1
  313. xinference/ui/web/ui/node_modules/.cache/babel-loader/a23824fe746b9c6ca5eee9159b5764d1ff1653c1d856288c0f75c742bbb0023b.json +0 -1
  314. xinference/ui/web/ui/node_modules/.cache/babel-loader/a3eb18af328280b139693c9092dff2a0ef8c9a967e6c8956ceee0996611f1984.json +0 -1
  315. xinference/ui/web/ui/node_modules/.cache/babel-loader/bc1aacc65a102db325ca61bcd2f681e1ae22c36a1f1d98a6ff5e4ad49dc7544f.json +0 -1
  316. xinference/ui/web/ui/node_modules/.cache/babel-loader/c682fd521747c19dae437d83ce3235a306ce6b68e24a117bc57c27ebb8d1f1ca.json +0 -1
  317. xinference/ui/web/ui/node_modules/.cache/babel-loader/d5c224be7081f18cba1678b7874a9782eba895df004874ff8f243f94ba79942a.json +0 -1
  318. xinference/ui/web/ui/node_modules/.cache/babel-loader/f7f18bfb539b036a6a342176dd98a85df5057a884a8da978d679f2a0264883d0.json +0 -1
  319. xinference/ui/web/ui/node_modules/clipboard/.babelrc.json +0 -11
  320. xinference/ui/web/ui/node_modules/clipboard/.eslintrc.json +0 -24
  321. xinference/ui/web/ui/node_modules/clipboard/.prettierrc.json +0 -9
  322. xinference/ui/web/ui/node_modules/clipboard/bower.json +0 -18
  323. xinference/ui/web/ui/node_modules/clipboard/composer.json +0 -25
  324. xinference/ui/web/ui/node_modules/clipboard/package.json +0 -63
  325. xinference/ui/web/ui/node_modules/delegate/package.json +0 -31
  326. xinference/ui/web/ui/node_modules/good-listener/bower.json +0 -11
  327. xinference/ui/web/ui/node_modules/good-listener/package.json +0 -35
  328. xinference/ui/web/ui/node_modules/select/bower.json +0 -13
  329. xinference/ui/web/ui/node_modules/select/package.json +0 -29
  330. xinference/ui/web/ui/node_modules/tiny-emitter/package.json +0 -53
  331. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/WHEEL +0 -0
  332. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/entry_points.txt +0 -0
  333. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/licenses/LICENSE +0 -0
  334. {xinference-1.9.1.dist-info → xinference-1.10.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1878 @@
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from packaging import version
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ import transformers
32
+
33
+ from indextts.gpt.transformers_generation_utils import GenerationMixin
34
+ from indextts.gpt.transformers_modeling_utils import PreTrainedModel
35
+ from transformers.modeling_utils import SequenceSummary
36
+
37
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ CausalLMOutputWithCrossAttentions,
41
+ QuestionAnsweringModelOutput,
42
+ SequenceClassifierOutputWithPast,
43
+ TokenClassifierOutput,
44
+ )
45
+ # from transformers.modeling_utils import PreTrainedModel, SequenceSummary
46
+
47
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
48
+ from transformers.utils import (
49
+ ModelOutput,
50
+ add_code_sample_docstrings,
51
+ add_start_docstrings,
52
+ add_start_docstrings_to_model_forward,
53
+ get_torch_version,
54
+ is_flash_attn_2_available,
55
+ is_flash_attn_greater_or_equal_2_10,
56
+ logging,
57
+ replace_return_docstrings,
58
+ )
59
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
60
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
61
+
62
+
63
+ if is_flash_attn_2_available():
64
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
65
+
66
+
67
+ logger = logging.get_logger(__name__)
68
+
69
+ _CHECKPOINT_FOR_DOC = "openai-community/gpt2"
70
+ _CONFIG_FOR_DOC = "GPT2Config"
71
+
72
+
73
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
74
+ """Load tf checkpoints in a pytorch model"""
75
+ try:
76
+ import re
77
+
78
+ import tensorflow as tf
79
+ except ImportError:
80
+ logger.error(
81
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
82
+ "https://www.tensorflow.org/install/ for installation instructions."
83
+ )
84
+ raise
85
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
86
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
87
+ # Load weights from TF model
88
+ init_vars = tf.train.list_variables(tf_path)
89
+ names = []
90
+ arrays = []
91
+ for name, shape in init_vars:
92
+ logger.info(f"Loading TF weight {name} with shape {shape}")
93
+ array = tf.train.load_variable(tf_path, name)
94
+ names.append(name)
95
+ arrays.append(array.squeeze())
96
+
97
+ for name, array in zip(names, arrays):
98
+ name = name[6:] # skip "model/"
99
+ name = name.split("/")
100
+ pointer = model
101
+ for m_name in name:
102
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
103
+ scope_names = re.split(r"(\d+)", m_name)
104
+ else:
105
+ scope_names = [m_name]
106
+ if scope_names[0] == "w" or scope_names[0] == "g":
107
+ pointer = getattr(pointer, "weight")
108
+ elif scope_names[0] == "b":
109
+ pointer = getattr(pointer, "bias")
110
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
111
+ pointer = getattr(pointer, scope_names[0])
112
+ pointer = getattr(pointer, "weight")
113
+ else:
114
+ pointer = getattr(pointer, scope_names[0])
115
+ if len(scope_names) >= 2:
116
+ num = int(scope_names[1])
117
+ pointer = pointer[num]
118
+ try:
119
+ if pointer.shape != array.shape:
120
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
121
+ except ValueError as e:
122
+ e.args += (pointer.shape, array.shape)
123
+ raise
124
+ logger.info(f"Initialize PyTorch weight {name}")
125
+ pointer.data = torch.from_numpy(array)
126
+ return model
127
+
128
+
129
+ class GPT2Attention(nn.Module):
130
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
131
+ super().__init__()
132
+ self.config = config
133
+ max_positions = config.max_position_embeddings
134
+ self.register_buffer(
135
+ "bias",
136
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
137
+ 1, 1, max_positions, max_positions
138
+ ),
139
+ persistent=False,
140
+ )
141
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
142
+
143
+ self.embed_dim = config.hidden_size
144
+ self.num_heads = config.num_attention_heads
145
+ self.head_dim = self.embed_dim // self.num_heads
146
+ self.split_size = self.embed_dim
147
+ if self.head_dim * self.num_heads != self.embed_dim:
148
+ raise ValueError(
149
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
150
+ f" {self.num_heads})."
151
+ )
152
+
153
+ self.scale_attn_weights = config.scale_attn_weights
154
+ self.is_cross_attention = is_cross_attention
155
+
156
+ # Layer-wise attention scaling, reordering, and upcasting
157
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
158
+ self.layer_idx = layer_idx
159
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
160
+
161
+ if self.is_cross_attention:
162
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
163
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
164
+ else:
165
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
166
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
167
+
168
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
169
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
170
+ self.is_causal = True
171
+
172
+ self.pruned_heads = set()
173
+
174
+ def prune_heads(self, heads):
175
+ if len(heads) == 0:
176
+ return
177
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
178
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
179
+
180
+ # Prune conv1d layers
181
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
182
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
183
+
184
+ # Update hyper params
185
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
186
+ self.num_heads = self.num_heads - len(heads)
187
+ self.pruned_heads = self.pruned_heads.union(heads)
188
+
189
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
190
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
191
+
192
+ if self.scale_attn_weights:
193
+ attn_weights = attn_weights / torch.full(
194
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
195
+ )
196
+
197
+ # Layer-wise attention scaling
198
+ if self.scale_attn_by_inverse_layer_idx:
199
+ attn_weights = attn_weights / float(self.layer_idx + 1)
200
+
201
+ if not self.is_cross_attention:
202
+ # if only "normal" attention layer implements causal mask
203
+ query_length, key_length = query.size(-2), key.size(-2)
204
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
205
+ mask_value = torch.finfo(attn_weights.dtype).min
206
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
207
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
208
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
209
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
210
+
211
+ if attention_mask is not None:
212
+ # Apply the attention mask
213
+ attn_weights = attn_weights + attention_mask
214
+
215
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
216
+
217
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
218
+ attn_weights = attn_weights.type(value.dtype)
219
+ attn_weights = self.attn_dropout(attn_weights)
220
+
221
+ # Mask heads if we want to
222
+ if head_mask is not None:
223
+ attn_weights = attn_weights * head_mask
224
+
225
+ attn_output = torch.matmul(attn_weights, value)
226
+
227
+ return attn_output, attn_weights
228
+
229
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
230
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
231
+ bsz, num_heads, q_seq_len, dk = query.size()
232
+ _, _, k_seq_len, _ = key.size()
233
+
234
+ # Preallocate attn_weights for `baddbmm`
235
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
236
+
237
+ # Compute Scale Factor
238
+ scale_factor = 1.0
239
+ if self.scale_attn_weights:
240
+ scale_factor /= float(value.size(-1)) ** 0.5
241
+
242
+ if self.scale_attn_by_inverse_layer_idx:
243
+ scale_factor /= float(self.layer_idx + 1)
244
+
245
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
246
+ with torch.amp.autocast(query.device.type, enabled=False):
247
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
248
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
249
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
250
+
251
+ if not self.is_cross_attention:
252
+ # if only "normal" attention layer implements causal mask
253
+ query_length, key_length = query.size(-2), key.size(-2)
254
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
255
+ mask_value = torch.finfo(attn_weights.dtype).min
256
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
257
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
258
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
259
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
260
+
261
+ if attention_mask is not None:
262
+ # Apply the attention mask
263
+ attn_weights = attn_weights + attention_mask
264
+
265
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
266
+
267
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
268
+ if attn_weights.dtype != torch.float32:
269
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
270
+ attn_weights = attn_weights.type(value.dtype)
271
+ attn_weights = self.attn_dropout(attn_weights)
272
+
273
+ # Mask heads if we want to
274
+ if head_mask is not None:
275
+ attn_weights = attn_weights * head_mask
276
+
277
+ attn_output = torch.matmul(attn_weights, value)
278
+
279
+ return attn_output, attn_weights
280
+
281
+ def _split_heads(self, tensor, num_heads, attn_head_size):
282
+ """
283
+ Splits hidden_size dim into attn_head_size and num_heads
284
+ """
285
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
286
+ tensor = tensor.view(new_shape)
287
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
288
+
289
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
290
+ """
291
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
292
+ """
293
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
294
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
295
+ return tensor.view(new_shape)
296
+
297
+ def forward(
298
+ self,
299
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
300
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
301
+ attention_mask: Optional[torch.FloatTensor] = None,
302
+ head_mask: Optional[torch.FloatTensor] = None,
303
+ encoder_hidden_states: Optional[torch.Tensor] = None,
304
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
305
+ use_cache: Optional[bool] = False,
306
+ output_attentions: Optional[bool] = False,
307
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
308
+ if encoder_hidden_states is not None:
309
+ if not hasattr(self, "q_attn"):
310
+ raise ValueError(
311
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
312
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
313
+ )
314
+
315
+ query = self.q_attn(hidden_states)
316
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
317
+ attention_mask = encoder_attention_mask
318
+ else:
319
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
320
+
321
+ query = self._split_heads(query, self.num_heads, self.head_dim)
322
+ key = self._split_heads(key, self.num_heads, self.head_dim)
323
+ value = self._split_heads(value, self.num_heads, self.head_dim)
324
+
325
+ if layer_past is not None:
326
+ past_key, past_value = layer_past
327
+ key = torch.cat((past_key, key), dim=-2)
328
+ value = torch.cat((past_value, value), dim=-2)
329
+
330
+ if use_cache is True:
331
+ present = (key, value)
332
+ else:
333
+ present = None
334
+
335
+ if self.reorder_and_upcast_attn:
336
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
337
+ else:
338
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
339
+
340
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
341
+ attn_output = self.c_proj(attn_output)
342
+ attn_output = self.resid_dropout(attn_output)
343
+
344
+ outputs = (attn_output, present)
345
+ if output_attentions:
346
+ outputs += (attn_weights,)
347
+
348
+ return outputs # a, present, (attentions)
349
+
350
+
351
+ class GPT2FlashAttention2(GPT2Attention):
352
+ """
353
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
354
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
355
+ flash attention and deal with padding tokens in case the input contains any of them.
356
+ """
357
+
358
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
359
+ def __init__(self, *args, **kwargs):
360
+ super().__init__(*args, **kwargs)
361
+
362
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
363
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
364
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
365
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
370
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
371
+ attention_mask: Optional[torch.FloatTensor] = None,
372
+ head_mask: Optional[torch.FloatTensor] = None,
373
+ encoder_hidden_states: Optional[torch.Tensor] = None,
374
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
375
+ use_cache: Optional[bool] = False,
376
+ output_attentions: Optional[bool] = False,
377
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
378
+ bsz, _, _ = hidden_states.size()
379
+ if encoder_hidden_states is not None:
380
+ if not hasattr(self, "q_attn"):
381
+ raise ValueError(
382
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
383
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
384
+ )
385
+
386
+ query = self.q_attn(hidden_states)
387
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
388
+ attention_mask = encoder_attention_mask
389
+ else:
390
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
391
+
392
+ query = self._split_heads(query, self.num_heads, self.head_dim)
393
+ key = self._split_heads(key, self.num_heads, self.head_dim)
394
+ value = self._split_heads(value, self.num_heads, self.head_dim)
395
+
396
+ if layer_past is not None:
397
+ past_key = layer_past[0]
398
+ past_value = layer_past[1]
399
+ key = torch.cat((past_key, key), dim=-2)
400
+ value = torch.cat((past_value, value), dim=-2)
401
+
402
+ present = None
403
+ if use_cache is True:
404
+ present = (key, value)
405
+
406
+ query_length = query.shape[2]
407
+ tgt_len = key.shape[2]
408
+
409
+ # Flash attention requires the input to have the shape
410
+ # batch_size x seq_length x head_dim x hidden_dim
411
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
412
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
413
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
414
+
415
+ attn_dropout = self.attn_dropout.p if self.training else 0.0
416
+
417
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
418
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
419
+ # cast them back in the correct dtype just to be sure everything works as expected.
420
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
421
+ # in fp32. (LlamaRMSNorm handles it correctly)
422
+
423
+ if query.dtype == torch.float32:
424
+ if torch.is_autocast_enabled():
425
+ target_dtype = torch.get_autocast_gpu_dtype()
426
+ # Handle the case where the model is quantized
427
+ elif hasattr(self.config, "_pre_quantization_dtype"):
428
+ target_dtype = self.config._pre_quantization_dtype
429
+ else:
430
+ target_dtype = self.c_proj.weight.dtype
431
+
432
+ logger.warning_once(
433
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
434
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
+ f" {target_dtype}."
436
+ )
437
+
438
+ query = query.to(target_dtype)
439
+ key = key.to(target_dtype)
440
+ value = value.to(target_dtype)
441
+
442
+ attn_output = _flash_attention_forward(
443
+ query,
444
+ key,
445
+ value,
446
+ attention_mask,
447
+ query_length,
448
+ dropout=attn_dropout,
449
+ is_causal=self.is_causal,
450
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
451
+ )
452
+
453
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
454
+ attn_output = self.c_proj(attn_weights_reshaped)
455
+ attn_output = self.resid_dropout(attn_output)
456
+
457
+ outputs = (attn_output, present)
458
+ if output_attentions:
459
+ outputs += (attn_weights_reshaped,)
460
+
461
+ return outputs
462
+
463
+
464
+ class GPT2SdpaAttention(GPT2Attention):
465
+ """
466
+ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
467
+ `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass
468
+ to adapt to the SDPA API.
469
+ """
470
+
471
+ def __init__(self, *args, **kwargs):
472
+ super().__init__(*args, **kwargs)
473
+
474
+ # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__
475
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
476
+ # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0.
477
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
478
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
479
+
480
+ def forward(
481
+ self,
482
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
483
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
484
+ attention_mask: Optional[torch.FloatTensor] = None,
485
+ head_mask: Optional[torch.FloatTensor] = None,
486
+ encoder_hidden_states: Optional[torch.Tensor] = None,
487
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
488
+ use_cache: Optional[bool] = False,
489
+ output_attentions: Optional[bool] = False,
490
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
491
+ if output_attentions or head_mask is not None:
492
+ logger.warning_once(
493
+ "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
494
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
495
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
496
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
497
+ )
498
+ return super().forward(
499
+ hidden_states=hidden_states,
500
+ layer_past=layer_past,
501
+ attention_mask=attention_mask,
502
+ head_mask=head_mask,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ encoder_attention_mask=encoder_attention_mask,
505
+ use_cache=use_cache,
506
+ output_attentions=output_attentions,
507
+ )
508
+
509
+ bsz, q_len, _ = hidden_states.size()
510
+
511
+ # Initial attention projections
512
+ is_cross_attention = encoder_hidden_states is not None
513
+ if is_cross_attention:
514
+ if not hasattr(self, "q_attn"):
515
+ raise ValueError(
516
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
517
+ "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
518
+ )
519
+
520
+ query = self.q_attn(hidden_states)
521
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
522
+ attention_mask = encoder_attention_mask
523
+ else:
524
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
525
+
526
+ query = self._split_heads(query, self.num_heads, self.head_dim)
527
+ key = self._split_heads(key, self.num_heads, self.head_dim)
528
+ value = self._split_heads(value, self.num_heads, self.head_dim)
529
+
530
+ # Optional kv caching
531
+ if layer_past is not None:
532
+ past_key = layer_past[0]
533
+ past_value = layer_past[1]
534
+ key = torch.cat((past_key, key), dim=-2)
535
+ value = torch.cat((past_value, value), dim=-2)
536
+
537
+ present = None
538
+ if use_cache is True:
539
+ present = (key, value)
540
+
541
+ # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
542
+ if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
543
+ query = query.contiguous()
544
+ key = key.contiguous()
545
+ value = value.contiguous()
546
+
547
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
548
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
549
+ is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False
550
+
551
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
552
+ query,
553
+ key,
554
+ value,
555
+ attn_mask=attention_mask,
556
+ dropout_p=self.attn_dropout.p if self.training else 0.0,
557
+ is_causal=is_causal,
558
+ )
559
+
560
+ # Reshape outputs
561
+ attn_output = attn_output.transpose(1, 2).contiguous()
562
+ attn_output = attn_output.view(bsz, q_len, self.embed_dim)
563
+
564
+ # Final projection
565
+ attn_output = self.c_proj(attn_output)
566
+ attn_output = self.resid_dropout(attn_output)
567
+
568
+ return attn_output, present, None
569
+
570
+
571
+ class GPT2MLP(nn.Module):
572
+ def __init__(self, intermediate_size, config):
573
+ super().__init__()
574
+ embed_dim = config.hidden_size
575
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
576
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
577
+ self.act = ACT2FN[config.activation_function]
578
+ self.dropout = nn.Dropout(config.resid_pdrop)
579
+
580
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
581
+ hidden_states = self.c_fc(hidden_states)
582
+ hidden_states = self.act(hidden_states)
583
+ hidden_states = self.c_proj(hidden_states)
584
+ hidden_states = self.dropout(hidden_states)
585
+ return hidden_states
586
+
587
+
588
+ GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
589
+
590
+
591
+ class GPT2Block(nn.Module):
592
+ def __init__(self, config, layer_idx=None):
593
+ super().__init__()
594
+ hidden_size = config.hidden_size
595
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
596
+ attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation]
597
+
598
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
599
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
600
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
601
+
602
+ if config.add_cross_attention:
603
+ self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
604
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
605
+
606
+ self.mlp = GPT2MLP(inner_dim, config)
607
+
608
+ def forward(
609
+ self,
610
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
611
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
612
+ attention_mask: Optional[torch.FloatTensor] = None,
613
+ head_mask: Optional[torch.FloatTensor] = None,
614
+ encoder_hidden_states: Optional[torch.Tensor] = None,
615
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
616
+ use_cache: Optional[bool] = False,
617
+ output_attentions: Optional[bool] = False,
618
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
619
+ residual = hidden_states
620
+ hidden_states = self.ln_1(hidden_states)
621
+ attn_outputs = self.attn(
622
+ hidden_states,
623
+ layer_past=layer_past,
624
+ attention_mask=attention_mask,
625
+ head_mask=head_mask,
626
+ use_cache=use_cache,
627
+ output_attentions=output_attentions,
628
+ )
629
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
630
+ outputs = attn_outputs[1:]
631
+ # residual connection
632
+ hidden_states = attn_output + residual
633
+
634
+ if encoder_hidden_states is not None:
635
+ # add one self-attention block for cross-attention
636
+ if not hasattr(self, "crossattention"):
637
+ raise ValueError(
638
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
639
+ "cross-attention layers by setting `config.add_cross_attention=True`"
640
+ )
641
+ residual = hidden_states
642
+ hidden_states = self.ln_cross_attn(hidden_states)
643
+ cross_attn_outputs = self.crossattention(
644
+ hidden_states,
645
+ attention_mask=attention_mask,
646
+ head_mask=head_mask,
647
+ encoder_hidden_states=encoder_hidden_states,
648
+ encoder_attention_mask=encoder_attention_mask,
649
+ output_attentions=output_attentions,
650
+ )
651
+ attn_output = cross_attn_outputs[0]
652
+ # residual connection
653
+ hidden_states = residual + attn_output
654
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
655
+
656
+ residual = hidden_states
657
+ hidden_states = self.ln_2(hidden_states)
658
+ feed_forward_hidden_states = self.mlp(hidden_states)
659
+ # residual connection
660
+ hidden_states = residual + feed_forward_hidden_states
661
+
662
+ if use_cache:
663
+ outputs = (hidden_states,) + outputs
664
+ else:
665
+ outputs = (hidden_states,) + outputs[1:]
666
+
667
+ return outputs # hidden_states, present, (attentions, cross_attentions)
668
+
669
+
670
+ class GPT2PreTrainedModel(PreTrainedModel):
671
+ """
672
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
673
+ models.
674
+ """
675
+
676
+ config_class = GPT2Config
677
+ load_tf_weights = load_tf_weights_in_gpt2
678
+ base_model_prefix = "transformer"
679
+ is_parallelizable = True
680
+ supports_gradient_checkpointing = True
681
+ _no_split_modules = ["GPT2Block"]
682
+ _skip_keys_device_placement = "past_key_values"
683
+ _supports_flash_attn_2 = True
684
+ _supports_sdpa = True
685
+
686
+ def __init__(self, *inputs, **kwargs):
687
+ super().__init__(*inputs, **kwargs)
688
+
689
+ def _init_weights(self, module):
690
+ """Initialize the weights."""
691
+ if isinstance(module, (nn.Linear, Conv1D)):
692
+ # Slightly different from the TF version which uses truncated_normal for initialization
693
+ # cf https://github.com/pytorch/pytorch/pull/5617
694
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
695
+ if module.bias is not None:
696
+ module.bias.data.zero_()
697
+ elif isinstance(module, nn.Embedding):
698
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
699
+ if module.padding_idx is not None:
700
+ module.weight.data[module.padding_idx].zero_()
701
+ elif isinstance(module, nn.LayerNorm):
702
+ module.bias.data.zero_()
703
+ module.weight.data.fill_(1.0)
704
+
705
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
706
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
707
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
708
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
709
+ #
710
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
711
+ for name, p in module.named_parameters():
712
+ if name == "c_proj.weight":
713
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
714
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
715
+
716
+
717
+ @dataclass
718
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
719
+ """
720
+ Base class for outputs of models predicting if two sentences are consecutive or not.
721
+
722
+ Args:
723
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
724
+ Language modeling loss.
725
+ mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
726
+ Multiple choice classification loss.
727
+ logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
728
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
729
+ mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
730
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
731
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
732
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
733
+ sequence_length, embed_size_per_head)`).
734
+
735
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
736
+ `past_key_values` input) to speed up sequential decoding.
737
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
738
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
739
+ shape `(batch_size, sequence_length, hidden_size)`.
740
+
741
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
742
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
743
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
744
+ sequence_length)`.
745
+
746
+ GPT2Attentions weights after the attention softmax, used to compute the weighted average in the
747
+ self-attention heads.
748
+ """
749
+
750
+ loss: Optional[torch.FloatTensor] = None
751
+ mc_loss: Optional[torch.FloatTensor] = None
752
+ logits: torch.FloatTensor = None
753
+ mc_logits: torch.FloatTensor = None
754
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
755
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
756
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
757
+
758
+
759
+ GPT2_START_DOCSTRING = r"""
760
+
761
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
762
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
763
+ etc.)
764
+
765
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
766
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
767
+ and behavior.
768
+
769
+ Parameters:
770
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
771
+ Initializing with a config file does not load the weights associated with the model, only the
772
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
773
+ """
774
+
775
+ GPT2_INPUTS_DOCSTRING = r"""
776
+ Args:
777
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
778
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
779
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
780
+ sequence tokens in the vocabulary.
781
+
782
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
783
+ `input_ids`.
784
+
785
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
786
+ [`PreTrainedTokenizer.__call__`] for details.
787
+
788
+ [What are input IDs?](../glossary#input-ids)
789
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
790
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
791
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
792
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
793
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
795
+
796
+ - 1 for tokens that are **not masked**,
797
+ - 0 for tokens that are **masked**.
798
+
799
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
800
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
801
+ `len(past_key_values) + len(input_ids)`
802
+
803
+ [What are attention masks?](../glossary#attention-mask)
804
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
805
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
806
+ 1]`:
807
+
808
+ - 0 corresponds to a *sentence A* token,
809
+ - 1 corresponds to a *sentence B* token.
810
+
811
+ [What are token type IDs?](../glossary#token-type-ids)
812
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
813
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
814
+ config.max_position_embeddings - 1]`.
815
+
816
+ [What are position IDs?](../glossary#position-ids)
817
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
818
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
819
+
820
+ - 1 indicates the head is **not masked**,
821
+ - 0 indicates the head is **masked**.
822
+
823
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
824
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
825
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
826
+ model's internal embedding lookup matrix.
827
+
828
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
829
+ `past_key_values`).
830
+ use_cache (`bool`, *optional*):
831
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
832
+ `past_key_values`).
833
+ output_attentions (`bool`, *optional*):
834
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
835
+ tensors for more detail.
836
+ output_hidden_states (`bool`, *optional*):
837
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
838
+ more detail.
839
+ return_dict (`bool`, *optional*):
840
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
841
+ """
842
+ PARALLELIZE_DOCSTRING = r"""
843
+ This is an experimental feature and is a subject to change at a moment's notice.
844
+
845
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
846
+ it will evenly distribute blocks across all devices.
847
+
848
+ Args:
849
+ device_map (`Dict[int, list]`, *optional*):
850
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
851
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
852
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
853
+ following number of attention modules:
854
+
855
+ - openai-community/gpt2: 12
856
+ - openai-community/gpt2-medium: 24
857
+ - openai-community/gpt2-large: 36
858
+ - openai-community/gpt2-xl: 48
859
+
860
+ Example:
861
+
862
+ ```python
863
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
864
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
865
+ device_map = {
866
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
867
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
868
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
869
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
870
+ }
871
+ model.parallelize(device_map)
872
+ ```
873
+ """
874
+ DEPARALLELIZE_DOCSTRING = r"""
875
+ Moves the model to cpu from a model parallel state.
876
+
877
+ Example:
878
+
879
+ ```python
880
+ # On a 4 GPU machine with openai-community/gpt2-large:
881
+ model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
882
+ device_map = {
883
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
884
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
885
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
886
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
887
+ }
888
+ model.parallelize(device_map) # Splits the model across several devices
889
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
890
+ ```
891
+ """
892
+
893
+
894
+ @add_start_docstrings(
895
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
896
+ GPT2_START_DOCSTRING,
897
+ )
898
+ class GPT2Model(GPT2PreTrainedModel):
899
+ _supports_param_buffer_assignment = False
900
+
901
+ def __init__(self, config):
902
+ super().__init__(config)
903
+
904
+ self.embed_dim = config.hidden_size
905
+
906
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
907
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
908
+
909
+ self.drop = nn.Dropout(config.embd_pdrop)
910
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
911
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
912
+
913
+ # Model parallel
914
+ self.model_parallel = False
915
+ self.device_map = None
916
+ self.gradient_checkpointing = False
917
+ self._attn_implementation = config._attn_implementation
918
+
919
+ # Initialize weights and apply final processing
920
+ self.post_init()
921
+
922
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
923
+ def parallelize(self, device_map=None):
924
+ # Check validity of device_map
925
+ warnings.warn(
926
+ "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
927
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
928
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
929
+ " ...}",
930
+ FutureWarning,
931
+ )
932
+ self.device_map = (
933
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
934
+ )
935
+ assert_device_map(self.device_map, len(self.h))
936
+ self.model_parallel = True
937
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
938
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
939
+ self.wte = self.wte.to(self.first_device)
940
+ self.wpe = self.wpe.to(self.first_device)
941
+ # Load onto devices
942
+ for k, v in self.device_map.items():
943
+ for block in v:
944
+ cuda_device = "cuda:" + str(k)
945
+ self.h[block] = self.h[block].to(cuda_device)
946
+ # ln_f to last
947
+ self.ln_f = self.ln_f.to(self.last_device)
948
+
949
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
950
+ def deparallelize(self):
951
+ warnings.warn(
952
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
953
+ FutureWarning,
954
+ )
955
+ self.model_parallel = False
956
+ self.device_map = None
957
+ self.first_device = "cpu"
958
+ self.last_device = "cpu"
959
+ self.wte = self.wte.to("cpu")
960
+ self.wpe = self.wpe.to("cpu")
961
+ for index in range(len(self.h)):
962
+ self.h[index] = self.h[index].to("cpu")
963
+ self.ln_f = self.ln_f.to("cpu")
964
+ torch.cuda.empty_cache()
965
+
966
+ def get_input_embeddings(self):
967
+ return self.wte
968
+
969
+ def set_input_embeddings(self, new_embeddings):
970
+ self.wte = new_embeddings
971
+
972
+ def _prune_heads(self, heads_to_prune):
973
+ """
974
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
975
+ """
976
+ for layer, heads in heads_to_prune.items():
977
+ self.h[layer].attn.prune_heads(heads)
978
+
979
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
980
+ @add_code_sample_docstrings(
981
+ checkpoint=_CHECKPOINT_FOR_DOC,
982
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
983
+ config_class=_CONFIG_FOR_DOC,
984
+ )
985
+ def forward(
986
+ self,
987
+ input_ids: Optional[torch.LongTensor] = None,
988
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
989
+ attention_mask: Optional[torch.FloatTensor] = None,
990
+ token_type_ids: Optional[torch.LongTensor] = None,
991
+ position_ids: Optional[torch.LongTensor] = None,
992
+ head_mask: Optional[torch.FloatTensor] = None,
993
+ inputs_embeds: Optional[torch.FloatTensor] = None,
994
+ encoder_hidden_states: Optional[torch.Tensor] = None,
995
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
996
+ use_cache: Optional[bool] = None,
997
+ output_attentions: Optional[bool] = None,
998
+ output_hidden_states: Optional[bool] = None,
999
+ return_dict: Optional[bool] = None,
1000
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
1001
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
+ output_hidden_states = (
1003
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1004
+ )
1005
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1006
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
+
1008
+ if input_ids is not None and inputs_embeds is not None:
1009
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1010
+ elif input_ids is not None:
1011
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1012
+ input_shape = input_ids.size()
1013
+ input_ids = input_ids.view(-1, input_shape[-1])
1014
+ batch_size = input_ids.shape[0]
1015
+ elif inputs_embeds is not None:
1016
+ input_shape = inputs_embeds.size()[:-1]
1017
+ batch_size = inputs_embeds.shape[0]
1018
+ else:
1019
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1020
+
1021
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1022
+
1023
+ if token_type_ids is not None:
1024
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
1025
+
1026
+ if past_key_values is None:
1027
+ past_length = 0
1028
+ past_key_values = tuple([None] * len(self.h))
1029
+ else:
1030
+ past_length = past_key_values[0][0].size(-2)
1031
+ if position_ids is None:
1032
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
1033
+ position_ids = position_ids.unsqueeze(0)
1034
+
1035
+ if inputs_embeds is None:
1036
+ inputs_embeds = self.wte(input_ids)
1037
+ position_embeds = self.wpe(position_ids)
1038
+ hidden_states = inputs_embeds + position_embeds
1039
+
1040
+ # Attention mask.
1041
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
1042
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
1043
+ if self._attn_implementation == "flash_attention_2":
1044
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1045
+ elif _use_sdpa:
1046
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1047
+ attention_mask=attention_mask,
1048
+ input_shape=(batch_size, input_shape[-1]),
1049
+ inputs_embeds=inputs_embeds,
1050
+ past_key_values_length=past_length,
1051
+ )
1052
+ else:
1053
+ if attention_mask is not None:
1054
+ # We create a 3D attention mask from a 2D tensor mask.
1055
+ # Sizes are [batch_size, 1, 1, to_seq_length]
1056
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
1057
+ # this attention mask is more simple than the triangular masking of causal attention
1058
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
1059
+ attention_mask = attention_mask[:, None, None, :]
1060
+
1061
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1062
+ # masked positions, this operation will create a tensor which is 0.0 for
1063
+ # positions we want to attend and the dtype's smallest value for masked positions.
1064
+ # Since we are adding it to the raw scores before the softmax, this is
1065
+ # effectively the same as removing these entirely.
1066
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
1067
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
1068
+
1069
+ # If a 2D or 3D attention mask is provided for the cross-attention
1070
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1071
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
1072
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1073
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1074
+ if encoder_attention_mask is None:
1075
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1076
+ if _use_sdpa:
1077
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1078
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
1079
+ )
1080
+ elif not self._attn_implementation == "flash_attention_2":
1081
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1082
+ else:
1083
+ encoder_attention_mask = None
1084
+
1085
+ # Prepare head mask if needed
1086
+ # 1.0 in head_mask indicate we keep the head
1087
+ # attention_probs has shape bsz x n_heads x N x N
1088
+ # head_mask has shape n_layer x batch x n_heads x N x N
1089
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
1090
+
1091
+ if token_type_ids is not None:
1092
+ token_type_embeds = self.wte(token_type_ids)
1093
+ hidden_states = hidden_states + token_type_embeds
1094
+
1095
+ hidden_states = self.drop(hidden_states)
1096
+
1097
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
1098
+
1099
+ if self.gradient_checkpointing and self.training:
1100
+ if use_cache:
1101
+ logger.warning_once(
1102
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1103
+ )
1104
+ use_cache = False
1105
+
1106
+ presents = () if use_cache else None
1107
+ all_self_attentions = () if output_attentions else None
1108
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
1109
+ all_hidden_states = () if output_hidden_states else None
1110
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
1111
+ # Model parallel
1112
+ if self.model_parallel:
1113
+ torch.cuda.set_device(hidden_states.device)
1114
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
1115
+ if layer_past is not None:
1116
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
1117
+ # Ensure that attention_mask is always on the same device as hidden_states
1118
+ if attention_mask is not None:
1119
+ attention_mask = attention_mask.to(hidden_states.device)
1120
+ if isinstance(head_mask, torch.Tensor):
1121
+ head_mask = head_mask.to(hidden_states.device)
1122
+ if output_hidden_states:
1123
+ all_hidden_states = all_hidden_states + (hidden_states,)
1124
+
1125
+ if self.gradient_checkpointing and self.training:
1126
+ outputs = self._gradient_checkpointing_func(
1127
+ block.__call__,
1128
+ hidden_states,
1129
+ None,
1130
+ attention_mask,
1131
+ head_mask[i],
1132
+ encoder_hidden_states,
1133
+ encoder_attention_mask,
1134
+ use_cache,
1135
+ output_attentions,
1136
+ )
1137
+ else:
1138
+ outputs = block(
1139
+ hidden_states,
1140
+ layer_past=layer_past,
1141
+ attention_mask=attention_mask,
1142
+ head_mask=head_mask[i],
1143
+ encoder_hidden_states=encoder_hidden_states,
1144
+ encoder_attention_mask=encoder_attention_mask,
1145
+ use_cache=use_cache,
1146
+ output_attentions=output_attentions,
1147
+ )
1148
+
1149
+ hidden_states = outputs[0]
1150
+ if use_cache is True:
1151
+ presents = presents + (outputs[1],)
1152
+
1153
+ if output_attentions:
1154
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1155
+ if self.config.add_cross_attention:
1156
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1157
+
1158
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1159
+ if self.model_parallel:
1160
+ for k, v in self.device_map.items():
1161
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1162
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1163
+
1164
+ hidden_states = self.ln_f(hidden_states)
1165
+
1166
+ hidden_states = hidden_states.view(output_shape)
1167
+ # Add last hidden state
1168
+ if output_hidden_states:
1169
+ all_hidden_states = all_hidden_states + (hidden_states,)
1170
+
1171
+ if not return_dict:
1172
+ return tuple(
1173
+ v
1174
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1175
+ if v is not None
1176
+ )
1177
+
1178
+ return BaseModelOutputWithPastAndCrossAttentions(
1179
+ last_hidden_state=hidden_states,
1180
+ past_key_values=presents,
1181
+ hidden_states=all_hidden_states,
1182
+ attentions=all_self_attentions,
1183
+ cross_attentions=all_cross_attentions,
1184
+ )
1185
+
1186
+
1187
+ @add_start_docstrings(
1188
+ """
1189
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
1190
+ embeddings).
1191
+ """,
1192
+ GPT2_START_DOCSTRING,
1193
+ )
1194
+ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
1195
+ _tied_weights_keys = ["lm_head.weight"]
1196
+
1197
+ def __init__(self, config):
1198
+ super().__init__(config)
1199
+ self.transformer = GPT2Model(config)
1200
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1201
+
1202
+ # Model parallel
1203
+ self.model_parallel = False
1204
+ self.device_map = None
1205
+
1206
+ # Initialize weights and apply final processing
1207
+ self.post_init()
1208
+
1209
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1210
+ def parallelize(self, device_map=None):
1211
+ warnings.warn(
1212
+ "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1213
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1214
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1215
+ " 0, 'transformer.h.1': 1, ...}",
1216
+ FutureWarning,
1217
+ )
1218
+ self.device_map = (
1219
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1220
+ if device_map is None
1221
+ else device_map
1222
+ )
1223
+ assert_device_map(self.device_map, len(self.transformer.h))
1224
+ self.transformer.parallelize(self.device_map)
1225
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1226
+ self.model_parallel = True
1227
+
1228
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1229
+ def deparallelize(self):
1230
+ warnings.warn(
1231
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1232
+ FutureWarning,
1233
+ )
1234
+ self.transformer.deparallelize()
1235
+ self.transformer = self.transformer.to("cpu")
1236
+ self.lm_head = self.lm_head.to("cpu")
1237
+ self.model_parallel = False
1238
+ torch.cuda.empty_cache()
1239
+
1240
+ def get_output_embeddings(self):
1241
+ return self.lm_head
1242
+
1243
+ def set_output_embeddings(self, new_embeddings):
1244
+ self.lm_head = new_embeddings
1245
+
1246
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1247
+ @add_code_sample_docstrings(
1248
+ checkpoint=_CHECKPOINT_FOR_DOC,
1249
+ output_type=CausalLMOutputWithCrossAttentions,
1250
+ config_class=_CONFIG_FOR_DOC,
1251
+ )
1252
+ def forward(
1253
+ self,
1254
+ input_ids: Optional[torch.LongTensor] = None,
1255
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1256
+ attention_mask: Optional[torch.FloatTensor] = None,
1257
+ token_type_ids: Optional[torch.LongTensor] = None,
1258
+ position_ids: Optional[torch.LongTensor] = None,
1259
+ head_mask: Optional[torch.FloatTensor] = None,
1260
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1261
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1262
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1263
+ labels: Optional[torch.LongTensor] = None,
1264
+ use_cache: Optional[bool] = None,
1265
+ output_attentions: Optional[bool] = None,
1266
+ output_hidden_states: Optional[bool] = None,
1267
+ return_dict: Optional[bool] = None,
1268
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1269
+ r"""
1270
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1271
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1272
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1273
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1274
+ """
1275
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1276
+
1277
+ transformer_outputs = self.transformer(
1278
+ input_ids,
1279
+ past_key_values=past_key_values,
1280
+ attention_mask=attention_mask,
1281
+ token_type_ids=token_type_ids,
1282
+ position_ids=position_ids,
1283
+ head_mask=head_mask,
1284
+ inputs_embeds=inputs_embeds,
1285
+ encoder_hidden_states=encoder_hidden_states,
1286
+ encoder_attention_mask=encoder_attention_mask,
1287
+ use_cache=use_cache,
1288
+ output_attentions=output_attentions,
1289
+ output_hidden_states=output_hidden_states,
1290
+ return_dict=return_dict,
1291
+ )
1292
+ hidden_states = transformer_outputs[0]
1293
+
1294
+ # Set device for model parallelism
1295
+ if self.model_parallel:
1296
+ torch.cuda.set_device(self.transformer.first_device)
1297
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1298
+
1299
+ lm_logits = self.lm_head(hidden_states)
1300
+
1301
+ loss = None
1302
+ if labels is not None:
1303
+ # move labels to correct device to enable model parallelism
1304
+ labels = labels.to(lm_logits.device)
1305
+ # Shift so that tokens < n predict n
1306
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1307
+ shift_labels = labels[..., 1:].contiguous()
1308
+ # Flatten the tokens
1309
+ loss_fct = CrossEntropyLoss()
1310
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1311
+
1312
+ if not return_dict:
1313
+ output = (lm_logits,) + transformer_outputs[1:]
1314
+ return ((loss,) + output) if loss is not None else output
1315
+
1316
+ return CausalLMOutputWithCrossAttentions(
1317
+ loss=loss,
1318
+ logits=lm_logits,
1319
+ past_key_values=transformer_outputs.past_key_values,
1320
+ hidden_states=transformer_outputs.hidden_states,
1321
+ attentions=transformer_outputs.attentions,
1322
+ cross_attentions=transformer_outputs.cross_attentions,
1323
+ )
1324
+
1325
+ @staticmethod
1326
+ def _reorder_cache(
1327
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1328
+ ) -> Tuple[Tuple[torch.Tensor]]:
1329
+ """
1330
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1331
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1332
+ beam_idx at every generation step.
1333
+ """
1334
+ return tuple(
1335
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1336
+ for layer_past in past_key_values
1337
+ )
1338
+
1339
+
1340
+ @add_start_docstrings(
1341
+ """
1342
+ The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1343
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1344
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1345
+ input sequence).
1346
+ """,
1347
+ GPT2_START_DOCSTRING,
1348
+ )
1349
+ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
1350
+ _tied_weights_keys = ["lm_head.weight"]
1351
+
1352
+ def __init__(self, config):
1353
+ super().__init__(config)
1354
+ config.num_labels = 1
1355
+ self.transformer = GPT2Model(config)
1356
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1357
+ self.multiple_choice_head = SequenceSummary(config)
1358
+
1359
+ # Model parallel
1360
+ self.model_parallel = False
1361
+ self.device_map = None
1362
+
1363
+ # Initialize weights and apply final processing
1364
+ self.post_init()
1365
+
1366
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1367
+ def parallelize(self, device_map=None):
1368
+ warnings.warn(
1369
+ "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
1370
+ " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
1371
+ " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1372
+ " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
1373
+ FutureWarning,
1374
+ )
1375
+ self.device_map = (
1376
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1377
+ if device_map is None
1378
+ else device_map
1379
+ )
1380
+ assert_device_map(self.device_map, len(self.transformer.h))
1381
+ self.transformer.parallelize(self.device_map)
1382
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1383
+ self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
1384
+ self.model_parallel = True
1385
+
1386
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1387
+ def deparallelize(self):
1388
+ warnings.warn(
1389
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1390
+ FutureWarning,
1391
+ )
1392
+ self.transformer.deparallelize()
1393
+ self.transformer = self.transformer.to("cpu")
1394
+ self.lm_head = self.lm_head.to("cpu")
1395
+ self.multiple_choice_head = self.multiple_choice_head.to("cpu")
1396
+ self.model_parallel = False
1397
+ torch.cuda.empty_cache()
1398
+
1399
+ def get_output_embeddings(self):
1400
+ return self.lm_head
1401
+
1402
+ def set_output_embeddings(self, new_embeddings):
1403
+ self.lm_head = new_embeddings
1404
+
1405
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1406
+ @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
1407
+ def forward(
1408
+ self,
1409
+ input_ids: Optional[torch.LongTensor] = None,
1410
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1411
+ attention_mask: Optional[torch.FloatTensor] = None,
1412
+ token_type_ids: Optional[torch.LongTensor] = None,
1413
+ position_ids: Optional[torch.LongTensor] = None,
1414
+ head_mask: Optional[torch.FloatTensor] = None,
1415
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1416
+ mc_token_ids: Optional[torch.LongTensor] = None,
1417
+ labels: Optional[torch.LongTensor] = None,
1418
+ mc_labels: Optional[torch.LongTensor] = None,
1419
+ use_cache: Optional[bool] = None,
1420
+ output_attentions: Optional[bool] = None,
1421
+ output_hidden_states: Optional[bool] = None,
1422
+ return_dict: Optional[bool] = None,
1423
+ **kwargs,
1424
+ ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]:
1425
+ r"""
1426
+ mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
1427
+ Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
1428
+ 1]`.
1429
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1430
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1431
+ `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
1432
+ `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
1433
+ mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
1434
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
1435
+ where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
1436
+
1437
+ Return:
1438
+
1439
+ Example:
1440
+
1441
+ ```python
1442
+ >>> import torch
1443
+ >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
1444
+
1445
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
1446
+ >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
1447
+
1448
+ >>> # Add a [CLS] to the vocabulary (we should train it also!)
1449
+ >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
1450
+ >>> # Update the model embeddings with the new vocabulary size
1451
+ >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
1452
+
1453
+ >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
1454
+ >>> encoded_choices = [tokenizer.encode(s) for s in choices]
1455
+ >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
1456
+
1457
+ >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
1458
+ >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
1459
+
1460
+ >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
1461
+ >>> lm_logits = outputs.logits
1462
+ >>> mc_logits = outputs.mc_logits
1463
+ ```"""
1464
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1465
+
1466
+ transformer_outputs = self.transformer(
1467
+ input_ids,
1468
+ past_key_values=past_key_values,
1469
+ attention_mask=attention_mask,
1470
+ token_type_ids=token_type_ids,
1471
+ position_ids=position_ids,
1472
+ head_mask=head_mask,
1473
+ inputs_embeds=inputs_embeds,
1474
+ use_cache=use_cache,
1475
+ output_attentions=output_attentions,
1476
+ output_hidden_states=output_hidden_states,
1477
+ return_dict=return_dict,
1478
+ )
1479
+
1480
+ hidden_states = transformer_outputs[0]
1481
+
1482
+ # Set device for model parallelism
1483
+ if self.model_parallel:
1484
+ torch.cuda.set_device(self.transformer.first_device)
1485
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1486
+
1487
+ lm_logits = self.lm_head(hidden_states)
1488
+ mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
1489
+
1490
+ mc_loss = None
1491
+ if mc_labels is not None:
1492
+ loss_fct = CrossEntropyLoss()
1493
+ mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
1494
+ lm_loss = None
1495
+ if labels is not None:
1496
+ labels = labels.to(lm_logits.device)
1497
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1498
+ shift_labels = labels[..., 1:].contiguous()
1499
+ loss_fct = CrossEntropyLoss()
1500
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1501
+
1502
+ if not return_dict:
1503
+ output = (lm_logits, mc_logits) + transformer_outputs[1:]
1504
+ if mc_loss is not None:
1505
+ output = (mc_loss,) + output
1506
+ return ((lm_loss,) + output) if lm_loss is not None else output
1507
+
1508
+ return GPT2DoubleHeadsModelOutput(
1509
+ loss=lm_loss,
1510
+ mc_loss=mc_loss,
1511
+ logits=lm_logits,
1512
+ mc_logits=mc_logits,
1513
+ past_key_values=transformer_outputs.past_key_values,
1514
+ hidden_states=transformer_outputs.hidden_states,
1515
+ attentions=transformer_outputs.attentions,
1516
+ )
1517
+
1518
+ @staticmethod
1519
+ def _reorder_cache(
1520
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1521
+ ) -> Tuple[Tuple[torch.Tensor]]:
1522
+ """
1523
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1524
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1525
+ beam_idx at every generation step.
1526
+ """
1527
+ return tuple(
1528
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1529
+ for layer_past in past_key_values
1530
+ )
1531
+
1532
+
1533
+ @add_start_docstrings(
1534
+ """
1535
+ The GPT2 Model transformer with a sequence classification head on top (linear layer).
1536
+
1537
+ [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1538
+ (e.g. GPT-1) do.
1539
+
1540
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1541
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1542
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1543
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1544
+ each row of the batch).
1545
+ """,
1546
+ GPT2_START_DOCSTRING,
1547
+ )
1548
+ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1549
+ def __init__(self, config):
1550
+ super().__init__(config)
1551
+ self.num_labels = config.num_labels
1552
+ self.transformer = GPT2Model(config)
1553
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1554
+
1555
+ # Model parallel
1556
+ self.model_parallel = False
1557
+ self.device_map = None
1558
+
1559
+ # Initialize weights and apply final processing
1560
+ self.post_init()
1561
+
1562
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1563
+ @add_code_sample_docstrings(
1564
+ checkpoint="microsoft/DialogRPT-updown",
1565
+ output_type=SequenceClassifierOutputWithPast,
1566
+ config_class=_CONFIG_FOR_DOC,
1567
+ )
1568
+ def forward(
1569
+ self,
1570
+ input_ids: Optional[torch.LongTensor] = None,
1571
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1572
+ attention_mask: Optional[torch.FloatTensor] = None,
1573
+ token_type_ids: Optional[torch.LongTensor] = None,
1574
+ position_ids: Optional[torch.LongTensor] = None,
1575
+ head_mask: Optional[torch.FloatTensor] = None,
1576
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1577
+ labels: Optional[torch.LongTensor] = None,
1578
+ use_cache: Optional[bool] = None,
1579
+ output_attentions: Optional[bool] = None,
1580
+ output_hidden_states: Optional[bool] = None,
1581
+ return_dict: Optional[bool] = None,
1582
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1583
+ r"""
1584
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1585
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1586
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1587
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1588
+ """
1589
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1590
+
1591
+ transformer_outputs = self.transformer(
1592
+ input_ids,
1593
+ past_key_values=past_key_values,
1594
+ attention_mask=attention_mask,
1595
+ token_type_ids=token_type_ids,
1596
+ position_ids=position_ids,
1597
+ head_mask=head_mask,
1598
+ inputs_embeds=inputs_embeds,
1599
+ use_cache=use_cache,
1600
+ output_attentions=output_attentions,
1601
+ output_hidden_states=output_hidden_states,
1602
+ return_dict=return_dict,
1603
+ )
1604
+ hidden_states = transformer_outputs[0]
1605
+ logits = self.score(hidden_states)
1606
+
1607
+ if input_ids is not None:
1608
+ batch_size, sequence_length = input_ids.shape[:2]
1609
+ else:
1610
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1611
+
1612
+ assert (
1613
+ self.config.pad_token_id is not None or batch_size == 1
1614
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1615
+ if self.config.pad_token_id is None:
1616
+ sequence_lengths = -1
1617
+ else:
1618
+ if input_ids is not None:
1619
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1620
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1621
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1622
+ sequence_lengths = sequence_lengths.to(logits.device)
1623
+ else:
1624
+ sequence_lengths = -1
1625
+ logger.warning_once(
1626
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1627
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1628
+ )
1629
+
1630
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1631
+
1632
+ loss = None
1633
+ if labels is not None:
1634
+ if self.config.problem_type is None:
1635
+ if self.num_labels == 1:
1636
+ self.config.problem_type = "regression"
1637
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1638
+ self.config.problem_type = "single_label_classification"
1639
+ else:
1640
+ self.config.problem_type = "multi_label_classification"
1641
+
1642
+ if self.config.problem_type == "regression":
1643
+ loss_fct = MSELoss()
1644
+ if self.num_labels == 1:
1645
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1646
+ else:
1647
+ loss = loss_fct(pooled_logits, labels)
1648
+ elif self.config.problem_type == "single_label_classification":
1649
+ loss_fct = CrossEntropyLoss()
1650
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1651
+ elif self.config.problem_type == "multi_label_classification":
1652
+ loss_fct = BCEWithLogitsLoss()
1653
+ loss = loss_fct(pooled_logits, labels)
1654
+ if not return_dict:
1655
+ output = (pooled_logits,) + transformer_outputs[1:]
1656
+ return ((loss,) + output) if loss is not None else output
1657
+
1658
+ return SequenceClassifierOutputWithPast(
1659
+ loss=loss,
1660
+ logits=pooled_logits,
1661
+ past_key_values=transformer_outputs.past_key_values,
1662
+ hidden_states=transformer_outputs.hidden_states,
1663
+ attentions=transformer_outputs.attentions,
1664
+ )
1665
+
1666
+
1667
+ @add_start_docstrings(
1668
+ """
1669
+ GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1670
+ Named-Entity-Recognition (NER) tasks.
1671
+ """,
1672
+ GPT2_START_DOCSTRING,
1673
+ )
1674
+ class GPT2ForTokenClassification(GPT2PreTrainedModel):
1675
+ def __init__(self, config):
1676
+ super().__init__(config)
1677
+ self.num_labels = config.num_labels
1678
+
1679
+ self.transformer = GPT2Model(config)
1680
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1681
+ classifier_dropout = config.classifier_dropout
1682
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1683
+ classifier_dropout = config.hidden_dropout
1684
+ else:
1685
+ classifier_dropout = 0.1
1686
+ self.dropout = nn.Dropout(classifier_dropout)
1687
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1688
+
1689
+ # Model parallel
1690
+ self.model_parallel = False
1691
+ self.device_map = None
1692
+
1693
+ # Initialize weights and apply final processing
1694
+ self.post_init()
1695
+
1696
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
1697
+ # fmt: off
1698
+ @add_code_sample_docstrings(
1699
+ checkpoint="brad1141/gpt2-finetuned-comp2",
1700
+ output_type=TokenClassifierOutput,
1701
+ config_class=_CONFIG_FOR_DOC,
1702
+ expected_loss=0.25,
1703
+ expected_output=[
1704
+ "Lead",
1705
+ "Lead",
1706
+ "Lead",
1707
+ "Position",
1708
+ "Lead",
1709
+ "Lead",
1710
+ "Lead",
1711
+ "Lead",
1712
+ "Lead",
1713
+ "Lead",
1714
+ "Lead",
1715
+ "Lead",
1716
+ ],
1717
+ )
1718
+ # fmt: on
1719
+ def forward(
1720
+ self,
1721
+ input_ids: Optional[torch.LongTensor] = None,
1722
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1723
+ attention_mask: Optional[torch.FloatTensor] = None,
1724
+ token_type_ids: Optional[torch.LongTensor] = None,
1725
+ position_ids: Optional[torch.LongTensor] = None,
1726
+ head_mask: Optional[torch.FloatTensor] = None,
1727
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1728
+ labels: Optional[torch.LongTensor] = None,
1729
+ use_cache: Optional[bool] = None,
1730
+ output_attentions: Optional[bool] = None,
1731
+ output_hidden_states: Optional[bool] = None,
1732
+ return_dict: Optional[bool] = None,
1733
+ ) -> Union[Tuple, TokenClassifierOutput]:
1734
+ r"""
1735
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1736
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1737
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1738
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1739
+ """
1740
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1741
+
1742
+ transformer_outputs = self.transformer(
1743
+ input_ids,
1744
+ past_key_values=past_key_values,
1745
+ attention_mask=attention_mask,
1746
+ token_type_ids=token_type_ids,
1747
+ position_ids=position_ids,
1748
+ head_mask=head_mask,
1749
+ inputs_embeds=inputs_embeds,
1750
+ use_cache=use_cache,
1751
+ output_attentions=output_attentions,
1752
+ output_hidden_states=output_hidden_states,
1753
+ return_dict=return_dict,
1754
+ )
1755
+
1756
+ hidden_states = transformer_outputs[0]
1757
+ hidden_states = self.dropout(hidden_states)
1758
+ logits = self.classifier(hidden_states)
1759
+
1760
+ loss = None
1761
+ if labels is not None:
1762
+ labels = labels.to(logits.device)
1763
+ loss_fct = CrossEntropyLoss()
1764
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1765
+
1766
+ if not return_dict:
1767
+ output = (logits,) + transformer_outputs[2:]
1768
+ return ((loss,) + output) if loss is not None else output
1769
+
1770
+ return TokenClassifierOutput(
1771
+ loss=loss,
1772
+ logits=logits,
1773
+ hidden_states=transformer_outputs.hidden_states,
1774
+ attentions=transformer_outputs.attentions,
1775
+ )
1776
+
1777
+
1778
+ @add_start_docstrings(
1779
+ """
1780
+ The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like
1781
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1782
+ """,
1783
+ GPT2_START_DOCSTRING,
1784
+ )
1785
+ class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
1786
+ def __init__(self, config):
1787
+ super().__init__(config)
1788
+ self.num_labels = config.num_labels
1789
+ self.transformer = GPT2Model(config)
1790
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1791
+
1792
+ # Model parallel
1793
+ self.model_parallel = False
1794
+ self.device_map = None
1795
+
1796
+ # Initialize weights and apply final processing
1797
+ self.post_init()
1798
+
1799
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1800
+ @add_code_sample_docstrings(
1801
+ checkpoint=_CHECKPOINT_FOR_DOC,
1802
+ output_type=QuestionAnsweringModelOutput,
1803
+ config_class=_CONFIG_FOR_DOC,
1804
+ real_checkpoint=_CHECKPOINT_FOR_DOC,
1805
+ )
1806
+ def forward(
1807
+ self,
1808
+ input_ids: Optional[torch.LongTensor] = None,
1809
+ attention_mask: Optional[torch.FloatTensor] = None,
1810
+ token_type_ids: Optional[torch.LongTensor] = None,
1811
+ position_ids: Optional[torch.LongTensor] = None,
1812
+ head_mask: Optional[torch.FloatTensor] = None,
1813
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1814
+ start_positions: Optional[torch.LongTensor] = None,
1815
+ end_positions: Optional[torch.LongTensor] = None,
1816
+ output_attentions: Optional[bool] = None,
1817
+ output_hidden_states: Optional[bool] = None,
1818
+ return_dict: Optional[bool] = None,
1819
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1820
+ r"""
1821
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1822
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1823
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1824
+ are not taken into account for computing the loss.
1825
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1826
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1827
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1828
+ are not taken into account for computing the loss.
1829
+ """
1830
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1831
+
1832
+ outputs = self.transformer(
1833
+ input_ids,
1834
+ attention_mask=attention_mask,
1835
+ token_type_ids=token_type_ids,
1836
+ position_ids=position_ids,
1837
+ head_mask=head_mask,
1838
+ inputs_embeds=inputs_embeds,
1839
+ output_attentions=output_attentions,
1840
+ output_hidden_states=output_hidden_states,
1841
+ return_dict=return_dict,
1842
+ )
1843
+
1844
+ sequence_output = outputs[0]
1845
+
1846
+ logits = self.qa_outputs(sequence_output)
1847
+ start_logits, end_logits = logits.split(1, dim=-1)
1848
+ start_logits = start_logits.squeeze(-1).contiguous()
1849
+ end_logits = end_logits.squeeze(-1).contiguous()
1850
+
1851
+ total_loss = None
1852
+ if start_positions is not None and end_positions is not None:
1853
+ # If we are on multi-GPU, split add a dimension
1854
+ if len(start_positions.size()) > 1:
1855
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1856
+ if len(end_positions.size()) > 1:
1857
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1858
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1859
+ ignored_index = start_logits.size(1)
1860
+ start_positions = start_positions.clamp(0, ignored_index)
1861
+ end_positions = end_positions.clamp(0, ignored_index)
1862
+
1863
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1864
+ start_loss = loss_fct(start_logits, start_positions)
1865
+ end_loss = loss_fct(end_logits, end_positions)
1866
+ total_loss = (start_loss + end_loss) / 2
1867
+
1868
+ if not return_dict:
1869
+ output = (start_logits, end_logits) + outputs[2:]
1870
+ return ((total_loss,) + output) if total_loss is not None else output
1871
+
1872
+ return QuestionAnsweringModelOutput(
1873
+ loss=total_loss,
1874
+ start_logits=start_logits,
1875
+ end_logits=end_logits,
1876
+ hidden_states=outputs.hidden_states,
1877
+ attentions=outputs.attentions,
1878
+ )