npcpy 1.2.34__py3-none-any.whl → 1.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
npcpy/ft/diff.py CHANGED
@@ -1,11 +1,11 @@
1
- # finetuning diffuser models
2
1
  try:
3
2
  import torch
4
3
  import torch.nn as nn
5
4
  import torch.nn.functional as F
6
5
  from torch.utils.data import DataLoader, Dataset as TorchDataset
7
6
  from transformers import CLIPTextModel, CLIPTokenizer
8
- except:
7
+ TORCH_AVAILABLE = True
8
+ except ImportError:
9
9
  torch = None
10
10
  nn = None
11
11
  F = None
@@ -13,9 +13,11 @@ except:
13
13
  TorchDataset = None
14
14
  CLIPTextModel = None
15
15
  CLIPTokenizer = None
16
+ TORCH_AVAILABLE = False
17
+
16
18
  import math
17
19
  from dataclasses import dataclass, field
18
- from typing import List, Optional, Callable
20
+ from typing import List, Optional
19
21
  import numpy as np
20
22
  from PIL import Image
21
23
  import os
@@ -34,77 +36,336 @@ class DiffusionConfig:
34
36
  num_epochs: int = 100
35
37
  batch_size: int = 4
36
38
  learning_rate: float = 1e-5
37
- checkpoint_frequency: int = 1000
38
- output_dir: str = "diffusion_model"
39
- use_clip: bool = True
40
- num_channels: int = 1
39
+ checkpoint_frequency: int = 10
40
+ output_model_path: str = "diffusion_model"
41
+ use_clip: bool = False
42
+ num_channels: int = 3
41
43
 
42
44
 
43
- class SinusoidalPositionEmbeddings(nn.Module):
44
-
45
- def __init__(self, dim):
46
- super().__init__()
47
- self.dim = dim
48
-
49
- def forward(self, time):
50
- device = time.device
51
- half_dim = self.dim // 2
52
- embeddings = math.log(10000) / (half_dim - 1)
53
- embeddings = torch.exp(
54
- torch.arange(half_dim, device=device) * -embeddings
55
- )
56
- embeddings = time[:, None] * embeddings[None, :]
57
- embeddings = torch.cat(
58
- (embeddings.sin(), embeddings.cos()),
59
- dim=-1
60
- )
61
- return embeddings
45
+ if TORCH_AVAILABLE:
46
+ class SinusoidalPositionEmbeddings(nn.Module):
47
+ def __init__(self, dim):
48
+ super().__init__()
49
+ self.dim = dim
62
50
 
51
+ def forward(self, time):
52
+ device = time.device
53
+ half_dim = self.dim // 2
54
+ embeddings = math.log(10000) / (half_dim - 1)
55
+ embeddings = torch.exp(
56
+ torch.arange(half_dim, device=device) * -embeddings
57
+ )
58
+ embeddings = time[:, None] * embeddings[None, :]
59
+ embeddings = torch.cat(
60
+ (embeddings.sin(), embeddings.cos()),
61
+ dim=-1
62
+ )
63
+ return embeddings
63
64
 
64
- class SimpleUNet(nn.Module):
65
-
66
- def __init__(
67
- self,
68
- image_size=128,
69
- channels=256,
70
- time_emb_dim=128,
71
- num_channels=1
72
- ):
73
- super().__init__()
74
-
75
- self.image_size = image_size
76
-
77
- self.time_mlp = nn.Sequential(
78
- SinusoidalPositionEmbeddings(time_emb_dim),
79
- nn.Linear(time_emb_dim, time_emb_dim * 4),
80
- nn.GELU(),
81
- nn.Linear(time_emb_dim * 4, channels),
82
- )
83
-
84
- self.text_mlp = nn.Sequential(
85
- nn.Linear(768, time_emb_dim),
86
- nn.GELU(),
87
- nn.Linear(time_emb_dim, time_emb_dim),
88
- nn.GELU(),
89
- nn.Linear(time_emb_dim, channels),
90
- )
91
-
92
- self.conv_in = nn.Conv2d(num_channels, channels, 1, padding=0)
93
-
94
- self.down1 = nn.Sequential(
95
- nn.Conv2d(channels, channels * 2, 4, 2, 1),
96
- nn.GroupNorm(8, channels * 2),
97
- nn.GELU(),
65
+ class SimpleUNet(nn.Module):
66
+ def __init__(self, image_size=128, channels=256,
67
+ time_emb_dim=128, num_channels=3):
68
+ super().__init__()
69
+ self.image_size = image_size
70
+
71
+ self.time_mlp = nn.Sequential(
72
+ SinusoidalPositionEmbeddings(time_emb_dim),
73
+ nn.Linear(time_emb_dim, time_emb_dim * 4),
74
+ nn.GELU(),
75
+ nn.Linear(time_emb_dim * 4, channels),
76
+ )
77
+
78
+ self.conv_in = nn.Conv2d(num_channels, channels, 3, padding=1)
79
+
80
+ self.down1 = nn.Sequential(
81
+ nn.Conv2d(channels, channels * 2, 4, 2, 1),
82
+ nn.GroupNorm(8, channels * 2),
83
+ nn.GELU(),
84
+ )
85
+
86
+ self.down2 = nn.Sequential(
87
+ nn.Conv2d(channels * 2, channels * 4, 4, 2, 1),
88
+ nn.GroupNorm(8, channels * 4),
89
+ nn.GELU(),
90
+ )
91
+
92
+ self.mid = nn.Sequential(
93
+ nn.Conv2d(channels * 4, channels * 4, 3, 1, 1),
94
+ nn.GroupNorm(8, channels * 4),
95
+ nn.GELU(),
96
+ )
97
+
98
+ self.up1 = nn.Sequential(
99
+ nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1),
100
+ nn.GroupNorm(8, channels * 2),
101
+ nn.GELU(),
102
+ )
103
+
104
+ self.up2 = nn.Sequential(
105
+ nn.ConvTranspose2d(channels * 4, channels, 4, 2, 1),
106
+ nn.GroupNorm(8, channels),
107
+ nn.GELU(),
108
+ )
109
+
110
+ self.conv_out = nn.Conv2d(channels * 2, num_channels, 3, padding=1)
111
+
112
+ def forward(self, x, t):
113
+ t_emb = self.time_mlp(t)
114
+
115
+ x = self.conv_in(x)
116
+ h1 = x + t_emb[:, :, None, None]
117
+
118
+ h2 = self.down1(h1)
119
+ h3 = self.down2(h2)
120
+
121
+ h3 = self.mid(h3)
122
+
123
+ h = self.up1(h3)
124
+ h = torch.cat([h, h2], dim=1)
125
+ h = self.up2(h)
126
+ h = torch.cat([h, h1], dim=1)
127
+
128
+ return self.conv_out(h)
129
+
130
+ class ImageDataset(TorchDataset):
131
+ def __init__(self, image_paths, captions, image_size=128):
132
+ self.image_paths = image_paths
133
+ self.captions = captions if captions else [''] * len(image_paths)
134
+ self.image_size = image_size
135
+
136
+ def __len__(self):
137
+ return len(self.image_paths)
138
+
139
+ def __getitem__(self, idx):
140
+ img_path = self.image_paths[idx]
141
+ img = Image.open(img_path).convert('RGB')
142
+ img = img.resize((self.image_size, self.image_size))
143
+ img = np.array(img).astype(np.float32) / 255.0
144
+ img = (img - 0.5) * 2.0
145
+ img = torch.from_numpy(img).permute(2, 0, 1)
146
+ caption = self.captions[idx] if idx < len(self.captions) else ''
147
+ return img, caption
148
+
149
+ class DiffusionTrainer:
150
+ def __init__(self, config):
151
+ self.config = config
152
+ self.device = torch.device(
153
+ 'cuda' if torch.cuda.is_available() else 'cpu'
154
+ )
155
+
156
+ self.model = SimpleUNet(
157
+ image_size=config.image_size,
158
+ channels=config.channels,
159
+ time_emb_dim=config.time_emb_dim,
160
+ num_channels=config.num_channels
161
+ ).to(self.device)
162
+
163
+ self.betas = torch.linspace(
164
+ config.beta_start,
165
+ config.beta_end,
166
+ config.timesteps
167
+ ).to(self.device)
168
+ self.alphas = 1.0 - self.betas
169
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
170
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
171
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
172
+ 1.0 - self.alphas_cumprod
173
+ )
174
+
175
+ def add_noise(self, x, t):
176
+ sqrt_alpha = self.sqrt_alphas_cumprod[t][:, None, None, None]
177
+ sqrt_one_minus = self.sqrt_one_minus_alphas_cumprod[t][
178
+ :, None, None, None
179
+ ]
180
+ noise = torch.randn_like(x)
181
+ return sqrt_alpha * x + sqrt_one_minus * noise, noise
182
+
183
+ def train(self, dataloader):
184
+ optimizer = torch.optim.AdamW(
185
+ self.model.parameters(),
186
+ lr=self.config.learning_rate
187
+ )
188
+
189
+ os.makedirs(self.config.output_model_path, exist_ok=True)
190
+ checkpoint_dir = os.path.join(
191
+ self.config.output_model_path,
192
+ 'checkpoints'
193
+ )
194
+ os.makedirs(checkpoint_dir, exist_ok=True)
195
+
196
+ global_step = 0
197
+
198
+ for epoch in range(self.config.num_epochs):
199
+ self.model.train()
200
+ epoch_loss = 0.0
201
+
202
+ pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
203
+ for batch_idx, (images, captions) in enumerate(pbar):
204
+ images = images.to(self.device)
205
+ batch_size = images.shape[0]
206
+
207
+ t = torch.randint(
208
+ 0,
209
+ self.config.timesteps,
210
+ (batch_size,),
211
+ device=self.device
212
+ ).long()
213
+
214
+ noisy_images, noise = self.add_noise(images, t)
215
+
216
+ predicted_noise = self.model(noisy_images, t)
217
+
218
+ loss = F.mse_loss(predicted_noise, noise)
219
+
220
+ optimizer.zero_grad()
221
+ loss.backward()
222
+ optimizer.step()
223
+
224
+ epoch_loss += loss.item()
225
+ global_step += 1
226
+
227
+ pbar.set_postfix({'loss': loss.item()})
228
+
229
+ if global_step % self.config.checkpoint_frequency == 0:
230
+ ckpt_path = os.path.join(
231
+ checkpoint_dir,
232
+ f'checkpoint-epoch{epoch+1}-step{global_step}.pt'
233
+ )
234
+ torch.save({
235
+ 'epoch': epoch,
236
+ 'step': global_step,
237
+ 'model_state_dict': self.model.state_dict(),
238
+ 'optimizer_state_dict': optimizer.state_dict(),
239
+ 'loss': loss.item(),
240
+ }, ckpt_path)
241
+
242
+ avg_loss = epoch_loss / len(dataloader)
243
+ print(f'Epoch {epoch+1} avg loss: {avg_loss:.6f}')
244
+
245
+ final_path = os.path.join(
246
+ self.config.output_model_path,
247
+ 'model_final.pt'
248
+ )
249
+ torch.save({
250
+ 'model_state_dict': self.model.state_dict(),
251
+ 'config': self.config,
252
+ }, final_path)
253
+
254
+ return self.config.output_model_path
255
+
256
+ @torch.no_grad()
257
+ def sample(self, num_samples=1):
258
+ self.model.eval()
259
+
260
+ x = torch.randn(
261
+ num_samples,
262
+ self.config.num_channels,
263
+ self.config.image_size,
264
+ self.config.image_size,
265
+ device=self.device
266
+ )
267
+
268
+ for t in reversed(range(self.config.timesteps)):
269
+ t_batch = torch.full(
270
+ (num_samples,),
271
+ t,
272
+ device=self.device,
273
+ dtype=torch.long
274
+ )
275
+
276
+ predicted_noise = self.model(x, t_batch)
277
+
278
+ alpha = self.alphas[t]
279
+ alpha_cumprod = self.alphas_cumprod[t]
280
+ beta = self.betas[t]
281
+
282
+ if t > 0:
283
+ noise = torch.randn_like(x)
284
+ else:
285
+ noise = torch.zeros_like(x)
286
+
287
+ x = (1 / torch.sqrt(alpha)) * (
288
+ x - (beta / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
289
+ ) + torch.sqrt(beta) * noise
290
+
291
+ x = (x + 1) / 2
292
+ x = torch.clamp(x, 0, 1)
293
+
294
+ return x
295
+
296
+ else:
297
+ SinusoidalPositionEmbeddings = None
298
+ SimpleUNet = None
299
+ ImageDataset = None
300
+ DiffusionTrainer = None
301
+
302
+
303
+ def train_diffusion(image_paths, captions=None, config=None,
304
+ resume_from=None):
305
+ if not TORCH_AVAILABLE:
306
+ raise ImportError(
307
+ "PyTorch not available. Install: pip install torch torchvision"
98
308
  )
99
-
100
- self.down2 = nn.Sequential(
101
- nn.Conv2d(channels * 2, channels * 4, 4, 2, 1),
102
- nn.GroupNorm(8, channels * 4),
103
- nn.GELU(),
309
+
310
+ if config is None:
311
+ config = DiffusionConfig()
312
+
313
+ if captions is None:
314
+ captions = [''] * len(image_paths)
315
+
316
+ dataset = ImageDataset(image_paths, captions, config.image_size)
317
+ dataloader = DataLoader(
318
+ dataset,
319
+ batch_size=config.batch_size,
320
+ shuffle=True,
321
+ num_workers=0
322
+ )
323
+
324
+ trainer = DiffusionTrainer(config)
325
+
326
+ if resume_from and os.path.exists(resume_from):
327
+ checkpoint = torch.load(resume_from, map_location=trainer.device)
328
+ trainer.model.load_state_dict(checkpoint['model_state_dict'])
329
+ print(f'Resumed from {resume_from}')
330
+
331
+ output_path = trainer.train(dataloader)
332
+
333
+ gc.collect()
334
+ if torch.cuda.is_available():
335
+ torch.cuda.empty_cache()
336
+
337
+ return output_path
338
+
339
+
340
+ def generate_image(model_path, prompt=None, num_samples=1, image_size=128):
341
+ if not TORCH_AVAILABLE:
342
+ raise ImportError(
343
+ "PyTorch not available. Install: pip install torch torchvision"
104
344
  )
105
-
106
- self.down3 = nn.Sequential(
107
- nn.Conv2d(channels * 4, channels * 8, 4, 2, 1),
108
- nn.GroupNorm(8, channels * 8),
109
- nn.GELU(),
110
- )
345
+
346
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
347
+
348
+ # Fix: Load with weights_only=False for your custom checkpoint
349
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
350
+
351
+ if 'config' in checkpoint:
352
+ config = checkpoint['config']
353
+ else:
354
+ config = DiffusionConfig(image_size=image_size)
355
+
356
+ trainer = DiffusionTrainer(config)
357
+ trainer.model.load_state_dict(checkpoint['model_state_dict'])
358
+
359
+ samples = trainer.sample(num_samples)
360
+
361
+ images = []
362
+ for i in range(num_samples):
363
+ img_tensor = samples[i].cpu()
364
+ img_np = img_tensor.permute(1, 2, 0).numpy()
365
+ img_np = (img_np * 255).astype(np.uint8)
366
+ img = Image.fromarray(img_np)
367
+ images.append(img)
368
+
369
+ if num_samples == 1:
370
+ return images[0]
371
+ return images
npcpy/gen/image_gen.py CHANGED
@@ -21,29 +21,86 @@ def generate_image_diffusers(
21
21
  """Generate an image using the Stable Diffusion API with memory optimization."""
22
22
  import torch
23
23
  import gc
24
-
24
+ import os
25
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline
26
+
25
27
  try:
26
28
  torch_dtype = torch.float16 if device != "cpu" and torch.cuda.is_available() else torch.float32
27
29
 
28
- if 'Qwen' in model:
29
- from diffusers import DiffusionPipeline
30
-
31
- pipe = DiffusionPipeline.from_pretrained(
32
- model,
33
- torch_dtype=torch_dtype,
34
- use_safetensors=True,
35
- variant="fp16" if torch_dtype == torch.float16 else None,
36
- )
37
- else:
38
- from diffusers import StableDiffusionPipeline
30
+ if os.path.isdir(model):
31
+ print(f"🌋 Loading fine-tuned Diffusers model from local path: {model}")
39
32
 
40
- pipe = StableDiffusionPipeline.from_pretrained(
41
- model,
42
- torch_dtype=torch_dtype,
43
- use_safetensors=True,
44
- variant="fp16" if torch_dtype == torch.float16 else None,
45
- )
33
+ checkpoint_path = os.path.join(model, 'model_final.pt')
34
+ if os.path.exists(checkpoint_path):
35
+ print(f"🌋 Found model_final.pt at {checkpoint_path}.")
36
+
37
+ # Load checkpoint to inspect it
38
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
39
+
40
+ # Check if this is a custom SimpleUNet model (from your training code)
41
+ # vs a Stable Diffusion UNet2DConditionModel
42
+ if 'config' in checkpoint and hasattr(checkpoint['config'], 'image_size'):
43
+ print(f"🌋 Detected custom SimpleUNet model, using custom generation")
44
+ # Use your custom generate_image function from npcpy.ft.diff
45
+ from npcpy.ft.diff import generate_image as custom_generate_image
46
+
47
+ # Your custom model ignores prompts and generates based on training data
48
+ image = custom_generate_image(
49
+ model_path=checkpoint_path,
50
+ prompt=prompt,
51
+ num_samples=1,
52
+ image_size=height # Use the requested height
53
+ )
54
+ return image
55
+
56
+ else:
57
+ # This is a Stable Diffusion checkpoint
58
+ print(f"🌋 Detected Stable Diffusion UNet checkpoint")
59
+ base_model_id = "runwayml/stable-diffusion-v1-5"
60
+ print(f"🌋 Loading base pipeline: {base_model_id}")
61
+ pipe = StableDiffusionPipeline.from_pretrained(
62
+ base_model_id,
63
+ torch_dtype=torch_dtype,
64
+ use_safetensors=True,
65
+ variant="fp16" if torch_dtype == torch.float16 else None,
66
+ )
67
+
68
+ print(f"🌋 Loading custom UNet weights from {checkpoint_path}")
69
+
70
+ # Extract the actual model state dict
71
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
72
+ unet_state_dict = checkpoint['model_state_dict']
73
+ print(f"🌋 Extracted model_state_dict from checkpoint")
74
+ else:
75
+ unet_state_dict = checkpoint
76
+ print(f"🌋 Using checkpoint directly as state_dict")
77
+
78
+ # Load the state dict into the UNet
79
+ pipe.unet.load_state_dict(unet_state_dict)
80
+ pipe = pipe.to(device)
81
+ print(f"🌋 Successfully loaded fine-tuned UNet weights")
82
+
83
+ else:
84
+ raise OSError(f"Error: Fine-tuned model directory {model} does not contain 'model_final.pt'")
85
+
86
+ else:
87
+ print(f"🌋 Loading standard Diffusers model: {model}")
88
+ if 'Qwen' in model:
89
+ pipe = DiffusionPipeline.from_pretrained(
90
+ model,
91
+ torch_dtype=torch_dtype,
92
+ use_safetensors=True,
93
+ variant="fp16" if torch_dtype == torch.float16 else None,
94
+ )
95
+ else:
96
+ pipe = StableDiffusionPipeline.from_pretrained(
97
+ model,
98
+ torch_dtype=torch_dtype,
99
+ use_safetensors=True,
100
+ variant="fp16" if torch_dtype == torch.float16 else None,
101
+ )
46
102
 
103
+ # Common pipeline setup for Stable Diffusion models
47
104
  if hasattr(pipe, 'enable_attention_slicing'):
48
105
  pipe.enable_attention_slicing()
49
106
 
@@ -85,7 +142,6 @@ def generate_image_diffusers(
85
142
  raise MemoryError(f"Insufficient memory for image generation with model {model}. Try a smaller model or reduce image size.")
86
143
  else:
87
144
  raise e
88
-
89
145
  import os
90
146
  import base64
91
147
  import io
@@ -294,6 +350,8 @@ def gemini_image_gen(
294
350
 
295
351
  else:
296
352
  raise ValueError(f"Unsupported Gemini image model or API usage for new generation: '{model}'")
353
+ # In npcpy/gen/image_gen.py, find the generate_image function and replace it with this:
354
+
297
355
  def generate_image(
298
356
  prompt: str,
299
357
  model: str ,
@@ -305,6 +363,7 @@ def generate_image(
305
363
  api_url: Optional[str] = None,
306
364
  attachments: Union[List[Union[str, bytes, Image.Image]], None] = None,
307
365
  save_path: Optional[str] = None,
366
+ custom_model_path: Optional[str] = None, # <--- NEW: Accept custom_model_path
308
367
  ):
309
368
  """
310
369
  Unified function to generate or edit images using various providers.
@@ -320,13 +379,15 @@ def generate_image(
320
379
  api_url (str): API URL for the provider.
321
380
  attachments (list): List of images for editing. Can be file paths, bytes, or PIL Images.
322
381
  save_path (str): Path to save the generated image.
382
+ custom_model_path (str): Path to a locally fine-tuned Diffusers model. <--- NEW
323
383
 
324
384
  Returns:
325
385
  List[PIL.Image.Image]: A list of generated PIL Image objects.
326
386
  """
327
387
  from urllib.request import urlopen
388
+ import os # Ensure os is imported for path checks
328
389
 
329
- if model is None:
390
+ if model is None and custom_model_path is None: # Only set default if no model or custom path is provided
330
391
  if provider == "openai":
331
392
  model = "dall-e-2"
332
393
  elif provider == "diffusers":
@@ -336,12 +397,22 @@ def generate_image(
336
397
 
337
398
  all_generated_pil_images = []
338
399
 
400
+ # <--- CRITICAL FIX: Handle custom_model_path for Diffusers here
339
401
  if provider == "diffusers":
402
+ # If a custom_model_path is provided and exists, use it instead of a generic model name
403
+ if custom_model_path and os.path.isdir(custom_model_path):
404
+ print(f"🌋 Using custom Diffusers model from path: {custom_model_path}")
405
+ model_to_use = custom_model_path
406
+ else:
407
+ # Otherwise, use the standard model name (e.g., "runwayml/stable-diffusion-v1-5")
408
+ model_to_use = model
409
+ print(f"🌋 Using standard Diffusers model: {model_to_use}")
410
+
340
411
  for _ in range(n_images):
341
412
  try:
342
413
  image = generate_image_diffusers(
343
414
  prompt=prompt,
344
- model=model,
415
+ model=model_to_use, # <--- Pass the resolved model_to_use
345
416
  height=height,
346
417
  width=width
347
418
  )
@@ -373,15 +444,42 @@ def generate_image(
373
444
  all_generated_pil_images.extend(images)
374
445
 
375
446
  else:
447
+ # This is the fallback for other providers or if provider is not explicitly handled
376
448
  valid_sizes = ["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]
377
449
  size = f"{width}x{height}"
378
450
 
379
451
  if attachments is not None:
380
452
  raise ValueError("Image editing not supported with litellm provider")
381
453
 
454
+ # The litellm.image_generation function expects the provider as part of the model string
455
+ # e.g., "huggingface/starcoder" or "openai/dall-e-3"
456
+ # Since we've already handled "diffusers", "openai", "gemini" above,
457
+ # this 'else' block implies a generic litellm call.
458
+ # We need to ensure the model string is correctly formatted for litellm.
459
+ # However, the error message "LLM Provider NOT provided" suggests litellm
460
+ # is not even getting the `provider` correctly.
461
+ # The fix for this is ensuring the `provider` is explicitly passed to litellm.image_generation
462
+ # which is already happening in `gen_image` in `llm_funcs.py`
463
+
464
+ # If we reach here, it means the provider is not 'diffusers', 'openai', or 'gemini',
465
+ # and litellm is the intended route. We need to pass the provider explicitly.
466
+ # The original code here was trying to construct `model=f"{provider}/{model}"`
467
+ # but the error indicates `provider` itself was missing.
468
+ # The `image_generation` from litellm expects `model` to be `provider/model_name`.
469
+ # Since the `provider` variable is available, we can construct this.
470
+
471
+ # This block is for generic litellm providers (not diffusers, openai, gemini)
472
+ # The error indicates `provider` itself was not making it to litellm.
473
+ # This `generate_image` function already receives `provider`.
474
+ # The issue is likely how `gen_image` in `llm_funcs.py` calls this `generate_image`.
475
+ # However, if this `else` branch is hit, we ensure litellm gets the provider.
476
+
477
+ # Construct the model string for litellm
478
+ litellm_model_string = f"{provider}/{model}" if provider and model else model
479
+
382
480
  image_response = image_generation(
383
481
  prompt=prompt,
384
- model=f"{provider}/{model}",
482
+ model=litellm_model_string, # <--- Ensure model string includes provider for litellm
385
483
  n=n_images,
386
484
  size=size,
387
485
  api_key=api_key,
@@ -407,7 +505,6 @@ def generate_image(
407
505
 
408
506
  return all_generated_pil_images
409
507
 
410
-
411
508
  def edit_image(
412
509
  prompt: str,
413
510
  image_path: str,