InvokeAI 6.10.0rc2__py3-none-any.whl → 6.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +50 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/ideal_size.py +6 -1
  11. invokeai/app/invocations/metadata.py +4 -0
  12. invokeai/app/invocations/metadata_linked.py +47 -0
  13. invokeai/app/invocations/model.py +1 -0
  14. invokeai/app/invocations/z_image_denoise.py +8 -3
  15. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  16. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  17. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  18. invokeai/app/services/config/config_default.py +3 -1
  19. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  20. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  21. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  22. invokeai/app/services/model_records/model_records_base.py +4 -2
  23. invokeai/app/services/shared/invocation_context.py +15 -0
  24. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  25. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  26. invokeai/app/util/step_callback.py +42 -0
  27. invokeai/backend/flux/denoise.py +239 -204
  28. invokeai/backend/flux/dype/__init__.py +18 -0
  29. invokeai/backend/flux/dype/base.py +226 -0
  30. invokeai/backend/flux/dype/embed.py +116 -0
  31. invokeai/backend/flux/dype/presets.py +141 -0
  32. invokeai/backend/flux/dype/rope.py +110 -0
  33. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  34. invokeai/backend/flux/util.py +35 -1
  35. invokeai/backend/flux2/__init__.py +4 -0
  36. invokeai/backend/flux2/denoise.py +261 -0
  37. invokeai/backend/flux2/ref_image_extension.py +294 -0
  38. invokeai/backend/flux2/sampling_utils.py +209 -0
  39. invokeai/backend/model_manager/configs/factory.py +19 -1
  40. invokeai/backend/model_manager/configs/main.py +395 -3
  41. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  42. invokeai/backend/model_manager/configs/vae.py +104 -2
  43. invokeai/backend/model_manager/load/load_default.py +0 -1
  44. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  45. invokeai/backend/model_manager/load/model_loaders/flux.py +1007 -2
  46. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +0 -1
  47. invokeai/backend/model_manager/load/model_loaders/z_image.py +121 -28
  48. invokeai/backend/model_manager/starter_models.py +128 -0
  49. invokeai/backend/model_manager/taxonomy.py +31 -4
  50. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  51. invokeai/backend/util/vae_working_memory.py +0 -2
  52. invokeai/frontend/web/dist/assets/App-ClpIJstk.js +161 -0
  53. invokeai/frontend/web/dist/assets/{browser-ponyfill-BP0RxJ4G.js → browser-ponyfill-Cw07u5G1.js} +1 -1
  54. invokeai/frontend/web/dist/assets/{index-B44qKjrs.js → index-DSKM8iGj.js} +69 -69
  55. invokeai/frontend/web/dist/index.html +1 -1
  56. invokeai/frontend/web/dist/locales/en.json +58 -5
  57. invokeai/frontend/web/dist/locales/it.json +2 -1
  58. invokeai/version/invokeai_version.py +1 -1
  59. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/METADATA +7 -1
  60. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/RECORD +66 -49
  61. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/WHEEL +1 -1
  62. invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +0 -161
  63. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/entry_points.txt +0 -0
  64. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE +0 -0
  65. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  66. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  67. {invokeai-6.10.0rc2.dist-info → invokeai-6.11.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
1
+ """DyPE base configuration and utilities."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Literal
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ @dataclass
12
+ class DyPEConfig:
13
+ """Configuration for Dynamic Position Extrapolation."""
14
+
15
+ enable_dype: bool = True
16
+ base_resolution: int = 1024 # Native training resolution
17
+ method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
18
+ dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
19
+ dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
20
+ dype_start_sigma: float = 1.0 # When DyPE decay starts
21
+
22
+
23
+ def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
24
+ """Calculate magnitude scaling factor.
25
+
26
+ Args:
27
+ scale: The resolution scaling factor
28
+ mscale_factor: Adjustment factor for the scaling
29
+
30
+ Returns:
31
+ The magnitude scaling factor
32
+ """
33
+ if scale <= 1.0:
34
+ return 1.0
35
+ return mscale_factor * math.log(scale) + 1.0
36
+
37
+
38
+ def get_timestep_mscale(
39
+ scale: float,
40
+ current_sigma: float,
41
+ dype_scale: float,
42
+ dype_exponent: float,
43
+ dype_start_sigma: float,
44
+ ) -> float:
45
+ """Calculate timestep-dependent magnitude scaling.
46
+
47
+ The key insight of DyPE: early steps focus on low frequencies (global structure),
48
+ late steps on high frequencies (details). This function modulates the scaling
49
+ based on the current timestep/sigma.
50
+
51
+ Args:
52
+ scale: Resolution scaling factor
53
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
54
+ dype_scale: DyPE magnitude (λs)
55
+ dype_exponent: DyPE decay speed (λt)
56
+ dype_start_sigma: Sigma threshold to start decay
57
+
58
+ Returns:
59
+ Timestep-modulated scaling factor
60
+ """
61
+ if scale <= 1.0:
62
+ return 1.0
63
+
64
+ # Normalize sigma to [0, 1] range relative to start_sigma
65
+ if current_sigma >= dype_start_sigma:
66
+ t_normalized = 1.0
67
+ else:
68
+ t_normalized = current_sigma / dype_start_sigma
69
+
70
+ # Apply exponential decay: stronger extrapolation early, weaker late
71
+ # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
72
+ decay = math.exp(-dype_exponent * (1.0 - t_normalized))
73
+
74
+ # Base mscale from resolution
75
+ base_mscale = get_mscale(scale)
76
+
77
+ # Interpolate between base_mscale and 1.0 based on decay and dype_scale
78
+ # When decay=1 (early): use scaled value
79
+ # When decay=0 (late): use base value
80
+ scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
81
+
82
+ return scaled_mscale
83
+
84
+
85
+ def compute_vision_yarn_freqs(
86
+ pos: Tensor,
87
+ dim: int,
88
+ theta: int,
89
+ scale_h: float,
90
+ scale_w: float,
91
+ current_sigma: float,
92
+ dype_config: DyPEConfig,
93
+ ) -> tuple[Tensor, Tensor]:
94
+ """Compute RoPE frequencies using NTK-aware scaling for high-resolution.
95
+
96
+ This method extends FLUX's position encoding to handle resolutions beyond
97
+ the 1024px training resolution by scaling the base frequency (theta).
98
+
99
+ The NTK-aware approach smoothly interpolates frequencies to cover larger
100
+ position ranges without breaking the attention patterns.
101
+
102
+ Args:
103
+ pos: Position tensor
104
+ dim: Embedding dimension
105
+ theta: RoPE base frequency
106
+ scale_h: Height scaling factor
107
+ scale_w: Width scaling factor
108
+ current_sigma: Current noise level (reserved for future timestep-aware scaling)
109
+ dype_config: DyPE configuration
110
+
111
+ Returns:
112
+ Tuple of (cos, sin) frequency tensors
113
+ """
114
+ assert dim % 2 == 0
115
+
116
+ # Use the larger scale for NTK calculation
117
+ scale = max(scale_h, scale_w)
118
+
119
+ device = pos.device
120
+ dtype = torch.float64 if device.type != "mps" else torch.float32
121
+
122
+ # NTK-aware theta scaling: extends position coverage for high-res
123
+ # Formula: theta_scaled = theta * scale^(dim/(dim-2))
124
+ # This increases the wavelength of position encodings proportionally
125
+ if scale > 1.0:
126
+ ntk_alpha = scale ** (dim / (dim - 2))
127
+ scaled_theta = theta * ntk_alpha
128
+ else:
129
+ scaled_theta = theta
130
+
131
+ # Standard RoPE frequency computation
132
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
133
+ freqs = 1.0 / (scaled_theta**freq_seq)
134
+
135
+ # Compute angles = position * frequency
136
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
137
+
138
+ cos = torch.cos(angles)
139
+ sin = torch.sin(angles)
140
+
141
+ return cos.to(pos.dtype), sin.to(pos.dtype)
142
+
143
+
144
+ def compute_yarn_freqs(
145
+ pos: Tensor,
146
+ dim: int,
147
+ theta: int,
148
+ scale: float,
149
+ current_sigma: float,
150
+ dype_config: DyPEConfig,
151
+ ) -> tuple[Tensor, Tensor]:
152
+ """Compute RoPE frequencies using YARN/NTK method.
153
+
154
+ Uses NTK-aware theta scaling for high-resolution support.
155
+
156
+ Args:
157
+ pos: Position tensor
158
+ dim: Embedding dimension
159
+ theta: RoPE base frequency
160
+ scale: Uniform scaling factor
161
+ current_sigma: Current noise level (reserved for future use)
162
+ dype_config: DyPE configuration
163
+
164
+ Returns:
165
+ Tuple of (cos, sin) frequency tensors
166
+ """
167
+ assert dim % 2 == 0
168
+
169
+ device = pos.device
170
+ dtype = torch.float64 if device.type != "mps" else torch.float32
171
+
172
+ # NTK-aware theta scaling
173
+ if scale > 1.0:
174
+ ntk_alpha = scale ** (dim / (dim - 2))
175
+ scaled_theta = theta * ntk_alpha
176
+ else:
177
+ scaled_theta = theta
178
+
179
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
180
+ freqs = 1.0 / (scaled_theta**freq_seq)
181
+
182
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
183
+
184
+ cos = torch.cos(angles)
185
+ sin = torch.sin(angles)
186
+
187
+ return cos.to(pos.dtype), sin.to(pos.dtype)
188
+
189
+
190
+ def compute_ntk_freqs(
191
+ pos: Tensor,
192
+ dim: int,
193
+ theta: int,
194
+ scale: float,
195
+ ) -> tuple[Tensor, Tensor]:
196
+ """Compute RoPE frequencies using NTK method.
197
+
198
+ Neural Tangent Kernel approach - continuous frequency scaling without
199
+ timestep dependency.
200
+
201
+ Args:
202
+ pos: Position tensor
203
+ dim: Embedding dimension
204
+ theta: RoPE base frequency
205
+ scale: Scaling factor
206
+
207
+ Returns:
208
+ Tuple of (cos, sin) frequency tensors
209
+ """
210
+ assert dim % 2 == 0
211
+
212
+ device = pos.device
213
+ dtype = torch.float64 if device.type != "mps" else torch.float32
214
+
215
+ # NTK scaling
216
+ scaled_theta = theta * (scale ** (dim / (dim - 2)))
217
+
218
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
219
+ freqs = 1.0 / (scaled_theta**freq_seq)
220
+
221
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
222
+
223
+ cos = torch.cos(angles)
224
+ sin = torch.sin(angles)
225
+
226
+ return cos.to(pos.dtype), sin.to(pos.dtype)
@@ -0,0 +1,116 @@
1
+ """DyPE-enhanced position embedding module."""
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+ from invokeai.backend.flux.dype.rope import rope_dype
8
+
9
+
10
+ class DyPEEmbedND(nn.Module):
11
+ """N-dimensional position embedding with DyPE support.
12
+
13
+ This class replaces the standard EmbedND from FLUX with a DyPE-aware version
14
+ that dynamically scales position embeddings based on resolution and timestep.
15
+
16
+ The key difference from EmbedND:
17
+ - Maintains step state (current_sigma, target dimensions)
18
+ - Uses rope_dype() instead of rope() for frequency computation
19
+ - Applies timestep-dependent scaling for better high-resolution generation
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ theta: int,
26
+ axes_dim: list[int],
27
+ dype_config: DyPEConfig,
28
+ ):
29
+ """Initialize DyPE position embedder.
30
+
31
+ Args:
32
+ dim: Total embedding dimension (sum of axes_dim)
33
+ theta: RoPE base frequency
34
+ axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX)
35
+ dype_config: DyPE configuration
36
+ """
37
+ super().__init__()
38
+ self.dim = dim
39
+ self.theta = theta
40
+ self.axes_dim = axes_dim
41
+ self.dype_config = dype_config
42
+
43
+ # Step state - updated before each denoising step
44
+ self._current_sigma: float = 1.0
45
+ self._target_height: int = 1024
46
+ self._target_width: int = 1024
47
+
48
+ def set_step_state(self, sigma: float, height: int, width: int) -> None:
49
+ """Update the step state before each denoising step.
50
+
51
+ This method should be called by the DyPE extension before each step
52
+ to update the current noise level and target dimensions.
53
+
54
+ Args:
55
+ sigma: Current noise level (timestep value, 1.0 = full noise)
56
+ height: Target image height in pixels
57
+ width: Target image width in pixels
58
+ """
59
+ self._current_sigma = sigma
60
+ self._target_height = height
61
+ self._target_width = width
62
+
63
+ def forward(self, ids: Tensor) -> Tensor:
64
+ """Compute position embeddings with DyPE scaling.
65
+
66
+ Args:
67
+ ids: Position indices tensor with shape (batch, seq_len, n_axes)
68
+ For FLUX: n_axes=3 (time/channel, height, width)
69
+
70
+ Returns:
71
+ Position embedding tensor with shape (batch, 1, seq_len, dim)
72
+ """
73
+ n_axes = ids.shape[-1]
74
+
75
+ # Compute RoPE for each axis with DyPE scaling
76
+ embeddings = []
77
+ for i in range(n_axes):
78
+ axis_emb = rope_dype(
79
+ pos=ids[..., i],
80
+ dim=self.axes_dim[i],
81
+ theta=self.theta,
82
+ current_sigma=self._current_sigma,
83
+ target_height=self._target_height,
84
+ target_width=self._target_width,
85
+ dype_config=self.dype_config,
86
+ )
87
+ embeddings.append(axis_emb)
88
+
89
+ # Concatenate embeddings from all axes
90
+ emb = torch.cat(embeddings, dim=-3)
91
+
92
+ return emb.unsqueeze(1)
93
+
94
+ @classmethod
95
+ def from_embednd(
96
+ cls,
97
+ embed_nd: nn.Module,
98
+ dype_config: DyPEConfig,
99
+ ) -> "DyPEEmbedND":
100
+ """Create a DyPEEmbedND from an existing EmbedND.
101
+
102
+ This is a convenience method for patching an existing FLUX model.
103
+
104
+ Args:
105
+ embed_nd: Original EmbedND module from FLUX
106
+ dype_config: DyPE configuration
107
+
108
+ Returns:
109
+ New DyPEEmbedND with same parameters
110
+ """
111
+ return cls(
112
+ dim=embed_nd.dim,
113
+ theta=embed_nd.theta,
114
+ axes_dim=embed_nd.axes_dim,
115
+ dype_config=dype_config,
116
+ )
@@ -0,0 +1,141 @@
1
+ """DyPE presets and automatic configuration."""
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+
8
+
9
+ class DyPEPreset(str, Enum):
10
+ """Predefined DyPE configurations."""
11
+
12
+ OFF = "off" # DyPE disabled
13
+ AUTO = "auto" # Automatically enable based on resolution
14
+ PRESET_4K = "4k" # Optimized for 3840x2160 / 4096x2160
15
+
16
+
17
+ @dataclass
18
+ class DyPEPresetConfig:
19
+ """Preset configuration values."""
20
+
21
+ base_resolution: int
22
+ method: str
23
+ dype_scale: float
24
+ dype_exponent: float
25
+ dype_start_sigma: float
26
+
27
+
28
+ # Predefined preset configurations
29
+ DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
30
+ DyPEPreset.PRESET_4K: DyPEPresetConfig(
31
+ base_resolution=1024,
32
+ method="vision_yarn",
33
+ dype_scale=2.0,
34
+ dype_exponent=2.0,
35
+ dype_start_sigma=1.0,
36
+ ),
37
+ }
38
+
39
+
40
+ def get_dype_config_for_resolution(
41
+ width: int,
42
+ height: int,
43
+ base_resolution: int = 1024,
44
+ activation_threshold: int = 1536,
45
+ ) -> DyPEConfig | None:
46
+ """Automatically determine DyPE config based on target resolution.
47
+
48
+ FLUX can handle resolutions up to ~1.5x natively without significant artifacts.
49
+ DyPE is only activated when the resolution exceeds the activation threshold.
50
+
51
+ Args:
52
+ width: Target image width in pixels
53
+ height: Target image height in pixels
54
+ base_resolution: Native training resolution of the model (for scale calculation)
55
+ activation_threshold: Resolution threshold above which DyPE is activated
56
+
57
+ Returns:
58
+ DyPEConfig if DyPE should be enabled, None otherwise
59
+ """
60
+ max_dim = max(width, height)
61
+
62
+ if max_dim <= activation_threshold:
63
+ return None # FLUX can handle this natively
64
+
65
+ # Calculate scaling factor based on base_resolution
66
+ scale = max_dim / base_resolution
67
+
68
+ # Dynamic parameters based on scaling
69
+ # Higher resolution = higher dype_scale, capped at 8.0
70
+ dynamic_dype_scale = min(2.0 * scale, 8.0)
71
+
72
+ return DyPEConfig(
73
+ enable_dype=True,
74
+ base_resolution=base_resolution,
75
+ method="vision_yarn",
76
+ dype_scale=dynamic_dype_scale,
77
+ dype_exponent=2.0,
78
+ dype_start_sigma=1.0,
79
+ )
80
+
81
+
82
+ def get_dype_config_from_preset(
83
+ preset: DyPEPreset,
84
+ width: int,
85
+ height: int,
86
+ custom_scale: float | None = None,
87
+ custom_exponent: float | None = None,
88
+ ) -> DyPEConfig | None:
89
+ """Get DyPE configuration from a preset or custom values.
90
+
91
+ Args:
92
+ preset: The DyPE preset to use
93
+ width: Target image width
94
+ height: Target image height
95
+ custom_scale: Optional custom dype_scale (overrides preset)
96
+ custom_exponent: Optional custom dype_exponent (overrides preset)
97
+
98
+ Returns:
99
+ DyPEConfig if DyPE should be enabled, None otherwise
100
+ """
101
+ if preset == DyPEPreset.OFF:
102
+ # Check if custom values are provided even with preset=OFF
103
+ if custom_scale is not None:
104
+ return DyPEConfig(
105
+ enable_dype=True,
106
+ base_resolution=1024,
107
+ method="vision_yarn",
108
+ dype_scale=custom_scale,
109
+ dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
110
+ dype_start_sigma=1.0,
111
+ )
112
+ return None
113
+
114
+ if preset == DyPEPreset.AUTO:
115
+ config = get_dype_config_for_resolution(
116
+ width=width,
117
+ height=height,
118
+ base_resolution=1024,
119
+ activation_threshold=1536,
120
+ )
121
+ # Apply custom overrides if provided
122
+ if config is not None:
123
+ if custom_scale is not None:
124
+ config.dype_scale = custom_scale
125
+ if custom_exponent is not None:
126
+ config.dype_exponent = custom_exponent
127
+ return config
128
+
129
+ # Use preset configuration
130
+ preset_config = DYPE_PRESETS.get(preset)
131
+ if preset_config is None:
132
+ return None
133
+
134
+ return DyPEConfig(
135
+ enable_dype=True,
136
+ base_resolution=preset_config.base_resolution,
137
+ method=preset_config.method,
138
+ dype_scale=custom_scale if custom_scale is not None else preset_config.dype_scale,
139
+ dype_exponent=custom_exponent if custom_exponent is not None else preset_config.dype_exponent,
140
+ dype_start_sigma=preset_config.dype_start_sigma,
141
+ )
@@ -0,0 +1,110 @@
1
+ """DyPE-enhanced RoPE (Rotary Position Embedding) functions."""
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+
7
+ from invokeai.backend.flux.dype.base import (
8
+ DyPEConfig,
9
+ compute_ntk_freqs,
10
+ compute_vision_yarn_freqs,
11
+ compute_yarn_freqs,
12
+ )
13
+
14
+
15
+ def rope_dype(
16
+ pos: Tensor,
17
+ dim: int,
18
+ theta: int,
19
+ current_sigma: float,
20
+ target_height: int,
21
+ target_width: int,
22
+ dype_config: DyPEConfig,
23
+ ) -> Tensor:
24
+ """Compute RoPE with Dynamic Position Extrapolation.
25
+
26
+ This is the core DyPE function that replaces the standard rope() function.
27
+ It applies resolution-aware and timestep-aware scaling to position embeddings.
28
+
29
+ Args:
30
+ pos: Position indices tensor
31
+ dim: Embedding dimension per axis
32
+ theta: RoPE base frequency (typically 10000)
33
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
34
+ target_height: Target image height in pixels
35
+ target_width: Target image width in pixels
36
+ dype_config: DyPE configuration
37
+
38
+ Returns:
39
+ Rotary position embedding tensor with shape suitable for FLUX attention
40
+ """
41
+ assert dim % 2 == 0
42
+
43
+ # Calculate scaling factors
44
+ base_res = dype_config.base_resolution
45
+ scale_h = target_height / base_res
46
+ scale_w = target_width / base_res
47
+ scale = max(scale_h, scale_w)
48
+
49
+ # If no scaling needed and DyPE disabled, use base method
50
+ if not dype_config.enable_dype or scale <= 1.0:
51
+ return _rope_base(pos, dim, theta)
52
+
53
+ # Select method and compute frequencies
54
+ method = dype_config.method
55
+
56
+ if method == "vision_yarn":
57
+ cos, sin = compute_vision_yarn_freqs(
58
+ pos=pos,
59
+ dim=dim,
60
+ theta=theta,
61
+ scale_h=scale_h,
62
+ scale_w=scale_w,
63
+ current_sigma=current_sigma,
64
+ dype_config=dype_config,
65
+ )
66
+ elif method == "yarn":
67
+ cos, sin = compute_yarn_freqs(
68
+ pos=pos,
69
+ dim=dim,
70
+ theta=theta,
71
+ scale=scale,
72
+ current_sigma=current_sigma,
73
+ dype_config=dype_config,
74
+ )
75
+ elif method == "ntk":
76
+ cos, sin = compute_ntk_freqs(
77
+ pos=pos,
78
+ dim=dim,
79
+ theta=theta,
80
+ scale=scale,
81
+ )
82
+ else: # "base"
83
+ return _rope_base(pos, dim, theta)
84
+
85
+ # Construct rotation matrix from cos/sin
86
+ # Output shape: (batch, seq_len, dim/2, 2, 2)
87
+ out = torch.stack([cos, -sin, sin, cos], dim=-1)
88
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
89
+
90
+ return out.to(dtype=pos.dtype, device=pos.device)
91
+
92
+
93
+ def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor:
94
+ """Standard RoPE without DyPE scaling.
95
+
96
+ This matches the original rope() function from invokeai.backend.flux.math.
97
+ """
98
+ assert dim % 2 == 0
99
+
100
+ device = pos.device
101
+ dtype = torch.float64 if device.type != "mps" else torch.float32
102
+
103
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
104
+ omega = 1.0 / (theta**scale)
105
+
106
+ out = torch.einsum("...n,d->...nd", pos.to(dtype), omega)
107
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
108
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
109
+
110
+ return out.to(dtype=pos.dtype, device=pos.device)
@@ -0,0 +1,91 @@
1
+ """DyPE extension for FLUX denoising pipeline."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+ from invokeai.backend.flux.dype.embed import DyPEEmbedND
8
+
9
+ if TYPE_CHECKING:
10
+ from invokeai.backend.flux.model import Flux
11
+
12
+
13
+ @dataclass
14
+ class DyPEExtension:
15
+ """Extension for Dynamic Position Extrapolation in FLUX models.
16
+
17
+ This extension manages the patching of the FLUX model's position embedder
18
+ and updates the step state during denoising.
19
+
20
+ Usage:
21
+ 1. Create extension with config and target dimensions
22
+ 2. Call patch_model() to replace pe_embedder with DyPE version
23
+ 3. Call update_step_state() before each denoising step
24
+ 4. Call restore_model() after denoising to restore original embedder
25
+ """
26
+
27
+ config: DyPEConfig
28
+ target_height: int
29
+ target_width: int
30
+
31
+ def patch_model(self, model: "Flux") -> tuple[DyPEEmbedND, object]:
32
+ """Patch the model's position embedder with DyPE version.
33
+
34
+ Args:
35
+ model: The FLUX model to patch
36
+
37
+ Returns:
38
+ Tuple of (new DyPE embedder, original embedder for restoration)
39
+ """
40
+ original_embedder = model.pe_embedder
41
+
42
+ dype_embedder = DyPEEmbedND.from_embednd(
43
+ embed_nd=original_embedder,
44
+ dype_config=self.config,
45
+ )
46
+
47
+ # Set initial state
48
+ dype_embedder.set_step_state(
49
+ sigma=1.0,
50
+ height=self.target_height,
51
+ width=self.target_width,
52
+ )
53
+
54
+ # Replace the embedder
55
+ model.pe_embedder = dype_embedder
56
+
57
+ return dype_embedder, original_embedder
58
+
59
+ def update_step_state(
60
+ self,
61
+ embedder: DyPEEmbedND,
62
+ timestep: float,
63
+ timestep_index: int,
64
+ total_steps: int,
65
+ ) -> None:
66
+ """Update the step state in the DyPE embedder.
67
+
68
+ This should be called before each denoising step to update the
69
+ current noise level for timestep-dependent scaling.
70
+
71
+ Args:
72
+ embedder: The DyPE embedder to update
73
+ timestep: Current timestep value (sigma/noise level)
74
+ timestep_index: Current step index (0-based)
75
+ total_steps: Total number of denoising steps
76
+ """
77
+ embedder.set_step_state(
78
+ sigma=timestep,
79
+ height=self.target_height,
80
+ width=self.target_width,
81
+ )
82
+
83
+ @staticmethod
84
+ def restore_model(model: "Flux", original_embedder: object) -> None:
85
+ """Restore the original position embedder.
86
+
87
+ Args:
88
+ model: The FLUX model to restore
89
+ original_embedder: The original embedder saved from patch_model()
90
+ """
91
+ model.pe_embedder = original_embedder