ai-edge-torch-nightly 0.2.0.dev20240606__py3-none-any.whl → 0.2.0.dev20240609__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +2 -2
- ai_edge_torch/convert/fx_passes/__init__.py +1 -1
- ai_edge_torch/convert/fx_passes/{build_upsample_bilinear2d_composite_pass.py → build_interpolate_composite_pass.py} +22 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +8 -4
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +275 -82
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +54 -3
- ai_edge_torch/generative/layers/attention.py +25 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +287 -0
- ai_edge_torch/generative/layers/unet/builder.py +29 -0
- ai_edge_torch/generative/layers/unet/model_config.py +117 -0
- ai_edge_torch/generative/utilities/autoencoder_loader.py +298 -0
- ai_edge_torch/generative/utilities/loader.py +7 -5
- {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/RECORD +20 -15
- {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240606.dist-info → ai_edge_torch_nightly-0.2.0.dev20240609.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.generative.layers.attention import SelfAttention
|
|
22
|
+
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
23
|
+
import ai_edge_torch.generative.layers.unet.builder as unet_builder
|
|
24
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ResidualBlock2D(nn.Module):
|
|
28
|
+
"""2D Residual block containing two Conv2D with optional time embedding as input."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: unet_cfg.ResidualBlock2DConfig):
|
|
31
|
+
"""Initialize an instance of the ResidualBlock2D.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config (unet_cfg.ResidualBlock2DConfig): the configuration of this block.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.config = config
|
|
38
|
+
self.norm_1 = layers_builder.build_norm(
|
|
39
|
+
config.in_channels, config.normalization_config
|
|
40
|
+
)
|
|
41
|
+
self.conv_1 = nn.Conv2d(
|
|
42
|
+
config.in_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
43
|
+
)
|
|
44
|
+
if config.time_embedding_channels is not None:
|
|
45
|
+
self.time_emb_proj = nn.Linear(
|
|
46
|
+
config.time_embedding_channels, config.out_channels
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
self.time_emb_proj = None
|
|
50
|
+
self.norm_2 = layers_builder.build_norm(
|
|
51
|
+
config.out_channels, config.normalization_config
|
|
52
|
+
)
|
|
53
|
+
self.conv_2 = nn.Conv2d(
|
|
54
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
55
|
+
)
|
|
56
|
+
self.act_fn = layers_builder.get_activation(config.activation_type)
|
|
57
|
+
if config.in_channels == config.out_channels:
|
|
58
|
+
self.residual_layer = nn.Identity()
|
|
59
|
+
else:
|
|
60
|
+
self.residual_layer = nn.Conv2d(
|
|
61
|
+
config.in_channels, config.out_channels, kernel_size=1, stride=1, padding=0
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def forward(
|
|
65
|
+
self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
|
|
66
|
+
) -> torch.Tensor:
|
|
67
|
+
"""Forward function of the ResidualBlock2D.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
71
|
+
time_emb (Optional[torch.Tensor]): optional time embedding tensor.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
output hidden_states tensor after ResidualBlock2D.
|
|
75
|
+
"""
|
|
76
|
+
residual = input_tensor
|
|
77
|
+
x = self.norm_1(input_tensor)
|
|
78
|
+
x = self.act_fn(x)
|
|
79
|
+
x = self.conv_1(x)
|
|
80
|
+
if self.time_emb_proj is not None:
|
|
81
|
+
time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
|
|
82
|
+
x = x + time_emb
|
|
83
|
+
x = self.norm_2(x)
|
|
84
|
+
x = self.act_fn(x)
|
|
85
|
+
x = self.conv_2(x)
|
|
86
|
+
x = x + self.residual_layer(residual)
|
|
87
|
+
return x
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class AttentionBlock2D(nn.Module):
|
|
91
|
+
"""2D self attention block
|
|
92
|
+
|
|
93
|
+
x = SelfAttention(Norm(input_tensor))
|
|
94
|
+
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
|
|
98
|
+
"""Initialize an instance of the AttentionBlock2D.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
config (unet_cfg.AttentionBlock2DConfig): the configuration of this block.
|
|
102
|
+
"""
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.norm = layers_builder.build_norm(config.dims, config.normalization_config)
|
|
105
|
+
self.attention = SelfAttention(config.dims, config.attention_config, 0, True)
|
|
106
|
+
|
|
107
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
108
|
+
"""Forward function of the AttentionBlock2D.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
output activation tensor after self attention.
|
|
115
|
+
"""
|
|
116
|
+
residual = input_tensor
|
|
117
|
+
x = self.norm(input_tensor)
|
|
118
|
+
B, C, H, W = x.shape
|
|
119
|
+
x = x.view(B, C, H * W)
|
|
120
|
+
x = x.transpose(-1, -2)
|
|
121
|
+
x = self.attention(x)
|
|
122
|
+
x = x.transpose(-1, -2)
|
|
123
|
+
x = x.view(B, C, H, W)
|
|
124
|
+
x = x + residual
|
|
125
|
+
return x
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class UpDecoderBlock2D(nn.Module):
|
|
129
|
+
"""Decoder block containing several residual blocks followed by an optional upsampler.
|
|
130
|
+
|
|
131
|
+
input_tensor
|
|
132
|
+
|
|
|
133
|
+
▼
|
|
134
|
+
┌───────────────────┐
|
|
135
|
+
│ ResidualBlock2D │ num_layers
|
|
136
|
+
└─────────┬─────────┘
|
|
137
|
+
│
|
|
138
|
+
┌─────────▼─────────┐
|
|
139
|
+
│ (Optional) │
|
|
140
|
+
│ Upsampler │
|
|
141
|
+
└─────────┬─────────┘
|
|
142
|
+
│
|
|
143
|
+
┌─────────▼─────────┐
|
|
144
|
+
│ (Optional) │
|
|
145
|
+
│ Conv2D │
|
|
146
|
+
└─────────┬─────────┘
|
|
147
|
+
│
|
|
148
|
+
▼
|
|
149
|
+
hidden_states
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self, config: unet_cfg.UpDecoderBlock2DConfig):
|
|
153
|
+
"""Initialize an instance of the UpDecoderBlock2D.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
config (unet_cfg.UpDecoderBlock2DConfig): the configuration of this block.
|
|
157
|
+
"""
|
|
158
|
+
super().__init__()
|
|
159
|
+
self.config = config
|
|
160
|
+
resnets = []
|
|
161
|
+
for i in range(config.num_layers):
|
|
162
|
+
input_channels = config.in_channels if i == 0 else config.out_channels
|
|
163
|
+
resnets.append(
|
|
164
|
+
ResidualBlock2D(
|
|
165
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
166
|
+
in_channels=input_channels,
|
|
167
|
+
out_channels=config.out_channels,
|
|
168
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
169
|
+
normalization_config=config.normalization_config,
|
|
170
|
+
activation_type=config.activation_type,
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
self.resnets = nn.ModuleList(resnets)
|
|
175
|
+
if config.add_upsample:
|
|
176
|
+
self.upsampler = unet_builder.build_upsampling(config.sampling_config)
|
|
177
|
+
if config.upsample_conv:
|
|
178
|
+
self.upsample_conv = nn.Conv2d(
|
|
179
|
+
config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
self.upsampler = None
|
|
183
|
+
|
|
184
|
+
def forward(
|
|
185
|
+
self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""Forward function of the UpDecoderBlock2D.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
191
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
192
|
+
time embedding context.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
output hidden_states tensor after UpDecoderBlock2D.
|
|
196
|
+
"""
|
|
197
|
+
hidden_states = input_tensor
|
|
198
|
+
for resnet in self.resnets:
|
|
199
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
200
|
+
if self.upsampler:
|
|
201
|
+
hidden_states = self.upsampler(hidden_states)
|
|
202
|
+
if self.upsample_conv:
|
|
203
|
+
hidden_states = self.upsample_conv(hidden_states)
|
|
204
|
+
return hidden_states
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class MidBlock2D(nn.Module):
|
|
208
|
+
"""Middle block containing at least one residual blocks with optional interleaved attention blocks.
|
|
209
|
+
|
|
210
|
+
input_tensor
|
|
211
|
+
|
|
|
212
|
+
▼
|
|
213
|
+
┌───────────────────┐
|
|
214
|
+
│ ResidualBlock2D │
|
|
215
|
+
└─────────┬─────────┘
|
|
216
|
+
│
|
|
217
|
+
┌─────────────▼─────────────┐
|
|
218
|
+
│ ┌───────────────────┐ │
|
|
219
|
+
│ │ (Optional) │ │
|
|
220
|
+
│ │ AttentionBlock2D │ │
|
|
221
|
+
│ └─────────┬─────────┘ │ num_layers
|
|
222
|
+
│ │ │
|
|
223
|
+
│ ┌─────────▼─────────┐ │
|
|
224
|
+
│ │ ResidualBlock2D │ │
|
|
225
|
+
│ └───────────────────┘ │
|
|
226
|
+
└─────────────┬─────────────┘
|
|
227
|
+
│
|
|
228
|
+
▼
|
|
229
|
+
hidden_states
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(self, config: unet_cfg.MidBlock2DConfig):
|
|
233
|
+
"""Initialize an instance of the MidBlock2D.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
config (unet_cfg.MidBlock2DConfig): the configuration of this block.
|
|
237
|
+
"""
|
|
238
|
+
super().__init__()
|
|
239
|
+
self.config = config
|
|
240
|
+
resnets = [
|
|
241
|
+
ResidualBlock2D(
|
|
242
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
243
|
+
in_channels=config.in_channels,
|
|
244
|
+
out_channels=config.in_channels,
|
|
245
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
246
|
+
normalization_config=config.normalization_config,
|
|
247
|
+
activation_type=config.activation_type,
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
]
|
|
251
|
+
attentions = []
|
|
252
|
+
for i in range(config.num_layers):
|
|
253
|
+
if self.config.attention_block_config:
|
|
254
|
+
attentions.append(AttentionBlock2D(config.attention_block_config))
|
|
255
|
+
resnets.append(
|
|
256
|
+
ResidualBlock2D(
|
|
257
|
+
unet_cfg.ResidualBlock2DConfig(
|
|
258
|
+
in_channels=config.in_channels,
|
|
259
|
+
out_channels=config.in_channels,
|
|
260
|
+
time_embedding_channels=config.time_embedding_channels,
|
|
261
|
+
normalization_config=config.normalization_config,
|
|
262
|
+
activation_type=config.activation_type,
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
self.resnets = nn.ModuleList(resnets)
|
|
267
|
+
self.attentions = nn.ModuleList(attentions)
|
|
268
|
+
|
|
269
|
+
def forward(
|
|
270
|
+
self, input_tensor: torch.Tensor, time_emb: Optional[torch.Tensor] = None
|
|
271
|
+
) -> torch.Tensor:
|
|
272
|
+
"""Forward function of the MidBlock2D.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
input_tensor (torch.Tensor): the input tensor.
|
|
276
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is configured to accept
|
|
277
|
+
time embedding context.
|
|
278
|
+
|
|
279
|
+
Returns:
|
|
280
|
+
output hidden_states tensor after MidBlock2D.
|
|
281
|
+
"""
|
|
282
|
+
hidden_states = self.resnets[0](input_tensor, time_emb)
|
|
283
|
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
|
284
|
+
if attn is not None:
|
|
285
|
+
hidden_states = attn(hidden_states)
|
|
286
|
+
hidden_states = resnet(hidden_states, time_emb)
|
|
287
|
+
return hidden_states
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Builder utils for individual components.
|
|
16
|
+
|
|
17
|
+
from torch import nn
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
|
|
20
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def build_upsampling(config: unet_config.SamplingConfig):
|
|
24
|
+
if config.mode == unet_config.SamplingType.NEAREST:
|
|
25
|
+
return nn.UpsamplingNearest2d(scale_factor=config.scale_factor)
|
|
26
|
+
elif config.mode == unet_config.SamplingType.BILINEAR:
|
|
27
|
+
return nn.UpsamplingBilinear2d(scale_factor=config.scale_factor)
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError("Unsupported upsampling type.")
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# UNet configuration class.
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from dataclasses import field
|
|
19
|
+
import enum
|
|
20
|
+
from typing import List, Optional
|
|
21
|
+
|
|
22
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SamplingType(enum.Enum):
|
|
27
|
+
NEAREST = enum.auto()
|
|
28
|
+
BILINEAR = enum.auto()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class SamplingConfig:
|
|
33
|
+
scale_factor: float
|
|
34
|
+
mode: SamplingType
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ResidualBlock2DConfig:
|
|
39
|
+
in_channels: int
|
|
40
|
+
out_channels: int
|
|
41
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
42
|
+
activation_type: layers_cfg.ActivationType
|
|
43
|
+
# Optional time embedding channels if the residual block takes a time embedding context as input
|
|
44
|
+
time_embedding_channels: Optional[int] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class AttentionBlock2DConfig:
|
|
49
|
+
dims: int
|
|
50
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
51
|
+
attention_config: layers_cfg.AttentionConfig
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class UpDecoderBlock2DConfig:
|
|
56
|
+
in_channels: int
|
|
57
|
+
out_channels: int
|
|
58
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
59
|
+
activation_type: layers_cfg.ActivationType
|
|
60
|
+
num_layers: int
|
|
61
|
+
# Optional time embedding channels if the residual blocks take a time embedding context as input
|
|
62
|
+
time_embedding_channels: Optional[int] = None
|
|
63
|
+
# Whether to add upsample operation after residual blocks
|
|
64
|
+
add_upsample: bool = True
|
|
65
|
+
# Whether to add a conv2d layer after upsample
|
|
66
|
+
upsample_conv: bool = True
|
|
67
|
+
# Optional sampling config if add_upsample is True.
|
|
68
|
+
sampling_config: Optional[SamplingConfig] = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class MidBlock2DConfig:
|
|
73
|
+
in_channels: int
|
|
74
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
75
|
+
activation_type: layers_cfg.ActivationType
|
|
76
|
+
num_layers: int
|
|
77
|
+
# Optional time embedding channels if the residual blocks take a time embedding context as input
|
|
78
|
+
time_embedding_channels: Optional[int] = None
|
|
79
|
+
# Optional config of attention blocks interleaved with residual blocks
|
|
80
|
+
attention_block_config: Optional[AttentionBlock2DConfig] = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class AutoEncoderConfig:
|
|
85
|
+
"""Configurations of encoder/decoder in the autoencoder model."""
|
|
86
|
+
|
|
87
|
+
# The activation type of encoder/decoder blocks.
|
|
88
|
+
activation_type: layers_cfg.ActivationType
|
|
89
|
+
|
|
90
|
+
# The output channels of each block.
|
|
91
|
+
block_out_channels: List[int]
|
|
92
|
+
|
|
93
|
+
# Number of channels in the input image.
|
|
94
|
+
in_channels: int
|
|
95
|
+
|
|
96
|
+
# Number of channels in the output.
|
|
97
|
+
out_channels: int
|
|
98
|
+
|
|
99
|
+
# Number of channels in the latent space.
|
|
100
|
+
latent_channels: int
|
|
101
|
+
|
|
102
|
+
# The component-wise standard deviation of the trained latent space computed using the first batch of the
|
|
103
|
+
# training set. This is used to scale the latent space to have unit variance when training the diffusion
|
|
104
|
+
# model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
|
105
|
+
# diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
|
106
|
+
# / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
|
107
|
+
# Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
|
108
|
+
scaling_factor: float
|
|
109
|
+
|
|
110
|
+
# The layesr number of each encoder/decoder block.
|
|
111
|
+
layers_per_block: int
|
|
112
|
+
|
|
113
|
+
# The normalization config.
|
|
114
|
+
normalization_config: layers_cfg.NormalizationConfig
|
|
115
|
+
|
|
116
|
+
# The configuration of middle blocks, that is, after the last block of encoder and before the first block of decoder.
|
|
117
|
+
mid_block_config: MidBlock2DConfig
|