agent-brain-rag 1.2.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/METADATA +54 -16
  2. agent_brain_rag-2.0.0.dist-info/RECORD +50 -0
  3. agent_brain_server/__init__.py +1 -1
  4. agent_brain_server/api/main.py +30 -2
  5. agent_brain_server/api/routers/health.py +1 -0
  6. agent_brain_server/config/provider_config.py +308 -0
  7. agent_brain_server/config/settings.py +12 -1
  8. agent_brain_server/indexing/__init__.py +21 -0
  9. agent_brain_server/indexing/embedding.py +86 -135
  10. agent_brain_server/indexing/graph_extractors.py +582 -0
  11. agent_brain_server/indexing/graph_index.py +536 -0
  12. agent_brain_server/models/__init__.py +9 -0
  13. agent_brain_server/models/graph.py +253 -0
  14. agent_brain_server/models/health.py +15 -3
  15. agent_brain_server/models/query.py +14 -1
  16. agent_brain_server/providers/__init__.py +64 -0
  17. agent_brain_server/providers/base.py +251 -0
  18. agent_brain_server/providers/embedding/__init__.py +23 -0
  19. agent_brain_server/providers/embedding/cohere.py +163 -0
  20. agent_brain_server/providers/embedding/ollama.py +150 -0
  21. agent_brain_server/providers/embedding/openai.py +118 -0
  22. agent_brain_server/providers/exceptions.py +95 -0
  23. agent_brain_server/providers/factory.py +157 -0
  24. agent_brain_server/providers/summarization/__init__.py +41 -0
  25. agent_brain_server/providers/summarization/anthropic.py +87 -0
  26. agent_brain_server/providers/summarization/gemini.py +96 -0
  27. agent_brain_server/providers/summarization/grok.py +95 -0
  28. agent_brain_server/providers/summarization/ollama.py +114 -0
  29. agent_brain_server/providers/summarization/openai.py +87 -0
  30. agent_brain_server/services/indexing_service.py +39 -0
  31. agent_brain_server/services/query_service.py +203 -0
  32. agent_brain_server/storage/__init__.py +18 -2
  33. agent_brain_server/storage/graph_store.py +519 -0
  34. agent_brain_server/storage/vector_store.py +35 -0
  35. agent_brain_server/storage_paths.py +2 -0
  36. agent_brain_rag-1.2.0.dist-info/RECORD +0 -31
  37. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/WHEEL +0 -0
  38. {agent_brain_rag-1.2.0.dist-info → agent_brain_rag-2.0.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,253 @@
1
+ """Models for GraphRAG feature (Feature 113).
2
+
3
+ Defines Pydantic models for graph entities, relationships, and status.
4
+ All models are configured with frozen=True for immutability.
5
+ """
6
+
7
+ from datetime import datetime
8
+ from typing import Optional
9
+
10
+ from pydantic import BaseModel, ConfigDict, Field
11
+
12
+
13
+ class GraphTriple(BaseModel):
14
+ """Represents a subject-predicate-object triple in the knowledge graph.
15
+
16
+ Triples are the fundamental unit of knowledge representation in GraphRAG.
17
+ They capture relationships between entities extracted from documents.
18
+
19
+ Attributes:
20
+ subject: The subject entity (e.g., "FastAPI").
21
+ subject_type: Optional type classification (e.g., "Framework").
22
+ predicate: The relationship type (e.g., "uses").
23
+ object: The object entity (e.g., "Pydantic").
24
+ object_type: Optional type classification (e.g., "Library").
25
+ source_chunk_id: Optional ID of the source document chunk.
26
+ """
27
+
28
+ model_config = ConfigDict(
29
+ frozen=True,
30
+ json_schema_extra={
31
+ "examples": [
32
+ {
33
+ "subject": "FastAPI",
34
+ "subject_type": "Framework",
35
+ "predicate": "uses",
36
+ "object": "Pydantic",
37
+ "object_type": "Library",
38
+ "source_chunk_id": "chunk_abc123",
39
+ },
40
+ {
41
+ "subject": "UserController",
42
+ "subject_type": "Class",
43
+ "predicate": "calls",
44
+ "object": "authenticate_user",
45
+ "object_type": "Function",
46
+ "source_chunk_id": "chunk_def456",
47
+ },
48
+ ]
49
+ },
50
+ )
51
+
52
+ subject: str = Field(
53
+ ...,
54
+ min_length=1,
55
+ description="Subject entity in the triple",
56
+ )
57
+ subject_type: Optional[str] = Field(
58
+ default=None,
59
+ description="Type classification for subject entity",
60
+ )
61
+ predicate: str = Field(
62
+ ...,
63
+ min_length=1,
64
+ description="Relationship type connecting subject to object",
65
+ )
66
+ object: str = Field(
67
+ ...,
68
+ min_length=1,
69
+ description="Object entity in the triple",
70
+ )
71
+ object_type: Optional[str] = Field(
72
+ default=None,
73
+ description="Type classification for object entity",
74
+ )
75
+ source_chunk_id: Optional[str] = Field(
76
+ default=None,
77
+ description="ID of the source document chunk",
78
+ )
79
+
80
+
81
+ class GraphEntity(BaseModel):
82
+ """Represents an entity node in the knowledge graph.
83
+
84
+ Entities are the nodes in the graph, representing concepts,
85
+ code elements, or other named items extracted from documents.
86
+
87
+ Attributes:
88
+ name: Unique name/identifier of the entity.
89
+ entity_type: Classification type (e.g., "Class", "Function", "Concept").
90
+ description: Optional description of the entity.
91
+ source_chunk_ids: List of source chunk IDs where entity appears.
92
+ properties: Additional metadata properties.
93
+ """
94
+
95
+ model_config = ConfigDict(
96
+ frozen=True,
97
+ json_schema_extra={
98
+ "examples": [
99
+ {
100
+ "name": "VectorStoreManager",
101
+ "entity_type": "Class",
102
+ "description": "Manages Chroma vector store operations",
103
+ "source_chunk_ids": ["chunk_001", "chunk_002"],
104
+ "properties": {"module": "storage.vector_store"},
105
+ },
106
+ ]
107
+ },
108
+ )
109
+
110
+ name: str = Field(
111
+ ...,
112
+ min_length=1,
113
+ description="Unique name/identifier of the entity",
114
+ )
115
+ entity_type: Optional[str] = Field(
116
+ default=None,
117
+ description="Classification type for the entity",
118
+ )
119
+ description: Optional[str] = Field(
120
+ default=None,
121
+ description="Description of the entity",
122
+ )
123
+ source_chunk_ids: list[str] = Field(
124
+ default_factory=list,
125
+ description="List of source chunk IDs where entity appears",
126
+ )
127
+ properties: dict[str, str] = Field(
128
+ default_factory=dict,
129
+ description="Additional metadata properties",
130
+ )
131
+
132
+
133
+ class GraphIndexStatus(BaseModel):
134
+ """Status of the graph index.
135
+
136
+ Provides information about the graph index state,
137
+ including whether it's enabled, initialized, and statistics.
138
+
139
+ Attributes:
140
+ enabled: Whether graph indexing is enabled.
141
+ initialized: Whether the graph store is initialized.
142
+ entity_count: Number of entities in the graph.
143
+ relationship_count: Number of relationships in the graph.
144
+ last_updated: Timestamp of last graph update.
145
+ store_type: Type of graph store backend.
146
+ """
147
+
148
+ model_config = ConfigDict(
149
+ frozen=True,
150
+ json_schema_extra={
151
+ "examples": [
152
+ {
153
+ "enabled": True,
154
+ "initialized": True,
155
+ "entity_count": 150,
156
+ "relationship_count": 320,
157
+ "last_updated": "2024-12-15T10:30:00Z",
158
+ "store_type": "simple",
159
+ },
160
+ {
161
+ "enabled": False,
162
+ "initialized": False,
163
+ "entity_count": 0,
164
+ "relationship_count": 0,
165
+ "last_updated": None,
166
+ "store_type": "simple",
167
+ },
168
+ ]
169
+ },
170
+ )
171
+
172
+ enabled: bool = Field(
173
+ default=False,
174
+ description="Whether graph indexing is enabled",
175
+ )
176
+ initialized: bool = Field(
177
+ default=False,
178
+ description="Whether the graph store is initialized",
179
+ )
180
+ entity_count: int = Field(
181
+ default=0,
182
+ ge=0,
183
+ description="Number of entities in the graph",
184
+ )
185
+ relationship_count: int = Field(
186
+ default=0,
187
+ ge=0,
188
+ description="Number of relationships in the graph",
189
+ )
190
+ last_updated: Optional[datetime] = Field(
191
+ default=None,
192
+ description="Timestamp of last graph update",
193
+ )
194
+ store_type: str = Field(
195
+ default="simple",
196
+ description="Type of graph store backend (simple or kuzu)",
197
+ )
198
+
199
+
200
+ class GraphQueryContext(BaseModel):
201
+ """Context information from graph-based retrieval.
202
+
203
+ Contains additional context extracted from the knowledge graph
204
+ during query processing.
205
+
206
+ Attributes:
207
+ related_entities: List of related entity names.
208
+ relationship_paths: List of relationship paths as strings.
209
+ subgraph_triplets: Relevant triplets from the graph.
210
+ graph_score: Score from graph-based retrieval.
211
+ """
212
+
213
+ model_config = ConfigDict(
214
+ frozen=True,
215
+ json_schema_extra={
216
+ "examples": [
217
+ {
218
+ "related_entities": ["FastAPI", "Pydantic", "Uvicorn"],
219
+ "relationship_paths": [
220
+ "FastAPI -> uses -> Pydantic",
221
+ "FastAPI -> runs_on -> Uvicorn",
222
+ ],
223
+ "subgraph_triplets": [
224
+ {
225
+ "subject": "FastAPI",
226
+ "predicate": "uses",
227
+ "object": "Pydantic",
228
+ },
229
+ ],
230
+ "graph_score": 0.85,
231
+ },
232
+ ]
233
+ },
234
+ )
235
+
236
+ related_entities: list[str] = Field(
237
+ default_factory=list,
238
+ description="List of related entity names",
239
+ )
240
+ relationship_paths: list[str] = Field(
241
+ default_factory=list,
242
+ description="Relationship paths as formatted strings",
243
+ )
244
+ subgraph_triplets: list[GraphTriple] = Field(
245
+ default_factory=list,
246
+ description="Relevant triplets from the graph",
247
+ )
248
+ graph_score: float = Field(
249
+ default=0.0,
250
+ ge=0.0,
251
+ le=1.0,
252
+ description="Score from graph-based retrieval",
253
+ )
@@ -1,7 +1,7 @@
1
1
  """Health status models."""
2
2
 
3
3
  from datetime import datetime, timezone
4
- from typing import Literal, Optional
4
+ from typing import Any, Literal, Optional
5
5
 
6
6
  from pydantic import BaseModel, Field
7
7
 
@@ -22,7 +22,7 @@ class HealthStatus(BaseModel):
22
22
  description="Timestamp of the health check",
23
23
  )
24
24
  version: str = Field(
25
- default="1.2.0",
25
+ default="2.0.0",
26
26
  description="Server version",
27
27
  )
28
28
  mode: Optional[str] = Field(
@@ -49,7 +49,7 @@ class HealthStatus(BaseModel):
49
49
  "status": "healthy",
50
50
  "message": "Server is running and ready for queries",
51
51
  "timestamp": "2024-12-15T10:30:00Z",
52
- "version": "1.2.0",
52
+ "version": "2.0.0",
53
53
  }
54
54
  ]
55
55
  }
@@ -105,6 +105,11 @@ class IndexingStatus(BaseModel):
105
105
  default_factory=list,
106
106
  description="List of folders that have been indexed",
107
107
  )
108
+ # Graph index status (Feature 113)
109
+ graph_index: Optional[dict[str, Any]] = Field(
110
+ default=None,
111
+ description="Graph index status with entity_count, relationship_count, etc.",
112
+ )
108
113
 
109
114
  model_config = {
110
115
  "json_schema_extra": {
@@ -120,6 +125,13 @@ class IndexingStatus(BaseModel):
120
125
  "last_indexed_at": "2024-12-15T10:30:00Z",
121
126
  "indexed_folders": ["/path/to/docs"],
122
127
  "supported_languages": ["python", "typescript", "java"],
128
+ "graph_index": {
129
+ "enabled": True,
130
+ "initialized": True,
131
+ "entity_count": 120,
132
+ "relationship_count": 250,
133
+ "store_type": "simple",
134
+ },
123
135
  }
124
136
  ]
125
137
  }
@@ -14,6 +14,8 @@ class QueryMode(str, Enum):
14
14
  VECTOR = "vector"
15
15
  BM25 = "bm25"
16
16
  HYBRID = "hybrid"
17
+ GRAPH = "graph" # Graph-only retrieval (Feature 113)
18
+ MULTI = "multi" # Multi-retrieval: vector + BM25 + graph with RRF (Feature 113)
17
19
 
18
20
 
19
21
  class QueryRequest(BaseModel):
@@ -39,7 +41,7 @@ class QueryRequest(BaseModel):
39
41
  )
40
42
  mode: QueryMode = Field(
41
43
  default=QueryMode.HYBRID,
42
- description="Retrieval mode (vector, bm25, hybrid)",
44
+ description="Retrieval mode (vector, bm25, hybrid, graph, multi)",
43
45
  )
44
46
  alpha: float = Field(
45
47
  default=0.5,
@@ -131,6 +133,17 @@ class QueryResult(BaseModel):
131
133
  default=None, description="Programming language for code files"
132
134
  )
133
135
 
136
+ # GraphRAG fields (Feature 113)
137
+ graph_score: float | None = Field(
138
+ default=None, description="Score from graph-based retrieval"
139
+ )
140
+ related_entities: list[str] | None = Field(
141
+ default=None, description="Related entities from knowledge graph"
142
+ )
143
+ relationship_path: list[str] | None = Field(
144
+ default=None, description="Relationship paths in the graph"
145
+ )
146
+
134
147
  # Additional metadata
135
148
  metadata: dict[str, Any] = Field(
136
149
  default_factory=dict, description="Additional metadata"
@@ -0,0 +1,64 @@
1
+ """Pluggable model providers for Agent Brain.
2
+
3
+ This package provides abstractions for embedding and summarization providers,
4
+ allowing configuration-driven selection between OpenAI, Ollama, Cohere (embeddings)
5
+ and Anthropic, OpenAI, Gemini, Grok, Ollama (summarization) providers.
6
+ """
7
+
8
+ from agent_brain_server.providers.base import (
9
+ BaseEmbeddingProvider,
10
+ BaseSummarizationProvider,
11
+ EmbeddingProvider,
12
+ EmbeddingProviderType,
13
+ SummarizationProvider,
14
+ SummarizationProviderType,
15
+ )
16
+ from agent_brain_server.providers.exceptions import (
17
+ AuthenticationError,
18
+ ConfigurationError,
19
+ ModelNotFoundError,
20
+ ProviderError,
21
+ ProviderMismatchError,
22
+ ProviderNotFoundError,
23
+ RateLimitError,
24
+ )
25
+ from agent_brain_server.providers.factory import ProviderRegistry
26
+
27
+ __all__ = [
28
+ # Protocols
29
+ "EmbeddingProvider",
30
+ "SummarizationProvider",
31
+ # Base classes
32
+ "BaseEmbeddingProvider",
33
+ "BaseSummarizationProvider",
34
+ # Enums
35
+ "EmbeddingProviderType",
36
+ "SummarizationProviderType",
37
+ # Factory
38
+ "ProviderRegistry",
39
+ # Exceptions
40
+ "ProviderError",
41
+ "ConfigurationError",
42
+ "AuthenticationError",
43
+ "ProviderNotFoundError",
44
+ "ProviderMismatchError",
45
+ "RateLimitError",
46
+ "ModelNotFoundError",
47
+ ]
48
+
49
+
50
+ def _register_providers() -> None:
51
+ """Register all built-in providers with the registry."""
52
+ # Import providers to trigger registration
53
+ from agent_brain_server.providers import (
54
+ embedding, # noqa: F401
55
+ summarization, # noqa: F401
56
+ )
57
+
58
+ # Silence unused import warnings
59
+ _ = embedding
60
+ _ = summarization
61
+
62
+
63
+ # Auto-register providers on import
64
+ _register_providers()
@@ -0,0 +1,251 @@
1
+ """Base protocols and classes for pluggable providers."""
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Awaitable, Callable
6
+ from enum import Enum
7
+ from typing import Optional, Protocol, runtime_checkable
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class EmbeddingProviderType(str, Enum):
13
+ """Supported embedding providers."""
14
+
15
+ OPENAI = "openai"
16
+ OLLAMA = "ollama"
17
+ COHERE = "cohere"
18
+
19
+
20
+ class SummarizationProviderType(str, Enum):
21
+ """Supported summarization providers."""
22
+
23
+ ANTHROPIC = "anthropic"
24
+ OPENAI = "openai"
25
+ GEMINI = "gemini"
26
+ GROK = "grok"
27
+ OLLAMA = "ollama"
28
+
29
+
30
+ @runtime_checkable
31
+ class EmbeddingProvider(Protocol):
32
+ """Protocol for embedding providers.
33
+
34
+ All embedding providers must implement this interface to be usable
35
+ by the Agent Brain indexing and query systems.
36
+ """
37
+
38
+ async def embed_text(self, text: str) -> list[float]:
39
+ """Generate embedding for a single text.
40
+
41
+ Args:
42
+ text: Text to embed.
43
+
44
+ Returns:
45
+ Embedding vector as list of floats.
46
+
47
+ Raises:
48
+ ProviderError: If embedding generation fails.
49
+ """
50
+ ...
51
+
52
+ async def embed_texts(
53
+ self,
54
+ texts: list[str],
55
+ progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
56
+ ) -> list[list[float]]:
57
+ """Generate embeddings for multiple texts.
58
+
59
+ Args:
60
+ texts: List of texts to embed.
61
+ progress_callback: Optional callback(processed, total) for progress.
62
+
63
+ Returns:
64
+ List of embedding vectors, one per input text.
65
+
66
+ Raises:
67
+ ProviderError: If embedding generation fails.
68
+ """
69
+ ...
70
+
71
+ def get_dimensions(self) -> int:
72
+ """Get the embedding vector dimensions for the current model.
73
+
74
+ Returns:
75
+ Number of dimensions in the embedding vector.
76
+ """
77
+ ...
78
+
79
+ @property
80
+ def provider_name(self) -> str:
81
+ """Human-readable provider name for logging."""
82
+ ...
83
+
84
+ @property
85
+ def model_name(self) -> str:
86
+ """Model identifier being used."""
87
+ ...
88
+
89
+
90
+ @runtime_checkable
91
+ class SummarizationProvider(Protocol):
92
+ """Protocol for summarization/LLM providers.
93
+
94
+ All summarization providers must implement this interface to be usable
95
+ by the Agent Brain code summarization system.
96
+ """
97
+
98
+ async def summarize(self, text: str) -> str:
99
+ """Generate a summary of the given text.
100
+
101
+ Args:
102
+ text: Text to summarize (typically source code).
103
+
104
+ Returns:
105
+ Natural language summary of the text.
106
+
107
+ Raises:
108
+ ProviderError: If summarization fails.
109
+ """
110
+ ...
111
+
112
+ async def generate(self, prompt: str) -> str:
113
+ """Generate text based on a prompt (generic LLM call).
114
+
115
+ Args:
116
+ prompt: The prompt to send to the LLM.
117
+
118
+ Returns:
119
+ Generated text response.
120
+
121
+ Raises:
122
+ ProviderError: If generation fails.
123
+ """
124
+ ...
125
+
126
+ @property
127
+ def provider_name(self) -> str:
128
+ """Human-readable provider name for logging."""
129
+ ...
130
+
131
+ @property
132
+ def model_name(self) -> str:
133
+ """Model identifier being used."""
134
+ ...
135
+
136
+
137
+ class BaseEmbeddingProvider(ABC):
138
+ """Base class for embedding providers with common functionality."""
139
+
140
+ def __init__(self, model: str, batch_size: int = 100) -> None:
141
+ self._model = model
142
+ self._batch_size = batch_size
143
+ logger.info(
144
+ f"Initialized {self.provider_name} embedding provider with model {model}"
145
+ )
146
+
147
+ @property
148
+ def model_name(self) -> str:
149
+ """Model identifier being used."""
150
+ return self._model
151
+
152
+ async def embed_texts(
153
+ self,
154
+ texts: list[str],
155
+ progress_callback: Optional[Callable[[int, int], Awaitable[None]]] = None,
156
+ ) -> list[list[float]]:
157
+ """Default batch implementation using _embed_batch."""
158
+ if not texts:
159
+ return []
160
+
161
+ all_embeddings: list[list[float]] = []
162
+
163
+ for i in range(0, len(texts), self._batch_size):
164
+ batch = texts[i : i + self._batch_size]
165
+ batch_embeddings = await self._embed_batch(batch)
166
+ all_embeddings.extend(batch_embeddings)
167
+
168
+ if progress_callback:
169
+ await progress_callback(
170
+ min(i + self._batch_size, len(texts)),
171
+ len(texts),
172
+ )
173
+
174
+ logger.debug(
175
+ f"Generated embeddings for batch {i // self._batch_size + 1} "
176
+ f"({len(batch)} texts)"
177
+ )
178
+
179
+ return all_embeddings
180
+
181
+ @abstractmethod
182
+ async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
183
+ """Provider-specific batch embedding implementation."""
184
+ ...
185
+
186
+ @abstractmethod
187
+ async def embed_text(self, text: str) -> list[float]:
188
+ """Provider-specific single text embedding."""
189
+ ...
190
+
191
+ @abstractmethod
192
+ def get_dimensions(self) -> int:
193
+ """Provider-specific dimension lookup."""
194
+ ...
195
+
196
+ @property
197
+ @abstractmethod
198
+ def provider_name(self) -> str:
199
+ """Human-readable provider name for logging."""
200
+ ...
201
+
202
+
203
+ class BaseSummarizationProvider(ABC):
204
+ """Base class for summarization providers with common functionality."""
205
+
206
+ DEFAULT_PROMPT_TEMPLATE = (
207
+ "You are an expert software engineer analyzing source code. "
208
+ "Provide a concise 1-2 sentence summary of what this code does. "
209
+ "Focus on the functionality, purpose, and behavior. "
210
+ "Be specific about inputs, outputs, and side effects. "
211
+ "Ignore implementation details and focus on what the code accomplishes.\n\n"
212
+ "Code to summarize:\n{code}\n\n"
213
+ "Summary:"
214
+ )
215
+
216
+ def __init__(
217
+ self,
218
+ model: str,
219
+ max_tokens: int = 300,
220
+ temperature: float = 0.1,
221
+ prompt_template: Optional[str] = None,
222
+ ) -> None:
223
+ self._model = model
224
+ self._max_tokens = max_tokens
225
+ self._temperature = temperature
226
+ self._prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE
227
+ logger.info(
228
+ f"Initialized {self.provider_name} summarization provider "
229
+ f"with model {model}"
230
+ )
231
+
232
+ @property
233
+ def model_name(self) -> str:
234
+ """Model identifier being used."""
235
+ return self._model
236
+
237
+ async def summarize(self, text: str) -> str:
238
+ """Generate summary using the prompt template."""
239
+ prompt = self._prompt_template.format(code=text)
240
+ return await self.generate(prompt)
241
+
242
+ @abstractmethod
243
+ async def generate(self, prompt: str) -> str:
244
+ """Provider-specific text generation."""
245
+ ...
246
+
247
+ @property
248
+ @abstractmethod
249
+ def provider_name(self) -> str:
250
+ """Human-readable provider name for logging."""
251
+ ...
@@ -0,0 +1,23 @@
1
+ """Embedding providers for Agent Brain.
2
+
3
+ This module provides embedding implementations for:
4
+ - OpenAI (text-embedding-3-large, text-embedding-3-small, text-embedding-ada-002)
5
+ - Ollama (nomic-embed-text, mxbai-embed-large, etc.)
6
+ - Cohere (embed-english-v3, embed-multilingual-v3, etc.)
7
+ """
8
+
9
+ from agent_brain_server.providers.embedding.cohere import CohereEmbeddingProvider
10
+ from agent_brain_server.providers.embedding.ollama import OllamaEmbeddingProvider
11
+ from agent_brain_server.providers.embedding.openai import OpenAIEmbeddingProvider
12
+ from agent_brain_server.providers.factory import ProviderRegistry
13
+
14
+ # Register embedding providers
15
+ ProviderRegistry.register_embedding_provider("openai", OpenAIEmbeddingProvider)
16
+ ProviderRegistry.register_embedding_provider("ollama", OllamaEmbeddingProvider)
17
+ ProviderRegistry.register_embedding_provider("cohere", CohereEmbeddingProvider)
18
+
19
+ __all__ = [
20
+ "OpenAIEmbeddingProvider",
21
+ "OllamaEmbeddingProvider",
22
+ "CohereEmbeddingProvider",
23
+ ]