datarobot-moderations 11.2.9__py3-none-any.whl → 11.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_dome/__init__.py +1 -1
- datarobot_dome/async_http_client.py +1 -1
- datarobot_dome/chat_helper.py +1 -1
- datarobot_dome/constants.py +26 -2
- datarobot_dome/drum_integration.py +2 -3
- datarobot_dome/guard_executor.py +67 -16
- datarobot_dome/guard_factory.py +126 -0
- datarobot_dome/guard_helpers.py +16 -1
- datarobot_dome/guards/__init__.py +16 -1
- datarobot_dome/guards/base.py +259 -0
- datarobot_dome/guards/guard_llm_mixin.py +3 -1
- datarobot_dome/guards/model_guard.py +84 -0
- datarobot_dome/guards/nemo_evaluator.py +73 -0
- datarobot_dome/guards/nemo_guard.py +146 -0
- datarobot_dome/guards/ootb_guard.py +209 -0
- datarobot_dome/guards/validation.py +201 -0
- datarobot_dome/llm.py +1 -1
- datarobot_dome/metrics/__init__.py +1 -1
- datarobot_dome/metrics/citation_metrics.py +1 -1
- datarobot_dome/metrics/factory.py +3 -4
- datarobot_dome/metrics/metric_scorer.py +1 -1
- datarobot_dome/pipeline/__init__.py +1 -1
- datarobot_dome/pipeline/llm_pipeline.py +3 -3
- datarobot_dome/pipeline/pipeline.py +20 -17
- datarobot_dome/pipeline/vdb_pipeline.py +2 -3
- datarobot_dome/runtime.py +1 -1
- datarobot_dome/streaming.py +2 -2
- {datarobot_moderations-11.2.9.dist-info → datarobot_moderations-11.2.11.dist-info}/METADATA +3 -1
- datarobot_moderations-11.2.11.dist-info/RECORD +30 -0
- {datarobot_moderations-11.2.9.dist-info → datarobot_moderations-11.2.11.dist-info}/WHEEL +1 -1
- datarobot_dome/guard.py +0 -845
- datarobot_moderations-11.2.9.dist-info/RECORD +0 -24
datarobot_dome/guard.py
DELETED
|
@@ -1,845 +0,0 @@
|
|
|
1
|
-
# ---------------------------------------------------------------------------------
|
|
2
|
-
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
-
# Last updated 2025.
|
|
4
|
-
#
|
|
5
|
-
# DataRobot, Inc. Confidential.
|
|
6
|
-
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
-
#
|
|
8
|
-
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
-
# For details, see
|
|
10
|
-
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
-
# ---------------------------------------------------------------------------------
|
|
12
|
-
import logging
|
|
13
|
-
import os
|
|
14
|
-
from abc import ABC
|
|
15
|
-
|
|
16
|
-
import datarobot as dr
|
|
17
|
-
import trafaret as t
|
|
18
|
-
from datarobot.enums import CustomMetricAggregationType
|
|
19
|
-
from datarobot.enums import CustomMetricDirectionality
|
|
20
|
-
from deepeval.metrics import TaskCompletionMetric
|
|
21
|
-
from llama_index.core import Settings
|
|
22
|
-
from llama_index.core.evaluation import FaithfulnessEvaluator
|
|
23
|
-
from nemoguardrails import LLMRails
|
|
24
|
-
from nemoguardrails import RailsConfig
|
|
25
|
-
from ragas.llms import LangchainLLMWrapper
|
|
26
|
-
from ragas.llms import LlamaIndexLLMWrapper
|
|
27
|
-
from ragas.metrics import AgentGoalAccuracyWithoutReference
|
|
28
|
-
|
|
29
|
-
from datarobot_dome.constants import AGENT_GOAL_ACCURACY_COLUMN_NAME
|
|
30
|
-
from datarobot_dome.constants import COST_COLUMN_NAME
|
|
31
|
-
from datarobot_dome.constants import CUSTOM_METRIC_DESCRIPTION_SUFFIX
|
|
32
|
-
from datarobot_dome.constants import DEFAULT_GUARD_PREDICTION_TIMEOUT_IN_SEC
|
|
33
|
-
from datarobot_dome.constants import DEFAULT_PROMPT_COLUMN_NAME
|
|
34
|
-
from datarobot_dome.constants import DEFAULT_RESPONSE_COLUMN_NAME
|
|
35
|
-
from datarobot_dome.constants import FAITHFULLNESS_COLUMN_NAME
|
|
36
|
-
from datarobot_dome.constants import NEMO_GUARD_COLUMN_NAME
|
|
37
|
-
from datarobot_dome.constants import NEMO_GUARDRAILS_DIR
|
|
38
|
-
from datarobot_dome.constants import ROUGE_1_COLUMN_NAME
|
|
39
|
-
from datarobot_dome.constants import SPAN_PREFIX
|
|
40
|
-
from datarobot_dome.constants import TASK_ADHERENCE_SCORE_COLUMN_NAME
|
|
41
|
-
from datarobot_dome.constants import TOKEN_COUNT_COLUMN_NAME
|
|
42
|
-
from datarobot_dome.constants import AwsModel
|
|
43
|
-
from datarobot_dome.constants import CostCurrency
|
|
44
|
-
from datarobot_dome.constants import GoogleModel
|
|
45
|
-
from datarobot_dome.constants import GuardAction
|
|
46
|
-
from datarobot_dome.constants import GuardLLMType
|
|
47
|
-
from datarobot_dome.constants import GuardModelTargetType
|
|
48
|
-
from datarobot_dome.constants import GuardOperatorType
|
|
49
|
-
from datarobot_dome.constants import GuardStage
|
|
50
|
-
from datarobot_dome.constants import GuardTimeoutAction
|
|
51
|
-
from datarobot_dome.constants import GuardType
|
|
52
|
-
from datarobot_dome.constants import OOTBType
|
|
53
|
-
from datarobot_dome.guard_helpers import DEFAULT_OPEN_AI_API_VERSION
|
|
54
|
-
from datarobot_dome.guard_helpers import ModerationDeepEvalLLM
|
|
55
|
-
from datarobot_dome.guard_helpers import get_azure_openai_client
|
|
56
|
-
from datarobot_dome.guard_helpers import get_chat_nvidia_llm
|
|
57
|
-
from datarobot_dome.guard_helpers import get_datarobot_endpoint_and_token
|
|
58
|
-
from datarobot_dome.guard_helpers import get_llm_gateway_client
|
|
59
|
-
from datarobot_dome.guard_helpers import use_llm_gateway_inference
|
|
60
|
-
from datarobot_dome.guards.guard_llm_mixin import GuardLLMMixin
|
|
61
|
-
|
|
62
|
-
MAX_GUARD_NAME_LENGTH = 255
|
|
63
|
-
MAX_COLUMN_NAME_LENGTH = 255
|
|
64
|
-
MAX_GUARD_COLUMN_NAME_LENGTH = 255
|
|
65
|
-
MAX_GUARD_MESSAGE_LENGTH = 4096
|
|
66
|
-
MAX_GUARD_DESCRIPTION_LENGTH = 4096
|
|
67
|
-
OBJECT_ID_LENGTH = 24
|
|
68
|
-
MAX_REGEX_LENGTH = 255
|
|
69
|
-
MAX_URL_LENGTH = 255
|
|
70
|
-
MAX_TOKEN_LENGTH = 255
|
|
71
|
-
NEMO_THRESHOLD = "TRUE"
|
|
72
|
-
MAX_GUARD_CUSTOM_METRIC_BASELINE_VALUE_LIST_LENGTH = 5
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
cost_metric_trafaret = t.Dict(
|
|
76
|
-
{
|
|
77
|
-
t.Key("currency", to_name="currency", optional=True, default=CostCurrency.USD): t.Enum(
|
|
78
|
-
*CostCurrency.ALL
|
|
79
|
-
),
|
|
80
|
-
t.Key("input_price", to_name="input_price", optional=False): t.Float(),
|
|
81
|
-
t.Key("input_unit", to_name="input_unit", optional=False): t.Int(),
|
|
82
|
-
t.Key("output_price", to_name="output_price", optional=False): t.Float(),
|
|
83
|
-
t.Key("output_unit", to_name="output_unit", optional=False): t.Int(),
|
|
84
|
-
}
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
model_info_trafaret = t.Dict(
|
|
89
|
-
{
|
|
90
|
-
t.Key("class_names", to_name="class_names", optional=True): t.List(
|
|
91
|
-
t.String(max_length=MAX_COLUMN_NAME_LENGTH)
|
|
92
|
-
),
|
|
93
|
-
t.Key("model_id", to_name="model_id", optional=True): t.String(max_length=OBJECT_ID_LENGTH),
|
|
94
|
-
t.Key("input_column_name", to_name="input_column_name", optional=False): t.String(
|
|
95
|
-
max_length=MAX_COLUMN_NAME_LENGTH
|
|
96
|
-
),
|
|
97
|
-
t.Key("target_name", to_name="target_name", optional=False): t.String(
|
|
98
|
-
max_length=MAX_COLUMN_NAME_LENGTH
|
|
99
|
-
),
|
|
100
|
-
t.Key(
|
|
101
|
-
"replacement_text_column_name", to_name="replacement_text_column_name", optional=True
|
|
102
|
-
): t.Or(t.String(allow_blank=True, max_length=MAX_COLUMN_NAME_LENGTH), t.Null),
|
|
103
|
-
t.Key("target_type", to_name="target_type", optional=False): t.Enum(
|
|
104
|
-
*GuardModelTargetType.ALL
|
|
105
|
-
),
|
|
106
|
-
},
|
|
107
|
-
allow_extra=["*"],
|
|
108
|
-
)
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
model_guard_intervention_trafaret = t.Dict(
|
|
112
|
-
{
|
|
113
|
-
t.Key("comparand", to_name="comparand", optional=False): t.Or(
|
|
114
|
-
t.String(max_length=MAX_GUARD_NAME_LENGTH),
|
|
115
|
-
t.Float(),
|
|
116
|
-
t.Bool(),
|
|
117
|
-
t.List(t.String(max_length=MAX_GUARD_NAME_LENGTH)),
|
|
118
|
-
t.List(t.Float()),
|
|
119
|
-
),
|
|
120
|
-
t.Key("comparator", to_name="comparator", optional=False): t.Enum(*GuardOperatorType.ALL),
|
|
121
|
-
},
|
|
122
|
-
allow_extra=["*"],
|
|
123
|
-
)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
guard_intervention_trafaret = t.Dict(
|
|
127
|
-
{
|
|
128
|
-
t.Key("action", to_name="action", optional=False): t.Enum(*GuardAction.ALL),
|
|
129
|
-
t.Key("message", to_name="message", optional=True): t.String(
|
|
130
|
-
max_length=MAX_GUARD_MESSAGE_LENGTH, allow_blank=True
|
|
131
|
-
),
|
|
132
|
-
t.Key("conditions", to_name="conditions", optional=True): t.Or(
|
|
133
|
-
t.List(
|
|
134
|
-
model_guard_intervention_trafaret,
|
|
135
|
-
max_length=1,
|
|
136
|
-
min_length=0,
|
|
137
|
-
),
|
|
138
|
-
t.Null,
|
|
139
|
-
),
|
|
140
|
-
t.Key("send_notification", to_name="send_notification", optional=True): t.Bool(),
|
|
141
|
-
},
|
|
142
|
-
allow_extra=["*"],
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
additional_guard_config_trafaret = t.Dict(
|
|
146
|
-
{
|
|
147
|
-
t.Key("cost", to_name="cost", optional=True): t.Or(cost_metric_trafaret, t.Null),
|
|
148
|
-
t.Key("tool_call", to_name="tool_call", optional=True): t.Or(t.Any(), t.Null),
|
|
149
|
-
}
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
guard_trafaret = t.Dict(
|
|
154
|
-
{
|
|
155
|
-
t.Key("name", to_name="name", optional=False): t.String(max_length=MAX_GUARD_NAME_LENGTH),
|
|
156
|
-
t.Key("description", to_name="description", optional=True): t.String(
|
|
157
|
-
max_length=MAX_GUARD_DESCRIPTION_LENGTH
|
|
158
|
-
),
|
|
159
|
-
t.Key("type", to_name="type", optional=False): t.Enum(*GuardType.ALL),
|
|
160
|
-
t.Key("stage", to_name="stage", optional=False): t.Or(
|
|
161
|
-
t.List(t.Enum(*GuardStage.ALL)), t.Enum(*GuardStage.ALL)
|
|
162
|
-
),
|
|
163
|
-
t.Key("llm_type", to_name="llm_type", optional=True): t.Enum(*GuardLLMType.ALL),
|
|
164
|
-
t.Key("ootb_type", to_name="ootb_type", optional=True): t.Enum(*OOTBType.ALL),
|
|
165
|
-
t.Key("deployment_id", to_name="deployment_id", optional=True): t.Or(
|
|
166
|
-
t.String(max_length=OBJECT_ID_LENGTH), t.Null
|
|
167
|
-
),
|
|
168
|
-
t.Key("model_info", to_name="model_info", optional=True): model_info_trafaret,
|
|
169
|
-
t.Key("intervention", to_name="intervention", optional=True): t.Or(
|
|
170
|
-
guard_intervention_trafaret, t.Null
|
|
171
|
-
),
|
|
172
|
-
t.Key("openai_api_key", to_name="openai_api_key", optional=True): t.Or(
|
|
173
|
-
t.String(max_length=MAX_TOKEN_LENGTH), t.Null
|
|
174
|
-
),
|
|
175
|
-
t.Key("openai_deployment_id", to_name="openai_deployment_id", optional=True): t.Or(
|
|
176
|
-
t.String(max_length=OBJECT_ID_LENGTH), t.Null
|
|
177
|
-
),
|
|
178
|
-
t.Key("openai_api_base", to_name="openai_api_base", optional=True): t.Or(
|
|
179
|
-
t.String(max_length=MAX_URL_LENGTH), t.Null
|
|
180
|
-
),
|
|
181
|
-
t.Key("google_region", to_name="google_region", optional=True): t.Or(t.String, t.Null),
|
|
182
|
-
t.Key("google_model", to_name="google_model", optional=True): t.Or(
|
|
183
|
-
t.Enum(*GoogleModel.ALL), t.Null
|
|
184
|
-
),
|
|
185
|
-
t.Key("aws_region", to_name="aws_region", optional=True): t.Or(t.String, t.Null),
|
|
186
|
-
t.Key("aws_model", to_name="aws_model", optional=True): t.Or(t.Enum(*AwsModel.ALL), t.Null),
|
|
187
|
-
t.Key("faas_url", optional=True): t.Or(t.String(max_length=MAX_URL_LENGTH), t.Null),
|
|
188
|
-
t.Key("copy_citations", optional=True, default=False): t.Bool(),
|
|
189
|
-
t.Key("is_agentic", to_name="is_agentic", optional=True, default=False): t.Bool(),
|
|
190
|
-
t.Key(
|
|
191
|
-
"additional_guard_config",
|
|
192
|
-
to_name="additional_guard_config",
|
|
193
|
-
optional=True,
|
|
194
|
-
default=None,
|
|
195
|
-
): t.Or(additional_guard_config_trafaret, t.Null),
|
|
196
|
-
},
|
|
197
|
-
allow_extra=["*"],
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
moderation_config_trafaret = t.Dict(
|
|
202
|
-
{
|
|
203
|
-
t.Key(
|
|
204
|
-
"timeout_sec",
|
|
205
|
-
to_name="timeout_sec",
|
|
206
|
-
optional=True,
|
|
207
|
-
default=DEFAULT_GUARD_PREDICTION_TIMEOUT_IN_SEC,
|
|
208
|
-
): t.Int(gt=1),
|
|
209
|
-
t.Key(
|
|
210
|
-
"timeout_action",
|
|
211
|
-
to_name="timeout_action",
|
|
212
|
-
optional=True,
|
|
213
|
-
default=GuardTimeoutAction.SCORE,
|
|
214
|
-
): t.Enum(*GuardTimeoutAction.ALL),
|
|
215
|
-
# Why default is True?
|
|
216
|
-
# We manually tested it and sending extra output with OpenAI completion object under
|
|
217
|
-
# "datarobot_moderations" field seems to be working by default, "EVEN WITH" OpenAI client
|
|
218
|
-
# It will always work with the API response (because it will simply be treated as extra data
|
|
219
|
-
# in the json response). So, most of the times it is going to work. In future, if the
|
|
220
|
-
# OpenAI client couldn't recognize extra data - we can simply disable this flag, so that
|
|
221
|
-
# it won't break the client and user flow
|
|
222
|
-
t.Key(
|
|
223
|
-
"enable_extra_model_output_for_chat",
|
|
224
|
-
to_name="enable_extra_model_output_for_chat",
|
|
225
|
-
optional=True,
|
|
226
|
-
default=True,
|
|
227
|
-
): t.Bool(),
|
|
228
|
-
t.Key("guards", to_name="guards", optional=False): t.List(guard_trafaret),
|
|
229
|
-
},
|
|
230
|
-
allow_extra=["*"],
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
def get_metric_column_name(
|
|
235
|
-
guard_type: GuardType,
|
|
236
|
-
ootb_type: OOTBType | None,
|
|
237
|
-
stage: GuardStage,
|
|
238
|
-
model_guard_target_name: str | None = None,
|
|
239
|
-
metric_name: str | None = None,
|
|
240
|
-
) -> str:
|
|
241
|
-
"""Gets the metric column name. Note that this function gets used in buzok code. If you update
|
|
242
|
-
it, please also update the moderation library in the buzok worker image.
|
|
243
|
-
"""
|
|
244
|
-
if guard_type == GuardType.MODEL:
|
|
245
|
-
if model_guard_target_name is None:
|
|
246
|
-
raise ValueError(
|
|
247
|
-
"For the model guard type, a valid model_guard_target_name has to be provided."
|
|
248
|
-
)
|
|
249
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + model_guard_target_name
|
|
250
|
-
elif guard_type == GuardType.OOTB:
|
|
251
|
-
if ootb_type is None:
|
|
252
|
-
raise ValueError("For the OOTB type, a valid OOTB guard type has to be provided.")
|
|
253
|
-
elif ootb_type == OOTBType.TOKEN_COUNT:
|
|
254
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + TOKEN_COUNT_COLUMN_NAME
|
|
255
|
-
elif ootb_type == OOTBType.ROUGE_1:
|
|
256
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + ROUGE_1_COLUMN_NAME
|
|
257
|
-
elif ootb_type == OOTBType.FAITHFULNESS:
|
|
258
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + FAITHFULLNESS_COLUMN_NAME
|
|
259
|
-
elif ootb_type == OOTBType.AGENT_GOAL_ACCURACY:
|
|
260
|
-
metric_result_key = AGENT_GOAL_ACCURACY_COLUMN_NAME
|
|
261
|
-
elif ootb_type == OOTBType.CUSTOM_METRIC:
|
|
262
|
-
if metric_name is None:
|
|
263
|
-
raise ValueError(
|
|
264
|
-
"For the custom metric type, a valid metric_name has to be provided."
|
|
265
|
-
)
|
|
266
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + metric_name
|
|
267
|
-
elif ootb_type == OOTBType.COST:
|
|
268
|
-
metric_result_key = COST_COLUMN_NAME
|
|
269
|
-
elif ootb_type == OOTBType.TASK_ADHERENCE:
|
|
270
|
-
metric_result_key = TASK_ADHERENCE_SCORE_COLUMN_NAME
|
|
271
|
-
else:
|
|
272
|
-
raise ValueError("The provided OOTB type is not implemented.")
|
|
273
|
-
elif guard_type == GuardType.NEMO_GUARDRAILS:
|
|
274
|
-
metric_result_key = Guard.get_stage_str(stage) + "_" + NEMO_GUARD_COLUMN_NAME
|
|
275
|
-
else:
|
|
276
|
-
raise ValueError("The provided guard type is not implemented.")
|
|
277
|
-
return metric_result_key
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
class Guard(ABC):
|
|
281
|
-
def __init__(self, config: dict, stage=None):
|
|
282
|
-
self._name = config["name"]
|
|
283
|
-
self._description = config.get("description")
|
|
284
|
-
self._type = config["type"]
|
|
285
|
-
self._stage = stage if stage else config["stage"]
|
|
286
|
-
self._pipeline = None
|
|
287
|
-
self._model_info = None
|
|
288
|
-
self.intervention = None
|
|
289
|
-
self._deployment_id = config.get("deployment_id")
|
|
290
|
-
self._dr_cm = None
|
|
291
|
-
self._faas_url = config.get("faas_url")
|
|
292
|
-
self._copy_citations = config["copy_citations"]
|
|
293
|
-
self.is_agentic = config.get("is_agentic", False)
|
|
294
|
-
self.metric_column_name = get_metric_column_name(
|
|
295
|
-
config["type"],
|
|
296
|
-
config.get("ootb_type"),
|
|
297
|
-
self._stage,
|
|
298
|
-
config.get("model_info", {}).get("target_name"),
|
|
299
|
-
config["name"],
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
if config.get("intervention"):
|
|
303
|
-
self.intervention = GuardIntervention(config["intervention"])
|
|
304
|
-
if config.get("model_info"):
|
|
305
|
-
self._model_info = GuardModelInfo(config["model_info"])
|
|
306
|
-
|
|
307
|
-
@property
|
|
308
|
-
def name(self) -> str:
|
|
309
|
-
return self._name
|
|
310
|
-
|
|
311
|
-
@property
|
|
312
|
-
def description(self) -> str:
|
|
313
|
-
return self._description
|
|
314
|
-
|
|
315
|
-
@property
|
|
316
|
-
def type(self) -> GuardType:
|
|
317
|
-
return self._type
|
|
318
|
-
|
|
319
|
-
@property
|
|
320
|
-
def stage(self) -> GuardStage:
|
|
321
|
-
return self._stage
|
|
322
|
-
|
|
323
|
-
@property
|
|
324
|
-
def faas_url(self) -> str:
|
|
325
|
-
return self._faas_url
|
|
326
|
-
|
|
327
|
-
@property
|
|
328
|
-
def copy_citations(self) -> str:
|
|
329
|
-
return self._copy_citations
|
|
330
|
-
|
|
331
|
-
def set_pipeline(self, pipeline):
|
|
332
|
-
self._pipeline = pipeline
|
|
333
|
-
|
|
334
|
-
@property
|
|
335
|
-
def llm_type(self):
|
|
336
|
-
return self._llm_type
|
|
337
|
-
|
|
338
|
-
@staticmethod
|
|
339
|
-
def get_stage_str(stage):
|
|
340
|
-
return "Prompts" if stage == GuardStage.PROMPT else "Responses"
|
|
341
|
-
|
|
342
|
-
def has_latency_custom_metric(self) -> bool:
|
|
343
|
-
"""Determines if latency metric is tracked for this guard type. Default is True."""
|
|
344
|
-
return True
|
|
345
|
-
|
|
346
|
-
def get_latency_custom_metric_name(self):
|
|
347
|
-
return f"{self.name} Guard Latency"
|
|
348
|
-
|
|
349
|
-
def get_latency_custom_metric(self):
|
|
350
|
-
return {
|
|
351
|
-
"name": self.get_latency_custom_metric_name(),
|
|
352
|
-
"directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
|
|
353
|
-
"units": "seconds",
|
|
354
|
-
"type": CustomMetricAggregationType.AVERAGE,
|
|
355
|
-
"baselineValue": 0,
|
|
356
|
-
"isModelSpecific": True,
|
|
357
|
-
"timeStep": "hour",
|
|
358
|
-
"description": (
|
|
359
|
-
f"{self.get_latency_custom_metric_name()}. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
360
|
-
),
|
|
361
|
-
}
|
|
362
|
-
|
|
363
|
-
def has_average_score_custom_metric(self) -> bool:
|
|
364
|
-
"""Determines if an average score metric is tracked for this guard type. Default is True."""
|
|
365
|
-
return True
|
|
366
|
-
|
|
367
|
-
def get_average_score_custom_metric_name(self, stage):
|
|
368
|
-
return f"{self.name} Guard Average Score for {self.get_stage_str(stage)}"
|
|
369
|
-
|
|
370
|
-
def get_average_score_metric(self, stage):
|
|
371
|
-
return {
|
|
372
|
-
"name": self.get_average_score_custom_metric_name(stage),
|
|
373
|
-
"directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
|
|
374
|
-
"units": "probability",
|
|
375
|
-
"type": CustomMetricAggregationType.AVERAGE,
|
|
376
|
-
"baselineValue": 0,
|
|
377
|
-
"isModelSpecific": True,
|
|
378
|
-
"timeStep": "hour",
|
|
379
|
-
"description": (
|
|
380
|
-
f"{self.get_average_score_custom_metric_name(stage)}. "
|
|
381
|
-
f" {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
382
|
-
),
|
|
383
|
-
}
|
|
384
|
-
|
|
385
|
-
def get_guard_enforced_custom_metric_name(self, stage, moderation_method):
|
|
386
|
-
if moderation_method == GuardAction.REPLACE:
|
|
387
|
-
return f"{self.name} Guard replaced {self.get_stage_str(stage)}"
|
|
388
|
-
return f"{self.name} Guard {moderation_method}ed {self.get_stage_str(stage)}"
|
|
389
|
-
|
|
390
|
-
def get_enforced_custom_metric(self, stage, moderation_method):
|
|
391
|
-
return {
|
|
392
|
-
"name": self.get_guard_enforced_custom_metric_name(stage, moderation_method),
|
|
393
|
-
"directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
|
|
394
|
-
"units": "count",
|
|
395
|
-
"type": CustomMetricAggregationType.SUM,
|
|
396
|
-
"baselineValue": 0,
|
|
397
|
-
"isModelSpecific": True,
|
|
398
|
-
"timeStep": "hour",
|
|
399
|
-
"description": (
|
|
400
|
-
f"Number of {self.get_stage_str(stage)} {moderation_method}ed by the "
|
|
401
|
-
f"{self.name} guard. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
402
|
-
),
|
|
403
|
-
}
|
|
404
|
-
|
|
405
|
-
def get_input_column(self, stage):
|
|
406
|
-
if stage == GuardStage.PROMPT:
|
|
407
|
-
return (
|
|
408
|
-
self._model_info.input_column_name
|
|
409
|
-
if (self._model_info.input_column_name)
|
|
410
|
-
else DEFAULT_PROMPT_COLUMN_NAME
|
|
411
|
-
)
|
|
412
|
-
else:
|
|
413
|
-
return (
|
|
414
|
-
self._model_info.input_column_name
|
|
415
|
-
if (self._model_info and self._model_info.input_column_name)
|
|
416
|
-
else DEFAULT_RESPONSE_COLUMN_NAME
|
|
417
|
-
)
|
|
418
|
-
|
|
419
|
-
def get_intervention_action(self):
|
|
420
|
-
if not self.intervention:
|
|
421
|
-
return GuardAction.NONE
|
|
422
|
-
return self.intervention.action
|
|
423
|
-
|
|
424
|
-
def get_comparand(self):
|
|
425
|
-
return self.intervention.threshold
|
|
426
|
-
|
|
427
|
-
def get_enforced_span_attribute_name(self, stage):
|
|
428
|
-
intervention_action = self.get_intervention_action()
|
|
429
|
-
if intervention_action in [GuardAction.BLOCK, GuardAction.REPORT]:
|
|
430
|
-
return f"{SPAN_PREFIX}.{stage.lower()}.{intervention_action}ed"
|
|
431
|
-
elif intervention_action == GuardAction.REPLACE:
|
|
432
|
-
return f"{SPAN_PREFIX}.{stage.lower()}.replaced"
|
|
433
|
-
else:
|
|
434
|
-
raise NotImplementedError
|
|
435
|
-
|
|
436
|
-
def get_span_column_name(self, _):
|
|
437
|
-
raise NotImplementedError
|
|
438
|
-
|
|
439
|
-
def get_span_attribute_name(self, _):
|
|
440
|
-
raise NotImplementedError
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
class GuardModelInfo:
|
|
444
|
-
def __init__(self, model_config: dict):
|
|
445
|
-
self._model_id = model_config.get("model_id")
|
|
446
|
-
self._input_column_name = model_config["input_column_name"]
|
|
447
|
-
self._target_name = model_config["target_name"]
|
|
448
|
-
self._target_type = model_config["target_type"]
|
|
449
|
-
self._class_names = model_config.get("class_names")
|
|
450
|
-
self.replacement_text_column_name = model_config.get("replacement_text_column_name")
|
|
451
|
-
|
|
452
|
-
@property
|
|
453
|
-
def model_id(self) -> str:
|
|
454
|
-
return self._model_id
|
|
455
|
-
|
|
456
|
-
@property
|
|
457
|
-
def input_column_name(self) -> str:
|
|
458
|
-
return self._input_column_name
|
|
459
|
-
|
|
460
|
-
@property
|
|
461
|
-
def target_name(self) -> str:
|
|
462
|
-
return self._target_name
|
|
463
|
-
|
|
464
|
-
@property
|
|
465
|
-
def target_type(self) -> str:
|
|
466
|
-
return self._target_type
|
|
467
|
-
|
|
468
|
-
@property
|
|
469
|
-
def class_names(self):
|
|
470
|
-
return self._class_names
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
class GuardIntervention:
|
|
474
|
-
def __init__(self, intervention_config: dict) -> None:
|
|
475
|
-
self.action = intervention_config["action"]
|
|
476
|
-
self.message = intervention_config.get("message")
|
|
477
|
-
self.threshold = None
|
|
478
|
-
self.comparator = None
|
|
479
|
-
if (
|
|
480
|
-
"conditions" in intervention_config
|
|
481
|
-
and intervention_config["conditions"] is not None
|
|
482
|
-
and len(intervention_config["conditions"]) > 0
|
|
483
|
-
):
|
|
484
|
-
self.threshold = intervention_config["conditions"][0].get("comparand")
|
|
485
|
-
self.comparator = intervention_config["conditions"][0].get("comparator")
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
class ModelGuard(Guard):
|
|
489
|
-
def __init__(self, config: dict, stage=None):
|
|
490
|
-
super().__init__(config, stage)
|
|
491
|
-
self._deployment_id = config["deployment_id"]
|
|
492
|
-
self._model_info = GuardModelInfo(config["model_info"])
|
|
493
|
-
# dr.Client is set in the Pipeline init, Lets query the deployment
|
|
494
|
-
# to get the prediction server information
|
|
495
|
-
self.deployment = dr.Deployment.get(self._deployment_id)
|
|
496
|
-
|
|
497
|
-
@property
|
|
498
|
-
def deployment_id(self) -> str:
|
|
499
|
-
return self._deployment_id
|
|
500
|
-
|
|
501
|
-
@property
|
|
502
|
-
def model_info(self):
|
|
503
|
-
return self._model_info
|
|
504
|
-
|
|
505
|
-
def get_span_column_name(self, _):
|
|
506
|
-
if self.model_info is None:
|
|
507
|
-
raise NotImplementedError("Missing model_info for model guard")
|
|
508
|
-
# Typically 0th index is the target name
|
|
509
|
-
return self._model_info.target_name.split("_")[0]
|
|
510
|
-
|
|
511
|
-
def get_span_attribute_name(self, stage):
|
|
512
|
-
return f"{SPAN_PREFIX}.{stage.lower()}.{self.get_span_column_name(stage)}"
|
|
513
|
-
|
|
514
|
-
def has_average_score_custom_metric(self) -> bool:
|
|
515
|
-
"""A couple ModelGuard types do not have an average score metric"""
|
|
516
|
-
return self.model_info.target_type not in [
|
|
517
|
-
"Multiclass",
|
|
518
|
-
"TextGeneration",
|
|
519
|
-
]
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
class NeMoGuard(Guard, GuardLLMMixin):
|
|
523
|
-
def __init__(self, config: dict, stage=None, model_dir: str = os.getcwd()):
|
|
524
|
-
super().__init__(config, stage)
|
|
525
|
-
# NeMo guard only takes a boolean as threshold and equal to as comparator.
|
|
526
|
-
# Threshold bool == TRUE is defined in the colang file as the output of
|
|
527
|
-
# `bot should intervene`
|
|
528
|
-
if self.intervention:
|
|
529
|
-
if not self.intervention.threshold:
|
|
530
|
-
self.intervention.threshold = NEMO_THRESHOLD
|
|
531
|
-
if not self.intervention.comparator:
|
|
532
|
-
self.intervention.comparator = GuardOperatorType.EQUALS
|
|
533
|
-
|
|
534
|
-
# Default LLM Type for NeMo is set to OpenAI
|
|
535
|
-
self._llm_type = config.get("llm_type", GuardLLMType.OPENAI)
|
|
536
|
-
self.openai_api_base = config.get("openai_api_base")
|
|
537
|
-
self.openai_deployment_id = config.get("openai_deployment_id")
|
|
538
|
-
llm_id = None
|
|
539
|
-
credentials = None
|
|
540
|
-
use_llm_gateway = use_llm_gateway_inference(self._llm_type)
|
|
541
|
-
try:
|
|
542
|
-
self.openai_api_key = self.get_openai_api_key(config, self._llm_type)
|
|
543
|
-
if self._llm_type != GuardLLMType.NIM and self.openai_api_key is None:
|
|
544
|
-
raise ValueError("OpenAI API key is required for NeMo Guardrails")
|
|
545
|
-
|
|
546
|
-
if self.llm_type == GuardLLMType.OPENAI:
|
|
547
|
-
credentials = {
|
|
548
|
-
"credential_type": "openai",
|
|
549
|
-
"api_key": self.openai_api_key,
|
|
550
|
-
}
|
|
551
|
-
os.environ["OPENAI_API_KEY"] = self.openai_api_key
|
|
552
|
-
llm = None
|
|
553
|
-
elif self.llm_type == GuardLLMType.AZURE_OPENAI:
|
|
554
|
-
if self.openai_api_base is None:
|
|
555
|
-
raise ValueError("Azure OpenAI API base url is required for LLM Guard")
|
|
556
|
-
if self.openai_deployment_id is None:
|
|
557
|
-
raise ValueError("Azure OpenAI deployment ID is required for LLM Guard")
|
|
558
|
-
credentials = {
|
|
559
|
-
"credential_type": "azure_openai",
|
|
560
|
-
"api_base": self.openai_api_base,
|
|
561
|
-
"api_version": DEFAULT_OPEN_AI_API_VERSION,
|
|
562
|
-
"api_key": self.openai_api_key,
|
|
563
|
-
}
|
|
564
|
-
azure_openai_client = get_azure_openai_client(
|
|
565
|
-
openai_api_key=self.openai_api_key,
|
|
566
|
-
openai_api_base=self.openai_api_base,
|
|
567
|
-
openai_deployment_id=self.openai_deployment_id,
|
|
568
|
-
)
|
|
569
|
-
llm = azure_openai_client
|
|
570
|
-
elif self.llm_type == GuardLLMType.GOOGLE:
|
|
571
|
-
# llm_id = config["google_model"]
|
|
572
|
-
raise NotImplementedError
|
|
573
|
-
elif self.llm_type == GuardLLMType.AMAZON:
|
|
574
|
-
# llm_id = config["aws_model"]
|
|
575
|
-
raise NotImplementedError
|
|
576
|
-
elif self.llm_type == GuardLLMType.DATAROBOT:
|
|
577
|
-
raise NotImplementedError
|
|
578
|
-
elif self.llm_type == GuardLLMType.NIM:
|
|
579
|
-
if config.get("deployment_id") is None:
|
|
580
|
-
if self.openai_api_base is None:
|
|
581
|
-
raise ValueError("NIM DataRobot deployment id is required for NIM LLM Type")
|
|
582
|
-
else:
|
|
583
|
-
logging.warning(
|
|
584
|
-
"Using 'openai_api_base' is being deprecated and will be removed "
|
|
585
|
-
"in the next release. Please configure NIM DataRobot deployment "
|
|
586
|
-
"using deployment_id"
|
|
587
|
-
)
|
|
588
|
-
if self.openai_api_key is None:
|
|
589
|
-
raise ValueError("OpenAI API key is required for NeMo Guardrails")
|
|
590
|
-
else:
|
|
591
|
-
self.deployment = dr.Deployment.get(self._deployment_id)
|
|
592
|
-
datarobot_endpoint, self.openai_api_key = get_datarobot_endpoint_and_token()
|
|
593
|
-
self.openai_api_base = (
|
|
594
|
-
f"{datarobot_endpoint}/deployments/{str(self._deployment_id)}"
|
|
595
|
-
)
|
|
596
|
-
llm = get_chat_nvidia_llm(
|
|
597
|
-
api_key=self.openai_api_key,
|
|
598
|
-
base_url=self.openai_api_base,
|
|
599
|
-
)
|
|
600
|
-
else:
|
|
601
|
-
raise ValueError(f"Invalid LLMType: {self.llm_type}")
|
|
602
|
-
|
|
603
|
-
except Exception as e:
|
|
604
|
-
# no valid user credentials provided, raise if not using LLM Gateway
|
|
605
|
-
credentials = None
|
|
606
|
-
if not use_llm_gateway:
|
|
607
|
-
raise e
|
|
608
|
-
|
|
609
|
-
if use_llm_gateway:
|
|
610
|
-
# Currently only OPENAI and AZURE_OPENAI are supported by NeMoGuard
|
|
611
|
-
# For Bedrock and Vertex the model in the config is actually the LLM ID
|
|
612
|
-
# For OpenAI we use the default model defined in get_llm_gateway_client
|
|
613
|
-
# For Azure we use the deployment ID
|
|
614
|
-
llm = get_llm_gateway_client(
|
|
615
|
-
llm_id=llm_id,
|
|
616
|
-
openai_deployment_id=self.openai_deployment_id,
|
|
617
|
-
credentials=credentials,
|
|
618
|
-
)
|
|
619
|
-
|
|
620
|
-
# Use guard stage to determine whether to read from prompt/response subdirectory
|
|
621
|
-
# for nemo configurations. "nemo_guardrails" folder is at same level of custom.py
|
|
622
|
-
# So, the config path becomes model_dir + "nemo_guardrails"
|
|
623
|
-
nemo_config_path = os.path.join(model_dir, NEMO_GUARDRAILS_DIR)
|
|
624
|
-
self.nemo_rails_config_path = os.path.join(nemo_config_path, self.stage)
|
|
625
|
-
nemo_rails_config = RailsConfig.from_path(config_path=self.nemo_rails_config_path)
|
|
626
|
-
self._nemo_llm_rails = LLMRails(nemo_rails_config, llm=llm)
|
|
627
|
-
|
|
628
|
-
def has_average_score_custom_metric(self) -> bool:
|
|
629
|
-
"""No average score metrics for NemoGuard's"""
|
|
630
|
-
return False
|
|
631
|
-
|
|
632
|
-
@property
|
|
633
|
-
def nemo_llm_rails(self):
|
|
634
|
-
return self._nemo_llm_rails
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
class OOTBGuard(Guard):
|
|
638
|
-
def __init__(self, config: dict, stage=None):
|
|
639
|
-
super().__init__(config, stage)
|
|
640
|
-
self._ootb_type = config["ootb_type"]
|
|
641
|
-
|
|
642
|
-
@property
|
|
643
|
-
def ootb_type(self):
|
|
644
|
-
return self._ootb_type
|
|
645
|
-
|
|
646
|
-
def has_latency_custom_metric(self):
|
|
647
|
-
"""Latency is not tracked for token counts guards"""
|
|
648
|
-
return self._ootb_type != OOTBType.TOKEN_COUNT
|
|
649
|
-
|
|
650
|
-
def get_span_column_name(self, _):
|
|
651
|
-
if self._ootb_type == OOTBType.TOKEN_COUNT:
|
|
652
|
-
return TOKEN_COUNT_COLUMN_NAME
|
|
653
|
-
elif self._ootb_type == OOTBType.ROUGE_1:
|
|
654
|
-
return ROUGE_1_COLUMN_NAME
|
|
655
|
-
elif self._ootb_type == OOTBType.CUSTOM_METRIC:
|
|
656
|
-
return self.name
|
|
657
|
-
else:
|
|
658
|
-
raise NotImplementedError(f"No span attribute name defined for {self._ootb_type} guard")
|
|
659
|
-
|
|
660
|
-
def get_span_attribute_name(self, stage):
|
|
661
|
-
return f"{SPAN_PREFIX}.{stage.lower()}.{self.get_span_column_name(stage)}"
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
class OOTBCostMetric(OOTBGuard):
|
|
665
|
-
def __init__(self, config, stage):
|
|
666
|
-
super().__init__(config, stage)
|
|
667
|
-
# The cost is calculated based on the usage metrics returned by the
|
|
668
|
-
# completion object, so it can be evaluated only at response stage
|
|
669
|
-
self._stage = GuardStage.RESPONSE
|
|
670
|
-
cost_config = config["additional_guard_config"]["cost"]
|
|
671
|
-
self.currency = cost_config["currency"]
|
|
672
|
-
self.input_price = cost_config["input_price"]
|
|
673
|
-
self.input_unit = cost_config["input_unit"]
|
|
674
|
-
self.input_multiplier = self.input_price / self.input_unit
|
|
675
|
-
self.output_price = cost_config["output_price"]
|
|
676
|
-
self.output_unit = cost_config["output_unit"]
|
|
677
|
-
self.output_multiplier = self.output_price / self.output_unit
|
|
678
|
-
|
|
679
|
-
def get_average_score_custom_metric_name(self, _):
|
|
680
|
-
return f"Total cost in {self.currency}"
|
|
681
|
-
|
|
682
|
-
def get_average_score_metric(self, _):
|
|
683
|
-
return {
|
|
684
|
-
"name": self.get_average_score_custom_metric_name(_),
|
|
685
|
-
"directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
|
|
686
|
-
"units": "value",
|
|
687
|
-
"type": CustomMetricAggregationType.SUM,
|
|
688
|
-
"baselineValue": 0,
|
|
689
|
-
"isModelSpecific": True,
|
|
690
|
-
"timeStep": "hour",
|
|
691
|
-
"description": (
|
|
692
|
-
f"{self.get_average_score_custom_metric_name(_)}. "
|
|
693
|
-
f" {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
694
|
-
),
|
|
695
|
-
}
|
|
696
|
-
|
|
697
|
-
def get_span_column_name(self, _):
|
|
698
|
-
return f"{COST_COLUMN_NAME}.{self.currency.lower()}"
|
|
699
|
-
|
|
700
|
-
def get_span_attribute_name(self, _):
|
|
701
|
-
return f"{SPAN_PREFIX}.{self._stage.lower()}.{self.get_span_column_name(_)}"
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
class FaithfulnessGuard(OOTBGuard, GuardLLMMixin):
|
|
705
|
-
def __init__(self, config: dict, stage=None):
|
|
706
|
-
super().__init__(config, stage)
|
|
707
|
-
|
|
708
|
-
if self.stage == GuardStage.PROMPT:
|
|
709
|
-
raise Exception("Faithfulness cannot be configured for the Prompt stage")
|
|
710
|
-
|
|
711
|
-
# Default LLM Type for Faithfulness is set to Azure OpenAI
|
|
712
|
-
self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
|
|
713
|
-
Settings.llm = self.get_llm(config, self._llm_type)
|
|
714
|
-
Settings.embed_model = None
|
|
715
|
-
self._evaluator = FaithfulnessEvaluator()
|
|
716
|
-
|
|
717
|
-
@property
|
|
718
|
-
def faithfulness_evaluator(self):
|
|
719
|
-
return self._evaluator
|
|
720
|
-
|
|
721
|
-
def get_span_column_name(self, _):
|
|
722
|
-
return FAITHFULLNESS_COLUMN_NAME
|
|
723
|
-
|
|
724
|
-
def get_span_attribute_name(self, _):
|
|
725
|
-
return f"{SPAN_PREFIX}.{self._stage.lower()}.{self.get_span_column_name(_)}"
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
class AgentGoalAccuracyGuard(OOTBGuard, GuardLLMMixin):
|
|
729
|
-
def __init__(self, config: dict, stage=None):
|
|
730
|
-
super().__init__(config, stage)
|
|
731
|
-
|
|
732
|
-
if self.stage == GuardStage.PROMPT:
|
|
733
|
-
raise Exception("Agent Goal Accuracy guard cannot be configured for the Prompt stage")
|
|
734
|
-
|
|
735
|
-
# Default LLM Type for Agent Goal Accuracy is set to Azure OpenAI
|
|
736
|
-
self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
|
|
737
|
-
llm = self.get_llm(config, self._llm_type)
|
|
738
|
-
if self._llm_type == GuardLLMType.AZURE_OPENAI:
|
|
739
|
-
evaluator_llm = LangchainLLMWrapper(llm)
|
|
740
|
-
else:
|
|
741
|
-
evaluator_llm = LlamaIndexLLMWrapper(llm)
|
|
742
|
-
self.scorer = AgentGoalAccuracyWithoutReference(llm=evaluator_llm)
|
|
743
|
-
|
|
744
|
-
@property
|
|
745
|
-
def accuracy_scorer(self):
|
|
746
|
-
return self.scorer
|
|
747
|
-
|
|
748
|
-
def get_span_column_name(self, _):
|
|
749
|
-
return AGENT_GOAL_ACCURACY_COLUMN_NAME
|
|
750
|
-
|
|
751
|
-
def get_span_attribute_name(self, _):
|
|
752
|
-
return f"{SPAN_PREFIX}.{self._stage.lower()}.{self.get_span_column_name(_)}"
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
class TaskAdherenceGuard(OOTBGuard, GuardLLMMixin):
|
|
756
|
-
def __init__(self, config: dict, stage=None):
|
|
757
|
-
super().__init__(config, stage)
|
|
758
|
-
|
|
759
|
-
if self.stage == GuardStage.PROMPT:
|
|
760
|
-
raise Exception("Agent Goal Accuracy guard cannot be configured for the Prompt stage")
|
|
761
|
-
|
|
762
|
-
# Default LLM Type for Faithfulness is set to Azure OpenAI
|
|
763
|
-
self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
|
|
764
|
-
llm = self.get_llm(config, self._llm_type)
|
|
765
|
-
deepeval_llm = ModerationDeepEvalLLM(llm)
|
|
766
|
-
self.scorer = TaskCompletionMetric(model=deepeval_llm, include_reason=True)
|
|
767
|
-
|
|
768
|
-
@property
|
|
769
|
-
def task_adherence_scorer(self):
|
|
770
|
-
return self.scorer
|
|
771
|
-
|
|
772
|
-
def get_span_column_name(self, _):
|
|
773
|
-
return TASK_ADHERENCE_SCORE_COLUMN_NAME
|
|
774
|
-
|
|
775
|
-
def get_span_attribute_name(self, _):
|
|
776
|
-
return f"{SPAN_PREFIX}.{self._stage.lower()}.{self.get_span_column_name(_)}"
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
class GuardFactory:
|
|
780
|
-
@classmethod
|
|
781
|
-
def _perform_post_validation_checks(cls, guard_config):
|
|
782
|
-
if not guard_config.get("intervention"):
|
|
783
|
-
return
|
|
784
|
-
|
|
785
|
-
if guard_config["intervention"]["action"] == GuardAction.BLOCK and (
|
|
786
|
-
guard_config["intervention"]["message"] is None
|
|
787
|
-
or len(guard_config["intervention"]["message"]) == 0
|
|
788
|
-
):
|
|
789
|
-
raise ValueError("Blocked action needs a blocking message")
|
|
790
|
-
|
|
791
|
-
if guard_config["intervention"]["action"] == GuardAction.REPLACE:
|
|
792
|
-
if "model_info" not in guard_config:
|
|
793
|
-
raise ValueError("'Replace' action needs model_info section")
|
|
794
|
-
if (
|
|
795
|
-
"replacement_text_column_name" not in guard_config["model_info"]
|
|
796
|
-
or guard_config["model_info"]["replacement_text_column_name"] is None
|
|
797
|
-
or len(guard_config["model_info"]["replacement_text_column_name"]) == 0
|
|
798
|
-
):
|
|
799
|
-
raise ValueError(
|
|
800
|
-
"'Replace' action needs valid 'replacement_text_column_name' "
|
|
801
|
-
"in 'model_info' section of the guard"
|
|
802
|
-
)
|
|
803
|
-
|
|
804
|
-
if not guard_config["intervention"].get("conditions"):
|
|
805
|
-
return
|
|
806
|
-
|
|
807
|
-
if len(guard_config["intervention"]["conditions"]) == 0:
|
|
808
|
-
return
|
|
809
|
-
|
|
810
|
-
condition = guard_config["intervention"]["conditions"][0]
|
|
811
|
-
if condition["comparator"] in GuardOperatorType.REQUIRES_LIST_COMPARAND:
|
|
812
|
-
if not isinstance(condition["comparand"], list):
|
|
813
|
-
raise ValueError(
|
|
814
|
-
f"Comparand needs to be a list with {condition['comparator']} comparator"
|
|
815
|
-
)
|
|
816
|
-
elif isinstance(condition["comparand"], list):
|
|
817
|
-
raise ValueError(
|
|
818
|
-
f"Comparand needs to be a scalar with {condition['comparator']} comparator"
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
@staticmethod
|
|
822
|
-
def create(input_config: dict, stage=None, model_dir: str = os.getcwd()) -> Guard:
|
|
823
|
-
config = guard_trafaret.check(input_config)
|
|
824
|
-
|
|
825
|
-
GuardFactory._perform_post_validation_checks(config)
|
|
826
|
-
|
|
827
|
-
if config["type"] == GuardType.MODEL:
|
|
828
|
-
guard = ModelGuard(config, stage)
|
|
829
|
-
elif config["type"] == GuardType.OOTB:
|
|
830
|
-
if config["ootb_type"] == OOTBType.FAITHFULNESS:
|
|
831
|
-
guard = FaithfulnessGuard(config, stage)
|
|
832
|
-
elif config["ootb_type"] == OOTBType.COST:
|
|
833
|
-
guard = OOTBCostMetric(config, stage)
|
|
834
|
-
elif config["ootb_type"] == OOTBType.AGENT_GOAL_ACCURACY:
|
|
835
|
-
guard = AgentGoalAccuracyGuard(config, stage)
|
|
836
|
-
elif config["ootb_type"] == OOTBType.TASK_ADHERENCE:
|
|
837
|
-
guard = TaskAdherenceGuard(config, stage)
|
|
838
|
-
else:
|
|
839
|
-
guard = OOTBGuard(config, stage)
|
|
840
|
-
elif config["type"] == GuardType.NEMO_GUARDRAILS:
|
|
841
|
-
guard = NeMoGuard(config, stage, model_dir)
|
|
842
|
-
else:
|
|
843
|
-
raise ValueError(f"Invalid guard type: {config['type']}")
|
|
844
|
-
|
|
845
|
-
return guard
|