agent-brain-rag 1.2.0__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +54 -16
- agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
- agent_brain_server/__init__.py +1 -1
- agent_brain_server/api/main.py +30 -2
- agent_brain_server/api/routers/health.py +1 -0
- agent_brain_server/config/provider_config.py +308 -0
- agent_brain_server/config/settings.py +12 -1
- agent_brain_server/indexing/__init__.py +21 -0
- agent_brain_server/indexing/embedding.py +86 -135
- agent_brain_server/indexing/graph_extractors.py +582 -0
- agent_brain_server/indexing/graph_index.py +536 -0
- agent_brain_server/models/__init__.py +9 -0
- agent_brain_server/models/graph.py +253 -0
- agent_brain_server/models/health.py +15 -3
- agent_brain_server/models/query.py +14 -1
- agent_brain_server/providers/__init__.py +64 -0
- agent_brain_server/providers/base.py +251 -0
- agent_brain_server/providers/embedding/__init__.py +23 -0
- agent_brain_server/providers/embedding/cohere.py +163 -0
- agent_brain_server/providers/embedding/ollama.py +150 -0
- agent_brain_server/providers/embedding/openai.py +118 -0
- agent_brain_server/providers/exceptions.py +95 -0
- agent_brain_server/providers/factory.py +157 -0
- agent_brain_server/providers/summarization/__init__.py +41 -0
- agent_brain_server/providers/summarization/anthropic.py +87 -0
- agent_brain_server/providers/summarization/gemini.py +96 -0
- agent_brain_server/providers/summarization/grok.py +95 -0
- agent_brain_server/providers/summarization/ollama.py +114 -0
- agent_brain_server/providers/summarization/openai.py +87 -0
- agent_brain_server/services/indexing_service.py +39 -0
- agent_brain_server/services/query_service.py +203 -0
- agent_brain_server/storage/__init__.py +18 -2
- agent_brain_server/storage/graph_store.py +519 -0
- agent_brain_server/storage/vector_store.py +35 -0
- agent_brain_server/storage_paths.py +2 -0
- agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
- {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
"""Models for GraphRAG feature (Feature 113).
|
|
2
|
+
|
|
3
|
+
Defines Pydantic models for graph entities, relationships, and status.
|
|
4
|
+
All models are configured with frozen=True for immutability.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Optional
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GraphTriple(BaseModel):
|
|
14
|
+
"""Represents a subject-predicate-object triple in the knowledge graph.
|
|
15
|
+
|
|
16
|
+
Triples are the fundamental unit of knowledge representation in GraphRAG.
|
|
17
|
+
They capture relationships between entities extracted from documents.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
subject: The subject entity (e.g., "FastAPI").
|
|
21
|
+
subject_type: Optional type classification (e.g., "Framework").
|
|
22
|
+
predicate: The relationship type (e.g., "uses").
|
|
23
|
+
object: The object entity (e.g., "Pydantic").
|
|
24
|
+
object_type: Optional type classification (e.g., "Library").
|
|
25
|
+
source_chunk_id: Optional ID of the source document chunk.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
model_config = ConfigDict(
|
|
29
|
+
frozen=True,
|
|
30
|
+
json_schema_extra={
|
|
31
|
+
"examples": [
|
|
32
|
+
{
|
|
33
|
+
"subject": "FastAPI",
|
|
34
|
+
"subject_type": "Framework",
|
|
35
|
+
"predicate": "uses",
|
|
36
|
+
"object": "Pydantic",
|
|
37
|
+
"object_type": "Library",
|
|
38
|
+
"source_chunk_id": "chunk_abc123",
|
|
39
|
+
},
|
|
40
|
+
{
|
|
41
|
+
"subject": "UserController",
|
|
42
|
+
"subject_type": "Class",
|
|
43
|
+
"predicate": "calls",
|
|
44
|
+
"object": "authenticate_user",
|
|
45
|
+
"object_type": "Function",
|
|
46
|
+
"source_chunk_id": "chunk_def456",
|
|
47
|
+
},
|
|
48
|
+
]
|
|
49
|
+
},
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
subject: str = Field(
|
|
53
|
+
...,
|
|
54
|
+
min_length=1,
|
|
55
|
+
description="Subject entity in the triple",
|
|
56
|
+
)
|
|
57
|
+
subject_type: Optional[str] = Field(
|
|
58
|
+
default=None,
|
|
59
|
+
description="Type classification for subject entity",
|
|
60
|
+
)
|
|
61
|
+
predicate: str = Field(
|
|
62
|
+
...,
|
|
63
|
+
min_length=1,
|
|
64
|
+
description="Relationship type connecting subject to object",
|
|
65
|
+
)
|
|
66
|
+
object: str = Field(
|
|
67
|
+
...,
|
|
68
|
+
min_length=1,
|
|
69
|
+
description="Object entity in the triple",
|
|
70
|
+
)
|
|
71
|
+
object_type: Optional[str] = Field(
|
|
72
|
+
default=None,
|
|
73
|
+
description="Type classification for object entity",
|
|
74
|
+
)
|
|
75
|
+
source_chunk_id: Optional[str] = Field(
|
|
76
|
+
default=None,
|
|
77
|
+
description="ID of the source document chunk",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class GraphEntity(BaseModel):
|
|
82
|
+
"""Represents an entity node in the knowledge graph.
|
|
83
|
+
|
|
84
|
+
Entities are the nodes in the graph, representing concepts,
|
|
85
|
+
code elements, or other named items extracted from documents.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
name: Unique name/identifier of the entity.
|
|
89
|
+
entity_type: Classification type (e.g., "Class", "Function", "Concept").
|
|
90
|
+
description: Optional description of the entity.
|
|
91
|
+
source_chunk_ids: List of source chunk IDs where entity appears.
|
|
92
|
+
properties: Additional metadata properties.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
model_config = ConfigDict(
|
|
96
|
+
frozen=True,
|
|
97
|
+
json_schema_extra={
|
|
98
|
+
"examples": [
|
|
99
|
+
{
|
|
100
|
+
"name": "VectorStoreManager",
|
|
101
|
+
"entity_type": "Class",
|
|
102
|
+
"description": "Manages Chroma vector store operations",
|
|
103
|
+
"source_chunk_ids": ["chunk_001", "chunk_002"],
|
|
104
|
+
"properties": {"module": "storage.vector_store"},
|
|
105
|
+
},
|
|
106
|
+
]
|
|
107
|
+
},
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
name: str = Field(
|
|
111
|
+
...,
|
|
112
|
+
min_length=1,
|
|
113
|
+
description="Unique name/identifier of the entity",
|
|
114
|
+
)
|
|
115
|
+
entity_type: Optional[str] = Field(
|
|
116
|
+
default=None,
|
|
117
|
+
description="Classification type for the entity",
|
|
118
|
+
)
|
|
119
|
+
description: Optional[str] = Field(
|
|
120
|
+
default=None,
|
|
121
|
+
description="Description of the entity",
|
|
122
|
+
)
|
|
123
|
+
source_chunk_ids: list[str] = Field(
|
|
124
|
+
default_factory=list,
|
|
125
|
+
description="List of source chunk IDs where entity appears",
|
|
126
|
+
)
|
|
127
|
+
properties: dict[str, str] = Field(
|
|
128
|
+
default_factory=dict,
|
|
129
|
+
description="Additional metadata properties",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class GraphIndexStatus(BaseModel):
|
|
134
|
+
"""Status of the graph index.
|
|
135
|
+
|
|
136
|
+
Provides information about the graph index state,
|
|
137
|
+
including whether it's enabled, initialized, and statistics.
|
|
138
|
+
|
|
139
|
+
Attributes:
|
|
140
|
+
enabled: Whether graph indexing is enabled.
|
|
141
|
+
initialized: Whether the graph store is initialized.
|
|
142
|
+
entity_count: Number of entities in the graph.
|
|
143
|
+
relationship_count: Number of relationships in the graph.
|
|
144
|
+
last_updated: Timestamp of last graph update.
|
|
145
|
+
store_type: Type of graph store backend.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
model_config = ConfigDict(
|
|
149
|
+
frozen=True,
|
|
150
|
+
json_schema_extra={
|
|
151
|
+
"examples": [
|
|
152
|
+
{
|
|
153
|
+
"enabled": True,
|
|
154
|
+
"initialized": True,
|
|
155
|
+
"entity_count": 150,
|
|
156
|
+
"relationship_count": 320,
|
|
157
|
+
"last_updated": "2024-12-15T10:30:00Z",
|
|
158
|
+
"store_type": "simple",
|
|
159
|
+
},
|
|
160
|
+
{
|
|
161
|
+
"enabled": False,
|
|
162
|
+
"initialized": False,
|
|
163
|
+
"entity_count": 0,
|
|
164
|
+
"relationship_count": 0,
|
|
165
|
+
"last_updated": None,
|
|
166
|
+
"store_type": "simple",
|
|
167
|
+
},
|
|
168
|
+
]
|
|
169
|
+
},
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
enabled: bool = Field(
|
|
173
|
+
default=False,
|
|
174
|
+
description="Whether graph indexing is enabled",
|
|
175
|
+
)
|
|
176
|
+
initialized: bool = Field(
|
|
177
|
+
default=False,
|
|
178
|
+
description="Whether the graph store is initialized",
|
|
179
|
+
)
|
|
180
|
+
entity_count: int = Field(
|
|
181
|
+
default=0,
|
|
182
|
+
ge=0,
|
|
183
|
+
description="Number of entities in the graph",
|
|
184
|
+
)
|
|
185
|
+
relationship_count: int = Field(
|
|
186
|
+
default=0,
|
|
187
|
+
ge=0,
|
|
188
|
+
description="Number of relationships in the graph",
|
|
189
|
+
)
|
|
190
|
+
last_updated: Optional[datetime] = Field(
|
|
191
|
+
default=None,
|
|
192
|
+
description="Timestamp of last graph update",
|
|
193
|
+
)
|
|
194
|
+
store_type: str = Field(
|
|
195
|
+
default="simple",
|
|
196
|
+
description="Type of graph store backend (simple or kuzu)",
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class GraphQueryContext(BaseModel):
|
|
201
|
+
"""Context information from graph-based retrieval.
|
|
202
|
+
|
|
203
|
+
Contains additional context extracted from the knowledge graph
|
|
204
|
+
during query processing.
|
|
205
|
+
|
|
206
|
+
Attributes:
|
|
207
|
+
related_entities: List of related entity names.
|
|
208
|
+
relationship_paths: List of relationship paths as strings.
|
|
209
|
+
subgraph_triplets: Relevant triplets from the graph.
|
|
210
|
+
graph_score: Score from graph-based retrieval.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
model_config = ConfigDict(
|
|
214
|
+
frozen=True,
|
|
215
|
+
json_schema_extra={
|
|
216
|
+
"examples": [
|
|
217
|
+
{
|
|
218
|
+
"related_entities": ["FastAPI", "Pydantic", "Uvicorn"],
|
|
219
|
+
"relationship_paths": [
|
|
220
|
+
"FastAPI -> uses -> Pydantic",
|
|
221
|
+
"FastAPI -> runs_on -> Uvicorn",
|
|
222
|
+
],
|
|
223
|
+
"subgraph_triplets": [
|
|
224
|
+
{
|
|
225
|
+
"subject": "FastAPI",
|
|
226
|
+
"predicate": "uses",
|
|
227
|
+
"object": "Pydantic",
|
|
228
|
+
},
|
|
229
|
+
],
|
|
230
|
+
"graph_score": 0.85,
|
|
231
|
+
},
|
|
232
|
+
]
|
|
233
|
+
},
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
related_entities: list[str] = Field(
|
|
237
|
+
default_factory=list,
|
|
238
|
+
description="List of related entity names",
|
|
239
|
+
)
|
|
240
|
+
relationship_paths: list[str] = Field(
|
|
241
|
+
default_factory=list,
|
|
242
|
+
description="Relationship paths as formatted strings",
|
|
243
|
+
)
|
|
244
|
+
subgraph_triplets: list[GraphTriple] = Field(
|
|
245
|
+
default_factory=list,
|
|
246
|
+
description="Relevant triplets from the graph",
|
|
247
|
+
)
|
|
248
|
+
graph_score: float = Field(
|
|
249
|
+
default=0.0,
|
|
250
|
+
ge=0.0,
|
|
251
|
+
le=1.0,
|
|
252
|
+
description="Score from graph-based retrieval",
|
|
253
|
+
)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Health status models."""
|
|
2
2
|
|
|
3
3
|
from datetime import datetime, timezone
|
|
4
|
-
from typing import Literal, Optional
|
|
4
|
+
from typing import Any, Literal, Optional
|
|
5
5
|
|
|
6
6
|
from pydantic import BaseModel, Field
|
|
7
7
|
|
|
@@ -22,7 +22,7 @@ class HealthStatus(BaseModel):
|
|
|
22
22
|
description="Timestamp of the health check",
|
|
23
23
|
)
|
|
24
24
|
version: str = Field(
|
|
25
|
-
default="
|
|
25
|
+
default="2.0.0",
|
|
26
26
|
description="Server version",
|
|
27
27
|
)
|
|
28
28
|
mode: Optional[str] = Field(
|
|
@@ -49,7 +49,7 @@ class HealthStatus(BaseModel):
|
|
|
49
49
|
"status": "healthy",
|
|
50
50
|
"message": "Server is running and ready for queries",
|
|
51
51
|
"timestamp": "2024-12-15T10:30:00Z",
|
|
52
|
-
"version": "
|
|
52
|
+
"version": "2.0.0",
|
|
53
53
|
}
|
|
54
54
|
]
|
|
55
55
|
}
|
|
@@ -105,6 +105,11 @@ class IndexingStatus(BaseModel):
|
|
|
105
105
|
default_factory=list,
|
|
106
106
|
description="List of folders that have been indexed",
|
|
107
107
|
)
|
|
108
|
+
# Graph index status (Feature 113)
|
|
109
|
+
graph_index: Optional[dict[str, Any]] = Field(
|
|
110
|
+
default=None,
|
|
111
|
+
description="Graph index status with entity_count, relationship_count, etc.",
|
|
112
|
+
)
|
|
108
113
|
|
|
109
114
|
model_config = {
|
|
110
115
|
"json_schema_extra": {
|
|
@@ -120,6 +125,13 @@ class IndexingStatus(BaseModel):
|
|
|
120
125
|
"last_indexed_at": "2024-12-15T10:30:00Z",
|
|
121
126
|
"indexed_folders": ["/path/to/docs"],
|
|
122
127
|
"supported_languages": ["python", "typescript", "java"],
|
|
128
|
+
"graph_index": {
|
|
129
|
+
"enabled": True,
|
|
130
|
+
"initialized": True,
|
|
131
|
+
"entity_count": 120,
|
|
132
|
+
"relationship_count": 250,
|
|
133
|
+
"store_type": "simple",
|
|
134
|
+
},
|
|
123
135
|
}
|
|
124
136
|
]
|
|
125
137
|
}
|
|
@@ -14,6 +14,8 @@ class QueryMode(str, Enum):
|
|
|
14
14
|
VECTOR = "vector"
|
|
15
15
|
BM25 = "bm25"
|
|
16
16
|
HYBRID = "hybrid"
|
|
17
|
+
GRAPH = "graph" # Graph-only retrieval (Feature 113)
|
|
18
|
+
MULTI = "multi" # Multi-retrieval: vector + BM25 + graph with RRF (Feature 113)
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
class QueryRequest(BaseModel):
|
|
@@ -39,7 +41,7 @@ class QueryRequest(BaseModel):
|
|
|
39
41
|
)
|
|
40
42
|
mode: QueryMode = Field(
|
|
41
43
|
default=QueryMode.HYBRID,
|
|
42
|
-
description="Retrieval mode (vector, bm25, hybrid)",
|
|
44
|
+
description="Retrieval mode (vector, bm25, hybrid, graph, multi)",
|
|
43
45
|
)
|
|
44
46
|
alpha: float = Field(
|
|
45
47
|
default=0.5,
|
|
@@ -131,6 +133,17 @@ class QueryResult(BaseModel):
|
|
|
131
133
|
default=None, description="Programming language for code files"
|
|
132
134
|
)
|
|
133
135
|
|
|
136
|
+
# GraphRAG fields (Feature 113)
|
|
137
|
+
graph_score: float | None = Field(
|
|
138
|
+
default=None, description="Score from graph-based retrieval"
|
|
139
|
+
)
|
|
140
|
+
related_entities: list[str] | None = Field(
|
|
141
|
+
default=None, description="Related entities from knowledge graph"
|
|
142
|
+
)
|
|
143
|
+
relationship_path: list[str] | None = Field(
|
|
144
|
+
default=None, description="Relationship paths in the graph"
|
|
145
|
+
)
|
|
146
|
+
|
|
134
147
|
# Additional metadata
|
|
135
148
|
metadata: dict[str, Any] = Field(
|
|
136
149
|
default_factory=dict, description="Additional metadata"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Pluggable model providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This package provides abstractions for embedding and summarization providers,
|
|
4
|
+
allowing configuration-driven selection between OpenAI, Ollama, Cohere (embeddings)
|
|
5
|
+
and Anthropic, OpenAI, Gemini, Grok, Ollama (summarization) providers.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from agent_brain_server.providers.base import (
|
|
9
|
+
BaseEmbeddingProvider,
|
|
10
|
+
BaseSummarizationProvider,
|
|
11
|
+
EmbeddingProvider,
|
|
12
|
+
EmbeddingProviderType,
|
|
13
|
+
SummarizationProvider,
|
|
14
|
+
SummarizationProviderType,
|
|
15
|
+
)
|
|
16
|
+
from agent_brain_server.providers.exceptions import (
|
|
17
|
+
AuthenticationError,
|
|
18
|
+
ConfigurationError,
|
|
19
|
+
ModelNotFoundError,
|
|
20
|
+
ProviderError,
|
|
21
|
+
ProviderMismatchError,
|
|
22
|
+
ProviderNotFoundError,
|
|
23
|
+
RateLimitError,
|
|
24
|
+
)
|
|
25
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
# Protocols
|
|
29
|
+
"EmbeddingProvider",
|
|
30
|
+
"SummarizationProvider",
|
|
31
|
+
# Base classes
|
|
32
|
+
"BaseEmbeddingProvider",
|
|
33
|
+
"BaseSummarizationProvider",
|
|
34
|
+
# Enums
|
|
35
|
+
"EmbeddingProviderType",
|
|
36
|
+
"SummarizationProviderType",
|
|
37
|
+
# Factory
|
|
38
|
+
"ProviderRegistry",
|
|
39
|
+
# Exceptions
|
|
40
|
+
"ProviderError",
|
|
41
|
+
"ConfigurationError",
|
|
42
|
+
"AuthenticationError",
|
|
43
|
+
"ProviderNotFoundError",
|
|
44
|
+
"ProviderMismatchError",
|
|
45
|
+
"RateLimitError",
|
|
46
|
+
"ModelNotFoundError",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _register_providers() -> None:
|
|
51
|
+
"""Register all built-in providers with the registry."""
|
|
52
|
+
# Import providers to trigger registration
|
|
53
|
+
from agent_brain_server.providers import (
|
|
54
|
+
embedding, # noqa: F401
|
|
55
|
+
summarization, # noqa: F401
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Silence unused import warnings
|
|
59
|
+
_ = embedding
|
|
60
|
+
_ = summarization
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Auto-register providers on import
|
|
64
|
+
_register_providers()
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Base protocols and classes for pluggable providers."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Awaitable, Callable
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Optional, Protocol, runtime_checkable
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EmbeddingProviderType(str, Enum):
|
|
13
|
+
"""Supported embedding providers."""
|
|
14
|
+
|
|
15
|
+
OPENAI = "openai"
|
|
16
|
+
OLLAMA = "ollama"
|
|
17
|
+
COHERE = "cohere"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SummarizationProviderType(str, Enum):
|
|
21
|
+
"""Supported summarization providers."""
|
|
22
|
+
|
|
23
|
+
ANTHROPIC = "anthropic"
|
|
24
|
+
OPENAI = "openai"
|
|
25
|
+
GEMINI = "gemini"
|
|
26
|
+
GROK = "grok"
|
|
27
|
+
OLLAMA = "ollama"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@runtime_checkable
|
|
31
|
+
class EmbeddingProvider(Protocol):
|
|
32
|
+
"""Protocol for embedding providers.
|
|
33
|
+
|
|
34
|
+
All embedding providers must implement this interface to be usable
|
|
35
|
+
by the Agent Brain indexing and query systems.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
39
|
+
"""Generate embedding for a single text.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
text: Text to embed.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Embedding vector as list of floats.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ProviderError: If embedding generation fails.
|
|
49
|
+
"""
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
async def embed_texts(
|
|
53
|
+
self,
|
|
54
|
+
texts: list[str],
|
|
55
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
56
|
+
) -> list[list[float]]:
|
|
57
|
+
"""Generate embeddings for multiple texts.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
texts: List of texts to embed.
|
|
61
|
+
progress_callback: Optional callback(processed, total) for progress.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
List of embedding vectors, one per input text.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
ProviderError: If embedding generation fails.
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
def get_dimensions(self) -> int:
|
|
72
|
+
"""Get the embedding vector dimensions for the current model.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Number of dimensions in the embedding vector.
|
|
76
|
+
"""
|
|
77
|
+
...
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def provider_name(self) -> str:
|
|
81
|
+
"""Human-readable provider name for logging."""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def model_name(self) -> str:
|
|
86
|
+
"""Model identifier being used."""
|
|
87
|
+
...
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@runtime_checkable
|
|
91
|
+
class SummarizationProvider(Protocol):
|
|
92
|
+
"""Protocol for summarization/LLM providers.
|
|
93
|
+
|
|
94
|
+
All summarization providers must implement this interface to be usable
|
|
95
|
+
by the Agent Brain code summarization system.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
async def summarize(self, text: str) -> str:
|
|
99
|
+
"""Generate a summary of the given text.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
text: Text to summarize (typically source code).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Natural language summary of the text.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ProviderError: If summarization fails.
|
|
109
|
+
"""
|
|
110
|
+
...
|
|
111
|
+
|
|
112
|
+
async def generate(self, prompt: str) -> str:
|
|
113
|
+
"""Generate text based on a prompt (generic LLM call).
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
prompt: The prompt to send to the LLM.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
Generated text response.
|
|
120
|
+
|
|
121
|
+
Raises:
|
|
122
|
+
ProviderError: If generation fails.
|
|
123
|
+
"""
|
|
124
|
+
...
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def provider_name(self) -> str:
|
|
128
|
+
"""Human-readable provider name for logging."""
|
|
129
|
+
...
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def model_name(self) -> str:
|
|
133
|
+
"""Model identifier being used."""
|
|
134
|
+
...
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class BaseEmbeddingProvider(ABC):
|
|
138
|
+
"""Base class for embedding providers with common functionality."""
|
|
139
|
+
|
|
140
|
+
def __init__(self, model: str, batch_size: int = 100) -> None:
|
|
141
|
+
self._model = model
|
|
142
|
+
self._batch_size = batch_size
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Initialized {self.provider_name} embedding provider with model {model}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def model_name(self) -> str:
|
|
149
|
+
"""Model identifier being used."""
|
|
150
|
+
return self._model
|
|
151
|
+
|
|
152
|
+
async def embed_texts(
|
|
153
|
+
self,
|
|
154
|
+
texts: list[str],
|
|
155
|
+
progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
|
|
156
|
+
) -> list[list[float]]:
|
|
157
|
+
"""Default batch implementation using _embed_batch."""
|
|
158
|
+
if not texts:
|
|
159
|
+
return []
|
|
160
|
+
|
|
161
|
+
all_embeddings: list[list[float]] = []
|
|
162
|
+
|
|
163
|
+
for i in range(0, len(texts), self._batch_size):
|
|
164
|
+
batch = texts[i : i + self._batch_size]
|
|
165
|
+
batch_embeddings = await self._embed_batch(batch)
|
|
166
|
+
all_embeddings.extend(batch_embeddings)
|
|
167
|
+
|
|
168
|
+
if progress_callback:
|
|
169
|
+
await progress_callback(
|
|
170
|
+
min(i + self._batch_size, len(texts)),
|
|
171
|
+
len(texts),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
logger.debug(
|
|
175
|
+
f"Generated embeddings for batch {i // self._batch_size + 1} "
|
|
176
|
+
f"({len(batch)} texts)"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return all_embeddings
|
|
180
|
+
|
|
181
|
+
@abstractmethod
|
|
182
|
+
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
183
|
+
"""Provider-specific batch embedding implementation."""
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
async def embed_text(self, text: str) -> list[float]:
|
|
188
|
+
"""Provider-specific single text embedding."""
|
|
189
|
+
...
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def get_dimensions(self) -> int:
|
|
193
|
+
"""Provider-specific dimension lookup."""
|
|
194
|
+
...
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def provider_name(self) -> str:
|
|
199
|
+
"""Human-readable provider name for logging."""
|
|
200
|
+
...
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class BaseSummarizationProvider(ABC):
|
|
204
|
+
"""Base class for summarization providers with common functionality."""
|
|
205
|
+
|
|
206
|
+
DEFAULT_PROMPT_TEMPLATE = (
|
|
207
|
+
"You are an expert software engineer analyzing source code. "
|
|
208
|
+
"Provide a concise 1-2 sentence summary of what this code does. "
|
|
209
|
+
"Focus on the functionality, purpose, and behavior. "
|
|
210
|
+
"Be specific about inputs, outputs, and side effects. "
|
|
211
|
+
"Ignore implementation details and focus on what the code accomplishes.\n\n"
|
|
212
|
+
"Code to summarize:\n{code}\n\n"
|
|
213
|
+
"Summary:"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
model: str,
|
|
219
|
+
max_tokens: int = 300,
|
|
220
|
+
temperature: float = 0.1,
|
|
221
|
+
prompt_template: Optional[str] = None,
|
|
222
|
+
) -> None:
|
|
223
|
+
self._model = model
|
|
224
|
+
self._max_tokens = max_tokens
|
|
225
|
+
self._temperature = temperature
|
|
226
|
+
self._prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE
|
|
227
|
+
logger.info(
|
|
228
|
+
f"Initialized {self.provider_name} summarization provider "
|
|
229
|
+
f"with model {model}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def model_name(self) -> str:
|
|
234
|
+
"""Model identifier being used."""
|
|
235
|
+
return self._model
|
|
236
|
+
|
|
237
|
+
async def summarize(self, text: str) -> str:
|
|
238
|
+
"""Generate summary using the prompt template."""
|
|
239
|
+
prompt = self._prompt_template.format(code=text)
|
|
240
|
+
return await self.generate(prompt)
|
|
241
|
+
|
|
242
|
+
@abstractmethod
|
|
243
|
+
async def generate(self, prompt: str) -> str:
|
|
244
|
+
"""Provider-specific text generation."""
|
|
245
|
+
...
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
@abstractmethod
|
|
249
|
+
def provider_name(self) -> str:
|
|
250
|
+
"""Human-readable provider name for logging."""
|
|
251
|
+
...
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Embedding providers for Agent Brain.
|
|
2
|
+
|
|
3
|
+
This module provides embedding implementations for:
|
|
4
|
+
- OpenAI (text-embedding-3-large, text-embedding-3-small, text-embedding-ada-002)
|
|
5
|
+
- Ollama (nomic-embed-text, mxbai-embed-large, etc.)
|
|
6
|
+
- Cohere (embed-english-v3, embed-multilingual-v3, etc.)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from agent_brain_server.providers.embedding.cohere import CohereEmbeddingProvider
|
|
10
|
+
from agent_brain_server.providers.embedding.ollama import OllamaEmbeddingProvider
|
|
11
|
+
from agent_brain_server.providers.embedding.openai import OpenAIEmbeddingProvider
|
|
12
|
+
from agent_brain_server.providers.factory import ProviderRegistry
|
|
13
|
+
|
|
14
|
+
# Register embedding providers
|
|
15
|
+
ProviderRegistry.register_embedding_provider("openai", OpenAIEmbeddingProvider)
|
|
16
|
+
ProviderRegistry.register_embedding_provider("ollama", OllamaEmbeddingProvider)
|
|
17
|
+
ProviderRegistry.register_embedding_provider("cohere", CohereEmbeddingProvider)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"OpenAIEmbeddingProvider",
|
|
21
|
+
"OllamaEmbeddingProvider",
|
|
22
|
+
"CohereEmbeddingProvider",
|
|
23
|
+
]
|