agent-brain-rag 1.1.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.1.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +68 -27
- agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
- agent_brain_rag-2.0.0.dist-info/entry_points.txt +4 -0
- {doc_serve_server → agent_brain_server}/__init__.py +1 -1
- {doc_serve_server → agent_brain_server}/api/main.py +90 -26
- {doc_serve_server → agent_brain_server}/api/routers/health.py +4 -2
- {doc_serve_server → agent_brain_server}/api/routers/index.py +1 -1
- {doc_serve_server → agent_brain_server}/api/routers/query.py +3 -3
- agent_brain_server/config/provider_config.py +308 -0
- {doc_serve_server → agent_brain_server}/config/settings.py +12 -1
- agent_brain_server/indexing/__init__.py +40 -0
- {doc_serve_server → agent_brain_server}/indexing/bm25_index.py +1 -1
- {doc_serve_server → agent_brain_server}/indexing/chunking.py +1 -1
- agent_brain_server/indexing/embedding.py +225 -0
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- {doc_serve_server → agent_brain_server}/models/__init__.py +9 -0
- agent_brain_server/models/graph.py +253 -0
- {doc_serve_server → agent_brain_server}/models/health.py +15 -3
- {doc_serve_server → agent_brain_server}/models/query.py +14 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- {doc_serve_server → agent_brain_server}/services/indexing_service.py +43 -4
- {doc_serve_server → agent_brain_server}/services/query_service.py +212 -4
- agent_brain_server/storage/__init__.py +21 -0
- agent_brain_server/storage/graph_store.py +519 -0
- {doc_serve_server → agent_brain_server}/storage/vector_store.py +36 -1
- {doc_serve_server → agent_brain_server}/storage_paths.py +2 -0
- agent_brain_rag-1.1.0.dist-info/RECORD +0 -31
- agent_brain_rag-1.1.0.dist-info/entry_points.txt +0 -3
- doc_serve_server/indexing/__init__.py +0 -19
- doc_serve_server/indexing/embedding.py +0 -274
- doc_serve_server/storage/__init__.py +0 -5
- {agent_brain_rag-1.1.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
- {doc_serve_server → agent_brain_server}/api/__init__.py +0 -0
- {doc_serve_server → agent_brain_server}/api/routers/__init__.py +0 -0
- {doc_serve_server → agent_brain_server}/config/__init__.py +0 -0
- {doc_serve_server → agent_brain_server}/indexing/document_loader.py +0 -0
- {doc_serve_server → agent_brain_server}/locking.py +0 -0
- {doc_serve_server → agent_brain_server}/models/index.py +0 -0
- {doc_serve_server → agent_brain_server}/project_root.py +0 -0
- {doc_serve_server → agent_brain_server}/runtime.py +0 -0
- {doc_serve_server → agent_brain_server}/services/__init__.py +0 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""Provider configuration models and YAML loader.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for embedding and summarization
|
|
4
|
+
provider configuration, and functions to load configuration from YAML files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from functools import lru_cache
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import yaml
|
|
14
|
+
from pydantic import BaseModel, Field, field_validator
|
|
15
|
+
|
|
16
|
+
from agent_brain_server.providers.base import (
|
|
17
|
+
EmbeddingProviderType,
|
|
18
|
+
SummarizationProviderType,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddingConfig(BaseModel):
|
|
25
|
+
"""Configuration for embedding provider."""
|
|
26
|
+
|
|
27
|
+
provider: EmbeddingProviderType = Field(
|
|
28
|
+
default=EmbeddingProviderType.OPENAI,
|
|
29
|
+
description="Embedding provider to use",
|
|
30
|
+
)
|
|
31
|
+
model: str = Field(
|
|
32
|
+
default="text-embedding-3-large",
|
|
33
|
+
description="Model name for embeddings",
|
|
34
|
+
)
|
|
35
|
+
api_key_env: Optional[str] = Field(
|
|
36
|
+
default="OPENAI_API_KEY",
|
|
37
|
+
description="Environment variable name containing API key",
|
|
38
|
+
)
|
|
39
|
+
base_url: Optional[str] = Field(
|
|
40
|
+
default=None,
|
|
41
|
+
description="Custom base URL (for Ollama or compatible APIs)",
|
|
42
|
+
)
|
|
43
|
+
params: dict[str, Any] = Field(
|
|
44
|
+
default_factory=dict,
|
|
45
|
+
description="Provider-specific parameters",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
model_config = {"use_enum_values": True}
|
|
49
|
+
|
|
50
|
+
@field_validator("provider", mode="before")
|
|
51
|
+
@classmethod
|
|
52
|
+
def validate_provider(cls, v: Any) -> EmbeddingProviderType:
|
|
53
|
+
"""Convert string to enum if needed."""
|
|
54
|
+
if isinstance(v, str):
|
|
55
|
+
return EmbeddingProviderType(v.lower())
|
|
56
|
+
if isinstance(v, EmbeddingProviderType):
|
|
57
|
+
return v
|
|
58
|
+
return EmbeddingProviderType(v)
|
|
59
|
+
|
|
60
|
+
def get_api_key(self) -> Optional[str]:
|
|
61
|
+
"""Resolve API key from environment variable.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
API key value or None if not found/not needed
|
|
65
|
+
"""
|
|
66
|
+
if self.provider == EmbeddingProviderType.OLLAMA:
|
|
67
|
+
return None # Ollama doesn't need API key
|
|
68
|
+
if self.api_key_env:
|
|
69
|
+
return os.getenv(self.api_key_env)
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
def get_base_url(self) -> Optional[str]:
|
|
73
|
+
"""Get base URL with defaults for specific providers.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Base URL for the provider
|
|
77
|
+
"""
|
|
78
|
+
if self.base_url:
|
|
79
|
+
return self.base_url
|
|
80
|
+
if self.provider == EmbeddingProviderType.OLLAMA:
|
|
81
|
+
return "http://localhost:11434/v1"
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class SummarizationConfig(BaseModel):
|
|
86
|
+
"""Configuration for summarization provider."""
|
|
87
|
+
|
|
88
|
+
provider: SummarizationProviderType = Field(
|
|
89
|
+
default=SummarizationProviderType.ANTHROPIC,
|
|
90
|
+
description="Summarization provider to use",
|
|
91
|
+
)
|
|
92
|
+
model: str = Field(
|
|
93
|
+
default="claude-haiku-4-5-20251001",
|
|
94
|
+
description="Model name for summarization",
|
|
95
|
+
)
|
|
96
|
+
api_key_env: Optional[str] = Field(
|
|
97
|
+
default="ANTHROPIC_API_KEY",
|
|
98
|
+
description="Environment variable name containing API key",
|
|
99
|
+
)
|
|
100
|
+
base_url: Optional[str] = Field(
|
|
101
|
+
default=None,
|
|
102
|
+
description="Custom base URL (for Grok or Ollama)",
|
|
103
|
+
)
|
|
104
|
+
params: dict[str, Any] = Field(
|
|
105
|
+
default_factory=dict,
|
|
106
|
+
description="Provider-specific parameters (max_tokens, temperature)",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
model_config = {"use_enum_values": True}
|
|
110
|
+
|
|
111
|
+
@field_validator("provider", mode="before")
|
|
112
|
+
@classmethod
|
|
113
|
+
def validate_provider(cls, v: Any) -> SummarizationProviderType:
|
|
114
|
+
"""Convert string to enum if needed."""
|
|
115
|
+
if isinstance(v, str):
|
|
116
|
+
return SummarizationProviderType(v.lower())
|
|
117
|
+
if isinstance(v, SummarizationProviderType):
|
|
118
|
+
return v
|
|
119
|
+
return SummarizationProviderType(v)
|
|
120
|
+
|
|
121
|
+
def get_api_key(self) -> Optional[str]:
|
|
122
|
+
"""Resolve API key from environment variable.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
API key value or None if not found/not needed
|
|
126
|
+
"""
|
|
127
|
+
if self.provider == SummarizationProviderType.OLLAMA:
|
|
128
|
+
return None # Ollama doesn't need API key
|
|
129
|
+
if self.api_key_env:
|
|
130
|
+
return os.getenv(self.api_key_env)
|
|
131
|
+
return None
|
|
132
|
+
|
|
133
|
+
def get_base_url(self) -> Optional[str]:
|
|
134
|
+
"""Get base URL with defaults for specific providers.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Base URL for the provider
|
|
138
|
+
"""
|
|
139
|
+
if self.base_url:
|
|
140
|
+
return self.base_url
|
|
141
|
+
if self.provider == SummarizationProviderType.OLLAMA:
|
|
142
|
+
return "http://localhost:11434/v1"
|
|
143
|
+
if self.provider == SummarizationProviderType.GROK:
|
|
144
|
+
return "https://api.x.ai/v1"
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class ProviderSettings(BaseModel):
|
|
149
|
+
"""Top-level provider configuration."""
|
|
150
|
+
|
|
151
|
+
embedding: EmbeddingConfig = Field(
|
|
152
|
+
default_factory=EmbeddingConfig,
|
|
153
|
+
description="Embedding provider configuration",
|
|
154
|
+
)
|
|
155
|
+
summarization: SummarizationConfig = Field(
|
|
156
|
+
default_factory=SummarizationConfig,
|
|
157
|
+
description="Summarization provider configuration",
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _find_config_file() -> Optional[Path]:
|
|
162
|
+
"""Find the configuration file in standard locations.
|
|
163
|
+
|
|
164
|
+
Search order:
|
|
165
|
+
1. DOC_SERVE_CONFIG environment variable
|
|
166
|
+
2. Current directory config.yaml
|
|
167
|
+
3. State directory config.yaml (if DOC_SERVE_STATE_DIR set)
|
|
168
|
+
4. Project root config.yaml
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Path to config file or None if not found
|
|
172
|
+
"""
|
|
173
|
+
# 1. Environment variable override
|
|
174
|
+
env_config = os.getenv("DOC_SERVE_CONFIG")
|
|
175
|
+
if env_config:
|
|
176
|
+
path = Path(env_config)
|
|
177
|
+
if path.exists():
|
|
178
|
+
return path
|
|
179
|
+
logger.warning(f"DOC_SERVE_CONFIG points to non-existent file: {env_config}")
|
|
180
|
+
|
|
181
|
+
# 2. Current directory
|
|
182
|
+
cwd_config = Path.cwd() / "config.yaml"
|
|
183
|
+
if cwd_config.exists():
|
|
184
|
+
return cwd_config
|
|
185
|
+
|
|
186
|
+
# 3. State directory
|
|
187
|
+
state_dir = os.getenv("DOC_SERVE_STATE_DIR")
|
|
188
|
+
if state_dir:
|
|
189
|
+
state_config = Path(state_dir) / "config.yaml"
|
|
190
|
+
if state_config.exists():
|
|
191
|
+
return state_config
|
|
192
|
+
|
|
193
|
+
# 4. .claude/doc-serve directory (project root pattern)
|
|
194
|
+
claude_dir = Path.cwd() / ".claude" / "doc-serve"
|
|
195
|
+
if claude_dir.exists():
|
|
196
|
+
claude_config = claude_dir / "config.yaml"
|
|
197
|
+
if claude_config.exists():
|
|
198
|
+
return claude_config
|
|
199
|
+
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _load_yaml_config(path: Path) -> dict[str, Any]:
|
|
204
|
+
"""Load YAML configuration from file.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
path: Path to YAML config file
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Configuration dictionary
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ConfigurationError: If YAML parsing fails
|
|
214
|
+
"""
|
|
215
|
+
from agent_brain_server.providers.exceptions import ConfigurationError
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
with open(path) as f:
|
|
219
|
+
config = yaml.safe_load(f)
|
|
220
|
+
return config if config else {}
|
|
221
|
+
except yaml.YAMLError as e:
|
|
222
|
+
raise ConfigurationError(
|
|
223
|
+
f"Failed to parse config file {path}: {e}",
|
|
224
|
+
"config",
|
|
225
|
+
) from e
|
|
226
|
+
except OSError as e:
|
|
227
|
+
raise ConfigurationError(
|
|
228
|
+
f"Failed to read config file {path}: {e}",
|
|
229
|
+
"config",
|
|
230
|
+
) from e
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@lru_cache
|
|
234
|
+
def load_provider_settings() -> ProviderSettings:
|
|
235
|
+
"""Load provider settings from YAML config or defaults.
|
|
236
|
+
|
|
237
|
+
This function:
|
|
238
|
+
1. Searches for config.yaml in standard locations
|
|
239
|
+
2. Parses YAML and validates against Pydantic models
|
|
240
|
+
3. Falls back to defaults (OpenAI embeddings + Anthropic summarization)
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Validated ProviderSettings instance
|
|
244
|
+
"""
|
|
245
|
+
config_path = _find_config_file()
|
|
246
|
+
|
|
247
|
+
if config_path:
|
|
248
|
+
logger.info(f"Loading provider config from {config_path}")
|
|
249
|
+
raw_config = _load_yaml_config(config_path)
|
|
250
|
+
settings = ProviderSettings(**raw_config)
|
|
251
|
+
else:
|
|
252
|
+
logger.info("No config file found, using default providers")
|
|
253
|
+
settings = ProviderSettings()
|
|
254
|
+
|
|
255
|
+
# Log active configuration
|
|
256
|
+
logger.info(
|
|
257
|
+
f"Active embedding provider: {settings.embedding.provider} "
|
|
258
|
+
f"(model: {settings.embedding.model})"
|
|
259
|
+
)
|
|
260
|
+
logger.info(
|
|
261
|
+
f"Active summarization provider: {settings.summarization.provider} "
|
|
262
|
+
f"(model: {settings.summarization.model})"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return settings
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def clear_settings_cache() -> None:
|
|
269
|
+
"""Clear the cached provider settings (for testing)."""
|
|
270
|
+
load_provider_settings.cache_clear()
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def validate_provider_config(settings: ProviderSettings) -> list[str]:
|
|
274
|
+
"""Validate provider configuration and return list of errors.
|
|
275
|
+
|
|
276
|
+
Checks:
|
|
277
|
+
- API keys are available for providers that need them
|
|
278
|
+
- Models are known for the selected provider
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
settings: Provider settings to validate
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
List of validation error messages (empty if valid)
|
|
285
|
+
"""
|
|
286
|
+
errors: list[str] = []
|
|
287
|
+
|
|
288
|
+
# Validate embedding provider
|
|
289
|
+
if settings.embedding.provider != EmbeddingProviderType.OLLAMA:
|
|
290
|
+
api_key = settings.embedding.get_api_key()
|
|
291
|
+
if not api_key:
|
|
292
|
+
env_var = settings.embedding.api_key_env or "OPENAI_API_KEY"
|
|
293
|
+
errors.append(
|
|
294
|
+
f"Missing API key for {settings.embedding.provider} embeddings. "
|
|
295
|
+
f"Set {env_var} environment variable."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Validate summarization provider
|
|
299
|
+
if settings.summarization.provider != SummarizationProviderType.OLLAMA:
|
|
300
|
+
api_key = settings.summarization.get_api_key()
|
|
301
|
+
if not api_key:
|
|
302
|
+
env_var = settings.summarization.api_key_env or "ANTHROPIC_API_KEY"
|
|
303
|
+
errors.append(
|
|
304
|
+
f"Missing API key for {settings.summarization.provider} summarization. "
|
|
305
|
+
f"Set {env_var} environment variable."
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return errors
|
|
@@ -26,7 +26,7 @@ class Settings(BaseSettings):
|
|
|
26
26
|
|
|
27
27
|
# Anthropic Configuration
|
|
28
28
|
ANTHROPIC_API_KEY: str = ""
|
|
29
|
-
CLAUDE_MODEL: str = "claude-
|
|
29
|
+
CLAUDE_MODEL: str = "claude-haiku-4-5-20251001" # Claude 4.5 Haiku (latest)
|
|
30
30
|
|
|
31
31
|
# Chroma Configuration
|
|
32
32
|
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
|
@@ -51,6 +51,17 @@ class Settings(BaseSettings):
|
|
|
51
51
|
DOC_SERVE_STATE_DIR: Optional[str] = None # Override state directory
|
|
52
52
|
DOC_SERVE_MODE: str = "project" # "project" or "shared"
|
|
53
53
|
|
|
54
|
+
# GraphRAG Configuration (Feature 113)
|
|
55
|
+
ENABLE_GRAPH_INDEX: bool = False # Master switch for graph indexing
|
|
56
|
+
GRAPH_STORE_TYPE: str = "simple" # "simple" (in-memory) or "kuzu" (persistent)
|
|
57
|
+
GRAPH_INDEX_PATH: str = "./graph_index" # Path for graph persistence
|
|
58
|
+
GRAPH_EXTRACTION_MODEL: str = "claude-haiku-4-5" # Model for entity extraction
|
|
59
|
+
GRAPH_MAX_TRIPLETS_PER_CHUNK: int = 10 # Max triplets per document chunk
|
|
60
|
+
GRAPH_USE_CODE_METADATA: bool = True # Use AST metadata for code entities
|
|
61
|
+
GRAPH_USE_LLM_EXTRACTION: bool = True # Use LLM for additional extraction
|
|
62
|
+
GRAPH_TRAVERSAL_DEPTH: int = 2 # Depth for graph traversal in queries
|
|
63
|
+
GRAPH_RRF_K: int = 60 # Reciprocal Rank Fusion constant for multi-retrieval
|
|
64
|
+
|
|
54
65
|
model_config = SettingsConfigDict(
|
|
55
66
|
env_file=[
|
|
56
67
|
".env", # Current directory
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Indexing pipeline components for document processing."""
|
|
2
|
+
|
|
3
|
+
from agent_brain_server.indexing.bm25_index import BM25IndexManager, get_bm25_manager
|
|
4
|
+
from agent_brain_server.indexing.chunking import CodeChunker, ContextAwareChunker
|
|
5
|
+
from agent_brain_server.indexing.document_loader import DocumentLoader
|
|
6
|
+
from agent_brain_server.indexing.embedding import (
|
|
7
|
+
EmbeddingGenerator,
|
|
8
|
+
get_embedding_generator,
|
|
9
|
+
)
|
|
10
|
+
from agent_brain_server.indexing.graph_extractors import (
|
|
11
|
+
CodeMetadataExtractor,
|
|
12
|
+
LLMEntityExtractor,
|
|
13
|
+
get_code_extractor,
|
|
14
|
+
get_llm_extractor,
|
|
15
|
+
reset_extractors,
|
|
16
|
+
)
|
|
17
|
+
from agent_brain_server.indexing.graph_index import (
|
|
18
|
+
GraphIndexManager,
|
|
19
|
+
get_graph_index_manager,
|
|
20
|
+
reset_graph_index_manager,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"DocumentLoader",
|
|
25
|
+
"ContextAwareChunker",
|
|
26
|
+
"CodeChunker",
|
|
27
|
+
"EmbeddingGenerator",
|
|
28
|
+
"get_embedding_generator",
|
|
29
|
+
"BM25IndexManager",
|
|
30
|
+
"get_bm25_manager",
|
|
31
|
+
# Graph indexing (Feature 113)
|
|
32
|
+
"LLMEntityExtractor",
|
|
33
|
+
"CodeMetadataExtractor",
|
|
34
|
+
"get_llm_extractor",
|
|
35
|
+
"get_code_extractor",
|
|
36
|
+
"reset_extractors",
|
|
37
|
+
"GraphIndexManager",
|
|
38
|
+
"get_graph_index_manager",
|
|
39
|
+
"reset_graph_index_manager",
|
|
40
|
+
]
|
|
@@ -8,7 +8,7 @@ from typing import Optional
|
|
|
8
8
|
from llama_index.core.schema import BaseNode, NodeWithScore
|
|
9
9
|
from llama_index.retrievers.bm25 import BM25Retriever
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from agent_brain_server.config import settings
|
|
12
12
|
|
|
13
13
|
logger = logging.getLogger(__name__)
|
|
14
14
|
|
|
@@ -13,7 +13,7 @@ import tree_sitter
|
|
|
13
13
|
import tree_sitter_language_pack as tslp
|
|
14
14
|
from llama_index.core.node_parser import CodeSplitter, SentenceSplitter
|
|
15
15
|
|
|
16
|
-
from
|
|
16
|
+
from agent_brain_server.config import settings
|
|
17
17
|
|
|
18
18
|
from .document_loader import LoadedDocument
|
|
19
19
|
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Embedding generation using pluggable providers.
|
|
2
|
+
|
|
3
|
+
This module provides embedding and summarization functionality using
|
|
4
|
+
the configurable provider system. Providers are selected based on
|
|
5
|
+
config.yaml or environment defaults.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
from collections.abc import Awaitable, Callable
|
|
11
|
+
from typing import TYPE_CHECKING, Optional
|
|
12
|
+
|
|
13
|
+
from agent_brain_server.config.provider_config import load_provider_settings
|
|
14
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from agent_brain_server.providers.base import (
|
|
18
|
+
EmbeddingProvider,
|
|
19
|
+
SummarizationProvider,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from .chunking import TextChunk
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EmbeddingGenerator:
|
|
28
|
+
"""Generates embeddings and summaries using pluggable providers.
|
|
29
|
+
|
|
30
|
+
Supports batch processing with configurable batch sizes
|
|
31
|
+
and automatic provider selection based on configuration.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
embedding_provider: Optional["EmbeddingProvider"] = None,
|
|
37
|
+
summarization_provider: Optional["SummarizationProvider"] = None,
|
|
38
|
+
):
|
|
39
|
+
"""Initialize the embedding generator.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
embedding_provider: Optional embedding provider. If not provided,
|
|
43
|
+
creates one from configuration.
|
|
44
|
+
summarization_provider: Optional summarization provider. If not
|
|
45
|
+
provided, creates one from configuration.
|
|
46
|
+
"""
|
|
47
|
+
# Load configuration
|
|
48
|
+
settings = load_provider_settings()
|
|
49
|
+
|
|
50
|
+
# Initialize providers from config or use provided ones
|
|
51
|
+
if embedding_provider is not None:
|
|
52
|
+
self._embedding_provider = embedding_provider
|
|
53
|
+
else:
|
|
54
|
+
self._embedding_provider = ProviderRegistry.get_embedding_provider(
|
|
55
|
+
settings.embedding
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if summarization_provider is not None:
|
|
59
|
+
self._summarization_provider = summarization_provider
|
|
60
|
+
else:
|
|
61
|
+
self._summarization_provider = ProviderRegistry.get_summarization_provider(
|
|
62
|
+
settings.summarization
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
logger.info(
|
|
66
|
+
f"EmbeddingGenerator initialized with "
|
|
67
|
+
f"{self._embedding_provider.provider_name} embeddings "
|
|
68
|
+
f"({self._embedding_provider.model_name}) and "
|
|
69
|
+
f"{self._summarization_provider.provider_name} summarization "
|
|
70
|
+
f"({self._summarization_provider.model_name})"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def model(self) -> str:
|
|
75
|
+
"""Get the embedding model name."""
|
|
76
|
+
return self._embedding_provider.model_name
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def embedding_provider(self) -> "EmbeddingProvider":
|
|
80
|
+
"""Get the embedding provider."""
|
|
81
|
+
return self._embedding_provider
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def summarization_provider(self) -> "SummarizationProvider":
|
|
85
|
+
"""Get the summarization provider."""
|
|
86
|
+
return self._summarization_provider
|
|
87
|
+
|
|
88
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
89
|
+
"""Generate embedding for a single text.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
text: Text to embed.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Embedding vector as list of floats.
|
|
96
|
+
"""
|
|
97
|
+
return await self._embedding_provider.embed_text(text)
|
|
98
|
+
|
|
99
|
+
async def embed_texts(
|
|
100
|
+
self,
|
|
101
|
+
texts: list[str],
|
|
102
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
103
|
+
) -> list[list[float]]:
|
|
104
|
+
"""Generate embeddings for multiple texts.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
texts: List of texts to embed.
|
|
108
|
+
progress_callback: Optional callback(processed, total) for progress.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of embedding vectors.
|
|
112
|
+
"""
|
|
113
|
+
return await self._embedding_provider.embed_texts(texts, progress_callback)
|
|
114
|
+
|
|
115
|
+
async def embed_chunks(
|
|
116
|
+
self,
|
|
117
|
+
chunks: list[TextChunk],
|
|
118
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
119
|
+
) -> list[list[float]]:
|
|
120
|
+
"""Generate embeddings for a list of text chunks.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
chunks: List of TextChunk objects.
|
|
124
|
+
progress_callback: Optional callback for progress updates.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
List of embedding vectors corresponding to each chunk.
|
|
128
|
+
"""
|
|
129
|
+
texts = [chunk.text for chunk in chunks]
|
|
130
|
+
return await self.embed_texts(texts, progress_callback)
|
|
131
|
+
|
|
132
|
+
async def embed_query(self, query: str) -> list[float]:
|
|
133
|
+
"""Generate embedding for a search query.
|
|
134
|
+
|
|
135
|
+
This is a convenience wrapper around embed_text for queries.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
query: The search query text.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Query embedding vector.
|
|
142
|
+
"""
|
|
143
|
+
return await self.embed_text(query)
|
|
144
|
+
|
|
145
|
+
def get_embedding_dimensions(self) -> int:
|
|
146
|
+
"""Get the expected embedding dimensions for the current model.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Number of dimensions in the embedding vector.
|
|
150
|
+
"""
|
|
151
|
+
return self._embedding_provider.get_dimensions()
|
|
152
|
+
|
|
153
|
+
async def generate_summary(self, code_text: str) -> str:
|
|
154
|
+
"""Generate a natural language summary of code.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
code_text: The source code to summarize.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Natural language summary of the code's functionality.
|
|
161
|
+
"""
|
|
162
|
+
try:
|
|
163
|
+
summary = await self._summarization_provider.summarize(code_text)
|
|
164
|
+
|
|
165
|
+
if summary and len(summary) > 10:
|
|
166
|
+
return summary
|
|
167
|
+
else:
|
|
168
|
+
logger.warning(
|
|
169
|
+
f"{self._summarization_provider.provider_name} "
|
|
170
|
+
"returned empty or too short summary"
|
|
171
|
+
)
|
|
172
|
+
return self._extract_fallback_summary(code_text)
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logger.error(f"Failed to generate code summary: {e}")
|
|
176
|
+
return self._extract_fallback_summary(code_text)
|
|
177
|
+
|
|
178
|
+
def _extract_fallback_summary(self, code_text: str) -> str:
|
|
179
|
+
"""Extract summary from docstrings or comments as fallback.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
code_text: Source code to analyze.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Extracted summary or empty string.
|
|
186
|
+
"""
|
|
187
|
+
# Try to find Python docstrings
|
|
188
|
+
docstring_match = re.search(r'""".*?"""', code_text, re.DOTALL)
|
|
189
|
+
if docstring_match:
|
|
190
|
+
docstring = docstring_match.group(0)[3:-3]
|
|
191
|
+
if len(docstring) > 10:
|
|
192
|
+
return docstring[:200] + "..." if len(docstring) > 200 else docstring
|
|
193
|
+
|
|
194
|
+
# Try to find function/class comments
|
|
195
|
+
comment_match = re.search(
|
|
196
|
+
r"#.*(?:function|class|method|def)", code_text, re.IGNORECASE
|
|
197
|
+
)
|
|
198
|
+
if comment_match:
|
|
199
|
+
return comment_match.group(0).strip("#").strip()
|
|
200
|
+
|
|
201
|
+
# Last resort: first line if it looks like a comment
|
|
202
|
+
lines = code_text.strip().split("\n")
|
|
203
|
+
first_line = lines[0].strip()
|
|
204
|
+
if first_line.startswith(("#", "//", "/*")):
|
|
205
|
+
return first_line.lstrip("#/*").strip()
|
|
206
|
+
|
|
207
|
+
return ""
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# Singleton instance
|
|
211
|
+
_embedding_generator: Optional[EmbeddingGenerator] = None
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def get_embedding_generator() -> EmbeddingGenerator:
|
|
215
|
+
"""Get the global embedding generator instance."""
|
|
216
|
+
global _embedding_generator
|
|
217
|
+
if _embedding_generator is None:
|
|
218
|
+
_embedding_generator = EmbeddingGenerator()
|
|
219
|
+
return _embedding_generator
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def reset_embedding_generator() -> None:
|
|
223
|
+
"""Reset the global embedding generator (for testing)."""
|
|
224
|
+
global _embedding_generator
|
|
225
|
+
_embedding_generator = None
|