langroid 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langroid/agent/base.py CHANGED
@@ -5,6 +5,7 @@ import json
5
5
  import logging
6
6
  import re
7
7
  from abc import ABC
8
+ from collections import OrderedDict
8
9
  from contextlib import ExitStack
9
10
  from types import SimpleNamespace
10
11
  from typing import (
@@ -31,11 +32,13 @@ from langroid.agent.tool_message import ToolMessage
31
32
  from langroid.language_models.base import (
32
33
  LanguageModel,
33
34
  LLMConfig,
35
+ LLMFunctionCall,
34
36
  LLMMessage,
35
37
  LLMResponse,
36
38
  LLMTokenUsage,
37
39
  OpenAIToolCall,
38
40
  StreamingIfAllowed,
41
+ ToolChoiceTypes,
39
42
  )
40
43
  from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
41
44
  from langroid.mytypes import Entity
@@ -44,11 +47,12 @@ from langroid.parsing.parser import Parser, ParsingConfig
44
47
  from langroid.prompts.prompts_config import PromptsConfig
45
48
  from langroid.pydantic_v1 import BaseSettings, Field, ValidationError, validator
46
49
  from langroid.utils.configuration import settings
47
- from langroid.utils.constants import NO_ANSWER
50
+ from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
48
51
  from langroid.utils.object_registry import ObjectRegistry
49
52
  from langroid.utils.output import status
50
53
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
51
54
 
55
+ ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
52
56
  console = Console(quiet=settings.quiet)
53
57
 
54
58
  logger = logging.getLogger(__name__)
@@ -108,9 +112,8 @@ class Agent(ABC):
108
112
  self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
109
113
  self.llm_tools_handled: Set[str] = set()
110
114
  self.llm_tools_usable: Set[str] = set()
115
+ self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
111
116
  self.interactive: bool | None = None
112
- self.total_llm_token_cost = 0.0
113
- self.total_llm_token_usage = 0
114
117
  self.token_stats_str = ""
115
118
  self.default_human_response: Optional[str] = None
116
119
  self._indent = ""
@@ -140,6 +143,13 @@ class Agent(ABC):
140
143
  show_error_message=noop_fn,
141
144
  show_start_response=noop_fn,
142
145
  )
146
+ Agent.init_state(self)
147
+ self.init_state()
148
+
149
+ def init_state(self) -> None:
150
+ """Initialize all state vars. Called by Task.run() if restart is True"""
151
+ self.total_llm_token_cost = 0.0
152
+ self.total_llm_token_usage = 0
143
153
 
144
154
  @staticmethod
145
155
  def from_id(id: str) -> "Agent":
@@ -239,7 +249,7 @@ class Agent(ABC):
239
249
  ):
240
250
  """
241
251
  If the message class has a `handle` method,
242
- and does NOT have a method with the same name as the tool,
252
+ and agent does NOT have a method with the same name as the tool,
243
253
  then we create a method for the agent whose name
244
254
  is the value of `tool`, and whose body is the `handle` method.
245
255
  This removes a separate step of having to define this method
@@ -247,13 +257,25 @@ class Agent(ABC):
247
257
  in one place, i.e. in the message class.
248
258
  See `tests/main/test_stateless_tool_messages.py` for an example.
249
259
  """
250
- setattr(self, tool, lambda obj: obj.handle())
260
+ has_chat_doc_arg = (
261
+ len(inspect.signature(message_class.handle).parameters) > 1
262
+ )
263
+ if has_chat_doc_arg:
264
+ setattr(self, tool, lambda obj, chat_doc: obj.handle(chat_doc))
265
+ else:
266
+ setattr(self, tool, lambda obj: obj.handle())
251
267
  elif (
252
268
  hasattr(message_class, "response")
253
269
  and inspect.isfunction(message_class.response)
254
270
  and not hasattr(self, tool)
255
271
  ):
256
- setattr(self, tool, lambda obj: obj.response(self))
272
+ has_chat_doc_arg = (
273
+ len(inspect.signature(message_class.response).parameters) > 2
274
+ )
275
+ if has_chat_doc_arg:
276
+ setattr(self, tool, lambda obj, chat_doc: obj.response(self, chat_doc))
277
+ else:
278
+ setattr(self, tool, lambda obj: obj.response(self))
257
279
 
258
280
  if hasattr(message_class, "handle_message_fallback") and (
259
281
  inspect.isfunction(message_class.handle_message_fallback)
@@ -311,9 +333,27 @@ class Agent(ABC):
311
333
  ]
312
334
  return "\n\n".join(sample_convo)
313
335
 
314
- def create_agent_response(self, content: str | None = None) -> ChatDocument:
336
+ def create_agent_response(
337
+ self,
338
+ content: str | None = None,
339
+ tool_messages: List[ToolMessage] = [],
340
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
341
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
342
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
343
+ function_call: LLMFunctionCall | None = None,
344
+ recipient: str = "",
345
+ ) -> ChatDocument:
315
346
  """Template for agent_response."""
316
- return self._response_template(Entity.AGENT, content)
347
+ return self.response_template(
348
+ Entity.AGENT,
349
+ content=content,
350
+ tool_messages=tool_messages,
351
+ oai_tool_calls=oai_tool_calls,
352
+ oai_tool_choice=oai_tool_choice,
353
+ oai_tool_id2result=oai_tool_id2result,
354
+ function_call=function_call,
355
+ recipient=recipient,
356
+ )
317
357
 
318
358
  async def agent_response_async(
319
359
  self,
@@ -366,7 +406,7 @@ class Agent(ABC):
366
406
  # set sender_name to name of the function_call
367
407
  sender_name = msg.function_call.name
368
408
 
369
- results_str, id2result, oai_tool_id = self._process_tool_results(
409
+ results_str, id2result, oai_tool_id = self.process_tool_results(
370
410
  results if isinstance(results, str) else "",
371
411
  id2result=None if isinstance(results, str) else results,
372
412
  tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
@@ -384,21 +424,29 @@ class Agent(ABC):
384
424
  ),
385
425
  )
386
426
 
387
- def _process_tool_results(
427
+ def process_tool_results(
388
428
  self,
389
429
  results: str,
390
- id2result: Dict[str, str] | None,
430
+ id2result: OrderedDict[str, str] | None,
391
431
  tool_calls: List[OpenAIToolCall] | None = None,
392
432
  ) -> Tuple[str, Dict[str, str] | None, str | None]:
393
433
  """
394
434
  Process results from a response, based on whether
395
- they are results of OpenAI tool-calls from THIS agent.
435
+ they are results of OpenAI tool-calls from THIS agent, so that
436
+ we can construct an appropriate LLMMessage that contains tool results.
437
+
438
+ Args:
439
+ results (str): A possible string result from handling tool(s)
440
+ id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
441
+ if there are multiple tool results.
442
+ tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
443
+ results are a response to.
396
444
 
397
445
  Return:
398
446
  - str: The response string
399
447
  - Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
400
448
  multiple tool results.
401
- - str|None: tool_id if we there was a single tool result
449
+ - str|None: tool_id if there was a single tool result
402
450
 
403
451
  """
404
452
  id2result_ = copy.deepcopy(id2result) if id2result is not None else None
@@ -409,11 +457,11 @@ class Agent(ABC):
409
457
  # in this case ignore id2result
410
458
  assert (
411
459
  id2result is None
412
- ), "id2result should be None when results is not empty!"
460
+ ), "id2result should be None when results string is non-empty!"
413
461
  results_str = results
414
462
  if len(self.oai_tool_calls) > 0:
415
- # There may be multiple tool-calls, but we only have one result,
416
- # so we return it as a map of the first tool-call id to the result.
463
+ # We only have one result, so in case there is a
464
+ # "pending" OpenAI tool-call, we expect no more than 1 such.
417
465
  assert (
418
466
  len(self.oai_tool_calls) == 1
419
467
  ), "There are multiple pending tool-calls, but only one result!"
@@ -422,70 +470,115 @@ class Agent(ABC):
422
470
  # can properly set the `tool_call_id` field of the LLMMessage.
423
471
  oai_tool_id = self.oai_tool_calls[0].id
424
472
  elif id2result is not None and id2result_ is not None: # appease mypy
425
- assert (
426
- tool_calls is not None
427
- ), "tool_calls cannot be None when id2result is not None!"
428
- # This must be an OpenAI tool id -> result map;
429
- # However some ids may not correspond to the tool-calls in the list of
430
- # pending tool-calls (self.oai_tool_calls). Such results are concatenated
431
- # into a simple string, to store in the ChatDocument.content,
432
- # and the rest
433
- # (i.e. those that DO correspond to tools in self.oai_tool_calls)
434
- # are stored as a dict in ChatDocument.oa_tool_id2result.
435
-
436
- # OAI tools from THIS agent, awaiting response
437
- pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
438
- # tool_calls that the results are a response to
439
- # (but these may have been sent from another agent, hence may not be in
440
- # self.oai_tool_calls)
441
- parent_tool_id2name = {
442
- tc.id: tc.function.name
443
- for tc in tool_calls or []
444
- if tc.function is not None
445
- }
446
-
447
- # (id, result) for result NOT corresponding to self.oai_tool_calls,
448
- # i.e. these are results of EXTERNAL tool-calls from another agent.
449
- external_tool_id_results = []
450
-
451
- for tc_id, result in id2result.items():
452
- if tc_id not in pending_tool_ids:
453
- external_tool_id_results.append((tc_id, result))
454
- id2result_.pop(tc_id)
455
- if len(external_tool_id_results) == 0:
473
+ if len(id2result_) == len(self.oai_tool_calls):
474
+ # if the number of pending tool calls equals the number of results,
475
+ # then ignore the ids in id2result, and use the results in order,
476
+ # which is preserved since id2result is an OrderedDict.
477
+ assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
456
478
  results_str = ""
457
- elif len(external_tool_id_results) == 1:
458
- results_str = external_tool_id_results[0][1]
459
- else:
460
- results_str = "\n\n".join(
461
- [
462
- f"Result from tool/function {parent_tool_id2name[id]}: {result}"
463
- for id, result in external_tool_id_results
464
- ]
479
+ id2result_ = OrderedDict(
480
+ zip(
481
+ [tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
482
+ )
465
483
  )
484
+ else:
485
+ assert (
486
+ tool_calls is not None
487
+ ), "tool_calls cannot be None when id2result is not None!"
488
+ # This must be an OpenAI tool id -> result map;
489
+ # However some ids may not correspond to the tool-calls in the list of
490
+ # pending tool-calls (self.oai_tool_calls).
491
+ # Such results are concatenated into a simple string, to store in the
492
+ # ChatDocument.content, and the rest
493
+ # (i.e. those that DO correspond to tools in self.oai_tool_calls)
494
+ # are stored as a dict in ChatDocument.oai_tool_id2result.
495
+
496
+ # OAI tools from THIS agent, awaiting response
497
+ pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
498
+ # tool_calls that the results are a response to
499
+ # (but these may have been sent from another agent, hence may not be in
500
+ # self.oai_tool_calls)
501
+ parent_tool_id2name = {
502
+ tc.id: tc.function.name
503
+ for tc in tool_calls or []
504
+ if tc.function is not None
505
+ }
506
+
507
+ # (id, result) for result NOT corresponding to self.oai_tool_calls,
508
+ # i.e. these are results of EXTERNAL tool-calls from another agent.
509
+ external_tool_id_results = []
510
+
511
+ for tc_id, result in id2result.items():
512
+ if tc_id not in pending_tool_ids:
513
+ external_tool_id_results.append((tc_id, result))
514
+ id2result_.pop(tc_id)
515
+ if len(external_tool_id_results) == 0:
516
+ results_str = ""
517
+ elif len(external_tool_id_results) == 1:
518
+ results_str = external_tool_id_results[0][1]
519
+ else:
520
+ results_str = "\n\n".join(
521
+ [
522
+ f"Result from tool/function "
523
+ f"{parent_tool_id2name[id]}: {result}"
524
+ for id, result in external_tool_id_results
525
+ ]
526
+ )
466
527
 
467
- if len(id2result_) == 0:
468
- id2result_ = None
469
- elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
470
- results_str = list(id2result_.values())[0]
471
- oai_tool_id = list(id2result_.keys())[0]
472
- id2result_ = None
528
+ if len(id2result_) == 0:
529
+ id2result_ = None
530
+ elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
531
+ results_str = list(id2result_.values())[0]
532
+ oai_tool_id = list(id2result_.keys())[0]
533
+ id2result_ = None
473
534
 
474
535
  return results_str, id2result_, oai_tool_id
475
536
 
476
- def _response_template(self, e: Entity, content: str | None = None) -> ChatDocument:
537
+ def response_template(
538
+ self,
539
+ e: Entity,
540
+ content: str | None = None,
541
+ tool_messages: List[ToolMessage] = [],
542
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
543
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
544
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
545
+ function_call: LLMFunctionCall | None = None,
546
+ recipient: str = "",
547
+ ) -> ChatDocument:
477
548
  """Template for response from entity `e`."""
478
549
  return ChatDocument(
479
550
  content=content or "",
480
- tool_messages=[],
551
+ tool_messages=tool_messages,
552
+ oai_tool_calls=oai_tool_calls,
553
+ oai_tool_id2result=oai_tool_id2result,
554
+ function_call=function_call,
555
+ oai_tool_choice=oai_tool_choice,
481
556
  metadata=ChatDocMetaData(
482
- source=e, sender=e, sender_name=self.config.name, tool_ids=[]
557
+ source=e, sender=e, sender_name=self.config.name, recipient=recipient
483
558
  ),
484
559
  )
485
560
 
486
- def create_user_response(self, content: str | None = None) -> ChatDocument:
561
+ def create_user_response(
562
+ self,
563
+ content: str | None = None,
564
+ tool_messages: List[ToolMessage] = [],
565
+ oai_tool_calls: List[OpenAIToolCall] | None = None,
566
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
567
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
568
+ function_call: LLMFunctionCall | None = None,
569
+ recipient: str = "",
570
+ ) -> ChatDocument:
487
571
  """Template for user_response."""
488
- return self._response_template(Entity.USER, content)
572
+ return self.response_template(
573
+ e=Entity.USER,
574
+ content=content,
575
+ tool_messages=tool_messages,
576
+ oai_tool_calls=oai_tool_calls,
577
+ oai_tool_choice=oai_tool_choice,
578
+ oai_tool_id2result=oai_tool_id2result,
579
+ function_call=function_call,
580
+ recipient=recipient,
581
+ )
489
582
 
490
583
  async def user_response_async(
491
584
  self,
@@ -579,9 +672,27 @@ class Agent(ABC):
579
672
 
580
673
  return True
581
674
 
582
- def create_llm_response(self, content: str | None = None) -> ChatDocument:
675
+ def create_llm_response(
676
+ self,
677
+ content: str | None = None,
678
+ tool_messages: List[ToolMessage] = [],
679
+ oai_tool_calls: None | List[OpenAIToolCall] = None,
680
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
681
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
682
+ function_call: LLMFunctionCall | None = None,
683
+ recipient: str = "",
684
+ ) -> ChatDocument:
583
685
  """Template for llm_response."""
584
- return self._response_template(Entity.LLM, content)
686
+ return self.response_template(
687
+ Entity.LLM,
688
+ content=content,
689
+ tool_messages=tool_messages,
690
+ oai_tool_calls=oai_tool_calls,
691
+ oai_tool_choice=oai_tool_choice,
692
+ oai_tool_id2result=oai_tool_id2result,
693
+ function_call=function_call,
694
+ recipient=recipient,
695
+ )
585
696
 
586
697
  @no_type_check
587
698
  async def llm_response_async(
@@ -725,12 +836,60 @@ class Agent(ABC):
725
836
  return True
726
837
  return False
727
838
 
728
- def get_tool_messages(self, msg: str | ChatDocument) -> List[ToolMessage]:
839
+ def _tool_recipient_match(self, tool: ToolMessage) -> bool:
840
+ """Is tool is handled by this agent
841
+ and an explicit `recipient` field doesn't preclude this agent from handling it?
842
+ """
843
+ if tool.default_value("request") not in self.llm_tools_handled:
844
+ return False
845
+ if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
846
+ return tool.recipient == "" or tool.recipient == self.config.name
847
+ return True
848
+
849
+ def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
850
+ """
851
+ Does the msg have at least one tool, and ALL tools are
852
+ disabled for handling by this agent?
853
+ """
854
+ tools = self.get_tool_messages(msg, all_tools=True)
855
+ if len(tools) == 0:
856
+ return False
857
+ return all(not self._tool_recipient_match(t) for t in tools)
858
+
859
+ def get_tool_messages(
860
+ self,
861
+ msg: str | ChatDocument,
862
+ all_tools: bool = False,
863
+ ) -> List[ToolMessage]:
864
+ """
865
+ Get ToolMessages recognized in msg, handle-able by this agent.
866
+ If all_tools is True:
867
+ - return all tools, i.e. any tool in self.llm_tools_known,
868
+ whether it is handled by this agent or not;
869
+ - otherwise, return only the tools handled by this agent.
870
+ """
871
+
729
872
  if isinstance(msg, str):
730
- return self.get_json_tool_messages(msg)
873
+ json_tools = self.get_json_tool_messages(msg)
874
+ if all_tools:
875
+ return json_tools
876
+ else:
877
+ return [
878
+ t
879
+ for t in json_tools
880
+ if self._tool_recipient_match(t) and t.default_value("request")
881
+ ]
882
+
883
+ if all_tools and len(msg.all_tool_messages) > 0:
884
+ # We've already identified all_tool_messages in the msg;
885
+ # return the corresponding ToolMessage objects
886
+ return msg.all_tool_messages
731
887
  if len(msg.tool_messages) > 0:
732
- # We've already found tool_messages
733
- # (either via OpenAI Fn-call or Langroid-native ToolMessage)
888
+ # We've already found tool_messages,
889
+ # (either via OpenAI Fn-call or Langroid-native ToolMessage);
890
+ # or they were added by an agent_response.
891
+ # note these could be from a forwarded msg from another agent,
892
+ # so return ONLY the messages THIS agent to enabled to handle.
734
893
  return msg.tool_messages
735
894
  assert isinstance(msg, ChatDocument)
736
895
  if (
@@ -740,19 +899,32 @@ class Agent(ABC):
740
899
  ):
741
900
 
742
901
  tools = self.get_json_tool_messages(msg.content)
743
- msg.tool_messages = tools
744
- return tools
902
+ msg.all_tool_messages = tools
903
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
904
+ msg.tool_messages = my_tools
905
+
906
+ if all_tools:
907
+ return tools
908
+ else:
909
+ return my_tools
745
910
 
746
911
  # otherwise, we look for `tool_calls` (possibly multiple)
747
912
  tools = self.get_oai_tool_calls_classes(msg)
748
- msg.tool_messages = tools
913
+ msg.all_tool_messages = tools
914
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
915
+ msg.tool_messages = my_tools
749
916
 
750
917
  if len(tools) == 0:
751
918
  # otherwise, we look for a `function_call`
752
919
  fun_call_cls = self.get_function_call_class(msg)
753
920
  tools = [fun_call_cls] if fun_call_cls is not None else []
754
- msg.tool_messages = tools
755
- return tools
921
+ msg.all_tool_messages = tools
922
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
923
+ msg.tool_messages = my_tools
924
+ if all_tools:
925
+ return tools
926
+ else:
927
+ return my_tools
756
928
 
757
929
  def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
758
930
  """
@@ -853,7 +1025,7 @@ class Agent(ABC):
853
1025
 
854
1026
  def handle_message(
855
1027
  self, msg: str | ChatDocument
856
- ) -> None | str | Dict[str, str] | ChatDocument:
1028
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
857
1029
  """
858
1030
  Handle a "tool" message either a string containing one or more
859
1031
  valid "tool" JSON substrings, or a
@@ -889,16 +1061,63 @@ class Agent(ABC):
889
1061
  # for this agent.
890
1062
  return None
891
1063
  if len(tools) == 0:
892
- return self.handle_message_fallback(msg)
1064
+ fallback_result = self.handle_message_fallback(msg)
1065
+ if fallback_result is not None and isinstance(fallback_result, ToolMessage):
1066
+ return self.create_agent_response(tool_messages=[fallback_result])
1067
+ return fallback_result
893
1068
  has_ids = all([t.id != "" for t in tools])
894
- results = [self.handle_tool_message(t) for t in tools]
1069
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1070
+
1071
+ # check whether there are multiple orchestration-tools (e.g. DoneTool etc),
1072
+ # in which case set result to error-string since we don't yet support
1073
+ # multi-tools with one or more orch tools.
1074
+ from langroid.agent.tools.orchestration import (
1075
+ AgentDoneTool,
1076
+ AgentSendTool,
1077
+ DonePassTool,
1078
+ DoneTool,
1079
+ ForwardTool,
1080
+ PassTool,
1081
+ SendTool,
1082
+ )
1083
+ from langroid.agent.tools.recipient_tool import RecipientTool
1084
+
1085
+ ORCHESTRATION_TOOLS = (
1086
+ AgentDoneTool,
1087
+ DoneTool,
1088
+ PassTool,
1089
+ DonePassTool,
1090
+ ForwardTool,
1091
+ RecipientTool,
1092
+ SendTool,
1093
+ AgentSendTool,
1094
+ )
1095
+
1096
+ has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
1097
+ results: List[str | ChatDocument | None]
1098
+ if has_orch and len(tools) > 1:
1099
+ err_str = "ERROR: Use ONE tool at a time!"
1100
+ results = [err_str for _ in tools]
1101
+ else:
1102
+ results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
1103
+
895
1104
  tool_names = [t.default_value("request") for t in tools]
896
1105
  if has_ids:
897
- id2result = {
898
- t.id: r
1106
+ id2result = OrderedDict(
1107
+ (t.id, r)
899
1108
  for t, r in zip(tools, results)
900
1109
  if r is not None and isinstance(r, str)
901
- }
1110
+ )
1111
+ result_values = list(id2result.values())
1112
+ if len(id2result) > 1 and any(
1113
+ orch_str in r
1114
+ for r in result_values
1115
+ for orch_str in ORCHESTRATION_STRINGS
1116
+ ):
1117
+ # Cannot support multi-tool results containing orchestration strings!
1118
+ # Replace results with err string to force LLM to retry
1119
+ err_str = "ERROR: Please use ONE tool at a time!"
1120
+ id2result = OrderedDict((id, err_str) for id in id2result.keys())
902
1121
 
903
1122
  name_results_list = [
904
1123
  (name, r) for name, r in zip(tool_names, results) if r is not None
@@ -940,19 +1159,28 @@ class Agent(ABC):
940
1159
  self, msg: str | ChatDocument
941
1160
  ) -> str | ChatDocument | None:
942
1161
  """
943
- Fallback method to handle possible "tool" msg if no other method applies
944
- or if an error is thrown.
945
- This method can be overridden by subclasses.
1162
+ Fallback method for the "no-tools" scenario.
1163
+ This method can be overridden by subclasses, e.g.,
1164
+ to create a "reminder" message when a tool is expected but the LLM "forgot"
1165
+ to generate one.
946
1166
 
947
1167
  Args:
948
1168
  msg (str | ChatDocument): The input msg to handle
949
1169
  Returns:
950
1170
  str: The result of the handler method in string form so it can
951
- be sent back to the LLM.
1171
+ be handled by LLM
952
1172
  """
953
1173
  return None
954
1174
 
955
1175
  def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
1176
+ """
1177
+ Parse the json str into ANY ToolMessage KNOWN to agent --
1178
+ This includes non-used/handled tools, i.e. any tool in self.llm_tools_known.
1179
+ The exception to this is below where we try our best to infer the tool
1180
+ when the LLM has "forgotten" to include the "request" field in the JSON --
1181
+ in this case we ONLY look at the possible set of HANDLED tools, i.e.
1182
+ self.llm_tools_handled.
1183
+ """
956
1184
  json_data = json.loads(json_str)
957
1185
  # check if the json_data contains a "properties" field
958
1186
  # which further contains the actual tool-call
@@ -976,9 +1204,8 @@ class Agent(ABC):
976
1204
  if isinstance(properties, dict):
977
1205
  json_data = properties
978
1206
  request = json_data.get("request")
979
-
980
1207
  if request is None:
981
- handled = [self.llm_tools_map[r] for r in self.llm_tools_handled]
1208
+ possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
982
1209
  default_keys = set(ToolMessage.__fields__.keys())
983
1210
  request_keys = set(json_data.keys())
984
1211
 
@@ -1002,7 +1229,7 @@ class Agent(ABC):
1002
1229
  candidate_tools = list(
1003
1230
  filter(
1004
1231
  lambda t: t is not None,
1005
- map(maybe_parse, handled),
1232
+ map(maybe_parse, possible),
1006
1233
  )
1007
1234
  )
1008
1235
 
@@ -1013,7 +1240,7 @@ class Agent(ABC):
1013
1240
  else:
1014
1241
  return None
1015
1242
 
1016
- if not isinstance(request, str) or request not in self.llm_tools_handled:
1243
+ if not isinstance(request, str) or request not in self.llm_tools_known:
1017
1244
  return None
1018
1245
 
1019
1246
  message_class = self.llm_tools_map.get(request)
@@ -1027,11 +1254,18 @@ class Agent(ABC):
1027
1254
  raise ve
1028
1255
  return message
1029
1256
 
1030
- def handle_tool_message(self, tool: ToolMessage) -> None | str | ChatDocument:
1257
+ def handle_tool_message(
1258
+ self,
1259
+ tool: ToolMessage,
1260
+ chat_doc: Optional[ChatDocument] = None,
1261
+ ) -> None | str | ChatDocument:
1031
1262
  """
1032
1263
  Respond to a tool request from the LLM, in the form of an ToolMessage object.
1033
1264
  Args:
1034
1265
  tool: ToolMessage object representing the tool request.
1266
+ chat_doc: Optional ChatDocument object containing the tool request.
1267
+ This is passed to the tool-handler method only if it has a `chat_doc`
1268
+ argument.
1035
1269
 
1036
1270
  Returns:
1037
1271
 
@@ -1040,9 +1274,34 @@ class Agent(ABC):
1040
1274
  handler_method = getattr(self, tool_name, None)
1041
1275
  if handler_method is None:
1042
1276
  return None
1043
-
1277
+ has_chat_doc_arg = (
1278
+ chat_doc is not None
1279
+ and "chat_doc" in inspect.signature(handler_method).parameters
1280
+ )
1044
1281
  try:
1045
- result = handler_method(tool)
1282
+ if has_chat_doc_arg:
1283
+ maybe_result = handler_method(tool, chat_doc=chat_doc)
1284
+ else:
1285
+ maybe_result = handler_method(tool)
1286
+ if isinstance(maybe_result, ToolMessage):
1287
+ # result is a ToolMessage, so...
1288
+ result_tool_name = maybe_result.default_value("request")
1289
+ if (
1290
+ result_tool_name in self.llm_tools_handled
1291
+ and tool_name != result_tool_name
1292
+ ):
1293
+ # TODO: do we need to remove the tool message from the chat_doc?
1294
+ # if (chat_doc is not None and
1295
+ # maybe_result in chat_doc.tool_messages):
1296
+ # chat_doc.tool_messages.remove(maybe_result)
1297
+ # if we can handle it, do so
1298
+ result = self.handle_tool_message(maybe_result, chat_doc=chat_doc)
1299
+ else:
1300
+ # else wrap it in an agent response and return it so
1301
+ # orchestrator can find a respondent
1302
+ result = self.create_agent_response(tool_messages=[maybe_result])
1303
+ else:
1304
+ result = maybe_result
1046
1305
  except Exception as e:
1047
1306
  # raise the error here since we are sure it's
1048
1307
  # not a pydantic validation error,