claude-sql 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ """c-TF-IDF per cluster -- in-house (no bertopic).
2
+
3
+ Reads ``clusters.parquet`` (from cluster_worker) and the text messages
4
+ view to build one pseudo-document per cluster, then runs a sklearn
5
+ ``CountVectorizer`` and computes the c-TF-IDF weights used by BERTopic.
6
+ Writes top-N terms per cluster to ``cluster_terms.parquet``.
7
+
8
+ Public API
9
+ ----------
10
+ run_terms(con, settings, *, force=False) -> dict[str, int]
11
+ Compute ``cluster_terms.parquet``. Returns ``{"clusters": K, "terms": N}``.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import time
17
+ from typing import TYPE_CHECKING
18
+
19
+ import numpy as np
20
+ import polars as pl
21
+ from loguru import logger
22
+
23
+ if TYPE_CHECKING:
24
+ import duckdb
25
+
26
+ from claude_sql.config import Settings
27
+
28
+
29
+ def run_terms(
30
+ con: duckdb.DuckDBPyConnection,
31
+ settings: Settings,
32
+ *,
33
+ force: bool = False,
34
+ ) -> dict[str, int]:
35
+ """Compute c-TF-IDF top terms per cluster and write the parquet output.
36
+
37
+ Parameters
38
+ ----------
39
+ con
40
+ Open DuckDB connection with ``messages_text`` registered.
41
+ settings
42
+ Runtime settings (``clusters_parquet_path``, ``cluster_terms_parquet_path``,
43
+ and the ``tfidf_*`` hyperparameters).
44
+ force
45
+ If True, recompute even when the output parquet already exists.
46
+
47
+ Returns
48
+ -------
49
+ dict[str, int]
50
+ ``{"clusters": K, "terms": N}`` where K is the number of clusters
51
+ processed and N is the total row count written to parquet.
52
+ """
53
+ out = settings.cluster_terms_parquet_path
54
+ clusters_path = settings.clusters_parquet_path
55
+
56
+ if not clusters_path.exists() or clusters_path.stat().st_size < 16:
57
+ raise FileNotFoundError(
58
+ f"Clusters parquet missing at {clusters_path}. Run `claude-sql cluster` first."
59
+ )
60
+ if out.exists() and out.stat().st_size > 16 and not force:
61
+ df = pl.read_parquet(out)
62
+ logger.info("cluster_terms parquet already exists at {}", out)
63
+ return {"clusters": int(df["cluster_id"].n_unique()), "terms": len(df)}
64
+
65
+ from sklearn.feature_extraction.text import CountVectorizer
66
+
67
+ t0 = time.monotonic()
68
+ # Join clusters parquet to messages_text on uuid. ``mt.uuid`` is DuckDB
69
+ # UUID; the parquet column is VARCHAR -- cast to match.
70
+ sql = """
71
+ SELECT c.cluster_id,
72
+ mt.text_content
73
+ FROM read_parquet(?) c
74
+ JOIN messages_text mt
75
+ ON CAST(mt.uuid AS VARCHAR) = c.uuid
76
+ WHERE c.cluster_id >= 0
77
+ AND mt.text_content IS NOT NULL
78
+ """
79
+ df = con.execute(sql, [str(clusters_path)]).pl()
80
+ logger.info(
81
+ "Joined {} rows clusters x messages_text in {:.1f}s",
82
+ len(df),
83
+ time.monotonic() - t0,
84
+ )
85
+
86
+ # Build one pseudo-document per cluster.
87
+ per_cluster = (
88
+ df.group_by("cluster_id")
89
+ .agg(pl.col("text_content").str.join("\n").alias("doc"))
90
+ .sort("cluster_id")
91
+ )
92
+ cluster_ids = per_cluster["cluster_id"].to_list()
93
+ docs = per_cluster["doc"].to_list()
94
+ logger.info("Built {} cluster pseudo-docs", len(docs))
95
+
96
+ cv = CountVectorizer(
97
+ min_df=settings.tfidf_min_df,
98
+ max_df=settings.tfidf_max_df,
99
+ ngram_range=(settings.tfidf_ngram_min, settings.tfidf_ngram_max),
100
+ lowercase=True,
101
+ strip_accents="unicode",
102
+ ).fit(docs)
103
+ tf = cv.transform(docs).toarray().astype(np.float32) # (n_clusters, vocab)
104
+ vocab = cv.get_feature_names_out()
105
+ logger.info("Vocabulary size: {} terms", len(vocab))
106
+
107
+ # c-TF-IDF: term frequency in cluster x log(1 + avg_docs_per_term / col_sum).
108
+ row_sum = tf.sum(axis=1, keepdims=True)
109
+ row_sum[row_sum == 0] = 1.0
110
+ tf_norm = tf / row_sum
111
+ col_sum = tf.sum(axis=0)
112
+ total = tf.sum()
113
+ avg = col_sum / max(total, 1.0)
114
+ idf = np.log(1.0 + (avg.sum() / np.maximum(col_sum, 1e-9)))
115
+ ctfidf = tf_norm * idf # (n_clusters, vocab)
116
+
117
+ top_n = settings.tfidf_top_n_terms
118
+ rows: list[tuple[int, str, float, int]] = []
119
+ for k, cid in enumerate(cluster_ids):
120
+ idx = np.argsort(-ctfidf[k])[:top_n]
121
+ for rank, i in enumerate(idx):
122
+ w = float(ctfidf[k, i])
123
+ if w <= 0:
124
+ continue
125
+ rows.append((int(cid), str(vocab[i]), w, rank + 1))
126
+
127
+ out_df = pl.DataFrame(
128
+ rows,
129
+ schema={
130
+ "cluster_id": pl.Int32,
131
+ "term": pl.Utf8,
132
+ "weight": pl.Float32,
133
+ "rank": pl.Int32,
134
+ },
135
+ orient="row",
136
+ )
137
+ out.parent.mkdir(parents=True, exist_ok=True)
138
+ out_df.write_parquet(out)
139
+ logger.info(
140
+ "Wrote {} term-rows across {} clusters in {:.1f}s",
141
+ len(out_df),
142
+ len(cluster_ids),
143
+ time.monotonic() - t0,
144
+ )
145
+ return {"clusters": len(cluster_ids), "terms": len(out_df)}
@@ -0,0 +1,190 @@
1
+ """Ungrounded-claim detector v0 — entity spotting over tool outputs.
2
+
3
+ For each assistant turn, extract *factual claims about internal systems*
4
+ (file paths, function names, config flags, CLI subcommands, table names,
5
+ env vars, Slack/work IDs) and check whether those entities appear in the
6
+ same session's tool-call outputs. An assertion that names an entity
7
+ never seen in tool output is flagged as potentially ungrounded.
8
+
9
+ This is v0: fast, conservative, a regex+span-graph combo. A later pass
10
+ can add Nova 2 Lite for claim-phrase extraction when the regex misses
11
+ semantic claims.
12
+
13
+ Schema written to parquet::
14
+
15
+ session_id STRING
16
+ turn_idx INT64
17
+ claim_entity STRING # the specific entity name the agent asserted
18
+ claim_kind STRING # path | function | flag | env_var | id | cli
19
+ grounded BOOLEAN # True if seen in tool output
20
+ tool_output_hits INT64 # count of exact matches in tool results
21
+ freeze_sha STRING
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import re
27
+ from collections.abc import Iterable
28
+ from dataclasses import dataclass
29
+ from pathlib import Path
30
+
31
+ import polars as pl
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Entity extractors
35
+ # ---------------------------------------------------------------------------
36
+
37
+ # Unix-ish paths with at least one '/' and a filename-shaped tail
38
+ _PATH_RE = re.compile(r"(?<![A-Za-z])(/(?:[A-Za-z0-9_.-]+/)+[A-Za-z0-9_.-]+)")
39
+ # Python-style dotted function/attr references and bare ident(): handler.run, foo()
40
+ _FUNCTION_RE = re.compile(r"\b([a-z_][a-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)+)\s*\(")
41
+ # Flags: --flag-name, CLAUDE_SQL_CONCURRENCY, SOME_ENV_VAR (>=2 underscores OR upper + underscore)
42
+ _FLAG_RE = re.compile(r"(--[a-z][a-z0-9-]{2,})")
43
+ _ENV_VAR_RE = re.compile(r"\b([A-Z][A-Z0-9_]{4,}_[A-Z0-9_]+)\b")
44
+ # Work item / thread ts / session UUID patterns from blind_handover
45
+ _ID_RE = re.compile(
46
+ r"\b(wi_[0-9a-f]{12}|[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})\b"
47
+ )
48
+ # Specific CLI subcommand style: `claude-sql <verb>`
49
+ _CLI_RE = re.compile(r"\bclaude-sql\s+([a-z][a-z0-9-]+)\b")
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class Claim:
54
+ """One extracted claim: entity name + the category it belongs to."""
55
+
56
+ entity: str
57
+ kind: str # "path" | "function" | "flag" | "env_var" | "id" | "cli"
58
+
59
+
60
+ def extract_claims(text: str) -> list[Claim]:
61
+ """Pull factual-entity claims out of assistant text.
62
+
63
+ Order matters: more specific patterns run first so a CLI subcommand
64
+ isn't double-counted as a generic function call.
65
+ """
66
+ claims: list[Claim] = []
67
+ seen: set[tuple[str, str]] = set()
68
+
69
+ def push(entity: str, kind: str) -> None:
70
+ key = (entity, kind)
71
+ if key not in seen:
72
+ seen.add(key)
73
+ claims.append(Claim(entity=entity, kind=kind))
74
+
75
+ for m in _CLI_RE.finditer(text):
76
+ push(m.group(1), "cli")
77
+ for m in _PATH_RE.finditer(text):
78
+ push(m.group(1), "path")
79
+ for m in _FUNCTION_RE.finditer(text):
80
+ push(m.group(1), "function")
81
+ for m in _FLAG_RE.finditer(text):
82
+ push(m.group(1), "flag")
83
+ for m in _ENV_VAR_RE.finditer(text):
84
+ push(m.group(1), "env_var")
85
+ for m in _ID_RE.finditer(text):
86
+ push(m.group(1), "id")
87
+ return claims
88
+
89
+
90
+ # ---------------------------------------------------------------------------
91
+ # Grounding check
92
+ # ---------------------------------------------------------------------------
93
+
94
+
95
+ def count_in(haystack: str, needle: str) -> int:
96
+ """Count non-overlapping substring occurrences; safe for regex-unfriendly needles."""
97
+ if not needle:
98
+ return 0
99
+ return haystack.count(needle)
100
+
101
+
102
+ def check_claims(claims: Iterable[Claim], tool_output_text: str) -> list[dict]:
103
+ """For each claim, count hits in the tool-output text and decide grounded."""
104
+ rows: list[dict] = []
105
+ for claim in claims:
106
+ hits = count_in(tool_output_text, claim.entity)
107
+ rows.append(
108
+ {
109
+ "claim_entity": claim.entity,
110
+ "claim_kind": claim.kind,
111
+ "grounded": hits > 0,
112
+ "tool_output_hits": hits,
113
+ }
114
+ )
115
+ return rows
116
+
117
+
118
+ # ---------------------------------------------------------------------------
119
+ # Batch over sessions
120
+ # ---------------------------------------------------------------------------
121
+
122
+
123
+ @dataclass(frozen=True)
124
+ class Turn:
125
+ """One assistant turn paired with the tool output visible to it."""
126
+
127
+ session_id: str
128
+ turn_idx: int
129
+ assistant_text: str
130
+ tool_output_text: str # concatenation of ToolResult content from this session
131
+
132
+
133
+ def detect(turns: list[Turn], freeze_sha: str) -> pl.DataFrame:
134
+ """Run the detector over a batch of turns; return a parquet-shaped frame."""
135
+ rows: list[dict] = []
136
+ for t in turns:
137
+ claims = extract_claims(t.assistant_text)
138
+ checked = check_claims(claims, t.tool_output_text)
139
+ # PERF401: comprehension would obscure the nested per-turn state mapping.
140
+ for row in checked:
141
+ rows.append( # noqa: PERF401
142
+ {
143
+ "session_id": t.session_id,
144
+ "turn_idx": t.turn_idx,
145
+ "claim_entity": row["claim_entity"],
146
+ "claim_kind": row["claim_kind"],
147
+ "grounded": row["grounded"],
148
+ "tool_output_hits": row["tool_output_hits"],
149
+ "freeze_sha": freeze_sha,
150
+ }
151
+ )
152
+ if not rows:
153
+ return pl.DataFrame(
154
+ schema={
155
+ "session_id": pl.String,
156
+ "turn_idx": pl.Int64,
157
+ "claim_entity": pl.String,
158
+ "claim_kind": pl.String,
159
+ "grounded": pl.Boolean,
160
+ "tool_output_hits": pl.Int64,
161
+ "freeze_sha": pl.String,
162
+ }
163
+ )
164
+ return pl.DataFrame(rows)
165
+
166
+
167
+ def summarize(df: pl.DataFrame) -> pl.DataFrame:
168
+ """Per-session rollup: ungrounded-claim count + rate."""
169
+ if df.height == 0:
170
+ return pl.DataFrame(
171
+ schema={
172
+ "session_id": pl.String,
173
+ "n_claims": pl.Int64,
174
+ "n_ungrounded": pl.Int64,
175
+ "ungrounded_rate": pl.Float64,
176
+ }
177
+ )
178
+ return (
179
+ df.group_by("session_id")
180
+ .agg(
181
+ pl.len().alias("n_claims"),
182
+ (~pl.col("grounded")).sum().alias("n_ungrounded"),
183
+ (1.0 - pl.col("grounded").cast(pl.Float64).mean()).alias("ungrounded_rate"),
184
+ )
185
+ .sort("ungrounded_rate", descending=True)
186
+ )
187
+
188
+
189
+ def to_parquet(df: pl.DataFrame, path: Path) -> None:
190
+ df.write_parquet(path)