cat-llm 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.19.dist-info → cat_llm-0.0.21.dist-info}/METADATA +1 -1
- cat_llm-0.0.21.dist-info/RECORD +8 -0
- catllm/CERAD_functions.py +330 -0
- catllm/__about__.py +1 -1
- catllm/__init__.py +2 -1
- catllm/cat_llm.py +509 -181
- cat_llm-0.0.19.dist-info/RECORD +0 -7
- {cat_llm-0.0.19.dist-info → cat_llm-0.0.21.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.19.dist-info → cat_llm-0.0.21.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
catllm/CERAD_functions.py,sha256=Qb6-X4147pLkiGYnhCiTZbUllnI_Phc78NeEGm9IYic,14560
|
|
2
|
+
catllm/__about__.py,sha256=TunDrsvgMTHUI-XmOUy6ZvCm0jyzgtPLSTQ4FQ-i1EE,404
|
|
3
|
+
catllm/__init__.py,sha256=bgH_2K70m3WP9B8GZNciAr7ld7bSNwgpzhYgkoAC8Bo,299
|
|
4
|
+
catllm/cat_llm.py,sha256=z3ohaq2sXrcI6ygvQTKNNlAN5ptmNmwCimkM625j7LQ,58205
|
|
5
|
+
cat_llm-0.0.21.dist-info/METADATA,sha256=1AzIw7qX1yrWpsozHM-XfTvzk2R8nSY35yrogvt93uY,1679
|
|
6
|
+
cat_llm-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
cat_llm-0.0.21.dist-info/licenses/LICENSE,sha256=wJLsvOr6lrFUDcoPXExa01HOKFWrS3JC9f0RudRw8uw,1075
|
|
8
|
+
cat_llm-0.0.21.dist-info/RECORD,,
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# image multi-class (binary) function
|
|
2
|
+
def cerad_score(
|
|
3
|
+
shape,
|
|
4
|
+
image_input,
|
|
5
|
+
api_key,
|
|
6
|
+
user_model="gpt-4o-2024-11-20",
|
|
7
|
+
creativity=0,
|
|
8
|
+
safety=False,
|
|
9
|
+
filename="categorized_data.csv",
|
|
10
|
+
model_source="OpenAI"
|
|
11
|
+
):
|
|
12
|
+
import os
|
|
13
|
+
import json
|
|
14
|
+
import pandas as pd
|
|
15
|
+
import regex
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
import glob
|
|
18
|
+
import base64
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
|
|
21
|
+
shape = shape.lower()
|
|
22
|
+
|
|
23
|
+
if shape == "circle":
|
|
24
|
+
categories = ["It has a drawing of a circle",
|
|
25
|
+
"The drawing does not resemble a circle",
|
|
26
|
+
"The drawing resembles a circle",
|
|
27
|
+
"The circle is closed",
|
|
28
|
+
"The circle is almost closed",
|
|
29
|
+
"The circle is circular",
|
|
30
|
+
"The circle is almost circular",
|
|
31
|
+
"None of the above descriptions apply"]
|
|
32
|
+
elif shape == "diamond":
|
|
33
|
+
categories = ["It has a drawing of a diamond",
|
|
34
|
+
"It has a drawing of a square",
|
|
35
|
+
"A drawn shape DOES NOT resemble a diamond",
|
|
36
|
+
"A drawn shape resembles a diamond",
|
|
37
|
+
"The drawn shape has 4 sides",
|
|
38
|
+
"The drawn shape sides are about equal",
|
|
39
|
+
"If a diamond is drawn it's more elaborate than a simple diamond (such as overlapping diamonds or a diamond with an extras lines inside)",
|
|
40
|
+
"None of the above descriptions apply"]
|
|
41
|
+
elif shape == "rectangles" or shape == "overlapping rectangles":
|
|
42
|
+
categories = ["It has a drawing of overlapping rectangles",
|
|
43
|
+
"A drawn shape DOES NOT resemble a overlapping rectangles",
|
|
44
|
+
"A drawn shape resembles a overlapping rectangles",
|
|
45
|
+
"Rectangle 1 has 4 sides",
|
|
46
|
+
"Rectangle 2 has 4 sides",
|
|
47
|
+
"The rectangles are overlapping",
|
|
48
|
+
"The rectangles overlap contains a longer vertical rectangle with top and bottom portruding",
|
|
49
|
+
"None of the above descriptions apply"]
|
|
50
|
+
elif shape == "cube":
|
|
51
|
+
categories = ["The image contains a drawing that clearly represents a cube (3D box shape)",
|
|
52
|
+
"The image does NOT contain any drawing that resembles a cube or 3D box",
|
|
53
|
+
"The image contains a WELL-DRAWN recognizable cube with proper 3D perspective",
|
|
54
|
+
"If a cube is present: the front face appears as a square or diamond shape",
|
|
55
|
+
"If a cube is present: internal/hidden edges are visible (showing 3D depth, not just an outline)",
|
|
56
|
+
"If a cube is present: the front and back faces appear parallel to each other",
|
|
57
|
+
"The image contains only a 2D square (flat shape, no 3D appearance)",
|
|
58
|
+
"None of the above descriptions apply"]
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
|
|
61
|
+
|
|
62
|
+
image_extensions = [
|
|
63
|
+
'*.png', '*.jpg', '*.jpeg',
|
|
64
|
+
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
65
|
+
'*.tif', '*.tiff', '*.bmp',
|
|
66
|
+
'*.heif', '*.heic', '*.ico',
|
|
67
|
+
'*.psd'
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
if not isinstance(image_input, list):
|
|
71
|
+
# If image_input is a filepath (string)
|
|
72
|
+
image_files = []
|
|
73
|
+
for ext in image_extensions:
|
|
74
|
+
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
75
|
+
|
|
76
|
+
print(f"Found {len(image_files)} images.")
|
|
77
|
+
else:
|
|
78
|
+
# If image_files is already a list
|
|
79
|
+
image_files = image_input
|
|
80
|
+
print(f"Provided a list of {len(image_input)} images.")
|
|
81
|
+
|
|
82
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
83
|
+
cat_num = len(categories)
|
|
84
|
+
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
85
|
+
example_JSON = json.dumps(category_dict, indent=4)
|
|
86
|
+
|
|
87
|
+
link1 = []
|
|
88
|
+
extracted_jsons = []
|
|
89
|
+
|
|
90
|
+
for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
|
|
91
|
+
# Check validity first
|
|
92
|
+
if img_path is None or not os.path.exists(img_path):
|
|
93
|
+
link1.append("Skipped NaN input or invalid path")
|
|
94
|
+
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
95
|
+
continue # Skip the rest of the loop iteration
|
|
96
|
+
|
|
97
|
+
# Only open the file if path is valid
|
|
98
|
+
with open(img_path, "rb") as f:
|
|
99
|
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
100
|
+
|
|
101
|
+
# Handle extension safely
|
|
102
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
103
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
104
|
+
|
|
105
|
+
prompt = [
|
|
106
|
+
{
|
|
107
|
+
"type": "text",
|
|
108
|
+
"text": (
|
|
109
|
+
f"You are an image-tagging assistant trained in the CERAD Constructional Praxis test.\n"
|
|
110
|
+
f"Task ► Examine the attached image and decide, **for each category below**, "
|
|
111
|
+
f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
|
|
112
|
+
f"Image is expected to show within it a drawing of a {shape}.\n\n"
|
|
113
|
+
f"Categories:\n{categories_str}\n\n"
|
|
114
|
+
f"Output format ► Respond with **only** a JSON object whose keys are the "
|
|
115
|
+
f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
|
|
116
|
+
f"No additional keys, comments, or text.\n\n"
|
|
117
|
+
f"Example:\n"
|
|
118
|
+
f"{example_JSON}"
|
|
119
|
+
),
|
|
120
|
+
},
|
|
121
|
+
{
|
|
122
|
+
"type": "image_url",
|
|
123
|
+
"image_url": {"url": encoded_image, "detail": "high"},
|
|
124
|
+
},
|
|
125
|
+
]
|
|
126
|
+
if model_source == "OpenAI":
|
|
127
|
+
from openai import OpenAI
|
|
128
|
+
client = OpenAI(api_key=api_key)
|
|
129
|
+
try:
|
|
130
|
+
response_obj = client.chat.completions.create(
|
|
131
|
+
model=user_model,
|
|
132
|
+
messages=[{'role': 'user', 'content': prompt}],
|
|
133
|
+
temperature=creativity
|
|
134
|
+
)
|
|
135
|
+
reply = response_obj.choices[0].message.content
|
|
136
|
+
link1.append(reply)
|
|
137
|
+
except Exception as e:
|
|
138
|
+
print(f"An error occurred: {e}")
|
|
139
|
+
link1.append(f"Error processing input: {e}")
|
|
140
|
+
|
|
141
|
+
elif model_source == "Perplexity":
|
|
142
|
+
from openai import OpenAI
|
|
143
|
+
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
144
|
+
try:
|
|
145
|
+
response_obj = client.chat.completions.create(
|
|
146
|
+
model=user_model,
|
|
147
|
+
messages=[{'role': 'user', 'content': prompt}],
|
|
148
|
+
temperature=creativity
|
|
149
|
+
)
|
|
150
|
+
reply = response_obj.choices[0].message.content
|
|
151
|
+
link1.append(reply)
|
|
152
|
+
except Exception as e:
|
|
153
|
+
print(f"An error occurred: {e}")
|
|
154
|
+
link1.append(f"Error processing input: {e}")
|
|
155
|
+
elif model_source == "Anthropic":
|
|
156
|
+
import anthropic
|
|
157
|
+
client = anthropic.Anthropic(api_key=api_key)
|
|
158
|
+
try:
|
|
159
|
+
message = client.messages.create(
|
|
160
|
+
model=user_model,
|
|
161
|
+
max_tokens=1024,
|
|
162
|
+
temperature=creativity,
|
|
163
|
+
messages=[{"role": "user", "content": prompt}]
|
|
164
|
+
)
|
|
165
|
+
reply = message.content[0].text # Anthropic returns content as list
|
|
166
|
+
link1.append(reply)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
print(f"An error occurred: {e}")
|
|
169
|
+
link1.append(f"Error processing input: {e}")
|
|
170
|
+
elif model_source == "Mistral":
|
|
171
|
+
from mistralai import Mistral
|
|
172
|
+
client = Mistral(api_key=api_key)
|
|
173
|
+
try:
|
|
174
|
+
response = client.chat.complete(
|
|
175
|
+
model=user_model,
|
|
176
|
+
messages=[
|
|
177
|
+
{'role': 'user', 'content': prompt}
|
|
178
|
+
],
|
|
179
|
+
temperature=creativity
|
|
180
|
+
)
|
|
181
|
+
reply = response.choices[0].message.content
|
|
182
|
+
link1.append(reply)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
print(f"An error occurred: {e}")
|
|
185
|
+
link1.append(f"Error processing input: {e}")
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
188
|
+
# in situation that no JSON is found
|
|
189
|
+
if reply is not None:
|
|
190
|
+
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
191
|
+
if extracted_json:
|
|
192
|
+
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
193
|
+
extracted_jsons.append(cleaned_json)
|
|
194
|
+
#print(cleaned_json)
|
|
195
|
+
else:
|
|
196
|
+
error_message = """{"1":"e"}"""
|
|
197
|
+
extracted_jsons.append(error_message)
|
|
198
|
+
print(error_message)
|
|
199
|
+
else:
|
|
200
|
+
error_message = """{"1":"e"}"""
|
|
201
|
+
extracted_jsons.append(error_message)
|
|
202
|
+
#print(error_message)
|
|
203
|
+
|
|
204
|
+
# --- Safety Save ---
|
|
205
|
+
if safety:
|
|
206
|
+
#print(f"Saving CSV to: {save_directory}")
|
|
207
|
+
# Save progress so far
|
|
208
|
+
temp_df = pd.DataFrame({
|
|
209
|
+
'image_input': image_files[:i+1],
|
|
210
|
+
'link1': link1,
|
|
211
|
+
'json': extracted_jsons
|
|
212
|
+
})
|
|
213
|
+
# Normalize processed jsons so far
|
|
214
|
+
normalized_data_list = []
|
|
215
|
+
for json_str in extracted_jsons:
|
|
216
|
+
try:
|
|
217
|
+
parsed_obj = json.loads(json_str)
|
|
218
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
219
|
+
except json.JSONDecodeError:
|
|
220
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
221
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
222
|
+
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
223
|
+
# Save to CSV
|
|
224
|
+
if filename is None:
|
|
225
|
+
filepath = os.path.join(os.getcwd(), 'catllm_data.csv')
|
|
226
|
+
else:
|
|
227
|
+
filepath = filename
|
|
228
|
+
temp_df.to_csv(filepath, index=False)
|
|
229
|
+
|
|
230
|
+
# --- Final DataFrame ---
|
|
231
|
+
normalized_data_list = []
|
|
232
|
+
for json_str in extracted_jsons:
|
|
233
|
+
try:
|
|
234
|
+
parsed_obj = json.loads(json_str)
|
|
235
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
236
|
+
except json.JSONDecodeError:
|
|
237
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
238
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
239
|
+
|
|
240
|
+
categorized_data = pd.DataFrame({
|
|
241
|
+
'image_input': image_files,
|
|
242
|
+
'link1': pd.Series(link1).reset_index(drop=True),
|
|
243
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
244
|
+
})
|
|
245
|
+
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
246
|
+
columns_to_convert = ["1", "2", "3", "4", "5", "6", "7"]
|
|
247
|
+
categorized_data[columns_to_convert] = categorized_data[columns_to_convert].apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
|
|
248
|
+
|
|
249
|
+
if shape == "circle":
|
|
250
|
+
|
|
251
|
+
categorized_data = categorized_data.rename(columns={
|
|
252
|
+
"1": "drawing_present",
|
|
253
|
+
"2": "not_similar",
|
|
254
|
+
"3": "similar",
|
|
255
|
+
"4": "cir_closed",
|
|
256
|
+
"5": "cir_almost_closed",
|
|
257
|
+
"6": "cir_round",
|
|
258
|
+
"7": "cir_almost_round",
|
|
259
|
+
"8": "none"
|
|
260
|
+
})
|
|
261
|
+
|
|
262
|
+
categorized_data['score'] = categorized_data['cir_almost_closed'] + categorized_data['cir_closed'] + categorized_data['cir_round'] + categorized_data['cir_almost_round']
|
|
263
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
264
|
+
categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
|
|
265
|
+
|
|
266
|
+
elif shape == "diamond":
|
|
267
|
+
|
|
268
|
+
categorized_data = categorized_data.rename(columns={
|
|
269
|
+
"1": "drawing_present",
|
|
270
|
+
"2": "diamond_square",
|
|
271
|
+
"3": "not_similar",
|
|
272
|
+
"4": "similar",
|
|
273
|
+
"5": "diamond_4_sides",
|
|
274
|
+
"6": "diamond_equal_sides",
|
|
275
|
+
"7": "complex_diamond",
|
|
276
|
+
"8": "none"
|
|
277
|
+
})
|
|
278
|
+
|
|
279
|
+
categorized_data['score'] = categorized_data['diamond_4_sides'] + categorized_data['diamond_equal_sides'] + categorized_data['similar']
|
|
280
|
+
|
|
281
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
282
|
+
categorized_data.loc[(categorized_data['diamond_square'] == 1) & (categorized_data['score'] == 0), 'score'] = 2
|
|
283
|
+
|
|
284
|
+
elif shape == "rectangles" or shape == "overlapping rectangles":
|
|
285
|
+
|
|
286
|
+
categorized_data = categorized_data.rename(columns={
|
|
287
|
+
"1":"drawing_present",
|
|
288
|
+
"2": "not_similar",
|
|
289
|
+
"3": "similar",
|
|
290
|
+
"4": "r1_4_sides",
|
|
291
|
+
"5": "r2_4_sides",
|
|
292
|
+
"6": "rectangles_overlap",
|
|
293
|
+
"7": "rectangles_cross",
|
|
294
|
+
"8": "none"
|
|
295
|
+
})
|
|
296
|
+
|
|
297
|
+
categorized_data['score'] = 0
|
|
298
|
+
categorized_data.loc[(categorized_data['r1_4_sides'] == 1) & (categorized_data['r2_4_sides'] == 1), 'score'] = 1
|
|
299
|
+
categorized_data.loc[(categorized_data['rectangles_overlap'] == 1) & (categorized_data['rectangles_cross'] == 1), 'score'] += 1
|
|
300
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
301
|
+
|
|
302
|
+
elif shape == "cube":
|
|
303
|
+
|
|
304
|
+
categorized_data = categorized_data.rename(columns={
|
|
305
|
+
"1": "drawing_present",
|
|
306
|
+
"2": "not_similar",
|
|
307
|
+
"3": "similar",
|
|
308
|
+
"4": "cube_front_face",
|
|
309
|
+
"5": "cube_internal_lines",
|
|
310
|
+
"6": "cube_opposite_sides",
|
|
311
|
+
"7": "square_only",
|
|
312
|
+
"8": "none"
|
|
313
|
+
})
|
|
314
|
+
|
|
315
|
+
categorized_data['score'] = categorized_data['cube_front_face'] + categorized_data['cube_internal_lines'] + categorized_data['cube_opposite_sides'] + categorized_data['similar']
|
|
316
|
+
categorized_data.loc[categorized_data['similar'] == 1, 'score'] = categorized_data['score'] + 1
|
|
317
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
318
|
+
categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
|
|
319
|
+
categorized_data.loc[(categorized_data['not_similar'] == 1) & (categorized_data['score'] == 0), 'score'] = 0
|
|
320
|
+
categorized_data.loc[categorized_data['score'] > 4, 'score'] = 4
|
|
321
|
+
|
|
322
|
+
else:
|
|
323
|
+
raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
|
|
324
|
+
|
|
325
|
+
categorized_data.loc[categorized_data['no_valid_image'] == 1, 'score'] = None
|
|
326
|
+
|
|
327
|
+
if filename is not None:
|
|
328
|
+
categorized_data.to_csv(filename, index=False)
|
|
329
|
+
|
|
330
|
+
return categorized_data
|
catllm/__about__.py
CHANGED
catllm/__init__.py
CHANGED
catllm/cat_llm.py
CHANGED
|
@@ -1,3 +1,223 @@
|
|
|
1
|
+
#extract categories from corpus
|
|
2
|
+
def explore_corpus(
|
|
3
|
+
survey_question,
|
|
4
|
+
survey_input,
|
|
5
|
+
api_key,
|
|
6
|
+
research_question=None,
|
|
7
|
+
specificity="broad",
|
|
8
|
+
cat_num=10,
|
|
9
|
+
divisions=5,
|
|
10
|
+
user_model="gpt-4o-2024-11-20",
|
|
11
|
+
creativity=0,
|
|
12
|
+
filename="corpus_exploration.csv",
|
|
13
|
+
model_source="OpenAI"
|
|
14
|
+
):
|
|
15
|
+
import os
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import random
|
|
18
|
+
from openai import OpenAI
|
|
19
|
+
from openai import OpenAI, BadRequestError
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
|
|
22
|
+
print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
|
|
23
|
+
print()
|
|
24
|
+
|
|
25
|
+
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
26
|
+
chunk_size = int(chunk_size)
|
|
27
|
+
|
|
28
|
+
if chunk_size < (cat_num/2):
|
|
29
|
+
raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
|
|
30
|
+
f"Choose one solution: \n"
|
|
31
|
+
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
32
|
+
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
33
|
+
|
|
34
|
+
random_chunks = []
|
|
35
|
+
for i in range(divisions):
|
|
36
|
+
chunk = survey_input.sample(n=chunk_size).tolist()
|
|
37
|
+
random_chunks.append(chunk)
|
|
38
|
+
|
|
39
|
+
responses = []
|
|
40
|
+
responses_list = []
|
|
41
|
+
|
|
42
|
+
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
43
|
+
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
44
|
+
prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
45
|
+
Responses are each separated by a semicolon. \
|
|
46
|
+
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
47
|
+
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
48
|
+
|
|
49
|
+
if model_source == "OpenAI":
|
|
50
|
+
client = OpenAI(api_key=api_key)
|
|
51
|
+
try:
|
|
52
|
+
response_obj = client.chat.completions.create(
|
|
53
|
+
model=user_model,
|
|
54
|
+
messages=[
|
|
55
|
+
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
56
|
+
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
57
|
+
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
58
|
+
{'role': 'user', 'content': prompt}
|
|
59
|
+
]
|
|
60
|
+
temperature=creativity
|
|
61
|
+
)
|
|
62
|
+
reply = response_obj.choices[0].message.content
|
|
63
|
+
responses.append(reply)
|
|
64
|
+
except BadRequestError as e:
|
|
65
|
+
if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
|
|
66
|
+
error_msg = (f"Token limit exceeded for model {user_model}. "
|
|
67
|
+
f"Try increasing the 'iterations' parameter to create smaller chunks.")
|
|
68
|
+
raise ValueError(error_msg)
|
|
69
|
+
else:
|
|
70
|
+
print(f"OpenAI API error: {e}")
|
|
71
|
+
except Exception as e:
|
|
72
|
+
print(f"An error occurred: {e}")
|
|
73
|
+
else:
|
|
74
|
+
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
75
|
+
|
|
76
|
+
# Extract just the text as a list
|
|
77
|
+
items = []
|
|
78
|
+
for line in responses[i].split('\n'):
|
|
79
|
+
if '. ' in line:
|
|
80
|
+
try:
|
|
81
|
+
items.append(line.split('. ', 1)[1])
|
|
82
|
+
except IndexError:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
responses_list.append(items)
|
|
86
|
+
|
|
87
|
+
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
88
|
+
|
|
89
|
+
#convert flat_list to a df
|
|
90
|
+
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
91
|
+
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
92
|
+
df['counts'] = df['Category'].map(counts)
|
|
93
|
+
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
94
|
+
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
95
|
+
|
|
96
|
+
if filename is not None:
|
|
97
|
+
df.to_csv(filename, index=False)
|
|
98
|
+
|
|
99
|
+
return df
|
|
100
|
+
|
|
101
|
+
#extract top categories from corpus
|
|
102
|
+
def explore_common_categories(
|
|
103
|
+
survey_question,
|
|
104
|
+
survey_input,
|
|
105
|
+
api_key,
|
|
106
|
+
top_n=10,
|
|
107
|
+
cat_num=10,
|
|
108
|
+
divisions=5,
|
|
109
|
+
user_model="gpt-4o-2024-11-20",
|
|
110
|
+
creativity=0,
|
|
111
|
+
specificity="broad",
|
|
112
|
+
research_question=None,
|
|
113
|
+
filename=None,
|
|
114
|
+
model_source="OpenAI"
|
|
115
|
+
):
|
|
116
|
+
import os
|
|
117
|
+
import pandas as pd
|
|
118
|
+
import random
|
|
119
|
+
from openai import OpenAI
|
|
120
|
+
from openai import OpenAI, BadRequestError
|
|
121
|
+
from tqdm import tqdm
|
|
122
|
+
|
|
123
|
+
print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
|
|
124
|
+
print()
|
|
125
|
+
|
|
126
|
+
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
127
|
+
chunk_size = int(chunk_size)
|
|
128
|
+
|
|
129
|
+
if chunk_size < (cat_num/2):
|
|
130
|
+
raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
|
|
131
|
+
f"Choose one solution: \n"
|
|
132
|
+
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
133
|
+
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
134
|
+
|
|
135
|
+
random_chunks = []
|
|
136
|
+
for i in range(divisions):
|
|
137
|
+
chunk = survey_input.sample(n=chunk_size).tolist()
|
|
138
|
+
random_chunks.append(chunk)
|
|
139
|
+
|
|
140
|
+
responses = []
|
|
141
|
+
responses_list = []
|
|
142
|
+
|
|
143
|
+
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
144
|
+
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
145
|
+
prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
146
|
+
Responses are each separated by a semicolon. \
|
|
147
|
+
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
148
|
+
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
149
|
+
|
|
150
|
+
if model_source == "OpenAI":
|
|
151
|
+
client = OpenAI(api_key=api_key)
|
|
152
|
+
try:
|
|
153
|
+
response_obj = client.chat.completions.create(
|
|
154
|
+
model=user_model,
|
|
155
|
+
messages=[
|
|
156
|
+
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
157
|
+
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
158
|
+
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
159
|
+
{'role': 'user', 'content': prompt}
|
|
160
|
+
],
|
|
161
|
+
temperature=creativity
|
|
162
|
+
)
|
|
163
|
+
reply = response_obj.choices[0].message.content
|
|
164
|
+
responses.append(reply)
|
|
165
|
+
except BadRequestError as e:
|
|
166
|
+
if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
|
|
167
|
+
error_msg = (f"Token limit exceeded for model {user_model}. "
|
|
168
|
+
f"Try increasing the 'iterations' parameter to create smaller chunks.")
|
|
169
|
+
raise ValueError(error_msg)
|
|
170
|
+
else:
|
|
171
|
+
print(f"OpenAI API error: {e}")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
print(f"An error occurred: {e}")
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
176
|
+
|
|
177
|
+
# Extract just the text as a list
|
|
178
|
+
items = []
|
|
179
|
+
for line in responses[i].split('\n'):
|
|
180
|
+
if '. ' in line:
|
|
181
|
+
try:
|
|
182
|
+
items.append(line.split('. ', 1)[1])
|
|
183
|
+
except IndexError:
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
responses_list.append(items)
|
|
187
|
+
|
|
188
|
+
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
189
|
+
|
|
190
|
+
#convert flat_list to a df
|
|
191
|
+
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
192
|
+
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
193
|
+
df['counts'] = df['Category'].map(counts)
|
|
194
|
+
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
195
|
+
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
196
|
+
|
|
197
|
+
second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
|
|
198
|
+
The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
|
|
199
|
+
Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
|
|
200
|
+
|
|
201
|
+
if model_source == "OpenAI":
|
|
202
|
+
client = OpenAI(api_key=api_key)
|
|
203
|
+
response_obj = client.chat.completions.create(
|
|
204
|
+
model=user_model,
|
|
205
|
+
messages=[{'role': 'user', 'content': second_prompt}],
|
|
206
|
+
temperature=creativity
|
|
207
|
+
)
|
|
208
|
+
top_categories = response_obj.choices[0].message.content
|
|
209
|
+
print(top_categories)
|
|
210
|
+
|
|
211
|
+
top_categories_final = []
|
|
212
|
+
for line in top_categories.split('\n'):
|
|
213
|
+
if '. ' in line:
|
|
214
|
+
try:
|
|
215
|
+
top_categories_final.append(line.split('. ', 1)[1])
|
|
216
|
+
except IndexError:
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
return top_categories_final
|
|
220
|
+
|
|
1
221
|
#multi-class text classification
|
|
2
222
|
def extract_multi_class(
|
|
3
223
|
survey_question,
|
|
@@ -843,9 +1063,6 @@ def extract_image_features(
|
|
|
843
1063
|
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
844
1064
|
})
|
|
845
1065
|
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
846
|
-
|
|
847
|
-
if columns != "numbered": #if user wants text columns
|
|
848
|
-
categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
|
|
849
1066
|
|
|
850
1067
|
if to_csv:
|
|
851
1068
|
if save_directory is None:
|
|
@@ -854,222 +1071,333 @@ def extract_image_features(
|
|
|
854
1071
|
|
|
855
1072
|
return categorized_data
|
|
856
1073
|
|
|
857
|
-
#
|
|
858
|
-
def
|
|
859
|
-
|
|
860
|
-
|
|
1074
|
+
# image multi-class (binary) function
|
|
1075
|
+
def cerad_score(
|
|
1076
|
+
shape,
|
|
1077
|
+
image_input,
|
|
861
1078
|
api_key,
|
|
862
|
-
research_question=None,
|
|
863
|
-
specificity="broad",
|
|
864
|
-
cat_num=10,
|
|
865
|
-
divisions=5,
|
|
866
1079
|
user_model="gpt-4o-2024-11-20",
|
|
867
1080
|
creativity=0,
|
|
868
|
-
|
|
1081
|
+
safety=False,
|
|
1082
|
+
filename="categorized_data.csv",
|
|
869
1083
|
model_source="OpenAI"
|
|
870
1084
|
):
|
|
871
1085
|
import os
|
|
1086
|
+
import json
|
|
872
1087
|
import pandas as pd
|
|
873
|
-
import
|
|
874
|
-
from openai import OpenAI
|
|
875
|
-
from openai import OpenAI, BadRequestError
|
|
1088
|
+
import regex
|
|
876
1089
|
from tqdm import tqdm
|
|
1090
|
+
import glob
|
|
1091
|
+
import base64
|
|
1092
|
+
from pathlib import Path
|
|
877
1093
|
|
|
878
|
-
|
|
879
|
-
print()
|
|
880
|
-
|
|
881
|
-
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
882
|
-
chunk_size = int(chunk_size)
|
|
1094
|
+
shape = shape.lower()
|
|
883
1095
|
|
|
884
|
-
if
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
1096
|
+
if shape == "circle":
|
|
1097
|
+
categories = ["It has a drawing of a circle",
|
|
1098
|
+
"The drawing does not resemble a circle",
|
|
1099
|
+
"The drawing resembles a circle",
|
|
1100
|
+
"The circle is closed",
|
|
1101
|
+
"The circle is almost closed",
|
|
1102
|
+
"The circle is circular",
|
|
1103
|
+
"The circle is almost circular",
|
|
1104
|
+
"None of the above descriptions apply"]
|
|
1105
|
+
elif shape == "diamond":
|
|
1106
|
+
categories = ["It has a drawing of a diamond",
|
|
1107
|
+
"It has a drawing of a square",
|
|
1108
|
+
"A drawn shape DOES NOT resemble a diamond",
|
|
1109
|
+
"A drawn shape resembles a diamond",
|
|
1110
|
+
"The drawn shape has 4 sides",
|
|
1111
|
+
"The drawn shape sides are about equal",
|
|
1112
|
+
"If a diamond is drawn it's more elaborate than a simple diamond (such as overlapping diamonds or a diamond with an extras lines inside)",
|
|
1113
|
+
"None of the above descriptions apply"]
|
|
1114
|
+
elif shape == "rectangles" or shape == "overlapping rectangles":
|
|
1115
|
+
categories = ["It has a drawing of overlapping rectangles",
|
|
1116
|
+
"A drawn shape DOES NOT resemble a overlapping rectangles",
|
|
1117
|
+
"A drawn shape resembles a overlapping rectangles",
|
|
1118
|
+
"Rectangle 1 has 4 sides",
|
|
1119
|
+
"Rectangle 2 has 4 sides",
|
|
1120
|
+
"The rectangles are overlapping",
|
|
1121
|
+
"The rectangles overlap contains a longer vertical rectangle with top and bottom portruding",
|
|
1122
|
+
"None of the above descriptions apply"]
|
|
1123
|
+
elif shape == "cube":
|
|
1124
|
+
categories = ["The image contains a drawing that clearly represents a cube (3D box shape)",
|
|
1125
|
+
"The image does NOT contain any drawing that resembles a cube or 3D box",
|
|
1126
|
+
"The image contains a WELL-DRAWN recognizable cube with proper 3D perspective",
|
|
1127
|
+
"If a cube is present: the front face appears as a square or diamond shape",
|
|
1128
|
+
"If a cube is present: internal/hidden edges are visible (showing 3D depth, not just an outline)",
|
|
1129
|
+
"If a cube is present: the front and back faces appear parallel to each other",
|
|
1130
|
+
"The image contains only a 2D square (flat shape, no 3D appearance)",
|
|
1131
|
+
"None of the above descriptions apply"]
|
|
1132
|
+
else:
|
|
1133
|
+
raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
|
|
889
1134
|
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
1135
|
+
image_extensions = [
|
|
1136
|
+
'*.png', '*.jpg', '*.jpeg',
|
|
1137
|
+
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
1138
|
+
'*.tif', '*.tiff', '*.bmp',
|
|
1139
|
+
'*.heif', '*.heic', '*.ico',
|
|
1140
|
+
'*.psd'
|
|
1141
|
+
]
|
|
1142
|
+
|
|
1143
|
+
if not isinstance(image_input, list):
|
|
1144
|
+
# If image_input is a filepath (string)
|
|
1145
|
+
image_files = []
|
|
1146
|
+
for ext in image_extensions:
|
|
1147
|
+
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
894
1148
|
|
|
895
|
-
|
|
896
|
-
|
|
1149
|
+
print(f"Found {len(image_files)} images.")
|
|
1150
|
+
else:
|
|
1151
|
+
# If image_files is already a list
|
|
1152
|
+
image_files = image_input
|
|
1153
|
+
print(f"Provided a list of {len(image_input)} images.")
|
|
897
1154
|
|
|
898
|
-
for i in
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
1155
|
+
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
1156
|
+
cat_num = len(categories)
|
|
1157
|
+
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
1158
|
+
example_JSON = json.dumps(category_dict, indent=4)
|
|
1159
|
+
|
|
1160
|
+
link1 = []
|
|
1161
|
+
extracted_jsons = []
|
|
1162
|
+
|
|
1163
|
+
for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
|
|
1164
|
+
# Check validity first
|
|
1165
|
+
if img_path is None or not os.path.exists(img_path):
|
|
1166
|
+
link1.append("Skipped NaN input or invalid path")
|
|
1167
|
+
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
1168
|
+
continue # Skip the rest of the loop iteration
|
|
904
1169
|
|
|
1170
|
+
# Only open the file if path is valid
|
|
1171
|
+
with open(img_path, "rb") as f:
|
|
1172
|
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
1173
|
+
|
|
1174
|
+
# Handle extension safely
|
|
1175
|
+
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
1176
|
+
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
1177
|
+
|
|
1178
|
+
prompt = [
|
|
1179
|
+
{
|
|
1180
|
+
"type": "text",
|
|
1181
|
+
"text": (
|
|
1182
|
+
f"You are an image-tagging assistant trained in the CERAD Constructional Praxis test.\n"
|
|
1183
|
+
f"Task ► Examine the attached image and decide, **for each category below**, "
|
|
1184
|
+
f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
|
|
1185
|
+
f"Image is expected to show within it a drawing of a {shape}.\n\n"
|
|
1186
|
+
f"Categories:\n{categories_str}\n\n"
|
|
1187
|
+
f"Output format ► Respond with **only** a JSON object whose keys are the "
|
|
1188
|
+
f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
|
|
1189
|
+
f"No additional keys, comments, or text.\n\n"
|
|
1190
|
+
f"Example:\n"
|
|
1191
|
+
f"{example_JSON}"
|
|
1192
|
+
),
|
|
1193
|
+
},
|
|
1194
|
+
{
|
|
1195
|
+
"type": "image_url",
|
|
1196
|
+
"image_url": {"url": encoded_image, "detail": "high"},
|
|
1197
|
+
},
|
|
1198
|
+
]
|
|
905
1199
|
if model_source == "OpenAI":
|
|
1200
|
+
from openai import OpenAI
|
|
906
1201
|
client = OpenAI(api_key=api_key)
|
|
907
1202
|
try:
|
|
908
1203
|
response_obj = client.chat.completions.create(
|
|
909
1204
|
model=user_model,
|
|
910
|
-
messages=[
|
|
911
|
-
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
912
|
-
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
913
|
-
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
914
|
-
{'role': 'user', 'content': prompt}
|
|
915
|
-
]
|
|
1205
|
+
messages=[{'role': 'user', 'content': prompt}],
|
|
916
1206
|
temperature=creativity
|
|
917
1207
|
)
|
|
918
1208
|
reply = response_obj.choices[0].message.content
|
|
919
|
-
|
|
920
|
-
except BadRequestError as e:
|
|
921
|
-
if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
|
|
922
|
-
error_msg = (f"Token limit exceeded for model {user_model}. "
|
|
923
|
-
f"Try increasing the 'iterations' parameter to create smaller chunks.")
|
|
924
|
-
raise ValueError(error_msg)
|
|
925
|
-
else:
|
|
926
|
-
print(f"OpenAI API error: {e}")
|
|
1209
|
+
link1.append(reply)
|
|
927
1210
|
except Exception as e:
|
|
928
1211
|
print(f"An error occurred: {e}")
|
|
929
|
-
|
|
930
|
-
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
931
|
-
|
|
932
|
-
# Extract just the text as a list
|
|
933
|
-
items = []
|
|
934
|
-
for line in responses[i].split('\n'):
|
|
935
|
-
if '. ' in line:
|
|
936
|
-
try:
|
|
937
|
-
items.append(line.split('. ', 1)[1])
|
|
938
|
-
except IndexError:
|
|
939
|
-
pass
|
|
940
|
-
|
|
941
|
-
responses_list.append(items)
|
|
942
|
-
|
|
943
|
-
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
944
|
-
|
|
945
|
-
#convert flat_list to a df
|
|
946
|
-
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
947
|
-
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
948
|
-
df['counts'] = df['Category'].map(counts)
|
|
949
|
-
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
950
|
-
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
951
|
-
|
|
952
|
-
if filename is not None:
|
|
953
|
-
df.to_csv(filename, index=False)
|
|
954
|
-
|
|
955
|
-
return df
|
|
956
|
-
|
|
957
|
-
#extract top categories from corpus
|
|
958
|
-
def explore_common_categories(
|
|
959
|
-
survey_question,
|
|
960
|
-
survey_input,
|
|
961
|
-
api_key,
|
|
962
|
-
top_n=10,
|
|
963
|
-
cat_num=10,
|
|
964
|
-
divisions=5,
|
|
965
|
-
user_model="gpt-4o-2024-11-20",
|
|
966
|
-
creativity=0,
|
|
967
|
-
specificity="broad",
|
|
968
|
-
research_question=None,
|
|
969
|
-
filename=None,
|
|
970
|
-
model_source="OpenAI"
|
|
971
|
-
):
|
|
972
|
-
import os
|
|
973
|
-
import pandas as pd
|
|
974
|
-
import random
|
|
975
|
-
from openai import OpenAI
|
|
976
|
-
from openai import OpenAI, BadRequestError
|
|
977
|
-
from tqdm import tqdm
|
|
978
|
-
|
|
979
|
-
print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
|
|
980
|
-
print()
|
|
981
|
-
|
|
982
|
-
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
983
|
-
chunk_size = int(chunk_size)
|
|
984
|
-
|
|
985
|
-
if chunk_size < (cat_num/2):
|
|
986
|
-
raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
|
|
987
|
-
f"Choose one solution: \n"
|
|
988
|
-
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
989
|
-
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
1212
|
+
link1.append(f"Error processing input: {e}")
|
|
990
1213
|
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
random_chunks.append(chunk)
|
|
995
|
-
|
|
996
|
-
responses = []
|
|
997
|
-
responses_list = []
|
|
998
|
-
|
|
999
|
-
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
1000
|
-
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
1001
|
-
prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
1002
|
-
Responses are each separated by a semicolon. \
|
|
1003
|
-
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
1004
|
-
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
1005
|
-
|
|
1006
|
-
if model_source == "OpenAI":
|
|
1007
|
-
client = OpenAI(api_key=api_key)
|
|
1214
|
+
elif model_source == "Perplexity":
|
|
1215
|
+
from openai import OpenAI
|
|
1216
|
+
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
1008
1217
|
try:
|
|
1009
1218
|
response_obj = client.chat.completions.create(
|
|
1010
1219
|
model=user_model,
|
|
1011
|
-
messages=[
|
|
1012
|
-
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
1013
|
-
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
1014
|
-
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
1015
|
-
{'role': 'user', 'content': prompt}
|
|
1016
|
-
],
|
|
1220
|
+
messages=[{'role': 'user', 'content': prompt}],
|
|
1017
1221
|
temperature=creativity
|
|
1018
1222
|
)
|
|
1019
1223
|
reply = response_obj.choices[0].message.content
|
|
1020
|
-
|
|
1021
|
-
except
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1224
|
+
link1.append(reply)
|
|
1225
|
+
except Exception as e:
|
|
1226
|
+
print(f"An error occurred: {e}")
|
|
1227
|
+
link1.append(f"Error processing input: {e}")
|
|
1228
|
+
elif model_source == "Anthropic":
|
|
1229
|
+
import anthropic
|
|
1230
|
+
client = anthropic.Anthropic(api_key=api_key)
|
|
1231
|
+
try:
|
|
1232
|
+
message = client.messages.create(
|
|
1233
|
+
model=user_model,
|
|
1234
|
+
max_tokens=1024,
|
|
1235
|
+
temperature=creativity,
|
|
1236
|
+
messages=[{"role": "user", "content": prompt}]
|
|
1237
|
+
)
|
|
1238
|
+
reply = message.content[0].text # Anthropic returns content as list
|
|
1239
|
+
link1.append(reply)
|
|
1240
|
+
except Exception as e:
|
|
1241
|
+
print(f"An error occurred: {e}")
|
|
1242
|
+
link1.append(f"Error processing input: {e}")
|
|
1243
|
+
elif model_source == "Mistral":
|
|
1244
|
+
from mistralai import Mistral
|
|
1245
|
+
client = Mistral(api_key=api_key)
|
|
1246
|
+
try:
|
|
1247
|
+
response = client.chat.complete(
|
|
1248
|
+
model=user_model,
|
|
1249
|
+
messages=[
|
|
1250
|
+
{'role': 'user', 'content': prompt}
|
|
1251
|
+
],
|
|
1252
|
+
temperature=creativity
|
|
1253
|
+
)
|
|
1254
|
+
reply = response.choices[0].message.content
|
|
1255
|
+
link1.append(reply)
|
|
1028
1256
|
except Exception as e:
|
|
1029
1257
|
print(f"An error occurred: {e}")
|
|
1258
|
+
link1.append(f"Error processing input: {e}")
|
|
1030
1259
|
else:
|
|
1031
|
-
raise ValueError(
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1260
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
1261
|
+
# in situation that no JSON is found
|
|
1262
|
+
if reply is not None:
|
|
1263
|
+
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
1264
|
+
if extracted_json:
|
|
1265
|
+
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
1266
|
+
extracted_jsons.append(cleaned_json)
|
|
1267
|
+
#print(cleaned_json)
|
|
1268
|
+
else:
|
|
1269
|
+
error_message = """{"1":"e"}"""
|
|
1270
|
+
extracted_jsons.append(error_message)
|
|
1271
|
+
print(error_message)
|
|
1272
|
+
else:
|
|
1273
|
+
error_message = """{"1":"e"}"""
|
|
1274
|
+
extracted_jsons.append(error_message)
|
|
1275
|
+
#print(error_message)
|
|
1276
|
+
|
|
1277
|
+
# --- Safety Save ---
|
|
1278
|
+
if safety:
|
|
1279
|
+
#print(f"Saving CSV to: {save_directory}")
|
|
1280
|
+
# Save progress so far
|
|
1281
|
+
temp_df = pd.DataFrame({
|
|
1282
|
+
'image_input': image_files[:i+1],
|
|
1283
|
+
'link1': link1,
|
|
1284
|
+
'json': extracted_jsons
|
|
1285
|
+
})
|
|
1286
|
+
# Normalize processed jsons so far
|
|
1287
|
+
normalized_data_list = []
|
|
1288
|
+
for json_str in extracted_jsons:
|
|
1037
1289
|
try:
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1290
|
+
parsed_obj = json.loads(json_str)
|
|
1291
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1292
|
+
except json.JSONDecodeError:
|
|
1293
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1294
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1295
|
+
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
1296
|
+
# Save to CSV
|
|
1297
|
+
if filename is None:
|
|
1298
|
+
filepath = os.path.join(os.getcwd(), 'catllm_data.csv')
|
|
1299
|
+
else:
|
|
1300
|
+
filepath = filename
|
|
1301
|
+
temp_df.to_csv(filepath, index=False)
|
|
1041
1302
|
|
|
1042
|
-
|
|
1303
|
+
# --- Final DataFrame ---
|
|
1304
|
+
normalized_data_list = []
|
|
1305
|
+
for json_str in extracted_jsons:
|
|
1306
|
+
try:
|
|
1307
|
+
parsed_obj = json.loads(json_str)
|
|
1308
|
+
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
1309
|
+
except json.JSONDecodeError:
|
|
1310
|
+
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
1311
|
+
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
1043
1312
|
|
|
1044
|
-
|
|
1313
|
+
categorized_data = pd.DataFrame({
|
|
1314
|
+
'image_input': image_files,
|
|
1315
|
+
'link1': pd.Series(link1).reset_index(drop=True),
|
|
1316
|
+
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
1317
|
+
})
|
|
1318
|
+
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
1319
|
+
columns_to_convert = ["1", "2", "3", "4", "5", "6", "7"]
|
|
1320
|
+
categorized_data[columns_to_convert] = categorized_data[columns_to_convert].apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
|
|
1045
1321
|
|
|
1046
|
-
|
|
1047
|
-
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
1048
|
-
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
1049
|
-
df['counts'] = df['Category'].map(counts)
|
|
1050
|
-
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
1051
|
-
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
1322
|
+
if shape == "circle":
|
|
1052
1323
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
)
|
|
1064
|
-
top_categories = response_obj.choices[0].message.content
|
|
1065
|
-
print(top_categories)
|
|
1324
|
+
categorized_data = categorized_data.rename(columns={
|
|
1325
|
+
"1": "drawing_present",
|
|
1326
|
+
"2": "not_similar",
|
|
1327
|
+
"3": "similar",
|
|
1328
|
+
"4": "cir_closed",
|
|
1329
|
+
"5": "cir_almost_closed",
|
|
1330
|
+
"6": "cir_round",
|
|
1331
|
+
"7": "cir_almost_round",
|
|
1332
|
+
"8": "none"
|
|
1333
|
+
})
|
|
1066
1334
|
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1335
|
+
categorized_data['score'] = categorized_data['cir_almost_closed'] + categorized_data['cir_closed'] + categorized_data['cir_round'] + categorized_data['cir_almost_round']
|
|
1336
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
1337
|
+
categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
|
|
1338
|
+
|
|
1339
|
+
elif shape == "diamond":
|
|
1340
|
+
|
|
1341
|
+
categorized_data = categorized_data.rename(columns={
|
|
1342
|
+
"1": "drawing_present",
|
|
1343
|
+
"2": "diamond_square",
|
|
1344
|
+
"3": "not_similar",
|
|
1345
|
+
"4": "similar",
|
|
1346
|
+
"5": "diamond_4_sides",
|
|
1347
|
+
"6": "diamond_equal_sides",
|
|
1348
|
+
"7": "complex_diamond",
|
|
1349
|
+
"8": "none"
|
|
1350
|
+
})
|
|
1351
|
+
|
|
1352
|
+
categorized_data['score'] = categorized_data['diamond_4_sides'] + categorized_data['diamond_equal_sides'] + categorized_data['similar']
|
|
1353
|
+
|
|
1354
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
1355
|
+
categorized_data.loc[(categorized_data['diamond_square'] == 1) & (categorized_data['score'] == 0), 'score'] = 2
|
|
1356
|
+
|
|
1357
|
+
elif shape == "rectangles" or shape == "overlapping rectangles":
|
|
1358
|
+
|
|
1359
|
+
categorized_data = categorized_data.rename(columns={
|
|
1360
|
+
"1":"drawing_present",
|
|
1361
|
+
"2": "not_similar",
|
|
1362
|
+
"3": "similar",
|
|
1363
|
+
"4": "r1_4_sides",
|
|
1364
|
+
"5": "r2_4_sides",
|
|
1365
|
+
"6": "rectangles_overlap",
|
|
1366
|
+
"7": "rectangles_cross",
|
|
1367
|
+
"8": "none"
|
|
1368
|
+
})
|
|
1074
1369
|
|
|
1075
|
-
|
|
1370
|
+
categorized_data['score'] = 0
|
|
1371
|
+
categorized_data.loc[(categorized_data['r1_4_sides'] == 1) & (categorized_data['r2_4_sides'] == 1), 'score'] = 1
|
|
1372
|
+
categorized_data.loc[(categorized_data['rectangles_overlap'] == 1) & (categorized_data['rectangles_cross'] == 1), 'score'] += 1
|
|
1373
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
1374
|
+
|
|
1375
|
+
elif shape == "cube":
|
|
1376
|
+
|
|
1377
|
+
categorized_data = categorized_data.rename(columns={
|
|
1378
|
+
"1": "drawing_present",
|
|
1379
|
+
"2": "not_similar",
|
|
1380
|
+
"3": "similar",
|
|
1381
|
+
"4": "cube_front_face",
|
|
1382
|
+
"5": "cube_internal_lines",
|
|
1383
|
+
"6": "cube_opposite_sides",
|
|
1384
|
+
"7": "square_only",
|
|
1385
|
+
"8": "none"
|
|
1386
|
+
})
|
|
1387
|
+
|
|
1388
|
+
categorized_data['score'] = categorized_data['cube_front_face'] + categorized_data['cube_internal_lines'] + categorized_data['cube_opposite_sides'] + categorized_data['similar']
|
|
1389
|
+
categorized_data.loc[categorized_data['similar'] == 1, 'score'] = categorized_data['score'] + 1
|
|
1390
|
+
categorized_data.loc[categorized_data['none'] == 1, 'score'] = 0
|
|
1391
|
+
categorized_data.loc[(categorized_data['drawing_present'] == 0) & (categorized_data['score'] == 0), 'score'] = 0
|
|
1392
|
+
categorized_data.loc[(categorized_data['not_similar'] == 1) & (categorized_data['score'] == 0), 'score'] = 0
|
|
1393
|
+
categorized_data.loc[categorized_data['score'] > 4, 'score'] = 4
|
|
1394
|
+
|
|
1395
|
+
else:
|
|
1396
|
+
raise ValueError("Invalid shape! Choose from 'circle', 'diamond', 'rectangles', or 'cube'.")
|
|
1397
|
+
|
|
1398
|
+
categorized_data.loc[categorized_data['no_valid_image'] == 1, 'score'] = None
|
|
1399
|
+
|
|
1400
|
+
if filename is not None:
|
|
1401
|
+
categorized_data.to_csv(filename, index=False)
|
|
1402
|
+
|
|
1403
|
+
return categorized_data
|
cat_llm-0.0.19.dist-info/RECORD
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
catllm/__about__.py,sha256=ht639_Zw-BmQ-6TW3sLLeC_s67LbW9TFDrYT8b4xpIg,404
|
|
2
|
-
catllm/__init__.py,sha256=xDin9x4jymeccuxE9Xf-27ncR9h7247IwLbeYN-m3j8,266
|
|
3
|
-
catllm/cat_llm.py,sha256=pj_xcsFA5OQVhMv9-73YT7tDn_3Ol3UowqYAbPrlrZI,43825
|
|
4
|
-
cat_llm-0.0.19.dist-info/METADATA,sha256=Ln27uQy5nspYHg0-iFH_XvAlDJFlfxREKeye2llHteI,1679
|
|
5
|
-
cat_llm-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
cat_llm-0.0.19.dist-info/licenses/LICENSE,sha256=wJLsvOr6lrFUDcoPXExa01HOKFWrS3JC9f0RudRw8uw,1075
|
|
7
|
-
cat_llm-0.0.19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|