cat-llm 0.0.18__tar.gz → 0.0.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.18
3
+ Version: 0.0.19
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -1,10 +1,10 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.18"
4
+ __version__ = "0.0.19"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
8
- __description__ = "A tool for categorizing text data and images using LLMs and vision models"
8
+ __description__ = "A tool for categorizing and exploring text data and images using LLMs and vision models"
9
9
  __url__ = "https://github.com/chrissoria/cat-llm"
10
10
  __license__ = "MIT"
@@ -25,12 +25,9 @@ def extract_multi_class(
25
25
  example_JSON = json.dumps(category_dict, indent=4)
26
26
 
27
27
  # ensure number of categories is what user wants
28
- print("\nPlease verify the categories you entered:")
28
+ print("\nThe categories you entered:")
29
29
  for i, cat in enumerate(categories, 1):
30
30
  print(f"{i}. {cat}")
31
- response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
32
- while response.strip().lower() != "next":
33
- response = input("Please type 'next' to continue: ")
34
31
 
35
32
  link1 = []
36
33
  extracted_jsons = []
@@ -684,18 +681,9 @@ def extract_image_features(
684
681
  example_JSON = json.dumps(category_dict, indent=4)
685
682
 
686
683
  # ensure number of categories is what user wants
687
- print("\nPlease verify the categories you entered:")
684
+ print("\nThe image features to be extracted are:")
688
685
  for i, cat in enumerate(features_to_extract, 1):
689
686
  print(f"{i}. {cat}")
690
- print("\nIf the list above is correct, type 'next' and press Enter to continue.")
691
- print("Type 'exit' to cancel the operation.")
692
-
693
- print("\nPlease verify the categories you entered:")
694
- for i, cat in enumerate(features_to_extract, 1):
695
- print(f"{i}. {cat}")
696
- response = input("\nIf the list above is correct, type 'next' and press Enter to continue: ")
697
- while response.strip().lower() != "next":
698
- response = input("Please type 'next' to continue: ")
699
687
 
700
688
  link1 = []
701
689
  extracted_jsons = []
@@ -871,6 +859,8 @@ def explore_corpus(
871
859
  survey_question,
872
860
  survey_input,
873
861
  api_key,
862
+ research_question=None,
863
+ specificity="broad",
874
864
  cat_num=10,
875
865
  divisions=5,
876
866
  user_model="gpt-4o-2024-11-20",
@@ -892,7 +882,7 @@ def explore_corpus(
892
882
  chunk_size = int(chunk_size)
893
883
 
894
884
  if chunk_size < (cat_num/2):
895
- raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
885
+ raise ValueError(f"Cannot extract {cat_num} {specificity} categories from chunks of only {chunk_size} responses. \n"
896
886
  f"Choose one solution: \n"
897
887
  f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
898
888
  f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
@@ -907,7 +897,7 @@ def explore_corpus(
907
897
 
908
898
  for i in tqdm(range(divisions), desc="Processing chunks"):
909
899
  survey_participant_chunks = '; '.join(random_chunks[i])
910
- prompt = f"""Identify {cat_num} broad categories of responses to the question "{survey_question}" in the following list of responses. \
900
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
911
901
  Responses are each separated by a semicolon. \
912
902
  Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
913
903
  Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
@@ -917,7 +907,12 @@ Number your categories from 1 through {cat_num} and be concise with the category
917
907
  try:
918
908
  response_obj = client.chat.completions.create(
919
909
  model=user_model,
920
- messages=[{'role': 'user', 'content': prompt}],
910
+ messages=[
911
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
912
+ The specific task is to identify {specificity} categories of responses to a survey question. \
913
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
914
+ {'role': 'user', 'content': prompt}
915
+ ]
921
916
  temperature=creativity
922
917
  )
923
918
  reply = response_obj.choices[0].message.content
@@ -957,4 +952,124 @@ Number your categories from 1 through {cat_num} and be concise with the category
957
952
  if filename is not None:
958
953
  df.to_csv(filename, index=False)
959
954
 
960
- return df
955
+ return df
956
+
957
+ #extract top categories from corpus
958
+ def explore_common_categories(
959
+ survey_question,
960
+ survey_input,
961
+ api_key,
962
+ top_n=10,
963
+ cat_num=10,
964
+ divisions=5,
965
+ user_model="gpt-4o-2024-11-20",
966
+ creativity=0,
967
+ specificity="broad",
968
+ research_question=None,
969
+ filename=None,
970
+ model_source="OpenAI"
971
+ ):
972
+ import os
973
+ import pandas as pd
974
+ import random
975
+ from openai import OpenAI
976
+ from openai import OpenAI, BadRequestError
977
+ from tqdm import tqdm
978
+
979
+ print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
980
+ print()
981
+
982
+ chunk_size = round(max(1, len(survey_input) / divisions),0)
983
+ chunk_size = int(chunk_size)
984
+
985
+ if chunk_size < (cat_num/2):
986
+ raise ValueError(f"Cannot extract {cat_num} categories from chunks of only {chunk_size} responses. \n"
987
+ f"Choose one solution: \n"
988
+ f"(1) Reduce 'divisions' parameter (currently {divisions}) to create larger chunks, or \n"
989
+ f"(2) Reduce 'cat_num' parameter (currently {cat_num}) to extract fewer categories per chunk.")
990
+
991
+ random_chunks = []
992
+ for i in range(divisions):
993
+ chunk = survey_input.sample(n=chunk_size).tolist()
994
+ random_chunks.append(chunk)
995
+
996
+ responses = []
997
+ responses_list = []
998
+
999
+ for i in tqdm(range(divisions), desc="Processing chunks"):
1000
+ survey_participant_chunks = '; '.join(random_chunks[i])
1001
+ prompt = f"""Identify {cat_num} {specificity} categories of responses to the question "{survey_question}" in the following list of responses. \
1002
+ Responses are each separated by a semicolon. \
1003
+ Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
1004
+ Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
1005
+
1006
+ if model_source == "OpenAI":
1007
+ client = OpenAI(api_key=api_key)
1008
+ try:
1009
+ response_obj = client.chat.completions.create(
1010
+ model=user_model,
1011
+ messages=[
1012
+ {'role': 'system', 'content': f"""You are a helpful assistant that extracts categories from survey responses. \
1013
+ The specific task is to identify {specificity} categories of responses to a survey question. \
1014
+ The research question is: {research_question}""" if research_question else "You are a helpful assistant."},
1015
+ {'role': 'user', 'content': prompt}
1016
+ ],
1017
+ temperature=creativity
1018
+ )
1019
+ reply = response_obj.choices[0].message.content
1020
+ responses.append(reply)
1021
+ except BadRequestError as e:
1022
+ if "context_length_exceeded" in str(e) or "maximum context length" in str(e):
1023
+ error_msg = (f"Token limit exceeded for model {user_model}. "
1024
+ f"Try increasing the 'iterations' parameter to create smaller chunks.")
1025
+ raise ValueError(error_msg)
1026
+ else:
1027
+ print(f"OpenAI API error: {e}")
1028
+ except Exception as e:
1029
+ print(f"An error occurred: {e}")
1030
+ else:
1031
+ raise ValueError(f"Unsupported model_source: {model_source}")
1032
+
1033
+ # Extract just the text as a list
1034
+ items = []
1035
+ for line in responses[i].split('\n'):
1036
+ if '. ' in line:
1037
+ try:
1038
+ items.append(line.split('. ', 1)[1])
1039
+ except IndexError:
1040
+ pass
1041
+
1042
+ responses_list.append(items)
1043
+
1044
+ flat_list = [item.lower() for sublist in responses_list for item in sublist]
1045
+
1046
+ #convert flat_list to a df
1047
+ df = pd.DataFrame(flat_list, columns=['Category'])
1048
+ counts = pd.Series(flat_list).value_counts() # Use original list before conversion
1049
+ df['counts'] = df['Category'].map(counts)
1050
+ df = df.sort_values(by='counts', ascending=False).reset_index(drop=True)
1051
+ df = df.drop_duplicates(subset='Category', keep='first').reset_index(drop=True)
1052
+
1053
+ second_prompt = f"""From this list of categories, extract the top {top_n} most common categories. \
1054
+ The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
1055
+ Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
1056
+
1057
+ if model_source == "OpenAI":
1058
+ client = OpenAI(api_key=api_key)
1059
+ response_obj = client.chat.completions.create(
1060
+ model=user_model,
1061
+ messages=[{'role': 'user', 'content': second_prompt}],
1062
+ temperature=creativity
1063
+ )
1064
+ top_categories = response_obj.choices[0].message.content
1065
+ print(top_categories)
1066
+
1067
+ top_categories_final = []
1068
+ for line in top_categories.split('\n'):
1069
+ if '. ' in line:
1070
+ try:
1071
+ top_categories_final.append(line.split('. ', 1)[1])
1072
+ except IndexError:
1073
+ pass
1074
+
1075
+ return top_categories_final
File without changes
File without changes
File without changes