ragaai-catalyst 2.0.4__py3-none-any.whl → 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/__init__.py +3 -1
- ragaai_catalyst/dataset.py +50 -61
- ragaai_catalyst/evaluation.py +80 -47
- ragaai_catalyst/guard_executor.py +97 -0
- ragaai_catalyst/guardrails_manager.py +259 -0
- ragaai_catalyst/internal_api_completion.py +83 -0
- ragaai_catalyst/prompt_manager.py +1 -1
- ragaai_catalyst/proxy_call.py +1 -1
- ragaai_catalyst/ragaai_catalyst.py +1 -1
- ragaai_catalyst/synthetic_data_generation.py +206 -77
- ragaai_catalyst/tracers/llamaindex_callback.py +361 -0
- ragaai_catalyst/tracers/tracer.py +62 -28
- ragaai_catalyst-2.0.6.dist-info/METADATA +386 -0
- ragaai_catalyst-2.0.6.dist-info/RECORD +29 -0
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6.dist-info}/WHEEL +1 -1
- ragaai_catalyst-2.0.4.dist-info/METADATA +0 -228
- ragaai_catalyst-2.0.4.dist-info/RECORD +0 -25
- {ragaai_catalyst-2.0.4.dist-info → ragaai_catalyst-2.0.6.dist-info}/top_level.txt +0 -0
ragaai_catalyst/__init__.py
CHANGED
@@ -6,6 +6,8 @@ from .dataset import Dataset
|
|
6
6
|
from .prompt_manager import PromptManager
|
7
7
|
from .evaluation import Evaluation
|
8
8
|
from .synthetic_data_generation import SyntheticDataGeneration
|
9
|
+
from .guardrails_manager import GuardrailsManager
|
10
|
+
from .guard_executor import GuardExecutor
|
9
11
|
|
10
12
|
|
11
|
-
__all__ = ["Experiment", "RagaAICatalyst", "Tracer", "PromptManager", "Evaluation","SyntheticDataGeneration"]
|
13
|
+
__all__ = ["Experiment", "RagaAICatalyst", "Tracer", "PromptManager", "Evaluation","SyntheticDataGeneration", "GuardrailsManager"]
|
ragaai_catalyst/dataset.py
CHANGED
@@ -16,7 +16,7 @@ class Dataset:
|
|
16
16
|
|
17
17
|
def __init__(self, project_name):
|
18
18
|
self.project_name = project_name
|
19
|
-
self.num_projects =
|
19
|
+
self.num_projects = 99999
|
20
20
|
Dataset.BASE_URL = (
|
21
21
|
os.getenv("RAGAAI_CATALYST_BASE_URL")
|
22
22
|
if os.getenv("RAGAAI_CATALYST_BASE_URL")
|
@@ -99,82 +99,71 @@ class Dataset:
|
|
99
99
|
raise
|
100
100
|
|
101
101
|
def get_schema_mapping(self):
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
Creates a new dataset with the given `dataset_name` and `filter_list`.
|
107
|
-
|
108
|
-
Args:
|
109
|
-
dataset_name (str): The name of the dataset to be created.
|
110
|
-
filter_list (list): A list of filters to be applied to the dataset.
|
111
|
-
|
112
|
-
Returns:
|
113
|
-
str: A message indicating the success of the dataset creation and the name of the created dataset.
|
114
|
-
|
115
|
-
Raises:
|
116
|
-
None
|
117
|
-
|
118
|
-
"""
|
119
|
-
|
120
|
-
def request_trace_creation():
|
121
|
-
headers = {
|
122
|
-
"Content-Type": "application/json",
|
123
|
-
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
124
|
-
"X-Project-Name": self.project_name,
|
125
|
-
}
|
126
|
-
json_data = {
|
127
|
-
"projectName": self.project_name,
|
128
|
-
"subDatasetName": dataset_name,
|
129
|
-
"filterList": filter_list,
|
130
|
-
}
|
131
|
-
try:
|
132
|
-
response = requests.post(
|
133
|
-
f"{Dataset.BASE_URL}/v1/llm/sub-dataset",
|
134
|
-
headers=headers,
|
135
|
-
json=json_data,
|
136
|
-
timeout=Dataset.TIMEOUT,
|
137
|
-
)
|
138
|
-
response.raise_for_status()
|
139
|
-
return response
|
140
|
-
except requests.exceptions.RequestException as e:
|
141
|
-
logger.error(f"Failed to create dataset from trace: {e}")
|
142
|
-
raise
|
143
|
-
|
102
|
+
headers = {
|
103
|
+
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
104
|
+
"X-Project-Name": self.project_name,
|
105
|
+
}
|
144
106
|
try:
|
145
|
-
response =
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
107
|
+
response = requests.get(
|
108
|
+
f"{Dataset.BASE_URL}/v1/llm/schema-elements",
|
109
|
+
headers=headers,
|
110
|
+
timeout=Dataset.TIMEOUT,
|
111
|
+
)
|
112
|
+
response.raise_for_status()
|
113
|
+
response_data = response.json()["data"]["schemaElements"]
|
114
|
+
if not response.json()['success']:
|
115
|
+
raise ValueError('Unable to fetch Schema Elements for the CSV')
|
116
|
+
return response_data
|
117
|
+
except requests.exceptions.RequestException as e:
|
118
|
+
logger.error(f"Failed to get CSV schema: {e}")
|
156
119
|
raise
|
157
120
|
|
158
121
|
###################### CSV Upload APIs ###################
|
159
122
|
|
160
|
-
def
|
123
|
+
def get_dataset_columns(self, dataset_name):
|
124
|
+
list_dataset = self.list_datasets()
|
125
|
+
if dataset_name not in list_dataset:
|
126
|
+
raise ValueError(f"Dataset {dataset_name} does not exists. Please enter a valid dataset name")
|
127
|
+
|
161
128
|
headers = {
|
162
129
|
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
163
130
|
"X-Project-Name": self.project_name,
|
164
131
|
}
|
132
|
+
headers = {
|
133
|
+
'Content-Type': 'application/json',
|
134
|
+
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
135
|
+
"X-Project-Id": str(self.project_id),
|
136
|
+
}
|
137
|
+
json_data = {"size": 12, "page": "0", "projectId": str(self.project_id), "search": ""}
|
138
|
+
try:
|
139
|
+
response = requests.post(
|
140
|
+
f"{Dataset.BASE_URL}/v2/llm/dataset",
|
141
|
+
headers=headers,
|
142
|
+
json=json_data,
|
143
|
+
timeout=Dataset.TIMEOUT,
|
144
|
+
)
|
145
|
+
response.raise_for_status()
|
146
|
+
datasets = response.json()["data"]["content"]
|
147
|
+
dataset_id = [dataset["id"] for dataset in datasets if dataset["name"]==dataset_name][0]
|
148
|
+
except requests.exceptions.RequestException as e:
|
149
|
+
logger.error(f"Failed to list datasets: {e}")
|
150
|
+
raise
|
151
|
+
|
165
152
|
try:
|
166
153
|
response = requests.get(
|
167
|
-
f"{Dataset.BASE_URL}/
|
154
|
+
f"{Dataset.BASE_URL}/v2/llm/dataset/{dataset_id}?initialCols=0",
|
168
155
|
headers=headers,
|
169
156
|
timeout=Dataset.TIMEOUT,
|
170
157
|
)
|
171
158
|
response.raise_for_status()
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
159
|
+
dataset_columns = response.json()["data"]["datasetColumnsResponses"]
|
160
|
+
dataset_columns = [item["displayName"] for item in dataset_columns]
|
161
|
+
dataset_columns = [data for data in dataset_columns if not data.startswith('_')]
|
162
|
+
if not response.json()['success']:
|
163
|
+
raise ValueError('Unable to fetch details of for the CSV')
|
164
|
+
return dataset_columns
|
176
165
|
except requests.exceptions.RequestException as e:
|
177
|
-
logger.error(f"Failed to get CSV
|
166
|
+
logger.error(f"Failed to get CSV columns: {e}")
|
178
167
|
raise
|
179
168
|
|
180
169
|
def create_from_csv(self, csv_path, dataset_name, schema_mapping):
|
ragaai_catalyst/evaluation.py
CHANGED
@@ -16,7 +16,7 @@ class Evaluation:
|
|
16
16
|
self.base_url = f"{RagaAICatalyst.BASE_URL}"
|
17
17
|
self.timeout = 10
|
18
18
|
self.jobId = None
|
19
|
-
self.num_projects=
|
19
|
+
self.num_projects=99999
|
20
20
|
|
21
21
|
try:
|
22
22
|
response = requests.get(
|
@@ -80,7 +80,8 @@ class Evaluation:
|
|
80
80
|
try:
|
81
81
|
response = requests.get(
|
82
82
|
f'{self.base_url}/v1/llm/llm-metrics',
|
83
|
-
headers=headers
|
83
|
+
headers=headers,
|
84
|
+
timeout=self.timeout)
|
84
85
|
response.raise_for_status()
|
85
86
|
metric_names = [metric["name"] for metric in response.json()["data"]["metrics"]]
|
86
87
|
return metric_names
|
@@ -96,14 +97,45 @@ class Evaluation:
|
|
96
97
|
logger.error(f"An unexpected error occurred: {e}")
|
97
98
|
return []
|
98
99
|
|
99
|
-
def
|
100
|
+
def _get_dataset_id_based_on_dataset_type(self, metric_to_evaluate):
|
101
|
+
try:
|
102
|
+
headers = {
|
103
|
+
'Content-Type': 'application/json',
|
104
|
+
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
105
|
+
"X-Project-Id": str(self.project_id),
|
106
|
+
}
|
107
|
+
json_data = {"size": 12, "page": "0", "projectId": str(self.project_id), "search": ""}
|
108
|
+
response = requests.post(
|
109
|
+
f"{self.base_url}/v2/llm/dataset",
|
110
|
+
headers=headers,
|
111
|
+
json=json_data,
|
112
|
+
timeout=self.timeout,
|
113
|
+
)
|
114
|
+
|
115
|
+
response.raise_for_status()
|
116
|
+
datasets_content = response.json()["data"]["content"]
|
117
|
+
dataset = [dataset for dataset in datasets_content if dataset["name"]==self.dataset_name][0]
|
118
|
+
if (dataset["datasetType"]=="prompt" and metric_to_evaluate=="prompt") or (dataset["datasetType"]=="chat" and metric_to_evaluate=="chat") or dataset["datasetType"]==None:
|
119
|
+
return dataset["id"]
|
120
|
+
else:
|
121
|
+
return dataset["derivedDatasetId"]
|
122
|
+
except requests.exceptions.RequestException as e:
|
123
|
+
logger.error(f"Failed to retrieve dataset list: {e}")
|
124
|
+
raise
|
125
|
+
|
126
|
+
|
127
|
+
def _get_dataset_schema(self, metric_to_evaluate=None):
|
128
|
+
#this dataset_id is based on which type of metric_to_evaluate
|
129
|
+
data_set_id=self._get_dataset_id_based_on_dataset_type(metric_to_evaluate)
|
130
|
+
self.dataset_id=data_set_id
|
131
|
+
|
100
132
|
headers = {
|
101
133
|
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
102
134
|
'Content-Type': 'application/json',
|
103
135
|
'X-Project-Id': str(self.project_id),
|
104
136
|
}
|
105
137
|
data = {
|
106
|
-
"datasetId": str(
|
138
|
+
"datasetId": str(data_set_id),
|
107
139
|
"fields": [],
|
108
140
|
"rowFilterList": []
|
109
141
|
}
|
@@ -111,7 +143,8 @@ class Evaluation:
|
|
111
143
|
response = requests.post(
|
112
144
|
f'{self.base_url}/v1/llm/docs',
|
113
145
|
headers=headers,
|
114
|
-
json=data
|
146
|
+
json=data,
|
147
|
+
timeout=self.timeout)
|
115
148
|
response.raise_for_status()
|
116
149
|
if response.status_code == 200:
|
117
150
|
return response.json()["data"]["columns"]
|
@@ -127,29 +160,9 @@ class Evaluation:
|
|
127
160
|
logger.error(f"An unexpected error occurred: {e}")
|
128
161
|
return {}
|
129
162
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
dataset_schema = self._get_dataset_schema()
|
134
|
-
variableName = None
|
135
|
-
for column in dataset_schema:
|
136
|
-
columnName = column["columnType"]
|
137
|
-
displayName = column["displayName"]
|
138
|
-
# print(columnName, displayName)
|
139
|
-
if "".join(columnName.split("_")).lower() == schemaName.lower():
|
140
|
-
variableName = displayName
|
141
|
-
break
|
142
|
-
return variableName
|
143
|
-
# print(variableName)
|
144
|
-
# if variableName:
|
145
|
-
# return variableName
|
146
|
-
# else:
|
147
|
-
# raise ValueError(f"'{schemaName}' column is required for {metric_name} metric evaluation, but not found in dataset")
|
148
|
-
|
149
|
-
|
150
|
-
def _get_variablename_from_user_schema_mapping(self, schemaName, metric_name, schema_mapping):
|
151
|
-
# pdb.set_trace()
|
152
|
-
user_dataset_schema = self._get_dataset_schema()
|
163
|
+
|
164
|
+
def _get_variablename_from_user_schema_mapping(self, schemaName, metric_name, schema_mapping, metric_to_evaluate):
|
165
|
+
user_dataset_schema = self._get_dataset_schema(metric_to_evaluate)
|
153
166
|
user_dataset_columns = [item["displayName"] for item in user_dataset_schema]
|
154
167
|
variableName = None
|
155
168
|
for key, val in schema_mapping.items():
|
@@ -157,7 +170,7 @@ class Evaluation:
|
|
157
170
|
if key in user_dataset_columns:
|
158
171
|
variableName=key
|
159
172
|
else:
|
160
|
-
raise ValueError(f"Column '{key}' is not present in {self.dataset_name}")
|
173
|
+
raise ValueError(f"Column '{key}' is not present in '{self.dataset_name}' dataset")
|
161
174
|
if variableName:
|
162
175
|
return variableName
|
163
176
|
else:
|
@@ -170,10 +183,17 @@ class Evaluation:
|
|
170
183
|
for schema in metrics_schema:
|
171
184
|
if schema["name"]==metric_name:
|
172
185
|
requiredFields = schema["config"]["requiredFields"]
|
186
|
+
|
187
|
+
#this is added to check if "Chat" column is required for metric evaluation
|
188
|
+
required_variables = [_["name"].lower() for _ in requiredFields]
|
189
|
+
if "chat" in required_variables:
|
190
|
+
metric_to_evaluate = "chat"
|
191
|
+
else:
|
192
|
+
metric_to_evaluate = "prompt"
|
193
|
+
|
173
194
|
for field in requiredFields:
|
174
195
|
schemaName = field["name"]
|
175
|
-
|
176
|
-
variableName = self._get_variablename_from_user_schema_mapping(schemaName.lower(), metric_name, schema_mapping)
|
196
|
+
variableName = self._get_variablename_from_user_schema_mapping(schemaName.lower(), metric_name, schema_mapping, metric_to_evaluate)
|
177
197
|
mapping.append({"schemaName": schemaName, "variableName": variableName})
|
178
198
|
return mapping
|
179
199
|
|
@@ -203,7 +223,8 @@ class Evaluation:
|
|
203
223
|
try:
|
204
224
|
response = requests.get(
|
205
225
|
f'{self.base_url}/v1/llm/llm-metrics',
|
206
|
-
headers=headers
|
226
|
+
headers=headers,
|
227
|
+
timeout=self.timeout)
|
207
228
|
response.raise_for_status()
|
208
229
|
metrics_schema = [metric for metric in response.json()["data"]["metrics"]]
|
209
230
|
return metrics_schema
|
@@ -220,7 +241,6 @@ class Evaluation:
|
|
220
241
|
return []
|
221
242
|
|
222
243
|
def _update_base_json(self, metrics):
|
223
|
-
metric_schema_mapping = {"datasetId":self.dataset_id}
|
224
244
|
metrics_schema_response = self._get_metrics_schema_response()
|
225
245
|
sub_providers = ["openai","azure","gemini","groq"]
|
226
246
|
metricParams = []
|
@@ -233,8 +253,15 @@ class Evaluation:
|
|
233
253
|
#checking if provider is one of the allowed providers
|
234
254
|
if key.lower()=="provider" and value.lower() not in sub_providers:
|
235
255
|
raise ValueError("Enter a valid provider name. The following Provider names are supported: OpenAI, Azure, Gemini, Groq")
|
236
|
-
|
237
|
-
|
256
|
+
|
257
|
+
if key.lower()=="threshold":
|
258
|
+
if len(value)>1:
|
259
|
+
raise ValueError("'threshold' can only take one argument gte/lte/eq")
|
260
|
+
else:
|
261
|
+
for key_thres, value_thres in value.items():
|
262
|
+
base_json["metricSpec"]["config"]["params"][key] = {f"{key_thres}":value_thres}
|
263
|
+
else:
|
264
|
+
base_json["metricSpec"]["config"]["params"][key] = {"value": value}
|
238
265
|
|
239
266
|
|
240
267
|
# if metric["config"]["model"]:
|
@@ -243,6 +270,7 @@ class Evaluation:
|
|
243
270
|
mappings = self._get_mapping(metric["name"], metrics_schema_response, metric["schema_mapping"])
|
244
271
|
base_json["metricSpec"]["config"]["mappings"] = mappings
|
245
272
|
metricParams.append(base_json)
|
273
|
+
metric_schema_mapping = {"datasetId":self.dataset_id}
|
246
274
|
metric_schema_mapping["metricParams"] = metricParams
|
247
275
|
return metric_schema_mapping
|
248
276
|
|
@@ -253,12 +281,15 @@ class Evaluation:
|
|
253
281
|
}
|
254
282
|
try:
|
255
283
|
response = requests.get(
|
256
|
-
f
|
257
|
-
headers=headers
|
258
|
-
|
284
|
+
f"{self.base_url}/v2/llm/dataset/{str(self.dataset_id)}?initialCols=0",
|
285
|
+
headers=headers,
|
286
|
+
timeout=self.timeout,
|
287
|
+
)
|
259
288
|
response.raise_for_status()
|
260
|
-
|
261
|
-
|
289
|
+
dataset_columns = response.json()["data"]["datasetColumnsResponses"]
|
290
|
+
dataset_columns = [item["displayName"] for item in dataset_columns]
|
291
|
+
executed_metric_list = [data for data in dataset_columns if not data.startswith('_')]
|
292
|
+
|
262
293
|
return executed_metric_list
|
263
294
|
except requests.exceptions.HTTPError as http_err:
|
264
295
|
logger.error(f"HTTP error occurred: {http_err}")
|
@@ -301,7 +332,8 @@ class Evaluation:
|
|
301
332
|
response = requests.post(
|
302
333
|
f'{self.base_url}/playground/metric-evaluation',
|
303
334
|
headers=headers,
|
304
|
-
json=metric_schema_mapping
|
335
|
+
json=metric_schema_mapping,
|
336
|
+
timeout=self.timeout
|
305
337
|
)
|
306
338
|
if response.status_code == 400:
|
307
339
|
raise ValueError(response.json()["message"])
|
@@ -327,14 +359,14 @@ class Evaluation:
|
|
327
359
|
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
328
360
|
'X-Project-Id': str(self.project_id),
|
329
361
|
}
|
330
|
-
data = {"jobId": self.jobId}
|
331
362
|
try:
|
332
|
-
response = requests.
|
363
|
+
response = requests.get(
|
333
364
|
f'{self.base_url}/job/status',
|
334
365
|
headers=headers,
|
335
|
-
|
366
|
+
timeout=self.timeout)
|
336
367
|
response.raise_for_status()
|
337
|
-
|
368
|
+
if response.json()["success"]:
|
369
|
+
status_json = [item["status"] for item in response.json()["data"]["content"] if item["id"]==self.jobId][0]
|
338
370
|
if status_json == "Failed":
|
339
371
|
return print("Job failed. No results to fetch.")
|
340
372
|
elif status_json == "In Progress":
|
@@ -373,7 +405,8 @@ class Evaluation:
|
|
373
405
|
response = requests.post(
|
374
406
|
f'{self.base_url}/v1/llm/docs',
|
375
407
|
headers=headers,
|
376
|
-
json=data
|
408
|
+
json=data,
|
409
|
+
timeout=self.timeout)
|
377
410
|
response.raise_for_status()
|
378
411
|
return response.json()
|
379
412
|
except requests.exceptions.HTTPError as http_err:
|
@@ -392,7 +425,7 @@ class Evaluation:
|
|
392
425
|
try:
|
393
426
|
response = get_presignedUrl()
|
394
427
|
preSignedURL = response["data"]["preSignedURL"]
|
395
|
-
response = requests.get(preSignedURL)
|
428
|
+
response = requests.get(preSignedURL, timeout=self.timeout)
|
396
429
|
response.raise_for_status()
|
397
430
|
return response.text
|
398
431
|
except requests.exceptions.HTTPError as http_err:
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import litellm
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
import os
|
5
|
+
import logging
|
6
|
+
logger = logging.getLogger('LiteLLM')
|
7
|
+
logger.setLevel(logging.ERROR)
|
8
|
+
|
9
|
+
class GuardExecutor:
|
10
|
+
|
11
|
+
def __init__(self,id,guard_manager,field_map={}):
|
12
|
+
self.deployment_id = id
|
13
|
+
self.field_map = field_map
|
14
|
+
self.guard_manager = guard_manager
|
15
|
+
self.deployment_details = self.guard_manager.get_deployment(id)
|
16
|
+
if not self.deployment_details:
|
17
|
+
raise ValueError('Error in getting deployment details')
|
18
|
+
self.base_url = guard_manager.base_url
|
19
|
+
for key in field_map.keys():
|
20
|
+
if key not in ['prompt','context','response','instruction']:
|
21
|
+
print('Keys in field map should be in ["prompt","context","response","instruction"]')
|
22
|
+
|
23
|
+
def execute_deployment(self,payload):
|
24
|
+
api = self.base_url + f'/guardrail/deployment/{self.deployment_id}/ingest'
|
25
|
+
|
26
|
+
payload = json.dumps(payload)
|
27
|
+
headers = {
|
28
|
+
'x-project-id': str(self.guard_manager.project_id),
|
29
|
+
'Content-Type': 'application/json',
|
30
|
+
'Authorization': f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}'
|
31
|
+
}
|
32
|
+
try:
|
33
|
+
response = requests.request("POST", api, headers=headers, data=payload,timeout=self.guard_manager.timeout)
|
34
|
+
except Exception as e:
|
35
|
+
print('Failed running guardrail: ',str(e))
|
36
|
+
return None
|
37
|
+
if response.status_code!=200:
|
38
|
+
print('Error in running deployment ',response.json()['message'])
|
39
|
+
if response.json()['success']:
|
40
|
+
return response.json()
|
41
|
+
else:
|
42
|
+
print(response.json()['message'])
|
43
|
+
return None
|
44
|
+
|
45
|
+
def llm_executor(self,messages,model_params,llm_caller):
|
46
|
+
if llm_caller == 'litellm':
|
47
|
+
model_params['messages'] = messages
|
48
|
+
response = litellm.completion(**model_params)
|
49
|
+
return response
|
50
|
+
else:
|
51
|
+
print(f"{llm_caller} not supported currently, use litellm as llm caller")
|
52
|
+
|
53
|
+
|
54
|
+
def __call__(self,messages,prompt_params,model_params,llm_caller='litellm'):
|
55
|
+
for key in self.field_map:
|
56
|
+
if key not in ['prompt','response']:
|
57
|
+
if self.field_map[key] not in prompt_params:
|
58
|
+
raise ValueError(f'{key} added as field map but not passed as prompt parameter')
|
59
|
+
context_var = self.field_map.get('context',None)
|
60
|
+
prompt = None
|
61
|
+
for msg in messages:
|
62
|
+
if 'role' in msg:
|
63
|
+
if msg['role'] == 'user':
|
64
|
+
prompt = msg['content']
|
65
|
+
if not context_var:
|
66
|
+
msg['content'] += '\n' + prompt_params[context_var]
|
67
|
+
doc = dict()
|
68
|
+
doc['prompt'] = prompt
|
69
|
+
doc['context'] = prompt_params[context_var]
|
70
|
+
|
71
|
+
# inactive the guardrails that needs Response variable
|
72
|
+
#deployment_response = self.execute_deployment(doc)
|
73
|
+
|
74
|
+
# activate only guardrails that require response
|
75
|
+
try:
|
76
|
+
llm_response = self.llm_executor(messages,model_params,llm_caller)
|
77
|
+
except Exception as e:
|
78
|
+
print('Error in running llm:',str(e))
|
79
|
+
return None
|
80
|
+
doc['response'] = llm_response['choices'][0].message.content
|
81
|
+
if 'instruction' in self.field_map:
|
82
|
+
instruction = prompt_params[self.field_map['instruction']]
|
83
|
+
doc['instruction'] = instruction
|
84
|
+
response = self.execute_deployment(doc)
|
85
|
+
if response and response['data']['status'] == 'FAIL':
|
86
|
+
print('Guardrail deployment run retured failed status, replacing with alternate response')
|
87
|
+
return response['data']['alternateResponse'],llm_response,response
|
88
|
+
else:
|
89
|
+
return None,llm_response,response
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
|
94
|
+
|
95
|
+
|
96
|
+
|
97
|
+
|