cat-llm 0.0.67__py3-none-any.whl → 0.0.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.68.dist-info}/METADATA +1 -1
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.68.dist-info}/RECORD +7 -6
- catllm/__about__.py +1 -1
- catllm/model_reference_list.py +93 -0
- catllm/text_functions.py +426 -36
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.68.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.67.dist-info → cat_llm-0.0.68.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.68
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
catllm/CERAD_functions.py,sha256=q4HbP5e2Yu8NnZZ-2eX4sImyj6u3i8xWcq0pYU81iis,22676
|
|
2
|
-
catllm/__about__.py,sha256=
|
|
2
|
+
catllm/__about__.py,sha256=i__BfDO7vMnAnfpLD_eRpelQtqul4YDGjfKMTP7PM3Y,430
|
|
3
3
|
catllm/__init__.py,sha256=sf02zp7N0NW0mAQi7eQ4gliWR1EwoqvXkHN2HwwjcTE,372
|
|
4
4
|
catllm/build_web_research.py,sha256=880dfE2bEQb-FrXP-42JoLLtyc9ox_sBULDr38xiTiQ,22655
|
|
5
5
|
catllm/image_functions.py,sha256=8_FftRU285x1HT-AgNkaobefQVD-5q7ZY_t7JFdL3Sg,36177
|
|
6
|
-
catllm/
|
|
6
|
+
catllm/model_reference_list.py,sha256=bakqZinbGCyY_SBJJIyBLnosKvJuna6B6TWne7YHfC8,2202
|
|
7
|
+
catllm/text_functions.py,sha256=nVXB6Z7_AYXZoXuApu5GoE4anSOtHz1Y6t71OBSyRQI,39408
|
|
7
8
|
catllm/images/circle.png,sha256=JWujAWAh08-TajAoEr_TAeFNLlfbryOLw6cgIBREBuQ,86202
|
|
8
9
|
catllm/images/cube.png,sha256=nFec3e5bmRe4zrBCJ8QK-HcJLrG7u7dYdKhmdMfacfE,77275
|
|
9
10
|
catllm/images/diamond.png,sha256=rJDZKtsnBGRO8FPA0iHuA8FvHFGi9PkI_DWSFdw6iv0,99568
|
|
10
11
|
catllm/images/overlapping_pentagons.png,sha256=VO5plI6eoVRnjfqinn1nNzsCP2WQhuQy71V0EASouW4,71208
|
|
11
12
|
catllm/images/rectangles.png,sha256=2XM16HO9EYWj2yHgN4bPXaCwPfl7iYQy0tQUGaJX9xg,40692
|
|
12
|
-
cat_llm-0.0.
|
|
13
|
-
cat_llm-0.0.
|
|
14
|
-
cat_llm-0.0.
|
|
15
|
-
cat_llm-0.0.
|
|
13
|
+
cat_llm-0.0.68.dist-info/METADATA,sha256=Z3SgZqkNNc8LW_eBc6Q8WmIXp0v2Qpjbjevz3kM4o_M,22423
|
|
14
|
+
cat_llm-0.0.68.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
15
|
+
cat_llm-0.0.68.dist-info/licenses/LICENSE,sha256=Vje2sS5WV4TnIwY5uQHrF4qnBAM3YOk1pGpdH0ot-2o,34969
|
|
16
|
+
cat_llm-0.0.68.dist-info/RECORD,,
|
catllm/__about__.py
CHANGED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# openai list of models
|
|
2
|
+
openai_models = [
|
|
3
|
+
"gpt-5",
|
|
4
|
+
"gpt-5-mini",
|
|
5
|
+
"gpt-5-nano",
|
|
6
|
+
"gpt-4o",
|
|
7
|
+
"gpt-4o-mini",
|
|
8
|
+
"gpt-4.1",
|
|
9
|
+
"gpt-4.1-mini",
|
|
10
|
+
"gpt-4.1-nano",
|
|
11
|
+
"gpt-3.5-turbo",
|
|
12
|
+
"text-davinci-003",
|
|
13
|
+
"text-davinci-002"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
# anthropic list of models
|
|
17
|
+
anthropic_models = [
|
|
18
|
+
"claude-opus-4-20250514-v1:0",
|
|
19
|
+
"claude-opus-4-1-20250805-v1:0",
|
|
20
|
+
"claude-sonnet-4-5-20250929-v1:0",
|
|
21
|
+
"claude-sonnet-4-20250514-v1:0",
|
|
22
|
+
"claude-3-7-sonnet-20250219-v1:0",
|
|
23
|
+
"claude-3-5-sonnet-20240620-v1:0",
|
|
24
|
+
"claude-3-5-haiku-20241022-v1:0",
|
|
25
|
+
"claude-3-opus-20240229-v1:0",
|
|
26
|
+
"claude-3-sonnet-20240229-v1:0",
|
|
27
|
+
"claude-haiku-4-5-20251001-v1:0",
|
|
28
|
+
"claude-sonnet-4-5-20250929",
|
|
29
|
+
"claude-haiku-4-5-20251001",
|
|
30
|
+
"claude-opus-4-1-20250805"
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
# google list of models
|
|
34
|
+
|
|
35
|
+
google_models = [
|
|
36
|
+
"gemini-2.5-flash",
|
|
37
|
+
"gemini-2.5-flash-lite",
|
|
38
|
+
"gemini-2.5-pro",
|
|
39
|
+
"gemini-2.0-flash",
|
|
40
|
+
"gemini-2.0-flash-lite",
|
|
41
|
+
"gemini-2.0-pro",
|
|
42
|
+
"gemini-2.5",
|
|
43
|
+
"gemini-2.0",
|
|
44
|
+
"gemini-2.5-flash-preview",
|
|
45
|
+
"gemini-2.5-pro-preview"
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# perplexity list of models
|
|
49
|
+
|
|
50
|
+
perplexity_models = [
|
|
51
|
+
"sonar",
|
|
52
|
+
"sonar-pro",
|
|
53
|
+
"sonar-reasoning",
|
|
54
|
+
"sonar-reasoning-pro",
|
|
55
|
+
"sonar-deep-research",
|
|
56
|
+
"r1-1776"
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
# mistral list of models
|
|
60
|
+
|
|
61
|
+
mistral_models = [
|
|
62
|
+
"mistral-large-latest",
|
|
63
|
+
"mistral-medium-2505",
|
|
64
|
+
"mistral-large-2411",
|
|
65
|
+
"codestral-2501",
|
|
66
|
+
"pixtral-large-2411",
|
|
67
|
+
"mistral-small-2407",
|
|
68
|
+
"mistral-embed",
|
|
69
|
+
"codestral-embed",
|
|
70
|
+
"mistral-moderation-2411",
|
|
71
|
+
"ministral-3b-2410",
|
|
72
|
+
"ministral-8b-2410"
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
# meta list of models
|
|
76
|
+
meta_llama_models = [
|
|
77
|
+
"meta/llama-3.1-8b-instruct",
|
|
78
|
+
"meta/llama-3.1-70b-instruct",
|
|
79
|
+
"meta/llama-3.1-405b-instruct",
|
|
80
|
+
"meta/llama-3.2-11b-vision-instruct",
|
|
81
|
+
"meta/llama-3.2-90b-vision-instruct",
|
|
82
|
+
"meta/llama-3.3-70b-instruct",
|
|
83
|
+
"meta/llama-4-scout-17b-16e-instruct",
|
|
84
|
+
"meta/llama-4-maverick-17b-128e-instruct",
|
|
85
|
+
"llama-4-maverick-17b-128e-instruct-maas",
|
|
86
|
+
"llama-4-scout-17b-16e-instruct-maas",
|
|
87
|
+
"llama-3.3-70b-instruct-maas",
|
|
88
|
+
"llama-3.2-90b-vision-instruct-maas",
|
|
89
|
+
"llama-3.1-405b-instruct-maas",
|
|
90
|
+
"llama-3.1-70b-instruct-maas",
|
|
91
|
+
"llama-3.1-8b-instruct-maas",
|
|
92
|
+
|
|
93
|
+
]
|
catllm/text_functions.py
CHANGED
|
@@ -232,12 +232,12 @@ Return the top {top_n} categories as a numbered list sorted from the most to lea
|
|
|
232
232
|
# GOAL: enable step-back prompting
|
|
233
233
|
# GOAL 2: enable self-consistency
|
|
234
234
|
def multi_class(
|
|
235
|
-
survey_question,
|
|
236
235
|
survey_input,
|
|
237
236
|
categories,
|
|
238
237
|
api_key,
|
|
239
238
|
user_model="gpt-5",
|
|
240
239
|
user_prompt = None,
|
|
240
|
+
survey_question = "",
|
|
241
241
|
example1 = None,
|
|
242
242
|
example2 = None,
|
|
243
243
|
example3 = None,
|
|
@@ -247,9 +247,10 @@ def multi_class(
|
|
|
247
247
|
creativity=None,
|
|
248
248
|
safety=False,
|
|
249
249
|
to_csv=False,
|
|
250
|
+
chain_of_verification=False,
|
|
250
251
|
filename="categorized_data.csv",
|
|
251
252
|
save_directory=None,
|
|
252
|
-
model_source="
|
|
253
|
+
model_source="auto"
|
|
253
254
|
):
|
|
254
255
|
import os
|
|
255
256
|
import json
|
|
@@ -257,7 +258,54 @@ def multi_class(
|
|
|
257
258
|
import regex
|
|
258
259
|
from tqdm import tqdm
|
|
259
260
|
|
|
261
|
+
def remove_numbering(line):
|
|
262
|
+
line = line.strip()
|
|
263
|
+
|
|
264
|
+
# Handle bullet points
|
|
265
|
+
if line.startswith('- '):
|
|
266
|
+
return line[2:].strip()
|
|
267
|
+
if line.startswith('• '):
|
|
268
|
+
return line[2:].strip()
|
|
269
|
+
|
|
270
|
+
# Handle numbered lists "1.", "10.", etc.
|
|
271
|
+
if line and line[0].isdigit():
|
|
272
|
+
# Find where the number ends
|
|
273
|
+
i = 0
|
|
274
|
+
while i < len(line) and line[i].isdigit():
|
|
275
|
+
i += 1
|
|
276
|
+
|
|
277
|
+
# Check if followed by '.' or ')'
|
|
278
|
+
if i < len(line) and line[i] in '.':
|
|
279
|
+
return line[i+1:].strip()
|
|
280
|
+
elif i < len(line) and line[i] in ')':
|
|
281
|
+
return line[i+1:].strip()
|
|
282
|
+
|
|
283
|
+
return line
|
|
284
|
+
|
|
260
285
|
model_source = model_source.lower() # eliminating case sensitivity
|
|
286
|
+
|
|
287
|
+
# auto-detect model source if not provided
|
|
288
|
+
if model_source is None or model_source == "auto":
|
|
289
|
+
user_model_lower = user_model.lower()
|
|
290
|
+
|
|
291
|
+
if "gpt" in user_model_lower:
|
|
292
|
+
model_source = "openai"
|
|
293
|
+
elif "claude" in user_model_lower:
|
|
294
|
+
model_source = "anthropic"
|
|
295
|
+
elif "gemini" in user_model_lower or "gemma" in user_model_lower:
|
|
296
|
+
model_source = "google"
|
|
297
|
+
elif "llama" in user_model_lower or "meta" in user_model_lower:
|
|
298
|
+
model_source = "huggingface"
|
|
299
|
+
elif "mistral" in user_model_lower or "mixtral" in user_model_lower:
|
|
300
|
+
model_source = "mistral"
|
|
301
|
+
elif "sonar" in user_model_lower or "pplx" in user_model_lower:
|
|
302
|
+
model_source = "perplexity"
|
|
303
|
+
elif "deepseek" in user_model_lower or "qwen" in user_model_lower:
|
|
304
|
+
model_source = "huggingface"
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(f"❌ Could not auto-detect model source from '{user_model}'. Please specify model_source explicitly: OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
|
|
307
|
+
else:
|
|
308
|
+
model_source = model_source.lower()
|
|
261
309
|
|
|
262
310
|
categories_str = "\n".join(f"{i + 1}. {cat}" for i, cat in enumerate(categories))
|
|
263
311
|
cat_num = len(categories)
|
|
@@ -265,17 +313,23 @@ def multi_class(
|
|
|
265
313
|
example_JSON = json.dumps(category_dict, indent=4)
|
|
266
314
|
|
|
267
315
|
# ensure number of categories is what user wants
|
|
268
|
-
print("\nThe categories you entered:")
|
|
316
|
+
print(f"\nThe categories you entered to be coded by {model_source} {user_model}:")
|
|
269
317
|
for i, cat in enumerate(categories, 1):
|
|
270
318
|
print(f"{i}. {cat}")
|
|
271
319
|
|
|
272
320
|
link1 = []
|
|
273
321
|
extracted_jsons = []
|
|
322
|
+
|
|
274
323
|
#handling example inputs
|
|
275
324
|
examples = [example1, example2, example3, example4, example5, example6]
|
|
276
325
|
examples_text = "\n".join(
|
|
277
326
|
f"Example {i}: {ex}" for i, ex in enumerate(examples, 1) if ex is not None
|
|
278
327
|
)
|
|
328
|
+
# allowing users to contextualize the survey question
|
|
329
|
+
if survey_question != None:
|
|
330
|
+
survey_question_context = f"A respondent was asked: {survey_question}."
|
|
331
|
+
else:
|
|
332
|
+
survey_question_context = ""
|
|
279
333
|
|
|
280
334
|
for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
|
|
281
335
|
reply = None
|
|
@@ -287,58 +341,248 @@ def multi_class(
|
|
|
287
341
|
#print(f"Skipped NaN input.")
|
|
288
342
|
else:
|
|
289
343
|
|
|
290
|
-
prompt = f"""
|
|
344
|
+
prompt = f"""{survey_question_context} \
|
|
291
345
|
Categorize this survey response "{response}" into the following categories that apply: \
|
|
292
346
|
{categories_str}
|
|
293
347
|
{examples_text}
|
|
294
|
-
Provide your work in JSON format
|
|
348
|
+
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
|
|
349
|
+
|
|
350
|
+
if chain_of_verification:
|
|
351
|
+
step2_prompt = f"""You provided this initial categorization:
|
|
352
|
+
<<INITIAL_REPLY>>
|
|
353
|
+
|
|
354
|
+
Original task: {prompt}
|
|
355
|
+
|
|
356
|
+
Generate a focused list of 3-5 verification questions to fact-check your categorization. Each question should:
|
|
357
|
+
- Be concise and specific (one sentence)
|
|
358
|
+
- Address a distinct aspect of the categorization
|
|
359
|
+
- Be answerable independently
|
|
360
|
+
|
|
361
|
+
Focus on verifying:
|
|
362
|
+
- Whether each category assignment is accurate
|
|
363
|
+
- Whether the categories match the criteria in the original task
|
|
364
|
+
- Whether there are any logical inconsistencies
|
|
365
|
+
|
|
366
|
+
Provide only the verification questions as a numbered list."""
|
|
367
|
+
|
|
368
|
+
step3_prompt = f"""Answer the following verification question based on the survey response provided.
|
|
369
|
+
|
|
370
|
+
Survey response: {response}
|
|
371
|
+
|
|
372
|
+
Verification question: <<QUESTION>>
|
|
295
373
|
|
|
296
|
-
|
|
374
|
+
Provide a brief, direct answer (1-2 sentences maximum).
|
|
375
|
+
|
|
376
|
+
Answer:"""
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
step4_prompt = f"""Original task: {prompt}
|
|
380
|
+
Initial categorization:
|
|
381
|
+
<<INITIAL_REPLY>>
|
|
382
|
+
Verification questions and answers:
|
|
383
|
+
<<VERIFICATION_QA>>
|
|
384
|
+
If no categories are present, assign "0" to all categories.
|
|
385
|
+
Provide the final corrected categorization in the same JSON format:"""
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
if model_source in ["openai", "perplexity", "huggingface"]:
|
|
297
389
|
from openai import OpenAI
|
|
298
|
-
|
|
390
|
+
from openai import OpenAI, BadRequestError, AuthenticationError
|
|
391
|
+
# conditional base_url setting based on model source
|
|
392
|
+
base_url = (
|
|
393
|
+
"https://api.perplexity.ai" if model_source == "perplexity"
|
|
394
|
+
else "https://router.huggingface.co/v1" if model_source == "huggingface"
|
|
395
|
+
else None # default
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
399
|
+
|
|
299
400
|
try:
|
|
300
401
|
response_obj = client.chat.completions.create(
|
|
301
402
|
model=user_model,
|
|
302
403
|
messages=[{'role': 'user', 'content': prompt}],
|
|
303
404
|
**({"temperature": creativity} if creativity is not None else {})
|
|
304
|
-
)
|
|
305
|
-
reply = response_obj.choices[0].message.content
|
|
306
|
-
link1.append(reply)
|
|
307
|
-
except Exception as e:
|
|
308
|
-
print(f"An error occurred: {e}")
|
|
309
|
-
link1.append(f"Error processing input: {e}")
|
|
310
|
-
elif model_source == "perplexity":
|
|
311
|
-
from openai import OpenAI
|
|
312
|
-
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
|
313
|
-
try:
|
|
314
|
-
response_obj = client.chat.completions.create(
|
|
315
|
-
model=user_model,
|
|
316
|
-
messages=[{'role': 'user', 'content': prompt}],
|
|
317
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
318
405
|
)
|
|
406
|
+
|
|
319
407
|
reply = response_obj.choices[0].message.content
|
|
320
|
-
|
|
408
|
+
|
|
409
|
+
if chain_of_verification:
|
|
410
|
+
try:
|
|
411
|
+
initial_reply = reply
|
|
412
|
+
#STEP 2: Generate verification questions
|
|
413
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
414
|
+
|
|
415
|
+
verification_response = client.chat.completions.create(
|
|
416
|
+
model=user_model,
|
|
417
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
418
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
verification_questions = verification_response.choices[0].message.content
|
|
422
|
+
#STEP 3: Answer verification questions
|
|
423
|
+
questions_list = [
|
|
424
|
+
remove_numbering(q)
|
|
425
|
+
for q in verification_questions.split('\n')
|
|
426
|
+
if q.strip()
|
|
427
|
+
]
|
|
428
|
+
verification_qa = []
|
|
429
|
+
|
|
430
|
+
#prompting each question individually
|
|
431
|
+
for question in questions_list:
|
|
432
|
+
|
|
433
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
434
|
+
|
|
435
|
+
answer_response = client.chat.completions.create(
|
|
436
|
+
model=user_model,
|
|
437
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
438
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
answer = answer_response.choices[0].message.content
|
|
442
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
443
|
+
|
|
444
|
+
#STEP 4: Final corrected categorization
|
|
445
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
446
|
+
|
|
447
|
+
step4_filled = (step4_prompt
|
|
448
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
449
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
450
|
+
|
|
451
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
452
|
+
|
|
453
|
+
final_response = client.chat.completions.create(
|
|
454
|
+
model=user_model,
|
|
455
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
456
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
reply = final_response.choices[0].message.content
|
|
460
|
+
|
|
461
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
462
|
+
link1.append(reply)
|
|
463
|
+
|
|
464
|
+
except Exception as e:
|
|
465
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
466
|
+
print("Falling back to initial response.\n")
|
|
467
|
+
link1.append(reply)
|
|
468
|
+
else:
|
|
469
|
+
#if chain of verification is not enabled, just append initial reply
|
|
470
|
+
link1.append(reply)
|
|
471
|
+
|
|
472
|
+
except BadRequestError as e:
|
|
473
|
+
# Model doesn't exist - halt immediately
|
|
474
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
|
321
475
|
except Exception as e:
|
|
322
476
|
print(f"An error occurred: {e}")
|
|
323
477
|
link1.append(f"Error processing input: {e}")
|
|
478
|
+
|
|
324
479
|
elif model_source == "anthropic":
|
|
480
|
+
|
|
325
481
|
import anthropic
|
|
326
482
|
client = anthropic.Anthropic(api_key=api_key)
|
|
483
|
+
|
|
327
484
|
try:
|
|
328
|
-
|
|
485
|
+
response_obj = client.messages.create(
|
|
329
486
|
model=user_model,
|
|
330
|
-
max_tokens=
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
487
|
+
max_tokens=4096,
|
|
488
|
+
messages=[{'role': 'user', 'content': prompt}],
|
|
489
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
reply = response_obj.content[0].text
|
|
493
|
+
|
|
494
|
+
if chain_of_verification:
|
|
495
|
+
try:
|
|
496
|
+
initial_reply = reply
|
|
497
|
+
#STEP 2: Generate verification questions
|
|
498
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
499
|
+
|
|
500
|
+
verification_response = client.messages.create(
|
|
501
|
+
model=user_model,
|
|
502
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
503
|
+
max_tokens=4096,
|
|
504
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
verification_questions = verification_response.content[0].text
|
|
508
|
+
#STEP 3: Answer verification questions
|
|
509
|
+
questions_list = [
|
|
510
|
+
remove_numbering(q)
|
|
511
|
+
for q in verification_questions.split('\n')
|
|
512
|
+
if q.strip()
|
|
513
|
+
]
|
|
514
|
+
print(f"Verification questions:\n{questions_list}\n")
|
|
515
|
+
verification_qa = []
|
|
516
|
+
|
|
517
|
+
#prompting each question individually
|
|
518
|
+
for question in questions_list:
|
|
519
|
+
|
|
520
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
521
|
+
|
|
522
|
+
answer_response = client.messages.create(
|
|
523
|
+
model=user_model,
|
|
524
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
525
|
+
max_tokens=4096,
|
|
526
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
answer = answer_response.content[0].text
|
|
530
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
531
|
+
|
|
532
|
+
#STEP 4: Final corrected categorization
|
|
533
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
534
|
+
|
|
535
|
+
step4_filled = (step4_prompt
|
|
536
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
537
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
538
|
+
|
|
539
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
540
|
+
|
|
541
|
+
final_response = client.messages.create(
|
|
542
|
+
model=user_model,
|
|
543
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
544
|
+
max_tokens=4096,
|
|
545
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
reply = final_response.content[0].text
|
|
549
|
+
|
|
550
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
551
|
+
link1.append(reply)
|
|
552
|
+
|
|
553
|
+
except Exception as e:
|
|
554
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
555
|
+
print("Falling back to initial response.\n")
|
|
556
|
+
link1.append(reply)
|
|
557
|
+
else:
|
|
558
|
+
#if chain of verification is not enabled, just append initial reply
|
|
559
|
+
link1.append(reply)
|
|
560
|
+
|
|
561
|
+
except anthropic.NotFoundError as e:
|
|
562
|
+
# Model doesn't exist - halt immediately
|
|
563
|
+
raise ValueError(f"❌ Model '{user_model}' on {model_source} not found. Please check the model name and try again.") from e
|
|
336
564
|
except Exception as e:
|
|
337
565
|
print(f"An error occurred: {e}")
|
|
338
566
|
link1.append(f"Error processing input: {e}")
|
|
339
567
|
|
|
340
568
|
elif model_source == "google":
|
|
341
569
|
import requests
|
|
570
|
+
|
|
571
|
+
def make_google_request(url, headers, payload, max_retries=3):
|
|
572
|
+
"""Make Google API request with exponential backoff on 429 errors"""
|
|
573
|
+
for attempt in range(max_retries):
|
|
574
|
+
try:
|
|
575
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
576
|
+
response.raise_for_status()
|
|
577
|
+
return response.json()
|
|
578
|
+
except requests.exceptions.HTTPError as e:
|
|
579
|
+
if e.response.status_code == 429 and attempt < max_retries - 1:
|
|
580
|
+
wait_time = 10 * (2 ** attempt)
|
|
581
|
+
print(f"⚠️ Rate limited. Waiting {wait_time}s...")
|
|
582
|
+
time.sleep(wait_time)
|
|
583
|
+
else:
|
|
584
|
+
raise
|
|
585
|
+
|
|
342
586
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent"
|
|
343
587
|
try:
|
|
344
588
|
headers = {
|
|
@@ -353,22 +597,100 @@ def multi_class(
|
|
|
353
597
|
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
354
598
|
}
|
|
355
599
|
|
|
356
|
-
|
|
357
|
-
response.raise_for_status() # Raise exception for HTTP errors
|
|
358
|
-
result = response.json()
|
|
600
|
+
result = make_google_request(url, headers, payload)
|
|
359
601
|
|
|
360
602
|
if "candidates" in result and result["candidates"]:
|
|
361
603
|
reply = result["candidates"][0]["content"]["parts"][0]["text"]
|
|
362
604
|
else:
|
|
363
605
|
reply = "No response generated"
|
|
364
606
|
|
|
365
|
-
|
|
607
|
+
if chain_of_verification:
|
|
608
|
+
try:
|
|
609
|
+
import time
|
|
610
|
+
initial_reply = reply
|
|
611
|
+
# STEP 2: Generate verification questions
|
|
612
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
613
|
+
|
|
614
|
+
payload_step2 = {
|
|
615
|
+
"contents": [{
|
|
616
|
+
"parts": [{"text": step2_filled}]
|
|
617
|
+
}],
|
|
618
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
result_step2 = make_google_request(url, headers, payload_step2)
|
|
622
|
+
|
|
623
|
+
verification_questions = result_step2["candidates"][0]["content"]["parts"][0]["text"]
|
|
624
|
+
|
|
625
|
+
# STEP 3: Answer verification questions
|
|
626
|
+
questions_list = [
|
|
627
|
+
remove_numbering(q)
|
|
628
|
+
for q in verification_questions.split('\n')
|
|
629
|
+
if q.strip()
|
|
630
|
+
]
|
|
631
|
+
verification_qa = []
|
|
632
|
+
|
|
633
|
+
for question in questions_list:
|
|
634
|
+
time.sleep(2) # temporary rate limit handling
|
|
635
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
636
|
+
payload_step3 = {
|
|
637
|
+
"contents": [{
|
|
638
|
+
"parts": [{"text": step3_filled}]
|
|
639
|
+
}],
|
|
640
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
result_step3 = make_google_request(url, headers, payload_step3)
|
|
644
|
+
|
|
645
|
+
answer = result_step3["candidates"][0]["content"]["parts"][0]["text"]
|
|
646
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
647
|
+
|
|
648
|
+
# STEP 4: Final corrected categorization
|
|
649
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
650
|
+
|
|
651
|
+
step4_filled = (step4_prompt
|
|
652
|
+
.replace('<<PROMPT>>', prompt)
|
|
653
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
654
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
655
|
+
|
|
656
|
+
payload_step4 = {
|
|
657
|
+
"contents": [{
|
|
658
|
+
"parts": [{"text": step4_filled}]
|
|
659
|
+
}],
|
|
660
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
661
|
+
}
|
|
662
|
+
|
|
663
|
+
result_step4 = make_google_request(url, headers, payload_step4)
|
|
664
|
+
|
|
665
|
+
reply = result_step4["candidates"][0]["content"]["parts"][0]["text"]
|
|
666
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
667
|
+
|
|
668
|
+
link1.append(reply)
|
|
669
|
+
|
|
670
|
+
except Exception as e:
|
|
671
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
672
|
+
print("Falling back to initial response.\n")
|
|
673
|
+
|
|
674
|
+
else:
|
|
675
|
+
# if chain of verification is not enabled, just append initial reply
|
|
676
|
+
link1.append(reply)
|
|
677
|
+
|
|
678
|
+
except requests.exceptions.HTTPError as e:
|
|
679
|
+
if e.response.status_code == 404:
|
|
680
|
+
raise ValueError(f"❌ Model '{user_model}' not found. Please check the model name and try again.") from e
|
|
681
|
+
elif e.response.status_code == 401 or e.response.status_code == 403:
|
|
682
|
+
raise ValueError(f"❌ Authentication failed. Please check your Google API key.") from e
|
|
683
|
+
else:
|
|
684
|
+
print(f"HTTP error occurred: {e}")
|
|
685
|
+
link1.append(f"Error processing input: {e}")
|
|
366
686
|
except Exception as e:
|
|
367
687
|
print(f"An error occurred: {e}")
|
|
368
688
|
link1.append(f"Error processing input: {e}")
|
|
369
689
|
|
|
370
690
|
elif model_source == "mistral":
|
|
371
691
|
from mistralai import Mistral
|
|
692
|
+
from mistralai.models import SDKError
|
|
693
|
+
|
|
372
694
|
client = Mistral(api_key=api_key)
|
|
373
695
|
try:
|
|
374
696
|
response = client.chat.complete(
|
|
@@ -379,12 +701,80 @@ def multi_class(
|
|
|
379
701
|
**({"temperature": creativity} if creativity is not None else {})
|
|
380
702
|
)
|
|
381
703
|
reply = response.choices[0].message.content
|
|
382
|
-
|
|
704
|
+
|
|
705
|
+
if chain_of_verification:
|
|
706
|
+
try:
|
|
707
|
+
initial_reply = reply
|
|
708
|
+
#STEP 2: Generate verification questions
|
|
709
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
710
|
+
|
|
711
|
+
verification_response = client.chat.complete(
|
|
712
|
+
model=user_model,
|
|
713
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
714
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
verification_questions = verification_response.choices[0].message.content
|
|
718
|
+
#STEP 3: Answer verification questions
|
|
719
|
+
questions_list = [
|
|
720
|
+
remove_numbering(q)
|
|
721
|
+
for q in verification_questions.split('\n')
|
|
722
|
+
if q.strip()
|
|
723
|
+
]
|
|
724
|
+
verification_qa = []
|
|
725
|
+
|
|
726
|
+
#prompting each question individually
|
|
727
|
+
for question in questions_list:
|
|
728
|
+
|
|
729
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
730
|
+
|
|
731
|
+
answer_response = client.chat.complete(
|
|
732
|
+
model=user_model,
|
|
733
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
734
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
answer = answer_response.choices[0].message.content
|
|
738
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
739
|
+
|
|
740
|
+
#STEP 4: Final corrected categorization
|
|
741
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
742
|
+
|
|
743
|
+
step4_filled = (step4_prompt
|
|
744
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
745
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
746
|
+
|
|
747
|
+
final_response = client.chat.complete(
|
|
748
|
+
model=user_model,
|
|
749
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
750
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
reply = final_response.choices[0].message.content
|
|
754
|
+
|
|
755
|
+
link1.append(reply)
|
|
756
|
+
except Exception as e:
|
|
757
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
758
|
+
print("Falling back to initial response.\n")
|
|
759
|
+
else:
|
|
760
|
+
#if chain of verification is not enabled, just append initial reply
|
|
761
|
+
link1.append(reply)
|
|
762
|
+
|
|
763
|
+
except SDKError as e:
|
|
764
|
+
error_str = str(e).lower()
|
|
765
|
+
if "invalid_model" in error_str or "invalid model" in error_str:
|
|
766
|
+
raise ValueError(f"❌ Model '{user_model}' not found.") from e
|
|
767
|
+
elif "401" in str(e) or "unauthorized" in str(e).lower():
|
|
768
|
+
raise ValueError(f"❌ Authentication failed. Please check your Mistral API key.") from e
|
|
769
|
+
else:
|
|
770
|
+
print(f"An error occurred: {e}")
|
|
771
|
+
link1.append(f"Error processing input: {e}")
|
|
383
772
|
except Exception as e:
|
|
384
|
-
print(f"An error occurred: {e}")
|
|
773
|
+
print(f"An unexpected error occurred: {e}")
|
|
385
774
|
link1.append(f"Error processing input: {e}")
|
|
775
|
+
|
|
386
776
|
else:
|
|
387
|
-
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, or Mistral")
|
|
777
|
+
raise ValueError("Unknown source! Choose from OpenAI, Anthropic, Perplexity, Google, Huggingface, or Mistral")
|
|
388
778
|
# in situation that no JSON is found
|
|
389
779
|
if reply is not None:
|
|
390
780
|
extracted_json = regex.findall(r'\{(?:[^{}]|(?R))*\}', reply, regex.DOTALL)
|
|
File without changes
|
|
File without changes
|