spindle-eval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spindle_eval/__init__.py +3 -0
- spindle_eval/baselines/__init__.py +5 -0
- spindle_eval/baselines/base.py +21 -0
- spindle_eval/baselines/bm25_baseline.py +42 -0
- spindle_eval/baselines/hybrid_search.py +69 -0
- spindle_eval/baselines/naive_rag.py +36 -0
- spindle_eval/baselines/no_rag.py +22 -0
- spindle_eval/baselines/oracle.py +24 -0
- spindle_eval/ci/__init__.py +1 -0
- spindle_eval/ci/regression.py +76 -0
- spindle_eval/ci/reporter.py +40 -0
- spindle_eval/compat.py +153 -0
- spindle_eval/conf/config.yaml +28 -0
- spindle_eval/conf/evaluation/full.yaml +8 -0
- spindle_eval/conf/evaluation/quick.yaml +6 -0
- spindle_eval/conf/extraction/finetuned.yaml +4 -0
- spindle_eval/conf/extraction/llm.yaml +4 -0
- spindle_eval/conf/extraction/nlp.yaml +4 -0
- spindle_eval/conf/generation/claude.yaml +3 -0
- spindle_eval/conf/generation/gemini.yaml +3 -0
- spindle_eval/conf/generation/gpt4.yaml +3 -0
- spindle_eval/conf/ontology/hybrid.yaml +12 -0
- spindle_eval/conf/ontology/schema_first.yaml +12 -0
- spindle_eval/conf/ontology/schema_free.yaml +4 -0
- spindle_eval/conf/preprocessing/default.yaml +4 -0
- spindle_eval/conf/preprocessing/large_chunks.yaml +4 -0
- spindle_eval/conf/preprocessing/small_chunks.yaml +4 -0
- spindle_eval/conf/retrieval/drift.yaml +4 -0
- spindle_eval/conf/retrieval/global.yaml +4 -0
- spindle_eval/conf/retrieval/hybrid.yaml +4 -0
- spindle_eval/conf/retrieval/local.yaml +4 -0
- spindle_eval/conf/sweep/chunk_size.yaml +11 -0
- spindle_eval/conf/sweep/er_threshold.yaml +10 -0
- spindle_eval/conf/sweep/none.yaml +1 -0
- spindle_eval/conf/sweep/retrieval.yaml +12 -0
- spindle_eval/datasets/__init__.py +15 -0
- spindle_eval/datasets/generator.py +82 -0
- spindle_eval/datasets/golden.py +151 -0
- spindle_eval/datasets/kos_reference.py +32 -0
- spindle_eval/datasets/versioning.py +37 -0
- spindle_eval/events/__init__.py +21 -0
- spindle_eval/events/analysis.py +117 -0
- spindle_eval/events/store.py +118 -0
- spindle_eval/golden_data/gold_kg/annotation_guidelines.md +30 -0
- spindle_eval/golden_data/questions.jsonl +3 -0
- spindle_eval/metrics/__init__.py +8 -0
- spindle_eval/metrics/chunk_metrics.py +30 -0
- spindle_eval/metrics/extraction_metrics.py +101 -0
- spindle_eval/metrics/graph_metrics.py +218 -0
- spindle_eval/metrics/kos_loader.py +42 -0
- spindle_eval/metrics/kos_metrics.py +367 -0
- spindle_eval/metrics/provenance_metrics.py +14 -0
- spindle_eval/metrics/ragas_scorers.py +49 -0
- spindle_eval/metrics/statistical.py +147 -0
- spindle_eval/mocks.py +227 -0
- spindle_eval/pipeline.py +120 -0
- spindle_eval/production/__init__.py +1 -0
- spindle_eval/production/feedback_loop.py +53 -0
- spindle_eval/production/staleness.py +39 -0
- spindle_eval/protocols.py +183 -0
- spindle_eval/runner.py +333 -0
- spindle_eval/tracking/__init__.py +39 -0
- spindle_eval/tracking/composite_tracker.py +53 -0
- spindle_eval/tracking/file_tracker.py +95 -0
- spindle_eval/tracking/langfuse_integration.py +39 -0
- spindle_eval/tracking/mlflow_tracker.py +106 -0
- spindle_eval/tracking/noop_tracker.py +44 -0
- spindle_eval-0.1.0.dist-info/METADATA +262 -0
- spindle_eval-0.1.0.dist-info/RECORD +73 -0
- spindle_eval-0.1.0.dist-info/WHEEL +5 -0
- spindle_eval-0.1.0.dist-info/entry_points.txt +2 -0
- spindle_eval-0.1.0.dist-info/licenses/LICENSE +21 -0
- spindle_eval-0.1.0.dist-info/top_level.txt +1 -0
spindle_eval/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Baseline protocol and shared result structures."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Protocol
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class BaselineResult:
|
|
11
|
+
answer: str
|
|
12
|
+
contexts: list[str]
|
|
13
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BaselineRunner(Protocol):
|
|
17
|
+
"""Common interface across baseline systems."""
|
|
18
|
+
|
|
19
|
+
name: str
|
|
20
|
+
|
|
21
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult: ...
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""BM25 lexical retrieval baseline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from rank_bm25 import BM25Okapi
|
|
8
|
+
|
|
9
|
+
from spindle_eval.baselines.base import BaselineResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BM25Baseline:
|
|
13
|
+
name = "bm25"
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
corpus: list[str],
|
|
18
|
+
llm_callable: Callable[[str], str],
|
|
19
|
+
top_k: int = 5,
|
|
20
|
+
) -> None:
|
|
21
|
+
self._corpus = corpus
|
|
22
|
+
self._tokenized = [doc.lower().split() for doc in corpus]
|
|
23
|
+
self._bm25 = BM25Okapi(self._tokenized)
|
|
24
|
+
self._llm_callable = llm_callable
|
|
25
|
+
self._top_k = top_k
|
|
26
|
+
|
|
27
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult:
|
|
28
|
+
top_k = int(kwargs.get("top_k", self._top_k))
|
|
29
|
+
scores = self._bm25.get_scores(question.lower().split())
|
|
30
|
+
ranked = sorted(
|
|
31
|
+
zip(self._corpus, scores, strict=False),
|
|
32
|
+
key=lambda item: item[1],
|
|
33
|
+
reverse=True,
|
|
34
|
+
)
|
|
35
|
+
contexts = [doc for doc, _ in ranked[:top_k]]
|
|
36
|
+
prompt = f"Question: {question}\n\nContext:\n" + "\n".join(contexts)
|
|
37
|
+
answer = self._llm_callable(prompt)
|
|
38
|
+
return BaselineResult(
|
|
39
|
+
answer=answer,
|
|
40
|
+
contexts=contexts,
|
|
41
|
+
metadata={"baseline": self.name, "top_k": top_k},
|
|
42
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Hybrid baseline: weighted fusion of BM25 and dense retrieval."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Any, Callable, Protocol
|
|
7
|
+
|
|
8
|
+
from rank_bm25 import BM25Okapi
|
|
9
|
+
|
|
10
|
+
from spindle_eval.baselines.base import BaselineResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DenseRetriever(Protocol):
|
|
14
|
+
def retrieve_with_scores(self, query: str, top_k: int) -> list[tuple[str, float]]: ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HybridSearchBaseline:
|
|
18
|
+
name = "hybrid_search"
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
corpus: list[str],
|
|
23
|
+
dense_retriever: DenseRetriever,
|
|
24
|
+
llm_callable: Callable[[str], str],
|
|
25
|
+
bm25_weight: float = 0.70,
|
|
26
|
+
dense_weight: float = 0.30,
|
|
27
|
+
top_k: int = 5,
|
|
28
|
+
) -> None:
|
|
29
|
+
self._corpus = corpus
|
|
30
|
+
self._bm25 = BM25Okapi([doc.lower().split() for doc in corpus])
|
|
31
|
+
self._dense_retriever = dense_retriever
|
|
32
|
+
self._llm_callable = llm_callable
|
|
33
|
+
self._bm25_weight = bm25_weight
|
|
34
|
+
self._dense_weight = dense_weight
|
|
35
|
+
self._top_k = top_k
|
|
36
|
+
|
|
37
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult:
|
|
38
|
+
top_k = int(kwargs.get("top_k", self._top_k))
|
|
39
|
+
bm25_scores = self._bm25.get_scores(question.lower().split())
|
|
40
|
+
bm25_ranked = sorted(
|
|
41
|
+
zip(self._corpus, bm25_scores, strict=False),
|
|
42
|
+
key=lambda item: item[1],
|
|
43
|
+
reverse=True,
|
|
44
|
+
)[: max(top_k * 2, 10)]
|
|
45
|
+
|
|
46
|
+
dense_ranked = self._dense_retriever.retrieve_with_scores(
|
|
47
|
+
question,
|
|
48
|
+
top_k=max(top_k * 2, 10),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
fused: dict[str, float] = defaultdict(float)
|
|
52
|
+
for doc, score in bm25_ranked:
|
|
53
|
+
fused[doc] += self._bm25_weight * float(score)
|
|
54
|
+
for doc, score in dense_ranked:
|
|
55
|
+
fused[doc] += self._dense_weight * float(score)
|
|
56
|
+
|
|
57
|
+
contexts = [doc for doc, _ in sorted(fused.items(), key=lambda x: x[1], reverse=True)[:top_k]]
|
|
58
|
+
prompt = f"Question: {question}\n\nContext:\n" + "\n".join(contexts)
|
|
59
|
+
answer = self._llm_callable(prompt)
|
|
60
|
+
return BaselineResult(
|
|
61
|
+
answer=answer,
|
|
62
|
+
contexts=contexts,
|
|
63
|
+
metadata={
|
|
64
|
+
"baseline": self.name,
|
|
65
|
+
"top_k": top_k,
|
|
66
|
+
"bm25_weight": self._bm25_weight,
|
|
67
|
+
"dense_weight": self._dense_weight,
|
|
68
|
+
},
|
|
69
|
+
)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Naive vector top-k RAG baseline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable, Protocol
|
|
6
|
+
|
|
7
|
+
from spindle_eval.baselines.base import BaselineResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class VectorRetriever(Protocol):
|
|
11
|
+
def retrieve(self, query: str, top_k: int) -> list[str]: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NaiveRAGBaseline:
|
|
15
|
+
name = "naive_rag"
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
retriever: VectorRetriever,
|
|
20
|
+
llm_callable: Callable[[str], str],
|
|
21
|
+
top_k: int = 5,
|
|
22
|
+
) -> None:
|
|
23
|
+
self._retriever = retriever
|
|
24
|
+
self._llm_callable = llm_callable
|
|
25
|
+
self._top_k = top_k
|
|
26
|
+
|
|
27
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult:
|
|
28
|
+
top_k = int(kwargs.get("top_k", self._top_k))
|
|
29
|
+
contexts = self._retriever.retrieve(question, top_k=top_k)
|
|
30
|
+
prompt = f"Question: {question}\n\nContext:\n" + "\n".join(contexts)
|
|
31
|
+
answer = self._llm_callable(prompt)
|
|
32
|
+
return BaselineResult(
|
|
33
|
+
answer=answer,
|
|
34
|
+
contexts=contexts,
|
|
35
|
+
metadata={"baseline": self.name, "top_k": top_k},
|
|
36
|
+
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""No-RAG baseline: answer from model parametric memory only."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from spindle_eval.baselines.base import BaselineResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NoRAGBaseline:
|
|
11
|
+
name = "no_rag"
|
|
12
|
+
|
|
13
|
+
def __init__(self, llm_callable: Callable[[str], str]) -> None:
|
|
14
|
+
self._llm_callable = llm_callable
|
|
15
|
+
|
|
16
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult:
|
|
17
|
+
answer = self._llm_callable(question)
|
|
18
|
+
return BaselineResult(
|
|
19
|
+
answer=answer,
|
|
20
|
+
contexts=[],
|
|
21
|
+
metadata={"baseline": self.name},
|
|
22
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Oracle baseline: answer generation with ground-truth contexts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
|
|
7
|
+
from spindle_eval.baselines.base import BaselineResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OracleBaseline:
|
|
11
|
+
name = "oracle"
|
|
12
|
+
|
|
13
|
+
def __init__(self, llm_callable: Callable[[str], str]) -> None:
|
|
14
|
+
self._llm_callable = llm_callable
|
|
15
|
+
|
|
16
|
+
def run(self, question: str, **kwargs: Any) -> BaselineResult:
|
|
17
|
+
contexts = list(kwargs.get("contexts", []))
|
|
18
|
+
prompt = f"Question: {question}\n\nGround truth context:\n" + "\n".join(contexts)
|
|
19
|
+
answer = self._llm_callable(prompt)
|
|
20
|
+
return BaselineResult(
|
|
21
|
+
answer=answer,
|
|
22
|
+
contexts=contexts,
|
|
23
|
+
metadata={"baseline": self.name},
|
|
24
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""CI utilities for regression detection and reporting."""
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Regression detection with tolerance thresholds and confidence intervals."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from spindle_eval.metrics.statistical import bootstrap_ci
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
DEFAULT_TOLERANCES = {
|
|
14
|
+
"faithfulness": -0.02,
|
|
15
|
+
"answer_relevancy": -0.03,
|
|
16
|
+
"context_precision": -0.05,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class RegressionResult:
|
|
22
|
+
passed: bool
|
|
23
|
+
deltas: dict[str, float]
|
|
24
|
+
violations: dict[str, float]
|
|
25
|
+
confidence_intervals: dict[str, dict[str, float]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_baseline_metrics(path: str | Path) -> dict[str, float]:
|
|
29
|
+
payload = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
30
|
+
return {k: float(v) for k, v in payload.items() if not k.startswith("_")}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def compare_against_baseline(
|
|
34
|
+
current_metrics: dict[str, float],
|
|
35
|
+
baseline_metrics: dict[str, float],
|
|
36
|
+
tolerances: dict[str, float] | None = None,
|
|
37
|
+
per_query_deltas: dict[str, list[float]] | None = None,
|
|
38
|
+
) -> RegressionResult:
|
|
39
|
+
"""Compare metrics with tolerance checks and optional bootstrap intervals."""
|
|
40
|
+
limits = {**DEFAULT_TOLERANCES, **(tolerances or {})}
|
|
41
|
+
deltas: dict[str, float] = {}
|
|
42
|
+
violations: dict[str, float] = {}
|
|
43
|
+
intervals: dict[str, dict[str, float]] = {}
|
|
44
|
+
|
|
45
|
+
for metric, baseline in baseline_metrics.items():
|
|
46
|
+
if metric not in current_metrics:
|
|
47
|
+
continue
|
|
48
|
+
delta = float(current_metrics[metric] - baseline)
|
|
49
|
+
deltas[metric] = delta
|
|
50
|
+
|
|
51
|
+
if metric in limits and delta < limits[metric]:
|
|
52
|
+
violations[metric] = delta
|
|
53
|
+
|
|
54
|
+
if per_query_deltas and metric in per_query_deltas:
|
|
55
|
+
ci = bootstrap_ci(per_query_deltas[metric], confidence_level=0.95)
|
|
56
|
+
intervals[metric] = {
|
|
57
|
+
"lower": ci.lower,
|
|
58
|
+
"upper": ci.upper,
|
|
59
|
+
"point_estimate": ci.point_estimate,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
return RegressionResult(
|
|
63
|
+
passed=len(violations) == 0,
|
|
64
|
+
deltas=deltas,
|
|
65
|
+
violations=violations,
|
|
66
|
+
confidence_intervals=intervals,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def to_dict(result: RegressionResult) -> dict[str, Any]:
|
|
71
|
+
return {
|
|
72
|
+
"passed": result.passed,
|
|
73
|
+
"deltas": result.deltas,
|
|
74
|
+
"violations": result.violations,
|
|
75
|
+
"confidence_intervals": result.confidence_intervals,
|
|
76
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""PR comment/report generation for evaluation results."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from spindle_eval.ci.regression import RegressionResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _fmt(value: float) -> str:
|
|
9
|
+
return f"{value:+.4f}"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def render_regression_report(result: RegressionResult) -> str:
|
|
13
|
+
"""Render markdown report suitable for GitHub PR comments."""
|
|
14
|
+
lines: list[str] = []
|
|
15
|
+
status = "PASS" if result.passed else "FAIL"
|
|
16
|
+
lines.append(f"## spindle-eval Regression Check: {status}")
|
|
17
|
+
lines.append("")
|
|
18
|
+
lines.append("| Metric | Delta vs Baseline | Status |")
|
|
19
|
+
lines.append("|---|---:|---|")
|
|
20
|
+
for metric, delta in sorted(result.deltas.items()):
|
|
21
|
+
violated = metric in result.violations
|
|
22
|
+
label = "REGRESSION" if violated else "OK"
|
|
23
|
+
lines.append(f"| `{metric}` | `{_fmt(delta)}` | {label} |")
|
|
24
|
+
|
|
25
|
+
if result.confidence_intervals:
|
|
26
|
+
lines.append("")
|
|
27
|
+
lines.append("### Bootstrap 95% CI")
|
|
28
|
+
lines.append("| Metric | Lower | Upper | Point Estimate |")
|
|
29
|
+
lines.append("|---|---:|---:|---:|")
|
|
30
|
+
for metric, ci in sorted(result.confidence_intervals.items()):
|
|
31
|
+
lines.append(
|
|
32
|
+
f"| `{metric}` | `{ci['lower']:.4f}` | `{ci['upper']:.4f}` | `{ci['point_estimate']:.4f}` |"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if result.violations:
|
|
36
|
+
lines.append("")
|
|
37
|
+
lines.append("### Violations")
|
|
38
|
+
for metric, delta in sorted(result.violations.items()):
|
|
39
|
+
lines.append(f"- `{metric}` dropped by `{_fmt(delta)}`")
|
|
40
|
+
return "\n".join(lines)
|
spindle_eval/compat.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""Compatibility layer for wrapping legacy component dicts as Stage pipelines.
|
|
2
|
+
|
|
3
|
+
Bridges the old 6-component dict interface (preprocessor, ontology_extractor,
|
|
4
|
+
triple_extractor, graph_store, retriever, generator) to the new StageDef/Stage
|
|
5
|
+
pipeline model.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from spindle_eval.protocols import StageDef, StageResult
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WrappedCallableStage:
|
|
16
|
+
"""Wraps a legacy callable component as a Stage.
|
|
17
|
+
|
|
18
|
+
Each stage type has its own calling convention, so we dispatch based
|
|
19
|
+
on the stage name to unpack the correct arguments.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, name: str, component: Any, **extras: Any) -> None:
|
|
23
|
+
self.name = name
|
|
24
|
+
self._component = component
|
|
25
|
+
self._extras = extras
|
|
26
|
+
|
|
27
|
+
def run(self, inputs: dict[str, Any], cfg: Any) -> StageResult:
|
|
28
|
+
if self.name == "preprocessing":
|
|
29
|
+
return self._run_preprocessor(cfg)
|
|
30
|
+
if self.name == "ontology_extraction":
|
|
31
|
+
return self._run_ontology(inputs, cfg)
|
|
32
|
+
if self.name == "triple_extraction":
|
|
33
|
+
return self._run_triple_extraction(inputs, cfg)
|
|
34
|
+
if self.name == "retrieval":
|
|
35
|
+
return self._run_retrieval(inputs, cfg)
|
|
36
|
+
if self.name == "generation":
|
|
37
|
+
return self._run_generation(inputs, cfg)
|
|
38
|
+
# Generic fallback: call component with (inputs, cfg)
|
|
39
|
+
result = self._component(inputs, cfg)
|
|
40
|
+
return StageResult(outputs={"result": result})
|
|
41
|
+
|
|
42
|
+
def _run_preprocessor(self, cfg: Any) -> StageResult:
|
|
43
|
+
chunks = self._component(cfg)
|
|
44
|
+
return StageResult(
|
|
45
|
+
outputs={"chunks": chunks},
|
|
46
|
+
metrics={"num_chunks": float(len(chunks))},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def _run_ontology(self, inputs: dict[str, Any], cfg: Any) -> StageResult:
|
|
50
|
+
chunks = inputs["chunks"]
|
|
51
|
+
ontology = self._component(chunks, cfg)
|
|
52
|
+
return StageResult(
|
|
53
|
+
outputs={"ontology": ontology},
|
|
54
|
+
metrics={
|
|
55
|
+
"entity_type_count": float(len(ontology.entity_types)),
|
|
56
|
+
"relation_type_count": float(len(ontology.relation_types)),
|
|
57
|
+
},
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def _run_triple_extraction(
|
|
61
|
+
self, inputs: dict[str, Any], cfg: Any
|
|
62
|
+
) -> StageResult:
|
|
63
|
+
chunks = inputs["chunks"]
|
|
64
|
+
ontology = inputs["ontology"]
|
|
65
|
+
triples = self._component(chunks, ontology, cfg)
|
|
66
|
+
graph_store = self._extras.get("graph_store")
|
|
67
|
+
if graph_store is not None:
|
|
68
|
+
graph_store.ingest(triples)
|
|
69
|
+
return StageResult(
|
|
70
|
+
outputs={"triples": triples},
|
|
71
|
+
metrics={"triple_count": float(len(triples))},
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def _run_retrieval(self, inputs: dict[str, Any], cfg: Any) -> StageResult:
|
|
75
|
+
component = self._component
|
|
76
|
+
query = inputs.get("query", "default query")
|
|
77
|
+
if callable(component):
|
|
78
|
+
contexts = component(query, cfg)
|
|
79
|
+
elif hasattr(component, "retrieve"):
|
|
80
|
+
contexts = component.retrieve(query, cfg)
|
|
81
|
+
else:
|
|
82
|
+
raise TypeError(
|
|
83
|
+
"Retriever must be callable or expose a retrieve() method."
|
|
84
|
+
)
|
|
85
|
+
return StageResult(
|
|
86
|
+
outputs={"contexts": contexts},
|
|
87
|
+
metrics={"num_contexts": float(len(contexts))},
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def _run_generation(self, inputs: dict[str, Any], cfg: Any) -> StageResult:
|
|
91
|
+
component = self._component
|
|
92
|
+
contexts = inputs["contexts"]
|
|
93
|
+
query = inputs.get("query", "default query")
|
|
94
|
+
if callable(component):
|
|
95
|
+
answer = component(query, contexts, cfg)
|
|
96
|
+
elif hasattr(component, "generate"):
|
|
97
|
+
answer = component.generate(query, contexts, cfg)
|
|
98
|
+
else:
|
|
99
|
+
raise TypeError(
|
|
100
|
+
"Generator must be callable or expose a generate() method."
|
|
101
|
+
)
|
|
102
|
+
return StageResult(
|
|
103
|
+
outputs={"answer": answer},
|
|
104
|
+
metrics={"answer_length": float(len(answer))},
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def from_component_dict(
|
|
109
|
+
components: dict[str, Any],
|
|
110
|
+
*,
|
|
111
|
+
tracker: Any = None,
|
|
112
|
+
) -> list[StageDef]:
|
|
113
|
+
"""Convert a legacy component dict to a list of StageDefs.
|
|
114
|
+
|
|
115
|
+
Expected keys: preprocessor, ontology_extractor, triple_extractor,
|
|
116
|
+
graph_store, retriever, generator.
|
|
117
|
+
"""
|
|
118
|
+
graph_store = components.get("graph_store")
|
|
119
|
+
|
|
120
|
+
return [
|
|
121
|
+
StageDef(
|
|
122
|
+
name="preprocessing",
|
|
123
|
+
stage=WrappedCallableStage("preprocessing", components["preprocessor"]),
|
|
124
|
+
),
|
|
125
|
+
StageDef(
|
|
126
|
+
name="ontology_extraction",
|
|
127
|
+
stage=WrappedCallableStage(
|
|
128
|
+
"ontology_extraction", components["ontology_extractor"]
|
|
129
|
+
),
|
|
130
|
+
input_keys={"chunks": "preprocessing.chunks"},
|
|
131
|
+
),
|
|
132
|
+
StageDef(
|
|
133
|
+
name="triple_extraction",
|
|
134
|
+
stage=WrappedCallableStage(
|
|
135
|
+
"triple_extraction",
|
|
136
|
+
components["triple_extractor"],
|
|
137
|
+
graph_store=graph_store,
|
|
138
|
+
),
|
|
139
|
+
input_keys={
|
|
140
|
+
"chunks": "preprocessing.chunks",
|
|
141
|
+
"ontology": "ontology_extraction.ontology",
|
|
142
|
+
},
|
|
143
|
+
),
|
|
144
|
+
StageDef(
|
|
145
|
+
name="retrieval",
|
|
146
|
+
stage=WrappedCallableStage("retrieval", components["retriever"]),
|
|
147
|
+
),
|
|
148
|
+
StageDef(
|
|
149
|
+
name="generation",
|
|
150
|
+
stage=WrappedCallableStage("generation", components["generator"]),
|
|
151
|
+
input_keys={"contexts": "retrieval.contexts"},
|
|
152
|
+
),
|
|
153
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
defaults:
|
|
2
|
+
- preprocessing: default
|
|
3
|
+
- ontology: hybrid
|
|
4
|
+
- extraction: llm
|
|
5
|
+
- retrieval: hybrid
|
|
6
|
+
- generation: gpt4
|
|
7
|
+
- evaluation: quick
|
|
8
|
+
- sweep: none
|
|
9
|
+
- _self_
|
|
10
|
+
|
|
11
|
+
run_name: spindle-eval-default
|
|
12
|
+
seed: 7
|
|
13
|
+
dataset_path: golden_data/questions.jsonl
|
|
14
|
+
oracle:
|
|
15
|
+
enabled: false
|
|
16
|
+
stage: ""
|
|
17
|
+
|
|
18
|
+
tracking:
|
|
19
|
+
backend: mlflow
|
|
20
|
+
mlflow_tracking_uri: ${oc.env:MLFLOW_TRACKING_URI,http://mlflow.internal:5000}
|
|
21
|
+
experiment_name: ${oc.env:MLFLOW_EXPERIMENT_NAME,graph-rag-sweeps}
|
|
22
|
+
enable_langfuse: false
|
|
23
|
+
langfuse_endpoint: ${oc.env:LANGFUSE_OTEL_ENDPOINT,http://langfuse.internal:4318/api/public/otel}
|
|
24
|
+
|
|
25
|
+
runner:
|
|
26
|
+
use_spindle_when_available: true
|
|
27
|
+
allow_mock_fallback: true
|
|
28
|
+
artifact_output_dir: ${oc.env:SPINDLE_EVAL_ARTIFACT_DIR,artifacts}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
enabled: true
|
|
2
|
+
hydra:
|
|
3
|
+
sweeper:
|
|
4
|
+
direction: maximize
|
|
5
|
+
n_trials: 50
|
|
6
|
+
sampler:
|
|
7
|
+
_target_: optuna.samplers.TPESampler
|
|
8
|
+
seed: 7
|
|
9
|
+
params:
|
|
10
|
+
preprocessing.chunk_size: choice(128,256,384,512,640,768,896,1024,1280,1536,1792,2048)
|
|
11
|
+
preprocessing.overlap: choice(0.0,0.1,0.2)
|