keep-skill 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
keep/paths.py CHANGED
@@ -1,11 +1,29 @@
1
1
  """
2
2
  Utility functions for locating paths.
3
+
4
+ Config and store discovery follows this priority:
5
+
6
+ Config discovery:
7
+ 1. KEEP_CONFIG envvar (path to .keep/ directory)
8
+ 2. Tree-walk from cwd up to ~, find first .keep/keep.toml
9
+ 3. ~/.keep/ default
10
+
11
+ Store resolution:
12
+ 1. --store CLI option (passed to Keeper)
13
+ 2. KEEP_STORE_PATH envvar
14
+ 3. store.path in config file
15
+ 4. Config directory itself (backwards compat)
3
16
  """
4
17
 
18
+ from __future__ import annotations
19
+
5
20
  import os
6
21
  import warnings
7
22
  from pathlib import Path
8
- from typing import Optional
23
+ from typing import TYPE_CHECKING, Optional
24
+
25
+ if TYPE_CHECKING:
26
+ from .config import StoreConfig
9
27
 
10
28
 
11
29
  def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
@@ -35,14 +53,67 @@ def find_git_root(start_path: Optional[Path] = None) -> Optional[Path]:
35
53
  return None
36
54
 
37
55
 
38
- def get_default_store_path() -> Path:
56
+ def find_config_dir(start_path: Optional[Path] = None) -> Path:
57
+ """
58
+ Find config directory by tree-walking from start_path up to home.
59
+
60
+ Looks for .keep/keep.toml at each directory level, stopping at home.
61
+ Returns the .keep/ directory containing the config, or ~/.keep/ if none found.
62
+
63
+ Args:
64
+ start_path: Path to start searching from. Defaults to cwd.
65
+
66
+ Returns:
67
+ Path to the .keep/ config directory.
68
+ """
69
+ if start_path is None:
70
+ start_path = Path.cwd()
71
+
72
+ home = Path.home()
73
+ current = start_path.resolve()
74
+
75
+ while True:
76
+ candidate = current / ".keep" / "keep.toml"
77
+ if candidate.exists():
78
+ return current / ".keep"
79
+
80
+ # Stop at home or filesystem root
81
+ if current == home or current == current.parent:
82
+ break
83
+ current = current.parent
84
+
85
+ # Default: ~/.keep/
86
+ return home / ".keep"
87
+
88
+
89
+ def get_config_dir() -> Path:
90
+ """
91
+ Get the config directory.
92
+
93
+ Priority:
94
+ 1. KEEP_CONFIG environment variable
95
+ 2. Tree-walk from cwd up to ~ (find_config_dir)
96
+
97
+ Returns:
98
+ Path to the .keep/ config directory.
99
+ """
100
+ env = os.environ.get("KEEP_CONFIG")
101
+ if env:
102
+ return Path(env).expanduser().resolve()
103
+ return find_config_dir()
104
+
105
+
106
+ def get_default_store_path(config: Optional[StoreConfig] = None) -> Path:
39
107
  """
40
108
  Get the default store path.
41
109
 
42
110
  Priority:
43
111
  1. KEEP_STORE_PATH environment variable
44
- 2. .keep/ directory at git repository root
45
- 3. ~/.keep/ in user's home directory (if not in a repo)
112
+ 2. store.path setting in config file
113
+ 3. Config directory itself (backwards compat)
114
+
115
+ Args:
116
+ config: Optional loaded config to check for store.path setting.
46
117
 
47
118
  Returns:
48
119
  Path to the store directory (may not exist yet).
@@ -51,17 +122,10 @@ def get_default_store_path() -> Path:
51
122
  env_path = os.environ.get("KEEP_STORE_PATH")
52
123
  if env_path:
53
124
  return Path(env_path).resolve()
54
-
55
- # Try to find git root
56
- git_root = find_git_root()
57
- if git_root:
58
- return git_root / ".keep"
59
125
 
60
- # Fall back to home directory with warning
61
- home = Path.home()
62
- warnings.warn(
63
- f"Not in a git repository. Using {home / '.keep'} for storage. "
64
- f"Set KEEP_STORE_PATH to specify a different location.",
65
- stacklevel=2,
66
- )
67
- return home / ".keep"
126
+ # Check config for explicit store.path
127
+ if config and config.store_path:
128
+ return Path(config.store_path).expanduser().resolve()
129
+
130
+ # Default: config directory is also the store
131
+ return get_config_dir()
keep/pending_summaries.py CHANGED
@@ -6,6 +6,7 @@ This enables fast indexing with lazy summarization.
6
6
  """
7
7
 
8
8
  import sqlite3
9
+ import threading
9
10
  from dataclasses import dataclass
10
11
  from datetime import datetime, timezone
11
12
  from pathlib import Path
@@ -37,6 +38,7 @@ class PendingSummaryQueue:
37
38
  """
38
39
  self._queue_path = queue_path
39
40
  self._conn: Optional[sqlite3.Connection] = None
41
+ self._lock = threading.Lock()
40
42
  self._init_db()
41
43
 
42
44
  def _init_db(self) -> None:
@@ -66,12 +68,13 @@ class PendingSummaryQueue:
66
68
  If the same id+collection already exists, replaces it (newer content wins).
67
69
  """
68
70
  now = datetime.now(timezone.utc).isoformat()
69
- self._conn.execute("""
70
- INSERT OR REPLACE INTO pending_summaries
71
- (id, collection, content, queued_at, attempts)
72
- VALUES (?, ?, ?, ?, 0)
73
- """, (id, collection, content, now))
74
- self._conn.commit()
71
+ with self._lock:
72
+ self._conn.execute("""
73
+ INSERT OR REPLACE INTO pending_summaries
74
+ (id, collection, content, queued_at, attempts)
75
+ VALUES (?, ?, ?, ?, 0)
76
+ """, (id, collection, content, now))
77
+ self._conn.commit()
75
78
 
76
79
  def dequeue(self, limit: int = 10) -> list[PendingSummary]:
77
80
  """
@@ -80,42 +83,44 @@ class PendingSummaryQueue:
80
83
  Items are returned but not removed - call `complete()` after successful processing.
81
84
  Increments attempt counter on each dequeue.
82
85
  """
83
- cursor = self._conn.execute("""
84
- SELECT id, collection, content, queued_at, attempts
85
- FROM pending_summaries
86
- ORDER BY queued_at ASC
87
- LIMIT ?
88
- """, (limit,))
89
-
90
- items = []
91
- for row in cursor.fetchall():
92
- items.append(PendingSummary(
93
- id=row[0],
94
- collection=row[1],
95
- content=row[2],
96
- queued_at=row[3],
97
- attempts=row[4],
98
- ))
99
-
100
- # Increment attempt counters
101
- if items:
102
- ids = [(item.id, item.collection) for item in items]
103
- self._conn.executemany("""
104
- UPDATE pending_summaries
105
- SET attempts = attempts + 1
106
- WHERE id = ? AND collection = ?
107
- """, ids)
108
- self._conn.commit()
86
+ with self._lock:
87
+ cursor = self._conn.execute("""
88
+ SELECT id, collection, content, queued_at, attempts
89
+ FROM pending_summaries
90
+ ORDER BY queued_at ASC
91
+ LIMIT ?
92
+ """, (limit,))
93
+
94
+ items = []
95
+ for row in cursor.fetchall():
96
+ items.append(PendingSummary(
97
+ id=row[0],
98
+ collection=row[1],
99
+ content=row[2],
100
+ queued_at=row[3],
101
+ attempts=row[4],
102
+ ))
103
+
104
+ # Increment attempt counters
105
+ if items:
106
+ ids = [(item.id, item.collection) for item in items]
107
+ self._conn.executemany("""
108
+ UPDATE pending_summaries
109
+ SET attempts = attempts + 1
110
+ WHERE id = ? AND collection = ?
111
+ """, ids)
112
+ self._conn.commit()
109
113
 
110
114
  return items
111
115
 
112
116
  def complete(self, id: str, collection: str) -> None:
113
117
  """Remove an item from the queue after successful processing."""
114
- self._conn.execute("""
115
- DELETE FROM pending_summaries
116
- WHERE id = ? AND collection = ?
117
- """, (id, collection))
118
- self._conn.commit()
118
+ with self._lock:
119
+ self._conn.execute("""
120
+ DELETE FROM pending_summaries
121
+ WHERE id = ? AND collection = ?
122
+ """, (id, collection))
123
+ self._conn.commit()
119
124
 
120
125
  def count(self) -> int:
121
126
  """Get count of pending items."""
@@ -143,9 +148,10 @@ class PendingSummaryQueue:
143
148
 
144
149
  def clear(self) -> int:
145
150
  """Clear all pending items. Returns count of items cleared."""
146
- count = self.count()
147
- self._conn.execute("DELETE FROM pending_summaries")
148
- self._conn.commit()
151
+ with self._lock:
152
+ count = self.count()
153
+ self._conn.execute("DELETE FROM pending_summaries")
154
+ self._conn.commit()
149
155
  return count
150
156
 
151
157
  def close(self) -> None:
@@ -8,6 +8,7 @@ avoiding redundant embedding calls for unchanged content.
8
8
  import hashlib
9
9
  import json
10
10
  import sqlite3
11
+ import threading
11
12
  from datetime import datetime, timezone
12
13
  from pathlib import Path
13
14
  from typing import Optional
@@ -32,6 +33,7 @@ class EmbeddingCache:
32
33
  self._cache_path = cache_path
33
34
  self._max_entries = max_entries
34
35
  self._conn: Optional[sqlite3.Connection] = None
36
+ self._lock = threading.RLock()
35
37
  self._init_db()
36
38
 
37
39
  def _init_db(self) -> None:
@@ -62,28 +64,30 @@ class EmbeddingCache:
62
64
  def get(self, model_name: str, content: str) -> Optional[list[float]]:
63
65
  """
64
66
  Get cached embedding if it exists.
65
-
67
+
66
68
  Updates last_accessed timestamp on hit.
67
69
  """
68
70
  content_hash = self._hash_key(model_name, content)
69
- cursor = self._conn.execute(
70
- "SELECT embedding FROM embedding_cache WHERE content_hash = ?",
71
- (content_hash,)
72
- )
73
- row = cursor.fetchone()
74
-
75
- if row is not None:
76
- # Update last_accessed
77
- now = datetime.now(timezone.utc).isoformat()
78
- self._conn.execute(
79
- "UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
80
- (now, content_hash)
71
+
72
+ with self._lock:
73
+ cursor = self._conn.execute(
74
+ "SELECT embedding FROM embedding_cache WHERE content_hash = ?",
75
+ (content_hash,)
81
76
  )
82
- self._conn.commit()
83
-
84
- # Deserialize embedding
85
- return json.loads(row[0])
86
-
77
+ row = cursor.fetchone()
78
+
79
+ if row is not None:
80
+ # Update last_accessed
81
+ now = datetime.now(timezone.utc).isoformat()
82
+ self._conn.execute(
83
+ "UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
84
+ (now, content_hash)
85
+ )
86
+ self._conn.commit()
87
+
88
+ # Deserialize embedding
89
+ return json.loads(row[0])
90
+
87
91
  return None
88
92
 
89
93
  def put(
@@ -94,40 +98,42 @@ class EmbeddingCache:
94
98
  ) -> None:
95
99
  """
96
100
  Cache an embedding.
97
-
101
+
98
102
  Evicts oldest entries if cache exceeds max_entries.
99
103
  """
100
104
  content_hash = self._hash_key(model_name, content)
101
105
  now = datetime.now(timezone.utc).isoformat()
102
106
  embedding_blob = json.dumps(embedding)
103
-
104
- self._conn.execute("""
105
- INSERT OR REPLACE INTO embedding_cache
106
- (content_hash, model_name, embedding, dimension, created_at, last_accessed)
107
- VALUES (?, ?, ?, ?, ?, ?)
108
- """, (content_hash, model_name, embedding_blob, len(embedding), now, now))
109
- self._conn.commit()
110
-
111
- # Evict old entries if needed
112
- self._maybe_evict()
107
+
108
+ with self._lock:
109
+ self._conn.execute("""
110
+ INSERT OR REPLACE INTO embedding_cache
111
+ (content_hash, model_name, embedding, dimension, created_at, last_accessed)
112
+ VALUES (?, ?, ?, ?, ?, ?)
113
+ """, (content_hash, model_name, embedding_blob, len(embedding), now, now))
114
+ self._conn.commit()
115
+
116
+ # Evict old entries if needed
117
+ self._maybe_evict()
113
118
 
114
119
  def _maybe_evict(self) -> None:
115
120
  """Evict oldest entries if cache exceeds max size."""
116
- cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
117
- count = cursor.fetchone()[0]
118
-
119
- if count > self._max_entries:
120
- # Delete oldest 10% by last_accessed
121
- evict_count = max(1, count // 10)
122
- self._conn.execute("""
123
- DELETE FROM embedding_cache
124
- WHERE content_hash IN (
125
- SELECT content_hash FROM embedding_cache
126
- ORDER BY last_accessed ASC
127
- LIMIT ?
128
- )
129
- """, (evict_count,))
130
- self._conn.commit()
121
+ with self._lock:
122
+ cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
123
+ count = cursor.fetchone()[0]
124
+
125
+ if count > self._max_entries:
126
+ # Delete oldest 10% by last_accessed
127
+ evict_count = max(1, count // 10)
128
+ self._conn.execute("""
129
+ DELETE FROM embedding_cache
130
+ WHERE content_hash IN (
131
+ SELECT content_hash FROM embedding_cache
132
+ ORDER BY last_accessed ASC
133
+ LIMIT ?
134
+ )
135
+ """, (evict_count,))
136
+ self._conn.commit()
131
137
 
132
138
  def stats(self) -> dict:
133
139
  """Get cache statistics."""
@@ -145,8 +151,9 @@ class EmbeddingCache:
145
151
 
146
152
  def clear(self) -> None:
147
153
  """Clear all cached embeddings."""
148
- self._conn.execute("DELETE FROM embedding_cache")
149
- self._conn.commit()
154
+ with self._lock:
155
+ self._conn.execute("DELETE FROM embedding_cache")
156
+ self._conn.commit()
150
157
 
151
158
  def close(self) -> None:
152
159
  """Close the database connection."""
@@ -11,12 +11,12 @@ from .base import EmbeddingProvider, get_registry
11
11
  class SentenceTransformerEmbedding:
12
12
  """
13
13
  Embedding provider using sentence-transformers library.
14
-
14
+
15
15
  Runs locally, no API key required. Good default for getting started.
16
-
16
+
17
17
  Requires: pip install sentence-transformers
18
18
  """
19
-
19
+
20
20
  def __init__(self, model: str = "all-MiniLM-L6-v2"):
21
21
  """
22
22
  Args:
@@ -29,9 +29,21 @@ class SentenceTransformerEmbedding:
29
29
  "SentenceTransformerEmbedding requires 'sentence-transformers' library. "
30
30
  "Install with: pip install sentence-transformers"
31
31
  )
32
-
32
+
33
33
  self.model_name = model
34
- self._model = SentenceTransformer(model)
34
+
35
+ # Check if model is already cached locally to avoid network calls
36
+ # Expand short model names (e.g. "all-MiniLM-L6-v2" -> "sentence-transformers/all-MiniLM-L6-v2")
37
+ local_only = False
38
+ try:
39
+ from huggingface_hub import try_to_load_from_cache
40
+ repo_id = model if "/" in model else f"sentence-transformers/{model}"
41
+ cached = try_to_load_from_cache(repo_id, "config.json")
42
+ local_only = cached is not None
43
+ except ImportError:
44
+ pass
45
+
46
+ self._model = SentenceTransformer(model, local_files_only=local_only)
35
47
 
36
48
  @property
37
49
  def dimension(self) -> int:
@@ -83,8 +95,9 @@ class OpenAIEmbedding:
83
95
  )
84
96
 
85
97
  self.model_name = model
86
- self._dimension = self.MODEL_DIMENSIONS.get(model, 1536)
87
-
98
+ # Use lookup table if available, otherwise detect lazily from first embedding
99
+ self._dimension = self.MODEL_DIMENSIONS.get(model)
100
+
88
101
  # Resolve API key
89
102
  key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
90
103
  if not key:
@@ -96,16 +109,24 @@ class OpenAIEmbedding:
96
109
 
97
110
  @property
98
111
  def dimension(self) -> int:
99
- """Get embedding dimension for the model."""
112
+ """Get embedding dimension for the model (detected lazily if unknown)."""
113
+ if self._dimension is None:
114
+ # Unknown model: detect from first embedding
115
+ test_embedding = self.embed("dimension test")
116
+ self._dimension = len(test_embedding)
100
117
  return self._dimension
101
-
118
+
102
119
  def embed(self, text: str) -> list[float]:
103
120
  """Generate embedding for a single text."""
104
121
  response = self._client.embeddings.create(
105
122
  model=self.model_name,
106
123
  input=text,
107
124
  )
108
- return response.data[0].embedding
125
+ embedding = response.data[0].embedding
126
+ # Cache dimension if not yet known
127
+ if self._dimension is None:
128
+ self._dimension = len(embedding)
129
+ return embedding
109
130
 
110
131
  def embed_batch(self, texts: list[str]) -> list[list[float]]:
111
132
  """Generate embeddings for multiple texts."""
@@ -152,7 +173,8 @@ class GeminiEmbedding:
152
173
  )
153
174
 
154
175
  self.model_name = model
155
- self._dimension = self.MODEL_DIMENSIONS.get(model, 768)
176
+ # Use lookup table if available, otherwise detect lazily from first embedding
177
+ self._dimension = self.MODEL_DIMENSIONS.get(model)
156
178
 
157
179
  # Resolve API key
158
180
  key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
@@ -165,7 +187,11 @@ class GeminiEmbedding:
165
187
 
166
188
  @property
167
189
  def dimension(self) -> int:
168
- """Get embedding dimension for the model."""
190
+ """Get embedding dimension for the model (detected lazily if unknown)."""
191
+ if self._dimension is None:
192
+ # Unknown model: detect from first embedding
193
+ test_embedding = self.embed("dimension test")
194
+ self._dimension = len(test_embedding)
169
195
  return self._dimension
170
196
 
171
197
  def embed(self, text: str) -> list[float]:
@@ -174,7 +200,11 @@ class GeminiEmbedding:
174
200
  model=self.model_name,
175
201
  contents=text,
176
202
  )
177
- return list(result.embeddings[0].values)
203
+ embedding = list(result.embeddings[0].values)
204
+ # Cache dimension if not yet known
205
+ if self._dimension is None:
206
+ self._dimension = len(embedding)
207
+ return embedding
178
208
 
179
209
  def embed_batch(self, texts: list[str]) -> list[list[float]]:
180
210
  """Generate embeddings for multiple texts."""
keep/providers/mlx.py CHANGED
@@ -15,21 +15,18 @@ from .base import EmbeddingProvider, SummarizationProvider, get_registry
15
15
 
16
16
  class MLXEmbedding:
17
17
  """
18
- Embedding provider using MLX on Apple Silicon.
19
-
20
- Uses sentence-transformer compatible models converted to MLX format.
21
-
18
+ Embedding provider using MPS (Metal) acceleration on Apple Silicon.
19
+
20
+ Uses sentence-transformer models with GPU acceleration via Metal Performance Shaders.
21
+
22
22
  Requires: pip install mlx sentence-transformers
23
23
  """
24
-
25
- def __init__(self, model: str = "mlx-community/bge-small-en-v1.5"):
24
+
25
+ def __init__(self, model: str = "all-MiniLM-L6-v2"):
26
26
  """
27
27
  Args:
28
- model: Model name from mlx-community hub or local path.
29
- Good options:
30
- - mlx-community/bge-small-en-v1.5 (small, fast)
31
- - mlx-community/bge-base-en-v1.5 (balanced)
32
- - mlx-community/bge-large-en-v1.5 (best quality)
28
+ model: Model name from sentence-transformers hub.
29
+ Default: all-MiniLM-L6-v2 (384 dims, fast, no auth required)
33
30
  """
34
31
  try:
35
32
  import mlx.core as mx
@@ -39,17 +36,22 @@ class MLXEmbedding:
39
36
  "MLXEmbedding requires 'mlx' and 'sentence-transformers'. "
40
37
  "Install with: pip install mlx sentence-transformers"
41
38
  )
42
-
39
+
43
40
  self.model_name = model
44
-
45
- # sentence-transformers can use MLX backend on Apple Silicon
46
- # For MLX-specific models, we use the direct approach
47
- if model.startswith("mlx-community/"):
48
- # Use sentence-transformers which auto-detects MLX
49
- self._model = SentenceTransformer(model, device="mps")
50
- else:
51
- self._model = SentenceTransformer(model)
52
-
41
+
42
+ # Check if model is already cached locally to avoid network calls
43
+ local_only = False
44
+ try:
45
+ from huggingface_hub import try_to_load_from_cache
46
+ repo_id = model if "/" in model else f"sentence-transformers/{model}"
47
+ cached = try_to_load_from_cache(repo_id, "config.json")
48
+ local_only = cached is not None
49
+ except ImportError:
50
+ pass
51
+
52
+ # Use MPS (Metal) for GPU acceleration on Apple Silicon
53
+ self._model = SentenceTransformer(model, device="mps", local_files_only=local_only)
54
+
53
55
  self._dimension: int | None = None
54
56
 
55
57
  @property