langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -12,20 +12,30 @@ langroid with the [hf-embeddings] extra, e.g.:
12
12
  pip install "langroid[hf-embeddings]"
13
13
 
14
14
  """
15
+
15
16
  import logging
16
17
  from contextlib import ExitStack
17
- from typing import List, Optional, Tuple, no_type_check
18
+ from functools import cache
19
+ from typing import Any, Dict, List, Optional, Set, Tuple, no_type_check
18
20
 
19
- from rich import print
20
- from rich.console import Console
21
+ import nest_asyncio
22
+ import numpy as np
23
+ import pandas as pd
24
+ from rich.prompt import Prompt
21
25
 
22
- from langroid.agent.base import Agent
26
+ from langroid.agent.batch import run_batch_tasks
23
27
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
24
28
  from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
29
+ from langroid.agent.special.relevance_extractor_agent import (
30
+ RelevanceExtractorAgent,
31
+ RelevanceExtractorAgentConfig,
32
+ )
33
+ from langroid.agent.task import Task
25
34
  from langroid.embedding_models.models import OpenAIEmbeddingsConfig
26
35
  from langroid.language_models.base import StreamingIfAllowed
27
36
  from langroid.language_models.openai_gpt import OpenAIChatModel, OpenAIGPTConfig
28
37
  from langroid.mytypes import DocMetaData, Document, Entity
38
+ from langroid.parsing.document_parser import DocumentType
29
39
  from langroid.parsing.parser import Parser, ParsingConfig, PdfParsingConfig, Splitter
30
40
  from langroid.parsing.repo_loader import RepoLoader
31
41
  from langroid.parsing.search import (
@@ -33,20 +43,26 @@ from langroid.parsing.search import (
33
43
  find_fuzzy_matches_in_docs,
34
44
  preprocess_text,
35
45
  )
46
+ from langroid.parsing.table_loader import describe_dataframe
36
47
  from langroid.parsing.url_loader import URLLoader
37
- from langroid.parsing.urls import get_urls_and_paths
48
+ from langroid.parsing.urls import get_list_from_user, get_urls_paths_bytes_indices
38
49
  from langroid.parsing.utils import batched
39
50
  from langroid.prompts.prompts_config import PromptsConfig
40
51
  from langroid.prompts.templates import SUMMARY_ANSWER_PROMPT_GPT4
41
52
  from langroid.utils.configuration import settings
42
53
  from langroid.utils.constants import NO_ANSWER
43
- from langroid.utils.output.printing import show_if_debug
44
- from langroid.vector_store.base import VectorStoreConfig
45
- from langroid.vector_store.qdrantdb import QdrantDBConfig
54
+ from langroid.utils.output import show_if_debug, status
55
+ from langroid.utils.pydantic_utils import dataframe_to_documents, extract_fields
56
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
57
+ from langroid.vector_store.lancedb import LanceDBConfig
46
58
 
47
- logger = logging.getLogger(__name__)
48
59
 
49
- console = Console()
60
+ @cache
61
+ def apply_nest_asyncio() -> None:
62
+ nest_asyncio.apply()
63
+
64
+
65
+ logger = logging.getLogger(__name__)
50
66
 
51
67
  DEFAULT_DOC_CHAT_INSTRUCTIONS = """
52
68
  Your task is to answer questions about various documents.
@@ -58,25 +74,29 @@ DEFAULT_DOC_CHAT_SYSTEM_MESSAGE = """
58
74
  You are a helpful assistant, helping me understand a collection of documents.
59
75
  """
60
76
 
77
+ has_sentence_transformers = False
78
+ try:
79
+ from sentence_transformer import SentenceTransformer # noqa: F401
61
80
 
62
- class DocChatAgentConfig(ChatAgentConfig):
63
- """
64
- Attributes:
65
- max_context_tokens (int): threshold to use for various steps, e.g.
66
- if we are able to fit the current stage of doc processing into
67
- this many tokens, we skip additional compression steps, and
68
- use the current docs as-is in the context
69
- conversation_mode (bool): if True, we will accumulate message history,
70
- and pass entire history to LLM at each round.
71
- If False, each request to LLM will consist only of the
72
- initial task messages plus the current query.
73
- """
81
+ has_sentence_transformers = True
82
+ except ImportError:
83
+ pass
74
84
 
85
+
86
+ class DocChatAgentConfig(ChatAgentConfig):
75
87
  system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
76
88
  user_message: str = DEFAULT_DOC_CHAT_INSTRUCTIONS
77
89
  summarize_prompt: str = SUMMARY_ANSWER_PROMPT_GPT4
78
- max_context_tokens: int = 1000
79
- conversation_mode: bool = True
90
+ # extra fields to include in content as key=value pairs
91
+ # (helps retrieval for table-like data)
92
+ add_fields_to_content: List[str] = []
93
+ filter_fields: List[str] = [] # fields usable in filter
94
+ retrieve_only: bool = False # only retr relevant extracts, don't gen summary answer
95
+ extraction_granularity: int = 1 # granularity (in sentences) for relev extraction
96
+ filter: str | None = (
97
+ None # filter condition for various lexical/semantic search fns
98
+ )
99
+ conversation_mode: bool = True # accumulate message history?
80
100
  # In assistant mode, DocChatAgent receives questions from another Agent,
81
101
  # and those will already be in stand-alone form, so in this mode
82
102
  # there is no need to convert them to stand-alone form.
@@ -88,14 +108,26 @@ class DocChatAgentConfig(ChatAgentConfig):
88
108
  # It is False by default; its benefits depends on the context.
89
109
  hypothetical_answer: bool = False
90
110
  n_query_rephrases: int = 0
111
+ n_neighbor_chunks: int = 0 # how many neighbors on either side of match to retrieve
112
+ n_fuzzy_neighbor_words: int = 100 # num neighbor words to retrieve for fuzzy match
91
113
  use_fuzzy_match: bool = True
92
114
  use_bm25_search: bool = True
93
- cross_encoder_reranking_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
115
+ cross_encoder_reranking_model: str = (
116
+ "cross-encoder/ms-marco-MiniLM-L-6-v2" if has_sentence_transformers else ""
117
+ )
118
+ rerank_diversity: bool = True # rerank to maximize diversity?
119
+ rerank_periphery: bool = True # rerank to avoid Lost In the Middle effect?
94
120
  embed_batch_size: int = 500 # get embedding of at most this many at a time
95
121
  cache: bool = True # cache results
96
122
  debug: bool = False
97
123
  stream: bool = True # allow streaming where needed
98
- doc_paths: List[str] = []
124
+ split: bool = True # use chunking
125
+ relevance_extractor_config: None | RelevanceExtractorAgentConfig = (
126
+ RelevanceExtractorAgentConfig(
127
+ llm=None # use the parent's llm unless explicitly set here
128
+ )
129
+ )
130
+ doc_paths: List[str | bytes] = []
99
131
  default_paths: List[str] = [
100
132
  "https://news.ycombinator.com/item?id=35629033",
101
133
  "https://www.newyorker.com/tech/annals-of-technology/chatgpt-is-a-blurry-jpeg-of-the-web",
@@ -115,11 +147,12 @@ class DocChatAgentConfig(ChatAgentConfig):
115
147
  min_chunk_chars=200,
116
148
  discard_chunk_chars=5, # discard chunks with fewer than this many chars
117
149
  n_similar_docs=3,
150
+ n_neighbor_ids=0, # num chunk IDs to store on either side of each chunk
118
151
  pdf=PdfParsingConfig(
119
152
  # NOTE: PDF parsing is extremely challenging, and each library
120
153
  # has its own strengths and weaknesses.
121
154
  # Try one that works for your use case.
122
- # or "haystack", "unstructured", "pdfplumber", "fitz", "pypdf"
155
+ # or "unstructured", "pdfplumber", "fitz", "pypdf"
123
156
  library="pdfplumber",
124
157
  ),
125
158
  )
@@ -136,10 +169,11 @@ class DocChatAgentConfig(ChatAgentConfig):
136
169
  dims=1536,
137
170
  )
138
171
 
139
- vecdb: VectorStoreConfig = QdrantDBConfig(
140
- collection_name=None,
141
- storage_path=".qdrant/data/",
142
- embedding=hf_embed_config,
172
+ vecdb: VectorStoreConfig = LanceDBConfig(
173
+ collection_name="doc-chat-lancedb",
174
+ replace_collection=True,
175
+ storage_path=".lancedb/data/",
176
+ embedding=hf_embed_config if has_sentence_transformers else oai_embed_config,
143
177
  )
144
178
  llm: OpenAIGPTConfig = OpenAIGPTConfig(
145
179
  type="openai",
@@ -163,14 +197,40 @@ class DocChatAgent(ChatAgent):
163
197
  ):
164
198
  super().__init__(config)
165
199
  self.config: DocChatAgentConfig = config
166
- self.original_docs: None | List[Document] = None
200
+ self.original_docs: List[Document] = []
167
201
  self.original_docs_length = 0
168
- self.chunked_docs: None | List[Document] = None
169
- self.chunked_docs_clean: None | List[Document] = None
202
+ self.from_dataframe = False
203
+ self.df_description = ""
204
+ self.chunked_docs: List[Document] = []
205
+ self.chunked_docs_clean: List[Document] = []
170
206
  self.response: None | Document = None
171
207
  if len(config.doc_paths) > 0:
172
208
  self.ingest()
173
209
 
210
+ def clear(self) -> None:
211
+ """Clear the document collection and the specific collection in vecdb"""
212
+ if self.vecdb is None:
213
+ raise ValueError("VecDB not set")
214
+ self.original_docs = []
215
+ self.original_docs_length = 0
216
+ self.chunked_docs = []
217
+ self.chunked_docs_clean = []
218
+ collection_name = self.vecdb.config.collection_name
219
+ if collection_name is None:
220
+ return
221
+ try:
222
+ # Note we may have used a vecdb with a config.collection_name
223
+ # different from the agent's config.vecdb.collection_name!!
224
+ self.vecdb.delete_collection(collection_name)
225
+ self.vecdb = VectorStore.create(self.vecdb.config)
226
+ except Exception as e:
227
+ logger.warning(
228
+ f"""
229
+ Error while deleting collection {collection_name}:
230
+ {e}
231
+ """
232
+ )
233
+
174
234
  def ingest(self) -> None:
175
235
  """
176
236
  Chunk + embed + store docs specified by self.config.doc_paths
@@ -187,59 +247,316 @@ class DocChatAgent(ChatAgent):
187
247
  # do keyword and other non-vector searches
188
248
  if self.vecdb is None:
189
249
  raise ValueError("VecDB not set")
190
- self.chunked_docs = self.vecdb.get_all_documents()
191
- self.chunked_docs_clean = [
192
- Document(content=preprocess_text(d.content), metadata=d.metadata)
193
- for d in self.chunked_docs
194
- ]
250
+ self.setup_documents(filter=self.config.filter)
195
251
  return
196
- urls, paths = get_urls_and_paths(self.config.doc_paths)
252
+ self.ingest_doc_paths(self.config.doc_paths) # type: ignore
253
+
254
+ def ingest_doc_paths(
255
+ self,
256
+ paths: str | bytes | List[str | bytes],
257
+ metadata: (
258
+ List[Dict[str, Any]] | Dict[str, Any] | DocMetaData | List[DocMetaData]
259
+ ) = [],
260
+ doc_type: str | DocumentType | None = None,
261
+ ) -> List[Document]:
262
+ """Split, ingest docs from specified paths,
263
+ do not add these to config.doc_paths.
264
+
265
+ Args:
266
+ paths: document paths, urls or byte-content of docs.
267
+ The bytes option is intended to support cases where a document
268
+ has already been read in as bytes (e.g. from an API or a database),
269
+ and we want to avoid having to write it to a temporary file
270
+ just to read it back in.
271
+ metadata: List of metadata dicts, one for each path.
272
+ If a single dict is passed in, it is used for all paths.
273
+ doc_type: DocumentType to use for parsing, if known.
274
+ MUST apply to all docs if specified.
275
+ This is especially useful when the `paths` are of bytes type,
276
+ to help with document type detection.
277
+ Returns:
278
+ List of Document objects
279
+ """
280
+ if isinstance(paths, str) or isinstance(paths, bytes):
281
+ paths = [paths]
282
+ all_paths = paths
283
+ paths_meta: Dict[int, Any] = {}
284
+ urls_meta: Dict[int, Any] = {}
285
+ idxs = range(len(all_paths))
286
+ url_idxs, path_idxs, bytes_idxs = get_urls_paths_bytes_indices(all_paths)
287
+ urls = [all_paths[i] for i in url_idxs]
288
+ paths = [all_paths[i] for i in path_idxs]
289
+ bytes_list = [all_paths[i] for i in bytes_idxs]
290
+ path_idxs.extend(bytes_idxs)
291
+ paths.extend(bytes_list)
292
+ if (isinstance(metadata, list) and len(metadata) > 0) or not isinstance(
293
+ metadata, list
294
+ ):
295
+ if isinstance(metadata, list):
296
+ idx2meta = {
297
+ p: (
298
+ m
299
+ if isinstance(m, dict)
300
+ else (isinstance(m, DocMetaData) and m.dict())
301
+ ) # appease mypy
302
+ for p, m in zip(idxs, metadata)
303
+ }
304
+ elif isinstance(metadata, dict):
305
+ idx2meta = {p: metadata for p in idxs}
306
+ else:
307
+ idx2meta = {p: metadata.dict() for p in idxs}
308
+ urls_meta = {u: idx2meta[u] for u in url_idxs}
309
+ paths_meta = {p: idx2meta[p] for p in path_idxs}
197
310
  docs: List[Document] = []
198
311
  parser = Parser(self.config.parsing)
199
312
  if len(urls) > 0:
200
- loader = URLLoader(urls=urls, parser=parser)
201
- docs = loader.load()
202
- if len(paths) > 0:
203
- for p in paths:
204
- path_docs = RepoLoader.get_documents(p, parser=parser)
313
+ for ui in url_idxs:
314
+ meta = urls_meta.get(ui, {})
315
+ loader = URLLoader(urls=[all_paths[ui]], parser=parser) # type: ignore
316
+ url_docs = loader.load()
317
+ # update metadata of each doc with meta
318
+ for d in url_docs:
319
+ d.metadata = d.metadata.copy(update=meta)
320
+ docs.extend(url_docs)
321
+ if len(paths) > 0: # paths OR bytes are handled similarly
322
+ for pi in path_idxs:
323
+ meta = paths_meta.get(pi, {})
324
+ p = all_paths[pi]
325
+ path_docs = RepoLoader.get_documents(
326
+ p,
327
+ parser=parser,
328
+ doc_type=doc_type,
329
+ )
330
+ # update metadata of each doc with meta
331
+ for d in path_docs:
332
+ d.metadata = d.metadata.copy(update=meta)
205
333
  docs.extend(path_docs)
206
334
  n_docs = len(docs)
207
- n_splits = self.ingest_docs(docs)
335
+ n_splits = self.ingest_docs(docs, split=self.config.split)
208
336
  if n_docs == 0:
209
- return
337
+ return []
210
338
  n_urls = len(urls)
211
339
  n_paths = len(paths)
212
340
  print(
213
341
  f"""
214
342
  [green]I have processed the following {n_urls} URLs
215
- and {n_paths} paths into {n_splits} parts:
343
+ and {n_paths} docs into {n_splits} parts:
216
344
  """.strip()
217
345
  )
218
- print("\n".join(urls))
219
- print("\n".join(paths))
346
+ path_reps = [p if isinstance(p, str) else "bytes" for p in paths]
347
+ print("\n".join([u for u in urls if isinstance(u, str)])) # appease mypy
348
+ print("\n".join(path_reps))
349
+ return docs
220
350
 
221
- def ingest_docs(self, docs: List[Document]) -> int:
351
+ def ingest_docs(
352
+ self,
353
+ docs: List[Document],
354
+ split: bool = True,
355
+ metadata: (
356
+ List[Dict[str, Any]] | Dict[str, Any] | DocMetaData | List[DocMetaData]
357
+ ) = [],
358
+ ) -> int:
222
359
  """
223
360
  Chunk docs into pieces, map each chunk to vec-embedding, store in vec-db
361
+
362
+ Args:
363
+ docs: List of Document objects
364
+ split: Whether to split docs into chunks. Default is True.
365
+ If False, docs are treated as "chunks" and are not split.
366
+ metadata: List of metadata dicts, one for each doc, to augment
367
+ whatever metadata is already in the doc.
368
+ [ASSUME no conflicting keys between the two metadata dicts.]
369
+ If a single dict is passed in, it is used for all docs.
224
370
  """
225
- self.original_docs = docs
371
+ if isinstance(metadata, list) and len(metadata) > 0:
372
+ for d, m in zip(docs, metadata):
373
+ d.metadata = d.metadata.copy(
374
+ update=m if isinstance(m, dict) else m.dict() # type: ignore
375
+ )
376
+ elif isinstance(metadata, dict):
377
+ for d in docs:
378
+ d.metadata = d.metadata.copy(update=metadata)
379
+ elif isinstance(metadata, DocMetaData):
380
+ for d in docs:
381
+ d.metadata = d.metadata.copy(update=metadata.dict())
382
+
383
+ self.original_docs.extend(docs)
226
384
  if self.parser is None:
227
385
  raise ValueError("Parser not set")
228
- docs = self.parser.split(docs)
229
- self.chunked_docs = docs
230
- self.chunked_docs_clean = [
231
- Document(content=preprocess_text(d.content), metadata=d.metadata)
232
- for d in self.chunked_docs
233
- ]
386
+ for d in docs:
387
+ if d.metadata.id in [None, ""]:
388
+ d.metadata.id = d._unique_hash_id()
389
+ if split:
390
+ docs = self.parser.split(docs)
391
+ else:
392
+ # treat each doc as a chunk
393
+ for d in docs:
394
+ d.metadata.is_chunk = True
234
395
  if self.vecdb is None:
235
396
  raise ValueError("VecDB not set")
397
+
398
+ # If any additional fields need to be added to content,
399
+ # add them as key=value pairs for all docs, before batching.
400
+ # This helps retrieval for table-like data.
401
+ # Note we need to do this at stage so that the embeddings
402
+ # are computed on the full content with these additional fields.
403
+ if len(self.config.add_fields_to_content) > 0:
404
+ fields = [
405
+ f for f in extract_fields(docs[0], self.config.add_fields_to_content)
406
+ ]
407
+ if len(fields) > 0:
408
+ for d in docs:
409
+ key_vals = extract_fields(d, fields)
410
+ d.content = (
411
+ ",".join(f"{k}={v}" for k, v in key_vals.items())
412
+ + ",content="
413
+ + d.content
414
+ )
415
+ docs = docs[: self.config.parsing.max_chunks]
236
416
  # add embeddings in batches, to stay under limit of embeddings API
237
417
  batches = list(batched(docs, self.config.embed_batch_size))
238
418
  for batch in batches:
239
419
  self.vecdb.add_documents(batch)
240
420
  self.original_docs_length = self.doc_length(docs)
421
+ self.setup_documents(docs, filter=self.config.filter)
241
422
  return len(docs)
242
423
 
424
+ @staticmethod
425
+ def document_compatible_dataframe(
426
+ df: pd.DataFrame,
427
+ content: str = "content",
428
+ metadata: List[str] = [],
429
+ ) -> Tuple[pd.DataFrame, List[str]]:
430
+ """
431
+ Convert dataframe so it is compatible with Document class:
432
+ - has "content" column
433
+ - has an "id" column to be used as Document.metadata.id
434
+
435
+ Args:
436
+ df: dataframe to convert
437
+ content: name of content column
438
+ metadata: list of metadata column names
439
+
440
+ Returns:
441
+ Tuple[pd.DataFrame, List[str]]: dataframe, metadata
442
+ - dataframe: dataframe with "content" column and "id" column
443
+ - metadata: list of metadata column names, including "id"
444
+ """
445
+ if content not in df.columns:
446
+ raise ValueError(
447
+ f"""
448
+ Content column {content} not in dataframe,
449
+ so we cannot ingest into the DocChatAgent.
450
+ Please specify the `content` parameter as a suitable
451
+ text-based column in the dataframe.
452
+ """
453
+ )
454
+ if content != "content":
455
+ # rename content column to "content", leave existing column intact
456
+ df = df.rename(columns={content: "content"}, inplace=False)
457
+
458
+ actual_metadata = metadata.copy()
459
+ if "id" not in df.columns:
460
+ docs = dataframe_to_documents(df, content="content", metadata=metadata)
461
+ ids = [str(d.id()) for d in docs]
462
+ df["id"] = ids
463
+
464
+ if "id" not in actual_metadata:
465
+ actual_metadata += ["id"]
466
+
467
+ return df, actual_metadata
468
+
469
+ def ingest_dataframe(
470
+ self,
471
+ df: pd.DataFrame,
472
+ content: str = "content",
473
+ metadata: List[str] = [],
474
+ ) -> int:
475
+ """
476
+ Ingest a dataframe into vecdb.
477
+ """
478
+ self.from_dataframe = True
479
+ self.df_description = describe_dataframe(
480
+ df, filter_fields=self.config.filter_fields, n_vals=5
481
+ )
482
+ df, metadata = DocChatAgent.document_compatible_dataframe(df, content, metadata)
483
+ docs = dataframe_to_documents(df, content="content", metadata=metadata)
484
+ # When ingesting a dataframe we will no longer do any chunking,
485
+ # so we mark each doc as a chunk.
486
+ # TODO - revisit this since we may still want to chunk large text columns
487
+ for d in docs:
488
+ d.metadata.is_chunk = True
489
+ return self.ingest_docs(docs)
490
+
491
+ def set_filter(self, filter: str) -> None:
492
+ self.config.filter = filter
493
+ self.setup_documents(filter=filter)
494
+
495
+ def setup_documents(
496
+ self,
497
+ docs: List[Document] = [],
498
+ filter: str | None = None,
499
+ ) -> None:
500
+ """
501
+ Setup `self.chunked_docs` and `self.chunked_docs_clean`
502
+ based on possible filter.
503
+ These will be used in various non-vector-based search functions,
504
+ e.g. self.get_similar_chunks_bm25(), self.get_fuzzy_matches(), etc.
505
+
506
+ Args:
507
+ docs: List of Document objects. This is empty when we are calling this
508
+ method after initial doc ingestion.
509
+ filter: Filter condition for various lexical/semantic search fns.
510
+ """
511
+ if filter is None and len(docs) > 0:
512
+ # no filter, so just use the docs passed in
513
+ self.chunked_docs.extend(docs)
514
+ else:
515
+ if self.vecdb is None:
516
+ raise ValueError("VecDB not set")
517
+ self.chunked_docs = self.vecdb.get_all_documents(where=filter or "")
518
+
519
+ self.chunked_docs_clean = [
520
+ Document(content=preprocess_text(d.content), metadata=d.metadata)
521
+ for d in self.chunked_docs
522
+ ]
523
+
524
+ def get_field_values(self, fields: list[str]) -> Dict[str, str]:
525
+ """Get string-listing of possible values of each filterable field,
526
+ e.g.
527
+ {
528
+ "genre": "crime, drama, mystery, ... (10 more)",
529
+ "certificate": "R, PG-13, PG, R",
530
+ }
531
+ """
532
+ field_values: Dict[str, Set[str]] = {}
533
+ # make empty set for each field
534
+ for f in fields:
535
+ field_values[f] = set()
536
+ if self.vecdb is None:
537
+ raise ValueError("VecDB not set")
538
+ # get all documents and accumulate possible values of each field until 10
539
+ docs = self.vecdb.get_all_documents() # only works for vecdbs that support this
540
+ for d in docs:
541
+ # extract fields from d
542
+ doc_field_vals = extract_fields(d, fields)
543
+ for field, val in doc_field_vals.items():
544
+ field_values[field].add(val)
545
+ # For each field make a string showing list of possible values,
546
+ # truncate to 20 values, and if there are more, indicate how many
547
+ # more there are, e.g. Genre: crime, drama, mystery, ... (20 more)
548
+ field_values_list = {}
549
+ for f in fields:
550
+ vals = list(field_values[f])
551
+ n = len(vals)
552
+ remaining = n - 20
553
+ vals = vals[:20]
554
+ if n > 20:
555
+ vals.append(f"(...{remaining} more)")
556
+ # make a string of the values, ensure they are strings
557
+ field_values_list[f] = ", ".join(str(v) for v in vals)
558
+ return field_values_list
559
+
243
560
  def doc_length(self, docs: List[Document]) -> int:
244
561
  """
245
562
  Calc token-length of a list of docs
@@ -252,7 +569,78 @@ class DocChatAgent(ChatAgent):
252
569
  raise ValueError("Parser not set")
253
570
  return self.parser.num_tokens(self.doc_string(docs))
254
571
 
255
- @no_type_check
572
+ def user_docs_ingest_dialog(self) -> None:
573
+ """
574
+ Ask user to select doc-collection, enter filenames/urls, and ingest into vecdb.
575
+ """
576
+ if self.vecdb is None:
577
+ raise ValueError("VecDB not set")
578
+ n_deletes = self.vecdb.clear_empty_collections()
579
+ collections = self.vecdb.list_collections()
580
+ collection_name = "NEW"
581
+ is_new_collection = False
582
+ replace_collection = False
583
+ if len(collections) > 0:
584
+ n = len(collections)
585
+ delete_str = (
586
+ f"(deleted {n_deletes} empty collections)" if n_deletes > 0 else ""
587
+ )
588
+ print(f"Found {n} collections: {delete_str}")
589
+ for i, option in enumerate(collections, start=1):
590
+ print(f"{i}. {option}")
591
+ while True:
592
+ choice = Prompt.ask(
593
+ f"Enter 1-{n} to select a collection, "
594
+ "or hit ENTER to create a NEW collection, "
595
+ "or -1 to DELETE ALL COLLECTIONS",
596
+ default="0",
597
+ )
598
+ try:
599
+ if -1 <= int(choice) <= n:
600
+ break
601
+ except Exception:
602
+ pass
603
+
604
+ if choice == "-1":
605
+ confirm = Prompt.ask(
606
+ "Are you sure you want to delete all collections?",
607
+ choices=["y", "n"],
608
+ default="n",
609
+ )
610
+ if confirm == "y":
611
+ self.vecdb.clear_all_collections(really=True)
612
+ collection_name = "NEW"
613
+
614
+ if int(choice) > 0:
615
+ collection_name = collections[int(choice) - 1]
616
+ print(f"Using collection {collection_name}")
617
+ choice = Prompt.ask(
618
+ "Would you like to replace this collection?",
619
+ choices=["y", "n"],
620
+ default="n",
621
+ )
622
+ replace_collection = choice == "y"
623
+
624
+ if collection_name == "NEW":
625
+ is_new_collection = True
626
+ collection_name = Prompt.ask(
627
+ "What would you like to name the NEW collection?",
628
+ default="doc-chat",
629
+ )
630
+
631
+ self.vecdb.set_collection(collection_name, replace=replace_collection)
632
+
633
+ default_urls_str = (
634
+ " (or leave empty for default URLs)" if is_new_collection else ""
635
+ )
636
+ print(f"[blue]Enter some URLs or file/dir paths below {default_urls_str}")
637
+ inputs = get_list_from_user()
638
+ if len(inputs) == 0:
639
+ if is_new_collection:
640
+ inputs = self.config.default_paths
641
+ self.config.doc_paths = inputs # type: ignore
642
+ self.ingest()
643
+
256
644
  def llm_response(
257
645
  self,
258
646
  query: None | str | ChatDocument = None,
@@ -269,10 +657,55 @@ class DocChatAgent(ChatAgent):
269
657
  query_str = query_str[1:] if query_str is not None else None
270
658
  if self.llm is None:
271
659
  raise ValueError("LLM not set")
272
- with StreamingIfAllowed(self.llm):
660
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
273
661
  response = super().llm_response(query_str)
274
662
  if query_str is not None:
275
- self.update_dialog(query_str, response.content)
663
+ self.update_dialog(
664
+ query_str, "" if response is None else response.content
665
+ )
666
+ return response
667
+ if query_str == "":
668
+ return None
669
+ elif query_str == "?" and self.response is not None:
670
+ return self.justify_response()
671
+ elif (query_str.startswith(("summar", "?")) and self.response is None) or (
672
+ query_str == "??"
673
+ ):
674
+ return self.summarize_docs()
675
+ else:
676
+ self.callbacks.show_start_response(entity="llm")
677
+ response = self.answer_from_docs(query_str)
678
+ return ChatDocument(
679
+ content=response.content,
680
+ metadata=ChatDocMetaData(
681
+ source=response.metadata.source,
682
+ sender=Entity.LLM,
683
+ ),
684
+ )
685
+
686
+ async def llm_response_async(
687
+ self,
688
+ query: None | str | ChatDocument = None,
689
+ ) -> Optional[ChatDocument]:
690
+ apply_nest_asyncio()
691
+ if not self.llm_can_respond(query):
692
+ return None
693
+ query_str: str | None
694
+ if isinstance(query, ChatDocument):
695
+ query_str = query.content
696
+ else:
697
+ query_str = query
698
+ if query_str is None or query_str.startswith("!"):
699
+ # direct query to LLM
700
+ query_str = query_str[1:] if query_str is not None else None
701
+ if self.llm is None:
702
+ raise ValueError("LLM not set")
703
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
704
+ response = await super().llm_response_async(query_str)
705
+ if query_str is not None:
706
+ self.update_dialog(
707
+ query_str, "" if response is None else response.content
708
+ )
276
709
  return response
277
710
  if query_str == "":
278
711
  return None
@@ -283,6 +716,7 @@ class DocChatAgent(ChatAgent):
283
716
  ):
284
717
  return self.summarize_docs()
285
718
  else:
719
+ self.callbacks.show_start_response(entity="llm")
286
720
  response = self.answer_from_docs(query_str)
287
721
  return ChatDocument(
288
722
  content=response.content,
@@ -314,7 +748,9 @@ class DocChatAgent(ChatAgent):
314
748
  ]
315
749
  )
316
750
 
317
- def get_summary_answer(self, question: str, passages: List[Document]) -> Document:
751
+ def get_summary_answer(
752
+ self, question: str, passages: List[Document]
753
+ ) -> ChatDocument:
318
754
  """
319
755
  Given a question and a list of (possibly) doc snippets,
320
756
  generate an answer if possible
@@ -342,9 +778,6 @@ class DocChatAgent(ChatAgent):
342
778
  # 2 new LLMMessage objects:
343
779
  # one for `final_prompt`, and one for the LLM response
344
780
 
345
- # TODO need to "forget" last two messages in message_history
346
- # if we are not in conversation mode
347
-
348
781
  if self.config.conversation_mode:
349
782
  # respond with temporary context
350
783
  answer_doc = super()._llm_response_temp_context(question, final_prompt)
@@ -353,16 +786,23 @@ class DocChatAgent(ChatAgent):
353
786
 
354
787
  final_answer = answer_doc.content.strip()
355
788
  show_if_debug(final_answer, "SUMMARIZE_RESPONSE= ")
356
- parts = final_answer.split("SOURCE:", maxsplit=1)
357
- if len(parts) > 1:
358
- content = parts[0].strip()
359
- sources = parts[1].strip()
360
- else:
789
+
790
+ if final_answer.startswith("SOURCE"):
791
+ # sometimes SOURCE may be shown first,
792
+ # in this case just use final_answer as-is for both content and source
361
793
  content = final_answer
362
- sources = ""
363
- return Document(
794
+ sources = final_answer
795
+ else:
796
+ parts = final_answer.split("SOURCE:", maxsplit=1)
797
+ if len(parts) > 1:
798
+ content = parts[0].strip()
799
+ sources = parts[1].strip()
800
+ else:
801
+ content = final_answer
802
+ sources = ""
803
+ return ChatDocument(
364
804
  content=content,
365
- metadata=DocMetaData(
805
+ metadata=ChatDocMetaData(
366
806
  source="SOURCE: " + sources,
367
807
  sender=Entity.LLM,
368
808
  cached=getattr(answer_doc.metadata, "cached", False),
@@ -372,7 +812,7 @@ class DocChatAgent(ChatAgent):
372
812
  def llm_hypothetical_answer(self, query: str) -> str:
373
813
  if self.llm is None:
374
814
  raise ValueError("LLM not set")
375
- with console.status("[cyan]LLM generating hypothetical answer..."):
815
+ with status("[cyan]LLM generating hypothetical answer..."):
376
816
  with StreamingIfAllowed(self.llm, False):
377
817
  # TODO: provide an easy way to
378
818
  # Adjust this prompt depending on context.
@@ -392,7 +832,7 @@ class DocChatAgent(ChatAgent):
392
832
  def llm_rephrase_query(self, query: str) -> List[str]:
393
833
  if self.llm is None:
394
834
  raise ValueError("LLM not set")
395
- with console.status("[cyan]LLM generating rephrases of query..."):
835
+ with status("[cyan]LLM generating rephrases of query..."):
396
836
  with StreamingIfAllowed(self.llm, False):
397
837
  rephrases = self.llm_response_forget(
398
838
  f"""
@@ -408,11 +848,13 @@ class DocChatAgent(ChatAgent):
408
848
  ) -> List[Tuple[Document, float]]:
409
849
  # find similar docs using bm25 similarity:
410
850
  # these may sometimes be more likely to contain a relevant verbatim extract
411
- with console.status("[cyan]Searching for similar chunks using bm25..."):
412
- if self.chunked_docs is None:
413
- raise ValueError("No chunked docs")
414
- if self.chunked_docs_clean is None:
415
- raise ValueError("No cleaned chunked docs")
851
+ with status("[cyan]Searching for similar chunks using bm25..."):
852
+ if self.chunked_docs is None or len(self.chunked_docs) == 0:
853
+ logger.warning("No chunked docs; cannot use bm25-similarity")
854
+ return []
855
+ if self.chunked_docs_clean is None or len(self.chunked_docs_clean) == 0:
856
+ logger.warning("No cleaned chunked docs; cannot use bm25-similarity")
857
+ return []
416
858
  docs_scores = find_closest_matches_with_bm25(
417
859
  self.chunked_docs,
418
860
  self.chunked_docs_clean, # already pre-processed!
@@ -424,24 +866,27 @@ class DocChatAgent(ChatAgent):
424
866
  def get_fuzzy_matches(self, query: str, multiple: int) -> List[Document]:
425
867
  # find similar docs using fuzzy matching:
426
868
  # these may sometimes be more likely to contain a relevant verbatim extract
427
- with console.status("[cyan]Finding fuzzy matches in chunks..."):
869
+ with status("[cyan]Finding fuzzy matches in chunks..."):
428
870
  if self.chunked_docs is None:
429
- raise ValueError("No chunked docs")
871
+ logger.warning("No chunked docs; cannot use fuzzy matching")
872
+ return []
873
+ if self.chunked_docs_clean is None:
874
+ logger.warning("No cleaned chunked docs; cannot use fuzzy-search")
875
+ return []
430
876
  fuzzy_match_docs = find_fuzzy_matches_in_docs(
431
877
  query,
432
878
  self.chunked_docs,
879
+ self.chunked_docs_clean,
433
880
  k=self.config.parsing.n_similar_docs * multiple,
434
- words_before=1000,
435
- words_after=1000,
881
+ words_before=self.config.n_fuzzy_neighbor_words,
882
+ words_after=self.config.n_fuzzy_neighbor_words,
436
883
  )
437
884
  return fuzzy_match_docs
438
885
 
439
886
  def rerank_with_cross_encoder(
440
887
  self, query: str, passages: List[Document]
441
888
  ) -> List[Document]:
442
- with console.status("[cyan]Re-ranking retrieved chunks using cross-encoder..."):
443
- if self.chunked_docs is None:
444
- raise ValueError("No chunked docs")
889
+ with status("[cyan]Re-ranking retrieved chunks using cross-encoder..."):
445
890
  try:
446
891
  from sentence_transformers import CrossEncoder
447
892
  except ImportError:
@@ -455,6 +900,8 @@ class DocChatAgent(ChatAgent):
455
900
 
456
901
  model = CrossEncoder(self.config.cross_encoder_reranking_model)
457
902
  scores = model.predict([(query, p.content) for p in passages])
903
+ # Convert to [0,1] so we might could use a cutoff later.
904
+ scores = 1.0 / (1 + np.exp(-np.array(scores)))
458
905
  # get top k scoring passages
459
906
  sorted_pairs = sorted(
460
907
  zip(scores, passages),
@@ -466,66 +913,187 @@ class DocChatAgent(ChatAgent):
466
913
  ]
467
914
  return passages
468
915
 
469
- @no_type_check
470
- def get_relevant_extracts(self, query: str) -> Tuple[str, List[Document]]:
916
+ def rerank_with_diversity(self, passages: List[Document]) -> List[Document]:
471
917
  """
472
- Get list of extracts from doc-chunks relevant to answering a query.
473
- These are the stages (some optional based on config):
474
- - use LLM to convert query to stand-alone query
475
- - optionally rephrase query to use below
476
- - optionally generate hypothetical answer (HyDE) to use below.
477
- - get relevant doc-chunks, via:
478
- - vector-embedding distance, from vecdb
479
- - bm25-ranking (keyword similarity)
480
- - fuzzy matching (keyword similarity)
481
- - re-ranking of doc-chunks using cross-encoder, pick top k
482
- - use LLM to get relevant extracts from doc-chunks
918
+ Rerank a list of items in such a way that each successive item is least similar
919
+ (on average) to the earlier items.
483
920
 
484
921
  Args:
485
- query (str): query to search for
922
+ query (str): The query for which the passages are relevant.
923
+ passages (List[Document]): A list of Documents to be reranked.
486
924
 
487
925
  Returns:
488
- query (str): stand-alone version of input query
489
- List[Document]: list of relevant extracts
926
+ List[Documents]: A reranked list of Documents.
927
+ """
490
928
 
929
+ if self.vecdb is None:
930
+ logger.warning("No vecdb; cannot use rerank_with_diversity")
931
+ return passages
932
+ emb_model = self.vecdb.embedding_model
933
+ emb_fn = emb_model.embedding_fn()
934
+ embs = emb_fn([p.content for p in passages])
935
+ embs_arr = [np.array(e) for e in embs]
936
+ indices = list(range(len(passages)))
937
+
938
+ # Helper function to compute average similarity to
939
+ # items in the current result list.
940
+ def avg_similarity_to_result(i: int, result: List[int]) -> float:
941
+ return sum( # type: ignore
942
+ (embs_arr[i] @ embs_arr[j])
943
+ / (np.linalg.norm(embs_arr[i]) * np.linalg.norm(embs_arr[j]))
944
+ for j in result
945
+ ) / len(result)
946
+
947
+ # copy passages to items
948
+ result = [indices.pop(0)] # Start with the first item.
949
+
950
+ while indices:
951
+ # Find the item that has the least average similarity
952
+ # to items in the result list.
953
+ least_similar_item = min(
954
+ indices, key=lambda i: avg_similarity_to_result(i, result)
955
+ )
956
+ result.append(least_similar_item)
957
+ indices.remove(least_similar_item)
958
+
959
+ # return passages in order of result list
960
+ return [passages[i] for i in result]
961
+
962
+ def rerank_to_periphery(self, passages: List[Document]) -> List[Document]:
491
963
  """
492
- if len(self.dialog) > 0 and not self.config.assistant_mode:
493
- # Regardless of whether we are in conversation mode or not,
494
- # for relevant doc/chunk extraction, we must convert the query
495
- # to a standalone query to get more relevant results.
496
- with console.status("[cyan]Converting to stand-alone query...[/cyan]"):
497
- with StreamingIfAllowed(self.llm, False):
498
- query = self.llm.followup_to_standalone(self.dialog, query)
499
- print(f"[orange2]New query: {query}")
964
+ Rerank to avoid Lost In the Middle (LIM) problem,
965
+ where LLMs pay more attention to items at the ends of a list,
966
+ rather than the middle. So we re-rank to make the best passages
967
+ appear at the periphery of the list.
968
+ https://arxiv.org/abs/2307.03172
500
969
 
970
+ Example reranking:
971
+ 1 2 3 4 5 6 7 8 9 ==> 1 3 5 7 9 8 6 4 2
972
+
973
+ Args:
974
+ passages (List[Document]): A list of Documents to be reranked.
975
+
976
+ Returns:
977
+ List[Documents]: A reranked list of Documents.
978
+
979
+ """
980
+ # Splitting items into odds and evens based on index, not value
981
+ odds = passages[::2]
982
+ evens = passages[1::2][::-1]
983
+
984
+ # Merging them back together
985
+ return odds + evens
986
+
987
+ def add_context_window(
988
+ self,
989
+ docs_scores: List[Tuple[Document, float]],
990
+ ) -> List[Tuple[Document, float]]:
991
+ """
992
+ In each doc's metadata, there may be a window_ids field indicating
993
+ the ids of the chunks around the current chunk. We use these stored
994
+ window_ids to retrieve the desired number
995
+ (self.config.n_neighbor_chunks) of neighbors
996
+ on either side of the current chunk.
997
+
998
+ Args:
999
+ docs_scores (List[Tuple[Document, float]]): List of pairs of documents
1000
+ to add context windows to together with their match scores.
1001
+
1002
+ Returns:
1003
+ List[Tuple[Document, float]]: List of (Document, score) tuples.
1004
+ """
1005
+ if self.vecdb is None or self.config.n_neighbor_chunks == 0:
1006
+ return docs_scores
1007
+ if len(docs_scores) == 0:
1008
+ return []
1009
+ if set(docs_scores[0][0].__fields__) != {"content", "metadata"}:
1010
+ # Do not add context window when there are other fields besides just
1011
+ # content and metadata, since we do not know how to set those other fields
1012
+ # for newly created docs with combined content.
1013
+ return docs_scores
1014
+ return self.vecdb.add_context_window(docs_scores, self.config.n_neighbor_chunks)
1015
+
1016
+ def get_semantic_search_results(
1017
+ self,
1018
+ query: str,
1019
+ k: int = 10,
1020
+ ) -> List[Tuple[Document, float]]:
1021
+ """
1022
+ Get semantic search results from vecdb.
1023
+ Args:
1024
+ query (str): query to search for
1025
+ k (int): number of results to return
1026
+ Returns:
1027
+ List[Tuple[Document, float]]: List of (Document, score) tuples.
1028
+ """
1029
+ if self.vecdb is None:
1030
+ raise ValueError("VecDB not set")
1031
+ # Note: for dynamic filtering based on a query, users can
1032
+ # use the `temp_update` context-manager to pass in a `filter` to self.config,
1033
+ # e.g.:
1034
+ # with temp_update(self.config, {"filter": "metadata.source=='source1'"}):
1035
+ # docs_scores = self.get_semantic_search_results(query, k=k)
1036
+ # This avoids having pass the `filter` argument to every function call
1037
+ # upstream of this one.
1038
+ # The `temp_update` context manager is defined in
1039
+ # `langroid/utils/pydantic_utils.py`
1040
+ return self.vecdb.similar_texts_with_scores(
1041
+ query,
1042
+ k=k,
1043
+ where=self.config.filter,
1044
+ )
1045
+
1046
+ def get_relevant_chunks(
1047
+ self, query: str, query_proxies: List[str] = []
1048
+ ) -> List[Document]:
1049
+ """
1050
+ The retrieval stage in RAG: get doc-chunks that are most "relevant"
1051
+ to the query (and possibly any proxy queries), from the document-store,
1052
+ which currently is the vector store,
1053
+ but in theory could be any document store, or even web-search.
1054
+ This stage does NOT involve an LLM, and the retrieved chunks
1055
+ could either be pre-chunked text (from the initial pre-processing stage
1056
+ where chunks were stored in the vector store), or they could be
1057
+ dynamically retrieved based on a window around a lexical match.
1058
+
1059
+ These are the steps (some optional based on config):
1060
+ - semantic search based on vector-embedding distance, from vecdb
1061
+ - lexical search using bm25-ranking (keyword similarity)
1062
+ - fuzzy matching (keyword similarity)
1063
+ - re-ranking of doc-chunks by relevance to query, using cross-encoder,
1064
+ and pick top k
1065
+
1066
+ Args:
1067
+ query: original query (assumed to be in stand-alone form)
1068
+ query_proxies: possible rephrases, or hypothetical answer to query
1069
+ (e.g. for HyDE-type retrieval)
1070
+
1071
+ Returns:
1072
+
1073
+ """
501
1074
  # if we are using cross-encoder reranking, we can retrieve more docs
502
1075
  # during retrieval, and leave it to the cross-encoder re-ranking
503
1076
  # to whittle down to self.config.parsing.n_similar_docs
504
1077
  retrieval_multiple = 1 if self.config.cross_encoder_reranking_model == "" else 3
505
- queries = [query]
506
- if self.config.hypothetical_answer:
507
- answer = self.llm_hypothetical_answer(query)
508
- queries = [query, answer]
509
1078
 
510
- if self.config.n_query_rephrases > 0:
511
- rephrases = self.llm_rephrase_query(query)
512
- queries += rephrases
1079
+ if self.vecdb is None:
1080
+ raise ValueError("VecDB not set")
513
1081
 
514
- with console.status("[cyan]Searching VecDB for relevant doc passages..."):
515
- docs_and_scores = []
516
- for q in queries:
517
- docs_and_scores += self.vecdb.similar_texts_with_scores(
1082
+ with status("[cyan]Searching VecDB for relevant doc passages..."):
1083
+ docs_and_scores: List[Tuple[Document, float]] = []
1084
+ for q in [query] + query_proxies:
1085
+ docs_and_scores += self.get_semantic_search_results(
518
1086
  q,
519
1087
  k=self.config.parsing.n_similar_docs * retrieval_multiple,
520
1088
  )
521
1089
  # keep only docs with unique d.id()
522
1090
  id2doc_score = {d.id(): (d, s) for d, s in docs_and_scores}
523
1091
  docs_and_scores = list(id2doc_score.values())
524
-
525
- passages = [
526
- Document(content=d.content, metadata=d.metadata)
527
- for (d, _) in docs_and_scores
528
- ]
1092
+ passages = [d for (d, _) in docs_and_scores]
1093
+ # passages = [
1094
+ # Document(content=d.content, metadata=d.metadata)
1095
+ # for (d, _) in docs_and_scores
1096
+ # ]
529
1097
 
530
1098
  if self.config.use_bm25_search:
531
1099
  docs_scores = self.get_similar_chunks_bm25(query, retrieval_multiple)
@@ -539,25 +1107,136 @@ class DocChatAgent(ChatAgent):
539
1107
  id2passage = {p.id(): p for p in passages}
540
1108
  passages = list(id2passage.values())
541
1109
 
1110
+ if len(passages) == 0:
1111
+ return []
1112
+
1113
+ passages_scores = [(p, 0.0) for p in passages]
1114
+ passages_scores = self.add_context_window(passages_scores)
1115
+ passages = [p for p, _ in passages_scores]
542
1116
  # now passages can potentially have a lot of doc chunks,
543
- # so we re-rank them using a cross-encoder scoring model
1117
+ # so we re-rank them using a cross-encoder scoring model,
1118
+ # and pick top k where k = config.parsing.n_similar_docs
544
1119
  # https://www.sbert.net/examples/applications/retrieve_rerank
545
1120
  if self.config.cross_encoder_reranking_model != "":
546
1121
  passages = self.rerank_with_cross_encoder(query, passages)
547
1122
 
1123
+ if self.config.rerank_diversity:
1124
+ # reorder to increase diversity among top docs
1125
+ passages = self.rerank_with_diversity(passages)
1126
+
1127
+ if self.config.rerank_periphery:
1128
+ # reorder so most important docs are at periphery
1129
+ # (see Lost In the Middle issue).
1130
+ passages = self.rerank_to_periphery(passages)
1131
+
1132
+ return passages
1133
+
1134
+ @no_type_check
1135
+ def get_relevant_extracts(self, query: str) -> Tuple[str, List[Document]]:
1136
+ """
1137
+ Get list of (verbatim) extracts from doc-chunks relevant to answering a query.
1138
+
1139
+ These are the stages (some optional based on config):
1140
+ - use LLM to convert query to stand-alone query
1141
+ - optionally use LLM to rephrase query to use below
1142
+ - optionally use LLM to generate hypothetical answer (HyDE) to use below.
1143
+ - get_relevant_chunks(): get doc-chunks relevant to query and proxies
1144
+ - use LLM to get relevant extracts from doc-chunks
1145
+
1146
+ Args:
1147
+ query (str): query to search for
1148
+
1149
+ Returns:
1150
+ query (str): stand-alone version of input query
1151
+ List[Document]: list of relevant extracts
1152
+
1153
+ """
1154
+ if len(self.dialog) > 0 and not self.config.assistant_mode:
1155
+ # Regardless of whether we are in conversation mode or not,
1156
+ # for relevant doc/chunk extraction, we must convert the query
1157
+ # to a standalone query to get more relevant results.
1158
+ with status("[cyan]Converting to stand-alone query...[/cyan]"):
1159
+ with StreamingIfAllowed(self.llm, False):
1160
+ query = self.llm.followup_to_standalone(self.dialog, query)
1161
+ print(f"[orange2]New query: {query}")
1162
+
1163
+ proxies = []
1164
+ if self.config.hypothetical_answer:
1165
+ answer = self.llm_hypothetical_answer(query)
1166
+ proxies = [answer]
1167
+
1168
+ if self.config.n_query_rephrases > 0:
1169
+ rephrases = self.llm_rephrase_query(query)
1170
+ proxies += rephrases
1171
+
1172
+ passages = self.get_relevant_chunks(query, proxies) # no LLM involved
1173
+
548
1174
  if len(passages) == 0:
549
1175
  return query, []
550
1176
 
551
- with console.status("[cyan]LLM Extracting verbatim passages..."):
1177
+ with status("[cyan]LLM Extracting verbatim passages..."):
552
1178
  with StreamingIfAllowed(self.llm, False):
553
1179
  # these are async calls, one per passage; turn off streaming
554
- extracts = self.llm.get_verbatim_extracts(query, passages)
1180
+ extracts = self.get_verbatim_extracts(query, passages)
555
1181
  extracts = [e for e in extracts if e.content != NO_ANSWER]
556
1182
 
557
1183
  return query, extracts
558
1184
 
559
- @no_type_check
560
- def answer_from_docs(self, query: str) -> Document:
1185
+ def get_verbatim_extracts(
1186
+ self,
1187
+ query: str,
1188
+ passages: List[Document],
1189
+ ) -> List[Document]:
1190
+ """
1191
+ Run RelevanceExtractorAgent in async/concurrent mode on passages,
1192
+ to extract portions relevant to answering query, from each passage.
1193
+ Args:
1194
+ query (str): query to answer
1195
+ passages (List[Documents]): list of passages to extract from
1196
+
1197
+ Returns:
1198
+ List[Document]: list of Documents containing extracts and metadata.
1199
+ """
1200
+ agent_cfg = self.config.relevance_extractor_config
1201
+ if agent_cfg is None:
1202
+ # no relevance extraction: simply return passages
1203
+ return passages
1204
+ if agent_cfg.llm is None:
1205
+ # Use main DocChatAgent's LLM if not provided explicitly:
1206
+ # this reduces setup burden on the user
1207
+ agent_cfg.llm = self.config.llm
1208
+ agent_cfg.query = query
1209
+ agent_cfg.segment_length = self.config.extraction_granularity
1210
+ agent_cfg.llm.stream = False # disable streaming for concurrent calls
1211
+
1212
+ agent = RelevanceExtractorAgent(agent_cfg)
1213
+ task = Task(
1214
+ agent,
1215
+ name="Relevance-Extractor",
1216
+ interactive=False,
1217
+ )
1218
+
1219
+ extracts = run_batch_tasks(
1220
+ task,
1221
+ passages,
1222
+ input_map=lambda msg: msg.content,
1223
+ output_map=lambda ans: ans.content if ans is not None else NO_ANSWER,
1224
+ )
1225
+
1226
+ # Caution: Retain ALL other fields in the Documents (which could be
1227
+ # other than just `content` and `metadata`), while simply replacing
1228
+ # `content` with the extracted portions
1229
+ passage_extracts = []
1230
+ for p, e in zip(passages, extracts):
1231
+ if e == NO_ANSWER or len(e) == 0:
1232
+ continue
1233
+ p_copy = p.copy()
1234
+ p_copy.content = e
1235
+ passage_extracts.append(p_copy)
1236
+
1237
+ return passage_extracts
1238
+
1239
+ def answer_from_docs(self, query: str) -> ChatDocument:
561
1240
  """
562
1241
  Answer query based on relevant docs from the VecDB
563
1242
 
@@ -567,24 +1246,38 @@ class DocChatAgent(ChatAgent):
567
1246
  Returns:
568
1247
  Document: answer
569
1248
  """
570
- response = Document(
1249
+ response = ChatDocument(
571
1250
  content=NO_ANSWER,
572
- metadata=DocMetaData(
1251
+ metadata=ChatDocMetaData(
573
1252
  source="None",
1253
+ sender=Entity.LLM,
574
1254
  ),
575
1255
  )
576
1256
  # query may be updated to a stand-alone version
577
1257
  query, extracts = self.get_relevant_extracts(query)
578
1258
  if len(extracts) == 0:
579
1259
  return response
1260
+ if self.llm is None:
1261
+ raise ValueError("LLM not set")
1262
+ if self.config.retrieve_only:
1263
+ # only return extracts, skip LLM-based summary answer
1264
+ meta = dict(
1265
+ sender=Entity.LLM,
1266
+ )
1267
+ # copy metadata from first doc, unclear what to do here.
1268
+ meta.update(extracts[0].metadata)
1269
+ return ChatDocument(
1270
+ content="\n\n".join([e.content for e in extracts]),
1271
+ metadata=ChatDocMetaData(**meta),
1272
+ )
580
1273
  with ExitStack() as stack:
581
1274
  # conditionally use Streaming or rich console context
582
1275
  cm = (
583
1276
  StreamingIfAllowed(self.llm)
584
1277
  if settings.stream
585
- else (console.status("LLM Generating final answer..."))
1278
+ else (status("LLM Generating final answer..."))
586
1279
  )
587
- stack.enter_context(cm)
1280
+ stack.enter_context(cm) # type: ignore
588
1281
  response = self.get_summary_answer(query, extracts)
589
1282
 
590
1283
  self.update_dialog(query, response.content)
@@ -598,7 +1291,7 @@ class DocChatAgent(ChatAgent):
598
1291
  """Summarize all docs"""
599
1292
  if self.llm is None:
600
1293
  raise ValueError("LLM not set")
601
- if self.original_docs is None:
1294
+ if len(self.original_docs) == 0:
602
1295
  logger.warning(
603
1296
  """
604
1297
  No docs to summarize! Perhaps you are re-using a previously
@@ -627,19 +1320,22 @@ class DocChatAgent(ChatAgent):
627
1320
  )
628
1321
  prompt = f"""
629
1322
  {instruction}
1323
+
1324
+ FULL TEXT:
630
1325
  {full_text}
631
1326
  """.strip()
632
1327
  with StreamingIfAllowed(self.llm):
633
- summary = Agent.llm_response(self, prompt)
634
- return summary # type: ignore
1328
+ summary = ChatAgent.llm_response(self, prompt)
1329
+ return summary
635
1330
 
636
- def justify_response(self) -> None:
1331
+ def justify_response(self) -> ChatDocument | None:
637
1332
  """Show evidence for last response"""
638
1333
  if self.response is None:
639
1334
  print("[magenta]No response yet")
640
- return
1335
+ return None
641
1336
  source = self.response.metadata.source
642
1337
  if len(source) > 0:
643
1338
  print("[magenta]" + source)
644
1339
  else:
645
1340
  print("[magenta]No source found")
1341
+ return None