InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/routers/model_manager.py +43 -1
- invokeai/app/invocations/fields.py +1 -1
- invokeai/app/invocations/flux2_denoise.py +499 -0
- invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
- invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
- invokeai/app/invocations/flux2_vae_decode.py +106 -0
- invokeai/app/invocations/flux2_vae_encode.py +88 -0
- invokeai/app/invocations/flux_denoise.py +77 -3
- invokeai/app/invocations/flux_lora_loader.py +1 -1
- invokeai/app/invocations/flux_model_loader.py +2 -5
- invokeai/app/invocations/ideal_size.py +6 -1
- invokeai/app/invocations/metadata.py +4 -0
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +1 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +244 -84
- invokeai/app/invocations/z_image_image_to_latents.py +9 -1
- invokeai/app/invocations/z_image_latents_to_image.py +9 -1
- invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
- invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +4 -2
- invokeai/app/services/shared/invocation_context.py +15 -0
- invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
- invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
- invokeai/app/util/step_callback.py +58 -2
- invokeai/backend/flux/denoise.py +338 -118
- invokeai/backend/flux/dype/__init__.py +31 -0
- invokeai/backend/flux/dype/base.py +260 -0
- invokeai/backend/flux/dype/embed.py +116 -0
- invokeai/backend/flux/dype/presets.py +148 -0
- invokeai/backend/flux/dype/rope.py +110 -0
- invokeai/backend/flux/extensions/dype_extension.py +91 -0
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/flux/util.py +35 -1
- invokeai/backend/flux2/__init__.py +4 -0
- invokeai/backend/flux2/denoise.py +280 -0
- invokeai/backend/flux2/ref_image_extension.py +294 -0
- invokeai/backend/flux2/sampling_utils.py +209 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/factory.py +19 -1
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/configs/main.py +395 -3
- invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
- invokeai/backend/model_manager/configs/vae.py +104 -2
- invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
- invokeai/backend/model_manager/starter_models.py +141 -4
- invokeai/backend/model_manager/taxonomy.py +31 -4
- invokeai/backend/model_manager/util/select_hf_files.py +3 -2
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/util/vae_working_memory.py +0 -2
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
- invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +85 -6
- invokeai/frontend/web/dist/locales/it.json +135 -15
- invokeai/frontend/web/dist/locales/ru.json +11 -11
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""DyPE base configuration and utilities."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class DyPEConfig:
|
|
13
|
+
"""Configuration for Dynamic Position Extrapolation."""
|
|
14
|
+
|
|
15
|
+
enable_dype: bool = True
|
|
16
|
+
base_resolution: int = 1024 # Native training resolution
|
|
17
|
+
method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
|
|
18
|
+
dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
|
|
19
|
+
dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
|
|
20
|
+
dype_start_sigma: float = 1.0 # When DyPE decay starts
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
|
|
24
|
+
"""Calculate magnitude scaling factor.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
scale: The resolution scaling factor
|
|
28
|
+
mscale_factor: Adjustment factor for the scaling
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
The magnitude scaling factor
|
|
32
|
+
"""
|
|
33
|
+
if scale <= 1.0:
|
|
34
|
+
return 1.0
|
|
35
|
+
return mscale_factor * math.log(scale) + 1.0
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_timestep_mscale(
|
|
39
|
+
scale: float,
|
|
40
|
+
current_sigma: float,
|
|
41
|
+
dype_scale: float,
|
|
42
|
+
dype_exponent: float,
|
|
43
|
+
dype_start_sigma: float,
|
|
44
|
+
) -> float:
|
|
45
|
+
"""Calculate timestep-dependent magnitude scaling.
|
|
46
|
+
|
|
47
|
+
The key insight of DyPE: early steps focus on low frequencies (global structure),
|
|
48
|
+
late steps on high frequencies (details). This function modulates the scaling
|
|
49
|
+
based on the current timestep/sigma.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
scale: Resolution scaling factor
|
|
53
|
+
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
|
54
|
+
dype_scale: DyPE magnitude (λs)
|
|
55
|
+
dype_exponent: DyPE decay speed (λt)
|
|
56
|
+
dype_start_sigma: Sigma threshold to start decay
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Timestep-modulated scaling factor
|
|
60
|
+
"""
|
|
61
|
+
if scale <= 1.0:
|
|
62
|
+
return 1.0
|
|
63
|
+
|
|
64
|
+
# Normalize sigma to [0, 1] range relative to start_sigma
|
|
65
|
+
if current_sigma >= dype_start_sigma:
|
|
66
|
+
t_normalized = 1.0
|
|
67
|
+
else:
|
|
68
|
+
t_normalized = current_sigma / dype_start_sigma
|
|
69
|
+
|
|
70
|
+
# Apply exponential decay: stronger extrapolation early, weaker late
|
|
71
|
+
# decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
|
|
72
|
+
decay = math.exp(-dype_exponent * (1.0 - t_normalized))
|
|
73
|
+
|
|
74
|
+
# Base mscale from resolution
|
|
75
|
+
base_mscale = get_mscale(scale)
|
|
76
|
+
|
|
77
|
+
# Interpolate between base_mscale and 1.0 based on decay and dype_scale
|
|
78
|
+
# When decay=1 (early): use scaled value
|
|
79
|
+
# When decay=0 (late): use base value
|
|
80
|
+
scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
|
|
81
|
+
|
|
82
|
+
return scaled_mscale
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def compute_vision_yarn_freqs(
|
|
86
|
+
pos: Tensor,
|
|
87
|
+
dim: int,
|
|
88
|
+
theta: int,
|
|
89
|
+
scale_h: float,
|
|
90
|
+
scale_w: float,
|
|
91
|
+
current_sigma: float,
|
|
92
|
+
dype_config: DyPEConfig,
|
|
93
|
+
) -> tuple[Tensor, Tensor]:
|
|
94
|
+
"""Compute RoPE frequencies using NTK-aware scaling for high-resolution.
|
|
95
|
+
|
|
96
|
+
This method extends FLUX's position encoding to handle resolutions beyond
|
|
97
|
+
the 1024px training resolution by scaling the base frequency (theta).
|
|
98
|
+
|
|
99
|
+
The NTK-aware approach smoothly interpolates frequencies to cover larger
|
|
100
|
+
position ranges without breaking the attention patterns.
|
|
101
|
+
|
|
102
|
+
DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on
|
|
103
|
+
the current timestep - stronger extrapolation in early steps (global structure),
|
|
104
|
+
weaker in late steps (fine details).
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
pos: Position tensor
|
|
108
|
+
dim: Embedding dimension
|
|
109
|
+
theta: RoPE base frequency
|
|
110
|
+
scale_h: Height scaling factor
|
|
111
|
+
scale_w: Width scaling factor
|
|
112
|
+
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
|
113
|
+
dype_config: DyPE configuration
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tuple of (cos, sin) frequency tensors
|
|
117
|
+
"""
|
|
118
|
+
assert dim % 2 == 0
|
|
119
|
+
|
|
120
|
+
# Use the larger scale for NTK calculation
|
|
121
|
+
scale = max(scale_h, scale_w)
|
|
122
|
+
|
|
123
|
+
device = pos.device
|
|
124
|
+
dtype = torch.float64 if device.type != "mps" else torch.float32
|
|
125
|
+
|
|
126
|
+
# NTK-aware theta scaling: extends position coverage for high-res
|
|
127
|
+
# Formula: theta_scaled = theta * scale^(dim/(dim-2))
|
|
128
|
+
# This increases the wavelength of position encodings proportionally
|
|
129
|
+
if scale > 1.0:
|
|
130
|
+
ntk_alpha = scale ** (dim / (dim - 2))
|
|
131
|
+
|
|
132
|
+
# Apply timestep-dependent DyPE modulation
|
|
133
|
+
# mscale controls how strongly we apply the NTK extrapolation
|
|
134
|
+
# Early steps (high sigma): stronger extrapolation for global structure
|
|
135
|
+
# Late steps (low sigma): weaker extrapolation for fine details
|
|
136
|
+
mscale = get_timestep_mscale(
|
|
137
|
+
scale=scale,
|
|
138
|
+
current_sigma=current_sigma,
|
|
139
|
+
dype_scale=dype_config.dype_scale,
|
|
140
|
+
dype_exponent=dype_config.dype_exponent,
|
|
141
|
+
dype_start_sigma=dype_config.dype_start_sigma,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Modulate NTK alpha by mscale
|
|
145
|
+
# When mscale > 1: interpolate towards stronger extrapolation
|
|
146
|
+
# When mscale = 1: use base NTK alpha
|
|
147
|
+
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
|
|
148
|
+
scaled_theta = theta * modulated_alpha
|
|
149
|
+
else:
|
|
150
|
+
scaled_theta = theta
|
|
151
|
+
|
|
152
|
+
# Standard RoPE frequency computation
|
|
153
|
+
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
|
154
|
+
freqs = 1.0 / (scaled_theta**freq_seq)
|
|
155
|
+
|
|
156
|
+
# Compute angles = position * frequency
|
|
157
|
+
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
|
158
|
+
|
|
159
|
+
cos = torch.cos(angles)
|
|
160
|
+
sin = torch.sin(angles)
|
|
161
|
+
|
|
162
|
+
return cos.to(pos.dtype), sin.to(pos.dtype)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def compute_yarn_freqs(
|
|
166
|
+
pos: Tensor,
|
|
167
|
+
dim: int,
|
|
168
|
+
theta: int,
|
|
169
|
+
scale: float,
|
|
170
|
+
current_sigma: float,
|
|
171
|
+
dype_config: DyPEConfig,
|
|
172
|
+
) -> tuple[Tensor, Tensor]:
|
|
173
|
+
"""Compute RoPE frequencies using YARN/NTK method.
|
|
174
|
+
|
|
175
|
+
Uses NTK-aware theta scaling for high-resolution support with
|
|
176
|
+
timestep-dependent DyPE modulation.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
pos: Position tensor
|
|
180
|
+
dim: Embedding dimension
|
|
181
|
+
theta: RoPE base frequency
|
|
182
|
+
scale: Uniform scaling factor
|
|
183
|
+
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
|
184
|
+
dype_config: DyPE configuration
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Tuple of (cos, sin) frequency tensors
|
|
188
|
+
"""
|
|
189
|
+
assert dim % 2 == 0
|
|
190
|
+
|
|
191
|
+
device = pos.device
|
|
192
|
+
dtype = torch.float64 if device.type != "mps" else torch.float32
|
|
193
|
+
|
|
194
|
+
# NTK-aware theta scaling with DyPE modulation
|
|
195
|
+
if scale > 1.0:
|
|
196
|
+
ntk_alpha = scale ** (dim / (dim - 2))
|
|
197
|
+
|
|
198
|
+
# Apply timestep-dependent DyPE modulation
|
|
199
|
+
mscale = get_timestep_mscale(
|
|
200
|
+
scale=scale,
|
|
201
|
+
current_sigma=current_sigma,
|
|
202
|
+
dype_scale=dype_config.dype_scale,
|
|
203
|
+
dype_exponent=dype_config.dype_exponent,
|
|
204
|
+
dype_start_sigma=dype_config.dype_start_sigma,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Modulate NTK alpha by mscale
|
|
208
|
+
modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
|
|
209
|
+
scaled_theta = theta * modulated_alpha
|
|
210
|
+
else:
|
|
211
|
+
scaled_theta = theta
|
|
212
|
+
|
|
213
|
+
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
|
214
|
+
freqs = 1.0 / (scaled_theta**freq_seq)
|
|
215
|
+
|
|
216
|
+
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
|
217
|
+
|
|
218
|
+
cos = torch.cos(angles)
|
|
219
|
+
sin = torch.sin(angles)
|
|
220
|
+
|
|
221
|
+
return cos.to(pos.dtype), sin.to(pos.dtype)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def compute_ntk_freqs(
|
|
225
|
+
pos: Tensor,
|
|
226
|
+
dim: int,
|
|
227
|
+
theta: int,
|
|
228
|
+
scale: float,
|
|
229
|
+
) -> tuple[Tensor, Tensor]:
|
|
230
|
+
"""Compute RoPE frequencies using NTK method.
|
|
231
|
+
|
|
232
|
+
Neural Tangent Kernel approach - continuous frequency scaling without
|
|
233
|
+
timestep dependency.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
pos: Position tensor
|
|
237
|
+
dim: Embedding dimension
|
|
238
|
+
theta: RoPE base frequency
|
|
239
|
+
scale: Scaling factor
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Tuple of (cos, sin) frequency tensors
|
|
243
|
+
"""
|
|
244
|
+
assert dim % 2 == 0
|
|
245
|
+
|
|
246
|
+
device = pos.device
|
|
247
|
+
dtype = torch.float64 if device.type != "mps" else torch.float32
|
|
248
|
+
|
|
249
|
+
# NTK scaling
|
|
250
|
+
scaled_theta = theta * (scale ** (dim / (dim - 2)))
|
|
251
|
+
|
|
252
|
+
freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
|
253
|
+
freqs = 1.0 / (scaled_theta**freq_seq)
|
|
254
|
+
|
|
255
|
+
angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
|
|
256
|
+
|
|
257
|
+
cos = torch.cos(angles)
|
|
258
|
+
sin = torch.sin(angles)
|
|
259
|
+
|
|
260
|
+
return cos.to(pos.dtype), sin.to(pos.dtype)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""DyPE-enhanced position embedding module."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from invokeai.backend.flux.dype.base import DyPEConfig
|
|
7
|
+
from invokeai.backend.flux.dype.rope import rope_dype
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DyPEEmbedND(nn.Module):
|
|
11
|
+
"""N-dimensional position embedding with DyPE support.
|
|
12
|
+
|
|
13
|
+
This class replaces the standard EmbedND from FLUX with a DyPE-aware version
|
|
14
|
+
that dynamically scales position embeddings based on resolution and timestep.
|
|
15
|
+
|
|
16
|
+
The key difference from EmbedND:
|
|
17
|
+
- Maintains step state (current_sigma, target dimensions)
|
|
18
|
+
- Uses rope_dype() instead of rope() for frequency computation
|
|
19
|
+
- Applies timestep-dependent scaling for better high-resolution generation
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
dim: int,
|
|
25
|
+
theta: int,
|
|
26
|
+
axes_dim: list[int],
|
|
27
|
+
dype_config: DyPEConfig,
|
|
28
|
+
):
|
|
29
|
+
"""Initialize DyPE position embedder.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dim: Total embedding dimension (sum of axes_dim)
|
|
33
|
+
theta: RoPE base frequency
|
|
34
|
+
axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX)
|
|
35
|
+
dype_config: DyPE configuration
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.dim = dim
|
|
39
|
+
self.theta = theta
|
|
40
|
+
self.axes_dim = axes_dim
|
|
41
|
+
self.dype_config = dype_config
|
|
42
|
+
|
|
43
|
+
# Step state - updated before each denoising step
|
|
44
|
+
self._current_sigma: float = 1.0
|
|
45
|
+
self._target_height: int = 1024
|
|
46
|
+
self._target_width: int = 1024
|
|
47
|
+
|
|
48
|
+
def set_step_state(self, sigma: float, height: int, width: int) -> None:
|
|
49
|
+
"""Update the step state before each denoising step.
|
|
50
|
+
|
|
51
|
+
This method should be called by the DyPE extension before each step
|
|
52
|
+
to update the current noise level and target dimensions.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
sigma: Current noise level (timestep value, 1.0 = full noise)
|
|
56
|
+
height: Target image height in pixels
|
|
57
|
+
width: Target image width in pixels
|
|
58
|
+
"""
|
|
59
|
+
self._current_sigma = sigma
|
|
60
|
+
self._target_height = height
|
|
61
|
+
self._target_width = width
|
|
62
|
+
|
|
63
|
+
def forward(self, ids: Tensor) -> Tensor:
|
|
64
|
+
"""Compute position embeddings with DyPE scaling.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
ids: Position indices tensor with shape (batch, seq_len, n_axes)
|
|
68
|
+
For FLUX: n_axes=3 (time/channel, height, width)
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Position embedding tensor with shape (batch, 1, seq_len, dim)
|
|
72
|
+
"""
|
|
73
|
+
n_axes = ids.shape[-1]
|
|
74
|
+
|
|
75
|
+
# Compute RoPE for each axis with DyPE scaling
|
|
76
|
+
embeddings = []
|
|
77
|
+
for i in range(n_axes):
|
|
78
|
+
axis_emb = rope_dype(
|
|
79
|
+
pos=ids[..., i],
|
|
80
|
+
dim=self.axes_dim[i],
|
|
81
|
+
theta=self.theta,
|
|
82
|
+
current_sigma=self._current_sigma,
|
|
83
|
+
target_height=self._target_height,
|
|
84
|
+
target_width=self._target_width,
|
|
85
|
+
dype_config=self.dype_config,
|
|
86
|
+
)
|
|
87
|
+
embeddings.append(axis_emb)
|
|
88
|
+
|
|
89
|
+
# Concatenate embeddings from all axes
|
|
90
|
+
emb = torch.cat(embeddings, dim=-3)
|
|
91
|
+
|
|
92
|
+
return emb.unsqueeze(1)
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def from_embednd(
|
|
96
|
+
cls,
|
|
97
|
+
embed_nd: nn.Module,
|
|
98
|
+
dype_config: DyPEConfig,
|
|
99
|
+
) -> "DyPEEmbedND":
|
|
100
|
+
"""Create a DyPEEmbedND from an existing EmbedND.
|
|
101
|
+
|
|
102
|
+
This is a convenience method for patching an existing FLUX model.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
embed_nd: Original EmbedND module from FLUX
|
|
106
|
+
dype_config: DyPE configuration
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
New DyPEEmbedND with same parameters
|
|
110
|
+
"""
|
|
111
|
+
return cls(
|
|
112
|
+
dim=embed_nd.dim,
|
|
113
|
+
theta=embed_nd.theta,
|
|
114
|
+
axes_dim=embed_nd.axes_dim,
|
|
115
|
+
dype_config=dype_config,
|
|
116
|
+
)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""DyPE presets and automatic configuration."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from invokeai.backend.flux.dype.base import DyPEConfig
|
|
7
|
+
|
|
8
|
+
# DyPE preset type - using Literal for proper frontend dropdown support
|
|
9
|
+
DyPEPreset = Literal["off", "manual", "auto", "4k"]
|
|
10
|
+
|
|
11
|
+
# Constants for preset values
|
|
12
|
+
DYPE_PRESET_OFF: DyPEPreset = "off"
|
|
13
|
+
DYPE_PRESET_MANUAL: DyPEPreset = "manual"
|
|
14
|
+
DYPE_PRESET_AUTO: DyPEPreset = "auto"
|
|
15
|
+
DYPE_PRESET_4K: DyPEPreset = "4k"
|
|
16
|
+
|
|
17
|
+
# Human-readable labels for the UI
|
|
18
|
+
DYPE_PRESET_LABELS: dict[str, str] = {
|
|
19
|
+
"off": "Off",
|
|
20
|
+
"manual": "Manual",
|
|
21
|
+
"auto": "Auto (>1536px)",
|
|
22
|
+
"4k": "4K Optimized",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class DyPEPresetConfig:
|
|
28
|
+
"""Preset configuration values."""
|
|
29
|
+
|
|
30
|
+
base_resolution: int
|
|
31
|
+
method: str
|
|
32
|
+
dype_scale: float
|
|
33
|
+
dype_exponent: float
|
|
34
|
+
dype_start_sigma: float
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# Predefined preset configurations
|
|
38
|
+
DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
|
|
39
|
+
DYPE_PRESET_4K: DyPEPresetConfig(
|
|
40
|
+
base_resolution=1024,
|
|
41
|
+
method="vision_yarn",
|
|
42
|
+
dype_scale=2.0,
|
|
43
|
+
dype_exponent=2.0,
|
|
44
|
+
dype_start_sigma=1.0,
|
|
45
|
+
),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_dype_config_for_resolution(
|
|
50
|
+
width: int,
|
|
51
|
+
height: int,
|
|
52
|
+
base_resolution: int = 1024,
|
|
53
|
+
activation_threshold: int = 1536,
|
|
54
|
+
) -> DyPEConfig | None:
|
|
55
|
+
"""Automatically determine DyPE config based on target resolution.
|
|
56
|
+
|
|
57
|
+
FLUX can handle resolutions up to ~1.5x natively without significant artifacts.
|
|
58
|
+
DyPE is only activated when the resolution exceeds the activation threshold.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
width: Target image width in pixels
|
|
62
|
+
height: Target image height in pixels
|
|
63
|
+
base_resolution: Native training resolution of the model (for scale calculation)
|
|
64
|
+
activation_threshold: Resolution threshold above which DyPE is activated
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
DyPEConfig if DyPE should be enabled, None otherwise
|
|
68
|
+
"""
|
|
69
|
+
max_dim = max(width, height)
|
|
70
|
+
|
|
71
|
+
if max_dim <= activation_threshold:
|
|
72
|
+
return None # FLUX can handle this natively
|
|
73
|
+
|
|
74
|
+
# Calculate scaling factor based on base_resolution
|
|
75
|
+
scale = max_dim / base_resolution
|
|
76
|
+
|
|
77
|
+
# Dynamic parameters based on scaling
|
|
78
|
+
# Higher resolution = higher dype_scale, capped at 8.0
|
|
79
|
+
dynamic_dype_scale = min(2.0 * scale, 8.0)
|
|
80
|
+
|
|
81
|
+
return DyPEConfig(
|
|
82
|
+
enable_dype=True,
|
|
83
|
+
base_resolution=base_resolution,
|
|
84
|
+
method="vision_yarn",
|
|
85
|
+
dype_scale=dynamic_dype_scale,
|
|
86
|
+
dype_exponent=2.0,
|
|
87
|
+
dype_start_sigma=1.0,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_dype_config_from_preset(
|
|
92
|
+
preset: DyPEPreset,
|
|
93
|
+
width: int,
|
|
94
|
+
height: int,
|
|
95
|
+
custom_scale: float | None = None,
|
|
96
|
+
custom_exponent: float | None = None,
|
|
97
|
+
) -> DyPEConfig | None:
|
|
98
|
+
"""Get DyPE configuration from a preset or custom values.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
preset: The DyPE preset to use
|
|
102
|
+
width: Target image width
|
|
103
|
+
height: Target image height
|
|
104
|
+
custom_scale: Optional custom dype_scale (only used with 'manual' preset)
|
|
105
|
+
custom_exponent: Optional custom dype_exponent (only used with 'manual' preset)
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
DyPEConfig if DyPE should be enabled, None otherwise
|
|
109
|
+
"""
|
|
110
|
+
if preset == DYPE_PRESET_OFF:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
if preset == DYPE_PRESET_MANUAL:
|
|
114
|
+
# Manual mode - custom values can override defaults
|
|
115
|
+
max_dim = max(width, height)
|
|
116
|
+
scale = max_dim / 1024
|
|
117
|
+
dynamic_dype_scale = min(2.0 * scale, 8.0)
|
|
118
|
+
return DyPEConfig(
|
|
119
|
+
enable_dype=True,
|
|
120
|
+
base_resolution=1024,
|
|
121
|
+
method="vision_yarn",
|
|
122
|
+
dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
|
|
123
|
+
dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
|
|
124
|
+
dype_start_sigma=1.0,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if preset == DYPE_PRESET_AUTO:
|
|
128
|
+
# Auto preset - custom values are ignored
|
|
129
|
+
return get_dype_config_for_resolution(
|
|
130
|
+
width=width,
|
|
131
|
+
height=height,
|
|
132
|
+
base_resolution=1024,
|
|
133
|
+
activation_threshold=1536,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Use preset configuration (4K etc.) - custom values are ignored
|
|
137
|
+
preset_config = DYPE_PRESETS.get(preset)
|
|
138
|
+
if preset_config is None:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
return DyPEConfig(
|
|
142
|
+
enable_dype=True,
|
|
143
|
+
base_resolution=preset_config.base_resolution,
|
|
144
|
+
method=preset_config.method,
|
|
145
|
+
dype_scale=preset_config.dype_scale,
|
|
146
|
+
dype_exponent=preset_config.dype_exponent,
|
|
147
|
+
dype_start_sigma=preset_config.dype_start_sigma,
|
|
148
|
+
)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""DyPE-enhanced RoPE (Rotary Position Embedding) functions."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from invokeai.backend.flux.dype.base import (
|
|
8
|
+
DyPEConfig,
|
|
9
|
+
compute_ntk_freqs,
|
|
10
|
+
compute_vision_yarn_freqs,
|
|
11
|
+
compute_yarn_freqs,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def rope_dype(
|
|
16
|
+
pos: Tensor,
|
|
17
|
+
dim: int,
|
|
18
|
+
theta: int,
|
|
19
|
+
current_sigma: float,
|
|
20
|
+
target_height: int,
|
|
21
|
+
target_width: int,
|
|
22
|
+
dype_config: DyPEConfig,
|
|
23
|
+
) -> Tensor:
|
|
24
|
+
"""Compute RoPE with Dynamic Position Extrapolation.
|
|
25
|
+
|
|
26
|
+
This is the core DyPE function that replaces the standard rope() function.
|
|
27
|
+
It applies resolution-aware and timestep-aware scaling to position embeddings.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
pos: Position indices tensor
|
|
31
|
+
dim: Embedding dimension per axis
|
|
32
|
+
theta: RoPE base frequency (typically 10000)
|
|
33
|
+
current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
|
|
34
|
+
target_height: Target image height in pixels
|
|
35
|
+
target_width: Target image width in pixels
|
|
36
|
+
dype_config: DyPE configuration
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Rotary position embedding tensor with shape suitable for FLUX attention
|
|
40
|
+
"""
|
|
41
|
+
assert dim % 2 == 0
|
|
42
|
+
|
|
43
|
+
# Calculate scaling factors
|
|
44
|
+
base_res = dype_config.base_resolution
|
|
45
|
+
scale_h = target_height / base_res
|
|
46
|
+
scale_w = target_width / base_res
|
|
47
|
+
scale = max(scale_h, scale_w)
|
|
48
|
+
|
|
49
|
+
# If no scaling needed and DyPE disabled, use base method
|
|
50
|
+
if not dype_config.enable_dype or scale <= 1.0:
|
|
51
|
+
return _rope_base(pos, dim, theta)
|
|
52
|
+
|
|
53
|
+
# Select method and compute frequencies
|
|
54
|
+
method = dype_config.method
|
|
55
|
+
|
|
56
|
+
if method == "vision_yarn":
|
|
57
|
+
cos, sin = compute_vision_yarn_freqs(
|
|
58
|
+
pos=pos,
|
|
59
|
+
dim=dim,
|
|
60
|
+
theta=theta,
|
|
61
|
+
scale_h=scale_h,
|
|
62
|
+
scale_w=scale_w,
|
|
63
|
+
current_sigma=current_sigma,
|
|
64
|
+
dype_config=dype_config,
|
|
65
|
+
)
|
|
66
|
+
elif method == "yarn":
|
|
67
|
+
cos, sin = compute_yarn_freqs(
|
|
68
|
+
pos=pos,
|
|
69
|
+
dim=dim,
|
|
70
|
+
theta=theta,
|
|
71
|
+
scale=scale,
|
|
72
|
+
current_sigma=current_sigma,
|
|
73
|
+
dype_config=dype_config,
|
|
74
|
+
)
|
|
75
|
+
elif method == "ntk":
|
|
76
|
+
cos, sin = compute_ntk_freqs(
|
|
77
|
+
pos=pos,
|
|
78
|
+
dim=dim,
|
|
79
|
+
theta=theta,
|
|
80
|
+
scale=scale,
|
|
81
|
+
)
|
|
82
|
+
else: # "base"
|
|
83
|
+
return _rope_base(pos, dim, theta)
|
|
84
|
+
|
|
85
|
+
# Construct rotation matrix from cos/sin
|
|
86
|
+
# Output shape: (batch, seq_len, dim/2, 2, 2)
|
|
87
|
+
out = torch.stack([cos, -sin, sin, cos], dim=-1)
|
|
88
|
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
|
89
|
+
|
|
90
|
+
return out.to(dtype=pos.dtype, device=pos.device)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|
94
|
+
"""Standard RoPE without DyPE scaling.
|
|
95
|
+
|
|
96
|
+
This matches the original rope() function from invokeai.backend.flux.math.
|
|
97
|
+
"""
|
|
98
|
+
assert dim % 2 == 0
|
|
99
|
+
|
|
100
|
+
device = pos.device
|
|
101
|
+
dtype = torch.float64 if device.type != "mps" else torch.float32
|
|
102
|
+
|
|
103
|
+
scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
|
|
104
|
+
omega = 1.0 / (theta**scale)
|
|
105
|
+
|
|
106
|
+
out = torch.einsum("...n,d->...nd", pos.to(dtype), omega)
|
|
107
|
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
|
108
|
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
|
109
|
+
|
|
110
|
+
return out.to(dtype=pos.dtype, device=pos.device)
|