InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -14,11 +14,31 @@ def dequantize_and_run(func, args, kwargs):
|
|
|
14
14
|
"""A helper function for running math ops on GGMLTensor inputs.
|
|
15
15
|
|
|
16
16
|
Dequantizes the inputs, and runs the function.
|
|
17
|
+
Also casts other floating point tensors to match the compute_dtype of GGMLTensors
|
|
18
|
+
to avoid dtype mismatches in matrix operations.
|
|
17
19
|
"""
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
# Find the compute_dtype from any GGMLTensor in the args
|
|
21
|
+
compute_dtype = None
|
|
22
|
+
for a in args:
|
|
23
|
+
if hasattr(a, "compute_dtype"):
|
|
24
|
+
compute_dtype = a.compute_dtype
|
|
25
|
+
break
|
|
26
|
+
if compute_dtype is None:
|
|
27
|
+
for v in kwargs.values():
|
|
28
|
+
if hasattr(v, "compute_dtype"):
|
|
29
|
+
compute_dtype = v.compute_dtype
|
|
30
|
+
break
|
|
31
|
+
|
|
32
|
+
def process_tensor(t):
|
|
33
|
+
if hasattr(t, "get_dequantized_tensor"):
|
|
34
|
+
return t.get_dequantized_tensor()
|
|
35
|
+
elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
|
|
36
|
+
# Cast other floating point tensors to match the GGUF compute_dtype
|
|
37
|
+
return t.to(compute_dtype)
|
|
38
|
+
return t
|
|
39
|
+
|
|
40
|
+
dequantized_args = [process_tensor(a) for a in args]
|
|
41
|
+
dequantized_kwargs = {k: process_tensor(v) for k, v in kwargs.items()}
|
|
22
42
|
return func(*dequantized_args, **dequantized_kwargs)
|
|
23
43
|
|
|
24
44
|
|
|
@@ -57,6 +77,9 @@ GGML_TENSOR_OP_TABLE = {
|
|
|
57
77
|
torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
|
|
58
78
|
torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
|
|
59
79
|
torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
|
|
80
|
+
torch.ops.aten.view.default: dequantize_and_run, # pyright: ignore
|
|
81
|
+
torch.ops.aten.expand.default: dequantize_and_run, # pyright: ignore
|
|
82
|
+
torch.ops.aten.index_put_.default: dequantize_and_run, # pyright: ignore
|
|
60
83
|
}
|
|
61
84
|
|
|
62
85
|
if torch.backends.mps.is_available():
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import gc
|
|
1
2
|
from pathlib import Path
|
|
2
3
|
|
|
3
4
|
import gguf
|
|
@@ -5,18 +6,52 @@ import torch
|
|
|
5
6
|
|
|
6
7
|
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
|
7
8
|
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
|
|
9
|
+
from invokeai.backend.util.logging import InvokeAILogger
|
|
10
|
+
|
|
11
|
+
logger = InvokeAILogger.get_logger()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WrappedGGUFReader:
|
|
15
|
+
"""Wrapper around GGUFReader that adds a close() method."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, path: Path):
|
|
18
|
+
self.reader = gguf.GGUFReader(path)
|
|
19
|
+
|
|
20
|
+
def __enter__(self):
|
|
21
|
+
return self.reader
|
|
22
|
+
|
|
23
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
24
|
+
self.close()
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
def close(self):
|
|
28
|
+
"""Explicitly close the memory-mapped file."""
|
|
29
|
+
if hasattr(self.reader, "data"):
|
|
30
|
+
try:
|
|
31
|
+
self.reader.data.flush()
|
|
32
|
+
del self.reader.data
|
|
33
|
+
except (AttributeError, OSError, ValueError) as e:
|
|
34
|
+
logger.warning(f"Wasn't able to close GGUF memory map: {e}")
|
|
35
|
+
del self.reader
|
|
36
|
+
gc.collect()
|
|
8
37
|
|
|
9
38
|
|
|
10
39
|
def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
40
|
+
with WrappedGGUFReader(path) as reader:
|
|
41
|
+
sd: dict[str, GGMLTensor] = {}
|
|
42
|
+
for tensor in reader.tensors:
|
|
43
|
+
# Use .copy() to create a true copy of the data, not a view.
|
|
44
|
+
# This is critical on Windows where the memory-mapped file cannot be deleted
|
|
45
|
+
# while tensors still hold references to the mapped memory.
|
|
46
|
+
torch_tensor = torch.from_numpy(tensor.data.copy())
|
|
47
|
+
|
|
48
|
+
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
|
49
|
+
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
|
|
50
|
+
torch_tensor = torch_tensor.view(*shape)
|
|
51
|
+
sd[tensor.name] = GGMLTensor(
|
|
52
|
+
torch_tensor,
|
|
53
|
+
ggml_quantization_type=tensor.tensor_type,
|
|
54
|
+
tensor_shape=shape,
|
|
55
|
+
compute_dtype=compute_dtype,
|
|
56
|
+
)
|
|
57
|
+
return sd
|
|
@@ -76,6 +76,18 @@ class CogView4ConditioningInfo:
|
|
|
76
76
|
return self
|
|
77
77
|
|
|
78
78
|
|
|
79
|
+
@dataclass
|
|
80
|
+
class ZImageConditioningInfo:
|
|
81
|
+
"""Z-Image text conditioning information from Qwen3 text encoder."""
|
|
82
|
+
|
|
83
|
+
prompt_embeds: torch.Tensor
|
|
84
|
+
"""Text embeddings from Qwen3 encoder. Shape: (batch_size, seq_len, hidden_size)."""
|
|
85
|
+
|
|
86
|
+
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
|
|
87
|
+
self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype)
|
|
88
|
+
return self
|
|
89
|
+
|
|
90
|
+
|
|
79
91
|
@dataclass
|
|
80
92
|
class ConditioningFieldData:
|
|
81
93
|
# If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
|
|
@@ -87,6 +99,7 @@ class ConditioningFieldData:
|
|
|
87
99
|
| List[FLUXConditioningInfo]
|
|
88
100
|
| List[SD3ConditioningInfo]
|
|
89
101
|
| List[CogView4ConditioningInfo]
|
|
102
|
+
| List[ZImageConditioningInfo]
|
|
90
103
|
)
|
|
91
104
|
|
|
92
105
|
|
invokeai/backend/util/devices.py
CHANGED
|
@@ -112,3 +112,28 @@ class TorchDevice:
|
|
|
112
112
|
@classmethod
|
|
113
113
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
|
114
114
|
return NAME_TO_PRECISION[precision_name]
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
|
118
|
+
"""Return bfloat16 if supported on the device, else fallback to float16/float32.
|
|
119
|
+
|
|
120
|
+
This is useful for models that require bfloat16 precision (e.g., Z-Image, Flux)
|
|
121
|
+
but need to run on hardware that may not support bfloat16.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
device: The target device. If None, uses choose_torch_device().
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
torch.bfloat16 if supported, torch.float16 for CUDA without bfloat16 support,
|
|
128
|
+
or torch.float32 for CPU/MPS.
|
|
129
|
+
"""
|
|
130
|
+
device = device or cls.choose_torch_device()
|
|
131
|
+
try:
|
|
132
|
+
# Test if bfloat16 is supported on this device
|
|
133
|
+
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
|
134
|
+
return torch.bfloat16
|
|
135
|
+
except TypeError:
|
|
136
|
+
# bfloat16 not supported - fallback based on device type
|
|
137
|
+
if device.type == "cuda":
|
|
138
|
+
return torch.float16
|
|
139
|
+
return torch.float32
|
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
6
6
|
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
|
7
7
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
|
8
|
-
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
|
8
|
+
from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
|
9
9
|
from diffusers.models.embeddings import (
|
|
10
10
|
TextImageProjection,
|
|
11
11
|
TextImageTimeEmbedding,
|
|
@@ -777,7 +777,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
|
777
777
|
|
|
778
778
|
|
|
779
779
|
diffusers.ControlNetModel = ControlNetModel
|
|
780
|
-
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
|
780
|
+
diffusers.models.controlnets.controlnet.ControlNetModel = ControlNetModel
|
|
781
781
|
|
|
782
782
|
|
|
783
783
|
# patch LoRACompatibleConv to use original Conv2D forward function
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Z-Image backend utilities
|
|
2
|
+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
|
3
|
+
from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
|
|
4
|
+
from invokeai.backend.z_image.z_image_controlnet_extension import (
|
|
5
|
+
ZImageControlNetExtension,
|
|
6
|
+
z_image_forward_with_control,
|
|
7
|
+
)
|
|
8
|
+
from invokeai.backend.z_image.z_image_patchify_utils import patchify_control_context
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ZImageControlAdapter",
|
|
12
|
+
"ZImageControlTransformer2DModel",
|
|
13
|
+
"ZImageControlNetExtension",
|
|
14
|
+
"z_image_forward_with_control",
|
|
15
|
+
"patchify_control_context",
|
|
16
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Z-Image extensions
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchvision
|
|
5
|
+
|
|
6
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
7
|
+
from invokeai.backend.util.mask import to_standard_float_mask
|
|
8
|
+
from invokeai.backend.z_image.text_conditioning import ZImageRegionalTextConditioning, ZImageTextConditioning
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ZImageRegionalPromptingExtension:
|
|
12
|
+
"""A class for managing regional prompting with Z-Image.
|
|
13
|
+
|
|
14
|
+
This implementation is inspired by the FLUX regional prompting extension and
|
|
15
|
+
the paper https://arxiv.org/pdf/2411.02395.
|
|
16
|
+
|
|
17
|
+
Key difference from FLUX: Z-Image uses sequence order [img_tokens, txt_tokens],
|
|
18
|
+
while FLUX uses [txt_tokens, img_tokens]. The attention mask construction
|
|
19
|
+
accounts for this difference.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
regional_text_conditioning: ZImageRegionalTextConditioning,
|
|
25
|
+
regional_attn_mask: torch.Tensor | None = None,
|
|
26
|
+
):
|
|
27
|
+
self.regional_text_conditioning = regional_text_conditioning
|
|
28
|
+
self.regional_attn_mask = regional_attn_mask
|
|
29
|
+
|
|
30
|
+
def get_attn_mask(self, block_index: int) -> torch.Tensor | None:
|
|
31
|
+
"""Get the attention mask for a given block index.
|
|
32
|
+
|
|
33
|
+
Uses alternating pattern: apply mask on even blocks, no mask on odd blocks.
|
|
34
|
+
This helps balance regional control with global coherence.
|
|
35
|
+
"""
|
|
36
|
+
order = [self.regional_attn_mask, None]
|
|
37
|
+
return order[block_index % len(order)]
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_text_conditionings(
|
|
41
|
+
cls,
|
|
42
|
+
text_conditionings: list[ZImageTextConditioning],
|
|
43
|
+
img_seq_len: int,
|
|
44
|
+
) -> "ZImageRegionalPromptingExtension":
|
|
45
|
+
"""Create a ZImageRegionalPromptingExtension from a list of text conditionings.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
text_conditionings: List of text conditionings with optional masks.
|
|
49
|
+
img_seq_len: The image sequence length (i.e. (H // patch_size) * (W // patch_size)).
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A configured ZImageRegionalPromptingExtension.
|
|
53
|
+
"""
|
|
54
|
+
regional_text_conditioning = ZImageRegionalTextConditioning.from_text_conditionings(text_conditionings)
|
|
55
|
+
attn_mask = cls._prepare_regional_attn_mask(regional_text_conditioning, img_seq_len)
|
|
56
|
+
return cls(
|
|
57
|
+
regional_text_conditioning=regional_text_conditioning,
|
|
58
|
+
regional_attn_mask=attn_mask,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def _prepare_regional_attn_mask(
|
|
63
|
+
cls,
|
|
64
|
+
regional_text_conditioning: ZImageRegionalTextConditioning,
|
|
65
|
+
img_seq_len: int,
|
|
66
|
+
) -> torch.Tensor | None:
|
|
67
|
+
"""Prepare a regional attention mask for Z-Image.
|
|
68
|
+
|
|
69
|
+
The mask controls which tokens can attend to each other:
|
|
70
|
+
- Image tokens within a region attend only to each other
|
|
71
|
+
- Image tokens attend only to their corresponding regional text
|
|
72
|
+
- Text tokens attend only to their corresponding regional image
|
|
73
|
+
- Text tokens attend to themselves
|
|
74
|
+
|
|
75
|
+
Z-Image sequence order: [img_tokens, txt_tokens]
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
regional_text_conditioning: The regional text conditioning data.
|
|
79
|
+
img_seq_len: Number of image tokens.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Attention mask of shape (img_seq_len + txt_seq_len, img_seq_len + txt_seq_len).
|
|
83
|
+
Returns None if no regional masks are present.
|
|
84
|
+
"""
|
|
85
|
+
# Check if any regional masks exist
|
|
86
|
+
has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
|
|
87
|
+
if not has_regional_masks:
|
|
88
|
+
# No regional masks, return None to use default attention
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
# Identify background region (area not covered by any mask)
|
|
92
|
+
background_region_mask: torch.Tensor | None = None
|
|
93
|
+
for image_mask in regional_text_conditioning.image_masks:
|
|
94
|
+
if image_mask is not None:
|
|
95
|
+
# image_mask shape: (1, 1, img_seq_len) -> flatten to (img_seq_len,)
|
|
96
|
+
mask_flat = image_mask.view(-1)
|
|
97
|
+
if background_region_mask is None:
|
|
98
|
+
background_region_mask = torch.ones_like(mask_flat)
|
|
99
|
+
background_region_mask = background_region_mask * (1 - mask_flat)
|
|
100
|
+
|
|
101
|
+
device = TorchDevice.choose_torch_device()
|
|
102
|
+
txt_seq_len = regional_text_conditioning.prompt_embeds.shape[0]
|
|
103
|
+
total_seq_len = img_seq_len + txt_seq_len
|
|
104
|
+
|
|
105
|
+
# Initialize empty attention mask
|
|
106
|
+
# Z-Image sequence: [img_tokens (0:img_seq_len), txt_tokens (img_seq_len:total_seq_len)]
|
|
107
|
+
regional_attention_mask = torch.zeros((total_seq_len, total_seq_len), device=device, dtype=torch.float16)
|
|
108
|
+
|
|
109
|
+
for image_mask, embedding_range in zip(
|
|
110
|
+
regional_text_conditioning.image_masks,
|
|
111
|
+
regional_text_conditioning.embedding_ranges,
|
|
112
|
+
strict=True,
|
|
113
|
+
):
|
|
114
|
+
# Calculate text token positions in the unified sequence
|
|
115
|
+
txt_start = img_seq_len + embedding_range.start
|
|
116
|
+
txt_end = img_seq_len + embedding_range.end
|
|
117
|
+
|
|
118
|
+
# 1. txt attends to itself
|
|
119
|
+
regional_attention_mask[txt_start:txt_end, txt_start:txt_end] = 1.0
|
|
120
|
+
|
|
121
|
+
if image_mask is not None:
|
|
122
|
+
# Flatten mask: (1, 1, img_seq_len) -> (img_seq_len,)
|
|
123
|
+
mask_flat = image_mask.view(img_seq_len)
|
|
124
|
+
|
|
125
|
+
# 2. img attends to corresponding regional txt
|
|
126
|
+
# Reshape mask to (img_seq_len, 1) for broadcasting
|
|
127
|
+
regional_attention_mask[:img_seq_len, txt_start:txt_end] = mask_flat.view(img_seq_len, 1)
|
|
128
|
+
|
|
129
|
+
# 3. txt attends to corresponding regional img
|
|
130
|
+
# Reshape mask to (1, img_seq_len) for broadcasting
|
|
131
|
+
regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len)
|
|
132
|
+
|
|
133
|
+
# 4. img self-attention within region
|
|
134
|
+
# mask @ mask.T creates pairwise attention within the masked region
|
|
135
|
+
regional_attention_mask[:img_seq_len, :img_seq_len] += mask_flat.view(img_seq_len, 1) @ mask_flat.view(
|
|
136
|
+
1, img_seq_len
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
# Global prompt: allow attention to/from background regions only
|
|
140
|
+
if background_region_mask is not None:
|
|
141
|
+
# 2. background img attends to global txt
|
|
142
|
+
regional_attention_mask[:img_seq_len, txt_start:txt_end] = background_region_mask.view(
|
|
143
|
+
img_seq_len, 1
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# 3. global txt attends to background img
|
|
147
|
+
regional_attention_mask[txt_start:txt_end, :img_seq_len] = background_region_mask.view(
|
|
148
|
+
1, img_seq_len
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
# No regional masks at all, allow full attention
|
|
152
|
+
regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0
|
|
153
|
+
regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0
|
|
154
|
+
|
|
155
|
+
# Allow background regions to attend to themselves
|
|
156
|
+
if background_region_mask is not None:
|
|
157
|
+
bg_mask = background_region_mask.view(img_seq_len, 1)
|
|
158
|
+
regional_attention_mask[:img_seq_len, :img_seq_len] += bg_mask @ bg_mask.T
|
|
159
|
+
|
|
160
|
+
# Convert to boolean mask
|
|
161
|
+
regional_attention_mask = regional_attention_mask > 0.5
|
|
162
|
+
|
|
163
|
+
return regional_attention_mask
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def preprocess_regional_prompt_mask(
|
|
167
|
+
mask: Optional[torch.Tensor],
|
|
168
|
+
target_height: int,
|
|
169
|
+
target_width: int,
|
|
170
|
+
dtype: torch.dtype,
|
|
171
|
+
device: torch.device,
|
|
172
|
+
) -> torch.Tensor:
|
|
173
|
+
"""Preprocess a regional prompt mask to match the target image token grid.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
mask: Input mask tensor. If None, returns a mask of all ones.
|
|
177
|
+
target_height: Height of the image token grid (H // patch_size).
|
|
178
|
+
target_width: Width of the image token grid (W // patch_size).
|
|
179
|
+
dtype: Target dtype for the mask.
|
|
180
|
+
device: Target device for the mask.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Processed mask of shape (1, 1, target_height * target_width).
|
|
184
|
+
"""
|
|
185
|
+
img_seq_len = target_height * target_width
|
|
186
|
+
|
|
187
|
+
if mask is None:
|
|
188
|
+
return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device)
|
|
189
|
+
|
|
190
|
+
mask = to_standard_float_mask(mask, out_dtype=dtype)
|
|
191
|
+
|
|
192
|
+
# Resize mask to target dimensions
|
|
193
|
+
tf = torchvision.transforms.Resize(
|
|
194
|
+
(target_height, target_width),
|
|
195
|
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Add batch dimension if needed: (h, w) -> (1, h, w) -> (1, 1, h, w)
|
|
199
|
+
if mask.ndim == 2:
|
|
200
|
+
mask = mask.unsqueeze(0)
|
|
201
|
+
if mask.ndim == 3:
|
|
202
|
+
mask = mask.unsqueeze(0)
|
|
203
|
+
|
|
204
|
+
resized_mask = tf(mask)
|
|
205
|
+
|
|
206
|
+
# Flatten to (1, 1, img_seq_len)
|
|
207
|
+
return resized_mask.flatten(start_dim=2).to(device=device)
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ZImageTextConditioning:
|
|
10
|
+
"""Z-Image text conditioning with optional regional mask.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
prompt_embeds: Text embeddings from Qwen3 encoder. Shape: (seq_len, hidden_size).
|
|
14
|
+
mask: Optional binary mask for regional prompting. If None, the prompt is global.
|
|
15
|
+
Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
prompt_embeds: torch.Tensor
|
|
19
|
+
mask: torch.Tensor | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ZImageRegionalTextConditioning:
|
|
24
|
+
"""Container for multiple regional text conditionings concatenated together.
|
|
25
|
+
|
|
26
|
+
In Z-Image, the unified sequence is [img_tokens, txt_tokens], which is different
|
|
27
|
+
from FLUX where it's [txt_tokens, img_tokens]. The attention mask must account for this.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
prompt_embeds: Concatenated text embeddings from all regional prompts.
|
|
31
|
+
Shape: (total_seq_len, hidden_size).
|
|
32
|
+
image_masks: List of binary masks for each regional prompt.
|
|
33
|
+
image_masks[i] corresponds to embedding_ranges[i].
|
|
34
|
+
If None, the prompt is global (applies to entire image).
|
|
35
|
+
Shape: (1, 1, img_seq_len).
|
|
36
|
+
embedding_ranges: List of ranges indicating which portion of prompt_embeds
|
|
37
|
+
corresponds to each regional prompt.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
prompt_embeds: torch.Tensor
|
|
41
|
+
image_masks: list[torch.Tensor | None]
|
|
42
|
+
embedding_ranges: list[Range]
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_text_conditionings(
|
|
46
|
+
cls,
|
|
47
|
+
text_conditionings: list[ZImageTextConditioning],
|
|
48
|
+
) -> "ZImageRegionalTextConditioning":
|
|
49
|
+
"""Create a ZImageRegionalTextConditioning from a list of ZImageTextConditioning objects.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
text_conditionings: List of text conditionings, each with optional mask.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A single ZImageRegionalTextConditioning with concatenated embeddings.
|
|
56
|
+
"""
|
|
57
|
+
concat_embeds: list[torch.Tensor] = []
|
|
58
|
+
concat_ranges: list[Range] = []
|
|
59
|
+
image_masks: list[torch.Tensor | None] = []
|
|
60
|
+
|
|
61
|
+
cur_embed_len = 0
|
|
62
|
+
for tc in text_conditionings:
|
|
63
|
+
concat_embeds.append(tc.prompt_embeds)
|
|
64
|
+
concat_ranges.append(Range(start=cur_embed_len, end=cur_embed_len + tc.prompt_embeds.shape[0]))
|
|
65
|
+
image_masks.append(tc.mask)
|
|
66
|
+
cur_embed_len += tc.prompt_embeds.shape[0]
|
|
67
|
+
|
|
68
|
+
prompt_embeds = torch.cat(concat_embeds, dim=0)
|
|
69
|
+
|
|
70
|
+
return cls(
|
|
71
|
+
prompt_embeds=prompt_embeds,
|
|
72
|
+
image_masks=image_masks,
|
|
73
|
+
embedding_ranges=concat_ranges,
|
|
74
|
+
)
|