diffsynth-engine 0.6.1.dev27__py3-none-any.whl → 0.6.1.dev29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/configs/pipeline.py +5 -0
- diffsynth_engine/models/base.py +1 -1
- diffsynth_engine/models/basic/lora.py +1 -0
- diffsynth_engine/models/basic/lora_nunchaku.py +221 -0
- diffsynth_engine/models/basic/video_sparse_attention.py +15 -3
- diffsynth_engine/models/qwen_image/__init__.py +8 -0
- diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py +341 -0
- diffsynth_engine/pipelines/base.py +11 -4
- diffsynth_engine/pipelines/qwen_image.py +64 -2
- diffsynth_engine/pipelines/wan_video.py +25 -1
- diffsynth_engine/utils/flag.py +24 -0
- diffsynth_engine/utils/parallel.py +23 -107
- diffsynth_engine/utils/process_group.py +149 -0
- {diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/METADATA +1 -1
- {diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/RECORD +18 -15
- {diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/top_level.txt +0 -0
|
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
|
|
|
251
251
|
# override OptimizationConfig
|
|
252
252
|
fbcache_relative_l1_threshold = 0.009
|
|
253
253
|
|
|
254
|
+
# svd
|
|
255
|
+
use_nunchaku: Optional[bool] = field(default=None, init=False)
|
|
256
|
+
use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
|
|
257
|
+
use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
|
|
258
|
+
|
|
254
259
|
@classmethod
|
|
255
260
|
def basic_config(
|
|
256
261
|
cls,
|
diffsynth_engine/models/base.py
CHANGED
|
@@ -40,7 +40,7 @@ class PreTrainedModel(nn.Module):
|
|
|
40
40
|
|
|
41
41
|
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
|
|
42
42
|
for args in lora_args:
|
|
43
|
-
key = args["
|
|
43
|
+
key = args["key"]
|
|
44
44
|
module = self.get_submodule(key)
|
|
45
45
|
if not isinstance(module, (LoRALinear, LoRAConv2d)):
|
|
46
46
|
raise ValueError(f"Unsupported lora key: {key}")
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
|
|
5
|
+
from .lora import LoRA
|
|
6
|
+
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
|
|
7
|
+
from nunchaku.lora.flux.nunchaku_converter import (
|
|
8
|
+
pack_lowrank_weight,
|
|
9
|
+
unpack_lowrank_weight,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LoRASVDQW4A4Linear(nn.Module):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
origin_linear: SVDQW4A4Linear,
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.origin_linear = origin_linear
|
|
21
|
+
self.base_rank = self.origin_linear.rank
|
|
22
|
+
self._lora_dict = OrderedDict()
|
|
23
|
+
|
|
24
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
return self.origin_linear(x)
|
|
26
|
+
|
|
27
|
+
def __getattr__(self, name: str):
|
|
28
|
+
try:
|
|
29
|
+
return super().__getattr__(name)
|
|
30
|
+
except AttributeError:
|
|
31
|
+
return getattr(self.origin_linear, name)
|
|
32
|
+
|
|
33
|
+
def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
|
|
34
|
+
final_scale = scale * (alpha / rank)
|
|
35
|
+
|
|
36
|
+
up_scaled = (up * final_scale).to(
|
|
37
|
+
dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
|
|
38
|
+
)
|
|
39
|
+
down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
|
|
40
|
+
|
|
41
|
+
with torch.no_grad():
|
|
42
|
+
pd_packed = self.origin_linear.proj_down.data
|
|
43
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
44
|
+
pd = unpack_lowrank_weight(pd_packed, down=True)
|
|
45
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
46
|
+
|
|
47
|
+
new_proj_down = torch.cat([pd, down_final], dim=0)
|
|
48
|
+
new_proj_up = torch.cat([pu, up_scaled], dim=1)
|
|
49
|
+
|
|
50
|
+
self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
|
|
51
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
|
|
52
|
+
|
|
53
|
+
current_total_rank = self.origin_linear.rank
|
|
54
|
+
self.origin_linear.rank += rank
|
|
55
|
+
self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
|
|
56
|
+
|
|
57
|
+
def add_frozen_lora(
|
|
58
|
+
self,
|
|
59
|
+
name: str,
|
|
60
|
+
scale: float,
|
|
61
|
+
rank: int,
|
|
62
|
+
alpha: int,
|
|
63
|
+
up: torch.Tensor,
|
|
64
|
+
down: torch.Tensor,
|
|
65
|
+
device: str,
|
|
66
|
+
dtype: torch.dtype,
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
if name in self._lora_dict:
|
|
70
|
+
raise ValueError(f"LoRA with name '{name}' already exists.")
|
|
71
|
+
|
|
72
|
+
self._apply_lora_weights(name, down, up, alpha, scale, rank)
|
|
73
|
+
|
|
74
|
+
def add_qkv_lora(
|
|
75
|
+
self,
|
|
76
|
+
name: str,
|
|
77
|
+
scale: float,
|
|
78
|
+
rank: int,
|
|
79
|
+
alpha: int,
|
|
80
|
+
q_up: torch.Tensor,
|
|
81
|
+
q_down: torch.Tensor,
|
|
82
|
+
k_up: torch.Tensor,
|
|
83
|
+
k_down: torch.Tensor,
|
|
84
|
+
v_up: torch.Tensor,
|
|
85
|
+
v_down: torch.Tensor,
|
|
86
|
+
device: str,
|
|
87
|
+
dtype: torch.dtype,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
if name in self._lora_dict:
|
|
91
|
+
raise ValueError(f"LoRA with name '{name}' already exists.")
|
|
92
|
+
|
|
93
|
+
fused_down = torch.cat([q_down, k_down, v_down], dim=0)
|
|
94
|
+
|
|
95
|
+
fused_rank = 3 * rank
|
|
96
|
+
out_q, out_k = q_up.shape[0], k_up.shape[0]
|
|
97
|
+
fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
|
|
98
|
+
fused_up[:out_q, :rank] = q_up
|
|
99
|
+
fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
|
|
100
|
+
fused_up[out_q + out_k :, 2 * rank :] = v_up
|
|
101
|
+
|
|
102
|
+
self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
|
|
103
|
+
|
|
104
|
+
def modify_scale(self, name: str, scale: float):
|
|
105
|
+
if name not in self._lora_dict:
|
|
106
|
+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
|
|
107
|
+
|
|
108
|
+
info = self._lora_dict[name]
|
|
109
|
+
old_scale = info["scale"]
|
|
110
|
+
|
|
111
|
+
if old_scale == scale:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
if old_scale == 0:
|
|
115
|
+
scale_factor = 0.0
|
|
116
|
+
else:
|
|
117
|
+
scale_factor = scale / old_scale
|
|
118
|
+
|
|
119
|
+
with torch.no_grad():
|
|
120
|
+
lora_rank = info["rank"]
|
|
121
|
+
start_idx = info["start_idx"]
|
|
122
|
+
end_idx = start_idx + lora_rank
|
|
123
|
+
|
|
124
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
125
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
126
|
+
pu[:, start_idx:end_idx] *= scale_factor
|
|
127
|
+
|
|
128
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
|
|
129
|
+
|
|
130
|
+
self._lora_dict[name]["scale"] = scale
|
|
131
|
+
|
|
132
|
+
def clear(self, release_all_cpu_memory: bool = False):
|
|
133
|
+
if not self._lora_dict:
|
|
134
|
+
return
|
|
135
|
+
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
pd_packed = self.origin_linear.proj_down.data
|
|
138
|
+
pu_packed = self.origin_linear.proj_up.data
|
|
139
|
+
|
|
140
|
+
pd = unpack_lowrank_weight(pd_packed, down=True)
|
|
141
|
+
pu = unpack_lowrank_weight(pu_packed, down=False)
|
|
142
|
+
|
|
143
|
+
pd_reset = pd[: self.base_rank, :].clone()
|
|
144
|
+
pu_reset = pu[:, : self.base_rank].clone()
|
|
145
|
+
|
|
146
|
+
self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
|
|
147
|
+
self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
|
|
148
|
+
|
|
149
|
+
self.origin_linear.rank = self.base_rank
|
|
150
|
+
|
|
151
|
+
self._lora_dict.clear()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class LoRAAWQW4A16Linear(nn.Module):
|
|
155
|
+
def __init__(self, origin_linear: AWQW4A16Linear):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.origin_linear = origin_linear
|
|
158
|
+
self._lora_dict = OrderedDict()
|
|
159
|
+
|
|
160
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
quantized_output = self.origin_linear(x)
|
|
162
|
+
|
|
163
|
+
for name, lora in self._lora_dict.items():
|
|
164
|
+
quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
|
|
165
|
+
|
|
166
|
+
return quantized_output
|
|
167
|
+
|
|
168
|
+
def __getattr__(self, name: str):
|
|
169
|
+
try:
|
|
170
|
+
return super().__getattr__(name)
|
|
171
|
+
except AttributeError:
|
|
172
|
+
return getattr(self.origin_linear, name)
|
|
173
|
+
|
|
174
|
+
def add_lora(
|
|
175
|
+
self,
|
|
176
|
+
name: str,
|
|
177
|
+
scale: float,
|
|
178
|
+
rank: int,
|
|
179
|
+
alpha: int,
|
|
180
|
+
up: torch.Tensor,
|
|
181
|
+
down: torch.Tensor,
|
|
182
|
+
device: str,
|
|
183
|
+
dtype: torch.dtype,
|
|
184
|
+
**kwargs,
|
|
185
|
+
):
|
|
186
|
+
up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
|
|
187
|
+
down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
|
|
188
|
+
|
|
189
|
+
up_linear.weight.data = up.reshape(self.out_features, rank)
|
|
190
|
+
down_linear.weight.data = down.reshape(rank, self.in_features)
|
|
191
|
+
|
|
192
|
+
lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
|
|
193
|
+
self._lora_dict[name] = lora
|
|
194
|
+
|
|
195
|
+
def modify_scale(self, name: str, scale: float):
|
|
196
|
+
if name not in self._lora_dict:
|
|
197
|
+
raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
|
|
198
|
+
self._lora_dict[name].scale = scale
|
|
199
|
+
|
|
200
|
+
def add_frozen_lora(self, *args, **kwargs):
|
|
201
|
+
raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
|
|
202
|
+
|
|
203
|
+
def clear(self, *args, **kwargs):
|
|
204
|
+
self._lora_dict.clear()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def patch_nunchaku_model_for_lora(model: nn.Module):
|
|
208
|
+
def _recursive_patch(module: nn.Module):
|
|
209
|
+
for name, child_module in module.named_children():
|
|
210
|
+
replacement = None
|
|
211
|
+
if isinstance(child_module, AWQW4A16Linear):
|
|
212
|
+
replacement = LoRAAWQW4A16Linear(child_module)
|
|
213
|
+
elif isinstance(child_module, SVDQW4A4Linear):
|
|
214
|
+
replacement = LoRASVDQW4A4Linear(child_module)
|
|
215
|
+
|
|
216
|
+
if replacement:
|
|
217
|
+
setattr(module, name, replacement)
|
|
218
|
+
else:
|
|
219
|
+
_recursive_patch(child_module)
|
|
220
|
+
|
|
221
|
+
_recursive_patch(model)
|
|
@@ -3,10 +3,15 @@ import math
|
|
|
3
3
|
import functools
|
|
4
4
|
|
|
5
5
|
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
|
|
6
|
-
from diffsynth_engine.utils.
|
|
6
|
+
from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
|
|
7
7
|
|
|
8
|
+
|
|
9
|
+
vsa_core = None
|
|
8
10
|
if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
9
|
-
|
|
11
|
+
try:
|
|
12
|
+
from vsa import video_sparse_attn as vsa_core
|
|
13
|
+
except Exception:
|
|
14
|
+
vsa_core = None
|
|
10
15
|
|
|
11
16
|
VSA_TILE_SIZE = (4, 4, 4)
|
|
12
17
|
|
|
@@ -171,6 +176,12 @@ def video_sparse_attn(
|
|
|
171
176
|
variable_block_sizes: torch.LongTensor,
|
|
172
177
|
non_pad_index: torch.LongTensor,
|
|
173
178
|
):
|
|
179
|
+
if vsa_core is None:
|
|
180
|
+
raise RuntimeError(
|
|
181
|
+
"Video sparse attention (VSA) is not available. "
|
|
182
|
+
"Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
|
|
183
|
+
)
|
|
184
|
+
|
|
174
185
|
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
|
|
175
186
|
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
|
|
176
187
|
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
|
|
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
|
|
|
212
223
|
):
|
|
213
224
|
from yunchang.comm.all_to_all import SeqAllToAll4D
|
|
214
225
|
|
|
215
|
-
|
|
226
|
+
ring_world_size = get_sp_ring_world_size()
|
|
227
|
+
assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
|
|
216
228
|
sp_ulysses_group = get_sp_ulysses_group()
|
|
217
229
|
|
|
218
230
|
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Any, Dict, List, Tuple, Optional
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
|
|
6
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
7
|
+
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
|
|
8
|
+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
|
|
9
|
+
from diffsynth_engine.models.qwen_image.qwen_image_dit import (
|
|
10
|
+
QwenFeedForward,
|
|
11
|
+
apply_rotary_emb_qwen,
|
|
12
|
+
QwenDoubleStreamAttention,
|
|
13
|
+
QwenImageTransformerBlock,
|
|
14
|
+
QwenImageDiT,
|
|
15
|
+
QwenEmbedRope,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from nunchaku.models.utils import fuse_linears
|
|
19
|
+
from nunchaku.ops.fused import fused_gelu_mlp
|
|
20
|
+
from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
|
|
21
|
+
from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
|
|
22
|
+
from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
dim_a,
|
|
29
|
+
dim_b,
|
|
30
|
+
num_heads,
|
|
31
|
+
head_dim,
|
|
32
|
+
device: str = "cuda:0",
|
|
33
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
34
|
+
nunchaku_rank: int = 32,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
|
|
37
|
+
|
|
38
|
+
to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
|
|
39
|
+
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
|
|
40
|
+
self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
|
|
41
|
+
|
|
42
|
+
del self.to_q, self.to_k, self.to_v
|
|
43
|
+
|
|
44
|
+
add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
|
|
45
|
+
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
|
|
46
|
+
self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
|
|
47
|
+
|
|
48
|
+
del self.add_q_proj, self.add_k_proj, self.add_v_proj
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self,
|
|
52
|
+
image: torch.FloatTensor,
|
|
53
|
+
text: torch.FloatTensor,
|
|
54
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
55
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
56
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
57
|
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
58
|
+
img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
|
|
59
|
+
txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
|
|
60
|
+
|
|
61
|
+
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
62
|
+
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
63
|
+
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
64
|
+
|
|
65
|
+
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
|
|
66
|
+
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
|
|
67
|
+
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
|
|
68
|
+
|
|
69
|
+
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
|
|
70
|
+
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
|
|
71
|
+
|
|
72
|
+
if rotary_emb is not None:
|
|
73
|
+
img_freqs, txt_freqs = rotary_emb
|
|
74
|
+
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
|
|
75
|
+
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
|
|
76
|
+
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
|
|
77
|
+
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
|
|
78
|
+
|
|
79
|
+
joint_q = torch.cat([txt_q, img_q], dim=1)
|
|
80
|
+
joint_k = torch.cat([txt_k, img_k], dim=1)
|
|
81
|
+
joint_v = torch.cat([txt_v, img_v], dim=1)
|
|
82
|
+
|
|
83
|
+
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
|
|
84
|
+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
|
|
85
|
+
|
|
86
|
+
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
|
|
87
|
+
|
|
88
|
+
txt_attn_output = joint_attn_out[:, : text.shape[1], :]
|
|
89
|
+
img_attn_output = joint_attn_out[:, text.shape[1] :, :]
|
|
90
|
+
|
|
91
|
+
img_attn_output = self.to_out(img_attn_output)
|
|
92
|
+
txt_attn_output = self.to_add_out(txt_attn_output)
|
|
93
|
+
|
|
94
|
+
return img_attn_output, txt_attn_output
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class QwenFeedForwardNunchaku(QwenFeedForward):
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
dim: int,
|
|
101
|
+
dim_out: Optional[int] = None,
|
|
102
|
+
dropout: float = 0.0,
|
|
103
|
+
device: str = "cuda:0",
|
|
104
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
105
|
+
rank: int = 32,
|
|
106
|
+
):
|
|
107
|
+
super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
|
|
108
|
+
self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
|
|
109
|
+
self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
|
|
110
|
+
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
|
|
111
|
+
|
|
112
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
|
113
|
+
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
dim: int,
|
|
120
|
+
num_attention_heads: int,
|
|
121
|
+
attention_head_dim: int,
|
|
122
|
+
eps: float = 1e-6,
|
|
123
|
+
device: str = "cuda:0",
|
|
124
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
125
|
+
scale_shift: float = 1.0,
|
|
126
|
+
use_nunchaku_awq: bool = True,
|
|
127
|
+
use_nunchaku_attn: bool = True,
|
|
128
|
+
nunchaku_rank: int = 32,
|
|
129
|
+
):
|
|
130
|
+
super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
|
|
131
|
+
|
|
132
|
+
self.use_nunchaku_awq = use_nunchaku_awq
|
|
133
|
+
if use_nunchaku_awq:
|
|
134
|
+
self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
|
|
135
|
+
|
|
136
|
+
if use_nunchaku_attn:
|
|
137
|
+
self.attn = QwenDoubleStreamAttentionNunchaku(
|
|
138
|
+
dim_a=dim,
|
|
139
|
+
dim_b=dim,
|
|
140
|
+
num_heads=num_attention_heads,
|
|
141
|
+
head_dim=attention_head_dim,
|
|
142
|
+
device=device,
|
|
143
|
+
dtype=dtype,
|
|
144
|
+
nunchaku_rank=nunchaku_rank,
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
self.attn = QwenDoubleStreamAttention(
|
|
148
|
+
dim_a=dim,
|
|
149
|
+
dim_b=dim,
|
|
150
|
+
num_heads=num_attention_heads,
|
|
151
|
+
head_dim=attention_head_dim,
|
|
152
|
+
device=device,
|
|
153
|
+
dtype=dtype,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
|
|
157
|
+
|
|
158
|
+
if use_nunchaku_awq:
|
|
159
|
+
self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
|
|
160
|
+
|
|
161
|
+
self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
|
|
162
|
+
|
|
163
|
+
self.scale_shift = scale_shift
|
|
164
|
+
|
|
165
|
+
def _modulate(self, x, mod_params):
|
|
166
|
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
|
167
|
+
if self.use_nunchaku_awq:
|
|
168
|
+
if self.scale_shift != 0:
|
|
169
|
+
scale.add_(self.scale_shift)
|
|
170
|
+
return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
171
|
+
else:
|
|
172
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
|
173
|
+
|
|
174
|
+
def forward(
|
|
175
|
+
self,
|
|
176
|
+
image: torch.Tensor,
|
|
177
|
+
text: torch.Tensor,
|
|
178
|
+
temb: torch.Tensor,
|
|
179
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
180
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
181
|
+
attn_kwargs: Optional[Dict[str, Any]] = None,
|
|
182
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
183
|
+
if self.use_nunchaku_awq:
|
|
184
|
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
|
185
|
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
|
186
|
+
|
|
187
|
+
# nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
|
|
188
|
+
img_mod_params = (
|
|
189
|
+
img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
|
|
190
|
+
)
|
|
191
|
+
txt_mod_params = (
|
|
192
|
+
txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
|
|
196
|
+
txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
|
|
197
|
+
else:
|
|
198
|
+
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
199
|
+
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
|
|
200
|
+
|
|
201
|
+
img_normed = self.img_norm1(image)
|
|
202
|
+
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
|
|
203
|
+
|
|
204
|
+
txt_normed = self.txt_norm1(text)
|
|
205
|
+
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
|
|
206
|
+
|
|
207
|
+
img_attn_out, txt_attn_out = self.attn(
|
|
208
|
+
image=img_modulated,
|
|
209
|
+
text=txt_modulated,
|
|
210
|
+
rotary_emb=rotary_emb,
|
|
211
|
+
attn_mask=attn_mask,
|
|
212
|
+
attn_kwargs=attn_kwargs,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
image = image + img_gate * img_attn_out
|
|
216
|
+
text = text + txt_gate * txt_attn_out
|
|
217
|
+
|
|
218
|
+
img_normed_2 = self.img_norm2(image)
|
|
219
|
+
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
|
|
220
|
+
|
|
221
|
+
txt_normed_2 = self.txt_norm2(text)
|
|
222
|
+
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
|
|
223
|
+
|
|
224
|
+
img_mlp_out = self.img_mlp(img_modulated_2)
|
|
225
|
+
txt_mlp_out = self.txt_mlp(txt_modulated_2)
|
|
226
|
+
|
|
227
|
+
image = image + img_gate_2 * img_mlp_out
|
|
228
|
+
text = text + txt_gate_2 * txt_mlp_out
|
|
229
|
+
|
|
230
|
+
return text, image
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class QwenImageDiTNunchaku(QwenImageDiT):
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
num_layers: int = 60,
|
|
237
|
+
device: str = "cuda:0",
|
|
238
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
239
|
+
use_nunchaku_awq: bool = True,
|
|
240
|
+
use_nunchaku_attn: bool = True,
|
|
241
|
+
nunchaku_rank: int = 32,
|
|
242
|
+
):
|
|
243
|
+
super().__init__()
|
|
244
|
+
|
|
245
|
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
|
|
246
|
+
|
|
247
|
+
self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
|
|
248
|
+
|
|
249
|
+
self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
|
|
250
|
+
|
|
251
|
+
self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
|
|
252
|
+
self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
|
|
253
|
+
|
|
254
|
+
self.transformer_blocks = nn.ModuleList(
|
|
255
|
+
[
|
|
256
|
+
QwenImageTransformerBlockNunchaku(
|
|
257
|
+
dim=3072,
|
|
258
|
+
num_attention_heads=24,
|
|
259
|
+
attention_head_dim=128,
|
|
260
|
+
device=device,
|
|
261
|
+
dtype=dtype,
|
|
262
|
+
scale_shift=0,
|
|
263
|
+
use_nunchaku_awq=use_nunchaku_awq,
|
|
264
|
+
use_nunchaku_attn=use_nunchaku_attn,
|
|
265
|
+
nunchaku_rank=nunchaku_rank,
|
|
266
|
+
)
|
|
267
|
+
for _ in range(num_layers)
|
|
268
|
+
]
|
|
269
|
+
)
|
|
270
|
+
self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
|
|
271
|
+
self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def from_state_dict(
|
|
275
|
+
cls,
|
|
276
|
+
state_dict: Dict[str, torch.Tensor],
|
|
277
|
+
device: str,
|
|
278
|
+
dtype: torch.dtype,
|
|
279
|
+
num_layers: int = 60,
|
|
280
|
+
use_nunchaku_awq: bool = True,
|
|
281
|
+
use_nunchaku_attn: bool = True,
|
|
282
|
+
nunchaku_rank: int = 32,
|
|
283
|
+
):
|
|
284
|
+
model = cls(
|
|
285
|
+
device="meta",
|
|
286
|
+
dtype=dtype,
|
|
287
|
+
num_layers=num_layers,
|
|
288
|
+
use_nunchaku_awq=use_nunchaku_awq,
|
|
289
|
+
use_nunchaku_attn=use_nunchaku_attn,
|
|
290
|
+
nunchaku_rank=nunchaku_rank,
|
|
291
|
+
)
|
|
292
|
+
model = model.requires_grad_(False)
|
|
293
|
+
model.load_state_dict(state_dict, assign=True)
|
|
294
|
+
model.to(device=device, non_blocking=True)
|
|
295
|
+
return model
|
|
296
|
+
|
|
297
|
+
def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
|
|
298
|
+
fuse_dict = {}
|
|
299
|
+
for args in lora_args:
|
|
300
|
+
key = args["key"]
|
|
301
|
+
if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
|
|
302
|
+
fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
|
|
303
|
+
type = key.rsplit(".", 1)[-1].split("_")[1]
|
|
304
|
+
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
|
|
305
|
+
fuse_dict[fuse_key][type] = args
|
|
306
|
+
continue
|
|
307
|
+
|
|
308
|
+
if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
|
|
309
|
+
fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
|
|
310
|
+
type = key.rsplit(".", 1)[-1].split("_")[1]
|
|
311
|
+
fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
|
|
312
|
+
fuse_dict[fuse_key][type] = args
|
|
313
|
+
continue
|
|
314
|
+
|
|
315
|
+
module = self.get_submodule(key)
|
|
316
|
+
if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
|
|
317
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
318
|
+
|
|
319
|
+
if fused and not isinstance(module, LoRAAWQW4A16Linear):
|
|
320
|
+
module.add_frozen_lora(**args)
|
|
321
|
+
else:
|
|
322
|
+
module.add_lora(**args)
|
|
323
|
+
|
|
324
|
+
for key in fuse_dict.keys():
|
|
325
|
+
module = self.get_submodule(key)
|
|
326
|
+
if not isinstance(module, LoRASVDQW4A4Linear):
|
|
327
|
+
raise ValueError(f"Unsupported lora key: {key}")
|
|
328
|
+
module.add_qkv_lora(
|
|
329
|
+
name=args["name"],
|
|
330
|
+
scale=fuse_dict[key]["q"]["scale"],
|
|
331
|
+
rank=fuse_dict[key]["q"]["rank"],
|
|
332
|
+
alpha=fuse_dict[key]["q"]["alpha"],
|
|
333
|
+
q_up=fuse_dict[key]["q"]["up"],
|
|
334
|
+
q_down=fuse_dict[key]["q"]["down"],
|
|
335
|
+
k_up=fuse_dict[key]["k"]["up"],
|
|
336
|
+
k_down=fuse_dict[key]["k"]["down"],
|
|
337
|
+
v_up=fuse_dict[key]["v"]["up"],
|
|
338
|
+
v_down=fuse_dict[key]["v"]["down"],
|
|
339
|
+
device=fuse_dict[key]["q"]["device"],
|
|
340
|
+
dtype=fuse_dict[key]["q"]["dtype"],
|
|
341
|
+
)
|
|
@@ -106,7 +106,8 @@ class BasePipeline:
|
|
|
106
106
|
for key, param in state_dict.items():
|
|
107
107
|
lora_args.append(
|
|
108
108
|
{
|
|
109
|
-
"name":
|
|
109
|
+
"name": lora_path,
|
|
110
|
+
"key": key,
|
|
110
111
|
"scale": lora_scale,
|
|
111
112
|
"rank": param["rank"],
|
|
112
113
|
"alpha": param["alpha"],
|
|
@@ -130,7 +131,10 @@ class BasePipeline:
|
|
|
130
131
|
|
|
131
132
|
@staticmethod
|
|
132
133
|
def load_model_checkpoint(
|
|
133
|
-
checkpoint_path: str | List[str],
|
|
134
|
+
checkpoint_path: str | List[str],
|
|
135
|
+
device: str = "cpu",
|
|
136
|
+
dtype: torch.dtype = torch.float16,
|
|
137
|
+
convert_dtype: bool = True,
|
|
134
138
|
) -> Dict[str, torch.Tensor]:
|
|
135
139
|
if isinstance(checkpoint_path, str):
|
|
136
140
|
checkpoint_path = [checkpoint_path]
|
|
@@ -140,8 +144,11 @@ class BasePipeline:
|
|
|
140
144
|
raise FileNotFoundError(f"{path} is not a file")
|
|
141
145
|
elif path.endswith(".safetensors"):
|
|
142
146
|
state_dict_ = load_file(path, device=device)
|
|
143
|
-
|
|
144
|
-
|
|
147
|
+
if convert_dtype:
|
|
148
|
+
for key, value in state_dict_.items():
|
|
149
|
+
state_dict[key] = value.to(dtype)
|
|
150
|
+
else:
|
|
151
|
+
state_dict.update(state_dict_)
|
|
145
152
|
|
|
146
153
|
elif path.endswith(".gguf"):
|
|
147
154
|
state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
|
|
@@ -2,6 +2,7 @@ import json
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
import math
|
|
5
|
+
import sys
|
|
5
6
|
from typing import Callable, List, Dict, Tuple, Optional, Union
|
|
6
7
|
from tqdm import tqdm
|
|
7
8
|
from einops import rearrange
|
|
@@ -38,11 +39,13 @@ from diffsynth_engine.utils.parallel import ParallelWrapper
|
|
|
38
39
|
from diffsynth_engine.utils import logging
|
|
39
40
|
from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
|
|
40
41
|
from diffsynth_engine.utils.download import fetch_model
|
|
42
|
+
from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE
|
|
41
43
|
|
|
42
44
|
|
|
43
45
|
logger = logging.get_logger(__name__)
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
|
|
46
49
|
class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
47
50
|
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
48
51
|
dit_dict = {}
|
|
@@ -77,6 +80,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
77
80
|
|
|
78
81
|
key = key.replace(f".{lora_a_suffix}", "")
|
|
79
82
|
key = key.replace("base_model.model.", "")
|
|
83
|
+
key = key.replace("transformer.", "")
|
|
80
84
|
|
|
81
85
|
if key.startswith("transformer") and "attn.to_out.0" in key:
|
|
82
86
|
key = key.replace("attn.to_out.0", "attn.to_out")
|
|
@@ -177,6 +181,36 @@ class QwenImagePipeline(BasePipeline):
|
|
|
177
181
|
"vae",
|
|
178
182
|
]
|
|
179
183
|
|
|
184
|
+
@classmethod
|
|
185
|
+
def _setup_nunchaku_config(
|
|
186
|
+
cls, model_state_dict: Dict[str, torch.Tensor], config: QwenImagePipelineConfig
|
|
187
|
+
) -> QwenImagePipelineConfig:
|
|
188
|
+
is_nunchaku_model = any("qweight" in key for key in model_state_dict)
|
|
189
|
+
|
|
190
|
+
if is_nunchaku_model:
|
|
191
|
+
logger.info("Nunchaku quantized model detected. Configuring for nunchaku.")
|
|
192
|
+
config.use_nunchaku = True
|
|
193
|
+
config.nunchaku_rank = model_state_dict["transformer_blocks.0.img_mlp.net.0.proj.proj_up"].shape[1]
|
|
194
|
+
|
|
195
|
+
if "transformer_blocks.0.img_mod.1.qweight" in model_state_dict:
|
|
196
|
+
config.use_nunchaku_awq = True
|
|
197
|
+
logger.info("Enable nunchaku AWQ.")
|
|
198
|
+
else:
|
|
199
|
+
config.use_nunchaku_awq = False
|
|
200
|
+
logger.info("Disable nunchaku AWQ.")
|
|
201
|
+
|
|
202
|
+
if "transformer_blocks.0.attn.to_qkv.qweight" in model_state_dict:
|
|
203
|
+
config.use_nunchaku_attn = True
|
|
204
|
+
logger.info("Enable nunchaku attention quantization.")
|
|
205
|
+
else:
|
|
206
|
+
config.use_nunchaku_attn = False
|
|
207
|
+
logger.info("Disable nunchaku attention quantization.")
|
|
208
|
+
|
|
209
|
+
else:
|
|
210
|
+
config.use_nunchaku = False
|
|
211
|
+
|
|
212
|
+
return config
|
|
213
|
+
|
|
180
214
|
@classmethod
|
|
181
215
|
def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) -> "QwenImagePipeline":
|
|
182
216
|
if isinstance(model_path_or_config, str):
|
|
@@ -185,7 +219,16 @@ class QwenImagePipeline(BasePipeline):
|
|
|
185
219
|
config = model_path_or_config
|
|
186
220
|
|
|
187
221
|
logger.info(f"loading state dict from {config.model_path} ...")
|
|
188
|
-
model_state_dict = cls.load_model_checkpoint(
|
|
222
|
+
model_state_dict = cls.load_model_checkpoint(
|
|
223
|
+
config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
config = cls._setup_nunchaku_config(model_state_dict, config)
|
|
227
|
+
|
|
228
|
+
# for svd quant model fp4/int4 linear layers, do not convert dtype here
|
|
229
|
+
if not config.use_nunchaku:
|
|
230
|
+
for key, value in model_state_dict.items():
|
|
231
|
+
model_state_dict[key] = value.to(config.model_dtype)
|
|
189
232
|
|
|
190
233
|
if config.vae_path is None:
|
|
191
234
|
config.vae_path = fetch_model(
|
|
@@ -221,6 +264,8 @@ class QwenImagePipeline(BasePipeline):
|
|
|
221
264
|
|
|
222
265
|
@classmethod
|
|
223
266
|
def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
|
|
267
|
+
config = cls._setup_nunchaku_config(state_dicts.model, config)
|
|
268
|
+
|
|
224
269
|
if config.parallelism > 1:
|
|
225
270
|
pipe = ParallelWrapper(
|
|
226
271
|
cfg_degree=config.cfg_degree,
|
|
@@ -270,13 +315,30 @@ class QwenImagePipeline(BasePipeline):
|
|
|
270
315
|
dtype=config.model_dtype,
|
|
271
316
|
relative_l1_threshold=config.fbcache_relative_l1_threshold,
|
|
272
317
|
)
|
|
318
|
+
elif config.use_nunchaku:
|
|
319
|
+
if not NUNCHAKU_AVAILABLE:
|
|
320
|
+
from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR
|
|
321
|
+
raise ImportError(NUNCHAKU_IMPORT_ERROR)
|
|
322
|
+
|
|
323
|
+
from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
|
|
324
|
+
from diffsynth_engine.models.basic.lora_nunchaku import patch_nunchaku_model_for_lora
|
|
325
|
+
|
|
326
|
+
dit = QwenImageDiTNunchaku.from_state_dict(
|
|
327
|
+
state_dicts.model,
|
|
328
|
+
device=init_device,
|
|
329
|
+
dtype=config.model_dtype,
|
|
330
|
+
use_nunchaku_awq=config.use_nunchaku_awq,
|
|
331
|
+
use_nunchaku_attn=config.use_nunchaku_attn,
|
|
332
|
+
nunchaku_rank=config.nunchaku_rank,
|
|
333
|
+
)
|
|
334
|
+
patch_nunchaku_model_for_lora(dit)
|
|
273
335
|
else:
|
|
274
336
|
dit = QwenImageDiT.from_state_dict(
|
|
275
337
|
state_dicts.model,
|
|
276
338
|
device=("cpu" if config.use_fsdp else init_device),
|
|
277
339
|
dtype=config.model_dtype,
|
|
278
340
|
)
|
|
279
|
-
if config.use_fp8_linear:
|
|
341
|
+
if config.use_fp8_linear and not config.use_nunchaku:
|
|
280
342
|
enable_fp8_linear(dit)
|
|
281
343
|
|
|
282
344
|
pipe = cls(
|
|
@@ -650,7 +650,7 @@ class WanVideoPipeline(BasePipeline):
|
|
|
650
650
|
dit_type = "wan2.2-i2v-a14b"
|
|
651
651
|
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
|
|
652
652
|
dit_type = "wan2.2-t2v-a14b"
|
|
653
|
-
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
653
|
+
elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
|
|
654
654
|
dit_type = "wan2.2-ti2v-5b"
|
|
655
655
|
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
|
|
656
656
|
dit_type = "wan2.1-flf2v-14b"
|
|
@@ -680,6 +680,30 @@ class WanVideoPipeline(BasePipeline):
|
|
|
680
680
|
if config.attn_params is None:
|
|
681
681
|
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
|
|
682
682
|
|
|
683
|
+
def update_weights(self, state_dicts: WanStateDicts) -> None:
|
|
684
|
+
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
|
|
685
|
+
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
|
|
686
|
+
is_dual_model_pipeline = self.dit2 is not None
|
|
687
|
+
|
|
688
|
+
if is_dual_model_state_dict != is_dual_model_pipeline:
|
|
689
|
+
raise ValueError(
|
|
690
|
+
f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
|
|
691
|
+
f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
|
|
692
|
+
f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
if is_dual_model_state_dict:
|
|
696
|
+
if "high_noise_model" in state_dicts.model:
|
|
697
|
+
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
|
|
698
|
+
if "low_noise_model" in state_dicts.model:
|
|
699
|
+
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
|
|
700
|
+
else:
|
|
701
|
+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
|
|
702
|
+
|
|
703
|
+
self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
|
|
704
|
+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
|
|
705
|
+
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
|
|
706
|
+
|
|
683
707
|
def compile(self):
|
|
684
708
|
self.dit.compile_repeated_blocks()
|
|
685
709
|
if self.dit2 is not None:
|
diffsynth_engine/utils/flag.py
CHANGED
|
@@ -55,3 +55,27 @@ if VIDEO_SPARSE_ATTN_AVAILABLE:
|
|
|
55
55
|
logger.info("Video sparse attention is available")
|
|
56
56
|
else:
|
|
57
57
|
logger.info("Video sparse attention is not available")
|
|
58
|
+
|
|
59
|
+
NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None
|
|
60
|
+
NUNCHAKU_IMPORT_ERROR = None
|
|
61
|
+
if NUNCHAKU_AVAILABLE:
|
|
62
|
+
logger.info("Nunchaku is available")
|
|
63
|
+
else:
|
|
64
|
+
logger.info("Nunchaku is not available")
|
|
65
|
+
import sys
|
|
66
|
+
torch_version = getattr(torch, "__version__", "unknown")
|
|
67
|
+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
|
68
|
+
NUNCHAKU_IMPORT_ERROR = (
|
|
69
|
+
"\n\n"
|
|
70
|
+
"ERROR: This model requires the 'nunchaku' library for quantized inference, but it is not installed.\n"
|
|
71
|
+
"'nunchaku' is not available on PyPI and must be installed manually.\n\n"
|
|
72
|
+
"Please follow these steps:\n"
|
|
73
|
+
"1. Visit the nunchaku releases page: https://github.com/nunchaku-tech/nunchaku/releases\n"
|
|
74
|
+
"2. Find the wheel (.whl) file that matches your environment:\n"
|
|
75
|
+
f" - PyTorch version: {torch_version}\n"
|
|
76
|
+
f" - Python version: {python_version}\n"
|
|
77
|
+
f" - Operating System: {sys.platform}\n"
|
|
78
|
+
"3. Copy the URL of the correct wheel file.\n"
|
|
79
|
+
"4. Install it using pip, for example:\n"
|
|
80
|
+
" pip install nunchaku @ https://.../your_specific_nunchaku_file.whl\n"
|
|
81
|
+
)
|
|
@@ -21,117 +21,33 @@ from queue import Empty
|
|
|
21
21
|
import diffsynth_engine.models.basic.attention as attention_ops
|
|
22
22
|
from diffsynth_engine.utils.platform import empty_cache
|
|
23
23
|
from diffsynth_engine.utils import logging
|
|
24
|
+
from diffsynth_engine.utils.process_group import (
|
|
25
|
+
PROCESS_GROUP,
|
|
26
|
+
get_cfg_group,
|
|
27
|
+
get_cfg_world_size,
|
|
28
|
+
get_cfg_rank,
|
|
29
|
+
get_cfg_ranks,
|
|
30
|
+
get_sp_group,
|
|
31
|
+
get_sp_world_size,
|
|
32
|
+
get_sp_rank,
|
|
33
|
+
get_sp_ranks,
|
|
34
|
+
get_sp_ulysses_group,
|
|
35
|
+
get_sp_ulysses_world_size,
|
|
36
|
+
get_sp_ulysses_rank,
|
|
37
|
+
get_sp_ulysses_ranks,
|
|
38
|
+
get_sp_ring_group,
|
|
39
|
+
get_sp_ring_world_size,
|
|
40
|
+
get_sp_ring_rank,
|
|
41
|
+
get_sp_ring_ranks,
|
|
42
|
+
get_tp_group,
|
|
43
|
+
get_tp_world_size,
|
|
44
|
+
get_tp_rank,
|
|
45
|
+
get_tp_ranks,
|
|
46
|
+
)
|
|
24
47
|
|
|
25
48
|
logger = logging.get_logger(__name__)
|
|
26
49
|
|
|
27
50
|
|
|
28
|
-
class Singleton:
|
|
29
|
-
_instance = None
|
|
30
|
-
|
|
31
|
-
def __new__(cls, *args, **kwargs):
|
|
32
|
-
if not cls._instance:
|
|
33
|
-
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
|
|
34
|
-
return cls._instance
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
class ProcessGroupSingleton(Singleton):
|
|
38
|
-
def __init__(self):
|
|
39
|
-
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
40
|
-
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
41
|
-
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
42
|
-
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
43
|
-
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
44
|
-
|
|
45
|
-
self.CFG_RANKS: List[int] = []
|
|
46
|
-
self.SP_RANKS: List[int] = []
|
|
47
|
-
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
48
|
-
self.SP_RING_RANKS: List[int] = []
|
|
49
|
-
self.TP_RANKS: List[int] = []
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
PROCESS_GROUP = ProcessGroupSingleton()
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_cfg_group():
|
|
56
|
-
return PROCESS_GROUP.CFG_GROUP
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def get_cfg_world_size():
|
|
60
|
-
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def get_cfg_rank():
|
|
64
|
-
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def get_cfg_ranks():
|
|
68
|
-
return PROCESS_GROUP.CFG_RANKS
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def get_sp_group():
|
|
72
|
-
return PROCESS_GROUP.SP_GROUP
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def get_sp_world_size():
|
|
76
|
-
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
def get_sp_rank():
|
|
80
|
-
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def get_sp_ranks():
|
|
84
|
-
return PROCESS_GROUP.SP_RANKS
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def get_sp_ulysses_group():
|
|
88
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def get_sp_ulysses_world_size():
|
|
92
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def get_sp_ulysses_rank():
|
|
96
|
-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_sp_ulysses_ranks():
|
|
100
|
-
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def get_sp_ring_group():
|
|
104
|
-
return PROCESS_GROUP.SP_RING_GROUP
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
def get_sp_ring_world_size():
|
|
108
|
-
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
def get_sp_ring_rank():
|
|
112
|
-
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def get_sp_ring_ranks():
|
|
116
|
-
return PROCESS_GROUP.SP_RING_RANKS
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def get_tp_group():
|
|
120
|
-
return PROCESS_GROUP.TP_GROUP
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def get_tp_world_size():
|
|
124
|
-
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def get_tp_rank():
|
|
128
|
-
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def get_tp_ranks():
|
|
132
|
-
return PROCESS_GROUP.TP_RANKS
|
|
133
|
-
|
|
134
|
-
|
|
135
51
|
def init_parallel_pgs(
|
|
136
52
|
cfg_degree: int = 1,
|
|
137
53
|
sp_ulysses_degree: int = 1,
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Process group management for distributed training.
|
|
3
|
+
|
|
4
|
+
This module provides singleton-based process group management for distributed training,
|
|
5
|
+
including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from typing import Optional, List
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Singleton:
|
|
13
|
+
_instance = None
|
|
14
|
+
|
|
15
|
+
def __new__(cls, *args, **kwargs):
|
|
16
|
+
if not cls._instance:
|
|
17
|
+
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
|
|
18
|
+
return cls._instance
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProcessGroupSingleton(Singleton):
|
|
22
|
+
def __init__(self):
|
|
23
|
+
if not hasattr(self, 'initialized'):
|
|
24
|
+
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
|
|
25
|
+
self.SP_GROUP: Optional[dist.ProcessGroup] = None
|
|
26
|
+
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
|
|
27
|
+
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
|
|
28
|
+
self.TP_GROUP: Optional[dist.ProcessGroup] = None
|
|
29
|
+
|
|
30
|
+
self.CFG_RANKS: List[int] = []
|
|
31
|
+
self.SP_RANKS: List[int] = []
|
|
32
|
+
self.SP_ULYSSUES_RANKS: List[int] = []
|
|
33
|
+
self.SP_RING_RANKS: List[int] = []
|
|
34
|
+
self.TP_RANKS: List[int] = []
|
|
35
|
+
|
|
36
|
+
self.initialized = True
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
PROCESS_GROUP = ProcessGroupSingleton()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# CFG parallel group functions
|
|
43
|
+
def get_cfg_group():
|
|
44
|
+
return PROCESS_GROUP.CFG_GROUP
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_cfg_world_size():
|
|
48
|
+
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_cfg_rank():
|
|
52
|
+
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_cfg_ranks():
|
|
56
|
+
return PROCESS_GROUP.CFG_RANKS
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Sequence parallel group functions
|
|
60
|
+
def get_sp_group():
|
|
61
|
+
return PROCESS_GROUP.SP_GROUP
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_sp_world_size():
|
|
65
|
+
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_sp_rank():
|
|
69
|
+
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_sp_ranks():
|
|
73
|
+
return PROCESS_GROUP.SP_RANKS
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Sequence parallel Ulysses group functions
|
|
77
|
+
def get_sp_ulysses_group():
|
|
78
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_sp_ulysses_world_size():
|
|
82
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_sp_ulysses_rank():
|
|
86
|
+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_sp_ulysses_ranks():
|
|
90
|
+
return PROCESS_GROUP.SP_ULYSSUES_RANKS
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# Sequence parallel Ring group functions
|
|
94
|
+
def get_sp_ring_group():
|
|
95
|
+
return PROCESS_GROUP.SP_RING_GROUP
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_sp_ring_world_size():
|
|
99
|
+
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_sp_ring_rank():
|
|
103
|
+
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_sp_ring_ranks():
|
|
107
|
+
return PROCESS_GROUP.SP_RING_RANKS
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Tensor parallel group functions
|
|
111
|
+
def get_tp_group():
|
|
112
|
+
return PROCESS_GROUP.TP_GROUP
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_tp_world_size():
|
|
116
|
+
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def get_tp_rank():
|
|
120
|
+
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def get_tp_ranks():
|
|
124
|
+
return PROCESS_GROUP.TP_RANKS
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
__all__ = [
|
|
128
|
+
"PROCESS_GROUP",
|
|
129
|
+
"get_cfg_group",
|
|
130
|
+
"get_cfg_world_size",
|
|
131
|
+
"get_cfg_rank",
|
|
132
|
+
"get_cfg_ranks",
|
|
133
|
+
"get_sp_group",
|
|
134
|
+
"get_sp_world_size",
|
|
135
|
+
"get_sp_rank",
|
|
136
|
+
"get_sp_ranks",
|
|
137
|
+
"get_sp_ulysses_group",
|
|
138
|
+
"get_sp_ulysses_world_size",
|
|
139
|
+
"get_sp_ulysses_rank",
|
|
140
|
+
"get_sp_ulysses_ranks",
|
|
141
|
+
"get_sp_ring_group",
|
|
142
|
+
"get_sp_ring_world_size",
|
|
143
|
+
"get_sp_ring_rank",
|
|
144
|
+
"get_sp_ring_ranks",
|
|
145
|
+
"get_tp_group",
|
|
146
|
+
"get_tp_world_size",
|
|
147
|
+
"get_tp_rank",
|
|
148
|
+
"get_tp_ranks",
|
|
149
|
+
]
|
|
@@ -81,18 +81,19 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
|
|
|
81
81
|
diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
|
|
82
82
|
diffsynth_engine/configs/__init__.py,sha256=vSjJToEdq3JX7t81_z4nwNwIdD4bYnFjxnMZH7PXMKo,1309
|
|
83
83
|
diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
|
|
84
|
-
diffsynth_engine/configs/pipeline.py,sha256=
|
|
84
|
+
diffsynth_engine/configs/pipeline.py,sha256=7duSdoD0LIROtepsLW9PxYsK59p7qSv34BVz0k29vu4,13633
|
|
85
85
|
diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
86
86
|
diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
|
|
87
|
-
diffsynth_engine/models/base.py,sha256=
|
|
87
|
+
diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
|
|
88
88
|
diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
89
89
|
diffsynth_engine/models/basic/attention.py,sha256=mvgk8LTqFwgtPdBeRv797IZNg9k7--X9wD92Hcr188c,15682
|
|
90
|
-
diffsynth_engine/models/basic/lora.py,sha256=
|
|
90
|
+
diffsynth_engine/models/basic/lora.py,sha256=Y6cBgrBsuDAP9FZz_fgK8vBi_EMg23saFIUSAsPIG-M,10670
|
|
91
|
+
diffsynth_engine/models/basic/lora_nunchaku.py,sha256=7qhzGCzUIfDrwtWG0nspwdyZ7YUkaM4vMqzxZby2Zds,7510
|
|
91
92
|
diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
|
|
92
93
|
diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
|
|
93
94
|
diffsynth_engine/models/basic/transformer_helper.py,sha256=6K7A5bVnN2bOoq6I0IQf7RJBhSZUP4jNf1n7NPGu8zA,5287
|
|
94
95
|
diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
|
|
95
|
-
diffsynth_engine/models/basic/video_sparse_attention.py,sha256=
|
|
96
|
+
diffsynth_engine/models/basic/video_sparse_attention.py,sha256=GxDN6PTpA1rCoQaXUwSPgH4708bEzVI1qsD48WVDXLA,8201
|
|
96
97
|
diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
|
|
97
98
|
diffsynth_engine/models/flux/flux_controlnet.py,sha256=NvFKQIx0NldX5uUxdmYwuS2s-xaFRlKotiE6lr3-HRY,8018
|
|
98
99
|
diffsynth_engine/models/flux/flux_dit.py,sha256=7sdV8KFQiHcK-8aqyvXBgC7E_-D9rcgBcnMXUq_AybI,23403
|
|
@@ -108,10 +109,11 @@ diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py,sha256=0IUrUSBi-6eWeaScUoi0e6
|
|
|
108
109
|
diffsynth_engine/models/hunyuan3d/moe.py,sha256=FAuUqgrB2ZFb0uGBhI-Afv850HmzDFP5yJKKogf4A4U,3552
|
|
109
110
|
diffsynth_engine/models/hunyuan3d/surface_extractor.py,sha256=b15mb1N4PYwAvDk1Gude8qlccRKrSg461xT59RjMEQk,4167
|
|
110
111
|
diffsynth_engine/models/hunyuan3d/volume_decoder.py,sha256=sgflj1a8sIerqGSalBAVQOlyiIihkLOLXYysNbulCoQ,2355
|
|
111
|
-
diffsynth_engine/models/qwen_image/__init__.py,sha256=
|
|
112
|
+
diffsynth_engine/models/qwen_image/__init__.py,sha256=_6f0LWaoLdDvD2CsjK2OzEIQryt9efge8DFS4_GUnHQ,582
|
|
112
113
|
diffsynth_engine/models/qwen_image/qwen2_5_vl.py,sha256=Eu-r-c42t_q74Qpwz21ToCGHpvSi7VND4B1EI0e-ePA,57748
|
|
113
114
|
diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=iJ-FinDyXa982Uao1is37bxUttyPu0Eldyd7qPJO_XQ,22582
|
|
114
115
|
diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py,sha256=LIv9X_BohKk5rcEzyl3ATLwd8MSoFX43wjkArQ68nq8,4828
|
|
116
|
+
diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py,sha256=TCzNsFxw-QBHrRg94f_ITs5u85Em-aoCAeCr2AylPpE,13478
|
|
115
117
|
diffsynth_engine/models/qwen_image/qwen_image_vae.py,sha256=eO7f4YqiYXfw7NncBNFTu-xEvdJ5uKY-SnfP15QY0tE,38443
|
|
116
118
|
diffsynth_engine/models/sd/__init__.py,sha256=hjoKRnwoXOLD0wude-w7I6wK5ak7ACMbnbkPuBB2oU0,380
|
|
117
119
|
diffsynth_engine/models/sd/sd_controlnet.py,sha256=kMGfIdriXhC7reT6iO2Z0rPICXEkXpytjeBQcR_sjT8,50577
|
|
@@ -141,15 +143,15 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
|
|
|
141
143
|
diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
|
|
142
144
|
diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
|
|
143
145
|
diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
|
|
144
|
-
diffsynth_engine/pipelines/base.py,sha256=
|
|
146
|
+
diffsynth_engine/pipelines/base.py,sha256=BNMNL-OU-9ilUv7O60trA3_rjHA21d6Oc5PKzKYBa80,16347
|
|
145
147
|
diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
|
|
146
148
|
diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
|
|
147
|
-
diffsynth_engine/pipelines/qwen_image.py,sha256=
|
|
149
|
+
diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
|
|
148
150
|
diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
|
|
149
151
|
diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
|
|
150
152
|
diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
|
|
151
153
|
diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
|
|
152
|
-
diffsynth_engine/pipelines/wan_video.py,sha256=
|
|
154
|
+
diffsynth_engine/pipelines/wan_video.py,sha256=9xjSvQ4mlVEDdaL6QuUURj4iyxhJ2xABBphQjkfzK8s,31323
|
|
153
155
|
diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
154
156
|
diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
|
|
155
157
|
diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
|
|
@@ -171,7 +173,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
|
|
|
171
173
|
diffsynth_engine/utils/constants.py,sha256=sJio3Vy8i0-PWYRnqquYt6ez9k6Tc9JdjCv6pn2BU_4,3551
|
|
172
174
|
diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
|
|
173
175
|
diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
|
|
174
|
-
diffsynth_engine/utils/flag.py,sha256=
|
|
176
|
+
diffsynth_engine/utils/flag.py,sha256=KSzjnzRe7sleNCJm8IpbJQbmBY4KNV2kDrijxi27Jek,2928
|
|
175
177
|
diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
|
|
176
178
|
diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
|
|
177
179
|
diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
|
|
@@ -180,15 +182,16 @@ diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPU
|
|
|
180
182
|
diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
|
|
181
183
|
diffsynth_engine/utils/offload.py,sha256=94og79TIkxldwYUgZT3L4OVu1WBlE7gfVPvO2MRhm6c,3551
|
|
182
184
|
diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
|
|
183
|
-
diffsynth_engine/utils/parallel.py,sha256=
|
|
185
|
+
diffsynth_engine/utils/parallel.py,sha256=OBGsAK-3ncArRyMU1lea7tbYgxSdCucQvXheL3Ssl5M,17653
|
|
184
186
|
diffsynth_engine/utils/platform.py,sha256=nbpG-XHJFRmYY6u_e7IBQ9Q6GyItrIkKf3VKuBPTUpY,627
|
|
187
|
+
diffsynth_engine/utils/process_group.py,sha256=P-X04a--Zb4M4kjc3DddmusrxCKqv8wiDGhXG4Al-rE,3783
|
|
185
188
|
diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
|
|
186
189
|
diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CDhg,2200
|
|
187
190
|
diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
188
191
|
diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
|
|
189
192
|
diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
|
|
190
|
-
diffsynth_engine-0.6.1.
|
|
191
|
-
diffsynth_engine-0.6.1.
|
|
192
|
-
diffsynth_engine-0.6.1.
|
|
193
|
-
diffsynth_engine-0.6.1.
|
|
194
|
-
diffsynth_engine-0.6.1.
|
|
193
|
+
diffsynth_engine-0.6.1.dev29.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
|
|
194
|
+
diffsynth_engine-0.6.1.dev29.dist-info/METADATA,sha256=8A5q0qhRMxeJi7IOvP3dcqk58BsgIBxy16ndlnDM_6I,1164
|
|
195
|
+
diffsynth_engine-0.6.1.dev29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
196
|
+
diffsynth_engine-0.6.1.dev29.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
|
|
197
|
+
diffsynth_engine-0.6.1.dev29.dist-info/RECORD,,
|
|
File without changes
|
{diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
{diffsynth_engine-0.6.1.dev27.dist-info → diffsynth_engine-0.6.1.dev29.dist-info}/top_level.txt
RENAMED
|
File without changes
|