brainify 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainify/__init__.py +32 -0
- brainify/client.py +267 -0
- brainify/config.py +82 -0
- brainify/exceptions.py +70 -0
- brainify/providers/__init__.py +17 -0
- brainify/providers/anthropic.py +120 -0
- brainify/providers/base.py +102 -0
- brainify/providers/cohere.py +90 -0
- brainify/providers/custom.py +58 -0
- brainify/providers/factory.py +76 -0
- brainify/providers/gemini.py +95 -0
- brainify/providers/groq.py +73 -0
- brainify/providers/mistral.py +72 -0
- brainify/providers/ollama.py +38 -0
- brainify/providers/openai.py +122 -0
- brainify/providers/together.py +34 -0
- brainify/py.typed +0 -0
- brainify/search/__init__.py +0 -0
- brainify/search/engine.py +145 -0
- brainify/search/processor.py +94 -0
- brainify/search/scraper.py +223 -0
- brainify/utils/__init__.py +14 -0
- brainify/utils/cache.py +91 -0
- brainify/utils/context.py +104 -0
- brainify/utils/detector.py +136 -0
- brainify/utils/logger.py +55 -0
- brainify/utils/prompts.py +89 -0
- brainify-0.1.0.dist-info/METADATA +5 -0
- brainify-0.1.0.dist-info/RECORD +32 -0
- brainify-0.1.0.dist-info/WHEEL +5 -0
- brainify-0.1.0.dist-info/licenses/LICENSE +21 -0
- brainify-0.1.0.dist-info/top_level.txt +1 -0
brainify/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from brainify.client import Brain, Brainify
|
|
2
|
+
from brainify.config import AIConfig, BrainifyConfig, SearchConfig
|
|
3
|
+
from brainify.exceptions import (
|
|
4
|
+
AuthenticationError,
|
|
5
|
+
BrainifyError,
|
|
6
|
+
ConfigurationError,
|
|
7
|
+
ProviderError,
|
|
8
|
+
RateLimitError,
|
|
9
|
+
SearchError,
|
|
10
|
+
)
|
|
11
|
+
from brainify.providers import SUPPORTED_PROVIDERS
|
|
12
|
+
from brainify.utils.detector import SmartSearchDetector
|
|
13
|
+
|
|
14
|
+
__version__ = "0.1.0"
|
|
15
|
+
__author__ = "Brainify"
|
|
16
|
+
__license__ = "MIT"
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"Brain",
|
|
20
|
+
"Brainify",
|
|
21
|
+
"BrainifyConfig",
|
|
22
|
+
"AIConfig",
|
|
23
|
+
"SearchConfig",
|
|
24
|
+
"SmartSearchDetector",
|
|
25
|
+
"BrainifyError",
|
|
26
|
+
"ProviderError",
|
|
27
|
+
"SearchError",
|
|
28
|
+
"ConfigurationError",
|
|
29
|
+
"RateLimitError",
|
|
30
|
+
"AuthenticationError",
|
|
31
|
+
"SUPPORTED_PROVIDERS",
|
|
32
|
+
]
|
brainify/client.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from brainify.config import BrainifyConfig
|
|
6
|
+
from brainify.providers.base import BaseProvider, Message
|
|
7
|
+
from brainify.providers.factory import ProviderFactory
|
|
8
|
+
from brainify.utils.context import ContextBuilder
|
|
9
|
+
from brainify.utils.detector import SmartSearchDetector
|
|
10
|
+
from brainify.utils.logger import get_logger
|
|
11
|
+
from brainify.utils.prompts import PromptBuilder
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
_DECISION_CACHE_MAX = 256
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class Brain:
|
|
19
|
+
|
|
20
|
+
def __init__(self, config: BrainifyConfig):
|
|
21
|
+
self._config = config
|
|
22
|
+
self._provider: BaseProvider = ProviderFactory.create(config.ai)
|
|
23
|
+
self._context_builder: ContextBuilder = ContextBuilder(config.search)
|
|
24
|
+
self._history: List[Message] = []
|
|
25
|
+
self._decision_cache: dict[str, bool] = {}
|
|
26
|
+
logger.info(
|
|
27
|
+
"Brain initialised — provider=%s model=%s enabled=%s auto_search=%s",
|
|
28
|
+
config.ai.provider,
|
|
29
|
+
config.ai.model,
|
|
30
|
+
config.search.enabled,
|
|
31
|
+
config.search.auto_search,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def chat(
|
|
35
|
+
self,
|
|
36
|
+
query: str,
|
|
37
|
+
*,
|
|
38
|
+
web_search: Optional[bool] = None,
|
|
39
|
+
clear_history: bool = False,
|
|
40
|
+
) -> str:
|
|
41
|
+
if clear_history:
|
|
42
|
+
self.clear_history()
|
|
43
|
+
|
|
44
|
+
use_search = self._resolve_search_flag(web_search, query)
|
|
45
|
+
system_prompt, _ = self._prepare_system(query, use_search)
|
|
46
|
+
messages = self._build_messages(query, system_prompt)
|
|
47
|
+
response = self._provider.complete(messages)
|
|
48
|
+
|
|
49
|
+
self._history.append(Message("user", query))
|
|
50
|
+
self._history.append(Message("assistant", response))
|
|
51
|
+
return response
|
|
52
|
+
|
|
53
|
+
def ask(self, query: str, *, web_search: Optional[bool] = None) -> str:
|
|
54
|
+
return self.chat(query, web_search=web_search, clear_history=True)
|
|
55
|
+
|
|
56
|
+
async def achat(
|
|
57
|
+
self,
|
|
58
|
+
query: str,
|
|
59
|
+
*,
|
|
60
|
+
web_search: Optional[bool] = None,
|
|
61
|
+
clear_history: bool = False,
|
|
62
|
+
) -> str:
|
|
63
|
+
if clear_history:
|
|
64
|
+
self.clear_history()
|
|
65
|
+
|
|
66
|
+
use_search = await self._aresolve_search_flag(web_search, query)
|
|
67
|
+
system_prompt, _ = await self._aprepare_system(query, use_search)
|
|
68
|
+
messages = self._build_messages(query, system_prompt)
|
|
69
|
+
response = await self._provider.acomplete(messages)
|
|
70
|
+
|
|
71
|
+
self._history.append(Message("user", query))
|
|
72
|
+
self._history.append(Message("assistant", response))
|
|
73
|
+
return response
|
|
74
|
+
|
|
75
|
+
async def aask(self, query: str, *, web_search: Optional[bool] = None) -> str:
|
|
76
|
+
return await self.achat(query, web_search=web_search, clear_history=True)
|
|
77
|
+
|
|
78
|
+
def search_needed(self, query: str) -> bool:
|
|
79
|
+
return SmartSearchDetector.needs_search(query)
|
|
80
|
+
|
|
81
|
+
def search_explain(self, query: str) -> dict:
|
|
82
|
+
return SmartSearchDetector.explain(query)
|
|
83
|
+
|
|
84
|
+
def clear_history(self) -> None:
|
|
85
|
+
self._history.clear()
|
|
86
|
+
logger.debug("Conversation history cleared.")
|
|
87
|
+
|
|
88
|
+
def clear_decision_cache(self) -> None:
|
|
89
|
+
self._decision_cache.clear()
|
|
90
|
+
logger.debug("LLM decision cache cleared.")
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def history(self) -> List[Message]:
|
|
94
|
+
return list(self._history)
|
|
95
|
+
|
|
96
|
+
def add_to_history(self, role: str, content: str) -> None:
|
|
97
|
+
self._history.append(Message(role, content))
|
|
98
|
+
|
|
99
|
+
def clear_search_cache(self) -> None:
|
|
100
|
+
self._context_builder.clear_cache()
|
|
101
|
+
logger.debug("Search cache cleared.")
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def search_cache_size(self) -> int:
|
|
105
|
+
return self._context_builder.cache_size
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def config(self) -> BrainifyConfig:
|
|
109
|
+
return self._config
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def provider_name(self) -> str:
|
|
113
|
+
return self._config.ai.provider
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def model(self) -> str:
|
|
117
|
+
return self._config.ai.model
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def web_search_enabled(self) -> bool:
|
|
121
|
+
return self._config.search.enabled
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def auto_search_enabled(self) -> bool:
|
|
125
|
+
return self._config.search.auto_search
|
|
126
|
+
|
|
127
|
+
def _resolve_search_flag(self, override: Optional[bool], query: str) -> bool:
|
|
128
|
+
if override is not None:
|
|
129
|
+
return override
|
|
130
|
+
if self._config.search.enabled:
|
|
131
|
+
return True
|
|
132
|
+
if self._config.search.auto_search:
|
|
133
|
+
return self._llm_search_decision(query)
|
|
134
|
+
return False
|
|
135
|
+
|
|
136
|
+
async def _aresolve_search_flag(self, override: Optional[bool], query: str) -> bool:
|
|
137
|
+
if override is not None:
|
|
138
|
+
return override
|
|
139
|
+
if self._config.search.enabled:
|
|
140
|
+
return True
|
|
141
|
+
if self._config.search.auto_search:
|
|
142
|
+
return await self._allm_search_decision(query)
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def _llm_search_decision(self, query: str) -> bool:
|
|
146
|
+
cache_key = query.strip().lower()
|
|
147
|
+
if cache_key in self._decision_cache:
|
|
148
|
+
cached = self._decision_cache[cache_key]
|
|
149
|
+
logger.debug("Decision cache hit: %s for %r", "SEARCH" if cached else "SKIP", query[:60])
|
|
150
|
+
return cached
|
|
151
|
+
|
|
152
|
+
try:
|
|
153
|
+
sys_prompt, usr_prompt = PromptBuilder.classifier_prompt(query)
|
|
154
|
+
classifier_messages = [
|
|
155
|
+
Message("system", sys_prompt),
|
|
156
|
+
Message("user", usr_prompt),
|
|
157
|
+
]
|
|
158
|
+
raw = self._provider.complete(classifier_messages)
|
|
159
|
+
decision = self._parse_classifier_response(raw, query)
|
|
160
|
+
except Exception as exc:
|
|
161
|
+
logger.warning("LLM classifier error (%s) — keyword fallback for: %r", exc, query[:60])
|
|
162
|
+
decision = SmartSearchDetector.needs_search(query)
|
|
163
|
+
|
|
164
|
+
self._store_decision(cache_key, decision)
|
|
165
|
+
return decision
|
|
166
|
+
|
|
167
|
+
async def _allm_search_decision(self, query: str) -> bool:
|
|
168
|
+
cache_key = query.strip().lower()
|
|
169
|
+
if cache_key in self._decision_cache:
|
|
170
|
+
cached = self._decision_cache[cache_key]
|
|
171
|
+
logger.debug("Decision cache hit: %s for %r", "SEARCH" if cached else "SKIP", query[:60])
|
|
172
|
+
return cached
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
sys_prompt, usr_prompt = PromptBuilder.classifier_prompt(query)
|
|
176
|
+
classifier_messages = [
|
|
177
|
+
Message("system", sys_prompt),
|
|
178
|
+
Message("user", usr_prompt),
|
|
179
|
+
]
|
|
180
|
+
raw = await self._provider.acomplete(classifier_messages)
|
|
181
|
+
decision = self._parse_classifier_response(raw, query)
|
|
182
|
+
except Exception as exc:
|
|
183
|
+
logger.warning("LLM classifier error (%s) — keyword fallback for: %r", exc, query[:60])
|
|
184
|
+
decision = SmartSearchDetector.needs_search(query)
|
|
185
|
+
|
|
186
|
+
self._store_decision(cache_key, decision)
|
|
187
|
+
return decision
|
|
188
|
+
|
|
189
|
+
def _parse_classifier_response(self, raw: str, query: str) -> bool:
|
|
190
|
+
answer = raw.strip().upper()
|
|
191
|
+
if answer.startswith("YES"):
|
|
192
|
+
logger.info("LLM decision → SEARCH | query: %r", query[:80])
|
|
193
|
+
return True
|
|
194
|
+
if answer.startswith("NO"):
|
|
195
|
+
logger.info("LLM decision → SKIP | query: %r", query[:80])
|
|
196
|
+
return False
|
|
197
|
+
logger.warning(
|
|
198
|
+
"Ambiguous LLM classifier response %r — keyword fallback for: %r",
|
|
199
|
+
raw[:30],
|
|
200
|
+
query[:60],
|
|
201
|
+
)
|
|
202
|
+
return SmartSearchDetector.needs_search(query)
|
|
203
|
+
|
|
204
|
+
def _store_decision(self, key: str, decision: bool) -> None:
|
|
205
|
+
if len(self._decision_cache) >= _DECISION_CACHE_MAX:
|
|
206
|
+
oldest = next(iter(self._decision_cache))
|
|
207
|
+
del self._decision_cache[oldest]
|
|
208
|
+
self._decision_cache[key] = decision
|
|
209
|
+
|
|
210
|
+
def _prepare_system(self, query: str, use_search: bool):
|
|
211
|
+
search_context = ""
|
|
212
|
+
search_results = []
|
|
213
|
+
|
|
214
|
+
if use_search:
|
|
215
|
+
try:
|
|
216
|
+
search_context, search_results = self._context_builder.build(query)
|
|
217
|
+
except Exception as exc:
|
|
218
|
+
logger.warning("Web search failed, continuing without it: %s", exc)
|
|
219
|
+
|
|
220
|
+
has_context = bool(search_context)
|
|
221
|
+
system_prompt = PromptBuilder.build(
|
|
222
|
+
web_search_active=use_search and has_context,
|
|
223
|
+
user_system_prompt=self._config.ai.system_prompt,
|
|
224
|
+
search_context=search_context if has_context else None,
|
|
225
|
+
)
|
|
226
|
+
if use_search and not has_context:
|
|
227
|
+
system_prompt += PromptBuilder.no_search_fallback_note()
|
|
228
|
+
|
|
229
|
+
return system_prompt, search_results
|
|
230
|
+
|
|
231
|
+
async def _aprepare_system(self, query: str, use_search: bool):
|
|
232
|
+
search_context = ""
|
|
233
|
+
search_results = []
|
|
234
|
+
|
|
235
|
+
if use_search:
|
|
236
|
+
try:
|
|
237
|
+
search_context, search_results = await self._context_builder.abuild(query)
|
|
238
|
+
except Exception as exc:
|
|
239
|
+
logger.warning("Async web search failed: %s", exc)
|
|
240
|
+
|
|
241
|
+
has_context = bool(search_context)
|
|
242
|
+
system_prompt = PromptBuilder.build(
|
|
243
|
+
web_search_active=use_search and has_context,
|
|
244
|
+
user_system_prompt=self._config.ai.system_prompt,
|
|
245
|
+
search_context=search_context if has_context else None,
|
|
246
|
+
)
|
|
247
|
+
if use_search and not has_context:
|
|
248
|
+
system_prompt += PromptBuilder.no_search_fallback_note()
|
|
249
|
+
|
|
250
|
+
return system_prompt, search_results
|
|
251
|
+
|
|
252
|
+
def _build_messages(self, query: str, system_prompt: str) -> List[Message]:
|
|
253
|
+
messages: List[Message] = [Message("system", system_prompt)]
|
|
254
|
+
messages.extend(self._history)
|
|
255
|
+
messages.append(Message("user", query))
|
|
256
|
+
return messages
|
|
257
|
+
|
|
258
|
+
def __repr__(self) -> str:
|
|
259
|
+
return (
|
|
260
|
+
f"Brain(provider={self.provider_name!r}, "
|
|
261
|
+
f"model={self.model!r}, "
|
|
262
|
+
f"enabled={self.web_search_enabled}, "
|
|
263
|
+
f"auto_search={self.auto_search_enabled})"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
Brainify = Brain
|
brainify/config.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from brainify.exceptions import ConfigurationError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class AIConfig:
|
|
11
|
+
provider: str
|
|
12
|
+
api_key: str = ""
|
|
13
|
+
model: str = ""
|
|
14
|
+
temperature: float = 0.7
|
|
15
|
+
max_tokens: int = 2048
|
|
16
|
+
top_p: float = 1.0
|
|
17
|
+
top_k: Optional[int] = None
|
|
18
|
+
frequency_penalty: float = 0.0
|
|
19
|
+
presence_penalty: float = 0.0
|
|
20
|
+
system_prompt: Optional[str] = None
|
|
21
|
+
base_url: Optional[str] = None
|
|
22
|
+
timeout: int = 60
|
|
23
|
+
extra_headers: Dict[str, str] = field(default_factory=dict)
|
|
24
|
+
extra_params: Dict[str, Any] = field(default_factory=dict)
|
|
25
|
+
|
|
26
|
+
def __post_init__(self):
|
|
27
|
+
self.provider = self.provider.lower().strip()
|
|
28
|
+
valid_providers = {
|
|
29
|
+
"openai", "anthropic", "gemini", "groq",
|
|
30
|
+
"mistral", "cohere", "together", "ollama", "custom",
|
|
31
|
+
}
|
|
32
|
+
if self.provider not in valid_providers:
|
|
33
|
+
raise ConfigurationError(
|
|
34
|
+
f"Unknown provider '{self.provider}'. "
|
|
35
|
+
f"Supported: {', '.join(sorted(valid_providers))}"
|
|
36
|
+
)
|
|
37
|
+
if not self.api_key and self.provider not in {"ollama", "custom"}:
|
|
38
|
+
raise ConfigurationError(
|
|
39
|
+
f"api_key is required for provider '{self.provider}'."
|
|
40
|
+
)
|
|
41
|
+
if not (0.0 <= self.temperature <= 2.0):
|
|
42
|
+
raise ConfigurationError("temperature must be between 0.0 and 2.0.")
|
|
43
|
+
if self.max_tokens < 1:
|
|
44
|
+
raise ConfigurationError("max_tokens must be >= 1.")
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def is_openai_compatible(self) -> bool:
|
|
48
|
+
return self.provider in {"openai", "groq", "together", "custom", "ollama"}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class SearchConfig:
|
|
53
|
+
enabled: bool = True
|
|
54
|
+
auto_search: bool = False
|
|
55
|
+
max_results: int = 5
|
|
56
|
+
max_pages_to_read: int = 3
|
|
57
|
+
max_content_length: int = 6000
|
|
58
|
+
search_timeout: int = 15
|
|
59
|
+
cache_ttl: int = 3600
|
|
60
|
+
safe_search: bool = True
|
|
61
|
+
region: str = "wt-wt"
|
|
62
|
+
time_filter: Optional[str] = None
|
|
63
|
+
excluded_domains: List[str] = field(default_factory=list)
|
|
64
|
+
include_snippets: bool = True
|
|
65
|
+
scrape_concurrency: int = 3
|
|
66
|
+
user_agent: Optional[str] = None
|
|
67
|
+
|
|
68
|
+
def __post_init__(self):
|
|
69
|
+
if self.max_results < 1 or self.max_results > 20:
|
|
70
|
+
raise ConfigurationError("max_results must be between 1 and 20.")
|
|
71
|
+
if self.max_pages_to_read < 0:
|
|
72
|
+
raise ConfigurationError("max_pages_to_read must be >= 0.")
|
|
73
|
+
if self.time_filter and self.time_filter not in {"d", "w", "m", "y"}:
|
|
74
|
+
raise ConfigurationError(
|
|
75
|
+
"time_filter must be one of: 'd', 'w', 'm', 'y'."
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass
|
|
80
|
+
class BrainifyConfig:
|
|
81
|
+
ai: AIConfig
|
|
82
|
+
search: SearchConfig = field(default_factory=SearchConfig)
|
brainify/exceptions.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Custom exceptions for Brainify."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BrainifyError(Exception):
|
|
5
|
+
"""Base exception for all Brainify errors."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, message: str, details: dict = None):
|
|
8
|
+
super().__init__(message)
|
|
9
|
+
self.message = message
|
|
10
|
+
self.details = details or {}
|
|
11
|
+
|
|
12
|
+
def __repr__(self):
|
|
13
|
+
return f"{self.__class__.__name__}(message={self.message!r})"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ConfigurationError(BrainifyError):
|
|
17
|
+
"""Raised when configuration is invalid or missing."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProviderError(BrainifyError):
|
|
21
|
+
"""Raised when an AI provider returns an error."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, message: str, provider: str = None, status_code: int = None, details: dict = None):
|
|
24
|
+
super().__init__(message, details)
|
|
25
|
+
self.provider = provider
|
|
26
|
+
self.status_code = status_code
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class AuthenticationError(ProviderError):
|
|
30
|
+
"""Raised when API key authentication fails."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RateLimitError(ProviderError):
|
|
34
|
+
"""Raised when the API rate limit is exceeded."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, message: str, provider: str = None, retry_after: float = None, details: dict = None):
|
|
37
|
+
super().__init__(message, provider=provider, details=details)
|
|
38
|
+
self.retry_after = retry_after
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SearchError(BrainifyError):
|
|
42
|
+
"""Raised when a web search fails."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ScrapingError(BrainifyError):
|
|
46
|
+
"""Raised when web page scraping fails."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ModelNotSupportedError(ConfigurationError):
|
|
50
|
+
"""Raised when the model is not supported by the provider."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, model: str, provider: str):
|
|
53
|
+
super().__init__(
|
|
54
|
+
f"Model '{model}' is not supported by provider '{provider}'. "
|
|
55
|
+
f"Check the provider documentation for supported models."
|
|
56
|
+
)
|
|
57
|
+
self.model = model
|
|
58
|
+
self.provider = provider
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ProviderNotInstalledError(ConfigurationError):
|
|
62
|
+
"""Raised when the required provider package is not installed."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, provider: str, package: str):
|
|
65
|
+
super().__init__(
|
|
66
|
+
f"Provider '{provider}' requires the '{package}' package. "
|
|
67
|
+
f"Install it with: pip install brainify[{provider}]"
|
|
68
|
+
)
|
|
69
|
+
self.provider = provider
|
|
70
|
+
self.package = package
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""AI provider implementations."""
|
|
2
|
+
|
|
3
|
+
from brainify.providers.factory import ProviderFactory
|
|
4
|
+
|
|
5
|
+
SUPPORTED_PROVIDERS = [
|
|
6
|
+
"openai",
|
|
7
|
+
"anthropic",
|
|
8
|
+
"gemini",
|
|
9
|
+
"groq",
|
|
10
|
+
"mistral",
|
|
11
|
+
"cohere",
|
|
12
|
+
"together",
|
|
13
|
+
"ollama",
|
|
14
|
+
"custom",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
__all__ = ["ProviderFactory", "SUPPORTED_PROVIDERS"]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Anthropic provider (Claude 3.5, Claude 3, …)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from brainify.config import AIConfig
|
|
8
|
+
from brainify.exceptions import (
|
|
9
|
+
AuthenticationError,
|
|
10
|
+
ProviderError,
|
|
11
|
+
RateLimitError,
|
|
12
|
+
ProviderNotInstalledError,
|
|
13
|
+
)
|
|
14
|
+
from brainify.providers.base import BaseProvider, Message
|
|
15
|
+
from brainify.utils.logger import get_logger
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AnthropicProvider(BaseProvider):
|
|
21
|
+
"""
|
|
22
|
+
Adapter for the Anthropic Messages API.
|
|
23
|
+
|
|
24
|
+
Supports: claude-3-5-sonnet-20241022, claude-3-5-haiku-20241022,
|
|
25
|
+
claude-3-opus-20240229, claude-3-haiku-20240307, etc.
|
|
26
|
+
|
|
27
|
+
Install: pip install anthropic
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: AIConfig):
|
|
31
|
+
super().__init__(config)
|
|
32
|
+
self._client = self._build_client()
|
|
33
|
+
|
|
34
|
+
def _build_client(self):
|
|
35
|
+
try:
|
|
36
|
+
from anthropic import Anthropic
|
|
37
|
+
except ImportError as exc:
|
|
38
|
+
raise ProviderNotInstalledError("anthropic", "anthropic") from exc
|
|
39
|
+
|
|
40
|
+
kwargs = {"api_key": self.config.api_key, "timeout": self.config.timeout}
|
|
41
|
+
if self.config.base_url:
|
|
42
|
+
kwargs["base_url"] = self.config.base_url
|
|
43
|
+
if self.config.extra_headers:
|
|
44
|
+
kwargs["default_headers"] = self.config.extra_headers
|
|
45
|
+
return Anthropic(**kwargs)
|
|
46
|
+
|
|
47
|
+
def complete(self, messages: List[Message]) -> str:
|
|
48
|
+
system_text = ""
|
|
49
|
+
chat_messages = []
|
|
50
|
+
for m in messages:
|
|
51
|
+
if m.role == "system":
|
|
52
|
+
system_text = m.content
|
|
53
|
+
else:
|
|
54
|
+
chat_messages.append(m.to_dict())
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
params = {
|
|
58
|
+
"model": self.config.model,
|
|
59
|
+
"messages": chat_messages,
|
|
60
|
+
"max_tokens": self.config.max_tokens,
|
|
61
|
+
"temperature": self.config.temperature,
|
|
62
|
+
"top_p": self.config.top_p,
|
|
63
|
+
}
|
|
64
|
+
if system_text:
|
|
65
|
+
params["system"] = system_text
|
|
66
|
+
if self.config.top_k is not None:
|
|
67
|
+
params["top_k"] = self.config.top_k
|
|
68
|
+
params.update(self.config.extra_params)
|
|
69
|
+
|
|
70
|
+
logger.debug("Anthropic request: model=%s", self.config.model)
|
|
71
|
+
response = self._client.messages.create(**params)
|
|
72
|
+
result = response.content[0].text if response.content else ""
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
except Exception as exc:
|
|
76
|
+
self._handle_error(exc)
|
|
77
|
+
|
|
78
|
+
async def acomplete(self, messages: List[Message]) -> str:
|
|
79
|
+
try:
|
|
80
|
+
from anthropic import AsyncAnthropic
|
|
81
|
+
except ImportError as exc:
|
|
82
|
+
raise ProviderNotInstalledError("anthropic", "anthropic") from exc
|
|
83
|
+
|
|
84
|
+
system_text = ""
|
|
85
|
+
chat_messages = []
|
|
86
|
+
for m in messages:
|
|
87
|
+
if m.role == "system":
|
|
88
|
+
system_text = m.content
|
|
89
|
+
else:
|
|
90
|
+
chat_messages.append(m.to_dict())
|
|
91
|
+
|
|
92
|
+
kwargs = {"api_key": self.config.api_key, "timeout": self.config.timeout}
|
|
93
|
+
async with AsyncAnthropic(**kwargs) as client:
|
|
94
|
+
params = {
|
|
95
|
+
"model": self.config.model,
|
|
96
|
+
"messages": chat_messages,
|
|
97
|
+
"max_tokens": self.config.max_tokens,
|
|
98
|
+
"temperature": self.config.temperature,
|
|
99
|
+
}
|
|
100
|
+
if system_text:
|
|
101
|
+
params["system"] = system_text
|
|
102
|
+
params.update(self.config.extra_params)
|
|
103
|
+
response = await client.messages.create(**params)
|
|
104
|
+
return response.content[0].text if response.content else ""
|
|
105
|
+
|
|
106
|
+
def _handle_error(self, exc: Exception):
|
|
107
|
+
try:
|
|
108
|
+
from anthropic import AuthenticationError as AAuth
|
|
109
|
+
from anthropic import RateLimitError as ARate
|
|
110
|
+
from anthropic import APIError
|
|
111
|
+
except ImportError:
|
|
112
|
+
raise ProviderError(str(exc), provider="anthropic") from exc
|
|
113
|
+
|
|
114
|
+
if isinstance(exc, AAuth):
|
|
115
|
+
raise AuthenticationError("Invalid Anthropic API key.", provider="anthropic") from exc
|
|
116
|
+
if isinstance(exc, ARate):
|
|
117
|
+
raise RateLimitError("Anthropic rate limit exceeded.", provider="anthropic") from exc
|
|
118
|
+
if isinstance(exc, APIError):
|
|
119
|
+
raise ProviderError(f"Anthropic API error: {exc}", provider="anthropic") from exc
|
|
120
|
+
raise ProviderError(f"Anthropic unexpected error: {exc}", provider="anthropic") from exc
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Abstract base class for all AI providers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
|
|
9
|
+
from brainify.config import AIConfig
|
|
10
|
+
from brainify.utils.logger import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Message:
|
|
16
|
+
"""A single chat message."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, role: str, content: str):
|
|
19
|
+
if role not in {"system", "user", "assistant"}:
|
|
20
|
+
raise ValueError(f"Invalid role: {role!r}. Must be system/user/assistant.")
|
|
21
|
+
self.role = role
|
|
22
|
+
self.content = content
|
|
23
|
+
|
|
24
|
+
def to_dict(self) -> dict:
|
|
25
|
+
return {"role": self.role, "content": self.content}
|
|
26
|
+
|
|
27
|
+
def __repr__(self) -> str:
|
|
28
|
+
preview = self.content[:60] + "…" if len(self.content) > 60 else self.content
|
|
29
|
+
return f"Message(role={self.role!r}, content={preview!r})"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseProvider(ABC):
|
|
33
|
+
"""
|
|
34
|
+
Abstract base for all Brainify AI provider adapters.
|
|
35
|
+
|
|
36
|
+
Subclasses must implement :meth:`complete` (sync) and may override
|
|
37
|
+
:meth:`acomplete` (async) for native async support.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, config: AIConfig):
|
|
41
|
+
self.config = config
|
|
42
|
+
self.provider_name: str = config.provider
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def complete(self, messages: List[Message]) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Send messages to the AI and return the response string.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
messages: Ordered list of chat messages (system, user, assistant…).
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The AI's response text.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ProviderError: On API errors.
|
|
58
|
+
AuthenticationError: On 401/403 responses.
|
|
59
|
+
RateLimitError: On 429 responses.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def acomplete(self, messages: List[Message]) -> str:
|
|
64
|
+
"""
|
|
65
|
+
Async version of :meth:`complete`.
|
|
66
|
+
|
|
67
|
+
Default implementation runs the sync version in a thread pool.
|
|
68
|
+
Override in subclasses for native async support.
|
|
69
|
+
"""
|
|
70
|
+
return await asyncio.get_event_loop().run_in_executor(
|
|
71
|
+
None, self.complete, messages
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _build_messages(
|
|
76
|
+
self,
|
|
77
|
+
user_message: str,
|
|
78
|
+
system_prompt: Optional[str] = None,
|
|
79
|
+
history: Optional[List[Message]] = None,
|
|
80
|
+
) -> List[Message]:
|
|
81
|
+
"""Build a full message list from components."""
|
|
82
|
+
messages: List[Message] = []
|
|
83
|
+
if system_prompt:
|
|
84
|
+
messages.append(Message("system", system_prompt))
|
|
85
|
+
if history:
|
|
86
|
+
messages.extend(history)
|
|
87
|
+
messages.append(Message("user", user_message))
|
|
88
|
+
return messages
|
|
89
|
+
|
|
90
|
+
def _get_common_params(self) -> dict:
|
|
91
|
+
"""Return common generation parameters from config."""
|
|
92
|
+
params = {
|
|
93
|
+
"temperature": self.config.temperature,
|
|
94
|
+
"max_tokens": self.config.max_tokens,
|
|
95
|
+
}
|
|
96
|
+
if self.config.top_p != 1.0:
|
|
97
|
+
params["top_p"] = self.config.top_p
|
|
98
|
+
params.update(self.config.extra_params)
|
|
99
|
+
return params
|
|
100
|
+
|
|
101
|
+
def __repr__(self) -> str:
|
|
102
|
+
return f"{self.__class__.__name__}(model={self.config.model!r})"
|