diffsynth-engine 0.4.3.dev5__py3-none-any.whl → 0.4.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,10 @@ from .configs import (
4
4
  FluxPipelineConfig,
5
5
  WanPipelineConfig,
6
6
  QwenImagePipelineConfig,
7
+ SDStateDicts,
8
+ SDXLStateDicts,
9
+ FluxStateDicts,
10
+ QwenImageStateDicts,
7
11
  ControlNetParams,
8
12
  ControlType,
9
13
  )
@@ -38,6 +42,10 @@ __all__ = [
38
42
  "SDXLPipelineConfig",
39
43
  "FluxPipelineConfig",
40
44
  "WanPipelineConfig",
45
+ "SDStateDicts",
46
+ "SDXLStateDicts",
47
+ "FluxStateDicts",
48
+ "QwenImageStateDicts",
41
49
  "FluxImagePipeline",
42
50
  "QwenImagePipelineConfig",
43
51
  "FluxControlNet",
@@ -172,31 +172,40 @@ class SDImagePipeline(BasePipeline):
172
172
  else:
173
173
  config = model_path_or_config
174
174
 
175
- logger.info(f"loading state dict from {config.model_path} ...")
176
- unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
177
-
178
- if config.vae_path is not None:
179
- logger.info(f"loading state dict from {config.vae_path} ...")
180
- vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
181
- else:
182
- vae_state_dict = unet_state_dict
183
-
184
- if config.clip_path is not None:
185
- logger.info(f"loading state dict from {config.clip_path} ...")
186
- clip_state_dict = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
187
- else:
188
- clip_state_dict = unet_state_dict
175
+ return cls.from_state_dict(SDStateDicts(), config)
189
176
 
177
+ @classmethod
178
+ def from_state_dict(cls, state_dicts: SDStateDicts, config: SDPipelineConfig) -> "SDImagePipeline":
179
+ if state_dicts.model is None:
180
+ if config.model_path is None:
181
+ raise ValueError("`model_path` cannot be empty")
182
+ logger.info(f"loading state dict from {config.model_path} ...")
183
+ state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
184
+
185
+ if state_dicts.vae is None:
186
+ if config.vae_path is None:
187
+ state_dicts.vae = state_dicts.model
188
+ else:
189
+ logger.info(f"loading state dict from {config.vae_path} ...")
190
+ state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
191
+
192
+ if state_dicts.clip is None:
193
+ if config.clip_path is None:
194
+ state_dicts.clip = state_dicts.model
195
+ else:
196
+ logger.info(f"loading state dict from {config.clip_path} ...")
197
+ state_dicts.clip = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
198
+
190
199
  init_device = "cpu" if config.offload_mode is not None else config.device
191
200
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
192
201
  with LoRAContext():
193
- text_encoder = SDTextEncoder.from_state_dict(clip_state_dict, device=init_device, dtype=config.clip_dtype)
194
- unet = SDUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
202
+ text_encoder = SDTextEncoder.from_state_dict(state_dicts.clip, device=init_device, dtype=config.clip_dtype)
203
+ unet = SDUNet.from_state_dict(state_dicts.model, device=init_device, dtype=config.model_dtype)
195
204
  vae_decoder = SDVAEDecoder.from_state_dict(
196
- vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
205
+ state_dicts.vae, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
197
206
  )
198
207
  vae_encoder = SDVAEEncoder.from_state_dict(
199
- vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
208
+ state_dicts.vae, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
200
209
  )
201
210
 
202
211
  pipe = cls(
@@ -213,10 +222,6 @@ class SDImagePipeline(BasePipeline):
213
222
  pipe.enable_cpu_offload(config.offload_mode)
214
223
  return pipe
215
224
 
216
- @classmethod
217
- def from_state_dict(cls, state_dicts: SDStateDicts, pipeline_config: SDPipelineConfig) -> "SDImagePipeline":
218
- raise NotImplementedError()
219
-
220
225
  def denoising_model(self):
221
226
  return self.unet
222
227
 
@@ -150,43 +150,53 @@ class SDXLImagePipeline(BasePipeline):
150
150
  else:
151
151
  config = model_path_or_config
152
152
 
153
- logger.info(f"loading state dict from {config.model_path} ...")
154
- unet_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
155
-
156
- if config.vae_path is not None:
157
- logger.info(f"loading state dict from {config.vae_path} ...")
158
- vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
159
- else:
160
- vae_state_dict = unet_state_dict
161
-
162
- if config.clip_l_path is not None:
163
- logger.info(f"loading state dict from {config.clip_l_path} ...")
164
- clip_l_state_dict = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype)
165
- else:
166
- clip_l_state_dict = unet_state_dict
167
-
168
- if config.clip_g_path is not None:
169
- logger.info(f"loading state dict from {config.clip_g_path} ...")
170
- clip_g_state_dict = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype)
171
- else:
172
- clip_g_state_dict = unet_state_dict
153
+ return cls.from_state_dict(SDXLStateDicts(), config)
173
154
 
155
+ @classmethod
156
+ def from_state_dict(cls, state_dicts: SDXLStateDicts, config: SDXLPipelineConfig) -> "SDXLImagePipeline":
157
+ if state_dicts.model is None:
158
+ if config.model_path is None:
159
+ raise ValueError("`model_path` cannot be empty")
160
+ logger.info(f"loading state dict from {config.model_path} ...")
161
+ state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
162
+
163
+ if state_dicts.vae is None:
164
+ if config.vae_path is None:
165
+ state_dicts.vae = state_dicts.model
166
+ else:
167
+ logger.info(f"loading state dict from {config.vae_path} ...")
168
+ state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
169
+
170
+ if state_dicts.clip_l is None:
171
+ if config.clip_l_path is None:
172
+ state_dicts.clip_l = state_dicts.model
173
+ else:
174
+ logger.info(f"loading state dict from {config.clip_l_path} ...")
175
+ state_dicts.clip_l = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype)
176
+
177
+ if state_dicts.clip_g is None:
178
+ if config.clip_g_path is None:
179
+ state_dicts.clip_g = state_dicts.model
180
+ else:
181
+ logger.info(f"loading state dict from {config.clip_g_path} ...")
182
+ state_dicts.clip_g = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype)
183
+
174
184
  init_device = "cpu" if config.offload_mode else config.device
175
185
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
176
186
  tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
177
187
  with LoRAContext():
178
188
  text_encoder = SDXLTextEncoder.from_state_dict(
179
- clip_l_state_dict, device=init_device, dtype=config.clip_l_dtype
189
+ state_dicts.clip_l, device=init_device, dtype=config.clip_l_dtype
180
190
  )
181
191
  text_encoder_2 = SDXLTextEncoder2.from_state_dict(
182
- clip_g_state_dict, device=init_device, dtype=config.clip_g_dtype
192
+ state_dicts.clip_g, device=init_device, dtype=config.clip_g_dtype
183
193
  )
184
- unet = SDXLUNet.from_state_dict(unet_state_dict, device=init_device, dtype=config.model_dtype)
194
+ unet = SDXLUNet.from_state_dict(state_dicts.model, device=init_device, dtype=config.model_dtype)
185
195
  vae_decoder = SDXLVAEDecoder.from_state_dict(
186
- vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
196
+ state_dicts.vae, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
187
197
  )
188
198
  vae_encoder = SDXLVAEEncoder.from_state_dict(
189
- vae_state_dict, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
199
+ state_dicts.vae, device=init_device, dtype=config.vae_dtype, attn_impl="sdpa"
190
200
  )
191
201
 
192
202
  pipe = cls(
@@ -205,10 +215,6 @@ class SDXLImagePipeline(BasePipeline):
205
215
  pipe.enable_cpu_offload(config.offload_mode)
206
216
  return pipe
207
217
 
208
- @classmethod
209
- def from_state_dict(cls, state_dicts: SDXLStateDicts, pipeline_config: SDXLPipelineConfig) -> "SDXLImagePipeline":
210
- raise NotImplementedError()
211
-
212
218
  def denoising_model(self):
213
219
  return self.unet
214
220
 
@@ -417,7 +417,7 @@ class WanVideoPipeline(BasePipeline):
417
417
  cfg_scale_ = cfg_scale if isinstance(cfg_scale, float) else cfg_scale[0]
418
418
 
419
419
  timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
420
- timestep = timestep.to(dtype=self.config.model_dtype, device=self.device)
420
+ timestep = timestep.to(dtype=self.dtype, device=self.device)
421
421
  # Classifier-free guidance
422
422
  noise_pred = self.predict_noise_with_cfg(
423
423
  model=model,
@@ -574,6 +574,18 @@ class WanVideoPipeline(BasePipeline):
574
574
  if config.offload_mode is not None:
575
575
  pipe.enable_cpu_offload(config.offload_mode)
576
576
 
577
+ if config.model_dtype == torch.float8_e4m3fn:
578
+ pipe.dtype = torch.bfloat16 # compute dtype
579
+ pipe.enable_fp8_autocast(
580
+ model_names=["dit"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
581
+ )
582
+
583
+ if config.t5_dtype == torch.float8_e4m3fn:
584
+ pipe.dtype = torch.bfloat16 # compute dtype
585
+ pipe.enable_fp8_autocast(
586
+ model_names=["text_encoder"], compute_dtype=pipe.dtype, use_fp8_linear=config.use_fp8_linear
587
+ )
588
+
577
589
  if config.parallelism > 1:
578
590
  return ParallelWrapper(
579
591
  pipe,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.4.3.dev5
3
+ Version: 0.4.3.dev7
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -1,4 +1,4 @@
1
- diffsynth_engine/__init__.py,sha256=tXaLuKje4NQ3zARAvqBUdj1pGLjP0ttkXKE6ysuzsOc,1586
1
+ diffsynth_engine/__init__.py,sha256=fcY1Z0QWNyrYuGX2dVTj2M8crWhVIL-vnPndfVI7mZs,1760
2
2
  diffsynth_engine/algorithm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  diffsynth_engine/algorithm/noise_scheduler/__init__.py,sha256=YvcwE2tCNua-OAX9GEPm0EXsINNWH4XvJMNZb-uaZMM,745
4
4
  diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py,sha256=WICrLEh7b2TdZMMEN14NqiYydj7dxXT6RolXymKiMk8,188
@@ -132,10 +132,10 @@ diffsynth_engine/pipelines/__init__.py,sha256=9QZVhZeRm_5m7yxie08yBtgM26NB4mfVBO
132
132
  diffsynth_engine/pipelines/base.py,sha256=goe_UO1LvUXVwP5geUmu0zdFUrSms9iss3OuRyuMjXY,13726
133
133
  diffsynth_engine/pipelines/flux_image.py,sha256=gWuZaMeupB_Wz3AY97eE1eEVSAmAm14aXIxkAqNXY7E,49224
134
134
  diffsynth_engine/pipelines/qwen_image.py,sha256=3S-eL2GY-c0g9nqDyYByr9RV-kdY589m75a0k4vw_AQ,18459
135
- diffsynth_engine/pipelines/sd_image.py,sha256=5cIIknh2M-fOqj7urKi9nZ40yc1LnvepbH_Af7SF4UA,17789
136
- diffsynth_engine/pipelines/sdxl_image.py,sha256=otv1T_0fhX3UcIoKbKCqb47Yge6xg0fPM0ry-uPEanI,21548
135
+ diffsynth_engine/pipelines/sd_image.py,sha256=GhrCadEmAWv4id0NdRpJW_EC2PgItBctXLkfPxq5gDI,18100
136
+ diffsynth_engine/pipelines/sdxl_image.py,sha256=kmidIz8zDtrw9ggLXI3WG7AQq_jmOPVct-O3hGNra_g,21951
137
137
  diffsynth_engine/pipelines/utils.py,sha256=lk7sFGEk-fGjgadLpwwppHKG-yZ0RC-4ZmHW7pRRe8A,473
138
- diffsynth_engine/pipelines/wan_video.py,sha256=stoYKm0wHf_pxZ_WHRTGHTR61KVG_U21yBUaUrDjSqw,25605
138
+ diffsynth_engine/pipelines/wan_video.py,sha256=lb0FrMFxQ6BNfOUErveWcnzPJa1gq0yYtMXUZjNTOuU,26126
139
139
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
141
141
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -168,8 +168,8 @@ diffsynth_engine/utils/parallel.py,sha256=Z9jqCv4mLV4JyXR3uTHyv1rujPiKU8PSCbAfiN
168
168
  diffsynth_engine/utils/platform.py,sha256=2lXdw6YkqcRONCeT98n4cyg1Ii8Ybbyj2Ns72Se9tlk,496
169
169
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
170
170
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
171
- diffsynth_engine-0.4.3.dev5.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
172
- diffsynth_engine-0.4.3.dev5.dist-info/METADATA,sha256=hMlrgbZrIStdg2Sr4iIePwnxCltZ1rttaP_1I2gXflA,1110
173
- diffsynth_engine-0.4.3.dev5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
174
- diffsynth_engine-0.4.3.dev5.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
175
- diffsynth_engine-0.4.3.dev5.dist-info/RECORD,,
171
+ diffsynth_engine-0.4.3.dev7.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
172
+ diffsynth_engine-0.4.3.dev7.dist-info/METADATA,sha256=AkzJrm0DuplPG552EIlIa1um1VoHzWK6DfHCgwDHYLQ,1110
173
+ diffsynth_engine-0.4.3.dev7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
174
+ diffsynth_engine-0.4.3.dev7.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
175
+ diffsynth_engine-0.4.3.dev7.dist-info/RECORD,,