keep-skill 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keep/__init__.py +53 -0
- keep/__main__.py +8 -0
- keep/api.py +686 -0
- keep/chunking.py +364 -0
- keep/cli.py +503 -0
- keep/config.py +323 -0
- keep/context.py +127 -0
- keep/indexing.py +208 -0
- keep/logging_config.py +73 -0
- keep/paths.py +67 -0
- keep/pending_summaries.py +166 -0
- keep/providers/__init__.py +40 -0
- keep/providers/base.py +416 -0
- keep/providers/documents.py +250 -0
- keep/providers/embedding_cache.py +260 -0
- keep/providers/embeddings.py +245 -0
- keep/providers/llm.py +371 -0
- keep/providers/mlx.py +256 -0
- keep/providers/summarization.py +107 -0
- keep/store.py +403 -0
- keep/types.py +65 -0
- keep_skill-0.1.0.dist-info/METADATA +290 -0
- keep_skill-0.1.0.dist-info/RECORD +26 -0
- keep_skill-0.1.0.dist-info/WHEEL +4 -0
- keep_skill-0.1.0.dist-info/entry_points.txt +2 -0
- keep_skill-0.1.0.dist-info/licenses/LICENSE +21 -0
keep/providers/llm.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Summarization and tagging providers using LLMs.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from .base import SummarizationProvider, TaggingProvider, get_registry
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# -----------------------------------------------------------------------------
|
|
13
|
+
# Summarization Providers
|
|
14
|
+
# -----------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
class AnthropicSummarization:
|
|
17
|
+
"""
|
|
18
|
+
Summarization provider using Anthropic's Claude API.
|
|
19
|
+
|
|
20
|
+
Requires: ANTHROPIC_API_KEY environment variable.
|
|
21
|
+
Optionally reads from OpenClaw config via OPENCLAW_CONFIG env var.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
SYSTEM_PROMPT = """You are a precise summarization assistant.
|
|
25
|
+
Create a concise summary of the provided document that captures:
|
|
26
|
+
- The main purpose or topic
|
|
27
|
+
- Key points or functionality
|
|
28
|
+
- Important details that would help someone decide if this document is relevant
|
|
29
|
+
|
|
30
|
+
Be factual and specific. Do not include phrases like "This document" - just state the content directly."""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model: str = "claude-3-5-haiku-20241022",
|
|
35
|
+
api_key: str | None = None,
|
|
36
|
+
max_tokens: int = 200,
|
|
37
|
+
):
|
|
38
|
+
try:
|
|
39
|
+
from anthropic import Anthropic
|
|
40
|
+
except ImportError:
|
|
41
|
+
raise RuntimeError("AnthropicSummarization requires 'anthropic' library")
|
|
42
|
+
|
|
43
|
+
self.model = model
|
|
44
|
+
self.max_tokens = max_tokens
|
|
45
|
+
|
|
46
|
+
# Try environment variable first, then OpenClaw config
|
|
47
|
+
key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
48
|
+
if not key:
|
|
49
|
+
# Try to read from OpenClaw config (OAuth tokens stored separately)
|
|
50
|
+
# For now, just require explicit API key
|
|
51
|
+
raise ValueError("ANTHROPIC_API_KEY environment variable required")
|
|
52
|
+
|
|
53
|
+
self.client = Anthropic(api_key=key)
|
|
54
|
+
|
|
55
|
+
def summarize(self, content: str) -> str:
|
|
56
|
+
"""Generate summary using Anthropic Claude."""
|
|
57
|
+
# Truncate very long content
|
|
58
|
+
truncated = content[:50000] if len(content) > 50000 else content
|
|
59
|
+
|
|
60
|
+
try:
|
|
61
|
+
response = self.client.messages.create(
|
|
62
|
+
model=self.model,
|
|
63
|
+
max_tokens=self.max_tokens,
|
|
64
|
+
system=self.SYSTEM_PROMPT,
|
|
65
|
+
messages=[
|
|
66
|
+
{"role": "user", "content": truncated}
|
|
67
|
+
],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Extract text from response
|
|
71
|
+
if response.content and len(response.content) > 0:
|
|
72
|
+
return response.content[0].text
|
|
73
|
+
return truncated[:500] # Fallback
|
|
74
|
+
except Exception as e:
|
|
75
|
+
# Fallback to truncation on error
|
|
76
|
+
return truncated[:500]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class OpenAISummarization:
|
|
80
|
+
"""
|
|
81
|
+
Summarization provider using OpenAI's chat API.
|
|
82
|
+
|
|
83
|
+
Requires: KEEP_OPENAI_API_KEY or OPENAI_API_KEY environment variable.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
SYSTEM_PROMPT = """You are a precise summarization assistant.
|
|
87
|
+
Create a concise summary of the provided document that captures:
|
|
88
|
+
- The main purpose or topic
|
|
89
|
+
- Key points or functionality
|
|
90
|
+
- Important details that would help someone decide if this document is relevant
|
|
91
|
+
|
|
92
|
+
Be factual and specific. Do not include phrases like "This document" - just state the content directly."""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
model: str = "gpt-4o-mini",
|
|
97
|
+
api_key: str | None = None,
|
|
98
|
+
max_tokens: int = 200,
|
|
99
|
+
):
|
|
100
|
+
try:
|
|
101
|
+
from openai import OpenAI
|
|
102
|
+
except ImportError:
|
|
103
|
+
raise RuntimeError("OpenAISummarization requires 'openai' library")
|
|
104
|
+
|
|
105
|
+
self.model = model
|
|
106
|
+
self.max_tokens = max_tokens
|
|
107
|
+
|
|
108
|
+
key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
109
|
+
if not key:
|
|
110
|
+
raise ValueError("OpenAI API key required")
|
|
111
|
+
|
|
112
|
+
self._client = OpenAI(api_key=key)
|
|
113
|
+
|
|
114
|
+
def summarize(self, content: str, *, max_length: int = 500) -> str:
|
|
115
|
+
"""Generate a summary using OpenAI."""
|
|
116
|
+
# Truncate very long content to avoid token limits
|
|
117
|
+
truncated = content[:50000] if len(content) > 50000 else content
|
|
118
|
+
|
|
119
|
+
response = self._client.chat.completions.create(
|
|
120
|
+
model=self.model,
|
|
121
|
+
messages=[
|
|
122
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
123
|
+
{"role": "user", "content": truncated},
|
|
124
|
+
],
|
|
125
|
+
max_tokens=self.max_tokens,
|
|
126
|
+
temperature=0.3,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return response.choices[0].message.content.strip()
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class OllamaSummarization:
|
|
133
|
+
"""
|
|
134
|
+
Summarization provider using Ollama's local API.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
SYSTEM_PROMPT = OpenAISummarization.SYSTEM_PROMPT
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
model: str = "llama3.2",
|
|
142
|
+
base_url: str = "http://localhost:11434",
|
|
143
|
+
):
|
|
144
|
+
self.model = model
|
|
145
|
+
self.base_url = base_url.rstrip("/")
|
|
146
|
+
|
|
147
|
+
def summarize(self, content: str, *, max_length: int = 500) -> str:
|
|
148
|
+
"""Generate a summary using Ollama."""
|
|
149
|
+
import requests
|
|
150
|
+
|
|
151
|
+
truncated = content[:50000] if len(content) > 50000 else content
|
|
152
|
+
|
|
153
|
+
response = requests.post(
|
|
154
|
+
f"{self.base_url}/api/chat",
|
|
155
|
+
json={
|
|
156
|
+
"model": self.model,
|
|
157
|
+
"messages": [
|
|
158
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
159
|
+
{"role": "user", "content": truncated},
|
|
160
|
+
],
|
|
161
|
+
"stream": False,
|
|
162
|
+
},
|
|
163
|
+
)
|
|
164
|
+
response.raise_for_status()
|
|
165
|
+
|
|
166
|
+
return response.json()["message"]["content"].strip()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class PassthroughSummarization:
|
|
170
|
+
"""
|
|
171
|
+
Summarization provider that returns the first N characters.
|
|
172
|
+
|
|
173
|
+
Useful for testing or when LLM summarization is not needed.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, max_chars: int = 500):
|
|
177
|
+
self.max_chars = max_chars
|
|
178
|
+
|
|
179
|
+
def summarize(self, content: str, *, max_length: int = 500) -> str:
|
|
180
|
+
"""Return truncated content as summary."""
|
|
181
|
+
limit = min(self.max_chars, max_length)
|
|
182
|
+
if len(content) <= limit:
|
|
183
|
+
return content
|
|
184
|
+
return content[:limit].rsplit(" ", 1)[0] + "..."
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# -----------------------------------------------------------------------------
|
|
188
|
+
# Tagging Providers
|
|
189
|
+
# -----------------------------------------------------------------------------
|
|
190
|
+
|
|
191
|
+
class AnthropicTagging:
|
|
192
|
+
"""
|
|
193
|
+
Tagging provider using Anthropic's Claude API with JSON output.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
|
|
197
|
+
|
|
198
|
+
Generate tags for these categories when applicable:
|
|
199
|
+
- content_type: The type of content (e.g., "documentation", "code", "article", "config")
|
|
200
|
+
- language: Programming language if code (e.g., "python", "javascript")
|
|
201
|
+
- domain: Subject domain (e.g., "authentication", "database", "api", "testing")
|
|
202
|
+
- framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
|
|
203
|
+
|
|
204
|
+
Only include tags that clearly apply. Values should be lowercase.
|
|
205
|
+
|
|
206
|
+
Respond with a JSON object only, no explanation."""
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
model: str = "claude-3-5-haiku-20241022",
|
|
211
|
+
api_key: str | None = None,
|
|
212
|
+
):
|
|
213
|
+
try:
|
|
214
|
+
from anthropic import Anthropic
|
|
215
|
+
except ImportError:
|
|
216
|
+
raise RuntimeError("AnthropicTagging requires 'anthropic' library")
|
|
217
|
+
|
|
218
|
+
self.model = model
|
|
219
|
+
|
|
220
|
+
key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
221
|
+
if not key:
|
|
222
|
+
raise ValueError("ANTHROPIC_API_KEY environment variable required")
|
|
223
|
+
|
|
224
|
+
self._client = Anthropic(api_key=key)
|
|
225
|
+
|
|
226
|
+
def tag(self, content: str) -> dict[str, str]:
|
|
227
|
+
"""Generate tags using Anthropic Claude."""
|
|
228
|
+
truncated = content[:20000] if len(content) > 20000 else content
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
response = self._client.messages.create(
|
|
232
|
+
model=self.model,
|
|
233
|
+
max_tokens=200,
|
|
234
|
+
temperature=0.2,
|
|
235
|
+
system=self.SYSTEM_PROMPT,
|
|
236
|
+
messages=[
|
|
237
|
+
{"role": "user", "content": truncated}
|
|
238
|
+
],
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Parse JSON from response
|
|
242
|
+
if response.content and len(response.content) > 0:
|
|
243
|
+
tags = json.loads(response.content[0].text)
|
|
244
|
+
return {str(k): str(v) for k, v in tags.items()}
|
|
245
|
+
return {}
|
|
246
|
+
except (json.JSONDecodeError, Exception):
|
|
247
|
+
return {}
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class OpenAITagging:
|
|
251
|
+
"""
|
|
252
|
+
Tagging provider using OpenAI's chat API with JSON output.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
|
|
256
|
+
|
|
257
|
+
Generate tags for these categories when applicable:
|
|
258
|
+
- content_type: The type of content (e.g., "documentation", "code", "article", "config")
|
|
259
|
+
- language: Programming language if code (e.g., "python", "javascript")
|
|
260
|
+
- domain: Subject domain (e.g., "authentication", "database", "api", "testing")
|
|
261
|
+
- framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
|
|
262
|
+
|
|
263
|
+
Only include tags that clearly apply. Values should be lowercase.
|
|
264
|
+
|
|
265
|
+
Respond with a JSON object only, no explanation."""
|
|
266
|
+
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
model: str = "gpt-4o-mini",
|
|
270
|
+
api_key: str | None = None,
|
|
271
|
+
):
|
|
272
|
+
try:
|
|
273
|
+
from openai import OpenAI
|
|
274
|
+
except ImportError:
|
|
275
|
+
raise RuntimeError("OpenAITagging requires 'openai' library")
|
|
276
|
+
|
|
277
|
+
self.model = model
|
|
278
|
+
|
|
279
|
+
key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
280
|
+
if not key:
|
|
281
|
+
raise ValueError("OpenAI API key required")
|
|
282
|
+
|
|
283
|
+
self._client = OpenAI(api_key=key)
|
|
284
|
+
|
|
285
|
+
def tag(self, content: str) -> dict[str, str]:
|
|
286
|
+
"""Generate tags using OpenAI."""
|
|
287
|
+
truncated = content[:20000] if len(content) > 20000 else content
|
|
288
|
+
|
|
289
|
+
response = self._client.chat.completions.create(
|
|
290
|
+
model=self.model,
|
|
291
|
+
messages=[
|
|
292
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
293
|
+
{"role": "user", "content": truncated},
|
|
294
|
+
],
|
|
295
|
+
response_format={"type": "json_object"},
|
|
296
|
+
max_tokens=200,
|
|
297
|
+
temperature=0.2,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
tags = json.loads(response.choices[0].message.content)
|
|
302
|
+
# Ensure all values are strings
|
|
303
|
+
return {str(k): str(v) for k, v in tags.items()}
|
|
304
|
+
except json.JSONDecodeError:
|
|
305
|
+
return {}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class OllamaTagging:
|
|
309
|
+
"""
|
|
310
|
+
Tagging provider using Ollama's local API.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
SYSTEM_PROMPT = OpenAITagging.SYSTEM_PROMPT
|
|
314
|
+
|
|
315
|
+
def __init__(
|
|
316
|
+
self,
|
|
317
|
+
model: str = "llama3.2",
|
|
318
|
+
base_url: str = "http://localhost:11434",
|
|
319
|
+
):
|
|
320
|
+
self.model = model
|
|
321
|
+
self.base_url = base_url.rstrip("/")
|
|
322
|
+
|
|
323
|
+
def tag(self, content: str) -> dict[str, str]:
|
|
324
|
+
"""Generate tags using Ollama."""
|
|
325
|
+
import requests
|
|
326
|
+
|
|
327
|
+
truncated = content[:20000] if len(content) > 20000 else content
|
|
328
|
+
|
|
329
|
+
response = requests.post(
|
|
330
|
+
f"{self.base_url}/api/chat",
|
|
331
|
+
json={
|
|
332
|
+
"model": self.model,
|
|
333
|
+
"messages": [
|
|
334
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
335
|
+
{"role": "user", "content": truncated},
|
|
336
|
+
],
|
|
337
|
+
"format": "json",
|
|
338
|
+
"stream": False,
|
|
339
|
+
},
|
|
340
|
+
)
|
|
341
|
+
response.raise_for_status()
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
tags = json.loads(response.json()["message"]["content"])
|
|
345
|
+
return {str(k): str(v) for k, v in tags.items()}
|
|
346
|
+
except (json.JSONDecodeError, KeyError):
|
|
347
|
+
return {}
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class NoopTagging:
|
|
351
|
+
"""
|
|
352
|
+
Tagging provider that returns empty tags.
|
|
353
|
+
|
|
354
|
+
Useful when tagging is disabled or for testing.
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
def tag(self, content: str) -> dict[str, str]:
|
|
358
|
+
"""Return empty tags."""
|
|
359
|
+
return {}
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
# Register providers
|
|
363
|
+
_registry = get_registry()
|
|
364
|
+
_registry.register_summarization("anthropic", AnthropicSummarization)
|
|
365
|
+
_registry.register_summarization("openai", OpenAISummarization)
|
|
366
|
+
_registry.register_summarization("ollama", OllamaSummarization)
|
|
367
|
+
_registry.register_summarization("passthrough", PassthroughSummarization)
|
|
368
|
+
_registry.register_tagging("anthropic", AnthropicTagging)
|
|
369
|
+
_registry.register_tagging("openai", OpenAITagging)
|
|
370
|
+
_registry.register_tagging("ollama", OllamaTagging)
|
|
371
|
+
_registry.register_tagging("noop", NoopTagging)
|
keep/providers/mlx.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MLX providers for Apple Silicon.
|
|
3
|
+
|
|
4
|
+
MLX is Apple's ML framework optimized for Apple Silicon. These providers
|
|
5
|
+
run entirely locally with no API keys required.
|
|
6
|
+
|
|
7
|
+
Requires: pip install mlx-lm mlx
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .base import EmbeddingProvider, SummarizationProvider, get_registry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MLXEmbedding:
|
|
17
|
+
"""
|
|
18
|
+
Embedding provider using MLX on Apple Silicon.
|
|
19
|
+
|
|
20
|
+
Uses sentence-transformer compatible models converted to MLX format.
|
|
21
|
+
|
|
22
|
+
Requires: pip install mlx sentence-transformers
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, model: str = "mlx-community/bge-small-en-v1.5"):
|
|
26
|
+
"""
|
|
27
|
+
Args:
|
|
28
|
+
model: Model name from mlx-community hub or local path.
|
|
29
|
+
Good options:
|
|
30
|
+
- mlx-community/bge-small-en-v1.5 (small, fast)
|
|
31
|
+
- mlx-community/bge-base-en-v1.5 (balanced)
|
|
32
|
+
- mlx-community/bge-large-en-v1.5 (best quality)
|
|
33
|
+
"""
|
|
34
|
+
try:
|
|
35
|
+
import mlx.core as mx
|
|
36
|
+
from sentence_transformers import SentenceTransformer
|
|
37
|
+
except ImportError:
|
|
38
|
+
raise RuntimeError(
|
|
39
|
+
"MLXEmbedding requires 'mlx' and 'sentence-transformers'. "
|
|
40
|
+
"Install with: pip install mlx sentence-transformers"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.model_name = model
|
|
44
|
+
|
|
45
|
+
# sentence-transformers can use MLX backend on Apple Silicon
|
|
46
|
+
# For MLX-specific models, we use the direct approach
|
|
47
|
+
if model.startswith("mlx-community/"):
|
|
48
|
+
# Use sentence-transformers which auto-detects MLX
|
|
49
|
+
self._model = SentenceTransformer(model, device="mps")
|
|
50
|
+
else:
|
|
51
|
+
self._model = SentenceTransformer(model)
|
|
52
|
+
|
|
53
|
+
self._dimension: int | None = None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def dimension(self) -> int:
|
|
57
|
+
"""Get embedding dimension from the model."""
|
|
58
|
+
if self._dimension is None:
|
|
59
|
+
self._dimension = self._model.get_sentence_embedding_dimension()
|
|
60
|
+
return self._dimension
|
|
61
|
+
|
|
62
|
+
def embed(self, text: str) -> list[float]:
|
|
63
|
+
"""Generate embedding for a single text."""
|
|
64
|
+
embedding = self._model.encode(text, convert_to_numpy=True)
|
|
65
|
+
return embedding.tolist()
|
|
66
|
+
|
|
67
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
68
|
+
"""Generate embeddings for multiple texts."""
|
|
69
|
+
embeddings = self._model.encode(texts, convert_to_numpy=True)
|
|
70
|
+
return embeddings.tolist()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class MLXSummarization:
|
|
74
|
+
"""
|
|
75
|
+
Summarization provider using MLX-LM on Apple Silicon.
|
|
76
|
+
|
|
77
|
+
Runs local LLMs optimized for Apple Silicon. No API key required.
|
|
78
|
+
|
|
79
|
+
Requires: pip install mlx-lm
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
SYSTEM_PROMPT = """You are a precise summarization assistant.
|
|
83
|
+
Create a concise summary of the provided document that captures:
|
|
84
|
+
- The main purpose or topic
|
|
85
|
+
- Key points or functionality
|
|
86
|
+
- Important details that would help someone decide if this document is relevant
|
|
87
|
+
|
|
88
|
+
Be factual and specific. Do not include phrases like "This document" - just state the content directly.
|
|
89
|
+
Keep the summary under 200 words."""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
model: str = "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
|
94
|
+
max_tokens: int = 300,
|
|
95
|
+
):
|
|
96
|
+
"""
|
|
97
|
+
Args:
|
|
98
|
+
model: Model name from mlx-community hub or local path.
|
|
99
|
+
Good options for summarization:
|
|
100
|
+
- mlx-community/Llama-3.2-3B-Instruct-4bit (fast, small)
|
|
101
|
+
- mlx-community/Llama-3.2-8B-Instruct-4bit (better quality)
|
|
102
|
+
- mlx-community/Mistral-7B-Instruct-v0.3-4bit (good balance)
|
|
103
|
+
- mlx-community/Phi-3.5-mini-instruct-4bit (very fast)
|
|
104
|
+
max_tokens: Maximum tokens in generated summary
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
from mlx_lm import load
|
|
108
|
+
except ImportError:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"MLXSummarization requires 'mlx-lm'. "
|
|
111
|
+
"Install with: pip install mlx-lm"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.model_name = model
|
|
115
|
+
self.max_tokens = max_tokens
|
|
116
|
+
|
|
117
|
+
# Load model and tokenizer (downloads on first use)
|
|
118
|
+
self._model, self._tokenizer = load(model)
|
|
119
|
+
|
|
120
|
+
def summarize(self, content: str, *, max_length: int = 500) -> str:
|
|
121
|
+
"""Generate a summary using MLX-LM."""
|
|
122
|
+
from mlx_lm import generate
|
|
123
|
+
|
|
124
|
+
# Truncate very long content to fit context window
|
|
125
|
+
# Most models have 4k-8k context, leave room for prompt and response
|
|
126
|
+
max_content_chars = 12000
|
|
127
|
+
truncated = content[:max_content_chars] if len(content) > max_content_chars else content
|
|
128
|
+
|
|
129
|
+
# Format as chat (works with instruction-tuned models)
|
|
130
|
+
if hasattr(self._tokenizer, "apply_chat_template"):
|
|
131
|
+
messages = [
|
|
132
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
133
|
+
{"role": "user", "content": f"Summarize the following:\n\n{truncated}"},
|
|
134
|
+
]
|
|
135
|
+
prompt = self._tokenizer.apply_chat_template(
|
|
136
|
+
messages,
|
|
137
|
+
tokenize=False,
|
|
138
|
+
add_generation_prompt=True
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
# Fallback for models without chat template
|
|
142
|
+
prompt = f"{self.SYSTEM_PROMPT}\n\nDocument:\n{truncated}\n\nSummary:"
|
|
143
|
+
|
|
144
|
+
# Generate
|
|
145
|
+
response = generate(
|
|
146
|
+
self._model,
|
|
147
|
+
self._tokenizer,
|
|
148
|
+
prompt=prompt,
|
|
149
|
+
max_tokens=self.max_tokens,
|
|
150
|
+
verbose=False,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return response.strip()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class MLXTagging:
|
|
157
|
+
"""
|
|
158
|
+
Tagging provider using MLX-LM on Apple Silicon.
|
|
159
|
+
|
|
160
|
+
Uses local LLMs to generate structured tags. No API key required.
|
|
161
|
+
|
|
162
|
+
Requires: pip install mlx-lm
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
SYSTEM_PROMPT = """Analyze the document and generate relevant tags as a JSON object.
|
|
166
|
+
|
|
167
|
+
Generate tags for these categories when applicable:
|
|
168
|
+
- content_type: The type of content (e.g., "documentation", "code", "article", "config")
|
|
169
|
+
- language: Programming language if code (e.g., "python", "javascript")
|
|
170
|
+
- domain: Subject domain (e.g., "authentication", "database", "api", "testing")
|
|
171
|
+
- framework: Framework or library if relevant (e.g., "react", "django", "fastapi")
|
|
172
|
+
|
|
173
|
+
Only include tags that clearly apply. Values should be lowercase.
|
|
174
|
+
Respond with ONLY a JSON object, no explanation or other text."""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
model: str = "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
|
179
|
+
max_tokens: int = 150,
|
|
180
|
+
):
|
|
181
|
+
"""
|
|
182
|
+
Args:
|
|
183
|
+
model: Model name from mlx-community hub
|
|
184
|
+
max_tokens: Maximum tokens in generated response
|
|
185
|
+
"""
|
|
186
|
+
try:
|
|
187
|
+
from mlx_lm import load
|
|
188
|
+
except ImportError:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
"MLXTagging requires 'mlx-lm'. "
|
|
191
|
+
"Install with: pip install mlx-lm"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self.model_name = model
|
|
195
|
+
self.max_tokens = max_tokens
|
|
196
|
+
self._model, self._tokenizer = load(model)
|
|
197
|
+
|
|
198
|
+
def tag(self, content: str) -> dict[str, str]:
|
|
199
|
+
"""Generate tags using MLX-LM."""
|
|
200
|
+
import json
|
|
201
|
+
from mlx_lm import generate
|
|
202
|
+
|
|
203
|
+
# Truncate content
|
|
204
|
+
max_content_chars = 8000
|
|
205
|
+
truncated = content[:max_content_chars] if len(content) > max_content_chars else content
|
|
206
|
+
|
|
207
|
+
# Format prompt
|
|
208
|
+
if hasattr(self._tokenizer, "apply_chat_template"):
|
|
209
|
+
messages = [
|
|
210
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
211
|
+
{"role": "user", "content": truncated},
|
|
212
|
+
]
|
|
213
|
+
prompt = self._tokenizer.apply_chat_template(
|
|
214
|
+
messages,
|
|
215
|
+
tokenize=False,
|
|
216
|
+
add_generation_prompt=True
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
prompt = f"{self.SYSTEM_PROMPT}\n\nDocument:\n{truncated}\n\nJSON:"
|
|
220
|
+
|
|
221
|
+
response = generate(
|
|
222
|
+
self._model,
|
|
223
|
+
self._tokenizer,
|
|
224
|
+
prompt=prompt,
|
|
225
|
+
max_tokens=self.max_tokens,
|
|
226
|
+
verbose=False,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Parse JSON from response
|
|
230
|
+
try:
|
|
231
|
+
# Try to extract JSON from response
|
|
232
|
+
response = response.strip()
|
|
233
|
+
# Handle case where model includes markdown code fence
|
|
234
|
+
if response.startswith("```"):
|
|
235
|
+
response = response.split("```")[1]
|
|
236
|
+
if response.startswith("json"):
|
|
237
|
+
response = response[4:]
|
|
238
|
+
|
|
239
|
+
tags = json.loads(response)
|
|
240
|
+
return {str(k): str(v) for k, v in tags.items()}
|
|
241
|
+
except (json.JSONDecodeError, IndexError):
|
|
242
|
+
return {}
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def is_apple_silicon() -> bool:
|
|
246
|
+
"""Check if running on Apple Silicon."""
|
|
247
|
+
import platform
|
|
248
|
+
return platform.system() == "Darwin" and platform.machine() == "arm64"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# Register providers (only on Apple Silicon)
|
|
252
|
+
if is_apple_silicon():
|
|
253
|
+
_registry = get_registry()
|
|
254
|
+
_registry.register_embedding("mlx", MLXEmbedding)
|
|
255
|
+
_registry.register_summarization("mlx", MLXSummarization)
|
|
256
|
+
_registry.register_tagging("mlx", MLXTagging)
|