InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,222 @@
1
+ """Flux2 Klein Model Loader Invocation.
2
+
3
+ Loads a Flux2 Klein model with its Qwen3 text encoder and VAE.
4
+ Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3.
5
+ """
6
+
7
+ from typing import Literal, Optional
8
+
9
+ from invokeai.app.invocations.baseinvocation import (
10
+ BaseInvocation,
11
+ BaseInvocationOutput,
12
+ Classification,
13
+ invocation,
14
+ invocation_output,
15
+ )
16
+ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
17
+ from invokeai.app.invocations.model import (
18
+ ModelIdentifierField,
19
+ Qwen3EncoderField,
20
+ TransformerField,
21
+ VAEField,
22
+ )
23
+ from invokeai.app.services.shared.invocation_context import InvocationContext
24
+ from invokeai.backend.model_manager.taxonomy import (
25
+ BaseModelType,
26
+ Flux2VariantType,
27
+ ModelFormat,
28
+ ModelType,
29
+ Qwen3VariantType,
30
+ SubModelType,
31
+ )
32
+
33
+
34
+ @invocation_output("flux2_klein_model_loader_output")
35
+ class Flux2KleinModelLoaderOutput(BaseInvocationOutput):
36
+ """Flux2 Klein model loader output."""
37
+
38
+ transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
39
+ qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
40
+ vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
41
+ max_seq_len: Literal[256, 512] = OutputField(
42
+ description="The max sequence length for the Qwen3 encoder.",
43
+ title="Max Seq Length",
44
+ )
45
+
46
+
47
+ @invocation(
48
+ "flux2_klein_model_loader",
49
+ title="Main Model - Flux2 Klein",
50
+ tags=["model", "flux", "klein", "qwen3"],
51
+ category="model",
52
+ version="1.0.0",
53
+ classification=Classification.Prototype,
54
+ )
55
+ class Flux2KleinModelLoaderInvocation(BaseInvocation):
56
+ """Loads a Flux2 Klein model, outputting its submodels.
57
+
58
+ Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
59
+ It uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel FLUX.1 VAE.
60
+
61
+ When using a Diffusers format model, both VAE and Qwen3 encoder are extracted
62
+ automatically from the main model. You can override with standalone models:
63
+ - Transformer: Always from Flux2 Klein main model
64
+ - VAE: From main model (Diffusers) or standalone VAE
65
+ - Qwen3 Encoder: From main model (Diffusers) or standalone Qwen3 model
66
+ """
67
+
68
+ model: ModelIdentifierField = InputField(
69
+ description=FieldDescriptions.flux_model,
70
+ input=Input.Direct,
71
+ ui_model_base=BaseModelType.Flux2,
72
+ ui_model_type=ModelType.Main,
73
+ title="Transformer",
74
+ )
75
+
76
+ vae_model: Optional[ModelIdentifierField] = InputField(
77
+ default=None,
78
+ description="Standalone VAE model. Flux2 Klein uses the same VAE as FLUX (16-channel). "
79
+ "If not provided, VAE will be loaded from the Qwen3 Source model.",
80
+ input=Input.Direct,
81
+ ui_model_base=[BaseModelType.Flux, BaseModelType.Flux2],
82
+ ui_model_type=ModelType.VAE,
83
+ title="VAE",
84
+ )
85
+
86
+ qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
87
+ default=None,
88
+ description="Standalone Qwen3 Encoder model. "
89
+ "If not provided, encoder will be loaded from the Qwen3 Source model.",
90
+ input=Input.Direct,
91
+ ui_model_type=ModelType.Qwen3Encoder,
92
+ title="Qwen3 Encoder",
93
+ )
94
+
95
+ qwen3_source_model: Optional[ModelIdentifierField] = InputField(
96
+ default=None,
97
+ description="Diffusers Flux2 Klein model to extract VAE and/or Qwen3 encoder from. "
98
+ "Use this if you don't have separate VAE/Qwen3 models. "
99
+ "Ignored if both VAE and Qwen3 Encoder are provided separately.",
100
+ input=Input.Direct,
101
+ ui_model_base=BaseModelType.Flux2,
102
+ ui_model_type=ModelType.Main,
103
+ ui_model_format=ModelFormat.Diffusers,
104
+ title="Qwen3 Source (Diffusers)",
105
+ )
106
+
107
+ max_seq_len: Literal[256, 512] = InputField(
108
+ default=512,
109
+ description="Max sequence length for the Qwen3 encoder.",
110
+ title="Max Seq Length",
111
+ )
112
+
113
+ def invoke(self, context: InvocationContext) -> Flux2KleinModelLoaderOutput:
114
+ # Transformer always comes from the main model
115
+ transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
116
+
117
+ # Check if main model is Diffusers format (can extract VAE directly)
118
+ main_config = context.models.get_config(self.model)
119
+ main_is_diffusers = main_config.format == ModelFormat.Diffusers
120
+
121
+ # Determine VAE source
122
+ # IMPORTANT: FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2), not the 16-channel FLUX.1 VAE.
123
+ # The VAE should come from the FLUX.2 Klein Diffusers model, not a separate FLUX VAE.
124
+ if self.vae_model is not None:
125
+ # Use standalone VAE (user explicitly selected one)
126
+ vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
127
+ elif main_is_diffusers:
128
+ # Extract VAE from main model (recommended for FLUX.2)
129
+ vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
130
+ elif self.qwen3_source_model is not None:
131
+ # Extract from Qwen3 source Diffusers model
132
+ self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
133
+ vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
134
+ else:
135
+ raise ValueError(
136
+ "No VAE source provided. Standalone safetensors/GGUF models require a separate VAE. "
137
+ "Options:\n"
138
+ " 1. Set 'VAE' to a standalone FLUX VAE model\n"
139
+ " 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the VAE from"
140
+ )
141
+
142
+ # Determine Qwen3 Encoder source
143
+ if self.qwen3_encoder_model is not None:
144
+ # Use standalone Qwen3 Encoder - validate it matches the FLUX.2 Klein variant
145
+ self._validate_qwen3_encoder_variant(context, main_config)
146
+ qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
147
+ qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
148
+ elif main_is_diffusers:
149
+ # Extract from main model (recommended for FLUX.2 Klein)
150
+ qwen3_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
151
+ qwen3_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
152
+ elif self.qwen3_source_model is not None:
153
+ # Extract from separate Diffusers model
154
+ self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
155
+ qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
156
+ qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
157
+ else:
158
+ raise ValueError(
159
+ "No Qwen3 Encoder source provided. Standalone safetensors/GGUF models require a separate text encoder. "
160
+ "Options:\n"
161
+ " 1. Set 'Qwen3 Encoder' to a standalone Qwen3 text encoder model "
162
+ "(Klein 4B needs Qwen3 4B, Klein 9B needs Qwen3 8B)\n"
163
+ " 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the encoder from"
164
+ )
165
+
166
+ return Flux2KleinModelLoaderOutput(
167
+ transformer=TransformerField(transformer=transformer, loras=[]),
168
+ qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
169
+ vae=VAEField(vae=vae),
170
+ max_seq_len=self.max_seq_len,
171
+ )
172
+
173
+ def _validate_diffusers_format(
174
+ self, context: InvocationContext, model: ModelIdentifierField, model_name: str
175
+ ) -> None:
176
+ """Validate that a model is in Diffusers format."""
177
+ config = context.models.get_config(model)
178
+ if config.format != ModelFormat.Diffusers:
179
+ raise ValueError(
180
+ f"The {model_name} model must be a Diffusers format model. "
181
+ f"The selected model '{config.name}' is in {config.format.value} format."
182
+ )
183
+
184
+ def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_config) -> None:
185
+ """Validate that the standalone Qwen3 encoder variant matches the FLUX.2 Klein variant.
186
+
187
+ - FLUX.2 Klein 4B requires Qwen3 4B encoder
188
+ - FLUX.2 Klein 9B requires Qwen3 8B encoder
189
+ """
190
+ if self.qwen3_encoder_model is None:
191
+ return
192
+
193
+ # Get the Qwen3 encoder config
194
+ qwen3_config = context.models.get_config(self.qwen3_encoder_model)
195
+
196
+ # Check if the config has a variant field
197
+ if not hasattr(qwen3_config, "variant"):
198
+ # Can't validate, skip
199
+ return
200
+
201
+ qwen3_variant = qwen3_config.variant
202
+
203
+ # Get the FLUX.2 Klein variant from the main model config
204
+ if not hasattr(main_config, "variant"):
205
+ return
206
+
207
+ flux2_variant = main_config.variant
208
+
209
+ # Validate the variants match
210
+ # Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
211
+ expected_qwen3_variant = None
212
+ if flux2_variant == Flux2VariantType.Klein4B:
213
+ expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
214
+ elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
215
+ expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
216
+
217
+ if expected_qwen3_variant is not None and qwen3_variant != expected_qwen3_variant:
218
+ raise ValueError(
219
+ f"Qwen3 encoder variant mismatch: FLUX.2 Klein {flux2_variant.value} requires "
220
+ f"{expected_qwen3_variant.value} encoder, but {qwen3_variant.value} was selected. "
221
+ f"Please select a matching Qwen3 encoder or use a Diffusers format model which includes the correct encoder."
222
+ )
@@ -0,0 +1,222 @@
1
+ """Flux2 Klein Text Encoder Invocation.
2
+
3
+ Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
4
+ The key difference is that it extracts hidden states from layers (9, 18, 27)
5
+ and stacks them together for richer text representations.
6
+
7
+ This implementation matches the diffusers Flux2KleinPipeline exactly.
8
+ """
9
+
10
+ from contextlib import ExitStack
11
+ from typing import Iterator, Literal, Optional, Tuple
12
+
13
+ import torch
14
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
15
+
16
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
17
+ from invokeai.app.invocations.fields import (
18
+ FieldDescriptions,
19
+ FluxConditioningField,
20
+ Input,
21
+ InputField,
22
+ TensorField,
23
+ UIComponent,
24
+ )
25
+ from invokeai.app.invocations.model import Qwen3EncoderField
26
+ from invokeai.app.invocations.primitives import FluxConditioningOutput
27
+ from invokeai.app.services.shared.invocation_context import InvocationContext
28
+ from invokeai.backend.patches.layer_patcher import LayerPatcher
29
+ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
30
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
31
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
32
+ from invokeai.backend.util.devices import TorchDevice
33
+
34
+ # FLUX.2 Klein extracts hidden states from these specific layers
35
+ # Matching diffusers Flux2KleinPipeline: (9, 18, 27)
36
+ # hidden_states[0] is embedding layer, so layer N is at index N
37
+ KLEIN_EXTRACTION_LAYERS = (9, 18, 27)
38
+
39
+ # Default max sequence length for Klein models
40
+ KLEIN_MAX_SEQ_LEN = 512
41
+
42
+
43
+ @invocation(
44
+ "flux2_klein_text_encoder",
45
+ title="Prompt - Flux2 Klein",
46
+ tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
47
+ category="conditioning",
48
+ version="1.1.0",
49
+ classification=Classification.Prototype,
50
+ )
51
+ class Flux2KleinTextEncoderInvocation(BaseInvocation):
52
+ """Encodes and preps a prompt for Flux2 Klein image generation.
53
+
54
+ Flux2 Klein uses Qwen3 as the text encoder, extracting hidden states from
55
+ layers (9, 18, 27) and stacking them for richer text representations.
56
+ This matches the diffusers Flux2KleinPipeline implementation exactly.
57
+ """
58
+
59
+ prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
60
+ qwen3_encoder: Qwen3EncoderField = InputField(
61
+ title="Qwen3 Encoder",
62
+ description=FieldDescriptions.qwen3_encoder,
63
+ input=Input.Connection,
64
+ )
65
+ max_seq_len: Literal[256, 512] = InputField(
66
+ default=512,
67
+ description="Max sequence length for the Qwen3 encoder.",
68
+ )
69
+ mask: Optional[TensorField] = InputField(
70
+ default=None,
71
+ description="A mask defining the region that this conditioning prompt applies to.",
72
+ )
73
+
74
+ @torch.no_grad()
75
+ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
76
+ qwen3_embeds, pooled_embeds = self._encode_prompt(context)
77
+
78
+ # Use FLUXConditioningInfo for compatibility with existing Flux denoiser
79
+ # t5_embeds -> qwen3 stacked embeddings
80
+ # clip_embeds -> pooled qwen3 embedding
81
+ conditioning_data = ConditioningFieldData(
82
+ conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
83
+ )
84
+
85
+ conditioning_name = context.conditioning.save(conditioning_data)
86
+ return FluxConditioningOutput(
87
+ conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
88
+ )
89
+
90
+ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ """Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
92
+
93
+ This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
94
+
95
+ Returns:
96
+ Tuple of (stacked_embeddings, pooled_embedding):
97
+ - stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
98
+ Shape: (1, seq_len, hidden_size * 3)
99
+ - pooled_embedding: Pooled representation for global conditioning.
100
+ Shape: (1, hidden_size)
101
+ """
102
+ prompt = self.prompt
103
+ device = TorchDevice.choose_torch_device()
104
+
105
+ text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
106
+ tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
107
+
108
+ with ExitStack() as exit_stack:
109
+ (cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
110
+ (_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
111
+
112
+ # Apply LoRA models to the text encoder
113
+ lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
114
+ exit_stack.enter_context(
115
+ LayerPatcher.apply_smart_model_patches(
116
+ model=text_encoder,
117
+ patches=self._lora_iterator(context),
118
+ prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
119
+ dtype=lora_dtype,
120
+ cached_weights=cached_weights,
121
+ )
122
+ )
123
+
124
+ context.util.signal_progress("Running Qwen3 text encoder (Klein)")
125
+
126
+ if not isinstance(text_encoder, PreTrainedModel):
127
+ raise TypeError(
128
+ f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
129
+ "The Qwen3 encoder model may be corrupted or incompatible."
130
+ )
131
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
132
+ raise TypeError(
133
+ f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
134
+ "The Qwen3 tokenizer may be corrupted or incompatible."
135
+ )
136
+
137
+ # Format messages exactly like diffusers Flux2KleinPipeline:
138
+ # - Only user message, NO system message
139
+ # - add_generation_prompt=True (adds assistant prefix)
140
+ # - enable_thinking=False
141
+ messages = [{"role": "user", "content": prompt}]
142
+
143
+ # Step 1: Apply chat template to get formatted text (tokenize=False)
144
+ text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
145
+ messages,
146
+ tokenize=False,
147
+ add_generation_prompt=True, # Adds assistant prefix like diffusers
148
+ enable_thinking=False, # Disable thinking mode
149
+ )
150
+
151
+ # Step 2: Tokenize the formatted text
152
+ inputs = tokenizer(
153
+ text,
154
+ return_tensors="pt",
155
+ padding="max_length",
156
+ truncation=True,
157
+ max_length=self.max_seq_len,
158
+ )
159
+
160
+ input_ids = inputs["input_ids"]
161
+ attention_mask = inputs["attention_mask"]
162
+
163
+ # Move to device
164
+ input_ids = input_ids.to(device)
165
+ attention_mask = attention_mask.to(device)
166
+
167
+ # Forward pass through the model - matching diffusers exactly
168
+ outputs = text_encoder(
169
+ input_ids=input_ids,
170
+ attention_mask=attention_mask,
171
+ output_hidden_states=True,
172
+ use_cache=False,
173
+ )
174
+
175
+ # Validate hidden_states output
176
+ if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
177
+ raise RuntimeError(
178
+ "Text encoder did not return hidden_states. "
179
+ "Ensure output_hidden_states=True is supported by this model."
180
+ )
181
+
182
+ num_hidden_layers = len(outputs.hidden_states)
183
+
184
+ # Extract and stack hidden states - EXACTLY like diffusers:
185
+ # out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
186
+ # prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
187
+ hidden_states_list = []
188
+ for layer_idx in KLEIN_EXTRACTION_LAYERS:
189
+ if layer_idx >= num_hidden_layers:
190
+ layer_idx = num_hidden_layers - 1
191
+ hidden_states_list.append(outputs.hidden_states[layer_idx])
192
+
193
+ # Stack along dim=1, then permute and reshape - exactly like diffusers
194
+ out = torch.stack(hidden_states_list, dim=1)
195
+ out = out.to(dtype=text_encoder.dtype, device=device)
196
+
197
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
198
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
199
+
200
+ # Create pooled embedding for global conditioning
201
+ # Use mean pooling over the sequence (excluding padding)
202
+ # This serves a similar role to CLIP's pooled output in standard FLUX
203
+ last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
204
+ # Expand mask to match hidden state dimensions
205
+ expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
206
+ sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
207
+ num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
208
+ pooled_embeds = sum_embeds / num_tokens
209
+
210
+ return prompt_embeds, pooled_embeds
211
+
212
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
213
+ """Iterate over LoRA models to apply to the Qwen3 text encoder."""
214
+ for lora in self.qwen3_encoder.loras:
215
+ lora_info = context.models.load(lora.lora)
216
+ if not isinstance(lora_info.model, ModelPatchRaw):
217
+ raise TypeError(
218
+ f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
219
+ "The LoRA model may be corrupted or incompatible."
220
+ )
221
+ yield (lora_info.model, lora.weight)
222
+ del lora_info
@@ -0,0 +1,106 @@
1
+ """Flux2 Klein VAE Decode Invocation.
2
+
3
+ Decodes latents to images using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
4
+ """
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from PIL import Image
9
+
10
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
11
+ from invokeai.app.invocations.fields import (
12
+ FieldDescriptions,
13
+ Input,
14
+ InputField,
15
+ LatentsField,
16
+ WithBoard,
17
+ WithMetadata,
18
+ )
19
+ from invokeai.app.invocations.model import VAEField
20
+ from invokeai.app.invocations.primitives import ImageOutput
21
+ from invokeai.app.services.shared.invocation_context import InvocationContext
22
+ from invokeai.backend.model_manager.load.load_base import LoadedModel
23
+ from invokeai.backend.util.devices import TorchDevice
24
+
25
+
26
+ @invocation(
27
+ "flux2_vae_decode",
28
+ title="Latents to Image - FLUX2",
29
+ tags=["latents", "image", "vae", "l2i", "flux2", "klein"],
30
+ category="latents",
31
+ version="1.0.0",
32
+ classification=Classification.Prototype,
33
+ )
34
+ class Flux2VaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
35
+ """Generates an image from latents using FLUX.2 Klein's 32-channel VAE."""
36
+
37
+ latents: LatentsField = InputField(
38
+ description=FieldDescriptions.latents,
39
+ input=Input.Connection,
40
+ )
41
+ vae: VAEField = InputField(
42
+ description=FieldDescriptions.vae,
43
+ input=Input.Connection,
44
+ )
45
+
46
+ def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
47
+ """Decode latents to image using FLUX.2 VAE.
48
+
49
+ Input latents should already be in the correct space after BN denormalization
50
+ was applied in the denoiser. The VAE expects (B, 32, H, W) format.
51
+ """
52
+ with vae_info.model_on_device() as (_, vae):
53
+ vae_dtype = next(iter(vae.parameters())).dtype
54
+ device = TorchDevice.choose_torch_device()
55
+ latents = latents.to(device=device, dtype=vae_dtype)
56
+
57
+ # Decode using diffusers API
58
+ decoded = vae.decode(latents, return_dict=False)[0]
59
+
60
+ # Debug: Log decoded output statistics
61
+ print(
62
+ f"[FLUX.2 VAE] Decoded output: shape={decoded.shape}, "
63
+ f"min={decoded.min().item():.4f}, max={decoded.max().item():.4f}, "
64
+ f"mean={decoded.mean().item():.4f}"
65
+ )
66
+ # Check per-channel statistics to diagnose color issues
67
+ for c in range(min(3, decoded.shape[1])):
68
+ ch = decoded[0, c]
69
+ print(
70
+ f"[FLUX.2 VAE] Channel {c}: min={ch.min().item():.4f}, "
71
+ f"max={ch.max().item():.4f}, mean={ch.mean().item():.4f}"
72
+ )
73
+
74
+ # Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
75
+ img = (decoded / 2 + 0.5).clamp(0, 1)
76
+ img = rearrange(img[0], "c h w -> h w c")
77
+ img_np = (img * 255).byte().cpu().numpy()
78
+ # Explicitly create RGB image (not grayscale)
79
+ img_pil = Image.fromarray(img_np, mode="RGB")
80
+ return img_pil
81
+
82
+ @torch.no_grad()
83
+ def invoke(self, context: InvocationContext) -> ImageOutput:
84
+ latents = context.tensors.load(self.latents.latents_name)
85
+
86
+ # Log latent statistics for debugging black image issues
87
+ context.logger.debug(
88
+ f"FLUX.2 VAE decode input: shape={latents.shape}, "
89
+ f"min={latents.min().item():.4f}, max={latents.max().item():.4f}, "
90
+ f"mean={latents.mean().item():.4f}"
91
+ )
92
+
93
+ # Warn if input latents are all zeros or very small (would cause black images)
94
+ if latents.abs().max() < 1e-6:
95
+ context.logger.warning(
96
+ "FLUX.2 VAE decode received near-zero latents! This will cause black images. "
97
+ "The latent cache may be corrupted - try clearing the cache."
98
+ )
99
+
100
+ vae_info = context.models.load(self.vae.vae)
101
+ context.util.signal_progress("Running VAE")
102
+ image = self._vae_decode(vae_info=vae_info, latents=latents)
103
+
104
+ TorchDevice.empty_cache()
105
+ image_dto = context.images.save(image=image)
106
+ return ImageOutput.build(image_dto)
@@ -0,0 +1,88 @@
1
+ """Flux2 Klein VAE Encode Invocation.
2
+
3
+ Encodes images to latents using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
4
+ """
5
+
6
+ import einops
7
+ import torch
8
+
9
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
10
+ from invokeai.app.invocations.fields import (
11
+ FieldDescriptions,
12
+ ImageField,
13
+ Input,
14
+ InputField,
15
+ )
16
+ from invokeai.app.invocations.model import VAEField
17
+ from invokeai.app.invocations.primitives import LatentsOutput
18
+ from invokeai.app.services.shared.invocation_context import InvocationContext
19
+ from invokeai.backend.model_manager.load.load_base import LoadedModel
20
+ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
21
+ from invokeai.backend.util.devices import TorchDevice
22
+
23
+
24
+ @invocation(
25
+ "flux2_vae_encode",
26
+ title="Image to Latents - FLUX2",
27
+ tags=["latents", "image", "vae", "i2l", "flux2", "klein"],
28
+ category="latents",
29
+ version="1.0.0",
30
+ classification=Classification.Prototype,
31
+ )
32
+ class Flux2VaeEncodeInvocation(BaseInvocation):
33
+ """Encodes an image into latents using FLUX.2 Klein's 32-channel VAE."""
34
+
35
+ image: ImageField = InputField(
36
+ description="The image to encode.",
37
+ )
38
+ vae: VAEField = InputField(
39
+ description=FieldDescriptions.vae,
40
+ input=Input.Connection,
41
+ )
42
+
43
+ def _vae_encode(self, vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
44
+ """Encode image to latents using FLUX.2 VAE.
45
+
46
+ The VAE encodes to 32-channel latent space.
47
+ Output latents shape: (B, 32, H/8, W/8).
48
+ """
49
+ with vae_info.model_on_device() as (_, vae):
50
+ vae_dtype = next(iter(vae.parameters())).dtype
51
+ device = TorchDevice.choose_torch_device()
52
+ image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
53
+
54
+ # Encode using diffusers API
55
+ # The VAE.encode() returns a DiagonalGaussianDistribution-like object
56
+ latent_dist = vae.encode(image_tensor, return_dict=False)[0]
57
+
58
+ # Sample from the distribution (or use mode for deterministic output)
59
+ # Using mode() for deterministic encoding
60
+ if hasattr(latent_dist, "mode"):
61
+ latents = latent_dist.mode()
62
+ elif hasattr(latent_dist, "sample"):
63
+ # Fall back to sampling if mode is not available
64
+ generator = torch.Generator(device=device).manual_seed(0)
65
+ latents = latent_dist.sample(generator=generator)
66
+ else:
67
+ # Direct tensor output (some VAE implementations)
68
+ latents = latent_dist
69
+
70
+ return latents
71
+
72
+ @torch.no_grad()
73
+ def invoke(self, context: InvocationContext) -> LatentsOutput:
74
+ image = context.images.get_pil(self.image.image_name)
75
+
76
+ vae_info = context.models.load(self.vae.vae)
77
+
78
+ # Convert image to tensor (HWC -> CHW, normalize to [-1, 1])
79
+ image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
80
+ if image_tensor.dim() == 3:
81
+ image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
82
+
83
+ context.util.signal_progress("Running VAE Encode")
84
+ latents = self._vae_encode(vae_info=vae_info, image_tensor=image_tensor)
85
+
86
+ latents = latents.to("cpu")
87
+ name = context.tensors.save(tensor=latents)
88
+ return LatentsOutput.build(latents_name=name, latents=latents, seed=None)