InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. invokeai/app/api/dependencies.py +2 -0
  2. invokeai/app/api/routers/model_manager.py +91 -2
  3. invokeai/app/api/routers/workflows.py +9 -0
  4. invokeai/app/invocations/fields.py +19 -0
  5. invokeai/app/invocations/image_to_latents.py +23 -5
  6. invokeai/app/invocations/latents_to_image.py +2 -25
  7. invokeai/app/invocations/metadata.py +9 -1
  8. invokeai/app/invocations/model.py +8 -0
  9. invokeai/app/invocations/primitives.py +12 -0
  10. invokeai/app/invocations/prompt_template.py +57 -0
  11. invokeai/app/invocations/z_image_control.py +112 -0
  12. invokeai/app/invocations/z_image_denoise.py +610 -0
  13. invokeai/app/invocations/z_image_image_to_latents.py +102 -0
  14. invokeai/app/invocations/z_image_latents_to_image.py +103 -0
  15. invokeai/app/invocations/z_image_lora_loader.py +153 -0
  16. invokeai/app/invocations/z_image_model_loader.py +135 -0
  17. invokeai/app/invocations/z_image_text_encoder.py +197 -0
  18. invokeai/app/services/model_install/model_install_common.py +14 -1
  19. invokeai/app/services/model_install/model_install_default.py +119 -19
  20. invokeai/app/services/model_records/model_records_base.py +12 -0
  21. invokeai/app/services/model_records/model_records_sql.py +17 -0
  22. invokeai/app/services/shared/graph.py +132 -77
  23. invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
  24. invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
  25. invokeai/app/util/step_callback.py +3 -0
  26. invokeai/backend/model_manager/configs/controlnet.py +47 -1
  27. invokeai/backend/model_manager/configs/factory.py +26 -1
  28. invokeai/backend/model_manager/configs/lora.py +43 -1
  29. invokeai/backend/model_manager/configs/main.py +113 -0
  30. invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
  31. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
  32. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
  33. invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
  34. invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
  35. invokeai/backend/model_manager/load/model_loaders/z_image.py +935 -0
  36. invokeai/backend/model_manager/load/model_util.py +6 -1
  37. invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
  38. invokeai/backend/model_manager/model_on_disk.py +3 -0
  39. invokeai/backend/model_manager/starter_models.py +70 -0
  40. invokeai/backend/model_manager/taxonomy.py +5 -0
  41. invokeai/backend/model_manager/util/select_hf_files.py +23 -8
  42. invokeai/backend/patches/layer_patcher.py +34 -16
  43. invokeai/backend/patches/layers/lora_layer_base.py +2 -1
  44. invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
  45. invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
  46. invokeai/backend/patches/lora_conversions/formats.py +5 -0
  47. invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
  48. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +155 -0
  49. invokeai/backend/quantization/gguf/ggml_tensor.py +27 -4
  50. invokeai/backend/quantization/gguf/loaders.py +47 -12
  51. invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
  52. invokeai/backend/util/devices.py +25 -0
  53. invokeai/backend/util/hotfixes.py +2 -2
  54. invokeai/backend/z_image/__init__.py +16 -0
  55. invokeai/backend/z_image/extensions/__init__.py +1 -0
  56. invokeai/backend/z_image/extensions/regional_prompting_extension.py +207 -0
  57. invokeai/backend/z_image/text_conditioning.py +74 -0
  58. invokeai/backend/z_image/z_image_control_adapter.py +238 -0
  59. invokeai/backend/z_image/z_image_control_transformer.py +643 -0
  60. invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
  61. invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
  62. invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
  63. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +161 -0
  64. invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-DHZxq1nk.js} +1 -1
  65. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +530 -0
  66. invokeai/frontend/web/dist/index.html +1 -1
  67. invokeai/frontend/web/dist/locales/de.json +24 -6
  68. invokeai/frontend/web/dist/locales/en.json +70 -1
  69. invokeai/frontend/web/dist/locales/es.json +0 -5
  70. invokeai/frontend/web/dist/locales/fr.json +0 -6
  71. invokeai/frontend/web/dist/locales/it.json +17 -64
  72. invokeai/frontend/web/dist/locales/ja.json +379 -44
  73. invokeai/frontend/web/dist/locales/ru.json +0 -6
  74. invokeai/frontend/web/dist/locales/vi.json +7 -54
  75. invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
  76. invokeai/version/invokeai_version.py +1 -1
  77. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/METADATA +3 -3
  78. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/RECORD +84 -60
  79. invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
  80. invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
  81. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/WHEEL +0 -0
  82. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/entry_points.txt +0 -0
  83. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE +0 -0
  84. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  85. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  86. {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,197 @@
1
+ from contextlib import ExitStack
2
+ from typing import Iterator, Optional, Tuple
3
+
4
+ import torch
5
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
6
+
7
+ from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
8
+ from invokeai.app.invocations.fields import (
9
+ FieldDescriptions,
10
+ Input,
11
+ InputField,
12
+ TensorField,
13
+ UIComponent,
14
+ ZImageConditioningField,
15
+ )
16
+ from invokeai.app.invocations.model import Qwen3EncoderField
17
+ from invokeai.app.invocations.primitives import ZImageConditioningOutput
18
+ from invokeai.app.services.shared.invocation_context import InvocationContext
19
+ from invokeai.backend.patches.layer_patcher import LayerPatcher
20
+ from invokeai.backend.patches.lora_conversions.z_image_lora_constants import Z_IMAGE_LORA_QWEN3_PREFIX
21
+ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
22
+ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
23
+ ConditioningFieldData,
24
+ ZImageConditioningInfo,
25
+ )
26
+ from invokeai.backend.util.devices import TorchDevice
27
+
28
+ # Z-Image max sequence length based on diffusers default
29
+ Z_IMAGE_MAX_SEQ_LEN = 512
30
+
31
+
32
+ @invocation(
33
+ "z_image_text_encoder",
34
+ title="Prompt - Z-Image",
35
+ tags=["prompt", "conditioning", "z-image"],
36
+ category="conditioning",
37
+ version="1.1.0",
38
+ classification=Classification.Prototype,
39
+ )
40
+ class ZImageTextEncoderInvocation(BaseInvocation):
41
+ """Encodes and preps a prompt for a Z-Image image.
42
+
43
+ Supports regional prompting by connecting a mask input.
44
+ """
45
+
46
+ prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
47
+ qwen3_encoder: Qwen3EncoderField = InputField(
48
+ title="Qwen3 Encoder",
49
+ description=FieldDescriptions.qwen3_encoder,
50
+ input=Input.Connection,
51
+ )
52
+ mask: Optional[TensorField] = InputField(
53
+ default=None,
54
+ description="A mask defining the region that this conditioning prompt applies to.",
55
+ )
56
+
57
+ @torch.no_grad()
58
+ def invoke(self, context: InvocationContext) -> ZImageConditioningOutput:
59
+ prompt_embeds = self._encode_prompt(context, max_seq_len=Z_IMAGE_MAX_SEQ_LEN)
60
+ conditioning_data = ConditioningFieldData(conditionings=[ZImageConditioningInfo(prompt_embeds=prompt_embeds)])
61
+ conditioning_name = context.conditioning.save(conditioning_data)
62
+ return ZImageConditioningOutput(
63
+ conditioning=ZImageConditioningField(conditioning_name=conditioning_name, mask=self.mask)
64
+ )
65
+
66
+ def _encode_prompt(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
67
+ """Encode prompt using Qwen3 text encoder.
68
+
69
+ Based on the ZImagePipeline._encode_prompt method from diffusers.
70
+ """
71
+ prompt = self.prompt
72
+ device = TorchDevice.choose_torch_device()
73
+
74
+ text_encoder_info = context.models.load(self.qwen3_encoder.text_encoder)
75
+ tokenizer_info = context.models.load(self.qwen3_encoder.tokenizer)
76
+
77
+ with ExitStack() as exit_stack:
78
+ (_, text_encoder) = exit_stack.enter_context(text_encoder_info.model_on_device())
79
+ (_, tokenizer) = exit_stack.enter_context(tokenizer_info.model_on_device())
80
+
81
+ # Apply LoRA models to the text encoder
82
+ lora_dtype = TorchDevice.choose_bfloat16_safe_dtype(device)
83
+ exit_stack.enter_context(
84
+ LayerPatcher.apply_smart_model_patches(
85
+ model=text_encoder,
86
+ patches=self._lora_iterator(context),
87
+ prefix=Z_IMAGE_LORA_QWEN3_PREFIX,
88
+ dtype=lora_dtype,
89
+ )
90
+ )
91
+
92
+ context.util.signal_progress("Running Qwen3 text encoder")
93
+ if not isinstance(text_encoder, PreTrainedModel):
94
+ raise TypeError(
95
+ f"Expected PreTrainedModel for text encoder, got {type(text_encoder).__name__}. "
96
+ "The Qwen3 encoder model may be corrupted or incompatible."
97
+ )
98
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
99
+ raise TypeError(
100
+ f"Expected PreTrainedTokenizerBase for tokenizer, got {type(tokenizer).__name__}. "
101
+ "The Qwen3 tokenizer may be corrupted or incompatible."
102
+ )
103
+
104
+ # Apply chat template similar to diffusers ZImagePipeline
105
+ # The chat template formats the prompt for the Qwen3 model
106
+ try:
107
+ prompt_formatted = tokenizer.apply_chat_template(
108
+ [{"role": "user", "content": prompt}],
109
+ tokenize=False,
110
+ add_generation_prompt=True,
111
+ enable_thinking=True,
112
+ )
113
+ except (AttributeError, TypeError) as e:
114
+ # Fallback if tokenizer doesn't support apply_chat_template or enable_thinking
115
+ context.logger.warning(f"Chat template failed ({e}), using raw prompt.")
116
+ prompt_formatted = prompt
117
+
118
+ # Tokenize the formatted prompt
119
+ text_inputs = tokenizer(
120
+ prompt_formatted,
121
+ padding="max_length",
122
+ max_length=max_seq_len,
123
+ truncation=True,
124
+ return_attention_mask=True,
125
+ return_tensors="pt",
126
+ )
127
+
128
+ text_input_ids = text_inputs.input_ids
129
+ attention_mask = text_inputs.attention_mask
130
+ if not isinstance(text_input_ids, torch.Tensor):
131
+ raise TypeError(
132
+ f"Expected torch.Tensor for input_ids, got {type(text_input_ids).__name__}. "
133
+ "Tokenizer returned unexpected type."
134
+ )
135
+ if not isinstance(attention_mask, torch.Tensor):
136
+ raise TypeError(
137
+ f"Expected torch.Tensor for attention_mask, got {type(attention_mask).__name__}. "
138
+ "Tokenizer returned unexpected type."
139
+ )
140
+
141
+ # Check for truncation
142
+ untruncated_ids = tokenizer(prompt_formatted, padding="longest", return_tensors="pt").input_ids
143
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
144
+ text_input_ids, untruncated_ids
145
+ ):
146
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
147
+ context.logger.warning(
148
+ f"The following part of your input was truncated because `max_sequence_length` is set to "
149
+ f"{max_seq_len} tokens: {removed_text}"
150
+ )
151
+
152
+ # Get hidden states from the text encoder
153
+ # Use the second-to-last hidden state like diffusers does
154
+ prompt_mask = attention_mask.to(device).bool()
155
+ outputs = text_encoder(
156
+ text_input_ids.to(device),
157
+ attention_mask=prompt_mask,
158
+ output_hidden_states=True,
159
+ )
160
+
161
+ # Validate hidden_states output
162
+ if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
163
+ raise RuntimeError(
164
+ "Text encoder did not return hidden_states. "
165
+ "Ensure output_hidden_states=True is supported by this model."
166
+ )
167
+ if len(outputs.hidden_states) < 2:
168
+ raise RuntimeError(
169
+ f"Expected at least 2 hidden states from text encoder, got {len(outputs.hidden_states)}. "
170
+ "This may indicate an incompatible model or configuration."
171
+ )
172
+ prompt_embeds = outputs.hidden_states[-2]
173
+
174
+ # Z-Image expects a 2D tensor [seq_len, hidden_dim] with only valid tokens
175
+ # Based on diffusers ZImagePipeline implementation:
176
+ # embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
177
+ # Since batch_size=1, we take the first item and filter by mask
178
+ prompt_embeds = prompt_embeds[0][prompt_mask[0]]
179
+
180
+ if not isinstance(prompt_embeds, torch.Tensor):
181
+ raise TypeError(
182
+ f"Expected torch.Tensor for prompt embeddings, got {type(prompt_embeds).__name__}. "
183
+ "Text encoder returned unexpected type."
184
+ )
185
+ return prompt_embeds
186
+
187
+ def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
188
+ """Iterate over LoRA models to apply to the Qwen3 text encoder."""
189
+ for lora in self.qwen3_encoder.loras:
190
+ lora_info = context.models.load(lora.lora)
191
+ if not isinstance(lora_info.model, ModelPatchRaw):
192
+ raise TypeError(
193
+ f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}. "
194
+ "The LoRA model may be corrupted or incompatible."
195
+ )
196
+ yield (lora_info.model, lora.weight)
197
+ del lora_info
@@ -85,9 +85,12 @@ class LocalModelSource(StringLikeSource):
85
85
 
86
86
  class HFModelSource(StringLikeSource):
87
87
  """
88
- A HuggingFace repo_id with optional variant, sub-folder and access token.
88
+ A HuggingFace repo_id with optional variant, sub-folder(s) and access token.
89
89
  Note that the variant option, if not provided to the constructor, will default to fp16, which is
90
90
  what people (almost) always want.
91
+
92
+ The subfolder can be a single path or multiple paths joined by '+' (e.g., "text_encoder+tokenizer").
93
+ When multiple subfolders are specified, all of them will be downloaded and combined into the model directory.
91
94
  """
92
95
 
93
96
  repo_id: str
@@ -103,6 +106,16 @@ class HFModelSource(StringLikeSource):
103
106
  raise ValueError(f"{v}: invalid repo_id format")
104
107
  return v
105
108
 
109
+ @property
110
+ def subfolders(self) -> list[Path]:
111
+ """Return list of subfolders (supports '+' separated multiple subfolders)."""
112
+ if self.subfolder is None:
113
+ return []
114
+ subfolder_str = self.subfolder.as_posix()
115
+ if "+" in subfolder_str:
116
+ return [Path(s.strip()) for s in subfolder_str.split("+")]
117
+ return [self.subfolder]
118
+
106
119
  def __str__(self) -> str:
107
120
  """Return string version of repoid when string rep needed."""
108
121
  base: str = self.repo_id
@@ -1,8 +1,10 @@
1
1
  """Model installation class."""
2
2
 
3
+ import gc
3
4
  import locale
4
5
  import os
5
6
  import re
7
+ import sys
6
8
  import threading
7
9
  import time
8
10
  from copy import deepcopy
@@ -187,6 +189,22 @@ class ModelInstallService(ModelInstallServiceBase):
187
189
  config.source_type = ModelSourceType.Path
188
190
  return self._register(model_path, config)
189
191
 
192
+ # TODO: Replace this with a proper fix for underlying problem of Windows holding open
193
+ # the file when it needs to be moved.
194
+ @staticmethod
195
+ def _move_with_retries(src: Path, dst: Path, attempts: int = 5, delay: float = 0.5) -> None:
196
+ """Workaround for Windows file-handle issues when moving files."""
197
+ for tries_left in range(attempts, 0, -1):
198
+ try:
199
+ move(src, dst)
200
+ return
201
+ except PermissionError:
202
+ gc.collect()
203
+ if tries_left == 1:
204
+ raise
205
+ time.sleep(delay)
206
+ delay *= 2 # Exponential backoff
207
+
190
208
  def install_path(
191
209
  self,
192
210
  model_path: Union[Path, str],
@@ -205,7 +223,7 @@ class ModelInstallService(ModelInstallServiceBase):
205
223
  dest_dir.mkdir(parents=True)
206
224
  dest_path = dest_dir / model_path.name if model_path.is_file() else dest_dir
207
225
  if model_path.is_file():
208
- move(model_path, dest_path)
226
+ self._move_with_retries(model_path, dest_path) # Windows workaround TODO: fix root cause
209
227
  elif model_path.is_dir():
210
228
  # Move the contents of the directory, not the directory itself
211
229
  for item in model_path.iterdir():
@@ -417,10 +435,15 @@ class ModelInstallService(ModelInstallServiceBase):
417
435
  model_path.mkdir(parents=True, exist_ok=True)
418
436
  model_source = self._guess_source(str(source))
419
437
  remote_files, _ = self._remote_files_from_source(model_source)
438
+ # Handle multiple subfolders for HFModelSource
439
+ subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else []
420
440
  job = self._multifile_download(
421
441
  dest=model_path,
422
442
  remote_files=remote_files,
423
- subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
443
+ subfolder=model_source.subfolder
444
+ if isinstance(model_source, HFModelSource) and len(subfolders) <= 1
445
+ else None,
446
+ subfolders=subfolders if len(subfolders) > 1 else None,
424
447
  )
425
448
  files_string = "file" if len(remote_files) == 1 else "files"
426
449
  self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
@@ -438,10 +461,13 @@ class ModelInstallService(ModelInstallServiceBase):
438
461
  if isinstance(source, HFModelSource):
439
462
  metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
440
463
  assert isinstance(metadata, ModelMetadataWithFiles)
464
+ # Use subfolders property which handles '+' separated multiple subfolders
465
+ subfolders = source.subfolders
441
466
  return (
442
467
  metadata.download_urls(
443
468
  variant=source.variant or self._guess_variant(),
444
- subfolder=source.subfolder,
469
+ subfolder=source.subfolder if len(subfolders) <= 1 else None,
470
+ subfolders=subfolders if len(subfolders) > 1 else None,
445
471
  session=self._session,
446
472
  ),
447
473
  metadata,
@@ -492,6 +518,39 @@ class ModelInstallService(ModelInstallServiceBase):
492
518
  self._install_thread.start()
493
519
  self._running = True
494
520
 
521
+ @staticmethod
522
+ def _safe_rmtree(path: Path, logger: Any) -> None:
523
+ """Remove a directory tree with retry logic for Windows file locking issues.
524
+
525
+ On Windows, memory-mapped files may not be immediately released even after
526
+ the file handle is closed. This function retries the removal with garbage
527
+ collection to help release any lingering references.
528
+ """
529
+ max_retries = 3
530
+ retry_delay = 0.5 # seconds
531
+
532
+ for attempt in range(max_retries):
533
+ try:
534
+ # Force garbage collection to release any lingering file references
535
+ gc.collect()
536
+ rmtree(path)
537
+ return
538
+ except PermissionError as e:
539
+ if attempt < max_retries - 1 and sys.platform == "win32":
540
+ logger.warning(
541
+ f"Failed to remove {path} (attempt {attempt + 1}/{max_retries}): {e}. "
542
+ f"Retrying in {retry_delay}s..."
543
+ )
544
+ time.sleep(retry_delay)
545
+ retry_delay *= 2 # Exponential backoff
546
+ else:
547
+ logger.error(f"Failed to remove temporary directory {path}: {e}")
548
+ # On final failure, don't raise - the temp dir will be cleaned up on next startup
549
+ return
550
+ except Exception as e:
551
+ logger.error(f"Unexpected error removing {path}: {e}")
552
+ return
553
+
495
554
  def _install_next_item(self) -> None:
496
555
  self._logger.debug(f"Installer thread {threading.get_ident()} starting")
497
556
  while True:
@@ -521,7 +580,7 @@ class ModelInstallService(ModelInstallServiceBase):
521
580
  finally:
522
581
  # if this is an install of a remote file, then clean up the temporary directory
523
582
  if job._install_tmpdir is not None:
524
- rmtree(job._install_tmpdir)
583
+ self._safe_rmtree(job._install_tmpdir, self._logger)
525
584
  self._install_completed_event.set()
526
585
  self._install_queue.task_done()
527
586
  self._logger.info(f"Installer thread {threading.get_ident()} exiting")
@@ -566,7 +625,7 @@ class ModelInstallService(ModelInstallServiceBase):
566
625
  path = self._app_config.models_path
567
626
  for tmpdir in path.glob(f"{TMPDIR_PREFIX}*"):
568
627
  self._logger.info(f"Removing dangling temporary directory {tmpdir}")
569
- rmtree(tmpdir)
628
+ self._safe_rmtree(tmpdir, self._logger)
570
629
 
571
630
  def _scan_for_missing_models(self) -> list[AnyModelConfig]:
572
631
  """Scan the models directory for missing models and return a list of them."""
@@ -741,10 +800,13 @@ class ModelInstallService(ModelInstallServiceBase):
741
800
  install_job._install_tmpdir = destdir
742
801
  install_job.total_bytes = sum((x.size or 0) for x in remote_files)
743
802
 
803
+ # Handle multiple subfolders for HFModelSource
804
+ subfolders = source.subfolders if isinstance(source, HFModelSource) else []
744
805
  multifile_job = self._multifile_download(
745
806
  remote_files=remote_files,
746
807
  dest=destdir,
747
- subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
808
+ subfolder=source.subfolder if isinstance(source, HFModelSource) and len(subfolders) <= 1 else None,
809
+ subfolders=subfolders if len(subfolders) > 1 else None,
748
810
  access_token=source.access_token,
749
811
  submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
750
812
  )
@@ -771,6 +833,7 @@ class ModelInstallService(ModelInstallServiceBase):
771
833
  remote_files: List[RemoteModelFile],
772
834
  dest: Path,
773
835
  subfolder: Optional[Path] = None,
836
+ subfolders: Optional[List[Path]] = None,
774
837
  access_token: Optional[str] = None,
775
838
  submit_job: bool = True,
776
839
  ) -> MultiFileDownloadJob:
@@ -778,24 +841,61 @@ class ModelInstallService(ModelInstallServiceBase):
778
841
  # we are installing the "vae" subfolder, we do not want to create an additional folder level, such
779
842
  # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
780
843
  # So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
781
- if subfolder:
844
+ #
845
+ # For multiple subfolders (e.g., text_encoder+tokenizer), we create a combined folder name
846
+ # (e.g., sdxl-turbo_text_encoder_tokenizer) and keep each subfolder's contents in its own
847
+ # subdirectory within the model folder.
848
+
849
+ if subfolders and len(subfolders) > 1:
850
+ # Multiple subfolders: create combined name and keep subfolder structure
851
+ top = Path(remote_files[0].path.parts[0]) # e.g. "Z-Image-Turbo/"
852
+ subfolder_names = [sf.name.replace("/", "_").replace("\\", "_") for sf in subfolders]
853
+ combined_name = "_".join(subfolder_names)
854
+ path_to_add = Path(f"{top}_{combined_name}")
855
+
856
+ parts: List[RemoteModelFile] = []
857
+ for model_file in remote_files:
858
+ assert model_file.size is not None
859
+ # Determine which subfolder this file belongs to
860
+ file_path = model_file.path
861
+ new_path: Optional[Path] = None
862
+ for sf in subfolders:
863
+ try:
864
+ # Try to get relative path from this subfolder
865
+ relative = file_path.relative_to(top / sf)
866
+ # Keep the subfolder name as a subdirectory
867
+ new_path = path_to_add / sf.name / relative
868
+ break
869
+ except ValueError:
870
+ continue
871
+
872
+ if new_path is None:
873
+ # File doesn't match any subfolder, keep original path structure
874
+ new_path = path_to_add / file_path.relative_to(top)
875
+
876
+ parts.append(RemoteModelFile(url=model_file.url, path=new_path))
877
+ elif subfolder:
878
+ # Single subfolder: flatten into renamed folder
782
879
  top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
783
880
  path_to_remove = top / subfolder # sdxl-turbo/vae/
784
881
  subfolder_rename = subfolder.name.replace("/", "_").replace("\\", "_")
785
882
  path_to_add = Path(f"{top}_{subfolder_rename}")
786
- else:
787
- path_to_remove = Path(".")
788
- path_to_add = Path(".")
789
-
790
- parts: List[RemoteModelFile] = []
791
- for model_file in remote_files:
792
- assert model_file.size is not None
793
- parts.append(
794
- RemoteModelFile(
795
- url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
796
- path=path_to_add / model_file.path.relative_to(path_to_remove),
883
+
884
+ parts = []
885
+ for model_file in remote_files:
886
+ assert model_file.size is not None
887
+ parts.append(
888
+ RemoteModelFile(
889
+ url=model_file.url,
890
+ path=path_to_add / model_file.path.relative_to(path_to_remove),
891
+ )
797
892
  )
798
- )
893
+ else:
894
+ # No subfolder specified - pass through unchanged
895
+ parts = []
896
+ for model_file in remote_files:
897
+ assert model_file.size is not None
898
+ parts.append(RemoteModelFile(url=model_file.url, path=model_file.path))
799
899
 
800
900
  return self._download_queue.multifile_download(
801
901
  parts=parts,
@@ -138,6 +138,18 @@ class ModelRecordServiceBase(ABC):
138
138
  """
139
139
  pass
140
140
 
141
+ @abstractmethod
142
+ def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
143
+ """
144
+ Replace the model record entirely, returning the new record.
145
+
146
+ This is used when we re-identify a model and have a new config object.
147
+
148
+ :param key: Unique key for the model to be updated.
149
+ :param new_config: The new model config to write.
150
+ """
151
+ pass
152
+
141
153
  @abstractmethod
142
154
  def get_model(self, key: str) -> AnyModelConfig:
143
155
  """
@@ -179,6 +179,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
179
179
 
180
180
  return self.get_model(key)
181
181
 
182
+ def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
183
+ if key != new_config.key:
184
+ raise ValueError("key does not match new_config.key")
185
+ with self._db.transaction() as cursor:
186
+ cursor.execute(
187
+ """--sql
188
+ UPDATE models
189
+ SET
190
+ config=?
191
+ WHERE id=?;
192
+ """,
193
+ (new_config.model_dump_json(), key),
194
+ )
195
+ if cursor.rowcount == 0:
196
+ raise UnknownModelException("model not found")
197
+ return self.get_model(key)
198
+
182
199
  def get_model(self, key: str) -> AnyModelConfig:
183
200
  """
184
201
  Retrieve the ModelConfigBase instance for the indicated model.