InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Flux2 Klein Model Loader Invocation.
|
|
2
|
+
|
|
3
|
+
Loads a Flux2 Klein model with its Qwen3 text encoder and VAE.
|
|
4
|
+
Unlike standard FLUX which uses CLIP+T5, Klein uses only Qwen3.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal, Optional
|
|
8
|
+
|
|
9
|
+
from invokeai.app.invocations.baseinvocation import (
|
|
10
|
+
BaseInvocation,
|
|
11
|
+
BaseInvocationOutput,
|
|
12
|
+
Classification,
|
|
13
|
+
invocation,
|
|
14
|
+
invocation_output,
|
|
15
|
+
)
|
|
16
|
+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
|
17
|
+
from invokeai.app.invocations.model import (
|
|
18
|
+
ModelIdentifierField,
|
|
19
|
+
Qwen3EncoderField,
|
|
20
|
+
TransformerField,
|
|
21
|
+
VAEField,
|
|
22
|
+
)
|
|
23
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
24
|
+
from invokeai.backend.model_manager.taxonomy import (
|
|
25
|
+
BaseModelType,
|
|
26
|
+
Flux2VariantType,
|
|
27
|
+
ModelFormat,
|
|
28
|
+
ModelType,
|
|
29
|
+
Qwen3VariantType,
|
|
30
|
+
SubModelType,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@invocation_output("flux2_klein_model_loader_output")
|
|
35
|
+
class Flux2KleinModelLoaderOutput(BaseInvocationOutput):
|
|
36
|
+
"""Flux2 Klein model loader output."""
|
|
37
|
+
|
|
38
|
+
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
|
39
|
+
qwen3_encoder: Qwen3EncoderField = OutputField(description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder")
|
|
40
|
+
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
|
41
|
+
max_seq_len: Literal[256, 512] = OutputField(
|
|
42
|
+
description="The max sequence length for the Qwen3 encoder.",
|
|
43
|
+
title="Max Seq Length",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@invocation(
|
|
48
|
+
"flux2_klein_model_loader",
|
|
49
|
+
title="Main Model - Flux2 Klein",
|
|
50
|
+
tags=["model", "flux", "klein", "qwen3"],
|
|
51
|
+
category="model",
|
|
52
|
+
version="1.0.0",
|
|
53
|
+
classification=Classification.Prototype,
|
|
54
|
+
)
|
|
55
|
+
class Flux2KleinModelLoaderInvocation(BaseInvocation):
|
|
56
|
+
"""Loads a Flux2 Klein model, outputting its submodels.
|
|
57
|
+
|
|
58
|
+
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
|
59
|
+
It uses a 32-channel VAE (AutoencoderKLFlux2) instead of the 16-channel FLUX.1 VAE.
|
|
60
|
+
|
|
61
|
+
When using a Diffusers format model, both VAE and Qwen3 encoder are extracted
|
|
62
|
+
automatically from the main model. You can override with standalone models:
|
|
63
|
+
- Transformer: Always from Flux2 Klein main model
|
|
64
|
+
- VAE: From main model (Diffusers) or standalone VAE
|
|
65
|
+
- Qwen3 Encoder: From main model (Diffusers) or standalone Qwen3 model
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
model: ModelIdentifierField = InputField(
|
|
69
|
+
description=FieldDescriptions.flux_model,
|
|
70
|
+
input=Input.Direct,
|
|
71
|
+
ui_model_base=BaseModelType.Flux2,
|
|
72
|
+
ui_model_type=ModelType.Main,
|
|
73
|
+
title="Transformer",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
vae_model: Optional[ModelIdentifierField] = InputField(
|
|
77
|
+
default=None,
|
|
78
|
+
description="Standalone VAE model. Flux2 Klein uses the same VAE as FLUX (16-channel). "
|
|
79
|
+
"If not provided, VAE will be loaded from the Qwen3 Source model.",
|
|
80
|
+
input=Input.Direct,
|
|
81
|
+
ui_model_base=[BaseModelType.Flux, BaseModelType.Flux2],
|
|
82
|
+
ui_model_type=ModelType.VAE,
|
|
83
|
+
title="VAE",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
qwen3_encoder_model: Optional[ModelIdentifierField] = InputField(
|
|
87
|
+
default=None,
|
|
88
|
+
description="Standalone Qwen3 Encoder model. "
|
|
89
|
+
"If not provided, encoder will be loaded from the Qwen3 Source model.",
|
|
90
|
+
input=Input.Direct,
|
|
91
|
+
ui_model_type=ModelType.Qwen3Encoder,
|
|
92
|
+
title="Qwen3 Encoder",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
qwen3_source_model: Optional[ModelIdentifierField] = InputField(
|
|
96
|
+
default=None,
|
|
97
|
+
description="Diffusers Flux2 Klein model to extract VAE and/or Qwen3 encoder from. "
|
|
98
|
+
"Use this if you don't have separate VAE/Qwen3 models. "
|
|
99
|
+
"Ignored if both VAE and Qwen3 Encoder are provided separately.",
|
|
100
|
+
input=Input.Direct,
|
|
101
|
+
ui_model_base=BaseModelType.Flux2,
|
|
102
|
+
ui_model_type=ModelType.Main,
|
|
103
|
+
ui_model_format=ModelFormat.Diffusers,
|
|
104
|
+
title="Qwen3 Source (Diffusers)",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
max_seq_len: Literal[256, 512] = InputField(
|
|
108
|
+
default=512,
|
|
109
|
+
description="Max sequence length for the Qwen3 encoder.",
|
|
110
|
+
title="Max Seq Length",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def invoke(self, context: InvocationContext) -> Flux2KleinModelLoaderOutput:
|
|
114
|
+
# Transformer always comes from the main model
|
|
115
|
+
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
|
116
|
+
|
|
117
|
+
# Check if main model is Diffusers format (can extract VAE directly)
|
|
118
|
+
main_config = context.models.get_config(self.model)
|
|
119
|
+
main_is_diffusers = main_config.format == ModelFormat.Diffusers
|
|
120
|
+
|
|
121
|
+
# Determine VAE source
|
|
122
|
+
# IMPORTANT: FLUX.2 Klein uses a 32-channel VAE (AutoencoderKLFlux2), not the 16-channel FLUX.1 VAE.
|
|
123
|
+
# The VAE should come from the FLUX.2 Klein Diffusers model, not a separate FLUX VAE.
|
|
124
|
+
if self.vae_model is not None:
|
|
125
|
+
# Use standalone VAE (user explicitly selected one)
|
|
126
|
+
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
127
|
+
elif main_is_diffusers:
|
|
128
|
+
# Extract VAE from main model (recommended for FLUX.2)
|
|
129
|
+
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
130
|
+
elif self.qwen3_source_model is not None:
|
|
131
|
+
# Extract from Qwen3 source Diffusers model
|
|
132
|
+
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
|
133
|
+
vae = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"No VAE source provided. Standalone safetensors/GGUF models require a separate VAE. "
|
|
137
|
+
"Options:\n"
|
|
138
|
+
" 1. Set 'VAE' to a standalone FLUX VAE model\n"
|
|
139
|
+
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the VAE from"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Determine Qwen3 Encoder source
|
|
143
|
+
if self.qwen3_encoder_model is not None:
|
|
144
|
+
# Use standalone Qwen3 Encoder - validate it matches the FLUX.2 Klein variant
|
|
145
|
+
self._validate_qwen3_encoder_variant(context, main_config)
|
|
146
|
+
qwen3_tokenizer = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
|
147
|
+
qwen3_encoder = self.qwen3_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
148
|
+
elif main_is_diffusers:
|
|
149
|
+
# Extract from main model (recommended for FLUX.2 Klein)
|
|
150
|
+
qwen3_tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
|
151
|
+
qwen3_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
152
|
+
elif self.qwen3_source_model is not None:
|
|
153
|
+
# Extract from separate Diffusers model
|
|
154
|
+
self._validate_diffusers_format(context, self.qwen3_source_model, "Qwen3 Source")
|
|
155
|
+
qwen3_tokenizer = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
|
156
|
+
qwen3_encoder = self.qwen3_source_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"No Qwen3 Encoder source provided. Standalone safetensors/GGUF models require a separate text encoder. "
|
|
160
|
+
"Options:\n"
|
|
161
|
+
" 1. Set 'Qwen3 Encoder' to a standalone Qwen3 text encoder model "
|
|
162
|
+
"(Klein 4B needs Qwen3 4B, Klein 9B needs Qwen3 8B)\n"
|
|
163
|
+
" 2. Set 'Qwen3 Source' to a Diffusers Flux2 Klein model to extract the encoder from"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return Flux2KleinModelLoaderOutput(
|
|
167
|
+
transformer=TransformerField(transformer=transformer, loras=[]),
|
|
168
|
+
qwen3_encoder=Qwen3EncoderField(tokenizer=qwen3_tokenizer, text_encoder=qwen3_encoder),
|
|
169
|
+
vae=VAEField(vae=vae),
|
|
170
|
+
max_seq_len=self.max_seq_len,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _validate_diffusers_format(
|
|
174
|
+
self, context: InvocationContext, model: ModelIdentifierField, model_name: str
|
|
175
|
+
) -> None:
|
|
176
|
+
"""Validate that a model is in Diffusers format."""
|
|
177
|
+
config = context.models.get_config(model)
|
|
178
|
+
if config.format != ModelFormat.Diffusers:
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"The {model_name} model must be a Diffusers format model. "
|
|
181
|
+
f"The selected model '{config.name}' is in {config.format.value} format."
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_config) -> None:
|
|
185
|
+
"""Validate that the standalone Qwen3 encoder variant matches the FLUX.2 Klein variant.
|
|
186
|
+
|
|
187
|
+
- FLUX.2 Klein 4B requires Qwen3 4B encoder
|
|
188
|
+
- FLUX.2 Klein 9B requires Qwen3 8B encoder
|
|
189
|
+
"""
|
|
190
|
+
if self.qwen3_encoder_model is None:
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
# Get the Qwen3 encoder config
|
|
194
|
+
qwen3_config = context.models.get_config(self.qwen3_encoder_model)
|
|
195
|
+
|
|
196
|
+
# Check if the config has a variant field
|
|
197
|
+
if not hasattr(qwen3_config, "variant"):
|
|
198
|
+
# Can't validate, skip
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
qwen3_variant = qwen3_config.variant
|
|
202
|
+
|
|
203
|
+
# Get the FLUX.2 Klein variant from the main model config
|
|
204
|
+
if not hasattr(main_config, "variant"):
|
|
205
|
+
return
|
|
206
|
+
|
|
207
|
+
flux2_variant = main_config.variant
|
|
208
|
+
|
|
209
|
+
# Validate the variants match
|
|
210
|
+
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
|
|
211
|
+
expected_qwen3_variant = None
|
|
212
|
+
if flux2_variant == Flux2VariantType.Klein4B:
|
|
213
|
+
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
|
|
214
|
+
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
|
|
215
|
+
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
|
|
216
|
+
|
|
217
|
+
if expected_qwen3_variant is not None and qwen3_variant != expected_qwen3_variant:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Qwen3 encoder variant mismatch: FLUX.2 Klein {flux2_variant.value} requires "
|
|
220
|
+
f"{expected_qwen3_variant.value} encoder, but {qwen3_variant.value} was selected. "
|
|
221
|
+
f"Please select a matching Qwen3 encoder or use a Diffusers format model which includes the correct encoder."
|
|
222
|
+
)
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Flux2 Klein Text Encoder Invocation.
|
|
2
|
+
|
|
3
|
+
Flux2 Klein uses Qwen3 as the text encoder instead of CLIP+T5.
|
|
4
|
+
The key difference is that it extracts hidden states from layers (9, 18, 27)
|
|
5
|
+
and stacks them together for richer text representations.
|
|
6
|
+
|
|
7
|
+
This implementation matches the diffusers Flux2KleinPipeline exactly.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from contextlib import ExitStack
|
|
11
|
+
from typing import Iterator, Literal, Optional, Tuple
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
15
|
+
|
|
16
|
+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
17
|
+
from invokeai.app.invocations.fields import (
|
|
18
|
+
FieldDescriptions,
|
|
19
|
+
FluxConditioningField,
|
|
20
|
+
Input,
|
|
21
|
+
InputField,
|
|
22
|
+
TensorField,
|
|
23
|
+
UIComponent,
|
|
24
|
+
)
|
|
25
|
+
from invokeai.app.invocations.model import Qwen3EncoderField
|
|
26
|
+
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
|
27
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
28
|
+
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
|
29
|
+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_T5_PREFIX
|
|
30
|
+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
31
|
+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
|
32
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
33
|
+
|
|
34
|
+
# FLUX.2 Klein extracts hidden states from these specific layers
|
|
35
|
+
# Matching diffusers Flux2KleinPipeline: (9, 18, 27)
|
|
36
|
+
# hidden_states[0] is embedding layer, so layer N is at index N
|
|
37
|
+
KLEIN_EXTRACTION_LAYERS = (9, 18, 27)
|
|
38
|
+
|
|
39
|
+
# Default max sequence length for Klein models
|
|
40
|
+
KLEIN_MAX_SEQ_LEN = 512
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@invocation(
|
|
44
|
+
"flux2_klein_text_encoder",
|
|
45
|
+
title="Prompt - Flux2 Klein",
|
|
46
|
+
tags=["prompt", "conditioning", "flux", "klein", "qwen3"],
|
|
47
|
+
category="conditioning",
|
|
48
|
+
version="1.1.0",
|
|
49
|
+
classification=Classification.Prototype,
|
|
50
|
+
)
|
|
51
|
+
class Flux2KleinTextEncoderInvocation(BaseInvocation):
|
|
52
|
+
"""Encodes and preps a prompt for Flux2 Klein image generation.
|
|
53
|
+
|
|
54
|
+
Flux2 Klein uses Qwen3 as the text encoder, extracting hidden states from
|
|
55
|
+
layers (9, 18, 27) and stacking them for richer text representations.
|
|
56
|
+
This matches the diffusers Flux2KleinPipeline implementation exactly.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
|
|
60
|
+
qwen3_encoder: Qwen3EncoderField = InputField(
|
|
61
|
+
title="Qwen3 Encoder",
|
|
62
|
+
description=FieldDescriptions.qwen3_encoder,
|
|
63
|
+
input=Input.Connection,
|
|
64
|
+
)
|
|
65
|
+
max_seq_len: Literal[256, 512] = InputField(
|
|
66
|
+
default=512,
|
|
67
|
+
description="Max sequence length for the Qwen3 encoder.",
|
|
68
|
+
)
|
|
69
|
+
mask: Optional[TensorField] = InputField(
|
|
70
|
+
default=None,
|
|
71
|
+
description="A mask defining the region that this conditioning prompt applies to.",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@torch.no_grad()
|
|
75
|
+
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
|
76
|
+
qwen3_embeds, pooled_embeds = self._encode_prompt(context)
|
|
77
|
+
|
|
78
|
+
# Use FLUXConditioningInfo for compatibility with existing Flux denoiser
|
|
79
|
+
# t5_embeds -> qwen3 stacked embeddings
|
|
80
|
+
# clip_embeds -> pooled qwen3 embedding
|
|
81
|
+
conditioning_data = ConditioningFieldData(
|
|
82
|
+
conditionings=[FLUXConditioningInfo(clip_embeds=pooled_embeds, t5_embeds=qwen3_embeds)]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
conditioning_name = context.conditioning.save(conditioning_data)
|
|
86
|
+
return FluxConditioningOutput(
|
|
87
|
+
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
91
|
+
"""Encode prompt using Qwen3 text encoder with Klein-style layer extraction.
|
|
92
|
+
|
|
93
|
+
This matches the diffusers Flux2KleinPipeline._get_qwen3_prompt_embeds() exactly.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple of (stacked_embeddings, pooled_embedding):
|
|
97
|
+
- stacked_embeddings: Hidden states from layers (9, 18, 27) stacked together.
|
|
98
|
+
Shape: (1, seq_len, hidden_size * 3)
|
|
99
|
+
- pooled_embedding: Pooled representation for global conditioning.
|
|
100
|
+
Shape: (1, hidden_size)
|
|
101
|
+
"""
|
|
102
|
+
prompt = self.prompt
|
|
103
|
+
device = TorchDevice.choose_torch_device()
|
|
104
|
+
|
|
105
|
+
text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
|
|
106
|
+
tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
|
|
107
|
+
|
|
108
|
+
with ExitStack() as exit_stack:
|
|
109
|
+
(cached_weights, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
|
|
110
|
+
(_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
|
|
111
|
+
|
|
112
|
+
# Apply LoRA models to the text encoder
|
|
113
|
+
lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
|
|
114
|
+
exit_stack.enter_context(
|
|
115
|
+
LayerPatcher.apply_smart_model_patches(
|
|
116
|
+
model=text_encoder,
|
|
117
|
+
patches=self._lora_iterator(context),
|
|
118
|
+
prefix=FLUX_LORA_T5_PREFIX, # Reuse T5 prefix for Qwen3 LoRAs
|
|
119
|
+
dtype=lora_dtype,
|
|
120
|
+
cached_weights=cached_weights,
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
context.util.signal_progress("Running Qwen3 text encoder (Klein)")
|
|
125
|
+
|
|
126
|
+
if not isinstance(text_encoder, PreTrainedModel):
|
|
127
|
+
raise TypeError(
|
|
128
|
+
f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
|
|
129
|
+
"The Qwen3 encoder model may be corrupted or incompatible."
|
|
130
|
+
)
|
|
131
|
+
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
132
|
+
raise TypeError(
|
|
133
|
+
f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
|
|
134
|
+
"The Qwen3 tokenizer may be corrupted or incompatible."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Format messages exactly like diffusers Flux2KleinPipeline:
|
|
138
|
+
# - Only user message, NO system message
|
|
139
|
+
# - add_generation_prompt=True (adds assistant prefix)
|
|
140
|
+
# - enable_thinking=False
|
|
141
|
+
messages = [{"role": "user", "content": prompt}]
|
|
142
|
+
|
|
143
|
+
# Step 1: Apply chat template to get formatted text (tokenize=False)
|
|
144
|
+
text: str = tokenizer.apply_chat_template( # type: ignore[assignment]
|
|
145
|
+
messages,
|
|
146
|
+
tokenize=False,
|
|
147
|
+
add_generation_prompt=True, # Adds assistant prefix like diffusers
|
|
148
|
+
enable_thinking=False, # Disable thinking mode
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Step 2: Tokenize the formatted text
|
|
152
|
+
inputs = tokenizer(
|
|
153
|
+
text,
|
|
154
|
+
return_tensors="pt",
|
|
155
|
+
padding="max_length",
|
|
156
|
+
truncation=True,
|
|
157
|
+
max_length=self.max_seq_len,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
input_ids = inputs["input_ids"]
|
|
161
|
+
attention_mask = inputs["attention_mask"]
|
|
162
|
+
|
|
163
|
+
# Move to device
|
|
164
|
+
input_ids = input_ids.to(device)
|
|
165
|
+
attention_mask = attention_mask.to(device)
|
|
166
|
+
|
|
167
|
+
# Forward pass through the model - matching diffusers exactly
|
|
168
|
+
outputs = text_encoder(
|
|
169
|
+
input_ids=input_ids,
|
|
170
|
+
attention_mask=attention_mask,
|
|
171
|
+
output_hidden_states=True,
|
|
172
|
+
use_cache=False,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Validate hidden_states output
|
|
176
|
+
if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
|
|
177
|
+
raise RuntimeError(
|
|
178
|
+
"Text encoder did not return hidden_states. "
|
|
179
|
+
"Ensure output_hidden_states=True is supported by this model."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
num_hidden_layers = len(outputs.hidden_states)
|
|
183
|
+
|
|
184
|
+
# Extract and stack hidden states - EXACTLY like diffusers:
|
|
185
|
+
# out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
|
186
|
+
# prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
|
187
|
+
hidden_states_list = []
|
|
188
|
+
for layer_idx in KLEIN_EXTRACTION_LAYERS:
|
|
189
|
+
if layer_idx >= num_hidden_layers:
|
|
190
|
+
layer_idx = num_hidden_layers - 1
|
|
191
|
+
hidden_states_list.append(outputs.hidden_states[layer_idx])
|
|
192
|
+
|
|
193
|
+
# Stack along dim=1, then permute and reshape - exactly like diffusers
|
|
194
|
+
out = torch.stack(hidden_states_list, dim=1)
|
|
195
|
+
out = out.to(dtype=text_encoder.dtype, device=device)
|
|
196
|
+
|
|
197
|
+
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
|
198
|
+
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
|
199
|
+
|
|
200
|
+
# Create pooled embedding for global conditioning
|
|
201
|
+
# Use mean pooling over the sequence (excluding padding)
|
|
202
|
+
# This serves a similar role to CLIP's pooled output in standard FLUX
|
|
203
|
+
last_hidden_state = outputs.hidden_states[-1] # Use last layer for pooling
|
|
204
|
+
# Expand mask to match hidden state dimensions
|
|
205
|
+
expanded_mask = attention_mask.unsqueeze(-1).expand_as(last_hidden_state).float()
|
|
206
|
+
sum_embeds = (last_hidden_state * expanded_mask).sum(dim=1)
|
|
207
|
+
num_tokens = expanded_mask.sum(dim=1).clamp(min=1)
|
|
208
|
+
pooled_embeds = sum_embeds / num_tokens
|
|
209
|
+
|
|
210
|
+
return prompt_embeds, pooled_embeds
|
|
211
|
+
|
|
212
|
+
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
|
213
|
+
"""Iterate over LoRA models to apply to the Qwen3 text encoder."""
|
|
214
|
+
for lora in self.qwen3_encoder.loras:
|
|
215
|
+
lora_info = context.models.load(lora.lora)
|
|
216
|
+
if not isinstance(lora_info.model, ModelPatchRaw):
|
|
217
|
+
raise TypeError(
|
|
218
|
+
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
|
|
219
|
+
"The LoRA model may be corrupted or incompatible."
|
|
220
|
+
)
|
|
221
|
+
yield (lora_info.model, lora.weight)
|
|
222
|
+
del lora_info
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Flux2 Klein VAE Decode Invocation.
|
|
2
|
+
|
|
3
|
+
Decodes latents to images using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
11
|
+
from invokeai.app.invocations.fields import (
|
|
12
|
+
FieldDescriptions,
|
|
13
|
+
Input,
|
|
14
|
+
InputField,
|
|
15
|
+
LatentsField,
|
|
16
|
+
WithBoard,
|
|
17
|
+
WithMetadata,
|
|
18
|
+
)
|
|
19
|
+
from invokeai.app.invocations.model import VAEField
|
|
20
|
+
from invokeai.app.invocations.primitives import ImageOutput
|
|
21
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
22
|
+
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
|
23
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@invocation(
|
|
27
|
+
"flux2_vae_decode",
|
|
28
|
+
title="Latents to Image - FLUX2",
|
|
29
|
+
tags=["latents", "image", "vae", "l2i", "flux2", "klein"],
|
|
30
|
+
category="latents",
|
|
31
|
+
version="1.0.0",
|
|
32
|
+
classification=Classification.Prototype,
|
|
33
|
+
)
|
|
34
|
+
class Flux2VaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
35
|
+
"""Generates an image from latents using FLUX.2 Klein's 32-channel VAE."""
|
|
36
|
+
|
|
37
|
+
latents: LatentsField = InputField(
|
|
38
|
+
description=FieldDescriptions.latents,
|
|
39
|
+
input=Input.Connection,
|
|
40
|
+
)
|
|
41
|
+
vae: VAEField = InputField(
|
|
42
|
+
description=FieldDescriptions.vae,
|
|
43
|
+
input=Input.Connection,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
|
47
|
+
"""Decode latents to image using FLUX.2 VAE.
|
|
48
|
+
|
|
49
|
+
Input latents should already be in the correct space after BN denormalization
|
|
50
|
+
was applied in the denoiser. The VAE expects (B, 32, H, W) format.
|
|
51
|
+
"""
|
|
52
|
+
with vae_info.model_on_device() as (_, vae):
|
|
53
|
+
vae_dtype = next(iter(vae.parameters())).dtype
|
|
54
|
+
device = TorchDevice.choose_torch_device()
|
|
55
|
+
latents = latents.to(device=device, dtype=vae_dtype)
|
|
56
|
+
|
|
57
|
+
# Decode using diffusers API
|
|
58
|
+
decoded = vae.decode(latents, return_dict=False)[0]
|
|
59
|
+
|
|
60
|
+
# Debug: Log decoded output statistics
|
|
61
|
+
print(
|
|
62
|
+
f"[FLUX.2 VAE] Decoded output: shape={decoded.shape}, "
|
|
63
|
+
f"min={decoded.min().item():.4f}, max={decoded.max().item():.4f}, "
|
|
64
|
+
f"mean={decoded.mean().item():.4f}"
|
|
65
|
+
)
|
|
66
|
+
# Check per-channel statistics to diagnose color issues
|
|
67
|
+
for c in range(min(3, decoded.shape[1])):
|
|
68
|
+
ch = decoded[0, c]
|
|
69
|
+
print(
|
|
70
|
+
f"[FLUX.2 VAE] Channel {c}: min={ch.min().item():.4f}, "
|
|
71
|
+
f"max={ch.max().item():.4f}, mean={ch.mean().item():.4f}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Convert from [-1, 1] to [0, 1] then to [0, 255] PIL image
|
|
75
|
+
img = (decoded / 2 + 0.5).clamp(0, 1)
|
|
76
|
+
img = rearrange(img[0], "c h w -> h w c")
|
|
77
|
+
img_np = (img * 255).byte().cpu().numpy()
|
|
78
|
+
# Explicitly create RGB image (not grayscale)
|
|
79
|
+
img_pil = Image.fromarray(img_np, mode="RGB")
|
|
80
|
+
return img_pil
|
|
81
|
+
|
|
82
|
+
@torch.no_grad()
|
|
83
|
+
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
84
|
+
latents = context.tensors.load(self.latents.latents_name)
|
|
85
|
+
|
|
86
|
+
# Log latent statistics for debugging black image issues
|
|
87
|
+
context.logger.debug(
|
|
88
|
+
f"FLUX.2 VAE decode input: shape={latents.shape}, "
|
|
89
|
+
f"min={latents.min().item():.4f}, max={latents.max().item():.4f}, "
|
|
90
|
+
f"mean={latents.mean().item():.4f}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Warn if input latents are all zeros or very small (would cause black images)
|
|
94
|
+
if latents.abs().max() < 1e-6:
|
|
95
|
+
context.logger.warning(
|
|
96
|
+
"FLUX.2 VAE decode received near-zero latents! This will cause black images. "
|
|
97
|
+
"The latent cache may be corrupted - try clearing the cache."
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
vae_info = context.models.load(self.vae.vae)
|
|
101
|
+
context.util.signal_progress("Running VAE")
|
|
102
|
+
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
|
103
|
+
|
|
104
|
+
TorchDevice.empty_cache()
|
|
105
|
+
image_dto = context.images.save(image=image)
|
|
106
|
+
return ImageOutput.build(image_dto)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Flux2 Klein VAE Encode Invocation.
|
|
2
|
+
|
|
3
|
+
Encodes images to latents using the FLUX.2 32-channel VAE (AutoencoderKLFlux2).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import einops
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
10
|
+
from invokeai.app.invocations.fields import (
|
|
11
|
+
FieldDescriptions,
|
|
12
|
+
ImageField,
|
|
13
|
+
Input,
|
|
14
|
+
InputField,
|
|
15
|
+
)
|
|
16
|
+
from invokeai.app.invocations.model import VAEField
|
|
17
|
+
from invokeai.app.invocations.primitives import LatentsOutput
|
|
18
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
19
|
+
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
|
20
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
|
21
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@invocation(
|
|
25
|
+
"flux2_vae_encode",
|
|
26
|
+
title="Image to Latents - FLUX2",
|
|
27
|
+
tags=["latents", "image", "vae", "i2l", "flux2", "klein"],
|
|
28
|
+
category="latents",
|
|
29
|
+
version="1.0.0",
|
|
30
|
+
classification=Classification.Prototype,
|
|
31
|
+
)
|
|
32
|
+
class Flux2VaeEncodeInvocation(BaseInvocation):
|
|
33
|
+
"""Encodes an image into latents using FLUX.2 Klein's 32-channel VAE."""
|
|
34
|
+
|
|
35
|
+
image: ImageField = InputField(
|
|
36
|
+
description="The image to encode.",
|
|
37
|
+
)
|
|
38
|
+
vae: VAEField = InputField(
|
|
39
|
+
description=FieldDescriptions.vae,
|
|
40
|
+
input=Input.Connection,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def _vae_encode(self, vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
|
44
|
+
"""Encode image to latents using FLUX.2 VAE.
|
|
45
|
+
|
|
46
|
+
The VAE encodes to 32-channel latent space.
|
|
47
|
+
Output latents shape: (B, 32, H/8, W/8).
|
|
48
|
+
"""
|
|
49
|
+
with vae_info.model_on_device() as (_, vae):
|
|
50
|
+
vae_dtype = next(iter(vae.parameters())).dtype
|
|
51
|
+
device = TorchDevice.choose_torch_device()
|
|
52
|
+
image_tensor = image_tensor.to(device=device, dtype=vae_dtype)
|
|
53
|
+
|
|
54
|
+
# Encode using diffusers API
|
|
55
|
+
# The VAE.encode() returns a DiagonalGaussianDistribution-like object
|
|
56
|
+
latent_dist = vae.encode(image_tensor, return_dict=False)[0]
|
|
57
|
+
|
|
58
|
+
# Sample from the distribution (or use mode for deterministic output)
|
|
59
|
+
# Using mode() for deterministic encoding
|
|
60
|
+
if hasattr(latent_dist, "mode"):
|
|
61
|
+
latents = latent_dist.mode()
|
|
62
|
+
elif hasattr(latent_dist, "sample"):
|
|
63
|
+
# Fall back to sampling if mode is not available
|
|
64
|
+
generator = torch.Generator(device=device).manual_seed(0)
|
|
65
|
+
latents = latent_dist.sample(generator=generator)
|
|
66
|
+
else:
|
|
67
|
+
# Direct tensor output (some VAE implementations)
|
|
68
|
+
latents = latent_dist
|
|
69
|
+
|
|
70
|
+
return latents
|
|
71
|
+
|
|
72
|
+
@torch.no_grad()
|
|
73
|
+
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
74
|
+
image = context.images.get_pil(self.image.image_name)
|
|
75
|
+
|
|
76
|
+
vae_info = context.models.load(self.vae.vae)
|
|
77
|
+
|
|
78
|
+
# Convert image to tensor (HWC -> CHW, normalize to [-1, 1])
|
|
79
|
+
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
80
|
+
if image_tensor.dim() == 3:
|
|
81
|
+
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
|
82
|
+
|
|
83
|
+
context.util.signal_progress("Running VAE Encode")
|
|
84
|
+
latents = self._vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
|
85
|
+
|
|
86
|
+
latents = latents.to("cpu")
|
|
87
|
+
name = context.tensors.save(tensor=latents)
|
|
88
|
+
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|