cat-llm 0.0.43__tar.gz → 0.0.51__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.43
3
+ Version: 0.0.51
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -374,7 +374,10 @@ def cerad_drawn_score(
374
374
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
375
375
 
376
376
  categorized_data = pd.DataFrame({
377
- 'image_input': image_files,
377
+ 'image_input': (
378
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
379
+ else pd.Series(image_files)
380
+ ),
378
381
  'link1': pd.Series(link1).reset_index(drop=True),
379
382
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
380
383
  })
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.43"
4
+ __version__ = "0.0.51"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
@@ -13,4 +13,5 @@ from .__about__ import (
13
13
 
14
14
  from .text_functions import *
15
15
  from .CERAD_functions import *
16
- from .image_functions import *
16
+ from .image_functions import *
17
+ from .build_web_research import *
@@ -0,0 +1,169 @@
1
+ #build dataset classification
2
+ def build_web_research_dataset(
3
+ search_question,
4
+ search_input,
5
+ api_key,
6
+ answer_format = "concise",
7
+ additional_instructions = "",
8
+ categories = ['Answer','URL'],
9
+ user_model="claude-3-7-sonnet-20250219",
10
+ creativity=0,
11
+ safety=False,
12
+ filename="categorized_data.csv",
13
+ save_directory=None,
14
+ model_source="Anthropic",
15
+ time_delay=5
16
+ ):
17
+ import os
18
+ import json
19
+ import pandas as pd
20
+ import regex
21
+ from tqdm import tqdm
22
+ import time
23
+
24
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
25
+ print(categories_str)
26
+ cat_num = len(categories)
27
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
28
+ example_JSON = json.dumps(category_dict, indent=4)
29
+
30
+ # ensure number of categories is what user wants
31
+ #print("\nThe information to be extracted:")
32
+ #for i, cat in enumerate(categories, 1):
33
+ #print(f"{i}. {cat}")
34
+
35
+ link1 = []
36
+ extracted_jsons = []
37
+
38
+ for idx, item in enumerate(tqdm(search_input, desc="Building dataset")):
39
+ if idx == 0: # delay the first item just to be safe
40
+ time.sleep(time_delay)
41
+ reply = None
42
+
43
+ if pd.isna(item):
44
+ link1.append("Skipped NaN input")
45
+ default_json = example_JSON
46
+ extracted_jsons.append(default_json)
47
+ #print(f"Skipped NaN input.")
48
+ else:
49
+ prompt = f"""<role>You are a research assistant specializing in finding current, factual information.</role>
50
+
51
+ <task>Find information about {item}'s {search_question}</task>
52
+
53
+ <rules>
54
+ - Search for the most current and authoritative information available
55
+ - Provide your answer as {answer_format}
56
+ - Prioritize official sources when possible
57
+ - If information is not found, state "Information not found"
58
+ - Include exactly one source URL where you found the information
59
+ - Do not include any explanatory text or commentary beyond the JSON
60
+ {additional_instructions}
61
+ </rules>
62
+
63
+ <format>
64
+ Return your response as valid JSON with this exact structure:
65
+ {{
66
+ "answer": "Your factual answer or 'Information not found'",
67
+ "url": "Source URL or 'No source available'"
68
+ }}
69
+ </format>"""
70
+ #print(prompt)
71
+ if model_source == "Anthropic":
72
+ import anthropic
73
+ client = anthropic.Anthropic(api_key=api_key)
74
+ try:
75
+ message = client.messages.create(
76
+ model=user_model,
77
+ max_tokens=1024,
78
+ temperature=creativity,
79
+ messages=[{"role": "user", "content": prompt}],
80
+ tools=[{
81
+ "type": "web_search_20250305",
82
+ "name": "web_search"
83
+ }]
84
+ )
85
+ reply = " ".join(
86
+ block.text
87
+ for block in message.content
88
+ if getattr(block, "type", "") == "text"
89
+ ).strip()
90
+ link1.append(reply)
91
+ time.sleep(time_delay)
92
+ print(reply)
93
+
94
+ except Exception as e:
95
+ print(f"An error occurred: {e}")
96
+ link1.append(f"Error processing input: {e}")
97
+ time.sleep(time_delay)
98
+ else:
99
+ raise ValueError("Unknown source! Currently this function only supports 'Anthropic' as model_source.")
100
+ # in situation that no JSON is found
101
+ if reply is not None:
102
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
103
+ if extracted_json:
104
+ raw_json = extracted_json[0].strip() # Only strip leading/trailing whitespace
105
+ try:
106
+ # Parse to validate JSON structure
107
+ parsed_obj = json.loads(raw_json)
108
+ # Re-serialize for consistent formatting (optional)
109
+ cleaned_json = json.dumps(parsed_obj)
110
+ extracted_jsons.append(cleaned_json)
111
+ except json.JSONDecodeError as e:
112
+ print(f"JSON parsing error: {e}")
113
+ # Fallback to raw extraction if parsing fails
114
+ extracted_jsons.append(raw_json)
115
+ else:
116
+ # Use consistent schema for errors
117
+ error_message = json.dumps({"answer": "e", "url": "e"})
118
+ extracted_jsons.append(error_message)
119
+ print(error_message)
120
+ else:
121
+ # Handle None reply case
122
+ error_message = json.dumps({"answer": "e", "url": "e"})
123
+ extracted_jsons.append(error_message)
124
+ #print(error_message)
125
+
126
+ # --- Safety Save ---
127
+ if safety:
128
+ # Save progress so far
129
+ temp_df = pd.DataFrame({
130
+ 'survey_response': search_input[:idx+1],
131
+ 'link1': link1,
132
+ 'json': extracted_jsons
133
+ })
134
+ # Normalize processed jsons so far
135
+ normalized_data_list = []
136
+ for json_str in extracted_jsons:
137
+ try:
138
+ parsed_obj = json.loads(json_str)
139
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
140
+ except json.JSONDecodeError:
141
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
142
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
143
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
144
+ # Save to CSV
145
+ if save_directory is None:
146
+ save_directory = os.getcwd()
147
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
148
+
149
+ # --- Final DataFrame ---
150
+ normalized_data_list = []
151
+ for json_str in extracted_jsons:
152
+ try:
153
+ parsed_obj = json.loads(json_str)
154
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
155
+ except json.JSONDecodeError:
156
+ normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
157
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
158
+
159
+ categorized_data = pd.DataFrame({
160
+ 'survey_response': (
161
+ search_input.reset_index(drop=True) if isinstance(search_input, (pd.DataFrame, pd.Series))
162
+ else pd.Series(search_input)
163
+ ),
164
+ 'link1': pd.Series(link1).reset_index(drop=True),
165
+ 'json': pd.Series(extracted_jsons).reset_index(drop=True)
166
+ })
167
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
168
+
169
+ return categorized_data
@@ -252,9 +252,11 @@ def image_multi_class(
252
252
  except json.JSONDecodeError:
253
253
  normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
254
254
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
255
-
256
255
  categorized_data = pd.DataFrame({
257
- 'image_input': image_files,
256
+ 'image_input': (
257
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
258
+ else pd.Series(image_files)
259
+ ),
258
260
  'link1': pd.Series(link1).reset_index(drop=True),
259
261
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
260
262
  })
@@ -549,7 +551,10 @@ def image_score_drawing(
549
551
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
550
552
 
551
553
  categorized_data = pd.DataFrame({
552
- 'image_input': image_files,
554
+ 'image_input': (
555
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
556
+ else pd.Series(image_files)
557
+ ),
553
558
  'link1': pd.Series(link1).reset_index(drop=True),
554
559
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
555
560
  })
@@ -835,7 +840,10 @@ def image_features(
835
840
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
836
841
 
837
842
  categorized_data = pd.DataFrame({
838
- 'image_input': image_files,
843
+ 'image_input': (
844
+ image_files.reset_index(drop=True) if isinstance(image_files, (pd.DataFrame, pd.Series))
845
+ else pd.Series(image_files)
846
+ ),
839
847
  'link1': pd.Series(link1).reset_index(drop=True),
840
848
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
841
849
  })
@@ -373,20 +373,14 @@ Provide your work in JSON format where the number belonging to each category is
373
373
  except json.JSONDecodeError:
374
374
  normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
375
375
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
376
-
377
376
  categorized_data = pd.DataFrame({
378
- 'survey_response': survey_input.reset_index(drop=True),
377
+ 'image_input': (
378
+ survey_input.reset_index(drop=True) if isinstance(survey_input, (pd.DataFrame, pd.Series))
379
+ else pd.Series(survey_input)
380
+ ),
379
381
  'link1': pd.Series(link1).reset_index(drop=True),
380
382
  'json': pd.Series(extracted_jsons).reset_index(drop=True)
381
383
  })
382
384
  categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
383
385
 
384
- if columns != "numbered": #if user wants text columns
385
- categorized_data.columns = list(categorized_data.columns[:3]) + categories[:len(categorized_data.columns) - 3]
386
-
387
- if to_csv:
388
- if save_directory is None:
389
- save_directory = os.getcwd()
390
- categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
391
-
392
386
  return categorized_data
File without changes
File without changes
File without changes
File without changes