diffsynth-engine 0.4.2.dev4__py3-none-any.whl → 0.4.2.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -315,6 +315,7 @@ class QwenImageTransformerBlock(nn.Module):
315
315
 
316
316
  class QwenImageDiT(PreTrainedModel):
317
317
  converter = QwenImageDiTStateDictConverter()
318
+ _supports_parallelization = True
318
319
 
319
320
  def __init__(
320
321
  self,
@@ -423,3 +424,6 @@ class QwenImageDiT(PreTrainedModel):
423
424
  model.load_state_dict(state_dict, assign=True)
424
425
  model.to(device=device, dtype=dtype, non_blocking=True)
425
426
  return model
427
+
428
+ def get_fsdp_modules(self):
429
+ return ["transformer_blocks"]
@@ -6,7 +6,7 @@ from typing import Dict, List, Tuple
6
6
  from PIL import Image
7
7
 
8
8
  from diffsynth_engine.configs import BaseConfig, BaseStateDicts
9
- from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
9
+ from diffsynth_engine.utils.offload import enable_sequential_cpu_offload, offload_model_to_dict, restore_model_from_dict
10
10
  from diffsynth_engine.utils.fp8_linear import enable_fp8_autocast
11
11
  from diffsynth_engine.utils.gguf import load_gguf_checkpoint
12
12
  from diffsynth_engine.utils import logging
@@ -40,6 +40,7 @@ class BasePipeline:
40
40
  self.dtype = dtype
41
41
  self.offload_mode = None
42
42
  self.model_names = []
43
+ self._offload_param_dict = {}
43
44
 
44
45
  @classmethod
45
46
  def from_pretrained(cls, model_path_or_config: str | BaseConfig) -> "BasePipeline":
@@ -243,14 +244,13 @@ class BasePipeline:
243
244
  for model_name in self.model_names:
244
245
  model = getattr(self, model_name)
245
246
  if model is not None:
246
- model.to("cpu")
247
+ self._offload_param_dict[model_name] = offload_model_to_dict(model)
247
248
  self.offload_mode = "cpu_offload"
248
249
 
249
250
  def _enable_sequential_cpu_offload(self):
250
251
  for model_name in self.model_names:
251
252
  model = getattr(self, model_name)
252
253
  if model is not None:
253
- model.to("cpu")
254
254
  enable_sequential_cpu_offload(model, self.device)
255
255
  self.offload_mode = "sequential_cpu_offload"
256
256
 
@@ -277,20 +277,12 @@ class BasePipeline:
277
277
  for model_name in self.model_names:
278
278
  if model_name not in load_model_names:
279
279
  model = getattr(self, model_name)
280
- if (
281
- model is not None
282
- and (p := next(model.parameters(), None)) is not None
283
- and p.device != torch.device("cpu")
284
- ):
285
- model.to("cpu")
280
+ if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != "cpu":
281
+ restore_model_from_dict(model, self._offload_param_dict[model_name])
286
282
  # load the needed models to device
287
283
  for model_name in load_model_names:
288
284
  model = getattr(self, model_name)
289
- if (
290
- model is not None
291
- and (p := next(model.parameters(), None)) is not None
292
- and p.device != torch.device(self.device)
293
- ):
285
+ if model is not None and (p := next(model.parameters(), None)) is not None and p.device.type != self.device:
294
286
  model.to(self.device)
295
287
  # fresh the cuda cache
296
288
  empty_cache()
@@ -584,4 +584,11 @@ class WanVideoPipeline(BasePipeline):
584
584
  use_fsdp=config.use_fsdp,
585
585
  device="cuda",
586
586
  )
587
+ if config.use_torch_compile:
588
+ pipe.compile()
587
589
  return pipe
590
+
591
+ def compile(self):
592
+ self.dit.compile()
593
+ if self.dit2 is not None:
594
+ self.dit2.compile()
@@ -1,8 +1,10 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ from typing import Dict
3
4
 
4
5
 
5
6
  def enable_sequential_cpu_offload(module: nn.Module, device: str = "cuda"):
7
+ module = module.to("cpu")
6
8
  if len(list(module.children())) == 0:
7
9
  if len(list(module.parameters())) > 0 or len(list(module.buffers())) > 0:
8
10
  # leaf module with parameters or buffers
@@ -50,3 +52,24 @@ def add_cpu_offload_hook(module: nn.Module, device: str = "cuda", recurse: bool
50
52
  module.register_forward_pre_hook(_forward_pre_hook)
51
53
  module.register_forward_hook(_forward_hook)
52
54
  setattr(module, "_cpu_offload_enabled", True)
55
+
56
+
57
+ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
58
+ module = module.to("cpu")
59
+ offload_param_dict = {}
60
+ for name, param in module.named_parameters(recurse=True):
61
+ param.data = param.data.pin_memory()
62
+ offload_param_dict[name] = param.data
63
+ for name, buffer in module.named_buffers(recurse=True):
64
+ buffer.data = buffer.data.pin_memory()
65
+ offload_param_dict[name] = buffer.data
66
+ return offload_param_dict
67
+
68
+
69
+ def restore_model_from_dict(module: nn.Module, offload_param_dict: Dict[str, torch.Tensor]):
70
+ for name, param in module.named_parameters(recurse=True):
71
+ if name in offload_param_dict:
72
+ param.data = offload_param_dict[name]
73
+ for name, buffer in module.named_buffers(recurse=True):
74
+ if name in offload_param_dict:
75
+ buffer.data = offload_param_dict[name]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.4.2.dev4
3
+ Version: 0.4.2.dev6
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -100,7 +100,7 @@ diffsynth_engine/models/flux/flux_text_encoder.py,sha256=Qcs277RIPP-O5AkcAb5Fb0j
100
100
  diffsynth_engine/models/flux/flux_vae.py,sha256=YoeZZSTbXo5rpYQ5-XY3A-Bd1swb8H1yOmtVjm0g5ZI,2994
101
101
  diffsynth_engine/models/qwen_image/__init__.py,sha256=X5pig621WEsDZ6L7HVkmYspV53-GDfs_la1ncaq_NFw,417
102
102
  diffsynth_engine/models/qwen_image/qwen2_5_vl.py,sha256=vpBo6eo_96iLky9YV5MX0nbmOleY2EX97TrJoRBNnw4,56511
103
- diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=CzYtYrqBzaHcn6OE4Z4yYqYdAoGRbfQtZ5KQVrN03to,16839
103
+ diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=bmsVuwMU26i34lCDDpKsxY1OVRwtWawn44r_E_Pd1Ws,16947
104
104
  diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py,sha256=tgCmD4MFNgd3HmLyoYnt8HZCzPBTgQ4zCQjV5qZSW_I,4870
105
105
  diffsynth_engine/models/qwen_image/qwen_image_vae.py,sha256=m455iJfJx8KVEsrkinjClokkEqd0RshSZoKZ_QAdRyk,38509
106
106
  diffsynth_engine/models/sd/__init__.py,sha256=hjoKRnwoXOLD0wude-w7I6wK5ak7ACMbnbkPuBB2oU0,380
@@ -129,13 +129,13 @@ diffsynth_engine/models/wan/wan_image_encoder.py,sha256=LYwcfCcQmXf9FP08DGaU2bfa
129
129
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=bkphxtqNNwXcEA_OaUrwV9CvICV-s16awu5Z9gjjzsM,10912
130
130
  diffsynth_engine/models/wan/wan_vae.py,sha256=AmBuqyPwZCFY0e8lUThlJoNHmpmTm2_dE1XYzXBCaAI,38937
131
131
  diffsynth_engine/pipelines/__init__.py,sha256=9QZVhZeRm_5m7yxie08yBtgM26NB4mfVBOjj1Prlv-k,447
132
- diffsynth_engine/pipelines/base.py,sha256=AUoB07rCvrfVyied-qQBIsexyf5eE8hqyXr-2IIWYwU,12339
132
+ diffsynth_engine/pipelines/base.py,sha256=730FzAn1cxZ2SM2qr9SXk7qNLpAau9PKEqwfMSW_8Aw,12336
133
133
  diffsynth_engine/pipelines/flux_image.py,sha256=gWuZaMeupB_Wz3AY97eE1eEVSAmAm14aXIxkAqNXY7E,49224
134
134
  diffsynth_engine/pipelines/qwen_image.py,sha256=S15FAxmiWcF3qItBigXR9jVJa4h9Qqs4_8cT3KXi8Ec,17506
135
135
  diffsynth_engine/pipelines/sd_image.py,sha256=5cIIknh2M-fOqj7urKi9nZ40yc1LnvepbH_Af7SF4UA,17789
136
136
  diffsynth_engine/pipelines/sdxl_image.py,sha256=otv1T_0fhX3UcIoKbKCqb47Yge6xg0fPM0ry-uPEanI,21548
137
137
  diffsynth_engine/pipelines/utils.py,sha256=lk7sFGEk-fGjgadLpwwppHKG-yZ0RC-4ZmHW7pRRe8A,473
138
- diffsynth_engine/pipelines/wan_video.py,sha256=wbCHPDgs4BmyX1DsawaXqxeCoVAcISNcXNnFr2qcTx0,25424
138
+ diffsynth_engine/pipelines/wan_video.py,sha256=stoYKm0wHf_pxZ_WHRTGHTR61KVG_U21yBUaUrDjSqw,25605
139
139
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
140
140
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
141
141
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -162,14 +162,14 @@ diffsynth_engine/utils/image.py,sha256=_46CVs1Qe7GdZNulWWJISnR_Y6FotC2tZGLKtr04g
162
162
  diffsynth_engine/utils/loader.py,sha256=Z5v1WNDWFY0OrVubB70j5VU3zeaAfEK_j8c1KrGI4yM,1240
163
163
  diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPUI,1601
164
164
  diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
165
- diffsynth_engine/utils/offload.py,sha256=CYIIDr9CDGE3YN1kCOUnd1BBGvYxjfBCR3BMzqCx4RQ,2580
165
+ diffsynth_engine/utils/offload.py,sha256=qEiqbeMQqeV1DtHF-6OuUO8Akdr1enfpPYwF0OBpx98,3500
166
166
  diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
167
167
  diffsynth_engine/utils/parallel.py,sha256=TtB6FzP2qo4VQqSenVnV-ZaKmp9xHaIWJ8D1ZjHtukE,17064
168
168
  diffsynth_engine/utils/platform.py,sha256=2lXdw6YkqcRONCeT98n4cyg1Ii8Ybbyj2Ns72Se9tlk,496
169
169
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
170
170
  diffsynth_engine/utils/video.py,sha256=Ne0rd2lb59UT1q5EotpjlY7OT8F9oTCFDyo1ST77uoQ,1004
171
- diffsynth_engine-0.4.2.dev4.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
172
- diffsynth_engine-0.4.2.dev4.dist-info/METADATA,sha256=-RN0vnmcItmjNVo9q-OV8oMAyfLk8MMvy03DUdZ516Q,1110
173
- diffsynth_engine-0.4.2.dev4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
174
- diffsynth_engine-0.4.2.dev4.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
175
- diffsynth_engine-0.4.2.dev4.dist-info/RECORD,,
171
+ diffsynth_engine-0.4.2.dev6.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
172
+ diffsynth_engine-0.4.2.dev6.dist-info/METADATA,sha256=xy1w-I4tAN0UahyFnr4ZP9EoMzm8C1uFGSrC4Y1P0m4,1110
173
+ diffsynth_engine-0.4.2.dev6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
174
+ diffsynth_engine-0.4.2.dev6.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
175
+ diffsynth_engine-0.4.2.dev6.dist-info/RECORD,,