agent-brain-rag 1.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/METADATA +55 -18
- agent_brain_rag-3.0.0.dist-info/RECORD +56 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/WHEEL +1 -1
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/entry_points.txt +0 -1
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +146 -45
- agent_brain_server/api/routers/__init__.py +2 -0
- agent_brain_server/api/routers/health.py +85 -21
- agent_brain_server/api/routers/index.py +108 -36
- agent_brain_server/api/routers/jobs.py +111 -0
- agent_brain_server/config/provider_config.py +352 -0
- agent_brain_server/config/settings.py +22 -5
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/bm25_index.py +15 -2
- agent_brain_server/indexing/document_loader.py +45 -4
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/job_queue/__init__.py +11 -0
- agent_brain_server/job_queue/job_service.py +317 -0
- agent_brain_server/job_queue/job_store.py +427 -0
- agent_brain_server/job_queue/job_worker.py +434 -0
- agent_brain_server/locking.py +101 -8
- agent_brain_server/models/__init__.py +28 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +30 -3
- agent_brain_server/models/job.py +289 -0
- agent_brain_server/models/query.py +16 -3
- agent_brain_server/project_root.py +1 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/runtime.py +2 -2
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +5 -3
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Pluggable model providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This package provides abstractions for embedding and summarization providers,
|
|
4
|
+
allowing configuration-driven selection between OpenAI, Ollama, Cohere (embeddings)
|
|
5
|
+
and Anthropic, OpenAI, Gemini, Grok, Ollama (summarization) providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import (
|
|
9
|
+
BaseEmbeddingProvider,
|
|
10
|
+
BaseSummarizationProvider,
|
|
11
|
+
EmbeddingProvider,
|
|
12
|
+
EmbeddingProviderType,
|
|
13
|
+
SummarizationProvider,
|
|
14
|
+
SummarizationProviderType,
|
|
15
|
+
)
|
|
16
|
+
from agent_brain_server.providers.exceptions import (
|
|
17
|
+
AuthenticationError,
|
|
18
|
+
ConfigurationError,
|
|
19
|
+
ModelNotFoundError,
|
|
20
|
+
ProviderError,
|
|
21
|
+
ProviderMismatchError,
|
|
22
|
+
ProviderNotFoundError,
|
|
23
|
+
RateLimitError,
|
|
24
|
+
)
|
|
25
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Protocols
|
|
29
|
+
"EmbeddingProvider",
|
|
30
|
+
"SummarizationProvider",
|
|
31
|
+
# Base classes
|
|
32
|
+
"BaseEmbeddingProvider",
|
|
33
|
+
"BaseSummarizationProvider",
|
|
34
|
+
# Enums
|
|
35
|
+
"EmbeddingProviderType",
|
|
36
|
+
"SummarizationProviderType",
|
|
37
|
+
# Factory
|
|
38
|
+
"ProviderRegistry",
|
|
39
|
+
# Exceptions
|
|
40
|
+
"ProviderError",
|
|
41
|
+
"ConfigurationError",
|
|
42
|
+
"AuthenticationError",
|
|
43
|
+
"ProviderNotFoundError",
|
|
44
|
+
"ProviderMismatchError",
|
|
45
|
+
"RateLimitError",
|
|
46
|
+
"ModelNotFoundError",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _register_providers() -> None:
|
|
51
|
+
"""Register all built-in providers with the registry."""
|
|
52
|
+
# Import providers to trigger registration
|
|
53
|
+
from agent_brain_server.providers import (
|
|
54
|
+
embedding, # noqa: F401
|
|
55
|
+
summarization, # noqa: F401
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Silence unused import warnings
|
|
59
|
+
_ = embedding
|
|
60
|
+
_ = summarization
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Auto-register providers on import
|
|
64
|
+
_register_providers()
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Base protocols and classes for pluggable providers."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Optional, Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EmbeddingProviderType(str, Enum):
|
|
13
|
+
"""Supported embedding providers."""
|
|
14
|
+
|
|
15
|
+
OPENAI = "openai"
|
|
16
|
+
OLLAMA = "ollama"
|
|
17
|
+
COHERE = "cohere"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SummarizationProviderType(str, Enum):
|
|
21
|
+
"""Supported summarization providers."""
|
|
22
|
+
|
|
23
|
+
ANTHROPIC = "anthropic"
|
|
24
|
+
OPENAI = "openai"
|
|
25
|
+
GEMINI = "gemini"
|
|
26
|
+
GROK = "grok"
|
|
27
|
+
OLLAMA = "ollama"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@runtime_checkable
|
|
31
|
+
class EmbeddingProvider(Protocol):
|
|
32
|
+
"""Protocol for embedding providers.
|
|
33
|
+
|
|
34
|
+
All embedding providers must implement this interface to be usable
|
|
35
|
+
by the Agent Brain indexing and query systems.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
39
|
+
"""Generate embedding for a single text.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
text: Text to embed.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Embedding vector as list of floats.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ProviderError: If embedding generation fails.
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
async def embed_texts(
|
|
53
|
+
self,
|
|
54
|
+
texts: list[str],
|
|
55
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
56
|
+
) -> list[list[float]]:
|
|
57
|
+
"""Generate embeddings for multiple texts.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
texts: List of texts to embed.
|
|
61
|
+
progress_callback: Optional callback(processed, total) for progress.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
List of embedding vectors, one per input text.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ProviderError: If embedding generation fails.
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
def get_dimensions(self) -> int:
|
|
72
|
+
"""Get the embedding vector dimensions for the current model.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Number of dimensions in the embedding vector.
|
|
76
|
+
"""
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def provider_name(self) -> str:
|
|
81
|
+
"""Human-readable provider name for logging."""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def model_name(self) -> str:
|
|
86
|
+
"""Model identifier being used."""
|
|
87
|
+
...
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@runtime_checkable
|
|
91
|
+
class SummarizationProvider(Protocol):
|
|
92
|
+
"""Protocol for summarization/LLM providers.
|
|
93
|
+
|
|
94
|
+
All summarization providers must implement this interface to be usable
|
|
95
|
+
by the Agent Brain code summarization system.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
async def summarize(self, text: str) -> str:
|
|
99
|
+
"""Generate a summary of the given text.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
text: Text to summarize (typically source code).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Natural language summary of the text.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ProviderError: If summarization fails.
|
|
109
|
+
"""
|
|
110
|
+
...
|
|
111
|
+
|
|
112
|
+
async def generate(self, prompt: str) -> str:
|
|
113
|
+
"""Generate text based on a prompt (generic LLM call).
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
prompt: The prompt to send to the LLM.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Generated text response.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ProviderError: If generation fails.
|
|
123
|
+
"""
|
|
124
|
+
...
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def provider_name(self) -> str:
|
|
128
|
+
"""Human-readable provider name for logging."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def model_name(self) -> str:
|
|
133
|
+
"""Model identifier being used."""
|
|
134
|
+
...
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class BaseEmbeddingProvider(ABC):
|
|
138
|
+
"""Base class for embedding providers with common functionality."""
|
|
139
|
+
|
|
140
|
+
def __init__(self, model: str, batch_size: int = 100) -> None:
|
|
141
|
+
self._model = model
|
|
142
|
+
self._batch_size = batch_size
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Initialized {self.provider_name} embedding provider with model {model}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def model_name(self) -> str:
|
|
149
|
+
"""Model identifier being used."""
|
|
150
|
+
return self._model
|
|
151
|
+
|
|
152
|
+
async def embed_texts(
|
|
153
|
+
self,
|
|
154
|
+
texts: list[str],
|
|
155
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
156
|
+
) -> list[list[float]]:
|
|
157
|
+
"""Default batch implementation using _embed_batch."""
|
|
158
|
+
if not texts:
|
|
159
|
+
return []
|
|
160
|
+
|
|
161
|
+
all_embeddings: list[list[float]] = []
|
|
162
|
+
|
|
163
|
+
for i in range(0, len(texts), self._batch_size):
|
|
164
|
+
batch = texts[i : i + self._batch_size]
|
|
165
|
+
batch_embeddings = await self._embed_batch(batch)
|
|
166
|
+
all_embeddings.extend(batch_embeddings)
|
|
167
|
+
|
|
168
|
+
if progress_callback:
|
|
169
|
+
await progress_callback(
|
|
170
|
+
min(i + self._batch_size, len(texts)),
|
|
171
|
+
len(texts),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
logger.debug(
|
|
175
|
+
f"Generated embeddings for batch {i // self._batch_size + 1} "
|
|
176
|
+
f"({len(batch)} texts)"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return all_embeddings
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
183
|
+
"""Provider-specific batch embedding implementation."""
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
188
|
+
"""Provider-specific single text embedding."""
|
|
189
|
+
...
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def get_dimensions(self) -> int:
|
|
193
|
+
"""Provider-specific dimension lookup."""
|
|
194
|
+
...
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def provider_name(self) -> str:
|
|
199
|
+
"""Human-readable provider name for logging."""
|
|
200
|
+
...
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class BaseSummarizationProvider(ABC):
|
|
204
|
+
"""Base class for summarization providers with common functionality."""
|
|
205
|
+
|
|
206
|
+
DEFAULT_PROMPT_TEMPLATE = (
|
|
207
|
+
"You are an expert software engineer analyzing source code. "
|
|
208
|
+
"Provide a concise 1-2 sentence summary of what this code does. "
|
|
209
|
+
"Focus on the functionality, purpose, and behavior. "
|
|
210
|
+
"Be specific about inputs, outputs, and side effects. "
|
|
211
|
+
"Ignore implementation details and focus on what the code accomplishes.\n\n"
|
|
212
|
+
"Code to summarize:\n{code}\n\n"
|
|
213
|
+
"Summary:"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
model: str,
|
|
219
|
+
max_tokens: int = 300,
|
|
220
|
+
temperature: float = 0.1,
|
|
221
|
+
prompt_template: Optional[str] = None,
|
|
222
|
+
) -> None:
|
|
223
|
+
self._model = model
|
|
224
|
+
self._max_tokens = max_tokens
|
|
225
|
+
self._temperature = temperature
|
|
226
|
+
self._prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE
|
|
227
|
+
logger.info(
|
|
228
|
+
f"Initialized {self.provider_name} summarization provider "
|
|
229
|
+
f"with model {model}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def model_name(self) -> str:
|
|
234
|
+
"""Model identifier being used."""
|
|
235
|
+
return self._model
|
|
236
|
+
|
|
237
|
+
async def summarize(self, text: str) -> str:
|
|
238
|
+
"""Generate summary using the prompt template."""
|
|
239
|
+
prompt = self._prompt_template.format(code=text)
|
|
240
|
+
return await self.generate(prompt)
|
|
241
|
+
|
|
242
|
+
@abstractmethod
|
|
243
|
+
async def generate(self, prompt: str) -> str:
|
|
244
|
+
"""Provider-specific text generation."""
|
|
245
|
+
...
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
@abstractmethod
|
|
249
|
+
def provider_name(self) -> str:
|
|
250
|
+
"""Human-readable provider name for logging."""
|
|
251
|
+
...
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Embedding providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This module provides embedding implementations for:
|
|
4
|
+
- OpenAI (text-embedding-3-large, text-embedding-3-small, text-embedding-ada-002)
|
|
5
|
+
- Ollama (nomic-embed-text, mxbai-embed-large, etc.)
|
|
6
|
+
- Cohere (embed-english-v3, embed-multilingual-v3, etc.)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from agent_brain_server.providers.embedding.cohere import CohereEmbeddingProvider
|
|
10
|
+
from agent_brain_server.providers.embedding.ollama import OllamaEmbeddingProvider
|
|
11
|
+
from agent_brain_server.providers.embedding.openai import OpenAIEmbeddingProvider
|
|
12
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
13
|
+
|
|
14
|
+
# Register embedding providers
|
|
15
|
+
ProviderRegistry.register_embedding_provider("openai", OpenAIEmbeddingProvider)
|
|
16
|
+
ProviderRegistry.register_embedding_provider("ollama", OllamaEmbeddingProvider)
|
|
17
|
+
ProviderRegistry.register_embedding_provider("cohere", CohereEmbeddingProvider)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"OpenAIEmbeddingProvider",
|
|
21
|
+
"OllamaEmbeddingProvider",
|
|
22
|
+
"CohereEmbeddingProvider",
|
|
23
|
+
]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Cohere embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import cohere
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Model dimension mappings for Cohere embedding models
|
|
17
|
+
COHERE_MODEL_DIMENSIONS: dict[str, int] = {
|
|
18
|
+
"embed-english-v3.0": 1024,
|
|
19
|
+
"embed-english-light-v3.0": 384,
|
|
20
|
+
"embed-multilingual-v3.0": 1024,
|
|
21
|
+
"embed-multilingual-light-v3.0": 384,
|
|
22
|
+
"embed-english-v2.0": 4096,
|
|
23
|
+
"embed-english-light-v2.0": 1024,
|
|
24
|
+
"embed-multilingual-v2.0": 768,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_COHERE_DIMENSIONS = 1024
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CohereEmbeddingProvider(BaseEmbeddingProvider):
|
|
31
|
+
"""Cohere embedding provider using Cohere's embedding models.
|
|
32
|
+
|
|
33
|
+
Supports:
|
|
34
|
+
- embed-english-v3.0 (1024 dimensions, best for English)
|
|
35
|
+
- embed-english-light-v3.0 (384 dimensions, faster)
|
|
36
|
+
- embed-multilingual-v3.0 (1024 dimensions, 100+ languages)
|
|
37
|
+
- embed-multilingual-light-v3.0 (384 dimensions, faster multilingual)
|
|
38
|
+
|
|
39
|
+
Cohere embeddings support different input types for optimal performance:
|
|
40
|
+
- search_document: For indexing documents to be searched
|
|
41
|
+
- search_query: For search queries
|
|
42
|
+
- classification: For classification tasks
|
|
43
|
+
- clustering: For clustering tasks
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
47
|
+
"""Initialize Cohere embedding provider.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
config: Embedding configuration
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
AuthenticationError: If API key is not available
|
|
54
|
+
"""
|
|
55
|
+
api_key = config.get_api_key()
|
|
56
|
+
if not api_key:
|
|
57
|
+
raise AuthenticationError(
|
|
58
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
59
|
+
self.provider_name,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
batch_size = config.params.get("batch_size", 96) # Cohere limit
|
|
63
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
64
|
+
|
|
65
|
+
self._client = cohere.AsyncClientV2(api_key=api_key)
|
|
66
|
+
self._input_type = config.params.get("input_type", "search_document")
|
|
67
|
+
self._truncate = config.params.get("truncate", "END")
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def provider_name(self) -> str:
|
|
71
|
+
"""Human-readable provider name."""
|
|
72
|
+
return "Cohere"
|
|
73
|
+
|
|
74
|
+
def get_dimensions(self) -> int:
|
|
75
|
+
"""Get embedding dimensions for current model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Number of dimensions in embedding vector
|
|
79
|
+
"""
|
|
80
|
+
return COHERE_MODEL_DIMENSIONS.get(self._model, DEFAULT_COHERE_DIMENSIONS)
|
|
81
|
+
|
|
82
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
83
|
+
"""Generate embedding for single text.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
text: Text to embed
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Embedding vector as list of floats
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ProviderError: If embedding generation fails
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
response = await self._client.embed(
|
|
96
|
+
texts=[text],
|
|
97
|
+
model=self._model,
|
|
98
|
+
input_type=self._input_type,
|
|
99
|
+
truncate=self._truncate,
|
|
100
|
+
)
|
|
101
|
+
embeddings = response.embeddings.float_
|
|
102
|
+
if embeddings is None:
|
|
103
|
+
raise ProviderError(
|
|
104
|
+
"No embeddings returned from Cohere",
|
|
105
|
+
self.provider_name,
|
|
106
|
+
)
|
|
107
|
+
return list(embeddings[0])
|
|
108
|
+
except Exception as e:
|
|
109
|
+
raise ProviderError(
|
|
110
|
+
f"Failed to generate embedding: {e}",
|
|
111
|
+
self.provider_name,
|
|
112
|
+
cause=e,
|
|
113
|
+
) from e
|
|
114
|
+
|
|
115
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
116
|
+
"""Generate embeddings for a batch of texts.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
texts: List of texts to embed
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of embedding vectors
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ProviderError: If embedding generation fails
|
|
126
|
+
"""
|
|
127
|
+
try:
|
|
128
|
+
response = await self._client.embed(
|
|
129
|
+
texts=texts,
|
|
130
|
+
model=self._model,
|
|
131
|
+
input_type=self._input_type,
|
|
132
|
+
truncate=self._truncate,
|
|
133
|
+
)
|
|
134
|
+
embeddings = response.embeddings.float_
|
|
135
|
+
if embeddings is None:
|
|
136
|
+
raise ProviderError(
|
|
137
|
+
"No embeddings returned from Cohere",
|
|
138
|
+
self.provider_name,
|
|
139
|
+
)
|
|
140
|
+
return [list(emb) for emb in embeddings]
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise ProviderError(
|
|
143
|
+
f"Failed to generate batch embeddings: {e}",
|
|
144
|
+
self.provider_name,
|
|
145
|
+
cause=e,
|
|
146
|
+
) from e
|
|
147
|
+
|
|
148
|
+
def set_input_type(self, input_type: str) -> None:
|
|
149
|
+
"""Set the input type for embeddings.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
input_type: One of 'search_document', 'search_query',
|
|
153
|
+
'classification', or 'clustering'
|
|
154
|
+
"""
|
|
155
|
+
valid_types = [
|
|
156
|
+
"search_document",
|
|
157
|
+
"search_query",
|
|
158
|
+
"classification",
|
|
159
|
+
"clustering",
|
|
160
|
+
]
|
|
161
|
+
if input_type not in valid_types:
|
|
162
|
+
raise ValueError(f"Invalid input_type. Must be one of: {valid_types}")
|
|
163
|
+
self._input_type = input_type
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Ollama embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import (
|
|
10
|
+
OllamaConnectionError,
|
|
11
|
+
ProviderError,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Model dimension mappings for common Ollama embedding models
|
|
20
|
+
OLLAMA_MODEL_DIMENSIONS: dict[str, int] = {
|
|
21
|
+
"nomic-embed-text": 768,
|
|
22
|
+
"mxbai-embed-large": 1024,
|
|
23
|
+
"all-minilm": 384,
|
|
24
|
+
"snowflake-arctic-embed": 1024,
|
|
25
|
+
"bge-m3": 1024,
|
|
26
|
+
"bge-large": 1024,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
DEFAULT_OLLAMA_DIMENSIONS = 768
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OllamaEmbeddingProvider(BaseEmbeddingProvider):
|
|
33
|
+
"""Ollama embedding provider using local models.
|
|
34
|
+
|
|
35
|
+
Uses OpenAI-compatible API endpoint provided by Ollama.
|
|
36
|
+
|
|
37
|
+
Supports:
|
|
38
|
+
- nomic-embed-text (768 dimensions, general purpose)
|
|
39
|
+
- mxbai-embed-large (1024 dimensions, multilingual)
|
|
40
|
+
- all-minilm (384 dimensions, lightweight)
|
|
41
|
+
- snowflake-arctic-embed (1024 dimensions, high quality)
|
|
42
|
+
- And any other embedding model available in Ollama
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
46
|
+
"""Initialize Ollama embedding provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: Embedding configuration
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
Ollama does not require an API key as it runs locally.
|
|
53
|
+
"""
|
|
54
|
+
batch_size = config.params.get("batch_size", 100)
|
|
55
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
56
|
+
|
|
57
|
+
# Ollama uses OpenAI-compatible API
|
|
58
|
+
base_url = config.get_base_url() or "http://localhost:11434/v1"
|
|
59
|
+
self._base_url = base_url
|
|
60
|
+
self._client = AsyncOpenAI(
|
|
61
|
+
api_key="ollama", # Ollama doesn't need real key
|
|
62
|
+
base_url=base_url,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Optional parameters
|
|
66
|
+
self._num_ctx = config.params.get("num_ctx", 2048)
|
|
67
|
+
self._num_threads = config.params.get("num_threads")
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def provider_name(self) -> str:
|
|
71
|
+
"""Human-readable provider name."""
|
|
72
|
+
return "Ollama"
|
|
73
|
+
|
|
74
|
+
def get_dimensions(self) -> int:
|
|
75
|
+
"""Get embedding dimensions for current model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Number of dimensions in embedding vector
|
|
79
|
+
"""
|
|
80
|
+
return OLLAMA_MODEL_DIMENSIONS.get(self._model, DEFAULT_OLLAMA_DIMENSIONS)
|
|
81
|
+
|
|
82
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
83
|
+
"""Generate embedding for single text.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
text: Text to embed
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Embedding vector as list of floats
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
OllamaConnectionError: If Ollama is not running
|
|
93
|
+
ProviderError: If embedding generation fails
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
response = await self._client.embeddings.create(
|
|
97
|
+
model=self._model,
|
|
98
|
+
input=text,
|
|
99
|
+
)
|
|
100
|
+
return response.data[0].embedding
|
|
101
|
+
except Exception as e:
|
|
102
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
103
|
+
raise OllamaConnectionError(self._base_url, cause=e) from e
|
|
104
|
+
raise ProviderError(
|
|
105
|
+
f"Failed to generate embedding: {e}",
|
|
106
|
+
self.provider_name,
|
|
107
|
+
cause=e,
|
|
108
|
+
) from e
|
|
109
|
+
|
|
110
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
111
|
+
"""Generate embeddings for a batch of texts.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
texts: List of texts to embed
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of embedding vectors
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
OllamaConnectionError: If Ollama is not running
|
|
121
|
+
ProviderError: If embedding generation fails
|
|
122
|
+
"""
|
|
123
|
+
try:
|
|
124
|
+
response = await self._client.embeddings.create(
|
|
125
|
+
model=self._model,
|
|
126
|
+
input=texts,
|
|
127
|
+
)
|
|
128
|
+
return [item.embedding for item in response.data]
|
|
129
|
+
except Exception as e:
|
|
130
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
131
|
+
raise OllamaConnectionError(self._base_url, cause=e) from e
|
|
132
|
+
raise ProviderError(
|
|
133
|
+
f"Failed to generate batch embeddings: {e}",
|
|
134
|
+
self.provider_name,
|
|
135
|
+
cause=e,
|
|
136
|
+
) from e
|
|
137
|
+
|
|
138
|
+
async def health_check(self) -> bool:
|
|
139
|
+
"""Check if Ollama is running and accessible.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if Ollama is healthy, False otherwise
|
|
143
|
+
"""
|
|
144
|
+
try:
|
|
145
|
+
# Try to list models to verify connection
|
|
146
|
+
await self._client.models.list()
|
|
147
|
+
return True
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.warning(f"Ollama health check failed: {e}")
|
|
150
|
+
return False
|