InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""DyPE extension for FLUX denoising pipeline."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from invokeai.backend.flux.dype.base import DyPEConfig
|
|
7
|
+
from invokeai.backend.flux.dype.embed import DyPEEmbedND
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from invokeai.backend.flux.model import Flux
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DyPEExtension:
|
|
15
|
+
"""Extension for Dynamic Position Extrapolation in FLUX models.
|
|
16
|
+
|
|
17
|
+
This extension manages the patching of the FLUX model's position embedder
|
|
18
|
+
and updates the step state during denoising.
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
1. Create extension with config and target dimensions
|
|
22
|
+
2. Call patch_model() to replace pe_embedder with DyPE version
|
|
23
|
+
3. Call update_step_state() before each denoising step
|
|
24
|
+
4. Call restore_model() after denoising to restore original embedder
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
config: DyPEConfig
|
|
28
|
+
target_height: int
|
|
29
|
+
target_width: int
|
|
30
|
+
|
|
31
|
+
def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]:
|
|
32
|
+
"""Patch the model's position embedder with DyPE version.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model: The FLUX model to patch
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Tuple of (new DyPE embedder, original embedder for restoration)
|
|
39
|
+
"""
|
|
40
|
+
original_embedder = model.pe_embedder
|
|
41
|
+
|
|
42
|
+
dype_embedder = DyPEEmbedND.from_embednd(
|
|
43
|
+
embed_nd=original_embedder,
|
|
44
|
+
dype_config=self.config,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Set initial state
|
|
48
|
+
dype_embedder.set_step_state(
|
|
49
|
+
sigma=1.0,
|
|
50
|
+
height=self.target_height,
|
|
51
|
+
width=self.target_width,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Replace the embedder
|
|
55
|
+
model.pe_embedder = dype_embedder
|
|
56
|
+
|
|
57
|
+
return dype_embedder, original_embedder
|
|
58
|
+
|
|
59
|
+
def update_step_state(
|
|
60
|
+
self,
|
|
61
|
+
embedder: DyPEEmbedND,
|
|
62
|
+
timestep: float,
|
|
63
|
+
timestep_index: int,
|
|
64
|
+
total_steps: int,
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Update the step state in the DyPE embedder.
|
|
67
|
+
|
|
68
|
+
This should be called before each denoising step to update the
|
|
69
|
+
current noise level for timestep-dependent scaling.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
embedder: The DyPE embedder to update
|
|
73
|
+
timestep: Current timestep value (sigma/noise level)
|
|
74
|
+
timestep_index: Current step index (0-based)
|
|
75
|
+
total_steps: Total number of denoising steps
|
|
76
|
+
"""
|
|
77
|
+
embedder.set_step_state(
|
|
78
|
+
sigma=timestep,
|
|
79
|
+
height=self.target_height,
|
|
80
|
+
width=self.target_width,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def restore_model(model: "Flux", original_embedder: object) -> None:
|
|
85
|
+
"""Restore the original position embedder.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model: The FLUX model to restore
|
|
89
|
+
original_embedder: The original embedder saved from patch_model()
|
|
90
|
+
"""
|
|
91
|
+
model.pe_embedder = original_embedder
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Flow Matching scheduler definitions and mapping.
|
|
2
|
+
|
|
3
|
+
This module provides the scheduler types and mapping for Flow Matching models
|
|
4
|
+
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal, Type
|
|
8
|
+
|
|
9
|
+
from diffusers import (
|
|
10
|
+
FlowMatchEulerDiscreteScheduler,
|
|
11
|
+
FlowMatchHeunDiscreteScheduler,
|
|
12
|
+
)
|
|
13
|
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
14
|
+
|
|
15
|
+
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
|
|
16
|
+
try:
|
|
17
|
+
from diffusers import FlowMatchLCMScheduler
|
|
18
|
+
|
|
19
|
+
_HAS_LCM = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
_HAS_LCM = False
|
|
22
|
+
|
|
23
|
+
# Scheduler name literal type for type checking
|
|
24
|
+
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
|
25
|
+
|
|
26
|
+
# Human-readable labels for the UI
|
|
27
|
+
FLUX_SCHEDULER_LABELS: dict[str, str] = {
|
|
28
|
+
"euler": "Euler",
|
|
29
|
+
"heun": "Heun (2nd order)",
|
|
30
|
+
"lcm": "LCM",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Mapping from scheduler names to scheduler classes
|
|
34
|
+
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
|
35
|
+
"euler": FlowMatchEulerDiscreteScheduler,
|
|
36
|
+
"heun": FlowMatchHeunDiscreteScheduler,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
if _HAS_LCM:
|
|
40
|
+
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
|
|
44
|
+
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
|
|
45
|
+
# can be used for experimentation.
|
|
46
|
+
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
|
47
|
+
|
|
48
|
+
# Human-readable labels for the UI
|
|
49
|
+
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
|
|
50
|
+
"euler": "Euler",
|
|
51
|
+
"heun": "Heun (2nd order)",
|
|
52
|
+
"lcm": "LCM",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
# Mapping from scheduler names to scheduler classes (same as Flux)
|
|
56
|
+
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
|
57
|
+
"euler": FlowMatchEulerDiscreteScheduler,
|
|
58
|
+
"heun": FlowMatchHeunDiscreteScheduler,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if _HAS_LCM:
|
|
62
|
+
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
invokeai/backend/flux/util.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Literal
|
|
|
5
5
|
|
|
6
6
|
from invokeai.backend.flux.model import FluxParams
|
|
7
7
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
|
8
|
-
from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType
|
|
8
|
+
from invokeai.backend.model_manager.taxonomy import AnyVariant, Flux2VariantType, FluxVariantType
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -46,6 +46,8 @@ _flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = {
|
|
|
46
46
|
FluxVariantType.Dev: 512,
|
|
47
47
|
FluxVariantType.DevFill: 512,
|
|
48
48
|
FluxVariantType.Schnell: 256,
|
|
49
|
+
Flux2VariantType.Klein4B: 512,
|
|
50
|
+
Flux2VariantType.Klein9B: 512,
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -117,6 +119,38 @@ _flux_transformer_params: dict[AnyVariant, FluxParams] = {
|
|
|
117
119
|
qkv_bias=True,
|
|
118
120
|
guidance_embed=True,
|
|
119
121
|
),
|
|
122
|
+
# Flux2 Klein 4B uses Qwen3 4B text encoder with stacked embeddings from layers [9, 18, 27]
|
|
123
|
+
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 2560 = 7680)
|
|
124
|
+
Flux2VariantType.Klein4B: FluxParams(
|
|
125
|
+
in_channels=64,
|
|
126
|
+
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
|
|
127
|
+
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
|
|
128
|
+
hidden_size=3072,
|
|
129
|
+
mlp_ratio=4.0,
|
|
130
|
+
num_heads=24,
|
|
131
|
+
depth=19,
|
|
132
|
+
depth_single_blocks=38,
|
|
133
|
+
axes_dim=[16, 56, 56],
|
|
134
|
+
theta=10_000,
|
|
135
|
+
qkv_bias=True,
|
|
136
|
+
guidance_embed=True,
|
|
137
|
+
),
|
|
138
|
+
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
|
|
139
|
+
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
|
|
140
|
+
Flux2VariantType.Klein9B: FluxParams(
|
|
141
|
+
in_channels=64,
|
|
142
|
+
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
|
|
143
|
+
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
|
|
144
|
+
hidden_size=3072,
|
|
145
|
+
mlp_ratio=4.0,
|
|
146
|
+
num_heads=24,
|
|
147
|
+
depth=19,
|
|
148
|
+
depth_single_blocks=38,
|
|
149
|
+
axes_dim=[16, 56, 56],
|
|
150
|
+
theta=10_000,
|
|
151
|
+
qkv_bias=True,
|
|
152
|
+
guidance_embed=True,
|
|
153
|
+
),
|
|
120
154
|
}
|
|
121
155
|
|
|
122
156
|
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
"""Flux2 Klein Denoising Function.
|
|
2
|
+
|
|
3
|
+
This module provides the denoising function for FLUX.2 Klein models,
|
|
4
|
+
which use Qwen3 as the text encoder instead of CLIP+T5.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
|
|
15
|
+
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def denoise(
|
|
19
|
+
model: torch.nn.Module,
|
|
20
|
+
# model input
|
|
21
|
+
img: torch.Tensor,
|
|
22
|
+
img_ids: torch.Tensor,
|
|
23
|
+
txt: torch.Tensor,
|
|
24
|
+
txt_ids: torch.Tensor,
|
|
25
|
+
# sampling parameters
|
|
26
|
+
timesteps: list[float],
|
|
27
|
+
step_callback: Callable[[PipelineIntermediateState], None],
|
|
28
|
+
cfg_scale: list[float],
|
|
29
|
+
# Negative conditioning for CFG
|
|
30
|
+
neg_txt: torch.Tensor | None = None,
|
|
31
|
+
neg_txt_ids: torch.Tensor | None = None,
|
|
32
|
+
# Scheduler for stepping (e.g., FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler)
|
|
33
|
+
scheduler: Any = None,
|
|
34
|
+
# Dynamic shifting parameter for FLUX.2 Klein (computed from image resolution)
|
|
35
|
+
mu: float | None = None,
|
|
36
|
+
# Inpainting extension for merging latents during denoising
|
|
37
|
+
inpaint_extension: RectifiedFlowInpaintExtension | None = None,
|
|
38
|
+
# Reference image conditioning (multi-reference image editing)
|
|
39
|
+
img_cond_seq: torch.Tensor | None = None,
|
|
40
|
+
img_cond_seq_ids: torch.Tensor | None = None,
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
"""Denoise latents using a FLUX.2 Klein transformer model.
|
|
43
|
+
|
|
44
|
+
This is a simplified denoise function for FLUX.2 Klein models that uses
|
|
45
|
+
the diffusers Flux2Transformer2DModel interface.
|
|
46
|
+
|
|
47
|
+
Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
|
|
48
|
+
CFG is applied externally using negative conditioning when cfg_scale != 1.0.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
model: The Flux2Transformer2DModel from diffusers.
|
|
52
|
+
img: Packed latent image tensor of shape (B, seq_len, channels).
|
|
53
|
+
img_ids: Image position IDs tensor.
|
|
54
|
+
txt: Text encoder hidden states (Qwen3 embeddings).
|
|
55
|
+
txt_ids: Text position IDs tensor.
|
|
56
|
+
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
|
|
57
|
+
step_callback: Callback function for progress updates.
|
|
58
|
+
cfg_scale: List of CFG scale values per step.
|
|
59
|
+
neg_txt: Negative text embeddings for CFG (optional).
|
|
60
|
+
neg_txt_ids: Negative text position IDs (optional).
|
|
61
|
+
scheduler: Optional diffusers scheduler (Euler, Heun, LCM). If None, uses manual Euler.
|
|
62
|
+
mu: Dynamic shifting parameter computed from image resolution. Required when scheduler
|
|
63
|
+
has use_dynamic_shifting=True.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Denoised latent tensor.
|
|
67
|
+
"""
|
|
68
|
+
total_steps = len(timesteps) - 1
|
|
69
|
+
|
|
70
|
+
# Store original sequence length for extracting output later (before concatenating reference images)
|
|
71
|
+
original_seq_len = img.shape[1]
|
|
72
|
+
|
|
73
|
+
# Concatenate reference image conditioning if provided (multi-reference image editing)
|
|
74
|
+
if img_cond_seq is not None and img_cond_seq_ids is not None:
|
|
75
|
+
img = torch.cat([img, img_cond_seq], dim=1)
|
|
76
|
+
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)
|
|
77
|
+
|
|
78
|
+
# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
|
|
79
|
+
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
|
|
80
|
+
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
|
|
81
|
+
|
|
82
|
+
# Use scheduler if provided
|
|
83
|
+
use_scheduler = scheduler is not None
|
|
84
|
+
if use_scheduler:
|
|
85
|
+
# Set up scheduler with sigmas and mu for dynamic shifting
|
|
86
|
+
# Convert timesteps (0-1 range) to sigmas for the scheduler
|
|
87
|
+
# The scheduler will apply dynamic shifting internally using mu (if enabled in scheduler config)
|
|
88
|
+
sigmas = np.array(timesteps[:-1], dtype=np.float32) # Exclude final 0.0
|
|
89
|
+
|
|
90
|
+
# Pass mu if provided - it will only be used if scheduler has use_dynamic_shifting=True
|
|
91
|
+
if mu is not None:
|
|
92
|
+
scheduler.set_timesteps(sigmas=sigmas.tolist(), mu=mu, device=img.device)
|
|
93
|
+
else:
|
|
94
|
+
scheduler.set_timesteps(sigmas=sigmas.tolist(), device=img.device)
|
|
95
|
+
num_scheduler_steps = len(scheduler.timesteps)
|
|
96
|
+
is_heun = hasattr(scheduler, "state_in_first_order")
|
|
97
|
+
user_step = 0
|
|
98
|
+
|
|
99
|
+
pbar = tqdm(total=total_steps, desc="Denoising")
|
|
100
|
+
for step_index in range(num_scheduler_steps):
|
|
101
|
+
timestep = scheduler.timesteps[step_index]
|
|
102
|
+
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
|
|
103
|
+
t_curr = timestep.item() / scheduler.config.num_train_timesteps
|
|
104
|
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
105
|
+
|
|
106
|
+
# Track if we're in first or second order step (for Heun)
|
|
107
|
+
in_first_order = scheduler.state_in_first_order if is_heun else True
|
|
108
|
+
|
|
109
|
+
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
|
110
|
+
output = model(
|
|
111
|
+
hidden_states=img,
|
|
112
|
+
encoder_hidden_states=txt,
|
|
113
|
+
timestep=t_vec,
|
|
114
|
+
img_ids=img_ids,
|
|
115
|
+
txt_ids=txt_ids,
|
|
116
|
+
guidance=guidance,
|
|
117
|
+
return_dict=False,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Extract the sample from the output (return_dict=False returns tuple)
|
|
121
|
+
pred = output[0] if isinstance(output, tuple) else output
|
|
122
|
+
|
|
123
|
+
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
|
|
124
|
+
|
|
125
|
+
# Apply CFG if scale is not 1.0
|
|
126
|
+
if not math.isclose(step_cfg_scale, 1.0):
|
|
127
|
+
if neg_txt is None:
|
|
128
|
+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
|
129
|
+
|
|
130
|
+
neg_output = model(
|
|
131
|
+
hidden_states=img,
|
|
132
|
+
encoder_hidden_states=neg_txt,
|
|
133
|
+
timestep=t_vec,
|
|
134
|
+
img_ids=img_ids,
|
|
135
|
+
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
|
136
|
+
guidance=guidance,
|
|
137
|
+
return_dict=False,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
|
141
|
+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
|
142
|
+
|
|
143
|
+
# Use scheduler.step() for the update
|
|
144
|
+
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
|
|
145
|
+
img = step_output.prev_sample
|
|
146
|
+
|
|
147
|
+
# Get t_prev for inpainting (next sigma value)
|
|
148
|
+
if step_index + 1 < len(scheduler.sigmas):
|
|
149
|
+
t_prev = scheduler.sigmas[step_index + 1].item()
|
|
150
|
+
else:
|
|
151
|
+
t_prev = 0.0
|
|
152
|
+
|
|
153
|
+
# Apply inpainting merge at each step
|
|
154
|
+
if inpaint_extension is not None:
|
|
155
|
+
# Separate the generated latents from the reference conditioning
|
|
156
|
+
gen_img = img[:, :original_seq_len, :]
|
|
157
|
+
ref_img = img[:, original_seq_len:, :]
|
|
158
|
+
|
|
159
|
+
# Merge only the generated part
|
|
160
|
+
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
|
|
161
|
+
|
|
162
|
+
# Concatenate back together
|
|
163
|
+
img = torch.cat([gen_img, ref_img], dim=1)
|
|
164
|
+
|
|
165
|
+
# For Heun, only increment user step after second-order step completes
|
|
166
|
+
if is_heun:
|
|
167
|
+
if not in_first_order:
|
|
168
|
+
user_step += 1
|
|
169
|
+
if user_step <= total_steps:
|
|
170
|
+
pbar.update(1)
|
|
171
|
+
preview_img = img - t_curr * pred
|
|
172
|
+
if inpaint_extension is not None:
|
|
173
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
|
174
|
+
preview_img, 0.0
|
|
175
|
+
)
|
|
176
|
+
step_callback(
|
|
177
|
+
PipelineIntermediateState(
|
|
178
|
+
step=user_step,
|
|
179
|
+
order=2,
|
|
180
|
+
total_steps=total_steps,
|
|
181
|
+
timestep=int(t_curr * 1000),
|
|
182
|
+
latents=preview_img,
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
user_step += 1
|
|
187
|
+
if user_step <= total_steps:
|
|
188
|
+
pbar.update(1)
|
|
189
|
+
preview_img = img - t_curr * pred
|
|
190
|
+
if inpaint_extension is not None:
|
|
191
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
|
192
|
+
# Extract only the generated image portion for preview (exclude reference images)
|
|
193
|
+
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
|
194
|
+
step_callback(
|
|
195
|
+
PipelineIntermediateState(
|
|
196
|
+
step=user_step,
|
|
197
|
+
order=1,
|
|
198
|
+
total_steps=total_steps,
|
|
199
|
+
timestep=int(t_curr * 1000),
|
|
200
|
+
latents=callback_latents,
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
pbar.close()
|
|
205
|
+
else:
|
|
206
|
+
# Manual Euler stepping (original behavior)
|
|
207
|
+
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
|
208
|
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
209
|
+
|
|
210
|
+
# Run the transformer model (matching diffusers: guidance=guidance, return_dict=False)
|
|
211
|
+
output = model(
|
|
212
|
+
hidden_states=img,
|
|
213
|
+
encoder_hidden_states=txt,
|
|
214
|
+
timestep=t_vec,
|
|
215
|
+
img_ids=img_ids,
|
|
216
|
+
txt_ids=txt_ids,
|
|
217
|
+
guidance=guidance,
|
|
218
|
+
return_dict=False,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Extract the sample from the output (return_dict=False returns tuple)
|
|
222
|
+
pred = output[0] if isinstance(output, tuple) else output
|
|
223
|
+
|
|
224
|
+
step_cfg_scale = cfg_scale[step_index]
|
|
225
|
+
|
|
226
|
+
# Apply CFG if scale is not 1.0
|
|
227
|
+
if not math.isclose(step_cfg_scale, 1.0):
|
|
228
|
+
if neg_txt is None:
|
|
229
|
+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
|
230
|
+
|
|
231
|
+
neg_output = model(
|
|
232
|
+
hidden_states=img,
|
|
233
|
+
encoder_hidden_states=neg_txt,
|
|
234
|
+
timestep=t_vec,
|
|
235
|
+
img_ids=img_ids,
|
|
236
|
+
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
|
|
237
|
+
guidance=guidance,
|
|
238
|
+
return_dict=False,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
neg_pred = neg_output[0] if isinstance(neg_output, tuple) else neg_output
|
|
242
|
+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
|
243
|
+
|
|
244
|
+
# Euler step
|
|
245
|
+
preview_img = img - t_curr * pred
|
|
246
|
+
img = img + (t_prev - t_curr) * pred
|
|
247
|
+
|
|
248
|
+
# Apply inpainting merge at each step
|
|
249
|
+
if inpaint_extension is not None:
|
|
250
|
+
# Separate the generated latents from the reference conditioning
|
|
251
|
+
gen_img = img[:, :original_seq_len, :]
|
|
252
|
+
ref_img = img[:, original_seq_len:, :]
|
|
253
|
+
|
|
254
|
+
# Merge only the generated part
|
|
255
|
+
gen_img = inpaint_extension.merge_intermediate_latents_with_init_latents(gen_img, t_prev)
|
|
256
|
+
|
|
257
|
+
# Concatenate back together
|
|
258
|
+
img = torch.cat([gen_img, ref_img], dim=1)
|
|
259
|
+
|
|
260
|
+
# Handling preview images
|
|
261
|
+
preview_gen = preview_img[:, :original_seq_len, :]
|
|
262
|
+
preview_gen = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_gen, 0.0)
|
|
263
|
+
|
|
264
|
+
# Extract only the generated image portion for preview (exclude reference images)
|
|
265
|
+
callback_latents = preview_img[:, :original_seq_len, :] if img_cond_seq is not None else preview_img
|
|
266
|
+
step_callback(
|
|
267
|
+
PipelineIntermediateState(
|
|
268
|
+
step=step_index + 1,
|
|
269
|
+
order=1,
|
|
270
|
+
total_steps=total_steps,
|
|
271
|
+
timestep=int(t_curr),
|
|
272
|
+
latents=callback_latents,
|
|
273
|
+
),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Extract only the generated image portion (exclude concatenated reference images)
|
|
277
|
+
if img_cond_seq is not None:
|
|
278
|
+
img = img[:, :original_seq_len, :]
|
|
279
|
+
|
|
280
|
+
return img
|