langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -0,0 +1,463 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, Generator, List, Optional, Sequence, Tuple, Type
|
3
|
+
|
4
|
+
import lancedb
|
5
|
+
import pandas as pd
|
6
|
+
from dotenv import load_dotenv
|
7
|
+
from lancedb.pydantic import LanceModel, Vector
|
8
|
+
from lancedb.query import LanceVectorQueryBuilder
|
9
|
+
from pydantic import BaseModel, ValidationError, create_model
|
10
|
+
|
11
|
+
from langroid.embedding_models.base import (
|
12
|
+
EmbeddingModel,
|
13
|
+
EmbeddingModelsConfig,
|
14
|
+
)
|
15
|
+
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
16
|
+
from langroid.mytypes import Document, EmbeddingFunction
|
17
|
+
from langroid.utils.configuration import settings
|
18
|
+
from langroid.utils.pydantic_utils import (
|
19
|
+
dataframe_to_document_model,
|
20
|
+
dataframe_to_documents,
|
21
|
+
extend_document_class,
|
22
|
+
extra_metadata,
|
23
|
+
flatten_pydantic_instance,
|
24
|
+
flatten_pydantic_model,
|
25
|
+
nested_dict_from_flat,
|
26
|
+
)
|
27
|
+
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class LanceDBConfig(VectorStoreConfig):
|
33
|
+
cloud: bool = False
|
34
|
+
collection_name: str | None = "temp"
|
35
|
+
storage_path: str = ".lancedb/data"
|
36
|
+
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
37
|
+
distance: str = "cosine"
|
38
|
+
# document_class is used to store in lancedb with right schema,
|
39
|
+
# and also to retrieve the right type of Documents when searching.
|
40
|
+
document_class: Type[Document] = Document
|
41
|
+
flatten: bool = False # flatten Document class into LanceSchema ?
|
42
|
+
|
43
|
+
|
44
|
+
class LanceDB(VectorStore):
|
45
|
+
def __init__(self, config: LanceDBConfig = LanceDBConfig()):
|
46
|
+
super().__init__(config)
|
47
|
+
self.config: LanceDBConfig = config
|
48
|
+
emb_model = EmbeddingModel.create(config.embedding)
|
49
|
+
self.embedding_fn: EmbeddingFunction = emb_model.embedding_fn()
|
50
|
+
self.embedding_dim = emb_model.embedding_dims
|
51
|
+
self.host = config.host
|
52
|
+
self.port = config.port
|
53
|
+
self.is_from_dataframe = False # were docs ingested from a dataframe?
|
54
|
+
self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
|
55
|
+
self._setup_schemas(config.document_class)
|
56
|
+
|
57
|
+
load_dotenv()
|
58
|
+
if self.config.cloud:
|
59
|
+
logger.warning(
|
60
|
+
"LanceDB Cloud is not available yet. Switching to local storage."
|
61
|
+
)
|
62
|
+
config.cloud = False
|
63
|
+
else:
|
64
|
+
try:
|
65
|
+
self.client = lancedb.connect(
|
66
|
+
uri=config.storage_path,
|
67
|
+
)
|
68
|
+
except Exception as e:
|
69
|
+
new_storage_path = config.storage_path + ".new"
|
70
|
+
logger.warning(
|
71
|
+
f"""
|
72
|
+
Error connecting to local LanceDB at {config.storage_path}:
|
73
|
+
{e}
|
74
|
+
Switching to {new_storage_path}
|
75
|
+
"""
|
76
|
+
)
|
77
|
+
self.client = lancedb.connect(
|
78
|
+
uri=new_storage_path,
|
79
|
+
)
|
80
|
+
|
81
|
+
# Note: Only create collection if a non-null collection name is provided.
|
82
|
+
# This is useful to delay creation of vecdb until we have a suitable
|
83
|
+
# collection name (e.g. we could get it from the url or folder path).
|
84
|
+
if config.collection_name is not None:
|
85
|
+
self.create_collection(
|
86
|
+
config.collection_name, replace=config.replace_collection
|
87
|
+
)
|
88
|
+
|
89
|
+
def _setup_schemas(self, doc_cls: Type[Document] | None) -> None:
|
90
|
+
doc_cls = doc_cls or self.config.document_class
|
91
|
+
self.unflattened_schema = self._create_lance_schema(doc_cls)
|
92
|
+
self.schema = (
|
93
|
+
self._create_flat_lance_schema(doc_cls)
|
94
|
+
if self.config.flatten
|
95
|
+
else self.unflattened_schema
|
96
|
+
)
|
97
|
+
|
98
|
+
def clear_empty_collections(self) -> int:
|
99
|
+
coll_names = self.list_collections()
|
100
|
+
n_deletes = 0
|
101
|
+
for name in coll_names:
|
102
|
+
nr = self.client.open_table(name).head(1).shape[0]
|
103
|
+
if nr == 0:
|
104
|
+
n_deletes += 1
|
105
|
+
self.client.drop_table(name)
|
106
|
+
return n_deletes
|
107
|
+
|
108
|
+
def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
|
109
|
+
"""Clear all collections with the given prefix."""
|
110
|
+
if not really:
|
111
|
+
logger.warning("Not deleting all collections, set really=True to confirm")
|
112
|
+
return 0
|
113
|
+
coll_names = [
|
114
|
+
c for c in self.list_collections(empty=True) if c.startswith(prefix)
|
115
|
+
]
|
116
|
+
if len(coll_names) == 0:
|
117
|
+
logger.warning(f"No collections found with prefix {prefix}")
|
118
|
+
return 0
|
119
|
+
n_empty_deletes = 0
|
120
|
+
n_non_empty_deletes = 0
|
121
|
+
for name in coll_names:
|
122
|
+
nr = self.client.open_table(name).head(1).shape[0]
|
123
|
+
n_empty_deletes += nr == 0
|
124
|
+
n_non_empty_deletes += nr > 0
|
125
|
+
self.client.drop_table(name)
|
126
|
+
logger.warning(
|
127
|
+
f"""
|
128
|
+
Deleted {n_empty_deletes} empty collections and
|
129
|
+
{n_non_empty_deletes} non-empty collections.
|
130
|
+
"""
|
131
|
+
)
|
132
|
+
return n_empty_deletes + n_non_empty_deletes
|
133
|
+
|
134
|
+
def list_collections(self, empty: bool = False) -> List[str]:
|
135
|
+
"""
|
136
|
+
Returns:
|
137
|
+
List of collection names that have at least one vector.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
empty (bool, optional): Whether to include empty collections.
|
141
|
+
"""
|
142
|
+
colls = self.client.table_names(limit=None)
|
143
|
+
if len(colls) == 0:
|
144
|
+
return []
|
145
|
+
if empty: # include empty tbls
|
146
|
+
return colls # type: ignore
|
147
|
+
counts = [self.client.open_table(coll).head(1).shape[0] for coll in colls]
|
148
|
+
return [coll for coll, count in zip(colls, counts) if count > 0]
|
149
|
+
|
150
|
+
def _create_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
|
151
|
+
"""
|
152
|
+
Create a subclass of LanceModel with fields:
|
153
|
+
- id (str)
|
154
|
+
- Vector field that has dims equal to
|
155
|
+
the embedding dimension of the embedding model, and a data field of type
|
156
|
+
DocClass.
|
157
|
+
- other fields from doc_cls
|
158
|
+
|
159
|
+
Args:
|
160
|
+
doc_cls (Type[Document]): A Pydantic model which should be a subclass of
|
161
|
+
Document, to be used as the type for the data field.
|
162
|
+
|
163
|
+
Returns:
|
164
|
+
Type[BaseModel]: A new Pydantic model subclassing from LanceModel.
|
165
|
+
|
166
|
+
Raises:
|
167
|
+
ValueError: If `n` is not a non-negative integer or if `DocClass` is not a
|
168
|
+
subclass of Document.
|
169
|
+
"""
|
170
|
+
if not issubclass(doc_cls, Document):
|
171
|
+
raise ValueError("DocClass must be a subclass of Document")
|
172
|
+
|
173
|
+
n = self.embedding_dim
|
174
|
+
|
175
|
+
# Prepare fields for the new model
|
176
|
+
fields = {"id": (str, ...), "vector": (Vector(n), ...)}
|
177
|
+
|
178
|
+
sorted_fields = dict(
|
179
|
+
sorted(doc_cls.__fields__.items(), key=lambda item: item[0])
|
180
|
+
)
|
181
|
+
# Add both statically and dynamically defined fields from doc_cls
|
182
|
+
for field_name, field in sorted_fields.items():
|
183
|
+
fields[field_name] = (field.outer_type_, field.default)
|
184
|
+
|
185
|
+
# Create the new model with dynamic fields
|
186
|
+
NewModel = create_model(
|
187
|
+
"NewModel", __base__=LanceModel, **fields
|
188
|
+
) # type: ignore
|
189
|
+
return NewModel # type: ignore
|
190
|
+
|
191
|
+
def _create_flat_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
|
192
|
+
"""
|
193
|
+
Flat version of the lance_schema, as nested Pydantic schemas are not yet
|
194
|
+
supported by LanceDB.
|
195
|
+
"""
|
196
|
+
lance_model = self._create_lance_schema(doc_cls)
|
197
|
+
FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
|
198
|
+
return FlatModel
|
199
|
+
|
200
|
+
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
201
|
+
"""
|
202
|
+
Create a collection with the given name, optionally replacing an existing
|
203
|
+
collection if `replace` is True.
|
204
|
+
Args:
|
205
|
+
collection_name (str): Name of the collection to create.
|
206
|
+
replace (bool): Whether to replace an existing collection
|
207
|
+
with the same name. Defaults to False.
|
208
|
+
"""
|
209
|
+
self.config.collection_name = collection_name
|
210
|
+
collections = self.list_collections()
|
211
|
+
if collection_name in collections:
|
212
|
+
coll = self.client.open_table(collection_name)
|
213
|
+
if coll.head().shape[0] > 0:
|
214
|
+
logger.warning(f"Non-empty Collection {collection_name} already exists")
|
215
|
+
if not replace:
|
216
|
+
logger.warning("Not replacing collection")
|
217
|
+
return
|
218
|
+
else:
|
219
|
+
logger.warning("Recreating fresh collection")
|
220
|
+
self.client.create_table(collection_name, schema=self.schema, mode="overwrite")
|
221
|
+
if settings.debug:
|
222
|
+
level = logger.getEffectiveLevel()
|
223
|
+
logger.setLevel(logging.INFO)
|
224
|
+
logger.setLevel(level)
|
225
|
+
|
226
|
+
def _maybe_set_doc_class_schema(self, doc: Document) -> None:
|
227
|
+
"""
|
228
|
+
Set the config.document_class and self.schema based on doc if needed
|
229
|
+
Args:
|
230
|
+
doc: an instance of Document, to be added to a collection
|
231
|
+
"""
|
232
|
+
extra_metadata_fields = extra_metadata(doc, self.config.document_class)
|
233
|
+
if len(extra_metadata_fields) > 0:
|
234
|
+
logger.warning(
|
235
|
+
f"""
|
236
|
+
Added documents contain extra metadata fields:
|
237
|
+
{extra_metadata_fields}
|
238
|
+
which were not present in the original config.document_class.
|
239
|
+
Trying to change document_class and corresponding schemas.
|
240
|
+
Overriding LanceDBConfig.document_class with an auto-generated
|
241
|
+
Pydantic class that includes these extra fields.
|
242
|
+
If this fails, or you see odd results, it is recommended that you
|
243
|
+
define a subclass of Document, with metadata of class derived from
|
244
|
+
DocMetaData, with extra fields defined via
|
245
|
+
`Field(..., description="...")` declarations,
|
246
|
+
and set this document class as the value of the
|
247
|
+
LanceDBConfig.document_class attribute.
|
248
|
+
"""
|
249
|
+
)
|
250
|
+
|
251
|
+
doc_cls = extend_document_class(doc)
|
252
|
+
self.config.document_class = doc_cls
|
253
|
+
self._setup_schemas(doc_cls)
|
254
|
+
|
255
|
+
def add_documents(self, documents: Sequence[Document]) -> None:
|
256
|
+
super().maybe_add_ids(documents)
|
257
|
+
colls = self.list_collections(empty=True)
|
258
|
+
if len(documents) == 0:
|
259
|
+
return
|
260
|
+
embedding_vecs = self.embedding_fn([doc.content for doc in documents])
|
261
|
+
coll_name = self.config.collection_name
|
262
|
+
if coll_name is None:
|
263
|
+
raise ValueError("No collection name set, cannot ingest docs")
|
264
|
+
self._maybe_set_doc_class_schema(documents[0])
|
265
|
+
if (
|
266
|
+
coll_name not in colls
|
267
|
+
or self.client.open_table(coll_name).head(1).shape[0] == 0
|
268
|
+
):
|
269
|
+
# collection either doesn't exist or is empty, so replace it,
|
270
|
+
self.create_collection(coll_name, replace=True)
|
271
|
+
|
272
|
+
ids = [str(d.id()) for d in documents]
|
273
|
+
# don't insert all at once, batch in chunks of b,
|
274
|
+
# else we get an API error
|
275
|
+
b = self.config.batch_size
|
276
|
+
|
277
|
+
def make_batches() -> Generator[List[BaseModel], None, None]:
|
278
|
+
for i in range(0, len(ids), b):
|
279
|
+
batch = [
|
280
|
+
self.unflattened_schema(
|
281
|
+
id=ids[i + j],
|
282
|
+
vector=embedding_vecs[i + j],
|
283
|
+
**doc.dict(),
|
284
|
+
)
|
285
|
+
for j, doc in enumerate(documents[i : i + b])
|
286
|
+
]
|
287
|
+
if self.config.flatten:
|
288
|
+
batch = [
|
289
|
+
flatten_pydantic_instance(instance) # type: ignore
|
290
|
+
for instance in batch
|
291
|
+
]
|
292
|
+
yield batch
|
293
|
+
|
294
|
+
tbl = self.client.open_table(self.config.collection_name)
|
295
|
+
try:
|
296
|
+
tbl.add(make_batches())
|
297
|
+
except Exception as e:
|
298
|
+
logger.error(
|
299
|
+
f"""
|
300
|
+
Error adding documents to LanceDB: {e}
|
301
|
+
POSSIBLE REMEDY: Delete the LancdDB storage directory
|
302
|
+
{self.config.storage_path} and try again.
|
303
|
+
"""
|
304
|
+
)
|
305
|
+
|
306
|
+
def add_dataframe(
|
307
|
+
self,
|
308
|
+
df: pd.DataFrame,
|
309
|
+
content: str = "content",
|
310
|
+
metadata: List[str] = [],
|
311
|
+
) -> None:
|
312
|
+
"""
|
313
|
+
Add a dataframe to the collection.
|
314
|
+
Args:
|
315
|
+
df (pd.DataFrame): A dataframe
|
316
|
+
content (str): The name of the column in the dataframe that contains the
|
317
|
+
text content to be embedded using the embedding model.
|
318
|
+
metadata (List[str]): A list of column names in the dataframe that contain
|
319
|
+
metadata to be stored in the database. Defaults to [].
|
320
|
+
"""
|
321
|
+
self.is_from_dataframe = True
|
322
|
+
actual_metadata = metadata.copy()
|
323
|
+
self.df_metadata_columns = actual_metadata # could be updated below
|
324
|
+
# get content column
|
325
|
+
content_values = df[content].values.tolist()
|
326
|
+
embedding_vecs = self.embedding_fn(content_values)
|
327
|
+
|
328
|
+
# add vector column
|
329
|
+
df["vector"] = embedding_vecs
|
330
|
+
if content != "content":
|
331
|
+
# rename content column to "content", leave existing column intact
|
332
|
+
df = df.rename(columns={content: "content"}, inplace=False)
|
333
|
+
|
334
|
+
if "id" not in df.columns:
|
335
|
+
docs = dataframe_to_documents(df, content="content", metadata=metadata)
|
336
|
+
ids = [str(d.id()) for d in docs]
|
337
|
+
df["id"] = ids
|
338
|
+
|
339
|
+
if "id" not in actual_metadata:
|
340
|
+
actual_metadata += ["id"]
|
341
|
+
|
342
|
+
colls = self.list_collections(empty=True)
|
343
|
+
coll_name = self.config.collection_name
|
344
|
+
if (
|
345
|
+
coll_name not in colls
|
346
|
+
or self.client.open_table(coll_name).head(1).shape[0] == 0
|
347
|
+
):
|
348
|
+
# collection either doesn't exist or is empty, so replace it
|
349
|
+
# and set new schema from df
|
350
|
+
self.client.create_table(
|
351
|
+
self.config.collection_name,
|
352
|
+
data=df,
|
353
|
+
mode="overwrite",
|
354
|
+
)
|
355
|
+
doc_cls = dataframe_to_document_model(
|
356
|
+
df,
|
357
|
+
content=content,
|
358
|
+
metadata=actual_metadata,
|
359
|
+
exclude=["vector"],
|
360
|
+
)
|
361
|
+
self.config.document_class = doc_cls # type: ignore
|
362
|
+
self._setup_schemas(doc_cls) # type: ignore
|
363
|
+
else:
|
364
|
+
# collection exists and is not empty, so append to it
|
365
|
+
tbl = self.client.open_table(self.config.collection_name)
|
366
|
+
tbl.add(df)
|
367
|
+
|
368
|
+
def delete_collection(self, collection_name: str) -> None:
|
369
|
+
self.client.drop_table(collection_name, ignore_missing=True)
|
370
|
+
|
371
|
+
def _lance_result_to_docs(self, result: LanceVectorQueryBuilder) -> List[Document]:
|
372
|
+
if self.is_from_dataframe:
|
373
|
+
df = result.to_pandas()
|
374
|
+
return dataframe_to_documents(
|
375
|
+
df,
|
376
|
+
content="content",
|
377
|
+
metadata=self.df_metadata_columns,
|
378
|
+
doc_cls=self.config.document_class,
|
379
|
+
)
|
380
|
+
else:
|
381
|
+
records = result.to_arrow().to_pylist()
|
382
|
+
return self._records_to_docs(records)
|
383
|
+
|
384
|
+
def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
|
385
|
+
if self.config.flatten:
|
386
|
+
docs = [
|
387
|
+
self.unflattened_schema(**nested_dict_from_flat(rec)) for rec in records
|
388
|
+
]
|
389
|
+
else:
|
390
|
+
try:
|
391
|
+
docs = [self.schema(**rec) for rec in records]
|
392
|
+
except ValidationError as e:
|
393
|
+
raise ValueError(
|
394
|
+
f"""
|
395
|
+
Error validating LanceDB result: {e}
|
396
|
+
HINT: This could happen when you're re-using an
|
397
|
+
existing LanceDB store with a different schema.
|
398
|
+
Try deleting your local lancedb storage at `{self.config.storage_path}`
|
399
|
+
re-ingesting your documents and/or replacing the collections.
|
400
|
+
"""
|
401
|
+
)
|
402
|
+
|
403
|
+
doc_cls = self.config.document_class
|
404
|
+
doc_cls_field_names = doc_cls.__fields__.keys()
|
405
|
+
return [
|
406
|
+
doc_cls(
|
407
|
+
**{
|
408
|
+
field_name: getattr(doc, field_name)
|
409
|
+
for field_name in doc_cls_field_names
|
410
|
+
}
|
411
|
+
)
|
412
|
+
for doc in docs
|
413
|
+
]
|
414
|
+
|
415
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
416
|
+
if self.config.collection_name is None:
|
417
|
+
raise ValueError("No collection name set, cannot retrieve docs")
|
418
|
+
tbl = self.client.open_table(self.config.collection_name)
|
419
|
+
pre_result = tbl.search(None).where(where or None).limit(None)
|
420
|
+
return self._lance_result_to_docs(pre_result)
|
421
|
+
|
422
|
+
def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
|
423
|
+
if self.config.collection_name is None:
|
424
|
+
raise ValueError("No collection name set, cannot retrieve docs")
|
425
|
+
_ids = [str(id) for id in ids]
|
426
|
+
tbl = self.client.open_table(self.config.collection_name)
|
427
|
+
docs = []
|
428
|
+
for _id in _ids:
|
429
|
+
results = self._lance_result_to_docs(tbl.search().where(f"id == '{_id}'"))
|
430
|
+
if len(results) > 0:
|
431
|
+
docs.append(results[0])
|
432
|
+
return docs
|
433
|
+
|
434
|
+
def similar_texts_with_scores(
|
435
|
+
self,
|
436
|
+
text: str,
|
437
|
+
k: int = 1,
|
438
|
+
where: Optional[str] = None,
|
439
|
+
) -> List[Tuple[Document, float]]:
|
440
|
+
embedding = self.embedding_fn([text])[0]
|
441
|
+
tbl = self.client.open_table(self.config.collection_name)
|
442
|
+
result = (
|
443
|
+
tbl.search(embedding)
|
444
|
+
.metric(self.config.distance)
|
445
|
+
.where(where, prefilter=True)
|
446
|
+
.limit(k)
|
447
|
+
)
|
448
|
+
docs = self._lance_result_to_docs(result)
|
449
|
+
# note _distance is 1 - cosine
|
450
|
+
if self.is_from_dataframe:
|
451
|
+
scores = [
|
452
|
+
1 - rec["_distance"] for rec in result.to_pandas().to_dict("records")
|
453
|
+
]
|
454
|
+
else:
|
455
|
+
scores = [1 - rec["_distance"] for rec in result.to_arrow().to_pylist()]
|
456
|
+
if len(docs) == 0:
|
457
|
+
logger.warning(f"No matches found for {text}")
|
458
|
+
return []
|
459
|
+
if settings.debug:
|
460
|
+
logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
|
461
|
+
doc_score_pairs = list(zip(docs, scores))
|
462
|
+
self.show_if_debug(doc_score_pairs)
|
463
|
+
return doc_score_pairs
|
@@ -32,7 +32,7 @@ class MeiliSearchConfig(VectorStoreConfig):
|
|
32
32
|
|
33
33
|
|
34
34
|
class MeiliSearch(VectorStore):
|
35
|
-
def __init__(self, config: MeiliSearchConfig):
|
35
|
+
def __init__(self, config: MeiliSearchConfig = MeiliSearchConfig()):
|
36
36
|
super().__init__(config)
|
37
37
|
self.config: MeiliSearchConfig = config
|
38
38
|
self.host = config.host
|
@@ -165,12 +165,13 @@ class MeiliSearch(VectorStore):
|
|
165
165
|
async with self.client() as client:
|
166
166
|
index = client.index(collection_name)
|
167
167
|
await index.add_documents_in_batches(
|
168
|
-
documents=documents,
|
168
|
+
documents=documents,
|
169
169
|
batch_size=self.config.batch_size,
|
170
170
|
primary_key=self.config.primary_key,
|
171
171
|
)
|
172
172
|
|
173
173
|
def add_documents(self, documents: Sequence[Document]) -> None:
|
174
|
+
super().maybe_add_ids(documents)
|
174
175
|
if len(documents) == 0:
|
175
176
|
return
|
176
177
|
colls = self._list_all_collections()
|
@@ -197,18 +198,19 @@ class MeiliSearch(VectorStore):
|
|
197
198
|
except ValueError:
|
198
199
|
return id
|
199
200
|
|
200
|
-
async def _async_get_documents(self) -> DocumentsInfo:
|
201
|
+
async def _async_get_documents(self, where: str = "") -> DocumentsInfo:
|
201
202
|
if self.config.collection_name is None:
|
202
203
|
raise ValueError("No collection name set, cannot retrieve docs")
|
204
|
+
filter = [] if where is None else where
|
203
205
|
async with self.client() as client:
|
204
206
|
index = client.index(self.config.collection_name)
|
205
|
-
documents = await index.get_documents(limit=10_000)
|
207
|
+
documents = await index.get_documents(limit=10_000, filter=filter)
|
206
208
|
return documents
|
207
209
|
|
208
|
-
def get_all_documents(self) -> List[Document]:
|
210
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
209
211
|
if self.config.collection_name is None:
|
210
212
|
raise ValueError("No collection name set, cannot retrieve docs")
|
211
|
-
docs = asyncio.run(self._async_get_documents())
|
213
|
+
docs = asyncio.run(self._async_get_documents(where))
|
212
214
|
if docs is None:
|
213
215
|
return []
|
214
216
|
doc_results = docs.results
|
@@ -226,7 +228,7 @@ class MeiliSearch(VectorStore):
|
|
226
228
|
async with self.client() as client:
|
227
229
|
index = client.index(self.config.collection_name)
|
228
230
|
documents = await asyncio.gather(*[index.get_document(id) for id in ids])
|
229
|
-
return documents
|
231
|
+
return documents
|
230
232
|
|
231
233
|
def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
|
232
234
|
if self.config.collection_name is None:
|
@@ -263,6 +265,7 @@ class MeiliSearch(VectorStore):
|
|
263
265
|
text: str,
|
264
266
|
k: int = 20,
|
265
267
|
where: Optional[str] = None,
|
268
|
+
neighbors: int = 0, # ignored
|
266
269
|
) -> List[Tuple[Document, float]]:
|
267
270
|
filter = [] if where is None else where
|
268
271
|
if self.config.collection_name is None:
|