nexaai 1.0.16rc10__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc12__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (29) hide show
  1. nexaai/__init__.py +7 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +1 -1
  4. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  5. nexaai/binds/libnexa_bridge.dylib +0 -0
  6. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  8. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/ml.py +60 -14
  10. nexaai/log.py +92 -0
  11. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  12. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  13. nexaai/mlx_backend/image_gen/interface.py +82 -0
  14. nexaai/mlx_backend/image_gen/main.py +281 -0
  15. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  16. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  17. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  18. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  19. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  20. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  21. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  22. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  23. nexaai/mlx_backend/ml.py +60 -14
  24. nexaai/mlx_backend/sd/modeling/model_io.py +72 -17
  25. nexaai/runtime.py +4 -0
  26. {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/METADATA +1 -1
  27. {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/RECORD +29 -16
  28. {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/WHEEL +0 -0
  29. {nexaai-1.0.16rc10.dist-info → nexaai-1.0.16rc12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,274 @@
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import List
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from .config import AutoencoderConfig
10
+ from .unet import ResnetBlock2D, upsample_nearest
11
+
12
+
13
+ class Attention(nn.Module):
14
+ """A single head unmasked attention for use with the VAE."""
15
+
16
+ def __init__(self, dims: int, norm_groups: int = 32):
17
+ super().__init__()
18
+
19
+ self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
20
+ self.query_proj = nn.Linear(dims, dims)
21
+ self.key_proj = nn.Linear(dims, dims)
22
+ self.value_proj = nn.Linear(dims, dims)
23
+ self.out_proj = nn.Linear(dims, dims)
24
+
25
+ def __call__(self, x):
26
+ B, H, W, C = x.shape
27
+
28
+ y = self.group_norm(x)
29
+
30
+ queries = self.query_proj(y).reshape(B, H * W, C)
31
+ keys = self.key_proj(y).reshape(B, H * W, C)
32
+ values = self.value_proj(y).reshape(B, H * W, C)
33
+
34
+ scale = 1 / math.sqrt(queries.shape[-1])
35
+ scores = (queries * scale) @ keys.transpose(0, 2, 1)
36
+ attn = mx.softmax(scores, axis=-1)
37
+ y = (attn @ values).reshape(B, H, W, C)
38
+
39
+ y = self.out_proj(y)
40
+ x = x + y
41
+
42
+ return x
43
+
44
+
45
+ class EncoderDecoderBlock2D(nn.Module):
46
+ def __init__(
47
+ self,
48
+ in_channels: int,
49
+ out_channels: int,
50
+ num_layers: int = 1,
51
+ resnet_groups: int = 32,
52
+ add_downsample=True,
53
+ add_upsample=True,
54
+ ):
55
+ super().__init__()
56
+
57
+ # Add the resnet blocks
58
+ self.resnets = [
59
+ ResnetBlock2D(
60
+ in_channels=in_channels if i == 0 else out_channels,
61
+ out_channels=out_channels,
62
+ groups=resnet_groups,
63
+ )
64
+ for i in range(num_layers)
65
+ ]
66
+
67
+ # Add an optional downsampling layer
68
+ if add_downsample:
69
+ self.downsample = nn.Conv2d(
70
+ out_channels, out_channels, kernel_size=3, stride=2, padding=0
71
+ )
72
+
73
+ # or upsampling layer
74
+ if add_upsample:
75
+ self.upsample = nn.Conv2d(
76
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
77
+ )
78
+
79
+ def __call__(self, x):
80
+ for resnet in self.resnets:
81
+ x = resnet(x)
82
+
83
+ if "downsample" in self:
84
+ x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
85
+ x = self.downsample(x)
86
+
87
+ if "upsample" in self:
88
+ x = self.upsample(upsample_nearest(x))
89
+
90
+ return x
91
+
92
+
93
+ class Encoder(nn.Module):
94
+ """Implements the encoder side of the Autoencoder."""
95
+
96
+ def __init__(
97
+ self,
98
+ in_channels: int,
99
+ out_channels: int,
100
+ block_out_channels: List[int] = [64],
101
+ layers_per_block: int = 2,
102
+ resnet_groups: int = 32,
103
+ ):
104
+ super().__init__()
105
+
106
+ self.conv_in = nn.Conv2d(
107
+ in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
108
+ )
109
+
110
+ channels = [block_out_channels[0]] + list(block_out_channels)
111
+ self.down_blocks = [
112
+ EncoderDecoderBlock2D(
113
+ in_channels,
114
+ out_channels,
115
+ num_layers=layers_per_block,
116
+ resnet_groups=resnet_groups,
117
+ add_downsample=i < len(block_out_channels) - 1,
118
+ add_upsample=False,
119
+ )
120
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
121
+ ]
122
+
123
+ self.mid_blocks = [
124
+ ResnetBlock2D(
125
+ in_channels=block_out_channels[-1],
126
+ out_channels=block_out_channels[-1],
127
+ groups=resnet_groups,
128
+ ),
129
+ Attention(block_out_channels[-1], resnet_groups),
130
+ ResnetBlock2D(
131
+ in_channels=block_out_channels[-1],
132
+ out_channels=block_out_channels[-1],
133
+ groups=resnet_groups,
134
+ ),
135
+ ]
136
+
137
+ self.conv_norm_out = nn.GroupNorm(
138
+ resnet_groups, block_out_channels[-1], pytorch_compatible=True
139
+ )
140
+ self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
141
+
142
+ def __call__(self, x):
143
+ x = self.conv_in(x)
144
+
145
+ for l in self.down_blocks:
146
+ x = l(x)
147
+
148
+ x = self.mid_blocks[0](x)
149
+ x = self.mid_blocks[1](x)
150
+ x = self.mid_blocks[2](x)
151
+
152
+ x = self.conv_norm_out(x)
153
+ x = nn.silu(x)
154
+ x = self.conv_out(x)
155
+
156
+ return x
157
+
158
+
159
+ class Decoder(nn.Module):
160
+ """Implements the decoder side of the Autoencoder."""
161
+
162
+ def __init__(
163
+ self,
164
+ in_channels: int,
165
+ out_channels: int,
166
+ block_out_channels: List[int] = [64],
167
+ layers_per_block: int = 2,
168
+ resnet_groups: int = 32,
169
+ ):
170
+ super().__init__()
171
+
172
+ self.conv_in = nn.Conv2d(
173
+ in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
174
+ )
175
+
176
+ self.mid_blocks = [
177
+ ResnetBlock2D(
178
+ in_channels=block_out_channels[-1],
179
+ out_channels=block_out_channels[-1],
180
+ groups=resnet_groups,
181
+ ),
182
+ Attention(block_out_channels[-1], resnet_groups),
183
+ ResnetBlock2D(
184
+ in_channels=block_out_channels[-1],
185
+ out_channels=block_out_channels[-1],
186
+ groups=resnet_groups,
187
+ ),
188
+ ]
189
+
190
+ channels = list(reversed(block_out_channels))
191
+ channels = [channels[0]] + channels
192
+ self.up_blocks = [
193
+ EncoderDecoderBlock2D(
194
+ in_channels,
195
+ out_channels,
196
+ num_layers=layers_per_block,
197
+ resnet_groups=resnet_groups,
198
+ add_downsample=False,
199
+ add_upsample=i < len(block_out_channels) - 1,
200
+ )
201
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
202
+ ]
203
+
204
+ self.conv_norm_out = nn.GroupNorm(
205
+ resnet_groups, block_out_channels[0], pytorch_compatible=True
206
+ )
207
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
208
+
209
+ def __call__(self, x):
210
+ x = self.conv_in(x)
211
+
212
+ x = self.mid_blocks[0](x)
213
+ x = self.mid_blocks[1](x)
214
+ x = self.mid_blocks[2](x)
215
+
216
+ for l in self.up_blocks:
217
+ x = l(x)
218
+
219
+ x = self.conv_norm_out(x)
220
+ x = nn.silu(x)
221
+ x = self.conv_out(x)
222
+
223
+ return x
224
+
225
+
226
+ class Autoencoder(nn.Module):
227
+ """The autoencoder that allows us to perform diffusion in the latent space."""
228
+
229
+ def __init__(self, config: AutoencoderConfig):
230
+ super().__init__()
231
+
232
+ self.latent_channels = config.latent_channels_in
233
+ self.scaling_factor = config.scaling_factor
234
+ self.encoder = Encoder(
235
+ config.in_channels,
236
+ config.latent_channels_out,
237
+ config.block_out_channels,
238
+ config.layers_per_block,
239
+ resnet_groups=config.norm_num_groups,
240
+ )
241
+ self.decoder = Decoder(
242
+ config.latent_channels_in,
243
+ config.out_channels,
244
+ config.block_out_channels,
245
+ config.layers_per_block + 1,
246
+ resnet_groups=config.norm_num_groups,
247
+ )
248
+
249
+ self.quant_proj = nn.Linear(
250
+ config.latent_channels_out, config.latent_channels_out
251
+ )
252
+ self.post_quant_proj = nn.Linear(
253
+ config.latent_channels_in, config.latent_channels_in
254
+ )
255
+
256
+ def decode(self, z):
257
+ z = z / self.scaling_factor
258
+ return self.decoder(self.post_quant_proj(z))
259
+
260
+ def encode(self, x):
261
+ x = self.encoder(x)
262
+ x = self.quant_proj(x)
263
+ mean, logvar = x.split(2, axis=-1)
264
+ mean = mean * self.scaling_factor
265
+ logvar = logvar + 2 * math.log(self.scaling_factor)
266
+
267
+ return mean, logvar
268
+
269
+ def __call__(self, x, key=None):
270
+ mean, logvar = self.encode(x)
271
+ z = mx.random.normal(mean.shape, key=key) * mx.exp(0.5 * logvar) + mean
272
+ x_hat = self.decode(z)
273
+
274
+ return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
nexaai/mlx_backend/ml.py CHANGED
@@ -1,6 +1,9 @@
1
1
  # This file defines the python interface that c-lib expects from a python backend
2
2
 
3
3
  from __future__ import annotations
4
+ from typing import Optional
5
+ from pathlib import Path
6
+ from dataclasses import dataclass
4
7
 
5
8
  from abc import ABC, abstractmethod
6
9
  from dataclasses import dataclass, field
@@ -101,9 +104,12 @@ class ModelConfig:
101
104
  n_threads_batch: int = 0 # number of threads to use for batch processing
102
105
  n_batch: int = 0 # logical maximum batch size that can be submitted to llama_decode
103
106
  n_ubatch: int = 0 # physical maximum batch size
104
- n_seq_max: int = 0 # max number of sequences (i.e. distinct states for recurrent models)
105
- chat_template_path: Optional[Path] = None # path to chat template file, optional
106
- chat_template_content: Optional[str] = None # content of chat template file, optional
107
+ # max number of sequences (i.e. distinct states for recurrent models)
108
+ n_seq_max: int = 0
109
+ # path to chat template file, optional
110
+ chat_template_path: Optional[Path] = None
111
+ # content of chat template file, optional
112
+ chat_template_content: Optional[str] = None
107
113
 
108
114
 
109
115
  @dataclass
@@ -118,7 +124,8 @@ class SamplerConfig:
118
124
  frequency_penalty: float = 0.0
119
125
  seed: int = -1 # –1 for random
120
126
  grammar_path: Optional[Path] = None
121
- grammar_string: Optional[str] = None # Optional grammar string (BNF-like format)
127
+ # Optional grammar string (BNF-like format)
128
+ grammar_string: Optional[str] = None
122
129
 
123
130
 
124
131
  @dataclass
@@ -128,8 +135,10 @@ class GenerationConfig:
128
135
  stop: Sequence[str] = field(default_factory=tuple)
129
136
  n_past: int = 0
130
137
  sampler_config: Optional[SamplerConfig] = None
131
- image_paths: Optional[Sequence[Path]] = None # Array of image paths for VLM (None if none)
132
- audio_paths: Optional[Sequence[Path]] = None # Array of audio paths for VLM (None if none)
138
+ # Array of image paths for VLM (None if none)
139
+ image_paths: Optional[Sequence[Path]] = None
140
+ # Array of audio paths for VLM (None if none)
141
+ audio_paths: Optional[Sequence[Path]] = None
133
142
 
134
143
 
135
144
  @dataclass
@@ -170,6 +179,32 @@ class RerankConfig:
170
179
  normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
171
180
 
172
181
 
182
+ # image-gen
183
+
184
+
185
+ @dataclass
186
+ class ImageGenTxt2ImgInput:
187
+ """Input structure for text-to-image generation."""
188
+ prompt: str
189
+ config: ImageGenerationConfig
190
+ output_path: Optional[Path] = None
191
+
192
+
193
+ @dataclass
194
+ class ImageGenImg2ImgInput:
195
+ """Input structure for image-to-image generation."""
196
+ init_image_path: Path
197
+ prompt: str
198
+ config: ImageGenerationConfig
199
+ output_path: Optional[Path] = None
200
+
201
+
202
+ @dataclass
203
+ class ImageGenOutput:
204
+ """Output structure for image generation."""
205
+ output_image_path: Path
206
+
207
+
173
208
  @dataclass
174
209
  class ImageSamplerConfig:
175
210
  """Configuration for image sampling."""
@@ -180,17 +215,27 @@ class ImageSamplerConfig:
180
215
  seed: int = -1 # –1 for random
181
216
 
182
217
 
218
+ @dataclass
219
+ class ImageGenCreateInput:
220
+ """Configuration for image generation."""
221
+ model_name: str
222
+ model_path: Path
223
+ config: ModelConfig
224
+ scheduler_config_path: Path
225
+ plugin_id: str
226
+ device_id: Optional[str] = None
227
+
228
+
183
229
  @dataclass
184
230
  class ImageGenerationConfig:
185
231
  """Configuration for image generation."""
186
- prompts: str | List[str]
187
- negative_prompts: str | List[str] | None = None
232
+ prompts: List[str]
233
+ sampler_config: ImageSamplerConfig
234
+ scheduler_config: SchedulerConfig
235
+ strength: float
236
+ negative_prompts: Optional[List[str]] = None
188
237
  height: int = 512
189
238
  width: int = 512
190
- sampler_config: Optional[ImageSamplerConfig] = None
191
- lora_id: int = -1 # –1 for none
192
- init_image: Optional[Image] = None
193
- strength: float = 1.0
194
239
 
195
240
 
196
241
  @dataclass
@@ -261,7 +306,7 @@ class TTSResult:
261
306
  class BoundingBox:
262
307
  """Generic bounding box structure."""
263
308
  x: float # X coordinate (normalized or pixel, depends on model)
264
- y: float # Y coordinate (normalized or pixel, depends on model)
309
+ y: float # Y coordinate (normalized or pixel, depends on model)
265
310
  width: float # Width
266
311
  height: float # Height
267
312
 
@@ -275,7 +320,8 @@ class CVResult:
275
320
  confidence: float = 0.0 # Confidence score [0.0-1.0]
276
321
  bbox: Optional[BoundingBox] = None # Bounding box (example: YOLO)
277
322
  text: Optional[str] = None # Text result (example: OCR)
278
- embedding: Optional[List[float]] = None # Feature embedding (example: CLIP embedding)
323
+ # Feature embedding (example: CLIP embedding)
324
+ embedding: Optional[List[float]] = None
279
325
  embedding_dim: int = 0 # Embedding dimension
280
326
 
281
327
 
@@ -1,6 +1,7 @@
1
1
  # Copyright © 2023-2024 Apple Inc.
2
2
 
3
3
  import json
4
+ import os
4
5
  from typing import Optional
5
6
 
6
7
  import mlx.core as mx
@@ -176,19 +177,37 @@ def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
176
177
 
177
178
 
178
179
  def _check_key(key: str, part: str):
180
+ # Check if it's a local path
181
+ if os.path.exists(key) or '/' in key or '\\' in key:
182
+ # For local paths, we'll use a default model structure
183
+ return
179
184
  if key not in _MODELS:
180
185
  raise ValueError(
181
186
  f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
182
187
  )
183
188
 
189
+ def _get_model_path(key: str, file_path: str):
190
+ """Get the full path for a model file, supporting both local and HuggingFace paths"""
191
+ if os.path.exists(key) or '/' in key or '\\' in key:
192
+ # Local path
193
+ return os.path.join(key, file_path)
194
+ else:
195
+ # HuggingFace path
196
+ return hf_hub_download(key, file_path)
197
+
184
198
 
185
199
  def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
186
200
  """Load the stable diffusion UNet from Hugging Face Hub."""
187
201
  _check_key(key, "load_unet")
188
202
 
189
- # Download the config and create the model
190
- unet_config = _MODELS[key]["unet_config"]
191
- with open(hf_hub_download(key, unet_config)) as f:
203
+ # Get the config path
204
+ if os.path.exists(key) or '/' in key or '\\' in key:
205
+ # Local path - use SDXL Turbo structure
206
+ unet_config = "unet/config.json"
207
+ else:
208
+ unet_config = _MODELS[key]["unet_config"]
209
+
210
+ with open(_get_model_path(key, unet_config)) as f:
192
211
  config = json.load(f)
193
212
 
194
213
  n_blocks = len(config["block_out_channels"])
@@ -219,8 +238,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
219
238
  )
220
239
 
221
240
  # Download the weights and map them into the model
222
- unet_weights = _MODELS[key]["unet"]
223
- weight_file = hf_hub_download(key, unet_weights)
241
+ if os.path.exists(key) or '/' in key or '\\' in key:
242
+ # Local path - use SDXL Turbo structure
243
+ unet_weights = "unet/diffusion_pytorch_model.safetensors"
244
+ else:
245
+ unet_weights = _MODELS[key]["unet"]
246
+
247
+ weight_file = _get_model_path(key, unet_weights)
224
248
  _load_safetensor_weights(map_unet_weights, model, weight_file, float16)
225
249
 
226
250
  return model
@@ -238,8 +262,13 @@ def load_text_encoder(
238
262
  config_key = config_key or (model_key + "_config")
239
263
 
240
264
  # Download the config and create the model
241
- text_encoder_config = _MODELS[key][config_key]
242
- with open(hf_hub_download(key, text_encoder_config)) as f:
265
+ if os.path.exists(key) or '/' in key or '\\' in key:
266
+ # Local path - use SDXL Turbo structure
267
+ text_encoder_config = f"{model_key}/config.json"
268
+ else:
269
+ text_encoder_config = _MODELS[key][config_key]
270
+
271
+ with open(_get_model_path(key, text_encoder_config)) as f:
243
272
  config = json.load(f)
244
273
 
245
274
  with_projection = "WithProjection" in config["architectures"][0]
@@ -257,8 +286,13 @@ def load_text_encoder(
257
286
  )
258
287
 
259
288
  # Download the weights and map them into the model
260
- text_encoder_weights = _MODELS[key][model_key]
261
- weight_file = hf_hub_download(key, text_encoder_weights)
289
+ if os.path.exists(key) or '/' in key or '\\' in key:
290
+ # Local path - use SDXL Turbo structure
291
+ text_encoder_weights = f"{model_key}/model.safetensors"
292
+ else:
293
+ text_encoder_weights = _MODELS[key][model_key]
294
+
295
+ weight_file = _get_model_path(key, text_encoder_weights)
262
296
  _load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
263
297
 
264
298
  return model
@@ -269,8 +303,13 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
269
303
  _check_key(key, "load_autoencoder")
270
304
 
271
305
  # Download the config and create the model
272
- vae_config = _MODELS[key]["vae_config"]
273
- with open(hf_hub_download(key, vae_config)) as f:
306
+ if os.path.exists(key) or '/' in key or '\\' in key:
307
+ # Local path - use SDXL Turbo structure
308
+ vae_config = "vae/config.json"
309
+ else:
310
+ vae_config = _MODELS[key]["vae_config"]
311
+
312
+ with open(_get_model_path(key, vae_config)) as f:
274
313
  config = json.load(f)
275
314
 
276
315
  model = Autoencoder(
@@ -287,8 +326,13 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
287
326
  )
288
327
 
289
328
  # Download the weights and map them into the model
290
- vae_weights = _MODELS[key]["vae"]
291
- weight_file = hf_hub_download(key, vae_weights)
329
+ if os.path.exists(key) or '/' in key or '\\' in key:
330
+ # Local path - use SDXL Turbo structure
331
+ vae_weights = "vae/diffusion_pytorch_model.safetensors"
332
+ else:
333
+ vae_weights = _MODELS[key]["vae"]
334
+
335
+ weight_file = _get_model_path(key, vae_weights)
292
336
  _load_safetensor_weights(map_vae_weights, model, weight_file, float16)
293
337
 
294
338
  return model
@@ -298,8 +342,13 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
298
342
  """Load the stable diffusion config from Hugging Face Hub."""
299
343
  _check_key(key, "load_diffusion_config")
300
344
 
301
- diffusion_config = _MODELS[key]["diffusion_config"]
302
- with open(hf_hub_download(key, diffusion_config)) as f:
345
+ if os.path.exists(key) or '/' in key or '\\' in key:
346
+ # Local path - use SDXL Turbo structure
347
+ diffusion_config = "scheduler/scheduler_config.json"
348
+ else:
349
+ diffusion_config = _MODELS[key]["diffusion_config"]
350
+
351
+ with open(_get_model_path(key, diffusion_config)) as f:
303
352
  config = json.load(f)
304
353
 
305
354
  return DiffusionConfig(
@@ -317,11 +366,17 @@ def load_tokenizer(
317
366
  ):
318
367
  _check_key(key, "load_tokenizer")
319
368
 
320
- vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
369
+ if os.path.exists(key) or '/' in key or '\\' in key:
370
+ # Local path - use SDXL Turbo structure
371
+ vocab_file = _get_model_path(key, f"tokenizer/{vocab_key.split('_')[1]}.json")
372
+ merges_file = _get_model_path(key, f"tokenizer/{merges_key.split('_')[1]}.txt")
373
+ else:
374
+ vocab_file = _get_model_path(key, _MODELS[key][vocab_key])
375
+ merges_file = _get_model_path(key, _MODELS[key][merges_key])
376
+
321
377
  with open(vocab_file, encoding="utf-8") as f:
322
378
  vocab = json.load(f)
323
379
 
324
- merges_file = hf_hub_download(key, _MODELS[key][merges_key])
325
380
  with open(merges_file, encoding="utf-8") as f:
326
381
  bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
327
382
  bpe_merges = [tuple(m.split()) for m in bpe_merges]
nexaai/runtime.py CHANGED
@@ -28,6 +28,10 @@ def _shutdown_runtime() -> None:
28
28
  # Public helper so advanced users can reclaim memory on demand
29
29
  shutdown = _shutdown_runtime
30
30
 
31
+ def is_initialized() -> bool:
32
+ """Check if the runtime has been initialized."""
33
+ return _runtime_alive
34
+
31
35
  # ----------------------------------------------------------------------
32
36
  # Single public class
33
37
  # ----------------------------------------------------------------------
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nexaai
3
- Version: 1.0.16rc10
3
+ Version: 1.0.16rc12
4
4
  Summary: Python bindings for NexaSDK C-lib backend
5
5
  Author-email: "Nexa AI, Inc." <dev@nexa.ai>
6
6
  Project-URL: Homepage, https://github.com/NexaAI/nexasdk-bridge