agent-brain-rag 1.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/METADATA +55 -18
- agent_brain_rag-3.0.0.dist-info/RECORD +56 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/WHEEL +1 -1
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/entry_points.txt +0 -1
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +146 -45
- agent_brain_server/api/routers/__init__.py +2 -0
- agent_brain_server/api/routers/health.py +85 -21
- agent_brain_server/api/routers/index.py +108 -36
- agent_brain_server/api/routers/jobs.py +111 -0
- agent_brain_server/config/provider_config.py +352 -0
- agent_brain_server/config/settings.py +22 -5
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/bm25_index.py +15 -2
- agent_brain_server/indexing/document_loader.py +45 -4
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/job_queue/__init__.py +11 -0
- agent_brain_server/job_queue/job_service.py +317 -0
- agent_brain_server/job_queue/job_store.py +427 -0
- agent_brain_server/job_queue/job_worker.py +434 -0
- agent_brain_server/locking.py +101 -8
- agent_brain_server/models/__init__.py +28 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +30 -3
- agent_brain_server/models/job.py +289 -0
- agent_brain_server/models/query.py +16 -3
- agent_brain_server/project_root.py +1 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/runtime.py +2 -2
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +5 -3
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""OpenAI embedding provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseEmbeddingProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import EmbeddingConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Model dimension mappings for OpenAI embedding models
|
|
17
|
+
OPENAI_MODEL_DIMENSIONS: dict[str, int] = {
|
|
18
|
+
"text-embedding-3-large": 3072,
|
|
19
|
+
"text-embedding-3-small": 1536,
|
|
20
|
+
"text-embedding-ada-002": 1536,
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
|
|
25
|
+
"""OpenAI embedding provider using text-embedding models.
|
|
26
|
+
|
|
27
|
+
Supports:
|
|
28
|
+
- text-embedding-3-large (3072 dimensions, highest quality)
|
|
29
|
+
- text-embedding-3-small (1536 dimensions, faster)
|
|
30
|
+
- text-embedding-ada-002 (1536 dimensions, legacy)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: "EmbeddingConfig") -> None:
|
|
34
|
+
"""Initialize OpenAI embedding provider.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
config: Embedding configuration
|
|
38
|
+
|
|
39
|
+
Raises:
|
|
40
|
+
AuthenticationError: If API key is not available
|
|
41
|
+
"""
|
|
42
|
+
api_key = config.get_api_key()
|
|
43
|
+
if not api_key:
|
|
44
|
+
raise AuthenticationError(
|
|
45
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
46
|
+
self.provider_name,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
batch_size = config.params.get("batch_size", 100)
|
|
50
|
+
super().__init__(model=config.model, batch_size=batch_size)
|
|
51
|
+
|
|
52
|
+
self._client = AsyncOpenAI(api_key=api_key)
|
|
53
|
+
self._dimensions_override = config.params.get("dimensions")
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def provider_name(self) -> str:
|
|
57
|
+
"""Human-readable provider name."""
|
|
58
|
+
return "OpenAI"
|
|
59
|
+
|
|
60
|
+
def get_dimensions(self) -> int:
|
|
61
|
+
"""Get embedding dimensions for current model.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Number of dimensions in embedding vector
|
|
65
|
+
"""
|
|
66
|
+
if self._dimensions_override:
|
|
67
|
+
return int(self._dimensions_override)
|
|
68
|
+
return OPENAI_MODEL_DIMENSIONS.get(self._model, 3072)
|
|
69
|
+
|
|
70
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
71
|
+
"""Generate embedding for single text.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
text: Text to embed
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Embedding vector as list of floats
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ProviderError: If embedding generation fails
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
response = await self._client.embeddings.create(
|
|
84
|
+
model=self._model,
|
|
85
|
+
input=text,
|
|
86
|
+
)
|
|
87
|
+
return response.data[0].embedding
|
|
88
|
+
except Exception as e:
|
|
89
|
+
raise ProviderError(
|
|
90
|
+
f"Failed to generate embedding: {e}",
|
|
91
|
+
self.provider_name,
|
|
92
|
+
cause=e,
|
|
93
|
+
) from e
|
|
94
|
+
|
|
95
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
96
|
+
"""Generate embeddings for a batch of texts.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
texts: List of texts to embed
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of embedding vectors
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
ProviderError: If embedding generation fails
|
|
106
|
+
"""
|
|
107
|
+
try:
|
|
108
|
+
response = await self._client.embeddings.create(
|
|
109
|
+
model=self._model,
|
|
110
|
+
input=texts,
|
|
111
|
+
)
|
|
112
|
+
return [item.embedding for item in response.data]
|
|
113
|
+
except Exception as e:
|
|
114
|
+
raise ProviderError(
|
|
115
|
+
f"Failed to generate batch embeddings: {e}",
|
|
116
|
+
self.provider_name,
|
|
117
|
+
cause=e,
|
|
118
|
+
) from e
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Exception hierarchy for provider errors."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ProviderError(Exception):
|
|
7
|
+
"""Base exception for provider errors."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self, message: str, provider: str, cause: Optional[Exception] = None
|
|
11
|
+
) -> None:
|
|
12
|
+
self.provider = provider
|
|
13
|
+
self.cause = cause
|
|
14
|
+
super().__init__(f"[{provider}] {message}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConfigurationError(ProviderError):
|
|
18
|
+
"""Raised when provider configuration is invalid."""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AuthenticationError(ProviderError):
|
|
24
|
+
"""Raised when API key is missing or invalid."""
|
|
25
|
+
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ProviderNotFoundError(ProviderError):
|
|
30
|
+
"""Raised when requested provider type is not registered."""
|
|
31
|
+
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProviderMismatchError(ProviderError):
|
|
36
|
+
"""Raised when current provider doesn't match indexed data."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
current_provider: str,
|
|
41
|
+
current_model: str,
|
|
42
|
+
indexed_provider: str,
|
|
43
|
+
indexed_model: str,
|
|
44
|
+
) -> None:
|
|
45
|
+
message = (
|
|
46
|
+
f"Provider mismatch: index was created with "
|
|
47
|
+
f"{indexed_provider}/{indexed_model}, "
|
|
48
|
+
f"but current config uses {current_provider}/{current_model}. "
|
|
49
|
+
f"Re-index with --force to update."
|
|
50
|
+
)
|
|
51
|
+
super().__init__(message, current_provider)
|
|
52
|
+
self.current_model = current_model
|
|
53
|
+
self.indexed_provider = indexed_provider
|
|
54
|
+
self.indexed_model = indexed_model
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class RateLimitError(ProviderError):
|
|
58
|
+
"""Raised when provider rate limit is hit."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, provider: str, retry_after: Optional[int] = None) -> None:
|
|
61
|
+
self.retry_after = retry_after
|
|
62
|
+
message = "Rate limit exceeded"
|
|
63
|
+
if retry_after:
|
|
64
|
+
message += f", retry after {retry_after}s"
|
|
65
|
+
super().__init__(message, provider)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ModelNotFoundError(ProviderError):
|
|
69
|
+
"""Raised when specified model is not available."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self, provider: str, model: str, available_models: Optional[list[str]] = None
|
|
73
|
+
) -> None:
|
|
74
|
+
self.model = model
|
|
75
|
+
self.available_models = available_models or []
|
|
76
|
+
if available_models:
|
|
77
|
+
message = (
|
|
78
|
+
f"Model '{model}' not found. "
|
|
79
|
+
f"Available: {', '.join(available_models[:5])}"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
message = f"Model '{model}' not found"
|
|
83
|
+
super().__init__(message, provider)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class OllamaConnectionError(ProviderError):
|
|
87
|
+
"""Raised when Ollama is not running or unreachable."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, base_url: str, cause: Optional[Exception] = None) -> None:
|
|
90
|
+
message = (
|
|
91
|
+
f"Cannot connect to Ollama at {base_url}. "
|
|
92
|
+
"Ensure Ollama is running with 'ollama serve' command."
|
|
93
|
+
)
|
|
94
|
+
super().__init__(message, "ollama", cause)
|
|
95
|
+
self.base_url = base_url
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Provider factory and registry for dynamic provider instantiation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
5
|
+
|
|
6
|
+
from agent_brain_server.providers.exceptions import ProviderNotFoundError
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from agent_brain_server.config.provider_config import (
|
|
10
|
+
EmbeddingConfig,
|
|
11
|
+
SummarizationConfig,
|
|
12
|
+
)
|
|
13
|
+
from agent_brain_server.providers.base import (
|
|
14
|
+
EmbeddingProvider,
|
|
15
|
+
SummarizationProvider,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProviderRegistry:
|
|
22
|
+
"""Registry for provider implementations.
|
|
23
|
+
|
|
24
|
+
Allows dynamic registration of providers and lazy instantiation.
|
|
25
|
+
Implements singleton pattern for provider instance caching.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_embedding_providers: dict[str, type[Any]] = {}
|
|
29
|
+
_summarization_providers: dict[str, type[Any]] = {}
|
|
30
|
+
_instances: dict[str, Any] = {}
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def register_embedding_provider(
|
|
34
|
+
cls,
|
|
35
|
+
provider_type: str,
|
|
36
|
+
provider_class: type["EmbeddingProvider"],
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Register an embedding provider class.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
provider_type: Provider identifier (e.g., 'openai', 'ollama')
|
|
42
|
+
provider_class: Provider class implementing EmbeddingProvider protocol
|
|
43
|
+
"""
|
|
44
|
+
cls._embedding_providers[provider_type] = provider_class
|
|
45
|
+
logger.debug(f"Registered embedding provider: {provider_type}")
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def register_summarization_provider(
|
|
49
|
+
cls,
|
|
50
|
+
provider_type: str,
|
|
51
|
+
provider_class: type["SummarizationProvider"],
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Register a summarization provider class.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
provider_type: Provider identifier (e.g., 'anthropic', 'openai')
|
|
57
|
+
provider_class: Provider class implementing SummarizationProvider protocol
|
|
58
|
+
"""
|
|
59
|
+
cls._summarization_providers[provider_type] = provider_class
|
|
60
|
+
logger.debug(f"Registered summarization provider: {provider_type}")
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def get_embedding_provider(cls, config: "EmbeddingConfig") -> "EmbeddingProvider":
|
|
64
|
+
"""Get or create embedding provider instance.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: Embedding provider configuration
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Configured EmbeddingProvider instance
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ProviderNotFoundError: If provider type is not registered
|
|
74
|
+
"""
|
|
75
|
+
# Get provider type as string value
|
|
76
|
+
provider_type = (
|
|
77
|
+
config.provider.value
|
|
78
|
+
if hasattr(config.provider, "value")
|
|
79
|
+
else str(config.provider)
|
|
80
|
+
)
|
|
81
|
+
cache_key = f"embed:{provider_type}:{config.model}"
|
|
82
|
+
|
|
83
|
+
if cache_key not in cls._instances:
|
|
84
|
+
provider_class = cls._embedding_providers.get(provider_type)
|
|
85
|
+
if not provider_class:
|
|
86
|
+
available = list(cls._embedding_providers.keys())
|
|
87
|
+
raise ProviderNotFoundError(
|
|
88
|
+
f"Unknown embedding provider: {provider_type}. "
|
|
89
|
+
f"Available: {', '.join(available)}",
|
|
90
|
+
provider_type,
|
|
91
|
+
)
|
|
92
|
+
cls._instances[cache_key] = provider_class(config)
|
|
93
|
+
logger.info(
|
|
94
|
+
f"Created {provider_type} embedding provider with model {config.model}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
from agent_brain_server.providers.base import EmbeddingProvider
|
|
98
|
+
|
|
99
|
+
return cast(EmbeddingProvider, cls._instances[cache_key])
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def get_summarization_provider(
|
|
103
|
+
cls, config: "SummarizationConfig"
|
|
104
|
+
) -> "SummarizationProvider":
|
|
105
|
+
"""Get or create summarization provider instance.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
config: Summarization provider configuration
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Configured SummarizationProvider instance
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
ProviderNotFoundError: If provider type is not registered
|
|
115
|
+
"""
|
|
116
|
+
# Get provider type as string value
|
|
117
|
+
provider_type = (
|
|
118
|
+
config.provider.value
|
|
119
|
+
if hasattr(config.provider, "value")
|
|
120
|
+
else str(config.provider)
|
|
121
|
+
)
|
|
122
|
+
cache_key = f"summ:{provider_type}:{config.model}"
|
|
123
|
+
|
|
124
|
+
if cache_key not in cls._instances:
|
|
125
|
+
provider_class = cls._summarization_providers.get(provider_type)
|
|
126
|
+
if not provider_class:
|
|
127
|
+
available = list(cls._summarization_providers.keys())
|
|
128
|
+
raise ProviderNotFoundError(
|
|
129
|
+
f"Unknown summarization provider: {provider_type}. "
|
|
130
|
+
f"Available: {', '.join(available)}",
|
|
131
|
+
provider_type,
|
|
132
|
+
)
|
|
133
|
+
cls._instances[cache_key] = provider_class(config)
|
|
134
|
+
logger.info(
|
|
135
|
+
f"Created {provider_type} summarization provider "
|
|
136
|
+
f"with model {config.model}"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
from agent_brain_server.providers.base import SummarizationProvider
|
|
140
|
+
|
|
141
|
+
return cast(SummarizationProvider, cls._instances[cache_key])
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def clear_cache(cls) -> None:
|
|
145
|
+
"""Clear provider instance cache (for testing)."""
|
|
146
|
+
cls._instances.clear()
|
|
147
|
+
logger.debug("Cleared provider instance cache")
|
|
148
|
+
|
|
149
|
+
@classmethod
|
|
150
|
+
def get_available_embedding_providers(cls) -> list[str]:
|
|
151
|
+
"""Get list of registered embedding provider types."""
|
|
152
|
+
return list(cls._embedding_providers.keys())
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def get_available_summarization_providers(cls) -> list[str]:
|
|
156
|
+
"""Get list of registered summarization provider types."""
|
|
157
|
+
return list(cls._summarization_providers.keys())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Summarization providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This module provides summarization/LLM implementations for:
|
|
4
|
+
- Anthropic (Claude 4.5 Haiku, Sonnet, Opus)
|
|
5
|
+
- OpenAI (GPT-5, GPT-5 Mini)
|
|
6
|
+
- Gemini (gemini-3-flash, gemini-3-pro)
|
|
7
|
+
- Grok (grok-4, grok-4-fast)
|
|
8
|
+
- Ollama (llama4:scout, mistral-small3.2, qwen3-coder, gemma3)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
12
|
+
from agent_brain_server.providers.summarization.anthropic import (
|
|
13
|
+
AnthropicSummarizationProvider,
|
|
14
|
+
)
|
|
15
|
+
from agent_brain_server.providers.summarization.gemini import (
|
|
16
|
+
GeminiSummarizationProvider,
|
|
17
|
+
)
|
|
18
|
+
from agent_brain_server.providers.summarization.grok import GrokSummarizationProvider
|
|
19
|
+
from agent_brain_server.providers.summarization.ollama import (
|
|
20
|
+
OllamaSummarizationProvider,
|
|
21
|
+
)
|
|
22
|
+
from agent_brain_server.providers.summarization.openai import (
|
|
23
|
+
OpenAISummarizationProvider,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Register summarization providers
|
|
27
|
+
ProviderRegistry.register_summarization_provider(
|
|
28
|
+
"anthropic", AnthropicSummarizationProvider
|
|
29
|
+
)
|
|
30
|
+
ProviderRegistry.register_summarization_provider("openai", OpenAISummarizationProvider)
|
|
31
|
+
ProviderRegistry.register_summarization_provider("gemini", GeminiSummarizationProvider)
|
|
32
|
+
ProviderRegistry.register_summarization_provider("grok", GrokSummarizationProvider)
|
|
33
|
+
ProviderRegistry.register_summarization_provider("ollama", OllamaSummarizationProvider)
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"AnthropicSummarizationProvider",
|
|
37
|
+
"OpenAISummarizationProvider",
|
|
38
|
+
"GeminiSummarizationProvider",
|
|
39
|
+
"GrokSummarizationProvider",
|
|
40
|
+
"OllamaSummarizationProvider",
|
|
41
|
+
]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Anthropic (Claude) summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from anthropic import AsyncAnthropic
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AnthropicSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""Anthropic (Claude) summarization provider.
|
|
19
|
+
|
|
20
|
+
Supports:
|
|
21
|
+
- claude-haiku-4-5-20251001 (fast, cost-effective)
|
|
22
|
+
- claude-sonnet-4-5-20250514 (balanced)
|
|
23
|
+
- claude-opus-4-5-20251101 (highest quality)
|
|
24
|
+
- And other Claude models
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
28
|
+
"""Initialize Anthropic summarization provider.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: Summarization configuration
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
AuthenticationError: If API key is not available
|
|
35
|
+
"""
|
|
36
|
+
api_key = config.get_api_key()
|
|
37
|
+
if not api_key:
|
|
38
|
+
raise AuthenticationError(
|
|
39
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
40
|
+
self.provider_name,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
44
|
+
temperature = config.params.get("temperature", 0.1)
|
|
45
|
+
prompt_template = config.params.get("prompt_template")
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
model=config.model,
|
|
49
|
+
max_tokens=max_tokens,
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
prompt_template=prompt_template,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self._client = AsyncAnthropic(api_key=api_key)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def provider_name(self) -> str:
|
|
58
|
+
"""Human-readable provider name."""
|
|
59
|
+
return "Anthropic"
|
|
60
|
+
|
|
61
|
+
async def generate(self, prompt: str) -> str:
|
|
62
|
+
"""Generate text based on prompt using Claude.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
prompt: The prompt to send to Claude
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Generated text response
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ProviderError: If generation fails
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
response = await self._client.messages.create(
|
|
75
|
+
model=self._model,
|
|
76
|
+
max_tokens=self._max_tokens,
|
|
77
|
+
temperature=self._temperature,
|
|
78
|
+
messages=[{"role": "user", "content": prompt}],
|
|
79
|
+
)
|
|
80
|
+
# Extract text from response
|
|
81
|
+
return response.content[0].text # type: ignore[union-attr]
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise ProviderError(
|
|
84
|
+
f"Failed to generate text: {e}",
|
|
85
|
+
self.provider_name,
|
|
86
|
+
cause=e,
|
|
87
|
+
) from e
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Google Gemini summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import google.generativeai as genai
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GeminiSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""Google Gemini summarization provider.
|
|
19
|
+
|
|
20
|
+
Supports:
|
|
21
|
+
- gemini-3-flash (fast, cost-effective)
|
|
22
|
+
- gemini-3-pro (highest quality)
|
|
23
|
+
- And other Gemini models
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
27
|
+
"""Initialize Gemini summarization provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Summarization configuration
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
AuthenticationError: If API key is not available
|
|
34
|
+
"""
|
|
35
|
+
api_key = config.get_api_key()
|
|
36
|
+
if not api_key:
|
|
37
|
+
raise AuthenticationError(
|
|
38
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
39
|
+
self.provider_name,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
max_tokens = config.params.get(
|
|
43
|
+
"max_output_tokens", config.params.get("max_tokens", 300)
|
|
44
|
+
)
|
|
45
|
+
temperature = config.params.get("temperature", 0.1)
|
|
46
|
+
prompt_template = config.params.get("prompt_template")
|
|
47
|
+
top_p = config.params.get("top_p", 0.95)
|
|
48
|
+
|
|
49
|
+
super().__init__(
|
|
50
|
+
model=config.model,
|
|
51
|
+
max_tokens=max_tokens,
|
|
52
|
+
temperature=temperature,
|
|
53
|
+
prompt_template=prompt_template,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Configure Gemini with API key
|
|
57
|
+
genai.configure(api_key=api_key) # type: ignore[attr-defined]
|
|
58
|
+
|
|
59
|
+
# Create model with generation config
|
|
60
|
+
generation_config = genai.types.GenerationConfig(
|
|
61
|
+
max_output_tokens=max_tokens,
|
|
62
|
+
temperature=temperature,
|
|
63
|
+
top_p=top_p,
|
|
64
|
+
)
|
|
65
|
+
self._model_instance = genai.GenerativeModel( # type: ignore[attr-defined]
|
|
66
|
+
model_name=config.model,
|
|
67
|
+
generation_config=generation_config,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def provider_name(self) -> str:
|
|
72
|
+
"""Human-readable provider name."""
|
|
73
|
+
return "Gemini"
|
|
74
|
+
|
|
75
|
+
async def generate(self, prompt: str) -> str:
|
|
76
|
+
"""Generate text based on prompt using Gemini.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
prompt: The prompt to send to Gemini
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Generated text response
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ProviderError: If generation fails
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
# Use async generation
|
|
89
|
+
response = await self._model_instance.generate_content_async(prompt)
|
|
90
|
+
return str(response.text)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise ProviderError(
|
|
93
|
+
f"Failed to generate text: {e}",
|
|
94
|
+
self.provider_name,
|
|
95
|
+
cause=e,
|
|
96
|
+
) from e
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""xAI Grok summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GrokSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""xAI Grok summarization provider.
|
|
19
|
+
|
|
20
|
+
Uses OpenAI-compatible API at https://api.x.ai/v1
|
|
21
|
+
|
|
22
|
+
Supports:
|
|
23
|
+
- grok-4 (most capable, with reasoning)
|
|
24
|
+
- grok-4-fast (fast variant)
|
|
25
|
+
- grok-3 (previous generation)
|
|
26
|
+
- And other Grok models
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
30
|
+
"""Initialize Grok summarization provider.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: Summarization configuration
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
AuthenticationError: If API key is not available
|
|
37
|
+
"""
|
|
38
|
+
api_key = config.get_api_key()
|
|
39
|
+
if not api_key:
|
|
40
|
+
raise AuthenticationError(
|
|
41
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
42
|
+
self.provider_name,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
46
|
+
temperature = config.params.get("temperature", 0.1)
|
|
47
|
+
prompt_template = config.params.get("prompt_template")
|
|
48
|
+
|
|
49
|
+
super().__init__(
|
|
50
|
+
model=config.model,
|
|
51
|
+
max_tokens=max_tokens,
|
|
52
|
+
temperature=temperature,
|
|
53
|
+
prompt_template=prompt_template,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Grok uses OpenAI-compatible API
|
|
57
|
+
base_url = config.get_base_url() or "https://api.x.ai/v1"
|
|
58
|
+
self._client = AsyncOpenAI(
|
|
59
|
+
api_key=api_key,
|
|
60
|
+
base_url=base_url,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def provider_name(self) -> str:
|
|
65
|
+
"""Human-readable provider name."""
|
|
66
|
+
return "Grok"
|
|
67
|
+
|
|
68
|
+
async def generate(self, prompt: str) -> str:
|
|
69
|
+
"""Generate text based on prompt using Grok.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
prompt: The prompt to send to Grok
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Generated text response
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ProviderError: If generation fails
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
response = await self._client.chat.completions.create(
|
|
82
|
+
model=self._model,
|
|
83
|
+
max_tokens=self._max_tokens,
|
|
84
|
+
temperature=self._temperature,
|
|
85
|
+
messages=[{"role": "user", "content": prompt}],
|
|
86
|
+
)
|
|
87
|
+
# Extract text from response
|
|
88
|
+
content = response.choices[0].message.content
|
|
89
|
+
return content if content else ""
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise ProviderError(
|
|
92
|
+
f"Failed to generate text: {e}",
|
|
93
|
+
self.provider_name,
|
|
94
|
+
cause=e,
|
|
95
|
+
) from e
|