ragforge-sdk 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +10 -0
- ragforge/__init__.py +21 -0
- ragforge/cache/__init__.py +4 -0
- ragforge/cache/semantic_cache.py +191 -0
- ragforge/config/__init__.py +8 -0
- ragforge/config/llm_config.py +38 -0
- ragforge/config/model_config.py +45 -0
- ragforge/config/pipeline_config.py +78 -0
- ragforge/dedup/__init__.py +5 -0
- ragforge/dedup/deduplicator.py +122 -0
- ragforge/evaluation/__init__.py +7 -0
- ragforge/evaluation/evaluator.py +206 -0
- ragforge/evaluation/judge.py +136 -0
- ragforge/fusion/__init__.py +7 -0
- ragforge/fusion/adaptive.py +181 -0
- ragforge/fusion/blend.py +135 -0
- ragforge/fusion/rrf.py +139 -0
- ragforge/llm/__init__.py +5 -0
- ragforge/llm/llm_client.py +129 -0
- ragforge/models/__init__.py +7 -0
- ragforge/models/embedding.py +130 -0
- ragforge/models/reranker.py +136 -0
- ragforge/pipeline.py +424 -0
- ragforge/profiler.py +74 -0
- ragforge/protocols.py +158 -0
- ragforge/query/__init__.py +6 -0
- ragforge/query/planner.py +156 -0
- ragforge/retrieval/__init__.py +6 -0
- ragforge/retrieval/bm25.py +93 -0
- ragforge/retrieval/hybrid.py +102 -0
- ragforge/retrieval/vector.py +102 -0
- ragforge/tracing/__init__.py +5 -0
- ragforge/tracing/trace.py +81 -0
- ragforge/type_utils.py +217 -0
- ragforge/utils.py +118 -0
- ragforge_sdk-1.0.0.dist-info/METADATA +826 -0
- ragforge_sdk-1.0.0.dist-info/RECORD +40 -0
- ragforge_sdk-1.0.0.dist-info/WHEEL +5 -0
- ragforge_sdk-1.0.0.dist-info/licenses/LICENSE +201 -0
- ragforge_sdk-1.0.0.dist-info/top_level.txt +2 -0
__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .ragforge.pipeline import SearchPipeline
|
|
2
|
+
from .ragforge.models import FastembedEmbedder, FastembedReranker
|
|
3
|
+
from .ragforge.retrieval import BM25Retriever, VectorRetriever, HybridRetriever
|
|
4
|
+
from .ragforge.fusion import RRFFusion, PositionAwareBlend, AdaptiveFusion
|
|
5
|
+
from .ragforge.llm import LLMClient
|
|
6
|
+
from .ragforge.query import QueryPlanner
|
|
7
|
+
from .ragforge.evaluation import LLMJudge, Evaluator
|
|
8
|
+
from .ragforge.cache import SemanticCache
|
|
9
|
+
from .ragforge.dedup import Deduplicator
|
|
10
|
+
from .ragforge.config import ModelConfig, PipelineConfig, LLMConfig, QueryTransformStrategy
|
ragforge/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .cache.semantic_cache import SemanticCache
|
|
2
|
+
from .config.llm_config import LLMConfig
|
|
3
|
+
from .config.model_config import ModelConfig
|
|
4
|
+
from .config.pipeline_config import PipelineConfig, QueryTransformStrategy
|
|
5
|
+
from .dedup.deduplicator import Deduplicator
|
|
6
|
+
from .evaluation.evaluator import Evaluator
|
|
7
|
+
from .evaluation.judge import LLMJudge
|
|
8
|
+
from .fusion.adaptive import AdaptiveFusion
|
|
9
|
+
from .fusion.blend import PositionAwareBlend
|
|
10
|
+
from .fusion.rrf import RRFFusion
|
|
11
|
+
from .llm.llm_client import LLMClient
|
|
12
|
+
from .models.embedding import FastembedEmbedder
|
|
13
|
+
from .models.reranker import FastembedReranker
|
|
14
|
+
from .query.planner import QueryPlanner
|
|
15
|
+
from .retrieval.bm25 import BM25Retriever
|
|
16
|
+
from .retrieval.hybrid import HybridRetriever
|
|
17
|
+
from .retrieval.vector import VectorRetriever
|
|
18
|
+
from .tracing.trace import Tracer
|
|
19
|
+
|
|
20
|
+
__version__ = "1.0.0"
|
|
21
|
+
__author__ = "jiangnanboy"
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Semantic Cache
|
|
3
|
+
==============
|
|
4
|
+
Cache search results keyed by query embedding similarity.
|
|
5
|
+
|
|
6
|
+
Features:
|
|
7
|
+
- Embeds queries and finds cached entries by cosine similarity
|
|
8
|
+
- Configurable similarity threshold
|
|
9
|
+
- TTL-based expiration
|
|
10
|
+
- Thread-safe (basic)
|
|
11
|
+
- Uses any ``Embedder`` implementation
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
20
|
+
|
|
21
|
+
from ..protocols import Embedder
|
|
22
|
+
|
|
23
|
+
class _CacheEntry:
|
|
24
|
+
"""Internal cache entry."""
|
|
25
|
+
|
|
26
|
+
__slots__ = ("query", "query_embedding", "results", "timestamp")
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
query: str,
|
|
31
|
+
query_embedding: np.ndarray,
|
|
32
|
+
results: list[tuple[str, float]],
|
|
33
|
+
) -> None:
|
|
34
|
+
self.query = query
|
|
35
|
+
self.query_embedding = query_embedding
|
|
36
|
+
self.results = results
|
|
37
|
+
self.timestamp = time.time()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SemanticCache:
|
|
41
|
+
"""Semantic cache for RAG search results.
|
|
42
|
+
|
|
43
|
+
Caches pipeline results keyed by query embedding. When a new query
|
|
44
|
+
is similar enough to a cached one, returns the cached results instead
|
|
45
|
+
of re-running the pipeline.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
embedder:
|
|
49
|
+
Any ``Embedder`` for computing query embeddings.
|
|
50
|
+
similarity_threshold:
|
|
51
|
+
Minimum cosine similarity to consider a cache hit (default 0.95).
|
|
52
|
+
ttl_seconds:
|
|
53
|
+
Time-to-live for cache entries in seconds (default 3600).
|
|
54
|
+
Set to 0 for no expiration.
|
|
55
|
+
max_size:
|
|
56
|
+
Maximum number of entries to keep in cache (default 1000).
|
|
57
|
+
|
|
58
|
+
Example::
|
|
59
|
+
|
|
60
|
+
from ragforge import SemanticCache, FastembedEmbedder, ModelConfig
|
|
61
|
+
|
|
62
|
+
cache = SemanticCache(
|
|
63
|
+
embedder=FastembedEmbedder(ModelConfig()),
|
|
64
|
+
similarity_threshold=0.95,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# First call: runs pipeline, stores result
|
|
68
|
+
result = cache.get_or_search(
|
|
69
|
+
"苹果手机价格",
|
|
70
|
+
search_fn=lambda q, d: pipeline.search(q, d),
|
|
71
|
+
documents=docs,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Similar query: returns cached result instantly
|
|
75
|
+
result = cache.get_or_search(
|
|
76
|
+
"苹果手机多少钱",
|
|
77
|
+
search_fn=lambda q, d: pipeline.search(q, d),
|
|
78
|
+
documents=docs,
|
|
79
|
+
)
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
embedder: Embedder,
|
|
85
|
+
similarity_threshold: float = 0.95,
|
|
86
|
+
ttl_seconds: float = 3600,
|
|
87
|
+
max_size: int = 1000,
|
|
88
|
+
) -> None:
|
|
89
|
+
self._embedder = embedder
|
|
90
|
+
self._threshold = similarity_threshold
|
|
91
|
+
self._ttl = ttl_seconds
|
|
92
|
+
self._max_size = max_size
|
|
93
|
+
self._entries: list[_CacheEntry] = []
|
|
94
|
+
self._hits = 0
|
|
95
|
+
self._misses = 0
|
|
96
|
+
|
|
97
|
+
def get(
|
|
98
|
+
self, query: str
|
|
99
|
+
) -> list[tuple[str, float]] | None:
|
|
100
|
+
"""Look up a cached result for the query.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
query: Search query.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Cached results if a hit is found (similarity > threshold),
|
|
107
|
+
``None`` otherwise.
|
|
108
|
+
"""
|
|
109
|
+
now = time.time()
|
|
110
|
+
query_emb = self._embedder.embed(query).reshape(1, -1)
|
|
111
|
+
|
|
112
|
+
for entry in self._entries:
|
|
113
|
+
# Check TTL
|
|
114
|
+
if self._ttl > 0 and (now - entry.timestamp) > self._ttl:
|
|
115
|
+
continue
|
|
116
|
+
# Check similarity
|
|
117
|
+
sim = cosine_similarity(query_emb, entry.query_embedding.reshape(1, -1))[0][0]
|
|
118
|
+
if sim >= self._threshold:
|
|
119
|
+
self._hits += 1
|
|
120
|
+
return entry.results
|
|
121
|
+
|
|
122
|
+
self._misses += 1
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
def put(
|
|
126
|
+
self,
|
|
127
|
+
query: str,
|
|
128
|
+
results: list[tuple[str, float]],
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Store a search result in the cache.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
query: The query that produced these results.
|
|
134
|
+
results: Pipeline output to cache.
|
|
135
|
+
"""
|
|
136
|
+
query_emb = self._embedder.embed(query)
|
|
137
|
+
|
|
138
|
+
# Evict oldest if at capacity
|
|
139
|
+
if len(self._entries) >= self._max_size:
|
|
140
|
+
self._entries.pop(0)
|
|
141
|
+
|
|
142
|
+
self._entries.append(_CacheEntry(query, query_emb, results))
|
|
143
|
+
|
|
144
|
+
def get_or_search(
|
|
145
|
+
self,
|
|
146
|
+
query: str,
|
|
147
|
+
search_fn: object,
|
|
148
|
+
documents: list[str],
|
|
149
|
+
) -> list[tuple[str, float]]:
|
|
150
|
+
"""Get cached result or execute search function.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
query: Search query.
|
|
154
|
+
search_fn:
|
|
155
|
+
Callable ``search_fn(query, documents) -> results``.
|
|
156
|
+
documents: Candidate documents.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Either cached or fresh search results.
|
|
160
|
+
"""
|
|
161
|
+
cached = self.get(query)
|
|
162
|
+
if cached is not None:
|
|
163
|
+
return cached
|
|
164
|
+
|
|
165
|
+
results = search_fn(query, documents)
|
|
166
|
+
self.put(query, results)
|
|
167
|
+
return results
|
|
168
|
+
|
|
169
|
+
def clear(self) -> int:
|
|
170
|
+
"""Clear all cache entries.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Number of entries cleared.
|
|
174
|
+
"""
|
|
175
|
+
count = len(self._entries)
|
|
176
|
+
self._entries.clear()
|
|
177
|
+
self._hits = 0
|
|
178
|
+
self._misses = 0
|
|
179
|
+
return count
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def stats(self) -> dict[str, int | float]:
|
|
183
|
+
"""Cache statistics."""
|
|
184
|
+
total = self._hits + self._misses
|
|
185
|
+
hit_rate = self._hits / total if total > 0 else 0.0
|
|
186
|
+
return {
|
|
187
|
+
"entries": len(self._entries),
|
|
188
|
+
"hits": self._hits,
|
|
189
|
+
"misses": self._misses,
|
|
190
|
+
"hit_rate": round(hit_rate, 4),
|
|
191
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Configuration
|
|
3
|
+
=================
|
|
4
|
+
Configuration for LLM backends (DeepSeek, etc.).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LLMConfig:
|
|
14
|
+
"""Configuration for LLM API access.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
api_key:
|
|
18
|
+
API key for the LLM service.
|
|
19
|
+
base_url:
|
|
20
|
+
Base URL of the LLM API endpoint.
|
|
21
|
+
Defaults to DeepSeek API.
|
|
22
|
+
model:
|
|
23
|
+
Model name to use.
|
|
24
|
+
Defaults to ``deepseek-chat``.
|
|
25
|
+
temperature:
|
|
26
|
+
Sampling temperature for generation (0.0 = deterministic).
|
|
27
|
+
max_tokens:
|
|
28
|
+
Maximum number of tokens in the response.
|
|
29
|
+
timeout:
|
|
30
|
+
Request timeout in seconds.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
api_key: str = ""
|
|
34
|
+
base_url: str = "https://api.deepseek.com"
|
|
35
|
+
model: str = "deepseek-v4-flash"
|
|
36
|
+
temperature: float = 0.1
|
|
37
|
+
max_tokens: int = 1024
|
|
38
|
+
timeout: int = 60
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Configuration
|
|
3
|
+
===================
|
|
4
|
+
Dataclass holding model paths and loading parameters.
|
|
5
|
+
|
|
6
|
+
Separation of Concerns:
|
|
7
|
+
- *ModelConfig* → model paths, file names, dimensions
|
|
8
|
+
- *PipelineConfig* → algorithm hyper-parameters
|
|
9
|
+
|
|
10
|
+
This makes it easy to swap models without touching pipeline logic.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ModelConfig:
|
|
20
|
+
"""Configuration for embedding and reranker models.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
embedding_model_path:
|
|
24
|
+
Local directory containing the embedding ONNX model files.
|
|
25
|
+
If ``None``, the model will be auto-downloaded on first use.
|
|
26
|
+
embedding_onnx_file:
|
|
27
|
+
Name of the embedding ONNX weight file (default ``model.onnx``).
|
|
28
|
+
embedding_dim:
|
|
29
|
+
Dimensionality of the embedding vectors.
|
|
30
|
+
rerank_model_path:
|
|
31
|
+
Local directory containing the reranker ONNX model files.
|
|
32
|
+
If ``None``, the model will be auto-downloaded on first use.
|
|
33
|
+
rerank_onnx_file:
|
|
34
|
+
Name of the reranker ONNX weight file
|
|
35
|
+
(default ``model_int8.onnx``).
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
# --- Embedding model ---
|
|
39
|
+
embedding_model_path: str | None = None
|
|
40
|
+
embedding_onnx_file: str = "model.onnx"
|
|
41
|
+
embedding_dim: int = 384
|
|
42
|
+
|
|
43
|
+
# --- Reranker model ---
|
|
44
|
+
rerank_model_path: str | None = None
|
|
45
|
+
rerank_onnx_file: str = "model_int8.onnx"
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pipeline Configuration
|
|
3
|
+
======================
|
|
4
|
+
Dataclass holding algorithm hyper-parameters for the search pipeline.
|
|
5
|
+
|
|
6
|
+
Separation of Concerns:
|
|
7
|
+
- *ModelConfig* → model paths, file names, dimensions
|
|
8
|
+
- *PipelineConfig* → algorithm hyper-parameters
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from enum import Enum
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class QueryTransformStrategy(str, Enum):
|
|
18
|
+
"""Strategy for how query transform results are used in the pipeline.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
REPLACE:
|
|
22
|
+
Replace the original query with the transformed version.
|
|
23
|
+
This is the default and original behavior.
|
|
24
|
+
RETRIEVE_AND_FUSE:
|
|
25
|
+
Retrieve with both the original query AND the transformed
|
|
26
|
+
query, then fuse all result sets together. This implements
|
|
27
|
+
the **Multi-Query Retrieval** pattern for higher recall.
|
|
28
|
+
|
|
29
|
+
- If ``transform()`` returns a single string, two retrievals
|
|
30
|
+
are performed (original + rewritten) and results are fused.
|
|
31
|
+
- If ``transform()`` returns a list of sub-queries, N+1
|
|
32
|
+
retrievals are performed (original + each sub-query) and
|
|
33
|
+
all results are fused via RRF.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
REPLACE = "replace"
|
|
37
|
+
RETRIEVE_AND_FUSE = "retrieve_and_fuse"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class PipelineConfig:
|
|
42
|
+
"""Hyper-parameters for the retrieval & fusion pipeline.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
rrf_k:
|
|
46
|
+
Reciprocal Rank Fusion constant. Higher values flatten
|
|
47
|
+
score differences across ranks.
|
|
48
|
+
top_k_recall:
|
|
49
|
+
Number of candidates to keep after fusion (before reranking).
|
|
50
|
+
query_weight:
|
|
51
|
+
Multiplier applied to RRF base scores.
|
|
52
|
+
bonus_rank1:
|
|
53
|
+
Score bonus for the top-1 document after RRF fusion.
|
|
54
|
+
bonus_rank2_3:
|
|
55
|
+
Score bonus for documents ranked 2-3 after RRF fusion.
|
|
56
|
+
blend_weights:
|
|
57
|
+
Position-aware blending weights mapping rank buckets
|
|
58
|
+
to ``(retrieval_weight, reranker_weight)`` dicts.
|
|
59
|
+
query_transform_strategy:
|
|
60
|
+
How to apply query transformation results.
|
|
61
|
+
``REPLACE`` (default): use rewritten query directly.
|
|
62
|
+
``RETRIEVE_AND_FUSE``: retrieve with both original and
|
|
63
|
+
rewritten queries, then fuse results.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
rrf_k: int = 60
|
|
67
|
+
top_k_recall: int = 30
|
|
68
|
+
query_weight: float = 2.0
|
|
69
|
+
bonus_rank1: float = 0.05
|
|
70
|
+
bonus_rank2_3: float = 0.02
|
|
71
|
+
query_transform_strategy: QueryTransformStrategy = QueryTransformStrategy.REPLACE
|
|
72
|
+
blend_weights: dict[str, dict[str, float]] = field(
|
|
73
|
+
default_factory=lambda: {
|
|
74
|
+
"top1-3": {"retrieval": 0.75, "reranker": 0.25},
|
|
75
|
+
"top4-10": {"retrieval": 0.60, "reranker": 0.40},
|
|
76
|
+
"top11+": {"retrieval": 0.40, "reranker": 0.60},
|
|
77
|
+
}
|
|
78
|
+
)
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Document Deduplicator
|
|
3
|
+
=====================
|
|
4
|
+
Identifies and removes near-duplicate documents using embedding similarity.
|
|
5
|
+
|
|
6
|
+
Features:
|
|
7
|
+
- Embedding-based near-duplicate detection
|
|
8
|
+
- Configurable similarity threshold
|
|
9
|
+
- Returns deduplicated document list
|
|
10
|
+
- Optionally returns duplicate clusters for inspection
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
17
|
+
|
|
18
|
+
from ..protocols import Embedder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Deduplicator:
|
|
22
|
+
"""Document deduplicator using embedding similarity.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
embedder:
|
|
26
|
+
Any ``Embedder`` for computing document embeddings.
|
|
27
|
+
threshold:
|
|
28
|
+
Cosine similarity threshold above which documents
|
|
29
|
+
are considered duplicates (default 0.95).
|
|
30
|
+
|
|
31
|
+
Example::
|
|
32
|
+
|
|
33
|
+
from ragforge import Deduplicator, FastembedEmbedder, ModelConfig
|
|
34
|
+
|
|
35
|
+
dedup = Deduplicator(
|
|
36
|
+
embedder=FastembedEmbedder(ModelConfig()),
|
|
37
|
+
threshold=0.95,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
documents = ["iPhone 价格", "iPhone价格", "华为手机报价"]
|
|
41
|
+
unique = dedup.deduplicate(documents)
|
|
42
|
+
print(f"{len(documents)} → {len(unique)} unique")
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
embedder: Embedder,
|
|
48
|
+
threshold: float = 0.95,
|
|
49
|
+
) -> None:
|
|
50
|
+
self._embedder = embedder
|
|
51
|
+
self._threshold = threshold
|
|
52
|
+
|
|
53
|
+
def deduplicate(
|
|
54
|
+
self,
|
|
55
|
+
documents: list[str],
|
|
56
|
+
) -> list[str]:
|
|
57
|
+
"""Remove near-duplicate documents.
|
|
58
|
+
|
|
59
|
+
Keeps the first occurrence of each duplicate cluster.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
documents: List of document strings.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Deduplicated list preserving original order.
|
|
66
|
+
"""
|
|
67
|
+
if not documents:
|
|
68
|
+
return []
|
|
69
|
+
|
|
70
|
+
embeddings = self._embedder.embed_batch(documents)
|
|
71
|
+
emb_matrix = np.array(embeddings)
|
|
72
|
+
sim_matrix = cosine_similarity(emb_matrix)
|
|
73
|
+
|
|
74
|
+
keep: list[bool] = [True] * len(documents)
|
|
75
|
+
for i in range(len(documents)):
|
|
76
|
+
if not keep[i]:
|
|
77
|
+
continue
|
|
78
|
+
for j in range(i + 1, len(documents)):
|
|
79
|
+
if not keep[j]:
|
|
80
|
+
continue
|
|
81
|
+
if sim_matrix[i][j] >= self._threshold:
|
|
82
|
+
keep[j] = False # j is a duplicate of i
|
|
83
|
+
|
|
84
|
+
return [doc for doc, k in zip(documents, keep) if k]
|
|
85
|
+
|
|
86
|
+
def find_clusters(
|
|
87
|
+
self,
|
|
88
|
+
documents: list[str],
|
|
89
|
+
) -> list[list[int]]:
|
|
90
|
+
"""Find groups of near-duplicate document indices.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
documents: List of document strings.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
List of clusters, where each cluster is a list of
|
|
97
|
+
document indices. Singletons (non-duplicates) are
|
|
98
|
+
not included in the output.
|
|
99
|
+
"""
|
|
100
|
+
if not documents:
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
embeddings = self._embedder.embed_batch(documents)
|
|
104
|
+
emb_matrix = np.array(embeddings)
|
|
105
|
+
sim_matrix = cosine_similarity(emb_matrix)
|
|
106
|
+
|
|
107
|
+
visited: set[int] = set()
|
|
108
|
+
clusters: list[list[int]] = []
|
|
109
|
+
|
|
110
|
+
for i in range(len(documents)):
|
|
111
|
+
if i in visited:
|
|
112
|
+
continue
|
|
113
|
+
cluster = [i]
|
|
114
|
+
for j in range(i + 1, len(documents)):
|
|
115
|
+
if j not in visited and sim_matrix[i][j] >= self._threshold:
|
|
116
|
+
cluster.append(j)
|
|
117
|
+
visited.add(j)
|
|
118
|
+
if len(cluster) > 1:
|
|
119
|
+
clusters.append(sorted(cluster))
|
|
120
|
+
visited.add(i)
|
|
121
|
+
|
|
122
|
+
return clusters
|