headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,200 @@
1
+ """Provider detection for LangChain models.
2
+
3
+ This module provides automatic provider detection from LangChain chat models
4
+ without requiring explicit provider imports. It uses duck-typing based on
5
+ class paths to identify the appropriate Headroom provider.
6
+
7
+ Example:
8
+ from langchain_anthropic import ChatAnthropic
9
+ from headroom.integrations.langchain import get_headroom_provider
10
+
11
+ model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
12
+ provider = get_headroom_provider(model) # Returns AnthropicProvider
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import logging
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ if TYPE_CHECKING:
21
+ from headroom.providers.base import Provider
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Provider detection patterns
26
+ # Maps provider name to list of class path patterns to match
27
+ PROVIDER_PATTERNS: dict[str, list[str]] = {
28
+ "openai": [
29
+ "langchain_openai.ChatOpenAI",
30
+ "langchain_openai.chat_models.ChatOpenAI",
31
+ "langchain_community.chat_models.ChatOpenAI",
32
+ "langchain.chat_models.ChatOpenAI",
33
+ "ChatOpenAI",
34
+ ],
35
+ "anthropic": [
36
+ "langchain_anthropic.ChatAnthropic",
37
+ "langchain_anthropic.chat_models.ChatAnthropic",
38
+ "langchain_community.chat_models.ChatAnthropic",
39
+ "langchain.chat_models.ChatAnthropic",
40
+ "ChatAnthropic",
41
+ ],
42
+ "google": [
43
+ "langchain_google_genai.ChatGoogleGenerativeAI",
44
+ "langchain_google_genai.chat_models.ChatGoogleGenerativeAI",
45
+ "langchain_community.chat_models.ChatGoogleGenerativeAI",
46
+ "ChatGoogleGenerativeAI",
47
+ # Also match Vertex AI
48
+ "langchain_google_vertexai.ChatVertexAI",
49
+ "ChatVertexAI",
50
+ ],
51
+ "cohere": [
52
+ "langchain_cohere.ChatCohere",
53
+ "langchain_community.chat_models.ChatCohere",
54
+ "ChatCohere",
55
+ ],
56
+ "mistral": [
57
+ "langchain_mistralai.ChatMistralAI",
58
+ "langchain_community.chat_models.ChatMistralAI",
59
+ "ChatMistralAI",
60
+ ],
61
+ }
62
+
63
+ # Model name patterns for fallback detection
64
+ MODEL_NAME_PATTERNS: dict[str, list[str]] = {
65
+ "anthropic": ["claude", "anthropic"],
66
+ "openai": ["gpt", "o1", "o3", "davinci", "turbo"],
67
+ "google": ["gemini", "palm", "bison"],
68
+ "cohere": ["command", "cohere"],
69
+ "mistral": ["mistral", "mixtral"],
70
+ }
71
+
72
+
73
+ def detect_provider(model: Any) -> str:
74
+ """Detect provider name from a LangChain model using duck-typing.
75
+
76
+ Detection strategy:
77
+ 1. Check class module and name against known patterns
78
+ 2. Check model_name attribute against known model patterns
79
+ 3. Fall back to "openai" as safe default
80
+
81
+ Args:
82
+ model: Any LangChain chat model instance
83
+
84
+ Returns:
85
+ Provider name string: "openai", "anthropic", "google", "cohere", "mistral"
86
+
87
+ Example:
88
+ >>> from langchain_anthropic import ChatAnthropic
89
+ >>> model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
90
+ >>> detect_provider(model)
91
+ 'anthropic'
92
+ """
93
+ # Strategy 1: Check class path
94
+ class_module = getattr(model.__class__, "__module__", "")
95
+ class_name = model.__class__.__name__
96
+ class_path = f"{class_module}.{class_name}"
97
+
98
+ for provider_name, patterns in PROVIDER_PATTERNS.items():
99
+ for pattern in patterns:
100
+ if pattern in class_path or class_name == pattern.split(".")[-1]:
101
+ logger.debug(f"Detected provider '{provider_name}' from class path: {class_path}")
102
+ return provider_name
103
+
104
+ # Strategy 2: Check model_name attribute
105
+ model_name = _get_model_name(model)
106
+ if model_name:
107
+ model_name_lower = model_name.lower()
108
+ for provider_name, name_patterns in MODEL_NAME_PATTERNS.items():
109
+ for pattern in name_patterns:
110
+ if pattern in model_name_lower:
111
+ logger.debug(
112
+ f"Detected provider '{provider_name}' from model name: {model_name}"
113
+ )
114
+ return provider_name
115
+
116
+ # Strategy 3: Fall back to OpenAI (most common, safe default)
117
+ logger.debug(f"Could not detect provider for {class_path}, falling back to 'openai'")
118
+ return "openai"
119
+
120
+
121
+ def _get_model_name(model: Any) -> str | None:
122
+ """Extract model name from a LangChain model.
123
+
124
+ Tries common attribute names used by different LangChain models.
125
+ """
126
+ # Try common attribute names
127
+ for attr in ["model_name", "model", "model_id", "_model_name"]:
128
+ value = getattr(model, attr, None)
129
+ if isinstance(value, str):
130
+ return value
131
+
132
+ return None
133
+
134
+
135
+ def get_headroom_provider(model: Any) -> Provider:
136
+ """Get appropriate Headroom Provider instance for a LangChain model.
137
+
138
+ This function automatically detects the provider from the model type
139
+ and returns a configured Headroom provider for accurate token counting
140
+ and context limit detection.
141
+
142
+ Args:
143
+ model: Any LangChain chat model instance
144
+
145
+ Returns:
146
+ Configured Headroom Provider instance
147
+
148
+ Example:
149
+ >>> from langchain_anthropic import ChatAnthropic
150
+ >>> model = ChatAnthropic(model="claude-3-5-sonnet-20241022")
151
+ >>> provider = get_headroom_provider(model)
152
+ >>> provider.name
153
+ 'anthropic'
154
+ """
155
+ # Import providers lazily to avoid circular imports
156
+ from headroom.providers import (
157
+ AnthropicProvider,
158
+ GoogleProvider,
159
+ OpenAIProvider,
160
+ )
161
+
162
+ provider_name = detect_provider(model)
163
+
164
+ if provider_name == "anthropic":
165
+ return AnthropicProvider()
166
+ elif provider_name == "google":
167
+ return GoogleProvider()
168
+ # Cohere and Mistral fall back to OpenAI-compatible for now
169
+ # TODO: Add dedicated providers when needed
170
+
171
+ # Default to OpenAI
172
+ return OpenAIProvider()
173
+
174
+
175
+ def get_model_name_from_langchain(model: Any) -> str:
176
+ """Extract the model name string from a LangChain model.
177
+
178
+ Useful for getting the model identifier for token counting
179
+ and context limit lookup.
180
+
181
+ Args:
182
+ model: Any LangChain chat model instance
183
+
184
+ Returns:
185
+ Model name string (e.g., "gpt-4o", "claude-3-5-sonnet-20241022")
186
+ """
187
+ name = _get_model_name(model)
188
+ if name:
189
+ return name
190
+
191
+ # Try to infer from class name
192
+ class_name = model.__class__.__name__
193
+ if "GPT" in class_name or "OpenAI" in class_name:
194
+ return "gpt-4o" # Safe default for OpenAI
195
+ elif "Anthropic" in class_name or "Claude" in class_name:
196
+ return "claude-3-5-sonnet-20241022" # Safe default for Anthropic
197
+ elif "Google" in class_name or "Gemini" in class_name:
198
+ return "gemini-1.5-pro" # Safe default for Google
199
+
200
+ return "gpt-4o" # Ultimate fallback
@@ -0,0 +1,371 @@
1
+ """Retriever integration for LangChain with intelligent document compression.
2
+
3
+ This module provides HeadroomDocumentCompressor, a LangChain BaseDocumentCompressor
4
+ that reduces retrieved documents based on relevance scoring while preserving
5
+ the most important information.
6
+
7
+ Example:
8
+ from langchain.retrievers import ContextualCompressionRetriever
9
+ from langchain_community.vectorstores import Chroma
10
+ from headroom.integrations import HeadroomDocumentCompressor
11
+
12
+ # Create vector store retriever
13
+ vectorstore = Chroma.from_documents(documents, embeddings)
14
+ base_retriever = vectorstore.as_retriever(search_kwargs={"k": 50})
15
+
16
+ # Wrap with Headroom compression
17
+ compressor = HeadroomDocumentCompressor(max_documents=10)
18
+ retriever = ContextualCompressionRetriever(
19
+ base_compressor=compressor,
20
+ base_retriever=base_retriever,
21
+ )
22
+
23
+ # Retrieve - automatically keeps most relevant documents
24
+ docs = retriever.invoke("What is the capital of France?")
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import logging
30
+ import re
31
+ from collections.abc import Sequence
32
+ from dataclasses import dataclass
33
+ from typing import Any
34
+
35
+ # LangChain imports - these are optional dependencies
36
+ try:
37
+ from langchain_core.callbacks import Callbacks
38
+ from langchain_core.documents import Document
39
+
40
+ # BaseDocumentCompressor location varies by langchain version
41
+ try:
42
+ from langchain.retrievers.document_compressors import BaseDocumentCompressor
43
+ except ImportError:
44
+ try:
45
+ from langchain_core.documents.compressors import BaseDocumentCompressor
46
+ except ImportError:
47
+ # Fallback: create a minimal base class
48
+ class BaseDocumentCompressor: # type: ignore[no-redef]
49
+ """Minimal base class for document compression."""
50
+
51
+ def compress_documents(
52
+ self, documents: Sequence[Any], query: str, callbacks: Any = None
53
+ ) -> Sequence[Any]:
54
+ raise NotImplementedError
55
+
56
+ LANGCHAIN_AVAILABLE = True
57
+ except ImportError:
58
+ LANGCHAIN_AVAILABLE = False
59
+ BaseDocumentCompressor = object # type: ignore[misc,assignment]
60
+ Document = object # type: ignore[misc,assignment]
61
+ Callbacks = None # type: ignore[misc,assignment]
62
+
63
+ logger = logging.getLogger(__name__)
64
+
65
+
66
+ def _check_langchain_available() -> None:
67
+ """Raise ImportError if LangChain is not installed."""
68
+ if not LANGCHAIN_AVAILABLE:
69
+ raise ImportError(
70
+ "LangChain is required for this integration. "
71
+ "Install with: pip install headroom[langchain] "
72
+ "or: pip install langchain-core"
73
+ )
74
+
75
+
76
+ @dataclass
77
+ class CompressionMetrics:
78
+ """Metrics from document compression."""
79
+
80
+ documents_before: int
81
+ documents_after: int
82
+ documents_removed: int
83
+ relevance_scores: list[float]
84
+
85
+
86
+ class HeadroomDocumentCompressor(BaseDocumentCompressor):
87
+ """Compresses retrieved documents based on relevance to query.
88
+
89
+ Uses BM25-style relevance scoring to keep only the most relevant
90
+ documents from a larger retrieval set. This allows you to retrieve
91
+ many documents initially (for recall) and then compress down to
92
+ the most relevant ones (for precision).
93
+
94
+ Works with LangChain's ContextualCompressionRetriever pattern.
95
+
96
+ Example:
97
+ from langchain.retrievers import ContextualCompressionRetriever
98
+ from headroom.integrations import HeadroomDocumentCompressor
99
+
100
+ compressor = HeadroomDocumentCompressor(
101
+ max_documents=10,
102
+ min_relevance=0.3,
103
+ )
104
+
105
+ retriever = ContextualCompressionRetriever(
106
+ base_compressor=compressor,
107
+ base_retriever=base_retriever, # Any retriever
108
+ )
109
+
110
+ # Retrieves top 10 most relevant docs
111
+ docs = retriever.invoke("What is Python?")
112
+
113
+ Attributes:
114
+ max_documents: Maximum documents to return
115
+ min_relevance: Minimum relevance score (0-1) to include
116
+ prefer_diverse: Whether to prefer diverse results
117
+ """
118
+
119
+ max_documents: int = 10
120
+ min_relevance: float = 0.0
121
+ prefer_diverse: bool = False
122
+
123
+ def __init__(
124
+ self,
125
+ max_documents: int = 10,
126
+ min_relevance: float = 0.0,
127
+ prefer_diverse: bool = False,
128
+ **kwargs: Any,
129
+ ):
130
+ """Initialize HeadroomDocumentCompressor.
131
+
132
+ Args:
133
+ max_documents: Maximum number of documents to return. Default 10.
134
+ min_relevance: Minimum relevance score (0-1) for a document to
135
+ be included. Default 0.0 (no minimum).
136
+ prefer_diverse: If True, use MMR-style selection to prefer
137
+ diverse results over pure relevance. Default False.
138
+ **kwargs: Additional arguments for BaseDocumentCompressor.
139
+ """
140
+ _check_langchain_available()
141
+
142
+ super().__init__(**kwargs)
143
+ self.max_documents = max_documents
144
+ self.min_relevance = min_relevance
145
+ self.prefer_diverse = prefer_diverse
146
+ self._last_metrics: CompressionMetrics | None = None
147
+
148
+ def compress_documents(
149
+ self,
150
+ documents: Sequence[Document],
151
+ query: str,
152
+ callbacks: Callbacks = None,
153
+ ) -> Sequence[Document]:
154
+ """Compress documents based on relevance to query.
155
+
156
+ Args:
157
+ documents: Documents to compress.
158
+ query: Query to score relevance against.
159
+ callbacks: LangChain callbacks (unused).
160
+
161
+ Returns:
162
+ Compressed list of most relevant documents.
163
+ """
164
+ if not documents:
165
+ self._last_metrics = CompressionMetrics(
166
+ documents_before=0,
167
+ documents_after=0,
168
+ documents_removed=0,
169
+ relevance_scores=[],
170
+ )
171
+ return []
172
+
173
+ if len(documents) <= self.max_documents:
174
+ # No compression needed
175
+ scores = [self._score_document(doc, query) for doc in documents]
176
+ self._last_metrics = CompressionMetrics(
177
+ documents_before=len(documents),
178
+ documents_after=len(documents),
179
+ documents_removed=0,
180
+ relevance_scores=scores,
181
+ )
182
+ return list(documents)
183
+
184
+ # Score all documents
185
+ scored = [(doc, self._score_document(doc, query)) for doc in documents]
186
+
187
+ if self.prefer_diverse:
188
+ # Use MMR-style selection for diversity
189
+ selected = self._select_diverse(scored, query)
190
+ else:
191
+ # Sort by relevance score
192
+ scored.sort(key=lambda x: x[1], reverse=True)
193
+ selected = scored[: self.max_documents]
194
+
195
+ # Filter by minimum relevance
196
+ if self.min_relevance > 0:
197
+ selected = [(doc, score) for doc, score in selected if score >= self.min_relevance]
198
+
199
+ # Track metrics
200
+ final_docs = [doc for doc, _ in selected]
201
+ final_scores = [score for _, score in selected]
202
+
203
+ self._last_metrics = CompressionMetrics(
204
+ documents_before=len(documents),
205
+ documents_after=len(final_docs),
206
+ documents_removed=len(documents) - len(final_docs),
207
+ relevance_scores=final_scores,
208
+ )
209
+
210
+ logger.info(
211
+ f"HeadroomDocumentCompressor: {len(documents)} -> {len(final_docs)} documents "
212
+ f"(avg relevance: {sum(final_scores) / len(final_scores) if final_scores else 0:.2f})"
213
+ )
214
+
215
+ return final_docs
216
+
217
+ def _score_document(self, doc: Document, query: str) -> float:
218
+ """Score a document's relevance to the query using BM25-style scoring.
219
+
220
+ Args:
221
+ doc: Document to score.
222
+ query: Query to compare against.
223
+
224
+ Returns:
225
+ Relevance score between 0 and 1.
226
+ """
227
+ content = doc.page_content.lower()
228
+ query_lower = query.lower()
229
+
230
+ # Tokenize
231
+ query_terms = self._tokenize(query_lower)
232
+ doc_terms = self._tokenize(content)
233
+
234
+ if not query_terms or not doc_terms:
235
+ return 0.0
236
+
237
+ # BM25-style scoring
238
+ k1 = 1.5
239
+ b = 0.75
240
+ avg_dl = 100 # Assume average document length
241
+
242
+ doc_len = len(doc_terms)
243
+ term_freqs: dict[str, int] = {}
244
+ for term in doc_terms:
245
+ term_freqs[term] = term_freqs.get(term, 0) + 1
246
+
247
+ score = 0.0
248
+ for term in query_terms:
249
+ if term in term_freqs:
250
+ tf = term_freqs[term]
251
+ # Simplified BM25 (without IDF since we don't have corpus stats)
252
+ numerator = tf * (k1 + 1)
253
+ denominator = tf + k1 * (1 - b + b * (doc_len / avg_dl))
254
+ score += numerator / denominator
255
+
256
+ # Normalize to 0-1 range
257
+ max_possible = len(query_terms) * (k1 + 1)
258
+ normalized = score / max_possible if max_possible > 0 else 0.0
259
+
260
+ # Boost for exact phrase matches
261
+ if query_lower in content:
262
+ normalized = min(1.0, normalized + 0.3)
263
+
264
+ return min(1.0, normalized)
265
+
266
+ def _tokenize(self, text: str) -> list[str]:
267
+ """Tokenize text into terms.
268
+
269
+ Args:
270
+ text: Text to tokenize.
271
+
272
+ Returns:
273
+ List of tokens.
274
+ """
275
+ # Simple tokenization: split on non-alphanumeric, filter short terms
276
+ tokens = re.findall(r"\b\w+\b", text)
277
+ return [t for t in tokens if len(t) > 1]
278
+
279
+ def _select_diverse(
280
+ self, scored_docs: list[tuple[Document, float]], query: str
281
+ ) -> list[tuple[Document, float]]:
282
+ """Select diverse documents using MMR-style approach.
283
+
284
+ Balances relevance with diversity to avoid redundant results.
285
+
286
+ Args:
287
+ scored_docs: List of (document, relevance_score) tuples.
288
+ query: Original query.
289
+
290
+ Returns:
291
+ Selected documents with diversity considered.
292
+ """
293
+ if not scored_docs:
294
+ return []
295
+
296
+ # Sort by initial relevance
297
+ scored_docs = sorted(scored_docs, key=lambda x: x[1], reverse=True)
298
+
299
+ # Start with most relevant
300
+ selected = [scored_docs[0]]
301
+ remaining = scored_docs[1:]
302
+
303
+ lambda_param = 0.5 # Balance between relevance and diversity
304
+
305
+ while len(selected) < self.max_documents and remaining:
306
+ best_score = -1.0
307
+ best_idx = 0
308
+
309
+ for i, (doc, rel_score) in enumerate(remaining):
310
+ # Calculate max similarity to already selected docs
311
+ max_sim = max(self._document_similarity(doc, sel_doc) for sel_doc, _ in selected)
312
+
313
+ # MMR score: lambda * relevance - (1-lambda) * max_similarity
314
+ mmr_score = lambda_param * rel_score - (1 - lambda_param) * max_sim
315
+
316
+ if mmr_score > best_score:
317
+ best_score = mmr_score
318
+ best_idx = i
319
+
320
+ selected.append(remaining[best_idx])
321
+ remaining.pop(best_idx)
322
+
323
+ return selected
324
+
325
+ def _document_similarity(self, doc1: Document, doc2: Document) -> float:
326
+ """Calculate similarity between two documents.
327
+
328
+ Uses Jaccard similarity on terms for simplicity.
329
+
330
+ Args:
331
+ doc1: First document.
332
+ doc2: Second document.
333
+
334
+ Returns:
335
+ Similarity score between 0 and 1.
336
+ """
337
+ terms1 = set(self._tokenize(doc1.page_content.lower()))
338
+ terms2 = set(self._tokenize(doc2.page_content.lower()))
339
+
340
+ if not terms1 or not terms2:
341
+ return 0.0
342
+
343
+ intersection = len(terms1 & terms2)
344
+ union = len(terms1 | terms2)
345
+
346
+ return intersection / union if union > 0 else 0.0
347
+
348
+ @property
349
+ def last_metrics(self) -> CompressionMetrics | None:
350
+ """Get metrics from the last compression operation."""
351
+ return self._last_metrics
352
+
353
+ def get_compression_stats(self) -> dict[str, Any]:
354
+ """Get statistics from the last compression.
355
+
356
+ Returns:
357
+ Dictionary with compression metrics, or empty if no compression yet.
358
+ """
359
+ if self._last_metrics is None:
360
+ return {}
361
+
362
+ return {
363
+ "documents_before": self._last_metrics.documents_before,
364
+ "documents_after": self._last_metrics.documents_after,
365
+ "documents_removed": self._last_metrics.documents_removed,
366
+ "average_relevance": (
367
+ sum(self._last_metrics.relevance_scores) / len(self._last_metrics.relevance_scores)
368
+ if self._last_metrics.relevance_scores
369
+ else 0.0
370
+ ),
371
+ }