diffsynth-engine 0.6.1.dev27__py3-none-any.whl → 0.6.1.dev29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -251,6 +251,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251
251
  # override OptimizationConfig
252
252
  fbcache_relative_l1_threshold = 0.009
253
253
 
254
+ # svd
255
+ use_nunchaku: Optional[bool] = field(default=None, init=False)
256
+ use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257
+ use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258
+
254
259
  @classmethod
255
260
  def basic_config(
256
261
  cls,
@@ -40,7 +40,7 @@ class PreTrainedModel(nn.Module):
40
40
 
41
41
  def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = True):
42
42
  for args in lora_args:
43
- key = args["name"]
43
+ key = args["key"]
44
44
  module = self.get_submodule(key)
45
45
  if not isinstance(module, (LoRALinear, LoRAConv2d)):
46
46
  raise ValueError(f"Unsupported lora key: {key}")
@@ -132,6 +132,7 @@ class LoRALinear(nn.Linear):
132
132
  device: str,
133
133
  dtype: torch.dtype,
134
134
  save_original_weight: bool = True,
135
+ **kwargs,
135
136
  ):
136
137
  if save_original_weight and self._original_weight is None:
137
138
  if self.weight.dtype == torch.float8_e4m3fn:
@@ -0,0 +1,221 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from collections import OrderedDict
4
+
5
+ from .lora import LoRA
6
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
7
+ from nunchaku.lora.flux.nunchaku_converter import (
8
+ pack_lowrank_weight,
9
+ unpack_lowrank_weight,
10
+ )
11
+
12
+
13
+ class LoRASVDQW4A4Linear(nn.Module):
14
+ def __init__(
15
+ self,
16
+ origin_linear: SVDQW4A4Linear,
17
+ ):
18
+ super().__init__()
19
+
20
+ self.origin_linear = origin_linear
21
+ self.base_rank = self.origin_linear.rank
22
+ self._lora_dict = OrderedDict()
23
+
24
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
25
+ return self.origin_linear(x)
26
+
27
+ def __getattr__(self, name: str):
28
+ try:
29
+ return super().__getattr__(name)
30
+ except AttributeError:
31
+ return getattr(self.origin_linear, name)
32
+
33
+ def _apply_lora_weights(self, name: str, down: torch.Tensor, up: torch.Tensor, alpha: int, scale: float, rank: int):
34
+ final_scale = scale * (alpha / rank)
35
+
36
+ up_scaled = (up * final_scale).to(
37
+ dtype=self.origin_linear.proj_up.dtype, device=self.origin_linear.proj_up.device
38
+ )
39
+ down_final = down.to(dtype=self.origin_linear.proj_down.dtype, device=self.origin_linear.proj_down.device)
40
+
41
+ with torch.no_grad():
42
+ pd_packed = self.origin_linear.proj_down.data
43
+ pu_packed = self.origin_linear.proj_up.data
44
+ pd = unpack_lowrank_weight(pd_packed, down=True)
45
+ pu = unpack_lowrank_weight(pu_packed, down=False)
46
+
47
+ new_proj_down = torch.cat([pd, down_final], dim=0)
48
+ new_proj_up = torch.cat([pu, up_scaled], dim=1)
49
+
50
+ self.origin_linear.proj_down.data = pack_lowrank_weight(new_proj_down, down=True)
51
+ self.origin_linear.proj_up.data = pack_lowrank_weight(new_proj_up, down=False)
52
+
53
+ current_total_rank = self.origin_linear.rank
54
+ self.origin_linear.rank += rank
55
+ self._lora_dict[name] = {"rank": rank, "alpha": alpha, "scale": scale, "start_idx": current_total_rank}
56
+
57
+ def add_frozen_lora(
58
+ self,
59
+ name: str,
60
+ scale: float,
61
+ rank: int,
62
+ alpha: int,
63
+ up: torch.Tensor,
64
+ down: torch.Tensor,
65
+ device: str,
66
+ dtype: torch.dtype,
67
+ **kwargs,
68
+ ):
69
+ if name in self._lora_dict:
70
+ raise ValueError(f"LoRA with name '{name}' already exists.")
71
+
72
+ self._apply_lora_weights(name, down, up, alpha, scale, rank)
73
+
74
+ def add_qkv_lora(
75
+ self,
76
+ name: str,
77
+ scale: float,
78
+ rank: int,
79
+ alpha: int,
80
+ q_up: torch.Tensor,
81
+ q_down: torch.Tensor,
82
+ k_up: torch.Tensor,
83
+ k_down: torch.Tensor,
84
+ v_up: torch.Tensor,
85
+ v_down: torch.Tensor,
86
+ device: str,
87
+ dtype: torch.dtype,
88
+ **kwargs,
89
+ ):
90
+ if name in self._lora_dict:
91
+ raise ValueError(f"LoRA with name '{name}' already exists.")
92
+
93
+ fused_down = torch.cat([q_down, k_down, v_down], dim=0)
94
+
95
+ fused_rank = 3 * rank
96
+ out_q, out_k = q_up.shape[0], k_up.shape[0]
97
+ fused_up = torch.zeros((self.out_features, fused_rank), device=q_up.device, dtype=q_up.dtype)
98
+ fused_up[:out_q, :rank] = q_up
99
+ fused_up[out_q : out_q + out_k, rank : 2 * rank] = k_up
100
+ fused_up[out_q + out_k :, 2 * rank :] = v_up
101
+
102
+ self._apply_lora_weights(name, fused_down, fused_up, alpha, scale, rank)
103
+
104
+ def modify_scale(self, name: str, scale: float):
105
+ if name not in self._lora_dict:
106
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
107
+
108
+ info = self._lora_dict[name]
109
+ old_scale = info["scale"]
110
+
111
+ if old_scale == scale:
112
+ return
113
+
114
+ if old_scale == 0:
115
+ scale_factor = 0.0
116
+ else:
117
+ scale_factor = scale / old_scale
118
+
119
+ with torch.no_grad():
120
+ lora_rank = info["rank"]
121
+ start_idx = info["start_idx"]
122
+ end_idx = start_idx + lora_rank
123
+
124
+ pu_packed = self.origin_linear.proj_up.data
125
+ pu = unpack_lowrank_weight(pu_packed, down=False)
126
+ pu[:, start_idx:end_idx] *= scale_factor
127
+
128
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu, down=False)
129
+
130
+ self._lora_dict[name]["scale"] = scale
131
+
132
+ def clear(self, release_all_cpu_memory: bool = False):
133
+ if not self._lora_dict:
134
+ return
135
+
136
+ with torch.no_grad():
137
+ pd_packed = self.origin_linear.proj_down.data
138
+ pu_packed = self.origin_linear.proj_up.data
139
+
140
+ pd = unpack_lowrank_weight(pd_packed, down=True)
141
+ pu = unpack_lowrank_weight(pu_packed, down=False)
142
+
143
+ pd_reset = pd[: self.base_rank, :].clone()
144
+ pu_reset = pu[:, : self.base_rank].clone()
145
+
146
+ self.origin_linear.proj_down.data = pack_lowrank_weight(pd_reset, down=True)
147
+ self.origin_linear.proj_up.data = pack_lowrank_weight(pu_reset, down=False)
148
+
149
+ self.origin_linear.rank = self.base_rank
150
+
151
+ self._lora_dict.clear()
152
+
153
+
154
+ class LoRAAWQW4A16Linear(nn.Module):
155
+ def __init__(self, origin_linear: AWQW4A16Linear):
156
+ super().__init__()
157
+ self.origin_linear = origin_linear
158
+ self._lora_dict = OrderedDict()
159
+
160
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ quantized_output = self.origin_linear(x)
162
+
163
+ for name, lora in self._lora_dict.items():
164
+ quantized_output += lora(x.to(lora.dtype)).to(quantized_output.dtype)
165
+
166
+ return quantized_output
167
+
168
+ def __getattr__(self, name: str):
169
+ try:
170
+ return super().__getattr__(name)
171
+ except AttributeError:
172
+ return getattr(self.origin_linear, name)
173
+
174
+ def add_lora(
175
+ self,
176
+ name: str,
177
+ scale: float,
178
+ rank: int,
179
+ alpha: int,
180
+ up: torch.Tensor,
181
+ down: torch.Tensor,
182
+ device: str,
183
+ dtype: torch.dtype,
184
+ **kwargs,
185
+ ):
186
+ up_linear = nn.Linear(rank, self.out_features, bias=False, device="meta", dtype=dtype).to_empty(device=device)
187
+ down_linear = nn.Linear(self.in_features, rank, bias=False, device="meta", dtype=dtype).to_empty(device=device)
188
+
189
+ up_linear.weight.data = up.reshape(self.out_features, rank)
190
+ down_linear.weight.data = down.reshape(rank, self.in_features)
191
+
192
+ lora = LoRA(scale, rank, alpha, up_linear, down_linear, device, dtype)
193
+ self._lora_dict[name] = lora
194
+
195
+ def modify_scale(self, name: str, scale: float):
196
+ if name not in self._lora_dict:
197
+ raise ValueError(f"LoRA name {name} not found in {self.__class__.__name__}")
198
+ self._lora_dict[name].scale = scale
199
+
200
+ def add_frozen_lora(self, *args, **kwargs):
201
+ raise NotImplementedError("Frozen LoRA (merging weights) is not supported for AWQW4A16Linear.")
202
+
203
+ def clear(self, *args, **kwargs):
204
+ self._lora_dict.clear()
205
+
206
+
207
+ def patch_nunchaku_model_for_lora(model: nn.Module):
208
+ def _recursive_patch(module: nn.Module):
209
+ for name, child_module in module.named_children():
210
+ replacement = None
211
+ if isinstance(child_module, AWQW4A16Linear):
212
+ replacement = LoRAAWQW4A16Linear(child_module)
213
+ elif isinstance(child_module, SVDQW4A4Linear):
214
+ replacement = LoRASVDQW4A4Linear(child_module)
215
+
216
+ if replacement:
217
+ setattr(module, name, replacement)
218
+ else:
219
+ _recursive_patch(child_module)
220
+
221
+ _recursive_patch(model)
@@ -3,10 +3,15 @@ import math
3
3
  import functools
4
4
 
5
5
  from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6
- from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
6
+ from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
7
7
 
8
+
9
+ vsa_core = None
8
10
  if VIDEO_SPARSE_ATTN_AVAILABLE:
9
- from vsa import video_sparse_attn as vsa_core
11
+ try:
12
+ from vsa import video_sparse_attn as vsa_core
13
+ except Exception:
14
+ vsa_core = None
10
15
 
11
16
  VSA_TILE_SIZE = (4, 4, 4)
12
17
 
@@ -171,6 +176,12 @@ def video_sparse_attn(
171
176
  variable_block_sizes: torch.LongTensor,
172
177
  non_pad_index: torch.LongTensor,
173
178
  ):
179
+ if vsa_core is None:
180
+ raise RuntimeError(
181
+ "Video sparse attention (VSA) is not available. "
182
+ "Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
183
+ )
184
+
174
185
  q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175
186
  k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176
187
  v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
212
223
  ):
213
224
  from yunchang.comm.all_to_all import SeqAllToAll4D
214
225
 
215
- assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
226
+ ring_world_size = get_sp_ring_world_size()
227
+ assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
216
228
  sp_ulysses_group = get_sp_ulysses_group()
217
229
 
218
230
  q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)
@@ -11,3 +11,11 @@ __all__ = [
11
11
  "Qwen2_5_VLVisionConfig",
12
12
  "Qwen2_5_VLConfig",
13
13
  ]
14
+
15
+ try:
16
+ from .qwen_image_dit_nunchaku import QwenImageDiTNunchaku
17
+
18
+ __all__.append("QwenImageDiTNunchaku")
19
+
20
+ except (ImportError, ModuleNotFoundError):
21
+ pass
@@ -0,0 +1,341 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Any, Dict, List, Tuple, Optional
4
+ from einops import rearrange
5
+
6
+ from diffsynth_engine.models.basic import attention as attention_ops
7
+ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
8
+ from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, RMSNorm
9
+ from diffsynth_engine.models.qwen_image.qwen_image_dit import (
10
+ QwenFeedForward,
11
+ apply_rotary_emb_qwen,
12
+ QwenDoubleStreamAttention,
13
+ QwenImageTransformerBlock,
14
+ QwenImageDiT,
15
+ QwenEmbedRope,
16
+ )
17
+
18
+ from nunchaku.models.utils import fuse_linears
19
+ from nunchaku.ops.fused import fused_gelu_mlp
20
+ from nunchaku.models.linear import AWQW4A16Linear, SVDQW4A4Linear
21
+ from diffsynth_engine.models.basic.lora import LoRALinear, LoRAConv2d
22
+ from diffsynth_engine.models.basic.lora_nunchaku import LoRASVDQW4A4Linear, LoRAAWQW4A16Linear
23
+
24
+
25
+ class QwenDoubleStreamAttentionNunchaku(QwenDoubleStreamAttention):
26
+ def __init__(
27
+ self,
28
+ dim_a,
29
+ dim_b,
30
+ num_heads,
31
+ head_dim,
32
+ device: str = "cuda:0",
33
+ dtype: torch.dtype = torch.bfloat16,
34
+ nunchaku_rank: int = 32,
35
+ ):
36
+ super().__init__(dim_a, dim_b, num_heads, head_dim, device=device, dtype=dtype)
37
+
38
+ to_qkv = fuse_linears([self.to_q, self.to_k, self.to_v])
39
+ self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, rank=nunchaku_rank)
40
+ self.to_out = SVDQW4A4Linear.from_linear(self.to_out, rank=nunchaku_rank)
41
+
42
+ del self.to_q, self.to_k, self.to_v
43
+
44
+ add_qkv_proj = fuse_linears([self.add_q_proj, self.add_k_proj, self.add_v_proj])
45
+ self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, rank=nunchaku_rank)
46
+ self.to_add_out = SVDQW4A4Linear.from_linear(self.to_add_out, rank=nunchaku_rank)
47
+
48
+ del self.add_q_proj, self.add_k_proj, self.add_v_proj
49
+
50
+ def forward(
51
+ self,
52
+ image: torch.FloatTensor,
53
+ text: torch.FloatTensor,
54
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
55
+ attn_mask: Optional[torch.Tensor] = None,
56
+ attn_kwargs: Optional[Dict[str, Any]] = None,
57
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
58
+ img_q, img_k, img_v = self.to_qkv(image).chunk(3, dim=-1)
59
+ txt_q, txt_k, txt_v = self.add_qkv_proj(text).chunk(3, dim=-1)
60
+
61
+ img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
62
+ img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
63
+ img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
64
+
65
+ txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
66
+ txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
67
+ txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
68
+
69
+ img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
70
+ txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
71
+
72
+ if rotary_emb is not None:
73
+ img_freqs, txt_freqs = rotary_emb
74
+ img_q = apply_rotary_emb_qwen(img_q, img_freqs)
75
+ img_k = apply_rotary_emb_qwen(img_k, img_freqs)
76
+ txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
77
+ txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
78
+
79
+ joint_q = torch.cat([txt_q, img_q], dim=1)
80
+ joint_k = torch.cat([txt_k, img_k], dim=1)
81
+ joint_v = torch.cat([txt_v, img_v], dim=1)
82
+
83
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
84
+ joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)
85
+
86
+ joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
87
+
88
+ txt_attn_output = joint_attn_out[:, : text.shape[1], :]
89
+ img_attn_output = joint_attn_out[:, text.shape[1] :, :]
90
+
91
+ img_attn_output = self.to_out(img_attn_output)
92
+ txt_attn_output = self.to_add_out(txt_attn_output)
93
+
94
+ return img_attn_output, txt_attn_output
95
+
96
+
97
+ class QwenFeedForwardNunchaku(QwenFeedForward):
98
+ def __init__(
99
+ self,
100
+ dim: int,
101
+ dim_out: Optional[int] = None,
102
+ dropout: float = 0.0,
103
+ device: str = "cuda:0",
104
+ dtype: torch.dtype = torch.bfloat16,
105
+ rank: int = 32,
106
+ ):
107
+ super().__init__(dim, dim_out, dropout, device=device, dtype=dtype)
108
+ self.net[0].proj = SVDQW4A4Linear.from_linear(self.net[0].proj, rank=rank)
109
+ self.net[2] = SVDQW4A4Linear.from_linear(self.net[2], rank=rank)
110
+ self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
111
+
112
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
113
+ return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
114
+
115
+
116
+ class QwenImageTransformerBlockNunchaku(QwenImageTransformerBlock):
117
+ def __init__(
118
+ self,
119
+ dim: int,
120
+ num_attention_heads: int,
121
+ attention_head_dim: int,
122
+ eps: float = 1e-6,
123
+ device: str = "cuda:0",
124
+ dtype: torch.dtype = torch.bfloat16,
125
+ scale_shift: float = 1.0,
126
+ use_nunchaku_awq: bool = True,
127
+ use_nunchaku_attn: bool = True,
128
+ nunchaku_rank: int = 32,
129
+ ):
130
+ super().__init__(dim, num_attention_heads, attention_head_dim, eps, device=device, dtype=dtype)
131
+
132
+ self.use_nunchaku_awq = use_nunchaku_awq
133
+ if use_nunchaku_awq:
134
+ self.img_mod[1] = AWQW4A16Linear.from_linear(self.img_mod[1], rank=nunchaku_rank)
135
+
136
+ if use_nunchaku_attn:
137
+ self.attn = QwenDoubleStreamAttentionNunchaku(
138
+ dim_a=dim,
139
+ dim_b=dim,
140
+ num_heads=num_attention_heads,
141
+ head_dim=attention_head_dim,
142
+ device=device,
143
+ dtype=dtype,
144
+ nunchaku_rank=nunchaku_rank,
145
+ )
146
+ else:
147
+ self.attn = QwenDoubleStreamAttention(
148
+ dim_a=dim,
149
+ dim_b=dim,
150
+ num_heads=num_attention_heads,
151
+ head_dim=attention_head_dim,
152
+ device=device,
153
+ dtype=dtype,
154
+ )
155
+
156
+ self.img_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
157
+
158
+ if use_nunchaku_awq:
159
+ self.txt_mod[1] = AWQW4A16Linear.from_linear(self.txt_mod[1], rank=nunchaku_rank)
160
+
161
+ self.txt_mlp = QwenFeedForwardNunchaku(dim=dim, dim_out=dim, device=device, dtype=dtype, rank=nunchaku_rank)
162
+
163
+ self.scale_shift = scale_shift
164
+
165
+ def _modulate(self, x, mod_params):
166
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
167
+ if self.use_nunchaku_awq:
168
+ if self.scale_shift != 0:
169
+ scale.add_(self.scale_shift)
170
+ return x * scale.unsqueeze(1) + shift.unsqueeze(1), gate.unsqueeze(1)
171
+ else:
172
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
173
+
174
+ def forward(
175
+ self,
176
+ image: torch.Tensor,
177
+ text: torch.Tensor,
178
+ temb: torch.Tensor,
179
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
180
+ attn_mask: Optional[torch.Tensor] = None,
181
+ attn_kwargs: Optional[Dict[str, Any]] = None,
182
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
+ if self.use_nunchaku_awq:
184
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
185
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
186
+
187
+ # nunchaku's mod_params is [B, 6*dim] instead of [B, dim*6]
188
+ img_mod_params = (
189
+ img_mod_params.view(img_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(img_mod_params.shape[0], -1)
190
+ )
191
+ txt_mod_params = (
192
+ txt_mod_params.view(txt_mod_params.shape[0], -1, 6).transpose(1, 2).reshape(txt_mod_params.shape[0], -1)
193
+ )
194
+
195
+ img_mod_attn, img_mod_mlp = img_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
196
+ txt_mod_attn, txt_mod_mlp = txt_mod_params.chunk(2, dim=-1) # [B, 3*dim] each
197
+ else:
198
+ img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
199
+ txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
200
+
201
+ img_normed = self.img_norm1(image)
202
+ img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
203
+
204
+ txt_normed = self.txt_norm1(text)
205
+ txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
206
+
207
+ img_attn_out, txt_attn_out = self.attn(
208
+ image=img_modulated,
209
+ text=txt_modulated,
210
+ rotary_emb=rotary_emb,
211
+ attn_mask=attn_mask,
212
+ attn_kwargs=attn_kwargs,
213
+ )
214
+
215
+ image = image + img_gate * img_attn_out
216
+ text = text + txt_gate * txt_attn_out
217
+
218
+ img_normed_2 = self.img_norm2(image)
219
+ img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
220
+
221
+ txt_normed_2 = self.txt_norm2(text)
222
+ txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
223
+
224
+ img_mlp_out = self.img_mlp(img_modulated_2)
225
+ txt_mlp_out = self.txt_mlp(txt_modulated_2)
226
+
227
+ image = image + img_gate_2 * img_mlp_out
228
+ text = text + txt_gate_2 * txt_mlp_out
229
+
230
+ return text, image
231
+
232
+
233
+ class QwenImageDiTNunchaku(QwenImageDiT):
234
+ def __init__(
235
+ self,
236
+ num_layers: int = 60,
237
+ device: str = "cuda:0",
238
+ dtype: torch.dtype = torch.bfloat16,
239
+ use_nunchaku_awq: bool = True,
240
+ use_nunchaku_attn: bool = True,
241
+ nunchaku_rank: int = 32,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True, device=device)
246
+
247
+ self.time_text_embed = TimestepEmbeddings(256, 3072, device=device, dtype=dtype)
248
+
249
+ self.txt_norm = RMSNorm(3584, eps=1e-6, device=device, dtype=dtype)
250
+
251
+ self.img_in = nn.Linear(64, 3072, device=device, dtype=dtype)
252
+ self.txt_in = nn.Linear(3584, 3072, device=device, dtype=dtype)
253
+
254
+ self.transformer_blocks = nn.ModuleList(
255
+ [
256
+ QwenImageTransformerBlockNunchaku(
257
+ dim=3072,
258
+ num_attention_heads=24,
259
+ attention_head_dim=128,
260
+ device=device,
261
+ dtype=dtype,
262
+ scale_shift=0,
263
+ use_nunchaku_awq=use_nunchaku_awq,
264
+ use_nunchaku_attn=use_nunchaku_attn,
265
+ nunchaku_rank=nunchaku_rank,
266
+ )
267
+ for _ in range(num_layers)
268
+ ]
269
+ )
270
+ self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
271
+ self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
272
+
273
+ @classmethod
274
+ def from_state_dict(
275
+ cls,
276
+ state_dict: Dict[str, torch.Tensor],
277
+ device: str,
278
+ dtype: torch.dtype,
279
+ num_layers: int = 60,
280
+ use_nunchaku_awq: bool = True,
281
+ use_nunchaku_attn: bool = True,
282
+ nunchaku_rank: int = 32,
283
+ ):
284
+ model = cls(
285
+ device="meta",
286
+ dtype=dtype,
287
+ num_layers=num_layers,
288
+ use_nunchaku_awq=use_nunchaku_awq,
289
+ use_nunchaku_attn=use_nunchaku_attn,
290
+ nunchaku_rank=nunchaku_rank,
291
+ )
292
+ model = model.requires_grad_(False)
293
+ model.load_state_dict(state_dict, assign=True)
294
+ model.to(device=device, non_blocking=True)
295
+ return model
296
+
297
+ def load_loras(self, lora_args: List[Dict[str, Any]], fused: bool = False):
298
+ fuse_dict = {}
299
+ for args in lora_args:
300
+ key = args["key"]
301
+ if any(suffix in key for suffix in {"add_q_proj", "add_k_proj", "add_v_proj"}):
302
+ fuse_key = f"{key.rsplit('.', 1)[0]}.add_qkv_proj"
303
+ type = key.rsplit(".", 1)[-1].split("_")[1]
304
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
305
+ fuse_dict[fuse_key][type] = args
306
+ continue
307
+
308
+ if any(suffix in key for suffix in {"to_q", "to_k", "to_v"}):
309
+ fuse_key = f"{key.rsplit('.', 1)[0]}.to_qkv"
310
+ type = key.rsplit(".", 1)[-1].split("_")[1]
311
+ fuse_dict[fuse_key] = fuse_dict.get(fuse_key, {})
312
+ fuse_dict[fuse_key][type] = args
313
+ continue
314
+
315
+ module = self.get_submodule(key)
316
+ if not isinstance(module, (LoRALinear, LoRAConv2d, LoRASVDQW4A4Linear, LoRAAWQW4A16Linear)):
317
+ raise ValueError(f"Unsupported lora key: {key}")
318
+
319
+ if fused and not isinstance(module, LoRAAWQW4A16Linear):
320
+ module.add_frozen_lora(**args)
321
+ else:
322
+ module.add_lora(**args)
323
+
324
+ for key in fuse_dict.keys():
325
+ module = self.get_submodule(key)
326
+ if not isinstance(module, LoRASVDQW4A4Linear):
327
+ raise ValueError(f"Unsupported lora key: {key}")
328
+ module.add_qkv_lora(
329
+ name=args["name"],
330
+ scale=fuse_dict[key]["q"]["scale"],
331
+ rank=fuse_dict[key]["q"]["rank"],
332
+ alpha=fuse_dict[key]["q"]["alpha"],
333
+ q_up=fuse_dict[key]["q"]["up"],
334
+ q_down=fuse_dict[key]["q"]["down"],
335
+ k_up=fuse_dict[key]["k"]["up"],
336
+ k_down=fuse_dict[key]["k"]["down"],
337
+ v_up=fuse_dict[key]["v"]["up"],
338
+ v_down=fuse_dict[key]["v"]["down"],
339
+ device=fuse_dict[key]["q"]["device"],
340
+ dtype=fuse_dict[key]["q"]["dtype"],
341
+ )
@@ -106,7 +106,8 @@ class BasePipeline:
106
106
  for key, param in state_dict.items():
107
107
  lora_args.append(
108
108
  {
109
- "name": key,
109
+ "name": lora_path,
110
+ "key": key,
110
111
  "scale": lora_scale,
111
112
  "rank": param["rank"],
112
113
  "alpha": param["alpha"],
@@ -130,7 +131,10 @@ class BasePipeline:
130
131
 
131
132
  @staticmethod
132
133
  def load_model_checkpoint(
133
- checkpoint_path: str | List[str], device: str = "cpu", dtype: torch.dtype = torch.float16
134
+ checkpoint_path: str | List[str],
135
+ device: str = "cpu",
136
+ dtype: torch.dtype = torch.float16,
137
+ convert_dtype: bool = True,
134
138
  ) -> Dict[str, torch.Tensor]:
135
139
  if isinstance(checkpoint_path, str):
136
140
  checkpoint_path = [checkpoint_path]
@@ -140,8 +144,11 @@ class BasePipeline:
140
144
  raise FileNotFoundError(f"{path} is not a file")
141
145
  elif path.endswith(".safetensors"):
142
146
  state_dict_ = load_file(path, device=device)
143
- for key, value in state_dict_.items():
144
- state_dict[key] = value.to(dtype)
147
+ if convert_dtype:
148
+ for key, value in state_dict_.items():
149
+ state_dict[key] = value.to(dtype)
150
+ else:
151
+ state_dict.update(state_dict_)
145
152
 
146
153
  elif path.endswith(".gguf"):
147
154
  state_dict.update(**load_gguf_checkpoint(path, device=device, dtype=dtype))
@@ -2,6 +2,7 @@ import json
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
  import math
5
+ import sys
5
6
  from typing import Callable, List, Dict, Tuple, Optional, Union
6
7
  from tqdm import tqdm
7
8
  from einops import rearrange
@@ -38,11 +39,13 @@ from diffsynth_engine.utils.parallel import ParallelWrapper
38
39
  from diffsynth_engine.utils import logging
39
40
  from diffsynth_engine.utils.fp8_linear import enable_fp8_linear
40
41
  from diffsynth_engine.utils.download import fetch_model
42
+ from diffsynth_engine.utils.flag import NUNCHAKU_AVAILABLE
41
43
 
42
44
 
43
45
  logger = logging.get_logger(__name__)
44
46
 
45
47
 
48
+
46
49
  class QwenImageLoRAConverter(LoRAStateDictConverter):
47
50
  def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
48
51
  dit_dict = {}
@@ -77,6 +80,7 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
77
80
 
78
81
  key = key.replace(f".{lora_a_suffix}", "")
79
82
  key = key.replace("base_model.model.", "")
83
+ key = key.replace("transformer.", "")
80
84
 
81
85
  if key.startswith("transformer") and "attn.to_out.0" in key:
82
86
  key = key.replace("attn.to_out.0", "attn.to_out")
@@ -177,6 +181,36 @@ class QwenImagePipeline(BasePipeline):
177
181
  "vae",
178
182
  ]
179
183
 
184
+ @classmethod
185
+ def _setup_nunchaku_config(
186
+ cls, model_state_dict: Dict[str, torch.Tensor], config: QwenImagePipelineConfig
187
+ ) -> QwenImagePipelineConfig:
188
+ is_nunchaku_model = any("qweight" in key for key in model_state_dict)
189
+
190
+ if is_nunchaku_model:
191
+ logger.info("Nunchaku quantized model detected. Configuring for nunchaku.")
192
+ config.use_nunchaku = True
193
+ config.nunchaku_rank = model_state_dict["transformer_blocks.0.img_mlp.net.0.proj.proj_up"].shape[1]
194
+
195
+ if "transformer_blocks.0.img_mod.1.qweight" in model_state_dict:
196
+ config.use_nunchaku_awq = True
197
+ logger.info("Enable nunchaku AWQ.")
198
+ else:
199
+ config.use_nunchaku_awq = False
200
+ logger.info("Disable nunchaku AWQ.")
201
+
202
+ if "transformer_blocks.0.attn.to_qkv.qweight" in model_state_dict:
203
+ config.use_nunchaku_attn = True
204
+ logger.info("Enable nunchaku attention quantization.")
205
+ else:
206
+ config.use_nunchaku_attn = False
207
+ logger.info("Disable nunchaku attention quantization.")
208
+
209
+ else:
210
+ config.use_nunchaku = False
211
+
212
+ return config
213
+
180
214
  @classmethod
181
215
  def from_pretrained(cls, model_path_or_config: str | QwenImagePipelineConfig) -> "QwenImagePipeline":
182
216
  if isinstance(model_path_or_config, str):
@@ -185,7 +219,16 @@ class QwenImagePipeline(BasePipeline):
185
219
  config = model_path_or_config
186
220
 
187
221
  logger.info(f"loading state dict from {config.model_path} ...")
188
- model_state_dict = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
222
+ model_state_dict = cls.load_model_checkpoint(
223
+ config.model_path, device="cpu", dtype=config.model_dtype, convert_dtype=False
224
+ )
225
+
226
+ config = cls._setup_nunchaku_config(model_state_dict, config)
227
+
228
+ # for svd quant model fp4/int4 linear layers, do not convert dtype here
229
+ if not config.use_nunchaku:
230
+ for key, value in model_state_dict.items():
231
+ model_state_dict[key] = value.to(config.model_dtype)
189
232
 
190
233
  if config.vae_path is None:
191
234
  config.vae_path = fetch_model(
@@ -221,6 +264,8 @@ class QwenImagePipeline(BasePipeline):
221
264
 
222
265
  @classmethod
223
266
  def from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePipelineConfig) -> "QwenImagePipeline":
267
+ config = cls._setup_nunchaku_config(state_dicts.model, config)
268
+
224
269
  if config.parallelism > 1:
225
270
  pipe = ParallelWrapper(
226
271
  cfg_degree=config.cfg_degree,
@@ -270,13 +315,30 @@ class QwenImagePipeline(BasePipeline):
270
315
  dtype=config.model_dtype,
271
316
  relative_l1_threshold=config.fbcache_relative_l1_threshold,
272
317
  )
318
+ elif config.use_nunchaku:
319
+ if not NUNCHAKU_AVAILABLE:
320
+ from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR
321
+ raise ImportError(NUNCHAKU_IMPORT_ERROR)
322
+
323
+ from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
324
+ from diffsynth_engine.models.basic.lora_nunchaku import patch_nunchaku_model_for_lora
325
+
326
+ dit = QwenImageDiTNunchaku.from_state_dict(
327
+ state_dicts.model,
328
+ device=init_device,
329
+ dtype=config.model_dtype,
330
+ use_nunchaku_awq=config.use_nunchaku_awq,
331
+ use_nunchaku_attn=config.use_nunchaku_attn,
332
+ nunchaku_rank=config.nunchaku_rank,
333
+ )
334
+ patch_nunchaku_model_for_lora(dit)
273
335
  else:
274
336
  dit = QwenImageDiT.from_state_dict(
275
337
  state_dicts.model,
276
338
  device=("cpu" if config.use_fsdp else init_device),
277
339
  dtype=config.model_dtype,
278
340
  )
279
- if config.use_fp8_linear:
341
+ if config.use_fp8_linear and not config.use_nunchaku:
280
342
  enable_fp8_linear(dit)
281
343
 
282
344
  pipe = cls(
@@ -650,7 +650,7 @@ class WanVideoPipeline(BasePipeline):
650
650
  dit_type = "wan2.2-i2v-a14b"
651
651
  elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
652
652
  dit_type = "wan2.2-t2v-a14b"
653
- elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
653
+ elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
654
654
  dit_type = "wan2.2-ti2v-5b"
655
655
  elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
656
656
  dit_type = "wan2.1-flf2v-14b"
@@ -680,6 +680,30 @@ class WanVideoPipeline(BasePipeline):
680
680
  if config.attn_params is None:
681
681
  config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682
682
 
683
+ def update_weights(self, state_dicts: WanStateDicts) -> None:
684
+ is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
685
+ ("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
686
+ is_dual_model_pipeline = self.dit2 is not None
687
+
688
+ if is_dual_model_state_dict != is_dual_model_pipeline:
689
+ raise ValueError(
690
+ f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
691
+ f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
692
+ f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
693
+ )
694
+
695
+ if is_dual_model_state_dict:
696
+ if "high_noise_model" in state_dicts.model:
697
+ self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
698
+ if "low_noise_model" in state_dicts.model:
699
+ self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
700
+ else:
701
+ self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
702
+
703
+ self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
704
+ self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
705
+ self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
706
+
683
707
  def compile(self):
684
708
  self.dit.compile_repeated_blocks()
685
709
  if self.dit2 is not None:
@@ -55,3 +55,27 @@ if VIDEO_SPARSE_ATTN_AVAILABLE:
55
55
  logger.info("Video sparse attention is available")
56
56
  else:
57
57
  logger.info("Video sparse attention is not available")
58
+
59
+ NUNCHAKU_AVAILABLE = importlib.util.find_spec("nunchaku") is not None
60
+ NUNCHAKU_IMPORT_ERROR = None
61
+ if NUNCHAKU_AVAILABLE:
62
+ logger.info("Nunchaku is available")
63
+ else:
64
+ logger.info("Nunchaku is not available")
65
+ import sys
66
+ torch_version = getattr(torch, "__version__", "unknown")
67
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
68
+ NUNCHAKU_IMPORT_ERROR = (
69
+ "\n\n"
70
+ "ERROR: This model requires the 'nunchaku' library for quantized inference, but it is not installed.\n"
71
+ "'nunchaku' is not available on PyPI and must be installed manually.\n\n"
72
+ "Please follow these steps:\n"
73
+ "1. Visit the nunchaku releases page: https://github.com/nunchaku-tech/nunchaku/releases\n"
74
+ "2. Find the wheel (.whl) file that matches your environment:\n"
75
+ f" - PyTorch version: {torch_version}\n"
76
+ f" - Python version: {python_version}\n"
77
+ f" - Operating System: {sys.platform}\n"
78
+ "3. Copy the URL of the correct wheel file.\n"
79
+ "4. Install it using pip, for example:\n"
80
+ " pip install nunchaku @ https://.../your_specific_nunchaku_file.whl\n"
81
+ )
@@ -21,117 +21,33 @@ from queue import Empty
21
21
  import diffsynth_engine.models.basic.attention as attention_ops
22
22
  from diffsynth_engine.utils.platform import empty_cache
23
23
  from diffsynth_engine.utils import logging
24
+ from diffsynth_engine.utils.process_group import (
25
+ PROCESS_GROUP,
26
+ get_cfg_group,
27
+ get_cfg_world_size,
28
+ get_cfg_rank,
29
+ get_cfg_ranks,
30
+ get_sp_group,
31
+ get_sp_world_size,
32
+ get_sp_rank,
33
+ get_sp_ranks,
34
+ get_sp_ulysses_group,
35
+ get_sp_ulysses_world_size,
36
+ get_sp_ulysses_rank,
37
+ get_sp_ulysses_ranks,
38
+ get_sp_ring_group,
39
+ get_sp_ring_world_size,
40
+ get_sp_ring_rank,
41
+ get_sp_ring_ranks,
42
+ get_tp_group,
43
+ get_tp_world_size,
44
+ get_tp_rank,
45
+ get_tp_ranks,
46
+ )
24
47
 
25
48
  logger = logging.get_logger(__name__)
26
49
 
27
50
 
28
- class Singleton:
29
- _instance = None
30
-
31
- def __new__(cls, *args, **kwargs):
32
- if not cls._instance:
33
- cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
34
- return cls._instance
35
-
36
-
37
- class ProcessGroupSingleton(Singleton):
38
- def __init__(self):
39
- self.CFG_GROUP: Optional[dist.ProcessGroup] = None
40
- self.SP_GROUP: Optional[dist.ProcessGroup] = None
41
- self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
42
- self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
43
- self.TP_GROUP: Optional[dist.ProcessGroup] = None
44
-
45
- self.CFG_RANKS: List[int] = []
46
- self.SP_RANKS: List[int] = []
47
- self.SP_ULYSSUES_RANKS: List[int] = []
48
- self.SP_RING_RANKS: List[int] = []
49
- self.TP_RANKS: List[int] = []
50
-
51
-
52
- PROCESS_GROUP = ProcessGroupSingleton()
53
-
54
-
55
- def get_cfg_group():
56
- return PROCESS_GROUP.CFG_GROUP
57
-
58
-
59
- def get_cfg_world_size():
60
- return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
61
-
62
-
63
- def get_cfg_rank():
64
- return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
65
-
66
-
67
- def get_cfg_ranks():
68
- return PROCESS_GROUP.CFG_RANKS
69
-
70
-
71
- def get_sp_group():
72
- return PROCESS_GROUP.SP_GROUP
73
-
74
-
75
- def get_sp_world_size():
76
- return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
77
-
78
-
79
- def get_sp_rank():
80
- return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
81
-
82
-
83
- def get_sp_ranks():
84
- return PROCESS_GROUP.SP_RANKS
85
-
86
-
87
- def get_sp_ulysses_group():
88
- return PROCESS_GROUP.SP_ULYSSUES_GROUP
89
-
90
-
91
- def get_sp_ulysses_world_size():
92
- return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
93
-
94
-
95
- def get_sp_ulysses_rank():
96
- return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
97
-
98
-
99
- def get_sp_ulysses_ranks():
100
- return PROCESS_GROUP.SP_ULYSSUES_RANKS
101
-
102
-
103
- def get_sp_ring_group():
104
- return PROCESS_GROUP.SP_RING_GROUP
105
-
106
-
107
- def get_sp_ring_world_size():
108
- return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
109
-
110
-
111
- def get_sp_ring_rank():
112
- return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
113
-
114
-
115
- def get_sp_ring_ranks():
116
- return PROCESS_GROUP.SP_RING_RANKS
117
-
118
-
119
- def get_tp_group():
120
- return PROCESS_GROUP.TP_GROUP
121
-
122
-
123
- def get_tp_world_size():
124
- return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
125
-
126
-
127
- def get_tp_rank():
128
- return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
129
-
130
-
131
- def get_tp_ranks():
132
- return PROCESS_GROUP.TP_RANKS
133
-
134
-
135
51
  def init_parallel_pgs(
136
52
  cfg_degree: int = 1,
137
53
  sp_ulysses_degree: int = 1,
@@ -0,0 +1,149 @@
1
+ """
2
+ Process group management for distributed training.
3
+
4
+ This module provides singleton-based process group management for distributed training,
5
+ including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
6
+ """
7
+
8
+ import torch.distributed as dist
9
+ from typing import Optional, List
10
+
11
+
12
+ class Singleton:
13
+ _instance = None
14
+
15
+ def __new__(cls, *args, **kwargs):
16
+ if not cls._instance:
17
+ cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
18
+ return cls._instance
19
+
20
+
21
+ class ProcessGroupSingleton(Singleton):
22
+ def __init__(self):
23
+ if not hasattr(self, 'initialized'):
24
+ self.CFG_GROUP: Optional[dist.ProcessGroup] = None
25
+ self.SP_GROUP: Optional[dist.ProcessGroup] = None
26
+ self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
27
+ self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
28
+ self.TP_GROUP: Optional[dist.ProcessGroup] = None
29
+
30
+ self.CFG_RANKS: List[int] = []
31
+ self.SP_RANKS: List[int] = []
32
+ self.SP_ULYSSUES_RANKS: List[int] = []
33
+ self.SP_RING_RANKS: List[int] = []
34
+ self.TP_RANKS: List[int] = []
35
+
36
+ self.initialized = True
37
+
38
+
39
+ PROCESS_GROUP = ProcessGroupSingleton()
40
+
41
+
42
+ # CFG parallel group functions
43
+ def get_cfg_group():
44
+ return PROCESS_GROUP.CFG_GROUP
45
+
46
+
47
+ def get_cfg_world_size():
48
+ return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
49
+
50
+
51
+ def get_cfg_rank():
52
+ return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
53
+
54
+
55
+ def get_cfg_ranks():
56
+ return PROCESS_GROUP.CFG_RANKS
57
+
58
+
59
+ # Sequence parallel group functions
60
+ def get_sp_group():
61
+ return PROCESS_GROUP.SP_GROUP
62
+
63
+
64
+ def get_sp_world_size():
65
+ return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
66
+
67
+
68
+ def get_sp_rank():
69
+ return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
70
+
71
+
72
+ def get_sp_ranks():
73
+ return PROCESS_GROUP.SP_RANKS
74
+
75
+
76
+ # Sequence parallel Ulysses group functions
77
+ def get_sp_ulysses_group():
78
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP
79
+
80
+
81
+ def get_sp_ulysses_world_size():
82
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
83
+
84
+
85
+ def get_sp_ulysses_rank():
86
+ return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
87
+
88
+
89
+ def get_sp_ulysses_ranks():
90
+ return PROCESS_GROUP.SP_ULYSSUES_RANKS
91
+
92
+
93
+ # Sequence parallel Ring group functions
94
+ def get_sp_ring_group():
95
+ return PROCESS_GROUP.SP_RING_GROUP
96
+
97
+
98
+ def get_sp_ring_world_size():
99
+ return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
100
+
101
+
102
+ def get_sp_ring_rank():
103
+ return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
104
+
105
+
106
+ def get_sp_ring_ranks():
107
+ return PROCESS_GROUP.SP_RING_RANKS
108
+
109
+
110
+ # Tensor parallel group functions
111
+ def get_tp_group():
112
+ return PROCESS_GROUP.TP_GROUP
113
+
114
+
115
+ def get_tp_world_size():
116
+ return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
117
+
118
+
119
+ def get_tp_rank():
120
+ return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
121
+
122
+
123
+ def get_tp_ranks():
124
+ return PROCESS_GROUP.TP_RANKS
125
+
126
+
127
+ __all__ = [
128
+ "PROCESS_GROUP",
129
+ "get_cfg_group",
130
+ "get_cfg_world_size",
131
+ "get_cfg_rank",
132
+ "get_cfg_ranks",
133
+ "get_sp_group",
134
+ "get_sp_world_size",
135
+ "get_sp_rank",
136
+ "get_sp_ranks",
137
+ "get_sp_ulysses_group",
138
+ "get_sp_ulysses_world_size",
139
+ "get_sp_ulysses_rank",
140
+ "get_sp_ulysses_ranks",
141
+ "get_sp_ring_group",
142
+ "get_sp_ring_world_size",
143
+ "get_sp_ring_rank",
144
+ "get_sp_ring_ranks",
145
+ "get_tp_group",
146
+ "get_tp_world_size",
147
+ "get_tp_rank",
148
+ "get_tp_ranks",
149
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev27
3
+ Version: 0.6.1.dev29
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -81,18 +81,19 @@ diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer.json,sha256=bhl7TT29cdoU
81
81
  diffsynth_engine/conf/tokenizers/wan/umt5-xxl/tokenizer_config.json,sha256=7Zo6iw-qcacKMoR-BDX-A25uES1N9O23u0ipIeNE3AU,61728
82
82
  diffsynth_engine/configs/__init__.py,sha256=vSjJToEdq3JX7t81_z4nwNwIdD4bYnFjxnMZH7PXMKo,1309
83
83
  diffsynth_engine/configs/controlnet.py,sha256=f3vclyP3lcAjxDGD9C1vevhqqQ7W2LL_c6Wye0uxk3Q,1180
84
- diffsynth_engine/configs/pipeline.py,sha256=ADgWJa7bA3Z3Z1JtVLgmt4N3eS1KRp9yHu1QvTBzTm0,13404
84
+ diffsynth_engine/configs/pipeline.py,sha256=7duSdoD0LIROtepsLW9PxYsK59p7qSv34BVz0k29vu4,13633
85
85
  diffsynth_engine/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
86
86
  diffsynth_engine/models/__init__.py,sha256=8Ze7cSE8InetgXWTNb0neVA2Q44K7WlE-h7O-02m2sY,119
87
- diffsynth_engine/models/base.py,sha256=BA5vgMqfy_cjuL2OtXbrFD-Qg5xQnaumHpj5TabwSy8,2559
87
+ diffsynth_engine/models/base.py,sha256=svao__9WH8VNcyXz5o5dzywYXDcGV0YV9IfkLzDKews,2558
88
88
  diffsynth_engine/models/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
89
89
  diffsynth_engine/models/basic/attention.py,sha256=mvgk8LTqFwgtPdBeRv797IZNg9k7--X9wD92Hcr188c,15682
90
- diffsynth_engine/models/basic/lora.py,sha256=PT-A3pwIuUrW2w3TnNlBPb1KRj70QYiBaoCvLnkR5cs,10652
90
+ diffsynth_engine/models/basic/lora.py,sha256=Y6cBgrBsuDAP9FZz_fgK8vBi_EMg23saFIUSAsPIG-M,10670
91
+ diffsynth_engine/models/basic/lora_nunchaku.py,sha256=7qhzGCzUIfDrwtWG0nspwdyZ7YUkaM4vMqzxZby2Zds,7510
91
92
  diffsynth_engine/models/basic/relative_position_emb.py,sha256=rCXOweZMcayVnNUVvBcYXMdhHS257B_PC8PZSWxvhNQ,2540
92
93
  diffsynth_engine/models/basic/timestep.py,sha256=WJODYqkSXEM0wcS42YkkfrGwxWt0e60zMTkDdUBQqBw,2810
93
94
  diffsynth_engine/models/basic/transformer_helper.py,sha256=6K7A5bVnN2bOoq6I0IQf7RJBhSZUP4jNf1n7NPGu8zA,5287
94
95
  diffsynth_engine/models/basic/unet_helper.py,sha256=4lN6F80Ubm6ip4dkLVmB-Og5-Y25Wduhs9Q8qjyzK6E,9044
95
- diffsynth_engine/models/basic/video_sparse_attention.py,sha256=iXA3sHDLWk1ns1lVCNbZdiaDu94kBIsw-9vrCGAll7g,7843
96
+ diffsynth_engine/models/basic/video_sparse_attention.py,sha256=GxDN6PTpA1rCoQaXUwSPgH4708bEzVI1qsD48WVDXLA,8201
96
97
  diffsynth_engine/models/flux/__init__.py,sha256=x0JoxL0CdiiVrY0BjkIrGinud7mcXecLleGO0km91XQ,686
97
98
  diffsynth_engine/models/flux/flux_controlnet.py,sha256=NvFKQIx0NldX5uUxdmYwuS2s-xaFRlKotiE6lr3-HRY,8018
98
99
  diffsynth_engine/models/flux/flux_dit.py,sha256=7sdV8KFQiHcK-8aqyvXBgC7E_-D9rcgBcnMXUq_AybI,23403
@@ -108,10 +109,11 @@ diffsynth_engine/models/hunyuan3d/hunyuan3d_vae.py,sha256=0IUrUSBi-6eWeaScUoi0e6
108
109
  diffsynth_engine/models/hunyuan3d/moe.py,sha256=FAuUqgrB2ZFb0uGBhI-Afv850HmzDFP5yJKKogf4A4U,3552
109
110
  diffsynth_engine/models/hunyuan3d/surface_extractor.py,sha256=b15mb1N4PYwAvDk1Gude8qlccRKrSg461xT59RjMEQk,4167
110
111
  diffsynth_engine/models/hunyuan3d/volume_decoder.py,sha256=sgflj1a8sIerqGSalBAVQOlyiIihkLOLXYysNbulCoQ,2355
111
- diffsynth_engine/models/qwen_image/__init__.py,sha256=X5pig621WEsDZ6L7HVkmYspV53-GDfs_la1ncaq_NFw,417
112
+ diffsynth_engine/models/qwen_image/__init__.py,sha256=_6f0LWaoLdDvD2CsjK2OzEIQryt9efge8DFS4_GUnHQ,582
112
113
  diffsynth_engine/models/qwen_image/qwen2_5_vl.py,sha256=Eu-r-c42t_q74Qpwz21ToCGHpvSi7VND4B1EI0e-ePA,57748
113
114
  diffsynth_engine/models/qwen_image/qwen_image_dit.py,sha256=iJ-FinDyXa982Uao1is37bxUttyPu0Eldyd7qPJO_XQ,22582
114
115
  diffsynth_engine/models/qwen_image/qwen_image_dit_fbcache.py,sha256=LIv9X_BohKk5rcEzyl3ATLwd8MSoFX43wjkArQ68nq8,4828
116
+ diffsynth_engine/models/qwen_image/qwen_image_dit_nunchaku.py,sha256=TCzNsFxw-QBHrRg94f_ITs5u85Em-aoCAeCr2AylPpE,13478
115
117
  diffsynth_engine/models/qwen_image/qwen_image_vae.py,sha256=eO7f4YqiYXfw7NncBNFTu-xEvdJ5uKY-SnfP15QY0tE,38443
116
118
  diffsynth_engine/models/sd/__init__.py,sha256=hjoKRnwoXOLD0wude-w7I6wK5ak7ACMbnbkPuBB2oU0,380
117
119
  diffsynth_engine/models/sd/sd_controlnet.py,sha256=kMGfIdriXhC7reT6iO2Z0rPICXEkXpytjeBQcR_sjT8,50577
@@ -141,15 +143,15 @@ diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-
141
143
  diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
142
144
  diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
143
145
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
144
- diffsynth_engine/pipelines/base.py,sha256=Yvb2xiHT1Jhx4HDkNPHdXjzhUkM9_65D4zM-GSSOWoU,16133
146
+ diffsynth_engine/pipelines/base.py,sha256=BNMNL-OU-9ilUv7O60trA3_rjHA21d6Oc5PKzKYBa80,16347
145
147
  diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
146
148
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
147
- diffsynth_engine/pipelines/qwen_image.py,sha256=n6Nnin8OyC9Mfp8O-3N4GNq12Mws8_hHWv-SwU4-HCc,33054
149
+ diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
148
150
  diffsynth_engine/pipelines/sd_image.py,sha256=nr-Nhsnomq8CsUqhTM3i2l2zG01YjwXdfRXgr_bC3F0,17891
149
151
  diffsynth_engine/pipelines/sdxl_image.py,sha256=v7ZACGPb6EcBunL6e5E9jynSQjE7GQx8etEV-ZLP91g,21704
150
152
  diffsynth_engine/pipelines/utils.py,sha256=HZbJHErNJS1DhlwJKvZ9dY7Kh8Zdlsw3zE2e88TYGRY,2277
151
153
  diffsynth_engine/pipelines/wan_s2v.py,sha256=QHlCLMqlmnp55iYm2mzg4qCq4jceRAP3Zt5Mubz3mAM,29384
152
- diffsynth_engine/pipelines/wan_video.py,sha256=rJq60LiaCoLq1PkqUzzrdvFkp6h73fc-ZUu0MiMQC-c,29668
154
+ diffsynth_engine/pipelines/wan_video.py,sha256=9xjSvQ4mlVEDdaL6QuUURj4iyxhJ2xABBphQjkfzK8s,31323
153
155
  diffsynth_engine/processor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
154
156
  diffsynth_engine/processor/canny_processor.py,sha256=hV30NlblTkEFUAmF_O-LJrNlGVM2SFrqq6okfF8VpOo,602
155
157
  diffsynth_engine/processor/depth_processor.py,sha256=dQvs3JsnyMbz4dyI9QoR8oO-mMFBFAgNvgqeCoaU5jk,1532
@@ -171,7 +173,7 @@ diffsynth_engine/utils/cache.py,sha256=Ivef22pCuhEq-4H00gSvkLS8ceVZoGis7OSitYL6g
171
173
  diffsynth_engine/utils/constants.py,sha256=sJio3Vy8i0-PWYRnqquYt6ez9k6Tc9JdjCv6pn2BU_4,3551
172
174
  diffsynth_engine/utils/download.py,sha256=w9QQjllPfTUEY371UTREU7o_vvdMY-Q2DymDel3ZEZY,6792
173
175
  diffsynth_engine/utils/env.py,sha256=k749eYt_qKGq38GocDiXfkhp8nZrowFefNVTZ8R755I,363
174
- diffsynth_engine/utils/flag.py,sha256=v9GcRFYiNMonD9qmDLWdbXONuF-AcQ_KABPFtRZd0Tc,1767
176
+ diffsynth_engine/utils/flag.py,sha256=KSzjnzRe7sleNCJm8IpbJQbmBY4KNV2kDrijxi27Jek,2928
175
177
  diffsynth_engine/utils/fp8_linear.py,sha256=k34YFWo2dc3t8aKjHaCW9CbQMOTqXxaDHk8aw8aKif4,3857
176
178
  diffsynth_engine/utils/gguf.py,sha256=ZWvw46V4g4uVyAR_oCq-4K5nPdKVrYk3u47uXMgA9lU,14092
177
179
  diffsynth_engine/utils/image.py,sha256=PiDButjv0fsRS23kpQgCLZAlBumpzQmNnolfvb5EKQ0,9626
@@ -180,15 +182,16 @@ diffsynth_engine/utils/lock.py,sha256=1Ipgst9eEFfFdViAvD5bxdB6HnHHBcqWYOb__fGaPU
180
182
  diffsynth_engine/utils/logging.py,sha256=XB0xTT8PBN6btkOjFtOvjlrOCRVgDGT8PFAp1vmse28,467
181
183
  diffsynth_engine/utils/offload.py,sha256=94og79TIkxldwYUgZT3L4OVu1WBlE7gfVPvO2MRhm6c,3551
182
184
  diffsynth_engine/utils/onnx.py,sha256=jeWUudJHnESjuiEAHyUZYUZz7dCj34O9aGjHCe8yjWo,1149
183
- diffsynth_engine/utils/parallel.py,sha256=6T8oCTp-7Gb3qsgNRB2Bp3DF4eyx1FzvS6pFnEJbsek,19789
185
+ diffsynth_engine/utils/parallel.py,sha256=OBGsAK-3ncArRyMU1lea7tbYgxSdCucQvXheL3Ssl5M,17653
184
186
  diffsynth_engine/utils/platform.py,sha256=nbpG-XHJFRmYY6u_e7IBQ9Q6GyItrIkKf3VKuBPTUpY,627
187
+ diffsynth_engine/utils/process_group.py,sha256=P-X04a--Zb4M4kjc3DddmusrxCKqv8wiDGhXG4Al-rE,3783
185
188
  diffsynth_engine/utils/prompt.py,sha256=YItMchoVzsG6y-LB4vzzDUWrkhKRVlt1HfVhxZjSxMQ,280
186
189
  diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CDhg,2200
187
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
188
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
189
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
190
- diffsynth_engine-0.6.1.dev27.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
191
- diffsynth_engine-0.6.1.dev27.dist-info/METADATA,sha256=w8FRm_Fr7AZp3TPFh1TUHk93eWxm9CFAZcU8S4qwKj0,1164
192
- diffsynth_engine-0.6.1.dev27.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
193
- diffsynth_engine-0.6.1.dev27.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
194
- diffsynth_engine-0.6.1.dev27.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev29.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev29.dist-info/METADATA,sha256=8A5q0qhRMxeJi7IOvP3dcqk58BsgIBxy16ndlnDM_6I,1164
195
+ diffsynth_engine-0.6.1.dev29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev29.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev29.dist-info/RECORD,,