agent-brain-rag 1.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/METADATA +55 -18
  2. agent_brain_rag-3.0.0.dist-info/RECORD +56 -0
  3. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/WHEEL +1 -1
  4. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/entry_points.txt +0 -1
  5. agent_brain_server/__init__.py +1 -1
  6. agent_brain_server/api/main.py +146 -45
  7. agent_brain_server/api/routers/__init__.py +2 -0
  8. agent_brain_server/api/routers/health.py +85 -21
  9. agent_brain_server/api/routers/index.py +108 -36
  10. agent_brain_server/api/routers/jobs.py +111 -0
  11. agent_brain_server/config/provider_config.py +352 -0
  12. agent_brain_server/config/settings.py +22 -5
  13. agent_brain_server/indexing/__init__.py +21 -0
  14. agent_brain_server/indexing/bm25_index.py +15 -2
  15. agent_brain_server/indexing/document_loader.py +45 -4
  16. agent_brain_server/indexing/embedding.py +86 -135
  17. agent_brain_server/indexing/graph_extractors.py +582 -0
  18. agent_brain_server/indexing/graph_index.py +536 -0
  19. agent_brain_server/job_queue/__init__.py +11 -0
  20. agent_brain_server/job_queue/job_service.py +317 -0
  21. agent_brain_server/job_queue/job_store.py +427 -0
  22. agent_brain_server/job_queue/job_worker.py +434 -0
  23. agent_brain_server/locking.py +101 -8
  24. agent_brain_server/models/__init__.py +28 -0
  25. agent_brain_server/models/graph.py +253 -0
  26. agent_brain_server/models/health.py +30 -3
  27. agent_brain_server/models/job.py +289 -0
  28. agent_brain_server/models/query.py +16 -3
  29. agent_brain_server/project_root.py +1 -1
  30. agent_brain_server/providers/__init__.py +64 -0
  31. agent_brain_server/providers/base.py +251 -0
  32. agent_brain_server/providers/embedding/__init__.py +23 -0
  33. agent_brain_server/providers/embedding/cohere.py +163 -0
  34. agent_brain_server/providers/embedding/ollama.py +150 -0
  35. agent_brain_server/providers/embedding/openai.py +118 -0
  36. agent_brain_server/providers/exceptions.py +95 -0
  37. agent_brain_server/providers/factory.py +157 -0
  38. agent_brain_server/providers/summarization/__init__.py +41 -0
  39. agent_brain_server/providers/summarization/anthropic.py +87 -0
  40. agent_brain_server/providers/summarization/gemini.py +96 -0
  41. agent_brain_server/providers/summarization/grok.py +95 -0
  42. agent_brain_server/providers/summarization/ollama.py +114 -0
  43. agent_brain_server/providers/summarization/openai.py +87 -0
  44. agent_brain_server/runtime.py +2 -2
  45. agent_brain_server/services/indexing_service.py +39 -0
  46. agent_brain_server/services/query_service.py +203 -0
  47. agent_brain_server/storage/__init__.py +18 -2
  48. agent_brain_server/storage/graph_store.py +519 -0
  49. agent_brain_server/storage/vector_store.py +35 -0
  50. agent_brain_server/storage_paths.py +5 -3
  51. agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
@@ -0,0 +1,118 @@
1
+ """OpenAI embedding provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from openai import AsyncOpenAI
7
+
8
+ from agent_brain_server.providers.base import BaseEmbeddingProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import EmbeddingConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Model dimension mappings for OpenAI embedding models
17
+ OPENAI_MODEL_DIMENSIONS: dict[str, int] = {
18
+ "text-embedding-3-large": 3072,
19
+ "text-embedding-3-small": 1536,
20
+ "text-embedding-ada-002": 1536,
21
+ }
22
+
23
+
24
+ class OpenAIEmbeddingProvider(BaseEmbeddingProvider):
25
+ """OpenAI embedding provider using text-embedding models.
26
+
27
+ Supports:
28
+ - text-embedding-3-large (3072 dimensions, highest quality)
29
+ - text-embedding-3-small (1536 dimensions, faster)
30
+ - text-embedding-ada-002 (1536 dimensions, legacy)
31
+ """
32
+
33
+ def __init__(self, config: "EmbeddingConfig") -> None:
34
+ """Initialize OpenAI embedding provider.
35
+
36
+ Args:
37
+ config: Embedding configuration
38
+
39
+ Raises:
40
+ AuthenticationError: If API key is not available
41
+ """
42
+ api_key = config.get_api_key()
43
+ if not api_key:
44
+ raise AuthenticationError(
45
+ f"Missing API key. Set {config.api_key_env} environment variable.",
46
+ self.provider_name,
47
+ )
48
+
49
+ batch_size = config.params.get("batch_size", 100)
50
+ super().__init__(model=config.model, batch_size=batch_size)
51
+
52
+ self._client = AsyncOpenAI(api_key=api_key)
53
+ self._dimensions_override = config.params.get("dimensions")
54
+
55
+ @property
56
+ def provider_name(self) -> str:
57
+ """Human-readable provider name."""
58
+ return "OpenAI"
59
+
60
+ def get_dimensions(self) -> int:
61
+ """Get embedding dimensions for current model.
62
+
63
+ Returns:
64
+ Number of dimensions in embedding vector
65
+ """
66
+ if self._dimensions_override:
67
+ return int(self._dimensions_override)
68
+ return OPENAI_MODEL_DIMENSIONS.get(self._model, 3072)
69
+
70
+ async def embed_text(self, text: str) -> list[float]:
71
+ """Generate embedding for single text.
72
+
73
+ Args:
74
+ text: Text to embed
75
+
76
+ Returns:
77
+ Embedding vector as list of floats
78
+
79
+ Raises:
80
+ ProviderError: If embedding generation fails
81
+ """
82
+ try:
83
+ response = await self._client.embeddings.create(
84
+ model=self._model,
85
+ input=text,
86
+ )
87
+ return response.data[0].embedding
88
+ except Exception as e:
89
+ raise ProviderError(
90
+ f"Failed to generate embedding: {e}",
91
+ self.provider_name,
92
+ cause=e,
93
+ ) from e
94
+
95
+ async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
96
+ """Generate embeddings for a batch of texts.
97
+
98
+ Args:
99
+ texts: List of texts to embed
100
+
101
+ Returns:
102
+ List of embedding vectors
103
+
104
+ Raises:
105
+ ProviderError: If embedding generation fails
106
+ """
107
+ try:
108
+ response = await self._client.embeddings.create(
109
+ model=self._model,
110
+ input=texts,
111
+ )
112
+ return [item.embedding for item in response.data]
113
+ except Exception as e:
114
+ raise ProviderError(
115
+ f"Failed to generate batch embeddings: {e}",
116
+ self.provider_name,
117
+ cause=e,
118
+ ) from e
@@ -0,0 +1,95 @@
1
+ """Exception hierarchy for provider errors."""
2
+
3
+ from typing import Optional
4
+
5
+
6
+ class ProviderError(Exception):
7
+ """Base exception for provider errors."""
8
+
9
+ def __init__(
10
+ self, message: str, provider: str, cause: Optional[Exception] = None
11
+ ) -> None:
12
+ self.provider = provider
13
+ self.cause = cause
14
+ super().__init__(f"[{provider}] {message}")
15
+
16
+
17
+ class ConfigurationError(ProviderError):
18
+ """Raised when provider configuration is invalid."""
19
+
20
+ pass
21
+
22
+
23
+ class AuthenticationError(ProviderError):
24
+ """Raised when API key is missing or invalid."""
25
+
26
+ pass
27
+
28
+
29
+ class ProviderNotFoundError(ProviderError):
30
+ """Raised when requested provider type is not registered."""
31
+
32
+ pass
33
+
34
+
35
+ class ProviderMismatchError(ProviderError):
36
+ """Raised when current provider doesn't match indexed data."""
37
+
38
+ def __init__(
39
+ self,
40
+ current_provider: str,
41
+ current_model: str,
42
+ indexed_provider: str,
43
+ indexed_model: str,
44
+ ) -> None:
45
+ message = (
46
+ f"Provider mismatch: index was created with "
47
+ f"{indexed_provider}/{indexed_model}, "
48
+ f"but current config uses {current_provider}/{current_model}. "
49
+ f"Re-index with --force to update."
50
+ )
51
+ super().__init__(message, current_provider)
52
+ self.current_model = current_model
53
+ self.indexed_provider = indexed_provider
54
+ self.indexed_model = indexed_model
55
+
56
+
57
+ class RateLimitError(ProviderError):
58
+ """Raised when provider rate limit is hit."""
59
+
60
+ def __init__(self, provider: str, retry_after: Optional[int] = None) -> None:
61
+ self.retry_after = retry_after
62
+ message = "Rate limit exceeded"
63
+ if retry_after:
64
+ message += f", retry after {retry_after}s"
65
+ super().__init__(message, provider)
66
+
67
+
68
+ class ModelNotFoundError(ProviderError):
69
+ """Raised when specified model is not available."""
70
+
71
+ def __init__(
72
+ self, provider: str, model: str, available_models: Optional[list[str]] = None
73
+ ) -> None:
74
+ self.model = model
75
+ self.available_models = available_models or []
76
+ if available_models:
77
+ message = (
78
+ f"Model '{model}' not found. "
79
+ f"Available: {', '.join(available_models[:5])}"
80
+ )
81
+ else:
82
+ message = f"Model '{model}' not found"
83
+ super().__init__(message, provider)
84
+
85
+
86
+ class OllamaConnectionError(ProviderError):
87
+ """Raised when Ollama is not running or unreachable."""
88
+
89
+ def __init__(self, base_url: str, cause: Optional[Exception] = None) -> None:
90
+ message = (
91
+ f"Cannot connect to Ollama at {base_url}. "
92
+ "Ensure Ollama is running with 'ollama serve' command."
93
+ )
94
+ super().__init__(message, "ollama", cause)
95
+ self.base_url = base_url
@@ -0,0 +1,157 @@
1
+ """Provider factory and registry for dynamic provider instantiation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, cast
5
+
6
+ from agent_brain_server.providers.exceptions import ProviderNotFoundError
7
+
8
+ if TYPE_CHECKING:
9
+ from agent_brain_server.config.provider_config import (
10
+ EmbeddingConfig,
11
+ SummarizationConfig,
12
+ )
13
+ from agent_brain_server.providers.base import (
14
+ EmbeddingProvider,
15
+ SummarizationProvider,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProviderRegistry:
22
+ """Registry for provider implementations.
23
+
24
+ Allows dynamic registration of providers and lazy instantiation.
25
+ Implements singleton pattern for provider instance caching.
26
+ """
27
+
28
+ _embedding_providers: dict[str, type[Any]] = {}
29
+ _summarization_providers: dict[str, type[Any]] = {}
30
+ _instances: dict[str, Any] = {}
31
+
32
+ @classmethod
33
+ def register_embedding_provider(
34
+ cls,
35
+ provider_type: str,
36
+ provider_class: type["EmbeddingProvider"],
37
+ ) -> None:
38
+ """Register an embedding provider class.
39
+
40
+ Args:
41
+ provider_type: Provider identifier (e.g., 'openai', 'ollama')
42
+ provider_class: Provider class implementing EmbeddingProvider protocol
43
+ """
44
+ cls._embedding_providers[provider_type] = provider_class
45
+ logger.debug(f"Registered embedding provider: {provider_type}")
46
+
47
+ @classmethod
48
+ def register_summarization_provider(
49
+ cls,
50
+ provider_type: str,
51
+ provider_class: type["SummarizationProvider"],
52
+ ) -> None:
53
+ """Register a summarization provider class.
54
+
55
+ Args:
56
+ provider_type: Provider identifier (e.g., 'anthropic', 'openai')
57
+ provider_class: Provider class implementing SummarizationProvider protocol
58
+ """
59
+ cls._summarization_providers[provider_type] = provider_class
60
+ logger.debug(f"Registered summarization provider: {provider_type}")
61
+
62
+ @classmethod
63
+ def get_embedding_provider(cls, config: "EmbeddingConfig") -> "EmbeddingProvider":
64
+ """Get or create embedding provider instance.
65
+
66
+ Args:
67
+ config: Embedding provider configuration
68
+
69
+ Returns:
70
+ Configured EmbeddingProvider instance
71
+
72
+ Raises:
73
+ ProviderNotFoundError: If provider type is not registered
74
+ """
75
+ # Get provider type as string value
76
+ provider_type = (
77
+ config.provider.value
78
+ if hasattr(config.provider, "value")
79
+ else str(config.provider)
80
+ )
81
+ cache_key = f"embed:{provider_type}:{config.model}"
82
+
83
+ if cache_key not in cls._instances:
84
+ provider_class = cls._embedding_providers.get(provider_type)
85
+ if not provider_class:
86
+ available = list(cls._embedding_providers.keys())
87
+ raise ProviderNotFoundError(
88
+ f"Unknown embedding provider: {provider_type}. "
89
+ f"Available: {', '.join(available)}",
90
+ provider_type,
91
+ )
92
+ cls._instances[cache_key] = provider_class(config)
93
+ logger.info(
94
+ f"Created {provider_type} embedding provider with model {config.model}"
95
+ )
96
+
97
+ from agent_brain_server.providers.base import EmbeddingProvider
98
+
99
+ return cast(EmbeddingProvider, cls._instances[cache_key])
100
+
101
+ @classmethod
102
+ def get_summarization_provider(
103
+ cls, config: "SummarizationConfig"
104
+ ) -> "SummarizationProvider":
105
+ """Get or create summarization provider instance.
106
+
107
+ Args:
108
+ config: Summarization provider configuration
109
+
110
+ Returns:
111
+ Configured SummarizationProvider instance
112
+
113
+ Raises:
114
+ ProviderNotFoundError: If provider type is not registered
115
+ """
116
+ # Get provider type as string value
117
+ provider_type = (
118
+ config.provider.value
119
+ if hasattr(config.provider, "value")
120
+ else str(config.provider)
121
+ )
122
+ cache_key = f"summ:{provider_type}:{config.model}"
123
+
124
+ if cache_key not in cls._instances:
125
+ provider_class = cls._summarization_providers.get(provider_type)
126
+ if not provider_class:
127
+ available = list(cls._summarization_providers.keys())
128
+ raise ProviderNotFoundError(
129
+ f"Unknown summarization provider: {provider_type}. "
130
+ f"Available: {', '.join(available)}",
131
+ provider_type,
132
+ )
133
+ cls._instances[cache_key] = provider_class(config)
134
+ logger.info(
135
+ f"Created {provider_type} summarization provider "
136
+ f"with model {config.model}"
137
+ )
138
+
139
+ from agent_brain_server.providers.base import SummarizationProvider
140
+
141
+ return cast(SummarizationProvider, cls._instances[cache_key])
142
+
143
+ @classmethod
144
+ def clear_cache(cls) -> None:
145
+ """Clear provider instance cache (for testing)."""
146
+ cls._instances.clear()
147
+ logger.debug("Cleared provider instance cache")
148
+
149
+ @classmethod
150
+ def get_available_embedding_providers(cls) -> list[str]:
151
+ """Get list of registered embedding provider types."""
152
+ return list(cls._embedding_providers.keys())
153
+
154
+ @classmethod
155
+ def get_available_summarization_providers(cls) -> list[str]:
156
+ """Get list of registered summarization provider types."""
157
+ return list(cls._summarization_providers.keys())
@@ -0,0 +1,41 @@
1
+ """Summarization providers for Agent Brain.
2
+
3
+ This module provides summarization/LLM implementations for:
4
+ - Anthropic (Claude 4.5 Haiku, Sonnet, Opus)
5
+ - OpenAI (GPT-5, GPT-5 Mini)
6
+ - Gemini (gemini-3-flash, gemini-3-pro)
7
+ - Grok (grok-4, grok-4-fast)
8
+ - Ollama (llama4:scout, mistral-small3.2, qwen3-coder, gemma3)
9
+ """
10
+
11
+ from agent_brain_server.providers.factory import ProviderRegistry
12
+ from agent_brain_server.providers.summarization.anthropic import (
13
+ AnthropicSummarizationProvider,
14
+ )
15
+ from agent_brain_server.providers.summarization.gemini import (
16
+ GeminiSummarizationProvider,
17
+ )
18
+ from agent_brain_server.providers.summarization.grok import GrokSummarizationProvider
19
+ from agent_brain_server.providers.summarization.ollama import (
20
+ OllamaSummarizationProvider,
21
+ )
22
+ from agent_brain_server.providers.summarization.openai import (
23
+ OpenAISummarizationProvider,
24
+ )
25
+
26
+ # Register summarization providers
27
+ ProviderRegistry.register_summarization_provider(
28
+ "anthropic", AnthropicSummarizationProvider
29
+ )
30
+ ProviderRegistry.register_summarization_provider("openai", OpenAISummarizationProvider)
31
+ ProviderRegistry.register_summarization_provider("gemini", GeminiSummarizationProvider)
32
+ ProviderRegistry.register_summarization_provider("grok", GrokSummarizationProvider)
33
+ ProviderRegistry.register_summarization_provider("ollama", OllamaSummarizationProvider)
34
+
35
+ __all__ = [
36
+ "AnthropicSummarizationProvider",
37
+ "OpenAISummarizationProvider",
38
+ "GeminiSummarizationProvider",
39
+ "GrokSummarizationProvider",
40
+ "OllamaSummarizationProvider",
41
+ ]
@@ -0,0 +1,87 @@
1
+ """Anthropic (Claude) summarization provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from anthropic import AsyncAnthropic
7
+
8
+ from agent_brain_server.providers.base import BaseSummarizationProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import SummarizationConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class AnthropicSummarizationProvider(BaseSummarizationProvider):
18
+ """Anthropic (Claude) summarization provider.
19
+
20
+ Supports:
21
+ - claude-haiku-4-5-20251001 (fast, cost-effective)
22
+ - claude-sonnet-4-5-20250514 (balanced)
23
+ - claude-opus-4-5-20251101 (highest quality)
24
+ - And other Claude models
25
+ """
26
+
27
+ def __init__(self, config: "SummarizationConfig") -> None:
28
+ """Initialize Anthropic summarization provider.
29
+
30
+ Args:
31
+ config: Summarization configuration
32
+
33
+ Raises:
34
+ AuthenticationError: If API key is not available
35
+ """
36
+ api_key = config.get_api_key()
37
+ if not api_key:
38
+ raise AuthenticationError(
39
+ f"Missing API key. Set {config.api_key_env} environment variable.",
40
+ self.provider_name,
41
+ )
42
+
43
+ max_tokens = config.params.get("max_tokens", 300)
44
+ temperature = config.params.get("temperature", 0.1)
45
+ prompt_template = config.params.get("prompt_template")
46
+
47
+ super().__init__(
48
+ model=config.model,
49
+ max_tokens=max_tokens,
50
+ temperature=temperature,
51
+ prompt_template=prompt_template,
52
+ )
53
+
54
+ self._client = AsyncAnthropic(api_key=api_key)
55
+
56
+ @property
57
+ def provider_name(self) -> str:
58
+ """Human-readable provider name."""
59
+ return "Anthropic"
60
+
61
+ async def generate(self, prompt: str) -> str:
62
+ """Generate text based on prompt using Claude.
63
+
64
+ Args:
65
+ prompt: The prompt to send to Claude
66
+
67
+ Returns:
68
+ Generated text response
69
+
70
+ Raises:
71
+ ProviderError: If generation fails
72
+ """
73
+ try:
74
+ response = await self._client.messages.create(
75
+ model=self._model,
76
+ max_tokens=self._max_tokens,
77
+ temperature=self._temperature,
78
+ messages=[{"role": "user", "content": prompt}],
79
+ )
80
+ # Extract text from response
81
+ return response.content[0].text # type: ignore[union-attr]
82
+ except Exception as e:
83
+ raise ProviderError(
84
+ f"Failed to generate text: {e}",
85
+ self.provider_name,
86
+ cause=e,
87
+ ) from e
@@ -0,0 +1,96 @@
1
+ """Google Gemini summarization provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import google.generativeai as genai
7
+
8
+ from agent_brain_server.providers.base import BaseSummarizationProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import SummarizationConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class GeminiSummarizationProvider(BaseSummarizationProvider):
18
+ """Google Gemini summarization provider.
19
+
20
+ Supports:
21
+ - gemini-3-flash (fast, cost-effective)
22
+ - gemini-3-pro (highest quality)
23
+ - And other Gemini models
24
+ """
25
+
26
+ def __init__(self, config: "SummarizationConfig") -> None:
27
+ """Initialize Gemini summarization provider.
28
+
29
+ Args:
30
+ config: Summarization configuration
31
+
32
+ Raises:
33
+ AuthenticationError: If API key is not available
34
+ """
35
+ api_key = config.get_api_key()
36
+ if not api_key:
37
+ raise AuthenticationError(
38
+ f"Missing API key. Set {config.api_key_env} environment variable.",
39
+ self.provider_name,
40
+ )
41
+
42
+ max_tokens = config.params.get(
43
+ "max_output_tokens", config.params.get("max_tokens", 300)
44
+ )
45
+ temperature = config.params.get("temperature", 0.1)
46
+ prompt_template = config.params.get("prompt_template")
47
+ top_p = config.params.get("top_p", 0.95)
48
+
49
+ super().__init__(
50
+ model=config.model,
51
+ max_tokens=max_tokens,
52
+ temperature=temperature,
53
+ prompt_template=prompt_template,
54
+ )
55
+
56
+ # Configure Gemini with API key
57
+ genai.configure(api_key=api_key) # type: ignore[attr-defined]
58
+
59
+ # Create model with generation config
60
+ generation_config = genai.types.GenerationConfig(
61
+ max_output_tokens=max_tokens,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ )
65
+ self._model_instance = genai.GenerativeModel( # type: ignore[attr-defined]
66
+ model_name=config.model,
67
+ generation_config=generation_config,
68
+ )
69
+
70
+ @property
71
+ def provider_name(self) -> str:
72
+ """Human-readable provider name."""
73
+ return "Gemini"
74
+
75
+ async def generate(self, prompt: str) -> str:
76
+ """Generate text based on prompt using Gemini.
77
+
78
+ Args:
79
+ prompt: The prompt to send to Gemini
80
+
81
+ Returns:
82
+ Generated text response
83
+
84
+ Raises:
85
+ ProviderError: If generation fails
86
+ """
87
+ try:
88
+ # Use async generation
89
+ response = await self._model_instance.generate_content_async(prompt)
90
+ return str(response.text)
91
+ except Exception as e:
92
+ raise ProviderError(
93
+ f"Failed to generate text: {e}",
94
+ self.provider_name,
95
+ cause=e,
96
+ ) from e
@@ -0,0 +1,95 @@
1
+ """xAI Grok summarization provider implementation."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from openai import AsyncOpenAI
7
+
8
+ from agent_brain_server.providers.base import BaseSummarizationProvider
9
+ from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
10
+
11
+ if TYPE_CHECKING:
12
+ from agent_brain_server.config.provider_config import SummarizationConfig
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class GrokSummarizationProvider(BaseSummarizationProvider):
18
+ """xAI Grok summarization provider.
19
+
20
+ Uses OpenAI-compatible API at https://api.x.ai/v1
21
+
22
+ Supports:
23
+ - grok-4 (most capable, with reasoning)
24
+ - grok-4-fast (fast variant)
25
+ - grok-3 (previous generation)
26
+ - And other Grok models
27
+ """
28
+
29
+ def __init__(self, config: "SummarizationConfig") -> None:
30
+ """Initialize Grok summarization provider.
31
+
32
+ Args:
33
+ config: Summarization configuration
34
+
35
+ Raises:
36
+ AuthenticationError: If API key is not available
37
+ """
38
+ api_key = config.get_api_key()
39
+ if not api_key:
40
+ raise AuthenticationError(
41
+ f"Missing API key. Set {config.api_key_env} environment variable.",
42
+ self.provider_name,
43
+ )
44
+
45
+ max_tokens = config.params.get("max_tokens", 300)
46
+ temperature = config.params.get("temperature", 0.1)
47
+ prompt_template = config.params.get("prompt_template")
48
+
49
+ super().__init__(
50
+ model=config.model,
51
+ max_tokens=max_tokens,
52
+ temperature=temperature,
53
+ prompt_template=prompt_template,
54
+ )
55
+
56
+ # Grok uses OpenAI-compatible API
57
+ base_url = config.get_base_url() or "https://api.x.ai/v1"
58
+ self._client = AsyncOpenAI(
59
+ api_key=api_key,
60
+ base_url=base_url,
61
+ )
62
+
63
+ @property
64
+ def provider_name(self) -> str:
65
+ """Human-readable provider name."""
66
+ return "Grok"
67
+
68
+ async def generate(self, prompt: str) -> str:
69
+ """Generate text based on prompt using Grok.
70
+
71
+ Args:
72
+ prompt: The prompt to send to Grok
73
+
74
+ Returns:
75
+ Generated text response
76
+
77
+ Raises:
78
+ ProviderError: If generation fails
79
+ """
80
+ try:
81
+ response = await self._client.chat.completions.create(
82
+ model=self._model,
83
+ max_tokens=self._max_tokens,
84
+ temperature=self._temperature,
85
+ messages=[{"role": "user", "content": prompt}],
86
+ )
87
+ # Extract text from response
88
+ content = response.choices[0].message.content
89
+ return content if content else ""
90
+ except Exception as e:
91
+ raise ProviderError(
92
+ f"Failed to generate text: {e}",
93
+ self.provider_name,
94
+ cause=e,
95
+ ) from e