agentkernel-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentkernel/__init__.py +7 -0
- agentkernel/__main__.py +5 -0
- agentkernel/agent.py +311 -0
- agentkernel/approval/__init__.py +23 -0
- agentkernel/approval/base.py +34 -0
- agentkernel/approval/cli.py +129 -0
- agentkernel/approval/policy.py +58 -0
- agentkernel/approval/risk.py +91 -0
- agentkernel/approval/sandbox.py +201 -0
- agentkernel/budget.py +64 -0
- agentkernel/checkpoint.py +50 -0
- agentkernel/cli.py +1482 -0
- agentkernel/config.py +224 -0
- agentkernel/context/__init__.py +17 -0
- agentkernel/context/manager.py +216 -0
- agentkernel/context/truncate.py +35 -0
- agentkernel/cron.py +146 -0
- agentkernel/curation.py +183 -0
- agentkernel/doctor.py +141 -0
- agentkernel/embeddings.py +132 -0
- agentkernel/evaluation.py +186 -0
- agentkernel/improvement.py +133 -0
- agentkernel/insights.py +141 -0
- agentkernel/kanban.py +114 -0
- agentkernel/knowledge.py +383 -0
- agentkernel/loops.py +145 -0
- agentkernel/mcp/__init__.py +23 -0
- agentkernel/mcp/client.py +181 -0
- agentkernel/mcp/config.py +59 -0
- agentkernel/mcp/tools.py +96 -0
- agentkernel/memory.py +1208 -0
- agentkernel/paths.py +73 -0
- agentkernel/plugins.py +76 -0
- agentkernel/profiles.py +70 -0
- agentkernel/progress.py +89 -0
- agentkernel/providers/__init__.py +35 -0
- agentkernel/providers/_http.py +157 -0
- agentkernel/providers/anthropic.py +282 -0
- agentkernel/providers/base.py +38 -0
- agentkernel/providers/credentials.py +65 -0
- agentkernel/providers/local.py +34 -0
- agentkernel/providers/openai.py +260 -0
- agentkernel/redaction.py +77 -0
- agentkernel/semantic_index.py +139 -0
- agentkernel/semantic_memory.py +253 -0
- agentkernel/skills.py +268 -0
- agentkernel/subagent.py +161 -0
- agentkernel/telemetry.py +199 -0
- agentkernel/templates/README.md +35 -0
- agentkernel/templates/SKILL.md +28 -0
- agentkernel/templates/eval-suite.toml +22 -0
- agentkernel/templates/loop.toml +29 -0
- agentkernel/templates/mcp-servers.toml +22 -0
- agentkernel/templates/profile.toml +29 -0
- agentkernel/templates/tool_module.py +64 -0
- agentkernel/tools/__init__.py +5 -0
- agentkernel/tools/base.py +100 -0
- agentkernel/tools/builtin/__init__.py +37 -0
- agentkernel/tools/builtin/checkpoint_tool.py +33 -0
- agentkernel/tools/builtin/clarify.py +60 -0
- agentkernel/tools/builtin/files.py +221 -0
- agentkernel/tools/builtin/kanban_tool.py +100 -0
- agentkernel/tools/builtin/search.py +225 -0
- agentkernel/tools/builtin/shell.py +67 -0
- agentkernel/tools/builtin/todo.py +106 -0
- agentkernel/tui/__init__.py +50 -0
- agentkernel/tui/app.py +594 -0
- agentkernel/types.py +127 -0
- agentkernel/worktree.py +64 -0
- agentkernel_cli-0.1.0.dist-info/METADATA +426 -0
- agentkernel_cli-0.1.0.dist-info/RECORD +74 -0
- agentkernel_cli-0.1.0.dist-info/WHEEL +4 -0
- agentkernel_cli-0.1.0.dist-info/entry_points.txt +2 -0
- agentkernel_cli-0.1.0.dist-info/licenses/LICENSE +201 -0
agentkernel/redaction.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Secret redaction for tool output (design §18.1).
|
|
2
|
+
|
|
3
|
+
Tool results — `bash` stdout, `read_file` contents, web/MCP output — can carry
|
|
4
|
+
API keys and tokens. This module scrubs well-known secret formats from that text
|
|
5
|
+
*before* it enters the context window and the trace, so a leaked credential in a
|
|
6
|
+
command's output doesn't get memorialized in the conversation or logged.
|
|
7
|
+
|
|
8
|
+
It is deliberately conservative: it matches high-signal token shapes (provider
|
|
9
|
+
key prefixes, PEM private-key blocks, `Authorization` headers, and labelled
|
|
10
|
+
`secret = …` assignments) rather than guessing at entropy, to keep false
|
|
11
|
+
positives low. Structured `ToolResult.data` is never touched — only text content.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
_PLACEHOLDER = "[REDACTED]"
|
|
19
|
+
|
|
20
|
+
# PEM private-key blocks (multi-line) — replace the whole block.
|
|
21
|
+
_PRIVATE_KEY = re.compile(
|
|
22
|
+
r"-----BEGIN [A-Z ]*PRIVATE KEY-----.*?-----END [A-Z ]*PRIVATE KEY-----",
|
|
23
|
+
re.DOTALL,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Standalone, high-signal token formats; the whole match is replaced. Anthropic's
|
|
27
|
+
# `sk-ant-` must come before the generic `sk-` so it wins.
|
|
28
|
+
_TOKEN_PATTERNS = (
|
|
29
|
+
re.compile(r"sk-ant-[A-Za-z0-9_\-]{20,}"),
|
|
30
|
+
re.compile(r"sk-(?:proj-)?[A-Za-z0-9]{20,}"),
|
|
31
|
+
re.compile(r"\bgh[posru]_[A-Za-z0-9]{30,}"),
|
|
32
|
+
re.compile(r"\bgithub_pat_[A-Za-z0-9_]{20,}"),
|
|
33
|
+
re.compile(r"\bglpat-[A-Za-z0-9_\-]{20,}"),
|
|
34
|
+
re.compile(r"\bxox[baprs]-[A-Za-z0-9-]{10,}"),
|
|
35
|
+
re.compile(r"\bAKIA[0-9A-Z]{16}\b"),
|
|
36
|
+
re.compile(r"\bAIza[A-Za-z0-9_\-]{35}\b"),
|
|
37
|
+
re.compile(r"\bhf_[A-Za-z0-9]{30,}"),
|
|
38
|
+
re.compile(r"\b(?:sk|rk)_live_[A-Za-z0-9]{20,}"),
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Labelled assignments — keep the label + operator, redact the value. The value
|
|
42
|
+
# must be a contiguous 12+ char run (no spaces), which keeps prose like
|
|
43
|
+
# "token: the next word" from matching.
|
|
44
|
+
_ASSIGNMENT = re.compile(
|
|
45
|
+
r"(?i)\b(api[_-]?key|secret|token|password|passwd|pwd|access[_-]?key)"
|
|
46
|
+
r"(\s*[=:]\s*)"
|
|
47
|
+
r"['\"]?([A-Za-z0-9._\-/+=]{12,})['\"]?"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# `Authorization: Bearer <token>` — keep the header prefix, redact the token.
|
|
51
|
+
_AUTH_HEADER = re.compile(
|
|
52
|
+
r"(?i)(authorization\s*:\s*(?:bearer\s+)?)([A-Za-z0-9._\-]{16,})"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def redact_secrets(text: str) -> tuple[str, int]:
|
|
57
|
+
"""Return ``(scrubbed_text, count)`` with known secret formats replaced.
|
|
58
|
+
|
|
59
|
+
``count`` is how many secrets were redacted (0 means the text was clean).
|
|
60
|
+
"""
|
|
61
|
+
if not text:
|
|
62
|
+
return text, 0
|
|
63
|
+
|
|
64
|
+
count = 0
|
|
65
|
+
|
|
66
|
+
def _apply(pattern: re.Pattern[str], repl, s: str) -> str:
|
|
67
|
+
nonlocal count
|
|
68
|
+
s, n = pattern.subn(repl, s)
|
|
69
|
+
count += n
|
|
70
|
+
return s
|
|
71
|
+
|
|
72
|
+
out = _apply(_PRIVATE_KEY, _PLACEHOLDER, text)
|
|
73
|
+
for pattern in _TOKEN_PATTERNS:
|
|
74
|
+
out = _apply(pattern, _PLACEHOLDER, out)
|
|
75
|
+
out = _apply(_ASSIGNMENT, lambda m: f"{m.group(1)}{m.group(2)}{_PLACEHOLDER}", out)
|
|
76
|
+
out = _apply(_AUTH_HEADER, lambda m: f"{m.group(1)}{_PLACEHOLDER}", out)
|
|
77
|
+
return out, count
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Approximate nearest-neighbor helpers for semantic memory.
|
|
2
|
+
|
|
3
|
+
These use only the Python standard library so the kernel keeps its stdlib-only
|
|
4
|
+
constraint. The default path is a brute-force cosine scan; when scale demands it,
|
|
5
|
+
a small random-projection LSH index can prune the candidate set.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import random
|
|
12
|
+
import sqlite3
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LSHIndex:
|
|
17
|
+
"""Random-projection locality-sensitive hash index for dense vectors.
|
|
18
|
+
|
|
19
|
+
Each vector is projected onto ``num_bits`` random hyperplanes; the sign of
|
|
20
|
+
each projection becomes one bit of an integer bucket key. Queries fetch the
|
|
21
|
+
exact bucket plus all buckets one bit away, which dramatically improves
|
|
22
|
+
recall without a full linear scan.
|
|
23
|
+
|
|
24
|
+
The hyperplanes are persisted in the same SQLite database as the vectors so
|
|
25
|
+
the index is stable across process restarts.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
dim: int,
|
|
31
|
+
num_bits: int,
|
|
32
|
+
conn: Callable[[], sqlite3.Connection],
|
|
33
|
+
*,
|
|
34
|
+
seed: int = 0,
|
|
35
|
+
) -> None:
|
|
36
|
+
self.dim = dim
|
|
37
|
+
self.num_bits = num_bits
|
|
38
|
+
self._conn = conn
|
|
39
|
+
self._seed = seed
|
|
40
|
+
self._hyperplanes = self._ensure_hyperplanes()
|
|
41
|
+
|
|
42
|
+
def _ensure_hyperplanes(self) -> list[list[float]]:
|
|
43
|
+
conn = self._conn()
|
|
44
|
+
conn.execute(
|
|
45
|
+
"CREATE TABLE IF NOT EXISTS lsh_meta (key TEXT PRIMARY KEY, value TEXT)"
|
|
46
|
+
)
|
|
47
|
+
conn.execute(
|
|
48
|
+
"CREATE TABLE IF NOT EXISTS lsh_buckets ("
|
|
49
|
+
"note_id INTEGER PRIMARY KEY, bucket INTEGER NOT NULL"
|
|
50
|
+
")"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
bits_row = conn.execute(
|
|
54
|
+
'SELECT value FROM lsh_meta WHERE key = "bits"'
|
|
55
|
+
).fetchone()
|
|
56
|
+
seed_row = conn.execute(
|
|
57
|
+
'SELECT value FROM lsh_meta WHERE key = "seed"'
|
|
58
|
+
).fetchone()
|
|
59
|
+
planes_row = conn.execute(
|
|
60
|
+
'SELECT value FROM lsh_meta WHERE key = "hyperplanes"'
|
|
61
|
+
).fetchone()
|
|
62
|
+
|
|
63
|
+
if bits_row and seed_row and planes_row:
|
|
64
|
+
stored_bits = int(bits_row["value"])
|
|
65
|
+
stored_seed = int(seed_row["value"])
|
|
66
|
+
planes = json.loads(planes_row["value"])
|
|
67
|
+
if (
|
|
68
|
+
stored_bits == self.num_bits
|
|
69
|
+
and stored_seed == self._seed
|
|
70
|
+
and len(planes) == self.num_bits
|
|
71
|
+
and all(len(p) == self.dim for p in planes)
|
|
72
|
+
):
|
|
73
|
+
return planes
|
|
74
|
+
|
|
75
|
+
rng = random.Random(self._seed)
|
|
76
|
+
planes = [
|
|
77
|
+
[rng.gauss(0.0, 1.0) for _ in range(self.dim)]
|
|
78
|
+
for _ in range(self.num_bits)
|
|
79
|
+
]
|
|
80
|
+
with conn:
|
|
81
|
+
conn.execute(
|
|
82
|
+
"INSERT OR REPLACE INTO lsh_meta (key, value) VALUES (?, ?)",
|
|
83
|
+
("bits", str(self.num_bits)),
|
|
84
|
+
)
|
|
85
|
+
conn.execute(
|
|
86
|
+
"INSERT OR REPLACE INTO lsh_meta (key, value) VALUES (?, ?)",
|
|
87
|
+
("seed", str(self._seed)),
|
|
88
|
+
)
|
|
89
|
+
conn.execute(
|
|
90
|
+
"INSERT OR REPLACE INTO lsh_meta (key, value) VALUES (?, ?)",
|
|
91
|
+
("hyperplanes", json.dumps(planes)),
|
|
92
|
+
)
|
|
93
|
+
conn.execute("DELETE FROM lsh_buckets")
|
|
94
|
+
return planes
|
|
95
|
+
|
|
96
|
+
def hash(self, vector: list[float]) -> int:
|
|
97
|
+
"""Return the integer bucket for ``vector``."""
|
|
98
|
+
bucket = 0
|
|
99
|
+
for bit, plane in enumerate(self._hyperplanes):
|
|
100
|
+
dot = sum(v * p for v, p in zip(vector, plane, strict=True))
|
|
101
|
+
if dot >= 0:
|
|
102
|
+
bucket |= 1 << bit
|
|
103
|
+
return bucket
|
|
104
|
+
|
|
105
|
+
def query_buckets(self, vector: list[float]) -> list[int]:
|
|
106
|
+
"""Return the query bucket and all one-bit neighbors."""
|
|
107
|
+
base = self.hash(vector)
|
|
108
|
+
buckets = [base]
|
|
109
|
+
for bit in range(self.num_bits):
|
|
110
|
+
buckets.append(base ^ (1 << bit))
|
|
111
|
+
return buckets
|
|
112
|
+
|
|
113
|
+
def upsert(self, note_id: int, vector: list[float]) -> None:
|
|
114
|
+
"""Store/update the bucket for ``note_id``."""
|
|
115
|
+
with self._conn():
|
|
116
|
+
self._conn().execute(
|
|
117
|
+
"INSERT OR REPLACE INTO lsh_buckets (note_id, bucket) VALUES (?, ?)",
|
|
118
|
+
(note_id, self.hash(vector)),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def remove(self, note_id: int) -> None:
|
|
122
|
+
with self._conn():
|
|
123
|
+
self._conn().execute(
|
|
124
|
+
"DELETE FROM lsh_buckets WHERE note_id = ?", (note_id,)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def candidate_ids(self, buckets: list[int]) -> list[int]:
|
|
128
|
+
"""Return note ids whose bucket is in ``buckets``."""
|
|
129
|
+
if not buckets:
|
|
130
|
+
return []
|
|
131
|
+
placeholders = ",".join("?" for _ in buckets)
|
|
132
|
+
rows = self._conn().execute(
|
|
133
|
+
f"""
|
|
134
|
+
SELECT note_id FROM lsh_buckets
|
|
135
|
+
WHERE bucket IN ({placeholders})
|
|
136
|
+
""",
|
|
137
|
+
tuple(buckets),
|
|
138
|
+
).fetchall()
|
|
139
|
+
return [row["note_id"] for row in rows]
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Semantic note search over SQLite notebooks.
|
|
2
|
+
|
|
3
|
+
This is a thin subclass of ``SqliteNoteStore`` so the JSONL / SQLite split and
|
|
4
|
+
all existing memory-tool behavior stay intact. When an ``EmbeddingProvider`` is
|
|
5
|
+
configured, each note stores a dense vector and recall is re-ranked by cosine
|
|
6
|
+
similarity rather than only token overlap.
|
|
7
|
+
|
|
8
|
+
For large notebooks the optional LSH index in ``semantic_index`` prunes the
|
|
9
|
+
candidate set before the dense comparison, avoiding a full linear scan.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
from collections.abc import Sequence
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
from agentkernel.embeddings import EmbeddingProvider, cosine_similarity
|
|
19
|
+
from agentkernel.memory import MemoryNote, SqliteNoteStore
|
|
20
|
+
from agentkernel.semantic_index import LSHIndex
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SemanticSqliteNoteStore(SqliteNoteStore):
|
|
24
|
+
"""SQLite notebook that also stores dense embeddings for semantic ranking.
|
|
25
|
+
|
|
26
|
+
Keyword and full-text search still retrieve candidates; dense similarity
|
|
27
|
+
refines their order. Notes created before the provider existed can be
|
|
28
|
+
backfilled with ``reindex_embeddings()``.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
path: str | Path,
|
|
34
|
+
*,
|
|
35
|
+
embedding_provider: EmbeddingProvider,
|
|
36
|
+
lsh_bits: int | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
self._embedding_provider = embedding_provider
|
|
39
|
+
self._lsh_bits = lsh_bits
|
|
40
|
+
self._lsh_index: LSHIndex | None = None
|
|
41
|
+
# Parent creates the notes table and optional FTS5 index.
|
|
42
|
+
super().__init__(path)
|
|
43
|
+
self._ensure_embedding_schema()
|
|
44
|
+
|
|
45
|
+
def _ensure_embedding_schema(self) -> None:
|
|
46
|
+
conn = self._connection()
|
|
47
|
+
conn.execute(
|
|
48
|
+
"""
|
|
49
|
+
CREATE TABLE IF NOT EXISTS note_embeddings (
|
|
50
|
+
note_id INTEGER PRIMARY KEY,
|
|
51
|
+
embedding_json TEXT NOT NULL
|
|
52
|
+
)
|
|
53
|
+
"""
|
|
54
|
+
)
|
|
55
|
+
conn.commit()
|
|
56
|
+
|
|
57
|
+
def _ensure_lsh_index(self, sample_vector: list[float] | None = None) -> None:
|
|
58
|
+
"""Create the LSH index once we know the vector dimension."""
|
|
59
|
+
if self._lsh_index is not None or not self._lsh_bits:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
dim = len(sample_vector) if sample_vector else None
|
|
63
|
+
if dim is None:
|
|
64
|
+
rows = self._connection().execute(
|
|
65
|
+
"SELECT embedding_json FROM note_embeddings LIMIT 1"
|
|
66
|
+
).fetchall()
|
|
67
|
+
if rows:
|
|
68
|
+
dim = len(json.loads(rows[0]["embedding_json"]))
|
|
69
|
+
if dim is None:
|
|
70
|
+
return # no embeddings yet; will initialize on first note
|
|
71
|
+
|
|
72
|
+
self._lsh_index = LSHIndex(
|
|
73
|
+
dim=dim,
|
|
74
|
+
num_bits=self._lsh_bits,
|
|
75
|
+
conn=self._connection,
|
|
76
|
+
seed=0,
|
|
77
|
+
)
|
|
78
|
+
# Backfill buckets for any existing embeddings (e.g. after a schema
|
|
79
|
+
# change or when opening an older notebook).
|
|
80
|
+
count = self._connection().execute(
|
|
81
|
+
"SELECT COUNT(*) FROM lsh_buckets"
|
|
82
|
+
).fetchone()[0]
|
|
83
|
+
if count == 0:
|
|
84
|
+
rows = self._connection().execute(
|
|
85
|
+
"SELECT note_id, embedding_json FROM note_embeddings"
|
|
86
|
+
).fetchall()
|
|
87
|
+
for row in rows:
|
|
88
|
+
self._lsh_index.upsert(row["note_id"], json.loads(row["embedding_json"]))
|
|
89
|
+
|
|
90
|
+
def _upsert_embedding(self, note_id: int, text: str) -> None:
|
|
91
|
+
if not text or self._embedding_provider is None:
|
|
92
|
+
return
|
|
93
|
+
vector = self._embedding_provider.embed([text])[0]
|
|
94
|
+
if not vector:
|
|
95
|
+
return
|
|
96
|
+
self._ensure_lsh_index(vector)
|
|
97
|
+
with self._connection():
|
|
98
|
+
self._connection().execute(
|
|
99
|
+
"""
|
|
100
|
+
INSERT INTO note_embeddings (note_id, embedding_json)
|
|
101
|
+
VALUES (?, ?)
|
|
102
|
+
ON CONFLICT(note_id) DO UPDATE SET
|
|
103
|
+
embedding_json = excluded.embedding_json
|
|
104
|
+
""",
|
|
105
|
+
(note_id, json.dumps(vector)),
|
|
106
|
+
)
|
|
107
|
+
if self._lsh_index is not None:
|
|
108
|
+
self._lsh_index.upsert(note_id, vector)
|
|
109
|
+
|
|
110
|
+
def add(self, text: str, *, tags: Sequence[str] | None = None) -> MemoryNote:
|
|
111
|
+
note = super().add(text, tags=tags)
|
|
112
|
+
self._upsert_embedding(note.note_id, note.text)
|
|
113
|
+
return note
|
|
114
|
+
|
|
115
|
+
def update(
|
|
116
|
+
self,
|
|
117
|
+
note_id: int,
|
|
118
|
+
text: str,
|
|
119
|
+
*,
|
|
120
|
+
tags: Sequence[str] | None = None,
|
|
121
|
+
) -> MemoryNote | None:
|
|
122
|
+
note = super().update(note_id, text, tags=tags)
|
|
123
|
+
if note is not None:
|
|
124
|
+
self._upsert_embedding(note.note_id, note.text)
|
|
125
|
+
return note
|
|
126
|
+
|
|
127
|
+
def search(self, query: str, *, limit: int = 5) -> list[MemoryNote]:
|
|
128
|
+
query = query.strip()
|
|
129
|
+
if not query:
|
|
130
|
+
return super().search(query, limit=limit)
|
|
131
|
+
if self._embedding_provider is None:
|
|
132
|
+
return super().search(query, limit=limit)
|
|
133
|
+
|
|
134
|
+
query_vec = self._embedding_provider.embed([query])[0]
|
|
135
|
+
candidates = self._candidates(query_vec, limit=limit)
|
|
136
|
+
if not candidates:
|
|
137
|
+
return []
|
|
138
|
+
|
|
139
|
+
ids = [note.note_id for note in candidates]
|
|
140
|
+
vectors = self._load_embeddings(ids)
|
|
141
|
+
|
|
142
|
+
scored: list[tuple[float, int, MemoryNote]] = []
|
|
143
|
+
for note in candidates:
|
|
144
|
+
vec = vectors.get(note.note_id)
|
|
145
|
+
if vec and query_vec:
|
|
146
|
+
similarity = cosine_similarity(query_vec, vec)
|
|
147
|
+
else:
|
|
148
|
+
# Notes without embeddings fall below any scored note.
|
|
149
|
+
similarity = -1.0 if vectors else 0.0
|
|
150
|
+
scored.append((similarity, note.note_id, note))
|
|
151
|
+
|
|
152
|
+
# Highest similarity first; tie-break by note_id for stability.
|
|
153
|
+
scored.sort(key=lambda item: (item[0], item[1]), reverse=True)
|
|
154
|
+
ranked = [note for _, _, note in scored[:limit]]
|
|
155
|
+
for note in ranked:
|
|
156
|
+
self._touch(note)
|
|
157
|
+
return ranked
|
|
158
|
+
|
|
159
|
+
def _candidates(
|
|
160
|
+
self, query_vec: list[float], *, limit: int
|
|
161
|
+
) -> list[MemoryNote]:
|
|
162
|
+
"""Return notes to score for ``query_vec``.
|
|
163
|
+
|
|
164
|
+
If an LSH index is active, use it to narrow the set; otherwise fall
|
|
165
|
+
back to scanning every note. When the pruned set is too small we also
|
|
166
|
+
fall back to avoid missing neighbors due to hash collisions.
|
|
167
|
+
"""
|
|
168
|
+
all_notes = self.all()
|
|
169
|
+
if not self._lsh_bits or not all_notes:
|
|
170
|
+
return all_notes
|
|
171
|
+
|
|
172
|
+
self._ensure_lsh_index()
|
|
173
|
+
if self._lsh_index is None:
|
|
174
|
+
return all_notes
|
|
175
|
+
|
|
176
|
+
buckets = self._lsh_index.query_buckets(query_vec)
|
|
177
|
+
candidate_ids = self._lsh_index.candidate_ids(buckets)
|
|
178
|
+
# LSH is only a speedup; if the bucket is empty or tiny, scan the table
|
|
179
|
+
# so accuracy does not suffer on small notebooks or unlucky hashes.
|
|
180
|
+
if len(candidate_ids) < max(limit * 2, 8):
|
|
181
|
+
return all_notes
|
|
182
|
+
|
|
183
|
+
placeholders = ",".join("?" for _ in candidate_ids)
|
|
184
|
+
rows = self._connection().execute(
|
|
185
|
+
f"""
|
|
186
|
+
SELECT * FROM notes
|
|
187
|
+
WHERE note_id IN ({placeholders})
|
|
188
|
+
ORDER BY note_id
|
|
189
|
+
""",
|
|
190
|
+
tuple(candidate_ids),
|
|
191
|
+
).fetchall()
|
|
192
|
+
return [self._row_to_note(r) for r in rows]
|
|
193
|
+
|
|
194
|
+
def _load_embeddings(self, note_ids: list[int]) -> dict[int, list[float]]:
|
|
195
|
+
if not note_ids:
|
|
196
|
+
return {}
|
|
197
|
+
placeholders = ",".join("?" for _ in note_ids)
|
|
198
|
+
rows = self._connection().execute(
|
|
199
|
+
f"""
|
|
200
|
+
SELECT note_id, embedding_json
|
|
201
|
+
FROM note_embeddings
|
|
202
|
+
WHERE note_id IN ({placeholders})
|
|
203
|
+
""",
|
|
204
|
+
tuple(note_ids),
|
|
205
|
+
).fetchall()
|
|
206
|
+
return {row["note_id"]: json.loads(row["embedding_json"]) for row in rows}
|
|
207
|
+
|
|
208
|
+
def reindex_embeddings(self) -> int:
|
|
209
|
+
"""Compute and store embeddings for all notes that do not have one yet."""
|
|
210
|
+
rows = self._connection().execute(
|
|
211
|
+
"""
|
|
212
|
+
SELECT n.note_id, n.text
|
|
213
|
+
FROM notes n
|
|
214
|
+
LEFT JOIN note_embeddings e ON n.note_id = e.note_id
|
|
215
|
+
WHERE e.note_id IS NULL AND n.text IS NOT NULL AND n.text != ''
|
|
216
|
+
"""
|
|
217
|
+
).fetchall()
|
|
218
|
+
if not rows:
|
|
219
|
+
return 0
|
|
220
|
+
vectors = self._embedding_provider.embed([r["text"] for r in rows])
|
|
221
|
+
count = 0
|
|
222
|
+
with self._connection():
|
|
223
|
+
for row, vec in zip(rows, vectors, strict=True):
|
|
224
|
+
if vec:
|
|
225
|
+
self._ensure_lsh_index(vec)
|
|
226
|
+
self._connection().execute(
|
|
227
|
+
"INSERT INTO note_embeddings (note_id, embedding_json) VALUES (?, ?)",
|
|
228
|
+
(row["note_id"], json.dumps(vec)),
|
|
229
|
+
)
|
|
230
|
+
if self._lsh_index is not None:
|
|
231
|
+
self._lsh_index.upsert(row["note_id"], vec)
|
|
232
|
+
count += 1
|
|
233
|
+
return count
|
|
234
|
+
|
|
235
|
+
def forget(
|
|
236
|
+
self,
|
|
237
|
+
*,
|
|
238
|
+
note_id: int | None = None,
|
|
239
|
+
text_prefix: str | None = None,
|
|
240
|
+
) -> list[MemoryNote]:
|
|
241
|
+
removed = super().forget(note_id=note_id, text_prefix=text_prefix)
|
|
242
|
+
if removed:
|
|
243
|
+
ids = [note.note_id for note in removed]
|
|
244
|
+
placeholders = ",".join("?" for _ in ids)
|
|
245
|
+
with self._connection():
|
|
246
|
+
self._connection().execute(
|
|
247
|
+
f"DELETE FROM note_embeddings WHERE note_id IN ({placeholders})",
|
|
248
|
+
tuple(ids),
|
|
249
|
+
)
|
|
250
|
+
if self._lsh_index is not None:
|
|
251
|
+
for nid in ids:
|
|
252
|
+
self._lsh_index.remove(nid)
|
|
253
|
+
return removed
|