cat-llm 0.0.62__py3-none-any.whl → 0.0.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cat-llm
3
- Version: 0.0.62
3
+ Version: 0.0.64
4
4
  Summary: A tool for categorizing text data and images using LLMs and vision models
5
5
  Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
6
6
  Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
@@ -22,6 +22,7 @@ Requires-Python: >=3.8
22
22
  Requires-Dist: anthropic
23
23
  Requires-Dist: openai
24
24
  Requires-Dist: pandas
25
+ Requires-Dist: perplexity
25
26
  Requires-Dist: requests
26
27
  Requires-Dist: tqdm
27
28
  Description-Content-Type: text/markdown
@@ -0,0 +1,15 @@
1
+ catllm/CERAD_functions.py,sha256=q4HbP5e2Yu8NnZZ-2eX4sImyj6u3i8xWcq0pYU81iis,22676
2
+ catllm/__about__.py,sha256=ef_C266qfrp7mTd1dpTp_iodPNJXT6D5pQVsfdLEmB8,408
3
+ catllm/__init__.py,sha256=sf02zp7N0NW0mAQi7eQ4gliWR1EwoqvXkHN2HwwjcTE,372
4
+ catllm/build_web_research.py,sha256=880dfE2bEQb-FrXP-42JoLLtyc9ox_sBULDr38xiTiQ,22655
5
+ catllm/image_functions.py,sha256=8_FftRU285x1HT-AgNkaobefQVD-5q7ZY_t7JFdL3Sg,36177
6
+ catllm/text_functions.py,sha256=Jf51lNaFtcS2QGnNLkhM8rFVJSD4tN0Bm_VfELvb47g,18686
7
+ catllm/images/circle.png,sha256=JWujAWAh08-TajAoEr_TAeFNLlfbryOLw6cgIBREBuQ,86202
8
+ catllm/images/cube.png,sha256=nFec3e5bmRe4zrBCJ8QK-HcJLrG7u7dYdKhmdMfacfE,77275
9
+ catllm/images/diamond.png,sha256=rJDZKtsnBGRO8FPA0iHuA8FvHFGi9PkI_DWSFdw6iv0,99568
10
+ catllm/images/overlapping_pentagons.png,sha256=VO5plI6eoVRnjfqinn1nNzsCP2WQhuQy71V0EASouW4,71208
11
+ catllm/images/rectangles.png,sha256=2XM16HO9EYWj2yHgN4bPXaCwPfl7iYQy0tQUGaJX9xg,40692
12
+ cat_llm-0.0.64.dist-info/METADATA,sha256=KKM4UUyBk7ty5oG7IpVjnJvbF-vRbpC0tFTsAl8UdLo,22421
13
+ cat_llm-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ cat_llm-0.0.64.dist-info/licenses/LICENSE,sha256=Vje2sS5WV4TnIwY5uQHrF4qnBAM3YOk1pGpdH0ot-2o,34969
15
+ cat_llm-0.0.64.dist-info/RECORD,,
catllm/CERAD_functions.py CHANGED
@@ -44,6 +44,8 @@ def cerad_drawn_score(
44
44
  from pathlib import Path
45
45
  import pkg_resources
46
46
 
47
+ model_source = model_source.lower() # eliminating case sensitivity
48
+
47
49
  shape = shape.lower()
48
50
  shape = "rectangles" if shape == "overlapping rectangles" else shape
49
51
  if shape == "circle":
@@ -155,7 +157,7 @@ def cerad_drawn_score(
155
157
  else:
156
158
  reference_text = f"Image is expected to show within it a drawing of a {shape}.\n\n"
157
159
 
158
- if model_source == "OpenAI" and valid_image:
160
+ if model_source == "openai" and valid_image:
159
161
  prompt = [
160
162
  {
161
163
  "type": "text",
@@ -185,7 +187,7 @@ def cerad_drawn_score(
185
187
  "image_url": {"url": encoded_image, "detail": "high"}
186
188
  })
187
189
 
188
- elif model_source == "Anthropic" and valid_image:
190
+ elif model_source == "anthropic" and valid_image:
189
191
  prompt = [
190
192
  {
191
193
  "type": "text",
@@ -225,7 +227,7 @@ def cerad_drawn_score(
225
227
  }
226
228
  )
227
229
 
228
- elif model_source == "Mistral" and valid_image:
230
+ elif model_source == "mistral" and valid_image:
229
231
  prompt = [
230
232
  {
231
233
  "type": "text",
@@ -254,7 +256,7 @@ def cerad_drawn_score(
254
256
  "image_url": f"data:image/{ext};base64,{encoded_image}"
255
257
  })
256
258
 
257
- if model_source == "OpenAI" and valid_image:
259
+ if model_source == "openai" and valid_image:
258
260
  from openai import OpenAI
259
261
  client = OpenAI(api_key=api_key)
260
262
  try:
@@ -272,7 +274,7 @@ def cerad_drawn_score(
272
274
  print("An error occurred: {e}")
273
275
  link1.append("Error processing input: {e}")
274
276
 
275
- elif model_source == "Anthropic" and valid_image:
277
+ elif model_source == "anthropic" and valid_image:
276
278
  import anthropic
277
279
  client = anthropic.Anthropic(api_key=api_key)
278
280
  try:
@@ -291,7 +293,7 @@ def cerad_drawn_score(
291
293
  print("An error occurred: {e}")
292
294
  link1.append("Error processing input: {e}")
293
295
 
294
- elif model_source == "Mistral" and valid_image:
296
+ elif model_source == "mistral" and valid_image:
295
297
  from mistralai import Mistral
296
298
  reply = None
297
299
  client = Mistral(api_key=api_key)
catllm/__about__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
2
2
  #
3
3
  # SPDX-License-Identifier: MIT
4
- __version__ = "0.0.62"
4
+ __version__ = "0.0.64"
5
5
  __author__ = "Chris Soria"
6
6
  __email__ = "chrissoria@berkeley.edu"
7
7
  __title__ = "cat-llm"
@@ -5,39 +5,78 @@ def build_web_research_dataset(
5
5
  api_key,
6
6
  answer_format = "concise",
7
7
  additional_instructions = "",
8
- categories = ['Answer','URL'],
8
+ categories = ['Answer'],
9
9
  user_model="claude-sonnet-4-20250514",
10
10
  creativity=None,
11
11
  safety=False,
12
12
  filename="categorized_data.csv",
13
13
  save_directory=None,
14
14
  model_source="Anthropic",
15
- time_delay=15
15
+ start_date=None,
16
+ end_date=None,
17
+ search_depth="", #enables Tavily searches
18
+ tavily_api=None,
19
+ output_urls = True,
20
+ max_retries = 6, #API rate limit error handler retries
21
+ time_delay=5
16
22
  ):
17
23
  import os
24
+ import re
18
25
  import json
19
26
  import pandas as pd
20
27
  import regex
21
28
  from tqdm import tqdm
22
29
  import time
30
+ from datetime import datetime
31
+
32
+ #ensures proper date format
33
+ def _validate_date(date_str):
34
+ """Validates YYYY-MM-DD format"""
35
+ if date_str is None:
36
+ return True # None is acceptable (means no date constraint)
37
+
38
+ if not isinstance(date_str, str):
39
+ return False
40
+
41
+ # Check pattern: YYYY_MM_DD
42
+ pattern = r'^\d{4}-\d{2}-\d{2}$'
43
+ if not re.match(pattern, date_str):
44
+ return False
45
+
46
+ # Validate actual date
47
+ try:
48
+ year, month, day = date_str.split('-')
49
+ datetime(int(year), int(month), int(day))
50
+ return True
51
+ except (ValueError, OverflowError):
52
+ return False
53
+
54
+ # Validate dates at the start of the function
55
+ if not _validate_date(start_date):
56
+ raise ValueError(f"start_date must be in YYYY-MM-DD format, got: {start_date}")
57
+
58
+ if not _validate_date(end_date):
59
+ raise ValueError(f"end_date must be in YYYY-MM-DD format, got: {end_date}")
23
60
 
24
61
  model_source = model_source.lower() # eliminating case sensitivity
62
+
63
+ if model_source == "perplexity" and start_date is not None:
64
+ start_date = datetime.strptime(start_date, "%Y-%m-%d").strftime("%m/%d/%Y")
65
+ if model_source == "perplexity" and end_date is not None:
66
+ end_date = datetime.strptime(end_date, "%Y-%m-%d").strftime("%m/%d/%Y")
67
+
68
+ # in case user switches to google but doesn't switch model
69
+ if model_source == "google" and user_model == "claude-sonnet-4-20250514":
70
+ user_model = "gemini-2.5-flash"
25
71
 
26
72
  categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
27
- print(categories_str)
28
73
  cat_num = len(categories)
29
74
  category_dict = {str(i+1): "0" for i in range(cat_num)}
30
75
  example_JSON = json.dumps(category_dict, indent=4)
31
-
32
- # ensure number of categories is what user wants
33
- #print("\nThe information to be extracted:")
34
- #for i, cat in enumerate(categories, 1):
35
- #print(f"{i}. {cat}")
36
76
 
37
77
  link1 = []
38
78
  extracted_jsons = []
39
-
40
- max_retries = 5 #API rate limit error handler retries
79
+ extracted_urls = []
41
80
 
42
81
  for idx, item in enumerate(tqdm(search_input, desc="Building dataset")):
43
82
  if idx > 0: # Skip delay for first item only
@@ -46,9 +85,9 @@ def build_web_research_dataset(
46
85
 
47
86
  if pd.isna(item):
48
87
  link1.append("Skipped NaN input")
88
+ extracted_urls.append([])
49
89
  default_json = example_JSON
50
90
  extracted_jsons.append(default_json)
51
- #print(f"Skipped NaN input.")
52
91
  else:
53
92
  prompt = f"""<role>You are a research assistant specializing in finding current, factual information.</role>
54
93
 
@@ -59,7 +98,6 @@ def build_web_research_dataset(
59
98
  - Provide your answer as {answer_format}
60
99
  - Prioritize official sources when possible
61
100
  - If information is not found, state "Information not found"
62
- - Include exactly one source URL where you found the information
63
101
  - Do not include any explanatory text or commentary beyond the JSON
64
102
  {additional_instructions}
65
103
  </rules>
@@ -68,14 +106,74 @@ def build_web_research_dataset(
68
106
  Return your response as valid JSON with this exact structure:
69
107
  {{
70
108
  "answer": "Your factual answer or 'Information not found'",
71
- "url": "Source URL or 'No source available'"
109
+ "second_best_answer": "Your second best factual answer or 'Information not found'",
110
+ "confidence": "confidence in response 0-5 or 'Information not found'"
72
111
  }}
112
+
73
113
  </format>"""
74
- #print(prompt)
75
- if model_source == "anthropic":
114
+
115
+ if start_date is not None and end_date is not None:
116
+ append_text = f"\n- Focus on webpages with a page age between {start_date} and {end_date}."
117
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
118
+ elif start_date is not None:
119
+ append_text = f"\n- Focus on webpages published after {start_date}."
120
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
121
+ elif end_date is not None:
122
+ append_text = f"\n- Focus on webpages published before {end_date}."
123
+ prompt = prompt.replace("<rules>", "<rules>" + append_text)
124
+
125
+ if search_depth == "advanced" and model_source != "perplexity":
126
+ try:
127
+ from tavily import TavilyClient
128
+ tavily_client = TavilyClient(tavily_api)
129
+ tavily_response = tavily_client.search(
130
+ query=f"{item}'s {search_question}",
131
+ include_answer=True,
132
+ max_results=15,
133
+ search_depth="advanced",
134
+ **({"start_date": start_date} if start_date is not None else {}),
135
+ **({"end_date": end_date} if end_date is not None else {})
136
+ )
137
+
138
+ urls = [
139
+ result['url']
140
+ for result in tavily_response.get('results', [])
141
+ if 'url' in result
142
+ ]
143
+ seen = set()
144
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
145
+ extracted_urls.append(urls)
146
+
147
+ except Exception as e:
148
+ error_msg = str(e).lower()
149
+ if "unauthorized" in error_msg or "403" in error_msg or "401" in error_msg or "api_key" in error_msg:
150
+ raise ValueError("ERROR: Invalid or missing tavily_api required for advanced search. Get one at https://app.tavily.com/home. To install: pip install tavily-python") from e
151
+ else:
152
+ print(f"Tavily search error: {e}")
153
+ link1.append(f"Error with Tavily search: {e}")
154
+ extracted_urls.append([])
155
+ continue
156
+
157
+ #print(tavily_response)
158
+
159
+ advanced_prompt = f"""Based on the following search results about {item}'s {search_question}, provide your answer in this EXACT JSON format and {answer_format}:
160
+ If you can't find the information, respond with 'Information not found'.
161
+ {{"answer": "your answer here or 'Information not found'",
162
+ "second_best_answer": "your second best answer here or 'Information not found'",
163
+ "confidence": "confidence in response 0-5 or 'Information not found'"}}
164
+
165
+ Search results:
166
+ {tavily_response}
167
+
168
+ Additional context from sources:
169
+ {chr(10).join([f"- {r.get('title', '')}: {r.get('content', '')}" for r in tavily_response.get('results', [])[:3]])}
170
+
171
+ Return ONLY the JSON object, no other text."""
172
+
173
+ if model_source == "anthropic" and search_depth != "advanced":
76
174
  import anthropic
77
175
  client = anthropic.Anthropic(api_key=api_key)
78
-
176
+ #print(prompt)
79
177
  attempt = 0
80
178
  while attempt < max_retries:
81
179
  try:
@@ -95,8 +193,21 @@ def build_web_research_dataset(
95
193
  if getattr(block, "type", "") == "text"
96
194
  ).strip()
97
195
  link1.append(reply)
196
+
197
+ urls = [
198
+ item["url"]
199
+ for block in message.content
200
+ if getattr(block, "type", "") == "web_search_tool_result"
201
+ for item in (getattr(block, "content", []) or [])
202
+ if isinstance(item, dict) and item.get("type") == "web_search_result" and "url" in item
203
+ ]
204
+
205
+ seen = set()
206
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
207
+ extracted_urls.append(urls)
208
+
98
209
  break
99
- except anthropic.error.RateLimitError as e:
210
+ except anthropic.RateLimitError as e:
100
211
  wait_time = 2 ** attempt # Exponential backoff, keeps doubling after each attempt
101
212
  print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
102
213
  time.sleep(wait_time) #in case user wants to try and buffer the amount of errors by adding a wait time before attemps
@@ -104,11 +215,59 @@ def build_web_research_dataset(
104
215
  except Exception as e:
105
216
  print(f"A Non-rate-limit error occurred: {e}")
106
217
  link1.append(f"Error processing input: {e}")
218
+ extracted_urls.append([])
107
219
  break #stop retrying
108
220
  else:
109
221
  link1.append("Max retries exceeded for rate limit errors.")
222
+ extracted_urls.append([])
110
223
 
111
- elif model_source == "google":
224
+ elif model_source == "anthropic" and search_depth == "advanced":
225
+ import anthropic
226
+ claude_client = anthropic.Anthropic(api_key=api_key)
227
+
228
+ attempt = 0
229
+ while attempt < max_retries:
230
+ try:
231
+ message = claude_client.messages.create(
232
+ model=user_model,
233
+ max_tokens=1024,
234
+ messages=[{"role": "user", "content": advanced_prompt}],
235
+ **({"temperature": creativity} if creativity is not None else {})
236
+ )
237
+
238
+ reply = " ".join(
239
+ block.text
240
+ for block in message.content
241
+ if getattr(block, "type", "") == "text"
242
+ ).strip()
243
+
244
+ try:
245
+ import json
246
+ json_response = json.loads(reply)
247
+ final_answer = json_response.get('answer', reply)
248
+ link1.append(final_answer)
249
+ except json.JSONDecodeError:
250
+
251
+ print(f"JSON parse error, using raw reply: {reply}")
252
+ link1.append(reply)
253
+
254
+ break # Success
255
+
256
+ except anthropic.RateLimitError as e:
257
+ wait_time = 2 ** attempt
258
+ print(f"Rate limit error encountered. Retrying in {wait_time} seconds...")
259
+ time.sleep(wait_time)
260
+ attempt += 1
261
+
262
+ except Exception as e:
263
+ print(f"A Non-rate-limit error occurred: {e}")
264
+ link1.append(f"Error processing input: {e}")
265
+ break
266
+ else:
267
+ # Max retries exceeded
268
+ link1.append("Max retries exceeded for rate limit errors.")
269
+
270
+ elif model_source == "google" and search_depth != "advanced":
112
271
  import requests
113
272
  url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
114
273
  try:
@@ -121,11 +280,62 @@ def build_web_research_dataset(
121
280
  "tools": [{"google_search": {}}],
122
281
  **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
123
282
  }
124
-
283
+
125
284
  response = requests.post(url, headers=headers, json=payload)
126
285
  response.raise_for_status()
127
286
  result = response.json()
287
+
288
+ urls = []
289
+ for cand in result.get("candidates", []):
290
+ rendered_html = (
291
+ cand.get("groundingMetadata", {})
292
+ .get("searchEntryPoint", {})
293
+ .get("renderedContent", "")
294
+ )
295
+ if rendered_html:
296
+ # regex: capture href="..."; limited to class="chip"
297
+ found = re.findall(
298
+ r'<a[^>]*class=["\']chip["\'][^>]*href=["\']([^"\']+)["\']',
299
+ rendered_html,
300
+ flags=re.IGNORECASE
301
+ )
302
+ urls.extend(found)
303
+
304
+ seen = set()
305
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
306
+ extracted_urls.append(urls)
307
+
308
+ # extract reply from Google's response structure
309
+ if "candidates" in result and result["candidates"]:
310
+ reply = result["candidates"][0]["content"]["parts"][0]["text"]
311
+ else:
312
+ reply = "No response generated"
313
+
314
+ link1.append(reply)
315
+
316
+ except Exception as e:
317
+ print(f"An error occurred: {e}")
318
+ link1.append(f"Error processing input: {e}")
319
+ extracted_urls.append([])
320
+
321
+ elif model_source == "google" and search_depth == "advanced":
322
+ import requests
323
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
324
+ try:
325
+ headers = {
326
+ "x-goog-api-key": api_key,
327
+ "Content-Type": "application/json"
328
+ }
128
329
 
330
+ payload = {
331
+ "contents": [{"parts": [{"text": advanced_prompt}]}],
332
+ **({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
333
+ }
334
+
335
+ response = requests.post(url, headers=headers, json=payload)
336
+ response.raise_for_status()
337
+ result = response.json()
338
+
129
339
  # extract reply from Google's response structure
130
340
  if "candidates" in result and result["candidates"]:
131
341
  reply = result["candidates"][0]["content"]["parts"][0]["text"]
@@ -138,6 +348,54 @@ def build_web_research_dataset(
138
348
  print(f"An error occurred: {e}")
139
349
  link1.append(f"Error processing input: {e}")
140
350
 
351
+ elif model_source == "perplexity":
352
+
353
+ from perplexity import Perplexity
354
+ client = Perplexity(api_key=api_key)
355
+ try:
356
+ response = client.chat.completions.create(
357
+ messages=[
358
+ {
359
+ "role": "user",
360
+ "content": prompt
361
+ }
362
+ ],
363
+ model=user_model,
364
+ max_tokens=1024,
365
+ **({"temperature": creativity} if creativity is not None else {}),
366
+ web_search_options={"search_context_size": "high" if search_depth == "advanced" else "medium"},
367
+ **({"search_after_date_filter": start_date} if start_date else {}),
368
+ **({"search_before_date_filter": end_date} if end_date else {}),
369
+ response_format={ #requiring a JSON
370
+ "type": "json_schema",
371
+ "json_schema": {
372
+ "schema": {
373
+ "type": "object",
374
+ "properties": {
375
+ "answer": {"type": "string"},
376
+ "second_best_answer": {"type": "string"},
377
+ "confidence": {"type": "integer"}
378
+ },
379
+ "required": ["answer", "second_best_answer"]
380
+ }
381
+ }
382
+ }
383
+ )
384
+
385
+ reply = response.choices[0].message.content
386
+ #print(response)
387
+ link1.append(reply)
388
+
389
+ urls = list(response.citations) if hasattr(response, 'citations') else []
390
+
391
+ seen = set()
392
+ urls = [u for u in urls if not (u in seen or seen.add(u))]
393
+ extracted_urls.append(urls)
394
+
395
+ except Exception as e:
396
+ print(f"An error occurred: {e}")
397
+ link1.append(f"Error processing input: {e}")
398
+ extracted_urls.append([])
141
399
  else:
142
400
  raise ValueError("Unknown source! Currently this function only supports 'Anthropic' or 'Google' as model_source.")
143
401
  # in situation that no JSON is found
@@ -157,12 +415,12 @@ def build_web_research_dataset(
157
415
  extracted_jsons.append(raw_json)
158
416
  else:
159
417
  # Use consistent schema for errors
160
- error_message = json.dumps({"answer": "e", "url": "e"})
418
+ error_message = json.dumps({"answer": "e"})
161
419
  extracted_jsons.append(error_message)
162
420
  print(error_message)
163
421
  else:
164
422
  # Handle None reply case
165
- error_message = json.dumps({"answer": "e", "url": "e"})
423
+ error_message = json.dumps({"answer": "e"})
166
424
  extracted_jsons.append(error_message)
167
425
  #print(error_message)
168
426
 
@@ -170,9 +428,9 @@ def build_web_research_dataset(
170
428
  if safety:
171
429
  # Save progress so far
172
430
  temp_df = pd.DataFrame({
173
- 'survey_response': search_input[:idx+1],
174
- 'model_response': link1,
175
- 'json': extracted_jsons
431
+ 'raw_response': search_input[:idx+1]
432
+ #'model_response': link1,
433
+ #'json': extracted_jsons
176
434
  })
177
435
  # Normalize processed jsons so far
178
436
  normalized_data_list = []
@@ -181,9 +439,11 @@ def build_web_research_dataset(
181
439
  parsed_obj = json.loads(json_str)
182
440
  normalized_data_list.append(pd.json_normalize(parsed_obj))
183
441
  except json.JSONDecodeError:
184
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
442
+ normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
185
443
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
444
+ temp_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
186
445
  temp_df = pd.concat([temp_df, normalized_data], axis=1)
446
+ temp_df = pd.concat([temp_df, temp_urls], axis=1)
187
447
  # Save to CSV
188
448
  if save_directory is None:
189
449
  save_directory = os.getcwd()
@@ -196,17 +456,37 @@ def build_web_research_dataset(
196
456
  parsed_obj = json.loads(json_str)
197
457
  normalized_data_list.append(pd.json_normalize(parsed_obj))
198
458
  except json.JSONDecodeError:
199
- normalized_data_list.append(pd.DataFrame({"1": ["e"]}))
459
+ normalized_data_list.append(pd.DataFrame({"answer": ["e"]}))
200
460
  normalized_data = pd.concat(normalized_data_list, ignore_index=True)
201
461
 
462
+ # converting urls to dataframe and adding prefix
463
+ df_urls = pd.DataFrame(extracted_urls).add_prefix("url_")
464
+
202
465
  categorized_data = pd.DataFrame({
203
- 'survey_response': (
466
+ 'search_input': (
204
467
  search_input.reset_index(drop=True) if isinstance(search_input, (pd.DataFrame, pd.Series))
205
468
  else pd.Series(search_input)
206
469
  ),
207
- 'link1': pd.Series(link1).reset_index(drop=True),
208
- 'json': pd.Series(extracted_jsons).reset_index(drop=True)
470
+ 'raw_response': pd.Series(link1).reset_index(drop=True),
471
+ #'json': pd.Series(extracted_jsons).reset_index(drop=True),
472
+ #"all_urls": pd.Series(extracted_urls).reset_index(drop=True)
209
473
  })
210
474
  categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
475
+ categorized_data = pd.concat([categorized_data, df_urls], axis=1)
476
+
477
+ # drop second best answer column if it exists
478
+ # we only ask for the second best answer to "force" the model to think more carefully about its best answer, but we don't actually need to keep it
479
+ categorized_data = categorized_data.drop(columns=["second_best_answer"], errors='ignore')
480
+
481
+ # dropping this column for advanced searches (this column is mostly useful for basic searches to see what the model saw)
482
+ if search_depth == "advanced":
483
+ categorized_data = categorized_data.drop(columns=["raw_response"], errors='ignore')
484
+
485
+ #for users who don't want the urls included in the final dataframe
486
+ if output_urls is False:
487
+ categorized_data = categorized_data.drop(columns=[col for col in categorized_data.columns if col.startswith("url_")])
488
+
489
+ if save_directory is not None:
490
+ categorized_data.to_csv(os.path.join(save_directory, filename), index=False)
211
491
 
212
492
  return categorized_data
catllm/image_functions.py CHANGED
@@ -33,6 +33,8 @@ def image_multi_class(
33
33
  '*.psd'
34
34
  ]
35
35
 
36
+ model_source = model_source.lower() # eliminating case sensitivity
37
+
36
38
  if not isinstance(image_input, list):
37
39
  # If image_input is a filepath (string)
38
40
  image_files = []
@@ -86,7 +88,7 @@ def image_multi_class(
86
88
 
87
89
  # Handle extension safely
88
90
  ext = Path(img_path).suffix.lstrip(".").lower()
89
- if model_source == "OpenAI" or model_source == "Mistral":
91
+ if model_source == "openai" or model_source == "mistral":
90
92
  encoded_image = f"data:image/{ext};base64,{encoded}"
91
93
  prompt = [
92
94
  {
@@ -110,7 +112,7 @@ def image_multi_class(
110
112
  },
111
113
  ]
112
114
 
113
- elif model_source == "Anthropic":
115
+ elif model_source == "anthropic":
114
116
  encoded_image = f"data:image/{ext};base64,{encoded}"
115
117
  prompt = [
116
118
  {"type": "text",
@@ -136,7 +138,7 @@ def image_multi_class(
136
138
  }
137
139
  }
138
140
  ]
139
- if model_source == "OpenAI":
141
+ if model_source == "openAI":
140
142
  from openai import OpenAI
141
143
  client = OpenAI(api_key=api_key)
142
144
  try:
@@ -154,7 +156,7 @@ def image_multi_class(
154
156
  print("An error occurred: {e}")
155
157
  link1.append("Error processing input: {e}")
156
158
 
157
- elif model_source == "Anthropic":
159
+ elif model_source == "anthropic":
158
160
  import anthropic
159
161
  reply = None
160
162
  client = anthropic.Anthropic(api_key=api_key)
@@ -174,7 +176,7 @@ def image_multi_class(
174
176
  print("An error occurred: {e}")
175
177
  link1.append("Error processing input: {e}")
176
178
 
177
- elif model_source == "Mistral":
179
+ elif model_source == "mistral":
178
180
  from mistralai import Mistral
179
181
  client = Mistral(api_key=api_key)
180
182
  try:
@@ -305,6 +307,8 @@ def image_score_drawing(
305
307
  '*.psd'
306
308
  ]
307
309
 
310
+ model_source = model_source.lower() # eliminating case sensitivity
311
+
308
312
  if not isinstance(image_input, list):
309
313
  # If image_input is a filepath (string)
310
314
  image_files = []
@@ -354,7 +358,7 @@ def image_score_drawing(
354
358
  ext = Path(img_path).suffix.lstrip(".").lower()
355
359
  encoded_image = f"data:image/{ext};base64,{encoded}"
356
360
 
357
- if model_source == "OpenAI" or model_source == "Mistral":
361
+ if model_source == "openai" or model_source == "mistral":
358
362
  prompt = [
359
363
  {
360
364
  "type": "text",
@@ -390,7 +394,7 @@ def image_score_drawing(
390
394
  }
391
395
  ]
392
396
 
393
- elif model_source == "Anthropic": # Changed to elif
397
+ elif model_source == "anthropic": # Changed to elif
394
398
  prompt = [
395
399
  {
396
400
  "type": "text",
@@ -435,7 +439,7 @@ def image_score_drawing(
435
439
  ]
436
440
 
437
441
 
438
- if model_source == "OpenAI":
442
+ if model_source == "openai":
439
443
  from openai import OpenAI
440
444
  client = OpenAI(api_key=api_key)
441
445
  try:
@@ -453,7 +457,7 @@ def image_score_drawing(
453
457
  print("An error occurred: {e}")
454
458
  link1.append("Error processing input: {e}")
455
459
 
456
- elif model_source == "Anthropic":
460
+ elif model_source == "anthropic":
457
461
  import anthropic
458
462
  client = anthropic.Anthropic(api_key=api_key)
459
463
  try:
@@ -472,7 +476,7 @@ def image_score_drawing(
472
476
  print("An error occurred: {e}")
473
477
  link1.append("Error processing input: {e}")
474
478
 
475
- elif model_source == "Mistral":
479
+ elif model_source == "mistral":
476
480
  from mistralai import Mistral
477
481
  client = Mistral(api_key=api_key)
478
482
  try:
@@ -598,6 +602,8 @@ def image_features(
598
602
  '*.psd'
599
603
  ]
600
604
 
605
+ model_source = model_source.lower() # eliminating case sensitivity
606
+
601
607
  if not isinstance(image_input, list):
602
608
  # If image_input is a filepath (string)
603
609
  image_files = []
@@ -644,7 +650,7 @@ def image_features(
644
650
  encoded_image = f"data:image/{ext};base64,{encoded}"
645
651
  valid_image = True
646
652
 
647
- if model_source == "OpenAI" or model_source == "Mistral":
653
+ if model_source == "openai" or model_source == "mistral":
648
654
  prompt = [
649
655
  {
650
656
  "type": "text",
@@ -674,7 +680,7 @@ def image_features(
674
680
  "image_url": {"url": encoded_image, "detail": "high"},
675
681
  },
676
682
  ]
677
- elif model_source == "Anthropic":
683
+ elif model_source == "anthropic":
678
684
  prompt = [
679
685
  {
680
686
  "type": "text",
@@ -708,7 +714,7 @@ def image_features(
708
714
  }
709
715
  }
710
716
  ]
711
- if model_source == "OpenAI":
717
+ if model_source == "openai":
712
718
  from openai import OpenAI
713
719
  client = OpenAI(api_key=api_key)
714
720
  try:
@@ -726,7 +732,7 @@ def image_features(
726
732
  print("An error occurred: {e}")
727
733
  link1.append("Error processing input: {e}")
728
734
 
729
- elif model_source == "Perplexity":
735
+ elif model_source == "perplexity":
730
736
  from openai import OpenAI
731
737
  client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
732
738
  try:
@@ -744,7 +750,7 @@ def image_features(
744
750
  print("An error occurred: {e}")
745
751
  link1.append("Error processing input: {e}")
746
752
 
747
- elif model_source == "Anthropic":
753
+ elif model_source == "anthropic":
748
754
  import anthropic
749
755
  client = anthropic.Anthropic(api_key=api_key)
750
756
  try:
@@ -763,7 +769,7 @@ def image_features(
763
769
  print("An error occurred: {e}")
764
770
  link1.append("Error processing input: {e}")
765
771
 
766
- elif model_source == "Mistral":
772
+ elif model_source == "mistral":
767
773
  from mistralai import Mistral
768
774
  client = Mistral(api_key=api_key)
769
775
  try:
catllm/text_functions.py CHANGED
@@ -22,6 +22,8 @@ def explore_corpus(
22
22
  print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted.")
23
23
  print()
24
24
 
25
+ model_source = model_source.lower() # eliminating case sensitivity
26
+
25
27
  chunk_size = round(max(1, len(survey_input) / divisions),0)
26
28
  chunk_size = int(chunk_size)
27
29
 
@@ -46,7 +48,7 @@ Responses are each separated by a semicolon. \
46
48
  Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
47
49
  Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
48
50
 
49
- if model_source == "OpenAI":
51
+ if model_source == "openai":
50
52
  client = OpenAI(api_key=api_key)
51
53
  try:
52
54
  response_obj = client.chat.completions.create(
@@ -123,6 +125,8 @@ def explore_common_categories(
123
125
  print(f"Exploring class for question: '{survey_question}'.\n {cat_num * divisions} unique categories to be extracted and {top_n} to be identified as the most common.")
124
126
  print()
125
127
 
128
+ model_source = model_source.lower() # eliminating case sensitivity
129
+
126
130
  chunk_size = round(max(1, len(survey_input) / divisions),0)
127
131
  chunk_size = int(chunk_size)
128
132
 
@@ -147,7 +151,7 @@ Responses are each separated by a semicolon. \
147
151
  Responses are contained within triple backticks here: ```{survey_participant_chunks}``` \
148
152
  Number your categories from 1 through {cat_num} and be concise with the category labels and provide no description of the categories."""
149
153
 
150
- if model_source == "OpenAI":
154
+ if model_source == "openai":
151
155
  client = OpenAI(api_key=api_key)
152
156
  try:
153
157
  response_obj = client.chat.completions.create(
@@ -198,7 +202,7 @@ Number your categories from 1 through {cat_num} and be concise with the category
198
202
  The categories are contained within triple backticks here: ```{df['Category'].tolist()}``` \
199
203
  Return the top {top_n} categories as a numbered list sorted from the most to least common and keep the categories {specificity}, with no additional text or explanation."""
200
204
 
201
- if model_source == "OpenAI":
205
+ if model_source == "openai":
202
206
  client = OpenAI(api_key=api_key)
203
207
  response_obj = client.chat.completions.create(
204
208
  model=user_model,
@@ -237,6 +241,8 @@ def multi_class(
237
241
  import pandas as pd
238
242
  import regex
239
243
  from tqdm import tqdm
244
+
245
+ model_source = model_source.lower() # eliminating case sensitivity
240
246
 
241
247
  categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
242
248
  cat_num = len(categories)
@@ -265,7 +271,7 @@ Categorize this survey response "{response}" into the following categories that
265
271
  {categories_str} \
266
272
  Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
267
273
  #print(prompt)
268
- if model_source == ("OpenAI"):
274
+ if model_source == ("openai"):
269
275
  from openai import OpenAI
270
276
  client = OpenAI(api_key=api_key)
271
277
  try:
@@ -279,7 +285,7 @@ Provide your work in JSON format where the number belonging to each category is
279
285
  except Exception as e:
280
286
  print(f"An error occurred: {e}")
281
287
  link1.append(f"Error processing input: {e}")
282
- elif model_source == "Perplexity":
288
+ elif model_source == "perplexity":
283
289
  from openai import OpenAI
284
290
  client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
285
291
  try:
@@ -293,7 +299,7 @@ Provide your work in JSON format where the number belonging to each category is
293
299
  except Exception as e:
294
300
  print(f"An error occurred: {e}")
295
301
  link1.append(f"Error processing input: {e}")
296
- elif model_source == "Anthropic":
302
+ elif model_source == "anthropic":
297
303
  import anthropic
298
304
  client = anthropic.Anthropic(api_key=api_key)
299
305
  try:
@@ -309,7 +315,7 @@ Provide your work in JSON format where the number belonging to each category is
309
315
  print(f"An error occurred: {e}")
310
316
  link1.append(f"Error processing input: {e}")
311
317
 
312
- elif model_source == "Google":
318
+ elif model_source == "google":
313
319
  import requests
314
320
  url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
315
321
  try:
@@ -339,7 +345,7 @@ Provide your work in JSON format where the number belonging to each category is
339
345
  print(f"An error occurred: {e}")
340
346
  link1.append(f"Error processing input: {e}")
341
347
 
342
- elif model_source == "Mistral":
348
+ elif model_source == "mistral":
343
349
  from mistralai import Mistral
344
350
  client = Mistral(api_key=api_key)
345
351
  try:
@@ -1,15 +0,0 @@
1
- catllm/CERAD_functions.py,sha256=05n7h27TuAp3klkOnrH--m1wMreYqYuObM9NIab934o,22603
2
- catllm/__about__.py,sha256=R0Mt1NOAMAQCF7SHD4XDl2P4gF92EnfjYXaJ1Xo0vdc,408
3
- catllm/__init__.py,sha256=sf02zp7N0NW0mAQi7eQ4gliWR1EwoqvXkHN2HwwjcTE,372
4
- catllm/build_web_research.py,sha256=nAKfkg7lihjXrYrLvltsKCvpb5zRFYpNp95A-0zpDb8,9159
5
- catllm/image_functions.py,sha256=KDb2UxDLrioerlqKZjKAX7lqfW-S_TSQCK6YxtJRKwI,35958
6
- catllm/text_functions.py,sha256=xfpwAYivnPnDlsU21vp1Pma9mDR24tn1lcBZQfsyIrc,18467
7
- catllm/images/circle.png,sha256=JWujAWAh08-TajAoEr_TAeFNLlfbryOLw6cgIBREBuQ,86202
8
- catllm/images/cube.png,sha256=nFec3e5bmRe4zrBCJ8QK-HcJLrG7u7dYdKhmdMfacfE,77275
9
- catllm/images/diamond.png,sha256=rJDZKtsnBGRO8FPA0iHuA8FvHFGi9PkI_DWSFdw6iv0,99568
10
- catllm/images/overlapping_pentagons.png,sha256=VO5plI6eoVRnjfqinn1nNzsCP2WQhuQy71V0EASouW4,71208
11
- catllm/images/rectangles.png,sha256=2XM16HO9EYWj2yHgN4bPXaCwPfl7iYQy0tQUGaJX9xg,40692
12
- cat_llm-0.0.62.dist-info/METADATA,sha256=jstvau7l_g2KqYSheIcZJxC8DX2Bf_lA_wLNzPO5-qw,22395
13
- cat_llm-0.0.62.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- cat_llm-0.0.62.dist-info/licenses/LICENSE,sha256=Vje2sS5WV4TnIwY5uQHrF4qnBAM3YOk1pGpdH0ot-2o,34969
15
- cat_llm-0.0.62.dist-info/RECORD,,