haiku.rag 0.11.1__tar.gz → 0.11.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag might be problematic. Click here for more details.
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/.gitignore +1 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/PKG-INFO +7 -1
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/README.md +6 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/pyproject.toml +5 -1
- haiku_rag-0.11.3/src/evaluations/benchmark.py +320 -0
- haiku_rag-0.11.3/src/evaluations/config.py +46 -0
- haiku_rag-0.11.3/src/evaluations/datasets/__init__.py +8 -0
- haiku_rag-0.11.3/src/evaluations/datasets/repliqa.py +58 -0
- haiku_rag-0.11.3/src/evaluations/datasets/wix.py +81 -0
- {haiku_rag-0.11.1/tests → haiku_rag-0.11.3/src/evaluations}/llm_judge.py +2 -1
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/app.py +36 -2
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/cli.py +11 -1
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/client.py +47 -22
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/config.py +2 -2
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/ollama.py +2 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/openai.py +2 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/vllm.py +2 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/voyageai.py +2 -0
- haiku_rag-0.11.3/src/haiku/rag/graph/__init__.py +1 -0
- haiku_rag-0.11.3/src/haiku/rag/graph/base.py +31 -0
- haiku_rag-0.11.3/src/haiku/rag/graph/common.py +33 -0
- haiku_rag-0.11.3/src/haiku/rag/graph/models.py +24 -0
- haiku_rag-0.11.3/src/haiku/rag/graph/nodes/__init__.py +0 -0
- {haiku_rag-0.11.1/src/haiku/rag/research → haiku_rag-0.11.3/src/haiku/rag/graph}/nodes/analysis.py +5 -4
- {haiku_rag-0.11.1/src/haiku/rag/research → haiku_rag-0.11.3/src/haiku/rag/graph}/nodes/plan.py +6 -4
- {haiku_rag-0.11.1/src/haiku/rag/research → haiku_rag-0.11.3/src/haiku/rag/graph}/nodes/search.py +5 -4
- {haiku_rag-0.11.1/src/haiku/rag/research → haiku_rag-0.11.3/src/haiku/rag/graph}/nodes/synthesize.py +3 -4
- haiku_rag-0.11.3/src/haiku/rag/graph/prompts.py +45 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/monitor.py +2 -2
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/__init__.py +1 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/dependencies.py +29 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/graph.py +21 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/models.py +20 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/nodes.py +303 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/prompts.py +57 -0
- haiku_rag-0.11.3/src/haiku/rag/qa/deep/state.py +25 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reranking/__init__.py +3 -0
- haiku_rag-0.11.3/src/haiku/rag/research/__init__.py +3 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/common.py +0 -31
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/dependencies.py +1 -1
- haiku_rag-0.11.3/src/haiku/rag/research/graph.py +20 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/models.py +0 -25
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/prompts.py +0 -46
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/repositories/settings.py +3 -3
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/uv.lock +3 -1
- haiku_rag-0.11.1/.github/FUNDING.yml +0 -3
- haiku_rag-0.11.1/.github/workflows/build-docs.yml +0 -28
- haiku_rag-0.11.1/.github/workflows/build-publish.yml +0 -18
- haiku_rag-0.11.1/docs/agents.md +0 -154
- haiku_rag-0.11.1/docs/benchmarks.md +0 -36
- haiku_rag-0.11.1/docs/cli.md +0 -219
- haiku_rag-0.11.1/docs/configuration.md +0 -267
- haiku_rag-0.11.1/docs/index.md +0 -65
- haiku_rag-0.11.1/docs/installation.md +0 -84
- haiku_rag-0.11.1/docs/mcp.md +0 -30
- haiku_rag-0.11.1/docs/python.md +0 -214
- haiku_rag-0.11.1/docs/server.md +0 -41
- haiku_rag-0.11.1/src/haiku/rag/research/__init__.py +0 -28
- haiku_rag-0.11.1/src/haiku/rag/research/graph.py +0 -31
- haiku_rag-0.11.1/tests/conftest.py +0 -26
- haiku_rag-0.11.1/tests/generate_benchmark_db.py +0 -171
- haiku_rag-0.11.1/tests/test_app.py +0 -248
- haiku_rag-0.11.1/tests/test_chunk.py +0 -195
- haiku_rag-0.11.1/tests/test_chunker.py +0 -39
- haiku_rag-0.11.1/tests/test_cli.py +0 -235
- haiku_rag-0.11.1/tests/test_client.py +0 -796
- haiku_rag-0.11.1/tests/test_document.py +0 -107
- haiku_rag-0.11.1/tests/test_embedder.py +0 -171
- haiku_rag-0.11.1/tests/test_info.py +0 -79
- haiku_rag-0.11.1/tests/test_lancedb_connection.py +0 -86
- haiku_rag-0.11.1/tests/test_monitor.py +0 -93
- haiku_rag-0.11.1/tests/test_preprocessor.py +0 -71
- haiku_rag-0.11.1/tests/test_qa.py +0 -106
- haiku_rag-0.11.1/tests/test_reader.py +0 -23
- haiku_rag-0.11.1/tests/test_rebuild.py +0 -49
- haiku_rag-0.11.1/tests/test_reranker.py +0 -89
- haiku_rag-0.11.1/tests/test_research_graph.py +0 -25
- haiku_rag-0.11.1/tests/test_research_graph_integration.py +0 -138
- haiku_rag-0.11.1/tests/test_search.py +0 -208
- haiku_rag-0.11.1/tests/test_settings.py +0 -84
- haiku_rag-0.11.1/tests/test_utils.py +0 -115
- haiku_rag-0.11.1/tests/test_versioning.py +0 -124
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/.pre-commit-config.yaml +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/.python-version +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/LICENSE +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/mkdocs.yml +0 -0
- {haiku_rag-0.11.1/src/haiku/rag → haiku_rag-0.11.3/src/evaluations}/__init__.py +0 -0
- {haiku_rag-0.11.1/tests → haiku_rag-0.11.3/src/haiku/rag}/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/chunker.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/embeddings/base.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/logging.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/mcp.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/migration.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/qa/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/qa/agent.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/qa/prompts.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reader.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reranking/base.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reranking/cohere.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reranking/mxbai.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/reranking/vllm.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/state.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/research/stream.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/engine.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/models/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/models/chunk.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/models/document.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/repositories/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/repositories/chunk.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/repositories/document.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/upgrades/__init__.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/upgrades/v0_10_1.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/store/upgrades/v0_9_3.py +0 -0
- {haiku_rag-0.11.1 → haiku_rag-0.11.3}/src/haiku/rag/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: haiku.rag
|
|
3
|
-
Version: 0.11.
|
|
3
|
+
Version: 0.11.3
|
|
4
4
|
Summary: Agentic Retrieval Augmented Generation (RAG) with LanceDB
|
|
5
5
|
Author-email: Yiorgis Gozadinos <ggozadinos@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -78,6 +78,12 @@ haiku-rag ask "Who is the author of haiku.rag?"
|
|
|
78
78
|
# Ask questions with citations
|
|
79
79
|
haiku-rag ask "Who is the author of haiku.rag?" --cite
|
|
80
80
|
|
|
81
|
+
# Deep QA (multi-agent question decomposition)
|
|
82
|
+
haiku-rag ask "Who is the author of haiku.rag?" --deep --cite
|
|
83
|
+
|
|
84
|
+
# Deep QA with verbose output
|
|
85
|
+
haiku-rag ask "Who is the author of haiku.rag?" --deep --verbose
|
|
86
|
+
|
|
81
87
|
# Multi‑agent research (iterative plan/search/evaluate)
|
|
82
88
|
haiku-rag research \
|
|
83
89
|
"What are the main drivers and trends of global temperature anomalies since 1990?" \
|
|
@@ -40,6 +40,12 @@ haiku-rag ask "Who is the author of haiku.rag?"
|
|
|
40
40
|
# Ask questions with citations
|
|
41
41
|
haiku-rag ask "Who is the author of haiku.rag?" --cite
|
|
42
42
|
|
|
43
|
+
# Deep QA (multi-agent question decomposition)
|
|
44
|
+
haiku-rag ask "Who is the author of haiku.rag?" --deep --cite
|
|
45
|
+
|
|
46
|
+
# Deep QA with verbose output
|
|
47
|
+
haiku-rag ask "Who is the author of haiku.rag?" --deep --verbose
|
|
48
|
+
|
|
43
49
|
# Multi‑agent research (iterative plan/search/evaluate)
|
|
44
50
|
haiku-rag research \
|
|
45
51
|
"What are the main drivers and trends of global temperature anomalies since 1990?" \
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
name = "haiku.rag"
|
|
4
4
|
description = "Agentic Retrieval Augmented Generation (RAG) with LanceDB"
|
|
5
|
-
version = "0.11.
|
|
5
|
+
version = "0.11.3"
|
|
6
6
|
authors = [{ name = "Yiorgis Gozadinos", email = "ggozadinos@gmail.com" }]
|
|
7
7
|
license = { text = "MIT" }
|
|
8
8
|
readme = { file = "README.md", content-type = "text/markdown" }
|
|
@@ -48,6 +48,9 @@ haiku-rag = "haiku.rag.cli:cli"
|
|
|
48
48
|
requires = ["hatchling"]
|
|
49
49
|
build-backend = "hatchling.build"
|
|
50
50
|
|
|
51
|
+
[tool.hatch.build]
|
|
52
|
+
exclude = ["/docs", "/tests", "/.github"]
|
|
53
|
+
|
|
51
54
|
[tool.hatch.build.targets.wheel]
|
|
52
55
|
packages = ["src/haiku"]
|
|
53
56
|
|
|
@@ -57,6 +60,7 @@ dev = [
|
|
|
57
60
|
"logfire>=4.7.0",
|
|
58
61
|
"mkdocs>=1.6.1",
|
|
59
62
|
"mkdocs-material>=9.6.14",
|
|
63
|
+
"pydantic-evals>=1.0.8",
|
|
60
64
|
"pre-commit>=4.2.0",
|
|
61
65
|
"pyright>=1.1.405",
|
|
62
66
|
"pytest>=8.4.2",
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Mapping
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import logfire
|
|
6
|
+
import typer
|
|
7
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
8
|
+
from pydantic_ai.providers.ollama import OllamaProvider
|
|
9
|
+
from pydantic_evals import Dataset as EvalDataset
|
|
10
|
+
from pydantic_evals.evaluators import IsInstance, LLMJudge
|
|
11
|
+
from pydantic_evals.reporting import ReportCaseFailure
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
from rich.progress import Progress
|
|
14
|
+
|
|
15
|
+
from evaluations.config import DatasetSpec, RetrievalSample
|
|
16
|
+
from evaluations.datasets import DATASETS
|
|
17
|
+
from evaluations.llm_judge import ANSWER_EQUIVALENCE_RUBRIC
|
|
18
|
+
from haiku.rag import logging # noqa: F401
|
|
19
|
+
from haiku.rag.client import HaikuRAG
|
|
20
|
+
from haiku.rag.config import Config
|
|
21
|
+
from haiku.rag.logging import configure_cli_logging
|
|
22
|
+
from haiku.rag.qa import get_qa_agent
|
|
23
|
+
|
|
24
|
+
QA_JUDGE_MODEL = "qwen3"
|
|
25
|
+
|
|
26
|
+
logfire.configure(send_to_logfire="if-token-present", service_name="evals")
|
|
27
|
+
logfire.instrument_pydantic_ai()
|
|
28
|
+
configure_cli_logging()
|
|
29
|
+
console = Console()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
async def populate_db(spec: DatasetSpec) -> None:
|
|
33
|
+
spec.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
34
|
+
corpus = spec.document_loader()
|
|
35
|
+
if spec.document_limit is not None:
|
|
36
|
+
corpus = corpus.select(range(min(spec.document_limit, len(corpus))))
|
|
37
|
+
|
|
38
|
+
with Progress() as progress:
|
|
39
|
+
task = progress.add_task("[green]Populating database...", total=len(corpus))
|
|
40
|
+
async with HaikuRAG(spec.db_path) as rag:
|
|
41
|
+
for doc in corpus:
|
|
42
|
+
doc_mapping = cast(Mapping[str, Any], doc)
|
|
43
|
+
payload = spec.document_mapper(doc_mapping)
|
|
44
|
+
if payload is None:
|
|
45
|
+
progress.advance(task)
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
existing = await rag.get_document_by_uri(payload.uri)
|
|
49
|
+
if existing is not None:
|
|
50
|
+
assert existing.id
|
|
51
|
+
chunks = await rag.chunk_repository.get_by_document_id(existing.id)
|
|
52
|
+
if chunks:
|
|
53
|
+
progress.advance(task)
|
|
54
|
+
continue
|
|
55
|
+
await rag.document_repository.delete(existing.id)
|
|
56
|
+
|
|
57
|
+
await rag.create_document(
|
|
58
|
+
content=payload.content,
|
|
59
|
+
uri=payload.uri,
|
|
60
|
+
title=payload.title,
|
|
61
|
+
metadata=payload.metadata,
|
|
62
|
+
)
|
|
63
|
+
progress.advance(task)
|
|
64
|
+
rag.store.vacuum()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _is_relevant_match(retrieved_uri: str | None, sample: RetrievalSample) -> bool:
|
|
68
|
+
return retrieved_uri is not None and retrieved_uri in sample.expected_uris
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
async def run_retrieval_benchmark(spec: DatasetSpec) -> dict[str, float] | None:
|
|
72
|
+
if spec.retrieval_loader is None or spec.retrieval_mapper is None:
|
|
73
|
+
console.print("Skipping retrieval benchmark; no retrieval config.")
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
corpus = spec.retrieval_loader()
|
|
77
|
+
|
|
78
|
+
recall_totals = {
|
|
79
|
+
1: 0.0,
|
|
80
|
+
3: 0.0,
|
|
81
|
+
5: 0.0,
|
|
82
|
+
}
|
|
83
|
+
total_queries = 0
|
|
84
|
+
|
|
85
|
+
with Progress() as progress:
|
|
86
|
+
task = progress.add_task(
|
|
87
|
+
"[blue]Running retrieval benchmark...", total=len(corpus)
|
|
88
|
+
)
|
|
89
|
+
async with HaikuRAG(spec.db_path) as rag:
|
|
90
|
+
for doc in corpus:
|
|
91
|
+
doc_mapping = cast(Mapping[str, Any], doc)
|
|
92
|
+
sample = spec.retrieval_mapper(doc_mapping)
|
|
93
|
+
if sample is None or sample.skip:
|
|
94
|
+
progress.advance(task)
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
matches = await rag.search(query=sample.question, limit=5)
|
|
98
|
+
if not matches:
|
|
99
|
+
progress.advance(task)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
total_queries += 1
|
|
103
|
+
|
|
104
|
+
retrieved_uris: list[str] = []
|
|
105
|
+
for chunk, _ in matches:
|
|
106
|
+
if chunk.document_id is None:
|
|
107
|
+
continue
|
|
108
|
+
retrieved_doc = await rag.get_document_by_id(chunk.document_id)
|
|
109
|
+
if retrieved_doc and retrieved_doc.uri:
|
|
110
|
+
retrieved_uris.append(retrieved_doc.uri)
|
|
111
|
+
|
|
112
|
+
# Compute per-query recall@K by counting how many relevant
|
|
113
|
+
# documents are retrieved within the first K results and
|
|
114
|
+
# averaging these fractions across all queries.
|
|
115
|
+
for cutoff in (1, 3, 5):
|
|
116
|
+
top_k = set(retrieved_uris[:cutoff])
|
|
117
|
+
relevant = set(sample.expected_uris)
|
|
118
|
+
if relevant:
|
|
119
|
+
matched = len(top_k & relevant)
|
|
120
|
+
recall_totals[cutoff] += matched / len(relevant)
|
|
121
|
+
|
|
122
|
+
progress.advance(task)
|
|
123
|
+
|
|
124
|
+
if total_queries == 0:
|
|
125
|
+
console.print("No retrieval cases to evaluate.")
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
recall_at_1 = recall_totals[1] / total_queries
|
|
129
|
+
recall_at_3 = recall_totals[3] / total_queries
|
|
130
|
+
recall_at_5 = recall_totals[5] / total_queries
|
|
131
|
+
|
|
132
|
+
console.print("\n=== Retrieval Benchmark Results ===", style="bold cyan")
|
|
133
|
+
console.print(f"Total queries: {total_queries}")
|
|
134
|
+
console.print(f"Recall@1: {recall_at_1:.4f}")
|
|
135
|
+
console.print(f"Recall@3: {recall_at_3:.4f}")
|
|
136
|
+
console.print(f"Recall@5: {recall_at_5:.4f}")
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
"recall@1": recall_at_1,
|
|
140
|
+
"recall@3": recall_at_3,
|
|
141
|
+
"recall@5": recall_at_5,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def run_qa_benchmark(
|
|
146
|
+
spec: DatasetSpec, qa_limit: int | None = None
|
|
147
|
+
) -> ReportCaseFailure[str, str, dict[str, str]] | None:
|
|
148
|
+
corpus = spec.qa_loader()
|
|
149
|
+
if qa_limit is not None:
|
|
150
|
+
corpus = corpus.select(range(min(qa_limit, len(corpus))))
|
|
151
|
+
|
|
152
|
+
cases = [
|
|
153
|
+
spec.qa_case_builder(index, cast(Mapping[str, Any], doc))
|
|
154
|
+
for index, doc in enumerate(corpus, start=1)
|
|
155
|
+
]
|
|
156
|
+
|
|
157
|
+
judge_model = OpenAIChatModel(
|
|
158
|
+
model_name=QA_JUDGE_MODEL,
|
|
159
|
+
provider=OllamaProvider(base_url=f"{Config.OLLAMA_BASE_URL}/v1"),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
evaluation_dataset = EvalDataset[str, str, dict[str, str]](
|
|
163
|
+
cases=cases,
|
|
164
|
+
evaluators=[
|
|
165
|
+
IsInstance(type_name="str"),
|
|
166
|
+
LLMJudge(
|
|
167
|
+
rubric=ANSWER_EQUIVALENCE_RUBRIC,
|
|
168
|
+
include_input=True,
|
|
169
|
+
include_expected_output=True,
|
|
170
|
+
model=judge_model,
|
|
171
|
+
assertion={
|
|
172
|
+
"evaluation_name": "answer_equivalent",
|
|
173
|
+
"include_reason": True,
|
|
174
|
+
},
|
|
175
|
+
),
|
|
176
|
+
],
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
total_processed = 0
|
|
180
|
+
passing_cases = 0
|
|
181
|
+
failures: list[ReportCaseFailure[str, str, dict[str, str]]] = []
|
|
182
|
+
|
|
183
|
+
with Progress(console=console) as progress:
|
|
184
|
+
qa_task = progress.add_task(
|
|
185
|
+
"[yellow]Evaluating QA cases...",
|
|
186
|
+
total=len(evaluation_dataset.cases),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
async with HaikuRAG(spec.db_path) as rag:
|
|
190
|
+
qa = get_qa_agent(rag)
|
|
191
|
+
|
|
192
|
+
async def answer_question(question: str) -> str:
|
|
193
|
+
return await qa.answer(question)
|
|
194
|
+
|
|
195
|
+
for case in evaluation_dataset.cases:
|
|
196
|
+
progress.console.print(f"\n[bold]Evaluating case:[/bold] {case.name}")
|
|
197
|
+
|
|
198
|
+
single_case_dataset = EvalDataset[str, str, dict[str, str]](
|
|
199
|
+
cases=[case],
|
|
200
|
+
evaluators=evaluation_dataset.evaluators,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
report = await single_case_dataset.evaluate(
|
|
204
|
+
answer_question,
|
|
205
|
+
name="qa_answer",
|
|
206
|
+
max_concurrency=1,
|
|
207
|
+
progress=False,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
total_processed += 1
|
|
211
|
+
|
|
212
|
+
if report.cases:
|
|
213
|
+
result_case = report.cases[0]
|
|
214
|
+
|
|
215
|
+
equivalence = result_case.assertions.get("answer_equivalent")
|
|
216
|
+
progress.console.print(f"Question: {result_case.inputs}")
|
|
217
|
+
progress.console.print(f"Expected: {result_case.expected_output}")
|
|
218
|
+
progress.console.print(f"Generated: {result_case.output}")
|
|
219
|
+
if equivalence is not None:
|
|
220
|
+
progress.console.print(
|
|
221
|
+
f"Equivalent: {equivalence.value}"
|
|
222
|
+
+ (f" — {equivalence.reason}" if equivalence.reason else "")
|
|
223
|
+
)
|
|
224
|
+
if equivalence.value:
|
|
225
|
+
passing_cases += 1
|
|
226
|
+
|
|
227
|
+
progress.console.print("")
|
|
228
|
+
|
|
229
|
+
if report.failures:
|
|
230
|
+
failures.extend(report.failures)
|
|
231
|
+
failure = report.failures[0]
|
|
232
|
+
progress.console.print(
|
|
233
|
+
"[red]Failure encountered during case evaluation:[/red]"
|
|
234
|
+
)
|
|
235
|
+
progress.console.print(f"Question: {failure.inputs}")
|
|
236
|
+
progress.console.print(f"Error: {failure.error_message}")
|
|
237
|
+
progress.console.print("")
|
|
238
|
+
|
|
239
|
+
progress.console.print(
|
|
240
|
+
f"[green]Accuracy: {(passing_cases / total_processed):.4f} "
|
|
241
|
+
f"{passing_cases}/{total_processed}[/green]"
|
|
242
|
+
)
|
|
243
|
+
progress.advance(qa_task)
|
|
244
|
+
|
|
245
|
+
total_cases = total_processed
|
|
246
|
+
accuracy = passing_cases / total_cases if total_cases > 0 else 0
|
|
247
|
+
|
|
248
|
+
console.print("\n=== QA Benchmark Results ===", style="bold cyan")
|
|
249
|
+
console.print(f"Total questions: {total_cases}")
|
|
250
|
+
console.print(f"Correct answers: {passing_cases}")
|
|
251
|
+
console.print(f"QA Accuracy: {accuracy:.4f} ({accuracy * 100:.2f}%)")
|
|
252
|
+
|
|
253
|
+
if failures:
|
|
254
|
+
console.print("[red]\nSummary of failures:[/red]")
|
|
255
|
+
for failure in failures:
|
|
256
|
+
console.print(f"Case: {failure.name}")
|
|
257
|
+
console.print(f"Question: {failure.inputs}")
|
|
258
|
+
console.print(f"Error: {failure.error_message}")
|
|
259
|
+
console.print("")
|
|
260
|
+
|
|
261
|
+
return failures[0] if failures else None
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
async def evaluate_dataset(
|
|
265
|
+
spec: DatasetSpec,
|
|
266
|
+
skip_db: bool,
|
|
267
|
+
skip_retrieval: bool,
|
|
268
|
+
skip_qa: bool,
|
|
269
|
+
qa_limit: int | None,
|
|
270
|
+
) -> None:
|
|
271
|
+
if not skip_db:
|
|
272
|
+
console.print(f"Using dataset: {spec.key}", style="bold magenta")
|
|
273
|
+
await populate_db(spec)
|
|
274
|
+
|
|
275
|
+
if not skip_retrieval:
|
|
276
|
+
console.print("Running retrieval benchmarks...", style="bold blue")
|
|
277
|
+
await run_retrieval_benchmark(spec)
|
|
278
|
+
|
|
279
|
+
if not skip_qa:
|
|
280
|
+
console.print("\nRunning QA benchmarks...", style="bold yellow")
|
|
281
|
+
await run_qa_benchmark(spec, qa_limit=qa_limit)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
app = typer.Typer(help="Run retrieval and QA benchmarks for configured datasets.")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
@app.command()
|
|
288
|
+
def run(
|
|
289
|
+
dataset: str = typer.Argument(..., help="Dataset key to evaluate."),
|
|
290
|
+
skip_db: bool = typer.Option(
|
|
291
|
+
False, "--skip-db", help="Skip updateing the evaluation db."
|
|
292
|
+
),
|
|
293
|
+
skip_retrieval: bool = typer.Option(
|
|
294
|
+
False, "--skip-retrieval", help="Skip retrieval benchmark."
|
|
295
|
+
),
|
|
296
|
+
skip_qa: bool = typer.Option(False, "--skip-qa", help="Skip QA benchmark."),
|
|
297
|
+
qa_limit: int | None = typer.Option(
|
|
298
|
+
None, "--qa-limit", help="Limit number of QA cases."
|
|
299
|
+
),
|
|
300
|
+
) -> None:
|
|
301
|
+
spec = DATASETS.get(dataset.lower())
|
|
302
|
+
if spec is None:
|
|
303
|
+
valid_datasets = ", ".join(sorted(DATASETS))
|
|
304
|
+
raise typer.BadParameter(
|
|
305
|
+
f"Unknown dataset '{dataset}'. Choose from: {valid_datasets}"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
asyncio.run(
|
|
309
|
+
evaluate_dataset(
|
|
310
|
+
spec=spec,
|
|
311
|
+
skip_db=skip_db,
|
|
312
|
+
skip_retrieval=skip_retrieval,
|
|
313
|
+
skip_qa=skip_qa,
|
|
314
|
+
qa_limit=qa_limit,
|
|
315
|
+
)
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
if __name__ == "__main__":
|
|
320
|
+
app()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from collections.abc import Callable, Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from datasets import Dataset
|
|
7
|
+
from pydantic_evals import Case
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class DocumentPayload:
|
|
12
|
+
uri: str
|
|
13
|
+
content: str
|
|
14
|
+
title: str | None = None
|
|
15
|
+
metadata: dict[str, Any] | None = None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class RetrievalSample:
|
|
20
|
+
question: str
|
|
21
|
+
expected_uris: tuple[str, ...]
|
|
22
|
+
skip: bool = False
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
DocumentLoader = Callable[[], Dataset]
|
|
26
|
+
DocumentMapper = Callable[[Mapping[str, Any]], DocumentPayload | None]
|
|
27
|
+
RetrievalLoader = Callable[[], Dataset]
|
|
28
|
+
RetrievalMapper = Callable[[Mapping[str, Any]], RetrievalSample | None]
|
|
29
|
+
CaseBuilder = Callable[[int, Mapping[str, Any]], Case[str, str, dict[str, str]]]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class DatasetSpec:
|
|
34
|
+
key: str
|
|
35
|
+
db_filename: str
|
|
36
|
+
document_loader: DocumentLoader
|
|
37
|
+
document_mapper: DocumentMapper
|
|
38
|
+
qa_loader: DocumentLoader
|
|
39
|
+
qa_case_builder: CaseBuilder
|
|
40
|
+
retrieval_loader: RetrievalLoader | None = None
|
|
41
|
+
retrieval_mapper: RetrievalMapper | None = None
|
|
42
|
+
document_limit: int | None = None
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def db_path(self) -> Path:
|
|
46
|
+
return Path(__file__).parent / "data" / self.db_filename
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, cast
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset, DatasetDict, load_dataset
|
|
5
|
+
from pydantic_evals import Case
|
|
6
|
+
|
|
7
|
+
from evaluations.config import DatasetSpec, DocumentPayload, RetrievalSample
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_repliqa_corpus() -> Dataset:
|
|
11
|
+
dataset_dict = cast(DatasetDict, load_dataset("ServiceNow/repliqa"))
|
|
12
|
+
dataset = cast(Dataset, dataset_dict["repliqa_3"])
|
|
13
|
+
return dataset.filter(lambda doc: doc["document_topic"] == "News Stories")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def map_repliqa_document(doc: Mapping[str, Any]) -> DocumentPayload:
|
|
17
|
+
return DocumentPayload(
|
|
18
|
+
uri=str(doc["document_id"]),
|
|
19
|
+
content=doc["document_extracted"],
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def map_repliqa_retrieval(doc: Mapping[str, Any]) -> RetrievalSample | None:
|
|
24
|
+
expected_answer = doc["answer"]
|
|
25
|
+
if expected_answer == "The answer is not found in the document.":
|
|
26
|
+
return None
|
|
27
|
+
return RetrievalSample(
|
|
28
|
+
question=doc["question"],
|
|
29
|
+
expected_uris=(str(doc["document_id"]),),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_repliqa_case(
|
|
34
|
+
index: int, doc: Mapping[str, Any]
|
|
35
|
+
) -> Case[str, str, dict[str, str]]:
|
|
36
|
+
document_id = doc["document_id"]
|
|
37
|
+
case_name = f"{index}_{document_id}" if document_id is not None else f"case_{index}"
|
|
38
|
+
return Case(
|
|
39
|
+
name=case_name,
|
|
40
|
+
inputs=doc["question"],
|
|
41
|
+
expected_output=doc["answer"],
|
|
42
|
+
metadata={
|
|
43
|
+
"document_id": str(document_id),
|
|
44
|
+
"case_index": str(index),
|
|
45
|
+
},
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
REPLIQ_SPEC = DatasetSpec(
|
|
50
|
+
key="repliqa",
|
|
51
|
+
db_filename="repliqa.lancedb",
|
|
52
|
+
document_loader=load_repliqa_corpus,
|
|
53
|
+
document_mapper=map_repliqa_document,
|
|
54
|
+
qa_loader=load_repliqa_corpus,
|
|
55
|
+
qa_case_builder=build_repliqa_case,
|
|
56
|
+
retrieval_loader=load_repliqa_corpus,
|
|
57
|
+
retrieval_mapper=map_repliqa_retrieval,
|
|
58
|
+
)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Iterable, Mapping
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
from datasets import Dataset, DatasetDict, load_dataset
|
|
6
|
+
from pydantic_evals import Case
|
|
7
|
+
|
|
8
|
+
from evaluations.config import DatasetSpec, DocumentPayload, RetrievalSample
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_wix_corpus() -> Dataset:
|
|
12
|
+
dataset_dict = cast(DatasetDict, load_dataset("Wix/WixQA", "wix_kb_corpus"))
|
|
13
|
+
return cast(Dataset, dataset_dict["train"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def map_wix_document(doc: Mapping[str, Any]) -> DocumentPayload:
|
|
17
|
+
article_id = doc.get("id")
|
|
18
|
+
url = doc.get("url")
|
|
19
|
+
uri = str(article_id) if article_id is not None else str(url)
|
|
20
|
+
|
|
21
|
+
metadata: dict[str, str] = {}
|
|
22
|
+
if article_id is not None:
|
|
23
|
+
metadata["article_id"] = str(article_id)
|
|
24
|
+
if url:
|
|
25
|
+
metadata["url"] = str(url)
|
|
26
|
+
|
|
27
|
+
return DocumentPayload(
|
|
28
|
+
uri=uri,
|
|
29
|
+
content=doc["contents"],
|
|
30
|
+
title=doc.get("title"),
|
|
31
|
+
metadata=metadata or None,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_wix_qa() -> Dataset:
|
|
36
|
+
dataset_dict = cast(DatasetDict, load_dataset("Wix/WixQA", "wixqa_expertwritten"))
|
|
37
|
+
return cast(Dataset, dataset_dict["train"])
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def map_wix_retrieval(doc: Mapping[str, Any]) -> RetrievalSample | None:
|
|
41
|
+
article_ids: Iterable[int | str] | None = doc.get("article_ids")
|
|
42
|
+
if not article_ids:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
expected_uris = tuple(str(article_id) for article_id in article_ids)
|
|
46
|
+
return RetrievalSample(
|
|
47
|
+
question=doc["question"],
|
|
48
|
+
expected_uris=expected_uris,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def build_wix_case(
|
|
53
|
+
index: int, doc: Mapping[str, Any]
|
|
54
|
+
) -> Case[str, str, dict[str, str]]:
|
|
55
|
+
article_ids = tuple(str(article_id) for article_id in doc.get("article_ids") or [])
|
|
56
|
+
joined_ids = "-".join(article_ids)
|
|
57
|
+
case_name = f"{index}_{joined_ids}" if joined_ids else f"case_{index}"
|
|
58
|
+
|
|
59
|
+
metadata = {
|
|
60
|
+
"case_index": str(index),
|
|
61
|
+
"document_ids": json.dumps(article_ids),
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return Case(
|
|
65
|
+
name=case_name,
|
|
66
|
+
inputs=doc["question"],
|
|
67
|
+
expected_output=doc["answer"],
|
|
68
|
+
metadata=metadata,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
WIX_SPEC = DatasetSpec(
|
|
73
|
+
key="wix",
|
|
74
|
+
db_filename="wix.lancedb",
|
|
75
|
+
document_loader=load_wix_corpus,
|
|
76
|
+
document_mapper=map_wix_document,
|
|
77
|
+
qa_loader=load_wix_qa,
|
|
78
|
+
qa_case_builder=build_wix_case,
|
|
79
|
+
retrieval_loader=load_wix_qa,
|
|
80
|
+
retrieval_mapper=map_wix_retrieval,
|
|
81
|
+
)
|
|
@@ -37,7 +37,7 @@ class LLMJudgeResponseSchema(BaseModel):
|
|
|
37
37
|
class LLMJudge:
|
|
38
38
|
"""LLM-as-judge for evaluating answer equivalence using Pydantic AI."""
|
|
39
39
|
|
|
40
|
-
def __init__(self, model: str = "
|
|
40
|
+
def __init__(self, model: str = "gpt-oss"):
|
|
41
41
|
# Create Ollama model
|
|
42
42
|
ollama_model = OpenAIChatModel(
|
|
43
43
|
model_name=model,
|
|
@@ -49,6 +49,7 @@ class LLMJudge:
|
|
|
49
49
|
model=ollama_model,
|
|
50
50
|
output_type=LLMJudgeResponseSchema,
|
|
51
51
|
system_prompt=ANSWER_EQUIVALENCE_RUBRIC,
|
|
52
|
+
retries=3,
|
|
52
53
|
)
|
|
53
54
|
|
|
54
55
|
async def judge_answers(
|
|
@@ -194,10 +194,44 @@ class HaikuRAGApp:
|
|
|
194
194
|
for chunk, score in results:
|
|
195
195
|
self._rich_print_search_result(chunk, score)
|
|
196
196
|
|
|
197
|
-
async def ask(
|
|
197
|
+
async def ask(
|
|
198
|
+
self,
|
|
199
|
+
question: str,
|
|
200
|
+
cite: bool = False,
|
|
201
|
+
deep: bool = False,
|
|
202
|
+
verbose: bool = False,
|
|
203
|
+
):
|
|
198
204
|
async with HaikuRAG(db_path=self.db_path) as self.client:
|
|
199
205
|
try:
|
|
200
|
-
|
|
206
|
+
if deep:
|
|
207
|
+
from rich.console import Console
|
|
208
|
+
|
|
209
|
+
from haiku.rag.qa.deep.dependencies import DeepQAContext
|
|
210
|
+
from haiku.rag.qa.deep.graph import build_deep_qa_graph
|
|
211
|
+
from haiku.rag.qa.deep.nodes import DeepQAPlanNode
|
|
212
|
+
from haiku.rag.qa.deep.state import DeepQADeps, DeepQAState
|
|
213
|
+
|
|
214
|
+
graph = build_deep_qa_graph()
|
|
215
|
+
context = DeepQAContext(
|
|
216
|
+
original_question=question, use_citations=cite
|
|
217
|
+
)
|
|
218
|
+
state = DeepQAState(context=context)
|
|
219
|
+
deps = DeepQADeps(
|
|
220
|
+
client=self.client, console=Console() if verbose else None
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
start_node = DeepQAPlanNode(
|
|
224
|
+
provider=Config.QA_PROVIDER,
|
|
225
|
+
model=Config.QA_MODEL,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
result = await graph.run(
|
|
229
|
+
start_node=start_node, state=state, deps=deps
|
|
230
|
+
)
|
|
231
|
+
answer = result.output.answer
|
|
232
|
+
else:
|
|
233
|
+
answer = await self.client.ask(question, cite=cite)
|
|
234
|
+
|
|
201
235
|
self.console.print(f"[bold blue]Question:[/bold blue] {question}")
|
|
202
236
|
self.console.print()
|
|
203
237
|
self.console.print("[bold green]Answer:[/bold green]")
|
|
@@ -299,11 +299,21 @@ def ask(
|
|
|
299
299
|
"--cite",
|
|
300
300
|
help="Include citations in the response",
|
|
301
301
|
),
|
|
302
|
+
deep: bool = typer.Option(
|
|
303
|
+
False,
|
|
304
|
+
"--deep",
|
|
305
|
+
help="Use deep multi-agent QA for complex questions",
|
|
306
|
+
),
|
|
307
|
+
verbose: bool = typer.Option(
|
|
308
|
+
False,
|
|
309
|
+
"--verbose",
|
|
310
|
+
help="Show verbose progress output (only with --deep)",
|
|
311
|
+
),
|
|
302
312
|
):
|
|
303
313
|
from haiku.rag.app import HaikuRAGApp
|
|
304
314
|
|
|
305
315
|
app = HaikuRAGApp(db_path=db)
|
|
306
|
-
asyncio.run(app.ask(question=question, cite=cite))
|
|
316
|
+
asyncio.run(app.ask(question=question, cite=cite, deep=deep, verbose=verbose))
|
|
307
317
|
|
|
308
318
|
|
|
309
319
|
@cli.command("research", help="Run multi-agent research and output a concise report")
|