diffsynth-engine 0.6.1.dev30__py3-none-any.whl → 0.6.1.dev31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -439,7 +439,7 @@ class WanImageEncoder(PreTrainedModel):
439
439
  def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
440
440
  super().__init__()
441
441
  # init model
442
- self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device="cpu")
442
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device=device)
443
443
 
444
444
  def encode_image(self, images: List[torch.Tensor]):
445
445
  # preprocess
@@ -38,19 +38,20 @@ class T5LayerNorm(nn.Module):
38
38
 
39
39
 
40
40
  class T5Attention(nn.Module):
41
- def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
41
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.0, device="cuda:0"):
42
42
  assert dim_attn % num_heads == 0
43
43
  super(T5Attention, self).__init__()
44
44
  self.dim = dim
45
45
  self.dim_attn = dim_attn
46
46
  self.num_heads = num_heads
47
47
  self.head_dim = dim_attn // num_heads
48
+ self.device = device
48
49
 
49
50
  # layers
50
- self.q = nn.Linear(dim, dim_attn, bias=False)
51
- self.k = nn.Linear(dim, dim_attn, bias=False)
52
- self.v = nn.Linear(dim, dim_attn, bias=False)
53
- self.o = nn.Linear(dim_attn, dim, bias=False)
51
+ self.q = nn.Linear(dim, dim_attn, bias=False, device=device)
52
+ self.k = nn.Linear(dim, dim_attn, bias=False, device=device)
53
+ self.v = nn.Linear(dim, dim_attn, bias=False, device=device)
54
+ self.o = nn.Linear(dim_attn, dim, bias=False, device=device)
54
55
  self.dropout = nn.Dropout(dropout)
55
56
 
56
57
  def forward(self, x, context=None, mask=None, pos_bias=None):
@@ -90,15 +91,16 @@ class T5Attention(nn.Module):
90
91
 
91
92
 
92
93
  class T5FeedForward(nn.Module):
93
- def __init__(self, dim, dim_ffn, dropout=0.0):
94
+ def __init__(self, dim, dim_ffn, dropout=0.0, device="cuda:0"):
94
95
  super(T5FeedForward, self).__init__()
95
96
  self.dim = dim
96
97
  self.dim_ffn = dim_ffn
98
+ self.device = device
97
99
 
98
100
  # layers
99
- self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
100
- self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
101
- self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
101
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False, device=device), GELU())
102
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False, device=device)
103
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False, device=device)
102
104
  self.dropout = nn.Dropout(dropout)
103
105
 
104
106
  def forward(self, x):
@@ -110,7 +112,7 @@ class T5FeedForward(nn.Module):
110
112
 
111
113
 
112
114
  class T5SelfAttention(nn.Module):
113
- def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0):
115
+ def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0, device="cuda:0"):
114
116
  super(T5SelfAttention, self).__init__()
115
117
  self.dim = dim
116
118
  self.dim_attn = dim_attn
@@ -118,13 +120,14 @@ class T5SelfAttention(nn.Module):
118
120
  self.num_heads = num_heads
119
121
  self.num_buckets = num_buckets
120
122
  self.shared_pos = shared_pos
123
+ self.device = device
121
124
 
122
125
  # layers
123
126
  self.norm1 = T5LayerNorm(dim)
124
- self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
127
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout, device)
125
128
  self.norm2 = T5LayerNorm(dim)
126
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
127
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
129
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout, device)
130
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device)
128
131
 
129
132
  def forward(self, x, mask=None, pos_bias=None):
130
133
  e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
@@ -134,15 +137,16 @@ class T5SelfAttention(nn.Module):
134
137
 
135
138
 
136
139
  class T5RelativeEmbedding(nn.Module):
137
- def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
140
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128, device="cuda:0"):
138
141
  super(T5RelativeEmbedding, self).__init__()
139
142
  self.num_buckets = num_buckets
140
143
  self.num_heads = num_heads
141
144
  self.bidirectional = bidirectional
142
145
  self.max_dist = max_dist
146
+ self.device = device
143
147
 
144
148
  # layers
145
- self.embedding = nn.Embedding(num_buckets, num_heads)
149
+ self.embedding = nn.Embedding(num_buckets, num_heads, device=device)
146
150
 
147
151
  def forward(self, lq, lk):
148
152
  device = self.embedding.weight.device
@@ -257,12 +261,12 @@ class WanTextEncoder(PreTrainedModel):
257
261
  self.shared_pos = shared_pos
258
262
 
259
263
  # layers
260
- self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
261
- self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
264
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, device=device)
265
+ self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device) if shared_pos else None
262
266
  self.dropout = nn.Dropout(dropout)
263
267
  self.blocks = nn.ModuleList(
264
268
  [
265
- T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
269
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, device)
266
270
  for _ in range(num_layers)
267
271
  ]
268
272
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffsynth_engine
3
- Version: 0.6.1.dev30
3
+ Version: 0.6.1.dev31
4
4
  Author: MuseAI x ModelScope
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Operating System :: OS Independent
@@ -138,9 +138,9 @@ diffsynth_engine/models/vae/vae.py,sha256=1Hz5Yb6f8V-psC0qothfzg8EZBPVPpg9KGlSMD
138
138
  diffsynth_engine/models/wan/__init__.py,sha256=eYwZ2Upo2mTjaAcBWuSft1m4mLnqE47bz2V_u-WtkwQ,246
139
139
  diffsynth_engine/models/wan/wan_audio_encoder.py,sha256=i8mVu5lhVlTnzVTDcSv7qGC6HjB3MuS9hFVkUrw9458,13629
140
140
  diffsynth_engine/models/wan/wan_dit.py,sha256=MEt9eWy6djWT1dtlFEHP9Yevat4-M_LSzWRauNSIHck,21599
141
- diffsynth_engine/models/wan/wan_image_encoder.py,sha256=VE7crdTxOFN2UCMN2cQlvHB9BilSbKOBQYgnXgl4E2Y,14313
141
+ diffsynth_engine/models/wan/wan_image_encoder.py,sha256=Vdd39lv_QvOsmPxihZWZZbpP-9QuCFpNJ39bdtI5qTQ,14314
142
142
  diffsynth_engine/models/wan/wan_s2v_dit.py,sha256=j63ulcWLY4XGITOKUMGX292LtSEtP-n8BTvqb98YExU,23615
143
- diffsynth_engine/models/wan/wan_text_encoder.py,sha256=OERlmwOqthAFPNnnT2sXJ4OjyyRmsRLx7VGp1zlBkLU,11021
143
+ diffsynth_engine/models/wan/wan_text_encoder.py,sha256=ePeOifbTI_o650mckzugyWPuHn5vhM-uFMcDVCijxPM,11394
144
144
  diffsynth_engine/models/wan/wan_vae.py,sha256=dC7MoUFeXRL7SIY0LG1OOUiZW-pp9IbXCghutMxpXr4,38889
145
145
  diffsynth_engine/pipelines/__init__.py,sha256=jh-4LSJ0vqlXiT8BgFgRIQxuAr2atEPyHrxXWj-Ud1U,604
146
146
  diffsynth_engine/pipelines/base.py,sha256=ShRiX5MY6bUkRKfuGrA1aalAqeHyeZxhzT87Mwc30b4,17231
@@ -190,8 +190,8 @@ diffsynth_engine/utils/video.py,sha256=8FCaeqIdUsWMgWI_6SO9SPynsToGcLCQAVYFTc4CD
190
190
  diffsynth_engine/utils/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
191
191
  diffsynth_engine/utils/memory/linear_regression.py,sha256=oW_EQEw13oPoyUrxiL8A7Ksa5AuJ2ynI2qhCbfAuZbg,3930
192
192
  diffsynth_engine/utils/memory/memory_predcit_model.py,sha256=EXprSl_zlVjgfMWNXP-iw83Ot3hyMcgYaRPv-dvyL84,3943
193
- diffsynth_engine-0.6.1.dev30.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
- diffsynth_engine-0.6.1.dev30.dist-info/METADATA,sha256=z-j4fdSyJwgilKYRl-MrSlhicE8MJP9uvoGYYTFrYKk,1164
195
- diffsynth_engine-0.6.1.dev30.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
- diffsynth_engine-0.6.1.dev30.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
- diffsynth_engine-0.6.1.dev30.dist-info/RECORD,,
193
+ diffsynth_engine-0.6.1.dev31.dist-info/licenses/LICENSE,sha256=x7aBqQuVI0IYnftgoTPI_A0I_rjdjPPQkjnU6N2nikM,11346
194
+ diffsynth_engine-0.6.1.dev31.dist-info/METADATA,sha256=PGHUdyy75RQEl6ownCDC66hY24x07mNdRA7oFszGvss,1164
195
+ diffsynth_engine-0.6.1.dev31.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
196
+ diffsynth_engine-0.6.1.dev31.dist-info/top_level.txt,sha256=6zgbiIzEHLbhgDKRyX0uBJOV3F6VnGGBRIQvSiYYn6w,17
197
+ diffsynth_engine-0.6.1.dev31.dist-info/RECORD,,