npcpy 1.1.28__py3-none-any.whl → 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- npcpy/data/audio.py +16 -38
- npcpy/data/image.py +29 -29
- npcpy/data/load.py +4 -3
- npcpy/data/text.py +28 -28
- npcpy/data/video.py +6 -6
- npcpy/data/web.py +49 -21
- npcpy/ft/__init__.py +0 -0
- npcpy/ft/diff.py +110 -0
- npcpy/ft/ge.py +115 -0
- npcpy/ft/memory_trainer.py +171 -0
- npcpy/ft/model_ensembler.py +357 -0
- npcpy/ft/rl.py +360 -0
- npcpy/ft/sft.py +248 -0
- npcpy/ft/usft.py +128 -0
- npcpy/gen/audio_gen.py +24 -0
- npcpy/gen/embeddings.py +13 -13
- npcpy/gen/image_gen.py +37 -15
- npcpy/gen/response.py +287 -111
- npcpy/gen/video_gen.py +10 -9
- npcpy/llm_funcs.py +447 -79
- npcpy/memory/command_history.py +201 -48
- npcpy/memory/kg_vis.py +74 -74
- npcpy/memory/knowledge_graph.py +482 -115
- npcpy/memory/memory_processor.py +81 -0
- npcpy/memory/search.py +70 -70
- npcpy/mix/debate.py +192 -3
- npcpy/npc_compiler.py +1541 -879
- npcpy/npc_sysenv.py +250 -78
- npcpy/serve.py +1036 -321
- npcpy/sql/ai_function_tools.py +257 -0
- npcpy/sql/database_ai_adapters.py +186 -0
- npcpy/sql/database_ai_functions.py +163 -0
- npcpy/sql/model_runner.py +19 -19
- npcpy/sql/npcsql.py +706 -507
- npcpy/sql/sql_model_compiler.py +156 -0
- npcpy/tools.py +20 -20
- npcpy/work/plan.py +8 -8
- npcpy/work/trigger.py +3 -3
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/METADATA +169 -9
- npcpy-1.2.32.dist-info/RECORD +54 -0
- npcpy-1.1.28.dist-info/RECORD +0 -40
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/WHEEL +0 -0
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/licenses/LICENSE +0 -0
- {npcpy-1.1.28.dist-info → npcpy-1.2.32.dist-info}/top_level.txt +0 -0
npcpy/ft/sft.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
# structured fine tuning of LLMs to produce structured output
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
import json
|
|
5
|
+
import numpy as np
|
|
6
|
+
import os
|
|
7
|
+
try:
|
|
8
|
+
import torch
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoModelForCausalLM,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
TrainingArguments
|
|
13
|
+
)
|
|
14
|
+
from trl import SFTTrainer
|
|
15
|
+
from peft import LoraConfig
|
|
16
|
+
except:
|
|
17
|
+
torch = None
|
|
18
|
+
SFTTrainer = None
|
|
19
|
+
LoraConfig = None
|
|
20
|
+
AutoModelForCausalLM = None
|
|
21
|
+
AutoTokenizer = None
|
|
22
|
+
TrainingArguments = None
|
|
23
|
+
|
|
24
|
+
from typing import List, Dict, Any, Optional
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SFTConfig:
|
|
29
|
+
base_model_name: str = "google/gemma-3-270m-it"
|
|
30
|
+
output_model_path: str = "models/sft_model"
|
|
31
|
+
lora_r: int = 8
|
|
32
|
+
lora_alpha: int = 16
|
|
33
|
+
use_4bit: bool = False
|
|
34
|
+
fp16: bool = False
|
|
35
|
+
bf16: bool = False
|
|
36
|
+
lora_dropout: float = 0.15
|
|
37
|
+
lora_target_modules: List[str] = field(
|
|
38
|
+
default_factory=lambda: ["q_proj", "v_proj"]
|
|
39
|
+
)
|
|
40
|
+
num_train_epochs: int = 20
|
|
41
|
+
per_device_train_batch_size: int = 2
|
|
42
|
+
gradient_accumulation_steps: int = 4
|
|
43
|
+
learning_rate: float = 3e-5
|
|
44
|
+
logging_steps: int = 10
|
|
45
|
+
optim: str = "adamw_torch"
|
|
46
|
+
lr_scheduler_type: str = "cosine_with_restarts"
|
|
47
|
+
weight_decay: float = 0.01
|
|
48
|
+
max_length: int = 512
|
|
49
|
+
save_steps: int = 50
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def format_training_examples(
|
|
53
|
+
inputs: List[str],
|
|
54
|
+
outputs: List[str],
|
|
55
|
+
format_style: str = "gemma"
|
|
56
|
+
) -> List[Dict[str, str]]:
|
|
57
|
+
|
|
58
|
+
formatted = []
|
|
59
|
+
|
|
60
|
+
for inp, out in zip(inputs, outputs):
|
|
61
|
+
if format_style == "gemma":
|
|
62
|
+
text = (
|
|
63
|
+
f"<start_of_turn>user\n{inp}<end_of_turn>\n"
|
|
64
|
+
f"<start_of_turn>model\n{out}<end_of_turn>"
|
|
65
|
+
)
|
|
66
|
+
elif format_style == "llama":
|
|
67
|
+
text = (
|
|
68
|
+
f"<|begin_of_text|><|start_header_id|>user"
|
|
69
|
+
f"<|end_header_id|>\n\n{inp}<|eot_id|>"
|
|
70
|
+
f"<|start_header_id|>assistant<|end_header_id|>"
|
|
71
|
+
f"\n\n{out}<|eot_id|>"
|
|
72
|
+
)
|
|
73
|
+
else:
|
|
74
|
+
text = f"Input: {inp}\nOutput: {out}"
|
|
75
|
+
|
|
76
|
+
formatted.append({"text": text})
|
|
77
|
+
|
|
78
|
+
return formatted
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def run_sft(
|
|
82
|
+
X: List[str],
|
|
83
|
+
y: List[str],
|
|
84
|
+
config: Optional[SFTConfig] = None,
|
|
85
|
+
validation_split: float = 0.0,
|
|
86
|
+
format_style: str = "gemma"
|
|
87
|
+
) -> str:
|
|
88
|
+
|
|
89
|
+
if config is None:
|
|
90
|
+
config = SFTConfig()
|
|
91
|
+
|
|
92
|
+
if len(X) != len(y):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"X and y must have same length: {len(X)} vs {len(y)}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
formatted_examples = format_training_examples(
|
|
98
|
+
X, y, format_style
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if validation_split > 0:
|
|
102
|
+
split_idx = int(len(formatted_examples) * (1 - validation_split))
|
|
103
|
+
train_examples = formatted_examples[:split_idx]
|
|
104
|
+
val_examples = formatted_examples[split_idx:]
|
|
105
|
+
print(
|
|
106
|
+
f"Split: {len(train_examples)} train, "
|
|
107
|
+
f"{len(val_examples)} val"
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
train_examples = formatted_examples
|
|
111
|
+
val_examples = []
|
|
112
|
+
|
|
113
|
+
dataset = Dataset.from_list(train_examples)
|
|
114
|
+
|
|
115
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
116
|
+
config.base_model_name,
|
|
117
|
+
trust_remote_code=True,
|
|
118
|
+
attn_implementation="eager"
|
|
119
|
+
)
|
|
120
|
+
model.config.use_cache = False
|
|
121
|
+
|
|
122
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
123
|
+
config.base_model_name,
|
|
124
|
+
trust_remote_code=True
|
|
125
|
+
)
|
|
126
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
127
|
+
tokenizer.padding_side = "right"
|
|
128
|
+
|
|
129
|
+
peft_config = LoraConfig(
|
|
130
|
+
r=config.lora_r,
|
|
131
|
+
lora_alpha=config.lora_alpha,
|
|
132
|
+
lora_dropout=config.lora_dropout,
|
|
133
|
+
target_modules=config.lora_target_modules,
|
|
134
|
+
bias="none",
|
|
135
|
+
task_type="CAUSAL_LM"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
training_args = TrainingArguments(
|
|
139
|
+
output_dir=config.output_model_path,
|
|
140
|
+
num_train_epochs=config.num_train_epochs,
|
|
141
|
+
per_device_train_batch_size=(
|
|
142
|
+
config.per_device_train_batch_size
|
|
143
|
+
),
|
|
144
|
+
gradient_accumulation_steps=(
|
|
145
|
+
config.gradient_accumulation_steps
|
|
146
|
+
),
|
|
147
|
+
optim=config.optim,
|
|
148
|
+
logging_steps=config.logging_steps,
|
|
149
|
+
learning_rate=config.learning_rate,
|
|
150
|
+
fp16=config.fp16,
|
|
151
|
+
bf16=config.bf16,
|
|
152
|
+
lr_scheduler_type=config.lr_scheduler_type,
|
|
153
|
+
group_by_length=True,
|
|
154
|
+
save_steps=config.save_steps,
|
|
155
|
+
weight_decay=config.weight_decay,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def formatting_func(example):
|
|
159
|
+
return example["text"]
|
|
160
|
+
|
|
161
|
+
trainer = SFTTrainer(
|
|
162
|
+
model=model,
|
|
163
|
+
train_dataset=dataset,
|
|
164
|
+
peft_config=peft_config,
|
|
165
|
+
args=training_args,
|
|
166
|
+
processing_class=tokenizer,
|
|
167
|
+
formatting_func=formatting_func
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
print(f"Training on {len(dataset)} examples")
|
|
171
|
+
trainer.train()
|
|
172
|
+
|
|
173
|
+
trainer.save_model(config.output_model_path)
|
|
174
|
+
print(f"Model saved to {config.output_model_path}")
|
|
175
|
+
|
|
176
|
+
return config.output_model_path
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def load_sft_model(model_path: str):
|
|
180
|
+
|
|
181
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
182
|
+
model_path,
|
|
183
|
+
torch_dtype=torch.float32,
|
|
184
|
+
device_map="auto",
|
|
185
|
+
attn_implementation="eager"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
189
|
+
model_path,
|
|
190
|
+
trust_remote_code=True
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if tokenizer.pad_token is None:
|
|
194
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
195
|
+
|
|
196
|
+
return model, tokenizer
|
|
197
|
+
def predict_sft(
|
|
198
|
+
model,
|
|
199
|
+
tokenizer,
|
|
200
|
+
prompt: str,
|
|
201
|
+
max_new_tokens: int = 128,
|
|
202
|
+
temperature: float = 0.7
|
|
203
|
+
) -> str:
|
|
204
|
+
|
|
205
|
+
device = next(model.parameters()).device
|
|
206
|
+
|
|
207
|
+
formatted_prompt = (
|
|
208
|
+
f"<start_of_turn>user\n{prompt}<end_of_turn>\n"
|
|
209
|
+
f"<start_of_turn>model\n"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
inputs = tokenizer(
|
|
213
|
+
formatted_prompt,
|
|
214
|
+
return_tensors="pt",
|
|
215
|
+
truncation=True,
|
|
216
|
+
max_length=512
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
input_ids = inputs.input_ids.to(device)
|
|
220
|
+
attention_mask = inputs.attention_mask.to(device)
|
|
221
|
+
|
|
222
|
+
with torch.no_grad():
|
|
223
|
+
outputs = model.generate(
|
|
224
|
+
input_ids=input_ids,
|
|
225
|
+
attention_mask=attention_mask,
|
|
226
|
+
max_new_tokens=max_new_tokens,
|
|
227
|
+
temperature=temperature,
|
|
228
|
+
do_sample=temperature > 0,
|
|
229
|
+
pad_token_id=tokenizer.eos_token_id
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
full_response = tokenizer.decode(
|
|
233
|
+
outputs[0],
|
|
234
|
+
skip_special_tokens=False
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if "<start_of_turn>model\n" in full_response:
|
|
238
|
+
response = full_response.split(
|
|
239
|
+
"<start_of_turn>model\n"
|
|
240
|
+
)[-1]
|
|
241
|
+
response = response.split("<end_of_turn>")[0].strip()
|
|
242
|
+
else:
|
|
243
|
+
response = tokenizer.decode(
|
|
244
|
+
outputs[0][len(input_ids[0]):],
|
|
245
|
+
skip_special_tokens=True
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return response
|
npcpy/ft/usft.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
try:
|
|
3
|
+
from datasets import Dataset, load_dataset
|
|
4
|
+
import torch
|
|
5
|
+
from transformers import (
|
|
6
|
+
AutoModelForCausalLM,
|
|
7
|
+
AutoTokenizer,
|
|
8
|
+
TrainingArguments
|
|
9
|
+
)
|
|
10
|
+
from trl import SFTTrainer
|
|
11
|
+
from peft import LoraConfig
|
|
12
|
+
except:
|
|
13
|
+
Dataset = None
|
|
14
|
+
load_dataset = None
|
|
15
|
+
torch = None
|
|
16
|
+
AutoModelForCausalLM = None
|
|
17
|
+
AutoTokenizer = None
|
|
18
|
+
TrainingArguments = None
|
|
19
|
+
SFTTrainer = None
|
|
20
|
+
|
|
21
|
+
from typing import List, Optional
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class USFTConfig:
|
|
26
|
+
base_model_name: str = "Qwen/Qwen3-0.6B"
|
|
27
|
+
output_model_path: str = "models/usft_model"
|
|
28
|
+
lora_r: int = 8
|
|
29
|
+
lora_alpha: int = 16
|
|
30
|
+
lora_dropout: float = 0.15
|
|
31
|
+
lora_target_modules: List[str] = field(
|
|
32
|
+
default_factory=lambda: ["q_proj", "v_proj"]
|
|
33
|
+
)
|
|
34
|
+
num_train_epochs: int = 3
|
|
35
|
+
per_device_train_batch_size: int = 4
|
|
36
|
+
gradient_accumulation_steps: int = 4
|
|
37
|
+
learning_rate: float = 2e-5
|
|
38
|
+
logging_steps: int = 10
|
|
39
|
+
optim: str = "adamw_torch"
|
|
40
|
+
lr_scheduler_type: str = "cosine"
|
|
41
|
+
weight_decay: float = 0.01
|
|
42
|
+
max_length: int = 512
|
|
43
|
+
save_steps: int = 100
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def run_usft(
|
|
47
|
+
texts: List[str],
|
|
48
|
+
config: Optional[USFTConfig] = None
|
|
49
|
+
) -> str:
|
|
50
|
+
|
|
51
|
+
if config is None:
|
|
52
|
+
config = USFTConfig()
|
|
53
|
+
|
|
54
|
+
dataset = Dataset.from_dict({"text": texts})
|
|
55
|
+
|
|
56
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
57
|
+
config.base_model_name,
|
|
58
|
+
trust_remote_code=True,
|
|
59
|
+
attn_implementation="eager"
|
|
60
|
+
)
|
|
61
|
+
model.config.use_cache = False
|
|
62
|
+
|
|
63
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
64
|
+
config.base_model_name,
|
|
65
|
+
trust_remote_code=True
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if tokenizer.pad_token is None:
|
|
69
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
70
|
+
|
|
71
|
+
tokenizer.padding_side = "right"
|
|
72
|
+
|
|
73
|
+
peft_config = LoraConfig(
|
|
74
|
+
r=config.lora_r,
|
|
75
|
+
lora_alpha=config.lora_alpha,
|
|
76
|
+
lora_dropout=config.lora_dropout,
|
|
77
|
+
target_modules=config.lora_target_modules,
|
|
78
|
+
bias="none",
|
|
79
|
+
task_type="CAUSAL_LM"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
training_args = TrainingArguments(
|
|
83
|
+
output_dir=config.output_model_path,
|
|
84
|
+
num_train_epochs=config.num_train_epochs,
|
|
85
|
+
per_device_train_batch_size=(
|
|
86
|
+
config.per_device_train_batch_size
|
|
87
|
+
),
|
|
88
|
+
gradient_accumulation_steps=(
|
|
89
|
+
config.gradient_accumulation_steps
|
|
90
|
+
),
|
|
91
|
+
optim=config.optim,
|
|
92
|
+
logging_steps=config.logging_steps,
|
|
93
|
+
learning_rate=config.learning_rate,
|
|
94
|
+
fp16=False,
|
|
95
|
+
bf16=torch.cuda.is_available(),
|
|
96
|
+
lr_scheduler_type=config.lr_scheduler_type,
|
|
97
|
+
save_steps=config.save_steps,
|
|
98
|
+
weight_decay=config.weight_decay,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
trainer = SFTTrainer(
|
|
102
|
+
model=model,
|
|
103
|
+
train_dataset=dataset,
|
|
104
|
+
peft_config=peft_config,
|
|
105
|
+
args=training_args,
|
|
106
|
+
max_seq_length=config.max_length,
|
|
107
|
+
dataset_text_field="text"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
print(f"Starting USFT on {len(dataset)} texts")
|
|
111
|
+
trainer.train()
|
|
112
|
+
|
|
113
|
+
trainer.save_model(config.output_model_path)
|
|
114
|
+
print(f"Model saved to {config.output_model_path}")
|
|
115
|
+
|
|
116
|
+
return config.output_model_path
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def load_corpus_from_hf(dataset_name: str, split: str = "train"):
|
|
120
|
+
|
|
121
|
+
ds = load_dataset(dataset_name, split=split)
|
|
122
|
+
|
|
123
|
+
if "text" in ds.column_names:
|
|
124
|
+
return ds["text"]
|
|
125
|
+
elif "content" in ds.column_names:
|
|
126
|
+
return ds["content"]
|
|
127
|
+
else:
|
|
128
|
+
return [str(item) for item in ds]
|
npcpy/gen/audio_gen.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import os
|
|
2
|
+
def tts_elevenlabs(text,
|
|
3
|
+
api_key=None,
|
|
4
|
+
voice_id='JBFqnCBsd6RMkjVDRZzb',
|
|
5
|
+
model_id='eleven_multilingual_v2',
|
|
6
|
+
output_format= 'mp3_44100_128'):
|
|
7
|
+
if api_key is None:
|
|
8
|
+
api_key = os.environ.get('ELEVENLABS_API_KEY')
|
|
9
|
+
from elevenlabs.client import ElevenLabs
|
|
10
|
+
from elevenlabs import play
|
|
11
|
+
|
|
12
|
+
client = ElevenLabs(
|
|
13
|
+
api_key=api_key,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
audio = client.text_to_speech.convert(
|
|
17
|
+
text=text,
|
|
18
|
+
voice_id=voice_id,
|
|
19
|
+
model_id=model_id,
|
|
20
|
+
output_format= output_format
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
play(audio)
|
|
24
|
+
return audio
|
npcpy/gen/embeddings.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
7
|
from typing import List, Dict, Optional
|
|
8
8
|
import numpy as np
|
|
9
9
|
from datetime import datetime
|
|
@@ -49,17 +49,17 @@ def store_embeddings_for_model(
|
|
|
49
49
|
collection_name = f"{provider}_{model}_embeddings"
|
|
50
50
|
collection = chroma_client.get_collection(collection_name)
|
|
51
51
|
|
|
52
|
-
|
|
52
|
+
|
|
53
53
|
if metadata is None:
|
|
54
|
-
metadata = [{"text_length": len(text)} for text in texts]
|
|
54
|
+
metadata = [{"text_length": len(text)} for text in texts]
|
|
55
55
|
print(
|
|
56
56
|
"metadata is none, creating metadata for each document as the length of the text"
|
|
57
57
|
)
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
collection.add(
|
|
60
60
|
ids=[str(i) for i in range(len(texts))],
|
|
61
61
|
embeddings=embeddings,
|
|
62
|
-
metadatas=metadata,
|
|
62
|
+
metadatas=metadata,
|
|
63
63
|
documents=texts,
|
|
64
64
|
)
|
|
65
65
|
|
|
@@ -67,7 +67,7 @@ def store_embeddings_for_model(
|
|
|
67
67
|
def delete_embeddings_from_collection(collection, ids):
|
|
68
68
|
"""Delete embeddings by id from Chroma collection."""
|
|
69
69
|
if ids:
|
|
70
|
-
collection.delete(ids=ids)
|
|
70
|
+
collection.delete(ids=ids)
|
|
71
71
|
|
|
72
72
|
|
|
73
73
|
def get_embeddings(
|
|
@@ -83,6 +83,6 @@ def get_embeddings(
|
|
|
83
83
|
else:
|
|
84
84
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
85
85
|
|
|
86
|
-
|
|
87
|
-
|
|
86
|
+
|
|
87
|
+
|
|
88
88
|
return embeddings
|
npcpy/gen/image_gen.py
CHANGED
|
@@ -86,6 +86,16 @@ def generate_image_diffusers(
|
|
|
86
86
|
else:
|
|
87
87
|
raise e
|
|
88
88
|
|
|
89
|
+
import os
|
|
90
|
+
import base64
|
|
91
|
+
import io
|
|
92
|
+
from typing import Union, List, Optional
|
|
93
|
+
|
|
94
|
+
import PIL
|
|
95
|
+
from PIL import Image
|
|
96
|
+
|
|
97
|
+
import requests
|
|
98
|
+
from urllib.request import urlopen
|
|
89
99
|
|
|
90
100
|
def openai_image_gen(
|
|
91
101
|
prompt: str,
|
|
@@ -97,36 +107,47 @@ def openai_image_gen(
|
|
|
97
107
|
):
|
|
98
108
|
"""Generate or edit an image using the OpenAI API."""
|
|
99
109
|
from openai import OpenAI
|
|
100
|
-
|
|
110
|
+
|
|
101
111
|
client = OpenAI()
|
|
102
|
-
|
|
112
|
+
|
|
103
113
|
if height is None:
|
|
104
114
|
height = 1024
|
|
105
115
|
if width is None:
|
|
106
|
-
width = 1024
|
|
107
|
-
|
|
108
|
-
size_str = f"{width}x{height}"
|
|
116
|
+
width = 1024
|
|
117
|
+
|
|
118
|
+
size_str = f"{width}x{height}"
|
|
109
119
|
|
|
110
120
|
if attachments is not None:
|
|
111
121
|
processed_images = []
|
|
122
|
+
files_to_close = []
|
|
112
123
|
for attachment in attachments:
|
|
113
124
|
if isinstance(attachment, str):
|
|
114
|
-
|
|
125
|
+
file_handle = open(attachment, "rb")
|
|
126
|
+
processed_images.append(file_handle)
|
|
127
|
+
files_to_close.append(file_handle)
|
|
115
128
|
elif isinstance(attachment, bytes):
|
|
116
|
-
|
|
129
|
+
img_byte_arr = io.BytesIO(attachment)
|
|
130
|
+
img_byte_arr.name = 'image.png' # FIX: Add filename hint
|
|
131
|
+
processed_images.append(img_byte_arr)
|
|
117
132
|
elif isinstance(attachment, Image.Image):
|
|
118
133
|
img_byte_arr = io.BytesIO()
|
|
119
134
|
attachment.save(img_byte_arr, format='PNG')
|
|
120
135
|
img_byte_arr.seek(0)
|
|
136
|
+
img_byte_arr.name = 'image.png' # FIX: Add filename hint
|
|
121
137
|
processed_images.append(img_byte_arr)
|
|
122
138
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
139
|
+
try:
|
|
140
|
+
result = client.images.edit(
|
|
141
|
+
model=model,
|
|
142
|
+
image=processed_images[0],
|
|
143
|
+
prompt=prompt,
|
|
144
|
+
n=n_images,
|
|
145
|
+
size=size_str,
|
|
146
|
+
)
|
|
147
|
+
finally:
|
|
148
|
+
# This ensures any files we opened are properly closed
|
|
149
|
+
for f in files_to_close:
|
|
150
|
+
f.close()
|
|
130
151
|
else:
|
|
131
152
|
result = client.images.generate(
|
|
132
153
|
model=model,
|
|
@@ -134,7 +155,7 @@ def openai_image_gen(
|
|
|
134
155
|
n=n_images,
|
|
135
156
|
size=size_str,
|
|
136
157
|
)
|
|
137
|
-
|
|
158
|
+
|
|
138
159
|
collected_images = []
|
|
139
160
|
for item_data in result.data:
|
|
140
161
|
if model == 'gpt-image-1':
|
|
@@ -153,6 +174,7 @@ def openai_image_gen(
|
|
|
153
174
|
return collected_images
|
|
154
175
|
|
|
155
176
|
|
|
177
|
+
|
|
156
178
|
def gemini_image_gen(
|
|
157
179
|
prompt: str,
|
|
158
180
|
model: str = "gemini-2.5-flash",
|