cat-llm 0.0.63__tar.gz → 0.0.65__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.63
3
+ Version: 0.0.65
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -22,6 +22,7 @@ Requires-Python: >=3.8
22
22
  Requires-Dist: anthropic
23
23
  Requires-Dist: openai
24
24
  Requires-Dist: pandas
25
+ Requires-Dist: perplexityai
25
26
  Requires-Dist: requests
26
27
  Requires-Dist: tqdm
27
28
  Description-Content-Type: text/markdown
@@ -29,7 +29,8 @@ dependencies = [
29
29
  "tqdm",
30
30
  "requests",
31
31
  "openai",
32
- "anthropic"
32
+ "anthropic",
33
+ "perplexityai"
33
34
  ]
34
35
 
35
36
  [project.urls]
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.63"
4
+ __version__ = "0.0.65"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
@@ -0,0 +1,492 @@
1
+ #build dataset classification
2
+ def build_web_research_dataset(
3
+ search_question,
4
+ search_input,
5
+ api_key,
6
+ answer_format = "concise",
7
+ additional_instructions = "",
8
+ categories = ['Answer'],
9
+ user_model="claude-sonnet-4-20250514",
10
+ creativity=None,
11
+ safety=False,
12
+ filename="categorized_data.csv",
13
+ save_directory=None,
14
+ model_source="Anthropic",
15
+ start_date=None,
16
+ end_date=None,
17
+ search_depth="", #enables Tavily searches
18
+ tavily_api=None,
19
+ output_urls = True,
20
+ max_retries = 6, #API rate limit error handler retries
21
+ time_delay=5
22
+ ):
23
+ import os
24
+ import re
25
+ import json
26
+ import pandas as pd
27
+ import regex
28
+ from tqdm import tqdm
29
+ import time
30
+ from datetime import datetime
31
+
32
+ #ensures proper date format
33
+ def _validate_date(date_str):
34
+ """Validates YYYY-MM-DD format"""
35
+ if date_str is None:
36
+ return True # None is acceptable (means no date constraint)
37
+
38
+ if not isinstance(date_str, str):
39
+ return False
40
+
41
+ # Check pattern: YYYY_MM_DD
42
+ pattern = r'^\d{4}-\d{2}-\d{2}$'
43
+ if not re.match(pattern, date_str):
44
+ return False
45
+
46
+ # Validate actual date
47
+ try:
48
+ year, month, day = date_str.split('-')
49
+ datetime(int(year), int(month), int(day))
50
+ return True
51
+ except (ValueError, OverflowError):
52
+ return False
53
+
54
+ # Validate dates at the start of the function
55
+ if not _validate_date(start_date):
56
+ raise ValueError(f"start_date must be in YYYY-MM-DD format, got: {start_date}")
57
+
58
+ if not _validate_date(end_date):
59
+ raise ValueError(f"end_date must be in YYYY-MM-DD format, got: {end_date}")
60
+
61
+ model_source = model_source.lower() # eliminating case sensitivity
62
+
63
+ if model_source == "perplexity" and start_date is not None:
64
+ start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%m/%d/%Y")
65
+ if model_source == "perplexity" and end_date is not None:
66
+ end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%m/%d/%Y")
67
+
68
+ # in case user switches to google but doesn't switch model
69
+ if model_source == "google" and user_model == "claude-sonnet-4-20250514":
70
+ user_model = "gemini-2.5-flash"
71
+
72
+ categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
73
+ cat_num = len(categories)
74
+ category_dict = {str(i+1): "0" for i in range(cat_num)}
75
+ example_JSON = json.dumps(category_dict, indent=4)
76
+
77
+ link1 = []
78
+ extracted_jsons = []
79
+ extracted_urls = []
80
+
81
+ for idx, item in enumerate(tqdm(search_input, desc="Building dataset")):
82
+ if idx > 0: # Skip delay for first item only
83
+ time.sleep(time_delay)
84
+ reply = None
85
+
86
+ if pd.isna(item):
87
+ link1.append("Skipped NaN input")
88
+ extracted_urls.append([])
89
+ default_json = example_JSON
90
+ extracted_jsons.append(default_json)
91
+ else:
92
+ prompt = f"""<role>You are a research assistant specializing in finding current, factual information.</role>
93
+
94
+ <task>Find information about {item}'s {search_question}</task>
95
+
96
+ <rules>
97
+ - Search for the most current and authoritative information available
98
+ - Provide your answer as {answer_format}
99
+ - Prioritize official sources when possible
100
+ - If information is not found, state "Information not found"
101
+ - Do not include any explanatory text or commentary beyond the JSON
102
+ {additional_instructions}
103
+ </rules>
104
+
105
+ <format>
106
+ Return your response as valid JSON with this exact structure:
107
+ {{
108
+ "answer": "Your factual answer or 'Information not found'",
109
+ "second_best_answer": "Your second best factual answer or 'Information not found'",
110
+ "confidence": "confidence in response 0-5 or 'Information not found'"
111
+ }}
112
+
113
+ </format>"""
114
+
115
+ if start_date is not None and end_date is not None:
116
+ append_text = f"\n- Focus on webpages with a page age between {start_date} and {end_date}."
117
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
118
+ elif start_date is not None:
119
+ append_text = f"\n- Focus on webpages published after {start_date}."
120
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
121
+ elif end_date is not None:
122
+ append_text = f"\n- Focus on webpages published before {end_date}."
123
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
124
+
125
+ if search_depth == "advanced" and model_source != "perplexity":
126
+ try:
127
+ from tavily import TavilyClient
128
+ tavily_client = TavilyClient(tavily_api)
129
+ tavily_response = tavily_client.search(
130
+ query=f"{item}'s {search_question}",
131
+ include_answer=True,
132
+ max_results=15,
133
+ search_depth="advanced",
134
+ **({"start_date": start_date} if start_date is not None else {}),
135
+ **({"end_date": end_date} if end_date is not None else {})
136
+ )
137
+
138
+ urls = [
139
+ result['url']
140
+ for result in tavily_response.get('results', [])
141
+ if 'url' in result
142
+ ]
143
+ seen = set()
144
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
145
+ extracted_urls.append(urls)
146
+
147
+ except Exception as e:
148
+ error_msg = str(e).lower()
149
+ if "unauthorized" in error_msg or "403" in error_msg or "401" in error_msg or "api_key" in error_msg:
150
+ raise ValueError("ERROR: Invalid or missing tavily_api required for advanced search. Get one at https://app.tavily.com/home. To install: pip install tavily-python") from e
151
+ else:
152
+ print(f"Tavily search error: {e}")
153
+ link1.append(f"Error with Tavily search: {e}")
154
+ extracted_urls.append([])
155
+ continue
156
+
157
+ #print(tavily_response)
158
+
159
+ advanced_prompt = f"""Based on the following search results about {item}'s {search_question}, provide your answer in this EXACT JSON format and {answer_format}:
160
+ If you can't find the information, respond with 'Information not found'.
161
+ {{"answer": "your answer here or 'Information not found'",
162
+ "second_best_answer": "your second best answer here or 'Information not found'",
163
+ "confidence": "confidence in response 0-5 or 'Information not found'"}}
164
+
165
+ Search results:
166
+ {tavily_response}
167
+
168
+ Additional context from sources:
169
+ {chr(10).join([f"- {r.get('title', '')}: {r.get('content', '')}" for r in tavily_response.get('results', [])[:3]])}
170
+
171
+ Return ONLY the JSON object, no other text."""
172
+
173
+ if model_source == "anthropic" and search_depth != "advanced":
174
+ import anthropic
175
+ client = anthropic.Anthropic(api_key=api_key)
176
+ #print(prompt)
177
+ attempt = 0
178
+ while attempt < max_retries:
179
+ try:
180
+ message = client.messages.create(
181
+ model=user_model,
182
+ max_tokens=1024,
183
+ messages=[{"role": "user", "content": prompt}],
184
+ **({"temperature": creativity} if creativity is not None else {}),
185
+ tools=[{
186
+ "type": "web_search_20250305",
187
+ "name": "web_search"
188
+ }]
189
+ )
190
+ reply = " ".join(
191
+ block.text
192
+ for block in message.content
193
+ if getattr(block, "type", "") == "text"
194
+ ).strip()
195
+ link1.append(reply)
196
+
197
+ urls = [
198
+ item["url"]
199
+ for block in message.content
200
+ if getattr(block, "type", "") == "web_search_tool_result"
201
+ for item in (getattr(block, "content", []) or [])
202
+ if isinstance(item, dict) and item.get("type") == "web_search_result" and "url" in item
203
+ ]
204
+
205
+ seen = set()
206
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
207
+ extracted_urls.append(urls)
208
+
209
+ break
210
+ except anthropic.RateLimitError as e:
211
+ wait_time = 2 ** attempt # Exponential backoff, keeps doubling after each attempt
212
+ print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
213
+ time.sleep(wait_time) #in case user wants to try and buffer the amount of errors by adding a wait time before attemps
214
+ attempt += 1
215
+ except Exception as e:
216
+ print(f"A Non-rate-limit error occurred: {e}")
217
+ link1.append(f"Error processing input: {e}")
218
+ extracted_urls.append([])
219
+ break #stop retrying
220
+ else:
221
+ link1.append("Max retries exceeded for rate limit errors.")
222
+ extracted_urls.append([])
223
+
224
+ elif model_source == "anthropic" and search_depth == "advanced":
225
+ import anthropic
226
+ claude_client = anthropic.Anthropic(api_key=api_key)
227
+
228
+ attempt = 0
229
+ while attempt < max_retries:
230
+ try:
231
+ message = claude_client.messages.create(
232
+ model=user_model,
233
+ max_tokens=1024,
234
+ messages=[{"role": "user", "content": advanced_prompt}],
235
+ **({"temperature": creativity} if creativity is not None else {})
236
+ )
237
+
238
+ reply = " ".join(
239
+ block.text
240
+ for block in message.content
241
+ if getattr(block, "type", "") == "text"
242
+ ).strip()
243
+
244
+ try:
245
+ import json
246
+ json_response = json.loads(reply)
247
+ final_answer = json_response.get('answer', reply)
248
+ link1.append(final_answer)
249
+ except json.JSONDecodeError:
250
+
251
+ print(f"JSON parse error, using raw reply: {reply}")
252
+ link1.append(reply)
253
+
254
+ break # Success
255
+
256
+ except anthropic.RateLimitError as e:
257
+ wait_time = 2 ** attempt
258
+ print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
259
+ time.sleep(wait_time)
260
+ attempt += 1
261
+
262
+ except Exception as e:
263
+ print(f"A Non-rate-limit error occurred: {e}")
264
+ link1.append(f"Error processing input: {e}")
265
+ break
266
+ else:
267
+ # Max retries exceeded
268
+ link1.append("Max retries exceeded for rate limit errors.")
269
+
270
+ elif model_source == "google" and search_depth != "advanced":
271
+ import requests
272
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
273
+ try:
274
+ headers = {
275
+ "x-goog-api-key": api_key,
276
+ "Content-Type": "application/json"
277
+ }
278
+ payload = {
279
+ "contents": [{"parts": [{"text": prompt}]}],
280
+ "tools": [{"google_search": {}}],
281
+ **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
282
+ }
283
+
284
+ response = requests.post(url, headers=headers, json=payload)
285
+ response.raise_for_status()
286
+ result = response.json()
287
+
288
+ urls = []
289
+ for cand in result.get("candidates", []):
290
+ rendered_html = (
291
+ cand.get("groundingMetadata", {})
292
+ .get("searchEntryPoint", {})
293
+ .get("renderedContent", "")
294
+ )
295
+ if rendered_html:
296
+ # regex: capture href="..."; limited to class="chip"
297
+ found = re.findall(
298
+ r'<a[^>]*class=["\']chip["\'][^>]*href=["\']([^"\']+)["\']',
299
+ rendered_html,
300
+ flags=re.IGNORECASE
301
+ )
302
+ urls.extend(found)
303
+
304
+ seen = set()
305
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
306
+ extracted_urls.append(urls)
307
+
308
+ # extract reply from Google's response structure
309
+ if "candidates" in result and result["candidates"]:
310
+ reply = result["candidates"][0]["content"]["parts"][0]["text"]
311
+ else:
312
+ reply = "No response generated"
313
+
314
+ link1.append(reply)
315
+
316
+ except Exception as e:
317
+ print(f"An error occurred: {e}")
318
+ link1.append(f"Error processing input: {e}")
319
+ extracted_urls.append([])
320
+
321
+ elif model_source == "google" and search_depth == "advanced":
322
+ import requests
323
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
324
+ try:
325
+ headers = {
326
+ "x-goog-api-key": api_key,
327
+ "Content-Type": "application/json"
328
+ }
329
+
330
+ payload = {
331
+ "contents": [{"parts": [{"text": advanced_prompt}]}],
332
+ **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
333
+ }
334
+
335
+ response = requests.post(url, headers=headers, json=payload)
336
+ response.raise_for_status()
337
+ result = response.json()
338
+
339
+ # extract reply from Google's response structure
340
+ if "candidates" in result and result["candidates"]:
341
+ reply = result["candidates"][0]["content"]["parts"][0]["text"]
342
+ else:
343
+ reply = "No response generated"
344
+
345
+ link1.append(reply)
346
+
347
+ except Exception as e:
348
+ print(f"An error occurred: {e}")
349
+ link1.append(f"Error processing input: {e}")
350
+
351
+ elif model_source == "perplexity":
352
+
353
+ from perplexity import Perplexity
354
+ client = Perplexity(api_key=api_key)
355
+ try:
356
+ response = client.chat.completions.create(
357
+ messages=[
358
+ {
359
+ "role": "user",
360
+ "content": prompt
361
+ }
362
+ ],
363
+ model=user_model,
364
+ max_tokens=1024,
365
+ **({"temperature": creativity} if creativity is not None else {}),
366
+ web_search_options={"search_context_size": "high" if search_depth == "advanced" else "medium"},
367
+ **({"search_after_date_filter": start_date} if start_date else {}),
368
+ **({"search_before_date_filter": end_date} if end_date else {}),
369
+ response_format={ #requiring a JSON
370
+ "type": "json_schema",
371
+ "json_schema": {
372
+ "schema": {
373
+ "type": "object",
374
+ "properties": {
375
+ "answer": {"type": "string"},
376
+ "second_best_answer": {"type": "string"},
377
+ "confidence": {"type": "integer"}
378
+ },
379
+ "required": ["answer", "second_best_answer"]
380
+ }
381
+ }
382
+ }
383
+ )
384
+
385
+ reply = response.choices[0].message.content
386
+ #print(response)
387
+ link1.append(reply)
388
+
389
+ urls = list(response.citations) if hasattr(response, 'citations') else []
390
+
391
+ seen = set()
392
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
393
+ extracted_urls.append(urls)
394
+
395
+ except Exception as e:
396
+ print(f"An error occurred: {e}")
397
+ link1.append(f"Error processing input: {e}")
398
+ extracted_urls.append([])
399
+ else:
400
+ raise ValueError("Unknown source! Currently this function only supports 'Anthropic' or 'Google' as model_source.")
401
+ # in situation that no JSON is found
402
+ if reply is not None:
403
+ extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
404
+ if extracted_json:
405
+ raw_json = extracted_json[0].strip() # Only strip leading/trailing whitespace
406
+ try:
407
+ # Parse to validate JSON structure
408
+ parsed_obj = json.loads(raw_json)
409
+ # Re-serialize for consistent formatting (optional)
410
+ cleaned_json = json.dumps(parsed_obj)
411
+ extracted_jsons.append(cleaned_json)
412
+ except json.JSONDecodeError as e:
413
+ print(f"JSON parsing error: {e}")
414
+ # Fallback to raw extraction if parsing fails
415
+ extracted_jsons.append(raw_json)
416
+ else:
417
+ # Use consistent schema for errors
418
+ error_message = json.dumps({"answer": "e"})
419
+ extracted_jsons.append(error_message)
420
+ print(error_message)
421
+ else:
422
+ # Handle None reply case
423
+ error_message = json.dumps({"answer": "e"})
424
+ extracted_jsons.append(error_message)
425
+ #print(error_message)
426
+
427
+ # --- Safety Save ---
428
+ if safety:
429
+ # Save progress so far
430
+ temp_df = pd.DataFrame({
431
+ 'raw_response': search_input[:idx+1]
432
+ #'model_response': link1,
433
+ #'json': extracted_jsons
434
+ })
435
+ # Normalize processed jsons so far
436
+ normalized_data_list = []
437
+ for json_str in extracted_jsons:
438
+ try:
439
+ parsed_obj = json.loads(json_str)
440
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
441
+ except json.JSONDecodeError:
442
+ normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
443
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
444
+ temp_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
445
+ temp_df = pd.concat([temp_df, normalized_data], axis=1)
446
+ temp_df = pd.concat([temp_df, temp_urls], axis=1)
447
+ # Save to CSV
448
+ if save_directory is None:
449
+ save_directory = os.getcwd()
450
+ temp_df.to_csv(os.path.join(save_directory, filename), index=False)
451
+
452
+ # --- Final DataFrame ---
453
+ normalized_data_list = []
454
+ for json_str in extracted_jsons:
455
+ try:
456
+ parsed_obj = json.loads(json_str)
457
+ normalized_data_list.append(pd.json_normalize(parsed_obj))
458
+ except json.JSONDecodeError:
459
+ normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
460
+ normalized_data = pd.concat(normalized_data_list, ignore_index=True)
461
+
462
+ # converting urls to dataframe and adding prefix
463
+ df_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
464
+
465
+ categorized_data = pd.DataFrame({
466
+ 'search_input': (
467
+ search_input.reset_index(drop=True) if isinstance(search_input, (pd.DataFrame, pd.Series))
468
+ else pd.Series(search_input)
469
+ ),
470
+ 'raw_response': pd.Series(link1).reset_index(drop=True),
471
+ #'json': pd.Series(extracted_jsons).reset_index(drop=True),
472
+ #"all_urls": pd.Series(extracted_urls).reset_index(drop=True)
473
+ })
474
+ categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
475
+ categorized_data = pd.concat([categorized_data, df_urls], axis=1)
476
+
477
+ # drop second best answer column if it exists
478
+ # we only ask for the second best answer to "force" the model to think more carefully about its best answer, but we don't actually need to keep it
479
+ categorized_data = categorized_data.drop(columns=["second_best_answer"], errors='ignore')
480
+
481
+ # dropping this column for advanced searches (this column is mostly useful for basic searches to see what the model saw)
482
+ if search_depth == "advanced":
483
+ categorized_data = categorized_data.drop(columns=["raw_response"], errors='ignore')
484
+
485
+ #for users who don't want the urls included in the final dataframe
486
+ if output_urls is False:
487
+ categorized_data = categorized_data.drop(columns=[col for col in categorized_data.columns if col.startswith("url_")])
488
+
489
+ if save_directory is not None:
490
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
491
+
492
+ return categorized_data
@@ -1,212 +0,0 @@
1
- #build dataset classification
2
- def build_web_research_dataset(
3
- search_question,
4
- search_input,
5
- api_key,
6
- answer_format = "concise",
7
- additional_instructions = "",
8
- categories = ['Answer','URL'],
9
- user_model="claude-sonnet-4-20250514",
10
- creativity=None,
11
- safety=False,
12
- filename="categorized_data.csv",
13
- save_directory=None,
14
- model_source="Anthropic",
15
- time_delay=15
16
- ):
17
- import os
18
- import json
19
- import pandas as pd
20
- import regex
21
- from tqdm import tqdm
22
- import time
23
-
24
- model_source = model_source.lower() # eliminating case sensitivity
25
-
26
- categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
27
- print(categories_str)
28
- cat_num = len(categories)
29
- category_dict = {str(i+1): "0" for i in range(cat_num)}
30
- example_JSON = json.dumps(category_dict, indent=4)
31
-
32
- # ensure number of categories is what user wants
33
- #print("\nThe information to be extracted:")
34
- #for i, cat in enumerate(categories, 1):
35
- #print(f"{i}. {cat}")
36
-
37
- link1 = []
38
- extracted_jsons = []
39
-
40
- max_retries = 5 #API rate limit error handler retries
41
-
42
- for idx, item in enumerate(tqdm(search_input, desc="Building dataset")):
43
- if idx > 0: # Skip delay for first item only
44
- time.sleep(time_delay)
45
- reply = None
46
-
47
- if pd.isna(item):
48
- link1.append("Skipped NaN input")
49
- default_json = example_JSON
50
- extracted_jsons.append(default_json)
51
- #print(f"Skipped NaN input.")
52
- else:
53
- prompt = f"""<role>You are a research assistant specializing in finding current, factual information.</role>
54
-
55
- <task>Find information about {item}'s {search_question}</task>
56
-
57
- <rules>
58
- - Search for the most current and authoritative information available
59
- - Provide your answer as {answer_format}
60
- - Prioritize official sources when possible
61
- - If information is not found, state "Information not found"
62
- - Include exactly one source URL where you found the information
63
- - Do not include any explanatory text or commentary beyond the JSON
64
- {additional_instructions}
65
- </rules>
66
-
67
- <format>
68
- Return your response as valid JSON with this exact structure:
69
- {{
70
- "answer": "Your factual answer or 'Information not found'",
71
- "url": "Source URL or 'No source available'"
72
- }}
73
- </format>"""
74
- #print(prompt)
75
- if model_source == "anthropic":
76
- import anthropic
77
- client = anthropic.Anthropic(api_key=api_key)
78
-
79
- attempt = 0
80
- while attempt < max_retries:
81
- try:
82
- message = client.messages.create(
83
- model=user_model,
84
- max_tokens=1024,
85
- messages=[{"role": "user", "content": prompt}],
86
- **({"temperature": creativity} if creativity is not None else {}),
87
- tools=[{
88
- "type": "web_search_20250305",
89
- "name": "web_search"
90
- }]
91
- )
92
- reply = " ".join(
93
- block.text
94
- for block in message.content
95
- if getattr(block, "type", "") == "text"
96
- ).strip()
97
- link1.append(reply)
98
- break
99
- except anthropic.error.RateLimitError as e:
100
- wait_time = 2 ** attempt # Exponential backoff, keeps doubling after each attempt
101
- print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
102
- time.sleep(wait_time) #in case user wants to try and buffer the amount of errors by adding a wait time before attemps
103
- attempt += 1
104
- except Exception as e:
105
- print(f"A Non-rate-limit error occurred: {e}")
106
- link1.append(f"Error processing input: {e}")
107
- break #stop retrying
108
- else:
109
- link1.append("Max retries exceeded for rate limit errors.")
110
-
111
- elif model_source == "google":
112
- import requests
113
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
114
- try:
115
- headers = {
116
- "x-goog-api-key": api_key,
117
- "Content-Type": "application/json"
118
- }
119
- payload = {
120
- "contents": [{"parts": [{"text": prompt}]}],
121
- "tools": [{"google_search": {}}],
122
- **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
123
- }
124
-
125
- response = requests.post(url, headers=headers, json=payload)
126
- response.raise_for_status()
127
- result = response.json()
128
-
129
- # extract reply from Google's response structure
130
- if "candidates" in result and result["candidates"]:
131
- reply = result["candidates"][0]["content"]["parts"][0]["text"]
132
- else:
133
- reply = "No response generated"
134
-
135
- link1.append(reply)
136
-
137
- except Exception as e:
138
- print(f"An error occurred: {e}")
139
- link1.append(f"Error processing input: {e}")
140
-
141
- else:
142
- raise ValueError("Unknown source! Currently this function only supports 'Anthropic' or 'Google' as model_source.")
143
- # in situation that no JSON is found
144
- if reply is not None:
145
- extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
146
- if extracted_json:
147
- raw_json = extracted_json[0].strip() # Only strip leading/trailing whitespace
148
- try:
149
- # Parse to validate JSON structure
150
- parsed_obj = json.loads(raw_json)
151
- # Re-serialize for consistent formatting (optional)
152
- cleaned_json = json.dumps(parsed_obj)
153
- extracted_jsons.append(cleaned_json)
154
- except json.JSONDecodeError as e:
155
- print(f"JSON parsing error: {e}")
156
- # Fallback to raw extraction if parsing fails
157
- extracted_jsons.append(raw_json)
158
- else:
159
- # Use consistent schema for errors
160
- error_message = json.dumps({"answer": "e", "url": "e"})
161
- extracted_jsons.append(error_message)
162
- print(error_message)
163
- else:
164
- # Handle None reply case
165
- error_message = json.dumps({"answer": "e", "url": "e"})
166
- extracted_jsons.append(error_message)
167
- #print(error_message)
168
-
169
- # --- Safety Save ---
170
- if safety:
171
- # Save progress so far
172
- temp_df = pd.DataFrame({
173
- 'survey_response': search_input[:idx+1],
174
- 'model_response': link1,
175
- 'json': extracted_jsons
176
- })
177
- # Normalize processed jsons so far
178
- normalized_data_list = []
179
- for json_str in extracted_jsons:
180
- try:
181
- parsed_obj = json.loads(json_str)
182
- normalized_data_list.append(pd.json_normalize(parsed_obj))
183
- except json.JSONDecodeError:
184
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
185
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
186
- temp_df = pd.concat([temp_df, normalized_data], axis=1)
187
- # Save to CSV
188
- if save_directory is None:
189
- save_directory = os.getcwd()
190
- temp_df.to_csv(os.path.join(save_directory, filename), index=False)
191
-
192
- # --- Final DataFrame ---
193
- normalized_data_list = []
194
- for json_str in extracted_jsons:
195
- try:
196
- parsed_obj = json.loads(json_str)
197
- normalized_data_list.append(pd.json_normalize(parsed_obj))
198
- except json.JSONDecodeError:
199
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
200
- normalized_data = pd.concat(normalized_data_list, ignore_index=True)
201
-
202
- categorized_data = pd.DataFrame({
203
- 'survey_response': (
204
- search_input.reset_index(drop=True) if isinstance(search_input, (pd.DataFrame, pd.Series))
205
- else pd.Series(search_input)
206
- ),
207
- 'link1': pd.Series(link1).reset_index(drop=True),
208
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
209
- })
210
- categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
211
-
212
- return categorized_data
File without changes
File without changes
File without changes