langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
langroid/utils/system.py CHANGED
@@ -1,15 +1,55 @@
1
+ import getpass
2
+ import hashlib
3
+ import importlib
1
4
  import inspect
2
5
  import logging
3
6
  import shutil
7
+ import socket
8
+ import traceback
9
+ from typing import Any
4
10
 
5
11
  logger = logging.getLogger(__name__)
6
12
 
7
13
  DELETION_ALLOWED_PATHS = [
8
14
  ".qdrant",
9
15
  ".chroma",
16
+ ".lancedb",
10
17
  ]
11
18
 
12
19
 
20
+ class LazyLoad:
21
+ """Lazy loading of modules or classes."""
22
+
23
+ def __init__(self, import_path: str) -> None:
24
+ self.import_path = import_path
25
+ self._target = None
26
+ self._is_target_loaded = False
27
+
28
+ def _load_target(self) -> None:
29
+ if not self._is_target_loaded:
30
+ try:
31
+ # Attempt to import as a module
32
+ self._target = importlib.import_module(self.import_path) # type: ignore
33
+ except ImportError:
34
+ # If module import fails, attempt to import as a
35
+ # class or function from a module
36
+ module_path, attr_name = self.import_path.rsplit(".", 1)
37
+ module = importlib.import_module(module_path)
38
+ self._target = getattr(module, attr_name)
39
+ self._is_target_loaded = True
40
+
41
+ def __getattr__(self, name: str) -> Any:
42
+ self._load_target()
43
+ return getattr(self._target, name)
44
+
45
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
46
+ self._load_target()
47
+ if callable(self._target):
48
+ return self._target(*args, **kwargs)
49
+ else:
50
+ raise TypeError(f"{self.import_path!r} object is not callable")
51
+
52
+
13
53
  def rmdir(path: str) -> bool:
14
54
  """
15
55
  Remove a directory recursively.
@@ -55,3 +95,61 @@ def caller_name() -> str:
55
95
  return ""
56
96
 
57
97
  return caller_frame.f_code.co_name
98
+
99
+
100
+ def friendly_error(e: Exception, msg: str = "An error occurred.") -> str:
101
+ tb = traceback.format_exc()
102
+ original_error_message: str = str(e)
103
+ full_error_message: str = (
104
+ f"{msg}\nOriginal error: {original_error_message}\nTraceback:\n{tb}"
105
+ )
106
+ return full_error_message
107
+
108
+
109
+ def generate_user_id(org: str = "") -> str:
110
+ """
111
+ Generate a unique user ID based on the username and machine name.
112
+ Returns:
113
+ """
114
+ # Get the username
115
+ username = getpass.getuser()
116
+
117
+ # Get the machine's name
118
+ machine_name = socket.gethostname()
119
+
120
+ org_pfx = f"{org}_" if org else ""
121
+
122
+ # Create a consistent unique ID based on the username and machine name
123
+ unique_string = f"{org_pfx}{username}@{machine_name}"
124
+
125
+ # Generate a SHA-256 hash of the unique string
126
+ user_id = hashlib.sha256(unique_string.encode()).hexdigest()
127
+
128
+ return user_id
129
+
130
+
131
+ def update_hash(hash: str | None = None, s: str = "") -> str:
132
+ """
133
+ Takes a SHA256 hash string and a new string, updates the hash with the new string,
134
+ and returns the updated hash string.
135
+
136
+ Args:
137
+ hash (str): A SHA256 hash string.
138
+ s (str): A new string to update the hash with.
139
+
140
+ Returns:
141
+ The updated hash in hexadecimal format.
142
+ """
143
+ # Create a new hash object if no hash is provided
144
+ if hash is None:
145
+ hash_obj = hashlib.sha256()
146
+ else:
147
+ # Convert the hexadecimal hash string to a byte object
148
+ hash_bytes = bytes.fromhex(hash)
149
+ hash_obj = hashlib.sha256(hash_bytes)
150
+
151
+ # Update the hash with the new string
152
+ hash_obj.update(s.encode("utf-8"))
153
+
154
+ # Return the updated hash in hexadecimal format and the original string
155
+ return hash_obj.hexdigest()
@@ -0,0 +1,40 @@
1
+ from . import base
2
+
3
+ from . import qdrantdb
4
+ from . import meilisearch
5
+ from . import lancedb
6
+
7
+ from .base import VectorStoreConfig, VectorStore
8
+ from .qdrantdb import QdrantDBConfig, QdrantDB
9
+ from .meilisearch import MeiliSearch, MeiliSearchConfig
10
+ from .lancedb import LanceDB, LanceDBConfig
11
+
12
+ has_chromadb = False
13
+ try:
14
+ from . import chromadb
15
+ from .chromadb import ChromaDBConfig, ChromaDB
16
+
17
+ chromadb # silence linters
18
+ ChromaDB
19
+ ChromaDBConfig
20
+ has_chromadb = True
21
+ except ImportError:
22
+ pass
23
+
24
+ __all__ = [
25
+ "base",
26
+ "VectorStore",
27
+ "VectorStoreConfig",
28
+ "qdrantdb",
29
+ "meilisearch",
30
+ "lancedb",
31
+ "QdrantDBConfig",
32
+ "QdrantDB",
33
+ "MeiliSearch",
34
+ "MeiliSearchConfig",
35
+ "LanceDB",
36
+ "LanceDBConfig",
37
+ ]
38
+
39
+ if has_chromadb:
40
+ __all__.extend(["chromadb", "ChromaDBConfig", "ChromaDB"])
@@ -1,21 +1,26 @@
1
+ import copy
1
2
  import logging
2
3
  from abc import ABC, abstractmethod
3
- from typing import List, Optional, Sequence, Tuple
4
+ from typing import Dict, List, Optional, Sequence, Tuple
4
5
 
6
+ import numpy as np
7
+ import pandas as pd
5
8
  from pydantic import BaseSettings
6
9
 
7
- from langroid.embedding_models.base import EmbeddingModelsConfig
10
+ from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
8
11
  from langroid.embedding_models.models import OpenAIEmbeddingsConfig
9
12
  from langroid.mytypes import Document
13
+ from langroid.utils.algorithms.graph import components, topological_sort
10
14
  from langroid.utils.configuration import settings
11
15
  from langroid.utils.output.printing import print_long_text
16
+ from langroid.utils.pandas_utils import stringify
12
17
 
13
18
  logger = logging.getLogger(__name__)
14
19
 
15
20
 
16
21
  class VectorStoreConfig(BaseSettings):
17
- type: str = "qdrant" # deprecated, keeping it for backward compatibility
18
- collection_name: str | None = None
22
+ type: str = "" # deprecated, keeping it for backward compatibility
23
+ collection_name: str | None = "temp"
19
24
  replace_collection: bool = False # replace collection if it already exists
20
25
  storage_path: str = ".qdrant/data"
21
26
  cloud: bool = False
@@ -36,16 +41,27 @@ class VectorStore(ABC):
36
41
 
37
42
  def __init__(self, config: VectorStoreConfig):
38
43
  self.config = config
44
+ self.embedding_model = EmbeddingModel.create(config.embedding)
39
45
 
40
46
  @staticmethod
41
47
  def create(config: VectorStoreConfig) -> Optional["VectorStore"]:
42
48
  from langroid.vector_store.chromadb import ChromaDB, ChromaDBConfig
49
+ from langroid.vector_store.lancedb import LanceDB, LanceDBConfig
50
+ from langroid.vector_store.meilisearch import MeiliSearch, MeiliSearchConfig
51
+ from langroid.vector_store.momento import MomentoVI, MomentoVIConfig
43
52
  from langroid.vector_store.qdrantdb import QdrantDB, QdrantDBConfig
44
53
 
45
54
  if isinstance(config, QdrantDBConfig):
46
55
  return QdrantDB(config)
47
56
  elif isinstance(config, ChromaDBConfig):
48
57
  return ChromaDB(config)
58
+ elif isinstance(config, MomentoVIConfig):
59
+ return MomentoVI(config)
60
+ elif isinstance(config, LanceDBConfig):
61
+ return LanceDB(config)
62
+ elif isinstance(config, MeiliSearchConfig):
63
+ return MeiliSearch(config)
64
+
49
65
  else:
50
66
  logger.warning(
51
67
  f"""
@@ -113,6 +129,42 @@ class VectorStore(ABC):
113
129
  def add_documents(self, documents: Sequence[Document]) -> None:
114
130
  pass
115
131
 
132
+ def compute_from_docs(self, docs: List[Document], calc: str) -> str:
133
+ """Compute a result on a set of documents,
134
+ using a dataframe calc string like `df.groupby('state')['income'].mean()`.
135
+ """
136
+ dicts = [doc.dict() for doc in docs]
137
+ df = pd.DataFrame(dicts)
138
+
139
+ try:
140
+ result = pd.eval( # safer than eval but limited to single expression
141
+ calc,
142
+ engine="python",
143
+ parser="pandas",
144
+ local_dict={"df": df},
145
+ )
146
+ except Exception as e:
147
+ # return error message so LLM can fix the calc string if needed
148
+ err = f"""
149
+ Error encountered in pandas eval: {str(e)}
150
+ """
151
+ if isinstance(e, KeyError) and "not in index" in str(e):
152
+ # Pd.eval sometimes fails on a perfectly valid exprn like
153
+ # df.loc[..., 'column'] with a KeyError.
154
+ err += """
155
+ Maybe try a different way, e.g.
156
+ instead of df.loc[..., 'column'], try df.loc[...]['column']
157
+ """
158
+ return err
159
+ return stringify(result)
160
+
161
+ def maybe_add_ids(self, documents: Sequence[Document]) -> None:
162
+ """Add ids to metadata if absent, since some
163
+ vecdbs don't like having blank ids."""
164
+ for d in documents:
165
+ if d.metadata.id in [None, ""]:
166
+ d.metadata.id = d._unique_hash_id()
167
+
116
168
  @abstractmethod
117
169
  def similar_texts_with_scores(
118
170
  self,
@@ -120,12 +172,157 @@ class VectorStore(ABC):
120
172
  k: int = 1,
121
173
  where: Optional[str] = None,
122
174
  ) -> List[Tuple[Document, float]]:
175
+ """
176
+ Find k most similar texts to the given text, in terms of vector distance metric
177
+ (e.g., cosine similarity).
178
+
179
+ Args:
180
+ text (str): The text to find similar texts for.
181
+ k (int, optional): Number of similar texts to retrieve. Defaults to 1.
182
+ where (Optional[str], optional): Where clause to filter the search.
183
+
184
+ Returns:
185
+ List[Tuple[Document,float]]: List of (Document, score) tuples.
186
+
187
+ """
123
188
  pass
124
189
 
190
+ def add_context_window(
191
+ self, docs_scores: List[Tuple[Document, float]], neighbors: int = 0
192
+ ) -> List[Tuple[Document, float]]:
193
+ """
194
+ In each doc's metadata, there may be a window_ids field indicating
195
+ the ids of the chunks around the current chunk.
196
+ These window_ids may overlap, so we
197
+ - coalesce each overlapping groups into a single window (maintaining ordering),
198
+ - create a new document for each part, preserving metadata,
199
+
200
+ We may have stored a longer set of window_ids than we need during chunking.
201
+ Now, we just want `neighbors` on each side of the center of the window_ids list.
202
+
203
+ Args:
204
+ docs_scores (List[Tuple[Document, float]]): List of pairs of documents
205
+ to add context windows to together with their match scores.
206
+ neighbors (int, optional): Number of neighbors on "each side" of match to
207
+ retrieve. Defaults to 0.
208
+ "Each side" here means before and after the match,
209
+ in the original text.
210
+
211
+ Returns:
212
+ List[Tuple[Document, float]]: List of (Document, score) tuples.
213
+ """
214
+ # We return a larger context around each match, i.e.
215
+ # a window of `neighbors` on each side of the match.
216
+ docs = [d for d, s in docs_scores]
217
+ scores = [s for d, s in docs_scores]
218
+ if neighbors == 0:
219
+ return docs_scores
220
+ doc_chunks = [d for d in docs if d.metadata.is_chunk]
221
+ if len(doc_chunks) == 0:
222
+ return docs_scores
223
+ window_ids_list = []
224
+ id2metadata = {}
225
+ # id -> highest score of a doc it appears in
226
+ id2max_score: Dict[int | str, float] = {}
227
+ for i, d in enumerate(docs):
228
+ window_ids = d.metadata.window_ids
229
+ if len(window_ids) == 0:
230
+ window_ids = [d.id()]
231
+ id2metadata.update({id: d.metadata for id in window_ids})
232
+
233
+ id2max_score.update(
234
+ {id: max(id2max_score.get(id, 0), scores[i]) for id in window_ids}
235
+ )
236
+ n = len(window_ids)
237
+ chunk_idx = window_ids.index(d.id())
238
+ neighbor_ids = window_ids[
239
+ max(0, chunk_idx - neighbors) : min(n, chunk_idx + neighbors + 1)
240
+ ]
241
+ window_ids_list += [neighbor_ids]
242
+
243
+ # window_ids could be from different docs,
244
+ # and they may overlap, so we coalesce overlapping groups into
245
+ # separate windows.
246
+ window_ids_list = self.remove_overlaps(window_ids_list)
247
+ final_docs = []
248
+ final_scores = []
249
+ for w in window_ids_list:
250
+ metadata = copy.deepcopy(id2metadata[w[0]])
251
+ metadata.window_ids = w
252
+ document = Document(
253
+ content=" ".join([d.content for d in self.get_documents_by_ids(w)]),
254
+ metadata=metadata,
255
+ )
256
+ # make a fresh id since content is in general different
257
+ document.metadata.id = document.hash_id(document.content)
258
+ final_docs += [document]
259
+ final_scores += [max(id2max_score[id] for id in w)]
260
+ return list(zip(final_docs, final_scores))
261
+
262
+ @staticmethod
263
+ def remove_overlaps(windows: List[List[str]]) -> List[List[str]]:
264
+ """
265
+ Given a collection of windows, where each window is a sequence of ids,
266
+ identify groups of overlapping windows, and for each overlapping group,
267
+ order the chunk-ids using topological sort so they appear in the original
268
+ order in the text.
269
+
270
+ Args:
271
+ windows (List[int|str]): List of windows, where each window is a
272
+ sequence of ids.
273
+
274
+ Returns:
275
+ List[int|str]: List of windows, where each window is a sequence of ids,
276
+ and no two windows overlap.
277
+ """
278
+ ids = set(id for w in windows for id in w)
279
+ # id -> {win -> # pos}
280
+ id2win2pos: Dict[str, Dict[int, int]] = {id: {} for id in ids}
281
+
282
+ for i, w in enumerate(windows):
283
+ for j, id in enumerate(w):
284
+ id2win2pos[id][i] = j
285
+
286
+ n = len(windows)
287
+ # relation between windows:
288
+ order = np.zeros((n, n), dtype=np.int8)
289
+ for i, w in enumerate(windows):
290
+ for j, x in enumerate(windows):
291
+ if i == j:
292
+ continue
293
+ if len(set(w).intersection(x)) == 0:
294
+ continue
295
+ id = list(set(w).intersection(x))[0] # any common id
296
+ if id2win2pos[id][i] > id2win2pos[id][j]:
297
+ order[i, j] = -1 # win i is before win j
298
+ else:
299
+ order[i, j] = 1 # win i is after win j
300
+
301
+ # find groups of windows that overlap, like connected components in a graph
302
+ groups = components(np.abs(order))
303
+
304
+ # order the chunk-ids in each group using topological sort
305
+ new_windows = []
306
+ for g in groups:
307
+ # find total ordering among windows in group based on order matrix
308
+ # (this is a topological sort)
309
+ _g = np.array(g)
310
+ order_matrix = order[_g][:, _g]
311
+ ordered_window_indices = topological_sort(order_matrix)
312
+ ordered_window_ids = [windows[i] for i in _g[ordered_window_indices]]
313
+ flattened = [id for w in ordered_window_ids for id in w]
314
+ flattened_deduped = list(dict.fromkeys(flattened))
315
+ # Note we are not going to split these, and instead we'll return
316
+ # larger windows from concatenating the connected groups.
317
+ # This ensures context is retained for LLM q/a
318
+ new_windows += [flattened_deduped]
319
+
320
+ return new_windows
321
+
125
322
  @abstractmethod
126
- def get_all_documents(self) -> List[Document]:
323
+ def get_all_documents(self, where: str = "") -> List[Document]:
127
324
  """
128
- Get all documents in the current collection.
325
+ Get all documents in the current collection, possibly filtered by `where`.
129
326
  """
130
327
  pass
131
328
 
@@ -1,8 +1,7 @@
1
+ import json
1
2
  import logging
2
3
  from typing import Any, Dict, List, Optional, Sequence, Tuple
3
4
 
4
- import chromadb
5
-
6
5
  from langroid.embedding_models.base import (
7
6
  EmbeddingModel,
8
7
  EmbeddingModelsConfig,
@@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
17
16
 
18
17
 
19
18
  class ChromaDBConfig(VectorStoreConfig):
20
- collection_name: str = "chroma-langroid"
19
+ collection_name: str = "temp"
21
20
  storage_path: str = ".chroma/data"
22
21
  embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
23
22
  host: str = "127.0.0.1"
@@ -25,8 +24,19 @@ class ChromaDBConfig(VectorStoreConfig):
25
24
 
26
25
 
27
26
  class ChromaDB(VectorStore):
28
- def __init__(self, config: ChromaDBConfig):
27
+ def __init__(self, config: ChromaDBConfig = ChromaDBConfig()):
29
28
  super().__init__(config)
29
+ try:
30
+ import chromadb
31
+ except ImportError:
32
+ raise ImportError(
33
+ """
34
+ ChromaDB is not installed by default with Langroid.
35
+ If you want to use it, please install it with the `chromadb` extra, e.g.
36
+ pip install "langroid[chromadb]"
37
+ or an equivalent command.
38
+ """
39
+ )
30
40
  self.config = config
31
41
  emb_model = EmbeddingModel.create(config.embedding)
32
42
  self.embedding_fn = emb_model.embedding_fn()
@@ -99,53 +109,78 @@ class ChromaDB(VectorStore):
99
109
 
100
110
  """
101
111
  self.config.collection_name = collection_name
112
+ if collection_name in self.list_collections(empty=True) and replace:
113
+ logger.warning(f"Replacing existing collection {collection_name}")
114
+ self.client.delete_collection(collection_name)
102
115
  self.collection = self.client.create_collection(
103
116
  name=self.config.collection_name,
104
117
  embedding_function=self.embedding_fn,
105
118
  get_or_create=not replace,
106
119
  )
107
120
 
108
- def add_documents(self, documents: Optional[Sequence[Document]] = None) -> None:
121
+ def add_documents(self, documents: Sequence[Document]) -> None:
122
+ super().maybe_add_ids(documents)
109
123
  if documents is None:
110
124
  return
111
125
  contents: List[str] = [document.content for document in documents]
112
- metadatas: List[dict[str, Any]] = [
113
- document.metadata.dict() for document in documents
126
+ # convert metadatas to dicts so chroma can handle them
127
+ metadata_dicts: List[dict[str, Any]] = [
128
+ d.metadata.dict_bool_int() for d in documents
114
129
  ]
130
+ for m in metadata_dicts:
131
+ # chroma does not handle non-atomic types in metadata
132
+ m["window_ids"] = ",".join(m["window_ids"])
133
+
115
134
  ids = [str(d.id()) for d in documents]
116
135
  self.collection.add(
117
136
  # embedding_models=embedding_models,
118
137
  documents=contents,
119
- metadatas=metadatas,
138
+ metadatas=metadata_dicts,
120
139
  ids=ids,
121
140
  )
122
141
 
123
- def get_all_documents(self) -> List[Document]:
124
- results = self.collection.get(include=["documents", "metadatas"])
142
+ def get_all_documents(self, where: str = "") -> List[Document]:
143
+ filter = json.loads(where) if where else None
144
+ results = self.collection.get(
145
+ include=["documents", "metadatas"],
146
+ where=filter,
147
+ )
125
148
  results["documents"] = [results["documents"]]
126
149
  results["metadatas"] = [results["metadatas"]]
127
150
  return self._docs_from_results(results)
128
151
 
129
152
  def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
130
- results = self.collection.get(ids=ids, include=["documents", "metadatas"])
131
- results["documents"] = [results["documents"]]
132
- results["metadatas"] = [results["metadatas"]]
133
- return self._docs_from_results(results)
153
+ # get them one by one since chroma mangles the order of the results
154
+ # when fetched from a list of ids.
155
+ results = [
156
+ self.collection.get(ids=[id], include=["documents", "metadatas"])
157
+ for id in ids
158
+ ]
159
+ final_results = {}
160
+ final_results["documents"] = [[r["documents"][0] for r in results]]
161
+ final_results["metadatas"] = [[r["metadatas"][0] for r in results]]
162
+ return self._docs_from_results(final_results)
134
163
 
135
164
  def delete_collection(self, collection_name: str) -> None:
136
- self.client.delete_collection(name=collection_name)
165
+ try:
166
+ self.client.delete_collection(name=collection_name)
167
+ except Exception:
168
+ pass
137
169
 
138
170
  def similar_texts_with_scores(
139
171
  self, text: str, k: int = 1, where: Optional[str] = None
140
172
  ) -> List[Tuple[Document, float]]:
173
+ n = self.collection.count()
174
+ filter = json.loads(where) if where else None
141
175
  results = self.collection.query(
142
176
  query_texts=[text],
143
- n_results=k,
144
- where=where,
177
+ n_results=min(n, k),
178
+ where=filter,
145
179
  include=["documents", "distances", "metadatas"],
146
180
  )
147
181
  docs = self._docs_from_results(results)
148
- scores = results["distances"][0]
182
+ # chroma distances are 1 - cosine.
183
+ scores = [1 - s for s in results["distances"][0]]
149
184
  return list(zip(docs, scores))
150
185
 
151
186
  def _docs_from_results(self, results: Dict[str, Any]) -> List[Document]:
@@ -164,22 +199,14 @@ class ChromaDB(VectorStore):
164
199
  for i, c in enumerate(contents):
165
200
  print_long_text("red", "italic red", f"MATCH-{i}", c)
166
201
  metadatas = results["metadatas"][0]
202
+ for m in metadatas:
203
+ # restore the stringified list of window_ids into the original List[str]
204
+ if m["window_ids"].strip() == "":
205
+ m["window_ids"] = []
206
+ else:
207
+ m["window_ids"] = m["window_ids"].split(",")
167
208
  docs = [
168
209
  Document(content=d, metadata=DocMetaData(**m))
169
210
  for d, m in zip(contents, metadatas)
170
211
  ]
171
212
  return docs
172
-
173
-
174
- # Example usage and testing
175
- # chroma_db = ChromaDB.from_documents(
176
- # collection_name="all-my-documents",
177
- # documents=["doc1000101", "doc288822"],
178
- # metadatas=[{"style": "style1"}, {"style": "style2"}],
179
- # ids=["uri9", "uri10"]
180
- # )
181
- # results = chroma_db.query(
182
- # query_texts=["This is a query document"],
183
- # n_results=2
184
- # )
185
- # print(results)