agent-brain-rag 1.2.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +54 -16
- agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +30 -2
- agent_brain_server/api/routers/health.py +1 -0
- agent_brain_server/config/provider_config.py +308 -0
- agent_brain_server/config/settings.py +12 -1
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/models/__init__.py +9 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +15 -3
- agent_brain_server/models/query.py +14 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +2 -0
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Summarization providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This module provides summarization/LLM implementations for:
|
|
4
|
+
- Anthropic (Claude 4.5 Haiku, Sonnet, Opus)
|
|
5
|
+
- OpenAI (GPT-5, GPT-5 Mini)
|
|
6
|
+
- Gemini (gemini-3-flash, gemini-3-pro)
|
|
7
|
+
- Grok (grok-4, grok-4-fast)
|
|
8
|
+
- Ollama (llama4:scout, mistral-small3.2, qwen3-coder, gemma3)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
12
|
+
from agent_brain_server.providers.summarization.anthropic import (
|
|
13
|
+
AnthropicSummarizationProvider,
|
|
14
|
+
)
|
|
15
|
+
from agent_brain_server.providers.summarization.gemini import (
|
|
16
|
+
GeminiSummarizationProvider,
|
|
17
|
+
)
|
|
18
|
+
from agent_brain_server.providers.summarization.grok import GrokSummarizationProvider
|
|
19
|
+
from agent_brain_server.providers.summarization.ollama import (
|
|
20
|
+
OllamaSummarizationProvider,
|
|
21
|
+
)
|
|
22
|
+
from agent_brain_server.providers.summarization.openai import (
|
|
23
|
+
OpenAISummarizationProvider,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Register summarization providers
|
|
27
|
+
ProviderRegistry.register_summarization_provider(
|
|
28
|
+
"anthropic", AnthropicSummarizationProvider
|
|
29
|
+
)
|
|
30
|
+
ProviderRegistry.register_summarization_provider("openai", OpenAISummarizationProvider)
|
|
31
|
+
ProviderRegistry.register_summarization_provider("gemini", GeminiSummarizationProvider)
|
|
32
|
+
ProviderRegistry.register_summarization_provider("grok", GrokSummarizationProvider)
|
|
33
|
+
ProviderRegistry.register_summarization_provider("ollama", OllamaSummarizationProvider)
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"AnthropicSummarizationProvider",
|
|
37
|
+
"OpenAISummarizationProvider",
|
|
38
|
+
"GeminiSummarizationProvider",
|
|
39
|
+
"GrokSummarizationProvider",
|
|
40
|
+
"OllamaSummarizationProvider",
|
|
41
|
+
]
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Anthropic (Claude) summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from anthropic import AsyncAnthropic
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AnthropicSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""Anthropic (Claude) summarization provider.
|
|
19
|
+
|
|
20
|
+
Supports:
|
|
21
|
+
- claude-haiku-4-5-20251001 (fast, cost-effective)
|
|
22
|
+
- claude-sonnet-4-5-20250514 (balanced)
|
|
23
|
+
- claude-opus-4-5-20251101 (highest quality)
|
|
24
|
+
- And other Claude models
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
28
|
+
"""Initialize Anthropic summarization provider.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
config: Summarization configuration
|
|
32
|
+
|
|
33
|
+
Raises:
|
|
34
|
+
AuthenticationError: If API key is not available
|
|
35
|
+
"""
|
|
36
|
+
api_key = config.get_api_key()
|
|
37
|
+
if not api_key:
|
|
38
|
+
raise AuthenticationError(
|
|
39
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
40
|
+
self.provider_name,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
44
|
+
temperature = config.params.get("temperature", 0.1)
|
|
45
|
+
prompt_template = config.params.get("prompt_template")
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
model=config.model,
|
|
49
|
+
max_tokens=max_tokens,
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
prompt_template=prompt_template,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self._client = AsyncAnthropic(api_key=api_key)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def provider_name(self) -> str:
|
|
58
|
+
"""Human-readable provider name."""
|
|
59
|
+
return "Anthropic"
|
|
60
|
+
|
|
61
|
+
async def generate(self, prompt: str) -> str:
|
|
62
|
+
"""Generate text based on prompt using Claude.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
prompt: The prompt to send to Claude
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Generated text response
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ProviderError: If generation fails
|
|
72
|
+
"""
|
|
73
|
+
try:
|
|
74
|
+
response = await self._client.messages.create(
|
|
75
|
+
model=self._model,
|
|
76
|
+
max_tokens=self._max_tokens,
|
|
77
|
+
temperature=self._temperature,
|
|
78
|
+
messages=[{"role": "user", "content": prompt}],
|
|
79
|
+
)
|
|
80
|
+
# Extract text from response
|
|
81
|
+
return response.content[0].text # type: ignore[union-attr]
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise ProviderError(
|
|
84
|
+
f"Failed to generate text: {e}",
|
|
85
|
+
self.provider_name,
|
|
86
|
+
cause=e,
|
|
87
|
+
) from e
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Google Gemini summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import google.generativeai as genai
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GeminiSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""Google Gemini summarization provider.
|
|
19
|
+
|
|
20
|
+
Supports:
|
|
21
|
+
- gemini-3-flash (fast, cost-effective)
|
|
22
|
+
- gemini-3-pro (highest quality)
|
|
23
|
+
- And other Gemini models
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
27
|
+
"""Initialize Gemini summarization provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Summarization configuration
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
AuthenticationError: If API key is not available
|
|
34
|
+
"""
|
|
35
|
+
api_key = config.get_api_key()
|
|
36
|
+
if not api_key:
|
|
37
|
+
raise AuthenticationError(
|
|
38
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
39
|
+
self.provider_name,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
max_tokens = config.params.get(
|
|
43
|
+
"max_output_tokens", config.params.get("max_tokens", 300)
|
|
44
|
+
)
|
|
45
|
+
temperature = config.params.get("temperature", 0.1)
|
|
46
|
+
prompt_template = config.params.get("prompt_template")
|
|
47
|
+
top_p = config.params.get("top_p", 0.95)
|
|
48
|
+
|
|
49
|
+
super().__init__(
|
|
50
|
+
model=config.model,
|
|
51
|
+
max_tokens=max_tokens,
|
|
52
|
+
temperature=temperature,
|
|
53
|
+
prompt_template=prompt_template,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Configure Gemini with API key
|
|
57
|
+
genai.configure(api_key=api_key) # type: ignore[attr-defined]
|
|
58
|
+
|
|
59
|
+
# Create model with generation config
|
|
60
|
+
generation_config = genai.types.GenerationConfig(
|
|
61
|
+
max_output_tokens=max_tokens,
|
|
62
|
+
temperature=temperature,
|
|
63
|
+
top_p=top_p,
|
|
64
|
+
)
|
|
65
|
+
self._model_instance = genai.GenerativeModel( # type: ignore[attr-defined]
|
|
66
|
+
model_name=config.model,
|
|
67
|
+
generation_config=generation_config,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def provider_name(self) -> str:
|
|
72
|
+
"""Human-readable provider name."""
|
|
73
|
+
return "Gemini"
|
|
74
|
+
|
|
75
|
+
async def generate(self, prompt: str) -> str:
|
|
76
|
+
"""Generate text based on prompt using Gemini.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
prompt: The prompt to send to Gemini
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Generated text response
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
ProviderError: If generation fails
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
# Use async generation
|
|
89
|
+
response = await self._model_instance.generate_content_async(prompt)
|
|
90
|
+
return str(response.text)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise ProviderError(
|
|
93
|
+
f"Failed to generate text: {e}",
|
|
94
|
+
self.provider_name,
|
|
95
|
+
cause=e,
|
|
96
|
+
) from e
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""xAI Grok summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GrokSummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""xAI Grok summarization provider.
|
|
19
|
+
|
|
20
|
+
Uses OpenAI-compatible API at https://api.x.ai/v1
|
|
21
|
+
|
|
22
|
+
Supports:
|
|
23
|
+
- grok-4 (most capable, with reasoning)
|
|
24
|
+
- grok-4-fast (fast variant)
|
|
25
|
+
- grok-3 (previous generation)
|
|
26
|
+
- And other Grok models
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
30
|
+
"""Initialize Grok summarization provider.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: Summarization configuration
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
AuthenticationError: If API key is not available
|
|
37
|
+
"""
|
|
38
|
+
api_key = config.get_api_key()
|
|
39
|
+
if not api_key:
|
|
40
|
+
raise AuthenticationError(
|
|
41
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
42
|
+
self.provider_name,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
46
|
+
temperature = config.params.get("temperature", 0.1)
|
|
47
|
+
prompt_template = config.params.get("prompt_template")
|
|
48
|
+
|
|
49
|
+
super().__init__(
|
|
50
|
+
model=config.model,
|
|
51
|
+
max_tokens=max_tokens,
|
|
52
|
+
temperature=temperature,
|
|
53
|
+
prompt_template=prompt_template,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Grok uses OpenAI-compatible API
|
|
57
|
+
base_url = config.get_base_url() or "https://api.x.ai/v1"
|
|
58
|
+
self._client = AsyncOpenAI(
|
|
59
|
+
api_key=api_key,
|
|
60
|
+
base_url=base_url,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def provider_name(self) -> str:
|
|
65
|
+
"""Human-readable provider name."""
|
|
66
|
+
return "Grok"
|
|
67
|
+
|
|
68
|
+
async def generate(self, prompt: str) -> str:
|
|
69
|
+
"""Generate text based on prompt using Grok.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
prompt: The prompt to send to Grok
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Generated text response
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ProviderError: If generation fails
|
|
79
|
+
"""
|
|
80
|
+
try:
|
|
81
|
+
response = await self._client.chat.completions.create(
|
|
82
|
+
model=self._model,
|
|
83
|
+
max_tokens=self._max_tokens,
|
|
84
|
+
temperature=self._temperature,
|
|
85
|
+
messages=[{"role": "user", "content": prompt}],
|
|
86
|
+
)
|
|
87
|
+
# Extract text from response
|
|
88
|
+
content = response.choices[0].message.content
|
|
89
|
+
return content if content else ""
|
|
90
|
+
except Exception as e:
|
|
91
|
+
raise ProviderError(
|
|
92
|
+
f"Failed to generate text: {e}",
|
|
93
|
+
self.provider_name,
|
|
94
|
+
cause=e,
|
|
95
|
+
) from e
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Ollama summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import (
|
|
10
|
+
OllamaConnectionError,
|
|
11
|
+
ProviderError,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OllamaSummarizationProvider(BaseSummarizationProvider):
|
|
21
|
+
"""Ollama summarization provider using local models.
|
|
22
|
+
|
|
23
|
+
Uses OpenAI-compatible API endpoint provided by Ollama.
|
|
24
|
+
|
|
25
|
+
Supports:
|
|
26
|
+
- llama4:scout (Meta's Llama 4 Scout - lightweight, fast)
|
|
27
|
+
- mistral-small3.2 (Mistral Small 3.2 - balanced)
|
|
28
|
+
- qwen3-coder (Alibaba Qwen 3 Coder - code-focused)
|
|
29
|
+
- gemma3 (Google Gemma 3 - efficient)
|
|
30
|
+
- deepseek-coder-v3 (DeepSeek Coder V3)
|
|
31
|
+
- And any other chat model available in Ollama
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
35
|
+
"""Initialize Ollama summarization provider.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: Summarization configuration
|
|
39
|
+
|
|
40
|
+
Note:
|
|
41
|
+
Ollama does not require an API key as it runs locally.
|
|
42
|
+
"""
|
|
43
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
44
|
+
temperature = config.params.get("temperature", 0.1)
|
|
45
|
+
prompt_template = config.params.get("prompt_template")
|
|
46
|
+
|
|
47
|
+
super().__init__(
|
|
48
|
+
model=config.model,
|
|
49
|
+
max_tokens=max_tokens,
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
prompt_template=prompt_template,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Ollama uses OpenAI-compatible API
|
|
55
|
+
base_url = config.get_base_url() or "http://localhost:11434/v1"
|
|
56
|
+
self._base_url = base_url
|
|
57
|
+
self._client = AsyncOpenAI(
|
|
58
|
+
api_key="ollama", # Ollama doesn't need real key
|
|
59
|
+
base_url=base_url,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Optional parameters
|
|
63
|
+
self._num_ctx = config.params.get("num_ctx", 4096)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def provider_name(self) -> str:
|
|
67
|
+
"""Human-readable provider name."""
|
|
68
|
+
return "Ollama"
|
|
69
|
+
|
|
70
|
+
async def generate(self, prompt: str) -> str:
|
|
71
|
+
"""Generate text based on prompt using Ollama.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
prompt: The prompt to send to Ollama
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Generated text response
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
OllamaConnectionError: If Ollama is not running
|
|
81
|
+
ProviderError: If generation fails
|
|
82
|
+
"""
|
|
83
|
+
try:
|
|
84
|
+
response = await self._client.chat.completions.create(
|
|
85
|
+
model=self._model,
|
|
86
|
+
max_tokens=self._max_tokens,
|
|
87
|
+
temperature=self._temperature,
|
|
88
|
+
messages=[{"role": "user", "content": prompt}],
|
|
89
|
+
)
|
|
90
|
+
# Extract text from response
|
|
91
|
+
content = response.choices[0].message.content
|
|
92
|
+
return content if content else ""
|
|
93
|
+
except Exception as e:
|
|
94
|
+
if "connection" in str(e).lower() or "refused" in str(e).lower():
|
|
95
|
+
raise OllamaConnectionError(self._base_url, cause=e) from e
|
|
96
|
+
raise ProviderError(
|
|
97
|
+
f"Failed to generate text: {e}",
|
|
98
|
+
self.provider_name,
|
|
99
|
+
cause=e,
|
|
100
|
+
) from e
|
|
101
|
+
|
|
102
|
+
async def health_check(self) -> bool:
|
|
103
|
+
"""Check if Ollama is running and accessible.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
True if Ollama is healthy, False otherwise
|
|
107
|
+
"""
|
|
108
|
+
try:
|
|
109
|
+
# Try to list models to verify connection
|
|
110
|
+
await self._client.models.list()
|
|
111
|
+
return True
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logger.warning(f"Ollama health check failed: {e}")
|
|
114
|
+
return False
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""OpenAI (GPT) summarization provider implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import BaseSummarizationProvider
|
|
9
|
+
from agent_brain_server.providers.exceptions import AuthenticationError, ProviderError
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from agent_brain_server.config.provider_config import SummarizationConfig
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class OpenAISummarizationProvider(BaseSummarizationProvider):
|
|
18
|
+
"""OpenAI (GPT) summarization provider.
|
|
19
|
+
|
|
20
|
+
Supports:
|
|
21
|
+
- gpt-5 (most capable)
|
|
22
|
+
- gpt-5-mini (fast, cost-effective)
|
|
23
|
+
- And other OpenAI chat models
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: "SummarizationConfig") -> None:
|
|
27
|
+
"""Initialize OpenAI summarization provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Summarization configuration
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
AuthenticationError: If API key is not available
|
|
34
|
+
"""
|
|
35
|
+
api_key = config.get_api_key()
|
|
36
|
+
if not api_key:
|
|
37
|
+
raise AuthenticationError(
|
|
38
|
+
f"Missing API key. Set {config.api_key_env} environment variable.",
|
|
39
|
+
self.provider_name,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
max_tokens = config.params.get("max_tokens", 300)
|
|
43
|
+
temperature = config.params.get("temperature", 0.1)
|
|
44
|
+
prompt_template = config.params.get("prompt_template")
|
|
45
|
+
|
|
46
|
+
super().__init__(
|
|
47
|
+
model=config.model,
|
|
48
|
+
max_tokens=max_tokens,
|
|
49
|
+
temperature=temperature,
|
|
50
|
+
prompt_template=prompt_template,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self._client = AsyncOpenAI(api_key=api_key)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def provider_name(self) -> str:
|
|
57
|
+
"""Human-readable provider name."""
|
|
58
|
+
return "OpenAI"
|
|
59
|
+
|
|
60
|
+
async def generate(self, prompt: str) -> str:
|
|
61
|
+
"""Generate text based on prompt using GPT.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
prompt: The prompt to send to GPT
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Generated text response
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ProviderError: If generation fails
|
|
71
|
+
"""
|
|
72
|
+
try:
|
|
73
|
+
response = await self._client.chat.completions.create(
|
|
74
|
+
model=self._model,
|
|
75
|
+
max_tokens=self._max_tokens,
|
|
76
|
+
temperature=self._temperature,
|
|
77
|
+
messages=[{"role": "user", "content": prompt}],
|
|
78
|
+
)
|
|
79
|
+
# Extract text from response
|
|
80
|
+
content = response.choices[0].message.content
|
|
81
|
+
return content if content else ""
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise ProviderError(
|
|
84
|
+
f"Failed to generate text: {e}",
|
|
85
|
+
self.provider_name,
|
|
86
|
+
cause=e,
|
|
87
|
+
) from e
|
|
@@ -10,6 +10,7 @@ from typing import Any, Callable, Optional, Union
|
|
|
10
10
|
|
|
11
11
|
from llama_index.core.schema import TextNode
|
|
12
12
|
|
|
13
|
+
from agent_brain_server.config import settings
|
|
13
14
|
from agent_brain_server.indexing import (
|
|
14
15
|
BM25IndexManager,
|
|
15
16
|
ContextAwareChunker,
|
|
@@ -18,6 +19,10 @@ from agent_brain_server.indexing import (
|
|
|
18
19
|
get_bm25_manager,
|
|
19
20
|
)
|
|
20
21
|
from agent_brain_server.indexing.chunking import CodeChunk, CodeChunker, TextChunk
|
|
22
|
+
from agent_brain_server.indexing.graph_index import (
|
|
23
|
+
GraphIndexManager,
|
|
24
|
+
get_graph_index_manager,
|
|
25
|
+
)
|
|
21
26
|
from agent_brain_server.models import IndexingState, IndexingStatusEnum, IndexRequest
|
|
22
27
|
from agent_brain_server.storage import VectorStoreManager, get_vector_store
|
|
23
28
|
|
|
@@ -43,6 +48,7 @@ class IndexingService:
|
|
|
43
48
|
chunker: Optional[ContextAwareChunker] = None,
|
|
44
49
|
embedding_generator: Optional[EmbeddingGenerator] = None,
|
|
45
50
|
bm25_manager: Optional[BM25IndexManager] = None,
|
|
51
|
+
graph_index_manager: Optional[GraphIndexManager] = None,
|
|
46
52
|
):
|
|
47
53
|
"""
|
|
48
54
|
Initialize the indexing service.
|
|
@@ -53,12 +59,14 @@ class IndexingService:
|
|
|
53
59
|
chunker: Text chunker instance.
|
|
54
60
|
embedding_generator: Embedding generator instance.
|
|
55
61
|
bm25_manager: BM25 index manager instance.
|
|
62
|
+
graph_index_manager: Graph index manager instance (Feature 113).
|
|
56
63
|
"""
|
|
57
64
|
self.vector_store = vector_store or get_vector_store()
|
|
58
65
|
self.document_loader = document_loader or DocumentLoader()
|
|
59
66
|
self.chunker = chunker or ContextAwareChunker()
|
|
60
67
|
self.embedding_generator = embedding_generator or EmbeddingGenerator()
|
|
61
68
|
self.bm25_manager = bm25_manager or get_bm25_manager()
|
|
69
|
+
self.graph_index_manager = graph_index_manager or get_graph_index_manager()
|
|
62
70
|
|
|
63
71
|
# Internal state
|
|
64
72
|
self._state = IndexingState(
|
|
@@ -382,6 +390,21 @@ class IndexingService:
|
|
|
382
390
|
]
|
|
383
391
|
self.bm25_manager.build_index(nodes)
|
|
384
392
|
|
|
393
|
+
# Step 6: Build graph index if enabled (Feature 113)
|
|
394
|
+
if settings.ENABLE_GRAPH_INDEX:
|
|
395
|
+
if progress_callback:
|
|
396
|
+
await progress_callback(97, 100, "Building graph index...")
|
|
397
|
+
|
|
398
|
+
def graph_progress(current: int, total: int, message: str) -> None:
|
|
399
|
+
# Synchronous callback wrapper
|
|
400
|
+
logger.debug(f"Graph indexing: {message}")
|
|
401
|
+
|
|
402
|
+
triplet_count = self.graph_index_manager.build_from_documents(
|
|
403
|
+
chunks,
|
|
404
|
+
progress_callback=graph_progress,
|
|
405
|
+
)
|
|
406
|
+
logger.info(f"Graph index built with {triplet_count} triplets")
|
|
407
|
+
|
|
385
408
|
# Mark as completed
|
|
386
409
|
self._state.status = IndexingStatusEnum.COMPLETED
|
|
387
410
|
self._state.completed_at = datetime.now(timezone.utc)
|
|
@@ -424,6 +447,9 @@ class IndexingService:
|
|
|
424
447
|
total_code_chunks = self._total_code_chunks
|
|
425
448
|
supported_languages = sorted(self._supported_languages)
|
|
426
449
|
|
|
450
|
+
# Get graph index status (Feature 113)
|
|
451
|
+
graph_status = self.graph_index_manager.get_status()
|
|
452
|
+
|
|
427
453
|
return {
|
|
428
454
|
"status": self._state.status.value,
|
|
429
455
|
"is_indexing": self._state.is_indexing,
|
|
@@ -446,6 +472,14 @@ class IndexingService:
|
|
|
446
472
|
),
|
|
447
473
|
"error": self._state.error,
|
|
448
474
|
"indexed_folders": sorted(self._indexed_folders),
|
|
475
|
+
# Graph index status (Feature 113)
|
|
476
|
+
"graph_index": {
|
|
477
|
+
"enabled": graph_status.enabled,
|
|
478
|
+
"initialized": graph_status.initialized,
|
|
479
|
+
"entity_count": graph_status.entity_count,
|
|
480
|
+
"relationship_count": graph_status.relationship_count,
|
|
481
|
+
"store_type": graph_status.store_type,
|
|
482
|
+
},
|
|
449
483
|
}
|
|
450
484
|
|
|
451
485
|
async def reset(self) -> None:
|
|
@@ -453,6 +487,8 @@ class IndexingService:
|
|
|
453
487
|
async with self._lock:
|
|
454
488
|
await self.vector_store.reset()
|
|
455
489
|
self.bm25_manager.reset()
|
|
490
|
+
# Clear graph index (Feature 113)
|
|
491
|
+
self.graph_index_manager.clear()
|
|
456
492
|
self._state = IndexingState(
|
|
457
493
|
current_job_id="",
|
|
458
494
|
folder_path="",
|
|
@@ -461,6 +497,9 @@ class IndexingService:
|
|
|
461
497
|
error=None,
|
|
462
498
|
)
|
|
463
499
|
self._indexed_folders.clear()
|
|
500
|
+
self._total_doc_chunks = 0
|
|
501
|
+
self._total_code_chunks = 0
|
|
502
|
+
self._supported_languages.clear()
|
|
464
503
|
logger.info("Indexing service reset")
|
|
465
504
|
|
|
466
505
|
|