langroid 0.33.12__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langroid/agent/batch.py CHANGED
@@ -1,7 +1,19 @@
1
1
  import asyncio
2
2
  import copy
3
3
  import inspect
4
- from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar, cast
4
+ import warnings
5
+ from enum import Enum
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Coroutine,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ TypeVar,
14
+ Union,
15
+ cast,
16
+ )
5
17
 
6
18
  from dotenv import load_dotenv
7
19
 
@@ -21,6 +33,238 @@ T = TypeVar("T")
21
33
  U = TypeVar("U")
22
34
 
23
35
 
36
+ class ExceptionHandling(str, Enum):
37
+ """Enum for exception handling options."""
38
+
39
+ RAISE = "raise"
40
+ RETURN_NONE = "return_none"
41
+ RETURN_EXCEPTION = "return_exception"
42
+
43
+
44
+ def _convert_exception_handling(
45
+ handle_exceptions: Union[bool, ExceptionHandling]
46
+ ) -> ExceptionHandling:
47
+ """Convert legacy boolean handle_exceptions to ExceptionHandling enum."""
48
+ if isinstance(handle_exceptions, ExceptionHandling):
49
+ return handle_exceptions
50
+
51
+ if isinstance(handle_exceptions, bool):
52
+ warnings.warn(
53
+ "Boolean handle_exceptions is deprecated. "
54
+ "Use ExceptionHandling enum instead: "
55
+ "RAISE, RETURN_NONE, or RETURN_EXCEPTION.",
56
+ DeprecationWarning,
57
+ stacklevel=2,
58
+ )
59
+ return (
60
+ ExceptionHandling.RETURN_NONE
61
+ if handle_exceptions
62
+ else ExceptionHandling.RAISE
63
+ )
64
+
65
+ raise TypeError(
66
+ "handle_exceptions must be bool or ExceptionHandling, "
67
+ f"not {type(handle_exceptions)}"
68
+ )
69
+
70
+
71
+ async def _process_batch_async(
72
+ inputs: Iterable[str | ChatDocument],
73
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
74
+ start_idx: int = 0,
75
+ stop_on_first_result: bool = False,
76
+ sequential: bool = False,
77
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
78
+ output_map: Callable[[Any], Any] = lambda x: x,
79
+ ) -> List[Optional[ChatDocument] | BaseException]:
80
+ """
81
+ Unified batch processing logic for both agent methods and tasks.
82
+
83
+ Args:
84
+ inputs: Iterable of inputs to process
85
+ do_task: Task execution function that takes (input, index) and returns result
86
+ start_idx: Starting index for the batch
87
+ stop_on_first_result: Whether to stop after first valid result
88
+ sequential: Whether to process sequentially
89
+ handle_exceptions: How to handle exceptions:
90
+ - RAISE or False: Let exceptions propagate
91
+ - RETURN_NONE or True: Convert exceptions to None in results
92
+ - RETURN_EXCEPTION: Include exception objects in results
93
+ Boolean values are deprecated and will be removed in a future version.
94
+ output_map: Function to map results to final output format
95
+ """
96
+ exception_handling = _convert_exception_handling(handle_exceptions)
97
+
98
+ def handle_error(e: BaseException) -> Any:
99
+ """Handle exceptions based on exception_handling."""
100
+ match exception_handling:
101
+ case ExceptionHandling.RAISE:
102
+ raise e
103
+ case ExceptionHandling.RETURN_NONE:
104
+ return None
105
+ case ExceptionHandling.RETURN_EXCEPTION:
106
+ return e
107
+
108
+ if stop_on_first_result:
109
+ results: List[Optional[ChatDocument] | BaseException] = []
110
+ pending: set[asyncio.Task[Any]] = set()
111
+ # Create task-to-index mapping
112
+ task_indices: dict[asyncio.Task[Any], int] = {}
113
+ try:
114
+ tasks = [
115
+ asyncio.create_task(do_task(input, i + start_idx))
116
+ for i, input in enumerate(inputs)
117
+ ]
118
+ task_indices = {task: i for i, task in enumerate(tasks)}
119
+ results = [None] * len(tasks)
120
+
121
+ done, pending = await asyncio.wait(
122
+ tasks, return_when=asyncio.FIRST_COMPLETED
123
+ )
124
+
125
+ # Process completed tasks
126
+ for task in done:
127
+ index = task_indices[task]
128
+ try:
129
+ result = await task
130
+ results[index] = output_map(result)
131
+ except BaseException as e:
132
+ results[index] = handle_error(e)
133
+
134
+ if any(r is not None for r in results):
135
+ return results
136
+ finally:
137
+ for task in pending:
138
+ task.cancel()
139
+ try:
140
+ await asyncio.gather(*pending, return_exceptions=True)
141
+ except BaseException as e:
142
+ handle_error(e)
143
+ return results
144
+
145
+ elif sequential:
146
+ results = []
147
+ for i, input in enumerate(inputs):
148
+ try:
149
+ result = await do_task(input, i + start_idx)
150
+ results.append(output_map(result))
151
+ except BaseException as e:
152
+ results.append(handle_error(e))
153
+ return results
154
+
155
+ # Parallel execution
156
+ else:
157
+ try:
158
+ return_exceptions = exception_handling != ExceptionHandling.RAISE
159
+ with quiet_mode(), SuppressLoggerWarnings():
160
+ results_with_exceptions = cast(
161
+ list[Optional[ChatDocument | BaseException]],
162
+ await asyncio.gather(
163
+ *(
164
+ do_task(input, i + start_idx)
165
+ for i, input in enumerate(inputs)
166
+ ),
167
+ return_exceptions=return_exceptions,
168
+ ),
169
+ )
170
+
171
+ if exception_handling == ExceptionHandling.RETURN_NONE:
172
+ results = [
173
+ None if isinstance(r, BaseException) else r
174
+ for r in results_with_exceptions
175
+ ]
176
+ else: # ExceptionHandling.RETURN_EXCEPTION
177
+ results = results_with_exceptions
178
+ except BaseException as e:
179
+ results = [handle_error(e) for _ in inputs]
180
+
181
+ return [output_map(r) for r in results]
182
+
183
+
184
+ def run_batched_tasks(
185
+ inputs: List[str | ChatDocument],
186
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
187
+ batch_size: Optional[int],
188
+ stop_on_first_result: bool,
189
+ sequential: bool,
190
+ handle_exceptions: Union[bool, ExceptionHandling],
191
+ output_map: Callable[[Any], Any],
192
+ message_template: str,
193
+ message: Optional[str] = None,
194
+ ) -> List[Any]:
195
+ """
196
+ Common batch processing logic for both agent methods and tasks.
197
+
198
+ Args:
199
+ inputs: List of inputs to process
200
+ do_task: Task execution function
201
+ batch_size: Size of batches, if None process all at once
202
+ stop_on_first_result: Whether to stop after first valid result
203
+ sequential: Whether to process sequentially
204
+ handle_exceptions: How to handle exceptions:
205
+ - RAISE or False: Let exceptions propagate
206
+ - RETURN_NONE or True: Convert exceptions to None in results
207
+ - RETURN_EXCEPTION: Include exception objects in results
208
+ Boolean values are deprecated and will be removed in a future version.
209
+ output_map: Function to map results
210
+ message_template: Template for status message
211
+ message: Optional override for status message
212
+ """
213
+
214
+ async def run_all_batched_tasks(
215
+ inputs: List[str | ChatDocument],
216
+ batch_size: int | None,
217
+ ) -> List[Any]:
218
+ """Extra wrap to run asyncio.run one single time and not once per loop
219
+
220
+ Args:
221
+ inputs (List[str | ChatDocument]): inputs to process
222
+ batch_size (int | None): batch size
223
+
224
+ Returns:
225
+ List[Any]: results
226
+ """
227
+ results: List[Any] = []
228
+ if batch_size is None:
229
+ msg = message or message_template.format(total=len(inputs))
230
+ with status(msg), SuppressLoggerWarnings():
231
+ results = await _process_batch_async(
232
+ inputs,
233
+ do_task,
234
+ stop_on_first_result=stop_on_first_result,
235
+ sequential=sequential,
236
+ handle_exceptions=handle_exceptions,
237
+ output_map=output_map,
238
+ )
239
+ else:
240
+ batches = batched(inputs, batch_size)
241
+ for batch in batches:
242
+ start_idx = len(results)
243
+ complete_str = f", {start_idx} complete" if start_idx > 0 else ""
244
+ msg = (
245
+ message or message_template.format(total=len(inputs)) + complete_str
246
+ )
247
+
248
+ if stop_on_first_result and any(r is not None for r in results):
249
+ results.extend([None] * len(batch))
250
+ else:
251
+ with status(msg), SuppressLoggerWarnings():
252
+ results.extend(
253
+ await _process_batch_async(
254
+ batch,
255
+ do_task,
256
+ start_idx=start_idx,
257
+ stop_on_first_result=stop_on_first_result,
258
+ sequential=sequential,
259
+ handle_exceptions=handle_exceptions,
260
+ output_map=output_map,
261
+ )
262
+ )
263
+ return results
264
+
265
+ return asyncio.run(run_all_batched_tasks(inputs, batch_size))
266
+
267
+
24
268
  def run_batch_task_gen(
25
269
  gen_task: Callable[[int], Task],
26
270
  items: list[T],
@@ -31,7 +275,7 @@ def run_batch_task_gen(
31
275
  batch_size: Optional[int] = None,
32
276
  turns: int = -1,
33
277
  message: Optional[str] = None,
34
- handle_exceptions: bool = False,
278
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
35
279
  max_cost: float = 0.0,
36
280
  max_tokens: int = 0,
37
281
  ) -> list[Optional[U]]:
@@ -58,7 +302,11 @@ def run_batch_task_gen(
58
302
  if None, unbatched
59
303
  turns (int): number of turns to run, -1 for infinite
60
304
  message (Optional[str]): optionally overrides the console status messages
61
- handle_exceptions: bool: Whether to replace exceptions with outputs of None
305
+ handle_exceptions: How to handle exceptions:
306
+ - RAISE or False: Let exceptions propagate
307
+ - RETURN_NONE or True: Convert exceptions to None in results
308
+ - RETURN_EXCEPTION: Include exception objects in results
309
+ Boolean values are deprecated and will be removed in a future version.
62
310
  max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
63
311
  max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
64
312
 
@@ -72,113 +320,33 @@ def run_batch_task_gen(
72
320
  async def _do_task(
73
321
  input: str | ChatDocument,
74
322
  i: int,
75
- return_idx: Optional[int] = None,
76
323
  ) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
77
324
  task_i = gen_task(i)
78
325
  if task_i.agent.llm is not None:
79
326
  task_i.agent.llm.set_stream(False)
80
327
  task_i.agent.config.show_stats = False
328
+
81
329
  try:
82
330
  result = await task_i.run_async(
83
331
  input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
84
332
  )
85
- if return_idx is not None:
86
- return return_idx, result
87
- else:
88
- return result
89
333
  except asyncio.CancelledError as e:
90
334
  task_i.kill()
91
- if handle_exceptions:
92
- return e
93
- else:
94
- raise e
95
- except BaseException as e:
96
- if handle_exceptions:
97
- return e
98
- else:
99
- raise e
100
-
101
- async def _do_all(
102
- inputs: Iterable[str | ChatDocument], start_idx: int = 0
103
- ) -> list[Optional[U]]:
104
- results: list[Optional[ChatDocument]] = []
105
- if stop_on_first_result:
106
- outputs: list[Optional[U]] = [None] * len(list(inputs))
107
- tasks = set(
108
- asyncio.create_task(_do_task(input, i + start_idx, return_idx=i))
109
- for i, input in enumerate(inputs)
110
- )
111
- while tasks:
112
- try:
113
- done, tasks = await asyncio.wait(
114
- tasks, return_when=asyncio.FIRST_COMPLETED
115
- )
116
- for task in done:
117
- idx_result = task.result()
118
- if not isinstance(idx_result, tuple):
119
- continue
120
- index, output = idx_result
121
- outputs[index] = output_map(output)
122
-
123
- if any(r is not None for r in outputs):
124
- return outputs
125
- finally:
126
- # Cancel all remaining tasks
127
- for task in tasks:
128
- task.cancel()
129
- # Wait for cancellations to complete
130
- try:
131
- await asyncio.gather(*tasks, return_exceptions=True)
132
- except BaseException as e:
133
- if not handle_exceptions:
134
- raise e
135
- return outputs
136
- elif sequential:
137
- for i, input in enumerate(inputs):
138
- result: Optional[ChatDocument] | BaseException = await _do_task(
139
- input, i + start_idx
140
- ) # type: ignore
141
-
142
- if isinstance(result, BaseException):
143
- result = None
144
-
145
- results.append(result)
146
- else:
147
- results_with_exceptions = cast(
148
- list[Optional[ChatDocument | BaseException]],
149
- await asyncio.gather(
150
- *(_do_task(input, i + start_idx) for i, input in enumerate(inputs)),
151
- ),
152
- )
153
-
154
- results = [
155
- r if not isinstance(r, BaseException) else None
156
- for r in results_with_exceptions
157
- ]
158
-
159
- return list(map(output_map, results))
160
-
161
- results: List[Optional[U]] = []
162
- if batch_size is None:
163
- msg = message or f"[bold green]Running {len(items)} tasks:"
164
-
165
- with status(msg), SuppressLoggerWarnings():
166
- results = asyncio.run(_do_all(inputs))
167
- else:
168
- batches = batched(inputs, batch_size)
169
-
170
- for batch in batches:
171
- start_idx = len(results)
172
- complete_str = f", {start_idx} complete" if start_idx > 0 else ""
173
- msg = message or f"[bold green]Running {len(items)} tasks{complete_str}:"
174
-
175
- if stop_on_first_result and any(r is not None for r in results):
176
- results.extend([None] * len(batch))
177
- else:
178
- with status(msg), SuppressLoggerWarnings():
179
- results.extend(asyncio.run(_do_all(batch, start_idx=start_idx)))
180
-
181
- return results
335
+ # exception will be handled by the caller
336
+ raise e
337
+ return result
338
+
339
+ return run_batched_tasks(
340
+ inputs=inputs,
341
+ do_task=_do_task,
342
+ batch_size=batch_size,
343
+ stop_on_first_result=stop_on_first_result,
344
+ sequential=sequential,
345
+ handle_exceptions=handle_exceptions,
346
+ output_map=output_map,
347
+ message_template="[bold green]Running {total} tasks:",
348
+ message=message,
349
+ )
182
350
 
183
351
 
184
352
  def run_batch_tasks(
@@ -242,6 +410,8 @@ def run_batch_agent_method(
242
410
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
243
411
  sequential: bool = True,
244
412
  stop_on_first_result: bool = False,
413
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
414
+ batch_size: Optional[int] = None,
245
415
  ) -> List[Any]:
246
416
  """
247
417
  Run the `method` on copies of `agent`, async/concurrently one per
@@ -265,6 +435,14 @@ def run_batch_agent_method(
265
435
  to final result
266
436
  sequential (bool): whether to run sequentially
267
437
  (e.g. some APIs such as ooba don't support concurrent requests)
438
+ stop_on_first_result (bool): whether to stop after the first valid
439
+ handle_exceptions: How to handle exceptions:
440
+ - RAISE or False: Let exceptions propagate
441
+ - RETURN_NONE or True: Convert exceptions to None in results
442
+ - RETURN_EXCEPTION: Include exception objects in results
443
+ Boolean values are deprecated and will be removed in a future version.
444
+ batch_size (Optional[int]): The number of items to process in each batch.
445
+ If None, process all items at once.
268
446
  Returns:
269
447
  List[Any]: list of final results
270
448
  """
@@ -288,43 +466,18 @@ def run_batch_agent_method(
288
466
  if method_i is None:
289
467
  raise ValueError(f"Agent {agent_name} has no method {method_name}")
290
468
  result = await method_i(input)
291
- return output_map(result)
469
+ return result
292
470
 
293
- async def _do_all() -> List[Any]:
294
- if stop_on_first_result:
295
- tasks = [
296
- asyncio.create_task(_do_task(input, i))
297
- for i, input in enumerate(inputs)
298
- ]
299
- results = [None] * len(tasks)
300
- try:
301
- done, pending = await asyncio.wait(
302
- tasks, return_when=asyncio.FIRST_COMPLETED
303
- )
304
- for task in done:
305
- index = tasks.index(task)
306
- results[index] = await task
307
- finally:
308
- for task in pending:
309
- task.cancel()
310
- await asyncio.gather(*pending, return_exceptions=True)
311
- return results
312
- elif sequential:
313
- results = []
314
- for i, input in enumerate(inputs):
315
- result = await _do_task(input, i)
316
- results.append(result)
317
- return results
318
- with quiet_mode(), SuppressLoggerWarnings():
319
- return await asyncio.gather(
320
- *(_do_task(input, i) for i, input in enumerate(inputs))
321
- )
322
-
323
- n = len(items)
324
- with status(f"[bold green]Running {n} copies of {agent_name}..."):
325
- results = asyncio.run(_do_all())
326
-
327
- return results
471
+ return run_batched_tasks(
472
+ inputs=inputs,
473
+ do_task=_do_task,
474
+ batch_size=batch_size,
475
+ stop_on_first_result=stop_on_first_result,
476
+ sequential=sequential,
477
+ handle_exceptions=handle_exceptions,
478
+ output_map=output_map,
479
+ message_template=f"[bold green]Running {{total}} copies of {agent_name}...",
480
+ )
328
481
 
329
482
 
330
483
  def llm_response_batch(
@@ -334,6 +487,7 @@ def llm_response_batch(
334
487
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
335
488
  sequential: bool = True,
336
489
  stop_on_first_result: bool = False,
490
+ batch_size: Optional[int] = None,
337
491
  ) -> List[Any]:
338
492
  return run_batch_agent_method(
339
493
  agent,
@@ -343,6 +497,7 @@ def llm_response_batch(
343
497
  output_map=output_map,
344
498
  sequential=sequential,
345
499
  stop_on_first_result=stop_on_first_result,
500
+ batch_size=batch_size,
346
501
  )
347
502
 
348
503
 
@@ -353,6 +508,7 @@ def agent_response_batch(
353
508
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
354
509
  sequential: bool = True,
355
510
  stop_on_first_result: bool = False,
511
+ batch_size: Optional[int] = None,
356
512
  ) -> List[Any]:
357
513
  return run_batch_agent_method(
358
514
  agent,
@@ -362,6 +518,7 @@ def agent_response_batch(
362
518
  output_map=output_map,
363
519
  sequential=sequential,
364
520
  stop_on_first_result=stop_on_first_result,
521
+ batch_size=batch_size,
365
522
  )
366
523
 
367
524
 
@@ -1845,14 +1845,16 @@ class ChatAgent(Agent):
1845
1845
  self.update_last_message(message, role=Role.USER)
1846
1846
  return answer_doc
1847
1847
 
1848
- def llm_response_forget(self, message: str) -> ChatDocument:
1848
+ def llm_response_forget(
1849
+ self, message: Optional[str | ChatDocument] = None
1850
+ ) -> ChatDocument:
1849
1851
  """
1850
1852
  LLM Response to single message, and restore message_history.
1851
1853
  In effect a "one-off" message & response that leaves agent
1852
1854
  message history state intact.
1853
1855
 
1854
1856
  Args:
1855
- message (str): user message
1857
+ message (str|ChatDocument): message to respond to.
1856
1858
 
1857
1859
  Returns:
1858
1860
  A Document object with the response.
@@ -1879,7 +1881,9 @@ class ChatAgent(Agent):
1879
1881
 
1880
1882
  return response
1881
1883
 
1882
- async def llm_response_forget_async(self, message: str) -> ChatDocument:
1884
+ async def llm_response_forget_async(
1885
+ self, message: Optional[str | ChatDocument] = None
1886
+ ) -> ChatDocument:
1883
1887
  """
1884
1888
  Async version of `llm_response_forget`. See there for details.
1885
1889
  """
@@ -1,3 +1,4 @@
1
+ # # langroid/agent/special/doc_chat_agent.py
1
2
  """
2
3
  Agent that supports asking queries about a set of documents, using
3
4
  retrieval-augmented generation (RAG).
@@ -16,14 +17,14 @@ pip install "langroid[hf-embeddings]"
16
17
  import logging
17
18
  from collections import OrderedDict
18
19
  from functools import cache
19
- from typing import Any, Dict, List, Optional, Set, Tuple, no_type_check
20
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, no_type_check
20
21
 
21
22
  import nest_asyncio
22
23
  import numpy as np
23
24
  import pandas as pd
24
25
  from rich.prompt import Prompt
25
26
 
26
- from langroid.agent.batch import run_batch_tasks
27
+ from langroid.agent.batch import run_batch_agent_method, run_batch_tasks
27
28
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
28
29
  from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
29
30
  from langroid.agent.special.relevance_extractor_agent import (
@@ -81,6 +82,8 @@ DEFAULT_DOC_CHAT_SYSTEM_MESSAGE = """
81
82
  You are a helpful assistant, helping me understand a collection of documents.
82
83
  """
83
84
 
85
+ CHUNK_ENRICHMENT_DELIMITER = "<##-##-##>"
86
+
84
87
  has_sentence_transformers = False
85
88
  try:
86
89
  from sentence_transformers import SentenceTransformer # noqa: F401
@@ -102,6 +105,12 @@ oai_embed_config = OpenAIEmbeddingsConfig(
102
105
  )
103
106
 
104
107
 
108
+ class ChunkEnrichmentAgentConfig(ChatAgentConfig):
109
+ batch_size: int = 50
110
+ delimiter: str = CHUNK_ENRICHMENT_DELIMITER
111
+ enrichment_prompt_fn: Callable[[str], str] = lambda x: x
112
+
113
+
105
114
  class DocChatAgentConfig(ChatAgentConfig):
106
115
  system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
107
116
  user_message: str = DEFAULT_DOC_CHAT_INSTRUCTIONS
@@ -126,6 +135,12 @@ class DocChatAgentConfig(ChatAgentConfig):
126
135
  # https://arxiv.org/pdf/2212.10496.pdf
127
136
  # It is False by default; its benefits depends on the context.
128
137
  hypothetical_answer: bool = False
138
+ # Optional config for chunk enrichment agent, e.g. to enrich
139
+ # chunks with hypothetical questions, or keywords to increase
140
+ # the "semantic surface area" of the chunks, which may help
141
+ # improve retrieval.
142
+ chunk_enrichment_config: Optional[ChunkEnrichmentAgentConfig] = None
143
+
129
144
  n_query_rephrases: int = 0
130
145
  n_neighbor_chunks: int = 0 # how many neighbors on either side of match to retrieve
131
146
  n_fuzzy_neighbor_words: int = 100 # num neighbor words to retrieve for fuzzy match
@@ -404,6 +419,8 @@ class DocChatAgent(ChatAgent):
404
419
  d.metadata.is_chunk = True
405
420
  if self.vecdb is None:
406
421
  raise ValueError("VecDB not set")
422
+ if self.config.chunk_enrichment_config is not None:
423
+ docs = self.enrich_chunks(docs)
407
424
 
408
425
  # If any additional fields need to be added to content,
409
426
  # add them as key=value pairs for all docs, before batching.
@@ -860,6 +877,72 @@ class DocChatAgent(ChatAgent):
860
877
  ).content
861
878
  return answer
862
879
 
880
+ def enrich_chunks(self, docs: List[Document]) -> List[Document]:
881
+ """
882
+ Enrich chunks using Agent configured with self.config.chunk_enrichment_config.
883
+
884
+ We assume that the system message of the agent is set in such a way
885
+ that when we run
886
+ ```
887
+ prompt = self.config.chunk_enrichment_config.enrichment_prompt_fn(text)
888
+ result = await agent.llm_response_forget_async(prompt)
889
+ ```
890
+
891
+ then `result.content` will contain the augmentation to the text.
892
+
893
+ Args:
894
+ docs: List of document chunks to enrich
895
+
896
+ Returns:
897
+ List[Document]: Documents (chunks) enriched with additional text,
898
+ separated by a delimiter.
899
+ """
900
+ if self.config.chunk_enrichment_config is None:
901
+ return docs
902
+ enrichment_config = self.config.chunk_enrichment_config
903
+ agent = ChatAgent(enrichment_config)
904
+ if agent.llm is None:
905
+ raise ValueError("LLM not set")
906
+
907
+ with status("[cyan]Augmenting chunks..."):
908
+ # Process chunks in parallel using run_batch_agent_method
909
+ questions_batch = run_batch_agent_method(
910
+ agent=agent,
911
+ method=agent.llm_response_forget_async,
912
+ items=docs,
913
+ input_map=lambda doc: (
914
+ enrichment_config.enrichment_prompt_fn(doc.content)
915
+ ),
916
+ output_map=lambda response: response.content if response else "",
917
+ sequential=False,
918
+ batch_size=enrichment_config.batch_size,
919
+ )
920
+
921
+ # Combine original content with generated questions
922
+ augmented_docs = []
923
+ for doc, enrichment in zip(docs, questions_batch):
924
+ if not enrichment:
925
+ augmented_docs.append(doc)
926
+ continue
927
+
928
+ # Combine original content with questions in a structured way
929
+ combined_content = f"""
930
+ {doc.content}
931
+
932
+ {enrichment_config.delimiter}
933
+ {enrichment}
934
+ """.strip()
935
+
936
+ new_doc = doc.copy(
937
+ update={
938
+ "content": combined_content,
939
+ "metadata": doc.metadata.copy(update={"has_enrichment": True}),
940
+ }
941
+ )
942
+ augmented_docs.append(new_doc)
943
+
944
+ return augmented_docs
945
+
863
946
  def llm_rephrase_query(self, query: str) -> List[str]:
864
947
  if self.llm is None:
865
948
  raise ValueError("LLM not set")
@@ -1143,20 +1226,22 @@ class DocChatAgent(ChatAgent):
1143
1226
  id2_rank_semantic = {d.id(): i for i, (d, _) in enumerate(docs_and_scores)}
1144
1227
  id2doc = {d.id(): d for d, _ in docs_and_scores}
1145
1228
  # make sure we get unique docs
1146
- passages = [id2doc[id] for id, _ in id2_rank_semantic.items()]
1229
+ passages = [id2doc[id] for id in id2_rank_semantic.keys()]
1147
1230
 
1148
1231
  id2_rank_bm25 = {}
1149
1232
  if self.config.use_bm25_search:
1150
1233
  # TODO: Add score threshold in config
1151
1234
  docs_scores = self.get_similar_chunks_bm25(query, retrieval_multiple)
1235
+ id2doc.update({d.id(): d for d, _ in docs_scores})
1152
1236
  if self.config.cross_encoder_reranking_model == "":
1153
1237
  # only if we're not re-ranking with a cross-encoder,
1154
1238
  # we collect these ranks for Reciprocal Rank Fusion down below.
1155
1239
  docs_scores = sorted(docs_scores, key=lambda x: x[1], reverse=True)
1156
1240
  id2_rank_bm25 = {d.id(): i for i, (d, _) in enumerate(docs_scores)}
1157
- id2doc.update({d.id(): d for d, _ in docs_scores})
1158
1241
  else:
1159
1242
  passages += [d for (d, _) in docs_scores]
1243
+ # eliminate duplicate ids
1244
+ passages = [id2doc[id] for id in id2doc.keys()]
1160
1245
 
1161
1246
  id2_rank_fuzzy = {}
1162
1247
  if self.config.use_fuzzy_match:
@@ -1174,6 +1259,8 @@ class DocChatAgent(ChatAgent):
1174
1259
  id2doc.update({d.id(): d for d, _ in fuzzy_match_doc_scores})
1175
1260
  else:
1176
1261
  passages += [d for (d, _) in fuzzy_match_doc_scores]
1262
+ # eliminate duplicate ids
1263
+ passages = [id2doc[id] for id in id2doc.keys()]
1177
1264
 
1178
1265
  if (
1179
1266
  self.config.cross_encoder_reranking_model == ""
@@ -1301,7 +1388,6 @@ class DocChatAgent(ChatAgent):
1301
1388
  if self.config.n_query_rephrases > 0:
1302
1389
  rephrases = self.llm_rephrase_query(query)
1303
1390
  proxies += rephrases
1304
-
1305
1391
  passages = self.get_relevant_chunks(query, proxies) # no LLM involved
1306
1392
 
1307
1393
  if len(passages) == 0:
@@ -1315,6 +1401,29 @@ class DocChatAgent(ChatAgent):
1315
1401
 
1316
1402
  return query, extracts
1317
1403
 
1404
+ def remove_chunk_enrichments(self, passages: List[Document]) -> List[Document]:
1405
+ """Remove any enrichments (like hypothetical questions, or keywords)
1406
+ from documents.
1407
+ Only cleans if enrichment was enabled in config.
1408
+
1409
+ Args:
1410
+ passages: List of documents to clean
1411
+
1412
+ Returns:
1413
+ List of documents with only original content
1414
+ """
1415
+ if self.config.chunk_enrichment_config is None:
1416
+ return passages
1417
+ delimiter = self.config.chunk_enrichment_config.delimiter
1418
+ return [
1419
+ (
1420
+ doc.copy(update={"content": doc.content.split(delimiter)[0].strip()})
1421
+ if doc.content and getattr(doc.metadata, "has_enrichment", False)
1422
+ else doc
1423
+ )
1424
+ for doc in passages
1425
+ ]
1426
+
1318
1427
  def get_verbatim_extracts(
1319
1428
  self,
1320
1429
  query: str,
@@ -1330,6 +1439,8 @@ class DocChatAgent(ChatAgent):
1330
1439
  Returns:
1331
1440
  List[Document]: list of Documents containing extracts and metadata.
1332
1441
  """
1442
+ passages = self.remove_chunk_enrichments(passages)
1443
+
1333
1444
  agent_cfg = self.config.relevance_extractor_config
1334
1445
  if agent_cfg is None:
1335
1446
  # no relevance extraction: simply return passages
langroid/mytypes.py CHANGED
@@ -75,6 +75,17 @@ class Document(BaseModel):
75
75
  def id(self) -> str:
76
76
  return self.metadata.id
77
77
 
78
+ @staticmethod
79
+ def from_string(
80
+ content: str,
81
+ source: str = "context",
82
+ is_chunk: bool = True,
83
+ ) -> "Document":
84
+ return Document(
85
+ content=content,
86
+ metadata=DocMetaData(source=source, is_chunk=is_chunk),
87
+ )
88
+
78
89
  def __str__(self) -> str:
79
90
  return dedent(
80
91
  f"""
@@ -7,9 +7,9 @@ try:
7
7
  from pydispatch import dispatcher
8
8
  from scrapy import signals
9
9
  from scrapy.crawler import CrawlerRunner
10
- from scrapy.http import Response
11
- from scrapy.linkextractors import LinkExtractor
12
- from scrapy.spiders import CrawlSpider, Rule
10
+ from scrapy.http.response.text import TextResponse
11
+ from scrapy.linkextractors.lxmlhtml import LxmlLinkExtractor
12
+ from scrapy.spiders import CrawlSpider, Rule # type: ignore
13
13
  from twisted.internet import defer, reactor
14
14
  except ImportError:
15
15
  raise LangroidImportError("scrapy", "scrapy")
@@ -21,7 +21,7 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
21
21
 
22
22
  custom_settings = {"DEPTH_LIMIT": 1, "CLOSESPIDER_ITEMCOUNT": 20}
23
23
 
24
- rules = (Rule(LinkExtractor(), callback="parse_item", follow=True),)
24
+ rules = (Rule(LxmlLinkExtractor(), callback="parse_item", follow=True),)
25
25
 
26
26
  def __init__(self, start_url: str, k: int = 20, *args, **kwargs): # type: ignore
27
27
  """Initialize the spider with start_url and k.
@@ -36,13 +36,13 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
36
36
  self.k = k
37
37
  self.visited_urls: Set[str] = set()
38
38
 
39
- def parse_item(self, response: Response): # type: ignore
39
+ def parse_item(self, response: TextResponse): # type: ignore
40
40
  """Extracts URLs that are within the same domain.
41
41
 
42
42
  Args:
43
43
  response: The scrapy response object.
44
44
  """
45
- for link in LinkExtractor(allow_domains=self.allowed_domains).extract_links(
45
+ for link in LxmlLinkExtractor(allow_domains=self.allowed_domains).extract_links(
46
46
  response
47
47
  ):
48
48
  if len(self.visited_urls) < self.k:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langroid
3
- Version: 0.33.12
3
+ Version: 0.34.0
4
4
  Summary: Harness LLMs with Multi-Agent Programming
5
5
  Author-email: Prasad Chalasani <pchalasani@gmail.com>
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  langroid/__init__.py,sha256=z_fCOLQJPOw3LLRPBlFB5-2HyCjpPgQa4m4iY5Fvb8Y,1800
2
2
  langroid/exceptions.py,sha256=gp6ku4ZLdXXCUQIwUNVFojJNGTzGnkevi2PLvG7HOhc,2555
3
- langroid/mytypes.py,sha256=ptAFxEAtiwmIfUnGisNotTe8wT9LKBf22lOfPgZoQIY,2368
3
+ langroid/mytypes.py,sha256=h1eMq1ZwTLVezObPfCseWNWbEOzP7mAKu2XoS63W1cM,2647
4
4
  langroid/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  langroid/agent/__init__.py,sha256=ll0Cubd2DZ-fsCMl7e10hf9ZjFGKzphfBco396IKITY,786
6
6
  langroid/agent/base.py,sha256=oThlrYygKDu1-bKjAfygldJ511gMKT8Z0qCrD52DdDM,77834
7
- langroid/agent/batch.py,sha256=qK3ph6VNj_1sOhfXCZY4r6gh035DglDKU751p8BU0tY,14665
8
- langroid/agent/chat_agent.py,sha256=cxamUgqQkr6_W3mqCPz3L7rJnXIkD4hemR7X7uhlBvI,82095
7
+ langroid/agent/batch.py,sha256=vi1r5i1-vN80WfqHDSwjEym_KfGsqPGUtwktmiK1nuk,20635
8
+ langroid/agent/chat_agent.py,sha256=A-7Iiiw7jsoJNlWerljM29BidkiIbjPOQIkGZpZHmt0,82210
9
9
  langroid/agent/chat_document.py,sha256=xPUMGzR83rn4iAEXIw2jy5LQ6YJ6Y0TiZ78XRQeDnJQ,17778
10
10
  langroid/agent/openai_assistant.py,sha256=JkAcs02bIrgPNVvUWVR06VCthc5-ulla2QMBzux_q6o,34340
11
11
  langroid/agent/task.py,sha256=XrXUbSoiFasvpIsZPn_cBpdWaTCKljJPRimtLMrSZrs,90347
@@ -14,7 +14,7 @@ langroid/agent/xml_tool_message.py,sha256=6SshYZJKIfi4mkE-gIoSwjkEYekQ8GwcSiCv7a
14
14
  langroid/agent/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  langroid/agent/callbacks/chainlit.py,sha256=RH8qUXaZE5o2WQz3WJQ1SdFtASGlxWCA6_HYz_3meDQ,20822
16
16
  langroid/agent/special/__init__.py,sha256=gik_Xtm_zV7U9s30Mn8UX3Gyuy4jTjQe9zjiE3HWmEo,1273
17
- langroid/agent/special/doc_chat_agent.py,sha256=zw2MvdCWRPH93d73PKh27KFiQ8sUCFPxAfLDdkxvdZQ,59301
17
+ langroid/agent/special/doc_chat_agent.py,sha256=tI16jVavTSOen9OUoRTl5heDTeTBhWsxW17XU9ZcEko,63563
18
18
  langroid/agent/special/lance_doc_chat_agent.py,sha256=s8xoRs0gGaFtDYFUSIRchsgDVbS5Q3C2b2mr3V1Fd-Q,10419
19
19
  langroid/agent/special/lance_tools.py,sha256=qS8x4wi8mrqfbYV2ztFzrcxyhHQ0ZWOc-zkYiH7awj0,2105
20
20
  langroid/agent/special/relevance_extractor_agent.py,sha256=zIx8GUdVo1aGW6ASla0NPQjYYIpmriK_TYMijqAx3F8,4796
@@ -85,7 +85,7 @@ langroid/parsing/parser.py,sha256=bTG5TO2CEwGdLf9979j9_dFntKX5FloGF8vhts6ObU0,11
85
85
  langroid/parsing/repo_loader.py,sha256=3GjvPJS6Vf5L6gV2zOU8s-Tf1oq_fZm-IB_RL_7CTsY,29373
86
86
  langroid/parsing/routing.py,sha256=-FcnlqldzL4ZoxuDwXjQPNHgBe9F9-F4R6q7b_z9CvI,1232
87
87
  langroid/parsing/search.py,sha256=0i_r0ESb5HEQfagA2g7_uMQyxYPADWVbdcN9ixZhS4E,8992
88
- langroid/parsing/spider.py,sha256=Y6y7b86Y2k770LdhxgjVlImBxuuy1V9n8-XQ3QPaG5s,3199
88
+ langroid/parsing/spider.py,sha256=hAVM6wxh1pQ0EN4tI5wMBtAjIk0T-xnpi-ZUzWybhos,3258
89
89
  langroid/parsing/table_loader.py,sha256=qNM4obT_0Y4tjrxNBCNUYjKQ9oETCZ7FbolKBTcz-GM,3410
90
90
  langroid/parsing/url_loader.py,sha256=JK48KktLRDBfjrt4nsUfy92M6yGdEeicAqOum2MdULM,4656
91
91
  langroid/parsing/urls.py,sha256=XjpaV5onG7gKQ5iQeFTzHSw5P08Aqw0g-rMUu61lR6s,7988
@@ -121,7 +121,7 @@ langroid/vector_store/lancedb.py,sha256=b3_vWkTjG8mweZ7ZNlUD-NjmQP_rLBZfyKWcxt2v
121
121
  langroid/vector_store/meilisearch.py,sha256=6frB7GFWeWmeKzRfLZIvzRjllniZ1cYj3HmhHQICXLs,11663
122
122
  langroid/vector_store/momento.py,sha256=UNHGT6jXuQtqY9f6MdqGU14bVnS0zHgIJUa30ULpUJo,10474
123
123
  langroid/vector_store/qdrantdb.py,sha256=HRLCt-FG8y4718omwpFaQZnWeYxPj0XCwS4tjokI1sU,18116
124
- langroid-0.33.12.dist-info/METADATA,sha256=tUh6elP7kcHfPS2dUXBE-gZ6vLoBfFzYsmg7nr2oCrg,59016
125
- langroid-0.33.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
126
- langroid-0.33.12.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
127
- langroid-0.33.12.dist-info/RECORD,,
124
+ langroid-0.34.0.dist-info/METADATA,sha256=fo7ULfjnWFED6Cag8aUFjOaPqEatQKBXEz-Z_rFyHnk,59015
125
+ langroid-0.34.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
126
+ langroid-0.34.0.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
127
+ langroid-0.34.0.dist-info/RECORD,,