InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +610 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +43 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +70 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en.json +70 -1
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,935 @@
|
|
|
1
|
+
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
2
|
+
"""Class for Z-Image model loading in InvokeAI."""
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
import accelerate
|
|
8
|
+
import torch
|
|
9
|
+
from transformers import AutoTokenizer, Qwen3ForCausalLM
|
|
10
|
+
|
|
11
|
+
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
|
|
12
|
+
from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_ZImage_Config
|
|
13
|
+
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
|
|
14
|
+
from invokeai.backend.model_manager.configs.main import Main_Checkpoint_ZImage_Config, Main_GGUF_ZImage_Config
|
|
15
|
+
from invokeai.backend.model_manager.configs.qwen3_encoder import (
|
|
16
|
+
Qwen3Encoder_Checkpoint_Config,
|
|
17
|
+
Qwen3Encoder_GGUF_Config,
|
|
18
|
+
Qwen3Encoder_Qwen3Encoder_Config,
|
|
19
|
+
)
|
|
20
|
+
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
|
21
|
+
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
|
22
|
+
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
|
23
|
+
from invokeai.backend.model_manager.taxonomy import (
|
|
24
|
+
AnyModel,
|
|
25
|
+
BaseModelType,
|
|
26
|
+
ModelFormat,
|
|
27
|
+
ModelType,
|
|
28
|
+
SubModelType,
|
|
29
|
+
)
|
|
30
|
+
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
|
31
|
+
from invokeai.backend.util.devices import TorchDevice
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
|
|
35
|
+
"""Convert Z-Image GGUF state dict keys to diffusers format.
|
|
36
|
+
|
|
37
|
+
The GGUF format uses original model keys that differ from diffusers:
|
|
38
|
+
- qkv.weight (fused) -> to_q.weight, to_k.weight, to_v.weight (split)
|
|
39
|
+
- out.weight -> to_out.0.weight
|
|
40
|
+
- q_norm.weight -> norm_q.weight
|
|
41
|
+
- k_norm.weight -> norm_k.weight
|
|
42
|
+
- x_embedder.* -> all_x_embedder.2-1.*
|
|
43
|
+
- final_layer.* -> all_final_layer.2-1.*
|
|
44
|
+
- norm_final.* -> skipped (diffusers uses non-learnable LayerNorm)
|
|
45
|
+
- x_pad_token, cap_pad_token: [dim] -> [1, dim] (diffusers expects batch dimension)
|
|
46
|
+
"""
|
|
47
|
+
new_sd: dict[str, Any] = {}
|
|
48
|
+
|
|
49
|
+
for key, value in sd.items():
|
|
50
|
+
if not isinstance(key, str):
|
|
51
|
+
new_sd[key] = value
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
# Handle padding tokens: GGUF has shape [dim], diffusers expects [1, dim]
|
|
55
|
+
if key in ("x_pad_token", "cap_pad_token"):
|
|
56
|
+
if hasattr(value, "shape") and len(value.shape) == 1:
|
|
57
|
+
# GGMLTensor doesn't support unsqueeze, so dequantize first if needed
|
|
58
|
+
if hasattr(value, "get_dequantized_tensor"):
|
|
59
|
+
value = value.get_dequantized_tensor()
|
|
60
|
+
# Use reshape instead of unsqueeze for better compatibility
|
|
61
|
+
value = torch.as_tensor(value).reshape(1, -1)
|
|
62
|
+
new_sd[key] = value
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
# Handle x_embedder -> all_x_embedder.2-1
|
|
66
|
+
if key.startswith("x_embedder."):
|
|
67
|
+
suffix = key[len("x_embedder.") :]
|
|
68
|
+
new_key = f"all_x_embedder.2-1.{suffix}"
|
|
69
|
+
new_sd[new_key] = value
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
# Handle final_layer -> all_final_layer.2-1
|
|
73
|
+
if key.startswith("final_layer."):
|
|
74
|
+
suffix = key[len("final_layer.") :]
|
|
75
|
+
new_key = f"all_final_layer.2-1.{suffix}"
|
|
76
|
+
new_sd[new_key] = value
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
# Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False
|
|
80
|
+
# (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros
|
|
81
|
+
if key.startswith("norm_final."):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
# Handle fused QKV weights - need to split
|
|
85
|
+
if ".attention.qkv." in key:
|
|
86
|
+
# Get the layer prefix and suffix
|
|
87
|
+
prefix = key.rsplit(".attention.qkv.", 1)[0]
|
|
88
|
+
suffix = key.rsplit(".attention.qkv.", 1)[1] # "weight" or "bias"
|
|
89
|
+
|
|
90
|
+
# Skip non-weight/bias tensors (e.g., FP8 scale_weight tensors)
|
|
91
|
+
# These are quantization metadata and should not be split
|
|
92
|
+
if suffix not in ("weight", "bias"):
|
|
93
|
+
new_sd[key] = value
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
# Split the fused QKV tensor into Q, K, V
|
|
97
|
+
tensor = value
|
|
98
|
+
if hasattr(tensor, "shape"):
|
|
99
|
+
if tensor.shape[0] % 3 != 0:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Cannot split QKV tensor '{key}': first dimension ({tensor.shape[0]}) "
|
|
102
|
+
"is not divisible by 3. The model file may be corrupted or incompatible."
|
|
103
|
+
)
|
|
104
|
+
dim = tensor.shape[0] // 3
|
|
105
|
+
q = tensor[:dim]
|
|
106
|
+
k = tensor[dim : 2 * dim]
|
|
107
|
+
v = tensor[2 * dim :]
|
|
108
|
+
|
|
109
|
+
new_sd[f"{prefix}.attention.to_q.{suffix}"] = q
|
|
110
|
+
new_sd[f"{prefix}.attention.to_k.{suffix}"] = k
|
|
111
|
+
new_sd[f"{prefix}.attention.to_v.{suffix}"] = v
|
|
112
|
+
continue
|
|
113
|
+
|
|
114
|
+
# Handle attention key renaming
|
|
115
|
+
if ".attention." in key:
|
|
116
|
+
new_key = key.replace(".q_norm.", ".norm_q.")
|
|
117
|
+
new_key = new_key.replace(".k_norm.", ".norm_k.")
|
|
118
|
+
new_key = new_key.replace(".attention.out.", ".attention.to_out.0.")
|
|
119
|
+
new_sd[new_key] = value
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# For all other keys, just copy as-is
|
|
123
|
+
new_sd[key] = value
|
|
124
|
+
|
|
125
|
+
return new_sd
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Diffusers)
|
|
129
|
+
class ZImageDiffusersModel(GenericDiffusersLoader):
|
|
130
|
+
"""Class to load Z-Image main models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
|
|
131
|
+
|
|
132
|
+
def _load_model(
|
|
133
|
+
self,
|
|
134
|
+
config: AnyModelConfig,
|
|
135
|
+
submodel_type: Optional[SubModelType] = None,
|
|
136
|
+
) -> AnyModel:
|
|
137
|
+
if isinstance(config, Checkpoint_Config_Base):
|
|
138
|
+
raise NotImplementedError("CheckpointConfigBase is not implemented for Z-Image models.")
|
|
139
|
+
|
|
140
|
+
if submodel_type is None:
|
|
141
|
+
raise Exception("A submodel type must be provided when loading main pipelines.")
|
|
142
|
+
|
|
143
|
+
model_path = Path(config.path)
|
|
144
|
+
load_class = self.get_hf_load_class(model_path, submodel_type)
|
|
145
|
+
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
|
|
146
|
+
variant = repo_variant.value if repo_variant else None
|
|
147
|
+
model_path = model_path / submodel_type.value
|
|
148
|
+
|
|
149
|
+
# Z-Image prefers bfloat16, but use safe dtype based on target device capabilities.
|
|
150
|
+
target_device = TorchDevice.choose_torch_device()
|
|
151
|
+
dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
152
|
+
try:
|
|
153
|
+
result: AnyModel = load_class.from_pretrained(
|
|
154
|
+
model_path,
|
|
155
|
+
torch_dtype=dtype,
|
|
156
|
+
variant=variant,
|
|
157
|
+
)
|
|
158
|
+
except OSError as e:
|
|
159
|
+
if variant and "no file named" in str(
|
|
160
|
+
e
|
|
161
|
+
): # try without the variant, just in case user's preferences changed
|
|
162
|
+
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
|
163
|
+
else:
|
|
164
|
+
raise e
|
|
165
|
+
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
@ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
|
170
|
+
class ZImageCheckpointModel(ModelLoader):
|
|
171
|
+
"""Class to load Z-Image transformer models from single-file checkpoints (safetensors, etc)."""
|
|
172
|
+
|
|
173
|
+
def _load_model(
|
|
174
|
+
self,
|
|
175
|
+
config: AnyModelConfig,
|
|
176
|
+
submodel_type: Optional[SubModelType] = None,
|
|
177
|
+
) -> AnyModel:
|
|
178
|
+
if not isinstance(config, Checkpoint_Config_Base):
|
|
179
|
+
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
|
180
|
+
|
|
181
|
+
match submodel_type:
|
|
182
|
+
case SubModelType.Transformer:
|
|
183
|
+
return self._load_from_singlefile(config)
|
|
184
|
+
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
def _load_from_singlefile(
|
|
190
|
+
self,
|
|
191
|
+
config: AnyModelConfig,
|
|
192
|
+
) -> AnyModel:
|
|
193
|
+
from diffusers import ZImageTransformer2DModel
|
|
194
|
+
from safetensors.torch import load_file
|
|
195
|
+
|
|
196
|
+
if not isinstance(config, Main_Checkpoint_ZImage_Config):
|
|
197
|
+
raise TypeError(
|
|
198
|
+
f"Expected Main_Checkpoint_ZImage_Config, got {type(config).__name__}. "
|
|
199
|
+
"Model configuration type mismatch."
|
|
200
|
+
)
|
|
201
|
+
model_path = Path(config.path)
|
|
202
|
+
|
|
203
|
+
# Load the state dict from safetensors/checkpoint file
|
|
204
|
+
sd = load_file(model_path)
|
|
205
|
+
|
|
206
|
+
# Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or
|
|
207
|
+
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
|
|
208
|
+
prefix_to_strip = None
|
|
209
|
+
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
|
210
|
+
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
|
211
|
+
prefix_to_strip = prefix
|
|
212
|
+
break
|
|
213
|
+
|
|
214
|
+
if prefix_to_strip:
|
|
215
|
+
stripped_sd = {}
|
|
216
|
+
for key, value in sd.items():
|
|
217
|
+
if isinstance(key, str) and key.startswith(prefix_to_strip):
|
|
218
|
+
stripped_sd[key[len(prefix_to_strip) :]] = value
|
|
219
|
+
else:
|
|
220
|
+
stripped_sd[key] = value
|
|
221
|
+
sd = stripped_sd
|
|
222
|
+
|
|
223
|
+
# Check if the state dict is in original format (not diffusers format)
|
|
224
|
+
# Original format has keys like "x_embedder.weight" instead of "all_x_embedder.2-1.weight"
|
|
225
|
+
needs_conversion = any(k.startswith("x_embedder.") for k in sd.keys() if isinstance(k, str))
|
|
226
|
+
|
|
227
|
+
if needs_conversion:
|
|
228
|
+
# Convert from original format to diffusers format
|
|
229
|
+
sd = _convert_z_image_gguf_to_diffusers(sd)
|
|
230
|
+
|
|
231
|
+
# Create an empty model with the default Z-Image config
|
|
232
|
+
# Z-Image-Turbo uses these default parameters from diffusers
|
|
233
|
+
with accelerate.init_empty_weights():
|
|
234
|
+
model = ZImageTransformer2DModel(
|
|
235
|
+
all_patch_size=(2,),
|
|
236
|
+
all_f_patch_size=(1,),
|
|
237
|
+
in_channels=16,
|
|
238
|
+
dim=3840,
|
|
239
|
+
n_layers=30,
|
|
240
|
+
n_refiner_layers=2,
|
|
241
|
+
n_heads=30,
|
|
242
|
+
n_kv_heads=30,
|
|
243
|
+
norm_eps=1e-05,
|
|
244
|
+
qk_norm=True,
|
|
245
|
+
cap_feat_dim=2560,
|
|
246
|
+
rope_theta=256.0,
|
|
247
|
+
t_scale=1000.0,
|
|
248
|
+
axes_dims=[32, 48, 48],
|
|
249
|
+
axes_lens=[1024, 512, 512],
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Determine safe dtype based on target device capabilities
|
|
253
|
+
target_device = TorchDevice.choose_torch_device()
|
|
254
|
+
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
255
|
+
|
|
256
|
+
# Handle memory management and dtype conversion
|
|
257
|
+
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
|
|
258
|
+
self._ram_cache.make_room(new_sd_size)
|
|
259
|
+
|
|
260
|
+
# Filter out FP8 scale_weight and scaled_fp8 metadata keys
|
|
261
|
+
# These are quantization metadata that shouldn't be loaded into the model
|
|
262
|
+
keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
|
|
263
|
+
for k in keys_to_remove:
|
|
264
|
+
del sd[k]
|
|
265
|
+
|
|
266
|
+
# Convert to target dtype
|
|
267
|
+
for k in sd.keys():
|
|
268
|
+
sd[k] = sd[k].to(model_dtype)
|
|
269
|
+
|
|
270
|
+
model.load_state_dict(sd, assign=True)
|
|
271
|
+
return model
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
@ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
|
|
275
|
+
class ZImageGGUFCheckpointModel(ModelLoader):
|
|
276
|
+
"""Class to load GGUF-quantized Z-Image transformer models."""
|
|
277
|
+
|
|
278
|
+
def _load_model(
|
|
279
|
+
self,
|
|
280
|
+
config: AnyModelConfig,
|
|
281
|
+
submodel_type: Optional[SubModelType] = None,
|
|
282
|
+
) -> AnyModel:
|
|
283
|
+
if not isinstance(config, Checkpoint_Config_Base):
|
|
284
|
+
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
|
285
|
+
|
|
286
|
+
match submodel_type:
|
|
287
|
+
case SubModelType.Transformer:
|
|
288
|
+
return self._load_from_singlefile(config)
|
|
289
|
+
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def _load_from_singlefile(
|
|
295
|
+
self,
|
|
296
|
+
config: AnyModelConfig,
|
|
297
|
+
) -> AnyModel:
|
|
298
|
+
from diffusers import ZImageTransformer2DModel
|
|
299
|
+
|
|
300
|
+
if not isinstance(config, Main_GGUF_ZImage_Config):
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Expected Main_GGUF_ZImage_Config, got {type(config).__name__}. Model configuration type mismatch."
|
|
303
|
+
)
|
|
304
|
+
model_path = Path(config.path)
|
|
305
|
+
|
|
306
|
+
# Determine safe dtype based on target device capabilities
|
|
307
|
+
target_device = TorchDevice.choose_torch_device()
|
|
308
|
+
compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
309
|
+
|
|
310
|
+
# Load the GGUF state dict
|
|
311
|
+
sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
|
|
312
|
+
|
|
313
|
+
# Some Z-Image GGUF models have keys prefixed with "diffusion_model." or
|
|
314
|
+
# "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
|
|
315
|
+
prefix_to_strip = None
|
|
316
|
+
for prefix in ["model.diffusion_model.", "diffusion_model."]:
|
|
317
|
+
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
|
|
318
|
+
prefix_to_strip = prefix
|
|
319
|
+
break
|
|
320
|
+
|
|
321
|
+
if prefix_to_strip:
|
|
322
|
+
stripped_sd = {}
|
|
323
|
+
for key, value in sd.items():
|
|
324
|
+
if isinstance(key, str) and key.startswith(prefix_to_strip):
|
|
325
|
+
stripped_sd[key[len(prefix_to_strip) :]] = value
|
|
326
|
+
else:
|
|
327
|
+
stripped_sd[key] = value
|
|
328
|
+
sd = stripped_sd
|
|
329
|
+
|
|
330
|
+
# Convert GGUF format keys to diffusers format
|
|
331
|
+
sd = _convert_z_image_gguf_to_diffusers(sd)
|
|
332
|
+
|
|
333
|
+
# Create an empty model with the default Z-Image config
|
|
334
|
+
# Z-Image-Turbo uses these default parameters from diffusers
|
|
335
|
+
with accelerate.init_empty_weights():
|
|
336
|
+
model = ZImageTransformer2DModel(
|
|
337
|
+
all_patch_size=(2,),
|
|
338
|
+
all_f_patch_size=(1,),
|
|
339
|
+
in_channels=16,
|
|
340
|
+
dim=3840,
|
|
341
|
+
n_layers=30,
|
|
342
|
+
n_refiner_layers=2,
|
|
343
|
+
n_heads=30,
|
|
344
|
+
n_kv_heads=30,
|
|
345
|
+
norm_eps=1e-05,
|
|
346
|
+
qk_norm=True,
|
|
347
|
+
cap_feat_dim=2560,
|
|
348
|
+
rope_theta=256.0,
|
|
349
|
+
t_scale=1000.0,
|
|
350
|
+
axes_dims=[32, 48, 48],
|
|
351
|
+
axes_lens=[1024, 512, 512],
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
model.load_state_dict(sd, assign=True)
|
|
355
|
+
return model
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Qwen3Encoder)
|
|
359
|
+
class Qwen3EncoderLoader(ModelLoader):
|
|
360
|
+
"""Class to load standalone Qwen3 Encoder models for Z-Image (directory format)."""
|
|
361
|
+
|
|
362
|
+
def _load_model(
|
|
363
|
+
self,
|
|
364
|
+
config: AnyModelConfig,
|
|
365
|
+
submodel_type: Optional[SubModelType] = None,
|
|
366
|
+
) -> AnyModel:
|
|
367
|
+
if not isinstance(config, Qwen3Encoder_Qwen3Encoder_Config):
|
|
368
|
+
raise ValueError("Only Qwen3Encoder_Qwen3Encoder_Config models are supported here.")
|
|
369
|
+
|
|
370
|
+
model_path = Path(config.path)
|
|
371
|
+
|
|
372
|
+
# Support both structures:
|
|
373
|
+
# 1. Full model: model_root/text_encoder/ and model_root/tokenizer/
|
|
374
|
+
# 2. Standalone download: model_root/ contains text_encoder files directly
|
|
375
|
+
text_encoder_path = model_path / "text_encoder"
|
|
376
|
+
tokenizer_path = model_path / "tokenizer"
|
|
377
|
+
|
|
378
|
+
# Check if this is a standalone text_encoder download (no nested text_encoder folder)
|
|
379
|
+
is_standalone = not text_encoder_path.exists() and (model_path / "config.json").exists()
|
|
380
|
+
|
|
381
|
+
if is_standalone:
|
|
382
|
+
text_encoder_path = model_path
|
|
383
|
+
tokenizer_path = model_path # Tokenizer files should also be in root
|
|
384
|
+
|
|
385
|
+
match submodel_type:
|
|
386
|
+
case SubModelType.Tokenizer:
|
|
387
|
+
return AutoTokenizer.from_pretrained(tokenizer_path)
|
|
388
|
+
case SubModelType.TextEncoder:
|
|
389
|
+
# Determine safe dtype based on target device capabilities
|
|
390
|
+
target_device = TorchDevice.choose_torch_device()
|
|
391
|
+
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
392
|
+
return Qwen3ForCausalLM.from_pretrained(
|
|
393
|
+
text_encoder_path,
|
|
394
|
+
torch_dtype=model_dtype,
|
|
395
|
+
low_cpu_mem_usage=True,
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
raise ValueError(
|
|
399
|
+
f"Only Tokenizer and TextEncoder submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
|
404
|
+
class ZImageControlCheckpointModel(ModelLoader):
|
|
405
|
+
"""Class to load Z-Image Control adapter models from safetensors checkpoint.
|
|
406
|
+
|
|
407
|
+
Z-Image Control models are standalone adapters containing control layers
|
|
408
|
+
(control_layers, control_all_x_embedder, control_noise_refiner) that can be
|
|
409
|
+
combined with a base ZImageTransformer2DModel at runtime for spatial conditioning
|
|
410
|
+
(Canny, HED, Depth, Pose, MLSD).
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
def _load_model(
|
|
414
|
+
self,
|
|
415
|
+
config: AnyModelConfig,
|
|
416
|
+
submodel_type: Optional[SubModelType] = None,
|
|
417
|
+
) -> AnyModel:
|
|
418
|
+
if not isinstance(config, Checkpoint_Config_Base):
|
|
419
|
+
raise ValueError("Only CheckpointConfigBase models are supported here.")
|
|
420
|
+
|
|
421
|
+
# ControlNet type models don't use submodel_type - load the adapter directly
|
|
422
|
+
return self._load_control_adapter(config)
|
|
423
|
+
|
|
424
|
+
def _load_control_adapter(
|
|
425
|
+
self,
|
|
426
|
+
config: AnyModelConfig,
|
|
427
|
+
) -> AnyModel:
|
|
428
|
+
from safetensors.torch import load_file
|
|
429
|
+
|
|
430
|
+
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
|
|
431
|
+
|
|
432
|
+
assert isinstance(config, ControlNet_Checkpoint_ZImage_Config)
|
|
433
|
+
model_path = Path(config.path)
|
|
434
|
+
|
|
435
|
+
# Load the safetensors state dict
|
|
436
|
+
sd = load_file(model_path)
|
|
437
|
+
|
|
438
|
+
# Determine number of control blocks from state dict
|
|
439
|
+
# Control blocks are named control_layers.0, control_layers.1, etc.
|
|
440
|
+
control_block_indices = set()
|
|
441
|
+
for key in sd.keys():
|
|
442
|
+
if key.startswith("control_layers."):
|
|
443
|
+
parts = key.split(".")
|
|
444
|
+
if len(parts) > 1 and parts[1].isdigit():
|
|
445
|
+
control_block_indices.add(int(parts[1]))
|
|
446
|
+
num_control_blocks = len(control_block_indices) if control_block_indices else 6
|
|
447
|
+
|
|
448
|
+
# Determine number of refiner layers from state dict
|
|
449
|
+
refiner_indices: set[int] = set()
|
|
450
|
+
for key in sd.keys():
|
|
451
|
+
if key.startswith("control_noise_refiner."):
|
|
452
|
+
parts = key.split(".")
|
|
453
|
+
if len(parts) > 1 and parts[1].isdigit():
|
|
454
|
+
refiner_indices.add(int(parts[1]))
|
|
455
|
+
n_refiner_layers = len(refiner_indices) if refiner_indices else 2
|
|
456
|
+
|
|
457
|
+
# Determine control_in_dim from embedder weight shape
|
|
458
|
+
# control_in_dim = weight.shape[1] / (f_patch_size * patch_size * patch_size)
|
|
459
|
+
# For patch_size=2, f_patch_size=1: control_in_dim = weight.shape[1] / 4
|
|
460
|
+
control_in_dim = 16 # Default for V1
|
|
461
|
+
embedder_key = "control_all_x_embedder.2-1.weight"
|
|
462
|
+
if embedder_key in sd:
|
|
463
|
+
weight_shape = sd[embedder_key].shape
|
|
464
|
+
# weight_shape[1] = f_patch_size * patch_size * patch_size * control_in_dim
|
|
465
|
+
control_in_dim = weight_shape[1] // 4 # 4 = 1 * 2 * 2
|
|
466
|
+
|
|
467
|
+
# Log detected configuration for debugging
|
|
468
|
+
from invokeai.backend.util.logging import InvokeAILogger
|
|
469
|
+
|
|
470
|
+
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
|
471
|
+
version = "V2.0" if control_in_dim > 16 else "V1"
|
|
472
|
+
logger.info(
|
|
473
|
+
f"Z-Image ControlNet detected: {version} "
|
|
474
|
+
f"(control_in_dim={control_in_dim}, num_control_blocks={num_control_blocks}, "
|
|
475
|
+
f"n_refiner_layers={n_refiner_layers})"
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Create an empty control adapter
|
|
479
|
+
dim = 3840
|
|
480
|
+
with accelerate.init_empty_weights():
|
|
481
|
+
model = ZImageControlAdapter(
|
|
482
|
+
num_control_blocks=num_control_blocks,
|
|
483
|
+
control_in_dim=control_in_dim,
|
|
484
|
+
all_patch_size=(2,),
|
|
485
|
+
all_f_patch_size=(1,),
|
|
486
|
+
dim=dim,
|
|
487
|
+
n_refiner_layers=n_refiner_layers,
|
|
488
|
+
n_heads=30,
|
|
489
|
+
n_kv_heads=30,
|
|
490
|
+
norm_eps=1e-05,
|
|
491
|
+
qk_norm=True,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Load state dict with strict=False to handle missing keys like x_pad_token
|
|
495
|
+
# Some control adapters may not include x_pad_token in their checkpoint
|
|
496
|
+
missing_keys, unexpected_keys = model.load_state_dict(sd, assign=True, strict=False)
|
|
497
|
+
|
|
498
|
+
# Initialize x_pad_token if it was missing from the checkpoint
|
|
499
|
+
if "x_pad_token" in missing_keys:
|
|
500
|
+
import torch.nn as nn
|
|
501
|
+
|
|
502
|
+
model.x_pad_token = nn.Parameter(torch.empty(dim))
|
|
503
|
+
nn.init.normal_(model.x_pad_token, std=0.02)
|
|
504
|
+
|
|
505
|
+
return model
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Checkpoint)
|
|
509
|
+
class Qwen3EncoderCheckpointLoader(ModelLoader):
|
|
510
|
+
"""Class to load single-file Qwen3 Encoder models for Z-Image (safetensors format)."""
|
|
511
|
+
|
|
512
|
+
# Default HuggingFace model to load tokenizer from when using single-file Qwen3 encoder
|
|
513
|
+
# Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
|
|
514
|
+
DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
|
|
515
|
+
|
|
516
|
+
def _load_model(
|
|
517
|
+
self,
|
|
518
|
+
config: AnyModelConfig,
|
|
519
|
+
submodel_type: Optional[SubModelType] = None,
|
|
520
|
+
) -> AnyModel:
|
|
521
|
+
if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
|
|
522
|
+
raise ValueError("Only Qwen3Encoder_Checkpoint_Config models are supported here.")
|
|
523
|
+
|
|
524
|
+
match submodel_type:
|
|
525
|
+
case SubModelType.TextEncoder:
|
|
526
|
+
return self._load_from_singlefile(config)
|
|
527
|
+
case SubModelType.Tokenizer:
|
|
528
|
+
# For single-file Qwen3, load tokenizer from HuggingFace
|
|
529
|
+
return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
|
|
530
|
+
|
|
531
|
+
raise ValueError(
|
|
532
|
+
f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
def _load_from_singlefile(
|
|
536
|
+
self,
|
|
537
|
+
config: AnyModelConfig,
|
|
538
|
+
) -> AnyModel:
|
|
539
|
+
from safetensors.torch import load_file
|
|
540
|
+
from transformers import Qwen3Config, Qwen3ForCausalLM
|
|
541
|
+
|
|
542
|
+
from invokeai.backend.util.logging import InvokeAILogger
|
|
543
|
+
|
|
544
|
+
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
|
545
|
+
|
|
546
|
+
if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
|
|
547
|
+
raise TypeError(
|
|
548
|
+
f"Expected Qwen3Encoder_Checkpoint_Config, got {type(config).__name__}. "
|
|
549
|
+
"Model configuration type mismatch."
|
|
550
|
+
)
|
|
551
|
+
model_path = Path(config.path)
|
|
552
|
+
|
|
553
|
+
# Determine safe dtype based on target device capabilities
|
|
554
|
+
target_device = TorchDevice.choose_torch_device()
|
|
555
|
+
model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
556
|
+
|
|
557
|
+
# Load the state dict from safetensors file
|
|
558
|
+
sd = load_file(model_path)
|
|
559
|
+
|
|
560
|
+
# Determine Qwen model configuration from state dict
|
|
561
|
+
# Count the number of layers by looking at layer keys
|
|
562
|
+
layer_count = 0
|
|
563
|
+
for key in sd.keys():
|
|
564
|
+
if isinstance(key, str) and key.startswith("model.layers."):
|
|
565
|
+
parts = key.split(".")
|
|
566
|
+
if len(parts) > 2:
|
|
567
|
+
try:
|
|
568
|
+
layer_idx = int(parts[2])
|
|
569
|
+
layer_count = max(layer_count, layer_idx + 1)
|
|
570
|
+
except ValueError:
|
|
571
|
+
pass
|
|
572
|
+
|
|
573
|
+
# Get hidden size from embed_tokens weight shape
|
|
574
|
+
embed_weight = sd.get("model.embed_tokens.weight")
|
|
575
|
+
if embed_weight is None:
|
|
576
|
+
raise ValueError("Could not find model.embed_tokens.weight in state dict")
|
|
577
|
+
if embed_weight.ndim != 2:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
f"Expected 2D embed_tokens weight tensor, got shape {embed_weight.shape}. "
|
|
580
|
+
"The model file may be corrupted or incompatible."
|
|
581
|
+
)
|
|
582
|
+
hidden_size = embed_weight.shape[1]
|
|
583
|
+
vocab_size = embed_weight.shape[0]
|
|
584
|
+
|
|
585
|
+
# Detect attention configuration from layer 0 weights
|
|
586
|
+
q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
|
|
587
|
+
k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
|
|
588
|
+
gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
|
|
589
|
+
|
|
590
|
+
if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
|
|
591
|
+
raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
|
|
592
|
+
|
|
593
|
+
# Calculate dimensions from actual weights
|
|
594
|
+
# Qwen3 uses head_dim separately from hidden_size
|
|
595
|
+
head_dim = 128 # Standard head dimension for Qwen3 models
|
|
596
|
+
num_attention_heads = q_proj_weight.shape[0] // head_dim
|
|
597
|
+
num_kv_heads = k_proj_weight.shape[0] // head_dim
|
|
598
|
+
intermediate_size = gate_proj_weight.shape[0]
|
|
599
|
+
|
|
600
|
+
# Create Qwen3 config - matches the diffusers text_encoder/config.json
|
|
601
|
+
qwen_config = Qwen3Config(
|
|
602
|
+
vocab_size=vocab_size,
|
|
603
|
+
hidden_size=hidden_size,
|
|
604
|
+
intermediate_size=intermediate_size,
|
|
605
|
+
num_hidden_layers=layer_count,
|
|
606
|
+
num_attention_heads=num_attention_heads,
|
|
607
|
+
num_key_value_heads=num_kv_heads,
|
|
608
|
+
head_dim=head_dim,
|
|
609
|
+
max_position_embeddings=40960,
|
|
610
|
+
rms_norm_eps=1e-6,
|
|
611
|
+
tie_word_embeddings=True,
|
|
612
|
+
rope_theta=1000000.0,
|
|
613
|
+
use_sliding_window=False,
|
|
614
|
+
attention_bias=False,
|
|
615
|
+
attention_dropout=0.0,
|
|
616
|
+
torch_dtype=model_dtype,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
# Handle memory management
|
|
620
|
+
new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
|
|
621
|
+
self._ram_cache.make_room(new_sd_size)
|
|
622
|
+
|
|
623
|
+
# Convert to target dtype
|
|
624
|
+
for k in sd.keys():
|
|
625
|
+
sd[k] = sd[k].to(model_dtype)
|
|
626
|
+
|
|
627
|
+
# Use Qwen3ForCausalLM - the correct model class for Z-Image text encoder
|
|
628
|
+
# Use init_empty_weights for fast model creation, then load weights with assign=True
|
|
629
|
+
with accelerate.init_empty_weights():
|
|
630
|
+
model = Qwen3ForCausalLM(qwen_config)
|
|
631
|
+
|
|
632
|
+
# Load the text model weights from checkpoint
|
|
633
|
+
# assign=True replaces meta tensors with real ones from state dict
|
|
634
|
+
model.load_state_dict(sd, strict=False, assign=True)
|
|
635
|
+
|
|
636
|
+
# Handle tied weights: lm_head shares weight with embed_tokens when tie_word_embeddings=True
|
|
637
|
+
# This doesn't work automatically with init_empty_weights, so we need to manually tie them
|
|
638
|
+
if qwen_config.tie_word_embeddings:
|
|
639
|
+
model.tie_weights()
|
|
640
|
+
|
|
641
|
+
# Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
|
|
642
|
+
# These are computed from config, not loaded from checkpoint
|
|
643
|
+
for name, buffer in list(model.named_buffers()):
|
|
644
|
+
if buffer.is_meta:
|
|
645
|
+
# Get parent module and buffer name
|
|
646
|
+
parts = name.rsplit(".", 1)
|
|
647
|
+
if len(parts) == 2:
|
|
648
|
+
parent = model.get_submodule(parts[0])
|
|
649
|
+
buffer_name = parts[1]
|
|
650
|
+
else:
|
|
651
|
+
parent = model
|
|
652
|
+
buffer_name = name
|
|
653
|
+
|
|
654
|
+
# Re-initialize the buffer based on expected shape and dtype
|
|
655
|
+
# For rotary embeddings, this is inv_freq which is computed from config
|
|
656
|
+
if buffer_name == "inv_freq":
|
|
657
|
+
# Compute inv_freq from config (same logic as Qwen3RotaryEmbedding.__init__)
|
|
658
|
+
base = qwen_config.rope_theta
|
|
659
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
|
660
|
+
parent.register_buffer(buffer_name, inv_freq.to(model_dtype), persistent=False)
|
|
661
|
+
else:
|
|
662
|
+
# For other buffers, log warning
|
|
663
|
+
logger.warning(f"Re-initializing unknown meta buffer: {name}")
|
|
664
|
+
|
|
665
|
+
return model
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.GGUFQuantized)
|
|
669
|
+
class Qwen3EncoderGGUFLoader(ModelLoader):
|
|
670
|
+
"""Class to load GGUF-quantized Qwen3 Encoder models for Z-Image."""
|
|
671
|
+
|
|
672
|
+
# Default HuggingFace model to load tokenizer from when using GGUF Qwen3 encoder
|
|
673
|
+
# Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
|
|
674
|
+
DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
|
|
675
|
+
|
|
676
|
+
def _load_model(
|
|
677
|
+
self,
|
|
678
|
+
config: AnyModelConfig,
|
|
679
|
+
submodel_type: Optional[SubModelType] = None,
|
|
680
|
+
) -> AnyModel:
|
|
681
|
+
if not isinstance(config, Qwen3Encoder_GGUF_Config):
|
|
682
|
+
raise ValueError("Only Qwen3Encoder_GGUF_Config models are supported here.")
|
|
683
|
+
|
|
684
|
+
match submodel_type:
|
|
685
|
+
case SubModelType.TextEncoder:
|
|
686
|
+
return self._load_from_gguf(config)
|
|
687
|
+
case SubModelType.Tokenizer:
|
|
688
|
+
# For GGUF Qwen3, load tokenizer from HuggingFace
|
|
689
|
+
return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
|
|
690
|
+
|
|
691
|
+
raise ValueError(
|
|
692
|
+
f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
def _load_from_gguf(
|
|
696
|
+
self,
|
|
697
|
+
config: AnyModelConfig,
|
|
698
|
+
) -> AnyModel:
|
|
699
|
+
from transformers import Qwen3Config, Qwen3ForCausalLM
|
|
700
|
+
|
|
701
|
+
from invokeai.backend.util.logging import InvokeAILogger
|
|
702
|
+
|
|
703
|
+
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
|
704
|
+
|
|
705
|
+
if not isinstance(config, Qwen3Encoder_GGUF_Config):
|
|
706
|
+
raise TypeError(
|
|
707
|
+
f"Expected Qwen3Encoder_GGUF_Config, got {type(config).__name__}. Model configuration type mismatch."
|
|
708
|
+
)
|
|
709
|
+
model_path = Path(config.path)
|
|
710
|
+
|
|
711
|
+
# Determine safe dtype based on target device capabilities
|
|
712
|
+
target_device = TorchDevice.choose_torch_device()
|
|
713
|
+
compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
|
|
714
|
+
|
|
715
|
+
# Load the GGUF state dict - this returns GGMLTensor wrappers (on CPU)
|
|
716
|
+
# We keep them on CPU and let the model cache system handle GPU movement
|
|
717
|
+
# via apply_custom_layers_to_model() and the partial loading cache
|
|
718
|
+
sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
|
|
719
|
+
|
|
720
|
+
# Check if this is llama.cpp format (blk.X.) or PyTorch format (model.layers.X.)
|
|
721
|
+
is_llamacpp_format = any(k.startswith("blk.") for k in sd.keys() if isinstance(k, str))
|
|
722
|
+
|
|
723
|
+
if is_llamacpp_format:
|
|
724
|
+
logger.info("Detected llama.cpp GGUF format, converting keys to PyTorch format")
|
|
725
|
+
sd = self._convert_llamacpp_to_pytorch(sd)
|
|
726
|
+
|
|
727
|
+
# Determine Qwen model configuration from state dict
|
|
728
|
+
# Count the number of layers by looking at layer keys
|
|
729
|
+
layer_count = 0
|
|
730
|
+
for key in sd.keys():
|
|
731
|
+
if isinstance(key, str) and key.startswith("model.layers."):
|
|
732
|
+
parts = key.split(".")
|
|
733
|
+
if len(parts) > 2:
|
|
734
|
+
try:
|
|
735
|
+
layer_idx = int(parts[2])
|
|
736
|
+
layer_count = max(layer_count, layer_idx + 1)
|
|
737
|
+
except ValueError:
|
|
738
|
+
pass
|
|
739
|
+
|
|
740
|
+
# Get hidden size from embed_tokens weight shape
|
|
741
|
+
embed_weight = sd.get("model.embed_tokens.weight")
|
|
742
|
+
if embed_weight is None:
|
|
743
|
+
raise ValueError("Could not find model.embed_tokens.weight in state dict")
|
|
744
|
+
|
|
745
|
+
# Handle GGMLTensor shape access
|
|
746
|
+
embed_shape = embed_weight.shape if hasattr(embed_weight, "shape") else embed_weight.tensor_shape
|
|
747
|
+
if len(embed_shape) != 2:
|
|
748
|
+
raise ValueError(
|
|
749
|
+
f"Expected 2D embed_tokens weight tensor, got shape {embed_shape}. "
|
|
750
|
+
"The model file may be corrupted or incompatible."
|
|
751
|
+
)
|
|
752
|
+
hidden_size = embed_shape[1]
|
|
753
|
+
vocab_size = embed_shape[0]
|
|
754
|
+
|
|
755
|
+
# Detect attention configuration from layer 0 weights
|
|
756
|
+
q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
|
|
757
|
+
k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
|
|
758
|
+
gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
|
|
759
|
+
|
|
760
|
+
if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
|
|
761
|
+
raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
|
|
762
|
+
|
|
763
|
+
# Handle GGMLTensor shape access
|
|
764
|
+
q_shape = q_proj_weight.shape if hasattr(q_proj_weight, "shape") else q_proj_weight.tensor_shape
|
|
765
|
+
k_shape = k_proj_weight.shape if hasattr(k_proj_weight, "shape") else k_proj_weight.tensor_shape
|
|
766
|
+
gate_shape = gate_proj_weight.shape if hasattr(gate_proj_weight, "shape") else gate_proj_weight.tensor_shape
|
|
767
|
+
|
|
768
|
+
# Calculate dimensions from actual weights
|
|
769
|
+
head_dim = 128 # Standard head dimension for Qwen3 models
|
|
770
|
+
num_attention_heads = q_shape[0] // head_dim
|
|
771
|
+
num_kv_heads = k_shape[0] // head_dim
|
|
772
|
+
intermediate_size = gate_shape[0]
|
|
773
|
+
|
|
774
|
+
logger.info(
|
|
775
|
+
f"Qwen3 GGUF Encoder config detected: layers={layer_count}, hidden={hidden_size}, "
|
|
776
|
+
f"heads={num_attention_heads}, kv_heads={num_kv_heads}, intermediate={intermediate_size}, "
|
|
777
|
+
f"head_dim={head_dim}"
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# Create Qwen3 config
|
|
781
|
+
qwen_config = Qwen3Config(
|
|
782
|
+
vocab_size=vocab_size,
|
|
783
|
+
hidden_size=hidden_size,
|
|
784
|
+
intermediate_size=intermediate_size,
|
|
785
|
+
num_hidden_layers=layer_count,
|
|
786
|
+
num_attention_heads=num_attention_heads,
|
|
787
|
+
num_key_value_heads=num_kv_heads,
|
|
788
|
+
head_dim=head_dim,
|
|
789
|
+
max_position_embeddings=40960,
|
|
790
|
+
rms_norm_eps=1e-6,
|
|
791
|
+
tie_word_embeddings=True,
|
|
792
|
+
rope_theta=1000000.0,
|
|
793
|
+
use_sliding_window=False,
|
|
794
|
+
attention_bias=False,
|
|
795
|
+
attention_dropout=0.0,
|
|
796
|
+
torch_dtype=compute_dtype,
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# Use Qwen3ForCausalLM with empty weights, then load GGUF tensors
|
|
800
|
+
with accelerate.init_empty_weights():
|
|
801
|
+
model = Qwen3ForCausalLM(qwen_config)
|
|
802
|
+
|
|
803
|
+
# Load the GGUF weights with assign=True
|
|
804
|
+
# GGMLTensor wrappers will be dequantized on-the-fly during inference
|
|
805
|
+
model.load_state_dict(sd, strict=False, assign=True)
|
|
806
|
+
|
|
807
|
+
# Dequantize embed_tokens weight - embedding lookups require indexed access
|
|
808
|
+
# which quantized GGMLTensors can't efficiently provide (no __torch_dispatch__ for embedding)
|
|
809
|
+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
|
810
|
+
|
|
811
|
+
embed_tokens_weight = model.model.embed_tokens.weight
|
|
812
|
+
if isinstance(embed_tokens_weight, GGMLTensor):
|
|
813
|
+
dequantized = embed_tokens_weight.get_dequantized_tensor()
|
|
814
|
+
model.model.embed_tokens.weight = torch.nn.Parameter(dequantized, requires_grad=False)
|
|
815
|
+
logger.info("Dequantized embed_tokens weight for embedding lookups")
|
|
816
|
+
|
|
817
|
+
# Handle tied weights - llama.cpp GGUF doesn't include lm_head.weight when embeddings are tied
|
|
818
|
+
# So we need to manually tie them after loading
|
|
819
|
+
if qwen_config.tie_word_embeddings:
|
|
820
|
+
# Check if lm_head.weight is still a meta tensor (wasn't in GGUF state dict)
|
|
821
|
+
if model.lm_head.weight.is_meta:
|
|
822
|
+
# Directly assign embed_tokens weight to lm_head (now dequantized)
|
|
823
|
+
model.lm_head.weight = model.model.embed_tokens.weight
|
|
824
|
+
logger.info("Tied lm_head.weight to embed_tokens.weight (GGUF tied embeddings)")
|
|
825
|
+
else:
|
|
826
|
+
# If lm_head.weight was loaded, use standard tie_weights
|
|
827
|
+
model.tie_weights()
|
|
828
|
+
|
|
829
|
+
# Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
|
|
830
|
+
for name, buffer in list(model.named_buffers()):
|
|
831
|
+
if buffer.is_meta:
|
|
832
|
+
parts = name.rsplit(".", 1)
|
|
833
|
+
if len(parts) == 2:
|
|
834
|
+
parent = model.get_submodule(parts[0])
|
|
835
|
+
buffer_name = parts[1]
|
|
836
|
+
else:
|
|
837
|
+
parent = model
|
|
838
|
+
buffer_name = name
|
|
839
|
+
|
|
840
|
+
if buffer_name == "inv_freq":
|
|
841
|
+
# Compute inv_freq from config - keep on CPU, cache system will move to GPU as needed
|
|
842
|
+
base = qwen_config.rope_theta
|
|
843
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
|
844
|
+
parent.register_buffer(buffer_name, inv_freq.to(dtype=compute_dtype), persistent=False)
|
|
845
|
+
else:
|
|
846
|
+
logger.warning(f"Re-initializing unknown meta buffer: {name}")
|
|
847
|
+
|
|
848
|
+
# Final check: ensure no meta tensors remain in parameters
|
|
849
|
+
meta_params = [(name, p) for name, p in model.named_parameters() if p.is_meta]
|
|
850
|
+
if meta_params:
|
|
851
|
+
meta_names = [name for name, _ in meta_params]
|
|
852
|
+
raise RuntimeError(
|
|
853
|
+
f"Failed to load all parameters from GGUF. The following remain as meta tensors: {meta_names}. "
|
|
854
|
+
"This may indicate missing keys in the GGUF file or a key mapping issue."
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
return model
|
|
858
|
+
|
|
859
|
+
def _convert_llamacpp_to_pytorch(self, sd: dict[str, Any]) -> dict[str, Any]:
|
|
860
|
+
"""Convert llama.cpp GGUF keys to PyTorch/HuggingFace format for Qwen models.
|
|
861
|
+
|
|
862
|
+
llama.cpp format:
|
|
863
|
+
- blk.X.attn_q.weight -> model.layers.X.self_attn.q_proj.weight
|
|
864
|
+
- blk.X.attn_k.weight -> model.layers.X.self_attn.k_proj.weight
|
|
865
|
+
- blk.X.attn_v.weight -> model.layers.X.self_attn.v_proj.weight
|
|
866
|
+
- blk.X.attn_output.weight -> model.layers.X.self_attn.o_proj.weight
|
|
867
|
+
- blk.X.attn_q_norm.weight -> model.layers.X.self_attn.q_norm.weight (Qwen3 QK norm)
|
|
868
|
+
- blk.X.attn_k_norm.weight -> model.layers.X.self_attn.k_norm.weight (Qwen3 QK norm)
|
|
869
|
+
- blk.X.ffn_gate.weight -> model.layers.X.mlp.gate_proj.weight
|
|
870
|
+
- blk.X.ffn_up.weight -> model.layers.X.mlp.up_proj.weight
|
|
871
|
+
- blk.X.ffn_down.weight -> model.layers.X.mlp.down_proj.weight
|
|
872
|
+
- blk.X.attn_norm.weight -> model.layers.X.input_layernorm.weight
|
|
873
|
+
- blk.X.ffn_norm.weight -> model.layers.X.post_attention_layernorm.weight
|
|
874
|
+
- token_embd.weight -> model.embed_tokens.weight
|
|
875
|
+
- output_norm.weight -> model.norm.weight
|
|
876
|
+
- output.weight -> lm_head.weight (if not tied)
|
|
877
|
+
"""
|
|
878
|
+
import re
|
|
879
|
+
|
|
880
|
+
key_map = {
|
|
881
|
+
"attn_q": "self_attn.q_proj",
|
|
882
|
+
"attn_k": "self_attn.k_proj",
|
|
883
|
+
"attn_v": "self_attn.v_proj",
|
|
884
|
+
"attn_output": "self_attn.o_proj",
|
|
885
|
+
"attn_q_norm": "self_attn.q_norm", # Qwen3 QK normalization
|
|
886
|
+
"attn_k_norm": "self_attn.k_norm", # Qwen3 QK normalization
|
|
887
|
+
"ffn_gate": "mlp.gate_proj",
|
|
888
|
+
"ffn_up": "mlp.up_proj",
|
|
889
|
+
"ffn_down": "mlp.down_proj",
|
|
890
|
+
"attn_norm": "input_layernorm",
|
|
891
|
+
"ffn_norm": "post_attention_layernorm",
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
new_sd: dict[str, Any] = {}
|
|
895
|
+
blk_pattern = re.compile(r"^blk\.(\d+)\.(.+)$")
|
|
896
|
+
|
|
897
|
+
for key, value in sd.items():
|
|
898
|
+
if not isinstance(key, str):
|
|
899
|
+
new_sd[key] = value
|
|
900
|
+
continue
|
|
901
|
+
|
|
902
|
+
# Handle block layers
|
|
903
|
+
match = blk_pattern.match(key)
|
|
904
|
+
if match:
|
|
905
|
+
layer_idx = match.group(1)
|
|
906
|
+
rest = match.group(2)
|
|
907
|
+
|
|
908
|
+
# Split rest into component and suffix (e.g., "attn_q.weight" -> "attn_q", "weight")
|
|
909
|
+
parts = rest.split(".", 1)
|
|
910
|
+
component = parts[0]
|
|
911
|
+
suffix = parts[1] if len(parts) > 1 else ""
|
|
912
|
+
|
|
913
|
+
if component in key_map:
|
|
914
|
+
new_component = key_map[component]
|
|
915
|
+
new_key = f"model.layers.{layer_idx}.{new_component}"
|
|
916
|
+
if suffix:
|
|
917
|
+
new_key += f".{suffix}"
|
|
918
|
+
new_sd[new_key] = value
|
|
919
|
+
else:
|
|
920
|
+
# Unknown component, keep as-is with model.layers prefix
|
|
921
|
+
new_sd[f"model.layers.{layer_idx}.{rest}"] = value
|
|
922
|
+
continue
|
|
923
|
+
|
|
924
|
+
# Handle non-block keys
|
|
925
|
+
if key == "token_embd.weight":
|
|
926
|
+
new_sd["model.embed_tokens.weight"] = value
|
|
927
|
+
elif key == "output_norm.weight":
|
|
928
|
+
new_sd["model.norm.weight"] = value
|
|
929
|
+
elif key == "output.weight":
|
|
930
|
+
new_sd["lm_head.weight"] = value
|
|
931
|
+
else:
|
|
932
|
+
# Keep other keys as-is
|
|
933
|
+
new_sd[key] = value
|
|
934
|
+
|
|
935
|
+
return new_sd
|