nexaai 1.0.15rc1__cp310-cp310-macosx_13_0_x86_64.whl → 1.0.16__cp310-cp310-macosx_13_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +7 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +1 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/ml.py +60 -14
- nexaai/mlx_backend/sd/modeling/model_io.py +72 -17
- nexaai/mlx_backend/vlm/interface.py +33 -2
- nexaai/runtime.py +4 -0
- nexaai/utils/quantization_utils.py +7 -1
- nexaai/vlm.py +4 -3
- nexaai/vlm_impl/mlx_vlm_impl.py +3 -1
- nexaai/vlm_impl/pybind_vlm_impl.py +3 -1
- {nexaai-1.0.15rc1.dist-info → nexaai-1.0.16.dist-info}/METADATA +1 -1
- {nexaai-1.0.15rc1.dist-info → nexaai-1.0.16.dist-info}/RECORD +35 -22
- {nexaai-1.0.15rc1.dist-info → nexaai-1.0.16.dist-info}/WHEEL +0 -0
- {nexaai-1.0.15rc1.dist-info → nexaai-1.0.16.dist-info}/top_level.txt +0 -0
nexaai/__init__.py
CHANGED
|
@@ -21,6 +21,9 @@ except ImportError:
|
|
|
21
21
|
# Import common configuration classes first (no external dependencies)
|
|
22
22
|
from .common import ModelConfig, GenerationConfig, ChatMessage, SamplerConfig, PluginID
|
|
23
23
|
|
|
24
|
+
# Import logging functionality
|
|
25
|
+
from .log import set_logger, get_error_message
|
|
26
|
+
|
|
24
27
|
# Create alias for PluginID to be accessible as plugin_id
|
|
25
28
|
plugin_id = PluginID
|
|
26
29
|
|
|
@@ -45,6 +48,10 @@ __all__ = [
|
|
|
45
48
|
"EmbeddingConfig",
|
|
46
49
|
"PluginID",
|
|
47
50
|
"plugin_id",
|
|
51
|
+
|
|
52
|
+
# Logging functionality
|
|
53
|
+
"set_logger",
|
|
54
|
+
"get_error_message",
|
|
48
55
|
|
|
49
56
|
"LLM",
|
|
50
57
|
"Embedder",
|
|
Binary file
|
nexaai/_version.py
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
nexaai/common.py
CHANGED
nexaai/log.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Logging configuration for NexaAI bridge.
|
|
3
|
+
|
|
4
|
+
This module provides a minimal API to configure bridge-wide logging
|
|
5
|
+
to route into Python's logging system.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import threading
|
|
10
|
+
from enum import IntEnum
|
|
11
|
+
from typing import Optional
|
|
12
|
+
|
|
13
|
+
from nexaai.binds import common_bind
|
|
14
|
+
from nexaai.runtime import is_initialized
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LogLevel(IntEnum):
|
|
18
|
+
"""Log levels matching ml_LogLevel from ml.h"""
|
|
19
|
+
TRACE = 0
|
|
20
|
+
DEBUG = 1
|
|
21
|
+
INFO = 2
|
|
22
|
+
WARN = 3
|
|
23
|
+
ERROR = 4
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Module-level state
|
|
27
|
+
_config_lock = threading.Lock()
|
|
28
|
+
_current_logger: Optional[logging.Logger] = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def set_logger(logger: Optional[logging.Logger] = None, *, strict: bool = True) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Set the process-wide bridge logger.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
logger: Python logger to receive bridge logs. If None, uses "nexaai.ml" logger.
|
|
37
|
+
strict: If True, raises if called after runtime initialization.
|
|
38
|
+
If False, attempts to set anyway (best-effort).
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
RuntimeError: If strict=True and runtime is already initialized.
|
|
42
|
+
"""
|
|
43
|
+
global _current_logger
|
|
44
|
+
|
|
45
|
+
with _config_lock:
|
|
46
|
+
# Check initialization state if strict mode
|
|
47
|
+
if strict and is_initialized():
|
|
48
|
+
raise RuntimeError(
|
|
49
|
+
"Cannot configure logging after runtime initialization. "
|
|
50
|
+
"Call set_logger() before creating any models, or use strict=False for best-effort."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Use default logger if none provided
|
|
54
|
+
if logger is None:
|
|
55
|
+
logger = logging.getLogger("nexaai.ml")
|
|
56
|
+
|
|
57
|
+
_current_logger = logger
|
|
58
|
+
|
|
59
|
+
# Set the C callback
|
|
60
|
+
common_bind.ml_set_log(_log_callback)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _log_callback(level: int, message: str) -> None:
|
|
64
|
+
"""Internal callback that forwards bridge logs to Python logger."""
|
|
65
|
+
if _current_logger is None:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
# Map bridge log levels to Python logging levels
|
|
69
|
+
if level == LogLevel.TRACE or level == LogLevel.DEBUG:
|
|
70
|
+
_current_logger.debug(message)
|
|
71
|
+
elif level == LogLevel.INFO:
|
|
72
|
+
_current_logger.info(message)
|
|
73
|
+
elif level == LogLevel.WARN:
|
|
74
|
+
_current_logger.warning(message)
|
|
75
|
+
elif level == LogLevel.ERROR:
|
|
76
|
+
_current_logger.error(message)
|
|
77
|
+
else:
|
|
78
|
+
# Fallback for unknown levels
|
|
79
|
+
_current_logger.info(f"[Level {level}] {message}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_error_message(error_code: int) -> str:
|
|
83
|
+
"""
|
|
84
|
+
Get error message string for error code.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
error_code: ML error code (typically negative)
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Human-readable error message
|
|
91
|
+
"""
|
|
92
|
+
return common_bind.ml_get_error_message(error_code)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Image generation module for MLX backend
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import (
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image as PILImage
|
|
11
|
+
import mlx.nn as nn
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
from .stable_diffusion import StableDiffusion, StableDiffusionXL
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Image:
|
|
18
|
+
def __init__(self, data: List[float], width: int, height: int, channels: int) -> None:
|
|
19
|
+
"""Initialize an image with pixel data"""
|
|
20
|
+
self.data = data
|
|
21
|
+
self.width = width
|
|
22
|
+
self.height = height
|
|
23
|
+
self.channels = channels
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_numpy(cls, array: np.ndarray) -> 'Image':
|
|
27
|
+
"""Create Image from numpy array (H, W, C)"""
|
|
28
|
+
height, width, channels = array.shape
|
|
29
|
+
data = array.flatten().tolist()
|
|
30
|
+
return cls(data, width, height, channels)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_pil(cls, pil_image: PILImage.Image) -> 'Image':
|
|
34
|
+
"""Create Image from PIL Image"""
|
|
35
|
+
array = np.array(pil_image).astype(np.float32) / 255.0
|
|
36
|
+
return cls.from_numpy(array)
|
|
37
|
+
|
|
38
|
+
def to_numpy(self) -> np.ndarray:
|
|
39
|
+
"""Convert to numpy array (H, W, C)"""
|
|
40
|
+
return np.array(self.data).reshape(self.height, self.width, self.channels)
|
|
41
|
+
|
|
42
|
+
def to_pil(self) -> PILImage.Image:
|
|
43
|
+
"""Convert to PIL Image"""
|
|
44
|
+
array = (self.to_numpy() * 255).astype(np.uint8)
|
|
45
|
+
return PILImage.fromarray(array)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ImageSamplerConfig:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
method: str = "ddim",
|
|
52
|
+
steps: int = 4, # SDXL Turbo typically uses fewer steps
|
|
53
|
+
guidance_scale: float = 0.0, # SDXL Turbo works well with no guidance
|
|
54
|
+
eta: float = 0.0,
|
|
55
|
+
seed: int = -1,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Initialize sampler configuration optimized for SDXL Turbo"""
|
|
58
|
+
self.method = method
|
|
59
|
+
self.steps = steps
|
|
60
|
+
self.guidance_scale = guidance_scale
|
|
61
|
+
self.eta = eta
|
|
62
|
+
self.seed = seed
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ImageGenerationConfig:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
prompts: str | List[str],
|
|
69
|
+
negative_prompts: str | List[str] | None = None,
|
|
70
|
+
height: int = 512,
|
|
71
|
+
width: int = 512,
|
|
72
|
+
sampler_config: Optional[ImageSamplerConfig] = None,
|
|
73
|
+
lora_id: int = -1, # Not used but kept for compatibility
|
|
74
|
+
init_image: Optional[Image] = None,
|
|
75
|
+
strength: float = 1.0,
|
|
76
|
+
n_images: int = 1,
|
|
77
|
+
n_rows: int = 1,
|
|
78
|
+
decoding_batch_size: int = 1,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Initialize image generation configuration"""
|
|
81
|
+
self.prompts = prompts
|
|
82
|
+
self.negative_prompts = negative_prompts or ""
|
|
83
|
+
self.height = height
|
|
84
|
+
self.width = width
|
|
85
|
+
self.sampler_config = sampler_config or ImageSamplerConfig()
|
|
86
|
+
self.lora_id = lora_id
|
|
87
|
+
self.init_image = init_image
|
|
88
|
+
self.strength = strength
|
|
89
|
+
self.n_images = n_images
|
|
90
|
+
self.n_rows = n_rows
|
|
91
|
+
self.decoding_batch_size = decoding_batch_size
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ImageGen:
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
model_path: str,
|
|
98
|
+
scheduler_config_path: Optional[str] = None,
|
|
99
|
+
device: Optional[str] = None,
|
|
100
|
+
float16: bool = True,
|
|
101
|
+
quantize: bool = False,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Initialize the image generation model for SDXL Turbo"""
|
|
104
|
+
self.model_path = model_path
|
|
105
|
+
self.scheduler_config_path = scheduler_config_path
|
|
106
|
+
self.float16 = float16
|
|
107
|
+
self.quantize = quantize
|
|
108
|
+
self.model = None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def load_model(model_path: str, float16: bool = True, quantize: bool = False) -> StableDiffusion:
|
|
112
|
+
"""Load a model from the given path - following txt2img.py pattern"""
|
|
113
|
+
|
|
114
|
+
# Check if it's a local path or HuggingFace repo
|
|
115
|
+
# If it contains path separators or exists as a file/directory, treat as local
|
|
116
|
+
is_local_path = (
|
|
117
|
+
'/' in model_path or '\\' in model_path or os.path.exists(model_path))
|
|
118
|
+
|
|
119
|
+
if is_local_path:
|
|
120
|
+
# For local paths, determine model type from the path or model files
|
|
121
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
122
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
123
|
+
else:
|
|
124
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
125
|
+
else:
|
|
126
|
+
# For HuggingFace repo names, use the original logic
|
|
127
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
128
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
129
|
+
else:
|
|
130
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
131
|
+
|
|
132
|
+
# Apply quantization if requested - same as txt2img.py
|
|
133
|
+
if quantize:
|
|
134
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
135
|
+
nn.quantize(
|
|
136
|
+
model.text_encoder_1, class_predicate=lambda _, m: isinstance(
|
|
137
|
+
m, nn.Linear)
|
|
138
|
+
)
|
|
139
|
+
nn.quantize(
|
|
140
|
+
model.text_encoder_2, class_predicate=lambda _, m: isinstance(
|
|
141
|
+
m, nn.Linear)
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
nn.quantize(
|
|
145
|
+
model.text_encoder, class_predicate=lambda _, m: isinstance(
|
|
146
|
+
m, nn.Linear)
|
|
147
|
+
)
|
|
148
|
+
nn.quantize(model.unet, group_size=32, bits=8)
|
|
149
|
+
return model
|
|
150
|
+
|
|
151
|
+
def txt2img(self, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
152
|
+
"""Generate an image from a text prompt - following txt2img.py pattern"""
|
|
153
|
+
if not self.model:
|
|
154
|
+
self.model = self.load_model(self.model_path)
|
|
155
|
+
if not self.model:
|
|
156
|
+
raise RuntimeError("Model not loaded")
|
|
157
|
+
|
|
158
|
+
sampler_config = config.sampler_config
|
|
159
|
+
|
|
160
|
+
negative_prompt = ""
|
|
161
|
+
if config.negative_prompts:
|
|
162
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
163
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
164
|
+
|
|
165
|
+
# Generate latents - following txt2img.py approach
|
|
166
|
+
latents_generator = self.model.generate_latents(
|
|
167
|
+
prompt,
|
|
168
|
+
n_images=1,
|
|
169
|
+
num_steps=sampler_config.steps,
|
|
170
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
171
|
+
negative_text=negative_prompt,
|
|
172
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Get final latents - following txt2img.py pattern
|
|
176
|
+
final_latents = None
|
|
177
|
+
for latents in latents_generator:
|
|
178
|
+
final_latents = latents
|
|
179
|
+
mx.eval(final_latents)
|
|
180
|
+
|
|
181
|
+
if final_latents is None:
|
|
182
|
+
raise RuntimeError("No latents generated")
|
|
183
|
+
|
|
184
|
+
# Decode to image - following txt2img.py pattern
|
|
185
|
+
decoded_image = self.model.decode(final_latents)
|
|
186
|
+
mx.eval(decoded_image)
|
|
187
|
+
|
|
188
|
+
# Convert to numpy array
|
|
189
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
190
|
+
|
|
191
|
+
if clear_cache:
|
|
192
|
+
mx.clear_cache()
|
|
193
|
+
|
|
194
|
+
return Image.from_numpy(image_array)
|
|
195
|
+
|
|
196
|
+
def img2img(self, init_image: Image, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
197
|
+
"""Generate an image from an initial image and a text prompt using SDXL Turbo"""
|
|
198
|
+
if not self.model:
|
|
199
|
+
self.model = self.load_model(self.model_path)
|
|
200
|
+
if not self.model:
|
|
201
|
+
raise RuntimeError("Model not loaded")
|
|
202
|
+
|
|
203
|
+
sampler_config = config.sampler_config
|
|
204
|
+
|
|
205
|
+
negative_prompt = ""
|
|
206
|
+
if config.negative_prompts:
|
|
207
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
208
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
209
|
+
|
|
210
|
+
img_tensor = _prepare_image_for_sd(
|
|
211
|
+
init_image, config.width, config.height)
|
|
212
|
+
|
|
213
|
+
# Generate latents from image
|
|
214
|
+
latents_generator = self.model.generate_latents_from_image(
|
|
215
|
+
img_tensor,
|
|
216
|
+
prompt,
|
|
217
|
+
n_images=1,
|
|
218
|
+
strength=config.strength,
|
|
219
|
+
num_steps=sampler_config.steps,
|
|
220
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
221
|
+
negative_text=negative_prompt,
|
|
222
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Get final latents
|
|
226
|
+
final_latents = None
|
|
227
|
+
for latents in latents_generator:
|
|
228
|
+
final_latents = latents
|
|
229
|
+
mx.eval(final_latents)
|
|
230
|
+
|
|
231
|
+
if final_latents is None:
|
|
232
|
+
raise RuntimeError("No latents generated")
|
|
233
|
+
|
|
234
|
+
# Decode to image
|
|
235
|
+
decoded_image = self.model.decode(final_latents)
|
|
236
|
+
mx.eval(decoded_image)
|
|
237
|
+
|
|
238
|
+
# Convert to numpy array
|
|
239
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
240
|
+
|
|
241
|
+
if clear_cache:
|
|
242
|
+
mx.clear_cache()
|
|
243
|
+
|
|
244
|
+
return Image.from_numpy(image_array)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ml import ImageGenCreateInput, ImageGenerationConfig, ImageGenImg2ImgInput, ImageGenTxt2ImgInput, ImageGenOutput
|
|
6
|
+
from profiling import ProfilingMixin, StopReason
|
|
7
|
+
|
|
8
|
+
from .generate_sd import ImageGen as SDImageGen, Image, ImageGenerationConfig as SDImageGenerationConfig, ImageSamplerConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageGen(ProfilingMixin):
|
|
12
|
+
sd_gen: Optional[SDImageGen] = None
|
|
13
|
+
|
|
14
|
+
def __init__(self, input: ImageGenCreateInput):
|
|
15
|
+
"""Initialize the image generation model"""
|
|
16
|
+
self.sd_gen = SDImageGen(model_path=input.model_path)
|
|
17
|
+
|
|
18
|
+
def destroy(self) -> None:
|
|
19
|
+
"""Clean up resources"""
|
|
20
|
+
self.sd_gen = None
|
|
21
|
+
|
|
22
|
+
def txt2img(self, input: ImageGenTxt2ImgInput) -> ImageGenOutput:
|
|
23
|
+
"""Generate an image from a text prompt - public interface"""
|
|
24
|
+
height = input.config.height
|
|
25
|
+
width = input.config.width
|
|
26
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
27
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
28
|
+
|
|
29
|
+
internal_config = SDImageGenerationConfig(
|
|
30
|
+
prompts=input.prompt,
|
|
31
|
+
negative_prompts=input.config.negative_prompts,
|
|
32
|
+
height=height,
|
|
33
|
+
width=width,
|
|
34
|
+
sampler_config=ImageSamplerConfig(
|
|
35
|
+
steps=input.config.sampler_config.steps,
|
|
36
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
37
|
+
seed=input.config.sampler_config.seed
|
|
38
|
+
),
|
|
39
|
+
strength=input.config.strength
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
result_image = self.sd_gen.txt2img(input.prompt, internal_config)
|
|
43
|
+
|
|
44
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
45
|
+
if parent_dir:
|
|
46
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
47
|
+
result_image.to_pil().save(input.output_path)
|
|
48
|
+
|
|
49
|
+
return ImageGenOutput(output_image_path=input.output_path)
|
|
50
|
+
|
|
51
|
+
def img2img(self, input: ImageGenImg2ImgInput) -> ImageGenOutput:
|
|
52
|
+
"""Generate an image from an initial image and a text prompt - public interface"""
|
|
53
|
+
height = input.config.height
|
|
54
|
+
width = input.config.width
|
|
55
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
56
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
57
|
+
|
|
58
|
+
init_image = Image.from_pil(input.init_image_path)
|
|
59
|
+
|
|
60
|
+
internal_config = SDImageGenerationConfig(
|
|
61
|
+
prompts=input.prompt,
|
|
62
|
+
negative_prompts=input.config.negative_prompts,
|
|
63
|
+
height=height,
|
|
64
|
+
width=width,
|
|
65
|
+
sampler_config=ImageSamplerConfig(
|
|
66
|
+
steps=input.config.sampler_config.steps,
|
|
67
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
68
|
+
seed=input.config.sampler_config.seed
|
|
69
|
+
),
|
|
70
|
+
init_image=init_image,
|
|
71
|
+
strength=input.config.strength
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
result_image = self.sd_gen.img2img(
|
|
75
|
+
init_image, input.prompt, internal_config)
|
|
76
|
+
|
|
77
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
78
|
+
if parent_dir:
|
|
79
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
80
|
+
result_image.to_pil().save(input.output_path)
|
|
81
|
+
|
|
82
|
+
return ImageGenOutput(output_image_path=input.output_path)
|