diffsynth-engine 0.6.1.dev41__py3-none-any.whl → 0.6.1.dev42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1132 @@
1
+ import math
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+ from diffsynth_engine.models.base import PreTrainedModel
10
+ from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
+ from diffsynth_engine.models.basic import attention as attention_ops
12
+ from diffsynth_engine.utils.gguf import gguf_inference
13
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
14
+ from diffsynth_engine.utils.parallel import (
15
+ cfg_parallel,
16
+ cfg_parallel_unshard,
17
+ sequence_parallel,
18
+ sequence_parallel_unshard,
19
+ )
20
+
21
+
22
+ ADALN_EMBED_DIM = 256
23
+ SEQ_MULTI_OF = 32
24
+ X_PAD_DIM = 64
25
+
26
+
27
+ class TimestepEmbedder(nn.Module):
28
+ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
29
+ super().__init__()
30
+ if mid_size is None:
31
+ mid_size = out_size
32
+ self.mlp = nn.Sequential(
33
+ nn.Linear(
34
+ frequency_embedding_size,
35
+ mid_size,
36
+ bias=True,
37
+ ),
38
+ nn.SiLU(),
39
+ nn.Linear(
40
+ mid_size,
41
+ out_size,
42
+ bias=True,
43
+ ),
44
+ )
45
+
46
+ self.frequency_embedding_size = frequency_embedding_size
47
+
48
+ @staticmethod
49
+ def timestep_embedding(t, dim, max_period=10000):
50
+ with torch.amp.autocast("cuda", enabled=False):
51
+ half = dim // 2
52
+ freqs = torch.exp(
53
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
54
+ )
55
+ args = t[:, None].float() * freqs[None]
56
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57
+ if dim % 2:
58
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59
+ return embedding
60
+
61
+ def forward(self, t):
62
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63
+ t_emb = self.mlp(t_freq.to(torch.bfloat16))
64
+ return t_emb
65
+
66
+
67
+ class FeedForward(nn.Module):
68
+ def __init__(self, dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
71
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
72
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
73
+
74
+ def _forward_silu_gating(self, x1, x3):
75
+ return F.silu(x1) * x3
76
+
77
+ def forward(self, x):
78
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
79
+
80
+
81
+ class Attention(torch.nn.Module):
82
+
83
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
84
+ super().__init__()
85
+ dim_inner = head_dim * num_heads
86
+ kv_dim = kv_dim if kv_dim is not None else q_dim
87
+ self.num_heads = num_heads
88
+ self.head_dim = head_dim
89
+
90
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
91
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
92
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
93
+ self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])
94
+
95
+ self.norm_q = RMSNorm(head_dim, eps=1e-5)
96
+ self.norm_k = RMSNorm(head_dim, eps=1e-5)
97
+
98
+ def forward(self, hidden_states, freqs_cis, attention_mask):
99
+ query = self.to_q(hidden_states)
100
+ key = self.to_k(hidden_states)
101
+ value = self.to_v(hidden_states)
102
+
103
+ query = query.unflatten(-1, (self.num_heads, -1))
104
+ key = key.unflatten(-1, (self.num_heads, -1))
105
+ value = value.unflatten(-1, (self.num_heads, -1))
106
+
107
+ # Apply Norms
108
+ if self.norm_q is not None:
109
+ query = self.norm_q(query)
110
+ if self.norm_k is not None:
111
+ key = self.norm_k(key)
112
+
113
+ # Apply RoPE
114
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
115
+ with torch.amp.autocast("cuda", enabled=False):
116
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
117
+ freqs_cis = freqs_cis.unsqueeze(2)
118
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
119
+ return x_out.type_as(x_in) # todo
120
+
121
+ if freqs_cis is not None:
122
+ query = apply_rotary_emb(query, freqs_cis)
123
+ key = apply_rotary_emb(key, freqs_cis)
124
+
125
+ # Cast to correct dtype
126
+ dtype = query.dtype
127
+ query, key = query.to(dtype), key.to(dtype)
128
+
129
+ # Compute joint attention
130
+ if attention_mask.shape[0] > 1:
131
+ attention_mask = attention_mask[:1]
132
+ hidden_states = attention_ops.attention(query, key, value, attn_mask=attention_mask)
133
+
134
+ # Reshape back
135
+ hidden_states = hidden_states.flatten(2, 3)
136
+ hidden_states = hidden_states.to(dtype)
137
+
138
+ output = self.to_out[0](hidden_states)
139
+ if len(self.to_out) > 1: # dropout
140
+ output = self.to_out[1](output)
141
+
142
+ return output
143
+
144
+
145
+ def select_per_token(
146
+ value_noisy: torch.Tensor,
147
+ value_clean: torch.Tensor,
148
+ noise_mask: torch.Tensor,
149
+ seq_len: int,
150
+ ) -> torch.Tensor:
151
+ noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
152
+ return torch.where(
153
+ noise_mask_expanded == 1,
154
+ value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
155
+ value_clean.unsqueeze(1).expand(-1, seq_len, -1),
156
+ )
157
+
158
+
159
+ class ZImageTransformerBlock(nn.Module):
160
+ def __init__(
161
+ self,
162
+ layer_id: int,
163
+ dim: int,
164
+ n_heads: int,
165
+ n_kv_heads: int,
166
+ norm_eps: float,
167
+ qk_norm: bool,
168
+ modulation=True,
169
+ ):
170
+ super().__init__()
171
+ self.dim = dim
172
+ self.head_dim = dim // n_heads
173
+
174
+ # Refactored to use diffusers Attention with custom processor
175
+ # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
176
+ self.attention = Attention(
177
+ q_dim=dim,
178
+ num_heads=n_heads,
179
+ head_dim=dim // n_heads,
180
+ )
181
+
182
+ self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
183
+ self.layer_id = layer_id
184
+
185
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
186
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
187
+
188
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
189
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
190
+
191
+ self.modulation = modulation
192
+ if modulation:
193
+ self.adaLN_modulation = nn.Sequential(
194
+ nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
195
+ )
196
+
197
+ def forward(
198
+ self,
199
+ x: torch.Tensor,
200
+ attn_mask: torch.Tensor,
201
+ freqs_cis: torch.Tensor,
202
+ adaln_input: Optional[torch.Tensor] = None,
203
+ noise_mask: Optional[torch.Tensor] = None,
204
+ adaln_noisy: Optional[torch.Tensor] = None,
205
+ adaln_clean: Optional[torch.Tensor] = None,
206
+ ):
207
+ if self.modulation:
208
+ seq_len = x.shape[1]
209
+
210
+ if noise_mask is not None:
211
+ # Per-token modulation: different modulation for noisy/clean tokens
212
+ mod_noisy = self.adaLN_modulation(adaln_noisy)
213
+ mod_clean = self.adaLN_modulation(adaln_clean)
214
+
215
+ scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
216
+ scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
217
+
218
+ gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
219
+ gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
220
+
221
+ scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
222
+ scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
223
+
224
+ scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
225
+ scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
226
+ gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
227
+ gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
228
+ else:
229
+ # Global modulation: same modulation for all tokens (avoid double select)
230
+ mod = self.adaLN_modulation(adaln_input)
231
+ scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
232
+ gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
233
+ scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
234
+
235
+ # Attention block
236
+ attn_out = self.attention(
237
+ self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
238
+ )
239
+ x = x + gate_msa * self.attention_norm2(attn_out)
240
+
241
+ # FFN block
242
+ x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
243
+ else:
244
+ # Attention block
245
+ attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
246
+ x = x + self.attention_norm2(attn_out)
247
+
248
+ # FFN block
249
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
250
+
251
+ return x
252
+
253
+
254
+ class FinalLayer(nn.Module):
255
+ def __init__(self, hidden_size, out_channels):
256
+ super().__init__()
257
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
258
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
259
+
260
+ self.adaLN_modulation = nn.Sequential(
261
+ nn.SiLU(),
262
+ nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
263
+ )
264
+
265
+ def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
266
+ seq_len = x.shape[1]
267
+
268
+ if noise_mask is not None:
269
+ # Per-token modulation
270
+ scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
271
+ scale_clean = 1.0 + self.adaLN_modulation(c_clean)
272
+ scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
273
+ else:
274
+ # Original global modulation
275
+ assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
276
+ scale = 1.0 + self.adaLN_modulation(c)
277
+ scale = scale.unsqueeze(1)
278
+
279
+ x = self.norm_final(x) * scale
280
+ x = self.linear(x)
281
+ return x
282
+
283
+
284
+ class RopeEmbedder:
285
+ def __init__(
286
+ self,
287
+ theta: float = 256.0,
288
+ axes_dims: List[int] = (16, 56, 56),
289
+ axes_lens: List[int] = (64, 128, 128),
290
+ ):
291
+ self.theta = theta
292
+ self.axes_dims = axes_dims
293
+ self.axes_lens = axes_lens
294
+ assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
295
+ self.freqs_cis = None
296
+
297
+ @staticmethod
298
+ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
299
+ with torch.device("cpu"):
300
+ freqs_cis = []
301
+ for i, (d, e) in enumerate(zip(dim, end)):
302
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
303
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
304
+ freqs = torch.outer(timestep, freqs).float()
305
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
306
+ freqs_cis.append(freqs_cis_i)
307
+
308
+ return freqs_cis
309
+
310
+ def __call__(self, ids: torch.Tensor):
311
+ assert ids.ndim == 2
312
+ assert ids.shape[-1] == len(self.axes_dims)
313
+ device = ids.device
314
+
315
+ if self.freqs_cis is None:
316
+ self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
317
+ self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
318
+
319
+ result = []
320
+ for i in range(len(self.axes_dims)):
321
+ index = ids[:, i]
322
+ result.append(self.freqs_cis[i][index])
323
+ return torch.cat(result, dim=-1)
324
+
325
+
326
+ class ZImageOmniBaseDiT(PreTrainedModel):
327
+ _supports_gradient_checkpointing = True
328
+ _no_split_modules = ["ZImageTransformerBlock"]
329
+
330
+ def __init__(
331
+ self,
332
+ all_patch_size=(2,),
333
+ all_f_patch_size=(1,),
334
+ in_channels=16,
335
+ dim=3840,
336
+ n_layers=30,
337
+ n_refiner_layers=2,
338
+ n_heads=30,
339
+ n_kv_heads=30,
340
+ norm_eps=1e-5,
341
+ qk_norm=True,
342
+ cap_feat_dim=2560,
343
+ rope_theta=256.0,
344
+ t_scale=1000.0,
345
+ axes_dims=[32, 48, 48],
346
+ axes_lens=[1024, 512, 512],
347
+ siglip_feat_dim=1152,
348
+ **kwargs,
349
+ ) -> None:
350
+ super().__init__()
351
+ self.in_channels = in_channels
352
+ self.out_channels = in_channels
353
+ self.all_patch_size = all_patch_size
354
+ self.all_f_patch_size = all_f_patch_size
355
+ self.dim = dim
356
+ self.n_heads = n_heads
357
+
358
+ self.rope_theta = rope_theta
359
+ self.t_scale = t_scale
360
+ self.gradient_checkpointing = False
361
+
362
+ assert len(all_patch_size) == len(all_f_patch_size)
363
+
364
+ all_x_embedder = {}
365
+ all_final_layer = {}
366
+ for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
367
+ x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
368
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
369
+
370
+ final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
371
+ all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
372
+
373
+ self.all_x_embedder = nn.ModuleDict(all_x_embedder)
374
+ self.all_final_layer = nn.ModuleDict(all_final_layer)
375
+ self.noise_refiner = nn.ModuleList(
376
+ [
377
+ ZImageTransformerBlock(
378
+ 1000 + layer_id,
379
+ dim,
380
+ n_heads,
381
+ n_kv_heads,
382
+ norm_eps,
383
+ qk_norm,
384
+ modulation=True,
385
+ )
386
+ for layer_id in range(n_refiner_layers)
387
+ ]
388
+ )
389
+ self.context_refiner = nn.ModuleList(
390
+ [
391
+ ZImageTransformerBlock(
392
+ layer_id,
393
+ dim,
394
+ n_heads,
395
+ n_kv_heads,
396
+ norm_eps,
397
+ qk_norm,
398
+ modulation=False,
399
+ )
400
+ for layer_id in range(n_refiner_layers)
401
+ ]
402
+ )
403
+ self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
404
+ self.cap_embedder = nn.Sequential(
405
+ RMSNorm(cap_feat_dim, eps=norm_eps),
406
+ nn.Linear(cap_feat_dim, dim, bias=True),
407
+ )
408
+
409
+ # Optional SigLIP components (for Omni variant)
410
+ self.siglip_feat_dim = siglip_feat_dim
411
+ if siglip_feat_dim is not None:
412
+ self.siglip_embedder = nn.Sequential(
413
+ RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
414
+ )
415
+ self.siglip_refiner = nn.ModuleList(
416
+ [
417
+ ZImageTransformerBlock(
418
+ 2000 + layer_id,
419
+ dim,
420
+ n_heads,
421
+ n_kv_heads,
422
+ norm_eps,
423
+ qk_norm,
424
+ modulation=False,
425
+ )
426
+ for layer_id in range(n_refiner_layers)
427
+ ]
428
+ )
429
+ self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
430
+ else:
431
+ self.siglip_embedder = None
432
+ self.siglip_refiner = None
433
+ self.siglip_pad_token = None
434
+
435
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
436
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
437
+
438
+ self.layers = nn.ModuleList(
439
+ [
440
+ ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
441
+ for layer_id in range(n_layers)
442
+ ]
443
+ )
444
+ head_dim = dim // n_heads
445
+ assert head_dim == sum(axes_dims)
446
+ self.axes_dims = axes_dims
447
+ self.axes_lens = axes_lens
448
+
449
+ self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
450
+
451
+ def unpatchify(
452
+ self,
453
+ x: List[torch.Tensor],
454
+ size: List[Tuple],
455
+ patch_size = 2,
456
+ f_patch_size = 1,
457
+ x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
458
+ ) -> List[torch.Tensor]:
459
+ pH = pW = patch_size
460
+ pF = f_patch_size
461
+ bsz = len(x)
462
+ assert len(size) == bsz
463
+
464
+ if x_pos_offsets is not None:
465
+ # Omni: extract target image from unified sequence (cond_images + target)
466
+ result = []
467
+ for i in range(bsz):
468
+ unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
469
+ cu_len = 0
470
+ x_item = None
471
+ for j in range(len(size[i])):
472
+ if size[i][j] is None:
473
+ ori_len = 0
474
+ pad_len = SEQ_MULTI_OF
475
+ cu_len += pad_len + ori_len
476
+ else:
477
+ F, H, W = size[i][j]
478
+ ori_len = (F // pF) * (H // pH) * (W // pW)
479
+ pad_len = (-ori_len) % SEQ_MULTI_OF
480
+ x_item = (
481
+ unified_x[cu_len : cu_len + ori_len]
482
+ .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
483
+ .permute(6, 0, 3, 1, 4, 2, 5)
484
+ .reshape(self.out_channels, F, H, W)
485
+ )
486
+ cu_len += ori_len + pad_len
487
+ result.append(x_item) # Return only the last (target) image
488
+ return result
489
+ else:
490
+ # Original mode: simple unpatchify
491
+ for i in range(bsz):
492
+ F, H, W = size[i]
493
+ ori_len = (F // pF) * (H // pH) * (W // pW)
494
+ # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
495
+ x[i] = (
496
+ x[i][:ori_len]
497
+ .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
498
+ .permute(6, 0, 3, 1, 4, 2, 5)
499
+ .reshape(self.out_channels, F, H, W)
500
+ )
501
+ return x
502
+
503
+ @staticmethod
504
+ def create_coordinate_grid(size, start=None, device=None):
505
+ if start is None:
506
+ start = (0 for _ in size)
507
+
508
+ axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
509
+ grids = torch.meshgrid(axes, indexing="ij")
510
+ return torch.stack(grids, dim=-1)
511
+
512
+ def patchify_and_embed(
513
+ self,
514
+ all_image: List[torch.Tensor],
515
+ all_cap_feats: List[torch.Tensor],
516
+ patch_size: int = 2,
517
+ f_patch_size: int = 1,
518
+ ):
519
+ pH = pW = patch_size
520
+ pF = f_patch_size
521
+ device = all_image[0].device
522
+
523
+ all_image_out = []
524
+ all_image_size = []
525
+ all_image_pos_ids = []
526
+ all_image_pad_mask = []
527
+ all_cap_pos_ids = []
528
+ all_cap_pad_mask = []
529
+ all_cap_feats_out = []
530
+
531
+ for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
532
+ ### Process Caption
533
+ cap_ori_len = len(cap_feat)
534
+ cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
535
+ # padded position ids
536
+ cap_padded_pos_ids = self.create_coordinate_grid(
537
+ size=(cap_ori_len + cap_padding_len, 1, 1),
538
+ start=(1, 0, 0),
539
+ device=device,
540
+ ).flatten(0, 2)
541
+ all_cap_pos_ids.append(cap_padded_pos_ids)
542
+ # pad mask
543
+ all_cap_pad_mask.append(
544
+ torch.cat(
545
+ [
546
+ torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
547
+ torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
548
+ ],
549
+ dim=0,
550
+ )
551
+ )
552
+ # padded feature
553
+ cap_padded_feat = torch.cat(
554
+ [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
555
+ dim=0,
556
+ )
557
+ all_cap_feats_out.append(cap_padded_feat)
558
+
559
+ ### Process Image
560
+ C, F, H, W = image.size()
561
+ all_image_size.append((F, H, W))
562
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
563
+
564
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
565
+ # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
566
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
567
+
568
+ image_ori_len = len(image)
569
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
570
+
571
+ image_ori_pos_ids = self.create_coordinate_grid(
572
+ size=(F_tokens, H_tokens, W_tokens),
573
+ start=(cap_ori_len + cap_padding_len + 1, 0, 0),
574
+ device=device,
575
+ ).flatten(0, 2)
576
+ image_padding_pos_ids = (
577
+ self.create_coordinate_grid(
578
+ size=(1, 1, 1),
579
+ start=(0, 0, 0),
580
+ device=device,
581
+ )
582
+ .flatten(0, 2)
583
+ .repeat(image_padding_len, 1)
584
+ )
585
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
586
+ all_image_pos_ids.append(image_padded_pos_ids)
587
+ # pad mask
588
+ all_image_pad_mask.append(
589
+ torch.cat(
590
+ [
591
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
592
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
593
+ ],
594
+ dim=0,
595
+ )
596
+ )
597
+ # padded feature
598
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
599
+ all_image_out.append(image_padded_feat)
600
+
601
+ return all_image_out, all_cap_feats_out, {
602
+ "x_size": all_image_size,
603
+ "x_pos_ids": all_image_pos_ids,
604
+ "cap_pos_ids": all_cap_pos_ids,
605
+ "x_pad_mask": all_image_pad_mask,
606
+ "cap_pad_mask": all_cap_pad_mask
607
+ }
608
+
609
+ def patchify_controlnet(
610
+ self,
611
+ all_image: List[torch.Tensor],
612
+ patch_size: int = 2,
613
+ f_patch_size: int = 1,
614
+ cap_padding_len: int = None,
615
+ ):
616
+ pH = pW = patch_size
617
+ pF = f_patch_size
618
+ device = all_image[0].device
619
+
620
+ all_image_out = []
621
+ all_image_size = []
622
+ all_image_pos_ids = []
623
+ all_image_pad_mask = []
624
+
625
+ for i, image in enumerate(all_image):
626
+ ### Process Image
627
+ C, F, H, W = image.size()
628
+ all_image_size.append((F, H, W))
629
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
630
+
631
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
632
+ # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
633
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
634
+
635
+ image_ori_len = len(image)
636
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
637
+
638
+ image_ori_pos_ids = self.create_coordinate_grid(
639
+ size=(F_tokens, H_tokens, W_tokens),
640
+ start=(cap_padding_len + 1, 0, 0),
641
+ device=device,
642
+ ).flatten(0, 2)
643
+ image_padding_pos_ids = (
644
+ self.create_coordinate_grid(
645
+ size=(1, 1, 1),
646
+ start=(0, 0, 0),
647
+ device=device,
648
+ )
649
+ .flatten(0, 2)
650
+ .repeat(image_padding_len, 1)
651
+ )
652
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
653
+ all_image_pos_ids.append(image_padded_pos_ids)
654
+ # pad mask
655
+ all_image_pad_mask.append(
656
+ torch.cat(
657
+ [
658
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
659
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
660
+ ],
661
+ dim=0,
662
+ )
663
+ )
664
+ # padded feature
665
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
666
+ all_image_out.append(image_padded_feat)
667
+
668
+ return (
669
+ all_image_out,
670
+ all_image_size,
671
+ all_image_pos_ids,
672
+ all_image_pad_mask,
673
+ )
674
+
675
+ def _prepare_sequence(
676
+ self,
677
+ feats: List[torch.Tensor],
678
+ pos_ids: List[torch.Tensor],
679
+ inner_pad_mask: List[torch.Tensor],
680
+ pad_token: torch.nn.Parameter,
681
+ noise_mask: Optional[List[List[int]]] = None,
682
+ device: torch.device = None,
683
+ ):
684
+ """Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
685
+ item_seqlens = [len(f) for f in feats]
686
+ max_seqlen = max(item_seqlens)
687
+ bsz = len(feats)
688
+
689
+ # Pad token
690
+ feats_cat = torch.cat(feats, dim=0)
691
+ feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
692
+ feats = list(feats_cat.split(item_seqlens, dim=0))
693
+
694
+ # RoPE
695
+ freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
696
+
697
+ # Pad to batch
698
+ feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
699
+ freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
700
+
701
+ # Attention mask
702
+ attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
703
+ for i, seq_len in enumerate(item_seqlens):
704
+ attn_mask[i, :seq_len] = 1
705
+
706
+ # Noise mask
707
+ noise_mask_tensor = None
708
+ if noise_mask is not None:
709
+ noise_mask_tensor = pad_sequence(
710
+ [torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
711
+ batch_first=True,
712
+ padding_value=0,
713
+ )[:, : feats.shape[1]]
714
+
715
+ return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
716
+
717
+ def _build_unified_sequence(
718
+ self,
719
+ x: torch.Tensor,
720
+ x_freqs: torch.Tensor,
721
+ x_seqlens: List[int],
722
+ x_noise_mask: Optional[List[List[int]]],
723
+ cap: torch.Tensor,
724
+ cap_freqs: torch.Tensor,
725
+ cap_seqlens: List[int],
726
+ cap_noise_mask: Optional[List[List[int]]],
727
+ siglip: Optional[torch.Tensor],
728
+ siglip_freqs: Optional[torch.Tensor],
729
+ siglip_seqlens: Optional[List[int]],
730
+ siglip_noise_mask: Optional[List[List[int]]],
731
+ omni_mode: bool,
732
+ device: torch.device,
733
+ ):
734
+ """Build unified sequence: x, cap, and optionally siglip.
735
+ Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
736
+ """
737
+ bsz = len(x_seqlens)
738
+ unified = []
739
+ unified_freqs = []
740
+ unified_noise_mask = []
741
+
742
+ for i in range(bsz):
743
+ x_len, cap_len = x_seqlens[i], cap_seqlens[i]
744
+
745
+ if omni_mode:
746
+ # Omni: [cap, x, siglip]
747
+ if siglip is not None and siglip_seqlens is not None:
748
+ sig_len = siglip_seqlens[i]
749
+ unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
750
+ unified_freqs.append(
751
+ torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
752
+ )
753
+ unified_noise_mask.append(
754
+ torch.tensor(
755
+ cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
756
+ )
757
+ )
758
+ else:
759
+ unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
760
+ unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
761
+ unified_noise_mask.append(
762
+ torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
763
+ )
764
+ else:
765
+ # Basic: [x, cap]
766
+ unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
767
+ unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
768
+
769
+ # Compute unified seqlens
770
+ if omni_mode:
771
+ if siglip is not None and siglip_seqlens is not None:
772
+ unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
773
+ else:
774
+ unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
775
+ else:
776
+ unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
777
+
778
+ max_seqlen = max(unified_seqlens)
779
+
780
+ # Pad to batch
781
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
782
+ unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
783
+
784
+ # Attention mask
785
+ attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
786
+ for i, seq_len in enumerate(unified_seqlens):
787
+ attn_mask[i, :seq_len] = 1
788
+
789
+ # Noise mask
790
+ noise_mask_tensor = None
791
+ if omni_mode:
792
+ noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
793
+ :, : unified.shape[1]
794
+ ]
795
+
796
+ return unified, unified_freqs, attn_mask, noise_mask_tensor
797
+
798
+ def _pad_with_ids(
799
+ self,
800
+ feat: torch.Tensor,
801
+ pos_grid_size: Tuple,
802
+ pos_start: Tuple,
803
+ device: torch.device,
804
+ noise_mask_val: Optional[int] = None,
805
+ ):
806
+ """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
807
+ ori_len = len(feat)
808
+ pad_len = (-ori_len) % SEQ_MULTI_OF
809
+ total_len = ori_len + pad_len
810
+
811
+ # Pos IDs
812
+ ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
813
+ if pad_len > 0:
814
+ pad_pos_ids = (
815
+ self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
816
+ .flatten(0, 2)
817
+ .repeat(pad_len, 1)
818
+ )
819
+ pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
820
+ padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
821
+ pad_mask = torch.cat(
822
+ [
823
+ torch.zeros(ori_len, dtype=torch.bool, device=device),
824
+ torch.ones(pad_len, dtype=torch.bool, device=device),
825
+ ]
826
+ )
827
+ else:
828
+ pos_ids = ori_pos_ids
829
+ padded_feat = feat
830
+ pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
831
+
832
+ noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
833
+ return padded_feat, pos_ids, pad_mask, total_len, noise_mask
834
+
835
+ def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
836
+ """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
837
+ pH, pW, pF = patch_size, patch_size, f_patch_size
838
+ C, F, H, W = image.size()
839
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
840
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
841
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
842
+ return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
843
+
844
+ def patchify_and_embed_omni(
845
+ self,
846
+ all_x: List[List[torch.Tensor]],
847
+ all_cap_feats: List[List[torch.Tensor]],
848
+ all_siglip_feats: List[List[torch.Tensor]],
849
+ patch_size: int = 2,
850
+ f_patch_size: int = 1,
851
+ images_noise_mask: List[List[int]] = None,
852
+ ):
853
+ """Patchify for omni mode: multiple images per batch item with noise masks."""
854
+ bsz = len(all_x)
855
+ device = all_x[0][-1].device
856
+ dtype = all_x[0][-1].dtype
857
+
858
+ all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
859
+ all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
860
+ all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
861
+
862
+ for i in range(bsz):
863
+ num_images = len(all_x[i])
864
+ cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
865
+ cap_end_pos = []
866
+ cap_cu_len = 1
867
+
868
+ # Process captions
869
+ for j, cap_item in enumerate(all_cap_feats[i]):
870
+ noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
871
+ cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
872
+ cap_item,
873
+ (len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
874
+ (cap_cu_len, 0, 0),
875
+ device,
876
+ noise_val,
877
+ )
878
+ cap_feats_list.append(cap_out)
879
+ cap_pos_list.append(cap_pos)
880
+ cap_mask_list.append(cap_mask)
881
+ cap_lens.append(cap_len)
882
+ cap_noise.extend(cap_nm)
883
+ cap_cu_len += len(cap_item)
884
+ cap_end_pos.append(cap_cu_len)
885
+ cap_cu_len += 2 # for image vae and siglip tokens
886
+
887
+ all_cap_out.append(torch.cat(cap_feats_list, dim=0))
888
+ all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
889
+ all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
890
+ all_cap_len.append(cap_lens)
891
+ all_cap_noise_mask.append(cap_noise)
892
+
893
+ # Process images
894
+ x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
895
+ for j, x_item in enumerate(all_x[i]):
896
+ noise_val = images_noise_mask[i][j]
897
+ if x_item is not None:
898
+ x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
899
+ x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
900
+ x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
901
+ )
902
+ x_size.append(size)
903
+ else:
904
+ x_len = SEQ_MULTI_OF
905
+ x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
906
+ x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
907
+ x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
908
+ x_nm = [noise_val] * x_len
909
+ x_size.append(None)
910
+ x_feats_list.append(x_out)
911
+ x_pos_list.append(x_pos)
912
+ x_mask_list.append(x_mask)
913
+ x_lens.append(x_len)
914
+ x_noise.extend(x_nm)
915
+
916
+ all_x_out.append(torch.cat(x_feats_list, dim=0))
917
+ all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
918
+ all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
919
+ all_x_size.append(x_size)
920
+ all_x_len.append(x_lens)
921
+ all_x_noise_mask.append(x_noise)
922
+
923
+ # Process siglip
924
+ if all_siglip_feats[i] is None:
925
+ all_sig_len.append([0] * num_images)
926
+ all_sig_out.append(None)
927
+ else:
928
+ sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
929
+ for j, sig_item in enumerate(all_siglip_feats[i]):
930
+ noise_val = images_noise_mask[i][j]
931
+ if sig_item is not None:
932
+ sig_H, sig_W, sig_C = sig_item.size()
933
+ sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
934
+ sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
935
+ sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
936
+ )
937
+ # Scale position IDs to match x resolution
938
+ if x_size[j] is not None:
939
+ sig_pos = sig_pos.float()
940
+ sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
941
+ sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
942
+ sig_pos = sig_pos.to(torch.int32)
943
+ else:
944
+ sig_len = SEQ_MULTI_OF
945
+ sig_out = torch.zeros((sig_len, self.siglip_feat_dim), dtype=dtype, device=device)
946
+ sig_pos = (
947
+ self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
948
+ )
949
+ sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
950
+ sig_nm = [noise_val] * sig_len
951
+ sig_feats_list.append(sig_out)
952
+ sig_pos_list.append(sig_pos)
953
+ sig_mask_list.append(sig_mask)
954
+ sig_lens.append(sig_len)
955
+ sig_noise.extend(sig_nm)
956
+
957
+ all_sig_out.append(torch.cat(sig_feats_list, dim=0))
958
+ all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
959
+ all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
960
+ all_sig_len.append(sig_lens)
961
+ all_sig_noise_mask.append(sig_noise)
962
+
963
+ # Compute x position offsets
964
+ all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
965
+
966
+ return (
967
+ all_x_out,
968
+ all_cap_out,
969
+ all_sig_out,
970
+ all_x_size,
971
+ all_x_pos_ids,
972
+ all_cap_pos_ids,
973
+ all_sig_pos_ids,
974
+ all_x_pad_mask,
975
+ all_cap_pad_mask,
976
+ all_sig_pad_mask,
977
+ all_x_pos_offsets,
978
+ all_x_noise_mask,
979
+ all_cap_noise_mask,
980
+ all_sig_noise_mask,
981
+ )
982
+
983
+ def forward(
984
+ self,
985
+ x: List[torch.Tensor],
986
+ t,
987
+ cap_feats: List[torch.Tensor],
988
+ siglip_feats = None,
989
+ image_noise_mask = None,
990
+ patch_size=2,
991
+ f_patch_size=1,
992
+ ):
993
+ assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
994
+ omni_mode = isinstance(x[0], list)
995
+ device = x[0][-1].device if omni_mode else x[0].device
996
+
997
+ use_cfg = len(x) > 1 and isinstance(x[0], list)
998
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
999
+ with (
1000
+ fp8_inference(fp8_linear_enabled),
1001
+ gguf_inference(),
1002
+ cfg_parallel((x, t, cap_feats, siglip_feats, image_noise_mask), use_cfg=use_cfg),
1003
+ ):
1004
+ if omni_mode:
1005
+ # Dual embeddings: noisy (t) and clean (t=1)
1006
+ t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
1007
+ t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
1008
+ adaln_input = None
1009
+ else:
1010
+ # Single embedding for all tokens
1011
+ adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
1012
+ t_noisy = t_clean = None
1013
+
1014
+ # Patchify
1015
+ if omni_mode:
1016
+ (
1017
+ x,
1018
+ cap_feats,
1019
+ siglip_feats,
1020
+ x_size,
1021
+ x_pos_ids,
1022
+ cap_pos_ids,
1023
+ siglip_pos_ids,
1024
+ x_pad_mask,
1025
+ cap_pad_mask,
1026
+ siglip_pad_mask,
1027
+ x_pos_offsets,
1028
+ x_noise_mask,
1029
+ cap_noise_mask,
1030
+ siglip_noise_mask,
1031
+ ) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
1032
+ else:
1033
+ (
1034
+ x,
1035
+ cap_feats,
1036
+ x_size,
1037
+ x_pos_ids,
1038
+ cap_pos_ids,
1039
+ x_pad_mask,
1040
+ cap_pad_mask,
1041
+ ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
1042
+ x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
1043
+
1044
+ # x embed & refine
1045
+ x_seqlens = [len(xi) for xi in x]
1046
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
1047
+ x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
1048
+ list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
1049
+ )
1050
+
1051
+ for layer in self.noise_refiner:
1052
+ x = layer(x=x, attn_mask=x_mask, freqs_cis=x_freqs, adaln_input=adaln_input, noise_mask=x_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean)
1053
+
1054
+ # Cap embed & refine
1055
+ cap_seqlens = [len(ci) for ci in cap_feats]
1056
+ cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
1057
+ cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
1058
+ list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
1059
+ )
1060
+
1061
+ for layer in self.context_refiner:
1062
+ cap_feats = layer(x=cap_feats, attn_mask=cap_mask, freqs_cis=cap_freqs)
1063
+
1064
+ # Siglip embed & refine
1065
+ siglip_seqlens = siglip_freqs = None
1066
+ if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
1067
+ siglip_seqlens = [len(si) for si in siglip_feats]
1068
+ siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
1069
+ siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
1070
+ list(siglip_feats.split(siglip_seqlens, dim=0)),
1071
+ siglip_pos_ids,
1072
+ siglip_pad_mask,
1073
+ self.siglip_pad_token,
1074
+ None,
1075
+ device,
1076
+ )
1077
+
1078
+ for layer in self.siglip_refiner:
1079
+ siglip_feats = layer(x=siglip_feats, attn_mask=siglip_mask, freqs_cis=siglip_freqs)
1080
+
1081
+ # Unified sequence
1082
+ unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
1083
+ x,
1084
+ x_freqs,
1085
+ x_seqlens,
1086
+ x_noise_mask,
1087
+ cap_feats,
1088
+ cap_freqs,
1089
+ cap_seqlens,
1090
+ cap_noise_mask,
1091
+ siglip_feats,
1092
+ siglip_freqs,
1093
+ siglip_seqlens,
1094
+ siglip_noise_mask,
1095
+ omni_mode,
1096
+ device,
1097
+ )
1098
+
1099
+ # Main transformer layers
1100
+ with sequence_parallel((unified, unified_freqs, unified_noise_tensor), seq_dims=(1, 1, 1)):
1101
+ for layer_idx, layer in enumerate(self.layers):
1102
+ unified = layer(x=unified, attn_mask=unified_mask, freqs_cis=unified_freqs, adaln_input=adaln_input, noise_mask=unified_noise_tensor, adaln_noisy=t_noisy, adaln_clean=t_clean)
1103
+ (unified,) = sequence_parallel_unshard((unified,), seq_dims=(1,), seq_lens=(unified.shape[1],))
1104
+
1105
+ unified = (
1106
+ self.all_final_layer[f"{patch_size}-{f_patch_size}"](
1107
+ unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
1108
+ )
1109
+ if omni_mode
1110
+ else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
1111
+ )
1112
+
1113
+ # Unpatchify
1114
+ x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
1115
+
1116
+ (x,) = cfg_parallel_unshard((x,), use_cfg=use_cfg)
1117
+
1118
+ return x
1119
+
1120
+ @classmethod
1121
+ def from_state_dict(
1122
+ cls,
1123
+ state_dict,
1124
+ device: str,
1125
+ dtype: torch.dtype,
1126
+ **kwargs,
1127
+ ):
1128
+ model = cls(device="meta", dtype=dtype, **kwargs)
1129
+ model = model.requires_grad_(False)
1130
+ model.load_state_dict(state_dict, assign=True)
1131
+ model.to(device=device, dtype=dtype, non_blocking=True)
1132
+ return model