nexaai 1.0.16rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.16rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +60 -14
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/ml.py +60 -14
- nexaai/mlx_backend/sd/modeling/model_io.py +72 -17
- {nexaai-1.0.16rc5.dist-info → nexaai-1.0.16rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.16rc5.dist-info → nexaai-1.0.16rc7.dist-info}/RECORD +23 -11
- {nexaai-1.0.16rc5.dist-info → nexaai-1.0.16rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.16rc5.dist-info → nexaai-1.0.16rc7.dist-info}/top_level.txt +0 -0
|
Binary file
|
nexaai/_version.py
CHANGED
|
Binary file
|
|
Binary file
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# This file defines the python interface that c-lib expects from a python backend
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from dataclasses import dataclass
|
|
4
7
|
|
|
5
8
|
from abc import ABC, abstractmethod
|
|
6
9
|
from dataclasses import dataclass, field
|
|
@@ -101,9 +104,12 @@ class ModelConfig:
|
|
|
101
104
|
n_threads_batch: int = 0 # number of threads to use for batch processing
|
|
102
105
|
n_batch: int = 0 # logical maximum batch size that can be submitted to llama_decode
|
|
103
106
|
n_ubatch: int = 0 # physical maximum batch size
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
+
# max number of sequences (i.e. distinct states for recurrent models)
|
|
108
|
+
n_seq_max: int = 0
|
|
109
|
+
# path to chat template file, optional
|
|
110
|
+
chat_template_path: Optional[Path] = None
|
|
111
|
+
# content of chat template file, optional
|
|
112
|
+
chat_template_content: Optional[str] = None
|
|
107
113
|
|
|
108
114
|
|
|
109
115
|
@dataclass
|
|
@@ -118,7 +124,8 @@ class SamplerConfig:
|
|
|
118
124
|
frequency_penalty: float = 0.0
|
|
119
125
|
seed: int = -1 # –1 for random
|
|
120
126
|
grammar_path: Optional[Path] = None
|
|
121
|
-
|
|
127
|
+
# Optional grammar string (BNF-like format)
|
|
128
|
+
grammar_string: Optional[str] = None
|
|
122
129
|
|
|
123
130
|
|
|
124
131
|
@dataclass
|
|
@@ -128,8 +135,10 @@ class GenerationConfig:
|
|
|
128
135
|
stop: Sequence[str] = field(default_factory=tuple)
|
|
129
136
|
n_past: int = 0
|
|
130
137
|
sampler_config: Optional[SamplerConfig] = None
|
|
131
|
-
|
|
132
|
-
|
|
138
|
+
# Array of image paths for VLM (None if none)
|
|
139
|
+
image_paths: Optional[Sequence[Path]] = None
|
|
140
|
+
# Array of audio paths for VLM (None if none)
|
|
141
|
+
audio_paths: Optional[Sequence[Path]] = None
|
|
133
142
|
|
|
134
143
|
|
|
135
144
|
@dataclass
|
|
@@ -170,6 +179,32 @@ class RerankConfig:
|
|
|
170
179
|
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
171
180
|
|
|
172
181
|
|
|
182
|
+
# image-gen
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class ImageGenTxt2ImgInput:
|
|
187
|
+
"""Input structure for text-to-image generation."""
|
|
188
|
+
prompt: str
|
|
189
|
+
config: ImageGenerationConfig
|
|
190
|
+
output_path: Optional[Path] = None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@dataclass
|
|
194
|
+
class ImageGenImg2ImgInput:
|
|
195
|
+
"""Input structure for image-to-image generation."""
|
|
196
|
+
init_image_path: Path
|
|
197
|
+
prompt: str
|
|
198
|
+
config: ImageGenerationConfig
|
|
199
|
+
output_path: Optional[Path] = None
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass
|
|
203
|
+
class ImageGenOutput:
|
|
204
|
+
"""Output structure for image generation."""
|
|
205
|
+
output_image_path: Path
|
|
206
|
+
|
|
207
|
+
|
|
173
208
|
@dataclass
|
|
174
209
|
class ImageSamplerConfig:
|
|
175
210
|
"""Configuration for image sampling."""
|
|
@@ -180,17 +215,27 @@ class ImageSamplerConfig:
|
|
|
180
215
|
seed: int = -1 # –1 for random
|
|
181
216
|
|
|
182
217
|
|
|
218
|
+
@dataclass
|
|
219
|
+
class ImageGenCreateInput:
|
|
220
|
+
"""Configuration for image generation."""
|
|
221
|
+
model_name: str
|
|
222
|
+
model_path: Path
|
|
223
|
+
config: ModelConfig
|
|
224
|
+
scheduler_config_path: Path
|
|
225
|
+
plugin_id: str
|
|
226
|
+
device_id: Optional[str] = None
|
|
227
|
+
|
|
228
|
+
|
|
183
229
|
@dataclass
|
|
184
230
|
class ImageGenerationConfig:
|
|
185
231
|
"""Configuration for image generation."""
|
|
186
|
-
prompts:
|
|
187
|
-
|
|
232
|
+
prompts: List[str]
|
|
233
|
+
sampler_config: ImageSamplerConfig
|
|
234
|
+
scheduler_config: SchedulerConfig
|
|
235
|
+
strength: float
|
|
236
|
+
negative_prompts: Optional[List[str]] = None
|
|
188
237
|
height: int = 512
|
|
189
238
|
width: int = 512
|
|
190
|
-
sampler_config: Optional[ImageSamplerConfig] = None
|
|
191
|
-
lora_id: int = -1 # –1 for none
|
|
192
|
-
init_image: Optional[Image] = None
|
|
193
|
-
strength: float = 1.0
|
|
194
239
|
|
|
195
240
|
|
|
196
241
|
@dataclass
|
|
@@ -261,7 +306,7 @@ class TTSResult:
|
|
|
261
306
|
class BoundingBox:
|
|
262
307
|
"""Generic bounding box structure."""
|
|
263
308
|
x: float # X coordinate (normalized or pixel, depends on model)
|
|
264
|
-
y: float # Y coordinate (normalized or pixel, depends on model)
|
|
309
|
+
y: float # Y coordinate (normalized or pixel, depends on model)
|
|
265
310
|
width: float # Width
|
|
266
311
|
height: float # Height
|
|
267
312
|
|
|
@@ -275,7 +320,8 @@ class CVResult:
|
|
|
275
320
|
confidence: float = 0.0 # Confidence score [0.0-1.0]
|
|
276
321
|
bbox: Optional[BoundingBox] = None # Bounding box (example: YOLO)
|
|
277
322
|
text: Optional[str] = None # Text result (example: OCR)
|
|
278
|
-
|
|
323
|
+
# Feature embedding (example: CLIP embedding)
|
|
324
|
+
embedding: Optional[List[float]] = None
|
|
279
325
|
embedding_dim: int = 0 # Embedding dimension
|
|
280
326
|
|
|
281
327
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Image generation module for MLX backend
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import (
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image as PILImage
|
|
11
|
+
import mlx.nn as nn
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
from .stable_diffusion import StableDiffusion, StableDiffusionXL
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Image:
|
|
18
|
+
def __init__(self, data: List[float], width: int, height: int, channels: int) -> None:
|
|
19
|
+
"""Initialize an image with pixel data"""
|
|
20
|
+
self.data = data
|
|
21
|
+
self.width = width
|
|
22
|
+
self.height = height
|
|
23
|
+
self.channels = channels
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_numpy(cls, array: np.ndarray) -> 'Image':
|
|
27
|
+
"""Create Image from numpy array (H, W, C)"""
|
|
28
|
+
height, width, channels = array.shape
|
|
29
|
+
data = array.flatten().tolist()
|
|
30
|
+
return cls(data, width, height, channels)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_pil(cls, pil_image: PILImage.Image) -> 'Image':
|
|
34
|
+
"""Create Image from PIL Image"""
|
|
35
|
+
array = np.array(pil_image).astype(np.float32) / 255.0
|
|
36
|
+
return cls.from_numpy(array)
|
|
37
|
+
|
|
38
|
+
def to_numpy(self) -> np.ndarray:
|
|
39
|
+
"""Convert to numpy array (H, W, C)"""
|
|
40
|
+
return np.array(self.data).reshape(self.height, self.width, self.channels)
|
|
41
|
+
|
|
42
|
+
def to_pil(self) -> PILImage.Image:
|
|
43
|
+
"""Convert to PIL Image"""
|
|
44
|
+
array = (self.to_numpy() * 255).astype(np.uint8)
|
|
45
|
+
return PILImage.fromarray(array)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ImageSamplerConfig:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
method: str = "ddim",
|
|
52
|
+
steps: int = 4, # SDXL Turbo typically uses fewer steps
|
|
53
|
+
guidance_scale: float = 0.0, # SDXL Turbo works well with no guidance
|
|
54
|
+
eta: float = 0.0,
|
|
55
|
+
seed: int = -1,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Initialize sampler configuration optimized for SDXL Turbo"""
|
|
58
|
+
self.method = method
|
|
59
|
+
self.steps = steps
|
|
60
|
+
self.guidance_scale = guidance_scale
|
|
61
|
+
self.eta = eta
|
|
62
|
+
self.seed = seed
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ImageGenerationConfig:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
prompts: str | List[str],
|
|
69
|
+
negative_prompts: str | List[str] | None = None,
|
|
70
|
+
height: int = 512,
|
|
71
|
+
width: int = 512,
|
|
72
|
+
sampler_config: Optional[ImageSamplerConfig] = None,
|
|
73
|
+
lora_id: int = -1, # Not used but kept for compatibility
|
|
74
|
+
init_image: Optional[Image] = None,
|
|
75
|
+
strength: float = 1.0,
|
|
76
|
+
n_images: int = 1,
|
|
77
|
+
n_rows: int = 1,
|
|
78
|
+
decoding_batch_size: int = 1,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Initialize image generation configuration"""
|
|
81
|
+
self.prompts = prompts
|
|
82
|
+
self.negative_prompts = negative_prompts or ""
|
|
83
|
+
self.height = height
|
|
84
|
+
self.width = width
|
|
85
|
+
self.sampler_config = sampler_config or ImageSamplerConfig()
|
|
86
|
+
self.lora_id = lora_id
|
|
87
|
+
self.init_image = init_image
|
|
88
|
+
self.strength = strength
|
|
89
|
+
self.n_images = n_images
|
|
90
|
+
self.n_rows = n_rows
|
|
91
|
+
self.decoding_batch_size = decoding_batch_size
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ImageGen:
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
model_path: str,
|
|
98
|
+
scheduler_config_path: Optional[str] = None,
|
|
99
|
+
device: Optional[str] = None,
|
|
100
|
+
float16: bool = True,
|
|
101
|
+
quantize: bool = False,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Initialize the image generation model for SDXL Turbo"""
|
|
104
|
+
self.model_path = model_path
|
|
105
|
+
self.scheduler_config_path = scheduler_config_path
|
|
106
|
+
self.float16 = float16
|
|
107
|
+
self.quantize = quantize
|
|
108
|
+
self.model = None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def load_model(model_path: str, float16: bool = True, quantize: bool = False) -> StableDiffusion:
|
|
112
|
+
"""Load a model from the given path - following txt2img.py pattern"""
|
|
113
|
+
|
|
114
|
+
# Check if it's a local path or HuggingFace repo
|
|
115
|
+
# If it contains path separators or exists as a file/directory, treat as local
|
|
116
|
+
is_local_path = (
|
|
117
|
+
'/' in model_path or '\\' in model_path or os.path.exists(model_path))
|
|
118
|
+
|
|
119
|
+
if is_local_path:
|
|
120
|
+
# For local paths, determine model type from the path or model files
|
|
121
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
122
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
123
|
+
else:
|
|
124
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
125
|
+
else:
|
|
126
|
+
# For HuggingFace repo names, use the original logic
|
|
127
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
128
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
129
|
+
else:
|
|
130
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
131
|
+
|
|
132
|
+
# Apply quantization if requested - same as txt2img.py
|
|
133
|
+
if quantize:
|
|
134
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
135
|
+
nn.quantize(
|
|
136
|
+
model.text_encoder_1, class_predicate=lambda _, m: isinstance(
|
|
137
|
+
m, nn.Linear)
|
|
138
|
+
)
|
|
139
|
+
nn.quantize(
|
|
140
|
+
model.text_encoder_2, class_predicate=lambda _, m: isinstance(
|
|
141
|
+
m, nn.Linear)
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
nn.quantize(
|
|
145
|
+
model.text_encoder, class_predicate=lambda _, m: isinstance(
|
|
146
|
+
m, nn.Linear)
|
|
147
|
+
)
|
|
148
|
+
nn.quantize(model.unet, group_size=32, bits=8)
|
|
149
|
+
return model
|
|
150
|
+
|
|
151
|
+
def txt2img(self, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
152
|
+
"""Generate an image from a text prompt - following txt2img.py pattern"""
|
|
153
|
+
if not self.model:
|
|
154
|
+
self.model = self.load_model(self.model_path)
|
|
155
|
+
if not self.model:
|
|
156
|
+
raise RuntimeError("Model not loaded")
|
|
157
|
+
|
|
158
|
+
sampler_config = config.sampler_config
|
|
159
|
+
|
|
160
|
+
negative_prompt = ""
|
|
161
|
+
if config.negative_prompts:
|
|
162
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
163
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
164
|
+
|
|
165
|
+
# Generate latents - following txt2img.py approach
|
|
166
|
+
latents_generator = self.model.generate_latents(
|
|
167
|
+
prompt,
|
|
168
|
+
n_images=1,
|
|
169
|
+
num_steps=sampler_config.steps,
|
|
170
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
171
|
+
negative_text=negative_prompt,
|
|
172
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Get final latents - following txt2img.py pattern
|
|
176
|
+
final_latents = None
|
|
177
|
+
for latents in latents_generator:
|
|
178
|
+
final_latents = latents
|
|
179
|
+
mx.eval(final_latents)
|
|
180
|
+
|
|
181
|
+
if final_latents is None:
|
|
182
|
+
raise RuntimeError("No latents generated")
|
|
183
|
+
|
|
184
|
+
# Decode to image - following txt2img.py pattern
|
|
185
|
+
decoded_image = self.model.decode(final_latents)
|
|
186
|
+
mx.eval(decoded_image)
|
|
187
|
+
|
|
188
|
+
# Convert to numpy array
|
|
189
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
190
|
+
|
|
191
|
+
if clear_cache:
|
|
192
|
+
mx.clear_cache()
|
|
193
|
+
|
|
194
|
+
return Image.from_numpy(image_array)
|
|
195
|
+
|
|
196
|
+
def img2img(self, init_image: Image, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
197
|
+
"""Generate an image from an initial image and a text prompt using SDXL Turbo"""
|
|
198
|
+
if not self.model:
|
|
199
|
+
self.model = self.load_model(self.model_path)
|
|
200
|
+
if not self.model:
|
|
201
|
+
raise RuntimeError("Model not loaded")
|
|
202
|
+
|
|
203
|
+
sampler_config = config.sampler_config
|
|
204
|
+
|
|
205
|
+
negative_prompt = ""
|
|
206
|
+
if config.negative_prompts:
|
|
207
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
208
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
209
|
+
|
|
210
|
+
img_tensor = _prepare_image_for_sd(
|
|
211
|
+
init_image, config.width, config.height)
|
|
212
|
+
|
|
213
|
+
# Generate latents from image
|
|
214
|
+
latents_generator = self.model.generate_latents_from_image(
|
|
215
|
+
img_tensor,
|
|
216
|
+
prompt,
|
|
217
|
+
n_images=1,
|
|
218
|
+
strength=config.strength,
|
|
219
|
+
num_steps=sampler_config.steps,
|
|
220
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
221
|
+
negative_text=negative_prompt,
|
|
222
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Get final latents
|
|
226
|
+
final_latents = None
|
|
227
|
+
for latents in latents_generator:
|
|
228
|
+
final_latents = latents
|
|
229
|
+
mx.eval(final_latents)
|
|
230
|
+
|
|
231
|
+
if final_latents is None:
|
|
232
|
+
raise RuntimeError("No latents generated")
|
|
233
|
+
|
|
234
|
+
# Decode to image
|
|
235
|
+
decoded_image = self.model.decode(final_latents)
|
|
236
|
+
mx.eval(decoded_image)
|
|
237
|
+
|
|
238
|
+
# Convert to numpy array
|
|
239
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
240
|
+
|
|
241
|
+
if clear_cache:
|
|
242
|
+
mx.clear_cache()
|
|
243
|
+
|
|
244
|
+
return Image.from_numpy(image_array)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ml import ImageGenCreateInput, ImageGenerationConfig, ImageGenImg2ImgInput, ImageGenTxt2ImgInput, ImageGenOutput
|
|
6
|
+
from profiling import ProfilingMixin, StopReason
|
|
7
|
+
|
|
8
|
+
from .generate_sd import ImageGen as SDImageGen, Image, ImageGenerationConfig as SDImageGenerationConfig, ImageSamplerConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageGen(ProfilingMixin):
|
|
12
|
+
sd_gen: Optional[SDImageGen] = None
|
|
13
|
+
|
|
14
|
+
def __init__(self, input: ImageGenCreateInput):
|
|
15
|
+
"""Initialize the image generation model"""
|
|
16
|
+
self.sd_gen = SDImageGen(model_path=input.model_path)
|
|
17
|
+
|
|
18
|
+
def destroy(self) -> None:
|
|
19
|
+
"""Clean up resources"""
|
|
20
|
+
self.sd_gen = None
|
|
21
|
+
|
|
22
|
+
def txt2img(self, input: ImageGenTxt2ImgInput) -> ImageGenOutput:
|
|
23
|
+
"""Generate an image from a text prompt - public interface"""
|
|
24
|
+
height = input.config.height
|
|
25
|
+
width = input.config.width
|
|
26
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
27
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
28
|
+
|
|
29
|
+
internal_config = SDImageGenerationConfig(
|
|
30
|
+
prompts=input.prompt,
|
|
31
|
+
negative_prompts=input.config.negative_prompts,
|
|
32
|
+
height=height,
|
|
33
|
+
width=width,
|
|
34
|
+
sampler_config=ImageSamplerConfig(
|
|
35
|
+
steps=input.config.sampler_config.steps,
|
|
36
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
37
|
+
seed=input.config.sampler_config.seed
|
|
38
|
+
),
|
|
39
|
+
strength=input.config.strength
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
result_image = self.sd_gen.txt2img(input.prompt, internal_config)
|
|
43
|
+
|
|
44
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
45
|
+
if not os.path.exists(parent_dir):
|
|
46
|
+
os.makedirs(parent_dir)
|
|
47
|
+
result_image.to_pil().save(input.output_path)
|
|
48
|
+
|
|
49
|
+
return ImageGenOutput(output_image_path=input.output_path)
|
|
50
|
+
|
|
51
|
+
def img2img(self, input: ImageGenImg2ImgInput) -> ImageGenOutput:
|
|
52
|
+
"""Generate an image from an initial image and a text prompt - public interface"""
|
|
53
|
+
height = input.config.height
|
|
54
|
+
width = input.config.width
|
|
55
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
56
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
57
|
+
|
|
58
|
+
init_image = Image.from_pil(input.init_image_path)
|
|
59
|
+
|
|
60
|
+
internal_config = SDImageGenerationConfig(
|
|
61
|
+
prompts=input.prompt,
|
|
62
|
+
negative_prompts=input.config.negative_prompts,
|
|
63
|
+
height=height,
|
|
64
|
+
width=width,
|
|
65
|
+
sampler_config=ImageSamplerConfig(
|
|
66
|
+
steps=input.config.sampler_config.steps,
|
|
67
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
68
|
+
seed=input.config.sampler_config.seed
|
|
69
|
+
),
|
|
70
|
+
init_image=init_image,
|
|
71
|
+
strength=input.config.strength
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
result_image = self.sd_gen.img2img(
|
|
75
|
+
init_image, input.prompt, internal_config)
|
|
76
|
+
|
|
77
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
78
|
+
if not os.path.exists(parent_dir):
|
|
79
|
+
os.makedirs(parent_dir)
|
|
80
|
+
|
|
81
|
+
result_image.to_pil().save(input.output_path)
|
|
82
|
+
return ImageGenOutput(output_image_path=input.output_path)
|