InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -56,6 +56,7 @@ from invokeai.backend.model_manager.configs.lora import (
|
|
|
56
56
|
)
|
|
57
57
|
from invokeai.backend.model_manager.configs.main import (
|
|
58
58
|
Main_BnBNF4_FLUX_Config,
|
|
59
|
+
Main_Checkpoint_Flux2_Config,
|
|
59
60
|
Main_Checkpoint_FLUX_Config,
|
|
60
61
|
Main_Checkpoint_SD1_Config,
|
|
61
62
|
Main_Checkpoint_SD2_Config,
|
|
@@ -63,12 +64,15 @@ from invokeai.backend.model_manager.configs.main import (
|
|
|
63
64
|
Main_Checkpoint_SDXLRefiner_Config,
|
|
64
65
|
Main_Checkpoint_ZImage_Config,
|
|
65
66
|
Main_Diffusers_CogView4_Config,
|
|
67
|
+
Main_Diffusers_Flux2_Config,
|
|
68
|
+
Main_Diffusers_FLUX_Config,
|
|
66
69
|
Main_Diffusers_SD1_Config,
|
|
67
70
|
Main_Diffusers_SD2_Config,
|
|
68
71
|
Main_Diffusers_SD3_Config,
|
|
69
72
|
Main_Diffusers_SDXL_Config,
|
|
70
73
|
Main_Diffusers_SDXLRefiner_Config,
|
|
71
74
|
Main_Diffusers_ZImage_Config,
|
|
75
|
+
Main_GGUF_Flux2_Config,
|
|
72
76
|
Main_GGUF_FLUX_Config,
|
|
73
77
|
Main_GGUF_ZImage_Config,
|
|
74
78
|
MainModelDefaultSettings,
|
|
@@ -95,10 +99,12 @@ from invokeai.backend.model_manager.configs.textual_inversion import (
|
|
|
95
99
|
)
|
|
96
100
|
from invokeai.backend.model_manager.configs.unknown import Unknown_Config
|
|
97
101
|
from invokeai.backend.model_manager.configs.vae import (
|
|
102
|
+
VAE_Checkpoint_Flux2_Config,
|
|
98
103
|
VAE_Checkpoint_FLUX_Config,
|
|
99
104
|
VAE_Checkpoint_SD1_Config,
|
|
100
105
|
VAE_Checkpoint_SD2_Config,
|
|
101
106
|
VAE_Checkpoint_SDXL_Config,
|
|
107
|
+
VAE_Diffusers_Flux2_Config,
|
|
102
108
|
VAE_Diffusers_SD1_Config,
|
|
103
109
|
VAE_Diffusers_SDXL_Config,
|
|
104
110
|
)
|
|
@@ -148,17 +154,25 @@ AnyModelConfig = Annotated[
|
|
|
148
154
|
Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
|
|
149
155
|
Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
|
|
150
156
|
Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
|
|
157
|
+
Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()],
|
|
158
|
+
Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()],
|
|
151
159
|
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
|
|
152
160
|
Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()],
|
|
153
161
|
# Main (Pipeline) - checkpoint format
|
|
162
|
+
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
|
|
163
|
+
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
|
|
154
164
|
Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
|
|
155
165
|
Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
|
|
156
166
|
Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
|
|
157
167
|
Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
|
|
168
|
+
Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
|
|
158
169
|
Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
|
|
159
170
|
Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
|
|
160
171
|
# Main (Pipeline) - quantized formats
|
|
172
|
+
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
|
|
173
|
+
# that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
|
|
161
174
|
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
|
|
175
|
+
Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()],
|
|
162
176
|
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
|
|
163
177
|
Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()],
|
|
164
178
|
# VAE - checkpoint format
|
|
@@ -166,9 +180,11 @@ AnyModelConfig = Annotated[
|
|
|
166
180
|
Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
|
|
167
181
|
Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
|
|
168
182
|
Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
|
|
183
|
+
Annotated[VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_Flux2_Config.get_tag()],
|
|
169
184
|
# VAE - diffusers format
|
|
170
185
|
Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
|
|
171
186
|
Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
|
|
187
|
+
Annotated[VAE_Diffusers_Flux2_Config, VAE_Diffusers_Flux2_Config.get_tag()],
|
|
172
188
|
# ControlNet - checkpoint format
|
|
173
189
|
Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
|
|
174
190
|
Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
|
|
@@ -498,7 +514,9 @@ class ModelConfigFactory:
|
|
|
498
514
|
# Now do any post-processing needed for specific model types/bases/etc.
|
|
499
515
|
match config.type:
|
|
500
516
|
case ModelType.Main:
|
|
501
|
-
|
|
517
|
+
# Pass variant if available (e.g., for Flux2 models)
|
|
518
|
+
variant = getattr(config, "variant", None)
|
|
519
|
+
config.default_settings = MainModelDefaultSettings.from_base(config.base, variant)
|
|
502
520
|
case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
|
|
503
521
|
config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
|
|
504
522
|
case ModelType.LoRA:
|
|
@@ -227,6 +227,42 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
|
|
|
227
227
|
|
|
228
228
|
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
|
|
229
229
|
|
|
230
|
+
@classmethod
|
|
231
|
+
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
|
|
232
|
+
"""Z-Image LoRAs have different key patterns than SD/SDXL LoRAs.
|
|
233
|
+
|
|
234
|
+
Z-Image LoRAs use keys like:
|
|
235
|
+
- diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
|
|
236
|
+
- diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
|
|
237
|
+
- diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
|
|
238
|
+
"""
|
|
239
|
+
state_dict = mod.load_state_dict()
|
|
240
|
+
|
|
241
|
+
# Check for Z-Image specific LoRA patterns
|
|
242
|
+
has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
|
|
243
|
+
state_dict,
|
|
244
|
+
{
|
|
245
|
+
"diffusion_model.layers.", # Z-Image S3-DiT layer pattern
|
|
246
|
+
},
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Also check for LoRA weight suffixes (various formats)
|
|
250
|
+
has_lora_suffix = state_dict_has_any_keys_ending_with(
|
|
251
|
+
state_dict,
|
|
252
|
+
{
|
|
253
|
+
"lora_A.weight",
|
|
254
|
+
"lora_B.weight",
|
|
255
|
+
"lora_down.weight",
|
|
256
|
+
"lora_up.weight",
|
|
257
|
+
"dora_scale",
|
|
258
|
+
},
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if has_z_image_lora_keys and has_lora_suffix:
|
|
262
|
+
return
|
|
263
|
+
|
|
264
|
+
raise NotAMatchError("model does not match Z-Image LoRA heuristics")
|
|
265
|
+
|
|
230
266
|
@classmethod
|
|
231
267
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
232
268
|
"""Z-Image LoRAs are identified by their diffusion_model.layers structure.
|
|
@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
|
|
|
23
23
|
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
|
24
24
|
from invokeai.backend.model_manager.taxonomy import (
|
|
25
25
|
BaseModelType,
|
|
26
|
+
Flux2VariantType,
|
|
26
27
|
FluxVariantType,
|
|
27
28
|
ModelFormat,
|
|
28
29
|
ModelType,
|
|
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
|
|
|
52
53
|
model_config = ConfigDict(extra="forbid")
|
|
53
54
|
|
|
54
55
|
@classmethod
|
|
55
|
-
def from_base(
|
|
56
|
+
def from_base(
|
|
57
|
+
cls,
|
|
58
|
+
base: BaseModelType,
|
|
59
|
+
variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
|
|
60
|
+
) -> Self | None:
|
|
56
61
|
match base:
|
|
57
62
|
case BaseModelType.StableDiffusion1:
|
|
58
63
|
return cls(width=512, height=512)
|
|
@@ -62,6 +67,14 @@ class MainModelDefaultSettings(BaseModel):
|
|
|
62
67
|
return cls(width=1024, height=1024)
|
|
63
68
|
case BaseModelType.ZImage:
|
|
64
69
|
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
|
|
70
|
+
case BaseModelType.Flux2:
|
|
71
|
+
# Different defaults based on variant
|
|
72
|
+
if variant == Flux2VariantType.Klein9BBase:
|
|
73
|
+
# Undistilled base model needs more steps
|
|
74
|
+
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
|
|
75
|
+
else:
|
|
76
|
+
# Distilled models (Klein 4B, Klein 9B) use fewer steps
|
|
77
|
+
return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
|
|
65
78
|
case _:
|
|
66
79
|
# TODO(psyche): Do we want defaults for other base types?
|
|
67
80
|
return None
|
|
@@ -114,7 +127,11 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
114
127
|
|
|
115
128
|
|
|
116
129
|
def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
117
|
-
"""Check if state dict contains Z-Image S3-DiT transformer keys.
|
|
130
|
+
"""Check if state dict contains Z-Image S3-DiT transformer keys.
|
|
131
|
+
|
|
132
|
+
This function returns True only for Z-Image main models, not LoRAs.
|
|
133
|
+
LoRAs are excluded by checking for LoRA-specific weight suffixes.
|
|
134
|
+
"""
|
|
118
135
|
# Z-Image specific keys that distinguish it from other models
|
|
119
136
|
z_image_specific_keys = {
|
|
120
137
|
"cap_embedder", # Caption embedder - unique to Z-Image
|
|
@@ -122,9 +139,23 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
122
139
|
"cap_pad_token", # Caption padding token
|
|
123
140
|
}
|
|
124
141
|
|
|
142
|
+
# LoRA-specific suffixes - if present, this is a LoRA not a main model
|
|
143
|
+
lora_suffixes = (
|
|
144
|
+
".lora_down.weight",
|
|
145
|
+
".lora_up.weight",
|
|
146
|
+
".lora_A.weight",
|
|
147
|
+
".lora_B.weight",
|
|
148
|
+
".dora_scale",
|
|
149
|
+
)
|
|
150
|
+
|
|
125
151
|
for key in state_dict.keys():
|
|
126
152
|
if isinstance(key, int):
|
|
127
153
|
continue
|
|
154
|
+
|
|
155
|
+
# If we find any LoRA-specific keys, this is not a main model
|
|
156
|
+
if key.endswith(lora_suffixes):
|
|
157
|
+
return False
|
|
158
|
+
|
|
128
159
|
# Check for Z-Image specific key prefixes
|
|
129
160
|
# Handle both direct keys (cap_embedder.0.weight) and
|
|
130
161
|
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
|
|
@@ -132,6 +163,7 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
132
163
|
for part in key_parts:
|
|
133
164
|
if part in z_image_specific_keys:
|
|
134
165
|
return True
|
|
166
|
+
|
|
135
167
|
return False
|
|
136
168
|
|
|
137
169
|
|
|
@@ -249,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
|
|
|
249
281
|
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
|
|
250
282
|
|
|
251
283
|
|
|
284
|
+
def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
|
|
285
|
+
"""Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
|
|
286
|
+
|
|
287
|
+
FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
|
|
288
|
+
- FLUX.1: context_in_dim = 4096 (T5)
|
|
289
|
+
- FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
|
|
290
|
+
- FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
|
|
291
|
+
|
|
292
|
+
Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
|
|
293
|
+
"""
|
|
294
|
+
# Check context_embedder input dimension (most reliable)
|
|
295
|
+
# Weight shape: [hidden_size, context_in_dim]
|
|
296
|
+
for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
|
|
297
|
+
if key in state_dict:
|
|
298
|
+
weight = state_dict[key]
|
|
299
|
+
if hasattr(weight, "shape") and len(weight.shape) >= 2:
|
|
300
|
+
context_in_dim = weight.shape[1]
|
|
301
|
+
# FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
|
|
302
|
+
if context_in_dim > 4096:
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
# Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
|
|
306
|
+
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
|
307
|
+
if key in state_dict:
|
|
308
|
+
in_channels = state_dict[key].shape[1]
|
|
309
|
+
# FLUX.2 uses 128 in_channels (32 latent channels × 4)
|
|
310
|
+
# FLUX.1 uses 64 in_channels (16 latent channels × 4)
|
|
311
|
+
if in_channels == 128:
|
|
312
|
+
return True
|
|
313
|
+
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
|
|
318
|
+
"""Determine FLUX.2 variant from state dict.
|
|
319
|
+
|
|
320
|
+
Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
|
|
321
|
+
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
|
|
322
|
+
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
|
|
323
|
+
|
|
324
|
+
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
|
|
325
|
+
We default to Klein9B (distilled) for all 9B models since GGUF models may not
|
|
326
|
+
include guidance embedding keys needed to distinguish them.
|
|
327
|
+
|
|
328
|
+
Supports both BFL format (checkpoint) and diffusers format keys:
|
|
329
|
+
- BFL format: txt_in.weight (context embedder)
|
|
330
|
+
- Diffusers format: context_embedder.weight
|
|
331
|
+
"""
|
|
332
|
+
# Context dimensions for each variant
|
|
333
|
+
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
|
334
|
+
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
|
335
|
+
|
|
336
|
+
# Check context_embedder to determine variant
|
|
337
|
+
# Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
|
|
338
|
+
context_keys = {
|
|
339
|
+
# Diffusers format
|
|
340
|
+
"context_embedder.weight",
|
|
341
|
+
"model.diffusion_model.context_embedder.weight",
|
|
342
|
+
# BFL format (used by checkpoint/GGUF models)
|
|
343
|
+
"txt_in.weight",
|
|
344
|
+
"model.diffusion_model.txt_in.weight",
|
|
345
|
+
}
|
|
346
|
+
for key in context_keys:
|
|
347
|
+
if key in state_dict:
|
|
348
|
+
weight = state_dict[key]
|
|
349
|
+
# Handle GGUF quantized tensors which use tensor_shape instead of shape
|
|
350
|
+
if hasattr(weight, "tensor_shape"):
|
|
351
|
+
shape = weight.tensor_shape
|
|
352
|
+
elif hasattr(weight, "shape"):
|
|
353
|
+
shape = weight.shape
|
|
354
|
+
else:
|
|
355
|
+
continue
|
|
356
|
+
if len(shape) >= 2:
|
|
357
|
+
context_in_dim = shape[1]
|
|
358
|
+
# Determine variant based on context dimension
|
|
359
|
+
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
|
|
360
|
+
# Default to Klein9B (distilled) - the official/common 9B model
|
|
361
|
+
return Flux2VariantType.Klein9B
|
|
362
|
+
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
|
|
363
|
+
return Flux2VariantType.Klein4B
|
|
364
|
+
elif context_in_dim > 4096:
|
|
365
|
+
# Unknown FLUX.2 variant, default to 4B
|
|
366
|
+
return Flux2VariantType.Klein4B
|
|
367
|
+
|
|
368
|
+
# Check in_channels as backup - can only confirm it's FLUX.2, not which variant
|
|
369
|
+
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
|
370
|
+
if key in state_dict:
|
|
371
|
+
weight = state_dict[key]
|
|
372
|
+
# Handle GGUF quantized tensors
|
|
373
|
+
if hasattr(weight, "tensor_shape"):
|
|
374
|
+
in_channels = weight.tensor_shape[1]
|
|
375
|
+
elif hasattr(weight, "shape"):
|
|
376
|
+
in_channels = weight.shape[1]
|
|
377
|
+
else:
|
|
378
|
+
continue
|
|
379
|
+
if in_channels == 128:
|
|
380
|
+
# It's FLUX.2 but we can't determine which Klein variant, default to 4B
|
|
381
|
+
return Flux2VariantType.Klein4B
|
|
382
|
+
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
|
|
252
386
|
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
|
|
253
387
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
254
388
|
|
|
@@ -322,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
322
456
|
|
|
323
457
|
@classmethod
|
|
324
458
|
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
|
459
|
+
state_dict = mod.load_state_dict()
|
|
325
460
|
if not state_dict_has_any_keys_exact(
|
|
326
|
-
|
|
461
|
+
state_dict,
|
|
327
462
|
{
|
|
328
463
|
"double_blocks.0.img_attn.norm.key_norm.scale",
|
|
329
464
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
|
@@ -331,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
331
466
|
):
|
|
332
467
|
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
|
|
333
468
|
|
|
469
|
+
# Exclude FLUX.2 models - they have their own config class
|
|
470
|
+
if _is_flux2_model(state_dict):
|
|
471
|
+
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
|
472
|
+
|
|
334
473
|
@classmethod
|
|
335
474
|
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
336
475
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
@@ -364,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
364
503
|
raise NotAMatchError("state dict looks like GGUF quantized")
|
|
365
504
|
|
|
366
505
|
|
|
506
|
+
class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
507
|
+
"""Model config for FLUX.2 checkpoint models (e.g. Klein)."""
|
|
508
|
+
|
|
509
|
+
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
510
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
511
|
+
|
|
512
|
+
variant: Flux2VariantType = Field()
|
|
513
|
+
|
|
514
|
+
@classmethod
|
|
515
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
516
|
+
raise_if_not_file(mod)
|
|
517
|
+
|
|
518
|
+
raise_for_override_fields(cls, override_fields)
|
|
519
|
+
|
|
520
|
+
cls._validate_looks_like_main_model(mod)
|
|
521
|
+
|
|
522
|
+
cls._validate_is_flux2(mod)
|
|
523
|
+
|
|
524
|
+
cls._validate_does_not_look_like_bnb_quantized(mod)
|
|
525
|
+
|
|
526
|
+
cls._validate_does_not_look_like_gguf_quantized(mod)
|
|
527
|
+
|
|
528
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
529
|
+
|
|
530
|
+
return cls(**override_fields, variant=variant)
|
|
531
|
+
|
|
532
|
+
@classmethod
|
|
533
|
+
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
|
534
|
+
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
|
535
|
+
state_dict = mod.load_state_dict()
|
|
536
|
+
if not _is_flux2_model(state_dict):
|
|
537
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
|
538
|
+
|
|
539
|
+
@classmethod
|
|
540
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
541
|
+
state_dict = mod.load_state_dict()
|
|
542
|
+
variant = _get_flux2_variant(state_dict)
|
|
543
|
+
|
|
544
|
+
if variant is None:
|
|
545
|
+
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
|
546
|
+
|
|
547
|
+
return variant
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
551
|
+
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
552
|
+
if not has_main_model_keys:
|
|
553
|
+
raise NotAMatchError("state dict does not look like a main model")
|
|
554
|
+
|
|
555
|
+
@classmethod
|
|
556
|
+
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
|
557
|
+
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
|
|
558
|
+
if has_bnb_nf4_keys:
|
|
559
|
+
raise NotAMatchError("state dict looks like bnb quantized nf4")
|
|
560
|
+
|
|
561
|
+
@classmethod
|
|
562
|
+
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
|
|
563
|
+
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
564
|
+
if has_ggml_tensors:
|
|
565
|
+
raise NotAMatchError("state dict looks like GGUF quantized")
|
|
566
|
+
|
|
567
|
+
|
|
367
568
|
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
368
569
|
"""Model config for main checkpoint models."""
|
|
369
570
|
|
|
@@ -431,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
|
|
431
632
|
|
|
432
633
|
cls._validate_looks_like_gguf_quantized(mod)
|
|
433
634
|
|
|
635
|
+
cls._validate_is_not_flux2(mod)
|
|
636
|
+
|
|
434
637
|
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
435
638
|
|
|
436
639
|
return cls(**override_fields, variant=variant)
|
|
@@ -461,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
|
|
461
664
|
if not has_ggml_tensors:
|
|
462
665
|
raise NotAMatchError("state dict does not look like GGUF quantized")
|
|
463
666
|
|
|
667
|
+
@classmethod
|
|
668
|
+
def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
|
|
669
|
+
"""Validate that this is NOT a FLUX.2 model."""
|
|
670
|
+
state_dict = mod.load_state_dict()
|
|
671
|
+
if _is_flux2_model(state_dict):
|
|
672
|
+
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
676
|
+
"""Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
|
|
677
|
+
|
|
678
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
679
|
+
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
680
|
+
|
|
681
|
+
variant: Flux2VariantType = Field()
|
|
682
|
+
|
|
683
|
+
@classmethod
|
|
684
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
685
|
+
raise_if_not_file(mod)
|
|
686
|
+
|
|
687
|
+
raise_for_override_fields(cls, override_fields)
|
|
688
|
+
|
|
689
|
+
cls._validate_looks_like_main_model(mod)
|
|
690
|
+
|
|
691
|
+
cls._validate_looks_like_gguf_quantized(mod)
|
|
692
|
+
|
|
693
|
+
cls._validate_is_flux2(mod)
|
|
694
|
+
|
|
695
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
696
|
+
|
|
697
|
+
return cls(**override_fields, variant=variant)
|
|
698
|
+
|
|
699
|
+
@classmethod
|
|
700
|
+
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
|
701
|
+
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
|
702
|
+
state_dict = mod.load_state_dict()
|
|
703
|
+
if not _is_flux2_model(state_dict):
|
|
704
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
|
705
|
+
|
|
706
|
+
@classmethod
|
|
707
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
708
|
+
state_dict = mod.load_state_dict()
|
|
709
|
+
variant = _get_flux2_variant(state_dict)
|
|
710
|
+
|
|
711
|
+
if variant is None:
|
|
712
|
+
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
|
713
|
+
|
|
714
|
+
return variant
|
|
715
|
+
|
|
716
|
+
@classmethod
|
|
717
|
+
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
718
|
+
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
719
|
+
if not has_main_model_keys:
|
|
720
|
+
raise NotAMatchError("state dict does not look like a main model")
|
|
721
|
+
|
|
722
|
+
@classmethod
|
|
723
|
+
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
|
724
|
+
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
725
|
+
if not has_ggml_tensors:
|
|
726
|
+
raise NotAMatchError("state dict does not look like GGUF quantized")
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
730
|
+
"""Model config for FLUX.1 models in diffusers format."""
|
|
731
|
+
|
|
732
|
+
base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
|
|
733
|
+
variant: FluxVariantType = Field()
|
|
734
|
+
|
|
735
|
+
@classmethod
|
|
736
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
737
|
+
raise_if_not_dir(mod)
|
|
738
|
+
|
|
739
|
+
raise_for_override_fields(cls, override_fields)
|
|
740
|
+
|
|
741
|
+
# Check for FLUX-specific pipeline or transformer class names
|
|
742
|
+
raise_for_class_name(
|
|
743
|
+
common_config_paths(mod.path),
|
|
744
|
+
{
|
|
745
|
+
"FluxPipeline",
|
|
746
|
+
"FluxFillPipeline",
|
|
747
|
+
"FluxTransformer2DModel",
|
|
748
|
+
},
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
752
|
+
|
|
753
|
+
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
754
|
+
|
|
755
|
+
return cls(
|
|
756
|
+
**override_fields,
|
|
757
|
+
variant=variant,
|
|
758
|
+
repo_variant=repo_variant,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
@classmethod
|
|
762
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
763
|
+
"""Determine the FLUX variant from the transformer config.
|
|
764
|
+
|
|
765
|
+
FLUX variants are distinguished by:
|
|
766
|
+
- in_channels: 64 for Dev/Schnell, 384 for DevFill
|
|
767
|
+
- guidance_embeds: True for Dev, False for Schnell
|
|
768
|
+
"""
|
|
769
|
+
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
|
770
|
+
|
|
771
|
+
in_channels = transformer_config.get("in_channels", 64)
|
|
772
|
+
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
|
773
|
+
|
|
774
|
+
# DevFill has 384 input channels
|
|
775
|
+
if in_channels == 384:
|
|
776
|
+
return FluxVariantType.DevFill
|
|
777
|
+
|
|
778
|
+
# Dev has guidance_embeds=True, Schnell has guidance_embeds=False
|
|
779
|
+
if guidance_embeds:
|
|
780
|
+
return FluxVariantType.Dev
|
|
781
|
+
else:
|
|
782
|
+
return FluxVariantType.Schnell
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
786
|
+
"""Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
|
|
787
|
+
|
|
788
|
+
base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
|
|
789
|
+
variant: Flux2VariantType = Field()
|
|
790
|
+
|
|
791
|
+
@classmethod
|
|
792
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
793
|
+
raise_if_not_dir(mod)
|
|
794
|
+
|
|
795
|
+
raise_for_override_fields(cls, override_fields)
|
|
796
|
+
|
|
797
|
+
# Check for FLUX.2-specific pipeline class names
|
|
798
|
+
raise_for_class_name(
|
|
799
|
+
common_config_paths(mod.path),
|
|
800
|
+
{
|
|
801
|
+
"Flux2KleinPipeline",
|
|
802
|
+
},
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
806
|
+
|
|
807
|
+
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
808
|
+
|
|
809
|
+
return cls(
|
|
810
|
+
**override_fields,
|
|
811
|
+
variant=variant,
|
|
812
|
+
repo_variant=repo_variant,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
@classmethod
|
|
816
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
817
|
+
"""Determine the FLUX.2 variant from the transformer config.
|
|
818
|
+
|
|
819
|
+
FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
|
|
820
|
+
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
|
|
821
|
+
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
|
|
822
|
+
|
|
823
|
+
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
|
|
824
|
+
we check guidance_embeds:
|
|
825
|
+
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
|
|
826
|
+
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
|
|
827
|
+
|
|
828
|
+
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
|
|
829
|
+
"""
|
|
830
|
+
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
|
831
|
+
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
|
832
|
+
|
|
833
|
+
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
|
834
|
+
|
|
835
|
+
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
|
|
836
|
+
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
|
837
|
+
|
|
838
|
+
# Determine variant based on joint_attention_dim
|
|
839
|
+
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
|
|
840
|
+
# Check guidance_embeds to distinguish distilled from undistilled
|
|
841
|
+
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
|
|
842
|
+
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
|
|
843
|
+
if guidance_embeds:
|
|
844
|
+
return Flux2VariantType.Klein9BBase
|
|
845
|
+
else:
|
|
846
|
+
return Flux2VariantType.Klein9B
|
|
847
|
+
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
|
|
848
|
+
return Flux2VariantType.Klein4B
|
|
849
|
+
elif joint_attention_dim > 4096:
|
|
850
|
+
# Unknown FLUX.2 variant, default to 4B
|
|
851
|
+
return Flux2VariantType.Klein4B
|
|
852
|
+
|
|
853
|
+
# Default to 4B
|
|
854
|
+
return Flux2VariantType.Klein4B
|
|
855
|
+
|
|
464
856
|
|
|
465
857
|
class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
|
|
466
858
|
prediction_type: SchedulerPredictionType = Field()
|