datarobot-moderations 11.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,736 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ import logging
13
+ import os
14
+ from abc import ABC
15
+
16
+ import datarobot as dr
17
+ import trafaret as t
18
+ from datarobot.enums import CustomMetricAggregationType
19
+ from datarobot.enums import CustomMetricDirectionality
20
+ from deepeval.metrics import TaskCompletionMetric
21
+ from llama_index.core import Settings
22
+ from llama_index.core.evaluation import FaithfulnessEvaluator
23
+ from nemoguardrails import LLMRails
24
+ from nemoguardrails import RailsConfig
25
+ from ragas.llms import LangchainLLMWrapper
26
+ from ragas.metrics import AgentGoalAccuracyWithoutReference
27
+
28
+ from datarobot_dome.constants import AGENT_GOAL_ACCURACY_COLUMN_NAME
29
+ from datarobot_dome.constants import COST_COLUMN_NAME
30
+ from datarobot_dome.constants import CUSTOM_METRIC_DESCRIPTION_SUFFIX
31
+ from datarobot_dome.constants import DEFAULT_GUARD_PREDICTION_TIMEOUT_IN_SEC
32
+ from datarobot_dome.constants import DEFAULT_PROMPT_COLUMN_NAME
33
+ from datarobot_dome.constants import DEFAULT_RESPONSE_COLUMN_NAME
34
+ from datarobot_dome.constants import FAITHFULLNESS_COLUMN_NAME
35
+ from datarobot_dome.constants import NEMO_GUARD_COLUMN_NAME
36
+ from datarobot_dome.constants import NEMO_GUARDRAILS_DIR
37
+ from datarobot_dome.constants import ROUGE_1_COLUMN_NAME
38
+ from datarobot_dome.constants import TASK_ADHERENCE_SCORE_COLUMN_NAME
39
+ from datarobot_dome.constants import TOKEN_COUNT_COLUMN_NAME
40
+ from datarobot_dome.constants import AwsModel
41
+ from datarobot_dome.constants import CostCurrency
42
+ from datarobot_dome.constants import GoogleModel
43
+ from datarobot_dome.constants import GuardAction
44
+ from datarobot_dome.constants import GuardLLMType
45
+ from datarobot_dome.constants import GuardModelTargetType
46
+ from datarobot_dome.constants import GuardOperatorType
47
+ from datarobot_dome.constants import GuardStage
48
+ from datarobot_dome.constants import GuardTimeoutAction
49
+ from datarobot_dome.constants import GuardType
50
+ from datarobot_dome.constants import OOTBType
51
+ from datarobot_dome.guard_helpers import ModerationDeepEvalLLM
52
+ from datarobot_dome.guard_helpers import get_azure_openai_client
53
+ from datarobot_dome.guard_helpers import get_chat_nvidia_llm
54
+ from datarobot_dome.guard_helpers import get_datarobot_endpoint_and_token
55
+ from datarobot_dome.guard_helpers import try_to_fallback_to_llm_gateway
56
+ from datarobot_dome.guards.guard_llm_mixin import GuardLLMMixin
57
+
58
+ MAX_GUARD_NAME_LENGTH = 255
59
+ MAX_COLUMN_NAME_LENGTH = 255
60
+ MAX_GUARD_COLUMN_NAME_LENGTH = 255
61
+ MAX_GUARD_MESSAGE_LENGTH = 4096
62
+ MAX_GUARD_DESCRIPTION_LENGTH = 4096
63
+ OBJECT_ID_LENGTH = 24
64
+ MAX_REGEX_LENGTH = 255
65
+ MAX_URL_LENGTH = 255
66
+ MAX_TOKEN_LENGTH = 255
67
+ NEMO_THRESHOLD = "TRUE"
68
+ MAX_GUARD_CUSTOM_METRIC_BASELINE_VALUE_LIST_LENGTH = 5
69
+
70
+
71
+ cost_metric_trafaret = t.Dict(
72
+ {
73
+ t.Key("currency", to_name="currency", optional=True, default=CostCurrency.USD): t.Enum(
74
+ *CostCurrency.ALL
75
+ ),
76
+ t.Key("input_price", to_name="input_price", optional=False): t.Float(),
77
+ t.Key("input_unit", to_name="input_unit", optional=False): t.Int(),
78
+ t.Key("output_price", to_name="output_price", optional=False): t.Float(),
79
+ t.Key("output_unit", to_name="output_unit", optional=False): t.Int(),
80
+ }
81
+ )
82
+
83
+
84
+ model_info_trafaret = t.Dict(
85
+ {
86
+ t.Key("class_names", to_name="class_names", optional=True): t.List(
87
+ t.String(max_length=MAX_COLUMN_NAME_LENGTH)
88
+ ),
89
+ t.Key("model_id", to_name="model_id", optional=True): t.String(max_length=OBJECT_ID_LENGTH),
90
+ t.Key("input_column_name", to_name="input_column_name", optional=False): t.String(
91
+ max_length=MAX_COLUMN_NAME_LENGTH
92
+ ),
93
+ t.Key("target_name", to_name="target_name", optional=False): t.String(
94
+ max_length=MAX_COLUMN_NAME_LENGTH
95
+ ),
96
+ t.Key(
97
+ "replacement_text_column_name", to_name="replacement_text_column_name", optional=True
98
+ ): t.Or(t.String(allow_blank=True, max_length=MAX_COLUMN_NAME_LENGTH), t.Null),
99
+ t.Key("target_type", to_name="target_type", optional=False): t.Enum(
100
+ *GuardModelTargetType.ALL
101
+ ),
102
+ },
103
+ allow_extra=["*"],
104
+ )
105
+
106
+
107
+ model_guard_intervention_trafaret = t.Dict(
108
+ {
109
+ t.Key("comparand", to_name="comparand", optional=False): t.Or(
110
+ t.String(max_length=MAX_GUARD_NAME_LENGTH),
111
+ t.Float(),
112
+ t.Bool(),
113
+ t.List(t.String(max_length=MAX_GUARD_NAME_LENGTH)),
114
+ t.List(t.Float()),
115
+ ),
116
+ t.Key("comparator", to_name="comparator", optional=False): t.Enum(*GuardOperatorType.ALL),
117
+ },
118
+ allow_extra=["*"],
119
+ )
120
+
121
+
122
+ guard_intervention_trafaret = t.Dict(
123
+ {
124
+ t.Key("action", to_name="action", optional=False): t.Enum(*GuardAction.ALL),
125
+ t.Key("message", to_name="message", optional=True): t.String(
126
+ max_length=MAX_GUARD_MESSAGE_LENGTH, allow_blank=True
127
+ ),
128
+ t.Key("conditions", to_name="conditions", optional=True): t.Or(
129
+ t.List(
130
+ model_guard_intervention_trafaret,
131
+ max_length=1,
132
+ min_length=0,
133
+ ),
134
+ t.Null,
135
+ ),
136
+ t.Key("send_notification", to_name="send_notification", optional=True): t.Bool(),
137
+ },
138
+ allow_extra=["*"],
139
+ )
140
+
141
+ additional_guard_config_trafaret = t.Dict(
142
+ {
143
+ t.Key("cost", to_name="cost", optional=True): t.Or(cost_metric_trafaret, t.Null),
144
+ }
145
+ )
146
+
147
+
148
+ guard_trafaret = t.Dict(
149
+ {
150
+ t.Key("name", to_name="name", optional=False): t.String(max_length=MAX_GUARD_NAME_LENGTH),
151
+ t.Key("description", to_name="description", optional=True): t.String(
152
+ max_length=MAX_GUARD_DESCRIPTION_LENGTH
153
+ ),
154
+ t.Key("type", to_name="type", optional=False): t.Enum(*GuardType.ALL),
155
+ t.Key("stage", to_name="stage", optional=False): t.Or(
156
+ t.List(t.Enum(*GuardStage.ALL)), t.Enum(*GuardStage.ALL)
157
+ ),
158
+ t.Key("llm_type", to_name="llm_type", optional=True): t.Enum(*GuardLLMType.ALL),
159
+ t.Key("ootb_type", to_name="ootb_type", optional=True): t.Enum(*OOTBType.ALL),
160
+ t.Key("deployment_id", to_name="deployment_id", optional=True): t.Or(
161
+ t.String(max_length=OBJECT_ID_LENGTH), t.Null
162
+ ),
163
+ t.Key("model_info", to_name="model_info", optional=True): model_info_trafaret,
164
+ t.Key("intervention", to_name="intervention", optional=True): t.Or(
165
+ guard_intervention_trafaret, t.Null
166
+ ),
167
+ t.Key("openai_api_key", to_name="openai_api_key", optional=True): t.Or(
168
+ t.String(max_length=MAX_TOKEN_LENGTH), t.Null
169
+ ),
170
+ t.Key("openai_deployment_id", to_name="openai_deployment_id", optional=True): t.Or(
171
+ t.String(max_length=OBJECT_ID_LENGTH), t.Null
172
+ ),
173
+ t.Key("openai_api_base", to_name="openai_api_base", optional=True): t.Or(
174
+ t.String(max_length=MAX_URL_LENGTH), t.Null
175
+ ),
176
+ t.Key("google_region", to_name="google_region", optional=True): t.Or(t.String, t.Null),
177
+ t.Key("google_model", to_name="google_model", optional=True): t.Or(
178
+ t.Enum(*GoogleModel.ALL), t.Null
179
+ ),
180
+ t.Key("aws_region", to_name="aws_region", optional=True): t.Or(t.String, t.Null),
181
+ t.Key("aws_model", to_name="aws_model", optional=True): t.Or(t.Enum(*AwsModel.ALL), t.Null),
182
+ t.Key("faas_url", optional=True): t.Or(t.String(max_length=MAX_URL_LENGTH), t.Null),
183
+ t.Key("copy_citations", optional=True, default=False): t.Bool(),
184
+ t.Key("is_agentic", to_name="is_agentic", optional=True, default=False): t.Bool(),
185
+ t.Key(
186
+ "additional_guard_config",
187
+ to_name="additional_guard_config",
188
+ optional=True,
189
+ default=None,
190
+ ): t.Or(additional_guard_config_trafaret, t.Null),
191
+ },
192
+ allow_extra=["*"],
193
+ )
194
+
195
+
196
+ moderation_config_trafaret = t.Dict(
197
+ {
198
+ t.Key(
199
+ "timeout_sec",
200
+ to_name="timeout_sec",
201
+ optional=True,
202
+ default=DEFAULT_GUARD_PREDICTION_TIMEOUT_IN_SEC,
203
+ ): t.Int(gt=1),
204
+ t.Key(
205
+ "timeout_action",
206
+ to_name="timeout_action",
207
+ optional=True,
208
+ default=GuardTimeoutAction.SCORE,
209
+ ): t.Enum(*GuardTimeoutAction.ALL),
210
+ # Why default is True?
211
+ # We manually tested it and sending extra output with OpenAI completion object under
212
+ # "datarobot_moderations" field seems to be working by default, "EVEN WITH" OpenAI client
213
+ # It will always work with the API response (because it will simply be treated as extra data
214
+ # in the json response). So, most of the times it is going to work. In future, if the
215
+ # OpenAI client couldn't recognize extra data - we can simply disable this flag, so that
216
+ # it won't break the client and user flow
217
+ t.Key(
218
+ "enable_extra_model_output_for_chat",
219
+ to_name="enable_extra_model_output_for_chat",
220
+ optional=True,
221
+ default=True,
222
+ ): t.Bool(),
223
+ t.Key("guards", to_name="guards", optional=False): t.List(guard_trafaret),
224
+ },
225
+ allow_extra=["*"],
226
+ )
227
+
228
+
229
+ class Guard(ABC):
230
+ def __init__(self, config: dict, stage=None):
231
+ self._name = config["name"]
232
+ self._description = config.get("description")
233
+ self._type = config["type"]
234
+ self._stage = stage if stage else config["stage"]
235
+ self._pipeline = None
236
+ self._model_info = None
237
+ self.intervention = None
238
+ self._deployment_id = config.get("deployment_id")
239
+ self._dr_cm = None
240
+ self._faas_url = config.get("faas_url")
241
+ self._copy_citations = config["copy_citations"]
242
+ self.is_agentic = config.get("is_agentic", False)
243
+
244
+ if config.get("intervention"):
245
+ self.intervention = GuardIntervention(config["intervention"])
246
+ if config.get("model_info"):
247
+ self._model_info = GuardModelInfo(config["model_info"])
248
+
249
+ @property
250
+ def name(self) -> str:
251
+ return self._name
252
+
253
+ @property
254
+ def description(self) -> str:
255
+ return self._description
256
+
257
+ @property
258
+ def type(self) -> GuardType:
259
+ return self._type
260
+
261
+ @property
262
+ def stage(self) -> GuardStage:
263
+ return self._stage
264
+
265
+ @property
266
+ def faas_url(self) -> str:
267
+ return self._faas_url
268
+
269
+ @property
270
+ def copy_citations(self) -> str:
271
+ return self._copy_citations
272
+
273
+ def set_pipeline(self, pipeline):
274
+ self._pipeline = pipeline
275
+
276
+ @property
277
+ def llm_type(self):
278
+ return self._llm_type
279
+
280
+ @staticmethod
281
+ def get_stage_str(stage):
282
+ return "Prompts" if stage == GuardStage.PROMPT else "Responses"
283
+
284
+ def has_latency_custom_metric(self) -> bool:
285
+ """Determines if latency metric is tracked for this guard type. Default is True."""
286
+ return True
287
+
288
+ def get_latency_custom_metric_name(self):
289
+ return f"{self.name} Guard Latency"
290
+
291
+ def get_latency_custom_metric(self):
292
+ return {
293
+ "name": self.get_latency_custom_metric_name(),
294
+ "directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
295
+ "units": "seconds",
296
+ "type": CustomMetricAggregationType.AVERAGE,
297
+ "baselineValue": 0,
298
+ "isModelSpecific": True,
299
+ "timeStep": "hour",
300
+ "description": (
301
+ f"{self.get_latency_custom_metric_name()}. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
302
+ ),
303
+ }
304
+
305
+ def has_average_score_custom_metric(self) -> bool:
306
+ """Determines if an average score metric is tracked for this guard type. Default is True."""
307
+ return True
308
+
309
+ def get_average_score_custom_metric_name(self, stage):
310
+ return f"{self.name} Guard Average Score for {self.get_stage_str(stage)}"
311
+
312
+ def get_average_score_metric(self, stage):
313
+ return {
314
+ "name": self.get_average_score_custom_metric_name(stage),
315
+ "directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
316
+ "units": "probability",
317
+ "type": CustomMetricAggregationType.AVERAGE,
318
+ "baselineValue": 0,
319
+ "isModelSpecific": True,
320
+ "timeStep": "hour",
321
+ "description": (
322
+ f"{self.get_average_score_custom_metric_name(stage)}. "
323
+ f" {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
324
+ ),
325
+ }
326
+
327
+ def get_guard_enforced_custom_metric_name(self, stage, moderation_method):
328
+ if moderation_method == GuardAction.REPLACE:
329
+ return f"{self.name} Guard replaced {self.get_stage_str(stage)}"
330
+ return f"{self.name} Guard {moderation_method}ed {self.get_stage_str(stage)}"
331
+
332
+ def get_enforced_custom_metric(self, stage, moderation_method):
333
+ return {
334
+ "name": self.get_guard_enforced_custom_metric_name(stage, moderation_method),
335
+ "directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
336
+ "units": "count",
337
+ "type": CustomMetricAggregationType.SUM,
338
+ "baselineValue": 0,
339
+ "isModelSpecific": True,
340
+ "timeStep": "hour",
341
+ "description": (
342
+ f"Number of {self.get_stage_str(stage)} {moderation_method}ed by the "
343
+ f"{self.name} guard. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
344
+ ),
345
+ }
346
+
347
+ def get_input_column(self, stage):
348
+ if stage == GuardStage.PROMPT:
349
+ return (
350
+ self._model_info.input_column_name
351
+ if (self._model_info.input_column_name)
352
+ else DEFAULT_PROMPT_COLUMN_NAME
353
+ )
354
+ else:
355
+ return (
356
+ self._model_info.input_column_name
357
+ if (self._model_info and self._model_info.input_column_name)
358
+ else DEFAULT_RESPONSE_COLUMN_NAME
359
+ )
360
+
361
+ def get_intervention_action(self):
362
+ if not self.intervention:
363
+ return GuardAction.NONE
364
+ return self.intervention.action
365
+
366
+ def get_comparand(self):
367
+ return self.intervention.threshold
368
+
369
+
370
+ class GuardModelInfo:
371
+ def __init__(self, model_config: dict):
372
+ self._model_id = model_config.get("model_id")
373
+ self._input_column_name = model_config["input_column_name"]
374
+ self._target_name = model_config["target_name"]
375
+ self._target_type = model_config["target_type"]
376
+ self._class_names = model_config.get("class_names")
377
+ self.replacement_text_column_name = model_config.get("replacement_text_column_name")
378
+
379
+ @property
380
+ def model_id(self) -> str:
381
+ return self._model_id
382
+
383
+ @property
384
+ def input_column_name(self) -> str:
385
+ return self._input_column_name
386
+
387
+ @property
388
+ def target_name(self) -> str:
389
+ return self._target_name
390
+
391
+ @property
392
+ def target_type(self) -> str:
393
+ return self._target_type
394
+
395
+ @property
396
+ def class_names(self):
397
+ return self._class_names
398
+
399
+
400
+ class GuardIntervention:
401
+ def __init__(self, intervention_config: dict) -> None:
402
+ self.action = intervention_config["action"]
403
+ self.message = intervention_config.get("message")
404
+ self.threshold = None
405
+ self.comparator = None
406
+ if (
407
+ "conditions" in intervention_config
408
+ and intervention_config["conditions"] is not None
409
+ and len(intervention_config["conditions"]) > 0
410
+ ):
411
+ self.threshold = intervention_config["conditions"][0].get("comparand")
412
+ self.comparator = intervention_config["conditions"][0].get("comparator")
413
+
414
+
415
+ class ModelGuard(Guard):
416
+ def __init__(self, config: dict, stage=None):
417
+ super().__init__(config, stage)
418
+ self._deployment_id = config["deployment_id"]
419
+ self._model_info = GuardModelInfo(config["model_info"])
420
+ # dr.Client is set in the Pipeline init, Lets query the deployment
421
+ # to get the prediction server information
422
+ self.deployment = dr.Deployment.get(self._deployment_id)
423
+
424
+ @property
425
+ def deployment_id(self) -> str:
426
+ return self._deployment_id
427
+
428
+ @property
429
+ def model_info(self):
430
+ return self._model_info
431
+
432
+ def get_metric_column_name(self, stage):
433
+ if self.model_info is None:
434
+ raise NotImplementedError("Missing model_info for model guard")
435
+ return self.get_stage_str(stage) + "_" + self._model_info.target_name
436
+
437
+ def has_average_score_custom_metric(self) -> bool:
438
+ """A couple ModelGuard types do not have an average score metric"""
439
+ return self.model_info.target_type not in [
440
+ "Multiclass",
441
+ "TextGeneration",
442
+ ]
443
+
444
+
445
+ class NeMoGuard(Guard, GuardLLMMixin):
446
+ def __init__(self, config: dict, stage=None, model_dir: str = os.getcwd()):
447
+ super().__init__(config, stage)
448
+ # NeMo guard only takes a boolean as threshold and equal to as comparator.
449
+ # Threshold bool == TRUE is defined in the colang file as the output of
450
+ # `bot should intervene`
451
+ if self.intervention:
452
+ if not self.intervention.threshold:
453
+ self.intervention.threshold = NEMO_THRESHOLD
454
+ if not self.intervention.comparator:
455
+ self.intervention.comparator = GuardOperatorType.EQUALS
456
+
457
+ # Default LLM Type for NeMo is set to OpenAI
458
+ self._llm_type = config.get("llm_type", GuardLLMType.OPENAI)
459
+ self.openai_api_base = config.get("openai_api_base")
460
+ self.openai_deployment_id = config.get("openai_deployment_id")
461
+ llm_id = None
462
+ try:
463
+ self.openai_api_key = self.get_openai_api_key(config, self._llm_type)
464
+ if self._llm_type != GuardLLMType.NIM and self.openai_api_key is None:
465
+ raise ValueError("OpenAI API key is required for NeMo Guardrails")
466
+
467
+ if self.llm_type == GuardLLMType.OPENAI:
468
+ os.environ["OPENAI_API_KEY"] = self.openai_api_key
469
+ llm = None
470
+ elif self.llm_type == GuardLLMType.AZURE_OPENAI:
471
+ if self.openai_api_base is None:
472
+ raise ValueError("Azure OpenAI API base url is required for LLM Guard")
473
+ if self.openai_deployment_id is None:
474
+ raise ValueError("Azure OpenAI deployment ID is required for LLM Guard")
475
+ azure_openai_client = get_azure_openai_client(
476
+ openai_api_key=self.openai_api_key,
477
+ openai_api_base=self.openai_api_base,
478
+ openai_deployment_id=self.openai_deployment_id,
479
+ )
480
+ llm = azure_openai_client
481
+ elif self.llm_type == GuardLLMType.GOOGLE:
482
+ # llm_id = config["google_model"]
483
+ raise NotImplementedError
484
+ elif self.llm_type == GuardLLMType.AMAZON:
485
+ # llm_id = config["aws_model"]
486
+ raise NotImplementedError
487
+ elif self.llm_type == GuardLLMType.DATAROBOT:
488
+ raise NotImplementedError
489
+ elif self.llm_type == GuardLLMType.NIM:
490
+ if config.get("deployment_id") is None:
491
+ if self.openai_api_base is None:
492
+ raise ValueError("NIM DataRobot deployment id is required for NIM LLM Type")
493
+ else:
494
+ logging.warning(
495
+ "Using 'openai_api_base' is being deprecated and will be removed "
496
+ "in the next release. Please configure NIM DataRobot deployment "
497
+ "using deployment_id"
498
+ )
499
+ if self.openai_api_key is None:
500
+ raise ValueError("OpenAI API key is required for NeMo Guardrails")
501
+ else:
502
+ self.deployment = dr.Deployment.get(self._deployment_id)
503
+ datarobot_endpoint, self.openai_api_key = get_datarobot_endpoint_and_token()
504
+ self.openai_api_base = (
505
+ f"{datarobot_endpoint}/deployments/{str(self._deployment_id)}"
506
+ )
507
+ llm = get_chat_nvidia_llm(
508
+ api_key=self.openai_api_key,
509
+ base_url=self.openai_api_base,
510
+ )
511
+ else:
512
+ raise ValueError(f"Invalid LLMType: {self.llm_type}")
513
+
514
+ except Exception as e:
515
+ llm = try_to_fallback_to_llm_gateway(
516
+ # Currently only OPENAI and AZURE_OPENAI are supported by NeMoGuard
517
+ # For Bedrock and Vertex the model in the config is actually the LLM ID
518
+ # For OpenAI we use the default model defined in get_llm_gateway_client
519
+ # For Azure we use the deployment ID
520
+ llm_id=llm_id,
521
+ openai_deployment_id=self.openai_deployment_id,
522
+ llm_type=self.llm_type,
523
+ e=e,
524
+ )
525
+
526
+ # Use guard stage to determine whether to read from prompt/response subdirectory
527
+ # for nemo configurations. "nemo_guardrails" folder is at same level of custom.py
528
+ # So, the config path becomes model_dir + "nemo_guardrails"
529
+ nemo_config_path = os.path.join(model_dir, NEMO_GUARDRAILS_DIR)
530
+ self.nemo_rails_config_path = os.path.join(nemo_config_path, self.stage)
531
+ nemo_rails_config = RailsConfig.from_path(config_path=self.nemo_rails_config_path)
532
+ self._nemo_llm_rails = LLMRails(nemo_rails_config, llm=llm)
533
+
534
+ def has_average_score_custom_metric(self) -> bool:
535
+ """No average score metrics for NemoGuard's"""
536
+ return False
537
+
538
+ def get_metric_column_name(self, stage):
539
+ return self.get_stage_str(stage) + "_" + NEMO_GUARD_COLUMN_NAME
540
+
541
+ @property
542
+ def nemo_llm_rails(self):
543
+ return self._nemo_llm_rails
544
+
545
+
546
+ class OOTBGuard(Guard):
547
+ def __init__(self, config: dict, stage=None):
548
+ super().__init__(config, stage)
549
+ self._ootb_type = config["ootb_type"]
550
+
551
+ @property
552
+ def ootb_type(self):
553
+ return self._ootb_type
554
+
555
+ def has_latency_custom_metric(self):
556
+ """Latency is not tracked for token counts guards"""
557
+ return self._ootb_type != OOTBType.TOKEN_COUNT
558
+
559
+ def get_metric_column_name(self, stage):
560
+ if self._ootb_type == OOTBType.TOKEN_COUNT:
561
+ return self.get_stage_str(stage) + "_" + TOKEN_COUNT_COLUMN_NAME
562
+ elif self._ootb_type == OOTBType.ROUGE_1:
563
+ return self.get_stage_str(stage) + "_" + ROUGE_1_COLUMN_NAME
564
+ elif self._ootb_type == OOTBType.CUSTOM_METRIC:
565
+ return self.get_stage_str(stage) + "_" + self.name
566
+ else:
567
+ raise NotImplementedError(f"No metric column name defined for {self._ootb_type} guard")
568
+
569
+
570
+ class OOTBCostMetric(OOTBGuard):
571
+ def __init__(self, config, stage):
572
+ super().__init__(config, stage)
573
+ # The cost is calculated based on the usage metrics returned by the
574
+ # completion object, so it can be evaluated only at response stage
575
+ self._stage = GuardStage.RESPONSE
576
+ cost_config = config["additional_guard_config"]["cost"]
577
+ self.currency = cost_config["currency"]
578
+ self.input_price = cost_config["input_price"]
579
+ self.input_unit = cost_config["input_unit"]
580
+ self.input_multiplier = self.input_price / self.input_unit
581
+ self.output_price = cost_config["output_price"]
582
+ self.output_unit = cost_config["output_unit"]
583
+ self.output_multiplier = self.output_price / self.output_unit
584
+
585
+ def get_metric_column_name(self, _):
586
+ return COST_COLUMN_NAME
587
+
588
+ def get_average_score_custom_metric_name(self, _):
589
+ return f"Total cost in {self.currency}"
590
+
591
+ def get_average_score_metric(self, _):
592
+ return {
593
+ "name": self.get_average_score_custom_metric_name(_),
594
+ "directionality": CustomMetricDirectionality.LOWER_IS_BETTER,
595
+ "units": "value",
596
+ "type": CustomMetricAggregationType.SUM,
597
+ "baselineValue": 0,
598
+ "isModelSpecific": True,
599
+ "timeStep": "hour",
600
+ "description": (
601
+ f"{self.get_average_score_custom_metric_name(_)}. "
602
+ f" {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
603
+ ),
604
+ }
605
+
606
+
607
+ class FaithfulnessGuard(OOTBGuard, GuardLLMMixin):
608
+ def __init__(self, config: dict, stage=None):
609
+ super().__init__(config, stage)
610
+
611
+ if self.stage == GuardStage.PROMPT:
612
+ raise Exception("Faithfulness cannot be configured for the Prompt stage")
613
+
614
+ # Default LLM Type for Faithfulness is set to Azure OpenAI
615
+ self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
616
+ Settings.llm = self.get_llm(config, self._llm_type)
617
+ Settings.embed_model = None
618
+ self._evaluator = FaithfulnessEvaluator()
619
+
620
+ def get_metric_column_name(self, stage):
621
+ return self.get_stage_str(stage) + "_" + FAITHFULLNESS_COLUMN_NAME
622
+
623
+ @property
624
+ def faithfulness_evaluator(self):
625
+ return self._evaluator
626
+
627
+
628
+ class AgentGoalAccuracyGuard(OOTBGuard, GuardLLMMixin):
629
+ def __init__(self, config: dict, stage=None):
630
+ super().__init__(config, stage)
631
+
632
+ if self.stage == GuardStage.PROMPT:
633
+ raise Exception("Agent Goal Accuracy guard cannot be configured for the Prompt stage")
634
+
635
+ # Default LLM Type for Agent Goal Accuracy is set to Azure OpenAI
636
+ self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
637
+ llm = self.get_llm(config, self._llm_type)
638
+ evaluator_llm = LangchainLLMWrapper(llm)
639
+ self.scorer = AgentGoalAccuracyWithoutReference(llm=evaluator_llm)
640
+
641
+ def get_metric_column_name(self, _):
642
+ return AGENT_GOAL_ACCURACY_COLUMN_NAME
643
+
644
+ @property
645
+ def accuracy_scorer(self):
646
+ return self.scorer
647
+
648
+
649
+ class TaskAdherenceGuard(OOTBGuard, GuardLLMMixin):
650
+ def __init__(self, config: dict, stage=None):
651
+ super().__init__(config, stage)
652
+
653
+ if self.stage == GuardStage.PROMPT:
654
+ raise Exception("Agent Goal Accuracy guard cannot be configured for the Prompt stage")
655
+
656
+ # Default LLM Type for Faithfulness is set to Azure OpenAI
657
+ self._llm_type = config.get("llm_type", GuardLLMType.AZURE_OPENAI)
658
+ llm = self.get_llm(config, self._llm_type)
659
+ deepeval_llm = ModerationDeepEvalLLM(llm)
660
+ self.scorer = TaskCompletionMetric(model=deepeval_llm, include_reason=True)
661
+
662
+ def get_metric_column_name(self, _):
663
+ return TASK_ADHERENCE_SCORE_COLUMN_NAME
664
+
665
+ @property
666
+ def task_adherence_scorer(self):
667
+ return self.scorer
668
+
669
+
670
+ class GuardFactory:
671
+ @classmethod
672
+ def _perform_post_validation_checks(cls, guard_config):
673
+ if not guard_config.get("intervention"):
674
+ return
675
+
676
+ if guard_config["intervention"]["action"] == GuardAction.BLOCK and (
677
+ guard_config["intervention"]["message"] is None
678
+ or len(guard_config["intervention"]["message"]) == 0
679
+ ):
680
+ raise ValueError("Blocked action needs a blocking message")
681
+
682
+ if guard_config["intervention"]["action"] == GuardAction.REPLACE:
683
+ if "model_info" not in guard_config:
684
+ raise ValueError("'Replace' action needs model_info section")
685
+ if (
686
+ "replacement_text_column_name" not in guard_config["model_info"]
687
+ or guard_config["model_info"]["replacement_text_column_name"] is None
688
+ or len(guard_config["model_info"]["replacement_text_column_name"]) == 0
689
+ ):
690
+ raise ValueError(
691
+ "'Replace' action needs valid 'replacement_text_column_name' "
692
+ "in 'model_info' section of the guard"
693
+ )
694
+
695
+ if not guard_config["intervention"].get("conditions"):
696
+ return
697
+
698
+ if len(guard_config["intervention"]["conditions"]) == 0:
699
+ return
700
+
701
+ condition = guard_config["intervention"]["conditions"][0]
702
+ if condition["comparator"] in GuardOperatorType.REQUIRES_LIST_COMPARAND:
703
+ if not isinstance(condition["comparand"], list):
704
+ raise ValueError(
705
+ f"Comparand needs to be a list with {condition['comparator']} comparator"
706
+ )
707
+ elif isinstance(condition["comparand"], list):
708
+ raise ValueError(
709
+ f"Comparand needs to be a scalar with {condition['comparator']} comparator"
710
+ )
711
+
712
+ @staticmethod
713
+ def create(input_config: dict, stage=None, model_dir: str = os.getcwd()) -> Guard:
714
+ config = guard_trafaret.check(input_config)
715
+
716
+ GuardFactory._perform_post_validation_checks(config)
717
+
718
+ if config["type"] == GuardType.MODEL:
719
+ guard = ModelGuard(config, stage)
720
+ elif config["type"] == GuardType.OOTB:
721
+ if config["ootb_type"] == OOTBType.FAITHFULNESS:
722
+ guard = FaithfulnessGuard(config, stage)
723
+ elif config["ootb_type"] == OOTBType.COST:
724
+ guard = OOTBCostMetric(config, stage)
725
+ elif config["ootb_type"] == OOTBType.AGENT_GOAL_ACCURACY:
726
+ guard = AgentGoalAccuracyGuard(config, stage)
727
+ elif config["ootb_type"] == OOTBType.TASK_ADHERENCE:
728
+ guard = TaskAdherenceGuard(config, stage)
729
+ else:
730
+ guard = OOTBGuard(config, stage)
731
+ elif config["type"] == GuardType.NEMO_GUARDRAILS:
732
+ guard = NeMoGuard(config, stage, model_dir)
733
+ else:
734
+ raise ValueError(f"Invalid guard type: {config['type']}")
735
+
736
+ return guard