cat-llm 0.0.68__py3-none-any.whl → 0.0.69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {cat_llm-0.0.68.dist-info → cat_llm-0.0.69.dist-info}/METADATA +2 -2
- {cat_llm-0.0.68.dist-info → cat_llm-0.0.69.dist-info}/RECORD +10 -7
- catllm/__about__.py +1 -1
- catllm/calls/CoVe.py +304 -0
- catllm/calls/__init__.py +25 -0
- catllm/calls/all_calls.py +433 -0
- catllm/model_reference_list.py +1 -0
- catllm/text_functions.py +147 -244
- {cat_llm-0.0.68.dist-info → cat_llm-0.0.69.dist-info}/WHEEL +0 -0
- {cat_llm-0.0.68.dist-info → cat_llm-0.0.69.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cat-llm
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.69
|
|
4
4
|
Summary: A tool for categorizing text data and images using LLMs and vision models
|
|
5
5
|
Project-URL: Documentation, https://github.com/chrissoria/cat-llm#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/chrissoria/cat-llm/issues
|
|
@@ -29,7 +29,7 @@ Description-Content-Type: text/markdown
|
|
|
29
29
|
|
|
30
30
|

|
|
31
31
|
|
|
32
|
-
#
|
|
32
|
+
# cat-llm
|
|
33
33
|
|
|
34
34
|
[](https://pypi.org/project/cat-llm)
|
|
35
35
|
[](https://pypi.org/project/cat-llm)
|
|
@@ -1,16 +1,19 @@
|
|
|
1
1
|
catllm/CERAD_functions.py,sha256=q4HbP5e2Yu8NnZZ-2eX4sImyj6u3i8xWcq0pYU81iis,22676
|
|
2
|
-
catllm/__about__.py,sha256=
|
|
2
|
+
catllm/__about__.py,sha256=qQkN04YWoxAJ5HglANO-XGwexy9aL_qFoZSv_CueaUs,430
|
|
3
3
|
catllm/__init__.py,sha256=sf02zp7N0NW0mAQi7eQ4gliWR1EwoqvXkHN2HwwjcTE,372
|
|
4
4
|
catllm/build_web_research.py,sha256=880dfE2bEQb-FrXP-42JoLLtyc9ox_sBULDr38xiTiQ,22655
|
|
5
5
|
catllm/image_functions.py,sha256=8_FftRU285x1HT-AgNkaobefQVD-5q7ZY_t7JFdL3Sg,36177
|
|
6
|
-
catllm/model_reference_list.py,sha256=
|
|
7
|
-
catllm/text_functions.py,sha256=
|
|
6
|
+
catllm/model_reference_list.py,sha256=37pWwMcgnf4biE3BVRluH5oz2P6ccdJJiCVNHodBH8k,2307
|
|
7
|
+
catllm/text_functions.py,sha256=Vd9tAPDCDEhoXVW6O-jXeftJiZQmsyyrKeEUneYeobw,32533
|
|
8
|
+
catllm/calls/CoVe.py,sha256=Y9OGJbaeJ3Odwira92cPXUlnm_ADFqvpOSFSNjFzMMU,10847
|
|
9
|
+
catllm/calls/__init__.py,sha256=fWuMwLeSGa6zXJYd4s8IyNblsD62G-1NMUsOKrNIkoI,725
|
|
10
|
+
catllm/calls/all_calls.py,sha256=E25KpZ_MakMDeCpNCOOM8kQvlfex6UMjnGN1wHkA4AI,14356
|
|
8
11
|
catllm/images/circle.png,sha256=JWujAWAh08-TajAoEr_TAeFNLlfbryOLw6cgIBREBuQ,86202
|
|
9
12
|
catllm/images/cube.png,sha256=nFec3e5bmRe4zrBCJ8QK-HcJLrG7u7dYdKhmdMfacfE,77275
|
|
10
13
|
catllm/images/diamond.png,sha256=rJDZKtsnBGRO8FPA0iHuA8FvHFGi9PkI_DWSFdw6iv0,99568
|
|
11
14
|
catllm/images/overlapping_pentagons.png,sha256=VO5plI6eoVRnjfqinn1nNzsCP2WQhuQy71V0EASouW4,71208
|
|
12
15
|
catllm/images/rectangles.png,sha256=2XM16HO9EYWj2yHgN4bPXaCwPfl7iYQy0tQUGaJX9xg,40692
|
|
13
|
-
cat_llm-0.0.
|
|
14
|
-
cat_llm-0.0.
|
|
15
|
-
cat_llm-0.0.
|
|
16
|
-
cat_llm-0.0.
|
|
16
|
+
cat_llm-0.0.69.dist-info/METADATA,sha256=E2q6apmvq1sDDiisnfyyQZzxqjNnqjCSecpalb5MgWQ,22424
|
|
17
|
+
cat_llm-0.0.69.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
18
|
+
cat_llm-0.0.69.dist-info/licenses/LICENSE,sha256=Vje2sS5WV4TnIwY5uQHrF4qnBAM3YOk1pGpdH0ot-2o,34969
|
|
19
|
+
cat_llm-0.0.69.dist-info/RECORD,,
|
catllm/__about__.py
CHANGED
catllm/calls/CoVe.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
# openai chain of verification calls
|
|
2
|
+
|
|
3
|
+
def chain_of_verification_openai(
|
|
4
|
+
initial_reply,
|
|
5
|
+
step2_prompt,
|
|
6
|
+
step3_prompt,
|
|
7
|
+
step4_prompt,
|
|
8
|
+
client,
|
|
9
|
+
user_model,
|
|
10
|
+
creativity,
|
|
11
|
+
remove_numbering
|
|
12
|
+
):
|
|
13
|
+
"""
|
|
14
|
+
Execute Chain of Verification (CoVe) process.
|
|
15
|
+
Returns the verified reply or initial reply if error occurs.
|
|
16
|
+
"""
|
|
17
|
+
try:
|
|
18
|
+
# STEP 2: Generate verification questions
|
|
19
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
20
|
+
|
|
21
|
+
verification_response = client.chat.completions.create(
|
|
22
|
+
model=user_model,
|
|
23
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
24
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
verification_questions = verification_response.choices[0].message.content
|
|
28
|
+
|
|
29
|
+
# STEP 3: Answer verification questions
|
|
30
|
+
questions_list = [
|
|
31
|
+
remove_numbering(q)
|
|
32
|
+
for q in verification_questions.split('\n')
|
|
33
|
+
if q.strip()
|
|
34
|
+
]
|
|
35
|
+
verification_qa = []
|
|
36
|
+
|
|
37
|
+
# Prompting each question individually
|
|
38
|
+
for question in questions_list:
|
|
39
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
40
|
+
|
|
41
|
+
answer_response = client.chat.completions.create(
|
|
42
|
+
model=user_model,
|
|
43
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
44
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
answer = answer_response.choices[0].message.content
|
|
48
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
49
|
+
|
|
50
|
+
# STEP 4: Final corrected categorization
|
|
51
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
52
|
+
|
|
53
|
+
step4_filled = (step4_prompt
|
|
54
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
55
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
56
|
+
|
|
57
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
58
|
+
|
|
59
|
+
final_response = client.chat.completions.create(
|
|
60
|
+
model=user_model,
|
|
61
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
62
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
verified_reply = final_response.choices[0].message.content
|
|
66
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
67
|
+
|
|
68
|
+
return verified_reply
|
|
69
|
+
|
|
70
|
+
except Exception as e:
|
|
71
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
72
|
+
print("Falling back to initial response.\n")
|
|
73
|
+
return initial_reply
|
|
74
|
+
|
|
75
|
+
# anthropic chain of verification calls
|
|
76
|
+
|
|
77
|
+
def chain_of_verification_anthropic(
|
|
78
|
+
initial_reply,
|
|
79
|
+
step2_prompt,
|
|
80
|
+
step3_prompt,
|
|
81
|
+
step4_prompt,
|
|
82
|
+
client,
|
|
83
|
+
user_model,
|
|
84
|
+
creativity,
|
|
85
|
+
remove_numbering
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Execute Chain of Verification (CoVe) process for Anthropic Claude.
|
|
89
|
+
Returns the verified reply or initial reply if error occurs.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
# STEP 2: Generate verification questions
|
|
93
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
94
|
+
|
|
95
|
+
verification_response = client.messages.create(
|
|
96
|
+
model=user_model,
|
|
97
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
98
|
+
max_tokens=4096,
|
|
99
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
verification_questions = verification_response.content[0].text
|
|
103
|
+
|
|
104
|
+
# STEP 3: Answer verification questions
|
|
105
|
+
questions_list = [
|
|
106
|
+
remove_numbering(q)
|
|
107
|
+
for q in verification_questions.split('\n')
|
|
108
|
+
if q.strip()
|
|
109
|
+
]
|
|
110
|
+
print(f"Verification questions:\n{questions_list}\n")
|
|
111
|
+
verification_qa = []
|
|
112
|
+
|
|
113
|
+
# Prompting each question individually
|
|
114
|
+
for question in questions_list:
|
|
115
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
116
|
+
|
|
117
|
+
answer_response = client.messages.create(
|
|
118
|
+
model=user_model,
|
|
119
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
120
|
+
max_tokens=4096,
|
|
121
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
answer = answer_response.content[0].text
|
|
125
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
126
|
+
|
|
127
|
+
# STEP 4: Final corrected categorization
|
|
128
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
129
|
+
|
|
130
|
+
step4_filled = (step4_prompt
|
|
131
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
132
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
133
|
+
|
|
134
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
135
|
+
|
|
136
|
+
final_response = client.messages.create(
|
|
137
|
+
model=user_model,
|
|
138
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
139
|
+
max_tokens=4096,
|
|
140
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
verified_reply = final_response.content[0].text
|
|
144
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
145
|
+
|
|
146
|
+
return verified_reply
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
150
|
+
print("Falling back to initial response.\n")
|
|
151
|
+
return initial_reply
|
|
152
|
+
|
|
153
|
+
# google chain of verification calls
|
|
154
|
+
def chain_of_verification_google(
|
|
155
|
+
initial_reply,
|
|
156
|
+
prompt,
|
|
157
|
+
step2_prompt,
|
|
158
|
+
step3_prompt,
|
|
159
|
+
step4_prompt,
|
|
160
|
+
url,
|
|
161
|
+
headers,
|
|
162
|
+
creativity,
|
|
163
|
+
remove_numbering,
|
|
164
|
+
make_google_request
|
|
165
|
+
):
|
|
166
|
+
import time
|
|
167
|
+
"""
|
|
168
|
+
Execute Chain of Verification (CoVe) process for Google Gemini.
|
|
169
|
+
Returns the verified reply or initial reply if error occurs.
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
# STEP 2: Generate verification questions
|
|
173
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
174
|
+
|
|
175
|
+
payload_step2 = {
|
|
176
|
+
"contents": [{
|
|
177
|
+
"parts": [{"text": step2_filled}]
|
|
178
|
+
}],
|
|
179
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
result_step2 = make_google_request(url, headers, payload_step2)
|
|
183
|
+
verification_questions = result_step2["candidates"][0]["content"]["parts"][0]["text"]
|
|
184
|
+
|
|
185
|
+
# STEP 3: Answer verification questions
|
|
186
|
+
questions_list = [
|
|
187
|
+
remove_numbering(q)
|
|
188
|
+
for q in verification_questions.split('\n')
|
|
189
|
+
if q.strip()
|
|
190
|
+
]
|
|
191
|
+
verification_qa = []
|
|
192
|
+
|
|
193
|
+
for question in questions_list:
|
|
194
|
+
time.sleep(2) # temporary rate limit handling
|
|
195
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
196
|
+
|
|
197
|
+
payload_step3 = {
|
|
198
|
+
"contents": [{
|
|
199
|
+
"parts": [{"text": step3_filled}]
|
|
200
|
+
}],
|
|
201
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
result_step3 = make_google_request(url, headers, payload_step3)
|
|
205
|
+
answer = result_step3["candidates"][0]["content"]["parts"][0]["text"]
|
|
206
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
207
|
+
|
|
208
|
+
# STEP 4: Final corrected categorization
|
|
209
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
210
|
+
|
|
211
|
+
step4_filled = (step4_prompt
|
|
212
|
+
.replace('<<PROMPT>>', prompt)
|
|
213
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
214
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
215
|
+
|
|
216
|
+
payload_step4 = {
|
|
217
|
+
"contents": [{
|
|
218
|
+
"parts": [{"text": step4_filled}]
|
|
219
|
+
}],
|
|
220
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
result_step4 = make_google_request(url, headers, payload_step4)
|
|
224
|
+
verified_reply = result_step4["candidates"][0]["content"]["parts"][0]["text"]
|
|
225
|
+
|
|
226
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
227
|
+
return verified_reply
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
231
|
+
print("Falling back to initial response.\n")
|
|
232
|
+
return initial_reply
|
|
233
|
+
|
|
234
|
+
# mistral chain of verification calls
|
|
235
|
+
|
|
236
|
+
def chain_of_verification_mistral(
|
|
237
|
+
initial_reply,
|
|
238
|
+
step2_prompt,
|
|
239
|
+
step3_prompt,
|
|
240
|
+
step4_prompt,
|
|
241
|
+
client,
|
|
242
|
+
user_model,
|
|
243
|
+
creativity,
|
|
244
|
+
remove_numbering
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Execute Chain of Verification (CoVe) process for Mistral AI.
|
|
248
|
+
Returns the verified reply or initial reply if error occurs.
|
|
249
|
+
"""
|
|
250
|
+
try:
|
|
251
|
+
# STEP 2: Generate verification questions
|
|
252
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
253
|
+
|
|
254
|
+
verification_response = client.chat.complete(
|
|
255
|
+
model=user_model,
|
|
256
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
257
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
verification_questions = verification_response.choices[0].message.content
|
|
261
|
+
|
|
262
|
+
# STEP 3: Answer verification questions
|
|
263
|
+
questions_list = [
|
|
264
|
+
remove_numbering(q)
|
|
265
|
+
for q in verification_questions.split('\n')
|
|
266
|
+
if q.strip()
|
|
267
|
+
]
|
|
268
|
+
verification_qa = []
|
|
269
|
+
|
|
270
|
+
# Prompting each question individually
|
|
271
|
+
for question in questions_list:
|
|
272
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
273
|
+
|
|
274
|
+
answer_response = client.chat.complete(
|
|
275
|
+
model=user_model,
|
|
276
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
277
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
answer = answer_response.choices[0].message.content
|
|
281
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
282
|
+
|
|
283
|
+
# STEP 4: Final corrected categorization
|
|
284
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
285
|
+
|
|
286
|
+
step4_filled = (step4_prompt
|
|
287
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
288
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
289
|
+
|
|
290
|
+
final_response = client.chat.complete(
|
|
291
|
+
model=user_model,
|
|
292
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
293
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
verified_reply = final_response.choices[0].message.content
|
|
297
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
298
|
+
|
|
299
|
+
return verified_reply
|
|
300
|
+
|
|
301
|
+
except Exception as e:
|
|
302
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
303
|
+
print("Falling back to initial response.\n")
|
|
304
|
+
return initial_reply
|
catllm/calls/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025-present Christopher Soria <chrissoria@berkeley.edu>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MIT
|
|
4
|
+
|
|
5
|
+
from .all_calls import (
|
|
6
|
+
get_stepback_insight_openai,
|
|
7
|
+
get_stepback_insight_anthropic,
|
|
8
|
+
get_stepback_insight_google,
|
|
9
|
+
get_stepback_insight_mistral,
|
|
10
|
+
chain_of_verification_openai,
|
|
11
|
+
chain_of_verification_google,
|
|
12
|
+
chain_of_verification_anthropic,
|
|
13
|
+
chain_of_verification_mistral
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
'get_stepback_insight_openai',
|
|
18
|
+
'get_stepback_insight_anthropic',
|
|
19
|
+
'get_stepback_insight_google',
|
|
20
|
+
'get_stepback_insight_mistral',
|
|
21
|
+
'chain_of_verification_openai',
|
|
22
|
+
'chain_of_verification_anthropic',
|
|
23
|
+
'chain_of_verification_google',
|
|
24
|
+
'chain_of_verification_mistral',
|
|
25
|
+
]
|
|
@@ -0,0 +1,433 @@
|
|
|
1
|
+
# openai stepback prompt
|
|
2
|
+
|
|
3
|
+
def get_stepback_insight_openai(
|
|
4
|
+
stepback,
|
|
5
|
+
api_key,
|
|
6
|
+
user_model,
|
|
7
|
+
model_source="openai",
|
|
8
|
+
creativity=None
|
|
9
|
+
):
|
|
10
|
+
from openai import OpenAI
|
|
11
|
+
# Conditional base_url setting based on model source
|
|
12
|
+
base_url = (
|
|
13
|
+
"https://api.perplexity.ai" if model_source == "perplexity"
|
|
14
|
+
else "https://router.huggingface.co/v1" if model_source == "huggingface"
|
|
15
|
+
else None
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
stepback_response = client.chat.completions.create(
|
|
22
|
+
model=user_model,
|
|
23
|
+
messages=[{'role': 'user', 'content': stepback}],
|
|
24
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
25
|
+
)
|
|
26
|
+
stepback_insight = stepback_response.choices[0].message.content
|
|
27
|
+
|
|
28
|
+
return stepback_insight, True
|
|
29
|
+
|
|
30
|
+
except Exception as e:
|
|
31
|
+
print(f"An error occurred during step-back prompting: {e}")
|
|
32
|
+
return None, False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# claude stepback prompt
|
|
36
|
+
|
|
37
|
+
def get_stepback_insight_anthropic(
|
|
38
|
+
stepback,
|
|
39
|
+
api_key,
|
|
40
|
+
user_model,
|
|
41
|
+
model_source="anthropic",
|
|
42
|
+
creativity=None
|
|
43
|
+
):
|
|
44
|
+
import anthropic
|
|
45
|
+
|
|
46
|
+
client = anthropic.Anthropic(api_key=api_key)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
stepback_response = client.messages.create(
|
|
50
|
+
model=user_model,
|
|
51
|
+
max_tokens=4096,
|
|
52
|
+
messages=[{'role': 'user', 'content': stepback}],
|
|
53
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
54
|
+
)
|
|
55
|
+
stepback_insight = stepback_response.content[0].text
|
|
56
|
+
|
|
57
|
+
return stepback_insight, True
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
print(f"An error occurred during step-back prompting: {e}")
|
|
61
|
+
return None, False
|
|
62
|
+
|
|
63
|
+
# google stepback prompt
|
|
64
|
+
|
|
65
|
+
def get_stepback_insight_google(
|
|
66
|
+
stepback,
|
|
67
|
+
api_key,
|
|
68
|
+
user_model,
|
|
69
|
+
model_source="google",
|
|
70
|
+
creativity=None
|
|
71
|
+
):
|
|
72
|
+
|
|
73
|
+
import requests
|
|
74
|
+
|
|
75
|
+
url = f"https://generativelanguage.googleapis.com/v1beta/models/{user_model}:generateContent?key={api_key}"
|
|
76
|
+
|
|
77
|
+
headers = {
|
|
78
|
+
"Content-Type": "application/json"
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
payload = {
|
|
82
|
+
"contents": [{
|
|
83
|
+
"parts": [{"text": stepback}],
|
|
84
|
+
|
|
85
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
86
|
+
}]
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
response = requests.post(url, headers=headers, json=payload)
|
|
91
|
+
response.raise_for_status() # Raise error for bad status codes
|
|
92
|
+
|
|
93
|
+
result = response.json()
|
|
94
|
+
stepback_insight = result['candidates'][0]['content']['parts'][0]['text']
|
|
95
|
+
|
|
96
|
+
return stepback_insight, True
|
|
97
|
+
|
|
98
|
+
except Exception as e:
|
|
99
|
+
print(f"An error occurred during step-back prompting: {e}")
|
|
100
|
+
return None, False
|
|
101
|
+
|
|
102
|
+
# mistral stepback prompt
|
|
103
|
+
|
|
104
|
+
def get_stepback_insight_mistral(
|
|
105
|
+
stepback,
|
|
106
|
+
api_key,
|
|
107
|
+
user_model,
|
|
108
|
+
model_source="mistral",
|
|
109
|
+
creativity=None
|
|
110
|
+
):
|
|
111
|
+
|
|
112
|
+
from mistralai import Mistral
|
|
113
|
+
|
|
114
|
+
client = Mistral(api_key=api_key)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
stepback_response = client.chat.complete(
|
|
118
|
+
model=user_model,
|
|
119
|
+
messages=[{'role': 'user', 'content': stepback}],
|
|
120
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
121
|
+
)
|
|
122
|
+
stepback_insight = stepback_response.choices[0].message.content
|
|
123
|
+
|
|
124
|
+
return stepback_insight, True
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print(f"An error occurred during step-back prompting: {e}")
|
|
128
|
+
return None, False
|
|
129
|
+
|
|
130
|
+
# openai chain of verification calls
|
|
131
|
+
|
|
132
|
+
def chain_of_verification_openai(
|
|
133
|
+
initial_reply,
|
|
134
|
+
step2_prompt,
|
|
135
|
+
step3_prompt,
|
|
136
|
+
step4_prompt,
|
|
137
|
+
client,
|
|
138
|
+
user_model,
|
|
139
|
+
creativity,
|
|
140
|
+
remove_numbering
|
|
141
|
+
):
|
|
142
|
+
"""
|
|
143
|
+
Execute Chain of Verification (CoVe) process.
|
|
144
|
+
Returns the verified reply or initial reply if error occurs.
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
# STEP 2: Generate verification questions
|
|
148
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
149
|
+
|
|
150
|
+
verification_response = client.chat.completions.create(
|
|
151
|
+
model=user_model,
|
|
152
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
153
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
verification_questions = verification_response.choices[0].message.content
|
|
157
|
+
|
|
158
|
+
# STEP 3: Answer verification questions
|
|
159
|
+
questions_list = [
|
|
160
|
+
remove_numbering(q)
|
|
161
|
+
for q in verification_questions.split('\n')
|
|
162
|
+
if q.strip()
|
|
163
|
+
]
|
|
164
|
+
verification_qa = []
|
|
165
|
+
|
|
166
|
+
# Prompting each question individually
|
|
167
|
+
for question in questions_list:
|
|
168
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
169
|
+
|
|
170
|
+
answer_response = client.chat.completions.create(
|
|
171
|
+
model=user_model,
|
|
172
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
173
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
answer = answer_response.choices[0].message.content
|
|
177
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
178
|
+
|
|
179
|
+
# STEP 4: Final corrected categorization
|
|
180
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
181
|
+
|
|
182
|
+
step4_filled = (step4_prompt
|
|
183
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
184
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
185
|
+
|
|
186
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
187
|
+
|
|
188
|
+
final_response = client.chat.completions.create(
|
|
189
|
+
model=user_model,
|
|
190
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
191
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
verified_reply = final_response.choices[0].message.content
|
|
195
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
196
|
+
|
|
197
|
+
return verified_reply
|
|
198
|
+
|
|
199
|
+
except Exception as e:
|
|
200
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
201
|
+
print("Falling back to initial response.\n")
|
|
202
|
+
return initial_reply
|
|
203
|
+
|
|
204
|
+
# anthropic chain of verification calls
|
|
205
|
+
|
|
206
|
+
def chain_of_verification_anthropic(
|
|
207
|
+
initial_reply,
|
|
208
|
+
step2_prompt,
|
|
209
|
+
step3_prompt,
|
|
210
|
+
step4_prompt,
|
|
211
|
+
client,
|
|
212
|
+
user_model,
|
|
213
|
+
creativity,
|
|
214
|
+
remove_numbering
|
|
215
|
+
):
|
|
216
|
+
"""
|
|
217
|
+
Execute Chain of Verification (CoVe) process for Anthropic Claude.
|
|
218
|
+
Returns the verified reply or initial reply if error occurs.
|
|
219
|
+
"""
|
|
220
|
+
try:
|
|
221
|
+
# STEP 2: Generate verification questions
|
|
222
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
223
|
+
|
|
224
|
+
verification_response = client.messages.create(
|
|
225
|
+
model=user_model,
|
|
226
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
227
|
+
max_tokens=4096,
|
|
228
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
verification_questions = verification_response.content[0].text
|
|
232
|
+
|
|
233
|
+
# STEP 3: Answer verification questions
|
|
234
|
+
questions_list = [
|
|
235
|
+
remove_numbering(q)
|
|
236
|
+
for q in verification_questions.split('\n')
|
|
237
|
+
if q.strip()
|
|
238
|
+
]
|
|
239
|
+
print(f"Verification questions:\n{questions_list}\n")
|
|
240
|
+
verification_qa = []
|
|
241
|
+
|
|
242
|
+
# Prompting each question individually
|
|
243
|
+
for question in questions_list:
|
|
244
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
245
|
+
|
|
246
|
+
answer_response = client.messages.create(
|
|
247
|
+
model=user_model,
|
|
248
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
249
|
+
max_tokens=4096,
|
|
250
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
answer = answer_response.content[0].text
|
|
254
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
255
|
+
|
|
256
|
+
# STEP 4: Final corrected categorization
|
|
257
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
258
|
+
|
|
259
|
+
step4_filled = (step4_prompt
|
|
260
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
261
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
262
|
+
|
|
263
|
+
print(f"Final prompt:\n{step4_filled}\n")
|
|
264
|
+
|
|
265
|
+
final_response = client.messages.create(
|
|
266
|
+
model=user_model,
|
|
267
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
268
|
+
max_tokens=4096,
|
|
269
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
verified_reply = final_response.content[0].text
|
|
273
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
274
|
+
|
|
275
|
+
return verified_reply
|
|
276
|
+
|
|
277
|
+
except Exception as e:
|
|
278
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
279
|
+
print("Falling back to initial response.\n")
|
|
280
|
+
return initial_reply
|
|
281
|
+
|
|
282
|
+
# google chain of verification calls
|
|
283
|
+
def chain_of_verification_google(
|
|
284
|
+
initial_reply,
|
|
285
|
+
prompt,
|
|
286
|
+
step2_prompt,
|
|
287
|
+
step3_prompt,
|
|
288
|
+
step4_prompt,
|
|
289
|
+
url,
|
|
290
|
+
headers,
|
|
291
|
+
creativity,
|
|
292
|
+
remove_numbering,
|
|
293
|
+
make_google_request
|
|
294
|
+
):
|
|
295
|
+
import time
|
|
296
|
+
"""
|
|
297
|
+
Execute Chain of Verification (CoVe) process for Google Gemini.
|
|
298
|
+
Returns the verified reply or initial reply if error occurs.
|
|
299
|
+
"""
|
|
300
|
+
try:
|
|
301
|
+
# STEP 2: Generate verification questions
|
|
302
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
303
|
+
|
|
304
|
+
payload_step2 = {
|
|
305
|
+
"contents": [{
|
|
306
|
+
"parts": [{"text": step2_filled}]
|
|
307
|
+
}],
|
|
308
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
result_step2 = make_google_request(url, headers, payload_step2)
|
|
312
|
+
verification_questions = result_step2["candidates"][0]["content"]["parts"][0]["text"]
|
|
313
|
+
|
|
314
|
+
# STEP 3: Answer verification questions
|
|
315
|
+
questions_list = [
|
|
316
|
+
remove_numbering(q)
|
|
317
|
+
for q in verification_questions.split('\n')
|
|
318
|
+
if q.strip()
|
|
319
|
+
]
|
|
320
|
+
verification_qa = []
|
|
321
|
+
|
|
322
|
+
for question in questions_list:
|
|
323
|
+
time.sleep(2) # temporary rate limit handling
|
|
324
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
325
|
+
|
|
326
|
+
payload_step3 = {
|
|
327
|
+
"contents": [{
|
|
328
|
+
"parts": [{"text": step3_filled}]
|
|
329
|
+
}],
|
|
330
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
result_step3 = make_google_request(url, headers, payload_step3)
|
|
334
|
+
answer = result_step3["candidates"][0]["content"]["parts"][0]["text"]
|
|
335
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
336
|
+
|
|
337
|
+
# STEP 4: Final corrected categorization
|
|
338
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
339
|
+
|
|
340
|
+
step4_filled = (step4_prompt
|
|
341
|
+
.replace('<<PROMPT>>', prompt)
|
|
342
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
343
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
344
|
+
|
|
345
|
+
payload_step4 = {
|
|
346
|
+
"contents": [{
|
|
347
|
+
"parts": [{"text": step4_filled}]
|
|
348
|
+
}],
|
|
349
|
+
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
result_step4 = make_google_request(url, headers, payload_step4)
|
|
353
|
+
verified_reply = result_step4["candidates"][0]["content"]["parts"][0]["text"]
|
|
354
|
+
|
|
355
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
356
|
+
return verified_reply
|
|
357
|
+
|
|
358
|
+
except Exception as e:
|
|
359
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
360
|
+
print("Falling back to initial response.\n")
|
|
361
|
+
return initial_reply
|
|
362
|
+
|
|
363
|
+
# mistral chain of verification calls
|
|
364
|
+
|
|
365
|
+
def chain_of_verification_mistral(
|
|
366
|
+
initial_reply,
|
|
367
|
+
step2_prompt,
|
|
368
|
+
step3_prompt,
|
|
369
|
+
step4_prompt,
|
|
370
|
+
client,
|
|
371
|
+
user_model,
|
|
372
|
+
creativity,
|
|
373
|
+
remove_numbering
|
|
374
|
+
):
|
|
375
|
+
"""
|
|
376
|
+
Execute Chain of Verification (CoVe) process for Mistral AI.
|
|
377
|
+
Returns the verified reply or initial reply if error occurs.
|
|
378
|
+
"""
|
|
379
|
+
try:
|
|
380
|
+
# STEP 2: Generate verification questions
|
|
381
|
+
step2_filled = step2_prompt.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
382
|
+
|
|
383
|
+
verification_response = client.chat.complete(
|
|
384
|
+
model=user_model,
|
|
385
|
+
messages=[{'role': 'user', 'content': step2_filled}],
|
|
386
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
verification_questions = verification_response.choices[0].message.content
|
|
390
|
+
|
|
391
|
+
# STEP 3: Answer verification questions
|
|
392
|
+
questions_list = [
|
|
393
|
+
remove_numbering(q)
|
|
394
|
+
for q in verification_questions.split('\n')
|
|
395
|
+
if q.strip()
|
|
396
|
+
]
|
|
397
|
+
verification_qa = []
|
|
398
|
+
|
|
399
|
+
# Prompting each question individually
|
|
400
|
+
for question in questions_list:
|
|
401
|
+
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
402
|
+
|
|
403
|
+
answer_response = client.chat.complete(
|
|
404
|
+
model=user_model,
|
|
405
|
+
messages=[{'role': 'user', 'content': step3_filled}],
|
|
406
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
answer = answer_response.choices[0].message.content
|
|
410
|
+
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
411
|
+
|
|
412
|
+
# STEP 4: Final corrected categorization
|
|
413
|
+
verification_qa_text = "\n\n".join(verification_qa)
|
|
414
|
+
|
|
415
|
+
step4_filled = (step4_prompt
|
|
416
|
+
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
417
|
+
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
418
|
+
|
|
419
|
+
final_response = client.chat.complete(
|
|
420
|
+
model=user_model,
|
|
421
|
+
messages=[{'role': 'user', 'content': step4_filled}],
|
|
422
|
+
**({"temperature": creativity} if creativity is not None else {})
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
verified_reply = final_response.choices[0].message.content
|
|
426
|
+
print("Chain of verification completed. Final response generated.\n")
|
|
427
|
+
|
|
428
|
+
return verified_reply
|
|
429
|
+
|
|
430
|
+
except Exception as e:
|
|
431
|
+
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
432
|
+
print("Falling back to initial response.\n")
|
|
433
|
+
return initial_reply
|
catllm/model_reference_list.py
CHANGED
catllm/text_functions.py
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
1
|
+
from .calls.all_calls import (
|
|
2
|
+
get_stepback_insight_openai,
|
|
3
|
+
get_stepback_insight_anthropic,
|
|
4
|
+
get_stepback_insight_google,
|
|
5
|
+
get_stepback_insight_mistral,
|
|
6
|
+
chain_of_verification_openai,
|
|
7
|
+
chain_of_verification_google,
|
|
8
|
+
chain_of_verification_anthropic,
|
|
9
|
+
chain_of_verification_mistral
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
1
13
|
#extract categories from corpus
|
|
2
14
|
def explore_corpus(
|
|
3
15
|
survey_question,
|
|
@@ -244,13 +256,15 @@ def multi_class(
|
|
|
244
256
|
example4 = None,
|
|
245
257
|
example5 = None,
|
|
246
258
|
example6 = None,
|
|
247
|
-
creativity=None,
|
|
248
|
-
safety=False,
|
|
249
|
-
to_csv=False,
|
|
250
|
-
chain_of_verification=False,
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
259
|
+
creativity = None,
|
|
260
|
+
safety = False,
|
|
261
|
+
to_csv = False,
|
|
262
|
+
chain_of_verification = False,
|
|
263
|
+
step_back_prompt = False,
|
|
264
|
+
context_prompt = False,
|
|
265
|
+
filename = "categorized_data.csv",
|
|
266
|
+
save_directory = None,
|
|
267
|
+
model_source = "auto"
|
|
254
268
|
):
|
|
255
269
|
import os
|
|
256
270
|
import json
|
|
@@ -331,6 +345,49 @@ def multi_class(
|
|
|
331
345
|
else:
|
|
332
346
|
survey_question_context = ""
|
|
333
347
|
|
|
348
|
+
# step back insight initializationif step_back_prompt:
|
|
349
|
+
if step_back_prompt:
|
|
350
|
+
if survey_question == "": # step back requires the survey question to function well
|
|
351
|
+
raise TypeError("survey_question is required when using step_back_prompt. Please provide the survey question you are analyzing.")
|
|
352
|
+
|
|
353
|
+
stepback = f"""What are the underlying factors or dimensions that explain how people typically answer "{survey_question}"?"""
|
|
354
|
+
|
|
355
|
+
if model_source in ["openai", "perplexity", "huggingface"]:
|
|
356
|
+
stepback_insight, step_back_added = get_stepback_insight_openai(
|
|
357
|
+
stepback=stepback,
|
|
358
|
+
api_key=api_key,
|
|
359
|
+
user_model=user_model,
|
|
360
|
+
model_source=model_source,
|
|
361
|
+
creativity=creativity
|
|
362
|
+
)
|
|
363
|
+
elif model_source == "anthropic":
|
|
364
|
+
stepback_insight, step_back_added = get_stepback_insight_anthropic(
|
|
365
|
+
stepback=stepback,
|
|
366
|
+
api_key=api_key,
|
|
367
|
+
user_model=user_model,
|
|
368
|
+
model_source=model_source,
|
|
369
|
+
creativity=creativity
|
|
370
|
+
)
|
|
371
|
+
elif model_source == "google":
|
|
372
|
+
stepback_insight, step_back_added = get_stepback_insight_google(
|
|
373
|
+
stepback=stepback,
|
|
374
|
+
api_key=api_key,
|
|
375
|
+
user_model=user_model,
|
|
376
|
+
model_source=model_source,
|
|
377
|
+
creativity=creativity
|
|
378
|
+
)
|
|
379
|
+
elif model_source == "mistral":
|
|
380
|
+
stepback_insight, step_back_added = get_stepback_insight_mistral(
|
|
381
|
+
stepback=stepback,
|
|
382
|
+
api_key=api_key,
|
|
383
|
+
user_model=user_model,
|
|
384
|
+
model_source=model_source,
|
|
385
|
+
creativity=creativity
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
stepback_insight = None
|
|
389
|
+
step_back_added = False
|
|
390
|
+
|
|
334
391
|
for idx, response in enumerate(tqdm(survey_input, desc="Categorizing responses")):
|
|
335
392
|
reply = None
|
|
336
393
|
|
|
@@ -347,6 +404,14 @@ def multi_class(
|
|
|
347
404
|
{examples_text}
|
|
348
405
|
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values."""
|
|
349
406
|
|
|
407
|
+
if context_prompt:
|
|
408
|
+
context = """You are an expert researcher in survey data categorization.
|
|
409
|
+
Apply multi-label classification and base decisions on explicit and implicit meanings.
|
|
410
|
+
When uncertain, prioritize precision over recall."""
|
|
411
|
+
|
|
412
|
+
prompt = context + prompt
|
|
413
|
+
print(prompt)
|
|
414
|
+
|
|
350
415
|
if chain_of_verification:
|
|
351
416
|
step2_prompt = f"""You provided this initial categorization:
|
|
352
417
|
<<INITIAL_REPLY>>
|
|
@@ -384,7 +449,7 @@ def multi_class(
|
|
|
384
449
|
If no categories are present, assign "0" to all categories.
|
|
385
450
|
Provide the final corrected categorization in the same JSON format:"""
|
|
386
451
|
|
|
387
|
-
|
|
452
|
+
# Main model interaction
|
|
388
453
|
if model_source in ["openai", "perplexity", "huggingface"]:
|
|
389
454
|
from openai import OpenAI
|
|
390
455
|
from openai import OpenAI, BadRequestError, AuthenticationError
|
|
@@ -398,73 +463,33 @@ def multi_class(
|
|
|
398
463
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
399
464
|
|
|
400
465
|
try:
|
|
466
|
+
messages = [
|
|
467
|
+
*([{'role': 'user', 'content': stepback}] if step_back_prompt and step_back_added else []), # only if step back is enabled and successful
|
|
468
|
+
*([{'role': 'assistant', 'content': stepback_insight}] if step_back_added else {}), # include insight if step back succeeded
|
|
469
|
+
{'role': 'user', 'content': prompt}
|
|
470
|
+
]
|
|
471
|
+
|
|
401
472
|
response_obj = client.chat.completions.create(
|
|
402
473
|
model=user_model,
|
|
403
|
-
messages=
|
|
474
|
+
messages=messages,
|
|
404
475
|
**({"temperature": creativity} if creativity is not None else {})
|
|
405
476
|
)
|
|
406
477
|
|
|
407
478
|
reply = response_obj.choices[0].message.content
|
|
408
479
|
|
|
409
480
|
if chain_of_verification:
|
|
410
|
-
|
|
411
|
-
initial_reply
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
verification_questions = verification_response.choices[0].message.content
|
|
422
|
-
#STEP 3: Answer verification questions
|
|
423
|
-
questions_list = [
|
|
424
|
-
remove_numbering(q)
|
|
425
|
-
for q in verification_questions.split('\n')
|
|
426
|
-
if q.strip()
|
|
427
|
-
]
|
|
428
|
-
verification_qa = []
|
|
429
|
-
|
|
430
|
-
#prompting each question individually
|
|
431
|
-
for question in questions_list:
|
|
432
|
-
|
|
433
|
-
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
434
|
-
|
|
435
|
-
answer_response = client.chat.completions.create(
|
|
436
|
-
model=user_model,
|
|
437
|
-
messages=[{'role': 'user', 'content': step3_filled}],
|
|
438
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
answer = answer_response.choices[0].message.content
|
|
442
|
-
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
443
|
-
|
|
444
|
-
#STEP 4: Final corrected categorization
|
|
445
|
-
verification_qa_text = "\n\n".join(verification_qa)
|
|
446
|
-
|
|
447
|
-
step4_filled = (step4_prompt
|
|
448
|
-
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
449
|
-
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
450
|
-
|
|
451
|
-
print(f"Final prompt:\n{step4_filled}\n")
|
|
452
|
-
|
|
453
|
-
final_response = client.chat.completions.create(
|
|
454
|
-
model=user_model,
|
|
455
|
-
messages=[{'role': 'user', 'content': step4_filled}],
|
|
456
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
457
|
-
)
|
|
458
|
-
|
|
459
|
-
reply = final_response.choices[0].message.content
|
|
460
|
-
|
|
461
|
-
print("Chain of verification completed. Final response generated.\n")
|
|
462
|
-
link1.append(reply)
|
|
481
|
+
reply = chain_of_verification_openai(
|
|
482
|
+
initial_reply=reply,
|
|
483
|
+
step2_prompt=step2_prompt,
|
|
484
|
+
step3_prompt=step3_prompt,
|
|
485
|
+
step4_prompt=step4_prompt,
|
|
486
|
+
client=client,
|
|
487
|
+
user_model=user_model,
|
|
488
|
+
creativity=creativity,
|
|
489
|
+
remove_numbering=remove_numbering
|
|
490
|
+
)
|
|
463
491
|
|
|
464
|
-
|
|
465
|
-
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
466
|
-
print("Falling back to initial response.\n")
|
|
467
|
-
link1.append(reply)
|
|
492
|
+
link1.append(reply)
|
|
468
493
|
else:
|
|
469
494
|
#if chain of verification is not enabled, just append initial reply
|
|
470
495
|
link1.append(reply)
|
|
@@ -492,68 +517,18 @@ def multi_class(
|
|
|
492
517
|
reply = response_obj.content[0].text
|
|
493
518
|
|
|
494
519
|
if chain_of_verification:
|
|
495
|
-
|
|
496
|
-
initial_reply
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
)
|
|
506
|
-
|
|
507
|
-
verification_questions = verification_response.content[0].text
|
|
508
|
-
#STEP 3: Answer verification questions
|
|
509
|
-
questions_list = [
|
|
510
|
-
remove_numbering(q)
|
|
511
|
-
for q in verification_questions.split('\n')
|
|
512
|
-
if q.strip()
|
|
513
|
-
]
|
|
514
|
-
print(f"Verification questions:\n{questions_list}\n")
|
|
515
|
-
verification_qa = []
|
|
516
|
-
|
|
517
|
-
#prompting each question individually
|
|
518
|
-
for question in questions_list:
|
|
519
|
-
|
|
520
|
-
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
521
|
-
|
|
522
|
-
answer_response = client.messages.create(
|
|
523
|
-
model=user_model,
|
|
524
|
-
messages=[{'role': 'user', 'content': step3_filled}],
|
|
525
|
-
max_tokens=4096,
|
|
526
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
527
|
-
)
|
|
528
|
-
|
|
529
|
-
answer = answer_response.content[0].text
|
|
530
|
-
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
531
|
-
|
|
532
|
-
#STEP 4: Final corrected categorization
|
|
533
|
-
verification_qa_text = "\n\n".join(verification_qa)
|
|
534
|
-
|
|
535
|
-
step4_filled = (step4_prompt
|
|
536
|
-
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
537
|
-
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
538
|
-
|
|
539
|
-
print(f"Final prompt:\n{step4_filled}\n")
|
|
540
|
-
|
|
541
|
-
final_response = client.messages.create(
|
|
542
|
-
model=user_model,
|
|
543
|
-
messages=[{'role': 'user', 'content': step4_filled}],
|
|
544
|
-
max_tokens=4096,
|
|
545
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
reply = final_response.content[0].text
|
|
549
|
-
|
|
550
|
-
print("Chain of verification completed. Final response generated.\n")
|
|
551
|
-
link1.append(reply)
|
|
520
|
+
reply = chain_of_verification_anthropic(
|
|
521
|
+
initial_reply=reply,
|
|
522
|
+
step2_prompt=step2_prompt,
|
|
523
|
+
step3_prompt=step3_prompt,
|
|
524
|
+
step4_prompt=step4_prompt,
|
|
525
|
+
client=client,
|
|
526
|
+
user_model=user_model,
|
|
527
|
+
creativity=creativity,
|
|
528
|
+
remove_numbering=remove_numbering
|
|
529
|
+
)
|
|
552
530
|
|
|
553
|
-
|
|
554
|
-
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
555
|
-
print("Falling back to initial response.\n")
|
|
556
|
-
link1.append(reply)
|
|
531
|
+
link1.append(reply)
|
|
557
532
|
else:
|
|
558
533
|
#if chain of verification is not enabled, just append initial reply
|
|
559
534
|
link1.append(reply)
|
|
@@ -605,71 +580,20 @@ def multi_class(
|
|
|
605
580
|
reply = "No response generated"
|
|
606
581
|
|
|
607
582
|
if chain_of_verification:
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
result_step2 = make_google_request(url, headers, payload_step2)
|
|
622
|
-
|
|
623
|
-
verification_questions = result_step2["candidates"][0]["content"]["parts"][0]["text"]
|
|
624
|
-
|
|
625
|
-
# STEP 3: Answer verification questions
|
|
626
|
-
questions_list = [
|
|
627
|
-
remove_numbering(q)
|
|
628
|
-
for q in verification_questions.split('\n')
|
|
629
|
-
if q.strip()
|
|
630
|
-
]
|
|
631
|
-
verification_qa = []
|
|
632
|
-
|
|
633
|
-
for question in questions_list:
|
|
634
|
-
time.sleep(2) # temporary rate limit handling
|
|
635
|
-
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
636
|
-
payload_step3 = {
|
|
637
|
-
"contents": [{
|
|
638
|
-
"parts": [{"text": step3_filled}]
|
|
639
|
-
}],
|
|
640
|
-
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
641
|
-
}
|
|
642
|
-
|
|
643
|
-
result_step3 = make_google_request(url, headers, payload_step3)
|
|
644
|
-
|
|
645
|
-
answer = result_step3["candidates"][0]["content"]["parts"][0]["text"]
|
|
646
|
-
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
647
|
-
|
|
648
|
-
# STEP 4: Final corrected categorization
|
|
649
|
-
verification_qa_text = "\n\n".join(verification_qa)
|
|
650
|
-
|
|
651
|
-
step4_filled = (step4_prompt
|
|
652
|
-
.replace('<<PROMPT>>', prompt)
|
|
653
|
-
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
654
|
-
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
655
|
-
|
|
656
|
-
payload_step4 = {
|
|
657
|
-
"contents": [{
|
|
658
|
-
"parts": [{"text": step4_filled}]
|
|
659
|
-
}],
|
|
660
|
-
**({"generationConfig": {"temperature": creativity}} if creativity is not None else {})
|
|
661
|
-
}
|
|
662
|
-
|
|
663
|
-
result_step4 = make_google_request(url, headers, payload_step4)
|
|
664
|
-
|
|
665
|
-
reply = result_step4["candidates"][0]["content"]["parts"][0]["text"]
|
|
666
|
-
print("Chain of verification completed. Final response generated.\n")
|
|
583
|
+
reply = chain_of_verification_google(
|
|
584
|
+
initial_reply=reply,
|
|
585
|
+
prompt=prompt,
|
|
586
|
+
step2_prompt=step2_prompt,
|
|
587
|
+
step3_prompt=step3_prompt,
|
|
588
|
+
step4_prompt=step4_prompt,
|
|
589
|
+
url=url,
|
|
590
|
+
headers=headers,
|
|
591
|
+
creativity=creativity,
|
|
592
|
+
remove_numbering=remove_numbering,
|
|
593
|
+
make_google_request=make_google_request
|
|
594
|
+
)
|
|
667
595
|
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
except Exception as e:
|
|
671
|
-
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
672
|
-
print("Falling back to initial response.\n")
|
|
596
|
+
link1.append(reply)
|
|
673
597
|
|
|
674
598
|
else:
|
|
675
599
|
# if chain of verification is not enabled, just append initial reply
|
|
@@ -703,59 +627,19 @@ def multi_class(
|
|
|
703
627
|
reply = response.choices[0].message.content
|
|
704
628
|
|
|
705
629
|
if chain_of_verification:
|
|
706
|
-
|
|
707
|
-
initial_reply
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
questions_list = [
|
|
720
|
-
remove_numbering(q)
|
|
721
|
-
for q in verification_questions.split('\n')
|
|
722
|
-
if q.strip()
|
|
723
|
-
]
|
|
724
|
-
verification_qa = []
|
|
725
|
-
|
|
726
|
-
#prompting each question individually
|
|
727
|
-
for question in questions_list:
|
|
728
|
-
|
|
729
|
-
step3_filled = step3_prompt.replace('<<QUESTION>>', question)
|
|
730
|
-
|
|
731
|
-
answer_response = client.chat.complete(
|
|
732
|
-
model=user_model,
|
|
733
|
-
messages=[{'role': 'user', 'content': step3_filled}],
|
|
734
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
735
|
-
)
|
|
736
|
-
|
|
737
|
-
answer = answer_response.choices[0].message.content
|
|
738
|
-
verification_qa.append(f"Q: {question}\nA: {answer}")
|
|
739
|
-
|
|
740
|
-
#STEP 4: Final corrected categorization
|
|
741
|
-
verification_qa_text = "\n\n".join(verification_qa)
|
|
742
|
-
|
|
743
|
-
step4_filled = (step4_prompt
|
|
744
|
-
.replace('<<INITIAL_REPLY>>', initial_reply)
|
|
745
|
-
.replace('<<VERIFICATION_QA>>', verification_qa_text))
|
|
746
|
-
|
|
747
|
-
final_response = client.chat.complete(
|
|
748
|
-
model=user_model,
|
|
749
|
-
messages=[{'role': 'user', 'content': step4_filled}],
|
|
750
|
-
**({"temperature": creativity} if creativity is not None else {})
|
|
751
|
-
)
|
|
752
|
-
|
|
753
|
-
reply = final_response.choices[0].message.content
|
|
754
|
-
|
|
755
|
-
link1.append(reply)
|
|
756
|
-
except Exception as e:
|
|
757
|
-
print(f"ERROR in Chain of Verification: {str(e)}")
|
|
758
|
-
print("Falling back to initial response.\n")
|
|
630
|
+
reply = chain_of_verification_mistral(
|
|
631
|
+
initial_reply=reply,
|
|
632
|
+
step2_prompt=step2_prompt,
|
|
633
|
+
step3_prompt=step3_prompt,
|
|
634
|
+
step4_prompt=step4_prompt,
|
|
635
|
+
client=client,
|
|
636
|
+
user_model=user_model,
|
|
637
|
+
creativity=creativity,
|
|
638
|
+
remove_numbering=remove_numbering
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
link1.append(reply)
|
|
642
|
+
|
|
759
643
|
else:
|
|
760
644
|
#if chain of verification is not enabled, just append initial reply
|
|
761
645
|
link1.append(reply)
|
|
@@ -832,6 +716,25 @@ def multi_class(
|
|
|
832
716
|
'json': pd.Series(extracted_jsons).reset_index(drop=True)
|
|
833
717
|
})
|
|
834
718
|
categorized_data = pd.concat([categorized_data, normalized_data], axis=1)
|
|
719
|
+
categorized_data = categorized_data.rename(columns=lambda x: f'category_{x}' if str(x).isdigit() else x)
|
|
720
|
+
|
|
721
|
+
#converting to numeric
|
|
722
|
+
cat_cols = [col for col in categorized_data.columns if col.startswith('category_')]
|
|
723
|
+
|
|
724
|
+
categorized_data['processing_status'] = np.where(
|
|
725
|
+
categorized_data[cat_cols].isna().all(axis=1),
|
|
726
|
+
'error',
|
|
727
|
+
'success'
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
categorized_data.loc[categorized_data[cat_cols].apply(pd.to_numeric, errors='coerce').isna().any(axis=1), cat_cols] = np.nan
|
|
731
|
+
categorized_data[cat_cols] = categorized_data[cat_cols].astype('Int64')
|
|
732
|
+
|
|
733
|
+
categorized_data['categories_present'] = categorized_data[cat_cols].apply(
|
|
734
|
+
lambda x: ','.join(x.dropna().astype(str)), axis=1
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
categorized_data['categories_counted'] = categorized_data[cat_cols].count(axis=1)
|
|
835
738
|
|
|
836
739
|
if to_csv:
|
|
837
740
|
if save_directory is None:
|
|
File without changes
|
|
File without changes
|