langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,463 @@
1
+ import logging
2
+ from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Type
3
+
4
+ import lancedb
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+ from lancedb.pydantic import LanceModel, Vector
8
+ from lancedb.query import LanceVectorQueryBuilder
9
+ from pydantic import BaseModel, ValidationError, create_model
10
+
11
+ from langroid.embedding_models.base import (
12
+ EmbeddingModel,
13
+ EmbeddingModelsConfig,
14
+ )
15
+ from langroid.embedding_models.models import OpenAIEmbeddingsConfig
16
+ from langroid.mytypes import Document, EmbeddingFunction
17
+ from langroid.utils.configuration import settings
18
+ from langroid.utils.pydantic_utils import (
19
+ dataframe_to_document_model,
20
+ dataframe_to_documents,
21
+ extend_document_class,
22
+ extra_metadata,
23
+ flatten_pydantic_instance,
24
+ flatten_pydantic_model,
25
+ nested_dict_from_flat,
26
+ )
27
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class LanceDBConfig(VectorStoreConfig):
33
+ cloud: bool = False
34
+ collection_name: str | None = "temp"
35
+ storage_path: str = ".lancedb/data"
36
+ embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
37
+ distance: str = "cosine"
38
+ # document_class is used to store in lancedb with right schema,
39
+ # and also to retrieve the right type of Documents when searching.
40
+ document_class: Type[Document] = Document
41
+ flatten: bool = False # flatten Document class into LanceSchema ?
42
+
43
+
44
+ class LanceDB(VectorStore):
45
+ def __init__(self, config: LanceDBConfig = LanceDBConfig()):
46
+ super().__init__(config)
47
+ self.config: LanceDBConfig = config
48
+ emb_model = EmbeddingModel.create(config.embedding)
49
+ self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
50
+ self.embedding_dim = emb_model.embedding_dims
51
+ self.host = config.host
52
+ self.port = config.port
53
+ self.is_from_dataframe = False # were docs ingested from a dataframe?
54
+ self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
55
+ self._setup_schemas(config.document_class)
56
+
57
+ load_dotenv()
58
+ if self.config.cloud:
59
+ logger.warning(
60
+ "LanceDB Cloud is not available yet. Switching to local storage."
61
+ )
62
+ config.cloud = False
63
+ else:
64
+ try:
65
+ self.client = lancedb.connect(
66
+ uri=config.storage_path,
67
+ )
68
+ except Exception as e:
69
+ new_storage_path = config.storage_path + ".new"
70
+ logger.warning(
71
+ f"""
72
+ Error connecting to local LanceDB at {config.storage_path}:
73
+ {e}
74
+ Switching to {new_storage_path}
75
+ """
76
+ )
77
+ self.client = lancedb.connect(
78
+ uri=new_storage_path,
79
+ )
80
+
81
+ # Note: Only create collection if a non-null collection name is provided.
82
+ # This is useful to delay creation of vecdb until we have a suitable
83
+ # collection name (e.g. we could get it from the url or folder path).
84
+ if config.collection_name is not None:
85
+ self.create_collection(
86
+ config.collection_name, replace=config.replace_collection
87
+ )
88
+
89
+ def _setup_schemas(self, doc_cls: Type[Document] | None) -> None:
90
+ doc_cls = doc_cls or self.config.document_class
91
+ self.unflattened_schema = self._create_lance_schema(doc_cls)
92
+ self.schema = (
93
+ self._create_flat_lance_schema(doc_cls)
94
+ if self.config.flatten
95
+ else self.unflattened_schema
96
+ )
97
+
98
+ def clear_empty_collections(self) -> int:
99
+ coll_names = self.list_collections()
100
+ n_deletes = 0
101
+ for name in coll_names:
102
+ nr = self.client.open_table(name).head(1).shape[0]
103
+ if nr == 0:
104
+ n_deletes += 1
105
+ self.client.drop_table(name)
106
+ return n_deletes
107
+
108
+ def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
109
+ """Clear all collections with the given prefix."""
110
+ if not really:
111
+ logger.warning("Not deleting all collections, set really=True to confirm")
112
+ return 0
113
+ coll_names = [
114
+ c for c in self.list_collections(empty=True) if c.startswith(prefix)
115
+ ]
116
+ if len(coll_names) == 0:
117
+ logger.warning(f"No collections found with prefix {prefix}")
118
+ return 0
119
+ n_empty_deletes = 0
120
+ n_non_empty_deletes = 0
121
+ for name in coll_names:
122
+ nr = self.client.open_table(name).head(1).shape[0]
123
+ n_empty_deletes += nr == 0
124
+ n_non_empty_deletes += nr > 0
125
+ self.client.drop_table(name)
126
+ logger.warning(
127
+ f"""
128
+ Deleted {n_empty_deletes} empty collections and
129
+ {n_non_empty_deletes} non-empty collections.
130
+ """
131
+ )
132
+ return n_empty_deletes + n_non_empty_deletes
133
+
134
+ def list_collections(self, empty: bool = False) -> List[str]:
135
+ """
136
+ Returns:
137
+ List of collection names that have at least one vector.
138
+
139
+ Args:
140
+ empty (bool, optional): Whether to include empty collections.
141
+ """
142
+ colls = self.client.table_names(limit=None)
143
+ if len(colls) == 0:
144
+ return []
145
+ if empty: # include empty tbls
146
+ return colls # type: ignore
147
+ counts = [self.client.open_table(coll).head(1).shape[0] for coll in colls]
148
+ return [coll for coll, count in zip(colls, counts) if count > 0]
149
+
150
+ def _create_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
151
+ """
152
+ Create a subclass of LanceModel with fields:
153
+ - id (str)
154
+ - Vector field that has dims equal to
155
+ the embedding dimension of the embedding model, and a data field of type
156
+ DocClass.
157
+ - other fields from doc_cls
158
+
159
+ Args:
160
+ doc_cls (Type[Document]): A Pydantic model which should be a subclass of
161
+ Document, to be used as the type for the data field.
162
+
163
+ Returns:
164
+ Type[BaseModel]: A new Pydantic model subclassing from LanceModel.
165
+
166
+ Raises:
167
+ ValueError: If `n` is not a non-negative integer or if `DocClass` is not a
168
+ subclass of Document.
169
+ """
170
+ if not issubclass(doc_cls, Document):
171
+ raise ValueError("DocClass must be a subclass of Document")
172
+
173
+ n = self.embedding_dim
174
+
175
+ # Prepare fields for the new model
176
+ fields = {"id": (str, ...), "vector": (Vector(n), ...)}
177
+
178
+ sorted_fields = dict(
179
+ sorted(doc_cls.__fields__.items(), key=lambda item: item[0])
180
+ )
181
+ # Add both statically and dynamically defined fields from doc_cls
182
+ for field_name, field in sorted_fields.items():
183
+ fields[field_name] = (field.outer_type_, field.default)
184
+
185
+ # Create the new model with dynamic fields
186
+ NewModel = create_model(
187
+ "NewModel", __base__=LanceModel, **fields
188
+ ) # type: ignore
189
+ return NewModel # type: ignore
190
+
191
+ def _create_flat_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
192
+ """
193
+ Flat version of the lance_schema, as nested Pydantic schemas are not yet
194
+ supported by LanceDB.
195
+ """
196
+ lance_model = self._create_lance_schema(doc_cls)
197
+ FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
198
+ return FlatModel
199
+
200
+ def create_collection(self, collection_name: str, replace: bool = False) -> None:
201
+ """
202
+ Create a collection with the given name, optionally replacing an existing
203
+ collection if `replace` is True.
204
+ Args:
205
+ collection_name (str): Name of the collection to create.
206
+ replace (bool): Whether to replace an existing collection
207
+ with the same name. Defaults to False.
208
+ """
209
+ self.config.collection_name = collection_name
210
+ collections = self.list_collections()
211
+ if collection_name in collections:
212
+ coll = self.client.open_table(collection_name)
213
+ if coll.head().shape[0] > 0:
214
+ logger.warning(f"Non-empty Collection {collection_name} already exists")
215
+ if not replace:
216
+ logger.warning("Not replacing collection")
217
+ return
218
+ else:
219
+ logger.warning("Recreating fresh collection")
220
+ self.client.create_table(collection_name, schema=self.schema, mode="overwrite")
221
+ if settings.debug:
222
+ level = logger.getEffectiveLevel()
223
+ logger.setLevel(logging.INFO)
224
+ logger.setLevel(level)
225
+
226
+ def _maybe_set_doc_class_schema(self, doc: Document) -> None:
227
+ """
228
+ Set the config.document_class and self.schema based on doc if needed
229
+ Args:
230
+ doc: an instance of Document, to be added to a collection
231
+ """
232
+ extra_metadata_fields = extra_metadata(doc, self.config.document_class)
233
+ if len(extra_metadata_fields) > 0:
234
+ logger.warning(
235
+ f"""
236
+ Added documents contain extra metadata fields:
237
+ {extra_metadata_fields}
238
+ which were not present in the original config.document_class.
239
+ Trying to change document_class and corresponding schemas.
240
+ Overriding LanceDBConfig.document_class with an auto-generated
241
+ Pydantic class that includes these extra fields.
242
+ If this fails, or you see odd results, it is recommended that you
243
+ define a subclass of Document, with metadata of class derived from
244
+ DocMetaData, with extra fields defined via
245
+ `Field(..., description="...")` declarations,
246
+ and set this document class as the value of the
247
+ LanceDBConfig.document_class attribute.
248
+ """
249
+ )
250
+
251
+ doc_cls = extend_document_class(doc)
252
+ self.config.document_class = doc_cls
253
+ self._setup_schemas(doc_cls)
254
+
255
+ def add_documents(self, documents: Sequence[Document]) -> None:
256
+ super().maybe_add_ids(documents)
257
+ colls = self.list_collections(empty=True)
258
+ if len(documents) == 0:
259
+ return
260
+ embedding_vecs = self.embedding_fn([doc.content for doc in documents])
261
+ coll_name = self.config.collection_name
262
+ if coll_name is None:
263
+ raise ValueError("No collection name set, cannot ingest docs")
264
+ self._maybe_set_doc_class_schema(documents[0])
265
+ if (
266
+ coll_name not in colls
267
+ or self.client.open_table(coll_name).head(1).shape[0] == 0
268
+ ):
269
+ # collection either doesn't exist or is empty, so replace it,
270
+ self.create_collection(coll_name, replace=True)
271
+
272
+ ids = [str(d.id()) for d in documents]
273
+ # don't insert all at once, batch in chunks of b,
274
+ # else we get an API error
275
+ b = self.config.batch_size
276
+
277
+ def make_batches() -> Generator[List[BaseModel], None, None]:
278
+ for i in range(0, len(ids), b):
279
+ batch = [
280
+ self.unflattened_schema(
281
+ id=ids[i + j],
282
+ vector=embedding_vecs[i + j],
283
+ **doc.dict(),
284
+ )
285
+ for j, doc in enumerate(documents[i : i + b])
286
+ ]
287
+ if self.config.flatten:
288
+ batch = [
289
+ flatten_pydantic_instance(instance) # type: ignore
290
+ for instance in batch
291
+ ]
292
+ yield batch
293
+
294
+ tbl = self.client.open_table(self.config.collection_name)
295
+ try:
296
+ tbl.add(make_batches())
297
+ except Exception as e:
298
+ logger.error(
299
+ f"""
300
+ Error adding documents to LanceDB: {e}
301
+ POSSIBLE REMEDY: Delete the LancdDB storage directory
302
+ {self.config.storage_path} and try again.
303
+ """
304
+ )
305
+
306
+ def add_dataframe(
307
+ self,
308
+ df: pd.DataFrame,
309
+ content: str = "content",
310
+ metadata: List[str] = [],
311
+ ) -> None:
312
+ """
313
+ Add a dataframe to the collection.
314
+ Args:
315
+ df (pd.DataFrame): A dataframe
316
+ content (str): The name of the column in the dataframe that contains the
317
+ text content to be embedded using the embedding model.
318
+ metadata (List[str]): A list of column names in the dataframe that contain
319
+ metadata to be stored in the database. Defaults to [].
320
+ """
321
+ self.is_from_dataframe = True
322
+ actual_metadata = metadata.copy()
323
+ self.df_metadata_columns = actual_metadata # could be updated below
324
+ # get content column
325
+ content_values = df[content].values.tolist()
326
+ embedding_vecs = self.embedding_fn(content_values)
327
+
328
+ # add vector column
329
+ df["vector"] = embedding_vecs
330
+ if content != "content":
331
+ # rename content column to "content", leave existing column intact
332
+ df = df.rename(columns={content: "content"}, inplace=False)
333
+
334
+ if "id" not in df.columns:
335
+ docs = dataframe_to_documents(df, content="content", metadata=metadata)
336
+ ids = [str(d.id()) for d in docs]
337
+ df["id"] = ids
338
+
339
+ if "id" not in actual_metadata:
340
+ actual_metadata += ["id"]
341
+
342
+ colls = self.list_collections(empty=True)
343
+ coll_name = self.config.collection_name
344
+ if (
345
+ coll_name not in colls
346
+ or self.client.open_table(coll_name).head(1).shape[0] == 0
347
+ ):
348
+ # collection either doesn't exist or is empty, so replace it
349
+ # and set new schema from df
350
+ self.client.create_table(
351
+ self.config.collection_name,
352
+ data=df,
353
+ mode="overwrite",
354
+ )
355
+ doc_cls = dataframe_to_document_model(
356
+ df,
357
+ content=content,
358
+ metadata=actual_metadata,
359
+ exclude=["vector"],
360
+ )
361
+ self.config.document_class = doc_cls # type: ignore
362
+ self._setup_schemas(doc_cls) # type: ignore
363
+ else:
364
+ # collection exists and is not empty, so append to it
365
+ tbl = self.client.open_table(self.config.collection_name)
366
+ tbl.add(df)
367
+
368
+ def delete_collection(self, collection_name: str) -> None:
369
+ self.client.drop_table(collection_name, ignore_missing=True)
370
+
371
+ def _lance_result_to_docs(self, result: LanceVectorQueryBuilder) -> List[Document]:
372
+ if self.is_from_dataframe:
373
+ df = result.to_pandas()
374
+ return dataframe_to_documents(
375
+ df,
376
+ content="content",
377
+ metadata=self.df_metadata_columns,
378
+ doc_cls=self.config.document_class,
379
+ )
380
+ else:
381
+ records = result.to_arrow().to_pylist()
382
+ return self._records_to_docs(records)
383
+
384
+ def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
385
+ if self.config.flatten:
386
+ docs = [
387
+ self.unflattened_schema(**nested_dict_from_flat(rec)) for rec in records
388
+ ]
389
+ else:
390
+ try:
391
+ docs = [self.schema(**rec) for rec in records]
392
+ except ValidationError as e:
393
+ raise ValueError(
394
+ f"""
395
+ Error validating LanceDB result: {e}
396
+ HINT: This could happen when you're re-using an
397
+ existing LanceDB store with a different schema.
398
+ Try deleting your local lancedb storage at `{self.config.storage_path}`
399
+ re-ingesting your documents and/or replacing the collections.
400
+ """
401
+ )
402
+
403
+ doc_cls = self.config.document_class
404
+ doc_cls_field_names = doc_cls.__fields__.keys()
405
+ return [
406
+ doc_cls(
407
+ **{
408
+ field_name: getattr(doc, field_name)
409
+ for field_name in doc_cls_field_names
410
+ }
411
+ )
412
+ for doc in docs
413
+ ]
414
+
415
+ def get_all_documents(self, where: str = "") -> List[Document]:
416
+ if self.config.collection_name is None:
417
+ raise ValueError("No collection name set, cannot retrieve docs")
418
+ tbl = self.client.open_table(self.config.collection_name)
419
+ pre_result = tbl.search(None).where(where or None).limit(None)
420
+ return self._lance_result_to_docs(pre_result)
421
+
422
+ def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
423
+ if self.config.collection_name is None:
424
+ raise ValueError("No collection name set, cannot retrieve docs")
425
+ _ids = [str(id) for id in ids]
426
+ tbl = self.client.open_table(self.config.collection_name)
427
+ docs = []
428
+ for _id in _ids:
429
+ results = self._lance_result_to_docs(tbl.search().where(f"id == '{_id}'"))
430
+ if len(results) > 0:
431
+ docs.append(results[0])
432
+ return docs
433
+
434
+ def similar_texts_with_scores(
435
+ self,
436
+ text: str,
437
+ k: int = 1,
438
+ where: Optional[str] = None,
439
+ ) -> List[Tuple[Document, float]]:
440
+ embedding = self.embedding_fn([text])[0]
441
+ tbl = self.client.open_table(self.config.collection_name)
442
+ result = (
443
+ tbl.search(embedding)
444
+ .metric(self.config.distance)
445
+ .where(where, prefilter=True)
446
+ .limit(k)
447
+ )
448
+ docs = self._lance_result_to_docs(result)
449
+ # note _distance is 1 - cosine
450
+ if self.is_from_dataframe:
451
+ scores = [
452
+ 1 - rec["_distance"] for rec in result.to_pandas().to_dict("records")
453
+ ]
454
+ else:
455
+ scores = [1 - rec["_distance"] for rec in result.to_arrow().to_pylist()]
456
+ if len(docs) == 0:
457
+ logger.warning(f"No matches found for {text}")
458
+ return []
459
+ if settings.debug:
460
+ logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
461
+ doc_score_pairs = list(zip(docs, scores))
462
+ self.show_if_debug(doc_score_pairs)
463
+ return doc_score_pairs
@@ -32,7 +32,7 @@ class MeiliSearchConfig(VectorStoreConfig):
32
32
 
33
33
 
34
34
  class MeiliSearch(VectorStore):
35
- def __init__(self, config: MeiliSearchConfig):
35
+ def __init__(self, config: MeiliSearchConfig = MeiliSearchConfig()):
36
36
  super().__init__(config)
37
37
  self.config: MeiliSearchConfig = config
38
38
  self.host = config.host
@@ -165,12 +165,13 @@ class MeiliSearch(VectorStore):
165
165
  async with self.client() as client:
166
166
  index = client.index(collection_name)
167
167
  await index.add_documents_in_batches(
168
- documents=documents, # type: ignore
168
+ documents=documents,
169
169
  batch_size=self.config.batch_size,
170
170
  primary_key=self.config.primary_key,
171
171
  )
172
172
 
173
173
  def add_documents(self, documents: Sequence[Document]) -> None:
174
+ super().maybe_add_ids(documents)
174
175
  if len(documents) == 0:
175
176
  return
176
177
  colls = self._list_all_collections()
@@ -197,18 +198,19 @@ class MeiliSearch(VectorStore):
197
198
  except ValueError:
198
199
  return id
199
200
 
200
- async def _async_get_documents(self) -> DocumentsInfo:
201
+ async def _async_get_documents(self, where: str = "") -> DocumentsInfo:
201
202
  if self.config.collection_name is None:
202
203
  raise ValueError("No collection name set, cannot retrieve docs")
204
+ filter = [] if where is None else where
203
205
  async with self.client() as client:
204
206
  index = client.index(self.config.collection_name)
205
- documents = await index.get_documents(limit=10_000)
207
+ documents = await index.get_documents(limit=10_000, filter=filter)
206
208
  return documents
207
209
 
208
- def get_all_documents(self) -> List[Document]:
210
+ def get_all_documents(self, where: str = "") -> List[Document]:
209
211
  if self.config.collection_name is None:
210
212
  raise ValueError("No collection name set, cannot retrieve docs")
211
- docs = asyncio.run(self._async_get_documents())
213
+ docs = asyncio.run(self._async_get_documents(where))
212
214
  if docs is None:
213
215
  return []
214
216
  doc_results = docs.results
@@ -226,7 +228,7 @@ class MeiliSearch(VectorStore):
226
228
  async with self.client() as client:
227
229
  index = client.index(self.config.collection_name)
228
230
  documents = await asyncio.gather(*[index.get_document(id) for id in ids])
229
- return documents # type: ignore
231
+ return documents
230
232
 
231
233
  def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
232
234
  if self.config.collection_name is None:
@@ -263,6 +265,7 @@ class MeiliSearch(VectorStore):
263
265
  text: str,
264
266
  k: int = 20,
265
267
  where: Optional[str] = None,
268
+ neighbors: int = 0, # ignored
266
269
  ) -> List[Tuple[Document, float]]:
267
270
  filter = [] if where is None else where
268
271
  if self.config.collection_name is None: