langroid 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ from .chat_document import (
6
6
  ChatDocument,
7
7
  )
8
8
  from .chat_agent import ChatAgentConfig, ChatAgent
9
- from .tool_message import ToolMessage, FinalResultTool
9
+ from .tool_message import ToolMessage
10
10
  from .task import Task
11
11
 
12
12
  from . import base
@@ -29,7 +29,6 @@ __all__ = [
29
29
  "ChatAgent",
30
30
  "ChatAgentConfig",
31
31
  "ToolMessage",
32
- "FinalResultTool",
33
32
  "Task",
34
33
  "base",
35
34
  "chat_document",
langroid/agent/base.py CHANGED
@@ -18,7 +18,10 @@ from typing import (
18
18
  Set,
19
19
  Tuple,
20
20
  Type,
21
+ TypeVar,
21
22
  cast,
23
+ get_args,
24
+ get_origin,
22
25
  no_type_check,
23
26
  )
24
27
 
@@ -46,7 +49,6 @@ from langroid.parsing.parse_json import extract_top_level_json
46
49
  from langroid.parsing.parser import Parser, ParsingConfig
47
50
  from langroid.prompts.prompts_config import PromptsConfig
48
51
  from langroid.pydantic_v1 import (
49
- BaseModel,
50
52
  BaseSettings,
51
53
  Field,
52
54
  ValidationError,
@@ -56,6 +58,7 @@ from langroid.utils.configuration import settings
56
58
  from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
57
59
  from langroid.utils.object_registry import ObjectRegistry
58
60
  from langroid.utils.output import status
61
+ from langroid.utils.types import from_string, to_string
59
62
  from langroid.vector_store.base import VectorStore, VectorStoreConfig
60
63
 
61
64
  ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
@@ -63,6 +66,8 @@ console = Console(quiet=settings.quiet)
63
66
 
64
67
  logger = logging.getLogger(__name__)
65
68
 
69
+ T = TypeVar("T")
70
+
66
71
 
67
72
  class AgentConfig(BaseSettings):
68
73
  """
@@ -78,6 +83,7 @@ class AgentConfig(BaseSettings):
78
83
  prompts: Optional[PromptsConfig] = PromptsConfig()
79
84
  show_stats: bool = True # show token usage/cost stats?
80
85
  add_to_registry: bool = True # register agent in ObjectRegistry?
86
+ respond_tools_only: bool = False # respond only to tool messages (not plain text)?
81
87
 
82
88
  @validator("name")
83
89
  def check_name_alphanum(cls, v: str) -> str:
@@ -150,7 +156,6 @@ class Agent(ABC):
150
156
  show_start_response=noop_fn,
151
157
  )
152
158
  Agent.init_state(self)
153
- self.init_state()
154
159
 
155
160
  def init_state(self) -> None:
156
161
  """Initialize all state vars. Called by Task.run() if restart is True"""
@@ -342,6 +347,7 @@ class Agent(ABC):
342
347
  def create_agent_response(
343
348
  self,
344
349
  content: str | None = None,
350
+ content_any: Any = None,
345
351
  tool_messages: List[ToolMessage] = [],
346
352
  oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
347
353
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
@@ -353,6 +359,7 @@ class Agent(ABC):
353
359
  return self.response_template(
354
360
  Entity.AGENT,
355
361
  content=content,
362
+ content_any=content_any,
356
363
  tool_messages=tool_messages,
357
364
  oai_tool_calls=oai_tool_calls,
358
365
  oai_tool_choice=oai_tool_choice,
@@ -544,6 +551,7 @@ class Agent(ABC):
544
551
  self,
545
552
  e: Entity,
546
553
  content: str | None = None,
554
+ content_any: Any = None,
547
555
  tool_messages: List[ToolMessage] = [],
548
556
  oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
549
557
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
@@ -554,6 +562,7 @@ class Agent(ABC):
554
562
  """Template for response from entity `e`."""
555
563
  return ChatDocument(
556
564
  content=content or "",
565
+ content_any=content_any,
557
566
  tool_messages=tool_messages,
558
567
  oai_tool_calls=oai_tool_calls,
559
568
  oai_tool_id2result=oai_tool_id2result,
@@ -567,6 +576,7 @@ class Agent(ABC):
567
576
  def create_user_response(
568
577
  self,
569
578
  content: str | None = None,
579
+ content_any: Any = None,
570
580
  tool_messages: List[ToolMessage] = [],
571
581
  oai_tool_calls: List[OpenAIToolCall] | None = None,
572
582
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
@@ -578,6 +588,7 @@ class Agent(ABC):
578
588
  return self.response_template(
579
589
  e=Entity.USER,
580
590
  content=content,
591
+ content_any=content_any,
581
592
  tool_messages=tool_messages,
582
593
  oai_tool_calls=oai_tool_calls,
583
594
  oai_tool_choice=oai_tool_choice,
@@ -678,9 +689,26 @@ class Agent(ABC):
678
689
 
679
690
  return True
680
691
 
692
+ def can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
693
+ """
694
+ Whether the agent can respond to a message.
695
+ Used in Task.py to skip a sub-task when we know it would not respond.
696
+ Args:
697
+ message (str|ChatDocument): message or ChatDocument object to respond to.
698
+ """
699
+ tools = self.get_tool_messages(message)
700
+ if len(tools) == 0 and self.config.respond_tools_only:
701
+ return False
702
+ if message is not None and self.has_only_unhandled_tools(message):
703
+ # The message has tools that are NOT enabled to be handled by this agent,
704
+ # which means the agent cannot respond to it.
705
+ return False
706
+ return True
707
+
681
708
  def create_llm_response(
682
709
  self,
683
710
  content: str | None = None,
711
+ content_any: Any = None,
684
712
  tool_messages: List[ToolMessage] = [],
685
713
  oai_tool_calls: None | List[OpenAIToolCall] = None,
686
714
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
@@ -692,6 +720,7 @@ class Agent(ABC):
692
720
  return self.response_template(
693
721
  Entity.LLM,
694
722
  content=content,
723
+ content_any=content_any,
695
724
  tool_messages=tool_messages,
696
725
  oai_tool_calls=oai_tool_calls,
697
726
  oai_tool_choice=oai_tool_choice,
@@ -857,6 +886,8 @@ class Agent(ABC):
857
886
  Does the msg have at least one tool, and ALL tools are
858
887
  disabled for handling by this agent?
859
888
  """
889
+ if msg is None:
890
+ return False
860
891
  tools = self.get_tool_messages(msg, all_tools=True)
861
892
  if len(tools) == 0:
862
893
  return False
@@ -864,7 +895,7 @@ class Agent(ABC):
864
895
 
865
896
  def get_tool_messages(
866
897
  self,
867
- msg: str | ChatDocument,
898
+ msg: str | ChatDocument | None,
868
899
  all_tools: bool = False,
869
900
  ) -> List[ToolMessage]:
870
901
  """
@@ -875,6 +906,9 @@ class Agent(ABC):
875
906
  - otherwise, return only the tools handled by this agent.
876
907
  """
877
908
 
909
+ if msg is None:
910
+ return []
911
+
878
912
  if isinstance(msg, str):
879
913
  json_tools = self.get_json_tool_messages(msg)
880
914
  if all_tools:
@@ -1071,7 +1105,7 @@ class Agent(ABC):
1071
1105
  fallback_result = self.handle_message_fallback(msg)
1072
1106
  if fallback_result is None:
1073
1107
  return None
1074
- return self._process_handle_message_result(
1108
+ return self.to_ChatDocument(
1075
1109
  fallback_result,
1076
1110
  chat_doc=msg if isinstance(msg, ChatDocument) else None,
1077
1111
  )
@@ -1110,7 +1144,13 @@ class Agent(ABC):
1110
1144
  results = [err_str for _ in tools]
1111
1145
  else:
1112
1146
  results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
1147
+ # if there's a solitary ChatDocument|str result, return it as is
1148
+ if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
1149
+ return results[0]
1150
+ # extract content from ChatDocument results so we have all str|None
1151
+ results = [r.content if isinstance(r, ChatDocument) else r for r in results]
1113
1152
 
1153
+ # now all results are str|None
1114
1154
  tool_names = [t.default_value("request") for t in tools]
1115
1155
  if has_ids:
1116
1156
  id2result = OrderedDict(
@@ -1133,35 +1173,16 @@ class Agent(ABC):
1133
1173
  (name, r) for name, r in zip(tool_names, results) if r is not None
1134
1174
  ]
1135
1175
  if len(name_results_list) == 0:
1136
- return None # self.handle_message_fallback(msg)
1176
+ return None
1177
+
1137
1178
  # there was a non-None result
1138
- chat_doc_results = [
1139
- r for _, r in name_results_list if isinstance(r, ChatDocument)
1140
- ]
1141
- if len(chat_doc_results) > 1:
1142
- logger.warning(
1143
- """There were multiple ChatDocument results from tools,
1144
- which is unexpected. The first one will be returned, and the others
1145
- will be ignored.
1146
- """
1147
- )
1148
- if len(chat_doc_results) > 0:
1149
- return chat_doc_results[0]
1150
1179
 
1151
1180
  if has_ids and len(id2result) > 1:
1152
1181
  # if there are multiple OpenAI Tool results, return them as a dict
1153
1182
  return id2result
1154
1183
 
1155
- if len(name_results_list) == 1 and isinstance(name_results_list[0][1], str):
1156
- # single str result -- return it as is
1157
- return name_results_list[0][1]
1158
-
1159
1184
  # multi-results: prepend the tool name to each result
1160
- str_results = [
1161
- f"Result from {name}: {r}"
1162
- for name, r in name_results_list
1163
- if isinstance(r, str)
1164
- ]
1185
+ str_results = [f"Result from {name}: {r}" for name, r in name_results_list]
1165
1186
  final = "\n\n".join(str_results)
1166
1187
  return final
1167
1188
 
@@ -1261,20 +1282,41 @@ class Agent(ABC):
1261
1282
  raise ve
1262
1283
  return message
1263
1284
 
1264
- def _process_handle_message_result(
1285
+ def to_ChatDocument(
1265
1286
  self,
1266
1287
  msg: Any,
1267
1288
  orig_tool_name: str | None = None,
1268
1289
  chat_doc: Optional[ChatDocument] = None,
1269
- ) -> None | str | ChatDocument:
1290
+ author_entity: Entity = Entity.AGENT,
1291
+ ) -> Optional[ChatDocument]:
1270
1292
  """
1271
- Process result of agent_response or tool handler, or handle_message_fallback.
1293
+ Convert result of a responder (agent_response or llm_response, or task.run()),
1294
+ or tool handler, or handle_message_fallback,
1295
+ to a ChatDocument, to enabling handling by other
1296
+ responders/tasks in a task loop possibly involving multiple agents.
1297
+
1298
+ Args:
1299
+ msg (Any): The result of a responder or tool handler or task.run()
1300
+ orig_tool_name (str): The original tool name that generated the response,
1301
+ if any.
1302
+ chat_doc (ChatDocument): The original ChatDocument object that `msg`
1303
+ is a response to.
1304
+ author_entity (Entity): The intended author of the result ChatDocument
1272
1305
  """
1273
- if isinstance(msg, ToolMessage):
1306
+ if msg is None or isinstance(msg, ChatDocument):
1307
+ return msg
1308
+
1309
+ is_agent_author = author_entity == Entity.AGENT
1310
+
1311
+ if isinstance(msg, str):
1312
+ return self.response_template(author_entity, content=msg, content_any=msg)
1313
+ elif isinstance(msg, ToolMessage):
1274
1314
  # result is a ToolMessage, so...
1275
1315
  result_tool_name = msg.default_value("request")
1276
- if result_tool_name in self.llm_tools_handled and (
1277
- orig_tool_name is None or orig_tool_name != result_tool_name
1316
+ if (
1317
+ is_agent_author
1318
+ and result_tool_name in self.llm_tools_handled
1319
+ and (orig_tool_name is None or orig_tool_name != result_tool_name)
1278
1320
  ):
1279
1321
  # TODO: do we need to remove the tool message from the chat_doc?
1280
1322
  # if (chat_doc is not None and
@@ -1282,30 +1324,73 @@ class Agent(ABC):
1282
1324
  # chat_doc.tool_messages.remove(msg)
1283
1325
  # if we can handle it, do so
1284
1326
  result = self.handle_tool_message(msg, chat_doc=chat_doc)
1327
+ if result is not None and isinstance(result, ChatDocument):
1328
+ return result
1285
1329
  else:
1286
1330
  # else wrap it in an agent response and return it so
1287
1331
  # orchestrator can find a respondent
1288
- result = self.create_agent_response(tool_messages=[msg])
1289
- elif isinstance(msg, (ChatDocument, str)):
1290
- result = msg
1291
- elif isinstance(msg, BaseModel):
1292
- result = msg.json()
1332
+ return self.response_template(author_entity, tool_messages=[msg])
1293
1333
  else:
1294
- # last resort: use json.dumps() or str() to make it a str
1295
- try:
1296
- result = json.dumps(msg)
1297
- except Exception:
1298
- try:
1299
- result = str(msg)
1300
- except Exception as e:
1301
- logger.error(
1302
- f"""
1303
- Error converting msg handler result to str: {e}",
1304
- """,
1305
- exc_info=True,
1306
- )
1307
- result = None
1308
- return result
1334
+ result = to_string(msg)
1335
+
1336
+ return (
1337
+ None
1338
+ if result is None
1339
+ else self.response_template(author_entity, content=result, content_any=msg)
1340
+ )
1341
+
1342
+ def from_ChatDocument(self, msg: ChatDocument, output_type: Type[T]) -> Optional[T]:
1343
+ """
1344
+ Extract a desired output_type from a ChatDocument object.
1345
+ We use this fallback order:
1346
+ - if `msg.content_any` exists and matches the output_type, return it
1347
+ - if `msg.content` exists and output_type is str return it
1348
+ - if output_type is a ToolMessage, return the first tool in `msg.tool_messages`
1349
+ - if output_type is a list of ToolMessage,
1350
+ return all tools in `msg.tool_messages`
1351
+ - search for a tool in `msg.tool_messages` that has a field of output_type,
1352
+ and if found, return that field value
1353
+ - return None if all the above fail
1354
+ """
1355
+ content = msg.content
1356
+ if output_type is str and content != "":
1357
+ return cast(T, content)
1358
+ content_any = msg.content_any
1359
+ if content_any is not None and isinstance(content_any, output_type):
1360
+ return cast(T, content_any)
1361
+
1362
+ tools = self.get_tool_messages(msg, all_tools=True)
1363
+
1364
+ if get_origin(output_type) is list:
1365
+ list_element_type = get_args(output_type)[0]
1366
+ if issubclass(list_element_type, ToolMessage):
1367
+ # list_element_type is a subclass of ToolMessage:
1368
+ # We output a list of objects derived from list_element_type
1369
+ return cast(
1370
+ T,
1371
+ [t for t in tools if isinstance(t, list_element_type)],
1372
+ )
1373
+ elif get_origin(output_type) is None and issubclass(output_type, ToolMessage):
1374
+ # output_type is a subclass of ToolMessage:
1375
+ # return the first tool that has this specific output_type
1376
+ for tool in tools:
1377
+ if isinstance(tool, output_type):
1378
+ return cast(T, tool)
1379
+ return None
1380
+ elif get_origin(output_type) is None and output_type in (str, int, float, bool):
1381
+ # attempt to get the output_type from the content,
1382
+ # if it's a primitive type
1383
+ primitive_value = from_string(content, output_type) # type: ignore
1384
+ if primitive_value is not None:
1385
+ return cast(T, primitive_value)
1386
+
1387
+ # then search for output_type as a field in a tool
1388
+ for tool in tools:
1389
+ value = tool.get_value_of_type(output_type)
1390
+ if value is not None:
1391
+ return cast(T, value)
1392
+
1393
+ return None
1309
1394
 
1310
1395
  def handle_tool_message(
1311
1396
  self,
@@ -1336,9 +1421,7 @@ class Agent(ABC):
1336
1421
  maybe_result = handler_method(tool, chat_doc=chat_doc)
1337
1422
  else:
1338
1423
  maybe_result = handler_method(tool)
1339
- result = self._process_handle_message_result(
1340
- maybe_result, tool_name, chat_doc
1341
- )
1424
+ result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
1342
1425
  except Exception as e:
1343
1426
  # raise the error here since we are sure it's
1344
1427
  # not a pydantic validation error,
@@ -118,7 +118,7 @@ class ChatAgent(Agent):
118
118
  self.config: ChatAgentConfig = config
119
119
  self.config._set_fn_or_tools(self._fn_call_available())
120
120
  self.message_history: List[LLMMessage] = []
121
- self.tool_instructions_added: bool = False
121
+ self.init_state()
122
122
  # An agent's "task" is defined by a system msg and an optional user msg;
123
123
  # These are "priming" messages that kick off the agent's conversation.
124
124
  self.system_message: str = self.config.system_message
@@ -161,6 +161,7 @@ class ChatAgent(Agent):
161
161
  DoneTool,
162
162
  ForwardTool,
163
163
  PassTool,
164
+ ResultTool,
164
165
  SendTool,
165
166
  )
166
167
 
@@ -171,6 +172,16 @@ class ChatAgent(Agent):
171
172
  self.enable_message(DonePassTool, use=False, handle=True)
172
173
  self.enable_message(SendTool, use=False, handle=True)
173
174
  self.enable_message(AgentSendTool, use=False, handle=True)
175
+ self.enable_message(ResultTool, use=False, handle=True)
176
+
177
+ def init_state(self) -> None:
178
+ """
179
+ Initialize the state of the agent. Just conversation state here,
180
+ but subclasses can override this to initialize other state.
181
+ """
182
+ super().init_state()
183
+ self.clear_history(0)
184
+ self.clear_dialog()
174
185
 
175
186
  @staticmethod
176
187
  def from_id(id: str) -> "ChatAgent":
@@ -303,8 +314,7 @@ class ChatAgent(Agent):
303
314
  usable_tool_classes: List[Type[ToolMessage]] = [
304
315
  t
305
316
  for t in list(self.llm_tools_map.values())
306
- if not t._handle_only
307
- and t.default_value("request") in self.llm_tools_usable
317
+ if t.default_value("request") in self.llm_tools_usable
308
318
  ]
309
319
 
310
320
  if len(usable_tool_classes) == 0:
@@ -513,6 +523,13 @@ class ChatAgent(Agent):
513
523
  tools = self._get_tool_list(message_class)
514
524
  if message_class is not None:
515
525
  request = message_class.default_value("request")
526
+ if request == "":
527
+ raise ValueError(
528
+ f"""
529
+ ToolMessage class {message_class} must have a non-empty
530
+ 'request' field if it is to be enabled as a tool.
531
+ """
532
+ )
516
533
  llm_function = message_class.llm_function_schema(defaults=include_defaults)
517
534
  self.llm_functions_map[request] = llm_function
518
535
  if force:
@@ -531,8 +548,21 @@ class ChatAgent(Agent):
531
548
  self.llm_functions_handled.discard(t)
532
549
 
533
550
  if use:
534
- self.llm_tools_usable.add(t)
535
- self.llm_functions_usable.add(t)
551
+ tool_class = self.llm_tools_map[t]
552
+ if tool_class._allow_llm_use:
553
+ self.llm_tools_usable.add(t)
554
+ self.llm_functions_usable.add(t)
555
+ else:
556
+ logger.warning(
557
+ f"""
558
+ ToolMessage class {tool_class} does not allow LLM use,
559
+ because `_allow_llm_use=False` either in the Tool or a
560
+ parent class of this tool;
561
+ so not enabling LLM use for this tool!
562
+ If you intended an LLM to use this tool,
563
+ set `_allow_llm_use=True` when you define the tool.
564
+ """
565
+ )
536
566
  else:
537
567
  self.llm_tools_usable.discard(t)
538
568
  self.llm_functions_usable.discard(t)
@@ -22,6 +22,7 @@ from langroid.parsing.parse_json import extract_top_level_json, top_level_json_f
22
22
  from langroid.pydantic_v1 import BaseModel, Extra
23
23
  from langroid.utils.object_registry import ObjectRegistry
24
24
  from langroid.utils.output.printing import shorten_text
25
+ from langroid.utils.types import to_string
25
26
 
26
27
 
27
28
  class ChatDocAttachment(BaseModel):
@@ -115,6 +116,7 @@ class ChatDocument(Document):
115
116
  attachment (None | ChatDocAttachment): Any additional data attached.
116
117
  """
117
118
 
119
+ content_any: Any = None # to hold arbitrary data returned by responders
118
120
  oai_tool_calls: Optional[List[OpenAIToolCall]] = None
119
121
  oai_tool_id2result: Optional[OrderedDict[str, str]] = None
120
122
  oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto"
@@ -281,6 +283,7 @@ class ChatDocument(Document):
281
283
  ChatDocument._clean_fn_call(oai_tc.function)
282
284
  return ChatDocument(
283
285
  content=message,
286
+ content_any=message,
284
287
  oai_tool_calls=response.oai_tool_calls,
285
288
  function_call=response.function_call,
286
289
  metadata=ChatDocMetaData(
@@ -303,6 +306,7 @@ class ChatDocument(Document):
303
306
  message = msg # retain the whole msg in this case
304
307
  return ChatDocument(
305
308
  content=message,
309
+ content_any=message,
306
310
  metadata=ChatDocMetaData(
307
311
  source=Entity.USER,
308
312
  sender=Entity.USER,
@@ -335,7 +339,7 @@ class ChatDocument(Document):
335
339
  tool_id = "" # for OpenAI Assistant
336
340
  chat_document_id: str = ""
337
341
  if isinstance(message, ChatDocument):
338
- content = message.content
342
+ content = message.content or to_string(message.content_any) or ""
339
343
  fun_call = message.function_call
340
344
  oai_tool_calls = message.oai_tool_calls
341
345
  if message.metadata.sender == Entity.USER and fun_call is not None:
@@ -169,6 +169,7 @@ class QueryPlanCritic(ChatAgent):
169
169
  self.enable_message(AgentDoneTool, use=False, handle=True)
170
170
 
171
171
  def init_state(self) -> None:
172
+ super().init_state()
172
173
  self.expecting_feedback_tool = False
173
174
 
174
175
  def query_plan_answer(self, msg: QueryPlanAnswerTool) -> str:
@@ -144,6 +144,7 @@ class LanceQueryPlanAgent(ChatAgent):
144
144
  self.enable_message(AgentDoneTool, use=False, handle=True)
145
145
 
146
146
  def init_state(self) -> None:
147
+ super().init_state()
147
148
  self.curr_query_plan: QueryPlan | None = None
148
149
  self.expecting_query_plan: bool = False
149
150
  # how many times re-trying query plan in response to feedback: