tokenpack-rag 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tokenpack/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ """Context-window budget optimization for retrieval-augmented LLM workflows."""
2
+
3
+ from tokenpack.chunk_profiles import ChunkSizeConfig, resolve_chunk_size_config
4
+ from tokenpack.chunking import SemanticThresholdChunker, StructureAwareChunker
5
+ from tokenpack.compression import CompressionConfig, CompressionResult, compress_chunks
6
+ from tokenpack.dataset import GoldRecord
7
+ from tokenpack.embeddings import DEFAULT_EMBEDDING_MODEL, SentenceTransformerEmbedder, make_embedder
8
+ from tokenpack.reranking import CrossEncoderReranker, apply_reranker, blend_reranker_scores
9
+ from tokenpack.selectors import select_chunks
10
+
11
+ __all__ = [
12
+ "CrossEncoderReranker",
13
+ "DEFAULT_EMBEDDING_MODEL",
14
+ "GoldRecord",
15
+ "SemanticThresholdChunker",
16
+ "SentenceTransformerEmbedder",
17
+ "StructureAwareChunker",
18
+ "ChunkSizeConfig",
19
+ "CompressionConfig",
20
+ "CompressionResult",
21
+ "compress_chunks",
22
+ "apply_reranker",
23
+ "blend_reranker_scores",
24
+ "make_embedder",
25
+ "resolve_chunk_size_config",
26
+ "select_chunks",
27
+ ]
28
+
tokenpack/benchmark.py ADDED
@@ -0,0 +1,256 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from tokenpack.dataset import GoldRecord, propose_gold_records
8
+ from tokenpack.embeddings import Embedder, cosine
9
+ from tokenpack.index import ChunkIndex
10
+ from tokenpack.models import ScoredChunk, SelectionResult
11
+ from tokenpack.scoring import DEFAULT_SCORING_PROFILE, score_chunks
12
+ from tokenpack.selectors import select_chunks
13
+
14
+
15
+ STRATEGIES = [
16
+ "document-prefix",
17
+ "full-document",
18
+ "top-k",
19
+ "production-rag",
20
+ "budget-top-k",
21
+ "greedy-value",
22
+ "greedy-density",
23
+ "mmr",
24
+ "knapsack",
25
+ "knapsack-redundancy",
26
+ "knapsack-coverage",
27
+ ]
28
+
29
+
30
+ def synthetic_queries(index: ChunkIndex, sample_size: int = 12) -> list[dict]:
31
+ return [
32
+ {
33
+ "query": record.query,
34
+ "evidence_chunk_id": record.evidence_chunk_ids[0],
35
+ "source_path": record.source_path,
36
+ }
37
+ for record in propose_gold_records(index, sample_size=sample_size)
38
+ ]
39
+
40
+
41
+ def run_benchmark(
42
+ index: ChunkIndex,
43
+ embedder: Embedder,
44
+ budget: int,
45
+ reserve_output: int,
46
+ sample_size: int = 12,
47
+ candidate_pool: int = 250,
48
+ scoring: str = DEFAULT_SCORING_PROFILE,
49
+ ) -> dict:
50
+ """Developer smoke benchmark using auto-proposed single-evidence queries."""
51
+
52
+ records = propose_gold_records(index, sample_size=sample_size)
53
+ payload = run_gold_benchmark(
54
+ index=index,
55
+ embedder=embedder,
56
+ records=records,
57
+ budgets=[budget],
58
+ reserve_output=reserve_output,
59
+ candidate_pool=candidate_pool,
60
+ scoring=scoring,
61
+ )
62
+ return payload["budgets"][0] | {
63
+ "mode": "smoke",
64
+ "scoring": scoring,
65
+ "query_count": len(records),
66
+ "queries": payload["budgets"][0]["queries"],
67
+ }
68
+
69
+
70
+ def run_gold_benchmark(
71
+ index: ChunkIndex,
72
+ embedder: Embedder,
73
+ records: list[GoldRecord],
74
+ budgets: list[int],
75
+ reserve_output: int,
76
+ candidate_pool: int = 250,
77
+ strategies: list[str] | None = None,
78
+ redundancy_penalty: float = 0.35,
79
+ scoring: str = DEFAULT_SCORING_PROFILE,
80
+ ) -> dict:
81
+ strategy_names = strategies or STRATEGIES
82
+ budget_runs = [
83
+ _run_gold_for_budget(
84
+ index=index,
85
+ embedder=embedder,
86
+ records=records,
87
+ budget=budget,
88
+ reserve_output=reserve_output,
89
+ candidate_pool=candidate_pool,
90
+ strategies=strategy_names,
91
+ redundancy_penalty=redundancy_penalty,
92
+ scoring=scoring,
93
+ )
94
+ for budget in budgets
95
+ ]
96
+ return {
97
+ "mode": "gold",
98
+ "scoring": scoring,
99
+ "query_count": len(records),
100
+ "reserve_output": reserve_output,
101
+ "strategies": strategy_names,
102
+ "budgets": budget_runs,
103
+ }
104
+
105
+
106
+ def save_benchmark(payload: dict, path: str | Path) -> None:
107
+ target = Path(path)
108
+ target.parent.mkdir(parents=True, exist_ok=True)
109
+ target.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
110
+
111
+
112
+ def _run_gold_for_budget(
113
+ index: ChunkIndex,
114
+ embedder: Embedder,
115
+ records: list[GoldRecord],
116
+ budget: int,
117
+ reserve_output: int,
118
+ candidate_pool: int,
119
+ strategies: list[str],
120
+ redundancy_penalty: float,
121
+ scoring: str,
122
+ ) -> dict:
123
+ effective_budget = max(0, budget - reserve_output)
124
+ totals = {
125
+ strategy: {
126
+ "recall": 0.0,
127
+ "precision": 0.0,
128
+ "coverage": 0.0,
129
+ "used_tokens": 0,
130
+ "budget_utilization": 0.0,
131
+ "total_value": 0.0,
132
+ "redundancy": 0.0,
133
+ "latency": 0.0,
134
+ "runs": 0,
135
+ }
136
+ for strategy in strategies
137
+ }
138
+ per_query = []
139
+ for record in records:
140
+ query_embedding = embedder.embed([record.query])[0]
141
+ scored = score_chunks(
142
+ query_embedding,
143
+ index.chunks,
144
+ index.embeddings,
145
+ scoring=scoring,
146
+ query_text=record.query,
147
+ )
148
+ scored_redundant = score_chunks(
149
+ query_embedding,
150
+ index.chunks,
151
+ index.embeddings,
152
+ redundancy_penalty=redundancy_penalty,
153
+ scoring=scoring,
154
+ query_text=record.query,
155
+ )
156
+ query_result = {
157
+ "query": record.query,
158
+ "evidence_chunk_ids": record.evidence_chunk_ids,
159
+ "strategies": {},
160
+ }
161
+ for strategy in strategies:
162
+ source_scores = scored_redundant if strategy == "knapsack-redundancy" else scored
163
+ started = time.perf_counter()
164
+ result = select_chunks(
165
+ source_scores,
166
+ strategy=strategy,
167
+ budget=effective_budget,
168
+ candidate_pool=candidate_pool,
169
+ embeddings=index.embeddings,
170
+ coverage_query=record.query,
171
+ )
172
+ elapsed = time.perf_counter() - started
173
+ metrics = _selection_metrics(result, record.evidence_chunk_ids, effective_budget)
174
+ metrics["latency_seconds"] = elapsed
175
+ _accumulate(totals[strategy], metrics, result)
176
+ query_result["strategies"][strategy] = metrics | {
177
+ "selected_count": len(result.selected),
178
+ "selected_chunk_ids": [item.chunk.id for item in result.selected],
179
+ }
180
+ per_query.append(query_result)
181
+
182
+ summary = {strategy: _summarize(values) for strategy, values in totals.items()}
183
+ return {
184
+ "budget": budget,
185
+ "reserve_output": reserve_output,
186
+ "effective_budget": effective_budget,
187
+ "query_count": len(records),
188
+ "summary": summary,
189
+ "queries": per_query,
190
+ }
191
+
192
+
193
+ def _selection_metrics(result: SelectionResult, evidence_ids: list[str], budget: int) -> dict:
194
+ selected_ids = {item.chunk.id for item in result.selected}
195
+ evidence_set = set(evidence_ids)
196
+ matched = selected_ids & evidence_set
197
+ recall = len(matched) / max(1, len(evidence_set))
198
+ precision = len(matched) / max(1, len(selected_ids))
199
+ return {
200
+ "evidence_recall_at_budget": recall,
201
+ "evidence_precision": precision,
202
+ "coverage_ratio": 1.0 if evidence_set and evidence_set.issubset(selected_ids) else 0.0,
203
+ "used_tokens": result.used_tokens,
204
+ "budget_utilization": result.used_tokens / max(1, budget),
205
+ "over_budget": result.used_tokens > budget,
206
+ "over_budget_tokens": max(0, result.used_tokens - budget),
207
+ "total_value": result.total_value,
208
+ "value_density": result.total_value / max(1, result.used_tokens),
209
+ "redundancy_score": redundancy_score(result.selected),
210
+ }
211
+
212
+
213
+ def redundancy_score(selected: list[ScoredChunk]) -> float:
214
+ embeddings = [item.embedding for item in selected if item.embedding is not None]
215
+ if len(embeddings) < 2:
216
+ return 0.0
217
+ total = 0.0
218
+ comparisons = 0
219
+ for left_index, left in enumerate(embeddings):
220
+ for right in embeddings[left_index + 1 :]:
221
+ total += max(0.0, cosine(left, right))
222
+ comparisons += 1
223
+ return total / max(1, comparisons)
224
+
225
+
226
+ def _accumulate(total: dict, metrics: dict, result: SelectionResult) -> None:
227
+ total["recall"] += metrics["evidence_recall_at_budget"]
228
+ total["precision"] += metrics["evidence_precision"]
229
+ total["coverage"] += metrics["coverage_ratio"]
230
+ total["used_tokens"] += result.used_tokens
231
+ total["budget_utilization"] += metrics["budget_utilization"]
232
+ total["over_budget"] = total.get("over_budget", 0) + int(metrics["over_budget"])
233
+ total["over_budget_tokens"] = total.get("over_budget_tokens", 0) + metrics["over_budget_tokens"]
234
+ total["total_value"] += result.total_value
235
+ total["redundancy"] += metrics["redundancy_score"]
236
+ total["latency"] += metrics["latency_seconds"]
237
+ total["runs"] += 1
238
+
239
+
240
+ def _summarize(values: dict) -> dict:
241
+ runs = max(1, values["runs"])
242
+ avg_tokens = values["used_tokens"] / runs
243
+ avg_value = values["total_value"] / runs
244
+ return {
245
+ "evidence_recall_at_budget": values["recall"] / runs,
246
+ "evidence_precision": values["precision"] / runs,
247
+ "coverage_ratio": values["coverage"] / runs,
248
+ "avg_used_tokens": avg_tokens,
249
+ "budget_utilization": values["budget_utilization"] / runs,
250
+ "over_budget_rate": values.get("over_budget", 0) / runs,
251
+ "avg_over_budget_tokens": values.get("over_budget_tokens", 0) / runs,
252
+ "avg_total_value": avg_value,
253
+ "value_density": avg_value / max(1.0, avg_tokens),
254
+ "redundancy_score": values["redundancy"] / runs,
255
+ "latency_seconds": values["latency"] / runs,
256
+ }
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True, slots=True)
7
+ class ChunkSizeConfig:
8
+ target_tokens: int
9
+ min_tokens: int
10
+ max_tokens: int
11
+
12
+
13
+ CHUNK_SIZE_PRESETS: dict[str, ChunkSizeConfig] = {
14
+ "default": ChunkSizeConfig(target_tokens=650, min_tokens=120, max_tokens=900),
15
+ "low-budget": ChunkSizeConfig(target_tokens=250, min_tokens=40, max_tokens=320),
16
+ }
17
+
18
+
19
+ def resolve_chunk_size_config(
20
+ preset: str,
21
+ target_tokens: int,
22
+ min_tokens: int,
23
+ max_tokens: int,
24
+ ) -> ChunkSizeConfig:
25
+ if preset == "manual":
26
+ return ChunkSizeConfig(
27
+ target_tokens=target_tokens,
28
+ min_tokens=min_tokens,
29
+ max_tokens=max_tokens,
30
+ )
31
+ try:
32
+ return CHUNK_SIZE_PRESETS[preset]
33
+ except KeyError as exc:
34
+ choices = ", ".join(["manual", *CHUNK_SIZE_PRESETS])
35
+ raise ValueError(f"Unknown chunk size preset: {preset}. Choose one of: {choices}.") from exc
tokenpack/chunking.py ADDED
@@ -0,0 +1,340 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import re
5
+ from collections.abc import Sequence
6
+ from typing import Any
7
+
8
+ from tokenpack.embeddings import cosine
9
+ from tokenpack.models import Chunk, TextBlock
10
+ from tokenpack.tokenization import TokenCounter
11
+
12
+ SENTENCE_RE = re.compile(r"(?<=[.!?])\s+")
13
+
14
+
15
+ class _ChunkGroupBase:
16
+ """Shared token-bounded grouping helpers for structure-aware chunkers."""
17
+
18
+ def __init__(
19
+ self,
20
+ target_tokens: int = 650,
21
+ min_tokens: int = 120,
22
+ max_tokens: int = 900,
23
+ token_counter: TokenCounter | None = None,
24
+ ) -> None:
25
+ if min_tokens > target_tokens or target_tokens > max_tokens:
26
+ raise ValueError("Expected min_tokens <= target_tokens <= max_tokens.")
27
+ self.target_tokens = target_tokens
28
+ self.min_tokens = min_tokens
29
+ self.max_tokens = max_tokens
30
+ self.token_counter = token_counter or TokenCounter()
31
+
32
+ def _split_large_block(self, block_id: int, block: TextBlock) -> list[Chunk]:
33
+ units = self._split_block_units(block)
34
+ chunks: list[Chunk] = []
35
+ current_units: list[str] = []
36
+ current_tokens = 0
37
+ split_offset = 0
38
+
39
+ def flush() -> None:
40
+ nonlocal current_units, current_tokens, split_offset
41
+ if not current_units:
42
+ return
43
+ separator = "\n" if block.metadata.get("content_type") == "code" else " "
44
+ text = separator.join(current_units).strip()
45
+ if text:
46
+ chunks.append(
47
+ self._make_chunk(
48
+ [(block_id, block, self.token_counter.count(text))],
49
+ text_override=text,
50
+ suffix=f"split-{split_offset}",
51
+ char_start=block.char_start,
52
+ char_end=block.char_start + len(text),
53
+ )
54
+ )
55
+ split_offset += 1
56
+ current_units = []
57
+ current_tokens = 0
58
+
59
+ for unit in units:
60
+ unit_tokens = max(1, self.token_counter.count(unit))
61
+ if unit_tokens > self.max_tokens:
62
+ flush()
63
+ for piece in self._split_oversized_unit(unit):
64
+ chunks.append(
65
+ self._make_chunk(
66
+ [(block_id, block, self.token_counter.count(piece))],
67
+ text_override=piece,
68
+ suffix=f"split-{split_offset}",
69
+ char_start=block.char_start,
70
+ char_end=block.char_start + len(piece),
71
+ )
72
+ )
73
+ split_offset += 1
74
+ continue
75
+ if current_units and current_tokens + unit_tokens > self.max_tokens:
76
+ flush()
77
+ current_units.append(unit)
78
+ current_tokens += unit_tokens
79
+ if current_tokens >= self.target_tokens:
80
+ flush()
81
+ flush()
82
+ return chunks
83
+
84
+ def _split_block_units(self, block: TextBlock) -> list[str]:
85
+ text = block.text.strip()
86
+ if not text:
87
+ return []
88
+ if block.metadata.get("content_type") == "code":
89
+ units = [line.rstrip() for line in text.splitlines() if line.strip()]
90
+ return units or [text]
91
+ paragraphs = [item.strip() for item in re.split(r"\n\s*\n+", text) if item.strip()]
92
+ units: list[str] = []
93
+ for paragraph in paragraphs or [text]:
94
+ sentences = [sentence.strip() for sentence in SENTENCE_RE.split(paragraph) if sentence.strip()]
95
+ units.extend(sentences or [paragraph])
96
+ return units
97
+
98
+ def _split_oversized_unit(self, unit: str) -> list[str]:
99
+ words = unit.split()
100
+ pieces: list[str] = []
101
+ current: list[str] = []
102
+ for word in words:
103
+ current.append(word)
104
+ if self.token_counter.count(" ".join(current)) >= self.target_tokens:
105
+ pieces.append(" ".join(current))
106
+ current = []
107
+ if current:
108
+ pieces.append(" ".join(current))
109
+ return pieces or [unit]
110
+
111
+ def _flush(self, items: list[tuple[int, TextBlock, int]]) -> list[Chunk]:
112
+ if not items:
113
+ return []
114
+ return [self._make_chunk(items)]
115
+
116
+ def _make_chunk(
117
+ self,
118
+ items: list[tuple[int, TextBlock, int]],
119
+ text_override: str | None = None,
120
+ suffix: str = "",
121
+ char_start: int | None = None,
122
+ char_end: int | None = None,
123
+ ) -> Chunk:
124
+ blocks = [item[1] for item in items]
125
+ text = text_override if text_override is not None else "\n\n".join(block.text for block in blocks)
126
+ token_count = max(1, self.token_counter.count(text))
127
+ start_page = next((block.page for block in blocks if block.page is not None), None)
128
+ end_page = next((block.page for block in reversed(blocks) if block.page is not None), start_page)
129
+ digest_input = (
130
+ f"{blocks[0].source_path}:{blocks[0].paragraph_index}:"
131
+ f"{blocks[-1].paragraph_index}:{suffix}:{text[:80]}"
132
+ )
133
+ digest = hashlib.sha1(digest_input.encode("utf-8", errors="replace")).hexdigest()[:12]
134
+ metadata = self._chunk_metadata(blocks)
135
+ return Chunk(
136
+ id=f"chunk-{digest}",
137
+ text=text,
138
+ source_path=blocks[0].source_path,
139
+ document_index=blocks[0].document_index,
140
+ start_page=start_page,
141
+ end_page=end_page,
142
+ start_paragraph=blocks[0].paragraph_index,
143
+ end_paragraph=blocks[-1].paragraph_index,
144
+ char_start=blocks[0].char_start if char_start is None else char_start,
145
+ char_end=blocks[-1].char_end if char_end is None else char_end,
146
+ token_count=token_count,
147
+ block_ids=[item[0] for item in items],
148
+ metadata=metadata,
149
+ )
150
+
151
+ def _chunk_metadata(self, blocks: list[TextBlock]) -> dict[str, Any]:
152
+ metadata: dict[str, Any] = {"bbox": [block.bbox for block in blocks if block.bbox is not None]}
153
+ all_keys = {key for block in blocks for key in block.metadata}
154
+ for key in all_keys:
155
+ values = [block.metadata.get(key) for block in blocks if key in block.metadata]
156
+ unique_values = {repr(value): value for value in values}
157
+ if len(unique_values) == 1:
158
+ metadata[key] = values[0]
159
+ elif key == "content_type":
160
+ metadata[key] = "mixed"
161
+ start_lines = [block.metadata.get("start_line") for block in blocks if isinstance(block.metadata.get("start_line"), int)]
162
+ end_lines = [block.metadata.get("end_line") for block in blocks if isinstance(block.metadata.get("end_line"), int)]
163
+ if start_lines:
164
+ metadata["start_line"] = min(start_lines)
165
+ if end_lines:
166
+ metadata["end_line"] = max(end_lines)
167
+ return metadata
168
+
169
+
170
+ class StructureAwareChunker(_ChunkGroupBase):
171
+ """Chunk documents and code using structural metadata plus semantic drift."""
172
+
173
+ def __init__(
174
+ self,
175
+ target_tokens: int = 650,
176
+ min_tokens: int = 120,
177
+ max_tokens: int = 900,
178
+ token_counter: TokenCounter | None = None,
179
+ block_embeddings: Sequence[list[float]] | None = None,
180
+ semantic_threshold: float = 0.35,
181
+ ) -> None:
182
+ super().__init__(
183
+ target_tokens=target_tokens,
184
+ min_tokens=min_tokens,
185
+ max_tokens=max_tokens,
186
+ token_counter=token_counter,
187
+ )
188
+ self.block_embeddings = block_embeddings
189
+ self.semantic_threshold = semantic_threshold
190
+
191
+ def chunk(self, blocks: Sequence[TextBlock]) -> list[Chunk]:
192
+ if self.block_embeddings is not None and len(self.block_embeddings) != len(blocks):
193
+ raise ValueError("StructureAwareChunker requires one embedding per text block when semantic boundaries are enabled.")
194
+
195
+ chunks: list[Chunk] = []
196
+ current: list[tuple[int, TextBlock, int]] = []
197
+ current_tokens = 0
198
+
199
+ def flush() -> None:
200
+ nonlocal current, current_tokens
201
+ chunks.extend(self._flush(current))
202
+ current = []
203
+ current_tokens = 0
204
+
205
+ for block_id, block in enumerate(blocks):
206
+ block_tokens = max(1, self.token_counter.count(block.text))
207
+ content_type = block.metadata.get("content_type", "document")
208
+ is_symbol = bool(block.metadata.get("symbol_name") or block.metadata.get("symbol_kind"))
209
+
210
+ if content_type == "code" and is_symbol:
211
+ flush()
212
+ if block_tokens > self.max_tokens:
213
+ chunks.extend(self._split_large_block(block_id, block))
214
+ else:
215
+ chunks.extend(self._flush([(block_id, block, block_tokens)]))
216
+ continue
217
+
218
+ if block_tokens > self.max_tokens:
219
+ flush()
220
+ chunks.extend(self._split_large_block(block_id, block))
221
+ continue
222
+
223
+ if current and self._structural_boundary(current[-1][1], block):
224
+ flush()
225
+
226
+ would_exceed = current_tokens + block_tokens > self.max_tokens
227
+ topic_shift = self._semantic_boundary(block_id, block, current)
228
+ good_enough = current_tokens >= self.min_tokens
229
+ if current and good_enough and (would_exceed or topic_shift):
230
+ flush()
231
+
232
+ current.append((block_id, block, block_tokens))
233
+ current_tokens += block_tokens
234
+
235
+ if current_tokens >= self.target_tokens:
236
+ flush()
237
+
238
+ flush()
239
+ for chunk in chunks:
240
+ chunk.metadata["chunker"] = "structure-aware"
241
+ if self.block_embeddings is not None:
242
+ chunk.metadata["semantic_threshold"] = self.semantic_threshold
243
+ return chunks
244
+
245
+ def _structural_boundary(self, previous: TextBlock, current: TextBlock) -> bool:
246
+ if previous.source_path != current.source_path:
247
+ return True
248
+ previous_type = previous.metadata.get("content_type", "document")
249
+ current_type = current.metadata.get("content_type", "document")
250
+ if previous_type != current_type:
251
+ return True
252
+ if previous.metadata.get("section_hint") != current.metadata.get("section_hint"):
253
+ return bool(previous.metadata.get("section_hint") or current.metadata.get("section_hint"))
254
+ return False
255
+
256
+ def _semantic_boundary(self, block_id: int, block: TextBlock, current: list[tuple[int, TextBlock, int]]) -> bool:
257
+ if self.block_embeddings is None or not current:
258
+ return False
259
+ previous_id, previous, _ = current[-1]
260
+ if previous.source_path != block.source_path:
261
+ return False
262
+ if previous.metadata.get("content_type", "document") != "document":
263
+ return False
264
+ if block.metadata.get("content_type", "document") != "document":
265
+ return False
266
+ similarity = cosine(self.block_embeddings[previous_id], self.block_embeddings[block_id])
267
+ return similarity < self.semantic_threshold
268
+
269
+
270
+ class SemanticThresholdChunker(_ChunkGroupBase):
271
+ """Start a new chunk when adjacent block embeddings indicate topic drift."""
272
+
273
+ def __init__(
274
+ self,
275
+ block_embeddings: Sequence[list[float]],
276
+ similarity_threshold: float = 0.35,
277
+ target_tokens: int = 650,
278
+ min_tokens: int = 120,
279
+ max_tokens: int = 900,
280
+ token_counter: TokenCounter | None = None,
281
+ ) -> None:
282
+ super().__init__(
283
+ target_tokens=target_tokens,
284
+ min_tokens=min_tokens,
285
+ max_tokens=max_tokens,
286
+ token_counter=token_counter,
287
+ )
288
+ self.block_embeddings = block_embeddings
289
+ self.similarity_threshold = similarity_threshold
290
+
291
+ def chunk(self, blocks: Sequence[TextBlock]) -> list[Chunk]:
292
+ if len(self.block_embeddings) != len(blocks):
293
+ raise ValueError("SemanticThresholdChunker requires one embedding per text block.")
294
+
295
+ chunks: list[Chunk] = []
296
+ current: list[tuple[int, TextBlock, int]] = []
297
+ current_tokens = 0
298
+
299
+ for block_id, block in enumerate(blocks):
300
+ block_tokens = max(1, self.token_counter.count(block.text))
301
+ if block_tokens > self.max_tokens:
302
+ chunks.extend(self._flush(current))
303
+ current = []
304
+ current_tokens = 0
305
+ chunks.extend(self._split_large_block(block_id, block))
306
+ continue
307
+
308
+ starts_new_doc = current and block.source_path != current[-1][1].source_path
309
+ would_exceed = current_tokens + block_tokens > self.max_tokens
310
+ topic_shift = self._topic_shift(block_id, current)
311
+ good_enough = current_tokens >= self.min_tokens
312
+
313
+ if current and (starts_new_doc or (would_exceed and good_enough) or (topic_shift and good_enough)):
314
+ chunks.extend(self._flush(current))
315
+ current = []
316
+ current_tokens = 0
317
+
318
+ current.append((block_id, block, block_tokens))
319
+ current_tokens += block_tokens
320
+
321
+ if current_tokens >= self.target_tokens:
322
+ chunks.extend(self._flush(current))
323
+ current = []
324
+ current_tokens = 0
325
+
326
+ chunks.extend(self._flush(current))
327
+ for chunk in chunks:
328
+ chunk.metadata["chunker"] = "semantic-threshold"
329
+ chunk.metadata["similarity_threshold"] = self.similarity_threshold
330
+ return chunks
331
+
332
+ def _topic_shift(self, block_id: int, current: list[tuple[int, TextBlock, int]]) -> bool:
333
+ if not current:
334
+ return False
335
+ previous_id = current[-1][0]
336
+ if current[-1][1].source_path != current[0][1].source_path:
337
+ return False
338
+ similarity = cosine(self.block_embeddings[previous_id], self.block_embeddings[block_id])
339
+ return similarity < self.similarity_threshold
340
+