stratifyai 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/__init__.py +5 -0
- cli/stratifyai_cli.py +1753 -0
- stratifyai/__init__.py +113 -0
- stratifyai/api_key_helper.py +372 -0
- stratifyai/caching.py +279 -0
- stratifyai/chat/__init__.py +54 -0
- stratifyai/chat/builder.py +366 -0
- stratifyai/chat/stratifyai_anthropic.py +194 -0
- stratifyai/chat/stratifyai_bedrock.py +200 -0
- stratifyai/chat/stratifyai_deepseek.py +194 -0
- stratifyai/chat/stratifyai_google.py +194 -0
- stratifyai/chat/stratifyai_grok.py +194 -0
- stratifyai/chat/stratifyai_groq.py +195 -0
- stratifyai/chat/stratifyai_ollama.py +201 -0
- stratifyai/chat/stratifyai_openai.py +209 -0
- stratifyai/chat/stratifyai_openrouter.py +201 -0
- stratifyai/chunking.py +158 -0
- stratifyai/client.py +292 -0
- stratifyai/config.py +1273 -0
- stratifyai/cost_tracker.py +257 -0
- stratifyai/embeddings.py +245 -0
- stratifyai/exceptions.py +91 -0
- stratifyai/models.py +59 -0
- stratifyai/providers/__init__.py +5 -0
- stratifyai/providers/anthropic.py +330 -0
- stratifyai/providers/base.py +183 -0
- stratifyai/providers/bedrock.py +634 -0
- stratifyai/providers/deepseek.py +39 -0
- stratifyai/providers/google.py +39 -0
- stratifyai/providers/grok.py +39 -0
- stratifyai/providers/groq.py +39 -0
- stratifyai/providers/ollama.py +43 -0
- stratifyai/providers/openai.py +344 -0
- stratifyai/providers/openai_compatible.py +372 -0
- stratifyai/providers/openrouter.py +39 -0
- stratifyai/py.typed +2 -0
- stratifyai/rag.py +381 -0
- stratifyai/retry.py +185 -0
- stratifyai/router.py +643 -0
- stratifyai/summarization.py +179 -0
- stratifyai/utils/__init__.py +11 -0
- stratifyai/utils/bedrock_validator.py +136 -0
- stratifyai/utils/code_extractor.py +327 -0
- stratifyai/utils/csv_extractor.py +197 -0
- stratifyai/utils/file_analyzer.py +192 -0
- stratifyai/utils/json_extractor.py +219 -0
- stratifyai/utils/log_extractor.py +267 -0
- stratifyai/utils/model_selector.py +324 -0
- stratifyai/utils/provider_validator.py +442 -0
- stratifyai/utils/token_counter.py +186 -0
- stratifyai/vectordb.py +344 -0
- stratifyai-0.1.0.dist-info/METADATA +263 -0
- stratifyai-0.1.0.dist-info/RECORD +57 -0
- stratifyai-0.1.0.dist-info/WHEEL +5 -0
- stratifyai-0.1.0.dist-info/entry_points.txt +2 -0
- stratifyai-0.1.0.dist-info/licenses/LICENSE +21 -0
- stratifyai-0.1.0.dist-info/top_level.txt +2 -0
stratifyai/caching.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""Caching utilities for LLM responses."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from functools import wraps
|
|
9
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
10
|
+
|
|
11
|
+
from .models import ChatResponse
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class CacheEntry:
|
|
16
|
+
"""Entry in the response cache."""
|
|
17
|
+
response: ChatResponse
|
|
18
|
+
timestamp: float
|
|
19
|
+
hits: int = 0
|
|
20
|
+
cost_saved: float = 0.0 # Total cost saved from this entry
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ResponseCache:
|
|
24
|
+
"""Thread-safe in-memory cache for LLM responses."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, ttl: int = 3600, max_size: int = 1000):
|
|
27
|
+
"""
|
|
28
|
+
Initialize response cache.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
ttl: Time-to-live for cache entries in seconds
|
|
32
|
+
max_size: Maximum number of entries to store
|
|
33
|
+
"""
|
|
34
|
+
self.ttl = ttl
|
|
35
|
+
self.max_size = max_size
|
|
36
|
+
self._cache: Dict[str, CacheEntry] = {}
|
|
37
|
+
self._lock = threading.Lock()
|
|
38
|
+
self._total_misses: int = 0
|
|
39
|
+
self._total_cost_saved: float = 0.0
|
|
40
|
+
|
|
41
|
+
def get(self, key: str) -> Optional[ChatResponse]:
|
|
42
|
+
"""
|
|
43
|
+
Get response from cache.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
key: Cache key
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Cached response if found and not expired, None otherwise
|
|
50
|
+
"""
|
|
51
|
+
with self._lock:
|
|
52
|
+
if key not in self._cache:
|
|
53
|
+
self._total_misses += 1
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
entry = self._cache[key]
|
|
57
|
+
|
|
58
|
+
# Check if expired
|
|
59
|
+
if time.time() - entry.timestamp > self.ttl:
|
|
60
|
+
del self._cache[key]
|
|
61
|
+
self._total_misses += 1
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
# Update hit count and cost saved
|
|
65
|
+
entry.hits += 1
|
|
66
|
+
if hasattr(entry.response, 'usage') and hasattr(entry.response.usage, 'cost_usd'):
|
|
67
|
+
cost = entry.response.usage.cost_usd
|
|
68
|
+
entry.cost_saved += cost
|
|
69
|
+
self._total_cost_saved += cost
|
|
70
|
+
|
|
71
|
+
return entry.response
|
|
72
|
+
|
|
73
|
+
def set(self, key: str, response: ChatResponse) -> None:
|
|
74
|
+
"""
|
|
75
|
+
Store response in cache.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
key: Cache key
|
|
79
|
+
response: Response to cache
|
|
80
|
+
"""
|
|
81
|
+
with self._lock:
|
|
82
|
+
# Evict oldest entry if cache is full
|
|
83
|
+
if len(self._cache) >= self.max_size:
|
|
84
|
+
oldest_key = min(
|
|
85
|
+
self._cache.keys(),
|
|
86
|
+
key=lambda k: self._cache[k].timestamp
|
|
87
|
+
)
|
|
88
|
+
del self._cache[oldest_key]
|
|
89
|
+
|
|
90
|
+
self._cache[key] = CacheEntry(
|
|
91
|
+
response=response,
|
|
92
|
+
timestamp=time.time(),
|
|
93
|
+
hits=0
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def clear(self) -> None:
|
|
97
|
+
"""Clear all cache entries."""
|
|
98
|
+
with self._lock:
|
|
99
|
+
self._cache.clear()
|
|
100
|
+
self._total_misses = 0
|
|
101
|
+
self._total_cost_saved = 0.0
|
|
102
|
+
|
|
103
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
104
|
+
"""
|
|
105
|
+
Get cache statistics.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Dictionary with cache stats including hits, misses, and cost savings
|
|
109
|
+
"""
|
|
110
|
+
with self._lock:
|
|
111
|
+
total_hits = sum(entry.hits for entry in self._cache.values())
|
|
112
|
+
total_requests = total_hits + self._total_misses
|
|
113
|
+
hit_rate = (total_hits / total_requests * 100) if total_requests > 0 else 0.0
|
|
114
|
+
|
|
115
|
+
return {
|
|
116
|
+
"size": len(self._cache),
|
|
117
|
+
"max_size": self.max_size,
|
|
118
|
+
"total_hits": total_hits,
|
|
119
|
+
"total_misses": self._total_misses,
|
|
120
|
+
"total_requests": total_requests,
|
|
121
|
+
"hit_rate": hit_rate,
|
|
122
|
+
"total_cost_saved": self._total_cost_saved,
|
|
123
|
+
"ttl": self.ttl,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def get_entries(self) -> list[Dict[str, Any]]:
|
|
127
|
+
"""
|
|
128
|
+
Get detailed information about cache entries.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of cache entry details
|
|
132
|
+
"""
|
|
133
|
+
with self._lock:
|
|
134
|
+
entries = []
|
|
135
|
+
for key, entry in self._cache.items():
|
|
136
|
+
age = time.time() - entry.timestamp
|
|
137
|
+
entries.append({
|
|
138
|
+
"key": key[:16] + "...", # Truncate hash for display
|
|
139
|
+
"model": entry.response.model if hasattr(entry.response, 'model') else "unknown",
|
|
140
|
+
"provider": entry.response.provider if hasattr(entry.response, 'provider') else "unknown",
|
|
141
|
+
"hits": entry.hits,
|
|
142
|
+
"cost_saved": entry.cost_saved,
|
|
143
|
+
"age_seconds": int(age),
|
|
144
|
+
"expires_in": int(self.ttl - age),
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
# Sort by hits (most popular first)
|
|
148
|
+
entries.sort(key=lambda x: x["hits"], reverse=True)
|
|
149
|
+
return entries
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Global cache instance
|
|
153
|
+
_global_cache = ResponseCache()
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def generate_cache_key(
|
|
157
|
+
model: str,
|
|
158
|
+
messages: list,
|
|
159
|
+
temperature: float,
|
|
160
|
+
max_tokens: Optional[int] = None,
|
|
161
|
+
**kwargs
|
|
162
|
+
) -> str:
|
|
163
|
+
"""
|
|
164
|
+
Generate a unique cache key from request parameters.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
model: Model name
|
|
168
|
+
messages: List of messages
|
|
169
|
+
temperature: Temperature parameter
|
|
170
|
+
max_tokens: Maximum tokens
|
|
171
|
+
**kwargs: Additional parameters
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
SHA256 hash of the request parameters
|
|
175
|
+
"""
|
|
176
|
+
# Convert messages to hashable format
|
|
177
|
+
messages_str = json.dumps(
|
|
178
|
+
[{"role": m.role, "content": m.content} if hasattr(m, "role") else m
|
|
179
|
+
for m in messages],
|
|
180
|
+
sort_keys=True
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Include relevant parameters
|
|
184
|
+
cache_data = {
|
|
185
|
+
"model": model,
|
|
186
|
+
"messages": messages_str,
|
|
187
|
+
"temperature": temperature,
|
|
188
|
+
"max_tokens": max_tokens,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
# Add any additional kwargs that affect the response
|
|
192
|
+
for key in sorted(kwargs.keys()):
|
|
193
|
+
if key not in ["stream", "extra_params"]: # Skip non-deterministic params
|
|
194
|
+
cache_data[key] = kwargs[key]
|
|
195
|
+
|
|
196
|
+
# Generate hash
|
|
197
|
+
cache_str = json.dumps(cache_data, sort_keys=True)
|
|
198
|
+
return hashlib.sha256(cache_str.encode()).hexdigest()
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def cache_response(
|
|
202
|
+
ttl: int = 3600,
|
|
203
|
+
cache_instance: Optional[ResponseCache] = None
|
|
204
|
+
):
|
|
205
|
+
"""
|
|
206
|
+
Decorator to cache async LLM responses.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
ttl: Time-to-live for cache entries in seconds
|
|
210
|
+
cache_instance: Optional cache instance (uses global cache if None)
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Decorated async function
|
|
214
|
+
|
|
215
|
+
Example:
|
|
216
|
+
@cache_response(ttl=3600)
|
|
217
|
+
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
218
|
+
return await self.provider.chat_completion(request)
|
|
219
|
+
"""
|
|
220
|
+
cache = cache_instance or _global_cache
|
|
221
|
+
cache.ttl = ttl
|
|
222
|
+
|
|
223
|
+
def decorator(func: Callable) -> Callable:
|
|
224
|
+
@wraps(func)
|
|
225
|
+
async def async_wrapper(*args, **kwargs) -> ChatResponse:
|
|
226
|
+
# Extract request parameters
|
|
227
|
+
# Handle both ChatRequest object and individual parameters
|
|
228
|
+
if args and hasattr(args[0], "model"):
|
|
229
|
+
request = args[0]
|
|
230
|
+
cache_key = generate_cache_key(
|
|
231
|
+
model=request.model,
|
|
232
|
+
messages=request.messages,
|
|
233
|
+
temperature=request.temperature,
|
|
234
|
+
max_tokens=request.max_tokens,
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
cache_key = generate_cache_key(**kwargs)
|
|
238
|
+
|
|
239
|
+
# Check cache
|
|
240
|
+
cached_response = cache.get(cache_key)
|
|
241
|
+
if cached_response is not None:
|
|
242
|
+
return cached_response
|
|
243
|
+
|
|
244
|
+
# Execute async function
|
|
245
|
+
response = await func(*args, **kwargs)
|
|
246
|
+
|
|
247
|
+
# Cache response (only if not streaming)
|
|
248
|
+
if not kwargs.get("stream", False):
|
|
249
|
+
cache.set(cache_key, response)
|
|
250
|
+
|
|
251
|
+
return response
|
|
252
|
+
|
|
253
|
+
return async_wrapper
|
|
254
|
+
return decorator
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def get_cache_stats() -> Dict[str, Any]:
|
|
258
|
+
"""
|
|
259
|
+
Get statistics from the global cache.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Dictionary with cache statistics
|
|
263
|
+
"""
|
|
264
|
+
return _global_cache.get_stats()
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def get_cache_entries() -> list[Dict[str, Any]]:
|
|
268
|
+
"""
|
|
269
|
+
Get detailed information about cache entries.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
List of cache entry details
|
|
273
|
+
"""
|
|
274
|
+
return _global_cache.get_entries()
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def clear_cache() -> None:
|
|
278
|
+
"""Clear the global cache."""
|
|
279
|
+
_global_cache.clear()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Chat package for StratifyAI provider-specific chat interfaces.
|
|
2
|
+
|
|
3
|
+
This package provides convenient, provider-specific chat functions.
|
|
4
|
+
Model must be specified for each request.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
# Model is always required
|
|
8
|
+
from stratifyai.chat import openai, anthropic, google
|
|
9
|
+
response = await openai.chat("Hello!", model="gpt-4.1-mini")
|
|
10
|
+
|
|
11
|
+
# Builder pattern (model required first)
|
|
12
|
+
from stratifyai.chat import anthropic
|
|
13
|
+
client = (
|
|
14
|
+
anthropic
|
|
15
|
+
.with_model("claude-opus-4-5")
|
|
16
|
+
.with_system("You are a helpful assistant")
|
|
17
|
+
.with_developer("Use markdown formatting")
|
|
18
|
+
)
|
|
19
|
+
response = await client.chat("Hello!")
|
|
20
|
+
|
|
21
|
+
# With additional parameters
|
|
22
|
+
response = await anthropic.chat(
|
|
23
|
+
"Explain quantum computing",
|
|
24
|
+
model="claude-sonnet-4-5",
|
|
25
|
+
temperature=0.5,
|
|
26
|
+
max_tokens=500,
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from stratifyai.chat.builder import ChatBuilder
|
|
31
|
+
from stratifyai.chat import (
|
|
32
|
+
stratifyai_openai as openai,
|
|
33
|
+
stratifyai_anthropic as anthropic,
|
|
34
|
+
stratifyai_google as google,
|
|
35
|
+
stratifyai_deepseek as deepseek,
|
|
36
|
+
stratifyai_groq as groq,
|
|
37
|
+
stratifyai_grok as grok,
|
|
38
|
+
stratifyai_openrouter as openrouter,
|
|
39
|
+
stratifyai_ollama as ollama,
|
|
40
|
+
stratifyai_bedrock as bedrock,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
__all__ = [
|
|
44
|
+
"ChatBuilder",
|
|
45
|
+
"openai",
|
|
46
|
+
"anthropic",
|
|
47
|
+
"google",
|
|
48
|
+
"deepseek",
|
|
49
|
+
"groq",
|
|
50
|
+
"grok",
|
|
51
|
+
"openrouter",
|
|
52
|
+
"ollama",
|
|
53
|
+
"bedrock",
|
|
54
|
+
]
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""Builder pattern for configuring chat clients.
|
|
2
|
+
|
|
3
|
+
Provides a fluent interface for configuring chat parameters before execution.
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
from stratifyai.chat import anthropic
|
|
7
|
+
|
|
8
|
+
# Model is always required
|
|
9
|
+
response = await anthropic.chat("Hello", model="claude-sonnet-4-5")
|
|
10
|
+
|
|
11
|
+
# Builder pattern with chaining (model required)
|
|
12
|
+
client = (
|
|
13
|
+
anthropic
|
|
14
|
+
.with_model("claude-opus-4-5")
|
|
15
|
+
.with_system("I am a helpful assistant")
|
|
16
|
+
.with_developer("Use markdown formatting")
|
|
17
|
+
.with_temperature(0.5)
|
|
18
|
+
.with_max_tokens(1000)
|
|
19
|
+
)
|
|
20
|
+
response = await client.chat("Hello")
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import asyncio
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import TYPE_CHECKING, AsyncIterator, Callable, Optional, Union
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from stratifyai import LLMClient
|
|
31
|
+
from stratifyai.models import ChatResponse, Message
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ChatBuilder:
|
|
36
|
+
"""Builder for configuring chat requests with fluent interface.
|
|
37
|
+
|
|
38
|
+
Each with_* method returns a new ChatBuilder instance with the
|
|
39
|
+
updated configuration, allowing for method chaining.
|
|
40
|
+
|
|
41
|
+
Model is required - either via with_model() or as a parameter to chat().
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
provider: The provider name (e.g., "openai", "anthropic")
|
|
45
|
+
default_temperature: Default temperature setting
|
|
46
|
+
default_max_tokens: Default max tokens setting
|
|
47
|
+
_model: Configured model (required before chat)
|
|
48
|
+
_system: System prompt
|
|
49
|
+
_developer: Developer/instruction prompt
|
|
50
|
+
_temperature: Configured temperature (None = use default)
|
|
51
|
+
_max_tokens: Configured max tokens (None = use default)
|
|
52
|
+
_client_factory: Factory function to create LLMClient
|
|
53
|
+
"""
|
|
54
|
+
provider: str
|
|
55
|
+
default_temperature: float = 0.7
|
|
56
|
+
default_max_tokens: Optional[int] = None
|
|
57
|
+
_model: Optional[str] = None
|
|
58
|
+
_system: Optional[str] = None
|
|
59
|
+
_developer: Optional[str] = None
|
|
60
|
+
_temperature: Optional[float] = None
|
|
61
|
+
_max_tokens: Optional[int] = None
|
|
62
|
+
_client_factory: Optional[Callable[[], "LLMClient"]] = None
|
|
63
|
+
_extra_kwargs: dict = field(default_factory=dict)
|
|
64
|
+
|
|
65
|
+
def _clone(self, **updates) -> "ChatBuilder":
|
|
66
|
+
"""Create a copy of this builder with updates applied."""
|
|
67
|
+
return ChatBuilder(
|
|
68
|
+
provider=self.provider,
|
|
69
|
+
default_temperature=self.default_temperature,
|
|
70
|
+
default_max_tokens=self.default_max_tokens,
|
|
71
|
+
_model=updates.get("_model", self._model),
|
|
72
|
+
_system=updates.get("_system", self._system),
|
|
73
|
+
_developer=updates.get("_developer", self._developer),
|
|
74
|
+
_temperature=updates.get("_temperature", self._temperature),
|
|
75
|
+
_max_tokens=updates.get("_max_tokens", self._max_tokens),
|
|
76
|
+
_client_factory=self._client_factory,
|
|
77
|
+
_extra_kwargs={**self._extra_kwargs, **updates.get("_extra_kwargs", {})},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def with_model(self, model: str) -> "ChatBuilder":
|
|
81
|
+
"""Set the model to use (required).
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model: Model name (e.g., "claude-opus-4-5", "gpt-4.1")
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
New ChatBuilder with model configured.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
client = anthropic.with_model("claude-opus-4-5")
|
|
91
|
+
"""
|
|
92
|
+
return self._clone(_model=model)
|
|
93
|
+
|
|
94
|
+
def with_system(self, prompt: str) -> "ChatBuilder":
|
|
95
|
+
"""Set the system prompt.
|
|
96
|
+
|
|
97
|
+
The system prompt sets the AI's behavior and context.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
prompt: System prompt text
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
New ChatBuilder with system prompt configured.
|
|
104
|
+
|
|
105
|
+
Example:
|
|
106
|
+
client = anthropic.with_system("You are a helpful coding assistant")
|
|
107
|
+
"""
|
|
108
|
+
return self._clone(_system=prompt)
|
|
109
|
+
|
|
110
|
+
def with_developer(self, instructions: str) -> "ChatBuilder":
|
|
111
|
+
"""Set developer/instruction prompt.
|
|
112
|
+
|
|
113
|
+
Developer instructions provide formatting or behavioral guidance.
|
|
114
|
+
These are prepended to the system prompt if both are set.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
instructions: Developer instructions text
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
New ChatBuilder with developer instructions configured.
|
|
121
|
+
|
|
122
|
+
Example:
|
|
123
|
+
client = anthropic.with_developer("Always use markdown formatting")
|
|
124
|
+
"""
|
|
125
|
+
return self._clone(_developer=instructions)
|
|
126
|
+
|
|
127
|
+
def with_temperature(self, temperature: float) -> "ChatBuilder":
|
|
128
|
+
"""Set the sampling temperature.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
temperature: Temperature value (0.0 = deterministic, 2.0 = creative)
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
New ChatBuilder with temperature configured.
|
|
135
|
+
|
|
136
|
+
Example:
|
|
137
|
+
client = anthropic.with_temperature(0.3)
|
|
138
|
+
"""
|
|
139
|
+
return self._clone(_temperature=temperature)
|
|
140
|
+
|
|
141
|
+
def with_max_tokens(self, max_tokens: int) -> "ChatBuilder":
|
|
142
|
+
"""Set the maximum tokens to generate.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
max_tokens: Maximum number of tokens in the response
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
New ChatBuilder with max_tokens configured.
|
|
149
|
+
|
|
150
|
+
Example:
|
|
151
|
+
client = anthropic.with_max_tokens(500)
|
|
152
|
+
"""
|
|
153
|
+
return self._clone(_max_tokens=max_tokens)
|
|
154
|
+
|
|
155
|
+
def with_options(self, **kwargs) -> "ChatBuilder":
|
|
156
|
+
"""Set additional provider-specific options.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
**kwargs: Additional options passed to the API
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
New ChatBuilder with extra options configured.
|
|
163
|
+
|
|
164
|
+
Example:
|
|
165
|
+
client = anthropic.with_options(top_p=0.9)
|
|
166
|
+
"""
|
|
167
|
+
return self._clone(_extra_kwargs=kwargs)
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def model(self) -> Optional[str]:
|
|
171
|
+
"""Get the configured model (None if not set)."""
|
|
172
|
+
return self._model
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def temperature(self) -> float:
|
|
176
|
+
"""Get the effective temperature (configured or default)."""
|
|
177
|
+
return self._temperature if self._temperature is not None else self.default_temperature
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def max_tokens(self) -> Optional[int]:
|
|
181
|
+
"""Get the effective max_tokens (configured or default)."""
|
|
182
|
+
return self._max_tokens if self._max_tokens is not None else self.default_max_tokens
|
|
183
|
+
|
|
184
|
+
def _get_client(self) -> "LLMClient":
|
|
185
|
+
"""Get or create the LLM client."""
|
|
186
|
+
if self._client_factory is None:
|
|
187
|
+
from stratifyai import LLMClient
|
|
188
|
+
return LLMClient(provider=self.provider)
|
|
189
|
+
return self._client_factory()
|
|
190
|
+
|
|
191
|
+
def _build_system_prompt(self) -> Optional[str]:
|
|
192
|
+
"""Combine developer and system prompts."""
|
|
193
|
+
if self._developer and self._system:
|
|
194
|
+
return f"{self._developer}\n\n{self._system}"
|
|
195
|
+
return self._developer or self._system
|
|
196
|
+
|
|
197
|
+
def _build_messages(self, prompt: Union[str, list]) -> list:
|
|
198
|
+
"""Build the messages list from prompt and configured prompts."""
|
|
199
|
+
from stratifyai.models import Message
|
|
200
|
+
|
|
201
|
+
if isinstance(prompt, str):
|
|
202
|
+
messages = []
|
|
203
|
+
system_prompt = self._build_system_prompt()
|
|
204
|
+
if system_prompt:
|
|
205
|
+
messages.append(Message(role="system", content=system_prompt))
|
|
206
|
+
messages.append(Message(role="user", content=prompt))
|
|
207
|
+
return messages
|
|
208
|
+
else:
|
|
209
|
+
# If prompt is already a list of messages, prepend system if not present
|
|
210
|
+
messages = list(prompt)
|
|
211
|
+
system_prompt = self._build_system_prompt()
|
|
212
|
+
if system_prompt and (not messages or messages[0].role != "system"):
|
|
213
|
+
from stratifyai.models import Message
|
|
214
|
+
messages.insert(0, Message(role="system", content=system_prompt))
|
|
215
|
+
return messages
|
|
216
|
+
|
|
217
|
+
async def chat(
|
|
218
|
+
self,
|
|
219
|
+
prompt: Union[str, list],
|
|
220
|
+
*,
|
|
221
|
+
model: Optional[str] = None,
|
|
222
|
+
system: Optional[str] = None,
|
|
223
|
+
temperature: Optional[float] = None,
|
|
224
|
+
max_tokens: Optional[int] = None,
|
|
225
|
+
stream: bool = False,
|
|
226
|
+
**kwargs,
|
|
227
|
+
) -> Union["ChatResponse", AsyncIterator["ChatResponse"]]:
|
|
228
|
+
"""Send an async chat completion request.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
prompt: User message string or list of Message objects.
|
|
232
|
+
model: Model to use (required if not set via with_model()).
|
|
233
|
+
system: Override the configured system prompt.
|
|
234
|
+
temperature: Override the configured temperature.
|
|
235
|
+
max_tokens: Override the configured max_tokens.
|
|
236
|
+
stream: Whether to stream the response.
|
|
237
|
+
**kwargs: Additional parameters passed to the API.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
ChatResponse object, or AsyncIterator[ChatResponse] if streaming.
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
ValueError: If no model is specified.
|
|
244
|
+
"""
|
|
245
|
+
client = self._get_client()
|
|
246
|
+
|
|
247
|
+
# Use overrides if provided, otherwise use configured values
|
|
248
|
+
effective_model = model or self._model
|
|
249
|
+
if not effective_model:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Model is required. Either call with_model() first or pass model parameter.\n"
|
|
252
|
+
f"Example: {self.provider}.with_model('model-name').chat(...) or "
|
|
253
|
+
f"{self.provider}.chat(..., model='model-name')"
|
|
254
|
+
)
|
|
255
|
+
effective_temp = temperature if temperature is not None else self.temperature
|
|
256
|
+
effective_max = max_tokens if max_tokens is not None else self.max_tokens
|
|
257
|
+
|
|
258
|
+
# Handle system prompt override
|
|
259
|
+
if system:
|
|
260
|
+
builder = self.with_system(system)
|
|
261
|
+
messages = builder._build_messages(prompt)
|
|
262
|
+
else:
|
|
263
|
+
messages = self._build_messages(prompt)
|
|
264
|
+
|
|
265
|
+
# Merge extra kwargs
|
|
266
|
+
merged_kwargs = {**self._extra_kwargs, **kwargs}
|
|
267
|
+
|
|
268
|
+
return await client.chat(
|
|
269
|
+
model=effective_model,
|
|
270
|
+
messages=messages,
|
|
271
|
+
temperature=effective_temp,
|
|
272
|
+
max_tokens=effective_max,
|
|
273
|
+
stream=stream,
|
|
274
|
+
**merged_kwargs,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
async def chat_stream(
|
|
278
|
+
self,
|
|
279
|
+
prompt: Union[str, list],
|
|
280
|
+
*,
|
|
281
|
+
model: Optional[str] = None,
|
|
282
|
+
system: Optional[str] = None,
|
|
283
|
+
temperature: Optional[float] = None,
|
|
284
|
+
max_tokens: Optional[int] = None,
|
|
285
|
+
**kwargs,
|
|
286
|
+
) -> AsyncIterator["ChatResponse"]:
|
|
287
|
+
"""Send an async streaming chat completion request.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
prompt: User message string or list of Message objects.
|
|
291
|
+
model: Override the configured model.
|
|
292
|
+
system: Override the configured system prompt.
|
|
293
|
+
temperature: Override the configured temperature.
|
|
294
|
+
max_tokens: Override the configured max_tokens.
|
|
295
|
+
**kwargs: Additional parameters passed to the API.
|
|
296
|
+
|
|
297
|
+
Yields:
|
|
298
|
+
ChatResponse chunks.
|
|
299
|
+
"""
|
|
300
|
+
return await self.chat(
|
|
301
|
+
prompt,
|
|
302
|
+
model=model,
|
|
303
|
+
system=system,
|
|
304
|
+
temperature=temperature,
|
|
305
|
+
max_tokens=max_tokens,
|
|
306
|
+
stream=True,
|
|
307
|
+
**kwargs,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def chat_sync(
|
|
311
|
+
self,
|
|
312
|
+
prompt: Union[str, list],
|
|
313
|
+
*,
|
|
314
|
+
model: Optional[str] = None,
|
|
315
|
+
system: Optional[str] = None,
|
|
316
|
+
temperature: Optional[float] = None,
|
|
317
|
+
max_tokens: Optional[int] = None,
|
|
318
|
+
**kwargs,
|
|
319
|
+
) -> "ChatResponse":
|
|
320
|
+
"""Synchronous wrapper for chat().
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
prompt: User message string or list of Message objects.
|
|
324
|
+
model: Override the configured model.
|
|
325
|
+
system: Override the configured system prompt.
|
|
326
|
+
temperature: Override the configured temperature.
|
|
327
|
+
max_tokens: Override the configured max_tokens.
|
|
328
|
+
**kwargs: Additional parameters passed to the API.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
ChatResponse object.
|
|
332
|
+
"""
|
|
333
|
+
return asyncio.run(self.chat(
|
|
334
|
+
prompt,
|
|
335
|
+
model=model,
|
|
336
|
+
system=system,
|
|
337
|
+
temperature=temperature,
|
|
338
|
+
max_tokens=max_tokens,
|
|
339
|
+
stream=False,
|
|
340
|
+
**kwargs,
|
|
341
|
+
))
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def create_module_builder(
|
|
345
|
+
provider: str,
|
|
346
|
+
default_temperature: float = 0.7,
|
|
347
|
+
default_max_tokens: Optional[int] = None,
|
|
348
|
+
client_factory: Optional[Callable[[], "LLMClient"]] = None,
|
|
349
|
+
) -> ChatBuilder:
|
|
350
|
+
"""Create a ChatBuilder for a provider module.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
provider: Provider name
|
|
354
|
+
default_temperature: Default temperature
|
|
355
|
+
default_max_tokens: Default max tokens
|
|
356
|
+
client_factory: Optional factory to create shared client
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Configured ChatBuilder instance
|
|
360
|
+
"""
|
|
361
|
+
return ChatBuilder(
|
|
362
|
+
provider=provider,
|
|
363
|
+
default_temperature=default_temperature,
|
|
364
|
+
default_max_tokens=default_max_tokens,
|
|
365
|
+
_client_factory=client_factory,
|
|
366
|
+
)
|