claude-sql 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- claude_sql/__init__.py +5 -0
- claude_sql/binding.py +740 -0
- claude_sql/blind_handover.py +155 -0
- claude_sql/checkpointer.py +202 -0
- claude_sql/cli.py +2344 -0
- claude_sql/cluster_worker.py +208 -0
- claude_sql/community_worker.py +306 -0
- claude_sql/config.py +380 -0
- claude_sql/embed_worker.py +482 -0
- claude_sql/freeze.py +189 -0
- claude_sql/friction_worker.py +561 -0
- claude_sql/install_source.py +77 -0
- claude_sql/judge_worker.py +459 -0
- claude_sql/judges.py +239 -0
- claude_sql/kappa_worker.py +257 -0
- claude_sql/llm_worker.py +1760 -0
- claude_sql/logging_setup.py +95 -0
- claude_sql/output.py +248 -0
- claude_sql/parquet_shards.py +172 -0
- claude_sql/retry_queue.py +180 -0
- claude_sql/review_sheet_render.py +167 -0
- claude_sql/review_sheet_worker.py +463 -0
- claude_sql/schemas.py +454 -0
- claude_sql/session_text.py +387 -0
- claude_sql/skills_catalog.py +354 -0
- claude_sql/sql_views.py +1751 -0
- claude_sql/terms_worker.py +145 -0
- claude_sql/ungrounded_worker.py +190 -0
- claude_sql-0.4.0.dist-info/METADATA +530 -0
- claude_sql-0.4.0.dist-info/RECORD +32 -0
- claude_sql-0.4.0.dist-info/WHEEL +4 -0
- claude_sql-0.4.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
"""Cohere Embed v4 backfill worker for claude-sql.
|
|
2
|
+
|
|
3
|
+
Discovers messages with no embedding yet, invokes ``cohere.embed-v4:0`` on
|
|
4
|
+
Amazon Bedrock in parallel batches (up to 96 texts per call), and appends the
|
|
5
|
+
resulting vectors to a parquet file keyed by message ``uuid``.
|
|
6
|
+
|
|
7
|
+
The worker converts the int8 response to float on insert because DuckDB's VSS
|
|
8
|
+
HNSW index requires ``FLOAT[]`` columns (storage loss of ~4x is accepted in
|
|
9
|
+
v1 — see research notes).
|
|
10
|
+
|
|
11
|
+
Tenacity retries on transient Bedrock errors via the loguru-native
|
|
12
|
+
``loguru_before_sleep`` helper from ``logging_setup`` — keeps every module
|
|
13
|
+
in claude-sql on a single logger (no stdlib ``logging`` imports).
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import json
|
|
20
|
+
import time
|
|
21
|
+
from datetime import UTC, datetime
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
import boto3
|
|
26
|
+
import duckdb
|
|
27
|
+
import polars as pl
|
|
28
|
+
from botocore.config import Config as BotoConfig
|
|
29
|
+
from botocore.exceptions import (
|
|
30
|
+
ClientError,
|
|
31
|
+
ConnectionError as BotoConnectionError,
|
|
32
|
+
EndpointConnectionError,
|
|
33
|
+
ReadTimeoutError,
|
|
34
|
+
SSLError,
|
|
35
|
+
)
|
|
36
|
+
from loguru import logger
|
|
37
|
+
from tenacity import (
|
|
38
|
+
retry,
|
|
39
|
+
retry_if_exception,
|
|
40
|
+
stop_after_attempt,
|
|
41
|
+
wait_exponential,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
from claude_sql.config import Settings
|
|
45
|
+
from claude_sql.logging_setup import loguru_before_sleep
|
|
46
|
+
from claude_sql.parquet_shards import iter_part_files, write_part
|
|
47
|
+
|
|
48
|
+
#: Conservative per-text character cap before sending to Bedrock. The real
|
|
49
|
+
#: model limit is 128K tokens per text; this cap keeps total payload below
|
|
50
|
+
#: the Bedrock 20 MB body ceiling even with a full batch of 96 large texts.
|
|
51
|
+
MAX_CHARS_PER_TEXT = 50_000
|
|
52
|
+
|
|
53
|
+
#: Bedrock error codes that tenacity should retry.
|
|
54
|
+
_RETRY_CODES: set[str] = {
|
|
55
|
+
"ThrottlingException",
|
|
56
|
+
"ServiceUnavailableException",
|
|
57
|
+
"ModelTimeoutException",
|
|
58
|
+
"ModelErrorException",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_retryable(exc: BaseException) -> bool:
|
|
63
|
+
"""Return True if ``exc`` is a Bedrock error worth retrying.
|
|
64
|
+
|
|
65
|
+
Two buckets:
|
|
66
|
+
* ``ClientError`` with a code in :data:`_RETRY_CODES` — service-level
|
|
67
|
+
throttling and transient model failures.
|
|
68
|
+
* Network-layer errors (SSL, connection, endpoint, read-timeout) that
|
|
69
|
+
surface when long-running batches hit flaky TCP connections.
|
|
70
|
+
"""
|
|
71
|
+
if isinstance(exc, SSLError | BotoConnectionError | EndpointConnectionError | ReadTimeoutError):
|
|
72
|
+
return True
|
|
73
|
+
if not isinstance(exc, ClientError):
|
|
74
|
+
return False
|
|
75
|
+
code = exc.response.get("Error", {}).get("Code")
|
|
76
|
+
return code in _RETRY_CODES
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def discover_unembedded(
|
|
80
|
+
con: duckdb.DuckDBPyConnection,
|
|
81
|
+
*,
|
|
82
|
+
embeddings_parquet: Path,
|
|
83
|
+
since_days: int | None = None,
|
|
84
|
+
limit: int | None = None,
|
|
85
|
+
) -> list[tuple[str, str]]:
|
|
86
|
+
"""Return ``(uuid, text)`` pairs that have no embedding yet.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
con
|
|
91
|
+
An open DuckDB connection with the ``messages_text`` view registered.
|
|
92
|
+
embeddings_parquet
|
|
93
|
+
Path to the parquet of already-embedded rows; may not exist yet.
|
|
94
|
+
since_days
|
|
95
|
+
If given, only include messages with ``ts >= now() - since_days``.
|
|
96
|
+
limit
|
|
97
|
+
Optional row cap.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
list of (uuid, text) tuples
|
|
102
|
+
Messages needing embedding, in DuckDB's scan order.
|
|
103
|
+
"""
|
|
104
|
+
# Treat missing / truncated parquet files (or empty shard directories) as
|
|
105
|
+
# "no embeddings yet" so an aborted previous run doesn't lock discovery
|
|
106
|
+
# into "skip all" via a corrupt index. Sharded directories are scanned
|
|
107
|
+
# via the part-file glob; legacy single files keep their original path.
|
|
108
|
+
parts = iter_part_files(embeddings_parquet)
|
|
109
|
+
parts = [p for p in parts if p.stat().st_size > 16]
|
|
110
|
+
if parts:
|
|
111
|
+
# CREATE VIEW doesn't accept prepared parameters in DuckDB; escape
|
|
112
|
+
# each part path inline as a SQL string literal and pass them all
|
|
113
|
+
# to ``read_parquet`` as a list.
|
|
114
|
+
path_literals = ", ".join(f"'{str(p).replace(chr(39), chr(39) * 2)}'" for p in parts)
|
|
115
|
+
con.execute(
|
|
116
|
+
"CREATE OR REPLACE TEMP VIEW _embedded AS "
|
|
117
|
+
f"SELECT uuid FROM read_parquet([{path_literals}]);"
|
|
118
|
+
)
|
|
119
|
+
# mt.uuid is typed UUID; parquet uuid column is VARCHAR. Cast to match.
|
|
120
|
+
anti = "AND CAST(mt.uuid AS VARCHAR) NOT IN (SELECT uuid FROM _embedded)"
|
|
121
|
+
else:
|
|
122
|
+
anti = ""
|
|
123
|
+
|
|
124
|
+
where = ["mt.text_content IS NOT NULL", "length(mt.text_content) > 0"]
|
|
125
|
+
if since_days is not None:
|
|
126
|
+
# DuckDB refuses to prepare an INTERVAL parameter; inline the coerced int.
|
|
127
|
+
where.append(f"mt.ts >= current_timestamp - INTERVAL {int(since_days)} DAY")
|
|
128
|
+
|
|
129
|
+
sql = (
|
|
130
|
+
f"SELECT mt.uuid, mt.text_content FROM messages_text mt WHERE {' AND '.join(where)} {anti}"
|
|
131
|
+
)
|
|
132
|
+
if limit is not None:
|
|
133
|
+
sql += f"\nLIMIT {int(limit)}"
|
|
134
|
+
|
|
135
|
+
rows = con.execute(sql).fetchall()
|
|
136
|
+
# DuckDB returns UUIDs as uuid.UUID objects; polars wants str for pl.Utf8.
|
|
137
|
+
return [(str(r[0]), r[1]) for r in rows]
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _build_bedrock_client(settings: Settings) -> Any:
|
|
141
|
+
"""Construct a boto3 ``bedrock-runtime`` client from settings.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
settings
|
|
146
|
+
Application settings providing the target AWS region.
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
botocore client
|
|
151
|
+
A low-level ``bedrock-runtime`` client.
|
|
152
|
+
"""
|
|
153
|
+
# Disable botocore's internal retry layer so tenacity sees throttling
|
|
154
|
+
# immediately — otherwise botocore silently absorbs 4 retries and our
|
|
155
|
+
# retry policy never kicks in. Also bump read_timeout for large batches.
|
|
156
|
+
boto_cfg = BotoConfig(
|
|
157
|
+
region_name=settings.region,
|
|
158
|
+
retries={"max_attempts": 0, "mode": "standard"},
|
|
159
|
+
read_timeout=60,
|
|
160
|
+
connect_timeout=10,
|
|
161
|
+
)
|
|
162
|
+
return boto3.client("bedrock-runtime", config=boto_cfg)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@retry(
|
|
166
|
+
# Cohere Embed v4 on Bedrock has a strict TPM bucket that replenishes over
|
|
167
|
+
# tens of seconds; wait up to 60s between attempts and try up to 10 times
|
|
168
|
+
# before surfacing the ThrottlingException.
|
|
169
|
+
stop=stop_after_attempt(10),
|
|
170
|
+
wait=wait_exponential(multiplier=2, min=2, max=60),
|
|
171
|
+
retry=retry_if_exception(_is_retryable),
|
|
172
|
+
before_sleep=loguru_before_sleep("WARNING"),
|
|
173
|
+
reraise=True,
|
|
174
|
+
)
|
|
175
|
+
def _invoke_bedrock_sync(
|
|
176
|
+
client: Any,
|
|
177
|
+
model_id: str,
|
|
178
|
+
texts: list[str],
|
|
179
|
+
*,
|
|
180
|
+
input_type: str,
|
|
181
|
+
output_dimension: int,
|
|
182
|
+
embedding_type: str,
|
|
183
|
+
) -> list[list[int]] | list[list[float]]:
|
|
184
|
+
"""Make one synchronous ``invoke_model`` call and return the vectors.
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
client
|
|
189
|
+
A boto3 ``bedrock-runtime`` client.
|
|
190
|
+
model_id
|
|
191
|
+
Cohere Embed v4 model ID (direct or CRIS profile).
|
|
192
|
+
texts
|
|
193
|
+
Up to 96 strings; each is clipped to ``MAX_CHARS_PER_TEXT``.
|
|
194
|
+
input_type
|
|
195
|
+
Either ``"search_document"`` (corpus) or ``"search_query"``.
|
|
196
|
+
output_dimension
|
|
197
|
+
Target Matryoshka dimension: 256, 512, 1024, or 1536.
|
|
198
|
+
embedding_type
|
|
199
|
+
One of ``"int8"``, ``"float"``, ``"uint8"``, ``"binary"``, ``"ubinary"``.
|
|
200
|
+
|
|
201
|
+
Returns
|
|
202
|
+
-------
|
|
203
|
+
list of list of int or float
|
|
204
|
+
Flat list of vectors matching the order of ``texts``.
|
|
205
|
+
"""
|
|
206
|
+
body = json.dumps(
|
|
207
|
+
{
|
|
208
|
+
"texts": [t[:MAX_CHARS_PER_TEXT] for t in texts],
|
|
209
|
+
"input_type": input_type,
|
|
210
|
+
"output_dimension": output_dimension,
|
|
211
|
+
"embedding_types": [embedding_type],
|
|
212
|
+
"truncate": "RIGHT",
|
|
213
|
+
}
|
|
214
|
+
)
|
|
215
|
+
resp = client.invoke_model(
|
|
216
|
+
modelId=model_id,
|
|
217
|
+
body=body,
|
|
218
|
+
contentType="application/json",
|
|
219
|
+
accept="application/json",
|
|
220
|
+
)
|
|
221
|
+
payload = json.loads(resp["body"].read())
|
|
222
|
+
return payload["embeddings"][embedding_type]
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
async def _embed_one_batch(
|
|
226
|
+
client: Any,
|
|
227
|
+
texts: list[str],
|
|
228
|
+
model_id: str,
|
|
229
|
+
*,
|
|
230
|
+
input_type: str,
|
|
231
|
+
output_dimension: int,
|
|
232
|
+
embedding_type: str,
|
|
233
|
+
sem: asyncio.Semaphore,
|
|
234
|
+
) -> list[list[int]] | list[list[float]]:
|
|
235
|
+
"""Embed a single batch under a concurrency-limiting semaphore."""
|
|
236
|
+
async with sem:
|
|
237
|
+
return await asyncio.to_thread(
|
|
238
|
+
_invoke_bedrock_sync,
|
|
239
|
+
client,
|
|
240
|
+
model_id,
|
|
241
|
+
texts,
|
|
242
|
+
input_type=input_type,
|
|
243
|
+
output_dimension=output_dimension,
|
|
244
|
+
embedding_type=embedding_type,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
async def embed_documents_async(
|
|
249
|
+
texts: list[str],
|
|
250
|
+
*,
|
|
251
|
+
settings: Settings,
|
|
252
|
+
) -> list[list[float]]:
|
|
253
|
+
"""Embed corpus documents in parallel and return ``float`` vectors.
|
|
254
|
+
|
|
255
|
+
Uses ``input_type="search_document"``. The Bedrock response (int8 or
|
|
256
|
+
other type per ``settings.embedding_type``) is cast to ``float`` so the
|
|
257
|
+
downstream parquet / DuckDB ``FLOAT[]`` column is directly consumable by
|
|
258
|
+
the VSS HNSW index.
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
texts
|
|
263
|
+
Full corpus to embed; this function handles batching and concurrency.
|
|
264
|
+
settings
|
|
265
|
+
Application settings (model, batch size, concurrency, output dim).
|
|
266
|
+
|
|
267
|
+
Returns
|
|
268
|
+
-------
|
|
269
|
+
list of list of float
|
|
270
|
+
One vector per input text, same order.
|
|
271
|
+
"""
|
|
272
|
+
if not texts:
|
|
273
|
+
return []
|
|
274
|
+
|
|
275
|
+
client = _build_bedrock_client(settings)
|
|
276
|
+
batch_size = settings.batch_size
|
|
277
|
+
batches = [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
|
|
278
|
+
sem = asyncio.Semaphore(settings.embed_concurrency)
|
|
279
|
+
|
|
280
|
+
logger.info(
|
|
281
|
+
"Embedding {} texts in {} batches (batch_size={}, concurrency={}, model={})",
|
|
282
|
+
len(texts),
|
|
283
|
+
len(batches),
|
|
284
|
+
batch_size,
|
|
285
|
+
settings.embed_concurrency,
|
|
286
|
+
settings.active_model_id,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
t0 = time.monotonic()
|
|
290
|
+
coros = [
|
|
291
|
+
_embed_one_batch(
|
|
292
|
+
client,
|
|
293
|
+
batch,
|
|
294
|
+
settings.active_model_id,
|
|
295
|
+
input_type="search_document",
|
|
296
|
+
output_dimension=settings.output_dimension,
|
|
297
|
+
embedding_type=settings.embedding_type,
|
|
298
|
+
sem=sem,
|
|
299
|
+
)
|
|
300
|
+
for batch in batches
|
|
301
|
+
]
|
|
302
|
+
results = await asyncio.gather(*coros)
|
|
303
|
+
elapsed = time.monotonic() - t0
|
|
304
|
+
|
|
305
|
+
vectors: list[list[float]] = [
|
|
306
|
+
[float(x) for x in v] for batch_vecs in results for v in batch_vecs
|
|
307
|
+
]
|
|
308
|
+
logger.info(
|
|
309
|
+
"Embedded {} vectors across {} batches in {:.2f}s ({:.1f} vec/s)",
|
|
310
|
+
len(vectors),
|
|
311
|
+
len(batches),
|
|
312
|
+
elapsed,
|
|
313
|
+
len(vectors) / elapsed if elapsed > 0 else 0.0,
|
|
314
|
+
)
|
|
315
|
+
return vectors
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def embed_query(text: str, *, settings: Settings) -> list[float]:
|
|
319
|
+
"""Embed a single query string for HNSW nearest-neighbor search.
|
|
320
|
+
|
|
321
|
+
Uses ``input_type="search_query"`` and forces ``embedding_type="float"``
|
|
322
|
+
regardless of ``settings.embedding_type`` because HNSW distance math
|
|
323
|
+
needs float vectors.
|
|
324
|
+
|
|
325
|
+
Parameters
|
|
326
|
+
----------
|
|
327
|
+
text
|
|
328
|
+
The user's natural-language query.
|
|
329
|
+
settings
|
|
330
|
+
Application settings (model, output dim, region).
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
list of float
|
|
335
|
+
A single vector of length ``settings.output_dimension``.
|
|
336
|
+
"""
|
|
337
|
+
client = _build_bedrock_client(settings)
|
|
338
|
+
vectors = _invoke_bedrock_sync(
|
|
339
|
+
client,
|
|
340
|
+
settings.active_model_id,
|
|
341
|
+
[text],
|
|
342
|
+
input_type="search_query",
|
|
343
|
+
output_dimension=settings.output_dimension,
|
|
344
|
+
embedding_type="float",
|
|
345
|
+
)
|
|
346
|
+
return [float(x) for x in vectors[0]]
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
async def run_backfill(
|
|
350
|
+
*,
|
|
351
|
+
con: duckdb.DuckDBPyConnection,
|
|
352
|
+
settings: Settings,
|
|
353
|
+
since_days: int | None = None,
|
|
354
|
+
limit: int | None = None,
|
|
355
|
+
dry_run: bool = False,
|
|
356
|
+
) -> int | dict[str, Any]:
|
|
357
|
+
"""Discover unembedded messages, embed them, and append to parquet.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
con
|
|
362
|
+
An open DuckDB connection with ``messages_text`` registered.
|
|
363
|
+
settings
|
|
364
|
+
Application settings (model, batch size, concurrency, parquet path).
|
|
365
|
+
since_days
|
|
366
|
+
If given, only consider messages newer than this many days.
|
|
367
|
+
limit
|
|
368
|
+
Optional cap on number of messages to embed this run.
|
|
369
|
+
dry_run
|
|
370
|
+
If true, log the plan and return a plan dict without calling Bedrock.
|
|
371
|
+
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
int | dict
|
|
375
|
+
Under ``dry_run=True``, a plan dict with ``{pipeline, candidates,
|
|
376
|
+
batches, batch_size, concurrency, model, since_days, limit}``.
|
|
377
|
+
Otherwise, count of newly embedded rows (0 when nothing is pending).
|
|
378
|
+
"""
|
|
379
|
+
pending = discover_unembedded(
|
|
380
|
+
con,
|
|
381
|
+
embeddings_parquet=settings.embeddings_parquet_path,
|
|
382
|
+
since_days=since_days,
|
|
383
|
+
limit=limit,
|
|
384
|
+
)
|
|
385
|
+
if not pending:
|
|
386
|
+
logger.info("No unembedded messages found - nothing to do")
|
|
387
|
+
if dry_run:
|
|
388
|
+
return {
|
|
389
|
+
"pipeline": "embed",
|
|
390
|
+
"candidates": 0,
|
|
391
|
+
"batches": 0,
|
|
392
|
+
"batch_size": settings.batch_size,
|
|
393
|
+
"concurrency": settings.embed_concurrency,
|
|
394
|
+
"model": settings.active_model_id,
|
|
395
|
+
"since_days": since_days,
|
|
396
|
+
"limit": limit,
|
|
397
|
+
"dry_run": True,
|
|
398
|
+
}
|
|
399
|
+
return 0
|
|
400
|
+
|
|
401
|
+
n_batches = (len(pending) + settings.batch_size - 1) // settings.batch_size
|
|
402
|
+
logger.info(
|
|
403
|
+
"Backfill plan: {} messages, {} batches, concurrency={}, model={}",
|
|
404
|
+
len(pending),
|
|
405
|
+
n_batches,
|
|
406
|
+
settings.embed_concurrency,
|
|
407
|
+
settings.active_model_id,
|
|
408
|
+
)
|
|
409
|
+
if dry_run:
|
|
410
|
+
logger.info("dry_run=True - skipping Bedrock calls")
|
|
411
|
+
return {
|
|
412
|
+
"pipeline": "embed",
|
|
413
|
+
"candidates": len(pending),
|
|
414
|
+
"batches": n_batches,
|
|
415
|
+
"batch_size": settings.batch_size,
|
|
416
|
+
"concurrency": settings.embed_concurrency,
|
|
417
|
+
"model": settings.active_model_id,
|
|
418
|
+
"since_days": since_days,
|
|
419
|
+
"limit": limit,
|
|
420
|
+
"dry_run": True,
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
# Checkpoint every N messages so a throttling-induced timeout doesn't
|
|
424
|
+
# discard work already embedded. chunk must be a multiple of batch_size.
|
|
425
|
+
chunk_size = max(settings.batch_size * 4, 256)
|
|
426
|
+
path = settings.embeddings_parquet_path
|
|
427
|
+
total_t0 = time.monotonic()
|
|
428
|
+
written = 0
|
|
429
|
+
for i in range(0, len(pending), chunk_size):
|
|
430
|
+
slice_ = pending[i : i + chunk_size]
|
|
431
|
+
logger.info(
|
|
432
|
+
"Chunk {}/{}: embedding {} messages",
|
|
433
|
+
i // chunk_size + 1,
|
|
434
|
+
(len(pending) + chunk_size - 1) // chunk_size,
|
|
435
|
+
len(slice_),
|
|
436
|
+
)
|
|
437
|
+
t0 = time.monotonic()
|
|
438
|
+
texts = [p[1] for p in slice_]
|
|
439
|
+
vectors = await embed_documents_async(texts, settings=settings)
|
|
440
|
+
elapsed = time.monotonic() - t0
|
|
441
|
+
logger.info(
|
|
442
|
+
"Chunk done in {:.1f}s ({:.1f} vec/s)",
|
|
443
|
+
elapsed,
|
|
444
|
+
len(vectors) / elapsed if elapsed > 0 else 0.0,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
now = datetime.now(UTC)
|
|
448
|
+
# Polars infers nested list[float] as Object when the batch is small or
|
|
449
|
+
# when rows are handed in as Python lists; force a fixed-size Array so
|
|
450
|
+
# write_parquet succeeds and DuckDB VSS sees FLOAT[dim] on read.
|
|
451
|
+
df = pl.DataFrame(
|
|
452
|
+
{
|
|
453
|
+
"uuid": [p[0] for p in slice_],
|
|
454
|
+
"model": [settings.active_model_id] * len(slice_),
|
|
455
|
+
"dim": [settings.output_dimension] * len(slice_),
|
|
456
|
+
"embedding": vectors,
|
|
457
|
+
"embedded_at": [now] * len(slice_),
|
|
458
|
+
},
|
|
459
|
+
schema={
|
|
460
|
+
"uuid": pl.Utf8,
|
|
461
|
+
"model": pl.Utf8,
|
|
462
|
+
"dim": pl.UInt16,
|
|
463
|
+
"embedding": pl.Array(pl.Float32, settings.output_dimension),
|
|
464
|
+
"embedded_at": pl.Datetime("us", "UTC"),
|
|
465
|
+
},
|
|
466
|
+
)
|
|
467
|
+
# Sharded write: drop a fresh ``part-<ts_ns>.parquet`` into the
|
|
468
|
+
# embeddings directory. Legacy single-file caches still rewrite the
|
|
469
|
+
# whole file (handled inside ``write_part``) so existing installs
|
|
470
|
+
# keep working until they're migrated.
|
|
471
|
+
written_path = write_part(path, df)
|
|
472
|
+
written += len(slice_)
|
|
473
|
+
logger.info("Checkpoint: {} rows -> {}", len(df), written_path)
|
|
474
|
+
|
|
475
|
+
total_elapsed = time.monotonic() - total_t0
|
|
476
|
+
logger.info(
|
|
477
|
+
"Backfill complete: {} embeddings in {:.1f}s ({:.1f} vec/s overall)",
|
|
478
|
+
written,
|
|
479
|
+
total_elapsed,
|
|
480
|
+
written / total_elapsed if total_elapsed > 0 else 0.0,
|
|
481
|
+
)
|
|
482
|
+
return written
|
claude_sql/freeze.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""Pre-registration: freeze + replay study manifests.
|
|
2
|
+
|
|
3
|
+
The ``freeze`` subcommand hashes the full study spec (rubric YAML,
|
|
4
|
+
judge panel, commit SHA, embed model, session-scoping rule) into a
|
|
5
|
+
deterministic manifest SHA and writes it under ``~/.claude/studies/<sha>/``.
|
|
6
|
+
|
|
7
|
+
The ``replay`` subcommand reads a manifest by SHA and rebuilds a
|
|
8
|
+
``Study`` object so downstream workers (``judge``, ``ungrounded-claim``,
|
|
9
|
+
``kappa``) can execute with the exact locked parameters.
|
|
10
|
+
|
|
11
|
+
This is the IRR study's audit trail. Every parquet the workers write
|
|
12
|
+
carries the manifest SHA in a ``freeze_sha`` column.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
import subprocess
|
|
21
|
+
from dataclasses import asdict, dataclass, field
|
|
22
|
+
from datetime import UTC, datetime
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from claude_sql import judges as judge_catalog
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class SessionScope:
|
|
30
|
+
"""Session-scoping rule for the study."""
|
|
31
|
+
|
|
32
|
+
min_turns: int = 10
|
|
33
|
+
max_turns: int = 40
|
|
34
|
+
max_interrupt_minutes: int = 15
|
|
35
|
+
kind: str = "mechanical"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass(frozen=True)
|
|
39
|
+
class Study:
|
|
40
|
+
"""A pre-registered IRR study specification."""
|
|
41
|
+
|
|
42
|
+
rubric_path: str
|
|
43
|
+
rubric_content_hash: str
|
|
44
|
+
panel_shortnames: tuple[str, ...]
|
|
45
|
+
panel_model_ids: tuple[str, ...]
|
|
46
|
+
embed_model_id: str
|
|
47
|
+
commit_sha: str
|
|
48
|
+
session_scope: SessionScope
|
|
49
|
+
seed: int
|
|
50
|
+
created_at_utc: str = field(
|
|
51
|
+
default_factory=lambda: datetime.now(UTC).isoformat(timespec="seconds")
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def manifest_sha(self) -> str:
|
|
56
|
+
"""Deterministic SHA256 over the parameterised fields (excludes created_at)."""
|
|
57
|
+
payload = {
|
|
58
|
+
"rubric_content_hash": self.rubric_content_hash,
|
|
59
|
+
"panel_model_ids": list(self.panel_model_ids),
|
|
60
|
+
"embed_model_id": self.embed_model_id,
|
|
61
|
+
"commit_sha": self.commit_sha,
|
|
62
|
+
"session_scope": asdict(self.session_scope),
|
|
63
|
+
"seed": self.seed,
|
|
64
|
+
}
|
|
65
|
+
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()[:16]
|
|
66
|
+
|
|
67
|
+
def to_dict(self) -> dict[str, object]:
|
|
68
|
+
d = asdict(self)
|
|
69
|
+
d["manifest_sha"] = self.manifest_sha
|
|
70
|
+
return d
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _read_rubric(rubric_path: Path) -> tuple[str, str]:
|
|
74
|
+
"""Read rubric YAML/JSON from disk and return (content, SHA256 hash).
|
|
75
|
+
|
|
76
|
+
Supports YAML (loaded as text, no parsing) and JSON; we hash the
|
|
77
|
+
raw bytes so whitespace differences produce different manifests,
|
|
78
|
+
which is what we want — even a reformatted rubric is a new study.
|
|
79
|
+
"""
|
|
80
|
+
if not rubric_path.exists():
|
|
81
|
+
raise FileNotFoundError(f"rubric not found: {rubric_path}")
|
|
82
|
+
content = rubric_path.read_text(encoding="utf-8")
|
|
83
|
+
h = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
84
|
+
return content, h
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _git_commit_sha(repo: Path) -> str:
|
|
88
|
+
"""Return the current ``git rev-parse HEAD`` for ``repo``, or ``"<dirty>"``."""
|
|
89
|
+
try:
|
|
90
|
+
out = subprocess.run(
|
|
91
|
+
["git", "-C", str(repo), "rev-parse", "HEAD"],
|
|
92
|
+
capture_output=True,
|
|
93
|
+
text=True,
|
|
94
|
+
check=True,
|
|
95
|
+
)
|
|
96
|
+
sha = out.stdout.strip()
|
|
97
|
+
# Flag dirty working trees so the commit hash doesn't over-claim reproducibility
|
|
98
|
+
dirty = subprocess.run(
|
|
99
|
+
["git", "-C", str(repo), "status", "--porcelain"],
|
|
100
|
+
capture_output=True,
|
|
101
|
+
text=True,
|
|
102
|
+
check=True,
|
|
103
|
+
)
|
|
104
|
+
return sha if not dirty.stdout.strip() else f"{sha}-dirty"
|
|
105
|
+
except (FileNotFoundError, subprocess.CalledProcessError):
|
|
106
|
+
return "<no-git>"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _studies_root() -> Path:
|
|
110
|
+
root = Path(os.path.expanduser("~/.claude/studies"))
|
|
111
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
112
|
+
return root
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def freeze(
|
|
116
|
+
rubric_path: Path,
|
|
117
|
+
panel_shortnames: tuple[str, ...],
|
|
118
|
+
embed_model_id: str = "global.cohere.embed-v4:0",
|
|
119
|
+
session_scope: SessionScope | None = None,
|
|
120
|
+
seed: int = 42,
|
|
121
|
+
repo: Path | None = None,
|
|
122
|
+
) -> Study:
|
|
123
|
+
"""Create and persist a study manifest.
|
|
124
|
+
|
|
125
|
+
Writes ``~/.claude/studies/<sha>/manifest.json`` plus a copy of the
|
|
126
|
+
rubric so the manifest is self-contained. Returns the ``Study``.
|
|
127
|
+
"""
|
|
128
|
+
rubric_content, rubric_hash = _read_rubric(rubric_path)
|
|
129
|
+
judges_resolved = judge_catalog.panel(list(panel_shortnames))
|
|
130
|
+
panel_model_ids = tuple(j.model_id for j in judges_resolved)
|
|
131
|
+
scope = session_scope or SessionScope()
|
|
132
|
+
study = Study(
|
|
133
|
+
rubric_path=str(rubric_path),
|
|
134
|
+
rubric_content_hash=rubric_hash,
|
|
135
|
+
panel_shortnames=tuple(panel_shortnames),
|
|
136
|
+
panel_model_ids=panel_model_ids,
|
|
137
|
+
embed_model_id=embed_model_id,
|
|
138
|
+
commit_sha=_git_commit_sha(repo or Path.cwd()),
|
|
139
|
+
session_scope=scope,
|
|
140
|
+
seed=seed,
|
|
141
|
+
)
|
|
142
|
+
study_dir = _studies_root() / study.manifest_sha
|
|
143
|
+
study_dir.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
(study_dir / "manifest.json").write_text(
|
|
145
|
+
json.dumps(study.to_dict(), indent=2, sort_keys=True), encoding="utf-8"
|
|
146
|
+
)
|
|
147
|
+
(study_dir / "rubric.yaml").write_text(rubric_content, encoding="utf-8")
|
|
148
|
+
return study
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def replay(manifest_sha: str) -> Study:
|
|
152
|
+
"""Load a previously-frozen study by its manifest SHA."""
|
|
153
|
+
study_dir = _studies_root() / manifest_sha
|
|
154
|
+
manifest_path = study_dir / "manifest.json"
|
|
155
|
+
if not manifest_path.exists():
|
|
156
|
+
raise FileNotFoundError(f"no manifest at {manifest_path}")
|
|
157
|
+
d = json.loads(manifest_path.read_text(encoding="utf-8"))
|
|
158
|
+
scope = SessionScope(**d["session_scope"])
|
|
159
|
+
return Study(
|
|
160
|
+
rubric_path=d["rubric_path"],
|
|
161
|
+
rubric_content_hash=d["rubric_content_hash"],
|
|
162
|
+
panel_shortnames=tuple(d["panel_shortnames"]),
|
|
163
|
+
panel_model_ids=tuple(d["panel_model_ids"]),
|
|
164
|
+
embed_model_id=d["embed_model_id"],
|
|
165
|
+
commit_sha=d["commit_sha"],
|
|
166
|
+
session_scope=scope,
|
|
167
|
+
seed=d["seed"],
|
|
168
|
+
created_at_utc=d["created_at_utc"],
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def list_studies() -> list[dict[str, object]]:
|
|
173
|
+
"""Return a summary of every frozen study under ``~/.claude/studies/``."""
|
|
174
|
+
root = _studies_root()
|
|
175
|
+
out: list[dict[str, object]] = []
|
|
176
|
+
for d in sorted(root.iterdir()):
|
|
177
|
+
mf = d / "manifest.json"
|
|
178
|
+
if not mf.exists():
|
|
179
|
+
continue
|
|
180
|
+
payload = json.loads(mf.read_text(encoding="utf-8"))
|
|
181
|
+
out.append(
|
|
182
|
+
{
|
|
183
|
+
"manifest_sha": payload["manifest_sha"],
|
|
184
|
+
"created_at_utc": payload["created_at_utc"],
|
|
185
|
+
"commit_sha": payload["commit_sha"],
|
|
186
|
+
"n_judges": len(payload["panel_shortnames"]),
|
|
187
|
+
}
|
|
188
|
+
)
|
|
189
|
+
return out
|