symbolicai 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- symai/__init__.py +198 -134
- symai/backend/base.py +51 -51
- symai/backend/engines/drawing/engine_bfl.py +33 -33
- symai/backend/engines/drawing/engine_gpt_image.py +4 -10
- symai/backend/engines/embedding/engine_llama_cpp.py +50 -35
- symai/backend/engines/embedding/engine_openai.py +22 -16
- symai/backend/engines/execute/engine_python.py +16 -16
- symai/backend/engines/files/engine_io.py +51 -49
- symai/backend/engines/imagecaptioning/engine_blip2.py +27 -23
- symai/backend/engines/imagecaptioning/engine_llavacpp_client.py +53 -46
- symai/backend/engines/index/engine_pinecone.py +116 -88
- symai/backend/engines/index/engine_qdrant.py +1011 -0
- symai/backend/engines/index/engine_vectordb.py +78 -52
- symai/backend/engines/lean/engine_lean4.py +65 -25
- symai/backend/engines/neurosymbolic/__init__.py +35 -28
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_chat.py +137 -135
- symai/backend/engines/neurosymbolic/engine_anthropic_claudeX_reasoning.py +145 -152
- symai/backend/engines/neurosymbolic/engine_cerebras.py +328 -0
- symai/backend/engines/neurosymbolic/engine_deepseekX_reasoning.py +75 -49
- symai/backend/engines/neurosymbolic/engine_google_geminiX_reasoning.py +199 -155
- symai/backend/engines/neurosymbolic/engine_groq.py +106 -72
- symai/backend/engines/neurosymbolic/engine_huggingface.py +100 -67
- symai/backend/engines/neurosymbolic/engine_llama_cpp.py +121 -93
- symai/backend/engines/neurosymbolic/engine_openai_gptX_chat.py +213 -132
- symai/backend/engines/neurosymbolic/engine_openai_gptX_reasoning.py +180 -137
- symai/backend/engines/ocr/engine_apilayer.py +18 -20
- symai/backend/engines/output/engine_stdout.py +9 -9
- symai/backend/engines/{webscraping → scrape}/engine_requests.py +25 -11
- symai/backend/engines/search/engine_openai.py +95 -83
- symai/backend/engines/search/engine_parallel.py +665 -0
- symai/backend/engines/search/engine_perplexity.py +40 -41
- symai/backend/engines/search/engine_serpapi.py +33 -28
- symai/backend/engines/speech_to_text/engine_local_whisper.py +37 -27
- symai/backend/engines/symbolic/engine_wolframalpha.py +14 -8
- symai/backend/engines/text_to_speech/engine_openai.py +15 -19
- symai/backend/engines/text_vision/engine_clip.py +34 -28
- symai/backend/engines/userinput/engine_console.py +3 -4
- symai/backend/mixin/__init__.py +4 -0
- symai/backend/mixin/anthropic.py +48 -40
- symai/backend/mixin/cerebras.py +9 -0
- symai/backend/mixin/deepseek.py +4 -5
- symai/backend/mixin/google.py +5 -4
- symai/backend/mixin/groq.py +2 -4
- symai/backend/mixin/openai.py +132 -110
- symai/backend/settings.py +14 -14
- symai/chat.py +164 -94
- symai/collect/dynamic.py +13 -11
- symai/collect/pipeline.py +39 -31
- symai/collect/stats.py +109 -69
- symai/components.py +578 -238
- symai/constraints.py +14 -5
- symai/core.py +1495 -1210
- symai/core_ext.py +55 -50
- symai/endpoints/api.py +113 -58
- symai/extended/api_builder.py +22 -17
- symai/extended/arxiv_pdf_parser.py +13 -5
- symai/extended/bibtex_parser.py +8 -4
- symai/extended/conversation.py +88 -69
- symai/extended/document.py +40 -27
- symai/extended/file_merger.py +45 -7
- symai/extended/graph.py +38 -24
- symai/extended/html_style_template.py +17 -11
- symai/extended/interfaces/blip_2.py +1 -1
- symai/extended/interfaces/clip.py +4 -2
- symai/extended/interfaces/console.py +5 -3
- symai/extended/interfaces/dall_e.py +3 -1
- symai/extended/interfaces/file.py +2 -0
- symai/extended/interfaces/flux.py +3 -1
- symai/extended/interfaces/gpt_image.py +15 -6
- symai/extended/interfaces/input.py +2 -1
- symai/extended/interfaces/llava.py +1 -1
- symai/extended/interfaces/{naive_webscraping.py → naive_scrape.py} +3 -2
- symai/extended/interfaces/naive_vectordb.py +2 -2
- symai/extended/interfaces/ocr.py +4 -2
- symai/extended/interfaces/openai_search.py +2 -0
- symai/extended/interfaces/parallel.py +30 -0
- symai/extended/interfaces/perplexity.py +2 -0
- symai/extended/interfaces/pinecone.py +6 -4
- symai/extended/interfaces/python.py +2 -0
- symai/extended/interfaces/serpapi.py +2 -0
- symai/extended/interfaces/terminal.py +0 -1
- symai/extended/interfaces/tts.py +2 -1
- symai/extended/interfaces/whisper.py +2 -1
- symai/extended/interfaces/wolframalpha.py +1 -0
- symai/extended/metrics/__init__.py +1 -1
- symai/extended/metrics/similarity.py +5 -2
- symai/extended/os_command.py +31 -22
- symai/extended/packages/symdev.py +39 -34
- symai/extended/packages/sympkg.py +30 -27
- symai/extended/packages/symrun.py +46 -35
- symai/extended/repo_cloner.py +10 -9
- symai/extended/seo_query_optimizer.py +15 -12
- symai/extended/solver.py +104 -76
- symai/extended/summarizer.py +8 -7
- symai/extended/taypan_interpreter.py +10 -9
- symai/extended/vectordb.py +28 -15
- symai/formatter/formatter.py +39 -31
- symai/formatter/regex.py +46 -44
- symai/functional.py +184 -86
- symai/imports.py +85 -51
- symai/interfaces.py +1 -1
- symai/memory.py +33 -24
- symai/menu/screen.py +28 -19
- symai/misc/console.py +27 -27
- symai/misc/loader.py +4 -3
- symai/models/base.py +147 -76
- symai/models/errors.py +1 -1
- symai/ops/__init__.py +1 -1
- symai/ops/measures.py +17 -14
- symai/ops/primitives.py +933 -635
- symai/post_processors.py +28 -24
- symai/pre_processors.py +58 -52
- symai/processor.py +15 -9
- symai/prompts.py +714 -649
- symai/server/huggingface_server.py +115 -32
- symai/server/llama_cpp_server.py +14 -6
- symai/server/qdrant_server.py +206 -0
- symai/shell.py +98 -39
- symai/shellsv.py +307 -223
- symai/strategy.py +135 -81
- symai/symbol.py +276 -225
- symai/utils.py +62 -46
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/METADATA +19 -9
- symbolicai-1.1.1.dist-info/RECORD +169 -0
- symbolicai-1.0.0.dist-info/RECORD +0 -163
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/WHEEL +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/entry_points.txt +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {symbolicai-1.0.0.dist-info → symbolicai-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1011 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import logging
|
|
3
|
+
import tempfile
|
|
4
|
+
import urllib.request
|
|
5
|
+
import uuid
|
|
6
|
+
import warnings
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from .... import core_ext
|
|
13
|
+
from ....symbol import Result, Symbol
|
|
14
|
+
from ....utils import UserMessage
|
|
15
|
+
from ...base import Engine
|
|
16
|
+
from ...settings import SYMAI_CONFIG, SYMSERVER_CONFIG
|
|
17
|
+
|
|
18
|
+
warnings.filterwarnings("ignore", module="qdrant_client")
|
|
19
|
+
try:
|
|
20
|
+
from qdrant_client import QdrantClient
|
|
21
|
+
from qdrant_client.http import models
|
|
22
|
+
from qdrant_client.http.models import (
|
|
23
|
+
Distance,
|
|
24
|
+
Filter,
|
|
25
|
+
PointStruct,
|
|
26
|
+
ScoredPoint,
|
|
27
|
+
VectorParams,
|
|
28
|
+
)
|
|
29
|
+
except ImportError:
|
|
30
|
+
QdrantClient = None
|
|
31
|
+
models = None
|
|
32
|
+
Distance = None
|
|
33
|
+
VectorParams = None
|
|
34
|
+
PointStruct = None
|
|
35
|
+
Filter = None
|
|
36
|
+
ScoredPoint = None
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
from ....components import ChonkieChunker, FileReader
|
|
40
|
+
except ImportError:
|
|
41
|
+
ChonkieChunker = None
|
|
42
|
+
FileReader = None
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from tokenizers import Tokenizer
|
|
46
|
+
except ImportError:
|
|
47
|
+
Tokenizer = None
|
|
48
|
+
|
|
49
|
+
logging.getLogger("qdrant_client").setLevel(logging.ERROR)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def chunks(iterable, batch_size=100):
|
|
53
|
+
"""A helper function to break an iterable into chunks of size batch_size."""
|
|
54
|
+
it = iter(iterable)
|
|
55
|
+
chunk = list(itertools.islice(it, batch_size))
|
|
56
|
+
while chunk:
|
|
57
|
+
yield chunk
|
|
58
|
+
chunk = list(itertools.islice(it, batch_size))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class QdrantResult(Result):
|
|
62
|
+
def __init__(self, res, query: str, embedding: list, **kwargs):
|
|
63
|
+
super().__init__(res, **kwargs)
|
|
64
|
+
self.raw = res
|
|
65
|
+
self._query = query
|
|
66
|
+
self._value = self._process(res)
|
|
67
|
+
self._metadata.raw = embedding
|
|
68
|
+
|
|
69
|
+
def _process(self, res):
|
|
70
|
+
if not res:
|
|
71
|
+
return None
|
|
72
|
+
try:
|
|
73
|
+
# Qdrant returns a list of ScoredPoint objects
|
|
74
|
+
# Convert to format similar to Pinecone for consistency
|
|
75
|
+
if isinstance(res, list):
|
|
76
|
+
matches = []
|
|
77
|
+
for point in res:
|
|
78
|
+
match = {
|
|
79
|
+
"id": point.id if hasattr(point, "id") else None,
|
|
80
|
+
"score": point.score if hasattr(point, "score") else None,
|
|
81
|
+
"metadata": point.payload if hasattr(point, "payload") else {},
|
|
82
|
+
}
|
|
83
|
+
# Extract text from payload if available
|
|
84
|
+
if "text" in match["metadata"]:
|
|
85
|
+
match["metadata"]["text"] = match["metadata"]["text"]
|
|
86
|
+
elif "content" in match["metadata"]:
|
|
87
|
+
match["metadata"]["text"] = match["metadata"]["content"]
|
|
88
|
+
matches.append(match)
|
|
89
|
+
return [v["metadata"].get("text", str(v)) for v in matches if "metadata" in v]
|
|
90
|
+
res = self._to_symbol(res).ast()
|
|
91
|
+
return [v["metadata"]["text"] for v in res.get("matches", []) if "metadata" in v]
|
|
92
|
+
except Exception as e:
|
|
93
|
+
message = [
|
|
94
|
+
"Sorry, failed to interact with Qdrant index. Please check collection name and try again later:",
|
|
95
|
+
str(e),
|
|
96
|
+
]
|
|
97
|
+
return [{"metadata": {"text": "\n".join(message)}}]
|
|
98
|
+
|
|
99
|
+
def _unpack_matches(self):
|
|
100
|
+
if not self.value:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
for i, match_item in enumerate(self.value):
|
|
104
|
+
if isinstance(match_item, dict):
|
|
105
|
+
match_text = match_item.get("metadata", {}).get("text", str(match_item))
|
|
106
|
+
else:
|
|
107
|
+
match_text = str(match_item)
|
|
108
|
+
match_text = match_text.strip()
|
|
109
|
+
if match_text.startswith("# ----[FILE_START]") and "# ----[FILE_END]" in match_text:
|
|
110
|
+
m = match_text.split("[FILE_CONTENT]:")[-1].strip()
|
|
111
|
+
splits = m.split("# ----[FILE_END]")
|
|
112
|
+
assert len(splits) >= 2, f"Invalid file format: {splits}"
|
|
113
|
+
content_text = splits[0]
|
|
114
|
+
file_name = ",".join(splits[1:]) # TODO: check why there are multiple file names
|
|
115
|
+
yield file_name.strip(), content_text.strip()
|
|
116
|
+
else:
|
|
117
|
+
yield i + 1, match_text
|
|
118
|
+
|
|
119
|
+
def __str__(self):
|
|
120
|
+
str_view = ""
|
|
121
|
+
for filename, content_text in self._unpack_matches():
|
|
122
|
+
# indent each line of the content
|
|
123
|
+
indented_content = "\n".join([" " + line for line in content_text.split("\n")])
|
|
124
|
+
str_view += f"* {filename}\n{indented_content}\n\n"
|
|
125
|
+
return f"""
|
|
126
|
+
[RESULT]
|
|
127
|
+
{"-=-" * 13}
|
|
128
|
+
|
|
129
|
+
Query: {self._query}
|
|
130
|
+
|
|
131
|
+
{"-=-" * 13}
|
|
132
|
+
|
|
133
|
+
Matches:
|
|
134
|
+
|
|
135
|
+
{str_view}
|
|
136
|
+
{"-=-" * 13}
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
def _repr_html_(self) -> str:
|
|
140
|
+
# return a nicely styled HTML list results based on retrieved documents
|
|
141
|
+
doc_str = ""
|
|
142
|
+
for filename, content in self._unpack_matches():
|
|
143
|
+
doc_str += f'<li><a href="{filename}"><b>{filename}</a></b><br>{content}</li>\n'
|
|
144
|
+
return f"<ul>{doc_str}</ul>"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class QdrantIndexEngine(Engine):
|
|
148
|
+
_default_url = "http://localhost:6333"
|
|
149
|
+
_default_api_key = SYMAI_CONFIG.get("INDEXING_ENGINE_API_KEY", None)
|
|
150
|
+
_default_index_name = "dataindex"
|
|
151
|
+
_default_index_dims = 1536
|
|
152
|
+
_default_index_top_k = 5
|
|
153
|
+
_default_index_metric = "Cosine"
|
|
154
|
+
_default_index_values = True
|
|
155
|
+
_default_index_metadata = True
|
|
156
|
+
_default_retry_tries = 20
|
|
157
|
+
_default_retry_delay = 0.5
|
|
158
|
+
_default_retry_max_delay = -1
|
|
159
|
+
_default_retry_backoff = 1
|
|
160
|
+
_default_retry_jitter = 0
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
url: str | None = None,
|
|
165
|
+
api_key: str | None = _default_api_key,
|
|
166
|
+
index_name: str = _default_index_name,
|
|
167
|
+
index_dims: int = _default_index_dims,
|
|
168
|
+
index_top_k: int = _default_index_top_k,
|
|
169
|
+
index_metric: str = _default_index_metric,
|
|
170
|
+
index_values: bool = _default_index_values,
|
|
171
|
+
index_metadata: bool = _default_index_metadata,
|
|
172
|
+
tries: int = _default_retry_tries,
|
|
173
|
+
delay: float = _default_retry_delay,
|
|
174
|
+
max_delay: int = _default_retry_max_delay,
|
|
175
|
+
backoff: int = _default_retry_backoff,
|
|
176
|
+
jitter: int = _default_retry_jitter,
|
|
177
|
+
chunker_name: str | None = "RecursiveChunker",
|
|
178
|
+
tokenizer_name: str | None = "gpt2",
|
|
179
|
+
embedding_model_name: str | None = "minishlab/potion-base-8M",
|
|
180
|
+
):
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.index_name = index_name
|
|
183
|
+
self.index_dims = index_dims
|
|
184
|
+
self.index_top_k = index_top_k
|
|
185
|
+
self.index_values = index_values
|
|
186
|
+
self.index_metadata = index_metadata
|
|
187
|
+
self.index_metric = self._parse_metric(index_metric)
|
|
188
|
+
# Get URL from SYMSERVER_CONFIG if available, otherwise use provided or default
|
|
189
|
+
if url:
|
|
190
|
+
self.url = url
|
|
191
|
+
elif SYMSERVER_CONFIG.get("url"):
|
|
192
|
+
self.url = SYMSERVER_CONFIG.get("url")
|
|
193
|
+
elif (
|
|
194
|
+
SYMSERVER_CONFIG.get("online")
|
|
195
|
+
and SYMSERVER_CONFIG.get("--host")
|
|
196
|
+
and SYMSERVER_CONFIG.get("--port")
|
|
197
|
+
):
|
|
198
|
+
self.url = f"http://{SYMSERVER_CONFIG.get('--host')}:{SYMSERVER_CONFIG.get('--port')}"
|
|
199
|
+
else:
|
|
200
|
+
self.url = self._default_url
|
|
201
|
+
self.api_key = api_key
|
|
202
|
+
self.tries = tries
|
|
203
|
+
self.delay = delay
|
|
204
|
+
self.max_delay = max_delay
|
|
205
|
+
self.backoff = backoff
|
|
206
|
+
self.jitter = jitter
|
|
207
|
+
self.client = None
|
|
208
|
+
self.name = self.__class__.__name__
|
|
209
|
+
|
|
210
|
+
# Initialize chunker and reader for manager functionality
|
|
211
|
+
self.chunker_name = chunker_name
|
|
212
|
+
self.tokenizer_name = tokenizer_name
|
|
213
|
+
self.embedding_model_name = embedding_model_name
|
|
214
|
+
self.chunker = None
|
|
215
|
+
self.reader = None
|
|
216
|
+
self.tokenizer = None
|
|
217
|
+
|
|
218
|
+
# Initialize chunker if available
|
|
219
|
+
if ChonkieChunker:
|
|
220
|
+
try:
|
|
221
|
+
self.chunker = ChonkieChunker(
|
|
222
|
+
tokenizer_name=tokenizer_name, embedding_model_name=embedding_model_name
|
|
223
|
+
)
|
|
224
|
+
if Tokenizer:
|
|
225
|
+
self.tokenizer = Tokenizer.from_pretrained(tokenizer_name)
|
|
226
|
+
except Exception as e:
|
|
227
|
+
warnings.warn(f"Failed to initialize chunker: {e}")
|
|
228
|
+
|
|
229
|
+
# Initialize FileReader
|
|
230
|
+
if FileReader:
|
|
231
|
+
try:
|
|
232
|
+
self.reader = FileReader()
|
|
233
|
+
except Exception as e:
|
|
234
|
+
warnings.warn(f"Failed to initialize FileReader: {e}")
|
|
235
|
+
|
|
236
|
+
def _parse_metric(self, metric: str) -> Distance:
|
|
237
|
+
"""Convert string metric to Qdrant Distance enum."""
|
|
238
|
+
if QdrantClient is None:
|
|
239
|
+
return metric
|
|
240
|
+
metric_map = {
|
|
241
|
+
"cosine": Distance.COSINE,
|
|
242
|
+
"dot": Distance.DOT,
|
|
243
|
+
"euclidean": Distance.EUCLID,
|
|
244
|
+
}
|
|
245
|
+
metric_lower = metric.lower()
|
|
246
|
+
return metric_map.get(metric_lower, Distance.COSINE)
|
|
247
|
+
|
|
248
|
+
def id(self) -> str:
|
|
249
|
+
# Check if Qdrant is configured (either via server or direct connection)
|
|
250
|
+
if SYMSERVER_CONFIG.get("online") or self.url:
|
|
251
|
+
if QdrantClient is None:
|
|
252
|
+
UserMessage(
|
|
253
|
+
"Qdrant client is not installed. Please install it with `pip install qdrant-client`.",
|
|
254
|
+
raise_with=ImportError,
|
|
255
|
+
)
|
|
256
|
+
return "index"
|
|
257
|
+
return super().id() # default to unregistered
|
|
258
|
+
|
|
259
|
+
def command(self, *args, **kwargs):
|
|
260
|
+
super().command(*args, **kwargs)
|
|
261
|
+
if "INDEXING_ENGINE_API_KEY" in kwargs:
|
|
262
|
+
self.api_key = kwargs["INDEXING_ENGINE_API_KEY"]
|
|
263
|
+
if "INDEXING_ENGINE_URL" in kwargs:
|
|
264
|
+
self.url = kwargs["INDEXING_ENGINE_URL"]
|
|
265
|
+
|
|
266
|
+
def _init_client(self):
|
|
267
|
+
"""Initialize Qdrant client if not already initialized."""
|
|
268
|
+
if self.client is None:
|
|
269
|
+
if QdrantClient is None:
|
|
270
|
+
UserMessage(
|
|
271
|
+
"Qdrant client is not installed. Please install it with `pip install qdrant-client`.",
|
|
272
|
+
raise_with=ImportError,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
client_kwargs = {"url": self.url}
|
|
276
|
+
if self.api_key:
|
|
277
|
+
client_kwargs["api_key"] = self.api_key
|
|
278
|
+
|
|
279
|
+
self.client = QdrantClient(**client_kwargs)
|
|
280
|
+
|
|
281
|
+
def _create_collection_sync(
|
|
282
|
+
self,
|
|
283
|
+
collection_name: str,
|
|
284
|
+
vector_size: int | None = None,
|
|
285
|
+
distance: (str | Distance) | None = None,
|
|
286
|
+
**kwargs,
|
|
287
|
+
):
|
|
288
|
+
"""Synchronous collection creation for internal use."""
|
|
289
|
+
self._init_client()
|
|
290
|
+
|
|
291
|
+
vector_size = vector_size or self.index_dims
|
|
292
|
+
if isinstance(distance, str):
|
|
293
|
+
distance = self._parse_metric(distance)
|
|
294
|
+
else:
|
|
295
|
+
distance = distance or self.index_metric
|
|
296
|
+
|
|
297
|
+
if not self.client.collection_exists(collection_name):
|
|
298
|
+
self.client.create_collection(
|
|
299
|
+
collection_name=collection_name,
|
|
300
|
+
vectors_config=VectorParams(size=vector_size, distance=distance),
|
|
301
|
+
**kwargs,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _init_collection(
|
|
305
|
+
self, collection_name: str, collection_dims: int, collection_metric: Distance
|
|
306
|
+
):
|
|
307
|
+
"""Initialize or create Qdrant collection (legacy method, uses _create_collection_sync)."""
|
|
308
|
+
self._create_collection_sync(collection_name, collection_dims, collection_metric)
|
|
309
|
+
|
|
310
|
+
def _configure_collection(self, **kwargs):
|
|
311
|
+
collection_name = kwargs.get("index_name", self.index_name)
|
|
312
|
+
del_ = kwargs.get("index_del", False)
|
|
313
|
+
|
|
314
|
+
if self.client is not None and del_:
|
|
315
|
+
try:
|
|
316
|
+
self.client.delete_collection(collection_name=collection_name)
|
|
317
|
+
except Exception as e:
|
|
318
|
+
warnings.warn(f"Failed to delete collection {collection_name}: {e}")
|
|
319
|
+
|
|
320
|
+
get_ = kwargs.get("index_get", False)
|
|
321
|
+
if get_:
|
|
322
|
+
# Reinitialize client to refresh collection list
|
|
323
|
+
self._init_client()
|
|
324
|
+
|
|
325
|
+
def _prepare_points_for_upsert(
|
|
326
|
+
self,
|
|
327
|
+
embeddings: list | np.ndarray | Any,
|
|
328
|
+
ids: list[int] | None = None,
|
|
329
|
+
payloads: list[dict] | None = None,
|
|
330
|
+
) -> list[PointStruct]:
|
|
331
|
+
"""Prepare points for upsert from embeddings, ids, and payloads."""
|
|
332
|
+
points = []
|
|
333
|
+
|
|
334
|
+
# Normalize to list
|
|
335
|
+
if isinstance(embeddings, np.ndarray):
|
|
336
|
+
embeddings = [embeddings] if embeddings.ndim == 1 else list(embeddings)
|
|
337
|
+
elif not isinstance(embeddings, list):
|
|
338
|
+
embeddings = [embeddings]
|
|
339
|
+
|
|
340
|
+
for i, vec in enumerate(embeddings):
|
|
341
|
+
point_id = ids[i] if ids and i < len(ids) else i
|
|
342
|
+
payload = payloads[i] if payloads and i < len(payloads) else {}
|
|
343
|
+
points.append(
|
|
344
|
+
PointStruct(id=point_id, vector=self._normalize_vector(vec), payload=payload)
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return points
|
|
348
|
+
|
|
349
|
+
def forward(self, argument):
|
|
350
|
+
kwargs = argument.kwargs
|
|
351
|
+
embedding = argument.prop.prepared_input
|
|
352
|
+
query = argument.prop.ori_query
|
|
353
|
+
operation = argument.prop.operation
|
|
354
|
+
collection_name = argument.prop.index_name if argument.prop.index_name else self.index_name
|
|
355
|
+
collection_dims = argument.prop.index_dims if argument.prop.index_dims else self.index_dims
|
|
356
|
+
rsp = None
|
|
357
|
+
|
|
358
|
+
# Initialize client
|
|
359
|
+
self._init_client()
|
|
360
|
+
|
|
361
|
+
if collection_name != self.index_name:
|
|
362
|
+
assert collection_name, "Please set a valid collection name for Qdrant indexing engine."
|
|
363
|
+
# switch collection
|
|
364
|
+
self.index_name = collection_name
|
|
365
|
+
kwargs["index_get"] = True
|
|
366
|
+
self._configure_collection(**kwargs)
|
|
367
|
+
|
|
368
|
+
if operation == "search":
|
|
369
|
+
# Ensure collection exists - fail fast if it doesn't
|
|
370
|
+
self._ensure_collection_exists(collection_name)
|
|
371
|
+
index_top_k = kwargs.get("index_top_k", self.index_top_k)
|
|
372
|
+
# Use existing _query method
|
|
373
|
+
rsp = self._query(collection_name, embedding, index_top_k)
|
|
374
|
+
elif operation == "add":
|
|
375
|
+
# Create collection if it doesn't exist (only for write operations)
|
|
376
|
+
self._create_collection_sync(collection_name, collection_dims, self.index_metric)
|
|
377
|
+
# Use shared point preparation method
|
|
378
|
+
ids = kwargs.get("ids", None)
|
|
379
|
+
payloads = kwargs.get("payloads", None)
|
|
380
|
+
points = self._prepare_points_for_upsert(embedding, ids, payloads)
|
|
381
|
+
|
|
382
|
+
# Use existing _upsert method in batches
|
|
383
|
+
for points_chunk in chunks(points, batch_size=100):
|
|
384
|
+
self._upsert(collection_name, points_chunk)
|
|
385
|
+
rsp = None
|
|
386
|
+
elif operation == "config":
|
|
387
|
+
# Ensure collection exists - fail fast if it doesn't
|
|
388
|
+
self._ensure_collection_exists(collection_name)
|
|
389
|
+
self._configure_collection(**kwargs)
|
|
390
|
+
rsp = None
|
|
391
|
+
else:
|
|
392
|
+
msg = "Invalid operation. Supported operations: search, add, config"
|
|
393
|
+
raise ValueError(msg)
|
|
394
|
+
|
|
395
|
+
metadata = {}
|
|
396
|
+
|
|
397
|
+
rsp = QdrantResult(rsp, query, embedding)
|
|
398
|
+
return [rsp], metadata
|
|
399
|
+
|
|
400
|
+
def prepare(self, argument):
|
|
401
|
+
assert not argument.prop.processed_input, (
|
|
402
|
+
"Qdrant indexing engine does not support processed_input."
|
|
403
|
+
)
|
|
404
|
+
argument.prop.prepared_input = argument.prop.prompt
|
|
405
|
+
|
|
406
|
+
def _upsert(self, collection_name: str, points: list[PointStruct]):
|
|
407
|
+
@core_ext.retry(
|
|
408
|
+
tries=self.tries,
|
|
409
|
+
delay=self.delay,
|
|
410
|
+
max_delay=self.max_delay,
|
|
411
|
+
backoff=self.backoff,
|
|
412
|
+
jitter=self.jitter,
|
|
413
|
+
)
|
|
414
|
+
def _func():
|
|
415
|
+
return self.client.upsert(collection_name=collection_name, points=points)
|
|
416
|
+
|
|
417
|
+
return _func()
|
|
418
|
+
|
|
419
|
+
def _normalize_vector(self, vector: list[float] | np.ndarray) -> list[float]:
|
|
420
|
+
"""Normalize vector to flat list format, handling 2D arrays and nested lists."""
|
|
421
|
+
if isinstance(vector, np.ndarray):
|
|
422
|
+
# Flatten if 2D (e.g., shape (1, 1536) -> (1536,))
|
|
423
|
+
if vector.ndim > 1:
|
|
424
|
+
vector = vector.flatten()
|
|
425
|
+
return vector.tolist()
|
|
426
|
+
if not isinstance(vector, list):
|
|
427
|
+
vector = list(vector)
|
|
428
|
+
|
|
429
|
+
# Handle nested lists that might have slipped through
|
|
430
|
+
if vector and len(vector) > 0 and isinstance(vector[0], list):
|
|
431
|
+
# Flatten nested list (e.g., [[1, 2, 3]] -> [1, 2, 3])
|
|
432
|
+
if len(vector) == 1:
|
|
433
|
+
vector = vector[0]
|
|
434
|
+
else:
|
|
435
|
+
vector = [item for sublist in vector for item in sublist]
|
|
436
|
+
|
|
437
|
+
return vector
|
|
438
|
+
|
|
439
|
+
def _query(self, collection_name: str, query_vector: list[float], top_k: int, **kwargs):
|
|
440
|
+
@core_ext.retry(
|
|
441
|
+
tries=self.tries,
|
|
442
|
+
delay=self.delay,
|
|
443
|
+
max_delay=self.max_delay,
|
|
444
|
+
backoff=self.backoff,
|
|
445
|
+
jitter=self.jitter,
|
|
446
|
+
)
|
|
447
|
+
def _func():
|
|
448
|
+
query_vector_normalized = self._normalize_vector(query_vector)
|
|
449
|
+
return self.client.search(
|
|
450
|
+
collection_name=collection_name,
|
|
451
|
+
query_vector=query_vector_normalized,
|
|
452
|
+
limit=top_k,
|
|
453
|
+
with_payload=True,
|
|
454
|
+
with_vectors=self.index_values,
|
|
455
|
+
**kwargs,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
return _func()
|
|
459
|
+
|
|
460
|
+
# ==================== Manager Methods ====================
|
|
461
|
+
|
|
462
|
+
def _check_initialization(self):
|
|
463
|
+
"""Check if engine is properly initialized."""
|
|
464
|
+
if self.client is None:
|
|
465
|
+
self._init_client()
|
|
466
|
+
if self.client is None:
|
|
467
|
+
msg = "Qdrant client not properly initialized."
|
|
468
|
+
raise RuntimeError(msg)
|
|
469
|
+
|
|
470
|
+
def _ensure_collection_exists(self, collection_name: str):
|
|
471
|
+
"""Ensure collection exists, raise error if not."""
|
|
472
|
+
self._check_initialization()
|
|
473
|
+
if not self.client.collection_exists(collection_name):
|
|
474
|
+
msg = f"Collection '{collection_name}' does not exist"
|
|
475
|
+
raise ValueError(msg)
|
|
476
|
+
|
|
477
|
+
# ==================== Collection Management ====================
|
|
478
|
+
|
|
479
|
+
async def create_collection(
|
|
480
|
+
self,
|
|
481
|
+
collection_name: str,
|
|
482
|
+
vector_size: int | None = None,
|
|
483
|
+
distance: (str | Distance) | None = None,
|
|
484
|
+
**kwargs,
|
|
485
|
+
):
|
|
486
|
+
"""
|
|
487
|
+
Create a new collection in Qdrant.
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
collection_name: Name of the collection to create
|
|
491
|
+
vector_size: Size of the vectors in this collection (defaults to index_dims)
|
|
492
|
+
distance: Distance metric (COSINE, EUCLIDEAN, or DOT) or string
|
|
493
|
+
**kwargs: Additional collection configuration parameters
|
|
494
|
+
"""
|
|
495
|
+
self._check_initialization()
|
|
496
|
+
|
|
497
|
+
if self.client.collection_exists(collection_name):
|
|
498
|
+
warnings.warn(f"Collection '{collection_name}' already exists")
|
|
499
|
+
return
|
|
500
|
+
|
|
501
|
+
# Use shared synchronous method
|
|
502
|
+
self._create_collection_sync(collection_name, vector_size, distance, **kwargs)
|
|
503
|
+
|
|
504
|
+
async def collection_exists(self, collection_name: str) -> bool:
|
|
505
|
+
"""
|
|
506
|
+
Check if a collection exists.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
collection_name: Name of the collection to check
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
True if collection exists, False otherwise
|
|
513
|
+
"""
|
|
514
|
+
self._check_initialization()
|
|
515
|
+
return self.client.collection_exists(collection_name)
|
|
516
|
+
|
|
517
|
+
async def list_collections(self) -> list[str]:
|
|
518
|
+
"""
|
|
519
|
+
List all collections in Qdrant.
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
List of collection names
|
|
523
|
+
"""
|
|
524
|
+
self._check_initialization()
|
|
525
|
+
collections = self.client.get_collections().collections
|
|
526
|
+
return [collection.name for collection in collections]
|
|
527
|
+
|
|
528
|
+
async def delete_collection(self, collection_name: str):
|
|
529
|
+
"""
|
|
530
|
+
Delete a collection from Qdrant.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
collection_name: Name of the collection to delete
|
|
534
|
+
"""
|
|
535
|
+
self._ensure_collection_exists(collection_name)
|
|
536
|
+
self.client.delete_collection(collection_name)
|
|
537
|
+
|
|
538
|
+
async def get_collection_info(self, collection_name: str) -> dict:
|
|
539
|
+
"""
|
|
540
|
+
Get information about a collection.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
collection_name: Name of the collection
|
|
544
|
+
|
|
545
|
+
Returns:
|
|
546
|
+
Dictionary containing collection information
|
|
547
|
+
"""
|
|
548
|
+
self._ensure_collection_exists(collection_name)
|
|
549
|
+
collection_info = self.client.get_collection(collection_name)
|
|
550
|
+
# Extract vector config - handle both single vector and named vectors
|
|
551
|
+
vector_config = collection_info.config.params.vectors
|
|
552
|
+
if hasattr(vector_config, "size"):
|
|
553
|
+
# Single vector configuration
|
|
554
|
+
vectors_info = {
|
|
555
|
+
"size": vector_config.size,
|
|
556
|
+
"distance": vector_config.distance,
|
|
557
|
+
}
|
|
558
|
+
else:
|
|
559
|
+
# Named vectors configuration
|
|
560
|
+
vectors_info = {
|
|
561
|
+
"named_vectors": {
|
|
562
|
+
name: {"size": vec.size, "distance": vec.distance}
|
|
563
|
+
for name, vec in vector_config.items()
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
return {
|
|
567
|
+
"name": collection_name,
|
|
568
|
+
"vectors_count": collection_info.vectors_count,
|
|
569
|
+
"indexed_vectors_count": collection_info.indexed_vectors_count,
|
|
570
|
+
"points_count": collection_info.points_count,
|
|
571
|
+
"config": {"params": {"vectors": vectors_info}},
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
# ==================== Point Operations ====================
|
|
575
|
+
|
|
576
|
+
def _upsert_points_sync(
|
|
577
|
+
self,
|
|
578
|
+
collection_name: str,
|
|
579
|
+
points: list[PointStruct] | list[dict],
|
|
580
|
+
**kwargs, # noqa: ARG002
|
|
581
|
+
):
|
|
582
|
+
"""Synchronous upsert for internal use."""
|
|
583
|
+
self._ensure_collection_exists(collection_name)
|
|
584
|
+
|
|
585
|
+
# Convert dict to PointStruct if needed, and normalize vectors
|
|
586
|
+
if not points:
|
|
587
|
+
msg = "Points list cannot be empty"
|
|
588
|
+
raise ValueError(msg)
|
|
589
|
+
if isinstance(points[0], dict):
|
|
590
|
+
points = [
|
|
591
|
+
PointStruct(
|
|
592
|
+
id=point["id"],
|
|
593
|
+
vector=self._normalize_vector(point["vector"]),
|
|
594
|
+
payload=point.get("payload", {}),
|
|
595
|
+
)
|
|
596
|
+
for point in points
|
|
597
|
+
]
|
|
598
|
+
else:
|
|
599
|
+
# Normalize vectors in existing PointStruct objects
|
|
600
|
+
points = [
|
|
601
|
+
PointStruct(
|
|
602
|
+
id=point.id,
|
|
603
|
+
vector=self._normalize_vector(point.vector),
|
|
604
|
+
payload=point.payload,
|
|
605
|
+
)
|
|
606
|
+
for point in points
|
|
607
|
+
]
|
|
608
|
+
|
|
609
|
+
# Upsert in batches using existing _upsert method
|
|
610
|
+
for points_chunk in chunks(points, batch_size=100):
|
|
611
|
+
self._upsert(collection_name, points_chunk)
|
|
612
|
+
|
|
613
|
+
async def upsert(
|
|
614
|
+
self,
|
|
615
|
+
collection_name: str,
|
|
616
|
+
points: list[PointStruct] | list[dict],
|
|
617
|
+
**kwargs,
|
|
618
|
+
):
|
|
619
|
+
"""
|
|
620
|
+
Insert or update points in a collection.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
collection_name: Name of the collection
|
|
624
|
+
points: List of PointStruct objects or dictionaries with id, vector, and optional payload
|
|
625
|
+
**kwargs: Additional arguments for upsert operation
|
|
626
|
+
"""
|
|
627
|
+
# Use shared synchronous method
|
|
628
|
+
self._upsert_points_sync(collection_name, points, **kwargs)
|
|
629
|
+
|
|
630
|
+
async def insert(
|
|
631
|
+
self,
|
|
632
|
+
collection_name: str,
|
|
633
|
+
points: list[PointStruct] | list[dict],
|
|
634
|
+
**kwargs,
|
|
635
|
+
):
|
|
636
|
+
"""
|
|
637
|
+
Insert points into a collection (alias for upsert).
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
collection_name: Name of the collection
|
|
641
|
+
points: List of PointStruct objects or dictionaries with id, vector, and optional payload
|
|
642
|
+
**kwargs: Additional arguments for insert operation
|
|
643
|
+
"""
|
|
644
|
+
await self.upsert(collection_name, points, **kwargs)
|
|
645
|
+
|
|
646
|
+
async def delete(
|
|
647
|
+
self,
|
|
648
|
+
collection_name: str,
|
|
649
|
+
points_selector: list[int] | int,
|
|
650
|
+
**kwargs,
|
|
651
|
+
):
|
|
652
|
+
"""
|
|
653
|
+
Delete points from a collection.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
collection_name: Name of the collection
|
|
657
|
+
points_selector: Point IDs to delete (single ID or list of IDs)
|
|
658
|
+
**kwargs: Additional arguments for delete operation
|
|
659
|
+
"""
|
|
660
|
+
self._ensure_collection_exists(collection_name)
|
|
661
|
+
|
|
662
|
+
# Convert single ID to list if needed
|
|
663
|
+
if isinstance(points_selector, int):
|
|
664
|
+
points_selector = [points_selector]
|
|
665
|
+
|
|
666
|
+
self.client.delete(
|
|
667
|
+
collection_name=collection_name, points_selector=points_selector, **kwargs
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
async def retrieve(
|
|
671
|
+
self,
|
|
672
|
+
collection_name: str,
|
|
673
|
+
ids: list[int] | int,
|
|
674
|
+
with_payload: bool = True,
|
|
675
|
+
with_vectors: bool = False,
|
|
676
|
+
**kwargs,
|
|
677
|
+
) -> list[dict]:
|
|
678
|
+
"""
|
|
679
|
+
Retrieve points by their IDs.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
collection_name: Name of the collection
|
|
683
|
+
ids: Point IDs to retrieve (single ID or list of IDs)
|
|
684
|
+
with_payload: Whether to include payload in results
|
|
685
|
+
with_vectors: Whether to include vectors in results
|
|
686
|
+
**kwargs: Additional arguments for retrieve operation
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
List of point dictionaries
|
|
690
|
+
"""
|
|
691
|
+
self._ensure_collection_exists(collection_name)
|
|
692
|
+
|
|
693
|
+
# Convert single ID to list if needed
|
|
694
|
+
if isinstance(ids, int):
|
|
695
|
+
ids = [ids]
|
|
696
|
+
|
|
697
|
+
points = self.client.retrieve(
|
|
698
|
+
collection_name=collection_name,
|
|
699
|
+
ids=ids,
|
|
700
|
+
with_payload=with_payload,
|
|
701
|
+
with_vectors=with_vectors,
|
|
702
|
+
**kwargs,
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
# Convert to list of dicts for easier use
|
|
706
|
+
result = []
|
|
707
|
+
for point in points:
|
|
708
|
+
point_dict = {"id": point.id}
|
|
709
|
+
if with_payload and point.payload:
|
|
710
|
+
point_dict["payload"] = point.payload
|
|
711
|
+
if with_vectors and point.vector:
|
|
712
|
+
point_dict["vector"] = point.vector
|
|
713
|
+
result.append(point_dict)
|
|
714
|
+
|
|
715
|
+
return result
|
|
716
|
+
|
|
717
|
+
# ==================== Search Operations ====================
|
|
718
|
+
|
|
719
|
+
def _search_sync(
|
|
720
|
+
self,
|
|
721
|
+
collection_name: str,
|
|
722
|
+
query_vector: list[float] | np.ndarray,
|
|
723
|
+
limit: int = 10,
|
|
724
|
+
score_threshold: float | None = None,
|
|
725
|
+
query_filter: Filter | None = None,
|
|
726
|
+
**kwargs,
|
|
727
|
+
) -> list[ScoredPoint]:
|
|
728
|
+
"""Synchronous search for internal use."""
|
|
729
|
+
self._ensure_collection_exists(collection_name)
|
|
730
|
+
|
|
731
|
+
# Build kwargs for search
|
|
732
|
+
search_kwargs = {"score_threshold": score_threshold, "query_filter": query_filter, **kwargs}
|
|
733
|
+
# Remove None values
|
|
734
|
+
search_kwargs = {k: v for k, v in search_kwargs.items() if v is not None}
|
|
735
|
+
|
|
736
|
+
# Use _query which handles retry logic and vector normalization
|
|
737
|
+
return self._query(collection_name, query_vector, limit, **search_kwargs)
|
|
738
|
+
|
|
739
|
+
async def search(
|
|
740
|
+
self,
|
|
741
|
+
collection_name: str,
|
|
742
|
+
query_vector: list[float] | np.ndarray,
|
|
743
|
+
limit: int = 10,
|
|
744
|
+
score_threshold: float | None = None,
|
|
745
|
+
query_filter: Filter | None = None,
|
|
746
|
+
**kwargs,
|
|
747
|
+
) -> list[ScoredPoint]:
|
|
748
|
+
"""
|
|
749
|
+
Search for similar vectors in a collection.
|
|
750
|
+
|
|
751
|
+
Args:
|
|
752
|
+
collection_name: Name of the collection to search
|
|
753
|
+
query_vector: Query vector to search for
|
|
754
|
+
limit: Maximum number of results to return
|
|
755
|
+
score_threshold: Minimum similarity score threshold
|
|
756
|
+
query_filter: Optional filter to apply to the search
|
|
757
|
+
**kwargs: Additional search parameters
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
List of ScoredPoint objects containing id, score, and payload
|
|
761
|
+
"""
|
|
762
|
+
# Use shared synchronous method
|
|
763
|
+
return self._search_sync(
|
|
764
|
+
collection_name, query_vector, limit, score_threshold, query_filter, **kwargs
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
# ==================== Document Operations with Chunking ====================
|
|
768
|
+
|
|
769
|
+
def _download_and_read_file(self, file_url: str) -> str:
|
|
770
|
+
"""
|
|
771
|
+
Download file from URL and read it using FileReader.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
file_url: URL to the file to download
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
Text content of the file
|
|
778
|
+
"""
|
|
779
|
+
if self.reader is None:
|
|
780
|
+
msg = "FileReader not initialized"
|
|
781
|
+
raise RuntimeError(msg)
|
|
782
|
+
|
|
783
|
+
file_path = Path(file_url)
|
|
784
|
+
suffix = file_path.suffix
|
|
785
|
+
with (
|
|
786
|
+
urllib.request.urlopen(file_url) as f,
|
|
787
|
+
tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file,
|
|
788
|
+
):
|
|
789
|
+
tmp_file.write(f.read())
|
|
790
|
+
tmp_file.flush()
|
|
791
|
+
tmp_file_name = tmp_file.name
|
|
792
|
+
|
|
793
|
+
try:
|
|
794
|
+
content = self.reader(tmp_file_name)
|
|
795
|
+
return content.value[0] if isinstance(content.value, list) else str(content.value)
|
|
796
|
+
finally:
|
|
797
|
+
# Clean up temporary file
|
|
798
|
+
tmp_path = Path(tmp_file_name)
|
|
799
|
+
if tmp_path.exists():
|
|
800
|
+
tmp_path.unlink()
|
|
801
|
+
|
|
802
|
+
async def chunk_and_upsert( # noqa: C901
|
|
803
|
+
self,
|
|
804
|
+
collection_name: str,
|
|
805
|
+
text: str | Symbol | None = None,
|
|
806
|
+
document_path: str | None = None,
|
|
807
|
+
document_url: str | None = None,
|
|
808
|
+
chunker_name: str | None = None,
|
|
809
|
+
chunker_kwargs: dict | None = None,
|
|
810
|
+
start_id: int | None = None,
|
|
811
|
+
metadata: dict | None = None,
|
|
812
|
+
**upsert_kwargs,
|
|
813
|
+
):
|
|
814
|
+
"""
|
|
815
|
+
Chunk text or documents using ChonkieChunker and upsert the chunks with embeddings into Qdrant.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
collection_name: Name of the collection to upsert into
|
|
819
|
+
text: Text to chunk (string or Symbol). If None, document_path or document_url must be provided.
|
|
820
|
+
document_path: Path to a document file to read using FileReader (PDF, etc.)
|
|
821
|
+
document_url: URL to a document file to download and read using FileReader
|
|
822
|
+
chunker_name: Name of the chunker to use. If None, uses the instance default chunker_name
|
|
823
|
+
chunker_kwargs: Additional keyword arguments for the chunker
|
|
824
|
+
start_id: Starting ID for the chunks (auto-incremented). If None, uses hash-based IDs
|
|
825
|
+
metadata: Optional metadata to add to all chunk payloads
|
|
826
|
+
**upsert_kwargs: Additional arguments for upsert operation
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
Number of chunks upserted
|
|
830
|
+
"""
|
|
831
|
+
self._ensure_collection_exists(collection_name)
|
|
832
|
+
|
|
833
|
+
# Validate input: exactly one of text, document_path, or document_url must be provided
|
|
834
|
+
input_count = sum(x is not None for x in [text, document_path, document_url])
|
|
835
|
+
if input_count == 0:
|
|
836
|
+
msg = "One of `text`, `document_path`, or `document_url` must be provided"
|
|
837
|
+
raise ValueError(msg)
|
|
838
|
+
if input_count > 1:
|
|
839
|
+
msg = "Only one of `text`, `document_path`, or `document_url` can be provided"
|
|
840
|
+
raise ValueError(msg)
|
|
841
|
+
|
|
842
|
+
# Get collection info to determine vector size
|
|
843
|
+
collection_info = await self.get_collection_info(collection_name)
|
|
844
|
+
vector_config = collection_info["config"]["params"]["vectors"]
|
|
845
|
+
if "size" in vector_config:
|
|
846
|
+
vector_size = vector_config["size"]
|
|
847
|
+
else:
|
|
848
|
+
# For named vectors, we need to specify which one to use
|
|
849
|
+
# Default to first named vector or raise error
|
|
850
|
+
named_vectors = vector_config.get("named_vectors", {})
|
|
851
|
+
if not named_vectors:
|
|
852
|
+
msg = "Collection has no vector configuration"
|
|
853
|
+
raise ValueError(msg)
|
|
854
|
+
vector_size = next(iter(named_vectors.values()))["size"]
|
|
855
|
+
|
|
856
|
+
# Check if chunker is initialized
|
|
857
|
+
if self.chunker is None:
|
|
858
|
+
msg = "Chunker not initialized. Please ensure ChonkieChunker is available."
|
|
859
|
+
raise RuntimeError(msg)
|
|
860
|
+
|
|
861
|
+
# Use instance chunker and default chunker_name if not provided
|
|
862
|
+
chunker_kwargs = chunker_kwargs or {}
|
|
863
|
+
if chunker_name is None:
|
|
864
|
+
chunker_name = self.chunker_name
|
|
865
|
+
|
|
866
|
+
# Handle document_path: read file using FileReader
|
|
867
|
+
if document_path is not None:
|
|
868
|
+
if self.reader is None:
|
|
869
|
+
msg = "FileReader not initialized. Please ensure FileReader is available."
|
|
870
|
+
raise RuntimeError(msg)
|
|
871
|
+
doc_path = Path(document_path)
|
|
872
|
+
if not doc_path.exists():
|
|
873
|
+
msg = f"Document file not found: {document_path}"
|
|
874
|
+
raise FileNotFoundError(msg)
|
|
875
|
+
content = self.reader(document_path)
|
|
876
|
+
text = content.value[0] if isinstance(content.value, list) else str(content.value)
|
|
877
|
+
# Add source to metadata if not already present
|
|
878
|
+
if metadata is None:
|
|
879
|
+
metadata = {}
|
|
880
|
+
if "source" not in metadata:
|
|
881
|
+
metadata["source"] = doc_path.name
|
|
882
|
+
|
|
883
|
+
# Handle document_url: download and read file using FileReader
|
|
884
|
+
elif document_url is not None:
|
|
885
|
+
if self.reader is None:
|
|
886
|
+
msg = "FileReader not initialized. Please ensure FileReader is available."
|
|
887
|
+
raise RuntimeError(msg)
|
|
888
|
+
text = self._download_and_read_file(document_url)
|
|
889
|
+
# Add source to metadata if not already present
|
|
890
|
+
if metadata is None:
|
|
891
|
+
metadata = {}
|
|
892
|
+
if "source" not in metadata:
|
|
893
|
+
metadata["source"] = document_url
|
|
894
|
+
|
|
895
|
+
# Convert text to Symbol if needed
|
|
896
|
+
text_symbol = Symbol(text) if isinstance(text, str) else text
|
|
897
|
+
|
|
898
|
+
# Chunk the text using instance chunker
|
|
899
|
+
chunks_symbol = self.chunker.forward(
|
|
900
|
+
text_symbol, chunker_name=chunker_name, **chunker_kwargs
|
|
901
|
+
)
|
|
902
|
+
chunks = chunks_symbol.value if hasattr(chunks_symbol, "value") else chunks_symbol
|
|
903
|
+
|
|
904
|
+
if not chunks:
|
|
905
|
+
warnings.warn("No chunks generated from text")
|
|
906
|
+
return 0
|
|
907
|
+
|
|
908
|
+
# Ensure chunks is a list
|
|
909
|
+
if not isinstance(chunks, list):
|
|
910
|
+
chunks = [chunks]
|
|
911
|
+
|
|
912
|
+
# Generate embeddings and create points
|
|
913
|
+
points = []
|
|
914
|
+
current_id = start_id if start_id is not None else 0
|
|
915
|
+
for chunk_item in chunks:
|
|
916
|
+
# Clean the chunk text
|
|
917
|
+
if ChonkieChunker:
|
|
918
|
+
chunk_text = ChonkieChunker.clean_text(str(chunk_item))
|
|
919
|
+
else:
|
|
920
|
+
chunk_text = str(chunk_item)
|
|
921
|
+
|
|
922
|
+
if not chunk_text.strip():
|
|
923
|
+
continue
|
|
924
|
+
|
|
925
|
+
# Generate embedding using Symbol's embedding property
|
|
926
|
+
chunk_symbol = Symbol(chunk_text)
|
|
927
|
+
|
|
928
|
+
# Generate embedding - Symbol has embedding property that returns numpy array
|
|
929
|
+
try:
|
|
930
|
+
embedding = chunk_symbol.embedding
|
|
931
|
+
except (AttributeError, Exception) as e:
|
|
932
|
+
# Fallback: try using Expression's embed method
|
|
933
|
+
try:
|
|
934
|
+
embedding = chunk_symbol.embed()
|
|
935
|
+
if hasattr(embedding, "value"):
|
|
936
|
+
embedding = embedding.value
|
|
937
|
+
except Exception as embed_err:
|
|
938
|
+
msg = f"Could not generate embedding for chunk. Error: {e}"
|
|
939
|
+
raise ValueError(msg) from embed_err
|
|
940
|
+
|
|
941
|
+
# Normalize embedding to flat list using existing helper
|
|
942
|
+
if isinstance(embedding, np.ndarray):
|
|
943
|
+
# Flatten if 2D (e.g., shape (1, 1536) -> (1536,))
|
|
944
|
+
if embedding.ndim > 1:
|
|
945
|
+
embedding = embedding.flatten()
|
|
946
|
+
embedding = embedding.tolist()
|
|
947
|
+
elif isinstance(embedding, list):
|
|
948
|
+
# Ensure embedding is a flat list (handle nested lists)
|
|
949
|
+
if embedding and len(embedding) > 0 and isinstance(embedding[0], list):
|
|
950
|
+
# Flatten nested list (e.g., [[1, 2, 3]] -> [1, 2, 3])
|
|
951
|
+
embedding = (
|
|
952
|
+
embedding[0]
|
|
953
|
+
if len(embedding) == 1
|
|
954
|
+
else [item for sublist in embedding for item in sublist]
|
|
955
|
+
)
|
|
956
|
+
else:
|
|
957
|
+
# Try to convert to list
|
|
958
|
+
try:
|
|
959
|
+
embedding = list(embedding) if embedding else []
|
|
960
|
+
except (TypeError, ValueError) as e:
|
|
961
|
+
msg = (
|
|
962
|
+
f"Could not generate embedding for chunk. "
|
|
963
|
+
f"Expected list or array, got type: {type(embedding)}"
|
|
964
|
+
)
|
|
965
|
+
raise ValueError(msg) from e
|
|
966
|
+
|
|
967
|
+
# Truncate or pad embedding to match vector_size
|
|
968
|
+
original_size = len(embedding)
|
|
969
|
+
if original_size != vector_size:
|
|
970
|
+
if original_size > vector_size:
|
|
971
|
+
embedding = embedding[:vector_size]
|
|
972
|
+
else:
|
|
973
|
+
embedding = embedding + [0.0] * (vector_size - original_size)
|
|
974
|
+
warnings.warn(
|
|
975
|
+
f"Embedding size ({original_size}) adjusted to match collection vector size ({vector_size})"
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Create payload
|
|
979
|
+
payload = {"text": chunk_text}
|
|
980
|
+
if metadata:
|
|
981
|
+
payload.update(metadata)
|
|
982
|
+
|
|
983
|
+
# Generate ID
|
|
984
|
+
if start_id is not None:
|
|
985
|
+
point_id = current_id
|
|
986
|
+
current_id += 1
|
|
987
|
+
else:
|
|
988
|
+
# Use uuid5 for deterministic, collision-resistant IDs based on content
|
|
989
|
+
# uuid5 uses SHA-1 internally, providing 160 bits of entropy
|
|
990
|
+
# Convert to int64 by taking modulo 2**63 to fit in signed 64-bit range
|
|
991
|
+
namespace_uuid = uuid.NAMESPACE_DNS # Use DNS namespace for consistency
|
|
992
|
+
uuid_obj = uuid.uuid5(namespace_uuid, chunk_text)
|
|
993
|
+
# Convert UUID (128 bits) to int64, ensuring it fits in signed 64-bit range
|
|
994
|
+
point_id = uuid_obj.int % (2**63)
|
|
995
|
+
|
|
996
|
+
points.append(
|
|
997
|
+
{
|
|
998
|
+
"id": point_id,
|
|
999
|
+
"vector": embedding,
|
|
1000
|
+
"payload": payload,
|
|
1001
|
+
}
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
if not points:
|
|
1005
|
+
warnings.warn("No valid points to upsert")
|
|
1006
|
+
return 0
|
|
1007
|
+
|
|
1008
|
+
# Upsert the points using shared synchronous method
|
|
1009
|
+
self._upsert_points_sync(collection_name=collection_name, points=points, **upsert_kwargs)
|
|
1010
|
+
|
|
1011
|
+
return len(points)
|