ragaai-catalyst 2.0.4__py3-none-any.whl → 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,14 @@ import csv
7
7
  import markdown
8
8
  import pandas as pd
9
9
  import json
10
- from ragaai_catalyst import proxy_call
10
+ from litellm import completion
11
+ from tqdm import tqdm
12
+ # import internal_api_completion
13
+ # import proxy_call
14
+ from .internal_api_completion import api_completion as internal_api_completion
15
+ from .proxy_call import api_completion as proxy_api_completion
16
+ # from ragaai_catalyst import internal_api_completion
17
+ # from ragaai_catalyst import proxy_call
11
18
  import ast
12
19
 
13
20
  # dotenv.load_dotenv()
@@ -21,55 +28,170 @@ class SyntheticDataGeneration:
21
28
  """
22
29
  Initialize the SyntheticDataGeneration class with API clients for Groq, Gemini, and OpenAI.
23
30
  """
24
- def generate_qna(self, text, question_type="simple", n=5,model_config=dict(),api_key=None):
31
+
32
+ def generate_qna(self, text, question_type="simple", n=5, model_config=dict(), api_key=None, **kwargs):
25
33
  """
26
34
  Generate questions based on the given text using the specified model and provider.
35
+ Uses batch processing for larger values of n to maintain response quality.
27
36
 
28
37
  Args:
29
38
  text (str): The input text to generate questions from.
30
39
  question_type (str): The type of questions to generate ('simple', 'mcq', or 'complex').
31
- model (str): The specific model to use for generation.
32
- provider (str): The AI provider to use ('groq', 'gemini', or 'openai').
33
40
  n (int): The number of question/answer pairs to generate.
41
+ model_config (dict): Configuration for the model including provider and model name.
42
+ api_key (str, optional): The API key for the selected provider.
43
+ **kwargs: Additional keyword arguments.
34
44
 
35
45
  Returns:
36
- pandas.DataFrame: A DataFrame containing the generated questions and answers.
46
+ pandas.DataFrame: A DataFrame containing exactly n generated questions and answers.
37
47
 
38
48
  Raises:
39
- ValueError: If an invalid provider is specified.
49
+ ValueError: If an invalid provider is specified or API key is missing.
40
50
  """
51
+ BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
41
52
  provider = model_config.get("provider")
42
53
  model = model_config.get("model")
43
54
  api_base = model_config.get("api_base")
44
55
 
45
- system_message = self._get_system_message(question_type, n)
56
+ # Initialize the appropriate client based on provider
57
+ self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
58
+
59
+ # Initialize progress bar
60
+ pbar = tqdm(total=n, desc="Generating QA pairs")
61
+
62
+ # Initial generation phase
63
+ num_batches = (n + BATCH_SIZE - 1) // BATCH_SIZE
64
+ all_responses = []
65
+
66
+ FAILURE_CASES = [
67
+ "Invalid API key provided",
68
+ "No connection adapters",
69
+ "Required API Keys are not set",
70
+ "litellm.BadRequestError",
71
+ "litellm.AuthenticationError"]
72
+
73
+ for _ in range(num_batches):
74
+ current_batch_size = min(BATCH_SIZE, n - len(all_responses))
75
+ if current_batch_size <= 0:
76
+ break
77
+
78
+ try:
79
+ system_message = self._get_system_message(question_type, current_batch_size)
80
+
81
+ if "internal_llm_proxy" in kwargs:
82
+ batch_df = self._generate_internal_response(text, system_message, model_config, kwargs)
83
+ else:
84
+ batch_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
85
+
86
+ if not batch_df.empty and len(batch_df) > 0:
87
+ all_responses.extend(batch_df.to_dict('records'))
88
+ pbar.update(len(batch_df))
89
+
90
+ except Exception as e:
91
+ print(f"Batch generation failed.")
92
+
93
+ if any(error in str(e) for error in FAILURE_CASES):
94
+ raise Exception(f"{e}")
95
+
96
+ else:
97
+ print(f"Retrying...")
98
+ continue
99
+
100
+
101
+ # Convert to DataFrame and remove duplicates
102
+ result_df = pd.DataFrame(all_responses)
103
+ result_df = result_df.drop_duplicates(subset=['Question'])
104
+
105
+ # Replenish phase - generate additional questions if needed due to duplicates
106
+ while (len(result_df) < n) and ((len(result_df) >= 1)):
107
+ questions_needed = n - len(result_df)
108
+ try:
109
+ system_message = self._get_system_message(question_type, questions_needed)
110
+
111
+ if "internal_llm_proxy" in kwargs:
112
+ additional_df = self._generate_internal_response(text, system_message, model_config, kwargs)
113
+ else:
114
+ additional_df = self._generate_batch_response(text, system_message, provider, model_config, api_key, api_base)
115
+
116
+ if not additional_df.empty and len(additional_df) > 0:
117
+ # Only add questions that aren't already in result_df
118
+ new_questions = additional_df[~additional_df['Question'].isin(result_df['Question'])]
119
+ if not new_questions.empty:
120
+ result_df = pd.concat([result_df, new_questions], ignore_index=True)
121
+ result_df = result_df.drop_duplicates(subset=['Question'])
122
+ pbar.update(len(new_questions))
123
+
124
+ except Exception as e:
125
+ print(f"Replenishment generation failed")
126
+
127
+ if any(error in str(e) for error in FAILURE_CASES):
128
+ raise Exception(f"{e}")
129
+
130
+ else:
131
+ print("An unexpected error occurred. Retrying...")
132
+ continue
133
+
134
+ pbar.close()
135
+
136
+ # Ensure exactly n rows and reset index starting from 1
137
+ final_df = result_df.head(n)
138
+ final_df.index = range(1, len(final_df) + 1)
139
+
140
+ return final_df
141
+
142
+ def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
143
+ """Initialize the appropriate client based on provider."""
144
+ if not provider:
145
+ raise ValueError("Model configuration must be provided with a valid provider and model.")
146
+
46
147
  if provider == "groq":
47
148
  if api_key is None and os.getenv("GROQ_API_KEY") is None:
48
149
  raise ValueError("API key must be provided for Groq.")
49
150
  self.groq_client = Groq(api_key=api_key or os.getenv("GROQ_API_KEY"))
50
- return self._generate_groq(text, system_message, model)
151
+
51
152
  elif provider == "gemini":
153
+ if api_key is None and os.getenv("GEMINI_API_KEY") is None and api_base is None and internal_llm_proxy is None:
154
+ raise ValueError("API key must be provided for Gemini.")
52
155
  genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
53
- if api_base is None:
54
- if api_key is None and os.getenv("GEMINI_API_KEY") is None:
55
- raise ValueError("API key must be provided for Gemini.")
56
- genai.configure(api_key=api_key or os.getenv("GEMINI_API_KEY"))
57
- return self._generate_gemini(text, system_message, model)
58
- else:
59
- messages=[
60
- {'role': 'user', 'content': system_message+text}
61
- ]
62
- a= proxy_call.api_completion(messages=messages ,model=model ,api_base=api_base)
63
- b= ast.literal_eval(a[0])
64
- return pd.DataFrame(b)
156
+
65
157
  elif provider == "openai":
66
- if api_key is None and os.getenv("OPENAI_API_KEY") is None:
158
+ if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
67
159
  raise ValueError("API key must be provided for OpenAI.")
68
160
  openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
69
- return self._generate_openai(text, system_message, model,api_key=api_key)
161
+
70
162
  else:
71
- raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
163
+ raise ValueError(f"Provider is not recognized.")
164
+
165
+ def _generate_batch_response(self, text, system_message, provider, model_config, api_key, api_base):
166
+ """Generate a batch of responses using the specified provider."""
167
+ MAX_RETRIES = 3
168
+
169
+ for attempt in range(MAX_RETRIES):
170
+ try:
171
+ if provider == "gemini" and api_base:
172
+ messages = [{'role': 'user', 'content': system_message + text}]
173
+ response = proxy_api_completion(messages=messages, model=model_config["model"], api_base=api_base)
174
+ # response = proxy_call.api_completion(messages=messages, model=model_config["model"], api_base=api_base)
175
+ return pd.DataFrame(ast.literal_eval(response[0]))
176
+ else:
177
+ return self._generate_llm_response(text, system_message, model_config, api_key)
178
+ except (json.JSONDecodeError, ValueError) as e:
179
+ if attempt == MAX_RETRIES - 1:
180
+ raise Exception(f"Failed to generate valid response after {MAX_RETRIES} attempts: {str(e)}")
181
+ continue
182
+
183
+ def _generate_internal_response(self, text, system_message, model_config, kwargs):
184
+ """Generate response using internal API."""
185
+ messages = [{'role': 'user', 'content': system_message + text}]
186
+ return internal_api_completion(
187
+ messages=messages,
188
+ model_config=model_config,
189
+ kwargs=kwargs
190
+ )
72
191
 
192
+
193
+
194
+
73
195
  def _get_system_message(self, question_type, n):
74
196
  """
75
197
  Get the appropriate system message for the specified question type.
@@ -113,64 +235,68 @@ class SyntheticDataGeneration:
113
235
  else:
114
236
  raise ValueError("Invalid question type")
115
237
 
116
- def _generate_groq(self, text, system_message, model):
238
+ def _generate_llm_response(self, text, system_message, model_config, api_key=None):
117
239
  """
118
- Generate questions using the Groq API.
240
+ Generate questions using LiteLLM which supports multiple providers (OpenAI, Groq, Gemini, etc.).
119
241
 
120
242
  Args:
121
243
  text (str): The input text to generate questions from.
122
244
  system_message (str): The system message for the AI model.
123
- model (str): The specific Groq model to use.
245
+ model_config (dict): Configuration dictionary containing model details.
246
+ Required keys:
247
+ - model: The model identifier (e.g., "gpt-4", "gemini-pro", "mixtral-8x7b-32768")
248
+ Optional keys:
249
+ - api_base: Custom API base URL if needed
250
+ - max_tokens: Maximum tokens in response
251
+ - temperature: Temperature for response generation
252
+ api_key (str, optional): The API key for the model provider.
124
253
 
125
254
  Returns:
126
255
  pandas.DataFrame: A DataFrame containing the generated questions and answers.
127
- """
128
- response = self.groq_client.chat.completions.create(
129
- model=model,
130
- messages=[
131
- {'role': 'system', 'content': system_message},
132
- {'role': 'user', 'content': text}
133
- ]
134
- )
135
- return self._parse_response(response, provider="groq")
136
256
 
137
- def _generate_gemini(self, text, system_message, model):
257
+ Raises:
258
+ Exception: If there's an error in generating the response.
138
259
  """
139
- Generate questions using the Gemini API.
140
260
 
141
- Args:
142
- text (str): The input text to generate questions from.
143
- system_message (str): The system message for the AI model.
144
- model (str): The specific Gemini model to use.
145
-
146
- Returns:
147
- pandas.DataFrame: A DataFrame containing the generated questions and answers.
148
- """
149
- model = genai.GenerativeModel(model)
150
- response = model.generate_content([system_message, text])
151
- return self._parse_response(response, provider="gemini")
261
+ # Prepare the messages in the format expected by LiteLLM
262
+ messages = [
263
+ {"role": "system", "content": system_message},
264
+ {"role": "user", "content": text}
265
+ ]
266
+
267
+ # Set up the completion parameters
268
+ completion_params = {
269
+ "model": model_config["model"],
270
+ "messages": messages,
271
+ "api_key": api_key
272
+ }
273
+
274
+ # Add optional parameters if they exist in model_config
275
+ if "api_base" in model_config:
276
+ completion_params["api_base"] = model_config["api_base"]
277
+ if "max_tokens" in model_config:
278
+ completion_params["max_tokens"] = model_config["max_tokens"]
279
+ if "temperature" in model_config:
280
+ completion_params["temperature"] = model_config["temperature"]
281
+
282
+ # Make the API call using LiteLLM
283
+ try:
284
+ response = completion(**completion_params)
285
+ except Exception as e:
286
+ if any(error in str(e).lower() for error in ["invalid api key", "incorrect api key", "unauthorized", "authentication"]):
287
+ raise ValueError(f"Invalid API key provided for {model_config.get('provider', 'the specified')} provider")
288
+ raise Exception(f"Error calling LLM API: {str(e)}")
152
289
 
153
- def _generate_openai(self, text, system_message, model,api_key=None):
154
- """
155
- Generate questions using the OpenAI API.
290
+ # Extract the content from the response
291
+ content = response.choices[0].message.content
156
292
 
157
- Args:+
158
- text (str): The input text to generate questions from.
159
- system_message (str): The system message for the AI model.
160
- model (str): The specific OpenAI model to use.
293
+ # Clean the response if needed (remove any prefix before the JSON list)
294
+ list_start_index = content.find('[')
295
+ if list_start_index != -1:
296
+ content = content[list_start_index:]
161
297
 
162
- Returns:
163
- pandas.DataFrame: A DataFrame containing the generated questions and answers.
164
- """
165
- client = openai.OpenAI(api_key=api_key)
166
- response = client.chat.completions.create(
167
- model=model,
168
- messages=[
169
- {"role": "system", "content": system_message},
170
- {"role": "user", "content": text}
171
- ]
172
- )
173
- return self._parse_response(response, provider="openai")
298
+ json_data = json.loads(content)
299
+ return pd.DataFrame(json_data)
174
300
 
175
301
  def _parse_response(self, response, provider):
176
302
  """
@@ -218,16 +344,19 @@ class SyntheticDataGeneration:
218
344
  if os.path.isfile(input_data):
219
345
  # If input_data is a file path
220
346
  _, file_extension = os.path.splitext(input_data)
221
- if file_extension.lower() == '.pdf':
222
- return self._read_pdf(input_data)
223
- elif file_extension.lower() == '.txt':
224
- return self._read_text(input_data)
225
- elif file_extension.lower() == '.md':
226
- return self._read_markdown(input_data)
227
- elif file_extension.lower() == '.csv':
228
- return self._read_csv(input_data)
229
- else:
230
- raise ValueError(f"Unsupported file type: {file_extension}")
347
+ try:
348
+ if file_extension.lower() == '.pdf':
349
+ return self._read_pdf(input_data)
350
+ elif file_extension.lower() == '.txt':
351
+ return self._read_text(input_data)
352
+ elif file_extension.lower() == '.md':
353
+ return self._read_markdown(input_data)
354
+ elif file_extension.lower() == '.csv':
355
+ return self._read_csv(input_data)
356
+ else:
357
+ raise ValueError(f"Unsupported file type: {file_extension}")
358
+ except Exception as e:
359
+ raise ValueError(f"Error reading the file. Upload a valid file. \n{e}")
231
360
  else:
232
361
  # If input_data is a string of text
233
362
  return input_data