langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,262 @@
1
+ """
2
+ Momento Vector Index.
3
+ https://docs.momentohq.com/vector-index/develop/api-reference
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ from typing import List, Optional, Sequence, Tuple, no_type_check
9
+
10
+ import momento.responses.vector_index as mvi_response
11
+ from dotenv import load_dotenv
12
+ from momento import (
13
+ # PreviewVectorIndexClientAsync,
14
+ CredentialProvider,
15
+ PreviewVectorIndexClient,
16
+ VectorIndexConfigurations,
17
+ )
18
+ from momento.requests.vector_index import (
19
+ ALL_METADATA,
20
+ Item,
21
+ SimilarityMetric,
22
+ )
23
+
24
+ from langroid.embedding_models.base import (
25
+ EmbeddingModel,
26
+ EmbeddingModelsConfig,
27
+ )
28
+ from langroid.embedding_models.models import OpenAIEmbeddingsConfig
29
+ from langroid.mytypes import Document, EmbeddingFunction
30
+ from langroid.utils.configuration import settings
31
+ from langroid.utils.pydantic_utils import (
32
+ flatten_pydantic_instance,
33
+ nested_dict_from_flat,
34
+ )
35
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class MomentoVIConfig(VectorStoreConfig):
41
+ cloud: bool = True
42
+ collection_name: str | None = "temp"
43
+ embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
44
+ distance: SimilarityMetric = SimilarityMetric.COSINE_SIMILARITY
45
+
46
+
47
+ class MomentoVI(VectorStore):
48
+ def __init__(self, config: MomentoVIConfig = MomentoVIConfig()):
49
+ super().__init__(config)
50
+ self.config: MomentoVIConfig = config
51
+ emb_model = EmbeddingModel.create(config.embedding)
52
+ self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
53
+ self.embedding_dim = emb_model.embedding_dims
54
+ self.host = config.host
55
+ self.port = config.port
56
+ load_dotenv()
57
+ api_key = os.getenv("MOMENTO_API_KEY")
58
+ if config.cloud:
59
+ if api_key is None:
60
+ raise ValueError(
61
+ """MOMENTO_API_KEY env variable must be set to
62
+ MomentoVI hosted service. Please set this in your .env file.
63
+ """
64
+ )
65
+ self.client = PreviewVectorIndexClient(
66
+ configuration=VectorIndexConfigurations.Default.latest(),
67
+ credential_provider=CredentialProvider.from_string(api_key),
68
+ )
69
+ else:
70
+ raise NotImplementedError("MomentoVI local not available yet")
71
+
72
+ # Note: Only create collection if a non-null collection name is provided.
73
+ # This is useful to delay creation of vecdb until we have a suitable
74
+ # collection name (e.g. we could get it from the url or folder path).
75
+ if config.collection_name is not None:
76
+ self.create_collection(
77
+ config.collection_name, replace=config.replace_collection
78
+ )
79
+
80
+ def clear_empty_collections(self) -> int:
81
+ logger.warning(
82
+ """
83
+ Momento VI does not yet have a way to easily get size of indices,
84
+ so clear_empty_collections is not deleting any indices.
85
+ """
86
+ )
87
+ return 0
88
+
89
+ def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
90
+ """Clear all collections with the given prefix."""
91
+
92
+ if not really:
93
+ logger.warning("Not deleting all collections, set really=True to confirm")
94
+ return 0
95
+ coll_names = self.list_collections(empty=False)
96
+ coll_names = [name for name in coll_names if name.startswith(prefix)]
97
+ if len(coll_names) == 0:
98
+ logger.warning(f"No collections found with prefix {prefix}")
99
+ return 0
100
+ for name in coll_names:
101
+ self.delete_collection(name)
102
+ logger.warning(
103
+ f"""
104
+ Deleted {len(coll_names)} indices from Momento VI
105
+ """
106
+ )
107
+ return len(coll_names)
108
+
109
+ def list_collections(self, empty: bool = False) -> List[str]:
110
+ """
111
+ Returns:
112
+ List of collection names that have at least one vector.
113
+
114
+ Args:
115
+ empty (bool, optional): Whether to include empty collections.
116
+ """
117
+ response = self.client.list_indexes()
118
+ if isinstance(response, mvi_response.ListIndexes.Success):
119
+ return [ind.name for ind in response.indexes]
120
+ elif isinstance(response, mvi_response.ListIndexes.Error):
121
+ raise ValueError(f"Error listing collections: {response.message}")
122
+ else:
123
+ raise ValueError(f"Unexpected response: {response}")
124
+
125
+ def create_collection(self, collection_name: str, replace: bool = False) -> None:
126
+ """
127
+ Create a collection with the given name, optionally replacing an existing
128
+ collection if `replace` is True.
129
+ Args:
130
+ collection_name (str): Name of the collection to create.
131
+ replace (bool): Whether to replace an existing collection
132
+ with the same name. Defaults to False.
133
+ """
134
+ self.config.collection_name = collection_name
135
+ response = self.client.create_index(
136
+ index_name=collection_name,
137
+ num_dimensions=self.embedding_dim,
138
+ similarity_metric=self.config.distance,
139
+ )
140
+ if isinstance(response, mvi_response.CreateIndex.Success):
141
+ logger.info(f"Created collection {collection_name}")
142
+ elif isinstance(response, mvi_response.CreateIndex.IndexAlreadyExists):
143
+ logger.warning(f"Collection {collection_name} already exists")
144
+ elif isinstance(response, mvi_response.CreateIndex.Error):
145
+ raise ValueError(
146
+ f"Error creating collection {collection_name}: {response.message}"
147
+ )
148
+ if settings.debug:
149
+ level = logger.getEffectiveLevel()
150
+ logger.setLevel(logging.INFO)
151
+ logger.info(f"Collection {collection_name} created")
152
+ logger.setLevel(level)
153
+
154
+ def add_documents(self, documents: Sequence[Document]) -> None:
155
+ super().maybe_add_ids(documents)
156
+ if len(documents) == 0:
157
+ return
158
+ embedding_vecs = self.embedding_fn([doc.content for doc in documents])
159
+ if self.config.collection_name is None:
160
+ raise ValueError("No collection name set, cannot ingest docs")
161
+
162
+ self.create_collection(self.config.collection_name, replace=True)
163
+
164
+ items = [
165
+ Item(
166
+ id=str(d.id()),
167
+ vector=embedding_vecs[i],
168
+ metadata=flatten_pydantic_instance(d, force_str=True),
169
+ # force all values to str since Momento requires it
170
+ )
171
+ for i, d in enumerate(documents)
172
+ ]
173
+
174
+ # don't insert all at once, batch in chunks of b,
175
+ # else we get an API error
176
+ b = self.config.batch_size
177
+ for i in range(0, len(documents), b):
178
+ response = self.client.upsert_item_batch(
179
+ index_name=self.config.collection_name,
180
+ items=items[i : i + b],
181
+ )
182
+ if isinstance(response, mvi_response.UpsertItemBatch.Success):
183
+ continue
184
+ elif isinstance(response, mvi_response.UpsertItemBatch.Error):
185
+ raise ValueError(f"Error adding documents: {response.message}")
186
+ else:
187
+ raise ValueError(f"Unexpected response: {response}")
188
+
189
+ def delete_collection(self, collection_name: str) -> None:
190
+ delete_response = self.client.delete_index(collection_name)
191
+ if isinstance(delete_response, mvi_response.DeleteIndex.Success):
192
+ logger.warning(f"Deleted index {collection_name}")
193
+ elif isinstance(delete_response, mvi_response.DeleteIndex.Error):
194
+ logger.error(
195
+ f"Error while deleting index {collection_name}: "
196
+ f" {delete_response.message}"
197
+ )
198
+
199
+ def _to_int_or_uuid(self, id: str) -> int | str:
200
+ try:
201
+ return int(id)
202
+ except ValueError:
203
+ return id
204
+
205
+ def get_all_documents(self, where: str = "") -> List[Document]:
206
+ raise NotImplementedError(
207
+ """
208
+ MomentoVI does not support get_all_documents().
209
+ Please use a different vector database, e.g. qdrant or chromadb.
210
+ """
211
+ )
212
+
213
+ def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
214
+ raise NotImplementedError(
215
+ """
216
+ MomentoVI does not support get_documents_by_ids.
217
+ Please use a different vector database, e.g. qdrant or chromadb.
218
+ """
219
+ )
220
+
221
+ @no_type_check
222
+ def similar_texts_with_scores(
223
+ self,
224
+ text: str,
225
+ k: int = 1,
226
+ where: Optional[str] = None,
227
+ neighbors: int = 0, # ignored
228
+ ) -> List[Tuple[Document, float]]:
229
+ if self.config.collection_name is None:
230
+ raise ValueError("No collection name set, cannot search")
231
+ embedding = self.embedding_fn([text])[0]
232
+ response = self.client.search(
233
+ index_name=self.config.collection_name,
234
+ query_vector=embedding,
235
+ top_k=k,
236
+ metadata_fields=ALL_METADATA,
237
+ )
238
+
239
+ if isinstance(response, mvi_response.Search.Error):
240
+ logger.warning(
241
+ f"Error while searching on index {self.config.collection_name}:"
242
+ f" {response.message}"
243
+ )
244
+ return []
245
+ elif not isinstance(response, mvi_response.Search.Success):
246
+ logger.warning(f"Unexpected response: {response}")
247
+ return []
248
+
249
+ scores = [match.metadata["distance"] for match in response.hits]
250
+ docs = [
251
+ Document.parse_obj(nested_dict_from_flat(match.metadata))
252
+ for match in response.hits
253
+ if match is not None
254
+ ]
255
+ if len(docs) == 0:
256
+ logger.warning(f"No matches found for {text}")
257
+ return []
258
+ if settings.debug:
259
+ logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
260
+ doc_score_pairs = list(zip(docs, scores))
261
+ self.show_if_debug(doc_score_pairs)
262
+ return doc_score_pairs
@@ -1,8 +1,10 @@
1
+ import hashlib
2
+ import json
1
3
  import logging
2
4
  import os
3
- from typing import List, Optional, Sequence, Tuple
5
+ import uuid
6
+ from typing import List, Optional, Sequence, Tuple, TypeVar
4
7
 
5
- from chromadb.api.types import EmbeddingFunction
6
8
  from dotenv import load_dotenv
7
9
  from qdrant_client import QdrantClient
8
10
  from qdrant_client.conversions.common_types import ScoredPoint
@@ -20,23 +22,50 @@ from langroid.embedding_models.base import (
20
22
  EmbeddingModelsConfig,
21
23
  )
22
24
  from langroid.embedding_models.models import OpenAIEmbeddingsConfig
23
- from langroid.mytypes import Document
25
+ from langroid.mytypes import Document, EmbeddingFunction
24
26
  from langroid.utils.configuration import settings
25
27
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
26
28
 
27
29
  logger = logging.getLogger(__name__)
28
30
 
29
31
 
32
+ T = TypeVar("T")
33
+
34
+
35
+ def from_optional(x: Optional[T], default: T) -> T:
36
+ if x is None:
37
+ return default
38
+
39
+ return x
40
+
41
+
42
+ def is_valid_uuid(uuid_to_test: str) -> bool:
43
+ """
44
+ Check if a given string is a valid UUID.
45
+ """
46
+ try:
47
+ uuid_obj = uuid.UUID(uuid_to_test)
48
+ return str(uuid_obj) == uuid_to_test
49
+ except Exception:
50
+ pass
51
+ # Check for valid unsigned 64-bit integer
52
+ try:
53
+ int_value = int(uuid_to_test)
54
+ return 0 <= int_value <= 18446744073709551615
55
+ except ValueError:
56
+ return False
57
+
58
+
30
59
  class QdrantDBConfig(VectorStoreConfig):
31
60
  cloud: bool = True
32
- collection_name: str | None = None
61
+ collection_name: str | None = "temp"
33
62
  storage_path: str = ".qdrant/data"
34
63
  embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
35
64
  distance: str = Distance.COSINE
36
65
 
37
66
 
38
67
  class QdrantDB(VectorStore):
39
- def __init__(self, config: QdrantDBConfig):
68
+ def __init__(self, config: QdrantDBConfig = QdrantDBConfig()):
40
69
  super().__init__(config)
41
70
  self.config = config
42
71
  emb_model = EmbeddingModel.create(config.embedding)
@@ -113,8 +142,10 @@ class QdrantDB(VectorStore):
113
142
  n_non_empty_deletes = 0
114
143
  for name in coll_names:
115
144
  info = self.client.get_collection(collection_name=name)
116
- n_empty_deletes += info.points_count == 0
117
- n_non_empty_deletes += info.points_count > 0
145
+ points_count = from_optional(info.points_count, 0)
146
+
147
+ n_empty_deletes += points_count == 0
148
+ n_non_empty_deletes += points_count > 0
118
149
  self.client.delete_collection(collection_name=name)
119
150
  logger.warning(
120
151
  f"""
@@ -135,11 +166,21 @@ class QdrantDB(VectorStore):
135
166
  colls = list(self.client.get_collections())[0][1]
136
167
  if empty:
137
168
  return [coll.name for coll in colls]
138
- counts = [
139
- self.client.get_collection(collection_name=coll.name).points_count
140
- for coll in colls
141
- ]
142
- return [coll.name for coll, count in zip(colls, counts) if count > 0]
169
+ counts = []
170
+ for coll in colls:
171
+ try:
172
+ counts.append(
173
+ from_optional(
174
+ self.client.get_collection(
175
+ collection_name=coll.name
176
+ ).points_count,
177
+ 0,
178
+ )
179
+ )
180
+ except Exception:
181
+ logger.warning(f"Error getting collection {coll.name}")
182
+ counts.append(0)
183
+ return [coll.name for coll, count in zip(colls, counts) if (count or 0) > 0]
143
184
 
144
185
  def create_collection(self, collection_name: str, replace: bool = False) -> None:
145
186
  """
@@ -154,7 +195,10 @@ class QdrantDB(VectorStore):
154
195
  collections = self.list_collections()
155
196
  if collection_name in collections:
156
197
  coll = self.client.get_collection(collection_name=collection_name)
157
- if coll.status == CollectionStatus.GREEN and coll.points_count > 0:
198
+ if (
199
+ coll.status == CollectionStatus.GREEN
200
+ and from_optional(coll.points_count, 0) > 0
201
+ ):
158
202
  logger.warning(f"Non-empty Collection {collection_name} already exists")
159
203
  if not replace:
160
204
  logger.warning("Not replacing collection")
@@ -178,9 +222,15 @@ class QdrantDB(VectorStore):
178
222
  logger.setLevel(level)
179
223
 
180
224
  def add_documents(self, documents: Sequence[Document]) -> None:
225
+ # Add id to metadata if not already present
226
+ super().maybe_add_ids(documents)
227
+ # Fix the ids due to qdrant finickiness
228
+ for doc in documents:
229
+ doc.metadata.id = str(self._to_int_or_uuid(doc.metadata.id))
181
230
  colls = self.list_collections(empty=True)
182
231
  if len(documents) == 0:
183
232
  return
233
+ document_dicts = [doc.dict() for doc in documents]
184
234
  embedding_vecs = self.embedding_fn([doc.content for doc in documents])
185
235
  if self.config.collection_name is None:
186
236
  raise ValueError("No collection name set, cannot ingest docs")
@@ -196,7 +246,7 @@ class QdrantDB(VectorStore):
196
246
  points=Batch(
197
247
  ids=ids[i : i + b],
198
248
  vectors=embedding_vecs[i : i + b],
199
- payloads=documents[i : i + b],
249
+ payloads=document_dicts[i : i + b],
200
250
  ),
201
251
  )
202
252
 
@@ -205,19 +255,42 @@ class QdrantDB(VectorStore):
205
255
 
206
256
  def _to_int_or_uuid(self, id: str) -> int | str:
207
257
  try:
208
- return int(id)
258
+ int_val = int(id)
259
+ if is_valid_uuid(id):
260
+ return int_val
209
261
  except ValueError:
262
+ pass
263
+
264
+ # If doc_id is already a valid UUID, return it as is
265
+ if isinstance(id, str) and is_valid_uuid(id):
210
266
  return id
211
267
 
212
- def get_all_documents(self) -> List[Document]:
268
+ # Otherwise, generate a UUID from the doc_id
269
+ # Convert doc_id to string if it's not already
270
+ id_str = str(id)
271
+
272
+ # Hash the document ID using SHA-1
273
+ hash_object = hashlib.sha1(id_str.encode())
274
+ hash_digest = hash_object.hexdigest()
275
+
276
+ # Truncate or manipulate the hash to fit into a UUID (128 bits)
277
+ uuid_str = hash_digest[:32]
278
+
279
+ # Format this string into a UUID format
280
+ formatted_uuid = uuid.UUID(uuid_str)
281
+
282
+ return str(formatted_uuid)
283
+
284
+ def get_all_documents(self, where: str = "") -> List[Document]:
213
285
  if self.config.collection_name is None:
214
286
  raise ValueError("No collection name set, cannot retrieve docs")
215
287
  docs = []
216
288
  offset = 0
289
+ filter = Filter() if where == "" else Filter.parse_obj(json.loads(where))
217
290
  while True:
218
291
  results, next_page_offset = self.client.scroll(
219
292
  collection_name=self.config.collection_name,
220
- scroll_filter=None,
293
+ scroll_filter=filter,
221
294
  offset=offset,
222
295
  limit=10_000, # try getting all at once, if not we keep paging
223
296
  with_payload=True,
@@ -239,7 +312,11 @@ class QdrantDB(VectorStore):
239
312
  with_vectors=False,
240
313
  with_payload=True,
241
314
  )
242
- docs = [Document(**record.payload) for record in records] # type: ignore
315
+ # Note the records may NOT be in the order of the ids,
316
+ # so we re-order them here.
317
+ id2payload = {record.id: record.payload for record in records}
318
+ ordered_payloads = [id2payload[id] for id in _ids]
319
+ docs = [Document(**payload) for payload in ordered_payloads] # type: ignore
243
320
  return docs
244
321
 
245
322
  def similar_texts_with_scores(
@@ -247,10 +324,14 @@ class QdrantDB(VectorStore):
247
324
  text: str,
248
325
  k: int = 1,
249
326
  where: Optional[str] = None,
327
+ neighbors: int = 0,
250
328
  ) -> List[Tuple[Document, float]]:
251
329
  embedding = self.embedding_fn([text])[0]
252
330
  # TODO filter may not work yet
253
- filter = Filter() if where is None else Filter.from_json(where) # type: ignore
331
+ if where is None or where == "":
332
+ filter = Filter()
333
+ else:
334
+ filter = Filter.parse_obj(json.loads(where))
254
335
  if self.config.collection_name is None:
255
336
  raise ValueError("No collection name set, cannot search")
256
337
  search_result: List[ScoredPoint] = self.client.search(
@@ -263,7 +344,7 @@ class QdrantDB(VectorStore):
263
344
  exact=False, # use Apx NN, not exact NN
264
345
  ),
265
346
  )
266
- scores = [match.score for match in search_result]
347
+ scores = [match.score for match in search_result if match is not None]
267
348
  docs = [
268
349
  Document(**(match.payload)) # type: ignore
269
350
  for match in search_result
@@ -272,8 +353,9 @@ class QdrantDB(VectorStore):
272
353
  if len(docs) == 0:
273
354
  logger.warning(f"No matches found for {text}")
274
355
  return []
275
- if settings.debug:
276
- logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
277
356
  doc_score_pairs = list(zip(docs, scores))
357
+ max_score = max(ds[1] for ds in doc_score_pairs)
358
+ if settings.debug:
359
+ logger.info(f"Found {len(doc_score_pairs)} matches, max score: {max_score}")
278
360
  self.show_if_debug(doc_score_pairs)
279
361
  return doc_score_pairs