ragaai-catalyst 2.0.4__py3-none-any.whl → 2.0.6b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +2 -1
- ragaai_catalyst/dataset.py +49 -60
- ragaai_catalyst/evaluation.py +79 -46
- ragaai_catalyst/guardrails_manager.py +233 -0
- ragaai_catalyst/internal_api_completion.py +83 -0
- ragaai_catalyst/proxy_call.py +1 -1
- ragaai_catalyst/synthetic_data_generation.py +201 -78
- ragaai_catalyst/tracers/llamaindex_callback.py +361 -0
- ragaai_catalyst/tracers/tracer.py +62 -28
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6b0.dist-info}/METADATA +139 -72
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6b0.dist-info}/RECORD +13 -10
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6b0.dist-info}/WHEEL +1 -1
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6b0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,83 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
import subprocess
|
4
|
+
import logging
|
5
|
+
import traceback
|
6
|
+
import pandas as pd
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
def api_completion(messages, model_config, kwargs):
|
11
|
+
attempts = 0
|
12
|
+
while attempts < 3:
|
13
|
+
|
14
|
+
user_id = kwargs.get('user_id', '1')
|
15
|
+
internal_llm_proxy = kwargs.get('internal_llm_proxy', -1)
|
16
|
+
|
17
|
+
|
18
|
+
job_id = model_config.get('job_id',-1)
|
19
|
+
converted_message = convert_input(messages,model_config, user_id)
|
20
|
+
payload = json.dumps(converted_message)
|
21
|
+
headers = {
|
22
|
+
'Content-Type': 'application/json',
|
23
|
+
# 'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)'
|
24
|
+
}
|
25
|
+
try:
|
26
|
+
response = requests.request("POST", internal_llm_proxy, headers=headers, data=payload)
|
27
|
+
if model_config.get('log_level','')=='debug':
|
28
|
+
logger.info(f'Model response Job ID {job_id} {response.text}')
|
29
|
+
if response.status_code!=200:
|
30
|
+
# logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
31
|
+
raise ValueError(str(response.text))
|
32
|
+
|
33
|
+
if response.status_code==200:
|
34
|
+
response = response.json()
|
35
|
+
if "error" in response:
|
36
|
+
raise ValueError(response["error"]["message"])
|
37
|
+
else:
|
38
|
+
result= response["choices"][0]["message"]["content"]
|
39
|
+
response1 = result.replace('\n', '')
|
40
|
+
try:
|
41
|
+
json_data = json.loads(response1)
|
42
|
+
df = pd.DataFrame(json_data)
|
43
|
+
return(df)
|
44
|
+
except json.JSONDecodeError:
|
45
|
+
attempts += 1 # Increment attempts if JSON parsing fails
|
46
|
+
if attempts == 3:
|
47
|
+
raise Exception("Failed to generate a valid response after multiple attempts.")
|
48
|
+
|
49
|
+
except Exception as e:
|
50
|
+
raise ValueError(f"{e}")
|
51
|
+
|
52
|
+
|
53
|
+
def get_username():
|
54
|
+
result = subprocess.run(['whoami'], capture_output=True, text=True)
|
55
|
+
result = result.stdout
|
56
|
+
return result
|
57
|
+
|
58
|
+
|
59
|
+
def convert_input(messages, model_config, user_id):
|
60
|
+
doc_input = {
|
61
|
+
"model": model_config.get('model'),
|
62
|
+
**model_config,
|
63
|
+
"messages": messages,
|
64
|
+
"user_id": user_id
|
65
|
+
}
|
66
|
+
return doc_input
|
67
|
+
|
68
|
+
|
69
|
+
if __name__=='__main__':
|
70
|
+
messages = [
|
71
|
+
{
|
72
|
+
"role": "system",
|
73
|
+
"content": "you are a poet well versed in shakespeare literature"
|
74
|
+
},
|
75
|
+
{
|
76
|
+
"role": "user",
|
77
|
+
"content": "write a poem on pirates and penguins"
|
78
|
+
}
|
79
|
+
]
|
80
|
+
kwargs = {"internal_llm_proxy": "http://13.200.11.66:4000/chat/completions", "user_id": 1}
|
81
|
+
model_config = {"model": "workday_gateway", "provider":"openai", "max_tokens": 10}
|
82
|
+
answer = api_completion(messages, model_config, kwargs)
|
83
|
+
print(answer)
|
ragaai_catalyst/proxy_call.py
CHANGED
@@ -23,7 +23,7 @@ def api_completion(model,messages, api_base='http://127.0.0.1:8000',
|
|
23
23
|
if model_config.get('log_level','')=='debug':
|
24
24
|
logger.info(f'Model response Job ID {job_id} {response.text}')
|
25
25
|
if response.status_code!=200:
|
26
|
-
logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
26
|
+
# logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
27
27
|
raise ValueError(str(response.text))
|
28
28
|
except Exception as e:
|
29
29
|
logger.error(f'Error in calling api Job ID {job_id}:',str(e))
|
@@ -7,7 +7,14 @@ import csv
|
|
7
7
|
import markdown
|
8
8
|
import pandas as pd
|
9
9
|
import json
|
10
|
-
from
|
10
|
+
from litellm import completion
|
11
|
+
from tqdm import tqdm
|
12
|
+
# import internal_api_completion
|
13
|
+
# import proxy_call
|
14
|
+
from .internal_api_completion import api_completion as internal_api_completion
|
15
|
+
from .proxy_call import api_completion as proxy_api_completion
|
16
|
+
# from ragaai_catalyst import internal_api_completion
|
17
|
+
# from ragaai_catalyst import proxy_call
|
11
18
|
import ast
|
12
19
|
|
13
20
|
# dotenv.load_dotenv()
|
@@ -21,55 +28,164 @@ class SyntheticDataGeneration:
|
|
21
28
|
"""
|
22
29
|
Initialize the SyntheticDataGeneration class with API clients for Groq, Gemini, and OpenAI.
|
23
30
|
"""
|
24
|
-
|
31
|
+
|
32
|
+
def generate_qna(self, text, question_type="simple", n=5, model_config=dict(), api_key=None, **kwargs):
|
25
33
|
"""
|
26
34
|
Generate questions based on the given text using the specified model and provider.
|
35
|
+
Uses batch processing for larger values of n to maintain response quality.
|
27
36
|
|
28
37
|
Args:
|
29
38
|
text (str): The input text to generate questions from.
|
30
39
|
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
31
|
-
model (str): The specific model to use for generation.
|
32
|
-
provider (str): The AI provider to use ('groq', 'gemini', or 'openai').
|
33
40
|
n (int): The number of question/answer pairs to generate.
|
41
|
+
model_config (dict): Configuration for the model including provider and model name.
|
42
|
+
api_key (str, optional): The API key for the selected provider.
|
43
|
+
**kwargs: Additional keyword arguments.
|
34
44
|
|
35
45
|
Returns:
|
36
|
-
pandas.DataFrame: A DataFrame containing
|
46
|
+
pandas.DataFrame: A DataFrame containing exactly n generated questions and answers.
|
37
47
|
|
38
48
|
Raises:
|
39
|
-
ValueError: If an invalid provider is specified.
|
49
|
+
ValueError: If an invalid provider is specified or API key is missing.
|
40
50
|
"""
|
51
|
+
BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
|
41
52
|
provider = model_config.get("provider")
|
42
53
|
model = model_config.get("model")
|
43
54
|
api_base = model_config.get("api_base")
|
44
55
|
|
45
|
-
|
56
|
+
# Initialize the appropriate client based on provider
|
57
|
+
self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
58
|
+
|
59
|
+
# Initialize progress bar
|
60
|
+
pbar = tqdm(total=n, desc="Generating QA pairs")
|
61
|
+
|
62
|
+
# Initial generation phase
|
63
|
+
num_batches = (n + BATCH_SIZE - 1) // BATCH_SIZE
|
64
|
+
all_responses = []
|
65
|
+
|
66
|
+
FAILURE_CASES = [
|
67
|
+
"Invalid API key provided",
|
68
|
+
"No connection adapters",
|
69
|
+
"Required API Keys are not set",
|
70
|
+
"litellm.BadRequestError",
|
71
|
+
"litellm.AuthenticationError"]
|
72
|
+
|
73
|
+
for _ in range(num_batches):
|
74
|
+
current_batch_size = min(BATCH_SIZE, n - len(all_responses))
|
75
|
+
if current_batch_size <= 0:
|
76
|
+
break
|
77
|
+
|
78
|
+
try:
|
79
|
+
system_message = self._get_system_message(question_type, current_batch_size)
|
80
|
+
|
81
|
+
if "internal_llm_proxy" in kwargs:
|
82
|
+
batch_df = self._generate_internal_response(text, system_message, model_config, kwargs)
|
83
|
+
else:
|
84
|
+
batch_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
|
85
|
+
|
86
|
+
if not batch_df.empty and len(batch_df) > 0:
|
87
|
+
all_responses.extend(batch_df.to_dict('records'))
|
88
|
+
pbar.update(len(batch_df))
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
print(f"Batch generation failed.")
|
92
|
+
|
93
|
+
if any(error in str(e) for error in FAILURE_CASES):
|
94
|
+
raise Exception(f"{e}")
|
95
|
+
|
96
|
+
else:
|
97
|
+
print(f"Retrying...")
|
98
|
+
continue
|
99
|
+
|
100
|
+
|
101
|
+
# Convert to DataFrame and remove duplicates
|
102
|
+
result_df = pd.DataFrame(all_responses)
|
103
|
+
result_df = result_df.drop_duplicates(subset=['Question'])
|
104
|
+
|
105
|
+
# Replenish phase - generate additional questions if needed due to duplicates
|
106
|
+
while (len(result_df) < n) and ((len(result_df) >= 1)):
|
107
|
+
questions_needed = n - len(result_df)
|
108
|
+
try:
|
109
|
+
system_message = self._get_system_message(question_type, questions_needed)
|
110
|
+
|
111
|
+
if "internal_llm_proxy" in kwargs:
|
112
|
+
additional_df = self._generate_internal_response(text, system_message, model_config, kwargs)
|
113
|
+
else:
|
114
|
+
additional_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
|
115
|
+
|
116
|
+
if not additional_df.empty and len(additional_df) > 0:
|
117
|
+
# Only add questions that aren't already in result_df
|
118
|
+
new_questions = additional_df[~additional_df['Question'].isin(result_df['Question'])]
|
119
|
+
if not new_questions.empty:
|
120
|
+
result_df = pd.concat([result_df, new_questions], ignore_index=True)
|
121
|
+
result_df = result_df.drop_duplicates(subset=['Question'])
|
122
|
+
pbar.update(len(new_questions))
|
123
|
+
|
124
|
+
except Exception as e:
|
125
|
+
print(f"Replenishment generation failed")
|
126
|
+
|
127
|
+
if any(error in str(e) for error in FAILURE_CASES):
|
128
|
+
raise Exception(f"{e}")
|
129
|
+
|
130
|
+
else:
|
131
|
+
print("An unexpected error occurred. Retrying...")
|
132
|
+
continue
|
133
|
+
|
134
|
+
pbar.close()
|
135
|
+
|
136
|
+
# Ensure exactly n rows and reset index starting from 1
|
137
|
+
final_df = result_df.head(n)
|
138
|
+
final_df.index = range(1, len(final_df) + 1)
|
139
|
+
|
140
|
+
return final_df
|
141
|
+
|
142
|
+
def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
|
143
|
+
"""Initialize the appropriate client based on provider."""
|
46
144
|
if provider == "groq":
|
47
145
|
if api_key is None and os.getenv("GROQ_API_KEY") is None:
|
48
146
|
raise ValueError("API key must be provided for Groq.")
|
49
147
|
self.groq_client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY"))
|
50
|
-
|
148
|
+
|
51
149
|
elif provider == "gemini":
|
150
|
+
if api_key is None and os.getenv("GEMINI_API_KEY") is None and api_base is None and internal_llm_proxy is None:
|
151
|
+
raise ValueError("API key must be provided for Gemini.")
|
52
152
|
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
53
|
-
|
54
|
-
if api_key is None and os.getenv("GEMINI_API_KEY") is None:
|
55
|
-
raise ValueError("API key must be provided for Gemini.")
|
56
|
-
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
57
|
-
return self._generate_gemini(text, system_message, model)
|
58
|
-
else:
|
59
|
-
messages=[
|
60
|
-
{'role': 'user', 'content': system_message+text}
|
61
|
-
]
|
62
|
-
a= proxy_call.api_completion(messages=messages ,model=model ,api_base=api_base)
|
63
|
-
b= ast.literal_eval(a[0])
|
64
|
-
return pd.DataFrame(b)
|
153
|
+
|
65
154
|
elif provider == "openai":
|
66
|
-
if api_key is None and os.getenv("OPENAI_API_KEY") is None:
|
155
|
+
if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
|
67
156
|
raise ValueError("API key must be provided for OpenAI.")
|
68
157
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
69
|
-
return self._generate_openai(text, system_message, model,api_key=api_key)
|
70
|
-
else:
|
71
|
-
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
72
158
|
|
159
|
+
def _generate_batch_response(self, text, system_message, provider, model_config, api_key, api_base):
|
160
|
+
"""Generate a batch of responses using the specified provider."""
|
161
|
+
MAX_RETRIES = 3
|
162
|
+
|
163
|
+
for attempt in range(MAX_RETRIES):
|
164
|
+
try:
|
165
|
+
if provider == "gemini" and api_base:
|
166
|
+
messages = [{'role': 'user', 'content': system_message + text}]
|
167
|
+
response = proxy_api_completion(messages=messages, model=model_config["model"], api_base=api_base)
|
168
|
+
# response = proxy_call.api_completion(messages=messages, model=model_config["model"], api_base=api_base)
|
169
|
+
return pd.DataFrame(ast.literal_eval(response[0]))
|
170
|
+
else:
|
171
|
+
return self._generate_llm_response(text, system_message, model_config, api_key)
|
172
|
+
except (json.JSONDecodeError, ValueError) as e:
|
173
|
+
if attempt == MAX_RETRIES - 1:
|
174
|
+
raise Exception(f"Failed to generate valid response after {MAX_RETRIES} attempts: {str(e)}")
|
175
|
+
continue
|
176
|
+
|
177
|
+
def _generate_internal_response(self, text, system_message, model_config, kwargs):
|
178
|
+
"""Generate response using internal API."""
|
179
|
+
messages = [{'role': 'user', 'content': system_message + text}]
|
180
|
+
return internal_api_completion(
|
181
|
+
messages=messages,
|
182
|
+
model_config=model_config,
|
183
|
+
kwargs=kwargs
|
184
|
+
)
|
185
|
+
|
186
|
+
|
187
|
+
|
188
|
+
|
73
189
|
def _get_system_message(self, question_type, n):
|
74
190
|
"""
|
75
191
|
Get the appropriate system message for the specified question type.
|
@@ -113,64 +229,68 @@ class SyntheticDataGeneration:
|
|
113
229
|
else:
|
114
230
|
raise ValueError("Invalid question type")
|
115
231
|
|
116
|
-
def
|
232
|
+
def _generate_llm_response(self, text, system_message, model_config, api_key=None):
|
117
233
|
"""
|
118
|
-
Generate questions using
|
234
|
+
Generate questions using LiteLLM which supports multiple providers (OpenAI, Groq, Gemini, etc.).
|
119
235
|
|
120
236
|
Args:
|
121
237
|
text (str): The input text to generate questions from.
|
122
238
|
system_message (str): The system message for the AI model.
|
123
|
-
|
239
|
+
model_config (dict): Configuration dictionary containing model details.
|
240
|
+
Required keys:
|
241
|
+
- model: The model identifier (e.g., "gpt-4", "gemini-pro", "mixtral-8x7b-32768")
|
242
|
+
Optional keys:
|
243
|
+
- api_base: Custom API base URL if needed
|
244
|
+
- max_tokens: Maximum tokens in response
|
245
|
+
- temperature: Temperature for response generation
|
246
|
+
api_key (str, optional): The API key for the model provider.
|
124
247
|
|
125
248
|
Returns:
|
126
249
|
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
127
|
-
"""
|
128
|
-
response = self.groq_client.chat.completions.create(
|
129
|
-
model=model,
|
130
|
-
messages=[
|
131
|
-
{'role': 'system', 'content': system_message},
|
132
|
-
{'role': 'user', 'content': text}
|
133
|
-
]
|
134
|
-
)
|
135
|
-
return self._parse_response(response, provider="groq")
|
136
250
|
|
137
|
-
|
251
|
+
Raises:
|
252
|
+
Exception: If there's an error in generating the response.
|
138
253
|
"""
|
139
|
-
Generate questions using the Gemini API.
|
140
254
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
255
|
+
# Prepare the messages in the format expected by LiteLLM
|
256
|
+
messages = [
|
257
|
+
{"role": "system", "content": system_message},
|
258
|
+
{"role": "user", "content": text}
|
259
|
+
]
|
260
|
+
|
261
|
+
# Set up the completion parameters
|
262
|
+
completion_params = {
|
263
|
+
"model": model_config["model"],
|
264
|
+
"messages": messages,
|
265
|
+
"api_key": api_key
|
266
|
+
}
|
267
|
+
|
268
|
+
# Add optional parameters if they exist in model_config
|
269
|
+
if "api_base" in model_config:
|
270
|
+
completion_params["api_base"] = model_config["api_base"]
|
271
|
+
if "max_tokens" in model_config:
|
272
|
+
completion_params["max_tokens"] = model_config["max_tokens"]
|
273
|
+
if "temperature" in model_config:
|
274
|
+
completion_params["temperature"] = model_config["temperature"]
|
275
|
+
|
276
|
+
# Make the API call using LiteLLM
|
277
|
+
try:
|
278
|
+
response = completion(**completion_params)
|
279
|
+
except Exception as e:
|
280
|
+
if any(error in str(e).lower() for error in ["invalid api key", "incorrect api key", "unauthorized", "authentication"]):
|
281
|
+
raise ValueError(f"Invalid API key provided for {model_config.get('provider', 'the specified')} provider")
|
282
|
+
raise Exception(f"Error calling LLM API: {str(e)}")
|
152
283
|
|
153
|
-
|
154
|
-
|
155
|
-
Generate questions using the OpenAI API.
|
284
|
+
# Extract the content from the response
|
285
|
+
content = response.choices[0].message.content
|
156
286
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
287
|
+
# Clean the response if needed (remove any prefix before the JSON list)
|
288
|
+
list_start_index = content.find('[')
|
289
|
+
if list_start_index != -1:
|
290
|
+
content = content[list_start_index:]
|
161
291
|
|
162
|
-
|
163
|
-
|
164
|
-
"""
|
165
|
-
client = openai.OpenAI(api_key=api_key)
|
166
|
-
response = client.chat.completions.create(
|
167
|
-
model=model,
|
168
|
-
messages=[
|
169
|
-
{"role": "system", "content": system_message},
|
170
|
-
{"role": "user", "content": text}
|
171
|
-
]
|
172
|
-
)
|
173
|
-
return self._parse_response(response, provider="openai")
|
292
|
+
json_data = json.loads(content)
|
293
|
+
return pd.DataFrame(json_data)
|
174
294
|
|
175
295
|
def _parse_response(self, response, provider):
|
176
296
|
"""
|
@@ -218,16 +338,19 @@ class SyntheticDataGeneration:
|
|
218
338
|
if os.path.isfile(input_data):
|
219
339
|
# If input_data is a file path
|
220
340
|
_, file_extension = os.path.splitext(input_data)
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
341
|
+
try:
|
342
|
+
if file_extension.lower() == '.pdf':
|
343
|
+
return self._read_pdf(input_data)
|
344
|
+
elif file_extension.lower() == '.txt':
|
345
|
+
return self._read_text(input_data)
|
346
|
+
elif file_extension.lower() == '.md':
|
347
|
+
return self._read_markdown(input_data)
|
348
|
+
elif file_extension.lower() == '.csv':
|
349
|
+
return self._read_csv(input_data)
|
350
|
+
else:
|
351
|
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
352
|
+
except Exception as e:
|
353
|
+
raise ValueError(f"Error reading the file. Upload a valid file. \n{e}")
|
231
354
|
else:
|
232
355
|
# If input_data is a string of text
|
233
356
|
return input_data
|