InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -150,11 +150,16 @@ class LoRA_LyCORIS_Config_Base(LoRA_Config_Base):
150
150
 
151
151
  @classmethod
152
152
  def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
153
- # First rule out ControlLoRA and Diffusers LoRA
153
+ # First rule out ControlLoRA
154
154
  flux_format = _get_flux_lora_format(mod)
155
155
  if flux_format in [FluxLoRAFormat.Control]:
156
156
  raise NotAMatchError("model looks like Control LoRA")
157
157
 
158
+ # If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
159
+ # it's valid and we skip the heuristic check
160
+ if flux_format is not None:
161
+ return
162
+
158
163
  # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
159
164
  # Some main models have these keys, likely due to the creator merging in a LoRA.
160
165
  has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
@@ -217,6 +222,37 @@ class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base):
217
222
  base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
218
223
 
219
224
 
225
+ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
226
+ """Model config for Z-Image LoRA models in LyCORIS format."""
227
+
228
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
229
+
230
+ @classmethod
231
+ def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
232
+ """Z-Image LoRAs are identified by their diffusion_model.layers structure.
233
+
234
+ Z-Image uses S3-DiT architecture with layer names like:
235
+ - diffusion_model.layers.0.attention.to_k.lora_A.weight
236
+ - diffusion_model.layers.0.feed_forward.w1.lora_A.weight
237
+ """
238
+ state_dict = mod.load_state_dict()
239
+
240
+ # Check for Z-Image transformer layer patterns
241
+ # Z-Image uses diffusion_model.layers.X structure (unlike Flux which uses double_blocks/single_blocks)
242
+ has_z_image_keys = state_dict_has_any_keys_starting_with(
243
+ state_dict,
244
+ {
245
+ "diffusion_model.layers.", # Z-Image S3-DiT layer pattern
246
+ },
247
+ )
248
+
249
+ # If it looks like a Z-Image LoRA, return ZImage base
250
+ if has_z_image_keys:
251
+ return BaseModelType.ZImage
252
+
253
+ raise NotAMatchError("model does not look like a Z-Image LoRA")
254
+
255
+
220
256
  class ControlAdapter_Config_Base(ABC, BaseModel):
221
257
  default_settings: ControlAdapterDefaultSettings | None = Field(None)
222
258
 
@@ -320,3 +356,9 @@ class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base):
320
356
 
321
357
  class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base):
322
358
  base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
359
+
360
+
361
+ class LoRA_Diffusers_ZImage_Config(LoRA_Diffusers_Config_Base, Config_Base):
362
+ """Model config for Z-Image LoRA models in Diffusers format."""
363
+
364
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
@@ -60,6 +60,8 @@ class MainModelDefaultSettings(BaseModel):
60
60
  return cls(width=768, height=768)
61
61
  case BaseModelType.StableDiffusionXL:
62
62
  return cls(width=1024, height=1024)
63
+ case BaseModelType.ZImage:
64
+ return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
63
65
  case _:
64
66
  # TODO(psyche): Do we want defaults for other base types?
65
67
  return None
@@ -111,6 +113,28 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
111
113
  return False
112
114
 
113
115
 
116
+ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
117
+ """Check if state dict contains Z-Image S3-DiT transformer keys."""
118
+ # Z-Image specific keys that distinguish it from other models
119
+ z_image_specific_keys = {
120
+ "cap_embedder", # Caption embedder - unique to Z-Image
121
+ "context_refiner", # Context refiner blocks
122
+ "cap_pad_token", # Caption padding token
123
+ }
124
+
125
+ for key in state_dict.keys():
126
+ if isinstance(key, int):
127
+ continue
128
+ # Check for Z-Image specific key prefixes
129
+ # Handle both direct keys (cap_embedder.0.weight) and
130
+ # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
131
+ key_parts = key.split(".")
132
+ for part in key_parts:
133
+ if part in z_image_specific_keys:
134
+ return True
135
+ return False
136
+
137
+
114
138
  class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
115
139
  """Model config for main checkpoint models."""
116
140
 
@@ -657,3 +681,92 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co
657
681
  **override_fields,
658
682
  repo_variant=repo_variant,
659
683
  )
684
+
685
+
686
+ class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
687
+ """Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
688
+
689
+ base: Literal[BaseModelType.ZImage] = Field(BaseModelType.ZImage)
690
+
691
+ @classmethod
692
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
693
+ raise_if_not_dir(mod)
694
+
695
+ raise_for_override_fields(cls, override_fields)
696
+
697
+ # This check implies the base type - no further validation needed.
698
+ raise_for_class_name(
699
+ common_config_paths(mod.path),
700
+ {
701
+ "ZImagePipeline",
702
+ },
703
+ )
704
+
705
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
706
+
707
+ return cls(
708
+ **override_fields,
709
+ repo_variant=repo_variant,
710
+ )
711
+
712
+
713
+ class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
714
+ """Model config for Z-Image single-file checkpoint models (safetensors, etc)."""
715
+
716
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
717
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
718
+
719
+ @classmethod
720
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
721
+ raise_if_not_file(mod)
722
+
723
+ raise_for_override_fields(cls, override_fields)
724
+
725
+ cls._validate_looks_like_z_image_model(mod)
726
+
727
+ cls._validate_does_not_look_like_gguf_quantized(mod)
728
+
729
+ return cls(**override_fields)
730
+
731
+ @classmethod
732
+ def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
733
+ has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
734
+ if not has_z_image_keys:
735
+ raise NotAMatchError("state dict does not look like a Z-Image model")
736
+
737
+ @classmethod
738
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
739
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
740
+ if has_ggml_tensors:
741
+ raise NotAMatchError("state dict looks like GGUF quantized")
742
+
743
+
744
+ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
745
+ """Model config for GGUF-quantized Z-Image transformer models."""
746
+
747
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
748
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
749
+
750
+ @classmethod
751
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
752
+ raise_if_not_file(mod)
753
+
754
+ raise_for_override_fields(cls, override_fields)
755
+
756
+ cls._validate_looks_like_z_image_model(mod)
757
+
758
+ cls._validate_looks_like_gguf_quantized(mod)
759
+
760
+ return cls(**override_fields)
761
+
762
+ @classmethod
763
+ def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
764
+ has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
765
+ if not has_z_image_keys:
766
+ raise NotAMatchError("state dict does not look like a Z-Image model")
767
+
768
+ @classmethod
769
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
770
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
771
+ if not has_ggml_tensors:
772
+ raise NotAMatchError("state dict does not look like GGUF quantized")
@@ -0,0 +1,156 @@
1
+ from typing import Any, Literal, Self
2
+
3
+ from pydantic import Field
4
+
5
+ from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base
6
+ from invokeai.backend.model_manager.configs.identification_utils import (
7
+ NotAMatchError,
8
+ raise_for_class_name,
9
+ raise_for_override_fields,
10
+ raise_if_not_dir,
11
+ raise_if_not_file,
12
+ )
13
+ from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
14
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
15
+ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
16
+
17
+
18
+ def _has_qwen3_keys(state_dict: dict[str | int, Any]) -> bool:
19
+ """Check if state dict contains Qwen3 model keys.
20
+
21
+ Supports both:
22
+ - PyTorch/diffusers format: model.layers.0., model.embed_tokens.weight
23
+ - GGUF/llama.cpp format: blk.0., token_embd.weight
24
+ """
25
+ # PyTorch/diffusers format indicators
26
+ pytorch_indicators = ["model.layers.0.", "model.embed_tokens.weight"]
27
+ # GGUF/llama.cpp format indicators
28
+ gguf_indicators = ["blk.0.", "token_embd.weight"]
29
+
30
+ for key in state_dict.keys():
31
+ if isinstance(key, str):
32
+ # Check PyTorch format
33
+ for indicator in pytorch_indicators:
34
+ if key.startswith(indicator) or key == indicator:
35
+ return True
36
+ # Check GGUF format
37
+ for indicator in gguf_indicators:
38
+ if key.startswith(indicator) or key == indicator:
39
+ return True
40
+ return False
41
+
42
+
43
+ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
44
+ """Check if state dict contains GGML tensors (GGUF quantized)."""
45
+ return any(isinstance(v, GGMLTensor) for v in state_dict.values())
46
+
47
+
48
+ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
49
+ """Configuration for single-file Qwen3 Encoder models (safetensors)."""
50
+
51
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
52
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
53
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
54
+
55
+ @classmethod
56
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
57
+ raise_if_not_file(mod)
58
+
59
+ raise_for_override_fields(cls, override_fields)
60
+
61
+ cls._validate_looks_like_qwen3_model(mod)
62
+
63
+ cls._validate_does_not_look_like_gguf_quantized(mod)
64
+
65
+ return cls(**override_fields)
66
+
67
+ @classmethod
68
+ def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
69
+ has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
70
+ if not has_qwen3_keys:
71
+ raise NotAMatchError("state dict does not look like a Qwen3 model")
72
+
73
+ @classmethod
74
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
75
+ has_ggml = _has_ggml_tensors(mod.load_state_dict())
76
+ if has_ggml:
77
+ raise NotAMatchError("state dict looks like GGUF quantized")
78
+
79
+
80
+ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
81
+ """Configuration for Qwen3 Encoder models in a diffusers-like format.
82
+
83
+ The model weights are expected to be in a folder called text_encoder inside the model directory,
84
+ compatible with Qwen2VLForConditionalGeneration or similar architectures used by Z-Image.
85
+ """
86
+
87
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
88
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
89
+ format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
90
+
91
+ @classmethod
92
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
93
+ raise_if_not_dir(mod)
94
+
95
+ raise_for_override_fields(cls, override_fields)
96
+
97
+ # Check for text_encoder config - support both:
98
+ # 1. Full model structure: model_root/text_encoder/config.json
99
+ # 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
100
+ config_path_nested = mod.path / "text_encoder" / "config.json"
101
+ config_path_direct = mod.path / "config.json"
102
+
103
+ if config_path_nested.exists():
104
+ expected_config_path = config_path_nested
105
+ elif config_path_direct.exists():
106
+ expected_config_path = config_path_direct
107
+ else:
108
+ from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
109
+
110
+ raise NotAMatchError(
111
+ f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
112
+ )
113
+
114
+ # Qwen3 uses Qwen2VLForConditionalGeneration or similar
115
+ raise_for_class_name(
116
+ expected_config_path,
117
+ {
118
+ "Qwen2VLForConditionalGeneration",
119
+ "Qwen2ForCausalLM",
120
+ "Qwen3ForCausalLM",
121
+ },
122
+ )
123
+
124
+ return cls(**override_fields)
125
+
126
+
127
+ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
128
+ """Configuration for GGUF-quantized Qwen3 Encoder models."""
129
+
130
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
131
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
132
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
133
+
134
+ @classmethod
135
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
136
+ raise_if_not_file(mod)
137
+
138
+ raise_for_override_fields(cls, override_fields)
139
+
140
+ cls._validate_looks_like_qwen3_model(mod)
141
+
142
+ cls._validate_looks_like_gguf_quantized(mod)
143
+
144
+ return cls(**override_fields)
145
+
146
+ @classmethod
147
+ def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
148
+ has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
149
+ if not has_qwen3_keys:
150
+ raise NotAMatchError("state dict does not look like a Qwen3 model")
151
+
152
+ @classmethod
153
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
154
+ has_ggml = _has_ggml_tensors(mod.load_state_dict())
155
+ if not has_ggml:
156
+ raise NotAMatchError("state dict does not look like GGUF quantized")
@@ -0,0 +1,40 @@
1
+ import torch
2
+ from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
3
+
4
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
6
+ CustomModuleMixin,
7
+ )
8
+
9
+
10
+ class CustomDiffusersRMSNorm(DiffusersRMSNorm, CustomModuleMixin):
11
+ """Custom wrapper for diffusers RMSNorm that supports device autocasting for partial model loading."""
12
+
13
+ def _autocast_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
14
+ weight = cast_to_device(self.weight, hidden_states.device) if self.weight is not None else None
15
+ bias = cast_to_device(self.bias, hidden_states.device) if self.bias is not None else None
16
+
17
+ input_dtype = hidden_states.dtype
18
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
19
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
20
+
21
+ if weight is not None:
22
+ # convert into half-precision if necessary
23
+ if weight.dtype in [torch.float16, torch.bfloat16]:
24
+ hidden_states = hidden_states.to(weight.dtype)
25
+ hidden_states = hidden_states * weight
26
+ if bias is not None:
27
+ hidden_states = hidden_states + bias
28
+ else:
29
+ hidden_states = hidden_states.to(input_dtype)
30
+
31
+ return hidden_states
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ if len(self._patches_and_weights) > 0:
35
+ raise RuntimeError("DiffusersRMSNorm layers do not support patches")
36
+
37
+ if self._device_autocasting_enabled:
38
+ return self._autocast_forward(hidden_states)
39
+ else:
40
+ return super().forward(hidden_states)
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
6
+ CustomModuleMixin,
7
+ )
8
+
9
+
10
+ class CustomLayerNorm(torch.nn.LayerNorm, CustomModuleMixin):
11
+ """Custom wrapper for torch.nn.LayerNorm that supports device autocasting for partial model loading."""
12
+
13
+ def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
14
+ weight = cast_to_device(self.weight, input.device) if self.weight is not None else None
15
+ bias = cast_to_device(self.bias, input.device) if self.bias is not None else None
16
+ return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
17
+
18
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
19
+ if len(self._patches_and_weights) > 0:
20
+ raise RuntimeError("LayerNorm layers do not support patches")
21
+
22
+ if self._device_autocasting_enabled:
23
+ return self._autocast_forward(input)
24
+ else:
25
+ return super().forward(input)
@@ -1,14 +1,18 @@
1
1
  from typing import TypeVar
2
2
 
3
3
  import torch
4
+ from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
4
5
 
5
- from invokeai.backend.flux.modules.layers import RMSNorm
6
+ from invokeai.backend.flux.modules.layers import RMSNorm as FluxRMSNorm
6
7
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
7
8
  CustomConv1d,
8
9
  )
9
10
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
10
11
  CustomConv2d,
11
12
  )
13
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_diffusers_rms_norm import (
14
+ CustomDiffusersRMSNorm,
15
+ )
12
16
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
13
17
  CustomEmbedding,
14
18
  )
@@ -18,6 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
18
22
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
19
23
  CustomGroupNorm,
20
24
  )
25
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_layer_norm import (
26
+ CustomLayerNorm,
27
+ )
21
28
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
22
29
  CustomLinear,
23
30
  )
@@ -31,7 +38,9 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
31
38
  torch.nn.Conv2d: CustomConv2d,
32
39
  torch.nn.GroupNorm: CustomGroupNorm,
33
40
  torch.nn.Embedding: CustomEmbedding,
34
- RMSNorm: CustomFluxRMSNorm,
41
+ torch.nn.LayerNorm: CustomLayerNorm,
42
+ FluxRMSNorm: CustomFluxRMSNorm,
43
+ DiffusersRMSNorm: CustomDiffusersRMSNorm,
35
44
  }
36
45
 
37
46
  try:
@@ -41,8 +41,13 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
41
41
  is_state_dict_likely_in_flux_onetrainer_format,
42
42
  lora_model_from_flux_onetrainer_state_dict,
43
43
  )
44
+ from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
45
+ is_state_dict_likely_in_flux_xlabs_format,
46
+ lora_model_from_flux_xlabs_state_dict,
47
+ )
44
48
  from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
45
49
  from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
50
+ from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
46
51
 
47
52
 
48
53
  @ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
@@ -117,6 +122,8 @@ class LoRALoader(ModelLoader):
117
122
  model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
118
123
  elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
119
124
  model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
125
+ elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
126
+ model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
120
127
  else:
121
128
  raise ValueError("LoRA model is in unsupported FLUX format")
122
129
  else:
@@ -124,6 +131,10 @@ class LoRALoader(ModelLoader):
124
131
  elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
125
132
  # Currently, we don't apply any conversions for SD1 and SD2 LoRA models.
126
133
  model = lora_model_from_sd_state_dict(state_dict=state_dict)
134
+ elif self._model_base == BaseModelType.ZImage:
135
+ # Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers.
136
+ # We set alpha=None to use rank as alpha (common default).
137
+ model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None)
127
138
  else:
128
139
  raise ValueError(f"Unsupported LoRA base model: {self._model_base}")
129
140