langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,262 @@
|
|
1
|
+
"""
|
2
|
+
Momento Vector Index.
|
3
|
+
https://docs.momentohq.com/vector-index/develop/api-reference
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
from typing import List, Optional, Sequence, Tuple, no_type_check
|
9
|
+
|
10
|
+
import momento.responses.vector_index as mvi_response
|
11
|
+
from dotenv import load_dotenv
|
12
|
+
from momento import (
|
13
|
+
# PreviewVectorIndexClientAsync,
|
14
|
+
CredentialProvider,
|
15
|
+
PreviewVectorIndexClient,
|
16
|
+
VectorIndexConfigurations,
|
17
|
+
)
|
18
|
+
from momento.requests.vector_index import (
|
19
|
+
ALL_METADATA,
|
20
|
+
Item,
|
21
|
+
SimilarityMetric,
|
22
|
+
)
|
23
|
+
|
24
|
+
from langroid.embedding_models.base import (
|
25
|
+
EmbeddingModel,
|
26
|
+
EmbeddingModelsConfig,
|
27
|
+
)
|
28
|
+
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
29
|
+
from langroid.mytypes import Document, EmbeddingFunction
|
30
|
+
from langroid.utils.configuration import settings
|
31
|
+
from langroid.utils.pydantic_utils import (
|
32
|
+
flatten_pydantic_instance,
|
33
|
+
nested_dict_from_flat,
|
34
|
+
)
|
35
|
+
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
class MomentoVIConfig(VectorStoreConfig):
|
41
|
+
cloud: bool = True
|
42
|
+
collection_name: str | None = "temp"
|
43
|
+
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
44
|
+
distance: SimilarityMetric = SimilarityMetric.COSINE_SIMILARITY
|
45
|
+
|
46
|
+
|
47
|
+
class MomentoVI(VectorStore):
|
48
|
+
def __init__(self, config: MomentoVIConfig = MomentoVIConfig()):
|
49
|
+
super().__init__(config)
|
50
|
+
self.config: MomentoVIConfig = config
|
51
|
+
emb_model = EmbeddingModel.create(config.embedding)
|
52
|
+
self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
|
53
|
+
self.embedding_dim = emb_model.embedding_dims
|
54
|
+
self.host = config.host
|
55
|
+
self.port = config.port
|
56
|
+
load_dotenv()
|
57
|
+
api_key = os.getenv("MOMENTO_API_KEY")
|
58
|
+
if config.cloud:
|
59
|
+
if api_key is None:
|
60
|
+
raise ValueError(
|
61
|
+
"""MOMENTO_API_KEY env variable must be set to
|
62
|
+
MomentoVI hosted service. Please set this in your .env file.
|
63
|
+
"""
|
64
|
+
)
|
65
|
+
self.client = PreviewVectorIndexClient(
|
66
|
+
configuration=VectorIndexConfigurations.Default.latest(),
|
67
|
+
credential_provider=CredentialProvider.from_string(api_key),
|
68
|
+
)
|
69
|
+
else:
|
70
|
+
raise NotImplementedError("MomentoVI local not available yet")
|
71
|
+
|
72
|
+
# Note: Only create collection if a non-null collection name is provided.
|
73
|
+
# This is useful to delay creation of vecdb until we have a suitable
|
74
|
+
# collection name (e.g. we could get it from the url or folder path).
|
75
|
+
if config.collection_name is not None:
|
76
|
+
self.create_collection(
|
77
|
+
config.collection_name, replace=config.replace_collection
|
78
|
+
)
|
79
|
+
|
80
|
+
def clear_empty_collections(self) -> int:
|
81
|
+
logger.warning(
|
82
|
+
"""
|
83
|
+
Momento VI does not yet have a way to easily get size of indices,
|
84
|
+
so clear_empty_collections is not deleting any indices.
|
85
|
+
"""
|
86
|
+
)
|
87
|
+
return 0
|
88
|
+
|
89
|
+
def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
|
90
|
+
"""Clear all collections with the given prefix."""
|
91
|
+
|
92
|
+
if not really:
|
93
|
+
logger.warning("Not deleting all collections, set really=True to confirm")
|
94
|
+
return 0
|
95
|
+
coll_names = self.list_collections(empty=False)
|
96
|
+
coll_names = [name for name in coll_names if name.startswith(prefix)]
|
97
|
+
if len(coll_names) == 0:
|
98
|
+
logger.warning(f"No collections found with prefix {prefix}")
|
99
|
+
return 0
|
100
|
+
for name in coll_names:
|
101
|
+
self.delete_collection(name)
|
102
|
+
logger.warning(
|
103
|
+
f"""
|
104
|
+
Deleted {len(coll_names)} indices from Momento VI
|
105
|
+
"""
|
106
|
+
)
|
107
|
+
return len(coll_names)
|
108
|
+
|
109
|
+
def list_collections(self, empty: bool = False) -> List[str]:
|
110
|
+
"""
|
111
|
+
Returns:
|
112
|
+
List of collection names that have at least one vector.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
empty (bool, optional): Whether to include empty collections.
|
116
|
+
"""
|
117
|
+
response = self.client.list_indexes()
|
118
|
+
if isinstance(response, mvi_response.ListIndexes.Success):
|
119
|
+
return [ind.name for ind in response.indexes]
|
120
|
+
elif isinstance(response, mvi_response.ListIndexes.Error):
|
121
|
+
raise ValueError(f"Error listing collections: {response.message}")
|
122
|
+
else:
|
123
|
+
raise ValueError(f"Unexpected response: {response}")
|
124
|
+
|
125
|
+
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
126
|
+
"""
|
127
|
+
Create a collection with the given name, optionally replacing an existing
|
128
|
+
collection if `replace` is True.
|
129
|
+
Args:
|
130
|
+
collection_name (str): Name of the collection to create.
|
131
|
+
replace (bool): Whether to replace an existing collection
|
132
|
+
with the same name. Defaults to False.
|
133
|
+
"""
|
134
|
+
self.config.collection_name = collection_name
|
135
|
+
response = self.client.create_index(
|
136
|
+
index_name=collection_name,
|
137
|
+
num_dimensions=self.embedding_dim,
|
138
|
+
similarity_metric=self.config.distance,
|
139
|
+
)
|
140
|
+
if isinstance(response, mvi_response.CreateIndex.Success):
|
141
|
+
logger.info(f"Created collection {collection_name}")
|
142
|
+
elif isinstance(response, mvi_response.CreateIndex.IndexAlreadyExists):
|
143
|
+
logger.warning(f"Collection {collection_name} already exists")
|
144
|
+
elif isinstance(response, mvi_response.CreateIndex.Error):
|
145
|
+
raise ValueError(
|
146
|
+
f"Error creating collection {collection_name}: {response.message}"
|
147
|
+
)
|
148
|
+
if settings.debug:
|
149
|
+
level = logger.getEffectiveLevel()
|
150
|
+
logger.setLevel(logging.INFO)
|
151
|
+
logger.info(f"Collection {collection_name} created")
|
152
|
+
logger.setLevel(level)
|
153
|
+
|
154
|
+
def add_documents(self, documents: Sequence[Document]) -> None:
|
155
|
+
super().maybe_add_ids(documents)
|
156
|
+
if len(documents) == 0:
|
157
|
+
return
|
158
|
+
embedding_vecs = self.embedding_fn([doc.content for doc in documents])
|
159
|
+
if self.config.collection_name is None:
|
160
|
+
raise ValueError("No collection name set, cannot ingest docs")
|
161
|
+
|
162
|
+
self.create_collection(self.config.collection_name, replace=True)
|
163
|
+
|
164
|
+
items = [
|
165
|
+
Item(
|
166
|
+
id=str(d.id()),
|
167
|
+
vector=embedding_vecs[i],
|
168
|
+
metadata=flatten_pydantic_instance(d, force_str=True),
|
169
|
+
# force all values to str since Momento requires it
|
170
|
+
)
|
171
|
+
for i, d in enumerate(documents)
|
172
|
+
]
|
173
|
+
|
174
|
+
# don't insert all at once, batch in chunks of b,
|
175
|
+
# else we get an API error
|
176
|
+
b = self.config.batch_size
|
177
|
+
for i in range(0, len(documents), b):
|
178
|
+
response = self.client.upsert_item_batch(
|
179
|
+
index_name=self.config.collection_name,
|
180
|
+
items=items[i : i + b],
|
181
|
+
)
|
182
|
+
if isinstance(response, mvi_response.UpsertItemBatch.Success):
|
183
|
+
continue
|
184
|
+
elif isinstance(response, mvi_response.UpsertItemBatch.Error):
|
185
|
+
raise ValueError(f"Error adding documents: {response.message}")
|
186
|
+
else:
|
187
|
+
raise ValueError(f"Unexpected response: {response}")
|
188
|
+
|
189
|
+
def delete_collection(self, collection_name: str) -> None:
|
190
|
+
delete_response = self.client.delete_index(collection_name)
|
191
|
+
if isinstance(delete_response, mvi_response.DeleteIndex.Success):
|
192
|
+
logger.warning(f"Deleted index {collection_name}")
|
193
|
+
elif isinstance(delete_response, mvi_response.DeleteIndex.Error):
|
194
|
+
logger.error(
|
195
|
+
f"Error while deleting index {collection_name}: "
|
196
|
+
f" {delete_response.message}"
|
197
|
+
)
|
198
|
+
|
199
|
+
def _to_int_or_uuid(self, id: str) -> int | str:
|
200
|
+
try:
|
201
|
+
return int(id)
|
202
|
+
except ValueError:
|
203
|
+
return id
|
204
|
+
|
205
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
206
|
+
raise NotImplementedError(
|
207
|
+
"""
|
208
|
+
MomentoVI does not support get_all_documents().
|
209
|
+
Please use a different vector database, e.g. qdrant or chromadb.
|
210
|
+
"""
|
211
|
+
)
|
212
|
+
|
213
|
+
def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
|
214
|
+
raise NotImplementedError(
|
215
|
+
"""
|
216
|
+
MomentoVI does not support get_documents_by_ids.
|
217
|
+
Please use a different vector database, e.g. qdrant or chromadb.
|
218
|
+
"""
|
219
|
+
)
|
220
|
+
|
221
|
+
@no_type_check
|
222
|
+
def similar_texts_with_scores(
|
223
|
+
self,
|
224
|
+
text: str,
|
225
|
+
k: int = 1,
|
226
|
+
where: Optional[str] = None,
|
227
|
+
neighbors: int = 0, # ignored
|
228
|
+
) -> List[Tuple[Document, float]]:
|
229
|
+
if self.config.collection_name is None:
|
230
|
+
raise ValueError("No collection name set, cannot search")
|
231
|
+
embedding = self.embedding_fn([text])[0]
|
232
|
+
response = self.client.search(
|
233
|
+
index_name=self.config.collection_name,
|
234
|
+
query_vector=embedding,
|
235
|
+
top_k=k,
|
236
|
+
metadata_fields=ALL_METADATA,
|
237
|
+
)
|
238
|
+
|
239
|
+
if isinstance(response, mvi_response.Search.Error):
|
240
|
+
logger.warning(
|
241
|
+
f"Error while searching on index {self.config.collection_name}:"
|
242
|
+
f" {response.message}"
|
243
|
+
)
|
244
|
+
return []
|
245
|
+
elif not isinstance(response, mvi_response.Search.Success):
|
246
|
+
logger.warning(f"Unexpected response: {response}")
|
247
|
+
return []
|
248
|
+
|
249
|
+
scores = [match.metadata["distance"] for match in response.hits]
|
250
|
+
docs = [
|
251
|
+
Document.parse_obj(nested_dict_from_flat(match.metadata))
|
252
|
+
for match in response.hits
|
253
|
+
if match is not None
|
254
|
+
]
|
255
|
+
if len(docs) == 0:
|
256
|
+
logger.warning(f"No matches found for {text}")
|
257
|
+
return []
|
258
|
+
if settings.debug:
|
259
|
+
logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
|
260
|
+
doc_score_pairs = list(zip(docs, scores))
|
261
|
+
self.show_if_debug(doc_score_pairs)
|
262
|
+
return doc_score_pairs
|
@@ -1,8 +1,10 @@
|
|
1
|
+
import hashlib
|
2
|
+
import json
|
1
3
|
import logging
|
2
4
|
import os
|
3
|
-
|
5
|
+
import uuid
|
6
|
+
from typing import List, Optional, Sequence, Tuple, TypeVar
|
4
7
|
|
5
|
-
from chromadb.api.types import EmbeddingFunction
|
6
8
|
from dotenv import load_dotenv
|
7
9
|
from qdrant_client import QdrantClient
|
8
10
|
from qdrant_client.conversions.common_types import ScoredPoint
|
@@ -20,23 +22,50 @@ from langroid.embedding_models.base import (
|
|
20
22
|
EmbeddingModelsConfig,
|
21
23
|
)
|
22
24
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
23
|
-
from langroid.mytypes import Document
|
25
|
+
from langroid.mytypes import Document, EmbeddingFunction
|
24
26
|
from langroid.utils.configuration import settings
|
25
27
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
26
28
|
|
27
29
|
logger = logging.getLogger(__name__)
|
28
30
|
|
29
31
|
|
32
|
+
T = TypeVar("T")
|
33
|
+
|
34
|
+
|
35
|
+
def from_optional(x: Optional[T], default: T) -> T:
|
36
|
+
if x is None:
|
37
|
+
return default
|
38
|
+
|
39
|
+
return x
|
40
|
+
|
41
|
+
|
42
|
+
def is_valid_uuid(uuid_to_test: str) -> bool:
|
43
|
+
"""
|
44
|
+
Check if a given string is a valid UUID.
|
45
|
+
"""
|
46
|
+
try:
|
47
|
+
uuid_obj = uuid.UUID(uuid_to_test)
|
48
|
+
return str(uuid_obj) == uuid_to_test
|
49
|
+
except Exception:
|
50
|
+
pass
|
51
|
+
# Check for valid unsigned 64-bit integer
|
52
|
+
try:
|
53
|
+
int_value = int(uuid_to_test)
|
54
|
+
return 0 <= int_value <= 18446744073709551615
|
55
|
+
except ValueError:
|
56
|
+
return False
|
57
|
+
|
58
|
+
|
30
59
|
class QdrantDBConfig(VectorStoreConfig):
|
31
60
|
cloud: bool = True
|
32
|
-
collection_name: str | None =
|
61
|
+
collection_name: str | None = "temp"
|
33
62
|
storage_path: str = ".qdrant/data"
|
34
63
|
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
35
64
|
distance: str = Distance.COSINE
|
36
65
|
|
37
66
|
|
38
67
|
class QdrantDB(VectorStore):
|
39
|
-
def __init__(self, config: QdrantDBConfig):
|
68
|
+
def __init__(self, config: QdrantDBConfig = QdrantDBConfig()):
|
40
69
|
super().__init__(config)
|
41
70
|
self.config = config
|
42
71
|
emb_model = EmbeddingModel.create(config.embedding)
|
@@ -113,8 +142,10 @@ class QdrantDB(VectorStore):
|
|
113
142
|
n_non_empty_deletes = 0
|
114
143
|
for name in coll_names:
|
115
144
|
info = self.client.get_collection(collection_name=name)
|
116
|
-
|
117
|
-
|
145
|
+
points_count = from_optional(info.points_count, 0)
|
146
|
+
|
147
|
+
n_empty_deletes += points_count == 0
|
148
|
+
n_non_empty_deletes += points_count > 0
|
118
149
|
self.client.delete_collection(collection_name=name)
|
119
150
|
logger.warning(
|
120
151
|
f"""
|
@@ -135,11 +166,21 @@ class QdrantDB(VectorStore):
|
|
135
166
|
colls = list(self.client.get_collections())[0][1]
|
136
167
|
if empty:
|
137
168
|
return [coll.name for coll in colls]
|
138
|
-
counts = [
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
169
|
+
counts = []
|
170
|
+
for coll in colls:
|
171
|
+
try:
|
172
|
+
counts.append(
|
173
|
+
from_optional(
|
174
|
+
self.client.get_collection(
|
175
|
+
collection_name=coll.name
|
176
|
+
).points_count,
|
177
|
+
0,
|
178
|
+
)
|
179
|
+
)
|
180
|
+
except Exception:
|
181
|
+
logger.warning(f"Error getting collection {coll.name}")
|
182
|
+
counts.append(0)
|
183
|
+
return [coll.name for coll, count in zip(colls, counts) if (count or 0) > 0]
|
143
184
|
|
144
185
|
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
145
186
|
"""
|
@@ -154,7 +195,10 @@ class QdrantDB(VectorStore):
|
|
154
195
|
collections = self.list_collections()
|
155
196
|
if collection_name in collections:
|
156
197
|
coll = self.client.get_collection(collection_name=collection_name)
|
157
|
-
if
|
198
|
+
if (
|
199
|
+
coll.status == CollectionStatus.GREEN
|
200
|
+
and from_optional(coll.points_count, 0) > 0
|
201
|
+
):
|
158
202
|
logger.warning(f"Non-empty Collection {collection_name} already exists")
|
159
203
|
if not replace:
|
160
204
|
logger.warning("Not replacing collection")
|
@@ -178,9 +222,15 @@ class QdrantDB(VectorStore):
|
|
178
222
|
logger.setLevel(level)
|
179
223
|
|
180
224
|
def add_documents(self, documents: Sequence[Document]) -> None:
|
225
|
+
# Add id to metadata if not already present
|
226
|
+
super().maybe_add_ids(documents)
|
227
|
+
# Fix the ids due to qdrant finickiness
|
228
|
+
for doc in documents:
|
229
|
+
doc.metadata.id = str(self._to_int_or_uuid(doc.metadata.id))
|
181
230
|
colls = self.list_collections(empty=True)
|
182
231
|
if len(documents) == 0:
|
183
232
|
return
|
233
|
+
document_dicts = [doc.dict() for doc in documents]
|
184
234
|
embedding_vecs = self.embedding_fn([doc.content for doc in documents])
|
185
235
|
if self.config.collection_name is None:
|
186
236
|
raise ValueError("No collection name set, cannot ingest docs")
|
@@ -196,7 +246,7 @@ class QdrantDB(VectorStore):
|
|
196
246
|
points=Batch(
|
197
247
|
ids=ids[i : i + b],
|
198
248
|
vectors=embedding_vecs[i : i + b],
|
199
|
-
payloads=
|
249
|
+
payloads=document_dicts[i : i + b],
|
200
250
|
),
|
201
251
|
)
|
202
252
|
|
@@ -205,19 +255,42 @@ class QdrantDB(VectorStore):
|
|
205
255
|
|
206
256
|
def _to_int_or_uuid(self, id: str) -> int | str:
|
207
257
|
try:
|
208
|
-
|
258
|
+
int_val = int(id)
|
259
|
+
if is_valid_uuid(id):
|
260
|
+
return int_val
|
209
261
|
except ValueError:
|
262
|
+
pass
|
263
|
+
|
264
|
+
# If doc_id is already a valid UUID, return it as is
|
265
|
+
if isinstance(id, str) and is_valid_uuid(id):
|
210
266
|
return id
|
211
267
|
|
212
|
-
|
268
|
+
# Otherwise, generate a UUID from the doc_id
|
269
|
+
# Convert doc_id to string if it's not already
|
270
|
+
id_str = str(id)
|
271
|
+
|
272
|
+
# Hash the document ID using SHA-1
|
273
|
+
hash_object = hashlib.sha1(id_str.encode())
|
274
|
+
hash_digest = hash_object.hexdigest()
|
275
|
+
|
276
|
+
# Truncate or manipulate the hash to fit into a UUID (128 bits)
|
277
|
+
uuid_str = hash_digest[:32]
|
278
|
+
|
279
|
+
# Format this string into a UUID format
|
280
|
+
formatted_uuid = uuid.UUID(uuid_str)
|
281
|
+
|
282
|
+
return str(formatted_uuid)
|
283
|
+
|
284
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
213
285
|
if self.config.collection_name is None:
|
214
286
|
raise ValueError("No collection name set, cannot retrieve docs")
|
215
287
|
docs = []
|
216
288
|
offset = 0
|
289
|
+
filter = Filter() if where == "" else Filter.parse_obj(json.loads(where))
|
217
290
|
while True:
|
218
291
|
results, next_page_offset = self.client.scroll(
|
219
292
|
collection_name=self.config.collection_name,
|
220
|
-
scroll_filter=
|
293
|
+
scroll_filter=filter,
|
221
294
|
offset=offset,
|
222
295
|
limit=10_000, # try getting all at once, if not we keep paging
|
223
296
|
with_payload=True,
|
@@ -239,7 +312,11 @@ class QdrantDB(VectorStore):
|
|
239
312
|
with_vectors=False,
|
240
313
|
with_payload=True,
|
241
314
|
)
|
242
|
-
|
315
|
+
# Note the records may NOT be in the order of the ids,
|
316
|
+
# so we re-order them here.
|
317
|
+
id2payload = {record.id: record.payload for record in records}
|
318
|
+
ordered_payloads = [id2payload[id] for id in _ids]
|
319
|
+
docs = [Document(**payload) for payload in ordered_payloads] # type: ignore
|
243
320
|
return docs
|
244
321
|
|
245
322
|
def similar_texts_with_scores(
|
@@ -247,10 +324,14 @@ class QdrantDB(VectorStore):
|
|
247
324
|
text: str,
|
248
325
|
k: int = 1,
|
249
326
|
where: Optional[str] = None,
|
327
|
+
neighbors: int = 0,
|
250
328
|
) -> List[Tuple[Document, float]]:
|
251
329
|
embedding = self.embedding_fn([text])[0]
|
252
330
|
# TODO filter may not work yet
|
253
|
-
|
331
|
+
if where is None or where == "":
|
332
|
+
filter = Filter()
|
333
|
+
else:
|
334
|
+
filter = Filter.parse_obj(json.loads(where))
|
254
335
|
if self.config.collection_name is None:
|
255
336
|
raise ValueError("No collection name set, cannot search")
|
256
337
|
search_result: List[ScoredPoint] = self.client.search(
|
@@ -263,7 +344,7 @@ class QdrantDB(VectorStore):
|
|
263
344
|
exact=False, # use Apx NN, not exact NN
|
264
345
|
),
|
265
346
|
)
|
266
|
-
scores = [match.score for match in search_result]
|
347
|
+
scores = [match.score for match in search_result if match is not None]
|
267
348
|
docs = [
|
268
349
|
Document(**(match.payload)) # type: ignore
|
269
350
|
for match in search_result
|
@@ -272,8 +353,9 @@ class QdrantDB(VectorStore):
|
|
272
353
|
if len(docs) == 0:
|
273
354
|
logger.warning(f"No matches found for {text}")
|
274
355
|
return []
|
275
|
-
if settings.debug:
|
276
|
-
logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
|
277
356
|
doc_score_pairs = list(zip(docs, scores))
|
357
|
+
max_score = max(ds[1] for ds in doc_score_pairs)
|
358
|
+
if settings.debug:
|
359
|
+
logger.info(f"Found {len(doc_score_pairs)} matches, max score: {max_score}")
|
278
360
|
self.show_if_debug(doc_score_pairs)
|
279
361
|
return doc_score_pairs
|