keep-skill 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
keep/providers/llm.py ADDED
@@ -0,0 +1,371 @@
1
+ """
2
+ Summarization and tagging providers using LLMs.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import Any
8
+
9
+ from .base import SummarizationProvider, TaggingProvider, get_registry
10
+
11
+
12
+ # -----------------------------------------------------------------------------
13
+ # Summarization Providers
14
+ # -----------------------------------------------------------------------------
15
+
16
+ class AnthropicSummarization:
17
+ """
18
+ Summarization provider using Anthropic's Claude API.
19
+
20
+ Requires: ANTHROPIC_API_KEY environment variable.
21
+ Optionally reads from OpenClaw config via OPENCLAW_CONFIG env var.
22
+ """
23
+
24
+ SYSTEM_PROMPT = """You are a precise summarization assistant.
25
+ Create a concise summary of the provided document that captures:
26
+ - The main purpose or topic
27
+ - Key points or functionality
28
+ - Important details that would help someone decide if this document is relevant
29
+
30
+ Be factual and specific. Do not include phrases like "This document" - just state the content directly."""
31
+
32
+ def __init__(
33
+ self,
34
+ model: str = "claude-3-5-haiku-20241022",
35
+ api_key: str | None = None,
36
+ max_tokens: int = 200,
37
+ ):
38
+ try:
39
+ from anthropic import Anthropic
40
+ except ImportError:
41
+ raise RuntimeError("AnthropicSummarization requires 'anthropic' library")
42
+
43
+ self.model = model
44
+ self.max_tokens = max_tokens
45
+
46
+ # Try environment variable first, then OpenClaw config
47
+ key = api_key or os.environ.get("ANTHROPIC_API_KEY")
48
+ if not key:
49
+ # Try to read from OpenClaw config (OAuth tokens stored separately)
50
+ # For now, just require explicit API key
51
+ raise ValueError("ANTHROPIC_API_KEY environment variable required")
52
+
53
+ self.client = Anthropic(api_key=key)
54
+
55
+ def summarize(self, content: str) -> str:
56
+ """Generate summary using Anthropic Claude."""
57
+ # Truncate very long content
58
+ truncated = content[:50000] if len(content) > 50000 else content
59
+
60
+ try:
61
+ response = self.client.messages.create(
62
+ model=self.model,
63
+ max_tokens=self.max_tokens,
64
+ system=self.SYSTEM_PROMPT,
65
+ messages=[
66
+ {"role": "user", "content": truncated}
67
+ ],
68
+ )
69
+
70
+ # Extract text from response
71
+ if response.content and len(response.content) > 0:
72
+ return response.content[0].text
73
+ return truncated[:500] # Fallback
74
+ except Exception as e:
75
+ # Fallback to truncation on error
76
+ return truncated[:500]
77
+
78
+
79
+ class OpenAISummarization:
80
+ """
81
+ Summarization provider using OpenAI's chat API.
82
+
83
+ Requires: KEEP_OPENAI_API_KEY or OPENAI_API_KEY environment variable.
84
+ """
85
+
86
+ SYSTEM_PROMPT = """You are a precise summarization assistant.
87
+ Create a concise summary of the provided document that captures:
88
+ - The main purpose or topic
89
+ - Key points or functionality
90
+ - Important details that would help someone decide if this document is relevant
91
+
92
+ Be factual and specific. Do not include phrases like "This document" - just state the content directly."""
93
+
94
+ def __init__(
95
+ self,
96
+ model: str = "gpt-4o-mini",
97
+ api_key: str | None = None,
98
+ max_tokens: int = 200,
99
+ ):
100
+ try:
101
+ from openai import OpenAI
102
+ except ImportError:
103
+ raise RuntimeError("OpenAISummarization requires 'openai' library")
104
+
105
+ self.model = model
106
+ self.max_tokens = max_tokens
107
+
108
+ key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
109
+ if not key:
110
+ raise ValueError("OpenAI API key required")
111
+
112
+ self._client = OpenAI(api_key=key)
113
+
114
+ def summarize(self, content: str, *, max_length: int = 500) -> str:
115
+ """Generate a summary using OpenAI."""
116
+ # Truncate very long content to avoid token limits
117
+ truncated = content[:50000] if len(content) > 50000 else content
118
+
119
+ response = self._client.chat.completions.create(
120
+ model=self.model,
121
+ messages=[
122
+ {"role": "system", "content": self.SYSTEM_PROMPT},
123
+ {"role": "user", "content": truncated},
124
+ ],
125
+ max_tokens=self.max_tokens,
126
+ temperature=0.3,
127
+ )
128
+
129
+ return response.choices[0].message.content.strip()
130
+
131
+
132
+ class OllamaSummarization:
133
+ """
134
+ Summarization provider using Ollama's local API.
135
+ """
136
+
137
+ SYSTEM_PROMPT = OpenAISummarization.SYSTEM_PROMPT
138
+
139
+ def __init__(
140
+ self,
141
+ model: str = "llama3.2",
142
+ base_url: str = "http://localhost:11434",
143
+ ):
144
+ self.model = model
145
+ self.base_url = base_url.rstrip("/")
146
+
147
+ def summarize(self, content: str, *, max_length: int = 500) -> str:
148
+ """Generate a summary using Ollama."""
149
+ import requests
150
+
151
+ truncated = content[:50000] if len(content) > 50000 else content
152
+
153
+ response = requests.post(
154
+ f"{self.base_url}/api/chat",
155
+ json={
156
+ "model": self.model,
157
+ "messages": [
158
+ {"role": "system", "content": self.SYSTEM_PROMPT},
159
+ {"role": "user", "content": truncated},
160
+ ],
161
+ "stream": False,
162
+ },
163
+ )
164
+ response.raise_for_status()
165
+
166
+ return response.json()["message"]["content"].strip()
167
+
168
+
169
+ class PassthroughSummarization:
170
+ """
171
+ Summarization provider that returns the first N characters.
172
+
173
+ Useful for testing or when LLM summarization is not needed.
174
+ """
175
+
176
+ def __init__(self, max_chars: int = 500):
177
+ self.max_chars = max_chars
178
+
179
+ def summarize(self, content: str, *, max_length: int = 500) -> str:
180
+ """Return truncated content as summary."""
181
+ limit = min(self.max_chars, max_length)
182
+ if len(content) <= limit:
183
+ return content
184
+ return content[:limit].rsplit(" ", 1)[0] + "..."
185
+
186
+
187
+ # -----------------------------------------------------------------------------
188
+ # Tagging Providers
189
+ # -----------------------------------------------------------------------------
190
+
191
+ class AnthropicTagging:
192
+ """
193
+ Tagging provider using Anthropic's Claude API with JSON output.
194
+ """
195
+
196
+ SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
197
+
198
+ Generate tags for these categories when applicable:
199
+ - content_type: The type of content (e.g., "documentation", "code", "article", "config")
200
+ - language: Programming language if code (e.g., "python", "javascript")
201
+ - domain: Subject domain (e.g., "authentication", "database", "api", "testing")
202
+ - framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
203
+
204
+ Only include tags that clearly apply. Values should be lowercase.
205
+
206
+ Respond with a JSON object only, no explanation."""
207
+
208
+ def __init__(
209
+ self,
210
+ model: str = "claude-3-5-haiku-20241022",
211
+ api_key: str | None = None,
212
+ ):
213
+ try:
214
+ from anthropic import Anthropic
215
+ except ImportError:
216
+ raise RuntimeError("AnthropicTagging requires 'anthropic' library")
217
+
218
+ self.model = model
219
+
220
+ key = api_key or os.environ.get("ANTHROPIC_API_KEY")
221
+ if not key:
222
+ raise ValueError("ANTHROPIC_API_KEY environment variable required")
223
+
224
+ self._client = Anthropic(api_key=key)
225
+
226
+ def tag(self, content: str) -> dict[str, str]:
227
+ """Generate tags using Anthropic Claude."""
228
+ truncated = content[:20000] if len(content) > 20000 else content
229
+
230
+ try:
231
+ response = self._client.messages.create(
232
+ model=self.model,
233
+ max_tokens=200,
234
+ temperature=0.2,
235
+ system=self.SYSTEM_PROMPT,
236
+ messages=[
237
+ {"role": "user", "content": truncated}
238
+ ],
239
+ )
240
+
241
+ # Parse JSON from response
242
+ if response.content and len(response.content) > 0:
243
+ tags = json.loads(response.content[0].text)
244
+ return {str(k): str(v) for k, v in tags.items()}
245
+ return {}
246
+ except (json.JSONDecodeError, Exception):
247
+ return {}
248
+
249
+
250
+ class OpenAITagging:
251
+ """
252
+ Tagging provider using OpenAI's chat API with JSON output.
253
+ """
254
+
255
+ SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
256
+
257
+ Generate tags for these categories when applicable:
258
+ - content_type: The type of content (e.g., "documentation", "code", "article", "config")
259
+ - language: Programming language if code (e.g., "python", "javascript")
260
+ - domain: Subject domain (e.g., "authentication", "database", "api", "testing")
261
+ - framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
262
+
263
+ Only include tags that clearly apply. Values should be lowercase.
264
+
265
+ Respond with a JSON object only, no explanation."""
266
+
267
+ def __init__(
268
+ self,
269
+ model: str = "gpt-4o-mini",
270
+ api_key: str | None = None,
271
+ ):
272
+ try:
273
+ from openai import OpenAI
274
+ except ImportError:
275
+ raise RuntimeError("OpenAITagging requires 'openai' library")
276
+
277
+ self.model = model
278
+
279
+ key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
280
+ if not key:
281
+ raise ValueError("OpenAI API key required")
282
+
283
+ self._client = OpenAI(api_key=key)
284
+
285
+ def tag(self, content: str) -> dict[str, str]:
286
+ """Generate tags using OpenAI."""
287
+ truncated = content[:20000] if len(content) > 20000 else content
288
+
289
+ response = self._client.chat.completions.create(
290
+ model=self.model,
291
+ messages=[
292
+ {"role": "system", "content": self.SYSTEM_PROMPT},
293
+ {"role": "user", "content": truncated},
294
+ ],
295
+ response_format={"type": "json_object"},
296
+ max_tokens=200,
297
+ temperature=0.2,
298
+ )
299
+
300
+ try:
301
+ tags = json.loads(response.choices[0].message.content)
302
+ # Ensure all values are strings
303
+ return {str(k): str(v) for k, v in tags.items()}
304
+ except json.JSONDecodeError:
305
+ return {}
306
+
307
+
308
+ class OllamaTagging:
309
+ """
310
+ Tagging provider using Ollama's local API.
311
+ """
312
+
313
+ SYSTEM_PROMPT = OpenAITagging.SYSTEM_PROMPT
314
+
315
+ def __init__(
316
+ self,
317
+ model: str = "llama3.2",
318
+ base_url: str = "http://localhost:11434",
319
+ ):
320
+ self.model = model
321
+ self.base_url = base_url.rstrip("/")
322
+
323
+ def tag(self, content: str) -> dict[str, str]:
324
+ """Generate tags using Ollama."""
325
+ import requests
326
+
327
+ truncated = content[:20000] if len(content) > 20000 else content
328
+
329
+ response = requests.post(
330
+ f"{self.base_url}/api/chat",
331
+ json={
332
+ "model": self.model,
333
+ "messages": [
334
+ {"role": "system", "content": self.SYSTEM_PROMPT},
335
+ {"role": "user", "content": truncated},
336
+ ],
337
+ "format": "json",
338
+ "stream": False,
339
+ },
340
+ )
341
+ response.raise_for_status()
342
+
343
+ try:
344
+ tags = json.loads(response.json()["message"]["content"])
345
+ return {str(k): str(v) for k, v in tags.items()}
346
+ except (json.JSONDecodeError, KeyError):
347
+ return {}
348
+
349
+
350
+ class NoopTagging:
351
+ """
352
+ Tagging provider that returns empty tags.
353
+
354
+ Useful when tagging is disabled or for testing.
355
+ """
356
+
357
+ def tag(self, content: str) -> dict[str, str]:
358
+ """Return empty tags."""
359
+ return {}
360
+
361
+
362
+ # Register providers
363
+ _registry = get_registry()
364
+ _registry.register_summarization("anthropic", AnthropicSummarization)
365
+ _registry.register_summarization("openai", OpenAISummarization)
366
+ _registry.register_summarization("ollama", OllamaSummarization)
367
+ _registry.register_summarization("passthrough", PassthroughSummarization)
368
+ _registry.register_tagging("anthropic", AnthropicTagging)
369
+ _registry.register_tagging("openai", OpenAITagging)
370
+ _registry.register_tagging("ollama", OllamaTagging)
371
+ _registry.register_tagging("noop", NoopTagging)
keep/providers/mlx.py ADDED
@@ -0,0 +1,256 @@
1
+ """
2
+ MLX providers for Apple Silicon.
3
+
4
+ MLX is Apple's ML framework optimized for Apple Silicon. These providers
5
+ run entirely locally with no API keys required.
6
+
7
+ Requires: pip install mlx-lm mlx
8
+ """
9
+
10
+ import os
11
+ from typing import Any
12
+
13
+ from .base import EmbeddingProvider, SummarizationProvider, get_registry
14
+
15
+
16
+ class MLXEmbedding:
17
+ """
18
+ Embedding provider using MLX on Apple Silicon.
19
+
20
+ Uses sentence-transformer compatible models converted to MLX format.
21
+
22
+ Requires: pip install mlx sentence-transformers
23
+ """
24
+
25
+ def __init__(self, model: str = "mlx-community/bge-small-en-v1.5"):
26
+ """
27
+ Args:
28
+ model: Model name from mlx-community hub or local path.
29
+ Good options:
30
+ - mlx-community/bge-small-en-v1.5 (small, fast)
31
+ - mlx-community/bge-base-en-v1.5 (balanced)
32
+ - mlx-community/bge-large-en-v1.5 (best quality)
33
+ """
34
+ try:
35
+ import mlx.core as mx
36
+ from sentence_transformers import SentenceTransformer
37
+ except ImportError:
38
+ raise RuntimeError(
39
+ "MLXEmbedding requires 'mlx' and 'sentence-transformers'. "
40
+ "Install with: pip install mlx sentence-transformers"
41
+ )
42
+
43
+ self.model_name = model
44
+
45
+ # sentence-transformers can use MLX backend on Apple Silicon
46
+ # For MLX-specific models, we use the direct approach
47
+ if model.startswith("mlx-community/"):
48
+ # Use sentence-transformers which auto-detects MLX
49
+ self._model = SentenceTransformer(model, device="mps")
50
+ else:
51
+ self._model = SentenceTransformer(model)
52
+
53
+ self._dimension: int | None = None
54
+
55
+ @property
56
+ def dimension(self) -> int:
57
+ """Get embedding dimension from the model."""
58
+ if self._dimension is None:
59
+ self._dimension = self._model.get_sentence_embedding_dimension()
60
+ return self._dimension
61
+
62
+ def embed(self, text: str) -> list[float]:
63
+ """Generate embedding for a single text."""
64
+ embedding = self._model.encode(text, convert_to_numpy=True)
65
+ return embedding.tolist()
66
+
67
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
68
+ """Generate embeddings for multiple texts."""
69
+ embeddings = self._model.encode(texts, convert_to_numpy=True)
70
+ return embeddings.tolist()
71
+
72
+
73
+ class MLXSummarization:
74
+ """
75
+ Summarization provider using MLX-LM on Apple Silicon.
76
+
77
+ Runs local LLMs optimized for Apple Silicon. No API key required.
78
+
79
+ Requires: pip install mlx-lm
80
+ """
81
+
82
+ SYSTEM_PROMPT = """You are a precise summarization assistant.
83
+ Create a concise summary of the provided document that captures:
84
+ - The main purpose or topic
85
+ - Key points or functionality
86
+ - Important details that would help someone decide if this document is relevant
87
+
88
+ Be factual and specific. Do not include phrases like "This document" - just state the content directly.
89
+ Keep the summary under 200 words."""
90
+
91
+ def __init__(
92
+ self,
93
+ model: str = "mlx-community/Llama-3.2-3B-Instruct-4bit",
94
+ max_tokens: int = 300,
95
+ ):
96
+ """
97
+ Args:
98
+ model: Model name from mlx-community hub or local path.
99
+ Good options for summarization:
100
+ - mlx-community/Llama-3.2-3B-Instruct-4bit (fast, small)
101
+ - mlx-community/Llama-3.2-8B-Instruct-4bit (better quality)
102
+ - mlx-community/Mistral-7B-Instruct-v0.3-4bit (good balance)
103
+ - mlx-community/Phi-3.5-mini-instruct-4bit (very fast)
104
+ max_tokens: Maximum tokens in generated summary
105
+ """
106
+ try:
107
+ from mlx_lm import load
108
+ except ImportError:
109
+ raise RuntimeError(
110
+ "MLXSummarization requires 'mlx-lm'. "
111
+ "Install with: pip install mlx-lm"
112
+ )
113
+
114
+ self.model_name = model
115
+ self.max_tokens = max_tokens
116
+
117
+ # Load model and tokenizer (downloads on first use)
118
+ self._model, self._tokenizer = load(model)
119
+
120
+ def summarize(self, content: str, *, max_length: int = 500) -> str:
121
+ """Generate a summary using MLX-LM."""
122
+ from mlx_lm import generate
123
+
124
+ # Truncate very long content to fit context window
125
+ # Most models have 4k-8k context, leave room for prompt and response
126
+ max_content_chars = 12000
127
+ truncated = content[:max_content_chars] if len(content) > max_content_chars else content
128
+
129
+ # Format as chat (works with instruction-tuned models)
130
+ if hasattr(self._tokenizer, "apply_chat_template"):
131
+ messages = [
132
+ {"role": "system", "content": self.SYSTEM_PROMPT},
133
+ {"role": "user", "content": f"Summarize the following:\n\n{truncated}"},
134
+ ]
135
+ prompt = self._tokenizer.apply_chat_template(
136
+ messages,
137
+ tokenize=False,
138
+ add_generation_prompt=True
139
+ )
140
+ else:
141
+ # Fallback for models without chat template
142
+ prompt = f"{self.SYSTEM_PROMPT}\n\nDocument:\n{truncated}\n\nSummary:"
143
+
144
+ # Generate
145
+ response = generate(
146
+ self._model,
147
+ self._tokenizer,
148
+ prompt=prompt,
149
+ max_tokens=self.max_tokens,
150
+ verbose=False,
151
+ )
152
+
153
+ return response.strip()
154
+
155
+
156
+ class MLXTagging:
157
+ """
158
+ Tagging provider using MLX-LM on Apple Silicon.
159
+
160
+ Uses local LLMs to generate structured tags. No API key required.
161
+
162
+ Requires: pip install mlx-lm
163
+ """
164
+
165
+ SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
166
+
167
+ Generate tags for these categories when applicable:
168
+ - content_type: The type of content (e.g., "documentation", "code", "article", "config")
169
+ - language: Programming language if code (e.g., "python", "javascript")
170
+ - domain: Subject domain (e.g., "authentication", "database", "api", "testing")
171
+ - framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
172
+
173
+ Only include tags that clearly apply. Values should be lowercase.
174
+ Respond with ONLY a JSON object, no explanation or other text."""
175
+
176
+ def __init__(
177
+ self,
178
+ model: str = "mlx-community/Llama-3.2-3B-Instruct-4bit",
179
+ max_tokens: int = 150,
180
+ ):
181
+ """
182
+ Args:
183
+ model: Model name from mlx-community hub
184
+ max_tokens: Maximum tokens in generated response
185
+ """
186
+ try:
187
+ from mlx_lm import load
188
+ except ImportError:
189
+ raise RuntimeError(
190
+ "MLXTagging requires 'mlx-lm'. "
191
+ "Install with: pip install mlx-lm"
192
+ )
193
+
194
+ self.model_name = model
195
+ self.max_tokens = max_tokens
196
+ self._model, self._tokenizer = load(model)
197
+
198
+ def tag(self, content: str) -> dict[str, str]:
199
+ """Generate tags using MLX-LM."""
200
+ import json
201
+ from mlx_lm import generate
202
+
203
+ # Truncate content
204
+ max_content_chars = 8000
205
+ truncated = content[:max_content_chars] if len(content) > max_content_chars else content
206
+
207
+ # Format prompt
208
+ if hasattr(self._tokenizer, "apply_chat_template"):
209
+ messages = [
210
+ {"role": "system", "content": self.SYSTEM_PROMPT},
211
+ {"role": "user", "content": truncated},
212
+ ]
213
+ prompt = self._tokenizer.apply_chat_template(
214
+ messages,
215
+ tokenize=False,
216
+ add_generation_prompt=True
217
+ )
218
+ else:
219
+ prompt = f"{self.SYSTEM_PROMPT}\n\nDocument:\n{truncated}\n\nJSON:"
220
+
221
+ response = generate(
222
+ self._model,
223
+ self._tokenizer,
224
+ prompt=prompt,
225
+ max_tokens=self.max_tokens,
226
+ verbose=False,
227
+ )
228
+
229
+ # Parse JSON from response
230
+ try:
231
+ # Try to extract JSON from response
232
+ response = response.strip()
233
+ # Handle case where model includes markdown code fence
234
+ if response.startswith("```"):
235
+ response = response.split("```")[1]
236
+ if response.startswith("json"):
237
+ response = response[4:]
238
+
239
+ tags = json.loads(response)
240
+ return {str(k): str(v) for k, v in tags.items()}
241
+ except (json.JSONDecodeError, IndexError):
242
+ return {}
243
+
244
+
245
+ def is_apple_silicon() -> bool:
246
+ """Check if running on Apple Silicon."""
247
+ import platform
248
+ return platform.system() == "Darwin" and platform.machine() == "arm64"
249
+
250
+
251
+ # Register providers (only on Apple Silicon)
252
+ if is_apple_silicon():
253
+ _registry = get_registry()
254
+ _registry.register_embedding("mlx", MLXEmbedding)
255
+ _registry.register_summarization("mlx", MLXSummarization)
256
+ _registry.register_tagging("mlx", MLXTagging)