ragaai-catalyst 2.0.3__py3-none-any.whl → 2.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +3 -1
- ragaai_catalyst/evaluation.py +46 -9
- ragaai_catalyst/proxy_call.py +134 -0
- ragaai_catalyst/synthetic_data_generation.py +323 -0
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.4.dist-info}/METADATA +5 -1
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.4.dist-info}/RECORD +8 -6
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.4.dist-info}/WHEEL +0 -0
- {ragaai_catalyst-2.0.3.dist-info → ragaai_catalyst-2.0.4.dist-info}/top_level.txt +0 -0
ragaai_catalyst/__init__.py
CHANGED
@@ -5,5 +5,7 @@ from .utils import response_checker
|
|
5
5
|
from .dataset import Dataset
|
6
6
|
from .prompt_manager import PromptManager
|
7
7
|
from .evaluation import Evaluation
|
8
|
+
from .synthetic_data_generation import SyntheticDataGeneration
|
8
9
|
|
9
|
-
|
10
|
+
|
11
|
+
__all__ = ["Experiment", "RagaAICatalyst", "Tracer", "PromptManager", "Evaluation","SyntheticDataGeneration"]
|
ragaai_catalyst/evaluation.py
CHANGED
@@ -128,27 +128,52 @@ class Evaluation:
|
|
128
128
|
return {}
|
129
129
|
|
130
130
|
def _get_variablename_from_dataset_schema(self, schemaName, metric_name):
|
131
|
+
# pdb.set_trace()
|
132
|
+
# print(schemaName)
|
131
133
|
dataset_schema = self._get_dataset_schema()
|
132
134
|
variableName = None
|
133
135
|
for column in dataset_schema:
|
134
|
-
columnName = column["
|
136
|
+
columnName = column["columnType"]
|
135
137
|
displayName = column["displayName"]
|
136
|
-
|
138
|
+
# print(columnName, displayName)
|
139
|
+
if "".join(columnName.split("_")).lower() == schemaName.lower():
|
137
140
|
variableName = displayName
|
141
|
+
break
|
142
|
+
return variableName
|
143
|
+
# print(variableName)
|
144
|
+
# if variableName:
|
145
|
+
# return variableName
|
146
|
+
# else:
|
147
|
+
# raise ValueError(f"'{schemaName}' column is required for {metric_name} metric evaluation, but not found in dataset")
|
148
|
+
|
149
|
+
|
150
|
+
def _get_variablename_from_user_schema_mapping(self, schemaName, metric_name, schema_mapping):
|
151
|
+
# pdb.set_trace()
|
152
|
+
user_dataset_schema = self._get_dataset_schema()
|
153
|
+
user_dataset_columns = [item["displayName"] for item in user_dataset_schema]
|
154
|
+
variableName = None
|
155
|
+
for key, val in schema_mapping.items():
|
156
|
+
if "".join(val.split("_")).lower()==schemaName:
|
157
|
+
if key in user_dataset_columns:
|
158
|
+
variableName=key
|
159
|
+
else:
|
160
|
+
raise ValueError(f"Column '{key}' is not present in {self.dataset_name}")
|
138
161
|
if variableName:
|
139
162
|
return variableName
|
140
163
|
else:
|
141
|
-
raise ValueError(f"'{schemaName
|
164
|
+
raise ValueError(f"Map '{schemaName}' column in schema_mapping for {metric_name} metric evaluation")
|
142
165
|
|
143
166
|
|
144
|
-
def _get_mapping(self, metric_name, metrics_schema):
|
167
|
+
def _get_mapping(self, metric_name, metrics_schema, schema_mapping):
|
168
|
+
|
145
169
|
mapping = []
|
146
170
|
for schema in metrics_schema:
|
147
171
|
if schema["name"]==metric_name:
|
148
172
|
requiredFields = schema["config"]["requiredFields"]
|
149
173
|
for field in requiredFields:
|
150
174
|
schemaName = field["name"]
|
151
|
-
variableName = self._get_variablename_from_dataset_schema(schemaName, metric_name)
|
175
|
+
# variableName = self._get_variablename_from_dataset_schema(schemaName, metric_name)
|
176
|
+
variableName = self._get_variablename_from_user_schema_mapping(schemaName.lower(), metric_name, schema_mapping)
|
152
177
|
mapping.append({"schemaName": schemaName, "variableName": variableName})
|
153
178
|
return mapping
|
154
179
|
|
@@ -160,7 +185,7 @@ class Evaluation:
|
|
160
185
|
"model": "null",
|
161
186
|
"params": {
|
162
187
|
"model": {
|
163
|
-
"value": "
|
188
|
+
"value": ""
|
164
189
|
}
|
165
190
|
},
|
166
191
|
"mappings": "mappings"
|
@@ -215,7 +240,7 @@ class Evaluation:
|
|
215
240
|
# if metric["config"]["model"]:
|
216
241
|
# base_json["metricSpec"]["config"]["params"]["model"]["value"] = metric["config"]["model"]
|
217
242
|
base_json["metricSpec"]["displayName"] = metric["column_name"]
|
218
|
-
mappings = self._get_mapping(metric["name"], metrics_schema_response)
|
243
|
+
mappings = self._get_mapping(metric["name"], metrics_schema_response, metric["schema_mapping"])
|
219
244
|
base_json["metricSpec"]["config"]["mappings"] = mappings
|
220
245
|
metricParams.append(base_json)
|
221
246
|
metric_schema_mapping["metricParams"] = metricParams
|
@@ -248,6 +273,13 @@ class Evaluation:
|
|
248
273
|
return []
|
249
274
|
|
250
275
|
def add_metrics(self, metrics):
|
276
|
+
#Handle required key if missing
|
277
|
+
required_keys = {"name", "config", "column_name", "schema_mapping"}
|
278
|
+
for metric in metrics:
|
279
|
+
missing_keys = required_keys - metric.keys()
|
280
|
+
if missing_keys:
|
281
|
+
raise ValueError(f"{missing_keys} required for each metric evaluation.")
|
282
|
+
|
251
283
|
executed_metric_list = self._get_executed_metrics_list()
|
252
284
|
metrics_name = self.list_metrics()
|
253
285
|
user_metric_names = [metric["name"] for metric in metrics]
|
@@ -265,7 +297,6 @@ class Evaluation:
|
|
265
297
|
'X-Project-Id': str(self.project_id),
|
266
298
|
}
|
267
299
|
metric_schema_mapping = self._update_base_json(metrics)
|
268
|
-
print(metric_schema_mapping)
|
269
300
|
try:
|
270
301
|
response = requests.post(
|
271
302
|
f'{self.base_url}/playground/metric-evaluation',
|
@@ -303,7 +334,13 @@ class Evaluation:
|
|
303
334
|
headers=headers,
|
304
335
|
json=data)
|
305
336
|
response.raise_for_status()
|
306
|
-
|
337
|
+
status_json = response.json()["data"]["status"]
|
338
|
+
if status_json == "Failed":
|
339
|
+
return print("Job failed. No results to fetch.")
|
340
|
+
elif status_json == "In Progress":
|
341
|
+
return print(f"Job in progress. Please wait while the job completes.\nVisit Job Status: {self.base_url.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to track")
|
342
|
+
elif status_json == "Completed":
|
343
|
+
print(f"Job completed. Fetching results.\nVisit Job Status: {self.base_url.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to check")
|
307
344
|
except requests.exceptions.HTTPError as http_err:
|
308
345
|
logger.error(f"HTTP error occurred: {http_err}")
|
309
346
|
except requests.exceptions.ConnectionError as conn_err:
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import requests
|
2
|
+
import json
|
3
|
+
import subprocess
|
4
|
+
import logging
|
5
|
+
import traceback
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
def api_completion(model,messages, api_base='http://127.0.0.1:8000',
|
10
|
+
api_key='',model_config=dict()):
|
11
|
+
whoami = get_username()
|
12
|
+
all_response = list()
|
13
|
+
job_id = model_config.get('job_id',-1)
|
14
|
+
converted_message = convert_input(messages,model,model_config)
|
15
|
+
payload = json.dumps(converted_message)
|
16
|
+
response = payload
|
17
|
+
headers = {
|
18
|
+
'Content-Type': 'application/json',
|
19
|
+
'Wd-PCA-Feature-Key':f'your_feature_key, $(whoami)'
|
20
|
+
}
|
21
|
+
try:
|
22
|
+
response = requests.request("POST", api_base, headers=headers, data=payload, verify=False)
|
23
|
+
if model_config.get('log_level','')=='debug':
|
24
|
+
logger.info(f'Model response Job ID {job_id} {response.text}')
|
25
|
+
if response.status_code!=200:
|
26
|
+
logger.error(f'Error in model response Job ID {job_id}:',str(response.text))
|
27
|
+
raise ValueError(str(response.text))
|
28
|
+
except Exception as e:
|
29
|
+
logger.error(f'Error in calling api Job ID {job_id}:',str(e))
|
30
|
+
raise ValueError(str(e))
|
31
|
+
try:
|
32
|
+
response = response.json()
|
33
|
+
if 'error' in response:
|
34
|
+
logger.error(f'Invalid response from API Job ID {job_id}:'+str(response))
|
35
|
+
raise ValueError(str(response.get('error')))
|
36
|
+
all_response.append(convert_output(response,job_id))
|
37
|
+
except ValueError as e1:
|
38
|
+
logger.error(f'Invalid json response from API Job ID {job_id}:'+response)
|
39
|
+
raise ValueError(str(e1))
|
40
|
+
except Exception as e1:
|
41
|
+
if model_config.get('log_level','')=='debug':
|
42
|
+
logger.info(f"Error trace Job ID: {job_id} {traceback.print_exc()}")
|
43
|
+
logger.error(f"Exception in parsing model response Job ID:{job_id} {str(e1)}")
|
44
|
+
logger.error(f"Model response Job ID: {job_id} {response.text}")
|
45
|
+
all_response.append(None)
|
46
|
+
return all_response
|
47
|
+
|
48
|
+
def get_username():
|
49
|
+
result = subprocess.run(['whoami'], capture_output=True, text=True)
|
50
|
+
result = result.stdout
|
51
|
+
return result
|
52
|
+
|
53
|
+
def convert_output(response,job_id):
|
54
|
+
try:
|
55
|
+
if response.get('prediction',{}).get('type','')=='generic-text-generation-v1':
|
56
|
+
return response['prediction']['output']
|
57
|
+
elif response.get('prediction',{}).get('type','')=='gcp-multimodal-v1':
|
58
|
+
full_response = ''
|
59
|
+
for chunk in response['prediction']['output']['chunks']:
|
60
|
+
candidate = chunk['candidates'][0]
|
61
|
+
if candidate['finishReason'] and candidate['finishReason'] not in ['STOP']:
|
62
|
+
raise ValueError(candidate['finishReason'])
|
63
|
+
part = candidate['content']['parts'][0]
|
64
|
+
full_response += part['text']
|
65
|
+
return full_response
|
66
|
+
else:
|
67
|
+
raise ValueError('Invalid prediction type passed in config')
|
68
|
+
except ValueError as e1:
|
69
|
+
raise ValueError(str(e1))
|
70
|
+
except Exception as e:
|
71
|
+
logger.warning(f'Exception in formatting model response Job ID {job_id}:'+str(e))
|
72
|
+
return None
|
73
|
+
|
74
|
+
|
75
|
+
def convert_input(prompt,model,model_config):
|
76
|
+
doc_input = {
|
77
|
+
"target": {
|
78
|
+
"provider": "echo",
|
79
|
+
"model": "echo"
|
80
|
+
},
|
81
|
+
"task": {
|
82
|
+
"type": "gcp-multimodal-v1",
|
83
|
+
"prediction_type": "gcp-multimodal-v1",
|
84
|
+
"input": {
|
85
|
+
"contents": [
|
86
|
+
{
|
87
|
+
"role": "user",
|
88
|
+
"parts": [
|
89
|
+
{
|
90
|
+
"text": "Give me a recipe for banana bread."
|
91
|
+
}
|
92
|
+
]
|
93
|
+
}
|
94
|
+
],
|
95
|
+
"safetySettings":
|
96
|
+
[
|
97
|
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
98
|
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
99
|
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
100
|
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
101
|
+
],
|
102
|
+
"generationConfig": {
|
103
|
+
"temperature": 0,
|
104
|
+
"maxOutputTokens": 8000,
|
105
|
+
"topK": 40,
|
106
|
+
"topP": 0.95,
|
107
|
+
"stopSequences": [],
|
108
|
+
"candidateCount": 1
|
109
|
+
}
|
110
|
+
}
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if 'provider' not in model_config:
|
114
|
+
doc_input['target']['provider'] = 'gcp'
|
115
|
+
else:
|
116
|
+
doc_input['target']['provider'] = model_config['provider']
|
117
|
+
doc_input['task']['type'] = model_config.get('task_type','gcp-multimodal-v1')
|
118
|
+
doc_input['task']['prediction_type'] = model_config.get('prediction_type','generic-text-generation-v1')
|
119
|
+
if 'safetySettings' in model_config:
|
120
|
+
doc_input['task']['input']['safetySettings'] = model_config.get('safetySettings')
|
121
|
+
if 'generationConfig' in model_config:
|
122
|
+
doc_input['task']['input']['generationConfig'] = model_config.get('generationConfig')
|
123
|
+
doc_input['target']['model'] = model
|
124
|
+
if model_config.get('log_level','')=='debug':
|
125
|
+
logger.info(f"Using model configs Job ID {model_config.get('job_id',-1)}{doc_input}")
|
126
|
+
doc_input['task']['input']['contents'][0]['parts'] = [{"text":prompt[0]['content']}]
|
127
|
+
return doc_input
|
128
|
+
|
129
|
+
|
130
|
+
|
131
|
+
if __name__=='__main__':
|
132
|
+
message_list = ["Hi How are you","I am good","How are you"]
|
133
|
+
response = batch_completion('gemini/gemini-1.5-flash',message_list,0,1,100,api_base='http://127.0.0.1:5000')
|
134
|
+
print(response)
|
@@ -0,0 +1,323 @@
|
|
1
|
+
import os
|
2
|
+
from groq import Groq
|
3
|
+
import google.generativeai as genai
|
4
|
+
import openai
|
5
|
+
import PyPDF2
|
6
|
+
import csv
|
7
|
+
import markdown
|
8
|
+
import pandas as pd
|
9
|
+
import json
|
10
|
+
from ragaai_catalyst import proxy_call
|
11
|
+
import ast
|
12
|
+
|
13
|
+
# dotenv.load_dotenv()
|
14
|
+
|
15
|
+
class SyntheticDataGeneration:
|
16
|
+
"""
|
17
|
+
A class for generating synthetic data using various AI models and processing different document types.
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self):
|
21
|
+
"""
|
22
|
+
Initialize the SyntheticDataGeneration class with API clients for Groq, Gemini, and OpenAI.
|
23
|
+
"""
|
24
|
+
def generate_qna(self, text, question_type="simple", n=5,model_config=dict(),api_key=None):
|
25
|
+
"""
|
26
|
+
Generate questions based on the given text using the specified model and provider.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
text (str): The input text to generate questions from.
|
30
|
+
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
31
|
+
model (str): The specific model to use for generation.
|
32
|
+
provider (str): The AI provider to use ('groq', 'gemini', or 'openai').
|
33
|
+
n (int): The number of question/answer pairs to generate.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If an invalid provider is specified.
|
40
|
+
"""
|
41
|
+
provider = model_config.get("provider")
|
42
|
+
model = model_config.get("model")
|
43
|
+
api_base = model_config.get("api_base")
|
44
|
+
|
45
|
+
system_message = self._get_system_message(question_type, n)
|
46
|
+
if provider == "groq":
|
47
|
+
if api_key is None and os.getenv("GROQ_API_KEY") is None:
|
48
|
+
raise ValueError("API key must be provided for Groq.")
|
49
|
+
self.groq_client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY"))
|
50
|
+
return self._generate_groq(text, system_message, model)
|
51
|
+
elif provider == "gemini":
|
52
|
+
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
53
|
+
if api_base is None:
|
54
|
+
if api_key is None and os.getenv("GEMINI_API_KEY") is None:
|
55
|
+
raise ValueError("API key must be provided for Gemini.")
|
56
|
+
genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
|
57
|
+
return self._generate_gemini(text, system_message, model)
|
58
|
+
else:
|
59
|
+
messages=[
|
60
|
+
{'role': 'user', 'content': system_message+text}
|
61
|
+
]
|
62
|
+
a= proxy_call.api_completion(messages=messages ,model=model ,api_base=api_base)
|
63
|
+
b= ast.literal_eval(a[0])
|
64
|
+
return pd.DataFrame(b)
|
65
|
+
elif provider == "openai":
|
66
|
+
if api_key is None and os.getenv("OPENAI_API_KEY") is None:
|
67
|
+
raise ValueError("API key must be provided for OpenAI.")
|
68
|
+
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
69
|
+
return self._generate_openai(text, system_message, model,api_key=api_key)
|
70
|
+
else:
|
71
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
72
|
+
|
73
|
+
def _get_system_message(self, question_type, n):
|
74
|
+
"""
|
75
|
+
Get the appropriate system message for the specified question type.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
|
79
|
+
n (int): The number of question/answer pairs to generate.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
str: The system message for the AI model.
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
ValueError: If an invalid question type is specified.
|
86
|
+
"""
|
87
|
+
if question_type == 'simple':
|
88
|
+
return f'''Generate a set of {n} very simple questions answerable in a single phrase.
|
89
|
+
Also return the answers for the generated questions.
|
90
|
+
Return the response in a list of object format.
|
91
|
+
Each object in list should have Question and corresponding answer.
|
92
|
+
Do not return any extra strings. Return Generated text strictly in below format.
|
93
|
+
[{{"Question":"question,"Answer":"answer"}}]
|
94
|
+
'''
|
95
|
+
elif question_type == 'mcq':
|
96
|
+
return f'''Generate a set of {n} questions with 4 probable answers from the given text.
|
97
|
+
The options should not be longer than a phrase. There should be only 1 correct answer.
|
98
|
+
There should not be any ambiguity between correct and incorrect options.
|
99
|
+
Return the response in a list of object format.
|
100
|
+
Each object in list should have Question and a list of options.
|
101
|
+
Do not return any extra strings. Return Generated text strictly in below format.
|
102
|
+
[{{"Question":"question","Options":[option1,option2,option3,option4]}}]
|
103
|
+
'''
|
104
|
+
elif question_type == 'complex':
|
105
|
+
return f'''Can you generate a set of {n} complex questions answerable in long form from the below texts.
|
106
|
+
Make sure the questions are important and provide new information to the user.
|
107
|
+
Return the response in a list of object format. Enclose any quotes in single quote.
|
108
|
+
Do not use double quotes within questions or answers.
|
109
|
+
Each object in list should have Question and corresponding answer.
|
110
|
+
Do not return any extra strings. Return generated text strictly in below format.
|
111
|
+
[{{"Question":"question","Answer":"answers"}}]
|
112
|
+
'''
|
113
|
+
else:
|
114
|
+
raise ValueError("Invalid question type")
|
115
|
+
|
116
|
+
def _generate_groq(self, text, system_message, model):
|
117
|
+
"""
|
118
|
+
Generate questions using the Groq API.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
text (str): The input text to generate questions from.
|
122
|
+
system_message (str): The system message for the AI model.
|
123
|
+
model (str): The specific Groq model to use.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
127
|
+
"""
|
128
|
+
response = self.groq_client.chat.completions.create(
|
129
|
+
model=model,
|
130
|
+
messages=[
|
131
|
+
{'role': 'system', 'content': system_message},
|
132
|
+
{'role': 'user', 'content': text}
|
133
|
+
]
|
134
|
+
)
|
135
|
+
return self._parse_response(response, provider="groq")
|
136
|
+
|
137
|
+
def _generate_gemini(self, text, system_message, model):
|
138
|
+
"""
|
139
|
+
Generate questions using the Gemini API.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
text (str): The input text to generate questions from.
|
143
|
+
system_message (str): The system message for the AI model.
|
144
|
+
model (str): The specific Gemini model to use.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
148
|
+
"""
|
149
|
+
model = genai.GenerativeModel(model)
|
150
|
+
response = model.generate_content([system_message, text])
|
151
|
+
return self._parse_response(response, provider="gemini")
|
152
|
+
|
153
|
+
def _generate_openai(self, text, system_message, model,api_key=None):
|
154
|
+
"""
|
155
|
+
Generate questions using the OpenAI API.
|
156
|
+
|
157
|
+
Args:+
|
158
|
+
text (str): The input text to generate questions from.
|
159
|
+
system_message (str): The system message for the AI model.
|
160
|
+
model (str): The specific OpenAI model to use.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
pandas.DataFrame: A DataFrame containing the generated questions and answers.
|
164
|
+
"""
|
165
|
+
client = openai.OpenAI(api_key=api_key)
|
166
|
+
response = client.chat.completions.create(
|
167
|
+
model=model,
|
168
|
+
messages=[
|
169
|
+
{"role": "system", "content": system_message},
|
170
|
+
{"role": "user", "content": text}
|
171
|
+
]
|
172
|
+
)
|
173
|
+
return self._parse_response(response, provider="openai")
|
174
|
+
|
175
|
+
def _parse_response(self, response, provider):
|
176
|
+
"""
|
177
|
+
Parse the response from the AI model and return it as a DataFrame.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
response (str): The response from the AI model.
|
181
|
+
provider (str): The AI provider used ('groq', 'gemini', or 'openai').
|
182
|
+
Returns:
|
183
|
+
pandas.DataFrame: The parsed response as a DataFrame.
|
184
|
+
"""
|
185
|
+
if provider == "openai":
|
186
|
+
data = response.choices[0].message.content
|
187
|
+
elif provider == "gemini":
|
188
|
+
data = response.candidates[0].content.parts[0].text
|
189
|
+
elif provider == "groq":
|
190
|
+
data = response.choices[0].message.content.replace('\n', '')
|
191
|
+
list_start_index = data.find('[') # Find the index of the first '['
|
192
|
+
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
193
|
+
data = substring_data
|
194
|
+
|
195
|
+
else:
|
196
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
197
|
+
try:
|
198
|
+
json_data = json.loads(data)
|
199
|
+
return pd.DataFrame(json_data)
|
200
|
+
except json.JSONDecodeError:
|
201
|
+
# If JSON parsing fails, return a DataFrame with a single column
|
202
|
+
return pd.DataFrame({'content': [data]})
|
203
|
+
|
204
|
+
def process_document(self, input_data):
|
205
|
+
"""
|
206
|
+
Process the input document and extract its content.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
input_data (str): Either a file path or a string of text.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
str: The extracted text content from the document.
|
213
|
+
|
214
|
+
Raises:
|
215
|
+
ValueError: If the input is neither a valid file path nor a string of text.
|
216
|
+
"""
|
217
|
+
if isinstance(input_data, str):
|
218
|
+
if os.path.isfile(input_data):
|
219
|
+
# If input_data is a file path
|
220
|
+
_, file_extension = os.path.splitext(input_data)
|
221
|
+
if file_extension.lower() == '.pdf':
|
222
|
+
return self._read_pdf(input_data)
|
223
|
+
elif file_extension.lower() == '.txt':
|
224
|
+
return self._read_text(input_data)
|
225
|
+
elif file_extension.lower() == '.md':
|
226
|
+
return self._read_markdown(input_data)
|
227
|
+
elif file_extension.lower() == '.csv':
|
228
|
+
return self._read_csv(input_data)
|
229
|
+
else:
|
230
|
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
231
|
+
else:
|
232
|
+
# If input_data is a string of text
|
233
|
+
return input_data
|
234
|
+
else:
|
235
|
+
raise ValueError("Input must be either a file path or a string of text")
|
236
|
+
|
237
|
+
def _read_pdf(self, file_path):
|
238
|
+
"""
|
239
|
+
Read and extract text from a PDF file.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
file_path (str): The path to the PDF file.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
str: The extracted text content from the PDF.
|
246
|
+
"""
|
247
|
+
text = ""
|
248
|
+
with open(file_path, 'rb') as file:
|
249
|
+
pdf_reader = PyPDF2.PdfReader(file)
|
250
|
+
for page in pdf_reader.pages:
|
251
|
+
text += page.extract_text()
|
252
|
+
return text
|
253
|
+
|
254
|
+
def _read_text(self, file_path):
|
255
|
+
"""
|
256
|
+
Read the contents of a text file.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
file_path (str): The path to the text file.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
str: The contents of the text file.
|
263
|
+
"""
|
264
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
265
|
+
return file.read()
|
266
|
+
|
267
|
+
def _read_markdown(self, file_path):
|
268
|
+
"""
|
269
|
+
Read and convert a Markdown file to HTML.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
file_path (str): The path to the Markdown file.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
str: The HTML content converted from the Markdown file.
|
276
|
+
"""
|
277
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
278
|
+
md_content = file.read()
|
279
|
+
html_content = markdown.markdown(md_content)
|
280
|
+
return html_content
|
281
|
+
|
282
|
+
def _read_csv(self, file_path):
|
283
|
+
"""
|
284
|
+
Read and extract text from a CSV file.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
file_path (str): The path to the CSV file.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
str: The extracted text content from the CSV, with each row joined and separated by newlines.
|
291
|
+
"""
|
292
|
+
text = ""
|
293
|
+
with open(file_path, 'r', encoding='utf-8') as file:
|
294
|
+
csv_reader = csv.reader(file)
|
295
|
+
for row in csv_reader:
|
296
|
+
text += " ".join(row) + "\n"
|
297
|
+
return text
|
298
|
+
|
299
|
+
def get_supported_qna(self):
|
300
|
+
"""
|
301
|
+
Get a list of supported question types.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
list: A list of supported question types.
|
305
|
+
"""
|
306
|
+
return ['simple', 'mcq', 'complex']
|
307
|
+
|
308
|
+
def get_supported_providers(self):
|
309
|
+
"""
|
310
|
+
Get a list of supported AI providers.
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
list: A list of supported AI providers.
|
314
|
+
"""
|
315
|
+
return ['gemini', 'openai']
|
316
|
+
|
317
|
+
# Usage:
|
318
|
+
# from synthetic_data_generation import SyntheticDataGeneration
|
319
|
+
# synthetic_data_generation = SyntheticDataGeneration()
|
320
|
+
# text = synthetic_data_generation.process_document(input_data=text_file)
|
321
|
+
# result = synthetic_data_generation.generate_question(text)
|
322
|
+
# supported_question_types = synthetic_data_generation.get_supported_question_types()
|
323
|
+
# supported_providers = synthetic_data_generation.get_supported_providers()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ragaai_catalyst
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.4
|
4
4
|
Summary: RAGA AI CATALYST
|
5
5
|
Author-email: Kiran Scaria <kiran.scaria@raga.ai>, Kedar Gaikwad <kedar.gaikwad@raga.ai>, Dushyant Mahajan <dushyant.mahajan@raga.ai>, Siddhartha Kosti <siddhartha.kosti@raga.ai>, Ritika Goel <ritika.goel@raga.ai>, Vijay Chaurasia <vijay.chaurasia@raga.ai>
|
6
6
|
Requires-Python: >=3.9
|
@@ -20,6 +20,10 @@ Requires-Dist: langchain-core>=0.2.11
|
|
20
20
|
Requires-Dist: langchain>=0.2.11
|
21
21
|
Requires-Dist: openai>=1.35.10
|
22
22
|
Requires-Dist: pandas>=2.1.1
|
23
|
+
Requires-Dist: groq>=0.11.0
|
24
|
+
Requires-Dist: PyPDF2>=3.0.1
|
25
|
+
Requires-Dist: google-generativeai>=0.8.2
|
26
|
+
Requires-Dist: Markdown>=3.7
|
23
27
|
Requires-Dist: tenacity==8.3.0
|
24
28
|
Provides-Extra: dev
|
25
29
|
Requires-Dist: pytest; extra == "dev"
|
@@ -1,10 +1,12 @@
|
|
1
|
-
ragaai_catalyst/__init__.py,sha256=
|
1
|
+
ragaai_catalyst/__init__.py,sha256=T0-X4yfIAe26-tWx6kLwNkKIjaFoQL2aNLIRp5wBG5w,424
|
2
2
|
ragaai_catalyst/_version.py,sha256=JKt9KaVNOMVeGs8ojO6LvIZr7ZkMzNN-gCcvryy4x8E,460
|
3
3
|
ragaai_catalyst/dataset.py,sha256=XjI06Exs6-64pQPQlky4mtcUllNMCgKP-bnM_t9EWkY,10920
|
4
|
-
ragaai_catalyst/evaluation.py,sha256=
|
4
|
+
ragaai_catalyst/evaluation.py,sha256=qSiSIvUr2FvIGdisUzwWEYWllXDIiLKv5D00URBgGSw,18495
|
5
5
|
ragaai_catalyst/experiment.py,sha256=8KvqgJg5JVnt9ghhGDJvdb4mN7ETBX_E5gNxBT0Nsn8,19010
|
6
6
|
ragaai_catalyst/prompt_manager.py,sha256=ZMIHrmsnPMq20YfeNxWXLtrxnJyMcxpeJ8Uya7S5dUA,16411
|
7
|
+
ragaai_catalyst/proxy_call.py,sha256=nlMdJCSW73sfN0fMbCbtIk6W992Nac5FJvcfNd6UDJk,5497
|
7
8
|
ragaai_catalyst/ragaai_catalyst.py,sha256=5Q1VCE7P33DtjaOtVGRUgBL8dpDL9kjisWGIkOyX4nE,17426
|
9
|
+
ragaai_catalyst/synthetic_data_generation.py,sha256=STpZF-a1mYT3GR4CGdDvhBdctf2ciSLyvDANqJxnQp8,12989
|
8
10
|
ragaai_catalyst/utils.py,sha256=TlhEFwLyRU690HvANbyoRycR3nQ67lxVUQoUOfTPYQ0,3772
|
9
11
|
ragaai_catalyst/tracers/__init__.py,sha256=NppmJhD3sQ5R1q6teaZLS7rULj08Gb6JT8XiPRIe_B0,49
|
10
12
|
ragaai_catalyst/tracers/tracer.py,sha256=eaGJdLEIjadHpbWBXBl5AhMa2vL97SVjik4U1L8gros,9591
|
@@ -17,7 +19,7 @@ ragaai_catalyst/tracers/instrumentators/llamaindex.py,sha256=SMrRlR4xM7k9HK43hak
|
|
17
19
|
ragaai_catalyst/tracers/instrumentators/openai.py,sha256=14R4KW9wQCR1xysLfsP_nxS7cqXrTPoD8En4MBAaZUU,379
|
18
20
|
ragaai_catalyst/tracers/utils/__init__.py,sha256=KeMaZtYaTojilpLv65qH08QmpYclfpacDA0U3wg6Ybw,64
|
19
21
|
ragaai_catalyst/tracers/utils/utils.py,sha256=ViygfJ7vZ7U0CTSA1lbxVloHp4NSlmfDzBRNCJuMhis,2374
|
20
|
-
ragaai_catalyst-2.0.
|
21
|
-
ragaai_catalyst-2.0.
|
22
|
-
ragaai_catalyst-2.0.
|
23
|
-
ragaai_catalyst-2.0.
|
22
|
+
ragaai_catalyst-2.0.4.dist-info/METADATA,sha256=XfU--Yoer4WL2NfZ7RNJcojMkMDR5bfL8J36NV6HqQA,6625
|
23
|
+
ragaai_catalyst-2.0.4.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
24
|
+
ragaai_catalyst-2.0.4.dist-info/top_level.txt,sha256=HpgsdRgEJMk8nqrU6qdCYk3di7MJkDL0B19lkc7dLfM,16
|
25
|
+
ragaai_catalyst-2.0.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|