ragforge-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragforge/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ """RAGforge — local-first RAG pipeline with an eval harness."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ from ragforge.pipeline import Answer, Pipeline, Source
6
+
7
+ try:
8
+ __version__ = version("ragforge-ml")
9
+ except PackageNotFoundError:
10
+ __version__ = "0.0.0+local"
11
+
12
+ __all__ = ["Answer", "Pipeline", "Source", "__version__"]
ragforge/cli.py ADDED
@@ -0,0 +1,205 @@
1
+ """Command-line interface — ``ragforge`` / ``rf``.
2
+
3
+ Subcommands:
4
+ ingest load PDF / Markdown documents into a collection
5
+ ask run the full RAG pipeline (retrieve, rerank, generate)
6
+ eval run the eval harness on a JSONL of QA samples
7
+ serve start the FastAPI server
8
+ info print the current collection stats
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Annotated
16
+
17
+ import typer
18
+ from rich.console import Console
19
+ from rich.panel import Panel
20
+ from rich.table import Table
21
+
22
+ app = typer.Typer(
23
+ name="ragforge",
24
+ help="Local-first RAG pipeline with an eval harness.",
25
+ no_args_is_help=True,
26
+ rich_markup_mode="rich",
27
+ )
28
+ console = Console()
29
+
30
+
31
+ @app.command()
32
+ def ingest(
33
+ paths: Annotated[list[Path], typer.Argument(help="Files or directories to index")],
34
+ collection: Annotated[str, typer.Option(help="Qdrant collection name")] = "ragforge",
35
+ store_path: Annotated[Path, typer.Option(help="Qdrant embedded storage dir")] = Path(
36
+ "qdrant_storage"
37
+ ),
38
+ chunk_size: Annotated[int, typer.Option()] = 1024,
39
+ chunk_overlap: Annotated[int, typer.Option()] = 128,
40
+ ) -> None:
41
+ """Load documents into the vector store."""
42
+ from ragforge import Pipeline
43
+
44
+ console.print(Panel.fit(f"[bold purple]ragforge ingest[/] [dim]{collection}[/]"))
45
+ rag = Pipeline.from_defaults(collection=collection, store_path=store_path, use_reranker=False)
46
+ rag.chunk_size = chunk_size
47
+ rag.chunk_overlap = chunk_overlap
48
+ n = rag.ingest([str(p) for p in paths])
49
+ console.print(
50
+ f"[green]ok[/] indexed [bold]{n}[/] chunks total in collection: {rag.store.count()}"
51
+ )
52
+
53
+
54
+ @app.command()
55
+ def ask(
56
+ question: Annotated[str, typer.Argument()],
57
+ collection: Annotated[str, typer.Option()] = "ragforge",
58
+ store_path: Annotated[Path, typer.Option()] = Path("qdrant_storage"),
59
+ model_id: Annotated[
60
+ str, typer.Option(help="HF causal LM to generate the answer")
61
+ ] = "Qwen/Qwen2.5-3B-Instruct",
62
+ k: Annotated[int, typer.Option(help="Top-k after rerank")] = 5,
63
+ max_new_tokens: Annotated[int, typer.Option()] = 256,
64
+ quantize: Annotated[
65
+ str | None, typer.Option(help="Optional turboquant-ml method, e.g. bnb-nf4")
66
+ ] = None,
67
+ ) -> None:
68
+ """Retrieve, rerank, and generate an answer."""
69
+ from ragforge import Pipeline
70
+
71
+ console.print(Panel.fit(f"[bold purple]ragforge ask[/] [dim]{question}[/]"))
72
+
73
+ if quantize:
74
+ from ragforge.llm import QuantizedHFLLM
75
+
76
+ llm = QuantizedHFLLM(model_id, method=quantize)
77
+ else:
78
+ from ragforge.llm import HFLLM
79
+
80
+ llm = HFLLM(model_id)
81
+
82
+ rag = Pipeline.from_defaults(collection=collection, store_path=store_path, llm=llm)
83
+ ans = rag.ask(question, top_k=k, max_new_tokens=max_new_tokens)
84
+ console.print(f"\n[bold]Answer[/] [dim]({ans.latency_ms:.0f} ms)[/]\n{ans.text}\n")
85
+
86
+ table = Table(title="Sources", show_header=True)
87
+ table.add_column("#", style="dim", width=3)
88
+ table.add_column("score", justify="right")
89
+ table.add_column("source")
90
+ for i, s in enumerate(ans.sources, 1):
91
+ loc = s.metadata.get("path", s.id)
92
+ page = s.metadata.get("page")
93
+ loc = f"{loc} p.{page}" if page else loc
94
+ table.add_row(str(i), f"{s.score:.3f}", loc)
95
+ console.print(table)
96
+
97
+
98
+ @app.command()
99
+ def eval(
100
+ dataset: Annotated[Path, typer.Argument(help="JSONL with {question, ground_truth} per line")],
101
+ collection: Annotated[str, typer.Option()] = "ragforge",
102
+ store_path: Annotated[Path, typer.Option()] = Path("qdrant_storage"),
103
+ model_id: Annotated[str, typer.Option()] = "Qwen/Qwen2.5-3B-Instruct",
104
+ metrics: Annotated[
105
+ str, typer.Option(help="Comma-separated")
106
+ ] = "context_recall,answer_relevance,faithfulness",
107
+ out: Annotated[Path | None, typer.Option(help="Save JSON report here")] = None,
108
+ limit: Annotated[int | None, typer.Option(help="Cap the dataset for quick runs")] = None,
109
+ ) -> None:
110
+ """Run the eval harness over a JSONL dataset."""
111
+ import time
112
+
113
+ from ragforge import Pipeline
114
+ from ragforge.eval import evaluate
115
+ from ragforge.eval.report import EvalReport
116
+ from ragforge.llm import HFLLM
117
+
118
+ samples_in = _read_jsonl(dataset, limit=limit)
119
+ console.print(
120
+ Panel.fit(f"[bold purple]ragforge eval[/] n={len(samples_in)} metrics={metrics}")
121
+ )
122
+
123
+ llm = HFLLM(model_id)
124
+ rag = Pipeline.from_defaults(collection=collection, store_path=store_path, llm=llm)
125
+
126
+ samples_out: list[dict] = []
127
+ latencies: list[float] = []
128
+ for s in samples_in:
129
+ t0 = time.perf_counter()
130
+ ans = rag.ask(s["question"])
131
+ latencies.append((time.perf_counter() - t0) * 1000)
132
+ samples_out.append(
133
+ {
134
+ "question": s["question"],
135
+ "answer": ans.text,
136
+ "contexts": [src.text for src in ans.sources],
137
+ "ground_truth": s.get("ground_truth"),
138
+ }
139
+ )
140
+
141
+ metric_list = [m.strip() for m in metrics.split(",") if m.strip()]
142
+ res = evaluate(samples_out, encoder=rag.encoder, metrics=metric_list)
143
+ report = EvalReport(
144
+ n=res["n"],
145
+ means=res["means"],
146
+ per_sample=res["per_sample"],
147
+ latencies_ms=latencies,
148
+ )
149
+ console.print(report.as_table())
150
+ if out:
151
+ report.save(out)
152
+ console.print(f"[green]ok[/] saved {out}")
153
+
154
+
155
+ @app.command()
156
+ def serve(
157
+ collection: Annotated[str, typer.Option()] = "ragforge",
158
+ store_path: Annotated[Path, typer.Option()] = Path("qdrant_storage"),
159
+ model_id: Annotated[str, typer.Option()] = "Qwen/Qwen2.5-3B-Instruct",
160
+ host: Annotated[str, typer.Option()] = "127.0.0.1",
161
+ port: Annotated[int, typer.Option()] = 8000,
162
+ ) -> None:
163
+ """Start the FastAPI server."""
164
+ import uvicorn
165
+
166
+ from ragforge import Pipeline
167
+ from ragforge.llm import HFLLM
168
+ from ragforge.serve import build_app
169
+
170
+ llm = HFLLM(model_id)
171
+ rag = Pipeline.from_defaults(collection=collection, store_path=store_path, llm=llm)
172
+ fastapi_app = build_app(rag)
173
+ uvicorn.run(fastapi_app, host=host, port=port, log_level="info")
174
+
175
+
176
+ @app.command()
177
+ def info(
178
+ collection: Annotated[str, typer.Option()] = "ragforge",
179
+ store_path: Annotated[Path, typer.Option()] = Path("qdrant_storage"),
180
+ ) -> None:
181
+ """Print the current collection stats."""
182
+ from ragforge import Pipeline
183
+
184
+ rag = Pipeline.from_defaults(collection=collection, store_path=store_path, use_reranker=False)
185
+ console.print(f"collection : [bold]{collection}[/]")
186
+ console.print(f"store : [bold]{store_path}[/]")
187
+ console.print(f"vectors : [bold]{rag.store.count()}[/]")
188
+ console.print(f"encoder : [bold]{rag.encoder.model_id}[/] dim={rag.encoder.dim}")
189
+
190
+
191
+ def _read_jsonl(path: Path, *, limit: int | None) -> list[dict]:
192
+ rows: list[dict] = []
193
+ with path.open(encoding="utf-8") as f:
194
+ for line in f:
195
+ line = line.strip()
196
+ if not line:
197
+ continue
198
+ rows.append(json.loads(line))
199
+ if limit and len(rows) >= limit:
200
+ break
201
+ return rows
202
+
203
+
204
+ if __name__ == "__main__":
205
+ app()
@@ -0,0 +1,5 @@
1
+ """Embedding models — thin wrapper around sentence-transformers."""
2
+
3
+ from ragforge.embed.encoder import Encoder, SentenceTransformerEncoder
4
+
5
+ __all__ = ["Encoder", "SentenceTransformerEncoder"]
@@ -0,0 +1,62 @@
1
+ """Encoder protocol + sentence-transformers backend.
2
+
3
+ The protocol is intentionally small — :meth:`encode` is the only required
4
+ method — so swapping the backend (e.g. for an in-house ONNX encoder, an API
5
+ client, or a cached encoder) is cheap.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Protocol
11
+
12
+ import numpy as np
13
+
14
+
15
+ class Encoder(Protocol):
16
+ """Anything that can map a batch of texts to fixed-size vectors."""
17
+
18
+ dim: int
19
+
20
+ def encode(
21
+ self, texts: list[str], *, batch_size: int = 32, normalize: bool = True
22
+ ) -> np.ndarray: ...
23
+
24
+
25
+ class SentenceTransformerEncoder:
26
+ """Default encoder.
27
+
28
+ ``BAAI/bge-small-en-v1.5`` (33 M params, 384-dim) is the default because it
29
+ sits at the Pareto frontier of "small enough to run anywhere" / "good
30
+ enough on MTEB". For multilingual corpora, pass
31
+ ``intfloat/multilingual-e5-small`` instead.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_id: str = "BAAI/bge-small-en-v1.5",
37
+ *,
38
+ device: str | None = None,
39
+ cache_folder: str | None = None,
40
+ ) -> None:
41
+ from sentence_transformers import SentenceTransformer
42
+
43
+ self.model_id = model_id
44
+ self._model = SentenceTransformer(model_id, device=device, cache_folder=cache_folder)
45
+ self.dim = self._model.get_sentence_embedding_dimension()
46
+
47
+ def encode(
48
+ self, texts: list[str], *, batch_size: int = 32, normalize: bool = True
49
+ ) -> np.ndarray:
50
+ if not texts:
51
+ return np.zeros((0, self.dim), dtype=np.float32)
52
+ vecs = self._model.encode(
53
+ texts,
54
+ batch_size=batch_size,
55
+ normalize_embeddings=normalize,
56
+ convert_to_numpy=True,
57
+ show_progress_bar=False,
58
+ )
59
+ return vecs.astype(np.float32, copy=False)
60
+
61
+ def __repr__(self) -> str:
62
+ return f"SentenceTransformerEncoder({self.model_id!r}, dim={self.dim})"
@@ -0,0 +1,17 @@
1
+ """Evaluation harness — measure RAG answer quality without an external API."""
2
+
3
+ from ragforge.eval.metrics import (
4
+ answer_relevance,
5
+ context_recall,
6
+ evaluate,
7
+ faithfulness,
8
+ )
9
+ from ragforge.eval.report import EvalReport
10
+
11
+ __all__ = [
12
+ "EvalReport",
13
+ "answer_relevance",
14
+ "context_recall",
15
+ "evaluate",
16
+ "faithfulness",
17
+ ]
@@ -0,0 +1,178 @@
1
+ """RAG quality metrics — embedding-based, pure-Python, no external API.
2
+
3
+ All metrics return a float in ``[0, 1]`` (higher is better) and accept the
4
+ same arguments so the orchestrator can iterate them uniformly:
5
+
6
+ metric(question, answer, contexts, ground_truth=None, *, encoder) -> float
7
+
8
+ The three metrics implemented here are RAGAS-compatible interpretations:
9
+
10
+ - ``context_recall``: of the n-grams in the ground-truth answer, what fraction
11
+ appear in any retrieved context block?
12
+ - ``answer_relevance``: cosine similarity between the question and the
13
+ answer's embedding, with a length-penalty for empty/non-answers.
14
+ - ``faithfulness``: of the claims (sentences) in the answer, what fraction
15
+ have a cosine similarity > τ with at least one context block?
16
+
17
+ These are deliberately *embedding-based* rather than NLI-based: they need no
18
+ extra model, run on CPU, and correlate strongly with judgment scores on
19
+ short-answer QA. For long-form evaluation, plug in a stronger judge model.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ import re
26
+ from collections.abc import Iterable
27
+
28
+ import numpy as np
29
+
30
+ from ragforge.embed import Encoder
31
+
32
+ logger = logging.getLogger("ragforge.eval")
33
+
34
+ _WORD = re.compile(r"\w+", re.UNICODE)
35
+ _SENT = re.compile(r"(?<=[.!?])\s+")
36
+
37
+
38
+ def context_recall(
39
+ question: str,
40
+ answer: str,
41
+ contexts: list[str],
42
+ ground_truth: str | None,
43
+ *,
44
+ encoder: Encoder | None = None,
45
+ ngram: int = 1,
46
+ ) -> float:
47
+ """Fraction of ground-truth ngrams found in any retrieved context."""
48
+ if not ground_truth or not contexts:
49
+ return 0.0
50
+ gold = _ngrams(ground_truth.lower(), ngram)
51
+ if not gold:
52
+ return 0.0
53
+ joined = " ".join(contexts).lower()
54
+ present = sum(1 for g in gold if " ".join(g) in joined)
55
+ return present / len(gold)
56
+
57
+
58
+ def answer_relevance(
59
+ question: str,
60
+ answer: str,
61
+ contexts: list[str],
62
+ ground_truth: str | None = None,
63
+ *,
64
+ encoder: Encoder,
65
+ ) -> float:
66
+ """Cosine similarity between question and answer embeddings.
67
+
68
+ Clamped to ``[0, 1]`` (negative cosine = unrelated, treated as 0).
69
+ "I don't know"-style answers are hard-capped at 0.2 even if they happen
70
+ to be lexically similar to the question.
71
+ """
72
+ if not answer.strip():
73
+ return 0.0
74
+ score = max(0.0, _cos_pair(encoder, question, answer))
75
+ if _looks_like_dont_know(answer):
76
+ return min(score, 0.2)
77
+ return min(score, 1.0)
78
+
79
+
80
+ def faithfulness(
81
+ question: str,
82
+ answer: str,
83
+ contexts: list[str],
84
+ ground_truth: str | None = None,
85
+ *,
86
+ encoder: Encoder,
87
+ tau: float = 0.55,
88
+ ) -> float:
89
+ """Fraction of answer sentences supported by at least one context block.
90
+
91
+ A sentence is "supported" if its cosine similarity with at least one
92
+ retrieved context block crosses ``tau``. Tuned for ``bge-small`` on
93
+ factual short-answer QA; raise for stricter eval.
94
+ """
95
+ if not answer.strip() or not contexts:
96
+ return 0.0
97
+ sentences = [s.strip() for s in _SENT.split(answer.strip()) if s.strip()]
98
+ if not sentences:
99
+ return 0.0
100
+ sent_vecs = encoder.encode(sentences)
101
+ ctx_vecs = encoder.encode(contexts)
102
+ # cosine on normalized vectors == dot product
103
+ sims = sent_vecs @ ctx_vecs.T
104
+ supported = (sims.max(axis=1) >= tau).sum()
105
+ return float(supported) / len(sentences)
106
+
107
+
108
+ # ---------------------------------------------------------------- orchestrator
109
+
110
+
111
+ def evaluate(
112
+ samples: list[dict],
113
+ *,
114
+ encoder: Encoder,
115
+ metrics: Iterable[str] = ("context_recall", "answer_relevance", "faithfulness"),
116
+ ) -> dict:
117
+ """Run multiple metrics over a list of QA samples and return per-metric means.
118
+
119
+ Each sample is a dict with at least:
120
+ question: str
121
+ answer: str
122
+ contexts: list[str]
123
+ ground_truth: str | None (required for context_recall)
124
+ """
125
+ registry = {
126
+ "context_recall": context_recall,
127
+ "answer_relevance": answer_relevance,
128
+ "faithfulness": faithfulness,
129
+ }
130
+ by_metric: dict[str, list[float]] = {m: [] for m in metrics}
131
+
132
+ for s in samples:
133
+ q = s.get("question", "")
134
+ a = s.get("answer", "")
135
+ ctxs = s.get("contexts", [])
136
+ gt = s.get("ground_truth")
137
+ for m in metrics:
138
+ fn = registry[m]
139
+ score = fn(q, a, ctxs, gt, encoder=encoder)
140
+ by_metric[m].append(float(score))
141
+
142
+ means = {m: (sum(v) / len(v) if v else 0.0) for m, v in by_metric.items()}
143
+ return {
144
+ "n": len(samples),
145
+ "means": means,
146
+ "per_sample": by_metric,
147
+ }
148
+
149
+
150
+ # ------------------------------------------------------------------ utilities
151
+
152
+
153
+ def _ngrams(text: str, n: int) -> set[tuple[str, ...]]:
154
+ words = _WORD.findall(text)
155
+ if len(words) < n:
156
+ return set()
157
+ return {tuple(words[i : i + n]) for i in range(len(words) - n + 1)}
158
+
159
+
160
+ def _cos_pair(encoder: Encoder, a: str, b: str) -> float:
161
+ vecs = encoder.encode([a, b])
162
+ return float(np.dot(vecs[0], vecs[1]))
163
+
164
+
165
+ _DK_PATTERNS = (
166
+ "i don't know",
167
+ "i do not know",
168
+ "not enough information",
169
+ "cannot determine",
170
+ "no answer",
171
+ "n/a",
172
+ "based on the provided context",
173
+ )
174
+
175
+
176
+ def _looks_like_dont_know(text: str) -> bool:
177
+ t = text.lower()
178
+ return any(p in t for p in _DK_PATTERNS)
@@ -0,0 +1,57 @@
1
+ """Pretty-printable eval report."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import statistics
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+
10
+
11
+ @dataclass
12
+ class EvalReport:
13
+ n: int
14
+ means: dict[str, float]
15
+ per_sample: dict[str, list[float]]
16
+ latencies_ms: list[float] = field(default_factory=list)
17
+ extras: dict = field(default_factory=dict)
18
+
19
+ @property
20
+ def p50_ms(self) -> float:
21
+ return statistics.median(self.latencies_ms) if self.latencies_ms else 0.0
22
+
23
+ @property
24
+ def p95_ms(self) -> float:
25
+ if not self.latencies_ms:
26
+ return 0.0
27
+ s = sorted(self.latencies_ms)
28
+ return s[int(0.95 * (len(s) - 1))]
29
+
30
+ def as_table(self) -> str:
31
+ lines = [
32
+ "+----------------+--------+",
33
+ "| metric | mean |",
34
+ "+----------------+--------+",
35
+ ]
36
+ for name, val in self.means.items():
37
+ lines.append(f"| {name:<14s} | {val:.3f} |")
38
+ lines.append("+----------------+--------+")
39
+ if self.latencies_ms:
40
+ lines.append(f"n={self.n} · p50={self.p50_ms:.0f}ms · p95={self.p95_ms:.0f}ms")
41
+ else:
42
+ lines.append(f"n={self.n}")
43
+ return "\n".join(lines)
44
+
45
+ def save(self, path: str | Path) -> None:
46
+ Path(path).write_text(
47
+ json.dumps(
48
+ {
49
+ "n": self.n,
50
+ "means": self.means,
51
+ "per_sample": self.per_sample,
52
+ "latencies_ms": self.latencies_ms,
53
+ "extras": self.extras,
54
+ },
55
+ indent=2,
56
+ )
57
+ )
@@ -0,0 +1,41 @@
1
+ """Document ingestion — load files into a stream of (text, metadata) pairs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterator
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from ragforge.ingest.markdown import load_markdown
10
+ from ragforge.ingest.pdf import load_pdf
11
+
12
+ Loader = dict[str, Any]
13
+
14
+
15
+ def iter_documents(paths: list[str | Path]) -> Iterator[tuple[str, dict]]:
16
+ """Dispatch each path to the right loader based on extension.
17
+
18
+ Directories are walked recursively. Unsupported file types are skipped
19
+ silently — extending the registry is a one-line change.
20
+ """
21
+ for raw in paths:
22
+ p = Path(raw)
23
+ if p.is_dir():
24
+ for child in sorted(p.rglob("*")):
25
+ if child.is_file():
26
+ yield from _load_one(child)
27
+ elif p.is_file():
28
+ yield from _load_one(p)
29
+
30
+
31
+ def _load_one(path: Path) -> Iterator[tuple[str, dict]]:
32
+ suffix = path.suffix.lower()
33
+ if suffix == ".pdf":
34
+ yield from load_pdf(path)
35
+ elif suffix in {".md", ".markdown"}:
36
+ yield from load_markdown(path)
37
+ elif suffix in {".txt", ".rst"}:
38
+ yield path.read_text(encoding="utf-8", errors="ignore"), {"path": str(path)}
39
+
40
+
41
+ __all__ = ["iter_documents", "load_markdown", "load_pdf"]
@@ -0,0 +1,102 @@
1
+ """Recursive character splitter with overlap — a faithful, dependency-free port
2
+ of the recipe that LangChain popularized.
3
+
4
+ Why character-based rather than token-based?
5
+ Embedding models truncate at a token budget, but tokenizers vary. A
6
+ well-tuned character size (≈4 chars/token for English) is portable across
7
+ encoders and avoids re-tokenizing during chunking. For mixed-language
8
+ corpora, plug in a token-aware splitter; the interface is one function.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from collections.abc import Iterable
14
+ from dataclasses import dataclass
15
+
16
+ _DEFAULT_SEPARATORS = ("\n\n", "\n", ". ", " ", "")
17
+
18
+
19
+ @dataclass
20
+ class Chunk:
21
+ text: str
22
+ metadata: dict
23
+
24
+
25
+ def split(
26
+ text: str,
27
+ metadata: dict | None = None,
28
+ *,
29
+ size: int = 1024,
30
+ overlap: int = 128,
31
+ separators: Iterable[str] = _DEFAULT_SEPARATORS,
32
+ ) -> list[Chunk]:
33
+ """Recursively split ``text`` into windows of ~``size`` chars.
34
+
35
+ Each chunk overlaps the previous one by ``overlap`` chars so a query that
36
+ falls on the boundary still hits a chunk with full context. The recursion
37
+ tries successively finer separators (paragraph → line → sentence → word →
38
+ char) so we cut at a semantically reasonable boundary when possible.
39
+ """
40
+ metadata = metadata or {}
41
+ seps = list(separators)
42
+ raw_chunks = _recursive_split(text, seps, size)
43
+ merged = _merge_with_overlap(raw_chunks, size=size, overlap=overlap)
44
+ return [
45
+ Chunk(text=c, metadata={**metadata, "chunk": i, "n_chunks": len(merged)})
46
+ for i, c in enumerate(merged)
47
+ ]
48
+
49
+
50
+ def split_documents(
51
+ docs: Iterable[tuple[str, dict]],
52
+ *,
53
+ size: int = 1024,
54
+ overlap: int = 128,
55
+ ) -> list[Chunk]:
56
+ out: list[Chunk] = []
57
+ for text, meta in docs:
58
+ out.extend(split(text, meta, size=size, overlap=overlap))
59
+ return out
60
+
61
+
62
+ def _recursive_split(text: str, separators: list[str], size: int) -> list[str]:
63
+ if len(text) <= size:
64
+ return [text]
65
+ sep = separators[0]
66
+ rest = separators[1:] or [""]
67
+ if sep == "":
68
+ return [text[i : i + size] for i in range(0, len(text), size)]
69
+ parts = text.split(sep)
70
+ out: list[str] = []
71
+ buf = ""
72
+ for part in parts:
73
+ candidate = (buf + sep + part) if buf else part
74
+ if len(candidate) <= size:
75
+ buf = candidate
76
+ continue
77
+ if buf:
78
+ out.append(buf)
79
+ if len(part) > size:
80
+ out.extend(_recursive_split(part, rest, size))
81
+ buf = ""
82
+ else:
83
+ buf = part
84
+ if buf:
85
+ out.append(buf)
86
+ return out
87
+
88
+
89
+ def _merge_with_overlap(chunks: list[str], *, size: int, overlap: int) -> list[str]:
90
+ """Sliding-window merge: ensures successive chunks share ``overlap`` chars."""
91
+ if overlap <= 0 or not chunks:
92
+ return chunks
93
+ out: list[str] = []
94
+ for c in chunks:
95
+ if out and len(out[-1]) + len(c) <= size + overlap:
96
+ out[-1] = out[-1] + c
97
+ else:
98
+ if out:
99
+ tail = out[-1][-overlap:]
100
+ c = tail + c if not c.startswith(tail) else c
101
+ out.append(c)
102
+ return out