nexaai 1.0.16rc10__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc12__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +7 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +60 -14
- nexaai/log.py +92 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/ml.py +60 -14
- nexaai/mlx_backend/sd/modeling/model_io.py +72 -17
- nexaai/runtime.py +4 -0
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/METADATA +1 -1
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/RECORD +29 -16
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/WHEEL +0 -0
- {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
# Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from .config import AutoencoderConfig
|
|
10
|
+
from .unet import ResnetBlock2D, upsample_nearest
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Attention(nn.Module):
|
|
14
|
+
"""A single head unmasked attention for use with the VAE."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, dims: int, norm_groups: int = 32):
|
|
17
|
+
super().__init__()
|
|
18
|
+
|
|
19
|
+
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
|
20
|
+
self.query_proj = nn.Linear(dims, dims)
|
|
21
|
+
self.key_proj = nn.Linear(dims, dims)
|
|
22
|
+
self.value_proj = nn.Linear(dims, dims)
|
|
23
|
+
self.out_proj = nn.Linear(dims, dims)
|
|
24
|
+
|
|
25
|
+
def __call__(self, x):
|
|
26
|
+
B, H, W, C = x.shape
|
|
27
|
+
|
|
28
|
+
y = self.group_norm(x)
|
|
29
|
+
|
|
30
|
+
queries = self.query_proj(y).reshape(B, H * W, C)
|
|
31
|
+
keys = self.key_proj(y).reshape(B, H * W, C)
|
|
32
|
+
values = self.value_proj(y).reshape(B, H * W, C)
|
|
33
|
+
|
|
34
|
+
scale = 1 / math.sqrt(queries.shape[-1])
|
|
35
|
+
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
|
36
|
+
attn = mx.softmax(scores, axis=-1)
|
|
37
|
+
y = (attn @ values).reshape(B, H, W, C)
|
|
38
|
+
|
|
39
|
+
y = self.out_proj(y)
|
|
40
|
+
x = x + y
|
|
41
|
+
|
|
42
|
+
return x
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class EncoderDecoderBlock2D(nn.Module):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
in_channels: int,
|
|
49
|
+
out_channels: int,
|
|
50
|
+
num_layers: int = 1,
|
|
51
|
+
resnet_groups: int = 32,
|
|
52
|
+
add_downsample=True,
|
|
53
|
+
add_upsample=True,
|
|
54
|
+
):
|
|
55
|
+
super().__init__()
|
|
56
|
+
|
|
57
|
+
# Add the resnet blocks
|
|
58
|
+
self.resnets = [
|
|
59
|
+
ResnetBlock2D(
|
|
60
|
+
in_channels=in_channels if i == 0 else out_channels,
|
|
61
|
+
out_channels=out_channels,
|
|
62
|
+
groups=resnet_groups,
|
|
63
|
+
)
|
|
64
|
+
for i in range(num_layers)
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
# Add an optional downsampling layer
|
|
68
|
+
if add_downsample:
|
|
69
|
+
self.downsample = nn.Conv2d(
|
|
70
|
+
out_channels, out_channels, kernel_size=3, stride=2, padding=0
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# or upsampling layer
|
|
74
|
+
if add_upsample:
|
|
75
|
+
self.upsample = nn.Conv2d(
|
|
76
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def __call__(self, x):
|
|
80
|
+
for resnet in self.resnets:
|
|
81
|
+
x = resnet(x)
|
|
82
|
+
|
|
83
|
+
if "downsample" in self:
|
|
84
|
+
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
|
85
|
+
x = self.downsample(x)
|
|
86
|
+
|
|
87
|
+
if "upsample" in self:
|
|
88
|
+
x = self.upsample(upsample_nearest(x))
|
|
89
|
+
|
|
90
|
+
return x
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class Encoder(nn.Module):
|
|
94
|
+
"""Implements the encoder side of the Autoencoder."""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
in_channels: int,
|
|
99
|
+
out_channels: int,
|
|
100
|
+
block_out_channels: List[int] = [64],
|
|
101
|
+
layers_per_block: int = 2,
|
|
102
|
+
resnet_groups: int = 32,
|
|
103
|
+
):
|
|
104
|
+
super().__init__()
|
|
105
|
+
|
|
106
|
+
self.conv_in = nn.Conv2d(
|
|
107
|
+
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
channels = [block_out_channels[0]] + list(block_out_channels)
|
|
111
|
+
self.down_blocks = [
|
|
112
|
+
EncoderDecoderBlock2D(
|
|
113
|
+
in_channels,
|
|
114
|
+
out_channels,
|
|
115
|
+
num_layers=layers_per_block,
|
|
116
|
+
resnet_groups=resnet_groups,
|
|
117
|
+
add_downsample=i < len(block_out_channels) - 1,
|
|
118
|
+
add_upsample=False,
|
|
119
|
+
)
|
|
120
|
+
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
self.mid_blocks = [
|
|
124
|
+
ResnetBlock2D(
|
|
125
|
+
in_channels=block_out_channels[-1],
|
|
126
|
+
out_channels=block_out_channels[-1],
|
|
127
|
+
groups=resnet_groups,
|
|
128
|
+
),
|
|
129
|
+
Attention(block_out_channels[-1], resnet_groups),
|
|
130
|
+
ResnetBlock2D(
|
|
131
|
+
in_channels=block_out_channels[-1],
|
|
132
|
+
out_channels=block_out_channels[-1],
|
|
133
|
+
groups=resnet_groups,
|
|
134
|
+
),
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
self.conv_norm_out = nn.GroupNorm(
|
|
138
|
+
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
|
139
|
+
)
|
|
140
|
+
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
|
141
|
+
|
|
142
|
+
def __call__(self, x):
|
|
143
|
+
x = self.conv_in(x)
|
|
144
|
+
|
|
145
|
+
for l in self.down_blocks:
|
|
146
|
+
x = l(x)
|
|
147
|
+
|
|
148
|
+
x = self.mid_blocks[0](x)
|
|
149
|
+
x = self.mid_blocks[1](x)
|
|
150
|
+
x = self.mid_blocks[2](x)
|
|
151
|
+
|
|
152
|
+
x = self.conv_norm_out(x)
|
|
153
|
+
x = nn.silu(x)
|
|
154
|
+
x = self.conv_out(x)
|
|
155
|
+
|
|
156
|
+
return x
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class Decoder(nn.Module):
|
|
160
|
+
"""Implements the decoder side of the Autoencoder."""
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
in_channels: int,
|
|
165
|
+
out_channels: int,
|
|
166
|
+
block_out_channels: List[int] = [64],
|
|
167
|
+
layers_per_block: int = 2,
|
|
168
|
+
resnet_groups: int = 32,
|
|
169
|
+
):
|
|
170
|
+
super().__init__()
|
|
171
|
+
|
|
172
|
+
self.conv_in = nn.Conv2d(
|
|
173
|
+
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.mid_blocks = [
|
|
177
|
+
ResnetBlock2D(
|
|
178
|
+
in_channels=block_out_channels[-1],
|
|
179
|
+
out_channels=block_out_channels[-1],
|
|
180
|
+
groups=resnet_groups,
|
|
181
|
+
),
|
|
182
|
+
Attention(block_out_channels[-1], resnet_groups),
|
|
183
|
+
ResnetBlock2D(
|
|
184
|
+
in_channels=block_out_channels[-1],
|
|
185
|
+
out_channels=block_out_channels[-1],
|
|
186
|
+
groups=resnet_groups,
|
|
187
|
+
),
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
channels = list(reversed(block_out_channels))
|
|
191
|
+
channels = [channels[0]] + channels
|
|
192
|
+
self.up_blocks = [
|
|
193
|
+
EncoderDecoderBlock2D(
|
|
194
|
+
in_channels,
|
|
195
|
+
out_channels,
|
|
196
|
+
num_layers=layers_per_block,
|
|
197
|
+
resnet_groups=resnet_groups,
|
|
198
|
+
add_downsample=False,
|
|
199
|
+
add_upsample=i < len(block_out_channels) - 1,
|
|
200
|
+
)
|
|
201
|
+
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
self.conv_norm_out = nn.GroupNorm(
|
|
205
|
+
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
|
206
|
+
)
|
|
207
|
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
|
208
|
+
|
|
209
|
+
def __call__(self, x):
|
|
210
|
+
x = self.conv_in(x)
|
|
211
|
+
|
|
212
|
+
x = self.mid_blocks[0](x)
|
|
213
|
+
x = self.mid_blocks[1](x)
|
|
214
|
+
x = self.mid_blocks[2](x)
|
|
215
|
+
|
|
216
|
+
for l in self.up_blocks:
|
|
217
|
+
x = l(x)
|
|
218
|
+
|
|
219
|
+
x = self.conv_norm_out(x)
|
|
220
|
+
x = nn.silu(x)
|
|
221
|
+
x = self.conv_out(x)
|
|
222
|
+
|
|
223
|
+
return x
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class Autoencoder(nn.Module):
|
|
227
|
+
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
|
228
|
+
|
|
229
|
+
def __init__(self, config: AutoencoderConfig):
|
|
230
|
+
super().__init__()
|
|
231
|
+
|
|
232
|
+
self.latent_channels = config.latent_channels_in
|
|
233
|
+
self.scaling_factor = config.scaling_factor
|
|
234
|
+
self.encoder = Encoder(
|
|
235
|
+
config.in_channels,
|
|
236
|
+
config.latent_channels_out,
|
|
237
|
+
config.block_out_channels,
|
|
238
|
+
config.layers_per_block,
|
|
239
|
+
resnet_groups=config.norm_num_groups,
|
|
240
|
+
)
|
|
241
|
+
self.decoder = Decoder(
|
|
242
|
+
config.latent_channels_in,
|
|
243
|
+
config.out_channels,
|
|
244
|
+
config.block_out_channels,
|
|
245
|
+
config.layers_per_block + 1,
|
|
246
|
+
resnet_groups=config.norm_num_groups,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
self.quant_proj = nn.Linear(
|
|
250
|
+
config.latent_channels_out, config.latent_channels_out
|
|
251
|
+
)
|
|
252
|
+
self.post_quant_proj = nn.Linear(
|
|
253
|
+
config.latent_channels_in, config.latent_channels_in
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def decode(self, z):
|
|
257
|
+
z = z / self.scaling_factor
|
|
258
|
+
return self.decoder(self.post_quant_proj(z))
|
|
259
|
+
|
|
260
|
+
def encode(self, x):
|
|
261
|
+
x = self.encoder(x)
|
|
262
|
+
x = self.quant_proj(x)
|
|
263
|
+
mean, logvar = x.split(2, axis=-1)
|
|
264
|
+
mean = mean * self.scaling_factor
|
|
265
|
+
logvar = logvar + 2 * math.log(self.scaling_factor)
|
|
266
|
+
|
|
267
|
+
return mean, logvar
|
|
268
|
+
|
|
269
|
+
def __call__(self, x, key=None):
|
|
270
|
+
mean, logvar = self.encode(x)
|
|
271
|
+
z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
|
|
272
|
+
x_hat = self.decode(z)
|
|
273
|
+
|
|
274
|
+
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
nexaai/mlx_backend/ml.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# This file defines the python interface that c-lib expects from a python backend
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from dataclasses import dataclass
|
|
4
7
|
|
|
5
8
|
from abc import ABC, abstractmethod
|
|
6
9
|
from dataclasses import dataclass, field
|
|
@@ -101,9 +104,12 @@ class ModelConfig:
|
|
|
101
104
|
n_threads_batch: int = 0 # number of threads to use for batch processing
|
|
102
105
|
n_batch: int = 0 # logical maximum batch size that can be submitted to llama_decode
|
|
103
106
|
n_ubatch: int = 0 # physical maximum batch size
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
# max number of sequences (i.e. distinct states for recurrent models)
|
|
108
|
+
n_seq_max: int = 0
|
|
109
|
+
# path to chat template file, optional
|
|
110
|
+
chat_template_path: Optional[Path] = None
|
|
111
|
+
# content of chat template file, optional
|
|
112
|
+
chat_template_content: Optional[str] = None
|
|
107
113
|
|
|
108
114
|
|
|
109
115
|
@dataclass
|
|
@@ -118,7 +124,8 @@ class SamplerConfig:
|
|
|
118
124
|
frequency_penalty: float = 0.0
|
|
119
125
|
seed: int = -1 # –1 for random
|
|
120
126
|
grammar_path: Optional[Path] = None
|
|
121
|
-
|
|
127
|
+
# Optional grammar string (BNF-like format)
|
|
128
|
+
grammar_string: Optional[str] = None
|
|
122
129
|
|
|
123
130
|
|
|
124
131
|
@dataclass
|
|
@@ -128,8 +135,10 @@ class GenerationConfig:
|
|
|
128
135
|
stop: Sequence[str] = field(default_factory=tuple)
|
|
129
136
|
n_past: int = 0
|
|
130
137
|
sampler_config: Optional[SamplerConfig] = None
|
|
131
|
-
|
|
132
|
-
|
|
138
|
+
# Array of image paths for VLM (None if none)
|
|
139
|
+
image_paths: Optional[Sequence[Path]] = None
|
|
140
|
+
# Array of audio paths for VLM (None if none)
|
|
141
|
+
audio_paths: Optional[Sequence[Path]] = None
|
|
133
142
|
|
|
134
143
|
|
|
135
144
|
@dataclass
|
|
@@ -170,6 +179,32 @@ class RerankConfig:
|
|
|
170
179
|
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
171
180
|
|
|
172
181
|
|
|
182
|
+
# image-gen
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class ImageGenTxt2ImgInput:
|
|
187
|
+
"""Input structure for text-to-image generation."""
|
|
188
|
+
prompt: str
|
|
189
|
+
config: ImageGenerationConfig
|
|
190
|
+
output_path: Optional[Path] = None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@dataclass
|
|
194
|
+
class ImageGenImg2ImgInput:
|
|
195
|
+
"""Input structure for image-to-image generation."""
|
|
196
|
+
init_image_path: Path
|
|
197
|
+
prompt: str
|
|
198
|
+
config: ImageGenerationConfig
|
|
199
|
+
output_path: Optional[Path] = None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class ImageGenOutput:
|
|
204
|
+
"""Output structure for image generation."""
|
|
205
|
+
output_image_path: Path
|
|
206
|
+
|
|
207
|
+
|
|
173
208
|
@dataclass
|
|
174
209
|
class ImageSamplerConfig:
|
|
175
210
|
"""Configuration for image sampling."""
|
|
@@ -180,17 +215,27 @@ class ImageSamplerConfig:
|
|
|
180
215
|
seed: int = -1 # –1 for random
|
|
181
216
|
|
|
182
217
|
|
|
218
|
+
@dataclass
|
|
219
|
+
class ImageGenCreateInput:
|
|
220
|
+
"""Configuration for image generation."""
|
|
221
|
+
model_name: str
|
|
222
|
+
model_path: Path
|
|
223
|
+
config: ModelConfig
|
|
224
|
+
scheduler_config_path: Path
|
|
225
|
+
plugin_id: str
|
|
226
|
+
device_id: Optional[str] = None
|
|
227
|
+
|
|
228
|
+
|
|
183
229
|
@dataclass
|
|
184
230
|
class ImageGenerationConfig:
|
|
185
231
|
"""Configuration for image generation."""
|
|
186
|
-
prompts:
|
|
187
|
-
|
|
232
|
+
prompts: List[str]
|
|
233
|
+
sampler_config: ImageSamplerConfig
|
|
234
|
+
scheduler_config: SchedulerConfig
|
|
235
|
+
strength: float
|
|
236
|
+
negative_prompts: Optional[List[str]] = None
|
|
188
237
|
height: int = 512
|
|
189
238
|
width: int = 512
|
|
190
|
-
sampler_config: Optional[ImageSamplerConfig] = None
|
|
191
|
-
lora_id: int = -1 # –1 for none
|
|
192
|
-
init_image: Optional[Image] = None
|
|
193
|
-
strength: float = 1.0
|
|
194
239
|
|
|
195
240
|
|
|
196
241
|
@dataclass
|
|
@@ -261,7 +306,7 @@ class TTSResult:
|
|
|
261
306
|
class BoundingBox:
|
|
262
307
|
"""Generic bounding box structure."""
|
|
263
308
|
x: float # X coordinate (normalized or pixel, depends on model)
|
|
264
|
-
y: float # Y coordinate (normalized or pixel, depends on model)
|
|
309
|
+
y: float # Y coordinate (normalized or pixel, depends on model)
|
|
265
310
|
width: float # Width
|
|
266
311
|
height: float # Height
|
|
267
312
|
|
|
@@ -275,7 +320,8 @@ class CVResult:
|
|
|
275
320
|
confidence: float = 0.0 # Confidence score [0.0-1.0]
|
|
276
321
|
bbox: Optional[BoundingBox] = None # Bounding box (example: YOLO)
|
|
277
322
|
text: Optional[str] = None # Text result (example: OCR)
|
|
278
|
-
|
|
323
|
+
# Feature embedding (example: CLIP embedding)
|
|
324
|
+
embedding: Optional[List[float]] = None
|
|
279
325
|
embedding_dim: int = 0 # Embedding dimension
|
|
280
326
|
|
|
281
327
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright © 2023-2024 Apple Inc.
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import os
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
6
7
|
import mlx.core as mx
|
|
@@ -176,19 +177,37 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
|
|
176
177
|
|
|
177
178
|
|
|
178
179
|
def _check_key(key: str, part: str):
|
|
180
|
+
# Check if it's a local path
|
|
181
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
182
|
+
# For local paths, we'll use a default model structure
|
|
183
|
+
return
|
|
179
184
|
if key not in _MODELS:
|
|
180
185
|
raise ValueError(
|
|
181
186
|
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
|
182
187
|
)
|
|
183
188
|
|
|
189
|
+
def _get_model_path(key: str, file_path: str):
|
|
190
|
+
"""Get the full path for a model file, supporting both local and HuggingFace paths"""
|
|
191
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
192
|
+
# Local path
|
|
193
|
+
return os.path.join(key, file_path)
|
|
194
|
+
else:
|
|
195
|
+
# HuggingFace path
|
|
196
|
+
return hf_hub_download(key, file_path)
|
|
197
|
+
|
|
184
198
|
|
|
185
199
|
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
186
200
|
"""Load the stable diffusion UNet from Hugging Face Hub."""
|
|
187
201
|
_check_key(key, "load_unet")
|
|
188
202
|
|
|
189
|
-
#
|
|
190
|
-
|
|
191
|
-
|
|
203
|
+
# Get the config path
|
|
204
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
205
|
+
# Local path - use SDXL Turbo structure
|
|
206
|
+
unet_config = "unet/config.json"
|
|
207
|
+
else:
|
|
208
|
+
unet_config = _MODELS[key]["unet_config"]
|
|
209
|
+
|
|
210
|
+
with open(_get_model_path(key, unet_config)) as f:
|
|
192
211
|
config = json.load(f)
|
|
193
212
|
|
|
194
213
|
n_blocks = len(config["block_out_channels"])
|
|
@@ -219,8 +238,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
|
219
238
|
)
|
|
220
239
|
|
|
221
240
|
# Download the weights and map them into the model
|
|
222
|
-
|
|
223
|
-
|
|
241
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
242
|
+
# Local path - use SDXL Turbo structure
|
|
243
|
+
unet_weights = "unet/diffusion_pytorch_model.safetensors"
|
|
244
|
+
else:
|
|
245
|
+
unet_weights = _MODELS[key]["unet"]
|
|
246
|
+
|
|
247
|
+
weight_file = _get_model_path(key, unet_weights)
|
|
224
248
|
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
|
225
249
|
|
|
226
250
|
return model
|
|
@@ -238,8 +262,13 @@ def load_text_encoder(
|
|
|
238
262
|
config_key = config_key or (model_key + "_config")
|
|
239
263
|
|
|
240
264
|
# Download the config and create the model
|
|
241
|
-
|
|
242
|
-
|
|
265
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
266
|
+
# Local path - use SDXL Turbo structure
|
|
267
|
+
text_encoder_config = f"{model_key}/config.json"
|
|
268
|
+
else:
|
|
269
|
+
text_encoder_config = _MODELS[key][config_key]
|
|
270
|
+
|
|
271
|
+
with open(_get_model_path(key, text_encoder_config)) as f:
|
|
243
272
|
config = json.load(f)
|
|
244
273
|
|
|
245
274
|
with_projection = "WithProjection" in config["architectures"][0]
|
|
@@ -257,8 +286,13 @@ def load_text_encoder(
|
|
|
257
286
|
)
|
|
258
287
|
|
|
259
288
|
# Download the weights and map them into the model
|
|
260
|
-
|
|
261
|
-
|
|
289
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
290
|
+
# Local path - use SDXL Turbo structure
|
|
291
|
+
text_encoder_weights = f"{model_key}/model.safetensors"
|
|
292
|
+
else:
|
|
293
|
+
text_encoder_weights = _MODELS[key][model_key]
|
|
294
|
+
|
|
295
|
+
weight_file = _get_model_path(key, text_encoder_weights)
|
|
262
296
|
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
|
263
297
|
|
|
264
298
|
return model
|
|
@@ -269,8 +303,13 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
|
269
303
|
_check_key(key, "load_autoencoder")
|
|
270
304
|
|
|
271
305
|
# Download the config and create the model
|
|
272
|
-
|
|
273
|
-
|
|
306
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
307
|
+
# Local path - use SDXL Turbo structure
|
|
308
|
+
vae_config = "vae/config.json"
|
|
309
|
+
else:
|
|
310
|
+
vae_config = _MODELS[key]["vae_config"]
|
|
311
|
+
|
|
312
|
+
with open(_get_model_path(key, vae_config)) as f:
|
|
274
313
|
config = json.load(f)
|
|
275
314
|
|
|
276
315
|
model = Autoencoder(
|
|
@@ -287,8 +326,13 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|
|
287
326
|
)
|
|
288
327
|
|
|
289
328
|
# Download the weights and map them into the model
|
|
290
|
-
|
|
291
|
-
|
|
329
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
330
|
+
# Local path - use SDXL Turbo structure
|
|
331
|
+
vae_weights = "vae/diffusion_pytorch_model.safetensors"
|
|
332
|
+
else:
|
|
333
|
+
vae_weights = _MODELS[key]["vae"]
|
|
334
|
+
|
|
335
|
+
weight_file = _get_model_path(key, vae_weights)
|
|
292
336
|
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
|
293
337
|
|
|
294
338
|
return model
|
|
@@ -298,8 +342,13 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
|
|
298
342
|
"""Load the stable diffusion config from Hugging Face Hub."""
|
|
299
343
|
_check_key(key, "load_diffusion_config")
|
|
300
344
|
|
|
301
|
-
|
|
302
|
-
|
|
345
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
346
|
+
# Local path - use SDXL Turbo structure
|
|
347
|
+
diffusion_config = "scheduler/scheduler_config.json"
|
|
348
|
+
else:
|
|
349
|
+
diffusion_config = _MODELS[key]["diffusion_config"]
|
|
350
|
+
|
|
351
|
+
with open(_get_model_path(key, diffusion_config)) as f:
|
|
303
352
|
config = json.load(f)
|
|
304
353
|
|
|
305
354
|
return DiffusionConfig(
|
|
@@ -317,11 +366,17 @@ def load_tokenizer(
|
|
|
317
366
|
):
|
|
318
367
|
_check_key(key, "load_tokenizer")
|
|
319
368
|
|
|
320
|
-
|
|
369
|
+
if os.path.exists(key) or '/' in key or '\\' in key:
|
|
370
|
+
# Local path - use SDXL Turbo structure
|
|
371
|
+
vocab_file = _get_model_path(key, f"tokenizer/{vocab_key.split('_')[1]}.json")
|
|
372
|
+
merges_file = _get_model_path(key, f"tokenizer/{merges_key.split('_')[1]}.txt")
|
|
373
|
+
else:
|
|
374
|
+
vocab_file = _get_model_path(key, _MODELS[key][vocab_key])
|
|
375
|
+
merges_file = _get_model_path(key, _MODELS[key][merges_key])
|
|
376
|
+
|
|
321
377
|
with open(vocab_file, encoding="utf-8") as f:
|
|
322
378
|
vocab = json.load(f)
|
|
323
379
|
|
|
324
|
-
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
|
|
325
380
|
with open(merges_file, encoding="utf-8") as f:
|
|
326
381
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
|
327
382
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
nexaai/runtime.py
CHANGED
|
@@ -28,6 +28,10 @@ def _shutdown_runtime() -> None:
|
|
|
28
28
|
# Public helper so advanced users can reclaim memory on demand
|
|
29
29
|
shutdown = _shutdown_runtime
|
|
30
30
|
|
|
31
|
+
def is_initialized() -> bool:
|
|
32
|
+
"""Check if the runtime has been initialized."""
|
|
33
|
+
return _runtime_alive
|
|
34
|
+
|
|
31
35
|
# ----------------------------------------------------------------------
|
|
32
36
|
# Single public class
|
|
33
37
|
# ----------------------------------------------------------------------
|