langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -12,20 +12,30 @@ langroid with the [hf-embeddings] extra, e.g.:
|
|
12
12
|
pip install "langroid[hf-embeddings]"
|
13
13
|
|
14
14
|
"""
|
15
|
+
|
15
16
|
import logging
|
16
17
|
from contextlib import ExitStack
|
17
|
-
from
|
18
|
+
from functools import cache
|
19
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, no_type_check
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
+
import nest_asyncio
|
22
|
+
import numpy as np
|
23
|
+
import pandas as pd
|
24
|
+
from rich.prompt import Prompt
|
21
25
|
|
22
|
-
from langroid.agent.
|
26
|
+
from langroid.agent.batch import run_batch_tasks
|
23
27
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
24
28
|
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
29
|
+
from langroid.agent.special.relevance_extractor_agent import (
|
30
|
+
RelevanceExtractorAgent,
|
31
|
+
RelevanceExtractorAgentConfig,
|
32
|
+
)
|
33
|
+
from langroid.agent.task import Task
|
25
34
|
from langroid.embedding_models.models import OpenAIEmbeddingsConfig
|
26
35
|
from langroid.language_models.base import StreamingIfAllowed
|
27
36
|
from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
|
28
37
|
from langroid.mytypes import DocMetaData, Document, Entity
|
38
|
+
from langroid.parsing.document_parser import DocumentType
|
29
39
|
from langroid.parsing.parser import Parser, ParsingConfig, PdfParsingConfig, Splitter
|
30
40
|
from langroid.parsing.repo_loader import RepoLoader
|
31
41
|
from langroid.parsing.search import (
|
@@ -33,20 +43,26 @@ from langroid.parsing.search import (
|
|
33
43
|
find_fuzzy_matches_in_docs,
|
34
44
|
preprocess_text,
|
35
45
|
)
|
46
|
+
from langroid.parsing.table_loader import describe_dataframe
|
36
47
|
from langroid.parsing.url_loader import URLLoader
|
37
|
-
from langroid.parsing.urls import
|
48
|
+
from langroid.parsing.urls import get_list_from_user, get_urls_paths_bytes_indices
|
38
49
|
from langroid.parsing.utils import batched
|
39
50
|
from langroid.prompts.prompts_config import PromptsConfig
|
40
51
|
from langroid.prompts.templates import SUMMARY_ANSWER_PROMPT_GPT4
|
41
52
|
from langroid.utils.configuration import settings
|
42
53
|
from langroid.utils.constants import NO_ANSWER
|
43
|
-
from langroid.utils.output
|
44
|
-
from langroid.
|
45
|
-
from langroid.vector_store.
|
54
|
+
from langroid.utils.output import show_if_debug, status
|
55
|
+
from langroid.utils.pydantic_utils import dataframe_to_documents, extract_fields
|
56
|
+
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
57
|
+
from langroid.vector_store.lancedb import LanceDBConfig
|
46
58
|
|
47
|
-
logger = logging.getLogger(__name__)
|
48
59
|
|
49
|
-
|
60
|
+
@cache
|
61
|
+
def apply_nest_asyncio() -> None:
|
62
|
+
nest_asyncio.apply()
|
63
|
+
|
64
|
+
|
65
|
+
logger = logging.getLogger(__name__)
|
50
66
|
|
51
67
|
DEFAULT_DOC_CHAT_INSTRUCTIONS = """
|
52
68
|
Your task is to answer questions about various documents.
|
@@ -58,25 +74,29 @@ DEFAULT_DOC_CHAT_SYSTEM_MESSAGE = """
|
|
58
74
|
You are a helpful assistant, helping me understand a collection of documents.
|
59
75
|
"""
|
60
76
|
|
77
|
+
has_sentence_transformers = False
|
78
|
+
try:
|
79
|
+
from sentence_transformer import SentenceTransformer # noqa: F401
|
61
80
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
max_context_tokens (int): threshold to use for various steps, e.g.
|
66
|
-
if we are able to fit the current stage of doc processing into
|
67
|
-
this many tokens, we skip additional compression steps, and
|
68
|
-
use the current docs as-is in the context
|
69
|
-
conversation_mode (bool): if True, we will accumulate message history,
|
70
|
-
and pass entire history to LLM at each round.
|
71
|
-
If False, each request to LLM will consist only of the
|
72
|
-
initial task messages plus the current query.
|
73
|
-
"""
|
81
|
+
has_sentence_transformers = True
|
82
|
+
except ImportError:
|
83
|
+
pass
|
74
84
|
|
85
|
+
|
86
|
+
class DocChatAgentConfig(ChatAgentConfig):
|
75
87
|
system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
|
76
88
|
user_message: str = DEFAULT_DOC_CHAT_INSTRUCTIONS
|
77
89
|
summarize_prompt: str = SUMMARY_ANSWER_PROMPT_GPT4
|
78
|
-
|
79
|
-
|
90
|
+
# extra fields to include in content as key=value pairs
|
91
|
+
# (helps retrieval for table-like data)
|
92
|
+
add_fields_to_content: List[str] = []
|
93
|
+
filter_fields: List[str] = [] # fields usable in filter
|
94
|
+
retrieve_only: bool = False # only retr relevant extracts, don't gen summary answer
|
95
|
+
extraction_granularity: int = 1 # granularity (in sentences) for relev extraction
|
96
|
+
filter: str | None = (
|
97
|
+
None # filter condition for various lexical/semantic search fns
|
98
|
+
)
|
99
|
+
conversation_mode: bool = True # accumulate message history?
|
80
100
|
# In assistant mode, DocChatAgent receives questions from another Agent,
|
81
101
|
# and those will already be in stand-alone form, so in this mode
|
82
102
|
# there is no need to convert them to stand-alone form.
|
@@ -88,14 +108,26 @@ class DocChatAgentConfig(ChatAgentConfig):
|
|
88
108
|
# It is False by default; its benefits depends on the context.
|
89
109
|
hypothetical_answer: bool = False
|
90
110
|
n_query_rephrases: int = 0
|
111
|
+
n_neighbor_chunks: int = 0 # how many neighbors on either side of match to retrieve
|
112
|
+
n_fuzzy_neighbor_words: int = 100 # num neighbor words to retrieve for fuzzy match
|
91
113
|
use_fuzzy_match: bool = True
|
92
114
|
use_bm25_search: bool = True
|
93
|
-
cross_encoder_reranking_model: str =
|
115
|
+
cross_encoder_reranking_model: str = (
|
116
|
+
"cross-encoder/ms-marco-MiniLM-L-6-v2" if has_sentence_transformers else ""
|
117
|
+
)
|
118
|
+
rerank_diversity: bool = True # rerank to maximize diversity?
|
119
|
+
rerank_periphery: bool = True # rerank to avoid Lost In the Middle effect?
|
94
120
|
embed_batch_size: int = 500 # get embedding of at most this many at a time
|
95
121
|
cache: bool = True # cache results
|
96
122
|
debug: bool = False
|
97
123
|
stream: bool = True # allow streaming where needed
|
98
|
-
|
124
|
+
split: bool = True # use chunking
|
125
|
+
relevance_extractor_config: None | RelevanceExtractorAgentConfig = (
|
126
|
+
RelevanceExtractorAgentConfig(
|
127
|
+
llm=None # use the parent's llm unless explicitly set here
|
128
|
+
)
|
129
|
+
)
|
130
|
+
doc_paths: List[str | bytes] = []
|
99
131
|
default_paths: List[str] = [
|
100
132
|
"https://news.ycombinator.com/item?id=35629033",
|
101
133
|
"https://www.newyorker.com/tech/annals-of-technology/chatgpt-is-a-blurry-jpeg-of-the-web",
|
@@ -115,11 +147,12 @@ class DocChatAgentConfig(ChatAgentConfig):
|
|
115
147
|
min_chunk_chars=200,
|
116
148
|
discard_chunk_chars=5, # discard chunks with fewer than this many chars
|
117
149
|
n_similar_docs=3,
|
150
|
+
n_neighbor_ids=0, # num chunk IDs to store on either side of each chunk
|
118
151
|
pdf=PdfParsingConfig(
|
119
152
|
# NOTE: PDF parsing is extremely challenging, and each library
|
120
153
|
# has its own strengths and weaknesses.
|
121
154
|
# Try one that works for your use case.
|
122
|
-
# or "
|
155
|
+
# or "unstructured", "pdfplumber", "fitz", "pypdf"
|
123
156
|
library="pdfplumber",
|
124
157
|
),
|
125
158
|
)
|
@@ -136,10 +169,11 @@ class DocChatAgentConfig(ChatAgentConfig):
|
|
136
169
|
dims=1536,
|
137
170
|
)
|
138
171
|
|
139
|
-
vecdb: VectorStoreConfig =
|
140
|
-
collection_name=
|
141
|
-
|
142
|
-
|
172
|
+
vecdb: VectorStoreConfig = LanceDBConfig(
|
173
|
+
collection_name="doc-chat-lancedb",
|
174
|
+
replace_collection=True,
|
175
|
+
storage_path=".lancedb/data/",
|
176
|
+
embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
|
143
177
|
)
|
144
178
|
llm: OpenAIGPTConfig = OpenAIGPTConfig(
|
145
179
|
type="openai",
|
@@ -163,14 +197,40 @@ class DocChatAgent(ChatAgent):
|
|
163
197
|
):
|
164
198
|
super().__init__(config)
|
165
199
|
self.config: DocChatAgentConfig = config
|
166
|
-
self.original_docs:
|
200
|
+
self.original_docs: List[Document] = []
|
167
201
|
self.original_docs_length = 0
|
168
|
-
self.
|
169
|
-
self.
|
202
|
+
self.from_dataframe = False
|
203
|
+
self.df_description = ""
|
204
|
+
self.chunked_docs: List[Document] = []
|
205
|
+
self.chunked_docs_clean: List[Document] = []
|
170
206
|
self.response: None | Document = None
|
171
207
|
if len(config.doc_paths) > 0:
|
172
208
|
self.ingest()
|
173
209
|
|
210
|
+
def clear(self) -> None:
|
211
|
+
"""Clear the document collection and the specific collection in vecdb"""
|
212
|
+
if self.vecdb is None:
|
213
|
+
raise ValueError("VecDB not set")
|
214
|
+
self.original_docs = []
|
215
|
+
self.original_docs_length = 0
|
216
|
+
self.chunked_docs = []
|
217
|
+
self.chunked_docs_clean = []
|
218
|
+
collection_name = self.vecdb.config.collection_name
|
219
|
+
if collection_name is None:
|
220
|
+
return
|
221
|
+
try:
|
222
|
+
# Note we may have used a vecdb with a config.collection_name
|
223
|
+
# different from the agent's config.vecdb.collection_name!!
|
224
|
+
self.vecdb.delete_collection(collection_name)
|
225
|
+
self.vecdb = VectorStore.create(self.vecdb.config)
|
226
|
+
except Exception as e:
|
227
|
+
logger.warning(
|
228
|
+
f"""
|
229
|
+
Error while deleting collection {collection_name}:
|
230
|
+
{e}
|
231
|
+
"""
|
232
|
+
)
|
233
|
+
|
174
234
|
def ingest(self) -> None:
|
175
235
|
"""
|
176
236
|
Chunk + embed + store docs specified by self.config.doc_paths
|
@@ -187,59 +247,316 @@ class DocChatAgent(ChatAgent):
|
|
187
247
|
# do keyword and other non-vector searches
|
188
248
|
if self.vecdb is None:
|
189
249
|
raise ValueError("VecDB not set")
|
190
|
-
self.
|
191
|
-
self.chunked_docs_clean = [
|
192
|
-
Document(content=preprocess_text(d.content), metadata=d.metadata)
|
193
|
-
for d in self.chunked_docs
|
194
|
-
]
|
250
|
+
self.setup_documents(filter=self.config.filter)
|
195
251
|
return
|
196
|
-
|
252
|
+
self.ingest_doc_paths(self.config.doc_paths) # type: ignore
|
253
|
+
|
254
|
+
def ingest_doc_paths(
|
255
|
+
self,
|
256
|
+
paths: str | bytes | List[str | bytes],
|
257
|
+
metadata: (
|
258
|
+
List[Dict[str, Any]] | Dict[str, Any] | DocMetaData | List[DocMetaData]
|
259
|
+
) = [],
|
260
|
+
doc_type: str | DocumentType | None = None,
|
261
|
+
) -> List[Document]:
|
262
|
+
"""Split, ingest docs from specified paths,
|
263
|
+
do not add these to config.doc_paths.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
paths: document paths, urls or byte-content of docs.
|
267
|
+
The bytes option is intended to support cases where a document
|
268
|
+
has already been read in as bytes (e.g. from an API or a database),
|
269
|
+
and we want to avoid having to write it to a temporary file
|
270
|
+
just to read it back in.
|
271
|
+
metadata: List of metadata dicts, one for each path.
|
272
|
+
If a single dict is passed in, it is used for all paths.
|
273
|
+
doc_type: DocumentType to use for parsing, if known.
|
274
|
+
MUST apply to all docs if specified.
|
275
|
+
This is especially useful when the `paths` are of bytes type,
|
276
|
+
to help with document type detection.
|
277
|
+
Returns:
|
278
|
+
List of Document objects
|
279
|
+
"""
|
280
|
+
if isinstance(paths, str) or isinstance(paths, bytes):
|
281
|
+
paths = [paths]
|
282
|
+
all_paths = paths
|
283
|
+
paths_meta: Dict[int, Any] = {}
|
284
|
+
urls_meta: Dict[int, Any] = {}
|
285
|
+
idxs = range(len(all_paths))
|
286
|
+
url_idxs, path_idxs, bytes_idxs = get_urls_paths_bytes_indices(all_paths)
|
287
|
+
urls = [all_paths[i] for i in url_idxs]
|
288
|
+
paths = [all_paths[i] for i in path_idxs]
|
289
|
+
bytes_list = [all_paths[i] for i in bytes_idxs]
|
290
|
+
path_idxs.extend(bytes_idxs)
|
291
|
+
paths.extend(bytes_list)
|
292
|
+
if (isinstance(metadata, list) and len(metadata) > 0) or not isinstance(
|
293
|
+
metadata, list
|
294
|
+
):
|
295
|
+
if isinstance(metadata, list):
|
296
|
+
idx2meta = {
|
297
|
+
p: (
|
298
|
+
m
|
299
|
+
if isinstance(m, dict)
|
300
|
+
else (isinstance(m, DocMetaData) and m.dict())
|
301
|
+
) # appease mypy
|
302
|
+
for p, m in zip(idxs, metadata)
|
303
|
+
}
|
304
|
+
elif isinstance(metadata, dict):
|
305
|
+
idx2meta = {p: metadata for p in idxs}
|
306
|
+
else:
|
307
|
+
idx2meta = {p: metadata.dict() for p in idxs}
|
308
|
+
urls_meta = {u: idx2meta[u] for u in url_idxs}
|
309
|
+
paths_meta = {p: idx2meta[p] for p in path_idxs}
|
197
310
|
docs: List[Document] = []
|
198
311
|
parser = Parser(self.config.parsing)
|
199
312
|
if len(urls) > 0:
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
313
|
+
for ui in url_idxs:
|
314
|
+
meta = urls_meta.get(ui, {})
|
315
|
+
loader = URLLoader(urls=[all_paths[ui]], parser=parser) # type: ignore
|
316
|
+
url_docs = loader.load()
|
317
|
+
# update metadata of each doc with meta
|
318
|
+
for d in url_docs:
|
319
|
+
d.metadata = d.metadata.copy(update=meta)
|
320
|
+
docs.extend(url_docs)
|
321
|
+
if len(paths) > 0: # paths OR bytes are handled similarly
|
322
|
+
for pi in path_idxs:
|
323
|
+
meta = paths_meta.get(pi, {})
|
324
|
+
p = all_paths[pi]
|
325
|
+
path_docs = RepoLoader.get_documents(
|
326
|
+
p,
|
327
|
+
parser=parser,
|
328
|
+
doc_type=doc_type,
|
329
|
+
)
|
330
|
+
# update metadata of each doc with meta
|
331
|
+
for d in path_docs:
|
332
|
+
d.metadata = d.metadata.copy(update=meta)
|
205
333
|
docs.extend(path_docs)
|
206
334
|
n_docs = len(docs)
|
207
|
-
n_splits = self.ingest_docs(docs)
|
335
|
+
n_splits = self.ingest_docs(docs, split=self.config.split)
|
208
336
|
if n_docs == 0:
|
209
|
-
return
|
337
|
+
return []
|
210
338
|
n_urls = len(urls)
|
211
339
|
n_paths = len(paths)
|
212
340
|
print(
|
213
341
|
f"""
|
214
342
|
[green]I have processed the following {n_urls} URLs
|
215
|
-
and {n_paths}
|
343
|
+
and {n_paths} docs into {n_splits} parts:
|
216
344
|
""".strip()
|
217
345
|
)
|
218
|
-
|
219
|
-
print("\n".join(
|
346
|
+
path_reps = [p if isinstance(p, str) else "bytes" for p in paths]
|
347
|
+
print("\n".join([u for u in urls if isinstance(u, str)])) # appease mypy
|
348
|
+
print("\n".join(path_reps))
|
349
|
+
return docs
|
220
350
|
|
221
|
-
def ingest_docs(
|
351
|
+
def ingest_docs(
|
352
|
+
self,
|
353
|
+
docs: List[Document],
|
354
|
+
split: bool = True,
|
355
|
+
metadata: (
|
356
|
+
List[Dict[str, Any]] | Dict[str, Any] | DocMetaData | List[DocMetaData]
|
357
|
+
) = [],
|
358
|
+
) -> int:
|
222
359
|
"""
|
223
360
|
Chunk docs into pieces, map each chunk to vec-embedding, store in vec-db
|
361
|
+
|
362
|
+
Args:
|
363
|
+
docs: List of Document objects
|
364
|
+
split: Whether to split docs into chunks. Default is True.
|
365
|
+
If False, docs are treated as "chunks" and are not split.
|
366
|
+
metadata: List of metadata dicts, one for each doc, to augment
|
367
|
+
whatever metadata is already in the doc.
|
368
|
+
[ASSUME no conflicting keys between the two metadata dicts.]
|
369
|
+
If a single dict is passed in, it is used for all docs.
|
224
370
|
"""
|
225
|
-
|
371
|
+
if isinstance(metadata, list) and len(metadata) > 0:
|
372
|
+
for d, m in zip(docs, metadata):
|
373
|
+
d.metadata = d.metadata.copy(
|
374
|
+
update=m if isinstance(m, dict) else m.dict() # type: ignore
|
375
|
+
)
|
376
|
+
elif isinstance(metadata, dict):
|
377
|
+
for d in docs:
|
378
|
+
d.metadata = d.metadata.copy(update=metadata)
|
379
|
+
elif isinstance(metadata, DocMetaData):
|
380
|
+
for d in docs:
|
381
|
+
d.metadata = d.metadata.copy(update=metadata.dict())
|
382
|
+
|
383
|
+
self.original_docs.extend(docs)
|
226
384
|
if self.parser is None:
|
227
385
|
raise ValueError("Parser not set")
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
386
|
+
for d in docs:
|
387
|
+
if d.metadata.id in [None, ""]:
|
388
|
+
d.metadata.id = d._unique_hash_id()
|
389
|
+
if split:
|
390
|
+
docs = self.parser.split(docs)
|
391
|
+
else:
|
392
|
+
# treat each doc as a chunk
|
393
|
+
for d in docs:
|
394
|
+
d.metadata.is_chunk = True
|
234
395
|
if self.vecdb is None:
|
235
396
|
raise ValueError("VecDB not set")
|
397
|
+
|
398
|
+
# If any additional fields need to be added to content,
|
399
|
+
# add them as key=value pairs for all docs, before batching.
|
400
|
+
# This helps retrieval for table-like data.
|
401
|
+
# Note we need to do this at stage so that the embeddings
|
402
|
+
# are computed on the full content with these additional fields.
|
403
|
+
if len(self.config.add_fields_to_content) > 0:
|
404
|
+
fields = [
|
405
|
+
f for f in extract_fields(docs[0], self.config.add_fields_to_content)
|
406
|
+
]
|
407
|
+
if len(fields) > 0:
|
408
|
+
for d in docs:
|
409
|
+
key_vals = extract_fields(d, fields)
|
410
|
+
d.content = (
|
411
|
+
",".join(f"{k}={v}" for k, v in key_vals.items())
|
412
|
+
+ ",content="
|
413
|
+
+ d.content
|
414
|
+
)
|
415
|
+
docs = docs[: self.config.parsing.max_chunks]
|
236
416
|
# add embeddings in batches, to stay under limit of embeddings API
|
237
417
|
batches = list(batched(docs, self.config.embed_batch_size))
|
238
418
|
for batch in batches:
|
239
419
|
self.vecdb.add_documents(batch)
|
240
420
|
self.original_docs_length = self.doc_length(docs)
|
421
|
+
self.setup_documents(docs, filter=self.config.filter)
|
241
422
|
return len(docs)
|
242
423
|
|
424
|
+
@staticmethod
|
425
|
+
def document_compatible_dataframe(
|
426
|
+
df: pd.DataFrame,
|
427
|
+
content: str = "content",
|
428
|
+
metadata: List[str] = [],
|
429
|
+
) -> Tuple[pd.DataFrame, List[str]]:
|
430
|
+
"""
|
431
|
+
Convert dataframe so it is compatible with Document class:
|
432
|
+
- has "content" column
|
433
|
+
- has an "id" column to be used as Document.metadata.id
|
434
|
+
|
435
|
+
Args:
|
436
|
+
df: dataframe to convert
|
437
|
+
content: name of content column
|
438
|
+
metadata: list of metadata column names
|
439
|
+
|
440
|
+
Returns:
|
441
|
+
Tuple[pd.DataFrame, List[str]]: dataframe, metadata
|
442
|
+
- dataframe: dataframe with "content" column and "id" column
|
443
|
+
- metadata: list of metadata column names, including "id"
|
444
|
+
"""
|
445
|
+
if content not in df.columns:
|
446
|
+
raise ValueError(
|
447
|
+
f"""
|
448
|
+
Content column {content} not in dataframe,
|
449
|
+
so we cannot ingest into the DocChatAgent.
|
450
|
+
Please specify the `content` parameter as a suitable
|
451
|
+
text-based column in the dataframe.
|
452
|
+
"""
|
453
|
+
)
|
454
|
+
if content != "content":
|
455
|
+
# rename content column to "content", leave existing column intact
|
456
|
+
df = df.rename(columns={content: "content"}, inplace=False)
|
457
|
+
|
458
|
+
actual_metadata = metadata.copy()
|
459
|
+
if "id" not in df.columns:
|
460
|
+
docs = dataframe_to_documents(df, content="content", metadata=metadata)
|
461
|
+
ids = [str(d.id()) for d in docs]
|
462
|
+
df["id"] = ids
|
463
|
+
|
464
|
+
if "id" not in actual_metadata:
|
465
|
+
actual_metadata += ["id"]
|
466
|
+
|
467
|
+
return df, actual_metadata
|
468
|
+
|
469
|
+
def ingest_dataframe(
|
470
|
+
self,
|
471
|
+
df: pd.DataFrame,
|
472
|
+
content: str = "content",
|
473
|
+
metadata: List[str] = [],
|
474
|
+
) -> int:
|
475
|
+
"""
|
476
|
+
Ingest a dataframe into vecdb.
|
477
|
+
"""
|
478
|
+
self.from_dataframe = True
|
479
|
+
self.df_description = describe_dataframe(
|
480
|
+
df, filter_fields=self.config.filter_fields, n_vals=5
|
481
|
+
)
|
482
|
+
df, metadata = DocChatAgent.document_compatible_dataframe(df, content, metadata)
|
483
|
+
docs = dataframe_to_documents(df, content="content", metadata=metadata)
|
484
|
+
# When ingesting a dataframe we will no longer do any chunking,
|
485
|
+
# so we mark each doc as a chunk.
|
486
|
+
# TODO - revisit this since we may still want to chunk large text columns
|
487
|
+
for d in docs:
|
488
|
+
d.metadata.is_chunk = True
|
489
|
+
return self.ingest_docs(docs)
|
490
|
+
|
491
|
+
def set_filter(self, filter: str) -> None:
|
492
|
+
self.config.filter = filter
|
493
|
+
self.setup_documents(filter=filter)
|
494
|
+
|
495
|
+
def setup_documents(
|
496
|
+
self,
|
497
|
+
docs: List[Document] = [],
|
498
|
+
filter: str | None = None,
|
499
|
+
) -> None:
|
500
|
+
"""
|
501
|
+
Setup `self.chunked_docs` and `self.chunked_docs_clean`
|
502
|
+
based on possible filter.
|
503
|
+
These will be used in various non-vector-based search functions,
|
504
|
+
e.g. self.get_similar_chunks_bm25(), self.get_fuzzy_matches(), etc.
|
505
|
+
|
506
|
+
Args:
|
507
|
+
docs: List of Document objects. This is empty when we are calling this
|
508
|
+
method after initial doc ingestion.
|
509
|
+
filter: Filter condition for various lexical/semantic search fns.
|
510
|
+
"""
|
511
|
+
if filter is None and len(docs) > 0:
|
512
|
+
# no filter, so just use the docs passed in
|
513
|
+
self.chunked_docs.extend(docs)
|
514
|
+
else:
|
515
|
+
if self.vecdb is None:
|
516
|
+
raise ValueError("VecDB not set")
|
517
|
+
self.chunked_docs = self.vecdb.get_all_documents(where=filter or "")
|
518
|
+
|
519
|
+
self.chunked_docs_clean = [
|
520
|
+
Document(content=preprocess_text(d.content), metadata=d.metadata)
|
521
|
+
for d in self.chunked_docs
|
522
|
+
]
|
523
|
+
|
524
|
+
def get_field_values(self, fields: list[str]) -> Dict[str, str]:
|
525
|
+
"""Get string-listing of possible values of each filterable field,
|
526
|
+
e.g.
|
527
|
+
{
|
528
|
+
"genre": "crime, drama, mystery, ... (10 more)",
|
529
|
+
"certificate": "R, PG-13, PG, R",
|
530
|
+
}
|
531
|
+
"""
|
532
|
+
field_values: Dict[str, Set[str]] = {}
|
533
|
+
# make empty set for each field
|
534
|
+
for f in fields:
|
535
|
+
field_values[f] = set()
|
536
|
+
if self.vecdb is None:
|
537
|
+
raise ValueError("VecDB not set")
|
538
|
+
# get all documents and accumulate possible values of each field until 10
|
539
|
+
docs = self.vecdb.get_all_documents() # only works for vecdbs that support this
|
540
|
+
for d in docs:
|
541
|
+
# extract fields from d
|
542
|
+
doc_field_vals = extract_fields(d, fields)
|
543
|
+
for field, val in doc_field_vals.items():
|
544
|
+
field_values[field].add(val)
|
545
|
+
# For each field make a string showing list of possible values,
|
546
|
+
# truncate to 20 values, and if there are more, indicate how many
|
547
|
+
# more there are, e.g. Genre: crime, drama, mystery, ... (20 more)
|
548
|
+
field_values_list = {}
|
549
|
+
for f in fields:
|
550
|
+
vals = list(field_values[f])
|
551
|
+
n = len(vals)
|
552
|
+
remaining = n - 20
|
553
|
+
vals = vals[:20]
|
554
|
+
if n > 20:
|
555
|
+
vals.append(f"(...{remaining} more)")
|
556
|
+
# make a string of the values, ensure they are strings
|
557
|
+
field_values_list[f] = ", ".join(str(v) for v in vals)
|
558
|
+
return field_values_list
|
559
|
+
|
243
560
|
def doc_length(self, docs: List[Document]) -> int:
|
244
561
|
"""
|
245
562
|
Calc token-length of a list of docs
|
@@ -252,7 +569,78 @@ class DocChatAgent(ChatAgent):
|
|
252
569
|
raise ValueError("Parser not set")
|
253
570
|
return self.parser.num_tokens(self.doc_string(docs))
|
254
571
|
|
255
|
-
|
572
|
+
def user_docs_ingest_dialog(self) -> None:
|
573
|
+
"""
|
574
|
+
Ask user to select doc-collection, enter filenames/urls, and ingest into vecdb.
|
575
|
+
"""
|
576
|
+
if self.vecdb is None:
|
577
|
+
raise ValueError("VecDB not set")
|
578
|
+
n_deletes = self.vecdb.clear_empty_collections()
|
579
|
+
collections = self.vecdb.list_collections()
|
580
|
+
collection_name = "NEW"
|
581
|
+
is_new_collection = False
|
582
|
+
replace_collection = False
|
583
|
+
if len(collections) > 0:
|
584
|
+
n = len(collections)
|
585
|
+
delete_str = (
|
586
|
+
f"(deleted {n_deletes} empty collections)" if n_deletes > 0 else ""
|
587
|
+
)
|
588
|
+
print(f"Found {n} collections: {delete_str}")
|
589
|
+
for i, option in enumerate(collections, start=1):
|
590
|
+
print(f"{i}. {option}")
|
591
|
+
while True:
|
592
|
+
choice = Prompt.ask(
|
593
|
+
f"Enter 1-{n} to select a collection, "
|
594
|
+
"or hit ENTER to create a NEW collection, "
|
595
|
+
"or -1 to DELETE ALL COLLECTIONS",
|
596
|
+
default="0",
|
597
|
+
)
|
598
|
+
try:
|
599
|
+
if -1 <= int(choice) <= n:
|
600
|
+
break
|
601
|
+
except Exception:
|
602
|
+
pass
|
603
|
+
|
604
|
+
if choice == "-1":
|
605
|
+
confirm = Prompt.ask(
|
606
|
+
"Are you sure you want to delete all collections?",
|
607
|
+
choices=["y", "n"],
|
608
|
+
default="n",
|
609
|
+
)
|
610
|
+
if confirm == "y":
|
611
|
+
self.vecdb.clear_all_collections(really=True)
|
612
|
+
collection_name = "NEW"
|
613
|
+
|
614
|
+
if int(choice) > 0:
|
615
|
+
collection_name = collections[int(choice) - 1]
|
616
|
+
print(f"Using collection {collection_name}")
|
617
|
+
choice = Prompt.ask(
|
618
|
+
"Would you like to replace this collection?",
|
619
|
+
choices=["y", "n"],
|
620
|
+
default="n",
|
621
|
+
)
|
622
|
+
replace_collection = choice == "y"
|
623
|
+
|
624
|
+
if collection_name == "NEW":
|
625
|
+
is_new_collection = True
|
626
|
+
collection_name = Prompt.ask(
|
627
|
+
"What would you like to name the NEW collection?",
|
628
|
+
default="doc-chat",
|
629
|
+
)
|
630
|
+
|
631
|
+
self.vecdb.set_collection(collection_name, replace=replace_collection)
|
632
|
+
|
633
|
+
default_urls_str = (
|
634
|
+
" (or leave empty for default URLs)" if is_new_collection else ""
|
635
|
+
)
|
636
|
+
print(f"[blue]Enter some URLs or file/dir paths below {default_urls_str}")
|
637
|
+
inputs = get_list_from_user()
|
638
|
+
if len(inputs) == 0:
|
639
|
+
if is_new_collection:
|
640
|
+
inputs = self.config.default_paths
|
641
|
+
self.config.doc_paths = inputs # type: ignore
|
642
|
+
self.ingest()
|
643
|
+
|
256
644
|
def llm_response(
|
257
645
|
self,
|
258
646
|
query: None | str | ChatDocument = None,
|
@@ -269,10 +657,55 @@ class DocChatAgent(ChatAgent):
|
|
269
657
|
query_str = query_str[1:] if query_str is not None else None
|
270
658
|
if self.llm is None:
|
271
659
|
raise ValueError("LLM not set")
|
272
|
-
with StreamingIfAllowed(self.llm):
|
660
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
273
661
|
response = super().llm_response(query_str)
|
274
662
|
if query_str is not None:
|
275
|
-
self.update_dialog(
|
663
|
+
self.update_dialog(
|
664
|
+
query_str, "" if response is None else response.content
|
665
|
+
)
|
666
|
+
return response
|
667
|
+
if query_str == "":
|
668
|
+
return None
|
669
|
+
elif query_str == "?" and self.response is not None:
|
670
|
+
return self.justify_response()
|
671
|
+
elif (query_str.startswith(("summar", "?")) and self.response is None) or (
|
672
|
+
query_str == "??"
|
673
|
+
):
|
674
|
+
return self.summarize_docs()
|
675
|
+
else:
|
676
|
+
self.callbacks.show_start_response(entity="llm")
|
677
|
+
response = self.answer_from_docs(query_str)
|
678
|
+
return ChatDocument(
|
679
|
+
content=response.content,
|
680
|
+
metadata=ChatDocMetaData(
|
681
|
+
source=response.metadata.source,
|
682
|
+
sender=Entity.LLM,
|
683
|
+
),
|
684
|
+
)
|
685
|
+
|
686
|
+
async def llm_response_async(
|
687
|
+
self,
|
688
|
+
query: None | str | ChatDocument = None,
|
689
|
+
) -> Optional[ChatDocument]:
|
690
|
+
apply_nest_asyncio()
|
691
|
+
if not self.llm_can_respond(query):
|
692
|
+
return None
|
693
|
+
query_str: str | None
|
694
|
+
if isinstance(query, ChatDocument):
|
695
|
+
query_str = query.content
|
696
|
+
else:
|
697
|
+
query_str = query
|
698
|
+
if query_str is None or query_str.startswith("!"):
|
699
|
+
# direct query to LLM
|
700
|
+
query_str = query_str[1:] if query_str is not None else None
|
701
|
+
if self.llm is None:
|
702
|
+
raise ValueError("LLM not set")
|
703
|
+
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
704
|
+
response = await super().llm_response_async(query_str)
|
705
|
+
if query_str is not None:
|
706
|
+
self.update_dialog(
|
707
|
+
query_str, "" if response is None else response.content
|
708
|
+
)
|
276
709
|
return response
|
277
710
|
if query_str == "":
|
278
711
|
return None
|
@@ -283,6 +716,7 @@ class DocChatAgent(ChatAgent):
|
|
283
716
|
):
|
284
717
|
return self.summarize_docs()
|
285
718
|
else:
|
719
|
+
self.callbacks.show_start_response(entity="llm")
|
286
720
|
response = self.answer_from_docs(query_str)
|
287
721
|
return ChatDocument(
|
288
722
|
content=response.content,
|
@@ -314,7 +748,9 @@ class DocChatAgent(ChatAgent):
|
|
314
748
|
]
|
315
749
|
)
|
316
750
|
|
317
|
-
def get_summary_answer(
|
751
|
+
def get_summary_answer(
|
752
|
+
self, question: str, passages: List[Document]
|
753
|
+
) -> ChatDocument:
|
318
754
|
"""
|
319
755
|
Given a question and a list of (possibly) doc snippets,
|
320
756
|
generate an answer if possible
|
@@ -342,9 +778,6 @@ class DocChatAgent(ChatAgent):
|
|
342
778
|
# 2 new LLMMessage objects:
|
343
779
|
# one for `final_prompt`, and one for the LLM response
|
344
780
|
|
345
|
-
# TODO need to "forget" last two messages in message_history
|
346
|
-
# if we are not in conversation mode
|
347
|
-
|
348
781
|
if self.config.conversation_mode:
|
349
782
|
# respond with temporary context
|
350
783
|
answer_doc = super()._llm_response_temp_context(question, final_prompt)
|
@@ -353,16 +786,23 @@ class DocChatAgent(ChatAgent):
|
|
353
786
|
|
354
787
|
final_answer = answer_doc.content.strip()
|
355
788
|
show_if_debug(final_answer, "SUMMARIZE_RESPONSE= ")
|
356
|
-
|
357
|
-
if
|
358
|
-
|
359
|
-
|
360
|
-
else:
|
789
|
+
|
790
|
+
if final_answer.startswith("SOURCE"):
|
791
|
+
# sometimes SOURCE may be shown first,
|
792
|
+
# in this case just use final_answer as-is for both content and source
|
361
793
|
content = final_answer
|
362
|
-
sources =
|
363
|
-
|
794
|
+
sources = final_answer
|
795
|
+
else:
|
796
|
+
parts = final_answer.split("SOURCE:", maxsplit=1)
|
797
|
+
if len(parts) > 1:
|
798
|
+
content = parts[0].strip()
|
799
|
+
sources = parts[1].strip()
|
800
|
+
else:
|
801
|
+
content = final_answer
|
802
|
+
sources = ""
|
803
|
+
return ChatDocument(
|
364
804
|
content=content,
|
365
|
-
metadata=
|
805
|
+
metadata=ChatDocMetaData(
|
366
806
|
source="SOURCE: " + sources,
|
367
807
|
sender=Entity.LLM,
|
368
808
|
cached=getattr(answer_doc.metadata, "cached", False),
|
@@ -372,7 +812,7 @@ class DocChatAgent(ChatAgent):
|
|
372
812
|
def llm_hypothetical_answer(self, query: str) -> str:
|
373
813
|
if self.llm is None:
|
374
814
|
raise ValueError("LLM not set")
|
375
|
-
with
|
815
|
+
with status("[cyan]LLM generating hypothetical answer..."):
|
376
816
|
with StreamingIfAllowed(self.llm, False):
|
377
817
|
# TODO: provide an easy way to
|
378
818
|
# Adjust this prompt depending on context.
|
@@ -392,7 +832,7 @@ class DocChatAgent(ChatAgent):
|
|
392
832
|
def llm_rephrase_query(self, query: str) -> List[str]:
|
393
833
|
if self.llm is None:
|
394
834
|
raise ValueError("LLM not set")
|
395
|
-
with
|
835
|
+
with status("[cyan]LLM generating rephrases of query..."):
|
396
836
|
with StreamingIfAllowed(self.llm, False):
|
397
837
|
rephrases = self.llm_response_forget(
|
398
838
|
f"""
|
@@ -408,11 +848,13 @@ class DocChatAgent(ChatAgent):
|
|
408
848
|
) -> List[Tuple[Document, float]]:
|
409
849
|
# find similar docs using bm25 similarity:
|
410
850
|
# these may sometimes be more likely to contain a relevant verbatim extract
|
411
|
-
with
|
412
|
-
if self.chunked_docs is None:
|
413
|
-
|
414
|
-
|
415
|
-
|
851
|
+
with status("[cyan]Searching for similar chunks using bm25..."):
|
852
|
+
if self.chunked_docs is None or len(self.chunked_docs) == 0:
|
853
|
+
logger.warning("No chunked docs; cannot use bm25-similarity")
|
854
|
+
return []
|
855
|
+
if self.chunked_docs_clean is None or len(self.chunked_docs_clean) == 0:
|
856
|
+
logger.warning("No cleaned chunked docs; cannot use bm25-similarity")
|
857
|
+
return []
|
416
858
|
docs_scores = find_closest_matches_with_bm25(
|
417
859
|
self.chunked_docs,
|
418
860
|
self.chunked_docs_clean, # already pre-processed!
|
@@ -424,24 +866,27 @@ class DocChatAgent(ChatAgent):
|
|
424
866
|
def get_fuzzy_matches(self, query: str, multiple: int) -> List[Document]:
|
425
867
|
# find similar docs using fuzzy matching:
|
426
868
|
# these may sometimes be more likely to contain a relevant verbatim extract
|
427
|
-
with
|
869
|
+
with status("[cyan]Finding fuzzy matches in chunks..."):
|
428
870
|
if self.chunked_docs is None:
|
429
|
-
|
871
|
+
logger.warning("No chunked docs; cannot use fuzzy matching")
|
872
|
+
return []
|
873
|
+
if self.chunked_docs_clean is None:
|
874
|
+
logger.warning("No cleaned chunked docs; cannot use fuzzy-search")
|
875
|
+
return []
|
430
876
|
fuzzy_match_docs = find_fuzzy_matches_in_docs(
|
431
877
|
query,
|
432
878
|
self.chunked_docs,
|
879
|
+
self.chunked_docs_clean,
|
433
880
|
k=self.config.parsing.n_similar_docs * multiple,
|
434
|
-
words_before=
|
435
|
-
words_after=
|
881
|
+
words_before=self.config.n_fuzzy_neighbor_words,
|
882
|
+
words_after=self.config.n_fuzzy_neighbor_words,
|
436
883
|
)
|
437
884
|
return fuzzy_match_docs
|
438
885
|
|
439
886
|
def rerank_with_cross_encoder(
|
440
887
|
self, query: str, passages: List[Document]
|
441
888
|
) -> List[Document]:
|
442
|
-
with
|
443
|
-
if self.chunked_docs is None:
|
444
|
-
raise ValueError("No chunked docs")
|
889
|
+
with status("[cyan]Re-ranking retrieved chunks using cross-encoder..."):
|
445
890
|
try:
|
446
891
|
from sentence_transformers import CrossEncoder
|
447
892
|
except ImportError:
|
@@ -455,6 +900,8 @@ class DocChatAgent(ChatAgent):
|
|
455
900
|
|
456
901
|
model = CrossEncoder(self.config.cross_encoder_reranking_model)
|
457
902
|
scores = model.predict([(query, p.content) for p in passages])
|
903
|
+
# Convert to [0,1] so we might could use a cutoff later.
|
904
|
+
scores = 1.0 / (1 + np.exp(-np.array(scores)))
|
458
905
|
# get top k scoring passages
|
459
906
|
sorted_pairs = sorted(
|
460
907
|
zip(scores, passages),
|
@@ -466,66 +913,187 @@ class DocChatAgent(ChatAgent):
|
|
466
913
|
]
|
467
914
|
return passages
|
468
915
|
|
469
|
-
|
470
|
-
def get_relevant_extracts(self, query: str) -> Tuple[str, List[Document]]:
|
916
|
+
def rerank_with_diversity(self, passages: List[Document]) -> List[Document]:
|
471
917
|
"""
|
472
|
-
|
473
|
-
|
474
|
-
- use LLM to convert query to stand-alone query
|
475
|
-
- optionally rephrase query to use below
|
476
|
-
- optionally generate hypothetical answer (HyDE) to use below.
|
477
|
-
- get relevant doc-chunks, via:
|
478
|
-
- vector-embedding distance, from vecdb
|
479
|
-
- bm25-ranking (keyword similarity)
|
480
|
-
- fuzzy matching (keyword similarity)
|
481
|
-
- re-ranking of doc-chunks using cross-encoder, pick top k
|
482
|
-
- use LLM to get relevant extracts from doc-chunks
|
918
|
+
Rerank a list of items in such a way that each successive item is least similar
|
919
|
+
(on average) to the earlier items.
|
483
920
|
|
484
921
|
Args:
|
485
|
-
|
922
|
+
query (str): The query for which the passages are relevant.
|
923
|
+
passages (List[Document]): A list of Documents to be reranked.
|
486
924
|
|
487
925
|
Returns:
|
488
|
-
|
489
|
-
|
926
|
+
List[Documents]: A reranked list of Documents.
|
927
|
+
"""
|
490
928
|
|
929
|
+
if self.vecdb is None:
|
930
|
+
logger.warning("No vecdb; cannot use rerank_with_diversity")
|
931
|
+
return passages
|
932
|
+
emb_model = self.vecdb.embedding_model
|
933
|
+
emb_fn = emb_model.embedding_fn()
|
934
|
+
embs = emb_fn([p.content for p in passages])
|
935
|
+
embs_arr = [np.array(e) for e in embs]
|
936
|
+
indices = list(range(len(passages)))
|
937
|
+
|
938
|
+
# Helper function to compute average similarity to
|
939
|
+
# items in the current result list.
|
940
|
+
def avg_similarity_to_result(i: int, result: List[int]) -> float:
|
941
|
+
return sum( # type: ignore
|
942
|
+
(embs_arr[i] @ embs_arr[j])
|
943
|
+
/ (np.linalg.norm(embs_arr[i]) * np.linalg.norm(embs_arr[j]))
|
944
|
+
for j in result
|
945
|
+
) / len(result)
|
946
|
+
|
947
|
+
# copy passages to items
|
948
|
+
result = [indices.pop(0)] # Start with the first item.
|
949
|
+
|
950
|
+
while indices:
|
951
|
+
# Find the item that has the least average similarity
|
952
|
+
# to items in the result list.
|
953
|
+
least_similar_item = min(
|
954
|
+
indices, key=lambda i: avg_similarity_to_result(i, result)
|
955
|
+
)
|
956
|
+
result.append(least_similar_item)
|
957
|
+
indices.remove(least_similar_item)
|
958
|
+
|
959
|
+
# return passages in order of result list
|
960
|
+
return [passages[i] for i in result]
|
961
|
+
|
962
|
+
def rerank_to_periphery(self, passages: List[Document]) -> List[Document]:
|
491
963
|
"""
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
with StreamingIfAllowed(self.llm, False):
|
498
|
-
query = self.llm.followup_to_standalone(self.dialog, query)
|
499
|
-
print(f"[orange2]New query: {query}")
|
964
|
+
Rerank to avoid Lost In the Middle (LIM) problem,
|
965
|
+
where LLMs pay more attention to items at the ends of a list,
|
966
|
+
rather than the middle. So we re-rank to make the best passages
|
967
|
+
appear at the periphery of the list.
|
968
|
+
https://arxiv.org/abs/2307.03172
|
500
969
|
|
970
|
+
Example reranking:
|
971
|
+
1 2 3 4 5 6 7 8 9 ==> 1 3 5 7 9 8 6 4 2
|
972
|
+
|
973
|
+
Args:
|
974
|
+
passages (List[Document]): A list of Documents to be reranked.
|
975
|
+
|
976
|
+
Returns:
|
977
|
+
List[Documents]: A reranked list of Documents.
|
978
|
+
|
979
|
+
"""
|
980
|
+
# Splitting items into odds and evens based on index, not value
|
981
|
+
odds = passages[::2]
|
982
|
+
evens = passages[1::2][::-1]
|
983
|
+
|
984
|
+
# Merging them back together
|
985
|
+
return odds + evens
|
986
|
+
|
987
|
+
def add_context_window(
|
988
|
+
self,
|
989
|
+
docs_scores: List[Tuple[Document, float]],
|
990
|
+
) -> List[Tuple[Document, float]]:
|
991
|
+
"""
|
992
|
+
In each doc's metadata, there may be a window_ids field indicating
|
993
|
+
the ids of the chunks around the current chunk. We use these stored
|
994
|
+
window_ids to retrieve the desired number
|
995
|
+
(self.config.n_neighbor_chunks) of neighbors
|
996
|
+
on either side of the current chunk.
|
997
|
+
|
998
|
+
Args:
|
999
|
+
docs_scores (List[Tuple[Document, float]]): List of pairs of documents
|
1000
|
+
to add context windows to together with their match scores.
|
1001
|
+
|
1002
|
+
Returns:
|
1003
|
+
List[Tuple[Document, float]]: List of (Document, score) tuples.
|
1004
|
+
"""
|
1005
|
+
if self.vecdb is None or self.config.n_neighbor_chunks == 0:
|
1006
|
+
return docs_scores
|
1007
|
+
if len(docs_scores) == 0:
|
1008
|
+
return []
|
1009
|
+
if set(docs_scores[0][0].__fields__) != {"content", "metadata"}:
|
1010
|
+
# Do not add context window when there are other fields besides just
|
1011
|
+
# content and metadata, since we do not know how to set those other fields
|
1012
|
+
# for newly created docs with combined content.
|
1013
|
+
return docs_scores
|
1014
|
+
return self.vecdb.add_context_window(docs_scores, self.config.n_neighbor_chunks)
|
1015
|
+
|
1016
|
+
def get_semantic_search_results(
|
1017
|
+
self,
|
1018
|
+
query: str,
|
1019
|
+
k: int = 10,
|
1020
|
+
) -> List[Tuple[Document, float]]:
|
1021
|
+
"""
|
1022
|
+
Get semantic search results from vecdb.
|
1023
|
+
Args:
|
1024
|
+
query (str): query to search for
|
1025
|
+
k (int): number of results to return
|
1026
|
+
Returns:
|
1027
|
+
List[Tuple[Document, float]]: List of (Document, score) tuples.
|
1028
|
+
"""
|
1029
|
+
if self.vecdb is None:
|
1030
|
+
raise ValueError("VecDB not set")
|
1031
|
+
# Note: for dynamic filtering based on a query, users can
|
1032
|
+
# use the `temp_update` context-manager to pass in a `filter` to self.config,
|
1033
|
+
# e.g.:
|
1034
|
+
# with temp_update(self.config, {"filter": "metadata.source=='source1'"}):
|
1035
|
+
# docs_scores = self.get_semantic_search_results(query, k=k)
|
1036
|
+
# This avoids having pass the `filter` argument to every function call
|
1037
|
+
# upstream of this one.
|
1038
|
+
# The `temp_update` context manager is defined in
|
1039
|
+
# `langroid/utils/pydantic_utils.py`
|
1040
|
+
return self.vecdb.similar_texts_with_scores(
|
1041
|
+
query,
|
1042
|
+
k=k,
|
1043
|
+
where=self.config.filter,
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
def get_relevant_chunks(
|
1047
|
+
self, query: str, query_proxies: List[str] = []
|
1048
|
+
) -> List[Document]:
|
1049
|
+
"""
|
1050
|
+
The retrieval stage in RAG: get doc-chunks that are most "relevant"
|
1051
|
+
to the query (and possibly any proxy queries), from the document-store,
|
1052
|
+
which currently is the vector store,
|
1053
|
+
but in theory could be any document store, or even web-search.
|
1054
|
+
This stage does NOT involve an LLM, and the retrieved chunks
|
1055
|
+
could either be pre-chunked text (from the initial pre-processing stage
|
1056
|
+
where chunks were stored in the vector store), or they could be
|
1057
|
+
dynamically retrieved based on a window around a lexical match.
|
1058
|
+
|
1059
|
+
These are the steps (some optional based on config):
|
1060
|
+
- semantic search based on vector-embedding distance, from vecdb
|
1061
|
+
- lexical search using bm25-ranking (keyword similarity)
|
1062
|
+
- fuzzy matching (keyword similarity)
|
1063
|
+
- re-ranking of doc-chunks by relevance to query, using cross-encoder,
|
1064
|
+
and pick top k
|
1065
|
+
|
1066
|
+
Args:
|
1067
|
+
query: original query (assumed to be in stand-alone form)
|
1068
|
+
query_proxies: possible rephrases, or hypothetical answer to query
|
1069
|
+
(e.g. for HyDE-type retrieval)
|
1070
|
+
|
1071
|
+
Returns:
|
1072
|
+
|
1073
|
+
"""
|
501
1074
|
# if we are using cross-encoder reranking, we can retrieve more docs
|
502
1075
|
# during retrieval, and leave it to the cross-encoder re-ranking
|
503
1076
|
# to whittle down to self.config.parsing.n_similar_docs
|
504
1077
|
retrieval_multiple = 1 if self.config.cross_encoder_reranking_model == "" else 3
|
505
|
-
queries = [query]
|
506
|
-
if self.config.hypothetical_answer:
|
507
|
-
answer = self.llm_hypothetical_answer(query)
|
508
|
-
queries = [query, answer]
|
509
1078
|
|
510
|
-
if self.
|
511
|
-
|
512
|
-
queries += rephrases
|
1079
|
+
if self.vecdb is None:
|
1080
|
+
raise ValueError("VecDB not set")
|
513
1081
|
|
514
|
-
with
|
515
|
-
docs_and_scores = []
|
516
|
-
for q in
|
517
|
-
docs_and_scores += self.
|
1082
|
+
with status("[cyan]Searching VecDB for relevant doc passages..."):
|
1083
|
+
docs_and_scores: List[Tuple[Document, float]] = []
|
1084
|
+
for q in [query] + query_proxies:
|
1085
|
+
docs_and_scores += self.get_semantic_search_results(
|
518
1086
|
q,
|
519
1087
|
k=self.config.parsing.n_similar_docs * retrieval_multiple,
|
520
1088
|
)
|
521
1089
|
# keep only docs with unique d.id()
|
522
1090
|
id2doc_score = {d.id(): (d, s) for d, s in docs_and_scores}
|
523
1091
|
docs_and_scores = list(id2doc_score.values())
|
524
|
-
|
525
|
-
passages = [
|
526
|
-
|
527
|
-
|
528
|
-
]
|
1092
|
+
passages = [d for (d, _) in docs_and_scores]
|
1093
|
+
# passages = [
|
1094
|
+
# Document(content=d.content, metadata=d.metadata)
|
1095
|
+
# for (d, _) in docs_and_scores
|
1096
|
+
# ]
|
529
1097
|
|
530
1098
|
if self.config.use_bm25_search:
|
531
1099
|
docs_scores = self.get_similar_chunks_bm25(query, retrieval_multiple)
|
@@ -539,25 +1107,136 @@ class DocChatAgent(ChatAgent):
|
|
539
1107
|
id2passage = {p.id(): p for p in passages}
|
540
1108
|
passages = list(id2passage.values())
|
541
1109
|
|
1110
|
+
if len(passages) == 0:
|
1111
|
+
return []
|
1112
|
+
|
1113
|
+
passages_scores = [(p, 0.0) for p in passages]
|
1114
|
+
passages_scores = self.add_context_window(passages_scores)
|
1115
|
+
passages = [p for p, _ in passages_scores]
|
542
1116
|
# now passages can potentially have a lot of doc chunks,
|
543
|
-
# so we re-rank them using a cross-encoder scoring model
|
1117
|
+
# so we re-rank them using a cross-encoder scoring model,
|
1118
|
+
# and pick top k where k = config.parsing.n_similar_docs
|
544
1119
|
# https://www.sbert.net/examples/applications/retrieve_rerank
|
545
1120
|
if self.config.cross_encoder_reranking_model != "":
|
546
1121
|
passages = self.rerank_with_cross_encoder(query, passages)
|
547
1122
|
|
1123
|
+
if self.config.rerank_diversity:
|
1124
|
+
# reorder to increase diversity among top docs
|
1125
|
+
passages = self.rerank_with_diversity(passages)
|
1126
|
+
|
1127
|
+
if self.config.rerank_periphery:
|
1128
|
+
# reorder so most important docs are at periphery
|
1129
|
+
# (see Lost In the Middle issue).
|
1130
|
+
passages = self.rerank_to_periphery(passages)
|
1131
|
+
|
1132
|
+
return passages
|
1133
|
+
|
1134
|
+
@no_type_check
|
1135
|
+
def get_relevant_extracts(self, query: str) -> Tuple[str, List[Document]]:
|
1136
|
+
"""
|
1137
|
+
Get list of (verbatim) extracts from doc-chunks relevant to answering a query.
|
1138
|
+
|
1139
|
+
These are the stages (some optional based on config):
|
1140
|
+
- use LLM to convert query to stand-alone query
|
1141
|
+
- optionally use LLM to rephrase query to use below
|
1142
|
+
- optionally use LLM to generate hypothetical answer (HyDE) to use below.
|
1143
|
+
- get_relevant_chunks(): get doc-chunks relevant to query and proxies
|
1144
|
+
- use LLM to get relevant extracts from doc-chunks
|
1145
|
+
|
1146
|
+
Args:
|
1147
|
+
query (str): query to search for
|
1148
|
+
|
1149
|
+
Returns:
|
1150
|
+
query (str): stand-alone version of input query
|
1151
|
+
List[Document]: list of relevant extracts
|
1152
|
+
|
1153
|
+
"""
|
1154
|
+
if len(self.dialog) > 0 and not self.config.assistant_mode:
|
1155
|
+
# Regardless of whether we are in conversation mode or not,
|
1156
|
+
# for relevant doc/chunk extraction, we must convert the query
|
1157
|
+
# to a standalone query to get more relevant results.
|
1158
|
+
with status("[cyan]Converting to stand-alone query...[/cyan]"):
|
1159
|
+
with StreamingIfAllowed(self.llm, False):
|
1160
|
+
query = self.llm.followup_to_standalone(self.dialog, query)
|
1161
|
+
print(f"[orange2]New query: {query}")
|
1162
|
+
|
1163
|
+
proxies = []
|
1164
|
+
if self.config.hypothetical_answer:
|
1165
|
+
answer = self.llm_hypothetical_answer(query)
|
1166
|
+
proxies = [answer]
|
1167
|
+
|
1168
|
+
if self.config.n_query_rephrases > 0:
|
1169
|
+
rephrases = self.llm_rephrase_query(query)
|
1170
|
+
proxies += rephrases
|
1171
|
+
|
1172
|
+
passages = self.get_relevant_chunks(query, proxies) # no LLM involved
|
1173
|
+
|
548
1174
|
if len(passages) == 0:
|
549
1175
|
return query, []
|
550
1176
|
|
551
|
-
with
|
1177
|
+
with status("[cyan]LLM Extracting verbatim passages..."):
|
552
1178
|
with StreamingIfAllowed(self.llm, False):
|
553
1179
|
# these are async calls, one per passage; turn off streaming
|
554
|
-
extracts = self.
|
1180
|
+
extracts = self.get_verbatim_extracts(query, passages)
|
555
1181
|
extracts = [e for e in extracts if e.content != NO_ANSWER]
|
556
1182
|
|
557
1183
|
return query, extracts
|
558
1184
|
|
559
|
-
|
560
|
-
|
1185
|
+
def get_verbatim_extracts(
|
1186
|
+
self,
|
1187
|
+
query: str,
|
1188
|
+
passages: List[Document],
|
1189
|
+
) -> List[Document]:
|
1190
|
+
"""
|
1191
|
+
Run RelevanceExtractorAgent in async/concurrent mode on passages,
|
1192
|
+
to extract portions relevant to answering query, from each passage.
|
1193
|
+
Args:
|
1194
|
+
query (str): query to answer
|
1195
|
+
passages (List[Documents]): list of passages to extract from
|
1196
|
+
|
1197
|
+
Returns:
|
1198
|
+
List[Document]: list of Documents containing extracts and metadata.
|
1199
|
+
"""
|
1200
|
+
agent_cfg = self.config.relevance_extractor_config
|
1201
|
+
if agent_cfg is None:
|
1202
|
+
# no relevance extraction: simply return passages
|
1203
|
+
return passages
|
1204
|
+
if agent_cfg.llm is None:
|
1205
|
+
# Use main DocChatAgent's LLM if not provided explicitly:
|
1206
|
+
# this reduces setup burden on the user
|
1207
|
+
agent_cfg.llm = self.config.llm
|
1208
|
+
agent_cfg.query = query
|
1209
|
+
agent_cfg.segment_length = self.config.extraction_granularity
|
1210
|
+
agent_cfg.llm.stream = False # disable streaming for concurrent calls
|
1211
|
+
|
1212
|
+
agent = RelevanceExtractorAgent(agent_cfg)
|
1213
|
+
task = Task(
|
1214
|
+
agent,
|
1215
|
+
name="Relevance-Extractor",
|
1216
|
+
interactive=False,
|
1217
|
+
)
|
1218
|
+
|
1219
|
+
extracts = run_batch_tasks(
|
1220
|
+
task,
|
1221
|
+
passages,
|
1222
|
+
input_map=lambda msg: msg.content,
|
1223
|
+
output_map=lambda ans: ans.content if ans is not None else NO_ANSWER,
|
1224
|
+
)
|
1225
|
+
|
1226
|
+
# Caution: Retain ALL other fields in the Documents (which could be
|
1227
|
+
# other than just `content` and `metadata`), while simply replacing
|
1228
|
+
# `content` with the extracted portions
|
1229
|
+
passage_extracts = []
|
1230
|
+
for p, e in zip(passages, extracts):
|
1231
|
+
if e == NO_ANSWER or len(e) == 0:
|
1232
|
+
continue
|
1233
|
+
p_copy = p.copy()
|
1234
|
+
p_copy.content = e
|
1235
|
+
passage_extracts.append(p_copy)
|
1236
|
+
|
1237
|
+
return passage_extracts
|
1238
|
+
|
1239
|
+
def answer_from_docs(self, query: str) -> ChatDocument:
|
561
1240
|
"""
|
562
1241
|
Answer query based on relevant docs from the VecDB
|
563
1242
|
|
@@ -567,24 +1246,38 @@ class DocChatAgent(ChatAgent):
|
|
567
1246
|
Returns:
|
568
1247
|
Document: answer
|
569
1248
|
"""
|
570
|
-
response =
|
1249
|
+
response = ChatDocument(
|
571
1250
|
content=NO_ANSWER,
|
572
|
-
metadata=
|
1251
|
+
metadata=ChatDocMetaData(
|
573
1252
|
source="None",
|
1253
|
+
sender=Entity.LLM,
|
574
1254
|
),
|
575
1255
|
)
|
576
1256
|
# query may be updated to a stand-alone version
|
577
1257
|
query, extracts = self.get_relevant_extracts(query)
|
578
1258
|
if len(extracts) == 0:
|
579
1259
|
return response
|
1260
|
+
if self.llm is None:
|
1261
|
+
raise ValueError("LLM not set")
|
1262
|
+
if self.config.retrieve_only:
|
1263
|
+
# only return extracts, skip LLM-based summary answer
|
1264
|
+
meta = dict(
|
1265
|
+
sender=Entity.LLM,
|
1266
|
+
)
|
1267
|
+
# copy metadata from first doc, unclear what to do here.
|
1268
|
+
meta.update(extracts[0].metadata)
|
1269
|
+
return ChatDocument(
|
1270
|
+
content="\n\n".join([e.content for e in extracts]),
|
1271
|
+
metadata=ChatDocMetaData(**meta),
|
1272
|
+
)
|
580
1273
|
with ExitStack() as stack:
|
581
1274
|
# conditionally use Streaming or rich console context
|
582
1275
|
cm = (
|
583
1276
|
StreamingIfAllowed(self.llm)
|
584
1277
|
if settings.stream
|
585
|
-
else (
|
1278
|
+
else (status("LLM Generating final answer..."))
|
586
1279
|
)
|
587
|
-
stack.enter_context(cm)
|
1280
|
+
stack.enter_context(cm) # type: ignore
|
588
1281
|
response = self.get_summary_answer(query, extracts)
|
589
1282
|
|
590
1283
|
self.update_dialog(query, response.content)
|
@@ -598,7 +1291,7 @@ class DocChatAgent(ChatAgent):
|
|
598
1291
|
"""Summarize all docs"""
|
599
1292
|
if self.llm is None:
|
600
1293
|
raise ValueError("LLM not set")
|
601
|
-
if self.original_docs
|
1294
|
+
if len(self.original_docs) == 0:
|
602
1295
|
logger.warning(
|
603
1296
|
"""
|
604
1297
|
No docs to summarize! Perhaps you are re-using a previously
|
@@ -627,19 +1320,22 @@ class DocChatAgent(ChatAgent):
|
|
627
1320
|
)
|
628
1321
|
prompt = f"""
|
629
1322
|
{instruction}
|
1323
|
+
|
1324
|
+
FULL TEXT:
|
630
1325
|
{full_text}
|
631
1326
|
""".strip()
|
632
1327
|
with StreamingIfAllowed(self.llm):
|
633
|
-
summary =
|
634
|
-
return summary
|
1328
|
+
summary = ChatAgent.llm_response(self, prompt)
|
1329
|
+
return summary
|
635
1330
|
|
636
|
-
def justify_response(self) -> None:
|
1331
|
+
def justify_response(self) -> ChatDocument | None:
|
637
1332
|
"""Show evidence for last response"""
|
638
1333
|
if self.response is None:
|
639
1334
|
print("[magenta]No response yet")
|
640
|
-
return
|
1335
|
+
return None
|
641
1336
|
source = self.response.metadata.source
|
642
1337
|
if len(source) > 0:
|
643
1338
|
print("[magenta]" + source)
|
644
1339
|
else:
|
645
1340
|
print("[magenta]No source found")
|
1341
|
+
return None
|