agent-brain-rag 1.1.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {agent_brain_rag-1.1.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +68 -27
  2. agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
  3. agent_brain_rag-2.0.0.dist-info/entry_points.txt +4 -0
  4. {doc_serve_server → agent_brain_server}/__init__.py +1 -1
  5. {doc_serve_server → agent_brain_server}/api/main.py +90 -26
  6. {doc_serve_server → agent_brain_server}/api/routers/health.py +4 -2
  7. {doc_serve_server → agent_brain_server}/api/routers/index.py +1 -1
  8. {doc_serve_server → agent_brain_server}/api/routers/query.py +3 -3
  9. agent_brain_server/config/provider_config.py +308 -0
  10. {doc_serve_server → agent_brain_server}/config/settings.py +12 -1
  11. agent_brain_server/indexing/__init__.py +40 -0
  12. {doc_serve_server → agent_brain_server}/indexing/bm25_index.py +1 -1
  13. {doc_serve_server → agent_brain_server}/indexing/chunking.py +1 -1
  14. agent_brain_server/indexing/embedding.py +225 -0
  15. agent_brain_server/indexing/graph_extractors.py +582 -0
  16. agent_brain_server/indexing/graph_index.py +536 -0
  17. {doc_serve_server → agent_brain_server}/models/__init__.py +9 -0
  18. agent_brain_server/models/graph.py +253 -0
  19. {doc_serve_server → agent_brain_server}/models/health.py +15 -3
  20. {doc_serve_server → agent_brain_server}/models/query.py +14 -1
  21. agent_brain_server/providers/__init__.py +64 -0
  22. agent_brain_server/providers/base.py +251 -0
  23. agent_brain_server/providers/embedding/__init__.py +23 -0
  24. agent_brain_server/providers/embedding/cohere.py +163 -0
  25. agent_brain_server/providers/embedding/ollama.py +150 -0
  26. agent_brain_server/providers/embedding/openai.py +118 -0
  27. agent_brain_server/providers/exceptions.py +95 -0
  28. agent_brain_server/providers/factory.py +157 -0
  29. agent_brain_server/providers/summarization/__init__.py +41 -0
  30. agent_brain_server/providers/summarization/anthropic.py +87 -0
  31. agent_brain_server/providers/summarization/gemini.py +96 -0
  32. agent_brain_server/providers/summarization/grok.py +95 -0
  33. agent_brain_server/providers/summarization/ollama.py +114 -0
  34. agent_brain_server/providers/summarization/openai.py +87 -0
  35. {doc_serve_server → agent_brain_server}/services/indexing_service.py +43 -4
  36. {doc_serve_server → agent_brain_server}/services/query_service.py +212 -4
  37. agent_brain_server/storage/__init__.py +21 -0
  38. agent_brain_server/storage/graph_store.py +519 -0
  39. {doc_serve_server → agent_brain_server}/storage/vector_store.py +36 -1
  40. {doc_serve_server → agent_brain_server}/storage_paths.py +2 -0
  41. agent_brain_rag-1.1.0.dist-info/RECORD +0 -31
  42. agent_brain_rag-1.1.0.dist-info/entry_points.txt +0 -3
  43. doc_serve_server/indexing/__init__.py +0 -19
  44. doc_serve_server/indexing/embedding.py +0 -274
  45. doc_serve_server/storage/__init__.py +0 -5
  46. {agent_brain_rag-1.1.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
  47. {doc_serve_server → agent_brain_server}/api/__init__.py +0 -0
  48. {doc_serve_server → agent_brain_server}/api/routers/__init__.py +0 -0
  49. {doc_serve_server → agent_brain_server}/config/__init__.py +0 -0
  50. {doc_serve_server → agent_brain_server}/indexing/document_loader.py +0 -0
  51. {doc_serve_server → agent_brain_server}/locking.py +0 -0
  52. {doc_serve_server → agent_brain_server}/models/index.py +0 -0
  53. {doc_serve_server → agent_brain_server}/project_root.py +0 -0
  54. {doc_serve_server → agent_brain_server}/runtime.py +0 -0
  55. {doc_serve_server → agent_brain_server}/services/__init__.py +0 -0
@@ -0,0 +1,163 @@
1
+ """Cohere embedding provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import cohere
7
+
8
+ from agent_brain_server.providers.base import BaseEmbeddingProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import EmbeddingConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Model dimension mappings for Cohere embedding models
17
+ COHERE_MODEL_DIMENSIONS: dict[str, int] = {
18
+ "embed-english-v3.0": 1024,
19
+ "embed-english-light-v3.0": 384,
20
+ "embed-multilingual-v3.0": 1024,
21
+ "embed-multilingual-light-v3.0": 384,
22
+ "embed-english-v2.0": 4096,
23
+ "embed-english-light-v2.0": 1024,
24
+ "embed-multilingual-v2.0": 768,
25
+ }
26
+
27
+ DEFAULT_COHERE_DIMENSIONS = 1024
28
+
29
+
30
+ class CohereEmbeddingProvider(BaseEmbeddingProvider):
31
+ """Cohere embedding provider using Cohere's embedding models.
32
+
33
+ Supports:
34
+ - embed-english-v3.0 (1024 dimensions, best for English)
35
+ - embed-english-light-v3.0 (384 dimensions, faster)
36
+ - embed-multilingual-v3.0 (1024 dimensions, 100+ languages)
37
+ - embed-multilingual-light-v3.0 (384 dimensions, faster multilingual)
38
+
39
+ Cohere embeddings support different input types for optimal performance:
40
+ - search_document: For indexing documents to be searched
41
+ - search_query: For search queries
42
+ - classification: For classification tasks
43
+ - clustering: For clustering tasks
44
+ """
45
+
46
+ def __init__(self, config: "EmbeddingConfig") -> None:
47
+ """Initialize Cohere embedding provider.
48
+
49
+ Args:
50
+ config: Embedding configuration
51
+
52
+ Raises:
53
+ AuthenticationError: If API key is not available
54
+ """
55
+ api_key = config.get_api_key()
56
+ if not api_key:
57
+ raise AuthenticationError(
58
+ f"Missing API key. Set {config.api_key_env} environment variable.",
59
+ self.provider_name,
60
+ )
61
+
62
+ batch_size = config.params.get("batch_size", 96) # Cohere limit
63
+ super().__init__(model=config.model, batch_size=batch_size)
64
+
65
+ self._client = cohere.AsyncClientV2(api_key=api_key)
66
+ self._input_type = config.params.get("input_type", "search_document")
67
+ self._truncate = config.params.get("truncate", "END")
68
+
69
+ @property
70
+ def provider_name(self) -> str:
71
+ """Human-readable provider name."""
72
+ return "Cohere"
73
+
74
+ def get_dimensions(self) -> int:
75
+ """Get embedding dimensions for current model.
76
+
77
+ Returns:
78
+ Number of dimensions in embedding vector
79
+ """
80
+ return COHERE_MODEL_DIMENSIONS.get(self._model, DEFAULT_COHERE_DIMENSIONS)
81
+
82
+ async def embed_text(self, text: str) -> list[float]:
83
+ """Generate embedding for single text.
84
+
85
+ Args:
86
+ text: Text to embed
87
+
88
+ Returns:
89
+ Embedding vector as list of floats
90
+
91
+ Raises:
92
+ ProviderError: If embedding generation fails
93
+ """
94
+ try:
95
+ response = await self._client.embed(
96
+ texts=[text],
97
+ model=self._model,
98
+ input_type=self._input_type,
99
+ truncate=self._truncate,
100
+ )
101
+ embeddings = response.embeddings.float_
102
+ if embeddings is None:
103
+ raise ProviderError(
104
+ "No embeddings returned from Cohere",
105
+ self.provider_name,
106
+ )
107
+ return list(embeddings[0])
108
+ except Exception as e:
109
+ raise ProviderError(
110
+ f"Failed to generate embedding: {e}",
111
+ self.provider_name,
112
+ cause=e,
113
+ ) from e
114
+
115
+ async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
116
+ """Generate embeddings for a batch of texts.
117
+
118
+ Args:
119
+ texts: List of texts to embed
120
+
121
+ Returns:
122
+ List of embedding vectors
123
+
124
+ Raises:
125
+ ProviderError: If embedding generation fails
126
+ """
127
+ try:
128
+ response = await self._client.embed(
129
+ texts=texts,
130
+ model=self._model,
131
+ input_type=self._input_type,
132
+ truncate=self._truncate,
133
+ )
134
+ embeddings = response.embeddings.float_
135
+ if embeddings is None:
136
+ raise ProviderError(
137
+ "No embeddings returned from Cohere",
138
+ self.provider_name,
139
+ )
140
+ return [list(emb) for emb in embeddings]
141
+ except Exception as e:
142
+ raise ProviderError(
143
+ f"Failed to generate batch embeddings: {e}",
144
+ self.provider_name,
145
+ cause=e,
146
+ ) from e
147
+
148
+ def set_input_type(self, input_type: str) -> None:
149
+ """Set the input type for embeddings.
150
+
151
+ Args:
152
+ input_type: One of 'search_document', 'search_query',
153
+ 'classification', or 'clustering'
154
+ """
155
+ valid_types = [
156
+ "search_document",
157
+ "search_query",
158
+ "classification",
159
+ "clustering",
160
+ ]
161
+ if input_type not in valid_types:
162
+ raise ValueError(f"Invalid input_type. Must be one of: {valid_types}")
163
+ self._input_type = input_type
@@ -0,0 +1,150 @@
1
+ """Ollama embedding provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from openai import AsyncOpenAI
7
+
8
+ from agent_brain_server.providers.base import BaseEmbeddingProvider
9
+ from agent_brain_server.providers.exceptions import (
10
+ OllamaConnectionError,
11
+ ProviderError,
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ from agent_brain_server.config.provider_config import EmbeddingConfig
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Model dimension mappings for common Ollama embedding models
20
+ OLLAMA_MODEL_DIMENSIONS: dict[str, int] = {
21
+ "nomic-embed-text": 768,
22
+ "mxbai-embed-large": 1024,
23
+ "all-minilm": 384,
24
+ "snowflake-arctic-embed": 1024,
25
+ "bge-m3": 1024,
26
+ "bge-large": 1024,
27
+ }
28
+
29
+ DEFAULT_OLLAMA_DIMENSIONS = 768
30
+
31
+
32
+ class OllamaEmbeddingProvider(BaseEmbeddingProvider):
33
+ """Ollama embedding provider using local models.
34
+
35
+ Uses OpenAI-compatible API endpoint provided by Ollama.
36
+
37
+ Supports:
38
+ - nomic-embed-text (768 dimensions, general purpose)
39
+ - mxbai-embed-large (1024 dimensions, multilingual)
40
+ - all-minilm (384 dimensions, lightweight)
41
+ - snowflake-arctic-embed (1024 dimensions, high quality)
42
+ - And any other embedding model available in Ollama
43
+ """
44
+
45
+ def __init__(self, config: "EmbeddingConfig") -> None:
46
+ """Initialize Ollama embedding provider.
47
+
48
+ Args:
49
+ config: Embedding configuration
50
+
51
+ Note:
52
+ Ollama does not require an API key as it runs locally.
53
+ """
54
+ batch_size = config.params.get("batch_size", 100)
55
+ super().__init__(model=config.model, batch_size=batch_size)
56
+
57
+ # Ollama uses OpenAI-compatible API
58
+ base_url = config.get_base_url() or "http://localhost:11434/v1"
59
+ self._base_url = base_url
60
+ self._client = AsyncOpenAI(
61
+ api_key="ollama", # Ollama doesn't need real key
62
+ base_url=base_url,
63
+ )
64
+
65
+ # Optional parameters
66
+ self._num_ctx = config.params.get("num_ctx", 2048)
67
+ self._num_threads = config.params.get("num_threads")
68
+
69
+ @property
70
+ def provider_name(self) -> str:
71
+ """Human-readable provider name."""
72
+ return "Ollama"
73
+
74
+ def get_dimensions(self) -> int:
75
+ """Get embedding dimensions for current model.
76
+
77
+ Returns:
78
+ Number of dimensions in embedding vector
79
+ """
80
+ return OLLAMA_MODEL_DIMENSIONS.get(self._model, DEFAULT_OLLAMA_DIMENSIONS)
81
+
82
+ async def embed_text(self, text: str) -> list[float]:
83
+ """Generate embedding for single text.
84
+
85
+ Args:
86
+ text: Text to embed
87
+
88
+ Returns:
89
+ Embedding vector as list of floats
90
+
91
+ Raises:
92
+ OllamaConnectionError: If Ollama is not running
93
+ ProviderError: If embedding generation fails
94
+ """
95
+ try:
96
+ response = await self._client.embeddings.create(
97
+ model=self._model,
98
+ input=text,
99
+ )
100
+ return response.data[0].embedding
101
+ except Exception as e:
102
+ if "connection" in str(e).lower() or "refused" in str(e).lower():
103
+ raise OllamaConnectionError(self._base_url, cause=e) from e
104
+ raise ProviderError(
105
+ f"Failed to generate embedding: {e}",
106
+ self.provider_name,
107
+ cause=e,
108
+ ) from e
109
+
110
+ async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
111
+ """Generate embeddings for a batch of texts.
112
+
113
+ Args:
114
+ texts: List of texts to embed
115
+
116
+ Returns:
117
+ List of embedding vectors
118
+
119
+ Raises:
120
+ OllamaConnectionError: If Ollama is not running
121
+ ProviderError: If embedding generation fails
122
+ """
123
+ try:
124
+ response = await self._client.embeddings.create(
125
+ model=self._model,
126
+ input=texts,
127
+ )
128
+ return [item.embedding for item in response.data]
129
+ except Exception as e:
130
+ if "connection" in str(e).lower() or "refused" in str(e).lower():
131
+ raise OllamaConnectionError(self._base_url, cause=e) from e
132
+ raise ProviderError(
133
+ f"Failed to generate batch embeddings: {e}",
134
+ self.provider_name,
135
+ cause=e,
136
+ ) from e
137
+
138
+ async def health_check(self) -> bool:
139
+ """Check if Ollama is running and accessible.
140
+
141
+ Returns:
142
+ True if Ollama is healthy, False otherwise
143
+ """
144
+ try:
145
+ # Try to list models to verify connection
146
+ await self._client.models.list()
147
+ return True
148
+ except Exception as e:
149
+ logger.warning(f"Ollama health check failed: {e}")
150
+ return False
@@ -0,0 +1,118 @@
1
+ """OpenAI embedding provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from openai import AsyncOpenAI
7
+
8
+ from agent_brain_server.providers.base import BaseEmbeddingProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import EmbeddingConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Model dimension mappings for OpenAI embedding models
17
+ OPENAI_MODEL_DIMENSIONS: dict[str, int] = {
18
+ "text-embedding-3-large": 3072,
19
+ "text-embedding-3-small": 1536,
20
+ "text-embedding-ada-002": 1536,
21
+ }
22
+
23
+
24
+ class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
25
+ """OpenAI embedding provider using text-embedding models.
26
+
27
+ Supports:
28
+ - text-embedding-3-large (3072 dimensions, highest quality)
29
+ - text-embedding-3-small (1536 dimensions, faster)
30
+ - text-embedding-ada-002 (1536 dimensions, legacy)
31
+ """
32
+
33
+ def __init__(self, config: "EmbeddingConfig") -> None:
34
+ """Initialize OpenAI embedding provider.
35
+
36
+ Args:
37
+ config: Embedding configuration
38
+
39
+ Raises:
40
+ AuthenticationError: If API key is not available
41
+ """
42
+ api_key = config.get_api_key()
43
+ if not api_key:
44
+ raise AuthenticationError(
45
+ f"Missing API key. Set {config.api_key_env} environment variable.",
46
+ self.provider_name,
47
+ )
48
+
49
+ batch_size = config.params.get("batch_size", 100)
50
+ super().__init__(model=config.model, batch_size=batch_size)
51
+
52
+ self._client = AsyncOpenAI(api_key=api_key)
53
+ self._dimensions_override = config.params.get("dimensions")
54
+
55
+ @property
56
+ def provider_name(self) -> str:
57
+ """Human-readable provider name."""
58
+ return "OpenAI"
59
+
60
+ def get_dimensions(self) -> int:
61
+ """Get embedding dimensions for current model.
62
+
63
+ Returns:
64
+ Number of dimensions in embedding vector
65
+ """
66
+ if self._dimensions_override:
67
+ return int(self._dimensions_override)
68
+ return OPENAI_MODEL_DIMENSIONS.get(self._model, 3072)
69
+
70
+ async def embed_text(self, text: str) -> list[float]:
71
+ """Generate embedding for single text.
72
+
73
+ Args:
74
+ text: Text to embed
75
+
76
+ Returns:
77
+ Embedding vector as list of floats
78
+
79
+ Raises:
80
+ ProviderError: If embedding generation fails
81
+ """
82
+ try:
83
+ response = await self._client.embeddings.create(
84
+ model=self._model,
85
+ input=text,
86
+ )
87
+ return response.data[0].embedding
88
+ except Exception as e:
89
+ raise ProviderError(
90
+ f"Failed to generate embedding: {e}",
91
+ self.provider_name,
92
+ cause=e,
93
+ ) from e
94
+
95
+ async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
96
+ """Generate embeddings for a batch of texts.
97
+
98
+ Args:
99
+ texts: List of texts to embed
100
+
101
+ Returns:
102
+ List of embedding vectors
103
+
104
+ Raises:
105
+ ProviderError: If embedding generation fails
106
+ """
107
+ try:
108
+ response = await self._client.embeddings.create(
109
+ model=self._model,
110
+ input=texts,
111
+ )
112
+ return [item.embedding for item in response.data]
113
+ except Exception as e:
114
+ raise ProviderError(
115
+ f"Failed to generate batch embeddings: {e}",
116
+ self.provider_name,
117
+ cause=e,
118
+ ) from e
@@ -0,0 +1,95 @@
1
+ """Exception hierarchy for provider errors."""
2
+
3
+ from typing import Optional
4
+
5
+
6
+ class ProviderError(Exception):
7
+ """Base exception for provider errors."""
8
+
9
+ def __init__(
10
+ self, message: str, provider: str, cause: Optional[Exception] = None
11
+ ) -> None:
12
+ self.provider = provider
13
+ self.cause = cause
14
+ super().__init__(f"[{provider}] {message}")
15
+
16
+
17
+ class ConfigurationError(ProviderError):
18
+ """Raised when provider configuration is invalid."""
19
+
20
+ pass
21
+
22
+
23
+ class AuthenticationError(ProviderError):
24
+ """Raised when API key is missing or invalid."""
25
+
26
+ pass
27
+
28
+
29
+ class ProviderNotFoundError(ProviderError):
30
+ """Raised when requested provider type is not registered."""
31
+
32
+ pass
33
+
34
+
35
+ class ProviderMismatchError(ProviderError):
36
+ """Raised when current provider doesn't match indexed data."""
37
+
38
+ def __init__(
39
+ self,
40
+ current_provider: str,
41
+ current_model: str,
42
+ indexed_provider: str,
43
+ indexed_model: str,
44
+ ) -> None:
45
+ message = (
46
+ f"Provider mismatch: index was created with "
47
+ f"{indexed_provider}/{indexed_model}, "
48
+ f"but current config uses {current_provider}/{current_model}. "
49
+ f"Re-index with --force to update."
50
+ )
51
+ super().__init__(message, current_provider)
52
+ self.current_model = current_model
53
+ self.indexed_provider = indexed_provider
54
+ self.indexed_model = indexed_model
55
+
56
+
57
+ class RateLimitError(ProviderError):
58
+ """Raised when provider rate limit is hit."""
59
+
60
+ def __init__(self, provider: str, retry_after: Optional[int] = None) -> None:
61
+ self.retry_after = retry_after
62
+ message = "Rate limit exceeded"
63
+ if retry_after:
64
+ message += f", retry after {retry_after}s"
65
+ super().__init__(message, provider)
66
+
67
+
68
+ class ModelNotFoundError(ProviderError):
69
+ """Raised when specified model is not available."""
70
+
71
+ def __init__(
72
+ self, provider: str, model: str, available_models: Optional[list[str]] = None
73
+ ) -> None:
74
+ self.model = model
75
+ self.available_models = available_models or []
76
+ if available_models:
77
+ message = (
78
+ f"Model '{model}' not found. "
79
+ f"Available: {', '.join(available_models[:5])}"
80
+ )
81
+ else:
82
+ message = f"Model '{model}' not found"
83
+ super().__init__(message, provider)
84
+
85
+
86
+ class OllamaConnectionError(ProviderError):
87
+ """Raised when Ollama is not running or unreachable."""
88
+
89
+ def __init__(self, base_url: str, cause: Optional[Exception] = None) -> None:
90
+ message = (
91
+ f"Cannot connect to Ollama at {base_url}. "
92
+ "Ensure Ollama is running with 'ollama serve' command."
93
+ )
94
+ super().__init__(message, "ollama", cause)
95
+ self.base_url = base_url
@@ -0,0 +1,157 @@
1
+ """Provider factory and registry for dynamic provider instantiation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, cast
5
+
6
+ from agent_brain_server.providers.exceptions import ProviderNotFoundError
7
+
8
+ if TYPE_CHECKING:
9
+ from agent_brain_server.config.provider_config import (
10
+ EmbeddingConfig,
11
+ SummarizationConfig,
12
+ )
13
+ from agent_brain_server.providers.base import (
14
+ EmbeddingProvider,
15
+ SummarizationProvider,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProviderRegistry:
22
+ """Registry for provider implementations.
23
+
24
+ Allows dynamic registration of providers and lazy instantiation.
25
+ Implements singleton pattern for provider instance caching.
26
+ """
27
+
28
+ _embedding_providers: dict[str, type[Any]] = {}
29
+ _summarization_providers: dict[str, type[Any]] = {}
30
+ _instances: dict[str, Any] = {}
31
+
32
+ @classmethod
33
+ def register_embedding_provider(
34
+ cls,
35
+ provider_type: str,
36
+ provider_class: type["EmbeddingProvider"],
37
+ ) -> None:
38
+ """Register an embedding provider class.
39
+
40
+ Args:
41
+ provider_type: Provider identifier (e.g., 'openai', 'ollama')
42
+ provider_class: Provider class implementing EmbeddingProvider protocol
43
+ """
44
+ cls._embedding_providers[provider_type] = provider_class
45
+ logger.debug(f"Registered embedding provider: {provider_type}")
46
+
47
+ @classmethod
48
+ def register_summarization_provider(
49
+ cls,
50
+ provider_type: str,
51
+ provider_class: type["SummarizationProvider"],
52
+ ) -> None:
53
+ """Register a summarization provider class.
54
+
55
+ Args:
56
+ provider_type: Provider identifier (e.g., 'anthropic', 'openai')
57
+ provider_class: Provider class implementing SummarizationProvider protocol
58
+ """
59
+ cls._summarization_providers[provider_type] = provider_class
60
+ logger.debug(f"Registered summarization provider: {provider_type}")
61
+
62
+ @classmethod
63
+ def get_embedding_provider(cls, config: "EmbeddingConfig") -> "EmbeddingProvider":
64
+ """Get or create embedding provider instance.
65
+
66
+ Args:
67
+ config: Embedding provider configuration
68
+
69
+ Returns:
70
+ Configured EmbeddingProvider instance
71
+
72
+ Raises:
73
+ ProviderNotFoundError: If provider type is not registered
74
+ """
75
+ # Get provider type as string value
76
+ provider_type = (
77
+ config.provider.value
78
+ if hasattr(config.provider, "value")
79
+ else str(config.provider)
80
+ )
81
+ cache_key = f"embed:{provider_type}:{config.model}"
82
+
83
+ if cache_key not in cls._instances:
84
+ provider_class = cls._embedding_providers.get(provider_type)
85
+ if not provider_class:
86
+ available = list(cls._embedding_providers.keys())
87
+ raise ProviderNotFoundError(
88
+ f"Unknown embedding provider: {provider_type}. "
89
+ f"Available: {', '.join(available)}",
90
+ provider_type,
91
+ )
92
+ cls._instances[cache_key] = provider_class(config)
93
+ logger.info(
94
+ f"Created {provider_type} embedding provider with model {config.model}"
95
+ )
96
+
97
+ from agent_brain_server.providers.base import EmbeddingProvider
98
+
99
+ return cast(EmbeddingProvider, cls._instances[cache_key])
100
+
101
+ @classmethod
102
+ def get_summarization_provider(
103
+ cls, config: "SummarizationConfig"
104
+ ) -> "SummarizationProvider":
105
+ """Get or create summarization provider instance.
106
+
107
+ Args:
108
+ config: Summarization provider configuration
109
+
110
+ Returns:
111
+ Configured SummarizationProvider instance
112
+
113
+ Raises:
114
+ ProviderNotFoundError: If provider type is not registered
115
+ """
116
+ # Get provider type as string value
117
+ provider_type = (
118
+ config.provider.value
119
+ if hasattr(config.provider, "value")
120
+ else str(config.provider)
121
+ )
122
+ cache_key = f"summ:{provider_type}:{config.model}"
123
+
124
+ if cache_key not in cls._instances:
125
+ provider_class = cls._summarization_providers.get(provider_type)
126
+ if not provider_class:
127
+ available = list(cls._summarization_providers.keys())
128
+ raise ProviderNotFoundError(
129
+ f"Unknown summarization provider: {provider_type}. "
130
+ f"Available: {', '.join(available)}",
131
+ provider_type,
132
+ )
133
+ cls._instances[cache_key] = provider_class(config)
134
+ logger.info(
135
+ f"Created {provider_type} summarization provider "
136
+ f"with model {config.model}"
137
+ )
138
+
139
+ from agent_brain_server.providers.base import SummarizationProvider
140
+
141
+ return cast(SummarizationProvider, cls._instances[cache_key])
142
+
143
+ @classmethod
144
+ def clear_cache(cls) -> None:
145
+ """Clear provider instance cache (for testing)."""
146
+ cls._instances.clear()
147
+ logger.debug("Cleared provider instance cache")
148
+
149
+ @classmethod
150
+ def get_available_embedding_providers(cls) -> list[str]:
151
+ """Get list of registered embedding provider types."""
152
+ return list(cls._embedding_providers.keys())
153
+
154
+ @classmethod
155
+ def get_available_summarization_providers(cls) -> list[str]:
156
+ """Get list of registered summarization provider types."""
157
+ return list(cls._summarization_providers.keys())