cat-llm 0.0.25__tar.gz → 0.0.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.25
3
+ Version: 0.0.26
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -2,6 +2,12 @@
2
2
  # specifically for pictures of drawings of shapes like circles, diamonds, rectangles, and cubes
3
3
 
4
4
  """
5
+ Ket features:
6
+ 1. Shape-specific scoring: The function can handle different shapes (circle, diamond, rectangles, cube) and provides tailored categories for each shape.
7
+ 2. Image input handling: It accepts image inputs either as file paths or a list of images.
8
+ 3. Model flexibility: The function allows users to specify different models (OpenAI, Anthropic, Perplexity, Mistral) for image analysis.
9
+ 4. Safety and progress saving: It can save progress to a CSV file, which is useful for long-running tasks or when processing many images.
10
+
5
11
  Areas for improvement:
6
12
  1. Prompt refinement: adjusting the prompt so that it produces a more accurate score.
7
13
  2. Image preprocessing: adjusting the images so that they are easier to be analyzed by the models.
@@ -19,7 +25,7 @@ def cerad_drawn_score(
19
25
  shape,
20
26
  image_input,
21
27
  api_key,
22
- user_model="gpt-4o-2024-11-20",
28
+ user_model="gpt-4o",
23
29
  creativity=0,
24
30
  reference_in_image=False,
25
31
  provide_reference=False,
@@ -40,8 +46,8 @@ def cerad_drawn_score(
40
46
 
41
47
  if shape == "circle":
42
48
  categories = ["The image contains a drawing that clearly represents a circle",
43
- "The drawing does not resemble a circle",
44
- "The drawing resembles a circle",
49
+ "The image does NOT contain any drawing that resembles a circle",
50
+ "The image contains a drawing that resembles a circle",
45
51
  "The circle is closed",
46
52
  "The circle is almost closed",
47
53
  "The circle is circular",
@@ -58,12 +64,12 @@ def cerad_drawn_score(
58
64
  "None of the above descriptions apply"]
59
65
  elif shape == "rectangles" or shape == "overlapping rectangles":
60
66
  categories = ["The image contains a drawing that clearly represents overlapping rectangles",
61
- "A drawn shape DOES NOT resemble a overlapping rectangles",
62
- "A drawn shape resembles a overlapping rectangles",
63
- "Rectangle 1 has 4 sides",
64
- "Rectangle 2 has 4 sides",
65
- "The rectangles are overlapping",
66
- "The rectangles overlap contains a longer vertical rectangle with top and bottom portruding",
67
+ "The image does NOT contain any drawing that resembles overlapping rectangles",
68
+ "The image contains a drawing that resembles overlapping rectangles",
69
+ "If rectangle 1 is present it has 4 sides",
70
+ "If rectablge 2 is present it has 4 sides",
71
+ "The drawn rectangles are overlapping",
72
+ "The drawn rectangles overlap to form a longer vertical rectangle with top and bottom sticking out",
67
73
  "None of the above descriptions apply"]
68
74
  elif shape == "cube":
69
75
  categories = ["The image contains a drawing that clearly represents a cube (3D box shape)",
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.25"
4
+ __version__ = "0.0.26"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
@@ -12,4 +12,5 @@ from .__about__ import (
12
12
  )
13
13
 
14
14
  from .cat_llm import *
15
- from .CERAD_functions import *
15
+ from .CERAD_functions import *
16
+ from .image_functions import *
@@ -0,0 +1,395 @@
1
+ #extract categories from corpus
2
+ def explore_corpus(
3
+ survey_question,
4
+ survey_input,
5
+ api_key,
6
+ research_question=None,
7
+ specificity="broad",
8
+ cat_num=10,
9
+ divisions=5,
10
+ user_model="gpt-4o-2024-11-20",
11
+ creativity=0,
12
+ filename="corpus_exploration.csv",
13
+ model_source="OpenAI"
14
+ ):
15
+ import os
16
+ import pandas as pd
17
+ import random
18
+ from openai import OpenAI
19
+ from openai import OpenAI, BadRequestError
20
+ from tqdm import tqdm
21
+
22
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
23
+ print()
24
+
25
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
26
+ chunk_size = int(chunk_size)
27
+
28
+ if chunk_size < (cat_num/2):
29
+ raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
30
+ f"Choose one solution: \n"
31
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
32
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
33
+
34
+ random_chunks = []
35
+ for i in range(divisions):
36
+ chunk = survey_input.sample(n=chunk_size).tolist()
37
+ random_chunks.append(chunk)
38
+
39
+ responses = []
40
+ responses_list = []
41
+
42
+ for i in tqdm(range(divisions), desc="Processing chunks"):
43
+ survey_participant_chunks = '; '.join(random_chunks[i])
44
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
45
+ Responses are each separated by a semicolon. \
46
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
47
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
48
+
49
+ if model_source == "OpenAI":
50
+ client = OpenAI(api_key=api_key)
51
+ try:
52
+ response_obj = client.chat.completions.create(
53
+ model=user_model,
54
+ messages=[
55
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
56
+ The specific task is to identify {specificity} categories of responses to a survey question. \
57
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
58
+ {'role': 'user', 'content': prompt}
59
+ ],
60
+ temperature=creativity
61
+ )
62
+ reply = response_obj.choices[0].message.content
63
+ responses.append(reply)
64
+ except BadRequestError as e:
65
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
66
+ error_msg = (f"Token limit exceeded for model {user_model}. "
67
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
68
+ raise ValueError(error_msg)
69
+ else:
70
+ print(f"OpenAI API error: {e}")
71
+ except Exception as e:
72
+ print(f"An error occurred: {e}")
73
+ else:
74
+ raise ValueError(f"Unsupported model_source: {model_source}")
75
+
76
+ # Extract just the text as a list
77
+ items = []
78
+ for line in responses[i].split('\n'):
79
+ if '. ' in line:
80
+ try:
81
+ items.append(line.split('. ', 1)[1])
82
+ except IndexError:
83
+ pass
84
+
85
+ responses_list.append(items)
86
+
87
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
88
+
89
+ #convert flat_list to a df
90
+ df = pd.DataFrame(flat_list, columns=['Category'])
91
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
92
+ df['counts'] = df['Category'].map(counts)
93
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
94
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
95
+
96
+ if filename is not None:
97
+ df.to_csv(filename, index=False)
98
+
99
+ return df
100
+
101
+ #extract top categories from corpus
102
+ def explore_common_categories(
103
+ survey_question,
104
+ survey_input,
105
+ api_key,
106
+ top_n=10,
107
+ cat_num=10,
108
+ divisions=5,
109
+ user_model="gpt-4o-2024-11-20",
110
+ creativity=0,
111
+ specificity="broad",
112
+ research_question=None,
113
+ filename=None,
114
+ model_source="OpenAI"
115
+ ):
116
+ import os
117
+ import pandas as pd
118
+ import random
119
+ from openai import OpenAI
120
+ from openai import OpenAI, BadRequestError
121
+ from tqdm import tqdm
122
+
123
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
124
+ print()
125
+
126
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
127
+ chunk_size = int(chunk_size)
128
+
129
+ if chunk_size < (cat_num/2):
130
+ raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
131
+ f"Choose one solution: \n"
132
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
133
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
134
+
135
+ random_chunks = []
136
+ for i in range(divisions):
137
+ chunk = survey_input.sample(n=chunk_size).tolist()
138
+ random_chunks.append(chunk)
139
+
140
+ responses = []
141
+ responses_list = []
142
+
143
+ for i in tqdm(range(divisions), desc="Processing chunks"):
144
+ survey_participant_chunks = '; '.join(random_chunks[i])
145
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
146
+ Responses are each separated by a semicolon. \
147
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
148
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
149
+
150
+ if model_source == "OpenAI":
151
+ client = OpenAI(api_key=api_key)
152
+ try:
153
+ response_obj = client.chat.completions.create(
154
+ model=user_model,
155
+ messages=[
156
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
157
+ The specific task is to identify {specificity} categories of responses to a survey question. \
158
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
159
+ {'role': 'user', 'content': prompt}
160
+ ],
161
+ temperature=creativity
162
+ )
163
+ reply = response_obj.choices[0].message.content
164
+ responses.append(reply)
165
+ except BadRequestError as e:
166
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
167
+ error_msg = (f"Token limit exceeded for model {user_model}. "
168
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
169
+ raise ValueError(error_msg)
170
+ else:
171
+ print(f"OpenAI API error: {e}")
172
+ except Exception as e:
173
+ print(f"An error occurred: {e}")
174
+ else:
175
+ raise ValueError(f"Unsupported model_source: {model_source}")
176
+
177
+ # Extract just the text as a list
178
+ items = []
179
+ for line in responses[i].split('\n'):
180
+ if '. ' in line:
181
+ try:
182
+ items.append(line.split('. ', 1)[1])
183
+ except IndexError:
184
+ pass
185
+
186
+ responses_list.append(items)
187
+
188
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
189
+
190
+ #convert flat_list to a df
191
+ df = pd.DataFrame(flat_list, columns=['Category'])
192
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
193
+ df['counts'] = df['Category'].map(counts)
194
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
195
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
196
+
197
+ second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
198
+ The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
199
+ Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
200
+
201
+ if model_source == "OpenAI":
202
+ client = OpenAI(api_key=api_key)
203
+ response_obj = client.chat.completions.create(
204
+ model=user_model,
205
+ messages=[{'role': 'user', 'content': second_prompt}],
206
+ temperature=creativity
207
+ )
208
+ top_categories = response_obj.choices[0].message.content
209
+ print(top_categories)
210
+
211
+ top_categories_final = []
212
+ for line in top_categories.split('\n'):
213
+ if '. ' in line:
214
+ try:
215
+ top_categories_final.append(line.split('. ', 1)[1])
216
+ except IndexError:
217
+ pass
218
+
219
+ return top_categories_final
220
+
221
+ #multi-class text classification
222
+ def extract_multi_class(
223
+ survey_question,
224
+ survey_input,
225
+ categories,
226
+ api_key,
227
+ columns="numbered",
228
+ user_model="gpt-4o-2024-11-20",
229
+ creativity=0,
230
+ to_csv=False,
231
+ safety=False,
232
+ filename="categorized_data.csv",
233
+ save_directory=None,
234
+ model_source="OpenAI"
235
+ ):
236
+ import os
237
+ import json
238
+ import pandas as pd
239
+ import regex
240
+ from tqdm import tqdm
241
+
242
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
243
+ cat_num = len(categories)
244
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
245
+ example_JSON = json.dumps(category_dict, indent=4)
246
+
247
+ # ensure number of categories is what user wants
248
+ print("\nThe categories you entered:")
249
+ for i, cat in enumerate(categories, 1):
250
+ print(f"{i}. {cat}")
251
+
252
+ link1 = []
253
+ extracted_jsons = []
254
+
255
+ for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
256
+ reply = None
257
+
258
+ if pd.isna(response):
259
+ link1.append("Skipped NaN input")
260
+ default_json = example_JSON
261
+ extracted_jsons.append(default_json)
262
+ #print(f"Skipped NaN input.")
263
+ else:
264
+ prompt = f"""A respondent was asked: {survey_question}. \
265
+ Categorize this survey response "{response}" into the following categories that apply: \
266
+ {categories_str} \
267
+ Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
268
+ #print(prompt)
269
+ if model_source == ("OpenAI"):
270
+ from openai import OpenAI
271
+ client = OpenAI(api_key=api_key)
272
+ try:
273
+ response_obj = client.chat.completions.create(
274
+ model=user_model,
275
+ messages=[{'role': 'user', 'content': prompt}],
276
+ temperature=creativity
277
+ )
278
+ reply = response_obj.choices[0].message.content
279
+ link1.append(reply)
280
+ except Exception as e:
281
+ print(f"An error occurred: {e}")
282
+ link1.append(f"Error processing input: {e}")
283
+ elif model_source == "Perplexity":
284
+ from openai import OpenAI
285
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
286
+ try:
287
+ response_obj = client.chat.completions.create(
288
+ model=user_model,
289
+ messages=[{'role': 'user', 'content': prompt}],
290
+ temperature=creativity
291
+ )
292
+ reply = response_obj.choices[0].message.content
293
+ link1.append(reply)
294
+ except Exception as e:
295
+ print(f"An error occurred: {e}")
296
+ link1.append(f"Error processing input: {e}")
297
+ elif model_source == "Anthropic":
298
+ import anthropic
299
+ client = anthropic.Anthropic(api_key=api_key)
300
+ try:
301
+ message = client.messages.create(
302
+ model=user_model,
303
+ max_tokens=1024,
304
+ temperature=creativity,
305
+ messages=[{"role": "user", "content": prompt}]
306
+ )
307
+ reply = message.content[0].text # Anthropic returns content as list
308
+ link1.append(reply)
309
+ except Exception as e:
310
+ print(f"An error occurred: {e}")
311
+ link1.append(f"Error processing input: {e}")
312
+ elif model_source == "Mistral":
313
+ from mistralai import Mistral
314
+ client = Mistral(api_key=api_key)
315
+ try:
316
+ response = client.chat.complete(
317
+ model=user_model,
318
+ messages=[
319
+ {'role': 'user', 'content': prompt}
320
+ ],
321
+ temperature=creativity
322
+ )
323
+ reply = response.choices[0].message.content
324
+ link1.append(reply)
325
+ except Exception as e:
326
+ print(f"An error occurred: {e}")
327
+ link1.append(f"Error processing input: {e}")
328
+ else:
329
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
330
+ # in situation that no JSON is found
331
+ if reply is not None:
332
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
333
+ if extracted_json:
334
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
335
+ extracted_jsons.append(cleaned_json)
336
+ #print(cleaned_json)
337
+ else:
338
+ error_message = """{"1":"e"}"""
339
+ extracted_jsons.append(error_message)
340
+ print(error_message)
341
+ else:
342
+ error_message = """{"1":"e"}"""
343
+ extracted_jsons.append(error_message)
344
+ #print(error_message)
345
+
346
+ # --- Safety Save ---
347
+ if safety:
348
+ # Save progress so far
349
+ temp_df = pd.DataFrame({
350
+ 'survey_response': survey_input[:idx+1],
351
+ 'link1': link1,
352
+ 'json': extracted_jsons
353
+ })
354
+ # Normalize processed jsons so far
355
+ normalized_data_list = []
356
+ for json_str in extracted_jsons:
357
+ try:
358
+ parsed_obj = json.loads(json_str)
359
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
360
+ except json.JSONDecodeError:
361
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
362
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
363
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
364
+ # Save to CSV
365
+ if save_directory is None:
366
+ save_directory = os.getcwd()
367
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
368
+
369
+ # --- Final DataFrame ---
370
+ normalized_data_list = []
371
+ for json_str in extracted_jsons:
372
+ try:
373
+ parsed_obj = json.loads(json_str)
374
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
375
+ except json.JSONDecodeError:
376
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
377
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
378
+
379
+ categorized_data = pd.DataFrame({
380
+ 'survey_response': survey_input.reset_index(drop=True),
381
+ 'link1': pd.Series(link1).reset_index(drop=True),
382
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
383
+ })
384
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
385
+
386
+ if columns != "numbered": #if user wants text columns
387
+ categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
388
+
389
+ if to_csv:
390
+ if save_directory is None:
391
+ save_directory = os.getcwd()
392
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
393
+
394
+ return categorized_data
395
+