InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,31 @@ def dequantize_and_run(func, args, kwargs):
14
14
  """A helper function for running math ops on GGMLTensor inputs.
15
15
 
16
16
  Dequantizes the inputs, and runs the function.
17
+ Also casts other floating point tensors to match the compute_dtype of GGMLTensors
18
+ to avoid dtype mismatches in matrix operations.
17
19
  """
18
- dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
19
- dequantized_kwargs = {
20
- k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
21
- }
20
+ # Find the compute_dtype from any GGMLTensor in the args
21
+ compute_dtype = None
22
+ for a in args:
23
+ if hasattr(a, "compute_dtype"):
24
+ compute_dtype = a.compute_dtype
25
+ break
26
+ if compute_dtype is None:
27
+ for v in kwargs.values():
28
+ if hasattr(v, "compute_dtype"):
29
+ compute_dtype = v.compute_dtype
30
+ break
31
+
32
+ def process_tensor(t):
33
+ if hasattr(t, "get_dequantized_tensor"):
34
+ return t.get_dequantized_tensor()
35
+ elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
36
+ # Cast other floating point tensors to match the GGUF compute_dtype
37
+ return t.to(compute_dtype)
38
+ return t
39
+
40
+ dequantized_args = [process_tensor(a) for a in args]
41
+ dequantized_kwargs = {k: process_tensor(v) for k, v in kwargs.items()}
22
42
  return func(*dequantized_args, **dequantized_kwargs)
23
43
 
24
44
 
@@ -57,6 +77,9 @@ GGML_TENSOR_OP_TABLE = {
57
77
  torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
58
78
  torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
59
79
  torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
80
+ torch.ops.aten.view.default: dequantize_and_run, # pyright: ignore
81
+ torch.ops.aten.expand.default: dequantize_and_run, # pyright: ignore
82
+ torch.ops.aten.index_put_.default: dequantize_and_run, # pyright: ignore
60
83
  }
61
84
 
62
85
  if torch.backends.mps.is_available():
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from pathlib import Path
2
3
 
3
4
  import gguf
@@ -5,18 +6,52 @@ import torch
5
6
 
6
7
  from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
7
8
  from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
9
+ from invokeai.backend.util.logging import InvokeAILogger
10
+
11
+ logger = InvokeAILogger.get_logger()
12
+
13
+
14
+ class WrappedGGUFReader:
15
+ """Wrapper around GGUFReader that adds a close() method."""
16
+
17
+ def __init__(self, path: Path):
18
+ self.reader = gguf.GGUFReader(path)
19
+
20
+ def __enter__(self):
21
+ return self.reader
22
+
23
+ def __exit__(self, exc_type, exc_val, exc_tb):
24
+ self.close()
25
+ return False
26
+
27
+ def close(self):
28
+ """Explicitly close the memory-mapped file."""
29
+ if hasattr(self.reader, "data"):
30
+ try:
31
+ self.reader.data.flush()
32
+ del self.reader.data
33
+ except (AttributeError, OSError, ValueError) as e:
34
+ logger.warning(f"Wasn't able to close GGUF memory map: {e}")
35
+ del self.reader
36
+ gc.collect()
8
37
 
9
38
 
10
39
  def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
11
- reader = gguf.GGUFReader(path)
12
-
13
- sd: dict[str, GGMLTensor] = {}
14
- for tensor in reader.tensors:
15
- torch_tensor = torch.from_numpy(tensor.data)
16
- shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
17
- if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
18
- torch_tensor = torch_tensor.view(*shape)
19
- sd[tensor.name] = GGMLTensor(
20
- torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
21
- )
22
- return sd
40
+ with WrappedGGUFReader(path) as reader:
41
+ sd: dict[str, GGMLTensor] = {}
42
+ for tensor in reader.tensors:
43
+ # Use .copy() to create a true copy of the data, not a view.
44
+ # This is critical on Windows where the memory-mapped file cannot be deleted
45
+ # while tensors still hold references to the mapped memory.
46
+ torch_tensor = torch.from_numpy(tensor.data.copy())
47
+
48
+ shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
49
+ if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
50
+ torch_tensor = torch_tensor.view(*shape)
51
+ sd[tensor.name] = GGMLTensor(
52
+ torch_tensor,
53
+ ggml_quantization_type=tensor.tensor_type,
54
+ tensor_shape=shape,
55
+ compute_dtype=compute_dtype,
56
+ )
57
+ return sd
@@ -76,6 +76,18 @@ class CogView4ConditioningInfo:
76
76
  return self
77
77
 
78
78
 
79
+ @dataclass
80
+ class ZImageConditioningInfo:
81
+ """Z-Image text conditioning information from Qwen3 text encoder."""
82
+
83
+ prompt_embeds: torch.Tensor
84
+ """Text embeddings from Qwen3 encoder. Shape: (batch_size, seq_len, hidden_size)."""
85
+
86
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
87
+ self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype)
88
+ return self
89
+
90
+
79
91
  @dataclass
80
92
  class ConditioningFieldData:
81
93
  # If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
@@ -87,6 +99,7 @@ class ConditioningFieldData:
87
99
  | List[FLUXConditioningInfo]
88
100
  | List[SD3ConditioningInfo]
89
101
  | List[CogView4ConditioningInfo]
102
+ | List[ZImageConditioningInfo]
90
103
  )
91
104
 
92
105
 
@@ -112,3 +112,28 @@ class TorchDevice:
112
112
  @classmethod
113
113
  def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
114
114
  return NAME_TO_PRECISION[precision_name]
115
+
116
+ @classmethod
117
+ def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
118
+ """Return bfloat16 if supported on the device, else fallback to float16/float32.
119
+
120
+ This is useful for models that require bfloat16 precision (e.g., Z-Image, Flux)
121
+ but need to run on hardware that may not support bfloat16.
122
+
123
+ Args:
124
+ device: The target device. If None, uses choose_torch_device().
125
+
126
+ Returns:
127
+ torch.bfloat16 if supported, torch.float16 for CUDA without bfloat16 support,
128
+ or torch.float32 for CPU/MPS.
129
+ """
130
+ device = device or cls.choose_torch_device()
131
+ try:
132
+ # Test if bfloat16 is supported on this device
133
+ torch.tensor([1.0], dtype=torch.bfloat16, device=device)
134
+ return torch.bfloat16
135
+ except TypeError:
136
+ # bfloat16 not supported - fallback based on device type
137
+ if device.type == "cuda":
138
+ return torch.float16
139
+ return torch.float32
@@ -5,7 +5,7 @@ import torch
5
5
  from diffusers.configuration_utils import ConfigMixin, register_to_config
6
6
  from diffusers.loaders.single_file_model import FromOriginalModelMixin
7
7
  from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
8
- from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
8
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
9
9
  from diffusers.models.embeddings import (
10
10
  TextImageProjection,
11
11
  TextImageTimeEmbedding,
@@ -777,7 +777,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
777
777
 
778
778
 
779
779
  diffusers.ControlNetModel = ControlNetModel
780
- diffusers.models.controlnet.ControlNetModel = ControlNetModel
780
+ diffusers.models.controlnets.controlnet.ControlNetModel = ControlNetModel
781
781
 
782
782
 
783
783
  # patch LoRACompatibleConv to use original Conv2D forward function
@@ -0,0 +1,16 @@
1
+ # Z-Image backend utilities
2
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
3
+ from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
4
+ from invokeai.backend.z_image.z_image_controlnet_extension import (
5
+ ZImageControlNetExtension,
6
+ z_image_forward_with_control,
7
+ )
8
+ from invokeai.backend.z_image.z_image_patchify_utils import patchify_control_context
9
+
10
+ __all__ = [
11
+ "ZImageControlAdapter",
12
+ "ZImageControlTransformer2DModel",
13
+ "ZImageControlNetExtension",
14
+ "z_image_forward_with_control",
15
+ "patchify_control_context",
16
+ ]
@@ -0,0 +1 @@
1
+ # Z-Image extensions
@@ -0,0 +1,207 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torchvision
5
+
6
+ from invokeai.backend.util.devices import TorchDevice
7
+ from invokeai.backend.util.mask import to_standard_float_mask
8
+ from invokeai.backend.z_image.text_conditioning import ZImageRegionalTextConditioning, ZImageTextConditioning
9
+
10
+
11
+ class ZImageRegionalPromptingExtension:
12
+ """A class for managing regional prompting with Z-Image.
13
+
14
+ This implementation is inspired by the FLUX regional prompting extension and
15
+ the paper https://arxiv.org/pdf/2411.02395.
16
+
17
+ Key difference from FLUX: Z-Image uses sequence order [img_tokens, txt_tokens],
18
+ while FLUX uses [txt_tokens, img_tokens]. The attention mask construction
19
+ accounts for this difference.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ regional_text_conditioning: ZImageRegionalTextConditioning,
25
+ regional_attn_mask: torch.Tensor | None = None,
26
+ ):
27
+ self.regional_text_conditioning = regional_text_conditioning
28
+ self.regional_attn_mask = regional_attn_mask
29
+
30
+ def get_attn_mask(self, block_index: int) -> torch.Tensor | None:
31
+ """Get the attention mask for a given block index.
32
+
33
+ Uses alternating pattern: apply mask on even blocks, no mask on odd blocks.
34
+ This helps balance regional control with global coherence.
35
+ """
36
+ order = [self.regional_attn_mask, None]
37
+ return order[block_index % len(order)]
38
+
39
+ @classmethod
40
+ def from_text_conditionings(
41
+ cls,
42
+ text_conditionings: list[ZImageTextConditioning],
43
+ img_seq_len: int,
44
+ ) -> "ZImageRegionalPromptingExtension":
45
+ """Create a ZImageRegionalPromptingExtension from a list of text conditionings.
46
+
47
+ Args:
48
+ text_conditionings: List of text conditionings with optional masks.
49
+ img_seq_len: The image sequence length (i.e. (H // patch_size) * (W // patch_size)).
50
+
51
+ Returns:
52
+ A configured ZImageRegionalPromptingExtension.
53
+ """
54
+ regional_text_conditioning = ZImageRegionalTextConditioning.from_text_conditionings(text_conditionings)
55
+ attn_mask = cls._prepare_regional_attn_mask(regional_text_conditioning, img_seq_len)
56
+ return cls(
57
+ regional_text_conditioning=regional_text_conditioning,
58
+ regional_attn_mask=attn_mask,
59
+ )
60
+
61
+ @classmethod
62
+ def _prepare_regional_attn_mask(
63
+ cls,
64
+ regional_text_conditioning: ZImageRegionalTextConditioning,
65
+ img_seq_len: int,
66
+ ) -> torch.Tensor | None:
67
+ """Prepare a regional attention mask for Z-Image.
68
+
69
+ The mask controls which tokens can attend to each other:
70
+ - Image tokens within a region attend only to each other
71
+ - Image tokens attend only to their corresponding regional text
72
+ - Text tokens attend only to their corresponding regional image
73
+ - Text tokens attend to themselves
74
+
75
+ Z-Image sequence order: [img_tokens, txt_tokens]
76
+
77
+ Args:
78
+ regional_text_conditioning: The regional text conditioning data.
79
+ img_seq_len: Number of image tokens.
80
+
81
+ Returns:
82
+ Attention mask of shape (img_seq_len + txt_seq_len, img_seq_len + txt_seq_len).
83
+ Returns None if no regional masks are present.
84
+ """
85
+ # Check if any regional masks exist
86
+ has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
87
+ if not has_regional_masks:
88
+ # No regional masks, return None to use default attention
89
+ return None
90
+
91
+ # Identify background region (area not covered by any mask)
92
+ background_region_mask: torch.Tensor | None = None
93
+ for image_mask in regional_text_conditioning.image_masks:
94
+ if image_mask is not None:
95
+ # image_mask shape: (1, 1, img_seq_len) -> flatten to (img_seq_len,)
96
+ mask_flat = image_mask.view(-1)
97
+ if background_region_mask is None:
98
+ background_region_mask = torch.ones_like(mask_flat)
99
+ background_region_mask = background_region_mask * (1 - mask_flat)
100
+
101
+ device = TorchDevice.choose_torch_device()
102
+ txt_seq_len = regional_text_conditioning.prompt_embeds.shape[0]
103
+ total_seq_len = img_seq_len + txt_seq_len
104
+
105
+ # Initialize empty attention mask
106
+ # Z-Image sequence: [img_tokens (0:img_seq_len), txt_tokens (img_seq_len:total_seq_len)]
107
+ regional_attention_mask = torch.zeros((total_seq_len, total_seq_len), device=device, dtype=torch.float16)
108
+
109
+ for image_mask, embedding_range in zip(
110
+ regional_text_conditioning.image_masks,
111
+ regional_text_conditioning.embedding_ranges,
112
+ strict=True,
113
+ ):
114
+ # Calculate text token positions in the unified sequence
115
+ txt_start = img_seq_len + embedding_range.start
116
+ txt_end = img_seq_len + embedding_range.end
117
+
118
+ # 1. txt attends to itself
119
+ regional_attention_mask[txt_start:txt_end, txt_start:txt_end] = 1.0
120
+
121
+ if image_mask is not None:
122
+ # Flatten mask: (1, 1, img_seq_len) -> (img_seq_len,)
123
+ mask_flat = image_mask.view(img_seq_len)
124
+
125
+ # 2. img attends to corresponding regional txt
126
+ # Reshape mask to (img_seq_len, 1) for broadcasting
127
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = mask_flat.view(img_seq_len, 1)
128
+
129
+ # 3. txt attends to corresponding regional img
130
+ # Reshape mask to (1, img_seq_len) for broadcasting
131
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len)
132
+
133
+ # 4. img self-attention within region
134
+ # mask @ mask.T creates pairwise attention within the masked region
135
+ regional_attention_mask[:img_seq_len, :img_seq_len] += mask_flat.view(img_seq_len, 1) @ mask_flat.view(
136
+ 1, img_seq_len
137
+ )
138
+ else:
139
+ # Global prompt: allow attention to/from background regions only
140
+ if background_region_mask is not None:
141
+ # 2. background img attends to global txt
142
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = background_region_mask.view(
143
+ img_seq_len, 1
144
+ )
145
+
146
+ # 3. global txt attends to background img
147
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = background_region_mask.view(
148
+ 1, img_seq_len
149
+ )
150
+ else:
151
+ # No regional masks at all, allow full attention
152
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0
153
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0
154
+
155
+ # Allow background regions to attend to themselves
156
+ if background_region_mask is not None:
157
+ bg_mask = background_region_mask.view(img_seq_len, 1)
158
+ regional_attention_mask[:img_seq_len, :img_seq_len] += bg_mask @ bg_mask.T
159
+
160
+ # Convert to boolean mask
161
+ regional_attention_mask = regional_attention_mask > 0.5
162
+
163
+ return regional_attention_mask
164
+
165
+ @staticmethod
166
+ def preprocess_regional_prompt_mask(
167
+ mask: Optional[torch.Tensor],
168
+ target_height: int,
169
+ target_width: int,
170
+ dtype: torch.dtype,
171
+ device: torch.device,
172
+ ) -> torch.Tensor:
173
+ """Preprocess a regional prompt mask to match the target image token grid.
174
+
175
+ Args:
176
+ mask: Input mask tensor. If None, returns a mask of all ones.
177
+ target_height: Height of the image token grid (H // patch_size).
178
+ target_width: Width of the image token grid (W // patch_size).
179
+ dtype: Target dtype for the mask.
180
+ device: Target device for the mask.
181
+
182
+ Returns:
183
+ Processed mask of shape (1, 1, target_height * target_width).
184
+ """
185
+ img_seq_len = target_height * target_width
186
+
187
+ if mask is None:
188
+ return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device)
189
+
190
+ mask = to_standard_float_mask(mask, out_dtype=dtype)
191
+
192
+ # Resize mask to target dimensions
193
+ tf = torchvision.transforms.Resize(
194
+ (target_height, target_width),
195
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST,
196
+ )
197
+
198
+ # Add batch dimension if needed: (h, w) -> (1, h, w) -> (1, 1, h, w)
199
+ if mask.ndim == 2:
200
+ mask = mask.unsqueeze(0)
201
+ if mask.ndim == 3:
202
+ mask = mask.unsqueeze(0)
203
+
204
+ resized_mask = tf(mask)
205
+
206
+ # Flatten to (1, 1, img_seq_len)
207
+ return resized_mask.flatten(start_dim=2).to(device=device)
@@ -0,0 +1,74 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
6
+
7
+
8
+ @dataclass
9
+ class ZImageTextConditioning:
10
+ """Z-Image text conditioning with optional regional mask.
11
+
12
+ Attributes:
13
+ prompt_embeds: Text embeddings from Qwen3 encoder. Shape: (seq_len, hidden_size).
14
+ mask: Optional binary mask for regional prompting. If None, the prompt is global.
15
+ Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
16
+ """
17
+
18
+ prompt_embeds: torch.Tensor
19
+ mask: torch.Tensor | None = None
20
+
21
+
22
+ @dataclass
23
+ class ZImageRegionalTextConditioning:
24
+ """Container for multiple regional text conditionings concatenated together.
25
+
26
+ In Z-Image, the unified sequence is [img_tokens, txt_tokens], which is different
27
+ from FLUX where it's [txt_tokens, img_tokens]. The attention mask must account for this.
28
+
29
+ Attributes:
30
+ prompt_embeds: Concatenated text embeddings from all regional prompts.
31
+ Shape: (total_seq_len, hidden_size).
32
+ image_masks: List of binary masks for each regional prompt.
33
+ image_masks[i] corresponds to embedding_ranges[i].
34
+ If None, the prompt is global (applies to entire image).
35
+ Shape: (1, 1, img_seq_len).
36
+ embedding_ranges: List of ranges indicating which portion of prompt_embeds
37
+ corresponds to each regional prompt.
38
+ """
39
+
40
+ prompt_embeds: torch.Tensor
41
+ image_masks: list[torch.Tensor | None]
42
+ embedding_ranges: list[Range]
43
+
44
+ @classmethod
45
+ def from_text_conditionings(
46
+ cls,
47
+ text_conditionings: list[ZImageTextConditioning],
48
+ ) -> "ZImageRegionalTextConditioning":
49
+ """Create a ZImageRegionalTextConditioning from a list of ZImageTextConditioning objects.
50
+
51
+ Args:
52
+ text_conditionings: List of text conditionings, each with optional mask.
53
+
54
+ Returns:
55
+ A single ZImageRegionalTextConditioning with concatenated embeddings.
56
+ """
57
+ concat_embeds: list[torch.Tensor] = []
58
+ concat_ranges: list[Range] = []
59
+ image_masks: list[torch.Tensor | None] = []
60
+
61
+ cur_embed_len = 0
62
+ for tc in text_conditionings:
63
+ concat_embeds.append(tc.prompt_embeds)
64
+ concat_ranges.append(Range(start=cur_embed_len, end=cur_embed_len + tc.prompt_embeds.shape[0]))
65
+ image_masks.append(tc.mask)
66
+ cur_embed_len += tc.prompt_embeds.shape[0]
67
+
68
+ prompt_embeds = torch.cat(concat_embeds, dim=0)
69
+
70
+ return cls(
71
+ prompt_embeds=prompt_embeds,
72
+ image_masks=image_masks,
73
+ embedding_ranges=concat_ranges,
74
+ )