InvokeAI 6.10.0rc1__py3-none-any.whl → 6.10.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/invocations/flux_denoise.py +15 -1
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/z_image_denoise.py +237 -82
- invokeai/backend/flux/denoise.py +196 -11
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/lora.py +36 -0
- invokeai/backend/model_manager/load/load_default.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +3 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +37 -3
- invokeai/backend/model_manager/starter_models.py +13 -4
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
- invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
- invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-BP0RxJ4G.js} +1 -1
- invokeai/frontend/web/dist/assets/{index-dgSJAY--.js → index-B44qKjrs.js} +51 -51
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +11 -5
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/METADATA +2 -2
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/RECORD +36 -29
- invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/WHEEL +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/entry_points.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/top_level.txt +0 -0
invokeai/backend/flux/denoise.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import math
|
|
2
3
|
from typing import Callable
|
|
3
4
|
|
|
4
5
|
import torch
|
|
6
|
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
5
7
|
from tqdm import tqdm
|
|
6
8
|
|
|
7
9
|
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
|
@@ -35,24 +37,207 @@ def denoise(
|
|
|
35
37
|
# extra img tokens (sequence-wise) - for Kontext conditioning
|
|
36
38
|
img_cond_seq: torch.Tensor | None = None,
|
|
37
39
|
img_cond_seq_ids: torch.Tensor | None = None,
|
|
40
|
+
# Optional scheduler for alternative sampling methods
|
|
41
|
+
scheduler: SchedulerMixin | None = None,
|
|
38
42
|
):
|
|
39
|
-
#
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
)
|
|
49
|
-
|
|
43
|
+
# Determine if we're using a diffusers scheduler or the built-in Euler method
|
|
44
|
+
use_scheduler = scheduler is not None
|
|
45
|
+
|
|
46
|
+
if use_scheduler:
|
|
47
|
+
# Initialize scheduler with timesteps
|
|
48
|
+
# The timesteps list contains values in [0, 1] range (sigmas)
|
|
49
|
+
# LCM should use num_inference_steps (it has its own sigma schedule),
|
|
50
|
+
# while other schedulers can use custom sigmas if supported
|
|
51
|
+
is_lcm = scheduler.__class__.__name__ == "FlowMatchLCMScheduler"
|
|
52
|
+
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
|
|
53
|
+
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
|
|
54
|
+
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
|
|
55
|
+
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
|
|
56
|
+
else:
|
|
57
|
+
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
|
|
58
|
+
# The schedule will be computed by the scheduler itself
|
|
59
|
+
num_inference_steps = len(timesteps) - 1
|
|
60
|
+
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
|
|
61
|
+
|
|
62
|
+
# For schedulers like Heun, the number of actual steps may differ
|
|
63
|
+
# (Heun doubles timesteps internally)
|
|
64
|
+
num_scheduler_steps = len(scheduler.timesteps)
|
|
65
|
+
# For user-facing step count, use the original number of denoising steps
|
|
66
|
+
total_steps = len(timesteps) - 1
|
|
67
|
+
else:
|
|
68
|
+
total_steps = len(timesteps) - 1
|
|
69
|
+
num_scheduler_steps = total_steps
|
|
70
|
+
|
|
50
71
|
# guidance_vec is ignored for schnell.
|
|
51
72
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
|
52
73
|
|
|
53
74
|
# Store original sequence length for slicing predictions
|
|
54
75
|
original_seq_len = img.shape[1]
|
|
55
76
|
|
|
77
|
+
# Track the actual step for user-facing progress (accounts for Heun's double steps)
|
|
78
|
+
user_step = 0
|
|
79
|
+
|
|
80
|
+
if use_scheduler:
|
|
81
|
+
# Use diffusers scheduler for stepping
|
|
82
|
+
# Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
|
|
83
|
+
# This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
|
|
84
|
+
pbar = tqdm(total=total_steps, desc="Denoising")
|
|
85
|
+
for step_index in range(num_scheduler_steps):
|
|
86
|
+
timestep = scheduler.timesteps[step_index]
|
|
87
|
+
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
|
|
88
|
+
t_curr = timestep.item() / scheduler.config.num_train_timesteps
|
|
89
|
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
90
|
+
|
|
91
|
+
# For Heun scheduler, track if we're in first or second order step
|
|
92
|
+
is_heun = hasattr(scheduler, "state_in_first_order")
|
|
93
|
+
in_first_order = scheduler.state_in_first_order if is_heun else True
|
|
94
|
+
|
|
95
|
+
# Run ControlNet models
|
|
96
|
+
controlnet_residuals: list[ControlNetFluxOutput] = []
|
|
97
|
+
for controlnet_extension in controlnet_extensions:
|
|
98
|
+
controlnet_residuals.append(
|
|
99
|
+
controlnet_extension.run_controlnet(
|
|
100
|
+
timestep_index=user_step,
|
|
101
|
+
total_num_timesteps=total_steps,
|
|
102
|
+
img=img,
|
|
103
|
+
img_ids=img_ids,
|
|
104
|
+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
|
105
|
+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
|
106
|
+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
|
107
|
+
timesteps=t_vec,
|
|
108
|
+
guidance=guidance_vec,
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
|
113
|
+
|
|
114
|
+
# Prepare input for model
|
|
115
|
+
img_input = img
|
|
116
|
+
img_input_ids = img_ids
|
|
117
|
+
|
|
118
|
+
if img_cond is not None:
|
|
119
|
+
img_input = torch.cat((img_input, img_cond), dim=-1)
|
|
120
|
+
|
|
121
|
+
if img_cond_seq is not None:
|
|
122
|
+
assert img_cond_seq_ids is not None
|
|
123
|
+
img_input = torch.cat((img_input, img_cond_seq), dim=1)
|
|
124
|
+
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
|
|
125
|
+
|
|
126
|
+
pred = model(
|
|
127
|
+
img=img_input,
|
|
128
|
+
img_ids=img_input_ids,
|
|
129
|
+
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
|
130
|
+
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
|
131
|
+
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
|
132
|
+
timesteps=t_vec,
|
|
133
|
+
guidance=guidance_vec,
|
|
134
|
+
timestep_index=user_step,
|
|
135
|
+
total_num_timesteps=total_steps,
|
|
136
|
+
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
|
137
|
+
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
|
138
|
+
ip_adapter_extensions=pos_ip_adapter_extensions,
|
|
139
|
+
regional_prompting_extension=pos_regional_prompting_extension,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if img_cond_seq is not None:
|
|
143
|
+
pred = pred[:, :original_seq_len]
|
|
144
|
+
|
|
145
|
+
# Get CFG scale for current user step
|
|
146
|
+
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
|
|
147
|
+
|
|
148
|
+
if not math.isclose(step_cfg_scale, 1.0):
|
|
149
|
+
if neg_regional_prompting_extension is None:
|
|
150
|
+
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
|
|
151
|
+
|
|
152
|
+
neg_img_input = img
|
|
153
|
+
neg_img_input_ids = img_ids
|
|
154
|
+
|
|
155
|
+
if img_cond is not None:
|
|
156
|
+
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
|
|
157
|
+
|
|
158
|
+
if img_cond_seq is not None:
|
|
159
|
+
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
|
|
160
|
+
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
|
|
161
|
+
|
|
162
|
+
neg_pred = model(
|
|
163
|
+
img=neg_img_input,
|
|
164
|
+
img_ids=neg_img_input_ids,
|
|
165
|
+
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
|
|
166
|
+
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
|
|
167
|
+
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
|
|
168
|
+
timesteps=t_vec,
|
|
169
|
+
guidance=guidance_vec,
|
|
170
|
+
timestep_index=user_step,
|
|
171
|
+
total_num_timesteps=total_steps,
|
|
172
|
+
controlnet_double_block_residuals=None,
|
|
173
|
+
controlnet_single_block_residuals=None,
|
|
174
|
+
ip_adapter_extensions=neg_ip_adapter_extensions,
|
|
175
|
+
regional_prompting_extension=neg_regional_prompting_extension,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if img_cond_seq is not None:
|
|
179
|
+
neg_pred = neg_pred[:, :original_seq_len]
|
|
180
|
+
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
|
|
181
|
+
|
|
182
|
+
# Use scheduler.step() for the update
|
|
183
|
+
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
|
|
184
|
+
img = step_output.prev_sample
|
|
185
|
+
|
|
186
|
+
# Get t_prev for inpainting (next sigma value)
|
|
187
|
+
if step_index + 1 < len(scheduler.sigmas):
|
|
188
|
+
t_prev = scheduler.sigmas[step_index + 1].item()
|
|
189
|
+
else:
|
|
190
|
+
t_prev = 0.0
|
|
191
|
+
|
|
192
|
+
if inpaint_extension is not None:
|
|
193
|
+
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
|
|
194
|
+
|
|
195
|
+
# For Heun, only increment user step after second-order step completes
|
|
196
|
+
if is_heun:
|
|
197
|
+
if not in_first_order:
|
|
198
|
+
# Second order step completed
|
|
199
|
+
user_step += 1
|
|
200
|
+
# Only call step_callback if we haven't exceeded total_steps
|
|
201
|
+
if user_step <= total_steps:
|
|
202
|
+
pbar.update(1)
|
|
203
|
+
preview_img = img - t_curr * pred
|
|
204
|
+
if inpaint_extension is not None:
|
|
205
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
|
|
206
|
+
preview_img, 0.0
|
|
207
|
+
)
|
|
208
|
+
step_callback(
|
|
209
|
+
PipelineIntermediateState(
|
|
210
|
+
step=user_step,
|
|
211
|
+
order=2,
|
|
212
|
+
total_steps=total_steps,
|
|
213
|
+
timestep=int(t_curr * 1000),
|
|
214
|
+
latents=preview_img,
|
|
215
|
+
),
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
# For LCM and other first-order schedulers
|
|
219
|
+
user_step += 1
|
|
220
|
+
# Only call step_callback if we haven't exceeded total_steps
|
|
221
|
+
# (LCM scheduler may have more internal steps than user-facing steps)
|
|
222
|
+
if user_step <= total_steps:
|
|
223
|
+
pbar.update(1)
|
|
224
|
+
preview_img = img - t_curr * pred
|
|
225
|
+
if inpaint_extension is not None:
|
|
226
|
+
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
|
|
227
|
+
step_callback(
|
|
228
|
+
PipelineIntermediateState(
|
|
229
|
+
step=user_step,
|
|
230
|
+
order=1,
|
|
231
|
+
total_steps=total_steps,
|
|
232
|
+
timestep=int(t_curr * 1000),
|
|
233
|
+
latents=preview_img,
|
|
234
|
+
),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
pbar.close()
|
|
238
|
+
return img
|
|
239
|
+
|
|
240
|
+
# Original Euler implementation (when scheduler is None)
|
|
56
241
|
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
|
57
242
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
|
58
243
|
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Flow Matching scheduler definitions and mapping.
|
|
2
|
+
|
|
3
|
+
This module provides the scheduler types and mapping for Flow Matching models
|
|
4
|
+
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Literal, Type
|
|
8
|
+
|
|
9
|
+
from diffusers import (
|
|
10
|
+
FlowMatchEulerDiscreteScheduler,
|
|
11
|
+
FlowMatchHeunDiscreteScheduler,
|
|
12
|
+
)
|
|
13
|
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
|
14
|
+
|
|
15
|
+
# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
|
|
16
|
+
try:
|
|
17
|
+
from diffusers import FlowMatchLCMScheduler
|
|
18
|
+
|
|
19
|
+
_HAS_LCM = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
_HAS_LCM = False
|
|
22
|
+
|
|
23
|
+
# Scheduler name literal type for type checking
|
|
24
|
+
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
|
25
|
+
|
|
26
|
+
# Human-readable labels for the UI
|
|
27
|
+
FLUX_SCHEDULER_LABELS: dict[str, str] = {
|
|
28
|
+
"euler": "Euler",
|
|
29
|
+
"heun": "Heun (2nd order)",
|
|
30
|
+
"lcm": "LCM",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
# Mapping from scheduler names to scheduler classes
|
|
34
|
+
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
|
35
|
+
"euler": FlowMatchEulerDiscreteScheduler,
|
|
36
|
+
"heun": FlowMatchHeunDiscreteScheduler,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
if _HAS_LCM:
|
|
40
|
+
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
|
|
44
|
+
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
|
|
45
|
+
# can be used for experimentation.
|
|
46
|
+
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
|
|
47
|
+
|
|
48
|
+
# Human-readable labels for the UI
|
|
49
|
+
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
|
|
50
|
+
"euler": "Euler",
|
|
51
|
+
"heun": "Heun (2nd order)",
|
|
52
|
+
"lcm": "LCM",
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
# Mapping from scheduler names to scheduler classes (same as Flux)
|
|
56
|
+
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
|
|
57
|
+
"euler": FlowMatchEulerDiscreteScheduler,
|
|
58
|
+
"heun": FlowMatchHeunDiscreteScheduler,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if _HAS_LCM:
|
|
62
|
+
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
|
|
@@ -0,0 +1,367 @@
|
|
|
1
|
+
# Original: https://github.com/joeyballentine/Material-Map-Generator
|
|
2
|
+
# Adopted and optimized for Invoke AI
|
|
3
|
+
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from typing import Any, List, Literal, Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"]
|
|
11
|
+
NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"]
|
|
12
|
+
PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"]
|
|
13
|
+
BLOCK_MODE = Literal["CNA", "NAC", "CNAC"]
|
|
14
|
+
UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1):
|
|
18
|
+
"""Helper to select Activation Layer"""
|
|
19
|
+
if act_type == "relu":
|
|
20
|
+
layer = nn.ReLU(inplace)
|
|
21
|
+
elif act_type == "leakyrelu":
|
|
22
|
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
|
23
|
+
elif act_type == "prelu":
|
|
24
|
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
|
25
|
+
return layer
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int):
|
|
29
|
+
"""Helper to select Normalization Layer"""
|
|
30
|
+
if norm_type == "batch":
|
|
31
|
+
layer = nn.BatchNorm2d(nc, affine=True)
|
|
32
|
+
elif norm_type == "instance":
|
|
33
|
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
|
34
|
+
return layer
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def pad(pad_type: PADDING_LAYER_TYPE, padding: int):
|
|
38
|
+
"""Helper to select Padding Layer"""
|
|
39
|
+
if padding == 0 or pad_type == "zero":
|
|
40
|
+
return None
|
|
41
|
+
if pad_type == "reflect":
|
|
42
|
+
layer = nn.ReflectionPad2d(padding)
|
|
43
|
+
elif pad_type == "replicate":
|
|
44
|
+
layer = nn.ReplicationPad2d(padding)
|
|
45
|
+
return layer
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_valid_padding(kernel_size: int, dilation: int):
|
|
49
|
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
|
50
|
+
padding = (kernel_size - 1) // 2
|
|
51
|
+
return padding
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def sequential(*args: Any):
|
|
55
|
+
# Flatten Sequential. It unwraps nn.Sequential.
|
|
56
|
+
if len(args) == 1:
|
|
57
|
+
if isinstance(args[0], OrderedDict):
|
|
58
|
+
raise NotImplementedError("sequential does not support OrderedDict input.")
|
|
59
|
+
return args[0] # No sequential is needed.
|
|
60
|
+
modules: List[nn.Module] = []
|
|
61
|
+
for module in args:
|
|
62
|
+
if isinstance(module, nn.Sequential):
|
|
63
|
+
for submodule in module.children():
|
|
64
|
+
modules.append(submodule)
|
|
65
|
+
elif isinstance(module, nn.Module):
|
|
66
|
+
modules.append(module)
|
|
67
|
+
return nn.Sequential(*modules)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def conv_block(
|
|
71
|
+
in_nc: int,
|
|
72
|
+
out_nc: int,
|
|
73
|
+
kernel_size: int,
|
|
74
|
+
stride: int = 1,
|
|
75
|
+
dilation: int = 1,
|
|
76
|
+
groups: int = 1,
|
|
77
|
+
bias: bool = True,
|
|
78
|
+
pad_type: Optional[PADDING_LAYER_TYPE] = "zero",
|
|
79
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
80
|
+
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
|
|
81
|
+
mode: BLOCK_MODE = "CNA",
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Conv layer with padding, normalization, activation
|
|
85
|
+
mode: CNA --> Conv -> Norm -> Act
|
|
86
|
+
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
|
87
|
+
"""
|
|
88
|
+
assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
|
|
89
|
+
padding = get_valid_padding(kernel_size, dilation)
|
|
90
|
+
p = pad(pad_type, padding) if pad_type else None
|
|
91
|
+
padding = padding if pad_type == "zero" else 0
|
|
92
|
+
|
|
93
|
+
c = nn.Conv2d(
|
|
94
|
+
in_nc,
|
|
95
|
+
out_nc,
|
|
96
|
+
kernel_size=kernel_size,
|
|
97
|
+
stride=stride,
|
|
98
|
+
padding=padding,
|
|
99
|
+
dilation=dilation,
|
|
100
|
+
bias=bias,
|
|
101
|
+
groups=groups,
|
|
102
|
+
)
|
|
103
|
+
a = act(act_type) if act_type else None
|
|
104
|
+
match mode:
|
|
105
|
+
case "CNA":
|
|
106
|
+
n = norm(norm_type, out_nc) if norm_type else None
|
|
107
|
+
return sequential(p, c, n, a)
|
|
108
|
+
case "NAC":
|
|
109
|
+
if norm_type is None and act_type is not None:
|
|
110
|
+
a = act(act_type, inplace=False)
|
|
111
|
+
n = norm(norm_type, in_nc) if norm_type else None
|
|
112
|
+
return sequential(n, a, p, c)
|
|
113
|
+
case "CNAC":
|
|
114
|
+
n = norm(norm_type, in_nc) if norm_type else None
|
|
115
|
+
return sequential(n, a, p, c)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class ConcatBlock(nn.Module):
|
|
119
|
+
# Concat the output of a submodule to its input
|
|
120
|
+
def __init__(self, submodule: nn.Module):
|
|
121
|
+
super(ConcatBlock, self).__init__()
|
|
122
|
+
self.sub = submodule
|
|
123
|
+
|
|
124
|
+
def forward(self, x: torch.Tensor):
|
|
125
|
+
output = torch.cat((x, self.sub(x)), dim=1)
|
|
126
|
+
return output
|
|
127
|
+
|
|
128
|
+
def __repr__(self):
|
|
129
|
+
tmpstr = "Identity .. \n|"
|
|
130
|
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
131
|
+
tmpstr = tmpstr + modstr
|
|
132
|
+
return tmpstr
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class ShortcutBlock(nn.Module):
|
|
136
|
+
# Elementwise sum the output of a submodule to its input
|
|
137
|
+
def __init__(self, submodule: nn.Module):
|
|
138
|
+
super(ShortcutBlock, self).__init__()
|
|
139
|
+
self.sub = submodule
|
|
140
|
+
|
|
141
|
+
def forward(self, x: torch.Tensor):
|
|
142
|
+
output = x + self.sub(x)
|
|
143
|
+
return output
|
|
144
|
+
|
|
145
|
+
def __repr__(self):
|
|
146
|
+
tmpstr = "Identity + \n|"
|
|
147
|
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
148
|
+
tmpstr = tmpstr + modstr
|
|
149
|
+
return tmpstr
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ShortcutBlockSPSR(nn.Module):
|
|
153
|
+
# Elementwise sum the output of a submodule to its input
|
|
154
|
+
def __init__(self, submodule: nn.Module):
|
|
155
|
+
super(ShortcutBlockSPSR, self).__init__()
|
|
156
|
+
self.sub = submodule
|
|
157
|
+
|
|
158
|
+
def forward(self, x: torch.Tensor):
|
|
159
|
+
return x, self.sub
|
|
160
|
+
|
|
161
|
+
def __repr__(self):
|
|
162
|
+
tmpstr = "Identity + \n|"
|
|
163
|
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
|
164
|
+
tmpstr = tmpstr + modstr
|
|
165
|
+
return tmpstr
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ResNetBlock(nn.Module):
|
|
169
|
+
"""
|
|
170
|
+
ResNet Block, 3-3 style
|
|
171
|
+
with extra residual scaling used in EDSR
|
|
172
|
+
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
in_nc: int,
|
|
178
|
+
mid_nc: int,
|
|
179
|
+
out_nc: int,
|
|
180
|
+
kernel_size: int = 3,
|
|
181
|
+
stride: int = 1,
|
|
182
|
+
dilation: int = 1,
|
|
183
|
+
groups: int = 1,
|
|
184
|
+
bias: bool = True,
|
|
185
|
+
pad_type: PADDING_LAYER_TYPE = "zero",
|
|
186
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
187
|
+
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
|
|
188
|
+
mode: BLOCK_MODE = "CNA",
|
|
189
|
+
res_scale: int = 1,
|
|
190
|
+
):
|
|
191
|
+
super(ResNetBlock, self).__init__()
|
|
192
|
+
conv0 = conv_block(
|
|
193
|
+
in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
|
|
194
|
+
)
|
|
195
|
+
if mode == "CNA":
|
|
196
|
+
act_type = None
|
|
197
|
+
if mode == "CNAC": # Residual path: |-CNAC-|
|
|
198
|
+
act_type = None
|
|
199
|
+
norm_type = None
|
|
200
|
+
conv1 = conv_block(
|
|
201
|
+
mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.res = sequential(conv0, conv1)
|
|
205
|
+
self.res_scale = res_scale
|
|
206
|
+
|
|
207
|
+
def forward(self, x: torch.Tensor):
|
|
208
|
+
res = self.res(x).mul(self.res_scale)
|
|
209
|
+
return x + res
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class ResidualDenseBlock_5C(nn.Module):
|
|
213
|
+
"""
|
|
214
|
+
Residual Dense Block
|
|
215
|
+
style: 5 convs
|
|
216
|
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
nc: int,
|
|
222
|
+
kernel_size: int = 3,
|
|
223
|
+
gc: int = 32,
|
|
224
|
+
stride: int = 1,
|
|
225
|
+
bias: bool = True,
|
|
226
|
+
pad_type: PADDING_LAYER_TYPE = "zero",
|
|
227
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
228
|
+
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
|
|
229
|
+
mode: BLOCK_MODE = "CNA",
|
|
230
|
+
):
|
|
231
|
+
super(ResidualDenseBlock_5C, self).__init__()
|
|
232
|
+
# gc: growth channel, i.e. intermediate channels
|
|
233
|
+
self.conv1 = conv_block(
|
|
234
|
+
nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode
|
|
235
|
+
)
|
|
236
|
+
self.conv2 = conv_block(
|
|
237
|
+
nc + gc,
|
|
238
|
+
gc,
|
|
239
|
+
kernel_size,
|
|
240
|
+
stride,
|
|
241
|
+
bias=bias,
|
|
242
|
+
pad_type=pad_type,
|
|
243
|
+
norm_type=norm_type,
|
|
244
|
+
act_type=act_type,
|
|
245
|
+
mode=mode,
|
|
246
|
+
)
|
|
247
|
+
self.conv3 = conv_block(
|
|
248
|
+
nc + 2 * gc,
|
|
249
|
+
gc,
|
|
250
|
+
kernel_size,
|
|
251
|
+
stride,
|
|
252
|
+
bias=bias,
|
|
253
|
+
pad_type=pad_type,
|
|
254
|
+
norm_type=norm_type,
|
|
255
|
+
act_type=act_type,
|
|
256
|
+
mode=mode,
|
|
257
|
+
)
|
|
258
|
+
self.conv4 = conv_block(
|
|
259
|
+
nc + 3 * gc,
|
|
260
|
+
gc,
|
|
261
|
+
kernel_size,
|
|
262
|
+
stride,
|
|
263
|
+
bias=bias,
|
|
264
|
+
pad_type=pad_type,
|
|
265
|
+
norm_type=norm_type,
|
|
266
|
+
act_type=act_type,
|
|
267
|
+
mode=mode,
|
|
268
|
+
)
|
|
269
|
+
if mode == "CNA":
|
|
270
|
+
last_act = None
|
|
271
|
+
else:
|
|
272
|
+
last_act = act_type
|
|
273
|
+
self.conv5 = conv_block(
|
|
274
|
+
nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def forward(self, x: torch.Tensor):
|
|
278
|
+
x1 = self.conv1(x)
|
|
279
|
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
|
280
|
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
|
281
|
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
|
282
|
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
|
283
|
+
return x5.mul(0.2) + x
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class RRDB(nn.Module):
|
|
287
|
+
"""
|
|
288
|
+
Residual in Residual Dense Block
|
|
289
|
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
def __init__(
|
|
293
|
+
self,
|
|
294
|
+
nc: int,
|
|
295
|
+
kernel_size: int = 3,
|
|
296
|
+
gc: int = 32,
|
|
297
|
+
stride: int = 1,
|
|
298
|
+
bias: bool = True,
|
|
299
|
+
pad_type: PADDING_LAYER_TYPE = "zero",
|
|
300
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
301
|
+
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
|
|
302
|
+
mode: BLOCK_MODE = "CNA",
|
|
303
|
+
):
|
|
304
|
+
super(RRDB, self).__init__()
|
|
305
|
+
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
|
306
|
+
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
|
307
|
+
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
|
|
308
|
+
|
|
309
|
+
def forward(self, x: torch.Tensor):
|
|
310
|
+
out = self.RDB1(x)
|
|
311
|
+
out = self.RDB2(out)
|
|
312
|
+
out = self.RDB3(out)
|
|
313
|
+
return out.mul(0.2) + x
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
# Upsampler
|
|
317
|
+
def pixelshuffle_block(
|
|
318
|
+
in_nc: int,
|
|
319
|
+
out_nc: int,
|
|
320
|
+
upscale_factor: int = 2,
|
|
321
|
+
kernel_size: int = 3,
|
|
322
|
+
stride: int = 1,
|
|
323
|
+
bias: bool = True,
|
|
324
|
+
pad_type: PADDING_LAYER_TYPE = "zero",
|
|
325
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
326
|
+
act_type: ACTIVATION_LAYER_TYPE = "relu",
|
|
327
|
+
):
|
|
328
|
+
"""
|
|
329
|
+
Pixel shuffle layer
|
|
330
|
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
|
331
|
+
Neural Network, CVPR17)
|
|
332
|
+
"""
|
|
333
|
+
conv = conv_block(
|
|
334
|
+
in_nc,
|
|
335
|
+
out_nc * (upscale_factor**2),
|
|
336
|
+
kernel_size,
|
|
337
|
+
stride,
|
|
338
|
+
bias=bias,
|
|
339
|
+
pad_type=pad_type,
|
|
340
|
+
norm_type=None,
|
|
341
|
+
act_type=None,
|
|
342
|
+
)
|
|
343
|
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
344
|
+
|
|
345
|
+
n = norm(norm_type, out_nc) if norm_type else None
|
|
346
|
+
a = act(act_type) if act_type else None
|
|
347
|
+
return sequential(conv, pixel_shuffle, n, a)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def upconv_block(
|
|
351
|
+
in_nc: int,
|
|
352
|
+
out_nc: int,
|
|
353
|
+
upscale_factor: int = 2,
|
|
354
|
+
kernel_size: int = 3,
|
|
355
|
+
stride: int = 1,
|
|
356
|
+
bias: bool = True,
|
|
357
|
+
pad_type: PADDING_LAYER_TYPE = "zero",
|
|
358
|
+
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
|
|
359
|
+
act_type: ACTIVATION_LAYER_TYPE = "relu",
|
|
360
|
+
mode: UPCONV_BLOCK_MODE = "nearest",
|
|
361
|
+
):
|
|
362
|
+
# Adopted from https://distill.pub/2016/deconv-checkerboard/
|
|
363
|
+
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
|
364
|
+
conv = conv_block(
|
|
365
|
+
in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type
|
|
366
|
+
)
|
|
367
|
+
return sequential(upsample, conv)
|