diffsynth 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth/__init__.py +6 -0
- diffsynth/configs/__init__.py +0 -0
- diffsynth/configs/model_config.py +243 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/simple_text_image.py +35 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/ESRGAN/__init__.py +118 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +242 -0
- diffsynth/extensions/__init__.py +0 -0
- diffsynth/models/__init__.py +1 -0
- diffsynth/models/attention.py +89 -0
- diffsynth/models/downloader.py +66 -0
- diffsynth/models/hunyuan_dit.py +451 -0
- diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
- diffsynth/models/kolors_text_encoder.py +1363 -0
- diffsynth/models/lora.py +195 -0
- diffsynth/models/model_manager.py +536 -0
- diffsynth/models/sd3_dit.py +798 -0
- diffsynth/models/sd3_text_encoder.py +1107 -0
- diffsynth/models/sd3_vae_decoder.py +81 -0
- diffsynth/models/sd3_vae_encoder.py +95 -0
- diffsynth/models/sd_controlnet.py +588 -0
- diffsynth/models/sd_ipadapter.py +57 -0
- diffsynth/models/sd_motion.py +199 -0
- diffsynth/models/sd_text_encoder.py +321 -0
- diffsynth/models/sd_unet.py +1108 -0
- diffsynth/models/sd_vae_decoder.py +336 -0
- diffsynth/models/sd_vae_encoder.py +282 -0
- diffsynth/models/sdxl_ipadapter.py +122 -0
- diffsynth/models/sdxl_motion.py +104 -0
- diffsynth/models/sdxl_text_encoder.py +759 -0
- diffsynth/models/sdxl_unet.py +1899 -0
- diffsynth/models/sdxl_vae_decoder.py +24 -0
- diffsynth/models/sdxl_vae_encoder.py +24 -0
- diffsynth/models/svd_image_encoder.py +505 -0
- diffsynth/models/svd_unet.py +2004 -0
- diffsynth/models/svd_vae_decoder.py +578 -0
- diffsynth/models/svd_vae_encoder.py +139 -0
- diffsynth/models/tiler.py +106 -0
- diffsynth/pipelines/__init__.py +9 -0
- diffsynth/pipelines/base.py +34 -0
- diffsynth/pipelines/dancer.py +178 -0
- diffsynth/pipelines/hunyuan_image.py +274 -0
- diffsynth/pipelines/pipeline_runner.py +105 -0
- diffsynth/pipelines/sd3_image.py +132 -0
- diffsynth/pipelines/sd_image.py +173 -0
- diffsynth/pipelines/sd_video.py +266 -0
- diffsynth/pipelines/sdxl_image.py +191 -0
- diffsynth/pipelines/sdxl_video.py +223 -0
- diffsynth/pipelines/svd_video.py +297 -0
- diffsynth/processors/FastBlend.py +142 -0
- diffsynth/processors/PILEditor.py +28 -0
- diffsynth/processors/RIFE.py +77 -0
- diffsynth/processors/__init__.py +0 -0
- diffsynth/processors/base.py +6 -0
- diffsynth/processors/sequencial_processor.py +41 -0
- diffsynth/prompters/__init__.py +6 -0
- diffsynth/prompters/base_prompter.py +57 -0
- diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
- diffsynth/prompters/kolors_prompter.py +353 -0
- diffsynth/prompters/prompt_refiners.py +77 -0
- diffsynth/prompters/sd3_prompter.py +92 -0
- diffsynth/prompters/sd_prompter.py +73 -0
- diffsynth/prompters/sdxl_prompter.py +61 -0
- diffsynth/schedulers/__init__.py +3 -0
- diffsynth/schedulers/continuous_ode.py +59 -0
- diffsynth/schedulers/ddim.py +79 -0
- diffsynth/schedulers/flow_match.py +51 -0
- diffsynth/tokenizer_configs/__init__.py +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
- diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
- diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
- diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
- diffsynth/trainers/__init__.py +0 -0
- diffsynth/trainers/text_to_image.py +253 -0
- diffsynth-1.0.0.dist-info/LICENSE +201 -0
- diffsynth-1.0.0.dist-info/METADATA +23 -0
- diffsynth-1.0.0.dist-info/RECORD +120 -0
- diffsynth-1.0.0.dist-info/WHEEL +5 -0
- diffsynth-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
import os, torch, hashlib, json, importlib
|
|
2
|
+
from safetensors import safe_open
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
from typing_extensions import Literal, TypeAlias
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from .downloader import download_models, Preset_model_id, Preset_model_website
|
|
8
|
+
|
|
9
|
+
from .sd_text_encoder import SDTextEncoder
|
|
10
|
+
from .sd_unet import SDUNet
|
|
11
|
+
from .sd_vae_encoder import SDVAEEncoder
|
|
12
|
+
from .sd_vae_decoder import SDVAEDecoder
|
|
13
|
+
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
|
|
14
|
+
|
|
15
|
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
|
16
|
+
from .sdxl_unet import SDXLUNet
|
|
17
|
+
from .sdxl_vae_decoder import SDXLVAEDecoder
|
|
18
|
+
from .sdxl_vae_encoder import SDXLVAEEncoder
|
|
19
|
+
|
|
20
|
+
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
|
21
|
+
from .sd3_dit import SD3DiT
|
|
22
|
+
from .sd3_vae_decoder import SD3VAEDecoder
|
|
23
|
+
from .sd3_vae_encoder import SD3VAEEncoder
|
|
24
|
+
|
|
25
|
+
from .sd_controlnet import SDControlNet
|
|
26
|
+
|
|
27
|
+
from .sd_motion import SDMotionModel
|
|
28
|
+
from .sdxl_motion import SDXLMotionModel
|
|
29
|
+
|
|
30
|
+
from .svd_image_encoder import SVDImageEncoder
|
|
31
|
+
from .svd_unet import SVDUNet
|
|
32
|
+
from .svd_vae_decoder import SVDVAEDecoder
|
|
33
|
+
from .svd_vae_encoder import SVDVAEEncoder
|
|
34
|
+
|
|
35
|
+
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
|
36
|
+
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
|
37
|
+
|
|
38
|
+
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
|
39
|
+
from .hunyuan_dit import HunyuanDiT
|
|
40
|
+
|
|
41
|
+
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def load_state_dict(file_path, torch_dtype=None):
|
|
46
|
+
if file_path.endswith(".safetensors"):
|
|
47
|
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
|
48
|
+
else:
|
|
49
|
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
|
53
|
+
state_dict = {}
|
|
54
|
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
55
|
+
for k in f.keys():
|
|
56
|
+
state_dict[k] = f.get_tensor(k)
|
|
57
|
+
if torch_dtype is not None:
|
|
58
|
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
|
59
|
+
return state_dict
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
|
63
|
+
state_dict = torch.load(file_path, map_location="cpu")
|
|
64
|
+
if torch_dtype is not None:
|
|
65
|
+
for i in state_dict:
|
|
66
|
+
if isinstance(state_dict[i], torch.Tensor):
|
|
67
|
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
|
68
|
+
return state_dict
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def search_for_embeddings(state_dict):
|
|
72
|
+
embeddings = []
|
|
73
|
+
for k in state_dict:
|
|
74
|
+
if isinstance(state_dict[k], torch.Tensor):
|
|
75
|
+
embeddings.append(state_dict[k])
|
|
76
|
+
elif isinstance(state_dict[k], dict):
|
|
77
|
+
embeddings += search_for_embeddings(state_dict[k])
|
|
78
|
+
return embeddings
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def search_parameter(param, state_dict):
|
|
82
|
+
for name, param_ in state_dict.items():
|
|
83
|
+
if param.numel() == param_.numel():
|
|
84
|
+
if param.shape == param_.shape:
|
|
85
|
+
if torch.dist(param, param_) < 1e-6:
|
|
86
|
+
return name
|
|
87
|
+
else:
|
|
88
|
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
|
89
|
+
return name
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
|
94
|
+
matched_keys = set()
|
|
95
|
+
with torch.no_grad():
|
|
96
|
+
for name in source_state_dict:
|
|
97
|
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
|
98
|
+
if rename is not None:
|
|
99
|
+
print(f'"{name}": "{rename}",')
|
|
100
|
+
matched_keys.add(rename)
|
|
101
|
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
|
102
|
+
length = source_state_dict[name].shape[0] // 3
|
|
103
|
+
rename = []
|
|
104
|
+
for i in range(3):
|
|
105
|
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
|
106
|
+
if None not in rename:
|
|
107
|
+
print(f'"{name}": {rename},')
|
|
108
|
+
for rename_ in rename:
|
|
109
|
+
matched_keys.add(rename_)
|
|
110
|
+
for name in target_state_dict:
|
|
111
|
+
if name not in matched_keys:
|
|
112
|
+
print("Cannot find", name, target_state_dict[name].shape)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def search_for_files(folder, extensions):
|
|
116
|
+
files = []
|
|
117
|
+
if os.path.isdir(folder):
|
|
118
|
+
for file in sorted(os.listdir(folder)):
|
|
119
|
+
files += search_for_files(os.path.join(folder, file), extensions)
|
|
120
|
+
elif os.path.isfile(folder):
|
|
121
|
+
for extension in extensions:
|
|
122
|
+
if folder.endswith(extension):
|
|
123
|
+
files.append(folder)
|
|
124
|
+
break
|
|
125
|
+
return files
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
|
129
|
+
keys = []
|
|
130
|
+
for key, value in state_dict.items():
|
|
131
|
+
if isinstance(key, str):
|
|
132
|
+
if isinstance(value, Tensor):
|
|
133
|
+
if with_shape:
|
|
134
|
+
shape = "_".join(map(str, list(value.shape)))
|
|
135
|
+
keys.append(key + ":" + shape)
|
|
136
|
+
keys.append(key)
|
|
137
|
+
elif isinstance(value, dict):
|
|
138
|
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
|
139
|
+
keys.sort()
|
|
140
|
+
keys_str = ",".join(keys)
|
|
141
|
+
return keys_str
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def split_state_dict_with_prefix(state_dict):
|
|
145
|
+
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
|
146
|
+
prefix_dict = {}
|
|
147
|
+
for key in keys:
|
|
148
|
+
prefix = key if "." not in key else key.split(".")[0]
|
|
149
|
+
if prefix not in prefix_dict:
|
|
150
|
+
prefix_dict[prefix] = []
|
|
151
|
+
prefix_dict[prefix].append(key)
|
|
152
|
+
state_dicts = []
|
|
153
|
+
for prefix, keys in prefix_dict.items():
|
|
154
|
+
sub_state_dict = {key: state_dict[key] for key in keys}
|
|
155
|
+
state_dicts.append(sub_state_dict)
|
|
156
|
+
return state_dicts
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
|
160
|
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
|
161
|
+
keys_str = keys_str.encode(encoding="UTF-8")
|
|
162
|
+
return hashlib.md5(keys_str).hexdigest()
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
|
166
|
+
loaded_model_names, loaded_models = [], []
|
|
167
|
+
for model_name, model_class in zip(model_names, model_classes):
|
|
168
|
+
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
|
169
|
+
state_dict_converter = model_class.state_dict_converter()
|
|
170
|
+
if model_resource == "civitai":
|
|
171
|
+
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
|
172
|
+
elif model_resource == "diffusers":
|
|
173
|
+
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
|
174
|
+
if isinstance(state_dict_results, tuple):
|
|
175
|
+
model_state_dict, extra_kwargs = state_dict_results
|
|
176
|
+
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
|
177
|
+
else:
|
|
178
|
+
model_state_dict, extra_kwargs = state_dict_results, {}
|
|
179
|
+
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
|
180
|
+
model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
|
|
181
|
+
model.load_state_dict(model_state_dict)
|
|
182
|
+
loaded_model_names.append(model_name)
|
|
183
|
+
loaded_models.append(model)
|
|
184
|
+
return loaded_model_names, loaded_models
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
|
188
|
+
loaded_model_names, loaded_models = [], []
|
|
189
|
+
for model_name, model_class in zip(model_names, model_classes):
|
|
190
|
+
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
|
191
|
+
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
|
192
|
+
model = model.half()
|
|
193
|
+
model = model.to(device=device)
|
|
194
|
+
loaded_model_names.append(model_name)
|
|
195
|
+
loaded_models.append(model)
|
|
196
|
+
return loaded_model_names, loaded_models
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
|
200
|
+
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
|
201
|
+
base_state_dict = base_model.state_dict()
|
|
202
|
+
base_model.to("cpu")
|
|
203
|
+
del base_model
|
|
204
|
+
model = model_class(**extra_kwargs)
|
|
205
|
+
model.load_state_dict(base_state_dict, strict=False)
|
|
206
|
+
model.load_state_dict(state_dict, strict=False)
|
|
207
|
+
model.to(dtype=torch_dtype, device=device)
|
|
208
|
+
return model
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
|
212
|
+
loaded_model_names, loaded_models = [], []
|
|
213
|
+
for model_name, model_class in zip(model_names, model_classes):
|
|
214
|
+
while True:
|
|
215
|
+
for model_id in range(len(model_manager.model)):
|
|
216
|
+
base_model_name = model_manager.model_name[model_id]
|
|
217
|
+
if base_model_name == model_name:
|
|
218
|
+
base_model_path = model_manager.model_path[model_id]
|
|
219
|
+
base_model = model_manager.model[model_id]
|
|
220
|
+
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
|
221
|
+
patched_model = load_single_patch_model_from_single_file(
|
|
222
|
+
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
|
223
|
+
loaded_model_names.append(base_model_name)
|
|
224
|
+
loaded_models.append(patched_model)
|
|
225
|
+
model_manager.model.pop(model_id)
|
|
226
|
+
model_manager.model_path.pop(model_id)
|
|
227
|
+
model_manager.model_name.pop(model_id)
|
|
228
|
+
break
|
|
229
|
+
else:
|
|
230
|
+
break
|
|
231
|
+
return loaded_model_names, loaded_models
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class ModelDetectorTemplate:
|
|
236
|
+
def __init__(self):
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
def match(self, file_path="", state_dict={}):
|
|
240
|
+
return False
|
|
241
|
+
|
|
242
|
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
243
|
+
return [], []
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class ModelDetectorFromSingleFile:
|
|
248
|
+
def __init__(self, model_loader_configs=[]):
|
|
249
|
+
self.keys_hash_with_shape_dict = {}
|
|
250
|
+
self.keys_hash_dict = {}
|
|
251
|
+
for metadata in model_loader_configs:
|
|
252
|
+
self.add_model_metadata(*metadata)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
|
256
|
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
|
257
|
+
if keys_hash is not None:
|
|
258
|
+
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def match(self, file_path="", state_dict={}):
|
|
262
|
+
if os.path.isdir(file_path):
|
|
263
|
+
return False
|
|
264
|
+
if len(state_dict) == 0:
|
|
265
|
+
state_dict = load_state_dict(file_path)
|
|
266
|
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
267
|
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
268
|
+
return True
|
|
269
|
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
270
|
+
if keys_hash in self.keys_hash_dict:
|
|
271
|
+
return True
|
|
272
|
+
return False
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
276
|
+
if len(state_dict) == 0:
|
|
277
|
+
state_dict = load_state_dict(file_path)
|
|
278
|
+
|
|
279
|
+
# Load models with strict matching
|
|
280
|
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
281
|
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
282
|
+
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
283
|
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
284
|
+
return loaded_model_names, loaded_models
|
|
285
|
+
|
|
286
|
+
# Load models without strict matching
|
|
287
|
+
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
|
288
|
+
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
|
289
|
+
if keys_hash in self.keys_hash_dict:
|
|
290
|
+
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
|
291
|
+
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
|
292
|
+
return loaded_model_names, loaded_models
|
|
293
|
+
|
|
294
|
+
return loaded_model_names, loaded_models
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
299
|
+
def __init__(self, model_loader_configs=[]):
|
|
300
|
+
super().__init__(model_loader_configs)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def match(self, file_path="", state_dict={}):
|
|
304
|
+
if os.path.isdir(file_path):
|
|
305
|
+
return False
|
|
306
|
+
if len(state_dict) == 0:
|
|
307
|
+
state_dict = load_state_dict(file_path)
|
|
308
|
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
309
|
+
for sub_state_dict in splited_state_dict:
|
|
310
|
+
if super().match(file_path, sub_state_dict):
|
|
311
|
+
return True
|
|
312
|
+
return False
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
316
|
+
# Split the state_dict and load from each component
|
|
317
|
+
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
|
318
|
+
valid_state_dict = {}
|
|
319
|
+
for sub_state_dict in splited_state_dict:
|
|
320
|
+
if super().match(file_path, sub_state_dict):
|
|
321
|
+
valid_state_dict.update(sub_state_dict)
|
|
322
|
+
if super().match(file_path, valid_state_dict):
|
|
323
|
+
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
324
|
+
else:
|
|
325
|
+
loaded_model_names, loaded_models = [], []
|
|
326
|
+
for sub_state_dict in splited_state_dict:
|
|
327
|
+
if super().match(file_path, sub_state_dict):
|
|
328
|
+
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
|
329
|
+
loaded_model_names += loaded_model_names_
|
|
330
|
+
loaded_models += loaded_models_
|
|
331
|
+
return loaded_model_names, loaded_models
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class ModelDetectorFromHuggingfaceFolder:
|
|
336
|
+
def __init__(self, model_loader_configs=[]):
|
|
337
|
+
self.architecture_dict = {}
|
|
338
|
+
for metadata in model_loader_configs:
|
|
339
|
+
self.add_model_metadata(*metadata)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def add_model_metadata(self, architecture, huggingface_lib, model_name):
|
|
343
|
+
self.architecture_dict[architecture] = (huggingface_lib, model_name)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def match(self, file_path="", state_dict={}):
|
|
347
|
+
if os.path.isfile(file_path):
|
|
348
|
+
return False
|
|
349
|
+
file_list = os.listdir(file_path)
|
|
350
|
+
if "config.json" not in file_list:
|
|
351
|
+
return False
|
|
352
|
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
353
|
+
config = json.load(f)
|
|
354
|
+
if "architectures" not in config:
|
|
355
|
+
return False
|
|
356
|
+
return True
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
|
360
|
+
with open(os.path.join(file_path, "config.json"), "r") as f:
|
|
361
|
+
config = json.load(f)
|
|
362
|
+
loaded_model_names, loaded_models = [], []
|
|
363
|
+
for architecture in config["architectures"]:
|
|
364
|
+
huggingface_lib, model_name = self.architecture_dict[architecture]
|
|
365
|
+
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
|
366
|
+
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
|
367
|
+
loaded_model_names += loaded_model_names_
|
|
368
|
+
loaded_models += loaded_models_
|
|
369
|
+
return loaded_model_names, loaded_models
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class ModelDetectorFromPatchedSingleFile:
|
|
374
|
+
def __init__(self, model_loader_configs=[]):
|
|
375
|
+
self.keys_hash_with_shape_dict = {}
|
|
376
|
+
for metadata in model_loader_configs:
|
|
377
|
+
self.add_model_metadata(*metadata)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
|
381
|
+
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def match(self, file_path="", state_dict={}):
|
|
385
|
+
if os.path.isdir(file_path):
|
|
386
|
+
return False
|
|
387
|
+
if len(state_dict) == 0:
|
|
388
|
+
state_dict = load_state_dict(file_path)
|
|
389
|
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
390
|
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
391
|
+
return True
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
|
396
|
+
if len(state_dict) == 0:
|
|
397
|
+
state_dict = load_state_dict(file_path)
|
|
398
|
+
|
|
399
|
+
# Load models with strict matching
|
|
400
|
+
loaded_model_names, loaded_models = [], []
|
|
401
|
+
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
|
402
|
+
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
|
403
|
+
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
|
404
|
+
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
|
405
|
+
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
|
406
|
+
loaded_model_names += loaded_model_names_
|
|
407
|
+
loaded_models += loaded_models_
|
|
408
|
+
return loaded_model_names, loaded_models
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class ModelManager:
|
|
413
|
+
def __init__(
|
|
414
|
+
self,
|
|
415
|
+
torch_dtype=torch.float16,
|
|
416
|
+
device="cuda",
|
|
417
|
+
model_id_list: List[Preset_model_id] = [],
|
|
418
|
+
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
|
419
|
+
file_path_list: List[str] = [],
|
|
420
|
+
):
|
|
421
|
+
self.torch_dtype = torch_dtype
|
|
422
|
+
self.device = device
|
|
423
|
+
self.model = []
|
|
424
|
+
self.model_path = []
|
|
425
|
+
self.model_name = []
|
|
426
|
+
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
|
427
|
+
self.model_detector = [
|
|
428
|
+
ModelDetectorFromSingleFile(model_loader_configs),
|
|
429
|
+
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
430
|
+
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
|
431
|
+
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
|
432
|
+
]
|
|
433
|
+
self.load_models(downloaded_files + file_path_list)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
|
437
|
+
print(f"Loading models from file: {file_path}")
|
|
438
|
+
if len(state_dict) == 0:
|
|
439
|
+
state_dict = load_state_dict(file_path)
|
|
440
|
+
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
|
441
|
+
for model_name, model in zip(model_names, models):
|
|
442
|
+
self.model.append(model)
|
|
443
|
+
self.model_path.append(file_path)
|
|
444
|
+
self.model_name.append(model_name)
|
|
445
|
+
print(f" The following models are loaded: {model_names}.")
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
|
449
|
+
print(f"Loading models from folder: {file_path}")
|
|
450
|
+
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
|
451
|
+
for model_name, model in zip(model_names, models):
|
|
452
|
+
self.model.append(model)
|
|
453
|
+
self.model_path.append(file_path)
|
|
454
|
+
self.model_name.append(model_name)
|
|
455
|
+
print(f" The following models are loaded: {model_names}.")
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
|
459
|
+
print(f"Loading patch models from file: {file_path}")
|
|
460
|
+
model_names, models = load_patch_model_from_single_file(
|
|
461
|
+
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
|
462
|
+
for model_name, model in zip(model_names, models):
|
|
463
|
+
self.model.append(model)
|
|
464
|
+
self.model_path.append(file_path)
|
|
465
|
+
self.model_name.append(model_name)
|
|
466
|
+
print(f" The following patched models are loaded: {model_names}.")
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
|
470
|
+
print(f"Loading LoRA models from file: {file_path}")
|
|
471
|
+
if len(state_dict) == 0:
|
|
472
|
+
state_dict = load_state_dict(file_path)
|
|
473
|
+
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
|
474
|
+
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
|
|
475
|
+
match_results = lora.match(model, state_dict)
|
|
476
|
+
if match_results is not None:
|
|
477
|
+
print(f" Adding LoRA to {model_name} ({model_path}).")
|
|
478
|
+
lora_prefix, model_resource = match_results
|
|
479
|
+
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
|
480
|
+
break
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def load_model(self, file_path, model_names=None):
|
|
484
|
+
print(f"Loading models from: {file_path}")
|
|
485
|
+
if os.path.isfile(file_path):
|
|
486
|
+
state_dict = load_state_dict(file_path)
|
|
487
|
+
else:
|
|
488
|
+
state_dict = None
|
|
489
|
+
for model_detector in self.model_detector:
|
|
490
|
+
if model_detector.match(file_path, state_dict):
|
|
491
|
+
model_names, models = model_detector.load(
|
|
492
|
+
file_path, state_dict,
|
|
493
|
+
device=self.device, torch_dtype=self.torch_dtype,
|
|
494
|
+
allowed_model_names=model_names, model_manager=self
|
|
495
|
+
)
|
|
496
|
+
for model_name, model in zip(model_names, models):
|
|
497
|
+
self.model.append(model)
|
|
498
|
+
self.model_path.append(file_path)
|
|
499
|
+
self.model_name.append(model_name)
|
|
500
|
+
print(f" The following models are loaded: {model_names}.")
|
|
501
|
+
break
|
|
502
|
+
else:
|
|
503
|
+
print(f" We cannot detect the model type. No models are loaded.")
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def load_models(self, file_path_list, model_names=None):
|
|
507
|
+
for file_path in file_path_list:
|
|
508
|
+
self.load_model(file_path, model_names)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def fetch_model(self, model_name, file_path=None, require_model_path=False):
|
|
512
|
+
fetched_models = []
|
|
513
|
+
fetched_model_paths = []
|
|
514
|
+
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
|
515
|
+
if file_path is not None and file_path != model_path:
|
|
516
|
+
continue
|
|
517
|
+
if model_name == model_name_:
|
|
518
|
+
fetched_models.append(model)
|
|
519
|
+
fetched_model_paths.append(model_path)
|
|
520
|
+
if len(fetched_models) == 0:
|
|
521
|
+
print(f"No {model_name} models available.")
|
|
522
|
+
return None
|
|
523
|
+
if len(fetched_models) == 1:
|
|
524
|
+
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
|
525
|
+
else:
|
|
526
|
+
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
|
527
|
+
if require_model_path:
|
|
528
|
+
return fetched_models[0], fetched_model_paths[0]
|
|
529
|
+
else:
|
|
530
|
+
return fetched_models[0]
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def to(self, device):
|
|
534
|
+
for model in self.model:
|
|
535
|
+
model.to(device)
|
|
536
|
+
|