diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1363 @@
1
+ """
2
+ This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
3
+ We didn't modify this model.
4
+ The tensor operation is performed in the prompter.
5
+ """
6
+
7
+
8
+ """ PyTorch ChatGLM model. """
9
+
10
+ import math
11
+ import copy
12
+ import warnings
13
+ import re
14
+ import sys
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss, LayerNorm
21
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
22
+ from torch.nn.utils import skip_init
23
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
24
+ from copy import deepcopy
25
+
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPast,
28
+ CausalLMOutputWithPast,
29
+ SequenceClassifierOutputWithPast,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import logging
33
+ from transformers.generation.logits_process import LogitsProcessor
34
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
35
+ from transformers import PretrainedConfig
36
+
37
+
38
+
39
+ class ChatGLMConfig(PretrainedConfig):
40
+ model_type = "chatglm"
41
+ def __init__(
42
+ self,
43
+ num_layers=28,
44
+ padded_vocab_size=65024,
45
+ hidden_size=4096,
46
+ ffn_hidden_size=13696,
47
+ kv_channels=128,
48
+ num_attention_heads=32,
49
+ seq_length=2048,
50
+ hidden_dropout=0.0,
51
+ classifier_dropout=None,
52
+ attention_dropout=0.0,
53
+ layernorm_epsilon=1e-5,
54
+ rmsnorm=True,
55
+ apply_residual_connection_post_layernorm=False,
56
+ post_layer_norm=True,
57
+ add_bias_linear=False,
58
+ add_qkv_bias=False,
59
+ bias_dropout_fusion=True,
60
+ multi_query_attention=False,
61
+ multi_query_group_num=1,
62
+ apply_query_key_layer_scaling=True,
63
+ attention_softmax_in_fp32=True,
64
+ fp32_residual_connection=False,
65
+ quantization_bit=0,
66
+ pre_seq_len=None,
67
+ prefix_projection=False,
68
+ **kwargs
69
+ ):
70
+ self.num_layers = num_layers
71
+ self.vocab_size = padded_vocab_size
72
+ self.padded_vocab_size = padded_vocab_size
73
+ self.hidden_size = hidden_size
74
+ self.ffn_hidden_size = ffn_hidden_size
75
+ self.kv_channels = kv_channels
76
+ self.num_attention_heads = num_attention_heads
77
+ self.seq_length = seq_length
78
+ self.hidden_dropout = hidden_dropout
79
+ self.classifier_dropout = classifier_dropout
80
+ self.attention_dropout = attention_dropout
81
+ self.layernorm_epsilon = layernorm_epsilon
82
+ self.rmsnorm = rmsnorm
83
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
84
+ self.post_layer_norm = post_layer_norm
85
+ self.add_bias_linear = add_bias_linear
86
+ self.add_qkv_bias = add_qkv_bias
87
+ self.bias_dropout_fusion = bias_dropout_fusion
88
+ self.multi_query_attention = multi_query_attention
89
+ self.multi_query_group_num = multi_query_group_num
90
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
91
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
92
+ self.fp32_residual_connection = fp32_residual_connection
93
+ self.quantization_bit = quantization_bit
94
+ self.pre_seq_len = pre_seq_len
95
+ self.prefix_projection = prefix_projection
96
+ super().__init__(**kwargs)
97
+
98
+
99
+
100
+ # flags required to enable jit fusion kernels
101
+
102
+ if sys.platform != 'darwin':
103
+ torch._C._jit_set_profiling_mode(False)
104
+ torch._C._jit_set_profiling_executor(False)
105
+ torch._C._jit_override_can_fuse_on_cpu(True)
106
+ torch._C._jit_override_can_fuse_on_gpu(True)
107
+
108
+ logger = logging.get_logger(__name__)
109
+
110
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
111
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
112
+
113
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
114
+ "THUDM/chatglm3-6b-base",
115
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
116
+ ]
117
+
118
+
119
+ def default_init(cls, *args, **kwargs):
120
+ return cls(*args, **kwargs)
121
+
122
+
123
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
124
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
125
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
126
+ scores.zero_()
127
+ scores[..., 5] = 5e4
128
+ return scores
129
+
130
+
131
+ class PrefixEncoder(torch.nn.Module):
132
+ """
133
+ The torch.nn model to encode the prefix
134
+ Input shape: (batch-size, prefix-length)
135
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
136
+ """
137
+
138
+ def __init__(self, config: ChatGLMConfig):
139
+ super().__init__()
140
+ self.prefix_projection = config.prefix_projection
141
+ if self.prefix_projection:
142
+ # Use a two-layer MLP to encode the prefix
143
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
144
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
145
+ self.trans = torch.nn.Sequential(
146
+ torch.nn.Linear(kv_size, config.hidden_size),
147
+ torch.nn.Tanh(),
148
+ torch.nn.Linear(config.hidden_size, kv_size)
149
+ )
150
+ else:
151
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
152
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
153
+
154
+ def forward(self, prefix: torch.Tensor):
155
+ if self.prefix_projection:
156
+ prefix_tokens = self.embedding(prefix)
157
+ past_key_values = self.trans(prefix_tokens)
158
+ else:
159
+ past_key_values = self.embedding(prefix)
160
+ return past_key_values
161
+
162
+
163
+ def split_tensor_along_last_dim(
164
+ tensor: torch.Tensor,
165
+ num_partitions: int,
166
+ contiguous_split_chunks: bool = False,
167
+ ) -> List[torch.Tensor]:
168
+ """Split a tensor along its last dimension.
169
+
170
+ Arguments:
171
+ tensor: input tensor.
172
+ num_partitions: number of partitions to split the tensor
173
+ contiguous_split_chunks: If True, make each chunk contiguous
174
+ in memory.
175
+
176
+ Returns:
177
+ A list of Tensors
178
+ """
179
+ # Get the size and dimension.
180
+ last_dim = tensor.dim() - 1
181
+ last_dim_size = tensor.size()[last_dim] // num_partitions
182
+ # Split.
183
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
184
+ # Note: torch.split does not create contiguous tensors by default.
185
+ if contiguous_split_chunks:
186
+ return tuple(chunk.contiguous() for chunk in tensor_list)
187
+
188
+ return tensor_list
189
+
190
+
191
+ class RotaryEmbedding(nn.Module):
192
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
193
+ super().__init__()
194
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
195
+ self.register_buffer("inv_freq", inv_freq)
196
+ self.dim = dim
197
+ self.original_impl = original_impl
198
+
199
+ def forward_impl(
200
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
201
+ ):
202
+ """Enhanced Transformer with Rotary Position Embedding.
203
+
204
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
205
+ transformers/rope/__init__.py. MIT License:
206
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
207
+ """
208
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
209
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
210
+
211
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
212
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
213
+
214
+ # Calculate the product of position index and $\theta_i$
215
+ idx_theta = torch.outer(seq_idx, theta).float()
216
+
217
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
218
+
219
+ # this is to mimic the behaviour of complex32, else we will get different results
220
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
221
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
222
+ return cache
223
+
224
+ def forward(self, max_seq_len, offset=0):
225
+ return self.forward_impl(
226
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
227
+ )
228
+
229
+
230
+ @torch.jit.script
231
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
232
+ # x: [sq, b, np, hn]
233
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
234
+ rot_dim = rope_cache.shape[-2] * 2
235
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
236
+ # truncate to support variable sizes
237
+ rope_cache = rope_cache[:sq]
238
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
239
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
240
+ x_out2 = torch.stack(
241
+ [
242
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
243
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
244
+ ],
245
+ -1,
246
+ )
247
+ x_out2 = x_out2.flatten(3)
248
+ return torch.cat((x_out2, x_pass), dim=-1)
249
+
250
+
251
+ class RMSNorm(torch.nn.Module):
252
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
253
+ super().__init__()
254
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
255
+ self.eps = eps
256
+
257
+ def forward(self, hidden_states: torch.Tensor):
258
+ input_dtype = hidden_states.dtype
259
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
260
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
261
+
262
+ return (self.weight * hidden_states).to(input_dtype)
263
+
264
+
265
+ class CoreAttention(torch.nn.Module):
266
+ def __init__(self, config: ChatGLMConfig, layer_number):
267
+ super(CoreAttention, self).__init__()
268
+
269
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
270
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
271
+ if self.apply_query_key_layer_scaling:
272
+ self.attention_softmax_in_fp32 = True
273
+ self.layer_number = max(1, layer_number)
274
+
275
+ projection_size = config.kv_channels * config.num_attention_heads
276
+
277
+ # Per attention head and per partition values.
278
+ self.hidden_size_per_partition = projection_size
279
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
280
+ self.num_attention_heads_per_partition = config.num_attention_heads
281
+
282
+ coeff = None
283
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
284
+ if self.apply_query_key_layer_scaling:
285
+ coeff = self.layer_number
286
+ self.norm_factor *= coeff
287
+ self.coeff = coeff
288
+
289
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
290
+
291
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
292
+ pytorch_major_version = int(torch.__version__.split('.')[0])
293
+ if pytorch_major_version >= 2:
294
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
295
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
296
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
297
+ is_causal=True)
298
+ else:
299
+ if attention_mask is not None:
300
+ attention_mask = ~attention_mask
301
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
302
+ attention_mask)
303
+ context_layer = context_layer.permute(2, 0, 1, 3)
304
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
305
+ context_layer = context_layer.reshape(*new_context_layer_shape)
306
+ else:
307
+ # Raw attention scores
308
+
309
+ # [b, np, sq, sk]
310
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
311
+
312
+ # [sq, b, np, hn] -> [sq, b * np, hn]
313
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
314
+ # [sk, b, np, hn] -> [sk, b * np, hn]
315
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
316
+
317
+ # preallocting input tensor: [b * np, sq, sk]
318
+ matmul_input_buffer = torch.empty(
319
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
320
+ device=query_layer.device
321
+ )
322
+
323
+ # Raw attention scores. [b * np, sq, sk]
324
+ matmul_result = torch.baddbmm(
325
+ matmul_input_buffer,
326
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
327
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
328
+ beta=0.0,
329
+ alpha=(1.0 / self.norm_factor),
330
+ )
331
+
332
+ # change view to [b, np, sq, sk]
333
+ attention_scores = matmul_result.view(*output_size)
334
+
335
+ # ===========================
336
+ # Attention probs and dropout
337
+ # ===========================
338
+
339
+ # attention scores and attention mask [b, np, sq, sk]
340
+ if self.attention_softmax_in_fp32:
341
+ attention_scores = attention_scores.float()
342
+ if self.coeff is not None:
343
+ attention_scores = attention_scores * self.coeff
344
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
345
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
346
+ device=attention_scores.device, dtype=torch.bool)
347
+ attention_mask.tril_()
348
+ attention_mask = ~attention_mask
349
+ if attention_mask is not None:
350
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
351
+ attention_probs = F.softmax(attention_scores, dim=-1)
352
+ attention_probs = attention_probs.type_as(value_layer)
353
+
354
+ # This is actually dropping out entire tokens to attend to, which might
355
+ # seem a bit unusual, but is taken from the original Transformer paper.
356
+ attention_probs = self.attention_dropout(attention_probs)
357
+ # =========================
358
+ # Context layer. [sq, b, hp]
359
+ # =========================
360
+
361
+ # value_layer -> context layer.
362
+ # [sk, b, np, hn] --> [b, np, sq, hn]
363
+
364
+ # context layer shape: [b, np, sq, hn]
365
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
366
+ # change view [sk, b * np, hn]
367
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
368
+ # change view [b * np, sq, sk]
369
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
370
+ # matmul: [b * np, sq, hn]
371
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
372
+ # change view [b, np, sq, hn]
373
+ context_layer = context_layer.view(*output_size)
374
+ # [b, np, sq, hn] --> [sq, b, np, hn]
375
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
376
+ # [sq, b, np, hn] --> [sq, b, hp]
377
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
378
+ context_layer = context_layer.view(*new_context_layer_shape)
379
+
380
+ return context_layer
381
+
382
+
383
+ class SelfAttention(torch.nn.Module):
384
+ """Parallel self-attention layer abstract class.
385
+
386
+ Self-attention layer takes input with size [s, b, h]
387
+ and returns output of the same size.
388
+ """
389
+
390
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
391
+ super(SelfAttention, self).__init__()
392
+ self.layer_number = max(1, layer_number)
393
+
394
+ self.projection_size = config.kv_channels * config.num_attention_heads
395
+
396
+ # Per attention head and per partition values.
397
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
398
+ self.num_attention_heads_per_partition = config.num_attention_heads
399
+
400
+ self.multi_query_attention = config.multi_query_attention
401
+ self.qkv_hidden_size = 3 * self.projection_size
402
+ if self.multi_query_attention:
403
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
404
+ self.qkv_hidden_size = (
405
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
406
+ )
407
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
408
+ bias=config.add_bias_linear or config.add_qkv_bias,
409
+ device=device, **_config_to_kwargs(config)
410
+ )
411
+
412
+ self.core_attention = CoreAttention(config, self.layer_number)
413
+
414
+ # Output.
415
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
416
+ device=device, **_config_to_kwargs(config)
417
+ )
418
+
419
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
420
+ if self.multi_query_attention:
421
+ num_attention_heads = self.num_multi_query_groups_per_partition
422
+ else:
423
+ num_attention_heads = self.num_attention_heads_per_partition
424
+ return torch.empty(
425
+ inference_max_sequence_len,
426
+ batch_size,
427
+ num_attention_heads,
428
+ self.hidden_size_per_attention_head,
429
+ dtype=dtype,
430
+ device=device,
431
+ )
432
+
433
+ def forward(
434
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
435
+ ):
436
+ # hidden_states: [sq, b, h]
437
+
438
+ # =================================================
439
+ # Pre-allocate memory for key-values for inference.
440
+ # =================================================
441
+ # =====================
442
+ # Query, Key, and Value
443
+ # =====================
444
+
445
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
446
+ mixed_x_layer = self.query_key_value(hidden_states)
447
+
448
+ if self.multi_query_attention:
449
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
450
+ [
451
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
452
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
453
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
454
+ ],
455
+ dim=-1,
456
+ )
457
+ query_layer = query_layer.view(
458
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
459
+ )
460
+ key_layer = key_layer.view(
461
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
462
+ )
463
+ value_layer = value_layer.view(
464
+ value_layer.size()[:-1]
465
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
466
+ )
467
+ else:
468
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
469
+ (self.num_attention_heads_per_partition,
470
+ 3 * self.hidden_size_per_attention_head)
471
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
472
+
473
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
474
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
475
+
476
+ # apply relative positional encoding (rotary embedding)
477
+ if rotary_pos_emb is not None:
478
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
479
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
480
+
481
+ # adjust key and value for inference
482
+ if kv_cache is not None:
483
+ cache_k, cache_v = kv_cache
484
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
485
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
486
+ if use_cache:
487
+ kv_cache = (key_layer, value_layer)
488
+ else:
489
+ kv_cache = None
490
+
491
+ if self.multi_query_attention:
492
+ key_layer = key_layer.unsqueeze(-2)
493
+ key_layer = key_layer.expand(
494
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
495
+ )
496
+ key_layer = key_layer.contiguous().view(
497
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
498
+ )
499
+ value_layer = value_layer.unsqueeze(-2)
500
+ value_layer = value_layer.expand(
501
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
502
+ )
503
+ value_layer = value_layer.contiguous().view(
504
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
505
+ )
506
+
507
+ # ==================================
508
+ # core attention computation
509
+ # ==================================
510
+
511
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
512
+
513
+ # =================
514
+ # Output. [sq, b, h]
515
+ # =================
516
+
517
+ output = self.dense(context_layer)
518
+
519
+ return output, kv_cache
520
+
521
+
522
+ def _config_to_kwargs(args):
523
+ common_kwargs = {
524
+ "dtype": args.torch_dtype,
525
+ }
526
+ return common_kwargs
527
+
528
+
529
+ class MLP(torch.nn.Module):
530
+ """MLP.
531
+
532
+ MLP will take the input with h hidden state, project it to 4*h
533
+ hidden dimension, perform nonlinear transformation, and project the
534
+ state back into h hidden dimension.
535
+ """
536
+
537
+ def __init__(self, config: ChatGLMConfig, device=None):
538
+ super(MLP, self).__init__()
539
+
540
+ self.add_bias = config.add_bias_linear
541
+
542
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
543
+ self.dense_h_to_4h = nn.Linear(
544
+ config.hidden_size,
545
+ config.ffn_hidden_size * 2,
546
+ bias=self.add_bias,
547
+ device=device,
548
+ **_config_to_kwargs(config)
549
+ )
550
+
551
+ def swiglu(x):
552
+ x = torch.chunk(x, 2, dim=-1)
553
+ return F.silu(x[0]) * x[1]
554
+
555
+ self.activation_func = swiglu
556
+
557
+ # Project back to h.
558
+ self.dense_4h_to_h = nn.Linear(
559
+ config.ffn_hidden_size,
560
+ config.hidden_size,
561
+ bias=self.add_bias,
562
+ device=device,
563
+ **_config_to_kwargs(config)
564
+ )
565
+
566
+ def forward(self, hidden_states):
567
+ # [s, b, 4hp]
568
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
569
+ intermediate_parallel = self.activation_func(intermediate_parallel)
570
+ # [s, b, h]
571
+ output = self.dense_4h_to_h(intermediate_parallel)
572
+ return output
573
+
574
+
575
+ class GLMBlock(torch.nn.Module):
576
+ """A single transformer layer.
577
+
578
+ Transformer layer takes input with size [s, b, h] and returns an
579
+ output of the same size.
580
+ """
581
+
582
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
583
+ super(GLMBlock, self).__init__()
584
+ self.layer_number = layer_number
585
+
586
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
587
+
588
+ self.fp32_residual_connection = config.fp32_residual_connection
589
+
590
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
591
+ # Layernorm on the input data.
592
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
593
+ dtype=config.torch_dtype)
594
+
595
+ # Self attention.
596
+ self.self_attention = SelfAttention(config, layer_number, device=device)
597
+ self.hidden_dropout = config.hidden_dropout
598
+
599
+ # Layernorm on the attention output
600
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
+ dtype=config.torch_dtype)
602
+
603
+ # MLP
604
+ self.mlp = MLP(config, device=device)
605
+
606
+ def forward(
607
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
608
+ ):
609
+ # hidden_states: [s, b, h]
610
+
611
+ # Layer norm at the beginning of the transformer layer.
612
+ layernorm_output = self.input_layernorm(hidden_states)
613
+ # Self attention.
614
+ attention_output, kv_cache = self.self_attention(
615
+ layernorm_output,
616
+ attention_mask,
617
+ rotary_pos_emb,
618
+ kv_cache=kv_cache,
619
+ use_cache=use_cache
620
+ )
621
+
622
+ # Residual connection.
623
+ if self.apply_residual_connection_post_layernorm:
624
+ residual = layernorm_output
625
+ else:
626
+ residual = hidden_states
627
+
628
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
629
+ layernorm_input = residual + layernorm_input
630
+
631
+ # Layer norm post the self attention.
632
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
633
+
634
+ # MLP.
635
+ mlp_output = self.mlp(layernorm_output)
636
+
637
+ # Second residual connection.
638
+ if self.apply_residual_connection_post_layernorm:
639
+ residual = layernorm_output
640
+ else:
641
+ residual = layernorm_input
642
+
643
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
644
+ output = residual + output
645
+
646
+ return output, kv_cache
647
+
648
+
649
+ class GLMTransformer(torch.nn.Module):
650
+ """Transformer class."""
651
+
652
+ def __init__(self, config: ChatGLMConfig, device=None):
653
+ super(GLMTransformer, self).__init__()
654
+
655
+ self.fp32_residual_connection = config.fp32_residual_connection
656
+ self.post_layer_norm = config.post_layer_norm
657
+
658
+ # Number of layers.
659
+ self.num_layers = config.num_layers
660
+
661
+ # Transformer layers.
662
+ def build_layer(layer_number):
663
+ return GLMBlock(config, layer_number, device=device)
664
+
665
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
666
+
667
+ if self.post_layer_norm:
668
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
669
+ # Final layer norm before output.
670
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
671
+ dtype=config.torch_dtype)
672
+
673
+ self.gradient_checkpointing = False
674
+
675
+ def _get_layer(self, layer_number):
676
+ return self.layers[layer_number]
677
+
678
+ def forward(
679
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
680
+ use_cache: Optional[bool] = True,
681
+ output_hidden_states: Optional[bool] = False,
682
+ ):
683
+ if not kv_caches:
684
+ kv_caches = [None for _ in range(self.num_layers)]
685
+ presents = () if use_cache else None
686
+ if self.gradient_checkpointing and self.training:
687
+ if use_cache:
688
+ logger.warning_once(
689
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
690
+ )
691
+ use_cache = False
692
+
693
+ all_self_attentions = None
694
+ all_hidden_states = () if output_hidden_states else None
695
+ for index in range(self.num_layers):
696
+ if output_hidden_states:
697
+ all_hidden_states = all_hidden_states + (hidden_states,)
698
+
699
+ layer = self._get_layer(index)
700
+ if self.gradient_checkpointing and self.training:
701
+ layer_ret = torch.utils.checkpoint.checkpoint(
702
+ layer,
703
+ hidden_states,
704
+ attention_mask,
705
+ rotary_pos_emb,
706
+ kv_caches[index],
707
+ use_cache
708
+ )
709
+ else:
710
+ layer_ret = layer(
711
+ hidden_states,
712
+ attention_mask,
713
+ rotary_pos_emb,
714
+ kv_cache=kv_caches[index],
715
+ use_cache=use_cache
716
+ )
717
+ hidden_states, kv_cache = layer_ret
718
+ if use_cache:
719
+ presents = presents + (kv_cache,)
720
+
721
+ if output_hidden_states:
722
+ all_hidden_states = all_hidden_states + (hidden_states,)
723
+
724
+ # Final layer norm.
725
+ if self.post_layer_norm:
726
+ hidden_states = self.final_layernorm(hidden_states)
727
+
728
+ return hidden_states, presents, all_hidden_states, all_self_attentions
729
+
730
+
731
+ class ChatGLMPreTrainedModel(PreTrainedModel):
732
+ """
733
+ An abstract class to handle weights initialization and
734
+ a simple interface for downloading and loading pretrained models.
735
+ """
736
+
737
+ is_parallelizable = False
738
+ supports_gradient_checkpointing = True
739
+ config_class = ChatGLMConfig
740
+ base_model_prefix = "transformer"
741
+ _no_split_modules = ["GLMBlock"]
742
+
743
+ def _init_weights(self, module: nn.Module):
744
+ """Initialize the weights."""
745
+ return
746
+
747
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
748
+ batch_size, seq_length = input_ids.shape
749
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
750
+ full_attention_mask.tril_()
751
+ past_length = 0
752
+ if past_key_values:
753
+ past_length = past_key_values[0][0].shape[0]
754
+ if past_length:
755
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
756
+ device=input_ids.device), full_attention_mask), dim=-1)
757
+ if padding_mask is not None:
758
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
759
+ if not past_length and padding_mask is not None:
760
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
761
+ full_attention_mask = (full_attention_mask < 0.5).bool()
762
+ full_attention_mask.unsqueeze_(1)
763
+ return full_attention_mask
764
+
765
+ def get_position_ids(self, input_ids, device):
766
+ batch_size, seq_length = input_ids.shape
767
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
768
+ return position_ids
769
+
770
+ def _set_gradient_checkpointing(self, module, value=False):
771
+ if isinstance(module, GLMTransformer):
772
+ module.gradient_checkpointing = value
773
+
774
+
775
+ class Embedding(torch.nn.Module):
776
+ """Language model embeddings."""
777
+
778
+ def __init__(self, config: ChatGLMConfig, device=None):
779
+ super(Embedding, self).__init__()
780
+
781
+ self.hidden_size = config.hidden_size
782
+ # Word embeddings (parallel).
783
+ self.word_embeddings = nn.Embedding(
784
+ config.padded_vocab_size,
785
+ self.hidden_size,
786
+ dtype=config.torch_dtype,
787
+ device=device
788
+ )
789
+ self.fp32_residual_connection = config.fp32_residual_connection
790
+
791
+ def forward(self, input_ids):
792
+ # Embeddings.
793
+ words_embeddings = self.word_embeddings(input_ids)
794
+ embeddings = words_embeddings
795
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
796
+ embeddings = embeddings.transpose(0, 1).contiguous()
797
+ # If the input flag for fp32 residual connection is set, convert for float.
798
+ if self.fp32_residual_connection:
799
+ embeddings = embeddings.float()
800
+ return embeddings
801
+
802
+
803
+ class ChatGLMModel(ChatGLMPreTrainedModel):
804
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
805
+ super().__init__(config)
806
+ if empty_init:
807
+ init_method = skip_init
808
+ else:
809
+ init_method = default_init
810
+ init_kwargs = {}
811
+ if device is not None:
812
+ init_kwargs["device"] = device
813
+ self.embedding = init_method(Embedding, config, **init_kwargs)
814
+ self.num_layers = config.num_layers
815
+ self.multi_query_group_num = config.multi_query_group_num
816
+ self.kv_channels = config.kv_channels
817
+
818
+ # Rotary positional embeddings
819
+ self.seq_length = config.seq_length
820
+ rotary_dim = (
821
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
822
+ )
823
+
824
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
825
+ dtype=config.torch_dtype)
826
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
827
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
828
+ dtype=config.torch_dtype, **init_kwargs)
829
+ self.pre_seq_len = config.pre_seq_len
830
+ self.prefix_projection = config.prefix_projection
831
+ if self.pre_seq_len is not None:
832
+ for param in self.parameters():
833
+ param.requires_grad = False
834
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
835
+ self.prefix_encoder = PrefixEncoder(config)
836
+ self.dropout = torch.nn.Dropout(0.1)
837
+
838
+ def get_input_embeddings(self):
839
+ return self.embedding.word_embeddings
840
+
841
+ def get_prompt(self, batch_size, device, dtype=torch.half):
842
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
843
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
844
+ past_key_values = past_key_values.view(
845
+ batch_size,
846
+ self.pre_seq_len,
847
+ self.num_layers * 2,
848
+ self.multi_query_group_num,
849
+ self.kv_channels
850
+ )
851
+ # seq_len, b, nh, hidden_size
852
+ past_key_values = self.dropout(past_key_values)
853
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
854
+ return past_key_values
855
+
856
+ def forward(
857
+ self,
858
+ input_ids,
859
+ position_ids: Optional[torch.Tensor] = None,
860
+ attention_mask: Optional[torch.BoolTensor] = None,
861
+ full_attention_mask: Optional[torch.BoolTensor] = None,
862
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
863
+ inputs_embeds: Optional[torch.Tensor] = None,
864
+ use_cache: Optional[bool] = None,
865
+ output_hidden_states: Optional[bool] = None,
866
+ return_dict: Optional[bool] = None,
867
+ ):
868
+ output_hidden_states = (
869
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
870
+ )
871
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
872
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
873
+
874
+ batch_size, seq_length = input_ids.shape
875
+
876
+ if inputs_embeds is None:
877
+ inputs_embeds = self.embedding(input_ids)
878
+
879
+ if self.pre_seq_len is not None:
880
+ if past_key_values is None:
881
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
882
+ dtype=inputs_embeds.dtype)
883
+ if attention_mask is not None:
884
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
885
+ attention_mask], dim=-1)
886
+
887
+ if full_attention_mask is None:
888
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
889
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
890
+
891
+ # Rotary positional embeddings
892
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
893
+ if position_ids is not None:
894
+ rotary_pos_emb = rotary_pos_emb[position_ids]
895
+ else:
896
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
897
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
898
+
899
+ # Run encoder.
900
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
901
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
902
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
903
+ )
904
+
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
907
+
908
+ return BaseModelOutputWithPast(
909
+ last_hidden_state=hidden_states,
910
+ past_key_values=presents,
911
+ hidden_states=all_hidden_states,
912
+ attentions=all_self_attentions,
913
+ )
914
+
915
+ def quantize(self, weight_bit_width: int):
916
+ from .quantization import quantize
917
+ quantize(self.encoder, weight_bit_width)
918
+ return self
919
+
920
+
921
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
922
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
923
+ super().__init__(config)
924
+
925
+ self.max_sequence_length = config.max_length
926
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
927
+ self.config = config
928
+ self.quantized = False
929
+
930
+ if self.config.quantization_bit:
931
+ self.quantize(self.config.quantization_bit, empty_init=True)
932
+
933
+ def _update_model_kwargs_for_generation(
934
+ self,
935
+ outputs: ModelOutput,
936
+ model_kwargs: Dict[str, Any],
937
+ is_encoder_decoder: bool = False,
938
+ standardize_cache_format: bool = False,
939
+ ) -> Dict[str, Any]:
940
+ # update past_key_values
941
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
942
+ outputs, standardize_cache_format=standardize_cache_format
943
+ )
944
+
945
+ # update attention mask
946
+ if "attention_mask" in model_kwargs:
947
+ attention_mask = model_kwargs["attention_mask"]
948
+ model_kwargs["attention_mask"] = torch.cat(
949
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
950
+ )
951
+
952
+ # update position ids
953
+ if "position_ids" in model_kwargs:
954
+ position_ids = model_kwargs["position_ids"]
955
+ new_position_id = position_ids[..., -1:].clone()
956
+ new_position_id += 1
957
+ model_kwargs["position_ids"] = torch.cat(
958
+ [position_ids, new_position_id], dim=-1
959
+ )
960
+
961
+ model_kwargs["is_first_forward"] = False
962
+ return model_kwargs
963
+
964
+ def prepare_inputs_for_generation(
965
+ self,
966
+ input_ids: torch.LongTensor,
967
+ past_key_values: Optional[torch.Tensor] = None,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ position_ids: Optional[torch.Tensor] = None,
970
+ use_cache: Optional[bool] = None,
971
+ is_first_forward: bool = True,
972
+ **kwargs
973
+ ) -> dict:
974
+ # only last token for input_ids if past is not None
975
+ if position_ids is None:
976
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
977
+ if not is_first_forward:
978
+ if past_key_values is not None:
979
+ position_ids = position_ids[..., -1:]
980
+ input_ids = input_ids[:, -1:]
981
+ return {
982
+ "input_ids": input_ids,
983
+ "past_key_values": past_key_values,
984
+ "position_ids": position_ids,
985
+ "attention_mask": attention_mask,
986
+ "return_last_logit": True,
987
+ "use_cache": use_cache
988
+ }
989
+
990
+ def forward(
991
+ self,
992
+ input_ids: Optional[torch.Tensor] = None,
993
+ position_ids: Optional[torch.Tensor] = None,
994
+ attention_mask: Optional[torch.Tensor] = None,
995
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
996
+ inputs_embeds: Optional[torch.Tensor] = None,
997
+ labels: Optional[torch.Tensor] = None,
998
+ use_cache: Optional[bool] = None,
999
+ output_attentions: Optional[bool] = None,
1000
+ output_hidden_states: Optional[bool] = None,
1001
+ return_dict: Optional[bool] = None,
1002
+ return_last_logit: Optional[bool] = False,
1003
+ ):
1004
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1006
+
1007
+ transformer_outputs = self.transformer(
1008
+ input_ids=input_ids,
1009
+ position_ids=position_ids,
1010
+ attention_mask=attention_mask,
1011
+ past_key_values=past_key_values,
1012
+ inputs_embeds=inputs_embeds,
1013
+ use_cache=use_cache,
1014
+ output_hidden_states=output_hidden_states,
1015
+ return_dict=return_dict,
1016
+ )
1017
+
1018
+ hidden_states = transformer_outputs[0]
1019
+ if return_last_logit:
1020
+ hidden_states = hidden_states[-1:]
1021
+ lm_logits = self.transformer.output_layer(hidden_states)
1022
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
1023
+
1024
+ loss = None
1025
+ if labels is not None:
1026
+ lm_logits = lm_logits.to(torch.float32)
1027
+
1028
+ # Shift so that tokens < n predict n
1029
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1030
+ shift_labels = labels[..., 1:].contiguous()
1031
+ # Flatten the tokens
1032
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1033
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1034
+
1035
+ lm_logits = lm_logits.to(hidden_states.dtype)
1036
+ loss = loss.to(hidden_states.dtype)
1037
+
1038
+ if not return_dict:
1039
+ output = (lm_logits,) + transformer_outputs[1:]
1040
+ return ((loss,) + output) if loss is not None else output
1041
+
1042
+ return CausalLMOutputWithPast(
1043
+ loss=loss,
1044
+ logits=lm_logits,
1045
+ past_key_values=transformer_outputs.past_key_values,
1046
+ hidden_states=transformer_outputs.hidden_states,
1047
+ attentions=transformer_outputs.attentions,
1048
+ )
1049
+
1050
+ @staticmethod
1051
+ def _reorder_cache(
1052
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1053
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1054
+ """
1055
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1056
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1057
+ beam_idx at every generation step.
1058
+
1059
+ Output shares the same memory storage as `past`.
1060
+ """
1061
+ return tuple(
1062
+ (
1063
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1064
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1065
+ )
1066
+ for layer_past in past
1067
+ )
1068
+
1069
+ def process_response(self, output, history):
1070
+ content = ""
1071
+ history = deepcopy(history)
1072
+ for response in output.split("<|assistant|>"):
1073
+ metadata, content = response.split("\n", maxsplit=1)
1074
+ if not metadata.strip():
1075
+ content = content.strip()
1076
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1077
+ content = content.replace("[[训练时间]]", "2023年")
1078
+ else:
1079
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1080
+ if history[0]["role"] == "system" and "tools" in history[0]:
1081
+ content = "\n".join(content.split("\n")[1:-1])
1082
+ def tool_call(**kwargs):
1083
+ return kwargs
1084
+ parameters = eval(content)
1085
+ content = {"name": metadata.strip(), "parameters": parameters}
1086
+ else:
1087
+ content = {"name": metadata.strip(), "content": content}
1088
+ return content, history
1089
+
1090
+ @torch.inference_mode()
1091
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1092
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1093
+ **kwargs):
1094
+ if history is None:
1095
+ history = []
1096
+ if logits_processor is None:
1097
+ logits_processor = LogitsProcessorList()
1098
+ logits_processor.append(InvalidScoreLogitsProcessor())
1099
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1100
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1101
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1102
+ inputs = inputs.to(self.device)
1103
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1104
+ tokenizer.get_command("<|observation|>")]
1105
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1106
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1107
+ response = tokenizer.decode(outputs)
1108
+ history.append({"role": role, "content": query})
1109
+ response, history = self.process_response(response, history)
1110
+ return response, history
1111
+
1112
+ @torch.inference_mode()
1113
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1114
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1115
+ logits_processor=None, return_past_key_values=False, **kwargs):
1116
+ if history is None:
1117
+ history = []
1118
+ if logits_processor is None:
1119
+ logits_processor = LogitsProcessorList()
1120
+ logits_processor.append(InvalidScoreLogitsProcessor())
1121
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1122
+ tokenizer.get_command("<|observation|>")]
1123
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1124
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1125
+ if past_key_values is None:
1126
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1127
+ else:
1128
+ inputs = tokenizer.build_chat_input(query, role=role)
1129
+ inputs = inputs.to(self.device)
1130
+ if past_key_values is not None:
1131
+ past_length = past_key_values[0][0].shape[0]
1132
+ if self.transformer.pre_seq_len is not None:
1133
+ past_length -= self.transformer.pre_seq_len
1134
+ inputs.position_ids += past_length
1135
+ attention_mask = inputs.attention_mask
1136
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1137
+ inputs['attention_mask'] = attention_mask
1138
+ history.append({"role": role, "content": query})
1139
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1140
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1141
+ **gen_kwargs):
1142
+ if return_past_key_values:
1143
+ outputs, past_key_values = outputs
1144
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1145
+ response = tokenizer.decode(outputs)
1146
+ if response and response[-1] != "�":
1147
+ response, new_history = self.process_response(response, history)
1148
+ if return_past_key_values:
1149
+ yield response, new_history, past_key_values
1150
+ else:
1151
+ yield response, new_history
1152
+
1153
+ @torch.inference_mode()
1154
+ def stream_generate(
1155
+ self,
1156
+ input_ids,
1157
+ generation_config: Optional[GenerationConfig] = None,
1158
+ logits_processor: Optional[LogitsProcessorList] = None,
1159
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1160
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1161
+ return_past_key_values=False,
1162
+ **kwargs,
1163
+ ):
1164
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1165
+
1166
+ if generation_config is None:
1167
+ generation_config = self.generation_config
1168
+ generation_config = copy.deepcopy(generation_config)
1169
+ model_kwargs = generation_config.update(**kwargs)
1170
+ model_kwargs["use_cache"] = generation_config.use_cache
1171
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1172
+
1173
+ if isinstance(eos_token_id, int):
1174
+ eos_token_id = [eos_token_id]
1175
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1176
+
1177
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1178
+ if has_default_max_length and generation_config.max_new_tokens is None:
1179
+ warnings.warn(
1180
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1181
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1182
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1183
+ UserWarning,
1184
+ )
1185
+ elif generation_config.max_new_tokens is not None:
1186
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1187
+ if not has_default_max_length:
1188
+ logger.warn(
1189
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1190
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1191
+ "Please refer to the documentation for more information. "
1192
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1193
+ UserWarning,
1194
+ )
1195
+
1196
+ if input_ids_seq_length >= generation_config.max_length:
1197
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1198
+ logger.warning(
1199
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1200
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1201
+ " increasing `max_new_tokens`."
1202
+ )
1203
+
1204
+ # 2. Set generation parameters if not already defined
1205
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1206
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1207
+
1208
+ logits_processor = self._get_logits_processor(
1209
+ generation_config=generation_config,
1210
+ input_ids_seq_length=input_ids_seq_length,
1211
+ encoder_input_ids=input_ids,
1212
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1213
+ logits_processor=logits_processor,
1214
+ )
1215
+
1216
+ stopping_criteria = self._get_stopping_criteria(
1217
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1218
+ )
1219
+ logits_warper = self._get_logits_warper(generation_config)
1220
+
1221
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1222
+ scores = None
1223
+ while True:
1224
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1225
+ # forward pass to get next token
1226
+ outputs = self(
1227
+ **model_inputs,
1228
+ return_dict=True,
1229
+ output_attentions=False,
1230
+ output_hidden_states=False,
1231
+ )
1232
+
1233
+ next_token_logits = outputs.logits[:, -1, :]
1234
+
1235
+ # pre-process distribution
1236
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1237
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1238
+
1239
+ # sample
1240
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1241
+ if generation_config.do_sample:
1242
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1243
+ else:
1244
+ next_tokens = torch.argmax(probs, dim=-1)
1245
+ # update generated ids, model inputs, and length for next step
1246
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1247
+ model_kwargs = self._update_model_kwargs_for_generation(
1248
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1249
+ )
1250
+ unfinished_sequences = unfinished_sequences.mul(
1251
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1252
+ )
1253
+ if return_past_key_values:
1254
+ yield input_ids, outputs.past_key_values
1255
+ else:
1256
+ yield input_ids
1257
+ # stop when each sentence is finished, or if we exceed the maximum length
1258
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1259
+ break
1260
+
1261
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1262
+ if bits == 0:
1263
+ return
1264
+
1265
+ from .quantization import quantize
1266
+
1267
+ if self.quantized:
1268
+ logger.info("Already quantized.")
1269
+ return self
1270
+
1271
+ self.quantized = True
1272
+
1273
+ self.config.quantization_bit = bits
1274
+
1275
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1276
+ **kwargs)
1277
+ return self
1278
+
1279
+
1280
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1281
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1282
+ super().__init__(config)
1283
+
1284
+ self.num_labels = config.num_labels
1285
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1286
+
1287
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1288
+ if config.classifier_dropout is not None:
1289
+ self.dropout = nn.Dropout(config.classifier_dropout)
1290
+ else:
1291
+ self.dropout = None
1292
+ self.config = config
1293
+
1294
+ if self.config.quantization_bit:
1295
+ self.quantize(self.config.quantization_bit, empty_init=True)
1296
+
1297
+ def forward(
1298
+ self,
1299
+ input_ids: Optional[torch.LongTensor] = None,
1300
+ position_ids: Optional[torch.LongTensor] = None,
1301
+ attention_mask: Optional[torch.Tensor] = None,
1302
+ full_attention_mask: Optional[torch.Tensor] = None,
1303
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1304
+ inputs_embeds: Optional[torch.LongTensor] = None,
1305
+ labels: Optional[torch.LongTensor] = None,
1306
+ use_cache: Optional[bool] = None,
1307
+ output_hidden_states: Optional[bool] = None,
1308
+ return_dict: Optional[bool] = None,
1309
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1310
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1311
+
1312
+ transformer_outputs = self.transformer(
1313
+ input_ids=input_ids,
1314
+ position_ids=position_ids,
1315
+ attention_mask=attention_mask,
1316
+ full_attention_mask=full_attention_mask,
1317
+ past_key_values=past_key_values,
1318
+ inputs_embeds=inputs_embeds,
1319
+ use_cache=use_cache,
1320
+ output_hidden_states=output_hidden_states,
1321
+ return_dict=return_dict,
1322
+ )
1323
+
1324
+ hidden_states = transformer_outputs[0]
1325
+ pooled_hidden_states = hidden_states[-1]
1326
+ if self.dropout is not None:
1327
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1328
+ logits = self.classifier_head(pooled_hidden_states)
1329
+
1330
+ loss = None
1331
+ if labels is not None:
1332
+ if self.config.problem_type is None:
1333
+ if self.num_labels == 1:
1334
+ self.config.problem_type = "regression"
1335
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1336
+ self.config.problem_type = "single_label_classification"
1337
+ else:
1338
+ self.config.problem_type = "multi_label_classification"
1339
+
1340
+ if self.config.problem_type == "regression":
1341
+ loss_fct = MSELoss()
1342
+ if self.num_labels == 1:
1343
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1344
+ else:
1345
+ loss = loss_fct(logits.float(), labels)
1346
+ elif self.config.problem_type == "single_label_classification":
1347
+ loss_fct = CrossEntropyLoss()
1348
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1349
+ elif self.config.problem_type == "multi_label_classification":
1350
+ loss_fct = BCEWithLogitsLoss()
1351
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1352
+
1353
+ if not return_dict:
1354
+ output = (logits,) + transformer_outputs[1:]
1355
+ return ((loss,) + output) if loss is not None else output
1356
+
1357
+ return SequenceClassifierOutputWithPast(
1358
+ loss=loss,
1359
+ logits=logits,
1360
+ past_key_values=transformer_outputs.past_key_values,
1361
+ hidden_states=transformer_outputs.hidden_states,
1362
+ attentions=transformer_outputs.attentions,
1363
+ )