langroid 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +42 -6
- langroid/agent/chat_agent.py +2 -2
- langroid/agent/special/doc_chat_agent.py +14 -4
- langroid/agent/special/lance_doc_chat_agent.py +25 -28
- langroid/agent/special/lance_rag/critic_agent.py +16 -6
- langroid/agent/special/lance_rag/query_planner_agent.py +8 -4
- langroid/agent/special/lance_tools.py +14 -8
- langroid/agent/team.py +1758 -0
- langroid/agent/tool_message.py +6 -10
- langroid/language_models/azure_openai.py +10 -5
- langroid/language_models/base.py +1 -0
- langroid/language_models/openai_gpt.py +18 -5
- langroid/parsing/parse_json.py +27 -1
- langroid/utils/pydantic_utils.py +0 -50
- langroid/vector_store/base.py +6 -4
- langroid/vector_store/chromadb.py +4 -2
- langroid/vector_store/lancedb.py +40 -172
- langroid/vector_store/momento.py +1 -0
- langroid/vector_store/qdrantdb.py +22 -4
- {langroid-0.3.0.dist-info → langroid-0.5.0.dist-info}/METADATA +4 -2
- {langroid-0.3.0.dist-info → langroid-0.5.0.dist-info}/RECORD +24 -23
- pyproject.toml +2 -2
- {langroid-0.3.0.dist-info → langroid-0.5.0.dist-info}/LICENSE +0 -0
- {langroid-0.3.0.dist-info → langroid-0.5.0.dist-info}/WHEEL +0 -0
langroid/agent/tool_message.py
CHANGED
@@ -35,12 +35,10 @@ class ToolMessage(ABC, BaseModel):
|
|
35
35
|
request (str): name of agent method to map to.
|
36
36
|
purpose (str): purpose of agent method, expressed in general terms.
|
37
37
|
(This is used when auto-generating the tool instruction to the LLM)
|
38
|
-
result (str): example of result of agent method.
|
39
38
|
"""
|
40
39
|
|
41
40
|
request: str
|
42
41
|
purpose: str
|
43
|
-
result: str = ""
|
44
42
|
|
45
43
|
class Config:
|
46
44
|
arbitrary_types_allowed = False
|
@@ -48,7 +46,7 @@ class ToolMessage(ABC, BaseModel):
|
|
48
46
|
validate_assignment = True
|
49
47
|
# do not include these fields in the generated schema
|
50
48
|
# since we don't require the LLM to specify them
|
51
|
-
schema_extra = {"exclude": {"purpose"
|
49
|
+
schema_extra = {"exclude": {"purpose"}}
|
52
50
|
|
53
51
|
@classmethod
|
54
52
|
def instructions(cls) -> str:
|
@@ -110,13 +108,13 @@ class ToolMessage(ABC, BaseModel):
|
|
110
108
|
return "\n\n".join(examples_jsons)
|
111
109
|
|
112
110
|
def to_json(self) -> str:
|
113
|
-
return self.json(indent=4, exclude={"
|
111
|
+
return self.json(indent=4, exclude={"purpose"})
|
114
112
|
|
115
113
|
def json_example(self) -> str:
|
116
|
-
return self.json(indent=4, exclude={"
|
114
|
+
return self.json(indent=4, exclude={"purpose"})
|
117
115
|
|
118
116
|
def dict_example(self) -> Dict[str, Any]:
|
119
|
-
return self.dict(exclude={"
|
117
|
+
return self.dict(exclude={"purpose"})
|
120
118
|
|
121
119
|
@classmethod
|
122
120
|
def default_value(cls, f: str) -> Any:
|
@@ -220,9 +218,7 @@ class ToolMessage(ABC, BaseModel):
|
|
220
218
|
if "description" not in parameters["properties"][name]:
|
221
219
|
parameters["properties"][name]["description"] = description
|
222
220
|
|
223
|
-
excludes =
|
224
|
-
["result", "purpose"] if request else ["request", "result", "purpose"]
|
225
|
-
)
|
221
|
+
excludes = ["purpose"] if request else ["request", "purpose"]
|
226
222
|
# exclude 'excludes' from parameters["properties"]:
|
227
223
|
parameters["properties"] = {
|
228
224
|
field: details
|
@@ -263,5 +259,5 @@ class ToolMessage(ABC, BaseModel):
|
|
263
259
|
Returns:
|
264
260
|
Dict[str, Any]: simplified schema
|
265
261
|
"""
|
266
|
-
schema = generate_simple_schema(cls, exclude=["
|
262
|
+
schema = generate_simple_schema(cls, exclude=["purpose"])
|
267
263
|
return schema
|
@@ -133,12 +133,15 @@ class AzureGPT(OpenAIGPT):
|
|
133
133
|
"""
|
134
134
|
Handles the setting of the GPT-4 model in the configuration.
|
135
135
|
This function checks the `model_version` in the configuration.
|
136
|
-
If the version is not set, it raises a ValueError indicating
|
137
|
-
version needs to be specified in the ``.env``
|
138
|
-
It sets `
|
139
|
-
'
|
136
|
+
If the version is not set, it raises a ValueError indicating
|
137
|
+
that the model version needs to be specified in the ``.env``
|
138
|
+
file. It sets `OpenAIChatMode.GPT4o` if the version is
|
139
|
+
'2024-05-13', `OpenAIChatModel.GPT4_TURBO` if the version is
|
140
|
+
'1106-Preview', otherwise, it defaults to setting
|
141
|
+
`OpenAIChatModel.GPT4`.
|
140
142
|
"""
|
141
143
|
VERSION_1106_PREVIEW = "1106-Preview"
|
144
|
+
VERSION_GPT4o = "2024-05-13"
|
142
145
|
|
143
146
|
if self.config.model_version == "":
|
144
147
|
raise ValueError(
|
@@ -146,7 +149,9 @@ class AzureGPT(OpenAIGPT):
|
|
146
149
|
"Please set it to the chat model version used in your deployment."
|
147
150
|
)
|
148
151
|
|
149
|
-
if self.config.model_version ==
|
152
|
+
if self.config.model_version == VERSION_GPT4o:
|
153
|
+
self.config.chat_model = OpenAIChatModel.GPT4o
|
154
|
+
elif self.config.model_version == VERSION_1106_PREVIEW:
|
150
155
|
self.config.chat_model = OpenAIChatModel.GPT4_TURBO
|
151
156
|
else:
|
152
157
|
self.config.chat_model = OpenAIChatModel.GPT4
|
langroid/language_models/base.py
CHANGED
@@ -234,6 +234,7 @@ class LLMResponse(BaseModel):
|
|
234
234
|
# in this case we ignore message, since all information is in function_call
|
235
235
|
msg = ""
|
236
236
|
args = self.function_call.arguments
|
237
|
+
recipient = ""
|
237
238
|
if isinstance(args, dict):
|
238
239
|
recipient = args.get("recipient", "")
|
239
240
|
return recipient, msg
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import ast
|
2
1
|
import hashlib
|
3
2
|
import json
|
4
3
|
import logging
|
@@ -49,6 +48,7 @@ from langroid.language_models.utils import (
|
|
49
48
|
async_retry_with_exponential_backoff,
|
50
49
|
retry_with_exponential_backoff,
|
51
50
|
)
|
51
|
+
from langroid.parsing.parse_json import parse_imperfect_json
|
52
52
|
from langroid.pydantic_v1 import BaseModel
|
53
53
|
from langroid.utils.configuration import settings
|
54
54
|
from langroid.utils.constants import Colors
|
@@ -797,11 +797,24 @@ class OpenAIGPT(LanguageModel):
|
|
797
797
|
args = {}
|
798
798
|
if has_function and function_args != "":
|
799
799
|
try:
|
800
|
-
|
801
|
-
|
800
|
+
stripped_fn_args = function_args.strip()
|
801
|
+
dict_or_list = parse_imperfect_json(stripped_fn_args)
|
802
|
+
if not isinstance(dict_or_list, dict):
|
803
|
+
raise ValueError(
|
804
|
+
f"""
|
805
|
+
Invalid function args: {stripped_fn_args}
|
806
|
+
parsed as {dict_or_list},
|
807
|
+
which is not a valid dict.
|
808
|
+
"""
|
809
|
+
)
|
810
|
+
args = dict_or_list
|
811
|
+
except (SyntaxError, ValueError) as e:
|
802
812
|
logging.warning(
|
803
|
-
f"
|
804
|
-
|
813
|
+
f"""
|
814
|
+
Parsing OpenAI function args failed: {function_args};
|
815
|
+
treating args as normal message. Error detail:
|
816
|
+
{e}
|
817
|
+
"""
|
805
818
|
)
|
806
819
|
has_function = False
|
807
820
|
completion = completion + function_args
|
langroid/parsing/parse_json.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
|
+
import ast
|
1
2
|
import json
|
2
|
-
from typing import Any, Iterator, List
|
3
|
+
from typing import Any, Dict, Iterator, List, Union
|
3
4
|
|
4
5
|
import yaml
|
5
6
|
from pyparsing import nestedExpr, originalTextFor
|
@@ -73,6 +74,31 @@ def add_quotes(s: str) -> str:
|
|
73
74
|
return s
|
74
75
|
|
75
76
|
|
77
|
+
def parse_imperfect_json(json_string: str) -> Union[Dict[str, Any], List[Any]]:
|
78
|
+
if not json_string.strip():
|
79
|
+
raise ValueError("Empty string is not valid JSON")
|
80
|
+
|
81
|
+
# First, try parsing with ast.literal_eval
|
82
|
+
try:
|
83
|
+
result = ast.literal_eval(json_string)
|
84
|
+
if isinstance(result, (dict, list)):
|
85
|
+
return result
|
86
|
+
except (ValueError, SyntaxError):
|
87
|
+
pass
|
88
|
+
|
89
|
+
# If ast.literal_eval fails or returns non-dict/list, try json.loads
|
90
|
+
try:
|
91
|
+
str = add_quotes(json_string)
|
92
|
+
result = json.loads(str)
|
93
|
+
if isinstance(result, (dict, list)):
|
94
|
+
return result
|
95
|
+
except json.JSONDecodeError:
|
96
|
+
pass
|
97
|
+
|
98
|
+
# If all methods fail, raise ValueError
|
99
|
+
raise ValueError(f"Unable to parse as JSON: {json_string}")
|
100
|
+
|
101
|
+
|
76
102
|
def repair_newlines(s: str) -> str:
|
77
103
|
"""
|
78
104
|
Attempt to load as json, and if it fails, try with newlines replaced by space.
|
langroid/utils/pydantic_utils.py
CHANGED
@@ -9,8 +9,6 @@ from typing import (
|
|
9
9
|
Tuple,
|
10
10
|
Type,
|
11
11
|
TypeVar,
|
12
|
-
get_args,
|
13
|
-
get_origin,
|
14
12
|
no_type_check,
|
15
13
|
)
|
16
14
|
|
@@ -313,54 +311,6 @@ def pydantic_obj_from_flat_dict(
|
|
313
311
|
return model(**nested_data)
|
314
312
|
|
315
313
|
|
316
|
-
def clean_schema(model: Type[BaseModel], excludes: List[str] = []) -> Dict[str, Any]:
|
317
|
-
"""
|
318
|
-
Generate a simple schema for a given Pydantic model,
|
319
|
-
including inherited fields, with an option to exclude certain fields.
|
320
|
-
Handles cases where fields are Lists or other generic types and includes
|
321
|
-
field descriptions if available.
|
322
|
-
|
323
|
-
Args:
|
324
|
-
model (Type[BaseModel]): The Pydantic model class.
|
325
|
-
excludes (List[str]): A list of field names to exclude.
|
326
|
-
|
327
|
-
Returns:
|
328
|
-
Dict[str, Any]: A dictionary representing the simple schema.
|
329
|
-
"""
|
330
|
-
schema = {}
|
331
|
-
|
332
|
-
for field_name, field_info in model.__fields__.items():
|
333
|
-
if field_name in excludes:
|
334
|
-
continue
|
335
|
-
|
336
|
-
field_type = field_info.outer_type_
|
337
|
-
description = field_info.field_info.description or ""
|
338
|
-
|
339
|
-
# Handle generic types like List[...]
|
340
|
-
if get_origin(field_type):
|
341
|
-
inner_types = get_args(field_type)
|
342
|
-
inner_type_names = [
|
343
|
-
t.__name__ if hasattr(t, "__name__") else str(t) for t in inner_types
|
344
|
-
]
|
345
|
-
field_type_str = (
|
346
|
-
f"{get_origin(field_type).__name__}" f'[{", ".join(inner_type_names)}]'
|
347
|
-
)
|
348
|
-
schema[field_name] = {"type": field_type_str, "description": description}
|
349
|
-
elif issubclass(field_type, BaseModel):
|
350
|
-
# Directly use the nested model's schema,
|
351
|
-
# integrating it into the current level
|
352
|
-
nested_schema = clean_schema(field_type, excludes)
|
353
|
-
schema[field_name] = {**nested_schema, "description": description}
|
354
|
-
else:
|
355
|
-
# For basic types, use 'type'
|
356
|
-
schema[field_name] = {
|
357
|
-
"type": field_type.__name__,
|
358
|
-
"description": description,
|
359
|
-
}
|
360
|
-
|
361
|
-
return schema
|
362
|
-
|
363
|
-
|
364
314
|
@contextmanager
|
365
315
|
def temp_update(
|
366
316
|
pydantic_object: BaseModel, updates: Dict[str, Any]
|
langroid/vector_store/base.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
import copy
|
2
2
|
import logging
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
-
from typing import Dict, List, Optional, Sequence, Tuple
|
4
|
+
from typing import Dict, List, Optional, Sequence, Tuple, Type
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
8
|
|
9
9
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
10
10
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
11
|
-
from langroid.mytypes import Document
|
11
|
+
from langroid.mytypes import DocMetaData, Document
|
12
12
|
from langroid.pydantic_v1 import BaseSettings
|
13
13
|
from langroid.utils.algorithms.graph import components, topological_sort
|
14
14
|
from langroid.utils.configuration import settings
|
@@ -32,6 +32,9 @@ class VectorStoreConfig(BaseSettings):
|
|
32
32
|
timeout: int = 60
|
33
33
|
host: str = "127.0.0.1"
|
34
34
|
port: int = 6333
|
35
|
+
# used when parsing search results back as Document objects
|
36
|
+
document_class: Type[Document] = Document
|
37
|
+
metadata_class: Type[DocMetaData] = DocMetaData
|
35
38
|
# compose_file: str = "langroid/vector_store/docker-compose-qdrant.yml"
|
36
39
|
|
37
40
|
|
@@ -113,8 +116,7 @@ class VectorStore(ABC):
|
|
113
116
|
"""
|
114
117
|
|
115
118
|
self.config.collection_name = collection_name
|
116
|
-
|
117
|
-
self.create_collection(collection_name, replace=replace)
|
119
|
+
self.config.replace_collection = replace
|
118
120
|
|
119
121
|
@abstractmethod
|
120
122
|
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
@@ -8,7 +8,7 @@ from langroid.embedding_models.base import (
|
|
8
8
|
)
|
9
9
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
10
10
|
from langroid.exceptions import LangroidImportError
|
11
|
-
from langroid.mytypes import
|
11
|
+
from langroid.mytypes import Document
|
12
12
|
from langroid.utils.configuration import settings
|
13
13
|
from langroid.utils.output.printing import print_long_text
|
14
14
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
@@ -200,7 +200,9 @@ class ChromaDB(VectorStore):
|
|
200
200
|
else:
|
201
201
|
m["window_ids"] = m["window_ids"].split(",")
|
202
202
|
docs = [
|
203
|
-
|
203
|
+
self.config.document_class(
|
204
|
+
content=d, metadata=self.config.metadata_class(**m)
|
205
|
+
)
|
204
206
|
for d, m in zip(contents, metadatas)
|
205
207
|
]
|
206
208
|
return docs
|
langroid/vector_store/lancedb.py
CHANGED
@@ -32,13 +32,7 @@ from langroid.utils.configuration import settings
|
|
32
32
|
from langroid.utils.pydantic_utils import (
|
33
33
|
dataframe_to_document_model,
|
34
34
|
dataframe_to_documents,
|
35
|
-
extend_document_class,
|
36
|
-
extra_metadata,
|
37
|
-
flatten_pydantic_instance,
|
38
|
-
flatten_pydantic_model,
|
39
|
-
nested_dict_from_flat,
|
40
35
|
)
|
41
|
-
from langroid.utils.system import pydantic_major_version
|
42
36
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
43
37
|
|
44
38
|
try:
|
@@ -58,10 +52,6 @@ class LanceDBConfig(VectorStoreConfig):
|
|
58
52
|
storage_path: str = ".lancedb/data"
|
59
53
|
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
60
54
|
distance: str = "cosine"
|
61
|
-
# document_class is used to store in lancedb with right schema,
|
62
|
-
# and also to retrieve the right type of Documents when searching.
|
63
|
-
document_class: Type[Document] = Document
|
64
|
-
flatten: bool = False # flatten Document class into LanceSchema ?
|
65
55
|
|
66
56
|
|
67
57
|
class LanceDB(VectorStore):
|
@@ -78,7 +68,6 @@ class LanceDB(VectorStore):
|
|
78
68
|
self.port = config.port
|
79
69
|
self.is_from_dataframe = False # were docs ingested from a dataframe?
|
80
70
|
self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
|
81
|
-
self._setup_schemas(config.document_class)
|
82
71
|
|
83
72
|
load_dotenv()
|
84
73
|
if self.config.cloud:
|
@@ -104,40 +93,6 @@ class LanceDB(VectorStore):
|
|
104
93
|
uri=new_storage_path,
|
105
94
|
)
|
106
95
|
|
107
|
-
# Note: Only create collection if a non-null collection name is provided.
|
108
|
-
# This is useful to delay creation of vecdb until we have a suitable
|
109
|
-
# collection name (e.g. we could get it from the url or folder path).
|
110
|
-
if config.collection_name is not None:
|
111
|
-
self.create_collection(
|
112
|
-
config.collection_name, replace=config.replace_collection
|
113
|
-
)
|
114
|
-
|
115
|
-
def _setup_schemas(self, doc_cls: Type[Document] | None) -> None:
|
116
|
-
try:
|
117
|
-
doc_cls = doc_cls or self.config.document_class
|
118
|
-
self.unflattened_schema = self._create_lance_schema(doc_cls)
|
119
|
-
self.schema = (
|
120
|
-
self._create_flat_lance_schema(doc_cls)
|
121
|
-
if self.config.flatten
|
122
|
-
else self.unflattened_schema
|
123
|
-
)
|
124
|
-
except (AttributeError, TypeError) as e:
|
125
|
-
pydantic_version = pydantic_major_version()
|
126
|
-
if pydantic_version > 1:
|
127
|
-
raise ValueError(
|
128
|
-
f"""
|
129
|
-
{e}
|
130
|
-
====
|
131
|
-
You are using Pydantic v{pydantic_version},
|
132
|
-
which is not yet compatible with Langroid's LanceDB integration.
|
133
|
-
To use Lancedb with Langroid, please install the
|
134
|
-
latest pydantic 1.x instead of pydantic v2, e.g.
|
135
|
-
pip install "pydantic<2.0.0"
|
136
|
-
"""
|
137
|
-
)
|
138
|
-
else:
|
139
|
-
raise e
|
140
|
-
|
141
96
|
def clear_empty_collections(self) -> int:
|
142
97
|
coll_names = self.list_collections()
|
143
98
|
n_deletes = 0
|
@@ -234,91 +189,8 @@ class LanceDB(VectorStore):
|
|
234
189
|
) # type: ignore
|
235
190
|
return NewModel # type: ignore
|
236
191
|
|
237
|
-
def _create_flat_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
|
238
|
-
"""
|
239
|
-
Flat version of the lance_schema, as nested Pydantic schemas are not yet
|
240
|
-
supported by LanceDB.
|
241
|
-
"""
|
242
|
-
if not has_lancedb:
|
243
|
-
raise LangroidImportError("lancedb", "lancedb")
|
244
|
-
lance_model = self._create_lance_schema(doc_cls)
|
245
|
-
FlatModel = flatten_pydantic_model(lance_model, base_model=LanceModel)
|
246
|
-
return FlatModel
|
247
|
-
|
248
192
|
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
249
|
-
|
250
|
-
Create a collection with the given name, optionally replacing an existing
|
251
|
-
collection if `replace` is True.
|
252
|
-
Args:
|
253
|
-
collection_name (str): Name of the collection to create.
|
254
|
-
replace (bool): Whether to replace an existing collection
|
255
|
-
with the same name. Defaults to False.
|
256
|
-
"""
|
257
|
-
self.config.collection_name = collection_name
|
258
|
-
collections = self.list_collections()
|
259
|
-
if collection_name in collections:
|
260
|
-
coll = self.client.open_table(collection_name)
|
261
|
-
if coll.head().shape[0] > 0:
|
262
|
-
logger.warning(f"Non-empty Collection {collection_name} already exists")
|
263
|
-
if not replace:
|
264
|
-
logger.warning("Not replacing collection")
|
265
|
-
return
|
266
|
-
else:
|
267
|
-
logger.warning("Recreating fresh collection")
|
268
|
-
try:
|
269
|
-
self.client.create_table(
|
270
|
-
collection_name, schema=self.schema, mode="overwrite"
|
271
|
-
)
|
272
|
-
except (AttributeError, TypeError) as e:
|
273
|
-
pydantic_version = pydantic_major_version()
|
274
|
-
if pydantic_version > 1:
|
275
|
-
raise ValueError(
|
276
|
-
f"""
|
277
|
-
{e}
|
278
|
-
====
|
279
|
-
You are using Pydantic v{pydantic_version},
|
280
|
-
which is not yet compatible with Langroid's LanceDB integration.
|
281
|
-
To use Lancedb with Langroid, please install the
|
282
|
-
latest pydantic 1.x instead of pydantic v2, e.g.
|
283
|
-
pip install "pydantic<2.0.0"
|
284
|
-
"""
|
285
|
-
)
|
286
|
-
else:
|
287
|
-
raise e
|
288
|
-
|
289
|
-
if settings.debug:
|
290
|
-
level = logger.getEffectiveLevel()
|
291
|
-
logger.setLevel(logging.INFO)
|
292
|
-
logger.setLevel(level)
|
293
|
-
|
294
|
-
def _maybe_set_doc_class_schema(self, doc: Document) -> None:
|
295
|
-
"""
|
296
|
-
Set the config.document_class and self.schema based on doc if needed
|
297
|
-
Args:
|
298
|
-
doc: an instance of Document, to be added to a collection
|
299
|
-
"""
|
300
|
-
extra_metadata_fields = extra_metadata(doc, self.config.document_class)
|
301
|
-
if len(extra_metadata_fields) > 0:
|
302
|
-
logger.warning(
|
303
|
-
f"""
|
304
|
-
Added documents contain extra metadata fields:
|
305
|
-
{extra_metadata_fields}
|
306
|
-
which were not present in the original config.document_class.
|
307
|
-
Trying to change document_class and corresponding schemas.
|
308
|
-
Overriding LanceDBConfig.document_class with an auto-generated
|
309
|
-
Pydantic class that includes these extra fields.
|
310
|
-
If this fails, or you see odd results, it is recommended that you
|
311
|
-
define a subclass of Document, with metadata of class derived from
|
312
|
-
DocMetaData, with extra fields defined via
|
313
|
-
`Field(..., description="...")` declarations,
|
314
|
-
and set this document class as the value of the
|
315
|
-
LanceDBConfig.document_class attribute.
|
316
|
-
"""
|
317
|
-
)
|
318
|
-
|
319
|
-
doc_cls = extend_document_class(doc)
|
320
|
-
self.config.document_class = doc_cls
|
321
|
-
self._setup_schemas(doc_cls)
|
193
|
+
self.config.replace_collection = replace
|
322
194
|
|
323
195
|
def add_documents(self, documents: Sequence[Document]) -> None:
|
324
196
|
super().maybe_add_ids(documents)
|
@@ -329,39 +201,52 @@ class LanceDB(VectorStore):
|
|
329
201
|
coll_name = self.config.collection_name
|
330
202
|
if coll_name is None:
|
331
203
|
raise ValueError("No collection name set, cannot ingest docs")
|
332
|
-
self._maybe_set_doc_class_schema(documents[0])
|
204
|
+
# self._maybe_set_doc_class_schema(documents[0])
|
205
|
+
table_exists = False
|
333
206
|
if (
|
334
|
-
coll_name
|
335
|
-
|
207
|
+
coll_name in colls
|
208
|
+
and self.client.open_table(coll_name).head(1).shape[0] > 0
|
336
209
|
):
|
337
|
-
# collection
|
338
|
-
|
210
|
+
# collection exists and is not empty:
|
211
|
+
# if replace_collection is True, we'll overwrite the existing collection,
|
212
|
+
# else we'll append to it.
|
213
|
+
if self.config.replace_collection:
|
214
|
+
self.client.drop_table(coll_name)
|
215
|
+
else:
|
216
|
+
table_exists = True
|
339
217
|
|
340
218
|
ids = [str(d.id()) for d in documents]
|
341
219
|
# don't insert all at once, batch in chunks of b,
|
342
220
|
# else we get an API error
|
343
221
|
b = self.config.batch_size
|
344
222
|
|
345
|
-
def make_batches() -> Generator[List[
|
223
|
+
def make_batches() -> Generator[List[Dict[str, Any]], None, None]:
|
346
224
|
for i in range(0, len(ids), b):
|
347
225
|
batch = [
|
348
|
-
|
226
|
+
dict(
|
349
227
|
id=ids[i + j],
|
350
228
|
vector=embedding_vecs[i + j],
|
351
229
|
**doc.dict(),
|
352
230
|
)
|
353
231
|
for j, doc in enumerate(documents[i : i + b])
|
354
232
|
]
|
355
|
-
if self.config.flatten:
|
356
|
-
batch = [
|
357
|
-
flatten_pydantic_instance(instance) # type: ignore
|
358
|
-
for instance in batch
|
359
|
-
]
|
360
233
|
yield batch
|
361
234
|
|
362
|
-
tbl = self.client.open_table(self.config.collection_name)
|
363
235
|
try:
|
364
|
-
|
236
|
+
if table_exists:
|
237
|
+
tbl = self.client.open_table(coll_name)
|
238
|
+
tbl.add(make_batches())
|
239
|
+
else:
|
240
|
+
batch_gen = make_batches()
|
241
|
+
batch = next(batch_gen)
|
242
|
+
# use first batch to create table...
|
243
|
+
tbl = self.client.create_table(
|
244
|
+
coll_name,
|
245
|
+
data=batch,
|
246
|
+
mode="create",
|
247
|
+
)
|
248
|
+
# ... and add the rest
|
249
|
+
tbl.add(batch_gen)
|
365
250
|
except Exception as e:
|
366
251
|
logger.error(
|
367
252
|
f"""
|
@@ -427,7 +312,6 @@ class LanceDB(VectorStore):
|
|
427
312
|
exclude=["vector"],
|
428
313
|
)
|
429
314
|
self.config.document_class = doc_cls # type: ignore
|
430
|
-
self._setup_schemas(doc_cls) # type: ignore
|
431
315
|
else:
|
432
316
|
# collection exists and is not empty, so append to it
|
433
317
|
tbl = self.client.open_table(self.config.collection_name)
|
@@ -452,35 +336,19 @@ class LanceDB(VectorStore):
|
|
452
336
|
return self._records_to_docs(records)
|
453
337
|
|
454
338
|
def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
|
455
|
-
|
456
|
-
docs = [
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
HINT: This could happen when you're re-using an
|
467
|
-
existing LanceDB store with a different schema.
|
468
|
-
Try deleting your local lancedb storage at `{self.config.storage_path}`
|
469
|
-
re-ingesting your documents and/or replacing the collections.
|
470
|
-
"""
|
471
|
-
)
|
472
|
-
|
473
|
-
doc_cls = self.config.document_class
|
474
|
-
doc_cls_field_names = doc_cls.__fields__.keys()
|
475
|
-
return [
|
476
|
-
doc_cls(
|
477
|
-
**{
|
478
|
-
field_name: getattr(doc, field_name)
|
479
|
-
for field_name in doc_cls_field_names
|
480
|
-
}
|
339
|
+
try:
|
340
|
+
docs = [self.config.document_class(**rec) for rec in records]
|
341
|
+
except ValidationError as e:
|
342
|
+
raise ValueError(
|
343
|
+
f"""
|
344
|
+
Error validating LanceDB result: {e}
|
345
|
+
HINT: This could happen when you're re-using an
|
346
|
+
existing LanceDB store with a different schema.
|
347
|
+
Try deleting your local lancedb storage at `{self.config.storage_path}`
|
348
|
+
re-ingesting your documents and/or replacing the collections.
|
349
|
+
"""
|
481
350
|
)
|
482
|
-
|
483
|
-
]
|
351
|
+
return docs
|
484
352
|
|
485
353
|
def get_all_documents(self, where: str = "") -> List[Document]:
|
486
354
|
if self.config.collection_name is None:
|
langroid/vector_store/momento.py
CHANGED
@@ -63,6 +63,7 @@ def is_valid_uuid(uuid_to_test: str) -> bool:
|
|
63
63
|
|
64
64
|
class QdrantDBConfig(VectorStoreConfig):
|
65
65
|
cloud: bool = True
|
66
|
+
docker: bool = False
|
66
67
|
collection_name: str | None = "temp"
|
67
68
|
storage_path: str = ".qdrant/data"
|
68
69
|
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
@@ -102,7 +103,19 @@ class QdrantDB(VectorStore):
|
|
102
103
|
load_dotenv()
|
103
104
|
key = os.getenv("QDRANT_API_KEY")
|
104
105
|
url = os.getenv("QDRANT_API_URL")
|
105
|
-
if config.
|
106
|
+
if config.docker:
|
107
|
+
if url is None:
|
108
|
+
logger.warning(
|
109
|
+
f"""The QDRANT_API_URL env variable must be set to use
|
110
|
+
QdrantDB in local docker mode. Please set this
|
111
|
+
value in your .env file.
|
112
|
+
Switching to local storage at {config.storage_path}
|
113
|
+
"""
|
114
|
+
)
|
115
|
+
config.cloud = False
|
116
|
+
else:
|
117
|
+
config.cloud = True
|
118
|
+
elif config.cloud and None in [key, url]:
|
106
119
|
logger.warning(
|
107
120
|
f"""QDRANT_API_KEY, QDRANT_API_URL env variable must be set to use
|
108
121
|
QdrantDB in cloud mode. Please set these values
|
@@ -111,6 +124,7 @@ class QdrantDB(VectorStore):
|
|
111
124
|
"""
|
112
125
|
)
|
113
126
|
config.cloud = False
|
127
|
+
|
114
128
|
if config.cloud:
|
115
129
|
self.client = QdrantClient(
|
116
130
|
url=url,
|
@@ -366,7 +380,11 @@ class QdrantDB(VectorStore):
|
|
366
380
|
with_payload=True,
|
367
381
|
with_vectors=False,
|
368
382
|
)
|
369
|
-
docs += [
|
383
|
+
docs += [
|
384
|
+
self.config.document_class(**record.payload) # type: ignore
|
385
|
+
for record in results
|
386
|
+
]
|
387
|
+
# ignore
|
370
388
|
if next_page_offset is None:
|
371
389
|
break
|
372
390
|
offset = next_page_offset # type: ignore
|
@@ -385,7 +403,7 @@ class QdrantDB(VectorStore):
|
|
385
403
|
# Note the records may NOT be in the order of the ids,
|
386
404
|
# so we re-order them here.
|
387
405
|
id2payload = {record.id: record.payload for record in records}
|
388
|
-
ordered_payloads = [id2payload[id] for id in _ids]
|
406
|
+
ordered_payloads = [id2payload[id] for id in _ids if id in id2payload]
|
389
407
|
docs = [Document(**payload) for payload in ordered_payloads] # type: ignore
|
390
408
|
return docs
|
391
409
|
|
@@ -437,7 +455,7 @@ class QdrantDB(VectorStore):
|
|
437
455
|
] # 2D list -> 1D list
|
438
456
|
scores = [match.score for match in search_result if match is not None]
|
439
457
|
docs = [
|
440
|
-
|
458
|
+
self.config.document_class(**(match.payload)) # type: ignore
|
441
459
|
for match in search_result
|
442
460
|
if match is not None
|
443
461
|
]
|