langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
langroid/utils/system.py
CHANGED
@@ -1,15 +1,55 @@
|
|
1
|
+
import getpass
|
2
|
+
import hashlib
|
3
|
+
import importlib
|
1
4
|
import inspect
|
2
5
|
import logging
|
3
6
|
import shutil
|
7
|
+
import socket
|
8
|
+
import traceback
|
9
|
+
from typing import Any
|
4
10
|
|
5
11
|
logger = logging.getLogger(__name__)
|
6
12
|
|
7
13
|
DELETION_ALLOWED_PATHS = [
|
8
14
|
".qdrant",
|
9
15
|
".chroma",
|
16
|
+
".lancedb",
|
10
17
|
]
|
11
18
|
|
12
19
|
|
20
|
+
class LazyLoad:
|
21
|
+
"""Lazy loading of modules or classes."""
|
22
|
+
|
23
|
+
def __init__(self, import_path: str) -> None:
|
24
|
+
self.import_path = import_path
|
25
|
+
self._target = None
|
26
|
+
self._is_target_loaded = False
|
27
|
+
|
28
|
+
def _load_target(self) -> None:
|
29
|
+
if not self._is_target_loaded:
|
30
|
+
try:
|
31
|
+
# Attempt to import as a module
|
32
|
+
self._target = importlib.import_module(self.import_path) # type: ignore
|
33
|
+
except ImportError:
|
34
|
+
# If module import fails, attempt to import as a
|
35
|
+
# class or function from a module
|
36
|
+
module_path, attr_name = self.import_path.rsplit(".", 1)
|
37
|
+
module = importlib.import_module(module_path)
|
38
|
+
self._target = getattr(module, attr_name)
|
39
|
+
self._is_target_loaded = True
|
40
|
+
|
41
|
+
def __getattr__(self, name: str) -> Any:
|
42
|
+
self._load_target()
|
43
|
+
return getattr(self._target, name)
|
44
|
+
|
45
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
46
|
+
self._load_target()
|
47
|
+
if callable(self._target):
|
48
|
+
return self._target(*args, **kwargs)
|
49
|
+
else:
|
50
|
+
raise TypeError(f"{self.import_path!r} object is not callable")
|
51
|
+
|
52
|
+
|
13
53
|
def rmdir(path: str) -> bool:
|
14
54
|
"""
|
15
55
|
Remove a directory recursively.
|
@@ -55,3 +95,61 @@ def caller_name() -> str:
|
|
55
95
|
return ""
|
56
96
|
|
57
97
|
return caller_frame.f_code.co_name
|
98
|
+
|
99
|
+
|
100
|
+
def friendly_error(e: Exception, msg: str = "An error occurred.") -> str:
|
101
|
+
tb = traceback.format_exc()
|
102
|
+
original_error_message: str = str(e)
|
103
|
+
full_error_message: str = (
|
104
|
+
f"{msg}\nOriginal error: {original_error_message}\nTraceback:\n{tb}"
|
105
|
+
)
|
106
|
+
return full_error_message
|
107
|
+
|
108
|
+
|
109
|
+
def generate_user_id(org: str = "") -> str:
|
110
|
+
"""
|
111
|
+
Generate a unique user ID based on the username and machine name.
|
112
|
+
Returns:
|
113
|
+
"""
|
114
|
+
# Get the username
|
115
|
+
username = getpass.getuser()
|
116
|
+
|
117
|
+
# Get the machine's name
|
118
|
+
machine_name = socket.gethostname()
|
119
|
+
|
120
|
+
org_pfx = f"{org}_" if org else ""
|
121
|
+
|
122
|
+
# Create a consistent unique ID based on the username and machine name
|
123
|
+
unique_string = f"{org_pfx}{username}@{machine_name}"
|
124
|
+
|
125
|
+
# Generate a SHA-256 hash of the unique string
|
126
|
+
user_id = hashlib.sha256(unique_string.encode()).hexdigest()
|
127
|
+
|
128
|
+
return user_id
|
129
|
+
|
130
|
+
|
131
|
+
def update_hash(hash: str | None = None, s: str = "") -> str:
|
132
|
+
"""
|
133
|
+
Takes a SHA256 hash string and a new string, updates the hash with the new string,
|
134
|
+
and returns the updated hash string.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
hash (str): A SHA256 hash string.
|
138
|
+
s (str): A new string to update the hash with.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
The updated hash in hexadecimal format.
|
142
|
+
"""
|
143
|
+
# Create a new hash object if no hash is provided
|
144
|
+
if hash is None:
|
145
|
+
hash_obj = hashlib.sha256()
|
146
|
+
else:
|
147
|
+
# Convert the hexadecimal hash string to a byte object
|
148
|
+
hash_bytes = bytes.fromhex(hash)
|
149
|
+
hash_obj = hashlib.sha256(hash_bytes)
|
150
|
+
|
151
|
+
# Update the hash with the new string
|
152
|
+
hash_obj.update(s.encode("utf-8"))
|
153
|
+
|
154
|
+
# Return the updated hash in hexadecimal format and the original string
|
155
|
+
return hash_obj.hexdigest()
|
@@ -0,0 +1,40 @@
|
|
1
|
+
from . import base
|
2
|
+
|
3
|
+
from . import qdrantdb
|
4
|
+
from . import meilisearch
|
5
|
+
from . import lancedb
|
6
|
+
|
7
|
+
from .base import VectorStoreConfig, VectorStore
|
8
|
+
from .qdrantdb import QdrantDBConfig, QdrantDB
|
9
|
+
from .meilisearch import MeiliSearch, MeiliSearchConfig
|
10
|
+
from .lancedb import LanceDB, LanceDBConfig
|
11
|
+
|
12
|
+
has_chromadb = False
|
13
|
+
try:
|
14
|
+
from . import chromadb
|
15
|
+
from .chromadb import ChromaDBConfig, ChromaDB
|
16
|
+
|
17
|
+
chromadb # silence linters
|
18
|
+
ChromaDB
|
19
|
+
ChromaDBConfig
|
20
|
+
has_chromadb = True
|
21
|
+
except ImportError:
|
22
|
+
pass
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"base",
|
26
|
+
"VectorStore",
|
27
|
+
"VectorStoreConfig",
|
28
|
+
"qdrantdb",
|
29
|
+
"meilisearch",
|
30
|
+
"lancedb",
|
31
|
+
"QdrantDBConfig",
|
32
|
+
"QdrantDB",
|
33
|
+
"MeiliSearch",
|
34
|
+
"MeiliSearchConfig",
|
35
|
+
"LanceDB",
|
36
|
+
"LanceDBConfig",
|
37
|
+
]
|
38
|
+
|
39
|
+
if has_chromadb:
|
40
|
+
__all__.extend(["chromadb", "ChromaDBConfig", "ChromaDB"])
|
langroid/vector_store/base.py
CHANGED
@@ -1,21 +1,26 @@
|
|
1
|
+
import copy
|
1
2
|
import logging
|
2
3
|
from abc import ABC, abstractmethod
|
3
|
-
from typing import List, Optional, Sequence, Tuple
|
4
|
+
from typing import Dict, List, Optional, Sequence, Tuple
|
4
5
|
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
5
8
|
from pydantic import BaseSettings
|
6
9
|
|
7
|
-
from langroid.embedding_models.base import EmbeddingModelsConfig
|
10
|
+
from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
|
8
11
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
9
12
|
from langroid.mytypes import Document
|
13
|
+
from langroid.utils.algorithms.graph import components, topological_sort
|
10
14
|
from langroid.utils.configuration import settings
|
11
15
|
from langroid.utils.output.printing import print_long_text
|
16
|
+
from langroid.utils.pandas_utils import stringify
|
12
17
|
|
13
18
|
logger = logging.getLogger(__name__)
|
14
19
|
|
15
20
|
|
16
21
|
class VectorStoreConfig(BaseSettings):
|
17
|
-
type: str = "
|
18
|
-
collection_name: str | None =
|
22
|
+
type: str = "" # deprecated, keeping it for backward compatibility
|
23
|
+
collection_name: str | None = "temp"
|
19
24
|
replace_collection: bool = False # replace collection if it already exists
|
20
25
|
storage_path: str = ".qdrant/data"
|
21
26
|
cloud: bool = False
|
@@ -36,16 +41,27 @@ class VectorStore(ABC):
|
|
36
41
|
|
37
42
|
def __init__(self, config: VectorStoreConfig):
|
38
43
|
self.config = config
|
44
|
+
self.embedding_model = EmbeddingModel.create(config.embedding)
|
39
45
|
|
40
46
|
@staticmethod
|
41
47
|
def create(config: VectorStoreConfig) -> Optional["VectorStore"]:
|
42
48
|
from langroid.vector_store.chromadb import ChromaDB, ChromaDBConfig
|
49
|
+
from langroid.vector_store.lancedb import LanceDB, LanceDBConfig
|
50
|
+
from langroid.vector_store.meilisearch import MeiliSearch, MeiliSearchConfig
|
51
|
+
from langroid.vector_store.momento import MomentoVI, MomentoVIConfig
|
43
52
|
from langroid.vector_store.qdrantdb import QdrantDB, QdrantDBConfig
|
44
53
|
|
45
54
|
if isinstance(config, QdrantDBConfig):
|
46
55
|
return QdrantDB(config)
|
47
56
|
elif isinstance(config, ChromaDBConfig):
|
48
57
|
return ChromaDB(config)
|
58
|
+
elif isinstance(config, MomentoVIConfig):
|
59
|
+
return MomentoVI(config)
|
60
|
+
elif isinstance(config, LanceDBConfig):
|
61
|
+
return LanceDB(config)
|
62
|
+
elif isinstance(config, MeiliSearchConfig):
|
63
|
+
return MeiliSearch(config)
|
64
|
+
|
49
65
|
else:
|
50
66
|
logger.warning(
|
51
67
|
f"""
|
@@ -113,6 +129,42 @@ class VectorStore(ABC):
|
|
113
129
|
def add_documents(self, documents: Sequence[Document]) -> None:
|
114
130
|
pass
|
115
131
|
|
132
|
+
def compute_from_docs(self, docs: List[Document], calc: str) -> str:
|
133
|
+
"""Compute a result on a set of documents,
|
134
|
+
using a dataframe calc string like `df.groupby('state')['income'].mean()`.
|
135
|
+
"""
|
136
|
+
dicts = [doc.dict() for doc in docs]
|
137
|
+
df = pd.DataFrame(dicts)
|
138
|
+
|
139
|
+
try:
|
140
|
+
result = pd.eval( # safer than eval but limited to single expression
|
141
|
+
calc,
|
142
|
+
engine="python",
|
143
|
+
parser="pandas",
|
144
|
+
local_dict={"df": df},
|
145
|
+
)
|
146
|
+
except Exception as e:
|
147
|
+
# return error message so LLM can fix the calc string if needed
|
148
|
+
err = f"""
|
149
|
+
Error encountered in pandas eval: {str(e)}
|
150
|
+
"""
|
151
|
+
if isinstance(e, KeyError) and "not in index" in str(e):
|
152
|
+
# Pd.eval sometimes fails on a perfectly valid exprn like
|
153
|
+
# df.loc[..., 'column'] with a KeyError.
|
154
|
+
err += """
|
155
|
+
Maybe try a different way, e.g.
|
156
|
+
instead of df.loc[..., 'column'], try df.loc[...]['column']
|
157
|
+
"""
|
158
|
+
return err
|
159
|
+
return stringify(result)
|
160
|
+
|
161
|
+
def maybe_add_ids(self, documents: Sequence[Document]) -> None:
|
162
|
+
"""Add ids to metadata if absent, since some
|
163
|
+
vecdbs don't like having blank ids."""
|
164
|
+
for d in documents:
|
165
|
+
if d.metadata.id in [None, ""]:
|
166
|
+
d.metadata.id = d._unique_hash_id()
|
167
|
+
|
116
168
|
@abstractmethod
|
117
169
|
def similar_texts_with_scores(
|
118
170
|
self,
|
@@ -120,12 +172,157 @@ class VectorStore(ABC):
|
|
120
172
|
k: int = 1,
|
121
173
|
where: Optional[str] = None,
|
122
174
|
) -> List[Tuple[Document, float]]:
|
175
|
+
"""
|
176
|
+
Find k most similar texts to the given text, in terms of vector distance metric
|
177
|
+
(e.g., cosine similarity).
|
178
|
+
|
179
|
+
Args:
|
180
|
+
text (str): The text to find similar texts for.
|
181
|
+
k (int, optional): Number of similar texts to retrieve. Defaults to 1.
|
182
|
+
where (Optional[str], optional): Where clause to filter the search.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
List[Tuple[Document,float]]: List of (Document, score) tuples.
|
186
|
+
|
187
|
+
"""
|
123
188
|
pass
|
124
189
|
|
190
|
+
def add_context_window(
|
191
|
+
self, docs_scores: List[Tuple[Document, float]], neighbors: int = 0
|
192
|
+
) -> List[Tuple[Document, float]]:
|
193
|
+
"""
|
194
|
+
In each doc's metadata, there may be a window_ids field indicating
|
195
|
+
the ids of the chunks around the current chunk.
|
196
|
+
These window_ids may overlap, so we
|
197
|
+
- coalesce each overlapping groups into a single window (maintaining ordering),
|
198
|
+
- create a new document for each part, preserving metadata,
|
199
|
+
|
200
|
+
We may have stored a longer set of window_ids than we need during chunking.
|
201
|
+
Now, we just want `neighbors` on each side of the center of the window_ids list.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
docs_scores (List[Tuple[Document, float]]): List of pairs of documents
|
205
|
+
to add context windows to together with their match scores.
|
206
|
+
neighbors (int, optional): Number of neighbors on "each side" of match to
|
207
|
+
retrieve. Defaults to 0.
|
208
|
+
"Each side" here means before and after the match,
|
209
|
+
in the original text.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
List[Tuple[Document, float]]: List of (Document, score) tuples.
|
213
|
+
"""
|
214
|
+
# We return a larger context around each match, i.e.
|
215
|
+
# a window of `neighbors` on each side of the match.
|
216
|
+
docs = [d for d, s in docs_scores]
|
217
|
+
scores = [s for d, s in docs_scores]
|
218
|
+
if neighbors == 0:
|
219
|
+
return docs_scores
|
220
|
+
doc_chunks = [d for d in docs if d.metadata.is_chunk]
|
221
|
+
if len(doc_chunks) == 0:
|
222
|
+
return docs_scores
|
223
|
+
window_ids_list = []
|
224
|
+
id2metadata = {}
|
225
|
+
# id -> highest score of a doc it appears in
|
226
|
+
id2max_score: Dict[int | str, float] = {}
|
227
|
+
for i, d in enumerate(docs):
|
228
|
+
window_ids = d.metadata.window_ids
|
229
|
+
if len(window_ids) == 0:
|
230
|
+
window_ids = [d.id()]
|
231
|
+
id2metadata.update({id: d.metadata for id in window_ids})
|
232
|
+
|
233
|
+
id2max_score.update(
|
234
|
+
{id: max(id2max_score.get(id, 0), scores[i]) for id in window_ids}
|
235
|
+
)
|
236
|
+
n = len(window_ids)
|
237
|
+
chunk_idx = window_ids.index(d.id())
|
238
|
+
neighbor_ids = window_ids[
|
239
|
+
max(0, chunk_idx - neighbors) : min(n, chunk_idx + neighbors + 1)
|
240
|
+
]
|
241
|
+
window_ids_list += [neighbor_ids]
|
242
|
+
|
243
|
+
# window_ids could be from different docs,
|
244
|
+
# and they may overlap, so we coalesce overlapping groups into
|
245
|
+
# separate windows.
|
246
|
+
window_ids_list = self.remove_overlaps(window_ids_list)
|
247
|
+
final_docs = []
|
248
|
+
final_scores = []
|
249
|
+
for w in window_ids_list:
|
250
|
+
metadata = copy.deepcopy(id2metadata[w[0]])
|
251
|
+
metadata.window_ids = w
|
252
|
+
document = Document(
|
253
|
+
content=" ".join([d.content for d in self.get_documents_by_ids(w)]),
|
254
|
+
metadata=metadata,
|
255
|
+
)
|
256
|
+
# make a fresh id since content is in general different
|
257
|
+
document.metadata.id = document.hash_id(document.content)
|
258
|
+
final_docs += [document]
|
259
|
+
final_scores += [max(id2max_score[id] for id in w)]
|
260
|
+
return list(zip(final_docs, final_scores))
|
261
|
+
|
262
|
+
@staticmethod
|
263
|
+
def remove_overlaps(windows: List[List[str]]) -> List[List[str]]:
|
264
|
+
"""
|
265
|
+
Given a collection of windows, where each window is a sequence of ids,
|
266
|
+
identify groups of overlapping windows, and for each overlapping group,
|
267
|
+
order the chunk-ids using topological sort so they appear in the original
|
268
|
+
order in the text.
|
269
|
+
|
270
|
+
Args:
|
271
|
+
windows (List[int|str]): List of windows, where each window is a
|
272
|
+
sequence of ids.
|
273
|
+
|
274
|
+
Returns:
|
275
|
+
List[int|str]: List of windows, where each window is a sequence of ids,
|
276
|
+
and no two windows overlap.
|
277
|
+
"""
|
278
|
+
ids = set(id for w in windows for id in w)
|
279
|
+
# id -> {win -> # pos}
|
280
|
+
id2win2pos: Dict[str, Dict[int, int]] = {id: {} for id in ids}
|
281
|
+
|
282
|
+
for i, w in enumerate(windows):
|
283
|
+
for j, id in enumerate(w):
|
284
|
+
id2win2pos[id][i] = j
|
285
|
+
|
286
|
+
n = len(windows)
|
287
|
+
# relation between windows:
|
288
|
+
order = np.zeros((n, n), dtype=np.int8)
|
289
|
+
for i, w in enumerate(windows):
|
290
|
+
for j, x in enumerate(windows):
|
291
|
+
if i == j:
|
292
|
+
continue
|
293
|
+
if len(set(w).intersection(x)) == 0:
|
294
|
+
continue
|
295
|
+
id = list(set(w).intersection(x))[0] # any common id
|
296
|
+
if id2win2pos[id][i] > id2win2pos[id][j]:
|
297
|
+
order[i, j] = -1 # win i is before win j
|
298
|
+
else:
|
299
|
+
order[i, j] = 1 # win i is after win j
|
300
|
+
|
301
|
+
# find groups of windows that overlap, like connected components in a graph
|
302
|
+
groups = components(np.abs(order))
|
303
|
+
|
304
|
+
# order the chunk-ids in each group using topological sort
|
305
|
+
new_windows = []
|
306
|
+
for g in groups:
|
307
|
+
# find total ordering among windows in group based on order matrix
|
308
|
+
# (this is a topological sort)
|
309
|
+
_g = np.array(g)
|
310
|
+
order_matrix = order[_g][:, _g]
|
311
|
+
ordered_window_indices = topological_sort(order_matrix)
|
312
|
+
ordered_window_ids = [windows[i] for i in _g[ordered_window_indices]]
|
313
|
+
flattened = [id for w in ordered_window_ids for id in w]
|
314
|
+
flattened_deduped = list(dict.fromkeys(flattened))
|
315
|
+
# Note we are not going to split these, and instead we'll return
|
316
|
+
# larger windows from concatenating the connected groups.
|
317
|
+
# This ensures context is retained for LLM q/a
|
318
|
+
new_windows += [flattened_deduped]
|
319
|
+
|
320
|
+
return new_windows
|
321
|
+
|
125
322
|
@abstractmethod
|
126
|
-
def get_all_documents(self) -> List[Document]:
|
323
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
127
324
|
"""
|
128
|
-
Get all documents in the current collection
|
325
|
+
Get all documents in the current collection, possibly filtered by `where`.
|
129
326
|
"""
|
130
327
|
pass
|
131
328
|
|
@@ -1,8 +1,7 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
2
3
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
3
4
|
|
4
|
-
import chromadb
|
5
|
-
|
6
5
|
from langroid.embedding_models.base import (
|
7
6
|
EmbeddingModel,
|
8
7
|
EmbeddingModelsConfig,
|
@@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|
17
16
|
|
18
17
|
|
19
18
|
class ChromaDBConfig(VectorStoreConfig):
|
20
|
-
collection_name: str = "
|
19
|
+
collection_name: str = "temp"
|
21
20
|
storage_path: str = ".chroma/data"
|
22
21
|
embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
|
23
22
|
host: str = "127.0.0.1"
|
@@ -25,8 +24,19 @@ class ChromaDBConfig(VectorStoreConfig):
|
|
25
24
|
|
26
25
|
|
27
26
|
class ChromaDB(VectorStore):
|
28
|
-
def __init__(self, config: ChromaDBConfig):
|
27
|
+
def __init__(self, config: ChromaDBConfig = ChromaDBConfig()):
|
29
28
|
super().__init__(config)
|
29
|
+
try:
|
30
|
+
import chromadb
|
31
|
+
except ImportError:
|
32
|
+
raise ImportError(
|
33
|
+
"""
|
34
|
+
ChromaDB is not installed by default with Langroid.
|
35
|
+
If you want to use it, please install it with the `chromadb` extra, e.g.
|
36
|
+
pip install "langroid[chromadb]"
|
37
|
+
or an equivalent command.
|
38
|
+
"""
|
39
|
+
)
|
30
40
|
self.config = config
|
31
41
|
emb_model = EmbeddingModel.create(config.embedding)
|
32
42
|
self.embedding_fn = emb_model.embedding_fn()
|
@@ -99,53 +109,78 @@ class ChromaDB(VectorStore):
|
|
99
109
|
|
100
110
|
"""
|
101
111
|
self.config.collection_name = collection_name
|
112
|
+
if collection_name in self.list_collections(empty=True) and replace:
|
113
|
+
logger.warning(f"Replacing existing collection {collection_name}")
|
114
|
+
self.client.delete_collection(collection_name)
|
102
115
|
self.collection = self.client.create_collection(
|
103
116
|
name=self.config.collection_name,
|
104
117
|
embedding_function=self.embedding_fn,
|
105
118
|
get_or_create=not replace,
|
106
119
|
)
|
107
120
|
|
108
|
-
def add_documents(self, documents:
|
121
|
+
def add_documents(self, documents: Sequence[Document]) -> None:
|
122
|
+
super().maybe_add_ids(documents)
|
109
123
|
if documents is None:
|
110
124
|
return
|
111
125
|
contents: List[str] = [document.content for document in documents]
|
112
|
-
metadatas
|
113
|
-
|
126
|
+
# convert metadatas to dicts so chroma can handle them
|
127
|
+
metadata_dicts: List[dict[str, Any]] = [
|
128
|
+
d.metadata.dict_bool_int() for d in documents
|
114
129
|
]
|
130
|
+
for m in metadata_dicts:
|
131
|
+
# chroma does not handle non-atomic types in metadata
|
132
|
+
m["window_ids"] = ",".join(m["window_ids"])
|
133
|
+
|
115
134
|
ids = [str(d.id()) for d in documents]
|
116
135
|
self.collection.add(
|
117
136
|
# embedding_models=embedding_models,
|
118
137
|
documents=contents,
|
119
|
-
metadatas=
|
138
|
+
metadatas=metadata_dicts,
|
120
139
|
ids=ids,
|
121
140
|
)
|
122
141
|
|
123
|
-
def get_all_documents(self) -> List[Document]:
|
124
|
-
|
142
|
+
def get_all_documents(self, where: str = "") -> List[Document]:
|
143
|
+
filter = json.loads(where) if where else None
|
144
|
+
results = self.collection.get(
|
145
|
+
include=["documents", "metadatas"],
|
146
|
+
where=filter,
|
147
|
+
)
|
125
148
|
results["documents"] = [results["documents"]]
|
126
149
|
results["metadatas"] = [results["metadatas"]]
|
127
150
|
return self._docs_from_results(results)
|
128
151
|
|
129
152
|
def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
|
130
|
-
|
131
|
-
|
132
|
-
results
|
133
|
-
|
153
|
+
# get them one by one since chroma mangles the order of the results
|
154
|
+
# when fetched from a list of ids.
|
155
|
+
results = [
|
156
|
+
self.collection.get(ids=[id], include=["documents", "metadatas"])
|
157
|
+
for id in ids
|
158
|
+
]
|
159
|
+
final_results = {}
|
160
|
+
final_results["documents"] = [[r["documents"][0] for r in results]]
|
161
|
+
final_results["metadatas"] = [[r["metadatas"][0] for r in results]]
|
162
|
+
return self._docs_from_results(final_results)
|
134
163
|
|
135
164
|
def delete_collection(self, collection_name: str) -> None:
|
136
|
-
|
165
|
+
try:
|
166
|
+
self.client.delete_collection(name=collection_name)
|
167
|
+
except Exception:
|
168
|
+
pass
|
137
169
|
|
138
170
|
def similar_texts_with_scores(
|
139
171
|
self, text: str, k: int = 1, where: Optional[str] = None
|
140
172
|
) -> List[Tuple[Document, float]]:
|
173
|
+
n = self.collection.count()
|
174
|
+
filter = json.loads(where) if where else None
|
141
175
|
results = self.collection.query(
|
142
176
|
query_texts=[text],
|
143
|
-
n_results=k,
|
144
|
-
where=
|
177
|
+
n_results=min(n, k),
|
178
|
+
where=filter,
|
145
179
|
include=["documents", "distances", "metadatas"],
|
146
180
|
)
|
147
181
|
docs = self._docs_from_results(results)
|
148
|
-
|
182
|
+
# chroma distances are 1 - cosine.
|
183
|
+
scores = [1 - s for s in results["distances"][0]]
|
149
184
|
return list(zip(docs, scores))
|
150
185
|
|
151
186
|
def _docs_from_results(self, results: Dict[str, Any]) -> List[Document]:
|
@@ -164,22 +199,14 @@ class ChromaDB(VectorStore):
|
|
164
199
|
for i, c in enumerate(contents):
|
165
200
|
print_long_text("red", "italic red", f"MATCH-{i}", c)
|
166
201
|
metadatas = results["metadatas"][0]
|
202
|
+
for m in metadatas:
|
203
|
+
# restore the stringified list of window_ids into the original List[str]
|
204
|
+
if m["window_ids"].strip() == "":
|
205
|
+
m["window_ids"] = []
|
206
|
+
else:
|
207
|
+
m["window_ids"] = m["window_ids"].split(",")
|
167
208
|
docs = [
|
168
209
|
Document(content=d, metadata=DocMetaData(**m))
|
169
210
|
for d, m in zip(contents, metadatas)
|
170
211
|
]
|
171
212
|
return docs
|
172
|
-
|
173
|
-
|
174
|
-
# Example usage and testing
|
175
|
-
# chroma_db = ChromaDB.from_documents(
|
176
|
-
# collection_name="all-my-documents",
|
177
|
-
# documents=["doc1000101", "doc288822"],
|
178
|
-
# metadatas=[{"style": "style1"}, {"style": "style2"}],
|
179
|
-
# ids=["uri9", "uri10"]
|
180
|
-
# )
|
181
|
-
# results = chroma_db.query(
|
182
|
-
# query_texts=["This is a query document"],
|
183
|
-
# n_results=2
|
184
|
-
# )
|
185
|
-
# print(results)
|