nexaai 1.0.16rc10__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc11__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +60 -14
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/ml.py +60 -14
- nexaai/mlx_backend/sd/modeling/model_io.py +72 -17
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc11.dist-info}/METADATA +1 -1
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc11.dist-info}/RECORD +23 -11
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc11.dist-info}/WHEEL +0 -0
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
# Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from .config import UNetConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def upsample_nearest(x, scale: int = 2):
|
|
13
|
+
B, H, W, C = x.shape
|
|
14
|
+
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
|
15
|
+
x = x.reshape(B, H * scale, W * scale, C)
|
|
16
|
+
|
|
17
|
+
return x
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TimestepEmbedding(nn.Module):
|
|
21
|
+
def __init__(self, in_channels: int, time_embed_dim: int):
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
|
25
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
|
26
|
+
|
|
27
|
+
def __call__(self, x):
|
|
28
|
+
x = self.linear_1(x)
|
|
29
|
+
x = nn.silu(x)
|
|
30
|
+
x = self.linear_2(x)
|
|
31
|
+
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class TransformerBlock(nn.Module):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
model_dims: int,
|
|
39
|
+
num_heads: int,
|
|
40
|
+
hidden_dims: Optional[int] = None,
|
|
41
|
+
memory_dims: Optional[int] = None,
|
|
42
|
+
):
|
|
43
|
+
super().__init__()
|
|
44
|
+
|
|
45
|
+
self.norm1 = nn.LayerNorm(model_dims)
|
|
46
|
+
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
|
47
|
+
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
|
48
|
+
|
|
49
|
+
memory_dims = memory_dims or model_dims
|
|
50
|
+
self.norm2 = nn.LayerNorm(model_dims)
|
|
51
|
+
self.attn2 = nn.MultiHeadAttention(
|
|
52
|
+
model_dims, num_heads, key_input_dims=memory_dims
|
|
53
|
+
)
|
|
54
|
+
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
|
55
|
+
|
|
56
|
+
hidden_dims = hidden_dims or 4 * model_dims
|
|
57
|
+
self.norm3 = nn.LayerNorm(model_dims)
|
|
58
|
+
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
|
59
|
+
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
|
60
|
+
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
|
61
|
+
|
|
62
|
+
def __call__(self, x, memory, attn_mask, memory_mask):
|
|
63
|
+
# Self attention
|
|
64
|
+
y = self.norm1(x)
|
|
65
|
+
y = self.attn1(y, y, y, attn_mask)
|
|
66
|
+
x = x + y
|
|
67
|
+
|
|
68
|
+
# Cross attention
|
|
69
|
+
y = self.norm2(x)
|
|
70
|
+
y = self.attn2(y, memory, memory, memory_mask)
|
|
71
|
+
x = x + y
|
|
72
|
+
|
|
73
|
+
# FFN
|
|
74
|
+
y = self.norm3(x)
|
|
75
|
+
y_a = self.linear1(y)
|
|
76
|
+
y_b = self.linear2(y)
|
|
77
|
+
y = y_a * nn.gelu(y_b)
|
|
78
|
+
y = self.linear3(y)
|
|
79
|
+
x = x + y
|
|
80
|
+
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Transformer2D(nn.Module):
|
|
85
|
+
"""A transformer model for inputs with 2 spatial dimensions."""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
in_channels: int,
|
|
90
|
+
model_dims: int,
|
|
91
|
+
encoder_dims: int,
|
|
92
|
+
num_heads: int,
|
|
93
|
+
num_layers: int = 1,
|
|
94
|
+
norm_num_groups: int = 32,
|
|
95
|
+
):
|
|
96
|
+
super().__init__()
|
|
97
|
+
|
|
98
|
+
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
|
99
|
+
self.proj_in = nn.Linear(in_channels, model_dims)
|
|
100
|
+
self.transformer_blocks = [
|
|
101
|
+
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
|
102
|
+
for i in range(num_layers)
|
|
103
|
+
]
|
|
104
|
+
self.proj_out = nn.Linear(model_dims, in_channels)
|
|
105
|
+
|
|
106
|
+
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
|
107
|
+
# Save the input to add to the output
|
|
108
|
+
input_x = x
|
|
109
|
+
dtype = x.dtype
|
|
110
|
+
|
|
111
|
+
# Perform the input norm and projection
|
|
112
|
+
B, H, W, C = x.shape
|
|
113
|
+
x = self.norm(x).reshape(B, -1, C)
|
|
114
|
+
x = self.proj_in(x)
|
|
115
|
+
|
|
116
|
+
# Apply the transformer
|
|
117
|
+
for block in self.transformer_blocks:
|
|
118
|
+
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
|
119
|
+
|
|
120
|
+
# Apply the output projection and reshape
|
|
121
|
+
x = self.proj_out(x)
|
|
122
|
+
x = x.reshape(B, H, W, C)
|
|
123
|
+
|
|
124
|
+
return x + input_x
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ResnetBlock2D(nn.Module):
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
in_channels: int,
|
|
131
|
+
out_channels: Optional[int] = None,
|
|
132
|
+
groups: int = 32,
|
|
133
|
+
temb_channels: Optional[int] = None,
|
|
134
|
+
):
|
|
135
|
+
super().__init__()
|
|
136
|
+
|
|
137
|
+
out_channels = out_channels or in_channels
|
|
138
|
+
|
|
139
|
+
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
|
140
|
+
self.conv1 = nn.Conv2d(
|
|
141
|
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
142
|
+
)
|
|
143
|
+
if temb_channels is not None:
|
|
144
|
+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
|
145
|
+
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
|
146
|
+
self.conv2 = nn.Conv2d(
|
|
147
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if in_channels != out_channels:
|
|
151
|
+
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
|
152
|
+
|
|
153
|
+
def __call__(self, x, temb=None):
|
|
154
|
+
dtype = x.dtype
|
|
155
|
+
|
|
156
|
+
if temb is not None:
|
|
157
|
+
temb = self.time_emb_proj(nn.silu(temb))
|
|
158
|
+
|
|
159
|
+
y = self.norm1(x)
|
|
160
|
+
y = nn.silu(y)
|
|
161
|
+
y = self.conv1(y)
|
|
162
|
+
if temb is not None:
|
|
163
|
+
y = y + temb[:, None, None, :]
|
|
164
|
+
y = self.norm2(y)
|
|
165
|
+
y = nn.silu(y)
|
|
166
|
+
y = self.conv2(y)
|
|
167
|
+
|
|
168
|
+
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
|
169
|
+
|
|
170
|
+
return x
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class UNetBlock2D(nn.Module):
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
in_channels: int,
|
|
177
|
+
out_channels: int,
|
|
178
|
+
temb_channels: int,
|
|
179
|
+
prev_out_channels: Optional[int] = None,
|
|
180
|
+
num_layers: int = 1,
|
|
181
|
+
transformer_layers_per_block: int = 1,
|
|
182
|
+
num_attention_heads: int = 8,
|
|
183
|
+
cross_attention_dim=1280,
|
|
184
|
+
resnet_groups: int = 32,
|
|
185
|
+
add_downsample=True,
|
|
186
|
+
add_upsample=True,
|
|
187
|
+
add_cross_attention=True,
|
|
188
|
+
):
|
|
189
|
+
super().__init__()
|
|
190
|
+
|
|
191
|
+
# Prepare the in channels list for the resnets
|
|
192
|
+
if prev_out_channels is None:
|
|
193
|
+
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
|
194
|
+
else:
|
|
195
|
+
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
|
196
|
+
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
|
197
|
+
in_channels_list = [
|
|
198
|
+
a + b for a, b in zip(in_channels_list, res_channels_list)
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
# Add resnet blocks that also process the time embedding
|
|
202
|
+
self.resnets = [
|
|
203
|
+
ResnetBlock2D(
|
|
204
|
+
in_channels=ic,
|
|
205
|
+
out_channels=out_channels,
|
|
206
|
+
temb_channels=temb_channels,
|
|
207
|
+
groups=resnet_groups,
|
|
208
|
+
)
|
|
209
|
+
for ic in in_channels_list
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
# Add optional cross attention layers
|
|
213
|
+
if add_cross_attention:
|
|
214
|
+
self.attentions = [
|
|
215
|
+
Transformer2D(
|
|
216
|
+
in_channels=out_channels,
|
|
217
|
+
model_dims=out_channels,
|
|
218
|
+
num_heads=num_attention_heads,
|
|
219
|
+
num_layers=transformer_layers_per_block,
|
|
220
|
+
encoder_dims=cross_attention_dim,
|
|
221
|
+
)
|
|
222
|
+
for i in range(num_layers)
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
# Add an optional downsampling layer
|
|
226
|
+
if add_downsample:
|
|
227
|
+
self.downsample = nn.Conv2d(
|
|
228
|
+
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# or upsampling layer
|
|
232
|
+
if add_upsample:
|
|
233
|
+
self.upsample = nn.Conv2d(
|
|
234
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def __call__(
|
|
238
|
+
self,
|
|
239
|
+
x,
|
|
240
|
+
encoder_x=None,
|
|
241
|
+
temb=None,
|
|
242
|
+
attn_mask=None,
|
|
243
|
+
encoder_attn_mask=None,
|
|
244
|
+
residual_hidden_states=None,
|
|
245
|
+
):
|
|
246
|
+
output_states = []
|
|
247
|
+
|
|
248
|
+
for i in range(len(self.resnets)):
|
|
249
|
+
if residual_hidden_states is not None:
|
|
250
|
+
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
|
251
|
+
|
|
252
|
+
x = self.resnets[i](x, temb)
|
|
253
|
+
|
|
254
|
+
if "attentions" in self:
|
|
255
|
+
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
|
256
|
+
|
|
257
|
+
output_states.append(x)
|
|
258
|
+
|
|
259
|
+
if "downsample" in self:
|
|
260
|
+
x = self.downsample(x)
|
|
261
|
+
output_states.append(x)
|
|
262
|
+
|
|
263
|
+
if "upsample" in self:
|
|
264
|
+
x = self.upsample(upsample_nearest(x))
|
|
265
|
+
output_states.append(x)
|
|
266
|
+
|
|
267
|
+
return x, output_states
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class UNetModel(nn.Module):
|
|
271
|
+
"""The conditional 2D UNet model that actually performs the denoising."""
|
|
272
|
+
|
|
273
|
+
def __init__(self, config: UNetConfig):
|
|
274
|
+
super().__init__()
|
|
275
|
+
|
|
276
|
+
self.conv_in = nn.Conv2d(
|
|
277
|
+
config.in_channels,
|
|
278
|
+
config.block_out_channels[0],
|
|
279
|
+
config.conv_in_kernel,
|
|
280
|
+
padding=(config.conv_in_kernel - 1) // 2,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
self.timesteps = nn.SinusoidalPositionalEncoding(
|
|
284
|
+
config.block_out_channels[0],
|
|
285
|
+
max_freq=1,
|
|
286
|
+
min_freq=math.exp(
|
|
287
|
+
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
|
288
|
+
),
|
|
289
|
+
scale=1.0,
|
|
290
|
+
cos_first=True,
|
|
291
|
+
full_turns=False,
|
|
292
|
+
)
|
|
293
|
+
self.time_embedding = TimestepEmbedding(
|
|
294
|
+
config.block_out_channels[0],
|
|
295
|
+
config.block_out_channels[0] * 4,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
if config.addition_embed_type == "text_time":
|
|
299
|
+
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
|
300
|
+
config.addition_time_embed_dim,
|
|
301
|
+
max_freq=1,
|
|
302
|
+
min_freq=math.exp(
|
|
303
|
+
-math.log(10000)
|
|
304
|
+
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
|
305
|
+
),
|
|
306
|
+
scale=1.0,
|
|
307
|
+
cos_first=True,
|
|
308
|
+
full_turns=False,
|
|
309
|
+
)
|
|
310
|
+
self.add_embedding = TimestepEmbedding(
|
|
311
|
+
config.projection_class_embeddings_input_dim,
|
|
312
|
+
config.block_out_channels[0] * 4,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Make the downsampling blocks
|
|
316
|
+
block_channels = [config.block_out_channels[0]] + list(
|
|
317
|
+
config.block_out_channels
|
|
318
|
+
)
|
|
319
|
+
self.down_blocks = [
|
|
320
|
+
UNetBlock2D(
|
|
321
|
+
in_channels=in_channels,
|
|
322
|
+
out_channels=out_channels,
|
|
323
|
+
temb_channels=config.block_out_channels[0] * 4,
|
|
324
|
+
num_layers=config.layers_per_block[i],
|
|
325
|
+
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
|
326
|
+
num_attention_heads=config.num_attention_heads[i],
|
|
327
|
+
cross_attention_dim=config.cross_attention_dim[i],
|
|
328
|
+
resnet_groups=config.norm_num_groups,
|
|
329
|
+
add_downsample=(i < len(config.block_out_channels) - 1),
|
|
330
|
+
add_upsample=False,
|
|
331
|
+
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
|
332
|
+
)
|
|
333
|
+
for i, (in_channels, out_channels) in enumerate(
|
|
334
|
+
zip(block_channels, block_channels[1:])
|
|
335
|
+
)
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
# Make the middle block
|
|
339
|
+
self.mid_blocks = [
|
|
340
|
+
ResnetBlock2D(
|
|
341
|
+
in_channels=config.block_out_channels[-1],
|
|
342
|
+
out_channels=config.block_out_channels[-1],
|
|
343
|
+
temb_channels=config.block_out_channels[0] * 4,
|
|
344
|
+
groups=config.norm_num_groups,
|
|
345
|
+
),
|
|
346
|
+
Transformer2D(
|
|
347
|
+
in_channels=config.block_out_channels[-1],
|
|
348
|
+
model_dims=config.block_out_channels[-1],
|
|
349
|
+
num_heads=config.num_attention_heads[-1],
|
|
350
|
+
num_layers=config.transformer_layers_per_block[-1],
|
|
351
|
+
encoder_dims=config.cross_attention_dim[-1],
|
|
352
|
+
),
|
|
353
|
+
ResnetBlock2D(
|
|
354
|
+
in_channels=config.block_out_channels[-1],
|
|
355
|
+
out_channels=config.block_out_channels[-1],
|
|
356
|
+
temb_channels=config.block_out_channels[0] * 4,
|
|
357
|
+
groups=config.norm_num_groups,
|
|
358
|
+
),
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
# Make the upsampling blocks
|
|
362
|
+
block_channels = (
|
|
363
|
+
[config.block_out_channels[0]]
|
|
364
|
+
+ list(config.block_out_channels)
|
|
365
|
+
+ [config.block_out_channels[-1]]
|
|
366
|
+
)
|
|
367
|
+
self.up_blocks = [
|
|
368
|
+
UNetBlock2D(
|
|
369
|
+
in_channels=in_channels,
|
|
370
|
+
out_channels=out_channels,
|
|
371
|
+
temb_channels=config.block_out_channels[0] * 4,
|
|
372
|
+
prev_out_channels=prev_out_channels,
|
|
373
|
+
num_layers=config.layers_per_block[i] + 1,
|
|
374
|
+
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
|
375
|
+
num_attention_heads=config.num_attention_heads[i],
|
|
376
|
+
cross_attention_dim=config.cross_attention_dim[i],
|
|
377
|
+
resnet_groups=config.norm_num_groups,
|
|
378
|
+
add_downsample=False,
|
|
379
|
+
add_upsample=(i > 0),
|
|
380
|
+
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
|
381
|
+
)
|
|
382
|
+
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
|
383
|
+
list(
|
|
384
|
+
enumerate(
|
|
385
|
+
zip(block_channels, block_channels[1:], block_channels[2:])
|
|
386
|
+
)
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
self.conv_norm_out = nn.GroupNorm(
|
|
392
|
+
config.norm_num_groups,
|
|
393
|
+
config.block_out_channels[0],
|
|
394
|
+
pytorch_compatible=True,
|
|
395
|
+
)
|
|
396
|
+
self.conv_out = nn.Conv2d(
|
|
397
|
+
config.block_out_channels[0],
|
|
398
|
+
config.out_channels,
|
|
399
|
+
config.conv_out_kernel,
|
|
400
|
+
padding=(config.conv_out_kernel - 1) // 2,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
def __call__(
|
|
404
|
+
self,
|
|
405
|
+
x,
|
|
406
|
+
timestep,
|
|
407
|
+
encoder_x,
|
|
408
|
+
attn_mask=None,
|
|
409
|
+
encoder_attn_mask=None,
|
|
410
|
+
text_time=None,
|
|
411
|
+
):
|
|
412
|
+
# Compute the time embeddings
|
|
413
|
+
temb = self.timesteps(timestep).astype(x.dtype)
|
|
414
|
+
temb = self.time_embedding(temb)
|
|
415
|
+
|
|
416
|
+
# Add the extra text_time conditioning
|
|
417
|
+
if text_time is not None:
|
|
418
|
+
text_emb, time_ids = text_time
|
|
419
|
+
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
|
420
|
+
emb = mx.concatenate([text_emb, emb], axis=-1)
|
|
421
|
+
emb = self.add_embedding(emb)
|
|
422
|
+
temb = temb + emb
|
|
423
|
+
|
|
424
|
+
# Preprocess the input
|
|
425
|
+
x = self.conv_in(x)
|
|
426
|
+
|
|
427
|
+
# Run the downsampling part of the unet
|
|
428
|
+
residuals = [x]
|
|
429
|
+
for block in self.down_blocks:
|
|
430
|
+
x, res = block(
|
|
431
|
+
x,
|
|
432
|
+
encoder_x=encoder_x,
|
|
433
|
+
temb=temb,
|
|
434
|
+
attn_mask=attn_mask,
|
|
435
|
+
encoder_attn_mask=encoder_attn_mask,
|
|
436
|
+
)
|
|
437
|
+
residuals.extend(res)
|
|
438
|
+
|
|
439
|
+
# Run the middle part of the unet
|
|
440
|
+
x = self.mid_blocks[0](x, temb)
|
|
441
|
+
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
|
442
|
+
x = self.mid_blocks[2](x, temb)
|
|
443
|
+
|
|
444
|
+
# Run the upsampling part of the unet
|
|
445
|
+
for block in self.up_blocks:
|
|
446
|
+
x, _ = block(
|
|
447
|
+
x,
|
|
448
|
+
encoder_x=encoder_x,
|
|
449
|
+
temb=temb,
|
|
450
|
+
attn_mask=attn_mask,
|
|
451
|
+
encoder_attn_mask=encoder_attn_mask,
|
|
452
|
+
residual_hidden_states=residuals,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# Postprocess the output
|
|
456
|
+
x = self.conv_norm_out(x)
|
|
457
|
+
x = nn.silu(x)
|
|
458
|
+
x = self.conv_out(x)
|
|
459
|
+
|
|
460
|
+
return x
|