npcpy 1.2.34__py3-none-any.whl → 1.2.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/data/audio.py +35 -1
- npcpy/data/load.py +149 -7
- npcpy/data/video.py +72 -0
- npcpy/ft/diff.py +332 -71
- npcpy/gen/image_gen.py +120 -23
- npcpy/gen/ocr.py +187 -0
- npcpy/memory/command_history.py +231 -40
- npcpy/npc_compiler.py +14 -5
- npcpy/serve.py +1206 -547
- {npcpy-1.2.34.dist-info → npcpy-1.2.35.dist-info}/METADATA +1 -1
- {npcpy-1.2.34.dist-info → npcpy-1.2.35.dist-info}/RECORD +14 -13
- {npcpy-1.2.34.dist-info → npcpy-1.2.35.dist-info}/WHEEL +0 -0
- {npcpy-1.2.34.dist-info → npcpy-1.2.35.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.2.34.dist-info → npcpy-1.2.35.dist-info}/top_level.txt +0 -0
npcpy/ft/diff.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
# finetuning diffuser models
|
|
2
1
|
try:
|
|
3
2
|
import torch
|
|
4
3
|
import torch.nn as nn
|
|
5
4
|
import torch.nn.functional as F
|
|
6
5
|
from torch.utils.data import DataLoader, Dataset as TorchDataset
|
|
7
6
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
8
|
-
|
|
7
|
+
TORCH_AVAILABLE = True
|
|
8
|
+
except ImportError:
|
|
9
9
|
torch = None
|
|
10
10
|
nn = None
|
|
11
11
|
F = None
|
|
@@ -13,9 +13,11 @@ except:
|
|
|
13
13
|
TorchDataset = None
|
|
14
14
|
CLIPTextModel = None
|
|
15
15
|
CLIPTokenizer = None
|
|
16
|
+
TORCH_AVAILABLE = False
|
|
17
|
+
|
|
16
18
|
import math
|
|
17
19
|
from dataclasses import dataclass, field
|
|
18
|
-
from typing import List, Optional
|
|
20
|
+
from typing import List, Optional
|
|
19
21
|
import numpy as np
|
|
20
22
|
from PIL import Image
|
|
21
23
|
import os
|
|
@@ -34,77 +36,336 @@ class DiffusionConfig:
|
|
|
34
36
|
num_epochs: int = 100
|
|
35
37
|
batch_size: int = 4
|
|
36
38
|
learning_rate: float = 1e-5
|
|
37
|
-
checkpoint_frequency: int =
|
|
38
|
-
|
|
39
|
-
use_clip: bool =
|
|
40
|
-
num_channels: int =
|
|
39
|
+
checkpoint_frequency: int = 10
|
|
40
|
+
output_model_path: str = "diffusion_model"
|
|
41
|
+
use_clip: bool = False
|
|
42
|
+
num_channels: int = 3
|
|
41
43
|
|
|
42
44
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def forward(self, time):
|
|
50
|
-
device = time.device
|
|
51
|
-
half_dim = self.dim // 2
|
|
52
|
-
embeddings = math.log(10000) / (half_dim - 1)
|
|
53
|
-
embeddings = torch.exp(
|
|
54
|
-
torch.arange(half_dim, device=device) * -embeddings
|
|
55
|
-
)
|
|
56
|
-
embeddings = time[:, None] * embeddings[None, :]
|
|
57
|
-
embeddings = torch.cat(
|
|
58
|
-
(embeddings.sin(), embeddings.cos()),
|
|
59
|
-
dim=-1
|
|
60
|
-
)
|
|
61
|
-
return embeddings
|
|
45
|
+
if TORCH_AVAILABLE:
|
|
46
|
+
class SinusoidalPositionEmbeddings(nn.Module):
|
|
47
|
+
def __init__(self, dim):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.dim = dim
|
|
62
50
|
|
|
51
|
+
def forward(self, time):
|
|
52
|
+
device = time.device
|
|
53
|
+
half_dim = self.dim // 2
|
|
54
|
+
embeddings = math.log(10000) / (half_dim - 1)
|
|
55
|
+
embeddings = torch.exp(
|
|
56
|
+
torch.arange(half_dim, device=device) * -embeddings
|
|
57
|
+
)
|
|
58
|
+
embeddings = time[:, None] * embeddings[None, :]
|
|
59
|
+
embeddings = torch.cat(
|
|
60
|
+
(embeddings.sin(), embeddings.cos()),
|
|
61
|
+
dim=-1
|
|
62
|
+
)
|
|
63
|
+
return embeddings
|
|
63
64
|
|
|
64
|
-
class SimpleUNet(nn.Module):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
nn.
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
nn.
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
nn.
|
|
65
|
+
class SimpleUNet(nn.Module):
|
|
66
|
+
def __init__(self, image_size=128, channels=256,
|
|
67
|
+
time_emb_dim=128, num_channels=3):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.image_size = image_size
|
|
70
|
+
|
|
71
|
+
self.time_mlp = nn.Sequential(
|
|
72
|
+
SinusoidalPositionEmbeddings(time_emb_dim),
|
|
73
|
+
nn.Linear(time_emb_dim, time_emb_dim * 4),
|
|
74
|
+
nn.GELU(),
|
|
75
|
+
nn.Linear(time_emb_dim * 4, channels),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
self.conv_in = nn.Conv2d(num_channels, channels, 3, padding=1)
|
|
79
|
+
|
|
80
|
+
self.down1 = nn.Sequential(
|
|
81
|
+
nn.Conv2d(channels, channels * 2, 4, 2, 1),
|
|
82
|
+
nn.GroupNorm(8, channels * 2),
|
|
83
|
+
nn.GELU(),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.down2 = nn.Sequential(
|
|
87
|
+
nn.Conv2d(channels * 2, channels * 4, 4, 2, 1),
|
|
88
|
+
nn.GroupNorm(8, channels * 4),
|
|
89
|
+
nn.GELU(),
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.mid = nn.Sequential(
|
|
93
|
+
nn.Conv2d(channels * 4, channels * 4, 3, 1, 1),
|
|
94
|
+
nn.GroupNorm(8, channels * 4),
|
|
95
|
+
nn.GELU(),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.up1 = nn.Sequential(
|
|
99
|
+
nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1),
|
|
100
|
+
nn.GroupNorm(8, channels * 2),
|
|
101
|
+
nn.GELU(),
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.up2 = nn.Sequential(
|
|
105
|
+
nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1),
|
|
106
|
+
nn.GroupNorm(8, channels),
|
|
107
|
+
nn.GELU(),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self.conv_out = nn.Conv2d(channels * 2, num_channels, 3, padding=1)
|
|
111
|
+
|
|
112
|
+
def forward(self, x, t):
|
|
113
|
+
t_emb = self.time_mlp(t)
|
|
114
|
+
|
|
115
|
+
x = self.conv_in(x)
|
|
116
|
+
h1 = x + t_emb[:, :, None, None]
|
|
117
|
+
|
|
118
|
+
h2 = self.down1(h1)
|
|
119
|
+
h3 = self.down2(h2)
|
|
120
|
+
|
|
121
|
+
h3 = self.mid(h3)
|
|
122
|
+
|
|
123
|
+
h = self.up1(h3)
|
|
124
|
+
h = torch.cat([h, h2], dim=1)
|
|
125
|
+
h = self.up2(h)
|
|
126
|
+
h = torch.cat([h, h1], dim=1)
|
|
127
|
+
|
|
128
|
+
return self.conv_out(h)
|
|
129
|
+
|
|
130
|
+
class ImageDataset(TorchDataset):
|
|
131
|
+
def __init__(self, image_paths, captions, image_size=128):
|
|
132
|
+
self.image_paths = image_paths
|
|
133
|
+
self.captions = captions if captions else [''] * len(image_paths)
|
|
134
|
+
self.image_size = image_size
|
|
135
|
+
|
|
136
|
+
def __len__(self):
|
|
137
|
+
return len(self.image_paths)
|
|
138
|
+
|
|
139
|
+
def __getitem__(self, idx):
|
|
140
|
+
img_path = self.image_paths[idx]
|
|
141
|
+
img = Image.open(img_path).convert('RGB')
|
|
142
|
+
img = img.resize((self.image_size, self.image_size))
|
|
143
|
+
img = np.array(img).astype(np.float32) / 255.0
|
|
144
|
+
img = (img - 0.5) * 2.0
|
|
145
|
+
img = torch.from_numpy(img).permute(2, 0, 1)
|
|
146
|
+
caption = self.captions[idx] if idx < len(self.captions) else ''
|
|
147
|
+
return img, caption
|
|
148
|
+
|
|
149
|
+
class DiffusionTrainer:
|
|
150
|
+
def __init__(self, config):
|
|
151
|
+
self.config = config
|
|
152
|
+
self.device = torch.device(
|
|
153
|
+
'cuda' if torch.cuda.is_available() else 'cpu'
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.model = SimpleUNet(
|
|
157
|
+
image_size=config.image_size,
|
|
158
|
+
channels=config.channels,
|
|
159
|
+
time_emb_dim=config.time_emb_dim,
|
|
160
|
+
num_channels=config.num_channels
|
|
161
|
+
).to(self.device)
|
|
162
|
+
|
|
163
|
+
self.betas = torch.linspace(
|
|
164
|
+
config.beta_start,
|
|
165
|
+
config.beta_end,
|
|
166
|
+
config.timesteps
|
|
167
|
+
).to(self.device)
|
|
168
|
+
self.alphas = 1.0 - self.betas
|
|
169
|
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
170
|
+
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
|
171
|
+
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
|
|
172
|
+
1.0 - self.alphas_cumprod
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def add_noise(self, x, t):
|
|
176
|
+
sqrt_alpha = self.sqrt_alphas_cumprod[t][:, None, None, None]
|
|
177
|
+
sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t][
|
|
178
|
+
:, None, None, None
|
|
179
|
+
]
|
|
180
|
+
noise = torch.randn_like(x)
|
|
181
|
+
return sqrt_alpha * x + sqrt_one_minus * noise, noise
|
|
182
|
+
|
|
183
|
+
def train(self, dataloader):
|
|
184
|
+
optimizer = torch.optim.AdamW(
|
|
185
|
+
self.model.parameters(),
|
|
186
|
+
lr=self.config.learning_rate
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
os.makedirs(self.config.output_model_path, exist_ok=True)
|
|
190
|
+
checkpoint_dir = os.path.join(
|
|
191
|
+
self.config.output_model_path,
|
|
192
|
+
'checkpoints'
|
|
193
|
+
)
|
|
194
|
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
195
|
+
|
|
196
|
+
global_step = 0
|
|
197
|
+
|
|
198
|
+
for epoch in range(self.config.num_epochs):
|
|
199
|
+
self.model.train()
|
|
200
|
+
epoch_loss = 0.0
|
|
201
|
+
|
|
202
|
+
pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
|
|
203
|
+
for batch_idx, (images, captions) in enumerate(pbar):
|
|
204
|
+
images = images.to(self.device)
|
|
205
|
+
batch_size = images.shape[0]
|
|
206
|
+
|
|
207
|
+
t = torch.randint(
|
|
208
|
+
0,
|
|
209
|
+
self.config.timesteps,
|
|
210
|
+
(batch_size,),
|
|
211
|
+
device=self.device
|
|
212
|
+
).long()
|
|
213
|
+
|
|
214
|
+
noisy_images, noise = self.add_noise(images, t)
|
|
215
|
+
|
|
216
|
+
predicted_noise = self.model(noisy_images, t)
|
|
217
|
+
|
|
218
|
+
loss = F.mse_loss(predicted_noise, noise)
|
|
219
|
+
|
|
220
|
+
optimizer.zero_grad()
|
|
221
|
+
loss.backward()
|
|
222
|
+
optimizer.step()
|
|
223
|
+
|
|
224
|
+
epoch_loss += loss.item()
|
|
225
|
+
global_step += 1
|
|
226
|
+
|
|
227
|
+
pbar.set_postfix({'loss': loss.item()})
|
|
228
|
+
|
|
229
|
+
if global_step % self.config.checkpoint_frequency == 0:
|
|
230
|
+
ckpt_path = os.path.join(
|
|
231
|
+
checkpoint_dir,
|
|
232
|
+
f'checkpoint-epoch{epoch+1}-step{global_step}.pt'
|
|
233
|
+
)
|
|
234
|
+
torch.save({
|
|
235
|
+
'epoch': epoch,
|
|
236
|
+
'step': global_step,
|
|
237
|
+
'model_state_dict': self.model.state_dict(),
|
|
238
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
239
|
+
'loss': loss.item(),
|
|
240
|
+
}, ckpt_path)
|
|
241
|
+
|
|
242
|
+
avg_loss = epoch_loss / len(dataloader)
|
|
243
|
+
print(f'Epoch {epoch+1} avg loss: {avg_loss:.6f}')
|
|
244
|
+
|
|
245
|
+
final_path = os.path.join(
|
|
246
|
+
self.config.output_model_path,
|
|
247
|
+
'model_final.pt'
|
|
248
|
+
)
|
|
249
|
+
torch.save({
|
|
250
|
+
'model_state_dict': self.model.state_dict(),
|
|
251
|
+
'config': self.config,
|
|
252
|
+
}, final_path)
|
|
253
|
+
|
|
254
|
+
return self.config.output_model_path
|
|
255
|
+
|
|
256
|
+
@torch.no_grad()
|
|
257
|
+
def sample(self, num_samples=1):
|
|
258
|
+
self.model.eval()
|
|
259
|
+
|
|
260
|
+
x = torch.randn(
|
|
261
|
+
num_samples,
|
|
262
|
+
self.config.num_channels,
|
|
263
|
+
self.config.image_size,
|
|
264
|
+
self.config.image_size,
|
|
265
|
+
device=self.device
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
for t in reversed(range(self.config.timesteps)):
|
|
269
|
+
t_batch = torch.full(
|
|
270
|
+
(num_samples,),
|
|
271
|
+
t,
|
|
272
|
+
device=self.device,
|
|
273
|
+
dtype=torch.long
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
predicted_noise = self.model(x, t_batch)
|
|
277
|
+
|
|
278
|
+
alpha = self.alphas[t]
|
|
279
|
+
alpha_cumprod = self.alphas_cumprod[t]
|
|
280
|
+
beta = self.betas[t]
|
|
281
|
+
|
|
282
|
+
if t > 0:
|
|
283
|
+
noise = torch.randn_like(x)
|
|
284
|
+
else:
|
|
285
|
+
noise = torch.zeros_like(x)
|
|
286
|
+
|
|
287
|
+
x = (1 / torch.sqrt(alpha)) * (
|
|
288
|
+
x - (beta / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
|
|
289
|
+
) + torch.sqrt(beta) * noise
|
|
290
|
+
|
|
291
|
+
x = (x + 1) / 2
|
|
292
|
+
x = torch.clamp(x, 0, 1)
|
|
293
|
+
|
|
294
|
+
return x
|
|
295
|
+
|
|
296
|
+
else:
|
|
297
|
+
SinusoidalPositionEmbeddings = None
|
|
298
|
+
SimpleUNet = None
|
|
299
|
+
ImageDataset = None
|
|
300
|
+
DiffusionTrainer = None
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def train_diffusion(image_paths, captions=None, config=None,
|
|
304
|
+
resume_from=None):
|
|
305
|
+
if not TORCH_AVAILABLE:
|
|
306
|
+
raise ImportError(
|
|
307
|
+
"PyTorch not available. Install: pip install torch torchvision"
|
|
98
308
|
)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
309
|
+
|
|
310
|
+
if config is None:
|
|
311
|
+
config = DiffusionConfig()
|
|
312
|
+
|
|
313
|
+
if captions is None:
|
|
314
|
+
captions = [''] * len(image_paths)
|
|
315
|
+
|
|
316
|
+
dataset = ImageDataset(image_paths, captions, config.image_size)
|
|
317
|
+
dataloader = DataLoader(
|
|
318
|
+
dataset,
|
|
319
|
+
batch_size=config.batch_size,
|
|
320
|
+
shuffle=True,
|
|
321
|
+
num_workers=0
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
trainer = DiffusionTrainer(config)
|
|
325
|
+
|
|
326
|
+
if resume_from and os.path.exists(resume_from):
|
|
327
|
+
checkpoint = torch.load(resume_from, map_location=trainer.device)
|
|
328
|
+
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
|
329
|
+
print(f'Resumed from {resume_from}')
|
|
330
|
+
|
|
331
|
+
output_path = trainer.train(dataloader)
|
|
332
|
+
|
|
333
|
+
gc.collect()
|
|
334
|
+
if torch.cuda.is_available():
|
|
335
|
+
torch.cuda.empty_cache()
|
|
336
|
+
|
|
337
|
+
return output_path
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def generate_image(model_path, prompt=None, num_samples=1, image_size=128):
|
|
341
|
+
if not TORCH_AVAILABLE:
|
|
342
|
+
raise ImportError(
|
|
343
|
+
"PyTorch not available. Install: pip install torch torchvision"
|
|
104
344
|
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
345
|
+
|
|
346
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
347
|
+
|
|
348
|
+
# Fix: Load with weights_only=False for your custom checkpoint
|
|
349
|
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
|
350
|
+
|
|
351
|
+
if 'config' in checkpoint:
|
|
352
|
+
config = checkpoint['config']
|
|
353
|
+
else:
|
|
354
|
+
config = DiffusionConfig(image_size=image_size)
|
|
355
|
+
|
|
356
|
+
trainer = DiffusionTrainer(config)
|
|
357
|
+
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
|
358
|
+
|
|
359
|
+
samples = trainer.sample(num_samples)
|
|
360
|
+
|
|
361
|
+
images = []
|
|
362
|
+
for i in range(num_samples):
|
|
363
|
+
img_tensor = samples[i].cpu()
|
|
364
|
+
img_np = img_tensor.permute(1, 2, 0).numpy()
|
|
365
|
+
img_np = (img_np * 255).astype(np.uint8)
|
|
366
|
+
img = Image.fromarray(img_np)
|
|
367
|
+
images.append(img)
|
|
368
|
+
|
|
369
|
+
if num_samples == 1:
|
|
370
|
+
return images[0]
|
|
371
|
+
return images
|
npcpy/gen/image_gen.py
CHANGED
|
@@ -21,29 +21,86 @@ def generate_image_diffusers(
|
|
|
21
21
|
"""Generate an image using the Stable Diffusion API with memory optimization."""
|
|
22
22
|
import torch
|
|
23
23
|
import gc
|
|
24
|
-
|
|
24
|
+
import os
|
|
25
|
+
from diffusers import DiffusionPipeline, StableDiffusionPipeline
|
|
26
|
+
|
|
25
27
|
try:
|
|
26
28
|
torch_dtype = torch.float16 if device != "cpu" and torch.cuda.is_available() else torch.float32
|
|
27
29
|
|
|
28
|
-
if
|
|
29
|
-
from
|
|
30
|
-
|
|
31
|
-
pipe = DiffusionPipeline.from_pretrained(
|
|
32
|
-
model,
|
|
33
|
-
torch_dtype=torch_dtype,
|
|
34
|
-
use_safetensors=True,
|
|
35
|
-
variant="fp16" if torch_dtype == torch.float16 else None,
|
|
36
|
-
)
|
|
37
|
-
else:
|
|
38
|
-
from diffusers import StableDiffusionPipeline
|
|
30
|
+
if os.path.isdir(model):
|
|
31
|
+
print(f"🌋 Loading fine-tuned Diffusers model from local path: {model}")
|
|
39
32
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
33
|
+
checkpoint_path = os.path.join(model, 'model_final.pt')
|
|
34
|
+
if os.path.exists(checkpoint_path):
|
|
35
|
+
print(f"🌋 Found model_final.pt at {checkpoint_path}.")
|
|
36
|
+
|
|
37
|
+
# Load checkpoint to inspect it
|
|
38
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
39
|
+
|
|
40
|
+
# Check if this is a custom SimpleUNet model (from your training code)
|
|
41
|
+
# vs a Stable Diffusion UNet2DConditionModel
|
|
42
|
+
if 'config' in checkpoint and hasattr(checkpoint['config'], 'image_size'):
|
|
43
|
+
print(f"🌋 Detected custom SimpleUNet model, using custom generation")
|
|
44
|
+
# Use your custom generate_image function from npcpy.ft.diff
|
|
45
|
+
from npcpy.ft.diff import generate_image as custom_generate_image
|
|
46
|
+
|
|
47
|
+
# Your custom model ignores prompts and generates based on training data
|
|
48
|
+
image = custom_generate_image(
|
|
49
|
+
model_path=checkpoint_path,
|
|
50
|
+
prompt=prompt,
|
|
51
|
+
num_samples=1,
|
|
52
|
+
image_size=height # Use the requested height
|
|
53
|
+
)
|
|
54
|
+
return image
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
# This is a Stable Diffusion checkpoint
|
|
58
|
+
print(f"🌋 Detected Stable Diffusion UNet checkpoint")
|
|
59
|
+
base_model_id = "runwayml/stable-diffusion-v1-5"
|
|
60
|
+
print(f"🌋 Loading base pipeline: {base_model_id}")
|
|
61
|
+
pipe = StableDiffusionPipeline.from_pretrained(
|
|
62
|
+
base_model_id,
|
|
63
|
+
torch_dtype=torch_dtype,
|
|
64
|
+
use_safetensors=True,
|
|
65
|
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
print(f"🌋 Loading custom UNet weights from {checkpoint_path}")
|
|
69
|
+
|
|
70
|
+
# Extract the actual model state dict
|
|
71
|
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
72
|
+
unet_state_dict = checkpoint['model_state_dict']
|
|
73
|
+
print(f"🌋 Extracted model_state_dict from checkpoint")
|
|
74
|
+
else:
|
|
75
|
+
unet_state_dict = checkpoint
|
|
76
|
+
print(f"🌋 Using checkpoint directly as state_dict")
|
|
77
|
+
|
|
78
|
+
# Load the state dict into the UNet
|
|
79
|
+
pipe.unet.load_state_dict(unet_state_dict)
|
|
80
|
+
pipe = pipe.to(device)
|
|
81
|
+
print(f"🌋 Successfully loaded fine-tuned UNet weights")
|
|
82
|
+
|
|
83
|
+
else:
|
|
84
|
+
raise OSError(f"Error: Fine-tuned model directory {model} does not contain 'model_final.pt'")
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
print(f"🌋 Loading standard Diffusers model: {model}")
|
|
88
|
+
if 'Qwen' in model:
|
|
89
|
+
pipe = DiffusionPipeline.from_pretrained(
|
|
90
|
+
model,
|
|
91
|
+
torch_dtype=torch_dtype,
|
|
92
|
+
use_safetensors=True,
|
|
93
|
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
pipe = StableDiffusionPipeline.from_pretrained(
|
|
97
|
+
model,
|
|
98
|
+
torch_dtype=torch_dtype,
|
|
99
|
+
use_safetensors=True,
|
|
100
|
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
|
101
|
+
)
|
|
46
102
|
|
|
103
|
+
# Common pipeline setup for Stable Diffusion models
|
|
47
104
|
if hasattr(pipe, 'enable_attention_slicing'):
|
|
48
105
|
pipe.enable_attention_slicing()
|
|
49
106
|
|
|
@@ -85,7 +142,6 @@ def generate_image_diffusers(
|
|
|
85
142
|
raise MemoryError(f"Insufficient memory for image generation with model {model}. Try a smaller model or reduce image size.")
|
|
86
143
|
else:
|
|
87
144
|
raise e
|
|
88
|
-
|
|
89
145
|
import os
|
|
90
146
|
import base64
|
|
91
147
|
import io
|
|
@@ -294,6 +350,8 @@ def gemini_image_gen(
|
|
|
294
350
|
|
|
295
351
|
else:
|
|
296
352
|
raise ValueError(f"Unsupported Gemini image model or API usage for new generation: '{model}'")
|
|
353
|
+
# In npcpy/gen/image_gen.py, find the generate_image function and replace it with this:
|
|
354
|
+
|
|
297
355
|
def generate_image(
|
|
298
356
|
prompt: str,
|
|
299
357
|
model: str ,
|
|
@@ -305,6 +363,7 @@ def generate_image(
|
|
|
305
363
|
api_url: Optional[str] = None,
|
|
306
364
|
attachments: Union[List[Union[str, bytes, Image.Image]], None] = None,
|
|
307
365
|
save_path: Optional[str] = None,
|
|
366
|
+
custom_model_path: Optional[str] = None, # <--- NEW: Accept custom_model_path
|
|
308
367
|
):
|
|
309
368
|
"""
|
|
310
369
|
Unified function to generate or edit images using various providers.
|
|
@@ -320,13 +379,15 @@ def generate_image(
|
|
|
320
379
|
api_url (str): API URL for the provider.
|
|
321
380
|
attachments (list): List of images for editing. Can be file paths, bytes, or PIL Images.
|
|
322
381
|
save_path (str): Path to save the generated image.
|
|
382
|
+
custom_model_path (str): Path to a locally fine-tuned Diffusers model. <--- NEW
|
|
323
383
|
|
|
324
384
|
Returns:
|
|
325
385
|
List[PIL.Image.Image]: A list of generated PIL Image objects.
|
|
326
386
|
"""
|
|
327
387
|
from urllib.request import urlopen
|
|
388
|
+
import os # Ensure os is imported for path checks
|
|
328
389
|
|
|
329
|
-
if model is None:
|
|
390
|
+
if model is None and custom_model_path is None: # Only set default if no model or custom path is provided
|
|
330
391
|
if provider == "openai":
|
|
331
392
|
model = "dall-e-2"
|
|
332
393
|
elif provider == "diffusers":
|
|
@@ -336,12 +397,22 @@ def generate_image(
|
|
|
336
397
|
|
|
337
398
|
all_generated_pil_images = []
|
|
338
399
|
|
|
400
|
+
# <--- CRITICAL FIX: Handle custom_model_path for Diffusers here
|
|
339
401
|
if provider == "diffusers":
|
|
402
|
+
# If a custom_model_path is provided and exists, use it instead of a generic model name
|
|
403
|
+
if custom_model_path and os.path.isdir(custom_model_path):
|
|
404
|
+
print(f"🌋 Using custom Diffusers model from path: {custom_model_path}")
|
|
405
|
+
model_to_use = custom_model_path
|
|
406
|
+
else:
|
|
407
|
+
# Otherwise, use the standard model name (e.g., "runwayml/stable-diffusion-v1-5")
|
|
408
|
+
model_to_use = model
|
|
409
|
+
print(f"🌋 Using standard Diffusers model: {model_to_use}")
|
|
410
|
+
|
|
340
411
|
for _ in range(n_images):
|
|
341
412
|
try:
|
|
342
413
|
image = generate_image_diffusers(
|
|
343
414
|
prompt=prompt,
|
|
344
|
-
model=
|
|
415
|
+
model=model_to_use, # <--- Pass the resolved model_to_use
|
|
345
416
|
height=height,
|
|
346
417
|
width=width
|
|
347
418
|
)
|
|
@@ -373,15 +444,42 @@ def generate_image(
|
|
|
373
444
|
all_generated_pil_images.extend(images)
|
|
374
445
|
|
|
375
446
|
else:
|
|
447
|
+
# This is the fallback for other providers or if provider is not explicitly handled
|
|
376
448
|
valid_sizes = ["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]
|
|
377
449
|
size = f"{width}x{height}"
|
|
378
450
|
|
|
379
451
|
if attachments is not None:
|
|
380
452
|
raise ValueError("Image editing not supported with litellm provider")
|
|
381
453
|
|
|
454
|
+
# The litellm.image_generation function expects the provider as part of the model string
|
|
455
|
+
# e.g., "huggingface/starcoder" or "openai/dall-e-3"
|
|
456
|
+
# Since we've already handled "diffusers", "openai", "gemini" above,
|
|
457
|
+
# this 'else' block implies a generic litellm call.
|
|
458
|
+
# We need to ensure the model string is correctly formatted for litellm.
|
|
459
|
+
# However, the error message "LLM Provider NOT provided" suggests litellm
|
|
460
|
+
# is not even getting the `provider` correctly.
|
|
461
|
+
# The fix for this is ensuring the `provider` is explicitly passed to litellm.image_generation
|
|
462
|
+
# which is already happening in `gen_image` in `llm_funcs.py`
|
|
463
|
+
|
|
464
|
+
# If we reach here, it means the provider is not 'diffusers', 'openai', or 'gemini',
|
|
465
|
+
# and litellm is the intended route. We need to pass the provider explicitly.
|
|
466
|
+
# The original code here was trying to construct `model=f"{provider}/{model}"`
|
|
467
|
+
# but the error indicates `provider` itself was missing.
|
|
468
|
+
# The `image_generation` from litellm expects `model` to be `provider/model_name`.
|
|
469
|
+
# Since the `provider` variable is available, we can construct this.
|
|
470
|
+
|
|
471
|
+
# This block is for generic litellm providers (not diffusers, openai, gemini)
|
|
472
|
+
# The error indicates `provider` itself was not making it to litellm.
|
|
473
|
+
# This `generate_image` function already receives `provider`.
|
|
474
|
+
# The issue is likely how `gen_image` in `llm_funcs.py` calls this `generate_image`.
|
|
475
|
+
# However, if this `else` branch is hit, we ensure litellm gets the provider.
|
|
476
|
+
|
|
477
|
+
# Construct the model string for litellm
|
|
478
|
+
litellm_model_string = f"{provider}/{model}" if provider and model else model
|
|
479
|
+
|
|
382
480
|
image_response = image_generation(
|
|
383
481
|
prompt=prompt,
|
|
384
|
-
model=
|
|
482
|
+
model=litellm_model_string, # <--- Ensure model string includes provider for litellm
|
|
385
483
|
n=n_images,
|
|
386
484
|
size=size,
|
|
387
485
|
api_key=api_key,
|
|
@@ -407,7 +505,6 @@ def generate_image(
|
|
|
407
505
|
|
|
408
506
|
return all_generated_pil_images
|
|
409
507
|
|
|
410
|
-
|
|
411
508
|
def edit_image(
|
|
412
509
|
prompt: str,
|
|
413
510
|
image_path: str,
|