InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +50 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/ideal_size.py +6 -1
  11. invokeai/app/invocations/metadata.py +4 -0
  12. invokeai/app/invocations/metadata_linked.py +47 -0
  13. invokeai/app/invocations/model.py +1 -0
  14. invokeai/app/invocations/z_image_denoise.py +8 -3
  15. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  16. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  17. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  18. invokeai/app/services/config/config_default.py +3 -1
  19. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  20. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  21. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  22. invokeai/app/services/model_records/model_records_base.py +4 -2
  23. invokeai/app/services/shared/invocation_context.py +15 -0
  24. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  25. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  26. invokeai/app/util/step_callback.py +42 -0
  27. invokeai/backend/flux/denoise.py +239 -204
  28. invokeai/backend/flux/dype/__init__.py +18 -0
  29. invokeai/backend/flux/dype/base.py +226 -0
  30. invokeai/backend/flux/dype/embed.py +116 -0
  31. invokeai/backend/flux/dype/presets.py +141 -0
  32. invokeai/backend/flux/dype/rope.py +110 -0
  33. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  34. invokeai/backend/flux/util.py +35 -1
  35. invokeai/backend/flux2/__init__.py +4 -0
  36. invokeai/backend/flux2/denoise.py +261 -0
  37. invokeai/backend/flux2/ref_image_extension.py +294 -0
  38. invokeai/backend/flux2/sampling_utils.py +209 -0
  39. invokeai/backend/model_manager/configs/factory.py +19 -1
  40. invokeai/backend/model_manager/configs/main.py +395 -3
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  42. invokeai/backend/model_manager/configs/vae.py +104 -2
  43. invokeai/backend/model_manager/load/load_default.py +0 -1
  44. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  45. invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
  46. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
  47. invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
  48. invokeai/backend/model_manager/starter_models.py +128 -0
  49. invokeai/backend/model_manager/taxonomy.py +31 -4
  50. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  51. invokeai/backend/util/vae_working_memory.py +0 -2
  52. invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
  53. invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
  54. invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
  55. invokeai/frontend/web/dist/index.html +1 -1
  56. invokeai/frontend/web/dist/locales/en.json +58 -5
  57. invokeai/frontend/web/dist/locales/it.json +2 -1
  58. invokeai/version/invokeai_version.py +1 -1
  59. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
  60. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
  61. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
  62. invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
  63. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
  64. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
  65. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  66. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  67. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
23
23
  from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
24
24
  from invokeai.backend.model_manager.taxonomy import (
25
25
  BaseModelType,
26
+ Flux2VariantType,
26
27
  FluxVariantType,
27
28
  ModelFormat,
28
29
  ModelType,
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
52
53
  model_config = ConfigDict(extra="forbid")
53
54
 
54
55
  @classmethod
55
- def from_base(cls, base: BaseModelType) -> Self | None:
56
+ def from_base(
57
+ cls,
58
+ base: BaseModelType,
59
+ variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
60
+ ) -> Self | None:
56
61
  match base:
57
62
  case BaseModelType.StableDiffusion1:
58
63
  return cls(width=512, height=512)
@@ -62,6 +67,14 @@ class MainModelDefaultSettings(BaseModel):
62
67
  return cls(width=1024, height=1024)
63
68
  case BaseModelType.ZImage:
64
69
  return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
70
+ case BaseModelType.Flux2:
71
+ # Different defaults based on variant
72
+ if variant == Flux2VariantType.Klein9BBase:
73
+ # Undistilled base model needs more steps
74
+ return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
75
+ else:
76
+ # Distilled models (Klein 4B, Klein 9B) use fewer steps
77
+ return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
65
78
  case _:
66
79
  # TODO(psyche): Do we want defaults for other base types?
67
80
  return None
@@ -114,7 +127,11 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
114
127
 
115
128
 
116
129
  def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
117
- """Check if state dict contains Z-Image S3-DiT transformer keys."""
130
+ """Check if state dict contains Z-Image S3-DiT transformer keys.
131
+
132
+ This function returns True only for Z-Image main models, not LoRAs.
133
+ LoRAs are excluded by checking for LoRA-specific weight suffixes.
134
+ """
118
135
  # Z-Image specific keys that distinguish it from other models
119
136
  z_image_specific_keys = {
120
137
  "cap_embedder", # Caption embedder - unique to Z-Image
@@ -122,9 +139,23 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
122
139
  "cap_pad_token", # Caption padding token
123
140
  }
124
141
 
142
+ # LoRA-specific suffixes - if present, this is a LoRA not a main model
143
+ lora_suffixes = (
144
+ ".lora_down.weight",
145
+ ".lora_up.weight",
146
+ ".lora_A.weight",
147
+ ".lora_B.weight",
148
+ ".dora_scale",
149
+ )
150
+
125
151
  for key in state_dict.keys():
126
152
  if isinstance(key, int):
127
153
  continue
154
+
155
+ # If we find any LoRA-specific keys, this is not a main model
156
+ if key.endswith(lora_suffixes):
157
+ return False
158
+
128
159
  # Check for Z-Image specific key prefixes
129
160
  # Handle both direct keys (cap_embedder.0.weight) and
130
161
  # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
@@ -132,6 +163,7 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
132
163
  for part in key_parts:
133
164
  if part in z_image_specific_keys:
134
165
  return True
166
+
135
167
  return False
136
168
 
137
169
 
@@ -249,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
249
281
  base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
250
282
 
251
283
 
284
+ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
285
+ """Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
286
+
287
+ FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
288
+ - FLUX.1: context_in_dim = 4096 (T5)
289
+ - FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
290
+ - FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
291
+
292
+ Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
293
+ """
294
+ # Check context_embedder input dimension (most reliable)
295
+ # Weight shape: [hidden_size, context_in_dim]
296
+ for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
297
+ if key in state_dict:
298
+ weight = state_dict[key]
299
+ if hasattr(weight, "shape") and len(weight.shape) >= 2:
300
+ context_in_dim = weight.shape[1]
301
+ # FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
302
+ if context_in_dim > 4096:
303
+ return True
304
+
305
+ # Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
306
+ for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
307
+ if key in state_dict:
308
+ in_channels = state_dict[key].shape[1]
309
+ # FLUX.2 uses 128 in_channels (32 latent channels × 4)
310
+ # FLUX.1 uses 64 in_channels (16 latent channels × 4)
311
+ if in_channels == 128:
312
+ return True
313
+
314
+ return False
315
+
316
+
317
+ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
318
+ """Determine FLUX.2 variant from state dict.
319
+
320
+ Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
321
+ - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
322
+ - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
323
+
324
+ Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
325
+ We default to Klein9B (distilled) for all 9B models since GGUF models may not
326
+ include guidance embedding keys needed to distinguish them.
327
+
328
+ Supports both BFL format (checkpoint) and diffusers format keys:
329
+ - BFL format: txt_in.weight (context embedder)
330
+ - Diffusers format: context_embedder.weight
331
+ """
332
+ # Context dimensions for each variant
333
+ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
334
+ KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
335
+
336
+ # Check context_embedder to determine variant
337
+ # Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
338
+ context_keys = {
339
+ # Diffusers format
340
+ "context_embedder.weight",
341
+ "model.diffusion_model.context_embedder.weight",
342
+ # BFL format (used by checkpoint/GGUF models)
343
+ "txt_in.weight",
344
+ "model.diffusion_model.txt_in.weight",
345
+ }
346
+ for key in context_keys:
347
+ if key in state_dict:
348
+ weight = state_dict[key]
349
+ # Handle GGUF quantized tensors which use tensor_shape instead of shape
350
+ if hasattr(weight, "tensor_shape"):
351
+ shape = weight.tensor_shape
352
+ elif hasattr(weight, "shape"):
353
+ shape = weight.shape
354
+ else:
355
+ continue
356
+ if len(shape) >= 2:
357
+ context_in_dim = shape[1]
358
+ # Determine variant based on context dimension
359
+ if context_in_dim == KLEIN_9B_CONTEXT_DIM:
360
+ # Default to Klein9B (distilled) - the official/common 9B model
361
+ return Flux2VariantType.Klein9B
362
+ elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
363
+ return Flux2VariantType.Klein4B
364
+ elif context_in_dim > 4096:
365
+ # Unknown FLUX.2 variant, default to 4B
366
+ return Flux2VariantType.Klein4B
367
+
368
+ # Check in_channels as backup - can only confirm it's FLUX.2, not which variant
369
+ for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
370
+ if key in state_dict:
371
+ weight = state_dict[key]
372
+ # Handle GGUF quantized tensors
373
+ if hasattr(weight, "tensor_shape"):
374
+ in_channels = weight.tensor_shape[1]
375
+ elif hasattr(weight, "shape"):
376
+ in_channels = weight.shape[1]
377
+ else:
378
+ continue
379
+ if in_channels == 128:
380
+ # It's FLUX.2 but we can't determine which Klein variant, default to 4B
381
+ return Flux2VariantType.Klein4B
382
+
383
+ return None
384
+
385
+
252
386
  def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
253
387
  # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
254
388
 
@@ -322,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
322
456
 
323
457
  @classmethod
324
458
  def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
459
+ state_dict = mod.load_state_dict()
325
460
  if not state_dict_has_any_keys_exact(
326
- mod.load_state_dict(),
461
+ state_dict,
327
462
  {
328
463
  "double_blocks.0.img_attn.norm.key_norm.scale",
329
464
  "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -331,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
331
466
  ):
332
467
  raise NotAMatchError("state dict does not look like a FLUX checkpoint")
333
468
 
469
+ # Exclude FLUX.2 models - they have their own config class
470
+ if _is_flux2_model(state_dict):
471
+ raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
472
+
334
473
  @classmethod
335
474
  def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
336
475
  # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
@@ -364,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
364
503
  raise NotAMatchError("state dict looks like GGUF quantized")
365
504
 
366
505
 
506
+ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
507
+ """Model config for FLUX.2 checkpoint models (e.g. Klein)."""
508
+
509
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
510
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
511
+
512
+ variant: Flux2VariantType = Field()
513
+
514
+ @classmethod
515
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
516
+ raise_if_not_file(mod)
517
+
518
+ raise_for_override_fields(cls, override_fields)
519
+
520
+ cls._validate_looks_like_main_model(mod)
521
+
522
+ cls._validate_is_flux2(mod)
523
+
524
+ cls._validate_does_not_look_like_bnb_quantized(mod)
525
+
526
+ cls._validate_does_not_look_like_gguf_quantized(mod)
527
+
528
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
529
+
530
+ return cls(**override_fields, variant=variant)
531
+
532
+ @classmethod
533
+ def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
534
+ """Validate that this is a FLUX.2 model, not FLUX.1."""
535
+ state_dict = mod.load_state_dict()
536
+ if not _is_flux2_model(state_dict):
537
+ raise NotAMatchError("state dict does not look like a FLUX.2 model")
538
+
539
+ @classmethod
540
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
541
+ state_dict = mod.load_state_dict()
542
+ variant = _get_flux2_variant(state_dict)
543
+
544
+ if variant is None:
545
+ raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
546
+
547
+ return variant
548
+
549
+ @classmethod
550
+ def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
551
+ has_main_model_keys = _has_main_keys(mod.load_state_dict())
552
+ if not has_main_model_keys:
553
+ raise NotAMatchError("state dict does not look like a main model")
554
+
555
+ @classmethod
556
+ def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
557
+ has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
558
+ if has_bnb_nf4_keys:
559
+ raise NotAMatchError("state dict looks like bnb quantized nf4")
560
+
561
+ @classmethod
562
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
563
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
564
+ if has_ggml_tensors:
565
+ raise NotAMatchError("state dict looks like GGUF quantized")
566
+
567
+
367
568
  class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
368
569
  """Model config for main checkpoint models."""
369
570
 
@@ -431,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
431
632
 
432
633
  cls._validate_looks_like_gguf_quantized(mod)
433
634
 
635
+ cls._validate_is_not_flux2(mod)
636
+
434
637
  variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
435
638
 
436
639
  return cls(**override_fields, variant=variant)
@@ -461,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
461
664
  if not has_ggml_tensors:
462
665
  raise NotAMatchError("state dict does not look like GGUF quantized")
463
666
 
667
+ @classmethod
668
+ def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
669
+ """Validate that this is NOT a FLUX.2 model."""
670
+ state_dict = mod.load_state_dict()
671
+ if _is_flux2_model(state_dict):
672
+ raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
673
+
674
+
675
+ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
676
+ """Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
677
+
678
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
679
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
680
+
681
+ variant: Flux2VariantType = Field()
682
+
683
+ @classmethod
684
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
685
+ raise_if_not_file(mod)
686
+
687
+ raise_for_override_fields(cls, override_fields)
688
+
689
+ cls._validate_looks_like_main_model(mod)
690
+
691
+ cls._validate_looks_like_gguf_quantized(mod)
692
+
693
+ cls._validate_is_flux2(mod)
694
+
695
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
696
+
697
+ return cls(**override_fields, variant=variant)
698
+
699
+ @classmethod
700
+ def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
701
+ """Validate that this is a FLUX.2 model, not FLUX.1."""
702
+ state_dict = mod.load_state_dict()
703
+ if not _is_flux2_model(state_dict):
704
+ raise NotAMatchError("state dict does not look like a FLUX.2 model")
705
+
706
+ @classmethod
707
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
708
+ state_dict = mod.load_state_dict()
709
+ variant = _get_flux2_variant(state_dict)
710
+
711
+ if variant is None:
712
+ raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
713
+
714
+ return variant
715
+
716
+ @classmethod
717
+ def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
718
+ has_main_model_keys = _has_main_keys(mod.load_state_dict())
719
+ if not has_main_model_keys:
720
+ raise NotAMatchError("state dict does not look like a main model")
721
+
722
+ @classmethod
723
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
724
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
725
+ if not has_ggml_tensors:
726
+ raise NotAMatchError("state dict does not look like GGUF quantized")
727
+
728
+
729
+ class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
730
+ """Model config for FLUX.1 models in diffusers format."""
731
+
732
+ base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
733
+ variant: FluxVariantType = Field()
734
+
735
+ @classmethod
736
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
737
+ raise_if_not_dir(mod)
738
+
739
+ raise_for_override_fields(cls, override_fields)
740
+
741
+ # Check for FLUX-specific pipeline or transformer class names
742
+ raise_for_class_name(
743
+ common_config_paths(mod.path),
744
+ {
745
+ "FluxPipeline",
746
+ "FluxFillPipeline",
747
+ "FluxTransformer2DModel",
748
+ },
749
+ )
750
+
751
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
752
+
753
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
754
+
755
+ return cls(
756
+ **override_fields,
757
+ variant=variant,
758
+ repo_variant=repo_variant,
759
+ )
760
+
761
+ @classmethod
762
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
763
+ """Determine the FLUX variant from the transformer config.
764
+
765
+ FLUX variants are distinguished by:
766
+ - in_channels: 64 for Dev/Schnell, 384 for DevFill
767
+ - guidance_embeds: True for Dev, False for Schnell
768
+ """
769
+ transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
770
+
771
+ in_channels = transformer_config.get("in_channels", 64)
772
+ guidance_embeds = transformer_config.get("guidance_embeds", False)
773
+
774
+ # DevFill has 384 input channels
775
+ if in_channels == 384:
776
+ return FluxVariantType.DevFill
777
+
778
+ # Dev has guidance_embeds=True, Schnell has guidance_embeds=False
779
+ if guidance_embeds:
780
+ return FluxVariantType.Dev
781
+ else:
782
+ return FluxVariantType.Schnell
783
+
784
+
785
+ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
786
+ """Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
787
+
788
+ base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
789
+ variant: Flux2VariantType = Field()
790
+
791
+ @classmethod
792
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
793
+ raise_if_not_dir(mod)
794
+
795
+ raise_for_override_fields(cls, override_fields)
796
+
797
+ # Check for FLUX.2-specific pipeline class names
798
+ raise_for_class_name(
799
+ common_config_paths(mod.path),
800
+ {
801
+ "Flux2KleinPipeline",
802
+ },
803
+ )
804
+
805
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
806
+
807
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
808
+
809
+ return cls(
810
+ **override_fields,
811
+ variant=variant,
812
+ repo_variant=repo_variant,
813
+ )
814
+
815
+ @classmethod
816
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
817
+ """Determine the FLUX.2 variant from the transformer config.
818
+
819
+ FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
820
+ - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
821
+ - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
822
+
823
+ To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
824
+ we check guidance_embeds:
825
+ - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
826
+ - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
827
+
828
+ Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
829
+ """
830
+ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
831
+ KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
832
+
833
+ transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
834
+
835
+ joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
836
+ guidance_embeds = transformer_config.get("guidance_embeds", False)
837
+
838
+ # Determine variant based on joint_attention_dim
839
+ if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
840
+ # Check guidance_embeds to distinguish distilled from undistilled
841
+ # Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
842
+ # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
843
+ if guidance_embeds:
844
+ return Flux2VariantType.Klein9BBase
845
+ else:
846
+ return Flux2VariantType.Klein9B
847
+ elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
848
+ return Flux2VariantType.Klein4B
849
+ elif joint_attention_dim > 4096:
850
+ # Unknown FLUX.2 variant, default to 4B
851
+ return Flux2VariantType.Klein4B
852
+
853
+ # Default to 4B
854
+ return Flux2VariantType.Klein4B
855
+
464
856
 
465
857
  class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
466
858
  prediction_type: SchedulerPredictionType = Field()
@@ -1,4 +1,5 @@
1
- from typing import Any, Literal, Self
1
+ import json
2
+ from typing import Any, Literal, Optional, Self
2
3
 
3
4
  from pydantic import Field
4
5
 
@@ -11,7 +12,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
11
12
  raise_if_not_file,
12
13
  )
13
14
  from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
14
- from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
15
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
15
16
  from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
16
17
 
17
18
 
@@ -45,12 +46,67 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
45
46
  return any(isinstance(v, GGMLTensor) for v in state_dict.values())
46
47
 
47
48
 
49
+ def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
50
+ """Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
51
+
52
+ The hidden_size can be determined from the embed_tokens.weight tensor shape:
53
+ - Qwen3 4B: hidden_size = 2560
54
+ - Qwen3 8B: hidden_size = 4096
55
+
56
+ For GGUF format, the key is 'token_embd.weight'.
57
+ For PyTorch format, the key is 'model.embed_tokens.weight'.
58
+ """
59
+ # Hidden size thresholds
60
+ QWEN3_4B_HIDDEN_SIZE = 2560
61
+ QWEN3_8B_HIDDEN_SIZE = 4096
62
+
63
+ # Try to find embed_tokens weight
64
+ embed_key = None
65
+ for key in state_dict.keys():
66
+ if isinstance(key, str):
67
+ if key == "model.embed_tokens.weight" or key == "token_embd.weight":
68
+ embed_key = key
69
+ break
70
+
71
+ if embed_key is None:
72
+ return None
73
+
74
+ tensor = state_dict[embed_key]
75
+
76
+ # Get hidden_size from tensor shape
77
+ # Shape is [vocab_size, hidden_size]
78
+ if isinstance(tensor, GGMLTensor):
79
+ # GGUF tensor
80
+ if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
81
+ hidden_size = tensor.shape[1]
82
+ else:
83
+ return None
84
+ elif hasattr(tensor, "shape"):
85
+ # PyTorch tensor
86
+ if len(tensor.shape) >= 2:
87
+ hidden_size = tensor.shape[1]
88
+ else:
89
+ return None
90
+ else:
91
+ return None
92
+
93
+ # Determine variant based on hidden_size
94
+ if hidden_size == QWEN3_4B_HIDDEN_SIZE:
95
+ return Qwen3VariantType.Qwen3_4B
96
+ elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
97
+ return Qwen3VariantType.Qwen3_8B
98
+ else:
99
+ # Unknown size, default to 4B (more common)
100
+ return Qwen3VariantType.Qwen3_4B
101
+
102
+
48
103
  class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
49
104
  """Configuration for single-file Qwen3 Encoder models (safetensors)."""
50
105
 
51
106
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
52
107
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
53
108
  format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
109
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
54
110
 
55
111
  @classmethod
56
112
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -62,7 +118,17 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
62
118
 
63
119
  cls._validate_does_not_look_like_gguf_quantized(mod)
64
120
 
65
- return cls(**override_fields)
121
+ # Determine variant from state dict
122
+ variant = cls._get_variant_or_default(mod)
123
+
124
+ return cls(variant=variant, **override_fields)
125
+
126
+ @classmethod
127
+ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
128
+ """Get variant from state dict, defaulting to 4B if unknown."""
129
+ state_dict = mod.load_state_dict()
130
+ variant = _get_qwen3_variant_from_state_dict(state_dict)
131
+ return variant if variant is not None else Qwen3VariantType.Qwen3_4B
66
132
 
67
133
  @classmethod
68
134
  def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
@@ -87,6 +153,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
87
153
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
88
154
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
89
155
  format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
156
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
90
157
 
91
158
  @classmethod
92
159
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -94,6 +161,16 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
94
161
 
95
162
  raise_for_override_fields(cls, override_fields)
96
163
 
164
+ # Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
165
+ # Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
166
+ model_index_path = mod.path / "model_index.json"
167
+ transformer_path = mod.path / "transformer"
168
+ if model_index_path.exists() or transformer_path.exists():
169
+ raise NotAMatchError(
170
+ "directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
171
+ "not a standalone Qwen3 encoder"
172
+ )
173
+
97
174
  # Check for text_encoder config - support both:
98
175
  # 1. Full model structure: model_root/text_encoder/config.json
99
176
  # 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
@@ -105,8 +182,6 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
105
182
  elif config_path_direct.exists():
106
183
  expected_config_path = config_path_direct
107
184
  else:
108
- from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
109
-
110
185
  raise NotAMatchError(
111
186
  f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
112
187
  )
@@ -121,7 +196,30 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
121
196
  },
122
197
  )
123
198
 
124
- return cls(**override_fields)
199
+ # Determine variant from config.json hidden_size
200
+ variant = cls._get_variant_from_config(expected_config_path)
201
+
202
+ return cls(variant=variant, **override_fields)
203
+
204
+ @classmethod
205
+ def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
206
+ """Get variant from config.json based on hidden_size."""
207
+ QWEN3_4B_HIDDEN_SIZE = 2560
208
+ QWEN3_8B_HIDDEN_SIZE = 4096
209
+
210
+ try:
211
+ with open(config_path, "r", encoding="utf-8") as f:
212
+ config = json.load(f)
213
+ hidden_size = config.get("hidden_size")
214
+ if hidden_size == QWEN3_8B_HIDDEN_SIZE:
215
+ return Qwen3VariantType.Qwen3_8B
216
+ elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
217
+ return Qwen3VariantType.Qwen3_4B
218
+ else:
219
+ # Default to 4B for unknown sizes
220
+ return Qwen3VariantType.Qwen3_4B
221
+ except (json.JSONDecodeError, OSError):
222
+ return Qwen3VariantType.Qwen3_4B
125
223
 
126
224
 
127
225
  class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
@@ -130,6 +228,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
130
228
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
131
229
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
132
230
  format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
231
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
133
232
 
134
233
  @classmethod
135
234
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -141,7 +240,17 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
141
240
 
142
241
  cls._validate_looks_like_gguf_quantized(mod)
143
242
 
144
- return cls(**override_fields)
243
+ # Determine variant from state dict
244
+ variant = cls._get_variant_or_default(mod)
245
+
246
+ return cls(variant=variant, **override_fields)
247
+
248
+ @classmethod
249
+ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
250
+ """Get variant from state dict, defaulting to 4B if unknown."""
251
+ state_dict = mod.load_state_dict()
252
+ variant = _get_qwen3_variant_from_state_dict(state_dict)
253
+ return variant if variant is not None else Qwen3VariantType.Qwen3_4B
145
254
 
146
255
  @classmethod
147
256
  def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None: