scroot 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scroot/__init__.py +109 -0
- scroot/agents.py +345 -0
- scroot/audit.py +131 -0
- scroot/cli/__init__.py +167 -0
- scroot/cli/download.py +49 -0
- scroot/cli/eval.py +230 -0
- scroot/cli/model_info.py +28 -0
- scroot/composite.py +170 -0
- scroot/config/__init__.py +0 -0
- scroot/config/corrector.py +92 -0
- scroot/connectors/__init__.py +5 -0
- scroot/connectors/database.py +357 -0
- scroot/context/__init__.py +9 -0
- scroot/context/adapters.py +86 -0
- scroot/context/builder.py +514 -0
- scroot/context/dedup.py +99 -0
- scroot/context/payload.py +66 -0
- scroot/context/pii.py +101 -0
- scroot/context/tokenizer.py +42 -0
- scroot/core.py +349 -0
- scroot/corrector/__init__.py +38 -0
- scroot/corrector/api.py +145 -0
- scroot/corrector/base.py +20 -0
- scroot/corrector/disabled.py +13 -0
- scroot/corrector/local.py +112 -0
- scroot/corrector/models.py +69 -0
- scroot/dashboard/__init__.py +0 -0
- scroot/dashboard/__main__.py +37 -0
- scroot/dashboard/routers/__init__.py +0 -0
- scroot/dashboard/routers/analytics.py +236 -0
- scroot/dashboard/routers/corrector.py +230 -0
- scroot/dashboard/routers/export.py +150 -0
- scroot/dashboard/routers/guardrails.py +41 -0
- scroot/dashboard/routers/pipeline.py +218 -0
- scroot/dashboard/routers/queue.py +188 -0
- scroot/dashboard/routers/records.py +252 -0
- scroot/dashboard/routers/settings.py +291 -0
- scroot/dashboard/security.py +135 -0
- scroot/dashboard/server.py +181 -0
- scroot/evidence.py +228 -0
- scroot/exceptions.py +62 -0
- scroot/feedback/__init__.py +6 -0
- scroot/feedback/injector.py +160 -0
- scroot/feedback/sanitizer.py +56 -0
- scroot/feedback/store.py +650 -0
- scroot/flags.py +42 -0
- scroot/metrics/__init__.py +15 -0
- scroot/metrics/_utils.py +9 -0
- scroot/metrics/completeness.py +139 -0
- scroot/metrics/confidence.py +83 -0
- scroot/metrics/consistency.py +125 -0
- scroot/metrics/groundedness.py +193 -0
- scroot/metrics/relevance.py +73 -0
- scroot/models.py +214 -0
- scroot/result.py +276 -0
- scroot/sampling.py +306 -0
- scroot/text_utils.py +136 -0
- scroot/ui/dist/assets/index-DW1dLzDl.js +101 -0
- scroot/ui/dist/assets/index-WOhrVVSM.css +2 -0
- scroot/ui/dist/favicon.svg +27 -0
- scroot/ui/dist/index.html +20 -0
- scroot-0.2.0.dist-info/METADATA +832 -0
- scroot-0.2.0.dist-info/RECORD +67 -0
- scroot-0.2.0.dist-info/WHEEL +5 -0
- scroot-0.2.0.dist-info/entry_points.txt +2 -0
- scroot-0.2.0.dist-info/licenses/LICENSE +201 -0
- scroot-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""CorrectorConfig - persisted to ~/.scroot/config.json."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import warnings
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
CorrectorMode = Literal["disabled", "local", "api"]
|
|
12
|
+
|
|
13
|
+
_DEFAULT_SYSTEM_PROMPT = (
|
|
14
|
+
"You are a correction assistant. Rewrite the response to be more "
|
|
15
|
+
"accurate, complete, and grounded in the provided context. "
|
|
16
|
+
"Return only the corrected response text, nothing else."
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class LocalConfig:
|
|
22
|
+
model_id: str = "phi4-mini"
|
|
23
|
+
n_threads: int = -1
|
|
24
|
+
n_gpu_layers: int = 0
|
|
25
|
+
context_window: int = 4096
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class APIConfig:
|
|
30
|
+
api_key: str = ""
|
|
31
|
+
base_url: str = ""
|
|
32
|
+
model: str = "gpt-4o-mini"
|
|
33
|
+
system_prompt: str = _DEFAULT_SYSTEM_PROMPT
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class CorrectorConfig:
|
|
38
|
+
mode: CorrectorMode = "disabled"
|
|
39
|
+
local: LocalConfig = field(default_factory=LocalConfig)
|
|
40
|
+
api: APIConfig = field(default_factory=APIConfig)
|
|
41
|
+
|
|
42
|
+
def save(self, path: Path) -> None:
|
|
43
|
+
# M-1: this file holds the provider API key. Create the parent dir as
|
|
44
|
+
# owner-only and write the file with 0600 so other local accounts
|
|
45
|
+
# cannot read the key.
|
|
46
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
try:
|
|
48
|
+
os.chmod(path.parent, 0o700)
|
|
49
|
+
except OSError:
|
|
50
|
+
pass
|
|
51
|
+
payload = json.dumps(
|
|
52
|
+
{
|
|
53
|
+
"mode": self.mode,
|
|
54
|
+
"local": vars(self.local),
|
|
55
|
+
"api": vars(self.api),
|
|
56
|
+
},
|
|
57
|
+
indent=2,
|
|
58
|
+
)
|
|
59
|
+
flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
|
|
60
|
+
fd = os.open(path, flags, 0o600)
|
|
61
|
+
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
|
62
|
+
f.write(payload)
|
|
63
|
+
try:
|
|
64
|
+
os.chmod(path, 0o600) # tighten if the file pre-existed
|
|
65
|
+
except OSError:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def load(cls, path: Path) -> "CorrectorConfig":
|
|
70
|
+
if not path.exists():
|
|
71
|
+
return cls()
|
|
72
|
+
try:
|
|
73
|
+
data = json.loads(path.read_text())
|
|
74
|
+
return cls(
|
|
75
|
+
mode=data.get("mode", "disabled"),
|
|
76
|
+
local=LocalConfig(**data.get("local", {})),
|
|
77
|
+
api=APIConfig(**data.get("api", {})),
|
|
78
|
+
)
|
|
79
|
+
except (OSError, ValueError, TypeError) as exc:
|
|
80
|
+
# A corrupt/unreadable config silently resetting to defaults would
|
|
81
|
+
# wipe the stored key/settings on the next save(). Surface it so the
|
|
82
|
+
# operator can recover the file instead of losing it unknowingly.
|
|
83
|
+
warnings.warn(
|
|
84
|
+
f"Could not read corrector config at {path} ({exc}); "
|
|
85
|
+
f"using defaults. The file will be overwritten on the next save.",
|
|
86
|
+
stacklevel=2,
|
|
87
|
+
)
|
|
88
|
+
return cls()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def default_config_path() -> Path:
|
|
92
|
+
return Path.home() / ".scroot" / "config.json"
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""Database connector for scoring stored LLM responses.
|
|
2
|
+
|
|
3
|
+
Uses SQLAlchemy for database abstraction. Supports PostgreSQL, MySQL,
|
|
4
|
+
SQLite, BigQuery, Snowflake, and any SQLAlchemy-compatible backend.
|
|
5
|
+
|
|
6
|
+
Reads responses from a source table, scores them via Auditor, and writes
|
|
7
|
+
results to a result table. The result table is auto-created if absent.
|
|
8
|
+
|
|
9
|
+
.. warning::
|
|
10
|
+
SQL injection risk: ``source_table``, ``result_table``, ``column_map``
|
|
11
|
+
values, and the ``where`` / ``cursor_column`` arguments accepted by
|
|
12
|
+
``fetch()``, ``write_result()``, and ``score_incremental()`` are
|
|
13
|
+
interpolated directly into SQL strings (table/column identifiers
|
|
14
|
+
cannot be parameterised via SQLAlchemy bind parameters). Only pass
|
|
15
|
+
values you control - never pass user-supplied input directly as a
|
|
16
|
+
table name, column name, or WHERE clause. See ``docs/security.md``
|
|
17
|
+
for details and the planned hardening (allowlist validation,
|
|
18
|
+
``dry_run`` mode).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import warnings
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger("scroot.connectors")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SecurityWarning(Warning):
|
|
32
|
+
"""Warns about a potential security risk in connector configuration."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_SQL_INJECTION_WARNING = (
|
|
36
|
+
"DatabaseConnector builds SQL using string interpolation for table "
|
|
37
|
+
"names, column names, and WHERE clauses. These are NOT parameterised. "
|
|
38
|
+
"Validate/allowlist any externally-influenced table/column names or "
|
|
39
|
+
"WHERE clauses before passing them to this connector. "
|
|
40
|
+
"See docs/security.md."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DatabaseConnector:
|
|
45
|
+
"""Connector for scoring LLM responses stored in a SQL database.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
connection_string: SQLAlchemy connection string.
|
|
49
|
+
Examples:
|
|
50
|
+
"postgresql://user:pass@host:5432/db"
|
|
51
|
+
"mysql+pymysql://user:pass@host/db"
|
|
52
|
+
"sqlite:///local.db"
|
|
53
|
+
"bigquery://project/dataset"
|
|
54
|
+
source_table: Name of the table containing LLM responses.
|
|
55
|
+
column_map: Dict mapping entail field names to your column names.
|
|
56
|
+
Required keys: "query", "response".
|
|
57
|
+
Optional: "context" (JSON array or NULL), "id" (row identifier).
|
|
58
|
+
result_table: Table to write scores to. Auto-created if absent.
|
|
59
|
+
batch_size: Rows fetched and scored per batch. Default 100.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
connection_string: str,
|
|
65
|
+
source_table: str,
|
|
66
|
+
column_map: dict,
|
|
67
|
+
result_table: str = "scroot_scores",
|
|
68
|
+
batch_size: int = 100,
|
|
69
|
+
):
|
|
70
|
+
try:
|
|
71
|
+
import sqlalchemy # noqa: F401
|
|
72
|
+
except ImportError as exc:
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"SQLAlchemy is required for database connectors: "
|
|
75
|
+
"pip install 'scroot[database]'"
|
|
76
|
+
) from exc
|
|
77
|
+
|
|
78
|
+
if "query" not in column_map or "response" not in column_map:
|
|
79
|
+
raise ValueError("column_map must include 'query' and 'response' keys")
|
|
80
|
+
|
|
81
|
+
warnings.warn(_SQL_INJECTION_WARNING, SecurityWarning, stacklevel=2)
|
|
82
|
+
|
|
83
|
+
self.connection_string = connection_string
|
|
84
|
+
self.source_table = source_table
|
|
85
|
+
self.column_map = column_map
|
|
86
|
+
self.result_table = result_table
|
|
87
|
+
self.batch_size = batch_size
|
|
88
|
+
|
|
89
|
+
import sqlalchemy as sa
|
|
90
|
+
self._engine = sa.create_engine(connection_string)
|
|
91
|
+
self._metadata = sa.MetaData()
|
|
92
|
+
self._ensure_result_table()
|
|
93
|
+
|
|
94
|
+
def _ensure_result_table(self) -> None:
|
|
95
|
+
"""Create the result table if it does not exist."""
|
|
96
|
+
import sqlalchemy as sa
|
|
97
|
+
|
|
98
|
+
if not sa.inspect(self._engine).has_table(self.result_table):
|
|
99
|
+
table = sa.Table(
|
|
100
|
+
self.result_table,
|
|
101
|
+
self._metadata,
|
|
102
|
+
sa.Column("id", sa.Integer, primary_key=True, autoincrement=True),
|
|
103
|
+
sa.Column("source_row_id", sa.String(255), index=True),
|
|
104
|
+
sa.Column("scored_at", sa.DateTime),
|
|
105
|
+
sa.Column("iqs", sa.Float),
|
|
106
|
+
sa.Column("groundedness", sa.Float, nullable=True),
|
|
107
|
+
sa.Column("completeness", sa.Float),
|
|
108
|
+
sa.Column("relevance", sa.Float),
|
|
109
|
+
sa.Column("consistency", sa.Float),
|
|
110
|
+
sa.Column("confidence", sa.Float),
|
|
111
|
+
sa.Column("flags", sa.Text),
|
|
112
|
+
sa.Column("details", sa.Text),
|
|
113
|
+
sa.Column("strategy", sa.String(50), nullable=True),
|
|
114
|
+
sa.Column("sample_seed", sa.Integer, nullable=True),
|
|
115
|
+
)
|
|
116
|
+
table.create(self._engine)
|
|
117
|
+
logger.info("Created result table: %s", self.result_table)
|
|
118
|
+
|
|
119
|
+
def _parse_context(self, raw_value) -> list[str] | None:
|
|
120
|
+
"""Parse a context column value into a list of strings."""
|
|
121
|
+
if raw_value is None:
|
|
122
|
+
return None
|
|
123
|
+
if isinstance(raw_value, list):
|
|
124
|
+
return [str(c) for c in raw_value]
|
|
125
|
+
if isinstance(raw_value, str):
|
|
126
|
+
try:
|
|
127
|
+
parsed = json.loads(raw_value)
|
|
128
|
+
if isinstance(parsed, list):
|
|
129
|
+
return [str(c) for c in parsed]
|
|
130
|
+
except json.JSONDecodeError:
|
|
131
|
+
return [raw_value]
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
def fetch(
|
|
135
|
+
self,
|
|
136
|
+
limit: int | None = None,
|
|
137
|
+
where: str | None = None,
|
|
138
|
+
offset: int = 0,
|
|
139
|
+
) -> list[dict]:
|
|
140
|
+
"""Fetch rows from the source table.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
limit: Max rows to fetch. None fetches all rows.
|
|
144
|
+
where: Optional SQL WHERE clause (without the WHERE keyword).
|
|
145
|
+
offset: Row offset for pagination.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
List of dicts with entail field keys plus "_row_id" and "_raw".
|
|
149
|
+
"""
|
|
150
|
+
import sqlalchemy as sa
|
|
151
|
+
|
|
152
|
+
parts = [f"SELECT * FROM {self.source_table}"]
|
|
153
|
+
if where:
|
|
154
|
+
parts.append(f"WHERE {where}")
|
|
155
|
+
if limit is not None:
|
|
156
|
+
parts.append(f"LIMIT {limit}")
|
|
157
|
+
if offset:
|
|
158
|
+
parts.append(f"OFFSET {offset}")
|
|
159
|
+
|
|
160
|
+
sql = " ".join(parts)
|
|
161
|
+
|
|
162
|
+
with self._engine.connect() as conn:
|
|
163
|
+
cursor = conn.execute(sa.text(sql))
|
|
164
|
+
col_names = list(cursor.keys())
|
|
165
|
+
rows = []
|
|
166
|
+
for row in cursor:
|
|
167
|
+
raw = dict(zip(col_names, row))
|
|
168
|
+
mapped: dict = {}
|
|
169
|
+
for entail_field, db_col in self.column_map.items():
|
|
170
|
+
mapped[entail_field] = raw.get(db_col)
|
|
171
|
+
id_col = self.column_map.get("id", "id")
|
|
172
|
+
mapped["_row_id"] = raw.get(id_col, raw.get("id"))
|
|
173
|
+
mapped["_raw"] = raw
|
|
174
|
+
rows.append(mapped)
|
|
175
|
+
return rows
|
|
176
|
+
|
|
177
|
+
def write_result(
|
|
178
|
+
self,
|
|
179
|
+
row_id,
|
|
180
|
+
result,
|
|
181
|
+
strategy: str | None = None,
|
|
182
|
+
seed: int | None = None,
|
|
183
|
+
) -> None:
|
|
184
|
+
"""Write a single EntailmentResult to the result table.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
row_id: Source row identifier.
|
|
188
|
+
result: EntailmentResult from Auditor.score().
|
|
189
|
+
strategy: Optional sampling strategy label.
|
|
190
|
+
seed: Optional sampling seed.
|
|
191
|
+
"""
|
|
192
|
+
import sqlalchemy as sa
|
|
193
|
+
|
|
194
|
+
sql = sa.text(
|
|
195
|
+
f"INSERT INTO {self.result_table} "
|
|
196
|
+
"(source_row_id, scored_at, iqs, groundedness, completeness, "
|
|
197
|
+
"relevance, consistency, confidence, flags, details, strategy, sample_seed) "
|
|
198
|
+
"VALUES "
|
|
199
|
+
"(:source_row_id, :scored_at, :iqs, :groundedness, :completeness, "
|
|
200
|
+
":relevance, :consistency, :confidence, :flags, :details, :strategy, :seed)"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
with self._engine.connect() as conn:
|
|
204
|
+
conn.execute(sql, {
|
|
205
|
+
"source_row_id": str(row_id),
|
|
206
|
+
"scored_at": datetime.now(timezone.utc).replace(tzinfo=None),
|
|
207
|
+
"iqs": result.iqs,
|
|
208
|
+
"groundedness": result.groundedness,
|
|
209
|
+
"completeness": result.completeness,
|
|
210
|
+
"relevance": result.relevance,
|
|
211
|
+
"consistency": result.consistency,
|
|
212
|
+
"confidence": result.confidence,
|
|
213
|
+
"flags": json.dumps(result.flags),
|
|
214
|
+
"details": json.dumps(result.details, default=str),
|
|
215
|
+
"strategy": strategy,
|
|
216
|
+
"seed": seed,
|
|
217
|
+
})
|
|
218
|
+
conn.commit()
|
|
219
|
+
|
|
220
|
+
def score_all(self, auditor, where: str | None = None) -> dict:
|
|
221
|
+
"""Score all responses in the source table.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
auditor: Auditor instance.
|
|
225
|
+
where: Optional SQL WHERE clause to filter rows.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Dict with total_scored, mean_iqs, flag_counts.
|
|
229
|
+
"""
|
|
230
|
+
offset = 0
|
|
231
|
+
total_scored = 0
|
|
232
|
+
all_iqs: list[float] = []
|
|
233
|
+
all_flags: list[str] = []
|
|
234
|
+
|
|
235
|
+
while True:
|
|
236
|
+
rows = self.fetch(limit=self.batch_size, where=where, offset=offset)
|
|
237
|
+
if not rows:
|
|
238
|
+
break
|
|
239
|
+
for row in rows:
|
|
240
|
+
context = self._parse_context(row.get("context"))
|
|
241
|
+
result = auditor.score(
|
|
242
|
+
query=str(row["query"]),
|
|
243
|
+
response=str(row["response"]),
|
|
244
|
+
context=context,
|
|
245
|
+
)
|
|
246
|
+
self.write_result(row["_row_id"], result)
|
|
247
|
+
all_iqs.append(result.iqs)
|
|
248
|
+
all_flags.extend(result.flags)
|
|
249
|
+
total_scored += 1
|
|
250
|
+
offset += self.batch_size
|
|
251
|
+
logger.info("Scored %d responses so far...", total_scored)
|
|
252
|
+
|
|
253
|
+
import numpy as np
|
|
254
|
+
flag_counts: dict[str, int] = {}
|
|
255
|
+
for f in all_flags:
|
|
256
|
+
flag_counts[f] = flag_counts.get(f, 0) + 1
|
|
257
|
+
|
|
258
|
+
return {
|
|
259
|
+
"total_scored": total_scored,
|
|
260
|
+
"mean_iqs": float(np.mean(all_iqs)) if all_iqs else 0.0,
|
|
261
|
+
"flag_counts": flag_counts,
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
def score_where(self, auditor, where: str) -> dict:
|
|
265
|
+
"""Score responses matching a SQL WHERE clause.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
auditor: Auditor instance.
|
|
269
|
+
where: SQL WHERE clause (without the WHERE keyword).
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
Dict with total_scored, mean_iqs, flag_counts.
|
|
273
|
+
"""
|
|
274
|
+
return self.score_all(auditor, where=where)
|
|
275
|
+
|
|
276
|
+
def score_sampled(
|
|
277
|
+
self,
|
|
278
|
+
auditor,
|
|
279
|
+
strategy: str = "random",
|
|
280
|
+
sample_size: int | None = None,
|
|
281
|
+
sample_pct: float | None = None,
|
|
282
|
+
seed: int | None = 42,
|
|
283
|
+
where: str | None = None,
|
|
284
|
+
):
|
|
285
|
+
"""Score a sampled subset of responses from the database.
|
|
286
|
+
|
|
287
|
+
Fetches all matching rows, applies sampling, scores the sample,
|
|
288
|
+
and writes results back with the sampling strategy label.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
auditor: Auditor instance.
|
|
292
|
+
strategy: Sampling strategy ("random", "percentage", "confidence").
|
|
293
|
+
sample_size: For "random" strategy.
|
|
294
|
+
sample_pct: For "percentage" strategy.
|
|
295
|
+
seed: Random seed.
|
|
296
|
+
where: Optional SQL WHERE filter.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
SamplingResult.
|
|
300
|
+
"""
|
|
301
|
+
from ..sampling import sample_and_score
|
|
302
|
+
|
|
303
|
+
rows = self.fetch(where=where)
|
|
304
|
+
items = []
|
|
305
|
+
for row in rows:
|
|
306
|
+
context = self._parse_context(row.get("context"))
|
|
307
|
+
items.append({
|
|
308
|
+
"query": str(row["query"]),
|
|
309
|
+
"response": str(row["response"]),
|
|
310
|
+
"context": context,
|
|
311
|
+
"_row_id": row["_row_id"],
|
|
312
|
+
})
|
|
313
|
+
|
|
314
|
+
result = sample_and_score(
|
|
315
|
+
auditor=auditor,
|
|
316
|
+
items=items,
|
|
317
|
+
strategy=strategy,
|
|
318
|
+
sample_size=sample_size,
|
|
319
|
+
sample_pct=sample_pct,
|
|
320
|
+
seed=seed,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
for si in result.scored_items:
|
|
324
|
+
row_id = si["item"].get("_row_id")
|
|
325
|
+
if row_id is not None:
|
|
326
|
+
self.write_result(row_id, si["result"], strategy=strategy, seed=seed)
|
|
327
|
+
|
|
328
|
+
return result
|
|
329
|
+
|
|
330
|
+
def score_incremental(
|
|
331
|
+
self, auditor, cursor_column: str = "created_at"
|
|
332
|
+
) -> dict:
|
|
333
|
+
"""Score only new responses since the last scored row.
|
|
334
|
+
|
|
335
|
+
Finds the maximum cursor_column value among source rows that have
|
|
336
|
+
already been scored, then scores all source rows where
|
|
337
|
+
cursor_column > that watermark.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
auditor: Auditor instance.
|
|
341
|
+
cursor_column: Column in source_table to use as the cursor.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dict with total_scored, mean_iqs, flag_counts.
|
|
345
|
+
"""
|
|
346
|
+
import sqlalchemy as sa
|
|
347
|
+
|
|
348
|
+
id_col = self.column_map.get("id", "id")
|
|
349
|
+
with self._engine.connect() as conn:
|
|
350
|
+
max_cursor = conn.execute(sa.text(
|
|
351
|
+
f"SELECT MAX(src.{cursor_column}) FROM {self.source_table} src "
|
|
352
|
+
f"WHERE CAST(src.{id_col} AS TEXT) IN "
|
|
353
|
+
f"(SELECT source_row_id FROM {self.result_table})"
|
|
354
|
+
)).scalar()
|
|
355
|
+
|
|
356
|
+
where = f"{cursor_column} > '{max_cursor}'" if max_cursor else None
|
|
357
|
+
return self.score_all(auditor, where=where)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Context assembly for groundedness scoring.
|
|
2
|
+
|
|
3
|
+
Public exports: ContextBuilder, ContextPayload, ContextEntry.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .builder import ContextBuilder
|
|
7
|
+
from .payload import ContextEntry, ContextPayload
|
|
8
|
+
|
|
9
|
+
__all__ = ["ContextBuilder", "ContextPayload", "ContextEntry"]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Type adapters - extract plain text from common document/result types.
|
|
2
|
+
|
|
3
|
+
Adapters are intentionally defensive: they never raise, only return
|
|
4
|
+
None when extraction fails. The caller logs a warning and skips the
|
|
5
|
+
chunk so the client's pipeline is never crashed by an exotic type.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
_TEXT_KEYS = ('text', 'content', 'page_content', 'body', 'chunk')
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def extract_text(chunk: Any, source_hint: str = "") -> str | None:
|
|
16
|
+
"""Extract plain text from any common document type.
|
|
17
|
+
|
|
18
|
+
Handles, in order: plain strings, LangChain ``Document``
|
|
19
|
+
(``.page_content``), dicts with a known text key, LlamaIndex
|
|
20
|
+
``TextNode``/``NodeWithScore`` (``.text`` / ``.node.text``), ChromaDB
|
|
21
|
+
``QueryResult``-style objects (``.documents`` list of lists), and
|
|
22
|
+
Pinecone ``ScoredVector`` (``.metadata['text']``).
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
chunk: The object to extract text from.
|
|
26
|
+
source_hint: Optional label, reserved for diagnostics.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The extracted text, or None if extraction fails (caller logs
|
|
30
|
+
and skips - never raises).
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
if isinstance(chunk, str):
|
|
34
|
+
return chunk
|
|
35
|
+
|
|
36
|
+
# LangChain Document
|
|
37
|
+
page_content = getattr(chunk, 'page_content', None)
|
|
38
|
+
if isinstance(page_content, str):
|
|
39
|
+
return page_content
|
|
40
|
+
|
|
41
|
+
# Dict with known text keys
|
|
42
|
+
if isinstance(chunk, dict):
|
|
43
|
+
for key in _TEXT_KEYS:
|
|
44
|
+
if key in chunk and isinstance(chunk[key], str):
|
|
45
|
+
return chunk[key]
|
|
46
|
+
# ChromaDB-style dict result: {'documents': [[...]]}
|
|
47
|
+
docs = chunk.get('documents')
|
|
48
|
+
if isinstance(docs, list):
|
|
49
|
+
flat = [
|
|
50
|
+
d
|
|
51
|
+
for sub in docs
|
|
52
|
+
for d in (sub if isinstance(sub, list) else [sub])
|
|
53
|
+
if isinstance(d, str)
|
|
54
|
+
]
|
|
55
|
+
return '\n\n'.join(flat) if flat else None
|
|
56
|
+
return None
|
|
57
|
+
|
|
58
|
+
# LlamaIndex TextNode / NodeWithScore
|
|
59
|
+
text = getattr(chunk, 'text', None)
|
|
60
|
+
if isinstance(text, str):
|
|
61
|
+
return text
|
|
62
|
+
node = getattr(chunk, 'node', None)
|
|
63
|
+
if node is not None:
|
|
64
|
+
node_text = getattr(node, 'text', None)
|
|
65
|
+
if isinstance(node_text, str):
|
|
66
|
+
return node_text
|
|
67
|
+
|
|
68
|
+
# ChromaDB QueryResult object (list of lists)
|
|
69
|
+
documents = getattr(chunk, 'documents', None)
|
|
70
|
+
if isinstance(documents, list):
|
|
71
|
+
flat = [
|
|
72
|
+
d
|
|
73
|
+
for sub in documents
|
|
74
|
+
for d in (sub if isinstance(sub, list) else [sub])
|
|
75
|
+
if isinstance(d, str)
|
|
76
|
+
]
|
|
77
|
+
return '\n\n'.join(flat) if flat else None
|
|
78
|
+
|
|
79
|
+
# Pinecone ScoredVector with metadata
|
|
80
|
+
metadata = getattr(chunk, 'metadata', None)
|
|
81
|
+
if isinstance(metadata, dict) and isinstance(metadata.get('text'), str):
|
|
82
|
+
return metadata['text']
|
|
83
|
+
|
|
84
|
+
return None
|
|
85
|
+
except Exception:
|
|
86
|
+
return None
|