keep-skill 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keep/__init__.py +3 -6
- keep/api.py +793 -141
- keep/cli.py +467 -129
- keep/config.py +172 -41
- keep/context.py +1 -125
- keep/document_store.py +569 -0
- keep/errors.py +33 -0
- keep/indexing.py +1 -1
- keep/logging_config.py +34 -3
- keep/paths.py +81 -17
- keep/pending_summaries.py +46 -40
- keep/providers/embedding_cache.py +53 -46
- keep/providers/embeddings.py +43 -13
- keep/providers/mlx.py +23 -21
- keep/store.py +58 -14
- {keep_skill-0.1.0.dist-info → keep_skill-0.2.0.dist-info}/METADATA +29 -15
- keep_skill-0.2.0.dist-info/RECORD +28 -0
- keep_skill-0.1.0.dist-info/RECORD +0 -26
- {keep_skill-0.1.0.dist-info → keep_skill-0.2.0.dist-info}/WHEEL +0 -0
- {keep_skill-0.1.0.dist-info → keep_skill-0.2.0.dist-info}/entry_points.txt +0 -0
- {keep_skill-0.1.0.dist-info → keep_skill-0.2.0.dist-info}/licenses/LICENSE +0 -0
keep/paths.py
CHANGED
|
@@ -1,11 +1,29 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Utility functions for locating paths.
|
|
3
|
+
|
|
4
|
+
Config and store discovery follows this priority:
|
|
5
|
+
|
|
6
|
+
Config discovery:
|
|
7
|
+
1. KEEP_CONFIG envvar (path to .keep/ directory)
|
|
8
|
+
2. Tree-walk from cwd up to ~, find first .keep/keep.toml
|
|
9
|
+
3. ~/.keep/ default
|
|
10
|
+
|
|
11
|
+
Store resolution:
|
|
12
|
+
1. --store CLI option (passed to Keeper)
|
|
13
|
+
2. KEEP_STORE_PATH envvar
|
|
14
|
+
3. store.path in config file
|
|
15
|
+
4. Config directory itself (backwards compat)
|
|
3
16
|
"""
|
|
4
17
|
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
5
20
|
import os
|
|
6
21
|
import warnings
|
|
7
22
|
from pathlib import Path
|
|
8
|
-
from typing import Optional
|
|
23
|
+
from typing import TYPE_CHECKING, Optional
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from .config import StoreConfig
|
|
9
27
|
|
|
10
28
|
|
|
11
29
|
def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
|
|
@@ -35,14 +53,67 @@ def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
|
|
|
35
53
|
return None
|
|
36
54
|
|
|
37
55
|
|
|
38
|
-
def
|
|
56
|
+
def find_config_dir(start_path: Optional[Path] = None) -> Path:
|
|
57
|
+
"""
|
|
58
|
+
Find config directory by tree-walking from start_path up to home.
|
|
59
|
+
|
|
60
|
+
Looks for .keep/keep.toml at each directory level, stopping at home.
|
|
61
|
+
Returns the .keep/ directory containing the config, or ~/.keep/ if none found.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
start_path: Path to start searching from. Defaults to cwd.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Path to the .keep/ config directory.
|
|
68
|
+
"""
|
|
69
|
+
if start_path is None:
|
|
70
|
+
start_path = Path.cwd()
|
|
71
|
+
|
|
72
|
+
home = Path.home()
|
|
73
|
+
current = start_path.resolve()
|
|
74
|
+
|
|
75
|
+
while True:
|
|
76
|
+
candidate = current / ".keep" / "keep.toml"
|
|
77
|
+
if candidate.exists():
|
|
78
|
+
return current / ".keep"
|
|
79
|
+
|
|
80
|
+
# Stop at home or filesystem root
|
|
81
|
+
if current == home or current == current.parent:
|
|
82
|
+
break
|
|
83
|
+
current = current.parent
|
|
84
|
+
|
|
85
|
+
# Default: ~/.keep/
|
|
86
|
+
return home / ".keep"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_config_dir() -> Path:
|
|
90
|
+
"""
|
|
91
|
+
Get the config directory.
|
|
92
|
+
|
|
93
|
+
Priority:
|
|
94
|
+
1. KEEP_CONFIG environment variable
|
|
95
|
+
2. Tree-walk from cwd up to ~ (find_config_dir)
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Path to the .keep/ config directory.
|
|
99
|
+
"""
|
|
100
|
+
env = os.environ.get("KEEP_CONFIG")
|
|
101
|
+
if env:
|
|
102
|
+
return Path(env).expanduser().resolve()
|
|
103
|
+
return find_config_dir()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_default_store_path(config: Optional[StoreConfig] = None) -> Path:
|
|
39
107
|
"""
|
|
40
108
|
Get the default store path.
|
|
41
109
|
|
|
42
110
|
Priority:
|
|
43
111
|
1. KEEP_STORE_PATH environment variable
|
|
44
|
-
2. .
|
|
45
|
-
3.
|
|
112
|
+
2. store.path setting in config file
|
|
113
|
+
3. Config directory itself (backwards compat)
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
config: Optional loaded config to check for store.path setting.
|
|
46
117
|
|
|
47
118
|
Returns:
|
|
48
119
|
Path to the store directory (may not exist yet).
|
|
@@ -51,17 +122,10 @@ def get_default_store_path() -> Path:
|
|
|
51
122
|
env_path = os.environ.get("KEEP_STORE_PATH")
|
|
52
123
|
if env_path:
|
|
53
124
|
return Path(env_path).resolve()
|
|
54
|
-
|
|
55
|
-
# Try to find git root
|
|
56
|
-
git_root = find_git_root()
|
|
57
|
-
if git_root:
|
|
58
|
-
return git_root / ".keep"
|
|
59
125
|
|
|
60
|
-
#
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
return home / ".keep"
|
|
126
|
+
# Check config for explicit store.path
|
|
127
|
+
if config and config.store_path:
|
|
128
|
+
return Path(config.store_path).expanduser().resolve()
|
|
129
|
+
|
|
130
|
+
# Default: config directory is also the store
|
|
131
|
+
return get_config_dir()
|
keep/pending_summaries.py
CHANGED
|
@@ -6,6 +6,7 @@ This enables fast indexing with lazy summarization.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import sqlite3
|
|
9
|
+
import threading
|
|
9
10
|
from dataclasses import dataclass
|
|
10
11
|
from datetime import datetime, timezone
|
|
11
12
|
from pathlib import Path
|
|
@@ -37,6 +38,7 @@ class PendingSummaryQueue:
|
|
|
37
38
|
"""
|
|
38
39
|
self._queue_path = queue_path
|
|
39
40
|
self._conn: Optional[sqlite3.Connection] = None
|
|
41
|
+
self._lock = threading.Lock()
|
|
40
42
|
self._init_db()
|
|
41
43
|
|
|
42
44
|
def _init_db(self) -> None:
|
|
@@ -66,12 +68,13 @@ class PendingSummaryQueue:
|
|
|
66
68
|
If the same id+collection already exists, replaces it (newer content wins).
|
|
67
69
|
"""
|
|
68
70
|
now = datetime.now(timezone.utc).isoformat()
|
|
69
|
-
self.
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
71
|
+
with self._lock:
|
|
72
|
+
self._conn.execute("""
|
|
73
|
+
INSERT OR REPLACE INTO pending_summaries
|
|
74
|
+
(id, collection, content, queued_at, attempts)
|
|
75
|
+
VALUES (?, ?, ?, ?, 0)
|
|
76
|
+
""", (id, collection, content, now))
|
|
77
|
+
self._conn.commit()
|
|
75
78
|
|
|
76
79
|
def dequeue(self, limit: int = 10) -> list[PendingSummary]:
|
|
77
80
|
"""
|
|
@@ -80,42 +83,44 @@ class PendingSummaryQueue:
|
|
|
80
83
|
Items are returned but not removed - call `complete()` after successful processing.
|
|
81
84
|
Increments attempt counter on each dequeue.
|
|
82
85
|
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
86
|
+
with self._lock:
|
|
87
|
+
cursor = self._conn.execute("""
|
|
88
|
+
SELECT id, collection, content, queued_at, attempts
|
|
89
|
+
FROM pending_summaries
|
|
90
|
+
ORDER BY queued_at ASC
|
|
91
|
+
LIMIT ?
|
|
92
|
+
""", (limit,))
|
|
93
|
+
|
|
94
|
+
items = []
|
|
95
|
+
for row in cursor.fetchall():
|
|
96
|
+
items.append(PendingSummary(
|
|
97
|
+
id=row[0],
|
|
98
|
+
collection=row[1],
|
|
99
|
+
content=row[2],
|
|
100
|
+
queued_at=row[3],
|
|
101
|
+
attempts=row[4],
|
|
102
|
+
))
|
|
103
|
+
|
|
104
|
+
# Increment attempt counters
|
|
105
|
+
if items:
|
|
106
|
+
ids = [(item.id, item.collection) for item in items]
|
|
107
|
+
self._conn.executemany("""
|
|
108
|
+
UPDATE pending_summaries
|
|
109
|
+
SET attempts = attempts + 1
|
|
110
|
+
WHERE id = ? AND collection = ?
|
|
111
|
+
""", ids)
|
|
112
|
+
self._conn.commit()
|
|
109
113
|
|
|
110
114
|
return items
|
|
111
115
|
|
|
112
116
|
def complete(self, id: str, collection: str) -> None:
|
|
113
117
|
"""Remove an item from the queue after successful processing."""
|
|
114
|
-
self.
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
118
|
+
with self._lock:
|
|
119
|
+
self._conn.execute("""
|
|
120
|
+
DELETE FROM pending_summaries
|
|
121
|
+
WHERE id = ? AND collection = ?
|
|
122
|
+
""", (id, collection))
|
|
123
|
+
self._conn.commit()
|
|
119
124
|
|
|
120
125
|
def count(self) -> int:
|
|
121
126
|
"""Get count of pending items."""
|
|
@@ -143,9 +148,10 @@ class PendingSummaryQueue:
|
|
|
143
148
|
|
|
144
149
|
def clear(self) -> int:
|
|
145
150
|
"""Clear all pending items. Returns count of items cleared."""
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
151
|
+
with self._lock:
|
|
152
|
+
count = self.count()
|
|
153
|
+
self._conn.execute("DELETE FROM pending_summaries")
|
|
154
|
+
self._conn.commit()
|
|
149
155
|
return count
|
|
150
156
|
|
|
151
157
|
def close(self) -> None:
|
|
@@ -8,6 +8,7 @@ avoiding redundant embedding calls for unchanged content.
|
|
|
8
8
|
import hashlib
|
|
9
9
|
import json
|
|
10
10
|
import sqlite3
|
|
11
|
+
import threading
|
|
11
12
|
from datetime import datetime, timezone
|
|
12
13
|
from pathlib import Path
|
|
13
14
|
from typing import Optional
|
|
@@ -32,6 +33,7 @@ class EmbeddingCache:
|
|
|
32
33
|
self._cache_path = cache_path
|
|
33
34
|
self._max_entries = max_entries
|
|
34
35
|
self._conn: Optional[sqlite3.Connection] = None
|
|
36
|
+
self._lock = threading.RLock()
|
|
35
37
|
self._init_db()
|
|
36
38
|
|
|
37
39
|
def _init_db(self) -> None:
|
|
@@ -62,28 +64,30 @@ class EmbeddingCache:
|
|
|
62
64
|
def get(self, model_name: str, content: str) -> Optional[list[float]]:
|
|
63
65
|
"""
|
|
64
66
|
Get cached embedding if it exists.
|
|
65
|
-
|
|
67
|
+
|
|
66
68
|
Updates last_accessed timestamp on hit.
|
|
67
69
|
"""
|
|
68
70
|
content_hash = self._hash_key(model_name, content)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
(
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if row is not None:
|
|
76
|
-
# Update last_accessed
|
|
77
|
-
now = datetime.now(timezone.utc).isoformat()
|
|
78
|
-
self._conn.execute(
|
|
79
|
-
"UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
|
|
80
|
-
(now, content_hash)
|
|
71
|
+
|
|
72
|
+
with self._lock:
|
|
73
|
+
cursor = self._conn.execute(
|
|
74
|
+
"SELECT embedding FROM embedding_cache WHERE content_hash = ?",
|
|
75
|
+
(content_hash,)
|
|
81
76
|
)
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
77
|
+
row = cursor.fetchone()
|
|
78
|
+
|
|
79
|
+
if row is not None:
|
|
80
|
+
# Update last_accessed
|
|
81
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
82
|
+
self._conn.execute(
|
|
83
|
+
"UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
|
|
84
|
+
(now, content_hash)
|
|
85
|
+
)
|
|
86
|
+
self._conn.commit()
|
|
87
|
+
|
|
88
|
+
# Deserialize embedding
|
|
89
|
+
return json.loads(row[0])
|
|
90
|
+
|
|
87
91
|
return None
|
|
88
92
|
|
|
89
93
|
def put(
|
|
@@ -94,40 +98,42 @@ class EmbeddingCache:
|
|
|
94
98
|
) -> None:
|
|
95
99
|
"""
|
|
96
100
|
Cache an embedding.
|
|
97
|
-
|
|
101
|
+
|
|
98
102
|
Evicts oldest entries if cache exceeds max_entries.
|
|
99
103
|
"""
|
|
100
104
|
content_hash = self._hash_key(model_name, content)
|
|
101
105
|
now = datetime.now(timezone.utc).isoformat()
|
|
102
106
|
embedding_blob = json.dumps(embedding)
|
|
103
|
-
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
107
|
+
|
|
108
|
+
with self._lock:
|
|
109
|
+
self._conn.execute("""
|
|
110
|
+
INSERT OR REPLACE INTO embedding_cache
|
|
111
|
+
(content_hash, model_name, embedding, dimension, created_at, last_accessed)
|
|
112
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
113
|
+
""", (content_hash, model_name, embedding_blob, len(embedding), now, now))
|
|
114
|
+
self._conn.commit()
|
|
115
|
+
|
|
116
|
+
# Evict old entries if needed
|
|
117
|
+
self._maybe_evict()
|
|
113
118
|
|
|
114
119
|
def _maybe_evict(self) -> None:
|
|
115
120
|
"""Evict oldest entries if cache exceeds max size."""
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
121
|
+
with self._lock:
|
|
122
|
+
cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
|
|
123
|
+
count = cursor.fetchone()[0]
|
|
124
|
+
|
|
125
|
+
if count > self._max_entries:
|
|
126
|
+
# Delete oldest 10% by last_accessed
|
|
127
|
+
evict_count = max(1, count // 10)
|
|
128
|
+
self._conn.execute("""
|
|
129
|
+
DELETE FROM embedding_cache
|
|
130
|
+
WHERE content_hash IN (
|
|
131
|
+
SELECT content_hash FROM embedding_cache
|
|
132
|
+
ORDER BY last_accessed ASC
|
|
133
|
+
LIMIT ?
|
|
134
|
+
)
|
|
135
|
+
""", (evict_count,))
|
|
136
|
+
self._conn.commit()
|
|
131
137
|
|
|
132
138
|
def stats(self) -> dict:
|
|
133
139
|
"""Get cache statistics."""
|
|
@@ -145,8 +151,9 @@ class EmbeddingCache:
|
|
|
145
151
|
|
|
146
152
|
def clear(self) -> None:
|
|
147
153
|
"""Clear all cached embeddings."""
|
|
148
|
-
self.
|
|
149
|
-
|
|
154
|
+
with self._lock:
|
|
155
|
+
self._conn.execute("DELETE FROM embedding_cache")
|
|
156
|
+
self._conn.commit()
|
|
150
157
|
|
|
151
158
|
def close(self) -> None:
|
|
152
159
|
"""Close the database connection."""
|
keep/providers/embeddings.py
CHANGED
|
@@ -11,12 +11,12 @@ from .base import EmbeddingProvider, get_registry
|
|
|
11
11
|
class SentenceTransformerEmbedding:
|
|
12
12
|
"""
|
|
13
13
|
Embedding provider using sentence-transformers library.
|
|
14
|
-
|
|
14
|
+
|
|
15
15
|
Runs locally, no API key required. Good default for getting started.
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
Requires: pip install sentence-transformers
|
|
18
18
|
"""
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
|
21
21
|
"""
|
|
22
22
|
Args:
|
|
@@ -29,9 +29,21 @@ class SentenceTransformerEmbedding:
|
|
|
29
29
|
"SentenceTransformerEmbedding requires 'sentence-transformers' library. "
|
|
30
30
|
"Install with: pip install sentence-transformers"
|
|
31
31
|
)
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
self.model_name = model
|
|
34
|
-
|
|
34
|
+
|
|
35
|
+
# Check if model is already cached locally to avoid network calls
|
|
36
|
+
# Expand short model names (e.g. "all-MiniLM-L6-v2" -> "sentence-transformers/all-MiniLM-L6-v2")
|
|
37
|
+
local_only = False
|
|
38
|
+
try:
|
|
39
|
+
from huggingface_hub import try_to_load_from_cache
|
|
40
|
+
repo_id = model if "/" in model else f"sentence-transformers/{model}"
|
|
41
|
+
cached = try_to_load_from_cache(repo_id, "config.json")
|
|
42
|
+
local_only = cached is not None
|
|
43
|
+
except ImportError:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
self._model = SentenceTransformer(model, local_files_only=local_only)
|
|
35
47
|
|
|
36
48
|
@property
|
|
37
49
|
def dimension(self) -> int:
|
|
@@ -83,8 +95,9 @@ class OpenAIEmbedding:
|
|
|
83
95
|
)
|
|
84
96
|
|
|
85
97
|
self.model_name = model
|
|
86
|
-
|
|
87
|
-
|
|
98
|
+
# Use lookup table if available, otherwise detect lazily from first embedding
|
|
99
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model)
|
|
100
|
+
|
|
88
101
|
# Resolve API key
|
|
89
102
|
key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
90
103
|
if not key:
|
|
@@ -96,16 +109,24 @@ class OpenAIEmbedding:
|
|
|
96
109
|
|
|
97
110
|
@property
|
|
98
111
|
def dimension(self) -> int:
|
|
99
|
-
"""Get embedding dimension for the model."""
|
|
112
|
+
"""Get embedding dimension for the model (detected lazily if unknown)."""
|
|
113
|
+
if self._dimension is None:
|
|
114
|
+
# Unknown model: detect from first embedding
|
|
115
|
+
test_embedding = self.embed("dimension test")
|
|
116
|
+
self._dimension = len(test_embedding)
|
|
100
117
|
return self._dimension
|
|
101
|
-
|
|
118
|
+
|
|
102
119
|
def embed(self, text: str) -> list[float]:
|
|
103
120
|
"""Generate embedding for a single text."""
|
|
104
121
|
response = self._client.embeddings.create(
|
|
105
122
|
model=self.model_name,
|
|
106
123
|
input=text,
|
|
107
124
|
)
|
|
108
|
-
|
|
125
|
+
embedding = response.data[0].embedding
|
|
126
|
+
# Cache dimension if not yet known
|
|
127
|
+
if self._dimension is None:
|
|
128
|
+
self._dimension = len(embedding)
|
|
129
|
+
return embedding
|
|
109
130
|
|
|
110
131
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
111
132
|
"""Generate embeddings for multiple texts."""
|
|
@@ -152,7 +173,8 @@ class GeminiEmbedding:
|
|
|
152
173
|
)
|
|
153
174
|
|
|
154
175
|
self.model_name = model
|
|
155
|
-
|
|
176
|
+
# Use lookup table if available, otherwise detect lazily from first embedding
|
|
177
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model)
|
|
156
178
|
|
|
157
179
|
# Resolve API key
|
|
158
180
|
key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
|
|
@@ -165,7 +187,11 @@ class GeminiEmbedding:
|
|
|
165
187
|
|
|
166
188
|
@property
|
|
167
189
|
def dimension(self) -> int:
|
|
168
|
-
"""Get embedding dimension for the model."""
|
|
190
|
+
"""Get embedding dimension for the model (detected lazily if unknown)."""
|
|
191
|
+
if self._dimension is None:
|
|
192
|
+
# Unknown model: detect from first embedding
|
|
193
|
+
test_embedding = self.embed("dimension test")
|
|
194
|
+
self._dimension = len(test_embedding)
|
|
169
195
|
return self._dimension
|
|
170
196
|
|
|
171
197
|
def embed(self, text: str) -> list[float]:
|
|
@@ -174,7 +200,11 @@ class GeminiEmbedding:
|
|
|
174
200
|
model=self.model_name,
|
|
175
201
|
contents=text,
|
|
176
202
|
)
|
|
177
|
-
|
|
203
|
+
embedding = list(result.embeddings[0].values)
|
|
204
|
+
# Cache dimension if not yet known
|
|
205
|
+
if self._dimension is None:
|
|
206
|
+
self._dimension = len(embedding)
|
|
207
|
+
return embedding
|
|
178
208
|
|
|
179
209
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
180
210
|
"""Generate embeddings for multiple texts."""
|
keep/providers/mlx.py
CHANGED
|
@@ -15,21 +15,18 @@ from .base import EmbeddingProvider, SummarizationProvider, get_registry
|
|
|
15
15
|
|
|
16
16
|
class MLXEmbedding:
|
|
17
17
|
"""
|
|
18
|
-
Embedding provider using
|
|
19
|
-
|
|
20
|
-
Uses sentence-transformer
|
|
21
|
-
|
|
18
|
+
Embedding provider using MPS (Metal) acceleration on Apple Silicon.
|
|
19
|
+
|
|
20
|
+
Uses sentence-transformer models with GPU acceleration via Metal Performance Shaders.
|
|
21
|
+
|
|
22
22
|
Requires: pip install mlx sentence-transformers
|
|
23
23
|
"""
|
|
24
|
-
|
|
25
|
-
def __init__(self, model: str = "
|
|
24
|
+
|
|
25
|
+
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
|
26
26
|
"""
|
|
27
27
|
Args:
|
|
28
|
-
model: Model name from
|
|
29
|
-
|
|
30
|
-
- mlx-community/bge-small-en-v1.5 (small, fast)
|
|
31
|
-
- mlx-community/bge-base-en-v1.5 (balanced)
|
|
32
|
-
- mlx-community/bge-large-en-v1.5 (best quality)
|
|
28
|
+
model: Model name from sentence-transformers hub.
|
|
29
|
+
Default: all-MiniLM-L6-v2 (384 dims, fast, no auth required)
|
|
33
30
|
"""
|
|
34
31
|
try:
|
|
35
32
|
import mlx.core as mx
|
|
@@ -39,17 +36,22 @@ class MLXEmbedding:
|
|
|
39
36
|
"MLXEmbedding requires 'mlx' and 'sentence-transformers'. "
|
|
40
37
|
"Install with: pip install mlx sentence-transformers"
|
|
41
38
|
)
|
|
42
|
-
|
|
39
|
+
|
|
43
40
|
self.model_name = model
|
|
44
|
-
|
|
45
|
-
#
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
41
|
+
|
|
42
|
+
# Check if model is already cached locally to avoid network calls
|
|
43
|
+
local_only = False
|
|
44
|
+
try:
|
|
45
|
+
from huggingface_hub import try_to_load_from_cache
|
|
46
|
+
repo_id = model if "/" in model else f"sentence-transformers/{model}"
|
|
47
|
+
cached = try_to_load_from_cache(repo_id, "config.json")
|
|
48
|
+
local_only = cached is not None
|
|
49
|
+
except ImportError:
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
# Use MPS (Metal) for GPU acceleration on Apple Silicon
|
|
53
|
+
self._model = SentenceTransformer(model, device="mps", local_files_only=local_only)
|
|
54
|
+
|
|
53
55
|
self._dimension: int | None = None
|
|
54
56
|
|
|
55
57
|
@property
|