diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,889 @@
1
+ # Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import List, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from torch.nn import LayerNorm
22
+ from torch.nn.utils import skip_init
23
+ from transformers import PretrainedConfig, PreTrainedModel
24
+ from transformers.modeling_outputs import BaseModelOutputWithPast
25
+
26
+ from ...utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ChatGLMConfig(PretrainedConfig):
33
+ model_type = "chatglm"
34
+
35
+ def __init__(
36
+ self,
37
+ num_layers=28,
38
+ padded_vocab_size=65024,
39
+ hidden_size=4096,
40
+ ffn_hidden_size=13696,
41
+ kv_channels=128,
42
+ num_attention_heads=32,
43
+ seq_length=2048,
44
+ hidden_dropout=0.0,
45
+ classifier_dropout=None,
46
+ attention_dropout=0.0,
47
+ layernorm_epsilon=1e-5,
48
+ rmsnorm=True,
49
+ apply_residual_connection_post_layernorm=False,
50
+ post_layer_norm=True,
51
+ add_bias_linear=False,
52
+ add_qkv_bias=False,
53
+ bias_dropout_fusion=True,
54
+ multi_query_attention=False,
55
+ multi_query_group_num=1,
56
+ apply_query_key_layer_scaling=True,
57
+ attention_softmax_in_fp32=True,
58
+ fp32_residual_connection=False,
59
+ quantization_bit=0,
60
+ pre_seq_len=None,
61
+ prefix_projection=False,
62
+ **kwargs,
63
+ ):
64
+ self.num_layers = num_layers
65
+ self.vocab_size = padded_vocab_size
66
+ self.padded_vocab_size = padded_vocab_size
67
+ self.hidden_size = hidden_size
68
+ self.ffn_hidden_size = ffn_hidden_size
69
+ self.kv_channels = kv_channels
70
+ self.num_attention_heads = num_attention_heads
71
+ self.seq_length = seq_length
72
+ self.hidden_dropout = hidden_dropout
73
+ self.classifier_dropout = classifier_dropout
74
+ self.attention_dropout = attention_dropout
75
+ self.layernorm_epsilon = layernorm_epsilon
76
+ self.rmsnorm = rmsnorm
77
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
78
+ self.post_layer_norm = post_layer_norm
79
+ self.add_bias_linear = add_bias_linear
80
+ self.add_qkv_bias = add_qkv_bias
81
+ self.bias_dropout_fusion = bias_dropout_fusion
82
+ self.multi_query_attention = multi_query_attention
83
+ self.multi_query_group_num = multi_query_group_num
84
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
85
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
86
+ self.fp32_residual_connection = fp32_residual_connection
87
+ self.quantization_bit = quantization_bit
88
+ self.pre_seq_len = pre_seq_len
89
+ self.prefix_projection = prefix_projection
90
+ super().__init__(**kwargs)
91
+
92
+
93
+ class RMSNorm(torch.nn.Module):
94
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
95
+ super().__init__()
96
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
97
+ self.eps = eps
98
+
99
+ def forward(self, hidden_states: torch.Tensor):
100
+ input_dtype = hidden_states.dtype
101
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
102
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
103
+
104
+ return (self.weight * hidden_states).to(input_dtype)
105
+
106
+
107
+ def _config_to_kwargs(args):
108
+ common_kwargs = {
109
+ "dtype": args.torch_dtype,
110
+ }
111
+ return common_kwargs
112
+
113
+
114
+ class CoreAttention(torch.nn.Module):
115
+ def __init__(self, config: ChatGLMConfig, layer_number):
116
+ super(CoreAttention, self).__init__()
117
+
118
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
119
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
120
+ if self.apply_query_key_layer_scaling:
121
+ self.attention_softmax_in_fp32 = True
122
+ self.layer_number = max(1, layer_number)
123
+
124
+ projection_size = config.kv_channels * config.num_attention_heads
125
+
126
+ # Per attention head and per partition values.
127
+ self.hidden_size_per_partition = projection_size
128
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
129
+ self.num_attention_heads_per_partition = config.num_attention_heads
130
+
131
+ coeff = None
132
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
133
+ if self.apply_query_key_layer_scaling:
134
+ coeff = self.layer_number
135
+ self.norm_factor *= coeff
136
+ self.coeff = coeff
137
+
138
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
139
+
140
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
141
+ pytorch_major_version = int(torch.__version__.split(".")[0])
142
+ if pytorch_major_version >= 2:
143
+ query_layer, key_layer, value_layer = [
144
+ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
145
+ ]
146
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
147
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
148
+ query_layer, key_layer, value_layer, is_causal=True
149
+ )
150
+ else:
151
+ if attention_mask is not None:
152
+ attention_mask = ~attention_mask
153
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
154
+ query_layer, key_layer, value_layer, attention_mask
155
+ )
156
+ context_layer = context_layer.permute(2, 0, 1, 3)
157
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
158
+ context_layer = context_layer.reshape(*new_context_layer_shape)
159
+ else:
160
+ # Raw attention scores
161
+
162
+ # [b, np, sq, sk]
163
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
164
+
165
+ # [sq, b, np, hn] -> [sq, b * np, hn]
166
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
167
+ # [sk, b, np, hn] -> [sk, b * np, hn]
168
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
169
+
170
+ # preallocting input tensor: [b * np, sq, sk]
171
+ matmul_input_buffer = torch.empty(
172
+ output_size[0] * output_size[1],
173
+ output_size[2],
174
+ output_size[3],
175
+ dtype=query_layer.dtype,
176
+ device=query_layer.device,
177
+ )
178
+
179
+ # Raw attention scores. [b * np, sq, sk]
180
+ matmul_result = torch.baddbmm(
181
+ matmul_input_buffer,
182
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
183
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
184
+ beta=0.0,
185
+ alpha=(1.0 / self.norm_factor),
186
+ )
187
+
188
+ # change view to [b, np, sq, sk]
189
+ attention_scores = matmul_result.view(*output_size)
190
+
191
+ # ===========================
192
+ # Attention probs and dropout
193
+ # ===========================
194
+
195
+ # attention scores and attention mask [b, np, sq, sk]
196
+ if self.attention_softmax_in_fp32:
197
+ attention_scores = attention_scores.float()
198
+ if self.coeff is not None:
199
+ attention_scores = attention_scores * self.coeff
200
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
201
+ attention_mask = torch.ones(
202
+ output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool
203
+ )
204
+ attention_mask.tril_()
205
+ attention_mask = ~attention_mask
206
+ if attention_mask is not None:
207
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
208
+ attention_probs = F.softmax(attention_scores, dim=-1)
209
+ attention_probs = attention_probs.type_as(value_layer)
210
+
211
+ # This is actually dropping out entire tokens to attend to, which might
212
+ # seem a bit unusual, but is taken from the original Transformer paper.
213
+ attention_probs = self.attention_dropout(attention_probs)
214
+ # =========================
215
+ # Context layer. [sq, b, hp]
216
+ # =========================
217
+
218
+ # value_layer -> context layer.
219
+ # [sk, b, np, hn] --> [b, np, sq, hn]
220
+
221
+ # context layer shape: [b, np, sq, hn]
222
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
223
+ # change view [sk, b * np, hn]
224
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
225
+ # change view [b * np, sq, sk]
226
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
227
+ # matmul: [b * np, sq, hn]
228
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
229
+ # change view [b, np, sq, hn]
230
+ context_layer = context_layer.view(*output_size)
231
+ # [b, np, sq, hn] --> [sq, b, np, hn]
232
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
233
+ # [sq, b, np, hn] --> [sq, b, hp]
234
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
235
+ context_layer = context_layer.view(*new_context_layer_shape)
236
+
237
+ return context_layer
238
+
239
+
240
+ def split_tensor_along_last_dim(
241
+ tensor: torch.Tensor,
242
+ num_partitions: int,
243
+ contiguous_split_chunks: bool = False,
244
+ ) -> List[torch.Tensor]:
245
+ """Split a tensor along its last dimension.
246
+
247
+ Arguments:
248
+ tensor: input tensor.
249
+ num_partitions: number of partitions to split the tensor
250
+ contiguous_split_chunks: If True, make each chunk contiguous
251
+ in memory.
252
+
253
+ Returns:
254
+ A list of Tensors
255
+ """
256
+ # Get the size and dimension.
257
+ last_dim = tensor.dim() - 1
258
+ last_dim_size = tensor.size()[last_dim] // num_partitions
259
+ # Split.
260
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
261
+ # Note: torch.split does not create contiguous tensors by default.
262
+ if contiguous_split_chunks:
263
+ return tuple(chunk.contiguous() for chunk in tensor_list)
264
+
265
+ return tensor_list
266
+
267
+
268
+ @torch.jit.script
269
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
270
+ # x: [sq, b, np, hn]
271
+ sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3)
272
+ rot_dim = rope_cache.shape[-2] * 2
273
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
274
+ # truncate to support variable sizes
275
+ rope_cache = rope_cache[:sq]
276
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
277
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
278
+ x_out2 = torch.stack(
279
+ [
280
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
281
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
282
+ ],
283
+ -1,
284
+ )
285
+ x_out2 = x_out2.flatten(3)
286
+ return torch.cat((x_out2, x_pass), dim=-1)
287
+
288
+
289
+ class SelfAttention(torch.nn.Module):
290
+ """Parallel self-attention layer abstract class.
291
+
292
+ Self-attention layer takes input with size [s, b, h] and returns output of the same size.
293
+ """
294
+
295
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
296
+ super(SelfAttention, self).__init__()
297
+ self.layer_number = max(1, layer_number)
298
+
299
+ self.projection_size = config.kv_channels * config.num_attention_heads
300
+
301
+ # Per attention head and per partition values.
302
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
303
+ self.num_attention_heads_per_partition = config.num_attention_heads
304
+
305
+ self.multi_query_attention = config.multi_query_attention
306
+ self.qkv_hidden_size = 3 * self.projection_size
307
+ if self.multi_query_attention:
308
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
309
+ self.qkv_hidden_size = (
310
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
311
+ )
312
+ self.query_key_value = nn.Linear(
313
+ config.hidden_size,
314
+ self.qkv_hidden_size,
315
+ bias=config.add_bias_linear or config.add_qkv_bias,
316
+ device=device,
317
+ **_config_to_kwargs(config),
318
+ )
319
+
320
+ self.core_attention = CoreAttention(config, self.layer_number)
321
+
322
+ # Output.
323
+ self.dense = nn.Linear(
324
+ self.projection_size,
325
+ config.hidden_size,
326
+ bias=config.add_bias_linear,
327
+ device=device,
328
+ **_config_to_kwargs(config),
329
+ )
330
+
331
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
332
+ if self.multi_query_attention:
333
+ num_attention_heads = self.num_multi_query_groups_per_partition
334
+ else:
335
+ num_attention_heads = self.num_attention_heads_per_partition
336
+ return torch.empty(
337
+ inference_max_sequence_len,
338
+ batch_size,
339
+ num_attention_heads,
340
+ self.hidden_size_per_attention_head,
341
+ dtype=dtype,
342
+ device=device,
343
+ )
344
+
345
+ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
346
+ # hidden_states: [sq, b, h]
347
+
348
+ # =================================================
349
+ # Pre-allocate memory for key-values for inference.
350
+ # =================================================
351
+ # =====================
352
+ # Query, Key, and Value
353
+ # =====================
354
+
355
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
356
+ mixed_x_layer = self.query_key_value(hidden_states)
357
+
358
+ if self.multi_query_attention:
359
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
360
+ [
361
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
362
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
363
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
364
+ ],
365
+ dim=-1,
366
+ )
367
+ query_layer = query_layer.view(
368
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
369
+ )
370
+ key_layer = key_layer.view(
371
+ key_layer.size()[:-1]
372
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
373
+ )
374
+ value_layer = value_layer.view(
375
+ value_layer.size()[:-1]
376
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
377
+ )
378
+ else:
379
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
380
+ self.num_attention_heads_per_partition,
381
+ 3 * self.hidden_size_per_attention_head,
382
+ )
383
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
384
+
385
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
386
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
387
+
388
+ # apply relative positional encoding (rotary embedding)
389
+ if rotary_pos_emb is not None:
390
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
391
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
392
+
393
+ # adjust key and value for inference
394
+ if kv_cache is not None:
395
+ cache_k, cache_v = kv_cache
396
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
397
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
398
+ if use_cache:
399
+ kv_cache = (key_layer, value_layer)
400
+ else:
401
+ kv_cache = None
402
+
403
+ if self.multi_query_attention:
404
+ key_layer = key_layer.unsqueeze(-2)
405
+ key_layer = key_layer.expand(
406
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
407
+ )
408
+ key_layer = key_layer.contiguous().view(
409
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
410
+ )
411
+ value_layer = value_layer.unsqueeze(-2)
412
+ value_layer = value_layer.expand(
413
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
414
+ )
415
+ value_layer = value_layer.contiguous().view(
416
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
417
+ )
418
+
419
+ # ==================================
420
+ # core attention computation
421
+ # ==================================
422
+
423
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
424
+
425
+ # =================
426
+ # Output. [sq, b, h]
427
+ # =================
428
+
429
+ output = self.dense(context_layer)
430
+
431
+ return output, kv_cache
432
+
433
+
434
+ class MLP(torch.nn.Module):
435
+ """MLP.
436
+
437
+ MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation,
438
+ and project the state back into h hidden dimension.
439
+ """
440
+
441
+ def __init__(self, config: ChatGLMConfig, device=None):
442
+ super(MLP, self).__init__()
443
+
444
+ self.add_bias = config.add_bias_linear
445
+
446
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
447
+ self.dense_h_to_4h = nn.Linear(
448
+ config.hidden_size,
449
+ config.ffn_hidden_size * 2,
450
+ bias=self.add_bias,
451
+ device=device,
452
+ **_config_to_kwargs(config),
453
+ )
454
+
455
+ def swiglu(x):
456
+ x = torch.chunk(x, 2, dim=-1)
457
+ return F.silu(x[0]) * x[1]
458
+
459
+ self.activation_func = swiglu
460
+
461
+ # Project back to h.
462
+ self.dense_4h_to_h = nn.Linear(
463
+ config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
464
+ )
465
+
466
+ def forward(self, hidden_states):
467
+ # [s, b, 4hp]
468
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
469
+ intermediate_parallel = self.activation_func(intermediate_parallel)
470
+ # [s, b, h]
471
+ output = self.dense_4h_to_h(intermediate_parallel)
472
+ return output
473
+
474
+
475
+ class GLMBlock(torch.nn.Module):
476
+ """A single transformer layer.
477
+
478
+ Transformer layer takes input with size [s, b, h] and returns an output of the same size.
479
+ """
480
+
481
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
482
+ super(GLMBlock, self).__init__()
483
+ self.layer_number = layer_number
484
+
485
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
486
+
487
+ self.fp32_residual_connection = config.fp32_residual_connection
488
+
489
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
490
+ # Layernorm on the input data.
491
+ self.input_layernorm = LayerNormFunc(
492
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
493
+ )
494
+
495
+ # Self attention.
496
+ self.self_attention = SelfAttention(config, layer_number, device=device)
497
+ self.hidden_dropout = config.hidden_dropout
498
+
499
+ # Layernorm on the attention output
500
+ self.post_attention_layernorm = LayerNormFunc(
501
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
502
+ )
503
+
504
+ # MLP
505
+ self.mlp = MLP(config, device=device)
506
+
507
+ def forward(
508
+ self,
509
+ hidden_states,
510
+ attention_mask,
511
+ rotary_pos_emb,
512
+ kv_cache=None,
513
+ use_cache=True,
514
+ ):
515
+ # hidden_states: [s, b, h]
516
+
517
+ # Layer norm at the beginning of the transformer layer.
518
+ layernorm_output = self.input_layernorm(hidden_states)
519
+ # Self attention.
520
+ attention_output, kv_cache = self.self_attention(
521
+ layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache
522
+ )
523
+
524
+ # Residual connection.
525
+ if self.apply_residual_connection_post_layernorm:
526
+ residual = layernorm_output
527
+ else:
528
+ residual = hidden_states
529
+
530
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
531
+ layernorm_input = residual + layernorm_input
532
+
533
+ # Layer norm post the self attention.
534
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
535
+
536
+ # MLP.
537
+ mlp_output = self.mlp(layernorm_output)
538
+
539
+ # Second residual connection.
540
+ if self.apply_residual_connection_post_layernorm:
541
+ residual = layernorm_output
542
+ else:
543
+ residual = layernorm_input
544
+
545
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
546
+ output = residual + output
547
+
548
+ return output, kv_cache
549
+
550
+
551
+ class GLMTransformer(torch.nn.Module):
552
+ """Transformer class."""
553
+
554
+ def __init__(self, config: ChatGLMConfig, device=None):
555
+ super(GLMTransformer, self).__init__()
556
+
557
+ self.fp32_residual_connection = config.fp32_residual_connection
558
+ self.post_layer_norm = config.post_layer_norm
559
+
560
+ # Number of layers.
561
+ self.num_layers = config.num_layers
562
+
563
+ # Transformer layers.
564
+ def build_layer(layer_number):
565
+ return GLMBlock(config, layer_number, device=device)
566
+
567
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
568
+
569
+ if self.post_layer_norm:
570
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571
+ # Final layer norm before output.
572
+ self.final_layernorm = LayerNormFunc(
573
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
574
+ )
575
+
576
+ self.gradient_checkpointing = False
577
+
578
+ def _get_layer(self, layer_number):
579
+ return self.layers[layer_number]
580
+
581
+ def forward(
582
+ self,
583
+ hidden_states,
584
+ attention_mask,
585
+ rotary_pos_emb,
586
+ kv_caches=None,
587
+ use_cache: Optional[bool] = True,
588
+ output_hidden_states: Optional[bool] = False,
589
+ ):
590
+ if not kv_caches:
591
+ kv_caches = [None for _ in range(self.num_layers)]
592
+ presents = () if use_cache else None
593
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
594
+ if use_cache:
595
+ logger.warning_once(
596
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
597
+ )
598
+ use_cache = False
599
+
600
+ all_self_attentions = None
601
+ all_hidden_states = () if output_hidden_states else None
602
+ for index in range(self.num_layers):
603
+ if output_hidden_states:
604
+ all_hidden_states = all_hidden_states + (hidden_states,)
605
+
606
+ layer = self._get_layer(index)
607
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
608
+ layer_ret = torch.utils.checkpoint.checkpoint(
609
+ layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
610
+ )
611
+ else:
612
+ layer_ret = layer(
613
+ hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache
614
+ )
615
+ hidden_states, kv_cache = layer_ret
616
+ if use_cache:
617
+ presents = presents + (kv_cache,)
618
+
619
+ if output_hidden_states:
620
+ all_hidden_states = all_hidden_states + (hidden_states,)
621
+
622
+ # Final layer norm.
623
+ if self.post_layer_norm:
624
+ hidden_states = self.final_layernorm(hidden_states)
625
+
626
+ return hidden_states, presents, all_hidden_states, all_self_attentions
627
+
628
+
629
+ class ChatGLMPreTrainedModel(PreTrainedModel):
630
+ """
631
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
632
+ models.
633
+ """
634
+
635
+ is_parallelizable = False
636
+ supports_gradient_checkpointing = True
637
+ config_class = ChatGLMConfig
638
+ base_model_prefix = "transformer"
639
+ _no_split_modules = ["GLMBlock"]
640
+
641
+ def _init_weights(self, module: nn.Module):
642
+ """Initialize the weights."""
643
+ return
644
+
645
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
646
+ batch_size, seq_length = input_ids.shape
647
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
648
+ full_attention_mask.tril_()
649
+ past_length = 0
650
+ if past_key_values:
651
+ past_length = past_key_values[0][0].shape[0]
652
+ if past_length:
653
+ full_attention_mask = torch.cat(
654
+ (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
655
+ )
656
+ if padding_mask is not None:
657
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
658
+ if not past_length and padding_mask is not None:
659
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
660
+ full_attention_mask = (full_attention_mask < 0.5).bool()
661
+ full_attention_mask.unsqueeze_(1)
662
+ return full_attention_mask
663
+
664
+ def get_position_ids(self, input_ids, device):
665
+ batch_size, seq_length = input_ids.shape
666
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
667
+ return position_ids
668
+
669
+ def _set_gradient_checkpointing(self, module, value=False):
670
+ if isinstance(module, GLMTransformer):
671
+ module.gradient_checkpointing = value
672
+
673
+
674
+ def default_init(cls, *args, **kwargs):
675
+ return cls(*args, **kwargs)
676
+
677
+
678
+ class Embedding(torch.nn.Module):
679
+ """Language model embeddings."""
680
+
681
+ def __init__(self, config: ChatGLMConfig, device=None):
682
+ super(Embedding, self).__init__()
683
+
684
+ self.hidden_size = config.hidden_size
685
+ # Word embeddings (parallel).
686
+ self.word_embeddings = nn.Embedding(
687
+ config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
688
+ )
689
+ self.fp32_residual_connection = config.fp32_residual_connection
690
+
691
+ def forward(self, input_ids):
692
+ # Embeddings.
693
+ words_embeddings = self.word_embeddings(input_ids)
694
+ embeddings = words_embeddings
695
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
696
+ embeddings = embeddings.transpose(0, 1).contiguous()
697
+ # If the input flag for fp32 residual connection is set, convert for float.
698
+ if self.fp32_residual_connection:
699
+ embeddings = embeddings.float()
700
+ return embeddings
701
+
702
+
703
+ class RotaryEmbedding(nn.Module):
704
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
705
+ super().__init__()
706
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
707
+ self.register_buffer("inv_freq", inv_freq)
708
+ self.dim = dim
709
+ self.original_impl = original_impl
710
+
711
+ def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
712
+ """Enhanced Transformer with Rotary Position Embedding.
713
+
714
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
715
+ transformers/rope/__init__.py. MIT License:
716
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
717
+ """
718
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
719
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
720
+
721
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
722
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
723
+
724
+ # Calculate the product of position index and $\theta_i$
725
+ idx_theta = torch.outer(seq_idx, theta).float()
726
+
727
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
728
+
729
+ # this is to mimic the behaviour of complex32, else we will get different results
730
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
731
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
732
+ return cache
733
+
734
+ def forward(self, max_seq_len, offset=0):
735
+ return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
736
+
737
+
738
+ class PrefixEncoder(torch.nn.Module):
739
+ """
740
+ The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size,
741
+ prefix-length, 2*layers*hidden)
742
+ """
743
+
744
+ def __init__(self, config: ChatGLMConfig):
745
+ super().__init__()
746
+ self.prefix_projection = config.prefix_projection
747
+ if self.prefix_projection:
748
+ # Use a two-layer MLP to encode the prefix
749
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
750
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
751
+ self.trans = torch.nn.Sequential(
752
+ torch.nn.Linear(kv_size, config.hidden_size),
753
+ torch.nn.Tanh(),
754
+ torch.nn.Linear(config.hidden_size, kv_size),
755
+ )
756
+ else:
757
+ self.embedding = torch.nn.Embedding(
758
+ config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
759
+ )
760
+
761
+ def forward(self, prefix: torch.Tensor):
762
+ if self.prefix_projection:
763
+ prefix_tokens = self.embedding(prefix)
764
+ past_key_values = self.trans(prefix_tokens)
765
+ else:
766
+ past_key_values = self.embedding(prefix)
767
+ return past_key_values
768
+
769
+
770
+ class ChatGLMModel(ChatGLMPreTrainedModel):
771
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
772
+ super().__init__(config)
773
+ if empty_init:
774
+ init_method = skip_init
775
+ else:
776
+ init_method = default_init
777
+ init_kwargs = {}
778
+ if device is not None:
779
+ init_kwargs["device"] = device
780
+ self.embedding = init_method(Embedding, config, **init_kwargs)
781
+ self.num_layers = config.num_layers
782
+ self.multi_query_group_num = config.multi_query_group_num
783
+ self.kv_channels = config.kv_channels
784
+
785
+ # Rotary positional embeddings
786
+ self.seq_length = config.seq_length
787
+ rotary_dim = (
788
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
789
+ )
790
+
791
+ self.rotary_pos_emb = RotaryEmbedding(
792
+ rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
793
+ )
794
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
795
+ self.output_layer = init_method(
796
+ nn.Linear,
797
+ config.hidden_size,
798
+ config.padded_vocab_size,
799
+ bias=False,
800
+ dtype=config.torch_dtype,
801
+ **init_kwargs,
802
+ )
803
+ self.pre_seq_len = config.pre_seq_len
804
+ self.prefix_projection = config.prefix_projection
805
+ if self.pre_seq_len is not None:
806
+ for param in self.parameters():
807
+ param.requires_grad = False
808
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
809
+ self.prefix_encoder = PrefixEncoder(config)
810
+ self.dropout = torch.nn.Dropout(0.1)
811
+
812
+ def get_input_embeddings(self):
813
+ return self.embedding.word_embeddings
814
+
815
+ def get_prompt(self, batch_size, device, dtype=torch.half):
816
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
817
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
818
+ past_key_values = past_key_values.view(
819
+ batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
820
+ )
821
+ # seq_len, b, nh, hidden_size
822
+ past_key_values = self.dropout(past_key_values)
823
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
824
+ return past_key_values
825
+
826
+ def forward(
827
+ self,
828
+ input_ids,
829
+ position_ids: Optional[torch.Tensor] = None,
830
+ attention_mask: Optional[torch.BoolTensor] = None,
831
+ full_attention_mask: Optional[torch.BoolTensor] = None,
832
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
833
+ inputs_embeds: Optional[torch.Tensor] = None,
834
+ use_cache: Optional[bool] = None,
835
+ output_hidden_states: Optional[bool] = None,
836
+ return_dict: Optional[bool] = None,
837
+ ):
838
+ output_hidden_states = (
839
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
840
+ )
841
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
842
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
843
+
844
+ batch_size, seq_length = input_ids.shape
845
+
846
+ if inputs_embeds is None:
847
+ inputs_embeds = self.embedding(input_ids)
848
+
849
+ if self.pre_seq_len is not None:
850
+ if past_key_values is None:
851
+ past_key_values = self.get_prompt(
852
+ batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
853
+ )
854
+ if attention_mask is not None:
855
+ attention_mask = torch.cat(
856
+ [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
857
+ )
858
+
859
+ if full_attention_mask is None:
860
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
861
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
862
+
863
+ # Rotary positional embeddings
864
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
865
+ if position_ids is not None:
866
+ rotary_pos_emb = rotary_pos_emb[position_ids]
867
+ else:
868
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
869
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
870
+
871
+ # Run encoder.
872
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
873
+ inputs_embeds,
874
+ full_attention_mask,
875
+ rotary_pos_emb=rotary_pos_emb,
876
+ kv_caches=past_key_values,
877
+ use_cache=use_cache,
878
+ output_hidden_states=output_hidden_states,
879
+ )
880
+
881
+ if not return_dict:
882
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
883
+
884
+ return BaseModelOutputWithPast(
885
+ last_hidden_state=hidden_states,
886
+ past_key_values=presents,
887
+ hidden_states=all_hidden_states,
888
+ attentions=all_self_attentions,
889
+ )