langroid 0.1.139__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. langroid/__init__.py +70 -0
  2. langroid/agent/__init__.py +22 -0
  3. langroid/agent/base.py +120 -33
  4. langroid/agent/batch.py +134 -35
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +608 -0
  7. langroid/agent/chat_agent.py +164 -100
  8. langroid/agent/chat_document.py +19 -2
  9. langroid/agent/openai_assistant.py +20 -10
  10. langroid/agent/special/__init__.py +33 -10
  11. langroid/agent/special/doc_chat_agent.py +521 -108
  12. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  13. langroid/agent/special/lance_rag/__init__.py +9 -0
  14. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  15. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  16. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  17. langroid/agent/special/lance_tools.py +44 -0
  18. langroid/agent/special/neo4j/__init__.py +0 -0
  19. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  20. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  21. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  22. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  23. langroid/agent/special/relevance_extractor_agent.py +23 -7
  24. langroid/agent/special/retriever_agent.py +29 -174
  25. langroid/agent/special/sql/__init__.py +7 -0
  26. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  27. langroid/agent/special/sql/utils/__init__.py +11 -0
  28. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  29. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  30. langroid/agent/special/table_chat_agent.py +43 -9
  31. langroid/agent/task.py +423 -114
  32. langroid/agent/tool_message.py +67 -10
  33. langroid/agent/tools/__init__.py +8 -0
  34. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  35. langroid/agent/tools/google_search_tool.py +11 -0
  36. langroid/agent/tools/metaphor_search_tool.py +67 -0
  37. langroid/agent/tools/recipient_tool.py +6 -24
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/cachedb/__init__.py +6 -0
  40. langroid/embedding_models/__init__.py +24 -0
  41. langroid/embedding_models/base.py +9 -1
  42. langroid/embedding_models/models.py +117 -17
  43. langroid/embedding_models/protoc/embeddings.proto +19 -0
  44. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  45. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  46. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  47. langroid/embedding_models/remote_embeds.py +153 -0
  48. langroid/language_models/__init__.py +22 -0
  49. langroid/language_models/azure_openai.py +47 -4
  50. langroid/language_models/base.py +26 -10
  51. langroid/language_models/config.py +5 -0
  52. langroid/language_models/openai_gpt.py +407 -121
  53. langroid/language_models/prompt_formatter/__init__.py +9 -0
  54. langroid/language_models/prompt_formatter/base.py +4 -6
  55. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  56. langroid/language_models/utils.py +10 -9
  57. langroid/mytypes.py +10 -4
  58. langroid/parsing/__init__.py +33 -1
  59. langroid/parsing/document_parser.py +259 -63
  60. langroid/parsing/image_text.py +32 -0
  61. langroid/parsing/parse_json.py +143 -0
  62. langroid/parsing/parser.py +20 -7
  63. langroid/parsing/repo_loader.py +108 -46
  64. langroid/parsing/search.py +8 -0
  65. langroid/parsing/table_loader.py +44 -0
  66. langroid/parsing/url_loader.py +59 -13
  67. langroid/parsing/urls.py +18 -9
  68. langroid/parsing/utils.py +130 -9
  69. langroid/parsing/web_search.py +73 -0
  70. langroid/prompts/__init__.py +7 -0
  71. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  72. langroid/prompts/prompts_config.py +1 -1
  73. langroid/utils/__init__.py +10 -0
  74. langroid/utils/algorithms/__init__.py +3 -0
  75. langroid/utils/configuration.py +0 -1
  76. langroid/utils/constants.py +4 -0
  77. langroid/utils/logging.py +2 -5
  78. langroid/utils/output/__init__.py +15 -2
  79. langroid/utils/output/status.py +33 -0
  80. langroid/utils/pandas_utils.py +30 -0
  81. langroid/utils/pydantic_utils.py +446 -4
  82. langroid/utils/system.py +36 -1
  83. langroid/vector_store/__init__.py +34 -2
  84. langroid/vector_store/base.py +33 -2
  85. langroid/vector_store/chromadb.py +42 -13
  86. langroid/vector_store/lancedb.py +226 -60
  87. langroid/vector_store/meilisearch.py +7 -6
  88. langroid/vector_store/momento.py +3 -2
  89. langroid/vector_store/qdrantdb.py +82 -11
  90. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/METADATA +190 -129
  91. langroid-0.1.219.dist-info/RECORD +127 -0
  92. langroid/agent/special/recipient_validator_agent.py +0 -157
  93. langroid/parsing/json.py +0 -64
  94. langroid/utils/web/selenium_login.py +0 -36
  95. langroid-0.1.139.dist-info/RECORD +0 -103
  96. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
  97. {langroid-0.1.139.dist-info → langroid-0.1.219.dist-info}/WHEEL +0 -0
@@ -1,8 +1,7 @@
1
+ import json
1
2
  import logging
2
3
  from typing import Any, Dict, List, Optional, Sequence, Tuple
3
4
 
4
- import chromadb
5
-
6
5
  from langroid.embedding_models.base import (
7
6
  EmbeddingModel,
8
7
  EmbeddingModelsConfig,
@@ -25,8 +24,19 @@ class ChromaDBConfig(VectorStoreConfig):
25
24
 
26
25
 
27
26
  class ChromaDB(VectorStore):
28
- def __init__(self, config: ChromaDBConfig):
27
+ def __init__(self, config: ChromaDBConfig = ChromaDBConfig()):
29
28
  super().__init__(config)
29
+ try:
30
+ import chromadb
31
+ except ImportError:
32
+ raise ImportError(
33
+ """
34
+ ChromaDB is not installed by default with Langroid.
35
+ If you want to use it, please install it with the `chromadb` extra, e.g.
36
+ pip install "langroid[chromadb]"
37
+ or an equivalent command.
38
+ """
39
+ )
30
40
  self.config = config
31
41
  emb_model = EmbeddingModel.create(config.embedding)
32
42
  self.embedding_fn = emb_model.embedding_fn()
@@ -114,7 +124,9 @@ class ChromaDB(VectorStore):
114
124
  return
115
125
  contents: List[str] = [document.content for document in documents]
116
126
  # convert metadatas to dicts so chroma can handle them
117
- metadata_dicts: List[dict[str, Any]] = [d.metadata.dict() for d in documents]
127
+ metadata_dicts: List[dict[str, Any]] = [
128
+ d.metadata.dict_bool_int() for d in documents
129
+ ]
118
130
  for m in metadata_dicts:
119
131
  # chroma does not handle non-atomic types in metadata
120
132
  m["window_ids"] = ",".join(m["window_ids"])
@@ -127,29 +139,43 @@ class ChromaDB(VectorStore):
127
139
  ids=ids,
128
140
  )
129
141
 
130
- def get_all_documents(self) -> List[Document]:
131
- results = self.collection.get(include=["documents", "metadatas"])
142
+ def get_all_documents(self, where: str = "") -> List[Document]:
143
+ filter = json.loads(where) if where else None
144
+ results = self.collection.get(
145
+ include=["documents", "metadatas"],
146
+ where=filter,
147
+ )
132
148
  results["documents"] = [results["documents"]]
133
149
  results["metadatas"] = [results["metadatas"]]
134
150
  return self._docs_from_results(results)
135
151
 
136
152
  def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
137
- results = self.collection.get(ids=ids, include=["documents", "metadatas"])
138
- results["documents"] = [results["documents"]]
139
- results["metadatas"] = [results["metadatas"]]
140
- return self._docs_from_results(results)
153
+ # get them one by one since chroma mangles the order of the results
154
+ # when fetched from a list of ids.
155
+ results = [
156
+ self.collection.get(ids=[id], include=["documents", "metadatas"])
157
+ for id in ids
158
+ ]
159
+ final_results = {}
160
+ final_results["documents"] = [[r["documents"][0] for r in results]]
161
+ final_results["metadatas"] = [[r["metadatas"][0] for r in results]]
162
+ return self._docs_from_results(final_results)
141
163
 
142
164
  def delete_collection(self, collection_name: str) -> None:
143
- self.client.delete_collection(name=collection_name)
165
+ try:
166
+ self.client.delete_collection(name=collection_name)
167
+ except Exception:
168
+ pass
144
169
 
145
170
  def similar_texts_with_scores(
146
171
  self, text: str, k: int = 1, where: Optional[str] = None
147
172
  ) -> List[Tuple[Document, float]]:
148
173
  n = self.collection.count()
174
+ filter = json.loads(where) if where else None
149
175
  results = self.collection.query(
150
176
  query_texts=[text],
151
177
  n_results=min(n, k),
152
- where=where,
178
+ where=filter,
153
179
  include=["documents", "distances", "metadatas"],
154
180
  )
155
181
  docs = self._docs_from_results(results)
@@ -175,7 +201,10 @@ class ChromaDB(VectorStore):
175
201
  metadatas = results["metadatas"][0]
176
202
  for m in metadatas:
177
203
  # restore the stringified list of window_ids into the original List[str]
178
- m["window_ids"] = m["window_ids"].split(",")
204
+ if m["window_ids"].strip() == "":
205
+ m["window_ids"] = []
206
+ else:
207
+ m["window_ids"] = m["window_ids"].split(",")
179
208
  docs = [
180
209
  Document(content=d, metadata=DocMetaData(**m))
181
210
  for d, m in zip(contents, metadatas)
@@ -2,9 +2,11 @@ import logging
2
2
  from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Type
3
3
 
4
4
  import lancedb
5
+ import pandas as pd
5
6
  from dotenv import load_dotenv
6
7
  from lancedb.pydantic import LanceModel, Vector
7
- from pydantic import BaseModel, create_model
8
+ from lancedb.query import LanceVectorQueryBuilder
9
+ from pydantic import BaseModel, ValidationError, create_model
8
10
 
9
11
  from langroid.embedding_models.base import (
10
12
  EmbeddingModel,
@@ -14,6 +16,10 @@ from langroid.embedding_models.models import OpenAIEmbeddingsConfig
14
16
  from langroid.mytypes import Document, EmbeddingFunction
15
17
  from langroid.utils.configuration import settings
16
18
  from langroid.utils.pydantic_utils import (
19
+ dataframe_to_document_model,
20
+ dataframe_to_documents,
21
+ extend_document_class,
22
+ extra_metadata,
17
23
  flatten_pydantic_instance,
18
24
  flatten_pydantic_model,
19
25
  nested_dict_from_flat,
@@ -29,11 +35,14 @@ class LanceDBConfig(VectorStoreConfig):
29
35
  storage_path: str = ".lancedb/data"
30
36
  embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
31
37
  distance: str = "cosine"
38
+ # document_class is used to store in lancedb with right schema,
39
+ # and also to retrieve the right type of Documents when searching.
32
40
  document_class: Type[Document] = Document
41
+ flatten: bool = False # flatten Document class into LanceSchema ?
33
42
 
34
43
 
35
44
  class LanceDB(VectorStore):
36
- def __init__(self, config: LanceDBConfig):
45
+ def __init__(self, config: LanceDBConfig = LanceDBConfig()):
37
46
  super().__init__(config)
38
47
  self.config: LanceDBConfig = config
39
48
  emb_model = EmbeddingModel.create(config.embedding)
@@ -41,8 +50,10 @@ class LanceDB(VectorStore):
41
50
  self.embedding_dim = emb_model.embedding_dims
42
51
  self.host = config.host
43
52
  self.port = config.port
44
- self.schema = self._create_lance_schema(self.config.document_class)
45
- self.flat_schema = self._create_flat_lance_schema(self.config.document_class)
53
+ self.is_from_dataframe = False # were docs ingested from a dataframe?
54
+ self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
55
+ self._setup_schemas(config.document_class)
56
+
46
57
  load_dotenv()
47
58
  if self.config.cloud:
48
59
  logger.warning(
@@ -75,6 +86,15 @@ class LanceDB(VectorStore):
75
86
  config.collection_name, replace=config.replace_collection
76
87
  )
77
88
 
89
+ def _setup_schemas(self, doc_cls: Type[Document] | None) -> None:
90
+ doc_cls = doc_cls or self.config.document_class
91
+ self.unflattened_schema = self._create_lance_schema(doc_cls)
92
+ self.schema = (
93
+ self._create_flat_lance_schema(doc_cls)
94
+ if self.config.flatten
95
+ else self.unflattened_schema
96
+ )
97
+
78
98
  def clear_empty_collections(self) -> int:
79
99
  coll_names = self.list_collections()
80
100
  n_deletes = 0
@@ -119,7 +139,7 @@ class LanceDB(VectorStore):
119
139
  Args:
120
140
  empty (bool, optional): Whether to include empty collections.
121
141
  """
122
- colls = self.client.table_names()
142
+ colls = self.client.table_names(limit=None)
123
143
  if len(colls) == 0:
124
144
  return []
125
145
  if empty: # include empty tbls
@@ -134,7 +154,7 @@ class LanceDB(VectorStore):
134
154
  - Vector field that has dims equal to
135
155
  the embedding dimension of the embedding model, and a data field of type
136
156
  DocClass.
137
- - payload of type `doc_cls`
157
+ - other fields from doc_cls
138
158
 
139
159
  Args:
140
160
  doc_cls (Type[Document]): A Pydantic model which should be a subclass of
@@ -152,13 +172,20 @@ class LanceDB(VectorStore):
152
172
 
153
173
  n = self.embedding_dim
154
174
 
155
- NewModel = create_model(
156
- "NewModel",
157
- __base__=LanceModel,
158
- id=(str, ...),
159
- vector=(Vector(n), ...),
160
- payload=(doc_cls, ...),
175
+ # Prepare fields for the new model
176
+ fields = {"id": (str, ...), "vector": (Vector(n), ...)}
177
+
178
+ sorted_fields = dict(
179
+ sorted(doc_cls.__fields__.items(), key=lambda item: item[0])
161
180
  )
181
+ # Add both statically and dynamically defined fields from doc_cls
182
+ for field_name, field in sorted_fields.items():
183
+ fields[field_name] = (field.outer_type_, field.default)
184
+
185
+ # Create the new model with dynamic fields
186
+ NewModel = create_model(
187
+ "NewModel", __base__=LanceModel, **fields
188
+ ) # type: ignore
162
189
  return NewModel # type: ignore
163
190
 
164
191
  def _create_flat_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
@@ -190,76 +217,218 @@ class LanceDB(VectorStore):
190
217
  return
191
218
  else:
192
219
  logger.warning("Recreating fresh collection")
193
- tbl = self.client.create_table(
194
- collection_name, schema=self.flat_schema, mode="overwrite"
195
- )
220
+ self.client.create_table(collection_name, schema=self.schema, mode="overwrite")
196
221
  if settings.debug:
197
222
  level = logger.getEffectiveLevel()
198
223
  logger.setLevel(logging.INFO)
199
- logger.info(tbl.schema)
200
224
  logger.setLevel(level)
201
225
 
226
+ def _maybe_set_doc_class_schema(self, doc: Document) -> None:
227
+ """
228
+ Set the config.document_class and self.schema based on doc if needed
229
+ Args:
230
+ doc: an instance of Document, to be added to a collection
231
+ """
232
+ extra_metadata_fields = extra_metadata(doc, self.config.document_class)
233
+ if len(extra_metadata_fields) > 0:
234
+ logger.warning(
235
+ f"""
236
+ Added documents contain extra metadata fields:
237
+ {extra_metadata_fields}
238
+ which were not present in the original config.document_class.
239
+ Trying to change document_class and corresponding schemas.
240
+ Overriding LanceDBConfig.document_class with an auto-generated
241
+ Pydantic class that includes these extra fields.
242
+ If this fails, or you see odd results, it is recommended that you
243
+ define a subclass of Document, with metadata of class derived from
244
+ DocMetaData, with extra fields defined via
245
+ `Field(..., description="...")` declarations,
246
+ and set this document class as the value of the
247
+ LanceDBConfig.document_class attribute.
248
+ """
249
+ )
250
+
251
+ doc_cls = extend_document_class(doc)
252
+ self.config.document_class = doc_cls
253
+ self._setup_schemas(doc_cls)
254
+
202
255
  def add_documents(self, documents: Sequence[Document]) -> None:
203
256
  super().maybe_add_ids(documents)
204
257
  colls = self.list_collections(empty=True)
205
258
  if len(documents) == 0:
206
259
  return
207
260
  embedding_vecs = self.embedding_fn([doc.content for doc in documents])
208
- if self.config.collection_name is None:
261
+ coll_name = self.config.collection_name
262
+ if coll_name is None:
209
263
  raise ValueError("No collection name set, cannot ingest docs")
210
- if self.config.collection_name not in colls:
211
- self.create_collection(self.config.collection_name, replace=True)
264
+ self._maybe_set_doc_class_schema(documents[0])
265
+ if (
266
+ coll_name not in colls
267
+ or self.client.open_table(coll_name).head(1).shape[0] == 0
268
+ ):
269
+ # collection either doesn't exist or is empty, so replace it,
270
+ self.create_collection(coll_name, replace=True)
271
+
212
272
  ids = [str(d.id()) for d in documents]
213
273
  # don't insert all at once, batch in chunks of b,
214
274
  # else we get an API error
215
275
  b = self.config.batch_size
216
276
 
217
- def make_batches() -> Generator[List[Dict[str, Any]], None, None]:
277
+ def make_batches() -> Generator[List[BaseModel], None, None]:
218
278
  for i in range(0, len(ids), b):
219
- yield [
220
- flatten_pydantic_instance(
221
- self.schema(
222
- id=ids[i],
223
- vector=embedding_vecs[i],
224
- payload=doc,
225
- )
279
+ batch = [
280
+ self.unflattened_schema(
281
+ id=ids[i + j],
282
+ vector=embedding_vecs[i + j],
283
+ **doc.dict(),
226
284
  )
227
- for i, doc in enumerate(documents[i : i + b])
285
+ for j, doc in enumerate(documents[i : i + b])
228
286
  ]
287
+ if self.config.flatten:
288
+ batch = [
289
+ flatten_pydantic_instance(instance) # type: ignore
290
+ for instance in batch
291
+ ]
292
+ yield batch
229
293
 
230
294
  tbl = self.client.open_table(self.config.collection_name)
231
- tbl.add(make_batches())
295
+ try:
296
+ tbl.add(make_batches())
297
+ except Exception as e:
298
+ logger.error(
299
+ f"""
300
+ Error adding documents to LanceDB: {e}
301
+ POSSIBLE REMEDY: Delete the LancdDB storage directory
302
+ {self.config.storage_path} and try again.
303
+ """
304
+ )
305
+
306
+ def add_dataframe(
307
+ self,
308
+ df: pd.DataFrame,
309
+ content: str = "content",
310
+ metadata: List[str] = [],
311
+ ) -> None:
312
+ """
313
+ Add a dataframe to the collection.
314
+ Args:
315
+ df (pd.DataFrame): A dataframe
316
+ content (str): The name of the column in the dataframe that contains the
317
+ text content to be embedded using the embedding model.
318
+ metadata (List[str]): A list of column names in the dataframe that contain
319
+ metadata to be stored in the database. Defaults to [].
320
+ """
321
+ self.is_from_dataframe = True
322
+ actual_metadata = metadata.copy()
323
+ self.df_metadata_columns = actual_metadata # could be updated below
324
+ # get content column
325
+ content_values = df[content].values.tolist()
326
+ embedding_vecs = self.embedding_fn(content_values)
327
+
328
+ # add vector column
329
+ df["vector"] = embedding_vecs
330
+ if content != "content":
331
+ # rename content column to "content", leave existing column intact
332
+ df = df.rename(columns={content: "content"}, inplace=False)
333
+
334
+ if "id" not in df.columns:
335
+ docs = dataframe_to_documents(df, content="content", metadata=metadata)
336
+ ids = [str(d.id()) for d in docs]
337
+ df["id"] = ids
338
+
339
+ if "id" not in actual_metadata:
340
+ actual_metadata += ["id"]
341
+
342
+ colls = self.list_collections(empty=True)
343
+ coll_name = self.config.collection_name
344
+ if (
345
+ coll_name not in colls
346
+ or self.client.open_table(coll_name).head(1).shape[0] == 0
347
+ ):
348
+ # collection either doesn't exist or is empty, so replace it
349
+ # and set new schema from df
350
+ self.client.create_table(
351
+ self.config.collection_name,
352
+ data=df,
353
+ mode="overwrite",
354
+ )
355
+ doc_cls = dataframe_to_document_model(
356
+ df,
357
+ content=content,
358
+ metadata=actual_metadata,
359
+ exclude=["vector"],
360
+ )
361
+ self.config.document_class = doc_cls # type: ignore
362
+ self._setup_schemas(doc_cls) # type: ignore
363
+ else:
364
+ # collection exists and is not empty, so append to it
365
+ tbl = self.client.open_table(self.config.collection_name)
366
+ tbl.add(df)
232
367
 
233
368
  def delete_collection(self, collection_name: str) -> None:
234
- self.client.drop_table(collection_name)
369
+ self.client.drop_table(collection_name, ignore_missing=True)
370
+
371
+ def _lance_result_to_docs(self, result: LanceVectorQueryBuilder) -> List[Document]:
372
+ if self.is_from_dataframe:
373
+ df = result.to_pandas()
374
+ return dataframe_to_documents(
375
+ df,
376
+ content="content",
377
+ metadata=self.df_metadata_columns,
378
+ doc_cls=self.config.document_class,
379
+ )
380
+ else:
381
+ records = result.to_arrow().to_pylist()
382
+ return self._records_to_docs(records)
235
383
 
236
- def get_all_documents(self) -> List[Document]:
384
+ def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
385
+ if self.config.flatten:
386
+ docs = [
387
+ self.unflattened_schema(**nested_dict_from_flat(rec)) for rec in records
388
+ ]
389
+ else:
390
+ try:
391
+ docs = [self.schema(**rec) for rec in records]
392
+ except ValidationError as e:
393
+ raise ValueError(
394
+ f"""
395
+ Error validating LanceDB result: {e}
396
+ HINT: This could happen when you're re-using an
397
+ existing LanceDB store with a different schema.
398
+ Try deleting your local lancedb storage at `{self.config.storage_path}`
399
+ re-ingesting your documents and/or replacing the collections.
400
+ """
401
+ )
402
+
403
+ doc_cls = self.config.document_class
404
+ doc_cls_field_names = doc_cls.__fields__.keys()
405
+ return [
406
+ doc_cls(
407
+ **{
408
+ field_name: getattr(doc, field_name)
409
+ for field_name in doc_cls_field_names
410
+ }
411
+ )
412
+ for doc in docs
413
+ ]
414
+
415
+ def get_all_documents(self, where: str = "") -> List[Document]:
237
416
  if self.config.collection_name is None:
238
417
  raise ValueError("No collection name set, cannot retrieve docs")
239
418
  tbl = self.client.open_table(self.config.collection_name)
240
- records = tbl.search(None).to_arrow().to_pylist()
241
- docs = [
242
- self.config.document_class(
243
- **(nested_dict_from_flat(rec, sub_dict="payload"))
244
- )
245
- for rec in records
246
- ]
247
- return docs
419
+ pre_result = tbl.search(None).where(where or None).limit(None)
420
+ return self._lance_result_to_docs(pre_result)
248
421
 
249
422
  def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
250
423
  if self.config.collection_name is None:
251
424
  raise ValueError("No collection name set, cannot retrieve docs")
252
425
  _ids = [str(id) for id in ids]
253
426
  tbl = self.client.open_table(self.config.collection_name)
254
- records = [
255
- tbl.search().where(f"id == '{_id}'").to_arrow().to_pylist()[0]
256
- for _id in _ids
257
- ]
258
- doc_cls = self.config.document_class
259
- docs = [
260
- doc_cls(**(nested_dict_from_flat(rec, sub_dict="payload")))
261
- for rec in records
262
- ]
427
+ docs = []
428
+ for _id in _ids:
429
+ results = self._lance_result_to_docs(tbl.search().where(f"id == '{_id}'"))
430
+ if len(results) > 0:
431
+ docs.append(results[0])
263
432
  return docs
264
433
 
265
434
  def similar_texts_with_scores(
@@ -270,23 +439,20 @@ class LanceDB(VectorStore):
270
439
  ) -> List[Tuple[Document, float]]:
271
440
  embedding = self.embedding_fn([text])[0]
272
441
  tbl = self.client.open_table(self.config.collection_name)
273
- records = (
442
+ result = (
274
443
  tbl.search(embedding)
275
444
  .metric(self.config.distance)
276
- .where(where)
445
+ .where(where, prefilter=True)
277
446
  .limit(k)
278
- .to_arrow()
279
- .to_pylist()
280
447
  )
281
-
448
+ docs = self._lance_result_to_docs(result)
282
449
  # note _distance is 1 - cosine
283
- scores = [1 - rec["_distance"] for rec in records]
284
- docs = [
285
- self.config.document_class(
286
- **(nested_dict_from_flat(rec, sub_dict="payload"))
287
- )
288
- for rec in records
289
- ]
450
+ if self.is_from_dataframe:
451
+ scores = [
452
+ 1 - rec["_distance"] for rec in result.to_pandas().to_dict("records")
453
+ ]
454
+ else:
455
+ scores = [1 - rec["_distance"] for rec in result.to_arrow().to_pylist()]
290
456
  if len(docs) == 0:
291
457
  logger.warning(f"No matches found for {text}")
292
458
  return []
@@ -32,7 +32,7 @@ class MeiliSearchConfig(VectorStoreConfig):
32
32
 
33
33
 
34
34
  class MeiliSearch(VectorStore):
35
- def __init__(self, config: MeiliSearchConfig):
35
+ def __init__(self, config: MeiliSearchConfig = MeiliSearchConfig()):
36
36
  super().__init__(config)
37
37
  self.config: MeiliSearchConfig = config
38
38
  self.host = config.host
@@ -165,7 +165,7 @@ class MeiliSearch(VectorStore):
165
165
  async with self.client() as client:
166
166
  index = client.index(collection_name)
167
167
  await index.add_documents_in_batches(
168
- documents=documents, # type: ignore
168
+ documents=documents,
169
169
  batch_size=self.config.batch_size,
170
170
  primary_key=self.config.primary_key,
171
171
  )
@@ -198,18 +198,19 @@ class MeiliSearch(VectorStore):
198
198
  except ValueError:
199
199
  return id
200
200
 
201
- async def _async_get_documents(self) -> DocumentsInfo:
201
+ async def _async_get_documents(self, where: str = "") -> DocumentsInfo:
202
202
  if self.config.collection_name is None:
203
203
  raise ValueError("No collection name set, cannot retrieve docs")
204
+ filter = [] if where is None else where
204
205
  async with self.client() as client:
205
206
  index = client.index(self.config.collection_name)
206
- documents = await index.get_documents(limit=10_000)
207
+ documents = await index.get_documents(limit=10_000, filter=filter)
207
208
  return documents
208
209
 
209
- def get_all_documents(self) -> List[Document]:
210
+ def get_all_documents(self, where: str = "") -> List[Document]:
210
211
  if self.config.collection_name is None:
211
212
  raise ValueError("No collection name set, cannot retrieve docs")
212
- docs = asyncio.run(self._async_get_documents())
213
+ docs = asyncio.run(self._async_get_documents(where))
213
214
  if docs is None:
214
215
  return []
215
216
  doc_results = docs.results
@@ -2,6 +2,7 @@
2
2
  Momento Vector Index.
3
3
  https://docs.momentohq.com/vector-index/develop/api-reference
4
4
  """
5
+
5
6
  import logging
6
7
  import os
7
8
  from typing import List, Optional, Sequence, Tuple, no_type_check
@@ -44,7 +45,7 @@ class MomentoVIConfig(VectorStoreConfig):
44
45
 
45
46
 
46
47
  class MomentoVI(VectorStore):
47
- def __init__(self, config: MomentoVIConfig):
48
+ def __init__(self, config: MomentoVIConfig = MomentoVIConfig()):
48
49
  super().__init__(config)
49
50
  self.config: MomentoVIConfig = config
50
51
  emb_model = EmbeddingModel.create(config.embedding)
@@ -201,7 +202,7 @@ class MomentoVI(VectorStore):
201
202
  except ValueError:
202
203
  return id
203
204
 
204
- def get_all_documents(self) -> List[Document]:
205
+ def get_all_documents(self, where: str = "") -> List[Document]:
205
206
  raise NotImplementedError(
206
207
  """
207
208
  MomentoVI does not support get_all_documents().