lm-deluge 0.0.67__py3-none-any.whl → 0.0.88__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +25 -2
- lm_deluge/api_requests/anthropic.py +92 -17
- lm_deluge/api_requests/base.py +47 -11
- lm_deluge/api_requests/bedrock.py +7 -4
- lm_deluge/api_requests/chat_reasoning.py +4 -0
- lm_deluge/api_requests/gemini.py +138 -18
- lm_deluge/api_requests/openai.py +114 -21
- lm_deluge/client.py +282 -49
- lm_deluge/config.py +15 -3
- lm_deluge/mock_openai.py +643 -0
- lm_deluge/models/__init__.py +12 -1
- lm_deluge/models/anthropic.py +17 -2
- lm_deluge/models/arcee.py +16 -0
- lm_deluge/models/deepseek.py +36 -4
- lm_deluge/models/google.py +29 -0
- lm_deluge/models/grok.py +24 -0
- lm_deluge/models/kimi.py +36 -0
- lm_deluge/models/minimax.py +10 -0
- lm_deluge/models/openai.py +100 -0
- lm_deluge/models/openrouter.py +86 -8
- lm_deluge/models/together.py +11 -0
- lm_deluge/models/zai.py +1 -0
- lm_deluge/pipelines/gepa/__init__.py +95 -0
- lm_deluge/pipelines/gepa/core.py +354 -0
- lm_deluge/pipelines/gepa/docs/samples.py +696 -0
- lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
- lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
- lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
- lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
- lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
- lm_deluge/pipelines/gepa/optimizer.py +435 -0
- lm_deluge/pipelines/gepa/proposer.py +235 -0
- lm_deluge/pipelines/gepa/util.py +165 -0
- lm_deluge/{llm_tools → pipelines}/score.py +2 -2
- lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
- lm_deluge/prompt.py +224 -40
- lm_deluge/request_context.py +7 -2
- lm_deluge/tool/__init__.py +1118 -0
- lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
- lm_deluge/tool/builtin/gemini.py +59 -0
- lm_deluge/tool/builtin/openai.py +74 -0
- lm_deluge/tool/cua/__init__.py +173 -0
- lm_deluge/tool/cua/actions.py +148 -0
- lm_deluge/tool/cua/base.py +27 -0
- lm_deluge/tool/cua/batch.py +215 -0
- lm_deluge/tool/cua/converters.py +466 -0
- lm_deluge/tool/cua/kernel.py +702 -0
- lm_deluge/tool/cua/trycua.py +989 -0
- lm_deluge/tool/prefab/__init__.py +45 -0
- lm_deluge/tool/prefab/batch_tool.py +156 -0
- lm_deluge/tool/prefab/docs.py +1119 -0
- lm_deluge/tool/prefab/email.py +294 -0
- lm_deluge/tool/prefab/filesystem.py +1711 -0
- lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
- lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
- lm_deluge/tool/prefab/memory.py +458 -0
- lm_deluge/tool/prefab/otc/__init__.py +165 -0
- lm_deluge/tool/prefab/otc/executor.py +281 -0
- lm_deluge/tool/prefab/otc/parse.py +188 -0
- lm_deluge/tool/prefab/random.py +212 -0
- lm_deluge/tool/prefab/rlm/__init__.py +296 -0
- lm_deluge/tool/prefab/rlm/executor.py +349 -0
- lm_deluge/tool/prefab/rlm/parse.py +144 -0
- lm_deluge/tool/prefab/sandbox.py +1621 -0
- lm_deluge/tool/prefab/sheets.py +385 -0
- lm_deluge/tool/prefab/subagents.py +233 -0
- lm_deluge/tool/prefab/todos.py +342 -0
- lm_deluge/tool/prefab/tool_search.py +169 -0
- lm_deluge/tool/prefab/web_search.py +199 -0
- lm_deluge/tracker.py +16 -13
- lm_deluge/util/schema.py +412 -0
- lm_deluge/warnings.py +8 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.88.dist-info}/METADATA +22 -9
- lm_deluge-0.0.88.dist-info/RECORD +117 -0
- lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
- lm_deluge/built_in_tools/openai.py +0 -28
- lm_deluge/presets/cerebras.py +0 -17
- lm_deluge/presets/meta.py +0 -13
- lm_deluge/tool.py +0 -849
- lm_deluge-0.0.67.dist-info/RECORD +0 -72
- lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
- /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/anthropic/bash.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/anthropic/computer_use.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
- /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.88.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.88.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.88.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Full text search prefab tool using Tantivy."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import tempfile
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Annotated, Any
|
|
7
|
+
|
|
8
|
+
from lm_deluge.tool import Tool
|
|
9
|
+
|
|
10
|
+
from .tantivy_index import SearchResult, TantivySearch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FullTextSearchManager:
|
|
14
|
+
"""
|
|
15
|
+
Full-text search tools using Tantivy.
|
|
16
|
+
|
|
17
|
+
Provides two tools:
|
|
18
|
+
- search: Search the corpus and get document IDs + previews
|
|
19
|
+
- fetch: Get the full contents of specific documents by ID
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
corpus: List of document dicts to index. Each dict must have an "id" field.
|
|
23
|
+
search_fields: List of field names to search. If None, searches all fields.
|
|
24
|
+
preview_fields: Fields to include in search result previews.
|
|
25
|
+
index_path: Path to store the Tantivy index. If None, uses a temp directory.
|
|
26
|
+
search_tool_name: Name for the search tool (default: "search")
|
|
27
|
+
fetch_tool_name: Name for the fetch tool (default: "fetch")
|
|
28
|
+
max_results: Maximum number of search results to return (default: 10)
|
|
29
|
+
include_fields: Fields to include in the index (searchable). If None, includes all.
|
|
30
|
+
exclude_fields: Fields to exclude from the index (not searchable).
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
```python
|
|
34
|
+
corpus = [
|
|
35
|
+
{"id": "1", "title": "Hello World", "content": "This is a test document."},
|
|
36
|
+
{"id": "2", "title": "Another Doc", "content": "More content here."},
|
|
37
|
+
]
|
|
38
|
+
manager = FullTextSearchManager(
|
|
39
|
+
corpus=corpus,
|
|
40
|
+
search_fields=["title", "content"],
|
|
41
|
+
preview_fields=["title"],
|
|
42
|
+
)
|
|
43
|
+
tools = manager.get_tools()
|
|
44
|
+
```
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
corpus: list[dict[str, Any]],
|
|
50
|
+
*,
|
|
51
|
+
search_fields: list[str] | None = None,
|
|
52
|
+
preview_fields: list[str] | None = None,
|
|
53
|
+
index_path: str | Path | None = None,
|
|
54
|
+
search_tool_name: str = "search",
|
|
55
|
+
fetch_tool_name: str = "fetch",
|
|
56
|
+
max_results: int = 10,
|
|
57
|
+
include_fields: list[str] | None = None,
|
|
58
|
+
exclude_fields: list[str] | None = None,
|
|
59
|
+
deduplicate_by: str | None = None,
|
|
60
|
+
):
|
|
61
|
+
# Initialize _temp_dir early to avoid __del__ issues
|
|
62
|
+
self._temp_dir: str | None = None
|
|
63
|
+
|
|
64
|
+
self.corpus = corpus
|
|
65
|
+
self.search_fields = search_fields
|
|
66
|
+
self.preview_fields = preview_fields
|
|
67
|
+
self.search_tool_name = search_tool_name
|
|
68
|
+
self.fetch_tool_name = fetch_tool_name
|
|
69
|
+
self.max_results = max_results
|
|
70
|
+
self._tools: list[Tool] | None = None
|
|
71
|
+
|
|
72
|
+
# Validate corpus
|
|
73
|
+
if not corpus:
|
|
74
|
+
raise ValueError("Corpus cannot be empty")
|
|
75
|
+
|
|
76
|
+
# Ensure all documents have an id field
|
|
77
|
+
for i, doc in enumerate(corpus):
|
|
78
|
+
if "id" not in doc:
|
|
79
|
+
raise ValueError(f"Document at index {i} is missing 'id' field")
|
|
80
|
+
|
|
81
|
+
# Set up index path
|
|
82
|
+
if index_path is None:
|
|
83
|
+
self._temp_dir = tempfile.mkdtemp(prefix="tantivy_")
|
|
84
|
+
self._index_path = Path(self._temp_dir)
|
|
85
|
+
else:
|
|
86
|
+
self._temp_dir = None
|
|
87
|
+
self._index_path = Path(index_path)
|
|
88
|
+
|
|
89
|
+
# Determine search fields from corpus if not provided
|
|
90
|
+
if search_fields is None:
|
|
91
|
+
# Use all string fields except 'id'
|
|
92
|
+
sample = corpus[0]
|
|
93
|
+
self.search_fields = [
|
|
94
|
+
k for k, v in sample.items() if k != "id" and isinstance(v, str)
|
|
95
|
+
]
|
|
96
|
+
else:
|
|
97
|
+
self.search_fields = search_fields
|
|
98
|
+
|
|
99
|
+
# Initialize Tantivy index
|
|
100
|
+
self._index = TantivySearch(
|
|
101
|
+
index_path=str(self._index_path),
|
|
102
|
+
include_fields=include_fields,
|
|
103
|
+
exclude_fields=exclude_fields,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Build the index
|
|
107
|
+
self._index.build_index(
|
|
108
|
+
records=corpus,
|
|
109
|
+
deduplicate_by=deduplicate_by,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Cache documents for efficient fetch
|
|
113
|
+
self._index.cache_documents_for_fetch(corpus, id_column="id")
|
|
114
|
+
|
|
115
|
+
def _format_preview(self, result: SearchResult) -> dict[str, Any]:
|
|
116
|
+
"""Format a search result for preview."""
|
|
117
|
+
preview: dict[str, Any] = {
|
|
118
|
+
"id": result.id,
|
|
119
|
+
"score": round(result.score, 4),
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Add preview fields
|
|
123
|
+
if self.preview_fields:
|
|
124
|
+
for field in self.preview_fields:
|
|
125
|
+
if field in result.content:
|
|
126
|
+
value = result.content[field]
|
|
127
|
+
# Truncate long values
|
|
128
|
+
if isinstance(value, str) and len(value) > 200:
|
|
129
|
+
value = value[:200] + "..."
|
|
130
|
+
preview[field] = value
|
|
131
|
+
else:
|
|
132
|
+
# Include all fields with truncation
|
|
133
|
+
for field, value in result.content.items():
|
|
134
|
+
if field == "id":
|
|
135
|
+
continue
|
|
136
|
+
if isinstance(value, str) and len(value) > 200:
|
|
137
|
+
value = value[:200] + "..."
|
|
138
|
+
preview[field] = value
|
|
139
|
+
|
|
140
|
+
return preview
|
|
141
|
+
|
|
142
|
+
def _search(
|
|
143
|
+
self,
|
|
144
|
+
query: Annotated[str, "Search query to find relevant documents"],
|
|
145
|
+
limit: Annotated[int, "Maximum number of results to return"] = 10,
|
|
146
|
+
) -> str:
|
|
147
|
+
"""
|
|
148
|
+
Search the corpus for documents matching the query.
|
|
149
|
+
|
|
150
|
+
Returns a list of document previews with IDs and scores.
|
|
151
|
+
Use the fetch tool to get full document contents.
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
# Use the search fields
|
|
155
|
+
assert self.search_fields is not None
|
|
156
|
+
results = self._index.search(
|
|
157
|
+
queries=[query],
|
|
158
|
+
fields=self.search_fields,
|
|
159
|
+
limit=min(limit, self.max_results),
|
|
160
|
+
escape=True,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
previews = [self._format_preview(r) for r in results]
|
|
164
|
+
|
|
165
|
+
return json.dumps(
|
|
166
|
+
{
|
|
167
|
+
"status": "success",
|
|
168
|
+
"query": query,
|
|
169
|
+
"num_results": len(previews),
|
|
170
|
+
"results": previews,
|
|
171
|
+
},
|
|
172
|
+
indent=2,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
return json.dumps({"status": "error", "error": str(e)})
|
|
177
|
+
|
|
178
|
+
def _fetch(
|
|
179
|
+
self,
|
|
180
|
+
document_ids: Annotated[
|
|
181
|
+
list[str], "List of document IDs to fetch full contents for"
|
|
182
|
+
],
|
|
183
|
+
) -> str:
|
|
184
|
+
"""
|
|
185
|
+
Fetch the full contents of documents by their IDs.
|
|
186
|
+
|
|
187
|
+
Use search first to find relevant document IDs.
|
|
188
|
+
"""
|
|
189
|
+
try:
|
|
190
|
+
documents, found_ids, missing_ids = self._index.fetch(document_ids)
|
|
191
|
+
|
|
192
|
+
result: dict[str, Any] = {
|
|
193
|
+
"status": "success",
|
|
194
|
+
"found": len(found_ids),
|
|
195
|
+
"missing": len(missing_ids),
|
|
196
|
+
"documents": documents,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
if missing_ids:
|
|
200
|
+
result["missing_ids"] = missing_ids
|
|
201
|
+
|
|
202
|
+
return json.dumps(result, indent=2)
|
|
203
|
+
|
|
204
|
+
except Exception as e:
|
|
205
|
+
return json.dumps({"status": "error", "error": str(e)})
|
|
206
|
+
|
|
207
|
+
def get_tools(self) -> list[Tool]:
|
|
208
|
+
"""Return the search and fetch tools."""
|
|
209
|
+
if self._tools is not None:
|
|
210
|
+
return self._tools
|
|
211
|
+
|
|
212
|
+
search_tool = Tool.from_function(self._search, name=self.search_tool_name)
|
|
213
|
+
fetch_tool = Tool.from_function(self._fetch, name=self.fetch_tool_name)
|
|
214
|
+
|
|
215
|
+
# Update descriptions for clarity
|
|
216
|
+
search_tool = search_tool.model_copy(
|
|
217
|
+
update={
|
|
218
|
+
"description": (
|
|
219
|
+
"Search the document corpus for relevant results. "
|
|
220
|
+
"Returns document IDs, relevance scores, and previews. "
|
|
221
|
+
"Use the fetch tool to get full document contents."
|
|
222
|
+
)
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
fetch_tool = fetch_tool.model_copy(
|
|
227
|
+
update={
|
|
228
|
+
"description": (
|
|
229
|
+
"Fetch the full contents of documents by their IDs. "
|
|
230
|
+
"Use after searching to get complete document text."
|
|
231
|
+
)
|
|
232
|
+
}
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self._tools = [search_tool, fetch_tool]
|
|
236
|
+
return self._tools
|
|
237
|
+
|
|
238
|
+
def search(
|
|
239
|
+
self, query: str, limit: int = 10, fields: list[str] | None = None
|
|
240
|
+
) -> list[SearchResult]:
|
|
241
|
+
"""
|
|
242
|
+
Direct search method for programmatic use.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
query: Search query string
|
|
246
|
+
limit: Maximum number of results
|
|
247
|
+
fields: Fields to search (defaults to self.search_fields)
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
List of SearchResult objects
|
|
251
|
+
"""
|
|
252
|
+
search_fields = fields or self.search_fields
|
|
253
|
+
assert search_fields is not None
|
|
254
|
+
return self._index.search(
|
|
255
|
+
queries=[query],
|
|
256
|
+
fields=search_fields,
|
|
257
|
+
limit=min(limit, self.max_results),
|
|
258
|
+
escape=True,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def fetch(self, document_ids: list[str]) -> list[dict[str, Any]]:
|
|
262
|
+
"""
|
|
263
|
+
Direct fetch method for programmatic use.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
document_ids: List of document IDs to fetch
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
List of document dicts
|
|
270
|
+
"""
|
|
271
|
+
documents, _, _ = self._index.fetch(document_ids)
|
|
272
|
+
return documents
|
|
273
|
+
|
|
274
|
+
def __del__(self):
|
|
275
|
+
"""Clean up temp directory if used."""
|
|
276
|
+
if self._temp_dir is not None:
|
|
277
|
+
import shutil
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
shutil.rmtree(self._temp_dir, ignore_errors=True)
|
|
281
|
+
except Exception:
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
__all__ = ["FullTextSearchManager", "SearchResult"]
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Iterable, Set, TypedDict
|
|
8
|
+
|
|
9
|
+
from lenlp import normalizer
|
|
10
|
+
from tantivy import Document, Index, SchemaBuilder
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
|
|
13
|
+
# Pattern to match field-like constructs (X:Y)
|
|
14
|
+
FIELDLIKE = re.compile(r'(?<!\\)\b([^\s:"]+)\s*:\s*([^\s")]+)')
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SearchResultContent(TypedDict):
|
|
18
|
+
title: str
|
|
19
|
+
url: str
|
|
20
|
+
description: str
|
|
21
|
+
keywords: str # comma-delimited :)
|
|
22
|
+
content: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SearchResult:
|
|
27
|
+
id: str
|
|
28
|
+
score: float
|
|
29
|
+
content: SearchResultContent | dict[str, Any] # for our purposes should be SRC
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SearchIndex(ABC):
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def build_index(
|
|
35
|
+
self,
|
|
36
|
+
records: list[dict],
|
|
37
|
+
deduplicate_by: str | None = None,
|
|
38
|
+
):
|
|
39
|
+
raise NotImplementedError
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def search(
|
|
43
|
+
self, queries: list[str], fields: list[str], limit: int = 8, escape: bool = True
|
|
44
|
+
) -> list[SearchResult]:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def deduplicate_exact(records: list[dict], deduplicate_by: str) -> list[dict]:
|
|
49
|
+
seen = set()
|
|
50
|
+
deduplicated = []
|
|
51
|
+
for record in tqdm(records):
|
|
52
|
+
key = normalizer.normalize(record[deduplicate_by])
|
|
53
|
+
if key not in seen:
|
|
54
|
+
deduplicated.append(record)
|
|
55
|
+
seen.add(key)
|
|
56
|
+
print(
|
|
57
|
+
f"Deduplication removed {len(records) - len(deduplicated)} of {len(records)} records."
|
|
58
|
+
)
|
|
59
|
+
return deduplicated
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _fix_unbalanced_quotes(s: str) -> str:
|
|
63
|
+
"""Keep balanced quotes; if an odd number of quotes, drop the strays.
|
|
64
|
+
|
|
65
|
+
Handles both double quotes (") and single quotes (') since Tantivy
|
|
66
|
+
treats both as phrase delimiters.
|
|
67
|
+
"""
|
|
68
|
+
# Handle double quotes
|
|
69
|
+
if s.count('"') % 2 != 0:
|
|
70
|
+
s = s.replace('"', " ")
|
|
71
|
+
|
|
72
|
+
# Handle single quotes
|
|
73
|
+
if s.count("'") % 2 != 0:
|
|
74
|
+
s = s.replace("'", " ")
|
|
75
|
+
|
|
76
|
+
return s
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _quote_nonfield_colons(
|
|
80
|
+
q: str, known_fields: Set[str], json_prefixes: Set[str]
|
|
81
|
+
) -> str:
|
|
82
|
+
"""Quote X:Y patterns when X is not a known field name.
|
|
83
|
+
|
|
84
|
+
This prevents strings like "3:12" from being interpreted as field queries.
|
|
85
|
+
Preserves legitimate field queries like "title:roof".
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def repl(m: re.Match) -> str:
|
|
89
|
+
raw_prefix = m.group(1)
|
|
90
|
+
# Handle escaped dots (a\.b)
|
|
91
|
+
norm = raw_prefix.replace(r"\.", ".")
|
|
92
|
+
top = norm.split(".", 1)[0]
|
|
93
|
+
|
|
94
|
+
if norm in known_fields or top in json_prefixes:
|
|
95
|
+
return m.group(0) # Genuine field query; keep as-is
|
|
96
|
+
return f'"{m.group(0)}"' # Protect accidental field-like token
|
|
97
|
+
|
|
98
|
+
return FIELDLIKE.sub(repl, q)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _strip_syntax_to_words(q: str) -> str:
|
|
102
|
+
"""Last-resort sanitizer: remove operators but keep words/numbers.
|
|
103
|
+
|
|
104
|
+
This is used as a fallback when all else fails to ensure we never crash.
|
|
105
|
+
"""
|
|
106
|
+
q = re.sub(r'([+\-!(){}\[\]^"~*?:\\\/]|&&|\|\|)', " ", q)
|
|
107
|
+
q = re.sub(r"\s+", " ", q).strip()
|
|
108
|
+
return q or "*" # Never return empty (match all if nothing left)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def sanitize_query_for_tantivy(
|
|
112
|
+
query: str, known_fields: Iterable[str], json_field_prefixes: Iterable[str] = ()
|
|
113
|
+
) -> str:
|
|
114
|
+
"""Context-aware query sanitization for Tantivy.
|
|
115
|
+
|
|
116
|
+
Uses heuristics to detect whether special characters are intended as
|
|
117
|
+
Tantivy syntax or just regular text:
|
|
118
|
+
- Parentheses: only syntax if AND/OR/&&/|| present
|
|
119
|
+
- Square brackets: only syntax if TO or IN present
|
|
120
|
+
- Curly braces: only syntax if TO present
|
|
121
|
+
|
|
122
|
+
Always quotes X:Y patterns when X is not a known field, and fixes
|
|
123
|
+
unbalanced quotes.
|
|
124
|
+
"""
|
|
125
|
+
# Step 1: Detect if special syntax is actually being used
|
|
126
|
+
has_boolean_ops = bool(re.search(r"\b(AND|OR)\b|&&|\|\|", query))
|
|
127
|
+
has_range_keyword = bool(re.search(r"\bTO\b", query))
|
|
128
|
+
has_set_keyword = bool(re.search(r"\bIN\b", query))
|
|
129
|
+
|
|
130
|
+
# Step 2: Remove syntax chars that are likely just regular text
|
|
131
|
+
if not has_boolean_ops:
|
|
132
|
+
# No boolean operators, so parentheses aren't for grouping
|
|
133
|
+
query = query.replace("(", " ").replace(")", " ")
|
|
134
|
+
|
|
135
|
+
if not has_range_keyword:
|
|
136
|
+
# No ranges, so curly braces aren't for exclusive bounds
|
|
137
|
+
query = query.replace("{", " ").replace("}", " ")
|
|
138
|
+
|
|
139
|
+
if not (has_range_keyword or has_set_keyword):
|
|
140
|
+
# No ranges or sets, so square brackets aren't query syntax
|
|
141
|
+
query = query.replace("[", " ").replace("]", " ")
|
|
142
|
+
|
|
143
|
+
# Step 3: Fix unbalanced quotes (both " and ')
|
|
144
|
+
query = _fix_unbalanced_quotes(query)
|
|
145
|
+
|
|
146
|
+
# Step 4: Quote non-field colons (e.g., "3:12")
|
|
147
|
+
query = _quote_nonfield_colons(query, set(known_fields), set(json_field_prefixes))
|
|
148
|
+
|
|
149
|
+
return query
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class TantivySearch(SearchIndex):
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
index_path: str,
|
|
156
|
+
include_fields: list[str] | None = None,
|
|
157
|
+
exclude_fields: list[str] | None = None,
|
|
158
|
+
):
|
|
159
|
+
self.index_path = Path(index_path)
|
|
160
|
+
self.schema = None
|
|
161
|
+
self.index = None
|
|
162
|
+
self.include_fields = include_fields
|
|
163
|
+
self.exclude_fields = exclude_fields
|
|
164
|
+
self.searchable_fields: set[str] | None = (
|
|
165
|
+
None # Will be set during schema creation
|
|
166
|
+
)
|
|
167
|
+
self.id_to_record: dict[str, dict] = {} # Cache for efficient document fetching
|
|
168
|
+
|
|
169
|
+
def create_schema(self, sample_record: dict):
|
|
170
|
+
schema_builder = SchemaBuilder()
|
|
171
|
+
|
|
172
|
+
# Determine which fields should be searchable
|
|
173
|
+
all_fields = set(sample_record.keys()) - {"id"}
|
|
174
|
+
|
|
175
|
+
if self.include_fields is not None and self.exclude_fields is not None:
|
|
176
|
+
# Both provided: searchable = include - exclude
|
|
177
|
+
self.searchable_fields = set(self.include_fields) - set(self.exclude_fields)
|
|
178
|
+
elif self.include_fields is not None:
|
|
179
|
+
# Only include provided: use only those
|
|
180
|
+
self.searchable_fields = set(self.include_fields)
|
|
181
|
+
elif self.exclude_fields is not None:
|
|
182
|
+
# Only exclude provided: use all except those
|
|
183
|
+
self.searchable_fields = all_fields - set(self.exclude_fields)
|
|
184
|
+
else:
|
|
185
|
+
# Neither provided: all fields searchable
|
|
186
|
+
# Set explicitly so we have field names for query sanitization
|
|
187
|
+
self.searchable_fields = all_fields
|
|
188
|
+
|
|
189
|
+
# Add all fields as text fields (they're all indexed, but we'll control search access)
|
|
190
|
+
for key in sample_record.keys():
|
|
191
|
+
if key == "id":
|
|
192
|
+
continue
|
|
193
|
+
schema_builder.add_text_field(key, stored=True, tokenizer_name="en_stem")
|
|
194
|
+
schema_builder.add_text_field(
|
|
195
|
+
"id", stored=True, tokenizer_name="raw", index_option="basic"
|
|
196
|
+
)
|
|
197
|
+
self.schema = schema_builder.build()
|
|
198
|
+
|
|
199
|
+
@classmethod
|
|
200
|
+
def from_index(cls, index_path: str):
|
|
201
|
+
index = cls(index_path)
|
|
202
|
+
index.load_index()
|
|
203
|
+
return index
|
|
204
|
+
|
|
205
|
+
def load_index(self):
|
|
206
|
+
if not self.index_path.exists():
|
|
207
|
+
raise FileNotFoundError(f"Index not found at {self.index_path}")
|
|
208
|
+
self.index = Index.open(str(self.index_path))
|
|
209
|
+
|
|
210
|
+
def _get_schema_field_names(self) -> set[str]:
|
|
211
|
+
"""Get field names for query sanitization."""
|
|
212
|
+
# searchable_fields is always set after create_schema() is called
|
|
213
|
+
return self.searchable_fields if self.searchable_fields is not None else set()
|
|
214
|
+
|
|
215
|
+
def build_index(
|
|
216
|
+
self,
|
|
217
|
+
records: list[dict],
|
|
218
|
+
deduplicate_by: str | None = None,
|
|
219
|
+
):
|
|
220
|
+
if not records:
|
|
221
|
+
raise ValueError("No records to index")
|
|
222
|
+
|
|
223
|
+
if not self.schema:
|
|
224
|
+
self.create_schema(records[0])
|
|
225
|
+
assert self.schema is not None, "Schema not created"
|
|
226
|
+
|
|
227
|
+
# Create index
|
|
228
|
+
if self.index_path.exists():
|
|
229
|
+
shutil.rmtree(self.index_path)
|
|
230
|
+
os.makedirs(self.index_path, exist_ok=True)
|
|
231
|
+
self.index = Index(self.schema, str(self.index_path))
|
|
232
|
+
|
|
233
|
+
# Deduplicate if requested
|
|
234
|
+
if deduplicate_by is not None:
|
|
235
|
+
records = deduplicate_exact(records, deduplicate_by)
|
|
236
|
+
|
|
237
|
+
# Index documents
|
|
238
|
+
writer = self.index.writer()
|
|
239
|
+
for record in tqdm(records):
|
|
240
|
+
writer.add_document(Document(**{k: [str(v)] for k, v in record.items()}))
|
|
241
|
+
writer.commit()
|
|
242
|
+
writer.wait_merging_threads()
|
|
243
|
+
|
|
244
|
+
def search(
|
|
245
|
+
self, queries: list[str], fields: list[str], limit: int = 8, escape: bool = True
|
|
246
|
+
) -> list[SearchResult]:
|
|
247
|
+
assert self.index is not None, "Index not built"
|
|
248
|
+
|
|
249
|
+
# Validate that requested fields are searchable
|
|
250
|
+
if self.searchable_fields is not None:
|
|
251
|
+
non_searchable = set(fields) - self.searchable_fields
|
|
252
|
+
if non_searchable:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Cannot search non-searchable fields: {sorted(non_searchable)}. "
|
|
255
|
+
f"Searchable fields are: {sorted(self.searchable_fields)}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self.index.reload()
|
|
259
|
+
searcher = self.index.searcher()
|
|
260
|
+
results = []
|
|
261
|
+
|
|
262
|
+
# Get field names for query sanitization
|
|
263
|
+
known_fields = self._get_schema_field_names()
|
|
264
|
+
|
|
265
|
+
for query in queries:
|
|
266
|
+
if escape:
|
|
267
|
+
# Three-tier fallback strategy
|
|
268
|
+
# Tier 1: Try sanitized query with context-aware cleaning
|
|
269
|
+
sanitized = sanitize_query_for_tantivy(query, known_fields)
|
|
270
|
+
try:
|
|
271
|
+
query_obj = self.index.parse_query(sanitized, fields)
|
|
272
|
+
except ValueError:
|
|
273
|
+
# Tier 2: Bag-of-words fallback - strip all operators
|
|
274
|
+
bag_of_words = _strip_syntax_to_words(sanitized)
|
|
275
|
+
try:
|
|
276
|
+
query_obj = self.index.parse_query(bag_of_words, fields)
|
|
277
|
+
except ValueError:
|
|
278
|
+
# This should never happen, but if it does, match nothing
|
|
279
|
+
query_obj = self.index.parse_query("", fields)
|
|
280
|
+
else:
|
|
281
|
+
# User disabled escaping - use query as-is
|
|
282
|
+
try:
|
|
283
|
+
query_obj = self.index.parse_query(query, fields)
|
|
284
|
+
except Exception as e:
|
|
285
|
+
print(f"Error parsing query: {e}")
|
|
286
|
+
query_obj = None
|
|
287
|
+
|
|
288
|
+
if query_obj:
|
|
289
|
+
hits = searcher.search(query_obj, limit=limit).hits
|
|
290
|
+
else:
|
|
291
|
+
hits = []
|
|
292
|
+
|
|
293
|
+
for score, doc_address in hits:
|
|
294
|
+
doc = searcher.doc(doc_address)
|
|
295
|
+
content = {k: v[0] for k, v in doc.to_dict().items()}
|
|
296
|
+
results.append(
|
|
297
|
+
SearchResult(
|
|
298
|
+
id=content["id"],
|
|
299
|
+
score=score,
|
|
300
|
+
content=content,
|
|
301
|
+
)
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Rank fusion using reciprocal rank fusion
|
|
305
|
+
doc_scores = {}
|
|
306
|
+
for rank, result in enumerate(results, 1):
|
|
307
|
+
doc_key = str(result.id)
|
|
308
|
+
if doc_key not in doc_scores:
|
|
309
|
+
doc_scores[doc_key] = 0
|
|
310
|
+
doc_scores[doc_key] += 1 / (60 + rank)
|
|
311
|
+
|
|
312
|
+
# Sort and deduplicate
|
|
313
|
+
unique_results = {}
|
|
314
|
+
for result in results:
|
|
315
|
+
doc_key = str(result.id)
|
|
316
|
+
if doc_key not in unique_results:
|
|
317
|
+
unique_results[doc_key] = result
|
|
318
|
+
|
|
319
|
+
sorted_results = sorted(
|
|
320
|
+
unique_results.values(),
|
|
321
|
+
key=lambda x: doc_scores[str(x.id)],
|
|
322
|
+
reverse=True,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
return sorted_results[:limit]
|
|
326
|
+
|
|
327
|
+
def cache_documents_for_fetch(
|
|
328
|
+
self, documents: list[dict], id_column: str = "id"
|
|
329
|
+
) -> None:
|
|
330
|
+
"""Cache documents for efficient retrieval by ID.
|
|
331
|
+
|
|
332
|
+
This builds an in-memory lookup table for O(1) document fetching.
|
|
333
|
+
Should be called after build_index() with the original document data.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
documents: List of document records (typically original unfiltered data)
|
|
337
|
+
id_column: Name of the ID field in the documents (default: "id")
|
|
338
|
+
"""
|
|
339
|
+
self.id_to_record = {}
|
|
340
|
+
for record in documents:
|
|
341
|
+
doc_id = str(record.get(id_column, ""))
|
|
342
|
+
self.id_to_record[doc_id] = record
|
|
343
|
+
|
|
344
|
+
def fetch(self, document_ids: list[str]) -> tuple[list[dict], list[str], list[str]]:
|
|
345
|
+
"""Fetch documents by ID using cached lookup.
|
|
346
|
+
|
|
347
|
+
Returns full original documents from the cache built by cache_documents_for_fetch().
|
|
348
|
+
If cache is empty, falls back to querying the Tantivy index (slower, returns only indexed fields).
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
document_ids: List of document IDs to retrieve
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Tuple of (documents, found_ids, missing_ids) where:
|
|
355
|
+
- documents: List of document dicts with all fields
|
|
356
|
+
- found_ids: List of IDs that were successfully found
|
|
357
|
+
- missing_ids: List of IDs that were not found
|
|
358
|
+
"""
|
|
359
|
+
results = []
|
|
360
|
+
found_ids = []
|
|
361
|
+
missing_ids = []
|
|
362
|
+
|
|
363
|
+
# Use cached lookup if available
|
|
364
|
+
if self.id_to_record:
|
|
365
|
+
for doc_id in document_ids:
|
|
366
|
+
if doc_id in self.id_to_record:
|
|
367
|
+
results.append(self.id_to_record[doc_id])
|
|
368
|
+
found_ids.append(doc_id)
|
|
369
|
+
else:
|
|
370
|
+
missing_ids.append(doc_id)
|
|
371
|
+
return results, found_ids, missing_ids
|
|
372
|
+
|
|
373
|
+
# Fallback: query Tantivy index (slower, returns only indexed fields)
|
|
374
|
+
assert self.index is not None, "Index not built and no cached documents"
|
|
375
|
+
|
|
376
|
+
self.index.reload()
|
|
377
|
+
searcher = self.index.searcher()
|
|
378
|
+
|
|
379
|
+
for doc_id in document_ids:
|
|
380
|
+
query_str = f'id:"{doc_id}"'
|
|
381
|
+
try:
|
|
382
|
+
query = self.index.parse_query(query_str, ["id"])
|
|
383
|
+
hits = searcher.search(query, limit=1).hits
|
|
384
|
+
|
|
385
|
+
if hits:
|
|
386
|
+
_, doc_address = hits[0]
|
|
387
|
+
doc = searcher.doc(doc_address)
|
|
388
|
+
content = {k: v[0] for k, v in doc.to_dict().items()}
|
|
389
|
+
results.append(content)
|
|
390
|
+
found_ids.append(doc_id)
|
|
391
|
+
else:
|
|
392
|
+
missing_ids.append(doc_id)
|
|
393
|
+
except Exception:
|
|
394
|
+
missing_ids.append(doc_id)
|
|
395
|
+
|
|
396
|
+
return results, found_ids, missing_ids
|