diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,889 @@
1
+ # Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import List, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from torch.nn import LayerNorm
22
+ from torch.nn.utils import skip_init
23
+ from transformers import PretrainedConfig, PreTrainedModel
24
+ from transformers.modeling_outputs import BaseModelOutputWithPast
25
+
26
+ from ...utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class ChatGLMConfig(PretrainedConfig):
33
+ model_type = "chatglm"
34
+
35
+ def __init__(
36
+ self,
37
+ num_layers=28,
38
+ padded_vocab_size=65024,
39
+ hidden_size=4096,
40
+ ffn_hidden_size=13696,
41
+ kv_channels=128,
42
+ num_attention_heads=32,
43
+ seq_length=2048,
44
+ hidden_dropout=0.0,
45
+ classifier_dropout=None,
46
+ attention_dropout=0.0,
47
+ layernorm_epsilon=1e-5,
48
+ rmsnorm=True,
49
+ apply_residual_connection_post_layernorm=False,
50
+ post_layer_norm=True,
51
+ add_bias_linear=False,
52
+ add_qkv_bias=False,
53
+ bias_dropout_fusion=True,
54
+ multi_query_attention=False,
55
+ multi_query_group_num=1,
56
+ apply_query_key_layer_scaling=True,
57
+ attention_softmax_in_fp32=True,
58
+ fp32_residual_connection=False,
59
+ quantization_bit=0,
60
+ pre_seq_len=None,
61
+ prefix_projection=False,
62
+ **kwargs,
63
+ ):
64
+ self.num_layers = num_layers
65
+ self.vocab_size = padded_vocab_size
66
+ self.padded_vocab_size = padded_vocab_size
67
+ self.hidden_size = hidden_size
68
+ self.ffn_hidden_size = ffn_hidden_size
69
+ self.kv_channels = kv_channels
70
+ self.num_attention_heads = num_attention_heads
71
+ self.seq_length = seq_length
72
+ self.hidden_dropout = hidden_dropout
73
+ self.classifier_dropout = classifier_dropout
74
+ self.attention_dropout = attention_dropout
75
+ self.layernorm_epsilon = layernorm_epsilon
76
+ self.rmsnorm = rmsnorm
77
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
78
+ self.post_layer_norm = post_layer_norm
79
+ self.add_bias_linear = add_bias_linear
80
+ self.add_qkv_bias = add_qkv_bias
81
+ self.bias_dropout_fusion = bias_dropout_fusion
82
+ self.multi_query_attention = multi_query_attention
83
+ self.multi_query_group_num = multi_query_group_num
84
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
85
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
86
+ self.fp32_residual_connection = fp32_residual_connection
87
+ self.quantization_bit = quantization_bit
88
+ self.pre_seq_len = pre_seq_len
89
+ self.prefix_projection = prefix_projection
90
+ super().__init__(**kwargs)
91
+
92
+
93
+ class RMSNorm(torch.nn.Module):
94
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
95
+ super().__init__()
96
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
97
+ self.eps = eps
98
+
99
+ def forward(self, hidden_states: torch.Tensor):
100
+ input_dtype = hidden_states.dtype
101
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
102
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
103
+
104
+ return (self.weight * hidden_states).to(input_dtype)
105
+
106
+
107
+ def _config_to_kwargs(args):
108
+ common_kwargs = {
109
+ "dtype": args.torch_dtype,
110
+ }
111
+ return common_kwargs
112
+
113
+
114
+ class CoreAttention(torch.nn.Module):
115
+ def __init__(self, config: ChatGLMConfig, layer_number):
116
+ super(CoreAttention, self).__init__()
117
+
118
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
119
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
120
+ if self.apply_query_key_layer_scaling:
121
+ self.attention_softmax_in_fp32 = True
122
+ self.layer_number = max(1, layer_number)
123
+
124
+ projection_size = config.kv_channels * config.num_attention_heads
125
+
126
+ # Per attention head and per partition values.
127
+ self.hidden_size_per_partition = projection_size
128
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
129
+ self.num_attention_heads_per_partition = config.num_attention_heads
130
+
131
+ coeff = None
132
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
133
+ if self.apply_query_key_layer_scaling:
134
+ coeff = self.layer_number
135
+ self.norm_factor *= coeff
136
+ self.coeff = coeff
137
+
138
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
139
+
140
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
141
+ pytorch_major_version = int(torch.__version__.split(".")[0])
142
+ if pytorch_major_version >= 2:
143
+ query_layer, key_layer, value_layer = [
144
+ k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
145
+ ]
146
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
147
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
148
+ query_layer, key_layer, value_layer, is_causal=True
149
+ )
150
+ else:
151
+ if attention_mask is not None:
152
+ attention_mask = ~attention_mask
153
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
154
+ query_layer, key_layer, value_layer, attention_mask
155
+ )
156
+ context_layer = context_layer.permute(2, 0, 1, 3)
157
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
158
+ context_layer = context_layer.reshape(*new_context_layer_shape)
159
+ else:
160
+ # Raw attention scores
161
+
162
+ # [b, np, sq, sk]
163
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
164
+
165
+ # [sq, b, np, hn] -> [sq, b * np, hn]
166
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
167
+ # [sk, b, np, hn] -> [sk, b * np, hn]
168
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
169
+
170
+ # preallocting input tensor: [b * np, sq, sk]
171
+ matmul_input_buffer = torch.empty(
172
+ output_size[0] * output_size[1],
173
+ output_size[2],
174
+ output_size[3],
175
+ dtype=query_layer.dtype,
176
+ device=query_layer.device,
177
+ )
178
+
179
+ # Raw attention scores. [b * np, sq, sk]
180
+ matmul_result = torch.baddbmm(
181
+ matmul_input_buffer,
182
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
183
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
184
+ beta=0.0,
185
+ alpha=(1.0 / self.norm_factor),
186
+ )
187
+
188
+ # change view to [b, np, sq, sk]
189
+ attention_scores = matmul_result.view(*output_size)
190
+
191
+ # ===========================
192
+ # Attention probs and dropout
193
+ # ===========================
194
+
195
+ # attention scores and attention mask [b, np, sq, sk]
196
+ if self.attention_softmax_in_fp32:
197
+ attention_scores = attention_scores.float()
198
+ if self.coeff is not None:
199
+ attention_scores = attention_scores * self.coeff
200
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
201
+ attention_mask = torch.ones(
202
+ output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool
203
+ )
204
+ attention_mask.tril_()
205
+ attention_mask = ~attention_mask
206
+ if attention_mask is not None:
207
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
208
+ attention_probs = F.softmax(attention_scores, dim=-1)
209
+ attention_probs = attention_probs.type_as(value_layer)
210
+
211
+ # This is actually dropping out entire tokens to attend to, which might
212
+ # seem a bit unusual, but is taken from the original Transformer paper.
213
+ attention_probs = self.attention_dropout(attention_probs)
214
+ # =========================
215
+ # Context layer. [sq, b, hp]
216
+ # =========================
217
+
218
+ # value_layer -> context layer.
219
+ # [sk, b, np, hn] --> [b, np, sq, hn]
220
+
221
+ # context layer shape: [b, np, sq, hn]
222
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
223
+ # change view [sk, b * np, hn]
224
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
225
+ # change view [b * np, sq, sk]
226
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
227
+ # matmul: [b * np, sq, hn]
228
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
229
+ # change view [b, np, sq, hn]
230
+ context_layer = context_layer.view(*output_size)
231
+ # [b, np, sq, hn] --> [sq, b, np, hn]
232
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
233
+ # [sq, b, np, hn] --> [sq, b, hp]
234
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
235
+ context_layer = context_layer.view(*new_context_layer_shape)
236
+
237
+ return context_layer
238
+
239
+
240
+ def split_tensor_along_last_dim(
241
+ tensor: torch.Tensor,
242
+ num_partitions: int,
243
+ contiguous_split_chunks: bool = False,
244
+ ) -> List[torch.Tensor]:
245
+ """Split a tensor along its last dimension.
246
+
247
+ Arguments:
248
+ tensor: input tensor.
249
+ num_partitions: number of partitions to split the tensor
250
+ contiguous_split_chunks: If True, make each chunk contiguous
251
+ in memory.
252
+
253
+ Returns:
254
+ A list of Tensors
255
+ """
256
+ # Get the size and dimension.
257
+ last_dim = tensor.dim() - 1
258
+ last_dim_size = tensor.size()[last_dim] // num_partitions
259
+ # Split.
260
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
261
+ # Note: torch.split does not create contiguous tensors by default.
262
+ if contiguous_split_chunks:
263
+ return tuple(chunk.contiguous() for chunk in tensor_list)
264
+
265
+ return tensor_list
266
+
267
+
268
+ @torch.jit.script
269
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
270
+ # x: [sq, b, np, hn]
271
+ sq, _b, np, _hn = x.size(0), x.size(1), x.size(2), x.size(3)
272
+ rot_dim = rope_cache.shape[-2] * 2
273
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
274
+ # truncate to support variable sizes
275
+ rope_cache = rope_cache[:sq]
276
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
277
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
278
+ x_out2 = torch.stack(
279
+ [
280
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
281
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
282
+ ],
283
+ -1,
284
+ )
285
+ x_out2 = x_out2.flatten(3)
286
+ return torch.cat((x_out2, x_pass), dim=-1)
287
+
288
+
289
+ class SelfAttention(torch.nn.Module):
290
+ """Parallel self-attention layer abstract class.
291
+
292
+ Self-attention layer takes input with size [s, b, h] and returns output of the same size.
293
+ """
294
+
295
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
296
+ super(SelfAttention, self).__init__()
297
+ self.layer_number = max(1, layer_number)
298
+
299
+ self.projection_size = config.kv_channels * config.num_attention_heads
300
+
301
+ # Per attention head and per partition values.
302
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
303
+ self.num_attention_heads_per_partition = config.num_attention_heads
304
+
305
+ self.multi_query_attention = config.multi_query_attention
306
+ self.qkv_hidden_size = 3 * self.projection_size
307
+ if self.multi_query_attention:
308
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
309
+ self.qkv_hidden_size = (
310
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
311
+ )
312
+ self.query_key_value = nn.Linear(
313
+ config.hidden_size,
314
+ self.qkv_hidden_size,
315
+ bias=config.add_bias_linear or config.add_qkv_bias,
316
+ device=device,
317
+ **_config_to_kwargs(config),
318
+ )
319
+
320
+ self.core_attention = CoreAttention(config, self.layer_number)
321
+
322
+ # Output.
323
+ self.dense = nn.Linear(
324
+ self.projection_size,
325
+ config.hidden_size,
326
+ bias=config.add_bias_linear,
327
+ device=device,
328
+ **_config_to_kwargs(config),
329
+ )
330
+
331
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
332
+ if self.multi_query_attention:
333
+ num_attention_heads = self.num_multi_query_groups_per_partition
334
+ else:
335
+ num_attention_heads = self.num_attention_heads_per_partition
336
+ return torch.empty(
337
+ inference_max_sequence_len,
338
+ batch_size,
339
+ num_attention_heads,
340
+ self.hidden_size_per_attention_head,
341
+ dtype=dtype,
342
+ device=device,
343
+ )
344
+
345
+ def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
346
+ # hidden_states: [sq, b, h]
347
+
348
+ # =================================================
349
+ # Pre-allocate memory for key-values for inference.
350
+ # =================================================
351
+ # =====================
352
+ # Query, Key, and Value
353
+ # =====================
354
+
355
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
356
+ mixed_x_layer = self.query_key_value(hidden_states)
357
+
358
+ if self.multi_query_attention:
359
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
360
+ [
361
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
362
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
363
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
364
+ ],
365
+ dim=-1,
366
+ )
367
+ query_layer = query_layer.view(
368
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
369
+ )
370
+ key_layer = key_layer.view(
371
+ key_layer.size()[:-1]
372
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
373
+ )
374
+ value_layer = value_layer.view(
375
+ value_layer.size()[:-1]
376
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
377
+ )
378
+ else:
379
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
380
+ self.num_attention_heads_per_partition,
381
+ 3 * self.hidden_size_per_attention_head,
382
+ )
383
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
384
+
385
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
386
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
387
+
388
+ # apply relative positional encoding (rotary embedding)
389
+ if rotary_pos_emb is not None:
390
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
391
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
392
+
393
+ # adjust key and value for inference
394
+ if kv_cache is not None:
395
+ cache_k, cache_v = kv_cache
396
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
397
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
398
+ if use_cache:
399
+ kv_cache = (key_layer, value_layer)
400
+ else:
401
+ kv_cache = None
402
+
403
+ if self.multi_query_attention:
404
+ key_layer = key_layer.unsqueeze(-2)
405
+ key_layer = key_layer.expand(
406
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
407
+ )
408
+ key_layer = key_layer.contiguous().view(
409
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
410
+ )
411
+ value_layer = value_layer.unsqueeze(-2)
412
+ value_layer = value_layer.expand(
413
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
414
+ )
415
+ value_layer = value_layer.contiguous().view(
416
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
417
+ )
418
+
419
+ # ==================================
420
+ # core attention computation
421
+ # ==================================
422
+
423
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
424
+
425
+ # =================
426
+ # Output. [sq, b, h]
427
+ # =================
428
+
429
+ output = self.dense(context_layer)
430
+
431
+ return output, kv_cache
432
+
433
+
434
+ class MLP(torch.nn.Module):
435
+ """MLP.
436
+
437
+ MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation,
438
+ and project the state back into h hidden dimension.
439
+ """
440
+
441
+ def __init__(self, config: ChatGLMConfig, device=None):
442
+ super(MLP, self).__init__()
443
+
444
+ self.add_bias = config.add_bias_linear
445
+
446
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
447
+ self.dense_h_to_4h = nn.Linear(
448
+ config.hidden_size,
449
+ config.ffn_hidden_size * 2,
450
+ bias=self.add_bias,
451
+ device=device,
452
+ **_config_to_kwargs(config),
453
+ )
454
+
455
+ def swiglu(x):
456
+ x = torch.chunk(x, 2, dim=-1)
457
+ return F.silu(x[0]) * x[1]
458
+
459
+ self.activation_func = swiglu
460
+
461
+ # Project back to h.
462
+ self.dense_4h_to_h = nn.Linear(
463
+ config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
464
+ )
465
+
466
+ def forward(self, hidden_states):
467
+ # [s, b, 4hp]
468
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
469
+ intermediate_parallel = self.activation_func(intermediate_parallel)
470
+ # [s, b, h]
471
+ output = self.dense_4h_to_h(intermediate_parallel)
472
+ return output
473
+
474
+
475
+ class GLMBlock(torch.nn.Module):
476
+ """A single transformer layer.
477
+
478
+ Transformer layer takes input with size [s, b, h] and returns an output of the same size.
479
+ """
480
+
481
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
482
+ super(GLMBlock, self).__init__()
483
+ self.layer_number = layer_number
484
+
485
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
486
+
487
+ self.fp32_residual_connection = config.fp32_residual_connection
488
+
489
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
490
+ # Layernorm on the input data.
491
+ self.input_layernorm = LayerNormFunc(
492
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
493
+ )
494
+
495
+ # Self attention.
496
+ self.self_attention = SelfAttention(config, layer_number, device=device)
497
+ self.hidden_dropout = config.hidden_dropout
498
+
499
+ # Layernorm on the attention output
500
+ self.post_attention_layernorm = LayerNormFunc(
501
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
502
+ )
503
+
504
+ # MLP
505
+ self.mlp = MLP(config, device=device)
506
+
507
+ def forward(
508
+ self,
509
+ hidden_states,
510
+ attention_mask,
511
+ rotary_pos_emb,
512
+ kv_cache=None,
513
+ use_cache=True,
514
+ ):
515
+ # hidden_states: [s, b, h]
516
+
517
+ # Layer norm at the beginning of the transformer layer.
518
+ layernorm_output = self.input_layernorm(hidden_states)
519
+ # Self attention.
520
+ attention_output, kv_cache = self.self_attention(
521
+ layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache
522
+ )
523
+
524
+ # Residual connection.
525
+ if self.apply_residual_connection_post_layernorm:
526
+ residual = layernorm_output
527
+ else:
528
+ residual = hidden_states
529
+
530
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
531
+ layernorm_input = residual + layernorm_input
532
+
533
+ # Layer norm post the self attention.
534
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
535
+
536
+ # MLP.
537
+ mlp_output = self.mlp(layernorm_output)
538
+
539
+ # Second residual connection.
540
+ if self.apply_residual_connection_post_layernorm:
541
+ residual = layernorm_output
542
+ else:
543
+ residual = layernorm_input
544
+
545
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
546
+ output = residual + output
547
+
548
+ return output, kv_cache
549
+
550
+
551
+ class GLMTransformer(torch.nn.Module):
552
+ """Transformer class."""
553
+
554
+ def __init__(self, config: ChatGLMConfig, device=None):
555
+ super(GLMTransformer, self).__init__()
556
+
557
+ self.fp32_residual_connection = config.fp32_residual_connection
558
+ self.post_layer_norm = config.post_layer_norm
559
+
560
+ # Number of layers.
561
+ self.num_layers = config.num_layers
562
+
563
+ # Transformer layers.
564
+ def build_layer(layer_number):
565
+ return GLMBlock(config, layer_number, device=device)
566
+
567
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
568
+
569
+ if self.post_layer_norm:
570
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571
+ # Final layer norm before output.
572
+ self.final_layernorm = LayerNormFunc(
573
+ config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
574
+ )
575
+
576
+ self.gradient_checkpointing = False
577
+
578
+ def _get_layer(self, layer_number):
579
+ return self.layers[layer_number]
580
+
581
+ def forward(
582
+ self,
583
+ hidden_states,
584
+ attention_mask,
585
+ rotary_pos_emb,
586
+ kv_caches=None,
587
+ use_cache: Optional[bool] = True,
588
+ output_hidden_states: Optional[bool] = False,
589
+ ):
590
+ if not kv_caches:
591
+ kv_caches = [None for _ in range(self.num_layers)]
592
+ presents = () if use_cache else None
593
+ if self.gradient_checkpointing and self.training:
594
+ if use_cache:
595
+ logger.warning_once(
596
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
597
+ )
598
+ use_cache = False
599
+
600
+ all_self_attentions = None
601
+ all_hidden_states = () if output_hidden_states else None
602
+ for index in range(self.num_layers):
603
+ if output_hidden_states:
604
+ all_hidden_states = all_hidden_states + (hidden_states,)
605
+
606
+ layer = self._get_layer(index)
607
+ if self.gradient_checkpointing and self.training:
608
+ layer_ret = torch.utils.checkpoint.checkpoint(
609
+ layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
610
+ )
611
+ else:
612
+ layer_ret = layer(
613
+ hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache
614
+ )
615
+ hidden_states, kv_cache = layer_ret
616
+ if use_cache:
617
+ presents = presents + (kv_cache,)
618
+
619
+ if output_hidden_states:
620
+ all_hidden_states = all_hidden_states + (hidden_states,)
621
+
622
+ # Final layer norm.
623
+ if self.post_layer_norm:
624
+ hidden_states = self.final_layernorm(hidden_states)
625
+
626
+ return hidden_states, presents, all_hidden_states, all_self_attentions
627
+
628
+
629
+ class ChatGLMPreTrainedModel(PreTrainedModel):
630
+ """
631
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
632
+ models.
633
+ """
634
+
635
+ is_parallelizable = False
636
+ supports_gradient_checkpointing = True
637
+ config_class = ChatGLMConfig
638
+ base_model_prefix = "transformer"
639
+ _no_split_modules = ["GLMBlock"]
640
+
641
+ def _init_weights(self, module: nn.Module):
642
+ """Initialize the weights."""
643
+ return
644
+
645
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
646
+ batch_size, seq_length = input_ids.shape
647
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
648
+ full_attention_mask.tril_()
649
+ past_length = 0
650
+ if past_key_values:
651
+ past_length = past_key_values[0][0].shape[0]
652
+ if past_length:
653
+ full_attention_mask = torch.cat(
654
+ (torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1
655
+ )
656
+ if padding_mask is not None:
657
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
658
+ if not past_length and padding_mask is not None:
659
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
660
+ full_attention_mask = (full_attention_mask < 0.5).bool()
661
+ full_attention_mask.unsqueeze_(1)
662
+ return full_attention_mask
663
+
664
+ def get_position_ids(self, input_ids, device):
665
+ batch_size, seq_length = input_ids.shape
666
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
667
+ return position_ids
668
+
669
+ def _set_gradient_checkpointing(self, module, value=False):
670
+ if isinstance(module, GLMTransformer):
671
+ module.gradient_checkpointing = value
672
+
673
+
674
+ def default_init(cls, *args, **kwargs):
675
+ return cls(*args, **kwargs)
676
+
677
+
678
+ class Embedding(torch.nn.Module):
679
+ """Language model embeddings."""
680
+
681
+ def __init__(self, config: ChatGLMConfig, device=None):
682
+ super(Embedding, self).__init__()
683
+
684
+ self.hidden_size = config.hidden_size
685
+ # Word embeddings (parallel).
686
+ self.word_embeddings = nn.Embedding(
687
+ config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
688
+ )
689
+ self.fp32_residual_connection = config.fp32_residual_connection
690
+
691
+ def forward(self, input_ids):
692
+ # Embeddings.
693
+ words_embeddings = self.word_embeddings(input_ids)
694
+ embeddings = words_embeddings
695
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
696
+ embeddings = embeddings.transpose(0, 1).contiguous()
697
+ # If the input flag for fp32 residual connection is set, convert for float.
698
+ if self.fp32_residual_connection:
699
+ embeddings = embeddings.float()
700
+ return embeddings
701
+
702
+
703
+ class RotaryEmbedding(nn.Module):
704
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
705
+ super().__init__()
706
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
707
+ self.register_buffer("inv_freq", inv_freq)
708
+ self.dim = dim
709
+ self.original_impl = original_impl
710
+
711
+ def forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):
712
+ """Enhanced Transformer with Rotary Position Embedding.
713
+
714
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
715
+ transformers/rope/__init__.py. MIT License:
716
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
717
+ """
718
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
719
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
720
+
721
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
722
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
723
+
724
+ # Calculate the product of position index and $\theta_i$
725
+ idx_theta = torch.outer(seq_idx, theta).float()
726
+
727
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
728
+
729
+ # this is to mimic the behaviour of complex32, else we will get different results
730
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
731
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
732
+ return cache
733
+
734
+ def forward(self, max_seq_len, offset=0):
735
+ return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
736
+
737
+
738
+ class PrefixEncoder(torch.nn.Module):
739
+ """
740
+ The torch.nn model to encode the prefix Input shape: (batch-size, prefix-length) Output shape: (batch-size,
741
+ prefix-length, 2*layers*hidden)
742
+ """
743
+
744
+ def __init__(self, config: ChatGLMConfig):
745
+ super().__init__()
746
+ self.prefix_projection = config.prefix_projection
747
+ if self.prefix_projection:
748
+ # Use a two-layer MLP to encode the prefix
749
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
750
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
751
+ self.trans = torch.nn.Sequential(
752
+ torch.nn.Linear(kv_size, config.hidden_size),
753
+ torch.nn.Tanh(),
754
+ torch.nn.Linear(config.hidden_size, kv_size),
755
+ )
756
+ else:
757
+ self.embedding = torch.nn.Embedding(
758
+ config.pre_seq_len, config.num_layers * config.kv_channels * config.multi_query_group_num * 2
759
+ )
760
+
761
+ def forward(self, prefix: torch.Tensor):
762
+ if self.prefix_projection:
763
+ prefix_tokens = self.embedding(prefix)
764
+ past_key_values = self.trans(prefix_tokens)
765
+ else:
766
+ past_key_values = self.embedding(prefix)
767
+ return past_key_values
768
+
769
+
770
+ class ChatGLMModel(ChatGLMPreTrainedModel):
771
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
772
+ super().__init__(config)
773
+ if empty_init:
774
+ init_method = skip_init
775
+ else:
776
+ init_method = default_init
777
+ init_kwargs = {}
778
+ if device is not None:
779
+ init_kwargs["device"] = device
780
+ self.embedding = init_method(Embedding, config, **init_kwargs)
781
+ self.num_layers = config.num_layers
782
+ self.multi_query_group_num = config.multi_query_group_num
783
+ self.kv_channels = config.kv_channels
784
+
785
+ # Rotary positional embeddings
786
+ self.seq_length = config.seq_length
787
+ rotary_dim = (
788
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
789
+ )
790
+
791
+ self.rotary_pos_emb = RotaryEmbedding(
792
+ rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
793
+ )
794
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
795
+ self.output_layer = init_method(
796
+ nn.Linear,
797
+ config.hidden_size,
798
+ config.padded_vocab_size,
799
+ bias=False,
800
+ dtype=config.torch_dtype,
801
+ **init_kwargs,
802
+ )
803
+ self.pre_seq_len = config.pre_seq_len
804
+ self.prefix_projection = config.prefix_projection
805
+ if self.pre_seq_len is not None:
806
+ for param in self.parameters():
807
+ param.requires_grad = False
808
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
809
+ self.prefix_encoder = PrefixEncoder(config)
810
+ self.dropout = torch.nn.Dropout(0.1)
811
+
812
+ def get_input_embeddings(self):
813
+ return self.embedding.word_embeddings
814
+
815
+ def get_prompt(self, batch_size, device, dtype=torch.half):
816
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
817
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
818
+ past_key_values = past_key_values.view(
819
+ batch_size, self.pre_seq_len, self.num_layers * 2, self.multi_query_group_num, self.kv_channels
820
+ )
821
+ # seq_len, b, nh, hidden_size
822
+ past_key_values = self.dropout(past_key_values)
823
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
824
+ return past_key_values
825
+
826
+ def forward(
827
+ self,
828
+ input_ids,
829
+ position_ids: Optional[torch.Tensor] = None,
830
+ attention_mask: Optional[torch.BoolTensor] = None,
831
+ full_attention_mask: Optional[torch.BoolTensor] = None,
832
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
833
+ inputs_embeds: Optional[torch.Tensor] = None,
834
+ use_cache: Optional[bool] = None,
835
+ output_hidden_states: Optional[bool] = None,
836
+ return_dict: Optional[bool] = None,
837
+ ):
838
+ output_hidden_states = (
839
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
840
+ )
841
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
842
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
843
+
844
+ batch_size, seq_length = input_ids.shape
845
+
846
+ if inputs_embeds is None:
847
+ inputs_embeds = self.embedding(input_ids)
848
+
849
+ if self.pre_seq_len is not None:
850
+ if past_key_values is None:
851
+ past_key_values = self.get_prompt(
852
+ batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
853
+ )
854
+ if attention_mask is not None:
855
+ attention_mask = torch.cat(
856
+ [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
857
+ )
858
+
859
+ if full_attention_mask is None:
860
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
861
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
862
+
863
+ # Rotary positional embeddings
864
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
865
+ if position_ids is not None:
866
+ rotary_pos_emb = rotary_pos_emb[position_ids]
867
+ else:
868
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
869
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
870
+
871
+ # Run encoder.
872
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
873
+ inputs_embeds,
874
+ full_attention_mask,
875
+ rotary_pos_emb=rotary_pos_emb,
876
+ kv_caches=past_key_values,
877
+ use_cache=use_cache,
878
+ output_hidden_states=output_hidden_states,
879
+ )
880
+
881
+ if not return_dict:
882
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
883
+
884
+ return BaseModelOutputWithPast(
885
+ last_hidden_state=hidden_states,
886
+ past_key_values=presents,
887
+ hidden_states=all_hidden_states,
888
+ attentions=all_self_attentions,
889
+ )