InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/flux_denoise.py +15 -1
  6. invokeai/app/invocations/image_to_latents.py +23 -5
  7. invokeai/app/invocations/latents_to_image.py +2 -25
  8. invokeai/app/invocations/metadata.py +9 -1
  9. invokeai/app/invocations/metadata_linked.py +47 -0
  10. invokeai/app/invocations/model.py +8 -0
  11. invokeai/app/invocations/pbr_maps.py +59 -0
  12. invokeai/app/invocations/primitives.py +12 -0
  13. invokeai/app/invocations/prompt_template.py +57 -0
  14. invokeai/app/invocations/z_image_control.py +112 -0
  15. invokeai/app/invocations/z_image_denoise.py +770 -0
  16. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  17. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  18. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  19. invokeai/app/invocations/z_image_model_loader.py +135 -0
  20. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  21. invokeai/app/services/config/config_default.py +3 -1
  22. invokeai/app/services/model_install/model_install_common.py +14 -1
  23. invokeai/app/services/model_install/model_install_default.py +119 -19
  24. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  25. invokeai/app/services/model_records/model_records_base.py +12 -0
  26. invokeai/app/services/model_records/model_records_sql.py +17 -0
  27. invokeai/app/services/shared/graph.py +132 -77
  28. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  29. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  30. invokeai/app/util/step_callback.py +3 -0
  31. invokeai/backend/flux/denoise.py +196 -11
  32. invokeai/backend/flux/schedulers.py +62 -0
  33. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  34. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  35. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  36. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  37. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  38. invokeai/backend/model_manager/configs/factory.py +26 -1
  39. invokeai/backend/model_manager/configs/lora.py +79 -1
  40. invokeai/backend/model_manager/configs/main.py +113 -0
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  42. invokeai/backend/model_manager/load/model_cache/model_cache.py +104 -2
  43. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  44. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  45. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  46. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  47. invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
  48. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  49. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  50. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  51. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  52. invokeai/backend/model_manager/load/model_loaders/z_image.py +969 -0
  53. invokeai/backend/model_manager/load/model_util.py +6 -1
  54. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  55. invokeai/backend/model_manager/model_on_disk.py +3 -0
  56. invokeai/backend/model_manager/starter_models.py +79 -0
  57. invokeai/backend/model_manager/taxonomy.py +5 -0
  58. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  59. invokeai/backend/patches/layer_patcher.py +34 -16
  60. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  61. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  62. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  63. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  64. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  65. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +189 -0
  66. invokeai/backend/quantization/gguf/ggml_tensor.py +38 -4
  67. invokeai/backend/quantization/gguf/loaders.py +47 -12
  68. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  69. invokeai/backend/util/devices.py +25 -0
  70. invokeai/backend/util/hotfixes.py +2 -2
  71. invokeai/backend/z_image/__init__.py +16 -0
  72. invokeai/backend/z_image/extensions/__init__.py +1 -0
  73. invokeai/backend/z_image/extensions/regional_prompting_extension.py +205 -0
  74. invokeai/backend/z_image/text_conditioning.py +74 -0
  75. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  76. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  77. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  78. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  79. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  80. invokeai/frontend/web/dist/assets/App-BBELGD-n.js +161 -0
  81. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-4xPFTMT3.js} +1 -1
  82. invokeai/frontend/web/dist/assets/index-vCDSQboA.js +530 -0
  83. invokeai/frontend/web/dist/index.html +1 -1
  84. invokeai/frontend/web/dist/locales/de.json +24 -6
  85. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  86. invokeai/frontend/web/dist/locales/en.json +78 -3
  87. invokeai/frontend/web/dist/locales/es.json +0 -5
  88. invokeai/frontend/web/dist/locales/fr.json +0 -6
  89. invokeai/frontend/web/dist/locales/it.json +17 -64
  90. invokeai/frontend/web/dist/locales/ja.json +379 -44
  91. invokeai/frontend/web/dist/locales/ru.json +0 -6
  92. invokeai/frontend/web/dist/locales/vi.json +7 -54
  93. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  94. invokeai/version/invokeai_version.py +1 -1
  95. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/METADATA +4 -4
  96. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/RECORD +102 -71
  97. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  98. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  99. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/WHEEL +0 -0
  100. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/entry_points.txt +0 -0
  101. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE +0 -0
  102. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  103. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  104. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/top_level.txt +0 -0
@@ -60,6 +60,8 @@ class MainModelDefaultSettings(BaseModel):
60
60
  return cls(width=768, height=768)
61
61
  case BaseModelType.StableDiffusionXL:
62
62
  return cls(width=1024, height=1024)
63
+ case BaseModelType.ZImage:
64
+ return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
63
65
  case _:
64
66
  # TODO(psyche): Do we want defaults for other base types?
65
67
  return None
@@ -111,6 +113,28 @@ def _has_main_keys(state_dict: dict[str | int, Any]) -> bool:
111
113
  return False
112
114
 
113
115
 
116
+ def _has_z_image_keys(state_dict: dict[str | int, Any]) -> bool:
117
+ """Check if state dict contains Z-Image S3-DiT transformer keys."""
118
+ # Z-Image specific keys that distinguish it from other models
119
+ z_image_specific_keys = {
120
+ "cap_embedder", # Caption embedder - unique to Z-Image
121
+ "context_refiner", # Context refiner blocks
122
+ "cap_pad_token", # Caption padding token
123
+ }
124
+
125
+ for key in state_dict.keys():
126
+ if isinstance(key, int):
127
+ continue
128
+ # Check for Z-Image specific key prefixes
129
+ # Handle both direct keys (cap_embedder.0.weight) and
130
+ # ComfyUI-style keys (model.diffusion_model.cap_embedder.0.weight)
131
+ key_parts = key.split(".")
132
+ for part in key_parts:
133
+ if part in z_image_specific_keys:
134
+ return True
135
+ return False
136
+
137
+
114
138
  class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
115
139
  """Model config for main checkpoint models."""
116
140
 
@@ -657,3 +681,92 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co
657
681
  **override_fields,
658
682
  repo_variant=repo_variant,
659
683
  )
684
+
685
+
686
+ class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
687
+ """Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
688
+
689
+ base: Literal[BaseModelType.ZImage] = Field(BaseModelType.ZImage)
690
+
691
+ @classmethod
692
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
693
+ raise_if_not_dir(mod)
694
+
695
+ raise_for_override_fields(cls, override_fields)
696
+
697
+ # This check implies the base type - no further validation needed.
698
+ raise_for_class_name(
699
+ common_config_paths(mod.path),
700
+ {
701
+ "ZImagePipeline",
702
+ },
703
+ )
704
+
705
+ repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
706
+
707
+ return cls(
708
+ **override_fields,
709
+ repo_variant=repo_variant,
710
+ )
711
+
712
+
713
+ class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
714
+ """Model config for Z-Image single-file checkpoint models (safetensors, etc)."""
715
+
716
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
717
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
718
+
719
+ @classmethod
720
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
721
+ raise_if_not_file(mod)
722
+
723
+ raise_for_override_fields(cls, override_fields)
724
+
725
+ cls._validate_looks_like_z_image_model(mod)
726
+
727
+ cls._validate_does_not_look_like_gguf_quantized(mod)
728
+
729
+ return cls(**override_fields)
730
+
731
+ @classmethod
732
+ def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
733
+ has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
734
+ if not has_z_image_keys:
735
+ raise NotAMatchError("state dict does not look like a Z-Image model")
736
+
737
+ @classmethod
738
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
739
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
740
+ if has_ggml_tensors:
741
+ raise NotAMatchError("state dict looks like GGUF quantized")
742
+
743
+
744
+ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
745
+ """Model config for GGUF-quantized Z-Image transformer models."""
746
+
747
+ base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
748
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
749
+
750
+ @classmethod
751
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
752
+ raise_if_not_file(mod)
753
+
754
+ raise_for_override_fields(cls, override_fields)
755
+
756
+ cls._validate_looks_like_z_image_model(mod)
757
+
758
+ cls._validate_looks_like_gguf_quantized(mod)
759
+
760
+ return cls(**override_fields)
761
+
762
+ @classmethod
763
+ def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
764
+ has_z_image_keys = _has_z_image_keys(mod.load_state_dict())
765
+ if not has_z_image_keys:
766
+ raise NotAMatchError("state dict does not look like a Z-Image model")
767
+
768
+ @classmethod
769
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
770
+ has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict())
771
+ if not has_ggml_tensors:
772
+ raise NotAMatchError("state dict does not look like GGUF quantized")
@@ -0,0 +1,156 @@
1
+ from typing import Any, Literal, Self
2
+
3
+ from pydantic import Field
4
+
5
+ from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Config_Base
6
+ from invokeai.backend.model_manager.configs.identification_utils import (
7
+ NotAMatchError,
8
+ raise_for_class_name,
9
+ raise_for_override_fields,
10
+ raise_if_not_dir,
11
+ raise_if_not_file,
12
+ )
13
+ from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
14
+ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
15
+ from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
16
+
17
+
18
+ def _has_qwen3_keys(state_dict: dict[str | int, Any]) -> bool:
19
+ """Check if state dict contains Qwen3 model keys.
20
+
21
+ Supports both:
22
+ - PyTorch/diffusers format: model.layers.0., model.embed_tokens.weight
23
+ - GGUF/llama.cpp format: blk.0., token_embd.weight
24
+ """
25
+ # PyTorch/diffusers format indicators
26
+ pytorch_indicators = ["model.layers.0.", "model.embed_tokens.weight"]
27
+ # GGUF/llama.cpp format indicators
28
+ gguf_indicators = ["blk.0.", "token_embd.weight"]
29
+
30
+ for key in state_dict.keys():
31
+ if isinstance(key, str):
32
+ # Check PyTorch format
33
+ for indicator in pytorch_indicators:
34
+ if key.startswith(indicator) or key == indicator:
35
+ return True
36
+ # Check GGUF format
37
+ for indicator in gguf_indicators:
38
+ if key.startswith(indicator) or key == indicator:
39
+ return True
40
+ return False
41
+
42
+
43
+ def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool:
44
+ """Check if state dict contains GGML tensors (GGUF quantized)."""
45
+ return any(isinstance(v, GGMLTensor) for v in state_dict.values())
46
+
47
+
48
+ class Qwen3Encoder_Checkpoint_Config(Checkpoint_Config_Base, Config_Base):
49
+ """Configuration for single-file Qwen3 Encoder models (safetensors)."""
50
+
51
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
52
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
53
+ format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
54
+
55
+ @classmethod
56
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
57
+ raise_if_not_file(mod)
58
+
59
+ raise_for_override_fields(cls, override_fields)
60
+
61
+ cls._validate_looks_like_qwen3_model(mod)
62
+
63
+ cls._validate_does_not_look_like_gguf_quantized(mod)
64
+
65
+ return cls(**override_fields)
66
+
67
+ @classmethod
68
+ def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
69
+ has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
70
+ if not has_qwen3_keys:
71
+ raise NotAMatchError("state dict does not look like a Qwen3 model")
72
+
73
+ @classmethod
74
+ def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
75
+ has_ggml = _has_ggml_tensors(mod.load_state_dict())
76
+ if has_ggml:
77
+ raise NotAMatchError("state dict looks like GGUF quantized")
78
+
79
+
80
+ class Qwen3Encoder_Qwen3Encoder_Config(Config_Base):
81
+ """Configuration for Qwen3 Encoder models in a diffusers-like format.
82
+
83
+ The model weights are expected to be in a folder called text_encoder inside the model directory,
84
+ compatible with Qwen2VLForConditionalGeneration or similar architectures used by Z-Image.
85
+ """
86
+
87
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
88
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
89
+ format: Literal[ModelFormat.Qwen3Encoder] = Field(default=ModelFormat.Qwen3Encoder)
90
+
91
+ @classmethod
92
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
93
+ raise_if_not_dir(mod)
94
+
95
+ raise_for_override_fields(cls, override_fields)
96
+
97
+ # Check for text_encoder config - support both:
98
+ # 1. Full model structure: model_root/text_encoder/config.json
99
+ # 2. Standalone text_encoder download: model_root/config.json (when text_encoder subfolder is downloaded separately)
100
+ config_path_nested = mod.path / "text_encoder" / "config.json"
101
+ config_path_direct = mod.path / "config.json"
102
+
103
+ if config_path_nested.exists():
104
+ expected_config_path = config_path_nested
105
+ elif config_path_direct.exists():
106
+ expected_config_path = config_path_direct
107
+ else:
108
+ from invokeai.backend.model_manager.configs.identification_utils import NotAMatchError
109
+
110
+ raise NotAMatchError(
111
+ f"unable to load config file(s): {{PosixPath('{config_path_nested}'): 'file does not exist'}}"
112
+ )
113
+
114
+ # Qwen3 uses Qwen2VLForConditionalGeneration or similar
115
+ raise_for_class_name(
116
+ expected_config_path,
117
+ {
118
+ "Qwen2VLForConditionalGeneration",
119
+ "Qwen2ForCausalLM",
120
+ "Qwen3ForCausalLM",
121
+ },
122
+ )
123
+
124
+ return cls(**override_fields)
125
+
126
+
127
+ class Qwen3Encoder_GGUF_Config(Checkpoint_Config_Base, Config_Base):
128
+ """Configuration for GGUF-quantized Qwen3 Encoder models."""
129
+
130
+ base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any)
131
+ type: Literal[ModelType.Qwen3Encoder] = Field(default=ModelType.Qwen3Encoder)
132
+ format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
133
+
134
+ @classmethod
135
+ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
136
+ raise_if_not_file(mod)
137
+
138
+ raise_for_override_fields(cls, override_fields)
139
+
140
+ cls._validate_looks_like_qwen3_model(mod)
141
+
142
+ cls._validate_looks_like_gguf_quantized(mod)
143
+
144
+ return cls(**override_fields)
145
+
146
+ @classmethod
147
+ def _validate_looks_like_qwen3_model(cls, mod: ModelOnDisk) -> None:
148
+ has_qwen3_keys = _has_qwen3_keys(mod.load_state_dict())
149
+ if not has_qwen3_keys:
150
+ raise NotAMatchError("state dict does not look like a Qwen3 model")
151
+
152
+ @classmethod
153
+ def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None:
154
+ has_ggml = _has_ggml_tensors(mod.load_state_dict())
155
+ if not has_ggml:
156
+ raise NotAMatchError("state dict does not look like GGUF quantized")
@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
55
55
  return wrapper
56
56
 
57
57
 
58
+ def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
59
+ """A decorator that records activity after a method completes successfully.
60
+
61
+ Note: This decorator should be applied to methods that already hold self._lock.
62
+ """
63
+
64
+ @wraps(method)
65
+ def wrapper(self, *args, **kwargs):
66
+ result = method(self, *args, **kwargs)
67
+ self._record_activity()
68
+ return result
69
+
70
+ return wrapper
71
+
72
+
58
73
  @dataclass
59
74
  class CacheEntrySnapshot:
60
75
  cache_key: str
@@ -132,6 +147,7 @@ class ModelCache:
132
147
  storage_device: torch.device | str = "cpu",
133
148
  log_memory_usage: bool = False,
134
149
  logger: Optional[Logger] = None,
150
+ keep_alive_minutes: float = 0,
135
151
  ):
136
152
  """Initialize the model RAM cache.
137
153
 
@@ -151,6 +167,7 @@ class ModelCache:
151
167
  snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
152
168
  behaviour.
153
169
  :param logger: InvokeAILogger to use (otherwise creates one)
170
+ :param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
154
171
  """
155
172
  self._enable_partial_loading = enable_partial_loading
156
173
  self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
@@ -182,6 +199,12 @@ class ModelCache:
182
199
  self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
183
200
  self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
184
201
 
202
+ # Keep-alive timeout support
203
+ self._keep_alive_minutes = keep_alive_minutes
204
+ self._last_activity_time: Optional[float] = None
205
+ self._timeout_timer: Optional[threading.Timer] = None
206
+ self._shutdown_event = threading.Event()
207
+
185
208
  def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
186
209
  self._on_cache_hit_callbacks.add(cb)
187
210
 
@@ -190,7 +213,7 @@ class ModelCache:
190
213
 
191
214
  return unsubscribe
192
215
 
193
- def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
216
+ def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
194
217
  self._on_cache_miss_callbacks.add(cb)
195
218
 
196
219
  def unsubscribe() -> None:
@@ -218,7 +241,78 @@ class ModelCache:
218
241
  """Set the CacheStats object for collecting cache statistics."""
219
242
  self._stats = stats
220
243
 
244
+ def _record_activity(self) -> None:
245
+ """Record model activity and reset the timeout timer if configured.
246
+
247
+ Note: This method should only be called when self._lock is already held.
248
+ """
249
+ if self._keep_alive_minutes <= 0:
250
+ return
251
+
252
+ self._last_activity_time = time.time()
253
+
254
+ # Cancel any existing timer
255
+ if self._timeout_timer is not None:
256
+ self._timeout_timer.cancel()
257
+
258
+ # Start a new timer
259
+ timeout_seconds = self._keep_alive_minutes * 60
260
+ self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
261
+ # Set as daemon so it doesn't prevent application shutdown
262
+ self._timeout_timer.daemon = True
263
+ self._timeout_timer.start()
264
+ self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
265
+
221
266
  @synchronized
267
+ @record_activity
268
+ def _on_timeout(self) -> None:
269
+ """Called when the keep-alive timeout expires. Clears the model cache."""
270
+ if self._shutdown_event.is_set():
271
+ return
272
+
273
+ # Double-check if there has been activity since the timer was set
274
+ # This handles the race condition where activity occurred just before the timer fired
275
+ if self._last_activity_time is not None and self._keep_alive_minutes > 0:
276
+ elapsed_minutes = (time.time() - self._last_activity_time) / 60
277
+ if elapsed_minutes < self._keep_alive_minutes:
278
+ # Activity occurred, don't clear cache
279
+ self._logger.debug(
280
+ f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
281
+ f"Skipping cache clear."
282
+ )
283
+ return
284
+
285
+ # Check if there are any unlocked models that can be cleared
286
+ unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
287
+
288
+ if len(unlocked_models) > 0:
289
+ self._logger.info(
290
+ f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
291
+ f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
292
+ )
293
+ # Clear the cache by requesting a very large amount of space.
294
+ # This is the same logic used by the "Clear Model Cache" button.
295
+ # Using 1000 GB ensures all unlocked models are removed.
296
+ self._make_room_internal(1000 * GB)
297
+ elif len(self._cached_models) > 0:
298
+ # All models are locked, don't log at info level
299
+ self._logger.debug(
300
+ f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
301
+ f"Skipping cache clear."
302
+ )
303
+ else:
304
+ self._logger.debug("Model cache timeout fired but cache is already empty.")
305
+
306
+ @synchronized
307
+ def shutdown(self) -> None:
308
+ """Shutdown the model cache, cancelling any pending timers."""
309
+ self._shutdown_event.set()
310
+ if self._timeout_timer is not None:
311
+ self._timeout_timer.cancel()
312
+ self._timeout_timer = None
313
+
314
+ @synchronized
315
+ @record_activity
222
316
  def put(self, key: str, model: AnyModel) -> None:
223
317
  """Add a model to the cache."""
224
318
  if key in self._cached_models:
@@ -228,7 +322,7 @@ class ModelCache:
228
322
  return
229
323
 
230
324
  size = calc_model_size_by_data(self._logger, model)
231
- self.make_room(size)
325
+ self._make_room_internal(size)
232
326
 
233
327
  # Inject custom modules into the model.
234
328
  if isinstance(model, torch.nn.Module):
@@ -272,6 +366,7 @@ class ModelCache:
272
366
  return overview
273
367
 
274
368
  @synchronized
369
+ @record_activity
275
370
  def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
276
371
  """Retrieve a model from the cache.
277
372
 
@@ -309,9 +404,11 @@ class ModelCache:
309
404
  self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
310
405
  for cb in self._on_cache_hit_callbacks:
311
406
  cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
407
+
312
408
  return cache_entry
313
409
 
314
410
  @synchronized
411
+ @record_activity
315
412
  def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
316
413
  """Lock a model for use and move it into VRAM."""
317
414
  if cache_entry.key not in self._cached_models:
@@ -348,6 +445,7 @@ class ModelCache:
348
445
  self._log_cache_state()
349
446
 
350
447
  @synchronized
448
+ @record_activity
351
449
  def unlock(self, cache_entry: CacheRecord) -> None:
352
450
  """Unlock a model."""
353
451
  if cache_entry.key not in self._cached_models:
@@ -691,6 +789,10 @@ class ModelCache:
691
789
  external references to the model, there's nothing that the cache can do about it, and those models will not be
692
790
  garbage-collected.
693
791
  """
792
+ self._make_room_internal(bytes_needed)
793
+
794
+ def _make_room_internal(self, bytes_needed: int) -> None:
795
+ """Internal implementation of make_room(). Assumes the lock is already held."""
694
796
  self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
695
797
  self._log_cache_state(title="Before dropping models:")
696
798
 
@@ -0,0 +1,40 @@
1
+ import torch
2
+ from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
3
+
4
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
6
+ CustomModuleMixin,
7
+ )
8
+
9
+
10
+ class CustomDiffusersRMSNorm(DiffusersRMSNorm, CustomModuleMixin):
11
+ """Custom wrapper for diffusers RMSNorm that supports device autocasting for partial model loading."""
12
+
13
+ def _autocast_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
14
+ weight = cast_to_device(self.weight, hidden_states.device) if self.weight is not None else None
15
+ bias = cast_to_device(self.bias, hidden_states.device) if self.bias is not None else None
16
+
17
+ input_dtype = hidden_states.dtype
18
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
19
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
20
+
21
+ if weight is not None:
22
+ # convert into half-precision if necessary
23
+ if weight.dtype in [torch.float16, torch.bfloat16]:
24
+ hidden_states = hidden_states.to(weight.dtype)
25
+ hidden_states = hidden_states * weight
26
+ if bias is not None:
27
+ hidden_states = hidden_states + bias
28
+ else:
29
+ hidden_states = hidden_states.to(input_dtype)
30
+
31
+ return hidden_states
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ if len(self._patches_and_weights) > 0:
35
+ raise RuntimeError("DiffusersRMSNorm layers do not support patches")
36
+
37
+ if self._device_autocasting_enabled:
38
+ return self._autocast_forward(hidden_states)
39
+ else:
40
+ return super().forward(hidden_states)
@@ -0,0 +1,25 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
6
+ CustomModuleMixin,
7
+ )
8
+
9
+
10
+ class CustomLayerNorm(torch.nn.LayerNorm, CustomModuleMixin):
11
+ """Custom wrapper for torch.nn.LayerNorm that supports device autocasting for partial model loading."""
12
+
13
+ def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
14
+ weight = cast_to_device(self.weight, input.device) if self.weight is not None else None
15
+ bias = cast_to_device(self.bias, input.device) if self.bias is not None else None
16
+ return F.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
17
+
18
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
19
+ if len(self._patches_and_weights) > 0:
20
+ raise RuntimeError("LayerNorm layers do not support patches")
21
+
22
+ if self._device_autocasting_enabled:
23
+ return self._autocast_forward(input)
24
+ else:
25
+ return super().forward(input)
@@ -1,14 +1,18 @@
1
1
  from typing import TypeVar
2
2
 
3
3
  import torch
4
+ from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
4
5
 
5
- from invokeai.backend.flux.modules.layers import RMSNorm
6
+ from invokeai.backend.flux.modules.layers import RMSNorm as FluxRMSNorm
6
7
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
7
8
  CustomConv1d,
8
9
  )
9
10
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import (
10
11
  CustomConv2d,
11
12
  )
13
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_diffusers_rms_norm import (
14
+ CustomDiffusersRMSNorm,
15
+ )
12
16
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
13
17
  CustomEmbedding,
14
18
  )
@@ -18,6 +22,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custo
18
22
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import (
19
23
  CustomGroupNorm,
20
24
  )
25
+ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_layer_norm import (
26
+ CustomLayerNorm,
27
+ )
21
28
  from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import (
22
29
  CustomLinear,
23
30
  )
@@ -31,7 +38,9 @@ AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]]
31
38
  torch.nn.Conv2d: CustomConv2d,
32
39
  torch.nn.GroupNorm: CustomGroupNorm,
33
40
  torch.nn.Embedding: CustomEmbedding,
34
- RMSNorm: CustomFluxRMSNorm,
41
+ torch.nn.LayerNorm: CustomLayerNorm,
42
+ FluxRMSNorm: CustomFluxRMSNorm,
43
+ DiffusersRMSNorm: CustomDiffusersRMSNorm,
35
44
  }
36
45
 
37
46
  try:
@@ -45,12 +45,13 @@ class CogView4DiffusersModel(GenericDiffusersLoader):
45
45
  model_path,
46
46
  torch_dtype=dtype,
47
47
  variant=variant,
48
+ local_files_only=True,
48
49
  )
49
50
  except OSError as e:
50
51
  if variant and "no file named" in str(
51
52
  e
52
53
  ): # try without the variant, just in case user's preferences changed
53
- result = load_class.from_pretrained(model_path, torch_dtype=dtype)
54
+ result = load_class.from_pretrained(model_path, torch_dtype=dtype, local_files_only=True)
54
55
  else:
55
56
  raise e
56
57
 
@@ -122,9 +122,9 @@ class CLIPDiffusersLoader(ModelLoader):
122
122
 
123
123
  match submodel_type:
124
124
  case SubModelType.Tokenizer:
125
- return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer")
125
+ return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer", local_files_only=True)
126
126
  case SubModelType.TextEncoder:
127
- return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder")
127
+ return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder", local_files_only=True)
128
128
 
129
129
  raise ValueError(
130
130
  f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@@ -148,10 +148,12 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
148
148
  )
149
149
  match submodel_type:
150
150
  case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
151
- return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
151
+ return T5TokenizerFast.from_pretrained(
152
+ Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
153
+ )
152
154
  case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
153
155
  te2_model_path = Path(config.path) / "text_encoder_2"
154
- model_config = AutoConfig.from_pretrained(te2_model_path)
156
+ model_config = AutoConfig.from_pretrained(te2_model_path, local_files_only=True)
155
157
  with accelerate.init_empty_weights():
156
158
  model = AutoModelForTextEncoding.from_config(model_config)
157
159
  model = quantize_model_llm_int8(model, modules_to_not_convert=set())
@@ -192,10 +194,15 @@ class T5EncoderCheckpointModel(ModelLoader):
192
194
 
193
195
  match submodel_type:
194
196
  case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
195
- return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
197
+ return T5TokenizerFast.from_pretrained(
198
+ Path(config.path) / "tokenizer_2", max_length=512, local_files_only=True
199
+ )
196
200
  case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
197
201
  return T5EncoderModel.from_pretrained(
198
- Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True
202
+ Path(config.path) / "text_encoder_2",
203
+ torch_dtype="auto",
204
+ low_cpu_mem_usage=True,
205
+ local_files_only=True,
199
206
  )
200
207
 
201
208
  raise ValueError(
@@ -37,12 +37,14 @@ class GenericDiffusersLoader(ModelLoader):
37
37
  repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
38
38
  variant = repo_variant.value if repo_variant else None
39
39
  try:
40
- result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant)
40
+ result: AnyModel = model_class.from_pretrained(
41
+ model_path, torch_dtype=self._torch_dtype, variant=variant, local_files_only=True
42
+ )
41
43
  except OSError as e:
42
44
  if variant and "no file named" in str(
43
45
  e
44
46
  ): # try without the variant, just in case user's preferences changed
45
- result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
47
+ result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, local_files_only=True)
46
48
  else:
47
49
  raise e
48
50
  return result