autochunks 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autochunk/__init__.py +9 -0
- autochunk/__main__.py +5 -0
- autochunk/adapters/__init__.py +3 -0
- autochunk/adapters/haystack.py +68 -0
- autochunk/adapters/langchain.py +81 -0
- autochunk/adapters/llamaindex.py +94 -0
- autochunk/autochunker.py +606 -0
- autochunk/chunkers/__init__.py +100 -0
- autochunk/chunkers/agentic.py +184 -0
- autochunk/chunkers/base.py +16 -0
- autochunk/chunkers/contextual_retrieval.py +151 -0
- autochunk/chunkers/fixed_length.py +110 -0
- autochunk/chunkers/html_section.py +225 -0
- autochunk/chunkers/hybrid_semantic_stat.py +199 -0
- autochunk/chunkers/layout_aware.py +192 -0
- autochunk/chunkers/parent_child.py +172 -0
- autochunk/chunkers/proposition.py +175 -0
- autochunk/chunkers/python_ast.py +248 -0
- autochunk/chunkers/recursive_character.py +215 -0
- autochunk/chunkers/semantic_local.py +140 -0
- autochunk/chunkers/sentence_aware.py +102 -0
- autochunk/cli.py +135 -0
- autochunk/config.py +76 -0
- autochunk/embedding/__init__.py +22 -0
- autochunk/embedding/adapter.py +14 -0
- autochunk/embedding/base.py +33 -0
- autochunk/embedding/hashing.py +42 -0
- autochunk/embedding/local.py +154 -0
- autochunk/embedding/ollama.py +66 -0
- autochunk/embedding/openai.py +62 -0
- autochunk/embedding/tokenizer.py +9 -0
- autochunk/enrichment/__init__.py +0 -0
- autochunk/enrichment/contextual.py +29 -0
- autochunk/eval/__init__.py +0 -0
- autochunk/eval/harness.py +177 -0
- autochunk/eval/metrics.py +27 -0
- autochunk/eval/ragas_eval.py +234 -0
- autochunk/eval/synthetic.py +104 -0
- autochunk/quality/__init__.py +31 -0
- autochunk/quality/deduplicator.py +326 -0
- autochunk/quality/overlap_optimizer.py +402 -0
- autochunk/quality/post_processor.py +245 -0
- autochunk/quality/scorer.py +459 -0
- autochunk/retrieval/__init__.py +0 -0
- autochunk/retrieval/in_memory.py +47 -0
- autochunk/retrieval/parent_child.py +4 -0
- autochunk/storage/__init__.py +0 -0
- autochunk/storage/cache.py +34 -0
- autochunk/storage/plan.py +40 -0
- autochunk/utils/__init__.py +0 -0
- autochunk/utils/hashing.py +8 -0
- autochunk/utils/io.py +176 -0
- autochunk/utils/logger.py +64 -0
- autochunk/utils/telemetry.py +44 -0
- autochunk/utils/text.py +199 -0
- autochunks-0.0.8.dist-info/METADATA +133 -0
- autochunks-0.0.8.dist-info/RECORD +61 -0
- autochunks-0.0.8.dist-info/WHEEL +5 -0
- autochunks-0.0.8.dist-info/entry_points.txt +2 -0
- autochunks-0.0.8.dist-info/licenses/LICENSE +15 -0
- autochunks-0.0.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
|
|
2
|
+
from .base import BaseEncoder
|
|
3
|
+
from .local import LocalEncoder
|
|
4
|
+
from .hashing import HashingEmbedding
|
|
5
|
+
from .openai import OpenAIEncoder
|
|
6
|
+
from .ollama import OllamaEncoder
|
|
7
|
+
|
|
8
|
+
def get_encoder(provider: str, model_name: str, **kwargs) -> BaseEncoder:
|
|
9
|
+
"""
|
|
10
|
+
Factory for chunking-aware embeddings.
|
|
11
|
+
Easily extendable to TEI, OpenAI, etc.
|
|
12
|
+
"""
|
|
13
|
+
if provider == "local":
|
|
14
|
+
return LocalEncoder(model_name_or_path=model_name, **kwargs)
|
|
15
|
+
elif provider == "hashing":
|
|
16
|
+
return HashingEmbedding(dim=kwargs.get("dim", 256))
|
|
17
|
+
elif provider == "openai":
|
|
18
|
+
return OpenAIEncoder(model_name=model_name, **kwargs)
|
|
19
|
+
elif provider == "ollama":
|
|
20
|
+
return OllamaEncoder(model_name=model_name, **kwargs)
|
|
21
|
+
else:
|
|
22
|
+
raise ValueError(f"Unknown embedding provider: {provider}")
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable, List
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class EmbeddingFn:
|
|
8
|
+
name: str
|
|
9
|
+
dim: int
|
|
10
|
+
fn: Callable[[List[str]], List[List[float]]]
|
|
11
|
+
cost_per_1k_tokens: float = 0.0
|
|
12
|
+
|
|
13
|
+
def __call__(self, texts: List[str]):
|
|
14
|
+
return self.fn(texts)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
class BaseEncoder(ABC):
|
|
7
|
+
"""
|
|
8
|
+
interface for all AutoChunks embedding providers.
|
|
9
|
+
Designed for high-throughput batching and pluggable backends (Local, TEI, OpenAI).
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
14
|
+
"""
|
|
15
|
+
Embed a list of strings into a list of vectors.
|
|
16
|
+
"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def dimension(self) -> int:
|
|
22
|
+
"""
|
|
23
|
+
Return the embedding dimension.
|
|
24
|
+
"""
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def model_name(self) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Return the model name/ID.
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from typing import List
|
|
4
|
+
import hashlib
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from .base import BaseEncoder
|
|
8
|
+
|
|
9
|
+
class HashingEmbedding(BaseEncoder):
|
|
10
|
+
"""Deterministic, offline-safe feature hashing embedding.
|
|
11
|
+
Not semantically meaningful but good for plumbing tests.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, dim: int = 256):
|
|
14
|
+
self._dim = dim
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def dimension(self) -> int:
|
|
18
|
+
return self._dim
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def model_name(self) -> str:
|
|
22
|
+
return "deterministic_hashing"
|
|
23
|
+
|
|
24
|
+
def _tok_hash(self, tok: str) -> int:
|
|
25
|
+
return int(hashlib.md5(tok.encode('utf-8')).hexdigest(), 16)
|
|
26
|
+
|
|
27
|
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
28
|
+
vecs = []
|
|
29
|
+
D = self._dim
|
|
30
|
+
for t in texts:
|
|
31
|
+
v = np.zeros(D, dtype=np.float32)
|
|
32
|
+
for tok in t.lower().split():
|
|
33
|
+
h = self._tok_hash(tok)
|
|
34
|
+
idx = h % D
|
|
35
|
+
sign = 1.0 if (h >> 1) & 1 else -1.0
|
|
36
|
+
v[idx] += sign
|
|
37
|
+
# L2 normalize
|
|
38
|
+
norm = np.linalg.norm(v)
|
|
39
|
+
if norm > 0:
|
|
40
|
+
v = v / norm
|
|
41
|
+
vecs.append(v.tolist())
|
|
42
|
+
return vecs
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Dict, Any, Optional
|
|
3
|
+
from ..utils.logger import logger
|
|
4
|
+
from .base import BaseEncoder
|
|
5
|
+
|
|
6
|
+
class LocalEncoder(BaseEncoder):
|
|
7
|
+
"""
|
|
8
|
+
Local Embedding Engine powered by sentence-transformers.
|
|
9
|
+
Automatically handles GPU/CPU selection and MTEB-aligned pooling logic.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name_or_path: str = "BAAI/bge-small-en-v1.5", device: str = None, cache_folder: str = None, trusted_orgs: List[str] = None):
|
|
13
|
+
try:
|
|
14
|
+
from sentence_transformers import SentenceTransformer
|
|
15
|
+
import torch
|
|
16
|
+
except ImportError:
|
|
17
|
+
raise ImportError("Please install sentence-transformers: pip install sentence-transformers torch")
|
|
18
|
+
|
|
19
|
+
# Device Detection
|
|
20
|
+
if device is None:
|
|
21
|
+
if torch.cuda.is_available():
|
|
22
|
+
device = "cuda"
|
|
23
|
+
elif torch.backends.mps.is_available():
|
|
24
|
+
device = "mps"
|
|
25
|
+
else:
|
|
26
|
+
device = "cpu"
|
|
27
|
+
|
|
28
|
+
from ..utils.logger import logger
|
|
29
|
+
logger.info(f"Using device [{device.upper()}] for embeddings.")
|
|
30
|
+
|
|
31
|
+
if device == "cpu":
|
|
32
|
+
# Check if we are missing out on CUDA
|
|
33
|
+
try:
|
|
34
|
+
import subprocess
|
|
35
|
+
nvidia_smi = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
36
|
+
if nvidia_smi.returncode == 0:
|
|
37
|
+
logger.warning("Tip: NVIDIA GPU detected hardware-wise, but Torch is using CPU. "
|
|
38
|
+
"Suggest reinstalling torch with CUDA support: https://pytorch.org/get-started/locally/")
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
# Safety Check: If downloading, ensure it's from a trusted official source
|
|
43
|
+
if not cache_folder and "/" in model_name_or_path:
|
|
44
|
+
org = model_name_or_path.split("/")[0]
|
|
45
|
+
allowed = trusted_orgs or ["ds4sd", "RapidAI", "BAAI", "sentence-transformers"]
|
|
46
|
+
if org not in allowed:
|
|
47
|
+
raise ValueError(f"Security Alert: Attempting to download from untrusted source '{org}'. "
|
|
48
|
+
f"Trusted official orgs are: {allowed}")
|
|
49
|
+
|
|
50
|
+
self.name = model_name_or_path
|
|
51
|
+
|
|
52
|
+
# Check if we might be downloading
|
|
53
|
+
import os
|
|
54
|
+
from ..utils.logger import logger
|
|
55
|
+
|
|
56
|
+
# Determine the effective cache path
|
|
57
|
+
effective_cache = cache_folder or os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
|
|
58
|
+
model_id_folder = "models--" + model_name_or_path.replace("/", "--")
|
|
59
|
+
is_cached = os.path.exists(os.path.join(effective_cache, model_id_folder)) or os.path.exists(model_name_or_path)
|
|
60
|
+
|
|
61
|
+
if not is_cached:
|
|
62
|
+
logger.info(f"Network Download: Local model '{model_name_or_path}' not found at {effective_cache}. Starting download from Hugging Face (Official)...")
|
|
63
|
+
else:
|
|
64
|
+
logger.info(f"Cache Hit: Using local model at {effective_cache if not os.path.exists(model_name_or_path) else model_name_or_path}")
|
|
65
|
+
|
|
66
|
+
# normalize_embeddings=True is standard for cosine similarity retrieval
|
|
67
|
+
self.model = SentenceTransformer(model_name_or_path, device=device, cache_folder=cache_folder)
|
|
68
|
+
self._dim = self.model.get_sentence_embedding_dimension()
|
|
69
|
+
|
|
70
|
+
# Log the detected capability
|
|
71
|
+
limit = self.max_seq_length
|
|
72
|
+
logger.info(f"LocalEncoder: Loaded '{model_name_or_path}' with max sequence length: {limit} tokens (truncation @ ~{int(limit * 4 * 0.95)} chars)")
|
|
73
|
+
|
|
74
|
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
75
|
+
# Use internal LRU cache to avoid re-embedding identical text segments
|
|
76
|
+
# across multiple candidate variations (common during hyperparameter sweeps)
|
|
77
|
+
if not hasattr(self, "_cache"):
|
|
78
|
+
self._cache = {}
|
|
79
|
+
|
|
80
|
+
results = [None] * len(texts)
|
|
81
|
+
to_embed_indices = []
|
|
82
|
+
to_embed_texts = []
|
|
83
|
+
hits = 0
|
|
84
|
+
|
|
85
|
+
# Determine safe truncation limit
|
|
86
|
+
safe_limit = int(self.max_seq_length * 4 * 0.95)
|
|
87
|
+
|
|
88
|
+
for i, t in enumerate(texts):
|
|
89
|
+
# Check cache first (using full text as key to avoid collisions)
|
|
90
|
+
if t in self._cache:
|
|
91
|
+
results[i] = self._cache[t]
|
|
92
|
+
hits += 1
|
|
93
|
+
else:
|
|
94
|
+
to_embed_indices.append(i)
|
|
95
|
+
# Truncate strictly for embedding model safety
|
|
96
|
+
# We store original text in cache key for recall, but embed the truncated version?
|
|
97
|
+
# No, best to cache the truncated text result.
|
|
98
|
+
|
|
99
|
+
# Actually, if we truncate here, we should be careful.
|
|
100
|
+
# But to save the crash, we must truncate.
|
|
101
|
+
safe_t = t[:safe_limit]
|
|
102
|
+
if len(t) > safe_limit:
|
|
103
|
+
# Log only once to avoid spam
|
|
104
|
+
if not hasattr(self, "_logged_truncation"):
|
|
105
|
+
logger.warning(f"LocalEncoder: Auto-truncating input > {safe_limit} chars for model safety.")
|
|
106
|
+
self._logged_truncation = True
|
|
107
|
+
|
|
108
|
+
to_embed_texts.append(safe_t)
|
|
109
|
+
|
|
110
|
+
if hits > 0:
|
|
111
|
+
logger.info(f"LocalEncoder: Cache Hits={hits}/{len(texts)}")
|
|
112
|
+
|
|
113
|
+
if to_embed_texts:
|
|
114
|
+
import numpy as np
|
|
115
|
+
# Efficient batching is handled internally by sentence-transformers
|
|
116
|
+
embeddings = self.model.encode(to_embed_texts, show_progress_bar=False)
|
|
117
|
+
|
|
118
|
+
# Cross-version compatibility: handle both numpy arrays and lists
|
|
119
|
+
if isinstance(embeddings, np.ndarray):
|
|
120
|
+
embeddings = embeddings.tolist()
|
|
121
|
+
|
|
122
|
+
for relative_idx, emb in enumerate(embeddings):
|
|
123
|
+
# Map back to original global index
|
|
124
|
+
original_idx = to_embed_indices[relative_idx]
|
|
125
|
+
|
|
126
|
+
# Get ORIGINAL text for cache key (so future calls hit cache)
|
|
127
|
+
original_text = texts[original_idx]
|
|
128
|
+
|
|
129
|
+
self._cache[original_text] = emb
|
|
130
|
+
results[original_idx] = emb
|
|
131
|
+
|
|
132
|
+
# Optional: Simple cache eviction if it gets too large (> 10k entries)
|
|
133
|
+
if len(self._cache) > 10000:
|
|
134
|
+
# Clear half the cache if it overflows
|
|
135
|
+
keys = list(self._cache.keys())
|
|
136
|
+
for k in keys[:5000]:
|
|
137
|
+
del self._cache[k]
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def dimension(self) -> int:
|
|
143
|
+
return self._dim
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def model_name(self) -> str:
|
|
147
|
+
return self.name
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def max_seq_length(self) -> int:
|
|
151
|
+
"""Returns the maximum token length the model can handle."""
|
|
152
|
+
if hasattr(self.model, "max_seq_length"):
|
|
153
|
+
return self.model.max_seq_length
|
|
154
|
+
return 512 # Safe default for BERT-like models
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import requests
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
from ..utils.logger import logger
|
|
5
|
+
from .base import BaseEncoder
|
|
6
|
+
|
|
7
|
+
class OllamaEncoder(BaseEncoder):
|
|
8
|
+
"""
|
|
9
|
+
Ollama Embedding Provider.
|
|
10
|
+
Assumes Ollama is running locally at http://localhost:11434
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, model_name: str = "llama3", base_url: str = "http://localhost:11434"):
|
|
14
|
+
self.name = model_name
|
|
15
|
+
self.base_url = base_url.rstrip("/")
|
|
16
|
+
self._dim = None # Will be detected on first call if possible
|
|
17
|
+
|
|
18
|
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
19
|
+
url = f"{self.base_url}/api/embed"
|
|
20
|
+
|
|
21
|
+
embeddings = []
|
|
22
|
+
# Ollama /api/embed takes one prompt or a list
|
|
23
|
+
payload = {
|
|
24
|
+
"model": self.name,
|
|
25
|
+
"input": texts
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
response = requests.post(url, json=payload, timeout=120)
|
|
30
|
+
response.raise_for_status()
|
|
31
|
+
data = response.json()
|
|
32
|
+
|
|
33
|
+
# Ollama returns "embeddings" which is a list of vectors
|
|
34
|
+
results = data.get("embeddings", [])
|
|
35
|
+
|
|
36
|
+
if not results:
|
|
37
|
+
# Fallback to older /api/embeddings if /api/embed is not available or empty
|
|
38
|
+
# /api/embeddings is deprecated but sometimes still used
|
|
39
|
+
logger.warning("Ollama /api/embed returned no results, trying sequential calls (fallback).")
|
|
40
|
+
results = []
|
|
41
|
+
for t in texts:
|
|
42
|
+
res = requests.post(f"{self.base_url}/api/embeddings", json={"model": self.name, "prompt": t}, timeout=30)
|
|
43
|
+
res.raise_for_status()
|
|
44
|
+
results.append(res.json()["embedding"])
|
|
45
|
+
|
|
46
|
+
if results and self._dim is None:
|
|
47
|
+
self._dim = len(results[0])
|
|
48
|
+
|
|
49
|
+
return results
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.error(f"Ollama Embedding failed: {e}")
|
|
52
|
+
raise
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def dimension(self) -> int:
|
|
56
|
+
if self._dim is None:
|
|
57
|
+
# Try to pulse the model to get dimension
|
|
58
|
+
try:
|
|
59
|
+
self.embed_batch(["pulsing"])
|
|
60
|
+
except:
|
|
61
|
+
return 4096 # common fallback for llama
|
|
62
|
+
return self._dim
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def model_name(self) -> str:
|
|
66
|
+
return self.name
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
import requests
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
from ..utils.logger import logger
|
|
6
|
+
from .base import BaseEncoder
|
|
7
|
+
|
|
8
|
+
class OpenAIEncoder(BaseEncoder):
|
|
9
|
+
"""
|
|
10
|
+
OpenAI Embedding Provider.
|
|
11
|
+
Requires OPENAI_API_KEY environment variable.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, model_name: str = "text-embedding-3-small", api_key: Optional[str] = None):
|
|
15
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
16
|
+
if not self.api_key:
|
|
17
|
+
logger.warning("OPENAI_API_KEY not found. OpenAI embeddings will fail unless provided.")
|
|
18
|
+
|
|
19
|
+
self.name = model_name
|
|
20
|
+
# Common dimensions as fallback
|
|
21
|
+
self._dim_map = {
|
|
22
|
+
"text-embedding-3-small": 1536,
|
|
23
|
+
"text-embedding-3-large": 3072,
|
|
24
|
+
"text-embedding-ada-002": 1536
|
|
25
|
+
}
|
|
26
|
+
self._dim = self._dim_map.get(model_name, 1536)
|
|
27
|
+
|
|
28
|
+
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
|
29
|
+
if not self.api_key:
|
|
30
|
+
raise ValueError("OPENAI_API_KEY is required for OpenAI embeddings.")
|
|
31
|
+
|
|
32
|
+
url = "https://api.openai.com/v1/embeddings"
|
|
33
|
+
headers = {
|
|
34
|
+
"Content-Type": "application/json",
|
|
35
|
+
"Authorization": f"Bearer {self.api_key}"
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
# OpenAI supports batching internally
|
|
39
|
+
payload = {
|
|
40
|
+
"input": texts,
|
|
41
|
+
"model": self.name
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
data = response.json()
|
|
48
|
+
|
|
49
|
+
# OpenAI returns data in the same order as input
|
|
50
|
+
embeddings = [item["embedding"] for item in data["data"]]
|
|
51
|
+
return embeddings
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.error(f"OpenAI Embedding failed: {e}")
|
|
54
|
+
raise
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def dimension(self) -> int:
|
|
58
|
+
return self._dim
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def model_name(self) -> str:
|
|
62
|
+
return self.name
|
|
File without changes
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from typing import List, Dict, Any, Callable
|
|
4
|
+
from ..utils.logger import logger
|
|
5
|
+
|
|
6
|
+
class ContextualEnricher:
|
|
7
|
+
"""
|
|
8
|
+
Implements Anthropic's 'Contextual Retrieval' concept.
|
|
9
|
+
Prepends a short summary of the parent document to each chunk.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, summarizer_fn: Callable[[str], str] = None):
|
|
12
|
+
self.summarizer = summarizer_fn
|
|
13
|
+
|
|
14
|
+
def enrich_batch(self, chunks: List[Dict[str, Any]], doc_text: str) -> List[Dict[str, Any]]:
|
|
15
|
+
if not self.summarizer:
|
|
16
|
+
logger.warning("No summarizer provided for ContextualEnricher. Skipping.")
|
|
17
|
+
return chunks
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
summary = self.summarizer(doc_text)
|
|
21
|
+
for chunk in chunks:
|
|
22
|
+
# Prepend the summary as context
|
|
23
|
+
original_text = chunk["text"]
|
|
24
|
+
chunk["text"] = f"[Document Summary: {summary}]\n\n{original_text}"
|
|
25
|
+
chunk["meta"]["contextual_summary"] = summary
|
|
26
|
+
except Exception as e:
|
|
27
|
+
logger.error(f"Error during contextual enrichment: {e}")
|
|
28
|
+
|
|
29
|
+
return chunks
|
|
File without changes
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from typing import List, Dict, Any, Optional, Callable
|
|
4
|
+
import random, time
|
|
5
|
+
from ..utils.text import split_sentences, whitespace_tokens
|
|
6
|
+
from ..utils.hashing import content_hash
|
|
7
|
+
from ..retrieval.in_memory import InMemoryIndex
|
|
8
|
+
from ..eval.metrics import mrr_at_k, ndcg_at_k, recall_at_k
|
|
9
|
+
from ..eval.synthetic import SyntheticQAGenerator
|
|
10
|
+
from ..utils.logger import logger
|
|
11
|
+
|
|
12
|
+
class EvalHarness:
|
|
13
|
+
def __init__(self, embedding_fn, k: int = 10):
|
|
14
|
+
self.embedding = embedding_fn
|
|
15
|
+
self.k = k
|
|
16
|
+
self.generator = SyntheticQAGenerator()
|
|
17
|
+
|
|
18
|
+
def build_synthetic_qa(self, docs: List[Dict], on_progress: Optional[Callable[[str], None]] = None) -> List[Dict]:
|
|
19
|
+
qa = []
|
|
20
|
+
rng = random.Random(42)
|
|
21
|
+
for d in docs:
|
|
22
|
+
sents = split_sentences(d["text"])[:20]
|
|
23
|
+
# 1. Add standard paraphrased queries
|
|
24
|
+
for s in sents[:2]:
|
|
25
|
+
query = self.generator.generate_hard_query(s, on_progress)
|
|
26
|
+
qa.append({
|
|
27
|
+
"id": content_hash(d["id"] + query),
|
|
28
|
+
"doc_id": d["id"],
|
|
29
|
+
"query": query,
|
|
30
|
+
"answer_span": s,
|
|
31
|
+
})
|
|
32
|
+
# 2. Add boundary-crossing queries (Advanced)
|
|
33
|
+
if len(sents) > 2:
|
|
34
|
+
boundary_qa = self.generator.generate_boundary_qa(d["id"], sents[:5], on_progress)
|
|
35
|
+
qa.extend(boundary_qa)
|
|
36
|
+
return qa
|
|
37
|
+
|
|
38
|
+
def evaluate(self, chunks: List[Dict], qa: List[Dict]) -> Dict[str, Any]:
|
|
39
|
+
# Build index
|
|
40
|
+
# 1. Add distractor/noise chunks to ensure search isn't too trivial
|
|
41
|
+
noise_chunks = []
|
|
42
|
+
for i in range(20):
|
|
43
|
+
noise_chunks.append({
|
|
44
|
+
"id": f"noise_{i}",
|
|
45
|
+
"doc_id": "noise",
|
|
46
|
+
"text": f"This is some random distractor text about something unrelated {i} to increase complexity.",
|
|
47
|
+
"meta": {}
|
|
48
|
+
})
|
|
49
|
+
|
|
50
|
+
all_eval_chunks = chunks + noise_chunks
|
|
51
|
+
logger.info(f"EvalHarness: Encoding {len(all_eval_chunks)} chunks (including {len(noise_chunks)} noise)...")
|
|
52
|
+
|
|
53
|
+
# Determine dynamic safety limit
|
|
54
|
+
model_limit = 512 # Fallback
|
|
55
|
+
|
|
56
|
+
# Check if it's Hashing (which has no limit)
|
|
57
|
+
is_hashing = getattr(self.embedding, "name", "").startswith("hashing") or "HashingEmbedding" in str(type(self.embedding))
|
|
58
|
+
|
|
59
|
+
if is_hashing:
|
|
60
|
+
MAX_CHARS = 1_000_000 # Virtually infinite
|
|
61
|
+
model_limit = 250_000
|
|
62
|
+
else:
|
|
63
|
+
if hasattr(self.embedding, "max_seq_length"):
|
|
64
|
+
model_limit = self.embedding.max_seq_length
|
|
65
|
+
elif hasattr(self.embedding, "__self__") and hasattr(self.embedding.__self__, "max_seq_length"):
|
|
66
|
+
# Handle bound methods
|
|
67
|
+
model_limit = self.embedding.__self__.max_seq_length
|
|
68
|
+
MAX_CHARS = int(model_limit * 4 * 0.95)
|
|
69
|
+
|
|
70
|
+
has_warned = False
|
|
71
|
+
def truncate(text: str) -> str:
|
|
72
|
+
nonlocal has_warned
|
|
73
|
+
if len(text) > MAX_CHARS:
|
|
74
|
+
if not has_warned:
|
|
75
|
+
# Only warn if it's NOT hashing (since hashing truncation is rare/impossible with this high limit)
|
|
76
|
+
logger.warning(f"EvalHarness: Truncating chunks > {MAX_CHARS} chars to fit embedding model ({model_limit} tokens).")
|
|
77
|
+
has_warned = True
|
|
78
|
+
return text[:MAX_CHARS]
|
|
79
|
+
return text
|
|
80
|
+
|
|
81
|
+
enc_start = time.time()
|
|
82
|
+
try:
|
|
83
|
+
vectors = self.embedding([truncate(c["text"]) for c in all_eval_chunks])
|
|
84
|
+
except RuntimeError as e:
|
|
85
|
+
if "expanded size" in str(e) or "512" in str(e):
|
|
86
|
+
logger.error(f"EvalHarness: Embedding failed - some chunks exceed model's max token length. Truncating aggressively...")
|
|
87
|
+
# Try with more aggressive truncation
|
|
88
|
+
vectors = self.embedding([truncate(c["text"])[:1200] for c in all_eval_chunks])
|
|
89
|
+
else:
|
|
90
|
+
raise
|
|
91
|
+
enc_time = time.time() - enc_start
|
|
92
|
+
logger.info(f"EvalHarness: Encoding complete in {enc_time:.2f}s")
|
|
93
|
+
|
|
94
|
+
index = InMemoryIndex(dim=len(vectors[0]))
|
|
95
|
+
index.add(vectors, all_eval_chunks)
|
|
96
|
+
|
|
97
|
+
mrr, ndcg, recall, covered = 0.0, 0.0, 0.0, 0
|
|
98
|
+
|
|
99
|
+
# --- BATCH QUERY EVALUATION ---
|
|
100
|
+
logger.info(f"EvalHarness: Encoding and searching {len(qa)} queries in batch mode...")
|
|
101
|
+
# Reuse detection logic
|
|
102
|
+
def truncate_q(text: str) -> str:
|
|
103
|
+
return truncate(text) # Use the same robust logic and warning system
|
|
104
|
+
|
|
105
|
+
query_texts = [truncate_q(item["query"]) for item in qa]
|
|
106
|
+
try:
|
|
107
|
+
query_vectors = self.embedding(query_texts)
|
|
108
|
+
except RuntimeError as e:
|
|
109
|
+
if "expanded size" in str(e):
|
|
110
|
+
logger.error(f"EvalHarness: Query embedding failed - text too long. Truncating aggressively...")
|
|
111
|
+
query_vectors = self.embedding([truncate_q(q)[:1000] for q in query_texts])
|
|
112
|
+
else:
|
|
113
|
+
raise
|
|
114
|
+
|
|
115
|
+
# Batch search (using updated InMemoryIndex with batch support)
|
|
116
|
+
batch_hits = index.search(query_vectors, top_k=self.k)
|
|
117
|
+
|
|
118
|
+
for i, (item, hits) in enumerate(zip(qa, batch_hits)):
|
|
119
|
+
target_doc = item["doc_id"].lower().replace("\\", "/")
|
|
120
|
+
|
|
121
|
+
# --- TOKEN-LEVEL RECALL ---
|
|
122
|
+
answer_tokens = set(whitespace_tokens(item["answer_span"].lower()))
|
|
123
|
+
found_tokens = set()
|
|
124
|
+
|
|
125
|
+
# --- RANKING & DCG RELEVANCE ---
|
|
126
|
+
retrieved_rels = []
|
|
127
|
+
has_perfect_match = False
|
|
128
|
+
|
|
129
|
+
for rank, (idx, dist) in enumerate(hits):
|
|
130
|
+
c = index.meta[idx]
|
|
131
|
+
rel = 0.0
|
|
132
|
+
|
|
133
|
+
# Check Document Match
|
|
134
|
+
if c["doc_id"].lower().replace("\\", "/") == item["doc_id"].lower().replace("\\", "/"):
|
|
135
|
+
# Normalize whitespace for robust substring matching
|
|
136
|
+
chunk_text_norm = " ".join(c["text"].lower().split())
|
|
137
|
+
answer_norm = " ".join(item["answer_span"].lower().split())
|
|
138
|
+
|
|
139
|
+
# 1. Full Answer Match (Highest Relevance)
|
|
140
|
+
if answer_norm in chunk_text_norm:
|
|
141
|
+
rel = 2.0
|
|
142
|
+
has_perfect_match = True
|
|
143
|
+
found_tokens.update(answer_tokens)
|
|
144
|
+
else:
|
|
145
|
+
# 2. Token Overlap Match (Partial Relevance)
|
|
146
|
+
chunk_tokens = set(chunk_text_norm.split())
|
|
147
|
+
overlap = answer_tokens.intersection(chunk_tokens)
|
|
148
|
+
if overlap:
|
|
149
|
+
rel = 1.0 + (len(overlap) / len(answer_tokens))
|
|
150
|
+
found_tokens.update(overlap)
|
|
151
|
+
|
|
152
|
+
retrieved_rels.append(rel)
|
|
153
|
+
|
|
154
|
+
# Score Aggregation
|
|
155
|
+
# Relaxed Coverage: Allow non-exact but high-overlap matches (relevance > 1.5)
|
|
156
|
+
# This accounts for markdown artifacts, minor cleaning diffs, etc.
|
|
157
|
+
if has_perfect_match or any(r > 1.5 for r in retrieved_rels):
|
|
158
|
+
covered += 1
|
|
159
|
+
|
|
160
|
+
# MRR: Binary look (was there a perfect match in top-K?)
|
|
161
|
+
binary_rels = [1 if r >= 2.0 else 0 for r in retrieved_rels]
|
|
162
|
+
mrr += mrr_at_k(binary_rels, self.k)
|
|
163
|
+
|
|
164
|
+
# nDCG: Uses the graduated relevance (0, 1.X, 2.0)
|
|
165
|
+
ndcg += ndcg_at_k(retrieved_rels, self.k)
|
|
166
|
+
|
|
167
|
+
# Recall: Percentage of total answer tokens covered by all top-K results
|
|
168
|
+
current_recall = len(found_tokens) / len(answer_tokens) if answer_tokens else 0
|
|
169
|
+
recall += current_recall
|
|
170
|
+
|
|
171
|
+
n = max(1, len(qa))
|
|
172
|
+
return {
|
|
173
|
+
"mrr@k": mrr / n,
|
|
174
|
+
"ndcg@k": ndcg / n,
|
|
175
|
+
"recall@k": recall / n,
|
|
176
|
+
"coverage": covered / n,
|
|
177
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from typing import List
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
def dcg(rels: List[int]) -> float:
|
|
7
|
+
return sum((2**r - 1) / math.log2(i+2) for i, r in enumerate(rels))
|
|
8
|
+
|
|
9
|
+
def ndcg_at_k(rels: List[int], k: int) -> float:
|
|
10
|
+
rels_k = rels[:k]
|
|
11
|
+
ideal = sorted(rels_k, reverse=True)
|
|
12
|
+
denom = dcg(ideal)
|
|
13
|
+
if denom == 0:
|
|
14
|
+
return 0.0
|
|
15
|
+
return dcg(rels_k) / denom
|
|
16
|
+
|
|
17
|
+
def mrr_at_k(rels: List[int], k: int) -> float:
|
|
18
|
+
for i, r in enumerate(rels[:k]):
|
|
19
|
+
if r > 0:
|
|
20
|
+
return 1.0 / (i+1)
|
|
21
|
+
return 0.0
|
|
22
|
+
|
|
23
|
+
def recall_at_k(rels: List[int], k: int, total_relevant: int) -> float:
|
|
24
|
+
if total_relevant == 0:
|
|
25
|
+
return 0.0
|
|
26
|
+
hit = sum(1 for r in rels[:k] if r > 0)
|
|
27
|
+
return hit / total_relevant
|