datarobot-moderations 11.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_dome/__init__.py +11 -0
- datarobot_dome/async_http_client.py +248 -0
- datarobot_dome/chat_helper.py +227 -0
- datarobot_dome/constants.py +318 -0
- datarobot_dome/drum_integration.py +977 -0
- datarobot_dome/guard.py +736 -0
- datarobot_dome/guard_executor.py +755 -0
- datarobot_dome/guard_helpers.py +457 -0
- datarobot_dome/guards/__init__.py +11 -0
- datarobot_dome/guards/guard_llm_mixin.py +232 -0
- datarobot_dome/llm.py +148 -0
- datarobot_dome/metrics/__init__.py +11 -0
- datarobot_dome/metrics/citation_metrics.py +98 -0
- datarobot_dome/metrics/factory.py +52 -0
- datarobot_dome/metrics/metric_scorer.py +78 -0
- datarobot_dome/pipeline/__init__.py +11 -0
- datarobot_dome/pipeline/llm_pipeline.py +474 -0
- datarobot_dome/pipeline/pipeline.py +376 -0
- datarobot_dome/pipeline/vdb_pipeline.py +127 -0
- datarobot_dome/streaming.py +395 -0
- datarobot_moderations-11.1.12.dist-info/METADATA +113 -0
- datarobot_moderations-11.1.12.dist-info/RECORD +23 -0
- datarobot_moderations-11.1.12.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
|
|
15
|
+
import datarobot as dr
|
|
16
|
+
import trafaret as t
|
|
17
|
+
|
|
18
|
+
from datarobot_dome.constants import AWS_ACCOUNT_SECRET_DEFINITION_SUFFIX
|
|
19
|
+
from datarobot_dome.constants import GOOGLE_SERVICE_ACCOUNT_SECRET_DEFINITION_SUFFIX
|
|
20
|
+
from datarobot_dome.constants import OPENAI_SECRET_DEFINITION_SUFFIX
|
|
21
|
+
from datarobot_dome.constants import SECRET_DEFINITION_PREFIX
|
|
22
|
+
from datarobot_dome.constants import GuardLLMType
|
|
23
|
+
from datarobot_dome.constants import GuardType
|
|
24
|
+
from datarobot_dome.constants import OOTBType
|
|
25
|
+
from datarobot_dome.guard_helpers import get_azure_openai_client
|
|
26
|
+
from datarobot_dome.guard_helpers import get_bedrock_client
|
|
27
|
+
from datarobot_dome.guard_helpers import get_datarobot_llm
|
|
28
|
+
from datarobot_dome.guard_helpers import get_vertex_client
|
|
29
|
+
from datarobot_dome.guard_helpers import try_to_fallback_to_llm_gateway
|
|
30
|
+
|
|
31
|
+
basic_credential_trafaret = t.Dict(
|
|
32
|
+
{
|
|
33
|
+
t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("basic"),
|
|
34
|
+
t.Key("password", to_name="password", optional=False): t.String,
|
|
35
|
+
},
|
|
36
|
+
allow_extra=["*"],
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
api_token_credential_trafaret = t.Dict(
|
|
40
|
+
{
|
|
41
|
+
t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("api_token"),
|
|
42
|
+
t.Key("apiToken", to_name="api_token", optional=False): t.String,
|
|
43
|
+
},
|
|
44
|
+
allow_extra=["*"],
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
google_service_account_trafaret = t.Dict(
|
|
48
|
+
{
|
|
49
|
+
t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("gcp"),
|
|
50
|
+
t.Key("gcpKey", to_name="gcp_key", optional=False): t.Dict(allow_extra=["*"]),
|
|
51
|
+
},
|
|
52
|
+
allow_extra=["*"],
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
aws_account_trafaret = t.Dict(
|
|
56
|
+
{
|
|
57
|
+
t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("s3"),
|
|
58
|
+
t.Key("awsAccessKeyId", to_name="aws_access_key_id", optional=False): t.String,
|
|
59
|
+
t.Key("awsSecretAccessKey", to_name="aws_secret_access_key", optional=False): t.String,
|
|
60
|
+
t.Key("awsSessionToken", to_name="aws_session_token", optional=True, default=None): t.String
|
|
61
|
+
| t.Null,
|
|
62
|
+
},
|
|
63
|
+
allow_extra=["*"],
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
credential_trafaret = t.Dict(
|
|
68
|
+
{
|
|
69
|
+
t.Key("type", optional=False): t.Enum("credential"),
|
|
70
|
+
t.Key("payload", optional=False): t.Or(
|
|
71
|
+
basic_credential_trafaret,
|
|
72
|
+
api_token_credential_trafaret,
|
|
73
|
+
google_service_account_trafaret,
|
|
74
|
+
aws_account_trafaret,
|
|
75
|
+
),
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GuardLLMMixin:
|
|
81
|
+
def get_secret_env_var_base(self, config, llm_type_str):
|
|
82
|
+
guard_type = config["type"]
|
|
83
|
+
guard_stage = config["stage"]
|
|
84
|
+
secret_env_var_name_prefix = f"{SECRET_DEFINITION_PREFIX}_{guard_type}_{guard_stage}_"
|
|
85
|
+
if guard_type == GuardType.NEMO_GUARDRAILS:
|
|
86
|
+
return f"{secret_env_var_name_prefix}{llm_type_str}"
|
|
87
|
+
elif guard_type == GuardType.OOTB:
|
|
88
|
+
if config["ootb_type"] == OOTBType.FAITHFULNESS:
|
|
89
|
+
return f"{secret_env_var_name_prefix}{OOTBType.FAITHFULNESS}_{llm_type_str}"
|
|
90
|
+
elif config["ootb_type"] == OOTBType.AGENT_GOAL_ACCURACY:
|
|
91
|
+
return f"{secret_env_var_name_prefix}{OOTBType.AGENT_GOAL_ACCURACY}_{llm_type_str}"
|
|
92
|
+
elif config["ootb_type"] == OOTBType.TASK_ADHERENCE:
|
|
93
|
+
return f"{secret_env_var_name_prefix}{OOTBType.TASK_ADHERENCE}_{llm_type_str}"
|
|
94
|
+
else:
|
|
95
|
+
raise Exception("Invalid guard config for building env var name")
|
|
96
|
+
else:
|
|
97
|
+
raise Exception("Invalid guard config for building env var name")
|
|
98
|
+
|
|
99
|
+
def build_open_ai_api_key_env_var_name(self, config, llm_type):
|
|
100
|
+
llm_type_str = ""
|
|
101
|
+
if llm_type == GuardLLMType.AZURE_OPENAI:
|
|
102
|
+
llm_type_str = "AZURE_"
|
|
103
|
+
elif llm_type == GuardLLMType.NIM:
|
|
104
|
+
llm_type_str = "NIM_"
|
|
105
|
+
var_name = self.get_secret_env_var_base(config, llm_type_str)
|
|
106
|
+
var_name += OPENAI_SECRET_DEFINITION_SUFFIX
|
|
107
|
+
return var_name.upper()
|
|
108
|
+
|
|
109
|
+
def get_openai_api_key(self, config, llm_type):
|
|
110
|
+
api_key_env_var_name = self.build_open_ai_api_key_env_var_name(config, llm_type)
|
|
111
|
+
if api_key_env_var_name not in os.environ:
|
|
112
|
+
if llm_type == GuardLLMType.NIM:
|
|
113
|
+
return None
|
|
114
|
+
raise Exception(f"Expected environment variable '{api_key_env_var_name}' not found")
|
|
115
|
+
|
|
116
|
+
env_var_value = json.loads(os.environ[api_key_env_var_name])
|
|
117
|
+
credential_config = credential_trafaret.check(env_var_value)
|
|
118
|
+
if credential_config["payload"]["credential_type"] == "basic":
|
|
119
|
+
return credential_config["payload"]["password"]
|
|
120
|
+
else:
|
|
121
|
+
return credential_config["payload"]["api_token"]
|
|
122
|
+
|
|
123
|
+
def get_google_service_account(self, config):
|
|
124
|
+
service_account_env_var_name = self.get_secret_env_var_base(
|
|
125
|
+
config, GOOGLE_SERVICE_ACCOUNT_SECRET_DEFINITION_SUFFIX
|
|
126
|
+
).upper()
|
|
127
|
+
if service_account_env_var_name not in os.environ:
|
|
128
|
+
raise Exception(
|
|
129
|
+
f"Expected environment variable '{service_account_env_var_name}' not found"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
env_var_value = json.loads(os.environ[service_account_env_var_name])
|
|
133
|
+
credential_config = credential_trafaret.check(env_var_value)
|
|
134
|
+
if credential_config["payload"]["credential_type"] == "gcp":
|
|
135
|
+
return credential_config["payload"]["gcp_key"]
|
|
136
|
+
else:
|
|
137
|
+
raise Exception("Google model requires a credential of type 'gcp'")
|
|
138
|
+
|
|
139
|
+
def get_aws_account(self, config):
|
|
140
|
+
service_account_env_var_name = self.get_secret_env_var_base(
|
|
141
|
+
config, AWS_ACCOUNT_SECRET_DEFINITION_SUFFIX
|
|
142
|
+
).upper()
|
|
143
|
+
if service_account_env_var_name not in os.environ:
|
|
144
|
+
raise Exception(
|
|
145
|
+
f"Expected environment variable '{service_account_env_var_name}' not found"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
env_var_value = json.loads(os.environ[service_account_env_var_name])
|
|
149
|
+
credential_config = credential_trafaret.check(env_var_value)
|
|
150
|
+
if credential_config["payload"]["credential_type"] == "s3":
|
|
151
|
+
return credential_config["payload"]
|
|
152
|
+
else:
|
|
153
|
+
raise Exception("Amazon model requires a credential of type 's3'")
|
|
154
|
+
|
|
155
|
+
def get_llm(self, config, llm_type):
|
|
156
|
+
openai_api_base = config.get("openai_api_base")
|
|
157
|
+
openai_deployment_id = config.get("openai_deployment_id")
|
|
158
|
+
llm_id = None
|
|
159
|
+
try:
|
|
160
|
+
if llm_type in [GuardLLMType.OPENAI, GuardLLMType.AZURE_OPENAI]:
|
|
161
|
+
openai_api_key = self.get_openai_api_key(config, llm_type)
|
|
162
|
+
if openai_api_key is None:
|
|
163
|
+
raise ValueError("OpenAI API key is required for Faithfulness guard")
|
|
164
|
+
|
|
165
|
+
if llm_type == GuardLLMType.OPENAI:
|
|
166
|
+
os.environ["OPENAI_API_KEY"] = openai_api_key
|
|
167
|
+
llm = "default"
|
|
168
|
+
elif llm_type == GuardLLMType.AZURE_OPENAI:
|
|
169
|
+
if openai_api_base is None:
|
|
170
|
+
raise ValueError("OpenAI API base url is required for LLM Guard")
|
|
171
|
+
if openai_deployment_id is None:
|
|
172
|
+
raise ValueError("OpenAI deployment ID is required for LLM Guard")
|
|
173
|
+
azure_openai_client = get_azure_openai_client(
|
|
174
|
+
openai_api_key=openai_api_key,
|
|
175
|
+
openai_api_base=openai_api_base,
|
|
176
|
+
openai_deployment_id=openai_deployment_id,
|
|
177
|
+
)
|
|
178
|
+
llm = azure_openai_client
|
|
179
|
+
elif llm_type == GuardLLMType.GOOGLE:
|
|
180
|
+
llm_id = config["google_model"]
|
|
181
|
+
if llm_id is None:
|
|
182
|
+
raise ValueError("Google model is required for LLM Guard")
|
|
183
|
+
if config.get("google_region") is None:
|
|
184
|
+
raise ValueError("Google region is required for LLM Guard")
|
|
185
|
+
llm = get_vertex_client(
|
|
186
|
+
google_model=llm_id,
|
|
187
|
+
google_service_account=self.get_google_service_account(config),
|
|
188
|
+
google_region=config["google_region"],
|
|
189
|
+
)
|
|
190
|
+
elif llm_type == GuardLLMType.AMAZON:
|
|
191
|
+
llm_id = config["aws_model"]
|
|
192
|
+
if llm_id is None:
|
|
193
|
+
raise ValueError("AWS model is required for LLM Guard")
|
|
194
|
+
if config.get("aws_region") is None:
|
|
195
|
+
raise ValueError("AWS region is required for LLM Guard")
|
|
196
|
+
credential_config = self.get_aws_account(config)
|
|
197
|
+
llm = get_bedrock_client(
|
|
198
|
+
aws_model=llm_id,
|
|
199
|
+
aws_access_key_id=credential_config["aws_access_key_id"],
|
|
200
|
+
aws_secret_access_key=credential_config["aws_secret_access_key"],
|
|
201
|
+
aws_session_token=credential_config["aws_session_token"],
|
|
202
|
+
aws_region=config["aws_region"],
|
|
203
|
+
)
|
|
204
|
+
elif llm_type == GuardLLMType.DATAROBOT:
|
|
205
|
+
if config["type"] == GuardType.OOTB and config["ootb_type"] in [
|
|
206
|
+
OOTBType.AGENT_GOAL_ACCURACY,
|
|
207
|
+
OOTBType.TASK_ADHERENCE,
|
|
208
|
+
]:
|
|
209
|
+
# DataRobot LLM does not implement generate / agenerate yet
|
|
210
|
+
# so can't support it for Agent Goal Accuracy
|
|
211
|
+
raise NotImplementedError
|
|
212
|
+
if config.get("deployment_id") is None:
|
|
213
|
+
raise ValueError("Deployment ID is required for LLM Guard")
|
|
214
|
+
deployment = dr.Deployment.get(config["deployment_id"])
|
|
215
|
+
llm = get_datarobot_llm(deployment)
|
|
216
|
+
elif llm_type == GuardLLMType.NIM:
|
|
217
|
+
raise NotImplementedError
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Invalid LLMType: {llm_type}")
|
|
220
|
+
|
|
221
|
+
except Exception as e:
|
|
222
|
+
llm = try_to_fallback_to_llm_gateway(
|
|
223
|
+
# For Bedrock and Vertex the model in the config is actually the LLM ID
|
|
224
|
+
# For OpenAI we use the default model defined in get_llm_gateway_client
|
|
225
|
+
# For Azure we use the deployment ID
|
|
226
|
+
llm_id=llm_id,
|
|
227
|
+
openai_deployment_id=openai_deployment_id,
|
|
228
|
+
llm_type=llm_type,
|
|
229
|
+
e=e,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
return llm
|
datarobot_dome/llm.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import datarobot as dr
|
|
15
|
+
import pandas as pd
|
|
16
|
+
from datarobot_predict.deployment import predict
|
|
17
|
+
from llama_index.core.base.llms.types import CompletionResponse
|
|
18
|
+
from llama_index.core.base.llms.types import LLMMetadata
|
|
19
|
+
from llama_index.core.bridge.pydantic import PrivateAttr
|
|
20
|
+
from llama_index.core.llms.callbacks import llm_chat_callback
|
|
21
|
+
from llama_index.core.llms.callbacks import llm_completion_callback
|
|
22
|
+
from llama_index.core.llms.llm import LLM
|
|
23
|
+
|
|
24
|
+
from datarobot_dome.async_http_client import AsyncHTTPClient
|
|
25
|
+
|
|
26
|
+
DEFAULT_TEMPERATURE = 1.0
|
|
27
|
+
MAX_TOKENS = 512
|
|
28
|
+
DEFAULT_TIMEOUT = 30
|
|
29
|
+
MAX_RETRIES = 5
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DataRobotLLM(LLM):
|
|
33
|
+
# DataRobot deployment object. Only one of `deployment` or `deployment_id` is required
|
|
34
|
+
_deployment: dr.Deployment = PrivateAttr()
|
|
35
|
+
# DataRobot endpoint URL, Only one of `dr_client` or the pair
|
|
36
|
+
# (datarobot_endpoint, datarobot_api_token) is required
|
|
37
|
+
_datarobot_endpoint: str = PrivateAttr()
|
|
38
|
+
# DataRobot API Token to use, Only one of `dr_client` or the pair
|
|
39
|
+
# (datarobot_endpoint, datarobot_api_token) is required
|
|
40
|
+
_datarobot_api_token: str = PrivateAttr()
|
|
41
|
+
# Async HTTP Client for all async prediction requests with DataRobot Deployment
|
|
42
|
+
_async_http_client: Any = PrivateAttr()
|
|
43
|
+
|
|
44
|
+
_prompt_column_name: str = PrivateAttr()
|
|
45
|
+
_target_column_name: str = PrivateAttr()
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
deployment,
|
|
50
|
+
datarobot_endpoint=None,
|
|
51
|
+
datarobot_api_token=None,
|
|
52
|
+
callback_manager=None,
|
|
53
|
+
):
|
|
54
|
+
super().__init__(
|
|
55
|
+
model="DataRobot LLM",
|
|
56
|
+
temperature=DEFAULT_TEMPERATURE,
|
|
57
|
+
max_tokens=MAX_TOKENS,
|
|
58
|
+
timeout=DEFAULT_TIMEOUT,
|
|
59
|
+
max_retries=MAX_RETRIES,
|
|
60
|
+
callback_manager=callback_manager,
|
|
61
|
+
)
|
|
62
|
+
if deployment is None:
|
|
63
|
+
raise ValueError("DataRobot deployment is required")
|
|
64
|
+
|
|
65
|
+
if datarobot_api_token is None and datarobot_endpoint is None:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"Connection parameters 'datarobot_endpoint' and 'datarobot_api_token' "
|
|
68
|
+
"needs to be provided"
|
|
69
|
+
)
|
|
70
|
+
self._deployment = deployment
|
|
71
|
+
self._datarobot_endpoint = datarobot_endpoint
|
|
72
|
+
self._datarobot_api_token = datarobot_api_token
|
|
73
|
+
|
|
74
|
+
if self._deployment.model["target_type"] != "TextGeneration":
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Invalid deployment '{self._deployment.label}' for LLM. Expecting an LLM "
|
|
77
|
+
f"deployment, but is a '{self._deployment.model['target_type']}' deployment"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self._prompt_column_name = self._deployment.model.get("prompt")
|
|
81
|
+
if self._prompt_column_name is None:
|
|
82
|
+
raise ValueError("Prompt column name 'prompt' is not set on the deployment / model")
|
|
83
|
+
|
|
84
|
+
self._target_column_name = self._deployment.model["target_name"] + "_PREDICTION"
|
|
85
|
+
self._async_http_client = AsyncHTTPClient(DEFAULT_TIMEOUT)
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def _llm_type(self) -> str:
|
|
89
|
+
"""Return type of llm."""
|
|
90
|
+
return "datarobot-llm"
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def _identifying_params(self):
|
|
94
|
+
"""Get the identifying parameters."""
|
|
95
|
+
_model_kwargs = self.model_kwargs or {}
|
|
96
|
+
return {
|
|
97
|
+
**{"endpoint_url": self._datarobot_endpoint, "deployment_id": str(self._deployment.id)},
|
|
98
|
+
**{"model_kwargs": _model_kwargs},
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def _call(self, prompt, stop=None, run_manager=None, **kwargs):
|
|
102
|
+
df = pd.DataFrame({self._prompt_column_name: [prompt]})
|
|
103
|
+
result_df, _ = predict(self._deployment, df)
|
|
104
|
+
return result_df[self._target_column_name].iloc[0]
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def class_name(cls) -> str:
|
|
108
|
+
return "DataRobotLLM"
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def metadata(self):
|
|
112
|
+
return LLMMetadata(is_chat_model=False)
|
|
113
|
+
|
|
114
|
+
@llm_chat_callback()
|
|
115
|
+
def chat(self, messages, **kwargs: Any):
|
|
116
|
+
raise NotImplementedError
|
|
117
|
+
|
|
118
|
+
@llm_completion_callback()
|
|
119
|
+
def complete(self, prompt, formatted, **kwargs):
|
|
120
|
+
df = pd.DataFrame({self._prompt_column_name: [prompt]})
|
|
121
|
+
result_df, _ = predict(self._deployment, df)
|
|
122
|
+
return CompletionResponse(text=result_df[self._target_column_name].iloc[0], raw={})
|
|
123
|
+
|
|
124
|
+
@llm_chat_callback()
|
|
125
|
+
def stream_chat(self, messages, **kwargs):
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
@llm_completion_callback()
|
|
129
|
+
def stream_complete(self, prompt, formatted=False, **kwargs):
|
|
130
|
+
raise NotImplementedError
|
|
131
|
+
|
|
132
|
+
@llm_chat_callback()
|
|
133
|
+
async def achat(self, messages, **kwargs):
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
136
|
+
@llm_completion_callback()
|
|
137
|
+
async def acomplete(self, prompt, formatted=False, **kwargs):
|
|
138
|
+
input_df_to_guard = pd.DataFrame({self._prompt_column_name: [prompt]})
|
|
139
|
+
result_df = await self._async_http_client.predict(self._deployment, input_df_to_guard)
|
|
140
|
+
return CompletionResponse(text=result_df[self._target_column_name].iloc[0], raw={})
|
|
141
|
+
|
|
142
|
+
@llm_chat_callback()
|
|
143
|
+
async def astream_chat(self, messages, **kwargs):
|
|
144
|
+
raise NotImplementedError
|
|
145
|
+
|
|
146
|
+
@llm_completion_callback()
|
|
147
|
+
async def astream_complete(self, prompt, formatted=False, **kwargs):
|
|
148
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from datarobot.enums import CustomMetricAggregationType
|
|
14
|
+
from datarobot.enums import CustomMetricDirectionality
|
|
15
|
+
|
|
16
|
+
from datarobot_dome.constants import CUSTOM_METRIC_DESCRIPTION_SUFFIX
|
|
17
|
+
from datarobot_dome.guard_helpers import get_token_count
|
|
18
|
+
from datarobot_dome.metrics.metric_scorer import MetricScorer
|
|
19
|
+
|
|
20
|
+
CITATION_COLUMN = "response.citations"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CitationTokenCountScorer(MetricScorer):
|
|
24
|
+
NAME = "Total Citation Tokens"
|
|
25
|
+
DESCRIPTION = f"Total number of citation tokens. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
26
|
+
DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
|
|
27
|
+
UNITS = "count"
|
|
28
|
+
AGGREGATION_TYPE = CustomMetricAggregationType.SUM
|
|
29
|
+
BASELINE_VALUE = 0
|
|
30
|
+
INPUT_COLUMN = CITATION_COLUMN
|
|
31
|
+
|
|
32
|
+
def score(self, df: pd.DataFrame) -> float:
|
|
33
|
+
column = self.input_column
|
|
34
|
+
if column not in df.columns:
|
|
35
|
+
return 0.0
|
|
36
|
+
|
|
37
|
+
return sum(
|
|
38
|
+
sum(get_token_count(v, self.encoding) for v in cell) for cell in df[column].values
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CitationTokenAverageScorer(MetricScorer):
|
|
43
|
+
NAME = "Average Citation Tokens"
|
|
44
|
+
DESCRIPTION = f"Average number of citation tokens. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
45
|
+
DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
|
|
46
|
+
UNITS = "count"
|
|
47
|
+
AGGREGATION_TYPE = CustomMetricAggregationType.AVERAGE
|
|
48
|
+
BASELINE_VALUE = 0
|
|
49
|
+
INPUT_COLUMN = CITATION_COLUMN
|
|
50
|
+
|
|
51
|
+
def score(self, df: pd.DataFrame) -> float:
|
|
52
|
+
average = 0.0
|
|
53
|
+
total = 0
|
|
54
|
+
count = 0
|
|
55
|
+
column = self.input_column
|
|
56
|
+
if column not in df.columns:
|
|
57
|
+
return 0.0
|
|
58
|
+
|
|
59
|
+
for cell in df[column].values:
|
|
60
|
+
total += sum(get_token_count(v, self.encoding) for v in cell)
|
|
61
|
+
count += sum(v != "" for v in cell)
|
|
62
|
+
average = total / count
|
|
63
|
+
|
|
64
|
+
return average
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class DocumentCountScorer(MetricScorer):
|
|
68
|
+
NAME = "Total Documents"
|
|
69
|
+
DESCRIPTION = f"Total number of documents. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
70
|
+
DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
|
|
71
|
+
UNITS = "count"
|
|
72
|
+
AGGREGATION_TYPE = CustomMetricAggregationType.SUM
|
|
73
|
+
BASELINE_VALUE = 0
|
|
74
|
+
INPUT_COLUMN = CITATION_COLUMN
|
|
75
|
+
|
|
76
|
+
def score(self, df: pd.DataFrame) -> float:
|
|
77
|
+
column = self.input_column
|
|
78
|
+
if column not in df.columns:
|
|
79
|
+
return 0.0
|
|
80
|
+
|
|
81
|
+
return sum(sum(bool(v) for v in cell) for cell in df[column].values)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class DocumentAverageScorer(MetricScorer):
|
|
85
|
+
NAME = "Average Documents"
|
|
86
|
+
DESCRIPTION = f"Average number of documents. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
|
|
87
|
+
DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
|
|
88
|
+
UNITS = "count"
|
|
89
|
+
AGGREGATION_TYPE = CustomMetricAggregationType.AVERAGE
|
|
90
|
+
BASELINE_VALUE = 0
|
|
91
|
+
INPUT_COLUMN = CITATION_COLUMN
|
|
92
|
+
|
|
93
|
+
def score(self, df: pd.DataFrame) -> float:
|
|
94
|
+
column = self.input_column
|
|
95
|
+
if column not in df.columns:
|
|
96
|
+
return 0.0
|
|
97
|
+
|
|
98
|
+
return sum(sum(bool(v) for v in cell) for cell in df[column].values)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
12
|
+
from typing import Any
|
|
13
|
+
from typing import ClassVar
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
from datarobot_dome.metrics.citation_metrics import CitationTokenAverageScorer
|
|
17
|
+
from datarobot_dome.metrics.citation_metrics import CitationTokenCountScorer
|
|
18
|
+
from datarobot_dome.metrics.citation_metrics import DocumentAverageScorer
|
|
19
|
+
from datarobot_dome.metrics.citation_metrics import DocumentCountScorer
|
|
20
|
+
from datarobot_dome.metrics.metric_scorer import MetricScorer
|
|
21
|
+
from datarobot_dome.metrics.metric_scorer import ScorerType
|
|
22
|
+
|
|
23
|
+
METRIC_SCORE_CLASS_MAP: dict[ScorerType, ClassVar] = {
|
|
24
|
+
ScorerType.CITATION_TOKEN_AVERAGE: CitationTokenAverageScorer,
|
|
25
|
+
ScorerType.CITATION_TOKEN_COUNT: CitationTokenCountScorer,
|
|
26
|
+
ScorerType.DOCUMENT_AVERAGE: DocumentAverageScorer,
|
|
27
|
+
ScorerType.DOCUMENT_COUNT: DocumentCountScorer,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MetricScorerFactory:
|
|
32
|
+
@staticmethod
|
|
33
|
+
def get_class(metric_type: ScorerType) -> ClassVar:
|
|
34
|
+
clazz = METRIC_SCORE_CLASS_MAP.get(metric_type)
|
|
35
|
+
if clazz is None:
|
|
36
|
+
raise ValueError(f"Unknown metric type: {metric_type}")
|
|
37
|
+
|
|
38
|
+
return clazz
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def create(metric_type: ScorerType, config: Optional[dict[str, Any]] = None) -> MetricScorer:
|
|
42
|
+
_config = config or {}
|
|
43
|
+
clazz = MetricScorerFactory.get_class(metric_type)
|
|
44
|
+
return clazz(_config)
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def custom_metric_config(
|
|
48
|
+
metric_type: ScorerType, config: Optional[dict[str, Any]] = None
|
|
49
|
+
) -> dict[str, Any]:
|
|
50
|
+
_config = config or {}
|
|
51
|
+
clazz = MetricScorerFactory.get_class(metric_type)
|
|
52
|
+
return clazz.custom_metric_definition(_config)
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|
|
12
|
+
from abc import ABC
|
|
13
|
+
from abc import abstractmethod
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import pandas as pd
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ScorerType(str, Enum):
|
|
21
|
+
CITATION_TOKEN_AVERAGE = "CITATION_TOKEN_AVERAGE"
|
|
22
|
+
CITATION_TOKEN_COUNT = "CITATION_TOKEN_COUNT"
|
|
23
|
+
DOCUMENT_AVERAGE = "DOCUMENT_AVERAGE"
|
|
24
|
+
DOCUMENT_COUNT = "DOCUMENT_COUNT"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MetricScorer(ABC):
|
|
28
|
+
IS_MODEL_SPECIFIC = True
|
|
29
|
+
TIME_STEP = "hour"
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
config: dict[str, Any],
|
|
34
|
+
):
|
|
35
|
+
self.config = config
|
|
36
|
+
|
|
37
|
+
def __repr__(self) -> str:
|
|
38
|
+
return f"{self.__class__.__name__}()"
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def custom_metric_definition(cls, config: dict[str, Any]) -> dict[str, Any]:
|
|
42
|
+
"""
|
|
43
|
+
Generates a custom-metric configuration/definition that is used to create the
|
|
44
|
+
custom-metric in the DR application.
|
|
45
|
+
|
|
46
|
+
This is done as a class method because we create custom-metrics before creating
|
|
47
|
+
the scorer.
|
|
48
|
+
"""
|
|
49
|
+
return {
|
|
50
|
+
"name": config.get("name", cls.NAME),
|
|
51
|
+
"directionality": cls.DIRECTIONALITY,
|
|
52
|
+
"units": cls.UNITS,
|
|
53
|
+
"type": cls.AGGREGATION_TYPE,
|
|
54
|
+
"baselineValue": config.get("baseline_value", cls.BASELINE_VALUE),
|
|
55
|
+
"isModelSpecific": config.get("is_model_specific", cls.IS_MODEL_SPECIFIC),
|
|
56
|
+
"timeStep": cls.TIME_STEP,
|
|
57
|
+
"description": config.get("description", cls.DESCRIPTION),
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def name(self) -> str:
|
|
62
|
+
return self.config.get("name", self.NAME)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def per_prediction(self) -> bool:
|
|
66
|
+
return self.config.get("per_prediction", False)
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def input_column(self) -> str:
|
|
70
|
+
return self.config.get("input_column", self.INPUT_COLUMN)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def encoding(self) -> str:
|
|
74
|
+
return self.config.get("encoding", "cl100k_base")
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def score(self, df: pd.DataFrame) -> float:
|
|
78
|
+
pass
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
|
|
3
|
+
# Last updated 2025.
|
|
4
|
+
#
|
|
5
|
+
# DataRobot, Inc. Confidential.
|
|
6
|
+
# This is proprietary source code of DataRobot, Inc. and its affiliates.
|
|
7
|
+
#
|
|
8
|
+
# This file and its contents are subject to DataRobot Tool and Utility Agreement.
|
|
9
|
+
# For details, see
|
|
10
|
+
# https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
|
|
11
|
+
# ---------------------------------------------------------------------------------
|