langroid 0.6.7__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. langroid/agent/base.py +499 -55
  2. langroid/agent/callbacks/chainlit.py +1 -1
  3. langroid/agent/chat_agent.py +191 -37
  4. langroid/agent/chat_document.py +142 -29
  5. langroid/agent/openai_assistant.py +20 -4
  6. langroid/agent/special/lance_doc_chat_agent.py +25 -18
  7. langroid/agent/special/lance_rag/critic_agent.py +37 -5
  8. langroid/agent/special/lance_rag/query_planner_agent.py +102 -63
  9. langroid/agent/special/lance_tools.py +10 -2
  10. langroid/agent/special/sql/sql_chat_agent.py +69 -13
  11. langroid/agent/task.py +179 -43
  12. langroid/agent/tool_message.py +19 -7
  13. langroid/agent/tools/__init__.py +5 -0
  14. langroid/agent/tools/orchestration.py +216 -0
  15. langroid/agent/tools/recipient_tool.py +6 -11
  16. langroid/agent/tools/rewind_tool.py +1 -1
  17. langroid/agent/typed_task.py +19 -0
  18. langroid/language_models/.chainlit/config.toml +121 -0
  19. langroid/language_models/.chainlit/translations/en-US.json +231 -0
  20. langroid/language_models/base.py +114 -12
  21. langroid/language_models/mock_lm.py +10 -1
  22. langroid/language_models/openai_gpt.py +260 -36
  23. langroid/mytypes.py +0 -1
  24. langroid/parsing/parse_json.py +19 -2
  25. langroid/utils/pydantic_utils.py +19 -0
  26. langroid/vector_store/base.py +3 -1
  27. langroid/vector_store/lancedb.py +2 -0
  28. {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/METADATA +4 -1
  29. {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/RECORD +32 -33
  30. pyproject.toml +2 -1
  31. langroid/agent/special/lance_rag_new/__init__.py +0 -9
  32. langroid/agent/special/lance_rag_new/critic_agent.py +0 -171
  33. langroid/agent/special/lance_rag_new/lance_rag_task.py +0 -144
  34. langroid/agent/special/lance_rag_new/query_planner_agent.py +0 -222
  35. langroid/agent/team.py +0 -1758
  36. {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/LICENSE +0 -0
  37. {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import asyncio
2
+ import copy
2
3
  import inspect
3
4
  import json
4
5
  import logging
5
6
  import re
6
7
  from abc import ABC
8
+ from collections import OrderedDict
7
9
  from contextlib import ExitStack
8
10
  from types import SimpleNamespace
9
11
  from typing import (
@@ -30,10 +32,13 @@ from langroid.agent.tool_message import ToolMessage
30
32
  from langroid.language_models.base import (
31
33
  LanguageModel,
32
34
  LLMConfig,
35
+ LLMFunctionCall,
33
36
  LLMMessage,
34
37
  LLMResponse,
35
38
  LLMTokenUsage,
39
+ OpenAIToolCall,
36
40
  StreamingIfAllowed,
41
+ ToolChoiceTypes,
37
42
  )
38
43
  from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
39
44
  from langroid.mytypes import Entity
@@ -42,11 +47,12 @@ from langroid.parsing.parser import Parser, ParsingConfig
42
47
  from langroid.prompts.prompts_config import PromptsConfig
43
48
  from langroid.pydantic_v1 import BaseSettings, Field, ValidationError, validator
44
49
  from langroid.utils.configuration import settings
45
- from langroid.utils.constants import NO_ANSWER
50
+ from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
46
51
  from langroid.utils.object_registry import ObjectRegistry
47
52
  from langroid.utils.output import status
48
53
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
49
54
 
55
+ ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
50
56
  console = Console(quiet=settings.quiet)
51
57
 
52
58
  logger = logging.getLogger(__name__)
@@ -93,6 +99,11 @@ class Agent(ABC):
93
99
  """
94
100
 
95
101
  id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
102
+ # OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
103
+ # is added to self.message_history
104
+ oai_tool_calls: List[OpenAIToolCall] = []
105
+ # Index of ALL tool calls generated by the agent
106
+ oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
96
107
 
97
108
  def __init__(self, config: AgentConfig = AgentConfig()):
98
109
  self.config = config
@@ -101,9 +112,8 @@ class Agent(ABC):
101
112
  self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
102
113
  self.llm_tools_handled: Set[str] = set()
103
114
  self.llm_tools_usable: Set[str] = set()
115
+ self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
104
116
  self.interactive: bool | None = None
105
- self.total_llm_token_cost = 0.0
106
- self.total_llm_token_usage = 0
107
117
  self.token_stats_str = ""
108
118
  self.default_human_response: Optional[str] = None
109
119
  self._indent = ""
@@ -133,6 +143,13 @@ class Agent(ABC):
133
143
  show_error_message=noop_fn,
134
144
  show_start_response=noop_fn,
135
145
  )
146
+ Agent.init_state(self)
147
+ self.init_state()
148
+
149
+ def init_state(self) -> None:
150
+ """Initialize all state vars. Called by Task.run() if restart is True"""
151
+ self.total_llm_token_cost = 0.0
152
+ self.total_llm_token_usage = 0
136
153
 
137
154
  @staticmethod
138
155
  def from_id(id: str) -> "Agent":
@@ -232,7 +249,7 @@ class Agent(ABC):
232
249
  ):
233
250
  """
234
251
  If the message class has a `handle` method,
235
- and does NOT have a method with the same name as the tool,
252
+ and agent does NOT have a method with the same name as the tool,
236
253
  then we create a method for the agent whose name
237
254
  is the value of `tool`, and whose body is the `handle` method.
238
255
  This removes a separate step of having to define this method
@@ -240,13 +257,25 @@ class Agent(ABC):
240
257
  in one place, i.e. in the message class.
241
258
  See `tests/main/test_stateless_tool_messages.py` for an example.
242
259
  """
243
- setattr(self, tool, lambda obj: obj.handle())
260
+ has_chat_doc_arg = (
261
+ len(inspect.signature(message_class.handle).parameters) > 1
262
+ )
263
+ if has_chat_doc_arg:
264
+ setattr(self, tool, lambda obj, chat_doc: obj.handle(chat_doc))
265
+ else:
266
+ setattr(self, tool, lambda obj: obj.handle())
244
267
  elif (
245
268
  hasattr(message_class, "response")
246
269
  and inspect.isfunction(message_class.response)
247
270
  and not hasattr(self, tool)
248
271
  ):
249
- setattr(self, tool, lambda obj: obj.response(self))
272
+ has_chat_doc_arg = (
273
+ len(inspect.signature(message_class.response).parameters) > 2
274
+ )
275
+ if has_chat_doc_arg:
276
+ setattr(self, tool, lambda obj, chat_doc: obj.response(self, chat_doc))
277
+ else:
278
+ setattr(self, tool, lambda obj: obj.response(self))
250
279
 
251
280
  if hasattr(message_class, "handle_message_fallback") and (
252
281
  inspect.isfunction(message_class.handle_message_fallback)
@@ -304,9 +333,27 @@ class Agent(ABC):
304
333
  ]
305
334
  return "\n\n".join(sample_convo)
306
335
 
307
- def create_agent_response(self, content: str | None = None) -> ChatDocument:
336
+ def create_agent_response(
337
+ self,
338
+ content: str | None = None,
339
+ tool_messages: List[ToolMessage] = [],
340
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
341
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
342
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
343
+ function_call: LLMFunctionCall | None = None,
344
+ recipient: str = "",
345
+ ) -> ChatDocument:
308
346
  """Template for agent_response."""
309
- return self._response_template(Entity.AGENT, content)
347
+ return self.response_template(
348
+ Entity.AGENT,
349
+ content=content,
350
+ tool_messages=tool_messages,
351
+ oai_tool_calls=oai_tool_calls,
352
+ oai_tool_choice=oai_tool_choice,
353
+ oai_tool_id2result=oai_tool_id2result,
354
+ function_call=function_call,
355
+ recipient=recipient,
356
+ )
310
357
 
311
358
  async def agent_response_async(
312
359
  self,
@@ -343,43 +390,195 @@ class Agent(ABC):
343
390
  )
344
391
  return results
345
392
  if not settings.quiet:
393
+ results_str = (
394
+ results if isinstance(results, str) else json.dumps(results, indent=2)
395
+ )
346
396
  console.print(f"[red]{self.indent}", end="")
347
- print(f"[red]Agent: {escape(results)}")
348
- maybe_json = len(extract_top_level_json(results)) > 0
397
+ print(f"[red]Agent: {escape(results_str)}")
398
+ maybe_json = len(extract_top_level_json(results_str)) > 0
349
399
  self.callbacks.show_agent_response(
350
- content=results,
400
+ content=results_str,
351
401
  language="json" if maybe_json else "text",
352
402
  )
353
403
  sender_name = self.config.name
354
404
  if isinstance(msg, ChatDocument) and msg.function_call is not None:
355
405
  # if result was from handling an LLM `function_call`,
356
- # set sender_name to "request", i.e. name of the function_call
406
+ # set sender_name to name of the function_call
357
407
  sender_name = msg.function_call.name
358
408
 
409
+ results_str, id2result, oai_tool_id = self.process_tool_results(
410
+ results if isinstance(results, str) else "",
411
+ id2result=None if isinstance(results, str) else results,
412
+ tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
413
+ )
359
414
  return ChatDocument(
360
- content=results,
415
+ content=results_str,
416
+ oai_tool_id2result=id2result,
361
417
  metadata=ChatDocMetaData(
362
418
  source=Entity.AGENT,
363
419
  sender=Entity.AGENT,
364
420
  sender_name=sender_name,
421
+ oai_tool_id=oai_tool_id,
365
422
  # preserve trail of tool_ids for OpenAI Assistant fn-calls
366
423
  tool_ids=[] if isinstance(msg, str) else msg.metadata.tool_ids,
367
424
  ),
368
425
  )
369
426
 
370
- def _response_template(self, e: Entity, content: str | None = None) -> ChatDocument:
427
+ def process_tool_results(
428
+ self,
429
+ results: str,
430
+ id2result: OrderedDict[str, str] | None,
431
+ tool_calls: List[OpenAIToolCall] | None = None,
432
+ ) -> Tuple[str, Dict[str, str] | None, str | None]:
433
+ """
434
+ Process results from a response, based on whether
435
+ they are results of OpenAI tool-calls from THIS agent, so that
436
+ we can construct an appropriate LLMMessage that contains tool results.
437
+
438
+ Args:
439
+ results (str): A possible string result from handling tool(s)
440
+ id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
441
+ if there are multiple tool results.
442
+ tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
443
+ results are a response to.
444
+
445
+ Return:
446
+ - str: The response string
447
+ - Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
448
+ multiple tool results.
449
+ - str|None: tool_id if there was a single tool result
450
+
451
+ """
452
+ id2result_ = copy.deepcopy(id2result) if id2result is not None else None
453
+ results_str = ""
454
+ oai_tool_id = None
455
+
456
+ if results != "":
457
+ # in this case ignore id2result
458
+ assert (
459
+ id2result is None
460
+ ), "id2result should be None when results string is non-empty!"
461
+ results_str = results
462
+ if len(self.oai_tool_calls) > 0:
463
+ # We only have one result, so in case there is a
464
+ # "pending" OpenAI tool-call, we expect no more than 1 such.
465
+ assert (
466
+ len(self.oai_tool_calls) == 1
467
+ ), "There are multiple pending tool-calls, but only one result!"
468
+ # We record the tool_id of the tool-call that
469
+ # the result is a response to, so that ChatDocument.to_LLMMessage
470
+ # can properly set the `tool_call_id` field of the LLMMessage.
471
+ oai_tool_id = self.oai_tool_calls[0].id
472
+ elif id2result is not None and id2result_ is not None: # appease mypy
473
+ if len(id2result_) == len(self.oai_tool_calls):
474
+ # if the number of pending tool calls equals the number of results,
475
+ # then ignore the ids in id2result, and use the results in order,
476
+ # which is preserved since id2result is an OrderedDict.
477
+ assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
478
+ results_str = ""
479
+ id2result_ = OrderedDict(
480
+ zip(
481
+ [tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
482
+ )
483
+ )
484
+ else:
485
+ assert (
486
+ tool_calls is not None
487
+ ), "tool_calls cannot be None when id2result is not None!"
488
+ # This must be an OpenAI tool id -> result map;
489
+ # However some ids may not correspond to the tool-calls in the list of
490
+ # pending tool-calls (self.oai_tool_calls).
491
+ # Such results are concatenated into a simple string, to store in the
492
+ # ChatDocument.content, and the rest
493
+ # (i.e. those that DO correspond to tools in self.oai_tool_calls)
494
+ # are stored as a dict in ChatDocument.oai_tool_id2result.
495
+
496
+ # OAI tools from THIS agent, awaiting response
497
+ pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
498
+ # tool_calls that the results are a response to
499
+ # (but these may have been sent from another agent, hence may not be in
500
+ # self.oai_tool_calls)
501
+ parent_tool_id2name = {
502
+ tc.id: tc.function.name
503
+ for tc in tool_calls or []
504
+ if tc.function is not None
505
+ }
506
+
507
+ # (id, result) for result NOT corresponding to self.oai_tool_calls,
508
+ # i.e. these are results of EXTERNAL tool-calls from another agent.
509
+ external_tool_id_results = []
510
+
511
+ for tc_id, result in id2result.items():
512
+ if tc_id not in pending_tool_ids:
513
+ external_tool_id_results.append((tc_id, result))
514
+ id2result_.pop(tc_id)
515
+ if len(external_tool_id_results) == 0:
516
+ results_str = ""
517
+ elif len(external_tool_id_results) == 1:
518
+ results_str = external_tool_id_results[0][1]
519
+ else:
520
+ results_str = "\n\n".join(
521
+ [
522
+ f"Result from tool/function "
523
+ f"{parent_tool_id2name[id]}: {result}"
524
+ for id, result in external_tool_id_results
525
+ ]
526
+ )
527
+
528
+ if len(id2result_) == 0:
529
+ id2result_ = None
530
+ elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
531
+ results_str = list(id2result_.values())[0]
532
+ oai_tool_id = list(id2result_.keys())[0]
533
+ id2result_ = None
534
+
535
+ return results_str, id2result_, oai_tool_id
536
+
537
+ def response_template(
538
+ self,
539
+ e: Entity,
540
+ content: str | None = None,
541
+ tool_messages: List[ToolMessage] = [],
542
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
543
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
544
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
545
+ function_call: LLMFunctionCall | None = None,
546
+ recipient: str = "",
547
+ ) -> ChatDocument:
371
548
  """Template for response from entity `e`."""
372
549
  return ChatDocument(
373
550
  content=content or "",
374
- tool_messages=[],
551
+ tool_messages=tool_messages,
552
+ oai_tool_calls=oai_tool_calls,
553
+ oai_tool_id2result=oai_tool_id2result,
554
+ function_call=function_call,
555
+ oai_tool_choice=oai_tool_choice,
375
556
  metadata=ChatDocMetaData(
376
- source=e, sender=e, sender_name=self.config.name, tool_ids=[]
557
+ source=e, sender=e, sender_name=self.config.name, recipient=recipient
377
558
  ),
378
559
  )
379
560
 
380
- def create_user_response(self, content: str | None = None) -> ChatDocument:
561
+ def create_user_response(
562
+ self,
563
+ content: str | None = None,
564
+ tool_messages: List[ToolMessage] = [],
565
+ oai_tool_calls: List[OpenAIToolCall] | None = None,
566
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
567
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
568
+ function_call: LLMFunctionCall | None = None,
569
+ recipient: str = "",
570
+ ) -> ChatDocument:
381
571
  """Template for user_response."""
382
- return self._response_template(Entity.USER, content)
572
+ return self.response_template(
573
+ e=Entity.USER,
574
+ content=content,
575
+ tool_messages=tool_messages,
576
+ oai_tool_calls=oai_tool_calls,
577
+ oai_tool_choice=oai_tool_choice,
578
+ oai_tool_id2result=oai_tool_id2result,
579
+ function_call=function_call,
580
+ recipient=recipient,
581
+ )
383
582
 
384
583
  async def user_response_async(
385
584
  self,
@@ -473,9 +672,27 @@ class Agent(ABC):
473
672
 
474
673
  return True
475
674
 
476
- def create_llm_response(self, content: str | None = None) -> ChatDocument:
675
+ def create_llm_response(
676
+ self,
677
+ content: str | None = None,
678
+ tool_messages: List[ToolMessage] = [],
679
+ oai_tool_calls: None | List[OpenAIToolCall] = None,
680
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
681
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
682
+ function_call: LLMFunctionCall | None = None,
683
+ recipient: str = "",
684
+ ) -> ChatDocument:
477
685
  """Template for llm_response."""
478
- return self._response_template(Entity.LLM, content)
686
+ return self.response_template(
687
+ Entity.LLM,
688
+ content=content,
689
+ tool_messages=tool_messages,
690
+ oai_tool_calls=oai_tool_calls,
691
+ oai_tool_choice=oai_tool_choice,
692
+ oai_tool_id2result=oai_tool_id2result,
693
+ function_call=function_call,
694
+ recipient=recipient,
695
+ )
479
696
 
480
697
  @no_type_check
481
698
  async def llm_response_async(
@@ -619,25 +836,95 @@ class Agent(ABC):
619
836
  return True
620
837
  return False
621
838
 
622
- def get_tool_messages(self, msg: str | ChatDocument) -> List[ToolMessage]:
839
+ def _tool_recipient_match(self, tool: ToolMessage) -> bool:
840
+ """Is tool is handled by this agent
841
+ and an explicit `recipient` field doesn't preclude this agent from handling it?
842
+ """
843
+ if tool.default_value("request") not in self.llm_tools_handled:
844
+ return False
845
+ if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
846
+ return tool.recipient == "" or tool.recipient == self.config.name
847
+ return True
848
+
849
+ def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
850
+ """
851
+ Does the msg have at least one tool, and ALL tools are
852
+ disabled for handling by this agent?
853
+ """
854
+ tools = self.get_tool_messages(msg, all_tools=True)
855
+ if len(tools) == 0:
856
+ return False
857
+ return all(not self._tool_recipient_match(t) for t in tools)
858
+
859
+ def get_tool_messages(
860
+ self,
861
+ msg: str | ChatDocument,
862
+ all_tools: bool = False,
863
+ ) -> List[ToolMessage]:
864
+ """
865
+ Get ToolMessages recognized in msg, handle-able by this agent.
866
+ If all_tools is True:
867
+ - return all tools, i.e. any tool in self.llm_tools_known,
868
+ whether it is handled by this agent or not;
869
+ - otherwise, return only the tools handled by this agent.
870
+ """
871
+
623
872
  if isinstance(msg, str):
624
- return self.get_json_tool_messages(msg)
873
+ json_tools = self.get_json_tool_messages(msg)
874
+ if all_tools:
875
+ return json_tools
876
+ else:
877
+ return [
878
+ t
879
+ for t in json_tools
880
+ if self._tool_recipient_match(t) and t.default_value("request")
881
+ ]
882
+
883
+ if all_tools and len(msg.all_tool_messages) > 0:
884
+ # We've already identified all_tool_messages in the msg;
885
+ # return the corresponding ToolMessage objects
886
+ return msg.all_tool_messages
625
887
  if len(msg.tool_messages) > 0:
626
- # We've already found tool_messages
627
- # (either via OpenAI Fn-call or Langroid-native ToolMessage)
888
+ # We've already found tool_messages,
889
+ # (either via OpenAI Fn-call or Langroid-native ToolMessage);
890
+ # or they were added by an agent_response.
891
+ # note these could be from a forwarded msg from another agent,
892
+ # so return ONLY the messages THIS agent to enabled to handle.
628
893
  return msg.tool_messages
629
894
  assert isinstance(msg, ChatDocument)
630
- # when `content` is non-empty, we assume there will be no `function_call`
631
- if msg.content != "":
895
+ if (
896
+ msg.content != ""
897
+ and msg.oai_tool_calls is None
898
+ and msg.function_call is None
899
+ ):
900
+
632
901
  tools = self.get_json_tool_messages(msg.content)
633
- msg.tool_messages = tools
634
- return tools
902
+ msg.all_tool_messages = tools
903
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
904
+ msg.tool_messages = my_tools
635
905
 
636
- # otherwise, we look for a `function_call`
637
- fun_call_cls = self.get_function_call_class(msg)
638
- tools = [fun_call_cls] if fun_call_cls is not None else []
639
- msg.tool_messages = tools
640
- return tools
906
+ if all_tools:
907
+ return tools
908
+ else:
909
+ return my_tools
910
+
911
+ # otherwise, we look for `tool_calls` (possibly multiple)
912
+ tools = self.get_oai_tool_calls_classes(msg)
913
+ msg.all_tool_messages = tools
914
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
915
+ msg.tool_messages = my_tools
916
+
917
+ if len(tools) == 0:
918
+ # otherwise, we look for a `function_call`
919
+ fun_call_cls = self.get_function_call_class(msg)
920
+ tools = [fun_call_cls] if fun_call_cls is not None else []
921
+ msg.all_tool_messages = tools
922
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
923
+ msg.tool_messages = my_tools
924
+ if all_tools:
925
+ return tools
926
+ else:
927
+ return my_tools
641
928
 
642
929
  def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
643
930
  """
@@ -656,6 +943,10 @@ class Agent(ABC):
656
943
  return [r for r in results if r is not None]
657
944
 
658
945
  def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
946
+ """
947
+ From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
948
+ corresponding to the `function_call` if it exists.
949
+ """
659
950
  if msg.function_call is None:
660
951
  return None
661
952
  tool_name = msg.function_call.name
@@ -677,6 +968,39 @@ class Agent(ABC):
677
968
  tool = tool_class.parse_obj(tool_msg)
678
969
  return tool
679
970
 
971
+ def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
972
+ """
973
+ From ChatDocument (constructed from an LLM Response), get
974
+ a list of ToolMessages corresponding to the `tool_calls`, if any.
975
+ """
976
+
977
+ if msg.oai_tool_calls is None:
978
+ return []
979
+ tools = []
980
+ for tc in msg.oai_tool_calls:
981
+ if tc.function is None:
982
+ continue
983
+ tool_name = tc.function.name
984
+ tool_msg = tc.function.arguments or {}
985
+ if tool_name not in self.llm_tools_handled:
986
+ logger.warning(
987
+ f"""
988
+ The tool_call '{tool_name}' is not handled
989
+ by the agent named '{self.config.name}'!
990
+ If you intended this agent to handle this function_call,
991
+ either the fn-call name is incorrectly generated by the LLM,
992
+ (in which case you may need to adjust your LLM instructions),
993
+ or you need to enable this agent to handle this fn-call.
994
+ """
995
+ )
996
+ continue
997
+ tool_class = self.llm_tools_map[tool_name]
998
+ tool_msg.update(dict(request=tool_name))
999
+ tool = tool_class.parse_obj(tool_msg)
1000
+ tool.id = tc.id or ""
1001
+ tools.append(tool)
1002
+ return tools
1003
+
680
1004
  def tool_validation_error(self, ve: ValidationError) -> str:
681
1005
  """
682
1006
  Handle a validation error raised when parsing a tool message,
@@ -699,7 +1023,9 @@ class Agent(ABC):
699
1023
  Please write your message again, correcting the errors.
700
1024
  """
701
1025
 
702
- def handle_message(self, msg: str | ChatDocument) -> None | str | ChatDocument:
1026
+ def handle_message(
1027
+ self, msg: str | ChatDocument
1028
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
703
1029
  """
704
1030
  Handle a "tool" message either a string containing one or more
705
1031
  valid "tool" JSON substrings, or a
@@ -711,9 +1037,16 @@ class Agent(ABC):
711
1037
  msg (str | ChatDocument): The string or ChatDocument to handle
712
1038
 
713
1039
  Returns:
714
- Optional[Str]: The result of the handler method in string form so it can
715
- be sent back to the LLM, or None if `msg` was not successfully
716
- handled by a method.
1040
+ The result of the handler method can be:
1041
+ - None if no tools successfully handled, or no tools present
1042
+ - str if langroid-native JSON tools were handled, and results concatenated,
1043
+ OR there's a SINGLE OpenAI tool-call.
1044
+ (We do this so the common scenario of a single tool/fn-call
1045
+ has a simple behavior).
1046
+ - Dict[str, str] if multiple OpenAI tool-calls were handled
1047
+ (dict is an id->result map)
1048
+ - ChatDocument if a handler returned a ChatDocument, intended to be the
1049
+ final response of the `agent_response` method.
717
1050
  """
718
1051
  try:
719
1052
  tools = self.get_tool_messages(msg)
@@ -728,15 +1061,73 @@ class Agent(ABC):
728
1061
  # for this agent.
729
1062
  return None
730
1063
  if len(tools) == 0:
731
- return self.handle_message_fallback(msg)
1064
+ fallback_result = self.handle_message_fallback(msg)
1065
+ if fallback_result is not None and isinstance(fallback_result, ToolMessage):
1066
+ return self.create_agent_response(tool_messages=[fallback_result])
1067
+ return fallback_result
1068
+ has_ids = all([t.id != "" for t in tools])
1069
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1070
+
1071
+ # check whether there are multiple orchestration-tools (e.g. DoneTool etc),
1072
+ # in which case set result to error-string since we don't yet support
1073
+ # multi-tools with one or more orch tools.
1074
+ from langroid.agent.tools.orchestration import (
1075
+ AgentDoneTool,
1076
+ AgentSendTool,
1077
+ DonePassTool,
1078
+ DoneTool,
1079
+ ForwardTool,
1080
+ PassTool,
1081
+ SendTool,
1082
+ )
1083
+ from langroid.agent.tools.recipient_tool import RecipientTool
1084
+
1085
+ ORCHESTRATION_TOOLS = (
1086
+ AgentDoneTool,
1087
+ DoneTool,
1088
+ PassTool,
1089
+ DonePassTool,
1090
+ ForwardTool,
1091
+ RecipientTool,
1092
+ SendTool,
1093
+ AgentSendTool,
1094
+ )
732
1095
 
733
- results = [self.handle_tool_message(t) for t in tools]
1096
+ has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
1097
+ results: List[str | ChatDocument | None]
1098
+ if has_orch and len(tools) > 1:
1099
+ err_str = "ERROR: Use ONE tool at a time!"
1100
+ results = [err_str for _ in tools]
1101
+ else:
1102
+ results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
1103
+
1104
+ tool_names = [t.default_value("request") for t in tools]
1105
+ if has_ids:
1106
+ id2result = OrderedDict(
1107
+ (t.id, r)
1108
+ for t, r in zip(tools, results)
1109
+ if r is not None and isinstance(r, str)
1110
+ )
1111
+ result_values = list(id2result.values())
1112
+ if len(id2result) > 1 and any(
1113
+ orch_str in r
1114
+ for r in result_values
1115
+ for orch_str in ORCHESTRATION_STRINGS
1116
+ ):
1117
+ # Cannot support multi-tool results containing orchestration strings!
1118
+ # Replace results with err string to force LLM to retry
1119
+ err_str = "ERROR: Please use ONE tool at a time!"
1120
+ id2result = OrderedDict((id, err_str) for id in id2result.keys())
734
1121
 
735
- results_list = [r for r in results if r is not None]
736
- if len(results_list) == 0:
1122
+ name_results_list = [
1123
+ (name, r) for name, r in zip(tool_names, results) if r is not None
1124
+ ]
1125
+ if len(name_results_list) == 0:
737
1126
  return None # self.handle_message_fallback(msg)
738
1127
  # there was a non-None result
739
- chat_doc_results = [r for r in results_list if isinstance(r, ChatDocument)]
1128
+ chat_doc_results = [
1129
+ r for _, r in name_results_list if isinstance(r, ChatDocument)
1130
+ ]
740
1131
  if len(chat_doc_results) > 1:
741
1132
  logger.warning(
742
1133
  """There were multiple ChatDocument results from tools,
@@ -747,27 +1138,49 @@ class Agent(ABC):
747
1138
  if len(chat_doc_results) > 0:
748
1139
  return chat_doc_results[0]
749
1140
 
750
- str_doc_results = [r for r in results_list if isinstance(r, str)]
751
- final = "\n".join(str_doc_results)
1141
+ if has_ids and len(id2result) > 1:
1142
+ # if there are multiple OpenAI Tool results, return them as a dict
1143
+ return id2result
1144
+
1145
+ if len(name_results_list) == 1 and isinstance(name_results_list[0][1], str):
1146
+ # single str result -- return it as is
1147
+ return name_results_list[0][1]
1148
+
1149
+ # multi-results: prepend the tool name to each result
1150
+ str_results = [
1151
+ f"Result from {name}: {r}"
1152
+ for name, r in name_results_list
1153
+ if isinstance(r, str)
1154
+ ]
1155
+ final = "\n\n".join(str_results)
752
1156
  return final
753
1157
 
754
1158
  def handle_message_fallback(
755
1159
  self, msg: str | ChatDocument
756
1160
  ) -> str | ChatDocument | None:
757
1161
  """
758
- Fallback method to handle possible "tool" msg if no other method applies
759
- or if an error is thrown.
760
- This method can be overridden by subclasses.
1162
+ Fallback method for the "no-tools" scenario.
1163
+ This method can be overridden by subclasses, e.g.,
1164
+ to create a "reminder" message when a tool is expected but the LLM "forgot"
1165
+ to generate one.
761
1166
 
762
1167
  Args:
763
1168
  msg (str | ChatDocument): The input msg to handle
764
1169
  Returns:
765
1170
  str: The result of the handler method in string form so it can
766
- be sent back to the LLM.
1171
+ be handled by LLM
767
1172
  """
768
1173
  return None
769
1174
 
770
1175
  def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
1176
+ """
1177
+ Parse the json str into ANY ToolMessage KNOWN to agent --
1178
+ This includes non-used/handled tools, i.e. any tool in self.llm_tools_known.
1179
+ The exception to this is below where we try our best to infer the tool
1180
+ when the LLM has "forgotten" to include the "request" field in the JSON --
1181
+ in this case we ONLY look at the possible set of HANDLED tools, i.e.
1182
+ self.llm_tools_handled.
1183
+ """
771
1184
  json_data = json.loads(json_str)
772
1185
  # check if the json_data contains a "properties" field
773
1186
  # which further contains the actual tool-call
@@ -791,9 +1204,8 @@ class Agent(ABC):
791
1204
  if isinstance(properties, dict):
792
1205
  json_data = properties
793
1206
  request = json_data.get("request")
794
-
795
1207
  if request is None:
796
- handled = [self.llm_tools_map[r] for r in self.llm_tools_handled]
1208
+ possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
797
1209
  default_keys = set(ToolMessage.__fields__.keys())
798
1210
  request_keys = set(json_data.keys())
799
1211
 
@@ -817,7 +1229,7 @@ class Agent(ABC):
817
1229
  candidate_tools = list(
818
1230
  filter(
819
1231
  lambda t: t is not None,
820
- map(maybe_parse, handled),
1232
+ map(maybe_parse, possible),
821
1233
  )
822
1234
  )
823
1235
 
@@ -828,7 +1240,7 @@ class Agent(ABC):
828
1240
  else:
829
1241
  return None
830
1242
 
831
- if not isinstance(request, str) or request not in self.llm_tools_handled:
1243
+ if not isinstance(request, str) or request not in self.llm_tools_known:
832
1244
  return None
833
1245
 
834
1246
  message_class = self.llm_tools_map.get(request)
@@ -842,11 +1254,18 @@ class Agent(ABC):
842
1254
  raise ve
843
1255
  return message
844
1256
 
845
- def handle_tool_message(self, tool: ToolMessage) -> None | str | ChatDocument:
1257
+ def handle_tool_message(
1258
+ self,
1259
+ tool: ToolMessage,
1260
+ chat_doc: Optional[ChatDocument] = None,
1261
+ ) -> None | str | ChatDocument:
846
1262
  """
847
1263
  Respond to a tool request from the LLM, in the form of an ToolMessage object.
848
1264
  Args:
849
1265
  tool: ToolMessage object representing the tool request.
1266
+ chat_doc: Optional ChatDocument object containing the tool request.
1267
+ This is passed to the tool-handler method only if it has a `chat_doc`
1268
+ argument.
850
1269
 
851
1270
  Returns:
852
1271
 
@@ -855,9 +1274,34 @@ class Agent(ABC):
855
1274
  handler_method = getattr(self, tool_name, None)
856
1275
  if handler_method is None:
857
1276
  return None
858
-
1277
+ has_chat_doc_arg = (
1278
+ chat_doc is not None
1279
+ and "chat_doc" in inspect.signature(handler_method).parameters
1280
+ )
859
1281
  try:
860
- result = handler_method(tool)
1282
+ if has_chat_doc_arg:
1283
+ maybe_result = handler_method(tool, chat_doc=chat_doc)
1284
+ else:
1285
+ maybe_result = handler_method(tool)
1286
+ if isinstance(maybe_result, ToolMessage):
1287
+ # result is a ToolMessage, so...
1288
+ result_tool_name = maybe_result.default_value("request")
1289
+ if (
1290
+ result_tool_name in self.llm_tools_handled
1291
+ and tool_name != result_tool_name
1292
+ ):
1293
+ # TODO: do we need to remove the tool message from the chat_doc?
1294
+ # if (chat_doc is not None and
1295
+ # maybe_result in chat_doc.tool_messages):
1296
+ # chat_doc.tool_messages.remove(maybe_result)
1297
+ # if we can handle it, do so
1298
+ result = self.handle_tool_message(maybe_result, chat_doc=chat_doc)
1299
+ else:
1300
+ # else wrap it in an agent response and return it so
1301
+ # orchestrator can find a respondent
1302
+ result = self.create_agent_response(tool_messages=[maybe_result])
1303
+ else:
1304
+ result = maybe_result
861
1305
  except Exception as e:
862
1306
  # raise the error here since we are sure it's
863
1307
  # not a pydantic validation error,