langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +39 -17
- langroid/agent/base.py-e +2216 -0
- langroid/agent/callbacks/chainlit.py +2 -1
- langroid/agent/chat_agent.py +73 -55
- langroid/agent/chat_agent.py-e +2086 -0
- langroid/agent/chat_document.py +7 -7
- langroid/agent/chat_document.py-e +513 -0
- langroid/agent/openai_assistant.py +9 -9
- langroid/agent/openai_assistant.py-e +882 -0
- langroid/agent/special/arangodb/arangodb_agent.py +10 -18
- langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
- langroid/agent/special/arangodb/tools.py +3 -3
- langroid/agent/special/doc_chat_agent.py +16 -14
- langroid/agent/special/lance_rag/critic_agent.py +2 -2
- langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
- langroid/agent/special/lance_tools.py +6 -5
- langroid/agent/special/lance_tools.py-e +61 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
- langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
- langroid/agent/special/relevance_extractor_agent.py +1 -1
- langroid/agent/special/sql/sql_chat_agent.py +11 -3
- langroid/agent/task.py +9 -87
- langroid/agent/task.py-e +2418 -0
- langroid/agent/tool_message.py +33 -17
- langroid/agent/tool_message.py-e +400 -0
- langroid/agent/tools/file_tools.py +4 -2
- langroid/agent/tools/file_tools.py-e +234 -0
- langroid/agent/tools/mcp/fastmcp_client.py +19 -6
- langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
- langroid/agent/tools/orchestration.py +22 -17
- langroid/agent/tools/orchestration.py-e +301 -0
- langroid/agent/tools/recipient_tool.py +3 -3
- langroid/agent/tools/task_tool.py +22 -16
- langroid/agent/tools/task_tool.py-e +249 -0
- langroid/agent/xml_tool_message.py +90 -35
- langroid/agent/xml_tool_message.py-e +392 -0
- langroid/cachedb/base.py +1 -1
- langroid/embedding_models/base.py +2 -2
- langroid/embedding_models/models.py +3 -7
- langroid/embedding_models/models.py-e +563 -0
- langroid/exceptions.py +4 -1
- langroid/language_models/azure_openai.py +2 -2
- langroid/language_models/azure_openai.py-e +134 -0
- langroid/language_models/base.py +6 -4
- langroid/language_models/base.py-e +812 -0
- langroid/language_models/client_cache.py +64 -0
- langroid/language_models/config.py +2 -4
- langroid/language_models/config.py-e +18 -0
- langroid/language_models/model_info.py +9 -1
- langroid/language_models/model_info.py-e +483 -0
- langroid/language_models/openai_gpt.py +119 -20
- langroid/language_models/openai_gpt.py-e +2280 -0
- langroid/language_models/provider_params.py +3 -22
- langroid/language_models/provider_params.py-e +153 -0
- langroid/mytypes.py +11 -4
- langroid/mytypes.py-e +132 -0
- langroid/parsing/code_parser.py +1 -1
- langroid/parsing/file_attachment.py +1 -1
- langroid/parsing/file_attachment.py-e +246 -0
- langroid/parsing/md_parser.py +14 -4
- langroid/parsing/md_parser.py-e +574 -0
- langroid/parsing/parser.py +22 -7
- langroid/parsing/parser.py-e +410 -0
- langroid/parsing/repo_loader.py +3 -1
- langroid/parsing/repo_loader.py-e +812 -0
- langroid/parsing/search.py +1 -1
- langroid/parsing/url_loader.py +17 -51
- langroid/parsing/url_loader.py-e +683 -0
- langroid/parsing/urls.py +5 -4
- langroid/parsing/urls.py-e +279 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/pydantic_v1/__init__.py +45 -6
- langroid/pydantic_v1/__init__.py-e +36 -0
- langroid/pydantic_v1/main.py +11 -4
- langroid/pydantic_v1/main.py-e +11 -0
- langroid/utils/configuration.py +13 -11
- langroid/utils/configuration.py-e +141 -0
- langroid/utils/constants.py +1 -1
- langroid/utils/constants.py-e +32 -0
- langroid/utils/globals.py +21 -5
- langroid/utils/globals.py-e +49 -0
- langroid/utils/html_logger.py +2 -1
- langroid/utils/html_logger.py-e +825 -0
- langroid/utils/object_registry.py +1 -1
- langroid/utils/object_registry.py-e +66 -0
- langroid/utils/pydantic_utils.py +55 -28
- langroid/utils/pydantic_utils.py-e +602 -0
- langroid/utils/types.py +2 -2
- langroid/utils/types.py-e +113 -0
- langroid/vector_store/base.py +3 -3
- langroid/vector_store/lancedb.py +5 -5
- langroid/vector_store/lancedb.py-e +404 -0
- langroid/vector_store/meilisearch.py +2 -2
- langroid/vector_store/pineconedb.py +4 -4
- langroid/vector_store/pineconedb.py-e +427 -0
- langroid/vector_store/postgres.py +1 -1
- langroid/vector_store/qdrantdb.py +3 -3
- langroid/vector_store/weaviatedb.py +1 -1
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
- langroid-0.59.0b1.dist-info/RECORD +181 -0
- langroid/agent/special/doc_chat_task.py +0 -0
- langroid/mcp/__init__.py +0 -1
- langroid/mcp/server/__init__.py +0 -1
- langroid-0.58.2.dist-info/RECORD +0 -145
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
- {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,113 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from inspect import signature
|
4
|
+
from typing import Any, Optional, Type, TypeVar, Union, get_args, get_origin
|
5
|
+
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
PrimitiveType = Union[int, float, bool, str]
|
10
|
+
T = TypeVar("T")
|
11
|
+
|
12
|
+
|
13
|
+
def is_instance_of(obj: Any, type_hint: Type[T] | Any) -> bool:
|
14
|
+
"""
|
15
|
+
Check if an object is an instance of a type hint, e.g.
|
16
|
+
to check whether x is of type `List[ToolMessage]` or type `int`
|
17
|
+
"""
|
18
|
+
if type_hint == Any:
|
19
|
+
return True
|
20
|
+
|
21
|
+
if type_hint is type(obj):
|
22
|
+
return True
|
23
|
+
|
24
|
+
origin = get_origin(type_hint)
|
25
|
+
args = get_args(type_hint)
|
26
|
+
|
27
|
+
if origin is Union:
|
28
|
+
return any(is_instance_of(obj, arg) for arg in args)
|
29
|
+
|
30
|
+
if origin: # e.g. List, Dict, Tuple, Set
|
31
|
+
if isinstance(obj, origin):
|
32
|
+
# check if all items in obj are of the required types
|
33
|
+
if args:
|
34
|
+
if isinstance(obj, (list, tuple, set)):
|
35
|
+
return all(is_instance_of(item, args[0]) for item in obj)
|
36
|
+
if isinstance(obj, dict):
|
37
|
+
return all(
|
38
|
+
is_instance_of(k, args[0]) and is_instance_of(v, args[1])
|
39
|
+
for k, v in obj.items()
|
40
|
+
)
|
41
|
+
return True
|
42
|
+
else:
|
43
|
+
return False
|
44
|
+
|
45
|
+
return isinstance(obj, type_hint)
|
46
|
+
|
47
|
+
|
48
|
+
def to_string(msg: Any) -> str:
|
49
|
+
"""
|
50
|
+
Best-effort conversion of arbitrary msg to str.
|
51
|
+
Return empty string if conversion fails.
|
52
|
+
"""
|
53
|
+
if msg is None:
|
54
|
+
return ""
|
55
|
+
if isinstance(msg, str):
|
56
|
+
return msg
|
57
|
+
if isinstance(msg, BaseModel):
|
58
|
+
return msg.model_dump_json()
|
59
|
+
# last resort: use json.dumps() or str() to make it a str
|
60
|
+
try:
|
61
|
+
return json.dumps(msg)
|
62
|
+
except Exception:
|
63
|
+
try:
|
64
|
+
return str(msg)
|
65
|
+
except Exception as e:
|
66
|
+
logger.error(
|
67
|
+
f"""
|
68
|
+
Error converting msg to str: {e}",
|
69
|
+
""",
|
70
|
+
exc_info=True,
|
71
|
+
)
|
72
|
+
return ""
|
73
|
+
|
74
|
+
|
75
|
+
def from_string(
|
76
|
+
s: str,
|
77
|
+
output_type: Type[PrimitiveType],
|
78
|
+
) -> Optional[PrimitiveType]:
|
79
|
+
if output_type is int:
|
80
|
+
try:
|
81
|
+
return int(s)
|
82
|
+
except ValueError:
|
83
|
+
return None
|
84
|
+
elif output_type is float:
|
85
|
+
try:
|
86
|
+
return float(s)
|
87
|
+
except ValueError:
|
88
|
+
return None
|
89
|
+
elif output_type is bool:
|
90
|
+
return s.lower() in ("true", "yes", "1")
|
91
|
+
elif output_type is str:
|
92
|
+
return s
|
93
|
+
else:
|
94
|
+
return None
|
95
|
+
|
96
|
+
|
97
|
+
def is_callable(obj: Any, k: int = 1) -> bool:
|
98
|
+
"""Check if object is callable and accepts exactly k args.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
obj: Object to check
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
bool: True if object is callable with k args, False otherwise
|
105
|
+
"""
|
106
|
+
if not callable(obj):
|
107
|
+
return False
|
108
|
+
try:
|
109
|
+
sig = signature(obj)
|
110
|
+
params = list(sig.parameters.values())
|
111
|
+
return len(params) == k
|
112
|
+
except ValueError:
|
113
|
+
return False
|
langroid/vector_store/base.py
CHANGED
@@ -5,11 +5,11 @@ from typing import Dict, List, Optional, Sequence, Tuple, Type
|
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
from pydantic_settings import BaseSettings
|
8
9
|
|
9
10
|
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
10
11
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
11
12
|
from langroid.mytypes import DocMetaData, Document, EmbeddingFunction
|
12
|
-
from langroid.pydantic_v1 import BaseSettings
|
13
13
|
from langroid.utils.algorithms.graph import components, topological_sort
|
14
14
|
from langroid.utils.configuration import settings
|
15
15
|
from langroid.utils.object_registry import ObjectRegistry
|
@@ -82,7 +82,7 @@ class VectorStore(ABC):
|
|
82
82
|
else:
|
83
83
|
logger.warning(
|
84
84
|
f"""
|
85
|
-
Unknown vector store config: {config.
|
85
|
+
Unknown vector store config: {config.__class__.__name__},
|
86
86
|
so skipping vector store creation!
|
87
87
|
If you intended to use a vector-store, please set a specific
|
88
88
|
vector-store in your script, typically in the `vecdb` field of a
|
@@ -160,7 +160,7 @@ class VectorStore(ABC):
|
|
160
160
|
If full_eval is True, sanitization is bypassed - use only with trusted input!
|
161
161
|
"""
|
162
162
|
# convert each doc to a dict, using dotted paths for nested fields
|
163
|
-
dicts = [flatten_dict(doc.
|
163
|
+
dicts = [flatten_dict(doc.model_dump(by_alias=True)) for doc in docs]
|
164
164
|
df = pd.DataFrame(dicts)
|
165
165
|
|
166
166
|
try:
|
langroid/vector_store/lancedb.py
CHANGED
@@ -15,8 +15,7 @@ from typing import (
|
|
15
15
|
|
16
16
|
import pandas as pd
|
17
17
|
from dotenv import load_dotenv
|
18
|
-
|
19
|
-
from langroid.pydantic_v1 import BaseModel, ValidationError, create_model
|
18
|
+
from pydantic import BaseModel, ValidationError, create_model
|
20
19
|
|
21
20
|
if TYPE_CHECKING:
|
22
21
|
from lancedb.query import LanceVectorQueryBuilder
|
@@ -175,11 +174,12 @@ class LanceDB(VectorStore):
|
|
175
174
|
fields = {"id": (str, ...), "vector": (Vector(n), ...)}
|
176
175
|
|
177
176
|
sorted_fields = dict(
|
178
|
-
sorted(doc_cls.
|
177
|
+
sorted(doc_cls.model_fields.items(), key=lambda item: item[0])
|
179
178
|
)
|
180
179
|
# Add both statically and dynamically defined fields from doc_cls
|
181
180
|
for field_name, field in sorted_fields.items():
|
182
|
-
|
181
|
+
field_type = field.annotation if hasattr(field, "annotation") else field
|
182
|
+
fields[field_name] = (field_type, field.default)
|
183
183
|
|
184
184
|
# Create the new model with dynamic fields
|
185
185
|
NewModel = create_model(
|
@@ -227,7 +227,7 @@ class LanceDB(VectorStore):
|
|
227
227
|
dict(
|
228
228
|
id=ids[i + j],
|
229
229
|
vector=embedding_vecs[i + j],
|
230
|
-
**doc.
|
230
|
+
**doc.model_dump(),
|
231
231
|
)
|
232
232
|
for j, doc in enumerate(documents[i : i + b])
|
233
233
|
]
|
@@ -0,0 +1,404 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Any,
|
7
|
+
Dict,
|
8
|
+
Generator,
|
9
|
+
List,
|
10
|
+
Optional,
|
11
|
+
Sequence,
|
12
|
+
Tuple,
|
13
|
+
Type,
|
14
|
+
)
|
15
|
+
|
16
|
+
import pandas as pd
|
17
|
+
from dotenv import load_dotenv
|
18
|
+
|
19
|
+
from pydantic import BaseModel, ValidationError, create_model
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from lancedb.query import LanceVectorQueryBuilder
|
23
|
+
|
24
|
+
from langroid.embedding_models.base import (
|
25
|
+
EmbeddingModelsConfig,
|
26
|
+
)
|
27
|
+
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
28
|
+
from langroid.exceptions import LangroidImportError
|
29
|
+
from langroid.mytypes import Document
|
30
|
+
from langroid.utils.configuration import settings
|
31
|
+
from langroid.utils.pydantic_utils import (
|
32
|
+
dataframe_to_document_model,
|
33
|
+
dataframe_to_documents,
|
34
|
+
)
|
35
|
+
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
36
|
+
|
37
|
+
try:
|
38
|
+
import lancedb
|
39
|
+
from lancedb.pydantic import LanceModel, Vector
|
40
|
+
|
41
|
+
has_lancedb = True
|
42
|
+
except ImportError:
|
43
|
+
has_lancedb = False
|
44
|
+
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
48
|
+
class LanceDBConfig(VectorStoreConfig):
|
49
|
+
cloud: bool = False
|
50
|
+
collection_name: str | None = "temp"
|
51
|
+
storage_path: str = ".lancedb/data"
|
52
|
+
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
53
|
+
distance: str = "cosine"
|
54
|
+
|
55
|
+
|
56
|
+
class LanceDB(VectorStore):
|
57
|
+
def __init__(self, config: LanceDBConfig = LanceDBConfig()):
|
58
|
+
super().__init__(config)
|
59
|
+
if not has_lancedb:
|
60
|
+
raise LangroidImportError("lancedb", "lancedb")
|
61
|
+
|
62
|
+
self.config: LanceDBConfig = config
|
63
|
+
self.host = config.host
|
64
|
+
self.port = config.port
|
65
|
+
self.is_from_dataframe = False # were docs ingested from a dataframe?
|
66
|
+
self.df_metadata_columns: List[str] = [] # metadata columns from dataframe
|
67
|
+
|
68
|
+
load_dotenv()
|
69
|
+
if self.config.cloud:
|
70
|
+
logger.warning(
|
71
|
+
"LanceDB Cloud is not available yet. Switching to local storage."
|
72
|
+
)
|
73
|
+
config.cloud = False
|
74
|
+
else:
|
75
|
+
try:
|
76
|
+
self.client = lancedb.connect(
|
77
|
+
uri=config.storage_path,
|
78
|
+
)
|
79
|
+
except Exception as e:
|
80
|
+
new_storage_path = config.storage_path + ".new"
|
81
|
+
logger.warning(
|
82
|
+
f"""
|
83
|
+
Error connecting to local LanceDB at {config.storage_path}:
|
84
|
+
{e}
|
85
|
+
Switching to {new_storage_path}
|
86
|
+
"""
|
87
|
+
)
|
88
|
+
self.client = lancedb.connect(
|
89
|
+
uri=new_storage_path,
|
90
|
+
)
|
91
|
+
|
92
|
+
def clear_empty_collections(self) -> int:
|
93
|
+
coll_names = self.list_collections()
|
94
|
+
n_deletes = 0
|
95
|
+
for name in coll_names:
|
96
|
+
nr = self.client.open_table(name).head(1).shape[0]
|
97
|
+
if nr == 0:
|
98
|
+
n_deletes += 1
|
99
|
+
self.client.drop_table(name)
|
100
|
+
return n_deletes
|
101
|
+
|
102
|
+
def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
|
103
|
+
"""Clear all collections with the given prefix."""
|
104
|
+
if not really:
|
105
|
+
logger.warning("Not deleting all collections, set really=True to confirm")
|
106
|
+
return 0
|
107
|
+
coll_names = [
|
108
|
+
c for c in self.list_collections(empty=True) if c.startswith(prefix)
|
109
|
+
]
|
110
|
+
if len(coll_names) == 0:
|
111
|
+
logger.warning(f"No collections found with prefix {prefix}")
|
112
|
+
return 0
|
113
|
+
n_empty_deletes = 0
|
114
|
+
n_non_empty_deletes = 0
|
115
|
+
for name in coll_names:
|
116
|
+
nr = self.client.open_table(name).head(1).shape[0]
|
117
|
+
n_empty_deletes += nr == 0
|
118
|
+
n_non_empty_deletes += nr > 0
|
119
|
+
self.client.drop_table(name)
|
120
|
+
logger.warning(
|
121
|
+
f"""
|
122
|
+
Deleted {n_empty_deletes} empty collections and
|
123
|
+
{n_non_empty_deletes} non-empty collections.
|
124
|
+
"""
|
125
|
+
)
|
126
|
+
return n_empty_deletes + n_non_empty_deletes
|
127
|
+
|
128
|
+
def list_collections(self, empty: bool = False) -> List[str]:
|
129
|
+
"""
|
130
|
+
Returns:
|
131
|
+
List of collection names that have at least one vector.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
empty (bool, optional): Whether to include empty collections.
|
135
|
+
"""
|
136
|
+
colls = self.client.table_names(limit=None)
|
137
|
+
if len(colls) == 0:
|
138
|
+
return []
|
139
|
+
if empty: # include empty tbls
|
140
|
+
return colls # type: ignore
|
141
|
+
counts = [self.client.open_table(coll).head(1).shape[0] for coll in colls]
|
142
|
+
return [coll for coll, count in zip(colls, counts) if count > 0]
|
143
|
+
|
144
|
+
def _create_lance_schema(self, doc_cls: Type[Document]) -> Type[BaseModel]:
|
145
|
+
"""
|
146
|
+
NOTE: NOT USED, but leaving it here as it may be useful.
|
147
|
+
|
148
|
+
Create a subclass of LanceModel with fields:
|
149
|
+
- id (str)
|
150
|
+
- Vector field that has dims equal to
|
151
|
+
the embedding dimension of the embedding model, and a data field of type
|
152
|
+
DocClass.
|
153
|
+
- other fields from doc_cls
|
154
|
+
|
155
|
+
Args:
|
156
|
+
doc_cls (Type[Document]): A Pydantic model which should be a subclass of
|
157
|
+
Document, to be used as the type for the data field.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Type[BaseModel]: A new Pydantic model subclassing from LanceModel.
|
161
|
+
|
162
|
+
Raises:
|
163
|
+
ValueError: If `n` is not a non-negative integer or if `DocClass` is not a
|
164
|
+
subclass of Document.
|
165
|
+
"""
|
166
|
+
if not issubclass(doc_cls, Document):
|
167
|
+
raise ValueError("DocClass must be a subclass of Document")
|
168
|
+
|
169
|
+
if not has_lancedb:
|
170
|
+
raise LangroidImportError("lancedb", "lancedb")
|
171
|
+
|
172
|
+
n = self.embedding_dim
|
173
|
+
|
174
|
+
# Prepare fields for the new model
|
175
|
+
fields = {"id": (str, ...), "vector": (Vector(n), ...)}
|
176
|
+
|
177
|
+
sorted_fields = dict(
|
178
|
+
sorted(doc_cls.__fields__.items(), key=lambda item: item[0])
|
179
|
+
)
|
180
|
+
# Add both statically and dynamically defined fields from doc_cls
|
181
|
+
for field_name, field in sorted_fields.items():
|
182
|
+
fields[field_name] = (field.outer_type_, field.default)
|
183
|
+
|
184
|
+
# Create the new model with dynamic fields
|
185
|
+
NewModel = create_model(
|
186
|
+
"NewModel", __base__=LanceModel, **fields
|
187
|
+
) # type: ignore
|
188
|
+
return NewModel # type: ignore
|
189
|
+
|
190
|
+
def create_collection(self, collection_name: str, replace: bool = False) -> None:
|
191
|
+
self.config.replace_collection = replace
|
192
|
+
self.config.collection_name = collection_name
|
193
|
+
if replace:
|
194
|
+
self.delete_collection(collection_name)
|
195
|
+
|
196
|
+
def add_documents(self, documents: Sequence[Document]) -> None:
|
197
|
+
super().maybe_add_ids(documents)
|
198
|
+
colls = self.list_collections(empty=True)
|
199
|
+
if len(documents) == 0:
|
200
|
+
return
|
201
|
+
embedding_vecs = self.embedding_fn([doc.content for doc in documents])
|
202
|
+
coll_name = self.config.collection_name
|
203
|
+
if coll_name is None:
|
204
|
+
raise ValueError("No collection name set, cannot ingest docs")
|
205
|
+
# self._maybe_set_doc_class_schema(documents[0])
|
206
|
+
table_exists = False
|
207
|
+
if (
|
208
|
+
coll_name in colls
|
209
|
+
and self.client.open_table(coll_name).head(1).shape[0] > 0
|
210
|
+
):
|
211
|
+
# collection exists and is not empty:
|
212
|
+
# if replace_collection is True, we'll overwrite the existing collection,
|
213
|
+
# else we'll append to it.
|
214
|
+
if self.config.replace_collection:
|
215
|
+
self.client.drop_table(coll_name)
|
216
|
+
else:
|
217
|
+
table_exists = True
|
218
|
+
|
219
|
+
ids = [str(d.id()) for d in documents]
|
220
|
+
# don't insert all at once, batch in chunks of b,
|
221
|
+
# else we get an API error
|
222
|
+
b = self.config.batch_size
|
223
|
+
|
224
|
+
def make_batches() -> Generator[List[Dict[str, Any]], None, None]:
|
225
|
+
for i in range(0, len(ids), b):
|
226
|
+
batch = [
|
227
|
+
dict(
|
228
|
+
id=ids[i + j],
|
229
|
+
vector=embedding_vecs[i + j],
|
230
|
+
**doc.model_dump(),
|
231
|
+
)
|
232
|
+
for j, doc in enumerate(documents[i : i + b])
|
233
|
+
]
|
234
|
+
yield batch
|
235
|
+
|
236
|
+
try:
|
237
|
+
if table_exists:
|
238
|
+
tbl = self.client.open_table(coll_name)
|
239
|
+
tbl.add(make_batches())
|
240
|
+
else:
|
241
|
+
batch_gen = make_batches()
|
242
|
+
batch = next(batch_gen)
|
243
|
+
# use first batch to create table...
|
244
|
+
tbl = self.client.create_table(
|
245
|
+
coll_name,
|
246
|
+
data=batch,
|
247
|
+
mode="create",
|
248
|
+
)
|
249
|
+
# ... and add the rest
|
250
|
+
tbl.add(batch_gen)
|
251
|
+
except Exception as e:
|
252
|
+
logger.error(
|
253
|
+
f"""
|
254
|
+
Error adding documents to LanceDB: {e}
|
255
|
+
POSSIBLE REMEDY: Delete the LancdDB storage directory
|
256
|
+
{self.config.storage_path} and try again.
|
257
|
+
"""
|
258
|
+
)
|
259
|
+
|
260
|
+
def add_dataframe(
|
261
|
+
self,
|
262
|
+
df: pd.DataFrame,
|
263
|
+
content: str = "content",
|
264
|
+
metadata: List[str] = [],
|
265
|
+
) -> None:
|
266
|
+
"""
|
267
|
+
Add a dataframe to the collection.
|
268
|
+
Args:
|
269
|
+
df (pd.DataFrame): A dataframe
|
270
|
+
content (str): The name of the column in the dataframe that contains the
|
271
|
+
text content to be embedded using the embedding model.
|
272
|
+
metadata (List[str]): A list of column names in the dataframe that contain
|
273
|
+
metadata to be stored in the database. Defaults to [].
|
274
|
+
"""
|
275
|
+
self.is_from_dataframe = True
|
276
|
+
actual_metadata = metadata.model_copy()
|
277
|
+
self.df_metadata_columns = actual_metadata # could be updated below
|
278
|
+
# get content column
|
279
|
+
content_values = df[content].values.tolist()
|
280
|
+
embedding_vecs = self.embedding_fn(content_values)
|
281
|
+
|
282
|
+
# add vector column
|
283
|
+
df["vector"] = embedding_vecs
|
284
|
+
if content != "content":
|
285
|
+
# rename content column to "content", leave existing column intact
|
286
|
+
df = df.rename(columns={content: "content"}, inplace=False)
|
287
|
+
|
288
|
+
if "id" not in df.columns:
|
289
|
+
docs = dataframe_to_documents(df, content="content", metadata=metadata)
|
290
|
+
ids = [str(d.id()) for d in docs]
|
291
|
+
df["id"] = ids
|
292
|
+
|
293
|
+
if "id" not in actual_metadata:
|
294
|
+
actual_metadata += ["id"]
|
295
|
+
|
296
|
+
colls = self.list_collections(empty=True)
|
297
|
+
coll_name = self.config.collection_name
|
298
|
+
if (
|
299
|
+
coll_name not in colls
|
300
|
+
or self.client.open_table(coll_name).head(1).shape[0] == 0
|
301
|
+
):
|
302
|
+
# collection either doesn't exist or is empty, so replace it
|
303
|
+
# and set new schema from df
|
304
|
+
self.client.create_table(
|
305
|
+
self.config.collection_name,
|
306
|
+
data=df,
|
307
|
+
mode="overwrite",
|
308
|
+
)
|
309
|
+
doc_cls = dataframe_to_document_model(
|
310
|
+
df,
|
311
|
+
content=content,
|
312
|
+
metadata=actual_metadata,
|
313
|
+
exclude=["vector"],
|
314
|
+
)
|
315
|
+
self.config.document_class = doc_cls # type: ignore
|
316
|
+
else:
|
317
|
+
# collection exists and is not empty, so append to it
|
318
|
+
tbl = self.client.open_table(self.config.collection_name)
|
319
|
+
tbl.add(df)
|
320
|
+
|
321
|
+
def delete_collection(self, collection_name: str) -> None:
|
322
|
+
self.client.drop_table(collection_name, ignore_missing=True)
|
323
|
+
|
324
|
+
def _lance_result_to_docs(
|
325
|
+
self, result: "LanceVectorQueryBuilder"
|
326
|
+
) -> List[Document]:
|
327
|
+
if self.is_from_dataframe:
|
328
|
+
df = result.to_pandas()
|
329
|
+
return dataframe_to_documents(
|
330
|
+
df,
|
331
|
+
content="content",
|
332
|
+
metadata=self.df_metadata_columns,
|
333
|
+
doc_cls=self.config.document_class,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
records = result.to_arrow().to_pylist()
|
337
|
+
return self._records_to_docs(records)
|
338
|
+
|
339
|
+
def _records_to_docs(self, records: List[Dict[str, Any]]) -> List[Document]:
|
340
|
+
try:
|
341
|
+
docs = [self.config.document_class(**rec) for rec in records]
|
342
|
+
except ValidationError as e:
|
343
|
+
raise ValueError(
|
344
|
+
f"""
|
345
|
+
Error validating LanceDB result: {e}
|
346
|
+
HINT: This could happen when you're re-using an
|
347
|
+
existing LanceDB store with a different schema.
|
348
|
+
Try deleting your local lancedb storage at `{self.config.storage_path}`
|
349
|
+
re-ingesting your documents and/or replacing the collections.
|
350
|
+
"""
|
351
|
+
)
|
352
|
+
return docs
|
353
|
+
|
354
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
355
|
+
if self.config.collection_name is None:
|
356
|
+
raise ValueError("No collection name set, cannot retrieve docs")
|
357
|
+
if self.config.collection_name not in self.list_collections(empty=True):
|
358
|
+
return []
|
359
|
+
tbl = self.client.open_table(self.config.collection_name)
|
360
|
+
pre_result = tbl.search(None).where(where or None).limit(None)
|
361
|
+
return self._lance_result_to_docs(pre_result)
|
362
|
+
|
363
|
+
def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
|
364
|
+
if self.config.collection_name is None:
|
365
|
+
raise ValueError("No collection name set, cannot retrieve docs")
|
366
|
+
_ids = [str(id) for id in ids]
|
367
|
+
tbl = self.client.open_table(self.config.collection_name)
|
368
|
+
docs = []
|
369
|
+
for _id in _ids:
|
370
|
+
results = self._lance_result_to_docs(tbl.search().where(f"id == '{_id}'"))
|
371
|
+
if len(results) > 0:
|
372
|
+
docs.append(results[0])
|
373
|
+
return docs
|
374
|
+
|
375
|
+
def similar_texts_with_scores(
|
376
|
+
self,
|
377
|
+
text: str,
|
378
|
+
k: int = 1,
|
379
|
+
where: Optional[str] = None,
|
380
|
+
) -> List[Tuple[Document, float]]:
|
381
|
+
embedding = self.embedding_fn([text])[0]
|
382
|
+
tbl = self.client.open_table(self.config.collection_name)
|
383
|
+
result = (
|
384
|
+
tbl.search(embedding)
|
385
|
+
.metric(self.config.distance)
|
386
|
+
.where(where, prefilter=True)
|
387
|
+
.limit(k)
|
388
|
+
)
|
389
|
+
docs = self._lance_result_to_docs(result)
|
390
|
+
# note _distance is 1 - cosine
|
391
|
+
if self.is_from_dataframe:
|
392
|
+
scores = [
|
393
|
+
1 - rec["_distance"] for rec in result.to_pandas().to_dict("records")
|
394
|
+
]
|
395
|
+
else:
|
396
|
+
scores = [1 - rec["_distance"] for rec in result.to_arrow().to_pylist()]
|
397
|
+
if len(docs) == 0:
|
398
|
+
logger.warning(f"No matches found for {text}")
|
399
|
+
return []
|
400
|
+
if settings.debug:
|
401
|
+
logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
|
402
|
+
doc_score_pairs = list(zip(docs, scores))
|
403
|
+
self.show_if_debug(doc_score_pairs)
|
404
|
+
return doc_score_pairs
|
@@ -33,7 +33,7 @@ class MeiliSearchConfig(VectorStoreConfig):
|
|
33
33
|
cloud: bool = False
|
34
34
|
collection_name: str | None = None
|
35
35
|
primary_key: str = "id"
|
36
|
-
port = 7700
|
36
|
+
port: int = 7700
|
37
37
|
|
38
38
|
|
39
39
|
class MeiliSearch(VectorStore):
|
@@ -193,7 +193,7 @@ class MeiliSearch(VectorStore):
|
|
193
193
|
dict(
|
194
194
|
id=d.id(),
|
195
195
|
content=d.content,
|
196
|
-
metadata=d.metadata.
|
196
|
+
metadata=d.metadata.model_dump(),
|
197
197
|
)
|
198
198
|
for d in documents
|
199
199
|
]
|
@@ -17,11 +17,11 @@ from typing import (
|
|
17
17
|
|
18
18
|
from dotenv import load_dotenv
|
19
19
|
|
20
|
+
# import dataclass
|
21
|
+
from pydantic import BaseModel
|
22
|
+
|
20
23
|
from langroid import LangroidImportError
|
21
24
|
from langroid.mytypes import Document
|
22
|
-
|
23
|
-
# import dataclass
|
24
|
-
from langroid.pydantic_v1 import BaseModel
|
25
25
|
from langroid.utils.configuration import settings
|
26
26
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
27
27
|
|
@@ -246,7 +246,7 @@ class PineconeDB(VectorStore):
|
|
246
246
|
return
|
247
247
|
|
248
248
|
super().maybe_add_ids(documents)
|
249
|
-
document_dicts = [doc.
|
249
|
+
document_dicts = [doc.model_dump() for doc in documents]
|
250
250
|
document_ids = [doc.id() for doc in documents]
|
251
251
|
embedding_vectors = self.embedding_fn([doc.content for doc in documents])
|
252
252
|
vectors = [
|