cat-stack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2078 @@
1
+ import warnings
2
+
3
+ from .text_functions import _detect_model_source
4
+ from .calls.image_stepback import get_image_stepback_insight
5
+
6
+ # Exported names (excludes deprecated image_multi_class)
7
+ __all__ = [
8
+ "_load_image_files",
9
+ "_encode_image",
10
+ "image_score_drawing",
11
+ "image_features",
12
+ "explore_image_categories",
13
+ ]
14
+ from .calls.image_CoVe import (
15
+ image_chain_of_verification_openai,
16
+ image_chain_of_verification_anthropic,
17
+ image_chain_of_verification_google,
18
+ image_chain_of_verification_mistral
19
+ )
20
+
21
+
22
+ def _load_image_files(image_input):
23
+ """Load image files from directory path, single file path, or return list as-is."""
24
+ import os
25
+ import glob
26
+
27
+ image_extensions = [
28
+ '*.png', '*.jpg', '*.jpeg',
29
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
30
+ '*.tif', '*.tiff', '*.bmp',
31
+ '*.heif', '*.heic', '*.ico',
32
+ '*.psd'
33
+ ]
34
+
35
+ if isinstance(image_input, list):
36
+ image_files = image_input
37
+ print(f"Provided a list of {len(image_input)} images.")
38
+ elif os.path.isfile(image_input):
39
+ # Single file path
40
+ image_files = [image_input]
41
+ print(f"Provided 1 image file.")
42
+ elif os.path.isdir(image_input):
43
+ # Directory path - glob for images
44
+ image_files = []
45
+ for ext in image_extensions:
46
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
47
+ print(f"Found {len(image_files)} images in directory.")
48
+ else:
49
+ raise FileNotFoundError(f"Image input not found: {image_input}")
50
+
51
+ return image_files
52
+
53
+
54
+ def _encode_image(img_path):
55
+ """Encode an image file to base64. Returns (encoded_data, extension, is_valid)."""
56
+ import os
57
+ import base64
58
+ from pathlib import Path
59
+
60
+ if img_path is None or not os.path.exists(img_path):
61
+ return None, None, False
62
+
63
+ if os.path.isdir(img_path):
64
+ return None, None, False
65
+
66
+ try:
67
+ with open(img_path, "rb") as f:
68
+ encoded = base64.b64encode(f.read()).decode("utf-8")
69
+ ext = Path(img_path).suffix.lstrip(".").lower()
70
+ if ext == "jpg":
71
+ ext = "jpeg"
72
+ return encoded, ext, True
73
+ except Exception as e:
74
+ print(f"Error encoding image: {e}")
75
+ return None, None, False
76
+
77
+
78
+ # image multi-class (binary) function
79
+ def image_multi_class(
80
+ image_description,
81
+ image_input,
82
+ categories,
83
+ api_key,
84
+ user_model="gpt-4o",
85
+ creativity=None,
86
+ safety=False,
87
+ chain_of_verification=False,
88
+ chain_of_thought=True,
89
+ step_back_prompt=False,
90
+ context_prompt=False,
91
+ thinking_budget=0,
92
+ example1=None,
93
+ example2=None,
94
+ example3=None,
95
+ example4=None,
96
+ example5=None,
97
+ example6=None,
98
+ filename=None,
99
+ save_directory=None,
100
+ model_source="auto"
101
+ ):
102
+ """
103
+ Classify images using LLMs.
104
+
105
+ .. deprecated::
106
+ Use :func:`cat_stack.classify` instead. This function will be removed in a future version.
107
+ """
108
+ warnings.warn(
109
+ "image_multi_class() is deprecated and will be removed in a future version. "
110
+ "Use cat_stack.classify() instead, which auto-detects image input.",
111
+ DeprecationWarning,
112
+ stacklevel=2,
113
+ )
114
+
115
+ import os
116
+ import json
117
+ import pandas as pd
118
+ import regex
119
+ import time
120
+ from tqdm import tqdm
121
+
122
+ if save_directory is not None and not os.path.isdir(save_directory):
123
+ raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
124
+
125
+ model_source = _detect_model_source(user_model, model_source)
126
+
127
+ image_files = _load_image_files(image_input)
128
+
129
+ # Handle "auto" categories - extract categories first
130
+ if categories == "auto":
131
+ if not image_description:
132
+ raise ValueError("image_description is required when using categories='auto'")
133
+
134
+ print("\nAuto-extracting categories from images...")
135
+ auto_result = explore_image_categories(
136
+ image_input=image_input,
137
+ api_key=api_key,
138
+ image_description=image_description,
139
+ user_model=user_model,
140
+ model_source=model_source,
141
+ creativity=creativity
142
+ )
143
+ categories = auto_result["top_categories"]
144
+ print(f"Extracted {len(categories)} categories: {categories}\n")
145
+
146
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
147
+ cat_num = len(categories)
148
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
149
+ example_JSON = json.dumps(category_dict, indent=4)
150
+
151
+ print(f"\nCategories to classify by {model_source} {user_model}:")
152
+ for i, cat in enumerate(categories, 1):
153
+ print(f"{i}. {cat}")
154
+
155
+ # Build examples text from provided examples
156
+ examples = [example1, example2, example3, example4, example5, example6]
157
+ examples = [ex for ex in examples if ex is not None]
158
+ if examples:
159
+ examples_text = "Here are some examples of how to categorize:\n" + "\n".join(examples)
160
+ else:
161
+ examples_text = ""
162
+
163
+ # Helper function for CoVe
164
+ def remove_numbering(line):
165
+ line = line.strip()
166
+ if line.startswith('- '):
167
+ return line[2:].strip()
168
+ if line.startswith('• '):
169
+ return line[2:].strip()
170
+ if line and line[0].isdigit():
171
+ i = 0
172
+ while i < len(line) and line[i].isdigit():
173
+ i += 1
174
+ if i < len(line) and line[i] in '.':
175
+ return line[i+1:].strip()
176
+ elif i < len(line) and line[i] in ')':
177
+ return line[i+1:].strip()
178
+ return line
179
+
180
+ # Step-back insight initialization
181
+ if step_back_prompt:
182
+ stepback = f"""What are the key visual features or patterns that typically indicate the presence of these categories in images showing "{image_description}"?
183
+
184
+ Categories to consider:
185
+ {categories_str}
186
+
187
+ Provide a brief analysis of what visual cues to look for when categorizing such images."""
188
+
189
+ stepback_insight, step_back_added = get_image_stepback_insight(
190
+ model_source, stepback, api_key, user_model, creativity
191
+ )
192
+ else:
193
+ stepback_insight = None
194
+ step_back_added = False
195
+
196
+ link1 = []
197
+ extracted_jsons = []
198
+
199
+ def _build_base_prompt_text():
200
+ """Build the base text portion of the prompt."""
201
+ if chain_of_thought:
202
+ base_text = (
203
+ f"You are an image-tagging assistant.\n"
204
+ f"Task ► Examine the attached image and decide, **for each category below**, "
205
+ f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
206
+ f"Image is expected to show: {image_description}\n\n"
207
+ f"Categories:\n{categories_str}\n\n"
208
+ f"Let's analyze step by step:\n"
209
+ f"1. First, identify the key visual elements in the image\n"
210
+ f"2. Then, match each element to the relevant categories\n"
211
+ f"3. Finally, assign 1 to matching categories and 0 to non-matching categories\n\n"
212
+ f"{examples_text}\n\n"
213
+ f"Output format ► Respond with **only** a JSON object whose keys are the "
214
+ f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
215
+ f"No additional keys, comments, or text.\n\n"
216
+ f"Example (three categories):\n"
217
+ f"{example_JSON}"
218
+ )
219
+ else:
220
+ base_text = (
221
+ f"You are an image-tagging assistant.\n"
222
+ f"Task ► Examine the attached image and decide, **for each category below**, "
223
+ f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
224
+ f"Image is expected to show: {image_description}\n\n"
225
+ f"Categories:\n{categories_str}\n\n"
226
+ f"{examples_text}\n\n"
227
+ f"Output format ► Respond with **only** a JSON object whose keys are the "
228
+ f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
229
+ f"No additional keys, comments, or text.\n\n"
230
+ f"Example (three categories):\n"
231
+ f"{example_JSON}"
232
+ )
233
+
234
+ if context_prompt:
235
+ context = (
236
+ "You are an expert visual analyst specializing in image categorization. "
237
+ "Apply multi-label classification based on explicit and implicit visual cues. "
238
+ "When uncertain, prioritize precision over recall.\n\n"
239
+ )
240
+ base_text = context + base_text
241
+
242
+ return base_text
243
+
244
+ def _build_cove_prompts(base_prompt_text):
245
+ """Build chain of verification prompts for images."""
246
+ step2_prompt = f"""You provided this initial categorization:
247
+ <<INITIAL_REPLY>>
248
+
249
+ Original task: {base_prompt_text}
250
+
251
+ Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
252
+ - Be concise and specific (one sentence)
253
+ - Address a distinct visual element or category assignment
254
+ - Be answerable by re-examining the image
255
+
256
+ Focus on verifying:
257
+ - Whether each category assignment matches what's visible in the image
258
+ - Whether any visual elements were missed or misinterpreted
259
+ - Whether there are any logical inconsistencies
260
+
261
+ Provide only the verification questions as a numbered list."""
262
+
263
+ step3_prompt = f"""Re-examine the attached image and answer the following verification question.
264
+
265
+ Image description: {image_description}
266
+
267
+ Verification question: <<QUESTION>>
268
+
269
+ Provide a brief, direct answer (1-2 sentences maximum) based on what you observe in the image.
270
+
271
+ Answer:"""
272
+
273
+ step4_prompt = f"""Original task: {base_prompt_text}
274
+ Initial categorization:
275
+ <<INITIAL_REPLY>>
276
+ Verification questions and answers:
277
+ <<VERIFICATION_QA>>
278
+ Based on this verification, provide the final corrected categorization.
279
+ If no categories are present, assign "0" to all categories.
280
+ Provide the final categorization in the same JSON format:"""
281
+
282
+ return step2_prompt, step3_prompt, step4_prompt
283
+
284
+ def _build_prompt_openai_mistral(encoded, ext, base_text):
285
+ """Build prompt for OpenAI/Mistral format."""
286
+ encoded_image = f"data:image/{ext};base64,{encoded}"
287
+ return [
288
+ {"type": "text", "text": base_text},
289
+ {"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}},
290
+ ]
291
+
292
+ def _build_prompt_anthropic(encoded, ext, base_text):
293
+ """Build prompt for Anthropic format."""
294
+ media_type = f"image/{ext}" if ext else "image/jpeg"
295
+ return [
296
+ {"type": "text", "text": base_text},
297
+ {
298
+ "type": "image",
299
+ "source": {
300
+ "type": "base64",
301
+ "media_type": media_type,
302
+ "data": encoded
303
+ }
304
+ }
305
+ ]
306
+
307
+ def _build_prompt_google(encoded, ext, base_text):
308
+ """Build prompt for Google format."""
309
+ return {
310
+ "text_prompt": base_text,
311
+ "image_data": encoded,
312
+ "mime_type": f"image/{ext}" if ext else "image/jpeg"
313
+ }
314
+
315
+ def _call_openai_compatible(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
316
+ """Handle OpenAI-compatible API calls (OpenAI, Perplexity, HuggingFace, xAI)."""
317
+ import requests as req
318
+
319
+ # Determine the base URL based on model source
320
+ if model_source == "huggingface":
321
+ from cat_stack.text_functions import _detect_huggingface_endpoint
322
+ base_url = _detect_huggingface_endpoint(api_key, user_model)
323
+ elif model_source == "huggingface-together":
324
+ base_url = "https://router.huggingface.co/together/v1"
325
+ elif model_source == "perplexity":
326
+ base_url = "https://api.perplexity.ai"
327
+ elif model_source == "xai":
328
+ base_url = "https://api.x.ai/v1"
329
+ else:
330
+ base_url = "https://api.openai.com/v1"
331
+
332
+ endpoint = f"{base_url}/chat/completions"
333
+
334
+ headers = {
335
+ "Content-Type": "application/json",
336
+ "Authorization": f"Bearer {api_key}"
337
+ }
338
+
339
+ max_retries = 8
340
+ delay = 2
341
+
342
+ for attempt in range(max_retries):
343
+ try:
344
+ # Build messages with optional stepback
345
+ messages = []
346
+ if step_back_prompt and step_back_added:
347
+ messages.append({'role': 'user', 'content': stepback})
348
+ messages.append({'role': 'assistant', 'content': stepback_insight})
349
+ messages.append({'role': 'user', 'content': prompt})
350
+
351
+ payload = {
352
+ "model": user_model,
353
+ "messages": messages,
354
+ }
355
+ if creativity is not None:
356
+ payload["temperature"] = creativity
357
+
358
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
359
+ response.raise_for_status()
360
+ result = response.json()
361
+ reply = result["choices"][0]["message"]["content"]
362
+
363
+ if chain_of_verification:
364
+ reply = image_chain_of_verification_openai(
365
+ initial_reply=reply,
366
+ step2_prompt=step2_prompt,
367
+ step3_prompt=step3_prompt,
368
+ step4_prompt=step4_prompt,
369
+ client=None, # Not used anymore, CoVe needs refactoring too
370
+ user_model=user_model,
371
+ creativity=creativity,
372
+ remove_numbering=remove_numbering,
373
+ image_content=image_content
374
+ )
375
+
376
+ return reply, None
377
+
378
+ except req.exceptions.HTTPError as e:
379
+ error_str = str(e).lower()
380
+ status_code = e.response.status_code if e.response else None
381
+
382
+ if status_code == 400 and "json_validate_failed" in error_str and attempt < max_retries - 1:
383
+ wait_time = delay * (2 ** attempt)
384
+ print(f"⚠️ JSON validation failed. Attempt {attempt + 1}/{max_retries}")
385
+ print(f"Retrying in {wait_time}s...")
386
+ time.sleep(wait_time)
387
+ elif status_code == 404:
388
+ raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
389
+ elif status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
390
+ wait_time = delay * (2 ** attempt)
391
+ print(f"Attempt {attempt + 1} failed with error: {e}")
392
+ print(f"Retrying in {wait_time}s...")
393
+ time.sleep(wait_time)
394
+ else:
395
+ print(f"❌ Failed after {max_retries} attempts: {e}")
396
+ return """{"1":"e"}""", f"Error processing input: {e}"
397
+
398
+ except Exception as e:
399
+ if ("500" in str(e) or "504" in str(e)) and attempt < max_retries - 1:
400
+ wait_time = delay * (2 ** attempt)
401
+ print(f"Attempt {attempt + 1} failed with error: {e}")
402
+ print(f"Retrying in {wait_time}s...")
403
+ time.sleep(wait_time)
404
+ else:
405
+ print(f"❌ Failed after {max_retries} attempts: {e}")
406
+ return """{"1":"e"}""", f"Error processing input: {e}"
407
+
408
+ return """{"1":"e"}""", "Max retries exceeded"
409
+
410
+ def _call_anthropic(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
411
+ """Handle Anthropic API calls using direct HTTP requests."""
412
+ import requests as req
413
+
414
+ endpoint = "https://api.anthropic.com/v1/messages"
415
+ headers = {
416
+ "Content-Type": "application/json",
417
+ "x-api-key": api_key,
418
+ "anthropic-version": "2023-06-01"
419
+ }
420
+
421
+ try:
422
+ # Build messages with optional stepback
423
+ messages = []
424
+ if step_back_prompt and step_back_added:
425
+ messages.append({'role': 'user', 'content': stepback})
426
+ messages.append({'role': 'assistant', 'content': stepback_insight})
427
+ messages.append({'role': 'user', 'content': prompt})
428
+
429
+ payload = {
430
+ "model": user_model,
431
+ "max_tokens": 1024,
432
+ "messages": messages,
433
+ }
434
+ if creativity is not None:
435
+ payload["temperature"] = creativity
436
+
437
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
438
+ response.raise_for_status()
439
+ result = response.json()
440
+
441
+ content = result.get("content", [])
442
+ if content and content[0].get("type") == "text":
443
+ reply = content[0].get("text", "")
444
+ else:
445
+ return """{"1":"e"}""", "No text content in response"
446
+
447
+ if chain_of_verification:
448
+ reply = image_chain_of_verification_anthropic(
449
+ initial_reply=reply,
450
+ step2_prompt=step2_prompt,
451
+ step3_prompt=step3_prompt,
452
+ step4_prompt=step4_prompt,
453
+ client=None, # No longer using SDK client
454
+ user_model=user_model,
455
+ creativity=creativity,
456
+ remove_numbering=remove_numbering,
457
+ image_content=image_content,
458
+ api_key=api_key # Pass api_key for HTTP calls
459
+ )
460
+
461
+ return reply, None
462
+
463
+ except req.exceptions.HTTPError as e:
464
+ if e.response is not None and e.response.status_code == 404:
465
+ raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
466
+ print(f"An error occurred: {e}")
467
+ return """{"1":"e"}""", f"Error processing input: {e}"
468
+ except Exception as e:
469
+ print(f"An error occurred: {e}")
470
+ return """{"1":"e"}""", f"Error processing input: {e}"
471
+
472
+ def _call_google(prompt_data, step2_prompt, step3_prompt, step4_prompt, base_prompt_text):
473
+ """Handle Google API calls."""
474
+ import requests
475
+
476
+ def make_google_request(url, headers, payload, max_retries=8):
477
+ for attempt in range(max_retries):
478
+ try:
479
+ response = requests.post(url, headers=headers, json=payload)
480
+ response.raise_for_status()
481
+ return response.json()
482
+ except requests.exceptions.HTTPError as e:
483
+ status_code = e.response.status_code
484
+ retryable_errors = [429, 500, 502, 503, 504]
485
+
486
+ if status_code in retryable_errors and attempt < max_retries - 1:
487
+ wait_time = 10 * (2 ** attempt) if status_code == 429 else 2 * (2 ** attempt)
488
+ error_type = "Rate limited" if status_code == 429 else f"Server error {status_code}"
489
+ print(f"⚠️ {error_type}. Attempt {attempt + 1}/{max_retries}")
490
+ print(f"Retrying in {wait_time}s...")
491
+ time.sleep(wait_time)
492
+ else:
493
+ raise
494
+
495
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
496
+ headers = {
497
+ "x-goog-api-key": api_key,
498
+ "Content-Type": "application/json"
499
+ }
500
+
501
+ # Build parts with optional stepback context
502
+ parts = []
503
+ if step_back_prompt and step_back_added:
504
+ parts.append({"text": f"Context from step-back analysis:\n{stepback_insight}\n\n"})
505
+ parts.append({"text": prompt_data["text_prompt"]})
506
+ parts.append({
507
+ "inline_data": {
508
+ "mime_type": prompt_data["mime_type"],
509
+ "data": prompt_data["image_data"]
510
+ }
511
+ })
512
+
513
+ payload = {
514
+ "contents": [{"parts": parts}],
515
+ "generationConfig": {
516
+ "responseMimeType": "application/json",
517
+ **({"temperature": creativity} if creativity is not None else {}),
518
+ **({"thinkingConfig": {"thinkingBudget": thinking_budget}} if thinking_budget else {})
519
+ }
520
+ }
521
+
522
+ try:
523
+ result = make_google_request(url, headers, payload)
524
+
525
+ if "candidates" in result and result["candidates"]:
526
+ reply = result["candidates"][0]["content"]["parts"][0]["text"]
527
+ else:
528
+ return "No response generated", None
529
+
530
+ if chain_of_verification:
531
+ reply = image_chain_of_verification_google(
532
+ initial_reply=reply,
533
+ prompt=base_prompt_text,
534
+ step2_prompt=step2_prompt,
535
+ step3_prompt=step3_prompt,
536
+ step4_prompt=step4_prompt,
537
+ url=url,
538
+ headers=headers,
539
+ creativity=creativity,
540
+ remove_numbering=remove_numbering,
541
+ make_google_request=make_google_request,
542
+ image_data=prompt_data["image_data"],
543
+ mime_type=prompt_data["mime_type"]
544
+ )
545
+
546
+ return reply, None
547
+
548
+ except requests.exceptions.HTTPError as e:
549
+ if e.response.status_code == 404:
550
+ raise ValueError(f"❌ Model '{user_model}' not found. Please check the model name and try again.") from e
551
+ elif e.response.status_code in [401, 403]:
552
+ raise ValueError(f"❌ Authentication failed. Please check your Google API key.") from e
553
+ else:
554
+ print(f"HTTP error occurred: {e}")
555
+ return """{"1":"e"}""", f"Error processing input: {e}"
556
+ except Exception as e:
557
+ print(f"An error occurred: {e}")
558
+ return """{"1":"e"}""", f"Error processing input: {e}"
559
+
560
+ def _call_mistral(prompt, step2_prompt, step3_prompt, step4_prompt, image_content):
561
+ """Handle Mistral API calls - uses requests directly."""
562
+ import requests as req
563
+
564
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
565
+ headers = {
566
+ "Content-Type": "application/json",
567
+ "Authorization": f"Bearer {api_key}"
568
+ }
569
+
570
+ max_retries = 8
571
+ delay = 2
572
+
573
+ for attempt in range(max_retries):
574
+ try:
575
+ # Build messages with optional stepback
576
+ messages = []
577
+ if step_back_prompt and step_back_added:
578
+ messages.append({'role': 'user', 'content': stepback})
579
+ messages.append({'role': 'assistant', 'content': stepback_insight})
580
+ messages.append({'role': 'user', 'content': prompt})
581
+
582
+ payload = {
583
+ "model": user_model,
584
+ "messages": messages,
585
+ }
586
+ if creativity is not None:
587
+ payload["temperature"] = creativity
588
+
589
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
590
+ response.raise_for_status()
591
+ result = response.json()
592
+ reply = result["choices"][0]["message"]["content"]
593
+
594
+ if chain_of_verification:
595
+ reply = image_chain_of_verification_mistral(
596
+ initial_reply=reply,
597
+ step2_prompt=step2_prompt,
598
+ step3_prompt=step3_prompt,
599
+ step4_prompt=step4_prompt,
600
+ client=None, # Not used anymore, CoVe needs refactoring too
601
+ user_model=user_model,
602
+ creativity=creativity,
603
+ remove_numbering=remove_numbering,
604
+ image_content=image_content
605
+ )
606
+
607
+ return reply, None
608
+
609
+ except req.exceptions.HTTPError as e:
610
+ error_str = str(e).lower()
611
+ status_code = e.response.status_code if e.response else None
612
+
613
+ if status_code == 404 or "invalid_model" in error_str or "invalid model" in error_str:
614
+ raise ValueError(f"❌ Model '{user_model}' not found.") from e
615
+ elif status_code == 401 or "unauthorized" in error_str:
616
+ raise ValueError(f"❌ Authentication failed. Please check your Mistral API key.") from e
617
+ elif status_code in [500, 502, 503, 504] and attempt < max_retries - 1:
618
+ wait_time = delay * (2 ** attempt)
619
+ print(f"⚠️ Server error {status_code}. Attempt {attempt + 1}/{max_retries}")
620
+ print(f"Retrying in {wait_time}s...")
621
+ time.sleep(wait_time)
622
+ else:
623
+ print(f"❌ Failed after {max_retries} attempts: {e}")
624
+ return """{"1":"e"}""", f"Error processing input: {e}"
625
+
626
+ except Exception as e:
627
+ print(f"❌ Unexpected error: {e}")
628
+ return """{"1":"e"}""", f"Error processing input: {e}"
629
+
630
+ return """{"1":"e"}""", "Max retries exceeded"
631
+
632
+ def _process_single_image(img_path):
633
+ """Process a single image and return (reply, error_msg)."""
634
+ encoded, ext, is_valid = _encode_image(img_path)
635
+
636
+ if not is_valid:
637
+ return None, "Invalid image path or encoding failed"
638
+
639
+ base_prompt_text = _build_base_prompt_text()
640
+
641
+ if chain_of_verification:
642
+ step2_prompt, step3_prompt, step4_prompt = _build_cove_prompts(base_prompt_text)
643
+ else:
644
+ step2_prompt = step3_prompt = step4_prompt = None
645
+
646
+ if model_source in ["openai", "perplexity", "huggingface", "xai"]:
647
+ prompt = _build_prompt_openai_mistral(encoded, ext, base_prompt_text)
648
+ # Image content for CoVe (just the image part)
649
+ encoded_image = f"data:image/{ext};base64,{encoded}"
650
+ image_content = {"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}}
651
+ return _call_openai_compatible(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
652
+
653
+ elif model_source == "anthropic":
654
+ prompt = _build_prompt_anthropic(encoded, ext, base_prompt_text)
655
+ media_type = f"image/{ext}" if ext else "image/jpeg"
656
+ image_content = {
657
+ "type": "image",
658
+ "source": {"type": "base64", "media_type": media_type, "data": encoded}
659
+ }
660
+ return _call_anthropic(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
661
+
662
+ elif model_source == "google":
663
+ prompt_data = _build_prompt_google(encoded, ext, base_prompt_text)
664
+ return _call_google(prompt_data, step2_prompt, step3_prompt, step4_prompt, base_prompt_text)
665
+
666
+ elif model_source == "mistral":
667
+ prompt = _build_prompt_openai_mistral(encoded, ext, base_prompt_text)
668
+ encoded_image = f"data:image/{ext};base64,{encoded}"
669
+ image_content = {"type": "image_url", "image_url": {"url": encoded_image, "detail": "high"}}
670
+ return _call_mistral(prompt, step2_prompt, step3_prompt, step4_prompt, image_content)
671
+
672
+ else:
673
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, Google, xAI, Huggingface, or Mistral")
674
+
675
+ def _extract_json(reply):
676
+ """Extract JSON from model reply."""
677
+ if reply is None:
678
+ return """{"1":"e"}"""
679
+
680
+ if reply == "invalid image path":
681
+ return """{"no_valid_path": 1}"""
682
+
683
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
684
+ if extracted_json:
685
+ return extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
686
+ else:
687
+ print("""{"1":"e"}""")
688
+ return """{"1":"e"}"""
689
+
690
+ # Main processing loop
691
+ for idx, img_path in enumerate(tqdm(image_files, desc="Categorizing images")):
692
+ if img_path is None:
693
+ link1.append("Skipped NaN input")
694
+ extracted_jsons.append("""{"no_valid_image": 1}""")
695
+ continue
696
+
697
+ reply, error_msg = _process_single_image(img_path)
698
+
699
+ if error_msg:
700
+ link1.append(error_msg)
701
+ if "Invalid image" in error_msg:
702
+ extracted_jsons.append("""{"no_valid_path": 1}""")
703
+ else:
704
+ extracted_jsons.append(_extract_json(reply))
705
+ else:
706
+ link1.append(reply)
707
+ extracted_jsons.append(_extract_json(reply))
708
+
709
+ # --- Safety Save ---
710
+ if safety:
711
+ if filename is None:
712
+ raise TypeError("filename is required when using safety. Please provide the filename.")
713
+
714
+ normalized_data_list = []
715
+ for json_str in extracted_jsons:
716
+ try:
717
+ parsed_obj = json.loads(json_str)
718
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
719
+ except json.JSONDecodeError:
720
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
721
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
722
+
723
+ temp_df = pd.DataFrame({
724
+ 'image_input': image_files[:idx+1],
725
+ 'model_response': link1,
726
+ 'json': extracted_jsons
727
+ })
728
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
729
+
730
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
731
+ temp_df.to_csv(save_path, index=False)
732
+
733
+ # --- Final DataFrame ---
734
+ normalized_data_list = []
735
+ for json_str in extracted_jsons:
736
+ try:
737
+ parsed_obj = json.loads(json_str)
738
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
739
+ except json.JSONDecodeError:
740
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
741
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
742
+
743
+ categorized_data = pd.DataFrame({
744
+ 'image_input': pd.Series(image_files),
745
+ 'model_response': pd.Series(link1).reset_index(drop=True),
746
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
747
+ })
748
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
749
+ categorized_data = categorized_data.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
750
+
751
+ # Identify rows with invalid strings (like "e")
752
+ cat_cols = [col for col in categorized_data.columns if col.startswith('category_')]
753
+ has_invalid_strings = categorized_data[cat_cols].apply(
754
+ lambda col: pd.to_numeric(col, errors='coerce').isna() & col.notna()
755
+ ).any(axis=1)
756
+
757
+ categorized_data['processing_status'] = (~has_invalid_strings).map({True: 'success', False: 'error'})
758
+ categorized_data.loc[has_invalid_strings, cat_cols] = pd.NA
759
+
760
+ for col in cat_cols:
761
+ categorized_data[col] = pd.to_numeric(categorized_data[col], errors='coerce')
762
+
763
+ categorized_data.loc[~has_invalid_strings, cat_cols] = (
764
+ categorized_data.loc[~has_invalid_strings, cat_cols].fillna(0)
765
+ )
766
+ categorized_data[cat_cols] = categorized_data[cat_cols].astype('Int64')
767
+
768
+ # Create categories_id (comma-separated binary values for each category)
769
+ categorized_data['categories_id'] = categorized_data[cat_cols].apply(
770
+ lambda x: ','.join(x.dropna().astype(int).astype(str)), axis=1
771
+ )
772
+
773
+ if filename:
774
+ save_path = os.path.join(save_directory, filename) if save_directory else filename
775
+ categorized_data.to_csv(save_path, index=False)
776
+
777
+ return categorized_data
778
+
779
+
780
+ # image score function
781
+ def image_score_drawing(
782
+ reference_image_description,
783
+ image_input,
784
+ reference_image,
785
+ api_key,
786
+ columns="numbered",
787
+ user_model="gpt-4o-2024-11-20",
788
+ creativity=None,
789
+ to_csv=False,
790
+ safety=False,
791
+ filename="categorized_data.csv",
792
+ save_directory=None,
793
+ model_source="OpenAI"
794
+ ):
795
+ import os
796
+ import json
797
+ import pandas as pd
798
+ import regex
799
+ from tqdm import tqdm
800
+ import glob
801
+ import base64
802
+ from pathlib import Path
803
+
804
+ if save_directory is not None and not os.path.isdir(save_directory):
805
+ # Directory doesn't exist - raise an exception to halt execution
806
+ raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
807
+
808
+ image_extensions = [
809
+ '*.png', '*.jpg', '*.jpeg',
810
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
811
+ '*.tif', '*.tiff', '*.bmp',
812
+ '*.heif', '*.heic', '*.ico',
813
+ '*.psd'
814
+ ]
815
+
816
+ model_source = model_source.lower() # eliminating case sensitivity
817
+
818
+ if not isinstance(image_input, list):
819
+ # If image_input is a filepath (string)
820
+ image_files = []
821
+ for ext in image_extensions:
822
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
823
+
824
+ print(f"Found {len(image_files)} images.")
825
+ else:
826
+ # If image_files is already a list
827
+ image_files = image_input
828
+ print(f"Provided a list of {len(image_input)} images.")
829
+
830
+ with open(reference_image, 'rb') as f:
831
+ reference = base64.b64encode(f.read()).decode('utf-8')
832
+ reference_image = f"data:image/{reference_image.split('.')[-1]};base64,{reference}"
833
+
834
+ link1 = []
835
+ extracted_jsons = []
836
+
837
+ for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
838
+ # Check validity first
839
+ if img_path is None or not os.path.exists(img_path):
840
+ link1.append("Skipped NaN input or invalid path")
841
+ extracted_jsons.append("""{"no_valid_image": 1}""")
842
+ continue # Skip the rest of the loop iteration
843
+
844
+ # Only open the file if path is valid
845
+ if os.path.isdir(img_path):
846
+ encoded = "Not a Valid Image, contains file path"
847
+ else:
848
+ try:
849
+ with open(img_path, "rb") as f:
850
+ encoded = base64.b64encode(f.read()).decode("utf-8")
851
+ except Exception as e:
852
+ encoded = f"Error: {str(e)}"
853
+ # Handle extension safely
854
+ if encoded.startswith("Error:") or encoded == "Not a Valid Image, contains file path":
855
+ encoded_image = encoded
856
+ valid_image = False
857
+
858
+ else:
859
+ ext = Path(img_path).suffix.lstrip(".").lower()
860
+ encoded_image = f"data:image/{ext};base64,{encoded}"
861
+ valid_image = True
862
+
863
+ # Handle extension safely
864
+ ext = Path(img_path).suffix.lstrip(".").lower()
865
+ encoded_image = f"data:image/{ext};base64,{encoded}"
866
+
867
+ if model_source == "openai" or model_source == "mistral":
868
+ prompt = [
869
+ {
870
+ "type": "text",
871
+ "text": (
872
+ f"You are a visual similarity assessment system.\n"
873
+ f"Task ► Compare these two images:\n"
874
+ f"1. REFERENCE (left): {reference_image_description}\n"
875
+ f"2. INPUT (right): User-provided drawing\n\n"
876
+ f"Rating criteria:\n"
877
+ f"1: No meaningful similarity (fundamentally different)\n"
878
+ f"2: Barely recognizable similarity (25% match)\n"
879
+ f"3: Partial match (50% key features)\n"
880
+ f"4: Strong alignment (75% features)\n"
881
+ f"5: Near-perfect match (90%+ similarity)\n\n"
882
+ f"Output format ► Return ONLY:\n"
883
+ "{\n"
884
+ ' "score": [1-5],\n'
885
+ ' "summary": "reason you scored"\n'
886
+ "}\n\n"
887
+ f"Critical rules:\n"
888
+ f"- Score must reflect shape, proportions, and key details\n"
889
+ f"- List only concrete matching elements from reference\n"
890
+ f"- No markdown or additional text"
891
+ )
892
+ },
893
+ {
894
+ "type": "image_url",
895
+ "image_url": {"url": reference_image, "detail": "high"}
896
+ },
897
+ {
898
+ "type": "image_url",
899
+ "image_url": {"url": encoded_image, "detail": "high"}
900
+ }
901
+ ]
902
+
903
+ elif model_source == "anthropic": # Changed to elif
904
+ prompt = [
905
+ {
906
+ "type": "text",
907
+ "text": (
908
+ f"You are a visual similarity assessment system.\n"
909
+ f"Task ► Compare these two images:\n"
910
+ f"1. REFERENCE (left): {reference_image_description}\n"
911
+ f"2. INPUT (right): User-provided drawing\n\n"
912
+ f"Rating criteria:\n"
913
+ f"1: No meaningful similarity (fundamentally different)\n"
914
+ f"2: Barely recognizable similarity (25% match)\n"
915
+ f"3: Partial match (50% key features)\n"
916
+ f"4: Strong alignment (75% features)\n"
917
+ f"5: Near-perfect match (90%+ similarity)\n\n"
918
+ f"Output format ► Return ONLY:\n"
919
+ "{\n"
920
+ ' "score": [1-5],\n'
921
+ ' "summary": "reason you scored"\n'
922
+ "}\n\n"
923
+ f"Critical rules:\n"
924
+ f"- Score must reflect shape, proportions, and key details\n"
925
+ f"- List only concrete matching elements from reference\n"
926
+ f"- No markdown or additional text"
927
+ )
928
+ },
929
+ {
930
+ "type": "image", # Added missing type
931
+ "source": {
932
+ "type": "base64",
933
+ "media_type": "image/png",
934
+ "data": reference
935
+ }
936
+ },
937
+ {
938
+ "type": "image", # Added missing type
939
+ "source": {
940
+ "type": "base64",
941
+ "media_type": "image/jpeg",
942
+ "data": encoded
943
+ }
944
+ }
945
+ ]
946
+
947
+
948
+ if model_source == "openai":
949
+ import requests as req
950
+ endpoint = "https://api.openai.com/v1/chat/completions"
951
+ headers = {
952
+ "Content-Type": "application/json",
953
+ "Authorization": f"Bearer {api_key}"
954
+ }
955
+ payload = {
956
+ "model": user_model,
957
+ "messages": [{'role': 'user', 'content': prompt}],
958
+ }
959
+ if creativity is not None:
960
+ payload["temperature"] = creativity
961
+ try:
962
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
963
+ response.raise_for_status()
964
+ result = response.json()
965
+ reply = result["choices"][0]["message"]["content"]
966
+ link1.append(reply)
967
+ except req.exceptions.HTTPError as e:
968
+ if e.response and e.response.status_code == 404:
969
+ raise ValueError(f"Invalid OpenAI model '{user_model}': {e}")
970
+ else:
971
+ print(f"An error occurred: {e}")
972
+ link1.append(f"Error processing input: {e}")
973
+ except Exception as e:
974
+ print(f"An error occurred: {e}")
975
+ link1.append(f"Error processing input: {e}")
976
+
977
+ elif model_source == "anthropic":
978
+ import requests as req
979
+ endpoint = "https://api.anthropic.com/v1/messages"
980
+ headers = {
981
+ "Content-Type": "application/json",
982
+ "x-api-key": api_key,
983
+ "anthropic-version": "2023-06-01"
984
+ }
985
+ payload = {
986
+ "model": user_model,
987
+ "max_tokens": 1024,
988
+ "messages": [{"role": "user", "content": prompt}],
989
+ }
990
+ if creativity is not None:
991
+ payload["temperature"] = creativity
992
+ try:
993
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
994
+ response.raise_for_status()
995
+ result = response.json()
996
+ content = result.get("content", [])
997
+ if content and content[0].get("type") == "text":
998
+ reply = content[0].get("text", "")
999
+ link1.append(reply)
1000
+ else:
1001
+ link1.append("Error processing input: No text content in response")
1002
+ except req.exceptions.HTTPError as e:
1003
+ if e.response is not None and e.response.status_code == 404:
1004
+ raise ValueError(f"Invalid Anthropic model '{user_model}': {e}")
1005
+ else:
1006
+ print(f"An error occurred: {e}")
1007
+ link1.append(f"Error processing input: {e}")
1008
+ except Exception as e:
1009
+ print(f"An error occurred: {e}")
1010
+ link1.append(f"Error processing input: {e}")
1011
+
1012
+ elif model_source == "mistral":
1013
+ import requests as req
1014
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
1015
+ headers = {
1016
+ "Content-Type": "application/json",
1017
+ "Authorization": f"Bearer {api_key}"
1018
+ }
1019
+ payload = {
1020
+ "model": user_model,
1021
+ "messages": [{'role': 'user', 'content': prompt}],
1022
+ }
1023
+ if creativity is not None:
1024
+ payload["temperature"] = creativity
1025
+ try:
1026
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1027
+ response.raise_for_status()
1028
+ result = response.json()
1029
+ reply = result["choices"][0]["message"]["content"]
1030
+ link1.append(reply)
1031
+ except req.exceptions.HTTPError as e:
1032
+ if e.response and e.response.status_code == 404:
1033
+ raise ValueError(f"Invalid Mistral model '{user_model}': {e}")
1034
+ else:
1035
+ print(f"An error occurred: {e}")
1036
+ link1.append(f"Error processing input: {e}")
1037
+ except Exception as e:
1038
+ print(f"An error occurred: {e}")
1039
+ link1.append(f"Error processing input: {e}")
1040
+ #if no valid image path is provided
1041
+ elif valid_image == False:
1042
+ reply = "invalid image path"
1043
+ print("Skipped NaN input or invalid path")
1044
+ #extracted_jsons.append("""{"no_valid_path": 1}""")
1045
+ link1.append("Error processing input: {e}")
1046
+ else:
1047
+ raise ValueError("Unknown source! Choose from OpenAI, Perplexity, or Mistral")
1048
+ # in situation that no JSON is found
1049
+ if reply is not None:
1050
+ if reply == "invalid image path":
1051
+ extracted_jsons.append("""{"no_valid_path": 1}""")
1052
+ else:
1053
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
1054
+ if extracted_json:
1055
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
1056
+ extracted_jsons.append(cleaned_json)
1057
+ else:
1058
+ error_message = """{"1":"e"}"""
1059
+ extracted_jsons.append(error_message)
1060
+ print(error_message)
1061
+ else:
1062
+ error_message = """{"1":"e"}"""
1063
+ extracted_jsons.append(error_message)
1064
+ print(error_message)
1065
+
1066
+ # --- Safety Save ---
1067
+ if safety:
1068
+ # Save progress so far
1069
+ temp_df = pd.DataFrame({
1070
+ 'image_input': image_files[:i+1],
1071
+ 'model_response': link1,
1072
+ 'json': extracted_jsons
1073
+ })
1074
+ # Normalize processed jsons so far
1075
+ normalized_data_list = []
1076
+ for json_str in extracted_jsons:
1077
+ try:
1078
+ parsed_obj = json.loads(json_str)
1079
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
1080
+ except json.JSONDecodeError:
1081
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1082
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1083
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
1084
+ # Save to CSV
1085
+ if save_directory is None:
1086
+ save_directory = os.getcwd()
1087
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
1088
+
1089
+ # --- Final DataFrame ---
1090
+ normalized_data_list = []
1091
+ for json_str in extracted_jsons:
1092
+ try:
1093
+ parsed_obj = json.loads(json_str)
1094
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
1095
+ except json.JSONDecodeError:
1096
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1097
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1098
+
1099
+ categorized_data = pd.DataFrame({
1100
+ 'image_input': (
1101
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
1102
+ else pd.Series(image_files)
1103
+ ),
1104
+ 'link1': pd.Series(link1).reset_index(drop=True),
1105
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
1106
+ })
1107
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
1108
+
1109
+ if to_csv:
1110
+ if save_directory is None:
1111
+ save_directory = os.getcwd()
1112
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
1113
+
1114
+ return categorized_data
1115
+
1116
+ # image features function
1117
+ def image_features(
1118
+ image_description,
1119
+ image_input,
1120
+ features_to_extract,
1121
+ api_key,
1122
+ user_model="gpt-4o-2024-11-20",
1123
+ creativity=None,
1124
+ to_csv=False,
1125
+ safety=False,
1126
+ filename="categorized_data.csv",
1127
+ save_directory=None,
1128
+ model_source="OpenAI"
1129
+ ):
1130
+ import os
1131
+ import json
1132
+ import pandas as pd
1133
+ import regex
1134
+ from tqdm import tqdm
1135
+ import glob
1136
+ import base64
1137
+ from pathlib import Path
1138
+
1139
+ image_extensions = [
1140
+ '*.png', '*.jpg', '*.jpeg',
1141
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
1142
+ '*.tif', '*.tiff', '*.bmp',
1143
+ '*.heif', '*.heic', '*.ico',
1144
+ '*.psd'
1145
+ ]
1146
+
1147
+ model_source = model_source.lower() # eliminating case sensitivity
1148
+
1149
+ if not isinstance(image_input, list):
1150
+ # If image_input is a filepath (string)
1151
+ image_files = []
1152
+ for ext in image_extensions:
1153
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
1154
+
1155
+ print(f"Found {len(image_files)} images.")
1156
+ else:
1157
+ # If image_files is already a list
1158
+ image_files = image_input
1159
+ print(f"Provided a list of {len(image_input)} images.")
1160
+
1161
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(features_to_extract))
1162
+ cat_num = len(features_to_extract)
1163
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
1164
+ example_JSON = json.dumps(category_dict, indent=4)
1165
+
1166
+ link1 = []
1167
+ extracted_jsons = []
1168
+
1169
+ for i, img_path in enumerate(tqdm(image_files, desc="Scoring images"), start=0):
1170
+ # Check validity first
1171
+ if img_path is None or not os.path.exists(img_path):
1172
+ link1.append("Skipped NaN input or invalid path")
1173
+ extracted_jsons.append("""{"no_valid_image": 1}""")
1174
+ continue # Skip the rest of the loop iteration
1175
+
1176
+ # Only open the file if path is valid
1177
+ if os.path.isdir(img_path):
1178
+ encoded = "Not a Valid Image, contains file path"
1179
+ else:
1180
+ try:
1181
+ with open(img_path, "rb") as f:
1182
+ encoded = base64.b64encode(f.read()).decode("utf-8")
1183
+ except Exception as e:
1184
+ encoded = f"Error: {str(e)}"
1185
+ # Handle extension safely
1186
+ if encoded.startswith("Error:") or encoded == "Not a Valid Image, contains file path":
1187
+ encoded_image = encoded
1188
+ valid_image = False
1189
+
1190
+ else:
1191
+ ext = Path(img_path).suffix.lstrip(".").lower()
1192
+ encoded_image = f"data:image/{ext};base64,{encoded}"
1193
+ valid_image = True
1194
+
1195
+ if model_source == "openai" or model_source == "mistral":
1196
+ prompt = [
1197
+ {
1198
+ "type": "text",
1199
+ "text": (
1200
+ f"You are a visual question answering assistant.\n"
1201
+ f"Task ► Analyze the attached image and answer these specific questions:\n\n"
1202
+ f"Image context: {image_description}\n\n"
1203
+ f"Questions to answer:\n{categories_str}\n\n"
1204
+ f"Output format ► Return **only** a JSON object where:\n"
1205
+ f"- Keys are question numbers ('1', '2', ...)\n"
1206
+ f"- Values are concise answers (numbers, short phrases)\n\n"
1207
+ f"Example for 3 questions:\n"
1208
+ "{\n"
1209
+ ' "1": "4",\n'
1210
+ ' "2": "blue",\n'
1211
+ ' "3": "yes"\n'
1212
+ "}\n\n"
1213
+ f"Important rules:\n"
1214
+ f"1. Answer directly - no explanations\n"
1215
+ f"2. Use exact numerical values when possible\n"
1216
+ f"3. For yes/no questions, use 'yes' or 'no'\n"
1217
+ f"4. Never add extra keys or formatting"
1218
+ ),
1219
+ },
1220
+ {
1221
+ "type": "image_url",
1222
+ "image_url": {"url": encoded_image, "detail": "high"},
1223
+ },
1224
+ ]
1225
+ elif model_source == "anthropic":
1226
+ prompt = [
1227
+ {
1228
+ "type": "text",
1229
+ "text": (
1230
+ f"You are a visual question answering assistant.\n"
1231
+ f"Task ► Analyze the attached image and answer these specific questions:\n\n"
1232
+ f"Image context: {image_description}\n\n"
1233
+ f"Questions to answer:\n{categories_str}\n\n"
1234
+ f"Output format ► Return **only** a JSON object where:\n"
1235
+ f"- Keys are question numbers ('1', '2', ...)\n"
1236
+ f"- Values are concise answers (numbers, short phrases)\n\n"
1237
+ f"Example for 3 questions:\n"
1238
+ "{\n"
1239
+ ' "1": "4",\n'
1240
+ ' "2": "blue",\n'
1241
+ ' "3": "yes"\n'
1242
+ "}\n\n"
1243
+ f"Important rules:\n"
1244
+ f"1. Answer directly - no explanations\n"
1245
+ f"2. Use exact numerical values when possible\n"
1246
+ f"3. For yes/no questions, use 'yes' or 'no'\n"
1247
+ f"4. Never add extra keys or formatting"
1248
+ )
1249
+ },
1250
+ {
1251
+ "type": "image",
1252
+ "source": {
1253
+ "type": "base64",
1254
+ "media_type": "image/jpeg",
1255
+ "data": encoded
1256
+ }
1257
+ }
1258
+ ]
1259
+ if model_source == "openai":
1260
+ import requests as req
1261
+ endpoint = "https://api.openai.com/v1/chat/completions"
1262
+ headers = {
1263
+ "Content-Type": "application/json",
1264
+ "Authorization": f"Bearer {api_key}"
1265
+ }
1266
+ payload = {
1267
+ "model": user_model,
1268
+ "messages": [{'role': 'user', 'content': prompt}],
1269
+ }
1270
+ if creativity is not None:
1271
+ payload["temperature"] = creativity
1272
+ try:
1273
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1274
+ response.raise_for_status()
1275
+ result = response.json()
1276
+ reply = result["choices"][0]["message"]["content"]
1277
+ link1.append(reply)
1278
+ except req.exceptions.HTTPError as e:
1279
+ if e.response and e.response.status_code == 404:
1280
+ raise ValueError(f"Invalid OpenAI model '{user_model}': {e}")
1281
+ else:
1282
+ print(f"An error occurred: {e}")
1283
+ link1.append(f"Error processing input: {e}")
1284
+ except Exception as e:
1285
+ print(f"An error occurred: {e}")
1286
+ link1.append(f"Error processing input: {e}")
1287
+
1288
+ elif model_source == "perplexity":
1289
+ import requests as req
1290
+ endpoint = "https://api.perplexity.ai/chat/completions"
1291
+ headers = {
1292
+ "Content-Type": "application/json",
1293
+ "Authorization": f"Bearer {api_key}"
1294
+ }
1295
+ payload = {
1296
+ "model": user_model,
1297
+ "messages": [{'role': 'user', 'content': prompt}],
1298
+ }
1299
+ if creativity is not None:
1300
+ payload["temperature"] = creativity
1301
+ try:
1302
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1303
+ response.raise_for_status()
1304
+ result = response.json()
1305
+ reply = result["choices"][0]["message"]["content"]
1306
+ link1.append(reply)
1307
+ except req.exceptions.HTTPError as e:
1308
+ if e.response and e.response.status_code == 404:
1309
+ raise ValueError(f"Invalid Perplexity model '{user_model}': {e}")
1310
+ else:
1311
+ print(f"An error occurred: {e}")
1312
+ link1.append(f"Error processing input: {e}")
1313
+ except Exception as e:
1314
+ print(f"An error occurred: {e}")
1315
+ link1.append(f"Error processing input: {e}")
1316
+
1317
+ elif model_source == "anthropic":
1318
+ import requests as req
1319
+ endpoint = "https://api.anthropic.com/v1/messages"
1320
+ headers = {
1321
+ "Content-Type": "application/json",
1322
+ "x-api-key": api_key,
1323
+ "anthropic-version": "2023-06-01"
1324
+ }
1325
+ payload = {
1326
+ "model": user_model,
1327
+ "max_tokens": 1024,
1328
+ "messages": [{"role": "user", "content": prompt}],
1329
+ }
1330
+ if creativity is not None:
1331
+ payload["temperature"] = creativity
1332
+ try:
1333
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1334
+ response.raise_for_status()
1335
+ result = response.json()
1336
+ content = result.get("content", [])
1337
+ if content and content[0].get("type") == "text":
1338
+ reply = content[0].get("text", "")
1339
+ link1.append(reply)
1340
+ else:
1341
+ link1.append("Error processing input: No text content in response")
1342
+ except req.exceptions.HTTPError as e:
1343
+ if e.response is not None and e.response.status_code == 404:
1344
+ raise ValueError(f"Invalid Anthropic model '{user_model}': {e}")
1345
+ else:
1346
+ print(f"An error occurred: {e}")
1347
+ link1.append(f"Error processing input: {e}")
1348
+ except Exception as e:
1349
+ print(f"An error occurred: {e}")
1350
+ link1.append(f"Error processing input: {e}")
1351
+
1352
+ elif model_source == "mistral":
1353
+ import requests as req
1354
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
1355
+ headers = {
1356
+ "Content-Type": "application/json",
1357
+ "Authorization": f"Bearer {api_key}"
1358
+ }
1359
+ payload = {
1360
+ "model": user_model,
1361
+ "messages": [{'role': 'user', 'content': prompt}],
1362
+ }
1363
+ if creativity is not None:
1364
+ payload["temperature"] = creativity
1365
+ try:
1366
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1367
+ response.raise_for_status()
1368
+ result = response.json()
1369
+ reply = result["choices"][0]["message"]["content"]
1370
+ link1.append(reply)
1371
+ except req.exceptions.HTTPError as e:
1372
+ if e.response and e.response.status_code == 404:
1373
+ raise ValueError(f"Invalid Mistral model '{user_model}': {e}")
1374
+ else:
1375
+ print(f"An error occurred: {e}")
1376
+ link1.append(f"Error processing input: {e}")
1377
+ except Exception as e:
1378
+ print(f"An error occurred: {e}")
1379
+ link1.append(f"Error processing input: {e}")
1380
+
1381
+ elif valid_image == False:
1382
+ print("Skipped NaN input or invalid path")
1383
+ reply = None
1384
+ link1.append(f"Error processing input: invalid image")
1385
+ else:
1386
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
1387
+ # in situation that no JSON is found
1388
+ if reply is not None:
1389
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
1390
+ if extracted_json:
1391
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
1392
+ extracted_jsons.append(cleaned_json)
1393
+ #print(cleaned_json)
1394
+ else:
1395
+ error_message = """{"1":"e"}"""
1396
+ extracted_jsons.append(error_message)
1397
+ print(error_message)
1398
+ else:
1399
+ error_message = """{"1":"e"}"""
1400
+ extracted_jsons.append(error_message)
1401
+ #print(error_message)
1402
+
1403
+ # --- Safety Save ---
1404
+ if safety:
1405
+ #print(f"Saving CSV to: {save_directory}")
1406
+ # Save progress so far
1407
+ temp_df = pd.DataFrame({
1408
+ 'image_input': image_files[:i+1],
1409
+ 'link1': link1,
1410
+ 'json': extracted_jsons
1411
+ })
1412
+ # Normalize processed jsons so far
1413
+ normalized_data_list = []
1414
+ for json_str in extracted_jsons:
1415
+ try:
1416
+ parsed_obj = json.loads(json_str)
1417
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
1418
+ except json.JSONDecodeError:
1419
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1420
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1421
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
1422
+ # Save to CSV
1423
+ if save_directory is None:
1424
+ save_directory = os.getcwd()
1425
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
1426
+
1427
+ # --- Final DataFrame ---
1428
+ normalized_data_list = []
1429
+ for json_str in extracted_jsons:
1430
+ try:
1431
+ parsed_obj = json.loads(json_str)
1432
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
1433
+ except json.JSONDecodeError:
1434
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
1435
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
1436
+
1437
+ categorized_data = pd.DataFrame({
1438
+ 'image_input': (
1439
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
1440
+ else pd.Series(image_files)
1441
+ ),
1442
+ 'model_response': pd.Series(link1).reset_index(drop=True),
1443
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
1444
+ })
1445
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
1446
+
1447
+ if to_csv:
1448
+ if save_directory is None:
1449
+ save_directory = os.getcwd()
1450
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
1451
+
1452
+ return categorized_data
1453
+
1454
+
1455
+ def explore_image_categories(
1456
+ image_input,
1457
+ api_key,
1458
+ image_description="",
1459
+ max_categories=12,
1460
+ categories_per_chunk=10,
1461
+ divisions=5,
1462
+ user_model="gpt-4o",
1463
+ creativity=None,
1464
+ specificity="broad",
1465
+ research_question=None,
1466
+ mode="image",
1467
+ filename=None,
1468
+ model_source="auto",
1469
+ iterations=3,
1470
+ random_state=None,
1471
+ progress_callback=None,
1472
+ ):
1473
+ """
1474
+ Explore and extract common categories from a collection of images.
1475
+
1476
+ Modes:
1477
+ - "image" (default): Samples random images and sends them directly to
1478
+ a vision model for category extraction. Best for visual categorization.
1479
+
1480
+ - "both": Samples random images, uses vision model to describe each
1481
+ image's content (including any text), then extracts categories from
1482
+ those descriptions. Best for images that contain text or mixed content.
1483
+
1484
+ Args:
1485
+ image_input: Path to image file, directory of images, or list of image paths
1486
+ api_key: API key for the model provider
1487
+ image_description: Description of what the images contain
1488
+ max_categories: Maximum number of final categories to return
1489
+ categories_per_chunk: Categories to extract per chunk of images
1490
+ divisions: Number of chunks to divide images into
1491
+ user_model: Model to use (must support vision)
1492
+ creativity: Temperature setting (None for default)
1493
+ specificity: "broad" or "specific" category granularity
1494
+ research_question: Optional research context
1495
+ mode: "image" or "both"
1496
+ filename: Optional CSV filename to save results
1497
+ model_source: "auto", "openai", "anthropic", "google", "mistral"
1498
+ iterations: Number of passes over the data
1499
+ random_state: Random seed for reproducibility
1500
+ progress_callback: Optional callback function for progress updates.
1501
+ Called as progress_callback(current_step, total_steps, step_label).
1502
+
1503
+ Returns:
1504
+ dict with keys:
1505
+ - counts_df: DataFrame of categories with counts
1506
+ - top_categories: List of top category names
1507
+ - raw_top_text: Raw model output from final merge step
1508
+ """
1509
+ import os
1510
+ import re
1511
+ import pandas as pd
1512
+ import numpy as np
1513
+ from tqdm import tqdm
1514
+
1515
+ model_source = _detect_model_source(user_model, model_source)
1516
+
1517
+ # Load all images
1518
+ image_files = _load_image_files(image_input)
1519
+ if not image_files:
1520
+ raise ValueError("No image files found in the specified input.")
1521
+
1522
+ n = len(image_files)
1523
+ if n == 0:
1524
+ raise ValueError("No images found.")
1525
+
1526
+ # Auto-adjust divisions for small datasets
1527
+ # Images can have multiple categories each, so we can use fewer divisions
1528
+ original_divisions = divisions
1529
+ divisions = min(divisions, max(1, n // 2)) # At least 2 images per chunk
1530
+ if divisions != original_divisions:
1531
+ print(f"Auto-adjusted divisions from {original_divisions} to {divisions} for {n} images.")
1532
+
1533
+ # Chunk sizing - images often contain multiple categories each
1534
+ chunk_size = int(round(max(1, n / divisions), 0))
1535
+ # Don't reduce categories_per_chunk as aggressively for images since each image can yield many categories
1536
+ if chunk_size < 2:
1537
+ # Only reduce if we have very few images
1538
+ old_categories_per_chunk = categories_per_chunk
1539
+ categories_per_chunk = max(5, chunk_size * 4)
1540
+ print(f"Auto-adjusted categories_per_chunk from {old_categories_per_chunk} to {categories_per_chunk} for chunk size {chunk_size}.")
1541
+
1542
+ print(
1543
+ f"Exploring categories in images: '{image_description}'.\n"
1544
+ f" {n} total images, {categories_per_chunk * divisions} categories to extract, "
1545
+ f"{max_categories} final categories. Mode: {mode}\n"
1546
+ )
1547
+
1548
+ # RNG for reproducible sampling
1549
+ rng = np.random.default_rng(random_state)
1550
+
1551
+ # Validate model_source (clients initialized per-call using requests)
1552
+ import requests as req
1553
+ if model_source not in ["openai", "huggingface", "huggingface-together", "xai", "anthropic", "google", "mistral"]:
1554
+ raise ValueError(f"Unsupported model_source: {model_source}")
1555
+
1556
+ # Determine base URL for OpenAI-compatible providers
1557
+ if model_source == "huggingface":
1558
+ from cat_stack.text_functions import _detect_huggingface_endpoint
1559
+ openai_base_url = _detect_huggingface_endpoint(api_key, user_model)
1560
+ elif model_source == "huggingface-together":
1561
+ openai_base_url = "https://router.huggingface.co/together/v1"
1562
+ elif model_source == "xai":
1563
+ openai_base_url = "https://api.x.ai/v1"
1564
+ elif model_source == "openai":
1565
+ openai_base_url = "https://api.openai.com/v1"
1566
+ else:
1567
+ openai_base_url = None # Not an OpenAI-compatible provider
1568
+
1569
+ def make_image_prompt() -> str:
1570
+ """Build prompt for image mode - direct category extraction."""
1571
+ return (
1572
+ f"Identify {categories_per_chunk} {specificity} categories of content found in this image. "
1573
+ f"The image is: {image_description}. "
1574
+ f"{'Research context: ' + research_question if research_question else ''}\n\n"
1575
+ f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
1576
+ )
1577
+
1578
+ def make_describe_prompt() -> str:
1579
+ """Build prompt for 'both' mode - describe image content."""
1580
+ return (
1581
+ f"Describe the content of this image in detail. "
1582
+ f"Include all visual elements, text, objects, people, and any other content. "
1583
+ f"The image is: {image_description}. "
1584
+ f"{'Research context: ' + research_question if research_question else ''}\n\n"
1585
+ f"Provide a comprehensive text description that captures both visual and textual content."
1586
+ )
1587
+
1588
+ def make_text_prompt(text_blob: str) -> str:
1589
+ """Build prompt for extracting categories from text descriptions."""
1590
+ return (
1591
+ f"Identify {categories_per_chunk} {specificity} categories of content found in this description. "
1592
+ f"The content is: {image_description}. "
1593
+ f"{'Research context: ' + research_question + '. ' if research_question else ''}"
1594
+ f"The description is contained within triple backticks: ```{text_blob}``` "
1595
+ f"Number your categories from 1 through {categories_per_chunk} and provide concise labels only (no descriptions)."
1596
+ )
1597
+
1598
+ def call_model_with_image(img_path, prompt_text, max_retries=6):
1599
+ """Send an image to the model and get category extraction."""
1600
+ encoded, ext, is_valid = _encode_image(img_path)
1601
+ if not is_valid:
1602
+ return None
1603
+
1604
+ for attempt in range(max_retries):
1605
+ try:
1606
+ if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
1607
+ endpoint = f"{openai_base_url}/chat/completions"
1608
+ headers = {
1609
+ "Content-Type": "application/json",
1610
+ "Authorization": f"Bearer {api_key}"
1611
+ }
1612
+ messages = [{
1613
+ "role": "user",
1614
+ "content": [
1615
+ {"type": "text", "text": prompt_text},
1616
+ {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
1617
+ ]
1618
+ }]
1619
+ payload = {"model": user_model, "messages": messages}
1620
+ if creativity is not None:
1621
+ payload["temperature"] = creativity
1622
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1623
+ response.raise_for_status()
1624
+ result = response.json()
1625
+ return result["choices"][0]["message"]["content"]
1626
+
1627
+ elif model_source == "anthropic":
1628
+ endpoint = "https://api.anthropic.com/v1/messages"
1629
+ headers = {
1630
+ "Content-Type": "application/json",
1631
+ "x-api-key": api_key,
1632
+ "anthropic-version": "2023-06-01"
1633
+ }
1634
+ media_type = f"image/{ext}" if ext else "image/jpeg"
1635
+ content = [
1636
+ {"type": "text", "text": prompt_text},
1637
+ {"type": "image", "source": {"type": "base64", "media_type": media_type, "data": encoded}}
1638
+ ]
1639
+ payload = {
1640
+ "model": user_model,
1641
+ "max_tokens": 2048,
1642
+ "messages": [{"role": "user", "content": content}],
1643
+ }
1644
+ if creativity is not None:
1645
+ payload["temperature"] = creativity
1646
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1647
+ response.raise_for_status()
1648
+ result = response.json()
1649
+ resp_content = result.get("content", [])
1650
+ if resp_content and resp_content[0].get("type") == "text":
1651
+ return resp_content[0].get("text", "")
1652
+ return None
1653
+
1654
+ elif model_source == "google":
1655
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
1656
+ headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
1657
+ mime_type = f"image/{ext}" if ext else "image/jpeg"
1658
+ parts = [
1659
+ {"text": prompt_text},
1660
+ {"inline_data": {"mime_type": mime_type, "data": encoded}}
1661
+ ]
1662
+ payload = {
1663
+ "contents": [{"parts": parts}],
1664
+ "generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
1665
+ }
1666
+ response = req.post(url, headers=headers, json=payload, timeout=120)
1667
+ response.raise_for_status()
1668
+ result = response.json()
1669
+ if "candidates" in result and result["candidates"]:
1670
+ return result["candidates"][0]["content"]["parts"][0]["text"]
1671
+ return None
1672
+
1673
+ elif model_source == "mistral":
1674
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
1675
+ headers = {
1676
+ "Content-Type": "application/json",
1677
+ "Authorization": f"Bearer {api_key}"
1678
+ }
1679
+ messages = [{
1680
+ "role": "user",
1681
+ "content": [
1682
+ {"type": "text", "text": prompt_text},
1683
+ {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
1684
+ ]
1685
+ }]
1686
+ payload = {"model": user_model, "messages": messages}
1687
+ if creativity is not None:
1688
+ payload["temperature"] = creativity
1689
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1690
+ response.raise_for_status()
1691
+ result = response.json()
1692
+ return result["choices"][0]["message"]["content"]
1693
+
1694
+ except Exception as e:
1695
+ delay = 2 ** attempt
1696
+ if attempt < max_retries - 1:
1697
+ print(f"Error processing image {img_path}: {e}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
1698
+ import time as _time
1699
+ _time.sleep(delay)
1700
+ else:
1701
+ print(f"Error processing image {img_path}: {e}. All {max_retries} attempts failed.")
1702
+ return None
1703
+
1704
+ def describe_image_with_vision(img_path, max_retries=6):
1705
+ """Use vision model to describe an image's content as text."""
1706
+ encoded, ext, is_valid = _encode_image(img_path)
1707
+ if not is_valid:
1708
+ return None
1709
+ prompt_text = make_describe_prompt()
1710
+
1711
+ for attempt in range(max_retries):
1712
+ try:
1713
+ if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
1714
+ endpoint = f"{openai_base_url}/chat/completions"
1715
+ headers = {
1716
+ "Content-Type": "application/json",
1717
+ "Authorization": f"Bearer {api_key}"
1718
+ }
1719
+ messages = [{
1720
+ "role": "user",
1721
+ "content": [
1722
+ {"type": "text", "text": prompt_text},
1723
+ {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
1724
+ ]
1725
+ }]
1726
+ payload = {"model": user_model, "messages": messages}
1727
+ if creativity is not None:
1728
+ payload["temperature"] = creativity
1729
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1730
+ response.raise_for_status()
1731
+ result = response.json()
1732
+ return result["choices"][0]["message"]["content"]
1733
+
1734
+ elif model_source == "anthropic":
1735
+ endpoint = "https://api.anthropic.com/v1/messages"
1736
+ headers = {
1737
+ "Content-Type": "application/json",
1738
+ "x-api-key": api_key,
1739
+ "anthropic-version": "2023-06-01"
1740
+ }
1741
+ media_type = f"image/{ext}" if ext else "image/jpeg"
1742
+ content = [
1743
+ {"type": "text", "text": prompt_text},
1744
+ {"type": "image", "source": {"type": "base64", "media_type": media_type, "data": encoded}}
1745
+ ]
1746
+ payload = {
1747
+ "model": user_model,
1748
+ "max_tokens": 4096,
1749
+ "messages": [{"role": "user", "content": content}],
1750
+ }
1751
+ if creativity is not None:
1752
+ payload["temperature"] = creativity
1753
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1754
+ response.raise_for_status()
1755
+ result = response.json()
1756
+ resp_content = result.get("content", [])
1757
+ if resp_content and resp_content[0].get("type") == "text":
1758
+ return resp_content[0].get("text", "")
1759
+ return None
1760
+
1761
+ elif model_source == "google":
1762
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
1763
+ headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
1764
+ mime_type = f"image/{ext}" if ext else "image/jpeg"
1765
+ parts = [
1766
+ {"text": prompt_text},
1767
+ {"inline_data": {"mime_type": mime_type, "data": encoded}}
1768
+ ]
1769
+ payload = {
1770
+ "contents": [{"parts": parts}],
1771
+ "generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
1772
+ }
1773
+ response = req.post(url, headers=headers, json=payload, timeout=120)
1774
+ response.raise_for_status()
1775
+ result = response.json()
1776
+ if "candidates" in result and result["candidates"]:
1777
+ return result["candidates"][0]["content"]["parts"][0]["text"]
1778
+ return None
1779
+
1780
+ elif model_source == "mistral":
1781
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
1782
+ headers = {
1783
+ "Content-Type": "application/json",
1784
+ "Authorization": f"Bearer {api_key}"
1785
+ }
1786
+ messages = [{
1787
+ "role": "user",
1788
+ "content": [
1789
+ {"type": "text", "text": prompt_text},
1790
+ {"type": "image_url", "image_url": {"url": f"data:image/{ext};base64,{encoded}"}}
1791
+ ]
1792
+ }]
1793
+ payload = {"model": user_model, "messages": messages}
1794
+ if creativity is not None:
1795
+ payload["temperature"] = creativity
1796
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1797
+ response.raise_for_status()
1798
+ result = response.json()
1799
+ return result["choices"][0]["message"]["content"]
1800
+
1801
+ except Exception as e:
1802
+ delay = 2 ** attempt
1803
+ if attempt < max_retries - 1:
1804
+ print(f"Error describing image {img_path}: {e}. Retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
1805
+ import time as _time
1806
+ _time.sleep(delay)
1807
+ else:
1808
+ print(f"Error describing image {img_path}: {e}. All {max_retries} attempts failed.")
1809
+ return None
1810
+
1811
+ def call_model_with_text(prompt_text):
1812
+ """Send text to the model for category extraction."""
1813
+ try:
1814
+ if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
1815
+ endpoint = f"{openai_base_url}/chat/completions"
1816
+ headers = {
1817
+ "Content-Type": "application/json",
1818
+ "Authorization": f"Bearer {api_key}"
1819
+ }
1820
+ payload = {"model": user_model, "messages": [{"role": "user", "content": prompt_text}]}
1821
+ if creativity is not None:
1822
+ payload["temperature"] = creativity
1823
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1824
+ response.raise_for_status()
1825
+ result = response.json()
1826
+ return result["choices"][0]["message"]["content"]
1827
+
1828
+ elif model_source == "anthropic":
1829
+ endpoint = "https://api.anthropic.com/v1/messages"
1830
+ headers = {
1831
+ "Content-Type": "application/json",
1832
+ "x-api-key": api_key,
1833
+ "anthropic-version": "2023-06-01"
1834
+ }
1835
+ payload = {
1836
+ "model": user_model,
1837
+ "max_tokens": 2048,
1838
+ "messages": [{"role": "user", "content": prompt_text}],
1839
+ }
1840
+ if creativity is not None:
1841
+ payload["temperature"] = creativity
1842
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1843
+ response.raise_for_status()
1844
+ result = response.json()
1845
+ resp_content = result.get("content", [])
1846
+ if resp_content and resp_content[0].get("type") == "text":
1847
+ return resp_content[0].get("text", "")
1848
+ return None
1849
+
1850
+ elif model_source == "google":
1851
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
1852
+ headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
1853
+ payload = {
1854
+ "contents": [{"parts": [{"text": prompt_text}]}],
1855
+ "generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
1856
+ }
1857
+ response = req.post(url, headers=headers, json=payload, timeout=120)
1858
+ response.raise_for_status()
1859
+ result = response.json()
1860
+ if "candidates" in result and result["candidates"]:
1861
+ return result["candidates"][0]["content"]["parts"][0]["text"]
1862
+ return None
1863
+
1864
+ elif model_source == "mistral":
1865
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
1866
+ headers = {
1867
+ "Content-Type": "application/json",
1868
+ "Authorization": f"Bearer {api_key}"
1869
+ }
1870
+ payload = {"model": user_model, "messages": [{"role": "user", "content": prompt_text}]}
1871
+ if creativity is not None:
1872
+ payload["temperature"] = creativity
1873
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
1874
+ response.raise_for_status()
1875
+ result = response.json()
1876
+ return result["choices"][0]["message"]["content"]
1877
+
1878
+ except Exception as e:
1879
+ print(f"Error in text mode: {e}")
1880
+ return None
1881
+
1882
+ # Parse numbered list pattern
1883
+ line_pat = re.compile(r"^\s*\d+\s*[\.\)\-]\s*(.+)$")
1884
+
1885
+ all_items = []
1886
+
1887
+ # Calculate total steps for progress tracking: (iterations * divisions) + 1 for final merge
1888
+ total_steps = (iterations * divisions) + 1
1889
+ current_step = 0
1890
+
1891
+ for pass_idx in range(iterations):
1892
+ # Sample images for this pass
1893
+ image_indices = list(range(n))
1894
+ rng.shuffle(image_indices)
1895
+
1896
+ # Create chunks
1897
+ chunks = [image_indices[i:i + chunk_size] for i in range(0, len(image_indices), chunk_size)][:divisions]
1898
+
1899
+ for chunk_idx, chunk in enumerate(tqdm(chunks, desc=f"Processing chunks (pass {pass_idx+1}/{iterations})")):
1900
+ if not chunk:
1901
+ continue
1902
+
1903
+ # Sample one random image from the full pool
1904
+ random_idx = rng.choice(image_indices)
1905
+ img_path = image_files[random_idx]
1906
+
1907
+ if mode == "image":
1908
+ # IMAGE MODE: Send image directly for category extraction
1909
+ prompt = make_image_prompt()
1910
+ reply = call_model_with_image(img_path, prompt)
1911
+
1912
+ elif mode == "both":
1913
+ # BOTH MODE: Describe image first, then extract categories from description
1914
+ image_description_text = describe_image_with_vision(img_path)
1915
+ if not image_description_text:
1916
+ continue
1917
+
1918
+ prompt = make_text_prompt(image_description_text)
1919
+ reply = call_model_with_text(prompt)
1920
+
1921
+ else:
1922
+ raise ValueError(f"Invalid mode: {mode}. Must be 'image' or 'both'.")
1923
+
1924
+ if reply:
1925
+ # Extract numbered items
1926
+ items = []
1927
+ for raw_line in reply.splitlines():
1928
+ m = line_pat.match(raw_line.strip())
1929
+ if m:
1930
+ items.append(m.group(1).strip())
1931
+ # Fallback for unnumbered lines
1932
+ if not items:
1933
+ for raw_line in reply.splitlines():
1934
+ s = raw_line.strip()
1935
+ if s:
1936
+ items.append(s)
1937
+ all_items.extend(items)
1938
+
1939
+ # Progress callback
1940
+ current_step += 1
1941
+ if progress_callback:
1942
+ progress_callback(current_step, total_steps, f"Pass {pass_idx+1}/{iterations}, chunk {chunk_idx+1}/{len(chunks)}")
1943
+
1944
+ # Normalize and count
1945
+ def normalize_category(cat):
1946
+ terms = sorted([t.strip().lower() for t in str(cat).split("/")])
1947
+ return "/".join(terms)
1948
+
1949
+ flat_list = [str(x).strip() for x in all_items if str(x).strip()]
1950
+ if not flat_list:
1951
+ raise ValueError("No categories were extracted from the images.")
1952
+
1953
+ df = pd.DataFrame(flat_list, columns=["Category"])
1954
+ df["normalized"] = df["Category"].map(normalize_category)
1955
+
1956
+ result = (
1957
+ df.groupby("normalized")
1958
+ .agg(Category=("Category", lambda x: x.value_counts().index[0]),
1959
+ counts=("Category", "size"))
1960
+ .sort_values("counts", ascending=False)
1961
+ .reset_index(drop=True)
1962
+ )
1963
+
1964
+ # Second-pass semantic merge
1965
+ seed_list = result["Category"].head(max_categories * 3).tolist()
1966
+
1967
+ second_prompt = f"""
1968
+ You are a data analyst reviewing categorized image data.
1969
+
1970
+ Task: From the provided categories, identify and return the top {max_categories} CONCEPTUALLY UNIQUE categories.
1971
+
1972
+ Critical Instructions:
1973
+ 1) Exact duplicates are already removed.
1974
+ 2) Merge SEMANTIC duplicates (same concept, different wording).
1975
+ 3) When merging:
1976
+ - Combine frequencies mentally
1977
+ - Keep the most frequent OR clearest label
1978
+ - Each concept appears ONLY ONCE
1979
+ 4) Keep category names {specificity}.
1980
+ 5) Return ONLY a numbered list of {max_categories} categories. No extra text.
1981
+
1982
+ Pre-processed Categories (sorted by frequency, top sample):
1983
+ {seed_list}
1984
+
1985
+ Output:
1986
+ 1. category
1987
+ 2. category
1988
+ ...
1989
+ {max_categories}. category
1990
+ """.strip()
1991
+
1992
+ try:
1993
+ if model_source in ["openai", "huggingface", "huggingface-together", "xai"]:
1994
+ endpoint = f"{openai_base_url}/chat/completions"
1995
+ headers = {
1996
+ "Content-Type": "application/json",
1997
+ "Authorization": f"Bearer {api_key}"
1998
+ }
1999
+ payload = {"model": user_model, "messages": [{"role": "user", "content": second_prompt}]}
2000
+ if creativity is not None:
2001
+ payload["temperature"] = creativity
2002
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
2003
+ response.raise_for_status()
2004
+ result = response.json()
2005
+ top_categories_text = result["choices"][0]["message"]["content"]
2006
+ elif model_source == "anthropic":
2007
+ endpoint = "https://api.anthropic.com/v1/messages"
2008
+ headers = {
2009
+ "Content-Type": "application/json",
2010
+ "x-api-key": api_key,
2011
+ "anthropic-version": "2023-06-01"
2012
+ }
2013
+ payload = {
2014
+ "model": user_model,
2015
+ "max_tokens": 2048,
2016
+ "messages": [{"role": "user", "content": second_prompt}],
2017
+ }
2018
+ if creativity is not None:
2019
+ payload["temperature"] = creativity
2020
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
2021
+ response.raise_for_status()
2022
+ result = response.json()
2023
+ resp_content = result.get("content", [])
2024
+ if resp_content and resp_content[0].get("type") == "text":
2025
+ top_categories_text = resp_content[0].get("text", "")
2026
+ else:
2027
+ top_categories_text = ""
2028
+ elif model_source == "google":
2029
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
2030
+ headers = {"x-goog-api-key": api_key, "Content-Type": "application/json"}
2031
+ payload = {
2032
+ "contents": [{"parts": [{"text": second_prompt}]}],
2033
+ "generationConfig": {**({"temperature": creativity} if creativity is not None else {})}
2034
+ }
2035
+ response = req.post(url, headers=headers, json=payload, timeout=120)
2036
+ response.raise_for_status()
2037
+ res = response.json()
2038
+ top_categories_text = res["candidates"][0]["content"]["parts"][0]["text"]
2039
+ elif model_source == "mistral":
2040
+ endpoint = "https://api.mistral.ai/v1/chat/completions"
2041
+ headers = {
2042
+ "Content-Type": "application/json",
2043
+ "Authorization": f"Bearer {api_key}"
2044
+ }
2045
+ payload = {"model": user_model, "messages": [{"role": "user", "content": second_prompt}]}
2046
+ if creativity is not None:
2047
+ payload["temperature"] = creativity
2048
+ response = req.post(endpoint, headers=headers, json=payload, timeout=120)
2049
+ response.raise_for_status()
2050
+ result = response.json()
2051
+ top_categories_text = result["choices"][0]["message"]["content"]
2052
+ except Exception as e:
2053
+ print(f"Error in second-pass merge: {e}")
2054
+ top_categories_text = ""
2055
+
2056
+ # Final progress callback for the merge step
2057
+ if progress_callback:
2058
+ progress_callback(total_steps, total_steps, "Merging categories")
2059
+
2060
+ # Parse final list
2061
+ final = []
2062
+ for line in top_categories_text.splitlines():
2063
+ m = line_pat.match(line.strip())
2064
+ if m:
2065
+ final.append(m.group(1).strip())
2066
+ if not final:
2067
+ final = [l.strip("-*• ").strip() for l in top_categories_text.splitlines() if l.strip()]
2068
+
2069
+ print("\nTop categories:\n" + "\n".join(f"{i+1}. {c}" for i, c in enumerate(final[:max_categories])))
2070
+
2071
+ if filename:
2072
+ result.to_csv(filename, index=False)
2073
+
2074
+ return {
2075
+ "counts_df": result,
2076
+ "top_categories": final[:max_categories],
2077
+ "raw_top_text": top_categories_text
2078
+ }