InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +50 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/z_image_denoise.py +8 -3
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +42 -0
- invokeai/backend/flux/denoise.py +239 -204
- invokeai/backend/flux/dype/__init__.py +18 -0
- invokeai/backend/flux/dype/base.py +226 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +141 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +261 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/load_default.py +0 -1
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
- invokeai/backend/model_manager/starter_models.py +128 -0
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
- invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en.json +58 -5
- invokeai/frontend/web/dist/locales/it.json +2 -1
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
|
|
|
23
23
|
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
|
24
24
|
from invokeai.backend.model_manager.taxonomy import (
|
|
25
25
|
BaseModelType,
|
|
26
|
+
Flux2VariantType,
|
|
26
27
|
FluxVariantType,
|
|
27
28
|
ModelFormat,
|
|
28
29
|
ModelType,
|
|
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
|
|
|
52
53
|
model_config = ConfigDict(extra="forbid")
|
|
53
54
|
|
|
54
55
|
@classmethod
|
|
55
|
-
def from_base(
|
|
56
|
+
def from_base(
|
|
57
|
+
cls,
|
|
58
|
+
base: BaseModelType,
|
|
59
|
+
variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
|
|
60
|
+
) -> Self | None:
|
|
56
61
|
match base:
|
|
57
62
|
case BaseModelType.StableDiffusion1:
|
|
58
63
|
return cls(width=512, height=512)
|
|
@@ -62,6 +67,14 @@ class MainModelDefaultSettings(BaseModel):
|
|
|
62
67
|
return cls(width=1024, height=1024)
|
|
63
68
|
case BaseModelType.ZImage:
|
|
64
69
|
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
|
|
70
|
+
case BaseModelType.Flux2:
|
|
71
|
+
# Different defaults based on variant
|
|
72
|
+
if variant == Flux2VariantType.Klein9BBase:
|
|
73
|
+
# Undistilled base model needs more steps
|
|
74
|
+
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
|
|
75
|
+
else:
|
|
76
|
+
# Distilled models (Klein 4B, Klein 9B) use fewer steps
|
|
77
|
+
return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
|
|
65
78
|
case _:
|
|
66
79
|
# TODO(psyche): Do we want defaults for other base types?
|
|
67
80
|
return None
|
|
@@ -114,7 +127,11 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
114
127
|
|
|
115
128
|
|
|
116
129
|
def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
117
|
-
"""Check if state dict contains Z-Image S3-DiT transformer keys.
|
|
130
|
+
"""Check if state dict contains Z-Image S3-DiT transformer keys.
|
|
131
|
+
|
|
132
|
+
This function returns True only for Z-Image main models, not LoRAs.
|
|
133
|
+
LoRAs are excluded by checking for LoRA-specific weight suffixes.
|
|
134
|
+
"""
|
|
118
135
|
# Z-Image specific keys that distinguish it from other models
|
|
119
136
|
z_image_specific_keys = {
|
|
120
137
|
"cap_embedder", # Caption embedder - unique to Z-Image
|
|
@@ -122,9 +139,23 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
122
139
|
"cap_pad_token", # Caption padding token
|
|
123
140
|
}
|
|
124
141
|
|
|
142
|
+
# LoRA-specific suffixes - if present, this is a LoRA not a main model
|
|
143
|
+
lora_suffixes = (
|
|
144
|
+
".lora_down.weight",
|
|
145
|
+
".lora_up.weight",
|
|
146
|
+
".lora_A.weight",
|
|
147
|
+
".lora_B.weight",
|
|
148
|
+
".dora_scale",
|
|
149
|
+
)
|
|
150
|
+
|
|
125
151
|
for key in state_dict.keys():
|
|
126
152
|
if isinstance(key, int):
|
|
127
153
|
continue
|
|
154
|
+
|
|
155
|
+
# If we find any LoRA-specific keys, this is not a main model
|
|
156
|
+
if key.endswith(lora_suffixes):
|
|
157
|
+
return False
|
|
158
|
+
|
|
128
159
|
# Check for Z-Image specific key prefixes
|
|
129
160
|
# Handle both direct keys (cap_embedder.0.weight) and
|
|
130
161
|
# ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
|
|
@@ -132,6 +163,7 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
|
|
|
132
163
|
for part in key_parts:
|
|
133
164
|
if part in z_image_specific_keys:
|
|
134
165
|
return True
|
|
166
|
+
|
|
135
167
|
return False
|
|
136
168
|
|
|
137
169
|
|
|
@@ -249,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
|
|
|
249
281
|
base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
|
|
250
282
|
|
|
251
283
|
|
|
284
|
+
def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
|
|
285
|
+
"""Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
|
|
286
|
+
|
|
287
|
+
FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
|
|
288
|
+
- FLUX.1: context_in_dim = 4096 (T5)
|
|
289
|
+
- FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
|
|
290
|
+
- FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
|
|
291
|
+
|
|
292
|
+
Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
|
|
293
|
+
"""
|
|
294
|
+
# Check context_embedder input dimension (most reliable)
|
|
295
|
+
# Weight shape: [hidden_size, context_in_dim]
|
|
296
|
+
for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
|
|
297
|
+
if key in state_dict:
|
|
298
|
+
weight = state_dict[key]
|
|
299
|
+
if hasattr(weight, "shape") and len(weight.shape) >= 2:
|
|
300
|
+
context_in_dim = weight.shape[1]
|
|
301
|
+
# FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
|
|
302
|
+
if context_in_dim > 4096:
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
# Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
|
|
306
|
+
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
|
307
|
+
if key in state_dict:
|
|
308
|
+
in_channels = state_dict[key].shape[1]
|
|
309
|
+
# FLUX.2 uses 128 in_channels (32 latent channels × 4)
|
|
310
|
+
# FLUX.1 uses 64 in_channels (16 latent channels × 4)
|
|
311
|
+
if in_channels == 128:
|
|
312
|
+
return True
|
|
313
|
+
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
|
|
318
|
+
"""Determine FLUX.2 variant from state dict.
|
|
319
|
+
|
|
320
|
+
Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
|
|
321
|
+
- Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
|
|
322
|
+
- Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
|
|
323
|
+
|
|
324
|
+
Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
|
|
325
|
+
We default to Klein9B (distilled) for all 9B models since GGUF models may not
|
|
326
|
+
include guidance embedding keys needed to distinguish them.
|
|
327
|
+
|
|
328
|
+
Supports both BFL format (checkpoint) and diffusers format keys:
|
|
329
|
+
- BFL format: txt_in.weight (context embedder)
|
|
330
|
+
- Diffusers format: context_embedder.weight
|
|
331
|
+
"""
|
|
332
|
+
# Context dimensions for each variant
|
|
333
|
+
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
|
334
|
+
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
|
335
|
+
|
|
336
|
+
# Check context_embedder to determine variant
|
|
337
|
+
# Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
|
|
338
|
+
context_keys = {
|
|
339
|
+
# Diffusers format
|
|
340
|
+
"context_embedder.weight",
|
|
341
|
+
"model.diffusion_model.context_embedder.weight",
|
|
342
|
+
# BFL format (used by checkpoint/GGUF models)
|
|
343
|
+
"txt_in.weight",
|
|
344
|
+
"model.diffusion_model.txt_in.weight",
|
|
345
|
+
}
|
|
346
|
+
for key in context_keys:
|
|
347
|
+
if key in state_dict:
|
|
348
|
+
weight = state_dict[key]
|
|
349
|
+
# Handle GGUF quantized tensors which use tensor_shape instead of shape
|
|
350
|
+
if hasattr(weight, "tensor_shape"):
|
|
351
|
+
shape = weight.tensor_shape
|
|
352
|
+
elif hasattr(weight, "shape"):
|
|
353
|
+
shape = weight.shape
|
|
354
|
+
else:
|
|
355
|
+
continue
|
|
356
|
+
if len(shape) >= 2:
|
|
357
|
+
context_in_dim = shape[1]
|
|
358
|
+
# Determine variant based on context dimension
|
|
359
|
+
if context_in_dim == KLEIN_9B_CONTEXT_DIM:
|
|
360
|
+
# Default to Klein9B (distilled) - the official/common 9B model
|
|
361
|
+
return Flux2VariantType.Klein9B
|
|
362
|
+
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
|
|
363
|
+
return Flux2VariantType.Klein4B
|
|
364
|
+
elif context_in_dim > 4096:
|
|
365
|
+
# Unknown FLUX.2 variant, default to 4B
|
|
366
|
+
return Flux2VariantType.Klein4B
|
|
367
|
+
|
|
368
|
+
# Check in_channels as backup - can only confirm it's FLUX.2, not which variant
|
|
369
|
+
for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
|
|
370
|
+
if key in state_dict:
|
|
371
|
+
weight = state_dict[key]
|
|
372
|
+
# Handle GGUF quantized tensors
|
|
373
|
+
if hasattr(weight, "tensor_shape"):
|
|
374
|
+
in_channels = weight.tensor_shape[1]
|
|
375
|
+
elif hasattr(weight, "shape"):
|
|
376
|
+
in_channels = weight.shape[1]
|
|
377
|
+
else:
|
|
378
|
+
continue
|
|
379
|
+
if in_channels == 128:
|
|
380
|
+
# It's FLUX.2 but we can't determine which Klein variant, default to 4B
|
|
381
|
+
return Flux2VariantType.Klein4B
|
|
382
|
+
|
|
383
|
+
return None
|
|
384
|
+
|
|
385
|
+
|
|
252
386
|
def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
|
|
253
387
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
254
388
|
|
|
@@ -322,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
322
456
|
|
|
323
457
|
@classmethod
|
|
324
458
|
def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
|
|
459
|
+
state_dict = mod.load_state_dict()
|
|
325
460
|
if not state_dict_has_any_keys_exact(
|
|
326
|
-
|
|
461
|
+
state_dict,
|
|
327
462
|
{
|
|
328
463
|
"double_blocks.0.img_attn.norm.key_norm.scale",
|
|
329
464
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
|
@@ -331,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
331
466
|
):
|
|
332
467
|
raise NotAMatchError("state dict does not look like a FLUX checkpoint")
|
|
333
468
|
|
|
469
|
+
# Exclude FLUX.2 models - they have their own config class
|
|
470
|
+
if _is_flux2_model(state_dict):
|
|
471
|
+
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
|
472
|
+
|
|
334
473
|
@classmethod
|
|
335
474
|
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
336
475
|
# FLUX Model variant types are distinguished by input channels and the presence of certain keys.
|
|
@@ -364,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
|
|
|
364
503
|
raise NotAMatchError("state dict looks like GGUF quantized")
|
|
365
504
|
|
|
366
505
|
|
|
506
|
+
class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
507
|
+
"""Model config for FLUX.2 checkpoint models (e.g. Klein)."""
|
|
508
|
+
|
|
509
|
+
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
510
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
511
|
+
|
|
512
|
+
variant: Flux2VariantType = Field()
|
|
513
|
+
|
|
514
|
+
@classmethod
|
|
515
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
516
|
+
raise_if_not_file(mod)
|
|
517
|
+
|
|
518
|
+
raise_for_override_fields(cls, override_fields)
|
|
519
|
+
|
|
520
|
+
cls._validate_looks_like_main_model(mod)
|
|
521
|
+
|
|
522
|
+
cls._validate_is_flux2(mod)
|
|
523
|
+
|
|
524
|
+
cls._validate_does_not_look_like_bnb_quantized(mod)
|
|
525
|
+
|
|
526
|
+
cls._validate_does_not_look_like_gguf_quantized(mod)
|
|
527
|
+
|
|
528
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
529
|
+
|
|
530
|
+
return cls(**override_fields, variant=variant)
|
|
531
|
+
|
|
532
|
+
@classmethod
|
|
533
|
+
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
|
534
|
+
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
|
535
|
+
state_dict = mod.load_state_dict()
|
|
536
|
+
if not _is_flux2_model(state_dict):
|
|
537
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
|
538
|
+
|
|
539
|
+
@classmethod
|
|
540
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
541
|
+
state_dict = mod.load_state_dict()
|
|
542
|
+
variant = _get_flux2_variant(state_dict)
|
|
543
|
+
|
|
544
|
+
if variant is None:
|
|
545
|
+
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
|
546
|
+
|
|
547
|
+
return variant
|
|
548
|
+
|
|
549
|
+
@classmethod
|
|
550
|
+
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
551
|
+
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
552
|
+
if not has_main_model_keys:
|
|
553
|
+
raise NotAMatchError("state dict does not look like a main model")
|
|
554
|
+
|
|
555
|
+
@classmethod
|
|
556
|
+
def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
|
|
557
|
+
has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
|
|
558
|
+
if has_bnb_nf4_keys:
|
|
559
|
+
raise NotAMatchError("state dict looks like bnb quantized nf4")
|
|
560
|
+
|
|
561
|
+
@classmethod
|
|
562
|
+
def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
|
|
563
|
+
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
564
|
+
if has_ggml_tensors:
|
|
565
|
+
raise NotAMatchError("state dict looks like GGUF quantized")
|
|
566
|
+
|
|
567
|
+
|
|
367
568
|
class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
368
569
|
"""Model config for main checkpoint models."""
|
|
369
570
|
|
|
@@ -431,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
|
|
431
632
|
|
|
432
633
|
cls._validate_looks_like_gguf_quantized(mod)
|
|
433
634
|
|
|
635
|
+
cls._validate_is_not_flux2(mod)
|
|
636
|
+
|
|
434
637
|
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
435
638
|
|
|
436
639
|
return cls(**override_fields, variant=variant)
|
|
@@ -461,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
|
|
|
461
664
|
if not has_ggml_tensors:
|
|
462
665
|
raise NotAMatchError("state dict does not look like GGUF quantized")
|
|
463
666
|
|
|
667
|
+
@classmethod
|
|
668
|
+
def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
|
|
669
|
+
"""Validate that this is NOT a FLUX.2 model."""
|
|
670
|
+
state_dict = mod.load_state_dict()
|
|
671
|
+
if _is_flux2_model(state_dict):
|
|
672
|
+
raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
|
|
676
|
+
"""Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
|
|
677
|
+
|
|
678
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
679
|
+
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
680
|
+
|
|
681
|
+
variant: Flux2VariantType = Field()
|
|
682
|
+
|
|
683
|
+
@classmethod
|
|
684
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
685
|
+
raise_if_not_file(mod)
|
|
686
|
+
|
|
687
|
+
raise_for_override_fields(cls, override_fields)
|
|
688
|
+
|
|
689
|
+
cls._validate_looks_like_main_model(mod)
|
|
690
|
+
|
|
691
|
+
cls._validate_looks_like_gguf_quantized(mod)
|
|
692
|
+
|
|
693
|
+
cls._validate_is_flux2(mod)
|
|
694
|
+
|
|
695
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
696
|
+
|
|
697
|
+
return cls(**override_fields, variant=variant)
|
|
698
|
+
|
|
699
|
+
@classmethod
|
|
700
|
+
def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
|
|
701
|
+
"""Validate that this is a FLUX.2 model, not FLUX.1."""
|
|
702
|
+
state_dict = mod.load_state_dict()
|
|
703
|
+
if not _is_flux2_model(state_dict):
|
|
704
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 model")
|
|
705
|
+
|
|
706
|
+
@classmethod
|
|
707
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
708
|
+
state_dict = mod.load_state_dict()
|
|
709
|
+
variant = _get_flux2_variant(state_dict)
|
|
710
|
+
|
|
711
|
+
if variant is None:
|
|
712
|
+
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
|
|
713
|
+
|
|
714
|
+
return variant
|
|
715
|
+
|
|
716
|
+
@classmethod
|
|
717
|
+
def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
|
|
718
|
+
has_main_model_keys = _has_main_keys(mod.load_state_dict())
|
|
719
|
+
if not has_main_model_keys:
|
|
720
|
+
raise NotAMatchError("state dict does not look like a main model")
|
|
721
|
+
|
|
722
|
+
@classmethod
|
|
723
|
+
def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
|
|
724
|
+
has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
|
|
725
|
+
if not has_ggml_tensors:
|
|
726
|
+
raise NotAMatchError("state dict does not look like GGUF quantized")
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
730
|
+
"""Model config for FLUX.1 models in diffusers format."""
|
|
731
|
+
|
|
732
|
+
base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
|
|
733
|
+
variant: FluxVariantType = Field()
|
|
734
|
+
|
|
735
|
+
@classmethod
|
|
736
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
737
|
+
raise_if_not_dir(mod)
|
|
738
|
+
|
|
739
|
+
raise_for_override_fields(cls, override_fields)
|
|
740
|
+
|
|
741
|
+
# Check for FLUX-specific pipeline or transformer class names
|
|
742
|
+
raise_for_class_name(
|
|
743
|
+
common_config_paths(mod.path),
|
|
744
|
+
{
|
|
745
|
+
"FluxPipeline",
|
|
746
|
+
"FluxFillPipeline",
|
|
747
|
+
"FluxTransformer2DModel",
|
|
748
|
+
},
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
752
|
+
|
|
753
|
+
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
754
|
+
|
|
755
|
+
return cls(
|
|
756
|
+
**override_fields,
|
|
757
|
+
variant=variant,
|
|
758
|
+
repo_variant=repo_variant,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
@classmethod
|
|
762
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
|
|
763
|
+
"""Determine the FLUX variant from the transformer config.
|
|
764
|
+
|
|
765
|
+
FLUX variants are distinguished by:
|
|
766
|
+
- in_channels: 64 for Dev/Schnell, 384 for DevFill
|
|
767
|
+
- guidance_embeds: True for Dev, False for Schnell
|
|
768
|
+
"""
|
|
769
|
+
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
|
770
|
+
|
|
771
|
+
in_channels = transformer_config.get("in_channels", 64)
|
|
772
|
+
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
|
773
|
+
|
|
774
|
+
# DevFill has 384 input channels
|
|
775
|
+
if in_channels == 384:
|
|
776
|
+
return FluxVariantType.DevFill
|
|
777
|
+
|
|
778
|
+
# Dev has guidance_embeds=True, Schnell has guidance_embeds=False
|
|
779
|
+
if guidance_embeds:
|
|
780
|
+
return FluxVariantType.Dev
|
|
781
|
+
else:
|
|
782
|
+
return FluxVariantType.Schnell
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
|
|
786
|
+
"""Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
|
|
787
|
+
|
|
788
|
+
base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
|
|
789
|
+
variant: Flux2VariantType = Field()
|
|
790
|
+
|
|
791
|
+
@classmethod
|
|
792
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
793
|
+
raise_if_not_dir(mod)
|
|
794
|
+
|
|
795
|
+
raise_for_override_fields(cls, override_fields)
|
|
796
|
+
|
|
797
|
+
# Check for FLUX.2-specific pipeline class names
|
|
798
|
+
raise_for_class_name(
|
|
799
|
+
common_config_paths(mod.path),
|
|
800
|
+
{
|
|
801
|
+
"Flux2KleinPipeline",
|
|
802
|
+
},
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
|
|
806
|
+
|
|
807
|
+
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
|
|
808
|
+
|
|
809
|
+
return cls(
|
|
810
|
+
**override_fields,
|
|
811
|
+
variant=variant,
|
|
812
|
+
repo_variant=repo_variant,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
@classmethod
|
|
816
|
+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
|
|
817
|
+
"""Determine the FLUX.2 variant from the transformer config.
|
|
818
|
+
|
|
819
|
+
FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
|
|
820
|
+
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
|
|
821
|
+
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
|
|
822
|
+
|
|
823
|
+
To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
|
|
824
|
+
we check guidance_embeds:
|
|
825
|
+
- Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
|
|
826
|
+
- Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
|
|
827
|
+
|
|
828
|
+
Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
|
|
829
|
+
"""
|
|
830
|
+
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
|
|
831
|
+
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
|
|
832
|
+
|
|
833
|
+
transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
|
|
834
|
+
|
|
835
|
+
joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
|
|
836
|
+
guidance_embeds = transformer_config.get("guidance_embeds", False)
|
|
837
|
+
|
|
838
|
+
# Determine variant based on joint_attention_dim
|
|
839
|
+
if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
|
|
840
|
+
# Check guidance_embeds to distinguish distilled from undistilled
|
|
841
|
+
# Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
|
|
842
|
+
# Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
|
|
843
|
+
if guidance_embeds:
|
|
844
|
+
return Flux2VariantType.Klein9BBase
|
|
845
|
+
else:
|
|
846
|
+
return Flux2VariantType.Klein9B
|
|
847
|
+
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
|
|
848
|
+
return Flux2VariantType.Klein4B
|
|
849
|
+
elif joint_attention_dim > 4096:
|
|
850
|
+
# Unknown FLUX.2 variant, default to 4B
|
|
851
|
+
return Flux2VariantType.Klein4B
|
|
852
|
+
|
|
853
|
+
# Default to 4B
|
|
854
|
+
return Flux2VariantType.Klein4B
|
|
855
|
+
|
|
464
856
|
|
|
465
857
|
class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
|
|
466
858
|
prediction_type: SchedulerPredictionType = Field()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Literal, Optional, Self
|
|
2
3
|
|
|
3
4
|
from pydantic import Field
|
|
4
5
|
|
|
@@ -11,7 +12,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
|
|
|
11
12
|
raise_if_not_file,
|
|
12
13
|
)
|
|
13
14
|
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
|
14
|
-
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
|
15
|
+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
|
|
15
16
|
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
|
16
17
|
|
|
17
18
|
|
|
@@ -45,12 +46,67 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
|
|
|
45
46
|
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
|
|
46
47
|
|
|
47
48
|
|
|
49
|
+
def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
|
|
50
|
+
"""Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
|
|
51
|
+
|
|
52
|
+
The hidden_size can be determined from the embed_tokens.weight tensor shape:
|
|
53
|
+
- Qwen3 4B: hidden_size = 2560
|
|
54
|
+
- Qwen3 8B: hidden_size = 4096
|
|
55
|
+
|
|
56
|
+
For GGUF format, the key is 'token_embd.weight'.
|
|
57
|
+
For PyTorch format, the key is 'model.embed_tokens.weight'.
|
|
58
|
+
"""
|
|
59
|
+
# Hidden size thresholds
|
|
60
|
+
QWEN3_4B_HIDDEN_SIZE = 2560
|
|
61
|
+
QWEN3_8B_HIDDEN_SIZE = 4096
|
|
62
|
+
|
|
63
|
+
# Try to find embed_tokens weight
|
|
64
|
+
embed_key = None
|
|
65
|
+
for key in state_dict.keys():
|
|
66
|
+
if isinstance(key, str):
|
|
67
|
+
if key == "model.embed_tokens.weight" or key == "token_embd.weight":
|
|
68
|
+
embed_key = key
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
if embed_key is None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
tensor = state_dict[embed_key]
|
|
75
|
+
|
|
76
|
+
# Get hidden_size from tensor shape
|
|
77
|
+
# Shape is [vocab_size, hidden_size]
|
|
78
|
+
if isinstance(tensor, GGMLTensor):
|
|
79
|
+
# GGUF tensor
|
|
80
|
+
if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
|
|
81
|
+
hidden_size = tensor.shape[1]
|
|
82
|
+
else:
|
|
83
|
+
return None
|
|
84
|
+
elif hasattr(tensor, "shape"):
|
|
85
|
+
# PyTorch tensor
|
|
86
|
+
if len(tensor.shape) >= 2:
|
|
87
|
+
hidden_size = tensor.shape[1]
|
|
88
|
+
else:
|
|
89
|
+
return None
|
|
90
|
+
else:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# Determine variant based on hidden_size
|
|
94
|
+
if hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
|
95
|
+
return Qwen3VariantType.Qwen3_4B
|
|
96
|
+
elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
|
97
|
+
return Qwen3VariantType.Qwen3_8B
|
|
98
|
+
else:
|
|
99
|
+
# Unknown size, default to 4B (more common)
|
|
100
|
+
return Qwen3VariantType.Qwen3_4B
|
|
101
|
+
|
|
102
|
+
|
|
48
103
|
class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
|
|
49
104
|
"""Configuration for single-file Qwen3 Encoder models (safetensors)."""
|
|
50
105
|
|
|
51
106
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
52
107
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
53
108
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
109
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
54
110
|
|
|
55
111
|
@classmethod
|
|
56
112
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -62,7 +118,17 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
62
118
|
|
|
63
119
|
cls._validate_does_not_look_like_gguf_quantized(mod)
|
|
64
120
|
|
|
65
|
-
|
|
121
|
+
# Determine variant from state dict
|
|
122
|
+
variant = cls._get_variant_or_default(mod)
|
|
123
|
+
|
|
124
|
+
return cls(variant=variant, **override_fields)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
|
128
|
+
"""Get variant from state dict, defaulting to 4B if unknown."""
|
|
129
|
+
state_dict = mod.load_state_dict()
|
|
130
|
+
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
|
131
|
+
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
|
66
132
|
|
|
67
133
|
@classmethod
|
|
68
134
|
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|
|
@@ -87,6 +153,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
87
153
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
88
154
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
89
155
|
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
|
|
156
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
90
157
|
|
|
91
158
|
@classmethod
|
|
92
159
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -94,6 +161,16 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
94
161
|
|
|
95
162
|
raise_for_override_fields(cls, override_fields)
|
|
96
163
|
|
|
164
|
+
# Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
|
|
165
|
+
# Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
|
|
166
|
+
model_index_path = mod.path / "model_index.json"
|
|
167
|
+
transformer_path = mod.path / "transformer"
|
|
168
|
+
if model_index_path.exists() or transformer_path.exists():
|
|
169
|
+
raise NotAMatchError(
|
|
170
|
+
"directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
|
|
171
|
+
"not a standalone Qwen3 encoder"
|
|
172
|
+
)
|
|
173
|
+
|
|
97
174
|
# Check for text_encoder config - support both:
|
|
98
175
|
# 1. Full model structure: model_root/text_encoder/config.json
|
|
99
176
|
# 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
|
|
@@ -105,8 +182,6 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
105
182
|
elif config_path_direct.exists():
|
|
106
183
|
expected_config_path = config_path_direct
|
|
107
184
|
else:
|
|
108
|
-
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
|
109
|
-
|
|
110
185
|
raise NotAMatchError(
|
|
111
186
|
f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
|
|
112
187
|
)
|
|
@@ -121,7 +196,30 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
121
196
|
},
|
|
122
197
|
)
|
|
123
198
|
|
|
124
|
-
|
|
199
|
+
# Determine variant from config.json hidden_size
|
|
200
|
+
variant = cls._get_variant_from_config(expected_config_path)
|
|
201
|
+
|
|
202
|
+
return cls(variant=variant, **override_fields)
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
|
|
206
|
+
"""Get variant from config.json based on hidden_size."""
|
|
207
|
+
QWEN3_4B_HIDDEN_SIZE = 2560
|
|
208
|
+
QWEN3_8B_HIDDEN_SIZE = 4096
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
212
|
+
config = json.load(f)
|
|
213
|
+
hidden_size = config.get("hidden_size")
|
|
214
|
+
if hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
|
215
|
+
return Qwen3VariantType.Qwen3_8B
|
|
216
|
+
elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
|
217
|
+
return Qwen3VariantType.Qwen3_4B
|
|
218
|
+
else:
|
|
219
|
+
# Default to 4B for unknown sizes
|
|
220
|
+
return Qwen3VariantType.Qwen3_4B
|
|
221
|
+
except (json.JSONDecodeError, OSError):
|
|
222
|
+
return Qwen3VariantType.Qwen3_4B
|
|
125
223
|
|
|
126
224
|
|
|
127
225
|
class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
@@ -130,6 +228,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
130
228
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
131
229
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
132
230
|
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
231
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
133
232
|
|
|
134
233
|
@classmethod
|
|
135
234
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -141,7 +240,17 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
141
240
|
|
|
142
241
|
cls._validate_looks_like_gguf_quantized(mod)
|
|
143
242
|
|
|
144
|
-
|
|
243
|
+
# Determine variant from state dict
|
|
244
|
+
variant = cls._get_variant_or_default(mod)
|
|
245
|
+
|
|
246
|
+
return cls(variant=variant, **override_fields)
|
|
247
|
+
|
|
248
|
+
@classmethod
|
|
249
|
+
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
|
250
|
+
"""Get variant from state dict, defaulting to 4B if unknown."""
|
|
251
|
+
state_dict = mod.load_state_dict()
|
|
252
|
+
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
|
253
|
+
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
|
145
254
|
|
|
146
255
|
@classmethod
|
|
147
256
|
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|