cat-llm 0.0.63__py3-none-any.whl → 0.0.65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.63.dist-info → cat_llm-0.0.65.dist-info}/METADATA +2 -1
- {cat_llm-0.0.63.dist-info → cat_llm-0.0.65.dist-info}/RECORD +6 -6
- catllm/__about__.py +1 -1
- catllm/build_web_research.py +309 -29
- {cat_llm-0.0.63.dist-info → cat_llm-0.0.65.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.63.dist-info → cat_llm-0.0.65.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.65
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -22,6 +22,7 @@ Requires-Python: >=3.8
|
|
|
22
22
|
Requires-Dist: anthropic
|
|
23
23
|
Requires-Dist: openai
|
|
24
24
|
Requires-Dist: pandas
|
|
25
|
+
Requires-Dist: perplexityai
|
|
25
26
|
Requires-Dist: requests
|
|
26
27
|
Requires-Dist: tqdm
|
|
27
28
|
Description-Content-Type: text/markdown
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
catllm/CERAD_functions.py,sha256=q4HbP5e2Yu8NnZZ-2eX4sImyj6u3i8xWcq0pYU81iis,22676
|
|
2
|
-
catllm/__about__.py,sha256=
|
|
2
|
+
catllm/__about__.py,sha256=QeI7x2I4oYiFhztRrDnRvZOLW_kEShiCK7Y_hax8U8o,408
|
|
3
3
|
catllm/__init__.py,sha256=sf02zp7N0NW0mAQi7eQ4gliWR1EwoqvXkHN2HwwjcTE,372
|
|
4
|
-
catllm/build_web_research.py,sha256=
|
|
4
|
+
catllm/build_web_research.py,sha256=880dfE2bEQb-FrXP-42JoLLtyc9ox_sBULDr38xiTiQ,22655
|
|
5
5
|
catllm/image_functions.py,sha256=8_FftRU285x1HT-AgNkaobefQVD-5q7ZY_t7JFdL3Sg,36177
|
|
6
6
|
catllm/text_functions.py,sha256=Jf51lNaFtcS2QGnNLkhM8rFVJSD4tN0Bm_VfELvb47g,18686
|
|
7
7
|
catllm/images/circle.png,sha256=JWujAWAh08-TajAoEr_TAeFNLlfbryOLw6cgIBREBuQ,86202
|
|
@@ -9,7 +9,7 @@ catllm/images/cube.png,sha256=nFec3e5bmRe4zrBCJ8QK-HcJLrG7u7dYdKhmdMfacfE,77275
|
|
|
9
9
|
catllm/images/diamond.png,sha256=rJDZKtsnBGRO8FPA0iHuA8FvHFGi9PkI_DWSFdw6iv0,99568
|
|
10
10
|
catllm/images/overlapping_pentagons.png,sha256=VO5plI6eoVRnjfqinn1nNzsCP2WQhuQy71V0EASouW4,71208
|
|
11
11
|
catllm/images/rectangles.png,sha256=2XM16HO9EYWj2yHgN4bPXaCwPfl7iYQy0tQUGaJX9xg,40692
|
|
12
|
-
cat_llm-0.0.
|
|
13
|
-
cat_llm-0.0.
|
|
14
|
-
cat_llm-0.0.
|
|
15
|
-
cat_llm-0.0.
|
|
12
|
+
cat_llm-0.0.65.dist-info/METADATA,sha256=77WCioobgfzMsP_o76XHbRncfNrXYayxFgZDrUVFv7k,22423
|
|
13
|
+
cat_llm-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
14
|
+
cat_llm-0.0.65.dist-info/licenses/LICENSE,sha256=Vje2sS5WV4TnIwY5uQHrF4qnBAM3YOk1pGpdH0ot-2o,34969
|
|
15
|
+
cat_llm-0.0.65.dist-info/RECORD,,
|
catllm/__about__.py
CHANGED
catllm/build_web_research.py
CHANGED
|
@@ -5,39 +5,78 @@ def build_web_research_dataset(
|
|
|
5
5
|
api_key,
|
|
6
6
|
answer_format = "concise",
|
|
7
7
|
additional_instructions = "",
|
|
8
|
-
categories = ['Answer'
|
|
8
|
+
categories = ['Answer'],
|
|
9
9
|
user_model="claude-sonnet-4-20250514",
|
|
10
10
|
creativity=None,
|
|
11
11
|
safety=False,
|
|
12
12
|
filename="categorized_data.csv",
|
|
13
13
|
save_directory=None,
|
|
14
14
|
model_source="Anthropic",
|
|
15
|
-
|
|
15
|
+
start_date=None,
|
|
16
|
+
end_date=None,
|
|
17
|
+
search_depth="", #enables Tavily searches
|
|
18
|
+
tavily_api=None,
|
|
19
|
+
output_urls = True,
|
|
20
|
+
max_retries = 6, #API rate limit error handler retries
|
|
21
|
+
time_delay=5
|
|
16
22
|
):
|
|
17
23
|
import os
|
|
24
|
+
import re
|
|
18
25
|
import json
|
|
19
26
|
import pandas as pd
|
|
20
27
|
import regex
|
|
21
28
|
from tqdm import tqdm
|
|
22
29
|
import time
|
|
30
|
+
from datetime import datetime
|
|
31
|
+
|
|
32
|
+
#ensures proper date format
|
|
33
|
+
def _validate_date(date_str):
|
|
34
|
+
"""Validates YYYY-MM-DD format"""
|
|
35
|
+
if date_str is None:
|
|
36
|
+
return True # None is acceptable (means no date constraint)
|
|
37
|
+
|
|
38
|
+
if not isinstance(date_str, str):
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
# Check pattern: YYYY_MM_DD
|
|
42
|
+
pattern = r'^\d{4}-\d{2}-\d{2}$'
|
|
43
|
+
if not re.match(pattern, date_str):
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
# Validate actual date
|
|
47
|
+
try:
|
|
48
|
+
year, month, day = date_str.split('-')
|
|
49
|
+
datetime(int(year), int(month), int(day))
|
|
50
|
+
return True
|
|
51
|
+
except (ValueError, OverflowError):
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
# Validate dates at the start of the function
|
|
55
|
+
if not _validate_date(start_date):
|
|
56
|
+
raise ValueError(f"start_date must be in YYYY-MM-DD format, got: {start_date}")
|
|
57
|
+
|
|
58
|
+
if not _validate_date(end_date):
|
|
59
|
+
raise ValueError(f"end_date must be in YYYY-MM-DD format, got: {end_date}")
|
|
23
60
|
|
|
24
61
|
model_source = model_source.lower() # eliminating case sensitivity
|
|
62
|
+
|
|
63
|
+
if model_source == "perplexity" and start_date is not None:
|
|
64
|
+
start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%m/%d/%Y")
|
|
65
|
+
if model_source == "perplexity" and end_date is not None:
|
|
66
|
+
end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%m/%d/%Y")
|
|
67
|
+
|
|
68
|
+
# in case user switches to google but doesn't switch model
|
|
69
|
+
if model_source == "google" and user_model == "claude-sonnet-4-20250514":
|
|
70
|
+
user_model = "gemini-2.5-flash"
|
|
25
71
|
|
|
26
72
|
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
27
|
-
print(categories_str)
|
|
28
73
|
cat_num = len(categories)
|
|
29
74
|
category_dict = {str(i+1): "0" for i in range(cat_num)}
|
|
30
75
|
example_JSON = json.dumps(category_dict, indent=4)
|
|
31
|
-
|
|
32
|
-
# ensure number of categories is what user wants
|
|
33
|
-
#print("\nThe information to be extracted:")
|
|
34
|
-
#for i, cat in enumerate(categories, 1):
|
|
35
|
-
#print(f"{i}. {cat}")
|
|
36
76
|
|
|
37
77
|
link1 = []
|
|
38
78
|
extracted_jsons = []
|
|
39
|
-
|
|
40
|
-
max_retries = 5 #API rate limit error handler retries
|
|
79
|
+
extracted_urls = []
|
|
41
80
|
|
|
42
81
|
for idx, item in enumerate(tqdm(search_input, desc="Building dataset")):
|
|
43
82
|
if idx > 0: # Skip delay for first item only
|
|
@@ -46,9 +85,9 @@ def build_web_research_dataset(
|
|
|
46
85
|
|
|
47
86
|
if pd.isna(item):
|
|
48
87
|
link1.append("Skipped NaN input")
|
|
88
|
+
extracted_urls.append([])
|
|
49
89
|
default_json = example_JSON
|
|
50
90
|
extracted_jsons.append(default_json)
|
|
51
|
-
#print(f"Skipped NaN input.")
|
|
52
91
|
else:
|
|
53
92
|
prompt = f"""<role>You are a research assistant specializing in finding current, factual information.</role>
|
|
54
93
|
|
|
@@ -59,7 +98,6 @@ def build_web_research_dataset(
|
|
|
59
98
|
- Provide your answer as {answer_format}
|
|
60
99
|
- Prioritize official sources when possible
|
|
61
100
|
- If information is not found, state "Information not found"
|
|
62
|
-
- Include exactly one source URL where you found the information
|
|
63
101
|
- Do not include any explanatory text or commentary beyond the JSON
|
|
64
102
|
{additional_instructions}
|
|
65
103
|
</rules>
|
|
@@ -68,14 +106,74 @@ def build_web_research_dataset(
|
|
|
68
106
|
Return your response as valid JSON with this exact structure:
|
|
69
107
|
{{
|
|
70
108
|
"answer": "Your factual answer or 'Information not found'",
|
|
71
|
-
"
|
|
109
|
+
"second_best_answer": "Your second best factual answer or 'Information not found'",
|
|
110
|
+
"confidence": "confidence in response 0-5 or 'Information not found'"
|
|
72
111
|
}}
|
|
112
|
+
|
|
73
113
|
</format>"""
|
|
74
|
-
|
|
75
|
-
if
|
|
114
|
+
|
|
115
|
+
if start_date is not None and end_date is not None:
|
|
116
|
+
append_text = f"\n- Focus on webpages with a page age between {start_date} and {end_date}."
|
|
117
|
+
prompt = prompt.replace("<rules>", "<rules>" + append_text)
|
|
118
|
+
elif start_date is not None:
|
|
119
|
+
append_text = f"\n- Focus on webpages published after {start_date}."
|
|
120
|
+
prompt = prompt.replace("<rules>", "<rules>" + append_text)
|
|
121
|
+
elif end_date is not None:
|
|
122
|
+
append_text = f"\n- Focus on webpages published before {end_date}."
|
|
123
|
+
prompt = prompt.replace("<rules>", "<rules>" + append_text)
|
|
124
|
+
|
|
125
|
+
if search_depth == "advanced" and model_source != "perplexity":
|
|
126
|
+
try:
|
|
127
|
+
from tavily import TavilyClient
|
|
128
|
+
tavily_client = TavilyClient(tavily_api)
|
|
129
|
+
tavily_response = tavily_client.search(
|
|
130
|
+
query=f"{item}'s {search_question}",
|
|
131
|
+
include_answer=True,
|
|
132
|
+
max_results=15,
|
|
133
|
+
search_depth="advanced",
|
|
134
|
+
**({"start_date": start_date} if start_date is not None else {}),
|
|
135
|
+
**({"end_date": end_date} if end_date is not None else {})
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
urls = [
|
|
139
|
+
result['url']
|
|
140
|
+
for result in tavily_response.get('results', [])
|
|
141
|
+
if 'url' in result
|
|
142
|
+
]
|
|
143
|
+
seen = set()
|
|
144
|
+
urls = [u for u in urls if not (u in seen or seen.add(u))]
|
|
145
|
+
extracted_urls.append(urls)
|
|
146
|
+
|
|
147
|
+
except Exception as e:
|
|
148
|
+
error_msg = str(e).lower()
|
|
149
|
+
if "unauthorized" in error_msg or "403" in error_msg or "401" in error_msg or "api_key" in error_msg:
|
|
150
|
+
raise ValueError("ERROR: Invalid or missing tavily_api required for advanced search. Get one at https://app.tavily.com/home. To install: pip install tavily-python") from e
|
|
151
|
+
else:
|
|
152
|
+
print(f"Tavily search error: {e}")
|
|
153
|
+
link1.append(f"Error with Tavily search: {e}")
|
|
154
|
+
extracted_urls.append([])
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
#print(tavily_response)
|
|
158
|
+
|
|
159
|
+
advanced_prompt = f"""Based on the following search results about {item}'s {search_question}, provide your answer in this EXACT JSON format and {answer_format}:
|
|
160
|
+
If you can't find the information, respond with 'Information not found'.
|
|
161
|
+
{{"answer": "your answer here or 'Information not found'",
|
|
162
|
+
"second_best_answer": "your second best answer here or 'Information not found'",
|
|
163
|
+
"confidence": "confidence in response 0-5 or 'Information not found'"}}
|
|
164
|
+
|
|
165
|
+
Search results:
|
|
166
|
+
{tavily_response}
|
|
167
|
+
|
|
168
|
+
Additional context from sources:
|
|
169
|
+
{chr(10).join([f"- {r.get('title', '')}: {r.get('content', '')}" for r in tavily_response.get('results', [])[:3]])}
|
|
170
|
+
|
|
171
|
+
Return ONLY the JSON object, no other text."""
|
|
172
|
+
|
|
173
|
+
if model_source == "anthropic" and search_depth != "advanced":
|
|
76
174
|
import anthropic
|
|
77
175
|
client = anthropic.Anthropic(api_key=api_key)
|
|
78
|
-
|
|
176
|
+
#print(prompt)
|
|
79
177
|
attempt = 0
|
|
80
178
|
while attempt < max_retries:
|
|
81
179
|
try:
|
|
@@ -95,8 +193,21 @@ def build_web_research_dataset(
|
|
|
95
193
|
if getattr(block, "type", "") == "text"
|
|
96
194
|
).strip()
|
|
97
195
|
link1.append(reply)
|
|
196
|
+
|
|
197
|
+
urls = [
|
|
198
|
+
item["url"]
|
|
199
|
+
for block in message.content
|
|
200
|
+
if getattr(block, "type", "") == "web_search_tool_result"
|
|
201
|
+
for item in (getattr(block, "content", []) or [])
|
|
202
|
+
if isinstance(item, dict) and item.get("type") == "web_search_result" and "url" in item
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
seen = set()
|
|
206
|
+
urls = [u for u in urls if not (u in seen or seen.add(u))]
|
|
207
|
+
extracted_urls.append(urls)
|
|
208
|
+
|
|
98
209
|
break
|
|
99
|
-
except anthropic.
|
|
210
|
+
except anthropic.RateLimitError as e:
|
|
100
211
|
wait_time = 2 ** attempt # Exponential backoff, keeps doubling after each attempt
|
|
101
212
|
print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
|
|
102
213
|
time.sleep(wait_time) #in case user wants to try and buffer the amount of errors by adding a wait time before attemps
|
|
@@ -104,11 +215,59 @@ def build_web_research_dataset(
|
|
|
104
215
|
except Exception as e:
|
|
105
216
|
print(f"A Non-rate-limit error occurred: {e}")
|
|
106
217
|
link1.append(f"Error processing input: {e}")
|
|
218
|
+
extracted_urls.append([])
|
|
107
219
|
break #stop retrying
|
|
108
220
|
else:
|
|
109
221
|
link1.append("Max retries exceeded for rate limit errors.")
|
|
222
|
+
extracted_urls.append([])
|
|
110
223
|
|
|
111
|
-
elif model_source == "
|
|
224
|
+
elif model_source == "anthropic" and search_depth == "advanced":
|
|
225
|
+
import anthropic
|
|
226
|
+
claude_client = anthropic.Anthropic(api_key=api_key)
|
|
227
|
+
|
|
228
|
+
attempt = 0
|
|
229
|
+
while attempt < max_retries:
|
|
230
|
+
try:
|
|
231
|
+
message = claude_client.messages.create(
|
|
232
|
+
model=user_model,
|
|
233
|
+
max_tokens=1024,
|
|
234
|
+
messages=[{"role": "user", "content": advanced_prompt}],
|
|
235
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
reply = " ".join(
|
|
239
|
+
block.text
|
|
240
|
+
for block in message.content
|
|
241
|
+
if getattr(block, "type", "") == "text"
|
|
242
|
+
).strip()
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
import json
|
|
246
|
+
json_response = json.loads(reply)
|
|
247
|
+
final_answer = json_response.get('answer', reply)
|
|
248
|
+
link1.append(final_answer)
|
|
249
|
+
except json.JSONDecodeError:
|
|
250
|
+
|
|
251
|
+
print(f"JSON parse error, using raw reply: {reply}")
|
|
252
|
+
link1.append(reply)
|
|
253
|
+
|
|
254
|
+
break # Success
|
|
255
|
+
|
|
256
|
+
except anthropic.RateLimitError as e:
|
|
257
|
+
wait_time = 2 ** attempt
|
|
258
|
+
print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
|
|
259
|
+
time.sleep(wait_time)
|
|
260
|
+
attempt += 1
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
print(f"A Non-rate-limit error occurred: {e}")
|
|
264
|
+
link1.append(f"Error processing input: {e}")
|
|
265
|
+
break
|
|
266
|
+
else:
|
|
267
|
+
# Max retries exceeded
|
|
268
|
+
link1.append("Max retries exceeded for rate limit errors.")
|
|
269
|
+
|
|
270
|
+
elif model_source == "google" and search_depth != "advanced":
|
|
112
271
|
import requests
|
|
113
272
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
114
273
|
try:
|
|
@@ -121,11 +280,62 @@ def build_web_research_dataset(
|
|
|
121
280
|
"tools": [{"google_search": {}}],
|
|
122
281
|
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
123
282
|
}
|
|
124
|
-
|
|
283
|
+
|
|
125
284
|
response = requests.post(url, headers=headers, json=payload)
|
|
126
285
|
response.raise_for_status()
|
|
127
286
|
result = response.json()
|
|
287
|
+
|
|
288
|
+
urls = []
|
|
289
|
+
for cand in result.get("candidates", []):
|
|
290
|
+
rendered_html = (
|
|
291
|
+
cand.get("groundingMetadata", {})
|
|
292
|
+
.get("searchEntryPoint", {})
|
|
293
|
+
.get("renderedContent", "")
|
|
294
|
+
)
|
|
295
|
+
if rendered_html:
|
|
296
|
+
# regex: capture href="..."; limited to class="chip"
|
|
297
|
+
found = re.findall(
|
|
298
|
+
r'<a[^>]*class=["\']chip["\'][^>]*href=["\']([^"\']+)["\']',
|
|
299
|
+
rendered_html,
|
|
300
|
+
flags=re.IGNORECASE
|
|
301
|
+
)
|
|
302
|
+
urls.extend(found)
|
|
303
|
+
|
|
304
|
+
seen = set()
|
|
305
|
+
urls = [u for u in urls if not (u in seen or seen.add(u))]
|
|
306
|
+
extracted_urls.append(urls)
|
|
307
|
+
|
|
308
|
+
# extract reply from Google's response structure
|
|
309
|
+
if "candidates" in result and result["candidates"]:
|
|
310
|
+
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
311
|
+
else:
|
|
312
|
+
reply = "No response generated"
|
|
313
|
+
|
|
314
|
+
link1.append(reply)
|
|
315
|
+
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print(f"An error occurred: {e}")
|
|
318
|
+
link1.append(f"Error processing input: {e}")
|
|
319
|
+
extracted_urls.append([])
|
|
320
|
+
|
|
321
|
+
elif model_source == "google" and search_depth == "advanced":
|
|
322
|
+
import requests
|
|
323
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
324
|
+
try:
|
|
325
|
+
headers = {
|
|
326
|
+
"x-goog-api-key": api_key,
|
|
327
|
+
"Content-Type": "application/json"
|
|
328
|
+
}
|
|
128
329
|
|
|
330
|
+
payload = {
|
|
331
|
+
"contents": [{"parts": [{"text": advanced_prompt}]}],
|
|
332
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
336
|
+
response.raise_for_status()
|
|
337
|
+
result = response.json()
|
|
338
|
+
|
|
129
339
|
# extract reply from Google's response structure
|
|
130
340
|
if "candidates" in result and result["candidates"]:
|
|
131
341
|
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
@@ -138,6 +348,54 @@ def build_web_research_dataset(
|
|
|
138
348
|
print(f"An error occurred: {e}")
|
|
139
349
|
link1.append(f"Error processing input: {e}")
|
|
140
350
|
|
|
351
|
+
elif model_source == "perplexity":
|
|
352
|
+
|
|
353
|
+
from perplexity import Perplexity
|
|
354
|
+
client = Perplexity(api_key=api_key)
|
|
355
|
+
try:
|
|
356
|
+
response = client.chat.completions.create(
|
|
357
|
+
messages=[
|
|
358
|
+
{
|
|
359
|
+
"role": "user",
|
|
360
|
+
"content": prompt
|
|
361
|
+
}
|
|
362
|
+
],
|
|
363
|
+
model=user_model,
|
|
364
|
+
max_tokens=1024,
|
|
365
|
+
**({"temperature": creativity} if creativity is not None else {}),
|
|
366
|
+
web_search_options={"search_context_size": "high" if search_depth == "advanced" else "medium"},
|
|
367
|
+
**({"search_after_date_filter": start_date} if start_date else {}),
|
|
368
|
+
**({"search_before_date_filter": end_date} if end_date else {}),
|
|
369
|
+
response_format={ #requiring a JSON
|
|
370
|
+
"type": "json_schema",
|
|
371
|
+
"json_schema": {
|
|
372
|
+
"schema": {
|
|
373
|
+
"type": "object",
|
|
374
|
+
"properties": {
|
|
375
|
+
"answer": {"type": "string"},
|
|
376
|
+
"second_best_answer": {"type": "string"},
|
|
377
|
+
"confidence": {"type": "integer"}
|
|
378
|
+
},
|
|
379
|
+
"required": ["answer", "second_best_answer"]
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
reply = response.choices[0].message.content
|
|
386
|
+
#print(response)
|
|
387
|
+
link1.append(reply)
|
|
388
|
+
|
|
389
|
+
urls = list(response.citations) if hasattr(response, 'citations') else []
|
|
390
|
+
|
|
391
|
+
seen = set()
|
|
392
|
+
urls = [u for u in urls if not (u in seen or seen.add(u))]
|
|
393
|
+
extracted_urls.append(urls)
|
|
394
|
+
|
|
395
|
+
except Exception as e:
|
|
396
|
+
print(f"An error occurred: {e}")
|
|
397
|
+
link1.append(f"Error processing input: {e}")
|
|
398
|
+
extracted_urls.append([])
|
|
141
399
|
else:
|
|
142
400
|
raise ValueError("Unknown source! Currently this function only supports 'Anthropic' or 'Google' as model_source.")
|
|
143
401
|
# in situation that no JSON is found
|
|
@@ -157,12 +415,12 @@ def build_web_research_dataset(
|
|
|
157
415
|
extracted_jsons.append(raw_json)
|
|
158
416
|
else:
|
|
159
417
|
# Use consistent schema for errors
|
|
160
|
-
error_message = json.dumps({"answer": "e"
|
|
418
|
+
error_message = json.dumps({"answer": "e"})
|
|
161
419
|
extracted_jsons.append(error_message)
|
|
162
420
|
print(error_message)
|
|
163
421
|
else:
|
|
164
422
|
# Handle None reply case
|
|
165
|
-
error_message = json.dumps({"answer": "e"
|
|
423
|
+
error_message = json.dumps({"answer": "e"})
|
|
166
424
|
extracted_jsons.append(error_message)
|
|
167
425
|
#print(error_message)
|
|
168
426
|
|
|
@@ -170,9 +428,9 @@ def build_web_research_dataset(
|
|
|
170
428
|
if safety:
|
|
171
429
|
# Save progress so far
|
|
172
430
|
temp_df = pd.DataFrame({
|
|
173
|
-
'
|
|
174
|
-
'model_response': link1,
|
|
175
|
-
'json': extracted_jsons
|
|
431
|
+
'raw_response': search_input[:idx+1]
|
|
432
|
+
#'model_response': link1,
|
|
433
|
+
#'json': extracted_jsons
|
|
176
434
|
})
|
|
177
435
|
# Normalize processed jsons so far
|
|
178
436
|
normalized_data_list = []
|
|
@@ -181,9 +439,11 @@ def build_web_research_dataset(
|
|
|
181
439
|
parsed_obj = json.loads(json_str)
|
|
182
440
|
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
183
441
|
except json.JSONDecodeError:
|
|
184
|
-
normalized_data_list.append(pd.DataFrame({"
|
|
442
|
+
normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
|
|
185
443
|
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
444
|
+
temp_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
|
|
186
445
|
temp_df = pd.concat([temp_df, normalized_data], axis=1)
|
|
446
|
+
temp_df = pd.concat([temp_df, temp_urls], axis=1)
|
|
187
447
|
# Save to CSV
|
|
188
448
|
if save_directory is None:
|
|
189
449
|
save_directory = os.getcwd()
|
|
@@ -196,17 +456,37 @@ def build_web_research_dataset(
|
|
|
196
456
|
parsed_obj = json.loads(json_str)
|
|
197
457
|
normalized_data_list.append(pd.json_normalize(parsed_obj))
|
|
198
458
|
except json.JSONDecodeError:
|
|
199
|
-
normalized_data_list.append(pd.DataFrame({"
|
|
459
|
+
normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
|
|
200
460
|
normalized_data = pd.concat(normalized_data_list, ignore_index=True)
|
|
201
461
|
|
|
462
|
+
# converting urls to dataframe and adding prefix
|
|
463
|
+
df_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
|
|
464
|
+
|
|
202
465
|
categorized_data = pd.DataFrame({
|
|
203
|
-
'
|
|
466
|
+
'search_input': (
|
|
204
467
|
search_input.reset_index(drop=True) if isinstance(search_input, (pd.DataFrame, pd.Series))
|
|
205
468
|
else pd.Series(search_input)
|
|
206
469
|
),
|
|
207
|
-
'
|
|
208
|
-
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
470
|
+
'raw_response': pd.Series(link1).reset_index(drop=True),
|
|
471
|
+
#'json': pd.Series(extracted_jsons).reset_index(drop=True),
|
|
472
|
+
#"all_urls": pd.Series(extracted_urls).reset_index(drop=True)
|
|
209
473
|
})
|
|
210
474
|
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
475
|
+
categorized_data = pd.concat([categorized_data, df_urls], axis=1)
|
|
476
|
+
|
|
477
|
+
# drop second best answer column if it exists
|
|
478
|
+
# we only ask for the second best answer to "force" the model to think more carefully about its best answer, but we don't actually need to keep it
|
|
479
|
+
categorized_data = categorized_data.drop(columns=["second_best_answer"], errors='ignore')
|
|
480
|
+
|
|
481
|
+
# dropping this column for advanced searches (this column is mostly useful for basic searches to see what the model saw)
|
|
482
|
+
if search_depth == "advanced":
|
|
483
|
+
categorized_data = categorized_data.drop(columns=["raw_response"], errors='ignore')
|
|
484
|
+
|
|
485
|
+
#for users who don't want the urls included in the final dataframe
|
|
486
|
+
if output_urls is False:
|
|
487
|
+
categorized_data = categorized_data.drop(columns=[col for col in categorized_data.columns if col.startswith("url_")])
|
|
488
|
+
|
|
489
|
+
if save_directory is not None:
|
|
490
|
+
categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
|
|
211
491
|
|
|
212
492
|
return categorized_data
|
|
File without changes
|
|
File without changes
|