cat-llm 0.0.8__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.8 → cat_llm-0.0.9}/PKG-INFO +3 -1
- {cat_llm-0.0.8 → cat_llm-0.0.9}/README.md +2 -0
- {cat_llm-0.0.8 → cat_llm-0.0.9}/pyproject.toml +8 -0
- cat_llm-0.0.8/setup.py +0 -23
- cat_llm-0.0.8/src/catllm/__about__.py +0 -10
- cat_llm-0.0.8/src/catllm/__init__.py +0 -14
- cat_llm-0.0.8/src/catllm/cat_llm.py +0 -959
- cat_llm-0.0.8/tests/__init__.py +0 -3
- {cat_llm-0.0.8 → cat_llm-0.0.9}/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/Christopher Soria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -23,6 +23,8 @@ Requires-Dist: pandas
|
|
|
23
23
|
Requires-Dist: tqdm
|
|
24
24
|
Description-Content-Type: text/markdown
|
|
25
25
|
|
|
26
|
+

|
|
27
|
+
|
|
26
28
|
# cat_llm
|
|
27
29
|
|
|
28
30
|
[](https://pypi.org/project/cat-llm)
|
|
@@ -46,6 +46,14 @@ check = "mypy --install-types --non-interactive {args:src/catllm tests}"
|
|
|
46
46
|
|
|
47
47
|
[tool.hatch.build.targets.wheel]
|
|
48
48
|
packages = ["src/catllm"]
|
|
49
|
+
include = [
|
|
50
|
+
"src/catllm/images/*",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
[tool.hatch.build.targets.sdist]
|
|
54
|
+
include = [
|
|
55
|
+
"src/catllm/images/*",
|
|
56
|
+
]
|
|
49
57
|
|
|
50
58
|
[tool.coverage.run]
|
|
51
59
|
source_pkgs = ["catllm", "tests"]
|
cat_llm-0.0.8/setup.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
from setuptools import setup
|
|
2
|
-
import os
|
|
3
|
-
|
|
4
|
-
VERSION = ""0.0.3""
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def get_long_description():
|
|
8
|
-
with open(
|
|
9
|
-
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
|
|
10
|
-
encoding="utf8",
|
|
11
|
-
) as fp:
|
|
12
|
-
return fp.read()
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
setup(
|
|
16
|
-
name="catllm",
|
|
17
|
-
description="catllm is now cat-llm",
|
|
18
|
-
long_description=get_long_description(),
|
|
19
|
-
long_description_content_type="text/markdown",
|
|
20
|
-
version=VERSION,
|
|
21
|
-
install_requires=["cat-llm"],
|
|
22
|
-
classifiers=["Development Status :: 7 - Inactive"],
|
|
23
|
-
)
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: MIT
|
|
4
|
-
__version__ = "0.0.8"
|
|
5
|
-
__author__ = "Chris Soria"
|
|
6
|
-
__email__ = "chrissoria@berkeley.edu"
|
|
7
|
-
__title__ = "cat-llm"
|
|
8
|
-
__description__ = "A tool for categorizing text data and images using LLMs and vision models"
|
|
9
|
-
__url__ = "https://github.com/chrissoria/cat-llm" # Replace with your actual repository URL
|
|
10
|
-
__license__ = "MIT"
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
|
-
#
|
|
3
|
-
# SPDX-License-Identifier: MIT
|
|
4
|
-
|
|
5
|
-
from .__about__ import (
|
|
6
|
-
__version__,
|
|
7
|
-
__author__,
|
|
8
|
-
__description__,
|
|
9
|
-
__title__,
|
|
10
|
-
__url__,
|
|
11
|
-
__license__,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
from .cat_llm import *
|
|
@@ -1,959 +0,0 @@
|
|
|
1
|
-
#multi-class text classification
|
|
2
|
-
def extract_multi_class(
|
|
3
|
-
survey_question,
|
|
4
|
-
survey_input,
|
|
5
|
-
categories,
|
|
6
|
-
api_key,
|
|
7
|
-
columns="numbered",
|
|
8
|
-
user_model="gpt-4o-2024-11-20",
|
|
9
|
-
creativity=0,
|
|
10
|
-
to_csv=False,
|
|
11
|
-
safety=False,
|
|
12
|
-
filename="categorized_data.csv",
|
|
13
|
-
save_directory=None,
|
|
14
|
-
model_source="OpenAI"
|
|
15
|
-
):
|
|
16
|
-
import os
|
|
17
|
-
import json
|
|
18
|
-
import pandas as pd
|
|
19
|
-
import regex
|
|
20
|
-
from tqdm import tqdm
|
|
21
|
-
|
|
22
|
-
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
23
|
-
cat_num = len(categories)
|
|
24
|
-
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
25
|
-
example_JSON = json.dumps(category_dict, indent=4)
|
|
26
|
-
|
|
27
|
-
# ensure number of categories is what user wants
|
|
28
|
-
print("\nPlease verify the categories you entered:")
|
|
29
|
-
for i, cat in enumerate(categories, 1):
|
|
30
|
-
print(f"{i}. {cat}")
|
|
31
|
-
response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
|
|
32
|
-
while response.strip().lower() != "next":
|
|
33
|
-
response = input("Please type 'next' to continue: ")
|
|
34
|
-
|
|
35
|
-
link1 = []
|
|
36
|
-
extracted_jsons = []
|
|
37
|
-
|
|
38
|
-
for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
|
|
39
|
-
reply = None
|
|
40
|
-
|
|
41
|
-
if pd.isna(response):
|
|
42
|
-
link1.append("Skipped NaN input")
|
|
43
|
-
default_json = example_JSON
|
|
44
|
-
extracted_jsons.append(default_json)
|
|
45
|
-
#print(f"Skipped NaN input.")
|
|
46
|
-
else:
|
|
47
|
-
prompt = f"""A respondent was asked: {survey_question}. \
|
|
48
|
-
Categorize this survey response "{response}" into the following categories that apply: \
|
|
49
|
-
{categories_str} \
|
|
50
|
-
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
|
|
51
|
-
#print(prompt)
|
|
52
|
-
if model_source == ("OpenAI"):
|
|
53
|
-
from openai import OpenAI
|
|
54
|
-
client = OpenAI(api_key=api_key)
|
|
55
|
-
try:
|
|
56
|
-
response_obj = client.chat.completions.create(
|
|
57
|
-
model=user_model,
|
|
58
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
59
|
-
temperature=creativity
|
|
60
|
-
)
|
|
61
|
-
reply = response_obj.choices[0].message.content
|
|
62
|
-
link1.append(reply)
|
|
63
|
-
except Exception as e:
|
|
64
|
-
print(f"An error occurred: {e}")
|
|
65
|
-
link1.append(f"Error processing input: {e}")
|
|
66
|
-
elif model_source == "Perplexity":
|
|
67
|
-
from openai import OpenAI
|
|
68
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
69
|
-
try:
|
|
70
|
-
response_obj = client.chat.completions.create(
|
|
71
|
-
model=user_model,
|
|
72
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
73
|
-
temperature=creativity
|
|
74
|
-
)
|
|
75
|
-
reply = response_obj.choices[0].message.content
|
|
76
|
-
link1.append(reply)
|
|
77
|
-
except Exception as e:
|
|
78
|
-
print(f"An error occurred: {e}")
|
|
79
|
-
link1.append(f"Error processing input: {e}")
|
|
80
|
-
elif model_source == "Anthropic":
|
|
81
|
-
import anthropic
|
|
82
|
-
client = anthropic.Anthropic(api_key=api_key)
|
|
83
|
-
try:
|
|
84
|
-
message = client.messages.create(
|
|
85
|
-
model=user_model,
|
|
86
|
-
max_tokens=1024,
|
|
87
|
-
temperature=creativity,
|
|
88
|
-
messages=[{"role": "user", "content": prompt}]
|
|
89
|
-
)
|
|
90
|
-
reply = message.content[0].text # Anthropic returns content as list
|
|
91
|
-
link1.append(reply)
|
|
92
|
-
except Exception as e:
|
|
93
|
-
print(f"An error occurred: {e}")
|
|
94
|
-
link1.append(f"Error processing input: {e}")
|
|
95
|
-
elif model_source == "Mistral":
|
|
96
|
-
from mistralai import Mistral
|
|
97
|
-
client = Mistral(api_key=api_key)
|
|
98
|
-
try:
|
|
99
|
-
response = client.chat.complete(
|
|
100
|
-
model=user_model,
|
|
101
|
-
messages=[
|
|
102
|
-
{'role': 'user', 'content': prompt}
|
|
103
|
-
],
|
|
104
|
-
temperature=creativity
|
|
105
|
-
)
|
|
106
|
-
reply = response.choices[0].message.content
|
|
107
|
-
link1.append(reply)
|
|
108
|
-
except Exception as e:
|
|
109
|
-
print(f"An error occurred: {e}")
|
|
110
|
-
link1.append(f"Error processing input: {e}")
|
|
111
|
-
else:
|
|
112
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
113
|
-
# in situation that no JSON is found
|
|
114
|
-
if reply is not None:
|
|
115
|
-
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
116
|
-
if extracted_json:
|
|
117
|
-
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
118
|
-
extracted_jsons.append(cleaned_json)
|
|
119
|
-
#print(cleaned_json)
|
|
120
|
-
else:
|
|
121
|
-
error_message = """{"1":"e"}"""
|
|
122
|
-
extracted_jsons.append(error_message)
|
|
123
|
-
print(error_message)
|
|
124
|
-
else:
|
|
125
|
-
error_message = """{"1":"e"}"""
|
|
126
|
-
extracted_jsons.append(error_message)
|
|
127
|
-
#print(error_message)
|
|
128
|
-
|
|
129
|
-
# --- Safety Save ---
|
|
130
|
-
if safety:
|
|
131
|
-
# Save progress so far
|
|
132
|
-
temp_df = pd.DataFrame({
|
|
133
|
-
'survey_response': survey_input[:idx+1],
|
|
134
|
-
'link1': link1,
|
|
135
|
-
'json': extracted_jsons
|
|
136
|
-
})
|
|
137
|
-
# Normalize processed jsons so far
|
|
138
|
-
normalized_data_list = []
|
|
139
|
-
for json_str in extracted_jsons:
|
|
140
|
-
try:
|
|
141
|
-
parsed_obj = json.loads(json_str)
|
|
142
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
143
|
-
except json.JSONDecodeError:
|
|
144
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
145
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
146
|
-
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
147
|
-
# Save to CSV
|
|
148
|
-
if save_directory is None:
|
|
149
|
-
save_directory = os.getcwd()
|
|
150
|
-
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
151
|
-
|
|
152
|
-
# --- Final DataFrame ---
|
|
153
|
-
normalized_data_list = []
|
|
154
|
-
for json_str in extracted_jsons:
|
|
155
|
-
try:
|
|
156
|
-
parsed_obj = json.loads(json_str)
|
|
157
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
158
|
-
except json.JSONDecodeError:
|
|
159
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
160
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
161
|
-
|
|
162
|
-
categorized_data = pd.DataFrame({
|
|
163
|
-
'survey_response': survey_input.reset_index(drop=True),
|
|
164
|
-
'link1': pd.Series(link1).reset_index(drop=True),
|
|
165
|
-
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
166
|
-
})
|
|
167
|
-
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
168
|
-
|
|
169
|
-
if columns != "numbered": #if user wants text columns
|
|
170
|
-
categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
|
|
171
|
-
|
|
172
|
-
if to_csv:
|
|
173
|
-
if save_directory is None:
|
|
174
|
-
save_directory = os.getcwd()
|
|
175
|
-
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
176
|
-
|
|
177
|
-
return categorized_data
|
|
178
|
-
|
|
179
|
-
# image multi-class (binary) function
|
|
180
|
-
def extract_image_multi_class(
|
|
181
|
-
image_description,
|
|
182
|
-
image_input,
|
|
183
|
-
categories,
|
|
184
|
-
api_key,
|
|
185
|
-
columns="numbered",
|
|
186
|
-
user_model="gpt-4o-2024-11-20",
|
|
187
|
-
creativity=0,
|
|
188
|
-
to_csv=False,
|
|
189
|
-
safety=False,
|
|
190
|
-
filename="categorized_data.csv",
|
|
191
|
-
save_directory=None,
|
|
192
|
-
model_source="OpenAI"
|
|
193
|
-
):
|
|
194
|
-
import os
|
|
195
|
-
import json
|
|
196
|
-
import pandas as pd
|
|
197
|
-
import regex
|
|
198
|
-
from tqdm import tqdm
|
|
199
|
-
import glob
|
|
200
|
-
import base64
|
|
201
|
-
from pathlib import Path
|
|
202
|
-
|
|
203
|
-
if save_directory is not None and not os.path.isdir(save_directory):
|
|
204
|
-
# Directory doesn't exist - raise an exception to halt execution
|
|
205
|
-
raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
|
|
206
|
-
|
|
207
|
-
image_extensions = [
|
|
208
|
-
'*.png', '*.jpg', '*.jpeg',
|
|
209
|
-
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
210
|
-
'*.tif', '*.tiff', '*.bmp',
|
|
211
|
-
'*.heif', '*.heic', '*.ico',
|
|
212
|
-
'*.psd'
|
|
213
|
-
]
|
|
214
|
-
|
|
215
|
-
if not isinstance(image_input, list):
|
|
216
|
-
# If image_input is a filepath (string)
|
|
217
|
-
image_files = []
|
|
218
|
-
for ext in image_extensions:
|
|
219
|
-
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
220
|
-
|
|
221
|
-
print(f"Found {len(image_files)} images.")
|
|
222
|
-
else:
|
|
223
|
-
# If image_files is already a list
|
|
224
|
-
image_files = image_input
|
|
225
|
-
print(f"Provided a list of {len(image_input)} images.")
|
|
226
|
-
|
|
227
|
-
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
228
|
-
cat_num = len(categories)
|
|
229
|
-
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
230
|
-
example_JSON = json.dumps(category_dict, indent=4)
|
|
231
|
-
|
|
232
|
-
# ensure number of categories is what user wants
|
|
233
|
-
print("Categories to classify:")
|
|
234
|
-
for i, cat in enumerate(categories, 1):
|
|
235
|
-
print(f"{i}. {cat}")
|
|
236
|
-
|
|
237
|
-
link1 = []
|
|
238
|
-
extracted_jsons = []
|
|
239
|
-
|
|
240
|
-
for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
|
|
241
|
-
# Check validity first
|
|
242
|
-
if img_path is None or not os.path.exists(img_path):
|
|
243
|
-
link1.append("Skipped NaN input or invalid path")
|
|
244
|
-
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
245
|
-
continue # Skip the rest of the loop iteration
|
|
246
|
-
|
|
247
|
-
# Only open the file if path is valid
|
|
248
|
-
with open(img_path, "rb") as f:
|
|
249
|
-
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
250
|
-
|
|
251
|
-
# Handle extension safely
|
|
252
|
-
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
253
|
-
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
254
|
-
|
|
255
|
-
prompt = [
|
|
256
|
-
{
|
|
257
|
-
"type": "text",
|
|
258
|
-
"text": (
|
|
259
|
-
f"You are an image-tagging assistant.\n"
|
|
260
|
-
f"Task ► Examine the attached image and decide, **for each category below**, "
|
|
261
|
-
f"whether it is PRESENT (1) or NOT PRESENT (0).\n\n"
|
|
262
|
-
f"Image is expected to show: {image_description}\n\n"
|
|
263
|
-
f"Categories:\n{categories_str}\n\n"
|
|
264
|
-
f"Output format ► Respond with **only** a JSON object whose keys are the "
|
|
265
|
-
f"quoted category numbers ('1', '2', …) and whose values are 1 or 0. "
|
|
266
|
-
f"No additional keys, comments, or text.\n\n"
|
|
267
|
-
f"Example (three categories):\n"
|
|
268
|
-
f"{example_JSON}"
|
|
269
|
-
),
|
|
270
|
-
},
|
|
271
|
-
{
|
|
272
|
-
"type": "image_url",
|
|
273
|
-
"image_url": {"url": encoded_image, "detail": "high"},
|
|
274
|
-
},
|
|
275
|
-
]
|
|
276
|
-
if model_source == "OpenAI":
|
|
277
|
-
from openai import OpenAI
|
|
278
|
-
client = OpenAI(api_key=api_key)
|
|
279
|
-
try:
|
|
280
|
-
response_obj = client.chat.completions.create(
|
|
281
|
-
model=user_model,
|
|
282
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
283
|
-
temperature=creativity
|
|
284
|
-
)
|
|
285
|
-
reply = response_obj.choices[0].message.content
|
|
286
|
-
link1.append(reply)
|
|
287
|
-
except Exception as e:
|
|
288
|
-
print(f"An error occurred: {e}")
|
|
289
|
-
link1.append(f"Error processing input: {e}")
|
|
290
|
-
|
|
291
|
-
elif model_source == "Perplexity":
|
|
292
|
-
from openai import OpenAI
|
|
293
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
294
|
-
try:
|
|
295
|
-
response_obj = client.chat.completions.create(
|
|
296
|
-
model=user_model,
|
|
297
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
298
|
-
temperature=creativity
|
|
299
|
-
)
|
|
300
|
-
reply = response_obj.choices[0].message.content
|
|
301
|
-
link1.append(reply)
|
|
302
|
-
except Exception as e:
|
|
303
|
-
print(f"An error occurred: {e}")
|
|
304
|
-
link1.append(f"Error processing input: {e}")
|
|
305
|
-
elif model_source == "Anthropic":
|
|
306
|
-
import anthropic
|
|
307
|
-
client = anthropic.Anthropic(api_key=api_key)
|
|
308
|
-
try:
|
|
309
|
-
message = client.messages.create(
|
|
310
|
-
model=user_model,
|
|
311
|
-
max_tokens=1024,
|
|
312
|
-
temperature=creativity,
|
|
313
|
-
messages=[{"role": "user", "content": prompt}]
|
|
314
|
-
)
|
|
315
|
-
reply = message.content[0].text # Anthropic returns content as list
|
|
316
|
-
link1.append(reply)
|
|
317
|
-
except Exception as e:
|
|
318
|
-
print(f"An error occurred: {e}")
|
|
319
|
-
link1.append(f"Error processing input: {e}")
|
|
320
|
-
elif model_source == "Mistral":
|
|
321
|
-
from mistralai import Mistral
|
|
322
|
-
client = Mistral(api_key=api_key)
|
|
323
|
-
try:
|
|
324
|
-
response = client.chat.complete(
|
|
325
|
-
model=user_model,
|
|
326
|
-
messages=[
|
|
327
|
-
{'role': 'user', 'content': prompt}
|
|
328
|
-
],
|
|
329
|
-
temperature=creativity
|
|
330
|
-
)
|
|
331
|
-
reply = response.choices[0].message.content
|
|
332
|
-
link1.append(reply)
|
|
333
|
-
except Exception as e:
|
|
334
|
-
print(f"An error occurred: {e}")
|
|
335
|
-
link1.append(f"Error processing input: {e}")
|
|
336
|
-
else:
|
|
337
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
338
|
-
# in situation that no JSON is found
|
|
339
|
-
if reply is not None:
|
|
340
|
-
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
341
|
-
if extracted_json:
|
|
342
|
-
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
343
|
-
extracted_jsons.append(cleaned_json)
|
|
344
|
-
#print(cleaned_json)
|
|
345
|
-
else:
|
|
346
|
-
error_message = """{"1":"e"}"""
|
|
347
|
-
extracted_jsons.append(error_message)
|
|
348
|
-
print(error_message)
|
|
349
|
-
else:
|
|
350
|
-
error_message = """{"1":"e"}"""
|
|
351
|
-
extracted_jsons.append(error_message)
|
|
352
|
-
#print(error_message)
|
|
353
|
-
|
|
354
|
-
# --- Safety Save ---
|
|
355
|
-
if safety:
|
|
356
|
-
#print(f"Saving CSV to: {save_directory}")
|
|
357
|
-
# Save progress so far
|
|
358
|
-
temp_df = pd.DataFrame({
|
|
359
|
-
'image_input': image_files[:i+1],
|
|
360
|
-
'link1': link1,
|
|
361
|
-
'json': extracted_jsons
|
|
362
|
-
})
|
|
363
|
-
# Normalize processed jsons so far
|
|
364
|
-
normalized_data_list = []
|
|
365
|
-
for json_str in extracted_jsons:
|
|
366
|
-
try:
|
|
367
|
-
parsed_obj = json.loads(json_str)
|
|
368
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
369
|
-
except json.JSONDecodeError:
|
|
370
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
371
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
372
|
-
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
373
|
-
# Save to CSV
|
|
374
|
-
if save_directory is None:
|
|
375
|
-
save_directory = os.getcwd()
|
|
376
|
-
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
377
|
-
|
|
378
|
-
# --- Final DataFrame ---
|
|
379
|
-
normalized_data_list = []
|
|
380
|
-
for json_str in extracted_jsons:
|
|
381
|
-
try:
|
|
382
|
-
parsed_obj = json.loads(json_str)
|
|
383
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
384
|
-
except json.JSONDecodeError:
|
|
385
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
386
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
387
|
-
|
|
388
|
-
categorized_data = pd.DataFrame({
|
|
389
|
-
'image_input': image_files,
|
|
390
|
-
'link1': pd.Series(link1).reset_index(drop=True),
|
|
391
|
-
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
392
|
-
})
|
|
393
|
-
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
394
|
-
|
|
395
|
-
if columns != "numbered": #if user wants text columns
|
|
396
|
-
categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
|
|
397
|
-
|
|
398
|
-
if to_csv:
|
|
399
|
-
if save_directory is None:
|
|
400
|
-
save_directory = os.getcwd()
|
|
401
|
-
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
402
|
-
|
|
403
|
-
return categorized_data
|
|
404
|
-
|
|
405
|
-
#image score function
|
|
406
|
-
def extract_image_score(
|
|
407
|
-
reference_image_description,
|
|
408
|
-
image_input,
|
|
409
|
-
reference_image,
|
|
410
|
-
api_key,
|
|
411
|
-
columns="numbered",
|
|
412
|
-
user_model="gpt-4o-2024-11-20",
|
|
413
|
-
creativity=0,
|
|
414
|
-
to_csv=False,
|
|
415
|
-
safety=False,
|
|
416
|
-
filename="categorized_data.csv",
|
|
417
|
-
save_directory=None,
|
|
418
|
-
model_source="OpenAI"
|
|
419
|
-
):
|
|
420
|
-
import os
|
|
421
|
-
import json
|
|
422
|
-
import pandas as pd
|
|
423
|
-
import regex
|
|
424
|
-
from tqdm import tqdm
|
|
425
|
-
import glob
|
|
426
|
-
import base64
|
|
427
|
-
from pathlib import Path
|
|
428
|
-
|
|
429
|
-
if save_directory is not None and not os.path.isdir(save_directory):
|
|
430
|
-
# Directory doesn't exist - raise an exception to halt execution
|
|
431
|
-
raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
|
|
432
|
-
|
|
433
|
-
image_extensions = [
|
|
434
|
-
'*.png', '*.jpg', '*.jpeg',
|
|
435
|
-
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
436
|
-
'*.tif', '*.tiff', '*.bmp',
|
|
437
|
-
'*.heif', '*.heic', '*.ico',
|
|
438
|
-
'*.psd'
|
|
439
|
-
]
|
|
440
|
-
|
|
441
|
-
if not isinstance(image_input, list):
|
|
442
|
-
# If image_input is a filepath (string)
|
|
443
|
-
image_files = []
|
|
444
|
-
for ext in image_extensions:
|
|
445
|
-
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
446
|
-
|
|
447
|
-
print(f"Found {len(image_files)} images.")
|
|
448
|
-
else:
|
|
449
|
-
# If image_files is already a list
|
|
450
|
-
image_files = image_input
|
|
451
|
-
print(f"Provided a list of {len(image_input)} images.")
|
|
452
|
-
|
|
453
|
-
with open(reference_image, 'rb') as f:
|
|
454
|
-
reference_image = f"data:image/{reference_image.split('.')[-1]};base64,{base64.b64encode(f.read()).decode('utf-8')}"
|
|
455
|
-
|
|
456
|
-
link1 = []
|
|
457
|
-
extracted_jsons = []
|
|
458
|
-
|
|
459
|
-
for i, img_path in enumerate(tqdm(image_files, desc="Categorising images"), start=0):
|
|
460
|
-
# Check validity first
|
|
461
|
-
if img_path is None or not os.path.exists(img_path):
|
|
462
|
-
link1.append("Skipped NaN input or invalid path")
|
|
463
|
-
extracted_jsons.append("""{"no_valid_image": 1}""")
|
|
464
|
-
continue # Skip the rest of the loop iteration
|
|
465
|
-
|
|
466
|
-
# Only open the file if path is valid
|
|
467
|
-
with open(img_path, "rb") as f:
|
|
468
|
-
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
469
|
-
|
|
470
|
-
# Handle extension safely
|
|
471
|
-
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
472
|
-
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
473
|
-
|
|
474
|
-
prompt = [
|
|
475
|
-
{
|
|
476
|
-
"type": "text",
|
|
477
|
-
"text": (
|
|
478
|
-
f"You are a visual similarity assessment system.\n"
|
|
479
|
-
f"Task ► Compare these two images:\n"
|
|
480
|
-
f"1. REFERENCE (left): {reference_image_description}\n"
|
|
481
|
-
f"2. INPUT (right): User-provided drawing\n\n"
|
|
482
|
-
f"Rating criteria:\n"
|
|
483
|
-
f"1: No meaningful similarity (fundamentally different)\n"
|
|
484
|
-
f"2: Barely recognizable similarity (25% match)\n"
|
|
485
|
-
f"3: Partial match (50% key features)\n"
|
|
486
|
-
f"4: Strong alignment (75% features)\n"
|
|
487
|
-
f"5: Near-perfect match (90%+ similarity)\n\n"
|
|
488
|
-
f"Output format ► Return ONLY:\n"
|
|
489
|
-
"{\n"
|
|
490
|
-
' "score": [1-5],\n'
|
|
491
|
-
' "summary": "reason you scored"\n'
|
|
492
|
-
"}\n\n"
|
|
493
|
-
f"Critical rules:\n"
|
|
494
|
-
f"- Score must reflect shape, proportions, and key details\n"
|
|
495
|
-
f"- List only concrete matching elements from reference\n"
|
|
496
|
-
f"- No markdown or additional text"
|
|
497
|
-
),
|
|
498
|
-
},
|
|
499
|
-
{"type": "image_url",
|
|
500
|
-
"image_url": {"url": reference_image, "detail": "high"}
|
|
501
|
-
},
|
|
502
|
-
{
|
|
503
|
-
"type": "image_url",
|
|
504
|
-
|
|
505
|
-
"image_url": {"url": encoded_image, "detail": "high"},
|
|
506
|
-
},
|
|
507
|
-
]
|
|
508
|
-
if model_source == "OpenAI":
|
|
509
|
-
from openai import OpenAI
|
|
510
|
-
client = OpenAI(api_key=api_key)
|
|
511
|
-
try:
|
|
512
|
-
response_obj = client.chat.completions.create(
|
|
513
|
-
model=user_model,
|
|
514
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
515
|
-
temperature=creativity
|
|
516
|
-
)
|
|
517
|
-
reply = response_obj.choices[0].message.content
|
|
518
|
-
link1.append(reply)
|
|
519
|
-
except Exception as e:
|
|
520
|
-
print(f"An error occurred: {e}")
|
|
521
|
-
link1.append(f"Error processing input: {e}")
|
|
522
|
-
|
|
523
|
-
elif model_source == "Perplexity":
|
|
524
|
-
from openai import OpenAI
|
|
525
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
526
|
-
try:
|
|
527
|
-
response_obj = client.chat.completions.create(
|
|
528
|
-
model=user_model,
|
|
529
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
530
|
-
temperature=creativity
|
|
531
|
-
)
|
|
532
|
-
reply = response_obj.choices[0].message.content
|
|
533
|
-
link1.append(reply)
|
|
534
|
-
except Exception as e:
|
|
535
|
-
print(f"An error occurred: {e}")
|
|
536
|
-
link1.append(f"Error processing input: {e}")
|
|
537
|
-
elif model_source == "Anthropic":
|
|
538
|
-
import anthropic
|
|
539
|
-
client = anthropic.Anthropic(api_key=api_key)
|
|
540
|
-
try:
|
|
541
|
-
message = client.messages.create(
|
|
542
|
-
model=user_model,
|
|
543
|
-
max_tokens=1024,
|
|
544
|
-
temperature=creativity,
|
|
545
|
-
messages=[{"role": "user", "content": prompt}]
|
|
546
|
-
)
|
|
547
|
-
reply = message.content[0].text # Anthropic returns content as list
|
|
548
|
-
link1.append(reply)
|
|
549
|
-
except Exception as e:
|
|
550
|
-
print(f"An error occurred: {e}")
|
|
551
|
-
link1.append(f"Error processing input: {e}")
|
|
552
|
-
elif model_source == "Mistral":
|
|
553
|
-
from mistralai import Mistral
|
|
554
|
-
client = Mistral(api_key=api_key)
|
|
555
|
-
try:
|
|
556
|
-
response = client.chat.complete(
|
|
557
|
-
model=user_model,
|
|
558
|
-
messages=[
|
|
559
|
-
{'role': 'user', 'content': prompt}
|
|
560
|
-
],
|
|
561
|
-
temperature=creativity
|
|
562
|
-
)
|
|
563
|
-
reply = response.choices[0].message.content
|
|
564
|
-
link1.append(reply)
|
|
565
|
-
except Exception as e:
|
|
566
|
-
print(f"An error occurred: {e}")
|
|
567
|
-
link1.append(f"Error processing input: {e}")
|
|
568
|
-
else:
|
|
569
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
570
|
-
# in situation that no JSON is found
|
|
571
|
-
if reply is not None:
|
|
572
|
-
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
573
|
-
if extracted_json:
|
|
574
|
-
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
575
|
-
extracted_jsons.append(cleaned_json)
|
|
576
|
-
#print(cleaned_json)
|
|
577
|
-
else:
|
|
578
|
-
error_message = """{"1":"e"}"""
|
|
579
|
-
extracted_jsons.append(error_message)
|
|
580
|
-
print(error_message)
|
|
581
|
-
else:
|
|
582
|
-
error_message = """{"1":"e"}"""
|
|
583
|
-
extracted_jsons.append(error_message)
|
|
584
|
-
#print(error_message)
|
|
585
|
-
|
|
586
|
-
# --- Safety Save ---
|
|
587
|
-
if safety:
|
|
588
|
-
# Save progress so far
|
|
589
|
-
temp_df = pd.DataFrame({
|
|
590
|
-
'image_input': image_files[:i+1],
|
|
591
|
-
'link1': link1,
|
|
592
|
-
'json': extracted_jsons
|
|
593
|
-
})
|
|
594
|
-
# Normalize processed jsons so far
|
|
595
|
-
normalized_data_list = []
|
|
596
|
-
for json_str in extracted_jsons:
|
|
597
|
-
try:
|
|
598
|
-
parsed_obj = json.loads(json_str)
|
|
599
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
600
|
-
except json.JSONDecodeError:
|
|
601
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
602
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
603
|
-
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
604
|
-
# Save to CSV
|
|
605
|
-
if save_directory is None:
|
|
606
|
-
save_directory = os.getcwd()
|
|
607
|
-
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
608
|
-
|
|
609
|
-
# --- Final DataFrame ---
|
|
610
|
-
normalized_data_list = []
|
|
611
|
-
for json_str in extracted_jsons:
|
|
612
|
-
try:
|
|
613
|
-
parsed_obj = json.loads(json_str)
|
|
614
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
615
|
-
except json.JSONDecodeError:
|
|
616
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
617
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
618
|
-
|
|
619
|
-
categorized_data = pd.DataFrame({
|
|
620
|
-
'image_input': image_files,
|
|
621
|
-
'link1': pd.Series(link1).reset_index(drop=True),
|
|
622
|
-
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
623
|
-
})
|
|
624
|
-
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
625
|
-
|
|
626
|
-
if to_csv:
|
|
627
|
-
if save_directory is None:
|
|
628
|
-
save_directory = os.getcwd()
|
|
629
|
-
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
630
|
-
|
|
631
|
-
return categorized_data
|
|
632
|
-
|
|
633
|
-
# image features function
|
|
634
|
-
def extract_image_features(
|
|
635
|
-
image_description,
|
|
636
|
-
image_input,
|
|
637
|
-
features_to_extract,
|
|
638
|
-
api_key,
|
|
639
|
-
columns="numbered",
|
|
640
|
-
user_model="gpt-4o-2024-11-20",
|
|
641
|
-
creativity=0,
|
|
642
|
-
to_csv=False,
|
|
643
|
-
safety=False,
|
|
644
|
-
filename="categorized_data.csv",
|
|
645
|
-
save_directory=None,
|
|
646
|
-
model_source="OpenAI"
|
|
647
|
-
):
|
|
648
|
-
import os
|
|
649
|
-
import json
|
|
650
|
-
import pandas as pd
|
|
651
|
-
import regex
|
|
652
|
-
from tqdm import tqdm
|
|
653
|
-
import glob
|
|
654
|
-
import base64
|
|
655
|
-
from pathlib import Path
|
|
656
|
-
|
|
657
|
-
if save_directory is not None and not os.path.isdir(save_directory):
|
|
658
|
-
# Directory doesn't exist - raise an exception to halt execution
|
|
659
|
-
raise FileNotFoundError(f"Directory {save_directory} doesn't exist")
|
|
660
|
-
|
|
661
|
-
image_extensions = [
|
|
662
|
-
'*.png', '*.jpg', '*.jpeg',
|
|
663
|
-
'*.gif', '*.webp', '*.svg', '*.svgz', '*.avif', '*.apng',
|
|
664
|
-
'*.tif', '*.tiff', '*.bmp',
|
|
665
|
-
'*.heif', '*.heic', '*.ico',
|
|
666
|
-
'*.psd'
|
|
667
|
-
]
|
|
668
|
-
|
|
669
|
-
if not isinstance(image_input, list):
|
|
670
|
-
# If image_input is a filepath (string)
|
|
671
|
-
image_files = []
|
|
672
|
-
for ext in image_extensions:
|
|
673
|
-
image_files.extend(glob.glob(os.path.join(image_input, ext)))
|
|
674
|
-
|
|
675
|
-
print(f"Found {len(image_files)} images.")
|
|
676
|
-
else:
|
|
677
|
-
# If image_files is already a list
|
|
678
|
-
image_files = image_input
|
|
679
|
-
print(f"Provided a list of {len(image_input)} images.")
|
|
680
|
-
|
|
681
|
-
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(features_to_extract))
|
|
682
|
-
cat_num = len(features_to_extract)
|
|
683
|
-
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
684
|
-
example_JSON = json.dumps(category_dict, indent=4)
|
|
685
|
-
|
|
686
|
-
# ensure number of categories is what user wants
|
|
687
|
-
print("\nPlease verify the categories you entered:")
|
|
688
|
-
for i, cat in enumerate(features_to_extract, 1):
|
|
689
|
-
print(f"{i}. {cat}")
|
|
690
|
-
print("\nIf the list above is correct, type 'next' and press Enter to continue.")
|
|
691
|
-
print("Type 'exit' to cancel the operation.")
|
|
692
|
-
|
|
693
|
-
print("\nPlease verify the categories you entered:")
|
|
694
|
-
for i, cat in enumerate(features_to_extract, 1):
|
|
695
|
-
print(f"{i}. {cat}")
|
|
696
|
-
response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
|
|
697
|
-
while response.strip().lower() != "next":
|
|
698
|
-
response = input("Please type 'next' to continue: ")
|
|
699
|
-
|
|
700
|
-
link1 = []
|
|
701
|
-
extracted_jsons = []
|
|
702
|
-
|
|
703
|
-
for i, img_path in enumerate(
|
|
704
|
-
tqdm(image_files, desc="Categorising images"), start=0):
|
|
705
|
-
# encode this specific image once
|
|
706
|
-
with open(img_path, "rb") as f:
|
|
707
|
-
encoded = base64.b64encode(f.read()).decode("utf-8")
|
|
708
|
-
ext = Path(img_path).suffix.lstrip(".").lower()
|
|
709
|
-
encoded_image = f"data:image/{ext};base64,{encoded}"
|
|
710
|
-
|
|
711
|
-
prompt = [
|
|
712
|
-
{
|
|
713
|
-
"type": "text",
|
|
714
|
-
"text": (
|
|
715
|
-
f"You are a visual question answering assistant.\n"
|
|
716
|
-
f"Task ► Analyze the attached image and answer these specific questions:\n\n"
|
|
717
|
-
f"Image context: {image_description}\n\n"
|
|
718
|
-
f"Questions to answer:\n{categories_str}\n\n"
|
|
719
|
-
f"Output format ► Return **only** a JSON object where:\n"
|
|
720
|
-
f"- Keys are question numbers ('1', '2', ...)\n"
|
|
721
|
-
f"- Values are concise answers (numbers, short phrases)\n\n"
|
|
722
|
-
f"Example for 3 questions:\n"
|
|
723
|
-
"{\n"
|
|
724
|
-
' "1": "4",\n'
|
|
725
|
-
' "2": "blue",\n'
|
|
726
|
-
' "3": "yes"\n'
|
|
727
|
-
"}\n\n"
|
|
728
|
-
f"Important rules:\n"
|
|
729
|
-
f"1. Answer directly - no explanations\n"
|
|
730
|
-
f"2. Use exact numerical values when possible\n"
|
|
731
|
-
f"3. For yes/no questions, use 'yes' or 'no'\n"
|
|
732
|
-
f"4. Never add extra keys or formatting"
|
|
733
|
-
),
|
|
734
|
-
},
|
|
735
|
-
{
|
|
736
|
-
"type": "image_url",
|
|
737
|
-
"image_url": {"url": encoded_image, "detail": "high"},
|
|
738
|
-
},
|
|
739
|
-
]
|
|
740
|
-
if model_source == "OpenAI":
|
|
741
|
-
from openai import OpenAI
|
|
742
|
-
client = OpenAI(api_key=api_key)
|
|
743
|
-
try:
|
|
744
|
-
response_obj = client.chat.completions.create(
|
|
745
|
-
model=user_model,
|
|
746
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
747
|
-
temperature=creativity
|
|
748
|
-
)
|
|
749
|
-
reply = response_obj.choices[0].message.content
|
|
750
|
-
link1.append(reply)
|
|
751
|
-
except Exception as e:
|
|
752
|
-
print(f"An error occurred: {e}")
|
|
753
|
-
link1.append(f"Error processing input: {e}")
|
|
754
|
-
|
|
755
|
-
elif model_source == "Perplexity":
|
|
756
|
-
from openai import OpenAI
|
|
757
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
758
|
-
try:
|
|
759
|
-
response_obj = client.chat.completions.create(
|
|
760
|
-
model=user_model,
|
|
761
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
762
|
-
temperature=creativity
|
|
763
|
-
)
|
|
764
|
-
reply = response_obj.choices[0].message.content
|
|
765
|
-
link1.append(reply)
|
|
766
|
-
except Exception as e:
|
|
767
|
-
print(f"An error occurred: {e}")
|
|
768
|
-
link1.append(f"Error processing input: {e}")
|
|
769
|
-
elif model_source == "Anthropic":
|
|
770
|
-
import anthropic
|
|
771
|
-
client = anthropic.Anthropic(api_key=api_key)
|
|
772
|
-
try:
|
|
773
|
-
message = client.messages.create(
|
|
774
|
-
model=user_model,
|
|
775
|
-
max_tokens=1024,
|
|
776
|
-
temperature=creativity,
|
|
777
|
-
messages=[{"role": "user", "content": prompt}]
|
|
778
|
-
)
|
|
779
|
-
reply = message.content[0].text # Anthropic returns content as list
|
|
780
|
-
link1.append(reply)
|
|
781
|
-
except Exception as e:
|
|
782
|
-
print(f"An error occurred: {e}")
|
|
783
|
-
link1.append(f"Error processing input: {e}")
|
|
784
|
-
elif model_source == "Mistral":
|
|
785
|
-
from mistralai import Mistral
|
|
786
|
-
client = Mistral(api_key=api_key)
|
|
787
|
-
try:
|
|
788
|
-
response = client.chat.complete(
|
|
789
|
-
model=user_model,
|
|
790
|
-
messages=[
|
|
791
|
-
{'role': 'user', 'content': prompt}
|
|
792
|
-
],
|
|
793
|
-
temperature=creativity
|
|
794
|
-
)
|
|
795
|
-
reply = response.choices[0].message.content
|
|
796
|
-
link1.append(reply)
|
|
797
|
-
except Exception as e:
|
|
798
|
-
print(f"An error occurred: {e}")
|
|
799
|
-
link1.append(f"Error processing input: {e}")
|
|
800
|
-
else:
|
|
801
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
802
|
-
# in situation that no JSON is found
|
|
803
|
-
if reply is not None:
|
|
804
|
-
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
805
|
-
if extracted_json:
|
|
806
|
-
cleaned_json = extracted_json[0].replace('[', '').replace(']', '').replace('\n', '').replace(" ", '').replace(" ", '')
|
|
807
|
-
extracted_jsons.append(cleaned_json)
|
|
808
|
-
#print(cleaned_json)
|
|
809
|
-
else:
|
|
810
|
-
error_message = """{"1":"e"}"""
|
|
811
|
-
extracted_jsons.append(error_message)
|
|
812
|
-
print(error_message)
|
|
813
|
-
else:
|
|
814
|
-
error_message = """{"1":"e"}"""
|
|
815
|
-
extracted_jsons.append(error_message)
|
|
816
|
-
#print(error_message)
|
|
817
|
-
|
|
818
|
-
# --- Safety Save ---
|
|
819
|
-
if safety:
|
|
820
|
-
#print(f"Saving CSV to: {save_directory}")
|
|
821
|
-
# Save progress so far
|
|
822
|
-
temp_df = pd.DataFrame({
|
|
823
|
-
'image_input': image_files[:i+1],
|
|
824
|
-
'link1': link1,
|
|
825
|
-
'json': extracted_jsons
|
|
826
|
-
})
|
|
827
|
-
# Normalize processed jsons so far
|
|
828
|
-
normalized_data_list = []
|
|
829
|
-
for json_str in extracted_jsons:
|
|
830
|
-
try:
|
|
831
|
-
parsed_obj = json.loads(json_str)
|
|
832
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
833
|
-
except json.JSONDecodeError:
|
|
834
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
835
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
836
|
-
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
837
|
-
# Save to CSV
|
|
838
|
-
if save_directory is None:
|
|
839
|
-
save_directory = os.getcwd()
|
|
840
|
-
temp_df.to_csv(os.path.join(save_directory, filename), index=False)
|
|
841
|
-
|
|
842
|
-
# --- Final DataFrame ---
|
|
843
|
-
normalized_data_list = []
|
|
844
|
-
for json_str in extracted_jsons:
|
|
845
|
-
try:
|
|
846
|
-
parsed_obj = json.loads(json_str)
|
|
847
|
-
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
848
|
-
except json.JSONDecodeError:
|
|
849
|
-
normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
|
|
850
|
-
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
851
|
-
|
|
852
|
-
categorized_data = pd.DataFrame({
|
|
853
|
-
'image_input': image_files,
|
|
854
|
-
'link1': pd.Series(link1).reset_index(drop=True),
|
|
855
|
-
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
856
|
-
})
|
|
857
|
-
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
858
|
-
|
|
859
|
-
if columns != "numbered": #if user wants text columns
|
|
860
|
-
categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
|
|
861
|
-
|
|
862
|
-
if to_csv:
|
|
863
|
-
if save_directory is None:
|
|
864
|
-
save_directory = os.getcwd()
|
|
865
|
-
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
866
|
-
|
|
867
|
-
return categorized_data
|
|
868
|
-
|
|
869
|
-
#extract categories from corpus
|
|
870
|
-
def explore_corpus(
|
|
871
|
-
survey_question,
|
|
872
|
-
survey_input,
|
|
873
|
-
api_key,
|
|
874
|
-
cat_num=10,
|
|
875
|
-
divisions=5,
|
|
876
|
-
user_model="gpt-4o-2024-11-20",
|
|
877
|
-
creativity=0,
|
|
878
|
-
to_csv=False,
|
|
879
|
-
filename="categorized_data.csv",
|
|
880
|
-
save_directory=None,
|
|
881
|
-
model_source="OpenAI"
|
|
882
|
-
):
|
|
883
|
-
import os
|
|
884
|
-
import pandas as pd
|
|
885
|
-
import random
|
|
886
|
-
from openai import OpenAI
|
|
887
|
-
from openai import OpenAI, BadRequestError
|
|
888
|
-
from tqdm import tqdm
|
|
889
|
-
|
|
890
|
-
print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
|
|
891
|
-
print()
|
|
892
|
-
|
|
893
|
-
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
894
|
-
chunk_size = int(chunk_size)
|
|
895
|
-
|
|
896
|
-
if chunk_size < (cat_num/2):
|
|
897
|
-
raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
|
|
898
|
-
f"Choose one solution: \n"
|
|
899
|
-
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
900
|
-
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
901
|
-
|
|
902
|
-
random_chunks = []
|
|
903
|
-
for i in range(divisions):
|
|
904
|
-
chunk = survey_input.sample(n=chunk_size).tolist()
|
|
905
|
-
random_chunks.append(chunk)
|
|
906
|
-
|
|
907
|
-
responses = []
|
|
908
|
-
responses_list = []
|
|
909
|
-
|
|
910
|
-
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
911
|
-
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
912
|
-
prompt = f"""Identify {cat_num} broad categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
913
|
-
Responses are each separated by a semicolon. \
|
|
914
|
-
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
915
|
-
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
916
|
-
|
|
917
|
-
if model_source == "OpenAI":
|
|
918
|
-
client = OpenAI(api_key=api_key)
|
|
919
|
-
try:
|
|
920
|
-
response_obj = client.chat.completions.create(
|
|
921
|
-
model=user_model,
|
|
922
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
923
|
-
temperature=creativity
|
|
924
|
-
)
|
|
925
|
-
reply = response_obj.choices[0].message.content
|
|
926
|
-
responses.append(reply)
|
|
927
|
-
except BadRequestError as e:
|
|
928
|
-
if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
|
|
929
|
-
error_msg = (f"Token limit exceeded for model {user_model}. "
|
|
930
|
-
f"Try increasing the 'iterations' parameter to create smaller chunks.")
|
|
931
|
-
raise ValueError(error_msg)
|
|
932
|
-
else:
|
|
933
|
-
print(f"OpenAI API error: {e}")
|
|
934
|
-
except Exception as e:
|
|
935
|
-
print(f"An error occurred: {e}")
|
|
936
|
-
else:
|
|
937
|
-
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
938
|
-
|
|
939
|
-
# Extract just the text as a list
|
|
940
|
-
items = []
|
|
941
|
-
for line in responses[i].split('\n'):
|
|
942
|
-
if '. ' in line:
|
|
943
|
-
try:
|
|
944
|
-
items.append(line.split('. ', 1)[1])
|
|
945
|
-
except IndexError:
|
|
946
|
-
pass
|
|
947
|
-
|
|
948
|
-
responses_list.append(items)
|
|
949
|
-
|
|
950
|
-
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
951
|
-
|
|
952
|
-
#convert flat_list to a df
|
|
953
|
-
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
954
|
-
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
955
|
-
df['counts'] = df['Category'].map(counts)
|
|
956
|
-
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
957
|
-
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
958
|
-
|
|
959
|
-
return df
|
cat_llm-0.0.8/tests/__init__.py
DELETED
|
File without changes
|