diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from .base_prompter import BasePrompter
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
from ..models import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
4
|
+
from transformers import BertTokenizer, AutoTokenizer
|
|
5
|
+
import warnings, os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HunyuanDiTPrompter(BasePrompter):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
tokenizer_path=None,
|
|
12
|
+
tokenizer_t5_path=None
|
|
13
|
+
):
|
|
14
|
+
if tokenizer_path is None:
|
|
15
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
16
|
+
tokenizer_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
|
17
|
+
if tokenizer_t5_path is None:
|
|
18
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
19
|
+
tokenizer_t5_path = os.path.join(base_path, "tokenizer_configs/hunyuan_dit/tokenizer_t5")
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
|
22
|
+
with warnings.catch_warnings():
|
|
23
|
+
warnings.simplefilter("ignore")
|
|
24
|
+
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
|
|
25
|
+
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
|
26
|
+
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def fetch_models(self, text_encoder: HunyuanDiTCLIPTextEncoder = None, text_encoder_t5: HunyuanDiTT5TextEncoder = None):
|
|
30
|
+
self.text_encoder = text_encoder
|
|
31
|
+
self.text_encoder_t5 = text_encoder_t5
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
|
35
|
+
text_inputs = tokenizer(
|
|
36
|
+
prompt,
|
|
37
|
+
padding="max_length",
|
|
38
|
+
max_length=max_length,
|
|
39
|
+
truncation=True,
|
|
40
|
+
return_attention_mask=True,
|
|
41
|
+
return_tensors="pt",
|
|
42
|
+
)
|
|
43
|
+
text_input_ids = text_inputs.input_ids
|
|
44
|
+
attention_mask = text_inputs.attention_mask.to(device)
|
|
45
|
+
prompt_embeds = text_encoder(
|
|
46
|
+
text_input_ids.to(device),
|
|
47
|
+
attention_mask=attention_mask,
|
|
48
|
+
clip_skip=clip_skip
|
|
49
|
+
)
|
|
50
|
+
return prompt_embeds, attention_mask
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def encode_prompt(
|
|
54
|
+
self,
|
|
55
|
+
prompt,
|
|
56
|
+
clip_skip=1,
|
|
57
|
+
clip_skip_2=1,
|
|
58
|
+
positive=True,
|
|
59
|
+
device="cuda"
|
|
60
|
+
):
|
|
61
|
+
prompt = self.process_prompt(prompt, positive=positive)
|
|
62
|
+
|
|
63
|
+
# CLIP
|
|
64
|
+
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, self.text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
|
65
|
+
|
|
66
|
+
# T5
|
|
67
|
+
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, self.text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
|
68
|
+
|
|
69
|
+
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
from .base_prompter import BasePrompter
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
import json, os, re
|
|
4
|
+
from typing import List, Optional, Union, Dict
|
|
5
|
+
from sentencepiece import SentencePieceProcessor
|
|
6
|
+
from transformers import PreTrainedTokenizer
|
|
7
|
+
from transformers.utils import PaddingStrategy
|
|
8
|
+
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
|
9
|
+
from ..models.kolors_text_encoder import ChatGLMModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SPTokenizer:
|
|
13
|
+
def __init__(self, model_path: str):
|
|
14
|
+
# reload tokenizer
|
|
15
|
+
assert os.path.isfile(model_path), model_path
|
|
16
|
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
|
17
|
+
|
|
18
|
+
# BOS / EOS token IDs
|
|
19
|
+
self.n_words: int = self.sp_model.vocab_size()
|
|
20
|
+
self.bos_id: int = self.sp_model.bos_id()
|
|
21
|
+
self.eos_id: int = self.sp_model.eos_id()
|
|
22
|
+
self.pad_id: int = self.sp_model.unk_id()
|
|
23
|
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
|
24
|
+
|
|
25
|
+
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
|
26
|
+
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
|
27
|
+
self.special_tokens = {}
|
|
28
|
+
self.index_special_tokens = {}
|
|
29
|
+
for token in special_tokens:
|
|
30
|
+
self.special_tokens[token] = self.n_words
|
|
31
|
+
self.index_special_tokens[self.n_words] = token
|
|
32
|
+
self.n_words += 1
|
|
33
|
+
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
|
34
|
+
|
|
35
|
+
def tokenize(self, s: str, encode_special_tokens=False):
|
|
36
|
+
if encode_special_tokens:
|
|
37
|
+
last_index = 0
|
|
38
|
+
t = []
|
|
39
|
+
for match in re.finditer(self.role_special_token_expression, s):
|
|
40
|
+
if last_index < match.start():
|
|
41
|
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
|
42
|
+
t.append(s[match.start():match.end()])
|
|
43
|
+
last_index = match.end()
|
|
44
|
+
if last_index < len(s):
|
|
45
|
+
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
|
46
|
+
return t
|
|
47
|
+
else:
|
|
48
|
+
return self.sp_model.EncodeAsPieces(s)
|
|
49
|
+
|
|
50
|
+
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
|
51
|
+
assert type(s) is str
|
|
52
|
+
t = self.sp_model.encode(s)
|
|
53
|
+
if bos:
|
|
54
|
+
t = [self.bos_id] + t
|
|
55
|
+
if eos:
|
|
56
|
+
t = t + [self.eos_id]
|
|
57
|
+
return t
|
|
58
|
+
|
|
59
|
+
def decode(self, t: List[int]) -> str:
|
|
60
|
+
text, buffer = "", []
|
|
61
|
+
for token in t:
|
|
62
|
+
if token in self.index_special_tokens:
|
|
63
|
+
if buffer:
|
|
64
|
+
text += self.sp_model.decode(buffer)
|
|
65
|
+
buffer = []
|
|
66
|
+
text += self.index_special_tokens[token]
|
|
67
|
+
else:
|
|
68
|
+
buffer.append(token)
|
|
69
|
+
if buffer:
|
|
70
|
+
text += self.sp_model.decode(buffer)
|
|
71
|
+
return text
|
|
72
|
+
|
|
73
|
+
def decode_tokens(self, tokens: List[str]) -> str:
|
|
74
|
+
text = self.sp_model.DecodePieces(tokens)
|
|
75
|
+
return text
|
|
76
|
+
|
|
77
|
+
def convert_token_to_id(self, token):
|
|
78
|
+
""" Converts a token (str) in an id using the vocab. """
|
|
79
|
+
if token in self.special_tokens:
|
|
80
|
+
return self.special_tokens[token]
|
|
81
|
+
return self.sp_model.PieceToId(token)
|
|
82
|
+
|
|
83
|
+
def convert_id_to_token(self, index):
|
|
84
|
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
85
|
+
if index in self.index_special_tokens:
|
|
86
|
+
return self.index_special_tokens[index]
|
|
87
|
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
|
88
|
+
return ""
|
|
89
|
+
return self.sp_model.IdToPiece(index)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
94
|
+
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
|
95
|
+
|
|
96
|
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
|
97
|
+
|
|
98
|
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
|
99
|
+
**kwargs):
|
|
100
|
+
self.name = "GLMTokenizer"
|
|
101
|
+
|
|
102
|
+
self.vocab_file = vocab_file
|
|
103
|
+
self.tokenizer = SPTokenizer(vocab_file)
|
|
104
|
+
self.special_tokens = {
|
|
105
|
+
"<bos>": self.tokenizer.bos_id,
|
|
106
|
+
"<eos>": self.tokenizer.eos_id,
|
|
107
|
+
"<pad>": self.tokenizer.pad_id
|
|
108
|
+
}
|
|
109
|
+
self.encode_special_tokens = encode_special_tokens
|
|
110
|
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
111
|
+
encode_special_tokens=encode_special_tokens,
|
|
112
|
+
**kwargs)
|
|
113
|
+
|
|
114
|
+
def get_command(self, token):
|
|
115
|
+
if token in self.special_tokens:
|
|
116
|
+
return self.special_tokens[token]
|
|
117
|
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
|
118
|
+
return self.tokenizer.special_tokens[token]
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def unk_token(self) -> str:
|
|
122
|
+
return "<unk>"
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def pad_token(self) -> str:
|
|
126
|
+
return "<unk>"
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def pad_token_id(self):
|
|
130
|
+
return self.get_command("<pad>")
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def eos_token(self) -> str:
|
|
134
|
+
return "</s>"
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def eos_token_id(self):
|
|
138
|
+
return self.get_command("<eos>")
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def vocab_size(self):
|
|
142
|
+
return self.tokenizer.n_words
|
|
143
|
+
|
|
144
|
+
def get_vocab(self):
|
|
145
|
+
""" Returns vocab as a dict """
|
|
146
|
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
|
147
|
+
vocab.update(self.added_tokens_encoder)
|
|
148
|
+
return vocab
|
|
149
|
+
|
|
150
|
+
def _tokenize(self, text, **kwargs):
|
|
151
|
+
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
|
152
|
+
|
|
153
|
+
def _convert_token_to_id(self, token):
|
|
154
|
+
""" Converts a token (str) in an id using the vocab. """
|
|
155
|
+
return self.tokenizer.convert_token_to_id(token)
|
|
156
|
+
|
|
157
|
+
def _convert_id_to_token(self, index):
|
|
158
|
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
159
|
+
return self.tokenizer.convert_id_to_token(index)
|
|
160
|
+
|
|
161
|
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
162
|
+
return self.tokenizer.decode_tokens(tokens)
|
|
163
|
+
|
|
164
|
+
def save_vocabulary(self, save_directory, filename_prefix=None):
|
|
165
|
+
"""
|
|
166
|
+
Save the vocabulary and special tokens file to a directory.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
save_directory (`str`):
|
|
170
|
+
The directory in which to save the vocabulary.
|
|
171
|
+
filename_prefix (`str`, *optional*):
|
|
172
|
+
An optional prefix to add to the named of the saved files.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
`Tuple(str)`: Paths to the files saved.
|
|
176
|
+
"""
|
|
177
|
+
if os.path.isdir(save_directory):
|
|
178
|
+
vocab_file = os.path.join(
|
|
179
|
+
save_directory, self.vocab_files_names["vocab_file"]
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
vocab_file = save_directory
|
|
183
|
+
|
|
184
|
+
with open(self.vocab_file, 'rb') as fin:
|
|
185
|
+
proto_str = fin.read()
|
|
186
|
+
|
|
187
|
+
with open(vocab_file, "wb") as writer:
|
|
188
|
+
writer.write(proto_str)
|
|
189
|
+
|
|
190
|
+
return (vocab_file,)
|
|
191
|
+
|
|
192
|
+
def get_prefix_tokens(self):
|
|
193
|
+
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
|
194
|
+
return prefix_tokens
|
|
195
|
+
|
|
196
|
+
def build_single_message(self, role, metadata, message):
|
|
197
|
+
assert role in ["system", "user", "assistant", "observation"], role
|
|
198
|
+
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
|
199
|
+
message_tokens = self.tokenizer.encode(message)
|
|
200
|
+
tokens = role_tokens + message_tokens
|
|
201
|
+
return tokens
|
|
202
|
+
|
|
203
|
+
def build_chat_input(self, query, history=None, role="user"):
|
|
204
|
+
if history is None:
|
|
205
|
+
history = []
|
|
206
|
+
input_ids = []
|
|
207
|
+
for item in history:
|
|
208
|
+
content = item["content"]
|
|
209
|
+
if item["role"] == "system" and "tools" in item:
|
|
210
|
+
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
|
211
|
+
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
|
212
|
+
input_ids.extend(self.build_single_message(role, "", query))
|
|
213
|
+
input_ids.extend([self.get_command("<|assistant|>")])
|
|
214
|
+
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
|
215
|
+
|
|
216
|
+
def build_inputs_with_special_tokens(
|
|
217
|
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
|
218
|
+
) -> List[int]:
|
|
219
|
+
"""
|
|
220
|
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
|
221
|
+
adding special tokens. A BERT sequence has the following format:
|
|
222
|
+
|
|
223
|
+
- single sequence: `[CLS] X [SEP]`
|
|
224
|
+
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
token_ids_0 (`List[int]`):
|
|
228
|
+
List of IDs to which the special tokens will be added.
|
|
229
|
+
token_ids_1 (`List[int]`, *optional*):
|
|
230
|
+
Optional second list of IDs for sequence pairs.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
|
234
|
+
"""
|
|
235
|
+
prefix_tokens = self.get_prefix_tokens()
|
|
236
|
+
token_ids_0 = prefix_tokens + token_ids_0
|
|
237
|
+
if token_ids_1 is not None:
|
|
238
|
+
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
|
239
|
+
return token_ids_0
|
|
240
|
+
|
|
241
|
+
def _pad(
|
|
242
|
+
self,
|
|
243
|
+
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
|
244
|
+
max_length: Optional[int] = None,
|
|
245
|
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
|
246
|
+
pad_to_multiple_of: Optional[int] = None,
|
|
247
|
+
return_attention_mask: Optional[bool] = None,
|
|
248
|
+
) -> dict:
|
|
249
|
+
"""
|
|
250
|
+
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
encoded_inputs:
|
|
254
|
+
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
|
255
|
+
max_length: maximum length of the returned list and optionally padding length (see below).
|
|
256
|
+
Will truncate by taking into account the special tokens.
|
|
257
|
+
padding_strategy: PaddingStrategy to use for padding.
|
|
258
|
+
|
|
259
|
+
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
|
260
|
+
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
|
261
|
+
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
|
262
|
+
The tokenizer padding sides are defined in self.padding_side:
|
|
263
|
+
|
|
264
|
+
- 'left': pads on the left of the sequences
|
|
265
|
+
- 'right': pads on the right of the sequences
|
|
266
|
+
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
|
267
|
+
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
|
268
|
+
`>= 7.5` (Volta).
|
|
269
|
+
return_attention_mask:
|
|
270
|
+
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
|
271
|
+
"""
|
|
272
|
+
# Load from model defaults
|
|
273
|
+
assert self.padding_side == "left"
|
|
274
|
+
|
|
275
|
+
required_input = encoded_inputs[self.model_input_names[0]]
|
|
276
|
+
seq_length = len(required_input)
|
|
277
|
+
|
|
278
|
+
if padding_strategy == PaddingStrategy.LONGEST:
|
|
279
|
+
max_length = len(required_input)
|
|
280
|
+
|
|
281
|
+
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
|
282
|
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
|
283
|
+
|
|
284
|
+
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
|
285
|
+
|
|
286
|
+
# Initialize attention mask if not present.
|
|
287
|
+
if "attention_mask" not in encoded_inputs:
|
|
288
|
+
encoded_inputs["attention_mask"] = [1] * seq_length
|
|
289
|
+
|
|
290
|
+
if "position_ids" not in encoded_inputs:
|
|
291
|
+
encoded_inputs["position_ids"] = list(range(seq_length))
|
|
292
|
+
|
|
293
|
+
if needs_to_be_padded:
|
|
294
|
+
difference = max_length - len(required_input)
|
|
295
|
+
|
|
296
|
+
if "attention_mask" in encoded_inputs:
|
|
297
|
+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
|
298
|
+
if "position_ids" in encoded_inputs:
|
|
299
|
+
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
|
300
|
+
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
|
301
|
+
|
|
302
|
+
return encoded_inputs
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class KolorsPrompter(BasePrompter):
|
|
307
|
+
def __init__(
|
|
308
|
+
self,
|
|
309
|
+
tokenizer_path=None
|
|
310
|
+
):
|
|
311
|
+
if tokenizer_path is None:
|
|
312
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
313
|
+
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
|
|
314
|
+
super().__init__()
|
|
315
|
+
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
|
316
|
+
self.text_encoder: ChatGLMModel = None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def fetch_models(self, text_encoder: ChatGLMModel = None):
|
|
320
|
+
self.text_encoder = text_encoder
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
|
324
|
+
text_inputs = tokenizer(
|
|
325
|
+
prompt,
|
|
326
|
+
padding="max_length",
|
|
327
|
+
max_length=max_length,
|
|
328
|
+
truncation=True,
|
|
329
|
+
return_tensors="pt",
|
|
330
|
+
).to(device)
|
|
331
|
+
output = text_encoder(
|
|
332
|
+
input_ids=text_inputs['input_ids'] ,
|
|
333
|
+
attention_mask=text_inputs['attention_mask'],
|
|
334
|
+
position_ids=text_inputs['position_ids'],
|
|
335
|
+
output_hidden_states=True
|
|
336
|
+
)
|
|
337
|
+
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
|
|
338
|
+
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
|
|
339
|
+
return prompt_emb, pooled_prompt_emb
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def encode_prompt(
|
|
343
|
+
self,
|
|
344
|
+
prompt,
|
|
345
|
+
clip_skip=1,
|
|
346
|
+
clip_skip_2=2,
|
|
347
|
+
positive=True,
|
|
348
|
+
device="cuda"
|
|
349
|
+
):
|
|
350
|
+
prompt = self.process_prompt(prompt, positive=positive)
|
|
351
|
+
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, self.text_encoder, self.tokenizer, 256, clip_skip_2, device)
|
|
352
|
+
|
|
353
|
+
return pooled_prompt_emb, prompt_emb
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from transformers import AutoTokenizer
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BeautifulPrompt(torch.nn.Module):
|
|
8
|
+
def __init__(self, tokenizer_path=None, model=None, template=""):
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
11
|
+
self.model = model
|
|
12
|
+
self.template = template
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def from_model_manager(model_nameger: ModelManager):
|
|
17
|
+
model, model_path = model_nameger.fetch_model("beautiful_prompt", require_model_path=True)
|
|
18
|
+
template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
|
19
|
+
if model_path.endswith("v2"):
|
|
20
|
+
template = """Converts a simple image description into a prompt. \
|
|
21
|
+
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
|
22
|
+
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
|
23
|
+
but make sure there is a correlation between the input and output.\n\
|
|
24
|
+
### Input: {raw_prompt}\n### Output:"""
|
|
25
|
+
beautiful_prompt = BeautifulPrompt(
|
|
26
|
+
tokenizer_path=model_path,
|
|
27
|
+
model=model,
|
|
28
|
+
template=template
|
|
29
|
+
)
|
|
30
|
+
return beautiful_prompt
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __call__(self, raw_prompt, positive=True, **kwargs):
|
|
34
|
+
if positive:
|
|
35
|
+
model_input = self.template.format(raw_prompt=raw_prompt)
|
|
36
|
+
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
|
37
|
+
outputs = self.model.generate(
|
|
38
|
+
input_ids,
|
|
39
|
+
max_new_tokens=384,
|
|
40
|
+
do_sample=True,
|
|
41
|
+
temperature=0.9,
|
|
42
|
+
top_k=50,
|
|
43
|
+
top_p=0.95,
|
|
44
|
+
repetition_penalty=1.1,
|
|
45
|
+
num_return_sequences=1
|
|
46
|
+
)
|
|
47
|
+
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
|
48
|
+
outputs[:, input_ids.size(1):],
|
|
49
|
+
skip_special_tokens=True
|
|
50
|
+
)[0].strip()
|
|
51
|
+
print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
|
|
52
|
+
return prompt
|
|
53
|
+
else:
|
|
54
|
+
return raw_prompt
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Translator(torch.nn.Module):
|
|
59
|
+
def __init__(self, tokenizer_path=None, model=None):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
|
62
|
+
self.model = model
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def from_model_manager(model_nameger: ModelManager):
|
|
67
|
+
model, model_path = model_nameger.fetch_model("translator", require_model_path=True)
|
|
68
|
+
translator = Translator(tokenizer_path=model_path, model=model)
|
|
69
|
+
return translator
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def __call__(self, prompt, **kwargs):
|
|
73
|
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
|
74
|
+
output_ids = self.model.generate(input_ids)
|
|
75
|
+
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
|
76
|
+
print(f"Your prompt is translated: {prompt}")
|
|
77
|
+
return prompt
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from .base_prompter import BasePrompter
|
|
2
|
+
from ..models.model_manager import ModelManager
|
|
3
|
+
from ..models import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
4
|
+
from transformers import CLIPTokenizer, T5TokenizerFast
|
|
5
|
+
import os, torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SD3Prompter(BasePrompter):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
tokenizer_1_path=None,
|
|
12
|
+
tokenizer_2_path=None,
|
|
13
|
+
tokenizer_3_path=None
|
|
14
|
+
):
|
|
15
|
+
if tokenizer_1_path is None:
|
|
16
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
17
|
+
tokenizer_1_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
|
18
|
+
if tokenizer_2_path is None:
|
|
19
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
20
|
+
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_2")
|
|
21
|
+
if tokenizer_3_path is None:
|
|
22
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
23
|
+
tokenizer_3_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion_3/tokenizer_3")
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
|
26
|
+
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
|
27
|
+
self.tokenizer_3 = T5TokenizerFast.from_pretrained(tokenizer_3_path)
|
|
28
|
+
self.text_encoder_1: SD3TextEncoder1 = None
|
|
29
|
+
self.text_encoder_2: SD3TextEncoder2 = None
|
|
30
|
+
self.text_encoder_3: SD3TextEncoder3 = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: SD3TextEncoder2 = None, text_encoder_3: SD3TextEncoder3 = None):
|
|
34
|
+
self.text_encoder_1 = text_encoder_1
|
|
35
|
+
self.text_encoder_2 = text_encoder_2
|
|
36
|
+
self.text_encoder_3 = text_encoder_3
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def encode_prompt_using_clip(self, prompt, text_encoder, tokenizer, max_length, device):
|
|
40
|
+
input_ids = tokenizer(
|
|
41
|
+
prompt,
|
|
42
|
+
return_tensors="pt",
|
|
43
|
+
padding="max_length",
|
|
44
|
+
max_length=max_length,
|
|
45
|
+
truncation=True
|
|
46
|
+
).input_ids.to(device)
|
|
47
|
+
pooled_prompt_emb, prompt_emb = text_encoder(input_ids)
|
|
48
|
+
return pooled_prompt_emb, prompt_emb
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
|
|
52
|
+
input_ids = tokenizer(
|
|
53
|
+
prompt,
|
|
54
|
+
return_tensors="pt",
|
|
55
|
+
padding="max_length",
|
|
56
|
+
max_length=max_length,
|
|
57
|
+
truncation=True,
|
|
58
|
+
add_special_tokens=True,
|
|
59
|
+
).input_ids.to(device)
|
|
60
|
+
prompt_emb = text_encoder(input_ids)
|
|
61
|
+
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
|
62
|
+
|
|
63
|
+
return prompt_emb
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def encode_prompt(
|
|
67
|
+
self,
|
|
68
|
+
prompt,
|
|
69
|
+
positive=True,
|
|
70
|
+
device="cuda"
|
|
71
|
+
):
|
|
72
|
+
prompt = self.process_prompt(prompt, positive=positive)
|
|
73
|
+
|
|
74
|
+
# CLIP
|
|
75
|
+
pooled_prompt_emb_1, prompt_emb_1 = self.encode_prompt_using_clip(prompt, self.text_encoder_1, self.tokenizer_1, 77, device)
|
|
76
|
+
pooled_prompt_emb_2, prompt_emb_2 = self.encode_prompt_using_clip(prompt, self.text_encoder_2, self.tokenizer_2, 77, device)
|
|
77
|
+
|
|
78
|
+
# T5
|
|
79
|
+
if self.text_encoder_3 is None:
|
|
80
|
+
prompt_emb_3 = torch.zeros((prompt_emb_1.shape[0], 256, 4096), dtype=prompt_emb_1.dtype, device=device)
|
|
81
|
+
else:
|
|
82
|
+
prompt_emb_3 = self.encode_prompt_using_t5(prompt, self.text_encoder_3, self.tokenizer_3, 256, device)
|
|
83
|
+
prompt_emb_3 = prompt_emb_3.to(prompt_emb_1.dtype) # float32 -> float16
|
|
84
|
+
|
|
85
|
+
# Merge
|
|
86
|
+
prompt_emb = torch.cat([
|
|
87
|
+
torch.nn.functional.pad(torch.cat([prompt_emb_1, prompt_emb_2], dim=-1), (0, 4096 - 768 - 1280)),
|
|
88
|
+
prompt_emb_3
|
|
89
|
+
], dim=-2)
|
|
90
|
+
pooled_prompt_emb = torch.cat([pooled_prompt_emb_1, pooled_prompt_emb_2], dim=-1)
|
|
91
|
+
|
|
92
|
+
return prompt_emb, pooled_prompt_emb
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from .base_prompter import BasePrompter, tokenize_long_prompt
|
|
2
|
+
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
|
|
3
|
+
from ..models import SDTextEncoder
|
|
4
|
+
from transformers import CLIPTokenizer
|
|
5
|
+
import torch, os
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SDPrompter(BasePrompter):
|
|
10
|
+
def __init__(self, tokenizer_path=None):
|
|
11
|
+
if tokenizer_path is None:
|
|
12
|
+
base_path = os.path.dirname(os.path.dirname(__file__))
|
|
13
|
+
tokenizer_path = os.path.join(base_path, "tokenizer_configs/stable_diffusion/tokenizer")
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
|
16
|
+
self.text_encoder: SDTextEncoder = None
|
|
17
|
+
self.textual_inversion_dict = {}
|
|
18
|
+
self.keyword_dict = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def fetch_models(self, text_encoder: SDTextEncoder = None):
|
|
22
|
+
self.text_encoder = text_encoder
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def add_textual_inversions_to_model(self, textual_inversion_dict, text_encoder):
|
|
26
|
+
dtype = next(iter(text_encoder.parameters())).dtype
|
|
27
|
+
state_dict = text_encoder.token_embedding.state_dict()
|
|
28
|
+
token_embeddings = [state_dict["weight"]]
|
|
29
|
+
for keyword in textual_inversion_dict:
|
|
30
|
+
_, embeddings = textual_inversion_dict[keyword]
|
|
31
|
+
token_embeddings.append(embeddings.to(dtype=dtype, device=token_embeddings[0].device))
|
|
32
|
+
token_embeddings = torch.concat(token_embeddings, dim=0)
|
|
33
|
+
state_dict["weight"] = token_embeddings
|
|
34
|
+
text_encoder.token_embedding = torch.nn.Embedding(token_embeddings.shape[0], token_embeddings.shape[1])
|
|
35
|
+
text_encoder.token_embedding = text_encoder.token_embedding.to(dtype=dtype, device=token_embeddings[0].device)
|
|
36
|
+
text_encoder.token_embedding.load_state_dict(state_dict)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def add_textual_inversions_to_tokenizer(self, textual_inversion_dict, tokenizer):
|
|
40
|
+
additional_tokens = []
|
|
41
|
+
for keyword in textual_inversion_dict:
|
|
42
|
+
tokens, _ = textual_inversion_dict[keyword]
|
|
43
|
+
additional_tokens += tokens
|
|
44
|
+
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
|
45
|
+
tokenizer.add_tokens(additional_tokens)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def load_textual_inversions(self, model_paths):
|
|
49
|
+
for model_path in model_paths:
|
|
50
|
+
keyword = os.path.splitext(os.path.split(model_path)[-1])[0]
|
|
51
|
+
state_dict = load_state_dict(model_path)
|
|
52
|
+
|
|
53
|
+
# Search for embeddings
|
|
54
|
+
for embeddings in search_for_embeddings(state_dict):
|
|
55
|
+
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
|
56
|
+
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
|
57
|
+
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
|
58
|
+
|
|
59
|
+
self.add_textual_inversions_to_model(self.textual_inversion_dict, self.text_encoder)
|
|
60
|
+
self.add_textual_inversions_to_tokenizer(self.textual_inversion_dict, self.tokenizer)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def encode_prompt(self, prompt, clip_skip=1, device="cuda", positive=True):
|
|
64
|
+
prompt = self.process_prompt(prompt, positive=positive)
|
|
65
|
+
for keyword in self.keyword_dict:
|
|
66
|
+
if keyword in prompt:
|
|
67
|
+
print(f"Textual inversion {keyword} is enabled.")
|
|
68
|
+
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
|
69
|
+
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
|
70
|
+
prompt_emb = self.text_encoder(input_ids, clip_skip=clip_skip)
|
|
71
|
+
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
|
72
|
+
|
|
73
|
+
return prompt_emb
|