InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,7 @@ from invokeai.backend.flux.model import Flux
34
34
  from invokeai.backend.flux.modules.autoencoder import AutoEncoder
35
35
  from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
36
36
  from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
37
- from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
37
+ from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
38
38
  from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base
39
39
  from invokeai.backend.model_manager.configs.controlnet import (
40
40
  ControlNet_Checkpoint_Config_Base,
@@ -45,13 +45,16 @@ from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoi
45
45
  from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base
46
46
  from invokeai.backend.model_manager.configs.main import (
47
47
  Main_BnBNF4_FLUX_Config,
48
+ Main_Checkpoint_Flux2_Config,
48
49
  Main_Checkpoint_FLUX_Config,
50
+ Main_GGUF_Flux2_Config,
49
51
  Main_GGUF_FLUX_Config,
50
52
  )
51
53
  from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
52
- from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
54
+ from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base, VAE_Checkpoint_Flux2_Config
53
55
  from invokeai.backend.model_manager.load.load_default import ModelLoader
54
56
  from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
57
+ from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
55
58
  from invokeai.backend.model_manager.taxonomy import (
56
59
  AnyModel,
57
60
  BaseModelType,
@@ -108,6 +111,264 @@ class FluxVAELoader(ModelLoader):
108
111
  return model
109
112
 
110
113
 
114
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.VAE, format=ModelFormat.Diffusers)
115
+ class Flux2VAEDiffusersLoader(ModelLoader):
116
+ """Class to load FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2 with 32 latent channels)."""
117
+
118
+ def _load_model(
119
+ self,
120
+ config: AnyModelConfig,
121
+ submodel_type: Optional[SubModelType] = None,
122
+ ) -> AnyModel:
123
+ from diffusers import AutoencoderKLFlux2
124
+
125
+ model_path = Path(config.path)
126
+
127
+ # VAE is broken in float16, which mps defaults to
128
+ if self._torch_dtype == torch.float16:
129
+ try:
130
+ vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
131
+ except TypeError:
132
+ vae_dtype = torch.float32
133
+ else:
134
+ vae_dtype = self._torch_dtype
135
+
136
+ model = AutoencoderKLFlux2.from_pretrained(
137
+ model_path,
138
+ torch_dtype=vae_dtype,
139
+ local_files_only=True,
140
+ )
141
+
142
+ return model
143
+
144
+
145
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
146
+ class Flux2VAELoader(ModelLoader):
147
+ """Class to load FLUX.2 VAE models (AutoencoderKLFlux2 with 32 latent channels)."""
148
+
149
+ def _load_model(
150
+ self,
151
+ config: AnyModelConfig,
152
+ submodel_type: Optional[SubModelType] = None,
153
+ ) -> AnyModel:
154
+ if not isinstance(config, VAE_Checkpoint_Flux2_Config):
155
+ raise ValueError("Only VAE_Checkpoint_Flux2_Config models are currently supported here.")
156
+
157
+ from diffusers import AutoencoderKLFlux2
158
+
159
+ model_path = Path(config.path)
160
+
161
+ # Load state dict manually since from_single_file may not support AutoencoderKLFlux2 yet
162
+ sd = load_file(model_path)
163
+
164
+ # Convert BFL format to diffusers format if needed
165
+ # BFL format uses: encoder.down., decoder.up., decoder.mid.block_1, decoder.mid.attn_1, decoder.norm_out
166
+ # Diffusers uses: encoder.down_blocks., decoder.up_blocks., decoder.mid_block.resnets., decoder.conv_norm_out
167
+ is_bfl_format = any(
168
+ k.startswith("encoder.down.")
169
+ or k.startswith("decoder.up.")
170
+ or k.startswith("decoder.mid.block_")
171
+ or k.startswith("decoder.mid.attn_")
172
+ or k.startswith("decoder.norm_out")
173
+ or k.startswith("encoder.mid.block_")
174
+ or k.startswith("encoder.mid.attn_")
175
+ or k.startswith("encoder.norm_out")
176
+ for k in sd.keys()
177
+ )
178
+ if is_bfl_format:
179
+ sd = self._convert_flux2_vae_bfl_to_diffusers(sd)
180
+
181
+ # FLUX.2 VAE configuration (32 latent channels)
182
+ # Based on the official FLUX.2 VAE architecture
183
+ # Use default config - AutoencoderKLFlux2 has built-in defaults
184
+ with SilenceWarnings():
185
+ with accelerate.init_empty_weights():
186
+ model = AutoencoderKLFlux2()
187
+
188
+ # Convert to bfloat16 and load
189
+ for k in sd.keys():
190
+ sd[k] = sd[k].to(torch.bfloat16)
191
+
192
+ model.load_state_dict(sd, assign=True)
193
+
194
+ # VAE is broken in float16, which mps defaults to
195
+ if self._torch_dtype == torch.float16:
196
+ try:
197
+ vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
198
+ except TypeError:
199
+ vae_dtype = torch.float32
200
+ else:
201
+ vae_dtype = self._torch_dtype
202
+ model.to(vae_dtype)
203
+
204
+ return model
205
+
206
+ def _convert_flux2_vae_bfl_to_diffusers(self, sd: dict) -> dict:
207
+ """Convert FLUX.2 VAE BFL format state dict to diffusers format.
208
+
209
+ Key differences:
210
+ - encoder.down.X.block.Y -> encoder.down_blocks.X.resnets.Y
211
+ - encoder.down.X.downsample.conv -> encoder.down_blocks.X.downsamplers.0.conv
212
+ - encoder.mid.block_1/2 -> encoder.mid_block.resnets.0/1
213
+ - encoder.mid.attn_1.q/k/v -> encoder.mid_block.attentions.0.to_q/k/v
214
+ - encoder.norm_out -> encoder.conv_norm_out
215
+ - encoder.quant_conv -> quant_conv (top-level)
216
+ - decoder.up.X -> decoder.up_blocks.(num_blocks-1-X) (reversed order!)
217
+ - decoder.post_quant_conv -> post_quant_conv (top-level)
218
+ - *.nin_shortcut -> *.conv_shortcut
219
+ """
220
+ import re
221
+
222
+ converted = {}
223
+ num_up_blocks = 4 # Standard VAE has 4 up blocks
224
+
225
+ for old_key, tensor in sd.items():
226
+ new_key = old_key
227
+
228
+ # Encoder down blocks: encoder.down.X.block.Y -> encoder.down_blocks.X.resnets.Y
229
+ match = re.match(r"encoder\.down\.(\d+)\.block\.(\d+)\.(.*)", old_key)
230
+ if match:
231
+ block_idx, resnet_idx, rest = match.groups()
232
+ rest = rest.replace("nin_shortcut", "conv_shortcut")
233
+ new_key = f"encoder.down_blocks.{block_idx}.resnets.{resnet_idx}.{rest}"
234
+ converted[new_key] = tensor
235
+ continue
236
+
237
+ # Encoder downsamplers: encoder.down.X.downsample.conv -> encoder.down_blocks.X.downsamplers.0.conv
238
+ match = re.match(r"encoder\.down\.(\d+)\.downsample\.conv\.(.*)", old_key)
239
+ if match:
240
+ block_idx, rest = match.groups()
241
+ new_key = f"encoder.down_blocks.{block_idx}.downsamplers.0.conv.{rest}"
242
+ converted[new_key] = tensor
243
+ continue
244
+
245
+ # Encoder mid block resnets: encoder.mid.block_1/2 -> encoder.mid_block.resnets.0/1
246
+ match = re.match(r"encoder\.mid\.block_(\d+)\.(.*)", old_key)
247
+ if match:
248
+ block_num, rest = match.groups()
249
+ resnet_idx = int(block_num) - 1 # block_1 -> resnets.0, block_2 -> resnets.1
250
+ new_key = f"encoder.mid_block.resnets.{resnet_idx}.{rest}"
251
+ converted[new_key] = tensor
252
+ continue
253
+
254
+ # Encoder mid block attention: encoder.mid.attn_1.* -> encoder.mid_block.attentions.0.*
255
+ match = re.match(r"encoder\.mid\.attn_1\.(.*)", old_key)
256
+ if match:
257
+ rest = match.group(1)
258
+ # Map attention keys
259
+ # BFL uses Conv2d (shape [out, in, 1, 1]), diffusers uses Linear (shape [out, in])
260
+ # Squeeze the extra dimensions for weight tensors
261
+ if rest.startswith("q."):
262
+ new_key = f"encoder.mid_block.attentions.0.to_q.{rest[2:]}"
263
+ if rest.endswith(".weight") and tensor.dim() == 4:
264
+ tensor = tensor.squeeze(-1).squeeze(-1)
265
+ elif rest.startswith("k."):
266
+ new_key = f"encoder.mid_block.attentions.0.to_k.{rest[2:]}"
267
+ if rest.endswith(".weight") and tensor.dim() == 4:
268
+ tensor = tensor.squeeze(-1).squeeze(-1)
269
+ elif rest.startswith("v."):
270
+ new_key = f"encoder.mid_block.attentions.0.to_v.{rest[2:]}"
271
+ if rest.endswith(".weight") and tensor.dim() == 4:
272
+ tensor = tensor.squeeze(-1).squeeze(-1)
273
+ elif rest.startswith("proj_out."):
274
+ new_key = f"encoder.mid_block.attentions.0.to_out.0.{rest[9:]}"
275
+ if rest.endswith(".weight") and tensor.dim() == 4:
276
+ tensor = tensor.squeeze(-1).squeeze(-1)
277
+ elif rest.startswith("norm."):
278
+ new_key = f"encoder.mid_block.attentions.0.group_norm.{rest[5:]}"
279
+ else:
280
+ new_key = f"encoder.mid_block.attentions.0.{rest}"
281
+ converted[new_key] = tensor
282
+ continue
283
+
284
+ # Encoder norm_out -> conv_norm_out
285
+ if old_key.startswith("encoder.norm_out."):
286
+ new_key = old_key.replace("encoder.norm_out.", "encoder.conv_norm_out.")
287
+ converted[new_key] = tensor
288
+ continue
289
+
290
+ # Encoder quant_conv -> quant_conv (move to top level)
291
+ if old_key.startswith("encoder.quant_conv."):
292
+ new_key = old_key.replace("encoder.quant_conv.", "quant_conv.")
293
+ converted[new_key] = tensor
294
+ continue
295
+
296
+ # Decoder up blocks (reversed order!): decoder.up.X -> decoder.up_blocks.(num_blocks-1-X)
297
+ match = re.match(r"decoder\.up\.(\d+)\.block\.(\d+)\.(.*)", old_key)
298
+ if match:
299
+ block_idx, resnet_idx, rest = match.groups()
300
+ # Reverse the block index
301
+ new_block_idx = num_up_blocks - 1 - int(block_idx)
302
+ rest = rest.replace("nin_shortcut", "conv_shortcut")
303
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.{rest}"
304
+ converted[new_key] = tensor
305
+ continue
306
+
307
+ # Decoder upsamplers (reversed order!)
308
+ match = re.match(r"decoder\.up\.(\d+)\.upsample\.conv\.(.*)", old_key)
309
+ if match:
310
+ block_idx, rest = match.groups()
311
+ new_block_idx = num_up_blocks - 1 - int(block_idx)
312
+ new_key = f"decoder.up_blocks.{new_block_idx}.upsamplers.0.conv.{rest}"
313
+ converted[new_key] = tensor
314
+ continue
315
+
316
+ # Decoder mid block resnets: decoder.mid.block_1/2 -> decoder.mid_block.resnets.0/1
317
+ match = re.match(r"decoder\.mid\.block_(\d+)\.(.*)", old_key)
318
+ if match:
319
+ block_num, rest = match.groups()
320
+ resnet_idx = int(block_num) - 1
321
+ new_key = f"decoder.mid_block.resnets.{resnet_idx}.{rest}"
322
+ converted[new_key] = tensor
323
+ continue
324
+
325
+ # Decoder mid block attention: decoder.mid.attn_1.* -> decoder.mid_block.attentions.0.*
326
+ match = re.match(r"decoder\.mid\.attn_1\.(.*)", old_key)
327
+ if match:
328
+ rest = match.group(1)
329
+ # BFL uses Conv2d (shape [out, in, 1, 1]), diffusers uses Linear (shape [out, in])
330
+ # Squeeze the extra dimensions for weight tensors
331
+ if rest.startswith("q."):
332
+ new_key = f"decoder.mid_block.attentions.0.to_q.{rest[2:]}"
333
+ if rest.endswith(".weight") and tensor.dim() == 4:
334
+ tensor = tensor.squeeze(-1).squeeze(-1)
335
+ elif rest.startswith("k."):
336
+ new_key = f"decoder.mid_block.attentions.0.to_k.{rest[2:]}"
337
+ if rest.endswith(".weight") and tensor.dim() == 4:
338
+ tensor = tensor.squeeze(-1).squeeze(-1)
339
+ elif rest.startswith("v."):
340
+ new_key = f"decoder.mid_block.attentions.0.to_v.{rest[2:]}"
341
+ if rest.endswith(".weight") and tensor.dim() == 4:
342
+ tensor = tensor.squeeze(-1).squeeze(-1)
343
+ elif rest.startswith("proj_out."):
344
+ new_key = f"decoder.mid_block.attentions.0.to_out.0.{rest[9:]}"
345
+ if rest.endswith(".weight") and tensor.dim() == 4:
346
+ tensor = tensor.squeeze(-1).squeeze(-1)
347
+ elif rest.startswith("norm."):
348
+ new_key = f"decoder.mid_block.attentions.0.group_norm.{rest[5:]}"
349
+ else:
350
+ new_key = f"decoder.mid_block.attentions.0.{rest}"
351
+ converted[new_key] = tensor
352
+ continue
353
+
354
+ # Decoder norm_out -> conv_norm_out
355
+ if old_key.startswith("decoder.norm_out."):
356
+ new_key = old_key.replace("decoder.norm_out.", "decoder.conv_norm_out.")
357
+ converted[new_key] = tensor
358
+ continue
359
+
360
+ # Decoder post_quant_conv -> post_quant_conv (move to top level)
361
+ if old_key.startswith("decoder.post_quant_conv."):
362
+ new_key = old_key.replace("decoder.post_quant_conv.", "post_quant_conv.")
363
+ converted[new_key] = tensor
364
+ continue
365
+
366
+ # Keep other keys as-is (like encoder.conv_in, decoder.conv_in, decoder.conv_out, bn.*)
367
+ converted[new_key] = tensor
368
+
369
+ return converted
370
+
371
+
111
372
  @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
112
373
  class CLIPDiffusersLoader(ModelLoader):
113
374
  """Class to load main models."""
@@ -122,9 +383,9 @@ class CLIPDiffusersLoader(ModelLoader):
122
383
 
123
384
  match submodel_type:
124
385
  case SubModelType.Tokenizer:
125
- return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
386
+ return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer", local_files_only=True)
126
387
  case SubModelType.TextEncoder:
127
- return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
388
+ return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder", local_files_only=True)
128
389
 
129
390
  raise ValueError(
130
391
  f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@@ -148,10 +409,12 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
148
409
  )
149
410
  match submodel_type:
150
411
  case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
151
- return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
412
+ return T5TokenizerFast.from_pretrained(
413
+ Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
414
+ )
152
415
  case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
153
416
  te2_model_path = Path(config.path) / "text_encoder_2"
154
- model_config = AutoConfig.from_pretrained(te2_model_path)
417
+ model_config = AutoConfig.from_pretrained(te2_model_path, local_files_only=True)
155
418
  with accelerate.init_empty_weights():
156
419
  model = AutoModelForTextEncoding.from_config(model_config)
157
420
  model = quantize_model_llm_int8(model, modules_to_not_convert=set())
@@ -192,10 +455,15 @@ class T5EncoderCheckpointModel(ModelLoader):
192
455
 
193
456
  match submodel_type:
194
457
  case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
195
- return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
458
+ return T5TokenizerFast.from_pretrained(
459
+ Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
460
+ )
196
461
  case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
197
462
  return T5EncoderModel.from_pretrained(
198
- Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
463
+ Path(config.path) / "text_encoder_2",
464
+ torch_dtype="auto",
465
+ low_cpu_mem_usage=True,
466
+ local_files_only=True,
199
467
  )
200
468
 
201
469
  raise ValueError(
@@ -333,6 +601,750 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
333
601
  return model
334
602
 
335
603
 
604
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Diffusers)
605
+ class FluxDiffusersModel(GenericDiffusersLoader):
606
+ """Class to load FLUX.1 main models in diffusers format."""
607
+
608
+ def _load_model(
609
+ self,
610
+ config: AnyModelConfig,
611
+ submodel_type: Optional[SubModelType] = None,
612
+ ) -> AnyModel:
613
+ if isinstance(config, Checkpoint_Config_Base):
614
+ raise NotImplementedError("CheckpointConfigBase is not implemented for FLUX diffusers models.")
615
+
616
+ if submodel_type is None:
617
+ raise Exception("A submodel type must be provided when loading main pipelines.")
618
+
619
+ model_path = Path(config.path)
620
+ load_class = self.get_hf_load_class(model_path, submodel_type)
621
+ repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
622
+ variant = repo_variant.value if repo_variant else None
623
+ model_path = model_path / submodel_type.value
624
+
625
+ # We force bfloat16 for FLUX models. This is required for correct inference.
626
+ dtype = torch.bfloat16
627
+ try:
628
+ result: AnyModel = load_class.from_pretrained(
629
+ model_path,
630
+ torch_dtype=dtype,
631
+ variant=variant,
632
+ local_files_only=True,
633
+ )
634
+ except OSError as e:
635
+ if variant and "no file named" in str(
636
+ e
637
+ ): # try without the variant, just in case user's preferences changed
638
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
639
+ else:
640
+ raise e
641
+
642
+ return result
643
+
644
+
645
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.Diffusers)
646
+ class Flux2DiffusersModel(GenericDiffusersLoader):
647
+ """Class to load FLUX.2 main models in diffusers format (e.g. FLUX.2 Klein)."""
648
+
649
+ def _load_model(
650
+ self,
651
+ config: AnyModelConfig,
652
+ submodel_type: Optional[SubModelType] = None,
653
+ ) -> AnyModel:
654
+ if isinstance(config, Checkpoint_Config_Base):
655
+ raise NotImplementedError("CheckpointConfigBase is not implemented for FLUX.2 diffusers models.")
656
+
657
+ if submodel_type is None:
658
+ raise Exception("A submodel type must be provided when loading main pipelines.")
659
+
660
+ model_path = Path(config.path)
661
+ load_class = self.get_hf_load_class(model_path, submodel_type)
662
+ repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
663
+ variant = repo_variant.value if repo_variant else None
664
+ model_path = model_path / submodel_type.value
665
+
666
+ # We force bfloat16 for FLUX.2 models. This is required for correct inference.
667
+ # We use low_cpu_mem_usage=False to avoid meta tensors for weights not in checkpoint.
668
+ # FLUX.2 Klein models may have guidance_embeds=False, so the guidance_embed layers
669
+ # won't be in the checkpoint but the model class still creates them.
670
+ # We use SilenceWarnings to suppress the "guidance_embeds is not expected" warning
671
+ # from diffusers Flux2Transformer2DModel.
672
+ dtype = torch.bfloat16
673
+ with SilenceWarnings():
674
+ try:
675
+ result: AnyModel = load_class.from_pretrained(
676
+ model_path,
677
+ torch_dtype=dtype,
678
+ variant=variant,
679
+ local_files_only=True,
680
+ low_cpu_mem_usage=False,
681
+ )
682
+ except OSError as e:
683
+ if variant and "no file named" in str(
684
+ e
685
+ ): # try without the variant, just in case user's preferences changed
686
+ result = load_class.from_pretrained(
687
+ model_path,
688
+ torch_dtype=dtype,
689
+ local_files_only=True,
690
+ low_cpu_mem_usage=False,
691
+ )
692
+ else:
693
+ raise e
694
+
695
+ # For Klein models without guidance_embeds, zero out the guidance_embedder weights
696
+ # that were randomly initialized by diffusers. This prevents noise from affecting
697
+ # the time embeddings.
698
+ if submodel_type == SubModelType.Transformer and hasattr(result, "time_guidance_embed"):
699
+ # Check if this is a Klein model without guidance (guidance_embeds=False in config)
700
+ transformer_config_path = model_path / "config.json"
701
+ if transformer_config_path.exists():
702
+ import json
703
+
704
+ with open(transformer_config_path, "r") as f:
705
+ transformer_config = json.load(f)
706
+ if not transformer_config.get("guidance_embeds", True):
707
+ # Zero out the guidance embedder weights
708
+ guidance_emb = result.time_guidance_embed.guidance_embedder
709
+ if hasattr(guidance_emb, "linear_1"):
710
+ guidance_emb.linear_1.weight.data.zero_()
711
+ if guidance_emb.linear_1.bias is not None:
712
+ guidance_emb.linear_1.bias.data.zero_()
713
+ if hasattr(guidance_emb, "linear_2"):
714
+ guidance_emb.linear_2.weight.data.zero_()
715
+ if guidance_emb.linear_2.bias is not None:
716
+ guidance_emb.linear_2.bias.data.zero_()
717
+
718
+ return result
719
+
720
+
721
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.Checkpoint)
722
+ class Flux2CheckpointModel(ModelLoader):
723
+ """Class to load FLUX.2 transformer models from single-file checkpoints (safetensors)."""
724
+
725
+ def _load_model(
726
+ self,
727
+ config: AnyModelConfig,
728
+ submodel_type: Optional[SubModelType] = None,
729
+ ) -> AnyModel:
730
+ if not isinstance(config, Checkpoint_Config_Base):
731
+ raise ValueError("Only CheckpointConfigBase models are currently supported here.")
732
+
733
+ match submodel_type:
734
+ case SubModelType.Transformer:
735
+ return self._load_from_singlefile(config)
736
+
737
+ raise ValueError(
738
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
739
+ )
740
+
741
+ def _load_from_singlefile(
742
+ self,
743
+ config: AnyModelConfig,
744
+ ) -> AnyModel:
745
+ from diffusers import Flux2Transformer2DModel
746
+
747
+ if not isinstance(config, Main_Checkpoint_Flux2_Config):
748
+ raise TypeError(
749
+ f"Expected Main_Checkpoint_Flux2_Config, got {type(config).__name__}. "
750
+ "Model configuration type mismatch."
751
+ )
752
+ model_path = Path(config.path)
753
+
754
+ # Load state dict
755
+ sd = load_file(model_path)
756
+
757
+ # Handle FP8 quantized weights (ComfyUI-style or scaled FP8)
758
+ # These store weights as: layer.weight (FP8) + layer.weight_scale (FP32 scalar)
759
+ sd = self._dequantize_fp8_weights(sd)
760
+
761
+ # Check if keys have ComfyUI-style prefix and strip if needed
762
+ prefix_to_strip = None
763
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
764
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
765
+ prefix_to_strip = prefix
766
+ break
767
+
768
+ if prefix_to_strip:
769
+ sd = {
770
+ (k[len(prefix_to_strip) :] if isinstance(k, str) and k.startswith(prefix_to_strip) else k): v
771
+ for k, v in sd.items()
772
+ }
773
+
774
+ # Convert BFL format state dict to diffusers format
775
+ converted_sd = self._convert_flux2_bfl_to_diffusers(sd)
776
+
777
+ # Detect architecture from checkpoint keys
778
+ double_block_indices = [
779
+ int(k.split(".")[1])
780
+ for k in converted_sd.keys()
781
+ if isinstance(k, str) and k.startswith("transformer_blocks.")
782
+ ]
783
+ single_block_indices = [
784
+ int(k.split(".")[1])
785
+ for k in converted_sd.keys()
786
+ if isinstance(k, str) and k.startswith("single_transformer_blocks.")
787
+ ]
788
+
789
+ num_layers = max(double_block_indices) + 1 if double_block_indices else 5
790
+ num_single_layers = max(single_block_indices) + 1 if single_block_indices else 20
791
+
792
+ # Get dimensions from weights
793
+ # context_embedder.weight shape: [hidden_size, joint_attention_dim]
794
+ context_embedder_weight = converted_sd.get("context_embedder.weight")
795
+ if context_embedder_weight is not None:
796
+ hidden_size = context_embedder_weight.shape[0]
797
+ joint_attention_dim = context_embedder_weight.shape[1]
798
+ else:
799
+ # Default to Klein 4B dimensions
800
+ hidden_size = 3072
801
+ joint_attention_dim = 7680
802
+
803
+ x_embedder_weight = converted_sd.get("x_embedder.weight")
804
+ if x_embedder_weight is not None:
805
+ in_channels = x_embedder_weight.shape[1]
806
+ else:
807
+ in_channels = 128
808
+
809
+ # Calculate num_attention_heads from hidden_size
810
+ # Klein 4B: hidden_size=3072, num_attention_heads=24 (3072/128=24)
811
+ # Klein 9B: hidden_size=4096, num_attention_heads=32 (4096/128=32)
812
+ attention_head_dim = 128
813
+ num_attention_heads = hidden_size // attention_head_dim
814
+
815
+ # Klein models don't have guidance embeddings - check if they're in the checkpoint
816
+ has_guidance = "time_guidance_embed.guidance_embedder.linear_1.weight" in converted_sd
817
+
818
+ # Create model with detected configuration
819
+ with SilenceWarnings():
820
+ with accelerate.init_empty_weights():
821
+ model = Flux2Transformer2DModel(
822
+ in_channels=in_channels,
823
+ out_channels=in_channels,
824
+ num_layers=num_layers,
825
+ num_single_layers=num_single_layers,
826
+ attention_head_dim=attention_head_dim,
827
+ num_attention_heads=num_attention_heads,
828
+ joint_attention_dim=joint_attention_dim,
829
+ patch_size=1,
830
+ )
831
+
832
+ # If Klein model without guidance, initialize guidance embedder with zeros
833
+ if not has_guidance:
834
+ # Get the expected dimensions from timestep embedder (they should match)
835
+ timestep_linear1 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_1.weight")
836
+ if timestep_linear1 is not None:
837
+ in_features = timestep_linear1.shape[1]
838
+ out_features = timestep_linear1.shape[0]
839
+ # Initialize guidance embedder with same shape as timestep embedder
840
+ converted_sd["time_guidance_embed.guidance_embedder.linear_1.weight"] = torch.zeros(
841
+ out_features, in_features, dtype=torch.bfloat16
842
+ )
843
+ timestep_linear2 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_2.weight")
844
+ if timestep_linear2 is not None:
845
+ in_features2 = timestep_linear2.shape[1]
846
+ out_features2 = timestep_linear2.shape[0]
847
+ converted_sd["time_guidance_embed.guidance_embedder.linear_2.weight"] = torch.zeros(
848
+ out_features2, in_features2, dtype=torch.bfloat16
849
+ )
850
+
851
+ # Convert to bfloat16 and load
852
+ for k in converted_sd.keys():
853
+ converted_sd[k] = converted_sd[k].to(torch.bfloat16)
854
+
855
+ # Load the state dict - guidance weights were already initialized above if missing
856
+ model.load_state_dict(converted_sd, assign=True)
857
+
858
+ return model
859
+
860
+ def _convert_flux2_bfl_to_diffusers(self, sd: dict) -> dict:
861
+ """Convert FLUX.2 BFL format state dict to diffusers format.
862
+
863
+ Based on diffusers convert_flux2_to_diffusers.py key mappings.
864
+ """
865
+ converted = {}
866
+
867
+ # Basic key renames
868
+ key_renames = {
869
+ "img_in.weight": "x_embedder.weight",
870
+ "txt_in.weight": "context_embedder.weight",
871
+ "time_in.in_layer.weight": "time_guidance_embed.timestep_embedder.linear_1.weight",
872
+ "time_in.out_layer.weight": "time_guidance_embed.timestep_embedder.linear_2.weight",
873
+ "guidance_in.in_layer.weight": "time_guidance_embed.guidance_embedder.linear_1.weight",
874
+ "guidance_in.out_layer.weight": "time_guidance_embed.guidance_embedder.linear_2.weight",
875
+ "double_stream_modulation_img.lin.weight": "double_stream_modulation_img.linear.weight",
876
+ "double_stream_modulation_txt.lin.weight": "double_stream_modulation_txt.linear.weight",
877
+ "single_stream_modulation.lin.weight": "single_stream_modulation.linear.weight",
878
+ "final_layer.linear.weight": "proj_out.weight",
879
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
880
+ }
881
+
882
+ for old_key, tensor in sd.items():
883
+ new_key = old_key
884
+
885
+ # Apply basic renames
886
+ if old_key in key_renames:
887
+ new_key = key_renames[old_key]
888
+ # Apply scale-shift swap for adaLN modulation weights
889
+ # BFL and diffusers use different parameter ordering for AdaLayerNorm
890
+ if old_key == "final_layer.adaLN_modulation.1.weight":
891
+ tensor = self._swap_scale_shift(tensor)
892
+ converted[new_key] = tensor
893
+ continue
894
+
895
+ # Convert double_blocks.X.* to transformer_blocks.X.*
896
+ if old_key.startswith("double_blocks."):
897
+ new_key = self._convert_double_block_key(old_key, tensor, converted)
898
+ if new_key is None:
899
+ continue # Key was handled specially
900
+ # Convert single_blocks.X.* to single_transformer_blocks.X.*
901
+ elif old_key.startswith("single_blocks."):
902
+ new_key = self._convert_single_block_key(old_key, tensor, converted)
903
+ if new_key is None:
904
+ continue # Key was handled specially
905
+
906
+ if new_key != old_key or new_key not in converted:
907
+ converted[new_key] = tensor
908
+
909
+ return converted
910
+
911
+ def _convert_double_block_key(self, key: str, tensor: torch.Tensor, converted: dict) -> str | None:
912
+ """Convert double_blocks key to transformer_blocks format."""
913
+ parts = key.split(".")
914
+ block_idx = parts[1]
915
+ rest = ".".join(parts[2:])
916
+
917
+ prefix = f"transformer_blocks.{block_idx}"
918
+
919
+ # Attention QKV conversion - BFL uses fused qkv, diffusers uses separate
920
+ if "img_attn.qkv.weight" in rest:
921
+ # Split fused QKV into separate Q, K, V
922
+ # Defensive check: ensure tensor has at least 1 dimension and can be split into 3
923
+ if tensor.dim() < 1 or tensor.shape[0] % 3 != 0:
924
+ # Skip malformed tensors (might be metadata or corrupted)
925
+ return key
926
+ q, k, v = tensor.chunk(3, dim=0)
927
+ converted[f"{prefix}.attn.to_q.weight"] = q
928
+ converted[f"{prefix}.attn.to_k.weight"] = k
929
+ converted[f"{prefix}.attn.to_v.weight"] = v
930
+ return None
931
+ elif "txt_attn.qkv.weight" in rest:
932
+ # Defensive check
933
+ if tensor.dim() < 1 or tensor.shape[0] % 3 != 0:
934
+ return key
935
+ q, k, v = tensor.chunk(3, dim=0)
936
+ converted[f"{prefix}.attn.add_q_proj.weight"] = q
937
+ converted[f"{prefix}.attn.add_k_proj.weight"] = k
938
+ converted[f"{prefix}.attn.add_v_proj.weight"] = v
939
+ return None
940
+
941
+ # Attention output projection
942
+ if "img_attn.proj.weight" in rest:
943
+ return f"{prefix}.attn.to_out.0.weight"
944
+ elif "txt_attn.proj.weight" in rest:
945
+ return f"{prefix}.attn.to_add_out.weight"
946
+
947
+ # Attention norms
948
+ if "img_attn.norm.query_norm.scale" in rest:
949
+ return f"{prefix}.attn.norm_q.weight"
950
+ elif "img_attn.norm.key_norm.scale" in rest:
951
+ return f"{prefix}.attn.norm_k.weight"
952
+ elif "txt_attn.norm.query_norm.scale" in rest:
953
+ return f"{prefix}.attn.norm_added_q.weight"
954
+ elif "txt_attn.norm.key_norm.scale" in rest:
955
+ return f"{prefix}.attn.norm_added_k.weight"
956
+
957
+ # MLP layers
958
+ if "img_mlp.0.weight" in rest:
959
+ return f"{prefix}.ff.linear_in.weight"
960
+ elif "img_mlp.2.weight" in rest:
961
+ return f"{prefix}.ff.linear_out.weight"
962
+ elif "txt_mlp.0.weight" in rest:
963
+ return f"{prefix}.ff_context.linear_in.weight"
964
+ elif "txt_mlp.2.weight" in rest:
965
+ return f"{prefix}.ff_context.linear_out.weight"
966
+
967
+ return key
968
+
969
+ def _convert_single_block_key(self, key: str, tensor: torch.Tensor, converted: dict) -> str | None:
970
+ """Convert single_blocks key to single_transformer_blocks format."""
971
+ parts = key.split(".")
972
+ block_idx = parts[1]
973
+ rest = ".".join(parts[2:])
974
+
975
+ prefix = f"single_transformer_blocks.{block_idx}"
976
+
977
+ # linear1 is the fused QKV+MLP projection
978
+ if "linear1.weight" in rest:
979
+ return f"{prefix}.attn.to_qkv_mlp_proj.weight"
980
+ elif "linear2.weight" in rest:
981
+ return f"{prefix}.attn.to_out.weight"
982
+
983
+ # Norms
984
+ if "norm.query_norm.scale" in rest:
985
+ return f"{prefix}.attn.norm_q.weight"
986
+ elif "norm.key_norm.scale" in rest:
987
+ return f"{prefix}.attn.norm_k.weight"
988
+
989
+ return key
990
+
991
+ def _swap_scale_shift(self, weight: torch.Tensor) -> torch.Tensor:
992
+ """Swap scale and shift in AdaLayerNorm weights.
993
+
994
+ BFL and diffusers use different parameter ordering for AdaLayerNorm.
995
+ This function swaps the two halves of the weight tensor.
996
+
997
+ Args:
998
+ weight: Weight tensor of shape (out_features,) or (out_features, in_features)
999
+
1000
+ Returns:
1001
+ Weight tensor with scale and shift swapped.
1002
+ """
1003
+ # Defensive check: ensure tensor can be split
1004
+ if weight.dim() < 1 or weight.shape[0] % 2 != 0:
1005
+ return weight
1006
+ # Split in half along the first dimension and swap
1007
+ shift, scale = weight.chunk(2, dim=0)
1008
+ return torch.cat([scale, shift], dim=0)
1009
+
1010
+ def _dequantize_fp8_weights(self, sd: dict) -> dict:
1011
+ """Dequantize FP8 quantized weights in the state dict.
1012
+
1013
+ ComfyUI and some FLUX.2 models store quantized weights as:
1014
+ - layer.weight: quantized FP8 data
1015
+ - layer.weight_scale: scale factor (FP32 scalar or per-channel)
1016
+
1017
+ Dequantization formula: dequantized = weight.to(float) * weight_scale
1018
+
1019
+ Also handles FP8 tensors stored with float8_e4m3fn dtype by converting to float.
1020
+ """
1021
+ # Check for ComfyUI-style scale factors
1022
+ weight_scale_keys = [k for k in sd.keys() if isinstance(k, str) and k.endswith(".weight_scale")]
1023
+
1024
+ for scale_key in weight_scale_keys:
1025
+ # Get the corresponding weight key
1026
+ weight_key = scale_key.replace(".weight_scale", ".weight")
1027
+ if weight_key in sd:
1028
+ weight = sd[weight_key]
1029
+ scale = sd[scale_key]
1030
+
1031
+ # Dequantize: convert FP8 to float and multiply by scale
1032
+ # Note: Float8 types require .float() instead of .to(torch.float32)
1033
+ weight_float = weight.float()
1034
+ scale = scale.float()
1035
+
1036
+ # Handle block-wise quantization where scale may have different shape
1037
+ if scale.dim() > 0 and scale.shape != weight_float.shape and scale.numel() > 1:
1038
+ for dim in range(len(weight_float.shape)):
1039
+ if dim < len(scale.shape) and scale.shape[dim] != weight_float.shape[dim]:
1040
+ block_size = weight_float.shape[dim] // scale.shape[dim]
1041
+ if block_size > 1:
1042
+ scale = scale.repeat_interleave(block_size, dim=dim)
1043
+
1044
+ sd[weight_key] = weight_float * scale
1045
+
1046
+ # Filter out scale metadata keys and other FP8 metadata
1047
+ keys_to_remove = [
1048
+ k
1049
+ for k in sd.keys()
1050
+ if isinstance(k, str)
1051
+ and (k.endswith(".weight_scale") or k.endswith(".scale_weight") or "comfy_quant" in k or k == "scaled_fp8")
1052
+ ]
1053
+ for k in keys_to_remove:
1054
+ del sd[k]
1055
+
1056
+ # Handle native FP8 tensors (float8_e4m3fn dtype) that aren't already dequantized
1057
+ # Also filter out 0-dimensional tensors (scalars) which are typically metadata
1058
+ keys_to_convert = []
1059
+ keys_to_remove_scalars = []
1060
+ for key in list(sd.keys()):
1061
+ tensor = sd[key]
1062
+ if hasattr(tensor, "dim"):
1063
+ if tensor.dim() == 0:
1064
+ # 0-dimensional tensor (scalar) - likely metadata, remove it
1065
+ keys_to_remove_scalars.append(key)
1066
+ elif hasattr(tensor, "dtype") and "float8" in str(tensor.dtype):
1067
+ # Native FP8 tensor - mark for conversion
1068
+ keys_to_convert.append(key)
1069
+
1070
+ for k in keys_to_remove_scalars:
1071
+ del sd[k]
1072
+
1073
+ for key in keys_to_convert:
1074
+ # Convert FP8 tensor to float32
1075
+ sd[key] = sd[key].float()
1076
+
1077
+ return sd
1078
+
1079
+
1080
+ @ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
1081
+ class Flux2GGUFCheckpointModel(ModelLoader):
1082
+ """Class to load GGUF-quantized FLUX.2 transformer models."""
1083
+
1084
+ def _load_model(
1085
+ self,
1086
+ config: AnyModelConfig,
1087
+ submodel_type: Optional[SubModelType] = None,
1088
+ ) -> AnyModel:
1089
+ if not isinstance(config, Main_GGUF_Flux2_Config):
1090
+ raise ValueError("Only Main_GGUF_Flux2_Config models are currently supported here.")
1091
+
1092
+ match submodel_type:
1093
+ case SubModelType.Transformer:
1094
+ return self._load_from_singlefile(config)
1095
+
1096
+ raise ValueError(
1097
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
1098
+ )
1099
+
1100
+ def _load_from_singlefile(
1101
+ self,
1102
+ config: Main_GGUF_Flux2_Config,
1103
+ ) -> AnyModel:
1104
+ from diffusers import Flux2Transformer2DModel
1105
+
1106
+ model_path = Path(config.path)
1107
+
1108
+ # Load GGUF state dict
1109
+ sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
1110
+
1111
+ # Check if keys have ComfyUI-style prefix and strip if needed
1112
+ prefix_to_strip = None
1113
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
1114
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
1115
+ prefix_to_strip = prefix
1116
+ break
1117
+
1118
+ if prefix_to_strip:
1119
+ sd = {
1120
+ (k[len(prefix_to_strip) :] if isinstance(k, str) and k.startswith(prefix_to_strip) else k): v
1121
+ for k, v in sd.items()
1122
+ }
1123
+
1124
+ # Convert BFL format state dict to diffusers format
1125
+ converted_sd = self._convert_flux2_bfl_to_diffusers(sd)
1126
+
1127
+ # Detect architecture from checkpoint keys
1128
+ double_block_indices = [
1129
+ int(k.split(".")[1])
1130
+ for k in converted_sd.keys()
1131
+ if isinstance(k, str) and k.startswith("transformer_blocks.")
1132
+ ]
1133
+ single_block_indices = [
1134
+ int(k.split(".")[1])
1135
+ for k in converted_sd.keys()
1136
+ if isinstance(k, str) and k.startswith("single_transformer_blocks.")
1137
+ ]
1138
+
1139
+ num_layers = max(double_block_indices) + 1 if double_block_indices else 5
1140
+ num_single_layers = max(single_block_indices) + 1 if single_block_indices else 20
1141
+
1142
+ # Get dimensions from weights
1143
+ # context_embedder.weight shape: [hidden_size, joint_attention_dim]
1144
+ context_embedder_weight = converted_sd.get("context_embedder.weight")
1145
+ if context_embedder_weight is not None:
1146
+ if hasattr(context_embedder_weight, "tensor_shape"):
1147
+ hidden_size = context_embedder_weight.tensor_shape[0]
1148
+ joint_attention_dim = context_embedder_weight.tensor_shape[1]
1149
+ else:
1150
+ hidden_size = context_embedder_weight.shape[0]
1151
+ joint_attention_dim = context_embedder_weight.shape[1]
1152
+ else:
1153
+ # Default to Klein 4B dimensions
1154
+ hidden_size = 3072
1155
+ joint_attention_dim = 7680
1156
+
1157
+ x_embedder_weight = converted_sd.get("x_embedder.weight")
1158
+ if x_embedder_weight is not None:
1159
+ in_channels = (
1160
+ x_embedder_weight.tensor_shape[1]
1161
+ if hasattr(x_embedder_weight, "tensor_shape")
1162
+ else x_embedder_weight.shape[1]
1163
+ )
1164
+ else:
1165
+ in_channels = 128
1166
+
1167
+ # Calculate num_attention_heads from hidden_size
1168
+ # Klein 4B: hidden_size=3072, num_attention_heads=24 (3072/128=24)
1169
+ # Klein 9B: hidden_size=4096, num_attention_heads=32 (4096/128=32)
1170
+ attention_head_dim = 128
1171
+ num_attention_heads = hidden_size // attention_head_dim
1172
+
1173
+ # Klein models don't have guidance embeddings - check if they're in the checkpoint
1174
+ has_guidance = "time_guidance_embed.guidance_embedder.linear_1.weight" in converted_sd
1175
+
1176
+ # Create model with detected configuration
1177
+ with SilenceWarnings():
1178
+ with accelerate.init_empty_weights():
1179
+ model = Flux2Transformer2DModel(
1180
+ in_channels=in_channels,
1181
+ out_channels=in_channels,
1182
+ num_layers=num_layers,
1183
+ num_single_layers=num_single_layers,
1184
+ attention_head_dim=attention_head_dim,
1185
+ num_attention_heads=num_attention_heads,
1186
+ joint_attention_dim=joint_attention_dim,
1187
+ patch_size=1,
1188
+ )
1189
+
1190
+ # If Klein model without guidance, initialize guidance embedder with zeros
1191
+ if not has_guidance:
1192
+ timestep_linear1 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_1.weight")
1193
+ if timestep_linear1 is not None:
1194
+ in_features = (
1195
+ timestep_linear1.tensor_shape[1]
1196
+ if hasattr(timestep_linear1, "tensor_shape")
1197
+ else timestep_linear1.shape[1]
1198
+ )
1199
+ out_features = (
1200
+ timestep_linear1.tensor_shape[0]
1201
+ if hasattr(timestep_linear1, "tensor_shape")
1202
+ else timestep_linear1.shape[0]
1203
+ )
1204
+ converted_sd["time_guidance_embed.guidance_embedder.linear_1.weight"] = torch.zeros(
1205
+ out_features, in_features, dtype=torch.bfloat16
1206
+ )
1207
+ timestep_linear2 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_2.weight")
1208
+ if timestep_linear2 is not None:
1209
+ in_features2 = (
1210
+ timestep_linear2.tensor_shape[1]
1211
+ if hasattr(timestep_linear2, "tensor_shape")
1212
+ else timestep_linear2.shape[1]
1213
+ )
1214
+ out_features2 = (
1215
+ timestep_linear2.tensor_shape[0]
1216
+ if hasattr(timestep_linear2, "tensor_shape")
1217
+ else timestep_linear2.shape[0]
1218
+ )
1219
+ converted_sd["time_guidance_embed.guidance_embedder.linear_2.weight"] = torch.zeros(
1220
+ out_features2, in_features2, dtype=torch.bfloat16
1221
+ )
1222
+
1223
+ model.load_state_dict(converted_sd, assign=True)
1224
+ return model
1225
+
1226
+ def _convert_flux2_bfl_to_diffusers(self, sd: dict) -> dict:
1227
+ """Convert FLUX.2 BFL format state dict to diffusers format."""
1228
+ converted = {}
1229
+
1230
+ key_renames = {
1231
+ "img_in.weight": "x_embedder.weight",
1232
+ "txt_in.weight": "context_embedder.weight",
1233
+ "time_in.in_layer.weight": "time_guidance_embed.timestep_embedder.linear_1.weight",
1234
+ "time_in.out_layer.weight": "time_guidance_embed.timestep_embedder.linear_2.weight",
1235
+ "guidance_in.in_layer.weight": "time_guidance_embed.guidance_embedder.linear_1.weight",
1236
+ "guidance_in.out_layer.weight": "time_guidance_embed.guidance_embedder.linear_2.weight",
1237
+ "double_stream_modulation_img.lin.weight": "double_stream_modulation_img.linear.weight",
1238
+ "double_stream_modulation_txt.lin.weight": "double_stream_modulation_txt.linear.weight",
1239
+ "single_stream_modulation.lin.weight": "single_stream_modulation.linear.weight",
1240
+ "final_layer.linear.weight": "proj_out.weight",
1241
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
1242
+ }
1243
+
1244
+ for old_key, tensor in sd.items():
1245
+ new_key = old_key
1246
+
1247
+ if old_key in key_renames:
1248
+ new_key = key_renames[old_key]
1249
+ if old_key == "final_layer.adaLN_modulation.1.weight":
1250
+ tensor = self._swap_scale_shift(tensor)
1251
+ converted[new_key] = tensor
1252
+ continue
1253
+
1254
+ if old_key.startswith("double_blocks."):
1255
+ new_key = self._convert_double_block_key(old_key, tensor, converted)
1256
+ if new_key is None:
1257
+ continue
1258
+ elif old_key.startswith("single_blocks."):
1259
+ new_key = self._convert_single_block_key(old_key, tensor, converted)
1260
+ if new_key is None:
1261
+ continue
1262
+
1263
+ if new_key != old_key or new_key not in converted:
1264
+ converted[new_key] = tensor
1265
+
1266
+ return converted
1267
+
1268
+ def _convert_double_block_key(self, key: str, tensor, converted: dict) -> str | None:
1269
+ parts = key.split(".")
1270
+ block_idx = parts[1]
1271
+ rest = ".".join(parts[2:])
1272
+ prefix = f"transformer_blocks.{block_idx}"
1273
+
1274
+ if "img_attn.qkv.weight" in rest:
1275
+ q, k, v = self._chunk_tensor(tensor, 3)
1276
+ converted[f"{prefix}.attn.to_q.weight"] = q
1277
+ converted[f"{prefix}.attn.to_k.weight"] = k
1278
+ converted[f"{prefix}.attn.to_v.weight"] = v
1279
+ return None
1280
+ elif "txt_attn.qkv.weight" in rest:
1281
+ q, k, v = self._chunk_tensor(tensor, 3)
1282
+ converted[f"{prefix}.attn.add_q_proj.weight"] = q
1283
+ converted[f"{prefix}.attn.add_k_proj.weight"] = k
1284
+ converted[f"{prefix}.attn.add_v_proj.weight"] = v
1285
+ return None
1286
+
1287
+ if "img_attn.proj.weight" in rest:
1288
+ return f"{prefix}.attn.to_out.0.weight"
1289
+ elif "txt_attn.proj.weight" in rest:
1290
+ return f"{prefix}.attn.to_add_out.weight"
1291
+
1292
+ if "img_attn.norm.query_norm.scale" in rest:
1293
+ return f"{prefix}.attn.norm_q.weight"
1294
+ elif "img_attn.norm.key_norm.scale" in rest:
1295
+ return f"{prefix}.attn.norm_k.weight"
1296
+ elif "txt_attn.norm.query_norm.scale" in rest:
1297
+ return f"{prefix}.attn.norm_added_q.weight"
1298
+ elif "txt_attn.norm.key_norm.scale" in rest:
1299
+ return f"{prefix}.attn.norm_added_k.weight"
1300
+
1301
+ if "img_mlp.0.weight" in rest:
1302
+ return f"{prefix}.ff.linear_in.weight"
1303
+ elif "img_mlp.2.weight" in rest:
1304
+ return f"{prefix}.ff.linear_out.weight"
1305
+ elif "txt_mlp.0.weight" in rest:
1306
+ return f"{prefix}.ff_context.linear_in.weight"
1307
+ elif "txt_mlp.2.weight" in rest:
1308
+ return f"{prefix}.ff_context.linear_out.weight"
1309
+
1310
+ return key
1311
+
1312
+ def _convert_single_block_key(self, key: str, tensor, converted: dict) -> str | None:
1313
+ parts = key.split(".")
1314
+ block_idx = parts[1]
1315
+ rest = ".".join(parts[2:])
1316
+ prefix = f"single_transformer_blocks.{block_idx}"
1317
+
1318
+ if "linear1.weight" in rest:
1319
+ return f"{prefix}.attn.to_qkv_mlp_proj.weight"
1320
+ elif "linear2.weight" in rest:
1321
+ return f"{prefix}.attn.to_out.weight"
1322
+
1323
+ if "norm.query_norm.scale" in rest:
1324
+ return f"{prefix}.attn.norm_q.weight"
1325
+ elif "norm.key_norm.scale" in rest:
1326
+ return f"{prefix}.attn.norm_k.weight"
1327
+
1328
+ return key
1329
+
1330
+ def _chunk_tensor(self, tensor, chunks: int):
1331
+ """Chunk a tensor, handling both regular tensors and GGUF quantized tensors."""
1332
+ if hasattr(tensor, "get_dequantized_tensor"):
1333
+ # GGUF quantized tensor - dequantize first, then chunk
1334
+ # This loses quantization for the split weights, but is necessary
1335
+ # because diffusers uses separate Q/K/V projections
1336
+ tensor = tensor.get_dequantized_tensor()
1337
+ return tensor.chunk(chunks, dim=0)
1338
+
1339
+ def _swap_scale_shift(self, weight) -> torch.Tensor:
1340
+ """Swap scale and shift in AdaLayerNorm weights."""
1341
+ if hasattr(weight, "get_dequantized_tensor"):
1342
+ # For GGUF, dequantize first
1343
+ weight = weight.get_dequantized_tensor()
1344
+ shift, scale = weight.chunk(2, dim=0)
1345
+ return torch.cat([scale, shift], dim=0)
1346
+
1347
+
336
1348
  @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
337
1349
  @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
338
1350
  class FluxControlnetModel(ModelLoader):