cat-llm 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.18.dist-info → cat_llm-0.0.19.dist-info}/METADATA +1 -1
- cat_llm-0.0.19.dist-info/RECORD +7 -0
- catllm/__about__.py +2 -2
- catllm/cat_llm.py +133 -18
- cat_llm-0.0.18.dist-info/RECORD +0 -7
- {cat_llm-0.0.18.dist-info → cat_llm-0.0.19.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.18.dist-info → cat_llm-0.0.19.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
catllm/__about__.py,sha256=ht639_Zw-BmQ-6TW3sLLeC_s67LbW9TFDrYT8b4xpIg,404
|
|
2
|
+
catllm/__init__.py,sha256=xDin9x4jymeccuxE9Xf-27ncR9h7247IwLbeYN-m3j8,266
|
|
3
|
+
catllm/cat_llm.py,sha256=pj_xcsFA5OQVhMv9-73YT7tDn_3Ol3UowqYAbPrlrZI,43825
|
|
4
|
+
cat_llm-0.0.19.dist-info/METADATA,sha256=Ln27uQy5nspYHg0-iFH_XvAlDJFlfxREKeye2llHteI,1679
|
|
5
|
+
cat_llm-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
cat_llm-0.0.19.dist-info/licenses/LICENSE,sha256=wJLsvOr6lrFUDcoPXExa01HOKFWrS3JC9f0RudRw8uw,1075
|
|
7
|
+
cat_llm-0.0.19.dist-info/RECORD,,
|
catllm/__about__.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
2
|
#
|
|
3
3
|
# SPDX-License-Identifier: MIT
|
|
4
|
-
__version__ = "0.0.
|
|
4
|
+
__version__ = "0.0.19"
|
|
5
5
|
__author__ = "Chris Soria"
|
|
6
6
|
__email__ = "chrissoria@berkeley.edu"
|
|
7
7
|
__title__ = "cat-llm"
|
|
8
|
-
__description__ = "A tool for categorizing text data and images using LLMs and vision models"
|
|
8
|
+
__description__ = "A tool for categorizing and exploring text data and images using LLMs and vision models"
|
|
9
9
|
__url__ = "https://github.com/chrissoria/cat-llm"
|
|
10
10
|
__license__ = "MIT"
|
catllm/cat_llm.py
CHANGED
|
@@ -25,12 +25,9 @@ def extract_multi_class(
|
|
|
25
25
|
example_JSON = json.dumps(category_dict, indent=4)
|
|
26
26
|
|
|
27
27
|
# ensure number of categories is what user wants
|
|
28
|
-
print("\
|
|
28
|
+
print("\nThe categories you entered:")
|
|
29
29
|
for i, cat in enumerate(categories, 1):
|
|
30
30
|
print(f"{i}. {cat}")
|
|
31
|
-
response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
|
|
32
|
-
while response.strip().lower() != "next":
|
|
33
|
-
response = input("Please type 'next' to continue: ")
|
|
34
31
|
|
|
35
32
|
link1 = []
|
|
36
33
|
extracted_jsons = []
|
|
@@ -684,18 +681,9 @@ def extract_image_features(
|
|
|
684
681
|
example_JSON = json.dumps(category_dict, indent=4)
|
|
685
682
|
|
|
686
683
|
# ensure number of categories is what user wants
|
|
687
|
-
print("\
|
|
684
|
+
print("\nThe image features to be extracted are:")
|
|
688
685
|
for i, cat in enumerate(features_to_extract, 1):
|
|
689
686
|
print(f"{i}. {cat}")
|
|
690
|
-
print("\nIf the list above is correct, type 'next' and press Enter to continue.")
|
|
691
|
-
print("Type 'exit' to cancel the operation.")
|
|
692
|
-
|
|
693
|
-
print("\nPlease verify the categories you entered:")
|
|
694
|
-
for i, cat in enumerate(features_to_extract, 1):
|
|
695
|
-
print(f"{i}. {cat}")
|
|
696
|
-
response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
|
|
697
|
-
while response.strip().lower() != "next":
|
|
698
|
-
response = input("Please type 'next' to continue: ")
|
|
699
687
|
|
|
700
688
|
link1 = []
|
|
701
689
|
extracted_jsons = []
|
|
@@ -871,6 +859,8 @@ def explore_corpus(
|
|
|
871
859
|
survey_question,
|
|
872
860
|
survey_input,
|
|
873
861
|
api_key,
|
|
862
|
+
research_question=None,
|
|
863
|
+
specificity="broad",
|
|
874
864
|
cat_num=10,
|
|
875
865
|
divisions=5,
|
|
876
866
|
user_model="gpt-4o-2024-11-20",
|
|
@@ -892,7 +882,7 @@ def explore_corpus(
|
|
|
892
882
|
chunk_size = int(chunk_size)
|
|
893
883
|
|
|
894
884
|
if chunk_size < (cat_num/2):
|
|
895
|
-
raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
|
|
885
|
+
raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
|
|
896
886
|
f"Choose one solution: \n"
|
|
897
887
|
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
898
888
|
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
@@ -907,7 +897,7 @@ def explore_corpus(
|
|
|
907
897
|
|
|
908
898
|
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
909
899
|
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
910
|
-
prompt = f"""Identify {cat_num}
|
|
900
|
+
prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
911
901
|
Responses are each separated by a semicolon. \
|
|
912
902
|
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
913
903
|
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
@@ -917,7 +907,12 @@ Number your categories from 1 through {cat_num} and be concise with the category
|
|
|
917
907
|
try:
|
|
918
908
|
response_obj = client.chat.completions.create(
|
|
919
909
|
model=user_model,
|
|
920
|
-
messages=[
|
|
910
|
+
messages=[
|
|
911
|
+
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
912
|
+
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
913
|
+
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
914
|
+
{'role': 'user', 'content': prompt}
|
|
915
|
+
]
|
|
921
916
|
temperature=creativity
|
|
922
917
|
)
|
|
923
918
|
reply = response_obj.choices[0].message.content
|
|
@@ -957,4 +952,124 @@ Number your categories from 1 through {cat_num} and be concise with the category
|
|
|
957
952
|
if filename is not None:
|
|
958
953
|
df.to_csv(filename, index=False)
|
|
959
954
|
|
|
960
|
-
return df
|
|
955
|
+
return df
|
|
956
|
+
|
|
957
|
+
#extract top categories from corpus
|
|
958
|
+
def explore_common_categories(
|
|
959
|
+
survey_question,
|
|
960
|
+
survey_input,
|
|
961
|
+
api_key,
|
|
962
|
+
top_n=10,
|
|
963
|
+
cat_num=10,
|
|
964
|
+
divisions=5,
|
|
965
|
+
user_model="gpt-4o-2024-11-20",
|
|
966
|
+
creativity=0,
|
|
967
|
+
specificity="broad",
|
|
968
|
+
research_question=None,
|
|
969
|
+
filename=None,
|
|
970
|
+
model_source="OpenAI"
|
|
971
|
+
):
|
|
972
|
+
import os
|
|
973
|
+
import pandas as pd
|
|
974
|
+
import random
|
|
975
|
+
from openai import OpenAI
|
|
976
|
+
from openai import OpenAI, BadRequestError
|
|
977
|
+
from tqdm import tqdm
|
|
978
|
+
|
|
979
|
+
print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
|
|
980
|
+
print()
|
|
981
|
+
|
|
982
|
+
chunk_size = round(max(1, len(survey_input) / divisions),0)
|
|
983
|
+
chunk_size = int(chunk_size)
|
|
984
|
+
|
|
985
|
+
if chunk_size < (cat_num/2):
|
|
986
|
+
raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
|
|
987
|
+
f"Choose one solution: \n"
|
|
988
|
+
f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
|
|
989
|
+
f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
|
|
990
|
+
|
|
991
|
+
random_chunks = []
|
|
992
|
+
for i in range(divisions):
|
|
993
|
+
chunk = survey_input.sample(n=chunk_size).tolist()
|
|
994
|
+
random_chunks.append(chunk)
|
|
995
|
+
|
|
996
|
+
responses = []
|
|
997
|
+
responses_list = []
|
|
998
|
+
|
|
999
|
+
for i in tqdm(range(divisions), desc="Processing chunks"):
|
|
1000
|
+
survey_participant_chunks = '; '.join(random_chunks[i])
|
|
1001
|
+
prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
|
|
1002
|
+
Responses are each separated by a semicolon. \
|
|
1003
|
+
Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
|
|
1004
|
+
Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
|
|
1005
|
+
|
|
1006
|
+
if model_source == "OpenAI":
|
|
1007
|
+
client = OpenAI(api_key=api_key)
|
|
1008
|
+
try:
|
|
1009
|
+
response_obj = client.chat.completions.create(
|
|
1010
|
+
model=user_model,
|
|
1011
|
+
messages=[
|
|
1012
|
+
{'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
|
|
1013
|
+
The specific task is to identify {specificity} categories of responses to a survey question. \
|
|
1014
|
+
The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
|
|
1015
|
+
{'role': 'user', 'content': prompt}
|
|
1016
|
+
],
|
|
1017
|
+
temperature=creativity
|
|
1018
|
+
)
|
|
1019
|
+
reply = response_obj.choices[0].message.content
|
|
1020
|
+
responses.append(reply)
|
|
1021
|
+
except BadRequestError as e:
|
|
1022
|
+
if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
|
|
1023
|
+
error_msg = (f"Token limit exceeded for model {user_model}. "
|
|
1024
|
+
f"Try increasing the 'iterations' parameter to create smaller chunks.")
|
|
1025
|
+
raise ValueError(error_msg)
|
|
1026
|
+
else:
|
|
1027
|
+
print(f"OpenAI API error: {e}")
|
|
1028
|
+
except Exception as e:
|
|
1029
|
+
print(f"An error occurred: {e}")
|
|
1030
|
+
else:
|
|
1031
|
+
raise ValueError(f"Unsupported model_source: {model_source}")
|
|
1032
|
+
|
|
1033
|
+
# Extract just the text as a list
|
|
1034
|
+
items = []
|
|
1035
|
+
for line in responses[i].split('\n'):
|
|
1036
|
+
if '. ' in line:
|
|
1037
|
+
try:
|
|
1038
|
+
items.append(line.split('. ', 1)[1])
|
|
1039
|
+
except IndexError:
|
|
1040
|
+
pass
|
|
1041
|
+
|
|
1042
|
+
responses_list.append(items)
|
|
1043
|
+
|
|
1044
|
+
flat_list = [item.lower() for sublist in responses_list for item in sublist]
|
|
1045
|
+
|
|
1046
|
+
#convert flat_list to a df
|
|
1047
|
+
df = pd.DataFrame(flat_list, columns=['Category'])
|
|
1048
|
+
counts = pd.Series(flat_list).value_counts() # Use original list before conversion
|
|
1049
|
+
df['counts'] = df['Category'].map(counts)
|
|
1050
|
+
df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
|
|
1051
|
+
df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
|
|
1052
|
+
|
|
1053
|
+
second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
|
|
1054
|
+
The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
|
|
1055
|
+
Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
|
|
1056
|
+
|
|
1057
|
+
if model_source == "OpenAI":
|
|
1058
|
+
client = OpenAI(api_key=api_key)
|
|
1059
|
+
response_obj = client.chat.completions.create(
|
|
1060
|
+
model=user_model,
|
|
1061
|
+
messages=[{'role': 'user', 'content': second_prompt}],
|
|
1062
|
+
temperature=creativity
|
|
1063
|
+
)
|
|
1064
|
+
top_categories = response_obj.choices[0].message.content
|
|
1065
|
+
print(top_categories)
|
|
1066
|
+
|
|
1067
|
+
top_categories_final = []
|
|
1068
|
+
for line in top_categories.split('\n'):
|
|
1069
|
+
if '. ' in line:
|
|
1070
|
+
try:
|
|
1071
|
+
top_categories_final.append(line.split('. ', 1)[1])
|
|
1072
|
+
except IndexError:
|
|
1073
|
+
pass
|
|
1074
|
+
|
|
1075
|
+
return top_categories_final
|
cat_llm-0.0.18.dist-info/RECORD
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
catllm/__about__.py,sha256=Sm0CFgjLJWaUvvKF0YjVAEqklVtmslYxtTxI_EKTI3M,390
|
|
2
|
-
catllm/__init__.py,sha256=xDin9x4jymeccuxE9Xf-27ncR9h7247IwLbeYN-m3j8,266
|
|
3
|
-
catllm/cat_llm.py,sha256=rb0pQCQb1tBzeRMrZyNM2kePj-Yf3D_X0S4UJPtctVM,38738
|
|
4
|
-
cat_llm-0.0.18.dist-info/METADATA,sha256=OVwmvr7dil53X6VpDQE8ysSBWnlT7QK8RrArSnd6WNc,1679
|
|
5
|
-
cat_llm-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
cat_llm-0.0.18.dist-info/licenses/LICENSE,sha256=wJLsvOr6lrFUDcoPXExa01HOKFWrS3JC9f0RudRw8uw,1075
|
|
7
|
-
cat_llm-0.0.18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|