raglab 0.2.2__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {raglab-0.2.2 → raglab-0.2.4}/PKG-INFO +2 -2
- {raglab-0.2.2 → raglab-0.2.4}/pyproject.toml +2 -2
- {raglab-0.2.2 → raglab-0.2.4}/raglab/__init__.py +8 -0
- {raglab-0.2.2 → raglab-0.2.4}/raglab/agent.py +114 -18
- raglab-0.2.4/raglab/llm.py +297 -0
- {raglab-0.2.2 → raglab-0.2.4}/tests/test_agent.py +52 -2
- raglab-0.2.4/tests/test_fanin.py +288 -0
- raglab-0.2.4/tests/test_llm_roles.py +250 -0
- {raglab-0.2.2 → raglab-0.2.4}/.claude/CLAUDE.md +0 -0
- {raglab-0.2.2 → raglab-0.2.4}/.gitattributes +0 -0
- {raglab-0.2.2 → raglab-0.2.4}/.github/workflows/ci.yml +0 -0
- {raglab-0.2.2 → raglab-0.2.4}/.gitignore +0 -0
- {raglab-0.2.2 → raglab-0.2.4}/LICENSE +0 -0
- {raglab-0.2.2 → raglab-0.2.4}/README.md +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: raglab
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.4
|
|
4
4
|
Summary: A medley of tools to make RAG-based applications.
|
|
5
5
|
Project-URL: Homepage, https://github.com/thorwhalen/raglab
|
|
6
6
|
Project-URL: Repository, https://github.com/thorwhalen/raglab
|
|
@@ -9,7 +9,7 @@ Author: thorwhalen
|
|
|
9
9
|
License: mit
|
|
10
10
|
License-File: LICENSE
|
|
11
11
|
Requires-Python: >=3.10
|
|
12
|
-
Requires-Dist: ir>=0.1.
|
|
12
|
+
Requires-Dist: ir>=0.1.16
|
|
13
13
|
Provides-Extra: dev
|
|
14
14
|
Requires-Dist: pytest-cov>=4.0; extra == 'dev'
|
|
15
15
|
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
@@ -6,7 +6,7 @@ build-backend = "hatchling.build"
|
|
|
6
6
|
|
|
7
7
|
[project]
|
|
8
8
|
name = "raglab"
|
|
9
|
-
version = "0.2.
|
|
9
|
+
version = "0.2.4"
|
|
10
10
|
description = "A medley of tools to make RAG-based applications."
|
|
11
11
|
readme = "README.md"
|
|
12
12
|
requires-python = ">=3.10"
|
|
@@ -18,7 +18,7 @@ authors = [
|
|
|
18
18
|
# strategies (Planner/Formulator/Evaluator) use `oa` lazily — the `llm` extra
|
|
19
19
|
# below — so `import raglab` stays offline by default.
|
|
20
20
|
dependencies = [
|
|
21
|
-
"ir>=0.1.
|
|
21
|
+
"ir>=0.1.16",
|
|
22
22
|
]
|
|
23
23
|
|
|
24
24
|
[project.license]
|
|
@@ -42,11 +42,14 @@ from .agent import (
|
|
|
42
42
|
identity_citer,
|
|
43
43
|
identity_formulator,
|
|
44
44
|
ir_sources,
|
|
45
|
+
make_rrf_reranker,
|
|
45
46
|
make_search_agent,
|
|
46
47
|
passthrough_evaluator,
|
|
48
|
+
rrf_reranker,
|
|
47
49
|
score_reranker,
|
|
48
50
|
single_subtask_planner,
|
|
49
51
|
)
|
|
52
|
+
from .llm import EVALUATION_PROMPT, make_llm_evaluator, make_llm_formulator
|
|
50
53
|
|
|
51
54
|
__all__ = [
|
|
52
55
|
"Query",
|
|
@@ -68,5 +71,10 @@ __all__ = [
|
|
|
68
71
|
"identity_formulator",
|
|
69
72
|
"passthrough_evaluator",
|
|
70
73
|
"score_reranker",
|
|
74
|
+
"rrf_reranker",
|
|
75
|
+
"make_rrf_reranker",
|
|
71
76
|
"identity_citer",
|
|
77
|
+
"make_llm_formulator",
|
|
78
|
+
"make_llm_evaluator",
|
|
79
|
+
"EVALUATION_PROMPT",
|
|
72
80
|
]
|
|
@@ -6,6 +6,14 @@ This module is the v1 foundation: the immutable value types, the role *Protocols
|
|
|
6
6
|
is fully parametrized by injected roles. Concrete tools live at the leaves — an
|
|
7
7
|
`ir` corpus becomes one `Retriever` via :func:`ir.as_retriever`.
|
|
8
8
|
|
|
9
|
+
The agent is **multi-source by default**: the loop stamps each hit's
|
|
10
|
+
provenance (``hit.source``), and the fan-in :class:`Reranker` —
|
|
11
|
+
:func:`rrf_reranker` — merges heterogeneous sources by *rank*, never by raw
|
|
12
|
+
score (scores from different corpora / embedders / modes are incommensurable;
|
|
13
|
+
ir_07/ir_08). Raw magnitudes order and dedup hits *within* one source, and the
|
|
14
|
+
loop's pool always carries them; fused (ordinal) scores appear only at the
|
|
15
|
+
fan-in boundary.
|
|
16
|
+
|
|
9
17
|
The shape follows ir_09 §3/§6: a small set of named roles —
|
|
10
18
|
``Planner / Formulator / Retriever / Evaluator / Reranker / Citer`` — and a loop
|
|
11
19
|
whose defining feature is the **back-edge** (evaluator → reformulate) that makes
|
|
@@ -27,9 +35,12 @@ from collections.abc import Mapping, Sequence
|
|
|
27
35
|
from dataclasses import dataclass, field
|
|
28
36
|
from typing import Any, Protocol, runtime_checkable
|
|
29
37
|
|
|
30
|
-
# ir owns the retrieval substrate: the Result type
|
|
31
|
-
# contract
|
|
32
|
-
|
|
38
|
+
# ir owns the retrieval substrate: the Result type, the Retriever leaf
|
|
39
|
+
# contract, and the hit operations (dedup, cross-source fusion) live there
|
|
40
|
+
# (one-way dependency, ir is the SSOT).
|
|
41
|
+
from ir import Retriever, SearchHit, fuse_hits, tag_source
|
|
42
|
+
from ir.base import best_per_artifact
|
|
43
|
+
from ir.retrieve import DFLT_RRF_K, Identity
|
|
33
44
|
|
|
34
45
|
#: A retrieved item — ir's :class:`~ir.base.SearchHit` (ir_09's ``Result``):
|
|
35
46
|
#: a *pointer + snippet* (``text``) with a ``score`` and ``metadata``.
|
|
@@ -55,6 +66,8 @@ __all__ = [
|
|
|
55
66
|
"identity_formulator",
|
|
56
67
|
"passthrough_evaluator",
|
|
57
68
|
"score_reranker",
|
|
69
|
+
"rrf_reranker",
|
|
70
|
+
"make_rrf_reranker",
|
|
58
71
|
"identity_citer",
|
|
59
72
|
]
|
|
60
73
|
|
|
@@ -188,13 +201,82 @@ def passthrough_evaluator(task: SubTask, results: Sequence[Result]) -> Judgement
|
|
|
188
201
|
|
|
189
202
|
|
|
190
203
|
def score_reranker(results: Sequence[Result]) -> Sequence[Result]:
|
|
191
|
-
"""
|
|
204
|
+
"""Magnitude merge: one surface per artifact, ordered by descending raw score.
|
|
205
|
+
|
|
206
|
+
Delegates to :func:`ir.base.best_per_artifact` (ir is the SSOT for hit
|
|
207
|
+
operations): an artifact retrieved by several queries / rounds — common once
|
|
208
|
+
the back-edge re-queries — survives once, at its highest score. Identity is
|
|
209
|
+
``(source, artifact_id)``, so two sources' same-id artifacts never collapse.
|
|
210
|
+
|
|
211
|
+
A plain score sort compares raw scores **across** sources, which is only
|
|
212
|
+
sound when every source shares one score scale (same embedder + mode) — it
|
|
213
|
+
is the explicit homogeneous-sources opt-in. The default fan-in is
|
|
214
|
+
:func:`rrf_reranker`, which never compares raw scores across sources.
|
|
215
|
+
"""
|
|
216
|
+
return best_per_artifact(results)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _rrf_rerank(
|
|
220
|
+
results: Sequence[Result],
|
|
221
|
+
*,
|
|
222
|
+
rrf_k: int,
|
|
223
|
+
weights: Mapping[str, float] | None,
|
|
224
|
+
identity: Identity,
|
|
225
|
+
) -> Sequence[Result]:
|
|
226
|
+
"""Group by ``hit.source`` and rank-fuse via :func:`ir.fuse_hits`."""
|
|
227
|
+
# None is preserved as the untagged pseudo-source key: its hits fuse as
|
|
228
|
+
# one rank group and stay unattributed (never an empty-string stamp).
|
|
229
|
+
groups: dict[str | None, list[Result]] = {}
|
|
230
|
+
for h in results:
|
|
231
|
+
groups.setdefault(h.source, []).append(h)
|
|
232
|
+
if len(groups) <= 1:
|
|
233
|
+
# One scale: the magnitude merge, with hits passed through untouched.
|
|
234
|
+
return best_per_artifact(results)
|
|
235
|
+
return fuse_hits(groups, rrf_k=rrf_k, weights=weights, identity=identity)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def rrf_reranker(results: Sequence[Result]) -> Sequence[Result]:
|
|
239
|
+
"""Cross-source merge (the default fan-in): fuse by rank, never by raw score.
|
|
240
|
+
|
|
241
|
+
Groups the accumulated pool by ``hit.source`` and delegates the merge to
|
|
242
|
+
:func:`ir.fuse_hits` (ir is the SSOT for hit operations): within each
|
|
243
|
+
source raw scores order and dedup that source's hits — one scale, sound —
|
|
244
|
+
and across sources only **ranks** interact (Reciprocal Rank Fusion), so
|
|
245
|
+
heterogeneous embedders / modes can never mis-order the merge, and
|
|
246
|
+
colliding ``artifact_id``\\ s from different sources stay distinct results
|
|
247
|
+
(identity is ``(source, artifact_id)``).
|
|
248
|
+
|
|
249
|
+
A single-source pool (or an untagged one — hits with no ``source``) keeps
|
|
250
|
+
its raw scores and exactly :func:`score_reranker`'s ordering; fused,
|
|
251
|
+
rank-derived scores only appear when there is genuinely something to fuse.
|
|
252
|
+
Each fused hit keeps its pre-fusion magnitude as
|
|
253
|
+
``metadata["source_score"]``. For per-source weights, another ``rrf_k``,
|
|
254
|
+
or opt-in cross-source duplicate merging, use :func:`make_rrf_reranker`.
|
|
255
|
+
"""
|
|
256
|
+
return _rrf_rerank(results, rrf_k=DFLT_RRF_K, weights=None, identity=None)
|
|
257
|
+
|
|
192
258
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
259
|
+
def make_rrf_reranker(
|
|
260
|
+
*,
|
|
261
|
+
rrf_k: int = DFLT_RRF_K,
|
|
262
|
+
weights: Mapping[str, float] | None = None,
|
|
263
|
+
identity: Identity = None,
|
|
264
|
+
) -> Reranker:
|
|
265
|
+
"""A parametrized :func:`rrf_reranker` (per-source trust weights, ``rrf_k``).
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
rrf_k: the RRF rank constant (standard default 60).
|
|
269
|
+
weights: optional per-source trust dial, by source name (default 1.0
|
|
270
|
+
each) — biases the merge without ever comparing raw scores.
|
|
271
|
+
identity: opt-in cross-source duplicate detection (e.g. ``"pointer"``)
|
|
272
|
+
— see :data:`ir.retrieve.Identity`. Default: never merge across
|
|
273
|
+
sources.
|
|
196
274
|
"""
|
|
197
|
-
|
|
275
|
+
|
|
276
|
+
def reranker(results: Sequence[Result]) -> Sequence[Result]:
|
|
277
|
+
return _rrf_rerank(results, rrf_k=rrf_k, weights=weights, identity=identity)
|
|
278
|
+
|
|
279
|
+
return reranker
|
|
198
280
|
|
|
199
281
|
|
|
200
282
|
def identity_citer(results: Sequence[Result]) -> Sequence[Result]:
|
|
@@ -221,7 +303,7 @@ class SingleContextAgent:
|
|
|
221
303
|
planner: Planner = single_subtask_planner
|
|
222
304
|
formulator: Formulator = identity_formulator
|
|
223
305
|
evaluator: Evaluator = passthrough_evaluator
|
|
224
|
-
reranker: Reranker =
|
|
306
|
+
reranker: Reranker = rrf_reranker
|
|
225
307
|
citer: Citer = identity_citer
|
|
226
308
|
budget: Budget = field(default_factory=Budget)
|
|
227
309
|
|
|
@@ -243,7 +325,14 @@ class SingleContextAgent:
|
|
|
243
325
|
if retriever is None:
|
|
244
326
|
continue
|
|
245
327
|
for llq in self.formulator(current, source):
|
|
246
|
-
|
|
328
|
+
# Stamp the registry key on hits the retriever did not
|
|
329
|
+
# self-attribute (ir-backed retrievers stamp the corpus
|
|
330
|
+
# name themselves, and their tags win), so any custom
|
|
331
|
+
# Retriever still yields attributable hits — the fan-in
|
|
332
|
+
# reranker merges by source.
|
|
333
|
+
found.extend(
|
|
334
|
+
tag_source(retriever(llq.query, **dict(llq.params)), source)
|
|
335
|
+
)
|
|
247
336
|
judged = self.evaluator(current, found[: self.budget.max_results_per_task])
|
|
248
337
|
found = list(judged.relevant)
|
|
249
338
|
if judged.sufficient or judged.refinement is None:
|
|
@@ -265,17 +354,22 @@ def make_search_agent(
|
|
|
265
354
|
"""Build a :class:`SingleContextAgent` over *sources* with smart defaults.
|
|
266
355
|
|
|
267
356
|
``sources`` is a ``Mapping[name, Retriever]`` — e.g.
|
|
268
|
-
``{"skills": ir.as_retriever("skills")}
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
the
|
|
357
|
+
``{"skills": ir.as_retriever("skills")}``, :func:`ir_sources`, or the lazy
|
|
358
|
+
``ir.retrievers()`` view. Every role defaults to its no-LLM thin-slice
|
|
359
|
+
implementation, so ``make_search_agent(sources)("query")`` just works;
|
|
360
|
+
inject an LLM ``formulator`` / ``evaluator`` to turn on rewriting and the
|
|
361
|
+
back-edge.
|
|
362
|
+
|
|
363
|
+
A custom ``Retriever`` must return :class:`ir.SearchHit` instances (the
|
|
364
|
+
``Result`` alias): the loop stamps provenance (``hit.source``) on its
|
|
365
|
+
output, so duck-typed hit objects raise at the tagging step.
|
|
272
366
|
"""
|
|
273
367
|
return SingleContextAgent(
|
|
274
368
|
sources=sources,
|
|
275
369
|
planner=planner or single_subtask_planner,
|
|
276
370
|
formulator=formulator or identity_formulator,
|
|
277
371
|
evaluator=evaluator or passthrough_evaluator,
|
|
278
|
-
reranker=reranker or
|
|
372
|
+
reranker=reranker or rrf_reranker,
|
|
279
373
|
citer=citer or identity_citer,
|
|
280
374
|
budget=budget or Budget(),
|
|
281
375
|
)
|
|
@@ -284,9 +378,11 @@ def make_search_agent(
|
|
|
284
378
|
def ir_sources(*names: str, **search_defaults: Any) -> dict[str, Retriever]:
|
|
285
379
|
"""A source registry ``{name: Retriever}`` backed by named ``ir`` corpora.
|
|
286
380
|
|
|
287
|
-
Each name is bound to ``ir.as_retriever(name, **search_defaults)
|
|
288
|
-
|
|
289
|
-
|
|
381
|
+
Each name is bound to ``ir.as_retriever(name, **search_defaults)``, opened
|
|
382
|
+
eagerly. For the *lazy* live view over everything registered (a corpus
|
|
383
|
+
opens only when its key is first used), use ``ir.retrievers()`` instead —
|
|
384
|
+
the agent accepts either, or any ``Mapping[name, Retriever]``.
|
|
385
|
+
``search_defaults`` (e.g. ``mode="hybrid"``) apply to every source.
|
|
290
386
|
"""
|
|
291
387
|
import ir
|
|
292
388
|
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""LLM-backed roles — the Formulator and Evaluator (ir_09 §8 steps 1 & 3).
|
|
2
|
+
|
|
3
|
+
`raglab`'s no-LLM thin slice (:mod:`raglab.agent`) runs offline: an identity
|
|
4
|
+
formulator and a pass-through evaluator. This module supplies the two LLM roles
|
|
5
|
+
that turn that slice into a real agent — query understanding and the back-edge:
|
|
6
|
+
|
|
7
|
+
- :func:`make_llm_formulator` — query rewrite / expand / HyDE. A thin adapter
|
|
8
|
+
over ir's :func:`ir.make_llm_formulator` (the ``formulate=`` seam, ir_09 §3):
|
|
9
|
+
ir owns the lazy-:mod:`oa` rewriter and the identity fallback; raglab only
|
|
10
|
+
adapts the shape ``str -> str | [str]`` to the role contract
|
|
11
|
+
``(SubTask, source) -> [LowLevelQuery]``.
|
|
12
|
+
- :func:`make_llm_evaluator` — sufficiency + refinement. ``ir.select`` owns the
|
|
13
|
+
*relevance* decision (the calibrated committed subset, ir_09 §3); the LLM owns
|
|
14
|
+
*sufficiency* — informed by ir's model-free :attr:`ir.Selection.sufficient`
|
|
15
|
+
hint, it judges whether the committed subset actually satisfies the goal and,
|
|
16
|
+
when it does not, emits a ``refinement`` SubTask. That refinement is the
|
|
17
|
+
**back-edge** that makes the loop an agent rather than a DAG.
|
|
18
|
+
|
|
19
|
+
Two load-bearing boundaries (ir ↔ raglab), guarded here:
|
|
20
|
+
|
|
21
|
+
1. A Formulator returns **queries, never SubTasks** — decomposition is the
|
|
22
|
+
Planner's job. :func:`make_llm_formulator` only ever yields
|
|
23
|
+
:class:`~raglab.agent.LowLevelQuery`\\ s.
|
|
24
|
+
2. The **back-edge lives in raglab**, never in ir: ir derives a ``sufficient``
|
|
25
|
+
*signal* from its own selection; raglab's Evaluator is what reads it, decides,
|
|
26
|
+
and re-queries.
|
|
27
|
+
|
|
28
|
+
Both builders mirror :func:`ir.select.make_llm_selector`: an injectable callable
|
|
29
|
+
(a test double, or your own router), built lazily on :mod:`oa` only when omitted
|
|
30
|
+
(so ``import raglab`` stays offline), with a **safe fallback** — a formulator must
|
|
31
|
+
never make retrieval worse than the raw query, and an evaluator must never
|
|
32
|
+
fabricate an endless loop (on any failure it returns no refinement, which is the
|
|
33
|
+
loop's break condition).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
38
|
+
from collections.abc import Mapping, Sequence
|
|
39
|
+
from typing import Any, Callable
|
|
40
|
+
|
|
41
|
+
import ir
|
|
42
|
+
from ir.base import best_per_artifact
|
|
43
|
+
|
|
44
|
+
from .agent import (
|
|
45
|
+
Evaluator,
|
|
46
|
+
Formulator,
|
|
47
|
+
Judgement,
|
|
48
|
+
LowLevelQuery,
|
|
49
|
+
Reranker,
|
|
50
|
+
Result,
|
|
51
|
+
SubTask,
|
|
52
|
+
rrf_reranker,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
__all__ = ["make_llm_formulator", "make_llm_evaluator", "EVALUATION_PROMPT"]
|
|
56
|
+
|
|
57
|
+
#: An ir-style query formulator: a query string -> one query or several.
|
|
58
|
+
QueryFormulator = Callable[[str], "str | Sequence[str]"]
|
|
59
|
+
|
|
60
|
+
#: An evaluator judge: ``(goal, results) -> (sufficient, refinement)``.
|
|
61
|
+
#: ``sufficient`` ends the loop; a non-empty ``refinement`` query becomes the
|
|
62
|
+
#: next sub-goal (the back-edge). A raw text reply is also accepted and parsed.
|
|
63
|
+
Judge = Callable[..., "tuple[bool, str | None] | str"]
|
|
64
|
+
|
|
65
|
+
#: Truncate each rendered result's text to this many chars in the judge prompt.
|
|
66
|
+
DFLT_MAX_RESULT_CHARS = 500
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# --------------------------------------------------------------------------- #
|
|
70
|
+
# LLM Formulator — adapt ir's query-level formulator to the role contract
|
|
71
|
+
# --------------------------------------------------------------------------- #
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def make_llm_formulator(
|
|
75
|
+
*,
|
|
76
|
+
formulate: QueryFormulator | None = None,
|
|
77
|
+
params: Mapping[str, Any] | None = None,
|
|
78
|
+
**make_kwargs: Any,
|
|
79
|
+
) -> Formulator:
|
|
80
|
+
"""An LLM-backed raglab :class:`~raglab.agent.Formulator`.
|
|
81
|
+
|
|
82
|
+
Adapts an ir-style query formulator (``str -> str | [str]``) to the role
|
|
83
|
+
contract ``(SubTask, source) -> [LowLevelQuery]``: the sub-goal text is
|
|
84
|
+
rewritten / expanded into one or more search queries, each wrapped as a
|
|
85
|
+
:class:`~raglab.agent.LowLevelQuery` against ``source``.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
formulate: an injectable ir-style formulator (a test double, or one you
|
|
89
|
+
built). When omitted it is built once via
|
|
90
|
+
:func:`ir.make_llm_formulator` (``**make_kwargs`` are forwarded —
|
|
91
|
+
e.g. ``n``, ``prompt``, ``rewriter``); that call is offline (``oa`` is
|
|
92
|
+
imported lazily only when the formulator is *invoked*), so
|
|
93
|
+
``import raglab`` stays offline.
|
|
94
|
+
params: per-call retriever overrides (e.g. ``{"mode": "hybrid", "k": 5}``)
|
|
95
|
+
attached to every emitted :class:`~raglab.agent.LowLevelQuery`.
|
|
96
|
+
|
|
97
|
+
The boundary: this returns **queries, never SubTasks**. ir's formulator
|
|
98
|
+
already falls back to identity on any failure, so the emitted list is never
|
|
99
|
+
empty — a formulator must never make retrieval worse than the raw sub-goal.
|
|
100
|
+
"""
|
|
101
|
+
extra = dict(params or {})
|
|
102
|
+
fn: QueryFormulator = (
|
|
103
|
+
formulate if formulate is not None else ir.make_llm_formulator(**make_kwargs)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def formulator(task: SubTask, source: str) -> list[LowLevelQuery]:
|
|
107
|
+
try:
|
|
108
|
+
out = fn(task.goal)
|
|
109
|
+
except Exception:
|
|
110
|
+
out = None # a formulator must never make retrieval worse: fall back
|
|
111
|
+
queries = [out] if isinstance(out, str) else list(out or [])
|
|
112
|
+
queries = [q for q in queries if isinstance(q, str) and q.strip()] or [
|
|
113
|
+
task.goal
|
|
114
|
+
]
|
|
115
|
+
return [
|
|
116
|
+
LowLevelQuery(source=source, query=q, params=dict(extra)) for q in queries
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
return formulator
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# --------------------------------------------------------------------------- #
|
|
123
|
+
# LLM Evaluator — ir.select owns relevance; the LLM owns sufficiency + the
|
|
124
|
+
# back-edge (ir_09 §3/§4)
|
|
125
|
+
# --------------------------------------------------------------------------- #
|
|
126
|
+
|
|
127
|
+
#: Default prompt for :func:`make_llm_evaluator` — judge sufficiency, else emit a
|
|
128
|
+
#: single improved query (the refinement that drives the back-edge).
|
|
129
|
+
EVALUATION_PROMPT = """\
|
|
130
|
+
A search agent is pursuing this goal:
|
|
131
|
+
|
|
132
|
+
{goal}
|
|
133
|
+
|
|
134
|
+
A calibrated selector reviewed the retrieved candidates and committed to the
|
|
135
|
+
results below (it abstained if this list is empty):
|
|
136
|
+
|
|
137
|
+
{results}
|
|
138
|
+
|
|
139
|
+
Decide whether these results are SUFFICIENT to satisfy the goal.
|
|
140
|
+
- If they are, reply with exactly: SUFFICIENT
|
|
141
|
+
- If they are not, reply with: INSUFFICIENT
|
|
142
|
+
then, on the next line, a single improved search query that would retrieve what
|
|
143
|
+
is still missing. Keep it a terse search phrase — no prose, no numbering.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def make_llm_evaluator(
|
|
148
|
+
*,
|
|
149
|
+
judge: Judge | None = None,
|
|
150
|
+
select_strategy: str = "conservative",
|
|
151
|
+
select_kwargs: Mapping[str, Any] | None = None,
|
|
152
|
+
prerank: Reranker | None = None,
|
|
153
|
+
prompt: str = EVALUATION_PROMPT,
|
|
154
|
+
max_result_chars: int = DFLT_MAX_RESULT_CHARS,
|
|
155
|
+
**prompt_function_kwargs: Any,
|
|
156
|
+
) -> Evaluator:
|
|
157
|
+
"""An LLM-backed raglab :class:`~raglab.agent.Evaluator` (turns on the back-edge).
|
|
158
|
+
|
|
159
|
+
Relevance is ir's: each round the accumulated results are passed through
|
|
160
|
+
:func:`ir.select` and the committed subset becomes the ``Judgement.relevant``
|
|
161
|
+
(so the LLM stays in its lane — LLM relevance is known-fragile, ir_01 §3).
|
|
162
|
+
Sufficiency is the LLM's: it reads the committed subset (informed by ir's
|
|
163
|
+
model-free :attr:`~ir.Selection.sufficient` hint via abstention) and decides
|
|
164
|
+
whether the goal is satisfied. When it is not, the judge's improved query
|
|
165
|
+
becomes a ``refinement`` :class:`~raglab.agent.SubTask` over the same sources
|
|
166
|
+
— the **back-edge**.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
judge: an injectable ``(goal, results) -> (sufficient, refinement)``
|
|
170
|
+
callable (a test double, or your own router); a raw text reply is also
|
|
171
|
+
accepted and parsed. When omitted it is built lazily on :mod:`oa`.
|
|
172
|
+
select_strategy: ir selection strategy for the relevance decision
|
|
173
|
+
(default ``"conservative"`` — distractor-robust).
|
|
174
|
+
select_kwargs: extra args forwarded to :func:`ir.select`
|
|
175
|
+
(e.g. ``max_k``, ``rel``). ``min_score`` is allowed only while the
|
|
176
|
+
round's pool keeps raw scores (single-source, or
|
|
177
|
+
``prerank=score_reranker``): an absolute floor is per-(corpus,
|
|
178
|
+
mode, embedder), so a round whose pool was rank-fused across
|
|
179
|
+
sources **raises** rather than silently mass-abstaining on
|
|
180
|
+
ordinal RRF scores.
|
|
181
|
+
prerank: the per-round merge that puts the accumulated pool best-first
|
|
182
|
+
(and deduped) before ``ir.select``. Defaults to
|
|
183
|
+
:func:`~raglab.agent.rrf_reranker` — a round's pool can already mix
|
|
184
|
+
sources (a SubTask spans several), so the same rank-based,
|
|
185
|
+
scale-safe merge as the fan-in applies; single-source rounds keep
|
|
186
|
+
raw scores. Inject :func:`~raglab.agent.score_reranker` for
|
|
187
|
+
sources known to share one score scale.
|
|
188
|
+
max_result_chars: per-result text truncation in the judge prompt.
|
|
189
|
+
|
|
190
|
+
Safe fallback: any judge error returns ``refinement=None`` — the control
|
|
191
|
+
loop's break condition — so a failing judge can never fabricate an endless
|
|
192
|
+
loop. Sufficiency without a refinement query is likewise treated as a stop.
|
|
193
|
+
"""
|
|
194
|
+
sel_kw = dict(select_kwargs or {})
|
|
195
|
+
# The floor is handled here, not blindly forwarded: it must only ever meet
|
|
196
|
+
# raw (per-source-scale) scores — see the guard in the evaluator below.
|
|
197
|
+
floor = sel_kw.pop("min_score", None)
|
|
198
|
+
prerank = prerank or rrf_reranker
|
|
199
|
+
|
|
200
|
+
def _ask_judge(goal: str, rendered: str) -> tuple[bool, str | None]:
|
|
201
|
+
fn = (
|
|
202
|
+
judge
|
|
203
|
+
if judge is not None
|
|
204
|
+
else _default_llm_judge(prompt, **prompt_function_kwargs)
|
|
205
|
+
)
|
|
206
|
+
return _normalize_verdict(fn(goal=goal, results=rendered))
|
|
207
|
+
|
|
208
|
+
def evaluator(task: SubTask, results: Sequence[Result]) -> Judgement:
|
|
209
|
+
# ``ir.select`` documents a best-first precondition; the loop accumulates
|
|
210
|
+
# hits across rounds/sources in arbitrary order, so merge them first —
|
|
211
|
+
# the same (scale-safe) merge the final fan-in reranker uses.
|
|
212
|
+
ranked = list(prerank(results))
|
|
213
|
+
if floor is not None and any("source_rank" in h.metadata for h in ranked):
|
|
214
|
+
# fuse_hits stamps source_rank exactly when it fused: the pool
|
|
215
|
+
# spans sources, so the scores below are ordinal RRF values — an
|
|
216
|
+
# absolute floor against them is the mis-scaled comparison ir
|
|
217
|
+
# refuses loudly (ir_07). Never silently mass-abstain.
|
|
218
|
+
raise ValueError(
|
|
219
|
+
"min_score with a multi-source round: the pool was rank-fused, "
|
|
220
|
+
"so an absolute floor would compare against ordinal RRF scores. "
|
|
221
|
+
"Floors are per-(corpus, mode, embedder) — drop min_score, or "
|
|
222
|
+
"inject prerank=score_reranker for sources known to share one "
|
|
223
|
+
"score scale."
|
|
224
|
+
)
|
|
225
|
+
selection = ir.select(
|
|
226
|
+
ranked, strategy=select_strategy, min_score=floor, **sel_kw
|
|
227
|
+
)
|
|
228
|
+
# Map the committed subset back to the PRE-fusion originals: fused
|
|
229
|
+
# (ordinal) scores are local to this selection. ``Judgement.relevant``
|
|
230
|
+
# re-enters the loop's pool and the final fan-in re-fuses it, so it
|
|
231
|
+
# must carry raw per-source magnitudes — otherwise round N+1 compares
|
|
232
|
+
# round N's fused scores against raw ones inside one source group,
|
|
233
|
+
# and ``source_score`` gets overwritten by an already-fused value.
|
|
234
|
+
raw = {(h.source, h.artifact_id): h for h in best_per_artifact(results)}
|
|
235
|
+
relevant = [raw.get((h.source, h.artifact_id), h) for h in selection.selected]
|
|
236
|
+
rendered = (
|
|
237
|
+
_render_results(relevant, max_result_chars)
|
|
238
|
+
if relevant
|
|
239
|
+
else "(none — the selector abstained)"
|
|
240
|
+
)
|
|
241
|
+
try:
|
|
242
|
+
sufficient, refinement = _ask_judge(task.goal, rendered)
|
|
243
|
+
except Exception:
|
|
244
|
+
# Trust ir's model-free signal for the report, but never re-query on a
|
|
245
|
+
# judge failure: refinement=None is the loop's break condition.
|
|
246
|
+
return Judgement(
|
|
247
|
+
relevant=relevant, sufficient=selection.sufficient, refinement=None
|
|
248
|
+
)
|
|
249
|
+
if sufficient or not refinement:
|
|
250
|
+
return Judgement(relevant=relevant, sufficient=True, refinement=None)
|
|
251
|
+
return Judgement(
|
|
252
|
+
relevant=relevant,
|
|
253
|
+
sufficient=False,
|
|
254
|
+
refinement=SubTask(goal=refinement, sources=task.sources),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
return evaluator
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _render_results(results: Sequence[Result], max_chars: int) -> str:
|
|
261
|
+
"""Render committed hits as ``- id: text`` lines for the judge prompt."""
|
|
262
|
+
return "\n".join(f"- {h.artifact_id}: {str(h.text)[:max_chars]}" for h in results)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _normalize_verdict(out: "tuple[bool, str | None] | str") -> tuple[bool, str | None]:
|
|
266
|
+
"""Coerce a judge reply (a ``(sufficient, refinement)`` tuple, or raw text)."""
|
|
267
|
+
if isinstance(out, tuple):
|
|
268
|
+
sufficient, refinement = out
|
|
269
|
+
refinement = str(refinement).strip() if refinement else None
|
|
270
|
+
return bool(sufficient), (refinement or None)
|
|
271
|
+
return _parse_verdict(str(out or ""))
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _parse_verdict(text: str) -> tuple[bool, str | None]:
|
|
275
|
+
"""Parse a SUFFICIENT / INSUFFICIENT + refinement-query text reply."""
|
|
276
|
+
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
|
277
|
+
if not lines or lines[0].upper().startswith("SUFFICIENT"):
|
|
278
|
+
return True, None
|
|
279
|
+
refinement = lines[1] if len(lines) > 1 else None
|
|
280
|
+
return False, (refinement or None)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def _default_llm_judge(prompt: str, **prompt_function_kwargs: Any) -> Judge:
|
|
284
|
+
"""Build the default sufficiency judge on :mod:`oa` (lazy import)."""
|
|
285
|
+
import oa
|
|
286
|
+
|
|
287
|
+
fn = oa.prompt_function(
|
|
288
|
+
prompt,
|
|
289
|
+
egress=_parse_verdict,
|
|
290
|
+
name="evaluate_sufficiency",
|
|
291
|
+
**prompt_function_kwargs,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
def judge(*, goal: str, results: str) -> tuple[bool, str | None]:
|
|
295
|
+
return fn(goal=goal, results=results)
|
|
296
|
+
|
|
297
|
+
return judge
|
|
@@ -39,6 +39,8 @@ def test_defaults_satisfy_role_protocols():
|
|
|
39
39
|
assert isinstance(raglab.identity_formulator, raglab.Formulator)
|
|
40
40
|
assert isinstance(raglab.passthrough_evaluator, raglab.Evaluator)
|
|
41
41
|
assert isinstance(raglab.score_reranker, raglab.Reranker)
|
|
42
|
+
assert isinstance(raglab.rrf_reranker, raglab.Reranker)
|
|
43
|
+
assert isinstance(raglab.make_rrf_reranker(), raglab.Reranker)
|
|
42
44
|
assert isinstance(raglab.identity_citer, raglab.Citer)
|
|
43
45
|
|
|
44
46
|
|
|
@@ -57,13 +59,41 @@ def test_query_string_or_object_equivalent():
|
|
|
57
59
|
assert agent("q") == agent(Query(text="q"))
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
def
|
|
62
|
+
def test_cross_source_merge_is_rank_based_not_score_based():
|
|
63
|
+
# Two sources' raw scores live on different scales, so the default fan-in
|
|
64
|
+
# (rrf_reranker) fuses by per-source rank: both hits are rank 1 in their own
|
|
65
|
+
# source, and the tie breaks by source order — never by the (incomparable)
|
|
66
|
+
# raw scores. For a raw-score merge, inject score_reranker explicitly.
|
|
61
67
|
sources = {
|
|
62
68
|
"s1": _fake_retriever(_hits(("a", 0.3))),
|
|
63
69
|
"s2": _fake_retriever(_hits(("b", 0.9))),
|
|
64
70
|
}
|
|
65
71
|
results = make_search_agent(sources)("q")
|
|
66
|
-
assert [r.artifact_id for r in results] == ["
|
|
72
|
+
assert [r.artifact_id for r in results] == ["a", "b"] # source order, not score
|
|
73
|
+
by_score = make_search_agent(sources, reranker=raglab.score_reranker)("q")
|
|
74
|
+
assert [r.artifact_id for r in by_score] == ["b", "a"] # the explicit opt-in
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_final_results_have_no_duplicate_artifacts():
|
|
78
|
+
# An artifact re-retrieved across back-edge rounds collapses to its best score.
|
|
79
|
+
retr = _fake_retriever(_hits(("a", 0.9), ("b", 0.4)))
|
|
80
|
+
rounds = {"n": 0}
|
|
81
|
+
|
|
82
|
+
def refining(task, results):
|
|
83
|
+
rounds["n"] += 1
|
|
84
|
+
if rounds["n"] < 2:
|
|
85
|
+
return Judgement(
|
|
86
|
+
list(results),
|
|
87
|
+
sufficient=False,
|
|
88
|
+
refinement=SubTask(task.goal, task.sources),
|
|
89
|
+
)
|
|
90
|
+
return Judgement(list(results), sufficient=True)
|
|
91
|
+
|
|
92
|
+
ids = [
|
|
93
|
+
r.artifact_id for r in make_search_agent({"s": retr}, evaluator=refining)("q")
|
|
94
|
+
]
|
|
95
|
+
assert ids == ["a", "b"] # best score per artifact, descending
|
|
96
|
+
assert len(ids) == len(set(ids)) # the re-query did not duplicate artifacts
|
|
67
97
|
|
|
68
98
|
|
|
69
99
|
def test_passthrough_evaluator_does_not_loop():
|
|
@@ -133,6 +163,26 @@ def test_budget_bounds_a_never_sufficient_loop():
|
|
|
133
163
|
assert len(retr.calls) == 3 # exactly max_rounds — the safety net holds
|
|
134
164
|
|
|
135
165
|
|
|
166
|
+
def test_budget_caps_results_per_task_seen_by_evaluator():
|
|
167
|
+
retr = _fake_retriever(_hits(*[(f"a{i}", 1.0 - i / 100) for i in range(20)]))
|
|
168
|
+
seen = {}
|
|
169
|
+
|
|
170
|
+
def evaluator(task, results):
|
|
171
|
+
seen["n"] = len(results)
|
|
172
|
+
return Judgement(list(results), sufficient=True)
|
|
173
|
+
|
|
174
|
+
make_search_agent(
|
|
175
|
+
{"s": retr}, evaluator=evaluator, budget=Budget(max_results_per_task=5)
|
|
176
|
+
)("q")
|
|
177
|
+
assert seen["n"] == 5 # the evaluator sees at most max_results_per_task
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def test_budget_caps_sources_per_task():
|
|
181
|
+
retrs = {f"s{i}": _fake_retriever(_hits((f"a{i}", 0.5))) for i in range(6)}
|
|
182
|
+
make_search_agent(retrs, budget=Budget(max_sources_per_task=2))("q")
|
|
183
|
+
assert sum(1 for r in retrs.values() if r.calls) == 2 # only first 2 sources hit
|
|
184
|
+
|
|
185
|
+
|
|
136
186
|
# ----- end-to-end over a REAL ir corpus (hermetic: light embedder) ---------- #
|
|
137
187
|
|
|
138
188
|
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""Tests for the Reranker at fan-in — rank-based cross-source merge (ir_09 §3).
|
|
2
|
+
|
|
3
|
+
The property under test: raw scores never cross a source boundary. Within one
|
|
4
|
+
source they order and dedup that source's hits; across sources only ranks
|
|
5
|
+
interact (RRF via ``ir.fuse_hits``). Hermetic: fake retrievers with canned
|
|
6
|
+
hits; one end-to-end test wires two REAL ir corpora (light embedder, in-memory
|
|
7
|
+
stores) whose artifact ids deliberately collide.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
import raglab
|
|
13
|
+
from ir import SearchHit
|
|
14
|
+
from raglab import (
|
|
15
|
+
Judgement,
|
|
16
|
+
SubTask,
|
|
17
|
+
make_rrf_reranker,
|
|
18
|
+
make_search_agent,
|
|
19
|
+
rrf_reranker,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _hits(*specs, source=None):
|
|
24
|
+
"""``(artifact_id, score)`` pairs -> ir.SearchHits (optionally source-tagged)."""
|
|
25
|
+
return [
|
|
26
|
+
SearchHit(aid, "k", score, f"text {aid}", {}, source) for aid, score in specs
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _fake_retriever(hits):
|
|
31
|
+
"""A Retriever that records its calls and returns canned hits."""
|
|
32
|
+
calls = []
|
|
33
|
+
|
|
34
|
+
def retrieve(query, **kw):
|
|
35
|
+
calls.append((query, kw))
|
|
36
|
+
return list(hits)
|
|
37
|
+
|
|
38
|
+
retrieve.calls = calls
|
|
39
|
+
return retrieve
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ----- rrf_reranker: the role in isolation ---------------------------------- #
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_heterogeneous_scales_interleave_by_rank():
|
|
46
|
+
# A cosine-scale source (~[0,1]) and a BM25-scale source (~[0,50]): a raw
|
|
47
|
+
# score sort would bury the cosine source entirely; rank fusion interleaves.
|
|
48
|
+
pool = _hits(("c1", 0.92), ("c2", 0.85), source="cos") + _hits(
|
|
49
|
+
("b1", 31.0), ("b2", 24.0), source="bm25"
|
|
50
|
+
)
|
|
51
|
+
fused = rrf_reranker(pool)
|
|
52
|
+
assert {h.artifact_id for h in fused[:2]} == {"c1", "b1"} # both rank-1s lead
|
|
53
|
+
assert {h.artifact_id for h in fused[2:]} == {"c2", "b2"}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def test_colliding_ids_across_sources_stay_distinct():
|
|
57
|
+
pool = _hits(("dol", 0.9), source="skills") + _hits(("dol", 28.0), source="pkgs")
|
|
58
|
+
fused = rrf_reranker(pool)
|
|
59
|
+
assert len(fused) == 2 # same id, different corpus = different artifact
|
|
60
|
+
assert {h.source for h in fused} == {"skills", "pkgs"}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_single_source_pool_keeps_raw_scores():
|
|
64
|
+
pool = _hits(("a", 0.9), ("b", 0.5), source="s")
|
|
65
|
+
fused = rrf_reranker(pool)
|
|
66
|
+
assert [h.score for h in fused] == [0.9, 0.5] # = score_reranker's ordering
|
|
67
|
+
assert list(fused) == list(raglab.score_reranker(pool))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_untagged_pool_keeps_raw_scores_and_no_stamp():
|
|
71
|
+
pool = _hits(("a", 0.9), ("b", 0.5)) # no source anywhere
|
|
72
|
+
fused = rrf_reranker(pool)
|
|
73
|
+
assert [h.score for h in fused] == [0.9, 0.5]
|
|
74
|
+
assert all(h.source is None for h in fused) # passthrough, no "" stamping
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_cross_round_duplicates_collapse_per_source():
|
|
78
|
+
# The same (source, artifact) re-retrieved across back-edge rounds counts
|
|
79
|
+
# once, at its best rank — no duplicate ids, no double RRF mass.
|
|
80
|
+
pool = (
|
|
81
|
+
_hits(("a", 0.7), ("a", 0.9), source="s1")
|
|
82
|
+
+ _hits(("z", 5.0), source="s2")
|
|
83
|
+
+ _hits(("a", 0.8), source="s1")
|
|
84
|
+
)
|
|
85
|
+
fused = rrf_reranker(pool)
|
|
86
|
+
keyed = [(h.source, h.artifact_id) for h in fused]
|
|
87
|
+
assert len(keyed) == len(set(keyed)) == 2
|
|
88
|
+
a = next(h for h in fused if h.artifact_id == "a")
|
|
89
|
+
assert a.metadata["source_score"] == 0.9 # the best raw magnitude survives
|
|
90
|
+
assert a.metadata["source_rank"] == 1
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_make_rrf_reranker_weights_bias_the_merge():
|
|
94
|
+
pool = _hits(("a", 0.9), source="s1") + _hits(("b", 0.9), source="s2")
|
|
95
|
+
fused = make_rrf_reranker(weights={"s2": 2.0})(pool)
|
|
96
|
+
assert fused[0].artifact_id == "b" # trust dial, no score comparability needed
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# ----- the agent loop: tagging + fan-in ------------------------------------- #
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def test_loop_stamps_registry_key_on_untagged_hits():
|
|
103
|
+
sources = {
|
|
104
|
+
"s1": _fake_retriever(_hits(("a", 0.9))),
|
|
105
|
+
"s2": _fake_retriever(_hits(("b", 7.0))),
|
|
106
|
+
}
|
|
107
|
+
results = make_search_agent(sources)("q")
|
|
108
|
+
assert {h.source for h in results} == {"s1", "s2"}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def test_retriever_self_attribution_wins_over_registry_key():
|
|
112
|
+
# An ir-backed retriever stamps the corpus name itself; the loop must not
|
|
113
|
+
# overwrite it with the (possibly different) registry key.
|
|
114
|
+
sources = {"alias": _fake_retriever(_hits(("a", 0.9), source="corpus_x"))}
|
|
115
|
+
results = make_search_agent(sources)("q")
|
|
116
|
+
assert [h.source for h in results] == ["corpus_x"]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_agent_keeps_colliding_ids_from_two_sources():
|
|
120
|
+
sources = {
|
|
121
|
+
"skills": _fake_retriever(_hits(("dol", 0.9))),
|
|
122
|
+
"pkgs": _fake_retriever(_hits(("dol", 28.0))),
|
|
123
|
+
}
|
|
124
|
+
results = make_search_agent(sources)("q")
|
|
125
|
+
assert sorted((h.source, h.artifact_id) for h in results) == [
|
|
126
|
+
("pkgs", "dol"),
|
|
127
|
+
("skills", "dol"),
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def test_back_edge_rounds_do_not_duplicate_across_sources():
|
|
132
|
+
sources = {
|
|
133
|
+
"s1": _fake_retriever(_hits(("a", 0.9))),
|
|
134
|
+
"s2": _fake_retriever(_hits(("a", 6.0))), # same id, other source
|
|
135
|
+
}
|
|
136
|
+
rounds = {"n": 0}
|
|
137
|
+
|
|
138
|
+
def refining(task, results):
|
|
139
|
+
rounds["n"] += 1
|
|
140
|
+
if rounds["n"] < 3:
|
|
141
|
+
return Judgement(
|
|
142
|
+
list(results),
|
|
143
|
+
sufficient=False,
|
|
144
|
+
refinement=SubTask(task.goal, task.sources),
|
|
145
|
+
)
|
|
146
|
+
return Judgement(list(results), sufficient=True)
|
|
147
|
+
|
|
148
|
+
results = make_search_agent(sources, evaluator=refining)("q")
|
|
149
|
+
keyed = [(h.source, h.artifact_id) for h in results]
|
|
150
|
+
assert sorted(keyed) == [("s1", "a"), ("s2", "a")] # one per (source, artifact)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
# ----- the LLM evaluator's per-round prerank --------------------------------- #
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def test_llm_evaluator_prerank_is_scale_safe_by_default():
|
|
157
|
+
from raglab import make_llm_evaluator
|
|
158
|
+
|
|
159
|
+
# A round's pool already mixes sources: the per-round merge must keep both
|
|
160
|
+
# same-id artifacts (distinct sources) visible to ir.select and the judge.
|
|
161
|
+
pool = _hits(("readme", 0.9), source="s1") + _hits(("readme", 30.0), source="s2")
|
|
162
|
+
evaluator = make_llm_evaluator(judge=lambda *, goal, results: (True, None))
|
|
163
|
+
judgement = evaluator(SubTask("g", ("s1", "s2")), pool)
|
|
164
|
+
assert sorted((h.source, h.artifact_id) for h in judgement.relevant) == [
|
|
165
|
+
("s1", "readme"),
|
|
166
|
+
("s2", "readme"),
|
|
167
|
+
]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def test_llm_evaluator_prerank_is_injectable():
|
|
171
|
+
from raglab import make_llm_evaluator, score_reranker
|
|
172
|
+
|
|
173
|
+
pool = _hits(("readme", 0.9), source="s1") + _hits(("readme", 30.0), source="s2")
|
|
174
|
+
evaluator = make_llm_evaluator(
|
|
175
|
+
judge=lambda *, goal, results: (True, None), prerank=score_reranker
|
|
176
|
+
)
|
|
177
|
+
judgement = evaluator(SubTask("g", ("s1", "s2")), pool)
|
|
178
|
+
# The explicit magnitude opt-in keeps both too (identity is per source)…
|
|
179
|
+
# but ranks them by raw score: the BM25-scale hit wins the top slot.
|
|
180
|
+
assert judgement.relevant[0].source == "s2"
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def test_llm_evaluator_relevant_carries_prefusion_magnitudes():
|
|
184
|
+
from raglab import make_llm_evaluator
|
|
185
|
+
|
|
186
|
+
# Judgement.relevant re-enters the loop's pool, so it must carry RAW
|
|
187
|
+
# per-source scores — the fused (ordinal) view is local to selection.
|
|
188
|
+
pool = _hits(("a", 0.9), source="s1") + _hits(("b", 30.0), source="s2")
|
|
189
|
+
evaluator = make_llm_evaluator(judge=lambda *, goal, results: (True, None))
|
|
190
|
+
judgement = evaluator(SubTask("g", ("s1", "s2")), pool)
|
|
191
|
+
assert {h.score for h in judgement.relevant} == {0.9, 30.0} # not ~1/61
|
|
192
|
+
assert all("source_rank" not in h.metadata for h in judgement.relevant)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def test_agent_with_llm_evaluator_never_mixes_fused_and_raw_scores():
|
|
196
|
+
from raglab import make_llm_evaluator
|
|
197
|
+
|
|
198
|
+
# Regression for the double-fusion feedback: round-1 fused scores must not
|
|
199
|
+
# re-enter the pool, or round 2 compares ordinal RRF values against raw
|
|
200
|
+
# magnitudes INSIDE one source group (and source_score gets overwritten by
|
|
201
|
+
# an already-fused value).
|
|
202
|
+
def varying_retriever(rounds_hits):
|
|
203
|
+
state = {"round": 0}
|
|
204
|
+
|
|
205
|
+
def retrieve(query, **kw):
|
|
206
|
+
hits = rounds_hits[min(state["round"], len(rounds_hits) - 1)]
|
|
207
|
+
state["round"] += 1
|
|
208
|
+
return list(hits)
|
|
209
|
+
|
|
210
|
+
return retrieve
|
|
211
|
+
|
|
212
|
+
sources = {
|
|
213
|
+
"s1": varying_retriever([_hits(("a1", 0.9)), _hits(("a2", 0.5))]),
|
|
214
|
+
"s2": varying_retriever([_hits(("b1", 30.0)), _hits(("b2", 24.0))]),
|
|
215
|
+
}
|
|
216
|
+
verdicts = iter([(False, "refined q"), (True, None)])
|
|
217
|
+
evaluator = make_llm_evaluator(judge=lambda *, goal, results: next(verdicts))
|
|
218
|
+
results = make_search_agent(
|
|
219
|
+
sources, evaluator=evaluator, formulator=raglab.identity_formulator
|
|
220
|
+
)("q")
|
|
221
|
+
raw = {("s1", "a1"): 0.9, ("s1", "a2"): 0.5, ("s2", "b1"): 30.0, ("s2", "b2"): 24.0}
|
|
222
|
+
for h in results:
|
|
223
|
+
assert h.metadata["source_score"] == raw[(h.source, h.artifact_id)]
|
|
224
|
+
# Within one source, the final order follows raw magnitudes.
|
|
225
|
+
s1 = [h.artifact_id for h in results if h.source == "s1"]
|
|
226
|
+
assert s1 == sorted(s1, key=lambda a: -raw[("s1", a)])
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def test_llm_evaluator_min_score_works_single_source_raises_multi():
|
|
230
|
+
import pytest
|
|
231
|
+
|
|
232
|
+
from raglab import make_llm_evaluator
|
|
233
|
+
|
|
234
|
+
judge = lambda *, goal, results: (True, None) # noqa: E731
|
|
235
|
+
single = _hits(("a", 0.9), ("b", 0.1), source="s1")
|
|
236
|
+
evaluator = make_llm_evaluator(judge=judge, select_kwargs={"min_score": 0.5})
|
|
237
|
+
judgement = evaluator(SubTask("g", ("s1",)), single)
|
|
238
|
+
assert [h.artifact_id for h in judgement.relevant] == ["a"] # floor met raw scores
|
|
239
|
+
|
|
240
|
+
multi = _hits(("a", 0.9), source="s1") + _hits(("b", 30.0), source="s2")
|
|
241
|
+
with pytest.raises(ValueError, match="rank-fused"):
|
|
242
|
+
evaluator(SubTask("g", ("s1", "s2")), multi) # never silently mass-abstain
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def test_mixed_tagged_untagged_pool_keeps_none_provenance():
|
|
246
|
+
pool = _hits(("u", 0.5)) + _hits(("t", 9.0), source="s1")
|
|
247
|
+
fused = rrf_reranker(pool)
|
|
248
|
+
by_id = {h.artifact_id: h for h in fused}
|
|
249
|
+
assert by_id["u"].source is None # untagged pseudo-source, never ""
|
|
250
|
+
assert by_id["t"].source == "s1"
|
|
251
|
+
assert by_id["u"].to_dict()["source"] is None
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# ----- end-to-end over two REAL ir corpora (hermetic: light embedder) -------- #
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_agent_federates_two_real_corpora_with_colliding_ids():
|
|
258
|
+
import ir
|
|
259
|
+
from ir.store import CorpusStore
|
|
260
|
+
|
|
261
|
+
def corpus(docs, name):
|
|
262
|
+
return ir.build(
|
|
263
|
+
ir.CorpusSource.from_mapping(docs, name=name, strategy=ir.WholeText()),
|
|
264
|
+
store=CorpusStore.memory(),
|
|
265
|
+
embedder="light",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
docs_a = {"shared": "zebra zephyr zucchini", "alpha": "alpha apple avocado"}
|
|
269
|
+
docs_b = {"shared": "zebra zephyr zucchini", "beta": "beta banana blueberry"}
|
|
270
|
+
agent = make_search_agent(
|
|
271
|
+
{
|
|
272
|
+
"one": ir.as_retriever(corpus(docs_a, "one"), k=3),
|
|
273
|
+
"two": ir.as_retriever(corpus(docs_b, "two"), k=3),
|
|
274
|
+
}
|
|
275
|
+
)
|
|
276
|
+
results = agent("zebra zephyr zucchini")
|
|
277
|
+
keyed = {(h.source, h.artifact_id) for h in results}
|
|
278
|
+
# The colliding "shared" artifact survives from BOTH corpora, attributed.
|
|
279
|
+
assert {("one", "shared"), ("two", "shared")} <= keyed
|
|
280
|
+
assert results[0].to_dict()["source"] in {"one", "two"} # serialization-clean
|
|
281
|
+
fused_scores = [h.score for h in results]
|
|
282
|
+
assert fused_scores == sorted(fused_scores, reverse=True) # best-first
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def test_score_reranker_remains_the_homogeneous_opt_in():
|
|
286
|
+
pytest.importorskip("ir")
|
|
287
|
+
pool = _hits(("a", 0.3), source="s1") + _hits(("b", 0.9), source="s2")
|
|
288
|
+
assert [h.artifact_id for h in raglab.score_reranker(pool)] == ["b", "a"]
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Tests for raglab's LLM-backed roles (the Formulator and Evaluator).
|
|
2
|
+
|
|
3
|
+
Hermetic and deterministic: every "LLM" is an injected test double (no model, no
|
|
4
|
+
network). The Formulator adapts an ir-style ``str -> [str]`` rewriter; the
|
|
5
|
+
Evaluator delegates relevance to ``ir.select`` and sufficiency to the injected
|
|
6
|
+
judge. One end-to-end demo wires a real ``ir`` corpus (the light, numpy-only
|
|
7
|
+
embedder, in-memory store) and shows the **back-edge** recovering a gold document
|
|
8
|
+
that single-shot retrieval misses.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import ir
|
|
12
|
+
from ir import SearchHit
|
|
13
|
+
from ir.store import CorpusStore
|
|
14
|
+
|
|
15
|
+
from raglab import (
|
|
16
|
+
Budget,
|
|
17
|
+
LowLevelQuery,
|
|
18
|
+
SubTask,
|
|
19
|
+
make_llm_evaluator,
|
|
20
|
+
make_llm_formulator,
|
|
21
|
+
make_search_agent,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _hits(*specs):
|
|
26
|
+
"""``(artifact_id, score)`` pairs -> ir.SearchHits (no corpus needed)."""
|
|
27
|
+
return [SearchHit(aid, "k", score, f"text {aid}", {}) for aid, score in specs]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _fake_retriever(hits):
|
|
31
|
+
"""A Retriever that records its calls and returns canned hits."""
|
|
32
|
+
calls = []
|
|
33
|
+
|
|
34
|
+
def retrieve(query, **kw):
|
|
35
|
+
calls.append((query, kw))
|
|
36
|
+
return list(hits)
|
|
37
|
+
|
|
38
|
+
retrieve.calls = calls
|
|
39
|
+
return retrieve
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# ----- LLM Formulator: adapt str -> [str] to (SubTask, source) -> [LLQ] ------ #
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_llm_formulator_fans_out_one_llq_per_query():
|
|
46
|
+
formulator = make_llm_formulator(formulate=lambda q: [q, q + " alt"])
|
|
47
|
+
llqs = formulator(SubTask(goal="deploy", sources=("s",)), "s")
|
|
48
|
+
assert [q.query for q in llqs] == ["deploy", "deploy alt"]
|
|
49
|
+
assert all(isinstance(q, LowLevelQuery) and q.source == "s" for q in llqs)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_llm_formulator_accepts_a_bare_string():
|
|
53
|
+
formulator = make_llm_formulator(formulate=lambda q: q + "!")
|
|
54
|
+
llqs = formulator(SubTask(goal="x", sources=("s",)), "s")
|
|
55
|
+
assert [q.query for q in llqs] == ["x!"]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def test_llm_formulator_attaches_params_to_every_query():
|
|
59
|
+
formulator = make_llm_formulator(
|
|
60
|
+
formulate=lambda q: [q, q + " b"], params={"mode": "hybrid", "k": 5}
|
|
61
|
+
)
|
|
62
|
+
llqs = formulator(SubTask(goal="g", sources=("s",)), "s")
|
|
63
|
+
assert all(q.params == {"mode": "hybrid", "k": 5} for q in llqs)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def test_llm_formulator_empty_output_falls_back_to_the_goal():
|
|
67
|
+
# A formulator must never make retrieval worse than the raw sub-goal.
|
|
68
|
+
formulator = make_llm_formulator(formulate=lambda q: [])
|
|
69
|
+
llqs = formulator(SubTask(goal="the goal", sources=("s",)), "s")
|
|
70
|
+
assert [q.query for q in llqs] == ["the goal"]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def test_llm_formulator_swallows_a_raising_callable():
|
|
74
|
+
# A failing custom formulator must fall back to the goal, never propagate.
|
|
75
|
+
def boom(_q):
|
|
76
|
+
raise RuntimeError("rewriter down")
|
|
77
|
+
|
|
78
|
+
formulator = make_llm_formulator(formulate=boom)
|
|
79
|
+
llqs = formulator(SubTask(goal="the goal", sources=("s",)), "s")
|
|
80
|
+
assert [q.query for q in llqs] == ["the goal"]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def test_llm_formulator_drives_the_agent_loop():
|
|
84
|
+
retr = _fake_retriever(_hits(("a", 0.9)))
|
|
85
|
+
formulator = make_llm_formulator(formulate=lambda q: [q, q + " expanded"])
|
|
86
|
+
make_search_agent({"s": retr}, formulator=formulator)("q")
|
|
87
|
+
assert [c[0] for c in retr.calls] == ["q", "q expanded"]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# ----- LLM Evaluator: ir.select owns relevance, the LLM owns sufficiency ----- #
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_evaluator_relevance_comes_from_ir_select():
|
|
94
|
+
# Conservative selection keeps only the near-top hit; "b" is a distractor.
|
|
95
|
+
evaluator = make_llm_evaluator(judge=lambda **kw: (True, None))
|
|
96
|
+
judged = evaluator(SubTask("g", ("s",)), _hits(("a", 0.9), ("b", 0.1)))
|
|
97
|
+
assert [h.artifact_id for h in judged.relevant] == ["a"]
|
|
98
|
+
assert judged.sufficient and judged.refinement is None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def test_evaluator_emits_a_refinement_when_insufficient():
|
|
102
|
+
evaluator = make_llm_evaluator(judge=lambda **kw: (False, "better query"))
|
|
103
|
+
judged = evaluator(SubTask("g", ("s1", "s2")), _hits(("a", 0.9)))
|
|
104
|
+
assert judged.sufficient is False
|
|
105
|
+
assert judged.refinement == SubTask(goal="better query", sources=("s1", "s2"))
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def test_evaluator_insufficient_without_a_query_stops_the_loop():
|
|
109
|
+
# Insufficient but no refinement query -> nothing better to try -> stop.
|
|
110
|
+
evaluator = make_llm_evaluator(judge=lambda **kw: (False, None))
|
|
111
|
+
judged = evaluator(SubTask("g", ("s",)), _hits(("a", 0.9)))
|
|
112
|
+
assert judged.sufficient is True and judged.refinement is None
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_evaluator_parses_a_raw_text_reply():
|
|
116
|
+
evaluator = make_llm_evaluator(
|
|
117
|
+
judge=lambda **kw: "INSUFFICIENT\nvector database filtering"
|
|
118
|
+
)
|
|
119
|
+
judged = evaluator(SubTask("g", ("s",)), _hits(("a", 0.9)))
|
|
120
|
+
assert judged.refinement.goal == "vector database filtering"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_evaluator_judge_error_falls_back_to_signal_no_loop():
|
|
124
|
+
def boom(**kw):
|
|
125
|
+
raise RuntimeError("model down")
|
|
126
|
+
|
|
127
|
+
evaluator = make_llm_evaluator(judge=boom)
|
|
128
|
+
judged = evaluator(SubTask("g", ("s",)), _hits(("a", 0.9)))
|
|
129
|
+
# refinement=None is the loop's break condition: a judge error never spins.
|
|
130
|
+
assert judged.refinement is None
|
|
131
|
+
assert judged.sufficient is True # ir.select committed to "a" -> sufficient
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_evaluator_renders_abstention_to_the_judge():
|
|
135
|
+
seen = {}
|
|
136
|
+
|
|
137
|
+
def judge(*, goal, results):
|
|
138
|
+
seen["results"] = results
|
|
139
|
+
return (True, None)
|
|
140
|
+
|
|
141
|
+
evaluator = make_llm_evaluator(judge=judge)
|
|
142
|
+
evaluator(SubTask("g", ("s",)), []) # no results -> ir.select abstains
|
|
143
|
+
assert "abstained" in seen["results"]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def test_evaluator_forwards_select_kwargs_to_ir_select():
|
|
147
|
+
# Two near-tied hits: conservative keeps only "a" by default, but a loose
|
|
148
|
+
# rel threshold (forwarded via select_kwargs) admits "b" too.
|
|
149
|
+
hits = _hits(("a", 0.9), ("b", 0.6))
|
|
150
|
+
strict = make_llm_evaluator(judge=lambda **kw: (True, None))
|
|
151
|
+
loose = make_llm_evaluator(
|
|
152
|
+
judge=lambda **kw: (True, None), select_kwargs={"rel": 0.5}
|
|
153
|
+
)
|
|
154
|
+
assert [h.artifact_id for h in strict(SubTask("g", ("s",)), hits).relevant] == ["a"]
|
|
155
|
+
assert [h.artifact_id for h in loose(SubTask("g", ("s",)), hits).relevant] == [
|
|
156
|
+
"a",
|
|
157
|
+
"b",
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def test_evaluator_ranks_heterogeneous_results_before_selecting():
|
|
162
|
+
# Accumulated cross-source results arrive unordered; the evaluator must rank
|
|
163
|
+
# best-first before ir.select (which trusts input order).
|
|
164
|
+
captured = {}
|
|
165
|
+
|
|
166
|
+
def judge(*, goal, results):
|
|
167
|
+
captured["results"] = results
|
|
168
|
+
return (True, None)
|
|
169
|
+
|
|
170
|
+
evaluator = make_llm_evaluator(judge=judge, select_kwargs={"rel": 0.0})
|
|
171
|
+
unordered = _hits(("lo", 0.1), ("hi", 0.9), ("mid", 0.5))
|
|
172
|
+
judged = evaluator(SubTask("g", ("s",)), unordered)
|
|
173
|
+
assert [h.artifact_id for h in judged.relevant] == ["hi", "mid", "lo"]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# ----- the back-edge end-to-end, wired through the agent loop ---------------- #
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def test_evaluator_back_edge_loops_until_sufficient():
|
|
180
|
+
retr = _fake_retriever(_hits(("a", 0.9)))
|
|
181
|
+
rounds = {"n": 0}
|
|
182
|
+
|
|
183
|
+
def judge(*, goal, results):
|
|
184
|
+
rounds["n"] += 1
|
|
185
|
+
if rounds["n"] < 2:
|
|
186
|
+
return (False, goal + " more")
|
|
187
|
+
return (True, None)
|
|
188
|
+
|
|
189
|
+
make_search_agent({"s": retr}, evaluator=make_llm_evaluator(judge=judge))("q")
|
|
190
|
+
assert rounds["n"] == 2 # looped once via the back-edge
|
|
191
|
+
assert [c[0] for c in retr.calls] == ["q", "q more"] # refinement re-queried
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def test_evaluator_back_edge_is_bounded_by_budget():
|
|
195
|
+
retr = _fake_retriever(_hits(("a", 0.9)))
|
|
196
|
+
evaluator = make_llm_evaluator(judge=lambda **kw: (False, "again"))
|
|
197
|
+
make_search_agent({"s": retr}, evaluator=evaluator, budget=Budget(max_rounds=3))(
|
|
198
|
+
"q"
|
|
199
|
+
)
|
|
200
|
+
assert len(retr.calls) == 3 # the safety net holds even if never sufficient
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ----- end-to-end over a REAL ir corpus (hermetic: light embedder) ---------- #
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _light_corpus():
|
|
207
|
+
docs = {
|
|
208
|
+
"embed": "embed and cache model vectors",
|
|
209
|
+
"systemd": "configure systemd units and restart services",
|
|
210
|
+
"filtering": "narrow similarity search using metadata filters",
|
|
211
|
+
}
|
|
212
|
+
return ir.build(
|
|
213
|
+
ir.CorpusSource.from_mapping(docs, name="t", strategy=ir.WholeText()),
|
|
214
|
+
store=CorpusStore.memory(),
|
|
215
|
+
embedder="light",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_back_edge_recovers_a_doc_single_shot_misses():
|
|
220
|
+
"""A query that overlaps a distractor misses the gold; the refinement recovers it.
|
|
221
|
+
|
|
222
|
+
Deterministic with the light embedder: the round-1 query shares vocabulary
|
|
223
|
+
with ``embed`` (a positive-score distractor) but none with the gold
|
|
224
|
+
``filtering``, so single-shot ranks ``embed`` first; the injected judge
|
|
225
|
+
declares it insufficient and reformulates to the gold doc's own vocabulary, so
|
|
226
|
+
round 2 retrieves ``filtering`` to the top via the back-edge.
|
|
227
|
+
"""
|
|
228
|
+
corpus = _light_corpus()
|
|
229
|
+
sources = {"t": ir.as_retriever(corpus, k=3)}
|
|
230
|
+
vague = "cache model results" # overlaps the `embed` distractor, not the gold
|
|
231
|
+
gold_query = "narrow similarity search using metadata filters"
|
|
232
|
+
|
|
233
|
+
# Baseline: single-shot (no LLM evaluator) surfaces the distractor, not the gold.
|
|
234
|
+
baseline = make_search_agent(sources)(vague)
|
|
235
|
+
assert baseline[0].artifact_id == "embed"
|
|
236
|
+
|
|
237
|
+
# With the back-edge: reformulate to the gold's vocabulary, then it wins.
|
|
238
|
+
rounds = {"n": 0}
|
|
239
|
+
|
|
240
|
+
def judge(*, goal, results):
|
|
241
|
+
rounds["n"] += 1
|
|
242
|
+
if rounds["n"] < 2:
|
|
243
|
+
return (False, gold_query)
|
|
244
|
+
return (True, None)
|
|
245
|
+
|
|
246
|
+
agent = make_search_agent(sources, evaluator=make_llm_evaluator(judge=judge))
|
|
247
|
+
results = agent(vague)
|
|
248
|
+
assert rounds["n"] == 2 # the back-edge fired
|
|
249
|
+
assert results[0].artifact_id == "filtering" # gold recovered
|
|
250
|
+
assert isinstance(results[0], SearchHit)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|