nexaai 1.0.16rc9__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc10__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +0 -7
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +14 -60
- nexaai/mlx_backend/ml.py +14 -60
- nexaai/mlx_backend/sd/modeling/model_io.py +17 -72
- nexaai/runtime.py +0 -4
- {nexaai-1.0.16rc9.dist-info → nexaai-1.0.16rc10.dist-info}/METADATA +1 -1
- {nexaai-1.0.16rc9.dist-info → nexaai-1.0.16rc10.dist-info}/RECORD +16 -29
- nexaai/log.py +0 -92
- nexaai/mlx_backend/image_gen/__init__.py +0 -1
- nexaai/mlx_backend/image_gen/generate_sd.py +0 -244
- nexaai/mlx_backend/image_gen/interface.py +0 -82
- nexaai/mlx_backend/image_gen/main.py +0 -281
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +0 -65
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +0 -274
- {nexaai-1.0.16rc9.dist-info → nexaai-1.0.16rc10.dist-info}/WHEEL +0 -0
- {nexaai-1.0.16rc9.dist-info → nexaai-1.0.16rc10.dist-info}/top_level.txt +0 -0
|
@@ -1,274 +0,0 @@
|
|
|
1
|
-
# Copyright © 2023 Apple Inc.
|
|
2
|
-
|
|
3
|
-
import math
|
|
4
|
-
from typing import List
|
|
5
|
-
|
|
6
|
-
import mlx.core as mx
|
|
7
|
-
import mlx.nn as nn
|
|
8
|
-
|
|
9
|
-
from .config import AutoencoderConfig
|
|
10
|
-
from .unet import ResnetBlock2D, upsample_nearest
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class Attention(nn.Module):
|
|
14
|
-
"""A single head unmasked attention for use with the VAE."""
|
|
15
|
-
|
|
16
|
-
def __init__(self, dims: int, norm_groups: int = 32):
|
|
17
|
-
super().__init__()
|
|
18
|
-
|
|
19
|
-
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
|
20
|
-
self.query_proj = nn.Linear(dims, dims)
|
|
21
|
-
self.key_proj = nn.Linear(dims, dims)
|
|
22
|
-
self.value_proj = nn.Linear(dims, dims)
|
|
23
|
-
self.out_proj = nn.Linear(dims, dims)
|
|
24
|
-
|
|
25
|
-
def __call__(self, x):
|
|
26
|
-
B, H, W, C = x.shape
|
|
27
|
-
|
|
28
|
-
y = self.group_norm(x)
|
|
29
|
-
|
|
30
|
-
queries = self.query_proj(y).reshape(B, H * W, C)
|
|
31
|
-
keys = self.key_proj(y).reshape(B, H * W, C)
|
|
32
|
-
values = self.value_proj(y).reshape(B, H * W, C)
|
|
33
|
-
|
|
34
|
-
scale = 1 / math.sqrt(queries.shape[-1])
|
|
35
|
-
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
|
36
|
-
attn = mx.softmax(scores, axis=-1)
|
|
37
|
-
y = (attn @ values).reshape(B, H, W, C)
|
|
38
|
-
|
|
39
|
-
y = self.out_proj(y)
|
|
40
|
-
x = x + y
|
|
41
|
-
|
|
42
|
-
return x
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
class EncoderDecoderBlock2D(nn.Module):
|
|
46
|
-
def __init__(
|
|
47
|
-
self,
|
|
48
|
-
in_channels: int,
|
|
49
|
-
out_channels: int,
|
|
50
|
-
num_layers: int = 1,
|
|
51
|
-
resnet_groups: int = 32,
|
|
52
|
-
add_downsample=True,
|
|
53
|
-
add_upsample=True,
|
|
54
|
-
):
|
|
55
|
-
super().__init__()
|
|
56
|
-
|
|
57
|
-
# Add the resnet blocks
|
|
58
|
-
self.resnets = [
|
|
59
|
-
ResnetBlock2D(
|
|
60
|
-
in_channels=in_channels if i == 0 else out_channels,
|
|
61
|
-
out_channels=out_channels,
|
|
62
|
-
groups=resnet_groups,
|
|
63
|
-
)
|
|
64
|
-
for i in range(num_layers)
|
|
65
|
-
]
|
|
66
|
-
|
|
67
|
-
# Add an optional downsampling layer
|
|
68
|
-
if add_downsample:
|
|
69
|
-
self.downsample = nn.Conv2d(
|
|
70
|
-
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
# or upsampling layer
|
|
74
|
-
if add_upsample:
|
|
75
|
-
self.upsample = nn.Conv2d(
|
|
76
|
-
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
def __call__(self, x):
|
|
80
|
-
for resnet in self.resnets:
|
|
81
|
-
x = resnet(x)
|
|
82
|
-
|
|
83
|
-
if "downsample" in self:
|
|
84
|
-
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
|
85
|
-
x = self.downsample(x)
|
|
86
|
-
|
|
87
|
-
if "upsample" in self:
|
|
88
|
-
x = self.upsample(upsample_nearest(x))
|
|
89
|
-
|
|
90
|
-
return x
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
class Encoder(nn.Module):
|
|
94
|
-
"""Implements the encoder side of the Autoencoder."""
|
|
95
|
-
|
|
96
|
-
def __init__(
|
|
97
|
-
self,
|
|
98
|
-
in_channels: int,
|
|
99
|
-
out_channels: int,
|
|
100
|
-
block_out_channels: List[int] = [64],
|
|
101
|
-
layers_per_block: int = 2,
|
|
102
|
-
resnet_groups: int = 32,
|
|
103
|
-
):
|
|
104
|
-
super().__init__()
|
|
105
|
-
|
|
106
|
-
self.conv_in = nn.Conv2d(
|
|
107
|
-
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
channels = [block_out_channels[0]] + list(block_out_channels)
|
|
111
|
-
self.down_blocks = [
|
|
112
|
-
EncoderDecoderBlock2D(
|
|
113
|
-
in_channels,
|
|
114
|
-
out_channels,
|
|
115
|
-
num_layers=layers_per_block,
|
|
116
|
-
resnet_groups=resnet_groups,
|
|
117
|
-
add_downsample=i < len(block_out_channels) - 1,
|
|
118
|
-
add_upsample=False,
|
|
119
|
-
)
|
|
120
|
-
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
|
121
|
-
]
|
|
122
|
-
|
|
123
|
-
self.mid_blocks = [
|
|
124
|
-
ResnetBlock2D(
|
|
125
|
-
in_channels=block_out_channels[-1],
|
|
126
|
-
out_channels=block_out_channels[-1],
|
|
127
|
-
groups=resnet_groups,
|
|
128
|
-
),
|
|
129
|
-
Attention(block_out_channels[-1], resnet_groups),
|
|
130
|
-
ResnetBlock2D(
|
|
131
|
-
in_channels=block_out_channels[-1],
|
|
132
|
-
out_channels=block_out_channels[-1],
|
|
133
|
-
groups=resnet_groups,
|
|
134
|
-
),
|
|
135
|
-
]
|
|
136
|
-
|
|
137
|
-
self.conv_norm_out = nn.GroupNorm(
|
|
138
|
-
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
|
139
|
-
)
|
|
140
|
-
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
|
141
|
-
|
|
142
|
-
def __call__(self, x):
|
|
143
|
-
x = self.conv_in(x)
|
|
144
|
-
|
|
145
|
-
for l in self.down_blocks:
|
|
146
|
-
x = l(x)
|
|
147
|
-
|
|
148
|
-
x = self.mid_blocks[0](x)
|
|
149
|
-
x = self.mid_blocks[1](x)
|
|
150
|
-
x = self.mid_blocks[2](x)
|
|
151
|
-
|
|
152
|
-
x = self.conv_norm_out(x)
|
|
153
|
-
x = nn.silu(x)
|
|
154
|
-
x = self.conv_out(x)
|
|
155
|
-
|
|
156
|
-
return x
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
class Decoder(nn.Module):
|
|
160
|
-
"""Implements the decoder side of the Autoencoder."""
|
|
161
|
-
|
|
162
|
-
def __init__(
|
|
163
|
-
self,
|
|
164
|
-
in_channels: int,
|
|
165
|
-
out_channels: int,
|
|
166
|
-
block_out_channels: List[int] = [64],
|
|
167
|
-
layers_per_block: int = 2,
|
|
168
|
-
resnet_groups: int = 32,
|
|
169
|
-
):
|
|
170
|
-
super().__init__()
|
|
171
|
-
|
|
172
|
-
self.conv_in = nn.Conv2d(
|
|
173
|
-
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
self.mid_blocks = [
|
|
177
|
-
ResnetBlock2D(
|
|
178
|
-
in_channels=block_out_channels[-1],
|
|
179
|
-
out_channels=block_out_channels[-1],
|
|
180
|
-
groups=resnet_groups,
|
|
181
|
-
),
|
|
182
|
-
Attention(block_out_channels[-1], resnet_groups),
|
|
183
|
-
ResnetBlock2D(
|
|
184
|
-
in_channels=block_out_channels[-1],
|
|
185
|
-
out_channels=block_out_channels[-1],
|
|
186
|
-
groups=resnet_groups,
|
|
187
|
-
),
|
|
188
|
-
]
|
|
189
|
-
|
|
190
|
-
channels = list(reversed(block_out_channels))
|
|
191
|
-
channels = [channels[0]] + channels
|
|
192
|
-
self.up_blocks = [
|
|
193
|
-
EncoderDecoderBlock2D(
|
|
194
|
-
in_channels,
|
|
195
|
-
out_channels,
|
|
196
|
-
num_layers=layers_per_block,
|
|
197
|
-
resnet_groups=resnet_groups,
|
|
198
|
-
add_downsample=False,
|
|
199
|
-
add_upsample=i < len(block_out_channels) - 1,
|
|
200
|
-
)
|
|
201
|
-
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
|
202
|
-
]
|
|
203
|
-
|
|
204
|
-
self.conv_norm_out = nn.GroupNorm(
|
|
205
|
-
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
|
206
|
-
)
|
|
207
|
-
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
|
208
|
-
|
|
209
|
-
def __call__(self, x):
|
|
210
|
-
x = self.conv_in(x)
|
|
211
|
-
|
|
212
|
-
x = self.mid_blocks[0](x)
|
|
213
|
-
x = self.mid_blocks[1](x)
|
|
214
|
-
x = self.mid_blocks[2](x)
|
|
215
|
-
|
|
216
|
-
for l in self.up_blocks:
|
|
217
|
-
x = l(x)
|
|
218
|
-
|
|
219
|
-
x = self.conv_norm_out(x)
|
|
220
|
-
x = nn.silu(x)
|
|
221
|
-
x = self.conv_out(x)
|
|
222
|
-
|
|
223
|
-
return x
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
class Autoencoder(nn.Module):
|
|
227
|
-
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
|
228
|
-
|
|
229
|
-
def __init__(self, config: AutoencoderConfig):
|
|
230
|
-
super().__init__()
|
|
231
|
-
|
|
232
|
-
self.latent_channels = config.latent_channels_in
|
|
233
|
-
self.scaling_factor = config.scaling_factor
|
|
234
|
-
self.encoder = Encoder(
|
|
235
|
-
config.in_channels,
|
|
236
|
-
config.latent_channels_out,
|
|
237
|
-
config.block_out_channels,
|
|
238
|
-
config.layers_per_block,
|
|
239
|
-
resnet_groups=config.norm_num_groups,
|
|
240
|
-
)
|
|
241
|
-
self.decoder = Decoder(
|
|
242
|
-
config.latent_channels_in,
|
|
243
|
-
config.out_channels,
|
|
244
|
-
config.block_out_channels,
|
|
245
|
-
config.layers_per_block + 1,
|
|
246
|
-
resnet_groups=config.norm_num_groups,
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
self.quant_proj = nn.Linear(
|
|
250
|
-
config.latent_channels_out, config.latent_channels_out
|
|
251
|
-
)
|
|
252
|
-
self.post_quant_proj = nn.Linear(
|
|
253
|
-
config.latent_channels_in, config.latent_channels_in
|
|
254
|
-
)
|
|
255
|
-
|
|
256
|
-
def decode(self, z):
|
|
257
|
-
z = z / self.scaling_factor
|
|
258
|
-
return self.decoder(self.post_quant_proj(z))
|
|
259
|
-
|
|
260
|
-
def encode(self, x):
|
|
261
|
-
x = self.encoder(x)
|
|
262
|
-
x = self.quant_proj(x)
|
|
263
|
-
mean, logvar = x.split(2, axis=-1)
|
|
264
|
-
mean = mean * self.scaling_factor
|
|
265
|
-
logvar = logvar + 2 * math.log(self.scaling_factor)
|
|
266
|
-
|
|
267
|
-
return mean, logvar
|
|
268
|
-
|
|
269
|
-
def __call__(self, x, key=None):
|
|
270
|
-
mean, logvar = self.encode(x)
|
|
271
|
-
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
|
272
|
-
x_hat = self.decode(z)
|
|
273
|
-
|
|
274
|
-
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
|
File without changes
|
|
File without changes
|