claude-mpm 5.6.10__py3-none-any.whl → 5.6.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of claude-mpm might be problematic. Click here for more details.
- claude_mpm/VERSION +1 -1
- claude_mpm/cli/commands/commander.py +173 -3
- claude_mpm/cli/parsers/commander_parser.py +41 -8
- claude_mpm/cli/startup.py +104 -1
- claude_mpm/cli/startup_display.py +2 -1
- claude_mpm/commander/__init__.py +6 -0
- claude_mpm/commander/adapters/__init__.py +32 -3
- claude_mpm/commander/adapters/auggie.py +260 -0
- claude_mpm/commander/adapters/base.py +98 -1
- claude_mpm/commander/adapters/claude_code.py +32 -1
- claude_mpm/commander/adapters/codex.py +237 -0
- claude_mpm/commander/adapters/example_usage.py +310 -0
- claude_mpm/commander/adapters/mpm.py +389 -0
- claude_mpm/commander/adapters/registry.py +204 -0
- claude_mpm/commander/api/app.py +32 -16
- claude_mpm/commander/api/routes/messages.py +11 -11
- claude_mpm/commander/api/routes/projects.py +20 -20
- claude_mpm/commander/api/routes/sessions.py +19 -21
- claude_mpm/commander/api/routes/work.py +86 -50
- claude_mpm/commander/api/schemas.py +4 -0
- claude_mpm/commander/chat/cli.py +4 -0
- claude_mpm/commander/core/__init__.py +10 -0
- claude_mpm/commander/core/block_manager.py +325 -0
- claude_mpm/commander/core/response_manager.py +323 -0
- claude_mpm/commander/daemon.py +206 -10
- claude_mpm/commander/env_loader.py +59 -0
- claude_mpm/commander/memory/__init__.py +45 -0
- claude_mpm/commander/memory/compression.py +347 -0
- claude_mpm/commander/memory/embeddings.py +230 -0
- claude_mpm/commander/memory/entities.py +310 -0
- claude_mpm/commander/memory/example_usage.py +290 -0
- claude_mpm/commander/memory/integration.py +325 -0
- claude_mpm/commander/memory/search.py +381 -0
- claude_mpm/commander/memory/store.py +657 -0
- claude_mpm/commander/registry.py +10 -4
- claude_mpm/commander/runtime/monitor.py +32 -2
- claude_mpm/commander/work/executor.py +38 -20
- claude_mpm/commander/workflow/event_handler.py +25 -3
- claude_mpm/core/claude_runner.py +143 -0
- claude_mpm/core/output_style_manager.py +34 -7
- claude_mpm/hooks/claude_hooks/__pycache__/event_handlers.cpython-311.pyc +0 -0
- claude_mpm/hooks/claude_hooks/__pycache__/event_handlers.cpython-314.pyc +0 -0
- claude_mpm/hooks/claude_hooks/__pycache__/installer.cpython-311.pyc +0 -0
- claude_mpm/hooks/claude_hooks/auto_pause_handler.py +0 -0
- claude_mpm/hooks/claude_hooks/event_handlers.py +22 -0
- claude_mpm/hooks/claude_hooks/hook_handler.py +0 -0
- claude_mpm/hooks/claude_hooks/memory_integration.py +0 -0
- claude_mpm/hooks/claude_hooks/response_tracking.py +0 -0
- claude_mpm/hooks/claude_hooks/services/__pycache__/connection_manager.cpython-311.pyc +0 -0
- claude_mpm/hooks/templates/pre_tool_use_template.py +0 -0
- claude_mpm/scripts/start_activity_logging.py +0 -0
- claude_mpm/skills/__init__.py +2 -1
- claude_mpm/skills/bundled/pm/mpm-session-pause/SKILL.md +170 -0
- claude_mpm/skills/registry.py +295 -90
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/METADATA +5 -3
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/RECORD +55 -36
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/WHEEL +0 -0
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/entry_points.txt +0 -0
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/licenses/LICENSE +0 -0
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/licenses/LICENSE-FAQ.md +0 -0
- {claude_mpm-5.6.10.dist-info → claude_mpm-5.6.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Conversation memory system for Commander.
|
|
2
|
+
|
|
3
|
+
This module provides semantic search, storage, and context compression
|
|
4
|
+
for all Claude Code instance conversations.
|
|
5
|
+
|
|
6
|
+
Key Components:
|
|
7
|
+
- ConversationStore: CRUD operations for conversations
|
|
8
|
+
- EmbeddingService: Generate vector embeddings
|
|
9
|
+
- SemanticSearch: Query conversations semantically
|
|
10
|
+
- ContextCompressor: Summarize conversations for context
|
|
11
|
+
- EntityExtractor: Extract files, functions, errors
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
>>> from claude_mpm.commander.memory import (
|
|
15
|
+
... ConversationStore,
|
|
16
|
+
... EmbeddingService,
|
|
17
|
+
... SemanticSearch,
|
|
18
|
+
... ContextCompressor,
|
|
19
|
+
... )
|
|
20
|
+
>>> store = ConversationStore()
|
|
21
|
+
>>> embeddings = EmbeddingService()
|
|
22
|
+
>>> search = SemanticSearch(store, embeddings)
|
|
23
|
+
>>> results = await search.search("how did we fix the login bug?")
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from .compression import ContextCompressor
|
|
27
|
+
from .embeddings import EmbeddingService
|
|
28
|
+
from .entities import Entity, EntityExtractor, EntityType
|
|
29
|
+
from .integration import MemoryIntegration
|
|
30
|
+
from .search import SearchResult, SemanticSearch
|
|
31
|
+
from .store import Conversation, ConversationMessage, ConversationStore
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"ContextCompressor",
|
|
35
|
+
"Conversation",
|
|
36
|
+
"ConversationMessage",
|
|
37
|
+
"ConversationStore",
|
|
38
|
+
"EmbeddingService",
|
|
39
|
+
"Entity",
|
|
40
|
+
"EntityExtractor",
|
|
41
|
+
"EntityType",
|
|
42
|
+
"MemoryIntegration",
|
|
43
|
+
"SearchResult",
|
|
44
|
+
"SemanticSearch",
|
|
45
|
+
]
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
"""Context compression and conversation summarization.
|
|
2
|
+
|
|
3
|
+
Compresses long conversations into concise summaries for efficient context
|
|
4
|
+
loading when resuming sessions or searching past work.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import List, Optional
|
|
9
|
+
|
|
10
|
+
from ..llm.openrouter_client import OpenRouterClient
|
|
11
|
+
from .store import Conversation, ConversationMessage
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ContextCompressor:
|
|
17
|
+
"""Compress conversations into summaries for context loading.
|
|
18
|
+
|
|
19
|
+
Uses cheap LLM (mistral-small) to generate summaries of conversations
|
|
20
|
+
and compress multiple conversations into context strings.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
client: OpenRouterClient for LLM requests
|
|
24
|
+
summary_threshold: Minimum messages to trigger summarization
|
|
25
|
+
max_context_tokens: Maximum tokens for compressed context
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> compressor = ContextCompressor(client)
|
|
29
|
+
>>> summary = await compressor.summarize(messages)
|
|
30
|
+
>>> context = await compressor.compress_for_context(
|
|
31
|
+
... conversations,
|
|
32
|
+
... max_tokens=4000
|
|
33
|
+
... )
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
client: OpenRouterClient,
|
|
39
|
+
summary_threshold: int = 10,
|
|
40
|
+
max_context_tokens: int = 4000,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize context compressor.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
client: OpenRouterClient for LLM requests
|
|
46
|
+
summary_threshold: Minimum messages to summarize
|
|
47
|
+
max_context_tokens: Maximum tokens for context string
|
|
48
|
+
"""
|
|
49
|
+
self.client = client
|
|
50
|
+
self.summary_threshold = summary_threshold
|
|
51
|
+
self.max_context_tokens = max_context_tokens
|
|
52
|
+
|
|
53
|
+
logger.info(
|
|
54
|
+
"ContextCompressor initialized (threshold: %d msgs, max_tokens: %d)",
|
|
55
|
+
summary_threshold,
|
|
56
|
+
max_context_tokens,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
async def summarize(
|
|
60
|
+
self,
|
|
61
|
+
messages: List[ConversationMessage],
|
|
62
|
+
focus: Optional[str] = None,
|
|
63
|
+
) -> str:
|
|
64
|
+
"""Generate summary of conversation messages.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
messages: List of messages to summarize
|
|
68
|
+
focus: Optional focus area (e.g., "bug fixes", "API changes")
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Concise summary (2-4 sentences)
|
|
72
|
+
|
|
73
|
+
Example:
|
|
74
|
+
>>> summary = await compressor.summarize(messages)
|
|
75
|
+
>>> print(summary)
|
|
76
|
+
"Fixed login authentication bug in src/auth.py by updating token validation..."
|
|
77
|
+
"""
|
|
78
|
+
if len(messages) < 2:
|
|
79
|
+
# Too short to summarize
|
|
80
|
+
return messages[0].content if messages else ""
|
|
81
|
+
|
|
82
|
+
# Build conversation text
|
|
83
|
+
conversation_text = self._format_messages(messages)
|
|
84
|
+
|
|
85
|
+
# Build summarization prompt
|
|
86
|
+
if focus:
|
|
87
|
+
prompt = f"""Summarize the following conversation, focusing on: {focus}
|
|
88
|
+
|
|
89
|
+
Conversation:
|
|
90
|
+
{conversation_text}
|
|
91
|
+
|
|
92
|
+
Provide a concise summary (2-4 sentences) that captures:
|
|
93
|
+
1. What was the main task or problem
|
|
94
|
+
2. What actions were taken
|
|
95
|
+
3. What was the outcome or current status
|
|
96
|
+
4. Any important files, functions, or errors mentioned
|
|
97
|
+
|
|
98
|
+
Summary:"""
|
|
99
|
+
else:
|
|
100
|
+
prompt = f"""Summarize the following conversation in 2-4 sentences.
|
|
101
|
+
|
|
102
|
+
Conversation:
|
|
103
|
+
{conversation_text}
|
|
104
|
+
|
|
105
|
+
Focus on:
|
|
106
|
+
1. What was the main task or problem
|
|
107
|
+
2. What actions were taken
|
|
108
|
+
3. What was the outcome or current status
|
|
109
|
+
|
|
110
|
+
Summary:"""
|
|
111
|
+
|
|
112
|
+
messages_for_llm = [{"role": "user", "content": prompt}]
|
|
113
|
+
|
|
114
|
+
system = (
|
|
115
|
+
"You are a technical summarization assistant. "
|
|
116
|
+
"Provide clear, concise summaries of development conversations. "
|
|
117
|
+
"Focus on actionable information and key outcomes."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
summary = await self.client.chat(messages_for_llm, system=system)
|
|
121
|
+
logger.debug(
|
|
122
|
+
"Generated summary (%d chars) from %d messages", len(summary), len(messages)
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return summary.strip()
|
|
126
|
+
|
|
127
|
+
async def compress_for_context(
|
|
128
|
+
self,
|
|
129
|
+
conversations: List[Conversation],
|
|
130
|
+
max_tokens: Optional[int] = None,
|
|
131
|
+
prioritize_recent: bool = True,
|
|
132
|
+
) -> str:
|
|
133
|
+
"""Compress multiple conversations into context string.
|
|
134
|
+
|
|
135
|
+
Prioritizes recent conversations and uses summaries for older ones
|
|
136
|
+
to fit within token budget.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
conversations: List of conversations to compress
|
|
140
|
+
max_tokens: Maximum tokens (default: self.max_context_tokens)
|
|
141
|
+
prioritize_recent: Whether to prioritize recent conversations
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Compressed context string ready for LLM input
|
|
145
|
+
|
|
146
|
+
Example:
|
|
147
|
+
>>> context = await compressor.compress_for_context(
|
|
148
|
+
... conversations,
|
|
149
|
+
... max_tokens=4000
|
|
150
|
+
... )
|
|
151
|
+
>>> print(f"Context: {len(context)} chars")
|
|
152
|
+
"""
|
|
153
|
+
if max_tokens is None:
|
|
154
|
+
max_tokens = self.max_context_tokens
|
|
155
|
+
|
|
156
|
+
# Sort by recency if prioritizing
|
|
157
|
+
if prioritize_recent:
|
|
158
|
+
conversations = sorted(
|
|
159
|
+
conversations, key=lambda c: c.updated_at, reverse=True
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Build context incrementally
|
|
163
|
+
context_parts = []
|
|
164
|
+
current_tokens = 0
|
|
165
|
+
|
|
166
|
+
for conv in conversations:
|
|
167
|
+
# Use summary if available, else generate one
|
|
168
|
+
if conv.summary:
|
|
169
|
+
summary_text = conv.summary
|
|
170
|
+
elif len(conv.messages) >= self.summary_threshold:
|
|
171
|
+
# Generate summary on-the-fly
|
|
172
|
+
summary_text = await self.summarize(conv.messages)
|
|
173
|
+
else:
|
|
174
|
+
# Use full conversation for short ones
|
|
175
|
+
summary_text = conv.get_full_text()
|
|
176
|
+
|
|
177
|
+
# Format conversation section
|
|
178
|
+
section = self._format_conversation_section(conv, summary_text)
|
|
179
|
+
section_tokens = len(section) // 4 # Rough approximation
|
|
180
|
+
|
|
181
|
+
# Check if adding this would exceed budget
|
|
182
|
+
if current_tokens + section_tokens > max_tokens:
|
|
183
|
+
# Try to fit summary only
|
|
184
|
+
short_summary = summary_text.split(". ")[0] + "."
|
|
185
|
+
short_section = self._format_conversation_section(conv, short_summary)
|
|
186
|
+
short_tokens = len(short_section) // 4
|
|
187
|
+
|
|
188
|
+
if current_tokens + short_tokens <= max_tokens:
|
|
189
|
+
context_parts.append(short_section)
|
|
190
|
+
current_tokens += short_tokens
|
|
191
|
+
else:
|
|
192
|
+
# Can't fit any more, stop
|
|
193
|
+
break
|
|
194
|
+
else:
|
|
195
|
+
context_parts.append(section)
|
|
196
|
+
current_tokens += section_tokens
|
|
197
|
+
|
|
198
|
+
context = "\n\n---\n\n".join(context_parts)
|
|
199
|
+
|
|
200
|
+
logger.info(
|
|
201
|
+
"Compressed %d conversations into context (%d chars, ~%d tokens)",
|
|
202
|
+
len(context_parts),
|
|
203
|
+
len(context),
|
|
204
|
+
current_tokens,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return context
|
|
208
|
+
|
|
209
|
+
def needs_summarization(self, messages: List[ConversationMessage]) -> bool:
|
|
210
|
+
"""Check if conversation needs summarization.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
messages: List of messages to check
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
True if message count exceeds threshold
|
|
217
|
+
|
|
218
|
+
Example:
|
|
219
|
+
>>> if compressor.needs_summarization(messages):
|
|
220
|
+
... summary = await compressor.summarize(messages)
|
|
221
|
+
"""
|
|
222
|
+
return len(messages) >= self.summary_threshold
|
|
223
|
+
|
|
224
|
+
def _format_messages(
|
|
225
|
+
self,
|
|
226
|
+
messages: List[ConversationMessage],
|
|
227
|
+
max_messages: Optional[int] = None,
|
|
228
|
+
) -> str:
|
|
229
|
+
"""Format messages as text for summarization.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
messages: Messages to format
|
|
233
|
+
max_messages: Maximum messages to include
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Formatted conversation text
|
|
237
|
+
"""
|
|
238
|
+
if max_messages:
|
|
239
|
+
messages = messages[:max_messages]
|
|
240
|
+
|
|
241
|
+
lines = []
|
|
242
|
+
for msg in messages:
|
|
243
|
+
# Format: ROLE: content
|
|
244
|
+
lines.append(f"{msg.role.upper()}: {msg.content}")
|
|
245
|
+
|
|
246
|
+
return "\n\n".join(lines)
|
|
247
|
+
|
|
248
|
+
def _format_conversation_section(
|
|
249
|
+
self, conversation: Conversation, summary: str
|
|
250
|
+
) -> str:
|
|
251
|
+
"""Format conversation section for context string.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
conversation: Conversation to format
|
|
255
|
+
summary: Summary or full text
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Formatted section with metadata
|
|
259
|
+
"""
|
|
260
|
+
# Format timestamp
|
|
261
|
+
timestamp = conversation.updated_at.strftime("%Y-%m-%d %H:%M")
|
|
262
|
+
|
|
263
|
+
# Build section
|
|
264
|
+
return f"""## Conversation: {conversation.id}
|
|
265
|
+
**Project:** {conversation.project_id}
|
|
266
|
+
**Instance:** {conversation.instance_name}
|
|
267
|
+
**Updated:** {timestamp}
|
|
268
|
+
**Messages:** {conversation.message_count}
|
|
269
|
+
|
|
270
|
+
{summary}"""
|
|
271
|
+
|
|
272
|
+
async def auto_summarize_conversation(
|
|
273
|
+
self, conversation: Conversation
|
|
274
|
+
) -> Optional[str]:
|
|
275
|
+
"""Automatically summarize conversation if needed.
|
|
276
|
+
|
|
277
|
+
Checks if conversation needs summarization and generates one if so.
|
|
278
|
+
Updates the conversation's summary field but does NOT save to store.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
conversation: Conversation to summarize
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Summary if generated, None if not needed
|
|
285
|
+
|
|
286
|
+
Example:
|
|
287
|
+
>>> summary = await compressor.auto_summarize_conversation(conv)
|
|
288
|
+
>>> if summary:
|
|
289
|
+
... conv.summary = summary
|
|
290
|
+
... await store.save(conv)
|
|
291
|
+
"""
|
|
292
|
+
if not self.needs_summarization(conversation.messages):
|
|
293
|
+
logger.debug(
|
|
294
|
+
"Conversation %s too short to summarize (%d messages)",
|
|
295
|
+
conversation.id,
|
|
296
|
+
len(conversation.messages),
|
|
297
|
+
)
|
|
298
|
+
return None
|
|
299
|
+
|
|
300
|
+
if conversation.summary:
|
|
301
|
+
logger.debug("Conversation %s already has summary", conversation.id)
|
|
302
|
+
return conversation.summary
|
|
303
|
+
|
|
304
|
+
# Generate summary
|
|
305
|
+
summary = await self.summarize(conversation.messages)
|
|
306
|
+
logger.info("Auto-generated summary for conversation %s", conversation.id)
|
|
307
|
+
|
|
308
|
+
return summary
|
|
309
|
+
|
|
310
|
+
async def update_summary_if_stale(
|
|
311
|
+
self,
|
|
312
|
+
conversation: Conversation,
|
|
313
|
+
message_threshold: int = 5,
|
|
314
|
+
) -> Optional[str]:
|
|
315
|
+
"""Update summary if conversation has grown significantly.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
conversation: Conversation to check
|
|
319
|
+
message_threshold: New messages required to trigger update
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Updated summary if regenerated, None otherwise
|
|
323
|
+
|
|
324
|
+
Example:
|
|
325
|
+
>>> updated = await compressor.update_summary_if_stale(conv)
|
|
326
|
+
>>> if updated:
|
|
327
|
+
... conv.summary = updated
|
|
328
|
+
... await store.save(conv)
|
|
329
|
+
"""
|
|
330
|
+
if not conversation.summary:
|
|
331
|
+
# No existing summary, generate one
|
|
332
|
+
return await self.auto_summarize_conversation(conversation)
|
|
333
|
+
|
|
334
|
+
# Check if conversation has grown significantly
|
|
335
|
+
# (Simple heuristic: if more than threshold messages since last summarization)
|
|
336
|
+
# In practice, you'd track when summary was generated
|
|
337
|
+
if len(conversation.messages) < self.summary_threshold + message_threshold:
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
# Regenerate summary
|
|
341
|
+
logger.info(
|
|
342
|
+
"Regenerating stale summary for conversation %s (%d messages)",
|
|
343
|
+
conversation.id,
|
|
344
|
+
len(conversation.messages),
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return await self.summarize(conversation.messages)
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""Embedding service for semantic search.
|
|
2
|
+
|
|
3
|
+
Generates vector embeddings using sentence-transformers (local) or
|
|
4
|
+
OpenAI API (cloud). Defaults to local model for zero-cost operation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import logging
|
|
9
|
+
from typing import List, Literal, Optional
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
EmbeddingProvider = Literal["sentence-transformers", "openai"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class EmbeddingService:
|
|
17
|
+
"""Generate vector embeddings for semantic search.
|
|
18
|
+
|
|
19
|
+
Supports multiple providers:
|
|
20
|
+
- sentence-transformers: Local, free, good quality (default)
|
|
21
|
+
- openai: Cloud API, best quality, costs money
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
provider: Embedding provider to use
|
|
25
|
+
model: Model name for the provider
|
|
26
|
+
dimension: Embedding vector dimension
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> embeddings = EmbeddingService(provider="sentence-transformers")
|
|
30
|
+
>>> vector = await embeddings.embed("Fix the login bug")
|
|
31
|
+
>>> len(vector)
|
|
32
|
+
384
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
provider: EmbeddingProvider = "sentence-transformers",
|
|
38
|
+
model: Optional[str] = None,
|
|
39
|
+
):
|
|
40
|
+
"""Initialize embedding service.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
provider: Embedding provider ('sentence-transformers' or 'openai')
|
|
44
|
+
model: Model name (provider-specific default if None)
|
|
45
|
+
"""
|
|
46
|
+
self.provider = provider
|
|
47
|
+
self._encoder = None
|
|
48
|
+
self._client = None
|
|
49
|
+
|
|
50
|
+
if provider == "sentence-transformers":
|
|
51
|
+
self.model = model or "all-MiniLM-L6-v2"
|
|
52
|
+
self.dimension = 384
|
|
53
|
+
self._init_sentence_transformers()
|
|
54
|
+
elif provider == "openai":
|
|
55
|
+
self.model = model or "text-embedding-3-small"
|
|
56
|
+
self.dimension = 1536
|
|
57
|
+
self._init_openai()
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Unknown provider: {provider}")
|
|
60
|
+
|
|
61
|
+
logger.info(
|
|
62
|
+
"EmbeddingService initialized (provider: %s, model: %s, dim: %d)",
|
|
63
|
+
provider,
|
|
64
|
+
self.model,
|
|
65
|
+
self.dimension,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def _init_sentence_transformers(self) -> None:
|
|
69
|
+
"""Initialize sentence-transformers encoder.
|
|
70
|
+
|
|
71
|
+
Lazy loads on first use to avoid startup delay.
|
|
72
|
+
"""
|
|
73
|
+
# Lazy import to avoid dependency if not used
|
|
74
|
+
try:
|
|
75
|
+
from sentence_transformers import SentenceTransformer
|
|
76
|
+
|
|
77
|
+
self._encoder = SentenceTransformer(self.model)
|
|
78
|
+
logger.info("Loaded sentence-transformers model: %s", self.model)
|
|
79
|
+
except ImportError:
|
|
80
|
+
logger.error(
|
|
81
|
+
"sentence-transformers not installed. "
|
|
82
|
+
"Install with: pip install sentence-transformers"
|
|
83
|
+
)
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
def _init_openai(self) -> None:
|
|
87
|
+
"""Initialize OpenAI client.
|
|
88
|
+
|
|
89
|
+
Requires OPENAI_API_KEY environment variable.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
from openai import AsyncOpenAI
|
|
93
|
+
|
|
94
|
+
self._client = AsyncOpenAI()
|
|
95
|
+
logger.info("Initialized OpenAI client")
|
|
96
|
+
except ImportError:
|
|
97
|
+
logger.error("openai not installed. Install with: pip install openai")
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
async def embed(self, text: str) -> List[float]:
|
|
101
|
+
"""Generate embedding for text.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
text: Text to embed
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Embedding vector as list of floats
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
>>> vector = await embeddings.embed("Fix the login bug")
|
|
111
|
+
>>> len(vector)
|
|
112
|
+
384
|
|
113
|
+
"""
|
|
114
|
+
if self.provider == "sentence-transformers":
|
|
115
|
+
return await self._embed_sentence_transformers(text)
|
|
116
|
+
if self.provider == "openai":
|
|
117
|
+
return await self._embed_openai(text)
|
|
118
|
+
raise ValueError(f"Unknown provider: {self.provider}")
|
|
119
|
+
|
|
120
|
+
async def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
121
|
+
"""Generate embeddings for multiple texts.
|
|
122
|
+
|
|
123
|
+
More efficient than calling embed() in a loop.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
texts: List of texts to embed
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
List of embedding vectors
|
|
130
|
+
|
|
131
|
+
Example:
|
|
132
|
+
>>> vectors = await embeddings.embed_batch([
|
|
133
|
+
... "Fix the login bug",
|
|
134
|
+
... "Update the README"
|
|
135
|
+
... ])
|
|
136
|
+
>>> len(vectors)
|
|
137
|
+
2
|
|
138
|
+
"""
|
|
139
|
+
if self.provider == "sentence-transformers":
|
|
140
|
+
return await self._embed_batch_sentence_transformers(texts)
|
|
141
|
+
if self.provider == "openai":
|
|
142
|
+
return await self._embed_batch_openai(texts)
|
|
143
|
+
raise ValueError(f"Unknown provider: {self.provider}")
|
|
144
|
+
|
|
145
|
+
async def _embed_sentence_transformers(self, text: str) -> List[float]:
|
|
146
|
+
"""Generate embedding using sentence-transformers.
|
|
147
|
+
|
|
148
|
+
Runs in executor to avoid blocking event loop.
|
|
149
|
+
"""
|
|
150
|
+
if self._encoder is None:
|
|
151
|
+
raise RuntimeError("Encoder not initialized")
|
|
152
|
+
|
|
153
|
+
# Run encoding in executor (CPU-bound operation)
|
|
154
|
+
loop = asyncio.get_event_loop()
|
|
155
|
+
embedding = await loop.run_in_executor(
|
|
156
|
+
None, lambda: self._encoder.encode(text, convert_to_numpy=True)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return embedding.tolist()
|
|
160
|
+
|
|
161
|
+
async def _embed_batch_sentence_transformers(
|
|
162
|
+
self, texts: List[str]
|
|
163
|
+
) -> List[List[float]]:
|
|
164
|
+
"""Generate batch embeddings using sentence-transformers."""
|
|
165
|
+
if self._encoder is None:
|
|
166
|
+
raise RuntimeError("Encoder not initialized")
|
|
167
|
+
|
|
168
|
+
# Run batch encoding in executor
|
|
169
|
+
loop = asyncio.get_event_loop()
|
|
170
|
+
embeddings = await loop.run_in_executor(
|
|
171
|
+
None,
|
|
172
|
+
lambda: self._encoder.encode(
|
|
173
|
+
texts, convert_to_numpy=True, show_progress_bar=False
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return [emb.tolist() for emb in embeddings]
|
|
178
|
+
|
|
179
|
+
async def _embed_openai(self, text: str) -> List[float]:
|
|
180
|
+
"""Generate embedding using OpenAI API."""
|
|
181
|
+
if self._client is None:
|
|
182
|
+
raise RuntimeError("OpenAI client not initialized")
|
|
183
|
+
|
|
184
|
+
response = await self._client.embeddings.create(
|
|
185
|
+
model=self.model,
|
|
186
|
+
input=text,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return response.data[0].embedding
|
|
190
|
+
|
|
191
|
+
async def _embed_batch_openai(self, texts: List[str]) -> List[List[float]]:
|
|
192
|
+
"""Generate batch embeddings using OpenAI API."""
|
|
193
|
+
if self._client is None:
|
|
194
|
+
raise RuntimeError("OpenAI client not initialized")
|
|
195
|
+
|
|
196
|
+
response = await self._client.embeddings.create(
|
|
197
|
+
model=self.model,
|
|
198
|
+
input=texts,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return [item.embedding for item in response.data]
|
|
202
|
+
|
|
203
|
+
def cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
|
204
|
+
"""Calculate cosine similarity between two vectors.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
vec1: First vector
|
|
208
|
+
vec2: Second vector
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Similarity score in range [-1, 1] (1 = identical, -1 = opposite)
|
|
212
|
+
|
|
213
|
+
Example:
|
|
214
|
+
>>> sim = embeddings.cosine_similarity(vec1, vec2)
|
|
215
|
+
>>> print(f"Similarity: {sim:.3f}")
|
|
216
|
+
"""
|
|
217
|
+
import math
|
|
218
|
+
|
|
219
|
+
# Dot product
|
|
220
|
+
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
|
221
|
+
|
|
222
|
+
# Magnitudes
|
|
223
|
+
mag1 = math.sqrt(sum(a * a for a in vec1))
|
|
224
|
+
mag2 = math.sqrt(sum(b * b for b in vec2))
|
|
225
|
+
|
|
226
|
+
# Avoid division by zero
|
|
227
|
+
if mag1 == 0 or mag2 == 0:
|
|
228
|
+
return 0.0
|
|
229
|
+
|
|
230
|
+
return dot_product / (mag1 * mag2)
|