diffsynth 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. diffsynth/__init__.py +6 -0
  2. diffsynth/configs/__init__.py +0 -0
  3. diffsynth/configs/model_config.py +243 -0
  4. diffsynth/controlnets/__init__.py +2 -0
  5. diffsynth/controlnets/controlnet_unit.py +53 -0
  6. diffsynth/controlnets/processors.py +51 -0
  7. diffsynth/data/__init__.py +1 -0
  8. diffsynth/data/simple_text_image.py +35 -0
  9. diffsynth/data/video.py +148 -0
  10. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  11. diffsynth/extensions/FastBlend/__init__.py +63 -0
  12. diffsynth/extensions/FastBlend/api.py +397 -0
  13. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  14. diffsynth/extensions/FastBlend/data.py +146 -0
  15. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  16. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  17. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  18. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  19. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  20. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  21. diffsynth/extensions/RIFE/__init__.py +242 -0
  22. diffsynth/extensions/__init__.py +0 -0
  23. diffsynth/models/__init__.py +1 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +66 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +163 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/lora.py +195 -0
  30. diffsynth/models/model_manager.py +536 -0
  31. diffsynth/models/sd3_dit.py +798 -0
  32. diffsynth/models/sd3_text_encoder.py +1107 -0
  33. diffsynth/models/sd3_vae_decoder.py +81 -0
  34. diffsynth/models/sd3_vae_encoder.py +95 -0
  35. diffsynth/models/sd_controlnet.py +588 -0
  36. diffsynth/models/sd_ipadapter.py +57 -0
  37. diffsynth/models/sd_motion.py +199 -0
  38. diffsynth/models/sd_text_encoder.py +321 -0
  39. diffsynth/models/sd_unet.py +1108 -0
  40. diffsynth/models/sd_vae_decoder.py +336 -0
  41. diffsynth/models/sd_vae_encoder.py +282 -0
  42. diffsynth/models/sdxl_ipadapter.py +122 -0
  43. diffsynth/models/sdxl_motion.py +104 -0
  44. diffsynth/models/sdxl_text_encoder.py +759 -0
  45. diffsynth/models/sdxl_unet.py +1899 -0
  46. diffsynth/models/sdxl_vae_decoder.py +24 -0
  47. diffsynth/models/sdxl_vae_encoder.py +24 -0
  48. diffsynth/models/svd_image_encoder.py +505 -0
  49. diffsynth/models/svd_unet.py +2004 -0
  50. diffsynth/models/svd_vae_decoder.py +578 -0
  51. diffsynth/models/svd_vae_encoder.py +139 -0
  52. diffsynth/models/tiler.py +106 -0
  53. diffsynth/pipelines/__init__.py +9 -0
  54. diffsynth/pipelines/base.py +34 -0
  55. diffsynth/pipelines/dancer.py +178 -0
  56. diffsynth/pipelines/hunyuan_image.py +274 -0
  57. diffsynth/pipelines/pipeline_runner.py +105 -0
  58. diffsynth/pipelines/sd3_image.py +132 -0
  59. diffsynth/pipelines/sd_image.py +173 -0
  60. diffsynth/pipelines/sd_video.py +266 -0
  61. diffsynth/pipelines/sdxl_image.py +191 -0
  62. diffsynth/pipelines/sdxl_video.py +223 -0
  63. diffsynth/pipelines/svd_video.py +297 -0
  64. diffsynth/processors/FastBlend.py +142 -0
  65. diffsynth/processors/PILEditor.py +28 -0
  66. diffsynth/processors/RIFE.py +77 -0
  67. diffsynth/processors/__init__.py +0 -0
  68. diffsynth/processors/base.py +6 -0
  69. diffsynth/processors/sequencial_processor.py +41 -0
  70. diffsynth/prompters/__init__.py +6 -0
  71. diffsynth/prompters/base_prompter.py +57 -0
  72. diffsynth/prompters/hunyuan_dit_prompter.py +69 -0
  73. diffsynth/prompters/kolors_prompter.py +353 -0
  74. diffsynth/prompters/prompt_refiners.py +77 -0
  75. diffsynth/prompters/sd3_prompter.py +92 -0
  76. diffsynth/prompters/sd_prompter.py +73 -0
  77. diffsynth/prompters/sdxl_prompter.py +61 -0
  78. diffsynth/schedulers/__init__.py +3 -0
  79. diffsynth/schedulers/continuous_ode.py +59 -0
  80. diffsynth/schedulers/ddim.py +79 -0
  81. diffsynth/schedulers/flow_match.py +51 -0
  82. diffsynth/tokenizer_configs/__init__.py +0 -0
  83. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/special_tokens_map.json +7 -0
  84. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/tokenizer_config.json +16 -0
  85. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab.txt +47020 -0
  86. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer/vocab_org.txt +21128 -0
  87. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/config.json +28 -0
  88. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json +1 -0
  89. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/spiece.model +0 -0
  90. diffsynth/tokenizer_configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json +1 -0
  91. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model +0 -0
  92. diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json +12 -0
  93. diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt +0 -0
  94. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/merges.txt +48895 -0
  95. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  96. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  97. diffsynth/tokenizer_configs/stable_diffusion/tokenizer/vocab.json +49410 -0
  98. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/merges.txt +48895 -0
  99. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/special_tokens_map.json +30 -0
  100. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/tokenizer_config.json +30 -0
  101. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1/vocab.json +49410 -0
  102. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/merges.txt +48895 -0
  103. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/special_tokens_map.json +30 -0
  104. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/tokenizer_config.json +38 -0
  105. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_2/vocab.json +49410 -0
  106. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/special_tokens_map.json +125 -0
  107. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/spiece.model +0 -0
  108. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer.json +129428 -0
  109. diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_3/tokenizer_config.json +940 -0
  110. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/merges.txt +40213 -0
  111. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  112. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  113. diffsynth/tokenizer_configs/stable_diffusion_xl/tokenizer_2/vocab.json +49411 -0
  114. diffsynth/trainers/__init__.py +0 -0
  115. diffsynth/trainers/text_to_image.py +253 -0
  116. diffsynth-1.0.0.dist-info/LICENSE +201 -0
  117. diffsynth-1.0.0.dist-info/METADATA +23 -0
  118. diffsynth-1.0.0.dist-info/RECORD +120 -0
  119. diffsynth-1.0.0.dist-info/WHEEL +5 -0
  120. diffsynth-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,536 @@
1
+ import os, torch, hashlib, json, importlib
2
+ from safetensors import safe_open
3
+ from torch import Tensor
4
+ from typing_extensions import Literal, TypeAlias
5
+ from typing import List
6
+
7
+ from .downloader import download_models, Preset_model_id, Preset_model_website
8
+
9
+ from .sd_text_encoder import SDTextEncoder
10
+ from .sd_unet import SDUNet
11
+ from .sd_vae_encoder import SDVAEEncoder
12
+ from .sd_vae_decoder import SDVAEDecoder
13
+ from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
14
+
15
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
16
+ from .sdxl_unet import SDXLUNet
17
+ from .sdxl_vae_decoder import SDXLVAEDecoder
18
+ from .sdxl_vae_encoder import SDXLVAEEncoder
19
+
20
+ from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
21
+ from .sd3_dit import SD3DiT
22
+ from .sd3_vae_decoder import SD3VAEDecoder
23
+ from .sd3_vae_encoder import SD3VAEEncoder
24
+
25
+ from .sd_controlnet import SDControlNet
26
+
27
+ from .sd_motion import SDMotionModel
28
+ from .sdxl_motion import SDXLMotionModel
29
+
30
+ from .svd_image_encoder import SVDImageEncoder
31
+ from .svd_unet import SVDUNet
32
+ from .svd_vae_decoder import SVDVAEDecoder
33
+ from .svd_vae_encoder import SVDVAEEncoder
34
+
35
+ from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
36
+ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
37
+
38
+ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
39
+ from .hunyuan_dit import HunyuanDiT
40
+
41
+ from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
42
+
43
+
44
+
45
+ def load_state_dict(file_path, torch_dtype=None):
46
+ if file_path.endswith(".safetensors"):
47
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
48
+ else:
49
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
50
+
51
+
52
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
53
+ state_dict = {}
54
+ with safe_open(file_path, framework="pt", device="cpu") as f:
55
+ for k in f.keys():
56
+ state_dict[k] = f.get_tensor(k)
57
+ if torch_dtype is not None:
58
+ state_dict[k] = state_dict[k].to(torch_dtype)
59
+ return state_dict
60
+
61
+
62
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
63
+ state_dict = torch.load(file_path, map_location="cpu")
64
+ if torch_dtype is not None:
65
+ for i in state_dict:
66
+ if isinstance(state_dict[i], torch.Tensor):
67
+ state_dict[i] = state_dict[i].to(torch_dtype)
68
+ return state_dict
69
+
70
+
71
+ def search_for_embeddings(state_dict):
72
+ embeddings = []
73
+ for k in state_dict:
74
+ if isinstance(state_dict[k], torch.Tensor):
75
+ embeddings.append(state_dict[k])
76
+ elif isinstance(state_dict[k], dict):
77
+ embeddings += search_for_embeddings(state_dict[k])
78
+ return embeddings
79
+
80
+
81
+ def search_parameter(param, state_dict):
82
+ for name, param_ in state_dict.items():
83
+ if param.numel() == param_.numel():
84
+ if param.shape == param_.shape:
85
+ if torch.dist(param, param_) < 1e-6:
86
+ return name
87
+ else:
88
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
89
+ return name
90
+ return None
91
+
92
+
93
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
94
+ matched_keys = set()
95
+ with torch.no_grad():
96
+ for name in source_state_dict:
97
+ rename = search_parameter(source_state_dict[name], target_state_dict)
98
+ if rename is not None:
99
+ print(f'"{name}": "{rename}",')
100
+ matched_keys.add(rename)
101
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
102
+ length = source_state_dict[name].shape[0] // 3
103
+ rename = []
104
+ for i in range(3):
105
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
106
+ if None not in rename:
107
+ print(f'"{name}": {rename},')
108
+ for rename_ in rename:
109
+ matched_keys.add(rename_)
110
+ for name in target_state_dict:
111
+ if name not in matched_keys:
112
+ print("Cannot find", name, target_state_dict[name].shape)
113
+
114
+
115
+ def search_for_files(folder, extensions):
116
+ files = []
117
+ if os.path.isdir(folder):
118
+ for file in sorted(os.listdir(folder)):
119
+ files += search_for_files(os.path.join(folder, file), extensions)
120
+ elif os.path.isfile(folder):
121
+ for extension in extensions:
122
+ if folder.endswith(extension):
123
+ files.append(folder)
124
+ break
125
+ return files
126
+
127
+
128
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
129
+ keys = []
130
+ for key, value in state_dict.items():
131
+ if isinstance(key, str):
132
+ if isinstance(value, Tensor):
133
+ if with_shape:
134
+ shape = "_".join(map(str, list(value.shape)))
135
+ keys.append(key + ":" + shape)
136
+ keys.append(key)
137
+ elif isinstance(value, dict):
138
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
139
+ keys.sort()
140
+ keys_str = ",".join(keys)
141
+ return keys_str
142
+
143
+
144
+ def split_state_dict_with_prefix(state_dict):
145
+ keys = sorted([key for key in state_dict if isinstance(key, str)])
146
+ prefix_dict = {}
147
+ for key in keys:
148
+ prefix = key if "." not in key else key.split(".")[0]
149
+ if prefix not in prefix_dict:
150
+ prefix_dict[prefix] = []
151
+ prefix_dict[prefix].append(key)
152
+ state_dicts = []
153
+ for prefix, keys in prefix_dict.items():
154
+ sub_state_dict = {key: state_dict[key] for key in keys}
155
+ state_dicts.append(sub_state_dict)
156
+ return state_dicts
157
+
158
+
159
+ def hash_state_dict_keys(state_dict, with_shape=True):
160
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
161
+ keys_str = keys_str.encode(encoding="UTF-8")
162
+ return hashlib.md5(keys_str).hexdigest()
163
+
164
+
165
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
166
+ loaded_model_names, loaded_models = [], []
167
+ for model_name, model_class in zip(model_names, model_classes):
168
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
169
+ state_dict_converter = model_class.state_dict_converter()
170
+ if model_resource == "civitai":
171
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
172
+ elif model_resource == "diffusers":
173
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
174
+ if isinstance(state_dict_results, tuple):
175
+ model_state_dict, extra_kwargs = state_dict_results
176
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
177
+ else:
178
+ model_state_dict, extra_kwargs = state_dict_results, {}
179
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
180
+ model = model_class(**extra_kwargs).to(dtype=torch_dtype, device=device)
181
+ model.load_state_dict(model_state_dict)
182
+ loaded_model_names.append(model_name)
183
+ loaded_models.append(model)
184
+ return loaded_model_names, loaded_models
185
+
186
+
187
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
188
+ loaded_model_names, loaded_models = [], []
189
+ for model_name, model_class in zip(model_names, model_classes):
190
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
191
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
192
+ model = model.half()
193
+ model = model.to(device=device)
194
+ loaded_model_names.append(model_name)
195
+ loaded_models.append(model)
196
+ return loaded_model_names, loaded_models
197
+
198
+
199
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
200
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
201
+ base_state_dict = base_model.state_dict()
202
+ base_model.to("cpu")
203
+ del base_model
204
+ model = model_class(**extra_kwargs)
205
+ model.load_state_dict(base_state_dict, strict=False)
206
+ model.load_state_dict(state_dict, strict=False)
207
+ model.to(dtype=torch_dtype, device=device)
208
+ return model
209
+
210
+
211
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
212
+ loaded_model_names, loaded_models = [], []
213
+ for model_name, model_class in zip(model_names, model_classes):
214
+ while True:
215
+ for model_id in range(len(model_manager.model)):
216
+ base_model_name = model_manager.model_name[model_id]
217
+ if base_model_name == model_name:
218
+ base_model_path = model_manager.model_path[model_id]
219
+ base_model = model_manager.model[model_id]
220
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
221
+ patched_model = load_single_patch_model_from_single_file(
222
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
223
+ loaded_model_names.append(base_model_name)
224
+ loaded_models.append(patched_model)
225
+ model_manager.model.pop(model_id)
226
+ model_manager.model_path.pop(model_id)
227
+ model_manager.model_name.pop(model_id)
228
+ break
229
+ else:
230
+ break
231
+ return loaded_model_names, loaded_models
232
+
233
+
234
+
235
+ class ModelDetectorTemplate:
236
+ def __init__(self):
237
+ pass
238
+
239
+ def match(self, file_path="", state_dict={}):
240
+ return False
241
+
242
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
243
+ return [], []
244
+
245
+
246
+
247
+ class ModelDetectorFromSingleFile:
248
+ def __init__(self, model_loader_configs=[]):
249
+ self.keys_hash_with_shape_dict = {}
250
+ self.keys_hash_dict = {}
251
+ for metadata in model_loader_configs:
252
+ self.add_model_metadata(*metadata)
253
+
254
+
255
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
256
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
257
+ if keys_hash is not None:
258
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
259
+
260
+
261
+ def match(self, file_path="", state_dict={}):
262
+ if os.path.isdir(file_path):
263
+ return False
264
+ if len(state_dict) == 0:
265
+ state_dict = load_state_dict(file_path)
266
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
267
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
268
+ return True
269
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
270
+ if keys_hash in self.keys_hash_dict:
271
+ return True
272
+ return False
273
+
274
+
275
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
276
+ if len(state_dict) == 0:
277
+ state_dict = load_state_dict(file_path)
278
+
279
+ # Load models with strict matching
280
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
281
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
282
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
283
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
284
+ return loaded_model_names, loaded_models
285
+
286
+ # Load models without strict matching
287
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
288
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
289
+ if keys_hash in self.keys_hash_dict:
290
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
291
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
292
+ return loaded_model_names, loaded_models
293
+
294
+ return loaded_model_names, loaded_models
295
+
296
+
297
+
298
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
299
+ def __init__(self, model_loader_configs=[]):
300
+ super().__init__(model_loader_configs)
301
+
302
+
303
+ def match(self, file_path="", state_dict={}):
304
+ if os.path.isdir(file_path):
305
+ return False
306
+ if len(state_dict) == 0:
307
+ state_dict = load_state_dict(file_path)
308
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
309
+ for sub_state_dict in splited_state_dict:
310
+ if super().match(file_path, sub_state_dict):
311
+ return True
312
+ return False
313
+
314
+
315
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
316
+ # Split the state_dict and load from each component
317
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
318
+ valid_state_dict = {}
319
+ for sub_state_dict in splited_state_dict:
320
+ if super().match(file_path, sub_state_dict):
321
+ valid_state_dict.update(sub_state_dict)
322
+ if super().match(file_path, valid_state_dict):
323
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
324
+ else:
325
+ loaded_model_names, loaded_models = [], []
326
+ for sub_state_dict in splited_state_dict:
327
+ if super().match(file_path, sub_state_dict):
328
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
329
+ loaded_model_names += loaded_model_names_
330
+ loaded_models += loaded_models_
331
+ return loaded_model_names, loaded_models
332
+
333
+
334
+
335
+ class ModelDetectorFromHuggingfaceFolder:
336
+ def __init__(self, model_loader_configs=[]):
337
+ self.architecture_dict = {}
338
+ for metadata in model_loader_configs:
339
+ self.add_model_metadata(*metadata)
340
+
341
+
342
+ def add_model_metadata(self, architecture, huggingface_lib, model_name):
343
+ self.architecture_dict[architecture] = (huggingface_lib, model_name)
344
+
345
+
346
+ def match(self, file_path="", state_dict={}):
347
+ if os.path.isfile(file_path):
348
+ return False
349
+ file_list = os.listdir(file_path)
350
+ if "config.json" not in file_list:
351
+ return False
352
+ with open(os.path.join(file_path, "config.json"), "r") as f:
353
+ config = json.load(f)
354
+ if "architectures" not in config:
355
+ return False
356
+ return True
357
+
358
+
359
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
360
+ with open(os.path.join(file_path, "config.json"), "r") as f:
361
+ config = json.load(f)
362
+ loaded_model_names, loaded_models = [], []
363
+ for architecture in config["architectures"]:
364
+ huggingface_lib, model_name = self.architecture_dict[architecture]
365
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
366
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
367
+ loaded_model_names += loaded_model_names_
368
+ loaded_models += loaded_models_
369
+ return loaded_model_names, loaded_models
370
+
371
+
372
+
373
+ class ModelDetectorFromPatchedSingleFile:
374
+ def __init__(self, model_loader_configs=[]):
375
+ self.keys_hash_with_shape_dict = {}
376
+ for metadata in model_loader_configs:
377
+ self.add_model_metadata(*metadata)
378
+
379
+
380
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
381
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
382
+
383
+
384
+ def match(self, file_path="", state_dict={}):
385
+ if os.path.isdir(file_path):
386
+ return False
387
+ if len(state_dict) == 0:
388
+ state_dict = load_state_dict(file_path)
389
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
390
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
391
+ return True
392
+ return False
393
+
394
+
395
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
396
+ if len(state_dict) == 0:
397
+ state_dict = load_state_dict(file_path)
398
+
399
+ # Load models with strict matching
400
+ loaded_model_names, loaded_models = [], []
401
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
402
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
403
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
404
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
405
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
406
+ loaded_model_names += loaded_model_names_
407
+ loaded_models += loaded_models_
408
+ return loaded_model_names, loaded_models
409
+
410
+
411
+
412
+ class ModelManager:
413
+ def __init__(
414
+ self,
415
+ torch_dtype=torch.float16,
416
+ device="cuda",
417
+ model_id_list: List[Preset_model_id] = [],
418
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
419
+ file_path_list: List[str] = [],
420
+ ):
421
+ self.torch_dtype = torch_dtype
422
+ self.device = device
423
+ self.model = []
424
+ self.model_path = []
425
+ self.model_name = []
426
+ downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
427
+ self.model_detector = [
428
+ ModelDetectorFromSingleFile(model_loader_configs),
429
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
430
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
431
+ ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
432
+ ]
433
+ self.load_models(downloaded_files + file_path_list)
434
+
435
+
436
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
437
+ print(f"Loading models from file: {file_path}")
438
+ if len(state_dict) == 0:
439
+ state_dict = load_state_dict(file_path)
440
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
441
+ for model_name, model in zip(model_names, models):
442
+ self.model.append(model)
443
+ self.model_path.append(file_path)
444
+ self.model_name.append(model_name)
445
+ print(f" The following models are loaded: {model_names}.")
446
+
447
+
448
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
449
+ print(f"Loading models from folder: {file_path}")
450
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
451
+ for model_name, model in zip(model_names, models):
452
+ self.model.append(model)
453
+ self.model_path.append(file_path)
454
+ self.model_name.append(model_name)
455
+ print(f" The following models are loaded: {model_names}.")
456
+
457
+
458
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
459
+ print(f"Loading patch models from file: {file_path}")
460
+ model_names, models = load_patch_model_from_single_file(
461
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
462
+ for model_name, model in zip(model_names, models):
463
+ self.model.append(model)
464
+ self.model_path.append(file_path)
465
+ self.model_name.append(model_name)
466
+ print(f" The following patched models are loaded: {model_names}.")
467
+
468
+
469
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
470
+ print(f"Loading LoRA models from file: {file_path}")
471
+ if len(state_dict) == 0:
472
+ state_dict = load_state_dict(file_path)
473
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
474
+ for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
475
+ match_results = lora.match(model, state_dict)
476
+ if match_results is not None:
477
+ print(f" Adding LoRA to {model_name} ({model_path}).")
478
+ lora_prefix, model_resource = match_results
479
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
480
+ break
481
+
482
+
483
+ def load_model(self, file_path, model_names=None):
484
+ print(f"Loading models from: {file_path}")
485
+ if os.path.isfile(file_path):
486
+ state_dict = load_state_dict(file_path)
487
+ else:
488
+ state_dict = None
489
+ for model_detector in self.model_detector:
490
+ if model_detector.match(file_path, state_dict):
491
+ model_names, models = model_detector.load(
492
+ file_path, state_dict,
493
+ device=self.device, torch_dtype=self.torch_dtype,
494
+ allowed_model_names=model_names, model_manager=self
495
+ )
496
+ for model_name, model in zip(model_names, models):
497
+ self.model.append(model)
498
+ self.model_path.append(file_path)
499
+ self.model_name.append(model_name)
500
+ print(f" The following models are loaded: {model_names}.")
501
+ break
502
+ else:
503
+ print(f" We cannot detect the model type. No models are loaded.")
504
+
505
+
506
+ def load_models(self, file_path_list, model_names=None):
507
+ for file_path in file_path_list:
508
+ self.load_model(file_path, model_names)
509
+
510
+
511
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
512
+ fetched_models = []
513
+ fetched_model_paths = []
514
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
515
+ if file_path is not None and file_path != model_path:
516
+ continue
517
+ if model_name == model_name_:
518
+ fetched_models.append(model)
519
+ fetched_model_paths.append(model_path)
520
+ if len(fetched_models) == 0:
521
+ print(f"No {model_name} models available.")
522
+ return None
523
+ if len(fetched_models) == 1:
524
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
525
+ else:
526
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
527
+ if require_model_path:
528
+ return fetched_models[0], fetched_model_paths[0]
529
+ else:
530
+ return fetched_models[0]
531
+
532
+
533
+ def to(self, device):
534
+ for model in self.model:
535
+ model.to(device)
536
+