InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -219,7 +219,16 @@ async def reidentify_model(
|
|
|
219
219
|
result = ModelConfigFactory.from_model_on_disk(mod)
|
|
220
220
|
if result.config is None:
|
|
221
221
|
raise InvalidModelException("Unable to identify model format")
|
|
222
|
-
|
|
222
|
+
|
|
223
|
+
# Retain user-editable fields from the original config
|
|
224
|
+
result.config.key = config.key
|
|
225
|
+
result.config.name = config.name
|
|
226
|
+
result.config.description = config.description
|
|
227
|
+
result.config.cover_image = config.cover_image
|
|
228
|
+
result.config.trigger_phrases = config.trigger_phrases
|
|
229
|
+
result.config.source = config.source
|
|
230
|
+
result.config.source_type = config.source_type
|
|
231
|
+
|
|
223
232
|
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
|
|
224
233
|
return new_config
|
|
225
234
|
except UnknownModelException as e:
|
|
@@ -905,15 +914,48 @@ class StarterModelResponse(BaseModel):
|
|
|
905
914
|
def get_is_installed(
|
|
906
915
|
starter_model: StarterModel | StarterModelWithoutDependencies, installed_models: list[AnyModelConfig]
|
|
907
916
|
) -> bool:
|
|
917
|
+
from invokeai.backend.model_manager.taxonomy import ModelType
|
|
918
|
+
|
|
908
919
|
for model in installed_models:
|
|
920
|
+
# Check if source matches exactly
|
|
909
921
|
if model.source == starter_model.source:
|
|
910
922
|
return True
|
|
923
|
+
# Check if name (or previous names), base and type match
|
|
911
924
|
if (
|
|
912
925
|
(model.name == starter_model.name or model.name in starter_model.previous_names)
|
|
913
926
|
and model.base == starter_model.base
|
|
914
927
|
and model.type == starter_model.type
|
|
915
928
|
):
|
|
916
929
|
return True
|
|
930
|
+
|
|
931
|
+
# Special handling for Qwen3Encoder models - check by type and variant
|
|
932
|
+
# This allows renamed models to still be detected as installed
|
|
933
|
+
if starter_model.type == ModelType.Qwen3Encoder:
|
|
934
|
+
from invokeai.backend.model_manager.taxonomy import Qwen3VariantType
|
|
935
|
+
|
|
936
|
+
# Determine expected variant from source pattern
|
|
937
|
+
expected_variant: Qwen3VariantType | None = None
|
|
938
|
+
if "klein-9B" in starter_model.source or "qwen3_8b" in starter_model.source.lower():
|
|
939
|
+
expected_variant = Qwen3VariantType.Qwen3_8B
|
|
940
|
+
elif (
|
|
941
|
+
"klein-4B" in starter_model.source
|
|
942
|
+
or "qwen3_4b" in starter_model.source.lower()
|
|
943
|
+
or "Z-Image" in starter_model.source
|
|
944
|
+
):
|
|
945
|
+
expected_variant = Qwen3VariantType.Qwen3_4B
|
|
946
|
+
|
|
947
|
+
if expected_variant is not None:
|
|
948
|
+
for model in installed_models:
|
|
949
|
+
if model.type == ModelType.Qwen3Encoder and hasattr(model, "variant"):
|
|
950
|
+
model_variant = model.variant
|
|
951
|
+
# Handle both enum and string values
|
|
952
|
+
if isinstance(model_variant, Qwen3VariantType):
|
|
953
|
+
if model_variant == expected_variant:
|
|
954
|
+
return True
|
|
955
|
+
elif isinstance(model_variant, str):
|
|
956
|
+
if model_variant == expected_variant.value:
|
|
957
|
+
return True
|
|
958
|
+
|
|
917
959
|
return False
|
|
918
960
|
|
|
919
961
|
|
|
@@ -532,7 +532,7 @@ def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, An
|
|
|
532
532
|
case UIType.VAEModel:
|
|
533
533
|
ui_model_type = [ModelType.VAE]
|
|
534
534
|
case UIType.FluxVAEModel:
|
|
535
|
-
ui_model_base = [BaseModelType.Flux]
|
|
535
|
+
ui_model_base = [BaseModelType.Flux, BaseModelType.Flux2]
|
|
536
536
|
ui_model_type = [ModelType.VAE]
|
|
537
537
|
case UIType.LoRAModel:
|
|
538
538
|
ui_model_type = [ModelType.LoRA]
|
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
"""Flux2 Klein Denoise Invocation.
|
|
2
|
+
|
|
3
|
+
Run denoising process with a FLUX.2 Klein transformer model.
|
|
4
|
+
Uses Qwen3 conditioning instead of CLIP+T5.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from contextlib import ExitStack
|
|
8
|
+
from typing import Callable, Iterator, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torchvision.transforms as tv_transforms
|
|
12
|
+
from torchvision.transforms.functional import resize as tv_resize
|
|
13
|
+
|
|
14
|
+
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
|
15
|
+
from invokeai.app.invocations.fields import (
|
|
16
|
+
DenoiseMaskField,
|
|
17
|
+
FieldDescriptions,
|
|
18
|
+
FluxConditioningField,
|
|
19
|
+
FluxKontextConditioningField,
|
|
20
|
+
Input,
|
|
21
|
+
InputField,
|
|
22
|
+
LatentsField,
|
|
23
|
+
)
|
|
24
|
+
from invokeai.app.invocations.model import TransformerField, VAEField
|
|
25
|
+
from invokeai.app.invocations.primitives import LatentsOutput
|
|
26
|
+
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
27
|
+
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
|
28
|
+
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
|
|
29
|
+
from invokeai.backend.flux2.denoise import denoise
|
|
30
|
+
from invokeai.backend.flux2.ref_image_extension import Flux2RefImageExtension
|
|
31
|
+
from invokeai.backend.flux2.sampling_utils import (
|
|
32
|
+
compute_empirical_mu,
|
|
33
|
+
generate_img_ids_flux2,
|
|
34
|
+
get_noise_flux2,
|
|
35
|
+
get_schedule_flux2,
|
|
36
|
+
pack_flux2,
|
|
37
|
+
unpack_flux2,
|
|
38
|
+
)
|
|
39
|
+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
|
40
|
+
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
|
41
|
+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
|
42
|
+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
|
43
|
+
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
44
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
45
|
+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
|
46
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@invocation(
|
|
50
|
+
"flux2_denoise",
|
|
51
|
+
title="FLUX2 Denoise",
|
|
52
|
+
tags=["image", "flux", "flux2", "klein", "denoise"],
|
|
53
|
+
category="image",
|
|
54
|
+
version="1.3.0",
|
|
55
|
+
classification=Classification.Prototype,
|
|
56
|
+
)
|
|
57
|
+
class Flux2DenoiseInvocation(BaseInvocation):
|
|
58
|
+
"""Run denoising process with a FLUX.2 Klein transformer model.
|
|
59
|
+
|
|
60
|
+
This node is designed for FLUX.2 Klein models which use Qwen3 as the text encoder.
|
|
61
|
+
It does not support ControlNet, IP-Adapters, or regional prompting.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
latents: Optional[LatentsField] = InputField(
|
|
65
|
+
default=None,
|
|
66
|
+
description=FieldDescriptions.latents,
|
|
67
|
+
input=Input.Connection,
|
|
68
|
+
)
|
|
69
|
+
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
|
70
|
+
default=None,
|
|
71
|
+
description=FieldDescriptions.denoise_mask,
|
|
72
|
+
input=Input.Connection,
|
|
73
|
+
)
|
|
74
|
+
denoising_start: float = InputField(
|
|
75
|
+
default=0.0,
|
|
76
|
+
ge=0,
|
|
77
|
+
le=1,
|
|
78
|
+
description=FieldDescriptions.denoising_start,
|
|
79
|
+
)
|
|
80
|
+
denoising_end: float = InputField(
|
|
81
|
+
default=1.0,
|
|
82
|
+
ge=0,
|
|
83
|
+
le=1,
|
|
84
|
+
description=FieldDescriptions.denoising_end,
|
|
85
|
+
)
|
|
86
|
+
add_noise: bool = InputField(default=True, description="Add noise based on denoising start.")
|
|
87
|
+
transformer: TransformerField = InputField(
|
|
88
|
+
description=FieldDescriptions.flux_model,
|
|
89
|
+
input=Input.Connection,
|
|
90
|
+
title="Transformer",
|
|
91
|
+
)
|
|
92
|
+
positive_text_conditioning: FluxConditioningField = InputField(
|
|
93
|
+
description=FieldDescriptions.positive_cond,
|
|
94
|
+
input=Input.Connection,
|
|
95
|
+
)
|
|
96
|
+
negative_text_conditioning: Optional[FluxConditioningField] = InputField(
|
|
97
|
+
default=None,
|
|
98
|
+
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
|
|
99
|
+
input=Input.Connection,
|
|
100
|
+
)
|
|
101
|
+
cfg_scale: float = InputField(
|
|
102
|
+
default=1.0,
|
|
103
|
+
description=FieldDescriptions.cfg_scale,
|
|
104
|
+
title="CFG Scale",
|
|
105
|
+
)
|
|
106
|
+
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
|
107
|
+
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
|
108
|
+
num_steps: int = InputField(
|
|
109
|
+
default=4,
|
|
110
|
+
description="Number of diffusion steps. Use 4 for distilled models, 28+ for base models.",
|
|
111
|
+
)
|
|
112
|
+
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
|
|
113
|
+
default="euler",
|
|
114
|
+
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
|
|
115
|
+
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
|
|
116
|
+
ui_choice_labels=FLUX_SCHEDULER_LABELS,
|
|
117
|
+
)
|
|
118
|
+
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
|
119
|
+
vae: VAEField = InputField(
|
|
120
|
+
description="FLUX.2 VAE model (required for BN statistics).",
|
|
121
|
+
input=Input.Connection,
|
|
122
|
+
)
|
|
123
|
+
kontext_conditioning: FluxKontextConditioningField | list[FluxKontextConditioningField] | None = InputField(
|
|
124
|
+
default=None,
|
|
125
|
+
description="FLUX Kontext conditioning (reference images for multi-reference image editing).",
|
|
126
|
+
input=Input.Connection,
|
|
127
|
+
title="Reference Images",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _get_bn_stats(self, context: InvocationContext) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
|
131
|
+
"""Extract BN statistics from the FLUX.2 VAE.
|
|
132
|
+
|
|
133
|
+
The FLUX.2 VAE uses batch normalization on the patchified 128-channel representation.
|
|
134
|
+
IMPORTANT: BFL FLUX.2 VAE uses affine=False, so there are NO learnable weight/bias.
|
|
135
|
+
|
|
136
|
+
BN formula (affine=False): y = (x - mean) / std
|
|
137
|
+
Inverse: x = y * std + mean
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Tuple of (bn_mean, bn_std) tensors of shape (128,), or None if BN layer not found.
|
|
141
|
+
"""
|
|
142
|
+
with context.models.load(self.vae.vae).model_on_device() as (_, vae):
|
|
143
|
+
# Ensure VAE is in eval mode to prevent BN stats from being updated
|
|
144
|
+
vae.eval()
|
|
145
|
+
|
|
146
|
+
# Try to find the BN layer - it may be at different locations depending on model format
|
|
147
|
+
bn_layer = None
|
|
148
|
+
if hasattr(vae, "bn"):
|
|
149
|
+
bn_layer = vae.bn
|
|
150
|
+
elif hasattr(vae, "batch_norm"):
|
|
151
|
+
bn_layer = vae.batch_norm
|
|
152
|
+
elif hasattr(vae, "encoder") and hasattr(vae.encoder, "bn"):
|
|
153
|
+
bn_layer = vae.encoder.bn
|
|
154
|
+
|
|
155
|
+
if bn_layer is None:
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
# Verify running statistics are initialized
|
|
159
|
+
if bn_layer.running_mean is None or bn_layer.running_var is None:
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
# Get BN running statistics from VAE
|
|
163
|
+
bn_mean = bn_layer.running_mean.clone() # Shape: (128,)
|
|
164
|
+
bn_var = bn_layer.running_var.clone() # Shape: (128,)
|
|
165
|
+
bn_eps = bn_layer.eps if hasattr(bn_layer, "eps") else 1e-4 # BFL uses 1e-4
|
|
166
|
+
bn_std = torch.sqrt(bn_var + bn_eps)
|
|
167
|
+
|
|
168
|
+
return bn_mean, bn_std
|
|
169
|
+
|
|
170
|
+
def _bn_normalize(
|
|
171
|
+
self,
|
|
172
|
+
x: torch.Tensor,
|
|
173
|
+
bn_mean: torch.Tensor,
|
|
174
|
+
bn_std: torch.Tensor,
|
|
175
|
+
) -> torch.Tensor:
|
|
176
|
+
"""Apply BN normalization to packed latents.
|
|
177
|
+
|
|
178
|
+
BN formula (affine=False): y = (x - mean) / std
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
x: Packed latents of shape (B, seq, 128).
|
|
182
|
+
bn_mean: BN running mean of shape (128,).
|
|
183
|
+
bn_std: BN running std of shape (128,).
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Normalized latents of same shape.
|
|
187
|
+
"""
|
|
188
|
+
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
|
189
|
+
bn_mean = bn_mean.to(x.device, x.dtype)
|
|
190
|
+
bn_std = bn_std.to(x.device, x.dtype)
|
|
191
|
+
return (x - bn_mean) / bn_std
|
|
192
|
+
|
|
193
|
+
def _bn_denormalize(
|
|
194
|
+
self,
|
|
195
|
+
x: torch.Tensor,
|
|
196
|
+
bn_mean: torch.Tensor,
|
|
197
|
+
bn_std: torch.Tensor,
|
|
198
|
+
) -> torch.Tensor:
|
|
199
|
+
"""Apply BN denormalization to packed latents (inverse of normalization).
|
|
200
|
+
|
|
201
|
+
Inverse BN (affine=False): x = y * std + mean
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
x: Packed latents of shape (B, seq, 128).
|
|
205
|
+
bn_mean: BN running mean of shape (128,).
|
|
206
|
+
bn_std: BN running std of shape (128,).
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Denormalized latents of same shape.
|
|
210
|
+
"""
|
|
211
|
+
# x: (B, seq, 128), params: (128,) -> broadcast over batch and sequence dims
|
|
212
|
+
bn_mean = bn_mean.to(x.device, x.dtype)
|
|
213
|
+
bn_std = bn_std.to(x.device, x.dtype)
|
|
214
|
+
return x * bn_std + bn_mean
|
|
215
|
+
|
|
216
|
+
@torch.no_grad()
|
|
217
|
+
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
218
|
+
latents = self._run_diffusion(context)
|
|
219
|
+
latents = latents.detach().to("cpu")
|
|
220
|
+
|
|
221
|
+
name = context.tensors.save(tensor=latents)
|
|
222
|
+
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
|
223
|
+
|
|
224
|
+
def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
|
|
225
|
+
inference_dtype = torch.bfloat16
|
|
226
|
+
device = TorchDevice.choose_torch_device()
|
|
227
|
+
|
|
228
|
+
# Get BN statistics from VAE for latent denormalization (optional)
|
|
229
|
+
# BFL FLUX.2 VAE uses affine=False, so only mean/std are needed
|
|
230
|
+
# Some VAE formats (e.g. diffusers) may not expose BN stats directly
|
|
231
|
+
bn_stats = self._get_bn_stats(context)
|
|
232
|
+
bn_mean, bn_std = bn_stats if bn_stats is not None else (None, None)
|
|
233
|
+
|
|
234
|
+
# Load the input latents, if provided
|
|
235
|
+
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
|
236
|
+
if init_latents is not None:
|
|
237
|
+
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
|
238
|
+
|
|
239
|
+
# Prepare input noise (FLUX.2 uses 32 channels)
|
|
240
|
+
noise = get_noise_flux2(
|
|
241
|
+
num_samples=1,
|
|
242
|
+
height=self.height,
|
|
243
|
+
width=self.width,
|
|
244
|
+
device=device,
|
|
245
|
+
dtype=inference_dtype,
|
|
246
|
+
seed=self.seed,
|
|
247
|
+
)
|
|
248
|
+
b, _c, latent_h, latent_w = noise.shape
|
|
249
|
+
packed_h = latent_h // 2
|
|
250
|
+
packed_w = latent_w // 2
|
|
251
|
+
|
|
252
|
+
# Load the conditioning data
|
|
253
|
+
pos_cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
|
254
|
+
assert len(pos_cond_data.conditionings) == 1
|
|
255
|
+
pos_flux_conditioning = pos_cond_data.conditionings[0]
|
|
256
|
+
assert isinstance(pos_flux_conditioning, FLUXConditioningInfo)
|
|
257
|
+
pos_flux_conditioning = pos_flux_conditioning.to(dtype=inference_dtype, device=device)
|
|
258
|
+
|
|
259
|
+
# Qwen3 stacked embeddings (stored in t5_embeds field for compatibility)
|
|
260
|
+
txt = pos_flux_conditioning.t5_embeds
|
|
261
|
+
|
|
262
|
+
# Generate text position IDs (4D format for FLUX.2: T, H, W, L)
|
|
263
|
+
# FLUX.2 uses 4D position coordinates for its rotary position embeddings
|
|
264
|
+
# IMPORTANT: Position IDs must be int64 (long) dtype
|
|
265
|
+
# Diffusers uses: T=0, H=0, W=0, L=0..seq_len-1
|
|
266
|
+
seq_len = txt.shape[1]
|
|
267
|
+
txt_ids = torch.zeros(1, seq_len, 4, device=device, dtype=torch.long)
|
|
268
|
+
txt_ids[..., 3] = torch.arange(seq_len, device=device, dtype=torch.long) # L coordinate varies
|
|
269
|
+
|
|
270
|
+
# Load negative conditioning if provided
|
|
271
|
+
neg_txt = None
|
|
272
|
+
neg_txt_ids = None
|
|
273
|
+
if self.negative_text_conditioning is not None:
|
|
274
|
+
neg_cond_data = context.conditioning.load(self.negative_text_conditioning.conditioning_name)
|
|
275
|
+
assert len(neg_cond_data.conditionings) == 1
|
|
276
|
+
neg_flux_conditioning = neg_cond_data.conditionings[0]
|
|
277
|
+
assert isinstance(neg_flux_conditioning, FLUXConditioningInfo)
|
|
278
|
+
neg_flux_conditioning = neg_flux_conditioning.to(dtype=inference_dtype, device=device)
|
|
279
|
+
neg_txt = neg_flux_conditioning.t5_embeds
|
|
280
|
+
# For text tokens: T=0, H=0, W=0, L=0..seq_len-1 (only L varies per token)
|
|
281
|
+
neg_seq_len = neg_txt.shape[1]
|
|
282
|
+
neg_txt_ids = torch.zeros(1, neg_seq_len, 4, device=device, dtype=torch.long)
|
|
283
|
+
neg_txt_ids[..., 3] = torch.arange(neg_seq_len, device=device, dtype=torch.long)
|
|
284
|
+
|
|
285
|
+
# Validate transformer config
|
|
286
|
+
transformer_config = context.models.get_config(self.transformer.transformer)
|
|
287
|
+
assert transformer_config.base == BaseModelType.Flux2 and transformer_config.type == ModelType.Main
|
|
288
|
+
|
|
289
|
+
# Calculate the timestep schedule using FLUX.2 specific schedule
|
|
290
|
+
# This matches diffusers' Flux2Pipeline implementation
|
|
291
|
+
# Note: Schedule shifting is handled by the scheduler via mu parameter
|
|
292
|
+
image_seq_len = packed_h * packed_w
|
|
293
|
+
timesteps = get_schedule_flux2(
|
|
294
|
+
num_steps=self.num_steps,
|
|
295
|
+
image_seq_len=image_seq_len,
|
|
296
|
+
)
|
|
297
|
+
# Compute mu for dynamic schedule shifting (used by FlowMatchEulerDiscreteScheduler)
|
|
298
|
+
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=self.num_steps)
|
|
299
|
+
|
|
300
|
+
# Clip the timesteps schedule based on denoising_start and denoising_end
|
|
301
|
+
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
|
302
|
+
|
|
303
|
+
# Prepare input latent image
|
|
304
|
+
if init_latents is not None:
|
|
305
|
+
if self.add_noise:
|
|
306
|
+
t_0 = timesteps[0]
|
|
307
|
+
x = t_0 * noise + (1.0 - t_0) * init_latents
|
|
308
|
+
else:
|
|
309
|
+
x = init_latents
|
|
310
|
+
else:
|
|
311
|
+
if self.denoising_start > 1e-5:
|
|
312
|
+
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
|
313
|
+
x = noise
|
|
314
|
+
|
|
315
|
+
# If len(timesteps) == 1, then short-circuit
|
|
316
|
+
if len(timesteps) <= 1:
|
|
317
|
+
return x
|
|
318
|
+
|
|
319
|
+
# Generate image position IDs (FLUX.2 uses 4D coordinates)
|
|
320
|
+
# Position IDs use int64 dtype like diffusers
|
|
321
|
+
img_ids = generate_img_ids_flux2(h=latent_h, w=latent_w, batch_size=b, device=device)
|
|
322
|
+
|
|
323
|
+
# Prepare inpaint mask
|
|
324
|
+
inpaint_mask = self._prep_inpaint_mask(context, x)
|
|
325
|
+
|
|
326
|
+
# Pack all latent tensors
|
|
327
|
+
init_latents_packed = pack_flux2(init_latents) if init_latents is not None else None
|
|
328
|
+
inpaint_mask_packed = pack_flux2(inpaint_mask) if inpaint_mask is not None else None
|
|
329
|
+
noise_packed = pack_flux2(noise)
|
|
330
|
+
x = pack_flux2(x)
|
|
331
|
+
|
|
332
|
+
# Apply BN normalization BEFORE denoising (as per diffusers Flux2KleinPipeline)
|
|
333
|
+
# BN normalization: y = (x - mean) / std
|
|
334
|
+
# This transforms latents to normalized space for the transformer
|
|
335
|
+
# IMPORTANT: Also normalize init_latents and noise for inpainting to maintain consistency
|
|
336
|
+
if bn_mean is not None and bn_std is not None:
|
|
337
|
+
x = self._bn_normalize(x, bn_mean, bn_std)
|
|
338
|
+
if init_latents_packed is not None:
|
|
339
|
+
init_latents_packed = self._bn_normalize(init_latents_packed, bn_mean, bn_std)
|
|
340
|
+
noise_packed = self._bn_normalize(noise_packed, bn_mean, bn_std)
|
|
341
|
+
|
|
342
|
+
# Verify packed dimensions
|
|
343
|
+
assert packed_h * packed_w == x.shape[1]
|
|
344
|
+
|
|
345
|
+
# Prepare inpaint extension
|
|
346
|
+
inpaint_extension: Optional[RectifiedFlowInpaintExtension] = None
|
|
347
|
+
if inpaint_mask_packed is not None:
|
|
348
|
+
assert init_latents_packed is not None
|
|
349
|
+
inpaint_extension = RectifiedFlowInpaintExtension(
|
|
350
|
+
init_latents=init_latents_packed,
|
|
351
|
+
inpaint_mask=inpaint_mask_packed,
|
|
352
|
+
noise=noise_packed,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Prepare CFG scale list
|
|
356
|
+
num_steps = len(timesteps) - 1
|
|
357
|
+
cfg_scale_list = [self.cfg_scale] * num_steps
|
|
358
|
+
|
|
359
|
+
# Check if we're doing inpainting (have a mask or a clipped schedule)
|
|
360
|
+
is_inpainting = self.denoise_mask is not None or self.denoising_start > 1e-5
|
|
361
|
+
|
|
362
|
+
# Create scheduler with FLUX.2 Klein configuration
|
|
363
|
+
# For inpainting/img2img, use manual Euler stepping to preserve the exact timestep schedule
|
|
364
|
+
# For txt2img, use the scheduler with dynamic shifting for optimal results
|
|
365
|
+
scheduler = None
|
|
366
|
+
if self.scheduler in FLUX_SCHEDULER_MAP and not is_inpainting:
|
|
367
|
+
# Only use scheduler for txt2img - use manual Euler for inpainting to preserve exact timesteps
|
|
368
|
+
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
|
|
369
|
+
scheduler = scheduler_class(
|
|
370
|
+
num_train_timesteps=1000,
|
|
371
|
+
shift=3.0,
|
|
372
|
+
use_dynamic_shifting=True,
|
|
373
|
+
base_shift=0.5,
|
|
374
|
+
max_shift=1.15,
|
|
375
|
+
base_image_seq_len=256,
|
|
376
|
+
max_image_seq_len=4096,
|
|
377
|
+
time_shift_type="exponential",
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Prepare reference image extension for FLUX.2 Klein built-in editing
|
|
381
|
+
ref_image_extension = None
|
|
382
|
+
if self.kontext_conditioning:
|
|
383
|
+
ref_image_extension = Flux2RefImageExtension(
|
|
384
|
+
context=context,
|
|
385
|
+
ref_image_conditioning=self.kontext_conditioning
|
|
386
|
+
if isinstance(self.kontext_conditioning, list)
|
|
387
|
+
else [self.kontext_conditioning],
|
|
388
|
+
vae_field=self.vae,
|
|
389
|
+
device=device,
|
|
390
|
+
dtype=inference_dtype,
|
|
391
|
+
bn_mean=bn_mean,
|
|
392
|
+
bn_std=bn_std,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
with ExitStack() as exit_stack:
|
|
396
|
+
# Load the transformer model
|
|
397
|
+
(cached_weights, transformer) = exit_stack.enter_context(
|
|
398
|
+
context.models.load(self.transformer.transformer).model_on_device()
|
|
399
|
+
)
|
|
400
|
+
config = transformer_config
|
|
401
|
+
|
|
402
|
+
# Determine if the model is quantized
|
|
403
|
+
if config.format in [ModelFormat.Diffusers]:
|
|
404
|
+
model_is_quantized = False
|
|
405
|
+
elif config.format in [
|
|
406
|
+
ModelFormat.BnbQuantizedLlmInt8b,
|
|
407
|
+
ModelFormat.BnbQuantizednf4b,
|
|
408
|
+
ModelFormat.GGUFQuantized,
|
|
409
|
+
]:
|
|
410
|
+
model_is_quantized = True
|
|
411
|
+
else:
|
|
412
|
+
model_is_quantized = False
|
|
413
|
+
|
|
414
|
+
# Apply LoRA models to the transformer
|
|
415
|
+
exit_stack.enter_context(
|
|
416
|
+
LayerPatcher.apply_smart_model_patches(
|
|
417
|
+
model=transformer,
|
|
418
|
+
patches=self._lora_iterator(context),
|
|
419
|
+
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
|
|
420
|
+
dtype=inference_dtype,
|
|
421
|
+
cached_weights=cached_weights,
|
|
422
|
+
force_sidecar_patching=model_is_quantized,
|
|
423
|
+
)
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Prepare reference image conditioning if provided
|
|
427
|
+
img_cond_seq = None
|
|
428
|
+
img_cond_seq_ids = None
|
|
429
|
+
if ref_image_extension is not None:
|
|
430
|
+
# Ensure batch sizes match
|
|
431
|
+
ref_image_extension.ensure_batch_size(x.shape[0])
|
|
432
|
+
img_cond_seq, img_cond_seq_ids = (
|
|
433
|
+
ref_image_extension.ref_image_latents,
|
|
434
|
+
ref_image_extension.ref_image_ids,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
x = denoise(
|
|
438
|
+
model=transformer,
|
|
439
|
+
img=x,
|
|
440
|
+
img_ids=img_ids,
|
|
441
|
+
txt=txt,
|
|
442
|
+
txt_ids=txt_ids,
|
|
443
|
+
timesteps=timesteps,
|
|
444
|
+
step_callback=self._build_step_callback(context),
|
|
445
|
+
cfg_scale=cfg_scale_list,
|
|
446
|
+
neg_txt=neg_txt,
|
|
447
|
+
neg_txt_ids=neg_txt_ids,
|
|
448
|
+
scheduler=scheduler,
|
|
449
|
+
mu=mu,
|
|
450
|
+
inpaint_extension=inpaint_extension,
|
|
451
|
+
img_cond_seq=img_cond_seq,
|
|
452
|
+
img_cond_seq_ids=img_cond_seq_ids,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Apply BN denormalization if BN stats are available
|
|
456
|
+
# The diffusers Flux2KleinPipeline applies: latents = latents * bn_std + bn_mean
|
|
457
|
+
# This transforms latents from normalized space to VAE's expected input space
|
|
458
|
+
if bn_mean is not None and bn_std is not None:
|
|
459
|
+
x = self._bn_denormalize(x, bn_mean, bn_std)
|
|
460
|
+
|
|
461
|
+
x = unpack_flux2(x.float(), self.height, self.width)
|
|
462
|
+
return x
|
|
463
|
+
|
|
464
|
+
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> Optional[torch.Tensor]:
|
|
465
|
+
"""Prepare the inpaint mask."""
|
|
466
|
+
if self.denoise_mask is None:
|
|
467
|
+
return None
|
|
468
|
+
|
|
469
|
+
mask = context.tensors.load(self.denoise_mask.mask_name)
|
|
470
|
+
mask = 1.0 - mask
|
|
471
|
+
|
|
472
|
+
_, _, latent_height, latent_width = latents.shape
|
|
473
|
+
mask = tv_resize(
|
|
474
|
+
img=mask,
|
|
475
|
+
size=[latent_height, latent_width],
|
|
476
|
+
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
|
477
|
+
antialias=False,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
|
481
|
+
return mask.expand_as(latents)
|
|
482
|
+
|
|
483
|
+
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
|
|
484
|
+
"""Iterate over LoRA models to apply."""
|
|
485
|
+
for lora in self.transformer.loras:
|
|
486
|
+
lora_info = context.models.load(lora.lora)
|
|
487
|
+
assert isinstance(lora_info.model, ModelPatchRaw)
|
|
488
|
+
yield (lora_info.model, lora.weight)
|
|
489
|
+
del lora_info
|
|
490
|
+
|
|
491
|
+
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
|
|
492
|
+
"""Build a callback for step progress updates."""
|
|
493
|
+
|
|
494
|
+
def step_callback(state: PipelineIntermediateState) -> None:
|
|
495
|
+
latents = state.latents.float()
|
|
496
|
+
state.latents = unpack_flux2(latents, self.height, self.width).squeeze()
|
|
497
|
+
context.util.flux2_step_callback(state)
|
|
498
|
+
|
|
499
|
+
return step_callback
|