evalvault 1.70.1__py3-none-any.whl → 1.71.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. evalvault/adapters/inbound/api/adapter.py +367 -3
  2. evalvault/adapters/inbound/api/main.py +17 -1
  3. evalvault/adapters/inbound/api/routers/calibration.py +133 -0
  4. evalvault/adapters/inbound/api/routers/runs.py +71 -1
  5. evalvault/adapters/inbound/cli/commands/__init__.py +2 -0
  6. evalvault/adapters/inbound/cli/commands/analyze.py +1 -0
  7. evalvault/adapters/inbound/cli/commands/compare.py +1 -1
  8. evalvault/adapters/inbound/cli/commands/experiment.py +27 -1
  9. evalvault/adapters/inbound/cli/commands/graph_rag.py +303 -0
  10. evalvault/adapters/inbound/cli/commands/history.py +1 -1
  11. evalvault/adapters/inbound/cli/commands/regress.py +169 -1
  12. evalvault/adapters/inbound/cli/commands/run.py +225 -1
  13. evalvault/adapters/inbound/cli/commands/run_helpers.py +57 -0
  14. evalvault/adapters/outbound/analysis/network_analyzer_module.py +17 -4
  15. evalvault/adapters/outbound/dataset/__init__.py +6 -0
  16. evalvault/adapters/outbound/dataset/multiturn_json_loader.py +111 -0
  17. evalvault/adapters/outbound/report/__init__.py +6 -0
  18. evalvault/adapters/outbound/report/ci_report_formatter.py +43 -0
  19. evalvault/adapters/outbound/report/dashboard_generator.py +24 -9
  20. evalvault/adapters/outbound/report/pr_comment_formatter.py +50 -0
  21. evalvault/adapters/outbound/retriever/__init__.py +8 -0
  22. evalvault/adapters/outbound/retriever/graph_rag_adapter.py +326 -0
  23. evalvault/adapters/outbound/storage/base_sql.py +291 -0
  24. evalvault/adapters/outbound/storage/postgres_adapter.py +130 -0
  25. evalvault/adapters/outbound/storage/postgres_schema.sql +60 -0
  26. evalvault/adapters/outbound/storage/schema.sql +63 -0
  27. evalvault/adapters/outbound/storage/sqlite_adapter.py +107 -0
  28. evalvault/domain/entities/__init__.py +20 -0
  29. evalvault/domain/entities/graph_rag.py +30 -0
  30. evalvault/domain/entities/multiturn.py +78 -0
  31. evalvault/domain/metrics/__init__.py +10 -0
  32. evalvault/domain/metrics/multiturn_metrics.py +113 -0
  33. evalvault/domain/metrics/registry.py +36 -0
  34. evalvault/domain/services/__init__.py +8 -0
  35. evalvault/domain/services/evaluator.py +5 -2
  36. evalvault/domain/services/graph_rag_experiment.py +155 -0
  37. evalvault/domain/services/multiturn_evaluator.py +187 -0
  38. evalvault/ports/inbound/__init__.py +2 -0
  39. evalvault/ports/inbound/multiturn_port.py +23 -0
  40. evalvault/ports/inbound/web_port.py +4 -0
  41. evalvault/ports/outbound/graph_retriever_port.py +24 -0
  42. evalvault/ports/outbound/storage_port.py +25 -0
  43. {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/METADATA +1 -1
  44. {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/RECORD +47 -33
  45. {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/WHEEL +0 -0
  46. {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/entry_points.txt +0 -0
  47. {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -182,6 +182,71 @@ class SQLiteStorageAdapter(BaseSQLStorageAdapter):
182
182
  if "metadata" not in pipeline_columns:
183
183
  conn.execute("ALTER TABLE pipeline_results ADD COLUMN metadata TEXT")
184
184
 
185
+ multiturn_cursor = conn.execute("PRAGMA table_info(multiturn_runs)")
186
+ multiturn_columns = {row[1] for row in multiturn_cursor.fetchall()}
187
+ if not multiturn_columns:
188
+ conn.executescript(
189
+ """
190
+ CREATE TABLE IF NOT EXISTS multiturn_runs (
191
+ run_id TEXT PRIMARY KEY,
192
+ dataset_name TEXT NOT NULL,
193
+ dataset_version TEXT,
194
+ model_name TEXT,
195
+ started_at TIMESTAMP NOT NULL,
196
+ finished_at TIMESTAMP,
197
+ conversation_count INTEGER DEFAULT 0,
198
+ turn_count INTEGER DEFAULT 0,
199
+ metrics_evaluated TEXT,
200
+ drift_threshold REAL,
201
+ summary TEXT,
202
+ metadata TEXT,
203
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
204
+ );
205
+ CREATE INDEX IF NOT EXISTS idx_multiturn_runs_dataset ON multiturn_runs(dataset_name);
206
+ CREATE INDEX IF NOT EXISTS idx_multiturn_runs_started_at ON multiturn_runs(started_at DESC);
207
+
208
+ CREATE TABLE IF NOT EXISTS multiturn_conversations (
209
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
210
+ run_id TEXT NOT NULL,
211
+ conversation_id TEXT NOT NULL,
212
+ turn_count INTEGER DEFAULT 0,
213
+ drift_score REAL,
214
+ drift_threshold REAL,
215
+ drift_detected INTEGER DEFAULT 0,
216
+ summary TEXT,
217
+ FOREIGN KEY (run_id) REFERENCES multiturn_runs(run_id) ON DELETE CASCADE
218
+ );
219
+ CREATE INDEX IF NOT EXISTS idx_multiturn_conversations_run_id ON multiturn_conversations(run_id);
220
+ CREATE INDEX IF NOT EXISTS idx_multiturn_conversations_conv_id ON multiturn_conversations(conversation_id);
221
+
222
+ CREATE TABLE IF NOT EXISTS multiturn_turn_results (
223
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
224
+ run_id TEXT NOT NULL,
225
+ conversation_id TEXT NOT NULL,
226
+ turn_id TEXT NOT NULL,
227
+ turn_index INTEGER,
228
+ role TEXT NOT NULL,
229
+ passed INTEGER DEFAULT 0,
230
+ latency_ms INTEGER,
231
+ metadata TEXT,
232
+ FOREIGN KEY (run_id) REFERENCES multiturn_runs(run_id) ON DELETE CASCADE
233
+ );
234
+ CREATE INDEX IF NOT EXISTS idx_multiturn_turns_run_id ON multiturn_turn_results(run_id);
235
+ CREATE INDEX IF NOT EXISTS idx_multiturn_turns_conv_id ON multiturn_turn_results(conversation_id);
236
+
237
+ CREATE TABLE IF NOT EXISTS multiturn_metric_scores (
238
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
239
+ turn_result_id INTEGER NOT NULL,
240
+ metric_name TEXT NOT NULL,
241
+ score REAL NOT NULL,
242
+ threshold REAL,
243
+ FOREIGN KEY (turn_result_id) REFERENCES multiturn_turn_results(id) ON DELETE CASCADE
244
+ );
245
+ CREATE INDEX IF NOT EXISTS idx_multiturn_scores_turn_id ON multiturn_metric_scores(turn_result_id);
246
+ CREATE INDEX IF NOT EXISTS idx_multiturn_scores_metric_name ON multiturn_metric_scores(metric_name);
247
+ """
248
+ )
249
+
185
250
  # Prompt set methods
186
251
 
187
252
  def save_prompt_set(self, bundle: PromptSetBundle) -> None:
@@ -990,6 +1055,48 @@ class SQLiteStorageAdapter(BaseSQLStorageAdapter):
990
1055
 
991
1056
  return report_id
992
1057
 
1058
+ def list_analysis_reports(
1059
+ self,
1060
+ *,
1061
+ run_id: str,
1062
+ report_type: str | None = None,
1063
+ format: str | None = None,
1064
+ limit: int = 20,
1065
+ ) -> list[dict[str, Any]]:
1066
+ query = (
1067
+ "SELECT report_id, run_id, experiment_id, report_type, format, content, metadata, created_at "
1068
+ "FROM analysis_reports WHERE run_id = ?"
1069
+ )
1070
+ params: list[Any] = [run_id]
1071
+ if report_type:
1072
+ query += " AND report_type = ?"
1073
+ params.append(report_type)
1074
+ if format:
1075
+ query += " AND format = ?"
1076
+ params.append(format)
1077
+ query += " ORDER BY created_at DESC LIMIT ?"
1078
+ params.append(limit)
1079
+
1080
+ with self._get_connection() as conn:
1081
+ conn = cast(Any, conn)
1082
+ rows = conn.execute(query, tuple(params)).fetchall()
1083
+
1084
+ reports: list[dict[str, Any]] = []
1085
+ for row in rows:
1086
+ reports.append(
1087
+ {
1088
+ "report_id": row["report_id"],
1089
+ "run_id": row["run_id"],
1090
+ "experiment_id": row["experiment_id"],
1091
+ "report_type": row["report_type"],
1092
+ "format": row["format"],
1093
+ "content": row["content"],
1094
+ "metadata": self._deserialize_json(row["metadata"]),
1095
+ "created_at": row["created_at"],
1096
+ }
1097
+ )
1098
+ return reports
1099
+
993
1100
  def list_pipeline_results(self, limit: int = 50) -> list[dict[str, Any]]:
994
1101
  """파이프라인 분석 결과 목록을 조회합니다."""
995
1102
  query = """
@@ -21,6 +21,7 @@ from evalvault.domain.entities.feedback import (
21
21
  FeedbackSummary,
22
22
  SatisfactionFeedback,
23
23
  )
24
+ from evalvault.domain.entities.graph_rag import EntityNode, KnowledgeSubgraph, RelationEdge
24
25
  from evalvault.domain.entities.improvement import (
25
26
  EffortLevel,
26
27
  EvidenceSource,
@@ -42,6 +43,15 @@ from evalvault.domain.entities.judge_calibration import (
42
43
  )
43
44
  from evalvault.domain.entities.kg import EntityModel, RelationModel
44
45
  from evalvault.domain.entities.method import MethodInput, MethodInputDataset, MethodOutput
46
+ from evalvault.domain.entities.multiturn import (
47
+ ConversationTurn,
48
+ DriftAnalysis,
49
+ MultiTurnConversationRecord,
50
+ MultiTurnEvaluationResult,
51
+ MultiTurnRunRecord,
52
+ MultiTurnTestCase,
53
+ MultiTurnTurnResult,
54
+ )
45
55
  from evalvault.domain.entities.prompt import Prompt, PromptSet, PromptSetBundle, PromptSetItem
46
56
  from evalvault.domain.entities.prompt_suggestion import (
47
57
  PromptCandidate,
@@ -114,6 +124,16 @@ __all__ = [
114
124
  "JudgeCalibrationMetric",
115
125
  "JudgeCalibrationResult",
116
126
  "JudgeCalibrationSummary",
127
+ "ConversationTurn",
128
+ "MultiTurnConversationRecord",
129
+ "MultiTurnTestCase",
130
+ "MultiTurnTurnResult",
131
+ "MultiTurnEvaluationResult",
132
+ "DriftAnalysis",
133
+ "MultiTurnRunRecord",
134
+ "EntityNode",
135
+ "KnowledgeSubgraph",
136
+ "RelationEdge",
117
137
  # KG
118
138
  "EntityModel",
119
139
  "RelationModel",
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+
7
+ @dataclass
8
+ class EntityNode:
9
+ entity_id: str
10
+ name: str
11
+ entity_type: str
12
+ attributes: dict[str, Any] = field(default_factory=dict)
13
+
14
+
15
+ @dataclass
16
+ class RelationEdge:
17
+ source_id: str
18
+ target_id: str
19
+ relation_type: str
20
+ weight: float = 1.0
21
+ attributes: dict[str, Any] = field(default_factory=dict)
22
+
23
+
24
+ @dataclass
25
+ class KnowledgeSubgraph:
26
+ """질의에 대해 추출된 관련 서브그래프."""
27
+
28
+ nodes: list[EntityNode]
29
+ edges: list[RelationEdge]
30
+ relevance_score: float
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from datetime import datetime
5
+ from typing import Any, Literal
6
+
7
+
8
+ @dataclass
9
+ class ConversationTurn:
10
+ turn_id: str
11
+ role: Literal["user", "assistant"]
12
+ content: str
13
+ contexts: list[str] | None = None
14
+ ground_truth: str | None = None
15
+ metadata: dict[str, Any] = field(default_factory=dict)
16
+
17
+
18
+ @dataclass
19
+ class MultiTurnTestCase:
20
+ conversation_id: str
21
+ turns: list[ConversationTurn]
22
+ expected_final_answer: str | None = None
23
+ drift_tolerance: float = 0.1
24
+
25
+
26
+ @dataclass
27
+ class MultiTurnTurnResult:
28
+ conversation_id: str
29
+ turn_id: str
30
+ turn_index: int | None
31
+ role: Literal["user", "assistant"]
32
+ metrics: dict[str, float] = field(default_factory=dict)
33
+ passed: bool = False
34
+ latency_ms: int | None = None
35
+ metadata: dict[str, Any] = field(default_factory=dict)
36
+
37
+
38
+ @dataclass
39
+ class MultiTurnEvaluationResult:
40
+ conversation_id: str
41
+ turn_results: list[MultiTurnTurnResult] = field(default_factory=list)
42
+ summary: dict[str, Any] = field(default_factory=dict)
43
+
44
+
45
+ @dataclass
46
+ class DriftAnalysis:
47
+ conversation_id: str
48
+ drift_score: float
49
+ drift_threshold: float
50
+ drift_detected: bool
51
+ notes: list[str] = field(default_factory=list)
52
+
53
+
54
+ @dataclass
55
+ class MultiTurnRunRecord:
56
+ run_id: str
57
+ dataset_name: str
58
+ dataset_version: str | None
59
+ model_name: str | None
60
+ started_at: datetime
61
+ finished_at: datetime | None
62
+ conversation_count: int
63
+ turn_count: int
64
+ metrics_evaluated: list[str] = field(default_factory=list)
65
+ drift_threshold: float | None = None
66
+ summary: dict[str, Any] = field(default_factory=dict)
67
+ metadata: dict[str, Any] = field(default_factory=dict)
68
+
69
+
70
+ @dataclass
71
+ class MultiTurnConversationRecord:
72
+ run_id: str
73
+ conversation_id: str
74
+ turn_count: int
75
+ drift_score: float | None = None
76
+ drift_threshold: float | None = None
77
+ drift_detected: bool = False
78
+ summary: dict[str, Any] = field(default_factory=dict)
@@ -4,6 +4,12 @@ from evalvault.domain.metrics.confidence import ConfidenceScore
4
4
  from evalvault.domain.metrics.contextual_relevancy import ContextualRelevancy
5
5
  from evalvault.domain.metrics.entity_preservation import EntityPreservation
6
6
  from evalvault.domain.metrics.insurance import InsuranceTermAccuracy
7
+ from evalvault.domain.metrics.multiturn_metrics import (
8
+ calculate_context_coherence,
9
+ calculate_drift_rate,
10
+ calculate_turn_faithfulness,
11
+ calculate_turn_latency_p95,
12
+ )
7
13
  from evalvault.domain.metrics.no_answer import NoAnswerAccuracy, is_no_answer
8
14
  from evalvault.domain.metrics.retrieval_rank import MRR, NDCG, HitRate
9
15
  from evalvault.domain.metrics.summary_accuracy import SummaryAccuracy
@@ -28,4 +34,8 @@ __all__ = [
28
34
  "SummaryNonDefinitive",
29
35
  "SummaryRiskCoverage",
30
36
  "is_no_answer",
37
+ "calculate_context_coherence",
38
+ "calculate_drift_rate",
39
+ "calculate_turn_faithfulness",
40
+ "calculate_turn_latency_p95",
31
41
  ]
@@ -0,0 +1,113 @@
1
+ """
2
+ Utilities for multi-turn evaluation metrics.
3
+
4
+ Metrics:
5
+ - turn_faithfulness: average per-turn faithfulness
6
+ - context_coherence: coherence across turn contexts
7
+ - drift_rate: distance between initial intent and final response
8
+ - turn_latency: p95 latency across turns
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ import re
15
+ import unicodedata
16
+ from collections.abc import Iterable
17
+
18
+ from evalvault.domain.entities.multiturn import ConversationTurn, MultiTurnTurnResult
19
+
20
+
21
+ def _normalize_text(text: str) -> str:
22
+ if not text:
23
+ return ""
24
+ text = unicodedata.normalize("NFC", text)
25
+ text = text.lower()
26
+ text = re.sub(r"\s+", " ", text).strip()
27
+ return text
28
+
29
+
30
+ def _tokenize(text: str) -> set[str]:
31
+ if not text:
32
+ return set()
33
+ text = _normalize_text(text)
34
+ tokens = re.findall(r"[\w가-힣]+", text)
35
+ return set(tokens)
36
+
37
+
38
+ def _jaccard_similarity(left: str, right: str) -> float:
39
+ left_tokens = _tokenize(left)
40
+ right_tokens = _tokenize(right)
41
+ if not left_tokens and not right_tokens:
42
+ return 1.0
43
+ if not left_tokens or not right_tokens:
44
+ return 0.0
45
+ intersection = left_tokens.intersection(right_tokens)
46
+ union = left_tokens.union(right_tokens)
47
+ if not union:
48
+ return 0.0
49
+ return len(intersection) / len(union)
50
+
51
+
52
+ def _turn_context_text(turn: ConversationTurn) -> str:
53
+ if turn.contexts:
54
+ return " ".join([ctx for ctx in turn.contexts if ctx])
55
+ return turn.content or ""
56
+
57
+
58
+ def calculate_turn_faithfulness(turn_results: Iterable[MultiTurnTurnResult]) -> float:
59
+ scores: list[float] = []
60
+ for result in turn_results:
61
+ score = result.metrics.get("faithfulness") if result.metrics else None
62
+ if score is not None:
63
+ scores.append(score)
64
+ if not scores:
65
+ return 0.0
66
+ return sum(scores) / len(scores)
67
+
68
+
69
+ def calculate_context_coherence(turns: Iterable[ConversationTurn]) -> float:
70
+ turn_list = list(turns)
71
+ if len(turn_list) < 2:
72
+ return 1.0
73
+ scores: list[float] = []
74
+ for prev, curr in zip(turn_list, turn_list[1:], strict=False):
75
+ left = _turn_context_text(prev)
76
+ right = _turn_context_text(curr)
77
+ scores.append(_jaccard_similarity(left, right))
78
+ if not scores:
79
+ return 0.0
80
+ return sum(scores) / len(scores)
81
+
82
+
83
+ def calculate_drift_rate(turns: Iterable[ConversationTurn]) -> float:
84
+ turn_list = list(turns)
85
+ if not turn_list:
86
+ return 0.0
87
+ first_user = next((t for t in turn_list if t.role == "user"), None)
88
+ last_assistant = next((t for t in reversed(turn_list) if t.role == "assistant"), None)
89
+ if not first_user or not last_assistant:
90
+ return 0.0
91
+ similarity = _jaccard_similarity(first_user.content, last_assistant.content)
92
+ drift = 1.0 - similarity
93
+ if drift < 0.0:
94
+ return 0.0
95
+ if drift > 1.0:
96
+ return 1.0
97
+ return drift
98
+
99
+
100
+ def calculate_turn_latency_p95(latencies_ms: Iterable[int | None]) -> float:
101
+ values = [float(value) for value in latencies_ms if value is not None]
102
+ if not values:
103
+ return 0.0
104
+ values.sort()
105
+ if len(values) == 1:
106
+ return values[0]
107
+ rank = 0.95 * (len(values) - 1)
108
+ lower = int(math.floor(rank))
109
+ upper = int(math.ceil(rank))
110
+ if lower == upper:
111
+ return values[lower]
112
+ fraction = rank - lower
113
+ return values[lower] + (values[upper] - values[lower]) * fraction
@@ -139,6 +139,42 @@ _METRIC_SPECS: tuple[MetricSpec, ...] = (
139
139
  category="summary",
140
140
  signal_group="summary_fidelity",
141
141
  ),
142
+ MetricSpec(
143
+ name="turn_faithfulness",
144
+ description="(Multi-turn) Average faithfulness across assistant turns",
145
+ requires_ground_truth=False,
146
+ requires_embeddings=False,
147
+ source="custom",
148
+ category="qa",
149
+ signal_group="groundedness",
150
+ ),
151
+ MetricSpec(
152
+ name="context_coherence",
153
+ description="(Multi-turn) Context continuity across turns",
154
+ requires_ground_truth=False,
155
+ requires_embeddings=False,
156
+ source="custom",
157
+ category="qa",
158
+ signal_group="intent_alignment",
159
+ ),
160
+ MetricSpec(
161
+ name="drift_rate",
162
+ description="(Multi-turn) Distance between initial intent and final response",
163
+ requires_ground_truth=False,
164
+ requires_embeddings=False,
165
+ source="custom",
166
+ category="qa",
167
+ signal_group="intent_alignment",
168
+ ),
169
+ MetricSpec(
170
+ name="turn_latency",
171
+ description="(Multi-turn) P95 response latency across turns (ms)",
172
+ requires_ground_truth=False,
173
+ requires_embeddings=False,
174
+ source="custom",
175
+ category="qa",
176
+ signal_group="efficiency",
177
+ ),
142
178
  MetricSpec(
143
179
  name="entity_preservation",
144
180
  description="(Rule) Measures preservation of key insurance entities in summaries",
@@ -4,9 +4,14 @@ from evalvault.domain.services.analysis_service import AnalysisService
4
4
  from evalvault.domain.services.dataset_preprocessor import DatasetPreprocessor
5
5
  from evalvault.domain.services.domain_learning_hook import DomainLearningHook
6
6
  from evalvault.domain.services.evaluator import RagasEvaluator
7
+ from evalvault.domain.services.graph_rag_experiment import (
8
+ GraphRAGExperiment,
9
+ GraphRAGExperimentResult,
10
+ )
7
11
  from evalvault.domain.services.holdout_splitter import split_dataset_holdout
8
12
  from evalvault.domain.services.improvement_guide_service import ImprovementGuideService
9
13
  from evalvault.domain.services.method_runner import MethodRunnerService, MethodRunResult
14
+ from evalvault.domain.services.multiturn_evaluator import MultiTurnEvaluator
10
15
  from evalvault.domain.services.prompt_scoring_service import PromptScoringService
11
16
  from evalvault.domain.services.prompt_suggestion_reporter import PromptSuggestionReporter
12
17
 
@@ -17,8 +22,11 @@ __all__ = [
17
22
  "ImprovementGuideService",
18
23
  "MethodRunnerService",
19
24
  "MethodRunResult",
25
+ "GraphRAGExperiment",
26
+ "GraphRAGExperimentResult",
20
27
  "PromptScoringService",
21
28
  "PromptSuggestionReporter",
22
29
  "RagasEvaluator",
30
+ "MultiTurnEvaluator",
23
31
  "split_dataset_holdout",
24
32
  ]
@@ -63,9 +63,12 @@ _SUMMARY_FAITHFULNESS_PROMPT_EN = (
63
63
 
64
64
  def _patch_ragas_faithfulness_output() -> None:
65
65
  try:
66
- from ragas.metrics import Faithfulness
66
+ from ragas.metrics.collections import Faithfulness
67
67
  except Exception:
68
- return
68
+ try:
69
+ from ragas.metrics import Faithfulness
70
+ except Exception:
71
+ return
69
72
 
70
73
  prompt = getattr(Faithfulness, "nli_statements_prompt", None)
71
74
  if prompt is None:
@@ -0,0 +1,155 @@
1
+ """GraphRAG experiment helper for baseline vs graph comparison."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from evalvault.domain.entities import Dataset, EvaluationRun, TestCase
8
+ from evalvault.domain.entities.analysis import ComparisonResult
9
+ from evalvault.domain.entities.graph_rag import KnowledgeSubgraph
10
+ from evalvault.domain.services.analysis_service import AnalysisService
11
+ from evalvault.domain.services.evaluator import RagasEvaluator
12
+ from evalvault.ports.outbound.graph_retriever_port import GraphRetrieverPort
13
+ from evalvault.ports.outbound.korean_nlp_port import RetrieverPort
14
+ from evalvault.ports.outbound.llm_port import LLMPort
15
+
16
+
17
+ @dataclass
18
+ class GraphRAGExperimentResult:
19
+ baseline_run: EvaluationRun
20
+ graph_run: EvaluationRun
21
+ comparisons: list[ComparisonResult]
22
+ graph_subgraphs: dict[str, KnowledgeSubgraph]
23
+ graph_contexts: dict[str, str]
24
+
25
+
26
+ class GraphRAGExperiment:
27
+ """Compare baseline retrieval with GraphRAG context generation."""
28
+
29
+ def __init__(
30
+ self,
31
+ *,
32
+ evaluator: RagasEvaluator,
33
+ analysis_service: AnalysisService,
34
+ ) -> None:
35
+ self._evaluator = evaluator
36
+ self._analysis = analysis_service
37
+
38
+ async def run_comparison(
39
+ self,
40
+ *,
41
+ dataset: Dataset,
42
+ baseline_retriever: RetrieverPort,
43
+ graph_retriever: GraphRetrieverPort,
44
+ metrics: list[str],
45
+ llm: LLMPort,
46
+ thresholds: dict[str, float] | None = None,
47
+ retriever_top_k: int = 5,
48
+ graph_max_hops: int = 2,
49
+ graph_max_nodes: int = 20,
50
+ parallel: bool = False,
51
+ batch_size: int = 5,
52
+ prompt_overrides: dict[str, str] | None = None,
53
+ claim_level: bool = False,
54
+ language: str | None = None,
55
+ ) -> GraphRAGExperimentResult:
56
+ baseline_dataset = self._clone_dataset(dataset)
57
+ graph_dataset = self._clone_dataset(dataset)
58
+
59
+ graph_subgraphs, graph_contexts = self._apply_graph_contexts(
60
+ graph_dataset,
61
+ graph_retriever,
62
+ max_hops=graph_max_hops,
63
+ max_nodes=graph_max_nodes,
64
+ )
65
+
66
+ baseline_run = await self._evaluator.evaluate(
67
+ baseline_dataset,
68
+ metrics,
69
+ llm,
70
+ thresholds=thresholds,
71
+ parallel=parallel,
72
+ batch_size=batch_size,
73
+ retriever=baseline_retriever,
74
+ retriever_top_k=retriever_top_k,
75
+ prompt_overrides=prompt_overrides,
76
+ claim_level=claim_level,
77
+ language=language,
78
+ )
79
+
80
+ graph_run = await self._evaluator.evaluate(
81
+ graph_dataset,
82
+ metrics,
83
+ llm,
84
+ thresholds=thresholds,
85
+ parallel=parallel,
86
+ batch_size=batch_size,
87
+ retriever=None,
88
+ prompt_overrides=prompt_overrides,
89
+ claim_level=claim_level,
90
+ language=language,
91
+ )
92
+
93
+ comparisons = self._analysis.compare_runs(
94
+ baseline_run,
95
+ graph_run,
96
+ metrics=metrics,
97
+ )
98
+
99
+ return GraphRAGExperimentResult(
100
+ baseline_run=baseline_run,
101
+ graph_run=graph_run,
102
+ comparisons=comparisons,
103
+ graph_subgraphs=graph_subgraphs,
104
+ graph_contexts=graph_contexts,
105
+ )
106
+
107
+ @staticmethod
108
+ def _clone_dataset(dataset: Dataset) -> Dataset:
109
+ test_cases = [
110
+ TestCase(
111
+ id=case.id,
112
+ question=case.question,
113
+ answer=case.answer,
114
+ contexts=list(case.contexts),
115
+ ground_truth=case.ground_truth,
116
+ metadata=dict(case.metadata),
117
+ )
118
+ for case in dataset.test_cases
119
+ ]
120
+ return Dataset(
121
+ name=dataset.name,
122
+ version=dataset.version,
123
+ test_cases=test_cases,
124
+ metadata=dict(dataset.metadata),
125
+ source_file=dataset.source_file,
126
+ thresholds=dict(dataset.thresholds),
127
+ )
128
+
129
+ @staticmethod
130
+ def _apply_graph_contexts(
131
+ dataset: Dataset,
132
+ graph_retriever: GraphRetrieverPort,
133
+ *,
134
+ max_hops: int,
135
+ max_nodes: int,
136
+ ) -> tuple[dict[str, KnowledgeSubgraph], dict[str, str]]:
137
+ subgraphs: dict[str, KnowledgeSubgraph] = {}
138
+ contexts: dict[str, str] = {}
139
+ for case in dataset.test_cases:
140
+ if case.contexts and any(context.strip() for context in case.contexts):
141
+ continue
142
+ subgraph = graph_retriever.build_subgraph(
143
+ case.question,
144
+ max_hops=max_hops,
145
+ max_nodes=max_nodes,
146
+ )
147
+ context_text = graph_retriever.generate_context(subgraph)
148
+ if context_text:
149
+ case.contexts = [context_text]
150
+ contexts[case.id] = context_text
151
+ subgraphs[case.id] = subgraph
152
+ return subgraphs, contexts
153
+
154
+
155
+ __all__ = ["GraphRAGExperiment", "GraphRAGExperimentResult"]