valor-lite 0.37.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of valor-lite might be problematic. Click here for more details.

Files changed (49) hide show
  1. valor_lite/LICENSE +21 -0
  2. valor_lite/__init__.py +0 -0
  3. valor_lite/cache/__init__.py +11 -0
  4. valor_lite/cache/compute.py +154 -0
  5. valor_lite/cache/ephemeral.py +302 -0
  6. valor_lite/cache/persistent.py +529 -0
  7. valor_lite/classification/__init__.py +14 -0
  8. valor_lite/classification/annotation.py +45 -0
  9. valor_lite/classification/computation.py +378 -0
  10. valor_lite/classification/evaluator.py +879 -0
  11. valor_lite/classification/loader.py +97 -0
  12. valor_lite/classification/metric.py +535 -0
  13. valor_lite/classification/numpy_compatibility.py +13 -0
  14. valor_lite/classification/shared.py +184 -0
  15. valor_lite/classification/utilities.py +314 -0
  16. valor_lite/exceptions.py +20 -0
  17. valor_lite/object_detection/__init__.py +17 -0
  18. valor_lite/object_detection/annotation.py +238 -0
  19. valor_lite/object_detection/computation.py +841 -0
  20. valor_lite/object_detection/evaluator.py +805 -0
  21. valor_lite/object_detection/loader.py +292 -0
  22. valor_lite/object_detection/metric.py +850 -0
  23. valor_lite/object_detection/shared.py +185 -0
  24. valor_lite/object_detection/utilities.py +396 -0
  25. valor_lite/schemas.py +11 -0
  26. valor_lite/semantic_segmentation/__init__.py +15 -0
  27. valor_lite/semantic_segmentation/annotation.py +123 -0
  28. valor_lite/semantic_segmentation/computation.py +165 -0
  29. valor_lite/semantic_segmentation/evaluator.py +414 -0
  30. valor_lite/semantic_segmentation/loader.py +205 -0
  31. valor_lite/semantic_segmentation/metric.py +275 -0
  32. valor_lite/semantic_segmentation/shared.py +149 -0
  33. valor_lite/semantic_segmentation/utilities.py +88 -0
  34. valor_lite/text_generation/__init__.py +15 -0
  35. valor_lite/text_generation/annotation.py +56 -0
  36. valor_lite/text_generation/computation.py +611 -0
  37. valor_lite/text_generation/llm/__init__.py +0 -0
  38. valor_lite/text_generation/llm/exceptions.py +14 -0
  39. valor_lite/text_generation/llm/generation.py +903 -0
  40. valor_lite/text_generation/llm/instructions.py +814 -0
  41. valor_lite/text_generation/llm/integrations.py +226 -0
  42. valor_lite/text_generation/llm/utilities.py +43 -0
  43. valor_lite/text_generation/llm/validators.py +68 -0
  44. valor_lite/text_generation/manager.py +697 -0
  45. valor_lite/text_generation/metric.py +381 -0
  46. valor_lite-0.37.1.dist-info/METADATA +174 -0
  47. valor_lite-0.37.1.dist-info/RECORD +49 -0
  48. valor_lite-0.37.1.dist-info/WHEEL +5 -0
  49. valor_lite-0.37.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,15 @@
1
+ from .annotation import Context, QueryResponse
2
+ from .llm.integrations import ClientWrapper, MistralWrapper, OpenAIWrapper
3
+ from .manager import Evaluator
4
+ from .metric import Metric, MetricType
5
+
6
+ __all__ = [
7
+ "QueryResponse",
8
+ "Context",
9
+ "Evaluator",
10
+ "Metric",
11
+ "MetricType",
12
+ "ClientWrapper",
13
+ "OpenAIWrapper",
14
+ "MistralWrapper",
15
+ ]
@@ -0,0 +1,56 @@
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class Context:
6
+ """
7
+ Contextual ground truth and prediction.
8
+
9
+ Attributes
10
+ ----------
11
+ groundtruth : list[str]
12
+ The definitive context.
13
+ prediction : list[str]
14
+ Any retrieved context from a retrieval-augmented-generation (RAG) pipeline.
15
+
16
+ Examples
17
+ --------
18
+ ... context = Context(
19
+ ... groundtruth=[...],
20
+ ... prediction=[...],
21
+ ... )
22
+ """
23
+
24
+ groundtruth: list[str] = field(default_factory=list)
25
+ prediction: list[str] = field(default_factory=list)
26
+
27
+
28
+ @dataclass
29
+ class QueryResponse:
30
+ """
31
+ Text generation data structure containing ground truths and predictions.
32
+
33
+ Attributes
34
+ ----------
35
+ query : str
36
+ The user query.
37
+ response : str
38
+ The language model's response.
39
+ context : Context
40
+ Any context provided to the model.
41
+
42
+ Examples
43
+ --------
44
+ >>> query = QueryResponse(
45
+ ... query='When was George Washington born?',
46
+ ... response="February 22, 1732",
47
+ ... context=Context(
48
+ ... groundtruth=["02/22/1732"],
49
+ ... prediction=["02/22/1732"],
50
+ ... ),
51
+ ... )
52
+ """
53
+
54
+ query: str
55
+ response: str
56
+ context: Context | None = field(default=None)
@@ -0,0 +1,611 @@
1
+ from valor_lite.text_generation.llm.generation import (
2
+ generate_answer_correctness_verdicts,
3
+ generate_answer_relevance_verdicts,
4
+ generate_bias_verdicts,
5
+ generate_claims,
6
+ generate_context_precision_verdicts,
7
+ generate_context_recall_verdicts,
8
+ generate_context_relevance_verdicts,
9
+ generate_faithfulness_verdicts,
10
+ generate_hallucination_verdicts,
11
+ generate_opinions,
12
+ generate_statements,
13
+ generate_summary_coherence,
14
+ generate_toxicity_verdicts,
15
+ )
16
+ from valor_lite.text_generation.llm.integrations import ClientWrapper
17
+
18
+
19
+ def calculate_answer_correctness(
20
+ client: ClientWrapper,
21
+ system_prompt: str,
22
+ query: str,
23
+ response: str,
24
+ groundtruths: list[str],
25
+ ) -> float:
26
+ """
27
+ Compute answer correctness. Answer correctness is computed as an f1 score obtained
28
+ by comparing prediction statements to ground truth statements.
29
+
30
+ If there are multiple ground truths, then the f1 score is computed for each ground
31
+ truth and the maximum score is returned.
32
+
33
+ This metric was adapted from RAGAS. We follow a similar prompting strategy and
34
+ computation, however we do not do a weighted sum with an answer similarity score
35
+ using embeddings.
36
+
37
+ Parameters
38
+ ----------
39
+ client : ClientWrapper
40
+ The LLM client used to perform evaluation.
41
+ system_prompt : str
42
+ A system prompt to pass with the evaluation query.
43
+ query : str
44
+ The user query.
45
+ response : str
46
+ A generated response.
47
+ groundtruths : list[str]
48
+ A list of ground truth references.
49
+
50
+ Returns
51
+ -------
52
+ float
53
+ The answer correctness score between 0 and 1. Higher values indicate that the
54
+ answer is more correct. A score of 1 indicates that all statements in the
55
+ prediction are supported by the ground truth and all statements in the ground
56
+ truth are present in the prediction.
57
+ """
58
+ prediction_statements = generate_statements(
59
+ client=client,
60
+ system_prompt=system_prompt,
61
+ text=response,
62
+ )
63
+ f1_scores = [0.0]
64
+ for groundtruth in groundtruths:
65
+ groundtruth_statements = generate_statements(
66
+ client=client,
67
+ system_prompt=system_prompt,
68
+ text=groundtruth,
69
+ )
70
+ verdicts = generate_answer_correctness_verdicts(
71
+ client=client,
72
+ system_prompt=system_prompt,
73
+ query=query,
74
+ groundtruth_statements=groundtruth_statements,
75
+ prediction_statements=prediction_statements,
76
+ )
77
+
78
+ tp = len(verdicts["TP"])
79
+ fp = len(verdicts["FP"])
80
+ fn = len(verdicts["FN"])
81
+
82
+ f1_scores.append(tp / (tp + 0.5 * (fp + fn)) if tp > 0 else 0)
83
+
84
+ return max(f1_scores)
85
+
86
+
87
+ def calculate_answer_relevance(
88
+ client: ClientWrapper,
89
+ system_prompt: str,
90
+ query: str,
91
+ response: str,
92
+ ) -> float:
93
+ """
94
+ Compute answer relevance, the proportion of the model response that is
95
+ relevant to the query, for a single piece of text.
96
+
97
+ Parameters
98
+ ----------
99
+ client : ClientWrapper
100
+ The LLM client used to perform evaluation.
101
+ system_prompt : str
102
+ A system prompt to pass with the evaluation query.
103
+ query : str
104
+ The user query.
105
+ response : str
106
+ A generated response.
107
+
108
+ Returns
109
+ -------
110
+ float
111
+ The answer relevance score between 0 and 1. A score of 1 indicates that all
112
+ statements are relevant to the query.
113
+ """
114
+ statements = generate_statements(
115
+ client=client,
116
+ system_prompt=system_prompt,
117
+ text=response,
118
+ )
119
+ verdicts = generate_answer_relevance_verdicts(
120
+ client=client,
121
+ system_prompt=system_prompt,
122
+ query=query,
123
+ statements=statements,
124
+ )
125
+ if len(verdicts) == 0:
126
+ return 0.0
127
+
128
+ return sum(verdict["verdict"] == "yes" for verdict in verdicts) / len(
129
+ verdicts
130
+ )
131
+
132
+
133
+ def calculate_bias(
134
+ client: ClientWrapper,
135
+ system_prompt: str,
136
+ response: str,
137
+ ) -> float:
138
+ """
139
+ Compute bias, the proportion of model opinions that are biased.
140
+
141
+ Parameters
142
+ ----------
143
+ client : ClientWrapper
144
+ The LLM client used to perform evaluation.
145
+ system_prompt : str
146
+ A system prompt to pass with the evaluation query.
147
+ response : str
148
+ A generated response.
149
+
150
+ Returns
151
+ -------
152
+ float
153
+ The bias score between 0 and 1. A score of 1 indicates that all opinions in
154
+ the text are biased.
155
+ """
156
+
157
+ opinions = generate_opinions(
158
+ client=client,
159
+ system_prompt=system_prompt,
160
+ text=response,
161
+ )
162
+ if len(opinions) == 0:
163
+ return 0.0
164
+
165
+ verdicts = generate_bias_verdicts(
166
+ client=client,
167
+ system_prompt=system_prompt,
168
+ opinions=opinions,
169
+ )
170
+ return sum(verdict["verdict"] == "yes" for verdict in verdicts) / len(
171
+ verdicts
172
+ )
173
+
174
+
175
+ def calculate_context_precision(
176
+ client: ClientWrapper,
177
+ system_prompt: str,
178
+ query: str,
179
+ predicted_context: list[str],
180
+ groundtruth_context: list[str],
181
+ ) -> float:
182
+ """
183
+ Compute context precision, a score for evaluating the retrieval
184
+ mechanism of a RAG model.
185
+
186
+ First, an LLM is prompted to determine if each context in the context
187
+ list is useful for producing the ground truth answer to the query.
188
+
189
+ If there are multiple ground truths, then the verdict is "yes" for a
190
+ context if that context is useful for producing any of the ground truth
191
+ answers, and "no" otherwise.
192
+
193
+ Then, using these verdicts, the context precision score is computed as
194
+ a weighted sum of the precision at k for each k from 1 to the length
195
+ of the context list.
196
+
197
+ Note that the earlier a piece of context appears in the context list,
198
+ the more important it is in the computation of this score. For example,
199
+ the first context in the context list will be included in every precision
200
+ at k computation, so will have a large influence on the final score,
201
+ whereas the last context will only be used for the last precision at
202
+ k computation, so will have a small influence on the final score.
203
+
204
+ Parameters
205
+ ----------
206
+ client : ClientWrapper
207
+ The LLM client used to perform evaluation.
208
+ system_prompt : str
209
+ A system prompt to pass with the evaluation query.
210
+ query : str
211
+ The user query.
212
+ response : str
213
+ A generated response.
214
+ predicted_context : list[str]
215
+ A list of predicted context.
216
+ groundtruths : list[str]
217
+ A list of ground truth references.
218
+
219
+ Returns
220
+ -------
221
+ float
222
+ The context precision score between 0 and 1. A higher score indicates
223
+ better context precision.
224
+ """
225
+ if len(predicted_context) == 0 and len(groundtruth_context) == 0:
226
+ return 1.0
227
+ elif len(predicted_context) == 0 or len(groundtruth_context) == 0:
228
+ return 0.0
229
+
230
+ # Get verdicts for each ground truth, and aggregate by setting the verdict for
231
+ # a context to "yes" if the verdict is "yes" for any ground truth.
232
+ aggregate_verdicts = ["no"] * len(predicted_context)
233
+ for groundtruth in groundtruth_context:
234
+ verdicts = generate_context_precision_verdicts(
235
+ client=client,
236
+ system_prompt=system_prompt,
237
+ query=query,
238
+ ordered_context_list=predicted_context,
239
+ groundtruth=groundtruth,
240
+ )
241
+ for i in range(len(verdicts)):
242
+ if verdicts[i]["verdict"] == "yes":
243
+ aggregate_verdicts[i] = "yes"
244
+
245
+ # Use the aggregate verdicts to compute the precision at k for each k.
246
+ precision_at_k_list = []
247
+ for k in range(1, len(predicted_context) + 1):
248
+ # Only compute the precision at k if the kth context is relevant.
249
+ if aggregate_verdicts[k - 1] == "yes":
250
+ precision_at_k = (
251
+ sum(verdict == "yes" for verdict in aggregate_verdicts[:k]) / k
252
+ )
253
+ precision_at_k_list.append(precision_at_k)
254
+
255
+ # If none of the context are relevant, then the context precision is 0.
256
+ if len(precision_at_k_list) == 0:
257
+ return 0.0
258
+
259
+ # Average over all the precision at k for which the kth context is relevant.
260
+ return sum(precision_at_k_list) / len(precision_at_k_list)
261
+
262
+
263
+ def calculate_context_recall(
264
+ client: ClientWrapper,
265
+ system_prompt: str,
266
+ predicted_context: list[str],
267
+ groundtruth_context: list[str],
268
+ ) -> float:
269
+ """
270
+ Compute context recall, a score for evaluating the retrieval mechanism of a RAG model.
271
+
272
+ The context recall score is the proportion of statements in the ground truth
273
+ that are attributable to the context list.
274
+
275
+ If multiple ground truths are provided, then the context recall score is
276
+ computed for each ground truth and the maximum score is returned.
277
+
278
+ Parameters
279
+ ----------
280
+ client : ClientWrapper
281
+ The LLM client used to perform evaluation.
282
+ system_prompt : str
283
+ A system prompt to pass with the evaluation query.
284
+ predicted_context : list[str]
285
+ A list of predicted context.
286
+ groundtruths : list[str]
287
+ A list of ground truth references.
288
+
289
+ Returns
290
+ -------
291
+ float
292
+ The context recall score between 0 and 1. A score of 1 indicates that
293
+ all ground truth statements are attributable to the contexts in the context list.
294
+ """
295
+ if len(predicted_context) == 0 and len(groundtruth_context) == 0:
296
+ return 1.0
297
+ elif len(predicted_context) == 0 or len(groundtruth_context) == 0:
298
+ return 0.0
299
+
300
+ scores = []
301
+ for groundtruth in groundtruth_context:
302
+ groundtruth_statements = generate_statements(
303
+ client=client,
304
+ system_prompt=system_prompt,
305
+ text=groundtruth,
306
+ )
307
+ verdicts = generate_context_recall_verdicts(
308
+ client=client,
309
+ system_prompt=system_prompt,
310
+ context_list=predicted_context,
311
+ groundtruth_statements=groundtruth_statements,
312
+ )
313
+ scores.append(
314
+ sum(verdict["verdict"] == "yes" for verdict in verdicts)
315
+ / len(verdicts)
316
+ )
317
+
318
+ return max(scores)
319
+
320
+
321
+ def calculate_context_relevance(
322
+ client: ClientWrapper,
323
+ system_prompt: str,
324
+ query: str,
325
+ context: list[str],
326
+ ) -> float:
327
+ """
328
+ Compute context relevance, the proportion of contexts in the context list
329
+ that are relevant to the query.
330
+
331
+ Parameters
332
+ ----------
333
+ client : ClientWrapper
334
+ The LLM client used to perform evaluation.
335
+ system_prompt : str
336
+ A system prompt to pass with the evaluation query.
337
+ query : str
338
+ The user query.
339
+ context : list[str]
340
+ A list of predicted context.
341
+
342
+ Returns
343
+ -------
344
+ float
345
+ The context relevance score between 0 and 1. A score of 0 indicates
346
+ that none of the contexts are relevant and a score of 1 indicates
347
+ that all of the contexts are relevant.
348
+ """
349
+ if len(context) == 0:
350
+ return 0.0
351
+ verdicts = generate_context_relevance_verdicts(
352
+ client=client,
353
+ system_prompt=system_prompt,
354
+ query=query,
355
+ context_list=context,
356
+ )
357
+ return sum(verdict["verdict"] == "yes" for verdict in verdicts) / len(
358
+ verdicts
359
+ )
360
+
361
+
362
+ def calculate_faithfulness(
363
+ client: ClientWrapper,
364
+ system_prompt: str,
365
+ response: str,
366
+ context: list[str],
367
+ ) -> float:
368
+ """
369
+ Compute the faithfulness score. The faithfulness score is the proportion
370
+ of claims in the text that are implied by the list of contexts. Claims
371
+ that contradict the list of contexts and claims that are unrelated to
372
+ the list of contexts both count against the score.
373
+
374
+ Parameters
375
+ ----------
376
+ client : ClientWrapper
377
+ The LLM client used to perform evaluation.
378
+ system_prompt : str
379
+ A system prompt to pass with the evaluation query.
380
+ response : str
381
+ A generated response.
382
+ context : list[str]
383
+ A list of predicted context.
384
+
385
+ Returns
386
+ -------
387
+ float
388
+ The faithfulness score between 0 and 1. A score of 1 indicates that
389
+ all claims in the text are implied by the list of contexts.
390
+ """
391
+ if len(context) == 0:
392
+ return 0.0
393
+
394
+ claims = generate_claims(
395
+ client=client, system_prompt=system_prompt, text=response
396
+ )
397
+
398
+ # If there aren't any claims, then the text is perfectly faithful, as the text does not contain any non-faithful claims.
399
+ if len(claims) == 0:
400
+ return 1.0
401
+
402
+ faithfulness_verdicts = generate_faithfulness_verdicts(
403
+ client=client,
404
+ system_prompt=system_prompt,
405
+ claims=claims,
406
+ context_list=context,
407
+ )
408
+ return sum(
409
+ verdict["verdict"] == "yes" for verdict in faithfulness_verdicts
410
+ ) / len(faithfulness_verdicts)
411
+
412
+
413
+ def calculate_hallucination(
414
+ client: ClientWrapper,
415
+ system_prompt: str,
416
+ response: str,
417
+ context: list[str],
418
+ ) -> float:
419
+ """
420
+ Compute the hallucination score, the proportion of contexts in the context
421
+ list that are contradicted by the text.
422
+
423
+ Parameters
424
+ ----------
425
+ client : ClientWrapper
426
+ The LLM client used to perform evaluation.
427
+ system_prompt : str
428
+ A system prompt to pass with the evaluation query.
429
+ response : str
430
+ A generated response.
431
+ context : list[str]
432
+ A list of predicted context.
433
+
434
+ Returns
435
+ -------
436
+ float
437
+ The hallucination score between 0 and 1. A score of 1 indicates that
438
+ all contexts are contradicted by the text.
439
+ """
440
+ if len(context) == 0:
441
+ raise ValueError("Hallucination requires context to be calculated.")
442
+
443
+ verdicts = generate_hallucination_verdicts(
444
+ client=client,
445
+ system_prompt=system_prompt,
446
+ text=response,
447
+ context_list=context,
448
+ )
449
+ return sum(verdict["verdict"] == "yes" for verdict in verdicts) / len(
450
+ verdicts
451
+ )
452
+
453
+
454
+ def calculate_summary_coherence(
455
+ client: ClientWrapper,
456
+ system_prompt: str,
457
+ text: str,
458
+ summary: str,
459
+ ) -> int:
460
+ """
461
+ Compute summary coherence, the collective quality of a summary.
462
+
463
+ Parameters
464
+ ----------
465
+ client : ClientWrapper
466
+ The LLM client used to perform evaluation.
467
+ system_prompt : str
468
+ A system prompt to pass with the evaluation query.
469
+ text : str
470
+ The original text.
471
+ summary : str
472
+ The generated summary.
473
+
474
+ Returns
475
+ -------
476
+ int
477
+ The summary coherence score between 1 and 5. A score of 1 indicates
478
+ the lowest summary coherence and a score of 5 indicates the highest
479
+ summary coherence.
480
+ """
481
+ return generate_summary_coherence(
482
+ client=client,
483
+ system_prompt=system_prompt,
484
+ text=text,
485
+ summary=summary,
486
+ )
487
+
488
+
489
+ def calculate_toxicity(
490
+ client: ClientWrapper,
491
+ system_prompt: str,
492
+ response: str,
493
+ ) -> float:
494
+ """
495
+ Compute toxicity, the proportion of opinions that are toxic.
496
+
497
+ Parameters
498
+ ----------
499
+ client : ClientWrapper
500
+ The LLM client used to perform evaluation.
501
+ system_prompt : str
502
+ A system prompt to pass with the evaluation query.
503
+ response : str
504
+ A generated response.
505
+
506
+ Returns
507
+ -------
508
+ Metric
509
+ The toxicity score will be evaluated as a float between 0 and 1, with
510
+ 1 indicating that all opinions in the text are toxic.
511
+ """
512
+ opinions = generate_opinions(
513
+ client=client,
514
+ system_prompt=system_prompt,
515
+ text=response,
516
+ )
517
+ if len(opinions) == 0:
518
+ return 0.0
519
+
520
+ verdicts = generate_toxicity_verdicts(
521
+ client=client,
522
+ system_prompt=system_prompt,
523
+ opinions=opinions,
524
+ )
525
+ return sum(verdict["verdict"] == "yes" for verdict in verdicts) / len(
526
+ verdicts
527
+ )
528
+
529
+
530
+ def calculate_rouge_scores(
531
+ prediction: str,
532
+ references: str | list[str],
533
+ rouge_types: list[str],
534
+ use_stemmer: bool = False,
535
+ ) -> dict[str, float]:
536
+ """
537
+ Calculate ROUGE scores for a prediction given some set of references.
538
+
539
+ Parameters
540
+ ----------
541
+ prediction : str
542
+ A generated response to score. Each prediction should be a string with tokens separated by spaces.
543
+ references : str | list[str]
544
+ A list of references for a given response. Each reference should be a string with tokens separated by spaces.
545
+ rouge_types : list[str]
546
+ A list of rouge types to calculate.
547
+ use_stemmer: bool, default=False
548
+ If True, uses Porter stemmer to strip word suffixes. Defaults to False.
549
+ """
550
+ import evaluate
551
+
552
+ rouge = evaluate.load("rouge")
553
+
554
+ metrics = rouge.compute(
555
+ predictions=[prediction],
556
+ references=[references],
557
+ rouge_types=rouge_types,
558
+ use_stemmer=use_stemmer,
559
+ use_aggregator=False, # aggregation gives us an average across all predictions, which isn't what we want
560
+ )
561
+
562
+ # find the max value for each prediction
563
+ results = dict()
564
+ if metrics is not None:
565
+ for type_ in rouge_types:
566
+ results[type_] = max(metrics[type_][0], 0.0)
567
+ return results
568
+
569
+
570
+ def calculate_sentence_bleu(
571
+ prediction: str,
572
+ references: list[str],
573
+ weights: tuple[float, ...] | list[float],
574
+ ) -> float:
575
+ """
576
+ Calculate sentence BLEU scores for a of prediction - ground truth pair.
577
+
578
+ Parameters
579
+ ----------
580
+ prediction : str
581
+ A generated response to score. Each prediction should be a string with tokens separated by spaces.
582
+ references : list[str]
583
+ A list of references for a given response. Each reference should be a string with tokens separated by spaces.
584
+ weights : tuple[float]
585
+ The default BLEU calculates a score for up to 4-grams using uniform
586
+ weights (this is called BLEU-4). To evaluate your translations with
587
+ higher/lower order ngrams, use customized weights. Example: when accounting
588
+ for up to 5-grams with uniform weights (this is called BLEU-5) use [1/5]*5
589
+ """
590
+ from nltk.tokenize import RegexpTokenizer
591
+ from nltk.translate import bleu_score
592
+
593
+ if len(weights) == 0:
594
+ raise ValueError("At least one weight should be defined.")
595
+
596
+ tokenizer = RegexpTokenizer(
597
+ r"\w+|\$[\d]+|[^\s\.]+"
598
+ ) # regex tokenizer that ignores periods
599
+
600
+ tokenized_prediction = tokenizer.tokenize(prediction)
601
+ tokenized_references = [tokenizer.tokenize(ref) for ref in references]
602
+
603
+ # find the max value for each prediction
604
+ result = float(
605
+ bleu_score.sentence_bleu(
606
+ references=tokenized_references,
607
+ hypothesis=tokenized_prediction,
608
+ weights=weights,
609
+ ), # type: ignore
610
+ )
611
+ return result if result >= 1e-9 else 0.0
File without changes
@@ -0,0 +1,14 @@
1
+ class InvalidLLMResponseError(Exception):
2
+ """
3
+ Raised when the response from the LLM is invalid for a given metric computation.
4
+ """
5
+
6
+ pass
7
+
8
+
9
+ class MismatchingTextGenerationDatumError(Exception):
10
+ """
11
+ Raised when datums with the same uid but different text are added to the ValorTextGenerationStreamingManager.
12
+ """
13
+
14
+ pass