InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,531 @@
1
+ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2
+ """Z-Image ControlNet Extension for spatial conditioning.
3
+
4
+ This module provides an extension-based approach to Z-Image ControlNet,
5
+ similar to how FLUX ControlNet works. Instead of duplicating the entire
6
+ transformer, we compute control hints separately and inject them into
7
+ the base transformer's forward pass.
8
+ """
9
+
10
+ from typing import List, Optional, Tuple
11
+
12
+ import torch
13
+ from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
14
+ from torch.nn.utils.rnn import pad_sequence
15
+
16
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
17
+ from invokeai.backend.z_image.z_image_patchify_utils import SEQ_MULTI_OF, patchify_control_context
18
+
19
+
20
+ class ZImageControlNetExtension:
21
+ """Extension for Z-Image ControlNet - computes control hints without duplicating the transformer.
22
+
23
+ This class follows the same pattern as FLUX ControlNet extensions:
24
+ - The control adapter is loaded separately
25
+ - Control hints are computed per step
26
+ - Hints are injected into the transformer's layer outputs
27
+
28
+ Attributes:
29
+ control_adapter: The Z-Image control adapter model
30
+ control_cond: VAE-encoded control image latents
31
+ weight: Control strength (recommended: 0.65-0.80)
32
+ begin_step_percent: When to start applying control (0.0 = start)
33
+ end_step_percent: When to stop applying control (1.0 = end)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ control_adapter: ZImageControlAdapter,
39
+ control_cond: torch.Tensor,
40
+ weight: float = 0.75,
41
+ begin_step_percent: float = 0.0,
42
+ end_step_percent: float = 1.0,
43
+ skip_layers: int = 0, # Skip first N control injection layers
44
+ ):
45
+ self._adapter = control_adapter
46
+ self._control_cond = control_cond
47
+ self._weight = weight
48
+ self._begin_step_percent = begin_step_percent
49
+ self._end_step_percent = end_step_percent
50
+ self._skip_layers = skip_layers
51
+
52
+ # Get actual number of control blocks from loaded model (not config!)
53
+ # The safetensors may have more blocks than the config suggests
54
+ self._num_control_blocks = len(control_adapter.control_layers)
55
+
56
+ # Control layers are applied at every other layer (0, 2, 4, ...)
57
+ # This matches the default configuration in the original implementation
58
+ self._control_places = [i * 2 for i in range(self._num_control_blocks)]
59
+
60
+ # DEBUG: Print control configuration
61
+ print(f"[DEBUG] Actual num_control_blocks: {self._num_control_blocks}")
62
+ print(f"[DEBUG] control_places: {self._control_places}")
63
+
64
+ # DEBUG: Check if control_layers have non-zero weights
65
+ first_layer = control_adapter.control_layers[0]
66
+ if hasattr(first_layer, "after_proj"):
67
+ after_proj_norm = first_layer.after_proj.weight.norm().item()
68
+ print(f"[DEBUG] First control layer after_proj weight norm: {after_proj_norm}")
69
+ if after_proj_norm < 1e-6:
70
+ print("[WARNING] after_proj weights are near-zero! Weights may not be loaded correctly.")
71
+
72
+ @property
73
+ def weight(self) -> float:
74
+ return self._weight
75
+
76
+ @property
77
+ def control_places(self) -> List[int]:
78
+ return self._control_places
79
+
80
+ def should_apply(self, step_index: int, total_steps: int) -> bool:
81
+ """Check if control should be applied at this step."""
82
+ if total_steps == 0:
83
+ return True
84
+ step_percent = step_index / total_steps
85
+ return self._begin_step_percent <= step_percent <= self._end_step_percent
86
+
87
+ def prepare_control_state(
88
+ self,
89
+ base_transformer: ZImageTransformer2DModel,
90
+ cap_feats: torch.Tensor,
91
+ timestep_emb: torch.Tensor,
92
+ x_item_seqlens: List[int],
93
+ cap_item_seqlens: List[int],
94
+ x_freqs_cis: torch.Tensor,
95
+ patch_size: int = 2,
96
+ f_patch_size: int = 1,
97
+ ) -> torch.Tensor:
98
+ """Prepare control state (control_unified) for incremental hint computation.
99
+
100
+ This processes the control condition through patchify and noise_refiner,
101
+ returning the control_unified tensor that will be used incrementally.
102
+ """
103
+ bsz = 1
104
+ device = self._control_cond.device
105
+
106
+ # Patchify control context
107
+ control_context = [self._control_cond]
108
+ (
109
+ control_patches,
110
+ _,
111
+ _control_pos_ids,
112
+ control_pad_mask,
113
+ ) = patchify_control_context(
114
+ control_context,
115
+ patch_size,
116
+ f_patch_size,
117
+ cap_feats.size(1),
118
+ )
119
+
120
+ # Embed control context
121
+ ctrl_item_seqlens = [len(p) for p in control_patches]
122
+ ctrl_max_seqlen = max(ctrl_item_seqlens)
123
+
124
+ control_cat = torch.cat(control_patches, dim=0)
125
+ embedder_key = f"{patch_size}-{f_patch_size}"
126
+ control_cat = self._adapter.control_all_x_embedder[embedder_key](control_cat)
127
+
128
+ # Apply padding token
129
+ adaln_input = timestep_emb.type_as(control_cat)
130
+ x_pad_token = self._adapter.x_pad_token.to(dtype=control_cat.dtype)
131
+ control_cat[torch.cat(control_pad_mask)] = x_pad_token
132
+
133
+ control_list = list(control_cat.split(ctrl_item_seqlens, dim=0))
134
+ control_padded = pad_sequence(control_list, batch_first=True, padding_value=0.0)
135
+
136
+ # Use x_freqs_cis from main path for aligned position encoding
137
+ ctrl_freqs_cis_for_refiner = x_freqs_cis[:, : control_padded.shape[1]]
138
+
139
+ ctrl_attn_mask = torch.zeros((bsz, ctrl_max_seqlen), dtype=torch.bool, device=device)
140
+ for i, seq_len in enumerate(ctrl_item_seqlens):
141
+ ctrl_attn_mask[i, :seq_len] = 1
142
+
143
+ # Refine control context through control_noise_refiner
144
+ for layer in self._adapter.control_noise_refiner:
145
+ control_padded = layer(control_padded, ctrl_attn_mask, ctrl_freqs_cis_for_refiner, adaln_input)
146
+
147
+ # Store these for compute_single_hint
148
+ self._ctrl_item_seqlens = ctrl_item_seqlens
149
+ self._adaln_input = adaln_input
150
+
151
+ # Unify control with caption features
152
+ control_unified = []
153
+ for i in range(bsz):
154
+ ctrl_len = ctrl_item_seqlens[i]
155
+ cap_len = cap_item_seqlens[i]
156
+ control_unified.append(torch.cat([control_padded[i][:ctrl_len], cap_feats[i][:cap_len]]))
157
+
158
+ control_unified = pad_sequence(control_unified, batch_first=True, padding_value=0.0)
159
+
160
+ # DEBUG (only once)
161
+ if not hasattr(self, "_prepare_printed"):
162
+ self._prepare_printed = True
163
+ print(f"[DEBUG] Control state prepared: shape {control_unified.shape}")
164
+
165
+ return control_unified
166
+
167
+ def compute_single_hint(
168
+ self,
169
+ control_layer_idx: int,
170
+ control_state: torch.Tensor,
171
+ unified_hidden_states: torch.Tensor,
172
+ attn_mask: torch.Tensor,
173
+ freqs_cis: torch.Tensor,
174
+ adaln_input: torch.Tensor,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
176
+ """Compute a single hint from one control layer.
177
+
178
+ Args:
179
+ control_layer_idx: Which control layer to use (0, 1, 2, ...)
180
+ control_state: Current control state (stacked tensor from previous layers)
181
+ unified_hidden_states: Current unified hidden states from main transformer
182
+ attn_mask: Attention mask
183
+ freqs_cis: RoPE frequencies
184
+ adaln_input: Timestep embedding
185
+
186
+ Returns:
187
+ Tuple of (hint tensor, updated control_state)
188
+ """
189
+ layer = self._adapter.control_layers[control_layer_idx]
190
+
191
+ # Run control layer with CURRENT unified_hidden_states
192
+ control_state = layer(
193
+ control_state,
194
+ x=unified_hidden_states,
195
+ attn_mask=attn_mask,
196
+ freqs_cis=freqs_cis,
197
+ adaln_input=adaln_input,
198
+ )
199
+
200
+ # Extract hint from stacked state
201
+ # After control layer, control_state is stacked: [skip_0, ..., skip_n, running_state]
202
+ # We want the latest skip (second to last element)
203
+ unbinded = torch.unbind(control_state)
204
+ hint = unbinded[-2] # Latest skip connection
205
+
206
+ return hint, control_state
207
+
208
+ def compute_hints(
209
+ self,
210
+ base_transformer: ZImageTransformer2DModel,
211
+ unified_hidden_states: torch.Tensor,
212
+ cap_feats: torch.Tensor,
213
+ timestep_emb: torch.Tensor,
214
+ attn_mask: torch.Tensor,
215
+ freqs_cis: torch.Tensor,
216
+ x_item_seqlens: List[int],
217
+ cap_item_seqlens: List[int],
218
+ x_freqs_cis: torch.Tensor,
219
+ patch_size: int = 2,
220
+ f_patch_size: int = 1,
221
+ ) -> Tuple[torch.Tensor, ...]:
222
+ """Compute control hints using the adapter.
223
+
224
+ This method processes the control condition through the adapter's
225
+ control_noise_refiner and control_layers to produce hints that
226
+ will be added to the transformer's hidden states.
227
+
228
+ Args:
229
+ base_transformer: The base Z-Image transformer (for rope_embedder)
230
+ unified_hidden_states: Combined image+caption hidden states
231
+ cap_feats: Caption feature embeddings (padded)
232
+ timestep_emb: Timestep embeddings (adaln_input)
233
+ attn_mask: Unified attention mask
234
+ freqs_cis: RoPE frequencies
235
+ x_item_seqlens: Image sequence lengths per batch item
236
+ cap_item_seqlens: Caption sequence lengths per batch item
237
+ patch_size: Spatial patch size
238
+ f_patch_size: Frame patch size
239
+
240
+ Returns:
241
+ Tuple of hint tensors to add at each control layer position
242
+ """
243
+ # control_cond is always [C, F, H, W] format (single control image)
244
+ # where C = control_in_dim (16 for V1, 33 for V2.0), F = 1 frame
245
+ bsz = 1
246
+ device = self._control_cond.device
247
+
248
+ # Wrap control_cond in a list for patchify_control_context
249
+ # Expected input: List of [C, F, H, W] tensors
250
+ control_context = [self._control_cond]
251
+
252
+ # Patchify control context
253
+ # Note: We don't use control_pos_ids anymore - we use x_freqs_cis from main path instead
254
+ (
255
+ control_patches,
256
+ _,
257
+ _control_pos_ids, # Not used - we use main path's position encoding
258
+ control_pad_mask,
259
+ ) = patchify_control_context(
260
+ control_context,
261
+ patch_size,
262
+ f_patch_size,
263
+ cap_feats.size(1),
264
+ )
265
+
266
+ # Embed control context
267
+ ctrl_item_seqlens = [len(p) for p in control_patches]
268
+ assert all(s % SEQ_MULTI_OF == 0 for s in ctrl_item_seqlens)
269
+ ctrl_max_seqlen = max(ctrl_item_seqlens)
270
+
271
+ control_cat = torch.cat(control_patches, dim=0)
272
+ embedder_key = f"{patch_size}-{f_patch_size}"
273
+ control_cat = self._adapter.control_all_x_embedder[embedder_key](control_cat)
274
+
275
+ # Apply padding token (ensure dtype matches)
276
+ adaln_input = timestep_emb.type_as(control_cat)
277
+ x_pad_token = self._adapter.x_pad_token.to(dtype=control_cat.dtype)
278
+ control_cat[torch.cat(control_pad_mask)] = x_pad_token
279
+
280
+ control_list = list(control_cat.split(ctrl_item_seqlens, dim=0))
281
+
282
+ control_padded = pad_sequence(control_list, batch_first=True, padding_value=0.0)
283
+
284
+ # Use x_freqs_cis from main path for control patches (same spatial structure)
285
+ # This ensures control and image have aligned position encodings
286
+ ctrl_freqs_cis_for_refiner = x_freqs_cis[:, : control_padded.shape[1]]
287
+
288
+ ctrl_attn_mask = torch.zeros((bsz, ctrl_max_seqlen), dtype=torch.bool, device=device)
289
+ for i, seq_len in enumerate(ctrl_item_seqlens):
290
+ ctrl_attn_mask[i, :seq_len] = 1
291
+
292
+ # Refine control context through control_noise_refiner
293
+ # Using x_freqs_cis to match main path's position encoding
294
+ for layer in self._adapter.control_noise_refiner:
295
+ control_padded = layer(control_padded, ctrl_attn_mask, ctrl_freqs_cis_for_refiner, adaln_input)
296
+
297
+ # Unify control with caption features
298
+ control_unified = []
299
+ for i in range(bsz):
300
+ ctrl_len = ctrl_item_seqlens[i]
301
+ cap_len = cap_item_seqlens[i]
302
+ control_unified.append(torch.cat([control_padded[i][:ctrl_len], cap_feats[i][:cap_len]]))
303
+
304
+ control_unified = pad_sequence(control_unified, batch_first=True, padding_value=0.0)
305
+ c = control_unified
306
+
307
+ # Process through control_layers to generate hints
308
+ # DEBUG: Print shapes before control_layers (only on first call)
309
+ if not hasattr(self, "_debug_printed"):
310
+ self._debug_printed = True
311
+ print(f"[DEBUG] control_unified shape: {control_unified.shape}")
312
+ print(f"[DEBUG] unified_hidden_states shape: {unified_hidden_states.shape}")
313
+ print(f"[DEBUG] ctrl_item_seqlens: {ctrl_item_seqlens}, x_item_seqlens: {x_item_seqlens}")
314
+
315
+ # Check weight norms of critical layers
316
+ layer0 = self._adapter.control_layers[0]
317
+ if hasattr(layer0, "before_proj"):
318
+ print(f"[DEBUG] before_proj weight norm: {layer0.before_proj.weight.norm().item():.6f}")
319
+ if hasattr(layer0, "after_proj"):
320
+ print(f"[DEBUG] after_proj weight norm: {layer0.after_proj.weight.norm().item():.6f}")
321
+
322
+ # Check control_noise_refiner weights
323
+ if len(self._adapter.control_noise_refiner) > 0:
324
+ refiner0 = self._adapter.control_noise_refiner[0]
325
+ if hasattr(refiner0, "attn"):
326
+ print(f"[DEBUG] noise_refiner[0] attn.wq norm: {refiner0.attn.wq.weight.norm().item():.6f}")
327
+
328
+ for layer in self._adapter.control_layers:
329
+ c = layer(
330
+ c,
331
+ x=unified_hidden_states,
332
+ attn_mask=attn_mask,
333
+ freqs_cis=freqs_cis,
334
+ adaln_input=adaln_input,
335
+ )
336
+
337
+ # Extract hints (all but the last element which is the running state)
338
+ hints = tuple(torch.unbind(c)[:-1])
339
+
340
+ # DEBUG: Print hint shapes (only on first call)
341
+ if not hasattr(self, "_hints_printed"):
342
+ self._hints_printed = True
343
+ print(f"[DEBUG] Number of hints: {len(hints)}")
344
+ if hints:
345
+ print(f"[DEBUG] First hint shape: {hints[0].shape}")
346
+ # Also check hint statistics for each hint
347
+ for i, h in enumerate(hints[:3]): # First 3 hints
348
+ print(
349
+ f"[DEBUG] Hint[{i}] mean: {h.mean().item():.6f}, std: {h.std().item():.6f}, min: {h.min().item():.6f}, max: {h.max().item():.6f}"
350
+ )
351
+
352
+ return hints
353
+
354
+
355
+ def z_image_forward_with_control(
356
+ transformer: ZImageTransformer2DModel,
357
+ x: List[torch.Tensor],
358
+ t: torch.Tensor,
359
+ cap_feats: List[torch.Tensor],
360
+ control_extension: Optional[ZImageControlNetExtension] = None,
361
+ patch_size: int = 2,
362
+ f_patch_size: int = 1,
363
+ ) -> Tuple[List[torch.Tensor], dict]:
364
+ """Forward pass through Z-Image transformer with optional control injection.
365
+
366
+ This function replicates the base transformer's forward pass but allows
367
+ injecting control hints at specific layer positions. It uses the base
368
+ transformer's weights directly without duplicating them.
369
+
370
+ Args:
371
+ transformer: The base Z-Image transformer model
372
+ x: List of image tensors [C, F, H, W]
373
+ t: Timestep tensor
374
+ cap_feats: List of caption feature tensors
375
+ control_extension: Optional control extension for hint injection
376
+ patch_size: Spatial patch size (default: 2)
377
+ f_patch_size: Frame patch size (default: 1)
378
+
379
+ Returns:
380
+ Tuple of (output tensors list, empty dict for compatibility)
381
+ """
382
+ assert patch_size in transformer.all_patch_size
383
+ assert f_patch_size in transformer.all_f_patch_size
384
+
385
+ bsz = len(x)
386
+ device = x[0].device
387
+ t_scaled = t * transformer.t_scale
388
+ t_emb = transformer.t_embedder(t_scaled)
389
+
390
+ # Patchify and embed using base transformer's method
391
+ (
392
+ x_patches,
393
+ cap_feats_patches,
394
+ x_size,
395
+ x_pos_ids,
396
+ cap_pos_ids,
397
+ x_inner_pad_mask,
398
+ cap_inner_pad_mask,
399
+ ) = transformer.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
400
+
401
+ # === X embed & refine ===
402
+ x_item_seqlens = [len(p) for p in x_patches]
403
+ assert all(s % SEQ_MULTI_OF == 0 for s in x_item_seqlens)
404
+ x_max_item_seqlen = max(x_item_seqlens)
405
+
406
+ embedder_key = f"{patch_size}-{f_patch_size}"
407
+ x_cat = torch.cat(x_patches, dim=0)
408
+ x_cat = transformer.all_x_embedder[embedder_key](x_cat)
409
+
410
+ adaln_input = t_emb.type_as(x_cat)
411
+ x_cat[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token
412
+
413
+ x_list = list(x_cat.split(x_item_seqlens, dim=0))
414
+ x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(p) for p in x_pos_ids], dim=0))
415
+
416
+ x_padded = pad_sequence(x_list, batch_first=True, padding_value=0.0)
417
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
418
+ x_freqs_cis = x_freqs_cis[:, : x_padded.shape[1]]
419
+
420
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
421
+ for i, seq_len in enumerate(x_item_seqlens):
422
+ x_attn_mask[i, :seq_len] = 1
423
+
424
+ # Noise refiner
425
+ for layer in transformer.noise_refiner:
426
+ x_padded = layer(x_padded, x_attn_mask, x_freqs_cis, adaln_input)
427
+
428
+ # === Cap embed & refine ===
429
+ cap_item_seqlens = [len(p) for p in cap_feats_patches]
430
+ cap_max_item_seqlen = max(cap_item_seqlens)
431
+
432
+ cap_cat = torch.cat(cap_feats_patches, dim=0)
433
+ cap_cat = transformer.cap_embedder(cap_cat)
434
+ cap_cat[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token
435
+
436
+ cap_list = list(cap_cat.split(cap_item_seqlens, dim=0))
437
+ cap_freqs_cis = list(
438
+ transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(p) for p in cap_pos_ids], dim=0)
439
+ )
440
+
441
+ cap_padded = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
442
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
443
+ cap_freqs_cis = cap_freqs_cis[:, : cap_padded.shape[1]]
444
+
445
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
446
+ for i, seq_len in enumerate(cap_item_seqlens):
447
+ cap_attn_mask[i, :seq_len] = 1
448
+
449
+ # Context refiner
450
+ for layer in transformer.context_refiner:
451
+ cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis)
452
+
453
+ # === Unified ===
454
+ unified = []
455
+ unified_freqs_cis = []
456
+ for i in range(bsz):
457
+ x_len = x_item_seqlens[i]
458
+ cap_len = cap_item_seqlens[i]
459
+ unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
460
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
461
+
462
+ unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=False)]
463
+ unified_max_item_seqlen = max(unified_item_seqlens)
464
+
465
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
466
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
467
+
468
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
469
+ for i, seq_len in enumerate(unified_item_seqlens):
470
+ unified_attn_mask[i, :seq_len] = 1
471
+
472
+ # === Compute control hints if extension provided ===
473
+ # IMPORTANT: Hints are computed ONCE using the INITIAL unified state (before main layers)
474
+ # This matches the original VideoX-Fun architecture
475
+ control_places: List[int] = []
476
+ control_weight: float = 1.0
477
+ hints: Optional[Tuple[torch.Tensor, ...]] = None
478
+
479
+ # DEBUG: Print number of transformer layers (only once per session)
480
+ if not hasattr(z_image_forward_with_control, "_layers_printed"):
481
+ z_image_forward_with_control._layers_printed = True
482
+ print(f"[DEBUG] Base transformer has {len(transformer.layers)} layers")
483
+
484
+ if control_extension is not None:
485
+ # Compute ALL hints at once using the INITIAL unified state (before main layers run)
486
+ hints = control_extension.compute_hints(
487
+ base_transformer=transformer,
488
+ unified_hidden_states=unified, # INITIAL unified state!
489
+ cap_feats=cap_padded,
490
+ timestep_emb=adaln_input,
491
+ attn_mask=unified_attn_mask,
492
+ freqs_cis=unified_freqs_cis,
493
+ x_item_seqlens=x_item_seqlens,
494
+ cap_item_seqlens=cap_item_seqlens,
495
+ x_freqs_cis=x_freqs_cis,
496
+ patch_size=patch_size,
497
+ f_patch_size=f_patch_size,
498
+ )
499
+ control_places = control_extension.control_places
500
+ control_weight = control_extension.weight
501
+
502
+ # === Main transformer layers with pre-computed hint injection ===
503
+ skip_layers = control_extension._skip_layers if control_extension is not None else 0
504
+ control_layer_idx = 0
505
+ for layer_idx, layer in enumerate(transformer.layers):
506
+ unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
507
+
508
+ # Inject pre-computed control hint at designated positions
509
+ if hints is not None and layer_idx in control_places and control_layer_idx < len(hints):
510
+ # Skip first N hints if configured
511
+ if control_layer_idx >= skip_layers:
512
+ hint = hints[control_layer_idx]
513
+
514
+ # DEBUG: Print on first injection
515
+ if not hasattr(z_image_forward_with_control, "_injection_printed"):
516
+ z_image_forward_with_control._injection_printed = True
517
+ print(f"[DEBUG] Injection at layer {layer_idx} (control_layer {control_layer_idx})")
518
+ print(f"[DEBUG] Hint mean: {hint.mean().item():.6f}, std: {hint.std().item():.6f}")
519
+ print(f"[DEBUG] Unified mean: {unified.mean().item():.6f}, std: {unified.std().item():.6f}")
520
+ print(f"[DEBUG] control_weight: {control_weight}, skip_layers: {skip_layers}")
521
+
522
+ unified = unified + hint * control_weight
523
+
524
+ control_layer_idx += 1
525
+
526
+ # === Final layer and unpatchify ===
527
+ unified = transformer.all_final_layer[embedder_key](unified, adaln_input)
528
+ unified = list(unified.unbind(dim=0))
529
+ output = transformer.unpatchify(unified, x_size, patch_size, f_patch_size)
530
+
531
+ return output, {}
@@ -0,0 +1,135 @@
1
+ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2
+ """Utility functions for Z-Image patchify operations."""
3
+
4
+ from typing import List, Tuple
5
+
6
+ import torch
7
+
8
+ # Sequence must be multiple of this value (from diffusers transformer_z_image)
9
+ SEQ_MULTI_OF = 32
10
+
11
+
12
+ def create_coordinate_grid(
13
+ size: Tuple[int, ...],
14
+ start: Tuple[int, ...] | None = None,
15
+ device: torch.device | None = None,
16
+ ) -> torch.Tensor:
17
+ """Create a coordinate grid for position embeddings.
18
+
19
+ Args:
20
+ size: Size of the grid (e.g., (F, H, W))
21
+ start: Starting coordinates (default: all zeros)
22
+ device: Target device
23
+
24
+ Returns:
25
+ Coordinate grid tensor of shape (*size, len(size))
26
+ """
27
+ if start is None:
28
+ start = tuple(0 for _ in size)
29
+
30
+ axes = [
31
+ torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size, strict=False)
32
+ ]
33
+ grids = torch.meshgrid(axes, indexing="ij")
34
+ return torch.stack(grids, dim=-1)
35
+
36
+
37
+ def patchify_control_context(
38
+ all_image: List[torch.Tensor],
39
+ patch_size: int,
40
+ f_patch_size: int,
41
+ cap_seq_len: int,
42
+ ) -> Tuple[List[torch.Tensor], List[Tuple[int, int, int]], List[torch.Tensor], List[torch.Tensor]]:
43
+ """Patchify control images without embedding.
44
+
45
+ This function extracts patches from control images for control context processing.
46
+ It handles padding and position ID creation for the control signal.
47
+
48
+ Args:
49
+ all_image: List of control image tensors [C, F, H, W]
50
+ patch_size: Spatial patch size (height and width)
51
+ f_patch_size: Frame patch size
52
+ cap_seq_len: Caption sequence length (for position ID offset)
53
+
54
+ Returns:
55
+ Tuple of:
56
+ - all_image_out: List of patchified image tensors
57
+ - all_image_size: List of (F, H, W) tuples
58
+ - all_image_pos_ids: List of position ID tensors
59
+ - all_image_pad_mask: List of padding mask tensors
60
+ """
61
+ pH = pW = patch_size
62
+ pF = f_patch_size
63
+ device = all_image[0].device
64
+
65
+ all_image_out: List[torch.Tensor] = []
66
+ all_image_size: List[Tuple[int, int, int]] = []
67
+ all_image_pos_ids: List[torch.Tensor] = []
68
+ all_image_pad_mask: List[torch.Tensor] = []
69
+
70
+ # Calculate padded caption length for position offset
71
+ cap_padding_len = (-cap_seq_len) % SEQ_MULTI_OF
72
+ cap_padded_len = cap_seq_len + cap_padding_len
73
+
74
+ for image in all_image:
75
+ C, F, H, W = image.size()
76
+ all_image_size.append((F, H, W))
77
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
78
+
79
+ # Patchify: [C, F, H, W] -> [(F_tokens*H_tokens*W_tokens), (pF*pH*pW*C)]
80
+ # Step 1: Rearrange to put spatial dims together for proper patching
81
+ # [C, F, H, W] -> [F, H, W, C]
82
+ image = image.permute(1, 2, 3, 0).contiguous()
83
+
84
+ # Step 2: Split H and W into tokens and patch sizes
85
+ # [F, H, W, C] -> [F, H_tokens, pH, W_tokens, pW, C]
86
+ image = image.view(F, H_tokens, pH, W_tokens, pW, C)
87
+
88
+ # Step 3: Rearrange to group patches and features
89
+ # [F, H_tokens, pH, W_tokens, pW, C] -> [F, H_tokens, W_tokens, pH, pW, C]
90
+ image = image.permute(0, 1, 3, 2, 4, 5).contiguous()
91
+
92
+ # Step 4: For F > 1, we'd need to handle F similarly, but for F=1 this is simpler
93
+ # Final reshape: [F*H_tokens*W_tokens, pH*pW*C]
94
+ num_patches = F_tokens * H_tokens * W_tokens
95
+ patch_features = pF * pH * pW * C
96
+ image = image.reshape(num_patches, patch_features)
97
+
98
+ image_ori_len = len(image)
99
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
100
+
101
+ # Create position IDs
102
+ image_ori_pos_ids = create_coordinate_grid(
103
+ size=(F_tokens, H_tokens, W_tokens),
104
+ start=(cap_padded_len + 1, 0, 0),
105
+ device=device,
106
+ ).flatten(0, 2)
107
+
108
+ image_padding_pos_ids = (
109
+ create_coordinate_grid(
110
+ size=(1, 1, 1),
111
+ start=(0, 0, 0),
112
+ device=device,
113
+ )
114
+ .flatten(0, 2)
115
+ .repeat(image_padding_len, 1)
116
+ )
117
+ image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
118
+ all_image_pos_ids.append(image_padded_pos_ids)
119
+
120
+ # Padding mask
121
+ all_image_pad_mask.append(
122
+ torch.cat(
123
+ [
124
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
125
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
126
+ ],
127
+ dim=0,
128
+ )
129
+ )
130
+
131
+ # Padded feature
132
+ image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
133
+ all_image_out.append(image_padded_feat)
134
+
135
+ return all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask