langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. langroid/__init__.py +106 -0
  2. langroid/agent/__init__.py +41 -0
  3. langroid/agent/base.py +1983 -0
  4. langroid/agent/batch.py +398 -0
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +598 -0
  7. langroid/agent/chat_agent.py +1899 -0
  8. langroid/agent/chat_document.py +454 -0
  9. langroid/agent/openai_assistant.py +882 -0
  10. langroid/agent/special/__init__.py +59 -0
  11. langroid/agent/special/arangodb/__init__.py +0 -0
  12. langroid/agent/special/arangodb/arangodb_agent.py +656 -0
  13. langroid/agent/special/arangodb/system_messages.py +186 -0
  14. langroid/agent/special/arangodb/tools.py +107 -0
  15. langroid/agent/special/arangodb/utils.py +36 -0
  16. langroid/agent/special/doc_chat_agent.py +1466 -0
  17. langroid/agent/special/lance_doc_chat_agent.py +262 -0
  18. langroid/agent/special/lance_rag/__init__.py +9 -0
  19. langroid/agent/special/lance_rag/critic_agent.py +198 -0
  20. langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
  21. langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
  22. langroid/agent/special/lance_tools.py +61 -0
  23. langroid/agent/special/neo4j/__init__.py +0 -0
  24. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  25. langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
  26. langroid/agent/special/neo4j/system_messages.py +120 -0
  27. langroid/agent/special/neo4j/tools.py +32 -0
  28. langroid/agent/special/relevance_extractor_agent.py +127 -0
  29. langroid/agent/special/retriever_agent.py +56 -0
  30. langroid/agent/special/sql/__init__.py +17 -0
  31. langroid/agent/special/sql/sql_chat_agent.py +654 -0
  32. langroid/agent/special/sql/utils/__init__.py +21 -0
  33. langroid/agent/special/sql/utils/description_extractors.py +190 -0
  34. langroid/agent/special/sql/utils/populate_metadata.py +85 -0
  35. langroid/agent/special/sql/utils/system_message.py +35 -0
  36. langroid/agent/special/sql/utils/tools.py +64 -0
  37. langroid/agent/special/table_chat_agent.py +263 -0
  38. langroid/agent/task.py +2095 -0
  39. langroid/agent/tool_message.py +393 -0
  40. langroid/agent/tools/__init__.py +38 -0
  41. langroid/agent/tools/duckduckgo_search_tool.py +50 -0
  42. langroid/agent/tools/file_tools.py +234 -0
  43. langroid/agent/tools/google_search_tool.py +39 -0
  44. langroid/agent/tools/metaphor_search_tool.py +68 -0
  45. langroid/agent/tools/orchestration.py +303 -0
  46. langroid/agent/tools/recipient_tool.py +235 -0
  47. langroid/agent/tools/retrieval_tool.py +32 -0
  48. langroid/agent/tools/rewind_tool.py +137 -0
  49. langroid/agent/tools/segment_extract_tool.py +41 -0
  50. langroid/agent/xml_tool_message.py +382 -0
  51. langroid/cachedb/__init__.py +17 -0
  52. langroid/cachedb/base.py +58 -0
  53. langroid/cachedb/momento_cachedb.py +108 -0
  54. langroid/cachedb/redis_cachedb.py +153 -0
  55. langroid/embedding_models/__init__.py +39 -0
  56. langroid/embedding_models/base.py +74 -0
  57. langroid/embedding_models/models.py +461 -0
  58. langroid/embedding_models/protoc/__init__.py +0 -0
  59. langroid/embedding_models/protoc/embeddings.proto +19 -0
  60. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  61. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  62. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  63. langroid/embedding_models/remote_embeds.py +153 -0
  64. langroid/exceptions.py +71 -0
  65. langroid/language_models/__init__.py +53 -0
  66. langroid/language_models/azure_openai.py +153 -0
  67. langroid/language_models/base.py +678 -0
  68. langroid/language_models/config.py +18 -0
  69. langroid/language_models/mock_lm.py +124 -0
  70. langroid/language_models/openai_gpt.py +1964 -0
  71. langroid/language_models/prompt_formatter/__init__.py +16 -0
  72. langroid/language_models/prompt_formatter/base.py +40 -0
  73. langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
  74. langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
  75. langroid/language_models/utils.py +151 -0
  76. langroid/mytypes.py +84 -0
  77. langroid/parsing/__init__.py +52 -0
  78. langroid/parsing/agent_chats.py +38 -0
  79. langroid/parsing/code_parser.py +121 -0
  80. langroid/parsing/document_parser.py +718 -0
  81. langroid/parsing/para_sentence_split.py +62 -0
  82. langroid/parsing/parse_json.py +155 -0
  83. langroid/parsing/parser.py +313 -0
  84. langroid/parsing/repo_loader.py +790 -0
  85. langroid/parsing/routing.py +36 -0
  86. langroid/parsing/search.py +275 -0
  87. langroid/parsing/spider.py +102 -0
  88. langroid/parsing/table_loader.py +94 -0
  89. langroid/parsing/url_loader.py +111 -0
  90. langroid/parsing/urls.py +273 -0
  91. langroid/parsing/utils.py +373 -0
  92. langroid/parsing/web_search.py +156 -0
  93. langroid/prompts/__init__.py +9 -0
  94. langroid/prompts/dialog.py +17 -0
  95. langroid/prompts/prompts_config.py +5 -0
  96. langroid/prompts/templates.py +141 -0
  97. langroid/pydantic_v1/__init__.py +10 -0
  98. langroid/pydantic_v1/main.py +4 -0
  99. langroid/utils/__init__.py +19 -0
  100. langroid/utils/algorithms/__init__.py +3 -0
  101. langroid/utils/algorithms/graph.py +103 -0
  102. langroid/utils/configuration.py +98 -0
  103. langroid/utils/constants.py +30 -0
  104. langroid/utils/git_utils.py +252 -0
  105. langroid/utils/globals.py +49 -0
  106. langroid/utils/logging.py +135 -0
  107. langroid/utils/object_registry.py +66 -0
  108. langroid/utils/output/__init__.py +20 -0
  109. langroid/utils/output/citations.py +41 -0
  110. langroid/utils/output/printing.py +99 -0
  111. langroid/utils/output/status.py +40 -0
  112. langroid/utils/pandas_utils.py +30 -0
  113. langroid/utils/pydantic_utils.py +602 -0
  114. langroid/utils/system.py +286 -0
  115. langroid/utils/types.py +93 -0
  116. langroid/vector_store/__init__.py +50 -0
  117. langroid/vector_store/base.py +359 -0
  118. langroid/vector_store/chromadb.py +214 -0
  119. langroid/vector_store/lancedb.py +406 -0
  120. langroid/vector_store/meilisearch.py +299 -0
  121. langroid/vector_store/momento.py +278 -0
  122. langroid/vector_store/qdrantdb.py +468 -0
  123. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
  124. langroid-0.33.7.dist-info/RECORD +127 -0
  125. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
  126. langroid-0.33.6.dist-info/RECORD +0 -7
  127. langroid-0.33.6.dist-info/entry_points.txt +0 -4
  128. pyproject.toml +0 -356
  129. {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,359 @@
1
+ import copy
2
+ import logging
3
+ from abc import ABC, abstractmethod
4
+ from typing import Dict, List, Optional, Sequence, Tuple, Type
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from langroid.embedding_models.base import EmbeddingModel, EmbeddingModelsConfig
10
+ from langroid.embedding_models.models import OpenAIEmbeddingsConfig
11
+ from langroid.mytypes import DocMetaData, Document
12
+ from langroid.pydantic_v1 import BaseSettings
13
+ from langroid.utils.algorithms.graph import components, topological_sort
14
+ from langroid.utils.configuration import settings
15
+ from langroid.utils.object_registry import ObjectRegistry
16
+ from langroid.utils.output.printing import print_long_text
17
+ from langroid.utils.pandas_utils import stringify
18
+ from langroid.utils.pydantic_utils import flatten_dict
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class VectorStoreConfig(BaseSettings):
24
+ type: str = "" # deprecated, keeping it for backward compatibility
25
+ collection_name: str | None = "temp"
26
+ replace_collection: bool = False # replace collection if it already exists
27
+ storage_path: str = ".qdrant/data"
28
+ cloud: bool = False
29
+ batch_size: int = 200
30
+ embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig(
31
+ model_type="openai",
32
+ )
33
+ embedding_model: Optional[EmbeddingModel] = None
34
+ timeout: int = 60
35
+ host: str = "127.0.0.1"
36
+ port: int = 6333
37
+ # used when parsing search results back as Document objects
38
+ document_class: Type[Document] = Document
39
+ metadata_class: Type[DocMetaData] = DocMetaData
40
+ # compose_file: str = "langroid/vector_store/docker-compose-qdrant.yml"
41
+
42
+
43
+ class VectorStore(ABC):
44
+ """
45
+ Abstract base class for a vector store.
46
+ """
47
+
48
+ def __init__(self, config: VectorStoreConfig):
49
+ self.config = config
50
+ if config.embedding_model is None:
51
+ self.embedding_model = EmbeddingModel.create(config.embedding)
52
+ else:
53
+ self.embedding_model = config.embedding_model
54
+
55
+ @staticmethod
56
+ def create(config: VectorStoreConfig) -> Optional["VectorStore"]:
57
+ from langroid.vector_store.chromadb import ChromaDB, ChromaDBConfig
58
+ from langroid.vector_store.lancedb import LanceDB, LanceDBConfig
59
+ from langroid.vector_store.meilisearch import MeiliSearch, MeiliSearchConfig
60
+ from langroid.vector_store.momento import MomentoVI, MomentoVIConfig
61
+ from langroid.vector_store.qdrantdb import QdrantDB, QdrantDBConfig
62
+
63
+ if isinstance(config, QdrantDBConfig):
64
+ return QdrantDB(config)
65
+ elif isinstance(config, ChromaDBConfig):
66
+ return ChromaDB(config)
67
+ elif isinstance(config, MomentoVIConfig):
68
+ return MomentoVI(config)
69
+ elif isinstance(config, LanceDBConfig):
70
+ return LanceDB(config)
71
+ elif isinstance(config, MeiliSearchConfig):
72
+ return MeiliSearch(config)
73
+
74
+ else:
75
+ logger.warning(
76
+ f"""
77
+ Unknown vector store config: {config.__repr_name__()},
78
+ so skipping vector store creation!
79
+ If you intended to use a vector-store, please set a specific
80
+ vector-store in your script, typically in the `vecdb` field of a
81
+ `ChatAgentConfig`, otherwise set it to None.
82
+ """
83
+ )
84
+ return None
85
+
86
+ @abstractmethod
87
+ def clear_empty_collections(self) -> int:
88
+ """Clear all empty collections in the vector store.
89
+ Returns the number of collections deleted.
90
+ """
91
+ pass
92
+
93
+ @abstractmethod
94
+ def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
95
+ """
96
+ Clear all collections in the vector store.
97
+
98
+ Args:
99
+ really (bool, optional): Whether to really clear all collections.
100
+ Defaults to False.
101
+ prefix (str, optional): Prefix of collections to clear.
102
+ Returns:
103
+ int: Number of collections deleted.
104
+ """
105
+ pass
106
+
107
+ @abstractmethod
108
+ def list_collections(self, empty: bool = False) -> List[str]:
109
+ """List all collections in the vector store
110
+ (only non empty collections if empty=False).
111
+ """
112
+ pass
113
+
114
+ def set_collection(self, collection_name: str, replace: bool = False) -> None:
115
+ """
116
+ Set the current collection to the given collection name.
117
+ Args:
118
+ collection_name (str): Name of the collection.
119
+ replace (bool, optional): Whether to replace the collection if it
120
+ already exists. Defaults to False.
121
+ """
122
+
123
+ self.config.collection_name = collection_name
124
+ self.config.replace_collection = replace
125
+ if replace:
126
+ self.create_collection(collection_name, replace=True)
127
+
128
+ @abstractmethod
129
+ def create_collection(self, collection_name: str, replace: bool = False) -> None:
130
+ """Create a collection with the given name.
131
+ Args:
132
+ collection_name (str): Name of the collection.
133
+ replace (bool, optional): Whether to replace the
134
+ collection if it already exists. Defaults to False.
135
+ """
136
+ pass
137
+
138
+ @abstractmethod
139
+ def add_documents(self, documents: Sequence[Document]) -> None:
140
+ pass
141
+
142
+ def compute_from_docs(self, docs: List[Document], calc: str) -> str:
143
+ """Compute a result on a set of documents,
144
+ using a dataframe calc string like `df.groupby('state')['income'].mean()`.
145
+ """
146
+ # convert each doc to a dict, using dotted paths for nested fields
147
+ dicts = [flatten_dict(doc.dict(by_alias=True)) for doc in docs]
148
+ df = pd.DataFrame(dicts)
149
+
150
+ try:
151
+ result = pd.eval( # safer than eval but limited to single expression
152
+ calc,
153
+ engine="python",
154
+ parser="pandas",
155
+ local_dict={"df": df},
156
+ )
157
+ except Exception as e:
158
+ # return error message so LLM can fix the calc string if needed
159
+ err = f"""
160
+ Error encountered in pandas eval: {str(e)}
161
+ """
162
+ if isinstance(e, KeyError) and "not in index" in str(e):
163
+ # Pd.eval sometimes fails on a perfectly valid exprn like
164
+ # df.loc[..., 'column'] with a KeyError.
165
+ err += """
166
+ Maybe try a different way, e.g.
167
+ instead of df.loc[..., 'column'], try df.loc[...]['column']
168
+ """
169
+ return err
170
+ return stringify(result)
171
+
172
+ def maybe_add_ids(self, documents: Sequence[Document]) -> None:
173
+ """Add ids to metadata if absent, since some
174
+ vecdbs don't like having blank ids."""
175
+ for d in documents:
176
+ if d.metadata.id in [None, ""]:
177
+ d.metadata.id = ObjectRegistry.new_id()
178
+
179
+ @abstractmethod
180
+ def similar_texts_with_scores(
181
+ self,
182
+ text: str,
183
+ k: int = 1,
184
+ where: Optional[str] = None,
185
+ ) -> List[Tuple[Document, float]]:
186
+ """
187
+ Find k most similar texts to the given text, in terms of vector distance metric
188
+ (e.g., cosine similarity).
189
+
190
+ Args:
191
+ text (str): The text to find similar texts for.
192
+ k (int, optional): Number of similar texts to retrieve. Defaults to 1.
193
+ where (Optional[str], optional): Where clause to filter the search.
194
+
195
+ Returns:
196
+ List[Tuple[Document,float]]: List of (Document, score) tuples.
197
+
198
+ """
199
+ pass
200
+
201
+ def add_context_window(
202
+ self, docs_scores: List[Tuple[Document, float]], neighbors: int = 0
203
+ ) -> List[Tuple[Document, float]]:
204
+ """
205
+ In each doc's metadata, there may be a window_ids field indicating
206
+ the ids of the chunks around the current chunk.
207
+ These window_ids may overlap, so we
208
+ - coalesce each overlapping groups into a single window (maintaining ordering),
209
+ - create a new document for each part, preserving metadata,
210
+
211
+ We may have stored a longer set of window_ids than we need during chunking.
212
+ Now, we just want `neighbors` on each side of the center of the window_ids list.
213
+
214
+ Args:
215
+ docs_scores (List[Tuple[Document, float]]): List of pairs of documents
216
+ to add context windows to together with their match scores.
217
+ neighbors (int, optional): Number of neighbors on "each side" of match to
218
+ retrieve. Defaults to 0.
219
+ "Each side" here means before and after the match,
220
+ in the original text.
221
+
222
+ Returns:
223
+ List[Tuple[Document, float]]: List of (Document, score) tuples.
224
+ """
225
+ # We return a larger context around each match, i.e.
226
+ # a window of `neighbors` on each side of the match.
227
+ docs = [d for d, s in docs_scores]
228
+ scores = [s for d, s in docs_scores]
229
+ if neighbors == 0:
230
+ return docs_scores
231
+ doc_chunks = [d for d in docs if d.metadata.is_chunk]
232
+ if len(doc_chunks) == 0:
233
+ return docs_scores
234
+ window_ids_list = []
235
+ id2metadata = {}
236
+ # id -> highest score of a doc it appears in
237
+ id2max_score: Dict[int | str, float] = {}
238
+ for i, d in enumerate(docs):
239
+ window_ids = d.metadata.window_ids
240
+ if len(window_ids) == 0:
241
+ window_ids = [d.id()]
242
+ id2metadata.update({id: d.metadata for id in window_ids})
243
+
244
+ id2max_score.update(
245
+ {id: max(id2max_score.get(id, 0), scores[i]) for id in window_ids}
246
+ )
247
+ n = len(window_ids)
248
+ chunk_idx = window_ids.index(d.id())
249
+ neighbor_ids = window_ids[
250
+ max(0, chunk_idx - neighbors) : min(n, chunk_idx + neighbors + 1)
251
+ ]
252
+ window_ids_list += [neighbor_ids]
253
+
254
+ # window_ids could be from different docs,
255
+ # and they may overlap, so we coalesce overlapping groups into
256
+ # separate windows.
257
+ window_ids_list = self.remove_overlaps(window_ids_list)
258
+ final_docs = []
259
+ final_scores = []
260
+ for w in window_ids_list:
261
+ metadata = copy.deepcopy(id2metadata[w[0]])
262
+ metadata.window_ids = w
263
+ document = Document(
264
+ content=" ".join([d.content for d in self.get_documents_by_ids(w)]),
265
+ metadata=metadata,
266
+ )
267
+ # make a fresh id since content is in general different
268
+ document.metadata.id = ObjectRegistry.new_id()
269
+ final_docs += [document]
270
+ final_scores += [max(id2max_score[id] for id in w)]
271
+ return list(zip(final_docs, final_scores))
272
+
273
+ @staticmethod
274
+ def remove_overlaps(windows: List[List[str]]) -> List[List[str]]:
275
+ """
276
+ Given a collection of windows, where each window is a sequence of ids,
277
+ identify groups of overlapping windows, and for each overlapping group,
278
+ order the chunk-ids using topological sort so they appear in the original
279
+ order in the text.
280
+
281
+ Args:
282
+ windows (List[int|str]): List of windows, where each window is a
283
+ sequence of ids.
284
+
285
+ Returns:
286
+ List[int|str]: List of windows, where each window is a sequence of ids,
287
+ and no two windows overlap.
288
+ """
289
+ ids = set(id for w in windows for id in w)
290
+ # id -> {win -> # pos}
291
+ id2win2pos: Dict[str, Dict[int, int]] = {id: {} for id in ids}
292
+
293
+ for i, w in enumerate(windows):
294
+ for j, id in enumerate(w):
295
+ id2win2pos[id][i] = j
296
+
297
+ n = len(windows)
298
+ # relation between windows:
299
+ order = np.zeros((n, n), dtype=np.int8)
300
+ for i, w in enumerate(windows):
301
+ for j, x in enumerate(windows):
302
+ if i == j:
303
+ continue
304
+ if len(set(w).intersection(x)) == 0:
305
+ continue
306
+ id = list(set(w).intersection(x))[0] # any common id
307
+ if id2win2pos[id][i] > id2win2pos[id][j]:
308
+ order[i, j] = -1 # win i is before win j
309
+ else:
310
+ order[i, j] = 1 # win i is after win j
311
+
312
+ # find groups of windows that overlap, like connected components in a graph
313
+ groups = components(np.abs(order))
314
+
315
+ # order the chunk-ids in each group using topological sort
316
+ new_windows = []
317
+ for g in groups:
318
+ # find total ordering among windows in group based on order matrix
319
+ # (this is a topological sort)
320
+ _g = np.array(g)
321
+ order_matrix = order[_g][:, _g]
322
+ ordered_window_indices = topological_sort(order_matrix)
323
+ ordered_window_ids = [windows[i] for i in _g[ordered_window_indices]]
324
+ flattened = [id for w in ordered_window_ids for id in w]
325
+ flattened_deduped = list(dict.fromkeys(flattened))
326
+ # Note we are not going to split these, and instead we'll return
327
+ # larger windows from concatenating the connected groups.
328
+ # This ensures context is retained for LLM q/a
329
+ new_windows += [flattened_deduped]
330
+
331
+ return new_windows
332
+
333
+ @abstractmethod
334
+ def get_all_documents(self, where: str = "") -> List[Document]:
335
+ """
336
+ Get all documents in the current collection, possibly filtered by `where`.
337
+ """
338
+ pass
339
+
340
+ @abstractmethod
341
+ def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
342
+ """
343
+ Get documents by their ids.
344
+ Args:
345
+ ids (List[str]): List of document ids.
346
+
347
+ Returns:
348
+ List[Document]: List of documents
349
+ """
350
+ pass
351
+
352
+ @abstractmethod
353
+ def delete_collection(self, collection_name: str) -> None:
354
+ pass
355
+
356
+ def show_if_debug(self, doc_score_pairs: List[Tuple[Document, float]]) -> None:
357
+ if settings.debug:
358
+ for i, (d, s) in enumerate(doc_score_pairs):
359
+ print_long_text("red", "italic red", f"\nMATCH-{i}\n", d.content)
@@ -0,0 +1,214 @@
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
4
+
5
+ from langroid.embedding_models.base import (
6
+ EmbeddingModelsConfig,
7
+ )
8
+ from langroid.embedding_models.models import OpenAIEmbeddingsConfig
9
+ from langroid.exceptions import LangroidImportError
10
+ from langroid.mytypes import Document
11
+ from langroid.utils.configuration import settings
12
+ from langroid.utils.output.printing import print_long_text
13
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ChromaDBConfig(VectorStoreConfig):
19
+ collection_name: str = "temp"
20
+ storage_path: str = ".chroma/data"
21
+ embedding: EmbeddingModelsConfig = OpenAIEmbeddingsConfig()
22
+ host: str = "127.0.0.1"
23
+ port: int = 6333
24
+
25
+
26
+ class ChromaDB(VectorStore):
27
+ def __init__(self, config: ChromaDBConfig = ChromaDBConfig()):
28
+ super().__init__(config)
29
+ try:
30
+ import chromadb
31
+ except ImportError:
32
+ raise LangroidImportError("chromadb", "chromadb")
33
+ self.config = config
34
+ self.embedding_fn = self.embedding_model.embedding_fn()
35
+ self.client = chromadb.Client(
36
+ chromadb.config.Settings(
37
+ # chroma_db_impl="duckdb+parquet",
38
+ # is_persistent=bool(config.storage_path),
39
+ persist_directory=config.storage_path,
40
+ )
41
+ )
42
+ if self.config.collection_name is not None:
43
+ self.create_collection(
44
+ self.config.collection_name,
45
+ replace=self.config.replace_collection,
46
+ )
47
+
48
+ def clear_all_collections(self, really: bool = False, prefix: str = "") -> int:
49
+ """Clear all collections in the vector store with the given prefix."""
50
+
51
+ if not really:
52
+ logger.warning("Not deleting all collections, set really=True to confirm")
53
+ return 0
54
+ coll = [c for c in self.client.list_collections() if c.name.startswith(prefix)]
55
+ if len(coll) == 0:
56
+ logger.warning(f"No collections found with prefix {prefix}")
57
+ return 0
58
+ n_empty_deletes = 0
59
+ n_non_empty_deletes = 0
60
+ for c in coll:
61
+ n_empty_deletes += c.count() == 0
62
+ n_non_empty_deletes += c.count() > 0
63
+ self.client.delete_collection(name=c.name)
64
+ logger.warning(
65
+ f"""
66
+ Deleted {n_empty_deletes} empty collections and
67
+ {n_non_empty_deletes} non-empty collections.
68
+ """
69
+ )
70
+ return n_empty_deletes + n_non_empty_deletes
71
+
72
+ def clear_empty_collections(self) -> int:
73
+ colls = self.client.list_collections()
74
+ n_deletes = 0
75
+ for coll in colls:
76
+ if coll.count() == 0:
77
+ n_deletes += 1
78
+ self.client.delete_collection(name=coll.name)
79
+ return n_deletes
80
+
81
+ def list_collections(self, empty: bool = False) -> List[str]:
82
+ """
83
+ List non-empty collections in the vector store.
84
+ Args:
85
+ empty (bool, optional): Whether to list empty collections.
86
+ Returns:
87
+ List[str]: List of non-empty collection names.
88
+ """
89
+ colls = self.client.list_collections()
90
+ if empty:
91
+ return [coll.name for coll in colls]
92
+ return [coll.name for coll in colls if coll.count() > 0]
93
+
94
+ def create_collection(self, collection_name: str, replace: bool = False) -> None:
95
+ """
96
+ Create a collection in the vector store, optionally replacing an existing
97
+ collection if `replace` is True.
98
+ Args:
99
+ collection_name (str): Name of the collection to create or replace.
100
+ replace (bool, optional): Whether to replace an existing collection.
101
+ Defaults to False.
102
+
103
+ """
104
+ self.config.collection_name = collection_name
105
+ if collection_name in self.list_collections(empty=True) and replace:
106
+ logger.warning(f"Replacing existing collection {collection_name}")
107
+ self.client.delete_collection(collection_name)
108
+ self.collection = self.client.create_collection(
109
+ name=self.config.collection_name,
110
+ embedding_function=self.embedding_fn,
111
+ get_or_create=not replace,
112
+ )
113
+
114
+ def add_documents(self, documents: Sequence[Document]) -> None:
115
+ super().maybe_add_ids(documents)
116
+ if documents is None:
117
+ return
118
+ contents: List[str] = [document.content for document in documents]
119
+ # convert metadatas to dicts so chroma can handle them
120
+ metadata_dicts: List[dict[str, Any]] = [
121
+ d.metadata.dict_bool_int() for d in documents
122
+ ]
123
+ for m in metadata_dicts:
124
+ # chroma does not handle non-atomic types in metadata
125
+ m["window_ids"] = ",".join(m["window_ids"])
126
+
127
+ ids = [str(d.id()) for d in documents]
128
+
129
+ colls = self.list_collections(empty=True)
130
+ if self.config.collection_name is None:
131
+ raise ValueError("No collection name set, cannot ingest docs")
132
+ if self.config.collection_name not in colls:
133
+ self.create_collection(self.config.collection_name, replace=True)
134
+
135
+ self.collection.add(
136
+ # embedding_models=embedding_models,
137
+ documents=contents,
138
+ metadatas=metadata_dicts,
139
+ ids=ids,
140
+ )
141
+
142
+ def get_all_documents(self, where: str = "") -> List[Document]:
143
+ filter = json.loads(where) if where else None
144
+ results = self.collection.get(
145
+ include=["documents", "metadatas"],
146
+ where=filter,
147
+ )
148
+ results["documents"] = [results["documents"]]
149
+ results["metadatas"] = [results["metadatas"]]
150
+ return self._docs_from_results(results)
151
+
152
+ def get_documents_by_ids(self, ids: List[str]) -> List[Document]:
153
+ # get them one by one since chroma mangles the order of the results
154
+ # when fetched from a list of ids.
155
+ results = [
156
+ self.collection.get(ids=[id], include=["documents", "metadatas"])
157
+ for id in ids
158
+ ]
159
+ final_results = {}
160
+ final_results["documents"] = [[r["documents"][0] for r in results]]
161
+ final_results["metadatas"] = [[r["metadatas"][0] for r in results]]
162
+ return self._docs_from_results(final_results)
163
+
164
+ def delete_collection(self, collection_name: str) -> None:
165
+ try:
166
+ self.client.delete_collection(name=collection_name)
167
+ except Exception:
168
+ pass
169
+
170
+ def similar_texts_with_scores(
171
+ self, text: str, k: int = 1, where: Optional[str] = None
172
+ ) -> List[Tuple[Document, float]]:
173
+ n = self.collection.count()
174
+ filter = json.loads(where) if where else None
175
+ results = self.collection.query(
176
+ query_texts=[text],
177
+ n_results=min(n, k),
178
+ where=filter,
179
+ include=["documents", "distances", "metadatas"],
180
+ )
181
+ docs = self._docs_from_results(results)
182
+ # chroma distances are 1 - cosine.
183
+ scores = [1 - s for s in results["distances"][0]]
184
+ return list(zip(docs, scores))
185
+
186
+ def _docs_from_results(self, results: Dict[str, Any]) -> List[Document]:
187
+ """
188
+ Helper function to convert results from ChromaDB to a list of Documents
189
+ Args:
190
+ results (dict): results from ChromaDB
191
+
192
+ Returns:
193
+ List[Document]: list of Documents
194
+ """
195
+ if len(results["documents"][0]) == 0:
196
+ return []
197
+ contents = results["documents"][0]
198
+ if settings.debug:
199
+ for i, c in enumerate(contents):
200
+ print_long_text("red", "italic red", f"MATCH-{i}", c)
201
+ metadatas = results["metadatas"][0]
202
+ for m in metadatas:
203
+ # restore the stringified list of window_ids into the original List[str]
204
+ if m["window_ids"].strip() == "":
205
+ m["window_ids"] = []
206
+ else:
207
+ m["window_ids"] = m["window_ids"].split(",")
208
+ docs = [
209
+ self.config.document_class(
210
+ content=d, metadata=self.config.metadata_class(**m)
211
+ )
212
+ for d, m in zip(contents, metadatas)
213
+ ]
214
+ return docs