InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -34,7 +34,7 @@ from invokeai.backend.flux.model import Flux
|
|
|
34
34
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
|
35
35
|
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
|
36
36
|
from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params
|
|
37
|
-
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base
|
|
37
|
+
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
|
|
38
38
|
from invokeai.backend.model_manager.configs.clip_embed import CLIPEmbed_Diffusers_Config_Base
|
|
39
39
|
from invokeai.backend.model_manager.configs.controlnet import (
|
|
40
40
|
ControlNet_Checkpoint_Config_Base,
|
|
@@ -45,13 +45,16 @@ from invokeai.backend.model_manager.configs.flux_redux import FLUXRedux_Checkpoi
|
|
|
45
45
|
from invokeai.backend.model_manager.configs.ip_adapter import IPAdapter_Checkpoint_Config_Base
|
|
46
46
|
from invokeai.backend.model_manager.configs.main import (
|
|
47
47
|
Main_BnBNF4_FLUX_Config,
|
|
48
|
+
Main_Checkpoint_Flux2_Config,
|
|
48
49
|
Main_Checkpoint_FLUX_Config,
|
|
50
|
+
Main_GGUF_Flux2_Config,
|
|
49
51
|
Main_GGUF_FLUX_Config,
|
|
50
52
|
)
|
|
51
53
|
from invokeai.backend.model_manager.configs.t5_encoder import T5Encoder_BnBLLMint8_Config, T5Encoder_T5Encoder_Config
|
|
52
|
-
from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base
|
|
54
|
+
from invokeai.backend.model_manager.configs.vae import VAE_Checkpoint_Config_Base, VAE_Checkpoint_Flux2_Config
|
|
53
55
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
54
56
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
|
57
|
+
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
|
55
58
|
from invokeai.backend.model_manager.taxonomy import (
|
|
56
59
|
AnyModel,
|
|
57
60
|
BaseModelType,
|
|
@@ -108,6 +111,264 @@ class FluxVAELoader(ModelLoader):
|
|
|
108
111
|
return model
|
|
109
112
|
|
|
110
113
|
|
|
114
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
|
115
|
+
class Flux2VAEDiffusersLoader(ModelLoader):
|
|
116
|
+
"""Class to load FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2 with 32 latent channels)."""
|
|
117
|
+
|
|
118
|
+
def _load_model(
|
|
119
|
+
self,
|
|
120
|
+
config: AnyModelConfig,
|
|
121
|
+
submodel_type: Optional[SubModelType] = None,
|
|
122
|
+
) -> AnyModel:
|
|
123
|
+
from diffusers import AutoencoderKLFlux2
|
|
124
|
+
|
|
125
|
+
model_path = Path(config.path)
|
|
126
|
+
|
|
127
|
+
# VAE is broken in float16, which mps defaults to
|
|
128
|
+
if self._torch_dtype == torch.float16:
|
|
129
|
+
try:
|
|
130
|
+
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
|
131
|
+
except TypeError:
|
|
132
|
+
vae_dtype = torch.float32
|
|
133
|
+
else:
|
|
134
|
+
vae_dtype = self._torch_dtype
|
|
135
|
+
|
|
136
|
+
model = AutoencoderKLFlux2.from_pretrained(
|
|
137
|
+
model_path,
|
|
138
|
+
torch_dtype=vae_dtype,
|
|
139
|
+
local_files_only=True,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return model
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
|
146
|
+
class Flux2VAELoader(ModelLoader):
|
|
147
|
+
"""Class to load FLUX.2 VAE models (AutoencoderKLFlux2 with 32 latent channels)."""
|
|
148
|
+
|
|
149
|
+
def _load_model(
|
|
150
|
+
self,
|
|
151
|
+
config: AnyModelConfig,
|
|
152
|
+
submodel_type: Optional[SubModelType] = None,
|
|
153
|
+
) -> AnyModel:
|
|
154
|
+
if not isinstance(config, VAE_Checkpoint_Flux2_Config):
|
|
155
|
+
raise ValueError("Only VAE_Checkpoint_Flux2_Config models are currently supported here.")
|
|
156
|
+
|
|
157
|
+
from diffusers import AutoencoderKLFlux2
|
|
158
|
+
|
|
159
|
+
model_path = Path(config.path)
|
|
160
|
+
|
|
161
|
+
# Load state dict manually since from_single_file may not support AutoencoderKLFlux2 yet
|
|
162
|
+
sd = load_file(model_path)
|
|
163
|
+
|
|
164
|
+
# Convert BFL format to diffusers format if needed
|
|
165
|
+
# BFL format uses: encoder.down., decoder.up., decoder.mid.block_1, decoder.mid.attn_1, decoder.norm_out
|
|
166
|
+
# Diffusers uses: encoder.down_blocks., decoder.up_blocks., decoder.mid_block.resnets., decoder.conv_norm_out
|
|
167
|
+
is_bfl_format = any(
|
|
168
|
+
k.startswith("encoder.down.")
|
|
169
|
+
or k.startswith("decoder.up.")
|
|
170
|
+
or k.startswith("decoder.mid.block_")
|
|
171
|
+
or k.startswith("decoder.mid.attn_")
|
|
172
|
+
or k.startswith("decoder.norm_out")
|
|
173
|
+
or k.startswith("encoder.mid.block_")
|
|
174
|
+
or k.startswith("encoder.mid.attn_")
|
|
175
|
+
or k.startswith("encoder.norm_out")
|
|
176
|
+
for k in sd.keys()
|
|
177
|
+
)
|
|
178
|
+
if is_bfl_format:
|
|
179
|
+
sd = self._convert_flux2_vae_bfl_to_diffusers(sd)
|
|
180
|
+
|
|
181
|
+
# FLUX.2 VAE configuration (32 latent channels)
|
|
182
|
+
# Based on the official FLUX.2 VAE architecture
|
|
183
|
+
# Use default config - AutoencoderKLFlux2 has built-in defaults
|
|
184
|
+
with SilenceWarnings():
|
|
185
|
+
with accelerate.init_empty_weights():
|
|
186
|
+
model = AutoencoderKLFlux2()
|
|
187
|
+
|
|
188
|
+
# Convert to bfloat16 and load
|
|
189
|
+
for k in sd.keys():
|
|
190
|
+
sd[k] = sd[k].to(torch.bfloat16)
|
|
191
|
+
|
|
192
|
+
model.load_state_dict(sd, assign=True)
|
|
193
|
+
|
|
194
|
+
# VAE is broken in float16, which mps defaults to
|
|
195
|
+
if self._torch_dtype == torch.float16:
|
|
196
|
+
try:
|
|
197
|
+
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
|
198
|
+
except TypeError:
|
|
199
|
+
vae_dtype = torch.float32
|
|
200
|
+
else:
|
|
201
|
+
vae_dtype = self._torch_dtype
|
|
202
|
+
model.to(vae_dtype)
|
|
203
|
+
|
|
204
|
+
return model
|
|
205
|
+
|
|
206
|
+
def _convert_flux2_vae_bfl_to_diffusers(self, sd: dict) -> dict:
|
|
207
|
+
"""Convert FLUX.2 VAE BFL format state dict to diffusers format.
|
|
208
|
+
|
|
209
|
+
Key differences:
|
|
210
|
+
- encoder.down.X.block.Y -> encoder.down_blocks.X.resnets.Y
|
|
211
|
+
- encoder.down.X.downsample.conv -> encoder.down_blocks.X.downsamplers.0.conv
|
|
212
|
+
- encoder.mid.block_1/2 -> encoder.mid_block.resnets.0/1
|
|
213
|
+
- encoder.mid.attn_1.q/k/v -> encoder.mid_block.attentions.0.to_q/k/v
|
|
214
|
+
- encoder.norm_out -> encoder.conv_norm_out
|
|
215
|
+
- encoder.quant_conv -> quant_conv (top-level)
|
|
216
|
+
- decoder.up.X -> decoder.up_blocks.(num_blocks-1-X) (reversed order!)
|
|
217
|
+
- decoder.post_quant_conv -> post_quant_conv (top-level)
|
|
218
|
+
- *.nin_shortcut -> *.conv_shortcut
|
|
219
|
+
"""
|
|
220
|
+
import re
|
|
221
|
+
|
|
222
|
+
converted = {}
|
|
223
|
+
num_up_blocks = 4 # Standard VAE has 4 up blocks
|
|
224
|
+
|
|
225
|
+
for old_key, tensor in sd.items():
|
|
226
|
+
new_key = old_key
|
|
227
|
+
|
|
228
|
+
# Encoder down blocks: encoder.down.X.block.Y -> encoder.down_blocks.X.resnets.Y
|
|
229
|
+
match = re.match(r"encoder\.down\.(\d+)\.block\.(\d+)\.(.*)", old_key)
|
|
230
|
+
if match:
|
|
231
|
+
block_idx, resnet_idx, rest = match.groups()
|
|
232
|
+
rest = rest.replace("nin_shortcut", "conv_shortcut")
|
|
233
|
+
new_key = f"encoder.down_blocks.{block_idx}.resnets.{resnet_idx}.{rest}"
|
|
234
|
+
converted[new_key] = tensor
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
# Encoder downsamplers: encoder.down.X.downsample.conv -> encoder.down_blocks.X.downsamplers.0.conv
|
|
238
|
+
match = re.match(r"encoder\.down\.(\d+)\.downsample\.conv\.(.*)", old_key)
|
|
239
|
+
if match:
|
|
240
|
+
block_idx, rest = match.groups()
|
|
241
|
+
new_key = f"encoder.down_blocks.{block_idx}.downsamplers.0.conv.{rest}"
|
|
242
|
+
converted[new_key] = tensor
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
# Encoder mid block resnets: encoder.mid.block_1/2 -> encoder.mid_block.resnets.0/1
|
|
246
|
+
match = re.match(r"encoder\.mid\.block_(\d+)\.(.*)", old_key)
|
|
247
|
+
if match:
|
|
248
|
+
block_num, rest = match.groups()
|
|
249
|
+
resnet_idx = int(block_num) - 1 # block_1 -> resnets.0, block_2 -> resnets.1
|
|
250
|
+
new_key = f"encoder.mid_block.resnets.{resnet_idx}.{rest}"
|
|
251
|
+
converted[new_key] = tensor
|
|
252
|
+
continue
|
|
253
|
+
|
|
254
|
+
# Encoder mid block attention: encoder.mid.attn_1.* -> encoder.mid_block.attentions.0.*
|
|
255
|
+
match = re.match(r"encoder\.mid\.attn_1\.(.*)", old_key)
|
|
256
|
+
if match:
|
|
257
|
+
rest = match.group(1)
|
|
258
|
+
# Map attention keys
|
|
259
|
+
# BFL uses Conv2d (shape [out, in, 1, 1]), diffusers uses Linear (shape [out, in])
|
|
260
|
+
# Squeeze the extra dimensions for weight tensors
|
|
261
|
+
if rest.startswith("q."):
|
|
262
|
+
new_key = f"encoder.mid_block.attentions.0.to_q.{rest[2:]}"
|
|
263
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
264
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
265
|
+
elif rest.startswith("k."):
|
|
266
|
+
new_key = f"encoder.mid_block.attentions.0.to_k.{rest[2:]}"
|
|
267
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
268
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
269
|
+
elif rest.startswith("v."):
|
|
270
|
+
new_key = f"encoder.mid_block.attentions.0.to_v.{rest[2:]}"
|
|
271
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
272
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
273
|
+
elif rest.startswith("proj_out."):
|
|
274
|
+
new_key = f"encoder.mid_block.attentions.0.to_out.0.{rest[9:]}"
|
|
275
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
276
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
277
|
+
elif rest.startswith("norm."):
|
|
278
|
+
new_key = f"encoder.mid_block.attentions.0.group_norm.{rest[5:]}"
|
|
279
|
+
else:
|
|
280
|
+
new_key = f"encoder.mid_block.attentions.0.{rest}"
|
|
281
|
+
converted[new_key] = tensor
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
# Encoder norm_out -> conv_norm_out
|
|
285
|
+
if old_key.startswith("encoder.norm_out."):
|
|
286
|
+
new_key = old_key.replace("encoder.norm_out.", "encoder.conv_norm_out.")
|
|
287
|
+
converted[new_key] = tensor
|
|
288
|
+
continue
|
|
289
|
+
|
|
290
|
+
# Encoder quant_conv -> quant_conv (move to top level)
|
|
291
|
+
if old_key.startswith("encoder.quant_conv."):
|
|
292
|
+
new_key = old_key.replace("encoder.quant_conv.", "quant_conv.")
|
|
293
|
+
converted[new_key] = tensor
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
# Decoder up blocks (reversed order!): decoder.up.X -> decoder.up_blocks.(num_blocks-1-X)
|
|
297
|
+
match = re.match(r"decoder\.up\.(\d+)\.block\.(\d+)\.(.*)", old_key)
|
|
298
|
+
if match:
|
|
299
|
+
block_idx, resnet_idx, rest = match.groups()
|
|
300
|
+
# Reverse the block index
|
|
301
|
+
new_block_idx = num_up_blocks - 1 - int(block_idx)
|
|
302
|
+
rest = rest.replace("nin_shortcut", "conv_shortcut")
|
|
303
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.{rest}"
|
|
304
|
+
converted[new_key] = tensor
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
# Decoder upsamplers (reversed order!)
|
|
308
|
+
match = re.match(r"decoder\.up\.(\d+)\.upsample\.conv\.(.*)", old_key)
|
|
309
|
+
if match:
|
|
310
|
+
block_idx, rest = match.groups()
|
|
311
|
+
new_block_idx = num_up_blocks - 1 - int(block_idx)
|
|
312
|
+
new_key = f"decoder.up_blocks.{new_block_idx}.upsamplers.0.conv.{rest}"
|
|
313
|
+
converted[new_key] = tensor
|
|
314
|
+
continue
|
|
315
|
+
|
|
316
|
+
# Decoder mid block resnets: decoder.mid.block_1/2 -> decoder.mid_block.resnets.0/1
|
|
317
|
+
match = re.match(r"decoder\.mid\.block_(\d+)\.(.*)", old_key)
|
|
318
|
+
if match:
|
|
319
|
+
block_num, rest = match.groups()
|
|
320
|
+
resnet_idx = int(block_num) - 1
|
|
321
|
+
new_key = f"decoder.mid_block.resnets.{resnet_idx}.{rest}"
|
|
322
|
+
converted[new_key] = tensor
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
# Decoder mid block attention: decoder.mid.attn_1.* -> decoder.mid_block.attentions.0.*
|
|
326
|
+
match = re.match(r"decoder\.mid\.attn_1\.(.*)", old_key)
|
|
327
|
+
if match:
|
|
328
|
+
rest = match.group(1)
|
|
329
|
+
# BFL uses Conv2d (shape [out, in, 1, 1]), diffusers uses Linear (shape [out, in])
|
|
330
|
+
# Squeeze the extra dimensions for weight tensors
|
|
331
|
+
if rest.startswith("q."):
|
|
332
|
+
new_key = f"decoder.mid_block.attentions.0.to_q.{rest[2:]}"
|
|
333
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
334
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
335
|
+
elif rest.startswith("k."):
|
|
336
|
+
new_key = f"decoder.mid_block.attentions.0.to_k.{rest[2:]}"
|
|
337
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
338
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
339
|
+
elif rest.startswith("v."):
|
|
340
|
+
new_key = f"decoder.mid_block.attentions.0.to_v.{rest[2:]}"
|
|
341
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
342
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
343
|
+
elif rest.startswith("proj_out."):
|
|
344
|
+
new_key = f"decoder.mid_block.attentions.0.to_out.0.{rest[9:]}"
|
|
345
|
+
if rest.endswith(".weight") and tensor.dim() == 4:
|
|
346
|
+
tensor = tensor.squeeze(-1).squeeze(-1)
|
|
347
|
+
elif rest.startswith("norm."):
|
|
348
|
+
new_key = f"decoder.mid_block.attentions.0.group_norm.{rest[5:]}"
|
|
349
|
+
else:
|
|
350
|
+
new_key = f"decoder.mid_block.attentions.0.{rest}"
|
|
351
|
+
converted[new_key] = tensor
|
|
352
|
+
continue
|
|
353
|
+
|
|
354
|
+
# Decoder norm_out -> conv_norm_out
|
|
355
|
+
if old_key.startswith("decoder.norm_out."):
|
|
356
|
+
new_key = old_key.replace("decoder.norm_out.", "decoder.conv_norm_out.")
|
|
357
|
+
converted[new_key] = tensor
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
# Decoder post_quant_conv -> post_quant_conv (move to top level)
|
|
361
|
+
if old_key.startswith("decoder.post_quant_conv."):
|
|
362
|
+
new_key = old_key.replace("decoder.post_quant_conv.", "post_quant_conv.")
|
|
363
|
+
converted[new_key] = tensor
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
# Keep other keys as-is (like encoder.conv_in, decoder.conv_in, decoder.conv_out, bn.*)
|
|
367
|
+
converted[new_key] = tensor
|
|
368
|
+
|
|
369
|
+
return converted
|
|
370
|
+
|
|
371
|
+
|
|
111
372
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers)
|
|
112
373
|
class CLIPDiffusersLoader(ModelLoader):
|
|
113
374
|
"""Class to load main models."""
|
|
@@ -122,9 +383,9 @@ class CLIPDiffusersLoader(ModelLoader):
|
|
|
122
383
|
|
|
123
384
|
match submodel_type:
|
|
124
385
|
case SubModelType.Tokenizer:
|
|
125
|
-
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
|
|
386
|
+
return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer", local_files_only=True)
|
|
126
387
|
case SubModelType.TextEncoder:
|
|
127
|
-
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
|
|
388
|
+
return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder", local_files_only=True)
|
|
128
389
|
|
|
129
390
|
raise ValueError(
|
|
130
391
|
f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
@@ -148,10 +409,12 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
|
|
|
148
409
|
)
|
|
149
410
|
match submodel_type:
|
|
150
411
|
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
|
151
|
-
return T5TokenizerFast.from_pretrained(
|
|
412
|
+
return T5TokenizerFast.from_pretrained(
|
|
413
|
+
Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
|
|
414
|
+
)
|
|
152
415
|
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
|
153
416
|
te2_model_path = Path(config.path) / "text_encoder_2"
|
|
154
|
-
model_config = AutoConfig.from_pretrained(te2_model_path)
|
|
417
|
+
model_config = AutoConfig.from_pretrained(te2_model_path, local_files_only=True)
|
|
155
418
|
with accelerate.init_empty_weights():
|
|
156
419
|
model = AutoModelForTextEncoding.from_config(model_config)
|
|
157
420
|
model = quantize_model_llm_int8(model, modules_to_not_convert=set())
|
|
@@ -192,10 +455,15 @@ class T5EncoderCheckpointModel(ModelLoader):
|
|
|
192
455
|
|
|
193
456
|
match submodel_type:
|
|
194
457
|
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
|
|
195
|
-
return T5TokenizerFast.from_pretrained(
|
|
458
|
+
return T5TokenizerFast.from_pretrained(
|
|
459
|
+
Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
|
|
460
|
+
)
|
|
196
461
|
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
|
|
197
462
|
return T5EncoderModel.from_pretrained(
|
|
198
|
-
Path(config.path) / "text_encoder_2",
|
|
463
|
+
Path(config.path) / "text_encoder_2",
|
|
464
|
+
torch_dtype="auto",
|
|
465
|
+
low_cpu_mem_usage=True,
|
|
466
|
+
local_files_only=True,
|
|
199
467
|
)
|
|
200
468
|
|
|
201
469
|
raise ValueError(
|
|
@@ -333,6 +601,750 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|
|
333
601
|
return model
|
|
334
602
|
|
|
335
603
|
|
|
604
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Diffusers)
|
|
605
|
+
class FluxDiffusersModel(GenericDiffusersLoader):
|
|
606
|
+
"""Class to load FLUX.1 main models in diffusers format."""
|
|
607
|
+
|
|
608
|
+
def _load_model(
|
|
609
|
+
self,
|
|
610
|
+
config: AnyModelConfig,
|
|
611
|
+
submodel_type: Optional[SubModelType] = None,
|
|
612
|
+
) -> AnyModel:
|
|
613
|
+
if isinstance(config, Checkpoint_Config_Base):
|
|
614
|
+
raise NotImplementedError("CheckpointConfigBase is not implemented for FLUX diffusers models.")
|
|
615
|
+
|
|
616
|
+
if submodel_type is None:
|
|
617
|
+
raise Exception("A submodel type must be provided when loading main pipelines.")
|
|
618
|
+
|
|
619
|
+
model_path = Path(config.path)
|
|
620
|
+
load_class = self.get_hf_load_class(model_path, submodel_type)
|
|
621
|
+
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
|
|
622
|
+
variant = repo_variant.value if repo_variant else None
|
|
623
|
+
model_path = model_path / submodel_type.value
|
|
624
|
+
|
|
625
|
+
# We force bfloat16 for FLUX models. This is required for correct inference.
|
|
626
|
+
dtype = torch.bfloat16
|
|
627
|
+
try:
|
|
628
|
+
result: AnyModel = load_class.from_pretrained(
|
|
629
|
+
model_path,
|
|
630
|
+
torch_dtype=dtype,
|
|
631
|
+
variant=variant,
|
|
632
|
+
local_files_only=True,
|
|
633
|
+
)
|
|
634
|
+
except OSError as e:
|
|
635
|
+
if variant and "no file named" in str(
|
|
636
|
+
e
|
|
637
|
+
): # try without the variant, just in case user's preferences changed
|
|
638
|
+
result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
|
|
639
|
+
else:
|
|
640
|
+
raise e
|
|
641
|
+
|
|
642
|
+
return result
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
|
646
|
+
class Flux2DiffusersModel(GenericDiffusersLoader):
|
|
647
|
+
"""Class to load FLUX.2 main models in diffusers format (e.g. FLUX.2 Klein)."""
|
|
648
|
+
|
|
649
|
+
def _load_model(
|
|
650
|
+
self,
|
|
651
|
+
config: AnyModelConfig,
|
|
652
|
+
submodel_type: Optional[SubModelType] = None,
|
|
653
|
+
) -> AnyModel:
|
|
654
|
+
if isinstance(config, Checkpoint_Config_Base):
|
|
655
|
+
raise NotImplementedError("CheckpointConfigBase is not implemented for FLUX.2 diffusers models.")
|
|
656
|
+
|
|
657
|
+
if submodel_type is None:
|
|
658
|
+
raise Exception("A submodel type must be provided when loading main pipelines.")
|
|
659
|
+
|
|
660
|
+
model_path = Path(config.path)
|
|
661
|
+
load_class = self.get_hf_load_class(model_path, submodel_type)
|
|
662
|
+
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
|
|
663
|
+
variant = repo_variant.value if repo_variant else None
|
|
664
|
+
model_path = model_path / submodel_type.value
|
|
665
|
+
|
|
666
|
+
# We force bfloat16 for FLUX.2 models. This is required for correct inference.
|
|
667
|
+
# We use low_cpu_mem_usage=False to avoid meta tensors for weights not in checkpoint.
|
|
668
|
+
# FLUX.2 Klein models may have guidance_embeds=False, so the guidance_embed layers
|
|
669
|
+
# won't be in the checkpoint but the model class still creates them.
|
|
670
|
+
# We use SilenceWarnings to suppress the "guidance_embeds is not expected" warning
|
|
671
|
+
# from diffusers Flux2Transformer2DModel.
|
|
672
|
+
dtype = torch.bfloat16
|
|
673
|
+
with SilenceWarnings():
|
|
674
|
+
try:
|
|
675
|
+
result: AnyModel = load_class.from_pretrained(
|
|
676
|
+
model_path,
|
|
677
|
+
torch_dtype=dtype,
|
|
678
|
+
variant=variant,
|
|
679
|
+
local_files_only=True,
|
|
680
|
+
low_cpu_mem_usage=False,
|
|
681
|
+
)
|
|
682
|
+
except OSError as e:
|
|
683
|
+
if variant and "no file named" in str(
|
|
684
|
+
e
|
|
685
|
+
): # try without the variant, just in case user's preferences changed
|
|
686
|
+
result = load_class.from_pretrained(
|
|
687
|
+
model_path,
|
|
688
|
+
torch_dtype=dtype,
|
|
689
|
+
local_files_only=True,
|
|
690
|
+
low_cpu_mem_usage=False,
|
|
691
|
+
)
|
|
692
|
+
else:
|
|
693
|
+
raise e
|
|
694
|
+
|
|
695
|
+
# For Klein models without guidance_embeds, zero out the guidance_embedder weights
|
|
696
|
+
# that were randomly initialized by diffusers. This prevents noise from affecting
|
|
697
|
+
# the time embeddings.
|
|
698
|
+
if submodel_type == SubModelType.Transformer and hasattr(result, "time_guidance_embed"):
|
|
699
|
+
# Check if this is a Klein model without guidance (guidance_embeds=False in config)
|
|
700
|
+
transformer_config_path = model_path / "config.json"
|
|
701
|
+
if transformer_config_path.exists():
|
|
702
|
+
import json
|
|
703
|
+
|
|
704
|
+
with open(transformer_config_path, "r") as f:
|
|
705
|
+
transformer_config = json.load(f)
|
|
706
|
+
if not transformer_config.get("guidance_embeds", True):
|
|
707
|
+
# Zero out the guidance embedder weights
|
|
708
|
+
guidance_emb = result.time_guidance_embed.guidance_embedder
|
|
709
|
+
if hasattr(guidance_emb, "linear_1"):
|
|
710
|
+
guidance_emb.linear_1.weight.data.zero_()
|
|
711
|
+
if guidance_emb.linear_1.bias is not None:
|
|
712
|
+
guidance_emb.linear_1.bias.data.zero_()
|
|
713
|
+
if hasattr(guidance_emb, "linear_2"):
|
|
714
|
+
guidance_emb.linear_2.weight.data.zero_()
|
|
715
|
+
if guidance_emb.linear_2.bias is not None:
|
|
716
|
+
guidance_emb.linear_2.bias.data.zero_()
|
|
717
|
+
|
|
718
|
+
return result
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
|
722
|
+
class Flux2CheckpointModel(ModelLoader):
|
|
723
|
+
"""Class to load FLUX.2 transformer models from single-file checkpoints (safetensors)."""
|
|
724
|
+
|
|
725
|
+
def _load_model(
|
|
726
|
+
self,
|
|
727
|
+
config: AnyModelConfig,
|
|
728
|
+
submodel_type: Optional[SubModelType] = None,
|
|
729
|
+
) -> AnyModel:
|
|
730
|
+
if not isinstance(config, Checkpoint_Config_Base):
|
|
731
|
+
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
|
732
|
+
|
|
733
|
+
match submodel_type:
|
|
734
|
+
case SubModelType.Transformer:
|
|
735
|
+
return self._load_from_singlefile(config)
|
|
736
|
+
|
|
737
|
+
raise ValueError(
|
|
738
|
+
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
def _load_from_singlefile(
|
|
742
|
+
self,
|
|
743
|
+
config: AnyModelConfig,
|
|
744
|
+
) -> AnyModel:
|
|
745
|
+
from diffusers import Flux2Transformer2DModel
|
|
746
|
+
|
|
747
|
+
if not isinstance(config, Main_Checkpoint_Flux2_Config):
|
|
748
|
+
raise TypeError(
|
|
749
|
+
f"Expected Main_Checkpoint_Flux2_Config, got {type(config).__name__}. "
|
|
750
|
+
"Model configuration type mismatch."
|
|
751
|
+
)
|
|
752
|
+
model_path = Path(config.path)
|
|
753
|
+
|
|
754
|
+
# Load state dict
|
|
755
|
+
sd = load_file(model_path)
|
|
756
|
+
|
|
757
|
+
# Handle FP8 quantized weights (ComfyUI-style or scaled FP8)
|
|
758
|
+
# These store weights as: layer.weight (FP8) + layer.weight_scale (FP32 scalar)
|
|
759
|
+
sd = self._dequantize_fp8_weights(sd)
|
|
760
|
+
|
|
761
|
+
# Check if keys have ComfyUI-style prefix and strip if needed
|
|
762
|
+
prefix_to_strip = None
|
|
763
|
+
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
|
764
|
+
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
|
765
|
+
prefix_to_strip = prefix
|
|
766
|
+
break
|
|
767
|
+
|
|
768
|
+
if prefix_to_strip:
|
|
769
|
+
sd = {
|
|
770
|
+
(k[len(prefix_to_strip) :] if isinstance(k, str) and k.startswith(prefix_to_strip) else k): v
|
|
771
|
+
for k, v in sd.items()
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
# Convert BFL format state dict to diffusers format
|
|
775
|
+
converted_sd = self._convert_flux2_bfl_to_diffusers(sd)
|
|
776
|
+
|
|
777
|
+
# Detect architecture from checkpoint keys
|
|
778
|
+
double_block_indices = [
|
|
779
|
+
int(k.split(".")[1])
|
|
780
|
+
for k in converted_sd.keys()
|
|
781
|
+
if isinstance(k, str) and k.startswith("transformer_blocks.")
|
|
782
|
+
]
|
|
783
|
+
single_block_indices = [
|
|
784
|
+
int(k.split(".")[1])
|
|
785
|
+
for k in converted_sd.keys()
|
|
786
|
+
if isinstance(k, str) and k.startswith("single_transformer_blocks.")
|
|
787
|
+
]
|
|
788
|
+
|
|
789
|
+
num_layers = max(double_block_indices) + 1 if double_block_indices else 5
|
|
790
|
+
num_single_layers = max(single_block_indices) + 1 if single_block_indices else 20
|
|
791
|
+
|
|
792
|
+
# Get dimensions from weights
|
|
793
|
+
# context_embedder.weight shape: [hidden_size, joint_attention_dim]
|
|
794
|
+
context_embedder_weight = converted_sd.get("context_embedder.weight")
|
|
795
|
+
if context_embedder_weight is not None:
|
|
796
|
+
hidden_size = context_embedder_weight.shape[0]
|
|
797
|
+
joint_attention_dim = context_embedder_weight.shape[1]
|
|
798
|
+
else:
|
|
799
|
+
# Default to Klein 4B dimensions
|
|
800
|
+
hidden_size = 3072
|
|
801
|
+
joint_attention_dim = 7680
|
|
802
|
+
|
|
803
|
+
x_embedder_weight = converted_sd.get("x_embedder.weight")
|
|
804
|
+
if x_embedder_weight is not None:
|
|
805
|
+
in_channels = x_embedder_weight.shape[1]
|
|
806
|
+
else:
|
|
807
|
+
in_channels = 128
|
|
808
|
+
|
|
809
|
+
# Calculate num_attention_heads from hidden_size
|
|
810
|
+
# Klein 4B: hidden_size=3072, num_attention_heads=24 (3072/128=24)
|
|
811
|
+
# Klein 9B: hidden_size=4096, num_attention_heads=32 (4096/128=32)
|
|
812
|
+
attention_head_dim = 128
|
|
813
|
+
num_attention_heads = hidden_size // attention_head_dim
|
|
814
|
+
|
|
815
|
+
# Klein models don't have guidance embeddings - check if they're in the checkpoint
|
|
816
|
+
has_guidance = "time_guidance_embed.guidance_embedder.linear_1.weight" in converted_sd
|
|
817
|
+
|
|
818
|
+
# Create model with detected configuration
|
|
819
|
+
with SilenceWarnings():
|
|
820
|
+
with accelerate.init_empty_weights():
|
|
821
|
+
model = Flux2Transformer2DModel(
|
|
822
|
+
in_channels=in_channels,
|
|
823
|
+
out_channels=in_channels,
|
|
824
|
+
num_layers=num_layers,
|
|
825
|
+
num_single_layers=num_single_layers,
|
|
826
|
+
attention_head_dim=attention_head_dim,
|
|
827
|
+
num_attention_heads=num_attention_heads,
|
|
828
|
+
joint_attention_dim=joint_attention_dim,
|
|
829
|
+
patch_size=1,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
# If Klein model without guidance, initialize guidance embedder with zeros
|
|
833
|
+
if not has_guidance:
|
|
834
|
+
# Get the expected dimensions from timestep embedder (they should match)
|
|
835
|
+
timestep_linear1 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_1.weight")
|
|
836
|
+
if timestep_linear1 is not None:
|
|
837
|
+
in_features = timestep_linear1.shape[1]
|
|
838
|
+
out_features = timestep_linear1.shape[0]
|
|
839
|
+
# Initialize guidance embedder with same shape as timestep embedder
|
|
840
|
+
converted_sd["time_guidance_embed.guidance_embedder.linear_1.weight"] = torch.zeros(
|
|
841
|
+
out_features, in_features, dtype=torch.bfloat16
|
|
842
|
+
)
|
|
843
|
+
timestep_linear2 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_2.weight")
|
|
844
|
+
if timestep_linear2 is not None:
|
|
845
|
+
in_features2 = timestep_linear2.shape[1]
|
|
846
|
+
out_features2 = timestep_linear2.shape[0]
|
|
847
|
+
converted_sd["time_guidance_embed.guidance_embedder.linear_2.weight"] = torch.zeros(
|
|
848
|
+
out_features2, in_features2, dtype=torch.bfloat16
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
# Convert to bfloat16 and load
|
|
852
|
+
for k in converted_sd.keys():
|
|
853
|
+
converted_sd[k] = converted_sd[k].to(torch.bfloat16)
|
|
854
|
+
|
|
855
|
+
# Load the state dict - guidance weights were already initialized above if missing
|
|
856
|
+
model.load_state_dict(converted_sd, assign=True)
|
|
857
|
+
|
|
858
|
+
return model
|
|
859
|
+
|
|
860
|
+
def _convert_flux2_bfl_to_diffusers(self, sd: dict) -> dict:
|
|
861
|
+
"""Convert FLUX.2 BFL format state dict to diffusers format.
|
|
862
|
+
|
|
863
|
+
Based on diffusers convert_flux2_to_diffusers.py key mappings.
|
|
864
|
+
"""
|
|
865
|
+
converted = {}
|
|
866
|
+
|
|
867
|
+
# Basic key renames
|
|
868
|
+
key_renames = {
|
|
869
|
+
"img_in.weight": "x_embedder.weight",
|
|
870
|
+
"txt_in.weight": "context_embedder.weight",
|
|
871
|
+
"time_in.in_layer.weight": "time_guidance_embed.timestep_embedder.linear_1.weight",
|
|
872
|
+
"time_in.out_layer.weight": "time_guidance_embed.timestep_embedder.linear_2.weight",
|
|
873
|
+
"guidance_in.in_layer.weight": "time_guidance_embed.guidance_embedder.linear_1.weight",
|
|
874
|
+
"guidance_in.out_layer.weight": "time_guidance_embed.guidance_embedder.linear_2.weight",
|
|
875
|
+
"double_stream_modulation_img.lin.weight": "double_stream_modulation_img.linear.weight",
|
|
876
|
+
"double_stream_modulation_txt.lin.weight": "double_stream_modulation_txt.linear.weight",
|
|
877
|
+
"single_stream_modulation.lin.weight": "single_stream_modulation.linear.weight",
|
|
878
|
+
"final_layer.linear.weight": "proj_out.weight",
|
|
879
|
+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
for old_key, tensor in sd.items():
|
|
883
|
+
new_key = old_key
|
|
884
|
+
|
|
885
|
+
# Apply basic renames
|
|
886
|
+
if old_key in key_renames:
|
|
887
|
+
new_key = key_renames[old_key]
|
|
888
|
+
# Apply scale-shift swap for adaLN modulation weights
|
|
889
|
+
# BFL and diffusers use different parameter ordering for AdaLayerNorm
|
|
890
|
+
if old_key == "final_layer.adaLN_modulation.1.weight":
|
|
891
|
+
tensor = self._swap_scale_shift(tensor)
|
|
892
|
+
converted[new_key] = tensor
|
|
893
|
+
continue
|
|
894
|
+
|
|
895
|
+
# Convert double_blocks.X.* to transformer_blocks.X.*
|
|
896
|
+
if old_key.startswith("double_blocks."):
|
|
897
|
+
new_key = self._convert_double_block_key(old_key, tensor, converted)
|
|
898
|
+
if new_key is None:
|
|
899
|
+
continue # Key was handled specially
|
|
900
|
+
# Convert single_blocks.X.* to single_transformer_blocks.X.*
|
|
901
|
+
elif old_key.startswith("single_blocks."):
|
|
902
|
+
new_key = self._convert_single_block_key(old_key, tensor, converted)
|
|
903
|
+
if new_key is None:
|
|
904
|
+
continue # Key was handled specially
|
|
905
|
+
|
|
906
|
+
if new_key != old_key or new_key not in converted:
|
|
907
|
+
converted[new_key] = tensor
|
|
908
|
+
|
|
909
|
+
return converted
|
|
910
|
+
|
|
911
|
+
def _convert_double_block_key(self, key: str, tensor: torch.Tensor, converted: dict) -> str | None:
|
|
912
|
+
"""Convert double_blocks key to transformer_blocks format."""
|
|
913
|
+
parts = key.split(".")
|
|
914
|
+
block_idx = parts[1]
|
|
915
|
+
rest = ".".join(parts[2:])
|
|
916
|
+
|
|
917
|
+
prefix = f"transformer_blocks.{block_idx}"
|
|
918
|
+
|
|
919
|
+
# Attention QKV conversion - BFL uses fused qkv, diffusers uses separate
|
|
920
|
+
if "img_attn.qkv.weight" in rest:
|
|
921
|
+
# Split fused QKV into separate Q, K, V
|
|
922
|
+
# Defensive check: ensure tensor has at least 1 dimension and can be split into 3
|
|
923
|
+
if tensor.dim() < 1 or tensor.shape[0] % 3 != 0:
|
|
924
|
+
# Skip malformed tensors (might be metadata or corrupted)
|
|
925
|
+
return key
|
|
926
|
+
q, k, v = tensor.chunk(3, dim=0)
|
|
927
|
+
converted[f"{prefix}.attn.to_q.weight"] = q
|
|
928
|
+
converted[f"{prefix}.attn.to_k.weight"] = k
|
|
929
|
+
converted[f"{prefix}.attn.to_v.weight"] = v
|
|
930
|
+
return None
|
|
931
|
+
elif "txt_attn.qkv.weight" in rest:
|
|
932
|
+
# Defensive check
|
|
933
|
+
if tensor.dim() < 1 or tensor.shape[0] % 3 != 0:
|
|
934
|
+
return key
|
|
935
|
+
q, k, v = tensor.chunk(3, dim=0)
|
|
936
|
+
converted[f"{prefix}.attn.add_q_proj.weight"] = q
|
|
937
|
+
converted[f"{prefix}.attn.add_k_proj.weight"] = k
|
|
938
|
+
converted[f"{prefix}.attn.add_v_proj.weight"] = v
|
|
939
|
+
return None
|
|
940
|
+
|
|
941
|
+
# Attention output projection
|
|
942
|
+
if "img_attn.proj.weight" in rest:
|
|
943
|
+
return f"{prefix}.attn.to_out.0.weight"
|
|
944
|
+
elif "txt_attn.proj.weight" in rest:
|
|
945
|
+
return f"{prefix}.attn.to_add_out.weight"
|
|
946
|
+
|
|
947
|
+
# Attention norms
|
|
948
|
+
if "img_attn.norm.query_norm.scale" in rest:
|
|
949
|
+
return f"{prefix}.attn.norm_q.weight"
|
|
950
|
+
elif "img_attn.norm.key_norm.scale" in rest:
|
|
951
|
+
return f"{prefix}.attn.norm_k.weight"
|
|
952
|
+
elif "txt_attn.norm.query_norm.scale" in rest:
|
|
953
|
+
return f"{prefix}.attn.norm_added_q.weight"
|
|
954
|
+
elif "txt_attn.norm.key_norm.scale" in rest:
|
|
955
|
+
return f"{prefix}.attn.norm_added_k.weight"
|
|
956
|
+
|
|
957
|
+
# MLP layers
|
|
958
|
+
if "img_mlp.0.weight" in rest:
|
|
959
|
+
return f"{prefix}.ff.linear_in.weight"
|
|
960
|
+
elif "img_mlp.2.weight" in rest:
|
|
961
|
+
return f"{prefix}.ff.linear_out.weight"
|
|
962
|
+
elif "txt_mlp.0.weight" in rest:
|
|
963
|
+
return f"{prefix}.ff_context.linear_in.weight"
|
|
964
|
+
elif "txt_mlp.2.weight" in rest:
|
|
965
|
+
return f"{prefix}.ff_context.linear_out.weight"
|
|
966
|
+
|
|
967
|
+
return key
|
|
968
|
+
|
|
969
|
+
def _convert_single_block_key(self, key: str, tensor: torch.Tensor, converted: dict) -> str | None:
|
|
970
|
+
"""Convert single_blocks key to single_transformer_blocks format."""
|
|
971
|
+
parts = key.split(".")
|
|
972
|
+
block_idx = parts[1]
|
|
973
|
+
rest = ".".join(parts[2:])
|
|
974
|
+
|
|
975
|
+
prefix = f"single_transformer_blocks.{block_idx}"
|
|
976
|
+
|
|
977
|
+
# linear1 is the fused QKV+MLP projection
|
|
978
|
+
if "linear1.weight" in rest:
|
|
979
|
+
return f"{prefix}.attn.to_qkv_mlp_proj.weight"
|
|
980
|
+
elif "linear2.weight" in rest:
|
|
981
|
+
return f"{prefix}.attn.to_out.weight"
|
|
982
|
+
|
|
983
|
+
# Norms
|
|
984
|
+
if "norm.query_norm.scale" in rest:
|
|
985
|
+
return f"{prefix}.attn.norm_q.weight"
|
|
986
|
+
elif "norm.key_norm.scale" in rest:
|
|
987
|
+
return f"{prefix}.attn.norm_k.weight"
|
|
988
|
+
|
|
989
|
+
return key
|
|
990
|
+
|
|
991
|
+
def _swap_scale_shift(self, weight: torch.Tensor) -> torch.Tensor:
|
|
992
|
+
"""Swap scale and shift in AdaLayerNorm weights.
|
|
993
|
+
|
|
994
|
+
BFL and diffusers use different parameter ordering for AdaLayerNorm.
|
|
995
|
+
This function swaps the two halves of the weight tensor.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
weight: Weight tensor of shape (out_features,) or (out_features, in_features)
|
|
999
|
+
|
|
1000
|
+
Returns:
|
|
1001
|
+
Weight tensor with scale and shift swapped.
|
|
1002
|
+
"""
|
|
1003
|
+
# Defensive check: ensure tensor can be split
|
|
1004
|
+
if weight.dim() < 1 or weight.shape[0] % 2 != 0:
|
|
1005
|
+
return weight
|
|
1006
|
+
# Split in half along the first dimension and swap
|
|
1007
|
+
shift, scale = weight.chunk(2, dim=0)
|
|
1008
|
+
return torch.cat([scale, shift], dim=0)
|
|
1009
|
+
|
|
1010
|
+
def _dequantize_fp8_weights(self, sd: dict) -> dict:
|
|
1011
|
+
"""Dequantize FP8 quantized weights in the state dict.
|
|
1012
|
+
|
|
1013
|
+
ComfyUI and some FLUX.2 models store quantized weights as:
|
|
1014
|
+
- layer.weight: quantized FP8 data
|
|
1015
|
+
- layer.weight_scale: scale factor (FP32 scalar or per-channel)
|
|
1016
|
+
|
|
1017
|
+
Dequantization formula: dequantized = weight.to(float) * weight_scale
|
|
1018
|
+
|
|
1019
|
+
Also handles FP8 tensors stored with float8_e4m3fn dtype by converting to float.
|
|
1020
|
+
"""
|
|
1021
|
+
# Check for ComfyUI-style scale factors
|
|
1022
|
+
weight_scale_keys = [k for k in sd.keys() if isinstance(k, str) and k.endswith(".weight_scale")]
|
|
1023
|
+
|
|
1024
|
+
for scale_key in weight_scale_keys:
|
|
1025
|
+
# Get the corresponding weight key
|
|
1026
|
+
weight_key = scale_key.replace(".weight_scale", ".weight")
|
|
1027
|
+
if weight_key in sd:
|
|
1028
|
+
weight = sd[weight_key]
|
|
1029
|
+
scale = sd[scale_key]
|
|
1030
|
+
|
|
1031
|
+
# Dequantize: convert FP8 to float and multiply by scale
|
|
1032
|
+
# Note: Float8 types require .float() instead of .to(torch.float32)
|
|
1033
|
+
weight_float = weight.float()
|
|
1034
|
+
scale = scale.float()
|
|
1035
|
+
|
|
1036
|
+
# Handle block-wise quantization where scale may have different shape
|
|
1037
|
+
if scale.dim() > 0 and scale.shape != weight_float.shape and scale.numel() > 1:
|
|
1038
|
+
for dim in range(len(weight_float.shape)):
|
|
1039
|
+
if dim < len(scale.shape) and scale.shape[dim] != weight_float.shape[dim]:
|
|
1040
|
+
block_size = weight_float.shape[dim] // scale.shape[dim]
|
|
1041
|
+
if block_size > 1:
|
|
1042
|
+
scale = scale.repeat_interleave(block_size, dim=dim)
|
|
1043
|
+
|
|
1044
|
+
sd[weight_key] = weight_float * scale
|
|
1045
|
+
|
|
1046
|
+
# Filter out scale metadata keys and other FP8 metadata
|
|
1047
|
+
keys_to_remove = [
|
|
1048
|
+
k
|
|
1049
|
+
for k in sd.keys()
|
|
1050
|
+
if isinstance(k, str)
|
|
1051
|
+
and (k.endswith(".weight_scale") or k.endswith(".scale_weight") or "comfy_quant" in k or k == "scaled_fp8")
|
|
1052
|
+
]
|
|
1053
|
+
for k in keys_to_remove:
|
|
1054
|
+
del sd[k]
|
|
1055
|
+
|
|
1056
|
+
# Handle native FP8 tensors (float8_e4m3fn dtype) that aren't already dequantized
|
|
1057
|
+
# Also filter out 0-dimensional tensors (scalars) which are typically metadata
|
|
1058
|
+
keys_to_convert = []
|
|
1059
|
+
keys_to_remove_scalars = []
|
|
1060
|
+
for key in list(sd.keys()):
|
|
1061
|
+
tensor = sd[key]
|
|
1062
|
+
if hasattr(tensor, "dim"):
|
|
1063
|
+
if tensor.dim() == 0:
|
|
1064
|
+
# 0-dimensional tensor (scalar) - likely metadata, remove it
|
|
1065
|
+
keys_to_remove_scalars.append(key)
|
|
1066
|
+
elif hasattr(tensor, "dtype") and "float8" in str(tensor.dtype):
|
|
1067
|
+
# Native FP8 tensor - mark for conversion
|
|
1068
|
+
keys_to_convert.append(key)
|
|
1069
|
+
|
|
1070
|
+
for k in keys_to_remove_scalars:
|
|
1071
|
+
del sd[k]
|
|
1072
|
+
|
|
1073
|
+
for key in keys_to_convert:
|
|
1074
|
+
# Convert FP8 tensor to float32
|
|
1075
|
+
sd[key] = sd[key].float()
|
|
1076
|
+
|
|
1077
|
+
return sd
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Flux2, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
|
|
1081
|
+
class Flux2GGUFCheckpointModel(ModelLoader):
|
|
1082
|
+
"""Class to load GGUF-quantized FLUX.2 transformer models."""
|
|
1083
|
+
|
|
1084
|
+
def _load_model(
|
|
1085
|
+
self,
|
|
1086
|
+
config: AnyModelConfig,
|
|
1087
|
+
submodel_type: Optional[SubModelType] = None,
|
|
1088
|
+
) -> AnyModel:
|
|
1089
|
+
if not isinstance(config, Main_GGUF_Flux2_Config):
|
|
1090
|
+
raise ValueError("Only Main_GGUF_Flux2_Config models are currently supported here.")
|
|
1091
|
+
|
|
1092
|
+
match submodel_type:
|
|
1093
|
+
case SubModelType.Transformer:
|
|
1094
|
+
return self._load_from_singlefile(config)
|
|
1095
|
+
|
|
1096
|
+
raise ValueError(
|
|
1097
|
+
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
def _load_from_singlefile(
|
|
1101
|
+
self,
|
|
1102
|
+
config: Main_GGUF_Flux2_Config,
|
|
1103
|
+
) -> AnyModel:
|
|
1104
|
+
from diffusers import Flux2Transformer2DModel
|
|
1105
|
+
|
|
1106
|
+
model_path = Path(config.path)
|
|
1107
|
+
|
|
1108
|
+
# Load GGUF state dict
|
|
1109
|
+
sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16)
|
|
1110
|
+
|
|
1111
|
+
# Check if keys have ComfyUI-style prefix and strip if needed
|
|
1112
|
+
prefix_to_strip = None
|
|
1113
|
+
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
|
1114
|
+
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
|
1115
|
+
prefix_to_strip = prefix
|
|
1116
|
+
break
|
|
1117
|
+
|
|
1118
|
+
if prefix_to_strip:
|
|
1119
|
+
sd = {
|
|
1120
|
+
(k[len(prefix_to_strip) :] if isinstance(k, str) and k.startswith(prefix_to_strip) else k): v
|
|
1121
|
+
for k, v in sd.items()
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
# Convert BFL format state dict to diffusers format
|
|
1125
|
+
converted_sd = self._convert_flux2_bfl_to_diffusers(sd)
|
|
1126
|
+
|
|
1127
|
+
# Detect architecture from checkpoint keys
|
|
1128
|
+
double_block_indices = [
|
|
1129
|
+
int(k.split(".")[1])
|
|
1130
|
+
for k in converted_sd.keys()
|
|
1131
|
+
if isinstance(k, str) and k.startswith("transformer_blocks.")
|
|
1132
|
+
]
|
|
1133
|
+
single_block_indices = [
|
|
1134
|
+
int(k.split(".")[1])
|
|
1135
|
+
for k in converted_sd.keys()
|
|
1136
|
+
if isinstance(k, str) and k.startswith("single_transformer_blocks.")
|
|
1137
|
+
]
|
|
1138
|
+
|
|
1139
|
+
num_layers = max(double_block_indices) + 1 if double_block_indices else 5
|
|
1140
|
+
num_single_layers = max(single_block_indices) + 1 if single_block_indices else 20
|
|
1141
|
+
|
|
1142
|
+
# Get dimensions from weights
|
|
1143
|
+
# context_embedder.weight shape: [hidden_size, joint_attention_dim]
|
|
1144
|
+
context_embedder_weight = converted_sd.get("context_embedder.weight")
|
|
1145
|
+
if context_embedder_weight is not None:
|
|
1146
|
+
if hasattr(context_embedder_weight, "tensor_shape"):
|
|
1147
|
+
hidden_size = context_embedder_weight.tensor_shape[0]
|
|
1148
|
+
joint_attention_dim = context_embedder_weight.tensor_shape[1]
|
|
1149
|
+
else:
|
|
1150
|
+
hidden_size = context_embedder_weight.shape[0]
|
|
1151
|
+
joint_attention_dim = context_embedder_weight.shape[1]
|
|
1152
|
+
else:
|
|
1153
|
+
# Default to Klein 4B dimensions
|
|
1154
|
+
hidden_size = 3072
|
|
1155
|
+
joint_attention_dim = 7680
|
|
1156
|
+
|
|
1157
|
+
x_embedder_weight = converted_sd.get("x_embedder.weight")
|
|
1158
|
+
if x_embedder_weight is not None:
|
|
1159
|
+
in_channels = (
|
|
1160
|
+
x_embedder_weight.tensor_shape[1]
|
|
1161
|
+
if hasattr(x_embedder_weight, "tensor_shape")
|
|
1162
|
+
else x_embedder_weight.shape[1]
|
|
1163
|
+
)
|
|
1164
|
+
else:
|
|
1165
|
+
in_channels = 128
|
|
1166
|
+
|
|
1167
|
+
# Calculate num_attention_heads from hidden_size
|
|
1168
|
+
# Klein 4B: hidden_size=3072, num_attention_heads=24 (3072/128=24)
|
|
1169
|
+
# Klein 9B: hidden_size=4096, num_attention_heads=32 (4096/128=32)
|
|
1170
|
+
attention_head_dim = 128
|
|
1171
|
+
num_attention_heads = hidden_size // attention_head_dim
|
|
1172
|
+
|
|
1173
|
+
# Klein models don't have guidance embeddings - check if they're in the checkpoint
|
|
1174
|
+
has_guidance = "time_guidance_embed.guidance_embedder.linear_1.weight" in converted_sd
|
|
1175
|
+
|
|
1176
|
+
# Create model with detected configuration
|
|
1177
|
+
with SilenceWarnings():
|
|
1178
|
+
with accelerate.init_empty_weights():
|
|
1179
|
+
model = Flux2Transformer2DModel(
|
|
1180
|
+
in_channels=in_channels,
|
|
1181
|
+
out_channels=in_channels,
|
|
1182
|
+
num_layers=num_layers,
|
|
1183
|
+
num_single_layers=num_single_layers,
|
|
1184
|
+
attention_head_dim=attention_head_dim,
|
|
1185
|
+
num_attention_heads=num_attention_heads,
|
|
1186
|
+
joint_attention_dim=joint_attention_dim,
|
|
1187
|
+
patch_size=1,
|
|
1188
|
+
)
|
|
1189
|
+
|
|
1190
|
+
# If Klein model without guidance, initialize guidance embedder with zeros
|
|
1191
|
+
if not has_guidance:
|
|
1192
|
+
timestep_linear1 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_1.weight")
|
|
1193
|
+
if timestep_linear1 is not None:
|
|
1194
|
+
in_features = (
|
|
1195
|
+
timestep_linear1.tensor_shape[1]
|
|
1196
|
+
if hasattr(timestep_linear1, "tensor_shape")
|
|
1197
|
+
else timestep_linear1.shape[1]
|
|
1198
|
+
)
|
|
1199
|
+
out_features = (
|
|
1200
|
+
timestep_linear1.tensor_shape[0]
|
|
1201
|
+
if hasattr(timestep_linear1, "tensor_shape")
|
|
1202
|
+
else timestep_linear1.shape[0]
|
|
1203
|
+
)
|
|
1204
|
+
converted_sd["time_guidance_embed.guidance_embedder.linear_1.weight"] = torch.zeros(
|
|
1205
|
+
out_features, in_features, dtype=torch.bfloat16
|
|
1206
|
+
)
|
|
1207
|
+
timestep_linear2 = converted_sd.get("time_guidance_embed.timestep_embedder.linear_2.weight")
|
|
1208
|
+
if timestep_linear2 is not None:
|
|
1209
|
+
in_features2 = (
|
|
1210
|
+
timestep_linear2.tensor_shape[1]
|
|
1211
|
+
if hasattr(timestep_linear2, "tensor_shape")
|
|
1212
|
+
else timestep_linear2.shape[1]
|
|
1213
|
+
)
|
|
1214
|
+
out_features2 = (
|
|
1215
|
+
timestep_linear2.tensor_shape[0]
|
|
1216
|
+
if hasattr(timestep_linear2, "tensor_shape")
|
|
1217
|
+
else timestep_linear2.shape[0]
|
|
1218
|
+
)
|
|
1219
|
+
converted_sd["time_guidance_embed.guidance_embedder.linear_2.weight"] = torch.zeros(
|
|
1220
|
+
out_features2, in_features2, dtype=torch.bfloat16
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
model.load_state_dict(converted_sd, assign=True)
|
|
1224
|
+
return model
|
|
1225
|
+
|
|
1226
|
+
def _convert_flux2_bfl_to_diffusers(self, sd: dict) -> dict:
|
|
1227
|
+
"""Convert FLUX.2 BFL format state dict to diffusers format."""
|
|
1228
|
+
converted = {}
|
|
1229
|
+
|
|
1230
|
+
key_renames = {
|
|
1231
|
+
"img_in.weight": "x_embedder.weight",
|
|
1232
|
+
"txt_in.weight": "context_embedder.weight",
|
|
1233
|
+
"time_in.in_layer.weight": "time_guidance_embed.timestep_embedder.linear_1.weight",
|
|
1234
|
+
"time_in.out_layer.weight": "time_guidance_embed.timestep_embedder.linear_2.weight",
|
|
1235
|
+
"guidance_in.in_layer.weight": "time_guidance_embed.guidance_embedder.linear_1.weight",
|
|
1236
|
+
"guidance_in.out_layer.weight": "time_guidance_embed.guidance_embedder.linear_2.weight",
|
|
1237
|
+
"double_stream_modulation_img.lin.weight": "double_stream_modulation_img.linear.weight",
|
|
1238
|
+
"double_stream_modulation_txt.lin.weight": "double_stream_modulation_txt.linear.weight",
|
|
1239
|
+
"single_stream_modulation.lin.weight": "single_stream_modulation.linear.weight",
|
|
1240
|
+
"final_layer.linear.weight": "proj_out.weight",
|
|
1241
|
+
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
|
1242
|
+
}
|
|
1243
|
+
|
|
1244
|
+
for old_key, tensor in sd.items():
|
|
1245
|
+
new_key = old_key
|
|
1246
|
+
|
|
1247
|
+
if old_key in key_renames:
|
|
1248
|
+
new_key = key_renames[old_key]
|
|
1249
|
+
if old_key == "final_layer.adaLN_modulation.1.weight":
|
|
1250
|
+
tensor = self._swap_scale_shift(tensor)
|
|
1251
|
+
converted[new_key] = tensor
|
|
1252
|
+
continue
|
|
1253
|
+
|
|
1254
|
+
if old_key.startswith("double_blocks."):
|
|
1255
|
+
new_key = self._convert_double_block_key(old_key, tensor, converted)
|
|
1256
|
+
if new_key is None:
|
|
1257
|
+
continue
|
|
1258
|
+
elif old_key.startswith("single_blocks."):
|
|
1259
|
+
new_key = self._convert_single_block_key(old_key, tensor, converted)
|
|
1260
|
+
if new_key is None:
|
|
1261
|
+
continue
|
|
1262
|
+
|
|
1263
|
+
if new_key != old_key or new_key not in converted:
|
|
1264
|
+
converted[new_key] = tensor
|
|
1265
|
+
|
|
1266
|
+
return converted
|
|
1267
|
+
|
|
1268
|
+
def _convert_double_block_key(self, key: str, tensor, converted: dict) -> str | None:
|
|
1269
|
+
parts = key.split(".")
|
|
1270
|
+
block_idx = parts[1]
|
|
1271
|
+
rest = ".".join(parts[2:])
|
|
1272
|
+
prefix = f"transformer_blocks.{block_idx}"
|
|
1273
|
+
|
|
1274
|
+
if "img_attn.qkv.weight" in rest:
|
|
1275
|
+
q, k, v = self._chunk_tensor(tensor, 3)
|
|
1276
|
+
converted[f"{prefix}.attn.to_q.weight"] = q
|
|
1277
|
+
converted[f"{prefix}.attn.to_k.weight"] = k
|
|
1278
|
+
converted[f"{prefix}.attn.to_v.weight"] = v
|
|
1279
|
+
return None
|
|
1280
|
+
elif "txt_attn.qkv.weight" in rest:
|
|
1281
|
+
q, k, v = self._chunk_tensor(tensor, 3)
|
|
1282
|
+
converted[f"{prefix}.attn.add_q_proj.weight"] = q
|
|
1283
|
+
converted[f"{prefix}.attn.add_k_proj.weight"] = k
|
|
1284
|
+
converted[f"{prefix}.attn.add_v_proj.weight"] = v
|
|
1285
|
+
return None
|
|
1286
|
+
|
|
1287
|
+
if "img_attn.proj.weight" in rest:
|
|
1288
|
+
return f"{prefix}.attn.to_out.0.weight"
|
|
1289
|
+
elif "txt_attn.proj.weight" in rest:
|
|
1290
|
+
return f"{prefix}.attn.to_add_out.weight"
|
|
1291
|
+
|
|
1292
|
+
if "img_attn.norm.query_norm.scale" in rest:
|
|
1293
|
+
return f"{prefix}.attn.norm_q.weight"
|
|
1294
|
+
elif "img_attn.norm.key_norm.scale" in rest:
|
|
1295
|
+
return f"{prefix}.attn.norm_k.weight"
|
|
1296
|
+
elif "txt_attn.norm.query_norm.scale" in rest:
|
|
1297
|
+
return f"{prefix}.attn.norm_added_q.weight"
|
|
1298
|
+
elif "txt_attn.norm.key_norm.scale" in rest:
|
|
1299
|
+
return f"{prefix}.attn.norm_added_k.weight"
|
|
1300
|
+
|
|
1301
|
+
if "img_mlp.0.weight" in rest:
|
|
1302
|
+
return f"{prefix}.ff.linear_in.weight"
|
|
1303
|
+
elif "img_mlp.2.weight" in rest:
|
|
1304
|
+
return f"{prefix}.ff.linear_out.weight"
|
|
1305
|
+
elif "txt_mlp.0.weight" in rest:
|
|
1306
|
+
return f"{prefix}.ff_context.linear_in.weight"
|
|
1307
|
+
elif "txt_mlp.2.weight" in rest:
|
|
1308
|
+
return f"{prefix}.ff_context.linear_out.weight"
|
|
1309
|
+
|
|
1310
|
+
return key
|
|
1311
|
+
|
|
1312
|
+
def _convert_single_block_key(self, key: str, tensor, converted: dict) -> str | None:
|
|
1313
|
+
parts = key.split(".")
|
|
1314
|
+
block_idx = parts[1]
|
|
1315
|
+
rest = ".".join(parts[2:])
|
|
1316
|
+
prefix = f"single_transformer_blocks.{block_idx}"
|
|
1317
|
+
|
|
1318
|
+
if "linear1.weight" in rest:
|
|
1319
|
+
return f"{prefix}.attn.to_qkv_mlp_proj.weight"
|
|
1320
|
+
elif "linear2.weight" in rest:
|
|
1321
|
+
return f"{prefix}.attn.to_out.weight"
|
|
1322
|
+
|
|
1323
|
+
if "norm.query_norm.scale" in rest:
|
|
1324
|
+
return f"{prefix}.attn.norm_q.weight"
|
|
1325
|
+
elif "norm.key_norm.scale" in rest:
|
|
1326
|
+
return f"{prefix}.attn.norm_k.weight"
|
|
1327
|
+
|
|
1328
|
+
return key
|
|
1329
|
+
|
|
1330
|
+
def _chunk_tensor(self, tensor, chunks: int):
|
|
1331
|
+
"""Chunk a tensor, handling both regular tensors and GGUF quantized tensors."""
|
|
1332
|
+
if hasattr(tensor, "get_dequantized_tensor"):
|
|
1333
|
+
# GGUF quantized tensor - dequantize first, then chunk
|
|
1334
|
+
# This loses quantization for the split weights, but is necessary
|
|
1335
|
+
# because diffusers uses separate Q/K/V projections
|
|
1336
|
+
tensor = tensor.get_dequantized_tensor()
|
|
1337
|
+
return tensor.chunk(chunks, dim=0)
|
|
1338
|
+
|
|
1339
|
+
def _swap_scale_shift(self, weight) -> torch.Tensor:
|
|
1340
|
+
"""Swap scale and shift in AdaLayerNorm weights."""
|
|
1341
|
+
if hasattr(weight, "get_dequantized_tensor"):
|
|
1342
|
+
# For GGUF, dequantize first
|
|
1343
|
+
weight = weight.get_dequantized_tensor()
|
|
1344
|
+
shift, scale = weight.chunk(2, dim=0)
|
|
1345
|
+
return torch.cat([scale, shift], dim=0)
|
|
1346
|
+
|
|
1347
|
+
|
|
336
1348
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
|
337
1349
|
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
|
338
1350
|
class FluxControlnetModel(ModelLoader):
|