cat-stack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cat_stack/__about__.py +10 -0
- cat_stack/__init__.py +128 -0
- cat_stack/_batch.py +1388 -0
- cat_stack/_category_analysis.py +348 -0
- cat_stack/_chunked.py +424 -0
- cat_stack/_embeddings.py +189 -0
- cat_stack/_formatter.py +169 -0
- cat_stack/_providers.py +1048 -0
- cat_stack/_tiebreaker.py +277 -0
- cat_stack/_utils.py +512 -0
- cat_stack/_web_fetch.py +194 -0
- cat_stack/calls/CoVe.py +287 -0
- cat_stack/calls/__init__.py +25 -0
- cat_stack/calls/all_calls.py +622 -0
- cat_stack/calls/image_CoVe.py +386 -0
- cat_stack/calls/image_stepback.py +210 -0
- cat_stack/calls/pdf_CoVe.py +386 -0
- cat_stack/calls/pdf_stepback.py +210 -0
- cat_stack/calls/stepback.py +180 -0
- cat_stack/calls/top_n.py +217 -0
- cat_stack/classify.py +682 -0
- cat_stack/explore.py +111 -0
- cat_stack/extract.py +218 -0
- cat_stack/image_functions.py +2078 -0
- cat_stack/images/circle.png +0 -0
- cat_stack/images/cube.png +0 -0
- cat_stack/images/diamond.png +0 -0
- cat_stack/images/overlapping_pentagons.png +0 -0
- cat_stack/images/rectangles.png +0 -0
- cat_stack/model_reference_list.py +94 -0
- cat_stack/pdf_functions.py +2087 -0
- cat_stack/summarize.py +290 -0
- cat_stack/text_functions.py +1358 -0
- cat_stack/text_functions_ensemble.py +3644 -0
- cat_stack-0.1.0.dist-info/METADATA +150 -0
- cat_stack-0.1.0.dist-info/RECORD +38 -0
- cat_stack-0.1.0.dist-info/WHEEL +4 -0
- cat_stack-0.1.0.dist-info/licenses/LICENSE +672 -0
|
@@ -0,0 +1,2078 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
from .text_functions import _detect_model_source
|
|
4
|
+
from .calls.image_stepback import get_image_stepback_insight
|
|
5
|
+
|
|
6
|
+
# Exported names (excludes deprecated image_multi_class)
|
|
7
|
+
__all__ = [
|
|
8
|
+
"_load_image_files",
|
|
9
|
+
"_encode_image",
|
|
10
|
+
"image_score_drawing",
|
|
11
|
+
"image_features",
|
|
12
|
+
"explore_image_categories",
|
|
13
|
+
]
|
|
14
|
+
from .calls.image_CoVe import (
|
|
15
|
+
image_chain_of_verification_openai,
|
|
16
|
+
image_chain_of_verification_anthropic,
|
|
17
|
+
image_chain_of_verification_google,
|
|
18
|
+
image_chain_of_verification_mistral
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _load_image_files(image_input):
|
|
23
|
+
"""Load image files from directory path, single file path, or return list as-is."""
|
|
24
|
+
import os
|
|
25
|
+
import glob
|
|
26
|
+
|
|
27
|
+
image_extensions = [
|
|
28
|
+
'*.png', '*.jpg', '*.jpeg',
|
|
29
|
+
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
30
|
+
'*.tif', '*.tiff', '*.bmp',
|
|
31
|
+
'*.heif', '*.heic', '*.ico',
|
|
32
|
+
'*.psd'
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
if isinstance(image_input, list):
|
|
36
|
+
image_files = image_input
|
|
37
|
+
print(f"Provided a list of {len(image_input)} images.")
|
|
38
|
+
elif os.path.isfile(image_input):
|
|
39
|
+
# Single file path
|
|
40
|
+
image_files = [image_input]
|
|
41
|
+
print(f"Provided 1 image file.")
|
|
42
|
+
elif os.path.isdir(image_input):
|
|
43
|
+
# Directory path - glob for images
|
|
44
|
+
image_files = []
|
|
45
|
+
for ext in image_extensions:
|
|
46
|
+
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
47
|
+
print(f"Found {len(image_files)} images in directory.")
|
|
48
|
+
else:
|
|
49
|
+
raise FileNotFoundError(f"Image input not found: {image_input}")
|
|
50
|
+
|
|
51
|
+
return image_files
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _encode_image(img_path):
|
|
55
|
+
"""Encode an image file to base64. Returns (encoded_data, extension, is_valid)."""
|
|
56
|
+
import os
|
|
57
|
+
import base64
|
|
58
|
+
from pathlib import Path
|
|
59
|
+
|
|
60
|
+
if img_path is None or not os.path.exists(img_path):
|
|
61
|
+
return None, None, False
|
|
62
|
+
|
|
63
|
+
if os.path.isdir(img_path):
|
|
64
|
+
return None, None, False
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
with open(img_path, "rb") as f:
|
|
68
|
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
69
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
70
|
+
if ext == "jpg":
|
|
71
|
+
ext = "jpeg"
|
|
72
|
+
return encoded, ext, True
|
|
73
|
+
except Exception as e:
|
|
74
|
+
print(f"Error encoding image: {e}")
|
|
75
|
+
return None, None, False
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# image multi-class (binary) function
|
|
79
|
+
def image_multi_class(
|
|
80
|
+
image_description,
|
|
81
|
+
image_input,
|
|
82
|
+
categories,
|
|
83
|
+
api_key,
|
|
84
|
+
user_model="gpt-4o",
|
|
85
|
+
creativity=None,
|
|
86
|
+
safety=False,
|
|
87
|
+
chain_of_verification=False,
|
|
88
|
+
chain_of_thought=True,
|
|
89
|
+
step_back_prompt=False,
|
|
90
|
+
context_prompt=False,
|
|
91
|
+
thinking_budget=0,
|
|
92
|
+
example1=None,
|
|
93
|
+
example2=None,
|
|
94
|
+
example3=None,
|
|
95
|
+
example4=None,
|
|
96
|
+
example5=None,
|
|
97
|
+
example6=None,
|
|
98
|
+
filename=None,
|
|
99
|
+
save_directory=None,
|
|
100
|
+
model_source="auto"
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Classify images using LLMs.
|
|
104
|
+
|
|
105
|
+
.. deprecated::
|
|
106
|
+
Use :func:`cat_stack.classify` instead. This function will be removed in a future version.
|
|
107
|
+
"""
|
|
108
|
+
warnings.warn(
|
|
109
|
+
"image_multi_class() is deprecated and will be removed in a future version. "
|
|
110
|
+
"Use cat_stack.classify() instead, which auto-detects image input.",
|
|
111
|
+
DeprecationWarning,
|
|
112
|
+
stacklevel=2,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
import os
|
|
116
|
+
import json
|
|
117
|
+
import pandas as pd
|
|
118
|
+
import regex
|
|
119
|
+
import time
|
|
120
|
+
from tqdm import tqdm
|
|
121
|
+
|
|
122
|
+
if save_directory is not None and not os.path.isdir(save_directory):
|
|
123
|
+
raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
|
|
124
|
+
|
|
125
|
+
model_source = _detect_model_source(user_model, model_source)
|
|
126
|
+
|
|
127
|
+
image_files = _load_image_files(image_input)
|
|
128
|
+
|
|
129
|
+
# Handle "auto" categories - extract categories first
|
|
130
|
+
if categories == "auto":
|
|
131
|
+
if not image_description:
|
|
132
|
+
raise ValueError("image_description is required when using categories='auto'")
|
|
133
|
+
|
|
134
|
+
print("\nAuto-extracting categories from images...")
|
|
135
|
+
auto_result = explore_image_categories(
|
|
136
|
+
image_input=image_input,
|
|
137
|
+
api_key=api_key,
|
|
138
|
+
image_description=image_description,
|
|
139
|
+
user_model=user_model,
|
|
140
|
+
model_source=model_source,
|
|
141
|
+
creativity=creativity
|
|
142
|
+
)
|
|
143
|
+
categories = auto_result["top_categories"]
|
|
144
|
+
print(f"Extracted {len(categories)} categories: {categories}\n")
|
|
145
|
+
|
|
146
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
147
|
+
cat_num = len(categories)
|
|
148
|
+
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
149
|
+
example_JSON = json.dumps(category_dict, indent=4)
|
|
150
|
+
|
|
151
|
+
print(f"\nCategories to classify by {model_source} {user_model}:")
|
|
152
|
+
for i, cat in enumerate(categories, 1):
|
|
153
|
+
print(f"{i}. {cat}")
|
|
154
|
+
|
|
155
|
+
# Build examples text from provided examples
|
|
156
|
+
examples = [example1, example2, example3, example4, example5, example6]
|
|
157
|
+
examples = [ex for ex in examples if ex is not None]
|
|
158
|
+
if examples:
|
|
159
|
+
examples_text = "Here are some examples of how to categorize:\n" + "\n".join(examples)
|
|
160
|
+
else:
|
|
161
|
+
examples_text = ""
|
|
162
|
+
|
|
163
|
+
# Helper function for CoVe
|
|
164
|
+
def remove_numbering(line):
|
|
165
|
+
line = line.strip()
|
|
166
|
+
if line.startswith('- '):
|
|
167
|
+
return line[2:].strip()
|
|
168
|
+
if line.startswith('• '):
|
|
169
|
+
return line[2:].strip()
|
|
170
|
+
if line and line[0].isdigit():
|
|
171
|
+
i = 0
|
|
172
|
+
while i < len(line) and line[i].isdigit():
|
|
173
|
+
i += 1
|
|
174
|
+
if i < len(line) and line[i] in '.':
|
|
175
|
+
return line[i+1:].strip()
|
|
176
|
+
elif i < len(line) and line[i] in ')':
|
|
177
|
+
return line[i+1:].strip()
|
|
178
|
+
return line
|
|
179
|
+
|
|
180
|
+
# Step-back insight initialization
|
|
181
|
+
if step_back_prompt:
|
|
182
|
+
stepback = f"""What are the key visual features or patterns that typically indicate the presence of these categories in images showing "{image_description}"?
|
|
183
|
+
|
|
184
|
+
Categories to consider:
|
|
185
|
+
{categories_str}
|
|
186
|
+
|
|
187
|
+
Provide a brief analysis of what visual cues to look for when categorizing such images."""
|
|
188
|
+
|
|
189
|
+
stepback_insight, step_back_added = get_image_stepback_insight(
|
|
190
|
+
model_source, stepback, api_key, user_model, creativity
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
stepback_insight = None
|
|
194
|
+
step_back_added = False
|
|
195
|
+
|
|
196
|
+
link1 = []
|
|
197
|
+
extracted_jsons = []
|
|
198
|
+
|
|
199
|
+
def _build_base_prompt_text():
|
|
200
|
+
"""Build the base text portion of the prompt."""
|
|
201
|
+
if chain_of_thought:
|
|
202
|
+
base_text = (
|
|
203
|
+
f"You are an image-tagging assistant.\n"
|
|
204
|
+
f"Task ► Examine the attached image and decide, **for each category below**, "
|
|
205
|
+
f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
|
|
206
|
+
f"Image is expected to show: {image_description}\n\n"
|
|
207
|
+
f"Categories:\n{categories_str}\n\n"
|
|
208
|
+
f"Let's analyze step by step:\n"
|
|
209
|
+
f"1. First, identify the key visual elements in the image\n"
|
|
210
|
+
f"2. Then, match each element to the relevant categories\n"
|
|
211
|
+
f"3. Finally, assign 1 to matching categories and 0 to non-matching categories\n\n"
|
|
212
|
+
f"{examples_text}\n\n"
|
|
213
|
+
f"Output format ► Respond with **only** a JSON object whose keys are the "
|
|
214
|
+
f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
|
|
215
|
+
f"No additional keys, comments, or text.\n\n"
|
|
216
|
+
f"Example (three categories):\n"
|
|
217
|
+
f"{example_JSON}"
|
|
218
|
+
)
|
|
219
|
+
else:
|
|
220
|
+
base_text = (
|
|
221
|
+
f"You are an image-tagging assistant.\n"
|
|
222
|
+
f"Task ► Examine the attached image and decide, **for each category below**, "
|
|
223
|
+
f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
|
|
224
|
+
f"Image is expected to show: {image_description}\n\n"
|
|
225
|
+
f"Categories:\n{categories_str}\n\n"
|
|
226
|
+
f"{examples_text}\n\n"
|
|
227
|
+
f"Output format ► Respond with **only** a JSON object whose keys are the "
|
|
228
|
+
f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
|
|
229
|
+
f"No additional keys, comments, or text.\n\n"
|
|
230
|
+
f"Example (three categories):\n"
|
|
231
|
+
f"{example_JSON}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if context_prompt:
|
|
235
|
+
context = (
|
|
236
|
+
"You are an expert visual analyst specializing in image categorization. "
|
|
237
|
+
"Apply multi-label classification based on explicit and implicit visual cues. "
|
|
238
|
+
"When uncertain, prioritize precision over recall.\n\n"
|
|
239
|
+
)
|
|
240
|
+
base_text = context + base_text
|
|
241
|
+
|
|
242
|
+
return base_text
|
|
243
|
+
|
|
244
|
+
def _build_cove_prompts(base_prompt_text):
|
|
245
|
+
"""Build chain of verification prompts for images."""
|
|
246
|
+
step2_prompt = f"""You provided this initial categorization:
|
|
247
|
+
<<INITIAL_REPLY>>
|
|
248
|
+
|
|
249
|
+
Original task: {base_prompt_text}
|
|
250
|
+
|
|
251
|
+
Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
|
|
252
|
+
- Be concise and specific (one sentence)
|
|
253
|
+
- Address a distinct visual element or category assignment
|
|
254
|
+
- Be answerable by re-examining the image
|
|
255
|
+
|
|
256
|
+
Focus on verifying:
|
|
257
|
+
- Whether each category assignment matches what's visible in the image
|
|
258
|
+
- Whether any visual elements were missed or misinterpreted
|
|
259
|
+
- Whether there are any logical inconsistencies
|
|
260
|
+
|
|
261
|
+
Provide only the verification questions as a numbered list."""
|
|
262
|
+
|
|
263
|
+
step3_prompt = f"""Re-examine the attached image and answer the following verification question.
|
|
264
|
+
|
|
265
|
+
Image description: {image_description}
|
|
266
|
+
|
|
267
|
+
Verification question: <<QUESTION>>
|
|
268
|
+
|
|
269
|
+
Provide a brief, direct answer (1-2 sentences maximum) based on what you observe in the image.
|
|
270
|
+
|
|
271
|
+
Answer:"""
|
|
272
|
+
|
|
273
|
+
step4_prompt = f"""Original task: {base_prompt_text}
|
|
274
|
+
Initial categorization:
|
|
275
|
+
<<INITIAL_REPLY>>
|
|
276
|
+
Verification questions and answers:
|
|
277
|
+
<<VERIFICATION_QA>>
|
|
278
|
+
Based on this verification, provide the final corrected categorization.
|
|
279
|
+
If no categories are present, assign "0" to all categories.
|
|
280
|
+
Provide the final categorization in the same JSON format:"""
|
|
281
|
+
|
|
282
|
+
return step2_prompt, step3_prompt, step4_prompt
|
|
283
|
+
|
|
284
|
+
def _build_prompt_openai_mistral(encoded, ext, base_text):
|
|
285
|
+
"""Build prompt for OpenAI/Mistral format."""
|
|
286
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
287
|
+
return [
|
|
288
|
+
{"type": "text", "text": base_text},
|
|
289
|
+
{"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}},
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
def _build_prompt_anthropic(encoded, ext, base_text):
|
|
293
|
+
"""Build prompt for Anthropic format."""
|
|
294
|
+
media_type = f"image/{ext}" if ext else "image/jpeg"
|
|
295
|
+
return [
|
|
296
|
+
{"type": "text", "text": base_text},
|
|
297
|
+
{
|
|
298
|
+
"type": "image",
|
|
299
|
+
"source": {
|
|
300
|
+
"type": "base64",
|
|
301
|
+
"media_type": media_type,
|
|
302
|
+
"data": encoded
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
def _build_prompt_google(encoded, ext, base_text):
|
|
308
|
+
"""Build prompt for Google format."""
|
|
309
|
+
return {
|
|
310
|
+
"text_prompt": base_text,
|
|
311
|
+
"image_data": encoded,
|
|
312
|
+
"mime_type": f"image/{ext}" if ext else "image/jpeg"
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
def _call_openai_compatible(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
|
|
316
|
+
"""Handle OpenAI-compatible API calls (OpenAI, Perplexity, HuggingFace, xAI)."""
|
|
317
|
+
import requests as req
|
|
318
|
+
|
|
319
|
+
# Determine the base URL based on model source
|
|
320
|
+
if model_source == "huggingface":
|
|
321
|
+
from cat_stack.text_functions import _detect_huggingface_endpoint
|
|
322
|
+
base_url = _detect_huggingface_endpoint(api_key, user_model)
|
|
323
|
+
elif model_source == "huggingface-together":
|
|
324
|
+
base_url = "https://router.huggingface.co/together/v1"
|
|
325
|
+
elif model_source == "perplexity":
|
|
326
|
+
base_url = "https://api.perplexity.ai"
|
|
327
|
+
elif model_source == "xai":
|
|
328
|
+
base_url = "https://api.x.ai/v1"
|
|
329
|
+
else:
|
|
330
|
+
base_url = "https://api.openai.com/v1"
|
|
331
|
+
|
|
332
|
+
endpoint = f"{base_url}/chat/completions"
|
|
333
|
+
|
|
334
|
+
headers = {
|
|
335
|
+
"Content-Type": "application/json",
|
|
336
|
+
"Authorization": f"Bearer {api_key}"
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
max_retries = 8
|
|
340
|
+
delay = 2
|
|
341
|
+
|
|
342
|
+
for attempt in range(max_retries):
|
|
343
|
+
try:
|
|
344
|
+
# Build messages with optional stepback
|
|
345
|
+
messages = []
|
|
346
|
+
if step_back_prompt and step_back_added:
|
|
347
|
+
messages.append({'role': 'user', 'content': stepback})
|
|
348
|
+
messages.append({'role': 'assistant', 'content': stepback_insight})
|
|
349
|
+
messages.append({'role': 'user', 'content': prompt})
|
|
350
|
+
|
|
351
|
+
payload = {
|
|
352
|
+
"model": user_model,
|
|
353
|
+
"messages": messages,
|
|
354
|
+
}
|
|
355
|
+
if creativity is not None:
|
|
356
|
+
payload["temperature"] = creativity
|
|
357
|
+
|
|
358
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
359
|
+
response.raise_for_status()
|
|
360
|
+
result = response.json()
|
|
361
|
+
reply = result["choices"][0]["message"]["content"]
|
|
362
|
+
|
|
363
|
+
if chain_of_verification:
|
|
364
|
+
reply = image_chain_of_verification_openai(
|
|
365
|
+
initial_reply=reply,
|
|
366
|
+
step2_prompt=step2_prompt,
|
|
367
|
+
step3_prompt=step3_prompt,
|
|
368
|
+
step4_prompt=step4_prompt,
|
|
369
|
+
client=None, # Not used anymore, CoVe needs refactoring too
|
|
370
|
+
user_model=user_model,
|
|
371
|
+
creativity=creativity,
|
|
372
|
+
remove_numbering=remove_numbering,
|
|
373
|
+
image_content=image_content
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
return reply, None
|
|
377
|
+
|
|
378
|
+
except req.exceptions.HTTPError as e:
|
|
379
|
+
error_str = str(e).lower()
|
|
380
|
+
status_code = e.response.status_code if e.response else None
|
|
381
|
+
|
|
382
|
+
if status_code == 400 and "json_validate_failed" in error_str and attempt < max_retries - 1:
|
|
383
|
+
wait_time = delay * (2 ** attempt)
|
|
384
|
+
print(f"⚠️ JSON validation failed. Attempt {attempt + 1}/{max_retries}")
|
|
385
|
+
print(f"Retrying in {wait_time}s...")
|
|
386
|
+
time.sleep(wait_time)
|
|
387
|
+
elif status_code == 404:
|
|
388
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
|
389
|
+
elif status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
|
|
390
|
+
wait_time = delay * (2 ** attempt)
|
|
391
|
+
print(f"Attempt {attempt + 1} failed with error: {e}")
|
|
392
|
+
print(f"Retrying in {wait_time}s...")
|
|
393
|
+
time.sleep(wait_time)
|
|
394
|
+
else:
|
|
395
|
+
print(f"❌ Failed after {max_retries} attempts: {e}")
|
|
396
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
397
|
+
|
|
398
|
+
except Exception as e:
|
|
399
|
+
if ("500" in str(e) or "504" in str(e)) and attempt < max_retries - 1:
|
|
400
|
+
wait_time = delay * (2 ** attempt)
|
|
401
|
+
print(f"Attempt {attempt + 1} failed with error: {e}")
|
|
402
|
+
print(f"Retrying in {wait_time}s...")
|
|
403
|
+
time.sleep(wait_time)
|
|
404
|
+
else:
|
|
405
|
+
print(f"❌ Failed after {max_retries} attempts: {e}")
|
|
406
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
407
|
+
|
|
408
|
+
return """{"1":"e"}""", "Max retries exceeded"
|
|
409
|
+
|
|
410
|
+
def _call_anthropic(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
|
|
411
|
+
"""Handle Anthropic API calls using direct HTTP requests."""
|
|
412
|
+
import requests as req
|
|
413
|
+
|
|
414
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
415
|
+
headers = {
|
|
416
|
+
"Content-Type": "application/json",
|
|
417
|
+
"x-api-key": api_key,
|
|
418
|
+
"anthropic-version": "2023-06-01"
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
try:
|
|
422
|
+
# Build messages with optional stepback
|
|
423
|
+
messages = []
|
|
424
|
+
if step_back_prompt and step_back_added:
|
|
425
|
+
messages.append({'role': 'user', 'content': stepback})
|
|
426
|
+
messages.append({'role': 'assistant', 'content': stepback_insight})
|
|
427
|
+
messages.append({'role': 'user', 'content': prompt})
|
|
428
|
+
|
|
429
|
+
payload = {
|
|
430
|
+
"model": user_model,
|
|
431
|
+
"max_tokens": 1024,
|
|
432
|
+
"messages": messages,
|
|
433
|
+
}
|
|
434
|
+
if creativity is not None:
|
|
435
|
+
payload["temperature"] = creativity
|
|
436
|
+
|
|
437
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
438
|
+
response.raise_for_status()
|
|
439
|
+
result = response.json()
|
|
440
|
+
|
|
441
|
+
content = result.get("content", [])
|
|
442
|
+
if content and content[0].get("type") == "text":
|
|
443
|
+
reply = content[0].get("text", "")
|
|
444
|
+
else:
|
|
445
|
+
return """{"1":"e"}""", "No text content in response"
|
|
446
|
+
|
|
447
|
+
if chain_of_verification:
|
|
448
|
+
reply = image_chain_of_verification_anthropic(
|
|
449
|
+
initial_reply=reply,
|
|
450
|
+
step2_prompt=step2_prompt,
|
|
451
|
+
step3_prompt=step3_prompt,
|
|
452
|
+
step4_prompt=step4_prompt,
|
|
453
|
+
client=None, # No longer using SDK client
|
|
454
|
+
user_model=user_model,
|
|
455
|
+
creativity=creativity,
|
|
456
|
+
remove_numbering=remove_numbering,
|
|
457
|
+
image_content=image_content,
|
|
458
|
+
api_key=api_key # Pass api_key for HTTP calls
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
return reply, None
|
|
462
|
+
|
|
463
|
+
except req.exceptions.HTTPError as e:
|
|
464
|
+
if e.response is not None and e.response.status_code == 404:
|
|
465
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
|
466
|
+
print(f"An error occurred: {e}")
|
|
467
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
468
|
+
except Exception as e:
|
|
469
|
+
print(f"An error occurred: {e}")
|
|
470
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
471
|
+
|
|
472
|
+
def _call_google(prompt_data, step2_prompt, step3_prompt, step4_prompt, base_prompt_text):
|
|
473
|
+
"""Handle Google API calls."""
|
|
474
|
+
import requests
|
|
475
|
+
|
|
476
|
+
def make_google_request(url, headers, payload, max_retries=8):
|
|
477
|
+
for attempt in range(max_retries):
|
|
478
|
+
try:
|
|
479
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
480
|
+
response.raise_for_status()
|
|
481
|
+
return response.json()
|
|
482
|
+
except requests.exceptions.HTTPError as e:
|
|
483
|
+
status_code = e.response.status_code
|
|
484
|
+
retryable_errors = [429, 500, 502, 503, 504]
|
|
485
|
+
|
|
486
|
+
if status_code in retryable_errors and attempt < max_retries - 1:
|
|
487
|
+
wait_time = 10 * (2 ** attempt) if status_code == 429 else 2 * (2 ** attempt)
|
|
488
|
+
error_type = "Rate limited" if status_code == 429 else f"Server error {status_code}"
|
|
489
|
+
print(f"⚠️ {error_type}. Attempt {attempt + 1}/{max_retries}")
|
|
490
|
+
print(f"Retrying in {wait_time}s...")
|
|
491
|
+
time.sleep(wait_time)
|
|
492
|
+
else:
|
|
493
|
+
raise
|
|
494
|
+
|
|
495
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
496
|
+
headers = {
|
|
497
|
+
"x-goog-api-key": api_key,
|
|
498
|
+
"Content-Type": "application/json"
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
# Build parts with optional stepback context
|
|
502
|
+
parts = []
|
|
503
|
+
if step_back_prompt and step_back_added:
|
|
504
|
+
parts.append({"text": f"Context from step-back analysis:\n{stepback_insight}\n\n"})
|
|
505
|
+
parts.append({"text": prompt_data["text_prompt"]})
|
|
506
|
+
parts.append({
|
|
507
|
+
"inline_data": {
|
|
508
|
+
"mime_type": prompt_data["mime_type"],
|
|
509
|
+
"data": prompt_data["image_data"]
|
|
510
|
+
}
|
|
511
|
+
})
|
|
512
|
+
|
|
513
|
+
payload = {
|
|
514
|
+
"contents": [{"parts": parts}],
|
|
515
|
+
"generationConfig": {
|
|
516
|
+
"responseMimeType": "application/json",
|
|
517
|
+
**({"temperature": creativity} if creativity is not None else {}),
|
|
518
|
+
**({"thinkingConfig": {"thinkingBudget": thinking_budget}} if thinking_budget else {})
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
try:
|
|
523
|
+
result = make_google_request(url, headers, payload)
|
|
524
|
+
|
|
525
|
+
if "candidates" in result and result["candidates"]:
|
|
526
|
+
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
527
|
+
else:
|
|
528
|
+
return "No response generated", None
|
|
529
|
+
|
|
530
|
+
if chain_of_verification:
|
|
531
|
+
reply = image_chain_of_verification_google(
|
|
532
|
+
initial_reply=reply,
|
|
533
|
+
prompt=base_prompt_text,
|
|
534
|
+
step2_prompt=step2_prompt,
|
|
535
|
+
step3_prompt=step3_prompt,
|
|
536
|
+
step4_prompt=step4_prompt,
|
|
537
|
+
url=url,
|
|
538
|
+
headers=headers,
|
|
539
|
+
creativity=creativity,
|
|
540
|
+
remove_numbering=remove_numbering,
|
|
541
|
+
make_google_request=make_google_request,
|
|
542
|
+
image_data=prompt_data["image_data"],
|
|
543
|
+
mime_type=prompt_data["mime_type"]
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return reply, None
|
|
547
|
+
|
|
548
|
+
except requests.exceptions.HTTPError as e:
|
|
549
|
+
if e.response.status_code == 404:
|
|
550
|
+
raise ValueError(f"❌ Model '{user_model}' not found. Please check the model name and try again.") from e
|
|
551
|
+
elif e.response.status_code in [401, 403]:
|
|
552
|
+
raise ValueError(f"❌ Authentication failed. Please check your Google API key.") from e
|
|
553
|
+
else:
|
|
554
|
+
print(f"HTTP error occurred: {e}")
|
|
555
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
556
|
+
except Exception as e:
|
|
557
|
+
print(f"An error occurred: {e}")
|
|
558
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
559
|
+
|
|
560
|
+
def _call_mistral(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
|
|
561
|
+
"""Handle Mistral API calls - uses requests directly."""
|
|
562
|
+
import requests as req
|
|
563
|
+
|
|
564
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
565
|
+
headers = {
|
|
566
|
+
"Content-Type": "application/json",
|
|
567
|
+
"Authorization": f"Bearer {api_key}"
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
max_retries = 8
|
|
571
|
+
delay = 2
|
|
572
|
+
|
|
573
|
+
for attempt in range(max_retries):
|
|
574
|
+
try:
|
|
575
|
+
# Build messages with optional stepback
|
|
576
|
+
messages = []
|
|
577
|
+
if step_back_prompt and step_back_added:
|
|
578
|
+
messages.append({'role': 'user', 'content': stepback})
|
|
579
|
+
messages.append({'role': 'assistant', 'content': stepback_insight})
|
|
580
|
+
messages.append({'role': 'user', 'content': prompt})
|
|
581
|
+
|
|
582
|
+
payload = {
|
|
583
|
+
"model": user_model,
|
|
584
|
+
"messages": messages,
|
|
585
|
+
}
|
|
586
|
+
if creativity is not None:
|
|
587
|
+
payload["temperature"] = creativity
|
|
588
|
+
|
|
589
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
590
|
+
response.raise_for_status()
|
|
591
|
+
result = response.json()
|
|
592
|
+
reply = result["choices"][0]["message"]["content"]
|
|
593
|
+
|
|
594
|
+
if chain_of_verification:
|
|
595
|
+
reply = image_chain_of_verification_mistral(
|
|
596
|
+
initial_reply=reply,
|
|
597
|
+
step2_prompt=step2_prompt,
|
|
598
|
+
step3_prompt=step3_prompt,
|
|
599
|
+
step4_prompt=step4_prompt,
|
|
600
|
+
client=None, # Not used anymore, CoVe needs refactoring too
|
|
601
|
+
user_model=user_model,
|
|
602
|
+
creativity=creativity,
|
|
603
|
+
remove_numbering=remove_numbering,
|
|
604
|
+
image_content=image_content
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
return reply, None
|
|
608
|
+
|
|
609
|
+
except req.exceptions.HTTPError as e:
|
|
610
|
+
error_str = str(e).lower()
|
|
611
|
+
status_code = e.response.status_code if e.response else None
|
|
612
|
+
|
|
613
|
+
if status_code == 404 or "invalid_model" in error_str or "invalid model" in error_str:
|
|
614
|
+
raise ValueError(f"❌ Model '{user_model}' not found.") from e
|
|
615
|
+
elif status_code == 401 or "unauthorized" in error_str:
|
|
616
|
+
raise ValueError(f"❌ Authentication failed. Please check your Mistral API key.") from e
|
|
617
|
+
elif status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
|
|
618
|
+
wait_time = delay * (2 ** attempt)
|
|
619
|
+
print(f"⚠️ Server error {status_code}. Attempt {attempt + 1}/{max_retries}")
|
|
620
|
+
print(f"Retrying in {wait_time}s...")
|
|
621
|
+
time.sleep(wait_time)
|
|
622
|
+
else:
|
|
623
|
+
print(f"❌ Failed after {max_retries} attempts: {e}")
|
|
624
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
625
|
+
|
|
626
|
+
except Exception as e:
|
|
627
|
+
print(f"❌ Unexpected error: {e}")
|
|
628
|
+
return """{"1":"e"}""", f"Error processing input: {e}"
|
|
629
|
+
|
|
630
|
+
return """{"1":"e"}""", "Max retries exceeded"
|
|
631
|
+
|
|
632
|
+
def _process_single_image(img_path):
|
|
633
|
+
"""Process a single image and return (reply, error_msg)."""
|
|
634
|
+
encoded, ext, is_valid = _encode_image(img_path)
|
|
635
|
+
|
|
636
|
+
if not is_valid:
|
|
637
|
+
return None, "Invalid image path or encoding failed"
|
|
638
|
+
|
|
639
|
+
base_prompt_text = _build_base_prompt_text()
|
|
640
|
+
|
|
641
|
+
if chain_of_verification:
|
|
642
|
+
step2_prompt, step3_prompt, step4_prompt = _build_cove_prompts(base_prompt_text)
|
|
643
|
+
else:
|
|
644
|
+
step2_prompt = step3_prompt = step4_prompt = None
|
|
645
|
+
|
|
646
|
+
if model_source in ["openai", "perplexity", "huggingface", "xai"]:
|
|
647
|
+
prompt = _build_prompt_openai_mistral(encoded, ext, base_prompt_text)
|
|
648
|
+
# Image content for CoVe (just the image part)
|
|
649
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
650
|
+
image_content = {"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}}
|
|
651
|
+
return _call_openai_compatible(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
|
|
652
|
+
|
|
653
|
+
elif model_source == "anthropic":
|
|
654
|
+
prompt = _build_prompt_anthropic(encoded, ext, base_prompt_text)
|
|
655
|
+
media_type = f"image/{ext}" if ext else "image/jpeg"
|
|
656
|
+
image_content = {
|
|
657
|
+
"type": "image",
|
|
658
|
+
"source": {"type": "base64", "media_type": media_type, "data": encoded}
|
|
659
|
+
}
|
|
660
|
+
return _call_anthropic(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
|
|
661
|
+
|
|
662
|
+
elif model_source == "google":
|
|
663
|
+
prompt_data = _build_prompt_google(encoded, ext, base_prompt_text)
|
|
664
|
+
return _call_google(prompt_data, step2_prompt, step3_prompt, step4_prompt, base_prompt_text)
|
|
665
|
+
|
|
666
|
+
elif model_source == "mistral":
|
|
667
|
+
prompt = _build_prompt_openai_mistral(encoded, ext, base_prompt_text)
|
|
668
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
669
|
+
image_content = {"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}}
|
|
670
|
+
return _call_mistral(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
|
|
671
|
+
|
|
672
|
+
else:
|
|
673
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, Google, xAI, Huggingface, or Mistral")
|
|
674
|
+
|
|
675
|
+
def _extract_json(reply):
|
|
676
|
+
"""Extract JSON from model reply."""
|
|
677
|
+
if reply is None:
|
|
678
|
+
return """{"1":"e"}"""
|
|
679
|
+
|
|
680
|
+
if reply == "invalid image path":
|
|
681
|
+
return """{"no_valid_path": 1}"""
|
|
682
|
+
|
|
683
|
+
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
684
|
+
if extracted_json:
|
|
685
|
+
return extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
686
|
+
else:
|
|
687
|
+
print("""{"1":"e"}""")
|
|
688
|
+
return """{"1":"e"}"""
|
|
689
|
+
|
|
690
|
+
# Main processing loop
|
|
691
|
+
for idx, img_path in enumerate(tqdm(image_files, desc="Categorizing images")):
|
|
692
|
+
if img_path is None:
|
|
693
|
+
link1.append("Skipped NaN input")
|
|
694
|
+
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
695
|
+
continue
|
|
696
|
+
|
|
697
|
+
reply, error_msg = _process_single_image(img_path)
|
|
698
|
+
|
|
699
|
+
if error_msg:
|
|
700
|
+
link1.append(error_msg)
|
|
701
|
+
if "Invalid image" in error_msg:
|
|
702
|
+
extracted_jsons.append("""{"no_valid_path": 1}""")
|
|
703
|
+
else:
|
|
704
|
+
extracted_jsons.append(_extract_json(reply))
|
|
705
|
+
else:
|
|
706
|
+
link1.append(reply)
|
|
707
|
+
extracted_jsons.append(_extract_json(reply))
|
|
708
|
+
|
|
709
|
+
# --- Safety Save ---
|
|
710
|
+
if safety:
|
|
711
|
+
if filename is None:
|
|
712
|
+
raise TypeError("filename is required when using safety. Please provide the filename.")
|
|
713
|
+
|
|
714
|
+
normalized_data_list = []
|
|
715
|
+
for json_str in extracted_jsons:
|
|
716
|
+
try:
|
|
717
|
+
parsed_obj = json.loads(json_str)
|
|
718
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
719
|
+
except json.JSONDecodeError:
|
|
720
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
721
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
722
|
+
|
|
723
|
+
temp_df = pd.DataFrame({
|
|
724
|
+
'image_input': image_files[:idx+1],
|
|
725
|
+
'model_response': link1,
|
|
726
|
+
'json': extracted_jsons
|
|
727
|
+
})
|
|
728
|
+
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
729
|
+
|
|
730
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
731
|
+
temp_df.to_csv(save_path, index=False)
|
|
732
|
+
|
|
733
|
+
# --- Final DataFrame ---
|
|
734
|
+
normalized_data_list = []
|
|
735
|
+
for json_str in extracted_jsons:
|
|
736
|
+
try:
|
|
737
|
+
parsed_obj = json.loads(json_str)
|
|
738
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
739
|
+
except json.JSONDecodeError:
|
|
740
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
741
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
742
|
+
|
|
743
|
+
categorized_data = pd.DataFrame({
|
|
744
|
+
'image_input': pd.Series(image_files),
|
|
745
|
+
'model_response': pd.Series(link1).reset_index(drop=True),
|
|
746
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
747
|
+
})
|
|
748
|
+
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
749
|
+
categorized_data = categorized_data.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
|
|
750
|
+
|
|
751
|
+
# Identify rows with invalid strings (like "e")
|
|
752
|
+
cat_cols = [col for col in categorized_data.columns if col.startswith('category_')]
|
|
753
|
+
has_invalid_strings = categorized_data[cat_cols].apply(
|
|
754
|
+
lambda col: pd.to_numeric(col, errors='coerce').isna() & col.notna()
|
|
755
|
+
).any(axis=1)
|
|
756
|
+
|
|
757
|
+
categorized_data['processing_status'] = (~has_invalid_strings).map({True: 'success', False: 'error'})
|
|
758
|
+
categorized_data.loc[has_invalid_strings, cat_cols] = pd.NA
|
|
759
|
+
|
|
760
|
+
for col in cat_cols:
|
|
761
|
+
categorized_data[col] = pd.to_numeric(categorized_data[col], errors='coerce')
|
|
762
|
+
|
|
763
|
+
categorized_data.loc[~has_invalid_strings, cat_cols] = (
|
|
764
|
+
categorized_data.loc[~has_invalid_strings, cat_cols].fillna(0)
|
|
765
|
+
)
|
|
766
|
+
categorized_data[cat_cols] = categorized_data[cat_cols].astype('Int64')
|
|
767
|
+
|
|
768
|
+
# Create categories_id (comma-separated binary values for each category)
|
|
769
|
+
categorized_data['categories_id'] = categorized_data[cat_cols].apply(
|
|
770
|
+
lambda x: ','.join(x.dropna().astype(int).astype(str)), axis=1
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
if filename:
|
|
774
|
+
save_path = os.path.join(save_directory, filename) if save_directory else filename
|
|
775
|
+
categorized_data.to_csv(save_path, index=False)
|
|
776
|
+
|
|
777
|
+
return categorized_data
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
# image score function
|
|
781
|
+
def image_score_drawing(
|
|
782
|
+
reference_image_description,
|
|
783
|
+
image_input,
|
|
784
|
+
reference_image,
|
|
785
|
+
api_key,
|
|
786
|
+
columns="numbered",
|
|
787
|
+
user_model="gpt-4o-2024-11-20",
|
|
788
|
+
creativity=None,
|
|
789
|
+
to_csv=False,
|
|
790
|
+
safety=False,
|
|
791
|
+
filename="categorized_data.csv",
|
|
792
|
+
save_directory=None,
|
|
793
|
+
model_source="OpenAI"
|
|
794
|
+
):
|
|
795
|
+
import os
|
|
796
|
+
import json
|
|
797
|
+
import pandas as pd
|
|
798
|
+
import regex
|
|
799
|
+
from tqdm import tqdm
|
|
800
|
+
import glob
|
|
801
|
+
import base64
|
|
802
|
+
from pathlib import Path
|
|
803
|
+
|
|
804
|
+
if save_directory is not None and not os.path.isdir(save_directory):
|
|
805
|
+
# Directory doesn't exist - raise an exception to halt execution
|
|
806
|
+
raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
|
|
807
|
+
|
|
808
|
+
image_extensions = [
|
|
809
|
+
'*.png', '*.jpg', '*.jpeg',
|
|
810
|
+
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
811
|
+
'*.tif', '*.tiff', '*.bmp',
|
|
812
|
+
'*.heif', '*.heic', '*.ico',
|
|
813
|
+
'*.psd'
|
|
814
|
+
]
|
|
815
|
+
|
|
816
|
+
model_source = model_source.lower() # eliminating case sensitivity
|
|
817
|
+
|
|
818
|
+
if not isinstance(image_input, list):
|
|
819
|
+
# If image_input is a filepath (string)
|
|
820
|
+
image_files = []
|
|
821
|
+
for ext in image_extensions:
|
|
822
|
+
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
823
|
+
|
|
824
|
+
print(f"Found {len(image_files)} images.")
|
|
825
|
+
else:
|
|
826
|
+
# If image_files is already a list
|
|
827
|
+
image_files = image_input
|
|
828
|
+
print(f"Provided a list of {len(image_input)} images.")
|
|
829
|
+
|
|
830
|
+
with open(reference_image, 'rb') as f:
|
|
831
|
+
reference = base64.b64encode(f.read()).decode('utf-8')
|
|
832
|
+
reference_image = f"data:image/{reference_image.split('.')[-1]};base64,{reference}"
|
|
833
|
+
|
|
834
|
+
link1 = []
|
|
835
|
+
extracted_jsons = []
|
|
836
|
+
|
|
837
|
+
for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
|
|
838
|
+
# Check validity first
|
|
839
|
+
if img_path is None or not os.path.exists(img_path):
|
|
840
|
+
link1.append("Skipped NaN input or invalid path")
|
|
841
|
+
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
842
|
+
continue # Skip the rest of the loop iteration
|
|
843
|
+
|
|
844
|
+
# Only open the file if path is valid
|
|
845
|
+
if os.path.isdir(img_path):
|
|
846
|
+
encoded = "Not a Valid Image, contains file path"
|
|
847
|
+
else:
|
|
848
|
+
try:
|
|
849
|
+
with open(img_path, "rb") as f:
|
|
850
|
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
851
|
+
except Exception as e:
|
|
852
|
+
encoded = f"Error: {str(e)}"
|
|
853
|
+
# Handle extension safely
|
|
854
|
+
if encoded.startswith("Error:") or encoded == "Not a Valid Image, contains file path":
|
|
855
|
+
encoded_image = encoded
|
|
856
|
+
valid_image = False
|
|
857
|
+
|
|
858
|
+
else:
|
|
859
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
860
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
861
|
+
valid_image = True
|
|
862
|
+
|
|
863
|
+
# Handle extension safely
|
|
864
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
865
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
866
|
+
|
|
867
|
+
if model_source == "openai" or model_source == "mistral":
|
|
868
|
+
prompt = [
|
|
869
|
+
{
|
|
870
|
+
"type": "text",
|
|
871
|
+
"text": (
|
|
872
|
+
f"You are a visual similarity assessment system.\n"
|
|
873
|
+
f"Task ► Compare these two images:\n"
|
|
874
|
+
f"1. REFERENCE (left): {reference_image_description}\n"
|
|
875
|
+
f"2. INPUT (right): User-provided drawing\n\n"
|
|
876
|
+
f"Rating criteria:\n"
|
|
877
|
+
f"1: No meaningful similarity (fundamentally different)\n"
|
|
878
|
+
f"2: Barely recognizable similarity (25% match)\n"
|
|
879
|
+
f"3: Partial match (50% key features)\n"
|
|
880
|
+
f"4: Strong alignment (75% features)\n"
|
|
881
|
+
f"5: Near-perfect match (90%+ similarity)\n\n"
|
|
882
|
+
f"Output format ► Return ONLY:\n"
|
|
883
|
+
"{\n"
|
|
884
|
+
' "score": [1-5],\n'
|
|
885
|
+
' "summary": "reason you scored"\n'
|
|
886
|
+
"}\n\n"
|
|
887
|
+
f"Critical rules:\n"
|
|
888
|
+
f"- Score must reflect shape, proportions, and key details\n"
|
|
889
|
+
f"- List only concrete matching elements from reference\n"
|
|
890
|
+
f"- No markdown or additional text"
|
|
891
|
+
)
|
|
892
|
+
},
|
|
893
|
+
{
|
|
894
|
+
"type": "image_url",
|
|
895
|
+
"image_url": {"url": reference_image, "detail": "high"}
|
|
896
|
+
},
|
|
897
|
+
{
|
|
898
|
+
"type": "image_url",
|
|
899
|
+
"image_url": {"url": encoded_image, "detail": "high"}
|
|
900
|
+
}
|
|
901
|
+
]
|
|
902
|
+
|
|
903
|
+
elif model_source == "anthropic": # Changed to elif
|
|
904
|
+
prompt = [
|
|
905
|
+
{
|
|
906
|
+
"type": "text",
|
|
907
|
+
"text": (
|
|
908
|
+
f"You are a visual similarity assessment system.\n"
|
|
909
|
+
f"Task ► Compare these two images:\n"
|
|
910
|
+
f"1. REFERENCE (left): {reference_image_description}\n"
|
|
911
|
+
f"2. INPUT (right): User-provided drawing\n\n"
|
|
912
|
+
f"Rating criteria:\n"
|
|
913
|
+
f"1: No meaningful similarity (fundamentally different)\n"
|
|
914
|
+
f"2: Barely recognizable similarity (25% match)\n"
|
|
915
|
+
f"3: Partial match (50% key features)\n"
|
|
916
|
+
f"4: Strong alignment (75% features)\n"
|
|
917
|
+
f"5: Near-perfect match (90%+ similarity)\n\n"
|
|
918
|
+
f"Output format ► Return ONLY:\n"
|
|
919
|
+
"{\n"
|
|
920
|
+
' "score": [1-5],\n'
|
|
921
|
+
' "summary": "reason you scored"\n'
|
|
922
|
+
"}\n\n"
|
|
923
|
+
f"Critical rules:\n"
|
|
924
|
+
f"- Score must reflect shape, proportions, and key details\n"
|
|
925
|
+
f"- List only concrete matching elements from reference\n"
|
|
926
|
+
f"- No markdown or additional text"
|
|
927
|
+
)
|
|
928
|
+
},
|
|
929
|
+
{
|
|
930
|
+
"type": "image", # Added missing type
|
|
931
|
+
"source": {
|
|
932
|
+
"type": "base64",
|
|
933
|
+
"media_type": "image/png",
|
|
934
|
+
"data": reference
|
|
935
|
+
}
|
|
936
|
+
},
|
|
937
|
+
{
|
|
938
|
+
"type": "image", # Added missing type
|
|
939
|
+
"source": {
|
|
940
|
+
"type": "base64",
|
|
941
|
+
"media_type": "image/jpeg",
|
|
942
|
+
"data": encoded
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
]
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
if model_source == "openai":
|
|
949
|
+
import requests as req
|
|
950
|
+
endpoint = "https://api.openai.com/v1/chat/completions"
|
|
951
|
+
headers = {
|
|
952
|
+
"Content-Type": "application/json",
|
|
953
|
+
"Authorization": f"Bearer {api_key}"
|
|
954
|
+
}
|
|
955
|
+
payload = {
|
|
956
|
+
"model": user_model,
|
|
957
|
+
"messages": [{'role': 'user', 'content': prompt}],
|
|
958
|
+
}
|
|
959
|
+
if creativity is not None:
|
|
960
|
+
payload["temperature"] = creativity
|
|
961
|
+
try:
|
|
962
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
963
|
+
response.raise_for_status()
|
|
964
|
+
result = response.json()
|
|
965
|
+
reply = result["choices"][0]["message"]["content"]
|
|
966
|
+
link1.append(reply)
|
|
967
|
+
except req.exceptions.HTTPError as e:
|
|
968
|
+
if e.response and e.response.status_code == 404:
|
|
969
|
+
raise ValueError(f"Invalid OpenAI model '{user_model}': {e}")
|
|
970
|
+
else:
|
|
971
|
+
print(f"An error occurred: {e}")
|
|
972
|
+
link1.append(f"Error processing input: {e}")
|
|
973
|
+
except Exception as e:
|
|
974
|
+
print(f"An error occurred: {e}")
|
|
975
|
+
link1.append(f"Error processing input: {e}")
|
|
976
|
+
|
|
977
|
+
elif model_source == "anthropic":
|
|
978
|
+
import requests as req
|
|
979
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
980
|
+
headers = {
|
|
981
|
+
"Content-Type": "application/json",
|
|
982
|
+
"x-api-key": api_key,
|
|
983
|
+
"anthropic-version": "2023-06-01"
|
|
984
|
+
}
|
|
985
|
+
payload = {
|
|
986
|
+
"model": user_model,
|
|
987
|
+
"max_tokens": 1024,
|
|
988
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
989
|
+
}
|
|
990
|
+
if creativity is not None:
|
|
991
|
+
payload["temperature"] = creativity
|
|
992
|
+
try:
|
|
993
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
994
|
+
response.raise_for_status()
|
|
995
|
+
result = response.json()
|
|
996
|
+
content = result.get("content", [])
|
|
997
|
+
if content and content[0].get("type") == "text":
|
|
998
|
+
reply = content[0].get("text", "")
|
|
999
|
+
link1.append(reply)
|
|
1000
|
+
else:
|
|
1001
|
+
link1.append("Error processing input: No text content in response")
|
|
1002
|
+
except req.exceptions.HTTPError as e:
|
|
1003
|
+
if e.response is not None and e.response.status_code == 404:
|
|
1004
|
+
raise ValueError(f"Invalid Anthropic model '{user_model}': {e}")
|
|
1005
|
+
else:
|
|
1006
|
+
print(f"An error occurred: {e}")
|
|
1007
|
+
link1.append(f"Error processing input: {e}")
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
print(f"An error occurred: {e}")
|
|
1010
|
+
link1.append(f"Error processing input: {e}")
|
|
1011
|
+
|
|
1012
|
+
elif model_source == "mistral":
|
|
1013
|
+
import requests as req
|
|
1014
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
1015
|
+
headers = {
|
|
1016
|
+
"Content-Type": "application/json",
|
|
1017
|
+
"Authorization": f"Bearer {api_key}"
|
|
1018
|
+
}
|
|
1019
|
+
payload = {
|
|
1020
|
+
"model": user_model,
|
|
1021
|
+
"messages": [{'role': 'user', 'content': prompt}],
|
|
1022
|
+
}
|
|
1023
|
+
if creativity is not None:
|
|
1024
|
+
payload["temperature"] = creativity
|
|
1025
|
+
try:
|
|
1026
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1027
|
+
response.raise_for_status()
|
|
1028
|
+
result = response.json()
|
|
1029
|
+
reply = result["choices"][0]["message"]["content"]
|
|
1030
|
+
link1.append(reply)
|
|
1031
|
+
except req.exceptions.HTTPError as e:
|
|
1032
|
+
if e.response and e.response.status_code == 404:
|
|
1033
|
+
raise ValueError(f"Invalid Mistral model '{user_model}': {e}")
|
|
1034
|
+
else:
|
|
1035
|
+
print(f"An error occurred: {e}")
|
|
1036
|
+
link1.append(f"Error processing input: {e}")
|
|
1037
|
+
except Exception as e:
|
|
1038
|
+
print(f"An error occurred: {e}")
|
|
1039
|
+
link1.append(f"Error processing input: {e}")
|
|
1040
|
+
#if no valid image path is provided
|
|
1041
|
+
elif valid_image == False:
|
|
1042
|
+
reply = "invalid image path"
|
|
1043
|
+
print("Skipped NaN input or invalid path")
|
|
1044
|
+
#extracted_jsons.append("""{"no_valid_path": 1}""")
|
|
1045
|
+
link1.append("Error processing input: {e}")
|
|
1046
|
+
else:
|
|
1047
|
+
raise ValueError("Unknown source! Choose from OpenAI, Perplexity, or Mistral")
|
|
1048
|
+
# in situation that no JSON is found
|
|
1049
|
+
if reply is not None:
|
|
1050
|
+
if reply == "invalid image path":
|
|
1051
|
+
extracted_jsons.append("""{"no_valid_path": 1}""")
|
|
1052
|
+
else:
|
|
1053
|
+
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
1054
|
+
if extracted_json:
|
|
1055
|
+
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
1056
|
+
extracted_jsons.append(cleaned_json)
|
|
1057
|
+
else:
|
|
1058
|
+
error_message = """{"1":"e"}"""
|
|
1059
|
+
extracted_jsons.append(error_message)
|
|
1060
|
+
print(error_message)
|
|
1061
|
+
else:
|
|
1062
|
+
error_message = """{"1":"e"}"""
|
|
1063
|
+
extracted_jsons.append(error_message)
|
|
1064
|
+
print(error_message)
|
|
1065
|
+
|
|
1066
|
+
# --- Safety Save ---
|
|
1067
|
+
if safety:
|
|
1068
|
+
# Save progress so far
|
|
1069
|
+
temp_df = pd.DataFrame({
|
|
1070
|
+
'image_input': image_files[:i+1],
|
|
1071
|
+
'model_response': link1,
|
|
1072
|
+
'json': extracted_jsons
|
|
1073
|
+
})
|
|
1074
|
+
# Normalize processed jsons so far
|
|
1075
|
+
normalized_data_list = []
|
|
1076
|
+
for json_str in extracted_jsons:
|
|
1077
|
+
try:
|
|
1078
|
+
parsed_obj = json.loads(json_str)
|
|
1079
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1080
|
+
except json.JSONDecodeError:
|
|
1081
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1082
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1083
|
+
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
1084
|
+
# Save to CSV
|
|
1085
|
+
if save_directory is None:
|
|
1086
|
+
save_directory = os.getcwd()
|
|
1087
|
+
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
1088
|
+
|
|
1089
|
+
# --- Final DataFrame ---
|
|
1090
|
+
normalized_data_list = []
|
|
1091
|
+
for json_str in extracted_jsons:
|
|
1092
|
+
try:
|
|
1093
|
+
parsed_obj = json.loads(json_str)
|
|
1094
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1095
|
+
except json.JSONDecodeError:
|
|
1096
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1097
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1098
|
+
|
|
1099
|
+
categorized_data = pd.DataFrame({
|
|
1100
|
+
'image_input': (
|
|
1101
|
+
image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
|
|
1102
|
+
else pd.Series(image_files)
|
|
1103
|
+
),
|
|
1104
|
+
'link1': pd.Series(link1).reset_index(drop=True),
|
|
1105
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
1106
|
+
})
|
|
1107
|
+
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
1108
|
+
|
|
1109
|
+
if to_csv:
|
|
1110
|
+
if save_directory is None:
|
|
1111
|
+
save_directory = os.getcwd()
|
|
1112
|
+
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
1113
|
+
|
|
1114
|
+
return categorized_data
|
|
1115
|
+
|
|
1116
|
+
# image features function
|
|
1117
|
+
def image_features(
|
|
1118
|
+
image_description,
|
|
1119
|
+
image_input,
|
|
1120
|
+
features_to_extract,
|
|
1121
|
+
api_key,
|
|
1122
|
+
user_model="gpt-4o-2024-11-20",
|
|
1123
|
+
creativity=None,
|
|
1124
|
+
to_csv=False,
|
|
1125
|
+
safety=False,
|
|
1126
|
+
filename="categorized_data.csv",
|
|
1127
|
+
save_directory=None,
|
|
1128
|
+
model_source="OpenAI"
|
|
1129
|
+
):
|
|
1130
|
+
import os
|
|
1131
|
+
import json
|
|
1132
|
+
import pandas as pd
|
|
1133
|
+
import regex
|
|
1134
|
+
from tqdm import tqdm
|
|
1135
|
+
import glob
|
|
1136
|
+
import base64
|
|
1137
|
+
from pathlib import Path
|
|
1138
|
+
|
|
1139
|
+
image_extensions = [
|
|
1140
|
+
'*.png', '*.jpg', '*.jpeg',
|
|
1141
|
+
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
1142
|
+
'*.tif', '*.tiff', '*.bmp',
|
|
1143
|
+
'*.heif', '*.heic', '*.ico',
|
|
1144
|
+
'*.psd'
|
|
1145
|
+
]
|
|
1146
|
+
|
|
1147
|
+
model_source = model_source.lower() # eliminating case sensitivity
|
|
1148
|
+
|
|
1149
|
+
if not isinstance(image_input, list):
|
|
1150
|
+
# If image_input is a filepath (string)
|
|
1151
|
+
image_files = []
|
|
1152
|
+
for ext in image_extensions:
|
|
1153
|
+
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
1154
|
+
|
|
1155
|
+
print(f"Found {len(image_files)} images.")
|
|
1156
|
+
else:
|
|
1157
|
+
# If image_files is already a list
|
|
1158
|
+
image_files = image_input
|
|
1159
|
+
print(f"Provided a list of {len(image_input)} images.")
|
|
1160
|
+
|
|
1161
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(features_to_extract))
|
|
1162
|
+
cat_num = len(features_to_extract)
|
|
1163
|
+
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
1164
|
+
example_JSON = json.dumps(category_dict, indent=4)
|
|
1165
|
+
|
|
1166
|
+
link1 = []
|
|
1167
|
+
extracted_jsons = []
|
|
1168
|
+
|
|
1169
|
+
for i, img_path in enumerate(tqdm(image_files, desc="Scoring images"), start=0):
|
|
1170
|
+
# Check validity first
|
|
1171
|
+
if img_path is None or not os.path.exists(img_path):
|
|
1172
|
+
link1.append("Skipped NaN input or invalid path")
|
|
1173
|
+
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
1174
|
+
continue # Skip the rest of the loop iteration
|
|
1175
|
+
|
|
1176
|
+
# Only open the file if path is valid
|
|
1177
|
+
if os.path.isdir(img_path):
|
|
1178
|
+
encoded = "Not a Valid Image, contains file path"
|
|
1179
|
+
else:
|
|
1180
|
+
try:
|
|
1181
|
+
with open(img_path, "rb") as f:
|
|
1182
|
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
1183
|
+
except Exception as e:
|
|
1184
|
+
encoded = f"Error: {str(e)}"
|
|
1185
|
+
# Handle extension safely
|
|
1186
|
+
if encoded.startswith("Error:") or encoded == "Not a Valid Image, contains file path":
|
|
1187
|
+
encoded_image = encoded
|
|
1188
|
+
valid_image = False
|
|
1189
|
+
|
|
1190
|
+
else:
|
|
1191
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
1192
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
1193
|
+
valid_image = True
|
|
1194
|
+
|
|
1195
|
+
if model_source == "openai" or model_source == "mistral":
|
|
1196
|
+
prompt = [
|
|
1197
|
+
{
|
|
1198
|
+
"type": "text",
|
|
1199
|
+
"text": (
|
|
1200
|
+
f"You are a visual question answering assistant.\n"
|
|
1201
|
+
f"Task ► Analyze the attached image and answer these specific questions:\n\n"
|
|
1202
|
+
f"Image context: {image_description}\n\n"
|
|
1203
|
+
f"Questions to answer:\n{categories_str}\n\n"
|
|
1204
|
+
f"Output format ► Return **only** a JSON object where:\n"
|
|
1205
|
+
f"- Keys are question numbers ('1', '2', ...)\n"
|
|
1206
|
+
f"- Values are concise answers (numbers, short phrases)\n\n"
|
|
1207
|
+
f"Example for 3 questions:\n"
|
|
1208
|
+
"{\n"
|
|
1209
|
+
' "1": "4",\n'
|
|
1210
|
+
' "2": "blue",\n'
|
|
1211
|
+
' "3": "yes"\n'
|
|
1212
|
+
"}\n\n"
|
|
1213
|
+
f"Important rules:\n"
|
|
1214
|
+
f"1. Answer directly - no explanations\n"
|
|
1215
|
+
f"2. Use exact numerical values when possible\n"
|
|
1216
|
+
f"3. For yes/no questions, use 'yes' or 'no'\n"
|
|
1217
|
+
f"4. Never add extra keys or formatting"
|
|
1218
|
+
),
|
|
1219
|
+
},
|
|
1220
|
+
{
|
|
1221
|
+
"type": "image_url",
|
|
1222
|
+
"image_url": {"url": encoded_image, "detail": "high"},
|
|
1223
|
+
},
|
|
1224
|
+
]
|
|
1225
|
+
elif model_source == "anthropic":
|
|
1226
|
+
prompt = [
|
|
1227
|
+
{
|
|
1228
|
+
"type": "text",
|
|
1229
|
+
"text": (
|
|
1230
|
+
f"You are a visual question answering assistant.\n"
|
|
1231
|
+
f"Task ► Analyze the attached image and answer these specific questions:\n\n"
|
|
1232
|
+
f"Image context: {image_description}\n\n"
|
|
1233
|
+
f"Questions to answer:\n{categories_str}\n\n"
|
|
1234
|
+
f"Output format ► Return **only** a JSON object where:\n"
|
|
1235
|
+
f"- Keys are question numbers ('1', '2', ...)\n"
|
|
1236
|
+
f"- Values are concise answers (numbers, short phrases)\n\n"
|
|
1237
|
+
f"Example for 3 questions:\n"
|
|
1238
|
+
"{\n"
|
|
1239
|
+
' "1": "4",\n'
|
|
1240
|
+
' "2": "blue",\n'
|
|
1241
|
+
' "3": "yes"\n'
|
|
1242
|
+
"}\n\n"
|
|
1243
|
+
f"Important rules:\n"
|
|
1244
|
+
f"1. Answer directly - no explanations\n"
|
|
1245
|
+
f"2. Use exact numerical values when possible\n"
|
|
1246
|
+
f"3. For yes/no questions, use 'yes' or 'no'\n"
|
|
1247
|
+
f"4. Never add extra keys or formatting"
|
|
1248
|
+
)
|
|
1249
|
+
},
|
|
1250
|
+
{
|
|
1251
|
+
"type": "image",
|
|
1252
|
+
"source": {
|
|
1253
|
+
"type": "base64",
|
|
1254
|
+
"media_type": "image/jpeg",
|
|
1255
|
+
"data": encoded
|
|
1256
|
+
}
|
|
1257
|
+
}
|
|
1258
|
+
]
|
|
1259
|
+
if model_source == "openai":
|
|
1260
|
+
import requests as req
|
|
1261
|
+
endpoint = "https://api.openai.com/v1/chat/completions"
|
|
1262
|
+
headers = {
|
|
1263
|
+
"Content-Type": "application/json",
|
|
1264
|
+
"Authorization": f"Bearer {api_key}"
|
|
1265
|
+
}
|
|
1266
|
+
payload = {
|
|
1267
|
+
"model": user_model,
|
|
1268
|
+
"messages": [{'role': 'user', 'content': prompt}],
|
|
1269
|
+
}
|
|
1270
|
+
if creativity is not None:
|
|
1271
|
+
payload["temperature"] = creativity
|
|
1272
|
+
try:
|
|
1273
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1274
|
+
response.raise_for_status()
|
|
1275
|
+
result = response.json()
|
|
1276
|
+
reply = result["choices"][0]["message"]["content"]
|
|
1277
|
+
link1.append(reply)
|
|
1278
|
+
except req.exceptions.HTTPError as e:
|
|
1279
|
+
if e.response and e.response.status_code == 404:
|
|
1280
|
+
raise ValueError(f"Invalid OpenAI model '{user_model}': {e}")
|
|
1281
|
+
else:
|
|
1282
|
+
print(f"An error occurred: {e}")
|
|
1283
|
+
link1.append(f"Error processing input: {e}")
|
|
1284
|
+
except Exception as e:
|
|
1285
|
+
print(f"An error occurred: {e}")
|
|
1286
|
+
link1.append(f"Error processing input: {e}")
|
|
1287
|
+
|
|
1288
|
+
elif model_source == "perplexity":
|
|
1289
|
+
import requests as req
|
|
1290
|
+
endpoint = "https://api.perplexity.ai/chat/completions"
|
|
1291
|
+
headers = {
|
|
1292
|
+
"Content-Type": "application/json",
|
|
1293
|
+
"Authorization": f"Bearer {api_key}"
|
|
1294
|
+
}
|
|
1295
|
+
payload = {
|
|
1296
|
+
"model": user_model,
|
|
1297
|
+
"messages": [{'role': 'user', 'content': prompt}],
|
|
1298
|
+
}
|
|
1299
|
+
if creativity is not None:
|
|
1300
|
+
payload["temperature"] = creativity
|
|
1301
|
+
try:
|
|
1302
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1303
|
+
response.raise_for_status()
|
|
1304
|
+
result = response.json()
|
|
1305
|
+
reply = result["choices"][0]["message"]["content"]
|
|
1306
|
+
link1.append(reply)
|
|
1307
|
+
except req.exceptions.HTTPError as e:
|
|
1308
|
+
if e.response and e.response.status_code == 404:
|
|
1309
|
+
raise ValueError(f"Invalid Perplexity model '{user_model}': {e}")
|
|
1310
|
+
else:
|
|
1311
|
+
print(f"An error occurred: {e}")
|
|
1312
|
+
link1.append(f"Error processing input: {e}")
|
|
1313
|
+
except Exception as e:
|
|
1314
|
+
print(f"An error occurred: {e}")
|
|
1315
|
+
link1.append(f"Error processing input: {e}")
|
|
1316
|
+
|
|
1317
|
+
elif model_source == "anthropic":
|
|
1318
|
+
import requests as req
|
|
1319
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
1320
|
+
headers = {
|
|
1321
|
+
"Content-Type": "application/json",
|
|
1322
|
+
"x-api-key": api_key,
|
|
1323
|
+
"anthropic-version": "2023-06-01"
|
|
1324
|
+
}
|
|
1325
|
+
payload = {
|
|
1326
|
+
"model": user_model,
|
|
1327
|
+
"max_tokens": 1024,
|
|
1328
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
1329
|
+
}
|
|
1330
|
+
if creativity is not None:
|
|
1331
|
+
payload["temperature"] = creativity
|
|
1332
|
+
try:
|
|
1333
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1334
|
+
response.raise_for_status()
|
|
1335
|
+
result = response.json()
|
|
1336
|
+
content = result.get("content", [])
|
|
1337
|
+
if content and content[0].get("type") == "text":
|
|
1338
|
+
reply = content[0].get("text", "")
|
|
1339
|
+
link1.append(reply)
|
|
1340
|
+
else:
|
|
1341
|
+
link1.append("Error processing input: No text content in response")
|
|
1342
|
+
except req.exceptions.HTTPError as e:
|
|
1343
|
+
if e.response is not None and e.response.status_code == 404:
|
|
1344
|
+
raise ValueError(f"Invalid Anthropic model '{user_model}': {e}")
|
|
1345
|
+
else:
|
|
1346
|
+
print(f"An error occurred: {e}")
|
|
1347
|
+
link1.append(f"Error processing input: {e}")
|
|
1348
|
+
except Exception as e:
|
|
1349
|
+
print(f"An error occurred: {e}")
|
|
1350
|
+
link1.append(f"Error processing input: {e}")
|
|
1351
|
+
|
|
1352
|
+
elif model_source == "mistral":
|
|
1353
|
+
import requests as req
|
|
1354
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
1355
|
+
headers = {
|
|
1356
|
+
"Content-Type": "application/json",
|
|
1357
|
+
"Authorization": f"Bearer {api_key}"
|
|
1358
|
+
}
|
|
1359
|
+
payload = {
|
|
1360
|
+
"model": user_model,
|
|
1361
|
+
"messages": [{'role': 'user', 'content': prompt}],
|
|
1362
|
+
}
|
|
1363
|
+
if creativity is not None:
|
|
1364
|
+
payload["temperature"] = creativity
|
|
1365
|
+
try:
|
|
1366
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1367
|
+
response.raise_for_status()
|
|
1368
|
+
result = response.json()
|
|
1369
|
+
reply = result["choices"][0]["message"]["content"]
|
|
1370
|
+
link1.append(reply)
|
|
1371
|
+
except req.exceptions.HTTPError as e:
|
|
1372
|
+
if e.response and e.response.status_code == 404:
|
|
1373
|
+
raise ValueError(f"Invalid Mistral model '{user_model}': {e}")
|
|
1374
|
+
else:
|
|
1375
|
+
print(f"An error occurred: {e}")
|
|
1376
|
+
link1.append(f"Error processing input: {e}")
|
|
1377
|
+
except Exception as e:
|
|
1378
|
+
print(f"An error occurred: {e}")
|
|
1379
|
+
link1.append(f"Error processing input: {e}")
|
|
1380
|
+
|
|
1381
|
+
elif valid_image == False:
|
|
1382
|
+
print("Skipped NaN input or invalid path")
|
|
1383
|
+
reply = None
|
|
1384
|
+
link1.append(f"Error processing input: invalid image")
|
|
1385
|
+
else:
|
|
1386
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
1387
|
+
# in situation that no JSON is found
|
|
1388
|
+
if reply is not None:
|
|
1389
|
+
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
1390
|
+
if extracted_json:
|
|
1391
|
+
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
1392
|
+
extracted_jsons.append(cleaned_json)
|
|
1393
|
+
#print(cleaned_json)
|
|
1394
|
+
else:
|
|
1395
|
+
error_message = """{"1":"e"}"""
|
|
1396
|
+
extracted_jsons.append(error_message)
|
|
1397
|
+
print(error_message)
|
|
1398
|
+
else:
|
|
1399
|
+
error_message = """{"1":"e"}"""
|
|
1400
|
+
extracted_jsons.append(error_message)
|
|
1401
|
+
#print(error_message)
|
|
1402
|
+
|
|
1403
|
+
# --- Safety Save ---
|
|
1404
|
+
if safety:
|
|
1405
|
+
#print(f"Saving CSV to: {save_directory}")
|
|
1406
|
+
# Save progress so far
|
|
1407
|
+
temp_df = pd.DataFrame({
|
|
1408
|
+
'image_input': image_files[:i+1],
|
|
1409
|
+
'link1': link1,
|
|
1410
|
+
'json': extracted_jsons
|
|
1411
|
+
})
|
|
1412
|
+
# Normalize processed jsons so far
|
|
1413
|
+
normalized_data_list = []
|
|
1414
|
+
for json_str in extracted_jsons:
|
|
1415
|
+
try:
|
|
1416
|
+
parsed_obj = json.loads(json_str)
|
|
1417
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1418
|
+
except json.JSONDecodeError:
|
|
1419
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1420
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1421
|
+
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
1422
|
+
# Save to CSV
|
|
1423
|
+
if save_directory is None:
|
|
1424
|
+
save_directory = os.getcwd()
|
|
1425
|
+
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
1426
|
+
|
|
1427
|
+
# --- Final DataFrame ---
|
|
1428
|
+
normalized_data_list = []
|
|
1429
|
+
for json_str in extracted_jsons:
|
|
1430
|
+
try:
|
|
1431
|
+
parsed_obj = json.loads(json_str)
|
|
1432
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1433
|
+
except json.JSONDecodeError:
|
|
1434
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1435
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1436
|
+
|
|
1437
|
+
categorized_data = pd.DataFrame({
|
|
1438
|
+
'image_input': (
|
|
1439
|
+
image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
|
|
1440
|
+
else pd.Series(image_files)
|
|
1441
|
+
),
|
|
1442
|
+
'model_response': pd.Series(link1).reset_index(drop=True),
|
|
1443
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
1444
|
+
})
|
|
1445
|
+
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
1446
|
+
|
|
1447
|
+
if to_csv:
|
|
1448
|
+
if save_directory is None:
|
|
1449
|
+
save_directory = os.getcwd()
|
|
1450
|
+
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
1451
|
+
|
|
1452
|
+
return categorized_data
|
|
1453
|
+
|
|
1454
|
+
|
|
1455
|
+
def explore_image_categories(
|
|
1456
|
+
image_input,
|
|
1457
|
+
api_key,
|
|
1458
|
+
image_description="",
|
|
1459
|
+
max_categories=12,
|
|
1460
|
+
categories_per_chunk=10,
|
|
1461
|
+
divisions=5,
|
|
1462
|
+
user_model="gpt-4o",
|
|
1463
|
+
creativity=None,
|
|
1464
|
+
specificity="broad",
|
|
1465
|
+
research_question=None,
|
|
1466
|
+
mode="image",
|
|
1467
|
+
filename=None,
|
|
1468
|
+
model_source="auto",
|
|
1469
|
+
iterations=3,
|
|
1470
|
+
random_state=None,
|
|
1471
|
+
progress_callback=None,
|
|
1472
|
+
):
|
|
1473
|
+
"""
|
|
1474
|
+
Explore and extract common categories from a collection of images.
|
|
1475
|
+
|
|
1476
|
+
Modes:
|
|
1477
|
+
- "image" (default): Samples random images and sends them directly to
|
|
1478
|
+
a vision model for category extraction. Best for visual categorization.
|
|
1479
|
+
|
|
1480
|
+
- "both": Samples random images, uses vision model to describe each
|
|
1481
|
+
image's content (including any text), then extracts categories from
|
|
1482
|
+
those descriptions. Best for images that contain text or mixed content.
|
|
1483
|
+
|
|
1484
|
+
Args:
|
|
1485
|
+
image_input: Path to image file, directory of images, or list of image paths
|
|
1486
|
+
api_key: API key for the model provider
|
|
1487
|
+
image_description: Description of what the images contain
|
|
1488
|
+
max_categories: Maximum number of final categories to return
|
|
1489
|
+
categories_per_chunk: Categories to extract per chunk of images
|
|
1490
|
+
divisions: Number of chunks to divide images into
|
|
1491
|
+
user_model: Model to use (must support vision)
|
|
1492
|
+
creativity: Temperature setting (None for default)
|
|
1493
|
+
specificity: "broad" or "specific" category granularity
|
|
1494
|
+
research_question: Optional research context
|
|
1495
|
+
mode: "image" or "both"
|
|
1496
|
+
filename: Optional CSV filename to save results
|
|
1497
|
+
model_source: "auto", "openai", "anthropic", "google", "mistral"
|
|
1498
|
+
iterations: Number of passes over the data
|
|
1499
|
+
random_state: Random seed for reproducibility
|
|
1500
|
+
progress_callback: Optional callback function for progress updates.
|
|
1501
|
+
Called as progress_callback(current_step, total_steps, step_label).
|
|
1502
|
+
|
|
1503
|
+
Returns:
|
|
1504
|
+
dict with keys:
|
|
1505
|
+
- counts_df: DataFrame of categories with counts
|
|
1506
|
+
- top_categories: List of top category names
|
|
1507
|
+
- raw_top_text: Raw model output from final merge step
|
|
1508
|
+
"""
|
|
1509
|
+
import os
|
|
1510
|
+
import re
|
|
1511
|
+
import pandas as pd
|
|
1512
|
+
import numpy as np
|
|
1513
|
+
from tqdm import tqdm
|
|
1514
|
+
|
|
1515
|
+
model_source = _detect_model_source(user_model, model_source)
|
|
1516
|
+
|
|
1517
|
+
# Load all images
|
|
1518
|
+
image_files = _load_image_files(image_input)
|
|
1519
|
+
if not image_files:
|
|
1520
|
+
raise ValueError("No image files found in the specified input.")
|
|
1521
|
+
|
|
1522
|
+
n = len(image_files)
|
|
1523
|
+
if n == 0:
|
|
1524
|
+
raise ValueError("No images found.")
|
|
1525
|
+
|
|
1526
|
+
# Auto-adjust divisions for small datasets
|
|
1527
|
+
# Images can have multiple categories each, so we can use fewer divisions
|
|
1528
|
+
original_divisions = divisions
|
|
1529
|
+
divisions = min(divisions, max(1, n // 2)) # At least 2 images per chunk
|
|
1530
|
+
if divisions != original_divisions:
|
|
1531
|
+
print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} images.")
|
|
1532
|
+
|
|
1533
|
+
# Chunk sizing - images often contain multiple categories each
|
|
1534
|
+
chunk_size = int(round(max(1, n / divisions), 0))
|
|
1535
|
+
# Don't reduce categories_per_chunk as aggressively for images since each image can yield many categories
|
|
1536
|
+
if chunk_size < 2:
|
|
1537
|
+
# Only reduce if we have very few images
|
|
1538
|
+
old_categories_per_chunk = categories_per_chunk
|
|
1539
|
+
categories_per_chunk = max(5, chunk_size * 4)
|
|
1540
|
+
print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
|
|
1541
|
+
|
|
1542
|
+
print(
|
|
1543
|
+
f"Exploring categories in images: '{image_description}'.\n"
|
|
1544
|
+
f" {n} total images, {categories_per_chunk * divisions} categories to extract, "
|
|
1545
|
+
f"{max_categories} final categories. Mode: {mode}\n"
|
|
1546
|
+
)
|
|
1547
|
+
|
|
1548
|
+
# RNG for reproducible sampling
|
|
1549
|
+
rng = np.random.default_rng(random_state)
|
|
1550
|
+
|
|
1551
|
+
# Validate model_source (clients initialized per-call using requests)
|
|
1552
|
+
import requests as req
|
|
1553
|
+
if model_source not in ["openai", "huggingface", "huggingface-together", "xai", "anthropic", "google", "mistral"]:
|
|
1554
|
+
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
1555
|
+
|
|
1556
|
+
# Determine base URL for OpenAI-compatible providers
|
|
1557
|
+
if model_source == "huggingface":
|
|
1558
|
+
from cat_stack.text_functions import _detect_huggingface_endpoint
|
|
1559
|
+
openai_base_url = _detect_huggingface_endpoint(api_key, user_model)
|
|
1560
|
+
elif model_source == "huggingface-together":
|
|
1561
|
+
openai_base_url = "https://router.huggingface.co/together/v1"
|
|
1562
|
+
elif model_source == "xai":
|
|
1563
|
+
openai_base_url = "https://api.x.ai/v1"
|
|
1564
|
+
elif model_source == "openai":
|
|
1565
|
+
openai_base_url = "https://api.openai.com/v1"
|
|
1566
|
+
else:
|
|
1567
|
+
openai_base_url = None # Not an OpenAI-compatible provider
|
|
1568
|
+
|
|
1569
|
+
def make_image_prompt() -> str:
|
|
1570
|
+
"""Build prompt for image mode - direct category extraction."""
|
|
1571
|
+
return (
|
|
1572
|
+
f"Identify {categories_per_chunk} {specificity} categories of content found in this image. "
|
|
1573
|
+
f"The image is: {image_description}. "
|
|
1574
|
+
f"{'Research context: ' + research_question if research_question else ''}\n\n"
|
|
1575
|
+
f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
|
|
1576
|
+
)
|
|
1577
|
+
|
|
1578
|
+
def make_describe_prompt() -> str:
|
|
1579
|
+
"""Build prompt for 'both' mode - describe image content."""
|
|
1580
|
+
return (
|
|
1581
|
+
f"Describe the content of this image in detail. "
|
|
1582
|
+
f"Include all visual elements, text, objects, people, and any other content. "
|
|
1583
|
+
f"The image is: {image_description}. "
|
|
1584
|
+
f"{'Research context: ' + research_question if research_question else ''}\n\n"
|
|
1585
|
+
f"Provide a comprehensive text description that captures both visual and textual content."
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
def make_text_prompt(text_blob: str) -> str:
|
|
1589
|
+
"""Build prompt for extracting categories from text descriptions."""
|
|
1590
|
+
return (
|
|
1591
|
+
f"Identify {categories_per_chunk} {specificity} categories of content found in this description. "
|
|
1592
|
+
f"The content is: {image_description}. "
|
|
1593
|
+
f"{'Research context: ' + research_question + '. ' if research_question else ''}"
|
|
1594
|
+
f"The description is contained within triple backticks: ```{text_blob}``` "
|
|
1595
|
+
f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
|
|
1596
|
+
)
|
|
1597
|
+
|
|
1598
|
+
def call_model_with_image(img_path, prompt_text, max_retries=6):
|
|
1599
|
+
"""Send an image to the model and get category extraction."""
|
|
1600
|
+
encoded, ext, is_valid = _encode_image(img_path)
|
|
1601
|
+
if not is_valid:
|
|
1602
|
+
return None
|
|
1603
|
+
|
|
1604
|
+
for attempt in range(max_retries):
|
|
1605
|
+
try:
|
|
1606
|
+
if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
|
|
1607
|
+
endpoint = f"{openai_base_url}/chat/completions"
|
|
1608
|
+
headers = {
|
|
1609
|
+
"Content-Type": "application/json",
|
|
1610
|
+
"Authorization": f"Bearer {api_key}"
|
|
1611
|
+
}
|
|
1612
|
+
messages = [{
|
|
1613
|
+
"role": "user",
|
|
1614
|
+
"content": [
|
|
1615
|
+
{"type": "text", "text": prompt_text},
|
|
1616
|
+
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
|
|
1617
|
+
]
|
|
1618
|
+
}]
|
|
1619
|
+
payload = {"model": user_model, "messages": messages}
|
|
1620
|
+
if creativity is not None:
|
|
1621
|
+
payload["temperature"] = creativity
|
|
1622
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1623
|
+
response.raise_for_status()
|
|
1624
|
+
result = response.json()
|
|
1625
|
+
return result["choices"][0]["message"]["content"]
|
|
1626
|
+
|
|
1627
|
+
elif model_source == "anthropic":
|
|
1628
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
1629
|
+
headers = {
|
|
1630
|
+
"Content-Type": "application/json",
|
|
1631
|
+
"x-api-key": api_key,
|
|
1632
|
+
"anthropic-version": "2023-06-01"
|
|
1633
|
+
}
|
|
1634
|
+
media_type = f"image/{ext}" if ext else "image/jpeg"
|
|
1635
|
+
content = [
|
|
1636
|
+
{"type": "text", "text": prompt_text},
|
|
1637
|
+
{"type": "image", "source": {"type": "base64", "media_type": media_type, "data": encoded}}
|
|
1638
|
+
]
|
|
1639
|
+
payload = {
|
|
1640
|
+
"model": user_model,
|
|
1641
|
+
"max_tokens": 2048,
|
|
1642
|
+
"messages": [{"role": "user", "content": content}],
|
|
1643
|
+
}
|
|
1644
|
+
if creativity is not None:
|
|
1645
|
+
payload["temperature"] = creativity
|
|
1646
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1647
|
+
response.raise_for_status()
|
|
1648
|
+
result = response.json()
|
|
1649
|
+
resp_content = result.get("content", [])
|
|
1650
|
+
if resp_content and resp_content[0].get("type") == "text":
|
|
1651
|
+
return resp_content[0].get("text", "")
|
|
1652
|
+
return None
|
|
1653
|
+
|
|
1654
|
+
elif model_source == "google":
|
|
1655
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
1656
|
+
headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
|
|
1657
|
+
mime_type = f"image/{ext}" if ext else "image/jpeg"
|
|
1658
|
+
parts = [
|
|
1659
|
+
{"text": prompt_text},
|
|
1660
|
+
{"inline_data": {"mime_type": mime_type, "data": encoded}}
|
|
1661
|
+
]
|
|
1662
|
+
payload = {
|
|
1663
|
+
"contents": [{"parts": parts}],
|
|
1664
|
+
"generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
|
|
1665
|
+
}
|
|
1666
|
+
response = req.post(url, headers=headers, json=payload, timeout=120)
|
|
1667
|
+
response.raise_for_status()
|
|
1668
|
+
result = response.json()
|
|
1669
|
+
if "candidates" in result and result["candidates"]:
|
|
1670
|
+
return result["candidates"][0]["content"]["parts"][0]["text"]
|
|
1671
|
+
return None
|
|
1672
|
+
|
|
1673
|
+
elif model_source == "mistral":
|
|
1674
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
1675
|
+
headers = {
|
|
1676
|
+
"Content-Type": "application/json",
|
|
1677
|
+
"Authorization": f"Bearer {api_key}"
|
|
1678
|
+
}
|
|
1679
|
+
messages = [{
|
|
1680
|
+
"role": "user",
|
|
1681
|
+
"content": [
|
|
1682
|
+
{"type": "text", "text": prompt_text},
|
|
1683
|
+
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
|
|
1684
|
+
]
|
|
1685
|
+
}]
|
|
1686
|
+
payload = {"model": user_model, "messages": messages}
|
|
1687
|
+
if creativity is not None:
|
|
1688
|
+
payload["temperature"] = creativity
|
|
1689
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1690
|
+
response.raise_for_status()
|
|
1691
|
+
result = response.json()
|
|
1692
|
+
return result["choices"][0]["message"]["content"]
|
|
1693
|
+
|
|
1694
|
+
except Exception as e:
|
|
1695
|
+
delay = 2 ** attempt
|
|
1696
|
+
if attempt < max_retries - 1:
|
|
1697
|
+
print(f"Error processing image {img_path}: {e}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
|
|
1698
|
+
import time as _time
|
|
1699
|
+
_time.sleep(delay)
|
|
1700
|
+
else:
|
|
1701
|
+
print(f"Error processing image {img_path}: {e}. All {max_retries} attempts failed.")
|
|
1702
|
+
return None
|
|
1703
|
+
|
|
1704
|
+
def describe_image_with_vision(img_path, max_retries=6):
|
|
1705
|
+
"""Use vision model to describe an image's content as text."""
|
|
1706
|
+
encoded, ext, is_valid = _encode_image(img_path)
|
|
1707
|
+
if not is_valid:
|
|
1708
|
+
return None
|
|
1709
|
+
prompt_text = make_describe_prompt()
|
|
1710
|
+
|
|
1711
|
+
for attempt in range(max_retries):
|
|
1712
|
+
try:
|
|
1713
|
+
if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
|
|
1714
|
+
endpoint = f"{openai_base_url}/chat/completions"
|
|
1715
|
+
headers = {
|
|
1716
|
+
"Content-Type": "application/json",
|
|
1717
|
+
"Authorization": f"Bearer {api_key}"
|
|
1718
|
+
}
|
|
1719
|
+
messages = [{
|
|
1720
|
+
"role": "user",
|
|
1721
|
+
"content": [
|
|
1722
|
+
{"type": "text", "text": prompt_text},
|
|
1723
|
+
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
|
|
1724
|
+
]
|
|
1725
|
+
}]
|
|
1726
|
+
payload = {"model": user_model, "messages": messages}
|
|
1727
|
+
if creativity is not None:
|
|
1728
|
+
payload["temperature"] = creativity
|
|
1729
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1730
|
+
response.raise_for_status()
|
|
1731
|
+
result = response.json()
|
|
1732
|
+
return result["choices"][0]["message"]["content"]
|
|
1733
|
+
|
|
1734
|
+
elif model_source == "anthropic":
|
|
1735
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
1736
|
+
headers = {
|
|
1737
|
+
"Content-Type": "application/json",
|
|
1738
|
+
"x-api-key": api_key,
|
|
1739
|
+
"anthropic-version": "2023-06-01"
|
|
1740
|
+
}
|
|
1741
|
+
media_type = f"image/{ext}" if ext else "image/jpeg"
|
|
1742
|
+
content = [
|
|
1743
|
+
{"type": "text", "text": prompt_text},
|
|
1744
|
+
{"type": "image", "source": {"type": "base64", "media_type": media_type, "data": encoded}}
|
|
1745
|
+
]
|
|
1746
|
+
payload = {
|
|
1747
|
+
"model": user_model,
|
|
1748
|
+
"max_tokens": 4096,
|
|
1749
|
+
"messages": [{"role": "user", "content": content}],
|
|
1750
|
+
}
|
|
1751
|
+
if creativity is not None:
|
|
1752
|
+
payload["temperature"] = creativity
|
|
1753
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1754
|
+
response.raise_for_status()
|
|
1755
|
+
result = response.json()
|
|
1756
|
+
resp_content = result.get("content", [])
|
|
1757
|
+
if resp_content and resp_content[0].get("type") == "text":
|
|
1758
|
+
return resp_content[0].get("text", "")
|
|
1759
|
+
return None
|
|
1760
|
+
|
|
1761
|
+
elif model_source == "google":
|
|
1762
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
1763
|
+
headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
|
|
1764
|
+
mime_type = f"image/{ext}" if ext else "image/jpeg"
|
|
1765
|
+
parts = [
|
|
1766
|
+
{"text": prompt_text},
|
|
1767
|
+
{"inline_data": {"mime_type": mime_type, "data": encoded}}
|
|
1768
|
+
]
|
|
1769
|
+
payload = {
|
|
1770
|
+
"contents": [{"parts": parts}],
|
|
1771
|
+
"generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
|
|
1772
|
+
}
|
|
1773
|
+
response = req.post(url, headers=headers, json=payload, timeout=120)
|
|
1774
|
+
response.raise_for_status()
|
|
1775
|
+
result = response.json()
|
|
1776
|
+
if "candidates" in result and result["candidates"]:
|
|
1777
|
+
return result["candidates"][0]["content"]["parts"][0]["text"]
|
|
1778
|
+
return None
|
|
1779
|
+
|
|
1780
|
+
elif model_source == "mistral":
|
|
1781
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
1782
|
+
headers = {
|
|
1783
|
+
"Content-Type": "application/json",
|
|
1784
|
+
"Authorization": f"Bearer {api_key}"
|
|
1785
|
+
}
|
|
1786
|
+
messages = [{
|
|
1787
|
+
"role": "user",
|
|
1788
|
+
"content": [
|
|
1789
|
+
{"type": "text", "text": prompt_text},
|
|
1790
|
+
{"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
|
|
1791
|
+
]
|
|
1792
|
+
}]
|
|
1793
|
+
payload = {"model": user_model, "messages": messages}
|
|
1794
|
+
if creativity is not None:
|
|
1795
|
+
payload["temperature"] = creativity
|
|
1796
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1797
|
+
response.raise_for_status()
|
|
1798
|
+
result = response.json()
|
|
1799
|
+
return result["choices"][0]["message"]["content"]
|
|
1800
|
+
|
|
1801
|
+
except Exception as e:
|
|
1802
|
+
delay = 2 ** attempt
|
|
1803
|
+
if attempt < max_retries - 1:
|
|
1804
|
+
print(f"Error describing image {img_path}: {e}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
|
|
1805
|
+
import time as _time
|
|
1806
|
+
_time.sleep(delay)
|
|
1807
|
+
else:
|
|
1808
|
+
print(f"Error describing image {img_path}: {e}. All {max_retries} attempts failed.")
|
|
1809
|
+
return None
|
|
1810
|
+
|
|
1811
|
+
def call_model_with_text(prompt_text):
|
|
1812
|
+
"""Send text to the model for category extraction."""
|
|
1813
|
+
try:
|
|
1814
|
+
if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
|
|
1815
|
+
endpoint = f"{openai_base_url}/chat/completions"
|
|
1816
|
+
headers = {
|
|
1817
|
+
"Content-Type": "application/json",
|
|
1818
|
+
"Authorization": f"Bearer {api_key}"
|
|
1819
|
+
}
|
|
1820
|
+
payload = {"model": user_model, "messages": [{"role": "user", "content": prompt_text}]}
|
|
1821
|
+
if creativity is not None:
|
|
1822
|
+
payload["temperature"] = creativity
|
|
1823
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1824
|
+
response.raise_for_status()
|
|
1825
|
+
result = response.json()
|
|
1826
|
+
return result["choices"][0]["message"]["content"]
|
|
1827
|
+
|
|
1828
|
+
elif model_source == "anthropic":
|
|
1829
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
1830
|
+
headers = {
|
|
1831
|
+
"Content-Type": "application/json",
|
|
1832
|
+
"x-api-key": api_key,
|
|
1833
|
+
"anthropic-version": "2023-06-01"
|
|
1834
|
+
}
|
|
1835
|
+
payload = {
|
|
1836
|
+
"model": user_model,
|
|
1837
|
+
"max_tokens": 2048,
|
|
1838
|
+
"messages": [{"role": "user", "content": prompt_text}],
|
|
1839
|
+
}
|
|
1840
|
+
if creativity is not None:
|
|
1841
|
+
payload["temperature"] = creativity
|
|
1842
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1843
|
+
response.raise_for_status()
|
|
1844
|
+
result = response.json()
|
|
1845
|
+
resp_content = result.get("content", [])
|
|
1846
|
+
if resp_content and resp_content[0].get("type") == "text":
|
|
1847
|
+
return resp_content[0].get("text", "")
|
|
1848
|
+
return None
|
|
1849
|
+
|
|
1850
|
+
elif model_source == "google":
|
|
1851
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
1852
|
+
headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
|
|
1853
|
+
payload = {
|
|
1854
|
+
"contents": [{"parts": [{"text": prompt_text}]}],
|
|
1855
|
+
"generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
|
|
1856
|
+
}
|
|
1857
|
+
response = req.post(url, headers=headers, json=payload, timeout=120)
|
|
1858
|
+
response.raise_for_status()
|
|
1859
|
+
result = response.json()
|
|
1860
|
+
if "candidates" in result and result["candidates"]:
|
|
1861
|
+
return result["candidates"][0]["content"]["parts"][0]["text"]
|
|
1862
|
+
return None
|
|
1863
|
+
|
|
1864
|
+
elif model_source == "mistral":
|
|
1865
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
1866
|
+
headers = {
|
|
1867
|
+
"Content-Type": "application/json",
|
|
1868
|
+
"Authorization": f"Bearer {api_key}"
|
|
1869
|
+
}
|
|
1870
|
+
payload = {"model": user_model, "messages": [{"role": "user", "content": prompt_text}]}
|
|
1871
|
+
if creativity is not None:
|
|
1872
|
+
payload["temperature"] = creativity
|
|
1873
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
1874
|
+
response.raise_for_status()
|
|
1875
|
+
result = response.json()
|
|
1876
|
+
return result["choices"][0]["message"]["content"]
|
|
1877
|
+
|
|
1878
|
+
except Exception as e:
|
|
1879
|
+
print(f"Error in text mode: {e}")
|
|
1880
|
+
return None
|
|
1881
|
+
|
|
1882
|
+
# Parse numbered list pattern
|
|
1883
|
+
line_pat = re.compile(r"^\s*\d+\s*[\.\)\-]\s*(.+)$")
|
|
1884
|
+
|
|
1885
|
+
all_items = []
|
|
1886
|
+
|
|
1887
|
+
# Calculate total steps for progress tracking: (iterations * divisions) + 1 for final merge
|
|
1888
|
+
total_steps = (iterations * divisions) + 1
|
|
1889
|
+
current_step = 0
|
|
1890
|
+
|
|
1891
|
+
for pass_idx in range(iterations):
|
|
1892
|
+
# Sample images for this pass
|
|
1893
|
+
image_indices = list(range(n))
|
|
1894
|
+
rng.shuffle(image_indices)
|
|
1895
|
+
|
|
1896
|
+
# Create chunks
|
|
1897
|
+
chunks = [image_indices[i:i + chunk_size] for i in range(0, len(image_indices), chunk_size)][:divisions]
|
|
1898
|
+
|
|
1899
|
+
for chunk_idx, chunk in enumerate(tqdm(chunks, desc=f"Processing chunks (pass {pass_idx+1}/{iterations})")):
|
|
1900
|
+
if not chunk:
|
|
1901
|
+
continue
|
|
1902
|
+
|
|
1903
|
+
# Sample one random image from the full pool
|
|
1904
|
+
random_idx = rng.choice(image_indices)
|
|
1905
|
+
img_path = image_files[random_idx]
|
|
1906
|
+
|
|
1907
|
+
if mode == "image":
|
|
1908
|
+
# IMAGE MODE: Send image directly for category extraction
|
|
1909
|
+
prompt = make_image_prompt()
|
|
1910
|
+
reply = call_model_with_image(img_path, prompt)
|
|
1911
|
+
|
|
1912
|
+
elif mode == "both":
|
|
1913
|
+
# BOTH MODE: Describe image first, then extract categories from description
|
|
1914
|
+
image_description_text = describe_image_with_vision(img_path)
|
|
1915
|
+
if not image_description_text:
|
|
1916
|
+
continue
|
|
1917
|
+
|
|
1918
|
+
prompt = make_text_prompt(image_description_text)
|
|
1919
|
+
reply = call_model_with_text(prompt)
|
|
1920
|
+
|
|
1921
|
+
else:
|
|
1922
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'image' or 'both'.")
|
|
1923
|
+
|
|
1924
|
+
if reply:
|
|
1925
|
+
# Extract numbered items
|
|
1926
|
+
items = []
|
|
1927
|
+
for raw_line in reply.splitlines():
|
|
1928
|
+
m = line_pat.match(raw_line.strip())
|
|
1929
|
+
if m:
|
|
1930
|
+
items.append(m.group(1).strip())
|
|
1931
|
+
# Fallback for unnumbered lines
|
|
1932
|
+
if not items:
|
|
1933
|
+
for raw_line in reply.splitlines():
|
|
1934
|
+
s = raw_line.strip()
|
|
1935
|
+
if s:
|
|
1936
|
+
items.append(s)
|
|
1937
|
+
all_items.extend(items)
|
|
1938
|
+
|
|
1939
|
+
# Progress callback
|
|
1940
|
+
current_step += 1
|
|
1941
|
+
if progress_callback:
|
|
1942
|
+
progress_callback(current_step, total_steps, f"Pass {pass_idx+1}/{iterations}, chunk {chunk_idx+1}/{len(chunks)}")
|
|
1943
|
+
|
|
1944
|
+
# Normalize and count
|
|
1945
|
+
def normalize_category(cat):
|
|
1946
|
+
terms = sorted([t.strip().lower() for t in str(cat).split("/")])
|
|
1947
|
+
return "/".join(terms)
|
|
1948
|
+
|
|
1949
|
+
flat_list = [str(x).strip() for x in all_items if str(x).strip()]
|
|
1950
|
+
if not flat_list:
|
|
1951
|
+
raise ValueError("No categories were extracted from the images.")
|
|
1952
|
+
|
|
1953
|
+
df = pd.DataFrame(flat_list, columns=["Category"])
|
|
1954
|
+
df["normalized"] = df["Category"].map(normalize_category)
|
|
1955
|
+
|
|
1956
|
+
result = (
|
|
1957
|
+
df.groupby("normalized")
|
|
1958
|
+
.agg(Category=("Category", lambda x: x.value_counts().index[0]),
|
|
1959
|
+
counts=("Category", "size"))
|
|
1960
|
+
.sort_values("counts", ascending=False)
|
|
1961
|
+
.reset_index(drop=True)
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
# Second-pass semantic merge
|
|
1965
|
+
seed_list = result["Category"].head(max_categories * 3).tolist()
|
|
1966
|
+
|
|
1967
|
+
second_prompt = f"""
|
|
1968
|
+
You are a data analyst reviewing categorized image data.
|
|
1969
|
+
|
|
1970
|
+
Task: From the provided categories, identify and return the top {max_categories} CONCEPTUALLY UNIQUE categories.
|
|
1971
|
+
|
|
1972
|
+
Critical Instructions:
|
|
1973
|
+
1) Exact duplicates are already removed.
|
|
1974
|
+
2) Merge SEMANTIC duplicates (same concept, different wording).
|
|
1975
|
+
3) When merging:
|
|
1976
|
+
- Combine frequencies mentally
|
|
1977
|
+
- Keep the most frequent OR clearest label
|
|
1978
|
+
- Each concept appears ONLY ONCE
|
|
1979
|
+
4) Keep category names {specificity}.
|
|
1980
|
+
5) Return ONLY a numbered list of {max_categories} categories. No extra text.
|
|
1981
|
+
|
|
1982
|
+
Pre-processed Categories (sorted by frequency, top sample):
|
|
1983
|
+
{seed_list}
|
|
1984
|
+
|
|
1985
|
+
Output:
|
|
1986
|
+
1. category
|
|
1987
|
+
2. category
|
|
1988
|
+
...
|
|
1989
|
+
{max_categories}. category
|
|
1990
|
+
""".strip()
|
|
1991
|
+
|
|
1992
|
+
try:
|
|
1993
|
+
if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
|
|
1994
|
+
endpoint = f"{openai_base_url}/chat/completions"
|
|
1995
|
+
headers = {
|
|
1996
|
+
"Content-Type": "application/json",
|
|
1997
|
+
"Authorization": f"Bearer {api_key}"
|
|
1998
|
+
}
|
|
1999
|
+
payload = {"model": user_model, "messages": [{"role": "user", "content": second_prompt}]}
|
|
2000
|
+
if creativity is not None:
|
|
2001
|
+
payload["temperature"] = creativity
|
|
2002
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
2003
|
+
response.raise_for_status()
|
|
2004
|
+
result = response.json()
|
|
2005
|
+
top_categories_text = result["choices"][0]["message"]["content"]
|
|
2006
|
+
elif model_source == "anthropic":
|
|
2007
|
+
endpoint = "https://api.anthropic.com/v1/messages"
|
|
2008
|
+
headers = {
|
|
2009
|
+
"Content-Type": "application/json",
|
|
2010
|
+
"x-api-key": api_key,
|
|
2011
|
+
"anthropic-version": "2023-06-01"
|
|
2012
|
+
}
|
|
2013
|
+
payload = {
|
|
2014
|
+
"model": user_model,
|
|
2015
|
+
"max_tokens": 2048,
|
|
2016
|
+
"messages": [{"role": "user", "content": second_prompt}],
|
|
2017
|
+
}
|
|
2018
|
+
if creativity is not None:
|
|
2019
|
+
payload["temperature"] = creativity
|
|
2020
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
2021
|
+
response.raise_for_status()
|
|
2022
|
+
result = response.json()
|
|
2023
|
+
resp_content = result.get("content", [])
|
|
2024
|
+
if resp_content and resp_content[0].get("type") == "text":
|
|
2025
|
+
top_categories_text = resp_content[0].get("text", "")
|
|
2026
|
+
else:
|
|
2027
|
+
top_categories_text = ""
|
|
2028
|
+
elif model_source == "google":
|
|
2029
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
2030
|
+
headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
|
|
2031
|
+
payload = {
|
|
2032
|
+
"contents": [{"parts": [{"text": second_prompt}]}],
|
|
2033
|
+
"generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
|
|
2034
|
+
}
|
|
2035
|
+
response = req.post(url, headers=headers, json=payload, timeout=120)
|
|
2036
|
+
response.raise_for_status()
|
|
2037
|
+
res = response.json()
|
|
2038
|
+
top_categories_text = res["candidates"][0]["content"]["parts"][0]["text"]
|
|
2039
|
+
elif model_source == "mistral":
|
|
2040
|
+
endpoint = "https://api.mistral.ai/v1/chat/completions"
|
|
2041
|
+
headers = {
|
|
2042
|
+
"Content-Type": "application/json",
|
|
2043
|
+
"Authorization": f"Bearer {api_key}"
|
|
2044
|
+
}
|
|
2045
|
+
payload = {"model": user_model, "messages": [{"role": "user", "content": second_prompt}]}
|
|
2046
|
+
if creativity is not None:
|
|
2047
|
+
payload["temperature"] = creativity
|
|
2048
|
+
response = req.post(endpoint, headers=headers, json=payload, timeout=120)
|
|
2049
|
+
response.raise_for_status()
|
|
2050
|
+
result = response.json()
|
|
2051
|
+
top_categories_text = result["choices"][0]["message"]["content"]
|
|
2052
|
+
except Exception as e:
|
|
2053
|
+
print(f"Error in second-pass merge: {e}")
|
|
2054
|
+
top_categories_text = ""
|
|
2055
|
+
|
|
2056
|
+
# Final progress callback for the merge step
|
|
2057
|
+
if progress_callback:
|
|
2058
|
+
progress_callback(total_steps, total_steps, "Merging categories")
|
|
2059
|
+
|
|
2060
|
+
# Parse final list
|
|
2061
|
+
final = []
|
|
2062
|
+
for line in top_categories_text.splitlines():
|
|
2063
|
+
m = line_pat.match(line.strip())
|
|
2064
|
+
if m:
|
|
2065
|
+
final.append(m.group(1).strip())
|
|
2066
|
+
if not final:
|
|
2067
|
+
final = [l.strip("-*• ").strip() for l in top_categories_text.splitlines() if l.strip()]
|
|
2068
|
+
|
|
2069
|
+
print("\nTop categories:\n" + "\n".join(f"{i+1}. {c}" for i, c in enumerate(final[:max_categories])))
|
|
2070
|
+
|
|
2071
|
+
if filename:
|
|
2072
|
+
result.to_csv(filename, index=False)
|
|
2073
|
+
|
|
2074
|
+
return {
|
|
2075
|
+
"counts_df": result,
|
|
2076
|
+
"top_categories": final[:max_categories],
|
|
2077
|
+
"raw_top_text": top_categories_text
|
|
2078
|
+
}
|