langroid 0.33.13__py3-none-any.whl → 0.34.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langroid/agent/batch.py CHANGED
@@ -1,7 +1,19 @@
1
1
  import asyncio
2
2
  import copy
3
3
  import inspect
4
- from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar, cast
4
+ import warnings
5
+ from enum import Enum
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ Coroutine,
10
+ Iterable,
11
+ List,
12
+ Optional,
13
+ TypeVar,
14
+ Union,
15
+ cast,
16
+ )
5
17
 
6
18
  from dotenv import load_dotenv
7
19
 
@@ -21,6 +33,238 @@ T = TypeVar("T")
21
33
  U = TypeVar("U")
22
34
 
23
35
 
36
+ class ExceptionHandling(str, Enum):
37
+ """Enum for exception handling options."""
38
+
39
+ RAISE = "raise"
40
+ RETURN_NONE = "return_none"
41
+ RETURN_EXCEPTION = "return_exception"
42
+
43
+
44
+ def _convert_exception_handling(
45
+ handle_exceptions: Union[bool, ExceptionHandling]
46
+ ) -> ExceptionHandling:
47
+ """Convert legacy boolean handle_exceptions to ExceptionHandling enum."""
48
+ if isinstance(handle_exceptions, ExceptionHandling):
49
+ return handle_exceptions
50
+
51
+ if isinstance(handle_exceptions, bool):
52
+ warnings.warn(
53
+ "Boolean handle_exceptions is deprecated. "
54
+ "Use ExceptionHandling enum instead: "
55
+ "RAISE, RETURN_NONE, or RETURN_EXCEPTION.",
56
+ DeprecationWarning,
57
+ stacklevel=2,
58
+ )
59
+ return (
60
+ ExceptionHandling.RETURN_NONE
61
+ if handle_exceptions
62
+ else ExceptionHandling.RAISE
63
+ )
64
+
65
+ raise TypeError(
66
+ "handle_exceptions must be bool or ExceptionHandling, "
67
+ f"not {type(handle_exceptions)}"
68
+ )
69
+
70
+
71
+ async def _process_batch_async(
72
+ inputs: Iterable[str | ChatDocument],
73
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
74
+ start_idx: int = 0,
75
+ stop_on_first_result: bool = False,
76
+ sequential: bool = False,
77
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
78
+ output_map: Callable[[Any], Any] = lambda x: x,
79
+ ) -> List[Optional[ChatDocument] | BaseException]:
80
+ """
81
+ Unified batch processing logic for both agent methods and tasks.
82
+
83
+ Args:
84
+ inputs: Iterable of inputs to process
85
+ do_task: Task execution function that takes (input, index) and returns result
86
+ start_idx: Starting index for the batch
87
+ stop_on_first_result: Whether to stop after first valid result
88
+ sequential: Whether to process sequentially
89
+ handle_exceptions: How to handle exceptions:
90
+ - RAISE or False: Let exceptions propagate
91
+ - RETURN_NONE or True: Convert exceptions to None in results
92
+ - RETURN_EXCEPTION: Include exception objects in results
93
+ Boolean values are deprecated and will be removed in a future version.
94
+ output_map: Function to map results to final output format
95
+ """
96
+ exception_handling = _convert_exception_handling(handle_exceptions)
97
+
98
+ def handle_error(e: BaseException) -> Any:
99
+ """Handle exceptions based on exception_handling."""
100
+ match exception_handling:
101
+ case ExceptionHandling.RAISE:
102
+ raise e
103
+ case ExceptionHandling.RETURN_NONE:
104
+ return None
105
+ case ExceptionHandling.RETURN_EXCEPTION:
106
+ return e
107
+
108
+ if stop_on_first_result:
109
+ results: List[Optional[ChatDocument] | BaseException] = []
110
+ pending: set[asyncio.Task[Any]] = set()
111
+ # Create task-to-index mapping
112
+ task_indices: dict[asyncio.Task[Any], int] = {}
113
+ try:
114
+ tasks = [
115
+ asyncio.create_task(do_task(input, i + start_idx))
116
+ for i, input in enumerate(inputs)
117
+ ]
118
+ task_indices = {task: i for i, task in enumerate(tasks)}
119
+ results = [None] * len(tasks)
120
+
121
+ done, pending = await asyncio.wait(
122
+ tasks, return_when=asyncio.FIRST_COMPLETED
123
+ )
124
+
125
+ # Process completed tasks
126
+ for task in done:
127
+ index = task_indices[task]
128
+ try:
129
+ result = await task
130
+ results[index] = output_map(result)
131
+ except BaseException as e:
132
+ results[index] = handle_error(e)
133
+
134
+ if any(r is not None for r in results):
135
+ return results
136
+ finally:
137
+ for task in pending:
138
+ task.cancel()
139
+ try:
140
+ await asyncio.gather(*pending, return_exceptions=True)
141
+ except BaseException as e:
142
+ handle_error(e)
143
+ return results
144
+
145
+ elif sequential:
146
+ results = []
147
+ for i, input in enumerate(inputs):
148
+ try:
149
+ result = await do_task(input, i + start_idx)
150
+ results.append(output_map(result))
151
+ except BaseException as e:
152
+ results.append(handle_error(e))
153
+ return results
154
+
155
+ # Parallel execution
156
+ else:
157
+ try:
158
+ return_exceptions = exception_handling != ExceptionHandling.RAISE
159
+ with quiet_mode(), SuppressLoggerWarnings():
160
+ results_with_exceptions = cast(
161
+ list[Optional[ChatDocument | BaseException]],
162
+ await asyncio.gather(
163
+ *(
164
+ do_task(input, i + start_idx)
165
+ for i, input in enumerate(inputs)
166
+ ),
167
+ return_exceptions=return_exceptions,
168
+ ),
169
+ )
170
+
171
+ if exception_handling == ExceptionHandling.RETURN_NONE:
172
+ results = [
173
+ None if isinstance(r, BaseException) else r
174
+ for r in results_with_exceptions
175
+ ]
176
+ else: # ExceptionHandling.RETURN_EXCEPTION
177
+ results = results_with_exceptions
178
+ except BaseException as e:
179
+ results = [handle_error(e) for _ in inputs]
180
+
181
+ return [output_map(r) for r in results]
182
+
183
+
184
+ def run_batched_tasks(
185
+ inputs: List[str | ChatDocument],
186
+ do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
187
+ batch_size: Optional[int],
188
+ stop_on_first_result: bool,
189
+ sequential: bool,
190
+ handle_exceptions: Union[bool, ExceptionHandling],
191
+ output_map: Callable[[Any], Any],
192
+ message_template: str,
193
+ message: Optional[str] = None,
194
+ ) -> List[Any]:
195
+ """
196
+ Common batch processing logic for both agent methods and tasks.
197
+
198
+ Args:
199
+ inputs: List of inputs to process
200
+ do_task: Task execution function
201
+ batch_size: Size of batches, if None process all at once
202
+ stop_on_first_result: Whether to stop after first valid result
203
+ sequential: Whether to process sequentially
204
+ handle_exceptions: How to handle exceptions:
205
+ - RAISE or False: Let exceptions propagate
206
+ - RETURN_NONE or True: Convert exceptions to None in results
207
+ - RETURN_EXCEPTION: Include exception objects in results
208
+ Boolean values are deprecated and will be removed in a future version.
209
+ output_map: Function to map results
210
+ message_template: Template for status message
211
+ message: Optional override for status message
212
+ """
213
+
214
+ async def run_all_batched_tasks(
215
+ inputs: List[str | ChatDocument],
216
+ batch_size: int | None,
217
+ ) -> List[Any]:
218
+ """Extra wrap to run asyncio.run one single time and not once per loop
219
+
220
+ Args:
221
+ inputs (List[str | ChatDocument]): inputs to process
222
+ batch_size (int | None): batch size
223
+
224
+ Returns:
225
+ List[Any]: results
226
+ """
227
+ results: List[Any] = []
228
+ if batch_size is None:
229
+ msg = message or message_template.format(total=len(inputs))
230
+ with status(msg), SuppressLoggerWarnings():
231
+ results = await _process_batch_async(
232
+ inputs,
233
+ do_task,
234
+ stop_on_first_result=stop_on_first_result,
235
+ sequential=sequential,
236
+ handle_exceptions=handle_exceptions,
237
+ output_map=output_map,
238
+ )
239
+ else:
240
+ batches = batched(inputs, batch_size)
241
+ for batch in batches:
242
+ start_idx = len(results)
243
+ complete_str = f", {start_idx} complete" if start_idx > 0 else ""
244
+ msg = (
245
+ message or message_template.format(total=len(inputs)) + complete_str
246
+ )
247
+
248
+ if stop_on_first_result and any(r is not None for r in results):
249
+ results.extend([None] * len(batch))
250
+ else:
251
+ with status(msg), SuppressLoggerWarnings():
252
+ results.extend(
253
+ await _process_batch_async(
254
+ batch,
255
+ do_task,
256
+ start_idx=start_idx,
257
+ stop_on_first_result=stop_on_first_result,
258
+ sequential=sequential,
259
+ handle_exceptions=handle_exceptions,
260
+ output_map=output_map,
261
+ )
262
+ )
263
+ return results
264
+
265
+ return asyncio.run(run_all_batched_tasks(inputs, batch_size))
266
+
267
+
24
268
  def run_batch_task_gen(
25
269
  gen_task: Callable[[int], Task],
26
270
  items: list[T],
@@ -31,7 +275,7 @@ def run_batch_task_gen(
31
275
  batch_size: Optional[int] = None,
32
276
  turns: int = -1,
33
277
  message: Optional[str] = None,
34
- handle_exceptions: bool = False,
278
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
35
279
  max_cost: float = 0.0,
36
280
  max_tokens: int = 0,
37
281
  ) -> list[Optional[U]]:
@@ -58,7 +302,11 @@ def run_batch_task_gen(
58
302
  if None, unbatched
59
303
  turns (int): number of turns to run, -1 for infinite
60
304
  message (Optional[str]): optionally overrides the console status messages
61
- handle_exceptions: bool: Whether to replace exceptions with outputs of None
305
+ handle_exceptions: How to handle exceptions:
306
+ - RAISE or False: Let exceptions propagate
307
+ - RETURN_NONE or True: Convert exceptions to None in results
308
+ - RETURN_EXCEPTION: Include exception objects in results
309
+ Boolean values are deprecated and will be removed in a future version.
62
310
  max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
63
311
  max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
64
312
 
@@ -72,113 +320,33 @@ def run_batch_task_gen(
72
320
  async def _do_task(
73
321
  input: str | ChatDocument,
74
322
  i: int,
75
- return_idx: Optional[int] = None,
76
323
  ) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
77
324
  task_i = gen_task(i)
78
325
  if task_i.agent.llm is not None:
79
326
  task_i.agent.llm.set_stream(False)
80
327
  task_i.agent.config.show_stats = False
328
+
81
329
  try:
82
330
  result = await task_i.run_async(
83
331
  input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
84
332
  )
85
- if return_idx is not None:
86
- return return_idx, result
87
- else:
88
- return result
89
333
  except asyncio.CancelledError as e:
90
334
  task_i.kill()
91
- if handle_exceptions:
92
- return e
93
- else:
94
- raise e
95
- except BaseException as e:
96
- if handle_exceptions:
97
- return e
98
- else:
99
- raise e
100
-
101
- async def _do_all(
102
- inputs: Iterable[str | ChatDocument], start_idx: int = 0
103
- ) -> list[Optional[U]]:
104
- results: list[Optional[ChatDocument]] = []
105
- if stop_on_first_result:
106
- outputs: list[Optional[U]] = [None] * len(list(inputs))
107
- tasks = set(
108
- asyncio.create_task(_do_task(input, i + start_idx, return_idx=i))
109
- for i, input in enumerate(inputs)
110
- )
111
- while tasks:
112
- try:
113
- done, tasks = await asyncio.wait(
114
- tasks, return_when=asyncio.FIRST_COMPLETED
115
- )
116
- for task in done:
117
- idx_result = task.result()
118
- if not isinstance(idx_result, tuple):
119
- continue
120
- index, output = idx_result
121
- outputs[index] = output_map(output)
122
-
123
- if any(r is not None for r in outputs):
124
- return outputs
125
- finally:
126
- # Cancel all remaining tasks
127
- for task in tasks:
128
- task.cancel()
129
- # Wait for cancellations to complete
130
- try:
131
- await asyncio.gather(*tasks, return_exceptions=True)
132
- except BaseException as e:
133
- if not handle_exceptions:
134
- raise e
135
- return outputs
136
- elif sequential:
137
- for i, input in enumerate(inputs):
138
- result: Optional[ChatDocument] | BaseException = await _do_task(
139
- input, i + start_idx
140
- ) # type: ignore
141
-
142
- if isinstance(result, BaseException):
143
- result = None
144
-
145
- results.append(result)
146
- else:
147
- results_with_exceptions = cast(
148
- list[Optional[ChatDocument | BaseException]],
149
- await asyncio.gather(
150
- *(_do_task(input, i + start_idx) for i, input in enumerate(inputs)),
151
- ),
152
- )
153
-
154
- results = [
155
- r if not isinstance(r, BaseException) else None
156
- for r in results_with_exceptions
157
- ]
158
-
159
- return list(map(output_map, results))
160
-
161
- results: List[Optional[U]] = []
162
- if batch_size is None:
163
- msg = message or f"[bold green]Running {len(items)} tasks:"
164
-
165
- with status(msg), SuppressLoggerWarnings():
166
- results = asyncio.run(_do_all(inputs))
167
- else:
168
- batches = batched(inputs, batch_size)
169
-
170
- for batch in batches:
171
- start_idx = len(results)
172
- complete_str = f", {start_idx} complete" if start_idx > 0 else ""
173
- msg = message or f"[bold green]Running {len(items)} tasks{complete_str}:"
174
-
175
- if stop_on_first_result and any(r is not None for r in results):
176
- results.extend([None] * len(batch))
177
- else:
178
- with status(msg), SuppressLoggerWarnings():
179
- results.extend(asyncio.run(_do_all(batch, start_idx=start_idx)))
180
-
181
- return results
335
+ # exception will be handled by the caller
336
+ raise e
337
+ return result
338
+
339
+ return run_batched_tasks(
340
+ inputs=inputs,
341
+ do_task=_do_task,
342
+ batch_size=batch_size,
343
+ stop_on_first_result=stop_on_first_result,
344
+ sequential=sequential,
345
+ handle_exceptions=handle_exceptions,
346
+ output_map=output_map,
347
+ message_template="[bold green]Running {total} tasks:",
348
+ message=message,
349
+ )
182
350
 
183
351
 
184
352
  def run_batch_tasks(
@@ -242,6 +410,8 @@ def run_batch_agent_method(
242
410
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
243
411
  sequential: bool = True,
244
412
  stop_on_first_result: bool = False,
413
+ handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
414
+ batch_size: Optional[int] = None,
245
415
  ) -> List[Any]:
246
416
  """
247
417
  Run the `method` on copies of `agent`, async/concurrently one per
@@ -265,6 +435,14 @@ def run_batch_agent_method(
265
435
  to final result
266
436
  sequential (bool): whether to run sequentially
267
437
  (e.g. some APIs such as ooba don't support concurrent requests)
438
+ stop_on_first_result (bool): whether to stop after the first valid
439
+ handle_exceptions: How to handle exceptions:
440
+ - RAISE or False: Let exceptions propagate
441
+ - RETURN_NONE or True: Convert exceptions to None in results
442
+ - RETURN_EXCEPTION: Include exception objects in results
443
+ Boolean values are deprecated and will be removed in a future version.
444
+ batch_size (Optional[int]): The number of items to process in each batch.
445
+ If None, process all items at once.
268
446
  Returns:
269
447
  List[Any]: list of final results
270
448
  """
@@ -288,43 +466,18 @@ def run_batch_agent_method(
288
466
  if method_i is None:
289
467
  raise ValueError(f"Agent {agent_name} has no method {method_name}")
290
468
  result = await method_i(input)
291
- return output_map(result)
469
+ return result
292
470
 
293
- async def _do_all() -> List[Any]:
294
- if stop_on_first_result:
295
- tasks = [
296
- asyncio.create_task(_do_task(input, i))
297
- for i, input in enumerate(inputs)
298
- ]
299
- results = [None] * len(tasks)
300
- try:
301
- done, pending = await asyncio.wait(
302
- tasks, return_when=asyncio.FIRST_COMPLETED
303
- )
304
- for task in done:
305
- index = tasks.index(task)
306
- results[index] = await task
307
- finally:
308
- for task in pending:
309
- task.cancel()
310
- await asyncio.gather(*pending, return_exceptions=True)
311
- return results
312
- elif sequential:
313
- results = []
314
- for i, input in enumerate(inputs):
315
- result = await _do_task(input, i)
316
- results.append(result)
317
- return results
318
- with quiet_mode(), SuppressLoggerWarnings():
319
- return await asyncio.gather(
320
- *(_do_task(input, i) for i, input in enumerate(inputs))
321
- )
322
-
323
- n = len(items)
324
- with status(f"[bold green]Running {n} copies of {agent_name}..."):
325
- results = asyncio.run(_do_all())
326
-
327
- return results
471
+ return run_batched_tasks(
472
+ inputs=inputs,
473
+ do_task=_do_task,
474
+ batch_size=batch_size,
475
+ stop_on_first_result=stop_on_first_result,
476
+ sequential=sequential,
477
+ handle_exceptions=handle_exceptions,
478
+ output_map=output_map,
479
+ message_template=f"[bold green]Running {{total}} copies of {agent_name}...",
480
+ )
328
481
 
329
482
 
330
483
  def llm_response_batch(
@@ -334,6 +487,7 @@ def llm_response_batch(
334
487
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
335
488
  sequential: bool = True,
336
489
  stop_on_first_result: bool = False,
490
+ batch_size: Optional[int] = None,
337
491
  ) -> List[Any]:
338
492
  return run_batch_agent_method(
339
493
  agent,
@@ -343,6 +497,7 @@ def llm_response_batch(
343
497
  output_map=output_map,
344
498
  sequential=sequential,
345
499
  stop_on_first_result=stop_on_first_result,
500
+ batch_size=batch_size,
346
501
  )
347
502
 
348
503
 
@@ -353,6 +508,7 @@ def agent_response_batch(
353
508
  output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
354
509
  sequential: bool = True,
355
510
  stop_on_first_result: bool = False,
511
+ batch_size: Optional[int] = None,
356
512
  ) -> List[Any]:
357
513
  return run_batch_agent_method(
358
514
  agent,
@@ -362,6 +518,7 @@ def agent_response_batch(
362
518
  output_map=output_map,
363
519
  sequential=sequential,
364
520
  stop_on_first_result=stop_on_first_result,
521
+ batch_size=batch_size,
365
522
  )
366
523
 
367
524
 
@@ -1845,14 +1845,16 @@ class ChatAgent(Agent):
1845
1845
  self.update_last_message(message, role=Role.USER)
1846
1846
  return answer_doc
1847
1847
 
1848
- def llm_response_forget(self, message: str) -> ChatDocument:
1848
+ def llm_response_forget(
1849
+ self, message: Optional[str | ChatDocument] = None
1850
+ ) -> ChatDocument:
1849
1851
  """
1850
1852
  LLM Response to single message, and restore message_history.
1851
1853
  In effect a "one-off" message & response that leaves agent
1852
1854
  message history state intact.
1853
1855
 
1854
1856
  Args:
1855
- message (str): user message
1857
+ message (str|ChatDocument): message to respond to.
1856
1858
 
1857
1859
  Returns:
1858
1860
  A Document object with the response.
@@ -1879,7 +1881,9 @@ class ChatAgent(Agent):
1879
1881
 
1880
1882
  return response
1881
1883
 
1882
- async def llm_response_forget_async(self, message: str) -> ChatDocument:
1884
+ async def llm_response_forget_async(
1885
+ self, message: Optional[str | ChatDocument] = None
1886
+ ) -> ChatDocument:
1883
1887
  """
1884
1888
  Async version of `llm_response_forget`. See there for details.
1885
1889
  """
@@ -1,3 +1,4 @@
1
+ # # langroid/agent/special/doc_chat_agent.py
1
2
  """
2
3
  Agent that supports asking queries about a set of documents, using
3
4
  retrieval-augmented generation (RAG).
@@ -16,14 +17,14 @@ pip install "langroid[hf-embeddings]"
16
17
  import logging
17
18
  from collections import OrderedDict
18
19
  from functools import cache
19
- from typing import Any, Dict, List, Optional, Set, Tuple, no_type_check
20
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, no_type_check
20
21
 
21
22
  import nest_asyncio
22
23
  import numpy as np
23
24
  import pandas as pd
24
25
  from rich.prompt import Prompt
25
26
 
26
- from langroid.agent.batch import run_batch_tasks
27
+ from langroid.agent.batch import run_batch_agent_method, run_batch_tasks
27
28
  from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
28
29
  from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
29
30
  from langroid.agent.special.relevance_extractor_agent import (
@@ -71,16 +72,17 @@ def apply_nest_asyncio() -> None:
71
72
 
72
73
  logger = logging.getLogger(__name__)
73
74
 
74
- DEFAULT_DOC_CHAT_INSTRUCTIONS = """
75
- Your task is to answer questions about various documents.
76
- You will be given various passages from these documents, and asked to answer questions
77
- about them, or summarize them into coherent answers.
78
- """
79
75
 
80
76
  DEFAULT_DOC_CHAT_SYSTEM_MESSAGE = """
81
77
  You are a helpful assistant, helping me understand a collection of documents.
78
+
79
+ Your TASK is to answer questions about various documents.
80
+ You will be given various passages from these documents, and asked to answer questions
81
+ about them, or summarize them into coherent answers.
82
82
  """
83
83
 
84
+ CHUNK_ENRICHMENT_DELIMITER = "<##-##-##>"
85
+
84
86
  has_sentence_transformers = False
85
87
  try:
86
88
  from sentence_transformers import SentenceTransformer # noqa: F401
@@ -102,9 +104,14 @@ oai_embed_config = OpenAIEmbeddingsConfig(
102
104
  )
103
105
 
104
106
 
107
+ class ChunkEnrichmentAgentConfig(ChatAgentConfig):
108
+ batch_size: int = 50
109
+ delimiter: str = CHUNK_ENRICHMENT_DELIMITER
110
+ enrichment_prompt_fn: Callable[[str], str] = lambda x: x
111
+
112
+
105
113
  class DocChatAgentConfig(ChatAgentConfig):
106
114
  system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
107
- user_message: str = DEFAULT_DOC_CHAT_INSTRUCTIONS
108
115
  summarize_prompt: str = SUMMARY_ANSWER_PROMPT_GPT4
109
116
  # extra fields to include in content as key=value pairs
110
117
  # (helps retrieval for table-like data)
@@ -126,6 +133,12 @@ class DocChatAgentConfig(ChatAgentConfig):
126
133
  # https://arxiv.org/pdf/2212.10496.pdf
127
134
  # It is False by default; its benefits depends on the context.
128
135
  hypothetical_answer: bool = False
136
+ # Optional config for chunk enrichment agent, e.g. to enrich
137
+ # chunks with hypothetical questions, or keywords to increase
138
+ # the "semantic surface area" of the chunks, which may help
139
+ # improve retrieval.
140
+ chunk_enrichment_config: Optional[ChunkEnrichmentAgentConfig] = None
141
+
129
142
  n_query_rephrases: int = 0
130
143
  n_neighbor_chunks: int = 0 # how many neighbors on either side of match to retrieve
131
144
  n_fuzzy_neighbor_words: int = 100 # num neighbor words to retrieve for fuzzy match
@@ -404,6 +417,8 @@ class DocChatAgent(ChatAgent):
404
417
  d.metadata.is_chunk = True
405
418
  if self.vecdb is None:
406
419
  raise ValueError("VecDB not set")
420
+ if self.config.chunk_enrichment_config is not None:
421
+ docs = self.enrich_chunks(docs)
407
422
 
408
423
  # If any additional fields need to be added to content,
409
424
  # add them as key=value pairs for all docs, before batching.
@@ -860,6 +875,72 @@ class DocChatAgent(ChatAgent):
860
875
  ).content
861
876
  return answer
862
877
 
878
+ def enrich_chunks(self, docs: List[Document]) -> List[Document]:
879
+ """
880
+ Enrich chunks using Agent configured with self.config.chunk_enrichment_config.
881
+
882
+ We assume that the system message of the agent is set in such a way
883
+ that when we run
884
+ ```
885
+ prompt = self.config.chunk_enrichment_config.enrichment_prompt_fn(text)
886
+ result = await agent.llm_response_forget_async(prompt)
887
+ ```
888
+
889
+ then `result.content` will contain the augmentation to the text.
890
+
891
+ Args:
892
+ docs: List of document chunks to enrich
893
+
894
+ Returns:
895
+ List[Document]: Documents (chunks) enriched with additional text,
896
+ separated by a delimiter.
897
+ """
898
+ if self.config.chunk_enrichment_config is None:
899
+ return docs
900
+ enrichment_config = self.config.chunk_enrichment_config
901
+ agent = ChatAgent(enrichment_config)
902
+ if agent.llm is None:
903
+ raise ValueError("LLM not set")
904
+
905
+ with status("[cyan]Augmenting chunks..."):
906
+ # Process chunks in parallel using run_batch_agent_method
907
+ questions_batch = run_batch_agent_method(
908
+ agent=agent,
909
+ method=agent.llm_response_forget_async,
910
+ items=docs,
911
+ input_map=lambda doc: (
912
+ enrichment_config.enrichment_prompt_fn(doc.content)
913
+ ),
914
+ output_map=lambda response: response.content if response else "",
915
+ sequential=False,
916
+ batch_size=enrichment_config.batch_size,
917
+ )
918
+
919
+ # Combine original content with generated questions
920
+ augmented_docs = []
921
+ for doc, enrichment in zip(docs, questions_batch):
922
+ if not enrichment:
923
+ augmented_docs.append(doc)
924
+ continue
925
+
926
+ # Combine original content with questions in a structured way
927
+ combined_content = f"""
928
+ {doc.content}
929
+
930
+ {enrichment_config.delimiter}
931
+ {enrichment}
932
+ """.strip()
933
+
934
+ new_doc = doc.copy(
935
+ update={
936
+ "content": combined_content,
937
+ "metadata": doc.metadata.copy(update={"has_enrichment": True}),
938
+ }
939
+ )
940
+ augmented_docs.append(new_doc)
941
+
942
+ return augmented_docs
943
+
863
944
  def llm_rephrase_query(self, query: str) -> List[str]:
864
945
  if self.llm is None:
865
946
  raise ValueError("LLM not set")
@@ -1305,7 +1386,6 @@ class DocChatAgent(ChatAgent):
1305
1386
  if self.config.n_query_rephrases > 0:
1306
1387
  rephrases = self.llm_rephrase_query(query)
1307
1388
  proxies += rephrases
1308
-
1309
1389
  passages = self.get_relevant_chunks(query, proxies) # no LLM involved
1310
1390
 
1311
1391
  if len(passages) == 0:
@@ -1319,6 +1399,29 @@ class DocChatAgent(ChatAgent):
1319
1399
 
1320
1400
  return query, extracts
1321
1401
 
1402
+ def remove_chunk_enrichments(self, passages: List[Document]) -> List[Document]:
1403
+ """Remove any enrichments (like hypothetical questions, or keywords)
1404
+ from documents.
1405
+ Only cleans if enrichment was enabled in config.
1406
+
1407
+ Args:
1408
+ passages: List of documents to clean
1409
+
1410
+ Returns:
1411
+ List of documents with only original content
1412
+ """
1413
+ if self.config.chunk_enrichment_config is None:
1414
+ return passages
1415
+ delimiter = self.config.chunk_enrichment_config.delimiter
1416
+ return [
1417
+ (
1418
+ doc.copy(update={"content": doc.content.split(delimiter)[0].strip()})
1419
+ if doc.content and getattr(doc.metadata, "has_enrichment", False)
1420
+ else doc
1421
+ )
1422
+ for doc in passages
1423
+ ]
1424
+
1322
1425
  def get_verbatim_extracts(
1323
1426
  self,
1324
1427
  query: str,
@@ -1334,6 +1437,8 @@ class DocChatAgent(ChatAgent):
1334
1437
  Returns:
1335
1438
  List[Document]: list of Documents containing extracts and metadata.
1336
1439
  """
1440
+ passages = self.remove_chunk_enrichments(passages)
1441
+
1337
1442
  agent_cfg = self.config.relevance_extractor_config
1338
1443
  if agent_cfg is None:
1339
1444
  # no relevance extraction: simply return passages
langroid/mytypes.py CHANGED
@@ -75,6 +75,7 @@ class Document(BaseModel):
75
75
  def id(self) -> str:
76
76
  return self.metadata.id
77
77
 
78
+ @staticmethod
78
79
  def from_string(
79
80
  content: str,
80
81
  source: str = "context",
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import re
2
3
  from enum import Enum
3
4
  from typing import Dict, List, Literal
4
5
 
@@ -250,12 +251,12 @@ class Parser:
250
251
  continue
251
252
 
252
253
  # Find the last period or punctuation mark in the chunk
253
- last_punctuation = max(
254
- chunk_text.rfind("."),
255
- chunk_text.rfind("?"),
256
- chunk_text.rfind("!"),
257
- chunk_text.rfind("\n"),
258
- )
254
+ punctuation_matches = [
255
+ (m.start(), m.group())
256
+ for m in re.finditer(r"(?:[.!?][\s\n]|\n)", chunk_text)
257
+ ]
258
+
259
+ last_punctuation = max([pos for pos, _ in punctuation_matches] + [-1])
259
260
 
260
261
  # If there is a punctuation mark, and the last punctuation index is
261
262
  # after MIN_CHUNK_SIZE_CHARS
@@ -268,7 +269,7 @@ class Parser:
268
269
 
269
270
  # Remove any newline characters and strip any leading or
270
271
  # trailing whitespace
271
- chunk_text_to_append = chunk_text.replace("\n", " ").strip()
272
+ chunk_text_to_append = re.sub(r"\n{2,}", "\n", chunk_text).strip()
272
273
 
273
274
  if len(chunk_text_to_append) > self.config.discard_chunk_chars:
274
275
  # Append the chunk text to the list of chunks
@@ -7,9 +7,9 @@ try:
7
7
  from pydispatch import dispatcher
8
8
  from scrapy import signals
9
9
  from scrapy.crawler import CrawlerRunner
10
- from scrapy.http import Response
11
- from scrapy.linkextractors import LinkExtractor
12
- from scrapy.spiders import CrawlSpider, Rule
10
+ from scrapy.http.response.text import TextResponse
11
+ from scrapy.linkextractors.lxmlhtml import LxmlLinkExtractor
12
+ from scrapy.spiders import CrawlSpider, Rule # type: ignore
13
13
  from twisted.internet import defer, reactor
14
14
  except ImportError:
15
15
  raise LangroidImportError("scrapy", "scrapy")
@@ -21,7 +21,7 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
21
21
 
22
22
  custom_settings = {"DEPTH_LIMIT": 1, "CLOSESPIDER_ITEMCOUNT": 20}
23
23
 
24
- rules = (Rule(LinkExtractor(), callback="parse_item", follow=True),)
24
+ rules = (Rule(LxmlLinkExtractor(), callback="parse_item", follow=True),)
25
25
 
26
26
  def __init__(self, start_url: str, k: int = 20, *args, **kwargs): # type: ignore
27
27
  """Initialize the spider with start_url and k.
@@ -36,13 +36,13 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
36
36
  self.k = k
37
37
  self.visited_urls: Set[str] = set()
38
38
 
39
- def parse_item(self, response: Response): # type: ignore
39
+ def parse_item(self, response: TextResponse): # type: ignore
40
40
  """Extracts URLs that are within the same domain.
41
41
 
42
42
  Args:
43
43
  response: The scrapy response object.
44
44
  """
45
- for link in LinkExtractor(allow_domains=self.allowed_domains).extract_links(
45
+ for link in LxmlLinkExtractor(allow_domains=self.allowed_domains).extract_links(
46
46
  response
47
47
  ):
48
48
  if len(self.visited_urls) < self.k:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langroid
3
- Version: 0.33.13
3
+ Version: 0.34.1
4
4
  Summary: Harness LLMs with Multi-Agent Programming
5
5
  Author-email: Prasad Chalasani <pchalasani@gmail.com>
6
6
  License: MIT
@@ -287,7 +287,11 @@ teacher_task.run()
287
287
  <details>
288
288
  <summary> <b>Click to expand</b></summary>
289
289
 
290
+ - **Jan 2025:**
291
+ - [0.33.0](https://github.com/langroid/langroid/releases/tag/0.33.3) Move from Poetry to uv!
292
+ - [0.32.0](https://github.com/langroid/langroid/releases/tag/0.32.0) DeepSeek v3 support.
290
293
  - **Dec 2024:**
294
+ - [0.31.0](https://github.com/langroid/langroid/releases/tag/0.31.0) Azure OpenAI Embeddings
291
295
  - [0.30.0](https://github.com/langroid/langroid/releases/tag/0.30.0) Llama-cpp embeddings.
292
296
  - [0.29.0](https://github.com/langroid/langroid/releases/tag/0.29.0) Custom Azure OpenAI Client
293
297
  - [0.28.0](https://github.com/langroid/langroid/releases/tag/0.28.0) `ToolMessage`: `_handler` field to override
@@ -1,11 +1,11 @@
1
1
  langroid/__init__.py,sha256=z_fCOLQJPOw3LLRPBlFB5-2HyCjpPgQa4m4iY5Fvb8Y,1800
2
2
  langroid/exceptions.py,sha256=gp6ku4ZLdXXCUQIwUNVFojJNGTzGnkevi2PLvG7HOhc,2555
3
- langroid/mytypes.py,sha256=vrU4eb5In8fEYwC881T4NByIaGdxiMhmY5dLxGEkyZY,2629
3
+ langroid/mytypes.py,sha256=h1eMq1ZwTLVezObPfCseWNWbEOzP7mAKu2XoS63W1cM,2647
4
4
  langroid/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  langroid/agent/__init__.py,sha256=ll0Cubd2DZ-fsCMl7e10hf9ZjFGKzphfBco396IKITY,786
6
6
  langroid/agent/base.py,sha256=oThlrYygKDu1-bKjAfygldJ511gMKT8Z0qCrD52DdDM,77834
7
- langroid/agent/batch.py,sha256=qK3ph6VNj_1sOhfXCZY4r6gh035DglDKU751p8BU0tY,14665
8
- langroid/agent/chat_agent.py,sha256=cxamUgqQkr6_W3mqCPz3L7rJnXIkD4hemR7X7uhlBvI,82095
7
+ langroid/agent/batch.py,sha256=vi1r5i1-vN80WfqHDSwjEym_KfGsqPGUtwktmiK1nuk,20635
8
+ langroid/agent/chat_agent.py,sha256=A-7Iiiw7jsoJNlWerljM29BidkiIbjPOQIkGZpZHmt0,82210
9
9
  langroid/agent/chat_document.py,sha256=xPUMGzR83rn4iAEXIw2jy5LQ6YJ6Y0TiZ78XRQeDnJQ,17778
10
10
  langroid/agent/openai_assistant.py,sha256=JkAcs02bIrgPNVvUWVR06VCthc5-ulla2QMBzux_q6o,34340
11
11
  langroid/agent/task.py,sha256=XrXUbSoiFasvpIsZPn_cBpdWaTCKljJPRimtLMrSZrs,90347
@@ -14,7 +14,7 @@ langroid/agent/xml_tool_message.py,sha256=6SshYZJKIfi4mkE-gIoSwjkEYekQ8GwcSiCv7a
14
14
  langroid/agent/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  langroid/agent/callbacks/chainlit.py,sha256=RH8qUXaZE5o2WQz3WJQ1SdFtASGlxWCA6_HYz_3meDQ,20822
16
16
  langroid/agent/special/__init__.py,sha256=gik_Xtm_zV7U9s30Mn8UX3Gyuy4jTjQe9zjiE3HWmEo,1273
17
- langroid/agent/special/doc_chat_agent.py,sha256=xXkrFb5ty3NoKfgC0CZqolxmUP1j_sXqywQ_VNxOrOE,59505
17
+ langroid/agent/special/doc_chat_agent.py,sha256=EPRKhchk1nuExn4_Lbuu5hzGkBP06PvJAK1edoISQHc,63470
18
18
  langroid/agent/special/lance_doc_chat_agent.py,sha256=s8xoRs0gGaFtDYFUSIRchsgDVbS5Q3C2b2mr3V1Fd-Q,10419
19
19
  langroid/agent/special/lance_tools.py,sha256=qS8x4wi8mrqfbYV2ztFzrcxyhHQ0ZWOc-zkYiH7awj0,2105
20
20
  langroid/agent/special/relevance_extractor_agent.py,sha256=zIx8GUdVo1aGW6ASla0NPQjYYIpmriK_TYMijqAx3F8,4796
@@ -81,11 +81,11 @@ langroid/parsing/code_parser.py,sha256=AOxb3xbYpTBPP3goOm5dKfJdh5hS_2BhLVCEkifWZ
81
81
  langroid/parsing/document_parser.py,sha256=9xUOyrVNBAS9cpCvCptr2XK4Kq47W574i8zzGEoXc3c,24933
82
82
  langroid/parsing/para_sentence_split.py,sha256=AJBzZojP3zpB-_IMiiHismhqcvkrVBQ3ZINoQyx_bE4,2000
83
83
  langroid/parsing/parse_json.py,sha256=aADo38bAHQhC8on4aWZZzVzSDy-dK35vRLZsFI2ewh8,4756
84
- langroid/parsing/parser.py,sha256=bTG5TO2CEwGdLf9979j9_dFntKX5FloGF8vhts6ObU0,11978
84
+ langroid/parsing/parser.py,sha256=N0jr1Zl_f_rx-8YMmSQftPHquqSQfec-3s7JAhhEe6I,12032
85
85
  langroid/parsing/repo_loader.py,sha256=3GjvPJS6Vf5L6gV2zOU8s-Tf1oq_fZm-IB_RL_7CTsY,29373
86
86
  langroid/parsing/routing.py,sha256=-FcnlqldzL4ZoxuDwXjQPNHgBe9F9-F4R6q7b_z9CvI,1232
87
87
  langroid/parsing/search.py,sha256=0i_r0ESb5HEQfagA2g7_uMQyxYPADWVbdcN9ixZhS4E,8992
88
- langroid/parsing/spider.py,sha256=Y6y7b86Y2k770LdhxgjVlImBxuuy1V9n8-XQ3QPaG5s,3199
88
+ langroid/parsing/spider.py,sha256=hAVM6wxh1pQ0EN4tI5wMBtAjIk0T-xnpi-ZUzWybhos,3258
89
89
  langroid/parsing/table_loader.py,sha256=qNM4obT_0Y4tjrxNBCNUYjKQ9oETCZ7FbolKBTcz-GM,3410
90
90
  langroid/parsing/url_loader.py,sha256=JK48KktLRDBfjrt4nsUfy92M6yGdEeicAqOum2MdULM,4656
91
91
  langroid/parsing/urls.py,sha256=XjpaV5onG7gKQ5iQeFTzHSw5P08Aqw0g-rMUu61lR6s,7988
@@ -121,7 +121,7 @@ langroid/vector_store/lancedb.py,sha256=b3_vWkTjG8mweZ7ZNlUD-NjmQP_rLBZfyKWcxt2v
121
121
  langroid/vector_store/meilisearch.py,sha256=6frB7GFWeWmeKzRfLZIvzRjllniZ1cYj3HmhHQICXLs,11663
122
122
  langroid/vector_store/momento.py,sha256=UNHGT6jXuQtqY9f6MdqGU14bVnS0zHgIJUa30ULpUJo,10474
123
123
  langroid/vector_store/qdrantdb.py,sha256=HRLCt-FG8y4718omwpFaQZnWeYxPj0XCwS4tjokI1sU,18116
124
- langroid-0.33.13.dist-info/METADATA,sha256=6D80UFs_rmlz89V6OEStdRBVECC-B8iLtBLTghgRg5w,59016
125
- langroid-0.33.13.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
126
- langroid-0.33.13.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
127
- langroid-0.33.13.dist-info/RECORD,,
124
+ langroid-0.34.1.dist-info/METADATA,sha256=_sKdjcKkVQBqsbSe4p79aSmHTjZ1ClwX9GXO1Fn15y0,59313
125
+ langroid-0.34.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
126
+ langroid-0.34.1.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
127
+ langroid-0.34.1.dist-info/RECORD,,