agent-brain-rag 1.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/METADATA +55 -18
- agent_brain_rag-3.0.0.dist-info/RECORD +56 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/WHEEL +1 -1
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-3.0.0.dist-info}/entry_points.txt +0 -1
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +146 -45
- agent_brain_server/api/routers/__init__.py +2 -0
- agent_brain_server/api/routers/health.py +85 -21
- agent_brain_server/api/routers/index.py +108 -36
- agent_brain_server/api/routers/jobs.py +111 -0
- agent_brain_server/config/provider_config.py +352 -0
- agent_brain_server/config/settings.py +22 -5
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/bm25_index.py +15 -2
- agent_brain_server/indexing/document_loader.py +45 -4
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/job_queue/__init__.py +11 -0
- agent_brain_server/job_queue/job_service.py +317 -0
- agent_brain_server/job_queue/job_store.py +427 -0
- agent_brain_server/job_queue/job_worker.py +434 -0
- agent_brain_server/locking.py +101 -8
- agent_brain_server/models/__init__.py +28 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +30 -3
- agent_brain_server/models/job.py +289 -0
- agent_brain_server/models/query.py +16 -3
- agent_brain_server/project_root.py +1 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/runtime.py +2 -2
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +5 -3
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""Provider configuration models and YAML loader.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for embedding and summarization
|
|
4
|
+
provider configuration, and functions to load configuration from YAML files.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from functools import lru_cache
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
import yaml
|
|
14
|
+
from pydantic import BaseModel, Field, field_validator
|
|
15
|
+
|
|
16
|
+
from agent_brain_server.providers.base import (
|
|
17
|
+
EmbeddingProviderType,
|
|
18
|
+
SummarizationProviderType,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddingConfig(BaseModel):
|
|
25
|
+
"""Configuration for embedding provider."""
|
|
26
|
+
|
|
27
|
+
provider: EmbeddingProviderType = Field(
|
|
28
|
+
default=EmbeddingProviderType.OPENAI,
|
|
29
|
+
description="Embedding provider to use",
|
|
30
|
+
)
|
|
31
|
+
model: str = Field(
|
|
32
|
+
default="text-embedding-3-large",
|
|
33
|
+
description="Model name for embeddings",
|
|
34
|
+
)
|
|
35
|
+
api_key: Optional[str] = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
description="API key (alternative to api_key_env for local config files)",
|
|
38
|
+
)
|
|
39
|
+
api_key_env: Optional[str] = Field(
|
|
40
|
+
default="OPENAI_API_KEY",
|
|
41
|
+
description="Environment variable name containing API key",
|
|
42
|
+
)
|
|
43
|
+
base_url: Optional[str] = Field(
|
|
44
|
+
default=None,
|
|
45
|
+
description="Custom base URL (for Ollama or compatible APIs)",
|
|
46
|
+
)
|
|
47
|
+
params: dict[str, Any] = Field(
|
|
48
|
+
default_factory=dict,
|
|
49
|
+
description="Provider-specific parameters",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
model_config = {"use_enum_values": True}
|
|
53
|
+
|
|
54
|
+
@field_validator("provider", mode="before")
|
|
55
|
+
@classmethod
|
|
56
|
+
def validate_provider(cls, v: Any) -> EmbeddingProviderType:
|
|
57
|
+
"""Convert string to enum if needed."""
|
|
58
|
+
if isinstance(v, str):
|
|
59
|
+
return EmbeddingProviderType(v.lower())
|
|
60
|
+
if isinstance(v, EmbeddingProviderType):
|
|
61
|
+
return v
|
|
62
|
+
return EmbeddingProviderType(v)
|
|
63
|
+
|
|
64
|
+
def get_api_key(self) -> Optional[str]:
|
|
65
|
+
"""Resolve API key from config or environment variable.
|
|
66
|
+
|
|
67
|
+
Resolution order:
|
|
68
|
+
1. api_key field in config (direct value)
|
|
69
|
+
2. Environment variable specified by api_key_env
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
API key value or None if not found/not needed
|
|
73
|
+
"""
|
|
74
|
+
if self.provider == EmbeddingProviderType.OLLAMA:
|
|
75
|
+
return None # Ollama doesn't need API key
|
|
76
|
+
# Check direct api_key first
|
|
77
|
+
if self.api_key:
|
|
78
|
+
return self.api_key
|
|
79
|
+
# Fall back to environment variable
|
|
80
|
+
if self.api_key_env:
|
|
81
|
+
return os.getenv(self.api_key_env)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
def get_base_url(self) -> Optional[str]:
|
|
85
|
+
"""Get base URL with defaults for specific providers.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Base URL for the provider
|
|
89
|
+
"""
|
|
90
|
+
if self.base_url:
|
|
91
|
+
return self.base_url
|
|
92
|
+
if self.provider == EmbeddingProviderType.OLLAMA:
|
|
93
|
+
return "http://localhost:11434/v1"
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SummarizationConfig(BaseModel):
|
|
98
|
+
"""Configuration for summarization provider."""
|
|
99
|
+
|
|
100
|
+
provider: SummarizationProviderType = Field(
|
|
101
|
+
default=SummarizationProviderType.ANTHROPIC,
|
|
102
|
+
description="Summarization provider to use",
|
|
103
|
+
)
|
|
104
|
+
model: str = Field(
|
|
105
|
+
default="claude-haiku-4-5-20251001",
|
|
106
|
+
description="Model name for summarization",
|
|
107
|
+
)
|
|
108
|
+
api_key: Optional[str] = Field(
|
|
109
|
+
default=None,
|
|
110
|
+
description="API key (alternative to api_key_env for local config files)",
|
|
111
|
+
)
|
|
112
|
+
api_key_env: Optional[str] = Field(
|
|
113
|
+
default="ANTHROPIC_API_KEY",
|
|
114
|
+
description="Environment variable name containing API key",
|
|
115
|
+
)
|
|
116
|
+
base_url: Optional[str] = Field(
|
|
117
|
+
default=None,
|
|
118
|
+
description="Custom base URL (for Grok or Ollama)",
|
|
119
|
+
)
|
|
120
|
+
params: dict[str, Any] = Field(
|
|
121
|
+
default_factory=dict,
|
|
122
|
+
description="Provider-specific parameters (max_tokens, temperature)",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
model_config = {"use_enum_values": True}
|
|
126
|
+
|
|
127
|
+
@field_validator("provider", mode="before")
|
|
128
|
+
@classmethod
|
|
129
|
+
def validate_provider(cls, v: Any) -> SummarizationProviderType:
|
|
130
|
+
"""Convert string to enum if needed."""
|
|
131
|
+
if isinstance(v, str):
|
|
132
|
+
return SummarizationProviderType(v.lower())
|
|
133
|
+
if isinstance(v, SummarizationProviderType):
|
|
134
|
+
return v
|
|
135
|
+
return SummarizationProviderType(v)
|
|
136
|
+
|
|
137
|
+
def get_api_key(self) -> Optional[str]:
|
|
138
|
+
"""Resolve API key from config or environment variable.
|
|
139
|
+
|
|
140
|
+
Resolution order:
|
|
141
|
+
1. api_key field in config (direct value)
|
|
142
|
+
2. Environment variable specified by api_key_env
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
API key value or None if not found/not needed
|
|
146
|
+
"""
|
|
147
|
+
if self.provider == SummarizationProviderType.OLLAMA:
|
|
148
|
+
return None # Ollama doesn't need API key
|
|
149
|
+
# Check direct api_key first
|
|
150
|
+
if self.api_key:
|
|
151
|
+
return self.api_key
|
|
152
|
+
# Fall back to environment variable
|
|
153
|
+
if self.api_key_env:
|
|
154
|
+
return os.getenv(self.api_key_env)
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
def get_base_url(self) -> Optional[str]:
|
|
158
|
+
"""Get base URL with defaults for specific providers.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Base URL for the provider
|
|
162
|
+
"""
|
|
163
|
+
if self.base_url:
|
|
164
|
+
return self.base_url
|
|
165
|
+
if self.provider == SummarizationProviderType.OLLAMA:
|
|
166
|
+
return "http://localhost:11434/v1"
|
|
167
|
+
if self.provider == SummarizationProviderType.GROK:
|
|
168
|
+
return "https://api.x.ai/v1"
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class ProviderSettings(BaseModel):
|
|
173
|
+
"""Top-level provider configuration."""
|
|
174
|
+
|
|
175
|
+
embedding: EmbeddingConfig = Field(
|
|
176
|
+
default_factory=EmbeddingConfig,
|
|
177
|
+
description="Embedding provider configuration",
|
|
178
|
+
)
|
|
179
|
+
summarization: SummarizationConfig = Field(
|
|
180
|
+
default_factory=SummarizationConfig,
|
|
181
|
+
description="Summarization provider configuration",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _find_config_file() -> Optional[Path]:
|
|
186
|
+
"""Find the configuration file in standard locations.
|
|
187
|
+
|
|
188
|
+
Search order:
|
|
189
|
+
1. AGENT_BRAIN_CONFIG environment variable
|
|
190
|
+
2. State directory config.yaml (if AGENT_BRAIN_STATE_DIR or DOC_SERVE_STATE_DIR set)
|
|
191
|
+
3. Current directory config.yaml
|
|
192
|
+
4. Walk up from CWD looking for .claude/agent-brain/config.yaml
|
|
193
|
+
5. User home ~/.agent-brain/config.yaml
|
|
194
|
+
6. XDG config ~/.config/agent-brain/config.yaml
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Path to config file or None if not found
|
|
198
|
+
"""
|
|
199
|
+
# 1. Environment variable override
|
|
200
|
+
env_config = os.getenv("AGENT_BRAIN_CONFIG")
|
|
201
|
+
if env_config:
|
|
202
|
+
path = Path(env_config)
|
|
203
|
+
if path.exists():
|
|
204
|
+
logger.debug(f"Found config via AGENT_BRAIN_CONFIG: {path}")
|
|
205
|
+
return path
|
|
206
|
+
logger.warning(f"AGENT_BRAIN_CONFIG points to non-existent file: {env_config}")
|
|
207
|
+
|
|
208
|
+
# 2. State directory (check both new and legacy env vars)
|
|
209
|
+
state_dir = os.getenv("AGENT_BRAIN_STATE_DIR") or os.getenv("DOC_SERVE_STATE_DIR")
|
|
210
|
+
if state_dir:
|
|
211
|
+
state_config = Path(state_dir) / "config.yaml"
|
|
212
|
+
if state_config.exists():
|
|
213
|
+
logger.debug(f"Found config in state directory: {state_config}")
|
|
214
|
+
return state_config
|
|
215
|
+
|
|
216
|
+
# 3. Current directory
|
|
217
|
+
cwd_config = Path.cwd() / "config.yaml"
|
|
218
|
+
if cwd_config.exists():
|
|
219
|
+
logger.debug(f"Found config in current directory: {cwd_config}")
|
|
220
|
+
return cwd_config
|
|
221
|
+
|
|
222
|
+
# 4. Walk up from CWD looking for .claude/agent-brain/config.yaml
|
|
223
|
+
current = Path.cwd()
|
|
224
|
+
root = Path(current.anchor)
|
|
225
|
+
while current != root:
|
|
226
|
+
claude_config = current / ".claude" / "agent-brain" / "config.yaml"
|
|
227
|
+
if claude_config.exists():
|
|
228
|
+
logger.debug(f"Found config walking up from CWD: {claude_config}")
|
|
229
|
+
return claude_config
|
|
230
|
+
current = current.parent
|
|
231
|
+
|
|
232
|
+
# 5. User home directory ~/.agent-brain/config.yaml
|
|
233
|
+
home_config = Path.home() / ".agent-brain" / "config.yaml"
|
|
234
|
+
if home_config.exists():
|
|
235
|
+
logger.debug(f"Found config in home directory: {home_config}")
|
|
236
|
+
return home_config
|
|
237
|
+
|
|
238
|
+
# 6. XDG config directory ~/.config/agent-brain/config.yaml
|
|
239
|
+
xdg_config = Path.home() / ".config" / "agent-brain" / "config.yaml"
|
|
240
|
+
if xdg_config.exists():
|
|
241
|
+
logger.debug(f"Found config in XDG config directory: {xdg_config}")
|
|
242
|
+
return xdg_config
|
|
243
|
+
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _load_yaml_config(path: Path) -> dict[str, Any]:
|
|
248
|
+
"""Load YAML configuration from file.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
path: Path to YAML config file
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Configuration dictionary
|
|
255
|
+
|
|
256
|
+
Raises:
|
|
257
|
+
ConfigurationError: If YAML parsing fails
|
|
258
|
+
"""
|
|
259
|
+
from agent_brain_server.providers.exceptions import ConfigurationError
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
with open(path) as f:
|
|
263
|
+
config = yaml.safe_load(f)
|
|
264
|
+
return config if config else {}
|
|
265
|
+
except yaml.YAMLError as e:
|
|
266
|
+
raise ConfigurationError(
|
|
267
|
+
f"Failed to parse config file {path}: {e}",
|
|
268
|
+
"config",
|
|
269
|
+
) from e
|
|
270
|
+
except OSError as e:
|
|
271
|
+
raise ConfigurationError(
|
|
272
|
+
f"Failed to read config file {path}: {e}",
|
|
273
|
+
"config",
|
|
274
|
+
) from e
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
@lru_cache
|
|
278
|
+
def load_provider_settings() -> ProviderSettings:
|
|
279
|
+
"""Load provider settings from YAML config or defaults.
|
|
280
|
+
|
|
281
|
+
This function:
|
|
282
|
+
1. Searches for config.yaml in standard locations
|
|
283
|
+
2. Parses YAML and validates against Pydantic models
|
|
284
|
+
3. Falls back to defaults (OpenAI embeddings + Anthropic summarization)
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Validated ProviderSettings instance
|
|
288
|
+
"""
|
|
289
|
+
config_path = _find_config_file()
|
|
290
|
+
|
|
291
|
+
if config_path:
|
|
292
|
+
logger.info(f"Loading provider config from {config_path}")
|
|
293
|
+
raw_config = _load_yaml_config(config_path)
|
|
294
|
+
settings = ProviderSettings(**raw_config)
|
|
295
|
+
else:
|
|
296
|
+
logger.info("No config file found, using default providers")
|
|
297
|
+
settings = ProviderSettings()
|
|
298
|
+
|
|
299
|
+
# Log active configuration
|
|
300
|
+
logger.info(
|
|
301
|
+
f"Active embedding provider: {settings.embedding.provider} "
|
|
302
|
+
f"(model: {settings.embedding.model})"
|
|
303
|
+
)
|
|
304
|
+
logger.info(
|
|
305
|
+
f"Active summarization provider: {settings.summarization.provider} "
|
|
306
|
+
f"(model: {settings.summarization.model})"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return settings
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def clear_settings_cache() -> None:
|
|
313
|
+
"""Clear the cached provider settings (for testing)."""
|
|
314
|
+
load_provider_settings.cache_clear()
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def validate_provider_config(settings: ProviderSettings) -> list[str]:
|
|
318
|
+
"""Validate provider configuration and return list of errors.
|
|
319
|
+
|
|
320
|
+
Checks:
|
|
321
|
+
- API keys are available for providers that need them
|
|
322
|
+
- Models are known for the selected provider
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
settings: Provider settings to validate
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
List of validation error messages (empty if valid)
|
|
329
|
+
"""
|
|
330
|
+
errors: list[str] = []
|
|
331
|
+
|
|
332
|
+
# Validate embedding provider
|
|
333
|
+
if settings.embedding.provider != EmbeddingProviderType.OLLAMA:
|
|
334
|
+
api_key = settings.embedding.get_api_key()
|
|
335
|
+
if not api_key:
|
|
336
|
+
env_var = settings.embedding.api_key_env or "OPENAI_API_KEY"
|
|
337
|
+
errors.append(
|
|
338
|
+
f"Missing API key for {settings.embedding.provider} embeddings. "
|
|
339
|
+
f"Set {env_var} environment variable."
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Validate summarization provider
|
|
343
|
+
if settings.summarization.provider != SummarizationProviderType.OLLAMA:
|
|
344
|
+
api_key = settings.summarization.get_api_key()
|
|
345
|
+
if not api_key:
|
|
346
|
+
env_var = settings.summarization.api_key_env or "ANTHROPIC_API_KEY"
|
|
347
|
+
errors.append(
|
|
348
|
+
f"Missing API key for {settings.summarization.provider} summarization. "
|
|
349
|
+
f"Set {env_var} environment variable."
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return errors
|
|
@@ -26,12 +26,12 @@ class Settings(BaseSettings):
|
|
|
26
26
|
|
|
27
27
|
# Anthropic Configuration
|
|
28
28
|
ANTHROPIC_API_KEY: str = ""
|
|
29
|
-
CLAUDE_MODEL: str = "claude-
|
|
29
|
+
CLAUDE_MODEL: str = "claude-haiku-4-5-20251001" # Claude 4.5 Haiku (latest)
|
|
30
30
|
|
|
31
31
|
# Chroma Configuration
|
|
32
32
|
CHROMA_PERSIST_DIR: str = "./chroma_db"
|
|
33
33
|
BM25_INDEX_PATH: str = "./bm25_index"
|
|
34
|
-
COLLECTION_NAME: str = "
|
|
34
|
+
COLLECTION_NAME: str = "agent_brain_collection"
|
|
35
35
|
|
|
36
36
|
# Chunking Configuration
|
|
37
37
|
DEFAULT_CHUNK_SIZE: int = 512
|
|
@@ -48,14 +48,31 @@ class Settings(BaseSettings):
|
|
|
48
48
|
EMBEDDING_BATCH_SIZE: int = 100
|
|
49
49
|
|
|
50
50
|
# Multi-instance Configuration
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
AGENT_BRAIN_STATE_DIR: Optional[str] = None # Override state directory
|
|
52
|
+
AGENT_BRAIN_MODE: str = "project" # "project" or "shared"
|
|
53
|
+
|
|
54
|
+
# GraphRAG Configuration (Feature 113)
|
|
55
|
+
ENABLE_GRAPH_INDEX: bool = False # Master switch for graph indexing
|
|
56
|
+
GRAPH_STORE_TYPE: str = "simple" # "simple" (in-memory) or "kuzu" (persistent)
|
|
57
|
+
GRAPH_INDEX_PATH: str = "./graph_index" # Path for graph persistence
|
|
58
|
+
GRAPH_EXTRACTION_MODEL: str = "claude-haiku-4-5" # Model for entity extraction
|
|
59
|
+
GRAPH_MAX_TRIPLETS_PER_CHUNK: int = 10 # Max triplets per document chunk
|
|
60
|
+
GRAPH_USE_CODE_METADATA: bool = True # Use AST metadata for code entities
|
|
61
|
+
GRAPH_USE_LLM_EXTRACTION: bool = True # Use LLM for additional extraction
|
|
62
|
+
GRAPH_TRAVERSAL_DEPTH: int = 2 # Depth for graph traversal in queries
|
|
63
|
+
GRAPH_RRF_K: int = 60 # Reciprocal Rank Fusion constant for multi-retrieval
|
|
64
|
+
|
|
65
|
+
# Job Queue Configuration (Feature 115)
|
|
66
|
+
AGENT_BRAIN_MAX_QUEUE: int = 100 # Max pending jobs in queue
|
|
67
|
+
AGENT_BRAIN_JOB_TIMEOUT: int = 7200 # Job timeout in seconds (2 hours)
|
|
68
|
+
AGENT_BRAIN_MAX_RETRIES: int = 3 # Max retries for failed jobs
|
|
69
|
+
AGENT_BRAIN_CHECKPOINT_INTERVAL: int = 50 # Progress checkpoint every N files
|
|
53
70
|
|
|
54
71
|
model_config = SettingsConfigDict(
|
|
55
72
|
env_file=[
|
|
56
73
|
".env", # Current directory
|
|
57
74
|
Path(__file__).parent.parent.parent / ".env", # Project root
|
|
58
|
-
Path(__file__).parent.parent / ".env", #
|
|
75
|
+
Path(__file__).parent.parent / ".env", # agent-brain-server directory
|
|
59
76
|
],
|
|
60
77
|
env_file_encoding="utf-8",
|
|
61
78
|
case_sensitive=True,
|
|
@@ -7,6 +7,18 @@ from agent_brain_server.indexing.embedding import (
|
|
|
7
7
|
EmbeddingGenerator,
|
|
8
8
|
get_embedding_generator,
|
|
9
9
|
)
|
|
10
|
+
from agent_brain_server.indexing.graph_extractors import (
|
|
11
|
+
CodeMetadataExtractor,
|
|
12
|
+
LLMEntityExtractor,
|
|
13
|
+
get_code_extractor,
|
|
14
|
+
get_llm_extractor,
|
|
15
|
+
reset_extractors,
|
|
16
|
+
)
|
|
17
|
+
from agent_brain_server.indexing.graph_index import (
|
|
18
|
+
GraphIndexManager,
|
|
19
|
+
get_graph_index_manager,
|
|
20
|
+
reset_graph_index_manager,
|
|
21
|
+
)
|
|
10
22
|
|
|
11
23
|
__all__ = [
|
|
12
24
|
"DocumentLoader",
|
|
@@ -16,4 +28,13 @@ __all__ = [
|
|
|
16
28
|
"get_embedding_generator",
|
|
17
29
|
"BM25IndexManager",
|
|
18
30
|
"get_bm25_manager",
|
|
31
|
+
# Graph indexing (Feature 113)
|
|
32
|
+
"LLMEntityExtractor",
|
|
33
|
+
"CodeMetadataExtractor",
|
|
34
|
+
"get_llm_extractor",
|
|
35
|
+
"get_code_extractor",
|
|
36
|
+
"reset_extractors",
|
|
37
|
+
"GraphIndexManager",
|
|
38
|
+
"get_graph_index_manager",
|
|
39
|
+
"reset_graph_index_manager",
|
|
19
40
|
]
|
|
@@ -89,10 +89,23 @@ class BM25IndexManager:
|
|
|
89
89
|
if not self._retriever:
|
|
90
90
|
raise RuntimeError("BM25 index not initialized")
|
|
91
91
|
|
|
92
|
-
#
|
|
93
|
-
self._retriever.
|
|
92
|
+
# Cap top_k to corpus size to avoid bm25s "k larger than available scores" error
|
|
93
|
+
corpus_size = len(self._retriever.corpus) if self._retriever.corpus else 0
|
|
94
|
+
if corpus_size > 0:
|
|
95
|
+
effective_top_k = min(top_k, corpus_size)
|
|
96
|
+
else:
|
|
97
|
+
effective_top_k = top_k
|
|
98
|
+
|
|
99
|
+
self._retriever.similarity_top_k = effective_top_k
|
|
94
100
|
return self._retriever
|
|
95
101
|
|
|
102
|
+
@property
|
|
103
|
+
def corpus_size(self) -> int:
|
|
104
|
+
"""Get the number of documents in the BM25 index."""
|
|
105
|
+
if not self._retriever or not self._retriever.corpus:
|
|
106
|
+
return 0
|
|
107
|
+
return len(self._retriever.corpus)
|
|
108
|
+
|
|
96
109
|
async def search_with_filters(
|
|
97
110
|
self,
|
|
98
111
|
query: str,
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Document loading from various file formats using LlamaIndex."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
import re
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -272,9 +273,30 @@ class DocumentLoader:
|
|
|
272
273
|
|
|
273
274
|
SUPPORTED_EXTENSIONS: set[str] = DOCUMENT_EXTENSIONS | CODE_EXTENSIONS
|
|
274
275
|
|
|
276
|
+
# Default directories to exclude from indexing
|
|
277
|
+
DEFAULT_EXCLUDE_PATTERNS: list[str] = [
|
|
278
|
+
"**/node_modules/**",
|
|
279
|
+
"**/__pycache__/**",
|
|
280
|
+
"**/.venv/**",
|
|
281
|
+
"**/venv/**",
|
|
282
|
+
"**/.git/**",
|
|
283
|
+
"**/dist/**",
|
|
284
|
+
"**/build/**",
|
|
285
|
+
"**/target/**",
|
|
286
|
+
"**/.next/**",
|
|
287
|
+
"**/.nuxt/**",
|
|
288
|
+
"**/coverage/**",
|
|
289
|
+
"**/.pytest_cache/**",
|
|
290
|
+
"**/.mypy_cache/**",
|
|
291
|
+
"**/.tox/**",
|
|
292
|
+
"**/egg-info/**",
|
|
293
|
+
"**/*.egg-info/**",
|
|
294
|
+
]
|
|
295
|
+
|
|
275
296
|
def __init__(
|
|
276
297
|
self,
|
|
277
298
|
supported_extensions: Optional[set[str]] = None,
|
|
299
|
+
exclude_patterns: Optional[list[str]] = None,
|
|
278
300
|
):
|
|
279
301
|
"""
|
|
280
302
|
Initialize the document loader.
|
|
@@ -282,8 +304,15 @@ class DocumentLoader:
|
|
|
282
304
|
Args:
|
|
283
305
|
supported_extensions: Set of file extensions to load.
|
|
284
306
|
Defaults to SUPPORTED_EXTENSIONS.
|
|
307
|
+
exclude_patterns: List of glob patterns to exclude.
|
|
308
|
+
Defaults to DEFAULT_EXCLUDE_PATTERNS.
|
|
285
309
|
"""
|
|
286
310
|
self.extensions = supported_extensions or self.SUPPORTED_EXTENSIONS
|
|
311
|
+
self.exclude_patterns = (
|
|
312
|
+
exclude_patterns
|
|
313
|
+
if exclude_patterns is not None
|
|
314
|
+
else self.DEFAULT_EXCLUDE_PATTERNS
|
|
315
|
+
)
|
|
287
316
|
|
|
288
317
|
async def load_from_folder(
|
|
289
318
|
self,
|
|
@@ -313,16 +342,24 @@ class DocumentLoader:
|
|
|
313
342
|
raise ValueError(f"Path is not a directory: {folder_path}")
|
|
314
343
|
|
|
315
344
|
logger.info(f"Loading documents from: {folder_path} (recursive={recursive})")
|
|
345
|
+
if self.exclude_patterns:
|
|
346
|
+
logger.info(
|
|
347
|
+
f"Excluding patterns: {self.exclude_patterns[:3]}... "
|
|
348
|
+
f"({len(self.exclude_patterns)} total)"
|
|
349
|
+
)
|
|
316
350
|
|
|
317
351
|
# Use LlamaIndex's SimpleDirectoryReader
|
|
352
|
+
# Run in thread pool to avoid blocking the event loop
|
|
318
353
|
try:
|
|
319
354
|
reader = SimpleDirectoryReader(
|
|
320
355
|
input_dir=str(path),
|
|
321
356
|
recursive=recursive,
|
|
322
357
|
required_exts=list(self.extensions),
|
|
358
|
+
exclude=self.exclude_patterns,
|
|
323
359
|
filename_as_id=True,
|
|
324
360
|
)
|
|
325
|
-
|
|
361
|
+
# reader.load_data() is blocking I/O - run in thread pool
|
|
362
|
+
llama_documents: list[Document] = await asyncio.to_thread(reader.load_data)
|
|
326
363
|
except Exception as e:
|
|
327
364
|
logger.error(f"Failed to load documents: {e}")
|
|
328
365
|
raise
|
|
@@ -398,7 +435,8 @@ class DocumentLoader:
|
|
|
398
435
|
input_files=[str(path)],
|
|
399
436
|
filename_as_id=True,
|
|
400
437
|
)
|
|
401
|
-
|
|
438
|
+
# Run in thread pool to avoid blocking the event loop
|
|
439
|
+
docs = await asyncio.to_thread(reader.load_data)
|
|
402
440
|
|
|
403
441
|
if not docs:
|
|
404
442
|
raise ValueError(f"No content loaded from file: {file_path}")
|
|
@@ -456,8 +494,11 @@ class DocumentLoader:
|
|
|
456
494
|
# Use only document extensions
|
|
457
495
|
effective_extensions = self.DOCUMENT_EXTENSIONS
|
|
458
496
|
|
|
459
|
-
# Create a temporary loader with the effective extensions
|
|
460
|
-
temp_loader = DocumentLoader(
|
|
497
|
+
# Create a temporary loader with the effective extensions and exclude patterns
|
|
498
|
+
temp_loader = DocumentLoader(
|
|
499
|
+
supported_extensions=effective_extensions,
|
|
500
|
+
exclude_patterns=self.exclude_patterns,
|
|
501
|
+
)
|
|
461
502
|
|
|
462
503
|
# Load files using the configured extensions
|
|
463
504
|
loaded_docs = await temp_loader.load_from_folder(folder_path, recursive)
|