langroid 0.33.13__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/batch.py +292 -135
- langroid/agent/chat_agent.py +7 -3
- langroid/agent/special/doc_chat_agent.py +110 -3
- langroid/mytypes.py +1 -0
- langroid/parsing/spider.py +6 -6
- {langroid-0.33.13.dist-info → langroid-0.34.0.dist-info}/METADATA +1 -1
- {langroid-0.33.13.dist-info → langroid-0.34.0.dist-info}/RECORD +9 -9
- {langroid-0.33.13.dist-info → langroid-0.34.0.dist-info}/WHEEL +0 -0
- {langroid-0.33.13.dist-info → langroid-0.34.0.dist-info}/licenses/LICENSE +0 -0
langroid/agent/batch.py
CHANGED
@@ -1,7 +1,19 @@
|
|
1
1
|
import asyncio
|
2
2
|
import copy
|
3
3
|
import inspect
|
4
|
-
|
4
|
+
import warnings
|
5
|
+
from enum import Enum
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
Coroutine,
|
10
|
+
Iterable,
|
11
|
+
List,
|
12
|
+
Optional,
|
13
|
+
TypeVar,
|
14
|
+
Union,
|
15
|
+
cast,
|
16
|
+
)
|
5
17
|
|
6
18
|
from dotenv import load_dotenv
|
7
19
|
|
@@ -21,6 +33,238 @@ T = TypeVar("T")
|
|
21
33
|
U = TypeVar("U")
|
22
34
|
|
23
35
|
|
36
|
+
class ExceptionHandling(str, Enum):
|
37
|
+
"""Enum for exception handling options."""
|
38
|
+
|
39
|
+
RAISE = "raise"
|
40
|
+
RETURN_NONE = "return_none"
|
41
|
+
RETURN_EXCEPTION = "return_exception"
|
42
|
+
|
43
|
+
|
44
|
+
def _convert_exception_handling(
|
45
|
+
handle_exceptions: Union[bool, ExceptionHandling]
|
46
|
+
) -> ExceptionHandling:
|
47
|
+
"""Convert legacy boolean handle_exceptions to ExceptionHandling enum."""
|
48
|
+
if isinstance(handle_exceptions, ExceptionHandling):
|
49
|
+
return handle_exceptions
|
50
|
+
|
51
|
+
if isinstance(handle_exceptions, bool):
|
52
|
+
warnings.warn(
|
53
|
+
"Boolean handle_exceptions is deprecated. "
|
54
|
+
"Use ExceptionHandling enum instead: "
|
55
|
+
"RAISE, RETURN_NONE, or RETURN_EXCEPTION.",
|
56
|
+
DeprecationWarning,
|
57
|
+
stacklevel=2,
|
58
|
+
)
|
59
|
+
return (
|
60
|
+
ExceptionHandling.RETURN_NONE
|
61
|
+
if handle_exceptions
|
62
|
+
else ExceptionHandling.RAISE
|
63
|
+
)
|
64
|
+
|
65
|
+
raise TypeError(
|
66
|
+
"handle_exceptions must be bool or ExceptionHandling, "
|
67
|
+
f"not {type(handle_exceptions)}"
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
async def _process_batch_async(
|
72
|
+
inputs: Iterable[str | ChatDocument],
|
73
|
+
do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
|
74
|
+
start_idx: int = 0,
|
75
|
+
stop_on_first_result: bool = False,
|
76
|
+
sequential: bool = False,
|
77
|
+
handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
|
78
|
+
output_map: Callable[[Any], Any] = lambda x: x,
|
79
|
+
) -> List[Optional[ChatDocument] | BaseException]:
|
80
|
+
"""
|
81
|
+
Unified batch processing logic for both agent methods and tasks.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
inputs: Iterable of inputs to process
|
85
|
+
do_task: Task execution function that takes (input, index) and returns result
|
86
|
+
start_idx: Starting index for the batch
|
87
|
+
stop_on_first_result: Whether to stop after first valid result
|
88
|
+
sequential: Whether to process sequentially
|
89
|
+
handle_exceptions: How to handle exceptions:
|
90
|
+
- RAISE or False: Let exceptions propagate
|
91
|
+
- RETURN_NONE or True: Convert exceptions to None in results
|
92
|
+
- RETURN_EXCEPTION: Include exception objects in results
|
93
|
+
Boolean values are deprecated and will be removed in a future version.
|
94
|
+
output_map: Function to map results to final output format
|
95
|
+
"""
|
96
|
+
exception_handling = _convert_exception_handling(handle_exceptions)
|
97
|
+
|
98
|
+
def handle_error(e: BaseException) -> Any:
|
99
|
+
"""Handle exceptions based on exception_handling."""
|
100
|
+
match exception_handling:
|
101
|
+
case ExceptionHandling.RAISE:
|
102
|
+
raise e
|
103
|
+
case ExceptionHandling.RETURN_NONE:
|
104
|
+
return None
|
105
|
+
case ExceptionHandling.RETURN_EXCEPTION:
|
106
|
+
return e
|
107
|
+
|
108
|
+
if stop_on_first_result:
|
109
|
+
results: List[Optional[ChatDocument] | BaseException] = []
|
110
|
+
pending: set[asyncio.Task[Any]] = set()
|
111
|
+
# Create task-to-index mapping
|
112
|
+
task_indices: dict[asyncio.Task[Any], int] = {}
|
113
|
+
try:
|
114
|
+
tasks = [
|
115
|
+
asyncio.create_task(do_task(input, i + start_idx))
|
116
|
+
for i, input in enumerate(inputs)
|
117
|
+
]
|
118
|
+
task_indices = {task: i for i, task in enumerate(tasks)}
|
119
|
+
results = [None] * len(tasks)
|
120
|
+
|
121
|
+
done, pending = await asyncio.wait(
|
122
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
123
|
+
)
|
124
|
+
|
125
|
+
# Process completed tasks
|
126
|
+
for task in done:
|
127
|
+
index = task_indices[task]
|
128
|
+
try:
|
129
|
+
result = await task
|
130
|
+
results[index] = output_map(result)
|
131
|
+
except BaseException as e:
|
132
|
+
results[index] = handle_error(e)
|
133
|
+
|
134
|
+
if any(r is not None for r in results):
|
135
|
+
return results
|
136
|
+
finally:
|
137
|
+
for task in pending:
|
138
|
+
task.cancel()
|
139
|
+
try:
|
140
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
141
|
+
except BaseException as e:
|
142
|
+
handle_error(e)
|
143
|
+
return results
|
144
|
+
|
145
|
+
elif sequential:
|
146
|
+
results = []
|
147
|
+
for i, input in enumerate(inputs):
|
148
|
+
try:
|
149
|
+
result = await do_task(input, i + start_idx)
|
150
|
+
results.append(output_map(result))
|
151
|
+
except BaseException as e:
|
152
|
+
results.append(handle_error(e))
|
153
|
+
return results
|
154
|
+
|
155
|
+
# Parallel execution
|
156
|
+
else:
|
157
|
+
try:
|
158
|
+
return_exceptions = exception_handling != ExceptionHandling.RAISE
|
159
|
+
with quiet_mode(), SuppressLoggerWarnings():
|
160
|
+
results_with_exceptions = cast(
|
161
|
+
list[Optional[ChatDocument | BaseException]],
|
162
|
+
await asyncio.gather(
|
163
|
+
*(
|
164
|
+
do_task(input, i + start_idx)
|
165
|
+
for i, input in enumerate(inputs)
|
166
|
+
),
|
167
|
+
return_exceptions=return_exceptions,
|
168
|
+
),
|
169
|
+
)
|
170
|
+
|
171
|
+
if exception_handling == ExceptionHandling.RETURN_NONE:
|
172
|
+
results = [
|
173
|
+
None if isinstance(r, BaseException) else r
|
174
|
+
for r in results_with_exceptions
|
175
|
+
]
|
176
|
+
else: # ExceptionHandling.RETURN_EXCEPTION
|
177
|
+
results = results_with_exceptions
|
178
|
+
except BaseException as e:
|
179
|
+
results = [handle_error(e) for _ in inputs]
|
180
|
+
|
181
|
+
return [output_map(r) for r in results]
|
182
|
+
|
183
|
+
|
184
|
+
def run_batched_tasks(
|
185
|
+
inputs: List[str | ChatDocument],
|
186
|
+
do_task: Callable[[str | ChatDocument, int], Coroutine[Any, Any, Any]],
|
187
|
+
batch_size: Optional[int],
|
188
|
+
stop_on_first_result: bool,
|
189
|
+
sequential: bool,
|
190
|
+
handle_exceptions: Union[bool, ExceptionHandling],
|
191
|
+
output_map: Callable[[Any], Any],
|
192
|
+
message_template: str,
|
193
|
+
message: Optional[str] = None,
|
194
|
+
) -> List[Any]:
|
195
|
+
"""
|
196
|
+
Common batch processing logic for both agent methods and tasks.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
inputs: List of inputs to process
|
200
|
+
do_task: Task execution function
|
201
|
+
batch_size: Size of batches, if None process all at once
|
202
|
+
stop_on_first_result: Whether to stop after first valid result
|
203
|
+
sequential: Whether to process sequentially
|
204
|
+
handle_exceptions: How to handle exceptions:
|
205
|
+
- RAISE or False: Let exceptions propagate
|
206
|
+
- RETURN_NONE or True: Convert exceptions to None in results
|
207
|
+
- RETURN_EXCEPTION: Include exception objects in results
|
208
|
+
Boolean values are deprecated and will be removed in a future version.
|
209
|
+
output_map: Function to map results
|
210
|
+
message_template: Template for status message
|
211
|
+
message: Optional override for status message
|
212
|
+
"""
|
213
|
+
|
214
|
+
async def run_all_batched_tasks(
|
215
|
+
inputs: List[str | ChatDocument],
|
216
|
+
batch_size: int | None,
|
217
|
+
) -> List[Any]:
|
218
|
+
"""Extra wrap to run asyncio.run one single time and not once per loop
|
219
|
+
|
220
|
+
Args:
|
221
|
+
inputs (List[str | ChatDocument]): inputs to process
|
222
|
+
batch_size (int | None): batch size
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
List[Any]: results
|
226
|
+
"""
|
227
|
+
results: List[Any] = []
|
228
|
+
if batch_size is None:
|
229
|
+
msg = message or message_template.format(total=len(inputs))
|
230
|
+
with status(msg), SuppressLoggerWarnings():
|
231
|
+
results = await _process_batch_async(
|
232
|
+
inputs,
|
233
|
+
do_task,
|
234
|
+
stop_on_first_result=stop_on_first_result,
|
235
|
+
sequential=sequential,
|
236
|
+
handle_exceptions=handle_exceptions,
|
237
|
+
output_map=output_map,
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
batches = batched(inputs, batch_size)
|
241
|
+
for batch in batches:
|
242
|
+
start_idx = len(results)
|
243
|
+
complete_str = f", {start_idx} complete" if start_idx > 0 else ""
|
244
|
+
msg = (
|
245
|
+
message or message_template.format(total=len(inputs)) + complete_str
|
246
|
+
)
|
247
|
+
|
248
|
+
if stop_on_first_result and any(r is not None for r in results):
|
249
|
+
results.extend([None] * len(batch))
|
250
|
+
else:
|
251
|
+
with status(msg), SuppressLoggerWarnings():
|
252
|
+
results.extend(
|
253
|
+
await _process_batch_async(
|
254
|
+
batch,
|
255
|
+
do_task,
|
256
|
+
start_idx=start_idx,
|
257
|
+
stop_on_first_result=stop_on_first_result,
|
258
|
+
sequential=sequential,
|
259
|
+
handle_exceptions=handle_exceptions,
|
260
|
+
output_map=output_map,
|
261
|
+
)
|
262
|
+
)
|
263
|
+
return results
|
264
|
+
|
265
|
+
return asyncio.run(run_all_batched_tasks(inputs, batch_size))
|
266
|
+
|
267
|
+
|
24
268
|
def run_batch_task_gen(
|
25
269
|
gen_task: Callable[[int], Task],
|
26
270
|
items: list[T],
|
@@ -31,7 +275,7 @@ def run_batch_task_gen(
|
|
31
275
|
batch_size: Optional[int] = None,
|
32
276
|
turns: int = -1,
|
33
277
|
message: Optional[str] = None,
|
34
|
-
handle_exceptions: bool =
|
278
|
+
handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
|
35
279
|
max_cost: float = 0.0,
|
36
280
|
max_tokens: int = 0,
|
37
281
|
) -> list[Optional[U]]:
|
@@ -58,7 +302,11 @@ def run_batch_task_gen(
|
|
58
302
|
if None, unbatched
|
59
303
|
turns (int): number of turns to run, -1 for infinite
|
60
304
|
message (Optional[str]): optionally overrides the console status messages
|
61
|
-
handle_exceptions:
|
305
|
+
handle_exceptions: How to handle exceptions:
|
306
|
+
- RAISE or False: Let exceptions propagate
|
307
|
+
- RETURN_NONE or True: Convert exceptions to None in results
|
308
|
+
- RETURN_EXCEPTION: Include exception objects in results
|
309
|
+
Boolean values are deprecated and will be removed in a future version.
|
62
310
|
max_cost: float: maximum cost to run the task (default 0.0 for unlimited)
|
63
311
|
max_tokens: int: maximum token usage (in and out) (default 0 for unlimited)
|
64
312
|
|
@@ -72,113 +320,33 @@ def run_batch_task_gen(
|
|
72
320
|
async def _do_task(
|
73
321
|
input: str | ChatDocument,
|
74
322
|
i: int,
|
75
|
-
return_idx: Optional[int] = None,
|
76
323
|
) -> BaseException | Optional[ChatDocument] | tuple[int, Optional[ChatDocument]]:
|
77
324
|
task_i = gen_task(i)
|
78
325
|
if task_i.agent.llm is not None:
|
79
326
|
task_i.agent.llm.set_stream(False)
|
80
327
|
task_i.agent.config.show_stats = False
|
328
|
+
|
81
329
|
try:
|
82
330
|
result = await task_i.run_async(
|
83
331
|
input, turns=turns, max_cost=max_cost, max_tokens=max_tokens
|
84
332
|
)
|
85
|
-
if return_idx is not None:
|
86
|
-
return return_idx, result
|
87
|
-
else:
|
88
|
-
return result
|
89
333
|
except asyncio.CancelledError as e:
|
90
334
|
task_i.kill()
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
outputs: list[Optional[U]] = [None] * len(list(inputs))
|
107
|
-
tasks = set(
|
108
|
-
asyncio.create_task(_do_task(input, i + start_idx, return_idx=i))
|
109
|
-
for i, input in enumerate(inputs)
|
110
|
-
)
|
111
|
-
while tasks:
|
112
|
-
try:
|
113
|
-
done, tasks = await asyncio.wait(
|
114
|
-
tasks, return_when=asyncio.FIRST_COMPLETED
|
115
|
-
)
|
116
|
-
for task in done:
|
117
|
-
idx_result = task.result()
|
118
|
-
if not isinstance(idx_result, tuple):
|
119
|
-
continue
|
120
|
-
index, output = idx_result
|
121
|
-
outputs[index] = output_map(output)
|
122
|
-
|
123
|
-
if any(r is not None for r in outputs):
|
124
|
-
return outputs
|
125
|
-
finally:
|
126
|
-
# Cancel all remaining tasks
|
127
|
-
for task in tasks:
|
128
|
-
task.cancel()
|
129
|
-
# Wait for cancellations to complete
|
130
|
-
try:
|
131
|
-
await asyncio.gather(*tasks, return_exceptions=True)
|
132
|
-
except BaseException as e:
|
133
|
-
if not handle_exceptions:
|
134
|
-
raise e
|
135
|
-
return outputs
|
136
|
-
elif sequential:
|
137
|
-
for i, input in enumerate(inputs):
|
138
|
-
result: Optional[ChatDocument] | BaseException = await _do_task(
|
139
|
-
input, i + start_idx
|
140
|
-
) # type: ignore
|
141
|
-
|
142
|
-
if isinstance(result, BaseException):
|
143
|
-
result = None
|
144
|
-
|
145
|
-
results.append(result)
|
146
|
-
else:
|
147
|
-
results_with_exceptions = cast(
|
148
|
-
list[Optional[ChatDocument | BaseException]],
|
149
|
-
await asyncio.gather(
|
150
|
-
*(_do_task(input, i + start_idx) for i, input in enumerate(inputs)),
|
151
|
-
),
|
152
|
-
)
|
153
|
-
|
154
|
-
results = [
|
155
|
-
r if not isinstance(r, BaseException) else None
|
156
|
-
for r in results_with_exceptions
|
157
|
-
]
|
158
|
-
|
159
|
-
return list(map(output_map, results))
|
160
|
-
|
161
|
-
results: List[Optional[U]] = []
|
162
|
-
if batch_size is None:
|
163
|
-
msg = message or f"[bold green]Running {len(items)} tasks:"
|
164
|
-
|
165
|
-
with status(msg), SuppressLoggerWarnings():
|
166
|
-
results = asyncio.run(_do_all(inputs))
|
167
|
-
else:
|
168
|
-
batches = batched(inputs, batch_size)
|
169
|
-
|
170
|
-
for batch in batches:
|
171
|
-
start_idx = len(results)
|
172
|
-
complete_str = f", {start_idx} complete" if start_idx > 0 else ""
|
173
|
-
msg = message or f"[bold green]Running {len(items)} tasks{complete_str}:"
|
174
|
-
|
175
|
-
if stop_on_first_result and any(r is not None for r in results):
|
176
|
-
results.extend([None] * len(batch))
|
177
|
-
else:
|
178
|
-
with status(msg), SuppressLoggerWarnings():
|
179
|
-
results.extend(asyncio.run(_do_all(batch, start_idx=start_idx)))
|
180
|
-
|
181
|
-
return results
|
335
|
+
# exception will be handled by the caller
|
336
|
+
raise e
|
337
|
+
return result
|
338
|
+
|
339
|
+
return run_batched_tasks(
|
340
|
+
inputs=inputs,
|
341
|
+
do_task=_do_task,
|
342
|
+
batch_size=batch_size,
|
343
|
+
stop_on_first_result=stop_on_first_result,
|
344
|
+
sequential=sequential,
|
345
|
+
handle_exceptions=handle_exceptions,
|
346
|
+
output_map=output_map,
|
347
|
+
message_template="[bold green]Running {total} tasks:",
|
348
|
+
message=message,
|
349
|
+
)
|
182
350
|
|
183
351
|
|
184
352
|
def run_batch_tasks(
|
@@ -242,6 +410,8 @@ def run_batch_agent_method(
|
|
242
410
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
243
411
|
sequential: bool = True,
|
244
412
|
stop_on_first_result: bool = False,
|
413
|
+
handle_exceptions: Union[bool, ExceptionHandling] = ExceptionHandling.RAISE,
|
414
|
+
batch_size: Optional[int] = None,
|
245
415
|
) -> List[Any]:
|
246
416
|
"""
|
247
417
|
Run the `method` on copies of `agent`, async/concurrently one per
|
@@ -265,6 +435,14 @@ def run_batch_agent_method(
|
|
265
435
|
to final result
|
266
436
|
sequential (bool): whether to run sequentially
|
267
437
|
(e.g. some APIs such as ooba don't support concurrent requests)
|
438
|
+
stop_on_first_result (bool): whether to stop after the first valid
|
439
|
+
handle_exceptions: How to handle exceptions:
|
440
|
+
- RAISE or False: Let exceptions propagate
|
441
|
+
- RETURN_NONE or True: Convert exceptions to None in results
|
442
|
+
- RETURN_EXCEPTION: Include exception objects in results
|
443
|
+
Boolean values are deprecated and will be removed in a future version.
|
444
|
+
batch_size (Optional[int]): The number of items to process in each batch.
|
445
|
+
If None, process all items at once.
|
268
446
|
Returns:
|
269
447
|
List[Any]: list of final results
|
270
448
|
"""
|
@@ -288,43 +466,18 @@ def run_batch_agent_method(
|
|
288
466
|
if method_i is None:
|
289
467
|
raise ValueError(f"Agent {agent_name} has no method {method_name}")
|
290
468
|
result = await method_i(input)
|
291
|
-
return
|
469
|
+
return result
|
292
470
|
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
)
|
304
|
-
for task in done:
|
305
|
-
index = tasks.index(task)
|
306
|
-
results[index] = await task
|
307
|
-
finally:
|
308
|
-
for task in pending:
|
309
|
-
task.cancel()
|
310
|
-
await asyncio.gather(*pending, return_exceptions=True)
|
311
|
-
return results
|
312
|
-
elif sequential:
|
313
|
-
results = []
|
314
|
-
for i, input in enumerate(inputs):
|
315
|
-
result = await _do_task(input, i)
|
316
|
-
results.append(result)
|
317
|
-
return results
|
318
|
-
with quiet_mode(), SuppressLoggerWarnings():
|
319
|
-
return await asyncio.gather(
|
320
|
-
*(_do_task(input, i) for i, input in enumerate(inputs))
|
321
|
-
)
|
322
|
-
|
323
|
-
n = len(items)
|
324
|
-
with status(f"[bold green]Running {n} copies of {agent_name}..."):
|
325
|
-
results = asyncio.run(_do_all())
|
326
|
-
|
327
|
-
return results
|
471
|
+
return run_batched_tasks(
|
472
|
+
inputs=inputs,
|
473
|
+
do_task=_do_task,
|
474
|
+
batch_size=batch_size,
|
475
|
+
stop_on_first_result=stop_on_first_result,
|
476
|
+
sequential=sequential,
|
477
|
+
handle_exceptions=handle_exceptions,
|
478
|
+
output_map=output_map,
|
479
|
+
message_template=f"[bold green]Running {{total}} copies of {agent_name}...",
|
480
|
+
)
|
328
481
|
|
329
482
|
|
330
483
|
def llm_response_batch(
|
@@ -334,6 +487,7 @@ def llm_response_batch(
|
|
334
487
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
335
488
|
sequential: bool = True,
|
336
489
|
stop_on_first_result: bool = False,
|
490
|
+
batch_size: Optional[int] = None,
|
337
491
|
) -> List[Any]:
|
338
492
|
return run_batch_agent_method(
|
339
493
|
agent,
|
@@ -343,6 +497,7 @@ def llm_response_batch(
|
|
343
497
|
output_map=output_map,
|
344
498
|
sequential=sequential,
|
345
499
|
stop_on_first_result=stop_on_first_result,
|
500
|
+
batch_size=batch_size,
|
346
501
|
)
|
347
502
|
|
348
503
|
|
@@ -353,6 +508,7 @@ def agent_response_batch(
|
|
353
508
|
output_map: Callable[[ChatDocument | None], Any] = lambda x: x,
|
354
509
|
sequential: bool = True,
|
355
510
|
stop_on_first_result: bool = False,
|
511
|
+
batch_size: Optional[int] = None,
|
356
512
|
) -> List[Any]:
|
357
513
|
return run_batch_agent_method(
|
358
514
|
agent,
|
@@ -362,6 +518,7 @@ def agent_response_batch(
|
|
362
518
|
output_map=output_map,
|
363
519
|
sequential=sequential,
|
364
520
|
stop_on_first_result=stop_on_first_result,
|
521
|
+
batch_size=batch_size,
|
365
522
|
)
|
366
523
|
|
367
524
|
|
langroid/agent/chat_agent.py
CHANGED
@@ -1845,14 +1845,16 @@ class ChatAgent(Agent):
|
|
1845
1845
|
self.update_last_message(message, role=Role.USER)
|
1846
1846
|
return answer_doc
|
1847
1847
|
|
1848
|
-
def llm_response_forget(
|
1848
|
+
def llm_response_forget(
|
1849
|
+
self, message: Optional[str | ChatDocument] = None
|
1850
|
+
) -> ChatDocument:
|
1849
1851
|
"""
|
1850
1852
|
LLM Response to single message, and restore message_history.
|
1851
1853
|
In effect a "one-off" message & response that leaves agent
|
1852
1854
|
message history state intact.
|
1853
1855
|
|
1854
1856
|
Args:
|
1855
|
-
message (str):
|
1857
|
+
message (str|ChatDocument): message to respond to.
|
1856
1858
|
|
1857
1859
|
Returns:
|
1858
1860
|
A Document object with the response.
|
@@ -1879,7 +1881,9 @@ class ChatAgent(Agent):
|
|
1879
1881
|
|
1880
1882
|
return response
|
1881
1883
|
|
1882
|
-
async def llm_response_forget_async(
|
1884
|
+
async def llm_response_forget_async(
|
1885
|
+
self, message: Optional[str | ChatDocument] = None
|
1886
|
+
) -> ChatDocument:
|
1883
1887
|
"""
|
1884
1888
|
Async version of `llm_response_forget`. See there for details.
|
1885
1889
|
"""
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# # langroid/agent/special/doc_chat_agent.py
|
1
2
|
"""
|
2
3
|
Agent that supports asking queries about a set of documents, using
|
3
4
|
retrieval-augmented generation (RAG).
|
@@ -16,14 +17,14 @@ pip install "langroid[hf-embeddings]"
|
|
16
17
|
import logging
|
17
18
|
from collections import OrderedDict
|
18
19
|
from functools import cache
|
19
|
-
from typing import Any, Dict, List, Optional, Set, Tuple, no_type_check
|
20
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, no_type_check
|
20
21
|
|
21
22
|
import nest_asyncio
|
22
23
|
import numpy as np
|
23
24
|
import pandas as pd
|
24
25
|
from rich.prompt import Prompt
|
25
26
|
|
26
|
-
from langroid.agent.batch import run_batch_tasks
|
27
|
+
from langroid.agent.batch import run_batch_agent_method, run_batch_tasks
|
27
28
|
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
|
28
29
|
from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
|
29
30
|
from langroid.agent.special.relevance_extractor_agent import (
|
@@ -81,6 +82,8 @@ DEFAULT_DOC_CHAT_SYSTEM_MESSAGE = """
|
|
81
82
|
You are a helpful assistant, helping me understand a collection of documents.
|
82
83
|
"""
|
83
84
|
|
85
|
+
CHUNK_ENRICHMENT_DELIMITER = "<##-##-##>"
|
86
|
+
|
84
87
|
has_sentence_transformers = False
|
85
88
|
try:
|
86
89
|
from sentence_transformers import SentenceTransformer # noqa: F401
|
@@ -102,6 +105,12 @@ oai_embed_config = OpenAIEmbeddingsConfig(
|
|
102
105
|
)
|
103
106
|
|
104
107
|
|
108
|
+
class ChunkEnrichmentAgentConfig(ChatAgentConfig):
|
109
|
+
batch_size: int = 50
|
110
|
+
delimiter: str = CHUNK_ENRICHMENT_DELIMITER
|
111
|
+
enrichment_prompt_fn: Callable[[str], str] = lambda x: x
|
112
|
+
|
113
|
+
|
105
114
|
class DocChatAgentConfig(ChatAgentConfig):
|
106
115
|
system_message: str = DEFAULT_DOC_CHAT_SYSTEM_MESSAGE
|
107
116
|
user_message: str = DEFAULT_DOC_CHAT_INSTRUCTIONS
|
@@ -126,6 +135,12 @@ class DocChatAgentConfig(ChatAgentConfig):
|
|
126
135
|
# https://arxiv.org/pdf/2212.10496.pdf
|
127
136
|
# It is False by default; its benefits depends on the context.
|
128
137
|
hypothetical_answer: bool = False
|
138
|
+
# Optional config for chunk enrichment agent, e.g. to enrich
|
139
|
+
# chunks with hypothetical questions, or keywords to increase
|
140
|
+
# the "semantic surface area" of the chunks, which may help
|
141
|
+
# improve retrieval.
|
142
|
+
chunk_enrichment_config: Optional[ChunkEnrichmentAgentConfig] = None
|
143
|
+
|
129
144
|
n_query_rephrases: int = 0
|
130
145
|
n_neighbor_chunks: int = 0 # how many neighbors on either side of match to retrieve
|
131
146
|
n_fuzzy_neighbor_words: int = 100 # num neighbor words to retrieve for fuzzy match
|
@@ -404,6 +419,8 @@ class DocChatAgent(ChatAgent):
|
|
404
419
|
d.metadata.is_chunk = True
|
405
420
|
if self.vecdb is None:
|
406
421
|
raise ValueError("VecDB not set")
|
422
|
+
if self.config.chunk_enrichment_config is not None:
|
423
|
+
docs = self.enrich_chunks(docs)
|
407
424
|
|
408
425
|
# If any additional fields need to be added to content,
|
409
426
|
# add them as key=value pairs for all docs, before batching.
|
@@ -860,6 +877,72 @@ class DocChatAgent(ChatAgent):
|
|
860
877
|
).content
|
861
878
|
return answer
|
862
879
|
|
880
|
+
def enrich_chunks(self, docs: List[Document]) -> List[Document]:
|
881
|
+
"""
|
882
|
+
Enrich chunks using Agent configured with self.config.chunk_enrichment_config.
|
883
|
+
|
884
|
+
We assume that the system message of the agent is set in such a way
|
885
|
+
that when we run
|
886
|
+
```
|
887
|
+
prompt = self.config.chunk_enrichment_config.enrichment_prompt_fn(text)
|
888
|
+
result = await agent.llm_response_forget_async(prompt)
|
889
|
+
```
|
890
|
+
|
891
|
+
then `result.content` will contain the augmentation to the text.
|
892
|
+
|
893
|
+
Args:
|
894
|
+
docs: List of document chunks to enrich
|
895
|
+
|
896
|
+
Returns:
|
897
|
+
List[Document]: Documents (chunks) enriched with additional text,
|
898
|
+
separated by a delimiter.
|
899
|
+
"""
|
900
|
+
if self.config.chunk_enrichment_config is None:
|
901
|
+
return docs
|
902
|
+
enrichment_config = self.config.chunk_enrichment_config
|
903
|
+
agent = ChatAgent(enrichment_config)
|
904
|
+
if agent.llm is None:
|
905
|
+
raise ValueError("LLM not set")
|
906
|
+
|
907
|
+
with status("[cyan]Augmenting chunks..."):
|
908
|
+
# Process chunks in parallel using run_batch_agent_method
|
909
|
+
questions_batch = run_batch_agent_method(
|
910
|
+
agent=agent,
|
911
|
+
method=agent.llm_response_forget_async,
|
912
|
+
items=docs,
|
913
|
+
input_map=lambda doc: (
|
914
|
+
enrichment_config.enrichment_prompt_fn(doc.content)
|
915
|
+
),
|
916
|
+
output_map=lambda response: response.content if response else "",
|
917
|
+
sequential=False,
|
918
|
+
batch_size=enrichment_config.batch_size,
|
919
|
+
)
|
920
|
+
|
921
|
+
# Combine original content with generated questions
|
922
|
+
augmented_docs = []
|
923
|
+
for doc, enrichment in zip(docs, questions_batch):
|
924
|
+
if not enrichment:
|
925
|
+
augmented_docs.append(doc)
|
926
|
+
continue
|
927
|
+
|
928
|
+
# Combine original content with questions in a structured way
|
929
|
+
combined_content = f"""
|
930
|
+
{doc.content}
|
931
|
+
|
932
|
+
{enrichment_config.delimiter}
|
933
|
+
{enrichment}
|
934
|
+
""".strip()
|
935
|
+
|
936
|
+
new_doc = doc.copy(
|
937
|
+
update={
|
938
|
+
"content": combined_content,
|
939
|
+
"metadata": doc.metadata.copy(update={"has_enrichment": True}),
|
940
|
+
}
|
941
|
+
)
|
942
|
+
augmented_docs.append(new_doc)
|
943
|
+
|
944
|
+
return augmented_docs
|
945
|
+
|
863
946
|
def llm_rephrase_query(self, query: str) -> List[str]:
|
864
947
|
if self.llm is None:
|
865
948
|
raise ValueError("LLM not set")
|
@@ -1305,7 +1388,6 @@ class DocChatAgent(ChatAgent):
|
|
1305
1388
|
if self.config.n_query_rephrases > 0:
|
1306
1389
|
rephrases = self.llm_rephrase_query(query)
|
1307
1390
|
proxies += rephrases
|
1308
|
-
|
1309
1391
|
passages = self.get_relevant_chunks(query, proxies) # no LLM involved
|
1310
1392
|
|
1311
1393
|
if len(passages) == 0:
|
@@ -1319,6 +1401,29 @@ class DocChatAgent(ChatAgent):
|
|
1319
1401
|
|
1320
1402
|
return query, extracts
|
1321
1403
|
|
1404
|
+
def remove_chunk_enrichments(self, passages: List[Document]) -> List[Document]:
|
1405
|
+
"""Remove any enrichments (like hypothetical questions, or keywords)
|
1406
|
+
from documents.
|
1407
|
+
Only cleans if enrichment was enabled in config.
|
1408
|
+
|
1409
|
+
Args:
|
1410
|
+
passages: List of documents to clean
|
1411
|
+
|
1412
|
+
Returns:
|
1413
|
+
List of documents with only original content
|
1414
|
+
"""
|
1415
|
+
if self.config.chunk_enrichment_config is None:
|
1416
|
+
return passages
|
1417
|
+
delimiter = self.config.chunk_enrichment_config.delimiter
|
1418
|
+
return [
|
1419
|
+
(
|
1420
|
+
doc.copy(update={"content": doc.content.split(delimiter)[0].strip()})
|
1421
|
+
if doc.content and getattr(doc.metadata, "has_enrichment", False)
|
1422
|
+
else doc
|
1423
|
+
)
|
1424
|
+
for doc in passages
|
1425
|
+
]
|
1426
|
+
|
1322
1427
|
def get_verbatim_extracts(
|
1323
1428
|
self,
|
1324
1429
|
query: str,
|
@@ -1334,6 +1439,8 @@ class DocChatAgent(ChatAgent):
|
|
1334
1439
|
Returns:
|
1335
1440
|
List[Document]: list of Documents containing extracts and metadata.
|
1336
1441
|
"""
|
1442
|
+
passages = self.remove_chunk_enrichments(passages)
|
1443
|
+
|
1337
1444
|
agent_cfg = self.config.relevance_extractor_config
|
1338
1445
|
if agent_cfg is None:
|
1339
1446
|
# no relevance extraction: simply return passages
|
langroid/mytypes.py
CHANGED
langroid/parsing/spider.py
CHANGED
@@ -7,9 +7,9 @@ try:
|
|
7
7
|
from pydispatch import dispatcher
|
8
8
|
from scrapy import signals
|
9
9
|
from scrapy.crawler import CrawlerRunner
|
10
|
-
from scrapy.http import
|
11
|
-
from scrapy.linkextractors import
|
12
|
-
from scrapy.spiders import CrawlSpider, Rule
|
10
|
+
from scrapy.http.response.text import TextResponse
|
11
|
+
from scrapy.linkextractors.lxmlhtml import LxmlLinkExtractor
|
12
|
+
from scrapy.spiders import CrawlSpider, Rule # type: ignore
|
13
13
|
from twisted.internet import defer, reactor
|
14
14
|
except ImportError:
|
15
15
|
raise LangroidImportError("scrapy", "scrapy")
|
@@ -21,7 +21,7 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
|
|
21
21
|
|
22
22
|
custom_settings = {"DEPTH_LIMIT": 1, "CLOSESPIDER_ITEMCOUNT": 20}
|
23
23
|
|
24
|
-
rules = (Rule(
|
24
|
+
rules = (Rule(LxmlLinkExtractor(), callback="parse_item", follow=True),)
|
25
25
|
|
26
26
|
def __init__(self, start_url: str, k: int = 20, *args, **kwargs): # type: ignore
|
27
27
|
"""Initialize the spider with start_url and k.
|
@@ -36,13 +36,13 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
|
|
36
36
|
self.k = k
|
37
37
|
self.visited_urls: Set[str] = set()
|
38
38
|
|
39
|
-
def parse_item(self, response:
|
39
|
+
def parse_item(self, response: TextResponse): # type: ignore
|
40
40
|
"""Extracts URLs that are within the same domain.
|
41
41
|
|
42
42
|
Args:
|
43
43
|
response: The scrapy response object.
|
44
44
|
"""
|
45
|
-
for link in
|
45
|
+
for link in LxmlLinkExtractor(allow_domains=self.allowed_domains).extract_links(
|
46
46
|
response
|
47
47
|
):
|
48
48
|
if len(self.visited_urls) < self.k:
|
@@ -1,11 +1,11 @@
|
|
1
1
|
langroid/__init__.py,sha256=z_fCOLQJPOw3LLRPBlFB5-2HyCjpPgQa4m4iY5Fvb8Y,1800
|
2
2
|
langroid/exceptions.py,sha256=gp6ku4ZLdXXCUQIwUNVFojJNGTzGnkevi2PLvG7HOhc,2555
|
3
|
-
langroid/mytypes.py,sha256=
|
3
|
+
langroid/mytypes.py,sha256=h1eMq1ZwTLVezObPfCseWNWbEOzP7mAKu2XoS63W1cM,2647
|
4
4
|
langroid/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
langroid/agent/__init__.py,sha256=ll0Cubd2DZ-fsCMl7e10hf9ZjFGKzphfBco396IKITY,786
|
6
6
|
langroid/agent/base.py,sha256=oThlrYygKDu1-bKjAfygldJ511gMKT8Z0qCrD52DdDM,77834
|
7
|
-
langroid/agent/batch.py,sha256=
|
8
|
-
langroid/agent/chat_agent.py,sha256=
|
7
|
+
langroid/agent/batch.py,sha256=vi1r5i1-vN80WfqHDSwjEym_KfGsqPGUtwktmiK1nuk,20635
|
8
|
+
langroid/agent/chat_agent.py,sha256=A-7Iiiw7jsoJNlWerljM29BidkiIbjPOQIkGZpZHmt0,82210
|
9
9
|
langroid/agent/chat_document.py,sha256=xPUMGzR83rn4iAEXIw2jy5LQ6YJ6Y0TiZ78XRQeDnJQ,17778
|
10
10
|
langroid/agent/openai_assistant.py,sha256=JkAcs02bIrgPNVvUWVR06VCthc5-ulla2QMBzux_q6o,34340
|
11
11
|
langroid/agent/task.py,sha256=XrXUbSoiFasvpIsZPn_cBpdWaTCKljJPRimtLMrSZrs,90347
|
@@ -14,7 +14,7 @@ langroid/agent/xml_tool_message.py,sha256=6SshYZJKIfi4mkE-gIoSwjkEYekQ8GwcSiCv7a
|
|
14
14
|
langroid/agent/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
15
|
langroid/agent/callbacks/chainlit.py,sha256=RH8qUXaZE5o2WQz3WJQ1SdFtASGlxWCA6_HYz_3meDQ,20822
|
16
16
|
langroid/agent/special/__init__.py,sha256=gik_Xtm_zV7U9s30Mn8UX3Gyuy4jTjQe9zjiE3HWmEo,1273
|
17
|
-
langroid/agent/special/doc_chat_agent.py,sha256=
|
17
|
+
langroid/agent/special/doc_chat_agent.py,sha256=tI16jVavTSOen9OUoRTl5heDTeTBhWsxW17XU9ZcEko,63563
|
18
18
|
langroid/agent/special/lance_doc_chat_agent.py,sha256=s8xoRs0gGaFtDYFUSIRchsgDVbS5Q3C2b2mr3V1Fd-Q,10419
|
19
19
|
langroid/agent/special/lance_tools.py,sha256=qS8x4wi8mrqfbYV2ztFzrcxyhHQ0ZWOc-zkYiH7awj0,2105
|
20
20
|
langroid/agent/special/relevance_extractor_agent.py,sha256=zIx8GUdVo1aGW6ASla0NPQjYYIpmriK_TYMijqAx3F8,4796
|
@@ -85,7 +85,7 @@ langroid/parsing/parser.py,sha256=bTG5TO2CEwGdLf9979j9_dFntKX5FloGF8vhts6ObU0,11
|
|
85
85
|
langroid/parsing/repo_loader.py,sha256=3GjvPJS6Vf5L6gV2zOU8s-Tf1oq_fZm-IB_RL_7CTsY,29373
|
86
86
|
langroid/parsing/routing.py,sha256=-FcnlqldzL4ZoxuDwXjQPNHgBe9F9-F4R6q7b_z9CvI,1232
|
87
87
|
langroid/parsing/search.py,sha256=0i_r0ESb5HEQfagA2g7_uMQyxYPADWVbdcN9ixZhS4E,8992
|
88
|
-
langroid/parsing/spider.py,sha256=
|
88
|
+
langroid/parsing/spider.py,sha256=hAVM6wxh1pQ0EN4tI5wMBtAjIk0T-xnpi-ZUzWybhos,3258
|
89
89
|
langroid/parsing/table_loader.py,sha256=qNM4obT_0Y4tjrxNBCNUYjKQ9oETCZ7FbolKBTcz-GM,3410
|
90
90
|
langroid/parsing/url_loader.py,sha256=JK48KktLRDBfjrt4nsUfy92M6yGdEeicAqOum2MdULM,4656
|
91
91
|
langroid/parsing/urls.py,sha256=XjpaV5onG7gKQ5iQeFTzHSw5P08Aqw0g-rMUu61lR6s,7988
|
@@ -121,7 +121,7 @@ langroid/vector_store/lancedb.py,sha256=b3_vWkTjG8mweZ7ZNlUD-NjmQP_rLBZfyKWcxt2v
|
|
121
121
|
langroid/vector_store/meilisearch.py,sha256=6frB7GFWeWmeKzRfLZIvzRjllniZ1cYj3HmhHQICXLs,11663
|
122
122
|
langroid/vector_store/momento.py,sha256=UNHGT6jXuQtqY9f6MdqGU14bVnS0zHgIJUa30ULpUJo,10474
|
123
123
|
langroid/vector_store/qdrantdb.py,sha256=HRLCt-FG8y4718omwpFaQZnWeYxPj0XCwS4tjokI1sU,18116
|
124
|
-
langroid-0.
|
125
|
-
langroid-0.
|
126
|
-
langroid-0.
|
127
|
-
langroid-0.
|
124
|
+
langroid-0.34.0.dist-info/METADATA,sha256=fo7ULfjnWFED6Cag8aUFjOaPqEatQKBXEz-Z_rFyHnk,59015
|
125
|
+
langroid-0.34.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
126
|
+
langroid-0.34.0.dist-info/licenses/LICENSE,sha256=EgVbvA6VSYgUlvC3RvPKehSg7MFaxWDsFuzLOsPPfJg,1065
|
127
|
+
langroid-0.34.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|