InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Utilities for patching the ZImageTransformer2DModel to support regional attention masks."""
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import Callable, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def create_regional_forward(
|
|
11
|
+
original_forward: Callable,
|
|
12
|
+
regional_attn_mask: torch.Tensor,
|
|
13
|
+
img_seq_len: int,
|
|
14
|
+
) -> Callable:
|
|
15
|
+
"""Create a modified forward function that uses a regional attention mask.
|
|
16
|
+
|
|
17
|
+
The regional attention mask replaces the internally computed padding mask,
|
|
18
|
+
allowing for regional prompting where different image regions attend to
|
|
19
|
+
different text prompts.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
original_forward: The original forward method of ZImageTransformer2DModel.
|
|
23
|
+
regional_attn_mask: Attention mask of shape (seq_len, seq_len) where
|
|
24
|
+
seq_len = img_seq_len + txt_seq_len.
|
|
25
|
+
img_seq_len: Number of image tokens in the sequence.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A modified forward function with regional attention support.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def regional_forward(
|
|
32
|
+
self,
|
|
33
|
+
x: List[torch.Tensor],
|
|
34
|
+
t: torch.Tensor,
|
|
35
|
+
cap_feats: List[torch.Tensor],
|
|
36
|
+
patch_size: int = 2,
|
|
37
|
+
f_patch_size: int = 1,
|
|
38
|
+
) -> Tuple[List[torch.Tensor], dict]:
|
|
39
|
+
"""Modified forward with regional attention mask injection.
|
|
40
|
+
|
|
41
|
+
This is based on the original ZImageTransformer2DModel.forward but
|
|
42
|
+
replaces the padding-based attention mask with a regional attention mask.
|
|
43
|
+
"""
|
|
44
|
+
assert patch_size in self.all_patch_size
|
|
45
|
+
assert f_patch_size in self.all_f_patch_size
|
|
46
|
+
|
|
47
|
+
bsz = len(x)
|
|
48
|
+
device = x[0].device
|
|
49
|
+
t_scaled = t * self.t_scale
|
|
50
|
+
t_emb = self.t_embedder(t_scaled)
|
|
51
|
+
|
|
52
|
+
SEQ_MULTI_OF = 32 # From diffusers transformer_z_image.py
|
|
53
|
+
|
|
54
|
+
# Patchify and embed (reusing the original method)
|
|
55
|
+
(
|
|
56
|
+
x,
|
|
57
|
+
cap_feats,
|
|
58
|
+
x_size,
|
|
59
|
+
x_pos_ids,
|
|
60
|
+
cap_pos_ids,
|
|
61
|
+
x_inner_pad_mask,
|
|
62
|
+
cap_inner_pad_mask,
|
|
63
|
+
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
|
64
|
+
|
|
65
|
+
# x embed & refine
|
|
66
|
+
x_item_seqlens = [len(_) for _ in x]
|
|
67
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
|
68
|
+
x_max_item_seqlen = max(x_item_seqlens)
|
|
69
|
+
|
|
70
|
+
x_cat = torch.cat(x, dim=0)
|
|
71
|
+
x_cat = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
|
|
72
|
+
|
|
73
|
+
adaln_input = t_emb.type_as(x_cat)
|
|
74
|
+
x_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
|
75
|
+
x_list = list(x_cat.split(x_item_seqlens, dim=0))
|
|
76
|
+
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
|
77
|
+
|
|
78
|
+
x_padded = pad_sequence(x_list, batch_first=True, padding_value=0.0)
|
|
79
|
+
x_freqs_cis_padded = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
80
|
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
81
|
+
for i, seq_len in enumerate(x_item_seqlens):
|
|
82
|
+
x_attn_mask[i, :seq_len] = 1
|
|
83
|
+
|
|
84
|
+
# Process through noise_refiner
|
|
85
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
86
|
+
for layer in self.noise_refiner:
|
|
87
|
+
x_padded = self._gradient_checkpointing_func(
|
|
88
|
+
layer, x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
for layer in self.noise_refiner:
|
|
92
|
+
x_padded = layer(x_padded, x_attn_mask, x_freqs_cis_padded, adaln_input)
|
|
93
|
+
|
|
94
|
+
# cap embed & refine
|
|
95
|
+
cap_item_seqlens = [len(_) for _ in cap_feats]
|
|
96
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
|
97
|
+
cap_max_item_seqlen = max(cap_item_seqlens)
|
|
98
|
+
|
|
99
|
+
cap_cat = torch.cat(cap_feats, dim=0)
|
|
100
|
+
cap_cat = self.cap_embedder(cap_cat)
|
|
101
|
+
cap_cat[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
|
102
|
+
cap_list = list(cap_cat.split(cap_item_seqlens, dim=0))
|
|
103
|
+
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
|
104
|
+
|
|
105
|
+
cap_padded = pad_sequence(cap_list, batch_first=True, padding_value=0.0)
|
|
106
|
+
cap_freqs_cis_padded = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
|
107
|
+
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
|
108
|
+
for i, seq_len in enumerate(cap_item_seqlens):
|
|
109
|
+
cap_attn_mask[i, :seq_len] = 1
|
|
110
|
+
|
|
111
|
+
# Process through context_refiner
|
|
112
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
113
|
+
for layer in self.context_refiner:
|
|
114
|
+
cap_padded = self._gradient_checkpointing_func(layer, cap_padded, cap_attn_mask, cap_freqs_cis_padded)
|
|
115
|
+
else:
|
|
116
|
+
for layer in self.context_refiner:
|
|
117
|
+
cap_padded = layer(cap_padded, cap_attn_mask, cap_freqs_cis_padded)
|
|
118
|
+
|
|
119
|
+
# Unified sequence: [img_tokens, txt_tokens]
|
|
120
|
+
unified = []
|
|
121
|
+
unified_freqs_cis = []
|
|
122
|
+
for i in range(bsz):
|
|
123
|
+
x_len = x_item_seqlens[i]
|
|
124
|
+
cap_len = cap_item_seqlens[i]
|
|
125
|
+
unified.append(torch.cat([x_padded[i][:x_len], cap_padded[i][:cap_len]]))
|
|
126
|
+
unified_freqs_cis.append(torch.cat([x_freqs_cis_padded[i][:x_len], cap_freqs_cis_padded[i][:cap_len]]))
|
|
127
|
+
|
|
128
|
+
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens, strict=False)]
|
|
129
|
+
assert unified_item_seqlens == [len(_) for _ in unified]
|
|
130
|
+
unified_max_item_seqlen = max(unified_item_seqlens)
|
|
131
|
+
|
|
132
|
+
unified_padded = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
|
133
|
+
unified_freqs_cis_padded = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
|
134
|
+
|
|
135
|
+
# --- REGIONAL ATTENTION MASK INJECTION ---
|
|
136
|
+
# Instead of using the padding mask, we use the regional attention mask
|
|
137
|
+
# The regional mask is (seq_len, seq_len), we need to expand it to (batch, seq_len, seq_len)
|
|
138
|
+
# and then add the batch dimension for broadcasting: (batch, 1, seq_len, seq_len)
|
|
139
|
+
|
|
140
|
+
# Expand regional mask to match the actual sequence length (may include padding)
|
|
141
|
+
if regional_attn_mask.shape[0] != unified_max_item_seqlen:
|
|
142
|
+
# Pad the regional mask to match unified sequence length
|
|
143
|
+
padded_regional_mask = torch.zeros(
|
|
144
|
+
(unified_max_item_seqlen, unified_max_item_seqlen),
|
|
145
|
+
dtype=regional_attn_mask.dtype,
|
|
146
|
+
device=device,
|
|
147
|
+
)
|
|
148
|
+
mask_size = min(regional_attn_mask.shape[0], unified_max_item_seqlen)
|
|
149
|
+
padded_regional_mask[:mask_size, :mask_size] = regional_attn_mask[:mask_size, :mask_size]
|
|
150
|
+
else:
|
|
151
|
+
padded_regional_mask = regional_attn_mask.to(device)
|
|
152
|
+
|
|
153
|
+
# Convert boolean mask to additive float mask for attention
|
|
154
|
+
# True (attend) -> 0.0, False (block) -> -inf
|
|
155
|
+
# This is required because the attention backend expects additive masks for 4D inputs
|
|
156
|
+
# Use bfloat16 to match the transformer's query dtype
|
|
157
|
+
float_mask = torch.zeros_like(padded_regional_mask, dtype=torch.bfloat16)
|
|
158
|
+
float_mask[~padded_regional_mask] = float("-inf")
|
|
159
|
+
|
|
160
|
+
# Expand to (batch, 1, seq_len, seq_len) for attention
|
|
161
|
+
unified_attn_mask = float_mask.unsqueeze(0).unsqueeze(0).expand(bsz, 1, -1, -1)
|
|
162
|
+
|
|
163
|
+
# Process through main layers with regional attention mask
|
|
164
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
165
|
+
for layer_idx, layer in enumerate(self.layers):
|
|
166
|
+
# Alternate between regional mask and full attention
|
|
167
|
+
if layer_idx % 2 == 0:
|
|
168
|
+
unified_padded = self._gradient_checkpointing_func(
|
|
169
|
+
layer, unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
# Use padding mask only for odd layers (allows global coherence)
|
|
173
|
+
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
|
174
|
+
for i, seq_len in enumerate(unified_item_seqlens):
|
|
175
|
+
padding_mask[i, :seq_len] = 1
|
|
176
|
+
unified_padded = self._gradient_checkpointing_func(
|
|
177
|
+
layer, unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input
|
|
178
|
+
)
|
|
179
|
+
else:
|
|
180
|
+
for layer_idx, layer in enumerate(self.layers):
|
|
181
|
+
# Alternate between regional mask and full attention
|
|
182
|
+
if layer_idx % 2 == 0:
|
|
183
|
+
unified_padded = layer(unified_padded, unified_attn_mask, unified_freqs_cis_padded, adaln_input)
|
|
184
|
+
else:
|
|
185
|
+
# Use padding mask only for odd layers (allows global coherence)
|
|
186
|
+
padding_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
|
187
|
+
for i, seq_len in enumerate(unified_item_seqlens):
|
|
188
|
+
padding_mask[i, :seq_len] = 1
|
|
189
|
+
unified_padded = layer(unified_padded, padding_mask, unified_freqs_cis_padded, adaln_input)
|
|
190
|
+
|
|
191
|
+
# Final layer
|
|
192
|
+
unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified_padded, adaln_input)
|
|
193
|
+
unified_list = list(unified_out.unbind(dim=0))
|
|
194
|
+
x_out = self.unpatchify(unified_list, x_size, patch_size, f_patch_size)
|
|
195
|
+
|
|
196
|
+
return x_out, {}
|
|
197
|
+
|
|
198
|
+
return regional_forward
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@contextmanager
|
|
202
|
+
def patch_transformer_for_regional_prompting(
|
|
203
|
+
transformer,
|
|
204
|
+
regional_attn_mask: Optional[torch.Tensor],
|
|
205
|
+
img_seq_len: int,
|
|
206
|
+
):
|
|
207
|
+
"""Context manager to temporarily patch the transformer for regional prompting.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
transformer: The ZImageTransformer2DModel instance.
|
|
211
|
+
regional_attn_mask: Regional attention mask of shape (seq_len, seq_len).
|
|
212
|
+
If None, the transformer is not patched.
|
|
213
|
+
img_seq_len: Number of image tokens.
|
|
214
|
+
|
|
215
|
+
Yields:
|
|
216
|
+
The (possibly patched) transformer.
|
|
217
|
+
"""
|
|
218
|
+
if regional_attn_mask is None:
|
|
219
|
+
# No regional prompting, use original forward
|
|
220
|
+
yield transformer
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
# Store original forward
|
|
224
|
+
original_forward = transformer.forward
|
|
225
|
+
|
|
226
|
+
# Create and bind the regional forward
|
|
227
|
+
regional_fwd = create_regional_forward(original_forward, regional_attn_mask, img_seq_len)
|
|
228
|
+
transformer.forward = lambda *args, **kwargs: regional_fwd(transformer, *args, **kwargs)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
yield transformer
|
|
232
|
+
finally:
|
|
233
|
+
# Restore original forward
|
|
234
|
+
transformer.forward = original_forward
|