langroid 0.10.2__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -161,6 +161,7 @@ class ChatAgent(Agent):
161
161
  DoneTool,
162
162
  ForwardTool,
163
163
  PassTool,
164
+ ResultTool,
164
165
  SendTool,
165
166
  )
166
167
 
@@ -171,6 +172,7 @@ class ChatAgent(Agent):
171
172
  self.enable_message(DonePassTool, use=False, handle=True)
172
173
  self.enable_message(SendTool, use=False, handle=True)
173
174
  self.enable_message(AgentSendTool, use=False, handle=True)
175
+ self.enable_message(ResultTool, use=False, handle=True)
174
176
 
175
177
  def init_state(self) -> None:
176
178
  """
@@ -312,8 +314,7 @@ class ChatAgent(Agent):
312
314
  usable_tool_classes: List[Type[ToolMessage]] = [
313
315
  t
314
316
  for t in list(self.llm_tools_map.values())
315
- if not t._handle_only
316
- and t.default_value("request") in self.llm_tools_usable
317
+ if t.default_value("request") in self.llm_tools_usable
317
318
  ]
318
319
 
319
320
  if len(usable_tool_classes) == 0:
@@ -522,6 +523,13 @@ class ChatAgent(Agent):
522
523
  tools = self._get_tool_list(message_class)
523
524
  if message_class is not None:
524
525
  request = message_class.default_value("request")
526
+ if request == "":
527
+ raise ValueError(
528
+ f"""
529
+ ToolMessage class {message_class} must have a non-empty
530
+ 'request' field if it is to be enabled as a tool.
531
+ """
532
+ )
525
533
  llm_function = message_class.llm_function_schema(defaults=include_defaults)
526
534
  self.llm_functions_map[request] = llm_function
527
535
  if force:
@@ -540,8 +548,21 @@ class ChatAgent(Agent):
540
548
  self.llm_functions_handled.discard(t)
541
549
 
542
550
  if use:
543
- self.llm_tools_usable.add(t)
544
- self.llm_functions_usable.add(t)
551
+ tool_class = self.llm_tools_map[t]
552
+ if tool_class._allow_llm_use:
553
+ self.llm_tools_usable.add(t)
554
+ self.llm_functions_usable.add(t)
555
+ else:
556
+ logger.warning(
557
+ f"""
558
+ ToolMessage class {tool_class} does not allow LLM use,
559
+ because `_allow_llm_use=False` either in the Tool or a
560
+ parent class of this tool;
561
+ so not enabling LLM use for this tool!
562
+ If you intended an LLM to use this tool,
563
+ set `_allow_llm_use=True` when you define the tool.
564
+ """
565
+ )
545
566
  else:
546
567
  self.llm_tools_usable.discard(t)
547
568
  self.llm_functions_usable.discard(t)
@@ -22,6 +22,7 @@ from langroid.parsing.parse_json import extract_top_level_json, top_level_json_f
22
22
  from langroid.pydantic_v1 import BaseModel, Extra
23
23
  from langroid.utils.object_registry import ObjectRegistry
24
24
  from langroid.utils.output.printing import shorten_text
25
+ from langroid.utils.types import to_string
25
26
 
26
27
 
27
28
  class ChatDocAttachment(BaseModel):
@@ -115,6 +116,7 @@ class ChatDocument(Document):
115
116
  attachment (None | ChatDocAttachment): Any additional data attached.
116
117
  """
117
118
 
119
+ content_any: Any = None # to hold arbitrary data returned by responders
118
120
  oai_tool_calls: Optional[List[OpenAIToolCall]] = None
119
121
  oai_tool_id2result: Optional[OrderedDict[str, str]] = None
120
122
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto"
@@ -281,6 +283,7 @@ class ChatDocument(Document):
281
283
  ChatDocument._clean_fn_call(oai_tc.function)
282
284
  return ChatDocument(
283
285
  content=message,
286
+ content_any=message,
284
287
  oai_tool_calls=response.oai_tool_calls,
285
288
  function_call=response.function_call,
286
289
  metadata=ChatDocMetaData(
@@ -303,6 +306,7 @@ class ChatDocument(Document):
303
306
  message = msg # retain the whole msg in this case
304
307
  return ChatDocument(
305
308
  content=message,
309
+ content_any=message,
306
310
  metadata=ChatDocMetaData(
307
311
  source=Entity.USER,
308
312
  sender=Entity.USER,
@@ -335,7 +339,7 @@ class ChatDocument(Document):
335
339
  tool_id = "" # for OpenAI Assistant
336
340
  chat_document_id: str = ""
337
341
  if isinstance(message, ChatDocument):
338
- content = message.content
342
+ content = message.content or to_string(message.content_any) or ""
339
343
  fun_call = message.function_call
340
344
  oai_tool_calls = message.oai_tool_calls
341
345
  if message.metadata.sender == Entity.USER and fun_call is not None:
@@ -1253,12 +1253,12 @@ class DocChatAgent(ChatAgent):
1253
1253
  interactive=False,
1254
1254
  )
1255
1255
 
1256
- extracts = run_batch_tasks(
1256
+ extracts: list[str] = run_batch_tasks(
1257
1257
  task,
1258
1258
  passages,
1259
1259
  input_map=lambda msg: msg.content,
1260
1260
  output_map=lambda ans: ans.content if ans is not None else NO_ANSWER,
1261
- )
1261
+ ) # type: ignore
1262
1262
 
1263
1263
  # Caution: Retain ALL other fields in the Documents (which could be
1264
1264
  # other than just `content` and `metadata`), while simply replacing
langroid/agent/task.py CHANGED
@@ -18,7 +18,9 @@ from typing import (
18
18
  Optional,
19
19
  Tuple,
20
20
  Type,
21
+ TypeVar,
21
22
  cast,
23
+ overload,
22
24
  )
23
25
 
24
26
  import numpy as np
@@ -33,8 +35,8 @@ from langroid.agent.chat_document import (
33
35
  ChatDocument,
34
36
  StatusCode,
35
37
  )
36
- from langroid.agent.tool_message import FinalResultTool, ToolMessage
37
- from langroid.agent.tools.orchestration import AgentDoneTool, DoneTool
38
+ from langroid.agent.tool_message import ToolMessage
39
+ from langroid.agent.tools.orchestration import AgentDoneTool, DoneTool, FinalResultTool
38
40
  from langroid.cachedb.redis_cachedb import RedisCache, RedisCacheConfig
39
41
  from langroid.exceptions import InfiniteLoopException
40
42
  from langroid.mytypes import Entity
@@ -53,11 +55,14 @@ from langroid.utils.constants import (
53
55
  from langroid.utils.logging import RichFileLogger, setup_file_logger
54
56
  from langroid.utils.object_registry import scheduled_cleanup
55
57
  from langroid.utils.system import hash
58
+ from langroid.utils.types import to_string
56
59
 
57
60
  logger = logging.getLogger(__name__)
58
61
 
59
62
  Responder = Entity | Type["Task"]
60
63
 
64
+ T = TypeVar("T")
65
+
61
66
 
62
67
  def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
63
68
  pass
@@ -153,6 +158,7 @@ class Task:
153
158
  erase_substeps: bool = False,
154
159
  allow_null_result: bool = False,
155
160
  max_stalled_steps: int = 5,
161
+ default_return_type: Optional[type] = None,
156
162
  done_if_no_response: List[Responder] = [],
157
163
  done_if_response: List[Responder] = [],
158
164
  config: TaskConfig = TaskConfig(),
@@ -190,6 +196,8 @@ class Task:
190
196
  default_human_response (str|None): default response from user; useful for
191
197
  testing, to avoid interactive input from user.
192
198
  [Instead of this, setting `interactive` usually suffices]
199
+ default_return_type: if not None, extracts a value of this type from the
200
+ result of self.run()
193
201
  interactive (bool): if true, wait for human input after each non-human
194
202
  response (prevents infinite loop of non-human responses).
195
203
  Default is true. If false, then `default_human_response` is set to ""
@@ -298,6 +306,7 @@ class Task:
298
306
  self.agent.interactive = interactive
299
307
  self.only_user_quits_root = only_user_quits_root
300
308
  self.message_history_idx = -1
309
+ self.default_return_type = default_return_type
301
310
 
302
311
  # set to True if we want to collapse multi-turn conversation with sub-tasks into
303
312
  # just the first outgoing message and last incoming message.
@@ -582,16 +591,50 @@ class Task:
582
591
  for t in self.sub_tasks:
583
592
  t.reset_all_sub_tasks()
584
593
 
594
+ def __getitem__(self, return_type: type) -> Task:
595
+ """Returns a (shallow) copy of `self` with a default return type."""
596
+ clone = copy.copy(self)
597
+ clone.default_return_type = return_type
598
+ return clone
599
+
600
+ @overload
601
+ def run( # noqa
602
+ self,
603
+ msg: Any = None,
604
+ *,
605
+ turns: int = -1,
606
+ caller: None | Task = None,
607
+ max_cost: float = 0,
608
+ max_tokens: int = 0,
609
+ session_id: str = "",
610
+ allow_restart: bool = True,
611
+ ) -> Optional[ChatDocument]: ... # noqa
612
+
613
+ @overload
614
+ def run( # noqa
615
+ self,
616
+ msg: Any = None,
617
+ *,
618
+ turns: int = -1,
619
+ caller: None | Task = None,
620
+ max_cost: float = 0,
621
+ max_tokens: int = 0,
622
+ session_id: str = "",
623
+ allow_restart: bool = True,
624
+ return_type: Type[T],
625
+ ) -> Optional[T]: ... # noqa
626
+
585
627
  def run(
586
628
  self,
587
- msg: Optional[str | ChatDocument] = None,
629
+ msg: Any = None,
588
630
  turns: int = -1,
589
631
  caller: None | Task = None,
590
632
  max_cost: float = 0,
591
633
  max_tokens: int = 0,
592
634
  session_id: str = "",
593
635
  allow_restart: bool = True,
594
- ) -> Optional[ChatDocument]:
636
+ return_type: Optional[Type[T]] = None,
637
+ ) -> Optional[ChatDocument | T]:
595
638
  """Synchronous version of `run_async()`.
596
639
  See `run_async()` for details."""
597
640
  if allow_restart and (
@@ -614,19 +657,18 @@ class Task:
614
657
  self._init_message_counter()
615
658
  self.history.clear()
616
659
 
617
- assert (
618
- msg is None or isinstance(msg, str) or isinstance(msg, ChatDocument)
619
- ), f"msg arg in Task.run() must be None, str, or ChatDocument, not {type(msg)}"
660
+ msg_input = self.agent.to_ChatDocument(msg, author_entity=Entity.USER)
620
661
 
621
662
  if (
622
- isinstance(msg, ChatDocument)
623
- and msg.metadata.recipient != ""
624
- and msg.metadata.recipient != self.name
663
+ isinstance(msg_input, ChatDocument)
664
+ and msg_input.metadata.recipient != ""
665
+ and msg_input.metadata.recipient != self.name
625
666
  ):
626
667
  # this task is not the intended recipient so return None
627
668
  return None
669
+
628
670
  self._pre_run_loop(
629
- msg=msg,
671
+ msg=msg_input,
630
672
  caller=caller,
631
673
  is_async=False,
632
674
  )
@@ -677,24 +719,60 @@ class Task:
677
719
 
678
720
  final_result = self.result(status)
679
721
  self._post_run_loop()
722
+ if final_result is None:
723
+ return None
724
+
725
+ if return_type is None:
726
+ return_type = self.default_return_type
727
+
728
+ if return_type is not None and return_type != ChatDocument:
729
+ return self.agent.from_ChatDocument(final_result, return_type)
680
730
  return final_result
681
731
 
732
+ @overload
733
+ async def run_async( # noqa
734
+ self,
735
+ msg: Any = None,
736
+ *,
737
+ turns: int = -1,
738
+ caller: None | Task = None,
739
+ max_cost: float = 0,
740
+ max_tokens: int = 0,
741
+ session_id: str = "",
742
+ allow_restart: bool = True,
743
+ ) -> Optional[ChatDocument]: ... # noqa
744
+
745
+ @overload
746
+ async def run_async( # noqa
747
+ self,
748
+ msg: Any = None,
749
+ *,
750
+ turns: int = -1,
751
+ caller: None | Task = None,
752
+ max_cost: float = 0,
753
+ max_tokens: int = 0,
754
+ session_id: str = "",
755
+ allow_restart: bool = True,
756
+ return_type: Type[T],
757
+ ) -> Optional[T]: ... # noqa
758
+
682
759
  async def run_async(
683
760
  self,
684
- msg: Optional[str | ChatDocument] = None,
761
+ msg: Any = None,
685
762
  turns: int = -1,
686
763
  caller: None | Task = None,
687
764
  max_cost: float = 0,
688
765
  max_tokens: int = 0,
689
766
  session_id: str = "",
690
767
  allow_restart: bool = True,
691
- ) -> Optional[ChatDocument]:
768
+ return_type: Optional[Type[T]] = None,
769
+ ) -> Optional[ChatDocument | T]:
692
770
  """
693
771
  Loop over `step()` until task is considered done or `turns` is reached.
694
772
  Runs asynchronously.
695
773
 
696
774
  Args:
697
- msg (str|ChatDocument): initial *user-role* message to process; if None,
775
+ msg (Any): initial *user-role* message to process; if None,
698
776
  the LLM will respond to its initial `self.task_messages`
699
777
  which set up and kick off the overall task.
700
778
  The agent tries to achieve this goal by looping
@@ -710,6 +788,7 @@ class Task:
710
788
  max_tokens (int): max tokens allowed for the task (default 0 -> no limit)
711
789
  session_id (str): session id for the task
712
790
  allow_restart (bool): whether to allow restarting the task
791
+ return_type (Optional[Type[T]]): desired final result type
713
792
 
714
793
  Returns:
715
794
  Optional[ChatDocument]: valid result of the task.
@@ -740,17 +819,20 @@ class Task:
740
819
  self._init_message_counter()
741
820
  self.history.clear()
742
821
 
822
+ msg_input = self.agent.to_ChatDocument(msg, author_entity=Entity.USER)
823
+
743
824
  if (
744
- isinstance(msg, ChatDocument)
745
- and msg.metadata.recipient != ""
746
- and msg.metadata.recipient != self.name
825
+ isinstance(msg_input, ChatDocument)
826
+ and msg_input.metadata.recipient != ""
827
+ and msg_input.metadata.recipient != self.name
747
828
  ):
748
829
  # this task is not the intended recipient so return None
749
830
  return None
831
+
750
832
  self._pre_run_loop(
751
- msg=msg,
833
+ msg=msg_input,
752
834
  caller=caller,
753
- is_async=True,
835
+ is_async=False,
754
836
  )
755
837
  # self.turns overrides if it is > 0 and turns not set (i.e. = -1)
756
838
  turns = self.turns if turns < 0 else turns
@@ -800,6 +882,14 @@ class Task:
800
882
 
801
883
  final_result = self.result(status)
802
884
  self._post_run_loop()
885
+ if final_result is None:
886
+ return None
887
+
888
+ if return_type is None:
889
+ return_type = self.default_return_type
890
+
891
+ if return_type is not None and return_type != ChatDocument:
892
+ return self.agent.from_ChatDocument(final_result, return_type)
803
893
  return final_result
804
894
 
805
895
  def _pre_run_loop(
@@ -910,9 +1000,10 @@ class Task:
910
1000
  and not self.human_tried
911
1001
  and not self.agent.has_tool_message_attempt(self.pending_message)
912
1002
  ):
913
- # When in interactive mode,
914
1003
  # Give human first chance if they haven't been tried in last step,
915
1004
  # and the msg is not a tool-call attempt;
1005
+ # (When `interactive=False`, human is only allowed to respond only if
1006
+ # if explicitly addressed)
916
1007
  # This ensures human gets a chance to respond,
917
1008
  # other than to a LLM tool-call.
918
1009
  # When there's a tool msg attempt we want the
@@ -1246,7 +1337,13 @@ class Task:
1246
1337
  else:
1247
1338
  response_fn = self._entity_responder_map[cast(Entity, e)]
1248
1339
  result = response_fn(self.pending_message)
1249
- return self._process_result_routing(result, e)
1340
+
1341
+ result_chat_doc = self.agent.to_ChatDocument(
1342
+ result,
1343
+ chat_doc=self.pending_message,
1344
+ author_entity=e if isinstance(e, Entity) else Entity.USER,
1345
+ )
1346
+ return self._process_result_routing(result_chat_doc, e)
1250
1347
 
1251
1348
  def _process_result_routing(
1252
1349
  self, result: ChatDocument | None, e: Responder
@@ -1364,7 +1461,13 @@ class Task:
1364
1461
  else:
1365
1462
  response_fn = self._entity_responder_async_map[cast(Entity, e)]
1366
1463
  result = await response_fn(self.pending_message)
1367
- return self._process_result_routing(result, e)
1464
+
1465
+ result_chat_doc = self.agent.to_ChatDocument(
1466
+ result,
1467
+ chat_doc=self.pending_message,
1468
+ author_entity=e if isinstance(e, Entity) else Entity.USER,
1469
+ )
1470
+ return self._process_result_routing(result_chat_doc, e)
1368
1471
 
1369
1472
  def result(self, status: StatusCode | None = None) -> ChatDocument | None:
1370
1473
  """
@@ -1386,6 +1489,7 @@ class Task:
1386
1489
  result_msg = self.pending_message
1387
1490
 
1388
1491
  content = result_msg.content if result_msg else ""
1492
+ content_any = result_msg.content_any if result_msg else None
1389
1493
  if DONE in content:
1390
1494
  # assuming it is of the form "DONE: <content>"
1391
1495
  content = content.replace(DONE, "").strip()
@@ -1398,11 +1502,13 @@ class Task:
1398
1502
  for t in tool_messages:
1399
1503
  if isinstance(t, FinalResultTool):
1400
1504
  content = ""
1505
+ content_any = None
1401
1506
  tool_messages = [t] # pass it on to parent so it also quits
1402
1507
  break
1403
1508
  elif isinstance(t, (AgentDoneTool, DoneTool)):
1404
1509
  # there shouldn't be multiple tools like this; just take the first
1405
- content = t.content
1510
+ content = to_string(t.content)
1511
+ content_any = t.content
1406
1512
  if isinstance(t, AgentDoneTool):
1407
1513
  tool_messages = t.tools
1408
1514
  break
@@ -1420,6 +1526,7 @@ class Task:
1420
1526
  # since to the "parent" task, this result is equivalent to a response from USER
1421
1527
  result_doc = ChatDocument(
1422
1528
  content=content,
1529
+ content_any=content_any,
1423
1530
  oai_tool_calls=oai_tool_calls,
1424
1531
  oai_tool_id2result=oai_tool_id2result,
1425
1532
  function_call=fun_call,
@@ -1778,9 +1885,7 @@ class Task:
1778
1885
 
1779
1886
  if self.pending_message is None:
1780
1887
  return True
1781
- if isinstance(e, Task) and e.agent.has_only_unhandled_tools(
1782
- self.pending_message
1783
- ):
1888
+ if isinstance(e, Task) and not e.agent.can_respond(self.pending_message):
1784
1889
  return False
1785
1890
 
1786
1891
  if self._recipient_mismatch(e):
@@ -15,11 +15,12 @@ from typing import Any, Dict, List, Tuple, Type
15
15
  from docstring_parser import parse
16
16
 
17
17
  from langroid.language_models.base import LLMFunctionSpec
18
- from langroid.pydantic_v1 import BaseModel, ConfigDict, Extra
18
+ from langroid.pydantic_v1 import BaseModel, Extra
19
19
  from langroid.utils.pydantic_utils import (
20
20
  _recursive_purge_dict_key,
21
21
  generate_simple_schema,
22
22
  )
23
+ from langroid.utils.types import is_instance_of
23
24
 
24
25
 
25
26
  class ToolMessage(ABC, BaseModel):
@@ -41,19 +42,18 @@ class ToolMessage(ABC, BaseModel):
41
42
  purpose: str
42
43
  id: str = "" # placeholder for OpenAI-API tool_call_id
43
44
 
44
- model_config = ConfigDict(extra=Extra.allow)
45
+ _allow_llm_use: bool = True # allow an LLM to use (i.e. generate) this tool?
45
46
 
46
- _handle_only: bool = False # only allow handling, but not use (LLM-generation)?
47
+ # model_config = ConfigDict(extra=Extra.allow)
47
48
 
48
49
  class Config:
49
- # only HANDLING allowed, NOT "use" (i.e LLM generation)
50
- handle_only: bool = False
50
+ extra = Extra.allow
51
51
  arbitrary_types_allowed = False
52
52
  validate_all = True
53
53
  validate_assignment = True
54
54
  # do not include these fields in the generated schema
55
55
  # since we don't require the LLM to specify them
56
- schema_extra = {"exclude": {"purpose", "id", "model_config"}}
56
+ schema_extra = {"exclude": {"purpose", "id"}}
57
57
 
58
58
  @classmethod
59
59
  def instructions(cls) -> str:
@@ -123,6 +123,15 @@ class ToolMessage(ABC, BaseModel):
123
123
  def dict_example(self) -> Dict[str, Any]:
124
124
  return self.dict(exclude=self.Config.schema_extra["exclude"])
125
125
 
126
+ def get_value_of_type(self, target_type: Type[Any]) -> Any:
127
+ """Try to find a value of a desired type in the fields of the ToolMessage."""
128
+ ignore_fields = self.Config.schema_extra["exclude"].union(["request"])
129
+ for field_name in set(self.dict().keys()) - ignore_fields:
130
+ value = getattr(self, field_name)
131
+ if is_instance_of(value, target_type):
132
+ return value
133
+ return None
134
+
126
135
  @classmethod
127
136
  def default_value(cls, f: str) -> Any:
128
137
  """
@@ -273,40 +282,3 @@ class ToolMessage(ABC, BaseModel):
273
282
  exclude=list(cls.Config.schema_extra["exclude"]),
274
283
  )
275
284
  return schema
276
-
277
-
278
- class FinalResultTool(ToolMessage):
279
- """Class to use as a wrapper for sending arbitrary results from an Agent's
280
- agent_response or tool handlers, to:
281
- (a) trigger completion of the current task as well as all parent tasks, and
282
- (b) be returned as the final result of the root task, i.e. this tool would appear
283
- in the final ChatDocument's `tool_messages` list.
284
- See test_tool_handlers_and_results in test_tool_messages.py, and
285
- examples/basic/tool-extract-short-example.py.
286
-
287
- Note:
288
- - when defining a tool handler or agent_response, you can directly return
289
- FinalResultTool(field1 = val1, ...),
290
- where the values can be aribitrary data structures, including nested
291
- Pydantic objs, or you can define a subclass of FinalResultTool with the
292
- fields you want to return.
293
- - This is a special ToolMessage that is NOT meant to be used or handled
294
- by an agent.
295
- """
296
-
297
- request: str = ""
298
- purpose: str = "Ignored; Wrapper for a structured message"
299
- id: str = "" # placeholder for OpenAI-API tool_call_id
300
-
301
- _handle_only: bool = False # only allow handling, but not use (LLM-generation)?
302
-
303
- class Config:
304
- extra = Extra.allow
305
- # only HANDLING allowed, NOT "use" (i.e LLM generation)
306
- handle_only: bool = False
307
- arbitrary_types_allowed = False
308
- validate_all = True
309
- validate_assignment = True
310
- # do not include these fields in the generated schema
311
- # since we don't require the LLM to specify them
312
- schema_extra = {"exclude": {"purpose", "id"}}
@@ -13,6 +13,8 @@ from .orchestration import (
13
13
  SendTool,
14
14
  AgentSendTool,
15
15
  DonePassTool,
16
+ ResultTool,
17
+ FinalResultTool,
16
18
  )
17
19
 
18
20
  __all__ = [
@@ -31,4 +33,6 @@ __all__ = [
31
33
  "PassTool",
32
34
  "SendTool",
33
35
  "AgentSendTool",
36
+ "ResultTool",
37
+ "FinalResultTool",
34
38
  ]