cat-llm 0.0.17__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
- Project-URL: Documentation, https://github.com/Christopher Soria/cat-llm#readme
5
+ Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
7
7
  Project-URL: Source, https://github.com/chrissoria/cat-llm
8
8
  Author-email: Christopher Soria <chrissoria@berkeley.edu>
@@ -30,7 +30,7 @@ dependencies = [
30
30
  ]
31
31
 
32
32
  [project.urls]
33
- Documentation = "https://github.com/Christopher Soria/cat-llm#readme"
33
+ Documentation = "https://github.com/chrissoria/cat-llm#readme"
34
34
  Issues = "https://github.com/chrissoria/cat-llm/issues"
35
35
  Source = "https://github.com/chrissoria/cat-llm"
36
36
 
@@ -45,13 +45,15 @@ extra-dependencies = [
45
45
  check = "mypy --install-types --non-interactive {args:src/catllm tests}"
46
46
 
47
47
  [tool.hatch.build.targets.wheel]
48
- packages = ["src/catllm"]
48
+ packages = ["src/catllm"] # Keep the src/ prefix for wheel
49
49
  include = [
50
+ "src/catllm/**/*.py",
50
51
  "src/catllm/images/*",
51
52
  ]
52
53
 
53
54
  [tool.hatch.build.targets.sdist]
54
55
  include = [
56
+ "src/catllm/**/*.py",
55
57
  "src/catllm/images/*",
56
58
  ]
57
59
 
@@ -0,0 +1,10 @@
1
+ # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+ __version__ = "0.0.19"
5
+ __author__ = "Chris Soria"
6
+ __email__ = "chrissoria@berkeley.edu"
7
+ __title__ = "cat-llm"
8
+ __description__ = "A tool for categorizing and exploring text data and images using LLMs and vision models"
9
+ __url__ = "https://github.com/chrissoria/cat-llm"
10
+ __license__ = "MIT"
@@ -0,0 +1,14 @@
1
+ # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
+ #
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from .__about__ import (
6
+ __version__,
7
+ __author__,
8
+ __description__,
9
+ __title__,
10
+ __url__,
11
+ __license__,
12
+ )
13
+
14
+ from .cat_llm import *
@@ -0,0 +1,1075 @@
1
+ #multi-class text classification
2
+ def extract_multi_class(
3
+ survey_question,
4
+ survey_input,
5
+ categories,
6
+ api_key,
7
+ columns="numbered",
8
+ user_model="gpt-4o-2024-11-20",
9
+ creativity=0,
10
+ to_csv=False,
11
+ safety=False,
12
+ filename="categorized_data.csv",
13
+ save_directory=None,
14
+ model_source="OpenAI"
15
+ ):
16
+ import os
17
+ import json
18
+ import pandas as pd
19
+ import regex
20
+ from tqdm import tqdm
21
+
22
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
23
+ cat_num = len(categories)
24
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
25
+ example_JSON = json.dumps(category_dict, indent=4)
26
+
27
+ # ensure number of categories is what user wants
28
+ print("\nThe categories you entered:")
29
+ for i, cat in enumerate(categories, 1):
30
+ print(f"{i}. {cat}")
31
+
32
+ link1 = []
33
+ extracted_jsons = []
34
+
35
+ for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
36
+ reply = None
37
+
38
+ if pd.isna(response):
39
+ link1.append("Skipped NaN input")
40
+ default_json = example_JSON
41
+ extracted_jsons.append(default_json)
42
+ #print(f"Skipped NaN input.")
43
+ else:
44
+ prompt = f"""A respondent was asked: {survey_question}. \
45
+ Categorize this survey response "{response}" into the following categories that apply: \
46
+ {categories_str} \
47
+ Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
48
+ #print(prompt)
49
+ if model_source == ("OpenAI"):
50
+ from openai import OpenAI
51
+ client = OpenAI(api_key=api_key)
52
+ try:
53
+ response_obj = client.chat.completions.create(
54
+ model=user_model,
55
+ messages=[{'role': 'user', 'content': prompt}],
56
+ temperature=creativity
57
+ )
58
+ reply = response_obj.choices[0].message.content
59
+ link1.append(reply)
60
+ except Exception as e:
61
+ print(f"An error occurred: {e}")
62
+ link1.append(f"Error processing input: {e}")
63
+ elif model_source == "Perplexity":
64
+ from openai import OpenAI
65
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
66
+ try:
67
+ response_obj = client.chat.completions.create(
68
+ model=user_model,
69
+ messages=[{'role': 'user', 'content': prompt}],
70
+ temperature=creativity
71
+ )
72
+ reply = response_obj.choices[0].message.content
73
+ link1.append(reply)
74
+ except Exception as e:
75
+ print(f"An error occurred: {e}")
76
+ link1.append(f"Error processing input: {e}")
77
+ elif model_source == "Anthropic":
78
+ import anthropic
79
+ client = anthropic.Anthropic(api_key=api_key)
80
+ try:
81
+ message = client.messages.create(
82
+ model=user_model,
83
+ max_tokens=1024,
84
+ temperature=creativity,
85
+ messages=[{"role": "user", "content": prompt}]
86
+ )
87
+ reply = message.content[0].text # Anthropic returns content as list
88
+ link1.append(reply)
89
+ except Exception as e:
90
+ print(f"An error occurred: {e}")
91
+ link1.append(f"Error processing input: {e}")
92
+ elif model_source == "Mistral":
93
+ from mistralai import Mistral
94
+ client = Mistral(api_key=api_key)
95
+ try:
96
+ response = client.chat.complete(
97
+ model=user_model,
98
+ messages=[
99
+ {'role': 'user', 'content': prompt}
100
+ ],
101
+ temperature=creativity
102
+ )
103
+ reply = response.choices[0].message.content
104
+ link1.append(reply)
105
+ except Exception as e:
106
+ print(f"An error occurred: {e}")
107
+ link1.append(f"Error processing input: {e}")
108
+ else:
109
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
110
+ # in situation that no JSON is found
111
+ if reply is not None:
112
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
113
+ if extracted_json:
114
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
115
+ extracted_jsons.append(cleaned_json)
116
+ #print(cleaned_json)
117
+ else:
118
+ error_message = """{"1":"e"}"""
119
+ extracted_jsons.append(error_message)
120
+ print(error_message)
121
+ else:
122
+ error_message = """{"1":"e"}"""
123
+ extracted_jsons.append(error_message)
124
+ #print(error_message)
125
+
126
+ # --- Safety Save ---
127
+ if safety:
128
+ # Save progress so far
129
+ temp_df = pd.DataFrame({
130
+ 'survey_response': survey_input[:idx+1],
131
+ 'link1': link1,
132
+ 'json': extracted_jsons
133
+ })
134
+ # Normalize processed jsons so far
135
+ normalized_data_list = []
136
+ for json_str in extracted_jsons:
137
+ try:
138
+ parsed_obj = json.loads(json_str)
139
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
140
+ except json.JSONDecodeError:
141
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
142
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
143
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
144
+ # Save to CSV
145
+ if save_directory is None:
146
+ save_directory = os.getcwd()
147
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
148
+
149
+ # --- Final DataFrame ---
150
+ normalized_data_list = []
151
+ for json_str in extracted_jsons:
152
+ try:
153
+ parsed_obj = json.loads(json_str)
154
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
155
+ except json.JSONDecodeError:
156
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
157
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
158
+
159
+ categorized_data = pd.DataFrame({
160
+ 'survey_response': survey_input.reset_index(drop=True),
161
+ 'link1': pd.Series(link1).reset_index(drop=True),
162
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
163
+ })
164
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
165
+
166
+ if columns != "numbered": #if user wants text columns
167
+ categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
168
+
169
+ if to_csv:
170
+ if save_directory is None:
171
+ save_directory = os.getcwd()
172
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
173
+
174
+ return categorized_data
175
+
176
+ # image multi-class (binary) function
177
+ def extract_image_multi_class(
178
+ image_description,
179
+ image_input,
180
+ categories,
181
+ api_key,
182
+ columns="numbered",
183
+ user_model="gpt-4o-2024-11-20",
184
+ creativity=0,
185
+ to_csv=False,
186
+ safety=False,
187
+ filename="categorized_data.csv",
188
+ save_directory=None,
189
+ model_source="OpenAI"
190
+ ):
191
+ import os
192
+ import json
193
+ import pandas as pd
194
+ import regex
195
+ from tqdm import tqdm
196
+ import glob
197
+ import base64
198
+ from pathlib import Path
199
+
200
+ if save_directory is not None and not os.path.isdir(save_directory):
201
+ # Directory doesn't exist - raise an exception to halt execution
202
+ raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
203
+
204
+ image_extensions = [
205
+ '*.png', '*.jpg', '*.jpeg',
206
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
207
+ '*.tif', '*.tiff', '*.bmp',
208
+ '*.heif', '*.heic', '*.ico',
209
+ '*.psd'
210
+ ]
211
+
212
+ if not isinstance(image_input, list):
213
+ # If image_input is a filepath (string)
214
+ image_files = []
215
+ for ext in image_extensions:
216
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
217
+
218
+ print(f"Found {len(image_files)} images.")
219
+ else:
220
+ # If image_files is already a list
221
+ image_files = image_input
222
+ print(f"Provided a list of {len(image_input)} images.")
223
+
224
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
225
+ cat_num = len(categories)
226
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
227
+ example_JSON = json.dumps(category_dict, indent=4)
228
+
229
+ # ensure number of categories is what user wants
230
+ print("Categories to classify:")
231
+ for i, cat in enumerate(categories, 1):
232
+ print(f"{i}. {cat}")
233
+
234
+ link1 = []
235
+ extracted_jsons = []
236
+
237
+ for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
238
+ # Check validity first
239
+ if img_path is None or not os.path.exists(img_path):
240
+ link1.append("Skipped NaN input or invalid path")
241
+ extracted_jsons.append("""{"no_valid_image": 1}""")
242
+ continue # Skip the rest of the loop iteration
243
+
244
+ # Only open the file if path is valid
245
+ with open(img_path, "rb") as f:
246
+ encoded = base64.b64encode(f.read()).decode("utf-8")
247
+
248
+ # Handle extension safely
249
+ ext = Path(img_path).suffix.lstrip(".").lower()
250
+ encoded_image = f"data:image/{ext};base64,{encoded}"
251
+
252
+ prompt = [
253
+ {
254
+ "type": "text",
255
+ "text": (
256
+ f"You are an image-tagging assistant.\n"
257
+ f"Task ► Examine the attached image and decide, **for each category below**, "
258
+ f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
259
+ f"Image is expected to show: {image_description}\n\n"
260
+ f"Categories:\n{categories_str}\n\n"
261
+ f"Output format ► Respond with **only** a JSON object whose keys are the "
262
+ f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
263
+ f"No additional keys, comments, or text.\n\n"
264
+ f"Example (three categories):\n"
265
+ f"{example_JSON}"
266
+ ),
267
+ },
268
+ {
269
+ "type": "image_url",
270
+ "image_url": {"url": encoded_image, "detail": "high"},
271
+ },
272
+ ]
273
+ if model_source == "OpenAI":
274
+ from openai import OpenAI
275
+ client = OpenAI(api_key=api_key)
276
+ try:
277
+ response_obj = client.chat.completions.create(
278
+ model=user_model,
279
+ messages=[{'role': 'user', 'content': prompt}],
280
+ temperature=creativity
281
+ )
282
+ reply = response_obj.choices[0].message.content
283
+ link1.append(reply)
284
+ except Exception as e:
285
+ print(f"An error occurred: {e}")
286
+ link1.append(f"Error processing input: {e}")
287
+
288
+ elif model_source == "Perplexity":
289
+ from openai import OpenAI
290
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
291
+ try:
292
+ response_obj = client.chat.completions.create(
293
+ model=user_model,
294
+ messages=[{'role': 'user', 'content': prompt}],
295
+ temperature=creativity
296
+ )
297
+ reply = response_obj.choices[0].message.content
298
+ link1.append(reply)
299
+ except Exception as e:
300
+ print(f"An error occurred: {e}")
301
+ link1.append(f"Error processing input: {e}")
302
+ elif model_source == "Anthropic":
303
+ import anthropic
304
+ client = anthropic.Anthropic(api_key=api_key)
305
+ try:
306
+ message = client.messages.create(
307
+ model=user_model,
308
+ max_tokens=1024,
309
+ temperature=creativity,
310
+ messages=[{"role": "user", "content": prompt}]
311
+ )
312
+ reply = message.content[0].text # Anthropic returns content as list
313
+ link1.append(reply)
314
+ except Exception as e:
315
+ print(f"An error occurred: {e}")
316
+ link1.append(f"Error processing input: {e}")
317
+ elif model_source == "Mistral":
318
+ from mistralai import Mistral
319
+ client = Mistral(api_key=api_key)
320
+ try:
321
+ response = client.chat.complete(
322
+ model=user_model,
323
+ messages=[
324
+ {'role': 'user', 'content': prompt}
325
+ ],
326
+ temperature=creativity
327
+ )
328
+ reply = response.choices[0].message.content
329
+ link1.append(reply)
330
+ except Exception as e:
331
+ print(f"An error occurred: {e}")
332
+ link1.append(f"Error processing input: {e}")
333
+ else:
334
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
335
+ # in situation that no JSON is found
336
+ if reply is not None:
337
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
338
+ if extracted_json:
339
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
340
+ extracted_jsons.append(cleaned_json)
341
+ #print(cleaned_json)
342
+ else:
343
+ error_message = """{"1":"e"}"""
344
+ extracted_jsons.append(error_message)
345
+ print(error_message)
346
+ else:
347
+ error_message = """{"1":"e"}"""
348
+ extracted_jsons.append(error_message)
349
+ #print(error_message)
350
+
351
+ # --- Safety Save ---
352
+ if safety:
353
+ #print(f"Saving CSV to: {save_directory}")
354
+ # Save progress so far
355
+ temp_df = pd.DataFrame({
356
+ 'image_input': image_files[:i+1],
357
+ 'link1': link1,
358
+ 'json': extracted_jsons
359
+ })
360
+ # Normalize processed jsons so far
361
+ normalized_data_list = []
362
+ for json_str in extracted_jsons:
363
+ try:
364
+ parsed_obj = json.loads(json_str)
365
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
366
+ except json.JSONDecodeError:
367
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
368
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
369
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
370
+ # Save to CSV
371
+ if save_directory is None:
372
+ save_directory = os.getcwd()
373
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
374
+
375
+ # --- Final DataFrame ---
376
+ normalized_data_list = []
377
+ for json_str in extracted_jsons:
378
+ try:
379
+ parsed_obj = json.loads(json_str)
380
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
381
+ except json.JSONDecodeError:
382
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
383
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
384
+
385
+ categorized_data = pd.DataFrame({
386
+ 'image_input': image_files,
387
+ 'link1': pd.Series(link1).reset_index(drop=True),
388
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
389
+ })
390
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
391
+
392
+ if columns != "numbered": #if user wants text columns
393
+ categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
394
+
395
+ if to_csv:
396
+ if save_directory is None:
397
+ save_directory = os.getcwd()
398
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
399
+
400
+ return categorized_data
401
+
402
+ #image score function
403
+ def extract_image_score(
404
+ reference_image_description,
405
+ image_input,
406
+ reference_image,
407
+ api_key,
408
+ columns="numbered",
409
+ user_model="gpt-4o-2024-11-20",
410
+ creativity=0,
411
+ to_csv=False,
412
+ safety=False,
413
+ filename="categorized_data.csv",
414
+ save_directory=None,
415
+ model_source="OpenAI"
416
+ ):
417
+ import os
418
+ import json
419
+ import pandas as pd
420
+ import regex
421
+ from tqdm import tqdm
422
+ import glob
423
+ import base64
424
+ from pathlib import Path
425
+
426
+ if save_directory is not None and not os.path.isdir(save_directory):
427
+ # Directory doesn't exist - raise an exception to halt execution
428
+ raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
429
+
430
+ image_extensions = [
431
+ '*.png', '*.jpg', '*.jpeg',
432
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
433
+ '*.tif', '*.tiff', '*.bmp',
434
+ '*.heif', '*.heic', '*.ico',
435
+ '*.psd'
436
+ ]
437
+
438
+ if not isinstance(image_input, list):
439
+ # If image_input is a filepath (string)
440
+ image_files = []
441
+ for ext in image_extensions:
442
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
443
+
444
+ print(f"Found {len(image_files)} images.")
445
+ else:
446
+ # If image_files is already a list
447
+ image_files = image_input
448
+ print(f"Provided a list of {len(image_input)} images.")
449
+
450
+ with open(reference_image, 'rb') as f:
451
+ reference_image = f"data:image/{reference_image.split('.')[-1]};base64,{base64.b64encode(f.read()).decode('utf-8')}"
452
+
453
+ link1 = []
454
+ extracted_jsons = []
455
+
456
+ for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
457
+ # Check validity first
458
+ if img_path is None or not os.path.exists(img_path):
459
+ link1.append("Skipped NaN input or invalid path")
460
+ extracted_jsons.append("""{"no_valid_image": 1}""")
461
+ continue # Skip the rest of the loop iteration
462
+
463
+ # Only open the file if path is valid
464
+ with open(img_path, "rb") as f:
465
+ encoded = base64.b64encode(f.read()).decode("utf-8")
466
+
467
+ # Handle extension safely
468
+ ext = Path(img_path).suffix.lstrip(".").lower()
469
+ encoded_image = f"data:image/{ext};base64,{encoded}"
470
+
471
+ prompt = [
472
+ {
473
+ "type": "text",
474
+ "text": (
475
+ f"You are a visual similarity assessment system.\n"
476
+ f"Task ► Compare these two images:\n"
477
+ f"1. REFERENCE (left): {reference_image_description}\n"
478
+ f"2. INPUT (right): User-provided drawing\n\n"
479
+ f"Rating criteria:\n"
480
+ f"1: No meaningful similarity (fundamentally different)\n"
481
+ f"2: Barely recognizable similarity (25% match)\n"
482
+ f"3: Partial match (50% key features)\n"
483
+ f"4: Strong alignment (75% features)\n"
484
+ f"5: Near-perfect match (90%+ similarity)\n\n"
485
+ f"Output format ► Return ONLY:\n"
486
+ "{\n"
487
+ ' "score": [1-5],\n'
488
+ ' "summary": "reason you scored"\n'
489
+ "}\n\n"
490
+ f"Critical rules:\n"
491
+ f"- Score must reflect shape, proportions, and key details\n"
492
+ f"- List only concrete matching elements from reference\n"
493
+ f"- No markdown or additional text"
494
+ ),
495
+ },
496
+ {"type": "image_url",
497
+ "image_url": {"url": reference_image, "detail": "high"}
498
+ },
499
+ {
500
+ "type": "image_url",
501
+
502
+ "image_url": {"url": encoded_image, "detail": "high"},
503
+ },
504
+ ]
505
+ if model_source == "OpenAI":
506
+ from openai import OpenAI
507
+ client = OpenAI(api_key=api_key)
508
+ try:
509
+ response_obj = client.chat.completions.create(
510
+ model=user_model,
511
+ messages=[{'role': 'user', 'content': prompt}],
512
+ temperature=creativity
513
+ )
514
+ reply = response_obj.choices[0].message.content
515
+ link1.append(reply)
516
+ except Exception as e:
517
+ print(f"An error occurred: {e}")
518
+ link1.append(f"Error processing input: {e}")
519
+
520
+ elif model_source == "Perplexity":
521
+ from openai import OpenAI
522
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
523
+ try:
524
+ response_obj = client.chat.completions.create(
525
+ model=user_model,
526
+ messages=[{'role': 'user', 'content': prompt}],
527
+ temperature=creativity
528
+ )
529
+ reply = response_obj.choices[0].message.content
530
+ link1.append(reply)
531
+ except Exception as e:
532
+ print(f"An error occurred: {e}")
533
+ link1.append(f"Error processing input: {e}")
534
+ elif model_source == "Anthropic":
535
+ import anthropic
536
+ client = anthropic.Anthropic(api_key=api_key)
537
+ try:
538
+ message = client.messages.create(
539
+ model=user_model,
540
+ max_tokens=1024,
541
+ temperature=creativity,
542
+ messages=[{"role": "user", "content": prompt}]
543
+ )
544
+ reply = message.content[0].text # Anthropic returns content as list
545
+ link1.append(reply)
546
+ except Exception as e:
547
+ print(f"An error occurred: {e}")
548
+ link1.append(f"Error processing input: {e}")
549
+ elif model_source == "Mistral":
550
+ from mistralai import Mistral
551
+ client = Mistral(api_key=api_key)
552
+ try:
553
+ response = client.chat.complete(
554
+ model=user_model,
555
+ messages=[
556
+ {'role': 'user', 'content': prompt}
557
+ ],
558
+ temperature=creativity
559
+ )
560
+ reply = response.choices[0].message.content
561
+ link1.append(reply)
562
+ except Exception as e:
563
+ print(f"An error occurred: {e}")
564
+ link1.append(f"Error processing input: {e}")
565
+ else:
566
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
567
+ # in situation that no JSON is found
568
+ if reply is not None:
569
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
570
+ if extracted_json:
571
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
572
+ extracted_jsons.append(cleaned_json)
573
+ #print(cleaned_json)
574
+ else:
575
+ error_message = """{"1":"e"}"""
576
+ extracted_jsons.append(error_message)
577
+ print(error_message)
578
+ else:
579
+ error_message = """{"1":"e"}"""
580
+ extracted_jsons.append(error_message)
581
+ #print(error_message)
582
+
583
+ # --- Safety Save ---
584
+ if safety:
585
+ # Save progress so far
586
+ temp_df = pd.DataFrame({
587
+ 'image_input': image_files[:i+1],
588
+ 'link1': link1,
589
+ 'json': extracted_jsons
590
+ })
591
+ # Normalize processed jsons so far
592
+ normalized_data_list = []
593
+ for json_str in extracted_jsons:
594
+ try:
595
+ parsed_obj = json.loads(json_str)
596
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
597
+ except json.JSONDecodeError:
598
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
599
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
600
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
601
+ # Save to CSV
602
+ if save_directory is None:
603
+ save_directory = os.getcwd()
604
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
605
+
606
+ # --- Final DataFrame ---
607
+ normalized_data_list = []
608
+ for json_str in extracted_jsons:
609
+ try:
610
+ parsed_obj = json.loads(json_str)
611
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
612
+ except json.JSONDecodeError:
613
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
614
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
615
+
616
+ categorized_data = pd.DataFrame({
617
+ 'image_input': image_files,
618
+ 'link1': pd.Series(link1).reset_index(drop=True),
619
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
620
+ })
621
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
622
+
623
+ if to_csv:
624
+ if save_directory is None:
625
+ save_directory = os.getcwd()
626
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
627
+
628
+ return categorized_data
629
+
630
+ # image features function
631
+ def extract_image_features(
632
+ image_description,
633
+ image_input,
634
+ features_to_extract,
635
+ api_key,
636
+ columns="numbered",
637
+ user_model="gpt-4o-2024-11-20",
638
+ creativity=0,
639
+ to_csv=False,
640
+ safety=False,
641
+ filename="categorized_data.csv",
642
+ save_directory=None,
643
+ model_source="OpenAI"
644
+ ):
645
+ import os
646
+ import json
647
+ import pandas as pd
648
+ import regex
649
+ from tqdm import tqdm
650
+ import glob
651
+ import base64
652
+ from pathlib import Path
653
+
654
+ if save_directory is not None and not os.path.isdir(save_directory):
655
+ # Directory doesn't exist - raise an exception to halt execution
656
+ raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
657
+
658
+ image_extensions = [
659
+ '*.png', '*.jpg', '*.jpeg',
660
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
661
+ '*.tif', '*.tiff', '*.bmp',
662
+ '*.heif', '*.heic', '*.ico',
663
+ '*.psd'
664
+ ]
665
+
666
+ if not isinstance(image_input, list):
667
+ # If image_input is a filepath (string)
668
+ image_files = []
669
+ for ext in image_extensions:
670
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
671
+
672
+ print(f"Found {len(image_files)} images.")
673
+ else:
674
+ # If image_files is already a list
675
+ image_files = image_input
676
+ print(f"Provided a list of {len(image_input)} images.")
677
+
678
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(features_to_extract))
679
+ cat_num = len(features_to_extract)
680
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
681
+ example_JSON = json.dumps(category_dict, indent=4)
682
+
683
+ # ensure number of categories is what user wants
684
+ print("\nThe image features to be extracted are:")
685
+ for i, cat in enumerate(features_to_extract, 1):
686
+ print(f"{i}. {cat}")
687
+
688
+ link1 = []
689
+ extracted_jsons = []
690
+
691
+ for i, img_path in enumerate(
692
+ tqdm(image_files, desc="Categorising images"), start=0):
693
+ # encode this specific image once
694
+ with open(img_path, "rb") as f:
695
+ encoded = base64.b64encode(f.read()).decode("utf-8")
696
+ ext = Path(img_path).suffix.lstrip(".").lower()
697
+ encoded_image = f"data:image/{ext};base64,{encoded}"
698
+
699
+ prompt = [
700
+ {
701
+ "type": "text",
702
+ "text": (
703
+ f"You are a visual question answering assistant.\n"
704
+ f"Task ► Analyze the attached image and answer these specific questions:\n\n"
705
+ f"Image context: {image_description}\n\n"
706
+ f"Questions to answer:\n{categories_str}\n\n"
707
+ f"Output format ► Return **only** a JSON object where:\n"
708
+ f"- Keys are question numbers ('1', '2', ...)\n"
709
+ f"- Values are concise answers (numbers, short phrases)\n\n"
710
+ f"Example for 3 questions:\n"
711
+ "{\n"
712
+ ' "1": "4",\n'
713
+ ' "2": "blue",\n'
714
+ ' "3": "yes"\n'
715
+ "}\n\n"
716
+ f"Important rules:\n"
717
+ f"1. Answer directly - no explanations\n"
718
+ f"2. Use exact numerical values when possible\n"
719
+ f"3. For yes/no questions, use 'yes' or 'no'\n"
720
+ f"4. Never add extra keys or formatting"
721
+ ),
722
+ },
723
+ {
724
+ "type": "image_url",
725
+ "image_url": {"url": encoded_image, "detail": "high"},
726
+ },
727
+ ]
728
+ if model_source == "OpenAI":
729
+ from openai import OpenAI
730
+ client = OpenAI(api_key=api_key)
731
+ try:
732
+ response_obj = client.chat.completions.create(
733
+ model=user_model,
734
+ messages=[{'role': 'user', 'content': prompt}],
735
+ temperature=creativity
736
+ )
737
+ reply = response_obj.choices[0].message.content
738
+ link1.append(reply)
739
+ except Exception as e:
740
+ print(f"An error occurred: {e}")
741
+ link1.append(f"Error processing input: {e}")
742
+
743
+ elif model_source == "Perplexity":
744
+ from openai import OpenAI
745
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
746
+ try:
747
+ response_obj = client.chat.completions.create(
748
+ model=user_model,
749
+ messages=[{'role': 'user', 'content': prompt}],
750
+ temperature=creativity
751
+ )
752
+ reply = response_obj.choices[0].message.content
753
+ link1.append(reply)
754
+ except Exception as e:
755
+ print(f"An error occurred: {e}")
756
+ link1.append(f"Error processing input: {e}")
757
+ elif model_source == "Anthropic":
758
+ import anthropic
759
+ client = anthropic.Anthropic(api_key=api_key)
760
+ try:
761
+ message = client.messages.create(
762
+ model=user_model,
763
+ max_tokens=1024,
764
+ temperature=creativity,
765
+ messages=[{"role": "user", "content": prompt}]
766
+ )
767
+ reply = message.content[0].text # Anthropic returns content as list
768
+ link1.append(reply)
769
+ except Exception as e:
770
+ print(f"An error occurred: {e}")
771
+ link1.append(f"Error processing input: {e}")
772
+ elif model_source == "Mistral":
773
+ from mistralai import Mistral
774
+ client = Mistral(api_key=api_key)
775
+ try:
776
+ response = client.chat.complete(
777
+ model=user_model,
778
+ messages=[
779
+ {'role': 'user', 'content': prompt}
780
+ ],
781
+ temperature=creativity
782
+ )
783
+ reply = response.choices[0].message.content
784
+ link1.append(reply)
785
+ except Exception as e:
786
+ print(f"An error occurred: {e}")
787
+ link1.append(f"Error processing input: {e}")
788
+ else:
789
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
790
+ # in situation that no JSON is found
791
+ if reply is not None:
792
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
793
+ if extracted_json:
794
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
795
+ extracted_jsons.append(cleaned_json)
796
+ #print(cleaned_json)
797
+ else:
798
+ error_message = """{"1":"e"}"""
799
+ extracted_jsons.append(error_message)
800
+ print(error_message)
801
+ else:
802
+ error_message = """{"1":"e"}"""
803
+ extracted_jsons.append(error_message)
804
+ #print(error_message)
805
+
806
+ # --- Safety Save ---
807
+ if safety:
808
+ #print(f"Saving CSV to: {save_directory}")
809
+ # Save progress so far
810
+ temp_df = pd.DataFrame({
811
+ 'image_input': image_files[:i+1],
812
+ 'link1': link1,
813
+ 'json': extracted_jsons
814
+ })
815
+ # Normalize processed jsons so far
816
+ normalized_data_list = []
817
+ for json_str in extracted_jsons:
818
+ try:
819
+ parsed_obj = json.loads(json_str)
820
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
821
+ except json.JSONDecodeError:
822
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
823
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
824
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
825
+ # Save to CSV
826
+ if save_directory is None:
827
+ save_directory = os.getcwd()
828
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
829
+
830
+ # --- Final DataFrame ---
831
+ normalized_data_list = []
832
+ for json_str in extracted_jsons:
833
+ try:
834
+ parsed_obj = json.loads(json_str)
835
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
836
+ except json.JSONDecodeError:
837
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
838
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
839
+
840
+ categorized_data = pd.DataFrame({
841
+ 'image_input': image_files,
842
+ 'link1': pd.Series(link1).reset_index(drop=True),
843
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
844
+ })
845
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
846
+
847
+ if columns != "numbered": #if user wants text columns
848
+ categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
849
+
850
+ if to_csv:
851
+ if save_directory is None:
852
+ save_directory = os.getcwd()
853
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
854
+
855
+ return categorized_data
856
+
857
+ #extract categories from corpus
858
+ def explore_corpus(
859
+ survey_question,
860
+ survey_input,
861
+ api_key,
862
+ research_question=None,
863
+ specificity="broad",
864
+ cat_num=10,
865
+ divisions=5,
866
+ user_model="gpt-4o-2024-11-20",
867
+ creativity=0,
868
+ filename="corpus_exploration.csv",
869
+ model_source="OpenAI"
870
+ ):
871
+ import os
872
+ import pandas as pd
873
+ import random
874
+ from openai import OpenAI
875
+ from openai import OpenAI, BadRequestError
876
+ from tqdm import tqdm
877
+
878
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
879
+ print()
880
+
881
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
882
+ chunk_size = int(chunk_size)
883
+
884
+ if chunk_size < (cat_num/2):
885
+ raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
886
+ f"Choose one solution: \n"
887
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
888
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
889
+
890
+ random_chunks = []
891
+ for i in range(divisions):
892
+ chunk = survey_input.sample(n=chunk_size).tolist()
893
+ random_chunks.append(chunk)
894
+
895
+ responses = []
896
+ responses_list = []
897
+
898
+ for i in tqdm(range(divisions), desc="Processing chunks"):
899
+ survey_participant_chunks = '; '.join(random_chunks[i])
900
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
901
+ Responses are each separated by a semicolon. \
902
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
903
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
904
+
905
+ if model_source == "OpenAI":
906
+ client = OpenAI(api_key=api_key)
907
+ try:
908
+ response_obj = client.chat.completions.create(
909
+ model=user_model,
910
+ messages=[
911
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
912
+ The specific task is to identify {specificity} categories of responses to a survey question. \
913
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
914
+ {'role': 'user', 'content': prompt}
915
+ ]
916
+ temperature=creativity
917
+ )
918
+ reply = response_obj.choices[0].message.content
919
+ responses.append(reply)
920
+ except BadRequestError as e:
921
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
922
+ error_msg = (f"Token limit exceeded for model {user_model}. "
923
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
924
+ raise ValueError(error_msg)
925
+ else:
926
+ print(f"OpenAI API error: {e}")
927
+ except Exception as e:
928
+ print(f"An error occurred: {e}")
929
+ else:
930
+ raise ValueError(f"Unsupported model_source: {model_source}")
931
+
932
+ # Extract just the text as a list
933
+ items = []
934
+ for line in responses[i].split('\n'):
935
+ if '. ' in line:
936
+ try:
937
+ items.append(line.split('. ', 1)[1])
938
+ except IndexError:
939
+ pass
940
+
941
+ responses_list.append(items)
942
+
943
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
944
+
945
+ #convert flat_list to a df
946
+ df = pd.DataFrame(flat_list, columns=['Category'])
947
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
948
+ df['counts'] = df['Category'].map(counts)
949
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
950
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
951
+
952
+ if filename is not None:
953
+ df.to_csv(filename, index=False)
954
+
955
+ return df
956
+
957
+ #extract top categories from corpus
958
+ def explore_common_categories(
959
+ survey_question,
960
+ survey_input,
961
+ api_key,
962
+ top_n=10,
963
+ cat_num=10,
964
+ divisions=5,
965
+ user_model="gpt-4o-2024-11-20",
966
+ creativity=0,
967
+ specificity="broad",
968
+ research_question=None,
969
+ filename=None,
970
+ model_source="OpenAI"
971
+ ):
972
+ import os
973
+ import pandas as pd
974
+ import random
975
+ from openai import OpenAI
976
+ from openai import OpenAI, BadRequestError
977
+ from tqdm import tqdm
978
+
979
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
980
+ print()
981
+
982
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
983
+ chunk_size = int(chunk_size)
984
+
985
+ if chunk_size < (cat_num/2):
986
+ raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
987
+ f"Choose one solution: \n"
988
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
989
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
990
+
991
+ random_chunks = []
992
+ for i in range(divisions):
993
+ chunk = survey_input.sample(n=chunk_size).tolist()
994
+ random_chunks.append(chunk)
995
+
996
+ responses = []
997
+ responses_list = []
998
+
999
+ for i in tqdm(range(divisions), desc="Processing chunks"):
1000
+ survey_participant_chunks = '; '.join(random_chunks[i])
1001
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
1002
+ Responses are each separated by a semicolon. \
1003
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
1004
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
1005
+
1006
+ if model_source == "OpenAI":
1007
+ client = OpenAI(api_key=api_key)
1008
+ try:
1009
+ response_obj = client.chat.completions.create(
1010
+ model=user_model,
1011
+ messages=[
1012
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
1013
+ The specific task is to identify {specificity} categories of responses to a survey question. \
1014
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
1015
+ {'role': 'user', 'content': prompt}
1016
+ ],
1017
+ temperature=creativity
1018
+ )
1019
+ reply = response_obj.choices[0].message.content
1020
+ responses.append(reply)
1021
+ except BadRequestError as e:
1022
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
1023
+ error_msg = (f"Token limit exceeded for model {user_model}. "
1024
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
1025
+ raise ValueError(error_msg)
1026
+ else:
1027
+ print(f"OpenAI API error: {e}")
1028
+ except Exception as e:
1029
+ print(f"An error occurred: {e}")
1030
+ else:
1031
+ raise ValueError(f"Unsupported model_source: {model_source}")
1032
+
1033
+ # Extract just the text as a list
1034
+ items = []
1035
+ for line in responses[i].split('\n'):
1036
+ if '. ' in line:
1037
+ try:
1038
+ items.append(line.split('. ', 1)[1])
1039
+ except IndexError:
1040
+ pass
1041
+
1042
+ responses_list.append(items)
1043
+
1044
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
1045
+
1046
+ #convert flat_list to a df
1047
+ df = pd.DataFrame(flat_list, columns=['Category'])
1048
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
1049
+ df['counts'] = df['Category'].map(counts)
1050
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
1051
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
1052
+
1053
+ second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
1054
+ The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
1055
+ Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
1056
+
1057
+ if model_source == "OpenAI":
1058
+ client = OpenAI(api_key=api_key)
1059
+ response_obj = client.chat.completions.create(
1060
+ model=user_model,
1061
+ messages=[{'role': 'user', 'content': second_prompt}],
1062
+ temperature=creativity
1063
+ )
1064
+ top_categories = response_obj.choices[0].message.content
1065
+ print(top_categories)
1066
+
1067
+ top_categories_final = []
1068
+ for line in top_categories.split('\n'):
1069
+ if '. ' in line:
1070
+ try:
1071
+ top_categories_final.append(line.split('. ', 1)[1])
1072
+ except IndexError:
1073
+ pass
1074
+
1075
+ return top_categories_final
File without changes
File without changes