datarobot-moderations 11.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,232 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ import json
13
+ import os
14
+
15
+ import datarobot as dr
16
+ import trafaret as t
17
+
18
+ from datarobot_dome.constants import AWS_ACCOUNT_SECRET_DEFINITION_SUFFIX
19
+ from datarobot_dome.constants import GOOGLE_SERVICE_ACCOUNT_SECRET_DEFINITION_SUFFIX
20
+ from datarobot_dome.constants import OPENAI_SECRET_DEFINITION_SUFFIX
21
+ from datarobot_dome.constants import SECRET_DEFINITION_PREFIX
22
+ from datarobot_dome.constants import GuardLLMType
23
+ from datarobot_dome.constants import GuardType
24
+ from datarobot_dome.constants import OOTBType
25
+ from datarobot_dome.guard_helpers import get_azure_openai_client
26
+ from datarobot_dome.guard_helpers import get_bedrock_client
27
+ from datarobot_dome.guard_helpers import get_datarobot_llm
28
+ from datarobot_dome.guard_helpers import get_vertex_client
29
+ from datarobot_dome.guard_helpers import try_to_fallback_to_llm_gateway
30
+
31
+ basic_credential_trafaret = t.Dict(
32
+ {
33
+ t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("basic"),
34
+ t.Key("password", to_name="password", optional=False): t.String,
35
+ },
36
+ allow_extra=["*"],
37
+ )
38
+
39
+ api_token_credential_trafaret = t.Dict(
40
+ {
41
+ t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("api_token"),
42
+ t.Key("apiToken", to_name="api_token", optional=False): t.String,
43
+ },
44
+ allow_extra=["*"],
45
+ )
46
+
47
+ google_service_account_trafaret = t.Dict(
48
+ {
49
+ t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("gcp"),
50
+ t.Key("gcpKey", to_name="gcp_key", optional=False): t.Dict(allow_extra=["*"]),
51
+ },
52
+ allow_extra=["*"],
53
+ )
54
+
55
+ aws_account_trafaret = t.Dict(
56
+ {
57
+ t.Key("credentialType", to_name="credential_type", optional=False): t.Enum("s3"),
58
+ t.Key("awsAccessKeyId", to_name="aws_access_key_id", optional=False): t.String,
59
+ t.Key("awsSecretAccessKey", to_name="aws_secret_access_key", optional=False): t.String,
60
+ t.Key("awsSessionToken", to_name="aws_session_token", optional=True, default=None): t.String
61
+ | t.Null,
62
+ },
63
+ allow_extra=["*"],
64
+ )
65
+
66
+
67
+ credential_trafaret = t.Dict(
68
+ {
69
+ t.Key("type", optional=False): t.Enum("credential"),
70
+ t.Key("payload", optional=False): t.Or(
71
+ basic_credential_trafaret,
72
+ api_token_credential_trafaret,
73
+ google_service_account_trafaret,
74
+ aws_account_trafaret,
75
+ ),
76
+ }
77
+ )
78
+
79
+
80
+ class GuardLLMMixin:
81
+ def get_secret_env_var_base(self, config, llm_type_str):
82
+ guard_type = config["type"]
83
+ guard_stage = config["stage"]
84
+ secret_env_var_name_prefix = f"{SECRET_DEFINITION_PREFIX}_{guard_type}_{guard_stage}_"
85
+ if guard_type == GuardType.NEMO_GUARDRAILS:
86
+ return f"{secret_env_var_name_prefix}{llm_type_str}"
87
+ elif guard_type == GuardType.OOTB:
88
+ if config["ootb_type"] == OOTBType.FAITHFULNESS:
89
+ return f"{secret_env_var_name_prefix}{OOTBType.FAITHFULNESS}_{llm_type_str}"
90
+ elif config["ootb_type"] == OOTBType.AGENT_GOAL_ACCURACY:
91
+ return f"{secret_env_var_name_prefix}{OOTBType.AGENT_GOAL_ACCURACY}_{llm_type_str}"
92
+ elif config["ootb_type"] == OOTBType.TASK_ADHERENCE:
93
+ return f"{secret_env_var_name_prefix}{OOTBType.TASK_ADHERENCE}_{llm_type_str}"
94
+ else:
95
+ raise Exception("Invalid guard config for building env var name")
96
+ else:
97
+ raise Exception("Invalid guard config for building env var name")
98
+
99
+ def build_open_ai_api_key_env_var_name(self, config, llm_type):
100
+ llm_type_str = ""
101
+ if llm_type == GuardLLMType.AZURE_OPENAI:
102
+ llm_type_str = "AZURE_"
103
+ elif llm_type == GuardLLMType.NIM:
104
+ llm_type_str = "NIM_"
105
+ var_name = self.get_secret_env_var_base(config, llm_type_str)
106
+ var_name += OPENAI_SECRET_DEFINITION_SUFFIX
107
+ return var_name.upper()
108
+
109
+ def get_openai_api_key(self, config, llm_type):
110
+ api_key_env_var_name = self.build_open_ai_api_key_env_var_name(config, llm_type)
111
+ if api_key_env_var_name not in os.environ:
112
+ if llm_type == GuardLLMType.NIM:
113
+ return None
114
+ raise Exception(f"Expected environment variable '{api_key_env_var_name}' not found")
115
+
116
+ env_var_value = json.loads(os.environ[api_key_env_var_name])
117
+ credential_config = credential_trafaret.check(env_var_value)
118
+ if credential_config["payload"]["credential_type"] == "basic":
119
+ return credential_config["payload"]["password"]
120
+ else:
121
+ return credential_config["payload"]["api_token"]
122
+
123
+ def get_google_service_account(self, config):
124
+ service_account_env_var_name = self.get_secret_env_var_base(
125
+ config, GOOGLE_SERVICE_ACCOUNT_SECRET_DEFINITION_SUFFIX
126
+ ).upper()
127
+ if service_account_env_var_name not in os.environ:
128
+ raise Exception(
129
+ f"Expected environment variable '{service_account_env_var_name}' not found"
130
+ )
131
+
132
+ env_var_value = json.loads(os.environ[service_account_env_var_name])
133
+ credential_config = credential_trafaret.check(env_var_value)
134
+ if credential_config["payload"]["credential_type"] == "gcp":
135
+ return credential_config["payload"]["gcp_key"]
136
+ else:
137
+ raise Exception("Google model requires a credential of type 'gcp'")
138
+
139
+ def get_aws_account(self, config):
140
+ service_account_env_var_name = self.get_secret_env_var_base(
141
+ config, AWS_ACCOUNT_SECRET_DEFINITION_SUFFIX
142
+ ).upper()
143
+ if service_account_env_var_name not in os.environ:
144
+ raise Exception(
145
+ f"Expected environment variable '{service_account_env_var_name}' not found"
146
+ )
147
+
148
+ env_var_value = json.loads(os.environ[service_account_env_var_name])
149
+ credential_config = credential_trafaret.check(env_var_value)
150
+ if credential_config["payload"]["credential_type"] == "s3":
151
+ return credential_config["payload"]
152
+ else:
153
+ raise Exception("Amazon model requires a credential of type 's3'")
154
+
155
+ def get_llm(self, config, llm_type):
156
+ openai_api_base = config.get("openai_api_base")
157
+ openai_deployment_id = config.get("openai_deployment_id")
158
+ llm_id = None
159
+ try:
160
+ if llm_type in [GuardLLMType.OPENAI, GuardLLMType.AZURE_OPENAI]:
161
+ openai_api_key = self.get_openai_api_key(config, llm_type)
162
+ if openai_api_key is None:
163
+ raise ValueError("OpenAI API key is required for Faithfulness guard")
164
+
165
+ if llm_type == GuardLLMType.OPENAI:
166
+ os.environ["OPENAI_API_KEY"] = openai_api_key
167
+ llm = "default"
168
+ elif llm_type == GuardLLMType.AZURE_OPENAI:
169
+ if openai_api_base is None:
170
+ raise ValueError("OpenAI API base url is required for LLM Guard")
171
+ if openai_deployment_id is None:
172
+ raise ValueError("OpenAI deployment ID is required for LLM Guard")
173
+ azure_openai_client = get_azure_openai_client(
174
+ openai_api_key=openai_api_key,
175
+ openai_api_base=openai_api_base,
176
+ openai_deployment_id=openai_deployment_id,
177
+ )
178
+ llm = azure_openai_client
179
+ elif llm_type == GuardLLMType.GOOGLE:
180
+ llm_id = config["google_model"]
181
+ if llm_id is None:
182
+ raise ValueError("Google model is required for LLM Guard")
183
+ if config.get("google_region") is None:
184
+ raise ValueError("Google region is required for LLM Guard")
185
+ llm = get_vertex_client(
186
+ google_model=llm_id,
187
+ google_service_account=self.get_google_service_account(config),
188
+ google_region=config["google_region"],
189
+ )
190
+ elif llm_type == GuardLLMType.AMAZON:
191
+ llm_id = config["aws_model"]
192
+ if llm_id is None:
193
+ raise ValueError("AWS model is required for LLM Guard")
194
+ if config.get("aws_region") is None:
195
+ raise ValueError("AWS region is required for LLM Guard")
196
+ credential_config = self.get_aws_account(config)
197
+ llm = get_bedrock_client(
198
+ aws_model=llm_id,
199
+ aws_access_key_id=credential_config["aws_access_key_id"],
200
+ aws_secret_access_key=credential_config["aws_secret_access_key"],
201
+ aws_session_token=credential_config["aws_session_token"],
202
+ aws_region=config["aws_region"],
203
+ )
204
+ elif llm_type == GuardLLMType.DATAROBOT:
205
+ if config["type"] == GuardType.OOTB and config["ootb_type"] in [
206
+ OOTBType.AGENT_GOAL_ACCURACY,
207
+ OOTBType.TASK_ADHERENCE,
208
+ ]:
209
+ # DataRobot LLM does not implement generate / agenerate yet
210
+ # so can't support it for Agent Goal Accuracy
211
+ raise NotImplementedError
212
+ if config.get("deployment_id") is None:
213
+ raise ValueError("Deployment ID is required for LLM Guard")
214
+ deployment = dr.Deployment.get(config["deployment_id"])
215
+ llm = get_datarobot_llm(deployment)
216
+ elif llm_type == GuardLLMType.NIM:
217
+ raise NotImplementedError
218
+ else:
219
+ raise ValueError(f"Invalid LLMType: {llm_type}")
220
+
221
+ except Exception as e:
222
+ llm = try_to_fallback_to_llm_gateway(
223
+ # For Bedrock and Vertex the model in the config is actually the LLM ID
224
+ # For OpenAI we use the default model defined in get_llm_gateway_client
225
+ # For Azure we use the deployment ID
226
+ llm_id=llm_id,
227
+ openai_deployment_id=openai_deployment_id,
228
+ llm_type=llm_type,
229
+ e=e,
230
+ )
231
+
232
+ return llm
datarobot_dome/llm.py ADDED
@@ -0,0 +1,148 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ from typing import Any
13
+
14
+ import datarobot as dr
15
+ import pandas as pd
16
+ from datarobot_predict.deployment import predict
17
+ from llama_index.core.base.llms.types import CompletionResponse
18
+ from llama_index.core.base.llms.types import LLMMetadata
19
+ from llama_index.core.bridge.pydantic import PrivateAttr
20
+ from llama_index.core.llms.callbacks import llm_chat_callback
21
+ from llama_index.core.llms.callbacks import llm_completion_callback
22
+ from llama_index.core.llms.llm import LLM
23
+
24
+ from datarobot_dome.async_http_client import AsyncHTTPClient
25
+
26
+ DEFAULT_TEMPERATURE = 1.0
27
+ MAX_TOKENS = 512
28
+ DEFAULT_TIMEOUT = 30
29
+ MAX_RETRIES = 5
30
+
31
+
32
+ class DataRobotLLM(LLM):
33
+ # DataRobot deployment object. Only one of `deployment` or `deployment_id` is required
34
+ _deployment: dr.Deployment = PrivateAttr()
35
+ # DataRobot endpoint URL, Only one of `dr_client` or the pair
36
+ # (datarobot_endpoint, datarobot_api_token) is required
37
+ _datarobot_endpoint: str = PrivateAttr()
38
+ # DataRobot API Token to use, Only one of `dr_client` or the pair
39
+ # (datarobot_endpoint, datarobot_api_token) is required
40
+ _datarobot_api_token: str = PrivateAttr()
41
+ # Async HTTP Client for all async prediction requests with DataRobot Deployment
42
+ _async_http_client: Any = PrivateAttr()
43
+
44
+ _prompt_column_name: str = PrivateAttr()
45
+ _target_column_name: str = PrivateAttr()
46
+
47
+ def __init__(
48
+ self,
49
+ deployment,
50
+ datarobot_endpoint=None,
51
+ datarobot_api_token=None,
52
+ callback_manager=None,
53
+ ):
54
+ super().__init__(
55
+ model="DataRobot LLM",
56
+ temperature=DEFAULT_TEMPERATURE,
57
+ max_tokens=MAX_TOKENS,
58
+ timeout=DEFAULT_TIMEOUT,
59
+ max_retries=MAX_RETRIES,
60
+ callback_manager=callback_manager,
61
+ )
62
+ if deployment is None:
63
+ raise ValueError("DataRobot deployment is required")
64
+
65
+ if datarobot_api_token is None and datarobot_endpoint is None:
66
+ raise ValueError(
67
+ "Connection parameters 'datarobot_endpoint' and 'datarobot_api_token' "
68
+ "needs to be provided"
69
+ )
70
+ self._deployment = deployment
71
+ self._datarobot_endpoint = datarobot_endpoint
72
+ self._datarobot_api_token = datarobot_api_token
73
+
74
+ if self._deployment.model["target_type"] != "TextGeneration":
75
+ raise ValueError(
76
+ f"Invalid deployment '{self._deployment.label}' for LLM. Expecting an LLM "
77
+ f"deployment, but is a '{self._deployment.model['target_type']}' deployment"
78
+ )
79
+
80
+ self._prompt_column_name = self._deployment.model.get("prompt")
81
+ if self._prompt_column_name is None:
82
+ raise ValueError("Prompt column name 'prompt' is not set on the deployment / model")
83
+
84
+ self._target_column_name = self._deployment.model["target_name"] + "_PREDICTION"
85
+ self._async_http_client = AsyncHTTPClient(DEFAULT_TIMEOUT)
86
+
87
+ @property
88
+ def _llm_type(self) -> str:
89
+ """Return type of llm."""
90
+ return "datarobot-llm"
91
+
92
+ @property
93
+ def _identifying_params(self):
94
+ """Get the identifying parameters."""
95
+ _model_kwargs = self.model_kwargs or {}
96
+ return {
97
+ **{"endpoint_url": self._datarobot_endpoint, "deployment_id": str(self._deployment.id)},
98
+ **{"model_kwargs": _model_kwargs},
99
+ }
100
+
101
+ def _call(self, prompt, stop=None, run_manager=None, **kwargs):
102
+ df = pd.DataFrame({self._prompt_column_name: [prompt]})
103
+ result_df, _ = predict(self._deployment, df)
104
+ return result_df[self._target_column_name].iloc[0]
105
+
106
+ @classmethod
107
+ def class_name(cls) -> str:
108
+ return "DataRobotLLM"
109
+
110
+ @property
111
+ def metadata(self):
112
+ return LLMMetadata(is_chat_model=False)
113
+
114
+ @llm_chat_callback()
115
+ def chat(self, messages, **kwargs: Any):
116
+ raise NotImplementedError
117
+
118
+ @llm_completion_callback()
119
+ def complete(self, prompt, formatted, **kwargs):
120
+ df = pd.DataFrame({self._prompt_column_name: [prompt]})
121
+ result_df, _ = predict(self._deployment, df)
122
+ return CompletionResponse(text=result_df[self._target_column_name].iloc[0], raw={})
123
+
124
+ @llm_chat_callback()
125
+ def stream_chat(self, messages, **kwargs):
126
+ raise NotImplementedError
127
+
128
+ @llm_completion_callback()
129
+ def stream_complete(self, prompt, formatted=False, **kwargs):
130
+ raise NotImplementedError
131
+
132
+ @llm_chat_callback()
133
+ async def achat(self, messages, **kwargs):
134
+ raise NotImplementedError
135
+
136
+ @llm_completion_callback()
137
+ async def acomplete(self, prompt, formatted=False, **kwargs):
138
+ input_df_to_guard = pd.DataFrame({self._prompt_column_name: [prompt]})
139
+ result_df = await self._async_http_client.predict(self._deployment, input_df_to_guard)
140
+ return CompletionResponse(text=result_df[self._target_column_name].iloc[0], raw={})
141
+
142
+ @llm_chat_callback()
143
+ async def astream_chat(self, messages, **kwargs):
144
+ raise NotImplementedError
145
+
146
+ @llm_completion_callback()
147
+ async def astream_complete(self, prompt, formatted=False, **kwargs):
148
+ raise NotImplementedError
@@ -0,0 +1,11 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
@@ -0,0 +1,98 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ import pandas as pd
13
+ from datarobot.enums import CustomMetricAggregationType
14
+ from datarobot.enums import CustomMetricDirectionality
15
+
16
+ from datarobot_dome.constants import CUSTOM_METRIC_DESCRIPTION_SUFFIX
17
+ from datarobot_dome.guard_helpers import get_token_count
18
+ from datarobot_dome.metrics.metric_scorer import MetricScorer
19
+
20
+ CITATION_COLUMN = "response.citations"
21
+
22
+
23
+ class CitationTokenCountScorer(MetricScorer):
24
+ NAME = "Total Citation Tokens"
25
+ DESCRIPTION = f"Total number of citation tokens. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
26
+ DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
27
+ UNITS = "count"
28
+ AGGREGATION_TYPE = CustomMetricAggregationType.SUM
29
+ BASELINE_VALUE = 0
30
+ INPUT_COLUMN = CITATION_COLUMN
31
+
32
+ def score(self, df: pd.DataFrame) -> float:
33
+ column = self.input_column
34
+ if column not in df.columns:
35
+ return 0.0
36
+
37
+ return sum(
38
+ sum(get_token_count(v, self.encoding) for v in cell) for cell in df[column].values
39
+ )
40
+
41
+
42
+ class CitationTokenAverageScorer(MetricScorer):
43
+ NAME = "Average Citation Tokens"
44
+ DESCRIPTION = f"Average number of citation tokens. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
45
+ DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
46
+ UNITS = "count"
47
+ AGGREGATION_TYPE = CustomMetricAggregationType.AVERAGE
48
+ BASELINE_VALUE = 0
49
+ INPUT_COLUMN = CITATION_COLUMN
50
+
51
+ def score(self, df: pd.DataFrame) -> float:
52
+ average = 0.0
53
+ total = 0
54
+ count = 0
55
+ column = self.input_column
56
+ if column not in df.columns:
57
+ return 0.0
58
+
59
+ for cell in df[column].values:
60
+ total += sum(get_token_count(v, self.encoding) for v in cell)
61
+ count += sum(v != "" for v in cell)
62
+ average = total / count
63
+
64
+ return average
65
+
66
+
67
+ class DocumentCountScorer(MetricScorer):
68
+ NAME = "Total Documents"
69
+ DESCRIPTION = f"Total number of documents. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
70
+ DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
71
+ UNITS = "count"
72
+ AGGREGATION_TYPE = CustomMetricAggregationType.SUM
73
+ BASELINE_VALUE = 0
74
+ INPUT_COLUMN = CITATION_COLUMN
75
+
76
+ def score(self, df: pd.DataFrame) -> float:
77
+ column = self.input_column
78
+ if column not in df.columns:
79
+ return 0.0
80
+
81
+ return sum(sum(bool(v) for v in cell) for cell in df[column].values)
82
+
83
+
84
+ class DocumentAverageScorer(MetricScorer):
85
+ NAME = "Average Documents"
86
+ DESCRIPTION = f"Average number of documents. {CUSTOM_METRIC_DESCRIPTION_SUFFIX}"
87
+ DIRECTIONALITY = CustomMetricDirectionality.LOWER_IS_BETTER
88
+ UNITS = "count"
89
+ AGGREGATION_TYPE = CustomMetricAggregationType.AVERAGE
90
+ BASELINE_VALUE = 0
91
+ INPUT_COLUMN = CITATION_COLUMN
92
+
93
+ def score(self, df: pd.DataFrame) -> float:
94
+ column = self.input_column
95
+ if column not in df.columns:
96
+ return 0.0
97
+
98
+ return sum(sum(bool(v) for v in cell) for cell in df[column].values)
@@ -0,0 +1,52 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ from typing import Any
13
+ from typing import ClassVar
14
+ from typing import Optional
15
+
16
+ from datarobot_dome.metrics.citation_metrics import CitationTokenAverageScorer
17
+ from datarobot_dome.metrics.citation_metrics import CitationTokenCountScorer
18
+ from datarobot_dome.metrics.citation_metrics import DocumentAverageScorer
19
+ from datarobot_dome.metrics.citation_metrics import DocumentCountScorer
20
+ from datarobot_dome.metrics.metric_scorer import MetricScorer
21
+ from datarobot_dome.metrics.metric_scorer import ScorerType
22
+
23
+ METRIC_SCORE_CLASS_MAP: dict[ScorerType, ClassVar] = {
24
+ ScorerType.CITATION_TOKEN_AVERAGE: CitationTokenAverageScorer,
25
+ ScorerType.CITATION_TOKEN_COUNT: CitationTokenCountScorer,
26
+ ScorerType.DOCUMENT_AVERAGE: DocumentAverageScorer,
27
+ ScorerType.DOCUMENT_COUNT: DocumentCountScorer,
28
+ }
29
+
30
+
31
+ class MetricScorerFactory:
32
+ @staticmethod
33
+ def get_class(metric_type: ScorerType) -> ClassVar:
34
+ clazz = METRIC_SCORE_CLASS_MAP.get(metric_type)
35
+ if clazz is None:
36
+ raise ValueError(f"Unknown metric type: {metric_type}")
37
+
38
+ return clazz
39
+
40
+ @staticmethod
41
+ def create(metric_type: ScorerType, config: Optional[dict[str, Any]] = None) -> MetricScorer:
42
+ _config = config or {}
43
+ clazz = MetricScorerFactory.get_class(metric_type)
44
+ return clazz(_config)
45
+
46
+ @staticmethod
47
+ def custom_metric_config(
48
+ metric_type: ScorerType, config: Optional[dict[str, Any]] = None
49
+ ) -> dict[str, Any]:
50
+ _config = config or {}
51
+ clazz = MetricScorerFactory.get_class(metric_type)
52
+ return clazz.custom_metric_definition(_config)
@@ -0,0 +1,78 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------
12
+ from abc import ABC
13
+ from abc import abstractmethod
14
+ from enum import Enum
15
+ from typing import Any
16
+
17
+ import pandas as pd
18
+
19
+
20
+ class ScorerType(str, Enum):
21
+ CITATION_TOKEN_AVERAGE = "CITATION_TOKEN_AVERAGE"
22
+ CITATION_TOKEN_COUNT = "CITATION_TOKEN_COUNT"
23
+ DOCUMENT_AVERAGE = "DOCUMENT_AVERAGE"
24
+ DOCUMENT_COUNT = "DOCUMENT_COUNT"
25
+
26
+
27
+ class MetricScorer(ABC):
28
+ IS_MODEL_SPECIFIC = True
29
+ TIME_STEP = "hour"
30
+
31
+ def __init__(
32
+ self,
33
+ config: dict[str, Any],
34
+ ):
35
+ self.config = config
36
+
37
+ def __repr__(self) -> str:
38
+ return f"{self.__class__.__name__}()"
39
+
40
+ @classmethod
41
+ def custom_metric_definition(cls, config: dict[str, Any]) -> dict[str, Any]:
42
+ """
43
+ Generates a custom-metric configuration/definition that is used to create the
44
+ custom-metric in the DR application.
45
+
46
+ This is done as a class method because we create custom-metrics before creating
47
+ the scorer.
48
+ """
49
+ return {
50
+ "name": config.get("name", cls.NAME),
51
+ "directionality": cls.DIRECTIONALITY,
52
+ "units": cls.UNITS,
53
+ "type": cls.AGGREGATION_TYPE,
54
+ "baselineValue": config.get("baseline_value", cls.BASELINE_VALUE),
55
+ "isModelSpecific": config.get("is_model_specific", cls.IS_MODEL_SPECIFIC),
56
+ "timeStep": cls.TIME_STEP,
57
+ "description": config.get("description", cls.DESCRIPTION),
58
+ }
59
+
60
+ @property
61
+ def name(self) -> str:
62
+ return self.config.get("name", self.NAME)
63
+
64
+ @property
65
+ def per_prediction(self) -> bool:
66
+ return self.config.get("per_prediction", False)
67
+
68
+ @property
69
+ def input_column(self) -> str:
70
+ return self.config.get("input_column", self.INPUT_COLUMN)
71
+
72
+ @property
73
+ def encoding(self) -> str:
74
+ return self.config.get("encoding", "cl100k_base")
75
+
76
+ @abstractmethod
77
+ def score(self, df: pd.DataFrame) -> float:
78
+ pass
@@ -0,0 +1,11 @@
1
+ # ---------------------------------------------------------------------------------
2
+ # Copyright (c) 2025 DataRobot, Inc. and its affiliates. All rights reserved.
3
+ # Last updated 2025.
4
+ #
5
+ # DataRobot, Inc. Confidential.
6
+ # This is proprietary source code of DataRobot, Inc. and its affiliates.
7
+ #
8
+ # This file and its contents are subject to DataRobot Tool and Utility Agreement.
9
+ # For details, see
10
+ # https://www.datarobot.com/wp-content/uploads/2021/07/DataRobot-Tool-and-Utility-Agreement.pdf.
11
+ # ---------------------------------------------------------------------------------