InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
- from typing import Any, Literal, Self
1
+ import json
2
+ from typing import Any, Literal, Optional, Self
2
3
 
3
4
  from pydantic import Field
4
5
 
@@ -11,7 +12,7 @@ from invokeai.backend.model_manager.configs.identification_utils import (
11
12
  raise_if_not_file,
12
13
  )
13
14
  from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
14
- from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
15
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, Qwen3VariantType
15
16
  from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
16
17
 
17
18
 
@@ -45,12 +46,67 @@ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
45
46
  return any(isinstance(v, GGMLTensor) for v in state_dict.values())
46
47
 
47
48
 
49
+ def _get_qwen3_variant_from_state_dict(state_dict: dict[str | int, Any]) -> Optional[Qwen3VariantType]:
50
+ """Determine Qwen3 variant (4B vs 8B) from state dict based on hidden_size.
51
+
52
+ The hidden_size can be determined from the embed_tokens.weight tensor shape:
53
+ - Qwen3 4B: hidden_size = 2560
54
+ - Qwen3 8B: hidden_size = 4096
55
+
56
+ For GGUF format, the key is 'token_embd.weight'.
57
+ For PyTorch format, the key is 'model.embed_tokens.weight'.
58
+ """
59
+ # Hidden size thresholds
60
+ QWEN3_4B_HIDDEN_SIZE = 2560
61
+ QWEN3_8B_HIDDEN_SIZE = 4096
62
+
63
+ # Try to find embed_tokens weight
64
+ embed_key = None
65
+ for key in state_dict.keys():
66
+ if isinstance(key, str):
67
+ if key == "model.embed_tokens.weight" or key == "token_embd.weight":
68
+ embed_key = key
69
+ break
70
+
71
+ if embed_key is None:
72
+ return None
73
+
74
+ tensor = state_dict[embed_key]
75
+
76
+ # Get hidden_size from tensor shape
77
+ # Shape is [vocab_size, hidden_size]
78
+ if isinstance(tensor, GGMLTensor):
79
+ # GGUF tensor
80
+ if hasattr(tensor, "shape") and len(tensor.shape) >= 2:
81
+ hidden_size = tensor.shape[1]
82
+ else:
83
+ return None
84
+ elif hasattr(tensor, "shape"):
85
+ # PyTorch tensor
86
+ if len(tensor.shape) >= 2:
87
+ hidden_size = tensor.shape[1]
88
+ else:
89
+ return None
90
+ else:
91
+ return None
92
+
93
+ # Determine variant based on hidden_size
94
+ if hidden_size == QWEN3_4B_HIDDEN_SIZE:
95
+ return Qwen3VariantType.Qwen3_4B
96
+ elif hidden_size == QWEN3_8B_HIDDEN_SIZE:
97
+ return Qwen3VariantType.Qwen3_8B
98
+ else:
99
+ # Unknown size, default to 4B (more common)
100
+ return Qwen3VariantType.Qwen3_4B
101
+
102
+
48
103
  class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
49
104
  """Configuration for single-file Qwen3 Encoder models (safetensors)."""
50
105
 
51
106
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
52
107
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
53
108
  format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
109
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
54
110
 
55
111
  @classmethod
56
112
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -62,7 +118,17 @@ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
62
118
 
63
119
  cls._validate_does_not_look_like_gguf_quantized(mod)
64
120
 
65
- return cls(**override_fields)
121
+ # Determine variant from state dict
122
+ variant = cls._get_variant_or_default(mod)
123
+
124
+ return cls(variant=variant, **override_fields)
125
+
126
+ @classmethod
127
+ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
128
+ """Get variant from state dict, defaulting to 4B if unknown."""
129
+ state_dict = mod.load_state_dict()
130
+ variant = _get_qwen3_variant_from_state_dict(state_dict)
131
+ return variant if variant is not None else Qwen3VariantType.Qwen3_4B
66
132
 
67
133
  @classmethod
68
134
  def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
@@ -87,6 +153,7 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
87
153
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
88
154
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
89
155
  format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
156
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
90
157
 
91
158
  @classmethod
92
159
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -94,6 +161,16 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
94
161
 
95
162
  raise_for_override_fields(cls, override_fields)
96
163
 
164
+ # Exclude full pipeline models - these should be matched as main models, not just Qwen3 encoders.
165
+ # Full pipelines have model_index.json at root (diffusers format) or a transformer subfolder.
166
+ model_index_path = mod.path / "model_index.json"
167
+ transformer_path = mod.path / "transformer"
168
+ if model_index_path.exists() or transformer_path.exists():
169
+ raise NotAMatchError(
170
+ "directory looks like a full diffusers pipeline (has model_index.json or transformer folder), "
171
+ "not a standalone Qwen3 encoder"
172
+ )
173
+
97
174
  # Check for text_encoder config - support both:
98
175
  # 1. Full model structure: model_root/text_encoder/config.json
99
176
  # 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
@@ -105,8 +182,6 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
105
182
  elif config_path_direct.exists():
106
183
  expected_config_path = config_path_direct
107
184
  else:
108
- from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
109
-
110
185
  raise NotAMatchError(
111
186
  f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
112
187
  )
@@ -121,7 +196,30 @@ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
121
196
  },
122
197
  )
123
198
 
124
- return cls(**override_fields)
199
+ # Determine variant from config.json hidden_size
200
+ variant = cls._get_variant_from_config(expected_config_path)
201
+
202
+ return cls(variant=variant, **override_fields)
203
+
204
+ @classmethod
205
+ def _get_variant_from_config(cls, config_path) -> Qwen3VariantType:
206
+ """Get variant from config.json based on hidden_size."""
207
+ QWEN3_4B_HIDDEN_SIZE = 2560
208
+ QWEN3_8B_HIDDEN_SIZE = 4096
209
+
210
+ try:
211
+ with open(config_path, "r", encoding="utf-8") as f:
212
+ config = json.load(f)
213
+ hidden_size = config.get("hidden_size")
214
+ if hidden_size == QWEN3_8B_HIDDEN_SIZE:
215
+ return Qwen3VariantType.Qwen3_8B
216
+ elif hidden_size == QWEN3_4B_HIDDEN_SIZE:
217
+ return Qwen3VariantType.Qwen3_4B
218
+ else:
219
+ # Default to 4B for unknown sizes
220
+ return Qwen3VariantType.Qwen3_4B
221
+ except (json.JSONDecodeError, OSError):
222
+ return Qwen3VariantType.Qwen3_4B
125
223
 
126
224
 
127
225
  class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
@@ -130,6 +228,7 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
130
228
  base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
131
229
  type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
132
230
  format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
231
+ variant: Qwen3VariantType = Field(description="Qwen3 model size variant (4B or 8B)")
133
232
 
134
233
  @classmethod
135
234
  def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -141,7 +240,17 @@ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
141
240
 
142
241
  cls._validate_looks_like_gguf_quantized(mod)
143
242
 
144
- return cls(**override_fields)
243
+ # Determine variant from state dict
244
+ variant = cls._get_variant_or_default(mod)
245
+
246
+ return cls(variant=variant, **override_fields)
247
+
248
+ @classmethod
249
+ def _get_variant_or_default(cls, mod: ModelOnDisk) -> Qwen3VariantType:
250
+ """Get variant from state dict, defaulting to 4B if unknown."""
251
+ state_dict = mod.load_state_dict()
252
+ variant = _get_qwen3_variant_from_state_dict(state_dict)
253
+ return variant if variant is not None else Qwen3VariantType.Qwen3_4B
145
254
 
146
255
  @classmethod
147
256
  def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
@@ -33,6 +33,25 @@ REGEX_TO_BASE: dict[str, BaseModelType] = {
33
33
  }
34
34
 
35
35
 
36
+ def _is_flux2_vae(state_dict: dict[str | int, Any]) -> bool:
37
+ """Check if state dict is a FLUX.2 VAE (AutoencoderKLFlux2).
38
+
39
+ FLUX.2 VAE can be identified by:
40
+ 1. Batch Normalization layers (bn.running_mean, bn.running_var) - unique to FLUX.2
41
+ 2. 32-dimensional latent space (decoder.conv_in has 32 input channels)
42
+
43
+ FLUX.1 VAE has 16-dimensional latent space and no BatchNorm layers.
44
+ """
45
+ # Check for BN layer which is unique to FLUX.2 VAE
46
+ has_bn = "bn.running_mean" in state_dict or "bn.running_var" in state_dict
47
+
48
+ # Check for 32-channel latent space (FLUX.2 has 32, FLUX.1 has 16)
49
+ decoder_conv_in_key = "decoder.conv_in.weight"
50
+ has_32_latent_channels = decoder_conv_in_key in state_dict and state_dict[decoder_conv_in_key].shape[1] == 32
51
+
52
+ return has_bn or has_32_latent_channels
53
+
54
+
36
55
  class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
37
56
  """Model config for standalone VAE models."""
38
57
 
@@ -61,8 +80,9 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
61
80
 
62
81
  @classmethod
63
82
  def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
83
+ state_dict = mod.load_state_dict()
64
84
  if not state_dict_has_any_keys_starting_with(
65
- mod.load_state_dict(),
85
+ state_dict,
66
86
  {
67
87
  "encoder.conv_in",
68
88
  "decoder.conv_in",
@@ -70,9 +90,30 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
70
90
  ):
71
91
  raise NotAMatchError("model does not match Checkpoint VAE heuristics")
72
92
 
93
+ # Exclude FLUX.2 VAEs - they have their own config class
94
+ if _is_flux2_vae(state_dict):
95
+ raise NotAMatchError("model is a FLUX.2 VAE, not a standard VAE")
96
+
73
97
  @classmethod
74
98
  def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
75
- # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name
99
+ # First, try to identify by latent space dimensions (most reliable)
100
+ state_dict = mod.load_state_dict()
101
+ decoder_conv_in_key = "decoder.conv_in.weight"
102
+ if decoder_conv_in_key in state_dict:
103
+ latent_channels = state_dict[decoder_conv_in_key].shape[1]
104
+ if latent_channels == 16:
105
+ # Flux1 VAE has 16-dimensional latent space
106
+ return BaseModelType.Flux
107
+ elif latent_channels == 4:
108
+ # SD/SDXL VAE has 4-dimensional latent space
109
+ # Try to distinguish SD1/SD2/SDXL by name, fallback to SD1
110
+ for regexp, base in REGEX_TO_BASE.items():
111
+ if re.search(regexp, mod.path.name, re.IGNORECASE):
112
+ return base
113
+ # Default to SD1 if we can't determine from name
114
+ return BaseModelType.StableDiffusion1
115
+
116
+ # Fallback: guess based on name
76
117
  for regexp, base in REGEX_TO_BASE.items():
77
118
  if re.search(regexp, mod.path.name, re.IGNORECASE):
78
119
  return base
@@ -96,6 +137,44 @@ class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
96
137
  base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
97
138
 
98
139
 
140
+ class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base):
141
+ """Model config for FLUX.2 VAE checkpoint models (AutoencoderKLFlux2)."""
142
+
143
+ type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
144
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
145
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
146
+
147
+ @classmethod
148
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
149
+ raise_if_not_file(mod)
150
+
151
+ raise_for_override_fields(cls, override_fields)
152
+
153
+ cls._validate_looks_like_vae(mod)
154
+
155
+ cls._validate_is_flux2_vae(mod)
156
+
157
+ return cls(**override_fields)
158
+
159
+ @classmethod
160
+ def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
161
+ if not state_dict_has_any_keys_starting_with(
162
+ mod.load_state_dict(),
163
+ {
164
+ "encoder.conv_in",
165
+ "decoder.conv_in",
166
+ },
167
+ ):
168
+ raise NotAMatchError("model does not match Checkpoint VAE heuristics")
169
+
170
+ @classmethod
171
+ def _validate_is_flux2_vae(cls, mod: ModelOnDisk) -> None:
172
+ """Validate that this is a FLUX.2 VAE, not FLUX.1."""
173
+ state_dict = mod.load_state_dict()
174
+ if not _is_flux2_vae(state_dict):
175
+ raise NotAMatchError("state dict does not look like a FLUX.2 VAE")
176
+
177
+
99
178
  class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
100
179
  """Model config for standalone VAE models (diffusers version)."""
101
180
 
@@ -161,3 +240,26 @@ class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
161
240
 
162
241
  class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
163
242
  base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
243
+
244
+
245
+ class VAE_Diffusers_Flux2_Config(Diffusers_Config_Base, Config_Base):
246
+ """Model config for FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2)."""
247
+
248
+ type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
249
+ format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
250
+ base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
251
+
252
+ @classmethod
253
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
254
+ raise_if_not_dir(mod)
255
+
256
+ raise_for_override_fields(cls, override_fields)
257
+
258
+ raise_for_class_name(
259
+ common_config_paths(mod.path),
260
+ {
261
+ "AutoencoderKLFlux2",
262
+ },
263
+ )
264
+
265
+ return cls(**override_fields)
@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
55
55
  return wrapper
56
56
 
57
57
 
58
+ def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
59
+ """A decorator that records activity after a method completes successfully.
60
+
61
+ Note: This decorator should be applied to methods that already hold self._lock.
62
+ """
63
+
64
+ @wraps(method)
65
+ def wrapper(self, *args, **kwargs):
66
+ result = method(self, *args, **kwargs)
67
+ self._record_activity()
68
+ return result
69
+
70
+ return wrapper
71
+
72
+
58
73
  @dataclass
59
74
  class CacheEntrySnapshot:
60
75
  cache_key: str
@@ -132,6 +147,7 @@ class ModelCache:
132
147
  storage_device: torch.device | str = "cpu",
133
148
  log_memory_usage: bool = False,
134
149
  logger: Optional[Logger] = None,
150
+ keep_alive_minutes: float = 0,
135
151
  ):
136
152
  """Initialize the model RAM cache.
137
153
 
@@ -151,6 +167,7 @@ class ModelCache:
151
167
  snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
152
168
  behaviour.
153
169
  :param logger: InvokeAILogger to use (otherwise creates one)
170
+ :param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
154
171
  """
155
172
  self._enable_partial_loading = enable_partial_loading
156
173
  self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
@@ -182,6 +199,12 @@ class ModelCache:
182
199
  self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
183
200
  self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
184
201
 
202
+ # Keep-alive timeout support
203
+ self._keep_alive_minutes = keep_alive_minutes
204
+ self._last_activity_time: Optional[float] = None
205
+ self._timeout_timer: Optional[threading.Timer] = None
206
+ self._shutdown_event = threading.Event()
207
+
185
208
  def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
186
209
  self._on_cache_hit_callbacks.add(cb)
187
210
 
@@ -190,7 +213,7 @@ class ModelCache:
190
213
 
191
214
  return unsubscribe
192
215
 
193
- def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
216
+ def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
194
217
  self._on_cache_miss_callbacks.add(cb)
195
218
 
196
219
  def unsubscribe() -> None:
@@ -217,8 +240,82 @@ class ModelCache:
217
240
  def stats(self, stats: CacheStats) -> None:
218
241
  """Set the CacheStats object for collecting cache statistics."""
219
242
  self._stats = stats
243
+ # Populate the cache size in the stats object when it's set
244
+ if self._stats is not None:
245
+ self._stats.cache_size = self._ram_cache_size_bytes
246
+
247
+ def _record_activity(self) -> None:
248
+ """Record model activity and reset the timeout timer if configured.
249
+
250
+ Note: This method should only be called when self._lock is already held.
251
+ """
252
+ if self._keep_alive_minutes <= 0:
253
+ return
254
+
255
+ self._last_activity_time = time.time()
256
+
257
+ # Cancel any existing timer
258
+ if self._timeout_timer is not None:
259
+ self._timeout_timer.cancel()
260
+
261
+ # Start a new timer
262
+ timeout_seconds = self._keep_alive_minutes * 60
263
+ self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
264
+ # Set as daemon so it doesn't prevent application shutdown
265
+ self._timeout_timer.daemon = True
266
+ self._timeout_timer.start()
267
+ self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
220
268
 
221
269
  @synchronized
270
+ @record_activity
271
+ def _on_timeout(self) -> None:
272
+ """Called when the keep-alive timeout expires. Clears the model cache."""
273
+ if self._shutdown_event.is_set():
274
+ return
275
+
276
+ # Double-check if there has been activity since the timer was set
277
+ # This handles the race condition where activity occurred just before the timer fired
278
+ if self._last_activity_time is not None and self._keep_alive_minutes > 0:
279
+ elapsed_minutes = (time.time() - self._last_activity_time) / 60
280
+ if elapsed_minutes < self._keep_alive_minutes:
281
+ # Activity occurred, don't clear cache
282
+ self._logger.debug(
283
+ f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
284
+ f"Skipping cache clear."
285
+ )
286
+ return
287
+
288
+ # Check if there are any unlocked models that can be cleared
289
+ unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
290
+
291
+ if len(unlocked_models) > 0:
292
+ self._logger.info(
293
+ f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
294
+ f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
295
+ )
296
+ # Clear the cache by requesting a very large amount of space.
297
+ # This is the same logic used by the "Clear Model Cache" button.
298
+ # Using 1000 GB ensures all unlocked models are removed.
299
+ self._make_room_internal(1000 * GB)
300
+ elif len(self._cached_models) > 0:
301
+ # All models are locked, don't log at info level
302
+ self._logger.debug(
303
+ f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
304
+ f"Skipping cache clear."
305
+ )
306
+ else:
307
+ self._logger.debug("Model cache timeout fired but cache is already empty.")
308
+
309
+ @synchronized
310
+ def shutdown(self) -> None:
311
+ """Shutdown the model cache, cancelling any pending timers."""
312
+ self._shutdown_event.set()
313
+ if self._timeout_timer is not None:
314
+ self._timeout_timer.cancel()
315
+ self._timeout_timer = None
316
+
317
+ @synchronized
318
+ @record_activity
222
319
  def put(self, key: str, model: AnyModel) -> None:
223
320
  """Add a model to the cache."""
224
321
  if key in self._cached_models:
@@ -228,7 +325,7 @@ class ModelCache:
228
325
  return
229
326
 
230
327
  size = calc_model_size_by_data(self._logger, model)
231
- self.make_room(size)
328
+ self._make_room_internal(size)
232
329
 
233
330
  # Inject custom modules into the model.
234
331
  if isinstance(model, torch.nn.Module):
@@ -272,6 +369,7 @@ class ModelCache:
272
369
  return overview
273
370
 
274
371
  @synchronized
372
+ @record_activity
275
373
  def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
276
374
  """Retrieve a model from the cache.
277
375
 
@@ -309,9 +407,11 @@ class ModelCache:
309
407
  self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
310
408
  for cb in self._on_cache_hit_callbacks:
311
409
  cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
410
+
312
411
  return cache_entry
313
412
 
314
413
  @synchronized
414
+ @record_activity
315
415
  def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
316
416
  """Lock a model for use and move it into VRAM."""
317
417
  if cache_entry.key not in self._cached_models:
@@ -348,6 +448,7 @@ class ModelCache:
348
448
  self._log_cache_state()
349
449
 
350
450
  @synchronized
451
+ @record_activity
351
452
  def unlock(self, cache_entry: CacheRecord) -> None:
352
453
  """Unlock a model."""
353
454
  if cache_entry.key not in self._cached_models:
@@ -691,6 +792,10 @@ class ModelCache:
691
792
  external references to the model, there's nothing that the cache can do about it, and those models will not be
692
793
  garbage-collected.
693
794
  """
795
+ self._make_room_internal(bytes_needed)
796
+
797
+ def _make_room_internal(self, bytes_needed: int) -> None:
798
+ """Internal implementation of make_room(). Assumes the lock is already held."""
694
799
  self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
695
800
  self._log_cache_state(title="Before dropping models:")
696
801
 
@@ -45,12 +45,13 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
45
45
  model_path,
46
46
  torch_dtype=dtype,
47
47
  variant=variant,
48
+ local_files_only=True,
48
49
  )
49
50
  except OSError as e:
50
51
  if variant and "no file named" in str(
51
52
  e
52
53
  ): # try without the variant, just in case user's preferences changed
53
- result = load_class.from_pretrained(model_path, torch_dtype=dtype)
54
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
54
55
  else:
55
56
  raise e
56
57