langroid 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +353 -94
- langroid/agent/chat_agent.py +68 -9
- langroid/agent/chat_document.py +16 -7
- langroid/agent/openai_assistant.py +12 -1
- langroid/agent/special/lance_doc_chat_agent.py +25 -18
- langroid/agent/special/lance_rag/critic_agent.py +37 -5
- langroid/agent/special/lance_rag/query_planner_agent.py +102 -63
- langroid/agent/special/lance_tools.py +10 -2
- langroid/agent/task.py +156 -47
- langroid/agent/tool_message.py +12 -3
- langroid/agent/tools/__init__.py +5 -0
- langroid/agent/tools/orchestration.py +216 -0
- langroid/agent/tools/recipient_tool.py +6 -11
- langroid/agent/typed_task.py +19 -0
- langroid/language_models/base.py +3 -2
- langroid/mytypes.py +0 -1
- langroid/parsing/parse_json.py +19 -2
- langroid/utils/pydantic_utils.py +19 -0
- langroid/vector_store/base.py +3 -1
- langroid/vector_store/lancedb.py +2 -0
- {langroid-0.8.0.dist-info → langroid-0.9.0.dist-info}/METADATA +2 -1
- {langroid-0.8.0.dist-info → langroid-0.9.0.dist-info}/RECORD +25 -28
- pyproject.toml +2 -1
- langroid/agent/special/lance_rag_new/__init__.py +0 -9
- langroid/agent/special/lance_rag_new/critic_agent.py +0 -171
- langroid/agent/special/lance_rag_new/lance_rag_task.py +0 -144
- langroid/agent/special/lance_rag_new/query_planner_agent.py +0 -222
- langroid/agent/team.py +0 -1758
- {langroid-0.8.0.dist-info → langroid-0.9.0.dist-info}/LICENSE +0 -0
- {langroid-0.8.0.dist-info → langroid-0.9.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py
CHANGED
@@ -5,6 +5,7 @@ import json
|
|
5
5
|
import logging
|
6
6
|
import re
|
7
7
|
from abc import ABC
|
8
|
+
from collections import OrderedDict
|
8
9
|
from contextlib import ExitStack
|
9
10
|
from types import SimpleNamespace
|
10
11
|
from typing import (
|
@@ -31,11 +32,13 @@ from langroid.agent.tool_message import ToolMessage
|
|
31
32
|
from langroid.language_models.base import (
|
32
33
|
LanguageModel,
|
33
34
|
LLMConfig,
|
35
|
+
LLMFunctionCall,
|
34
36
|
LLMMessage,
|
35
37
|
LLMResponse,
|
36
38
|
LLMTokenUsage,
|
37
39
|
OpenAIToolCall,
|
38
40
|
StreamingIfAllowed,
|
41
|
+
ToolChoiceTypes,
|
39
42
|
)
|
40
43
|
from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
|
41
44
|
from langroid.mytypes import Entity
|
@@ -44,11 +47,12 @@ from langroid.parsing.parser import Parser, ParsingConfig
|
|
44
47
|
from langroid.prompts.prompts_config import PromptsConfig
|
45
48
|
from langroid.pydantic_v1 import BaseSettings, Field, ValidationError, validator
|
46
49
|
from langroid.utils.configuration import settings
|
47
|
-
from langroid.utils.constants import NO_ANSWER
|
50
|
+
from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
|
48
51
|
from langroid.utils.object_registry import ObjectRegistry
|
49
52
|
from langroid.utils.output import status
|
50
53
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
51
54
|
|
55
|
+
ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
|
52
56
|
console = Console(quiet=settings.quiet)
|
53
57
|
|
54
58
|
logger = logging.getLogger(__name__)
|
@@ -108,9 +112,8 @@ class Agent(ABC):
|
|
108
112
|
self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
|
109
113
|
self.llm_tools_handled: Set[str] = set()
|
110
114
|
self.llm_tools_usable: Set[str] = set()
|
115
|
+
self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
|
111
116
|
self.interactive: bool | None = None
|
112
|
-
self.total_llm_token_cost = 0.0
|
113
|
-
self.total_llm_token_usage = 0
|
114
117
|
self.token_stats_str = ""
|
115
118
|
self.default_human_response: Optional[str] = None
|
116
119
|
self._indent = ""
|
@@ -140,6 +143,13 @@ class Agent(ABC):
|
|
140
143
|
show_error_message=noop_fn,
|
141
144
|
show_start_response=noop_fn,
|
142
145
|
)
|
146
|
+
Agent.init_state(self)
|
147
|
+
self.init_state()
|
148
|
+
|
149
|
+
def init_state(self) -> None:
|
150
|
+
"""Initialize all state vars. Called by Task.run() if restart is True"""
|
151
|
+
self.total_llm_token_cost = 0.0
|
152
|
+
self.total_llm_token_usage = 0
|
143
153
|
|
144
154
|
@staticmethod
|
145
155
|
def from_id(id: str) -> "Agent":
|
@@ -239,7 +249,7 @@ class Agent(ABC):
|
|
239
249
|
):
|
240
250
|
"""
|
241
251
|
If the message class has a `handle` method,
|
242
|
-
and does NOT have a method with the same name as the tool,
|
252
|
+
and agent does NOT have a method with the same name as the tool,
|
243
253
|
then we create a method for the agent whose name
|
244
254
|
is the value of `tool`, and whose body is the `handle` method.
|
245
255
|
This removes a separate step of having to define this method
|
@@ -247,13 +257,25 @@ class Agent(ABC):
|
|
247
257
|
in one place, i.e. in the message class.
|
248
258
|
See `tests/main/test_stateless_tool_messages.py` for an example.
|
249
259
|
"""
|
250
|
-
|
260
|
+
has_chat_doc_arg = (
|
261
|
+
len(inspect.signature(message_class.handle).parameters) > 1
|
262
|
+
)
|
263
|
+
if has_chat_doc_arg:
|
264
|
+
setattr(self, tool, lambda obj, chat_doc: obj.handle(chat_doc))
|
265
|
+
else:
|
266
|
+
setattr(self, tool, lambda obj: obj.handle())
|
251
267
|
elif (
|
252
268
|
hasattr(message_class, "response")
|
253
269
|
and inspect.isfunction(message_class.response)
|
254
270
|
and not hasattr(self, tool)
|
255
271
|
):
|
256
|
-
|
272
|
+
has_chat_doc_arg = (
|
273
|
+
len(inspect.signature(message_class.response).parameters) > 2
|
274
|
+
)
|
275
|
+
if has_chat_doc_arg:
|
276
|
+
setattr(self, tool, lambda obj, chat_doc: obj.response(self, chat_doc))
|
277
|
+
else:
|
278
|
+
setattr(self, tool, lambda obj: obj.response(self))
|
257
279
|
|
258
280
|
if hasattr(message_class, "handle_message_fallback") and (
|
259
281
|
inspect.isfunction(message_class.handle_message_fallback)
|
@@ -311,9 +333,27 @@ class Agent(ABC):
|
|
311
333
|
]
|
312
334
|
return "\n\n".join(sample_convo)
|
313
335
|
|
314
|
-
def create_agent_response(
|
336
|
+
def create_agent_response(
|
337
|
+
self,
|
338
|
+
content: str | None = None,
|
339
|
+
tool_messages: List[ToolMessage] = [],
|
340
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
341
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
342
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
343
|
+
function_call: LLMFunctionCall | None = None,
|
344
|
+
recipient: str = "",
|
345
|
+
) -> ChatDocument:
|
315
346
|
"""Template for agent_response."""
|
316
|
-
return self.
|
347
|
+
return self.response_template(
|
348
|
+
Entity.AGENT,
|
349
|
+
content=content,
|
350
|
+
tool_messages=tool_messages,
|
351
|
+
oai_tool_calls=oai_tool_calls,
|
352
|
+
oai_tool_choice=oai_tool_choice,
|
353
|
+
oai_tool_id2result=oai_tool_id2result,
|
354
|
+
function_call=function_call,
|
355
|
+
recipient=recipient,
|
356
|
+
)
|
317
357
|
|
318
358
|
async def agent_response_async(
|
319
359
|
self,
|
@@ -366,7 +406,7 @@ class Agent(ABC):
|
|
366
406
|
# set sender_name to name of the function_call
|
367
407
|
sender_name = msg.function_call.name
|
368
408
|
|
369
|
-
results_str, id2result, oai_tool_id = self.
|
409
|
+
results_str, id2result, oai_tool_id = self.process_tool_results(
|
370
410
|
results if isinstance(results, str) else "",
|
371
411
|
id2result=None if isinstance(results, str) else results,
|
372
412
|
tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
|
@@ -384,21 +424,29 @@ class Agent(ABC):
|
|
384
424
|
),
|
385
425
|
)
|
386
426
|
|
387
|
-
def
|
427
|
+
def process_tool_results(
|
388
428
|
self,
|
389
429
|
results: str,
|
390
|
-
id2result:
|
430
|
+
id2result: OrderedDict[str, str] | None,
|
391
431
|
tool_calls: List[OpenAIToolCall] | None = None,
|
392
432
|
) -> Tuple[str, Dict[str, str] | None, str | None]:
|
393
433
|
"""
|
394
434
|
Process results from a response, based on whether
|
395
|
-
they are results of OpenAI tool-calls from THIS agent
|
435
|
+
they are results of OpenAI tool-calls from THIS agent, so that
|
436
|
+
we can construct an appropriate LLMMessage that contains tool results.
|
437
|
+
|
438
|
+
Args:
|
439
|
+
results (str): A possible string result from handling tool(s)
|
440
|
+
id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
|
441
|
+
if there are multiple tool results.
|
442
|
+
tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
|
443
|
+
results are a response to.
|
396
444
|
|
397
445
|
Return:
|
398
446
|
- str: The response string
|
399
447
|
- Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
|
400
448
|
multiple tool results.
|
401
|
-
- str|None: tool_id if
|
449
|
+
- str|None: tool_id if there was a single tool result
|
402
450
|
|
403
451
|
"""
|
404
452
|
id2result_ = copy.deepcopy(id2result) if id2result is not None else None
|
@@ -409,11 +457,11 @@ class Agent(ABC):
|
|
409
457
|
# in this case ignore id2result
|
410
458
|
assert (
|
411
459
|
id2result is None
|
412
|
-
), "id2result should be None when results is
|
460
|
+
), "id2result should be None when results string is non-empty!"
|
413
461
|
results_str = results
|
414
462
|
if len(self.oai_tool_calls) > 0:
|
415
|
-
#
|
416
|
-
#
|
463
|
+
# We only have one result, so in case there is a
|
464
|
+
# "pending" OpenAI tool-call, we expect no more than 1 such.
|
417
465
|
assert (
|
418
466
|
len(self.oai_tool_calls) == 1
|
419
467
|
), "There are multiple pending tool-calls, but only one result!"
|
@@ -422,70 +470,115 @@ class Agent(ABC):
|
|
422
470
|
# can properly set the `tool_call_id` field of the LLMMessage.
|
423
471
|
oai_tool_id = self.oai_tool_calls[0].id
|
424
472
|
elif id2result is not None and id2result_ is not None: # appease mypy
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
# pending tool-calls (self.oai_tool_calls). Such results are concatenated
|
431
|
-
# into a simple string, to store in the ChatDocument.content,
|
432
|
-
# and the rest
|
433
|
-
# (i.e. those that DO correspond to tools in self.oai_tool_calls)
|
434
|
-
# are stored as a dict in ChatDocument.oa_tool_id2result.
|
435
|
-
|
436
|
-
# OAI tools from THIS agent, awaiting response
|
437
|
-
pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
|
438
|
-
# tool_calls that the results are a response to
|
439
|
-
# (but these may have been sent from another agent, hence may not be in
|
440
|
-
# self.oai_tool_calls)
|
441
|
-
parent_tool_id2name = {
|
442
|
-
tc.id: tc.function.name
|
443
|
-
for tc in tool_calls or []
|
444
|
-
if tc.function is not None
|
445
|
-
}
|
446
|
-
|
447
|
-
# (id, result) for result NOT corresponding to self.oai_tool_calls,
|
448
|
-
# i.e. these are results of EXTERNAL tool-calls from another agent.
|
449
|
-
external_tool_id_results = []
|
450
|
-
|
451
|
-
for tc_id, result in id2result.items():
|
452
|
-
if tc_id not in pending_tool_ids:
|
453
|
-
external_tool_id_results.append((tc_id, result))
|
454
|
-
id2result_.pop(tc_id)
|
455
|
-
if len(external_tool_id_results) == 0:
|
473
|
+
if len(id2result_) == len(self.oai_tool_calls):
|
474
|
+
# if the number of pending tool calls equals the number of results,
|
475
|
+
# then ignore the ids in id2result, and use the results in order,
|
476
|
+
# which is preserved since id2result is an OrderedDict.
|
477
|
+
assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
|
456
478
|
results_str = ""
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
[
|
462
|
-
f"Result from tool/function {parent_tool_id2name[id]}: {result}"
|
463
|
-
for id, result in external_tool_id_results
|
464
|
-
]
|
479
|
+
id2result_ = OrderedDict(
|
480
|
+
zip(
|
481
|
+
[tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
|
482
|
+
)
|
465
483
|
)
|
484
|
+
else:
|
485
|
+
assert (
|
486
|
+
tool_calls is not None
|
487
|
+
), "tool_calls cannot be None when id2result is not None!"
|
488
|
+
# This must be an OpenAI tool id -> result map;
|
489
|
+
# However some ids may not correspond to the tool-calls in the list of
|
490
|
+
# pending tool-calls (self.oai_tool_calls).
|
491
|
+
# Such results are concatenated into a simple string, to store in the
|
492
|
+
# ChatDocument.content, and the rest
|
493
|
+
# (i.e. those that DO correspond to tools in self.oai_tool_calls)
|
494
|
+
# are stored as a dict in ChatDocument.oai_tool_id2result.
|
495
|
+
|
496
|
+
# OAI tools from THIS agent, awaiting response
|
497
|
+
pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
|
498
|
+
# tool_calls that the results are a response to
|
499
|
+
# (but these may have been sent from another agent, hence may not be in
|
500
|
+
# self.oai_tool_calls)
|
501
|
+
parent_tool_id2name = {
|
502
|
+
tc.id: tc.function.name
|
503
|
+
for tc in tool_calls or []
|
504
|
+
if tc.function is not None
|
505
|
+
}
|
506
|
+
|
507
|
+
# (id, result) for result NOT corresponding to self.oai_tool_calls,
|
508
|
+
# i.e. these are results of EXTERNAL tool-calls from another agent.
|
509
|
+
external_tool_id_results = []
|
510
|
+
|
511
|
+
for tc_id, result in id2result.items():
|
512
|
+
if tc_id not in pending_tool_ids:
|
513
|
+
external_tool_id_results.append((tc_id, result))
|
514
|
+
id2result_.pop(tc_id)
|
515
|
+
if len(external_tool_id_results) == 0:
|
516
|
+
results_str = ""
|
517
|
+
elif len(external_tool_id_results) == 1:
|
518
|
+
results_str = external_tool_id_results[0][1]
|
519
|
+
else:
|
520
|
+
results_str = "\n\n".join(
|
521
|
+
[
|
522
|
+
f"Result from tool/function "
|
523
|
+
f"{parent_tool_id2name[id]}: {result}"
|
524
|
+
for id, result in external_tool_id_results
|
525
|
+
]
|
526
|
+
)
|
466
527
|
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
528
|
+
if len(id2result_) == 0:
|
529
|
+
id2result_ = None
|
530
|
+
elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
|
531
|
+
results_str = list(id2result_.values())[0]
|
532
|
+
oai_tool_id = list(id2result_.keys())[0]
|
533
|
+
id2result_ = None
|
473
534
|
|
474
535
|
return results_str, id2result_, oai_tool_id
|
475
536
|
|
476
|
-
def
|
537
|
+
def response_template(
|
538
|
+
self,
|
539
|
+
e: Entity,
|
540
|
+
content: str | None = None,
|
541
|
+
tool_messages: List[ToolMessage] = [],
|
542
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
543
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
544
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
545
|
+
function_call: LLMFunctionCall | None = None,
|
546
|
+
recipient: str = "",
|
547
|
+
) -> ChatDocument:
|
477
548
|
"""Template for response from entity `e`."""
|
478
549
|
return ChatDocument(
|
479
550
|
content=content or "",
|
480
|
-
tool_messages=
|
551
|
+
tool_messages=tool_messages,
|
552
|
+
oai_tool_calls=oai_tool_calls,
|
553
|
+
oai_tool_id2result=oai_tool_id2result,
|
554
|
+
function_call=function_call,
|
555
|
+
oai_tool_choice=oai_tool_choice,
|
481
556
|
metadata=ChatDocMetaData(
|
482
|
-
source=e, sender=e, sender_name=self.config.name,
|
557
|
+
source=e, sender=e, sender_name=self.config.name, recipient=recipient
|
483
558
|
),
|
484
559
|
)
|
485
560
|
|
486
|
-
def create_user_response(
|
561
|
+
def create_user_response(
|
562
|
+
self,
|
563
|
+
content: str | None = None,
|
564
|
+
tool_messages: List[ToolMessage] = [],
|
565
|
+
oai_tool_calls: List[OpenAIToolCall] | None = None,
|
566
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
567
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
568
|
+
function_call: LLMFunctionCall | None = None,
|
569
|
+
recipient: str = "",
|
570
|
+
) -> ChatDocument:
|
487
571
|
"""Template for user_response."""
|
488
|
-
return self.
|
572
|
+
return self.response_template(
|
573
|
+
e=Entity.USER,
|
574
|
+
content=content,
|
575
|
+
tool_messages=tool_messages,
|
576
|
+
oai_tool_calls=oai_tool_calls,
|
577
|
+
oai_tool_choice=oai_tool_choice,
|
578
|
+
oai_tool_id2result=oai_tool_id2result,
|
579
|
+
function_call=function_call,
|
580
|
+
recipient=recipient,
|
581
|
+
)
|
489
582
|
|
490
583
|
async def user_response_async(
|
491
584
|
self,
|
@@ -579,9 +672,27 @@ class Agent(ABC):
|
|
579
672
|
|
580
673
|
return True
|
581
674
|
|
582
|
-
def create_llm_response(
|
675
|
+
def create_llm_response(
|
676
|
+
self,
|
677
|
+
content: str | None = None,
|
678
|
+
tool_messages: List[ToolMessage] = [],
|
679
|
+
oai_tool_calls: None | List[OpenAIToolCall] = None,
|
680
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
681
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
682
|
+
function_call: LLMFunctionCall | None = None,
|
683
|
+
recipient: str = "",
|
684
|
+
) -> ChatDocument:
|
583
685
|
"""Template for llm_response."""
|
584
|
-
return self.
|
686
|
+
return self.response_template(
|
687
|
+
Entity.LLM,
|
688
|
+
content=content,
|
689
|
+
tool_messages=tool_messages,
|
690
|
+
oai_tool_calls=oai_tool_calls,
|
691
|
+
oai_tool_choice=oai_tool_choice,
|
692
|
+
oai_tool_id2result=oai_tool_id2result,
|
693
|
+
function_call=function_call,
|
694
|
+
recipient=recipient,
|
695
|
+
)
|
585
696
|
|
586
697
|
@no_type_check
|
587
698
|
async def llm_response_async(
|
@@ -725,12 +836,60 @@ class Agent(ABC):
|
|
725
836
|
return True
|
726
837
|
return False
|
727
838
|
|
728
|
-
def
|
839
|
+
def _tool_recipient_match(self, tool: ToolMessage) -> bool:
|
840
|
+
"""Is tool is handled by this agent
|
841
|
+
and an explicit `recipient` field doesn't preclude this agent from handling it?
|
842
|
+
"""
|
843
|
+
if tool.default_value("request") not in self.llm_tools_handled:
|
844
|
+
return False
|
845
|
+
if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
|
846
|
+
return tool.recipient == "" or tool.recipient == self.config.name
|
847
|
+
return True
|
848
|
+
|
849
|
+
def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
|
850
|
+
"""
|
851
|
+
Does the msg have at least one tool, and ALL tools are
|
852
|
+
disabled for handling by this agent?
|
853
|
+
"""
|
854
|
+
tools = self.get_tool_messages(msg, all_tools=True)
|
855
|
+
if len(tools) == 0:
|
856
|
+
return False
|
857
|
+
return all(not self._tool_recipient_match(t) for t in tools)
|
858
|
+
|
859
|
+
def get_tool_messages(
|
860
|
+
self,
|
861
|
+
msg: str | ChatDocument,
|
862
|
+
all_tools: bool = False,
|
863
|
+
) -> List[ToolMessage]:
|
864
|
+
"""
|
865
|
+
Get ToolMessages recognized in msg, handle-able by this agent.
|
866
|
+
If all_tools is True:
|
867
|
+
- return all tools, i.e. any tool in self.llm_tools_known,
|
868
|
+
whether it is handled by this agent or not;
|
869
|
+
- otherwise, return only the tools handled by this agent.
|
870
|
+
"""
|
871
|
+
|
729
872
|
if isinstance(msg, str):
|
730
|
-
|
873
|
+
json_tools = self.get_json_tool_messages(msg)
|
874
|
+
if all_tools:
|
875
|
+
return json_tools
|
876
|
+
else:
|
877
|
+
return [
|
878
|
+
t
|
879
|
+
for t in json_tools
|
880
|
+
if self._tool_recipient_match(t) and t.default_value("request")
|
881
|
+
]
|
882
|
+
|
883
|
+
if all_tools and len(msg.all_tool_messages) > 0:
|
884
|
+
# We've already identified all_tool_messages in the msg;
|
885
|
+
# return the corresponding ToolMessage objects
|
886
|
+
return msg.all_tool_messages
|
731
887
|
if len(msg.tool_messages) > 0:
|
732
|
-
# We've already found tool_messages
|
733
|
-
# (either via OpenAI Fn-call or Langroid-native ToolMessage)
|
888
|
+
# We've already found tool_messages,
|
889
|
+
# (either via OpenAI Fn-call or Langroid-native ToolMessage);
|
890
|
+
# or they were added by an agent_response.
|
891
|
+
# note these could be from a forwarded msg from another agent,
|
892
|
+
# so return ONLY the messages THIS agent to enabled to handle.
|
734
893
|
return msg.tool_messages
|
735
894
|
assert isinstance(msg, ChatDocument)
|
736
895
|
if (
|
@@ -740,19 +899,32 @@ class Agent(ABC):
|
|
740
899
|
):
|
741
900
|
|
742
901
|
tools = self.get_json_tool_messages(msg.content)
|
743
|
-
msg.
|
744
|
-
|
902
|
+
msg.all_tool_messages = tools
|
903
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
904
|
+
msg.tool_messages = my_tools
|
905
|
+
|
906
|
+
if all_tools:
|
907
|
+
return tools
|
908
|
+
else:
|
909
|
+
return my_tools
|
745
910
|
|
746
911
|
# otherwise, we look for `tool_calls` (possibly multiple)
|
747
912
|
tools = self.get_oai_tool_calls_classes(msg)
|
748
|
-
msg.
|
913
|
+
msg.all_tool_messages = tools
|
914
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
915
|
+
msg.tool_messages = my_tools
|
749
916
|
|
750
917
|
if len(tools) == 0:
|
751
918
|
# otherwise, we look for a `function_call`
|
752
919
|
fun_call_cls = self.get_function_call_class(msg)
|
753
920
|
tools = [fun_call_cls] if fun_call_cls is not None else []
|
754
|
-
msg.
|
755
|
-
|
921
|
+
msg.all_tool_messages = tools
|
922
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
923
|
+
msg.tool_messages = my_tools
|
924
|
+
if all_tools:
|
925
|
+
return tools
|
926
|
+
else:
|
927
|
+
return my_tools
|
756
928
|
|
757
929
|
def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
|
758
930
|
"""
|
@@ -853,7 +1025,7 @@ class Agent(ABC):
|
|
853
1025
|
|
854
1026
|
def handle_message(
|
855
1027
|
self, msg: str | ChatDocument
|
856
|
-
) -> None | str |
|
1028
|
+
) -> None | str | OrderedDict[str, str] | ChatDocument:
|
857
1029
|
"""
|
858
1030
|
Handle a "tool" message either a string containing one or more
|
859
1031
|
valid "tool" JSON substrings, or a
|
@@ -889,16 +1061,63 @@ class Agent(ABC):
|
|
889
1061
|
# for this agent.
|
890
1062
|
return None
|
891
1063
|
if len(tools) == 0:
|
892
|
-
|
1064
|
+
fallback_result = self.handle_message_fallback(msg)
|
1065
|
+
if fallback_result is not None and isinstance(fallback_result, ToolMessage):
|
1066
|
+
return self.create_agent_response(tool_messages=[fallback_result])
|
1067
|
+
return fallback_result
|
893
1068
|
has_ids = all([t.id != "" for t in tools])
|
894
|
-
|
1069
|
+
chat_doc = msg if isinstance(msg, ChatDocument) else None
|
1070
|
+
|
1071
|
+
# check whether there are multiple orchestration-tools (e.g. DoneTool etc),
|
1072
|
+
# in which case set result to error-string since we don't yet support
|
1073
|
+
# multi-tools with one or more orch tools.
|
1074
|
+
from langroid.agent.tools.orchestration import (
|
1075
|
+
AgentDoneTool,
|
1076
|
+
AgentSendTool,
|
1077
|
+
DonePassTool,
|
1078
|
+
DoneTool,
|
1079
|
+
ForwardTool,
|
1080
|
+
PassTool,
|
1081
|
+
SendTool,
|
1082
|
+
)
|
1083
|
+
from langroid.agent.tools.recipient_tool import RecipientTool
|
1084
|
+
|
1085
|
+
ORCHESTRATION_TOOLS = (
|
1086
|
+
AgentDoneTool,
|
1087
|
+
DoneTool,
|
1088
|
+
PassTool,
|
1089
|
+
DonePassTool,
|
1090
|
+
ForwardTool,
|
1091
|
+
RecipientTool,
|
1092
|
+
SendTool,
|
1093
|
+
AgentSendTool,
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
|
1097
|
+
results: List[str | ChatDocument | None]
|
1098
|
+
if has_orch and len(tools) > 1:
|
1099
|
+
err_str = "ERROR: Use ONE tool at a time!"
|
1100
|
+
results = [err_str for _ in tools]
|
1101
|
+
else:
|
1102
|
+
results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
|
1103
|
+
|
895
1104
|
tool_names = [t.default_value("request") for t in tools]
|
896
1105
|
if has_ids:
|
897
|
-
id2result =
|
898
|
-
t.id
|
1106
|
+
id2result = OrderedDict(
|
1107
|
+
(t.id, r)
|
899
1108
|
for t, r in zip(tools, results)
|
900
1109
|
if r is not None and isinstance(r, str)
|
901
|
-
|
1110
|
+
)
|
1111
|
+
result_values = list(id2result.values())
|
1112
|
+
if len(id2result) > 1 and any(
|
1113
|
+
orch_str in r
|
1114
|
+
for r in result_values
|
1115
|
+
for orch_str in ORCHESTRATION_STRINGS
|
1116
|
+
):
|
1117
|
+
# Cannot support multi-tool results containing orchestration strings!
|
1118
|
+
# Replace results with err string to force LLM to retry
|
1119
|
+
err_str = "ERROR: Please use ONE tool at a time!"
|
1120
|
+
id2result = OrderedDict((id, err_str) for id in id2result.keys())
|
902
1121
|
|
903
1122
|
name_results_list = [
|
904
1123
|
(name, r) for name, r in zip(tool_names, results) if r is not None
|
@@ -940,19 +1159,28 @@ class Agent(ABC):
|
|
940
1159
|
self, msg: str | ChatDocument
|
941
1160
|
) -> str | ChatDocument | None:
|
942
1161
|
"""
|
943
|
-
Fallback method
|
944
|
-
|
945
|
-
|
1162
|
+
Fallback method for the "no-tools" scenario.
|
1163
|
+
This method can be overridden by subclasses, e.g.,
|
1164
|
+
to create a "reminder" message when a tool is expected but the LLM "forgot"
|
1165
|
+
to generate one.
|
946
1166
|
|
947
1167
|
Args:
|
948
1168
|
msg (str | ChatDocument): The input msg to handle
|
949
1169
|
Returns:
|
950
1170
|
str: The result of the handler method in string form so it can
|
951
|
-
be
|
1171
|
+
be handled by LLM
|
952
1172
|
"""
|
953
1173
|
return None
|
954
1174
|
|
955
1175
|
def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
|
1176
|
+
"""
|
1177
|
+
Parse the json str into ANY ToolMessage KNOWN to agent --
|
1178
|
+
This includes non-used/handled tools, i.e. any tool in self.llm_tools_known.
|
1179
|
+
The exception to this is below where we try our best to infer the tool
|
1180
|
+
when the LLM has "forgotten" to include the "request" field in the JSON --
|
1181
|
+
in this case we ONLY look at the possible set of HANDLED tools, i.e.
|
1182
|
+
self.llm_tools_handled.
|
1183
|
+
"""
|
956
1184
|
json_data = json.loads(json_str)
|
957
1185
|
# check if the json_data contains a "properties" field
|
958
1186
|
# which further contains the actual tool-call
|
@@ -976,9 +1204,8 @@ class Agent(ABC):
|
|
976
1204
|
if isinstance(properties, dict):
|
977
1205
|
json_data = properties
|
978
1206
|
request = json_data.get("request")
|
979
|
-
|
980
1207
|
if request is None:
|
981
|
-
|
1208
|
+
possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
|
982
1209
|
default_keys = set(ToolMessage.__fields__.keys())
|
983
1210
|
request_keys = set(json_data.keys())
|
984
1211
|
|
@@ -1002,7 +1229,7 @@ class Agent(ABC):
|
|
1002
1229
|
candidate_tools = list(
|
1003
1230
|
filter(
|
1004
1231
|
lambda t: t is not None,
|
1005
|
-
map(maybe_parse,
|
1232
|
+
map(maybe_parse, possible),
|
1006
1233
|
)
|
1007
1234
|
)
|
1008
1235
|
|
@@ -1013,7 +1240,7 @@ class Agent(ABC):
|
|
1013
1240
|
else:
|
1014
1241
|
return None
|
1015
1242
|
|
1016
|
-
if not isinstance(request, str) or request not in self.
|
1243
|
+
if not isinstance(request, str) or request not in self.llm_tools_known:
|
1017
1244
|
return None
|
1018
1245
|
|
1019
1246
|
message_class = self.llm_tools_map.get(request)
|
@@ -1027,11 +1254,18 @@ class Agent(ABC):
|
|
1027
1254
|
raise ve
|
1028
1255
|
return message
|
1029
1256
|
|
1030
|
-
def handle_tool_message(
|
1257
|
+
def handle_tool_message(
|
1258
|
+
self,
|
1259
|
+
tool: ToolMessage,
|
1260
|
+
chat_doc: Optional[ChatDocument] = None,
|
1261
|
+
) -> None | str | ChatDocument:
|
1031
1262
|
"""
|
1032
1263
|
Respond to a tool request from the LLM, in the form of an ToolMessage object.
|
1033
1264
|
Args:
|
1034
1265
|
tool: ToolMessage object representing the tool request.
|
1266
|
+
chat_doc: Optional ChatDocument object containing the tool request.
|
1267
|
+
This is passed to the tool-handler method only if it has a `chat_doc`
|
1268
|
+
argument.
|
1035
1269
|
|
1036
1270
|
Returns:
|
1037
1271
|
|
@@ -1040,9 +1274,34 @@ class Agent(ABC):
|
|
1040
1274
|
handler_method = getattr(self, tool_name, None)
|
1041
1275
|
if handler_method is None:
|
1042
1276
|
return None
|
1043
|
-
|
1277
|
+
has_chat_doc_arg = (
|
1278
|
+
chat_doc is not None
|
1279
|
+
and "chat_doc" in inspect.signature(handler_method).parameters
|
1280
|
+
)
|
1044
1281
|
try:
|
1045
|
-
|
1282
|
+
if has_chat_doc_arg:
|
1283
|
+
maybe_result = handler_method(tool, chat_doc=chat_doc)
|
1284
|
+
else:
|
1285
|
+
maybe_result = handler_method(tool)
|
1286
|
+
if isinstance(maybe_result, ToolMessage):
|
1287
|
+
# result is a ToolMessage, so...
|
1288
|
+
result_tool_name = maybe_result.default_value("request")
|
1289
|
+
if (
|
1290
|
+
result_tool_name in self.llm_tools_handled
|
1291
|
+
and tool_name != result_tool_name
|
1292
|
+
):
|
1293
|
+
# TODO: do we need to remove the tool message from the chat_doc?
|
1294
|
+
# if (chat_doc is not None and
|
1295
|
+
# maybe_result in chat_doc.tool_messages):
|
1296
|
+
# chat_doc.tool_messages.remove(maybe_result)
|
1297
|
+
# if we can handle it, do so
|
1298
|
+
result = self.handle_tool_message(maybe_result, chat_doc=chat_doc)
|
1299
|
+
else:
|
1300
|
+
# else wrap it in an agent response and return it so
|
1301
|
+
# orchestrator can find a respondent
|
1302
|
+
result = self.create_agent_response(tool_messages=[maybe_result])
|
1303
|
+
else:
|
1304
|
+
result = maybe_result
|
1046
1305
|
except Exception as e:
|
1047
1306
|
# raise the error here since we are sure it's
|
1048
1307
|
# not a pydantic validation error,
|