diffsynth-engine 0.6.1.dev34__py3-none-any.whl → 0.6.1.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,602 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict, List, Tuple, Union, Optional
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ import math
7
+
8
+ from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
9
+ from diffsynth_engine.models.basic import attention as attention_ops
10
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
+ from diffsynth_engine.utils.gguf import gguf_inference
12
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
13
+ from diffsynth_engine.utils.parallel import (
14
+ cfg_parallel,
15
+ cfg_parallel_unshard,
16
+ sequence_parallel,
17
+ sequence_parallel_unshard,
18
+ )
19
+
20
+
21
+ class ZImageStateDictConverter(StateDictConverter):
22
+ def __init__(self):
23
+ pass
24
+
25
+ def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
26
+ state_dict_ = {}
27
+ for name, param in state_dict.items():
28
+ name_ = name
29
+ if "attention.to_out.0" in name:
30
+ name_ = name.replace("attention.to_out.0", "attention.to_out")
31
+ if "adaLN_modulation.0" in name:
32
+ name_ = name.replace("adaLN_modulation.0", "adaLN_modulation")
33
+ state_dict_[name_] = param
34
+ return state_dict_
35
+
36
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
37
+ state_dict = self._from_diffusers(state_dict)
38
+ return state_dict
39
+
40
+
41
+ class ZImageTimestepEmbedder(nn.Module):
42
+ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256, device="cuda:0", dtype=torch.bfloat16):
43
+ super().__init__()
44
+ if mid_size is None:
45
+ mid_size = out_size
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+ self.mlp = nn.Sequential(
48
+ nn.Linear(frequency_embedding_size, mid_size, bias=True, device=device, dtype=dtype),
49
+ nn.SiLU(),
50
+ nn.Linear(mid_size, out_size, bias=True, device=device, dtype=dtype),
51
+ )
52
+
53
+ @staticmethod
54
+ def timestep_embedding(t, dim, max_period=10000):
55
+ with torch.amp.autocast("cuda", enabled=False):
56
+ half = dim // 2
57
+ freqs = torch.exp(
58
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
59
+ )
60
+ args = t[:, None].float() * freqs[None]
61
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
62
+ if dim % 2:
63
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
64
+ return embedding
65
+
66
+ def forward(self, t):
67
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
68
+ weight_dtype = self.mlp[0].weight.dtype
69
+ t_freq = t_freq.to(dtype=weight_dtype)
70
+ t_emb = self.mlp(t_freq)
71
+ return t_emb
72
+
73
+
74
+ class ZImageFeedForward(nn.Module):
75
+ def __init__(self, dim: int, hidden_dim: int, device="cuda:0", dtype=torch.bfloat16):
76
+ super().__init__()
77
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
78
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype)
79
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype)
80
+
81
+ def forward(self, x):
82
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
83
+
84
+
85
+ class ZImageFinalLayer(nn.Module):
86
+ def __init__(self, hidden_size, out_channels, adaln_embed_dim=256, device="cuda:0", dtype=torch.bfloat16):
87
+ super().__init__()
88
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
89
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True, device=device, dtype=dtype)
90
+
91
+ self.adaLN_modulation = nn.Sequential(
92
+ nn.SiLU(),
93
+ nn.Linear(min(hidden_size, adaln_embed_dim), hidden_size, bias=True, device=device, dtype=dtype),
94
+ )
95
+
96
+ def forward(self, x, c):
97
+ scale = 1.0 + self.adaLN_modulation(c)
98
+ x = self.norm_final(x) * scale.unsqueeze(1)
99
+ x = self.linear(x)
100
+ return x
101
+
102
+
103
+ class ZImageRopeEmbedder:
104
+ def __init__(
105
+ self,
106
+ theta: float = 256.0,
107
+ axes_dims: List[int] = (16, 56, 56),
108
+ axes_lens: List[int] = (64, 128, 128),
109
+ device: str = "cuda:0",
110
+ ):
111
+ self.theta = theta
112
+ self.axes_dims = axes_dims
113
+ self.axes_lens = axes_lens
114
+ assert len(axes_dims) == len(axes_lens)
115
+ self.freqs_cis = None
116
+ self.device = device
117
+
118
+ def precompute_freqs_cis(self, dim: List[int], end: List[int], theta: float = 256.0):
119
+ freqs_cis = []
120
+ for i, (d, e) in enumerate(zip(dim, end)):
121
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
122
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
123
+ freqs = torch.outer(timestep, freqs).float()
124
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64)
125
+ freqs_cis.append(freqs_cis_i)
126
+ return freqs_cis
127
+
128
+ def __call__(self, ids: torch.Tensor):
129
+ assert ids.ndim == 2
130
+ assert ids.shape[-1] == len(self.axes_dims)
131
+
132
+ if self.freqs_cis is None:
133
+ self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
134
+ self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
135
+ elif self.freqs_cis[0].device != ids.device:
136
+ self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
137
+
138
+ result = []
139
+ for i in range(len(self.axes_dims)):
140
+ index = ids[:, i]
141
+ result.append(self.freqs_cis[i][index])
142
+ return torch.cat(result, dim=-1)
143
+
144
+
145
+ def apply_rotary_emb_zimage(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
146
+ with torch.amp.autocast("cuda", enabled=False):
147
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
148
+ freqs_cis = freqs_cis.unsqueeze(2)
149
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
150
+ return x_out.type_as(x_in)
151
+
152
+
153
+ class ZImageAttention(nn.Module):
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ num_heads: int,
158
+ head_dim: int,
159
+ qk_norm: bool = True,
160
+ eps: float = 1e-5,
161
+ device: str = "cuda:0",
162
+ dtype: torch.dtype = torch.bfloat16,
163
+ ):
164
+ super().__init__()
165
+ self.heads = num_heads
166
+ self.head_dim = head_dim
167
+
168
+ self.to_q = nn.Linear(dim, dim, bias=False, device=device, dtype=dtype)
169
+ self.to_k = nn.Linear(dim, dim, bias=False, device=device, dtype=dtype)
170
+ self.to_v = nn.Linear(dim, dim, bias=False, device=device, dtype=dtype)
171
+
172
+ self.norm_q = RMSNorm(head_dim, eps=eps, device=device, dtype=dtype) if qk_norm else None
173
+ self.norm_k = RMSNorm(head_dim, eps=eps, device=device, dtype=dtype) if qk_norm else None
174
+
175
+ self.to_out = nn.Linear(dim, dim, bias=False, device=device, dtype=dtype)
176
+
177
+ def forward(
178
+ self,
179
+ x: torch.Tensor,
180
+ freqs_cis: Optional[torch.Tensor] = None,
181
+ attn_mask: Optional[torch.Tensor] = None,
182
+ **kwargs,
183
+ ) -> torch.Tensor:
184
+ q = self.to_q(x)
185
+ k = self.to_k(x)
186
+ v = self.to_v(x)
187
+
188
+ q = q.view(*q.shape[:2], self.heads, self.head_dim)
189
+ k = k.view(*k.shape[:2], self.heads, self.head_dim)
190
+ v = v.view(*v.shape[:2], self.heads, self.head_dim)
191
+
192
+ if self.norm_q is not None:
193
+ q = self.norm_q(q)
194
+ if self.norm_k is not None:
195
+ k = self.norm_k(k)
196
+
197
+ if freqs_cis is not None:
198
+ q = apply_rotary_emb_zimage(q, freqs_cis)
199
+ k = apply_rotary_emb_zimage(k, freqs_cis)
200
+
201
+ out = attention_ops.attention(q, k, v, attn_mask=attn_mask, **kwargs)
202
+
203
+ out = out.flatten(2)
204
+ out = self.to_out(out)
205
+ return out
206
+
207
+
208
+ class ZImageTransformerBlock(nn.Module):
209
+ def __init__(
210
+ self,
211
+ dim: int,
212
+ n_heads: int,
213
+ n_kv_heads: int,
214
+ norm_eps: float,
215
+ qk_norm: bool,
216
+ modulation: bool = True,
217
+ adaln_embed_dim: int = 256,
218
+ device: str = "cuda:0",
219
+ dtype: torch.dtype = torch.bfloat16,
220
+ ):
221
+ super().__init__()
222
+ self.dim = dim
223
+ self.modulation = modulation
224
+
225
+ self.attention = ZImageAttention(
226
+ dim=dim,
227
+ num_heads=n_heads,
228
+ head_dim=dim // n_heads,
229
+ qk_norm=qk_norm,
230
+ eps=1e-5,
231
+ device=device,
232
+ dtype=dtype,
233
+ )
234
+
235
+ self.feed_forward = ZImageFeedForward(dim=dim, hidden_dim=int(dim / 3 * 8), device=device, dtype=dtype)
236
+
237
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
238
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
239
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
240
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype)
241
+
242
+ if modulation:
243
+ self.adaLN_modulation = nn.Linear(min(dim, adaln_embed_dim), 4 * dim, bias=True, device=device, dtype=dtype)
244
+
245
+ def forward(
246
+ self,
247
+ x: torch.Tensor,
248
+ attn_mask: torch.Tensor,
249
+ freqs_cis: torch.Tensor,
250
+ adaln_input: Optional[torch.Tensor] = None,
251
+ ):
252
+ if self.modulation:
253
+ assert adaln_input is not None
254
+ mod_output = self.adaLN_modulation(adaln_input)
255
+ mod_output = mod_output.unsqueeze(1)
256
+ scale_msa, gate_msa, scale_mlp, gate_mlp = mod_output.chunk(4, dim=2)
257
+
258
+ gate_msa = gate_msa.tanh()
259
+ gate_mlp = gate_mlp.tanh()
260
+ scale_msa = 1.0 + scale_msa
261
+ scale_mlp = 1.0 + scale_mlp
262
+
263
+ attn_out = self.attention(self.attention_norm1(x) * scale_msa, freqs_cis=freqs_cis, attn_mask=attn_mask)
264
+ x = x + gate_msa * self.attention_norm2(attn_out)
265
+
266
+ ffn_out = self.feed_forward(self.ffn_norm1(x) * scale_mlp)
267
+ x = x + gate_mlp * self.ffn_norm2(ffn_out)
268
+ else:
269
+ attn_out = self.attention(self.attention_norm1(x), freqs_cis=freqs_cis, attn_mask=attn_mask)
270
+ x = x + self.attention_norm2(attn_out)
271
+
272
+ ffn_out = self.feed_forward(self.ffn_norm1(x))
273
+ x = x + self.ffn_norm2(ffn_out)
274
+
275
+ return x
276
+
277
+
278
+ class ZImageDiT(PreTrainedModel):
279
+ converter = ZImageStateDictConverter()
280
+ _supports_parallelization = True
281
+
282
+ def __init__(
283
+ self,
284
+ all_patch_size=(2,),
285
+ all_f_patch_size=(1,),
286
+ in_channels=16,
287
+ dim=3840,
288
+ n_layers=30,
289
+ n_refiner_layers=2,
290
+ n_heads=30,
291
+ n_kv_heads=30,
292
+ norm_eps=1e-5,
293
+ qk_norm=True,
294
+ cap_feat_dim=2560,
295
+ rope_theta=256.0,
296
+ t_scale=1000.0,
297
+ axes_dims=[32, 48, 48],
298
+ axes_lens=[1024, 512, 512],
299
+ device: str = "cuda:0",
300
+ dtype: torch.dtype = torch.bfloat16,
301
+ ):
302
+ super().__init__()
303
+ self.in_channels = in_channels
304
+ self.out_channels = in_channels
305
+ self.all_patch_size = all_patch_size
306
+ self.all_f_patch_size = all_f_patch_size
307
+ self.dim = dim
308
+ self.n_heads = n_heads
309
+ self.rope_theta = rope_theta
310
+ self.t_scale = t_scale
311
+ self.dtype = dtype
312
+ self.device = device
313
+ self.ADALN_EMBED_DIM = 256
314
+ self.SEQ_MULTI_OF = 32
315
+
316
+ all_x_embedder = {}
317
+ all_final_layer = {}
318
+ for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size):
319
+ key = f"{patch_size}-{f_patch_size}"
320
+ all_x_embedder[key] = nn.Linear(
321
+ f_patch_size * patch_size * patch_size * in_channels, dim, bias=True, device=device, dtype=dtype
322
+ )
323
+ all_final_layer[key] = ZImageFinalLayer(
324
+ dim, patch_size * patch_size * f_patch_size * in_channels, device=device, dtype=dtype
325
+ )
326
+
327
+ self.all_x_embedder = nn.ModuleDict(all_x_embedder)
328
+ self.all_final_layer = nn.ModuleDict(all_final_layer)
329
+
330
+ self.noise_refiner = nn.ModuleList(
331
+ [
332
+ ZImageTransformerBlock(
333
+ dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, device=device, dtype=dtype
334
+ )
335
+ for _ in range(n_refiner_layers)
336
+ ]
337
+ )
338
+ self.context_refiner = nn.ModuleList(
339
+ [
340
+ ZImageTransformerBlock(
341
+ dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False, device=device, dtype=dtype
342
+ )
343
+ for _ in range(n_refiner_layers)
344
+ ]
345
+ )
346
+
347
+ self.t_embedder = ZImageTimestepEmbedder(
348
+ min(dim, self.ADALN_EMBED_DIM), mid_size=1024, device=device, dtype=dtype
349
+ )
350
+
351
+ self.cap_embedder = nn.Sequential(
352
+ RMSNorm(cap_feat_dim, eps=norm_eps, device=device, dtype=dtype),
353
+ nn.Linear(cap_feat_dim, dim, bias=True, device=device, dtype=dtype),
354
+ )
355
+
356
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
357
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
358
+
359
+ self.layers = nn.ModuleList(
360
+ [
361
+ ZImageTransformerBlock(dim, n_heads, n_kv_heads, norm_eps, qk_norm, device=device, dtype=dtype)
362
+ for _ in range(n_layers)
363
+ ]
364
+ )
365
+
366
+ self.rope_embedder = ZImageRopeEmbedder(
367
+ theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens, device=device
368
+ )
369
+
370
+ @staticmethod
371
+ def create_coordinate_grid(size, start=None, device=None):
372
+ if start is None:
373
+ start = (0 for _ in size)
374
+ axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
375
+ grids = torch.meshgrid(axes, indexing="ij")
376
+ return torch.stack(grids, dim=-1)
377
+
378
+ def patchify_and_embed(self, all_image, all_cap_feats, patch_size, f_patch_size):
379
+ pH = pW = patch_size
380
+ pF = f_patch_size
381
+ device = all_image[0].device
382
+
383
+ all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], []
384
+ all_cap_feats_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
385
+
386
+ for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
387
+ cap_ori_len = len(cap_feat)
388
+ cap_padding_len = (-cap_ori_len) % self.SEQ_MULTI_OF
389
+ cap_padded_pos_ids = self.create_coordinate_grid(
390
+ size=(cap_ori_len + cap_padding_len, 1, 1), start=(1, 0, 0), device=device
391
+ ).flatten(0, 2)
392
+ all_cap_pos_ids.append(cap_padded_pos_ids)
393
+
394
+ cap_pad_mask = torch.cat(
395
+ [
396
+ torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
397
+ torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
398
+ ],
399
+ dim=0,
400
+ )
401
+ all_cap_pad_mask.append(cap_pad_mask)
402
+
403
+ cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
404
+ all_cap_feats_out.append(cap_padded_feat)
405
+
406
+ C, F, H, W = image.size()
407
+ all_image_size.append((F, H, W))
408
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
409
+
410
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
411
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
412
+
413
+ image_ori_len = len(image)
414
+ image_padding_len = (-image_ori_len) % self.SEQ_MULTI_OF
415
+
416
+ image_ori_pos_ids = self.create_coordinate_grid(
417
+ size=(F_tokens, H_tokens, W_tokens), start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device
418
+ ).flatten(0, 2)
419
+
420
+ if image_padding_len > 0:
421
+ pad_grid = self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device).flatten(0, 2)
422
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, pad_grid.repeat(image_padding_len, 1)], dim=0)
423
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
424
+ image_pad_mask = torch.cat(
425
+ [
426
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
427
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
428
+ ],
429
+ dim=0,
430
+ )
431
+ else:
432
+ image_padded_pos_ids = image_ori_pos_ids
433
+ image_padded_feat = image
434
+ image_pad_mask = torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
435
+
436
+ all_image_pos_ids.append(image_padded_pos_ids)
437
+ all_image_pad_mask.append(image_pad_mask)
438
+ all_image_out.append(image_padded_feat)
439
+
440
+ return (
441
+ all_image_out,
442
+ all_cap_feats_out,
443
+ all_image_size,
444
+ all_image_pos_ids,
445
+ all_cap_pos_ids,
446
+ all_image_pad_mask,
447
+ all_cap_pad_mask,
448
+ )
449
+
450
+ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
451
+ pH, pW, pF = patch_size, patch_size, f_patch_size
452
+ bsz = len(x)
453
+ for i in range(bsz):
454
+ F, H, W = size[i]
455
+ ori_len = (F // pF) * (H // pH) * (W // pW)
456
+ x[i] = (
457
+ x[i][:ori_len]
458
+ .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
459
+ .permute(6, 0, 3, 1, 4, 2, 5)
460
+ .reshape(self.out_channels, F, H, W)
461
+ )
462
+ return x
463
+
464
+ def forward(
465
+ self,
466
+ image: Union[torch.Tensor, List[torch.Tensor]],
467
+ timestep: torch.Tensor,
468
+ cap_feats: Union[torch.Tensor, List[torch.Tensor]],
469
+ patch_size: int = 2,
470
+ f_patch_size: int = 1,
471
+ ):
472
+ if isinstance(image, torch.Tensor):
473
+ image = list(image.unbind(0))
474
+ if isinstance(cap_feats, torch.Tensor):
475
+ cap_feats = list(cap_feats.unbind(0))
476
+
477
+ use_cfg = len(image) > 1
478
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
479
+ with (
480
+ fp8_inference(fp8_linear_enabled),
481
+ gguf_inference(),
482
+ cfg_parallel((image, timestep, cap_feats), use_cfg=use_cfg),
483
+ ):
484
+ bsz = len(image)
485
+ device = image[0].device
486
+
487
+ t = timestep * self.t_scale
488
+ t = self.t_embedder(t)
489
+
490
+ (x, cap_feats_processed, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = (
491
+ self.patchify_and_embed(image, cap_feats, patch_size, f_patch_size)
492
+ )
493
+
494
+ x_item_seqlens = [len(_) for _ in x]
495
+ x_max_item_seqlen = max(x_item_seqlens)
496
+ x = torch.cat(x, dim=0)
497
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
498
+
499
+ adaln_input = t.type_as(x)
500
+
501
+ x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
502
+
503
+ x = list(x.split(x_item_seqlens, dim=0))
504
+ x_freqs_cis = list(
505
+ self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)
506
+ )
507
+
508
+ x = pad_sequence(x, batch_first=True, padding_value=0.0)
509
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
510
+ x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
511
+
512
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
513
+ for i, seq_len in enumerate(x_item_seqlens):
514
+ x_attn_mask[i, :seq_len] = 1
515
+
516
+ x_attn_mask_4d = x_attn_mask.unsqueeze(1).unsqueeze(1)
517
+
518
+ for layer in self.noise_refiner:
519
+ x = layer(x, x_attn_mask_4d, x_freqs_cis, adaln_input)
520
+
521
+ cap_item_seqlens = [len(_) for _ in cap_feats_processed]
522
+ cap_max_item_seqlen = max(cap_item_seqlens)
523
+ cap_feats_tensor = torch.cat(cap_feats_processed, dim=0)
524
+ cap_feats_tensor = self.cap_embedder(cap_feats_tensor)
525
+ mask_tmp = torch.cat(cap_inner_pad_mask)
526
+ target_len = mask_tmp.sum()
527
+ if target_len > 0:
528
+ cap_feats_tensor[mask_tmp] = self.cap_pad_token.to(dtype=cap_feats_tensor.dtype).expand(target_len, -1)
529
+
530
+ cap_feats_list = list(cap_feats_tensor.split(cap_item_seqlens, dim=0))
531
+ cap_freqs_cis = list(
532
+ self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
533
+ )
534
+
535
+ cap_feats_padded = pad_sequence(cap_feats_list, batch_first=True, padding_value=0.0)
536
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
537
+ cap_freqs_cis = cap_freqs_cis[:, : cap_feats_padded.shape[1]]
538
+
539
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
540
+ for i, seq_len in enumerate(cap_item_seqlens):
541
+ cap_attn_mask[i, :seq_len] = 1
542
+ cap_attn_mask_4d = cap_attn_mask.unsqueeze(1).unsqueeze(1)
543
+
544
+ for layer in self.context_refiner:
545
+ cap_feats_padded = layer(cap_feats_padded, cap_attn_mask_4d, cap_freqs_cis, adaln_input=None)
546
+
547
+ unified = []
548
+ unified_freqs_cis = []
549
+ for i in range(bsz):
550
+ x_len = x_item_seqlens[i]
551
+ cap_len = cap_item_seqlens[i]
552
+ unified.append(torch.cat([x[i][:x_len], cap_feats_padded[i][:cap_len]]))
553
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
554
+
555
+ unified_item_seqlens = [len(_) for _ in unified]
556
+ unified_max_item_seqlen = max(unified_item_seqlens)
557
+
558
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
559
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
560
+
561
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
562
+ for i, seq_len in enumerate(unified_item_seqlens):
563
+ unified_attn_mask[i, :seq_len] = 1
564
+ unified_attn_mask_4d = unified_attn_mask.unsqueeze(1).unsqueeze(1)
565
+
566
+ with sequence_parallel((unified, unified_freqs_cis), seq_dims=(1, 1)):
567
+ for layer in self.layers:
568
+ unified = layer(unified, unified_attn_mask_4d, unified_freqs_cis, adaln_input)
569
+ (unified,) = sequence_parallel_unshard((unified,), seq_dims=(1,), seq_lens=(unified.shape[1],))
570
+ unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
571
+ unified_list = list(unified.unbind(dim=0))
572
+
573
+ output = self.unpatchify(unified_list, x_size, patch_size, f_patch_size)
574
+
575
+ (output,) = cfg_parallel_unshard((output,), use_cfg=use_cfg)
576
+
577
+ return output
578
+
579
+ @classmethod
580
+ def from_state_dict(
581
+ cls,
582
+ state_dict: Dict[str, torch.Tensor],
583
+ device: str,
584
+ dtype: torch.dtype,
585
+ **kwargs,
586
+ ):
587
+ model = cls(device="meta", dtype=dtype, **kwargs)
588
+ model = model.requires_grad_(False)
589
+ model.load_state_dict(state_dict, assign=True)
590
+ model.to(device=device, dtype=dtype, non_blocking=True)
591
+ return model
592
+
593
+ def compile_repeated_blocks(self, *args, **kwargs):
594
+ for block in self.noise_refiner:
595
+ block.compile(*args, **kwargs)
596
+ for block in self.context_refiner:
597
+ block.compile(*args, **kwargs)
598
+ for block in self.layers:
599
+ block.compile(*args, **kwargs)
600
+
601
+ def get_fsdp_module_cls(self):
602
+ return {ZImageTransformerBlock}
@@ -6,6 +6,7 @@ from .wan_video import WanVideoPipeline
6
6
  from .wan_s2v import WanSpeech2VideoPipeline
7
7
  from .qwen_image import QwenImagePipeline
8
8
  from .hunyuan3d_shape import Hunyuan3DShapePipeline
9
+ from .z_image import ZImagePipeline
9
10
 
10
11
  __all__ = [
11
12
  "BasePipeline",
@@ -17,4 +18,5 @@ __all__ = [
17
18
  "WanSpeech2VideoPipeline",
18
19
  "QwenImagePipeline",
19
20
  "Hunyuan3DShapePipeline",
21
+ "ZImagePipeline",
20
22
  ]