InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any, Literal, Optional, Self
|
|
2
3
|
|
|
3
4
|
from pydantic import Field
|
|
4
5
|
|
|
@@ -11,7 +12,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
|
|
|
11
12
|
raise_if_not_file,
|
|
12
13
|
)
|
|
13
14
|
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
|
14
|
-
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
|
15
|
+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
|
|
15
16
|
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
|
16
17
|
|
|
17
18
|
|
|
@@ -45,12 +46,67 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
|
|
|
45
46
|
return any(isinstance(v, GGMLTensor) for v in state_dict.values())
|
|
46
47
|
|
|
47
48
|
|
|
49
|
+
def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
|
|
50
|
+
"""Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
|
|
51
|
+
|
|
52
|
+
The hidden_size can be determined from the embed_tokens.weight tensor shape:
|
|
53
|
+
- Qwen3 4B: hidden_size = 2560
|
|
54
|
+
- Qwen3 8B: hidden_size = 4096
|
|
55
|
+
|
|
56
|
+
For GGUF format, the key is 'token_embd.weight'.
|
|
57
|
+
For PyTorch format, the key is 'model.embed_tokens.weight'.
|
|
58
|
+
"""
|
|
59
|
+
# Hidden size thresholds
|
|
60
|
+
QWEN3_4B_HIDDEN_SIZE = 2560
|
|
61
|
+
QWEN3_8B_HIDDEN_SIZE = 4096
|
|
62
|
+
|
|
63
|
+
# Try to find embed_tokens weight
|
|
64
|
+
embed_key = None
|
|
65
|
+
for key in state_dict.keys():
|
|
66
|
+
if isinstance(key, str):
|
|
67
|
+
if key == "model.embed_tokens.weight" or key == "token_embd.weight":
|
|
68
|
+
embed_key = key
|
|
69
|
+
break
|
|
70
|
+
|
|
71
|
+
if embed_key is None:
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
tensor = state_dict[embed_key]
|
|
75
|
+
|
|
76
|
+
# Get hidden_size from tensor shape
|
|
77
|
+
# Shape is [vocab_size, hidden_size]
|
|
78
|
+
if isinstance(tensor, GGMLTensor):
|
|
79
|
+
# GGUF tensor
|
|
80
|
+
if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
|
|
81
|
+
hidden_size = tensor.shape[1]
|
|
82
|
+
else:
|
|
83
|
+
return None
|
|
84
|
+
elif hasattr(tensor, "shape"):
|
|
85
|
+
# PyTorch tensor
|
|
86
|
+
if len(tensor.shape) >= 2:
|
|
87
|
+
hidden_size = tensor.shape[1]
|
|
88
|
+
else:
|
|
89
|
+
return None
|
|
90
|
+
else:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# Determine variant based on hidden_size
|
|
94
|
+
if hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
|
95
|
+
return Qwen3VariantType.Qwen3_4B
|
|
96
|
+
elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
|
97
|
+
return Qwen3VariantType.Qwen3_8B
|
|
98
|
+
else:
|
|
99
|
+
# Unknown size, default to 4B (more common)
|
|
100
|
+
return Qwen3VariantType.Qwen3_4B
|
|
101
|
+
|
|
102
|
+
|
|
48
103
|
class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
|
|
49
104
|
"""Configuration for single-file Qwen3 Encoder models (safetensors)."""
|
|
50
105
|
|
|
51
106
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
52
107
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
53
108
|
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
109
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
54
110
|
|
|
55
111
|
@classmethod
|
|
56
112
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -62,7 +118,17 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
62
118
|
|
|
63
119
|
cls._validate_does_not_look_like_gguf_quantized(mod)
|
|
64
120
|
|
|
65
|
-
|
|
121
|
+
# Determine variant from state dict
|
|
122
|
+
variant = cls._get_variant_or_default(mod)
|
|
123
|
+
|
|
124
|
+
return cls(variant=variant, **override_fields)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
|
128
|
+
"""Get variant from state dict, defaulting to 4B if unknown."""
|
|
129
|
+
state_dict = mod.load_state_dict()
|
|
130
|
+
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
|
131
|
+
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
|
66
132
|
|
|
67
133
|
@classmethod
|
|
68
134
|
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|
|
@@ -87,6 +153,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
87
153
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
88
154
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
89
155
|
format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
|
|
156
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
90
157
|
|
|
91
158
|
@classmethod
|
|
92
159
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -94,6 +161,16 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
94
161
|
|
|
95
162
|
raise_for_override_fields(cls, override_fields)
|
|
96
163
|
|
|
164
|
+
# Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
|
|
165
|
+
# Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
|
|
166
|
+
model_index_path = mod.path / "model_index.json"
|
|
167
|
+
transformer_path = mod.path / "transformer"
|
|
168
|
+
if model_index_path.exists() or transformer_path.exists():
|
|
169
|
+
raise NotAMatchError(
|
|
170
|
+
"directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
|
|
171
|
+
"not a standalone Qwen3 encoder"
|
|
172
|
+
)
|
|
173
|
+
|
|
97
174
|
# Check for text_encoder config - support both:
|
|
98
175
|
# 1. Full model structure: model_root/text_encoder/config.json
|
|
99
176
|
# 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
|
|
@@ -105,8 +182,6 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
105
182
|
elif config_path_direct.exists():
|
|
106
183
|
expected_config_path = config_path_direct
|
|
107
184
|
else:
|
|
108
|
-
from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
|
|
109
|
-
|
|
110
185
|
raise NotAMatchError(
|
|
111
186
|
f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
|
|
112
187
|
)
|
|
@@ -121,7 +196,30 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
|
|
|
121
196
|
},
|
|
122
197
|
)
|
|
123
198
|
|
|
124
|
-
|
|
199
|
+
# Determine variant from config.json hidden_size
|
|
200
|
+
variant = cls._get_variant_from_config(expected_config_path)
|
|
201
|
+
|
|
202
|
+
return cls(variant=variant, **override_fields)
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
|
|
206
|
+
"""Get variant from config.json based on hidden_size."""
|
|
207
|
+
QWEN3_4B_HIDDEN_SIZE = 2560
|
|
208
|
+
QWEN3_8B_HIDDEN_SIZE = 4096
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
212
|
+
config = json.load(f)
|
|
213
|
+
hidden_size = config.get("hidden_size")
|
|
214
|
+
if hidden_size == QWEN3_8B_HIDDEN_SIZE:
|
|
215
|
+
return Qwen3VariantType.Qwen3_8B
|
|
216
|
+
elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
|
|
217
|
+
return Qwen3VariantType.Qwen3_4B
|
|
218
|
+
else:
|
|
219
|
+
# Default to 4B for unknown sizes
|
|
220
|
+
return Qwen3VariantType.Qwen3_4B
|
|
221
|
+
except (json.JSONDecodeError, OSError):
|
|
222
|
+
return Qwen3VariantType.Qwen3_4B
|
|
125
223
|
|
|
126
224
|
|
|
127
225
|
class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
@@ -130,6 +228,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
130
228
|
base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
|
|
131
229
|
type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
|
|
132
230
|
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
|
|
231
|
+
variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
|
|
133
232
|
|
|
134
233
|
@classmethod
|
|
135
234
|
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
@@ -141,7 +240,17 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
|
|
|
141
240
|
|
|
142
241
|
cls._validate_looks_like_gguf_quantized(mod)
|
|
143
242
|
|
|
144
|
-
|
|
243
|
+
# Determine variant from state dict
|
|
244
|
+
variant = cls._get_variant_or_default(mod)
|
|
245
|
+
|
|
246
|
+
return cls(variant=variant, **override_fields)
|
|
247
|
+
|
|
248
|
+
@classmethod
|
|
249
|
+
def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
|
|
250
|
+
"""Get variant from state dict, defaulting to 4B if unknown."""
|
|
251
|
+
state_dict = mod.load_state_dict()
|
|
252
|
+
variant = _get_qwen3_variant_from_state_dict(state_dict)
|
|
253
|
+
return variant if variant is not None else Qwen3VariantType.Qwen3_4B
|
|
145
254
|
|
|
146
255
|
@classmethod
|
|
147
256
|
def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
|
|
@@ -33,6 +33,25 @@ REGEX_TO_BASE: dict[str, BaseModelType] = {
|
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
|
|
36
|
+
def _is_flux2_vae(state_dict: dict[str | int, Any]) -> bool:
|
|
37
|
+
"""Check if state dict is a FLUX.2 VAE (AutoencoderKLFlux2).
|
|
38
|
+
|
|
39
|
+
FLUX.2 VAE can be identified by:
|
|
40
|
+
1. Batch Normalization layers (bn.running_mean, bn.running_var) - unique to FLUX.2
|
|
41
|
+
2. 32-dimensional latent space (decoder.conv_in has 32 input channels)
|
|
42
|
+
|
|
43
|
+
FLUX.1 VAE has 16-dimensional latent space and no BatchNorm layers.
|
|
44
|
+
"""
|
|
45
|
+
# Check for BN layer which is unique to FLUX.2 VAE
|
|
46
|
+
has_bn = "bn.running_mean" in state_dict or "bn.running_var" in state_dict
|
|
47
|
+
|
|
48
|
+
# Check for 32-channel latent space (FLUX.2 has 32, FLUX.1 has 16)
|
|
49
|
+
decoder_conv_in_key = "decoder.conv_in.weight"
|
|
50
|
+
has_32_latent_channels = decoder_conv_in_key in state_dict and state_dict[decoder_conv_in_key].shape[1] == 32
|
|
51
|
+
|
|
52
|
+
return has_bn or has_32_latent_channels
|
|
53
|
+
|
|
54
|
+
|
|
36
55
|
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
37
56
|
"""Model config for standalone VAE models."""
|
|
38
57
|
|
|
@@ -61,8 +80,9 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
|
61
80
|
|
|
62
81
|
@classmethod
|
|
63
82
|
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
|
83
|
+
state_dict = mod.load_state_dict()
|
|
64
84
|
if not state_dict_has_any_keys_starting_with(
|
|
65
|
-
|
|
85
|
+
state_dict,
|
|
66
86
|
{
|
|
67
87
|
"encoder.conv_in",
|
|
68
88
|
"decoder.conv_in",
|
|
@@ -70,9 +90,30 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
|
70
90
|
):
|
|
71
91
|
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
|
72
92
|
|
|
93
|
+
# Exclude FLUX.2 VAEs - they have their own config class
|
|
94
|
+
if _is_flux2_vae(state_dict):
|
|
95
|
+
raise NotAMatchError("model is a FLUX.2 VAE, not a standard VAE")
|
|
96
|
+
|
|
73
97
|
@classmethod
|
|
74
98
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
75
|
-
#
|
|
99
|
+
# First, try to identify by latent space dimensions (most reliable)
|
|
100
|
+
state_dict = mod.load_state_dict()
|
|
101
|
+
decoder_conv_in_key = "decoder.conv_in.weight"
|
|
102
|
+
if decoder_conv_in_key in state_dict:
|
|
103
|
+
latent_channels = state_dict[decoder_conv_in_key].shape[1]
|
|
104
|
+
if latent_channels == 16:
|
|
105
|
+
# Flux1 VAE has 16-dimensional latent space
|
|
106
|
+
return BaseModelType.Flux
|
|
107
|
+
elif latent_channels == 4:
|
|
108
|
+
# SD/SDXL VAE has 4-dimensional latent space
|
|
109
|
+
# Try to distinguish SD1/SD2/SDXL by name, fallback to SD1
|
|
110
|
+
for regexp, base in REGEX_TO_BASE.items():
|
|
111
|
+
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
|
112
|
+
return base
|
|
113
|
+
# Default to SD1 if we can't determine from name
|
|
114
|
+
return BaseModelType.StableDiffusion1
|
|
115
|
+
|
|
116
|
+
# Fallback: guess based on name
|
|
76
117
|
for regexp, base in REGEX_TO_BASE.items():
|
|
77
118
|
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
|
78
119
|
return base
|
|
@@ -96,6 +137,44 @@ class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
96
137
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
97
138
|
|
|
98
139
|
|
|
140
|
+
class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base):
|
|
141
|
+
"""Model config for FLUX.2 VAE checkpoint models (AutoencoderKLFlux2)."""
|
|
142
|
+
|
|
143
|
+
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
144
|
+
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
145
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
149
|
+
raise_if_not_file(mod)
|
|
150
|
+
|
|
151
|
+
raise_for_override_fields(cls, override_fields)
|
|
152
|
+
|
|
153
|
+
cls._validate_looks_like_vae(mod)
|
|
154
|
+
|
|
155
|
+
cls._validate_is_flux2_vae(mod)
|
|
156
|
+
|
|
157
|
+
return cls(**override_fields)
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
|
161
|
+
if not state_dict_has_any_keys_starting_with(
|
|
162
|
+
mod.load_state_dict(),
|
|
163
|
+
{
|
|
164
|
+
"encoder.conv_in",
|
|
165
|
+
"decoder.conv_in",
|
|
166
|
+
},
|
|
167
|
+
):
|
|
168
|
+
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def _validate_is_flux2_vae(cls, mod: ModelOnDisk) -> None:
|
|
172
|
+
"""Validate that this is a FLUX.2 VAE, not FLUX.1."""
|
|
173
|
+
state_dict = mod.load_state_dict()
|
|
174
|
+
if not _is_flux2_vae(state_dict):
|
|
175
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 VAE")
|
|
176
|
+
|
|
177
|
+
|
|
99
178
|
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
|
100
179
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
101
180
|
|
|
@@ -161,3 +240,26 @@ class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
|
161
240
|
|
|
162
241
|
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
163
242
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class VAE_Diffusers_Flux2_Config(Diffusers_Config_Base, Config_Base):
|
|
246
|
+
"""Model config for FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2)."""
|
|
247
|
+
|
|
248
|
+
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
249
|
+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
250
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
254
|
+
raise_if_not_dir(mod)
|
|
255
|
+
|
|
256
|
+
raise_for_override_fields(cls, override_fields)
|
|
257
|
+
|
|
258
|
+
raise_for_class_name(
|
|
259
|
+
common_config_paths(mod.path),
|
|
260
|
+
{
|
|
261
|
+
"AutoencoderKLFlux2",
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return cls(**override_fields)
|
|
@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
|
55
55
|
return wrapper
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
59
|
+
"""A decorator that records activity after a method completes successfully.
|
|
60
|
+
|
|
61
|
+
Note: This decorator should be applied to methods that already hold self._lock.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
@wraps(method)
|
|
65
|
+
def wrapper(self, *args, **kwargs):
|
|
66
|
+
result = method(self, *args, **kwargs)
|
|
67
|
+
self._record_activity()
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
return wrapper
|
|
71
|
+
|
|
72
|
+
|
|
58
73
|
@dataclass
|
|
59
74
|
class CacheEntrySnapshot:
|
|
60
75
|
cache_key: str
|
|
@@ -132,6 +147,7 @@ class ModelCache:
|
|
|
132
147
|
storage_device: torch.device | str = "cpu",
|
|
133
148
|
log_memory_usage: bool = False,
|
|
134
149
|
logger: Optional[Logger] = None,
|
|
150
|
+
keep_alive_minutes: float = 0,
|
|
135
151
|
):
|
|
136
152
|
"""Initialize the model RAM cache.
|
|
137
153
|
|
|
@@ -151,6 +167,7 @@ class ModelCache:
|
|
|
151
167
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
|
152
168
|
behaviour.
|
|
153
169
|
:param logger: InvokeAILogger to use (otherwise creates one)
|
|
170
|
+
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
|
|
154
171
|
"""
|
|
155
172
|
self._enable_partial_loading = enable_partial_loading
|
|
156
173
|
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
|
|
@@ -182,6 +199,12 @@ class ModelCache:
|
|
|
182
199
|
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
|
183
200
|
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
|
|
184
201
|
|
|
202
|
+
# Keep-alive timeout support
|
|
203
|
+
self._keep_alive_minutes = keep_alive_minutes
|
|
204
|
+
self._last_activity_time: Optional[float] = None
|
|
205
|
+
self._timeout_timer: Optional[threading.Timer] = None
|
|
206
|
+
self._shutdown_event = threading.Event()
|
|
207
|
+
|
|
185
208
|
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
|
|
186
209
|
self._on_cache_hit_callbacks.add(cb)
|
|
187
210
|
|
|
@@ -190,7 +213,7 @@ class ModelCache:
|
|
|
190
213
|
|
|
191
214
|
return unsubscribe
|
|
192
215
|
|
|
193
|
-
def on_cache_miss(self, cb:
|
|
216
|
+
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
|
|
194
217
|
self._on_cache_miss_callbacks.add(cb)
|
|
195
218
|
|
|
196
219
|
def unsubscribe() -> None:
|
|
@@ -217,8 +240,82 @@ class ModelCache:
|
|
|
217
240
|
def stats(self, stats: CacheStats) -> None:
|
|
218
241
|
"""Set the CacheStats object for collecting cache statistics."""
|
|
219
242
|
self._stats = stats
|
|
243
|
+
# Populate the cache size in the stats object when it's set
|
|
244
|
+
if self._stats is not None:
|
|
245
|
+
self._stats.cache_size = self._ram_cache_size_bytes
|
|
246
|
+
|
|
247
|
+
def _record_activity(self) -> None:
|
|
248
|
+
"""Record model activity and reset the timeout timer if configured.
|
|
249
|
+
|
|
250
|
+
Note: This method should only be called when self._lock is already held.
|
|
251
|
+
"""
|
|
252
|
+
if self._keep_alive_minutes <= 0:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
self._last_activity_time = time.time()
|
|
256
|
+
|
|
257
|
+
# Cancel any existing timer
|
|
258
|
+
if self._timeout_timer is not None:
|
|
259
|
+
self._timeout_timer.cancel()
|
|
260
|
+
|
|
261
|
+
# Start a new timer
|
|
262
|
+
timeout_seconds = self._keep_alive_minutes * 60
|
|
263
|
+
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
|
|
264
|
+
# Set as daemon so it doesn't prevent application shutdown
|
|
265
|
+
self._timeout_timer.daemon = True
|
|
266
|
+
self._timeout_timer.start()
|
|
267
|
+
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
|
|
220
268
|
|
|
221
269
|
@synchronized
|
|
270
|
+
@record_activity
|
|
271
|
+
def _on_timeout(self) -> None:
|
|
272
|
+
"""Called when the keep-alive timeout expires. Clears the model cache."""
|
|
273
|
+
if self._shutdown_event.is_set():
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# Double-check if there has been activity since the timer was set
|
|
277
|
+
# This handles the race condition where activity occurred just before the timer fired
|
|
278
|
+
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
|
|
279
|
+
elapsed_minutes = (time.time() - self._last_activity_time) / 60
|
|
280
|
+
if elapsed_minutes < self._keep_alive_minutes:
|
|
281
|
+
# Activity occurred, don't clear cache
|
|
282
|
+
self._logger.debug(
|
|
283
|
+
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
|
|
284
|
+
f"Skipping cache clear."
|
|
285
|
+
)
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
# Check if there are any unlocked models that can be cleared
|
|
289
|
+
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
|
|
290
|
+
|
|
291
|
+
if len(unlocked_models) > 0:
|
|
292
|
+
self._logger.info(
|
|
293
|
+
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
|
|
294
|
+
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
|
|
295
|
+
)
|
|
296
|
+
# Clear the cache by requesting a very large amount of space.
|
|
297
|
+
# This is the same logic used by the "Clear Model Cache" button.
|
|
298
|
+
# Using 1000 GB ensures all unlocked models are removed.
|
|
299
|
+
self._make_room_internal(1000 * GB)
|
|
300
|
+
elif len(self._cached_models) > 0:
|
|
301
|
+
# All models are locked, don't log at info level
|
|
302
|
+
self._logger.debug(
|
|
303
|
+
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
|
|
304
|
+
f"Skipping cache clear."
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
self._logger.debug("Model cache timeout fired but cache is already empty.")
|
|
308
|
+
|
|
309
|
+
@synchronized
|
|
310
|
+
def shutdown(self) -> None:
|
|
311
|
+
"""Shutdown the model cache, cancelling any pending timers."""
|
|
312
|
+
self._shutdown_event.set()
|
|
313
|
+
if self._timeout_timer is not None:
|
|
314
|
+
self._timeout_timer.cancel()
|
|
315
|
+
self._timeout_timer = None
|
|
316
|
+
|
|
317
|
+
@synchronized
|
|
318
|
+
@record_activity
|
|
222
319
|
def put(self, key: str, model: AnyModel) -> None:
|
|
223
320
|
"""Add a model to the cache."""
|
|
224
321
|
if key in self._cached_models:
|
|
@@ -228,7 +325,7 @@ class ModelCache:
|
|
|
228
325
|
return
|
|
229
326
|
|
|
230
327
|
size = calc_model_size_by_data(self._logger, model)
|
|
231
|
-
self.
|
|
328
|
+
self._make_room_internal(size)
|
|
232
329
|
|
|
233
330
|
# Inject custom modules into the model.
|
|
234
331
|
if isinstance(model, torch.nn.Module):
|
|
@@ -272,6 +369,7 @@ class ModelCache:
|
|
|
272
369
|
return overview
|
|
273
370
|
|
|
274
371
|
@synchronized
|
|
372
|
+
@record_activity
|
|
275
373
|
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
|
276
374
|
"""Retrieve a model from the cache.
|
|
277
375
|
|
|
@@ -309,9 +407,11 @@ class ModelCache:
|
|
|
309
407
|
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
|
310
408
|
for cb in self._on_cache_hit_callbacks:
|
|
311
409
|
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
|
|
410
|
+
|
|
312
411
|
return cache_entry
|
|
313
412
|
|
|
314
413
|
@synchronized
|
|
414
|
+
@record_activity
|
|
315
415
|
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
|
316
416
|
"""Lock a model for use and move it into VRAM."""
|
|
317
417
|
if cache_entry.key not in self._cached_models:
|
|
@@ -348,6 +448,7 @@ class ModelCache:
|
|
|
348
448
|
self._log_cache_state()
|
|
349
449
|
|
|
350
450
|
@synchronized
|
|
451
|
+
@record_activity
|
|
351
452
|
def unlock(self, cache_entry: CacheRecord) -> None:
|
|
352
453
|
"""Unlock a model."""
|
|
353
454
|
if cache_entry.key not in self._cached_models:
|
|
@@ -691,6 +792,10 @@ class ModelCache:
|
|
|
691
792
|
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
|
692
793
|
garbage-collected.
|
|
693
794
|
"""
|
|
795
|
+
self._make_room_internal(bytes_needed)
|
|
796
|
+
|
|
797
|
+
def _make_room_internal(self, bytes_needed: int) -> None:
|
|
798
|
+
"""Internal implementation of make_room(). Assumes the lock is already held."""
|
|
694
799
|
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
|
|
695
800
|
self._log_cache_state(title="Before dropping models:")
|
|
696
801
|
|
|
@@ -45,12 +45,13 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
|
|
|
45
45
|
model_path,
|
|
46
46
|
torch_dtype=dtype,
|
|
47
47
|
variant=variant,
|
|
48
|
+
local_files_only=True,
|
|
48
49
|
)
|
|
49
50
|
except OSError as e:
|
|
50
51
|
if variant and "no file named" in str(
|
|
51
52
|
e
|
|
52
53
|
): # try without the variant, just in case user's preferences changed
|
|
53
|
-
result = load_class.from_pretrained(model_path, torch_dtype=dtype)
|
|
54
|
+
result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
|
|
54
55
|
else:
|
|
55
56
|
raise e
|
|
56
57
|
|