langroid 0.1.100__py3-none-any.whl → 0.1.102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +9 -8
- langroid/agent/batch.py +6 -4
- langroid/agent/chat_agent.py +9 -7
- langroid/agent/special/doc_chat_agent.py +100 -4
- langroid/agent/special/relevance_extractor_agent.py +11 -5
- langroid/agent/special/retriever_agent.py +1 -1
- langroid/agent/task.py +13 -10
- langroid/mytypes.py +10 -4
- langroid/parsing/document_parser.py +1 -0
- langroid/parsing/parser.py +62 -31
- langroid/parsing/search.py +54 -49
- langroid/parsing/utils.py +26 -0
- langroid/utils/algorithms/graph.py +49 -0
- langroid/utils/configuration.py +30 -1
- langroid/utils/output/printing.py +31 -1
- langroid/utils/pydantic_utils.py +3 -1
- langroid/vector_store/base.py +157 -1
- langroid/vector_store/chromadb.py +12 -19
- langroid/vector_store/meilisearch.py +1 -0
- langroid/vector_store/momento.py +1 -0
- langroid/vector_store/qdrantdb.py +10 -4
- {langroid-0.1.100.dist-info → langroid-0.1.102.dist-info}/METADATA +1 -1
- {langroid-0.1.100.dist-info → langroid-0.1.102.dist-info}/RECORD +25 -24
- {langroid-0.1.100.dist-info → langroid-0.1.102.dist-info}/LICENSE +0 -0
- {langroid-0.1.100.dist-info → langroid-0.1.102.dist-info}/WHEEL +0 -0
langroid/parsing/search.py
CHANGED
@@ -7,7 +7,6 @@ See tests for examples: tests/main/test_string_search.py
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
import difflib
|
10
|
-
import re
|
11
10
|
from typing import List, Tuple
|
12
11
|
|
13
12
|
from nltk.corpus import stopwords
|
@@ -24,6 +23,7 @@ from .utils import download_nltk_resource
|
|
24
23
|
def find_fuzzy_matches_in_docs(
|
25
24
|
query: str,
|
26
25
|
docs: List[Document],
|
26
|
+
docs_clean: List[Document],
|
27
27
|
k: int,
|
28
28
|
words_before: int | None = None,
|
29
29
|
words_after: int | None = None,
|
@@ -49,45 +49,45 @@ def find_fuzzy_matches_in_docs(
|
|
49
49
|
return []
|
50
50
|
best_matches = process.extract(
|
51
51
|
query,
|
52
|
-
[d.content for d in
|
52
|
+
[d.content for d in docs_clean],
|
53
53
|
limit=k,
|
54
54
|
scorer=fuzz.partial_ratio,
|
55
55
|
)
|
56
56
|
|
57
57
|
real_matches = [m for m, score in best_matches if score > 50]
|
58
|
-
|
59
|
-
|
60
|
-
for
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
words_in_text = doc.content.split()
|
65
|
-
first_word_idx = next(
|
66
|
-
(
|
67
|
-
i
|
68
|
-
for i, word in enumerate(words_in_text)
|
69
|
-
if word.startswith(words[0])
|
70
|
-
),
|
71
|
-
-1,
|
72
|
-
)
|
73
|
-
if words_before is None:
|
74
|
-
words_before = len(words_in_text)
|
75
|
-
if words_after is None:
|
76
|
-
words_after = len(words_in_text)
|
77
|
-
if first_word_idx != -1:
|
78
|
-
start_idx = max(0, first_word_idx - words_before)
|
79
|
-
end_idx = min(
|
80
|
-
len(words_in_text),
|
81
|
-
first_word_idx + len(words) + words_after,
|
82
|
-
)
|
83
|
-
doc_match = Document(
|
84
|
-
content=" ".join(words_in_text[start_idx:end_idx]),
|
85
|
-
metadata=doc.metadata,
|
86
|
-
)
|
87
|
-
results.append(doc_match)
|
58
|
+
# find the original docs that corresponding to the matches
|
59
|
+
orig_doc_matches = []
|
60
|
+
for i, m in enumerate(real_matches):
|
61
|
+
for j, doc_clean in enumerate(docs_clean):
|
62
|
+
if m in doc_clean.content:
|
63
|
+
orig_doc_matches.append(docs[j])
|
88
64
|
break
|
65
|
+
if words_after is None and words_before is None:
|
66
|
+
return orig_doc_matches
|
67
|
+
|
68
|
+
contextual_matches = []
|
69
|
+
for match in orig_doc_matches:
|
70
|
+
choice_text = match.content
|
71
|
+
contexts = []
|
72
|
+
while choice_text != "":
|
73
|
+
context, start_pos, end_pos = get_context(
|
74
|
+
query, choice_text, words_before, words_after
|
75
|
+
)
|
76
|
+
if context == "" or end_pos == 0:
|
77
|
+
break
|
78
|
+
contexts.append(context)
|
79
|
+
words = choice_text.split()
|
80
|
+
end_pos = min(end_pos, len(words))
|
81
|
+
choice_text = " ".join(words[end_pos:])
|
82
|
+
if len(contexts) > 0:
|
83
|
+
contextual_matches.append(
|
84
|
+
Document(
|
85
|
+
content=" ... ".join(contexts),
|
86
|
+
metadata=match.metadata,
|
87
|
+
)
|
88
|
+
)
|
89
89
|
|
90
|
-
return
|
90
|
+
return contextual_matches
|
91
91
|
|
92
92
|
|
93
93
|
def preprocess_text(text: str) -> str:
|
@@ -171,7 +171,7 @@ def get_context(
|
|
171
171
|
text: str,
|
172
172
|
words_before: int | None = 100,
|
173
173
|
words_after: int | None = 100,
|
174
|
-
) -> str:
|
174
|
+
) -> Tuple[str, int, int]:
|
175
175
|
"""
|
176
176
|
Returns a portion of text containing the best approximate match of the query,
|
177
177
|
including b words before and a words after the match.
|
@@ -185,7 +185,9 @@ def get_context(
|
|
185
185
|
Returns:
|
186
186
|
str: A string containing b words before, the match, and a words after
|
187
187
|
the best approximate match position of the query in the text. If no
|
188
|
-
match is found, returns
|
188
|
+
match is found, returns empty string.
|
189
|
+
int: The start position of the match in the text.
|
190
|
+
int: The end position of the match in the text.
|
189
191
|
|
190
192
|
Example:
|
191
193
|
>>> get_context("apple", "The quick brown fox jumps over the apple.", 3, 2)
|
@@ -193,26 +195,29 @@ def get_context(
|
|
193
195
|
"""
|
194
196
|
if words_after is None and words_before is None:
|
195
197
|
# return entire text since we're not asked to return a bounded context
|
196
|
-
return text
|
198
|
+
return text, 0, 0
|
199
|
+
|
200
|
+
# make sure there is a good enough fu
|
201
|
+
if fuzz.partial_ratio(query, text) < 70:
|
202
|
+
return "", 0, 0
|
197
203
|
|
198
204
|
sequence_matcher = difflib.SequenceMatcher(None, text, query)
|
199
205
|
match = sequence_matcher.find_longest_match(0, len(text), 0, len(query))
|
200
206
|
|
201
207
|
if match.size == 0:
|
202
|
-
return "
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
)
|
208
|
+
return "", 0, 0
|
209
|
+
|
210
|
+
segments = text.split()
|
211
|
+
n_segs = len(segments)
|
212
|
+
|
213
|
+
start_segment_pos = len(text[: match.a].split())
|
214
|
+
|
215
|
+
words_before = words_before or n_segs
|
216
|
+
words_after = words_after or n_segs
|
217
|
+
start_pos = max(0, start_segment_pos - words_before)
|
218
|
+
end_pos = min(len(segments), start_segment_pos + words_after + len(query.split()))
|
214
219
|
|
215
|
-
return " ".join(
|
220
|
+
return " ".join(segments[start_pos:end_pos]), start_pos, end_pos
|
216
221
|
|
217
222
|
|
218
223
|
def eliminate_near_duplicates(passages: List[str], threshold: float = 0.8) -> List[str]:
|
langroid/parsing/utils.py
CHANGED
@@ -165,6 +165,32 @@ def parse_number_range_list(specs: str) -> List[int]:
|
|
165
165
|
return sorted(list(spec_indices))
|
166
166
|
|
167
167
|
|
168
|
+
def strip_k(s: str, k: int = 2) -> str:
|
169
|
+
"""
|
170
|
+
Strip any leading and trailing whitespaces from the input text beyond length k.
|
171
|
+
This is useful for removing leading/trailing whitespaces from a text while
|
172
|
+
preserving paragraph structure.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
s (str): The input text.
|
176
|
+
k (int): The number of leading and trailing whitespaces to retain.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
str: The text with leading and trailing whitespaces removed beyond length k.
|
180
|
+
"""
|
181
|
+
|
182
|
+
# Count leading and trailing whitespaces
|
183
|
+
leading_count = len(s) - len(s.lstrip())
|
184
|
+
trailing_count = len(s) - len(s.rstrip())
|
185
|
+
|
186
|
+
# Determine how many whitespaces to retain
|
187
|
+
leading_keep = min(leading_count, k)
|
188
|
+
trailing_keep = min(trailing_count, k)
|
189
|
+
|
190
|
+
# Use slicing to get the desired output
|
191
|
+
return s[leading_count - leading_keep : len(s) - (trailing_count - trailing_keep)]
|
192
|
+
|
193
|
+
|
168
194
|
def clean_whitespace(text: str) -> str:
|
169
195
|
"""Remove extra whitespace from the input text, while preserving
|
170
196
|
paragraph structure.
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""
|
2
|
+
Graph algos.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import List, no_type_check
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
|
10
|
+
@no_type_check
|
11
|
+
def topological_sort(order: np.array) -> List[int]:
|
12
|
+
"""
|
13
|
+
Given a directed adjacency matrix, return a topological sort of the nodes.
|
14
|
+
order[i,j] = -1 means there is an edge from i to j.
|
15
|
+
order[i,j] = 0 means there is no edge from i to j.
|
16
|
+
order[i,j] = 1 means there is an edge from j to i.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
order (np.array): The adjacency matrix.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
List[int]: The topological sort of the nodes.
|
23
|
+
|
24
|
+
"""
|
25
|
+
n = order.shape[0]
|
26
|
+
|
27
|
+
# Calculate the in-degrees
|
28
|
+
in_degree = [0] * n
|
29
|
+
for i in range(n):
|
30
|
+
for j in range(n):
|
31
|
+
if order[i, j] == -1:
|
32
|
+
in_degree[j] += 1
|
33
|
+
|
34
|
+
# Initialize the queue with nodes of in-degree 0
|
35
|
+
queue = [i for i in range(n) if in_degree[i] == 0]
|
36
|
+
result = []
|
37
|
+
|
38
|
+
while queue:
|
39
|
+
node = queue.pop(0)
|
40
|
+
result.append(node)
|
41
|
+
|
42
|
+
for i in range(n):
|
43
|
+
if order[node, i] == -1:
|
44
|
+
in_degree[i] -= 1
|
45
|
+
if in_degree[i] == 0:
|
46
|
+
queue.append(i)
|
47
|
+
|
48
|
+
assert len(result) == n, "Cycle detected"
|
49
|
+
return result
|
langroid/utils/configuration.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
import copy
|
1
2
|
import os
|
2
|
-
from
|
3
|
+
from contextlib import contextmanager
|
4
|
+
from typing import Iterator, List
|
3
5
|
|
4
6
|
from dotenv import find_dotenv, load_dotenv
|
5
7
|
from pydantic import BaseSettings
|
@@ -17,6 +19,7 @@ class Settings(BaseSettings):
|
|
17
19
|
gpt3_5: bool = True # use GPT-3.5?
|
18
20
|
nofunc: bool = False # use model without function_call? (i.e. gpt-4)
|
19
21
|
chat_model: str = "" # language model name, e.g. litellm/ollama/llama2
|
22
|
+
quiet: bool = False # quiet mode (i.e. suppress all output)?
|
20
23
|
|
21
24
|
class Config:
|
22
25
|
extra = "forbid"
|
@@ -55,6 +58,32 @@ def set_global(key_vals: Settings) -> None:
|
|
55
58
|
settings.__dict__.update(key_vals.__dict__)
|
56
59
|
|
57
60
|
|
61
|
+
@contextmanager
|
62
|
+
def temporary_settings(temp_settings: Settings) -> Iterator[None]:
|
63
|
+
"""Temporarily update the global settings and restore them afterward."""
|
64
|
+
original_settings = copy.deepcopy(settings)
|
65
|
+
|
66
|
+
set_global(temp_settings)
|
67
|
+
|
68
|
+
try:
|
69
|
+
yield
|
70
|
+
finally:
|
71
|
+
settings.__dict__.update(original_settings.__dict__)
|
72
|
+
|
73
|
+
|
74
|
+
@contextmanager
|
75
|
+
def quiet_mode() -> Iterator[None]:
|
76
|
+
"""Temporarily set quiet=True in global settings and restore afterward."""
|
77
|
+
original_quiet = settings.quiet
|
78
|
+
|
79
|
+
set_global(Settings(quiet=True))
|
80
|
+
|
81
|
+
try:
|
82
|
+
yield
|
83
|
+
finally:
|
84
|
+
settings.quiet = original_quiet
|
85
|
+
|
86
|
+
|
58
87
|
def set_env(settings: BaseSettings) -> None:
|
59
88
|
"""
|
60
89
|
Set environment variables from a BaseSettings instance
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import sys
|
2
|
-
from
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from typing import Any, Iterator, Optional
|
3
4
|
|
4
5
|
from rich import print as rprint
|
5
6
|
from rich.text import Text
|
@@ -46,3 +47,32 @@ class PrintColored:
|
|
46
47
|
|
47
48
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
48
49
|
print(Colors().RESET)
|
50
|
+
|
51
|
+
|
52
|
+
@contextmanager
|
53
|
+
def silence_stdout() -> Iterator[None]:
|
54
|
+
"""
|
55
|
+
Temporarily silence all output to stdout and from rich.print.
|
56
|
+
|
57
|
+
This context manager redirects all output written to stdout (which includes
|
58
|
+
outputs from the built-in print function and rich.print) to /dev/null on
|
59
|
+
UNIX-like systems or NUL on Windows. Once the context block exits, stdout is
|
60
|
+
restored to its original state.
|
61
|
+
|
62
|
+
Example:
|
63
|
+
with silence_stdout_and_rich():
|
64
|
+
print("This won't be printed")
|
65
|
+
rich.print("This also won't be printed")
|
66
|
+
|
67
|
+
Note:
|
68
|
+
This suppresses both standard print functions and the rich library outputs.
|
69
|
+
"""
|
70
|
+
platform_null = "/dev/null" if sys.platform != "win32" else "NUL"
|
71
|
+
original_stdout = sys.stdout
|
72
|
+
fnull = open(platform_null, "w")
|
73
|
+
sys.stdout = fnull
|
74
|
+
try:
|
75
|
+
yield
|
76
|
+
finally:
|
77
|
+
sys.stdout = original_stdout
|
78
|
+
fnull.close()
|
langroid/utils/pydantic_utils.py
CHANGED
@@ -79,7 +79,9 @@ def flatten_pydantic_model(
|
|
79
79
|
current_model, current_prefix = models_to_process.pop()
|
80
80
|
|
81
81
|
for name, field in current_model.__fields__.items():
|
82
|
-
if
|
82
|
+
if isinstance(field.outer_type_, type) and issubclass(
|
83
|
+
field.outer_type_, BaseModel
|
84
|
+
):
|
83
85
|
new_prefix = (
|
84
86
|
f"{current_prefix}{name}__" if current_prefix else f"{name}__"
|
85
87
|
)
|
langroid/vector_store/base.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1
|
+
import copy
|
1
2
|
import logging
|
2
3
|
from abc import ABC, abstractmethod
|
3
|
-
from
|
4
|
+
from math import ceil
|
5
|
+
from typing import Dict, List, Optional, Sequence, Tuple
|
4
6
|
|
7
|
+
import numpy as np
|
5
8
|
from pydantic import BaseSettings
|
6
9
|
|
7
10
|
from langroid.embedding_models.base import EmbeddingModelsConfig
|
8
11
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
9
12
|
from langroid.mytypes import Document
|
13
|
+
from langroid.utils.algorithms.graph import topological_sort
|
10
14
|
from langroid.utils.configuration import settings
|
11
15
|
from langroid.utils.output.printing import print_long_text
|
12
16
|
|
@@ -130,8 +134,160 @@ class VectorStore(ABC):
|
|
130
134
|
k: int = 1,
|
131
135
|
where: Optional[str] = None,
|
132
136
|
) -> List[Tuple[Document, float]]:
|
137
|
+
"""
|
138
|
+
Find k most similar texts to the given text, in terms of vector distance metric
|
139
|
+
(e.g., cosine similarity).
|
140
|
+
|
141
|
+
Args:
|
142
|
+
text (str): The text to find similar texts for.
|
143
|
+
k (int, optional): Number of similar texts to retrieve. Defaults to 1.
|
144
|
+
where (Optional[str], optional): Where clause to filter the search.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
List[Tuple[Document,float]]: List of (Document, score) tuples.
|
148
|
+
|
149
|
+
"""
|
133
150
|
pass
|
134
151
|
|
152
|
+
def add_context_window(
|
153
|
+
self, docs_scores: List[Tuple[Document, float]], neighbors: int = 0
|
154
|
+
) -> List[Tuple[Document, float]]:
|
155
|
+
"""
|
156
|
+
In each doc's metadata, there may be a window_ids field indicating
|
157
|
+
the ids of the chunks around the current chunk.
|
158
|
+
These window_ids may overlap, so we
|
159
|
+
- gather connected-components of overlapping windows,
|
160
|
+
- split each component into roughly equal parts,
|
161
|
+
- create a new document for each part, preserving metadata,
|
162
|
+
|
163
|
+
We may have stored a longer set of window_ids than we need.
|
164
|
+
We just want `neighbors` on each side of the center of window_ids.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
docs (List[Document]): List of documents to add context window to.
|
168
|
+
scores (List[float]): List of match scores for each document.
|
169
|
+
neighbors (int, optional): Number of neighbors on "each side" of match to
|
170
|
+
retrieve. Defaults to 0.
|
171
|
+
"Each side" here means before and after the match,
|
172
|
+
in the original text.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
List[Tuple[Document, float]]: List of (Document, score) tuples.
|
176
|
+
"""
|
177
|
+
# We return a larger context around each match, i.e.
|
178
|
+
# a window of `neighbors` on each side of the match.
|
179
|
+
docs = [d for d, s in docs_scores]
|
180
|
+
scores = [s for d, s in docs_scores]
|
181
|
+
if neighbors == 0:
|
182
|
+
return docs_scores
|
183
|
+
doc_chunks = [d for d in docs if d.metadata.is_chunk]
|
184
|
+
if len(doc_chunks) == 0:
|
185
|
+
return docs_scores
|
186
|
+
window_ids_list = []
|
187
|
+
id2metadata = {}
|
188
|
+
# id -> highest score of a doc it appears in
|
189
|
+
id2max_score: Dict[int | str, float] = {}
|
190
|
+
for i, d in enumerate(docs):
|
191
|
+
window_ids = d.metadata.window_ids
|
192
|
+
id2metadata.update({id: d.metadata for id in window_ids})
|
193
|
+
|
194
|
+
id2max_score.update(
|
195
|
+
{id: max(id2max_score.get(id, 0), scores[i]) for id in window_ids}
|
196
|
+
)
|
197
|
+
n = len(window_ids)
|
198
|
+
chunk_idx = window_ids.index(d.id())
|
199
|
+
neighbor_ids = window_ids[
|
200
|
+
max(0, chunk_idx - neighbors) : min(n, chunk_idx + neighbors + 1)
|
201
|
+
]
|
202
|
+
window_ids_list += [neighbor_ids]
|
203
|
+
|
204
|
+
# window_ids could be from different docs,
|
205
|
+
# and they may overlap, so we first remove overlaps
|
206
|
+
window_ids_list = self.remove_overlaps(window_ids_list)
|
207
|
+
final_docs = []
|
208
|
+
final_scores = []
|
209
|
+
for w in window_ids_list:
|
210
|
+
metadata = copy.deepcopy(id2metadata[w[0]])
|
211
|
+
metadata.window_ids = w
|
212
|
+
document = Document(
|
213
|
+
content=" ".join([d.content for d in self.get_documents_by_ids(w)]),
|
214
|
+
metadata=metadata,
|
215
|
+
)
|
216
|
+
# make a fresh id since content is in general different
|
217
|
+
document.metadata.id = document.hash_id(document.content)
|
218
|
+
final_docs += [document]
|
219
|
+
final_scores += [max(id2max_score[id] for id in w)]
|
220
|
+
return list(zip(final_docs, final_scores))
|
221
|
+
|
222
|
+
@staticmethod
|
223
|
+
def remove_overlaps(windows: List[List[str]]) -> List[List[str]]:
|
224
|
+
"""
|
225
|
+
Given a collection of windows, where each window is a sequence of ids,
|
226
|
+
identify groups of overlapping windows, and for each overlapping k-group,
|
227
|
+
split the ids into k roughly equal sequences.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
windows (List[int|str]): List of windows, where each window is a
|
231
|
+
sequence of ids.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
List[int|str]: List of windows, where each window is a sequence of ids,
|
235
|
+
and no two windows overlap.
|
236
|
+
"""
|
237
|
+
ids = set(id for w in windows for id in w)
|
238
|
+
# id -> {win -> # pos}
|
239
|
+
id2win2pos: Dict[str, Dict[int, int]] = {id: {} for id in ids}
|
240
|
+
|
241
|
+
for i, w in enumerate(windows):
|
242
|
+
for j, id in enumerate(w):
|
243
|
+
id2win2pos[id][i] = j
|
244
|
+
|
245
|
+
n = len(windows)
|
246
|
+
# relation between windows:
|
247
|
+
order = np.zeros((n, n), dtype=np.int8)
|
248
|
+
for i, w in enumerate(windows):
|
249
|
+
for j, x in enumerate(windows):
|
250
|
+
if i == j:
|
251
|
+
continue
|
252
|
+
if len(set(w).intersection(x)) == 0:
|
253
|
+
continue
|
254
|
+
id = list(set(w).intersection(x))[0] # any common id
|
255
|
+
if id2win2pos[id][i] > id2win2pos[id][j]:
|
256
|
+
order[i, j] = -1 # win i is before win j
|
257
|
+
else:
|
258
|
+
order[i, j] = 1 # win i is after win j
|
259
|
+
|
260
|
+
# find groups of windows that overlap, like connected components in a graph
|
261
|
+
groups = [[0]]
|
262
|
+
for i in range(1, n):
|
263
|
+
found = False
|
264
|
+
for g in groups:
|
265
|
+
if any(order[i, j] != 0 for j in g):
|
266
|
+
g.append(i)
|
267
|
+
found = True
|
268
|
+
break
|
269
|
+
if not found:
|
270
|
+
groups.append([i])
|
271
|
+
|
272
|
+
# split each group into roughly equal parts
|
273
|
+
new_windows = []
|
274
|
+
max_window_len = max(len(w) for w in windows)
|
275
|
+
for g in groups:
|
276
|
+
# find total ordering among windows in group based on order matrix
|
277
|
+
# (this is a topological sort)
|
278
|
+
_g = np.array(g)
|
279
|
+
order_matrix = order[_g][:, _g]
|
280
|
+
ordered_window_indices = topological_sort(order_matrix)
|
281
|
+
ordered_window_ids = [windows[i] for i in _g[ordered_window_indices]]
|
282
|
+
flattened = [id for w in ordered_window_ids for id in w]
|
283
|
+
flattened_deduped = list(dict.fromkeys(flattened))
|
284
|
+
# split into k parts where k is the smallest integer such that
|
285
|
+
# each part has length <= max_window_len
|
286
|
+
k = max(1, int(ceil(len(flattened_deduped) / max_window_len)))
|
287
|
+
new_windows += np.array_split(flattened_deduped, k)
|
288
|
+
|
289
|
+
return [w.tolist() for w in new_windows]
|
290
|
+
|
135
291
|
@abstractmethod
|
136
292
|
def get_all_documents(self) -> List[Document]:
|
137
293
|
"""
|
@@ -109,14 +109,17 @@ class ChromaDB(VectorStore):
|
|
109
109
|
if documents is None:
|
110
110
|
return
|
111
111
|
contents: List[str] = [document.content for document in documents]
|
112
|
-
metadatas
|
113
|
-
|
114
|
-
|
112
|
+
# convert metadatas to dicts so chroma can handle them
|
113
|
+
metadata_dicts: List[dict[str, Any]] = [d.metadata.dict() for d in documents]
|
114
|
+
for m in metadata_dicts:
|
115
|
+
# chroma does not handle non-atomic types in metadata
|
116
|
+
m["window_ids"] = ",".join(m["window_ids"])
|
117
|
+
|
115
118
|
ids = [str(d.id()) for d in documents]
|
116
119
|
self.collection.add(
|
117
120
|
# embedding_models=embedding_models,
|
118
121
|
documents=contents,
|
119
|
-
metadatas=
|
122
|
+
metadatas=metadata_dicts,
|
120
123
|
ids=ids,
|
121
124
|
)
|
122
125
|
|
@@ -145,7 +148,8 @@ class ChromaDB(VectorStore):
|
|
145
148
|
include=["documents", "distances", "metadatas"],
|
146
149
|
)
|
147
150
|
docs = self._docs_from_results(results)
|
148
|
-
|
151
|
+
# chroma distances are 1 - cosine.
|
152
|
+
scores = [1 - s for s in results["distances"][0]]
|
149
153
|
return list(zip(docs, scores))
|
150
154
|
|
151
155
|
def _docs_from_results(self, results: Dict[str, Any]) -> List[Document]:
|
@@ -164,22 +168,11 @@ class ChromaDB(VectorStore):
|
|
164
168
|
for i, c in enumerate(contents):
|
165
169
|
print_long_text("red", "italic red", f"MATCH-{i}", c)
|
166
170
|
metadatas = results["metadatas"][0]
|
171
|
+
for m in metadatas:
|
172
|
+
# restore the stringified list of window_ids into the original List[str]
|
173
|
+
m["window_ids"] = m["window_ids"].split(",")
|
167
174
|
docs = [
|
168
175
|
Document(content=d, metadata=DocMetaData(**m))
|
169
176
|
for d, m in zip(contents, metadatas)
|
170
177
|
]
|
171
178
|
return docs
|
172
|
-
|
173
|
-
|
174
|
-
# Example usage and testing
|
175
|
-
# chroma_db = ChromaDB.from_documents(
|
176
|
-
# collection_name="all-my-documents",
|
177
|
-
# documents=["doc1000101", "doc288822"],
|
178
|
-
# metadatas=[{"style": "style1"}, {"style": "style2"}],
|
179
|
-
# ids=["uri9", "uri10"]
|
180
|
-
# )
|
181
|
-
# results = chroma_db.query(
|
182
|
-
# query_texts=["This is a query document"],
|
183
|
-
# n_results=2
|
184
|
-
# )
|
185
|
-
# print(results)
|
@@ -263,6 +263,7 @@ class MeiliSearch(VectorStore):
|
|
263
263
|
text: str,
|
264
264
|
k: int = 20,
|
265
265
|
where: Optional[str] = None,
|
266
|
+
neighbors: int = 0, # ignored
|
266
267
|
) -> List[Tuple[Document, float]]:
|
267
268
|
filter = [] if where is None else where
|
268
269
|
if self.config.collection_name is None:
|
langroid/vector_store/momento.py
CHANGED
@@ -222,6 +222,7 @@ class MomentoVI(VectorStore):
|
|
222
222
|
text: str,
|
223
223
|
k: int = 1,
|
224
224
|
where: Optional[str] = None,
|
225
|
+
neighbors: int = 0, # ignored
|
225
226
|
) -> List[Tuple[Document, float]]:
|
226
227
|
if self.config.collection_name is None:
|
227
228
|
raise ValueError("No collection name set, cannot search")
|
@@ -244,7 +244,11 @@ class QdrantDB(VectorStore):
|
|
244
244
|
with_vectors=False,
|
245
245
|
with_payload=True,
|
246
246
|
)
|
247
|
-
|
247
|
+
# Note the records may NOT be in the order of the ids,
|
248
|
+
# so we re-order them here.
|
249
|
+
id2payload = {record.id: record.payload for record in records}
|
250
|
+
ordered_payloads = [id2payload[id] for id in _ids]
|
251
|
+
docs = [Document(**payload) for payload in ordered_payloads] # type: ignore
|
248
252
|
return docs
|
249
253
|
|
250
254
|
def similar_texts_with_scores(
|
@@ -252,6 +256,7 @@ class QdrantDB(VectorStore):
|
|
252
256
|
text: str,
|
253
257
|
k: int = 1,
|
254
258
|
where: Optional[str] = None,
|
259
|
+
neighbors: int = 0,
|
255
260
|
) -> List[Tuple[Document, float]]:
|
256
261
|
embedding = self.embedding_fn([text])[0]
|
257
262
|
# TODO filter may not work yet
|
@@ -268,7 +273,7 @@ class QdrantDB(VectorStore):
|
|
268
273
|
exact=False, # use Apx NN, not exact NN
|
269
274
|
),
|
270
275
|
)
|
271
|
-
scores = [match.score for match in search_result]
|
276
|
+
scores = [match.score for match in search_result if match is not None]
|
272
277
|
docs = [
|
273
278
|
Document(**(match.payload)) # type: ignore
|
274
279
|
for match in search_result
|
@@ -277,8 +282,9 @@ class QdrantDB(VectorStore):
|
|
277
282
|
if len(docs) == 0:
|
278
283
|
logger.warning(f"No matches found for {text}")
|
279
284
|
return []
|
280
|
-
if settings.debug:
|
281
|
-
logger.info(f"Found {len(docs)} matches, max score: {max(scores)}")
|
282
285
|
doc_score_pairs = list(zip(docs, scores))
|
286
|
+
max_score = max(ds[1] for ds in doc_score_pairs)
|
287
|
+
if settings.debug:
|
288
|
+
logger.info(f"Found {len(doc_score_pairs)} matches, max score: {max_score}")
|
283
289
|
self.show_if_debug(doc_score_pairs)
|
284
290
|
return doc_score_pairs
|