benchmax 0.1.2.dev34__py3-none-any.whl → 0.1.2.dev35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmax/cli/corpus.py +2 -2
- benchmax/cli/launch.py +1 -7
- benchmax/cli/scaffold/skills/launch-run/SKILL.md +0 -6
- benchmax/envs/postgres_search/search_env.py +14 -3
- benchmax/envs/telestich/example.py +2 -3
- benchmax/platform/client.py +2 -7
- benchmax/platform/training_run.py +0 -1
- benchmax/rag/corpus/postgres/client.py +237 -12
- benchmax/rag/corpus/postgres/exceptions.py +2 -2
- benchmax/rag/corpus/postgres/source.py +81 -24
- benchmax/rag/qa_generation/batch_processor.py +138 -12
- benchmax/rag/qa_generation/filters/grounding_llm.py +117 -34
- benchmax/rag/qa_generation/filters/hop_count_validity.py +116 -31
- benchmax/rag/qa_generation/filters/retrieval_llm.py +131 -44
- benchmax/rag/qa_generation/generators/direct_llm.py +123 -43
- benchmax/rag/qa_generation/metadata_linker.py +179 -10
- benchmax/rag/qa_generation/pipeline.py +297 -205
- benchmax/rag/qa_generation/pipeline_config.py +89 -0
- benchmax/rag/qa_generation/search_agent_linker.py +59 -6
- benchmax/rag/qa_generation/wiki_chunk_linker.py +34 -6
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/METADATA +2 -2
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/RECORD +26 -26
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/WHEEL +0 -0
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/entry_points.txt +0 -0
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/licenses/LICENSE +0 -0
- {benchmax-0.1.2.dev34.dist-info → benchmax-0.1.2.dev35.dist-info}/top_level.txt +0 -0
benchmax/cli/corpus.py
CHANGED
|
@@ -110,7 +110,7 @@ def _cmd_corpus_list(args: argparse.Namespace) -> int:
|
|
|
110
110
|
if not corpora:
|
|
111
111
|
print("No corpora yet. Create one: castform corpus ingest <folder>")
|
|
112
112
|
return 0
|
|
113
|
-
print(f"{len(corpora)}/
|
|
113
|
+
print(f"{len(corpora)}/20 corpora:")
|
|
114
114
|
for c in corpora:
|
|
115
115
|
print(f" {c.name} (id: {c.id})")
|
|
116
116
|
return 0
|
|
@@ -206,7 +206,7 @@ def register(sub: argparse._SubParsersAction) -> None:
|
|
|
206
206
|
p_ing.add_argument("--json", action="store_true", help="Emit raw JSON")
|
|
207
207
|
p_ing.set_defaults(func=_cmd_corpus_ingest)
|
|
208
208
|
|
|
209
|
-
p_ls = corpus_sub.add_parser("list", help="List your corpora (and the
|
|
209
|
+
p_ls = corpus_sub.add_parser("list", help="List your corpora (and the 20-corpus cap)")
|
|
210
210
|
p_ls.add_argument("--json", action="store_true", help="Emit raw JSON")
|
|
211
211
|
p_ls.set_defaults(func=_cmd_corpus_list)
|
|
212
212
|
|
benchmax/cli/launch.py
CHANGED
|
@@ -126,7 +126,7 @@ def _cmd_launch(args: argparse.Namespace) -> int:
|
|
|
126
126
|
)
|
|
127
127
|
return 1
|
|
128
128
|
reply = input(
|
|
129
|
-
f"Launch '{run_name}'
|
|
129
|
+
f"Launch '{run_name}' — incurs GPU cost. Continue? [y/N] "
|
|
130
130
|
)
|
|
131
131
|
if reply.strip().lower() not in ("y", "yes"):
|
|
132
132
|
print("Aborted.")
|
|
@@ -177,7 +177,6 @@ def _cmd_launch(args: argparse.Namespace) -> int:
|
|
|
177
177
|
with warnings.catch_warnings(record=True) as caught:
|
|
178
178
|
warnings.simplefilter("always")
|
|
179
179
|
run_id = client.launch_training_run(
|
|
180
|
-
training_run_type=args.type,
|
|
181
180
|
name=run_name,
|
|
182
181
|
launcher_args=launcher_args or None,
|
|
183
182
|
**dataclasses.asdict(uploaded),
|
|
@@ -211,11 +210,6 @@ def register(sub: argparse._SubParsersAction) -> None:
|
|
|
211
210
|
)
|
|
212
211
|
p.add_argument("--eval", default="eval_dataset.jsonl", help="Eval dataset (jsonl)")
|
|
213
212
|
p.add_argument("--name", help="Run name (default: the env class name)")
|
|
214
|
-
p.add_argument(
|
|
215
|
-
"--type",
|
|
216
|
-
default="simple",
|
|
217
|
-
help="Training run type: simple (GPU) or simple-cpu (smoke)",
|
|
218
|
-
)
|
|
219
213
|
p.add_argument(
|
|
220
214
|
"--env-arg", action="append", metavar="KEY=VALUE", help="Env constructor arg"
|
|
221
215
|
)
|
|
@@ -66,9 +66,3 @@ truncated in training — keep `MAX_SEARCH_CALLS` ≤ 8 (see design-environment'
|
|
|
66
66
|
Server-controlled fields — `save`, `load`, `global_batch_size`, the eval mirrors —
|
|
67
67
|
are **not settable**: the launch handler fills them in and rejects caller input
|
|
68
68
|
that carries them. (`rollout_batch_size` is derived too, not a launch arg.)
|
|
69
|
-
|
|
70
|
-
### Run types
|
|
71
|
-
|
|
72
|
-
`--type simple` (default) is the GPU training pool. `--type simple-cpu` is a
|
|
73
|
-
CPU-only smoke pool (cheap) for exercising the launch lifecycle without GPU.
|
|
74
|
-
(`simple-r5` from older docs is not implemented.)
|
|
@@ -23,7 +23,6 @@ from benchmax.envs.example_id import make_example
|
|
|
23
23
|
from benchmax.envs.reward_helpers import (
|
|
24
24
|
clip01,
|
|
25
25
|
count_search_calls,
|
|
26
|
-
extract_answer_block,
|
|
27
26
|
extract_completion_text,
|
|
28
27
|
search_within_budget,
|
|
29
28
|
)
|
|
@@ -40,6 +39,19 @@ logger = logging.getLogger(__name__)
|
|
|
40
39
|
|
|
41
40
|
_CITATION_RE = re.compile(r"\[Source:\s*([^\]]+)\]", re.IGNORECASE)
|
|
42
41
|
|
|
42
|
+
_ANSWER_TAG_RE = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _extract_answer_block(text: str) -> str:
|
|
46
|
+
"""Extract content from <answer> tags; return "" if no answer tag is present.
|
|
47
|
+
|
|
48
|
+
Strict variant of ``reward_helpers.extract_answer_block``: a completion
|
|
49
|
+
without an explicit ``<answer>`` block scores as no answer rather than
|
|
50
|
+
falling back to the full text.
|
|
51
|
+
"""
|
|
52
|
+
match = _ANSWER_TAG_RE.search(text or "")
|
|
53
|
+
return match.group(1).strip() if match else ""
|
|
54
|
+
|
|
43
55
|
# Match Python-style `{name}` placeholders with word-char names only —
|
|
44
56
|
# leaves JSON-like literals (e.g. `{"answer": "X"}`) and unknown keys
|
|
45
57
|
# untouched, so a user-edited SYSTEM_PROMPT_TEMPLATE that contains JSON
|
|
@@ -70,7 +82,6 @@ _CORRECTNESS_RUBRIC = Rubric(
|
|
|
70
82
|
type="positive",
|
|
71
83
|
score_map={
|
|
72
84
|
0: "Provided answer is missing or incorrect.",
|
|
73
|
-
0.5: "Partially correct — captures some facts but missing key details.",
|
|
74
85
|
1: "Fully correct and factually consistent.",
|
|
75
86
|
},
|
|
76
87
|
)
|
|
@@ -286,7 +297,7 @@ tags. Cite your sources inline using [Source: <source_id>] next to each claim.
|
|
|
286
297
|
return zeros
|
|
287
298
|
|
|
288
299
|
t = task or {}
|
|
289
|
-
answer =
|
|
300
|
+
answer = _extract_answer_block(text)
|
|
290
301
|
prompt = str(t.get("question") or t.get("prompt") or "")
|
|
291
302
|
gt_str = str(t.get("ground_truth") or "")
|
|
292
303
|
reference_chunks = t.get("reference_chunks", [])
|
|
@@ -642,12 +642,11 @@ if __name__ == "__main__":
|
|
|
642
642
|
):
|
|
643
643
|
print(f" {label:<14}: {path}")
|
|
644
644
|
|
|
645
|
-
# 4. Launch the training run.
|
|
646
|
-
#
|
|
645
|
+
# 4. Launch the training run. The model arg selects the trainer YAML/pool
|
|
646
|
+
# server-side (Qwen3.5-4B→gpu4, Qwen3.5-35B-A3B→gpu8).
|
|
647
647
|
print(f"\nLaunching training run (model={MODEL}) ...")
|
|
648
648
|
with TrainerClient(api_key="", base_url=BASE_URL) as trainer:
|
|
649
649
|
run_id = trainer.launch_training_run(
|
|
650
|
-
training_run_type="simple",
|
|
651
650
|
env_cls_path=uploaded.env_cls_path,
|
|
652
651
|
env_metadata_path=uploaded.env_metadata_path,
|
|
653
652
|
train_dataset_path=uploaded.train_dataset_path,
|
benchmax/platform/client.py
CHANGED
|
@@ -360,7 +360,6 @@ class TrainerClient:
|
|
|
360
360
|
Example:
|
|
361
361
|
client = TrainerClient(api_key="sk_...", base_url="http://localhost:3000")
|
|
362
362
|
run_id = client.launch_training_run(
|
|
363
|
-
training_run_type="simple",
|
|
364
363
|
env_cls_path="envs/run-abc/abc123/env-cls.pkl",
|
|
365
364
|
env_metadata_path="envs/run-abc/abc123/env-metadata.json",
|
|
366
365
|
train_dataset_path="datasets/run-abc/def456/train.jsonl",
|
|
@@ -413,7 +412,6 @@ class TrainerClient:
|
|
|
413
412
|
|
|
414
413
|
def launch_training_run(
|
|
415
414
|
self,
|
|
416
|
-
training_run_type: str,
|
|
417
415
|
env_cls_path: str,
|
|
418
416
|
env_metadata_path: str,
|
|
419
417
|
train_dataset_path: str,
|
|
@@ -421,12 +419,9 @@ class TrainerClient:
|
|
|
421
419
|
name: str | None = None,
|
|
422
420
|
launcher_args: dict[str, Any] | None = None,
|
|
423
421
|
) -> str:
|
|
424
|
-
"""Launch a new training run
|
|
422
|
+
"""Launch a new training run.
|
|
425
423
|
|
|
426
424
|
Args:
|
|
427
|
-
training_run_type: Job template selector. ``"simple"`` (GPU pool —
|
|
428
|
-
gpu4 for 4B, gpu8 for 35B) or ``"simple-cpu"`` (CPU-only smoke
|
|
429
|
-
pool, no GPU).
|
|
430
425
|
env_cls_path: Path to the environment class pickle (.pkl file)
|
|
431
426
|
env_metadata_path: Path to the environment metadata JSON file
|
|
432
427
|
train_dataset_path: Path to the training dataset
|
|
@@ -453,7 +448,7 @@ class TrainerClient:
|
|
|
453
448
|
response = self._http_client.post(
|
|
454
449
|
"/v1/train/runs/launch",
|
|
455
450
|
json={
|
|
456
|
-
"type":
|
|
451
|
+
"type": "simple",
|
|
457
452
|
"name": name,
|
|
458
453
|
"args": args,
|
|
459
454
|
},
|
|
@@ -2,7 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import asyncio
|
|
5
6
|
import logging
|
|
7
|
+
import threading
|
|
6
8
|
import time
|
|
7
9
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
10
|
from dataclasses import dataclass, field
|
|
@@ -15,7 +17,6 @@ from benchmax.platform.credentials import TokenProvider, platform_bearer
|
|
|
15
17
|
|
|
16
18
|
from .exceptions import (
|
|
17
19
|
AuthenticationError,
|
|
18
|
-
ChunkLimitError,
|
|
19
20
|
CorpusAPIError,
|
|
20
21
|
CorpusLimitError,
|
|
21
22
|
CorpusNotFoundError,
|
|
@@ -52,11 +53,38 @@ class CorpusClient:
|
|
|
52
53
|
max_retries: int = 5
|
|
53
54
|
retry_backoff_seconds: float = 0.5
|
|
54
55
|
token_provider: TokenProvider = platform_bearer
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
56
|
+
# Enable HTTP/2 multiplexing on the async client. Safe there (one client
|
|
57
|
+
# bound to one event loop), unlike the shared sync client across threads.
|
|
58
|
+
async_http2: bool = True
|
|
59
|
+
# HTTP clients are created lazily, one per thread (see ``_http_client``).
|
|
60
|
+
# httpx.Client's connection pool is not safe to share across threads at high
|
|
61
|
+
# parallelism: the QA-gen work queue hits this client from every batch
|
|
62
|
+
# thread, which raced the shared pool's sockets into
|
|
63
|
+
# ``ReadError: [Errno 9] Bad file descriptor``. A client-per-thread avoids
|
|
64
|
+
# the shared-pool race entirely.
|
|
65
|
+
_local: threading.local = field(
|
|
66
|
+
init=False, repr=False, default_factory=threading.local
|
|
67
|
+
)
|
|
68
|
+
_client_override: httpx.Client | None = field(
|
|
69
|
+
init=False, repr=False, default=None
|
|
70
|
+
)
|
|
71
|
+
_client_registry: list[httpx.Client] = field(
|
|
72
|
+
init=False, repr=False, default_factory=list
|
|
73
|
+
)
|
|
74
|
+
_registry_lock: threading.Lock = field(
|
|
75
|
+
init=False, repr=False, default_factory=threading.Lock
|
|
76
|
+
)
|
|
77
|
+
# Single async client, lazily bound to the running event loop (rebuilt if the
|
|
78
|
+
# loop changes — asyncio.run() mints a fresh loop per Pipeline.run()).
|
|
79
|
+
_async_client: httpx.AsyncClient | None = field(
|
|
80
|
+
init=False, repr=False, default=None
|
|
81
|
+
)
|
|
82
|
+
_async_client_loop: Any = field(init=False, repr=False, default=None)
|
|
83
|
+
|
|
84
|
+
def _build_http_client(self) -> httpx.Client:
|
|
85
|
+
"""Create a new HTTP client and register it for ``close()``.
|
|
86
|
+
|
|
87
|
+
Auth is attached per request in ``_request`` — not baked here."""
|
|
60
88
|
timeout_config = httpx.Timeout(
|
|
61
89
|
timeout=self.timeout,
|
|
62
90
|
connect=self.timeout,
|
|
@@ -64,11 +92,37 @@ class CorpusClient:
|
|
|
64
92
|
write=self.timeout,
|
|
65
93
|
pool=self.timeout,
|
|
66
94
|
)
|
|
67
|
-
|
|
95
|
+
client = httpx.Client(
|
|
68
96
|
base_url=self.base_url,
|
|
69
97
|
headers={"Content-Type": "application/json"},
|
|
70
98
|
timeout=timeout_config,
|
|
71
99
|
)
|
|
100
|
+
with self._registry_lock:
|
|
101
|
+
self._client_registry.append(client)
|
|
102
|
+
return client
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def _http_client(self) -> httpx.Client:
|
|
106
|
+
"""The HTTP client for the calling thread.
|
|
107
|
+
|
|
108
|
+
Returns an explicitly installed override if present (e.g. a profiling
|
|
109
|
+
harness swapping in an HTTP/2 client); otherwise a lazily-created
|
|
110
|
+
per-thread client so concurrent batch threads never share a pool."""
|
|
111
|
+
if self._client_override is not None:
|
|
112
|
+
return self._client_override
|
|
113
|
+
client = getattr(self._local, "client", None)
|
|
114
|
+
if client is None:
|
|
115
|
+
client = self._build_http_client()
|
|
116
|
+
self._local.client = client
|
|
117
|
+
return client
|
|
118
|
+
|
|
119
|
+
@_http_client.setter
|
|
120
|
+
def _http_client(self, value: httpx.Client) -> None:
|
|
121
|
+
"""Install a single client shared across all threads. Intended for
|
|
122
|
+
single-threaded or multiplexed (HTTP/2) setups, not pool-per-thread."""
|
|
123
|
+
self._client_override = value
|
|
124
|
+
with self._registry_lock:
|
|
125
|
+
self._client_registry.append(value)
|
|
72
126
|
|
|
73
127
|
def _request(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
|
|
74
128
|
"""Execute an HTTP request with retry/backoff for transient network failures.
|
|
@@ -156,6 +210,97 @@ class CorpusClient:
|
|
|
156
210
|
pass
|
|
157
211
|
return self.retry_backoff_seconds * (2 ** (attempt - 1))
|
|
158
212
|
|
|
213
|
+
def _get_async_client(self) -> httpx.AsyncClient:
|
|
214
|
+
"""The async client bound to the running event loop.
|
|
215
|
+
|
|
216
|
+
Rebuilt when the loop changes (``asyncio.run`` mints a fresh loop per
|
|
217
|
+
``Pipeline.run``) or the client was closed. Creation does not ``await``,
|
|
218
|
+
so on a single event loop the check-then-build is race-free."""
|
|
219
|
+
loop = asyncio.get_running_loop()
|
|
220
|
+
client = self._async_client
|
|
221
|
+
if client is None or client.is_closed or self._async_client_loop is not loop:
|
|
222
|
+
timeout_config = httpx.Timeout(
|
|
223
|
+
timeout=self.timeout,
|
|
224
|
+
connect=self.timeout,
|
|
225
|
+
read=self.timeout,
|
|
226
|
+
write=self.timeout,
|
|
227
|
+
pool=self.timeout,
|
|
228
|
+
)
|
|
229
|
+
client = httpx.AsyncClient(
|
|
230
|
+
base_url=self.base_url,
|
|
231
|
+
headers={"Content-Type": "application/json"},
|
|
232
|
+
timeout=timeout_config,
|
|
233
|
+
http2=self.async_http2,
|
|
234
|
+
)
|
|
235
|
+
self._async_client = client
|
|
236
|
+
self._async_client_loop = loop
|
|
237
|
+
return client
|
|
238
|
+
|
|
239
|
+
async def _arequest(self, method: str, path: str, **kwargs: Any) -> httpx.Response:
|
|
240
|
+
"""Async twin of ``_request`` — same retry/backoff and 429 handling, with
|
|
241
|
+
``await asyncio.sleep`` instead of ``time.sleep`` so the loop stays free."""
|
|
242
|
+
try:
|
|
243
|
+
bearer = self.token_provider()
|
|
244
|
+
except RuntimeError as exc:
|
|
245
|
+
# The seam (platform_bearer) raises when no credential resolves; surface
|
|
246
|
+
# it as an auth error so callers catch it like any other Corpora failure.
|
|
247
|
+
raise AuthenticationError(
|
|
248
|
+
f"No Castform platform credential available for the Corpora API: {exc}"
|
|
249
|
+
) from exc
|
|
250
|
+
headers = {
|
|
251
|
+
**kwargs.pop("headers", {}),
|
|
252
|
+
"Authorization": f"Bearer {bearer}",
|
|
253
|
+
}
|
|
254
|
+
retries = max(1, int(self.max_retries))
|
|
255
|
+
client = self._get_async_client()
|
|
256
|
+
attempt = 1
|
|
257
|
+
while True:
|
|
258
|
+
try:
|
|
259
|
+
response = await client.request(method, path, headers=headers, **kwargs)
|
|
260
|
+
except (httpx.ConnectTimeout, httpx.ConnectError, httpx.ReadTimeout) as exc:
|
|
261
|
+
if attempt >= retries:
|
|
262
|
+
raise CorpusAPIError(
|
|
263
|
+
(
|
|
264
|
+
"Corpora API request failed after retries due to a network timeout/error. "
|
|
265
|
+
f"method={method} path={path} base_url={self.base_url} "
|
|
266
|
+
f"attempts={retries} last_error={exc!s}"
|
|
267
|
+
),
|
|
268
|
+
status_code=503,
|
|
269
|
+
) from exc
|
|
270
|
+
delay = self.retry_backoff_seconds * (2 ** (attempt - 1))
|
|
271
|
+
logger.warning(
|
|
272
|
+
"Corpora API request attempt %s/%s failed (%s). Retrying in %.2fs. "
|
|
273
|
+
"method=%s path=%s base_url=%s",
|
|
274
|
+
attempt,
|
|
275
|
+
retries,
|
|
276
|
+
type(exc).__name__,
|
|
277
|
+
delay,
|
|
278
|
+
method,
|
|
279
|
+
path,
|
|
280
|
+
self.base_url,
|
|
281
|
+
)
|
|
282
|
+
await asyncio.sleep(delay)
|
|
283
|
+
attempt += 1
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
if response.status_code == 429 and attempt < retries:
|
|
287
|
+
delay = self._retry_after_delay(response, attempt)
|
|
288
|
+
logger.warning(
|
|
289
|
+
"Corpora API rate-limited (429) on attempt %s/%s. Retrying in %.2fs. "
|
|
290
|
+
"method=%s path=%s base_url=%s",
|
|
291
|
+
attempt,
|
|
292
|
+
retries,
|
|
293
|
+
delay,
|
|
294
|
+
method,
|
|
295
|
+
path,
|
|
296
|
+
self.base_url,
|
|
297
|
+
)
|
|
298
|
+
await asyncio.sleep(delay)
|
|
299
|
+
attempt += 1
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
return response
|
|
303
|
+
|
|
159
304
|
def __enter__(self) -> "CorpusClient":
|
|
160
305
|
return self
|
|
161
306
|
|
|
@@ -163,8 +308,28 @@ class CorpusClient:
|
|
|
163
308
|
self.close()
|
|
164
309
|
|
|
165
310
|
def close(self) -> None:
|
|
166
|
-
"""Close
|
|
167
|
-
|
|
311
|
+
"""Close every HTTP client this instance created (one per thread, plus
|
|
312
|
+
any installed override)."""
|
|
313
|
+
with self._registry_lock:
|
|
314
|
+
clients = list(self._client_registry)
|
|
315
|
+
self._client_registry.clear()
|
|
316
|
+
self._client_override = None
|
|
317
|
+
for client in clients:
|
|
318
|
+
try:
|
|
319
|
+
client.close()
|
|
320
|
+
except Exception: # noqa: BLE001 — best-effort cleanup
|
|
321
|
+
logger.debug("Error closing corpus HTTP client", exc_info=True)
|
|
322
|
+
|
|
323
|
+
async def aclose(self) -> None:
|
|
324
|
+
"""Close the async client. Call from within its event loop."""
|
|
325
|
+
client = self._async_client
|
|
326
|
+
self._async_client = None
|
|
327
|
+
self._async_client_loop = None
|
|
328
|
+
if client is not None and not client.is_closed:
|
|
329
|
+
try:
|
|
330
|
+
await client.aclose()
|
|
331
|
+
except Exception: # noqa: BLE001 — best-effort cleanup
|
|
332
|
+
logger.debug("Error closing corpus async client", exc_info=True)
|
|
168
333
|
|
|
169
334
|
def _handle_response_errors(self, response: httpx.Response) -> None:
|
|
170
335
|
"""Convert HTTP errors to appropriate exceptions."""
|
|
@@ -181,7 +346,7 @@ class CorpusClient:
|
|
|
181
346
|
raise AuthenticationError(message)
|
|
182
347
|
|
|
183
348
|
if response.status_code == 400:
|
|
184
|
-
if "Maximum of
|
|
349
|
+
if "Maximum of 20 corpora" in message:
|
|
185
350
|
raise CorpusLimitError()
|
|
186
351
|
if "Chunk limit exceeded" in message:
|
|
187
352
|
raise CorpusAPIError(message, 400)
|
|
@@ -204,7 +369,7 @@ class CorpusClient:
|
|
|
204
369
|
Corpus object with id, name, timestamps
|
|
205
370
|
|
|
206
371
|
Raises:
|
|
207
|
-
CorpusLimitError: If max
|
|
372
|
+
CorpusLimitError: If max 20 corpora limit reached
|
|
208
373
|
AuthenticationError: If API key is invalid
|
|
209
374
|
"""
|
|
210
375
|
response = self._request("POST", "/v1/corpora", json={"name": name})
|
|
@@ -306,7 +471,7 @@ class CorpusClient:
|
|
|
306
471
|
print(f" ID: {corpus.id}")
|
|
307
472
|
print(f" Created: {corpus.created_at}")
|
|
308
473
|
|
|
309
|
-
print(
|
|
474
|
+
print("\n 0. Cancel operation")
|
|
310
475
|
print()
|
|
311
476
|
|
|
312
477
|
while True:
|
|
@@ -572,3 +737,63 @@ class CorpusClient:
|
|
|
572
737
|
matched.append((local_chunk, corpus_chunk.score or 0.0))
|
|
573
738
|
|
|
574
739
|
return matched
|
|
740
|
+
|
|
741
|
+
async def asearch(
|
|
742
|
+
self,
|
|
743
|
+
corpus_id: str,
|
|
744
|
+
query: str,
|
|
745
|
+
limit: int = 10,
|
|
746
|
+
offset: int = 0,
|
|
747
|
+
metadata: dict[str, Any] | None = None,
|
|
748
|
+
filters: dict[str, Any] | None = None,
|
|
749
|
+
) -> SearchResult:
|
|
750
|
+
"""Async twin of ``search``. Same payload + response shape, async I/O."""
|
|
751
|
+
payload: dict[str, Any] = {"query": query, "limit": limit, "offset": offset}
|
|
752
|
+
if metadata:
|
|
753
|
+
payload["metadata"] = metadata
|
|
754
|
+
if filters:
|
|
755
|
+
payload["filters"] = filters
|
|
756
|
+
|
|
757
|
+
response = await self._arequest(
|
|
758
|
+
"POST", f"/v1/corpora/{corpus_id}/search", json=payload
|
|
759
|
+
)
|
|
760
|
+
self._handle_response_errors(response)
|
|
761
|
+
|
|
762
|
+
data = response.json()
|
|
763
|
+
results = [
|
|
764
|
+
CorpusChunk(
|
|
765
|
+
id=r["id"],
|
|
766
|
+
content=r["content"],
|
|
767
|
+
metadata=r.get("metadata") or {},
|
|
768
|
+
score=r.get("score"),
|
|
769
|
+
)
|
|
770
|
+
for r in data.get("results", [])
|
|
771
|
+
]
|
|
772
|
+
|
|
773
|
+
return SearchResult(results=results, total=data.get("total", 0), query=query)
|
|
774
|
+
|
|
775
|
+
async def asearch_with_chunks(
|
|
776
|
+
self,
|
|
777
|
+
corpus_id: str,
|
|
778
|
+
query: str,
|
|
779
|
+
collection: ChunkCollection,
|
|
780
|
+
limit: int = 10,
|
|
781
|
+
metadata: dict[str, Any] | None = None,
|
|
782
|
+
filters: dict[str, Any] | None = None,
|
|
783
|
+
) -> list[tuple[Chunk, float]]:
|
|
784
|
+
"""Async twin of ``search_with_chunks``."""
|
|
785
|
+
result = await self.asearch(
|
|
786
|
+
corpus_id=corpus_id,
|
|
787
|
+
query=query,
|
|
788
|
+
limit=limit,
|
|
789
|
+
metadata=metadata,
|
|
790
|
+
filters=filters,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
matched: list[tuple[Chunk, float]] = []
|
|
794
|
+
for corpus_chunk in result.results:
|
|
795
|
+
local_chunk = collection.get_chunk_by_hash(corpus_chunk.id)
|
|
796
|
+
if local_chunk:
|
|
797
|
+
matched.append((local_chunk, corpus_chunk.score or 0.0))
|
|
798
|
+
|
|
799
|
+
return matched
|
|
@@ -24,10 +24,10 @@ class AuthenticationError(CorpusAPIError):
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class CorpusLimitError(CorpusAPIError):
|
|
27
|
-
"""Maximum corpus limit (
|
|
27
|
+
"""Maximum corpus limit (20) reached."""
|
|
28
28
|
|
|
29
29
|
def __init__(self, existing_corpora: list[Corpus] | None = None):
|
|
30
|
-
super().__init__("Maximum of
|
|
30
|
+
super().__init__("Maximum of 20 corpora per user reached", 400)
|
|
31
31
|
self.existing_corpora = existing_corpora or []
|
|
32
32
|
|
|
33
33
|
|
|
@@ -329,36 +329,93 @@ class PostgresChunkSource:
|
|
|
329
329
|
collection=self.collection,
|
|
330
330
|
limit=top_k,
|
|
331
331
|
)
|
|
332
|
+
self._accumulate_related(related_map, source, query, matched_chunks, top_k)
|
|
332
333
|
|
|
333
|
-
|
|
334
|
-
|
|
334
|
+
return self._sorted_related(related_map)
|
|
335
|
+
|
|
336
|
+
async def asearch_related(
|
|
337
|
+
self,
|
|
338
|
+
source: Chunk,
|
|
339
|
+
queries: list[str],
|
|
340
|
+
top_k: int = 5,
|
|
341
|
+
mode: SearchMode | None = None,
|
|
342
|
+
hybrid: HybridOptions | None = None,
|
|
343
|
+
) -> list[dict]:
|
|
344
|
+
"""Async twin of ``search_related`` — identical dedup/neighbor-skip/scoring,
|
|
345
|
+
async corpus I/O. Queries run sequentially for parity with the sync path;
|
|
346
|
+
cross-batch search concurrency comes from the async work queue."""
|
|
347
|
+
if hybrid is not None:
|
|
348
|
+
warnings.warn(
|
|
349
|
+
"PostgresChunkSource does not support hybrid search; 'hybrid' parameter is ignored.",
|
|
350
|
+
stacklevel=2,
|
|
351
|
+
)
|
|
352
|
+
if mode is not None and mode != "lexical":
|
|
353
|
+
warnings.warn(
|
|
354
|
+
f"PostgresChunkSource only supports 'lexical' mode; '{mode}' will be ignored.",
|
|
355
|
+
stacklevel=2,
|
|
356
|
+
)
|
|
357
|
+
self._assert_ready()
|
|
358
|
+
related_map: dict[str, dict] = {}
|
|
359
|
+
|
|
360
|
+
for query in queries:
|
|
361
|
+
matched_chunks = await self._client.asearch_with_chunks(
|
|
362
|
+
corpus_id=self._corpus.id,
|
|
363
|
+
query=query,
|
|
364
|
+
collection=self.collection,
|
|
365
|
+
limit=top_k,
|
|
366
|
+
)
|
|
367
|
+
self._accumulate_related(related_map, source, query, matched_chunks, top_k)
|
|
368
|
+
|
|
369
|
+
return self._sorted_related(related_map)
|
|
370
|
+
|
|
371
|
+
async def aclose(self) -> None:
|
|
372
|
+
"""Close the underlying corpus client's async transport (best-effort).
|
|
373
|
+
Call from within the event loop that used it."""
|
|
374
|
+
await self._client.aclose()
|
|
375
|
+
|
|
376
|
+
@staticmethod
|
|
377
|
+
def _accumulate_related(
|
|
378
|
+
related_map: dict[str, dict],
|
|
379
|
+
source: Chunk,
|
|
380
|
+
query: str,
|
|
381
|
+
matched_chunks: list[tuple[Chunk, float]],
|
|
382
|
+
top_k: int,
|
|
383
|
+
) -> None:
|
|
384
|
+
"""Merge one query's results into ``related_map``: skip the source chunk
|
|
385
|
+
and its same-file neighbors, dedup by hash, aggregate queries + max score."""
|
|
386
|
+
for result_chunk, score in matched_chunks[:top_k]:
|
|
387
|
+
if result_chunk.hash == source.hash:
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
is_same_file = result_chunk.get_metadata("file") == source.get_metadata(
|
|
391
|
+
"file"
|
|
392
|
+
)
|
|
393
|
+
if is_same_file:
|
|
394
|
+
index_diff = abs(
|
|
395
|
+
result_chunk.get_metadata("index", 0)
|
|
396
|
+
- source.get_metadata("index", 0)
|
|
397
|
+
)
|
|
398
|
+
if index_diff <= 1:
|
|
335
399
|
continue
|
|
336
400
|
|
|
337
|
-
|
|
338
|
-
|
|
401
|
+
if result_chunk.hash not in related_map:
|
|
402
|
+
related_map[result_chunk.hash] = {
|
|
403
|
+
"chunk": result_chunk,
|
|
404
|
+
"queries": [],
|
|
405
|
+
"same_file": is_same_file,
|
|
406
|
+
"max_score": score,
|
|
407
|
+
}
|
|
408
|
+
else:
|
|
409
|
+
related_map[result_chunk.hash]["max_score"] = max(
|
|
410
|
+
related_map[result_chunk.hash]["max_score"], score
|
|
339
411
|
)
|
|
340
|
-
if is_same_file:
|
|
341
|
-
index_diff = abs(
|
|
342
|
-
result_chunk.get_metadata("index", 0)
|
|
343
|
-
- source.get_metadata("index", 0)
|
|
344
|
-
)
|
|
345
|
-
if index_diff <= 1:
|
|
346
|
-
continue
|
|
347
|
-
|
|
348
|
-
if result_chunk.hash not in related_map:
|
|
349
|
-
related_map[result_chunk.hash] = {
|
|
350
|
-
"chunk": result_chunk,
|
|
351
|
-
"queries": [],
|
|
352
|
-
"same_file": is_same_file,
|
|
353
|
-
"max_score": score,
|
|
354
|
-
}
|
|
355
|
-
else:
|
|
356
|
-
related_map[result_chunk.hash]["max_score"] = max(
|
|
357
|
-
related_map[result_chunk.hash]["max_score"], score
|
|
358
|
-
)
|
|
359
412
|
|
|
360
|
-
|
|
413
|
+
related_map[result_chunk.hash]["queries"].append(query)
|
|
361
414
|
|
|
415
|
+
@staticmethod
|
|
416
|
+
def _sorted_related(related_map: dict[str, dict]) -> list[dict]:
|
|
417
|
+
"""Sort related chunks: most matching queries first, cross-file before
|
|
418
|
+
same-file, then max BM25 score — all descending."""
|
|
362
419
|
return sorted(
|
|
363
420
|
related_map.values(),
|
|
364
421
|
key=lambda x: (len(x["queries"]), not x["same_file"], x["max_score"]),
|