InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/flux_denoise.py +15 -1
  6. invokeai/app/invocations/image_to_latents.py +23 -5
  7. invokeai/app/invocations/latents_to_image.py +2 -25
  8. invokeai/app/invocations/metadata.py +9 -1
  9. invokeai/app/invocations/metadata_linked.py +47 -0
  10. invokeai/app/invocations/model.py +8 -0
  11. invokeai/app/invocations/pbr_maps.py +59 -0
  12. invokeai/app/invocations/primitives.py +12 -0
  13. invokeai/app/invocations/prompt_template.py +57 -0
  14. invokeai/app/invocations/z_image_control.py +112 -0
  15. invokeai/app/invocations/z_image_denoise.py +770 -0
  16. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  17. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  18. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  19. invokeai/app/invocations/z_image_model_loader.py +135 -0
  20. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  21. invokeai/app/services/config/config_default.py +3 -1
  22. invokeai/app/services/model_install/model_install_common.py +14 -1
  23. invokeai/app/services/model_install/model_install_default.py +119 -19
  24. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  25. invokeai/app/services/model_records/model_records_base.py +12 -0
  26. invokeai/app/services/model_records/model_records_sql.py +17 -0
  27. invokeai/app/services/shared/graph.py +132 -77
  28. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  29. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  30. invokeai/app/util/step_callback.py +3 -0
  31. invokeai/backend/flux/denoise.py +196 -11
  32. invokeai/backend/flux/schedulers.py +62 -0
  33. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  34. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  35. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  36. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  37. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  38. invokeai/backend/model_manager/configs/factory.py +26 -1
  39. invokeai/backend/model_manager/configs/lora.py +79 -1
  40. invokeai/backend/model_manager/configs/main.py +113 -0
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  42. invokeai/backend/model_manager/load/model_cache/model_cache.py +104 -2
  43. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  44. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  45. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  46. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  47. invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
  48. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  49. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  50. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  51. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  52. invokeai/backend/model_manager/load/model_loaders/z_image.py +969 -0
  53. invokeai/backend/model_manager/load/model_util.py +6 -1
  54. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  55. invokeai/backend/model_manager/model_on_disk.py +3 -0
  56. invokeai/backend/model_manager/starter_models.py +79 -0
  57. invokeai/backend/model_manager/taxonomy.py +5 -0
  58. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  59. invokeai/backend/patches/layer_patcher.py +34 -16
  60. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  61. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  62. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  63. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  64. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  65. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +189 -0
  66. invokeai/backend/quantization/gguf/ggml_tensor.py +38 -4
  67. invokeai/backend/quantization/gguf/loaders.py +47 -12
  68. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  69. invokeai/backend/util/devices.py +25 -0
  70. invokeai/backend/util/hotfixes.py +2 -2
  71. invokeai/backend/z_image/__init__.py +16 -0
  72. invokeai/backend/z_image/extensions/__init__.py +1 -0
  73. invokeai/backend/z_image/extensions/regional_prompting_extension.py +205 -0
  74. invokeai/backend/z_image/text_conditioning.py +74 -0
  75. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  76. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  77. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  78. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  79. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  80. invokeai/frontend/web/dist/assets/App-BBELGD-n.js +161 -0
  81. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-4xPFTMT3.js} +1 -1
  82. invokeai/frontend/web/dist/assets/index-vCDSQboA.js +530 -0
  83. invokeai/frontend/web/dist/index.html +1 -1
  84. invokeai/frontend/web/dist/locales/de.json +24 -6
  85. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  86. invokeai/frontend/web/dist/locales/en.json +78 -3
  87. invokeai/frontend/web/dist/locales/es.json +0 -5
  88. invokeai/frontend/web/dist/locales/fr.json +0 -6
  89. invokeai/frontend/web/dist/locales/it.json +17 -64
  90. invokeai/frontend/web/dist/locales/ja.json +379 -44
  91. invokeai/frontend/web/dist/locales/ru.json +0 -6
  92. invokeai/frontend/web/dist/locales/vi.json +7 -54
  93. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  94. invokeai/version/invokeai_version.py +1 -1
  95. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/METADATA +4 -4
  96. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/RECORD +102 -71
  97. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  98. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  99. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/WHEEL +0 -0
  100. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/entry_points.txt +0 -0
  101. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE +0 -0
  102. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  103. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  104. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
1
+ """Z-Image LoRA conversion utilities.
2
+
3
+ Z-Image uses S3-DiT transformer architecture with Qwen3 text encoder.
4
+ LoRAs for Z-Image typically follow the diffusers PEFT format.
5
+ """
6
+
7
+ from typing import Dict
8
+
9
+ import torch
10
+
11
+ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
12
+ from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
13
+ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import (
14
+ Z_IMAGE_LORA_QWEN3_PREFIX,
15
+ Z_IMAGE_LORA_TRANSFORMER_PREFIX,
16
+ )
17
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
18
+
19
+
20
+ def is_state_dict_likely_z_image_lora(state_dict: dict[str | int, torch.Tensor]) -> bool:
21
+ """Checks if the provided state dict is likely a Z-Image LoRA.
22
+
23
+ Z-Image LoRAs can have keys for transformer and/or Qwen3 text encoder.
24
+ They may use various prefixes depending on the training framework.
25
+ """
26
+ str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
27
+
28
+ # Check for Z-Image transformer keys (S3-DiT architecture)
29
+ # Various training frameworks use different prefixes
30
+ has_transformer_keys = any(
31
+ k.startswith(
32
+ (
33
+ "transformer.",
34
+ "base_model.model.transformer.",
35
+ "diffusion_model.",
36
+ )
37
+ )
38
+ for k in str_keys
39
+ )
40
+
41
+ # Check for Qwen3 text encoder keys
42
+ has_qwen3_keys = any(k.startswith(("text_encoder.", "base_model.model.text_encoder.")) for k in str_keys)
43
+
44
+ return has_transformer_keys or has_qwen3_keys
45
+
46
+
47
+ def lora_model_from_z_image_state_dict(
48
+ state_dict: Dict[str, torch.Tensor], alpha: float | None = None
49
+ ) -> ModelPatchRaw:
50
+ """Convert a Z-Image LoRA state dict to a ModelPatchRaw.
51
+
52
+ Z-Image LoRAs can contain layers for:
53
+ - Transformer (S3-DiT architecture)
54
+ - Qwen3 text encoder
55
+
56
+ Z-Image LoRAs may use various key prefixes depending on how they were trained:
57
+ - "transformer." or "base_model.model.transformer." for diffusers PEFT format
58
+ - "diffusion_model." for some training frameworks
59
+ - "text_encoder." or "base_model.model.text_encoder." for Qwen3 encoder
60
+
61
+ Args:
62
+ state_dict: The LoRA state dict
63
+ alpha: The alpha value for LoRA scaling. If None, uses rank as alpha.
64
+
65
+ Returns:
66
+ A ModelPatchRaw containing the LoRA layers
67
+ """
68
+ layers: dict[str, BaseLayerPatch] = {}
69
+
70
+ # Group keys by layer
71
+ grouped_state_dict = _group_by_layer(state_dict)
72
+
73
+ for layer_key, layer_dict in grouped_state_dict.items():
74
+ # Convert PEFT format keys to internal format
75
+ values = _get_lora_layer_values(layer_dict, alpha)
76
+
77
+ # Determine the appropriate prefix based on the layer type and clean up the key
78
+ clean_key = layer_key
79
+
80
+ # Handle various transformer prefixes
81
+ transformer_prefixes = [
82
+ "base_model.model.transformer.diffusion_model.",
83
+ "base_model.model.transformer.",
84
+ "transformer.diffusion_model.",
85
+ "transformer.",
86
+ "diffusion_model.",
87
+ ]
88
+
89
+ # Handle text encoder prefixes
90
+ text_encoder_prefixes = [
91
+ "base_model.model.text_encoder.",
92
+ "text_encoder.",
93
+ ]
94
+
95
+ is_text_encoder = False
96
+
97
+ # Check and strip text encoder prefixes first
98
+ for prefix in text_encoder_prefixes:
99
+ if layer_key.startswith(prefix):
100
+ clean_key = layer_key[len(prefix) :]
101
+ is_text_encoder = True
102
+ break
103
+
104
+ # If not text encoder, check transformer prefixes
105
+ if not is_text_encoder:
106
+ for prefix in transformer_prefixes:
107
+ if layer_key.startswith(prefix):
108
+ clean_key = layer_key[len(prefix) :]
109
+ break
110
+
111
+ # Apply the appropriate internal prefix
112
+ if is_text_encoder:
113
+ final_key = f"{Z_IMAGE_LORA_QWEN3_PREFIX}{clean_key}"
114
+ else:
115
+ final_key = f"{Z_IMAGE_LORA_TRANSFORMER_PREFIX}{clean_key}"
116
+
117
+ layer = any_lora_layer_from_state_dict(values)
118
+ layers[final_key] = layer
119
+
120
+ return ModelPatchRaw(layers=layers)
121
+
122
+
123
+ def _get_lora_layer_values(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]:
124
+ """Convert layer dict keys from PEFT format to internal format."""
125
+ if "lora_A.weight" in layer_dict:
126
+ # PEFT format: lora_A.weight, lora_B.weight
127
+ values = {
128
+ "lora_down.weight": layer_dict["lora_A.weight"],
129
+ "lora_up.weight": layer_dict["lora_B.weight"],
130
+ }
131
+ if alpha is not None:
132
+ values["alpha"] = torch.tensor(alpha)
133
+ return values
134
+ elif "lora_down.weight" in layer_dict:
135
+ # Already in internal format
136
+ return layer_dict
137
+ else:
138
+ # Unknown format, return as-is
139
+ return layer_dict
140
+
141
+
142
+ def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
143
+ """Groups the keys in the state dict by layer.
144
+
145
+ Z-Image LoRAs have keys like:
146
+ - diffusion_model.layers.17.attention.to_k.alpha
147
+ - diffusion_model.layers.17.attention.to_k.dora_scale
148
+ - diffusion_model.layers.17.attention.to_k.lora_down.weight
149
+ - diffusion_model.layers.17.attention.to_k.lora_up.weight
150
+
151
+ We need to group these by the full layer path (e.g., diffusion_model.layers.17.attention.to_k)
152
+ and extract the suffix (alpha, dora_scale, lora_down.weight, lora_up.weight).
153
+ """
154
+ layer_dict: dict[str, dict[str, torch.Tensor]] = {}
155
+
156
+ # Known suffixes that indicate the end of a layer name
157
+ known_suffixes = [
158
+ ".lora_A.weight",
159
+ ".lora_B.weight",
160
+ ".lora_down.weight",
161
+ ".lora_up.weight",
162
+ ".dora_scale",
163
+ ".alpha",
164
+ ]
165
+
166
+ for key in state_dict:
167
+ if not isinstance(key, str):
168
+ continue
169
+
170
+ # Try to find a known suffix
171
+ layer_name = None
172
+ key_name = None
173
+ for suffix in known_suffixes:
174
+ if key.endswith(suffix):
175
+ layer_name = key[: -len(suffix)]
176
+ key_name = suffix[1:] # Remove leading dot
177
+ break
178
+
179
+ if layer_name is None:
180
+ # Fallback to original logic for unknown formats
181
+ parts = key.rsplit(".", maxsplit=2)
182
+ layer_name = parts[0]
183
+ key_name = ".".join(parts[1:])
184
+
185
+ if layer_name not in layer_dict:
186
+ layer_dict[layer_name] = {}
187
+ layer_dict[layer_name][key_name] = state_dict[key]
188
+
189
+ return layer_dict
@@ -14,11 +14,42 @@ def dequantize_and_run(func, args, kwargs):
14
14
  """A helper function for running math ops on GGMLTensor inputs.
15
15
 
16
16
  Dequantizes the inputs, and runs the function.
17
+ Also casts other floating point tensors to match the compute_dtype of GGMLTensors
18
+ to avoid dtype mismatches in matrix operations.
17
19
  """
18
- dequantized_args = [a.get_dequantized_tensor() if hasattr(a, "get_dequantized_tensor") else a for a in args]
19
- dequantized_kwargs = {
20
- k: v.get_dequantized_tensor() if hasattr(v, "get_dequantized_tensor") else v for k, v in kwargs.items()
21
- }
20
+ # Find the compute_dtype and target_device from any GGMLTensor in the args
21
+ compute_dtype = None
22
+ target_device = None
23
+ for a in args:
24
+ if hasattr(a, "compute_dtype"):
25
+ compute_dtype = a.compute_dtype
26
+ if isinstance(a, torch.Tensor) and target_device is None:
27
+ target_device = a.device
28
+ if compute_dtype is not None and target_device is not None:
29
+ break
30
+ if compute_dtype is None or target_device is None:
31
+ for v in kwargs.values():
32
+ if hasattr(v, "compute_dtype") and compute_dtype is None:
33
+ compute_dtype = v.compute_dtype
34
+ if isinstance(v, torch.Tensor) and target_device is None:
35
+ target_device = v.device
36
+ if compute_dtype is not None and target_device is not None:
37
+ break
38
+
39
+ def process_tensor(t):
40
+ if hasattr(t, "get_dequantized_tensor"):
41
+ result = t.get_dequantized_tensor()
42
+ # Ensure the dequantized tensor is on the target device
43
+ if target_device is not None and result.device != target_device:
44
+ result = result.to(target_device)
45
+ return result
46
+ elif isinstance(t, torch.Tensor) and compute_dtype is not None and t.is_floating_point():
47
+ # Cast other floating point tensors to match the GGUF compute_dtype
48
+ return t.to(compute_dtype)
49
+ return t
50
+
51
+ dequantized_args = [process_tensor(a) for a in args]
52
+ dequantized_kwargs = {k: process_tensor(v) for k, v in kwargs.items()}
22
53
  return func(*dequantized_args, **dequantized_kwargs)
23
54
 
24
55
 
@@ -57,6 +88,9 @@ GGML_TENSOR_OP_TABLE = {
57
88
  torch.ops.aten.sub.Tensor: dequantize_and_run, # pyright: ignore
58
89
  torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore
59
90
  torch.ops.aten.slice.Tensor: dequantize_and_run, # pyright: ignore
91
+ torch.ops.aten.view.default: dequantize_and_run, # pyright: ignore
92
+ torch.ops.aten.expand.default: dequantize_and_run, # pyright: ignore
93
+ torch.ops.aten.index_put_.default: dequantize_and_run, # pyright: ignore
60
94
  }
61
95
 
62
96
  if torch.backends.mps.is_available():
@@ -1,3 +1,4 @@
1
+ import gc
1
2
  from pathlib import Path
2
3
 
3
4
  import gguf
@@ -5,18 +6,52 @@ import torch
5
6
 
6
7
  from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
7
8
  from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
9
+ from invokeai.backend.util.logging import InvokeAILogger
10
+
11
+ logger = InvokeAILogger.get_logger()
12
+
13
+
14
+ class WrappedGGUFReader:
15
+ """Wrapper around GGUFReader that adds a close() method."""
16
+
17
+ def __init__(self, path: Path):
18
+ self.reader = gguf.GGUFReader(path)
19
+
20
+ def __enter__(self):
21
+ return self.reader
22
+
23
+ def __exit__(self, exc_type, exc_val, exc_tb):
24
+ self.close()
25
+ return False
26
+
27
+ def close(self):
28
+ """Explicitly close the memory-mapped file."""
29
+ if hasattr(self.reader, "data"):
30
+ try:
31
+ self.reader.data.flush()
32
+ del self.reader.data
33
+ except (AttributeError, OSError, ValueError) as e:
34
+ logger.warning(f"Wasn't able to close GGUF memory map: {e}")
35
+ del self.reader
36
+ gc.collect()
8
37
 
9
38
 
10
39
  def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
11
- reader = gguf.GGUFReader(path)
12
-
13
- sd: dict[str, GGMLTensor] = {}
14
- for tensor in reader.tensors:
15
- torch_tensor = torch.from_numpy(tensor.data)
16
- shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
17
- if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
18
- torch_tensor = torch_tensor.view(*shape)
19
- sd[tensor.name] = GGMLTensor(
20
- torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
21
- )
22
- return sd
40
+ with WrappedGGUFReader(path) as reader:
41
+ sd: dict[str, GGMLTensor] = {}
42
+ for tensor in reader.tensors:
43
+ # Use .copy() to create a true copy of the data, not a view.
44
+ # This is critical on Windows where the memory-mapped file cannot be deleted
45
+ # while tensors still hold references to the mapped memory.
46
+ torch_tensor = torch.from_numpy(tensor.data.copy())
47
+
48
+ shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
49
+ if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
50
+ torch_tensor = torch_tensor.view(*shape)
51
+ sd[tensor.name] = GGMLTensor(
52
+ torch_tensor,
53
+ ggml_quantization_type=tensor.tensor_type,
54
+ tensor_shape=shape,
55
+ compute_dtype=compute_dtype,
56
+ )
57
+ return sd
@@ -76,6 +76,18 @@ class CogView4ConditioningInfo:
76
76
  return self
77
77
 
78
78
 
79
+ @dataclass
80
+ class ZImageConditioningInfo:
81
+ """Z-Image text conditioning information from Qwen3 text encoder."""
82
+
83
+ prompt_embeds: torch.Tensor
84
+ """Text embeddings from Qwen3 encoder. Shape: (batch_size, seq_len, hidden_size)."""
85
+
86
+ def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
87
+ self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype)
88
+ return self
89
+
90
+
79
91
  @dataclass
80
92
  class ConditioningFieldData:
81
93
  # If you change this class, adding more types, you _must_ update the instantiation of ObjectSerializerDisk in
@@ -87,6 +99,7 @@ class ConditioningFieldData:
87
99
  | List[FLUXConditioningInfo]
88
100
  | List[SD3ConditioningInfo]
89
101
  | List[CogView4ConditioningInfo]
102
+ | List[ZImageConditioningInfo]
90
103
  )
91
104
 
92
105
 
@@ -112,3 +112,28 @@ class TorchDevice:
112
112
  @classmethod
113
113
  def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
114
114
  return NAME_TO_PRECISION[precision_name]
115
+
116
+ @classmethod
117
+ def choose_bfloat16_safe_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
118
+ """Return bfloat16 if supported on the device, else fallback to float16/float32.
119
+
120
+ This is useful for models that require bfloat16 precision (e.g., Z-Image, Flux)
121
+ but need to run on hardware that may not support bfloat16.
122
+
123
+ Args:
124
+ device: The target device. If None, uses choose_torch_device().
125
+
126
+ Returns:
127
+ torch.bfloat16 if supported, torch.float16 for CUDA without bfloat16 support,
128
+ or torch.float32 for CPU/MPS.
129
+ """
130
+ device = device or cls.choose_torch_device()
131
+ try:
132
+ # Test if bfloat16 is supported on this device
133
+ torch.tensor([1.0], dtype=torch.bfloat16, device=device)
134
+ return torch.bfloat16
135
+ except TypeError:
136
+ # bfloat16 not supported - fallback based on device type
137
+ if device.type == "cuda":
138
+ return torch.float16
139
+ return torch.float32
@@ -5,7 +5,7 @@ import torch
5
5
  from diffusers.configuration_utils import ConfigMixin, register_to_config
6
6
  from diffusers.loaders.single_file_model import FromOriginalModelMixin
7
7
  from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
8
- from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
8
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
9
9
  from diffusers.models.embeddings import (
10
10
  TextImageProjection,
11
11
  TextImageTimeEmbedding,
@@ -777,7 +777,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
777
777
 
778
778
 
779
779
  diffusers.ControlNetModel = ControlNetModel
780
- diffusers.models.controlnet.ControlNetModel = ControlNetModel
780
+ diffusers.models.controlnets.controlnet.ControlNetModel = ControlNetModel
781
781
 
782
782
 
783
783
  # patch LoRACompatibleConv to use original Conv2D forward function
@@ -0,0 +1,16 @@
1
+ # Z-Image backend utilities
2
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
3
+ from invokeai.backend.z_image.z_image_control_transformer import ZImageControlTransformer2DModel
4
+ from invokeai.backend.z_image.z_image_controlnet_extension import (
5
+ ZImageControlNetExtension,
6
+ z_image_forward_with_control,
7
+ )
8
+ from invokeai.backend.z_image.z_image_patchify_utils import patchify_control_context
9
+
10
+ __all__ = [
11
+ "ZImageControlAdapter",
12
+ "ZImageControlTransformer2DModel",
13
+ "ZImageControlNetExtension",
14
+ "z_image_forward_with_control",
15
+ "patchify_control_context",
16
+ ]
@@ -0,0 +1 @@
1
+ # Z-Image extensions
@@ -0,0 +1,205 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torchvision
5
+
6
+ from invokeai.backend.util.devices import TorchDevice
7
+ from invokeai.backend.util.mask import to_standard_float_mask
8
+ from invokeai.backend.z_image.text_conditioning import ZImageRegionalTextConditioning, ZImageTextConditioning
9
+
10
+
11
+ class ZImageRegionalPromptingExtension:
12
+ """A class for managing regional prompting with Z-Image.
13
+
14
+ This implementation is inspired by the FLUX regional prompting extension and
15
+ the paper https://arxiv.org/pdf/2411.02395.
16
+
17
+ Key difference from FLUX: Z-Image uses sequence order [img_tokens, txt_tokens],
18
+ while FLUX uses [txt_tokens, img_tokens]. The attention mask construction
19
+ accounts for this difference.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ regional_text_conditioning: ZImageRegionalTextConditioning,
25
+ regional_attn_mask: torch.Tensor | None = None,
26
+ ):
27
+ self.regional_text_conditioning = regional_text_conditioning
28
+ self.regional_attn_mask = regional_attn_mask
29
+
30
+ def get_attn_mask(self, block_index: int) -> torch.Tensor | None:
31
+ """Get the attention mask for a given block index.
32
+
33
+ Uses alternating pattern: apply mask on even blocks, no mask on odd blocks.
34
+ This helps balance regional control with global coherence.
35
+ """
36
+ order = [self.regional_attn_mask, None]
37
+ return order[block_index % len(order)]
38
+
39
+ @classmethod
40
+ def from_text_conditionings(
41
+ cls,
42
+ text_conditionings: list[ZImageTextConditioning],
43
+ img_seq_len: int,
44
+ ) -> "ZImageRegionalPromptingExtension":
45
+ """Create a ZImageRegionalPromptingExtension from a list of text conditionings.
46
+
47
+ Args:
48
+ text_conditionings: List of text conditionings with optional masks.
49
+ img_seq_len: The image sequence length (i.e. (H // patch_size) * (W // patch_size)).
50
+
51
+ Returns:
52
+ A configured ZImageRegionalPromptingExtension.
53
+ """
54
+ regional_text_conditioning = ZImageRegionalTextConditioning.from_text_conditionings(text_conditionings)
55
+ attn_mask = cls._prepare_regional_attn_mask(regional_text_conditioning, img_seq_len)
56
+ return cls(
57
+ regional_text_conditioning=regional_text_conditioning,
58
+ regional_attn_mask=attn_mask,
59
+ )
60
+
61
+ @classmethod
62
+ def _prepare_regional_attn_mask(
63
+ cls,
64
+ regional_text_conditioning: ZImageRegionalTextConditioning,
65
+ img_seq_len: int,
66
+ ) -> torch.Tensor | None:
67
+ """Prepare a regional attention mask for Z-Image.
68
+
69
+ This uses an 'unrestricted' image self-attention approach (similar to FLUX):
70
+ - Image tokens can attend to ALL other image tokens (unrestricted self-attention)
71
+ - Image tokens attend only to their corresponding regional text
72
+ - Text tokens attend only to their corresponding regional image
73
+ - Text tokens attend to themselves
74
+
75
+ The unrestricted image self-attention allows the model to maintain global
76
+ coherence across regions, preventing the generation of separate/disconnected
77
+ images for each region.
78
+
79
+ Z-Image sequence order: [img_tokens, txt_tokens]
80
+
81
+ Args:
82
+ regional_text_conditioning: The regional text conditioning data.
83
+ img_seq_len: Number of image tokens.
84
+
85
+ Returns:
86
+ Attention mask of shape (img_seq_len + txt_seq_len, img_seq_len + txt_seq_len).
87
+ Returns None if no regional masks are present.
88
+ """
89
+ # Check if any regional masks exist
90
+ has_regional_masks = any(mask is not None for mask in regional_text_conditioning.image_masks)
91
+ if not has_regional_masks:
92
+ # No regional masks, return None to use default attention
93
+ return None
94
+
95
+ # Identify background region (area not covered by any mask)
96
+ background_region_mask: torch.Tensor | None = None
97
+ for image_mask in regional_text_conditioning.image_masks:
98
+ if image_mask is not None:
99
+ # image_mask shape: (1, 1, img_seq_len) -> flatten to (img_seq_len,)
100
+ mask_flat = image_mask.view(-1)
101
+ if background_region_mask is None:
102
+ background_region_mask = torch.ones_like(mask_flat)
103
+ background_region_mask = background_region_mask * (1 - mask_flat)
104
+
105
+ device = TorchDevice.choose_torch_device()
106
+ txt_seq_len = regional_text_conditioning.prompt_embeds.shape[0]
107
+ total_seq_len = img_seq_len + txt_seq_len
108
+
109
+ # Initialize empty attention mask
110
+ # Z-Image sequence: [img_tokens (0:img_seq_len), txt_tokens (img_seq_len:total_seq_len)]
111
+ regional_attention_mask = torch.zeros((total_seq_len, total_seq_len), device=device, dtype=torch.float16)
112
+
113
+ for image_mask, embedding_range in zip(
114
+ regional_text_conditioning.image_masks,
115
+ regional_text_conditioning.embedding_ranges,
116
+ strict=True,
117
+ ):
118
+ # Calculate text token positions in the unified sequence
119
+ txt_start = img_seq_len + embedding_range.start
120
+ txt_end = img_seq_len + embedding_range.end
121
+
122
+ # 1. txt attends to itself
123
+ regional_attention_mask[txt_start:txt_end, txt_start:txt_end] = 1.0
124
+
125
+ if image_mask is not None:
126
+ # Flatten mask: (1, 1, img_seq_len) -> (img_seq_len,)
127
+ mask_flat = image_mask.view(img_seq_len)
128
+
129
+ # 2. img attends to corresponding regional txt
130
+ # Reshape mask to (img_seq_len, 1) for broadcasting
131
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = mask_flat.view(img_seq_len, 1)
132
+
133
+ # 3. txt attends to corresponding regional img
134
+ # Reshape mask to (1, img_seq_len) for broadcasting
135
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = mask_flat.view(1, img_seq_len)
136
+ else:
137
+ # Global prompt: allow attention to/from background regions only
138
+ if background_region_mask is not None:
139
+ # 2. background img attends to global txt
140
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = background_region_mask.view(
141
+ img_seq_len, 1
142
+ )
143
+
144
+ # 3. global txt attends to background img
145
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = background_region_mask.view(
146
+ 1, img_seq_len
147
+ )
148
+ else:
149
+ # No regional masks at all, allow full attention
150
+ regional_attention_mask[:img_seq_len, txt_start:txt_end] = 1.0
151
+ regional_attention_mask[txt_start:txt_end, :img_seq_len] = 1.0
152
+
153
+ # 4. Allow unrestricted image self-attention
154
+ # This is the key difference from the restricted approach - all image tokens
155
+ # can attend to each other, which helps maintain global coherence across regions
156
+ regional_attention_mask[:img_seq_len, :img_seq_len] = 1.0
157
+
158
+ # Convert to boolean mask
159
+ regional_attention_mask = regional_attention_mask > 0.5
160
+
161
+ return regional_attention_mask
162
+
163
+ @staticmethod
164
+ def preprocess_regional_prompt_mask(
165
+ mask: Optional[torch.Tensor],
166
+ target_height: int,
167
+ target_width: int,
168
+ dtype: torch.dtype,
169
+ device: torch.device,
170
+ ) -> torch.Tensor:
171
+ """Preprocess a regional prompt mask to match the target image token grid.
172
+
173
+ Args:
174
+ mask: Input mask tensor. If None, returns a mask of all ones.
175
+ target_height: Height of the image token grid (H // patch_size).
176
+ target_width: Width of the image token grid (W // patch_size).
177
+ dtype: Target dtype for the mask.
178
+ device: Target device for the mask.
179
+
180
+ Returns:
181
+ Processed mask of shape (1, 1, target_height * target_width).
182
+ """
183
+ img_seq_len = target_height * target_width
184
+
185
+ if mask is None:
186
+ return torch.ones((1, 1, img_seq_len), dtype=dtype, device=device)
187
+
188
+ mask = to_standard_float_mask(mask, out_dtype=dtype)
189
+
190
+ # Resize mask to target dimensions
191
+ tf = torchvision.transforms.Resize(
192
+ (target_height, target_width),
193
+ interpolation=torchvision.transforms.InterpolationMode.NEAREST,
194
+ )
195
+
196
+ # Add batch dimension if needed: (h, w) -> (1, h, w) -> (1, 1, h, w)
197
+ if mask.ndim == 2:
198
+ mask = mask.unsqueeze(0)
199
+ if mask.ndim == 3:
200
+ mask = mask.unsqueeze(0)
201
+
202
+ resized_mask = tf(mask)
203
+
204
+ # Flatten to (1, 1, img_seq_len)
205
+ return resized_mask.flatten(start_dim=2).to(device=device)