diffsynth-engine 0.6.1.dev29__py3-none-any.whl → 0.6.1.dev31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -439,7 +439,7 @@ class WanImageEncoder(PreTrainedModel):
439
439
  def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
440
440
  super().__init__()
441
441
  # init model
442
- self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device="cpu")
442
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device=device)
443
443
 
444
444
  def encode_image(self, images: List[torch.Tensor]):
445
445
  # preprocess
@@ -38,19 +38,20 @@ class T5LayerNorm(nn.Module):
38
38
 
39
39
 
40
40
  class T5Attention(nn.Module):
41
- def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
41
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.0, device="cuda:0"):
42
42
  assert dim_attn % num_heads == 0
43
43
  super(T5Attention, self).__init__()
44
44
  self.dim = dim
45
45
  self.dim_attn = dim_attn
46
46
  self.num_heads = num_heads
47
47
  self.head_dim = dim_attn // num_heads
48
+ self.device = device
48
49
 
49
50
  # layers
50
- self.q = nn.Linear(dim, dim_attn, bias=False)
51
- self.k = nn.Linear(dim, dim_attn, bias=False)
52
- self.v = nn.Linear(dim, dim_attn, bias=False)
53
- self.o = nn.Linear(dim_attn, dim, bias=False)
51
+ self.q = nn.Linear(dim, dim_attn, bias=False, device=device)
52
+ self.k = nn.Linear(dim, dim_attn, bias=False, device=device)
53
+ self.v = nn.Linear(dim, dim_attn, bias=False, device=device)
54
+ self.o = nn.Linear(dim_attn, dim, bias=False, device=device)
54
55
  self.dropout = nn.Dropout(dropout)
55
56
 
56
57
  def forward(self, x, context=None, mask=None, pos_bias=None):
@@ -90,15 +91,16 @@ class T5Attention(nn.Module):
90
91
 
91
92
 
92
93
  class T5FeedForward(nn.Module):
93
- def __init__(self, dim, dim_ffn, dropout=0.0):
94
+ def __init__(self, dim, dim_ffn, dropout=0.0, device="cuda:0"):
94
95
  super(T5FeedForward, self).__init__()
95
96
  self.dim = dim
96
97
  self.dim_ffn = dim_ffn
98
+ self.device = device
97
99
 
98
100
  # layers
99
- self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
100
- self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
101
- self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
101
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False, device=device), GELU())
102
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False, device=device)
103
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False, device=device)
102
104
  self.dropout = nn.Dropout(dropout)
103
105
 
104
106
  def forward(self, x):
@@ -110,7 +112,7 @@ class T5FeedForward(nn.Module):
110
112
 
111
113
 
112
114
  class T5SelfAttention(nn.Module):
113
- def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0):
115
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0, device="cuda:0"):
114
116
  super(T5SelfAttention, self).__init__()
115
117
  self.dim = dim
116
118
  self.dim_attn = dim_attn
@@ -118,13 +120,14 @@ class T5SelfAttention(nn.Module):
118
120
  self.num_heads = num_heads
119
121
  self.num_buckets = num_buckets
120
122
  self.shared_pos = shared_pos
123
+ self.device = device
121
124
 
122
125
  # layers
123
126
  self.norm1 = T5LayerNorm(dim)
124
- self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
127
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout, device)
125
128
  self.norm2 = T5LayerNorm(dim)
126
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
127
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
129
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout, device)
130
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device)
128
131
 
129
132
  def forward(self, x, mask=None, pos_bias=None):
130
133
  e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
@@ -134,15 +137,16 @@ class T5SelfAttention(nn.Module):
134
137
 
135
138
 
136
139
  class T5RelativeEmbedding(nn.Module):
137
- def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
140
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128, device="cuda:0"):
138
141
  super(T5RelativeEmbedding, self).__init__()
139
142
  self.num_buckets = num_buckets
140
143
  self.num_heads = num_heads
141
144
  self.bidirectional = bidirectional
142
145
  self.max_dist = max_dist
146
+ self.device = device
143
147
 
144
148
  # layers
145
- self.embedding = nn.Embedding(num_buckets, num_heads)
149
+ self.embedding = nn.Embedding(num_buckets, num_heads, device=device)
146
150
 
147
151
  def forward(self, lq, lk):
148
152
  device = self.embedding.weight.device
@@ -257,12 +261,12 @@ class WanTextEncoder(PreTrainedModel):
257
261
  self.shared_pos = shared_pos
258
262
 
259
263
  # layers
260
- self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
261
- self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
264
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, device=device)
265
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device) if shared_pos else None
262
266
  self.dropout = nn.Dropout(dropout)
263
267
  self.blocks = nn.ModuleList(
264
268
  [
265
- T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
269
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, device)
266
270
  for _ in range(num_layers)
267
271
  ]
268
272
  )
@@ -74,9 +74,9 @@ class BasePipeline:
74
74
  component.load_state_dict(state_dict, assign=True)
75
75
  component.to(device=device, dtype=dtype, non_blocking=True)
76
76
 
77
- def load_loras(
77
+ def _load_lora_state_dicts(
78
78
  self,
79
- lora_list: List[Tuple[str, Union[float, LoraConfig]]],
79
+ lora_state_dict_list: List[Tuple[Dict[str, torch.Tensor], Union[float, LoraConfig], str]],
80
80
  fused: bool = True,
81
81
  save_original_weight: bool = False,
82
82
  lora_converter: Optional[LoRAStateDictConverter] = None,
@@ -84,29 +84,30 @@ class BasePipeline:
84
84
  if not lora_converter:
85
85
  lora_converter = self.lora_converter
86
86
 
87
- for lora_path, lora_item in lora_list:
87
+ for state_dict, lora_item, lora_name in lora_state_dict_list:
88
88
  if isinstance(lora_item, float):
89
89
  lora_scale = lora_item
90
90
  scheduler_config = None
91
- if isinstance(lora_item, LoraConfig):
91
+ elif isinstance(lora_item, LoraConfig):
92
92
  lora_scale = lora_item.scale
93
93
  scheduler_config = lora_item.scheduler_config
94
+ else:
95
+ raise ValueError(f"lora_item must be float or LoraConfig, got {type(lora_item)}")
94
96
 
95
- logger.info(f"loading lora from {lora_path} with LoraConfig (scale={lora_scale})")
96
- state_dict = load_file(lora_path, device=self.device)
97
+ logger.info(f"loading lora from state_dict '{lora_name}' with scale={lora_scale}")
97
98
 
98
99
  if scheduler_config is not None:
99
100
  self.apply_scheduler_config(scheduler_config)
100
101
  logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
101
102
 
102
103
  lora_state_dict = lora_converter.convert(state_dict)
103
- for model_name, state_dict in lora_state_dict.items():
104
+ for model_name, model_state_dict in lora_state_dict.items():
104
105
  model = getattr(self, model_name)
105
106
  lora_args = []
106
- for key, param in state_dict.items():
107
+ for key, param in model_state_dict.items():
107
108
  lora_args.append(
108
109
  {
109
- "name": lora_path,
110
+ "name": lora_name,
110
111
  "key": key,
111
112
  "scale": lora_scale,
112
113
  "rank": param["rank"],
@@ -120,6 +121,26 @@ class BasePipeline:
120
121
  )
121
122
  model.load_loras(lora_args, fused=fused)
122
123
 
124
+ def load_loras(
125
+ self,
126
+ lora_list: List[Tuple[str, Union[float, LoraConfig]]],
127
+ fused: bool = True,
128
+ save_original_weight: bool = False,
129
+ lora_converter: Optional[LoRAStateDictConverter] = None,
130
+ ):
131
+ lora_state_dict_list = []
132
+ for lora_path, lora_item in lora_list:
133
+ logger.info(f"loading lora from {lora_path}")
134
+ state_dict = load_file(lora_path, device=self.device)
135
+ lora_state_dict_list.append((state_dict, lora_item, lora_path))
136
+
137
+ self._load_lora_state_dicts(
138
+ lora_state_dict_list=lora_state_dict_list,
139
+ fused=fused,
140
+ save_original_weight=save_original_weight,
141
+ lora_converter=lora_converter,
142
+ )
143
+
123
144
  def load_lora(self, path: str, scale: float, fused: bool = True, save_original_weight: bool = False):
124
145
  self.load_loras([(path, scale)], fused, save_original_weight)
125
146
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev29
3
+ Version: 0.6.1.dev31
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -138,12 +138,12 @@ diffsynth_engine/models/vae/vae.py,sha256=1Hz5Yb6f8V-psC0qothfzg8EZBPVPpg9KGlSMD
138
138
  diffsynth_engine/models/wan/__init__.py,sha256=eYwZ2Upo2mTjaAcBWuSft1m4mLnqE47bz2V_u-WtkwQ,246
139
139
  diffsynth_engine/models/wan/wan_audio_encoder.py,sha256=i8mVu5lhVlTnzVTDcSv7qGC6HjB3MuS9hFVkUrw9458,13629
140
140
  diffsynth_engine/models/wan/wan_dit.py,sha256=MEt9eWy6djWT1dtlFEHP9Yevat4-M_LSzWRauNSIHck,21599
141
- diffsynth_engine/models/wan/wan_image_encoder.py,sha256=VE7crdTxOFN2UCMN2cQlvHB9BilSbKOBQYgnXgl4E2Y,14313
141
+ diffsynth_engine/models/wan/wan_image_encoder.py,sha256=Vdd39lv_QvOsmPxihZWZZbpP-9QuCFpNJ39bdtI5qTQ,14314
142
142
  diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-n8BTvqb98YExU,23615
143
- diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
143
+ diffsynth_engine/models/wan/wan_text_encoder.py,sha256=ePeOifbTI_o650mckzugyWPuHn5vhM-uFMcDVCijxPM,11394
144
144
  diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
145
145
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
146
- diffsynth_engine/pipelines/base.py,sha256=BNMNL-OU-9ilUv7O60trA3_rjHA21d6Oc5PKzKYBa80,16347
146
+ diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
147
147
  diffsynth_engine/pipelines/flux_image.py,sha256=L0ggxpthLD8a5-zdPHu9z668uWBei9YzPb4PFVypDNU,50707
148
148
  diffsynth_engine/pipelines/hunyuan3d_shape.py,sha256=TNV0Wr09Dj2bzzlpua9WioCClOj3YiLfE6utI9aWL8A,8164
149
149
  diffsynth_engine/pipelines/qwen_image.py,sha256=ktOirdU2ljgb6vHhXosC0tWgXI3gwvsoAtrYKYvMwzI,35719
@@ -190,8 +190,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
190
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
191
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
192
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
193
- diffsynth_engine-0.6.1.dev29.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
- diffsynth_engine-0.6.1.dev29.dist-info/METADATA,sha256=8A5q0qhRMxeJi7IOvP3dcqk58BsgIBxy16ndlnDM_6I,1164
195
- diffsynth_engine-0.6.1.dev29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
- diffsynth_engine-0.6.1.dev29.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
- diffsynth_engine-0.6.1.dev29.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev31.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev31.dist-info/METADATA,sha256=PGHUdyy75RQEl6ownCDC66hY24x07mNdRA7oFszGvss,1164
195
+ diffsynth_engine-0.6.1.dev31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev31.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev31.dist-info/RECORD,,