InvokeAI 6.10.0rc1__py3-none-any.whl → 6.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. invokeai/app/api/routers/model_manager.py +43 -1
  2. invokeai/app/invocations/fields.py +1 -1
  3. invokeai/app/invocations/flux2_denoise.py +499 -0
  4. invokeai/app/invocations/flux2_klein_model_loader.py +222 -0
  5. invokeai/app/invocations/flux2_klein_text_encoder.py +222 -0
  6. invokeai/app/invocations/flux2_vae_decode.py +106 -0
  7. invokeai/app/invocations/flux2_vae_encode.py +88 -0
  8. invokeai/app/invocations/flux_denoise.py +77 -3
  9. invokeai/app/invocations/flux_lora_loader.py +1 -1
  10. invokeai/app/invocations/flux_model_loader.py +2 -5
  11. invokeai/app/invocations/ideal_size.py +6 -1
  12. invokeai/app/invocations/metadata.py +4 -0
  13. invokeai/app/invocations/metadata_linked.py +47 -0
  14. invokeai/app/invocations/model.py +1 -0
  15. invokeai/app/invocations/pbr_maps.py +59 -0
  16. invokeai/app/invocations/z_image_denoise.py +244 -84
  17. invokeai/app/invocations/z_image_image_to_latents.py +9 -1
  18. invokeai/app/invocations/z_image_latents_to_image.py +9 -1
  19. invokeai/app/invocations/z_image_seed_variance_enhancer.py +110 -0
  20. invokeai/app/services/config/config_default.py +3 -1
  21. invokeai/app/services/invocation_stats/invocation_stats_common.py +6 -6
  22. invokeai/app/services/invocation_stats/invocation_stats_default.py +9 -4
  23. invokeai/app/services/model_manager/model_manager_default.py +7 -0
  24. invokeai/app/services/model_records/model_records_base.py +4 -2
  25. invokeai/app/services/shared/invocation_context.py +15 -0
  26. invokeai/app/services/shared/sqlite/sqlite_util.py +2 -0
  27. invokeai/app/services/shared/sqlite_migrator/migrations/migration_25.py +61 -0
  28. invokeai/app/util/step_callback.py +58 -2
  29. invokeai/backend/flux/denoise.py +338 -118
  30. invokeai/backend/flux/dype/__init__.py +31 -0
  31. invokeai/backend/flux/dype/base.py +260 -0
  32. invokeai/backend/flux/dype/embed.py +116 -0
  33. invokeai/backend/flux/dype/presets.py +148 -0
  34. invokeai/backend/flux/dype/rope.py +110 -0
  35. invokeai/backend/flux/extensions/dype_extension.py +91 -0
  36. invokeai/backend/flux/schedulers.py +62 -0
  37. invokeai/backend/flux/util.py +35 -1
  38. invokeai/backend/flux2/__init__.py +4 -0
  39. invokeai/backend/flux2/denoise.py +280 -0
  40. invokeai/backend/flux2/ref_image_extension.py +294 -0
  41. invokeai/backend/flux2/sampling_utils.py +209 -0
  42. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  43. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  44. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  45. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  46. invokeai/backend/model_manager/configs/factory.py +19 -1
  47. invokeai/backend/model_manager/configs/lora.py +36 -0
  48. invokeai/backend/model_manager/configs/main.py +395 -3
  49. invokeai/backend/model_manager/configs/qwen3_encoder.py +116 -7
  50. invokeai/backend/model_manager/configs/vae.py +104 -2
  51. invokeai/backend/model_manager/load/model_cache/model_cache.py +107 -2
  52. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  53. invokeai/backend/model_manager/load/model_loaders/flux.py +1020 -8
  54. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  55. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  56. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
  57. invokeai/backend/model_manager/load/model_loaders/z_image.py +158 -31
  58. invokeai/backend/model_manager/starter_models.py +141 -4
  59. invokeai/backend/model_manager/taxonomy.py +31 -4
  60. invokeai/backend/model_manager/util/select_hf_files.py +3 -2
  61. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  62. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  63. invokeai/backend/util/vae_working_memory.py +0 -2
  64. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  65. invokeai/frontend/web/dist/assets/App-D13dX7be.js +161 -0
  66. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-u_ZjhQTI.js} +1 -1
  67. invokeai/frontend/web/dist/assets/index-BB0nHmDe.js +530 -0
  68. invokeai/frontend/web/dist/index.html +1 -1
  69. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  70. invokeai/frontend/web/dist/locales/en.json +85 -6
  71. invokeai/frontend/web/dist/locales/it.json +135 -15
  72. invokeai/frontend/web/dist/locales/ru.json +11 -11
  73. invokeai/version/invokeai_version.py +1 -1
  74. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/METADATA +8 -2
  75. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/RECORD +81 -57
  76. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/WHEEL +1 -1
  77. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  78. invokeai/frontend/web/dist/assets/index-dgSJAY--.js +0 -530
  79. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/entry_points.txt +0 -0
  80. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE +0 -0
  81. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  82. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  83. {invokeai-6.10.0rc1.dist-info → invokeai-6.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,260 @@
1
+ """DyPE base configuration and utilities."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Literal
6
+
7
+ import torch
8
+ from torch import Tensor
9
+
10
+
11
+ @dataclass
12
+ class DyPEConfig:
13
+ """Configuration for Dynamic Position Extrapolation."""
14
+
15
+ enable_dype: bool = True
16
+ base_resolution: int = 1024 # Native training resolution
17
+ method: Literal["vision_yarn", "yarn", "ntk", "base"] = "vision_yarn"
18
+ dype_scale: float = 2.0 # Magnitude λs (0.0-8.0)
19
+ dype_exponent: float = 2.0 # Decay speed λt (0.0-1000.0)
20
+ dype_start_sigma: float = 1.0 # When DyPE decay starts
21
+
22
+
23
+ def get_mscale(scale: float, mscale_factor: float = 1.0) -> float:
24
+ """Calculate magnitude scaling factor.
25
+
26
+ Args:
27
+ scale: The resolution scaling factor
28
+ mscale_factor: Adjustment factor for the scaling
29
+
30
+ Returns:
31
+ The magnitude scaling factor
32
+ """
33
+ if scale <= 1.0:
34
+ return 1.0
35
+ return mscale_factor * math.log(scale) + 1.0
36
+
37
+
38
+ def get_timestep_mscale(
39
+ scale: float,
40
+ current_sigma: float,
41
+ dype_scale: float,
42
+ dype_exponent: float,
43
+ dype_start_sigma: float,
44
+ ) -> float:
45
+ """Calculate timestep-dependent magnitude scaling.
46
+
47
+ The key insight of DyPE: early steps focus on low frequencies (global structure),
48
+ late steps on high frequencies (details). This function modulates the scaling
49
+ based on the current timestep/sigma.
50
+
51
+ Args:
52
+ scale: Resolution scaling factor
53
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
54
+ dype_scale: DyPE magnitude (λs)
55
+ dype_exponent: DyPE decay speed (λt)
56
+ dype_start_sigma: Sigma threshold to start decay
57
+
58
+ Returns:
59
+ Timestep-modulated scaling factor
60
+ """
61
+ if scale <= 1.0:
62
+ return 1.0
63
+
64
+ # Normalize sigma to [0, 1] range relative to start_sigma
65
+ if current_sigma >= dype_start_sigma:
66
+ t_normalized = 1.0
67
+ else:
68
+ t_normalized = current_sigma / dype_start_sigma
69
+
70
+ # Apply exponential decay: stronger extrapolation early, weaker late
71
+ # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late
72
+ decay = math.exp(-dype_exponent * (1.0 - t_normalized))
73
+
74
+ # Base mscale from resolution
75
+ base_mscale = get_mscale(scale)
76
+
77
+ # Interpolate between base_mscale and 1.0 based on decay and dype_scale
78
+ # When decay=1 (early): use scaled value
79
+ # When decay=0 (late): use base value
80
+ scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay
81
+
82
+ return scaled_mscale
83
+
84
+
85
+ def compute_vision_yarn_freqs(
86
+ pos: Tensor,
87
+ dim: int,
88
+ theta: int,
89
+ scale_h: float,
90
+ scale_w: float,
91
+ current_sigma: float,
92
+ dype_config: DyPEConfig,
93
+ ) -> tuple[Tensor, Tensor]:
94
+ """Compute RoPE frequencies using NTK-aware scaling for high-resolution.
95
+
96
+ This method extends FLUX's position encoding to handle resolutions beyond
97
+ the 1024px training resolution by scaling the base frequency (theta).
98
+
99
+ The NTK-aware approach smoothly interpolates frequencies to cover larger
100
+ position ranges without breaking the attention patterns.
101
+
102
+ DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on
103
+ the current timestep - stronger extrapolation in early steps (global structure),
104
+ weaker in late steps (fine details).
105
+
106
+ Args:
107
+ pos: Position tensor
108
+ dim: Embedding dimension
109
+ theta: RoPE base frequency
110
+ scale_h: Height scaling factor
111
+ scale_w: Width scaling factor
112
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
113
+ dype_config: DyPE configuration
114
+
115
+ Returns:
116
+ Tuple of (cos, sin) frequency tensors
117
+ """
118
+ assert dim % 2 == 0
119
+
120
+ # Use the larger scale for NTK calculation
121
+ scale = max(scale_h, scale_w)
122
+
123
+ device = pos.device
124
+ dtype = torch.float64 if device.type != "mps" else torch.float32
125
+
126
+ # NTK-aware theta scaling: extends position coverage for high-res
127
+ # Formula: theta_scaled = theta * scale^(dim/(dim-2))
128
+ # This increases the wavelength of position encodings proportionally
129
+ if scale > 1.0:
130
+ ntk_alpha = scale ** (dim / (dim - 2))
131
+
132
+ # Apply timestep-dependent DyPE modulation
133
+ # mscale controls how strongly we apply the NTK extrapolation
134
+ # Early steps (high sigma): stronger extrapolation for global structure
135
+ # Late steps (low sigma): weaker extrapolation for fine details
136
+ mscale = get_timestep_mscale(
137
+ scale=scale,
138
+ current_sigma=current_sigma,
139
+ dype_scale=dype_config.dype_scale,
140
+ dype_exponent=dype_config.dype_exponent,
141
+ dype_start_sigma=dype_config.dype_start_sigma,
142
+ )
143
+
144
+ # Modulate NTK alpha by mscale
145
+ # When mscale > 1: interpolate towards stronger extrapolation
146
+ # When mscale = 1: use base NTK alpha
147
+ modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
148
+ scaled_theta = theta * modulated_alpha
149
+ else:
150
+ scaled_theta = theta
151
+
152
+ # Standard RoPE frequency computation
153
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
154
+ freqs = 1.0 / (scaled_theta**freq_seq)
155
+
156
+ # Compute angles = position * frequency
157
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
158
+
159
+ cos = torch.cos(angles)
160
+ sin = torch.sin(angles)
161
+
162
+ return cos.to(pos.dtype), sin.to(pos.dtype)
163
+
164
+
165
+ def compute_yarn_freqs(
166
+ pos: Tensor,
167
+ dim: int,
168
+ theta: int,
169
+ scale: float,
170
+ current_sigma: float,
171
+ dype_config: DyPEConfig,
172
+ ) -> tuple[Tensor, Tensor]:
173
+ """Compute RoPE frequencies using YARN/NTK method.
174
+
175
+ Uses NTK-aware theta scaling for high-resolution support with
176
+ timestep-dependent DyPE modulation.
177
+
178
+ Args:
179
+ pos: Position tensor
180
+ dim: Embedding dimension
181
+ theta: RoPE base frequency
182
+ scale: Uniform scaling factor
183
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
184
+ dype_config: DyPE configuration
185
+
186
+ Returns:
187
+ Tuple of (cos, sin) frequency tensors
188
+ """
189
+ assert dim % 2 == 0
190
+
191
+ device = pos.device
192
+ dtype = torch.float64 if device.type != "mps" else torch.float32
193
+
194
+ # NTK-aware theta scaling with DyPE modulation
195
+ if scale > 1.0:
196
+ ntk_alpha = scale ** (dim / (dim - 2))
197
+
198
+ # Apply timestep-dependent DyPE modulation
199
+ mscale = get_timestep_mscale(
200
+ scale=scale,
201
+ current_sigma=current_sigma,
202
+ dype_scale=dype_config.dype_scale,
203
+ dype_exponent=dype_config.dype_exponent,
204
+ dype_start_sigma=dype_config.dype_start_sigma,
205
+ )
206
+
207
+ # Modulate NTK alpha by mscale
208
+ modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale
209
+ scaled_theta = theta * modulated_alpha
210
+ else:
211
+ scaled_theta = theta
212
+
213
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
214
+ freqs = 1.0 / (scaled_theta**freq_seq)
215
+
216
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
217
+
218
+ cos = torch.cos(angles)
219
+ sin = torch.sin(angles)
220
+
221
+ return cos.to(pos.dtype), sin.to(pos.dtype)
222
+
223
+
224
+ def compute_ntk_freqs(
225
+ pos: Tensor,
226
+ dim: int,
227
+ theta: int,
228
+ scale: float,
229
+ ) -> tuple[Tensor, Tensor]:
230
+ """Compute RoPE frequencies using NTK method.
231
+
232
+ Neural Tangent Kernel approach - continuous frequency scaling without
233
+ timestep dependency.
234
+
235
+ Args:
236
+ pos: Position tensor
237
+ dim: Embedding dimension
238
+ theta: RoPE base frequency
239
+ scale: Scaling factor
240
+
241
+ Returns:
242
+ Tuple of (cos, sin) frequency tensors
243
+ """
244
+ assert dim % 2 == 0
245
+
246
+ device = pos.device
247
+ dtype = torch.float64 if device.type != "mps" else torch.float32
248
+
249
+ # NTK scaling
250
+ scaled_theta = theta * (scale ** (dim / (dim - 2)))
251
+
252
+ freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
253
+ freqs = 1.0 / (scaled_theta**freq_seq)
254
+
255
+ angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs)
256
+
257
+ cos = torch.cos(angles)
258
+ sin = torch.sin(angles)
259
+
260
+ return cos.to(pos.dtype), sin.to(pos.dtype)
@@ -0,0 +1,116 @@
1
+ """DyPE-enhanced position embedding module."""
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+ from invokeai.backend.flux.dype.rope import rope_dype
8
+
9
+
10
+ class DyPEEmbedND(nn.Module):
11
+ """N-dimensional position embedding with DyPE support.
12
+
13
+ This class replaces the standard EmbedND from FLUX with a DyPE-aware version
14
+ that dynamically scales position embeddings based on resolution and timestep.
15
+
16
+ The key difference from EmbedND:
17
+ - Maintains step state (current_sigma, target dimensions)
18
+ - Uses rope_dype() instead of rope() for frequency computation
19
+ - Applies timestep-dependent scaling for better high-resolution generation
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ theta: int,
26
+ axes_dim: list[int],
27
+ dype_config: DyPEConfig,
28
+ ):
29
+ """Initialize DyPE position embedder.
30
+
31
+ Args:
32
+ dim: Total embedding dimension (sum of axes_dim)
33
+ theta: RoPE base frequency
34
+ axes_dim: Dimension allocation per axis (e.g., [16, 56, 56] for FLUX)
35
+ dype_config: DyPE configuration
36
+ """
37
+ super().__init__()
38
+ self.dim = dim
39
+ self.theta = theta
40
+ self.axes_dim = axes_dim
41
+ self.dype_config = dype_config
42
+
43
+ # Step state - updated before each denoising step
44
+ self._current_sigma: float = 1.0
45
+ self._target_height: int = 1024
46
+ self._target_width: int = 1024
47
+
48
+ def set_step_state(self, sigma: float, height: int, width: int) -> None:
49
+ """Update the step state before each denoising step.
50
+
51
+ This method should be called by the DyPE extension before each step
52
+ to update the current noise level and target dimensions.
53
+
54
+ Args:
55
+ sigma: Current noise level (timestep value, 1.0 = full noise)
56
+ height: Target image height in pixels
57
+ width: Target image width in pixels
58
+ """
59
+ self._current_sigma = sigma
60
+ self._target_height = height
61
+ self._target_width = width
62
+
63
+ def forward(self, ids: Tensor) -> Tensor:
64
+ """Compute position embeddings with DyPE scaling.
65
+
66
+ Args:
67
+ ids: Position indices tensor with shape (batch, seq_len, n_axes)
68
+ For FLUX: n_axes=3 (time/channel, height, width)
69
+
70
+ Returns:
71
+ Position embedding tensor with shape (batch, 1, seq_len, dim)
72
+ """
73
+ n_axes = ids.shape[-1]
74
+
75
+ # Compute RoPE for each axis with DyPE scaling
76
+ embeddings = []
77
+ for i in range(n_axes):
78
+ axis_emb = rope_dype(
79
+ pos=ids[..., i],
80
+ dim=self.axes_dim[i],
81
+ theta=self.theta,
82
+ current_sigma=self._current_sigma,
83
+ target_height=self._target_height,
84
+ target_width=self._target_width,
85
+ dype_config=self.dype_config,
86
+ )
87
+ embeddings.append(axis_emb)
88
+
89
+ # Concatenate embeddings from all axes
90
+ emb = torch.cat(embeddings, dim=-3)
91
+
92
+ return emb.unsqueeze(1)
93
+
94
+ @classmethod
95
+ def from_embednd(
96
+ cls,
97
+ embed_nd: nn.Module,
98
+ dype_config: DyPEConfig,
99
+ ) -> "DyPEEmbedND":
100
+ """Create a DyPEEmbedND from an existing EmbedND.
101
+
102
+ This is a convenience method for patching an existing FLUX model.
103
+
104
+ Args:
105
+ embed_nd: Original EmbedND module from FLUX
106
+ dype_config: DyPE configuration
107
+
108
+ Returns:
109
+ New DyPEEmbedND with same parameters
110
+ """
111
+ return cls(
112
+ dim=embed_nd.dim,
113
+ theta=embed_nd.theta,
114
+ axes_dim=embed_nd.axes_dim,
115
+ dype_config=dype_config,
116
+ )
@@ -0,0 +1,148 @@
1
+ """DyPE presets and automatic configuration."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ from invokeai.backend.flux.dype.base import DyPEConfig
7
+
8
+ # DyPE preset type - using Literal for proper frontend dropdown support
9
+ DyPEPreset = Literal["off", "manual", "auto", "4k"]
10
+
11
+ # Constants for preset values
12
+ DYPE_PRESET_OFF: DyPEPreset = "off"
13
+ DYPE_PRESET_MANUAL: DyPEPreset = "manual"
14
+ DYPE_PRESET_AUTO: DyPEPreset = "auto"
15
+ DYPE_PRESET_4K: DyPEPreset = "4k"
16
+
17
+ # Human-readable labels for the UI
18
+ DYPE_PRESET_LABELS: dict[str, str] = {
19
+ "off": "Off",
20
+ "manual": "Manual",
21
+ "auto": "Auto (>1536px)",
22
+ "4k": "4K Optimized",
23
+ }
24
+
25
+
26
+ @dataclass
27
+ class DyPEPresetConfig:
28
+ """Preset configuration values."""
29
+
30
+ base_resolution: int
31
+ method: str
32
+ dype_scale: float
33
+ dype_exponent: float
34
+ dype_start_sigma: float
35
+
36
+
37
+ # Predefined preset configurations
38
+ DYPE_PRESETS: dict[DyPEPreset, DyPEPresetConfig] = {
39
+ DYPE_PRESET_4K: DyPEPresetConfig(
40
+ base_resolution=1024,
41
+ method="vision_yarn",
42
+ dype_scale=2.0,
43
+ dype_exponent=2.0,
44
+ dype_start_sigma=1.0,
45
+ ),
46
+ }
47
+
48
+
49
+ def get_dype_config_for_resolution(
50
+ width: int,
51
+ height: int,
52
+ base_resolution: int = 1024,
53
+ activation_threshold: int = 1536,
54
+ ) -> DyPEConfig | None:
55
+ """Automatically determine DyPE config based on target resolution.
56
+
57
+ FLUX can handle resolutions up to ~1.5x natively without significant artifacts.
58
+ DyPE is only activated when the resolution exceeds the activation threshold.
59
+
60
+ Args:
61
+ width: Target image width in pixels
62
+ height: Target image height in pixels
63
+ base_resolution: Native training resolution of the model (for scale calculation)
64
+ activation_threshold: Resolution threshold above which DyPE is activated
65
+
66
+ Returns:
67
+ DyPEConfig if DyPE should be enabled, None otherwise
68
+ """
69
+ max_dim = max(width, height)
70
+
71
+ if max_dim <= activation_threshold:
72
+ return None # FLUX can handle this natively
73
+
74
+ # Calculate scaling factor based on base_resolution
75
+ scale = max_dim / base_resolution
76
+
77
+ # Dynamic parameters based on scaling
78
+ # Higher resolution = higher dype_scale, capped at 8.0
79
+ dynamic_dype_scale = min(2.0 * scale, 8.0)
80
+
81
+ return DyPEConfig(
82
+ enable_dype=True,
83
+ base_resolution=base_resolution,
84
+ method="vision_yarn",
85
+ dype_scale=dynamic_dype_scale,
86
+ dype_exponent=2.0,
87
+ dype_start_sigma=1.0,
88
+ )
89
+
90
+
91
+ def get_dype_config_from_preset(
92
+ preset: DyPEPreset,
93
+ width: int,
94
+ height: int,
95
+ custom_scale: float | None = None,
96
+ custom_exponent: float | None = None,
97
+ ) -> DyPEConfig | None:
98
+ """Get DyPE configuration from a preset or custom values.
99
+
100
+ Args:
101
+ preset: The DyPE preset to use
102
+ width: Target image width
103
+ height: Target image height
104
+ custom_scale: Optional custom dype_scale (only used with 'manual' preset)
105
+ custom_exponent: Optional custom dype_exponent (only used with 'manual' preset)
106
+
107
+ Returns:
108
+ DyPEConfig if DyPE should be enabled, None otherwise
109
+ """
110
+ if preset == DYPE_PRESET_OFF:
111
+ return None
112
+
113
+ if preset == DYPE_PRESET_MANUAL:
114
+ # Manual mode - custom values can override defaults
115
+ max_dim = max(width, height)
116
+ scale = max_dim / 1024
117
+ dynamic_dype_scale = min(2.0 * scale, 8.0)
118
+ return DyPEConfig(
119
+ enable_dype=True,
120
+ base_resolution=1024,
121
+ method="vision_yarn",
122
+ dype_scale=custom_scale if custom_scale is not None else dynamic_dype_scale,
123
+ dype_exponent=custom_exponent if custom_exponent is not None else 2.0,
124
+ dype_start_sigma=1.0,
125
+ )
126
+
127
+ if preset == DYPE_PRESET_AUTO:
128
+ # Auto preset - custom values are ignored
129
+ return get_dype_config_for_resolution(
130
+ width=width,
131
+ height=height,
132
+ base_resolution=1024,
133
+ activation_threshold=1536,
134
+ )
135
+
136
+ # Use preset configuration (4K etc.) - custom values are ignored
137
+ preset_config = DYPE_PRESETS.get(preset)
138
+ if preset_config is None:
139
+ return None
140
+
141
+ return DyPEConfig(
142
+ enable_dype=True,
143
+ base_resolution=preset_config.base_resolution,
144
+ method=preset_config.method,
145
+ dype_scale=preset_config.dype_scale,
146
+ dype_exponent=preset_config.dype_exponent,
147
+ dype_start_sigma=preset_config.dype_start_sigma,
148
+ )
@@ -0,0 +1,110 @@
1
+ """DyPE-enhanced RoPE (Rotary Position Embedding) functions."""
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+
7
+ from invokeai.backend.flux.dype.base import (
8
+ DyPEConfig,
9
+ compute_ntk_freqs,
10
+ compute_vision_yarn_freqs,
11
+ compute_yarn_freqs,
12
+ )
13
+
14
+
15
+ def rope_dype(
16
+ pos: Tensor,
17
+ dim: int,
18
+ theta: int,
19
+ current_sigma: float,
20
+ target_height: int,
21
+ target_width: int,
22
+ dype_config: DyPEConfig,
23
+ ) -> Tensor:
24
+ """Compute RoPE with Dynamic Position Extrapolation.
25
+
26
+ This is the core DyPE function that replaces the standard rope() function.
27
+ It applies resolution-aware and timestep-aware scaling to position embeddings.
28
+
29
+ Args:
30
+ pos: Position indices tensor
31
+ dim: Embedding dimension per axis
32
+ theta: RoPE base frequency (typically 10000)
33
+ current_sigma: Current noise level (1.0 = full noise, 0.0 = clean)
34
+ target_height: Target image height in pixels
35
+ target_width: Target image width in pixels
36
+ dype_config: DyPE configuration
37
+
38
+ Returns:
39
+ Rotary position embedding tensor with shape suitable for FLUX attention
40
+ """
41
+ assert dim % 2 == 0
42
+
43
+ # Calculate scaling factors
44
+ base_res = dype_config.base_resolution
45
+ scale_h = target_height / base_res
46
+ scale_w = target_width / base_res
47
+ scale = max(scale_h, scale_w)
48
+
49
+ # If no scaling needed and DyPE disabled, use base method
50
+ if not dype_config.enable_dype or scale <= 1.0:
51
+ return _rope_base(pos, dim, theta)
52
+
53
+ # Select method and compute frequencies
54
+ method = dype_config.method
55
+
56
+ if method == "vision_yarn":
57
+ cos, sin = compute_vision_yarn_freqs(
58
+ pos=pos,
59
+ dim=dim,
60
+ theta=theta,
61
+ scale_h=scale_h,
62
+ scale_w=scale_w,
63
+ current_sigma=current_sigma,
64
+ dype_config=dype_config,
65
+ )
66
+ elif method == "yarn":
67
+ cos, sin = compute_yarn_freqs(
68
+ pos=pos,
69
+ dim=dim,
70
+ theta=theta,
71
+ scale=scale,
72
+ current_sigma=current_sigma,
73
+ dype_config=dype_config,
74
+ )
75
+ elif method == "ntk":
76
+ cos, sin = compute_ntk_freqs(
77
+ pos=pos,
78
+ dim=dim,
79
+ theta=theta,
80
+ scale=scale,
81
+ )
82
+ else: # "base"
83
+ return _rope_base(pos, dim, theta)
84
+
85
+ # Construct rotation matrix from cos/sin
86
+ # Output shape: (batch, seq_len, dim/2, 2, 2)
87
+ out = torch.stack([cos, -sin, sin, cos], dim=-1)
88
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
89
+
90
+ return out.to(dtype=pos.dtype, device=pos.device)
91
+
92
+
93
+ def _rope_base(pos: Tensor, dim: int, theta: int) -> Tensor:
94
+ """Standard RoPE without DyPE scaling.
95
+
96
+ This matches the original rope() function from invokeai.backend.flux.math.
97
+ """
98
+ assert dim % 2 == 0
99
+
100
+ device = pos.device
101
+ dtype = torch.float64 if device.type != "mps" else torch.float32
102
+
103
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim
104
+ omega = 1.0 / (theta**scale)
105
+
106
+ out = torch.einsum("...n,d->...nd", pos.to(dtype), omega)
107
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
108
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
109
+
110
+ return out.to(dtype=pos.dtype, device=pos.device)