keep-skill 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keep/__init__.py +3 -6
- keep/api.py +1052 -145
- keep/cli.py +705 -132
- keep/config.py +172 -41
- keep/context.py +1 -125
- keep/document_store.py +908 -0
- keep/errors.py +33 -0
- keep/indexing.py +1 -1
- keep/logging_config.py +34 -3
- keep/paths.py +81 -17
- keep/pending_summaries.py +52 -40
- keep/providers/embedding_cache.py +59 -46
- keep/providers/embeddings.py +43 -13
- keep/providers/mlx.py +23 -21
- keep/store.py +169 -25
- keep_skill-0.3.0.dist-info/METADATA +218 -0
- keep_skill-0.3.0.dist-info/RECORD +28 -0
- keep_skill-0.1.0.dist-info/METADATA +0 -290
- keep_skill-0.1.0.dist-info/RECORD +0 -26
- {keep_skill-0.1.0.dist-info → keep_skill-0.3.0.dist-info}/WHEEL +0 -0
- {keep_skill-0.1.0.dist-info → keep_skill-0.3.0.dist-info}/entry_points.txt +0 -0
- {keep_skill-0.1.0.dist-info → keep_skill-0.3.0.dist-info}/licenses/LICENSE +0 -0
keep/errors.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Error logging utilities for keep CLI.
|
|
3
|
+
|
|
4
|
+
Logs full stack traces to /tmp for debugging while showing clean messages to users.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import traceback
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
ERROR_LOG_PATH = Path("/tmp/keep-errors.log")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def log_exception(exc: Exception, context: str = "") -> Path:
|
|
15
|
+
"""
|
|
16
|
+
Log exception with full traceback to file.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
exc: The exception that occurred
|
|
20
|
+
context: Optional context string (e.g., command name)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Path to the error log file
|
|
24
|
+
"""
|
|
25
|
+
timestamp = datetime.now(timezone.utc).isoformat()
|
|
26
|
+
with open(ERROR_LOG_PATH, "a") as f:
|
|
27
|
+
f.write(f"\n{'='*60}\n")
|
|
28
|
+
f.write(f"[{timestamp}]")
|
|
29
|
+
if context:
|
|
30
|
+
f.write(f" {context}")
|
|
31
|
+
f.write("\n")
|
|
32
|
+
f.write(traceback.format_exc())
|
|
33
|
+
return ERROR_LOG_PATH
|
keep/indexing.py
CHANGED
keep/logging_config.py
CHANGED
|
@@ -57,17 +57,48 @@ def configure_quiet_mode(quiet: bool = True):
|
|
|
57
57
|
def enable_verbose_mode():
|
|
58
58
|
"""Re-enable verbose output for debugging."""
|
|
59
59
|
configure_quiet_mode(quiet=False)
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
# Restore defaults
|
|
62
62
|
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
|
63
63
|
os.environ.pop("TRANSFORMERS_VERBOSITY", None)
|
|
64
|
-
|
|
64
|
+
|
|
65
65
|
# Re-enable warnings
|
|
66
66
|
warnings.filterwarnings("default")
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
# Reset logging levels
|
|
69
69
|
import logging
|
|
70
70
|
logging.getLogger("transformers").setLevel(logging.INFO)
|
|
71
71
|
logging.getLogger("sentence_transformers").setLevel(logging.INFO)
|
|
72
72
|
logging.getLogger("mlx").setLevel(logging.INFO)
|
|
73
73
|
logging.getLogger("chromadb").setLevel(logging.INFO)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def enable_debug_mode():
|
|
77
|
+
"""Enable debug-level logging to stderr."""
|
|
78
|
+
import logging
|
|
79
|
+
|
|
80
|
+
# Re-enable warnings
|
|
81
|
+
warnings.filterwarnings("default")
|
|
82
|
+
|
|
83
|
+
# Restore library verbosity
|
|
84
|
+
os.environ.pop("HF_HUB_DISABLE_PROGRESS_BARS", None)
|
|
85
|
+
os.environ.pop("TRANSFORMERS_VERBOSITY", None)
|
|
86
|
+
|
|
87
|
+
# Configure root logger for debug output
|
|
88
|
+
root_logger = logging.getLogger()
|
|
89
|
+
root_logger.setLevel(logging.DEBUG)
|
|
90
|
+
|
|
91
|
+
# Add stderr handler if not already present
|
|
92
|
+
if not any(isinstance(h, logging.StreamHandler) and h.stream == sys.stderr
|
|
93
|
+
for h in root_logger.handlers):
|
|
94
|
+
handler = logging.StreamHandler(sys.stderr)
|
|
95
|
+
handler.setLevel(logging.DEBUG)
|
|
96
|
+
handler.setFormatter(logging.Formatter(
|
|
97
|
+
"%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
98
|
+
datefmt="%H:%M:%S"
|
|
99
|
+
))
|
|
100
|
+
root_logger.addHandler(handler)
|
|
101
|
+
|
|
102
|
+
# Set library loggers to DEBUG
|
|
103
|
+
for name in ("keep", "transformers", "sentence_transformers", "mlx", "chromadb"):
|
|
104
|
+
logging.getLogger(name).setLevel(logging.DEBUG)
|
keep/paths.py
CHANGED
|
@@ -1,11 +1,29 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Utility functions for locating paths.
|
|
3
|
+
|
|
4
|
+
Config and store discovery follows this priority:
|
|
5
|
+
|
|
6
|
+
Config discovery:
|
|
7
|
+
1. KEEP_CONFIG envvar (path to .keep/ directory)
|
|
8
|
+
2. Tree-walk from cwd up to ~, find first .keep/keep.toml
|
|
9
|
+
3. ~/.keep/ default
|
|
10
|
+
|
|
11
|
+
Store resolution:
|
|
12
|
+
1. --store CLI option (passed to Keeper)
|
|
13
|
+
2. KEEP_STORE_PATH envvar
|
|
14
|
+
3. store.path in config file
|
|
15
|
+
4. Config directory itself (backwards compat)
|
|
3
16
|
"""
|
|
4
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
5
20
|
import os
|
|
6
21
|
import warnings
|
|
7
22
|
from pathlib import Path
|
|
8
|
-
from typing import Optional
|
|
23
|
+
from typing import TYPE_CHECKING, Optional
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from .config import StoreConfig
|
|
9
27
|
|
|
10
28
|
|
|
11
29
|
def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
|
|
@@ -35,14 +53,67 @@ def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
|
|
|
35
53
|
return None
|
|
36
54
|
|
|
37
55
|
|
|
38
|
-
def
|
|
56
|
+
def find_config_dir(start_path: Optional[Path] = None) -> Path:
|
|
57
|
+
"""
|
|
58
|
+
Find config directory by tree-walking from start_path up to home.
|
|
59
|
+
|
|
60
|
+
Looks for .keep/keep.toml at each directory level, stopping at home.
|
|
61
|
+
Returns the .keep/ directory containing the config, or ~/.keep/ if none found.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
start_path: Path to start searching from. Defaults to cwd.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Path to the .keep/ config directory.
|
|
68
|
+
"""
|
|
69
|
+
if start_path is None:
|
|
70
|
+
start_path = Path.cwd()
|
|
71
|
+
|
|
72
|
+
home = Path.home()
|
|
73
|
+
current = start_path.resolve()
|
|
74
|
+
|
|
75
|
+
while True:
|
|
76
|
+
candidate = current / ".keep" / "keep.toml"
|
|
77
|
+
if candidate.exists():
|
|
78
|
+
return current / ".keep"
|
|
79
|
+
|
|
80
|
+
# Stop at home or filesystem root
|
|
81
|
+
if current == home or current == current.parent:
|
|
82
|
+
break
|
|
83
|
+
current = current.parent
|
|
84
|
+
|
|
85
|
+
# Default: ~/.keep/
|
|
86
|
+
return home / ".keep"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_config_dir() -> Path:
|
|
90
|
+
"""
|
|
91
|
+
Get the config directory.
|
|
92
|
+
|
|
93
|
+
Priority:
|
|
94
|
+
1. KEEP_CONFIG environment variable
|
|
95
|
+
2. Tree-walk from cwd up to ~ (find_config_dir)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Path to the .keep/ config directory.
|
|
99
|
+
"""
|
|
100
|
+
env = os.environ.get("KEEP_CONFIG")
|
|
101
|
+
if env:
|
|
102
|
+
return Path(env).expanduser().resolve()
|
|
103
|
+
return find_config_dir()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_default_store_path(config: Optional[StoreConfig] = None) -> Path:
|
|
39
107
|
"""
|
|
40
108
|
Get the default store path.
|
|
41
109
|
|
|
42
110
|
Priority:
|
|
43
111
|
1. KEEP_STORE_PATH environment variable
|
|
44
|
-
2. .
|
|
45
|
-
3.
|
|
112
|
+
2. store.path setting in config file
|
|
113
|
+
3. Config directory itself (backwards compat)
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
config: Optional loaded config to check for store.path setting.
|
|
46
117
|
|
|
47
118
|
Returns:
|
|
48
119
|
Path to the store directory (may not exist yet).
|
|
@@ -51,17 +122,10 @@ def get_default_store_path() -> Path:
|
|
|
51
122
|
env_path = os.environ.get("KEEP_STORE_PATH")
|
|
52
123
|
if env_path:
|
|
53
124
|
return Path(env_path).resolve()
|
|
54
|
-
|
|
55
|
-
# Try to find git root
|
|
56
|
-
git_root = find_git_root()
|
|
57
|
-
if git_root:
|
|
58
|
-
return git_root / ".keep"
|
|
59
125
|
|
|
60
|
-
#
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
return home / ".keep"
|
|
126
|
+
# Check config for explicit store.path
|
|
127
|
+
if config and config.store_path:
|
|
128
|
+
return Path(config.store_path).expanduser().resolve()
|
|
129
|
+
|
|
130
|
+
# Default: config directory is also the store
|
|
131
|
+
return get_config_dir()
|
keep/pending_summaries.py
CHANGED
|
@@ -6,6 +6,7 @@ This enables fast indexing with lazy summarization.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import sqlite3
|
|
9
|
+
import threading
|
|
9
10
|
from dataclasses import dataclass
|
|
10
11
|
from datetime import datetime, timezone
|
|
11
12
|
from pathlib import Path
|
|
@@ -37,12 +38,19 @@ class PendingSummaryQueue:
|
|
|
37
38
|
"""
|
|
38
39
|
self._queue_path = queue_path
|
|
39
40
|
self._conn: Optional[sqlite3.Connection] = None
|
|
41
|
+
self._lock = threading.Lock()
|
|
40
42
|
self._init_db()
|
|
41
43
|
|
|
42
44
|
def _init_db(self) -> None:
|
|
43
45
|
"""Initialize the SQLite database."""
|
|
44
46
|
self._queue_path.parent.mkdir(parents=True, exist_ok=True)
|
|
45
47
|
self._conn = sqlite3.connect(str(self._queue_path), check_same_thread=False)
|
|
48
|
+
|
|
49
|
+
# Enable WAL mode for better concurrent access across processes
|
|
50
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
51
|
+
# Wait up to 5 seconds for locks instead of failing immediately
|
|
52
|
+
self._conn.execute("PRAGMA busy_timeout=5000")
|
|
53
|
+
|
|
46
54
|
self._conn.execute("""
|
|
47
55
|
CREATE TABLE IF NOT EXISTS pending_summaries (
|
|
48
56
|
id TEXT NOT NULL,
|
|
@@ -66,12 +74,13 @@ class PendingSummaryQueue:
|
|
|
66
74
|
If the same id+collection already exists, replaces it (newer content wins).
|
|
67
75
|
"""
|
|
68
76
|
now = datetime.now(timezone.utc).isoformat()
|
|
69
|
-
self.
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
77
|
+
with self._lock:
|
|
78
|
+
self._conn.execute("""
|
|
79
|
+
INSERT OR REPLACE INTO pending_summaries
|
|
80
|
+
(id, collection, content, queued_at, attempts)
|
|
81
|
+
VALUES (?, ?, ?, ?, 0)
|
|
82
|
+
""", (id, collection, content, now))
|
|
83
|
+
self._conn.commit()
|
|
75
84
|
|
|
76
85
|
def dequeue(self, limit: int = 10) -> list[PendingSummary]:
|
|
77
86
|
"""
|
|
@@ -80,42 +89,44 @@ class PendingSummaryQueue:
|
|
|
80
89
|
Items are returned but not removed - call `complete()` after successful processing.
|
|
81
90
|
Increments attempt counter on each dequeue.
|
|
82
91
|
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
92
|
+
with self._lock:
|
|
93
|
+
cursor = self._conn.execute("""
|
|
94
|
+
SELECT id, collection, content, queued_at, attempts
|
|
95
|
+
FROM pending_summaries
|
|
96
|
+
ORDER BY queued_at ASC
|
|
97
|
+
LIMIT ?
|
|
98
|
+
""", (limit,))
|
|
99
|
+
|
|
100
|
+
items = []
|
|
101
|
+
for row in cursor.fetchall():
|
|
102
|
+
items.append(PendingSummary(
|
|
103
|
+
id=row[0],
|
|
104
|
+
collection=row[1],
|
|
105
|
+
content=row[2],
|
|
106
|
+
queued_at=row[3],
|
|
107
|
+
attempts=row[4],
|
|
108
|
+
))
|
|
109
|
+
|
|
110
|
+
# Increment attempt counters
|
|
111
|
+
if items:
|
|
112
|
+
ids = [(item.id, item.collection) for item in items]
|
|
113
|
+
self._conn.executemany("""
|
|
114
|
+
UPDATE pending_summaries
|
|
115
|
+
SET attempts = attempts + 1
|
|
116
|
+
WHERE id = ? AND collection = ?
|
|
117
|
+
""", ids)
|
|
118
|
+
self._conn.commit()
|
|
109
119
|
|
|
110
120
|
return items
|
|
111
121
|
|
|
112
122
|
def complete(self, id: str, collection: str) -> None:
|
|
113
123
|
"""Remove an item from the queue after successful processing."""
|
|
114
|
-
self.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
124
|
+
with self._lock:
|
|
125
|
+
self._conn.execute("""
|
|
126
|
+
DELETE FROM pending_summaries
|
|
127
|
+
WHERE id = ? AND collection = ?
|
|
128
|
+
""", (id, collection))
|
|
129
|
+
self._conn.commit()
|
|
119
130
|
|
|
120
131
|
def count(self) -> int:
|
|
121
132
|
"""Get count of pending items."""
|
|
@@ -143,9 +154,10 @@ class PendingSummaryQueue:
|
|
|
143
154
|
|
|
144
155
|
def clear(self) -> int:
|
|
145
156
|
"""Clear all pending items. Returns count of items cleared."""
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
157
|
+
with self._lock:
|
|
158
|
+
count = self.count()
|
|
159
|
+
self._conn.execute("DELETE FROM pending_summaries")
|
|
160
|
+
self._conn.commit()
|
|
149
161
|
return count
|
|
150
162
|
|
|
151
163
|
def close(self) -> None:
|
|
@@ -8,6 +8,7 @@ avoiding redundant embedding calls for unchanged content.
|
|
|
8
8
|
import hashlib
|
|
9
9
|
import json
|
|
10
10
|
import sqlite3
|
|
11
|
+
import threading
|
|
11
12
|
from datetime import datetime, timezone
|
|
12
13
|
from pathlib import Path
|
|
13
14
|
from typing import Optional
|
|
@@ -32,12 +33,19 @@ class EmbeddingCache:
|
|
|
32
33
|
self._cache_path = cache_path
|
|
33
34
|
self._max_entries = max_entries
|
|
34
35
|
self._conn: Optional[sqlite3.Connection] = None
|
|
36
|
+
self._lock = threading.RLock()
|
|
35
37
|
self._init_db()
|
|
36
38
|
|
|
37
39
|
def _init_db(self) -> None:
|
|
38
40
|
"""Initialize the SQLite database."""
|
|
39
41
|
self._cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
40
42
|
self._conn = sqlite3.connect(str(self._cache_path), check_same_thread=False)
|
|
43
|
+
|
|
44
|
+
# Enable WAL mode for better concurrent access across processes
|
|
45
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
46
|
+
# Wait up to 5 seconds for locks instead of failing immediately
|
|
47
|
+
self._conn.execute("PRAGMA busy_timeout=5000")
|
|
48
|
+
|
|
41
49
|
self._conn.execute("""
|
|
42
50
|
CREATE TABLE IF NOT EXISTS embedding_cache (
|
|
43
51
|
content_hash TEXT PRIMARY KEY,
|
|
@@ -62,28 +70,30 @@ class EmbeddingCache:
|
|
|
62
70
|
def get(self, model_name: str, content: str) -> Optional[list[float]]:
|
|
63
71
|
"""
|
|
64
72
|
Get cached embedding if it exists.
|
|
65
|
-
|
|
73
|
+
|
|
66
74
|
Updates last_accessed timestamp on hit.
|
|
67
75
|
"""
|
|
68
76
|
content_hash = self._hash_key(model_name, content)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
(
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if row is not None:
|
|
76
|
-
# Update last_accessed
|
|
77
|
-
now = datetime.now(timezone.utc).isoformat()
|
|
78
|
-
self._conn.execute(
|
|
79
|
-
"UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
|
|
80
|
-
(now, content_hash)
|
|
77
|
+
|
|
78
|
+
with self._lock:
|
|
79
|
+
cursor = self._conn.execute(
|
|
80
|
+
"SELECT embedding FROM embedding_cache WHERE content_hash = ?",
|
|
81
|
+
(content_hash,)
|
|
81
82
|
)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
83
|
+
row = cursor.fetchone()
|
|
84
|
+
|
|
85
|
+
if row is not None:
|
|
86
|
+
# Update last_accessed
|
|
87
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
88
|
+
self._conn.execute(
|
|
89
|
+
"UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
|
|
90
|
+
(now, content_hash)
|
|
91
|
+
)
|
|
92
|
+
self._conn.commit()
|
|
93
|
+
|
|
94
|
+
# Deserialize embedding
|
|
95
|
+
return json.loads(row[0])
|
|
96
|
+
|
|
87
97
|
return None
|
|
88
98
|
|
|
89
99
|
def put(
|
|
@@ -94,40 +104,42 @@ class EmbeddingCache:
|
|
|
94
104
|
) -> None:
|
|
95
105
|
"""
|
|
96
106
|
Cache an embedding.
|
|
97
|
-
|
|
107
|
+
|
|
98
108
|
Evicts oldest entries if cache exceeds max_entries.
|
|
99
109
|
"""
|
|
100
110
|
content_hash = self._hash_key(model_name, content)
|
|
101
111
|
now = datetime.now(timezone.utc).isoformat()
|
|
102
112
|
embedding_blob = json.dumps(embedding)
|
|
103
|
-
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
+
|
|
114
|
+
with self._lock:
|
|
115
|
+
self._conn.execute("""
|
|
116
|
+
INSERT OR REPLACE INTO embedding_cache
|
|
117
|
+
(content_hash, model_name, embedding, dimension, created_at, last_accessed)
|
|
118
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
119
|
+
""", (content_hash, model_name, embedding_blob, len(embedding), now, now))
|
|
120
|
+
self._conn.commit()
|
|
121
|
+
|
|
122
|
+
# Evict old entries if needed
|
|
123
|
+
self._maybe_evict()
|
|
113
124
|
|
|
114
125
|
def _maybe_evict(self) -> None:
|
|
115
126
|
"""Evict oldest entries if cache exceeds max size."""
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
127
|
+
with self._lock:
|
|
128
|
+
cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
|
|
129
|
+
count = cursor.fetchone()[0]
|
|
130
|
+
|
|
131
|
+
if count > self._max_entries:
|
|
132
|
+
# Delete oldest 10% by last_accessed
|
|
133
|
+
evict_count = max(1, count // 10)
|
|
134
|
+
self._conn.execute("""
|
|
135
|
+
DELETE FROM embedding_cache
|
|
136
|
+
WHERE content_hash IN (
|
|
137
|
+
SELECT content_hash FROM embedding_cache
|
|
138
|
+
ORDER BY last_accessed ASC
|
|
139
|
+
LIMIT ?
|
|
140
|
+
)
|
|
141
|
+
""", (evict_count,))
|
|
142
|
+
self._conn.commit()
|
|
131
143
|
|
|
132
144
|
def stats(self) -> dict:
|
|
133
145
|
"""Get cache statistics."""
|
|
@@ -145,8 +157,9 @@ class EmbeddingCache:
|
|
|
145
157
|
|
|
146
158
|
def clear(self) -> None:
|
|
147
159
|
"""Clear all cached embeddings."""
|
|
148
|
-
self.
|
|
149
|
-
|
|
160
|
+
with self._lock:
|
|
161
|
+
self._conn.execute("DELETE FROM embedding_cache")
|
|
162
|
+
self._conn.commit()
|
|
150
163
|
|
|
151
164
|
def close(self) -> None:
|
|
152
165
|
"""Close the database connection."""
|
keep/providers/embeddings.py
CHANGED
|
@@ -11,12 +11,12 @@ from .base import EmbeddingProvider, get_registry
|
|
|
11
11
|
class SentenceTransformerEmbedding:
|
|
12
12
|
"""
|
|
13
13
|
Embedding provider using sentence-transformers library.
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
Runs locally, no API key required. Good default for getting started.
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
Requires: pip install sentence-transformers
|
|
18
18
|
"""
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
|
21
21
|
"""
|
|
22
22
|
Args:
|
|
@@ -29,9 +29,21 @@ class SentenceTransformerEmbedding:
|
|
|
29
29
|
"SentenceTransformerEmbedding requires 'sentence-transformers' library. "
|
|
30
30
|
"Install with: pip install sentence-transformers"
|
|
31
31
|
)
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
self.model_name = model
|
|
34
|
-
|
|
34
|
+
|
|
35
|
+
# Check if model is already cached locally to avoid network calls
|
|
36
|
+
# Expand short model names (e.g. "all-MiniLM-L6-v2" -> "sentence-transformers/all-MiniLM-L6-v2")
|
|
37
|
+
local_only = False
|
|
38
|
+
try:
|
|
39
|
+
from huggingface_hub import try_to_load_from_cache
|
|
40
|
+
repo_id = model if "/" in model else f"sentence-transformers/{model}"
|
|
41
|
+
cached = try_to_load_from_cache(repo_id, "config.json")
|
|
42
|
+
local_only = cached is not None
|
|
43
|
+
except ImportError:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
self._model = SentenceTransformer(model, local_files_only=local_only)
|
|
35
47
|
|
|
36
48
|
@property
|
|
37
49
|
def dimension(self) -> int:
|
|
@@ -83,8 +95,9 @@ class OpenAIEmbedding:
|
|
|
83
95
|
)
|
|
84
96
|
|
|
85
97
|
self.model_name = model
|
|
86
|
-
|
|
87
|
-
|
|
98
|
+
# Use lookup table if available, otherwise detect lazily from first embedding
|
|
99
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model)
|
|
100
|
+
|
|
88
101
|
# Resolve API key
|
|
89
102
|
key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
90
103
|
if not key:
|
|
@@ -96,16 +109,24 @@ class OpenAIEmbedding:
|
|
|
96
109
|
|
|
97
110
|
@property
|
|
98
111
|
def dimension(self) -> int:
|
|
99
|
-
"""Get embedding dimension for the model."""
|
|
112
|
+
"""Get embedding dimension for the model (detected lazily if unknown)."""
|
|
113
|
+
if self._dimension is None:
|
|
114
|
+
# Unknown model: detect from first embedding
|
|
115
|
+
test_embedding = self.embed("dimension test")
|
|
116
|
+
self._dimension = len(test_embedding)
|
|
100
117
|
return self._dimension
|
|
101
|
-
|
|
118
|
+
|
|
102
119
|
def embed(self, text: str) -> list[float]:
|
|
103
120
|
"""Generate embedding for a single text."""
|
|
104
121
|
response = self._client.embeddings.create(
|
|
105
122
|
model=self.model_name,
|
|
106
123
|
input=text,
|
|
107
124
|
)
|
|
108
|
-
|
|
125
|
+
embedding = response.data[0].embedding
|
|
126
|
+
# Cache dimension if not yet known
|
|
127
|
+
if self._dimension is None:
|
|
128
|
+
self._dimension = len(embedding)
|
|
129
|
+
return embedding
|
|
109
130
|
|
|
110
131
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
111
132
|
"""Generate embeddings for multiple texts."""
|
|
@@ -152,7 +173,8 @@ class GeminiEmbedding:
|
|
|
152
173
|
)
|
|
153
174
|
|
|
154
175
|
self.model_name = model
|
|
155
|
-
|
|
176
|
+
# Use lookup table if available, otherwise detect lazily from first embedding
|
|
177
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model)
|
|
156
178
|
|
|
157
179
|
# Resolve API key
|
|
158
180
|
key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
|
|
@@ -165,7 +187,11 @@ class GeminiEmbedding:
|
|
|
165
187
|
|
|
166
188
|
@property
|
|
167
189
|
def dimension(self) -> int:
|
|
168
|
-
"""Get embedding dimension for the model."""
|
|
190
|
+
"""Get embedding dimension for the model (detected lazily if unknown)."""
|
|
191
|
+
if self._dimension is None:
|
|
192
|
+
# Unknown model: detect from first embedding
|
|
193
|
+
test_embedding = self.embed("dimension test")
|
|
194
|
+
self._dimension = len(test_embedding)
|
|
169
195
|
return self._dimension
|
|
170
196
|
|
|
171
197
|
def embed(self, text: str) -> list[float]:
|
|
@@ -174,7 +200,11 @@ class GeminiEmbedding:
|
|
|
174
200
|
model=self.model_name,
|
|
175
201
|
contents=text,
|
|
176
202
|
)
|
|
177
|
-
|
|
203
|
+
embedding = list(result.embeddings[0].values)
|
|
204
|
+
# Cache dimension if not yet known
|
|
205
|
+
if self._dimension is None:
|
|
206
|
+
self._dimension = len(embedding)
|
|
207
|
+
return embedding
|
|
178
208
|
|
|
179
209
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
180
210
|
"""Generate embeddings for multiple texts."""
|