InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/flux_denoise.py +15 -1
  6. invokeai/app/invocations/image_to_latents.py +23 -5
  7. invokeai/app/invocations/latents_to_image.py +2 -25
  8. invokeai/app/invocations/metadata.py +9 -1
  9. invokeai/app/invocations/metadata_linked.py +47 -0
  10. invokeai/app/invocations/model.py +8 -0
  11. invokeai/app/invocations/pbr_maps.py +59 -0
  12. invokeai/app/invocations/primitives.py +12 -0
  13. invokeai/app/invocations/prompt_template.py +57 -0
  14. invokeai/app/invocations/z_image_control.py +112 -0
  15. invokeai/app/invocations/z_image_denoise.py +770 -0
  16. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  17. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  18. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  19. invokeai/app/invocations/z_image_model_loader.py +135 -0
  20. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  21. invokeai/app/services/config/config_default.py +3 -1
  22. invokeai/app/services/model_install/model_install_common.py +14 -1
  23. invokeai/app/services/model_install/model_install_default.py +119 -19
  24. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  25. invokeai/app/services/model_records/model_records_base.py +12 -0
  26. invokeai/app/services/model_records/model_records_sql.py +17 -0
  27. invokeai/app/services/shared/graph.py +132 -77
  28. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  29. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  30. invokeai/app/util/step_callback.py +3 -0
  31. invokeai/backend/flux/denoise.py +196 -11
  32. invokeai/backend/flux/schedulers.py +62 -0
  33. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  34. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  35. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  36. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  37. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  38. invokeai/backend/model_manager/configs/factory.py +26 -1
  39. invokeai/backend/model_manager/configs/lora.py +79 -1
  40. invokeai/backend/model_manager/configs/main.py +113 -0
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  42. invokeai/backend/model_manager/load/model_cache/model_cache.py +104 -2
  43. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  44. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  45. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  46. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  47. invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
  48. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  49. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  50. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  51. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  52. invokeai/backend/model_manager/load/model_loaders/z_image.py +969 -0
  53. invokeai/backend/model_manager/load/model_util.py +6 -1
  54. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  55. invokeai/backend/model_manager/model_on_disk.py +3 -0
  56. invokeai/backend/model_manager/starter_models.py +79 -0
  57. invokeai/backend/model_manager/taxonomy.py +5 -0
  58. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  59. invokeai/backend/patches/layer_patcher.py +34 -16
  60. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  61. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  62. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  63. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  64. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  65. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +189 -0
  66. invokeai/backend/quantization/gguf/ggml_tensor.py +38 -4
  67. invokeai/backend/quantization/gguf/loaders.py +47 -12
  68. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  69. invokeai/backend/util/devices.py +25 -0
  70. invokeai/backend/util/hotfixes.py +2 -2
  71. invokeai/backend/z_image/__init__.py +16 -0
  72. invokeai/backend/z_image/extensions/__init__.py +1 -0
  73. invokeai/backend/z_image/extensions/regional_prompting_extension.py +205 -0
  74. invokeai/backend/z_image/text_conditioning.py +74 -0
  75. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  76. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  77. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  78. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  79. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  80. invokeai/frontend/web/dist/assets/App-BBELGD-n.js +161 -0
  81. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-4xPFTMT3.js} +1 -1
  82. invokeai/frontend/web/dist/assets/index-vCDSQboA.js +530 -0
  83. invokeai/frontend/web/dist/index.html +1 -1
  84. invokeai/frontend/web/dist/locales/de.json +24 -6
  85. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  86. invokeai/frontend/web/dist/locales/en.json +78 -3
  87. invokeai/frontend/web/dist/locales/es.json +0 -5
  88. invokeai/frontend/web/dist/locales/fr.json +0 -6
  89. invokeai/frontend/web/dist/locales/it.json +17 -64
  90. invokeai/frontend/web/dist/locales/ja.json +379 -44
  91. invokeai/frontend/web/dist/locales/ru.json +0 -6
  92. invokeai/frontend/web/dist/locales/vi.json +7 -54
  93. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  94. invokeai/version/invokeai_version.py +1 -1
  95. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/METADATA +4 -4
  96. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/RECORD +102 -71
  97. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  98. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  99. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/WHEEL +0 -0
  100. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/entry_points.txt +0 -0
  101. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE +0 -0
  102. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  103. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  104. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,969 @@
1
+ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2
+ """Class for Z-Image model loading in InvokeAI."""
3
+
4
+ from pathlib import Path
5
+ from typing import Any, Optional
6
+
7
+ import accelerate
8
+ import torch
9
+ from transformers import AutoTokenizer, Qwen3ForCausalLM
10
+
11
+ from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
12
+ from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_ZImage_Config
13
+ from invokeai.backend.model_manager.configs.factory import AnyModelConfig
14
+ from invokeai.backend.model_manager.configs.main import Main_Checkpoint_ZImage_Config, Main_GGUF_ZImage_Config
15
+ from invokeai.backend.model_manager.configs.qwen3_encoder import (
16
+ Qwen3Encoder_Checkpoint_Config,
17
+ Qwen3Encoder_GGUF_Config,
18
+ Qwen3Encoder_Qwen3Encoder_Config,
19
+ )
20
+ from invokeai.backend.model_manager.load.load_default import ModelLoader
21
+ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
22
+ from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
23
+ from invokeai.backend.model_manager.taxonomy import (
24
+ AnyModel,
25
+ BaseModelType,
26
+ ModelFormat,
27
+ ModelType,
28
+ SubModelType,
29
+ )
30
+ from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
31
+ from invokeai.backend.util.devices import TorchDevice
32
+
33
+
34
+ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
35
+ """Convert Z-Image GGUF state dict keys to diffusers format.
36
+
37
+ The GGUF format uses original model keys that differ from diffusers:
38
+ - qkv.weight (fused) -> to_q.weight, to_k.weight, to_v.weight (split)
39
+ - out.weight -> to_out.0.weight
40
+ - q_norm.weight -> norm_q.weight
41
+ - k_norm.weight -> norm_k.weight
42
+ - x_embedder.* -> all_x_embedder.2-1.*
43
+ - final_layer.* -> all_final_layer.2-1.*
44
+ - norm_final.* -> skipped (diffusers uses non-learnable LayerNorm)
45
+ - x_pad_token, cap_pad_token: [dim] -> [1, dim] (diffusers expects batch dimension)
46
+ """
47
+ new_sd: dict[str, Any] = {}
48
+
49
+ for key, value in sd.items():
50
+ if not isinstance(key, str):
51
+ new_sd[key] = value
52
+ continue
53
+
54
+ # Handle padding tokens: GGUF has shape [dim], diffusers expects [1, dim]
55
+ if key in ("x_pad_token", "cap_pad_token"):
56
+ if hasattr(value, "shape") and len(value.shape) == 1:
57
+ # GGMLTensor doesn't support unsqueeze, so dequantize first if needed
58
+ if hasattr(value, "get_dequantized_tensor"):
59
+ value = value.get_dequantized_tensor()
60
+ # Use reshape instead of unsqueeze for better compatibility
61
+ value = torch.as_tensor(value).reshape(1, -1)
62
+ new_sd[key] = value
63
+ continue
64
+
65
+ # Handle x_embedder -> all_x_embedder.2-1
66
+ if key.startswith("x_embedder."):
67
+ suffix = key[len("x_embedder.") :]
68
+ new_key = f"all_x_embedder.2-1.{suffix}"
69
+ new_sd[new_key] = value
70
+ continue
71
+
72
+ # Handle final_layer -> all_final_layer.2-1
73
+ if key.startswith("final_layer."):
74
+ suffix = key[len("final_layer.") :]
75
+ new_key = f"all_final_layer.2-1.{suffix}"
76
+ new_sd[new_key] = value
77
+ continue
78
+
79
+ # Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False
80
+ # (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros
81
+ if key.startswith("norm_final."):
82
+ continue
83
+
84
+ # Handle fused QKV weights - need to split
85
+ if ".attention.qkv." in key:
86
+ # Get the layer prefix and suffix
87
+ prefix = key.rsplit(".attention.qkv.", 1)[0]
88
+ suffix = key.rsplit(".attention.qkv.", 1)[1] # "weight" or "bias"
89
+
90
+ # Skip non-weight/bias tensors (e.g., FP8 scale_weight tensors)
91
+ # These are quantization metadata and should not be split
92
+ if suffix not in ("weight", "bias"):
93
+ new_sd[key] = value
94
+ continue
95
+
96
+ # Split the fused QKV tensor into Q, K, V
97
+ tensor = value
98
+ if hasattr(tensor, "shape"):
99
+ if tensor.shape[0] % 3 != 0:
100
+ raise ValueError(
101
+ f"Cannot split QKV tensor '{key}': first dimension ({tensor.shape[0]}) "
102
+ "is not divisible by 3. The model file may be corrupted or incompatible."
103
+ )
104
+ dim = tensor.shape[0] // 3
105
+ q = tensor[:dim]
106
+ k = tensor[dim : 2 * dim]
107
+ v = tensor[2 * dim :]
108
+
109
+ new_sd[f"{prefix}.attention.to_q.{suffix}"] = q
110
+ new_sd[f"{prefix}.attention.to_k.{suffix}"] = k
111
+ new_sd[f"{prefix}.attention.to_v.{suffix}"] = v
112
+ continue
113
+
114
+ # Handle attention key renaming
115
+ if ".attention." in key:
116
+ new_key = key.replace(".q_norm.", ".norm_q.")
117
+ new_key = new_key.replace(".k_norm.", ".norm_k.")
118
+ new_key = new_key.replace(".attention.out.", ".attention.to_out.0.")
119
+ new_sd[new_key] = value
120
+ continue
121
+
122
+ # For all other keys, just copy as-is
123
+ new_sd[key] = value
124
+
125
+ return new_sd
126
+
127
+
128
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Diffusers)
129
+ class ZImageDiffusersModel(GenericDiffusersLoader):
130
+ """Class to load Z-Image main models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
131
+
132
+ def _load_model(
133
+ self,
134
+ config: AnyModelConfig,
135
+ submodel_type: Optional[SubModelType] = None,
136
+ ) -> AnyModel:
137
+ if isinstance(config, Checkpoint_Config_Base):
138
+ raise NotImplementedError("CheckpointConfigBase is not implemented for Z-Image models.")
139
+
140
+ if submodel_type is None:
141
+ raise Exception("A submodel type must be provided when loading main pipelines.")
142
+
143
+ model_path = Path(config.path)
144
+ load_class = self.get_hf_load_class(model_path, submodel_type)
145
+ repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
146
+ variant = repo_variant.value if repo_variant else None
147
+ model_path = model_path / submodel_type.value
148
+
149
+ # Z-Image prefers bfloat16, but use safe dtype based on target device capabilities.
150
+ target_device = TorchDevice.choose_torch_device()
151
+ dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
152
+ try:
153
+ result: AnyModel = load_class.from_pretrained(
154
+ model_path,
155
+ torch_dtype=dtype,
156
+ variant=variant,
157
+ )
158
+ except OSError as e:
159
+ if variant and "no file named" in str(
160
+ e
161
+ ): # try without the variant, just in case user's preferences changed
162
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype)
163
+ else:
164
+ raise e
165
+
166
+ return result
167
+
168
+
169
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Checkpoint)
170
+ class ZImageCheckpointModel(ModelLoader):
171
+ """Class to load Z-Image transformer models from single-file checkpoints (safetensors, etc)."""
172
+
173
+ def _load_model(
174
+ self,
175
+ config: AnyModelConfig,
176
+ submodel_type: Optional[SubModelType] = None,
177
+ ) -> AnyModel:
178
+ if not isinstance(config, Checkpoint_Config_Base):
179
+ raise ValueError("Only CheckpointConfigBase models are currently supported here.")
180
+
181
+ match submodel_type:
182
+ case SubModelType.Transformer:
183
+ return self._load_from_singlefile(config)
184
+
185
+ raise ValueError(
186
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
187
+ )
188
+
189
+ def _load_from_singlefile(
190
+ self,
191
+ config: AnyModelConfig,
192
+ ) -> AnyModel:
193
+ from diffusers import ZImageTransformer2DModel
194
+ from safetensors.torch import load_file
195
+
196
+ if not isinstance(config, Main_Checkpoint_ZImage_Config):
197
+ raise TypeError(
198
+ f"Expected Main_Checkpoint_ZImage_Config, got {type(config).__name__}. "
199
+ "Model configuration type mismatch."
200
+ )
201
+ model_path = Path(config.path)
202
+
203
+ # Load the state dict from safetensors/checkpoint file
204
+ sd = load_file(model_path)
205
+
206
+ # Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or
207
+ # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
208
+ prefix_to_strip = None
209
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
210
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
211
+ prefix_to_strip = prefix
212
+ break
213
+
214
+ if prefix_to_strip:
215
+ stripped_sd = {}
216
+ for key, value in sd.items():
217
+ if isinstance(key, str) and key.startswith(prefix_to_strip):
218
+ stripped_sd[key[len(prefix_to_strip) :]] = value
219
+ else:
220
+ stripped_sd[key] = value
221
+ sd = stripped_sd
222
+
223
+ # Check if the state dict is in original format (not diffusers format)
224
+ # Original format has keys like "x_embedder.weight" instead of "all_x_embedder.2-1.weight"
225
+ needs_conversion = any(k.startswith("x_embedder.") for k in sd.keys() if isinstance(k, str))
226
+
227
+ if needs_conversion:
228
+ # Convert from original format to diffusers format
229
+ sd = _convert_z_image_gguf_to_diffusers(sd)
230
+
231
+ # Create an empty model with the default Z-Image config
232
+ # Z-Image-Turbo uses these default parameters from diffusers
233
+ with accelerate.init_empty_weights():
234
+ model = ZImageTransformer2DModel(
235
+ all_patch_size=(2,),
236
+ all_f_patch_size=(1,),
237
+ in_channels=16,
238
+ dim=3840,
239
+ n_layers=30,
240
+ n_refiner_layers=2,
241
+ n_heads=30,
242
+ n_kv_heads=30,
243
+ norm_eps=1e-05,
244
+ qk_norm=True,
245
+ cap_feat_dim=2560,
246
+ rope_theta=256.0,
247
+ t_scale=1000.0,
248
+ axes_dims=[32, 48, 48],
249
+ axes_lens=[1024, 512, 512],
250
+ )
251
+
252
+ # Determine safe dtype based on target device capabilities
253
+ target_device = TorchDevice.choose_torch_device()
254
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
255
+
256
+ # Handle memory management and dtype conversion
257
+ new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
258
+ self._ram_cache.make_room(new_sd_size)
259
+
260
+ # Filter out FP8 scale_weight and scaled_fp8 metadata keys
261
+ # These are quantization metadata that shouldn't be loaded into the model
262
+ keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
263
+ for k in keys_to_remove:
264
+ del sd[k]
265
+
266
+ # Convert to target dtype
267
+ for k in sd.keys():
268
+ sd[k] = sd[k].to(model_dtype)
269
+
270
+ model.load_state_dict(sd, assign=True)
271
+ return model
272
+
273
+
274
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
275
+ class ZImageGGUFCheckpointModel(ModelLoader):
276
+ """Class to load GGUF-quantized Z-Image transformer models."""
277
+
278
+ def _load_model(
279
+ self,
280
+ config: AnyModelConfig,
281
+ submodel_type: Optional[SubModelType] = None,
282
+ ) -> AnyModel:
283
+ if not isinstance(config, Checkpoint_Config_Base):
284
+ raise ValueError("Only CheckpointConfigBase models are currently supported here.")
285
+
286
+ match submodel_type:
287
+ case SubModelType.Transformer:
288
+ return self._load_from_singlefile(config)
289
+
290
+ raise ValueError(
291
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
292
+ )
293
+
294
+ def _load_from_singlefile(
295
+ self,
296
+ config: AnyModelConfig,
297
+ ) -> AnyModel:
298
+ from diffusers import ZImageTransformer2DModel
299
+
300
+ if not isinstance(config, Main_GGUF_ZImage_Config):
301
+ raise TypeError(
302
+ f"Expected Main_GGUF_ZImage_Config, got {type(config).__name__}. Model configuration type mismatch."
303
+ )
304
+ model_path = Path(config.path)
305
+
306
+ # Determine safe dtype based on target device capabilities
307
+ target_device = TorchDevice.choose_torch_device()
308
+ compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
309
+
310
+ # Load the GGUF state dict
311
+ sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
312
+
313
+ # Some Z-Image GGUF models have keys prefixed with "diffusion_model." or
314
+ # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
315
+ prefix_to_strip = None
316
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
317
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
318
+ prefix_to_strip = prefix
319
+ break
320
+
321
+ if prefix_to_strip:
322
+ stripped_sd = {}
323
+ for key, value in sd.items():
324
+ if isinstance(key, str) and key.startswith(prefix_to_strip):
325
+ stripped_sd[key[len(prefix_to_strip) :]] = value
326
+ else:
327
+ stripped_sd[key] = value
328
+ sd = stripped_sd
329
+
330
+ # Convert GGUF format keys to diffusers format
331
+ sd = _convert_z_image_gguf_to_diffusers(sd)
332
+
333
+ # Create an empty model with the default Z-Image config
334
+ # Z-Image-Turbo uses these default parameters from diffusers
335
+ with accelerate.init_empty_weights():
336
+ model = ZImageTransformer2DModel(
337
+ all_patch_size=(2,),
338
+ all_f_patch_size=(1,),
339
+ in_channels=16,
340
+ dim=3840,
341
+ n_layers=30,
342
+ n_refiner_layers=2,
343
+ n_heads=30,
344
+ n_kv_heads=30,
345
+ norm_eps=1e-05,
346
+ qk_norm=True,
347
+ cap_feat_dim=2560,
348
+ rope_theta=256.0,
349
+ t_scale=1000.0,
350
+ axes_dims=[32, 48, 48],
351
+ axes_lens=[1024, 512, 512],
352
+ )
353
+
354
+ model.load_state_dict(sd, assign=True)
355
+ return model
356
+
357
+
358
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Qwen3Encoder)
359
+ class Qwen3EncoderLoader(ModelLoader):
360
+ """Class to load standalone Qwen3 Encoder models for Z-Image (directory format)."""
361
+
362
+ def _load_model(
363
+ self,
364
+ config: AnyModelConfig,
365
+ submodel_type: Optional[SubModelType] = None,
366
+ ) -> AnyModel:
367
+ if not isinstance(config, Qwen3Encoder_Qwen3Encoder_Config):
368
+ raise ValueError("Only Qwen3Encoder_Qwen3Encoder_Config models are supported here.")
369
+
370
+ model_path = Path(config.path)
371
+
372
+ # Support both structures:
373
+ # 1. Full model: model_root/text_encoder/ and model_root/tokenizer/
374
+ # 2. Standalone download: model_root/ contains text_encoder files directly
375
+ text_encoder_path = model_path / "text_encoder"
376
+ tokenizer_path = model_path / "tokenizer"
377
+
378
+ # Check if this is a standalone text_encoder download (no nested text_encoder folder)
379
+ is_standalone = not text_encoder_path.exists() and (model_path / "config.json").exists()
380
+
381
+ if is_standalone:
382
+ text_encoder_path = model_path
383
+ tokenizer_path = model_path # Tokenizer files should also be in root
384
+
385
+ match submodel_type:
386
+ case SubModelType.Tokenizer:
387
+ # Use local_files_only=True to prevent network requests for validation
388
+ # The tokenizer files should already exist locally in the model directory
389
+ return AutoTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
390
+ case SubModelType.TextEncoder:
391
+ # Determine safe dtype based on target device capabilities
392
+ target_device = TorchDevice.choose_torch_device()
393
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
394
+ # Use local_files_only=True to prevent network requests for validation
395
+ return Qwen3ForCausalLM.from_pretrained(
396
+ text_encoder_path,
397
+ torch_dtype=model_dtype,
398
+ low_cpu_mem_usage=True,
399
+ local_files_only=True,
400
+ )
401
+
402
+ raise ValueError(
403
+ f"Only Tokenizer and TextEncoder submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
404
+ )
405
+
406
+
407
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
408
+ class ZImageControlCheckpointModel(ModelLoader):
409
+ """Class to load Z-Image Control adapter models from safetensors checkpoint.
410
+
411
+ Z-Image Control models are standalone adapters containing control layers
412
+ (control_layers, control_all_x_embedder, control_noise_refiner) that can be
413
+ combined with a base ZImageTransformer2DModel at runtime for spatial conditioning
414
+ (Canny, HED, Depth, Pose, MLSD).
415
+ """
416
+
417
+ def _load_model(
418
+ self,
419
+ config: AnyModelConfig,
420
+ submodel_type: Optional[SubModelType] = None,
421
+ ) -> AnyModel:
422
+ if not isinstance(config, Checkpoint_Config_Base):
423
+ raise ValueError("Only CheckpointConfigBase models are supported here.")
424
+
425
+ # ControlNet type models don't use submodel_type - load the adapter directly
426
+ return self._load_control_adapter(config)
427
+
428
+ def _load_control_adapter(
429
+ self,
430
+ config: AnyModelConfig,
431
+ ) -> AnyModel:
432
+ from safetensors.torch import load_file
433
+
434
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
435
+
436
+ assert isinstance(config, ControlNet_Checkpoint_ZImage_Config)
437
+ model_path = Path(config.path)
438
+
439
+ # Load the safetensors state dict
440
+ sd = load_file(model_path)
441
+
442
+ # Determine number of control blocks from state dict
443
+ # Control blocks are named control_layers.0, control_layers.1, etc.
444
+ control_block_indices = set()
445
+ for key in sd.keys():
446
+ if key.startswith("control_layers."):
447
+ parts = key.split(".")
448
+ if len(parts) > 1 and parts[1].isdigit():
449
+ control_block_indices.add(int(parts[1]))
450
+ num_control_blocks = len(control_block_indices) if control_block_indices else 6
451
+
452
+ # Determine number of refiner layers from state dict
453
+ refiner_indices: set[int] = set()
454
+ for key in sd.keys():
455
+ if key.startswith("control_noise_refiner."):
456
+ parts = key.split(".")
457
+ if len(parts) > 1 and parts[1].isdigit():
458
+ refiner_indices.add(int(parts[1]))
459
+ n_refiner_layers = len(refiner_indices) if refiner_indices else 2
460
+
461
+ # Determine control_in_dim from embedder weight shape
462
+ # control_in_dim = weight.shape[1] / (f_patch_size * patch_size * patch_size)
463
+ # For patch_size=2, f_patch_size=1: control_in_dim = weight.shape[1] / 4
464
+ control_in_dim = 16 # Default for V1
465
+ embedder_key = "control_all_x_embedder.2-1.weight"
466
+ if embedder_key in sd:
467
+ weight_shape = sd[embedder_key].shape
468
+ # weight_shape[1] = f_patch_size * patch_size * patch_size * control_in_dim
469
+ control_in_dim = weight_shape[1] // 4 # 4 = 1 * 2 * 2
470
+
471
+ # Log detected configuration for debugging
472
+ from invokeai.backend.util.logging import InvokeAILogger
473
+
474
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
475
+ version = "V2.0" if control_in_dim > 16 else "V1"
476
+ logger.info(
477
+ f"Z-Image ControlNet detected: {version} "
478
+ f"(control_in_dim={control_in_dim}, num_control_blocks={num_control_blocks}, "
479
+ f"n_refiner_layers={n_refiner_layers})"
480
+ )
481
+
482
+ # Create an empty control adapter
483
+ dim = 3840
484
+ with accelerate.init_empty_weights():
485
+ model = ZImageControlAdapter(
486
+ num_control_blocks=num_control_blocks,
487
+ control_in_dim=control_in_dim,
488
+ all_patch_size=(2,),
489
+ all_f_patch_size=(1,),
490
+ dim=dim,
491
+ n_refiner_layers=n_refiner_layers,
492
+ n_heads=30,
493
+ n_kv_heads=30,
494
+ norm_eps=1e-05,
495
+ qk_norm=True,
496
+ )
497
+
498
+ # Load state dict with strict=False to handle missing keys like x_pad_token
499
+ # Some control adapters may not include x_pad_token in their checkpoint
500
+ missing_keys, unexpected_keys = model.load_state_dict(sd, assign=True, strict=False)
501
+
502
+ # Initialize x_pad_token if it was missing from the checkpoint
503
+ if "x_pad_token" in missing_keys:
504
+ import torch.nn as nn
505
+
506
+ model.x_pad_token = nn.Parameter(torch.empty(dim))
507
+ nn.init.normal_(model.x_pad_token, std=0.02)
508
+
509
+ return model
510
+
511
+
512
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Checkpoint)
513
+ class Qwen3EncoderCheckpointLoader(ModelLoader):
514
+ """Class to load single-file Qwen3 Encoder models for Z-Image (safetensors format)."""
515
+
516
+ # Default HuggingFace model to load tokenizer from when using single-file Qwen3 encoder
517
+ # Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
518
+ DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
519
+
520
+ def _load_model(
521
+ self,
522
+ config: AnyModelConfig,
523
+ submodel_type: Optional[SubModelType] = None,
524
+ ) -> AnyModel:
525
+ if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
526
+ raise ValueError("Only Qwen3Encoder_Checkpoint_Config models are supported here.")
527
+
528
+ match submodel_type:
529
+ case SubModelType.TextEncoder:
530
+ return self._load_from_singlefile(config)
531
+ case SubModelType.Tokenizer:
532
+ # For single-file Qwen3, load tokenizer from HuggingFace
533
+ # Try local cache first to support offline usage after initial download
534
+ return self._load_tokenizer_with_offline_fallback()
535
+
536
+ raise ValueError(
537
+ f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
538
+ )
539
+
540
+ def _load_tokenizer_with_offline_fallback(self) -> AnyModel:
541
+ """Load tokenizer with local_files_only fallback for offline support.
542
+
543
+ First tries to load from local cache (offline), falling back to network download
544
+ if the tokenizer hasn't been cached yet. This ensures offline operation after
545
+ the initial download.
546
+ """
547
+ try:
548
+ # Try loading from local cache first (supports offline usage)
549
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE, local_files_only=True)
550
+ except OSError:
551
+ # Not in cache yet, download from HuggingFace
552
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
553
+
554
+ def _load_from_singlefile(
555
+ self,
556
+ config: AnyModelConfig,
557
+ ) -> AnyModel:
558
+ from safetensors.torch import load_file
559
+ from transformers import Qwen3Config, Qwen3ForCausalLM
560
+
561
+ from invokeai.backend.util.logging import InvokeAILogger
562
+
563
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
564
+
565
+ if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
566
+ raise TypeError(
567
+ f"Expected Qwen3Encoder_Checkpoint_Config, got {type(config).__name__}. "
568
+ "Model configuration type mismatch."
569
+ )
570
+ model_path = Path(config.path)
571
+
572
+ # Determine safe dtype based on target device capabilities
573
+ target_device = TorchDevice.choose_torch_device()
574
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
575
+
576
+ # Load the state dict from safetensors file
577
+ sd = load_file(model_path)
578
+
579
+ # Determine Qwen model configuration from state dict
580
+ # Count the number of layers by looking at layer keys
581
+ layer_count = 0
582
+ for key in sd.keys():
583
+ if isinstance(key, str) and key.startswith("model.layers."):
584
+ parts = key.split(".")
585
+ if len(parts) > 2:
586
+ try:
587
+ layer_idx = int(parts[2])
588
+ layer_count = max(layer_count, layer_idx + 1)
589
+ except ValueError:
590
+ pass
591
+
592
+ # Get hidden size from embed_tokens weight shape
593
+ embed_weight = sd.get("model.embed_tokens.weight")
594
+ if embed_weight is None:
595
+ raise ValueError("Could not find model.embed_tokens.weight in state dict")
596
+ if embed_weight.ndim != 2:
597
+ raise ValueError(
598
+ f"Expected 2D embed_tokens weight tensor, got shape {embed_weight.shape}. "
599
+ "The model file may be corrupted or incompatible."
600
+ )
601
+ hidden_size = embed_weight.shape[1]
602
+ vocab_size = embed_weight.shape[0]
603
+
604
+ # Detect attention configuration from layer 0 weights
605
+ q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
606
+ k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
607
+ gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
608
+
609
+ if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
610
+ raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
611
+
612
+ # Calculate dimensions from actual weights
613
+ # Qwen3 uses head_dim separately from hidden_size
614
+ head_dim = 128 # Standard head dimension for Qwen3 models
615
+ num_attention_heads = q_proj_weight.shape[0] // head_dim
616
+ num_kv_heads = k_proj_weight.shape[0] // head_dim
617
+ intermediate_size = gate_proj_weight.shape[0]
618
+
619
+ # Create Qwen3 config - matches the diffusers text_encoder/config.json
620
+ qwen_config = Qwen3Config(
621
+ vocab_size=vocab_size,
622
+ hidden_size=hidden_size,
623
+ intermediate_size=intermediate_size,
624
+ num_hidden_layers=layer_count,
625
+ num_attention_heads=num_attention_heads,
626
+ num_key_value_heads=num_kv_heads,
627
+ head_dim=head_dim,
628
+ max_position_embeddings=40960,
629
+ rms_norm_eps=1e-6,
630
+ tie_word_embeddings=True,
631
+ rope_theta=1000000.0,
632
+ use_sliding_window=False,
633
+ attention_bias=False,
634
+ attention_dropout=0.0,
635
+ torch_dtype=model_dtype,
636
+ )
637
+
638
+ # Handle memory management
639
+ new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
640
+ self._ram_cache.make_room(new_sd_size)
641
+
642
+ # Convert to target dtype
643
+ for k in sd.keys():
644
+ sd[k] = sd[k].to(model_dtype)
645
+
646
+ # Use Qwen3ForCausalLM - the correct model class for Z-Image text encoder
647
+ # Use init_empty_weights for fast model creation, then load weights with assign=True
648
+ with accelerate.init_empty_weights():
649
+ model = Qwen3ForCausalLM(qwen_config)
650
+
651
+ # Load the text model weights from checkpoint
652
+ # assign=True replaces meta tensors with real ones from state dict
653
+ model.load_state_dict(sd, strict=False, assign=True)
654
+
655
+ # Handle tied weights: lm_head shares weight with embed_tokens when tie_word_embeddings=True
656
+ # This doesn't work automatically with init_empty_weights, so we need to manually tie them
657
+ if qwen_config.tie_word_embeddings:
658
+ model.tie_weights()
659
+
660
+ # Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
661
+ # These are computed from config, not loaded from checkpoint
662
+ for name, buffer in list(model.named_buffers()):
663
+ if buffer.is_meta:
664
+ # Get parent module and buffer name
665
+ parts = name.rsplit(".", 1)
666
+ if len(parts) == 2:
667
+ parent = model.get_submodule(parts[0])
668
+ buffer_name = parts[1]
669
+ else:
670
+ parent = model
671
+ buffer_name = name
672
+
673
+ # Re-initialize the buffer based on expected shape and dtype
674
+ # For rotary embeddings, this is inv_freq which is computed from config
675
+ if buffer_name == "inv_freq":
676
+ # Compute inv_freq from config (same logic as Qwen3RotaryEmbedding.__init__)
677
+ base = qwen_config.rope_theta
678
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
679
+ parent.register_buffer(buffer_name, inv_freq.to(model_dtype), persistent=False)
680
+ else:
681
+ # For other buffers, log warning
682
+ logger.warning(f"Re-initializing unknown meta buffer: {name}")
683
+
684
+ return model
685
+
686
+
687
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.GGUFQuantized)
688
+ class Qwen3EncoderGGUFLoader(ModelLoader):
689
+ """Class to load GGUF-quantized Qwen3 Encoder models for Z-Image."""
690
+
691
+ # Default HuggingFace model to load tokenizer from when using GGUF Qwen3 encoder
692
+ # Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
693
+ DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
694
+
695
+ def _load_model(
696
+ self,
697
+ config: AnyModelConfig,
698
+ submodel_type: Optional[SubModelType] = None,
699
+ ) -> AnyModel:
700
+ if not isinstance(config, Qwen3Encoder_GGUF_Config):
701
+ raise ValueError("Only Qwen3Encoder_GGUF_Config models are supported here.")
702
+
703
+ match submodel_type:
704
+ case SubModelType.TextEncoder:
705
+ return self._load_from_gguf(config)
706
+ case SubModelType.Tokenizer:
707
+ # For GGUF Qwen3, load tokenizer from HuggingFace
708
+ # Try local cache first to support offline usage after initial download
709
+ return self._load_tokenizer_with_offline_fallback()
710
+
711
+ raise ValueError(
712
+ f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
713
+ )
714
+
715
+ def _load_tokenizer_with_offline_fallback(self) -> AnyModel:
716
+ """Load tokenizer with local_files_only fallback for offline support.
717
+
718
+ First tries to load from local cache (offline), falling back to network download
719
+ if the tokenizer hasn't been cached yet. This ensures offline operation after
720
+ the initial download.
721
+ """
722
+ try:
723
+ # Try loading from local cache first (supports offline usage)
724
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE, local_files_only=True)
725
+ except OSError:
726
+ # Not in cache yet, download from HuggingFace
727
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
728
+
729
+ def _load_from_gguf(
730
+ self,
731
+ config: AnyModelConfig,
732
+ ) -> AnyModel:
733
+ from transformers import Qwen3Config, Qwen3ForCausalLM
734
+
735
+ from invokeai.backend.util.logging import InvokeAILogger
736
+
737
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
738
+
739
+ if not isinstance(config, Qwen3Encoder_GGUF_Config):
740
+ raise TypeError(
741
+ f"Expected Qwen3Encoder_GGUF_Config, got {type(config).__name__}. Model configuration type mismatch."
742
+ )
743
+ model_path = Path(config.path)
744
+
745
+ # Determine safe dtype based on target device capabilities
746
+ target_device = TorchDevice.choose_torch_device()
747
+ compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
748
+
749
+ # Load the GGUF state dict - this returns GGMLTensor wrappers (on CPU)
750
+ # We keep them on CPU and let the model cache system handle GPU movement
751
+ # via apply_custom_layers_to_model() and the partial loading cache
752
+ sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
753
+
754
+ # Check if this is llama.cpp format (blk.X.) or PyTorch format (model.layers.X.)
755
+ is_llamacpp_format = any(k.startswith("blk.") for k in sd.keys() if isinstance(k, str))
756
+
757
+ if is_llamacpp_format:
758
+ logger.info("Detected llama.cpp GGUF format, converting keys to PyTorch format")
759
+ sd = self._convert_llamacpp_to_pytorch(sd)
760
+
761
+ # Determine Qwen model configuration from state dict
762
+ # Count the number of layers by looking at layer keys
763
+ layer_count = 0
764
+ for key in sd.keys():
765
+ if isinstance(key, str) and key.startswith("model.layers."):
766
+ parts = key.split(".")
767
+ if len(parts) > 2:
768
+ try:
769
+ layer_idx = int(parts[2])
770
+ layer_count = max(layer_count, layer_idx + 1)
771
+ except ValueError:
772
+ pass
773
+
774
+ # Get hidden size from embed_tokens weight shape
775
+ embed_weight = sd.get("model.embed_tokens.weight")
776
+ if embed_weight is None:
777
+ raise ValueError("Could not find model.embed_tokens.weight in state dict")
778
+
779
+ # Handle GGMLTensor shape access
780
+ embed_shape = embed_weight.shape if hasattr(embed_weight, "shape") else embed_weight.tensor_shape
781
+ if len(embed_shape) != 2:
782
+ raise ValueError(
783
+ f"Expected 2D embed_tokens weight tensor, got shape {embed_shape}. "
784
+ "The model file may be corrupted or incompatible."
785
+ )
786
+ hidden_size = embed_shape[1]
787
+ vocab_size = embed_shape[0]
788
+
789
+ # Detect attention configuration from layer 0 weights
790
+ q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
791
+ k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
792
+ gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
793
+
794
+ if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
795
+ raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
796
+
797
+ # Handle GGMLTensor shape access
798
+ q_shape = q_proj_weight.shape if hasattr(q_proj_weight, "shape") else q_proj_weight.tensor_shape
799
+ k_shape = k_proj_weight.shape if hasattr(k_proj_weight, "shape") else k_proj_weight.tensor_shape
800
+ gate_shape = gate_proj_weight.shape if hasattr(gate_proj_weight, "shape") else gate_proj_weight.tensor_shape
801
+
802
+ # Calculate dimensions from actual weights
803
+ head_dim = 128 # Standard head dimension for Qwen3 models
804
+ num_attention_heads = q_shape[0] // head_dim
805
+ num_kv_heads = k_shape[0] // head_dim
806
+ intermediate_size = gate_shape[0]
807
+
808
+ logger.info(
809
+ f"Qwen3 GGUF Encoder config detected: layers={layer_count}, hidden={hidden_size}, "
810
+ f"heads={num_attention_heads}, kv_heads={num_kv_heads}, intermediate={intermediate_size}, "
811
+ f"head_dim={head_dim}"
812
+ )
813
+
814
+ # Create Qwen3 config
815
+ qwen_config = Qwen3Config(
816
+ vocab_size=vocab_size,
817
+ hidden_size=hidden_size,
818
+ intermediate_size=intermediate_size,
819
+ num_hidden_layers=layer_count,
820
+ num_attention_heads=num_attention_heads,
821
+ num_key_value_heads=num_kv_heads,
822
+ head_dim=head_dim,
823
+ max_position_embeddings=40960,
824
+ rms_norm_eps=1e-6,
825
+ tie_word_embeddings=True,
826
+ rope_theta=1000000.0,
827
+ use_sliding_window=False,
828
+ attention_bias=False,
829
+ attention_dropout=0.0,
830
+ torch_dtype=compute_dtype,
831
+ )
832
+
833
+ # Use Qwen3ForCausalLM with empty weights, then load GGUF tensors
834
+ with accelerate.init_empty_weights():
835
+ model = Qwen3ForCausalLM(qwen_config)
836
+
837
+ # Load the GGUF weights with assign=True
838
+ # GGMLTensor wrappers will be dequantized on-the-fly during inference
839
+ model.load_state_dict(sd, strict=False, assign=True)
840
+
841
+ # Dequantize embed_tokens weight - embedding lookups require indexed access
842
+ # which quantized GGMLTensors can't efficiently provide (no __torch_dispatch__ for embedding)
843
+ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
844
+
845
+ embed_tokens_weight = model.model.embed_tokens.weight
846
+ if isinstance(embed_tokens_weight, GGMLTensor):
847
+ dequantized = embed_tokens_weight.get_dequantized_tensor()
848
+ model.model.embed_tokens.weight = torch.nn.Parameter(dequantized, requires_grad=False)
849
+ logger.info("Dequantized embed_tokens weight for embedding lookups")
850
+
851
+ # Handle tied weights - llama.cpp GGUF doesn't include lm_head.weight when embeddings are tied
852
+ # So we need to manually tie them after loading
853
+ if qwen_config.tie_word_embeddings:
854
+ # Check if lm_head.weight is still a meta tensor (wasn't in GGUF state dict)
855
+ if model.lm_head.weight.is_meta:
856
+ # Directly assign embed_tokens weight to lm_head (now dequantized)
857
+ model.lm_head.weight = model.model.embed_tokens.weight
858
+ logger.info("Tied lm_head.weight to embed_tokens.weight (GGUF tied embeddings)")
859
+ else:
860
+ # If lm_head.weight was loaded, use standard tie_weights
861
+ model.tie_weights()
862
+
863
+ # Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
864
+ for name, buffer in list(model.named_buffers()):
865
+ if buffer.is_meta:
866
+ parts = name.rsplit(".", 1)
867
+ if len(parts) == 2:
868
+ parent = model.get_submodule(parts[0])
869
+ buffer_name = parts[1]
870
+ else:
871
+ parent = model
872
+ buffer_name = name
873
+
874
+ if buffer_name == "inv_freq":
875
+ # Compute inv_freq from config - keep on CPU, cache system will move to GPU as needed
876
+ base = qwen_config.rope_theta
877
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
878
+ parent.register_buffer(buffer_name, inv_freq.to(dtype=compute_dtype), persistent=False)
879
+ else:
880
+ logger.warning(f"Re-initializing unknown meta buffer: {name}")
881
+
882
+ # Final check: ensure no meta tensors remain in parameters
883
+ meta_params = [(name, p) for name, p in model.named_parameters() if p.is_meta]
884
+ if meta_params:
885
+ meta_names = [name for name, _ in meta_params]
886
+ raise RuntimeError(
887
+ f"Failed to load all parameters from GGUF. The following remain as meta tensors: {meta_names}. "
888
+ "This may indicate missing keys in the GGUF file or a key mapping issue."
889
+ )
890
+
891
+ return model
892
+
893
+ def _convert_llamacpp_to_pytorch(self, sd: dict[str, Any]) -> dict[str, Any]:
894
+ """Convert llama.cpp GGUF keys to PyTorch/HuggingFace format for Qwen models.
895
+
896
+ llama.cpp format:
897
+ - blk.X.attn_q.weight -> model.layers.X.self_attn.q_proj.weight
898
+ - blk.X.attn_k.weight -> model.layers.X.self_attn.k_proj.weight
899
+ - blk.X.attn_v.weight -> model.layers.X.self_attn.v_proj.weight
900
+ - blk.X.attn_output.weight -> model.layers.X.self_attn.o_proj.weight
901
+ - blk.X.attn_q_norm.weight -> model.layers.X.self_attn.q_norm.weight (Qwen3 QK norm)
902
+ - blk.X.attn_k_norm.weight -> model.layers.X.self_attn.k_norm.weight (Qwen3 QK norm)
903
+ - blk.X.ffn_gate.weight -> model.layers.X.mlp.gate_proj.weight
904
+ - blk.X.ffn_up.weight -> model.layers.X.mlp.up_proj.weight
905
+ - blk.X.ffn_down.weight -> model.layers.X.mlp.down_proj.weight
906
+ - blk.X.attn_norm.weight -> model.layers.X.input_layernorm.weight
907
+ - blk.X.ffn_norm.weight -> model.layers.X.post_attention_layernorm.weight
908
+ - token_embd.weight -> model.embed_tokens.weight
909
+ - output_norm.weight -> model.norm.weight
910
+ - output.weight -> lm_head.weight (if not tied)
911
+ """
912
+ import re
913
+
914
+ key_map = {
915
+ "attn_q": "self_attn.q_proj",
916
+ "attn_k": "self_attn.k_proj",
917
+ "attn_v": "self_attn.v_proj",
918
+ "attn_output": "self_attn.o_proj",
919
+ "attn_q_norm": "self_attn.q_norm", # Qwen3 QK normalization
920
+ "attn_k_norm": "self_attn.k_norm", # Qwen3 QK normalization
921
+ "ffn_gate": "mlp.gate_proj",
922
+ "ffn_up": "mlp.up_proj",
923
+ "ffn_down": "mlp.down_proj",
924
+ "attn_norm": "input_layernorm",
925
+ "ffn_norm": "post_attention_layernorm",
926
+ }
927
+
928
+ new_sd: dict[str, Any] = {}
929
+ blk_pattern = re.compile(r"^blk\.(\d+)\.(.+)$")
930
+
931
+ for key, value in sd.items():
932
+ if not isinstance(key, str):
933
+ new_sd[key] = value
934
+ continue
935
+
936
+ # Handle block layers
937
+ match = blk_pattern.match(key)
938
+ if match:
939
+ layer_idx = match.group(1)
940
+ rest = match.group(2)
941
+
942
+ # Split rest into component and suffix (e.g., "attn_q.weight" -> "attn_q", "weight")
943
+ parts = rest.split(".", 1)
944
+ component = parts[0]
945
+ suffix = parts[1] if len(parts) > 1 else ""
946
+
947
+ if component in key_map:
948
+ new_component = key_map[component]
949
+ new_key = f"model.layers.{layer_idx}.{new_component}"
950
+ if suffix:
951
+ new_key += f".{suffix}"
952
+ new_sd[new_key] = value
953
+ else:
954
+ # Unknown component, keep as-is with model.layers prefix
955
+ new_sd[f"model.layers.{layer_idx}.{rest}"] = value
956
+ continue
957
+
958
+ # Handle non-block keys
959
+ if key == "token_embd.weight":
960
+ new_sd["model.embed_tokens.weight"] = value
961
+ elif key == "output_norm.weight":
962
+ new_sd["model.norm.weight"] = value
963
+ elif key == "output.weight":
964
+ new_sd["lm_head.weight"] = value
965
+ else:
966
+ # Keep other keys as-is
967
+ new_sd[key] = value
968
+
969
+ return new_sd