InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -56,6 +56,7 @@ from invokeai.backend.model_manager.configs.lora import (
56
56
  )
57
57
  from invokeai.backend.model_manager.configs.main import (
58
58
  Main_BnBNF4_FLUX_Config,
59
+ Main_Checkpoint_Flux2_Config,
59
60
  Main_Checkpoint_FLUX_Config,
60
61
  Main_Checkpoint_SD1_Config,
61
62
  Main_Checkpoint_SD2_Config,
@@ -63,12 +64,15 @@ from invokeai.backend.model_manager.configs.main import (
63
64
  Main_Checkpoint_SDXLRefiner_Config,
64
65
  Main_Checkpoint_ZImage_Config,
65
66
  Main_Diffusers_CogView4_Config,
67
+ Main_Diffusers_Flux2_Config,
68
+ Main_Diffusers_FLUX_Config,
66
69
  Main_Diffusers_SD1_Config,
67
70
  Main_Diffusers_SD2_Config,
68
71
  Main_Diffusers_SD3_Config,
69
72
  Main_Diffusers_SDXL_Config,
70
73
  Main_Diffusers_SDXLRefiner_Config,
71
74
  Main_Diffusers_ZImage_Config,
75
+ Main_GGUF_Flux2_Config,
72
76
  Main_GGUF_FLUX_Config,
73
77
  Main_GGUF_ZImage_Config,
74
78
  MainModelDefaultSettings,
@@ -95,10 +99,12 @@ from invokeai.backend.model_manager.configs.textual_inversion import (
95
99
  )
96
100
  from invokeai.backend.model_manager.configs.unknown import Unknown_Config
97
101
  from invokeai.backend.model_manager.configs.vae import (
102
+ VAE_Checkpoint_Flux2_Config,
98
103
  VAE_Checkpoint_FLUX_Config,
99
104
  VAE_Checkpoint_SD1_Config,
100
105
  VAE_Checkpoint_SD2_Config,
101
106
  VAE_Checkpoint_SDXL_Config,
107
+ VAE_Diffusers_Flux2_Config,
102
108
  VAE_Diffusers_SD1_Config,
103
109
  VAE_Diffusers_SDXL_Config,
104
110
  )
@@ -148,17 +154,25 @@ AnyModelConfig = Annotated[
148
154
  Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()],
149
155
  Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()],
150
156
  Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()],
157
+ Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()],
158
+ Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()],
151
159
  Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
152
160
  Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()],
153
161
  # Main (Pipeline) - checkpoint format
162
+ # IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
163
+ # that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
154
164
  Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()],
155
165
  Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()],
156
166
  Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()],
157
167
  Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()],
168
+ Annotated[Main_Checkpoint_Flux2_Config, Main_Checkpoint_Flux2_Config.get_tag()],
158
169
  Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()],
159
170
  Annotated[Main_Checkpoint_ZImage_Config, Main_Checkpoint_ZImage_Config.get_tag()],
160
171
  # Main (Pipeline) - quantized formats
172
+ # IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
173
+ # that will reject FLUX.1 models, but FLUX.1 validation may incorrectly match FLUX.2 models
161
174
  Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
175
+ Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()],
162
176
  Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
163
177
  Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()],
164
178
  # VAE - checkpoint format
@@ -166,9 +180,11 @@ AnyModelConfig = Annotated[
166
180
  Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()],
167
181
  Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()],
168
182
  Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()],
183
+ Annotated[VAE_Checkpoint_Flux2_Config, VAE_Checkpoint_Flux2_Config.get_tag()],
169
184
  # VAE - diffusers format
170
185
  Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()],
171
186
  Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()],
187
+ Annotated[VAE_Diffusers_Flux2_Config, VAE_Diffusers_Flux2_Config.get_tag()],
172
188
  # ControlNet - checkpoint format
173
189
  Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()],
174
190
  Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()],
@@ -498,7 +514,9 @@ class ModelConfigFactory:
498
514
  # Now do any post-processing needed for specific model types/bases/etc.
499
515
  match config.type:
500
516
  case ModelType.Main:
501
- config.default_settings = MainModelDefaultSettings.from_base(config.base)
517
+ # Pass variant if available (e.g., for Flux2 models)
518
+ variant = getattr(config, "variant", None)
519
+ config.default_settings = MainModelDefaultSettings.from_base(config.base, variant)
502
520
  case ModelType.ControlNet | ModelType.T2IAdapter | ModelType.ControlLoRa:
503
521
  config.default_settings = ControlAdapterDefaultSettings.from_model_name(config.name)
504
522
  case ModelType.LoRA:
@@ -227,6 +227,42 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
227
227
 
228
228
  base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
229
229
 
230
+ @classmethod
231
+ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
232
+ """Z-Image LoRAs have different key patterns than SD/SDXL LoRAs.
233
+
234
+ Z-Image LoRAs use keys like:
235
+ - diffusion_model.layers.X.attention.to_k.lora_down.weight (DoRA format)
236
+ - diffusion_model.layers.X.attention.to_k.lora_A.weight (PEFT format)
237
+ - diffusion_model.layers.X.attention.to_k.dora_scale (DoRA scale)
238
+ """
239
+ state_dict = mod.load_state_dict()
240
+
241
+ # Check for Z-Image specific LoRA patterns
242
+ has_z_image_lora_keys = state_dict_has_any_keys_starting_with(
243
+ state_dict,
244
+ {
245
+ "diffusion_model.layers.", # Z-Image S3-DiT layer pattern
246
+ },
247
+ )
248
+
249
+ # Also check for LoRA weight suffixes (various formats)
250
+ has_lora_suffix = state_dict_has_any_keys_ending_with(
251
+ state_dict,
252
+ {
253
+ "lora_A.weight",
254
+ "lora_B.weight",
255
+ "lora_down.weight",
256
+ "lora_up.weight",
257
+ "dora_scale",
258
+ },
259
+ )
260
+
261
+ if has_z_image_lora_keys and has_lora_suffix:
262
+ return
263
+
264
+ raise NotAMatchError("model does not match Z-Image LoRA heuristics")
265
+
230
266
  @classmethod
231
267
  def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
232
268
  """Z-Image LoRAs are identified by their diffusion_model.layers structure.
@@ -23,6 +23,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
23
23
  from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
24
24
  from invokeai.backend.model_manager.taxonomy import (
25
25
  BaseModelType,
26
+ Flux2VariantType,
26
27
  FluxVariantType,
27
28
  ModelFormat,
28
29
  ModelType,
@@ -52,7 +53,11 @@ class MainModelDefaultSettings(BaseModel):
52
53
  model_config = ConfigDict(extra="forbid")
53
54
 
54
55
  @classmethod
55
- def from_base(cls, base: BaseModelType) -> Self | None:
56
+ def from_base(
57
+ cls,
58
+ base: BaseModelType,
59
+ variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
60
+ ) -> Self | None:
56
61
  match base:
57
62
  case BaseModelType.StableDiffusion1:
58
63
  return cls(width=512, height=512)
@@ -62,6 +67,14 @@ class MainModelDefaultSettings(BaseModel):
62
67
  return cls(width=1024, height=1024)
63
68
  case BaseModelType.ZImage:
64
69
  return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
70
+ case BaseModelType.Flux2:
71
+ # Different defaults based on variant
72
+ if variant == Flux2VariantType.Klein9BBase:
73
+ # Undistilled base model needs more steps
74
+ return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
75
+ else:
76
+ # Distilled models (Klein 4B, Klein 9B) use fewer steps
77
+ return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
65
78
  case _:
66
79
  # TODO(psyche): Do we want defaults for other base types?
67
80
  return None
@@ -114,7 +127,11 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
114
127
 
115
128
 
116
129
  def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
117
- """Check if state dict contains Z-Image S3-DiT transformer keys."""
130
+ """Check if state dict contains Z-Image S3-DiT transformer keys.
131
+
132
+ This function returns True only for Z-Image main models, not LoRAs.
133
+ LoRAs are excluded by checking for LoRA-specific weight suffixes.
134
+ """
118
135
  # Z-Image specific keys that distinguish it from other models
119
136
  z_image_specific_keys = {
120
137
  "cap_embedder", # Caption embedder - unique to Z-Image
@@ -122,9 +139,23 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
122
139
  "cap_pad_token", # Caption padding token
123
140
  }
124
141
 
142
+ # LoRA-specific suffixes - if present, this is a LoRA not a main model
143
+ lora_suffixes = (
144
+ ".lora_down.weight",
145
+ ".lora_up.weight",
146
+ ".lora_A.weight",
147
+ ".lora_B.weight",
148
+ ".dora_scale",
149
+ )
150
+
125
151
  for key in state_dict.keys():
126
152
  if isinstance(key, int):
127
153
  continue
154
+
155
+ # If we find any LoRA-specific keys, this is not a main model
156
+ if key.endswith(lora_suffixes):
157
+ return False
158
+
128
159
  # Check for Z-Image specific key prefixes
129
160
  # Handle both direct keys (cap_embedder.0.weight) and
130
161
  # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
@@ -132,6 +163,7 @@ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
132
163
  for part in key_parts:
133
164
  if part in z_image_specific_keys:
134
165
  return True
166
+
135
167
  return False
136
168
 
137
169
 
@@ -249,6 +281,108 @@ class Main_Checkpoint_SDXLRefiner_Config(Main_SD_Checkpoint_Config_Base, Config_
249
281
  base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner)
250
282
 
251
283
 
284
+ def _is_flux2_model(state_dict: dict[str | int, Any]) -> bool:
285
+ """Check if state dict is a FLUX.2 model by examining context_embedder dimensions.
286
+
287
+ FLUX.2 Klein uses Qwen3 encoder with larger context dimension:
288
+ - FLUX.1: context_in_dim = 4096 (T5)
289
+ - FLUX.2 Klein 4B: context_in_dim = 7680 (3×Qwen3-4B hidden size)
290
+ - FLUX.2 Klein 8B: context_in_dim = 12288 (3×Qwen3-8B hidden size)
291
+
292
+ Also checks for FLUX.2-specific 32-channel latent space (in_channels=128 after packing).
293
+ """
294
+ # Check context_embedder input dimension (most reliable)
295
+ # Weight shape: [hidden_size, context_in_dim]
296
+ for key in {"context_embedder.weight", "model.diffusion_model.context_embedder.weight"}:
297
+ if key in state_dict:
298
+ weight = state_dict[key]
299
+ if hasattr(weight, "shape") and len(weight.shape) >= 2:
300
+ context_in_dim = weight.shape[1]
301
+ # FLUX.2 has context_in_dim > 4096 (Qwen3 vs T5)
302
+ if context_in_dim > 4096:
303
+ return True
304
+
305
+ # Also check in_channels - FLUX.2 uses 128 (32 latent channels × 4 packing)
306
+ for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
307
+ if key in state_dict:
308
+ in_channels = state_dict[key].shape[1]
309
+ # FLUX.2 uses 128 in_channels (32 latent channels × 4)
310
+ # FLUX.1 uses 64 in_channels (16 latent channels × 4)
311
+ if in_channels == 128:
312
+ return True
313
+
314
+ return False
315
+
316
+
317
+ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | None:
318
+ """Determine FLUX.2 variant from state dict.
319
+
320
+ Distinguishes between Klein 4B and Klein 9B based on context embedding dimension:
321
+ - Klein 4B: context_in_dim = 7680 (3 × Qwen3-4B hidden_size 2560)
322
+ - Klein 9B: context_in_dim = 12288 (3 × Qwen3-8B hidden_size 4096)
323
+
324
+ Note: Klein 9B Base (undistilled) also has context_in_dim = 12288 but is rare.
325
+ We default to Klein9B (distilled) for all 9B models since GGUF models may not
326
+ include guidance embedding keys needed to distinguish them.
327
+
328
+ Supports both BFL format (checkpoint) and diffusers format keys:
329
+ - BFL format: txt_in.weight (context embedder)
330
+ - Diffusers format: context_embedder.weight
331
+ """
332
+ # Context dimensions for each variant
333
+ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
334
+ KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
335
+
336
+ # Check context_embedder to determine variant
337
+ # Support both BFL format (txt_in.weight) and diffusers format (context_embedder.weight)
338
+ context_keys = {
339
+ # Diffusers format
340
+ "context_embedder.weight",
341
+ "model.diffusion_model.context_embedder.weight",
342
+ # BFL format (used by checkpoint/GGUF models)
343
+ "txt_in.weight",
344
+ "model.diffusion_model.txt_in.weight",
345
+ }
346
+ for key in context_keys:
347
+ if key in state_dict:
348
+ weight = state_dict[key]
349
+ # Handle GGUF quantized tensors which use tensor_shape instead of shape
350
+ if hasattr(weight, "tensor_shape"):
351
+ shape = weight.tensor_shape
352
+ elif hasattr(weight, "shape"):
353
+ shape = weight.shape
354
+ else:
355
+ continue
356
+ if len(shape) >= 2:
357
+ context_in_dim = shape[1]
358
+ # Determine variant based on context dimension
359
+ if context_in_dim == KLEIN_9B_CONTEXT_DIM:
360
+ # Default to Klein9B (distilled) - the official/common 9B model
361
+ return Flux2VariantType.Klein9B
362
+ elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
363
+ return Flux2VariantType.Klein4B
364
+ elif context_in_dim > 4096:
365
+ # Unknown FLUX.2 variant, default to 4B
366
+ return Flux2VariantType.Klein4B
367
+
368
+ # Check in_channels as backup - can only confirm it's FLUX.2, not which variant
369
+ for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}:
370
+ if key in state_dict:
371
+ weight = state_dict[key]
372
+ # Handle GGUF quantized tensors
373
+ if hasattr(weight, "tensor_shape"):
374
+ in_channels = weight.tensor_shape[1]
375
+ elif hasattr(weight, "shape"):
376
+ in_channels = weight.shape[1]
377
+ else:
378
+ continue
379
+ if in_channels == 128:
380
+ # It's FLUX.2 but we can't determine which Klein variant, default to 4B
381
+ return Flux2VariantType.Klein4B
382
+
383
+ return None
384
+
385
+
252
386
  def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None:
253
387
  # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
254
388
 
@@ -322,8 +456,9 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
322
456
 
323
457
  @classmethod
324
458
  def _validate_is_flux(cls, mod: ModelOnDisk) -> None:
459
+ state_dict = mod.load_state_dict()
325
460
  if not state_dict_has_any_keys_exact(
326
- mod.load_state_dict(),
461
+ state_dict,
327
462
  {
328
463
  "double_blocks.0.img_attn.norm.key_norm.scale",
329
464
  "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -331,6 +466,10 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
331
466
  ):
332
467
  raise NotAMatchError("state dict does not look like a FLUX checkpoint")
333
468
 
469
+ # Exclude FLUX.2 models - they have their own config class
470
+ if _is_flux2_model(state_dict):
471
+ raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
472
+
334
473
  @classmethod
335
474
  def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
336
475
  # FLUX Model variant types are distinguished by input channels and the presence of certain keys.
@@ -364,6 +503,68 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
364
503
  raise NotAMatchError("state dict looks like GGUF quantized")
365
504
 
366
505
 
506
+ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
507
+ """Model config for FLUX.2 checkpoint models (e.g. Klein)."""
508
+
509
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
510
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
511
+
512
+ variant: Flux2VariantType = Field()
513
+
514
+ @classmethod
515
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
516
+ raise_if_not_file(mod)
517
+
518
+ raise_for_override_fields(cls, override_fields)
519
+
520
+ cls._validate_looks_like_main_model(mod)
521
+
522
+ cls._validate_is_flux2(mod)
523
+
524
+ cls._validate_does_not_look_like_bnb_quantized(mod)
525
+
526
+ cls._validate_does_not_look_like_gguf_quantized(mod)
527
+
528
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
529
+
530
+ return cls(**override_fields, variant=variant)
531
+
532
+ @classmethod
533
+ def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
534
+ """Validate that this is a FLUX.2 model, not FLUX.1."""
535
+ state_dict = mod.load_state_dict()
536
+ if not _is_flux2_model(state_dict):
537
+ raise NotAMatchError("state dict does not look like a FLUX.2 model")
538
+
539
+ @classmethod
540
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
541
+ state_dict = mod.load_state_dict()
542
+ variant = _get_flux2_variant(state_dict)
543
+
544
+ if variant is None:
545
+ raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
546
+
547
+ return variant
548
+
549
+ @classmethod
550
+ def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
551
+ has_main_model_keys = _has_main_keys(mod.load_state_dict())
552
+ if not has_main_model_keys:
553
+ raise NotAMatchError("state dict does not look like a main model")
554
+
555
+ @classmethod
556
+ def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None:
557
+ has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict())
558
+ if has_bnb_nf4_keys:
559
+ raise NotAMatchError("state dict looks like bnb quantized nf4")
560
+
561
+ @classmethod
562
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk):
563
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
564
+ if has_ggml_tensors:
565
+ raise NotAMatchError("state dict looks like GGUF quantized")
566
+
567
+
367
568
  class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
368
569
  """Model config for main checkpoint models."""
369
570
 
@@ -431,6 +632,8 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
431
632
 
432
633
  cls._validate_looks_like_gguf_quantized(mod)
433
634
 
635
+ cls._validate_is_not_flux2(mod)
636
+
434
637
  variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
435
638
 
436
639
  return cls(**override_fields, variant=variant)
@@ -461,6 +664,195 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
461
664
  if not has_ggml_tensors:
462
665
  raise NotAMatchError("state dict does not look like GGUF quantized")
463
666
 
667
+ @classmethod
668
+ def _validate_is_not_flux2(cls, mod: ModelOnDisk) -> None:
669
+ """Validate that this is NOT a FLUX.2 model."""
670
+ state_dict = mod.load_state_dict()
671
+ if _is_flux2_model(state_dict):
672
+ raise NotAMatchError("model is a FLUX.2 model, not FLUX.1")
673
+
674
+
675
+ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
676
+ """Model config for GGUF-quantized FLUX.2 checkpoint models (e.g. Klein)."""
677
+
678
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
679
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
680
+
681
+ variant: Flux2VariantType = Field()
682
+
683
+ @classmethod
684
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
685
+ raise_if_not_file(mod)
686
+
687
+ raise_for_override_fields(cls, override_fields)
688
+
689
+ cls._validate_looks_like_main_model(mod)
690
+
691
+ cls._validate_looks_like_gguf_quantized(mod)
692
+
693
+ cls._validate_is_flux2(mod)
694
+
695
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
696
+
697
+ return cls(**override_fields, variant=variant)
698
+
699
+ @classmethod
700
+ def _validate_is_flux2(cls, mod: ModelOnDisk) -> None:
701
+ """Validate that this is a FLUX.2 model, not FLUX.1."""
702
+ state_dict = mod.load_state_dict()
703
+ if not _is_flux2_model(state_dict):
704
+ raise NotAMatchError("state dict does not look like a FLUX.2 model")
705
+
706
+ @classmethod
707
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
708
+ state_dict = mod.load_state_dict()
709
+ variant = _get_flux2_variant(state_dict)
710
+
711
+ if variant is None:
712
+ raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")
713
+
714
+ return variant
715
+
716
+ @classmethod
717
+ def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None:
718
+ has_main_model_keys = _has_main_keys(mod.load_state_dict())
719
+ if not has_main_model_keys:
720
+ raise NotAMatchError("state dict does not look like a main model")
721
+
722
+ @classmethod
723
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
724
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
725
+ if not has_ggml_tensors:
726
+ raise NotAMatchError("state dict does not look like GGUF quantized")
727
+
728
+
729
+ class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
730
+ """Model config for FLUX.1 models in diffusers format."""
731
+
732
+ base: Literal[BaseModelType.Flux] = Field(BaseModelType.Flux)
733
+ variant: FluxVariantType = Field()
734
+
735
+ @classmethod
736
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
737
+ raise_if_not_dir(mod)
738
+
739
+ raise_for_override_fields(cls, override_fields)
740
+
741
+ # Check for FLUX-specific pipeline or transformer class names
742
+ raise_for_class_name(
743
+ common_config_paths(mod.path),
744
+ {
745
+ "FluxPipeline",
746
+ "FluxFillPipeline",
747
+ "FluxTransformer2DModel",
748
+ },
749
+ )
750
+
751
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
752
+
753
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
754
+
755
+ return cls(
756
+ **override_fields,
757
+ variant=variant,
758
+ repo_variant=repo_variant,
759
+ )
760
+
761
+ @classmethod
762
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType:
763
+ """Determine the FLUX variant from the transformer config.
764
+
765
+ FLUX variants are distinguished by:
766
+ - in_channels: 64 for Dev/Schnell, 384 for DevFill
767
+ - guidance_embeds: True for Dev, False for Schnell
768
+ """
769
+ transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
770
+
771
+ in_channels = transformer_config.get("in_channels", 64)
772
+ guidance_embeds = transformer_config.get("guidance_embeds", False)
773
+
774
+ # DevFill has 384 input channels
775
+ if in_channels == 384:
776
+ return FluxVariantType.DevFill
777
+
778
+ # Dev has guidance_embeds=True, Schnell has guidance_embeds=False
779
+ if guidance_embeds:
780
+ return FluxVariantType.Dev
781
+ else:
782
+ return FluxVariantType.Schnell
783
+
784
+
785
+ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
786
+ """Model config for FLUX.2 models in diffusers format (e.g. FLUX.2 Klein)."""
787
+
788
+ base: Literal[BaseModelType.Flux2] = Field(BaseModelType.Flux2)
789
+ variant: Flux2VariantType = Field()
790
+
791
+ @classmethod
792
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
793
+ raise_if_not_dir(mod)
794
+
795
+ raise_for_override_fields(cls, override_fields)
796
+
797
+ # Check for FLUX.2-specific pipeline class names
798
+ raise_for_class_name(
799
+ common_config_paths(mod.path),
800
+ {
801
+ "Flux2KleinPipeline",
802
+ },
803
+ )
804
+
805
+ variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
806
+
807
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
808
+
809
+ return cls(
810
+ **override_fields,
811
+ variant=variant,
812
+ repo_variant=repo_variant,
813
+ )
814
+
815
+ @classmethod
816
+ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
817
+ """Determine the FLUX.2 variant from the transformer config.
818
+
819
+ FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
820
+ - Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
821
+ - Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)
822
+
823
+ To distinguish Klein 9B (distilled) from Klein 9B Base (undistilled),
824
+ we check guidance_embeds:
825
+ - Klein 9B (distilled): guidance_embeds = False (guidance is "baked in" during distillation)
826
+ - Klein 9B Base (undistilled): guidance_embeds = True (needs guidance at inference)
827
+
828
+ Note: The official BFL Klein 9B model is the distilled version with guidance_embeds=False.
829
+ """
830
+ KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
831
+ KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
832
+
833
+ transformer_config = get_config_dict_or_raise(mod.path / "transformer" / "config.json")
834
+
835
+ joint_attention_dim = transformer_config.get("joint_attention_dim", 4096)
836
+ guidance_embeds = transformer_config.get("guidance_embeds", False)
837
+
838
+ # Determine variant based on joint_attention_dim
839
+ if joint_attention_dim == KLEIN_9B_CONTEXT_DIM:
840
+ # Check guidance_embeds to distinguish distilled from undistilled
841
+ # Klein 9B (distilled): guidance_embeds = False (guidance is baked in)
842
+ # Klein 9B Base (undistilled): guidance_embeds = True (needs guidance)
843
+ if guidance_embeds:
844
+ return Flux2VariantType.Klein9BBase
845
+ else:
846
+ return Flux2VariantType.Klein9B
847
+ elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
848
+ return Flux2VariantType.Klein4B
849
+ elif joint_attention_dim > 4096:
850
+ # Unknown FLUX.2 variant, default to 4B
851
+ return Flux2VariantType.Klein4B
852
+
853
+ # Default to 4B
854
+ return Flux2VariantType.Klein4B
855
+
464
856
 
465
857
  class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
466
858
  prediction_type: SchedulerPredictionType = Field()