graphrag-eval 6.2.0__tar.gz → 6.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. graphrag_eval-6.4.0/PKG-INFO +46 -0
  2. graphrag_eval-6.4.0/README.md +27 -0
  3. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/aggregation.py +4 -4
  4. graphrag_eval-6.4.0/graphrag_eval/answer_correctness.py +176 -0
  5. graphrag_eval-6.4.0/graphrag_eval/answer_relevance.py +61 -0
  6. graphrag_eval-6.4.0/graphrag_eval/cli/answer_correctness.py +122 -0
  7. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/custom_evaluation.py +34 -18
  8. graphrag_eval-6.4.0/graphrag_eval/evaluation.py +154 -0
  9. graphrag_eval-6.4.0/graphrag_eval/evaluator.py +14 -0
  10. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/llm_factory.py +20 -10
  11. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/prompts/template.md +1 -1
  12. graphrag_eval-6.4.0/graphrag_eval/steps/__init__.py +0 -0
  13. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/evaluation.py +11 -3
  14. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/pyproject.toml +2 -2
  15. graphrag_eval-6.2.0/PKG-INFO +0 -1310
  16. graphrag_eval-6.2.0/README.md +0 -1291
  17. graphrag_eval-6.2.0/graphrag_eval/answer_correctness.py +0 -192
  18. graphrag_eval-6.2.0/graphrag_eval/answer_relevance.py +0 -29
  19. graphrag_eval-6.2.0/graphrag_eval/evaluation.py +0 -101
  20. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/LICENSE +0 -0
  21. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/__init__.py +0 -0
  22. {graphrag_eval-6.2.0/graphrag_eval/steps → graphrag_eval-6.4.0/graphrag_eval/cli}/__init__.py +0 -0
  23. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/iri_discovery.py +0 -0
  24. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/retrieval_answer.py +0 -0
  25. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/retrieval_context_ids.py +0 -0
  26. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/retrieval_context_texts.py +0 -0
  27. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/sparql.py +0 -0
  28. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/steps/timeseries.py +0 -0
  29. {graphrag_eval-6.2.0 → graphrag_eval-6.4.0}/graphrag_eval/util.py +0 -0
@@ -0,0 +1,46 @@
1
+ Metadata-Version: 2.3
2
+ Name: graphrag-eval
3
+ Version: 6.4.0
4
+ Summary: For assessing question answering systems' final answers and intermediate steps, against a given set of questions, reference answers and steps.
5
+ License: Apache-2.0
6
+ Author: Philip Ganchev
7
+ Author-email: philip.ganchev@graphwise.ai
8
+ Requires-Python: >=3.12,<3.13
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Provides-Extra: llm
13
+ Requires-Dist: pydantic (==2.12.5)
14
+ Requires-Dist: python-dateutil (==2.9.0.post0)
15
+ Requires-Dist: ragas (==0.4.3) ; extra == "llm"
16
+ Project-URL: Repository, https://github.com/Ontotext-AD/graphrag-eval
17
+ Description-Content-Type: text/markdown
18
+
19
+ <p align="center">
20
+ <img alt="Graphwise Logo" src="https://github.com/Ontotext-AD/graphrag-eval/blob/main/.github/Graphwise_Logo.jpg">
21
+ </p>
22
+
23
+ # QA Evaluation
24
+
25
+ This is a Python library for assessing the quality of question-answering systems, such as systems built with LLM-based agents. It is agnostic to the agent implementation and the LLM it uses.
26
+
27
+ The evaluation is based on a user-provided reference dataset containing queries, reference responses, and optional reference steps, such as expected tool uses. The evaluator compares these references with the agent's actual responses and executed steps. Reference steps can be grouped to allow some expected steps to occur in any order.
28
+
29
+ The library provides built-in evaluation metrics and supports user-defined custom metrics ([§ Metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md)).
30
+
31
+ ## Documentation
32
+
33
+ - [Quickstart](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/quickstart.md)
34
+ - [Metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md)
35
+ - [Configuration](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/config.md)
36
+ - [Input](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/input.md)
37
+ - [Output](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/output.md)
38
+
39
+ ## Maintainers
40
+
41
+ Developed and maintained by [Graphwise](https://graphwise.ai/). For issues and feature requests, please open a [GitHub issue](https://github.com/Ontotext-AD/graphrag-eval/issues).
42
+
43
+ ## License
44
+
45
+ Apache-2.0 License. See the [LICENSE](https://github.com/Ontotext-AD/graphrag-eval/blob/main/LICENSE) file for details.
46
+
@@ -0,0 +1,27 @@
1
+ <p align="center">
2
+ <img alt="Graphwise Logo" src="https://github.com/Ontotext-AD/graphrag-eval/blob/main/.github/Graphwise_Logo.jpg">
3
+ </p>
4
+
5
+ # QA Evaluation
6
+
7
+ This is a Python library for assessing the quality of question-answering systems, such as systems built with LLM-based agents. It is agnostic to the agent implementation and the LLM it uses.
8
+
9
+ The evaluation is based on a user-provided reference dataset containing queries, reference responses, and optional reference steps, such as expected tool uses. The evaluator compares these references with the agent's actual responses and executed steps. Reference steps can be grouped to allow some expected steps to occur in any order.
10
+
11
+ The library provides built-in evaluation metrics and supports user-defined custom metrics ([§ Metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md)).
12
+
13
+ ## Documentation
14
+
15
+ - [Quickstart](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/quickstart.md)
16
+ - [Metrics](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/metrics.md)
17
+ - [Configuration](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/config.md)
18
+ - [Input](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/input.md)
19
+ - [Output](https://github.com/Ontotext-AD/graphrag-eval/blob/main/docs/output.md)
20
+
21
+ ## Maintainers
22
+
23
+ Developed and maintained by [Graphwise](https://graphwise.ai/). For issues and feature requests, please open a [GitHub issue](https://github.com/Ontotext-AD/graphrag-eval/issues).
24
+
25
+ ## License
26
+
27
+ Apache-2.0 License. See the [LICENSE](https://github.com/Ontotext-AD/graphrag-eval/blob/main/LICENSE) file for details.
@@ -1,13 +1,13 @@
1
1
  import json
2
- import yaml
3
2
  from collections import defaultdict
4
3
  from collections.abc import Sequence
5
4
  from pathlib import Path
6
5
  from statistics import mean, median
7
6
  from typing import Any, Collection, Iterable
8
7
 
9
- from . import evaluation
8
+ import yaml
10
9
 
10
+ from . import evaluation
11
11
 
12
12
  METRICS = [
13
13
  "answer_recall",
@@ -155,7 +155,7 @@ def compute_micro_stats(
155
155
  ) -> dict:
156
156
  if custom_metrics is None:
157
157
  custom_metrics = []
158
-
158
+
159
159
  values = number_of_samples_per_template_by_status.values()
160
160
  micro_summary = defaultdict(dict, {
161
161
  "number_of_error_samples": sum(v["error"] for v in values),
@@ -197,7 +197,7 @@ def compute_macro_stats(
197
197
  ) -> dict:
198
198
  if custom_metrics is None:
199
199
  custom_metrics = []
200
-
200
+
201
201
  macro_summary = defaultdict(dict)
202
202
  for metric in METRICS + custom_metrics:
203
203
  means = [
@@ -0,0 +1,176 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Self, TYPE_CHECKING
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from graphrag_eval.util import compute_f1
9
+ from .evaluator import Evaluator
10
+
11
+ if TYPE_CHECKING:
12
+ from ragas.llms.base import InstructorBaseRagasLLM
13
+
14
+
15
+ def load_default_prompt() -> str:
16
+ with open(
17
+ Path(__file__).parent / "prompts" / "template.md",
18
+ encoding="utf-8"
19
+ ) as f:
20
+ return f.read()
21
+
22
+
23
+ class AnswerCorrectnessConfig(BaseModel):
24
+ enabled: bool = Field(default=True)
25
+ prompt: str = Field(default_factory=load_default_prompt)
26
+
27
+
28
+ class InvalidPromptException(Exception):
29
+ def __init__(
30
+ self,
31
+ message="The prompt template is invalid and cannot be "
32
+ "formatted."
33
+ ):
34
+ self.message = message
35
+ super().__init__(self.message)
36
+
37
+
38
+ class AnswerCorrectnessEvaluator:
39
+ def __init__(
40
+ self,
41
+ ragas_llm: InstructorBaseRagasLLM,
42
+ config: AnswerCorrectnessConfig | None = None,
43
+ ):
44
+ self.config = config or AnswerCorrectnessConfig()
45
+ self.__validate_prompt_template(self.config.prompt)
46
+ self.prompt_template = self.config.prompt
47
+ self.ragas_llm = ragas_llm
48
+
49
+ @classmethod
50
+ def from_config(
51
+ cls,
52
+ ragas_llm: InstructorBaseRagasLLM | None,
53
+ config: AnswerCorrectnessConfig | None
54
+ ) -> Self | None:
55
+ if ragas_llm is None:
56
+ return None
57
+ if config is None or not config.enabled:
58
+ return None
59
+ return cls(ragas_llm=ragas_llm, config=config)
60
+
61
+ @staticmethod
62
+ def __validate_prompt_template(prompt_template: str):
63
+ try:
64
+ prompt_template.format(
65
+ question="Q?",
66
+ reference_answer="R",
67
+ actual_answer="A",
68
+ )
69
+ except Exception as exc:
70
+ raise InvalidPromptException(
71
+ "Invalid prompt template. Must only contain placeholders: "
72
+ "{question}, {reference_answer}, and {actual_answer}. "
73
+ f"Original error: {exc}"
74
+ ) from exc
75
+
76
+ async def _agenerate(self, prompt):
77
+ """Wrapper method for easier testing"""
78
+ return (await self.ragas_llm.agenerate(prompt, None)).choices[0].message.content
79
+
80
+ async def evaluate_answer(
81
+ self,
82
+ question: str,
83
+ reference_answer: str,
84
+ actual_answer: str
85
+ ) -> tuple[int, int, int, str]:
86
+ if any(
87
+ not s.strip() for s in [question, reference_answer, actual_answer]
88
+ ):
89
+ raise ValueError(
90
+ "The question of the reference or the actual answer is a blank "
91
+ "string!"
92
+ )
93
+ prompt = self.prompt_template.format(
94
+ question=question,
95
+ reference_answer=reference_answer,
96
+ actual_answer=actual_answer,
97
+ )
98
+ response_str = await self._agenerate(prompt)
99
+ return self.extract_response_values(response_str)
100
+
101
+ async def evaluate(
102
+ self,
103
+ reference: dict[str, Any],
104
+ actual: dict[str, Any]
105
+ ) -> dict[str, Any]:
106
+ if "actual_answer" not in actual or "reference_answer" not in reference:
107
+ return {}
108
+ result = {}
109
+ try:
110
+ num_ref_claims, num_actual_claims, num_matching_claims, reason = \
111
+ await self.evaluate_answer(
112
+ reference["question_text"],
113
+ reference["reference_answer"],
114
+ actual["actual_answer"],
115
+ )
116
+ result.update({
117
+ "answer_reference_claims_count": num_ref_claims,
118
+ "answer_actual_claims_count": num_actual_claims,
119
+ "answer_matching_claims_count": num_matching_claims,
120
+ "answer_correctness_reason": reason,
121
+ })
122
+ recall, precision, f1 = self.compute_recall_precision_f1(
123
+ num_ref_claims, num_actual_claims, num_matching_claims
124
+ )
125
+ if recall is not None:
126
+ result["answer_recall"] = recall
127
+ if precision is not None:
128
+ result["answer_precision"] = precision
129
+ if f1 is not None:
130
+ result["answer_f1"] = f1
131
+ except Exception as exc:
132
+ result["answer_correctness_error"] = str(exc)
133
+ return result
134
+
135
+ @staticmethod
136
+ def compute_recall_precision_f1(
137
+ n_pos: int,
138
+ n_pred_pos: int,
139
+ n_true_pos: int,
140
+ ) -> tuple[float | None, float | None, float | None]:
141
+ recall = None
142
+ precision = None
143
+ if n_pos:
144
+ recall = n_true_pos / n_pos
145
+ if n_pred_pos:
146
+ precision = n_true_pos / n_pred_pos
147
+ return recall, precision, compute_f1(recall, precision)
148
+
149
+ @staticmethod
150
+ def extract_response_values(
151
+ response: str
152
+ ) -> tuple[int, int, int, str]:
153
+ vals = response.split("\t")
154
+ n = len(vals)
155
+ if n < 4:
156
+ raise ValueError(f"Expected 4 tab-separated values: {response}")
157
+ vals = vals[:4]
158
+ try:
159
+ n_ref, n_actual, n_matching = map(int, vals[:3])
160
+ except ValueError:
161
+ raise ValueError(f"Claims counts should be ints: {vals}")
162
+ if any([
163
+ n_ref < 1,
164
+ n_actual < 1,
165
+ n_matching < 0,
166
+ n_matching > n_ref,
167
+ n_matching > n_actual
168
+ ]):
169
+ raise ValueError(
170
+ "Invalid claims counts combination: "
171
+ f"{n_ref}\t{n_actual}\t{n_matching}"
172
+ )
173
+ return n_ref, n_actual, n_matching, vals[3]
174
+
175
+
176
+ _: Evaluator = AnswerCorrectnessEvaluator
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Self, TYPE_CHECKING
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from .evaluator import Evaluator
8
+
9
+ if TYPE_CHECKING:
10
+ from ragas.llms.base import InstructorBaseRagasLLM
11
+ from ragas.embeddings.base import BaseRagasEmbeddings, BaseRagasEmbedding
12
+
13
+
14
+ class AnswerRelevanceConfig(BaseModel):
15
+ enabled: bool = Field(default=True)
16
+
17
+
18
+ class AnswerRelevanceEvaluator:
19
+ def __init__(
20
+ self,
21
+ ragas_llm: InstructorBaseRagasLLM,
22
+ ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding
23
+ ):
24
+ from ragas.metrics.collections import AnswerRelevancy
25
+ self.scorer = AnswerRelevancy(llm=ragas_llm, embeddings=ragas_embedder)
26
+
27
+ @classmethod
28
+ def from_config(
29
+ cls,
30
+ ragas_llm: InstructorBaseRagasLLM | None,
31
+ ragas_embedder: BaseRagasEmbeddings | BaseRagasEmbedding | None,
32
+ config: AnswerRelevanceConfig | None
33
+ ) -> Self | None:
34
+ if ragas_llm is None or ragas_embedder is None:
35
+ return None
36
+ if config is None or not config.enabled:
37
+ return None
38
+ return cls(ragas_llm=ragas_llm, ragas_embedder=ragas_embedder)
39
+
40
+ async def evaluate(
41
+ self,
42
+ reference: dict[str, Any],
43
+ actual: dict[str, Any]
44
+ ) -> dict[str, Any]:
45
+ if "actual_answer" not in actual:
46
+ return {}
47
+ try:
48
+ result = await self.scorer.ascore(
49
+ user_input=reference["question_text"],
50
+ response=actual["actual_answer"]
51
+ )
52
+ return {
53
+ "answer_relevance": result.value
54
+ }
55
+ except Exception as e:
56
+ return {
57
+ "answer_relevance_error": str(e)
58
+ }
59
+
60
+
61
+ _: Evaluator = AnswerRelevanceEvaluator
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import asyncio
5
+ import csv
6
+ from argparse import ArgumentParser
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ from tqdm import tqdm
11
+
12
+ from graphrag_eval import llm_factory
13
+ from graphrag_eval.answer_correctness import AnswerCorrectnessEvaluator
14
+ from graphrag_eval.evaluation import Config
15
+
16
+ if TYPE_CHECKING:
17
+ from ragas.llms.base import InstructorBaseRagasLLM
18
+
19
+
20
+ def parse_args() -> argparse.Namespace:
21
+ parser = ArgumentParser(
22
+ description="Calculates answer correctness over the entries from the "
23
+ "input tsv file and stores the output in the output tsv "
24
+ "file.",
25
+ )
26
+ parser.add_argument(
27
+ "-i",
28
+ "--input-tsv-file-path",
29
+ type=Path,
30
+ required=True,
31
+ help="Input tsv file path with columns `Question`, `Reference answer` "
32
+ "and `Actual answer`",
33
+ )
34
+ parser.add_argument(
35
+ "-o",
36
+ "--output-tsv-file-path",
37
+ type=Path,
38
+ required=True,
39
+ help="Output tsv file path with columns `#Reference`, `#PTarget`, "
40
+ "`#Matching`, `Reasoning`, `Error`",
41
+ )
42
+ parser.add_argument(
43
+ "-c",
44
+ "--config-yaml-file-path",
45
+ type=Path,
46
+ required=True,
47
+ help="Config yaml file path with definition of the LLM to use and "
48
+ "optionally a custom prompt.",
49
+ )
50
+ return parser.parse_args()
51
+
52
+
53
+ async def evaluate_and_write(
54
+ input_tsv_file_path: Path,
55
+ output_tsv_file_path: Path,
56
+ evaluator: AnswerCorrectnessEvaluator,
57
+ ) -> None:
58
+ with open(input_tsv_file_path, encoding="utf-8") as f:
59
+ reader = csv.DictReader(f, delimiter="\t")
60
+ rows = [row for row in reader]
61
+ print(f"Writing results to {output_tsv_file_path}")
62
+ output_tsv_file_path.parent.mkdir(parents=True, exist_ok=True)
63
+ with open(output_tsv_file_path, "w", encoding="utf-8") as f:
64
+ writer = csv.writer(f, delimiter="\t")
65
+ writer.writerow(
66
+ ["#Reference", "#PTarget", "#Matching", "Reasoning", "Error"]
67
+ )
68
+
69
+ for row in tqdm(rows):
70
+ if "Question" not in row or \
71
+ "Reference answer" not in row or \
72
+ "Actual answer" not in row:
73
+ raise ValueError("Unexpected input format!")
74
+
75
+ try:
76
+ vals = await evaluator.evaluate_answer(
77
+ row["Question"],
78
+ row["Reference answer"],
79
+ row["Actual answer"]
80
+ )
81
+ vals = vals + ("",)
82
+ writer.writerow(vals)
83
+ except Exception as exc:
84
+ writer.writerow(["", "", "", "", str(exc)])
85
+ f.flush()
86
+
87
+
88
+ def run(
89
+ config_yaml_file_path: Path,
90
+ input_tsv_file_path: Path,
91
+ output_tsv_file_path: Path,
92
+ ):
93
+ config = Config.parse(config_yaml_file_path)
94
+ ragas_llm: InstructorBaseRagasLLM | None = llm_factory.create_llm(
95
+ config.llm
96
+ )
97
+ if ragas_llm is None:
98
+ raise ValueError(
99
+ "LLM must be configured to calculate the answer correctness!"
100
+ )
101
+ if config.answer_correctness and not config.answer_correctness.enabled:
102
+ raise ValueError(
103
+ "Can't disable answer correctness, when running this script!"
104
+ )
105
+ evaluator = AnswerCorrectnessEvaluator(
106
+ ragas_llm=ragas_llm,
107
+ config=config.answer_correctness,
108
+ )
109
+ asyncio.run(evaluate_and_write(
110
+ input_tsv_file_path,
111
+ output_tsv_file_path,
112
+ evaluator,
113
+ ))
114
+
115
+
116
+ def main():
117
+ args = parse_args()
118
+ run(
119
+ args.config_yaml_file_path,
120
+ args.input_tsv_file_path,
121
+ args.output_tsv_file_path,
122
+ )
@@ -1,9 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  import json
2
- from typing import Literal
4
+ from typing import Literal, Self, TYPE_CHECKING, Any
3
5
 
4
6
  from pydantic import BaseModel, ConfigDict, Field, model_validator
5
7
 
6
- from graphrag_eval.llm_factory import create_llm
8
+ from .evaluator import Evaluator
9
+
10
+ if TYPE_CHECKING:
11
+ from ragas.llms.base import InstructorBaseRagasLLM
7
12
 
8
13
  RESERVED_KEYS = {
9
14
  "template_id",
@@ -43,7 +48,7 @@ Inputs = Literal[
43
48
  StepsKey = Literal["args", "output"]
44
49
 
45
50
 
46
- class Config(BaseModel):
51
+ class EvaluatorConfig(BaseModel):
47
52
  model_config = ConfigDict(extra='forbid')
48
53
  name: str
49
54
  inputs: list[Inputs] = Field(..., min_length=1)
@@ -53,7 +58,7 @@ class Config(BaseModel):
53
58
  steps_keys: set[StepsKey] | None = Field(default=None, min_length=1)
54
59
 
55
60
  @model_validator(mode='after')
56
- def validate_step_dependencies(self) -> 'Config':
61
+ def validate_step_dependencies(self) -> Self:
57
62
  if set(self.inputs) & {"reference_steps", "actual_steps"}:
58
63
  suffix = "is required when steps are in inputs"
59
64
  for var_name in ["steps_name", "steps_keys"]:
@@ -62,7 +67,7 @@ class Config(BaseModel):
62
67
  return self
63
68
 
64
69
  @model_validator(mode='after')
65
- def validate_name_and_outputs(self) -> 'Config':
70
+ def validate_name_and_outputs(self) -> Self:
66
71
  if self.name + "_error" in RESERVED_KEYS:
67
72
  raise ValueError(f"Name {self.name} is reserved")
68
73
  conflicting_keys = set(self.outputs.keys()) & RESERVED_KEYS
@@ -76,7 +81,7 @@ def create_input_template(input_key: str) -> str:
76
81
  return f"# {header}\n{{{input_key}}}"
77
82
 
78
83
 
79
- def create_prompt_template(config: Config, output_variables: list[str]) -> str:
84
+ def create_prompt_template(config: EvaluatorConfig, output_variables: list[str]) -> str:
80
85
  """
81
86
  Return a template for the LLM prompt, with placeholders for the inputs,
82
87
  instructions, outputs etc. We use this template at evaluation time to
@@ -99,8 +104,8 @@ def create_prompt_template(config: Config, output_variables: list[str]) -> str:
99
104
  class CustomEvaluator:
100
105
  def __init__(
101
106
  self,
102
- config: Config,
103
- eval_config: "evaluation.Config",
107
+ ragas_llm: InstructorBaseRagasLLM,
108
+ config: EvaluatorConfig,
104
109
  ):
105
110
  self.name = config.name
106
111
  self.input_variables = config.inputs
@@ -111,11 +116,24 @@ class CustomEvaluator:
111
116
  config,
112
117
  self.output_variables
113
118
  )
114
- self.llm = create_llm(eval_config)
119
+ self.ragas_llm = ragas_llm
120
+
121
+ @classmethod
122
+ def from_config(
123
+ cls,
124
+ ragas_llm: InstructorBaseRagasLLM | None,
125
+ evaluation_configs: list[EvaluatorConfig] | None
126
+ ) -> list[Self]:
127
+ if ragas_llm and evaluation_configs:
128
+ return [
129
+ cls(ragas_llm, evaluation_config)
130
+ for evaluation_config in evaluation_configs
131
+ ]
132
+ return []
115
133
 
116
134
  async def _agenerate(self, prompt: str) -> str:
117
135
  """Wrapper method for easier testing"""
118
- return (await self.llm.agenerate(prompt, None)).choices[0].message.content
136
+ return (await self.ragas_llm.agenerate(prompt, None)).choices[0].message.content
119
137
 
120
138
  def format_steps(self, steps: list) -> str:
121
139
  steps_formatted = []
@@ -157,7 +175,11 @@ class CustomEvaluator:
157
175
  return result
158
176
  return self.error(f"Expected {n_exp} tab-separated values, got: {response}")
159
177
 
160
- async def evaluate(self, reference: dict, actual: dict) -> dict[str, str | None]:
178
+ async def evaluate(
179
+ self,
180
+ reference: dict[str, Any],
181
+ actual: dict[str, Any]
182
+ ) -> dict[str, Any]:
161
183
  inputs = {}
162
184
  if "question" in self.input_variables:
163
185
  if "question_text" not in reference:
@@ -195,10 +217,4 @@ class CustomEvaluator:
195
217
  return self.parse_outputs(response)
196
218
 
197
219
 
198
- def create_evaluators(config: "evaluation.Config") -> list[CustomEvaluator]:
199
- if config.custom_evaluations and config.llm:
200
- return [
201
- CustomEvaluator(custom_evaluation_config, config)
202
- for custom_evaluation_config in config.custom_evaluations
203
- ]
204
- return []
220
+ _: Evaluator = CustomEvaluator