headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Provider detection for LangChain models.
|
|
2
|
+
|
|
3
|
+
This module provides automatic provider detection from LangChain chat models
|
|
4
|
+
without requiring explicit provider imports. It uses duck-typing based on
|
|
5
|
+
class paths to identify the appropriate Headroom provider.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from langchain_anthropic import ChatAnthropic
|
|
9
|
+
from headroom.integrations.langchain import get_headroom_provider
|
|
10
|
+
|
|
11
|
+
model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
|
|
12
|
+
provider = get_headroom_provider(model) # Returns AnthropicProvider
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import TYPE_CHECKING, Any
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from headroom.providers.base import Provider
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Provider detection patterns
|
|
26
|
+
# Maps provider name to list of class path patterns to match
|
|
27
|
+
PROVIDER_PATTERNS: dict[str, list[str]] = {
|
|
28
|
+
"openai": [
|
|
29
|
+
"langchain_openai.ChatOpenAI",
|
|
30
|
+
"langchain_openai.chat_models.ChatOpenAI",
|
|
31
|
+
"langchain_community.chat_models.ChatOpenAI",
|
|
32
|
+
"langchain.chat_models.ChatOpenAI",
|
|
33
|
+
"ChatOpenAI",
|
|
34
|
+
],
|
|
35
|
+
"anthropic": [
|
|
36
|
+
"langchain_anthropic.ChatAnthropic",
|
|
37
|
+
"langchain_anthropic.chat_models.ChatAnthropic",
|
|
38
|
+
"langchain_community.chat_models.ChatAnthropic",
|
|
39
|
+
"langchain.chat_models.ChatAnthropic",
|
|
40
|
+
"ChatAnthropic",
|
|
41
|
+
],
|
|
42
|
+
"google": [
|
|
43
|
+
"langchain_google_genai.ChatGoogleGenerativeAI",
|
|
44
|
+
"langchain_google_genai.chat_models.ChatGoogleGenerativeAI",
|
|
45
|
+
"langchain_community.chat_models.ChatGoogleGenerativeAI",
|
|
46
|
+
"ChatGoogleGenerativeAI",
|
|
47
|
+
# Also match Vertex AI
|
|
48
|
+
"langchain_google_vertexai.ChatVertexAI",
|
|
49
|
+
"ChatVertexAI",
|
|
50
|
+
],
|
|
51
|
+
"cohere": [
|
|
52
|
+
"langchain_cohere.ChatCohere",
|
|
53
|
+
"langchain_community.chat_models.ChatCohere",
|
|
54
|
+
"ChatCohere",
|
|
55
|
+
],
|
|
56
|
+
"mistral": [
|
|
57
|
+
"langchain_mistralai.ChatMistralAI",
|
|
58
|
+
"langchain_community.chat_models.ChatMistralAI",
|
|
59
|
+
"ChatMistralAI",
|
|
60
|
+
],
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# Model name patterns for fallback detection
|
|
64
|
+
MODEL_NAME_PATTERNS: dict[str, list[str]] = {
|
|
65
|
+
"anthropic": ["claude", "anthropic"],
|
|
66
|
+
"openai": ["gpt", "o1", "o3", "davinci", "turbo"],
|
|
67
|
+
"google": ["gemini", "palm", "bison"],
|
|
68
|
+
"cohere": ["command", "cohere"],
|
|
69
|
+
"mistral": ["mistral", "mixtral"],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def detect_provider(model: Any) -> str:
|
|
74
|
+
"""Detect provider name from a LangChain model using duck-typing.
|
|
75
|
+
|
|
76
|
+
Detection strategy:
|
|
77
|
+
1. Check class module and name against known patterns
|
|
78
|
+
2. Check model_name attribute against known model patterns
|
|
79
|
+
3. Fall back to "openai" as safe default
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
model: Any LangChain chat model instance
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Provider name string: "openai", "anthropic", "google", "cohere", "mistral"
|
|
86
|
+
|
|
87
|
+
Example:
|
|
88
|
+
>>> from langchain_anthropic import ChatAnthropic
|
|
89
|
+
>>> model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
|
|
90
|
+
>>> detect_provider(model)
|
|
91
|
+
'anthropic'
|
|
92
|
+
"""
|
|
93
|
+
# Strategy 1: Check class path
|
|
94
|
+
class_module = getattr(model.__class__, "__module__", "")
|
|
95
|
+
class_name = model.__class__.__name__
|
|
96
|
+
class_path = f"{class_module}.{class_name}"
|
|
97
|
+
|
|
98
|
+
for provider_name, patterns in PROVIDER_PATTERNS.items():
|
|
99
|
+
for pattern in patterns:
|
|
100
|
+
if pattern in class_path or class_name == pattern.split(".")[-1]:
|
|
101
|
+
logger.debug(f"Detected provider '{provider_name}' from class path: {class_path}")
|
|
102
|
+
return provider_name
|
|
103
|
+
|
|
104
|
+
# Strategy 2: Check model_name attribute
|
|
105
|
+
model_name = _get_model_name(model)
|
|
106
|
+
if model_name:
|
|
107
|
+
model_name_lower = model_name.lower()
|
|
108
|
+
for provider_name, name_patterns in MODEL_NAME_PATTERNS.items():
|
|
109
|
+
for pattern in name_patterns:
|
|
110
|
+
if pattern in model_name_lower:
|
|
111
|
+
logger.debug(
|
|
112
|
+
f"Detected provider '{provider_name}' from model name: {model_name}"
|
|
113
|
+
)
|
|
114
|
+
return provider_name
|
|
115
|
+
|
|
116
|
+
# Strategy 3: Fall back to OpenAI (most common, safe default)
|
|
117
|
+
logger.debug(f"Could not detect provider for {class_path}, falling back to 'openai'")
|
|
118
|
+
return "openai"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _get_model_name(model: Any) -> str | None:
|
|
122
|
+
"""Extract model name from a LangChain model.
|
|
123
|
+
|
|
124
|
+
Tries common attribute names used by different LangChain models.
|
|
125
|
+
"""
|
|
126
|
+
# Try common attribute names
|
|
127
|
+
for attr in ["model_name", "model", "model_id", "_model_name"]:
|
|
128
|
+
value = getattr(model, attr, None)
|
|
129
|
+
if isinstance(value, str):
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_headroom_provider(model: Any) -> Provider:
|
|
136
|
+
"""Get appropriate Headroom Provider instance for a LangChain model.
|
|
137
|
+
|
|
138
|
+
This function automatically detects the provider from the model type
|
|
139
|
+
and returns a configured Headroom provider for accurate token counting
|
|
140
|
+
and context limit detection.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model: Any LangChain chat model instance
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Configured Headroom Provider instance
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
>>> from langchain_anthropic import ChatAnthropic
|
|
150
|
+
>>> model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
|
|
151
|
+
>>> provider = get_headroom_provider(model)
|
|
152
|
+
>>> provider.name
|
|
153
|
+
'anthropic'
|
|
154
|
+
"""
|
|
155
|
+
# Import providers lazily to avoid circular imports
|
|
156
|
+
from headroom.providers import (
|
|
157
|
+
AnthropicProvider,
|
|
158
|
+
GoogleProvider,
|
|
159
|
+
OpenAIProvider,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
provider_name = detect_provider(model)
|
|
163
|
+
|
|
164
|
+
if provider_name == "anthropic":
|
|
165
|
+
return AnthropicProvider()
|
|
166
|
+
elif provider_name == "google":
|
|
167
|
+
return GoogleProvider()
|
|
168
|
+
# Cohere and Mistral fall back to OpenAI-compatible for now
|
|
169
|
+
# TODO: Add dedicated providers when needed
|
|
170
|
+
|
|
171
|
+
# Default to OpenAI
|
|
172
|
+
return OpenAIProvider()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_model_name_from_langchain(model: Any) -> str:
|
|
176
|
+
"""Extract the model name string from a LangChain model.
|
|
177
|
+
|
|
178
|
+
Useful for getting the model identifier for token counting
|
|
179
|
+
and context limit lookup.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
model: Any LangChain chat model instance
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Model name string (e.g., "gpt-4o", "claude-3-5-sonnet-20241022")
|
|
186
|
+
"""
|
|
187
|
+
name = _get_model_name(model)
|
|
188
|
+
if name:
|
|
189
|
+
return name
|
|
190
|
+
|
|
191
|
+
# Try to infer from class name
|
|
192
|
+
class_name = model.__class__.__name__
|
|
193
|
+
if "GPT" in class_name or "OpenAI" in class_name:
|
|
194
|
+
return "gpt-4o" # Safe default for OpenAI
|
|
195
|
+
elif "Anthropic" in class_name or "Claude" in class_name:
|
|
196
|
+
return "claude-3-5-sonnet-20241022" # Safe default for Anthropic
|
|
197
|
+
elif "Google" in class_name or "Gemini" in class_name:
|
|
198
|
+
return "gemini-1.5-pro" # Safe default for Google
|
|
199
|
+
|
|
200
|
+
return "gpt-4o" # Ultimate fallback
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""Retriever integration for LangChain with intelligent document compression.
|
|
2
|
+
|
|
3
|
+
This module provides HeadroomDocumentCompressor, a LangChain BaseDocumentCompressor
|
|
4
|
+
that reduces retrieved documents based on relevance scoring while preserving
|
|
5
|
+
the most important information.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from langchain.retrievers import ContextualCompressionRetriever
|
|
9
|
+
from langchain_community.vectorstores import Chroma
|
|
10
|
+
from headroom.integrations import HeadroomDocumentCompressor
|
|
11
|
+
|
|
12
|
+
# Create vector store retriever
|
|
13
|
+
vectorstore = Chroma.from_documents(documents, embeddings)
|
|
14
|
+
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 50})
|
|
15
|
+
|
|
16
|
+
# Wrap with Headroom compression
|
|
17
|
+
compressor = HeadroomDocumentCompressor(max_documents=10)
|
|
18
|
+
retriever = ContextualCompressionRetriever(
|
|
19
|
+
base_compressor=compressor,
|
|
20
|
+
base_retriever=base_retriever,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Retrieve - automatically keeps most relevant documents
|
|
24
|
+
docs = retriever.invoke("What is the capital of France?")
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import logging
|
|
30
|
+
import re
|
|
31
|
+
from collections.abc import Sequence
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from typing import Any
|
|
34
|
+
|
|
35
|
+
# LangChain imports - these are optional dependencies
|
|
36
|
+
try:
|
|
37
|
+
from langchain_core.callbacks import Callbacks
|
|
38
|
+
from langchain_core.documents import Document
|
|
39
|
+
|
|
40
|
+
# BaseDocumentCompressor location varies by langchain version
|
|
41
|
+
try:
|
|
42
|
+
from langchain.retrievers.document_compressors import BaseDocumentCompressor
|
|
43
|
+
except ImportError:
|
|
44
|
+
try:
|
|
45
|
+
from langchain_core.documents.compressors import BaseDocumentCompressor
|
|
46
|
+
except ImportError:
|
|
47
|
+
# Fallback: create a minimal base class
|
|
48
|
+
class BaseDocumentCompressor: # type: ignore[no-redef]
|
|
49
|
+
"""Minimal base class for document compression."""
|
|
50
|
+
|
|
51
|
+
def compress_documents(
|
|
52
|
+
self, documents: Sequence[Any], query: str, callbacks: Any = None
|
|
53
|
+
) -> Sequence[Any]:
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
LANGCHAIN_AVAILABLE = True
|
|
57
|
+
except ImportError:
|
|
58
|
+
LANGCHAIN_AVAILABLE = False
|
|
59
|
+
BaseDocumentCompressor = object # type: ignore[misc,assignment]
|
|
60
|
+
Document = object # type: ignore[misc,assignment]
|
|
61
|
+
Callbacks = None # type: ignore[misc,assignment]
|
|
62
|
+
|
|
63
|
+
logger = logging.getLogger(__name__)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _check_langchain_available() -> None:
|
|
67
|
+
"""Raise ImportError if LangChain is not installed."""
|
|
68
|
+
if not LANGCHAIN_AVAILABLE:
|
|
69
|
+
raise ImportError(
|
|
70
|
+
"LangChain is required for this integration. "
|
|
71
|
+
"Install with: pip install headroom[langchain] "
|
|
72
|
+
"or: pip install langchain-core"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class CompressionMetrics:
|
|
78
|
+
"""Metrics from document compression."""
|
|
79
|
+
|
|
80
|
+
documents_before: int
|
|
81
|
+
documents_after: int
|
|
82
|
+
documents_removed: int
|
|
83
|
+
relevance_scores: list[float]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class HeadroomDocumentCompressor(BaseDocumentCompressor):
|
|
87
|
+
"""Compresses retrieved documents based on relevance to query.
|
|
88
|
+
|
|
89
|
+
Uses BM25-style relevance scoring to keep only the most relevant
|
|
90
|
+
documents from a larger retrieval set. This allows you to retrieve
|
|
91
|
+
many documents initially (for recall) and then compress down to
|
|
92
|
+
the most relevant ones (for precision).
|
|
93
|
+
|
|
94
|
+
Works with LangChain's ContextualCompressionRetriever pattern.
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
from langchain.retrievers import ContextualCompressionRetriever
|
|
98
|
+
from headroom.integrations import HeadroomDocumentCompressor
|
|
99
|
+
|
|
100
|
+
compressor = HeadroomDocumentCompressor(
|
|
101
|
+
max_documents=10,
|
|
102
|
+
min_relevance=0.3,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
retriever = ContextualCompressionRetriever(
|
|
106
|
+
base_compressor=compressor,
|
|
107
|
+
base_retriever=base_retriever, # Any retriever
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Retrieves top 10 most relevant docs
|
|
111
|
+
docs = retriever.invoke("What is Python?")
|
|
112
|
+
|
|
113
|
+
Attributes:
|
|
114
|
+
max_documents: Maximum documents to return
|
|
115
|
+
min_relevance: Minimum relevance score (0-1) to include
|
|
116
|
+
prefer_diverse: Whether to prefer diverse results
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
max_documents: int = 10
|
|
120
|
+
min_relevance: float = 0.0
|
|
121
|
+
prefer_diverse: bool = False
|
|
122
|
+
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
max_documents: int = 10,
|
|
126
|
+
min_relevance: float = 0.0,
|
|
127
|
+
prefer_diverse: bool = False,
|
|
128
|
+
**kwargs: Any,
|
|
129
|
+
):
|
|
130
|
+
"""Initialize HeadroomDocumentCompressor.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
max_documents: Maximum number of documents to return. Default 10.
|
|
134
|
+
min_relevance: Minimum relevance score (0-1) for a document to
|
|
135
|
+
be included. Default 0.0 (no minimum).
|
|
136
|
+
prefer_diverse: If True, use MMR-style selection to prefer
|
|
137
|
+
diverse results over pure relevance. Default False.
|
|
138
|
+
**kwargs: Additional arguments for BaseDocumentCompressor.
|
|
139
|
+
"""
|
|
140
|
+
_check_langchain_available()
|
|
141
|
+
|
|
142
|
+
super().__init__(**kwargs)
|
|
143
|
+
self.max_documents = max_documents
|
|
144
|
+
self.min_relevance = min_relevance
|
|
145
|
+
self.prefer_diverse = prefer_diverse
|
|
146
|
+
self._last_metrics: CompressionMetrics | None = None
|
|
147
|
+
|
|
148
|
+
def compress_documents(
|
|
149
|
+
self,
|
|
150
|
+
documents: Sequence[Document],
|
|
151
|
+
query: str,
|
|
152
|
+
callbacks: Callbacks = None,
|
|
153
|
+
) -> Sequence[Document]:
|
|
154
|
+
"""Compress documents based on relevance to query.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
documents: Documents to compress.
|
|
158
|
+
query: Query to score relevance against.
|
|
159
|
+
callbacks: LangChain callbacks (unused).
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Compressed list of most relevant documents.
|
|
163
|
+
"""
|
|
164
|
+
if not documents:
|
|
165
|
+
self._last_metrics = CompressionMetrics(
|
|
166
|
+
documents_before=0,
|
|
167
|
+
documents_after=0,
|
|
168
|
+
documents_removed=0,
|
|
169
|
+
relevance_scores=[],
|
|
170
|
+
)
|
|
171
|
+
return []
|
|
172
|
+
|
|
173
|
+
if len(documents) <= self.max_documents:
|
|
174
|
+
# No compression needed
|
|
175
|
+
scores = [self._score_document(doc, query) for doc in documents]
|
|
176
|
+
self._last_metrics = CompressionMetrics(
|
|
177
|
+
documents_before=len(documents),
|
|
178
|
+
documents_after=len(documents),
|
|
179
|
+
documents_removed=0,
|
|
180
|
+
relevance_scores=scores,
|
|
181
|
+
)
|
|
182
|
+
return list(documents)
|
|
183
|
+
|
|
184
|
+
# Score all documents
|
|
185
|
+
scored = [(doc, self._score_document(doc, query)) for doc in documents]
|
|
186
|
+
|
|
187
|
+
if self.prefer_diverse:
|
|
188
|
+
# Use MMR-style selection for diversity
|
|
189
|
+
selected = self._select_diverse(scored, query)
|
|
190
|
+
else:
|
|
191
|
+
# Sort by relevance score
|
|
192
|
+
scored.sort(key=lambda x: x[1], reverse=True)
|
|
193
|
+
selected = scored[: self.max_documents]
|
|
194
|
+
|
|
195
|
+
# Filter by minimum relevance
|
|
196
|
+
if self.min_relevance > 0:
|
|
197
|
+
selected = [(doc, score) for doc, score in selected if score >= self.min_relevance]
|
|
198
|
+
|
|
199
|
+
# Track metrics
|
|
200
|
+
final_docs = [doc for doc, _ in selected]
|
|
201
|
+
final_scores = [score for _, score in selected]
|
|
202
|
+
|
|
203
|
+
self._last_metrics = CompressionMetrics(
|
|
204
|
+
documents_before=len(documents),
|
|
205
|
+
documents_after=len(final_docs),
|
|
206
|
+
documents_removed=len(documents) - len(final_docs),
|
|
207
|
+
relevance_scores=final_scores,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
logger.info(
|
|
211
|
+
f"HeadroomDocumentCompressor: {len(documents)} -> {len(final_docs)} documents "
|
|
212
|
+
f"(avg relevance: {sum(final_scores) / len(final_scores) if final_scores else 0:.2f})"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
return final_docs
|
|
216
|
+
|
|
217
|
+
def _score_document(self, doc: Document, query: str) -> float:
|
|
218
|
+
"""Score a document's relevance to the query using BM25-style scoring.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
doc: Document to score.
|
|
222
|
+
query: Query to compare against.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Relevance score between 0 and 1.
|
|
226
|
+
"""
|
|
227
|
+
content = doc.page_content.lower()
|
|
228
|
+
query_lower = query.lower()
|
|
229
|
+
|
|
230
|
+
# Tokenize
|
|
231
|
+
query_terms = self._tokenize(query_lower)
|
|
232
|
+
doc_terms = self._tokenize(content)
|
|
233
|
+
|
|
234
|
+
if not query_terms or not doc_terms:
|
|
235
|
+
return 0.0
|
|
236
|
+
|
|
237
|
+
# BM25-style scoring
|
|
238
|
+
k1 = 1.5
|
|
239
|
+
b = 0.75
|
|
240
|
+
avg_dl = 100 # Assume average document length
|
|
241
|
+
|
|
242
|
+
doc_len = len(doc_terms)
|
|
243
|
+
term_freqs: dict[str, int] = {}
|
|
244
|
+
for term in doc_terms:
|
|
245
|
+
term_freqs[term] = term_freqs.get(term, 0) + 1
|
|
246
|
+
|
|
247
|
+
score = 0.0
|
|
248
|
+
for term in query_terms:
|
|
249
|
+
if term in term_freqs:
|
|
250
|
+
tf = term_freqs[term]
|
|
251
|
+
# Simplified BM25 (without IDF since we don't have corpus stats)
|
|
252
|
+
numerator = tf * (k1 + 1)
|
|
253
|
+
denominator = tf + k1 * (1 - b + b * (doc_len / avg_dl))
|
|
254
|
+
score += numerator / denominator
|
|
255
|
+
|
|
256
|
+
# Normalize to 0-1 range
|
|
257
|
+
max_possible = len(query_terms) * (k1 + 1)
|
|
258
|
+
normalized = score / max_possible if max_possible > 0 else 0.0
|
|
259
|
+
|
|
260
|
+
# Boost for exact phrase matches
|
|
261
|
+
if query_lower in content:
|
|
262
|
+
normalized = min(1.0, normalized + 0.3)
|
|
263
|
+
|
|
264
|
+
return min(1.0, normalized)
|
|
265
|
+
|
|
266
|
+
def _tokenize(self, text: str) -> list[str]:
|
|
267
|
+
"""Tokenize text into terms.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
text: Text to tokenize.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
List of tokens.
|
|
274
|
+
"""
|
|
275
|
+
# Simple tokenization: split on non-alphanumeric, filter short terms
|
|
276
|
+
tokens = re.findall(r"\b\w+\b", text)
|
|
277
|
+
return [t for t in tokens if len(t) > 1]
|
|
278
|
+
|
|
279
|
+
def _select_diverse(
|
|
280
|
+
self, scored_docs: list[tuple[Document, float]], query: str
|
|
281
|
+
) -> list[tuple[Document, float]]:
|
|
282
|
+
"""Select diverse documents using MMR-style approach.
|
|
283
|
+
|
|
284
|
+
Balances relevance with diversity to avoid redundant results.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
scored_docs: List of (document, relevance_score) tuples.
|
|
288
|
+
query: Original query.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Selected documents with diversity considered.
|
|
292
|
+
"""
|
|
293
|
+
if not scored_docs:
|
|
294
|
+
return []
|
|
295
|
+
|
|
296
|
+
# Sort by initial relevance
|
|
297
|
+
scored_docs = sorted(scored_docs, key=lambda x: x[1], reverse=True)
|
|
298
|
+
|
|
299
|
+
# Start with most relevant
|
|
300
|
+
selected = [scored_docs[0]]
|
|
301
|
+
remaining = scored_docs[1:]
|
|
302
|
+
|
|
303
|
+
lambda_param = 0.5 # Balance between relevance and diversity
|
|
304
|
+
|
|
305
|
+
while len(selected) < self.max_documents and remaining:
|
|
306
|
+
best_score = -1.0
|
|
307
|
+
best_idx = 0
|
|
308
|
+
|
|
309
|
+
for i, (doc, rel_score) in enumerate(remaining):
|
|
310
|
+
# Calculate max similarity to already selected docs
|
|
311
|
+
max_sim = max(self._document_similarity(doc, sel_doc) for sel_doc, _ in selected)
|
|
312
|
+
|
|
313
|
+
# MMR score: lambda * relevance - (1-lambda) * max_similarity
|
|
314
|
+
mmr_score = lambda_param * rel_score - (1 - lambda_param) * max_sim
|
|
315
|
+
|
|
316
|
+
if mmr_score > best_score:
|
|
317
|
+
best_score = mmr_score
|
|
318
|
+
best_idx = i
|
|
319
|
+
|
|
320
|
+
selected.append(remaining[best_idx])
|
|
321
|
+
remaining.pop(best_idx)
|
|
322
|
+
|
|
323
|
+
return selected
|
|
324
|
+
|
|
325
|
+
def _document_similarity(self, doc1: Document, doc2: Document) -> float:
|
|
326
|
+
"""Calculate similarity between two documents.
|
|
327
|
+
|
|
328
|
+
Uses Jaccard similarity on terms for simplicity.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
doc1: First document.
|
|
332
|
+
doc2: Second document.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Similarity score between 0 and 1.
|
|
336
|
+
"""
|
|
337
|
+
terms1 = set(self._tokenize(doc1.page_content.lower()))
|
|
338
|
+
terms2 = set(self._tokenize(doc2.page_content.lower()))
|
|
339
|
+
|
|
340
|
+
if not terms1 or not terms2:
|
|
341
|
+
return 0.0
|
|
342
|
+
|
|
343
|
+
intersection = len(terms1 & terms2)
|
|
344
|
+
union = len(terms1 | terms2)
|
|
345
|
+
|
|
346
|
+
return intersection / union if union > 0 else 0.0
|
|
347
|
+
|
|
348
|
+
@property
|
|
349
|
+
def last_metrics(self) -> CompressionMetrics | None:
|
|
350
|
+
"""Get metrics from the last compression operation."""
|
|
351
|
+
return self._last_metrics
|
|
352
|
+
|
|
353
|
+
def get_compression_stats(self) -> dict[str, Any]:
|
|
354
|
+
"""Get statistics from the last compression.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Dictionary with compression metrics, or empty if no compression yet.
|
|
358
|
+
"""
|
|
359
|
+
if self._last_metrics is None:
|
|
360
|
+
return {}
|
|
361
|
+
|
|
362
|
+
return {
|
|
363
|
+
"documents_before": self._last_metrics.documents_before,
|
|
364
|
+
"documents_after": self._last_metrics.documents_after,
|
|
365
|
+
"documents_removed": self._last_metrics.documents_removed,
|
|
366
|
+
"average_relevance": (
|
|
367
|
+
sum(self._last_metrics.relevance_scores) / len(self._last_metrics.relevance_scores)
|
|
368
|
+
if self._last_metrics.relevance_scores
|
|
369
|
+
else 0.0
|
|
370
|
+
),
|
|
371
|
+
}
|