valor-lite 0.37.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of valor-lite might be problematic. Click here for more details.
- valor_lite/LICENSE +21 -0
- valor_lite/__init__.py +0 -0
- valor_lite/cache/__init__.py +11 -0
- valor_lite/cache/compute.py +154 -0
- valor_lite/cache/ephemeral.py +302 -0
- valor_lite/cache/persistent.py +529 -0
- valor_lite/classification/__init__.py +14 -0
- valor_lite/classification/annotation.py +45 -0
- valor_lite/classification/computation.py +378 -0
- valor_lite/classification/evaluator.py +879 -0
- valor_lite/classification/loader.py +97 -0
- valor_lite/classification/metric.py +535 -0
- valor_lite/classification/numpy_compatibility.py +13 -0
- valor_lite/classification/shared.py +184 -0
- valor_lite/classification/utilities.py +314 -0
- valor_lite/exceptions.py +20 -0
- valor_lite/object_detection/__init__.py +17 -0
- valor_lite/object_detection/annotation.py +238 -0
- valor_lite/object_detection/computation.py +841 -0
- valor_lite/object_detection/evaluator.py +805 -0
- valor_lite/object_detection/loader.py +292 -0
- valor_lite/object_detection/metric.py +850 -0
- valor_lite/object_detection/shared.py +185 -0
- valor_lite/object_detection/utilities.py +396 -0
- valor_lite/schemas.py +11 -0
- valor_lite/semantic_segmentation/__init__.py +15 -0
- valor_lite/semantic_segmentation/annotation.py +123 -0
- valor_lite/semantic_segmentation/computation.py +165 -0
- valor_lite/semantic_segmentation/evaluator.py +414 -0
- valor_lite/semantic_segmentation/loader.py +205 -0
- valor_lite/semantic_segmentation/metric.py +275 -0
- valor_lite/semantic_segmentation/shared.py +149 -0
- valor_lite/semantic_segmentation/utilities.py +88 -0
- valor_lite/text_generation/__init__.py +15 -0
- valor_lite/text_generation/annotation.py +56 -0
- valor_lite/text_generation/computation.py +611 -0
- valor_lite/text_generation/llm/__init__.py +0 -0
- valor_lite/text_generation/llm/exceptions.py +14 -0
- valor_lite/text_generation/llm/generation.py +903 -0
- valor_lite/text_generation/llm/instructions.py +814 -0
- valor_lite/text_generation/llm/integrations.py +226 -0
- valor_lite/text_generation/llm/utilities.py +43 -0
- valor_lite/text_generation/llm/validators.py +68 -0
- valor_lite/text_generation/manager.py +697 -0
- valor_lite/text_generation/metric.py +381 -0
- valor_lite-0.37.1.dist-info/METADATA +174 -0
- valor_lite-0.37.1.dist-info/RECORD +49 -0
- valor_lite-0.37.1.dist-info/WHEEL +5 -0
- valor_lite-0.37.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,697 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
|
|
3
|
+
from valor_lite.text_generation.annotation import QueryResponse
|
|
4
|
+
from valor_lite.text_generation.computation import (
|
|
5
|
+
calculate_answer_correctness,
|
|
6
|
+
calculate_answer_relevance,
|
|
7
|
+
calculate_bias,
|
|
8
|
+
calculate_context_precision,
|
|
9
|
+
calculate_context_recall,
|
|
10
|
+
calculate_context_relevance,
|
|
11
|
+
calculate_faithfulness,
|
|
12
|
+
calculate_hallucination,
|
|
13
|
+
calculate_rouge_scores,
|
|
14
|
+
calculate_sentence_bleu,
|
|
15
|
+
calculate_summary_coherence,
|
|
16
|
+
calculate_toxicity,
|
|
17
|
+
)
|
|
18
|
+
from valor_lite.text_generation.llm.exceptions import InvalidLLMResponseError
|
|
19
|
+
from valor_lite.text_generation.llm.integrations import (
|
|
20
|
+
ClientWrapper,
|
|
21
|
+
MistralWrapper,
|
|
22
|
+
OpenAIWrapper,
|
|
23
|
+
)
|
|
24
|
+
from valor_lite.text_generation.metric import Metric, MetricType
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def llm_guided_metric(fn):
|
|
28
|
+
"""
|
|
29
|
+
Call the LLMClient class function with retries for InvalidLLMResponseError.
|
|
30
|
+
|
|
31
|
+
If retries is set to 0, then the function will only be called once and not retried.
|
|
32
|
+
|
|
33
|
+
If, for example, retries is set to 3, then the function will be retried in the
|
|
34
|
+
event of an InvalidLLMResponseError up to 3 times, for a maximum of 4 calls.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
@wraps(fn)
|
|
38
|
+
def wrapper(self, *args, **kwargs):
|
|
39
|
+
client = getattr(self, "client", None)
|
|
40
|
+
if client is None:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f"{fn.__name__} requires the definition of an LLM client."
|
|
43
|
+
)
|
|
44
|
+
if getattr(client, "model_name", None) is None:
|
|
45
|
+
raise AttributeError(
|
|
46
|
+
"Client wrapper should contain 'model_name' as a string attribute."
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
error = None
|
|
50
|
+
retries = getattr(self, "retries", 0)
|
|
51
|
+
for _ in range(1 + retries):
|
|
52
|
+
try:
|
|
53
|
+
return fn(self, *args, **kwargs)
|
|
54
|
+
except InvalidLLMResponseError as e:
|
|
55
|
+
error = e
|
|
56
|
+
if error is not None:
|
|
57
|
+
return Metric.error(
|
|
58
|
+
error_type=type(error).__name__,
|
|
59
|
+
error_message=str(error),
|
|
60
|
+
model_name=client.model_name,
|
|
61
|
+
retries=retries,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return wrapper
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Evaluator:
|
|
68
|
+
"""
|
|
69
|
+
Parent class for all LLM clients.
|
|
70
|
+
|
|
71
|
+
Attributes
|
|
72
|
+
----------
|
|
73
|
+
client : ClientWrapper, optional
|
|
74
|
+
An optional client to compute llm-guided metrics.
|
|
75
|
+
retries : int
|
|
76
|
+
The number of times to retry the API call if it fails. Defaults to 0, indicating
|
|
77
|
+
that the call will not be retried.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
client: ClientWrapper | None = None,
|
|
83
|
+
retries: int = 0,
|
|
84
|
+
default_system_prompt: str = "You are a helpful assistant.",
|
|
85
|
+
):
|
|
86
|
+
"""
|
|
87
|
+
Creates an instance of a generic LLM client.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
client : ClientWrapper, optional
|
|
92
|
+
Any LLM client that conforms to _ClientWrapper. Required for LLM-guided metrics.
|
|
93
|
+
retries : int, default=0
|
|
94
|
+
The number of times to retry the API call if it fails. Defaults to 0, indicating
|
|
95
|
+
that the call will not be retried.
|
|
96
|
+
default_system_prompt : str, default="You are a helpful assistant."
|
|
97
|
+
The default system prompt that is given to the evaluating LLM.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
self.client = client
|
|
101
|
+
self.retries = retries
|
|
102
|
+
self.default_system_prompt = default_system_prompt
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def openai(
|
|
106
|
+
cls,
|
|
107
|
+
model_name: str = "gpt-3.5-turbo",
|
|
108
|
+
api_key: str | None = None,
|
|
109
|
+
retries: int = 0,
|
|
110
|
+
seed: int | None = None,
|
|
111
|
+
default_system_prompt: str = "You are a helpful assistant.",
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Create an evaluator using OpenAI's client.
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
model_name : str, default="gpt-3.5-turbo"
|
|
119
|
+
The model to use. Defaults to "gpt-3.5-turbo".
|
|
120
|
+
api_key : str, optional
|
|
121
|
+
The OpenAI API key to use. If not specified, then the OPENAI_API_KEY environment
|
|
122
|
+
variable will be used.
|
|
123
|
+
retries : int, default=0
|
|
124
|
+
The number of times to retry the API call if it fails. Defaults to 0, indicating
|
|
125
|
+
that the call will not be retried. For example, if self.retries is set to 3,
|
|
126
|
+
this means that the call will be retried up to 3 times, for a maximum of 4 calls.
|
|
127
|
+
seed : int, optional
|
|
128
|
+
An optional seed can be provided to GPT to get deterministic results.
|
|
129
|
+
default_system_prompt : str, default="You are a helpful assistant."
|
|
130
|
+
The default system prompt that is given to the evaluating LLM.
|
|
131
|
+
"""
|
|
132
|
+
if seed is not None:
|
|
133
|
+
if retries != 0:
|
|
134
|
+
raise ValueError(
|
|
135
|
+
"Seed is provided, but retries is not 0. Retries should be 0 when seed is provided."
|
|
136
|
+
)
|
|
137
|
+
client = OpenAIWrapper(
|
|
138
|
+
api_key=api_key,
|
|
139
|
+
model_name=model_name,
|
|
140
|
+
seed=seed,
|
|
141
|
+
)
|
|
142
|
+
return cls(
|
|
143
|
+
client=client,
|
|
144
|
+
retries=retries,
|
|
145
|
+
default_system_prompt=default_system_prompt,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def mistral(
|
|
150
|
+
cls,
|
|
151
|
+
model_name: str = "mistral-small-latest",
|
|
152
|
+
api_key: str | None = None,
|
|
153
|
+
retries: int = 0,
|
|
154
|
+
default_system_prompt: str = "You are a helpful assistant.",
|
|
155
|
+
):
|
|
156
|
+
"""
|
|
157
|
+
Create an evaluator using the Mistral API.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
model_name : str, default="mistral-small-latest"
|
|
162
|
+
The model to use. Defaults to "mistral-small-latest".
|
|
163
|
+
api_key : str, optional
|
|
164
|
+
The Mistral API key to use. If not specified, then the MISTRAL_API_KEY environment
|
|
165
|
+
variable will be used.
|
|
166
|
+
retries : int, default=0
|
|
167
|
+
The number of times to retry the API call if it fails. Defaults to 0, indicating
|
|
168
|
+
that the call will not be retried. For example, if self.retries is set to 3,
|
|
169
|
+
this means that the call will be retried up to 3 times, for a maximum of 4 calls.
|
|
170
|
+
default_system_prompt : str, default="You are a helpful assistant."
|
|
171
|
+
The default system prompt that is given to the evaluating LLM.
|
|
172
|
+
"""
|
|
173
|
+
client = MistralWrapper(
|
|
174
|
+
api_key=api_key,
|
|
175
|
+
model_name=model_name,
|
|
176
|
+
)
|
|
177
|
+
return cls(
|
|
178
|
+
client=client,
|
|
179
|
+
retries=retries,
|
|
180
|
+
default_system_prompt=default_system_prompt,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
@llm_guided_metric
|
|
184
|
+
def compute_answer_correctness(
|
|
185
|
+
self,
|
|
186
|
+
response: QueryResponse,
|
|
187
|
+
) -> Metric:
|
|
188
|
+
"""
|
|
189
|
+
Compute answer correctness. Answer correctness is computed as an f1 score obtained
|
|
190
|
+
by comparing prediction statements to ground truth statements.
|
|
191
|
+
|
|
192
|
+
If there are multiple ground truths, then the f1 score is computed for each ground
|
|
193
|
+
truth and the maximum score is returned.
|
|
194
|
+
|
|
195
|
+
This metric was adapted from RAGAS. We follow a similar prompting strategy and
|
|
196
|
+
computation, however we do not do a weighted sum with an answer similarity score
|
|
197
|
+
using embeddings.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
response: QueryResponse
|
|
202
|
+
A user query with ground truth and generated response.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
Metric
|
|
207
|
+
The answer correctness score between 0 and 1. Higher values indicate that the
|
|
208
|
+
answer is more correct. A score of 1 indicates that all statements in the
|
|
209
|
+
prediction are supported by the ground truth and all statements in the ground
|
|
210
|
+
truth are present in the prediction.
|
|
211
|
+
"""
|
|
212
|
+
if not response.context:
|
|
213
|
+
raise ValueError("The answer correctness metric requires context.")
|
|
214
|
+
|
|
215
|
+
result = calculate_answer_correctness(
|
|
216
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
217
|
+
system_prompt=self.default_system_prompt,
|
|
218
|
+
query=response.query,
|
|
219
|
+
response=response.response,
|
|
220
|
+
groundtruths=response.context.groundtruth,
|
|
221
|
+
)
|
|
222
|
+
return Metric.answer_correctness(
|
|
223
|
+
value=result,
|
|
224
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
225
|
+
retries=self.retries,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@llm_guided_metric
|
|
229
|
+
def compute_answer_relevance(self, response: QueryResponse) -> Metric:
|
|
230
|
+
"""
|
|
231
|
+
Compute answer relevance, the proportion of the model response that is
|
|
232
|
+
relevant to the query, for a single piece of text.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
response: QueryResponse
|
|
237
|
+
A user query with ground truth and generated response.
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
Metric
|
|
242
|
+
The answer relevance score between 0 and 1. A score of 1 indicates that all
|
|
243
|
+
statements are relevant to the query.
|
|
244
|
+
"""
|
|
245
|
+
result = calculate_answer_relevance(
|
|
246
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
247
|
+
system_prompt=self.default_system_prompt,
|
|
248
|
+
query=response.query,
|
|
249
|
+
response=response.response,
|
|
250
|
+
)
|
|
251
|
+
return Metric.answer_relevance(
|
|
252
|
+
value=result,
|
|
253
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
254
|
+
retries=self.retries,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
@llm_guided_metric
|
|
258
|
+
def compute_bias(
|
|
259
|
+
self,
|
|
260
|
+
response: QueryResponse,
|
|
261
|
+
) -> Metric:
|
|
262
|
+
"""
|
|
263
|
+
Compute bias, the proportion of model opinions that are biased.
|
|
264
|
+
|
|
265
|
+
Parameters
|
|
266
|
+
----------
|
|
267
|
+
response: QueryResponse
|
|
268
|
+
A user query with ground truth and generated response.
|
|
269
|
+
|
|
270
|
+
Returns
|
|
271
|
+
-------
|
|
272
|
+
float
|
|
273
|
+
The bias score between 0 and 1. A score of 1 indicates that all opinions in
|
|
274
|
+
the text are biased.
|
|
275
|
+
"""
|
|
276
|
+
result = calculate_bias(
|
|
277
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
278
|
+
system_prompt=self.default_system_prompt,
|
|
279
|
+
response=response.response,
|
|
280
|
+
)
|
|
281
|
+
return Metric.bias(
|
|
282
|
+
value=result,
|
|
283
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
284
|
+
retries=self.retries,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
@llm_guided_metric
|
|
288
|
+
def compute_context_precision(
|
|
289
|
+
self,
|
|
290
|
+
response: QueryResponse,
|
|
291
|
+
) -> Metric:
|
|
292
|
+
"""
|
|
293
|
+
Compute context precision, a score for evaluating the retrieval
|
|
294
|
+
mechanism of a RAG model.
|
|
295
|
+
|
|
296
|
+
First, an LLM is prompted to determine if each context in the context
|
|
297
|
+
list is useful for producing the ground truth answer to the query.
|
|
298
|
+
|
|
299
|
+
If there are multiple ground truths, then the verdict is "yes" for a
|
|
300
|
+
context if that context is useful for producing any of the ground truth
|
|
301
|
+
answers, and "no" otherwise.
|
|
302
|
+
|
|
303
|
+
Then, using these verdicts, the context precision score is computed as
|
|
304
|
+
a weighted sum of the precision at k for each k from 1 to the length
|
|
305
|
+
of the context list.
|
|
306
|
+
|
|
307
|
+
Note that the earlier a piece of context appears in the context list,
|
|
308
|
+
the more important it is in the computation of this score. For example,
|
|
309
|
+
the first context in the context list will be included in every precision
|
|
310
|
+
at k computation, so will have a large influence on the final score,
|
|
311
|
+
whereas the last context will only be used for the last precision at
|
|
312
|
+
k computation, so will have a small influence on the final score.
|
|
313
|
+
|
|
314
|
+
Parameters
|
|
315
|
+
----------
|
|
316
|
+
response: QueryResponse
|
|
317
|
+
A user query with ground truth and generated response.
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
Metric
|
|
322
|
+
The context precision score between 0 and 1. A higher score indicates
|
|
323
|
+
better context precision.
|
|
324
|
+
"""
|
|
325
|
+
if not response.context:
|
|
326
|
+
raise ValueError("The context precision metric requires context.")
|
|
327
|
+
|
|
328
|
+
result = calculate_context_precision(
|
|
329
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
330
|
+
system_prompt=self.default_system_prompt,
|
|
331
|
+
query=response.query,
|
|
332
|
+
predicted_context=response.context.prediction,
|
|
333
|
+
groundtruth_context=response.context.groundtruth,
|
|
334
|
+
)
|
|
335
|
+
return Metric.context_precision(
|
|
336
|
+
value=result,
|
|
337
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
338
|
+
retries=self.retries,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
@llm_guided_metric
|
|
342
|
+
def compute_context_recall(
|
|
343
|
+
self,
|
|
344
|
+
response: QueryResponse,
|
|
345
|
+
) -> Metric:
|
|
346
|
+
"""
|
|
347
|
+
Compute context recall, a score for evaluating the retrieval mechanism of a RAG model.
|
|
348
|
+
|
|
349
|
+
The context recall score is the proportion of statements in the ground truth
|
|
350
|
+
that are attributable to the context list.
|
|
351
|
+
|
|
352
|
+
If multiple ground truths are provided, then the context recall score is
|
|
353
|
+
computed for each ground truth and the maximum score is returned.
|
|
354
|
+
|
|
355
|
+
Parameters
|
|
356
|
+
----------
|
|
357
|
+
response: QueryResponse
|
|
358
|
+
A user query with ground truth and generated response.
|
|
359
|
+
|
|
360
|
+
Returns
|
|
361
|
+
-------
|
|
362
|
+
Metric
|
|
363
|
+
The context recall score between 0 and 1. A score of 1 indicates that
|
|
364
|
+
all ground truth statements are attributable to the contexts in the context list.
|
|
365
|
+
"""
|
|
366
|
+
if not response.context:
|
|
367
|
+
raise ValueError("The context recall metric requires context.")
|
|
368
|
+
|
|
369
|
+
result = calculate_context_recall(
|
|
370
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
371
|
+
system_prompt=self.default_system_prompt,
|
|
372
|
+
predicted_context=response.context.prediction,
|
|
373
|
+
groundtruth_context=response.context.groundtruth,
|
|
374
|
+
)
|
|
375
|
+
return Metric.context_recall(
|
|
376
|
+
value=result,
|
|
377
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
378
|
+
retries=self.retries,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
@llm_guided_metric
|
|
382
|
+
def compute_context_relevance(
|
|
383
|
+
self,
|
|
384
|
+
response: QueryResponse,
|
|
385
|
+
) -> Metric:
|
|
386
|
+
"""
|
|
387
|
+
Compute context relevance, the proportion of contexts in the context list
|
|
388
|
+
that are relevant to the query.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
response: QueryResponse
|
|
393
|
+
A user query with ground truth and generated response.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
Metric
|
|
398
|
+
The context relevance score between 0 and 1. A score of 0 indicates
|
|
399
|
+
that none of the contexts are relevant and a score of 1 indicates
|
|
400
|
+
that all of the contexts are relevant.
|
|
401
|
+
"""
|
|
402
|
+
if not response.context:
|
|
403
|
+
raise ValueError("The context relevance metric requires context.")
|
|
404
|
+
|
|
405
|
+
result = calculate_context_relevance(
|
|
406
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
407
|
+
system_prompt=self.default_system_prompt,
|
|
408
|
+
query=response.query,
|
|
409
|
+
context=response.context.prediction,
|
|
410
|
+
)
|
|
411
|
+
return Metric.context_relevance(
|
|
412
|
+
value=result,
|
|
413
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
414
|
+
retries=self.retries,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
@llm_guided_metric
|
|
418
|
+
def compute_faithfulness(
|
|
419
|
+
self,
|
|
420
|
+
response: QueryResponse,
|
|
421
|
+
) -> Metric:
|
|
422
|
+
"""
|
|
423
|
+
Compute the faithfulness score. The faithfulness score is the proportion
|
|
424
|
+
of claims in the text that are implied by the list of contexts. Claims
|
|
425
|
+
that contradict the list of contexts and claims that are unrelated to
|
|
426
|
+
the list of contexts both count against the score.
|
|
427
|
+
|
|
428
|
+
Parameters
|
|
429
|
+
----------
|
|
430
|
+
response: QueryResponse
|
|
431
|
+
A user query with ground truth and generated response.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
Metric
|
|
436
|
+
The faithfulness score between 0 and 1. A score of 1 indicates that
|
|
437
|
+
all claims in the text are implied by the list of contexts.
|
|
438
|
+
"""
|
|
439
|
+
|
|
440
|
+
if not response.context:
|
|
441
|
+
raise ValueError("The faithfulness metric requires context.")
|
|
442
|
+
|
|
443
|
+
result = calculate_faithfulness(
|
|
444
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
445
|
+
system_prompt=self.default_system_prompt,
|
|
446
|
+
response=response.response,
|
|
447
|
+
context=response.context.prediction,
|
|
448
|
+
)
|
|
449
|
+
return Metric.faithfulness(
|
|
450
|
+
value=result,
|
|
451
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
452
|
+
retries=self.retries,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
@llm_guided_metric
|
|
456
|
+
def compute_hallucination(
|
|
457
|
+
self,
|
|
458
|
+
response: QueryResponse,
|
|
459
|
+
) -> Metric:
|
|
460
|
+
"""
|
|
461
|
+
Compute the hallucination score, the proportion of contexts in the context
|
|
462
|
+
list that are contradicted by the text.
|
|
463
|
+
|
|
464
|
+
Parameters
|
|
465
|
+
----------
|
|
466
|
+
response: QueryResponse
|
|
467
|
+
A user query with ground truth and generated response.
|
|
468
|
+
|
|
469
|
+
Returns
|
|
470
|
+
-------
|
|
471
|
+
Metric
|
|
472
|
+
The hallucination score between 0 and 1. A score of 1 indicates that
|
|
473
|
+
all contexts are contradicted by the text.
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
if not response.context:
|
|
477
|
+
raise ValueError("The hallucination metric requires context.")
|
|
478
|
+
|
|
479
|
+
result = calculate_hallucination(
|
|
480
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
481
|
+
system_prompt=self.default_system_prompt,
|
|
482
|
+
response=response.response,
|
|
483
|
+
context=response.context.prediction,
|
|
484
|
+
)
|
|
485
|
+
return Metric.hallucination(
|
|
486
|
+
value=result,
|
|
487
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
488
|
+
retries=self.retries,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
@llm_guided_metric
|
|
492
|
+
def compute_summary_coherence(
|
|
493
|
+
self,
|
|
494
|
+
response: QueryResponse,
|
|
495
|
+
) -> Metric:
|
|
496
|
+
"""
|
|
497
|
+
Compute summary coherence, the collective quality of a summary.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
response: QueryResponse
|
|
502
|
+
A user query with ground truth and generated response.
|
|
503
|
+
|
|
504
|
+
Returns
|
|
505
|
+
-------
|
|
506
|
+
Metric
|
|
507
|
+
The summary coherence score between 1 and 5. A score of 1 indicates
|
|
508
|
+
the lowest summary coherence and a score of 5 indicates the highest
|
|
509
|
+
summary coherence.
|
|
510
|
+
"""
|
|
511
|
+
result = calculate_summary_coherence(
|
|
512
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
513
|
+
system_prompt=self.default_system_prompt,
|
|
514
|
+
text=response.query,
|
|
515
|
+
summary=response.response,
|
|
516
|
+
)
|
|
517
|
+
return Metric.summary_coherence(
|
|
518
|
+
value=result,
|
|
519
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
520
|
+
retries=self.retries,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
@llm_guided_metric
|
|
524
|
+
def compute_toxicity(
|
|
525
|
+
self,
|
|
526
|
+
response: QueryResponse,
|
|
527
|
+
) -> Metric:
|
|
528
|
+
"""
|
|
529
|
+
Compute toxicity, the portion of opinions that are toxic.
|
|
530
|
+
|
|
531
|
+
Parameters
|
|
532
|
+
----------
|
|
533
|
+
response: QueryResponse
|
|
534
|
+
A user query with ground truth and generated response.
|
|
535
|
+
|
|
536
|
+
Returns
|
|
537
|
+
-------
|
|
538
|
+
Metric
|
|
539
|
+
The toxicity score will be evaluated as a float between 0 and 1, with
|
|
540
|
+
1 indicating that all opinions in the text are toxic.
|
|
541
|
+
"""
|
|
542
|
+
result = calculate_toxicity(
|
|
543
|
+
client=self.client, # type: ignore - wrapper handles None case
|
|
544
|
+
system_prompt=self.default_system_prompt,
|
|
545
|
+
response=response.response,
|
|
546
|
+
)
|
|
547
|
+
return Metric.toxicity(
|
|
548
|
+
value=result,
|
|
549
|
+
model_name=self.client.model_name, # type: ignore - wrapper handles None case
|
|
550
|
+
retries=self.retries,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
@staticmethod
|
|
554
|
+
def compute_rouge(
|
|
555
|
+
response: QueryResponse,
|
|
556
|
+
rouge_types: list[str] = [
|
|
557
|
+
"rouge1",
|
|
558
|
+
"rouge2",
|
|
559
|
+
"rougeL",
|
|
560
|
+
"rougeLsum",
|
|
561
|
+
],
|
|
562
|
+
use_stemmer: bool = False,
|
|
563
|
+
) -> list[Metric]:
|
|
564
|
+
"""
|
|
565
|
+
Calculate ROUGE scores for a model response given some set of references.
|
|
566
|
+
|
|
567
|
+
Parameters
|
|
568
|
+
----------
|
|
569
|
+
response: QueryResponse
|
|
570
|
+
A user query with ground truth and generated response.
|
|
571
|
+
rouge_types : list[str], optional
|
|
572
|
+
A list of rouge types to calculate.
|
|
573
|
+
Defaults to ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'].
|
|
574
|
+
use_stemmer: bool, default=False
|
|
575
|
+
If True, uses Porter stemmer to strip word suffixes. Defaults to False.
|
|
576
|
+
|
|
577
|
+
Returns
|
|
578
|
+
-------
|
|
579
|
+
list[Metric]
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
if not response.context:
|
|
583
|
+
raise ValueError("ROUGE metrics require context.")
|
|
584
|
+
|
|
585
|
+
results = calculate_rouge_scores(
|
|
586
|
+
prediction=response.response,
|
|
587
|
+
references=response.context.groundtruth,
|
|
588
|
+
rouge_types=rouge_types,
|
|
589
|
+
use_stemmer=use_stemmer,
|
|
590
|
+
)
|
|
591
|
+
return [
|
|
592
|
+
Metric.rouge(
|
|
593
|
+
value=result,
|
|
594
|
+
rouge_type=rouge_type,
|
|
595
|
+
use_stemmer=use_stemmer,
|
|
596
|
+
)
|
|
597
|
+
for rouge_type, result in results.items()
|
|
598
|
+
]
|
|
599
|
+
|
|
600
|
+
@staticmethod
|
|
601
|
+
def compute_sentence_bleu(
|
|
602
|
+
response: QueryResponse,
|
|
603
|
+
weights: list[float] = [0.25, 0.25, 0.25, 0.25],
|
|
604
|
+
) -> Metric:
|
|
605
|
+
"""
|
|
606
|
+
Calculate sentence BLEU scores for a set of model response - ground truth pairs.
|
|
607
|
+
|
|
608
|
+
Parameters
|
|
609
|
+
----------
|
|
610
|
+
response: QueryResponse
|
|
611
|
+
A user query with ground truth and generated response.
|
|
612
|
+
weights: list[float], default=[0.25, 0.25, 0.25, 0.25]
|
|
613
|
+
The default BLEU calculates a score for up to 4-grams using uniform
|
|
614
|
+
weights (this is called BLEU-4). To evaluate your translations with
|
|
615
|
+
higher/lower order ngrams, use customized weights. Example: when accounting
|
|
616
|
+
for up to 5-grams with uniform weights (this is called BLEU-5) use [1/5]*5
|
|
617
|
+
"""
|
|
618
|
+
|
|
619
|
+
if not response.context:
|
|
620
|
+
raise ValueError("The sentence BLEU metric requires context.")
|
|
621
|
+
|
|
622
|
+
result = calculate_sentence_bleu(
|
|
623
|
+
prediction=response.response,
|
|
624
|
+
references=response.context.groundtruth,
|
|
625
|
+
weights=weights,
|
|
626
|
+
)
|
|
627
|
+
return Metric.bleu(
|
|
628
|
+
value=result,
|
|
629
|
+
weights=weights,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
def compute_all(
|
|
633
|
+
self,
|
|
634
|
+
response: QueryResponse,
|
|
635
|
+
bleu_weights: list[float] = [0.25, 0.25, 0.25, 0.25],
|
|
636
|
+
rouge_types: list[str] = [
|
|
637
|
+
"rouge1",
|
|
638
|
+
"rouge2",
|
|
639
|
+
"rougeL",
|
|
640
|
+
"rougeLsum",
|
|
641
|
+
],
|
|
642
|
+
rouge_use_stemmer: bool = False,
|
|
643
|
+
) -> dict[MetricType, list[Metric]]:
|
|
644
|
+
"""
|
|
645
|
+
Computes all available metrics.
|
|
646
|
+
|
|
647
|
+
Parameters
|
|
648
|
+
----------
|
|
649
|
+
response: QueryResponse
|
|
650
|
+
A user query with ground truth and generated response.
|
|
651
|
+
bleu_weights: list[float], default=[0.25, 0.25, 0.25, 0.25]
|
|
652
|
+
The default BLEU calculates a score for up to 4-grams using uniform
|
|
653
|
+
weights (this is called BLEU-4). To evaluate your translations with
|
|
654
|
+
higher/lower order ngrams, use customized weights. Example: when accounting
|
|
655
|
+
for up to 5-grams with uniform weights (this is called BLEU-5) use [1/5]*5
|
|
656
|
+
rouge_types : list[str], optional
|
|
657
|
+
A list of rouge types to calculate.
|
|
658
|
+
Defaults to ['rouge1', 'rouge2', 'rougeL', 'rougeLsum'].
|
|
659
|
+
rouge_use_stemmer: bool, default=False
|
|
660
|
+
If True, uses Porter stemmer to strip word suffixes. Defaults to False.
|
|
661
|
+
"""
|
|
662
|
+
results = dict()
|
|
663
|
+
results[MetricType.AnswerCorrectness] = [
|
|
664
|
+
self.compute_answer_correctness(response)
|
|
665
|
+
]
|
|
666
|
+
results[MetricType.AnswerRelevance] = [
|
|
667
|
+
self.compute_answer_relevance(response)
|
|
668
|
+
]
|
|
669
|
+
results[MetricType.Bias] = [self.compute_bias(response)]
|
|
670
|
+
results[MetricType.ContextPrecision] = [
|
|
671
|
+
self.compute_context_precision(response)
|
|
672
|
+
]
|
|
673
|
+
results[MetricType.ContextRecall] = [
|
|
674
|
+
self.compute_context_recall(response)
|
|
675
|
+
]
|
|
676
|
+
results[MetricType.ContextRelevance] = [
|
|
677
|
+
self.compute_context_relevance(response)
|
|
678
|
+
]
|
|
679
|
+
results[MetricType.Faithfulness] = [
|
|
680
|
+
self.compute_faithfulness(response)
|
|
681
|
+
]
|
|
682
|
+
results[MetricType.Hallucination] = [
|
|
683
|
+
self.compute_hallucination(response)
|
|
684
|
+
]
|
|
685
|
+
results[MetricType.SummaryCoherence] = [
|
|
686
|
+
self.compute_summary_coherence(response)
|
|
687
|
+
]
|
|
688
|
+
results[MetricType.Toxicity] = [self.compute_toxicity(response)]
|
|
689
|
+
results[MetricType.ROUGE] = self.compute_rouge(
|
|
690
|
+
response=response,
|
|
691
|
+
rouge_types=rouge_types,
|
|
692
|
+
use_stemmer=rouge_use_stemmer,
|
|
693
|
+
)
|
|
694
|
+
results[MetricType.BLEU] = [
|
|
695
|
+
self.compute_sentence_bleu(response=response, weights=bleu_weights)
|
|
696
|
+
]
|
|
697
|
+
return results
|