agent-brain-rag 1.2.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +54 -16
- agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +30 -2
- agent_brain_server/api/routers/health.py +1 -0
- agent_brain_server/config/provider_config.py +308 -0
- agent_brain_server/config/settings.py +12 -1
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/models/__init__.py +9 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +15 -3
- agent_brain_server/models/query.py +14 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +2 -0
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Cohere embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import cohere
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Model dimension mappings for Cohere embedding models
|
|
17
|
+
COHERE_MODEL_DIMENSIONS: dict[str, int] = {
|
|
18
|
+
"embed-english-v3.0": 1024,
|
|
19
|
+
"embed-english-light-v3.0": 384,
|
|
20
|
+
"embed-multilingual-v3.0": 1024,
|
|
21
|
+
"embed-multilingual-light-v3.0": 384,
|
|
22
|
+
"embed-english-v2.0": 4096,
|
|
23
|
+
"embed-english-light-v2.0": 1024,
|
|
24
|
+
"embed-multilingual-v2.0": 768,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
DEFAULT_COHERE_DIMENSIONS = 1024
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CohereEmbeddingProvider(BaseEmbeddingProvider):
|
|
31
|
+
"""Cohere embedding provider using Cohere's embedding models.
|
|
32
|
+
|
|
33
|
+
Supports:
|
|
34
|
+
- embed-english-v3.0 (1024 dimensions, best for English)
|
|
35
|
+
- embed-english-light-v3.0 (384 dimensions, faster)
|
|
36
|
+
- embed-multilingual-v3.0 (1024 dimensions, 100+ languages)
|
|
37
|
+
- embed-multilingual-light-v3.0 (384 dimensions, faster multilingual)
|
|
38
|
+
|
|
39
|
+
Cohere embeddings support different input types for optimal performance:
|
|
40
|
+
- search_document: For indexing documents to be searched
|
|
41
|
+
- search_query: For search queries
|
|
42
|
+
- classification: For classification tasks
|
|
43
|
+
- clustering: For clustering tasks
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
47
|
+
"""Initialize Cohere embedding provider.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
config: Embedding configuration
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
AuthenticationError: If API key is not available
|
|
54
|
+
"""
|
|
55
|
+
api_key = config.get_api_key()
|
|
56
|
+
if not api_key:
|
|
57
|
+
raise AuthenticationError(
|
|
58
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
59
|
+
self.provider_name,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
batch_size = config.params.get("batch_size", 96) # Cohere limit
|
|
63
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
64
|
+
|
|
65
|
+
self._client = cohere.AsyncClientV2(api_key=api_key)
|
|
66
|
+
self._input_type = config.params.get("input_type", "search_document")
|
|
67
|
+
self._truncate = config.params.get("truncate", "END")
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def provider_name(self) -> str:
|
|
71
|
+
"""Human-readable provider name."""
|
|
72
|
+
return "Cohere"
|
|
73
|
+
|
|
74
|
+
def get_dimensions(self) -> int:
|
|
75
|
+
"""Get embedding dimensions for current model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Number of dimensions in embedding vector
|
|
79
|
+
"""
|
|
80
|
+
return COHERE_MODEL_DIMENSIONS.get(self._model, DEFAULT_COHERE_DIMENSIONS)
|
|
81
|
+
|
|
82
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
83
|
+
"""Generate embedding for single text.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
text: Text to embed
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Embedding vector as list of floats
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ProviderError: If embedding generation fails
|
|
93
|
+
"""
|
|
94
|
+
try:
|
|
95
|
+
response = await self._client.embed(
|
|
96
|
+
texts=[text],
|
|
97
|
+
model=self._model,
|
|
98
|
+
input_type=self._input_type,
|
|
99
|
+
truncate=self._truncate,
|
|
100
|
+
)
|
|
101
|
+
embeddings = response.embeddings.float_
|
|
102
|
+
if embeddings is None:
|
|
103
|
+
raise ProviderError(
|
|
104
|
+
"No embeddings returned from Cohere",
|
|
105
|
+
self.provider_name,
|
|
106
|
+
)
|
|
107
|
+
return list(embeddings[0])
|
|
108
|
+
except Exception as e:
|
|
109
|
+
raise ProviderError(
|
|
110
|
+
f"Failed to generate embedding: {e}",
|
|
111
|
+
self.provider_name,
|
|
112
|
+
cause=e,
|
|
113
|
+
) from e
|
|
114
|
+
|
|
115
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
116
|
+
"""Generate embeddings for a batch of texts.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
texts: List of texts to embed
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of embedding vectors
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ProviderError: If embedding generation fails
|
|
126
|
+
"""
|
|
127
|
+
try:
|
|
128
|
+
response = await self._client.embed(
|
|
129
|
+
texts=texts,
|
|
130
|
+
model=self._model,
|
|
131
|
+
input_type=self._input_type,
|
|
132
|
+
truncate=self._truncate,
|
|
133
|
+
)
|
|
134
|
+
embeddings = response.embeddings.float_
|
|
135
|
+
if embeddings is None:
|
|
136
|
+
raise ProviderError(
|
|
137
|
+
"No embeddings returned from Cohere",
|
|
138
|
+
self.provider_name,
|
|
139
|
+
)
|
|
140
|
+
return [list(emb) for emb in embeddings]
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise ProviderError(
|
|
143
|
+
f"Failed to generate batch embeddings: {e}",
|
|
144
|
+
self.provider_name,
|
|
145
|
+
cause=e,
|
|
146
|
+
) from e
|
|
147
|
+
|
|
148
|
+
def set_input_type(self, input_type: str) -> None:
|
|
149
|
+
"""Set the input type for embeddings.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
input_type: One of 'search_document', 'search_query',
|
|
153
|
+
'classification', or 'clustering'
|
|
154
|
+
"""
|
|
155
|
+
valid_types = [
|
|
156
|
+
"search_document",
|
|
157
|
+
"search_query",
|
|
158
|
+
"classification",
|
|
159
|
+
"clustering",
|
|
160
|
+
]
|
|
161
|
+
if input_type not in valid_types:
|
|
162
|
+
raise ValueError(f"Invalid input_type. Must be one of: {valid_types}")
|
|
163
|
+
self._input_type = input_type
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Ollama embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import (
|
|
10
|
+
OllamaConnectionError,
|
|
11
|
+
ProviderError,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Model dimension mappings for common Ollama embedding models
|
|
20
|
+
OLLAMA_MODEL_DIMENSIONS: dict[str, int] = {
|
|
21
|
+
"nomic-embed-text": 768,
|
|
22
|
+
"mxbai-embed-large": 1024,
|
|
23
|
+
"all-minilm": 384,
|
|
24
|
+
"snowflake-arctic-embed": 1024,
|
|
25
|
+
"bge-m3": 1024,
|
|
26
|
+
"bge-large": 1024,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
DEFAULT_OLLAMA_DIMENSIONS = 768
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OllamaEmbeddingProvider(BaseEmbeddingProvider):
|
|
33
|
+
"""Ollama embedding provider using local models.
|
|
34
|
+
|
|
35
|
+
Uses OpenAI-compatible API endpoint provided by Ollama.
|
|
36
|
+
|
|
37
|
+
Supports:
|
|
38
|
+
- nomic-embed-text (768 dimensions, general purpose)
|
|
39
|
+
- mxbai-embed-large (1024 dimensions, multilingual)
|
|
40
|
+
- all-minilm (384 dimensions, lightweight)
|
|
41
|
+
- snowflake-arctic-embed (1024 dimensions, high quality)
|
|
42
|
+
- And any other embedding model available in Ollama
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
46
|
+
"""Initialize Ollama embedding provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: Embedding configuration
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
Ollama does not require an API key as it runs locally.
|
|
53
|
+
"""
|
|
54
|
+
batch_size = config.params.get("batch_size", 100)
|
|
55
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
56
|
+
|
|
57
|
+
# Ollama uses OpenAI-compatible API
|
|
58
|
+
base_url = config.get_base_url() or "http://localhost:11434/v1"
|
|
59
|
+
self._base_url = base_url
|
|
60
|
+
self._client = AsyncOpenAI(
|
|
61
|
+
api_key="ollama", # Ollama doesn't need real key
|
|
62
|
+
base_url=base_url,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Optional parameters
|
|
66
|
+
self._num_ctx = config.params.get("num_ctx", 2048)
|
|
67
|
+
self._num_threads = config.params.get("num_threads")
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def provider_name(self) -> str:
|
|
71
|
+
"""Human-readable provider name."""
|
|
72
|
+
return "Ollama"
|
|
73
|
+
|
|
74
|
+
def get_dimensions(self) -> int:
|
|
75
|
+
"""Get embedding dimensions for current model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Number of dimensions in embedding vector
|
|
79
|
+
"""
|
|
80
|
+
return OLLAMA_MODEL_DIMENSIONS.get(self._model, DEFAULT_OLLAMA_DIMENSIONS)
|
|
81
|
+
|
|
82
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
83
|
+
"""Generate embedding for single text.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
text: Text to embed
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Embedding vector as list of floats
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
OllamaConnectionError: If Ollama is not running
|
|
93
|
+
ProviderError: If embedding generation fails
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
response = await self._client.embeddings.create(
|
|
97
|
+
model=self._model,
|
|
98
|
+
input=text,
|
|
99
|
+
)
|
|
100
|
+
return response.data[0].embedding
|
|
101
|
+
except Exception as e:
|
|
102
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
103
|
+
raise OllamaConnectionError(self._base_url, cause=e) from e
|
|
104
|
+
raise ProviderError(
|
|
105
|
+
f"Failed to generate embedding: {e}",
|
|
106
|
+
self.provider_name,
|
|
107
|
+
cause=e,
|
|
108
|
+
) from e
|
|
109
|
+
|
|
110
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
111
|
+
"""Generate embeddings for a batch of texts.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
texts: List of texts to embed
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of embedding vectors
|
|
118
|
+
|
|
119
|
+
Raises:
|
|
120
|
+
OllamaConnectionError: If Ollama is not running
|
|
121
|
+
ProviderError: If embedding generation fails
|
|
122
|
+
"""
|
|
123
|
+
try:
|
|
124
|
+
response = await self._client.embeddings.create(
|
|
125
|
+
model=self._model,
|
|
126
|
+
input=texts,
|
|
127
|
+
)
|
|
128
|
+
return [item.embedding for item in response.data]
|
|
129
|
+
except Exception as e:
|
|
130
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
131
|
+
raise OllamaConnectionError(self._base_url, cause=e) from e
|
|
132
|
+
raise ProviderError(
|
|
133
|
+
f"Failed to generate batch embeddings: {e}",
|
|
134
|
+
self.provider_name,
|
|
135
|
+
cause=e,
|
|
136
|
+
) from e
|
|
137
|
+
|
|
138
|
+
async def health_check(self) -> bool:
|
|
139
|
+
"""Check if Ollama is running and accessible.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
True if Ollama is healthy, False otherwise
|
|
143
|
+
"""
|
|
144
|
+
try:
|
|
145
|
+
# Try to list models to verify connection
|
|
146
|
+
await self._client.models.list()
|
|
147
|
+
return True
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.warning(f"Ollama health check failed: {e}")
|
|
150
|
+
return False
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""OpenAI embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Model dimension mappings for OpenAI embedding models
|
|
17
|
+
OPENAI_MODEL_DIMENSIONS: dict[str, int] = {
|
|
18
|
+
"text-embedding-3-large": 3072,
|
|
19
|
+
"text-embedding-3-small": 1536,
|
|
20
|
+
"text-embedding-ada-002": 1536,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
25
|
+
"""OpenAI embedding provider using text-embedding models.
|
|
26
|
+
|
|
27
|
+
Supports:
|
|
28
|
+
- text-embedding-3-large (3072 dimensions, highest quality)
|
|
29
|
+
- text-embedding-3-small (1536 dimensions, faster)
|
|
30
|
+
- text-embedding-ada-002 (1536 dimensions, legacy)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
34
|
+
"""Initialize OpenAI embedding provider.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: Embedding configuration
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
AuthenticationError: If API key is not available
|
|
41
|
+
"""
|
|
42
|
+
api_key = config.get_api_key()
|
|
43
|
+
if not api_key:
|
|
44
|
+
raise AuthenticationError(
|
|
45
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
46
|
+
self.provider_name,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
batch_size = config.params.get("batch_size", 100)
|
|
50
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
51
|
+
|
|
52
|
+
self._client = AsyncOpenAI(api_key=api_key)
|
|
53
|
+
self._dimensions_override = config.params.get("dimensions")
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def provider_name(self) -> str:
|
|
57
|
+
"""Human-readable provider name."""
|
|
58
|
+
return "OpenAI"
|
|
59
|
+
|
|
60
|
+
def get_dimensions(self) -> int:
|
|
61
|
+
"""Get embedding dimensions for current model.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Number of dimensions in embedding vector
|
|
65
|
+
"""
|
|
66
|
+
if self._dimensions_override:
|
|
67
|
+
return int(self._dimensions_override)
|
|
68
|
+
return OPENAI_MODEL_DIMENSIONS.get(self._model, 3072)
|
|
69
|
+
|
|
70
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
71
|
+
"""Generate embedding for single text.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
text: Text to embed
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Embedding vector as list of floats
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ProviderError: If embedding generation fails
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
response = await self._client.embeddings.create(
|
|
84
|
+
model=self._model,
|
|
85
|
+
input=text,
|
|
86
|
+
)
|
|
87
|
+
return response.data[0].embedding
|
|
88
|
+
except Exception as e:
|
|
89
|
+
raise ProviderError(
|
|
90
|
+
f"Failed to generate embedding: {e}",
|
|
91
|
+
self.provider_name,
|
|
92
|
+
cause=e,
|
|
93
|
+
) from e
|
|
94
|
+
|
|
95
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
96
|
+
"""Generate embeddings for a batch of texts.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
texts: List of texts to embed
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of embedding vectors
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ProviderError: If embedding generation fails
|
|
106
|
+
"""
|
|
107
|
+
try:
|
|
108
|
+
response = await self._client.embeddings.create(
|
|
109
|
+
model=self._model,
|
|
110
|
+
input=texts,
|
|
111
|
+
)
|
|
112
|
+
return [item.embedding for item in response.data]
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ProviderError(
|
|
115
|
+
f"Failed to generate batch embeddings: {e}",
|
|
116
|
+
self.provider_name,
|
|
117
|
+
cause=e,
|
|
118
|
+
) from e
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Exception hierarchy for provider errors."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ProviderError(Exception):
|
|
7
|
+
"""Base exception for provider errors."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self, message: str, provider: str, cause: Optional[Exception] = None
|
|
11
|
+
) -> None:
|
|
12
|
+
self.provider = provider
|
|
13
|
+
self.cause = cause
|
|
14
|
+
super().__init__(f"[{provider}] {message}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConfigurationError(ProviderError):
|
|
18
|
+
"""Raised when provider configuration is invalid."""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthenticationError(ProviderError):
|
|
24
|
+
"""Raised when API key is missing or invalid."""
|
|
25
|
+
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ProviderNotFoundError(ProviderError):
|
|
30
|
+
"""Raised when requested provider type is not registered."""
|
|
31
|
+
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProviderMismatchError(ProviderError):
|
|
36
|
+
"""Raised when current provider doesn't match indexed data."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
current_provider: str,
|
|
41
|
+
current_model: str,
|
|
42
|
+
indexed_provider: str,
|
|
43
|
+
indexed_model: str,
|
|
44
|
+
) -> None:
|
|
45
|
+
message = (
|
|
46
|
+
f"Provider mismatch: index was created with "
|
|
47
|
+
f"{indexed_provider}/{indexed_model}, "
|
|
48
|
+
f"but current config uses {current_provider}/{current_model}. "
|
|
49
|
+
f"Re-index with --force to update."
|
|
50
|
+
)
|
|
51
|
+
super().__init__(message, current_provider)
|
|
52
|
+
self.current_model = current_model
|
|
53
|
+
self.indexed_provider = indexed_provider
|
|
54
|
+
self.indexed_model = indexed_model
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class RateLimitError(ProviderError):
|
|
58
|
+
"""Raised when provider rate limit is hit."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, provider: str, retry_after: Optional[int] = None) -> None:
|
|
61
|
+
self.retry_after = retry_after
|
|
62
|
+
message = "Rate limit exceeded"
|
|
63
|
+
if retry_after:
|
|
64
|
+
message += f", retry after {retry_after}s"
|
|
65
|
+
super().__init__(message, provider)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ModelNotFoundError(ProviderError):
|
|
69
|
+
"""Raised when specified model is not available."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self, provider: str, model: str, available_models: Optional[list[str]] = None
|
|
73
|
+
) -> None:
|
|
74
|
+
self.model = model
|
|
75
|
+
self.available_models = available_models or []
|
|
76
|
+
if available_models:
|
|
77
|
+
message = (
|
|
78
|
+
f"Model '{model}' not found. "
|
|
79
|
+
f"Available: {', '.join(available_models[:5])}"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
message = f"Model '{model}' not found"
|
|
83
|
+
super().__init__(message, provider)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class OllamaConnectionError(ProviderError):
|
|
87
|
+
"""Raised when Ollama is not running or unreachable."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, base_url: str, cause: Optional[Exception] = None) -> None:
|
|
90
|
+
message = (
|
|
91
|
+
f"Cannot connect to Ollama at {base_url}. "
|
|
92
|
+
"Ensure Ollama is running with 'ollama serve' command."
|
|
93
|
+
)
|
|
94
|
+
super().__init__(message, "ollama", cause)
|
|
95
|
+
self.base_url = base_url
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Provider factory and registry for dynamic provider instantiation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
5
|
+
|
|
6
|
+
from agent_brain_server.providers.exceptions import ProviderNotFoundError
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from agent_brain_server.config.provider_config import (
|
|
10
|
+
EmbeddingConfig,
|
|
11
|
+
SummarizationConfig,
|
|
12
|
+
)
|
|
13
|
+
from agent_brain_server.providers.base import (
|
|
14
|
+
EmbeddingProvider,
|
|
15
|
+
SummarizationProvider,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProviderRegistry:
|
|
22
|
+
"""Registry for provider implementations.
|
|
23
|
+
|
|
24
|
+
Allows dynamic registration of providers and lazy instantiation.
|
|
25
|
+
Implements singleton pattern for provider instance caching.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_embedding_providers: dict[str, type[Any]] = {}
|
|
29
|
+
_summarization_providers: dict[str, type[Any]] = {}
|
|
30
|
+
_instances: dict[str, Any] = {}
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def register_embedding_provider(
|
|
34
|
+
cls,
|
|
35
|
+
provider_type: str,
|
|
36
|
+
provider_class: type["EmbeddingProvider"],
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Register an embedding provider class.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
provider_type: Provider identifier (e.g., 'openai', 'ollama')
|
|
42
|
+
provider_class: Provider class implementing EmbeddingProvider protocol
|
|
43
|
+
"""
|
|
44
|
+
cls._embedding_providers[provider_type] = provider_class
|
|
45
|
+
logger.debug(f"Registered embedding provider: {provider_type}")
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def register_summarization_provider(
|
|
49
|
+
cls,
|
|
50
|
+
provider_type: str,
|
|
51
|
+
provider_class: type["SummarizationProvider"],
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Register a summarization provider class.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
provider_type: Provider identifier (e.g., 'anthropic', 'openai')
|
|
57
|
+
provider_class: Provider class implementing SummarizationProvider protocol
|
|
58
|
+
"""
|
|
59
|
+
cls._summarization_providers[provider_type] = provider_class
|
|
60
|
+
logger.debug(f"Registered summarization provider: {provider_type}")
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def get_embedding_provider(cls, config: "EmbeddingConfig") -> "EmbeddingProvider":
|
|
64
|
+
"""Get or create embedding provider instance.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: Embedding provider configuration
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Configured EmbeddingProvider instance
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ProviderNotFoundError: If provider type is not registered
|
|
74
|
+
"""
|
|
75
|
+
# Get provider type as string value
|
|
76
|
+
provider_type = (
|
|
77
|
+
config.provider.value
|
|
78
|
+
if hasattr(config.provider, "value")
|
|
79
|
+
else str(config.provider)
|
|
80
|
+
)
|
|
81
|
+
cache_key = f"embed:{provider_type}:{config.model}"
|
|
82
|
+
|
|
83
|
+
if cache_key not in cls._instances:
|
|
84
|
+
provider_class = cls._embedding_providers.get(provider_type)
|
|
85
|
+
if not provider_class:
|
|
86
|
+
available = list(cls._embedding_providers.keys())
|
|
87
|
+
raise ProviderNotFoundError(
|
|
88
|
+
f"Unknown embedding provider: {provider_type}. "
|
|
89
|
+
f"Available: {', '.join(available)}",
|
|
90
|
+
provider_type,
|
|
91
|
+
)
|
|
92
|
+
cls._instances[cache_key] = provider_class(config)
|
|
93
|
+
logger.info(
|
|
94
|
+
f"Created {provider_type} embedding provider with model {config.model}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
from agent_brain_server.providers.base import EmbeddingProvider
|
|
98
|
+
|
|
99
|
+
return cast(EmbeddingProvider, cls._instances[cache_key])
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def get_summarization_provider(
|
|
103
|
+
cls, config: "SummarizationConfig"
|
|
104
|
+
) -> "SummarizationProvider":
|
|
105
|
+
"""Get or create summarization provider instance.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
config: Summarization provider configuration
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Configured SummarizationProvider instance
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ProviderNotFoundError: If provider type is not registered
|
|
115
|
+
"""
|
|
116
|
+
# Get provider type as string value
|
|
117
|
+
provider_type = (
|
|
118
|
+
config.provider.value
|
|
119
|
+
if hasattr(config.provider, "value")
|
|
120
|
+
else str(config.provider)
|
|
121
|
+
)
|
|
122
|
+
cache_key = f"summ:{provider_type}:{config.model}"
|
|
123
|
+
|
|
124
|
+
if cache_key not in cls._instances:
|
|
125
|
+
provider_class = cls._summarization_providers.get(provider_type)
|
|
126
|
+
if not provider_class:
|
|
127
|
+
available = list(cls._summarization_providers.keys())
|
|
128
|
+
raise ProviderNotFoundError(
|
|
129
|
+
f"Unknown summarization provider: {provider_type}. "
|
|
130
|
+
f"Available: {', '.join(available)}",
|
|
131
|
+
provider_type,
|
|
132
|
+
)
|
|
133
|
+
cls._instances[cache_key] = provider_class(config)
|
|
134
|
+
logger.info(
|
|
135
|
+
f"Created {provider_type} summarization provider "
|
|
136
|
+
f"with model {config.model}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
from agent_brain_server.providers.base import SummarizationProvider
|
|
140
|
+
|
|
141
|
+
return cast(SummarizationProvider, cls._instances[cache_key])
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def clear_cache(cls) -> None:
|
|
145
|
+
"""Clear provider instance cache (for testing)."""
|
|
146
|
+
cls._instances.clear()
|
|
147
|
+
logger.debug("Cleared provider instance cache")
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def get_available_embedding_providers(cls) -> list[str]:
|
|
151
|
+
"""Get list of registered embedding provider types."""
|
|
152
|
+
return list(cls._embedding_providers.keys())
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def get_available_summarization_providers(cls) -> list[str]:
|
|
156
|
+
"""Get list of registered summarization provider types."""
|
|
157
|
+
return list(cls._summarization_providers.keys())
|