ragaai-catalyst 2.0.3__py3-none-any.whl → 2.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +3 -1
- ragaai_catalyst/evaluation.py +77 -25
- ragaai_catalyst/proxy_call.py +134 -0
- ragaai_catalyst/synthetic_data_generation.py +323 -0
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.5.dist-info}/METADATA +5 -1
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.5.dist-info}/RECORD +8 -6
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.5.dist-info}/WHEEL +0 -0
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.5.dist-info}/top_level.txt +0 -0
ragaai_catalyst/__init__.py
CHANGED
@@ -5,5 +5,7 @@ from .utils import response_checker
|
|
5
5
|
from .dataset import Dataset
|
6
6
|
from .prompt_manager import PromptManager
|
7
7
|
from .evaluation import Evaluation
|
8
|
+
from .synthetic_data_generation import SyntheticDataGeneration
|
8
9
|
|
9
|
-
|
10
|
+
|
11
|
+
__all__ = ["Experiment", "RagaAICatalyst", "Tracer", "PromptManager", "Evaluation","SyntheticDataGeneration"]
|
ragaai_catalyst/evaluation.py
CHANGED
@@ -80,7 +80,8 @@ class Evaluation:
|
|
80
80
|
try:
|
81
81
|
response = requests.get(
|
82
82
|
f'{self.base_url}/v1/llm/llm-metrics',
|
83
|
-
headers=headers
|
83
|
+
headers=headers,
|
84
|
+
timeout=self.timeout)
|
84
85
|
response.raise_for_status()
|
85
86
|
metric_names = [metric["name"] for metric in response.json()["data"]["metrics"]]
|
86
87
|
return metric_names
|
@@ -111,7 +112,8 @@ class Evaluation:
|
|
111
112
|
response = requests.post(
|
112
113
|
f'{self.base_url}/v1/llm/docs',
|
113
114
|
headers=headers,
|
114
|
-
json=data
|
115
|
+
json=data,
|
116
|
+
timeout=self.timeout)
|
115
117
|
response.raise_for_status()
|
116
118
|
if response.status_code == 200:
|
117
119
|
return response.json()["data"]["columns"]
|
@@ -128,27 +130,52 @@ class Evaluation:
|
|
128
130
|
return {}
|
129
131
|
|
130
132
|
def _get_variablename_from_dataset_schema(self, schemaName, metric_name):
|
133
|
+
# pdb.set_trace()
|
134
|
+
# print(schemaName)
|
131
135
|
dataset_schema = self._get_dataset_schema()
|
132
136
|
variableName = None
|
133
137
|
for column in dataset_schema:
|
134
|
-
columnName = column["
|
138
|
+
columnName = column["columnType"]
|
135
139
|
displayName = column["displayName"]
|
136
|
-
|
140
|
+
# print(columnName, displayName)
|
141
|
+
if "".join(columnName.split("_")).lower() == schemaName.lower():
|
137
142
|
variableName = displayName
|
143
|
+
break
|
144
|
+
return variableName
|
145
|
+
# print(variableName)
|
146
|
+
# if variableName:
|
147
|
+
# return variableName
|
148
|
+
# else:
|
149
|
+
# raise ValueError(f"'{schemaName}' column is required for {metric_name} metric evaluation, but not found in dataset")
|
150
|
+
|
151
|
+
|
152
|
+
def _get_variablename_from_user_schema_mapping(self, schemaName, metric_name, schema_mapping):
|
153
|
+
# pdb.set_trace()
|
154
|
+
user_dataset_schema = self._get_dataset_schema()
|
155
|
+
user_dataset_columns = [item["displayName"] for item in user_dataset_schema]
|
156
|
+
variableName = None
|
157
|
+
for key, val in schema_mapping.items():
|
158
|
+
if "".join(val.split("_")).lower()==schemaName:
|
159
|
+
if key in user_dataset_columns:
|
160
|
+
variableName=key
|
161
|
+
else:
|
162
|
+
raise ValueError(f"Column '{key}' is not present in {self.dataset_name}")
|
138
163
|
if variableName:
|
139
164
|
return variableName
|
140
165
|
else:
|
141
|
-
raise ValueError(f"'{schemaName
|
166
|
+
raise ValueError(f"Map '{schemaName}' column in schema_mapping for {metric_name} metric evaluation")
|
142
167
|
|
143
168
|
|
144
|
-
def _get_mapping(self, metric_name, metrics_schema):
|
169
|
+
def _get_mapping(self, metric_name, metrics_schema, schema_mapping):
|
170
|
+
|
145
171
|
mapping = []
|
146
172
|
for schema in metrics_schema:
|
147
173
|
if schema["name"]==metric_name:
|
148
174
|
requiredFields = schema["config"]["requiredFields"]
|
149
175
|
for field in requiredFields:
|
150
176
|
schemaName = field["name"]
|
151
|
-
variableName = self._get_variablename_from_dataset_schema(schemaName, metric_name)
|
177
|
+
# variableName = self._get_variablename_from_dataset_schema(schemaName, metric_name)
|
178
|
+
variableName = self._get_variablename_from_user_schema_mapping(schemaName.lower(), metric_name, schema_mapping)
|
152
179
|
mapping.append({"schemaName": schemaName, "variableName": variableName})
|
153
180
|
return mapping
|
154
181
|
|
@@ -160,7 +187,7 @@ class Evaluation:
|
|
160
187
|
"model": "null",
|
161
188
|
"params": {
|
162
189
|
"model": {
|
163
|
-
"value": "
|
190
|
+
"value": ""
|
164
191
|
}
|
165
192
|
},
|
166
193
|
"mappings": "mappings"
|
@@ -178,7 +205,8 @@ class Evaluation:
|
|
178
205
|
try:
|
179
206
|
response = requests.get(
|
180
207
|
f'{self.base_url}/v1/llm/llm-metrics',
|
181
|
-
headers=headers
|
208
|
+
headers=headers,
|
209
|
+
timeout=self.timeout)
|
182
210
|
response.raise_for_status()
|
183
211
|
metrics_schema = [metric for metric in response.json()["data"]["metrics"]]
|
184
212
|
return metrics_schema
|
@@ -208,14 +236,21 @@ class Evaluation:
|
|
208
236
|
#checking if provider is one of the allowed providers
|
209
237
|
if key.lower()=="provider" and value.lower() not in sub_providers:
|
210
238
|
raise ValueError("Enter a valid provider name. The following Provider names are supported: OpenAI, Azure, Gemini, Groq")
|
211
|
-
|
212
|
-
|
239
|
+
|
240
|
+
if key.lower()=="threshold":
|
241
|
+
if len(value)>1:
|
242
|
+
raise ValueError("'threshold' can only take one argument gte/lte/eq")
|
243
|
+
else:
|
244
|
+
for key_thres, value_thres in value.items():
|
245
|
+
base_json["metricSpec"]["config"]["params"][key] = {f"{key_thres}":value_thres}
|
246
|
+
else:
|
247
|
+
base_json["metricSpec"]["config"]["params"][key] = {"value": value}
|
213
248
|
|
214
249
|
|
215
250
|
# if metric["config"]["model"]:
|
216
251
|
# base_json["metricSpec"]["config"]["params"]["model"]["value"] = metric["config"]["model"]
|
217
252
|
base_json["metricSpec"]["displayName"] = metric["column_name"]
|
218
|
-
mappings = self._get_mapping(metric["name"], metrics_schema_response)
|
253
|
+
mappings = self._get_mapping(metric["name"], metrics_schema_response, metric["schema_mapping"])
|
219
254
|
base_json["metricSpec"]["config"]["mappings"] = mappings
|
220
255
|
metricParams.append(base_json)
|
221
256
|
metric_schema_mapping["metricParams"] = metricParams
|
@@ -228,12 +263,15 @@ class Evaluation:
|
|
228
263
|
}
|
229
264
|
try:
|
230
265
|
response = requests.get(
|
231
|
-
f
|
232
|
-
headers=headers
|
233
|
-
|
266
|
+
f"{self.base_url}/v2/llm/dataset/{str(self.dataset_id)}?initialCols=0",
|
267
|
+
headers=headers,
|
268
|
+
timeout=self.timeout,
|
269
|
+
)
|
234
270
|
response.raise_for_status()
|
235
|
-
|
236
|
-
|
271
|
+
dataset_columns = response.json()["data"]["datasetColumnsResponses"]
|
272
|
+
dataset_columns = [item["displayName"] for item in dataset_columns]
|
273
|
+
executed_metric_list = [data for data in dataset_columns if not data.startswith('_')]
|
274
|
+
|
237
275
|
return executed_metric_list
|
238
276
|
except requests.exceptions.HTTPError as http_err:
|
239
277
|
logger.error(f"HTTP error occurred: {http_err}")
|
@@ -248,6 +286,13 @@ class Evaluation:
|
|
248
286
|
return []
|
249
287
|
|
250
288
|
def add_metrics(self, metrics):
|
289
|
+
#Handle required key if missing
|
290
|
+
required_keys = {"name", "config", "column_name", "schema_mapping"}
|
291
|
+
for metric in metrics:
|
292
|
+
missing_keys = required_keys - metric.keys()
|
293
|
+
if missing_keys:
|
294
|
+
raise ValueError(f"{missing_keys} required for each metric evaluation.")
|
295
|
+
|
251
296
|
executed_metric_list = self._get_executed_metrics_list()
|
252
297
|
metrics_name = self.list_metrics()
|
253
298
|
user_metric_names = [metric["name"] for metric in metrics]
|
@@ -265,12 +310,12 @@ class Evaluation:
|
|
265
310
|
'X-Project-Id': str(self.project_id),
|
266
311
|
}
|
267
312
|
metric_schema_mapping = self._update_base_json(metrics)
|
268
|
-
print(metric_schema_mapping)
|
269
313
|
try:
|
270
314
|
response = requests.post(
|
271
315
|
f'{self.base_url}/playground/metric-evaluation',
|
272
316
|
headers=headers,
|
273
|
-
json=metric_schema_mapping
|
317
|
+
json=metric_schema_mapping,
|
318
|
+
timeout=self.timeout
|
274
319
|
)
|
275
320
|
if response.status_code == 400:
|
276
321
|
raise ValueError(response.json()["message"])
|
@@ -296,14 +341,20 @@ class Evaluation:
|
|
296
341
|
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
297
342
|
'X-Project-Id': str(self.project_id),
|
298
343
|
}
|
299
|
-
data = {"jobId": self.jobId}
|
300
344
|
try:
|
301
|
-
response = requests.
|
345
|
+
response = requests.get(
|
302
346
|
f'{self.base_url}/job/status',
|
303
347
|
headers=headers,
|
304
|
-
|
348
|
+
timeout=self.timeout)
|
305
349
|
response.raise_for_status()
|
306
|
-
|
350
|
+
if response.json()["success"]:
|
351
|
+
status_json = [item["status"] for item in response.json()["data"]["content"] if item["id"]==self.jobId][0]
|
352
|
+
if status_json == "Failed":
|
353
|
+
return print("Job failed. No results to fetch.")
|
354
|
+
elif status_json == "In Progress":
|
355
|
+
return print(f"Job in progress. Please wait while the job completes.\nVisit Job Status: {self.base_url.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to track")
|
356
|
+
elif status_json == "Completed":
|
357
|
+
print(f"Job completed. Fetching results.\nVisit Job Status: {self.base_url.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to check")
|
307
358
|
except requests.exceptions.HTTPError as http_err:
|
308
359
|
logger.error(f"HTTP error occurred: {http_err}")
|
309
360
|
except requests.exceptions.ConnectionError as conn_err:
|
@@ -336,7 +387,8 @@ class Evaluation:
|
|
336
387
|
response = requests.post(
|
337
388
|
f'{self.base_url}/v1/llm/docs',
|
338
389
|
headers=headers,
|
339
|
-
json=data
|
390
|
+
json=data,
|
391
|
+
timeout=self.timeout)
|
340
392
|
response.raise_for_status()
|
341
393
|
return response.json()
|
342
394
|
except requests.exceptions.HTTPError as http_err:
|
@@ -355,7 +407,7 @@ class Evaluation:
|
|
355
407
|
try:
|
356
408
|
response = get_presignedUrl()
|
357
409
|
preSignedURL = response["data"]["preSignedURL"]
|
358
|
-
response = requests.get(preSignedURL)
|
410
|
+
response = requests.get(preSignedURL, timeout=self.timeout)
|
359
411
|
response.raise_for_status()
|
360
412
|
return response.text
|
361
413
|
except requests.exceptions.HTTPError as http_err:
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
import subprocess
|
4
|
+
import logging
|
5
|
+
import traceback
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
def api_completion(model,messages, api_base='http://127.0.0.1:8000',
|
10
|
+
api_key='',model_config=dict()):
|
11
|
+
whoami = get_username()
|
12
|
+
all_response = list()
|
13
|
+
job_id = model_config.get('job_id',-1)
|
14
|
+
converted_message = convert_input(messages,model,model_config)
|
15
|
+
payload = json.dumps(converted_message)
|
16
|
+
response = payload
|
17
|
+
headers = {
|
18
|
+
'Content-Type': 'application/json',
|
19
|
+
'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)'
|
20
|
+
}
|
21
|
+
try:
|
22
|
+
response = requests.request("POST", api_base, headers=headers, data=payload, verify=False)
|
23
|
+
if model_config.get('log_level','')=='debug':
|
24
|
+
logger.info(f'Model response Job ID {job_id} {response.text}')
|
25
|
+
if response.status_code!=200:
|
26
|
+
logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
27
|
+
raise ValueError(str(response.text))
|
28
|
+
except Exception as e:
|
29
|
+
logger.error(f'Error in calling api Job ID {job_id}:',str(e))
|
30
|
+
raise ValueError(str(e))
|
31
|
+
try:
|
32
|
+
response = response.json()
|
33
|
+
if 'error' in response:
|
34
|
+
logger.error(f'Invalid response from API Job ID {job_id}:'+str(response))
|
35
|
+
raise ValueError(str(response.get('error')))
|
36
|
+
all_response.append(convert_output(response,job_id))
|
37
|
+
except ValueError as e1:
|
38
|
+
logger.error(f'Invalid json response from API Job ID {job_id}:'+response)
|
39
|
+
raise ValueError(str(e1))
|
40
|
+
except Exception as e1:
|
41
|
+
if model_config.get('log_level','')=='debug':
|
42
|
+
logger.info(f"Error trace Job ID: {job_id} {traceback.print_exc()}")
|
43
|
+
logger.error(f"Exception in parsing model response Job ID:{job_id} {str(e1)}")
|
44
|
+
logger.error(f"Model response Job ID: {job_id} {response.text}")
|
45
|
+
all_response.append(None)
|
46
|
+
return all_response
|
47
|
+
|
48
|
+
def get_username():
|
49
|
+
result = subprocess.run(['whoami'], capture_output=True, text=True)
|
50
|
+
result = result.stdout
|
51
|
+
return result
|
52
|
+
|
53
|
+
def convert_output(response,job_id):
|
54
|
+
try:
|
55
|
+
if response.get('prediction',{}).get('type','')=='generic-text-generation-v1':
|
56
|
+
return response['prediction']['output']
|
57
|
+
elif response.get('prediction',{}).get('type','')=='gcp-multimodal-v1':
|
58
|
+
full_response = ''
|
59
|
+
for chunk in response['prediction']['output']['chunks']:
|
60
|
+
candidate = chunk['candidates'][0]
|
61
|
+
if candidate['finishReason'] and candidate['finishReason'] not in ['STOP']:
|
62
|
+
raise ValueError(candidate['finishReason'])
|
63
|
+
part = candidate['content']['parts'][0]
|
64
|
+
full_response += part['text']
|
65
|
+
return full_response
|
66
|
+
else:
|
67
|
+
raise ValueError('Invalid prediction type passed in config')
|
68
|
+
except ValueError as e1:
|
69
|
+
raise ValueError(str(e1))
|
70
|
+
except Exception as e:
|
71
|
+
logger.warning(f'Exception in formatting model response Job ID {job_id}:'+str(e))
|
72
|
+
return None
|
73
|
+
|
74
|
+
|
75
|
+
def convert_input(prompt,model,model_config):
|
76
|
+
doc_input = {
|
77
|
+
"target": {
|
78
|
+
"provider": "echo",
|
79
|
+
"model": "echo"
|
80
|
+
},
|
81
|
+
"task": {
|
82
|
+
"type": "gcp-multimodal-v1",
|
83
|
+
"prediction_type": "gcp-multimodal-v1",
|
84
|
+
"input": {
|
85
|
+
"contents": [
|
86
|
+
{
|
87
|
+
"role": "user",
|
88
|
+
"parts": [
|
89
|
+
{
|
90
|
+
"text": "Give me a recipe for banana bread."
|
91
|
+
}
|
92
|
+
]
|
93
|
+
}
|
94
|
+
],
|
95
|
+
"safetySettings":
|
96
|
+
[
|
97
|
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
98
|
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
99
|
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
100
|
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
101
|
+
],
|
102
|
+
"generationConfig": {
|
103
|
+
"temperature": 0,
|
104
|
+
"maxOutputTokens": 8000,
|
105
|
+
"topK": 40,
|
106
|
+
"topP": 0.95,
|
107
|
+
"stopSequences": [],
|
108
|
+
"candidateCount": 1
|
109
|
+
}
|
110
|
+
}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if 'provider' not in model_config:
|
114
|
+
doc_input['target']['provider'] = 'gcp'
|
115
|
+
else:
|
116
|
+
doc_input['target']['provider'] = model_config['provider']
|
117
|
+
doc_input['task']['type'] = model_config.get('task_type','gcp-multimodal-v1')
|
118
|
+
doc_input['task']['prediction_type'] = model_config.get('prediction_type','generic-text-generation-v1')
|
119
|
+
if 'safetySettings' in model_config:
|
120
|
+
doc_input['task']['input']['safetySettings'] = model_config.get('safetySettings')
|
121
|
+
if 'generationConfig' in model_config:
|
122
|
+
doc_input['task']['input']['generationConfig'] = model_config.get('generationConfig')
|
123
|
+
doc_input['target']['model'] = model
|
124
|
+
if model_config.get('log_level','')=='debug':
|
125
|
+
logger.info(f"Using model configs Job ID {model_config.get('job_id',-1)}{doc_input}")
|
126
|
+
doc_input['task']['input']['contents'][0]['parts'] = [{"text":prompt[0]['content']}]
|
127
|
+
return doc_input
|
128
|
+
|
129
|
+
|
130
|
+
|
131
|
+
if __name__=='__main__':
|
132
|
+
message_list = ["Hi How are you","I am good","How are you"]
|
133
|
+
response = batch_completion('gemini/gemini-1.5-flash',message_list,0,1,100,api_base='http://127.0.0.1:5000')
|
134
|
+
print(response)
|
@@ -0,0 +1,323 @@
|
|
1
|
+
import os
|
2
|
+
from groq import Groq
|
3
|
+
import google.generativeai as genai
|
4
|
+
import openai
|
5
|
+
import PyPDF2
|
6
|
+
import csv
|
7
|
+
import markdown
|
8
|
+
import pandas as pd
|
9
|
+
import json
|
10
|
+
from ragaai_catalyst import proxy_call
|
11
|
+
import ast
|
12
|
+
|
13
|
+
# dotenv.load_dotenv()
|
14
|
+
|
15
|
+
class SyntheticDataGeneration:
|
16
|
+
"""
|
17
|
+
A class for generating synthetic data using various AI models and processing different document types.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self):
|
21
|
+
"""
|
22
|
+
Initialize the SyntheticDataGeneration class with API clients for Groq, Gemini, and OpenAI.
|
23
|
+
"""
|
24
|
+
def generate_qna(self, text, question_type="simple", n=5,model_config=dict(),api_key=None):
|
25
|
+
"""
|
26
|
+
Generate questions based on the given text using the specified model and provider.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
text (str): The input text to generate questions from.
|
30
|
+
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
31
|
+
model (str): The specific model to use for generation.
|
32
|
+
provider (str): The AI provider to use ('groq', 'gemini', or 'openai').
|
33
|
+
n (int): The number of question/answer pairs to generate.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If an invalid provider is specified.
|
40
|
+
"""
|
41
|
+
provider = model_config.get("provider")
|
42
|
+
model = model_config.get("model")
|
43
|
+
api_base = model_config.get("api_base")
|
44
|
+
|
45
|
+
system_message = self._get_system_message(question_type, n)
|
46
|
+
if provider == "groq":
|
47
|
+
if api_key is None and os.getenv("GROQ_API_KEY") is None:
|
48
|
+
raise ValueError("API key must be provided for Groq.")
|
49
|
+
self.groq_client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY"))
|
50
|
+
return self._generate_groq(text, system_message, model)
|
51
|
+
elif provider == "gemini":
|
52
|
+
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
53
|
+
if api_base is None:
|
54
|
+
if api_key is None and os.getenv("GEMINI_API_KEY") is None:
|
55
|
+
raise ValueError("API key must be provided for Gemini.")
|
56
|
+
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
57
|
+
return self._generate_gemini(text, system_message, model)
|
58
|
+
else:
|
59
|
+
messages=[
|
60
|
+
{'role': 'user', 'content': system_message+text}
|
61
|
+
]
|
62
|
+
a= proxy_call.api_completion(messages=messages ,model=model ,api_base=api_base)
|
63
|
+
b= ast.literal_eval(a[0])
|
64
|
+
return pd.DataFrame(b)
|
65
|
+
elif provider == "openai":
|
66
|
+
if api_key is None and os.getenv("OPENAI_API_KEY") is None:
|
67
|
+
raise ValueError("API key must be provided for OpenAI.")
|
68
|
+
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
69
|
+
return self._generate_openai(text, system_message, model,api_key=api_key)
|
70
|
+
else:
|
71
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
72
|
+
|
73
|
+
def _get_system_message(self, question_type, n):
|
74
|
+
"""
|
75
|
+
Get the appropriate system message for the specified question type.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
79
|
+
n (int): The number of question/answer pairs to generate.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
str: The system message for the AI model.
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
ValueError: If an invalid question type is specified.
|
86
|
+
"""
|
87
|
+
if question_type == 'simple':
|
88
|
+
return f'''Generate a set of {n} very simple questions answerable in a single phrase.
|
89
|
+
Also return the answers for the generated questions.
|
90
|
+
Return the response in a list of object format.
|
91
|
+
Each object in list should have Question and corresponding answer.
|
92
|
+
Do not return any extra strings. Return Generated text strictly in below format.
|
93
|
+
[{{"Question":"question,"Answer":"answer"}}]
|
94
|
+
'''
|
95
|
+
elif question_type == 'mcq':
|
96
|
+
return f'''Generate a set of {n} questions with 4 probable answers from the given text.
|
97
|
+
The options should not be longer than a phrase. There should be only 1 correct answer.
|
98
|
+
There should not be any ambiguity between correct and incorrect options.
|
99
|
+
Return the response in a list of object format.
|
100
|
+
Each object in list should have Question and a list of options.
|
101
|
+
Do not return any extra strings. Return Generated text strictly in below format.
|
102
|
+
[{{"Question":"question","Options":[option1,option2,option3,option4]}}]
|
103
|
+
'''
|
104
|
+
elif question_type == 'complex':
|
105
|
+
return f'''Can you generate a set of {n} complex questions answerable in long form from the below texts.
|
106
|
+
Make sure the questions are important and provide new information to the user.
|
107
|
+
Return the response in a list of object format. Enclose any quotes in single quote.
|
108
|
+
Do not use double quotes within questions or answers.
|
109
|
+
Each object in list should have Question and corresponding answer.
|
110
|
+
Do not return any extra strings. Return generated text strictly in below format.
|
111
|
+
[{{"Question":"question","Answer":"answers"}}]
|
112
|
+
'''
|
113
|
+
else:
|
114
|
+
raise ValueError("Invalid question type")
|
115
|
+
|
116
|
+
def _generate_groq(self, text, system_message, model):
|
117
|
+
"""
|
118
|
+
Generate questions using the Groq API.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
text (str): The input text to generate questions from.
|
122
|
+
system_message (str): The system message for the AI model.
|
123
|
+
model (str): The specific Groq model to use.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
127
|
+
"""
|
128
|
+
response = self.groq_client.chat.completions.create(
|
129
|
+
model=model,
|
130
|
+
messages=[
|
131
|
+
{'role': 'system', 'content': system_message},
|
132
|
+
{'role': 'user', 'content': text}
|
133
|
+
]
|
134
|
+
)
|
135
|
+
return self._parse_response(response, provider="groq")
|
136
|
+
|
137
|
+
def _generate_gemini(self, text, system_message, model):
|
138
|
+
"""
|
139
|
+
Generate questions using the Gemini API.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
text (str): The input text to generate questions from.
|
143
|
+
system_message (str): The system message for the AI model.
|
144
|
+
model (str): The specific Gemini model to use.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
148
|
+
"""
|
149
|
+
model = genai.GenerativeModel(model)
|
150
|
+
response = model.generate_content([system_message, text])
|
151
|
+
return self._parse_response(response, provider="gemini")
|
152
|
+
|
153
|
+
def _generate_openai(self, text, system_message, model,api_key=None):
|
154
|
+
"""
|
155
|
+
Generate questions using the OpenAI API.
|
156
|
+
|
157
|
+
Args:+
|
158
|
+
text (str): The input text to generate questions from.
|
159
|
+
system_message (str): The system message for the AI model.
|
160
|
+
model (str): The specific OpenAI model to use.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
164
|
+
"""
|
165
|
+
client = openai.OpenAI(api_key=api_key)
|
166
|
+
response = client.chat.completions.create(
|
167
|
+
model=model,
|
168
|
+
messages=[
|
169
|
+
{"role": "system", "content": system_message},
|
170
|
+
{"role": "user", "content": text}
|
171
|
+
]
|
172
|
+
)
|
173
|
+
return self._parse_response(response, provider="openai")
|
174
|
+
|
175
|
+
def _parse_response(self, response, provider):
|
176
|
+
"""
|
177
|
+
Parse the response from the AI model and return it as a DataFrame.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
response (str): The response from the AI model.
|
181
|
+
provider (str): The AI provider used ('groq', 'gemini', or 'openai').
|
182
|
+
Returns:
|
183
|
+
pandas.DataFrame: The parsed response as a DataFrame.
|
184
|
+
"""
|
185
|
+
if provider == "openai":
|
186
|
+
data = response.choices[0].message.content
|
187
|
+
elif provider == "gemini":
|
188
|
+
data = response.candidates[0].content.parts[0].text
|
189
|
+
elif provider == "groq":
|
190
|
+
data = response.choices[0].message.content.replace('\n', '')
|
191
|
+
list_start_index = data.find('[') # Find the index of the first '['
|
192
|
+
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
193
|
+
data = substring_data
|
194
|
+
|
195
|
+
else:
|
196
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
197
|
+
try:
|
198
|
+
json_data = json.loads(data)
|
199
|
+
return pd.DataFrame(json_data)
|
200
|
+
except json.JSONDecodeError:
|
201
|
+
# If JSON parsing fails, return a DataFrame with a single column
|
202
|
+
return pd.DataFrame({'content': [data]})
|
203
|
+
|
204
|
+
def process_document(self, input_data):
|
205
|
+
"""
|
206
|
+
Process the input document and extract its content.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
input_data (str): Either a file path or a string of text.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
str: The extracted text content from the document.
|
213
|
+
|
214
|
+
Raises:
|
215
|
+
ValueError: If the input is neither a valid file path nor a string of text.
|
216
|
+
"""
|
217
|
+
if isinstance(input_data, str):
|
218
|
+
if os.path.isfile(input_data):
|
219
|
+
# If input_data is a file path
|
220
|
+
_, file_extension = os.path.splitext(input_data)
|
221
|
+
if file_extension.lower() == '.pdf':
|
222
|
+
return self._read_pdf(input_data)
|
223
|
+
elif file_extension.lower() == '.txt':
|
224
|
+
return self._read_text(input_data)
|
225
|
+
elif file_extension.lower() == '.md':
|
226
|
+
return self._read_markdown(input_data)
|
227
|
+
elif file_extension.lower() == '.csv':
|
228
|
+
return self._read_csv(input_data)
|
229
|
+
else:
|
230
|
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
231
|
+
else:
|
232
|
+
# If input_data is a string of text
|
233
|
+
return input_data
|
234
|
+
else:
|
235
|
+
raise ValueError("Input must be either a file path or a string of text")
|
236
|
+
|
237
|
+
def _read_pdf(self, file_path):
|
238
|
+
"""
|
239
|
+
Read and extract text from a PDF file.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
file_path (str): The path to the PDF file.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
str: The extracted text content from the PDF.
|
246
|
+
"""
|
247
|
+
text = ""
|
248
|
+
with open(file_path, 'rb') as file:
|
249
|
+
pdf_reader = PyPDF2.PdfReader(file)
|
250
|
+
for page in pdf_reader.pages:
|
251
|
+
text += page.extract_text()
|
252
|
+
return text
|
253
|
+
|
254
|
+
def _read_text(self, file_path):
|
255
|
+
"""
|
256
|
+
Read the contents of a text file.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
file_path (str): The path to the text file.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
str: The contents of the text file.
|
263
|
+
"""
|
264
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
265
|
+
return file.read()
|
266
|
+
|
267
|
+
def _read_markdown(self, file_path):
|
268
|
+
"""
|
269
|
+
Read and convert a Markdown file to HTML.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
file_path (str): The path to the Markdown file.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
str: The HTML content converted from the Markdown file.
|
276
|
+
"""
|
277
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
278
|
+
md_content = file.read()
|
279
|
+
html_content = markdown.markdown(md_content)
|
280
|
+
return html_content
|
281
|
+
|
282
|
+
def _read_csv(self, file_path):
|
283
|
+
"""
|
284
|
+
Read and extract text from a CSV file.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
file_path (str): The path to the CSV file.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
str: The extracted text content from the CSV, with each row joined and separated by newlines.
|
291
|
+
"""
|
292
|
+
text = ""
|
293
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
294
|
+
csv_reader = csv.reader(file)
|
295
|
+
for row in csv_reader:
|
296
|
+
text += " ".join(row) + "\n"
|
297
|
+
return text
|
298
|
+
|
299
|
+
def get_supported_qna(self):
|
300
|
+
"""
|
301
|
+
Get a list of supported question types.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
list: A list of supported question types.
|
305
|
+
"""
|
306
|
+
return ['simple', 'mcq', 'complex']
|
307
|
+
|
308
|
+
def get_supported_providers(self):
|
309
|
+
"""
|
310
|
+
Get a list of supported AI providers.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
list: A list of supported AI providers.
|
314
|
+
"""
|
315
|
+
return ['gemini', 'openai']
|
316
|
+
|
317
|
+
# Usage:
|
318
|
+
# from synthetic_data_generation import SyntheticDataGeneration
|
319
|
+
# synthetic_data_generation = SyntheticDataGeneration()
|
320
|
+
# text = synthetic_data_generation.process_document(input_data=text_file)
|
321
|
+
# result = synthetic_data_generation.generate_question(text)
|
322
|
+
# supported_question_types = synthetic_data_generation.get_supported_question_types()
|
323
|
+
# supported_providers = synthetic_data_generation.get_supported_providers()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ragaai_catalyst
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.5
|
4
4
|
Summary: RAGA AI CATALYST
|
5
5
|
Author-email: Kiran Scaria <kiran.scaria@raga.ai>, Kedar Gaikwad <kedar.gaikwad@raga.ai>, Dushyant Mahajan <dushyant.mahajan@raga.ai>, Siddhartha Kosti <siddhartha.kosti@raga.ai>, Ritika Goel <ritika.goel@raga.ai>, Vijay Chaurasia <vijay.chaurasia@raga.ai>
|
6
6
|
Requires-Python: >=3.9
|
@@ -20,6 +20,10 @@ Requires-Dist: langchain-core>=0.2.11
|
|
20
20
|
Requires-Dist: langchain>=0.2.11
|
21
21
|
Requires-Dist: openai>=1.35.10
|
22
22
|
Requires-Dist: pandas>=2.1.1
|
23
|
+
Requires-Dist: groq>=0.11.0
|
24
|
+
Requires-Dist: PyPDF2>=3.0.1
|
25
|
+
Requires-Dist: google-generativeai>=0.8.2
|
26
|
+
Requires-Dist: Markdown>=3.7
|
23
27
|
Requires-Dist: tenacity==8.3.0
|
24
28
|
Provides-Extra: dev
|
25
29
|
Requires-Dist: pytest; extra == "dev"
|
@@ -1,10 +1,12 @@
|
|
1
|
-
ragaai_catalyst/__init__.py,sha256=
|
1
|
+
ragaai_catalyst/__init__.py,sha256=T0-X4yfIAe26-tWx6kLwNkKIjaFoQL2aNLIRp5wBG5w,424
|
2
2
|
ragaai_catalyst/_version.py,sha256=JKt9KaVNOMVeGs8ojO6LvIZr7ZkMzNN-gCcvryy4x8E,460
|
3
3
|
ragaai_catalyst/dataset.py,sha256=XjI06Exs6-64pQPQlky4mtcUllNMCgKP-bnM_t9EWkY,10920
|
4
|
-
ragaai_catalyst/evaluation.py,sha256=
|
4
|
+
ragaai_catalyst/evaluation.py,sha256=PR7rMkvZ4km26B24sSc60GPNS0JkrUMIYo5CPEqX2Qw,19315
|
5
5
|
ragaai_catalyst/experiment.py,sha256=8KvqgJg5JVnt9ghhGDJvdb4mN7ETBX_E5gNxBT0Nsn8,19010
|
6
6
|
ragaai_catalyst/prompt_manager.py,sha256=ZMIHrmsnPMq20YfeNxWXLtrxnJyMcxpeJ8Uya7S5dUA,16411
|
7
|
+
ragaai_catalyst/proxy_call.py,sha256=nlMdJCSW73sfN0fMbCbtIk6W992Nac5FJvcfNd6UDJk,5497
|
7
8
|
ragaai_catalyst/ragaai_catalyst.py,sha256=5Q1VCE7P33DtjaOtVGRUgBL8dpDL9kjisWGIkOyX4nE,17426
|
9
|
+
ragaai_catalyst/synthetic_data_generation.py,sha256=STpZF-a1mYT3GR4CGdDvhBdctf2ciSLyvDANqJxnQp8,12989
|
8
10
|
ragaai_catalyst/utils.py,sha256=TlhEFwLyRU690HvANbyoRycR3nQ67lxVUQoUOfTPYQ0,3772
|
9
11
|
ragaai_catalyst/tracers/__init__.py,sha256=NppmJhD3sQ5R1q6teaZLS7rULj08Gb6JT8XiPRIe_B0,49
|
10
12
|
ragaai_catalyst/tracers/tracer.py,sha256=eaGJdLEIjadHpbWBXBl5AhMa2vL97SVjik4U1L8gros,9591
|
@@ -17,7 +19,7 @@ ragaai_catalyst/tracers/instrumentators/llamaindex.py,sha256=SMrRlR4xM7k9HK43hak
|
|
17
19
|
ragaai_catalyst/tracers/instrumentators/openai.py,sha256=14R4KW9wQCR1xysLfsP_nxS7cqXrTPoD8En4MBAaZUU,379
|
18
20
|
ragaai_catalyst/tracers/utils/__init__.py,sha256=KeMaZtYaTojilpLv65qH08QmpYclfpacDA0U3wg6Ybw,64
|
19
21
|
ragaai_catalyst/tracers/utils/utils.py,sha256=ViygfJ7vZ7U0CTSA1lbxVloHp4NSlmfDzBRNCJuMhis,2374
|
20
|
-
ragaai_catalyst-2.0.
|
21
|
-
ragaai_catalyst-2.0.
|
22
|
-
ragaai_catalyst-2.0.
|
23
|
-
ragaai_catalyst-2.0.
|
22
|
+
ragaai_catalyst-2.0.5.dist-info/METADATA,sha256=tWppjo0sERHjjugIOAWdwD1p05HO6T6N_E1KYd9G9hY,6625
|
23
|
+
ragaai_catalyst-2.0.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
24
|
+
ragaai_catalyst-2.0.5.dist-info/top_level.txt,sha256=HpgsdRgEJMk8nqrU6qdCYk3di7MJkDL0B19lkc7dLfM,16
|
25
|
+
ragaai_catalyst-2.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|