aetherroute 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherroute/__init__.py +8 -0
- aetherroute/adapters/__init__.py +8 -0
- aetherroute/adapters/prompt.py +129 -0
- aetherroute/adapters/token_counter.py +44 -0
- aetherroute/cache/__init__.py +3 -0
- aetherroute/cache/semantic.py +153 -0
- aetherroute/config.py +65 -0
- aetherroute/context/__init__.py +3 -0
- aetherroute/context/curator.py +164 -0
- aetherroute/cost/__init__.py +3 -0
- aetherroute/cost/governor.py +119 -0
- aetherroute/observability/__init__.py +3 -0
- aetherroute/observability/dashboard.py +18 -0
- aetherroute/observability/logger.py +79 -0
- aetherroute/observability/report.py +237 -0
- aetherroute/orchestrator.py +350 -0
- aetherroute/providers/__init__.py +15 -0
- aetherroute/providers/anthropic.py +131 -0
- aetherroute/providers/base.py +128 -0
- aetherroute/providers/mistral.py +142 -0
- aetherroute/providers/ollama.py +108 -0
- aetherroute/providers/openai.py +122 -0
- aetherroute/providers/registry.py +120 -0
- aetherroute/py.typed +1 -0
- aetherroute/router/__init__.py +4 -0
- aetherroute/router/classifier.py +64 -0
- aetherroute/router/engine.py +250 -0
- aetherroute/security/__init__.py +9 -0
- aetherroute/security/permission.py +50 -0
- aetherroute/security/sanitizer.py +64 -0
- aetherroute/validation/__init__.py +9 -0
- aetherroute/validation/consistency.py +109 -0
- aetherroute/validation/validator.py +112 -0
- aetherroute-0.1.0.dist-info/METADATA +324 -0
- aetherroute-0.1.0.dist-info/RECORD +39 -0
- aetherroute-0.1.0.dist-info/WHEEL +5 -0
- aetherroute-0.1.0.dist-info/entry_points.txt +2 -0
- aetherroute-0.1.0.dist-info/licenses/LICENSE +21 -0
- aetherroute-0.1.0.dist-info/top_level.txt +1 -0
aetherroute/__init__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Dict, Any, Optional
|
|
4
|
+
from aetherroute.adapters.token_counter import count_messages_tokens, count_tokens
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger("aetherroute.adapters.prompt")
|
|
7
|
+
|
|
8
|
+
class PromptAdapter:
|
|
9
|
+
"""
|
|
10
|
+
Normalizes templates and message histories for specific LLM providers.
|
|
11
|
+
Trims history to fit within a provider's max context limit.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def render_template(template: str, variables: Dict[str, Any]) -> str:
|
|
16
|
+
"""
|
|
17
|
+
Renders variables into double curly-braces template placeholders (e.g. {{variable}}).
|
|
18
|
+
"""
|
|
19
|
+
rendered = template
|
|
20
|
+
for key, val in variables.items():
|
|
21
|
+
pattern = re.compile(r"\{\{\s*" + re.escape(key) + r"\s*\}\}")
|
|
22
|
+
rendered = pattern.sub(str(val), rendered)
|
|
23
|
+
return rendered
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def normalize_messages(
|
|
27
|
+
cls,
|
|
28
|
+
messages: List[Dict[str, str]],
|
|
29
|
+
provider_name: str,
|
|
30
|
+
max_context: int,
|
|
31
|
+
reserved_output_tokens: int = 2048
|
|
32
|
+
) -> List[Dict[str, str]]:
|
|
33
|
+
"""
|
|
34
|
+
Normalizes role types, alternates roles if necessary (Anthropic),
|
|
35
|
+
and trims historical messages to fit the provider's max context budget.
|
|
36
|
+
"""
|
|
37
|
+
if not messages:
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
# 1. Filter out system messages and other messages
|
|
41
|
+
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|
42
|
+
other_msgs = [m for m in messages if m.get("role") != "system"]
|
|
43
|
+
|
|
44
|
+
# 2. Provider specific adjustments
|
|
45
|
+
normalized_others = []
|
|
46
|
+
if provider_name == "anthropic":
|
|
47
|
+
# Anthropic Claude requires alternating user and assistant messages, starting with user.
|
|
48
|
+
# Merge consecutive identical roles, and drop/convert initial assistant messages.
|
|
49
|
+
current_role = None
|
|
50
|
+
current_content = []
|
|
51
|
+
|
|
52
|
+
for msg in other_msgs:
|
|
53
|
+
role = msg.get("role", "user")
|
|
54
|
+
# Normalize roles
|
|
55
|
+
if role not in ["user", "assistant"]:
|
|
56
|
+
role = "user"
|
|
57
|
+
|
|
58
|
+
if current_role is None:
|
|
59
|
+
if role == "assistant":
|
|
60
|
+
# Cannot start with assistant, convert to user or ignore
|
|
61
|
+
# Let's convert it to user so we don't lose context
|
|
62
|
+
role = "user"
|
|
63
|
+
current_role = role
|
|
64
|
+
current_content.append(msg.get("content", ""))
|
|
65
|
+
elif role == current_role:
|
|
66
|
+
# Merge consecutive identical roles
|
|
67
|
+
current_content.append(msg.get("content", ""))
|
|
68
|
+
else:
|
|
69
|
+
# Save previous
|
|
70
|
+
normalized_others.append({
|
|
71
|
+
"role": current_role,
|
|
72
|
+
"content": "\n\n".join(current_content)
|
|
73
|
+
})
|
|
74
|
+
current_role = role
|
|
75
|
+
current_content = [msg.get("content", "")]
|
|
76
|
+
|
|
77
|
+
# Add the last message group
|
|
78
|
+
if current_role:
|
|
79
|
+
normalized_others.append({
|
|
80
|
+
"role": current_role,
|
|
81
|
+
"content": "\n\n".join(current_content)
|
|
82
|
+
})
|
|
83
|
+
else:
|
|
84
|
+
# OpenAI, Mistral, Ollama accept standard system/user/assistant messages.
|
|
85
|
+
for msg in other_msgs:
|
|
86
|
+
role = msg.get("role", "user")
|
|
87
|
+
if role not in ["user", "assistant", "system"]:
|
|
88
|
+
role = "user"
|
|
89
|
+
normalized_others.append({
|
|
90
|
+
"role": role,
|
|
91
|
+
"content": msg.get("content", "")
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
# 3. Handle context window pruning (sliding window over chat history)
|
|
95
|
+
# We always want to keep the system messages if they exist.
|
|
96
|
+
# We prune from the oldest of other_msgs (index 0 of normalized_others)
|
|
97
|
+
allowed_tokens = max_context - reserved_output_tokens
|
|
98
|
+
|
|
99
|
+
system_tokens = count_messages_tokens(system_msgs)
|
|
100
|
+
if system_tokens > allowed_tokens:
|
|
101
|
+
# System message itself exceeds the allowed size! Prune system message characters as fallback.
|
|
102
|
+
logger.warning("System message exceeds context budget. Truncating system message.")
|
|
103
|
+
for s_msg in system_msgs:
|
|
104
|
+
s_msg["content"] = s_msg["content"][:allowed_tokens * 4]
|
|
105
|
+
return system_msgs
|
|
106
|
+
|
|
107
|
+
# Prune older user/assistant messages until we fit
|
|
108
|
+
final_messages = system_msgs + normalized_others
|
|
109
|
+
total_tokens = count_messages_tokens(final_messages)
|
|
110
|
+
|
|
111
|
+
while total_tokens > allowed_tokens and len(normalized_others) > 1:
|
|
112
|
+
# Remove the oldest history message
|
|
113
|
+
removed = normalized_others.pop(0)
|
|
114
|
+
logger.info(f"Context limit reached. Pruning message with role: {removed['role']}")
|
|
115
|
+
final_messages = system_msgs + normalized_others
|
|
116
|
+
total_tokens = count_messages_tokens(final_messages)
|
|
117
|
+
|
|
118
|
+
# If it still exceeds (e.g. only one huge query message remaining), truncate the text content of that message
|
|
119
|
+
if total_tokens > allowed_tokens and normalized_others:
|
|
120
|
+
msg_to_truncate = normalized_others[0]
|
|
121
|
+
current_tokens = count_tokens(msg_to_truncate["content"])
|
|
122
|
+
excess_tokens = total_tokens - allowed_tokens
|
|
123
|
+
allowed_msg_tokens = max(100, current_tokens - excess_tokens)
|
|
124
|
+
|
|
125
|
+
# Simple truncate by character approximation
|
|
126
|
+
msg_to_truncate["content"] = msg_to_truncate["content"][:allowed_msg_tokens * 4] + "... [truncated]"
|
|
127
|
+
final_messages = system_msgs + normalized_others
|
|
128
|
+
|
|
129
|
+
return final_messages
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger("aetherroute.adapters.token_counter")
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import tiktoken
|
|
8
|
+
_TIKTOKEN_AVAILABLE = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
_TIKTOKEN_AVAILABLE = False
|
|
11
|
+
logger.warning("tiktoken library not found. Falling back to rough character estimation (1 token ≈ 4 characters).")
|
|
12
|
+
|
|
13
|
+
def count_tokens(text: str, model_name: str = "gpt-4o-mini") -> int:
|
|
14
|
+
"""
|
|
15
|
+
Returns the token count of a given string.
|
|
16
|
+
Uses tiktoken if available, otherwise falls back to character division.
|
|
17
|
+
"""
|
|
18
|
+
if not text:
|
|
19
|
+
return 0
|
|
20
|
+
|
|
21
|
+
if _TIKTOKEN_AVAILABLE:
|
|
22
|
+
try:
|
|
23
|
+
# Try to get the encoding for the model
|
|
24
|
+
try:
|
|
25
|
+
encoding = tiktoken.encoding_for_model(model_name)
|
|
26
|
+
except KeyError:
|
|
27
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
28
|
+
return len(encoding.encode(text))
|
|
29
|
+
except Exception as e:
|
|
30
|
+
logger.debug(f"Error counting tokens with tiktoken: {e}. Falling back to character estimation.")
|
|
31
|
+
|
|
32
|
+
# Fallback: 1 token is roughly 4 characters on average for English text
|
|
33
|
+
return max(1, len(text) // 4)
|
|
34
|
+
|
|
35
|
+
def count_messages_tokens(messages: List[Dict[str, str]], model_name: str = "gpt-4o-mini") -> int:
|
|
36
|
+
"""
|
|
37
|
+
Estimates the total token count of a message list.
|
|
38
|
+
"""
|
|
39
|
+
token_sum = 0
|
|
40
|
+
for msg in messages:
|
|
41
|
+
token_sum += count_tokens(msg.get("content", ""), model_name)
|
|
42
|
+
# Add a few overhead tokens per message (role, structure)
|
|
43
|
+
token_sum += 4
|
|
44
|
+
return token_sum + 3 # overall conversation overhead
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import Dict, Any, Optional, List, Tuple
|
|
5
|
+
from aetherroute.context.curator import ContextCurator
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger("aetherroute.cache.semantic")
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import redis.asyncio as aioredis
|
|
11
|
+
_REDIS_AVAILABLE = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
_REDIS_AVAILABLE = False
|
|
14
|
+
logger.warning("redis-py not installed. Caching will run strictly in-memory.")
|
|
15
|
+
|
|
16
|
+
class SemanticCache:
|
|
17
|
+
"""
|
|
18
|
+
Semantic and exact-match cache for LLM completions.
|
|
19
|
+
Tries to connect to Redis, and automatically falls back to an in-memory database.
|
|
20
|
+
Matches queries fuzzy-style using TF-IDF cosine similarity.
|
|
21
|
+
"""
|
|
22
|
+
def __init__(self, config: Any):
|
|
23
|
+
self.config = config
|
|
24
|
+
self.redis_url = getattr(config.cache, "redis_url", "redis://localhost:6379")
|
|
25
|
+
self.threshold = getattr(config.cache, "semantic_threshold", 0.85)
|
|
26
|
+
self.ttl = getattr(config.cache, "ttl_seconds", 3600)
|
|
27
|
+
self.client: Optional[Any] = None
|
|
28
|
+
self.use_redis = False
|
|
29
|
+
|
|
30
|
+
# Local in-memory cache fallback: maps query -> (response_dict, expiry_timestamp)
|
|
31
|
+
self._in_memory_cache: Dict[str, Tuple[Dict[str, Any], float]] = {}
|
|
32
|
+
|
|
33
|
+
async def connect(self) -> None:
|
|
34
|
+
"""Attempts to connect to Redis. Falls back to in-memory on failure."""
|
|
35
|
+
if _REDIS_AVAILABLE:
|
|
36
|
+
try:
|
|
37
|
+
self.client = aioredis.from_url(self.redis_url, decode_responses=True)
|
|
38
|
+
# Test connection
|
|
39
|
+
await asyncio.wait_for(self.client.ping(), timeout=2.0)
|
|
40
|
+
self.use_redis = True
|
|
41
|
+
logger.info(f"Connected to Redis semantic cache at {self.redis_url}")
|
|
42
|
+
return
|
|
43
|
+
except Exception as e:
|
|
44
|
+
logger.warning(f"Failed to connect to Redis ({e}). Falling back to In-Memory Cache.")
|
|
45
|
+
|
|
46
|
+
self.use_redis = False
|
|
47
|
+
self.client = None
|
|
48
|
+
|
|
49
|
+
async def get(self, query: str) -> Optional[Dict[str, Any]]:
|
|
50
|
+
"""
|
|
51
|
+
Retrieves a cached response if an exact or semantically similar query is found.
|
|
52
|
+
"""
|
|
53
|
+
# Ensure connection state is checked
|
|
54
|
+
if self.use_redis and self.client is None:
|
|
55
|
+
await self.connect()
|
|
56
|
+
|
|
57
|
+
cleaned_query = query.strip()
|
|
58
|
+
|
|
59
|
+
if self.use_redis and self.client:
|
|
60
|
+
try:
|
|
61
|
+
# 1. Check exact match
|
|
62
|
+
exact_key = f"cache:exact:{cleaned_query}"
|
|
63
|
+
exact_val = await self.client.get(exact_key)
|
|
64
|
+
if exact_val:
|
|
65
|
+
logger.info("SemanticCache: Exact match found in Redis.")
|
|
66
|
+
return json.loads(exact_val)
|
|
67
|
+
|
|
68
|
+
# 2. Check semantic match
|
|
69
|
+
# Get all cached queries
|
|
70
|
+
keys = await self.client.keys("cache:exact:*")
|
|
71
|
+
for key in keys:
|
|
72
|
+
original_query = key.replace("cache:exact:", "", 1)
|
|
73
|
+
similarity = ContextCurator.calculate_tfidf_similarity(cleaned_query, original_query)
|
|
74
|
+
if similarity >= self.threshold:
|
|
75
|
+
logger.info(f"SemanticCache: Fuzzy match found in Redis (similarity={similarity:.4f}).")
|
|
76
|
+
val = await self.client.get(key)
|
|
77
|
+
if val:
|
|
78
|
+
return json.loads(val)
|
|
79
|
+
except Exception as e:
|
|
80
|
+
logger.error(f"Redis cache read error: {e}. Falling back to local search.")
|
|
81
|
+
self.use_redis = False
|
|
82
|
+
|
|
83
|
+
# In-memory search (either fallback or primary)
|
|
84
|
+
import time
|
|
85
|
+
now = time.time()
|
|
86
|
+
|
|
87
|
+
# 1. Exact match in memory
|
|
88
|
+
if cleaned_query in self._in_memory_cache:
|
|
89
|
+
val, expiry = self._in_memory_cache[cleaned_query]
|
|
90
|
+
if now < expiry:
|
|
91
|
+
logger.info("SemanticCache: Exact match found in memory.")
|
|
92
|
+
return val
|
|
93
|
+
else:
|
|
94
|
+
del self._in_memory_cache[cleaned_query]
|
|
95
|
+
|
|
96
|
+
# 2. Semantic match in memory
|
|
97
|
+
for orig_query, (val, expiry) in list(self._in_memory_cache.items()):
|
|
98
|
+
if now < expiry:
|
|
99
|
+
similarity = ContextCurator.calculate_tfidf_similarity(cleaned_query, orig_query)
|
|
100
|
+
if similarity >= self.threshold:
|
|
101
|
+
logger.info(f"SemanticCache: Fuzzy match found in memory (similarity={similarity:.4f}).")
|
|
102
|
+
return val
|
|
103
|
+
else:
|
|
104
|
+
del self._in_memory_cache[orig_query]
|
|
105
|
+
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
async def set(self, query: str, response: Dict[str, Any], ttl: Optional[int] = None) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Caches a query response.
|
|
111
|
+
"""
|
|
112
|
+
if self.use_redis and self.client is None:
|
|
113
|
+
await self.connect()
|
|
114
|
+
|
|
115
|
+
cleaned_query = query.strip()
|
|
116
|
+
cache_ttl = ttl if ttl is not None else self.ttl
|
|
117
|
+
response_str = json.dumps(response)
|
|
118
|
+
|
|
119
|
+
if self.use_redis and self.client:
|
|
120
|
+
try:
|
|
121
|
+
exact_key = f"cache:exact:{cleaned_query}"
|
|
122
|
+
await self.client.set(exact_key, response_str, ex=cache_ttl)
|
|
123
|
+
logger.debug(f"Cached query in Redis with TTL={cache_ttl}")
|
|
124
|
+
return
|
|
125
|
+
except Exception as e:
|
|
126
|
+
logger.error(f"Redis cache write error: {e}. Writing to in-memory fallback.")
|
|
127
|
+
self.use_redis = False
|
|
128
|
+
|
|
129
|
+
# Write to memory fallback
|
|
130
|
+
import time
|
|
131
|
+
expiry = time.time() + cache_ttl
|
|
132
|
+
self._in_memory_cache[cleaned_query] = (response, expiry)
|
|
133
|
+
logger.debug(f"Cached query in-memory with TTL={cache_ttl}")
|
|
134
|
+
|
|
135
|
+
async def clear(self) -> None:
|
|
136
|
+
"""Clears all cached entries."""
|
|
137
|
+
self._in_memory_cache.clear()
|
|
138
|
+
if self.use_redis and self.client:
|
|
139
|
+
try:
|
|
140
|
+
keys = await self.client.keys("cache:*")
|
|
141
|
+
if keys:
|
|
142
|
+
await self.client.delete(*keys)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.error(f"Failed to clear Redis cache: {e}")
|
|
145
|
+
|
|
146
|
+
async def close(self) -> None:
|
|
147
|
+
"""Closes Redis connection if open."""
|
|
148
|
+
if self.client:
|
|
149
|
+
try:
|
|
150
|
+
await self.client.aclose()
|
|
151
|
+
except Exception:
|
|
152
|
+
pass
|
|
153
|
+
self.client = None
|
aetherroute/config.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, Any, List, Optional
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
class ModelConfig(BaseModel):
|
|
8
|
+
name: str
|
|
9
|
+
max_context: int
|
|
10
|
+
cost_per_1k_input: float
|
|
11
|
+
cost_per_1k_output: float
|
|
12
|
+
latency_sla: float
|
|
13
|
+
|
|
14
|
+
class ProviderConfig(BaseModel):
|
|
15
|
+
api_key_env: Optional[str] = None
|
|
16
|
+
base_url: Optional[str] = None
|
|
17
|
+
models: List[ModelConfig]
|
|
18
|
+
|
|
19
|
+
class TaskWeights(BaseModel):
|
|
20
|
+
accuracy: float
|
|
21
|
+
cost: float
|
|
22
|
+
latency: float
|
|
23
|
+
|
|
24
|
+
class RoutingConfig(BaseModel):
|
|
25
|
+
default_model: str
|
|
26
|
+
cost_ceiling_request: float
|
|
27
|
+
cost_ceiling_session: float
|
|
28
|
+
historical_db_path: str
|
|
29
|
+
task_weights: Dict[str, TaskWeights]
|
|
30
|
+
|
|
31
|
+
class CacheConfig(BaseModel):
|
|
32
|
+
redis_url: str
|
|
33
|
+
semantic_threshold: float
|
|
34
|
+
ttl_seconds: int
|
|
35
|
+
|
|
36
|
+
class SecurityConfig(BaseModel):
|
|
37
|
+
prompt_injection_action: str
|
|
38
|
+
permissions: Dict[str, Dict[str, List[str]]]
|
|
39
|
+
|
|
40
|
+
class AetherRouteConfig(BaseModel):
|
|
41
|
+
providers: Dict[str, ProviderConfig]
|
|
42
|
+
routing: RoutingConfig
|
|
43
|
+
cache: CacheConfig
|
|
44
|
+
security: SecurityConfig
|
|
45
|
+
|
|
46
|
+
def load_config(config_path: Optional[str] = None) -> AetherRouteConfig:
|
|
47
|
+
"""Load configuration from a YAML file. Looks in default locations if path not provided."""
|
|
48
|
+
if config_path is None:
|
|
49
|
+
# Check current working directory, then package root parent
|
|
50
|
+
paths_to_check = [
|
|
51
|
+
Path.cwd() / "config.yaml",
|
|
52
|
+
Path(__file__).parent.parent / "config.yaml",
|
|
53
|
+
]
|
|
54
|
+
for p in paths_to_check:
|
|
55
|
+
if p.exists():
|
|
56
|
+
config_path = str(p)
|
|
57
|
+
break
|
|
58
|
+
|
|
59
|
+
if not config_path or not Path(config_path).exists():
|
|
60
|
+
raise FileNotFoundError(f"Configuration file not found. Checked default locations. Please provide a valid config_path.")
|
|
61
|
+
|
|
62
|
+
with open(config_path, "r") as f:
|
|
63
|
+
data = yaml.safe_load(f)
|
|
64
|
+
|
|
65
|
+
return AetherRouteConfig(**data)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import math
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Dict, Any, Optional
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from aetherroute.adapters.token_counter import count_tokens, count_messages_tokens
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger("aetherroute.context.curator")
|
|
9
|
+
|
|
10
|
+
class ContextCurator:
|
|
11
|
+
"""
|
|
12
|
+
Ranks context documents based on TF-IDF relevance to the query, prunes context
|
|
13
|
+
to fit token budgets, and uses a cheap provider/model to summarize older conversation history.
|
|
14
|
+
"""
|
|
15
|
+
def __init__(self, config: Any):
|
|
16
|
+
self.config = config
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def _tokenize(text: str) -> List[str]:
|
|
20
|
+
"""Simple lowercase word tokenization."""
|
|
21
|
+
return re.findall(r'[a-z0-9]+', text.lower())
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def calculate_tfidf_similarity(cls, query: str, document: str) -> float:
|
|
25
|
+
"""
|
|
26
|
+
Calculates cosine similarity between query and document using simple TF-IDF vectors.
|
|
27
|
+
"""
|
|
28
|
+
q_tokens = cls._tokenize(query)
|
|
29
|
+
d_tokens = cls._tokenize(document)
|
|
30
|
+
|
|
31
|
+
if not q_tokens or not d_tokens:
|
|
32
|
+
return 0.0
|
|
33
|
+
|
|
34
|
+
q_counts = Counter(q_tokens)
|
|
35
|
+
d_counts = Counter(d_tokens)
|
|
36
|
+
|
|
37
|
+
# Calculate term frequencies
|
|
38
|
+
q_tf = {word: count / len(q_tokens) for word, count in q_counts.items()}
|
|
39
|
+
d_tf = {word: count / len(d_tokens) for word, count in d_counts.items()}
|
|
40
|
+
|
|
41
|
+
# Calculate cosine similarity
|
|
42
|
+
dot_product = 0.0
|
|
43
|
+
for word in q_tf:
|
|
44
|
+
if word in d_tf:
|
|
45
|
+
dot_product += q_tf[word] * d_tf[word]
|
|
46
|
+
|
|
47
|
+
q_norm = math.sqrt(sum(val ** 2 for val in q_tf.values()))
|
|
48
|
+
d_norm = math.sqrt(sum(val ** 2 for val in d_tf.values()))
|
|
49
|
+
|
|
50
|
+
if q_norm == 0 or d_norm == 0:
|
|
51
|
+
return 0.0
|
|
52
|
+
|
|
53
|
+
return dot_product / (q_norm * d_norm)
|
|
54
|
+
|
|
55
|
+
def rank_and_prune_documents(
|
|
56
|
+
self,
|
|
57
|
+
query: str,
|
|
58
|
+
documents: List[Dict[str, Any]],
|
|
59
|
+
max_tokens: int,
|
|
60
|
+
model_name: str = "gpt-4o-mini"
|
|
61
|
+
) -> List[Dict[str, Any]]:
|
|
62
|
+
"""
|
|
63
|
+
Ranks documents based on query similarity and includes as many high-scoring
|
|
64
|
+
documents as possible within the specified max_tokens limit.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
query: The user's search query.
|
|
68
|
+
documents: List of dicts, e.g. [{"id": 1, "text": "content..."}, ...]
|
|
69
|
+
max_tokens: Limit on how many tokens of context we can include.
|
|
70
|
+
model_name: Model used to compute token count.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Pruned and sorted list of relevant documents.
|
|
74
|
+
"""
|
|
75
|
+
scored_docs = []
|
|
76
|
+
for doc in documents:
|
|
77
|
+
text = doc.get("text", "")
|
|
78
|
+
score = self.calculate_tfidf_similarity(query, text)
|
|
79
|
+
# Store doc with its score
|
|
80
|
+
scored_docs.append((doc, score))
|
|
81
|
+
|
|
82
|
+
# Sort by score descending
|
|
83
|
+
scored_docs.sort(key=lambda x: x[1], reverse=True)
|
|
84
|
+
|
|
85
|
+
selected_docs = []
|
|
86
|
+
accumulated_tokens = 0
|
|
87
|
+
|
|
88
|
+
for doc, score in scored_docs:
|
|
89
|
+
doc_tokens = count_tokens(doc.get("text", ""), model_name)
|
|
90
|
+
if accumulated_tokens + doc_tokens <= max_tokens:
|
|
91
|
+
selected_docs.append(doc)
|
|
92
|
+
accumulated_tokens += doc_tokens
|
|
93
|
+
else:
|
|
94
|
+
logger.info(f"Pruning document (score={score:.4f}) to stay under token budget of {max_tokens}.")
|
|
95
|
+
|
|
96
|
+
return selected_docs
|
|
97
|
+
|
|
98
|
+
async def summarize_history(
|
|
99
|
+
self,
|
|
100
|
+
messages: List[Dict[str, str]],
|
|
101
|
+
cheap_provider: Any,
|
|
102
|
+
cheap_model: str,
|
|
103
|
+
keep_last_n: int = 3
|
|
104
|
+
) -> List[Dict[str, str]]:
|
|
105
|
+
"""
|
|
106
|
+
Compresses older messages in a conversation history by summarizing them
|
|
107
|
+
using a cheap provider model, keeping the last N messages as active context.
|
|
108
|
+
"""
|
|
109
|
+
# If there aren't enough messages to warrant summarization, return as is
|
|
110
|
+
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|
111
|
+
chat_msgs = [m for m in messages if m.get("role") != "system"]
|
|
112
|
+
|
|
113
|
+
if len(chat_msgs) <= keep_last_n + 1:
|
|
114
|
+
return messages
|
|
115
|
+
|
|
116
|
+
# Split into old history and fresh messages
|
|
117
|
+
old_history = chat_msgs[:-keep_last_n]
|
|
118
|
+
fresh_messages = chat_msgs[-keep_last_n:]
|
|
119
|
+
|
|
120
|
+
# Build a text compilation of the old history
|
|
121
|
+
formatted_history = []
|
|
122
|
+
for msg in old_history:
|
|
123
|
+
role = msg.get("role", "user").upper()
|
|
124
|
+
content = msg.get("content", "")
|
|
125
|
+
formatted_history.append(f"{role}: {content}")
|
|
126
|
+
|
|
127
|
+
history_text = "\n".join(formatted_history)
|
|
128
|
+
|
|
129
|
+
# Prepare the summarization query
|
|
130
|
+
summarize_prompt = [
|
|
131
|
+
{
|
|
132
|
+
"role": "system",
|
|
133
|
+
"content": "You are a helpful assistant. Summarize the following chat conversation highlights in a concise, dense paragraph. Focus on user requests, key decisions made, and assistant answers."
|
|
134
|
+
},
|
|
135
|
+
{
|
|
136
|
+
"role": "user",
|
|
137
|
+
"content": f"Conversation to summarize:\n\n{history_text}"
|
|
138
|
+
}
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
logger.info(f"Summarizing {len(old_history)} messages using cheap model '{cheap_model}'...")
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Generate the summary using the cheap provider
|
|
145
|
+
summary_response = await cheap_provider.generate(
|
|
146
|
+
messages=summarize_prompt,
|
|
147
|
+
model=cheap_model,
|
|
148
|
+
options={"max_tokens": 500, "temperature": 0.3}
|
|
149
|
+
)
|
|
150
|
+
summary_text = summary_response.get("text", "Error generating summary.")
|
|
151
|
+
|
|
152
|
+
# Form the new message history
|
|
153
|
+
new_history = system_msgs.copy()
|
|
154
|
+
new_history.append({
|
|
155
|
+
"role": "system",
|
|
156
|
+
"content": f"Summary of earlier conversation highlights: {summary_text}"
|
|
157
|
+
})
|
|
158
|
+
new_history.extend(fresh_messages)
|
|
159
|
+
|
|
160
|
+
logger.info("Successfully summarized older chat history.")
|
|
161
|
+
return new_history
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Failed to summarize older context: {e}. Returning original messages.")
|
|
164
|
+
return messages
|