autochunks 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. autochunk/__init__.py +9 -0
  2. autochunk/__main__.py +5 -0
  3. autochunk/adapters/__init__.py +3 -0
  4. autochunk/adapters/haystack.py +68 -0
  5. autochunk/adapters/langchain.py +81 -0
  6. autochunk/adapters/llamaindex.py +94 -0
  7. autochunk/autochunker.py +606 -0
  8. autochunk/chunkers/__init__.py +100 -0
  9. autochunk/chunkers/agentic.py +184 -0
  10. autochunk/chunkers/base.py +16 -0
  11. autochunk/chunkers/contextual_retrieval.py +151 -0
  12. autochunk/chunkers/fixed_length.py +110 -0
  13. autochunk/chunkers/html_section.py +225 -0
  14. autochunk/chunkers/hybrid_semantic_stat.py +199 -0
  15. autochunk/chunkers/layout_aware.py +192 -0
  16. autochunk/chunkers/parent_child.py +172 -0
  17. autochunk/chunkers/proposition.py +175 -0
  18. autochunk/chunkers/python_ast.py +248 -0
  19. autochunk/chunkers/recursive_character.py +215 -0
  20. autochunk/chunkers/semantic_local.py +140 -0
  21. autochunk/chunkers/sentence_aware.py +102 -0
  22. autochunk/cli.py +135 -0
  23. autochunk/config.py +76 -0
  24. autochunk/embedding/__init__.py +22 -0
  25. autochunk/embedding/adapter.py +14 -0
  26. autochunk/embedding/base.py +33 -0
  27. autochunk/embedding/hashing.py +42 -0
  28. autochunk/embedding/local.py +154 -0
  29. autochunk/embedding/ollama.py +66 -0
  30. autochunk/embedding/openai.py +62 -0
  31. autochunk/embedding/tokenizer.py +9 -0
  32. autochunk/enrichment/__init__.py +0 -0
  33. autochunk/enrichment/contextual.py +29 -0
  34. autochunk/eval/__init__.py +0 -0
  35. autochunk/eval/harness.py +177 -0
  36. autochunk/eval/metrics.py +27 -0
  37. autochunk/eval/ragas_eval.py +234 -0
  38. autochunk/eval/synthetic.py +104 -0
  39. autochunk/quality/__init__.py +31 -0
  40. autochunk/quality/deduplicator.py +326 -0
  41. autochunk/quality/overlap_optimizer.py +402 -0
  42. autochunk/quality/post_processor.py +245 -0
  43. autochunk/quality/scorer.py +459 -0
  44. autochunk/retrieval/__init__.py +0 -0
  45. autochunk/retrieval/in_memory.py +47 -0
  46. autochunk/retrieval/parent_child.py +4 -0
  47. autochunk/storage/__init__.py +0 -0
  48. autochunk/storage/cache.py +34 -0
  49. autochunk/storage/plan.py +40 -0
  50. autochunk/utils/__init__.py +0 -0
  51. autochunk/utils/hashing.py +8 -0
  52. autochunk/utils/io.py +176 -0
  53. autochunk/utils/logger.py +64 -0
  54. autochunk/utils/telemetry.py +44 -0
  55. autochunk/utils/text.py +199 -0
  56. autochunks-0.0.8.dist-info/METADATA +133 -0
  57. autochunks-0.0.8.dist-info/RECORD +61 -0
  58. autochunks-0.0.8.dist-info/WHEEL +5 -0
  59. autochunks-0.0.8.dist-info/entry_points.txt +2 -0
  60. autochunks-0.0.8.dist-info/licenses/LICENSE +15 -0
  61. autochunks-0.0.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
1
+
2
+ from .base import BaseEncoder
3
+ from .local import LocalEncoder
4
+ from .hashing import HashingEmbedding
5
+ from .openai import OpenAIEncoder
6
+ from .ollama import OllamaEncoder
7
+
8
+ def get_encoder(provider: str, model_name: str, **kwargs) -> BaseEncoder:
9
+ """
10
+ Factory for chunking-aware embeddings.
11
+ Easily extendable to TEI, OpenAI, etc.
12
+ """
13
+ if provider == "local":
14
+ return LocalEncoder(model_name_or_path=model_name, **kwargs)
15
+ elif provider == "hashing":
16
+ return HashingEmbedding(dim=kwargs.get("dim", 256))
17
+ elif provider == "openai":
18
+ return OpenAIEncoder(model_name=model_name, **kwargs)
19
+ elif provider == "ollama":
20
+ return OllamaEncoder(model_name=model_name, **kwargs)
21
+ else:
22
+ raise ValueError(f"Unknown embedding provider: {provider}")
@@ -0,0 +1,14 @@
1
+
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass
4
+ from typing import Callable, List
5
+
6
+ @dataclass
7
+ class EmbeddingFn:
8
+ name: str
9
+ dim: int
10
+ fn: Callable[[List[str]], List[List[float]]]
11
+ cost_per_1k_tokens: float = 0.0
12
+
13
+ def __call__(self, texts: List[str]):
14
+ return self.fn(texts)
@@ -0,0 +1,33 @@
1
+
2
+ from __future__ import annotations
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional
5
+
6
+ class BaseEncoder(ABC):
7
+ """
8
+ interface for all AutoChunks embedding providers.
9
+ Designed for high-throughput batching and pluggable backends (Local, TEI, OpenAI).
10
+ """
11
+
12
+ @abstractmethod
13
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
14
+ """
15
+ Embed a list of strings into a list of vectors.
16
+ """
17
+ pass
18
+
19
+ @property
20
+ @abstractmethod
21
+ def dimension(self) -> int:
22
+ """
23
+ Return the embedding dimension.
24
+ """
25
+ pass
26
+
27
+ @property
28
+ @abstractmethod
29
+ def model_name(self) -> str:
30
+ """
31
+ Return the model name/ID.
32
+ """
33
+ pass
@@ -0,0 +1,42 @@
1
+
2
+ from __future__ import annotations
3
+ from typing import List
4
+ import hashlib
5
+ import numpy as np
6
+
7
+ from .base import BaseEncoder
8
+
9
+ class HashingEmbedding(BaseEncoder):
10
+ """Deterministic, offline-safe feature hashing embedding.
11
+ Not semantically meaningful but good for plumbing tests.
12
+ """
13
+ def __init__(self, dim: int = 256):
14
+ self._dim = dim
15
+
16
+ @property
17
+ def dimension(self) -> int:
18
+ return self._dim
19
+
20
+ @property
21
+ def model_name(self) -> str:
22
+ return "deterministic_hashing"
23
+
24
+ def _tok_hash(self, tok: str) -> int:
25
+ return int(hashlib.md5(tok.encode('utf-8')).hexdigest(), 16)
26
+
27
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
28
+ vecs = []
29
+ D = self._dim
30
+ for t in texts:
31
+ v = np.zeros(D, dtype=np.float32)
32
+ for tok in t.lower().split():
33
+ h = self._tok_hash(tok)
34
+ idx = h % D
35
+ sign = 1.0 if (h >> 1) & 1 else -1.0
36
+ v[idx] += sign
37
+ # L2 normalize
38
+ norm = np.linalg.norm(v)
39
+ if norm > 0:
40
+ v = v / norm
41
+ vecs.append(v.tolist())
42
+ return vecs
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+ from typing import List, Dict, Any, Optional
3
+ from ..utils.logger import logger
4
+ from .base import BaseEncoder
5
+
6
+ class LocalEncoder(BaseEncoder):
7
+ """
8
+ Local Embedding Engine powered by sentence-transformers.
9
+ Automatically handles GPU/CPU selection and MTEB-aligned pooling logic.
10
+ """
11
+
12
+ def __init__(self, model_name_or_path: str = "BAAI/bge-small-en-v1.5", device: str = None, cache_folder: str = None, trusted_orgs: List[str] = None):
13
+ try:
14
+ from sentence_transformers import SentenceTransformer
15
+ import torch
16
+ except ImportError:
17
+ raise ImportError("Please install sentence-transformers: pip install sentence-transformers torch")
18
+
19
+ # Device Detection
20
+ if device is None:
21
+ if torch.cuda.is_available():
22
+ device = "cuda"
23
+ elif torch.backends.mps.is_available():
24
+ device = "mps"
25
+ else:
26
+ device = "cpu"
27
+
28
+ from ..utils.logger import logger
29
+ logger.info(f"Using device [{device.upper()}] for embeddings.")
30
+
31
+ if device == "cpu":
32
+ # Check if we are missing out on CUDA
33
+ try:
34
+ import subprocess
35
+ nvidia_smi = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
36
+ if nvidia_smi.returncode == 0:
37
+ logger.warning("Tip: NVIDIA GPU detected hardware-wise, but Torch is using CPU. "
38
+ "Suggest reinstalling torch with CUDA support: https://pytorch.org/get-started/locally/")
39
+ except Exception:
40
+ pass
41
+
42
+ # Safety Check: If downloading, ensure it's from a trusted official source
43
+ if not cache_folder and "/" in model_name_or_path:
44
+ org = model_name_or_path.split("/")[0]
45
+ allowed = trusted_orgs or ["ds4sd", "RapidAI", "BAAI", "sentence-transformers"]
46
+ if org not in allowed:
47
+ raise ValueError(f"Security Alert: Attempting to download from untrusted source '{org}'. "
48
+ f"Trusted official orgs are: {allowed}")
49
+
50
+ self.name = model_name_or_path
51
+
52
+ # Check if we might be downloading
53
+ import os
54
+ from ..utils.logger import logger
55
+
56
+ # Determine the effective cache path
57
+ effective_cache = cache_folder or os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
58
+ model_id_folder = "models--" + model_name_or_path.replace("/", "--")
59
+ is_cached = os.path.exists(os.path.join(effective_cache, model_id_folder)) or os.path.exists(model_name_or_path)
60
+
61
+ if not is_cached:
62
+ logger.info(f"Network Download: Local model '{model_name_or_path}' not found at {effective_cache}. Starting download from Hugging Face (Official)...")
63
+ else:
64
+ logger.info(f"Cache Hit: Using local model at {effective_cache if not os.path.exists(model_name_or_path) else model_name_or_path}")
65
+
66
+ # normalize_embeddings=True is standard for cosine similarity retrieval
67
+ self.model = SentenceTransformer(model_name_or_path, device=device, cache_folder=cache_folder)
68
+ self._dim = self.model.get_sentence_embedding_dimension()
69
+
70
+ # Log the detected capability
71
+ limit = self.max_seq_length
72
+ logger.info(f"LocalEncoder: Loaded '{model_name_or_path}' with max sequence length: {limit} tokens (truncation @ ~{int(limit * 4 * 0.95)} chars)")
73
+
74
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
75
+ # Use internal LRU cache to avoid re-embedding identical text segments
76
+ # across multiple candidate variations (common during hyperparameter sweeps)
77
+ if not hasattr(self, "_cache"):
78
+ self._cache = {}
79
+
80
+ results = [None] * len(texts)
81
+ to_embed_indices = []
82
+ to_embed_texts = []
83
+ hits = 0
84
+
85
+ # Determine safe truncation limit
86
+ safe_limit = int(self.max_seq_length * 4 * 0.95)
87
+
88
+ for i, t in enumerate(texts):
89
+ # Check cache first (using full text as key to avoid collisions)
90
+ if t in self._cache:
91
+ results[i] = self._cache[t]
92
+ hits += 1
93
+ else:
94
+ to_embed_indices.append(i)
95
+ # Truncate strictly for embedding model safety
96
+ # We store original text in cache key for recall, but embed the truncated version?
97
+ # No, best to cache the truncated text result.
98
+
99
+ # Actually, if we truncate here, we should be careful.
100
+ # But to save the crash, we must truncate.
101
+ safe_t = t[:safe_limit]
102
+ if len(t) > safe_limit:
103
+ # Log only once to avoid spam
104
+ if not hasattr(self, "_logged_truncation"):
105
+ logger.warning(f"LocalEncoder: Auto-truncating input > {safe_limit} chars for model safety.")
106
+ self._logged_truncation = True
107
+
108
+ to_embed_texts.append(safe_t)
109
+
110
+ if hits > 0:
111
+ logger.info(f"LocalEncoder: Cache Hits={hits}/{len(texts)}")
112
+
113
+ if to_embed_texts:
114
+ import numpy as np
115
+ # Efficient batching is handled internally by sentence-transformers
116
+ embeddings = self.model.encode(to_embed_texts, show_progress_bar=False)
117
+
118
+ # Cross-version compatibility: handle both numpy arrays and lists
119
+ if isinstance(embeddings, np.ndarray):
120
+ embeddings = embeddings.tolist()
121
+
122
+ for relative_idx, emb in enumerate(embeddings):
123
+ # Map back to original global index
124
+ original_idx = to_embed_indices[relative_idx]
125
+
126
+ # Get ORIGINAL text for cache key (so future calls hit cache)
127
+ original_text = texts[original_idx]
128
+
129
+ self._cache[original_text] = emb
130
+ results[original_idx] = emb
131
+
132
+ # Optional: Simple cache eviction if it gets too large (> 10k entries)
133
+ if len(self._cache) > 10000:
134
+ # Clear half the cache if it overflows
135
+ keys = list(self._cache.keys())
136
+ for k in keys[:5000]:
137
+ del self._cache[k]
138
+
139
+ return results
140
+
141
+ @property
142
+ def dimension(self) -> int:
143
+ return self._dim
144
+
145
+ @property
146
+ def model_name(self) -> str:
147
+ return self.name
148
+
149
+ @property
150
+ def max_seq_length(self) -> int:
151
+ """Returns the maximum token length the model can handle."""
152
+ if hasattr(self.model, "max_seq_length"):
153
+ return self.model.max_seq_length
154
+ return 512 # Safe default for BERT-like models
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+ import requests
3
+ from typing import List, Optional
4
+ from ..utils.logger import logger
5
+ from .base import BaseEncoder
6
+
7
+ class OllamaEncoder(BaseEncoder):
8
+ """
9
+ Ollama Embedding Provider.
10
+ Assumes Ollama is running locally at http://localhost:11434
11
+ """
12
+
13
+ def __init__(self, model_name: str = "llama3", base_url: str = "http://localhost:11434"):
14
+ self.name = model_name
15
+ self.base_url = base_url.rstrip("/")
16
+ self._dim = None # Will be detected on first call if possible
17
+
18
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
19
+ url = f"{self.base_url}/api/embed"
20
+
21
+ embeddings = []
22
+ # Ollama /api/embed takes one prompt or a list
23
+ payload = {
24
+ "model": self.name,
25
+ "input": texts
26
+ }
27
+
28
+ try:
29
+ response = requests.post(url, json=payload, timeout=120)
30
+ response.raise_for_status()
31
+ data = response.json()
32
+
33
+ # Ollama returns "embeddings" which is a list of vectors
34
+ results = data.get("embeddings", [])
35
+
36
+ if not results:
37
+ # Fallback to older /api/embeddings if /api/embed is not available or empty
38
+ # /api/embeddings is deprecated but sometimes still used
39
+ logger.warning("Ollama /api/embed returned no results, trying sequential calls (fallback).")
40
+ results = []
41
+ for t in texts:
42
+ res = requests.post(f"{self.base_url}/api/embeddings", json={"model": self.name, "prompt": t}, timeout=30)
43
+ res.raise_for_status()
44
+ results.append(res.json()["embedding"])
45
+
46
+ if results and self._dim is None:
47
+ self._dim = len(results[0])
48
+
49
+ return results
50
+ except Exception as e:
51
+ logger.error(f"Ollama Embedding failed: {e}")
52
+ raise
53
+
54
+ @property
55
+ def dimension(self) -> int:
56
+ if self._dim is None:
57
+ # Try to pulse the model to get dimension
58
+ try:
59
+ self.embed_batch(["pulsing"])
60
+ except:
61
+ return 4096 # common fallback for llama
62
+ return self._dim
63
+
64
+ @property
65
+ def model_name(self) -> str:
66
+ return self.name
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import requests
4
+ from typing import List, Optional
5
+ from ..utils.logger import logger
6
+ from .base import BaseEncoder
7
+
8
+ class OpenAIEncoder(BaseEncoder):
9
+ """
10
+ OpenAI Embedding Provider.
11
+ Requires OPENAI_API_KEY environment variable.
12
+ """
13
+
14
+ def __init__(self, model_name: str = "text-embedding-3-small", api_key: Optional[str] = None):
15
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY")
16
+ if not self.api_key:
17
+ logger.warning("OPENAI_API_KEY not found. OpenAI embeddings will fail unless provided.")
18
+
19
+ self.name = model_name
20
+ # Common dimensions as fallback
21
+ self._dim_map = {
22
+ "text-embedding-3-small": 1536,
23
+ "text-embedding-3-large": 3072,
24
+ "text-embedding-ada-002": 1536
25
+ }
26
+ self._dim = self._dim_map.get(model_name, 1536)
27
+
28
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
29
+ if not self.api_key:
30
+ raise ValueError("OPENAI_API_KEY is required for OpenAI embeddings.")
31
+
32
+ url = "https://api.openai.com/v1/embeddings"
33
+ headers = {
34
+ "Content-Type": "application/json",
35
+ "Authorization": f"Bearer {self.api_key}"
36
+ }
37
+
38
+ # OpenAI supports batching internally
39
+ payload = {
40
+ "input": texts,
41
+ "model": self.name
42
+ }
43
+
44
+ try:
45
+ response = requests.post(url, headers=headers, json=payload, timeout=60)
46
+ response.raise_for_status()
47
+ data = response.json()
48
+
49
+ # OpenAI returns data in the same order as input
50
+ embeddings = [item["embedding"] for item in data["data"]]
51
+ return embeddings
52
+ except Exception as e:
53
+ logger.error(f"OpenAI Embedding failed: {e}")
54
+ raise
55
+
56
+ @property
57
+ def dimension(self) -> int:
58
+ return self._dim
59
+
60
+ @property
61
+ def model_name(self) -> str:
62
+ return self.name
@@ -0,0 +1,9 @@
1
+
2
+ from dataclasses import dataclass
3
+ from typing import Callable
4
+
5
+ @dataclass
6
+ class SimpleTokenizer:
7
+ name: str = "whitespace"
8
+ def tokens(self, text: str):
9
+ return text.split()
File without changes
@@ -0,0 +1,29 @@
1
+
2
+ from __future__ import annotations
3
+ from typing import List, Dict, Any, Callable
4
+ from ..utils.logger import logger
5
+
6
+ class ContextualEnricher:
7
+ """
8
+ Implements Anthropic's 'Contextual Retrieval' concept.
9
+ Prepends a short summary of the parent document to each chunk.
10
+ """
11
+ def __init__(self, summarizer_fn: Callable[[str], str] = None):
12
+ self.summarizer = summarizer_fn
13
+
14
+ def enrich_batch(self, chunks: List[Dict[str, Any]], doc_text: str) -> List[Dict[str, Any]]:
15
+ if not self.summarizer:
16
+ logger.warning("No summarizer provided for ContextualEnricher. Skipping.")
17
+ return chunks
18
+
19
+ try:
20
+ summary = self.summarizer(doc_text)
21
+ for chunk in chunks:
22
+ # Prepend the summary as context
23
+ original_text = chunk["text"]
24
+ chunk["text"] = f"[Document Summary: {summary}]\n\n{original_text}"
25
+ chunk["meta"]["contextual_summary"] = summary
26
+ except Exception as e:
27
+ logger.error(f"Error during contextual enrichment: {e}")
28
+
29
+ return chunks
File without changes
@@ -0,0 +1,177 @@
1
+
2
+ from __future__ import annotations
3
+ from typing import List, Dict, Any, Optional, Callable
4
+ import random, time
5
+ from ..utils.text import split_sentences, whitespace_tokens
6
+ from ..utils.hashing import content_hash
7
+ from ..retrieval.in_memory import InMemoryIndex
8
+ from ..eval.metrics import mrr_at_k, ndcg_at_k, recall_at_k
9
+ from ..eval.synthetic import SyntheticQAGenerator
10
+ from ..utils.logger import logger
11
+
12
+ class EvalHarness:
13
+ def __init__(self, embedding_fn, k: int = 10):
14
+ self.embedding = embedding_fn
15
+ self.k = k
16
+ self.generator = SyntheticQAGenerator()
17
+
18
+ def build_synthetic_qa(self, docs: List[Dict], on_progress: Optional[Callable[[str], None]] = None) -> List[Dict]:
19
+ qa = []
20
+ rng = random.Random(42)
21
+ for d in docs:
22
+ sents = split_sentences(d["text"])[:20]
23
+ # 1. Add standard paraphrased queries
24
+ for s in sents[:2]:
25
+ query = self.generator.generate_hard_query(s, on_progress)
26
+ qa.append({
27
+ "id": content_hash(d["id"] + query),
28
+ "doc_id": d["id"],
29
+ "query": query,
30
+ "answer_span": s,
31
+ })
32
+ # 2. Add boundary-crossing queries (Advanced)
33
+ if len(sents) > 2:
34
+ boundary_qa = self.generator.generate_boundary_qa(d["id"], sents[:5], on_progress)
35
+ qa.extend(boundary_qa)
36
+ return qa
37
+
38
+ def evaluate(self, chunks: List[Dict], qa: List[Dict]) -> Dict[str, Any]:
39
+ # Build index
40
+ # 1. Add distractor/noise chunks to ensure search isn't too trivial
41
+ noise_chunks = []
42
+ for i in range(20):
43
+ noise_chunks.append({
44
+ "id": f"noise_{i}",
45
+ "doc_id": "noise",
46
+ "text": f"This is some random distractor text about something unrelated {i} to increase complexity.",
47
+ "meta": {}
48
+ })
49
+
50
+ all_eval_chunks = chunks + noise_chunks
51
+ logger.info(f"EvalHarness: Encoding {len(all_eval_chunks)} chunks (including {len(noise_chunks)} noise)...")
52
+
53
+ # Determine dynamic safety limit
54
+ model_limit = 512 # Fallback
55
+
56
+ # Check if it's Hashing (which has no limit)
57
+ is_hashing = getattr(self.embedding, "name", "").startswith("hashing") or "HashingEmbedding" in str(type(self.embedding))
58
+
59
+ if is_hashing:
60
+ MAX_CHARS = 1_000_000 # Virtually infinite
61
+ model_limit = 250_000
62
+ else:
63
+ if hasattr(self.embedding, "max_seq_length"):
64
+ model_limit = self.embedding.max_seq_length
65
+ elif hasattr(self.embedding, "__self__") and hasattr(self.embedding.__self__, "max_seq_length"):
66
+ # Handle bound methods
67
+ model_limit = self.embedding.__self__.max_seq_length
68
+ MAX_CHARS = int(model_limit * 4 * 0.95)
69
+
70
+ has_warned = False
71
+ def truncate(text: str) -> str:
72
+ nonlocal has_warned
73
+ if len(text) > MAX_CHARS:
74
+ if not has_warned:
75
+ # Only warn if it's NOT hashing (since hashing truncation is rare/impossible with this high limit)
76
+ logger.warning(f"EvalHarness: Truncating chunks > {MAX_CHARS} chars to fit embedding model ({model_limit} tokens).")
77
+ has_warned = True
78
+ return text[:MAX_CHARS]
79
+ return text
80
+
81
+ enc_start = time.time()
82
+ try:
83
+ vectors = self.embedding([truncate(c["text"]) for c in all_eval_chunks])
84
+ except RuntimeError as e:
85
+ if "expanded size" in str(e) or "512" in str(e):
86
+ logger.error(f"EvalHarness: Embedding failed - some chunks exceed model's max token length. Truncating aggressively...")
87
+ # Try with more aggressive truncation
88
+ vectors = self.embedding([truncate(c["text"])[:1200] for c in all_eval_chunks])
89
+ else:
90
+ raise
91
+ enc_time = time.time() - enc_start
92
+ logger.info(f"EvalHarness: Encoding complete in {enc_time:.2f}s")
93
+
94
+ index = InMemoryIndex(dim=len(vectors[0]))
95
+ index.add(vectors, all_eval_chunks)
96
+
97
+ mrr, ndcg, recall, covered = 0.0, 0.0, 0.0, 0
98
+
99
+ # --- BATCH QUERY EVALUATION ---
100
+ logger.info(f"EvalHarness: Encoding and searching {len(qa)} queries in batch mode...")
101
+ # Reuse detection logic
102
+ def truncate_q(text: str) -> str:
103
+ return truncate(text) # Use the same robust logic and warning system
104
+
105
+ query_texts = [truncate_q(item["query"]) for item in qa]
106
+ try:
107
+ query_vectors = self.embedding(query_texts)
108
+ except RuntimeError as e:
109
+ if "expanded size" in str(e):
110
+ logger.error(f"EvalHarness: Query embedding failed - text too long. Truncating aggressively...")
111
+ query_vectors = self.embedding([truncate_q(q)[:1000] for q in query_texts])
112
+ else:
113
+ raise
114
+
115
+ # Batch search (using updated InMemoryIndex with batch support)
116
+ batch_hits = index.search(query_vectors, top_k=self.k)
117
+
118
+ for i, (item, hits) in enumerate(zip(qa, batch_hits)):
119
+ target_doc = item["doc_id"].lower().replace("\\", "/")
120
+
121
+ # --- TOKEN-LEVEL RECALL ---
122
+ answer_tokens = set(whitespace_tokens(item["answer_span"].lower()))
123
+ found_tokens = set()
124
+
125
+ # --- RANKING & DCG RELEVANCE ---
126
+ retrieved_rels = []
127
+ has_perfect_match = False
128
+
129
+ for rank, (idx, dist) in enumerate(hits):
130
+ c = index.meta[idx]
131
+ rel = 0.0
132
+
133
+ # Check Document Match
134
+ if c["doc_id"].lower().replace("\\", "/") == item["doc_id"].lower().replace("\\", "/"):
135
+ # Normalize whitespace for robust substring matching
136
+ chunk_text_norm = " ".join(c["text"].lower().split())
137
+ answer_norm = " ".join(item["answer_span"].lower().split())
138
+
139
+ # 1. Full Answer Match (Highest Relevance)
140
+ if answer_norm in chunk_text_norm:
141
+ rel = 2.0
142
+ has_perfect_match = True
143
+ found_tokens.update(answer_tokens)
144
+ else:
145
+ # 2. Token Overlap Match (Partial Relevance)
146
+ chunk_tokens = set(chunk_text_norm.split())
147
+ overlap = answer_tokens.intersection(chunk_tokens)
148
+ if overlap:
149
+ rel = 1.0 + (len(overlap) / len(answer_tokens))
150
+ found_tokens.update(overlap)
151
+
152
+ retrieved_rels.append(rel)
153
+
154
+ # Score Aggregation
155
+ # Relaxed Coverage: Allow non-exact but high-overlap matches (relevance > 1.5)
156
+ # This accounts for markdown artifacts, minor cleaning diffs, etc.
157
+ if has_perfect_match or any(r > 1.5 for r in retrieved_rels):
158
+ covered += 1
159
+
160
+ # MRR: Binary look (was there a perfect match in top-K?)
161
+ binary_rels = [1 if r >= 2.0 else 0 for r in retrieved_rels]
162
+ mrr += mrr_at_k(binary_rels, self.k)
163
+
164
+ # nDCG: Uses the graduated relevance (0, 1.X, 2.0)
165
+ ndcg += ndcg_at_k(retrieved_rels, self.k)
166
+
167
+ # Recall: Percentage of total answer tokens covered by all top-K results
168
+ current_recall = len(found_tokens) / len(answer_tokens) if answer_tokens else 0
169
+ recall += current_recall
170
+
171
+ n = max(1, len(qa))
172
+ return {
173
+ "mrr@k": mrr / n,
174
+ "ndcg@k": ndcg / n,
175
+ "recall@k": recall / n,
176
+ "coverage": covered / n,
177
+ }
@@ -0,0 +1,27 @@
1
+
2
+ from __future__ import annotations
3
+ from typing import List
4
+ import math
5
+
6
+ def dcg(rels: List[int]) -> float:
7
+ return sum((2**r - 1) / math.log2(i+2) for i, r in enumerate(rels))
8
+
9
+ def ndcg_at_k(rels: List[int], k: int) -> float:
10
+ rels_k = rels[:k]
11
+ ideal = sorted(rels_k, reverse=True)
12
+ denom = dcg(ideal)
13
+ if denom == 0:
14
+ return 0.0
15
+ return dcg(rels_k) / denom
16
+
17
+ def mrr_at_k(rels: List[int], k: int) -> float:
18
+ for i, r in enumerate(rels[:k]):
19
+ if r > 0:
20
+ return 1.0 / (i+1)
21
+ return 0.0
22
+
23
+ def recall_at_k(rels: List[int], k: int, total_relevant: int) -> float:
24
+ if total_relevant == 0:
25
+ return 0.0
26
+ hit = sum(1 for r in rels[:k] if r > 0)
27
+ return hit / total_relevant