cat-llm 0.0.19__tar.gz → 0.0.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.19
3
+ Version: 0.0.20
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -0,0 +1,330 @@
1
+ # image multi-class (binary) function
2
+ def cerad_score(
3
+ shape,
4
+ image_input,
5
+ api_key,
6
+ user_model="gpt-4o-2024-11-20",
7
+ creativity=0,
8
+ safety=False,
9
+ filename="categorized_data.csv",
10
+ model_source="OpenAI"
11
+ ):
12
+ import os
13
+ import json
14
+ import pandas as pd
15
+ import regex
16
+ from tqdm import tqdm
17
+ import glob
18
+ import base64
19
+ from pathlib import Path
20
+
21
+ shape = shape.lower()
22
+
23
+ if shape == "circle":
24
+ categories = ["It has a drawing of a circle",
25
+ "The drawing does not resemble a circle",
26
+ "The drawing resembles a circle",
27
+ "The circle is closed",
28
+ "The circle is almost closed",
29
+ "The circle is circular",
30
+ "The circle is almost circular",
31
+ "None of the above descriptions apply"]
32
+ elif shape == "diamond":
33
+ categories = ["It has a drawing of a diamond",
34
+ "It has a drawing of a square",
35
+ "A drawn shape DOES NOT resemble a diamond",
36
+ "A drawn shape resembles a diamond",
37
+ "The drawn shape has 4 sides",
38
+ "The drawn shape sides are about equal",
39
+ "If a diamond is drawn it's more elaborate than a simple diamond (such as overlapping diamonds or a diamond with an extras lines inside)",
40
+ "None of the above descriptions apply"]
41
+ elif shape == "rectangles" or shape == "overlapping rectangles":
42
+ categories = ["It has a drawing of overlapping rectangles",
43
+ "A drawn shape DOES NOT resemble a overlapping rectangles",
44
+ "A drawn shape resembles a overlapping rectangles",
45
+ "Rectangle 1 has 4 sides",
46
+ "Rectangle 2 has 4 sides",
47
+ "The rectangles are overlapping",
48
+ "The rectangles overlap contains a longer vertical rectangle with top and bottom portruding",
49
+ "None of the above descriptions apply"]
50
+ elif shape == "cube":
51
+ categories = ["The image contains a drawing that clearly represents a cube (3D box shape)",
52
+ "The image does NOT contain any drawing that resembles a cube or 3D box",
53
+ "The image contains a WELL-DRAWN recognizable cube with proper 3D perspective",
54
+ "If a cube is present: the front face appears as a square or diamond shape",
55
+ "If a cube is present: internal/hidden edges are visible (showing 3D depth, not just an outline)",
56
+ "If a cube is present: the front and back faces appear parallel to each other",
57
+ "The image contains only a 2D square (flat shape, no 3D appearance)",
58
+ "None of the above descriptions apply"]
59
+ else:
60
+ raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
61
+
62
+ image_extensions = [
63
+ '*.png', '*.jpg', '*.jpeg',
64
+ '*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
65
+ '*.tif', '*.tiff', '*.bmp',
66
+ '*.heif', '*.heic', '*.ico',
67
+ '*.psd'
68
+ ]
69
+
70
+ if not isinstance(image_input, list):
71
+ # If image_input is a filepath (string)
72
+ image_files = []
73
+ for ext in image_extensions:
74
+ image_files.extend(glob.glob(os.path.join(image_input, ext)))
75
+
76
+ print(f"Found {len(image_files)} images.")
77
+ else:
78
+ # If image_files is already a list
79
+ image_files = image_input
80
+ print(f"Provided a list of {len(image_input)} images.")
81
+
82
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
83
+ cat_num = len(categories)
84
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
85
+ example_JSON = json.dumps(category_dict, indent=4)
86
+
87
+ link1 = []
88
+ extracted_jsons = []
89
+
90
+ for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
91
+ # Check validity first
92
+ if img_path is None or not os.path.exists(img_path):
93
+ link1.append("Skipped NaN input or invalid path")
94
+ extracted_jsons.append("""{"no_valid_image": 1}""")
95
+ continue # Skip the rest of the loop iteration
96
+
97
+ # Only open the file if path is valid
98
+ with open(img_path, "rb") as f:
99
+ encoded = base64.b64encode(f.read()).decode("utf-8")
100
+
101
+ # Handle extension safely
102
+ ext = Path(img_path).suffix.lstrip(".").lower()
103
+ encoded_image = f"data:image/{ext};base64,{encoded}"
104
+
105
+ prompt = [
106
+ {
107
+ "type": "text",
108
+ "text": (
109
+ f"You are an image-tagging assistant trained in the CERAD Constructional Praxis test.\n"
110
+ f"Task ► Examine the attached image and decide, **for each category below**, "
111
+ f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
112
+ f"Image is expected to show within it a drawing of a {shape}.\n\n"
113
+ f"Categories:\n{categories_str}\n\n"
114
+ f"Output format ► Respond with **only** a JSON object whose keys are the "
115
+ f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
116
+ f"No additional keys, comments, or text.\n\n"
117
+ f"Example:\n"
118
+ f"{example_JSON}"
119
+ ),
120
+ },
121
+ {
122
+ "type": "image_url",
123
+ "image_url": {"url": encoded_image, "detail": "high"},
124
+ },
125
+ ]
126
+ if model_source == "OpenAI":
127
+ from openai import OpenAI
128
+ client = OpenAI(api_key=api_key)
129
+ try:
130
+ response_obj = client.chat.completions.create(
131
+ model=user_model,
132
+ messages=[{'role': 'user', 'content': prompt}],
133
+ temperature=creativity
134
+ )
135
+ reply = response_obj.choices[0].message.content
136
+ link1.append(reply)
137
+ except Exception as e:
138
+ print(f"An error occurred: {e}")
139
+ link1.append(f"Error processing input: {e}")
140
+
141
+ elif model_source == "Perplexity":
142
+ from openai import OpenAI
143
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
144
+ try:
145
+ response_obj = client.chat.completions.create(
146
+ model=user_model,
147
+ messages=[{'role': 'user', 'content': prompt}],
148
+ temperature=creativity
149
+ )
150
+ reply = response_obj.choices[0].message.content
151
+ link1.append(reply)
152
+ except Exception as e:
153
+ print(f"An error occurred: {e}")
154
+ link1.append(f"Error processing input: {e}")
155
+ elif model_source == "Anthropic":
156
+ import anthropic
157
+ client = anthropic.Anthropic(api_key=api_key)
158
+ try:
159
+ message = client.messages.create(
160
+ model=user_model,
161
+ max_tokens=1024,
162
+ temperature=creativity,
163
+ messages=[{"role": "user", "content": prompt}]
164
+ )
165
+ reply = message.content[0].text # Anthropic returns content as list
166
+ link1.append(reply)
167
+ except Exception as e:
168
+ print(f"An error occurred: {e}")
169
+ link1.append(f"Error processing input: {e}")
170
+ elif model_source == "Mistral":
171
+ from mistralai import Mistral
172
+ client = Mistral(api_key=api_key)
173
+ try:
174
+ response = client.chat.complete(
175
+ model=user_model,
176
+ messages=[
177
+ {'role': 'user', 'content': prompt}
178
+ ],
179
+ temperature=creativity
180
+ )
181
+ reply = response.choices[0].message.content
182
+ link1.append(reply)
183
+ except Exception as e:
184
+ print(f"An error occurred: {e}")
185
+ link1.append(f"Error processing input: {e}")
186
+ else:
187
+ raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
188
+ # in situation that no JSON is found
189
+ if reply is not None:
190
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
191
+ if extracted_json:
192
+ cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
193
+ extracted_jsons.append(cleaned_json)
194
+ #print(cleaned_json)
195
+ else:
196
+ error_message = """{"1":"e"}"""
197
+ extracted_jsons.append(error_message)
198
+ print(error_message)
199
+ else:
200
+ error_message = """{"1":"e"}"""
201
+ extracted_jsons.append(error_message)
202
+ #print(error_message)
203
+
204
+ # --- Safety Save ---
205
+ if safety:
206
+ #print(f"Saving CSV to: {save_directory}")
207
+ # Save progress so far
208
+ temp_df = pd.DataFrame({
209
+ 'image_input': image_files[:i+1],
210
+ 'link1': link1,
211
+ 'json': extracted_jsons
212
+ })
213
+ # Normalize processed jsons so far
214
+ normalized_data_list = []
215
+ for json_str in extracted_jsons:
216
+ try:
217
+ parsed_obj = json.loads(json_str)
218
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
219
+ except json.JSONDecodeError:
220
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
221
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
222
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
223
+ # Save to CSV
224
+ if filename is None:
225
+ filepath = os.path.join(os.getcwd(), 'catllm_data.csv')
226
+ else:
227
+ filepath = filename
228
+ temp_df.to_csv(filepath, index=False)
229
+
230
+ # --- Final DataFrame ---
231
+ normalized_data_list = []
232
+ for json_str in extracted_jsons:
233
+ try:
234
+ parsed_obj = json.loads(json_str)
235
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
236
+ except json.JSONDecodeError:
237
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
238
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
239
+
240
+ categorized_data = pd.DataFrame({
241
+ 'image_input': image_files,
242
+ 'link1': pd.Series(link1).reset_index(drop=True),
243
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
244
+ })
245
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
246
+ columns_to_convert = ["1", "2", "3", "4", "5", "6", "7"]
247
+ categorized_data[columns_to_convert] = categorized_data[columns_to_convert].apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
248
+
249
+ if shape == "circle":
250
+
251
+ categorized_data = categorized_data.rename(columns={
252
+ "1": "drawing_present",
253
+ "2": "not_similar",
254
+ "3": "similar",
255
+ "4": "cir_closed",
256
+ "5": "cir_almost_closed",
257
+ "6": "cir_round",
258
+ "7": "cir_almost_round",
259
+ "8": "none"
260
+ })
261
+
262
+ categorized_data['score'] = categorized_data['cir_almost_closed'] + categorized_data['cir_closed'] + categorized_data['cir_round'] + categorized_data['cir_almost_round']
263
+ categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
264
+ categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
265
+
266
+ elif shape == "diamond":
267
+
268
+ categorized_data = categorized_data.rename(columns={
269
+ "1": "drawing_present",
270
+ "2": "diamond_square",
271
+ "3": "not_similar",
272
+ "4": "similar",
273
+ "5": "diamond_4_sides",
274
+ "6": "diamond_equal_sides",
275
+ "7": "complex_diamond",
276
+ "8": "none"
277
+ })
278
+
279
+ categorized_data['score'] = categorized_data['diamond_4_sides'] + categorized_data['diamond_equal_sides'] + categorized_data['similar']
280
+
281
+ categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
282
+ categorized_data.loc[(categorized_data['diamond_square'] == 1) & (categorized_data['score'] == 0), 'score'] = 2
283
+
284
+ elif shape == "rectangles" or shape == "overlapping rectangles":
285
+
286
+ categorized_data = categorized_data.rename(columns={
287
+ "1":"drawing_present",
288
+ "2": "not_similar",
289
+ "3": "similar",
290
+ "4": "r1_4_sides",
291
+ "5": "r2_4_sides",
292
+ "6": "rectangles_overlap",
293
+ "7": "rectangles_cross",
294
+ "8": "none"
295
+ })
296
+
297
+ categorized_data['score'] = 0
298
+ categorized_data.loc[(categorized_data['r1_4_sides'] == 1) & (categorized_data['r2_4_sides'] == 1), 'score'] = 1
299
+ categorized_data.loc[(categorized_data['rectangles_overlap'] == 1) & (categorized_data['rectangles_cross'] == 1), 'score'] += 1
300
+ categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
301
+
302
+ elif shape == "cube":
303
+
304
+ categorized_data = categorized_data.rename(columns={
305
+ "1": "drawing_present",
306
+ "2": "not_similar",
307
+ "3": "similar",
308
+ "4": "cube_front_face",
309
+ "5": "cube_internal_lines",
310
+ "6": "cube_opposite_sides",
311
+ "7": "square_only",
312
+ "8": "none"
313
+ })
314
+
315
+ categorized_data['score'] = categorized_data['cube_front_face'] + categorized_data['cube_internal_lines'] + categorized_data['cube_opposite_sides'] + categorized_data['similar']
316
+ categorized_data.loc[categorized_data['similar'] == 1, 'score'] = categorized_data['score'] + 1
317
+ categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
318
+ categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
319
+ categorized_data.loc[(categorized_data['not_similar'] == 1) & (categorized_data['score'] == 0), 'score'] = 0
320
+ categorized_data.loc[categorized_data['score'] > 4, 'score'] = 4
321
+
322
+ else:
323
+ raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
324
+
325
+ categorized_data.loc[categorized_data['no_valid_image'] == 1, 'score'] = None
326
+
327
+ if filename is not None:
328
+ categorized_data.to_csv(filename, index=False)
329
+
330
+ return categorized_data
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.19"
4
+ __version__ = "0.0.20"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
@@ -11,4 +11,5 @@ from .__about__ import (
11
11
  __license__,
12
12
  )
13
13
 
14
- from .cat_llm import *
14
+ from .cat_llm import *
15
+ from .CERAD_functions import *
@@ -1,3 +1,223 @@
1
+ #extract categories from corpus
2
+ def explore_corpus(
3
+ survey_question,
4
+ survey_input,
5
+ api_key,
6
+ research_question=None,
7
+ specificity="broad",
8
+ cat_num=10,
9
+ divisions=5,
10
+ user_model="gpt-4o-2024-11-20",
11
+ creativity=0,
12
+ filename="corpus_exploration.csv",
13
+ model_source="OpenAI"
14
+ ):
15
+ import os
16
+ import pandas as pd
17
+ import random
18
+ from openai import OpenAI
19
+ from openai import OpenAI, BadRequestError
20
+ from tqdm import tqdm
21
+
22
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
23
+ print()
24
+
25
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
26
+ chunk_size = int(chunk_size)
27
+
28
+ if chunk_size < (cat_num/2):
29
+ raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
30
+ f"Choose one solution: \n"
31
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
32
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
33
+
34
+ random_chunks = []
35
+ for i in range(divisions):
36
+ chunk = survey_input.sample(n=chunk_size).tolist()
37
+ random_chunks.append(chunk)
38
+
39
+ responses = []
40
+ responses_list = []
41
+
42
+ for i in tqdm(range(divisions), desc="Processing chunks"):
43
+ survey_participant_chunks = '; '.join(random_chunks[i])
44
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
45
+ Responses are each separated by a semicolon. \
46
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
47
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
48
+
49
+ if model_source == "OpenAI":
50
+ client = OpenAI(api_key=api_key)
51
+ try:
52
+ response_obj = client.chat.completions.create(
53
+ model=user_model,
54
+ messages=[
55
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
56
+ The specific task is to identify {specificity} categories of responses to a survey question. \
57
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
58
+ {'role': 'user', 'content': prompt}
59
+ ]
60
+ temperature=creativity
61
+ )
62
+ reply = response_obj.choices[0].message.content
63
+ responses.append(reply)
64
+ except BadRequestError as e:
65
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
66
+ error_msg = (f"Token limit exceeded for model {user_model}. "
67
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
68
+ raise ValueError(error_msg)
69
+ else:
70
+ print(f"OpenAI API error: {e}")
71
+ except Exception as e:
72
+ print(f"An error occurred: {e}")
73
+ else:
74
+ raise ValueError(f"Unsupported model_source: {model_source}")
75
+
76
+ # Extract just the text as a list
77
+ items = []
78
+ for line in responses[i].split('\n'):
79
+ if '. ' in line:
80
+ try:
81
+ items.append(line.split('. ', 1)[1])
82
+ except IndexError:
83
+ pass
84
+
85
+ responses_list.append(items)
86
+
87
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
88
+
89
+ #convert flat_list to a df
90
+ df = pd.DataFrame(flat_list, columns=['Category'])
91
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
92
+ df['counts'] = df['Category'].map(counts)
93
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
94
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
95
+
96
+ if filename is not None:
97
+ df.to_csv(filename, index=False)
98
+
99
+ return df
100
+
101
+ #extract top categories from corpus
102
+ def explore_common_categories(
103
+ survey_question,
104
+ survey_input,
105
+ api_key,
106
+ top_n=10,
107
+ cat_num=10,
108
+ divisions=5,
109
+ user_model="gpt-4o-2024-11-20",
110
+ creativity=0,
111
+ specificity="broad",
112
+ research_question=None,
113
+ filename=None,
114
+ model_source="OpenAI"
115
+ ):
116
+ import os
117
+ import pandas as pd
118
+ import random
119
+ from openai import OpenAI
120
+ from openai import OpenAI, BadRequestError
121
+ from tqdm import tqdm
122
+
123
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
124
+ print()
125
+
126
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
127
+ chunk_size = int(chunk_size)
128
+
129
+ if chunk_size < (cat_num/2):
130
+ raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
131
+ f"Choose one solution: \n"
132
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
133
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
134
+
135
+ random_chunks = []
136
+ for i in range(divisions):
137
+ chunk = survey_input.sample(n=chunk_size).tolist()
138
+ random_chunks.append(chunk)
139
+
140
+ responses = []
141
+ responses_list = []
142
+
143
+ for i in tqdm(range(divisions), desc="Processing chunks"):
144
+ survey_participant_chunks = '; '.join(random_chunks[i])
145
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
146
+ Responses are each separated by a semicolon. \
147
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
148
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
149
+
150
+ if model_source == "OpenAI":
151
+ client = OpenAI(api_key=api_key)
152
+ try:
153
+ response_obj = client.chat.completions.create(
154
+ model=user_model,
155
+ messages=[
156
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
157
+ The specific task is to identify {specificity} categories of responses to a survey question. \
158
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
159
+ {'role': 'user', 'content': prompt}
160
+ ],
161
+ temperature=creativity
162
+ )
163
+ reply = response_obj.choices[0].message.content
164
+ responses.append(reply)
165
+ except BadRequestError as e:
166
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
167
+ error_msg = (f"Token limit exceeded for model {user_model}. "
168
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
169
+ raise ValueError(error_msg)
170
+ else:
171
+ print(f"OpenAI API error: {e}")
172
+ except Exception as e:
173
+ print(f"An error occurred: {e}")
174
+ else:
175
+ raise ValueError(f"Unsupported model_source: {model_source}")
176
+
177
+ # Extract just the text as a list
178
+ items = []
179
+ for line in responses[i].split('\n'):
180
+ if '. ' in line:
181
+ try:
182
+ items.append(line.split('. ', 1)[1])
183
+ except IndexError:
184
+ pass
185
+
186
+ responses_list.append(items)
187
+
188
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
189
+
190
+ #convert flat_list to a df
191
+ df = pd.DataFrame(flat_list, columns=['Category'])
192
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
193
+ df['counts'] = df['Category'].map(counts)
194
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
195
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
196
+
197
+ second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
198
+ The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
199
+ Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
200
+
201
+ if model_source == "OpenAI":
202
+ client = OpenAI(api_key=api_key)
203
+ response_obj = client.chat.completions.create(
204
+ model=user_model,
205
+ messages=[{'role': 'user', 'content': second_prompt}],
206
+ temperature=creativity
207
+ )
208
+ top_categories = response_obj.choices[0].message.content
209
+ print(top_categories)
210
+
211
+ top_categories_final = []
212
+ for line in top_categories.split('\n'):
213
+ if '. ' in line:
214
+ try:
215
+ top_categories_final.append(line.split('. ', 1)[1])
216
+ except IndexError:
217
+ pass
218
+
219
+ return top_categories_final
220
+
1
221
  #multi-class text classification
2
222
  def extract_multi_class(
3
223
  survey_question,
@@ -852,224 +1072,4 @@ def extract_image_features(
852
1072
  save_directory = os.getcwd()
853
1073
  categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
854
1074
 
855
- return categorized_data
856
-
857
- #extract categories from corpus
858
- def explore_corpus(
859
- survey_question,
860
- survey_input,
861
- api_key,
862
- research_question=None,
863
- specificity="broad",
864
- cat_num=10,
865
- divisions=5,
866
- user_model="gpt-4o-2024-11-20",
867
- creativity=0,
868
- filename="corpus_exploration.csv",
869
- model_source="OpenAI"
870
- ):
871
- import os
872
- import pandas as pd
873
- import random
874
- from openai import OpenAI
875
- from openai import OpenAI, BadRequestError
876
- from tqdm import tqdm
877
-
878
- print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
879
- print()
880
-
881
- chunk_size = round(max(1, len(survey_input) / divisions),0)
882
- chunk_size = int(chunk_size)
883
-
884
- if chunk_size < (cat_num/2):
885
- raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
886
- f"Choose one solution: \n"
887
- f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
888
- f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
889
-
890
- random_chunks = []
891
- for i in range(divisions):
892
- chunk = survey_input.sample(n=chunk_size).tolist()
893
- random_chunks.append(chunk)
894
-
895
- responses = []
896
- responses_list = []
897
-
898
- for i in tqdm(range(divisions), desc="Processing chunks"):
899
- survey_participant_chunks = '; '.join(random_chunks[i])
900
- prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
901
- Responses are each separated by a semicolon. \
902
- Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
903
- Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
904
-
905
- if model_source == "OpenAI":
906
- client = OpenAI(api_key=api_key)
907
- try:
908
- response_obj = client.chat.completions.create(
909
- model=user_model,
910
- messages=[
911
- {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
912
- The specific task is to identify {specificity} categories of responses to a survey question. \
913
- The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
914
- {'role': 'user', 'content': prompt}
915
- ]
916
- temperature=creativity
917
- )
918
- reply = response_obj.choices[0].message.content
919
- responses.append(reply)
920
- except BadRequestError as e:
921
- if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
922
- error_msg = (f"Token limit exceeded for model {user_model}. "
923
- f"Try increasing the 'iterations' parameter to create smaller chunks.")
924
- raise ValueError(error_msg)
925
- else:
926
- print(f"OpenAI API error: {e}")
927
- except Exception as e:
928
- print(f"An error occurred: {e}")
929
- else:
930
- raise ValueError(f"Unsupported model_source: {model_source}")
931
-
932
- # Extract just the text as a list
933
- items = []
934
- for line in responses[i].split('\n'):
935
- if '. ' in line:
936
- try:
937
- items.append(line.split('. ', 1)[1])
938
- except IndexError:
939
- pass
940
-
941
- responses_list.append(items)
942
-
943
- flat_list = [item.lower() for sublist in responses_list for item in sublist]
944
-
945
- #convert flat_list to a df
946
- df = pd.DataFrame(flat_list, columns=['Category'])
947
- counts = pd.Series(flat_list).value_counts() # Use original list before conversion
948
- df['counts'] = df['Category'].map(counts)
949
- df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
950
- df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
951
-
952
- if filename is not None:
953
- df.to_csv(filename, index=False)
954
-
955
- return df
956
-
957
- #extract top categories from corpus
958
- def explore_common_categories(
959
- survey_question,
960
- survey_input,
961
- api_key,
962
- top_n=10,
963
- cat_num=10,
964
- divisions=5,
965
- user_model="gpt-4o-2024-11-20",
966
- creativity=0,
967
- specificity="broad",
968
- research_question=None,
969
- filename=None,
970
- model_source="OpenAI"
971
- ):
972
- import os
973
- import pandas as pd
974
- import random
975
- from openai import OpenAI
976
- from openai import OpenAI, BadRequestError
977
- from tqdm import tqdm
978
-
979
- print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
980
- print()
981
-
982
- chunk_size = round(max(1, len(survey_input) / divisions),0)
983
- chunk_size = int(chunk_size)
984
-
985
- if chunk_size < (cat_num/2):
986
- raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
987
- f"Choose one solution: \n"
988
- f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
989
- f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
990
-
991
- random_chunks = []
992
- for i in range(divisions):
993
- chunk = survey_input.sample(n=chunk_size).tolist()
994
- random_chunks.append(chunk)
995
-
996
- responses = []
997
- responses_list = []
998
-
999
- for i in tqdm(range(divisions), desc="Processing chunks"):
1000
- survey_participant_chunks = '; '.join(random_chunks[i])
1001
- prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
1002
- Responses are each separated by a semicolon. \
1003
- Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
1004
- Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
1005
-
1006
- if model_source == "OpenAI":
1007
- client = OpenAI(api_key=api_key)
1008
- try:
1009
- response_obj = client.chat.completions.create(
1010
- model=user_model,
1011
- messages=[
1012
- {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
1013
- The specific task is to identify {specificity} categories of responses to a survey question. \
1014
- The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
1015
- {'role': 'user', 'content': prompt}
1016
- ],
1017
- temperature=creativity
1018
- )
1019
- reply = response_obj.choices[0].message.content
1020
- responses.append(reply)
1021
- except BadRequestError as e:
1022
- if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
1023
- error_msg = (f"Token limit exceeded for model {user_model}. "
1024
- f"Try increasing the 'iterations' parameter to create smaller chunks.")
1025
- raise ValueError(error_msg)
1026
- else:
1027
- print(f"OpenAI API error: {e}")
1028
- except Exception as e:
1029
- print(f"An error occurred: {e}")
1030
- else:
1031
- raise ValueError(f"Unsupported model_source: {model_source}")
1032
-
1033
- # Extract just the text as a list
1034
- items = []
1035
- for line in responses[i].split('\n'):
1036
- if '. ' in line:
1037
- try:
1038
- items.append(line.split('. ', 1)[1])
1039
- except IndexError:
1040
- pass
1041
-
1042
- responses_list.append(items)
1043
-
1044
- flat_list = [item.lower() for sublist in responses_list for item in sublist]
1045
-
1046
- #convert flat_list to a df
1047
- df = pd.DataFrame(flat_list, columns=['Category'])
1048
- counts = pd.Series(flat_list).value_counts() # Use original list before conversion
1049
- df['counts'] = df['Category'].map(counts)
1050
- df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
1051
- df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
1052
-
1053
- second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
1054
- The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
1055
- Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
1056
-
1057
- if model_source == "OpenAI":
1058
- client = OpenAI(api_key=api_key)
1059
- response_obj = client.chat.completions.create(
1060
- model=user_model,
1061
- messages=[{'role': 'user', 'content': second_prompt}],
1062
- temperature=creativity
1063
- )
1064
- top_categories = response_obj.choices[0].message.content
1065
- print(top_categories)
1066
-
1067
- top_categories_final = []
1068
- for line in top_categories.split('\n'):
1069
- if '. ' in line:
1070
- try:
1071
- top_categories_final.append(line.split('. ', 1)[1])
1072
- except IndexError:
1073
- pass
1074
-
1075
- return top_categories_final
1075
+ return categorized_data
File without changes
File without changes
File without changes