InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +50 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/z_image_denoise.py +8 -3
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +42 -0
- invokeai/backend/flux/denoise.py +239 -204
- invokeai/backend/flux/dype/__init__.py +18 -0
- invokeai/backend/flux/dype/base.py +226 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +141 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +261 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/load_default.py +0 -1
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
- invokeai/backend/model_manager/starter_models.py +128 -0
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
- invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en.json +58 -5
- invokeai/frontend/web/dist/locales/it.json +2 -1
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -33,6 +33,25 @@ REGEX_TO_BASE: dict[str, BaseModelType] = {
|
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
|
|
36
|
+
def _is_flux2_vae(state_dict: dict[str | int, Any]) -> bool:
|
|
37
|
+
"""Check if state dict is a FLUX.2 VAE (AutoencoderKLFlux2).
|
|
38
|
+
|
|
39
|
+
FLUX.2 VAE can be identified by:
|
|
40
|
+
1. Batch Normalization layers (bn.running_mean, bn.running_var) - unique to FLUX.2
|
|
41
|
+
2. 32-dimensional latent space (decoder.conv_in has 32 input channels)
|
|
42
|
+
|
|
43
|
+
FLUX.1 VAE has 16-dimensional latent space and no BatchNorm layers.
|
|
44
|
+
"""
|
|
45
|
+
# Check for BN layer which is unique to FLUX.2 VAE
|
|
46
|
+
has_bn = "bn.running_mean" in state_dict or "bn.running_var" in state_dict
|
|
47
|
+
|
|
48
|
+
# Check for 32-channel latent space (FLUX.2 has 32, FLUX.1 has 16)
|
|
49
|
+
decoder_conv_in_key = "decoder.conv_in.weight"
|
|
50
|
+
has_32_latent_channels = decoder_conv_in_key in state_dict and state_dict[decoder_conv_in_key].shape[1] == 32
|
|
51
|
+
|
|
52
|
+
return has_bn or has_32_latent_channels
|
|
53
|
+
|
|
54
|
+
|
|
36
55
|
class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
37
56
|
"""Model config for standalone VAE models."""
|
|
38
57
|
|
|
@@ -61,8 +80,9 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
|
61
80
|
|
|
62
81
|
@classmethod
|
|
63
82
|
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
|
83
|
+
state_dict = mod.load_state_dict()
|
|
64
84
|
if not state_dict_has_any_keys_starting_with(
|
|
65
|
-
|
|
85
|
+
state_dict,
|
|
66
86
|
{
|
|
67
87
|
"encoder.conv_in",
|
|
68
88
|
"decoder.conv_in",
|
|
@@ -70,9 +90,30 @@ class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base):
|
|
|
70
90
|
):
|
|
71
91
|
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
|
72
92
|
|
|
93
|
+
# Exclude FLUX.2 VAEs - they have their own config class
|
|
94
|
+
if _is_flux2_vae(state_dict):
|
|
95
|
+
raise NotAMatchError("model is a FLUX.2 VAE, not a standard VAE")
|
|
96
|
+
|
|
73
97
|
@classmethod
|
|
74
98
|
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
|
|
75
|
-
#
|
|
99
|
+
# First, try to identify by latent space dimensions (most reliable)
|
|
100
|
+
state_dict = mod.load_state_dict()
|
|
101
|
+
decoder_conv_in_key = "decoder.conv_in.weight"
|
|
102
|
+
if decoder_conv_in_key in state_dict:
|
|
103
|
+
latent_channels = state_dict[decoder_conv_in_key].shape[1]
|
|
104
|
+
if latent_channels == 16:
|
|
105
|
+
# Flux1 VAE has 16-dimensional latent space
|
|
106
|
+
return BaseModelType.Flux
|
|
107
|
+
elif latent_channels == 4:
|
|
108
|
+
# SD/SDXL VAE has 4-dimensional latent space
|
|
109
|
+
# Try to distinguish SD1/SD2/SDXL by name, fallback to SD1
|
|
110
|
+
for regexp, base in REGEX_TO_BASE.items():
|
|
111
|
+
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
|
112
|
+
return base
|
|
113
|
+
# Default to SD1 if we can't determine from name
|
|
114
|
+
return BaseModelType.StableDiffusion1
|
|
115
|
+
|
|
116
|
+
# Fallback: guess based on name
|
|
76
117
|
for regexp, base in REGEX_TO_BASE.items():
|
|
77
118
|
if re.search(regexp, mod.path.name, re.IGNORECASE):
|
|
78
119
|
return base
|
|
@@ -96,6 +137,44 @@ class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base):
|
|
|
96
137
|
base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux)
|
|
97
138
|
|
|
98
139
|
|
|
140
|
+
class VAE_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Config_Base):
|
|
141
|
+
"""Model config for FLUX.2 VAE checkpoint models (AutoencoderKLFlux2)."""
|
|
142
|
+
|
|
143
|
+
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
144
|
+
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
|
|
145
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
149
|
+
raise_if_not_file(mod)
|
|
150
|
+
|
|
151
|
+
raise_for_override_fields(cls, override_fields)
|
|
152
|
+
|
|
153
|
+
cls._validate_looks_like_vae(mod)
|
|
154
|
+
|
|
155
|
+
cls._validate_is_flux2_vae(mod)
|
|
156
|
+
|
|
157
|
+
return cls(**override_fields)
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None:
|
|
161
|
+
if not state_dict_has_any_keys_starting_with(
|
|
162
|
+
mod.load_state_dict(),
|
|
163
|
+
{
|
|
164
|
+
"encoder.conv_in",
|
|
165
|
+
"decoder.conv_in",
|
|
166
|
+
},
|
|
167
|
+
):
|
|
168
|
+
raise NotAMatchError("model does not match Checkpoint VAE heuristics")
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def _validate_is_flux2_vae(cls, mod: ModelOnDisk) -> None:
|
|
172
|
+
"""Validate that this is a FLUX.2 VAE, not FLUX.1."""
|
|
173
|
+
state_dict = mod.load_state_dict()
|
|
174
|
+
if not _is_flux2_vae(state_dict):
|
|
175
|
+
raise NotAMatchError("state dict does not look like a FLUX.2 VAE")
|
|
176
|
+
|
|
177
|
+
|
|
99
178
|
class VAE_Diffusers_Config_Base(Diffusers_Config_Base):
|
|
100
179
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
101
180
|
|
|
@@ -161,3 +240,26 @@ class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
|
161
240
|
|
|
162
241
|
class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base):
|
|
163
242
|
base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class VAE_Diffusers_Flux2_Config(Diffusers_Config_Base, Config_Base):
|
|
246
|
+
"""Model config for FLUX.2 VAE models in diffusers format (AutoencoderKLFlux2)."""
|
|
247
|
+
|
|
248
|
+
type: Literal[ModelType.VAE] = Field(default=ModelType.VAE)
|
|
249
|
+
format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers)
|
|
250
|
+
base: Literal[BaseModelType.Flux2] = Field(default=BaseModelType.Flux2)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
|
|
254
|
+
raise_if_not_dir(mod)
|
|
255
|
+
|
|
256
|
+
raise_for_override_fields(cls, override_fields)
|
|
257
|
+
|
|
258
|
+
raise_for_class_name(
|
|
259
|
+
common_config_paths(mod.path),
|
|
260
|
+
{
|
|
261
|
+
"AutoencoderKLFlux2",
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return cls(**override_fields)
|
|
@@ -75,7 +75,6 @@ class ModelLoader(ModelLoaderBase):
|
|
|
75
75
|
|
|
76
76
|
config.path = str(self._get_model_path(config))
|
|
77
77
|
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
|
78
|
-
self._logger.info(f"Loading model '{stats_name}' into RAM cache..., config={config}")
|
|
79
78
|
loaded_model = self._load_model(config, submodel_type)
|
|
80
79
|
|
|
81
80
|
self._ram_cache.put(
|
|
@@ -55,6 +55,21 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
|
55
55
|
return wrapper
|
|
56
56
|
|
|
57
57
|
|
|
58
|
+
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
59
|
+
"""A decorator that records activity after a method completes successfully.
|
|
60
|
+
|
|
61
|
+
Note: This decorator should be applied to methods that already hold self._lock.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
@wraps(method)
|
|
65
|
+
def wrapper(self, *args, **kwargs):
|
|
66
|
+
result = method(self, *args, **kwargs)
|
|
67
|
+
self._record_activity()
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
return wrapper
|
|
71
|
+
|
|
72
|
+
|
|
58
73
|
@dataclass
|
|
59
74
|
class CacheEntrySnapshot:
|
|
60
75
|
cache_key: str
|
|
@@ -132,6 +147,7 @@ class ModelCache:
|
|
|
132
147
|
storage_device: torch.device | str = "cpu",
|
|
133
148
|
log_memory_usage: bool = False,
|
|
134
149
|
logger: Optional[Logger] = None,
|
|
150
|
+
keep_alive_minutes: float = 0,
|
|
135
151
|
):
|
|
136
152
|
"""Initialize the model RAM cache.
|
|
137
153
|
|
|
@@ -151,6 +167,7 @@ class ModelCache:
|
|
|
151
167
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
|
152
168
|
behaviour.
|
|
153
169
|
:param logger: InvokeAILogger to use (otherwise creates one)
|
|
170
|
+
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
|
|
154
171
|
"""
|
|
155
172
|
self._enable_partial_loading = enable_partial_loading
|
|
156
173
|
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
|
|
@@ -182,6 +199,12 @@ class ModelCache:
|
|
|
182
199
|
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
|
|
183
200
|
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
|
|
184
201
|
|
|
202
|
+
# Keep-alive timeout support
|
|
203
|
+
self._keep_alive_minutes = keep_alive_minutes
|
|
204
|
+
self._last_activity_time: Optional[float] = None
|
|
205
|
+
self._timeout_timer: Optional[threading.Timer] = None
|
|
206
|
+
self._shutdown_event = threading.Event()
|
|
207
|
+
|
|
185
208
|
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
|
|
186
209
|
self._on_cache_hit_callbacks.add(cb)
|
|
187
210
|
|
|
@@ -190,7 +213,7 @@ class ModelCache:
|
|
|
190
213
|
|
|
191
214
|
return unsubscribe
|
|
192
215
|
|
|
193
|
-
def on_cache_miss(self, cb:
|
|
216
|
+
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
|
|
194
217
|
self._on_cache_miss_callbacks.add(cb)
|
|
195
218
|
|
|
196
219
|
def unsubscribe() -> None:
|
|
@@ -217,8 +240,82 @@ class ModelCache:
|
|
|
217
240
|
def stats(self, stats: CacheStats) -> None:
|
|
218
241
|
"""Set the CacheStats object for collecting cache statistics."""
|
|
219
242
|
self._stats = stats
|
|
243
|
+
# Populate the cache size in the stats object when it's set
|
|
244
|
+
if self._stats is not None:
|
|
245
|
+
self._stats.cache_size = self._ram_cache_size_bytes
|
|
246
|
+
|
|
247
|
+
def _record_activity(self) -> None:
|
|
248
|
+
"""Record model activity and reset the timeout timer if configured.
|
|
249
|
+
|
|
250
|
+
Note: This method should only be called when self._lock is already held.
|
|
251
|
+
"""
|
|
252
|
+
if self._keep_alive_minutes <= 0:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
self._last_activity_time = time.time()
|
|
256
|
+
|
|
257
|
+
# Cancel any existing timer
|
|
258
|
+
if self._timeout_timer is not None:
|
|
259
|
+
self._timeout_timer.cancel()
|
|
260
|
+
|
|
261
|
+
# Start a new timer
|
|
262
|
+
timeout_seconds = self._keep_alive_minutes * 60
|
|
263
|
+
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
|
|
264
|
+
# Set as daemon so it doesn't prevent application shutdown
|
|
265
|
+
self._timeout_timer.daemon = True
|
|
266
|
+
self._timeout_timer.start()
|
|
267
|
+
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
|
|
220
268
|
|
|
221
269
|
@synchronized
|
|
270
|
+
@record_activity
|
|
271
|
+
def _on_timeout(self) -> None:
|
|
272
|
+
"""Called when the keep-alive timeout expires. Clears the model cache."""
|
|
273
|
+
if self._shutdown_event.is_set():
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# Double-check if there has been activity since the timer was set
|
|
277
|
+
# This handles the race condition where activity occurred just before the timer fired
|
|
278
|
+
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
|
|
279
|
+
elapsed_minutes = (time.time() - self._last_activity_time) / 60
|
|
280
|
+
if elapsed_minutes < self._keep_alive_minutes:
|
|
281
|
+
# Activity occurred, don't clear cache
|
|
282
|
+
self._logger.debug(
|
|
283
|
+
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
|
|
284
|
+
f"Skipping cache clear."
|
|
285
|
+
)
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
# Check if there are any unlocked models that can be cleared
|
|
289
|
+
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
|
|
290
|
+
|
|
291
|
+
if len(unlocked_models) > 0:
|
|
292
|
+
self._logger.info(
|
|
293
|
+
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
|
|
294
|
+
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
|
|
295
|
+
)
|
|
296
|
+
# Clear the cache by requesting a very large amount of space.
|
|
297
|
+
# This is the same logic used by the "Clear Model Cache" button.
|
|
298
|
+
# Using 1000 GB ensures all unlocked models are removed.
|
|
299
|
+
self._make_room_internal(1000 * GB)
|
|
300
|
+
elif len(self._cached_models) > 0:
|
|
301
|
+
# All models are locked, don't log at info level
|
|
302
|
+
self._logger.debug(
|
|
303
|
+
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
|
|
304
|
+
f"Skipping cache clear."
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
self._logger.debug("Model cache timeout fired but cache is already empty.")
|
|
308
|
+
|
|
309
|
+
@synchronized
|
|
310
|
+
def shutdown(self) -> None:
|
|
311
|
+
"""Shutdown the model cache, cancelling any pending timers."""
|
|
312
|
+
self._shutdown_event.set()
|
|
313
|
+
if self._timeout_timer is not None:
|
|
314
|
+
self._timeout_timer.cancel()
|
|
315
|
+
self._timeout_timer = None
|
|
316
|
+
|
|
317
|
+
@synchronized
|
|
318
|
+
@record_activity
|
|
222
319
|
def put(self, key: str, model: AnyModel) -> None:
|
|
223
320
|
"""Add a model to the cache."""
|
|
224
321
|
if key in self._cached_models:
|
|
@@ -228,7 +325,7 @@ class ModelCache:
|
|
|
228
325
|
return
|
|
229
326
|
|
|
230
327
|
size = calc_model_size_by_data(self._logger, model)
|
|
231
|
-
self.
|
|
328
|
+
self._make_room_internal(size)
|
|
232
329
|
|
|
233
330
|
# Inject custom modules into the model.
|
|
234
331
|
if isinstance(model, torch.nn.Module):
|
|
@@ -272,6 +369,7 @@ class ModelCache:
|
|
|
272
369
|
return overview
|
|
273
370
|
|
|
274
371
|
@synchronized
|
|
372
|
+
@record_activity
|
|
275
373
|
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
|
|
276
374
|
"""Retrieve a model from the cache.
|
|
277
375
|
|
|
@@ -309,9 +407,11 @@ class ModelCache:
|
|
|
309
407
|
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
|
|
310
408
|
for cb in self._on_cache_hit_callbacks:
|
|
311
409
|
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
|
|
410
|
+
|
|
312
411
|
return cache_entry
|
|
313
412
|
|
|
314
413
|
@synchronized
|
|
414
|
+
@record_activity
|
|
315
415
|
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
|
|
316
416
|
"""Lock a model for use and move it into VRAM."""
|
|
317
417
|
if cache_entry.key not in self._cached_models:
|
|
@@ -348,6 +448,7 @@ class ModelCache:
|
|
|
348
448
|
self._log_cache_state()
|
|
349
449
|
|
|
350
450
|
@synchronized
|
|
451
|
+
@record_activity
|
|
351
452
|
def unlock(self, cache_entry: CacheRecord) -> None:
|
|
352
453
|
"""Unlock a model."""
|
|
353
454
|
if cache_entry.key not in self._cached_models:
|
|
@@ -691,6 +792,10 @@ class ModelCache:
|
|
|
691
792
|
external references to the model, there's nothing that the cache can do about it, and those models will not be
|
|
692
793
|
garbage-collected.
|
|
693
794
|
"""
|
|
795
|
+
self._make_room_internal(bytes_needed)
|
|
796
|
+
|
|
797
|
+
def _make_room_internal(self, bytes_needed: int) -> None:
|
|
798
|
+
"""Internal implementation of make_room(). Assumes the lock is already held."""
|
|
694
799
|
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
|
|
695
800
|
self._log_cache_state(title="Before dropping models:")
|
|
696
801
|
|