nexaai 1.0.16rc8__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc10__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

@@ -1,274 +0,0 @@
1
- # Copyright © 2023 Apple Inc.
2
-
3
- import math
4
- from typing import List
5
-
6
- import mlx.core as mx
7
- import mlx.nn as nn
8
-
9
- from .config import AutoencoderConfig
10
- from .unet import ResnetBlock2D, upsample_nearest
11
-
12
-
13
- class Attention(nn.Module):
14
- """A single head unmasked attention for use with the VAE."""
15
-
16
- def __init__(self, dims: int, norm_groups: int = 32):
17
- super().__init__()
18
-
19
- self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
20
- self.query_proj = nn.Linear(dims, dims)
21
- self.key_proj = nn.Linear(dims, dims)
22
- self.value_proj = nn.Linear(dims, dims)
23
- self.out_proj = nn.Linear(dims, dims)
24
-
25
- def __call__(self, x):
26
- B, H, W, C = x.shape
27
-
28
- y = self.group_norm(x)
29
-
30
- queries = self.query_proj(y).reshape(B, H * W, C)
31
- keys = self.key_proj(y).reshape(B, H * W, C)
32
- values = self.value_proj(y).reshape(B, H * W, C)
33
-
34
- scale = 1 / math.sqrt(queries.shape[-1])
35
- scores = (queries * scale) @ keys.transpose(0, 2, 1)
36
- attn = mx.softmax(scores, axis=-1)
37
- y = (attn @ values).reshape(B, H, W, C)
38
-
39
- y = self.out_proj(y)
40
- x = x + y
41
-
42
- return x
43
-
44
-
45
- class EncoderDecoderBlock2D(nn.Module):
46
- def __init__(
47
- self,
48
- in_channels: int,
49
- out_channels: int,
50
- num_layers: int = 1,
51
- resnet_groups: int = 32,
52
- add_downsample=True,
53
- add_upsample=True,
54
- ):
55
- super().__init__()
56
-
57
- # Add the resnet blocks
58
- self.resnets = [
59
- ResnetBlock2D(
60
- in_channels=in_channels if i == 0 else out_channels,
61
- out_channels=out_channels,
62
- groups=resnet_groups,
63
- )
64
- for i in range(num_layers)
65
- ]
66
-
67
- # Add an optional downsampling layer
68
- if add_downsample:
69
- self.downsample = nn.Conv2d(
70
- out_channels, out_channels, kernel_size=3, stride=2, padding=0
71
- )
72
-
73
- # or upsampling layer
74
- if add_upsample:
75
- self.upsample = nn.Conv2d(
76
- out_channels, out_channels, kernel_size=3, stride=1, padding=1
77
- )
78
-
79
- def __call__(self, x):
80
- for resnet in self.resnets:
81
- x = resnet(x)
82
-
83
- if "downsample" in self:
84
- x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
85
- x = self.downsample(x)
86
-
87
- if "upsample" in self:
88
- x = self.upsample(upsample_nearest(x))
89
-
90
- return x
91
-
92
-
93
- class Encoder(nn.Module):
94
- """Implements the encoder side of the Autoencoder."""
95
-
96
- def __init__(
97
- self,
98
- in_channels: int,
99
- out_channels: int,
100
- block_out_channels: List[int] = [64],
101
- layers_per_block: int = 2,
102
- resnet_groups: int = 32,
103
- ):
104
- super().__init__()
105
-
106
- self.conv_in = nn.Conv2d(
107
- in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
108
- )
109
-
110
- channels = [block_out_channels[0]] + list(block_out_channels)
111
- self.down_blocks = [
112
- EncoderDecoderBlock2D(
113
- in_channels,
114
- out_channels,
115
- num_layers=layers_per_block,
116
- resnet_groups=resnet_groups,
117
- add_downsample=i < len(block_out_channels) - 1,
118
- add_upsample=False,
119
- )
120
- for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
121
- ]
122
-
123
- self.mid_blocks = [
124
- ResnetBlock2D(
125
- in_channels=block_out_channels[-1],
126
- out_channels=block_out_channels[-1],
127
- groups=resnet_groups,
128
- ),
129
- Attention(block_out_channels[-1], resnet_groups),
130
- ResnetBlock2D(
131
- in_channels=block_out_channels[-1],
132
- out_channels=block_out_channels[-1],
133
- groups=resnet_groups,
134
- ),
135
- ]
136
-
137
- self.conv_norm_out = nn.GroupNorm(
138
- resnet_groups, block_out_channels[-1], pytorch_compatible=True
139
- )
140
- self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
141
-
142
- def __call__(self, x):
143
- x = self.conv_in(x)
144
-
145
- for l in self.down_blocks:
146
- x = l(x)
147
-
148
- x = self.mid_blocks[0](x)
149
- x = self.mid_blocks[1](x)
150
- x = self.mid_blocks[2](x)
151
-
152
- x = self.conv_norm_out(x)
153
- x = nn.silu(x)
154
- x = self.conv_out(x)
155
-
156
- return x
157
-
158
-
159
- class Decoder(nn.Module):
160
- """Implements the decoder side of the Autoencoder."""
161
-
162
- def __init__(
163
- self,
164
- in_channels: int,
165
- out_channels: int,
166
- block_out_channels: List[int] = [64],
167
- layers_per_block: int = 2,
168
- resnet_groups: int = 32,
169
- ):
170
- super().__init__()
171
-
172
- self.conv_in = nn.Conv2d(
173
- in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
174
- )
175
-
176
- self.mid_blocks = [
177
- ResnetBlock2D(
178
- in_channels=block_out_channels[-1],
179
- out_channels=block_out_channels[-1],
180
- groups=resnet_groups,
181
- ),
182
- Attention(block_out_channels[-1], resnet_groups),
183
- ResnetBlock2D(
184
- in_channels=block_out_channels[-1],
185
- out_channels=block_out_channels[-1],
186
- groups=resnet_groups,
187
- ),
188
- ]
189
-
190
- channels = list(reversed(block_out_channels))
191
- channels = [channels[0]] + channels
192
- self.up_blocks = [
193
- EncoderDecoderBlock2D(
194
- in_channels,
195
- out_channels,
196
- num_layers=layers_per_block,
197
- resnet_groups=resnet_groups,
198
- add_downsample=False,
199
- add_upsample=i < len(block_out_channels) - 1,
200
- )
201
- for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
202
- ]
203
-
204
- self.conv_norm_out = nn.GroupNorm(
205
- resnet_groups, block_out_channels[0], pytorch_compatible=True
206
- )
207
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
208
-
209
- def __call__(self, x):
210
- x = self.conv_in(x)
211
-
212
- x = self.mid_blocks[0](x)
213
- x = self.mid_blocks[1](x)
214
- x = self.mid_blocks[2](x)
215
-
216
- for l in self.up_blocks:
217
- x = l(x)
218
-
219
- x = self.conv_norm_out(x)
220
- x = nn.silu(x)
221
- x = self.conv_out(x)
222
-
223
- return x
224
-
225
-
226
- class Autoencoder(nn.Module):
227
- """The autoencoder that allows us to perform diffusion in the latent space."""
228
-
229
- def __init__(self, config: AutoencoderConfig):
230
- super().__init__()
231
-
232
- self.latent_channels = config.latent_channels_in
233
- self.scaling_factor = config.scaling_factor
234
- self.encoder = Encoder(
235
- config.in_channels,
236
- config.latent_channels_out,
237
- config.block_out_channels,
238
- config.layers_per_block,
239
- resnet_groups=config.norm_num_groups,
240
- )
241
- self.decoder = Decoder(
242
- config.latent_channels_in,
243
- config.out_channels,
244
- config.block_out_channels,
245
- config.layers_per_block + 1,
246
- resnet_groups=config.norm_num_groups,
247
- )
248
-
249
- self.quant_proj = nn.Linear(
250
- config.latent_channels_out, config.latent_channels_out
251
- )
252
- self.post_quant_proj = nn.Linear(
253
- config.latent_channels_in, config.latent_channels_in
254
- )
255
-
256
- def decode(self, z):
257
- z = z / self.scaling_factor
258
- return self.decoder(self.post_quant_proj(z))
259
-
260
- def encode(self, x):
261
- x = self.encoder(x)
262
- x = self.quant_proj(x)
263
- mean, logvar = x.split(2, axis=-1)
264
- mean = mean * self.scaling_factor
265
- logvar = logvar + 2 * math.log(self.scaling_factor)
266
-
267
- return mean, logvar
268
-
269
- def __call__(self, x, key=None):
270
- mean, logvar = self.encode(x)
271
- z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
272
- x_hat = self.decode(z)
273
-
274
- return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)