flashlite 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashlite/__init__.py +169 -0
- flashlite/cache/__init__.py +14 -0
- flashlite/cache/base.py +194 -0
- flashlite/cache/disk.py +285 -0
- flashlite/cache/memory.py +157 -0
- flashlite/client.py +671 -0
- flashlite/config.py +154 -0
- flashlite/conversation/__init__.py +30 -0
- flashlite/conversation/context.py +319 -0
- flashlite/conversation/manager.py +385 -0
- flashlite/conversation/multi_agent.py +378 -0
- flashlite/core/__init__.py +13 -0
- flashlite/core/completion.py +145 -0
- flashlite/core/messages.py +130 -0
- flashlite/middleware/__init__.py +18 -0
- flashlite/middleware/base.py +90 -0
- flashlite/middleware/cache.py +121 -0
- flashlite/middleware/logging.py +159 -0
- flashlite/middleware/rate_limit.py +211 -0
- flashlite/middleware/retry.py +149 -0
- flashlite/observability/__init__.py +34 -0
- flashlite/observability/callbacks.py +155 -0
- flashlite/observability/inspect_compat.py +266 -0
- flashlite/observability/logging.py +293 -0
- flashlite/observability/metrics.py +221 -0
- flashlite/py.typed +0 -0
- flashlite/structured/__init__.py +31 -0
- flashlite/structured/outputs.py +189 -0
- flashlite/structured/schema.py +165 -0
- flashlite/templating/__init__.py +11 -0
- flashlite/templating/engine.py +217 -0
- flashlite/templating/filters.py +143 -0
- flashlite/templating/registry.py +165 -0
- flashlite/tools/__init__.py +74 -0
- flashlite/tools/definitions.py +382 -0
- flashlite/tools/execution.py +353 -0
- flashlite/types.py +233 -0
- flashlite-0.1.0.dist-info/METADATA +173 -0
- flashlite-0.1.0.dist-info/RECORD +41 -0
- flashlite-0.1.0.dist-info/WHEEL +4 -0
- flashlite-0.1.0.dist-info/licenses/LICENSE.md +21 -0
flashlite/config.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Configuration management and environment loading."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
|
|
10
|
+
from .types import ConfigError, RateLimitConfig, RetryConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class FlashliteConfig:
|
|
15
|
+
"""Main configuration for Flashlite client."""
|
|
16
|
+
|
|
17
|
+
# Default model to use if not specified per-request
|
|
18
|
+
default_model: str | None = None
|
|
19
|
+
|
|
20
|
+
# Default completion parameters
|
|
21
|
+
default_temperature: float | None = None
|
|
22
|
+
default_max_tokens: int | None = None
|
|
23
|
+
|
|
24
|
+
# Retry configuration
|
|
25
|
+
retry: RetryConfig = field(default_factory=RetryConfig)
|
|
26
|
+
|
|
27
|
+
# Rate limiting configuration
|
|
28
|
+
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
|
|
29
|
+
|
|
30
|
+
# Template directory
|
|
31
|
+
template_dir: Path | str | None = None
|
|
32
|
+
|
|
33
|
+
# Logging
|
|
34
|
+
log_requests: bool = False
|
|
35
|
+
log_level: str = "INFO"
|
|
36
|
+
|
|
37
|
+
# Default kwargs to pass to all completions
|
|
38
|
+
default_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
# Timeout in seconds
|
|
41
|
+
timeout: float = 600.0
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def from_env(cls) -> "FlashliteConfig":
|
|
45
|
+
"""Create config from environment variables."""
|
|
46
|
+
config = cls()
|
|
47
|
+
|
|
48
|
+
# Read config from FLASHLITE_ prefixed env vars
|
|
49
|
+
if model := os.getenv("FLASHLITE_DEFAULT_MODEL"):
|
|
50
|
+
config.default_model = model
|
|
51
|
+
|
|
52
|
+
if temp := os.getenv("FLASHLITE_DEFAULT_TEMPERATURE"):
|
|
53
|
+
try:
|
|
54
|
+
config.default_temperature = float(temp)
|
|
55
|
+
except ValueError:
|
|
56
|
+
raise ConfigError(f"Invalid FLASHLITE_DEFAULT_TEMPERATURE: {temp}")
|
|
57
|
+
|
|
58
|
+
if max_tokens := os.getenv("FLASHLITE_DEFAULT_MAX_TOKENS"):
|
|
59
|
+
try:
|
|
60
|
+
config.default_max_tokens = int(max_tokens)
|
|
61
|
+
except ValueError:
|
|
62
|
+
raise ConfigError(f"Invalid FLASHLITE_DEFAULT_MAX_TOKENS: {max_tokens}")
|
|
63
|
+
|
|
64
|
+
if log_level := os.getenv("FLASHLITE_LOG_LEVEL"):
|
|
65
|
+
config.log_level = log_level.upper()
|
|
66
|
+
|
|
67
|
+
if os.getenv("FLASHLITE_LOG_REQUESTS", "").lower() in ("1", "true", "yes"):
|
|
68
|
+
config.log_requests = True
|
|
69
|
+
|
|
70
|
+
if template_dir := os.getenv("FLASHLITE_TEMPLATE_DIR"):
|
|
71
|
+
config.template_dir = Path(template_dir)
|
|
72
|
+
|
|
73
|
+
if rpm := os.getenv("FLASHLITE_RATE_LIMIT_RPM"):
|
|
74
|
+
try:
|
|
75
|
+
config.rate_limit.requests_per_minute = float(rpm)
|
|
76
|
+
except ValueError:
|
|
77
|
+
raise ConfigError(f"Invalid FLASHLITE_RATE_LIMIT_RPM: {rpm}")
|
|
78
|
+
|
|
79
|
+
if tpm := os.getenv("FLASHLITE_RATE_LIMIT_TPM"):
|
|
80
|
+
try:
|
|
81
|
+
config.rate_limit.tokens_per_minute = float(tpm)
|
|
82
|
+
except ValueError:
|
|
83
|
+
raise ConfigError(f"Invalid FLASHLITE_RATE_LIMIT_TPM: {tpm}")
|
|
84
|
+
|
|
85
|
+
if timeout := os.getenv("FLASHLITE_TIMEOUT"):
|
|
86
|
+
try:
|
|
87
|
+
config.timeout = float(timeout)
|
|
88
|
+
except ValueError:
|
|
89
|
+
raise ConfigError(f"Invalid FLASHLITE_TIMEOUT: {timeout}")
|
|
90
|
+
|
|
91
|
+
return config
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def load_env_files(
|
|
95
|
+
env_file: str | Path | None = None,
|
|
96
|
+
env_files: list[str | Path] | None = None,
|
|
97
|
+
) -> None:
|
|
98
|
+
"""
|
|
99
|
+
Load environment variables from .env files.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
env_file: Single env file to load
|
|
103
|
+
env_files: Multiple env files to load (later files override earlier)
|
|
104
|
+
"""
|
|
105
|
+
files_to_load: list[Path] = []
|
|
106
|
+
|
|
107
|
+
if env_files:
|
|
108
|
+
files_to_load.extend(Path(f) for f in env_files)
|
|
109
|
+
elif env_file:
|
|
110
|
+
files_to_load.append(Path(env_file))
|
|
111
|
+
else:
|
|
112
|
+
# Default: try to load .env from current directory
|
|
113
|
+
default_env = Path(".env")
|
|
114
|
+
if default_env.exists():
|
|
115
|
+
files_to_load.append(default_env)
|
|
116
|
+
|
|
117
|
+
# Load files in order (later overrides earlier)
|
|
118
|
+
for file_path in files_to_load:
|
|
119
|
+
if file_path.exists():
|
|
120
|
+
load_dotenv(file_path, override=True)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def validate_api_keys(required_providers: list[str] | None = None) -> dict[str, bool]:
|
|
124
|
+
"""
|
|
125
|
+
Check which API keys are configured.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
required_providers: If provided, raise error if any are missing
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Dict mapping provider names to whether their key is set
|
|
132
|
+
"""
|
|
133
|
+
key_mapping = {
|
|
134
|
+
"openai": "OPENAI_API_KEY",
|
|
135
|
+
"anthropic": "ANTHROPIC_API_KEY",
|
|
136
|
+
"google": "GOOGLE_API_KEY",
|
|
137
|
+
"cohere": "COHERE_API_KEY",
|
|
138
|
+
"azure": "AZURE_API_KEY",
|
|
139
|
+
"bedrock": "AWS_ACCESS_KEY_ID",
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
results = {}
|
|
143
|
+
for provider, env_var in key_mapping.items():
|
|
144
|
+
results[provider] = bool(os.getenv(env_var))
|
|
145
|
+
|
|
146
|
+
if required_providers:
|
|
147
|
+
missing = [p for p in required_providers if not results.get(p)]
|
|
148
|
+
if missing:
|
|
149
|
+
raise ConfigError(
|
|
150
|
+
f"Missing API keys for providers: {', '.join(missing)}. "
|
|
151
|
+
f"Set the corresponding environment variables."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return results
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Conversation management module for flashlite."""
|
|
2
|
+
|
|
3
|
+
from .context import (
|
|
4
|
+
ContextLimits,
|
|
5
|
+
ContextManager,
|
|
6
|
+
check_context_fit,
|
|
7
|
+
estimate_messages_tokens,
|
|
8
|
+
estimate_tokens,
|
|
9
|
+
truncate_messages,
|
|
10
|
+
)
|
|
11
|
+
from .manager import Conversation, ConversationState, Turn
|
|
12
|
+
from .multi_agent import Agent, ChatMessage, MultiAgentChat
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
# Conversation management
|
|
16
|
+
"Conversation",
|
|
17
|
+
"ConversationState",
|
|
18
|
+
"Turn",
|
|
19
|
+
# Multi-agent conversations
|
|
20
|
+
"MultiAgentChat",
|
|
21
|
+
"Agent",
|
|
22
|
+
"ChatMessage",
|
|
23
|
+
# Context management
|
|
24
|
+
"ContextManager",
|
|
25
|
+
"ContextLimits",
|
|
26
|
+
"estimate_tokens",
|
|
27
|
+
"estimate_messages_tokens",
|
|
28
|
+
"check_context_fit",
|
|
29
|
+
"truncate_messages",
|
|
30
|
+
]
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""Context window management for conversations."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ..types import Messages
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Approximate token counts per character for different models
|
|
13
|
+
# These are rough estimates - actual tokenization varies
|
|
14
|
+
CHARS_PER_TOKEN_ESTIMATE = 4
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ContextLimits:
|
|
19
|
+
"""Context window limits for a model."""
|
|
20
|
+
|
|
21
|
+
max_tokens: int
|
|
22
|
+
recommended_max: int | None = None # Leave room for response
|
|
23
|
+
|
|
24
|
+
@classmethod
|
|
25
|
+
def for_model(cls, model: str) -> "ContextLimits":
|
|
26
|
+
"""Get context limits for a model (approximate)."""
|
|
27
|
+
model_lower = model.lower()
|
|
28
|
+
|
|
29
|
+
# GPT-4 variants
|
|
30
|
+
if "gpt-4o" in model_lower:
|
|
31
|
+
return cls(max_tokens=128_000, recommended_max=120_000)
|
|
32
|
+
if "gpt-4-turbo" in model_lower or "gpt-4-1106" in model_lower:
|
|
33
|
+
return cls(max_tokens=128_000, recommended_max=120_000)
|
|
34
|
+
if "gpt-4-32k" in model_lower:
|
|
35
|
+
return cls(max_tokens=32_768, recommended_max=30_000)
|
|
36
|
+
if "gpt-4" in model_lower:
|
|
37
|
+
return cls(max_tokens=8_192, recommended_max=7_000)
|
|
38
|
+
|
|
39
|
+
# GPT-3.5 variants
|
|
40
|
+
if "gpt-3.5-turbo-16k" in model_lower:
|
|
41
|
+
return cls(max_tokens=16_384, recommended_max=15_000)
|
|
42
|
+
if "gpt-3.5" in model_lower:
|
|
43
|
+
return cls(max_tokens=16_384, recommended_max=15_000)
|
|
44
|
+
|
|
45
|
+
# Claude variants
|
|
46
|
+
if "claude-3" in model_lower or "claude-sonnet-4" in model_lower:
|
|
47
|
+
return cls(max_tokens=200_000, recommended_max=190_000)
|
|
48
|
+
if "claude-2" in model_lower:
|
|
49
|
+
return cls(max_tokens=100_000, recommended_max=95_000)
|
|
50
|
+
if "claude" in model_lower:
|
|
51
|
+
return cls(max_tokens=100_000, recommended_max=95_000)
|
|
52
|
+
|
|
53
|
+
# Gemini
|
|
54
|
+
if "gemini-1.5" in model_lower:
|
|
55
|
+
return cls(max_tokens=1_000_000, recommended_max=900_000)
|
|
56
|
+
if "gemini" in model_lower:
|
|
57
|
+
return cls(max_tokens=32_768, recommended_max=30_000)
|
|
58
|
+
|
|
59
|
+
# Mistral
|
|
60
|
+
if "mistral-large" in model_lower:
|
|
61
|
+
return cls(max_tokens=128_000, recommended_max=120_000)
|
|
62
|
+
if "mistral" in model_lower:
|
|
63
|
+
return cls(max_tokens=32_768, recommended_max=30_000)
|
|
64
|
+
|
|
65
|
+
# Default conservative estimate
|
|
66
|
+
return cls(max_tokens=8_192, recommended_max=7_000)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def estimate_tokens(text: str) -> int:
|
|
70
|
+
"""
|
|
71
|
+
Estimate the number of tokens in text.
|
|
72
|
+
|
|
73
|
+
This is a rough approximation. For accurate counts, use tiktoken
|
|
74
|
+
or the provider's tokenizer.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
text: The text to estimate tokens for
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Estimated token count
|
|
81
|
+
"""
|
|
82
|
+
return len(text) // CHARS_PER_TOKEN_ESTIMATE + 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def estimate_messages_tokens(messages: Messages) -> int:
|
|
86
|
+
"""
|
|
87
|
+
Estimate total tokens in a messages list.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
messages: List of messages
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Estimated token count
|
|
94
|
+
"""
|
|
95
|
+
total = 0
|
|
96
|
+
for msg in messages:
|
|
97
|
+
# Add overhead for message structure
|
|
98
|
+
total += 4 # Approximate overhead per message
|
|
99
|
+
|
|
100
|
+
content = msg.get("content", "")
|
|
101
|
+
if isinstance(content, str):
|
|
102
|
+
total += estimate_tokens(content)
|
|
103
|
+
elif isinstance(content, list):
|
|
104
|
+
# Handle multi-part content (e.g., images)
|
|
105
|
+
for part in content:
|
|
106
|
+
if isinstance(part, dict) and "text" in part:
|
|
107
|
+
total += estimate_tokens(part["text"])
|
|
108
|
+
else:
|
|
109
|
+
total += 100 # Rough estimate for non-text content
|
|
110
|
+
|
|
111
|
+
return total
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def check_context_fit(
|
|
115
|
+
messages: Messages,
|
|
116
|
+
model: str,
|
|
117
|
+
max_response_tokens: int = 4096,
|
|
118
|
+
) -> tuple[bool, dict[str, Any]]:
|
|
119
|
+
"""
|
|
120
|
+
Check if messages fit within the model's context window.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
messages: The messages to check
|
|
124
|
+
model: The model name
|
|
125
|
+
max_response_tokens: Expected max tokens in response
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Tuple of (fits, info_dict) where info_dict contains:
|
|
129
|
+
- estimated_tokens: Estimated input tokens
|
|
130
|
+
- max_tokens: Model's max context
|
|
131
|
+
- remaining: Tokens remaining for response
|
|
132
|
+
- warning: Optional warning message
|
|
133
|
+
"""
|
|
134
|
+
limits = ContextLimits.for_model(model)
|
|
135
|
+
estimated = estimate_messages_tokens(messages)
|
|
136
|
+
remaining = limits.max_tokens - estimated - max_response_tokens
|
|
137
|
+
|
|
138
|
+
info: dict[str, Any] = {
|
|
139
|
+
"estimated_tokens": estimated,
|
|
140
|
+
"max_tokens": limits.max_tokens,
|
|
141
|
+
"remaining": remaining,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
if remaining < 0:
|
|
145
|
+
info["warning"] = (
|
|
146
|
+
f"Messages ({estimated} tokens) + response ({max_response_tokens}) "
|
|
147
|
+
f"exceed context limit ({limits.max_tokens})"
|
|
148
|
+
)
|
|
149
|
+
return False, info
|
|
150
|
+
|
|
151
|
+
if limits.recommended_max and estimated > limits.recommended_max:
|
|
152
|
+
info["warning"] = (
|
|
153
|
+
f"Messages ({estimated} tokens) exceed recommended max "
|
|
154
|
+
f"({limits.recommended_max}). Consider truncating."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return True, info
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def truncate_messages(
|
|
161
|
+
messages: Messages,
|
|
162
|
+
max_tokens: int,
|
|
163
|
+
strategy: str = "oldest",
|
|
164
|
+
keep_system: bool = True,
|
|
165
|
+
) -> Messages:
|
|
166
|
+
"""
|
|
167
|
+
Truncate messages to fit within a token budget.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
messages: The messages to truncate
|
|
171
|
+
max_tokens: Maximum tokens to keep
|
|
172
|
+
strategy: Truncation strategy - "oldest" removes oldest messages first
|
|
173
|
+
keep_system: Whether to always keep system messages
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Truncated messages list
|
|
177
|
+
"""
|
|
178
|
+
if strategy != "oldest":
|
|
179
|
+
raise ValueError(f"Unknown truncation strategy: {strategy}")
|
|
180
|
+
|
|
181
|
+
messages_list = list(messages)
|
|
182
|
+
current_tokens = estimate_messages_tokens(messages_list)
|
|
183
|
+
|
|
184
|
+
if current_tokens <= max_tokens:
|
|
185
|
+
return messages_list
|
|
186
|
+
|
|
187
|
+
# Separate system messages if we're keeping them
|
|
188
|
+
system_messages = []
|
|
189
|
+
other_messages = []
|
|
190
|
+
|
|
191
|
+
for msg in messages_list:
|
|
192
|
+
if keep_system and msg.get("role") == "system":
|
|
193
|
+
system_messages.append(msg)
|
|
194
|
+
else:
|
|
195
|
+
other_messages.append(msg)
|
|
196
|
+
|
|
197
|
+
system_tokens = estimate_messages_tokens(system_messages)
|
|
198
|
+
available_tokens = max_tokens - system_tokens
|
|
199
|
+
|
|
200
|
+
# Remove oldest messages until we fit
|
|
201
|
+
while other_messages and estimate_messages_tokens(other_messages) > available_tokens:
|
|
202
|
+
other_messages.pop(0)
|
|
203
|
+
|
|
204
|
+
result = system_messages + other_messages
|
|
205
|
+
|
|
206
|
+
removed_count = len(messages_list) - len(result)
|
|
207
|
+
if removed_count > 0:
|
|
208
|
+
logger.info(
|
|
209
|
+
f"Truncated {removed_count} messages to fit context window "
|
|
210
|
+
f"({current_tokens} -> {estimate_messages_tokens(result)} tokens)"
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class ContextManager:
|
|
217
|
+
"""
|
|
218
|
+
Manages context window for a conversation.
|
|
219
|
+
|
|
220
|
+
Provides automatic truncation when approaching limits and
|
|
221
|
+
warnings when context is getting full.
|
|
222
|
+
|
|
223
|
+
Example:
|
|
224
|
+
ctx = ContextManager(model="gpt-4o", max_response_tokens=4096)
|
|
225
|
+
|
|
226
|
+
# Check if messages fit
|
|
227
|
+
fits, info = ctx.check(messages)
|
|
228
|
+
if not fits:
|
|
229
|
+
messages = ctx.truncate(messages)
|
|
230
|
+
|
|
231
|
+
# Or use auto mode
|
|
232
|
+
messages = ctx.prepare(messages) # Automatically truncates if needed
|
|
233
|
+
"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
model: str,
|
|
238
|
+
max_response_tokens: int = 4096,
|
|
239
|
+
auto_truncate: bool = True,
|
|
240
|
+
truncation_strategy: str = "oldest",
|
|
241
|
+
keep_system: bool = True,
|
|
242
|
+
warn_threshold: float = 0.8,
|
|
243
|
+
):
|
|
244
|
+
"""
|
|
245
|
+
Initialize context manager.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
model: Model name to get context limits for
|
|
249
|
+
max_response_tokens: Expected max tokens in response
|
|
250
|
+
auto_truncate: Whether to automatically truncate when needed
|
|
251
|
+
truncation_strategy: Strategy for truncation ("oldest")
|
|
252
|
+
keep_system: Whether to preserve system messages during truncation
|
|
253
|
+
warn_threshold: Warn when context usage exceeds this ratio (0-1)
|
|
254
|
+
"""
|
|
255
|
+
self._model = model
|
|
256
|
+
self._limits = ContextLimits.for_model(model)
|
|
257
|
+
self._max_response_tokens = max_response_tokens
|
|
258
|
+
self._auto_truncate = auto_truncate
|
|
259
|
+
self._truncation_strategy = truncation_strategy
|
|
260
|
+
self._keep_system = keep_system
|
|
261
|
+
self._warn_threshold = warn_threshold
|
|
262
|
+
|
|
263
|
+
def check(self, messages: Messages) -> tuple[bool, dict[str, Any]]:
|
|
264
|
+
"""Check if messages fit within context limits."""
|
|
265
|
+
return check_context_fit(messages, self._model, self._max_response_tokens)
|
|
266
|
+
|
|
267
|
+
def truncate(self, messages: Messages) -> Messages:
|
|
268
|
+
"""Truncate messages to fit within limits."""
|
|
269
|
+
max_input = self._limits.max_tokens - self._max_response_tokens
|
|
270
|
+
return truncate_messages(
|
|
271
|
+
messages,
|
|
272
|
+
max_input,
|
|
273
|
+
strategy=self._truncation_strategy,
|
|
274
|
+
keep_system=self._keep_system,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def prepare(self, messages: Messages) -> Messages:
|
|
278
|
+
"""
|
|
279
|
+
Prepare messages for completion.
|
|
280
|
+
|
|
281
|
+
Checks fit and truncates if needed (when auto_truncate is enabled).
|
|
282
|
+
Logs warnings when approaching limits.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
messages: The messages to prepare
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Messages ready for completion (possibly truncated)
|
|
289
|
+
"""
|
|
290
|
+
fits, info = self.check(messages)
|
|
291
|
+
|
|
292
|
+
# Check warning threshold
|
|
293
|
+
usage_ratio = info["estimated_tokens"] / self._limits.max_tokens
|
|
294
|
+
if usage_ratio > self._warn_threshold:
|
|
295
|
+
logger.warning(
|
|
296
|
+
f"Context usage at {usage_ratio:.0%} "
|
|
297
|
+
f"({info['estimated_tokens']}/{self._limits.max_tokens} tokens)"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
if not fits:
|
|
301
|
+
if self._auto_truncate:
|
|
302
|
+
logger.warning(f"Context exceeded, truncating: {info.get('warning')}")
|
|
303
|
+
return self.truncate(messages)
|
|
304
|
+
else:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Messages exceed context limit: {info.get('warning')}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return list(messages)
|
|
310
|
+
|
|
311
|
+
@property
|
|
312
|
+
def model(self) -> str:
|
|
313
|
+
"""The model this manager is configured for."""
|
|
314
|
+
return self._model
|
|
315
|
+
|
|
316
|
+
@property
|
|
317
|
+
def max_tokens(self) -> int:
|
|
318
|
+
"""Maximum context tokens for the model."""
|
|
319
|
+
return self._limits.max_tokens
|