ragaai-catalyst 2.0.5__py3-none-any.whl → 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +3 -1
- ragaai_catalyst/dataset.py +50 -61
- ragaai_catalyst/evaluation.py +48 -30
- ragaai_catalyst/guard_executor.py +97 -0
- ragaai_catalyst/guardrails_manager.py +259 -0
- ragaai_catalyst/internal_api_completion.py +83 -0
- ragaai_catalyst/prompt_manager.py +1 -1
- ragaai_catalyst/proxy_call.py +1 -1
- ragaai_catalyst/ragaai_catalyst.py +1 -1
- ragaai_catalyst/synthetic_data_generation.py +206 -77
- ragaai_catalyst/tracers/llamaindex_callback.py +361 -0
- ragaai_catalyst/tracers/tracer.py +62 -28
- ragaai_catalyst-2.0.6.dist-info/METADATA +386 -0
- ragaai_catalyst-2.0.6.dist-info/RECORD +29 -0
- {ragaai_catalyst-2.0.5.dist-info → ragaai_catalyst-2.0.6.dist-info}/WHEEL +1 -1
- ragaai_catalyst-2.0.5.dist-info/METADATA +0 -228
- ragaai_catalyst-2.0.5.dist-info/RECORD +0 -25
- {ragaai_catalyst-2.0.5.dist-info → ragaai_catalyst-2.0.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,83 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
import subprocess
|
4
|
+
import logging
|
5
|
+
import traceback
|
6
|
+
import pandas as pd
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
def api_completion(messages, model_config, kwargs):
|
11
|
+
attempts = 0
|
12
|
+
while attempts < 3:
|
13
|
+
|
14
|
+
user_id = kwargs.get('user_id', '1')
|
15
|
+
internal_llm_proxy = kwargs.get('internal_llm_proxy', -1)
|
16
|
+
|
17
|
+
|
18
|
+
job_id = model_config.get('job_id',-1)
|
19
|
+
converted_message = convert_input(messages,model_config, user_id)
|
20
|
+
payload = json.dumps(converted_message)
|
21
|
+
headers = {
|
22
|
+
'Content-Type': 'application/json',
|
23
|
+
# 'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)'
|
24
|
+
}
|
25
|
+
try:
|
26
|
+
response = requests.request("POST", internal_llm_proxy, headers=headers, data=payload)
|
27
|
+
if model_config.get('log_level','')=='debug':
|
28
|
+
logger.info(f'Model response Job ID {job_id} {response.text}')
|
29
|
+
if response.status_code!=200:
|
30
|
+
# logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
31
|
+
raise ValueError(str(response.text))
|
32
|
+
|
33
|
+
if response.status_code==200:
|
34
|
+
response = response.json()
|
35
|
+
if "error" in response:
|
36
|
+
raise ValueError(response["error"]["message"])
|
37
|
+
else:
|
38
|
+
result= response["choices"][0]["message"]["content"]
|
39
|
+
response1 = result.replace('\n', '')
|
40
|
+
try:
|
41
|
+
json_data = json.loads(response1)
|
42
|
+
df = pd.DataFrame(json_data)
|
43
|
+
return(df)
|
44
|
+
except json.JSONDecodeError:
|
45
|
+
attempts += 1 # Increment attempts if JSON parsing fails
|
46
|
+
if attempts == 3:
|
47
|
+
raise Exception("Failed to generate a valid response after multiple attempts.")
|
48
|
+
|
49
|
+
except Exception as e:
|
50
|
+
raise ValueError(f"{e}")
|
51
|
+
|
52
|
+
|
53
|
+
def get_username():
|
54
|
+
result = subprocess.run(['whoami'], capture_output=True, text=True)
|
55
|
+
result = result.stdout
|
56
|
+
return result
|
57
|
+
|
58
|
+
|
59
|
+
def convert_input(messages, model_config, user_id):
|
60
|
+
doc_input = {
|
61
|
+
"model": model_config.get('model'),
|
62
|
+
**model_config,
|
63
|
+
"messages": messages,
|
64
|
+
"user_id": user_id
|
65
|
+
}
|
66
|
+
return doc_input
|
67
|
+
|
68
|
+
|
69
|
+
if __name__=='__main__':
|
70
|
+
messages = [
|
71
|
+
{
|
72
|
+
"role": "system",
|
73
|
+
"content": "you are a poet well versed in shakespeare literature"
|
74
|
+
},
|
75
|
+
{
|
76
|
+
"role": "user",
|
77
|
+
"content": "write a poem on pirates and penguins"
|
78
|
+
}
|
79
|
+
]
|
80
|
+
kwargs = {"internal_llm_proxy": "http://13.200.11.66:4000/chat/completions", "user_id": 1}
|
81
|
+
model_config = {"model": "workday_gateway", "provider":"openai", "max_tokens": 10}
|
82
|
+
answer = api_completion(messages, model_config, kwargs)
|
83
|
+
print(answer)
|
@@ -23,7 +23,7 @@ class PromptManager:
|
|
23
23
|
self.project_name = project_name
|
24
24
|
self.base_url = f"{RagaAICatalyst.BASE_URL}/playground/prompt"
|
25
25
|
self.timeout = 10
|
26
|
-
self.size =
|
26
|
+
self.size = 99999 #Number of projects to fetch
|
27
27
|
|
28
28
|
try:
|
29
29
|
response = requests.get(
|
ragaai_catalyst/proxy_call.py
CHANGED
@@ -23,7 +23,7 @@ def api_completion(model,messages, api_base='http://127.0.0.1:8000',
|
|
23
23
|
if model_config.get('log_level','')=='debug':
|
24
24
|
logger.info(f'Model response Job ID {job_id} {response.text}')
|
25
25
|
if response.status_code!=200:
|
26
|
-
logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
26
|
+
# logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
27
27
|
raise ValueError(str(response.text))
|
28
28
|
except Exception as e:
|
29
29
|
logger.error(f'Error in calling api Job ID {job_id}:',str(e))
|
@@ -287,7 +287,7 @@ class RagaAICatalyst:
|
|
287
287
|
def get_project_id(self, project_name):
|
288
288
|
pass
|
289
289
|
|
290
|
-
def list_projects(self, num_projects=
|
290
|
+
def list_projects(self, num_projects=99999):
|
291
291
|
"""
|
292
292
|
Retrieves a list of projects with the specified number of projects.
|
293
293
|
|
@@ -7,7 +7,14 @@ import csv
|
|
7
7
|
import markdown
|
8
8
|
import pandas as pd
|
9
9
|
import json
|
10
|
-
from
|
10
|
+
from litellm import completion
|
11
|
+
from tqdm import tqdm
|
12
|
+
# import internal_api_completion
|
13
|
+
# import proxy_call
|
14
|
+
from .internal_api_completion import api_completion as internal_api_completion
|
15
|
+
from .proxy_call import api_completion as proxy_api_completion
|
16
|
+
# from ragaai_catalyst import internal_api_completion
|
17
|
+
# from ragaai_catalyst import proxy_call
|
11
18
|
import ast
|
12
19
|
|
13
20
|
# dotenv.load_dotenv()
|
@@ -21,55 +28,170 @@ class SyntheticDataGeneration:
|
|
21
28
|
"""
|
22
29
|
Initialize the SyntheticDataGeneration class with API clients for Groq, Gemini, and OpenAI.
|
23
30
|
"""
|
24
|
-
|
31
|
+
|
32
|
+
def generate_qna(self, text, question_type="simple", n=5, model_config=dict(), api_key=None, **kwargs):
|
25
33
|
"""
|
26
34
|
Generate questions based on the given text using the specified model and provider.
|
35
|
+
Uses batch processing for larger values of n to maintain response quality.
|
27
36
|
|
28
37
|
Args:
|
29
38
|
text (str): The input text to generate questions from.
|
30
39
|
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
31
|
-
model (str): The specific model to use for generation.
|
32
|
-
provider (str): The AI provider to use ('groq', 'gemini', or 'openai').
|
33
40
|
n (int): The number of question/answer pairs to generate.
|
41
|
+
model_config (dict): Configuration for the model including provider and model name.
|
42
|
+
api_key (str, optional): The API key for the selected provider.
|
43
|
+
**kwargs: Additional keyword arguments.
|
34
44
|
|
35
45
|
Returns:
|
36
|
-
pandas.DataFrame: A DataFrame containing
|
46
|
+
pandas.DataFrame: A DataFrame containing exactly n generated questions and answers.
|
37
47
|
|
38
48
|
Raises:
|
39
|
-
ValueError: If an invalid provider is specified.
|
49
|
+
ValueError: If an invalid provider is specified or API key is missing.
|
40
50
|
"""
|
51
|
+
BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
|
41
52
|
provider = model_config.get("provider")
|
42
53
|
model = model_config.get("model")
|
43
54
|
api_base = model_config.get("api_base")
|
44
55
|
|
45
|
-
|
56
|
+
# Initialize the appropriate client based on provider
|
57
|
+
self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
58
|
+
|
59
|
+
# Initialize progress bar
|
60
|
+
pbar = tqdm(total=n, desc="Generating QA pairs")
|
61
|
+
|
62
|
+
# Initial generation phase
|
63
|
+
num_batches = (n + BATCH_SIZE - 1) // BATCH_SIZE
|
64
|
+
all_responses = []
|
65
|
+
|
66
|
+
FAILURE_CASES = [
|
67
|
+
"Invalid API key provided",
|
68
|
+
"No connection adapters",
|
69
|
+
"Required API Keys are not set",
|
70
|
+
"litellm.BadRequestError",
|
71
|
+
"litellm.AuthenticationError"]
|
72
|
+
|
73
|
+
for _ in range(num_batches):
|
74
|
+
current_batch_size = min(BATCH_SIZE, n - len(all_responses))
|
75
|
+
if current_batch_size <= 0:
|
76
|
+
break
|
77
|
+
|
78
|
+
try:
|
79
|
+
system_message = self._get_system_message(question_type, current_batch_size)
|
80
|
+
|
81
|
+
if "internal_llm_proxy" in kwargs:
|
82
|
+
batch_df = self._generate_internal_response(text, system_message, model_config, kwargs)
|
83
|
+
else:
|
84
|
+
batch_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
|
85
|
+
|
86
|
+
if not batch_df.empty and len(batch_df) > 0:
|
87
|
+
all_responses.extend(batch_df.to_dict('records'))
|
88
|
+
pbar.update(len(batch_df))
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
print(f"Batch generation failed.")
|
92
|
+
|
93
|
+
if any(error in str(e) for error in FAILURE_CASES):
|
94
|
+
raise Exception(f"{e}")
|
95
|
+
|
96
|
+
else:
|
97
|
+
print(f"Retrying...")
|
98
|
+
continue
|
99
|
+
|
100
|
+
|
101
|
+
# Convert to DataFrame and remove duplicates
|
102
|
+
result_df = pd.DataFrame(all_responses)
|
103
|
+
result_df = result_df.drop_duplicates(subset=['Question'])
|
104
|
+
|
105
|
+
# Replenish phase - generate additional questions if needed due to duplicates
|
106
|
+
while (len(result_df) < n) and ((len(result_df) >= 1)):
|
107
|
+
questions_needed = n - len(result_df)
|
108
|
+
try:
|
109
|
+
system_message = self._get_system_message(question_type, questions_needed)
|
110
|
+
|
111
|
+
if "internal_llm_proxy" in kwargs:
|
112
|
+
additional_df = self._generate_internal_response(text, system_message, model_config, kwargs)
|
113
|
+
else:
|
114
|
+
additional_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
|
115
|
+
|
116
|
+
if not additional_df.empty and len(additional_df) > 0:
|
117
|
+
# Only add questions that aren't already in result_df
|
118
|
+
new_questions = additional_df[~additional_df['Question'].isin(result_df['Question'])]
|
119
|
+
if not new_questions.empty:
|
120
|
+
result_df = pd.concat([result_df, new_questions], ignore_index=True)
|
121
|
+
result_df = result_df.drop_duplicates(subset=['Question'])
|
122
|
+
pbar.update(len(new_questions))
|
123
|
+
|
124
|
+
except Exception as e:
|
125
|
+
print(f"Replenishment generation failed")
|
126
|
+
|
127
|
+
if any(error in str(e) for error in FAILURE_CASES):
|
128
|
+
raise Exception(f"{e}")
|
129
|
+
|
130
|
+
else:
|
131
|
+
print("An unexpected error occurred. Retrying...")
|
132
|
+
continue
|
133
|
+
|
134
|
+
pbar.close()
|
135
|
+
|
136
|
+
# Ensure exactly n rows and reset index starting from 1
|
137
|
+
final_df = result_df.head(n)
|
138
|
+
final_df.index = range(1, len(final_df) + 1)
|
139
|
+
|
140
|
+
return final_df
|
141
|
+
|
142
|
+
def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
|
143
|
+
"""Initialize the appropriate client based on provider."""
|
144
|
+
if not provider:
|
145
|
+
raise ValueError("Model configuration must be provided with a valid provider and model.")
|
146
|
+
|
46
147
|
if provider == "groq":
|
47
148
|
if api_key is None and os.getenv("GROQ_API_KEY") is None:
|
48
149
|
raise ValueError("API key must be provided for Groq.")
|
49
150
|
self.groq_client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY"))
|
50
|
-
|
151
|
+
|
51
152
|
elif provider == "gemini":
|
153
|
+
if api_key is None and os.getenv("GEMINI_API_KEY") is None and api_base is None and internal_llm_proxy is None:
|
154
|
+
raise ValueError("API key must be provided for Gemini.")
|
52
155
|
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
53
|
-
|
54
|
-
if api_key is None and os.getenv("GEMINI_API_KEY") is None:
|
55
|
-
raise ValueError("API key must be provided for Gemini.")
|
56
|
-
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
57
|
-
return self._generate_gemini(text, system_message, model)
|
58
|
-
else:
|
59
|
-
messages=[
|
60
|
-
{'role': 'user', 'content': system_message+text}
|
61
|
-
]
|
62
|
-
a= proxy_call.api_completion(messages=messages ,model=model ,api_base=api_base)
|
63
|
-
b= ast.literal_eval(a[0])
|
64
|
-
return pd.DataFrame(b)
|
156
|
+
|
65
157
|
elif provider == "openai":
|
66
|
-
if api_key is None and os.getenv("OPENAI_API_KEY") is None:
|
158
|
+
if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
|
67
159
|
raise ValueError("API key must be provided for OpenAI.")
|
68
160
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
69
|
-
|
161
|
+
|
70
162
|
else:
|
71
|
-
raise ValueError("
|
163
|
+
raise ValueError(f"Provider is not recognized.")
|
164
|
+
|
165
|
+
def _generate_batch_response(self, text, system_message, provider, model_config, api_key, api_base):
|
166
|
+
"""Generate a batch of responses using the specified provider."""
|
167
|
+
MAX_RETRIES = 3
|
168
|
+
|
169
|
+
for attempt in range(MAX_RETRIES):
|
170
|
+
try:
|
171
|
+
if provider == "gemini" and api_base:
|
172
|
+
messages = [{'role': 'user', 'content': system_message + text}]
|
173
|
+
response = proxy_api_completion(messages=messages, model=model_config["model"], api_base=api_base)
|
174
|
+
# response = proxy_call.api_completion(messages=messages, model=model_config["model"], api_base=api_base)
|
175
|
+
return pd.DataFrame(ast.literal_eval(response[0]))
|
176
|
+
else:
|
177
|
+
return self._generate_llm_response(text, system_message, model_config, api_key)
|
178
|
+
except (json.JSONDecodeError, ValueError) as e:
|
179
|
+
if attempt == MAX_RETRIES - 1:
|
180
|
+
raise Exception(f"Failed to generate valid response after {MAX_RETRIES} attempts: {str(e)}")
|
181
|
+
continue
|
182
|
+
|
183
|
+
def _generate_internal_response(self, text, system_message, model_config, kwargs):
|
184
|
+
"""Generate response using internal API."""
|
185
|
+
messages = [{'role': 'user', 'content': system_message + text}]
|
186
|
+
return internal_api_completion(
|
187
|
+
messages=messages,
|
188
|
+
model_config=model_config,
|
189
|
+
kwargs=kwargs
|
190
|
+
)
|
72
191
|
|
192
|
+
|
193
|
+
|
194
|
+
|
73
195
|
def _get_system_message(self, question_type, n):
|
74
196
|
"""
|
75
197
|
Get the appropriate system message for the specified question type.
|
@@ -113,64 +235,68 @@ class SyntheticDataGeneration:
|
|
113
235
|
else:
|
114
236
|
raise ValueError("Invalid question type")
|
115
237
|
|
116
|
-
def
|
238
|
+
def _generate_llm_response(self, text, system_message, model_config, api_key=None):
|
117
239
|
"""
|
118
|
-
Generate questions using
|
240
|
+
Generate questions using LiteLLM which supports multiple providers (OpenAI, Groq, Gemini, etc.).
|
119
241
|
|
120
242
|
Args:
|
121
243
|
text (str): The input text to generate questions from.
|
122
244
|
system_message (str): The system message for the AI model.
|
123
|
-
|
245
|
+
model_config (dict): Configuration dictionary containing model details.
|
246
|
+
Required keys:
|
247
|
+
- model: The model identifier (e.g., "gpt-4", "gemini-pro", "mixtral-8x7b-32768")
|
248
|
+
Optional keys:
|
249
|
+
- api_base: Custom API base URL if needed
|
250
|
+
- max_tokens: Maximum tokens in response
|
251
|
+
- temperature: Temperature for response generation
|
252
|
+
api_key (str, optional): The API key for the model provider.
|
124
253
|
|
125
254
|
Returns:
|
126
255
|
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
127
|
-
"""
|
128
|
-
response = self.groq_client.chat.completions.create(
|
129
|
-
model=model,
|
130
|
-
messages=[
|
131
|
-
{'role': 'system', 'content': system_message},
|
132
|
-
{'role': 'user', 'content': text}
|
133
|
-
]
|
134
|
-
)
|
135
|
-
return self._parse_response(response, provider="groq")
|
136
256
|
|
137
|
-
|
257
|
+
Raises:
|
258
|
+
Exception: If there's an error in generating the response.
|
138
259
|
"""
|
139
|
-
Generate questions using the Gemini API.
|
140
260
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
261
|
+
# Prepare the messages in the format expected by LiteLLM
|
262
|
+
messages = [
|
263
|
+
{"role": "system", "content": system_message},
|
264
|
+
{"role": "user", "content": text}
|
265
|
+
]
|
266
|
+
|
267
|
+
# Set up the completion parameters
|
268
|
+
completion_params = {
|
269
|
+
"model": model_config["model"],
|
270
|
+
"messages": messages,
|
271
|
+
"api_key": api_key
|
272
|
+
}
|
273
|
+
|
274
|
+
# Add optional parameters if they exist in model_config
|
275
|
+
if "api_base" in model_config:
|
276
|
+
completion_params["api_base"] = model_config["api_base"]
|
277
|
+
if "max_tokens" in model_config:
|
278
|
+
completion_params["max_tokens"] = model_config["max_tokens"]
|
279
|
+
if "temperature" in model_config:
|
280
|
+
completion_params["temperature"] = model_config["temperature"]
|
281
|
+
|
282
|
+
# Make the API call using LiteLLM
|
283
|
+
try:
|
284
|
+
response = completion(**completion_params)
|
285
|
+
except Exception as e:
|
286
|
+
if any(error in str(e).lower() for error in ["invalid api key", "incorrect api key", "unauthorized", "authentication"]):
|
287
|
+
raise ValueError(f"Invalid API key provided for {model_config.get('provider', 'the specified')} provider")
|
288
|
+
raise Exception(f"Error calling LLM API: {str(e)}")
|
152
289
|
|
153
|
-
|
154
|
-
|
155
|
-
Generate questions using the OpenAI API.
|
290
|
+
# Extract the content from the response
|
291
|
+
content = response.choices[0].message.content
|
156
292
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
293
|
+
# Clean the response if needed (remove any prefix before the JSON list)
|
294
|
+
list_start_index = content.find('[')
|
295
|
+
if list_start_index != -1:
|
296
|
+
content = content[list_start_index:]
|
161
297
|
|
162
|
-
|
163
|
-
|
164
|
-
"""
|
165
|
-
client = openai.OpenAI(api_key=api_key)
|
166
|
-
response = client.chat.completions.create(
|
167
|
-
model=model,
|
168
|
-
messages=[
|
169
|
-
{"role": "system", "content": system_message},
|
170
|
-
{"role": "user", "content": text}
|
171
|
-
]
|
172
|
-
)
|
173
|
-
return self._parse_response(response, provider="openai")
|
298
|
+
json_data = json.loads(content)
|
299
|
+
return pd.DataFrame(json_data)
|
174
300
|
|
175
301
|
def _parse_response(self, response, provider):
|
176
302
|
"""
|
@@ -218,16 +344,19 @@ class SyntheticDataGeneration:
|
|
218
344
|
if os.path.isfile(input_data):
|
219
345
|
# If input_data is a file path
|
220
346
|
_, file_extension = os.path.splitext(input_data)
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
347
|
+
try:
|
348
|
+
if file_extension.lower() == '.pdf':
|
349
|
+
return self._read_pdf(input_data)
|
350
|
+
elif file_extension.lower() == '.txt':
|
351
|
+
return self._read_text(input_data)
|
352
|
+
elif file_extension.lower() == '.md':
|
353
|
+
return self._read_markdown(input_data)
|
354
|
+
elif file_extension.lower() == '.csv':
|
355
|
+
return self._read_csv(input_data)
|
356
|
+
else:
|
357
|
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
358
|
+
except Exception as e:
|
359
|
+
raise ValueError(f"Error reading the file. Upload a valid file. \n{e}")
|
231
360
|
else:
|
232
361
|
# If input_data is a string of text
|
233
362
|
return input_data
|