InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,935 @@
1
+ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
2
+ """Class for Z-Image model loading in InvokeAI."""
3
+
4
+ from pathlib import Path
5
+ from typing import Any, Optional
6
+
7
+ import accelerate
8
+ import torch
9
+ from transformers import AutoTokenizer, Qwen3ForCausalLM
10
+
11
+ from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
12
+ from invokeai.backend.model_manager.configs.controlnet import ControlNet_Checkpoint_ZImage_Config
13
+ from invokeai.backend.model_manager.configs.factory import AnyModelConfig
14
+ from invokeai.backend.model_manager.configs.main import Main_Checkpoint_ZImage_Config, Main_GGUF_ZImage_Config
15
+ from invokeai.backend.model_manager.configs.qwen3_encoder import (
16
+ Qwen3Encoder_Checkpoint_Config,
17
+ Qwen3Encoder_GGUF_Config,
18
+ Qwen3Encoder_Qwen3Encoder_Config,
19
+ )
20
+ from invokeai.backend.model_manager.load.load_default import ModelLoader
21
+ from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
22
+ from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
23
+ from invokeai.backend.model_manager.taxonomy import (
24
+ AnyModel,
25
+ BaseModelType,
26
+ ModelFormat,
27
+ ModelType,
28
+ SubModelType,
29
+ )
30
+ from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
31
+ from invokeai.backend.util.devices import TorchDevice
32
+
33
+
34
+ def _convert_z_image_gguf_to_diffusers(sd: dict[str, Any]) -> dict[str, Any]:
35
+ """Convert Z-Image GGUF state dict keys to diffusers format.
36
+
37
+ The GGUF format uses original model keys that differ from diffusers:
38
+ - qkv.weight (fused) -> to_q.weight, to_k.weight, to_v.weight (split)
39
+ - out.weight -> to_out.0.weight
40
+ - q_norm.weight -> norm_q.weight
41
+ - k_norm.weight -> norm_k.weight
42
+ - x_embedder.* -> all_x_embedder.2-1.*
43
+ - final_layer.* -> all_final_layer.2-1.*
44
+ - norm_final.* -> skipped (diffusers uses non-learnable LayerNorm)
45
+ - x_pad_token, cap_pad_token: [dim] -> [1, dim] (diffusers expects batch dimension)
46
+ """
47
+ new_sd: dict[str, Any] = {}
48
+
49
+ for key, value in sd.items():
50
+ if not isinstance(key, str):
51
+ new_sd[key] = value
52
+ continue
53
+
54
+ # Handle padding tokens: GGUF has shape [dim], diffusers expects [1, dim]
55
+ if key in ("x_pad_token", "cap_pad_token"):
56
+ if hasattr(value, "shape") and len(value.shape) == 1:
57
+ # GGMLTensor doesn't support unsqueeze, so dequantize first if needed
58
+ if hasattr(value, "get_dequantized_tensor"):
59
+ value = value.get_dequantized_tensor()
60
+ # Use reshape instead of unsqueeze for better compatibility
61
+ value = torch.as_tensor(value).reshape(1, -1)
62
+ new_sd[key] = value
63
+ continue
64
+
65
+ # Handle x_embedder -> all_x_embedder.2-1
66
+ if key.startswith("x_embedder."):
67
+ suffix = key[len("x_embedder.") :]
68
+ new_key = f"all_x_embedder.2-1.{suffix}"
69
+ new_sd[new_key] = value
70
+ continue
71
+
72
+ # Handle final_layer -> all_final_layer.2-1
73
+ if key.startswith("final_layer."):
74
+ suffix = key[len("final_layer.") :]
75
+ new_key = f"all_final_layer.2-1.{suffix}"
76
+ new_sd[new_key] = value
77
+ continue
78
+
79
+ # Skip norm_final keys - the diffusers model uses LayerNorm with elementwise_affine=False
80
+ # (no learnable weight/bias), but some checkpoints (e.g., FP8) include these as all-zeros
81
+ if key.startswith("norm_final."):
82
+ continue
83
+
84
+ # Handle fused QKV weights - need to split
85
+ if ".attention.qkv." in key:
86
+ # Get the layer prefix and suffix
87
+ prefix = key.rsplit(".attention.qkv.", 1)[0]
88
+ suffix = key.rsplit(".attention.qkv.", 1)[1] # "weight" or "bias"
89
+
90
+ # Skip non-weight/bias tensors (e.g., FP8 scale_weight tensors)
91
+ # These are quantization metadata and should not be split
92
+ if suffix not in ("weight", "bias"):
93
+ new_sd[key] = value
94
+ continue
95
+
96
+ # Split the fused QKV tensor into Q, K, V
97
+ tensor = value
98
+ if hasattr(tensor, "shape"):
99
+ if tensor.shape[0] % 3 != 0:
100
+ raise ValueError(
101
+ f"Cannot split QKV tensor '{key}': first dimension ({tensor.shape[0]}) "
102
+ "is not divisible by 3. The model file may be corrupted or incompatible."
103
+ )
104
+ dim = tensor.shape[0] // 3
105
+ q = tensor[:dim]
106
+ k = tensor[dim : 2 * dim]
107
+ v = tensor[2 * dim :]
108
+
109
+ new_sd[f"{prefix}.attention.to_q.{suffix}"] = q
110
+ new_sd[f"{prefix}.attention.to_k.{suffix}"] = k
111
+ new_sd[f"{prefix}.attention.to_v.{suffix}"] = v
112
+ continue
113
+
114
+ # Handle attention key renaming
115
+ if ".attention." in key:
116
+ new_key = key.replace(".q_norm.", ".norm_q.")
117
+ new_key = new_key.replace(".k_norm.", ".norm_k.")
118
+ new_key = new_key.replace(".attention.out.", ".attention.to_out.0.")
119
+ new_sd[new_key] = value
120
+ continue
121
+
122
+ # For all other keys, just copy as-is
123
+ new_sd[key] = value
124
+
125
+ return new_sd
126
+
127
+
128
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Diffusers)
129
+ class ZImageDiffusersModel(GenericDiffusersLoader):
130
+ """Class to load Z-Image main models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
131
+
132
+ def _load_model(
133
+ self,
134
+ config: AnyModelConfig,
135
+ submodel_type: Optional[SubModelType] = None,
136
+ ) -> AnyModel:
137
+ if isinstance(config, Checkpoint_Config_Base):
138
+ raise NotImplementedError("CheckpointConfigBase is not implemented for Z-Image models.")
139
+
140
+ if submodel_type is None:
141
+ raise Exception("A submodel type must be provided when loading main pipelines.")
142
+
143
+ model_path = Path(config.path)
144
+ load_class = self.get_hf_load_class(model_path, submodel_type)
145
+ repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
146
+ variant = repo_variant.value if repo_variant else None
147
+ model_path = model_path / submodel_type.value
148
+
149
+ # Z-Image prefers bfloat16, but use safe dtype based on target device capabilities.
150
+ target_device = TorchDevice.choose_torch_device()
151
+ dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
152
+ try:
153
+ result: AnyModel = load_class.from_pretrained(
154
+ model_path,
155
+ torch_dtype=dtype,
156
+ variant=variant,
157
+ )
158
+ except OSError as e:
159
+ if variant and "no file named" in str(
160
+ e
161
+ ): # try without the variant, just in case user's preferences changed
162
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype)
163
+ else:
164
+ raise e
165
+
166
+ return result
167
+
168
+
169
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.Checkpoint)
170
+ class ZImageCheckpointModel(ModelLoader):
171
+ """Class to load Z-Image transformer models from single-file checkpoints (safetensors, etc)."""
172
+
173
+ def _load_model(
174
+ self,
175
+ config: AnyModelConfig,
176
+ submodel_type: Optional[SubModelType] = None,
177
+ ) -> AnyModel:
178
+ if not isinstance(config, Checkpoint_Config_Base):
179
+ raise ValueError("Only CheckpointConfigBase models are currently supported here.")
180
+
181
+ match submodel_type:
182
+ case SubModelType.Transformer:
183
+ return self._load_from_singlefile(config)
184
+
185
+ raise ValueError(
186
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
187
+ )
188
+
189
+ def _load_from_singlefile(
190
+ self,
191
+ config: AnyModelConfig,
192
+ ) -> AnyModel:
193
+ from diffusers import ZImageTransformer2DModel
194
+ from safetensors.torch import load_file
195
+
196
+ if not isinstance(config, Main_Checkpoint_ZImage_Config):
197
+ raise TypeError(
198
+ f"Expected Main_Checkpoint_ZImage_Config, got {type(config).__name__}. "
199
+ "Model configuration type mismatch."
200
+ )
201
+ model_path = Path(config.path)
202
+
203
+ # Load the state dict from safetensors/checkpoint file
204
+ sd = load_file(model_path)
205
+
206
+ # Some Z-Image checkpoint files have keys prefixed with "diffusion_model." or
207
+ # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
208
+ prefix_to_strip = None
209
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
210
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
211
+ prefix_to_strip = prefix
212
+ break
213
+
214
+ if prefix_to_strip:
215
+ stripped_sd = {}
216
+ for key, value in sd.items():
217
+ if isinstance(key, str) and key.startswith(prefix_to_strip):
218
+ stripped_sd[key[len(prefix_to_strip) :]] = value
219
+ else:
220
+ stripped_sd[key] = value
221
+ sd = stripped_sd
222
+
223
+ # Check if the state dict is in original format (not diffusers format)
224
+ # Original format has keys like "x_embedder.weight" instead of "all_x_embedder.2-1.weight"
225
+ needs_conversion = any(k.startswith("x_embedder.") for k in sd.keys() if isinstance(k, str))
226
+
227
+ if needs_conversion:
228
+ # Convert from original format to diffusers format
229
+ sd = _convert_z_image_gguf_to_diffusers(sd)
230
+
231
+ # Create an empty model with the default Z-Image config
232
+ # Z-Image-Turbo uses these default parameters from diffusers
233
+ with accelerate.init_empty_weights():
234
+ model = ZImageTransformer2DModel(
235
+ all_patch_size=(2,),
236
+ all_f_patch_size=(1,),
237
+ in_channels=16,
238
+ dim=3840,
239
+ n_layers=30,
240
+ n_refiner_layers=2,
241
+ n_heads=30,
242
+ n_kv_heads=30,
243
+ norm_eps=1e-05,
244
+ qk_norm=True,
245
+ cap_feat_dim=2560,
246
+ rope_theta=256.0,
247
+ t_scale=1000.0,
248
+ axes_dims=[32, 48, 48],
249
+ axes_lens=[1024, 512, 512],
250
+ )
251
+
252
+ # Determine safe dtype based on target device capabilities
253
+ target_device = TorchDevice.choose_torch_device()
254
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
255
+
256
+ # Handle memory management and dtype conversion
257
+ new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
258
+ self._ram_cache.make_room(new_sd_size)
259
+
260
+ # Filter out FP8 scale_weight and scaled_fp8 metadata keys
261
+ # These are quantization metadata that shouldn't be loaded into the model
262
+ keys_to_remove = [k for k in sd.keys() if k.endswith(".scale_weight") or k == "scaled_fp8"]
263
+ for k in keys_to_remove:
264
+ del sd[k]
265
+
266
+ # Convert to target dtype
267
+ for k in sd.keys():
268
+ sd[k] = sd[k].to(model_dtype)
269
+
270
+ model.load_state_dict(sd, assign=True)
271
+ return model
272
+
273
+
274
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
275
+ class ZImageGGUFCheckpointModel(ModelLoader):
276
+ """Class to load GGUF-quantized Z-Image transformer models."""
277
+
278
+ def _load_model(
279
+ self,
280
+ config: AnyModelConfig,
281
+ submodel_type: Optional[SubModelType] = None,
282
+ ) -> AnyModel:
283
+ if not isinstance(config, Checkpoint_Config_Base):
284
+ raise ValueError("Only CheckpointConfigBase models are currently supported here.")
285
+
286
+ match submodel_type:
287
+ case SubModelType.Transformer:
288
+ return self._load_from_singlefile(config)
289
+
290
+ raise ValueError(
291
+ f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
292
+ )
293
+
294
+ def _load_from_singlefile(
295
+ self,
296
+ config: AnyModelConfig,
297
+ ) -> AnyModel:
298
+ from diffusers import ZImageTransformer2DModel
299
+
300
+ if not isinstance(config, Main_GGUF_ZImage_Config):
301
+ raise TypeError(
302
+ f"Expected Main_GGUF_ZImage_Config, got {type(config).__name__}. Model configuration type mismatch."
303
+ )
304
+ model_path = Path(config.path)
305
+
306
+ # Determine safe dtype based on target device capabilities
307
+ target_device = TorchDevice.choose_torch_device()
308
+ compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
309
+
310
+ # Load the GGUF state dict
311
+ sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
312
+
313
+ # Some Z-Image GGUF models have keys prefixed with "diffusion_model." or
314
+ # "model.diffusion_model." (ComfyUI-style format). Check if we need to strip this prefix.
315
+ prefix_to_strip = None
316
+ for prefix in ["model.diffusion_model.", "diffusion_model."]:
317
+ if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
318
+ prefix_to_strip = prefix
319
+ break
320
+
321
+ if prefix_to_strip:
322
+ stripped_sd = {}
323
+ for key, value in sd.items():
324
+ if isinstance(key, str) and key.startswith(prefix_to_strip):
325
+ stripped_sd[key[len(prefix_to_strip) :]] = value
326
+ else:
327
+ stripped_sd[key] = value
328
+ sd = stripped_sd
329
+
330
+ # Convert GGUF format keys to diffusers format
331
+ sd = _convert_z_image_gguf_to_diffusers(sd)
332
+
333
+ # Create an empty model with the default Z-Image config
334
+ # Z-Image-Turbo uses these default parameters from diffusers
335
+ with accelerate.init_empty_weights():
336
+ model = ZImageTransformer2DModel(
337
+ all_patch_size=(2,),
338
+ all_f_patch_size=(1,),
339
+ in_channels=16,
340
+ dim=3840,
341
+ n_layers=30,
342
+ n_refiner_layers=2,
343
+ n_heads=30,
344
+ n_kv_heads=30,
345
+ norm_eps=1e-05,
346
+ qk_norm=True,
347
+ cap_feat_dim=2560,
348
+ rope_theta=256.0,
349
+ t_scale=1000.0,
350
+ axes_dims=[32, 48, 48],
351
+ axes_lens=[1024, 512, 512],
352
+ )
353
+
354
+ model.load_state_dict(sd, assign=True)
355
+ return model
356
+
357
+
358
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Qwen3Encoder)
359
+ class Qwen3EncoderLoader(ModelLoader):
360
+ """Class to load standalone Qwen3 Encoder models for Z-Image (directory format)."""
361
+
362
+ def _load_model(
363
+ self,
364
+ config: AnyModelConfig,
365
+ submodel_type: Optional[SubModelType] = None,
366
+ ) -> AnyModel:
367
+ if not isinstance(config, Qwen3Encoder_Qwen3Encoder_Config):
368
+ raise ValueError("Only Qwen3Encoder_Qwen3Encoder_Config models are supported here.")
369
+
370
+ model_path = Path(config.path)
371
+
372
+ # Support both structures:
373
+ # 1. Full model: model_root/text_encoder/ and model_root/tokenizer/
374
+ # 2. Standalone download: model_root/ contains text_encoder files directly
375
+ text_encoder_path = model_path / "text_encoder"
376
+ tokenizer_path = model_path / "tokenizer"
377
+
378
+ # Check if this is a standalone text_encoder download (no nested text_encoder folder)
379
+ is_standalone = not text_encoder_path.exists() and (model_path / "config.json").exists()
380
+
381
+ if is_standalone:
382
+ text_encoder_path = model_path
383
+ tokenizer_path = model_path # Tokenizer files should also be in root
384
+
385
+ match submodel_type:
386
+ case SubModelType.Tokenizer:
387
+ return AutoTokenizer.from_pretrained(tokenizer_path)
388
+ case SubModelType.TextEncoder:
389
+ # Determine safe dtype based on target device capabilities
390
+ target_device = TorchDevice.choose_torch_device()
391
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
392
+ return Qwen3ForCausalLM.from_pretrained(
393
+ text_encoder_path,
394
+ torch_dtype=model_dtype,
395
+ low_cpu_mem_usage=True,
396
+ )
397
+
398
+ raise ValueError(
399
+ f"Only Tokenizer and TextEncoder submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
400
+ )
401
+
402
+
403
+ @ModelLoaderRegistry.register(base=BaseModelType.ZImage, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
404
+ class ZImageControlCheckpointModel(ModelLoader):
405
+ """Class to load Z-Image Control adapter models from safetensors checkpoint.
406
+
407
+ Z-Image Control models are standalone adapters containing control layers
408
+ (control_layers, control_all_x_embedder, control_noise_refiner) that can be
409
+ combined with a base ZImageTransformer2DModel at runtime for spatial conditioning
410
+ (Canny, HED, Depth, Pose, MLSD).
411
+ """
412
+
413
+ def _load_model(
414
+ self,
415
+ config: AnyModelConfig,
416
+ submodel_type: Optional[SubModelType] = None,
417
+ ) -> AnyModel:
418
+ if not isinstance(config, Checkpoint_Config_Base):
419
+ raise ValueError("Only CheckpointConfigBase models are supported here.")
420
+
421
+ # ControlNet type models don't use submodel_type - load the adapter directly
422
+ return self._load_control_adapter(config)
423
+
424
+ def _load_control_adapter(
425
+ self,
426
+ config: AnyModelConfig,
427
+ ) -> AnyModel:
428
+ from safetensors.torch import load_file
429
+
430
+ from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
431
+
432
+ assert isinstance(config, ControlNet_Checkpoint_ZImage_Config)
433
+ model_path = Path(config.path)
434
+
435
+ # Load the safetensors state dict
436
+ sd = load_file(model_path)
437
+
438
+ # Determine number of control blocks from state dict
439
+ # Control blocks are named control_layers.0, control_layers.1, etc.
440
+ control_block_indices = set()
441
+ for key in sd.keys():
442
+ if key.startswith("control_layers."):
443
+ parts = key.split(".")
444
+ if len(parts) > 1 and parts[1].isdigit():
445
+ control_block_indices.add(int(parts[1]))
446
+ num_control_blocks = len(control_block_indices) if control_block_indices else 6
447
+
448
+ # Determine number of refiner layers from state dict
449
+ refiner_indices: set[int] = set()
450
+ for key in sd.keys():
451
+ if key.startswith("control_noise_refiner."):
452
+ parts = key.split(".")
453
+ if len(parts) > 1 and parts[1].isdigit():
454
+ refiner_indices.add(int(parts[1]))
455
+ n_refiner_layers = len(refiner_indices) if refiner_indices else 2
456
+
457
+ # Determine control_in_dim from embedder weight shape
458
+ # control_in_dim = weight.shape[1] / (f_patch_size * patch_size * patch_size)
459
+ # For patch_size=2, f_patch_size=1: control_in_dim = weight.shape[1] / 4
460
+ control_in_dim = 16 # Default for V1
461
+ embedder_key = "control_all_x_embedder.2-1.weight"
462
+ if embedder_key in sd:
463
+ weight_shape = sd[embedder_key].shape
464
+ # weight_shape[1] = f_patch_size * patch_size * patch_size * control_in_dim
465
+ control_in_dim = weight_shape[1] // 4 # 4 = 1 * 2 * 2
466
+
467
+ # Log detected configuration for debugging
468
+ from invokeai.backend.util.logging import InvokeAILogger
469
+
470
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
471
+ version = "V2.0" if control_in_dim > 16 else "V1"
472
+ logger.info(
473
+ f"Z-Image ControlNet detected: {version} "
474
+ f"(control_in_dim={control_in_dim}, num_control_blocks={num_control_blocks}, "
475
+ f"n_refiner_layers={n_refiner_layers})"
476
+ )
477
+
478
+ # Create an empty control adapter
479
+ dim = 3840
480
+ with accelerate.init_empty_weights():
481
+ model = ZImageControlAdapter(
482
+ num_control_blocks=num_control_blocks,
483
+ control_in_dim=control_in_dim,
484
+ all_patch_size=(2,),
485
+ all_f_patch_size=(1,),
486
+ dim=dim,
487
+ n_refiner_layers=n_refiner_layers,
488
+ n_heads=30,
489
+ n_kv_heads=30,
490
+ norm_eps=1e-05,
491
+ qk_norm=True,
492
+ )
493
+
494
+ # Load state dict with strict=False to handle missing keys like x_pad_token
495
+ # Some control adapters may not include x_pad_token in their checkpoint
496
+ missing_keys, unexpected_keys = model.load_state_dict(sd, assign=True, strict=False)
497
+
498
+ # Initialize x_pad_token if it was missing from the checkpoint
499
+ if "x_pad_token" in missing_keys:
500
+ import torch.nn as nn
501
+
502
+ model.x_pad_token = nn.Parameter(torch.empty(dim))
503
+ nn.init.normal_(model.x_pad_token, std=0.02)
504
+
505
+ return model
506
+
507
+
508
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.Checkpoint)
509
+ class Qwen3EncoderCheckpointLoader(ModelLoader):
510
+ """Class to load single-file Qwen3 Encoder models for Z-Image (safetensors format)."""
511
+
512
+ # Default HuggingFace model to load tokenizer from when using single-file Qwen3 encoder
513
+ # Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
514
+ DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
515
+
516
+ def _load_model(
517
+ self,
518
+ config: AnyModelConfig,
519
+ submodel_type: Optional[SubModelType] = None,
520
+ ) -> AnyModel:
521
+ if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
522
+ raise ValueError("Only Qwen3Encoder_Checkpoint_Config models are supported here.")
523
+
524
+ match submodel_type:
525
+ case SubModelType.TextEncoder:
526
+ return self._load_from_singlefile(config)
527
+ case SubModelType.Tokenizer:
528
+ # For single-file Qwen3, load tokenizer from HuggingFace
529
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
530
+
531
+ raise ValueError(
532
+ f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
533
+ )
534
+
535
+ def _load_from_singlefile(
536
+ self,
537
+ config: AnyModelConfig,
538
+ ) -> AnyModel:
539
+ from safetensors.torch import load_file
540
+ from transformers import Qwen3Config, Qwen3ForCausalLM
541
+
542
+ from invokeai.backend.util.logging import InvokeAILogger
543
+
544
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
545
+
546
+ if not isinstance(config, Qwen3Encoder_Checkpoint_Config):
547
+ raise TypeError(
548
+ f"Expected Qwen3Encoder_Checkpoint_Config, got {type(config).__name__}. "
549
+ "Model configuration type mismatch."
550
+ )
551
+ model_path = Path(config.path)
552
+
553
+ # Determine safe dtype based on target device capabilities
554
+ target_device = TorchDevice.choose_torch_device()
555
+ model_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
556
+
557
+ # Load the state dict from safetensors file
558
+ sd = load_file(model_path)
559
+
560
+ # Determine Qwen model configuration from state dict
561
+ # Count the number of layers by looking at layer keys
562
+ layer_count = 0
563
+ for key in sd.keys():
564
+ if isinstance(key, str) and key.startswith("model.layers."):
565
+ parts = key.split(".")
566
+ if len(parts) > 2:
567
+ try:
568
+ layer_idx = int(parts[2])
569
+ layer_count = max(layer_count, layer_idx + 1)
570
+ except ValueError:
571
+ pass
572
+
573
+ # Get hidden size from embed_tokens weight shape
574
+ embed_weight = sd.get("model.embed_tokens.weight")
575
+ if embed_weight is None:
576
+ raise ValueError("Could not find model.embed_tokens.weight in state dict")
577
+ if embed_weight.ndim != 2:
578
+ raise ValueError(
579
+ f"Expected 2D embed_tokens weight tensor, got shape {embed_weight.shape}. "
580
+ "The model file may be corrupted or incompatible."
581
+ )
582
+ hidden_size = embed_weight.shape[1]
583
+ vocab_size = embed_weight.shape[0]
584
+
585
+ # Detect attention configuration from layer 0 weights
586
+ q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
587
+ k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
588
+ gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
589
+
590
+ if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
591
+ raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
592
+
593
+ # Calculate dimensions from actual weights
594
+ # Qwen3 uses head_dim separately from hidden_size
595
+ head_dim = 128 # Standard head dimension for Qwen3 models
596
+ num_attention_heads = q_proj_weight.shape[0] // head_dim
597
+ num_kv_heads = k_proj_weight.shape[0] // head_dim
598
+ intermediate_size = gate_proj_weight.shape[0]
599
+
600
+ # Create Qwen3 config - matches the diffusers text_encoder/config.json
601
+ qwen_config = Qwen3Config(
602
+ vocab_size=vocab_size,
603
+ hidden_size=hidden_size,
604
+ intermediate_size=intermediate_size,
605
+ num_hidden_layers=layer_count,
606
+ num_attention_heads=num_attention_heads,
607
+ num_key_value_heads=num_kv_heads,
608
+ head_dim=head_dim,
609
+ max_position_embeddings=40960,
610
+ rms_norm_eps=1e-6,
611
+ tie_word_embeddings=True,
612
+ rope_theta=1000000.0,
613
+ use_sliding_window=False,
614
+ attention_bias=False,
615
+ attention_dropout=0.0,
616
+ torch_dtype=model_dtype,
617
+ )
618
+
619
+ # Handle memory management
620
+ new_sd_size = sum([ten.nelement() * model_dtype.itemsize for ten in sd.values()])
621
+ self._ram_cache.make_room(new_sd_size)
622
+
623
+ # Convert to target dtype
624
+ for k in sd.keys():
625
+ sd[k] = sd[k].to(model_dtype)
626
+
627
+ # Use Qwen3ForCausalLM - the correct model class for Z-Image text encoder
628
+ # Use init_empty_weights for fast model creation, then load weights with assign=True
629
+ with accelerate.init_empty_weights():
630
+ model = Qwen3ForCausalLM(qwen_config)
631
+
632
+ # Load the text model weights from checkpoint
633
+ # assign=True replaces meta tensors with real ones from state dict
634
+ model.load_state_dict(sd, strict=False, assign=True)
635
+
636
+ # Handle tied weights: lm_head shares weight with embed_tokens when tie_word_embeddings=True
637
+ # This doesn't work automatically with init_empty_weights, so we need to manually tie them
638
+ if qwen_config.tie_word_embeddings:
639
+ model.tie_weights()
640
+
641
+ # Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
642
+ # These are computed from config, not loaded from checkpoint
643
+ for name, buffer in list(model.named_buffers()):
644
+ if buffer.is_meta:
645
+ # Get parent module and buffer name
646
+ parts = name.rsplit(".", 1)
647
+ if len(parts) == 2:
648
+ parent = model.get_submodule(parts[0])
649
+ buffer_name = parts[1]
650
+ else:
651
+ parent = model
652
+ buffer_name = name
653
+
654
+ # Re-initialize the buffer based on expected shape and dtype
655
+ # For rotary embeddings, this is inv_freq which is computed from config
656
+ if buffer_name == "inv_freq":
657
+ # Compute inv_freq from config (same logic as Qwen3RotaryEmbedding.__init__)
658
+ base = qwen_config.rope_theta
659
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
660
+ parent.register_buffer(buffer_name, inv_freq.to(model_dtype), persistent=False)
661
+ else:
662
+ # For other buffers, log warning
663
+ logger.warning(f"Re-initializing unknown meta buffer: {name}")
664
+
665
+ return model
666
+
667
+
668
+ @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Qwen3Encoder, format=ModelFormat.GGUFQuantized)
669
+ class Qwen3EncoderGGUFLoader(ModelLoader):
670
+ """Class to load GGUF-quantized Qwen3 Encoder models for Z-Image."""
671
+
672
+ # Default HuggingFace model to load tokenizer from when using GGUF Qwen3 encoder
673
+ # Must be Qwen3 (not Qwen2.5) to match Z-Image's text encoder architecture and special tokens
674
+ DEFAULT_TOKENIZER_SOURCE = "Qwen/Qwen3-4B"
675
+
676
+ def _load_model(
677
+ self,
678
+ config: AnyModelConfig,
679
+ submodel_type: Optional[SubModelType] = None,
680
+ ) -> AnyModel:
681
+ if not isinstance(config, Qwen3Encoder_GGUF_Config):
682
+ raise ValueError("Only Qwen3Encoder_GGUF_Config models are supported here.")
683
+
684
+ match submodel_type:
685
+ case SubModelType.TextEncoder:
686
+ return self._load_from_gguf(config)
687
+ case SubModelType.Tokenizer:
688
+ # For GGUF Qwen3, load tokenizer from HuggingFace
689
+ return AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_SOURCE)
690
+
691
+ raise ValueError(
692
+ f"Only TextEncoder and Tokenizer submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}"
693
+ )
694
+
695
+ def _load_from_gguf(
696
+ self,
697
+ config: AnyModelConfig,
698
+ ) -> AnyModel:
699
+ from transformers import Qwen3Config, Qwen3ForCausalLM
700
+
701
+ from invokeai.backend.util.logging import InvokeAILogger
702
+
703
+ logger = InvokeAILogger.get_logger(self.__class__.__name__)
704
+
705
+ if not isinstance(config, Qwen3Encoder_GGUF_Config):
706
+ raise TypeError(
707
+ f"Expected Qwen3Encoder_GGUF_Config, got {type(config).__name__}. Model configuration type mismatch."
708
+ )
709
+ model_path = Path(config.path)
710
+
711
+ # Determine safe dtype based on target device capabilities
712
+ target_device = TorchDevice.choose_torch_device()
713
+ compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
714
+
715
+ # Load the GGUF state dict - this returns GGMLTensor wrappers (on CPU)
716
+ # We keep them on CPU and let the model cache system handle GPU movement
717
+ # via apply_custom_layers_to_model() and the partial loading cache
718
+ sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
719
+
720
+ # Check if this is llama.cpp format (blk.X.) or PyTorch format (model.layers.X.)
721
+ is_llamacpp_format = any(k.startswith("blk.") for k in sd.keys() if isinstance(k, str))
722
+
723
+ if is_llamacpp_format:
724
+ logger.info("Detected llama.cpp GGUF format, converting keys to PyTorch format")
725
+ sd = self._convert_llamacpp_to_pytorch(sd)
726
+
727
+ # Determine Qwen model configuration from state dict
728
+ # Count the number of layers by looking at layer keys
729
+ layer_count = 0
730
+ for key in sd.keys():
731
+ if isinstance(key, str) and key.startswith("model.layers."):
732
+ parts = key.split(".")
733
+ if len(parts) > 2:
734
+ try:
735
+ layer_idx = int(parts[2])
736
+ layer_count = max(layer_count, layer_idx + 1)
737
+ except ValueError:
738
+ pass
739
+
740
+ # Get hidden size from embed_tokens weight shape
741
+ embed_weight = sd.get("model.embed_tokens.weight")
742
+ if embed_weight is None:
743
+ raise ValueError("Could not find model.embed_tokens.weight in state dict")
744
+
745
+ # Handle GGMLTensor shape access
746
+ embed_shape = embed_weight.shape if hasattr(embed_weight, "shape") else embed_weight.tensor_shape
747
+ if len(embed_shape) != 2:
748
+ raise ValueError(
749
+ f"Expected 2D embed_tokens weight tensor, got shape {embed_shape}. "
750
+ "The model file may be corrupted or incompatible."
751
+ )
752
+ hidden_size = embed_shape[1]
753
+ vocab_size = embed_shape[0]
754
+
755
+ # Detect attention configuration from layer 0 weights
756
+ q_proj_weight = sd.get("model.layers.0.self_attn.q_proj.weight")
757
+ k_proj_weight = sd.get("model.layers.0.self_attn.k_proj.weight")
758
+ gate_proj_weight = sd.get("model.layers.0.mlp.gate_proj.weight")
759
+
760
+ if q_proj_weight is None or k_proj_weight is None or gate_proj_weight is None:
761
+ raise ValueError("Could not find attention/mlp weights in state dict to determine configuration")
762
+
763
+ # Handle GGMLTensor shape access
764
+ q_shape = q_proj_weight.shape if hasattr(q_proj_weight, "shape") else q_proj_weight.tensor_shape
765
+ k_shape = k_proj_weight.shape if hasattr(k_proj_weight, "shape") else k_proj_weight.tensor_shape
766
+ gate_shape = gate_proj_weight.shape if hasattr(gate_proj_weight, "shape") else gate_proj_weight.tensor_shape
767
+
768
+ # Calculate dimensions from actual weights
769
+ head_dim = 128 # Standard head dimension for Qwen3 models
770
+ num_attention_heads = q_shape[0] // head_dim
771
+ num_kv_heads = k_shape[0] // head_dim
772
+ intermediate_size = gate_shape[0]
773
+
774
+ logger.info(
775
+ f"Qwen3 GGUF Encoder config detected: layers={layer_count}, hidden={hidden_size}, "
776
+ f"heads={num_attention_heads}, kv_heads={num_kv_heads}, intermediate={intermediate_size}, "
777
+ f"head_dim={head_dim}"
778
+ )
779
+
780
+ # Create Qwen3 config
781
+ qwen_config = Qwen3Config(
782
+ vocab_size=vocab_size,
783
+ hidden_size=hidden_size,
784
+ intermediate_size=intermediate_size,
785
+ num_hidden_layers=layer_count,
786
+ num_attention_heads=num_attention_heads,
787
+ num_key_value_heads=num_kv_heads,
788
+ head_dim=head_dim,
789
+ max_position_embeddings=40960,
790
+ rms_norm_eps=1e-6,
791
+ tie_word_embeddings=True,
792
+ rope_theta=1000000.0,
793
+ use_sliding_window=False,
794
+ attention_bias=False,
795
+ attention_dropout=0.0,
796
+ torch_dtype=compute_dtype,
797
+ )
798
+
799
+ # Use Qwen3ForCausalLM with empty weights, then load GGUF tensors
800
+ with accelerate.init_empty_weights():
801
+ model = Qwen3ForCausalLM(qwen_config)
802
+
803
+ # Load the GGUF weights with assign=True
804
+ # GGMLTensor wrappers will be dequantized on-the-fly during inference
805
+ model.load_state_dict(sd, strict=False, assign=True)
806
+
807
+ # Dequantize embed_tokens weight - embedding lookups require indexed access
808
+ # which quantized GGMLTensors can't efficiently provide (no __torch_dispatch__ for embedding)
809
+ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
810
+
811
+ embed_tokens_weight = model.model.embed_tokens.weight
812
+ if isinstance(embed_tokens_weight, GGMLTensor):
813
+ dequantized = embed_tokens_weight.get_dequantized_tensor()
814
+ model.model.embed_tokens.weight = torch.nn.Parameter(dequantized, requires_grad=False)
815
+ logger.info("Dequantized embed_tokens weight for embedding lookups")
816
+
817
+ # Handle tied weights - llama.cpp GGUF doesn't include lm_head.weight when embeddings are tied
818
+ # So we need to manually tie them after loading
819
+ if qwen_config.tie_word_embeddings:
820
+ # Check if lm_head.weight is still a meta tensor (wasn't in GGUF state dict)
821
+ if model.lm_head.weight.is_meta:
822
+ # Directly assign embed_tokens weight to lm_head (now dequantized)
823
+ model.lm_head.weight = model.model.embed_tokens.weight
824
+ logger.info("Tied lm_head.weight to embed_tokens.weight (GGUF tied embeddings)")
825
+ else:
826
+ # If lm_head.weight was loaded, use standard tie_weights
827
+ model.tie_weights()
828
+
829
+ # Re-initialize any remaining meta tensor buffers (like rotary embeddings inv_freq)
830
+ for name, buffer in list(model.named_buffers()):
831
+ if buffer.is_meta:
832
+ parts = name.rsplit(".", 1)
833
+ if len(parts) == 2:
834
+ parent = model.get_submodule(parts[0])
835
+ buffer_name = parts[1]
836
+ else:
837
+ parent = model
838
+ buffer_name = name
839
+
840
+ if buffer_name == "inv_freq":
841
+ # Compute inv_freq from config - keep on CPU, cache system will move to GPU as needed
842
+ base = qwen_config.rope_theta
843
+ inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
844
+ parent.register_buffer(buffer_name, inv_freq.to(dtype=compute_dtype), persistent=False)
845
+ else:
846
+ logger.warning(f"Re-initializing unknown meta buffer: {name}")
847
+
848
+ # Final check: ensure no meta tensors remain in parameters
849
+ meta_params = [(name, p) for name, p in model.named_parameters() if p.is_meta]
850
+ if meta_params:
851
+ meta_names = [name for name, _ in meta_params]
852
+ raise RuntimeError(
853
+ f"Failed to load all parameters from GGUF. The following remain as meta tensors: {meta_names}. "
854
+ "This may indicate missing keys in the GGUF file or a key mapping issue."
855
+ )
856
+
857
+ return model
858
+
859
+ def _convert_llamacpp_to_pytorch(self, sd: dict[str, Any]) -> dict[str, Any]:
860
+ """Convert llama.cpp GGUF keys to PyTorch/HuggingFace format for Qwen models.
861
+
862
+ llama.cpp format:
863
+ - blk.X.attn_q.weight -> model.layers.X.self_attn.q_proj.weight
864
+ - blk.X.attn_k.weight -> model.layers.X.self_attn.k_proj.weight
865
+ - blk.X.attn_v.weight -> model.layers.X.self_attn.v_proj.weight
866
+ - blk.X.attn_output.weight -> model.layers.X.self_attn.o_proj.weight
867
+ - blk.X.attn_q_norm.weight -> model.layers.X.self_attn.q_norm.weight (Qwen3 QK norm)
868
+ - blk.X.attn_k_norm.weight -> model.layers.X.self_attn.k_norm.weight (Qwen3 QK norm)
869
+ - blk.X.ffn_gate.weight -> model.layers.X.mlp.gate_proj.weight
870
+ - blk.X.ffn_up.weight -> model.layers.X.mlp.up_proj.weight
871
+ - blk.X.ffn_down.weight -> model.layers.X.mlp.down_proj.weight
872
+ - blk.X.attn_norm.weight -> model.layers.X.input_layernorm.weight
873
+ - blk.X.ffn_norm.weight -> model.layers.X.post_attention_layernorm.weight
874
+ - token_embd.weight -> model.embed_tokens.weight
875
+ - output_norm.weight -> model.norm.weight
876
+ - output.weight -> lm_head.weight (if not tied)
877
+ """
878
+ import re
879
+
880
+ key_map = {
881
+ "attn_q": "self_attn.q_proj",
882
+ "attn_k": "self_attn.k_proj",
883
+ "attn_v": "self_attn.v_proj",
884
+ "attn_output": "self_attn.o_proj",
885
+ "attn_q_norm": "self_attn.q_norm", # Qwen3 QK normalization
886
+ "attn_k_norm": "self_attn.k_norm", # Qwen3 QK normalization
887
+ "ffn_gate": "mlp.gate_proj",
888
+ "ffn_up": "mlp.up_proj",
889
+ "ffn_down": "mlp.down_proj",
890
+ "attn_norm": "input_layernorm",
891
+ "ffn_norm": "post_attention_layernorm",
892
+ }
893
+
894
+ new_sd: dict[str, Any] = {}
895
+ blk_pattern = re.compile(r"^blk\.(\d+)\.(.+)$")
896
+
897
+ for key, value in sd.items():
898
+ if not isinstance(key, str):
899
+ new_sd[key] = value
900
+ continue
901
+
902
+ # Handle block layers
903
+ match = blk_pattern.match(key)
904
+ if match:
905
+ layer_idx = match.group(1)
906
+ rest = match.group(2)
907
+
908
+ # Split rest into component and suffix (e.g., "attn_q.weight" -> "attn_q", "weight")
909
+ parts = rest.split(".", 1)
910
+ component = parts[0]
911
+ suffix = parts[1] if len(parts) > 1 else ""
912
+
913
+ if component in key_map:
914
+ new_component = key_map[component]
915
+ new_key = f"model.layers.{layer_idx}.{new_component}"
916
+ if suffix:
917
+ new_key += f".{suffix}"
918
+ new_sd[new_key] = value
919
+ else:
920
+ # Unknown component, keep as-is with model.layers prefix
921
+ new_sd[f"model.layers.{layer_idx}.{rest}"] = value
922
+ continue
923
+
924
+ # Handle non-block keys
925
+ if key == "token_embd.weight":
926
+ new_sd["model.embed_tokens.weight"] = value
927
+ elif key == "output_norm.weight":
928
+ new_sd["model.norm.weight"] = value
929
+ elif key == "output.weight":
930
+ new_sd["lm_head.weight"] = value
931
+ else:
932
+ # Keep other keys as-is
933
+ new_sd[key] = value
934
+
935
+ return new_sd