evalvault 1.70.1__py3-none-any.whl → 1.71.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evalvault/adapters/inbound/api/adapter.py +367 -3
- evalvault/adapters/inbound/api/main.py +17 -1
- evalvault/adapters/inbound/api/routers/calibration.py +133 -0
- evalvault/adapters/inbound/api/routers/runs.py +71 -1
- evalvault/adapters/inbound/cli/commands/__init__.py +2 -0
- evalvault/adapters/inbound/cli/commands/analyze.py +1 -0
- evalvault/adapters/inbound/cli/commands/compare.py +1 -1
- evalvault/adapters/inbound/cli/commands/experiment.py +27 -1
- evalvault/adapters/inbound/cli/commands/graph_rag.py +303 -0
- evalvault/adapters/inbound/cli/commands/history.py +1 -1
- evalvault/adapters/inbound/cli/commands/regress.py +169 -1
- evalvault/adapters/inbound/cli/commands/run.py +225 -1
- evalvault/adapters/inbound/cli/commands/run_helpers.py +57 -0
- evalvault/adapters/outbound/analysis/network_analyzer_module.py +17 -4
- evalvault/adapters/outbound/dataset/__init__.py +6 -0
- evalvault/adapters/outbound/dataset/multiturn_json_loader.py +111 -0
- evalvault/adapters/outbound/report/__init__.py +6 -0
- evalvault/adapters/outbound/report/ci_report_formatter.py +43 -0
- evalvault/adapters/outbound/report/dashboard_generator.py +24 -9
- evalvault/adapters/outbound/report/pr_comment_formatter.py +50 -0
- evalvault/adapters/outbound/retriever/__init__.py +8 -0
- evalvault/adapters/outbound/retriever/graph_rag_adapter.py +326 -0
- evalvault/adapters/outbound/storage/base_sql.py +291 -0
- evalvault/adapters/outbound/storage/postgres_adapter.py +130 -0
- evalvault/adapters/outbound/storage/postgres_schema.sql +60 -0
- evalvault/adapters/outbound/storage/schema.sql +63 -0
- evalvault/adapters/outbound/storage/sqlite_adapter.py +107 -0
- evalvault/domain/entities/__init__.py +20 -0
- evalvault/domain/entities/graph_rag.py +30 -0
- evalvault/domain/entities/multiturn.py +78 -0
- evalvault/domain/metrics/__init__.py +10 -0
- evalvault/domain/metrics/multiturn_metrics.py +113 -0
- evalvault/domain/metrics/registry.py +36 -0
- evalvault/domain/services/__init__.py +8 -0
- evalvault/domain/services/evaluator.py +5 -2
- evalvault/domain/services/graph_rag_experiment.py +155 -0
- evalvault/domain/services/multiturn_evaluator.py +187 -0
- evalvault/ports/inbound/__init__.py +2 -0
- evalvault/ports/inbound/multiturn_port.py +23 -0
- evalvault/ports/inbound/web_port.py +4 -0
- evalvault/ports/outbound/graph_retriever_port.py +24 -0
- evalvault/ports/outbound/storage_port.py +25 -0
- {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/METADATA +1 -1
- {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/RECORD +47 -33
- {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/WHEEL +0 -0
- {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/entry_points.txt +0 -0
- {evalvault-1.70.1.dist-info → evalvault-1.71.0.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import contextlib
|
|
3
4
|
import json
|
|
5
|
+
import os
|
|
4
6
|
import random
|
|
7
|
+
import sys
|
|
5
8
|
from importlib import import_module
|
|
6
9
|
from pathlib import Path
|
|
7
10
|
from typing import Any
|
|
@@ -9,6 +12,12 @@ from typing import Any
|
|
|
9
12
|
|
|
10
13
|
def _import_matplotlib_pyplot() -> Any:
|
|
11
14
|
try:
|
|
15
|
+
if "matplotlib.pyplot" in sys.modules:
|
|
16
|
+
return import_module("matplotlib.pyplot")
|
|
17
|
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
|
18
|
+
matplotlib = import_module("matplotlib")
|
|
19
|
+
with contextlib.suppress(Exception):
|
|
20
|
+
matplotlib.use("Agg", force=True)
|
|
12
21
|
return import_module("matplotlib.pyplot")
|
|
13
22
|
except ModuleNotFoundError as exc:
|
|
14
23
|
raise ImportError(
|
|
@@ -32,14 +41,20 @@ class DashboardGenerator:
|
|
|
32
41
|
plt.rcParams["legend.fontsize"] = 10
|
|
33
42
|
|
|
34
43
|
def generate_evaluation_dashboard(
|
|
35
|
-
self,
|
|
44
|
+
self,
|
|
45
|
+
run_id: str,
|
|
46
|
+
analysis_json_path: str | None = None,
|
|
47
|
+
analysis_data: dict[str, Any] | None = None,
|
|
36
48
|
) -> Any:
|
|
37
49
|
plt = _import_matplotlib_pyplot()
|
|
38
50
|
|
|
39
|
-
|
|
40
|
-
if
|
|
41
|
-
|
|
42
|
-
|
|
51
|
+
analysis_payload: dict[str, Any] = {}
|
|
52
|
+
if analysis_data is None:
|
|
53
|
+
if analysis_json_path and Path(analysis_json_path).exists():
|
|
54
|
+
with open(analysis_json_path, encoding="utf-8") as f:
|
|
55
|
+
analysis_payload = json.load(f)
|
|
56
|
+
elif isinstance(analysis_data, dict):
|
|
57
|
+
analysis_payload = analysis_data
|
|
43
58
|
|
|
44
59
|
fig, axes = plt.subplots(2, 2, figsize=(14, 10), constrained_layout=True)
|
|
45
60
|
fig.suptitle(
|
|
@@ -48,10 +63,10 @@ class DashboardGenerator:
|
|
|
48
63
|
fontweight="bold",
|
|
49
64
|
)
|
|
50
65
|
|
|
51
|
-
self._plot_metric_distribution(axes[0, 0],
|
|
52
|
-
self._plot_correlation_heatmap(axes[0, 1],
|
|
53
|
-
self._plot_pass_rates(axes[1, 0],
|
|
54
|
-
self._plot_failure_causes(axes[1, 1],
|
|
66
|
+
self._plot_metric_distribution(axes[0, 0], analysis_payload)
|
|
67
|
+
self._plot_correlation_heatmap(axes[0, 1], analysis_payload)
|
|
68
|
+
self._plot_pass_rates(axes[1, 0], analysis_payload)
|
|
69
|
+
self._plot_failure_causes(axes[1, 1], analysis_payload)
|
|
55
70
|
|
|
56
71
|
return fig
|
|
57
72
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from evalvault.adapters.outbound.report.ci_report_formatter import CIGateMetricRow
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def format_ci_gate_pr_comment(
|
|
7
|
+
rows: list[CIGateMetricRow],
|
|
8
|
+
*,
|
|
9
|
+
baseline_run_id: str,
|
|
10
|
+
current_run_id: str,
|
|
11
|
+
regression_rate: float,
|
|
12
|
+
regression_threshold: float,
|
|
13
|
+
gate_passed: bool,
|
|
14
|
+
threshold_failures: list[str],
|
|
15
|
+
regressed_metrics: list[str],
|
|
16
|
+
) -> str:
|
|
17
|
+
lines: list[str] = ["## EvalVault CI Gate", ""]
|
|
18
|
+
lines.append(f"- Baseline: `{baseline_run_id}`")
|
|
19
|
+
lines.append(f"- Current: `{current_run_id}`")
|
|
20
|
+
lines.append("")
|
|
21
|
+
lines.append("| Metric | Baseline | Current | Change | Status |")
|
|
22
|
+
lines.append("|--------|----------|---------|--------|--------|")
|
|
23
|
+
for row in rows:
|
|
24
|
+
change = f"{row.change_percent:+.1f}%"
|
|
25
|
+
lines.append(
|
|
26
|
+
f"| {row.metric} | {row.baseline_score:.3f} | {row.current_score:.3f} | {change} | {row.status} |"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
lines.append("")
|
|
30
|
+
if gate_passed:
|
|
31
|
+
status_line = "✅ PASSED"
|
|
32
|
+
comparison = "<"
|
|
33
|
+
else:
|
|
34
|
+
status_line = "❌ FAILED"
|
|
35
|
+
comparison = ">="
|
|
36
|
+
lines.append(
|
|
37
|
+
f"**Gate Status**: {status_line} (regression: {regression_rate:.1%} {comparison} {regression_threshold:.1%} threshold)"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if threshold_failures or regressed_metrics:
|
|
41
|
+
lines.append("")
|
|
42
|
+
if threshold_failures:
|
|
43
|
+
lines.append("**Threshold Failures**: " + ", ".join(sorted(set(threshold_failures))))
|
|
44
|
+
if regressed_metrics:
|
|
45
|
+
lines.append("**Regressions**: " + ", ".join(sorted(set(regressed_metrics))))
|
|
46
|
+
|
|
47
|
+
return "\n".join(lines).strip()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
__all__ = ["format_ci_gate_pr_comment"]
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
"""GraphRAG adapter that exposes graph-centric retrieval helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from evalvault.adapters.outbound.kg.networkx_adapter import NetworkXKnowledgeGraph
|
|
10
|
+
from evalvault.domain.entities.graph_rag import EntityNode, KnowledgeSubgraph, RelationEdge
|
|
11
|
+
from evalvault.domain.entities.kg import EntityModel, RelationModel
|
|
12
|
+
from evalvault.domain.services.entity_extractor import EntityExtractor
|
|
13
|
+
from evalvault.ports.outbound.graph_retriever_port import GraphRetrieverPort
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GraphRAGAdapter(GraphRetrieverPort):
|
|
17
|
+
"""GraphRAG adapter over NetworkXKnowledgeGraph."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
kg: NetworkXKnowledgeGraph,
|
|
22
|
+
*,
|
|
23
|
+
entity_extractor: EntityExtractor | None = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._kg = kg
|
|
26
|
+
self._entity_extractor = entity_extractor or EntityExtractor()
|
|
27
|
+
|
|
28
|
+
def extract_entities(self, text: str) -> list[EntityNode]:
|
|
29
|
+
names = self._extract_entity_names(text)
|
|
30
|
+
return [self._entity_to_node(entity) for entity in self._resolve_entities(names)]
|
|
31
|
+
|
|
32
|
+
def build_subgraph(
|
|
33
|
+
self,
|
|
34
|
+
query: str,
|
|
35
|
+
max_hops: int = 2,
|
|
36
|
+
max_nodes: int = 20,
|
|
37
|
+
) -> KnowledgeSubgraph:
|
|
38
|
+
if not query:
|
|
39
|
+
return KnowledgeSubgraph(nodes=[], edges=[], relevance_score=0.0)
|
|
40
|
+
|
|
41
|
+
resolved_max_hops = max(max_hops, 0)
|
|
42
|
+
resolved_max_nodes = max(max_nodes, 1)
|
|
43
|
+
names = self._extract_entity_names(query)
|
|
44
|
+
seeds = self._resolve_entities(names)
|
|
45
|
+
if not seeds:
|
|
46
|
+
return KnowledgeSubgraph(nodes=[], edges=[], relevance_score=0.0)
|
|
47
|
+
|
|
48
|
+
selected = self._select_entities(seeds, resolved_max_hops, resolved_max_nodes)
|
|
49
|
+
edges = self._collect_edges(selected)
|
|
50
|
+
relevance_score = self._compute_relevance(selected, edges)
|
|
51
|
+
|
|
52
|
+
return KnowledgeSubgraph(
|
|
53
|
+
nodes=[self._entity_to_node(entity) for entity in selected],
|
|
54
|
+
edges=[self._relation_to_edge(edge) for edge in edges],
|
|
55
|
+
relevance_score=relevance_score,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def generate_context(self, subgraph: KnowledgeSubgraph) -> str:
|
|
59
|
+
if not subgraph.nodes and not subgraph.edges:
|
|
60
|
+
return ""
|
|
61
|
+
|
|
62
|
+
lines: list[str] = []
|
|
63
|
+
if subgraph.nodes:
|
|
64
|
+
lines.append("Entities:")
|
|
65
|
+
for node in subgraph.nodes:
|
|
66
|
+
label = f"{node.name} ({node.entity_type})"
|
|
67
|
+
lines.append(f"- {label}")
|
|
68
|
+
|
|
69
|
+
if subgraph.edges:
|
|
70
|
+
if lines:
|
|
71
|
+
lines.append("")
|
|
72
|
+
lines.append("Relations:")
|
|
73
|
+
for edge in subgraph.edges:
|
|
74
|
+
label = f"{edge.source_id} -[{edge.relation_type}]-> {edge.target_id}"
|
|
75
|
+
lines.append(f"- {label}")
|
|
76
|
+
|
|
77
|
+
return "\n".join(lines)
|
|
78
|
+
|
|
79
|
+
def _extract_entity_names(self, text: str) -> list[str]:
|
|
80
|
+
names: list[str] = []
|
|
81
|
+
for entity in self._entity_extractor.extract_entities(text):
|
|
82
|
+
if entity.name:
|
|
83
|
+
names.append(entity.name)
|
|
84
|
+
names.extend(self._match_known_entities(text))
|
|
85
|
+
return self._dedupe(names)
|
|
86
|
+
|
|
87
|
+
def _match_known_entities(self, text: str) -> list[str]:
|
|
88
|
+
if not text:
|
|
89
|
+
return []
|
|
90
|
+
query_lower = text.lower()
|
|
91
|
+
matches: list[str] = []
|
|
92
|
+
for entity in self._kg.get_all_entities():
|
|
93
|
+
name = entity.name
|
|
94
|
+
if name and name.lower() in query_lower:
|
|
95
|
+
matches.append(name)
|
|
96
|
+
continue
|
|
97
|
+
canonical = entity.canonical_name
|
|
98
|
+
if canonical and canonical in query_lower:
|
|
99
|
+
matches.append(entity.name)
|
|
100
|
+
return matches
|
|
101
|
+
|
|
102
|
+
def _resolve_entities(self, names: Iterable[str]) -> list[EntityModel]:
|
|
103
|
+
resolved: dict[str, EntityModel] = {}
|
|
104
|
+
for name in names:
|
|
105
|
+
entity = self._kg.get_entity(name)
|
|
106
|
+
if entity:
|
|
107
|
+
resolved[entity.name] = entity
|
|
108
|
+
return list(resolved.values())
|
|
109
|
+
|
|
110
|
+
def _select_entities(
|
|
111
|
+
self,
|
|
112
|
+
seeds: list[EntityModel],
|
|
113
|
+
max_hops: int,
|
|
114
|
+
max_nodes: int,
|
|
115
|
+
) -> list[EntityModel]:
|
|
116
|
+
selected: dict[str, EntityModel] = {entity.name: entity for entity in seeds}
|
|
117
|
+
if max_hops > 0:
|
|
118
|
+
for seed in seeds:
|
|
119
|
+
for neighbor in self._kg.find_neighbors(seed.name, depth=max_hops):
|
|
120
|
+
selected.setdefault(neighbor.name, neighbor)
|
|
121
|
+
|
|
122
|
+
if len(selected) <= max_nodes:
|
|
123
|
+
return list(selected.values())
|
|
124
|
+
|
|
125
|
+
seed_names = {entity.name for entity in seeds}
|
|
126
|
+
prioritized = sorted(
|
|
127
|
+
selected.values(),
|
|
128
|
+
key=lambda entity: (entity.name not in seed_names, -entity.confidence, entity.name),
|
|
129
|
+
)
|
|
130
|
+
return prioritized[:max_nodes]
|
|
131
|
+
|
|
132
|
+
def _collect_edges(self, entities: list[EntityModel]) -> list[RelationModel]:
|
|
133
|
+
selected = {entity.name for entity in entities}
|
|
134
|
+
edges: list[RelationModel] = []
|
|
135
|
+
seen: set[tuple[str, str, str]] = set()
|
|
136
|
+
for entity in entities:
|
|
137
|
+
for relation in self._kg.get_outgoing_relations(entity.name):
|
|
138
|
+
if relation.target not in selected:
|
|
139
|
+
continue
|
|
140
|
+
key = (relation.source, relation.target, relation.relation_type)
|
|
141
|
+
if key in seen:
|
|
142
|
+
continue
|
|
143
|
+
seen.add(key)
|
|
144
|
+
edges.append(relation)
|
|
145
|
+
return edges
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def _compute_relevance(
|
|
149
|
+
entities: list[EntityModel],
|
|
150
|
+
edges: list[RelationModel],
|
|
151
|
+
) -> float:
|
|
152
|
+
if not entities and not edges:
|
|
153
|
+
return 0.0
|
|
154
|
+
scores = [entity.confidence for entity in entities] + [edge.confidence for edge in edges]
|
|
155
|
+
if not scores:
|
|
156
|
+
return 0.0
|
|
157
|
+
return sum(scores) / len(scores)
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _entity_to_node(entity: EntityModel) -> EntityNode:
|
|
161
|
+
attributes = {
|
|
162
|
+
**entity.attributes,
|
|
163
|
+
"confidence": entity.confidence,
|
|
164
|
+
"provenance": entity.provenance,
|
|
165
|
+
"source_document_id": entity.source_document_id,
|
|
166
|
+
"canonical_name": entity.canonical_name,
|
|
167
|
+
}
|
|
168
|
+
return EntityNode(
|
|
169
|
+
entity_id=entity.name,
|
|
170
|
+
name=entity.name,
|
|
171
|
+
entity_type=entity.entity_type,
|
|
172
|
+
attributes=attributes,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def _relation_to_edge(edge: RelationModel) -> RelationEdge:
|
|
177
|
+
attributes = {**edge.attributes, "provenance": edge.provenance}
|
|
178
|
+
return RelationEdge(
|
|
179
|
+
source_id=edge.source,
|
|
180
|
+
target_id=edge.target,
|
|
181
|
+
relation_type=edge.relation_type,
|
|
182
|
+
weight=edge.confidence,
|
|
183
|
+
attributes=attributes,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
@staticmethod
|
|
187
|
+
def _dedupe(values: Iterable[str]) -> list[str]:
|
|
188
|
+
seen: set[str] = set()
|
|
189
|
+
deduped: list[str] = []
|
|
190
|
+
for value in values:
|
|
191
|
+
if value in seen:
|
|
192
|
+
continue
|
|
193
|
+
seen.add(value)
|
|
194
|
+
deduped.append(value)
|
|
195
|
+
return deduped
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class LightRAGGraphAdapter(GraphRetrieverPort):
|
|
199
|
+
"""LightRAG-backed adapter that returns graph contexts."""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
lightrag_client: Any,
|
|
204
|
+
*,
|
|
205
|
+
query_mode: str = "mix",
|
|
206
|
+
query_param: Any | None = None,
|
|
207
|
+
entity_extractor: EntityExtractor | None = None,
|
|
208
|
+
) -> None:
|
|
209
|
+
self._client = lightrag_client
|
|
210
|
+
self._query_mode = query_mode
|
|
211
|
+
self._query_param = query_param
|
|
212
|
+
self._entity_extractor = entity_extractor or EntityExtractor()
|
|
213
|
+
|
|
214
|
+
def extract_entities(self, text: str) -> list[EntityNode]:
|
|
215
|
+
names = [entity.name for entity in self._entity_extractor.extract_entities(text)]
|
|
216
|
+
return [
|
|
217
|
+
EntityNode(entity_id=name, name=name, entity_type="mention")
|
|
218
|
+
for name in _dedupe_values(names)
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
def build_subgraph(
|
|
222
|
+
self,
|
|
223
|
+
query: str,
|
|
224
|
+
max_hops: int = 2,
|
|
225
|
+
max_nodes: int = 20,
|
|
226
|
+
) -> KnowledgeSubgraph:
|
|
227
|
+
if not query:
|
|
228
|
+
return KnowledgeSubgraph(nodes=[], edges=[], relevance_score=0.0)
|
|
229
|
+
|
|
230
|
+
param = self._build_query_param()
|
|
231
|
+
response = self._run_query(query, param)
|
|
232
|
+
context, references = self._extract_context_and_refs(response)
|
|
233
|
+
nodes = self._references_to_nodes(references, max_nodes=max_nodes)
|
|
234
|
+
relevance_score = 1.0 if context else 0.0
|
|
235
|
+
return KnowledgeSubgraph(nodes=nodes, edges=[], relevance_score=relevance_score)
|
|
236
|
+
|
|
237
|
+
def generate_context(self, subgraph: KnowledgeSubgraph) -> str:
|
|
238
|
+
if not subgraph.nodes:
|
|
239
|
+
return ""
|
|
240
|
+
lines = ["References:"]
|
|
241
|
+
for node in subgraph.nodes:
|
|
242
|
+
label = node.name
|
|
243
|
+
if node.attributes:
|
|
244
|
+
ref_id = node.attributes.get("id")
|
|
245
|
+
if ref_id and ref_id != node.name:
|
|
246
|
+
label = f"{node.name} ({ref_id})"
|
|
247
|
+
lines.append(f"- {label}")
|
|
248
|
+
return "\n".join(lines)
|
|
249
|
+
|
|
250
|
+
def _build_query_param(self) -> Any | None:
|
|
251
|
+
if self._query_param is not None:
|
|
252
|
+
return self._query_param
|
|
253
|
+
try:
|
|
254
|
+
from lightrag import QueryParam
|
|
255
|
+
|
|
256
|
+
return QueryParam(
|
|
257
|
+
mode=self._query_mode,
|
|
258
|
+
only_need_context=True,
|
|
259
|
+
include_references=True,
|
|
260
|
+
)
|
|
261
|
+
except Exception:
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
def _run_query(self, query: str, param: Any | None) -> Any:
|
|
265
|
+
if hasattr(self._client, "query"):
|
|
266
|
+
return self._client.query(query, param=param)
|
|
267
|
+
if hasattr(self._client, "aquery"):
|
|
268
|
+
try:
|
|
269
|
+
loop = asyncio.get_running_loop()
|
|
270
|
+
except RuntimeError:
|
|
271
|
+
return asyncio.run(self._client.aquery(query, param=param))
|
|
272
|
+
if loop.is_running():
|
|
273
|
+
raise RuntimeError("LightRAG aquery requires async context")
|
|
274
|
+
raise RuntimeError("LightRAG client must provide query or aquery")
|
|
275
|
+
|
|
276
|
+
@staticmethod
|
|
277
|
+
def _extract_context_and_refs(response: Any) -> tuple[str, list[Any]]:
|
|
278
|
+
if isinstance(response, str):
|
|
279
|
+
return response, []
|
|
280
|
+
if isinstance(response, dict):
|
|
281
|
+
context = response.get("context") or response.get("response") or response.get("answer")
|
|
282
|
+
references = response.get("references") or response.get("refs") or []
|
|
283
|
+
return str(context or ""), list(references) if references else []
|
|
284
|
+
return "", []
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _references_to_nodes(references: list[Any], *, max_nodes: int) -> list[EntityNode]:
|
|
288
|
+
nodes: list[EntityNode] = []
|
|
289
|
+
for idx, ref in enumerate(references, start=1):
|
|
290
|
+
if len(nodes) >= max_nodes:
|
|
291
|
+
break
|
|
292
|
+
if isinstance(ref, dict):
|
|
293
|
+
ref_id = ref.get("id") or ref.get("doc_id") or ref.get("source_id")
|
|
294
|
+
name = str(ref.get("title") or ref_id or f"ref-{idx}")
|
|
295
|
+
attrs = {k: v for k, v in ref.items() if k not in {"title"}}
|
|
296
|
+
nodes.append(
|
|
297
|
+
EntityNode(
|
|
298
|
+
entity_id=str(ref_id or name),
|
|
299
|
+
name=name,
|
|
300
|
+
entity_type="reference",
|
|
301
|
+
attributes={"id": ref_id, **attrs},
|
|
302
|
+
)
|
|
303
|
+
)
|
|
304
|
+
else:
|
|
305
|
+
nodes.append(
|
|
306
|
+
EntityNode(
|
|
307
|
+
entity_id=str(ref),
|
|
308
|
+
name=str(ref),
|
|
309
|
+
entity_type="reference",
|
|
310
|
+
)
|
|
311
|
+
)
|
|
312
|
+
return nodes
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _dedupe_values(values: Iterable[str]) -> list[str]:
|
|
316
|
+
seen: set[str] = set()
|
|
317
|
+
deduped: list[str] = []
|
|
318
|
+
for value in values:
|
|
319
|
+
if value in seen:
|
|
320
|
+
continue
|
|
321
|
+
seen.add(value)
|
|
322
|
+
deduped.append(value)
|
|
323
|
+
return deduped
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
__all__ = ["GraphRAGAdapter", "LightRAGGraphAdapter"]
|