brainify 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainify/__init__.py ADDED
@@ -0,0 +1,32 @@
1
+ from brainify.client import Brain, Brainify
2
+ from brainify.config import AIConfig, BrainifyConfig, SearchConfig
3
+ from brainify.exceptions import (
4
+ AuthenticationError,
5
+ BrainifyError,
6
+ ConfigurationError,
7
+ ProviderError,
8
+ RateLimitError,
9
+ SearchError,
10
+ )
11
+ from brainify.providers import SUPPORTED_PROVIDERS
12
+ from brainify.utils.detector import SmartSearchDetector
13
+
14
+ __version__ = "0.1.0"
15
+ __author__ = "Brainify"
16
+ __license__ = "MIT"
17
+
18
+ __all__ = [
19
+ "Brain",
20
+ "Brainify",
21
+ "BrainifyConfig",
22
+ "AIConfig",
23
+ "SearchConfig",
24
+ "SmartSearchDetector",
25
+ "BrainifyError",
26
+ "ProviderError",
27
+ "SearchError",
28
+ "ConfigurationError",
29
+ "RateLimitError",
30
+ "AuthenticationError",
31
+ "SUPPORTED_PROVIDERS",
32
+ ]
brainify/client.py ADDED
@@ -0,0 +1,267 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Optional
4
+
5
+ from brainify.config import BrainifyConfig
6
+ from brainify.providers.base import BaseProvider, Message
7
+ from brainify.providers.factory import ProviderFactory
8
+ from brainify.utils.context import ContextBuilder
9
+ from brainify.utils.detector import SmartSearchDetector
10
+ from brainify.utils.logger import get_logger
11
+ from brainify.utils.prompts import PromptBuilder
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ _DECISION_CACHE_MAX = 256
16
+
17
+
18
+ class Brain:
19
+
20
+ def __init__(self, config: BrainifyConfig):
21
+ self._config = config
22
+ self._provider: BaseProvider = ProviderFactory.create(config.ai)
23
+ self._context_builder: ContextBuilder = ContextBuilder(config.search)
24
+ self._history: List[Message] = []
25
+ self._decision_cache: dict[str, bool] = {}
26
+ logger.info(
27
+ "Brain initialised — provider=%s model=%s enabled=%s auto_search=%s",
28
+ config.ai.provider,
29
+ config.ai.model,
30
+ config.search.enabled,
31
+ config.search.auto_search,
32
+ )
33
+
34
+ def chat(
35
+ self,
36
+ query: str,
37
+ *,
38
+ web_search: Optional[bool] = None,
39
+ clear_history: bool = False,
40
+ ) -> str:
41
+ if clear_history:
42
+ self.clear_history()
43
+
44
+ use_search = self._resolve_search_flag(web_search, query)
45
+ system_prompt, _ = self._prepare_system(query, use_search)
46
+ messages = self._build_messages(query, system_prompt)
47
+ response = self._provider.complete(messages)
48
+
49
+ self._history.append(Message("user", query))
50
+ self._history.append(Message("assistant", response))
51
+ return response
52
+
53
+ def ask(self, query: str, *, web_search: Optional[bool] = None) -> str:
54
+ return self.chat(query, web_search=web_search, clear_history=True)
55
+
56
+ async def achat(
57
+ self,
58
+ query: str,
59
+ *,
60
+ web_search: Optional[bool] = None,
61
+ clear_history: bool = False,
62
+ ) -> str:
63
+ if clear_history:
64
+ self.clear_history()
65
+
66
+ use_search = await self._aresolve_search_flag(web_search, query)
67
+ system_prompt, _ = await self._aprepare_system(query, use_search)
68
+ messages = self._build_messages(query, system_prompt)
69
+ response = await self._provider.acomplete(messages)
70
+
71
+ self._history.append(Message("user", query))
72
+ self._history.append(Message("assistant", response))
73
+ return response
74
+
75
+ async def aask(self, query: str, *, web_search: Optional[bool] = None) -> str:
76
+ return await self.achat(query, web_search=web_search, clear_history=True)
77
+
78
+ def search_needed(self, query: str) -> bool:
79
+ return SmartSearchDetector.needs_search(query)
80
+
81
+ def search_explain(self, query: str) -> dict:
82
+ return SmartSearchDetector.explain(query)
83
+
84
+ def clear_history(self) -> None:
85
+ self._history.clear()
86
+ logger.debug("Conversation history cleared.")
87
+
88
+ def clear_decision_cache(self) -> None:
89
+ self._decision_cache.clear()
90
+ logger.debug("LLM decision cache cleared.")
91
+
92
+ @property
93
+ def history(self) -> List[Message]:
94
+ return list(self._history)
95
+
96
+ def add_to_history(self, role: str, content: str) -> None:
97
+ self._history.append(Message(role, content))
98
+
99
+ def clear_search_cache(self) -> None:
100
+ self._context_builder.clear_cache()
101
+ logger.debug("Search cache cleared.")
102
+
103
+ @property
104
+ def search_cache_size(self) -> int:
105
+ return self._context_builder.cache_size
106
+
107
+ @property
108
+ def config(self) -> BrainifyConfig:
109
+ return self._config
110
+
111
+ @property
112
+ def provider_name(self) -> str:
113
+ return self._config.ai.provider
114
+
115
+ @property
116
+ def model(self) -> str:
117
+ return self._config.ai.model
118
+
119
+ @property
120
+ def web_search_enabled(self) -> bool:
121
+ return self._config.search.enabled
122
+
123
+ @property
124
+ def auto_search_enabled(self) -> bool:
125
+ return self._config.search.auto_search
126
+
127
+ def _resolve_search_flag(self, override: Optional[bool], query: str) -> bool:
128
+ if override is not None:
129
+ return override
130
+ if self._config.search.enabled:
131
+ return True
132
+ if self._config.search.auto_search:
133
+ return self._llm_search_decision(query)
134
+ return False
135
+
136
+ async def _aresolve_search_flag(self, override: Optional[bool], query: str) -> bool:
137
+ if override is not None:
138
+ return override
139
+ if self._config.search.enabled:
140
+ return True
141
+ if self._config.search.auto_search:
142
+ return await self._allm_search_decision(query)
143
+ return False
144
+
145
+ def _llm_search_decision(self, query: str) -> bool:
146
+ cache_key = query.strip().lower()
147
+ if cache_key in self._decision_cache:
148
+ cached = self._decision_cache[cache_key]
149
+ logger.debug("Decision cache hit: %s for %r", "SEARCH" if cached else "SKIP", query[:60])
150
+ return cached
151
+
152
+ try:
153
+ sys_prompt, usr_prompt = PromptBuilder.classifier_prompt(query)
154
+ classifier_messages = [
155
+ Message("system", sys_prompt),
156
+ Message("user", usr_prompt),
157
+ ]
158
+ raw = self._provider.complete(classifier_messages)
159
+ decision = self._parse_classifier_response(raw, query)
160
+ except Exception as exc:
161
+ logger.warning("LLM classifier error (%s) — keyword fallback for: %r", exc, query[:60])
162
+ decision = SmartSearchDetector.needs_search(query)
163
+
164
+ self._store_decision(cache_key, decision)
165
+ return decision
166
+
167
+ async def _allm_search_decision(self, query: str) -> bool:
168
+ cache_key = query.strip().lower()
169
+ if cache_key in self._decision_cache:
170
+ cached = self._decision_cache[cache_key]
171
+ logger.debug("Decision cache hit: %s for %r", "SEARCH" if cached else "SKIP", query[:60])
172
+ return cached
173
+
174
+ try:
175
+ sys_prompt, usr_prompt = PromptBuilder.classifier_prompt(query)
176
+ classifier_messages = [
177
+ Message("system", sys_prompt),
178
+ Message("user", usr_prompt),
179
+ ]
180
+ raw = await self._provider.acomplete(classifier_messages)
181
+ decision = self._parse_classifier_response(raw, query)
182
+ except Exception as exc:
183
+ logger.warning("LLM classifier error (%s) — keyword fallback for: %r", exc, query[:60])
184
+ decision = SmartSearchDetector.needs_search(query)
185
+
186
+ self._store_decision(cache_key, decision)
187
+ return decision
188
+
189
+ def _parse_classifier_response(self, raw: str, query: str) -> bool:
190
+ answer = raw.strip().upper()
191
+ if answer.startswith("YES"):
192
+ logger.info("LLM decision → SEARCH | query: %r", query[:80])
193
+ return True
194
+ if answer.startswith("NO"):
195
+ logger.info("LLM decision → SKIP | query: %r", query[:80])
196
+ return False
197
+ logger.warning(
198
+ "Ambiguous LLM classifier response %r — keyword fallback for: %r",
199
+ raw[:30],
200
+ query[:60],
201
+ )
202
+ return SmartSearchDetector.needs_search(query)
203
+
204
+ def _store_decision(self, key: str, decision: bool) -> None:
205
+ if len(self._decision_cache) >= _DECISION_CACHE_MAX:
206
+ oldest = next(iter(self._decision_cache))
207
+ del self._decision_cache[oldest]
208
+ self._decision_cache[key] = decision
209
+
210
+ def _prepare_system(self, query: str, use_search: bool):
211
+ search_context = ""
212
+ search_results = []
213
+
214
+ if use_search:
215
+ try:
216
+ search_context, search_results = self._context_builder.build(query)
217
+ except Exception as exc:
218
+ logger.warning("Web search failed, continuing without it: %s", exc)
219
+
220
+ has_context = bool(search_context)
221
+ system_prompt = PromptBuilder.build(
222
+ web_search_active=use_search and has_context,
223
+ user_system_prompt=self._config.ai.system_prompt,
224
+ search_context=search_context if has_context else None,
225
+ )
226
+ if use_search and not has_context:
227
+ system_prompt += PromptBuilder.no_search_fallback_note()
228
+
229
+ return system_prompt, search_results
230
+
231
+ async def _aprepare_system(self, query: str, use_search: bool):
232
+ search_context = ""
233
+ search_results = []
234
+
235
+ if use_search:
236
+ try:
237
+ search_context, search_results = await self._context_builder.abuild(query)
238
+ except Exception as exc:
239
+ logger.warning("Async web search failed: %s", exc)
240
+
241
+ has_context = bool(search_context)
242
+ system_prompt = PromptBuilder.build(
243
+ web_search_active=use_search and has_context,
244
+ user_system_prompt=self._config.ai.system_prompt,
245
+ search_context=search_context if has_context else None,
246
+ )
247
+ if use_search and not has_context:
248
+ system_prompt += PromptBuilder.no_search_fallback_note()
249
+
250
+ return system_prompt, search_results
251
+
252
+ def _build_messages(self, query: str, system_prompt: str) -> List[Message]:
253
+ messages: List[Message] = [Message("system", system_prompt)]
254
+ messages.extend(self._history)
255
+ messages.append(Message("user", query))
256
+ return messages
257
+
258
+ def __repr__(self) -> str:
259
+ return (
260
+ f"Brain(provider={self.provider_name!r}, "
261
+ f"model={self.model!r}, "
262
+ f"enabled={self.web_search_enabled}, "
263
+ f"auto_search={self.auto_search_enabled})"
264
+ )
265
+
266
+
267
+ Brainify = Brain
brainify/config.py ADDED
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from brainify.exceptions import ConfigurationError
7
+
8
+
9
+ @dataclass
10
+ class AIConfig:
11
+ provider: str
12
+ api_key: str = ""
13
+ model: str = ""
14
+ temperature: float = 0.7
15
+ max_tokens: int = 2048
16
+ top_p: float = 1.0
17
+ top_k: Optional[int] = None
18
+ frequency_penalty: float = 0.0
19
+ presence_penalty: float = 0.0
20
+ system_prompt: Optional[str] = None
21
+ base_url: Optional[str] = None
22
+ timeout: int = 60
23
+ extra_headers: Dict[str, str] = field(default_factory=dict)
24
+ extra_params: Dict[str, Any] = field(default_factory=dict)
25
+
26
+ def __post_init__(self):
27
+ self.provider = self.provider.lower().strip()
28
+ valid_providers = {
29
+ "openai", "anthropic", "gemini", "groq",
30
+ "mistral", "cohere", "together", "ollama", "custom",
31
+ }
32
+ if self.provider not in valid_providers:
33
+ raise ConfigurationError(
34
+ f"Unknown provider '{self.provider}'. "
35
+ f"Supported: {', '.join(sorted(valid_providers))}"
36
+ )
37
+ if not self.api_key and self.provider not in {"ollama", "custom"}:
38
+ raise ConfigurationError(
39
+ f"api_key is required for provider '{self.provider}'."
40
+ )
41
+ if not (0.0 <= self.temperature <= 2.0):
42
+ raise ConfigurationError("temperature must be between 0.0 and 2.0.")
43
+ if self.max_tokens < 1:
44
+ raise ConfigurationError("max_tokens must be >= 1.")
45
+
46
+ @property
47
+ def is_openai_compatible(self) -> bool:
48
+ return self.provider in {"openai", "groq", "together", "custom", "ollama"}
49
+
50
+
51
+ @dataclass
52
+ class SearchConfig:
53
+ enabled: bool = True
54
+ auto_search: bool = False
55
+ max_results: int = 5
56
+ max_pages_to_read: int = 3
57
+ max_content_length: int = 6000
58
+ search_timeout: int = 15
59
+ cache_ttl: int = 3600
60
+ safe_search: bool = True
61
+ region: str = "wt-wt"
62
+ time_filter: Optional[str] = None
63
+ excluded_domains: List[str] = field(default_factory=list)
64
+ include_snippets: bool = True
65
+ scrape_concurrency: int = 3
66
+ user_agent: Optional[str] = None
67
+
68
+ def __post_init__(self):
69
+ if self.max_results < 1 or self.max_results > 20:
70
+ raise ConfigurationError("max_results must be between 1 and 20.")
71
+ if self.max_pages_to_read < 0:
72
+ raise ConfigurationError("max_pages_to_read must be >= 0.")
73
+ if self.time_filter and self.time_filter not in {"d", "w", "m", "y"}:
74
+ raise ConfigurationError(
75
+ "time_filter must be one of: 'd', 'w', 'm', 'y'."
76
+ )
77
+
78
+
79
+ @dataclass
80
+ class BrainifyConfig:
81
+ ai: AIConfig
82
+ search: SearchConfig = field(default_factory=SearchConfig)
brainify/exceptions.py ADDED
@@ -0,0 +1,70 @@
1
+ """Custom exceptions for Brainify."""
2
+
3
+
4
+ class BrainifyError(Exception):
5
+ """Base exception for all Brainify errors."""
6
+
7
+ def __init__(self, message: str, details: dict = None):
8
+ super().__init__(message)
9
+ self.message = message
10
+ self.details = details or {}
11
+
12
+ def __repr__(self):
13
+ return f"{self.__class__.__name__}(message={self.message!r})"
14
+
15
+
16
+ class ConfigurationError(BrainifyError):
17
+ """Raised when configuration is invalid or missing."""
18
+
19
+
20
+ class ProviderError(BrainifyError):
21
+ """Raised when an AI provider returns an error."""
22
+
23
+ def __init__(self, message: str, provider: str = None, status_code: int = None, details: dict = None):
24
+ super().__init__(message, details)
25
+ self.provider = provider
26
+ self.status_code = status_code
27
+
28
+
29
+ class AuthenticationError(ProviderError):
30
+ """Raised when API key authentication fails."""
31
+
32
+
33
+ class RateLimitError(ProviderError):
34
+ """Raised when the API rate limit is exceeded."""
35
+
36
+ def __init__(self, message: str, provider: str = None, retry_after: float = None, details: dict = None):
37
+ super().__init__(message, provider=provider, details=details)
38
+ self.retry_after = retry_after
39
+
40
+
41
+ class SearchError(BrainifyError):
42
+ """Raised when a web search fails."""
43
+
44
+
45
+ class ScrapingError(BrainifyError):
46
+ """Raised when web page scraping fails."""
47
+
48
+
49
+ class ModelNotSupportedError(ConfigurationError):
50
+ """Raised when the model is not supported by the provider."""
51
+
52
+ def __init__(self, model: str, provider: str):
53
+ super().__init__(
54
+ f"Model '{model}' is not supported by provider '{provider}'. "
55
+ f"Check the provider documentation for supported models."
56
+ )
57
+ self.model = model
58
+ self.provider = provider
59
+
60
+
61
+ class ProviderNotInstalledError(ConfigurationError):
62
+ """Raised when the required provider package is not installed."""
63
+
64
+ def __init__(self, provider: str, package: str):
65
+ super().__init__(
66
+ f"Provider '{provider}' requires the '{package}' package. "
67
+ f"Install it with: pip install brainify[{provider}]"
68
+ )
69
+ self.provider = provider
70
+ self.package = package
@@ -0,0 +1,17 @@
1
+ """AI provider implementations."""
2
+
3
+ from brainify.providers.factory import ProviderFactory
4
+
5
+ SUPPORTED_PROVIDERS = [
6
+ "openai",
7
+ "anthropic",
8
+ "gemini",
9
+ "groq",
10
+ "mistral",
11
+ "cohere",
12
+ "together",
13
+ "ollama",
14
+ "custom",
15
+ ]
16
+
17
+ __all__ = ["ProviderFactory", "SUPPORTED_PROVIDERS"]
@@ -0,0 +1,120 @@
1
+ """Anthropic provider (Claude 3.5, Claude 3, …)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List
6
+
7
+ from brainify.config import AIConfig
8
+ from brainify.exceptions import (
9
+ AuthenticationError,
10
+ ProviderError,
11
+ RateLimitError,
12
+ ProviderNotInstalledError,
13
+ )
14
+ from brainify.providers.base import BaseProvider, Message
15
+ from brainify.utils.logger import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ class AnthropicProvider(BaseProvider):
21
+ """
22
+ Adapter for the Anthropic Messages API.
23
+
24
+ Supports: claude-3-5-sonnet-20241022, claude-3-5-haiku-20241022,
25
+ claude-3-opus-20240229, claude-3-haiku-20240307, etc.
26
+
27
+ Install: pip install anthropic
28
+ """
29
+
30
+ def __init__(self, config: AIConfig):
31
+ super().__init__(config)
32
+ self._client = self._build_client()
33
+
34
+ def _build_client(self):
35
+ try:
36
+ from anthropic import Anthropic
37
+ except ImportError as exc:
38
+ raise ProviderNotInstalledError("anthropic", "anthropic") from exc
39
+
40
+ kwargs = {"api_key": self.config.api_key, "timeout": self.config.timeout}
41
+ if self.config.base_url:
42
+ kwargs["base_url"] = self.config.base_url
43
+ if self.config.extra_headers:
44
+ kwargs["default_headers"] = self.config.extra_headers
45
+ return Anthropic(**kwargs)
46
+
47
+ def complete(self, messages: List[Message]) -> str:
48
+ system_text = ""
49
+ chat_messages = []
50
+ for m in messages:
51
+ if m.role == "system":
52
+ system_text = m.content
53
+ else:
54
+ chat_messages.append(m.to_dict())
55
+
56
+ try:
57
+ params = {
58
+ "model": self.config.model,
59
+ "messages": chat_messages,
60
+ "max_tokens": self.config.max_tokens,
61
+ "temperature": self.config.temperature,
62
+ "top_p": self.config.top_p,
63
+ }
64
+ if system_text:
65
+ params["system"] = system_text
66
+ if self.config.top_k is not None:
67
+ params["top_k"] = self.config.top_k
68
+ params.update(self.config.extra_params)
69
+
70
+ logger.debug("Anthropic request: model=%s", self.config.model)
71
+ response = self._client.messages.create(**params)
72
+ result = response.content[0].text if response.content else ""
73
+ return result
74
+
75
+ except Exception as exc:
76
+ self._handle_error(exc)
77
+
78
+ async def acomplete(self, messages: List[Message]) -> str:
79
+ try:
80
+ from anthropic import AsyncAnthropic
81
+ except ImportError as exc:
82
+ raise ProviderNotInstalledError("anthropic", "anthropic") from exc
83
+
84
+ system_text = ""
85
+ chat_messages = []
86
+ for m in messages:
87
+ if m.role == "system":
88
+ system_text = m.content
89
+ else:
90
+ chat_messages.append(m.to_dict())
91
+
92
+ kwargs = {"api_key": self.config.api_key, "timeout": self.config.timeout}
93
+ async with AsyncAnthropic(**kwargs) as client:
94
+ params = {
95
+ "model": self.config.model,
96
+ "messages": chat_messages,
97
+ "max_tokens": self.config.max_tokens,
98
+ "temperature": self.config.temperature,
99
+ }
100
+ if system_text:
101
+ params["system"] = system_text
102
+ params.update(self.config.extra_params)
103
+ response = await client.messages.create(**params)
104
+ return response.content[0].text if response.content else ""
105
+
106
+ def _handle_error(self, exc: Exception):
107
+ try:
108
+ from anthropic import AuthenticationError as AAuth
109
+ from anthropic import RateLimitError as ARate
110
+ from anthropic import APIError
111
+ except ImportError:
112
+ raise ProviderError(str(exc), provider="anthropic") from exc
113
+
114
+ if isinstance(exc, AAuth):
115
+ raise AuthenticationError("Invalid Anthropic API key.", provider="anthropic") from exc
116
+ if isinstance(exc, ARate):
117
+ raise RateLimitError("Anthropic rate limit exceeded.", provider="anthropic") from exc
118
+ if isinstance(exc, APIError):
119
+ raise ProviderError(f"Anthropic API error: {exc}", provider="anthropic") from exc
120
+ raise ProviderError(f"Anthropic unexpected error: {exc}", provider="anthropic") from exc
@@ -0,0 +1,102 @@
1
+ """Abstract base class for all AI providers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from abc import ABC, abstractmethod
7
+ from typing import List, Optional
8
+
9
+ from brainify.config import AIConfig
10
+ from brainify.utils.logger import get_logger
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class Message:
16
+ """A single chat message."""
17
+
18
+ def __init__(self, role: str, content: str):
19
+ if role not in {"system", "user", "assistant"}:
20
+ raise ValueError(f"Invalid role: {role!r}. Must be system/user/assistant.")
21
+ self.role = role
22
+ self.content = content
23
+
24
+ def to_dict(self) -> dict:
25
+ return {"role": self.role, "content": self.content}
26
+
27
+ def __repr__(self) -> str:
28
+ preview = self.content[:60] + "…" if len(self.content) > 60 else self.content
29
+ return f"Message(role={self.role!r}, content={preview!r})"
30
+
31
+
32
+ class BaseProvider(ABC):
33
+ """
34
+ Abstract base for all Brainify AI provider adapters.
35
+
36
+ Subclasses must implement :meth:`complete` (sync) and may override
37
+ :meth:`acomplete` (async) for native async support.
38
+ """
39
+
40
+ def __init__(self, config: AIConfig):
41
+ self.config = config
42
+ self.provider_name: str = config.provider
43
+
44
+
45
+ @abstractmethod
46
+ def complete(self, messages: List[Message]) -> str:
47
+ """
48
+ Send messages to the AI and return the response string.
49
+
50
+ Args:
51
+ messages: Ordered list of chat messages (system, user, assistant…).
52
+
53
+ Returns:
54
+ The AI's response text.
55
+
56
+ Raises:
57
+ ProviderError: On API errors.
58
+ AuthenticationError: On 401/403 responses.
59
+ RateLimitError: On 429 responses.
60
+ """
61
+
62
+
63
+ async def acomplete(self, messages: List[Message]) -> str:
64
+ """
65
+ Async version of :meth:`complete`.
66
+
67
+ Default implementation runs the sync version in a thread pool.
68
+ Override in subclasses for native async support.
69
+ """
70
+ return await asyncio.get_event_loop().run_in_executor(
71
+ None, self.complete, messages
72
+ )
73
+
74
+
75
+ def _build_messages(
76
+ self,
77
+ user_message: str,
78
+ system_prompt: Optional[str] = None,
79
+ history: Optional[List[Message]] = None,
80
+ ) -> List[Message]:
81
+ """Build a full message list from components."""
82
+ messages: List[Message] = []
83
+ if system_prompt:
84
+ messages.append(Message("system", system_prompt))
85
+ if history:
86
+ messages.extend(history)
87
+ messages.append(Message("user", user_message))
88
+ return messages
89
+
90
+ def _get_common_params(self) -> dict:
91
+ """Return common generation parameters from config."""
92
+ params = {
93
+ "temperature": self.config.temperature,
94
+ "max_tokens": self.config.max_tokens,
95
+ }
96
+ if self.config.top_p != 1.0:
97
+ params["top_p"] = self.config.top_p
98
+ params.update(self.config.extra_params)
99
+ return params
100
+
101
+ def __repr__(self) -> str:
102
+ return f"{self.__class__.__name__}(model={self.config.model!r})"