InvokeAI 6.10.0rc1__py3-none-any.whl → 6.10.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. invokeai/app/invocations/flux_denoise.py +15 -1
  2. invokeai/app/invocations/pbr_maps.py +59 -0
  3. invokeai/app/invocations/z_image_denoise.py +237 -82
  4. invokeai/backend/flux/denoise.py +196 -11
  5. invokeai/backend/flux/schedulers.py +62 -0
  6. invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
  7. invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
  8. invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
  9. invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
  10. invokeai/backend/model_manager/configs/lora.py +36 -0
  11. invokeai/backend/model_manager/load/load_default.py +1 -0
  12. invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
  13. invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
  14. invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
  15. invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
  16. invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +3 -1
  17. invokeai/backend/model_manager/load/model_loaders/z_image.py +37 -3
  18. invokeai/backend/model_manager/starter_models.py +13 -4
  19. invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +39 -5
  20. invokeai/backend/quantization/gguf/ggml_tensor.py +15 -4
  21. invokeai/backend/z_image/extensions/regional_prompting_extension.py +10 -12
  22. invokeai/frontend/web/dist/assets/App-DllqPQ3j.js +161 -0
  23. invokeai/frontend/web/dist/assets/{browser-ponyfill-DHZxq1nk.js → browser-ponyfill-BP0RxJ4G.js} +1 -1
  24. invokeai/frontend/web/dist/assets/{index-dgSJAY--.js → index-B44qKjrs.js} +51 -51
  25. invokeai/frontend/web/dist/index.html +1 -1
  26. invokeai/frontend/web/dist/locales/en-GB.json +1 -0
  27. invokeai/frontend/web/dist/locales/en.json +11 -5
  28. invokeai/version/invokeai_version.py +1 -1
  29. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/METADATA +2 -2
  30. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/RECORD +36 -29
  31. invokeai/frontend/web/dist/assets/App-CYhlZO3Q.js +0 -161
  32. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/WHEEL +0 -0
  33. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/entry_points.txt +0 -0
  34. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE +0 -0
  35. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
  36. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
  37. {invokeai-6.10.0rc1.dist-info → invokeai-6.10.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
+ import inspect
1
2
  import math
2
3
  from typing import Callable
3
4
 
4
5
  import torch
6
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
5
7
  from tqdm import tqdm
6
8
 
7
9
  from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
@@ -35,24 +37,207 @@ def denoise(
35
37
  # extra img tokens (sequence-wise) - for Kontext conditioning
36
38
  img_cond_seq: torch.Tensor | None = None,
37
39
  img_cond_seq_ids: torch.Tensor | None = None,
40
+ # Optional scheduler for alternative sampling methods
41
+ scheduler: SchedulerMixin | None = None,
38
42
  ):
39
- # step 0 is the initial state
40
- total_steps = len(timesteps) - 1
41
- step_callback(
42
- PipelineIntermediateState(
43
- step=0,
44
- order=1,
45
- total_steps=total_steps,
46
- timestep=int(timesteps[0]),
47
- latents=img,
48
- ),
49
- )
43
+ # Determine if we're using a diffusers scheduler or the built-in Euler method
44
+ use_scheduler = scheduler is not None
45
+
46
+ if use_scheduler:
47
+ # Initialize scheduler with timesteps
48
+ # The timesteps list contains values in [0, 1] range (sigmas)
49
+ # LCM should use num_inference_steps (it has its own sigma schedule),
50
+ # while other schedulers can use custom sigmas if supported
51
+ is_lcm = scheduler.__class__.__name__ == "FlowMatchLCMScheduler"
52
+ set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
53
+ if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
54
+ # Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
55
+ scheduler.set_timesteps(sigmas=timesteps, device=img.device)
56
+ else:
57
+ # LCM or scheduler doesn't support custom sigmas - use num_inference_steps
58
+ # The schedule will be computed by the scheduler itself
59
+ num_inference_steps = len(timesteps) - 1
60
+ scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)
61
+
62
+ # For schedulers like Heun, the number of actual steps may differ
63
+ # (Heun doubles timesteps internally)
64
+ num_scheduler_steps = len(scheduler.timesteps)
65
+ # For user-facing step count, use the original number of denoising steps
66
+ total_steps = len(timesteps) - 1
67
+ else:
68
+ total_steps = len(timesteps) - 1
69
+ num_scheduler_steps = total_steps
70
+
50
71
  # guidance_vec is ignored for schnell.
51
72
  guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
52
73
 
53
74
  # Store original sequence length for slicing predictions
54
75
  original_seq_len = img.shape[1]
55
76
 
77
+ # Track the actual step for user-facing progress (accounts for Heun's double steps)
78
+ user_step = 0
79
+
80
+ if use_scheduler:
81
+ # Use diffusers scheduler for stepping
82
+ # Use tqdm with total_steps (user-facing steps) not num_scheduler_steps (internal steps)
83
+ # This ensures progress bar shows 1/8, 2/8, etc. even when scheduler uses more internal steps
84
+ pbar = tqdm(total=total_steps, desc="Denoising")
85
+ for step_index in range(num_scheduler_steps):
86
+ timestep = scheduler.timesteps[step_index]
87
+ # Convert scheduler timestep (0-1000) to normalized (0-1) for the model
88
+ t_curr = timestep.item() / scheduler.config.num_train_timesteps
89
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
90
+
91
+ # For Heun scheduler, track if we're in first or second order step
92
+ is_heun = hasattr(scheduler, "state_in_first_order")
93
+ in_first_order = scheduler.state_in_first_order if is_heun else True
94
+
95
+ # Run ControlNet models
96
+ controlnet_residuals: list[ControlNetFluxOutput] = []
97
+ for controlnet_extension in controlnet_extensions:
98
+ controlnet_residuals.append(
99
+ controlnet_extension.run_controlnet(
100
+ timestep_index=user_step,
101
+ total_num_timesteps=total_steps,
102
+ img=img,
103
+ img_ids=img_ids,
104
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
105
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
106
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
107
+ timesteps=t_vec,
108
+ guidance=guidance_vec,
109
+ )
110
+ )
111
+
112
+ merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
113
+
114
+ # Prepare input for model
115
+ img_input = img
116
+ img_input_ids = img_ids
117
+
118
+ if img_cond is not None:
119
+ img_input = torch.cat((img_input, img_cond), dim=-1)
120
+
121
+ if img_cond_seq is not None:
122
+ assert img_cond_seq_ids is not None
123
+ img_input = torch.cat((img_input, img_cond_seq), dim=1)
124
+ img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
125
+
126
+ pred = model(
127
+ img=img_input,
128
+ img_ids=img_input_ids,
129
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
130
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
131
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
132
+ timesteps=t_vec,
133
+ guidance=guidance_vec,
134
+ timestep_index=user_step,
135
+ total_num_timesteps=total_steps,
136
+ controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
137
+ controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
138
+ ip_adapter_extensions=pos_ip_adapter_extensions,
139
+ regional_prompting_extension=pos_regional_prompting_extension,
140
+ )
141
+
142
+ if img_cond_seq is not None:
143
+ pred = pred[:, :original_seq_len]
144
+
145
+ # Get CFG scale for current user step
146
+ step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]
147
+
148
+ if not math.isclose(step_cfg_scale, 1.0):
149
+ if neg_regional_prompting_extension is None:
150
+ raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
151
+
152
+ neg_img_input = img
153
+ neg_img_input_ids = img_ids
154
+
155
+ if img_cond is not None:
156
+ neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)
157
+
158
+ if img_cond_seq is not None:
159
+ neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
160
+ neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)
161
+
162
+ neg_pred = model(
163
+ img=neg_img_input,
164
+ img_ids=neg_img_input_ids,
165
+ txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
166
+ txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
167
+ y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
168
+ timesteps=t_vec,
169
+ guidance=guidance_vec,
170
+ timestep_index=user_step,
171
+ total_num_timesteps=total_steps,
172
+ controlnet_double_block_residuals=None,
173
+ controlnet_single_block_residuals=None,
174
+ ip_adapter_extensions=neg_ip_adapter_extensions,
175
+ regional_prompting_extension=neg_regional_prompting_extension,
176
+ )
177
+
178
+ if img_cond_seq is not None:
179
+ neg_pred = neg_pred[:, :original_seq_len]
180
+ pred = neg_pred + step_cfg_scale * (pred - neg_pred)
181
+
182
+ # Use scheduler.step() for the update
183
+ step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
184
+ img = step_output.prev_sample
185
+
186
+ # Get t_prev for inpainting (next sigma value)
187
+ if step_index + 1 < len(scheduler.sigmas):
188
+ t_prev = scheduler.sigmas[step_index + 1].item()
189
+ else:
190
+ t_prev = 0.0
191
+
192
+ if inpaint_extension is not None:
193
+ img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)
194
+
195
+ # For Heun, only increment user step after second-order step completes
196
+ if is_heun:
197
+ if not in_first_order:
198
+ # Second order step completed
199
+ user_step += 1
200
+ # Only call step_callback if we haven't exceeded total_steps
201
+ if user_step <= total_steps:
202
+ pbar.update(1)
203
+ preview_img = img - t_curr * pred
204
+ if inpaint_extension is not None:
205
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
206
+ preview_img, 0.0
207
+ )
208
+ step_callback(
209
+ PipelineIntermediateState(
210
+ step=user_step,
211
+ order=2,
212
+ total_steps=total_steps,
213
+ timestep=int(t_curr * 1000),
214
+ latents=preview_img,
215
+ ),
216
+ )
217
+ else:
218
+ # For LCM and other first-order schedulers
219
+ user_step += 1
220
+ # Only call step_callback if we haven't exceeded total_steps
221
+ # (LCM scheduler may have more internal steps than user-facing steps)
222
+ if user_step <= total_steps:
223
+ pbar.update(1)
224
+ preview_img = img - t_curr * pred
225
+ if inpaint_extension is not None:
226
+ preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
227
+ step_callback(
228
+ PipelineIntermediateState(
229
+ step=user_step,
230
+ order=1,
231
+ total_steps=total_steps,
232
+ timestep=int(t_curr * 1000),
233
+ latents=preview_img,
234
+ ),
235
+ )
236
+
237
+ pbar.close()
238
+ return img
239
+
240
+ # Original Euler implementation (when scheduler is None)
56
241
  for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
57
242
  t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
58
243
 
@@ -0,0 +1,62 @@
1
+ """Flow Matching scheduler definitions and mapping.
2
+
3
+ This module provides the scheduler types and mapping for Flow Matching models
4
+ (Flux and Z-Image), supporting multiple schedulers from the diffusers library.
5
+ """
6
+
7
+ from typing import Literal, Type
8
+
9
+ from diffusers import (
10
+ FlowMatchEulerDiscreteScheduler,
11
+ FlowMatchHeunDiscreteScheduler,
12
+ )
13
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
14
+
15
+ # Note: FlowMatchLCMScheduler may not be available in all diffusers versions
16
+ try:
17
+ from diffusers import FlowMatchLCMScheduler
18
+
19
+ _HAS_LCM = True
20
+ except ImportError:
21
+ _HAS_LCM = False
22
+
23
+ # Scheduler name literal type for type checking
24
+ FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
25
+
26
+ # Human-readable labels for the UI
27
+ FLUX_SCHEDULER_LABELS: dict[str, str] = {
28
+ "euler": "Euler",
29
+ "heun": "Heun (2nd order)",
30
+ "lcm": "LCM",
31
+ }
32
+
33
+ # Mapping from scheduler names to scheduler classes
34
+ FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
35
+ "euler": FlowMatchEulerDiscreteScheduler,
36
+ "heun": FlowMatchHeunDiscreteScheduler,
37
+ }
38
+
39
+ if _HAS_LCM:
40
+ FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
41
+
42
+
43
+ # Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
44
+ # Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
45
+ # can be used for experimentation.
46
+ ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
47
+
48
+ # Human-readable labels for the UI
49
+ ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
50
+ "euler": "Euler",
51
+ "heun": "Heun (2nd order)",
52
+ "lcm": "LCM",
53
+ }
54
+
55
+ # Mapping from scheduler names to scheduler classes (same as Flux)
56
+ ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
57
+ "euler": FlowMatchEulerDiscreteScheduler,
58
+ "heun": FlowMatchHeunDiscreteScheduler,
59
+ }
60
+
61
+ if _HAS_LCM:
62
+ ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
@@ -0,0 +1,367 @@
1
+ # Original: https://github.com/joeyballentine/Material-Map-Generator
2
+ # Adopted and optimized for Invoke AI
3
+
4
+ from collections import OrderedDict
5
+ from typing import Any, List, Literal, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"]
11
+ NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"]
12
+ PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"]
13
+ BLOCK_MODE = Literal["CNA", "NAC", "CNAC"]
14
+ UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"]
15
+
16
+
17
+ def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1):
18
+ """Helper to select Activation Layer"""
19
+ if act_type == "relu":
20
+ layer = nn.ReLU(inplace)
21
+ elif act_type == "leakyrelu":
22
+ layer = nn.LeakyReLU(neg_slope, inplace)
23
+ elif act_type == "prelu":
24
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
25
+ return layer
26
+
27
+
28
+ def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int):
29
+ """Helper to select Normalization Layer"""
30
+ if norm_type == "batch":
31
+ layer = nn.BatchNorm2d(nc, affine=True)
32
+ elif norm_type == "instance":
33
+ layer = nn.InstanceNorm2d(nc, affine=False)
34
+ return layer
35
+
36
+
37
+ def pad(pad_type: PADDING_LAYER_TYPE, padding: int):
38
+ """Helper to select Padding Layer"""
39
+ if padding == 0 or pad_type == "zero":
40
+ return None
41
+ if pad_type == "reflect":
42
+ layer = nn.ReflectionPad2d(padding)
43
+ elif pad_type == "replicate":
44
+ layer = nn.ReplicationPad2d(padding)
45
+ return layer
46
+
47
+
48
+ def get_valid_padding(kernel_size: int, dilation: int):
49
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
50
+ padding = (kernel_size - 1) // 2
51
+ return padding
52
+
53
+
54
+ def sequential(*args: Any):
55
+ # Flatten Sequential. It unwraps nn.Sequential.
56
+ if len(args) == 1:
57
+ if isinstance(args[0], OrderedDict):
58
+ raise NotImplementedError("sequential does not support OrderedDict input.")
59
+ return args[0] # No sequential is needed.
60
+ modules: List[nn.Module] = []
61
+ for module in args:
62
+ if isinstance(module, nn.Sequential):
63
+ for submodule in module.children():
64
+ modules.append(submodule)
65
+ elif isinstance(module, nn.Module):
66
+ modules.append(module)
67
+ return nn.Sequential(*modules)
68
+
69
+
70
+ def conv_block(
71
+ in_nc: int,
72
+ out_nc: int,
73
+ kernel_size: int,
74
+ stride: int = 1,
75
+ dilation: int = 1,
76
+ groups: int = 1,
77
+ bias: bool = True,
78
+ pad_type: Optional[PADDING_LAYER_TYPE] = "zero",
79
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
80
+ act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
81
+ mode: BLOCK_MODE = "CNA",
82
+ ):
83
+ """
84
+ Conv layer with padding, normalization, activation
85
+ mode: CNA --> Conv -> Norm -> Act
86
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
87
+ """
88
+ assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
89
+ padding = get_valid_padding(kernel_size, dilation)
90
+ p = pad(pad_type, padding) if pad_type else None
91
+ padding = padding if pad_type == "zero" else 0
92
+
93
+ c = nn.Conv2d(
94
+ in_nc,
95
+ out_nc,
96
+ kernel_size=kernel_size,
97
+ stride=stride,
98
+ padding=padding,
99
+ dilation=dilation,
100
+ bias=bias,
101
+ groups=groups,
102
+ )
103
+ a = act(act_type) if act_type else None
104
+ match mode:
105
+ case "CNA":
106
+ n = norm(norm_type, out_nc) if norm_type else None
107
+ return sequential(p, c, n, a)
108
+ case "NAC":
109
+ if norm_type is None and act_type is not None:
110
+ a = act(act_type, inplace=False)
111
+ n = norm(norm_type, in_nc) if norm_type else None
112
+ return sequential(n, a, p, c)
113
+ case "CNAC":
114
+ n = norm(norm_type, in_nc) if norm_type else None
115
+ return sequential(n, a, p, c)
116
+
117
+
118
+ class ConcatBlock(nn.Module):
119
+ # Concat the output of a submodule to its input
120
+ def __init__(self, submodule: nn.Module):
121
+ super(ConcatBlock, self).__init__()
122
+ self.sub = submodule
123
+
124
+ def forward(self, x: torch.Tensor):
125
+ output = torch.cat((x, self.sub(x)), dim=1)
126
+ return output
127
+
128
+ def __repr__(self):
129
+ tmpstr = "Identity .. \n|"
130
+ modstr = self.sub.__repr__().replace("\n", "\n|")
131
+ tmpstr = tmpstr + modstr
132
+ return tmpstr
133
+
134
+
135
+ class ShortcutBlock(nn.Module):
136
+ # Elementwise sum the output of a submodule to its input
137
+ def __init__(self, submodule: nn.Module):
138
+ super(ShortcutBlock, self).__init__()
139
+ self.sub = submodule
140
+
141
+ def forward(self, x: torch.Tensor):
142
+ output = x + self.sub(x)
143
+ return output
144
+
145
+ def __repr__(self):
146
+ tmpstr = "Identity + \n|"
147
+ modstr = self.sub.__repr__().replace("\n", "\n|")
148
+ tmpstr = tmpstr + modstr
149
+ return tmpstr
150
+
151
+
152
+ class ShortcutBlockSPSR(nn.Module):
153
+ # Elementwise sum the output of a submodule to its input
154
+ def __init__(self, submodule: nn.Module):
155
+ super(ShortcutBlockSPSR, self).__init__()
156
+ self.sub = submodule
157
+
158
+ def forward(self, x: torch.Tensor):
159
+ return x, self.sub
160
+
161
+ def __repr__(self):
162
+ tmpstr = "Identity + \n|"
163
+ modstr = self.sub.__repr__().replace("\n", "\n|")
164
+ tmpstr = tmpstr + modstr
165
+ return tmpstr
166
+
167
+
168
+ class ResNetBlock(nn.Module):
169
+ """
170
+ ResNet Block, 3-3 style
171
+ with extra residual scaling used in EDSR
172
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ in_nc: int,
178
+ mid_nc: int,
179
+ out_nc: int,
180
+ kernel_size: int = 3,
181
+ stride: int = 1,
182
+ dilation: int = 1,
183
+ groups: int = 1,
184
+ bias: bool = True,
185
+ pad_type: PADDING_LAYER_TYPE = "zero",
186
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
187
+ act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
188
+ mode: BLOCK_MODE = "CNA",
189
+ res_scale: int = 1,
190
+ ):
191
+ super(ResNetBlock, self).__init__()
192
+ conv0 = conv_block(
193
+ in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
194
+ )
195
+ if mode == "CNA":
196
+ act_type = None
197
+ if mode == "CNAC": # Residual path: |-CNAC-|
198
+ act_type = None
199
+ norm_type = None
200
+ conv1 = conv_block(
201
+ mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
202
+ )
203
+
204
+ self.res = sequential(conv0, conv1)
205
+ self.res_scale = res_scale
206
+
207
+ def forward(self, x: torch.Tensor):
208
+ res = self.res(x).mul(self.res_scale)
209
+ return x + res
210
+
211
+
212
+ class ResidualDenseBlock_5C(nn.Module):
213
+ """
214
+ Residual Dense Block
215
+ style: 5 convs
216
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ nc: int,
222
+ kernel_size: int = 3,
223
+ gc: int = 32,
224
+ stride: int = 1,
225
+ bias: bool = True,
226
+ pad_type: PADDING_LAYER_TYPE = "zero",
227
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
228
+ act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
229
+ mode: BLOCK_MODE = "CNA",
230
+ ):
231
+ super(ResidualDenseBlock_5C, self).__init__()
232
+ # gc: growth channel, i.e. intermediate channels
233
+ self.conv1 = conv_block(
234
+ nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode
235
+ )
236
+ self.conv2 = conv_block(
237
+ nc + gc,
238
+ gc,
239
+ kernel_size,
240
+ stride,
241
+ bias=bias,
242
+ pad_type=pad_type,
243
+ norm_type=norm_type,
244
+ act_type=act_type,
245
+ mode=mode,
246
+ )
247
+ self.conv3 = conv_block(
248
+ nc + 2 * gc,
249
+ gc,
250
+ kernel_size,
251
+ stride,
252
+ bias=bias,
253
+ pad_type=pad_type,
254
+ norm_type=norm_type,
255
+ act_type=act_type,
256
+ mode=mode,
257
+ )
258
+ self.conv4 = conv_block(
259
+ nc + 3 * gc,
260
+ gc,
261
+ kernel_size,
262
+ stride,
263
+ bias=bias,
264
+ pad_type=pad_type,
265
+ norm_type=norm_type,
266
+ act_type=act_type,
267
+ mode=mode,
268
+ )
269
+ if mode == "CNA":
270
+ last_act = None
271
+ else:
272
+ last_act = act_type
273
+ self.conv5 = conv_block(
274
+ nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode
275
+ )
276
+
277
+ def forward(self, x: torch.Tensor):
278
+ x1 = self.conv1(x)
279
+ x2 = self.conv2(torch.cat((x, x1), 1))
280
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
281
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
282
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
283
+ return x5.mul(0.2) + x
284
+
285
+
286
+ class RRDB(nn.Module):
287
+ """
288
+ Residual in Residual Dense Block
289
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ nc: int,
295
+ kernel_size: int = 3,
296
+ gc: int = 32,
297
+ stride: int = 1,
298
+ bias: bool = True,
299
+ pad_type: PADDING_LAYER_TYPE = "zero",
300
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
301
+ act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
302
+ mode: BLOCK_MODE = "CNA",
303
+ ):
304
+ super(RRDB, self).__init__()
305
+ self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
306
+ self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
307
+ self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
308
+
309
+ def forward(self, x: torch.Tensor):
310
+ out = self.RDB1(x)
311
+ out = self.RDB2(out)
312
+ out = self.RDB3(out)
313
+ return out.mul(0.2) + x
314
+
315
+
316
+ # Upsampler
317
+ def pixelshuffle_block(
318
+ in_nc: int,
319
+ out_nc: int,
320
+ upscale_factor: int = 2,
321
+ kernel_size: int = 3,
322
+ stride: int = 1,
323
+ bias: bool = True,
324
+ pad_type: PADDING_LAYER_TYPE = "zero",
325
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
326
+ act_type: ACTIVATION_LAYER_TYPE = "relu",
327
+ ):
328
+ """
329
+ Pixel shuffle layer
330
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
331
+ Neural Network, CVPR17)
332
+ """
333
+ conv = conv_block(
334
+ in_nc,
335
+ out_nc * (upscale_factor**2),
336
+ kernel_size,
337
+ stride,
338
+ bias=bias,
339
+ pad_type=pad_type,
340
+ norm_type=None,
341
+ act_type=None,
342
+ )
343
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
344
+
345
+ n = norm(norm_type, out_nc) if norm_type else None
346
+ a = act(act_type) if act_type else None
347
+ return sequential(conv, pixel_shuffle, n, a)
348
+
349
+
350
+ def upconv_block(
351
+ in_nc: int,
352
+ out_nc: int,
353
+ upscale_factor: int = 2,
354
+ kernel_size: int = 3,
355
+ stride: int = 1,
356
+ bias: bool = True,
357
+ pad_type: PADDING_LAYER_TYPE = "zero",
358
+ norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
359
+ act_type: ACTIVATION_LAYER_TYPE = "relu",
360
+ mode: UPCONV_BLOCK_MODE = "nearest",
361
+ ):
362
+ # Adopted from https://distill.pub/2016/deconv-checkerboard/
363
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
364
+ conv = conv_block(
365
+ in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type
366
+ )
367
+ return sequential(upsample, conv)