langroid 0.33.6__py3-none-any.whl → 0.33.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. langroid/__init__.py +106 -0
  2. langroid/agent/__init__.py +41 -0
  3. langroid/agent/base.py +1983 -0
  4. langroid/agent/batch.py +398 -0
  5. langroid/agent/callbacks/__init__.py +0 -0
  6. langroid/agent/callbacks/chainlit.py +598 -0
  7. langroid/agent/chat_agent.py +1899 -0
  8. langroid/agent/chat_document.py +454 -0
  9. langroid/agent/openai_assistant.py +882 -0
  10. langroid/agent/special/__init__.py +59 -0
  11. langroid/agent/special/arangodb/__init__.py +0 -0
  12. langroid/agent/special/arangodb/arangodb_agent.py +656 -0
  13. langroid/agent/special/arangodb/system_messages.py +186 -0
  14. langroid/agent/special/arangodb/tools.py +107 -0
  15. langroid/agent/special/arangodb/utils.py +36 -0
  16. langroid/agent/special/doc_chat_agent.py +1466 -0
  17. langroid/agent/special/lance_doc_chat_agent.py +262 -0
  18. langroid/agent/special/lance_rag/__init__.py +9 -0
  19. langroid/agent/special/lance_rag/critic_agent.py +198 -0
  20. langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
  21. langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
  22. langroid/agent/special/lance_tools.py +61 -0
  23. langroid/agent/special/neo4j/__init__.py +0 -0
  24. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  25. langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
  26. langroid/agent/special/neo4j/system_messages.py +120 -0
  27. langroid/agent/special/neo4j/tools.py +32 -0
  28. langroid/agent/special/relevance_extractor_agent.py +127 -0
  29. langroid/agent/special/retriever_agent.py +56 -0
  30. langroid/agent/special/sql/__init__.py +17 -0
  31. langroid/agent/special/sql/sql_chat_agent.py +654 -0
  32. langroid/agent/special/sql/utils/__init__.py +21 -0
  33. langroid/agent/special/sql/utils/description_extractors.py +190 -0
  34. langroid/agent/special/sql/utils/populate_metadata.py +85 -0
  35. langroid/agent/special/sql/utils/system_message.py +35 -0
  36. langroid/agent/special/sql/utils/tools.py +64 -0
  37. langroid/agent/special/table_chat_agent.py +263 -0
  38. langroid/agent/task.py +2099 -0
  39. langroid/agent/tool_message.py +393 -0
  40. langroid/agent/tools/__init__.py +38 -0
  41. langroid/agent/tools/duckduckgo_search_tool.py +50 -0
  42. langroid/agent/tools/file_tools.py +234 -0
  43. langroid/agent/tools/google_search_tool.py +39 -0
  44. langroid/agent/tools/metaphor_search_tool.py +68 -0
  45. langroid/agent/tools/orchestration.py +303 -0
  46. langroid/agent/tools/recipient_tool.py +235 -0
  47. langroid/agent/tools/retrieval_tool.py +32 -0
  48. langroid/agent/tools/rewind_tool.py +137 -0
  49. langroid/agent/tools/segment_extract_tool.py +41 -0
  50. langroid/agent/xml_tool_message.py +382 -0
  51. langroid/cachedb/__init__.py +17 -0
  52. langroid/cachedb/base.py +58 -0
  53. langroid/cachedb/momento_cachedb.py +108 -0
  54. langroid/cachedb/redis_cachedb.py +153 -0
  55. langroid/embedding_models/__init__.py +39 -0
  56. langroid/embedding_models/base.py +74 -0
  57. langroid/embedding_models/models.py +461 -0
  58. langroid/embedding_models/protoc/__init__.py +0 -0
  59. langroid/embedding_models/protoc/embeddings.proto +19 -0
  60. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  61. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  62. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  63. langroid/embedding_models/remote_embeds.py +153 -0
  64. langroid/exceptions.py +71 -0
  65. langroid/language_models/__init__.py +53 -0
  66. langroid/language_models/azure_openai.py +153 -0
  67. langroid/language_models/base.py +678 -0
  68. langroid/language_models/config.py +18 -0
  69. langroid/language_models/mock_lm.py +124 -0
  70. langroid/language_models/openai_gpt.py +1964 -0
  71. langroid/language_models/prompt_formatter/__init__.py +16 -0
  72. langroid/language_models/prompt_formatter/base.py +40 -0
  73. langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
  74. langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
  75. langroid/language_models/utils.py +151 -0
  76. langroid/mytypes.py +84 -0
  77. langroid/parsing/__init__.py +52 -0
  78. langroid/parsing/agent_chats.py +38 -0
  79. langroid/parsing/code_parser.py +121 -0
  80. langroid/parsing/document_parser.py +718 -0
  81. langroid/parsing/para_sentence_split.py +62 -0
  82. langroid/parsing/parse_json.py +155 -0
  83. langroid/parsing/parser.py +313 -0
  84. langroid/parsing/repo_loader.py +790 -0
  85. langroid/parsing/routing.py +36 -0
  86. langroid/parsing/search.py +275 -0
  87. langroid/parsing/spider.py +102 -0
  88. langroid/parsing/table_loader.py +94 -0
  89. langroid/parsing/url_loader.py +115 -0
  90. langroid/parsing/urls.py +273 -0
  91. langroid/parsing/utils.py +373 -0
  92. langroid/parsing/web_search.py +156 -0
  93. langroid/prompts/__init__.py +9 -0
  94. langroid/prompts/dialog.py +17 -0
  95. langroid/prompts/prompts_config.py +5 -0
  96. langroid/prompts/templates.py +141 -0
  97. langroid/pydantic_v1/__init__.py +10 -0
  98. langroid/pydantic_v1/main.py +4 -0
  99. langroid/utils/__init__.py +19 -0
  100. langroid/utils/algorithms/__init__.py +3 -0
  101. langroid/utils/algorithms/graph.py +103 -0
  102. langroid/utils/configuration.py +98 -0
  103. langroid/utils/constants.py +30 -0
  104. langroid/utils/git_utils.py +252 -0
  105. langroid/utils/globals.py +49 -0
  106. langroid/utils/logging.py +135 -0
  107. langroid/utils/object_registry.py +66 -0
  108. langroid/utils/output/__init__.py +20 -0
  109. langroid/utils/output/citations.py +41 -0
  110. langroid/utils/output/printing.py +99 -0
  111. langroid/utils/output/status.py +40 -0
  112. langroid/utils/pandas_utils.py +30 -0
  113. langroid/utils/pydantic_utils.py +602 -0
  114. langroid/utils/system.py +286 -0
  115. langroid/utils/types.py +93 -0
  116. langroid/vector_store/__init__.py +50 -0
  117. langroid/vector_store/base.py +359 -0
  118. langroid/vector_store/chromadb.py +214 -0
  119. langroid/vector_store/lancedb.py +406 -0
  120. langroid/vector_store/meilisearch.py +299 -0
  121. langroid/vector_store/momento.py +278 -0
  122. langroid/vector_store/qdrantdb.py +468 -0
  123. {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/METADATA +95 -94
  124. langroid-0.33.8.dist-info/RECORD +127 -0
  125. {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/WHEEL +1 -1
  126. langroid-0.33.6.dist-info/RECORD +0 -7
  127. langroid-0.33.6.dist-info/entry_points.txt +0 -4
  128. pyproject.toml +0 -356
  129. {langroid-0.33.6.dist-info → langroid-0.33.8.dist-info}/licenses/LICENSE +0 -0
langroid/agent/base.py ADDED
@@ -0,0 +1,1983 @@
1
+ import asyncio
2
+ import copy
3
+ import inspect
4
+ import json
5
+ import logging
6
+ import re
7
+ from abc import ABC
8
+ from collections import OrderedDict
9
+ from contextlib import ExitStack
10
+ from types import SimpleNamespace
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Coroutine,
15
+ Dict,
16
+ List,
17
+ Optional,
18
+ Set,
19
+ Tuple,
20
+ Type,
21
+ TypeVar,
22
+ cast,
23
+ get_args,
24
+ get_origin,
25
+ no_type_check,
26
+ )
27
+
28
+ from rich import print
29
+ from rich.console import Console
30
+ from rich.markup import escape
31
+ from rich.prompt import Prompt
32
+
33
+ from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
34
+ from langroid.agent.tool_message import ToolMessage
35
+ from langroid.agent.xml_tool_message import XMLToolMessage
36
+ from langroid.exceptions import XMLException
37
+ from langroid.language_models.base import (
38
+ LanguageModel,
39
+ LLMConfig,
40
+ LLMFunctionCall,
41
+ LLMMessage,
42
+ LLMResponse,
43
+ LLMTokenUsage,
44
+ OpenAIToolCall,
45
+ StreamingIfAllowed,
46
+ ToolChoiceTypes,
47
+ )
48
+ from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
49
+ from langroid.mytypes import Entity
50
+ from langroid.parsing.parse_json import extract_top_level_json
51
+ from langroid.parsing.parser import Parser, ParsingConfig
52
+ from langroid.prompts.prompts_config import PromptsConfig
53
+ from langroid.pydantic_v1 import (
54
+ BaseSettings,
55
+ Field,
56
+ ValidationError,
57
+ validator,
58
+ )
59
+ from langroid.utils.configuration import settings
60
+ from langroid.utils.constants import (
61
+ DONE,
62
+ NO_ANSWER,
63
+ PASS,
64
+ PASS_TO,
65
+ SEND_TO,
66
+ )
67
+ from langroid.utils.object_registry import ObjectRegistry
68
+ from langroid.utils.output import status
69
+ from langroid.utils.types import from_string, to_string
70
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
71
+
72
+ ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
73
+ console = Console(quiet=settings.quiet)
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ T = TypeVar("T")
78
+
79
+
80
+ class AgentConfig(BaseSettings):
81
+ """
82
+ General config settings for an LLM agent. This is nested, combining configs of
83
+ various components.
84
+ """
85
+
86
+ name: str = "LLM-Agent"
87
+ debug: bool = False
88
+ vecdb: Optional[VectorStoreConfig] = None
89
+ llm: Optional[LLMConfig] = OpenAIGPTConfig()
90
+ parsing: Optional[ParsingConfig] = ParsingConfig()
91
+ prompts: Optional[PromptsConfig] = PromptsConfig()
92
+ show_stats: bool = True # show token usage/cost stats?
93
+ add_to_registry: bool = True # register agent in ObjectRegistry?
94
+ respond_tools_only: bool = False # respond only to tool messages (not plain text)?
95
+ # allow multiple tool messages in a single response?
96
+ allow_multiple_tools: bool = True
97
+ human_prompt: str = (
98
+ "Human (respond or q, x to exit current level, " "or hit enter to continue)"
99
+ )
100
+
101
+ @validator("name")
102
+ def check_name_alphanum(cls, v: str) -> str:
103
+ if not re.match(r"^[a-zA-Z0-9_-]+$", v):
104
+ raise ValueError(
105
+ "The name must only contain alphanumeric characters, "
106
+ "underscores, or hyphens, with no spaces"
107
+ )
108
+ return v
109
+
110
+
111
+ def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
112
+ pass
113
+
114
+
115
+ async def async_noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
116
+ pass
117
+
118
+
119
+ async def async_lambda_noop_fn() -> Callable[..., Coroutine[Any, Any, None]]:
120
+ return async_noop_fn
121
+
122
+
123
+ class Agent(ABC):
124
+ """
125
+ An Agent is an abstraction that encapsulates mainly two components:
126
+
127
+ - a language model (LLM)
128
+ - a vector store (vecdb)
129
+
130
+ plus associated components such as a parser, and variables that hold
131
+ information about any tool/function-calling messages that have been defined.
132
+ """
133
+
134
+ id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
135
+ # OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
136
+ # is added to self.message_history
137
+ oai_tool_calls: List[OpenAIToolCall] = []
138
+ # Index of ALL tool calls generated by the agent
139
+ oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
140
+
141
+ def __init__(self, config: AgentConfig = AgentConfig()):
142
+ self.config = config
143
+ self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
144
+ self.dialog: List[Tuple[str, str]] = [] # seq of LLM (prompt, response) tuples
145
+ self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
146
+ self.llm_tools_handled: Set[str] = set()
147
+ self.llm_tools_usable: Set[str] = set()
148
+ self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
149
+ # Indicates which tool-names are allowed to be inferred when
150
+ # the LLM "forgets" to include the request field in its
151
+ # tool-call.
152
+ self.enabled_requests_for_inference: Optional[Set[str]] = (
153
+ None # If None, we allow all
154
+ )
155
+ self.interactive: bool = True # may be modified by Task wrapper
156
+ self.token_stats_str = ""
157
+ self.default_human_response: Optional[str] = None
158
+ self._indent = ""
159
+ self.llm = LanguageModel.create(config.llm)
160
+ self.vecdb = VectorStore.create(config.vecdb) if config.vecdb else None
161
+ self.tool_error = False
162
+ if config.parsing is not None and self.config.llm is not None:
163
+ # token_encoding_model is used to obtain the tokenizer,
164
+ # so in case it's an OpenAI model, we ensure that the tokenizer
165
+ # corresponding to the model is used.
166
+ if isinstance(self.llm, OpenAIGPT) and self.llm.is_openai_chat_model():
167
+ config.parsing.token_encoding_model = self.llm.config.chat_model
168
+ self.parser: Optional[Parser] = (
169
+ Parser(config.parsing) if config.parsing else None
170
+ )
171
+ if config.add_to_registry:
172
+ ObjectRegistry.register_object(self)
173
+
174
+ self.callbacks = SimpleNamespace(
175
+ start_llm_stream=lambda: noop_fn,
176
+ start_llm_stream_async=async_lambda_noop_fn,
177
+ cancel_llm_stream=noop_fn,
178
+ finish_llm_stream=noop_fn,
179
+ show_llm_response=noop_fn,
180
+ show_agent_response=noop_fn,
181
+ get_user_response=None,
182
+ get_user_response_async=None,
183
+ get_last_step=noop_fn,
184
+ set_parent_agent=noop_fn,
185
+ show_error_message=noop_fn,
186
+ show_start_response=noop_fn,
187
+ )
188
+ Agent.init_state(self)
189
+
190
+ def init_state(self) -> None:
191
+ """Initialize all state vars. Called by Task.run() if restart is True"""
192
+ self.total_llm_token_cost = 0.0
193
+ self.total_llm_token_usage = 0
194
+
195
+ @staticmethod
196
+ def from_id(id: str) -> "Agent":
197
+ return cast(Agent, ObjectRegistry.get(id))
198
+
199
+ @staticmethod
200
+ def delete_id(id: str) -> None:
201
+ ObjectRegistry.remove(id)
202
+
203
+ def entity_responders(
204
+ self,
205
+ ) -> List[
206
+ Tuple[Entity, Callable[[None | str | ChatDocument], None | ChatDocument]]
207
+ ]:
208
+ """
209
+ Sequence of (entity, response_method) pairs. This sequence is used
210
+ in a `Task` to respond to the current pending message.
211
+ See `Task.step()` for details.
212
+ Returns:
213
+ Sequence of (entity, response_method) pairs.
214
+ """
215
+ return [
216
+ (Entity.AGENT, self.agent_response),
217
+ (Entity.LLM, self.llm_response),
218
+ (Entity.USER, self.user_response),
219
+ ]
220
+
221
+ def entity_responders_async(
222
+ self,
223
+ ) -> List[
224
+ Tuple[
225
+ Entity,
226
+ Callable[
227
+ [None | str | ChatDocument], Coroutine[Any, Any, None | ChatDocument]
228
+ ],
229
+ ]
230
+ ]:
231
+ """
232
+ Async version of `entity_responders`. See there for details.
233
+ """
234
+ return [
235
+ (Entity.AGENT, self.agent_response_async),
236
+ (Entity.LLM, self.llm_response_async),
237
+ (Entity.USER, self.user_response_async),
238
+ ]
239
+
240
+ @property
241
+ def indent(self) -> str:
242
+ """Indentation to print before any responses from the agent's entities."""
243
+ return self._indent
244
+
245
+ @indent.setter
246
+ def indent(self, value: str) -> None:
247
+ self._indent = value
248
+
249
+ def update_dialog(self, prompt: str, output: str) -> None:
250
+ self.dialog.append((prompt, output))
251
+
252
+ def get_dialog(self) -> List[Tuple[str, str]]:
253
+ return self.dialog
254
+
255
+ def clear_dialog(self) -> None:
256
+ self.dialog = []
257
+
258
+ def _get_tool_list(
259
+ self, message_class: Optional[Type[ToolMessage]] = None
260
+ ) -> List[str]:
261
+ """
262
+ If `message_class` is None, return a list of all known tool names.
263
+ Otherwise, first add the tool name corresponding to the message class
264
+ (which is the value of the `request` field of the message class),
265
+ to the `self.llm_tools_map` dict, and then return a list
266
+ containing this tool name.
267
+
268
+ Args:
269
+ message_class (Optional[Type[ToolMessage]]): The message class whose tool
270
+ name is to be returned; Optional, default is None.
271
+ if None, return a list of all known tool names.
272
+
273
+ Returns:
274
+ List[str]: List of tool names: either just the tool name corresponding
275
+ to the message class, or all known tool names
276
+ (when `message_class` is None).
277
+
278
+ """
279
+ if message_class is None:
280
+ return list(self.llm_tools_map.keys())
281
+
282
+ if not issubclass(message_class, ToolMessage):
283
+ raise ValueError("message_class must be a subclass of ToolMessage")
284
+ tool = message_class.default_value("request")
285
+
286
+ """
287
+ if tool has handler method explicitly defined - use it,
288
+ otherwise use the tool name as the handler
289
+ """
290
+ if hasattr(message_class, "_handler"):
291
+ handler = getattr(message_class, "_handler", tool)
292
+ else:
293
+ handler = tool
294
+
295
+ self.llm_tools_map[tool] = message_class
296
+ if (
297
+ hasattr(message_class, "handle")
298
+ and inspect.isfunction(message_class.handle)
299
+ and not hasattr(self, handler)
300
+ ):
301
+ """
302
+ If the message class has a `handle` method,
303
+ and agent does NOT have a tool handler method,
304
+ then we create a method for the agent whose name
305
+ is the value of `handler`, and whose body is the `handle` method.
306
+ This removes a separate step of having to define this method
307
+ for the agent, and also keeps the tool definition AND handling
308
+ in one place, i.e. in the message class.
309
+ See `tests/main/test_stateless_tool_messages.py` for an example.
310
+ """
311
+ has_chat_doc_arg = (
312
+ len(inspect.signature(message_class.handle).parameters) > 1
313
+ )
314
+ if has_chat_doc_arg:
315
+ setattr(self, handler, lambda obj, chat_doc: obj.handle(chat_doc))
316
+ else:
317
+ setattr(self, handler, lambda obj: obj.handle())
318
+ elif (
319
+ hasattr(message_class, "response")
320
+ and inspect.isfunction(message_class.response)
321
+ and not hasattr(self, handler)
322
+ ):
323
+ has_chat_doc_arg = (
324
+ len(inspect.signature(message_class.response).parameters) > 2
325
+ )
326
+ if has_chat_doc_arg:
327
+ setattr(
328
+ self, handler, lambda obj, chat_doc: obj.response(self, chat_doc)
329
+ )
330
+ else:
331
+ setattr(self, handler, lambda obj: obj.response(self))
332
+
333
+ if hasattr(message_class, "handle_message_fallback") and (
334
+ inspect.isfunction(message_class.handle_message_fallback)
335
+ ):
336
+ setattr(
337
+ self,
338
+ "handle_message_fallback",
339
+ lambda msg: message_class.handle_message_fallback(self, msg),
340
+ )
341
+
342
+ async_handler_name = f"{handler}_async"
343
+ if (
344
+ hasattr(message_class, "handle_async")
345
+ and inspect.isfunction(message_class.handle_async)
346
+ and not hasattr(self, async_handler_name)
347
+ ):
348
+ has_chat_doc_arg = (
349
+ len(inspect.signature(message_class.handle_async).parameters) > 1
350
+ )
351
+
352
+ if has_chat_doc_arg:
353
+
354
+ @no_type_check
355
+ async def handler(obj, chat_doc):
356
+ return await obj.handle_async(chat_doc)
357
+
358
+ else:
359
+
360
+ @no_type_check
361
+ async def handler(obj):
362
+ return await obj.handle_async()
363
+
364
+ setattr(self, async_handler_name, handler)
365
+ elif (
366
+ hasattr(message_class, "response_async")
367
+ and inspect.isfunction(message_class.response_async)
368
+ and not hasattr(self, async_handler_name)
369
+ ):
370
+ has_chat_doc_arg = (
371
+ len(inspect.signature(message_class.response_async).parameters) > 2
372
+ )
373
+
374
+ if has_chat_doc_arg:
375
+
376
+ @no_type_check
377
+ async def handler(obj, chat_doc):
378
+ return await obj.response_async(self, chat_doc)
379
+
380
+ else:
381
+
382
+ @no_type_check
383
+ async def handler(obj):
384
+ return await obj.response_async(self)
385
+
386
+ setattr(self, async_handler_name, handler)
387
+
388
+ return [tool]
389
+
390
+ def enable_message_handling(
391
+ self, message_class: Optional[Type[ToolMessage]] = None
392
+ ) -> None:
393
+ """
394
+ Enable an agent to RESPOND (i.e. handle) a "tool" message of a specific type
395
+ from LLM. Also "registers" (i.e. adds) the `message_class` to the
396
+ `self.llm_tools_map` dict.
397
+
398
+ Args:
399
+ message_class (Optional[Type[ToolMessage]]): The message class to enable;
400
+ Optional; if None, all known message classes are enabled for handling.
401
+
402
+ """
403
+ for t in self._get_tool_list(message_class):
404
+ self.llm_tools_handled.add(t)
405
+
406
+ def disable_message_handling(
407
+ self,
408
+ message_class: Optional[Type[ToolMessage]] = None,
409
+ ) -> None:
410
+ """
411
+ Disable a message class from being handled by this Agent.
412
+
413
+ Args:
414
+ message_class (Optional[Type[ToolMessage]]): The message class to disable.
415
+ If None, all message classes are disabled.
416
+ """
417
+ for t in self._get_tool_list(message_class):
418
+ self.llm_tools_handled.discard(t)
419
+
420
+ def sample_multi_round_dialog(self) -> str:
421
+ """
422
+ Generate a sample multi-round dialog based on enabled message classes.
423
+ Returns:
424
+ str: The sample dialog string.
425
+ """
426
+ enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
427
+ # use at most 2 sample conversations, no need to be exhaustive;
428
+ sample_convo = [
429
+ msg_cls().usage_examples(random=True) # type: ignore
430
+ for i, msg_cls in enumerate(enabled_classes)
431
+ if i < 2
432
+ ]
433
+ return "\n\n".join(sample_convo)
434
+
435
+ def create_agent_response(
436
+ self,
437
+ content: str | None = None,
438
+ content_any: Any = None,
439
+ tool_messages: List[ToolMessage] = [],
440
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
441
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
442
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
443
+ function_call: LLMFunctionCall | None = None,
444
+ recipient: str = "",
445
+ ) -> ChatDocument:
446
+ """Template for agent_response."""
447
+ return self.response_template(
448
+ Entity.AGENT,
449
+ content=content,
450
+ content_any=content_any,
451
+ tool_messages=tool_messages,
452
+ oai_tool_calls=oai_tool_calls,
453
+ oai_tool_choice=oai_tool_choice,
454
+ oai_tool_id2result=oai_tool_id2result,
455
+ function_call=function_call,
456
+ recipient=recipient,
457
+ )
458
+
459
+ def _agent_response_final(
460
+ self,
461
+ msg: Optional[str | ChatDocument],
462
+ results: Optional[str | OrderedDict[str, str] | ChatDocument],
463
+ ) -> Optional[ChatDocument]:
464
+ """
465
+ Convert results to final response.
466
+ """
467
+ if results is None:
468
+ return None
469
+ if isinstance(results, str):
470
+ results_str = results
471
+ elif isinstance(results, ChatDocument):
472
+ results_str = results.content
473
+ elif isinstance(results, dict):
474
+ results_str = json.dumps(results, indent=2)
475
+ if not settings.quiet:
476
+ console.print(f"[red]{self.indent}", end="")
477
+ print(f"[red]Agent: {escape(results_str)}")
478
+ maybe_json = len(extract_top_level_json(results_str)) > 0
479
+ self.callbacks.show_agent_response(
480
+ content=results_str,
481
+ language="json" if maybe_json else "text",
482
+ )
483
+ if isinstance(results, ChatDocument):
484
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
485
+ results.metadata.tool_ids = (
486
+ [] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
487
+ )
488
+ return results
489
+ sender_name = self.config.name
490
+ if isinstance(msg, ChatDocument) and msg.function_call is not None:
491
+ # if result was from handling an LLM `function_call`,
492
+ # set sender_name to name of the function_call
493
+ sender_name = msg.function_call.name
494
+
495
+ results_str, id2result, oai_tool_id = self.process_tool_results(
496
+ results if isinstance(results, str) else "",
497
+ id2result=None if isinstance(results, str) else results,
498
+ tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
499
+ )
500
+ return ChatDocument(
501
+ content=results_str,
502
+ oai_tool_id2result=id2result,
503
+ metadata=ChatDocMetaData(
504
+ source=Entity.AGENT,
505
+ sender=Entity.AGENT,
506
+ sender_name=sender_name,
507
+ oai_tool_id=oai_tool_id,
508
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
509
+ tool_ids=(
510
+ [] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
511
+ ),
512
+ ),
513
+ )
514
+
515
+ async def agent_response_async(
516
+ self,
517
+ msg: Optional[str | ChatDocument] = None,
518
+ ) -> Optional[ChatDocument]:
519
+ """
520
+ Asynch version of `agent_response`. See there for details.
521
+ """
522
+ if msg is None:
523
+ return None
524
+
525
+ results = await self.handle_message_async(msg)
526
+
527
+ return self._agent_response_final(msg, results)
528
+
529
+ def agent_response(
530
+ self,
531
+ msg: Optional[str | ChatDocument] = None,
532
+ ) -> Optional[ChatDocument]:
533
+ """
534
+ Response from the "agent itself", typically (but not only)
535
+ used to handle LLM's "tool message" or `function_call`
536
+ (e.g. OpenAI `function_call`).
537
+ Args:
538
+ msg (str|ChatDocument): the input to respond to: if msg is a string,
539
+ and it contains a valid JSON-structured "tool message", or
540
+ if msg is a ChatDocument, and it contains a `function_call`.
541
+ Returns:
542
+ Optional[ChatDocument]: the response, packaged as a ChatDocument
543
+
544
+ """
545
+ if msg is None:
546
+ return None
547
+
548
+ results = self.handle_message(msg)
549
+
550
+ return self._agent_response_final(msg, results)
551
+
552
+ def process_tool_results(
553
+ self,
554
+ results: str,
555
+ id2result: OrderedDict[str, str] | None,
556
+ tool_calls: List[OpenAIToolCall] | None = None,
557
+ ) -> Tuple[str, Dict[str, str] | None, str | None]:
558
+ """
559
+ Process results from a response, based on whether
560
+ they are results of OpenAI tool-calls from THIS agent, so that
561
+ we can construct an appropriate LLMMessage that contains tool results.
562
+
563
+ Args:
564
+ results (str): A possible string result from handling tool(s)
565
+ id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
566
+ if there are multiple tool results.
567
+ tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
568
+ results are a response to.
569
+
570
+ Return:
571
+ - str: The response string
572
+ - Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
573
+ multiple tool results.
574
+ - str|None: tool_id if there was a single tool result
575
+
576
+ """
577
+ id2result_ = copy.deepcopy(id2result) if id2result is not None else None
578
+ results_str = ""
579
+ oai_tool_id = None
580
+
581
+ if results != "":
582
+ # in this case ignore id2result
583
+ assert (
584
+ id2result is None
585
+ ), "id2result should be None when results string is non-empty!"
586
+ results_str = results
587
+ if len(self.oai_tool_calls) > 0:
588
+ # We only have one result, so in case there is a
589
+ # "pending" OpenAI tool-call, we expect no more than 1 such.
590
+ assert (
591
+ len(self.oai_tool_calls) == 1
592
+ ), "There are multiple pending tool-calls, but only one result!"
593
+ # We record the tool_id of the tool-call that
594
+ # the result is a response to, so that ChatDocument.to_LLMMessage
595
+ # can properly set the `tool_call_id` field of the LLMMessage.
596
+ oai_tool_id = self.oai_tool_calls[0].id
597
+ elif id2result is not None and id2result_ is not None: # appease mypy
598
+ if len(id2result_) == len(self.oai_tool_calls):
599
+ # if the number of pending tool calls equals the number of results,
600
+ # then ignore the ids in id2result, and use the results in order,
601
+ # which is preserved since id2result is an OrderedDict.
602
+ assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
603
+ results_str = ""
604
+ id2result_ = OrderedDict(
605
+ zip(
606
+ [tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
607
+ )
608
+ )
609
+ else:
610
+ assert (
611
+ tool_calls is not None
612
+ ), "tool_calls cannot be None when id2result is not None!"
613
+ # This must be an OpenAI tool id -> result map;
614
+ # However some ids may not correspond to the tool-calls in the list of
615
+ # pending tool-calls (self.oai_tool_calls).
616
+ # Such results are concatenated into a simple string, to store in the
617
+ # ChatDocument.content, and the rest
618
+ # (i.e. those that DO correspond to tools in self.oai_tool_calls)
619
+ # are stored as a dict in ChatDocument.oai_tool_id2result.
620
+
621
+ # OAI tools from THIS agent, awaiting response
622
+ pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
623
+ # tool_calls that the results are a response to
624
+ # (but these may have been sent from another agent, hence may not be in
625
+ # self.oai_tool_calls)
626
+ parent_tool_id2name = {
627
+ tc.id: tc.function.name
628
+ for tc in tool_calls or []
629
+ if tc.function is not None
630
+ }
631
+
632
+ # (id, result) for result NOT corresponding to self.oai_tool_calls,
633
+ # i.e. these are results of EXTERNAL tool-calls from another agent.
634
+ external_tool_id_results = []
635
+
636
+ for tc_id, result in id2result.items():
637
+ if tc_id not in pending_tool_ids:
638
+ external_tool_id_results.append((tc_id, result))
639
+ id2result_.pop(tc_id)
640
+ if len(external_tool_id_results) == 0:
641
+ results_str = ""
642
+ elif len(external_tool_id_results) == 1:
643
+ results_str = external_tool_id_results[0][1]
644
+ else:
645
+ results_str = "\n\n".join(
646
+ [
647
+ f"Result from tool/function "
648
+ f"{parent_tool_id2name[id]}: {result}"
649
+ for id, result in external_tool_id_results
650
+ ]
651
+ )
652
+
653
+ if len(id2result_) == 0:
654
+ id2result_ = None
655
+ elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
656
+ results_str = list(id2result_.values())[0]
657
+ oai_tool_id = list(id2result_.keys())[0]
658
+ id2result_ = None
659
+
660
+ return results_str, id2result_, oai_tool_id
661
+
662
+ def response_template(
663
+ self,
664
+ e: Entity,
665
+ content: str | None = None,
666
+ content_any: Any = None,
667
+ tool_messages: List[ToolMessage] = [],
668
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
669
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
670
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
671
+ function_call: LLMFunctionCall | None = None,
672
+ recipient: str = "",
673
+ ) -> ChatDocument:
674
+ """Template for response from entity `e`."""
675
+ return ChatDocument(
676
+ content=content or "",
677
+ content_any=content_any,
678
+ tool_messages=tool_messages,
679
+ oai_tool_calls=oai_tool_calls,
680
+ oai_tool_id2result=oai_tool_id2result,
681
+ function_call=function_call,
682
+ oai_tool_choice=oai_tool_choice,
683
+ metadata=ChatDocMetaData(
684
+ source=e, sender=e, sender_name=self.config.name, recipient=recipient
685
+ ),
686
+ )
687
+
688
+ def create_user_response(
689
+ self,
690
+ content: str | None = None,
691
+ content_any: Any = None,
692
+ tool_messages: List[ToolMessage] = [],
693
+ oai_tool_calls: List[OpenAIToolCall] | None = None,
694
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
695
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
696
+ function_call: LLMFunctionCall | None = None,
697
+ recipient: str = "",
698
+ ) -> ChatDocument:
699
+ """Template for user_response."""
700
+ return self.response_template(
701
+ e=Entity.USER,
702
+ content=content,
703
+ content_any=content_any,
704
+ tool_messages=tool_messages,
705
+ oai_tool_calls=oai_tool_calls,
706
+ oai_tool_choice=oai_tool_choice,
707
+ oai_tool_id2result=oai_tool_id2result,
708
+ function_call=function_call,
709
+ recipient=recipient,
710
+ )
711
+
712
+ def user_can_respond(self, msg: Optional[str | ChatDocument] = None) -> bool:
713
+ """
714
+ Whether the user can respond to a message.
715
+
716
+ Args:
717
+ msg (str|ChatDocument): the string to respond to.
718
+
719
+ Returns:
720
+
721
+ """
722
+ # When msg explicitly addressed to user, this means an actual human response
723
+ # is being sought.
724
+ need_human_response = (
725
+ isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
726
+ )
727
+
728
+ if not self.interactive and not need_human_response:
729
+ return False
730
+
731
+ return True
732
+
733
+ def _user_response_final(
734
+ self, msg: Optional[str | ChatDocument], user_msg: str
735
+ ) -> Optional[ChatDocument]:
736
+ """
737
+ Convert user_msg to final response.
738
+ """
739
+ if not user_msg:
740
+ need_human_response = (
741
+ isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
742
+ )
743
+ user_msg = (
744
+ (self.default_human_response or "null") if need_human_response else ""
745
+ )
746
+ user_msg = user_msg.strip()
747
+
748
+ tool_ids = []
749
+ if msg is not None and isinstance(msg, ChatDocument):
750
+ tool_ids = msg.metadata.tool_ids
751
+
752
+ # only return non-None result if user_msg not empty
753
+ if not user_msg:
754
+ return None
755
+ else:
756
+ if user_msg.startswith("SYSTEM"):
757
+ user_msg = user_msg.replace("SYSTEM", "").strip()
758
+ source = Entity.SYSTEM
759
+ sender = Entity.SYSTEM
760
+ else:
761
+ source = Entity.USER
762
+ sender = Entity.USER
763
+ return ChatDocument(
764
+ content=user_msg,
765
+ metadata=ChatDocMetaData(
766
+ source=source,
767
+ sender=sender,
768
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
769
+ tool_ids=tool_ids,
770
+ ),
771
+ )
772
+
773
+ async def user_response_async(
774
+ self,
775
+ msg: Optional[str | ChatDocument] = None,
776
+ ) -> Optional[ChatDocument]:
777
+ """
778
+ Asynch version of `user_response`. See there for details.
779
+ """
780
+ if not self.user_can_respond(msg):
781
+ return None
782
+
783
+ if self.default_human_response is not None:
784
+ user_msg = self.default_human_response
785
+ else:
786
+ if (
787
+ self.callbacks.get_user_response_async is not None
788
+ and self.callbacks.get_user_response_async is not async_noop_fn
789
+ ):
790
+ user_msg = await self.callbacks.get_user_response_async(prompt="")
791
+ elif self.callbacks.get_user_response is not None:
792
+ user_msg = self.callbacks.get_user_response(prompt="")
793
+ else:
794
+ user_msg = Prompt.ask(
795
+ f"[blue]{self.indent}"
796
+ + self.config.human_prompt
797
+ + f"\n{self.indent}"
798
+ )
799
+
800
+ return self._user_response_final(msg, user_msg)
801
+
802
+ def user_response(
803
+ self,
804
+ msg: Optional[str | ChatDocument] = None,
805
+ ) -> Optional[ChatDocument]:
806
+ """
807
+ Get user response to current message. Could allow (human) user to intervene
808
+ with an actual answer, or quit using "q" or "x"
809
+
810
+ Args:
811
+ msg (str|ChatDocument): the string to respond to.
812
+
813
+ Returns:
814
+ (str) User response, packaged as a ChatDocument
815
+
816
+ """
817
+
818
+ if not self.user_can_respond(msg):
819
+ return None
820
+
821
+ if self.default_human_response is not None:
822
+ user_msg = self.default_human_response
823
+ else:
824
+ if self.callbacks.get_user_response is not None:
825
+ # ask user with empty prompt: no need for prompt
826
+ # since user has seen the conversation so far.
827
+ # But non-empty prompt can be useful when Agent
828
+ # uses a tool that requires user input, or in other scenarios.
829
+ user_msg = self.callbacks.get_user_response(prompt="")
830
+ else:
831
+ user_msg = Prompt.ask(
832
+ f"[blue]{self.indent}"
833
+ + self.config.human_prompt
834
+ + f"\n{self.indent}"
835
+ )
836
+
837
+ return self._user_response_final(msg, user_msg)
838
+
839
+ @no_type_check
840
+ def llm_can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
841
+ """
842
+ Whether the LLM can respond to a message.
843
+ Args:
844
+ message (str|ChatDocument): message or ChatDocument object to respond to.
845
+
846
+ Returns:
847
+
848
+ """
849
+ if self.llm is None:
850
+ return False
851
+
852
+ if message is not None and len(self.try_get_tool_messages(message)) > 0:
853
+ # if there is a valid "tool" message (either JSON or via `function_call`)
854
+ # then LLM cannot respond to it
855
+ return False
856
+
857
+ return True
858
+
859
+ def can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
860
+ """
861
+ Whether the agent can respond to a message.
862
+ Used in Task.py to skip a sub-task when we know it would not respond.
863
+ Args:
864
+ message (str|ChatDocument): message or ChatDocument object to respond to.
865
+ """
866
+ tools = self.try_get_tool_messages(message)
867
+ if len(tools) == 0 and self.config.respond_tools_only:
868
+ return False
869
+ if message is not None and self.has_only_unhandled_tools(message):
870
+ # The message has tools that are NOT enabled to be handled by this agent,
871
+ # which means the agent cannot respond to it.
872
+ return False
873
+ return True
874
+
875
+ def create_llm_response(
876
+ self,
877
+ content: str | None = None,
878
+ content_any: Any = None,
879
+ tool_messages: List[ToolMessage] = [],
880
+ oai_tool_calls: None | List[OpenAIToolCall] = None,
881
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
882
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
883
+ function_call: LLMFunctionCall | None = None,
884
+ recipient: str = "",
885
+ ) -> ChatDocument:
886
+ """Template for llm_response."""
887
+ return self.response_template(
888
+ Entity.LLM,
889
+ content=content,
890
+ content_any=content_any,
891
+ tool_messages=tool_messages,
892
+ oai_tool_calls=oai_tool_calls,
893
+ oai_tool_choice=oai_tool_choice,
894
+ oai_tool_id2result=oai_tool_id2result,
895
+ function_call=function_call,
896
+ recipient=recipient,
897
+ )
898
+
899
+ @no_type_check
900
+ async def llm_response_async(
901
+ self,
902
+ message: Optional[str | ChatDocument] = None,
903
+ ) -> Optional[ChatDocument]:
904
+ """
905
+ Asynch version of `llm_response`. See there for details.
906
+ """
907
+ if message is None or not self.llm_can_respond(message):
908
+ return None
909
+
910
+ if isinstance(message, ChatDocument):
911
+ prompt = message.content
912
+ else:
913
+ prompt = message
914
+
915
+ output_len = self.config.llm.max_output_tokens
916
+ if self.num_tokens(prompt) + output_len > self.llm.completion_context_length():
917
+ output_len = self.llm.completion_context_length() - self.num_tokens(prompt)
918
+ if output_len < self.config.llm.min_output_tokens:
919
+ raise ValueError(
920
+ """
921
+ Token-length of Prompt + Output is longer than the
922
+ completion context length of the LLM!
923
+ """
924
+ )
925
+ else:
926
+ logger.warning(
927
+ f"""
928
+ Requested output length has been shortened to {output_len}
929
+ so that the total length of Prompt + Output is less than
930
+ the completion context length of the LLM.
931
+ """
932
+ )
933
+
934
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
935
+ response = await self.llm.agenerate(prompt, output_len)
936
+
937
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
938
+ # We would have already displayed the msg "live" ONLY if
939
+ # streaming was enabled, AND we did not find a cached response.
940
+ # If we are here, it means the response has not yet been displayed.
941
+ cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
942
+ print(cached + "[green]" + escape(response.message))
943
+ async with self.lock:
944
+ self.update_token_usage(
945
+ response,
946
+ prompt,
947
+ self.llm.get_stream(),
948
+ chat=False, # i.e. it's a completion model not chat model
949
+ print_response_stats=self.config.show_stats and not settings.quiet,
950
+ )
951
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
952
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
953
+ cdoc.metadata.tool_ids = (
954
+ [] if isinstance(message, str) else message.metadata.tool_ids
955
+ )
956
+ return cdoc
957
+
958
+ @no_type_check
959
+ def llm_response(
960
+ self,
961
+ message: Optional[str | ChatDocument] = None,
962
+ ) -> Optional[ChatDocument]:
963
+ """
964
+ LLM response to a prompt.
965
+ Args:
966
+ message (str|ChatDocument): prompt string, or ChatDocument object
967
+
968
+ Returns:
969
+ Response from LLM, packaged as a ChatDocument
970
+ """
971
+ if message is None or not self.llm_can_respond(message):
972
+ return None
973
+
974
+ if isinstance(message, ChatDocument):
975
+ prompt = message.content
976
+ else:
977
+ prompt = message
978
+
979
+ with ExitStack() as stack: # for conditionally using rich spinner
980
+ if not self.llm.get_stream():
981
+ # show rich spinner only if not streaming!
982
+ cm = status("LLM responding to message...")
983
+ stack.enter_context(cm)
984
+ output_len = self.config.llm.max_output_tokens
985
+ if (
986
+ self.num_tokens(prompt) + output_len
987
+ > self.llm.completion_context_length()
988
+ ):
989
+ output_len = self.llm.completion_context_length() - self.num_tokens(
990
+ prompt
991
+ )
992
+ if output_len < self.config.llm.min_output_tokens:
993
+ raise ValueError(
994
+ """
995
+ Token-length of Prompt + Output is longer than the
996
+ completion context length of the LLM!
997
+ """
998
+ )
999
+ else:
1000
+ logger.warning(
1001
+ f"""
1002
+ Requested output length has been shortened to {output_len}
1003
+ so that the total length of Prompt + Output is less than
1004
+ the completion context length of the LLM.
1005
+ """
1006
+ )
1007
+ if self.llm.get_stream() and not settings.quiet:
1008
+ console.print(f"[green]{self.indent}", end="")
1009
+ response = self.llm.generate(prompt, output_len)
1010
+
1011
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
1012
+ # we would have already displayed the msg "live" ONLY if
1013
+ # streaming was enabled, AND we did not find a cached response
1014
+ # If we are here, it means the response has not yet been displayed.
1015
+ cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
1016
+ console.print(f"[green]{self.indent}", end="")
1017
+ print(cached + "[green]" + escape(response.message))
1018
+ self.update_token_usage(
1019
+ response,
1020
+ prompt,
1021
+ self.llm.get_stream(),
1022
+ chat=False, # i.e. it's a completion model not chat model
1023
+ print_response_stats=self.config.show_stats and not settings.quiet,
1024
+ )
1025
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
1026
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
1027
+ cdoc.metadata.tool_ids = (
1028
+ [] if isinstance(message, str) else message.metadata.tool_ids
1029
+ )
1030
+ return cdoc
1031
+
1032
+ def has_tool_message_attempt(self, msg: str | ChatDocument | None) -> bool:
1033
+ """
1034
+ Check whether msg contains a Tool/fn-call attempt (by the LLM).
1035
+
1036
+ CAUTION: This uses self.get_tool_messages(msg) which as a side-effect
1037
+ may update msg.tool_messages when msg is a ChatDocument, if there are
1038
+ any tools in msg.
1039
+ """
1040
+ if msg is None:
1041
+ return False
1042
+ try:
1043
+ tools = self.get_tool_messages(msg)
1044
+ return len(tools) > 0
1045
+ except (ValidationError, XMLException):
1046
+ # there is a tool/fn-call attempt but had a validation error,
1047
+ # so we still consider this a tool message "attempt"
1048
+ return True
1049
+ return False
1050
+
1051
+ def _tool_recipient_match(self, tool: ToolMessage) -> bool:
1052
+ """Is tool is handled by this agent
1053
+ and an explicit `recipient` field doesn't preclude this agent from handling it?
1054
+ """
1055
+ if tool.default_value("request") not in self.llm_tools_handled:
1056
+ return False
1057
+ if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
1058
+ return tool.recipient == "" or tool.recipient == self.config.name
1059
+ return True
1060
+
1061
+ def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
1062
+ """
1063
+ Does the msg have at least one tool, and ALL tools are
1064
+ disabled for handling by this agent?
1065
+ """
1066
+ if msg is None:
1067
+ return False
1068
+ tools = self.try_get_tool_messages(msg, all_tools=True)
1069
+ if len(tools) == 0:
1070
+ return False
1071
+ return all(not self._tool_recipient_match(t) for t in tools)
1072
+
1073
+ def try_get_tool_messages(
1074
+ self,
1075
+ msg: str | ChatDocument | None,
1076
+ all_tools: bool = False,
1077
+ ) -> List[ToolMessage]:
1078
+ try:
1079
+ return self.get_tool_messages(msg, all_tools)
1080
+ except (ValidationError, XMLException):
1081
+ return []
1082
+
1083
+ def get_tool_messages(
1084
+ self,
1085
+ msg: str | ChatDocument | None,
1086
+ all_tools: bool = False,
1087
+ ) -> List[ToolMessage]:
1088
+ """
1089
+ Get ToolMessages recognized in msg, handle-able by this agent.
1090
+ NOTE: as a side-effect, this will update msg.tool_messages
1091
+ when msg is a ChatDocument and msg contains tool messages.
1092
+ The intent here is that update=True should be set ONLY within agent_response()
1093
+ or agent_response_async() methods. In other words, we want to persist the
1094
+ msg.tool_messages only AFTER the agent has had a chance to handle the tools.
1095
+
1096
+ Args:
1097
+ msg (str|ChatDocument): the message to extract tools from.
1098
+ all_tools (bool):
1099
+ - if True, return all tools,
1100
+ i.e. any recognized tool in self.llm_tools_known,
1101
+ whether it is handled by this agent or not;
1102
+ - otherwise, return only the tools handled by this agent.
1103
+
1104
+ Returns:
1105
+ List[ToolMessage]: list of ToolMessage objects
1106
+ """
1107
+
1108
+ if msg is None:
1109
+ return []
1110
+
1111
+ if isinstance(msg, str):
1112
+ json_tools = self.get_formatted_tool_messages(msg)
1113
+ if all_tools:
1114
+ return json_tools
1115
+ else:
1116
+ return [
1117
+ t
1118
+ for t in json_tools
1119
+ if self._tool_recipient_match(t) and t.default_value("request")
1120
+ ]
1121
+
1122
+ if all_tools and len(msg.all_tool_messages) > 0:
1123
+ # We've already identified all_tool_messages in the msg;
1124
+ # return the corresponding ToolMessage objects
1125
+ return msg.all_tool_messages
1126
+ if len(msg.tool_messages) > 0:
1127
+ # We've already found tool_messages,
1128
+ # (either via OpenAI Fn-call or Langroid-native ToolMessage);
1129
+ # or they were added by an agent_response.
1130
+ # note these could be from a forwarded msg from another agent,
1131
+ # so return ONLY the messages THIS agent to enabled to handle.
1132
+ if all_tools:
1133
+ return msg.tool_messages
1134
+ return [t for t in msg.tool_messages if self._tool_recipient_match(t)]
1135
+ assert isinstance(msg, ChatDocument)
1136
+ if (
1137
+ msg.content != ""
1138
+ and msg.oai_tool_calls is None
1139
+ and msg.function_call is None
1140
+ ):
1141
+
1142
+ tools = self.get_formatted_tool_messages(msg.content)
1143
+ msg.all_tool_messages = tools
1144
+ # filter for actually handle-able tools, and recipient is this agent
1145
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1146
+ msg.tool_messages = my_tools
1147
+
1148
+ if all_tools:
1149
+ return tools
1150
+ else:
1151
+ return my_tools
1152
+
1153
+ # otherwise, we look for `tool_calls` (possibly multiple)
1154
+ tools = self.get_oai_tool_calls_classes(msg)
1155
+ msg.all_tool_messages = tools
1156
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1157
+ msg.tool_messages = my_tools
1158
+
1159
+ if len(tools) == 0:
1160
+ # otherwise, we look for a `function_call`
1161
+ fun_call_cls = self.get_function_call_class(msg)
1162
+ tools = [fun_call_cls] if fun_call_cls is not None else []
1163
+ msg.all_tool_messages = tools
1164
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1165
+ msg.tool_messages = my_tools
1166
+ if all_tools:
1167
+ return tools
1168
+ else:
1169
+ return my_tools
1170
+
1171
+ def get_formatted_tool_messages(self, input_str: str) -> List[ToolMessage]:
1172
+ """
1173
+ Returns ToolMessage objects (tools) corresponding to
1174
+ tool-formatted substrings, if any.
1175
+ ASSUMPTION - These tools are either ALL JSON-based, or ALL XML-based
1176
+ (i.e. not a mix of both).
1177
+ Terminology: a "formatted tool msg" is one which the LLM generates as
1178
+ part of its raw string output, rather than within a JSON object
1179
+ in the API response (i.e. this method does not extract tools/fns returned
1180
+ by OpenAI's tools/fns API or similar APIs).
1181
+
1182
+ Args:
1183
+ input_str (str): input string, typically a message sent by an LLM
1184
+
1185
+ Returns:
1186
+ List[ToolMessage]: list of ToolMessage objects
1187
+ """
1188
+ self.tool_error = False
1189
+ substrings = XMLToolMessage.find_candidates(input_str)
1190
+ is_json = False
1191
+ if len(substrings) == 0:
1192
+ substrings = extract_top_level_json(input_str)
1193
+ is_json = len(substrings) > 0
1194
+ if not is_json:
1195
+ return []
1196
+
1197
+ results = [self._get_one_tool_message(j, is_json) for j in substrings]
1198
+ valid_results = [r for r in results if r is not None]
1199
+ # If any tool is correctly formed we do not set the flag
1200
+ if len(valid_results) > 0:
1201
+ self.tool_error = False
1202
+ return valid_results
1203
+
1204
+ def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
1205
+ """
1206
+ From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
1207
+ corresponding to the `function_call` if it exists.
1208
+ """
1209
+ if msg.function_call is None:
1210
+ return None
1211
+ tool_name = msg.function_call.name
1212
+ tool_msg = msg.function_call.arguments or {}
1213
+ if tool_name not in self.llm_tools_handled:
1214
+ logger.warning(
1215
+ f"""
1216
+ The function_call '{tool_name}' is not handled
1217
+ by the agent named '{self.config.name}'!
1218
+ If you intended this agent to handle this function_call,
1219
+ either the fn-call name is incorrectly generated by the LLM,
1220
+ (in which case you may need to adjust your LLM instructions),
1221
+ or you need to enable this agent to handle this fn-call.
1222
+ """
1223
+ )
1224
+ if tool_name not in self.all_llm_tools_known:
1225
+ self.tool_error = True
1226
+ return None
1227
+ self.tool_error = False
1228
+ tool_class = self.llm_tools_map[tool_name]
1229
+ tool_msg.update(dict(request=tool_name))
1230
+ tool = tool_class.parse_obj(tool_msg)
1231
+ return tool
1232
+
1233
+ def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
1234
+ """
1235
+ From ChatDocument (constructed from an LLM Response), get
1236
+ a list of ToolMessages corresponding to the `tool_calls`, if any.
1237
+ """
1238
+
1239
+ if msg.oai_tool_calls is None:
1240
+ return []
1241
+ tools = []
1242
+ all_errors = True
1243
+ for tc in msg.oai_tool_calls:
1244
+ if tc.function is None:
1245
+ continue
1246
+ tool_name = tc.function.name
1247
+ tool_msg = tc.function.arguments or {}
1248
+ if tool_name not in self.llm_tools_handled:
1249
+ logger.warning(
1250
+ f"""
1251
+ The tool_call '{tool_name}' is not handled
1252
+ by the agent named '{self.config.name}'!
1253
+ If you intended this agent to handle this function_call,
1254
+ either the fn-call name is incorrectly generated by the LLM,
1255
+ (in which case you may need to adjust your LLM instructions),
1256
+ or you need to enable this agent to handle this fn-call.
1257
+ """
1258
+ )
1259
+ continue
1260
+ all_errors = False
1261
+ tool_class = self.llm_tools_map[tool_name]
1262
+ tool_msg.update(dict(request=tool_name))
1263
+ tool = tool_class.parse_obj(tool_msg)
1264
+ tool.id = tc.id or ""
1265
+ tools.append(tool)
1266
+ # When no tool is valid, set the recovery flag
1267
+ self.tool_error = all_errors
1268
+ return tools
1269
+
1270
+ def tool_validation_error(self, ve: ValidationError) -> str:
1271
+ """
1272
+ Handle a validation error raised when parsing a tool message,
1273
+ when there is a legit tool name used, but it has missing/bad fields.
1274
+ Args:
1275
+ tool (ToolMessage): The tool message that failed validation
1276
+ ve (ValidationError): The exception raised
1277
+
1278
+ Returns:
1279
+ str: The error message to send back to the LLM
1280
+ """
1281
+ tool_name = cast(ToolMessage, ve.model).default_value("request")
1282
+ bad_field_errors = "\n".join(
1283
+ [f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
1284
+ )
1285
+ return f"""
1286
+ There were one or more errors in your attempt to use the
1287
+ TOOL or function_call named '{tool_name}':
1288
+ {bad_field_errors}
1289
+ Please write your message again, correcting the errors.
1290
+ """
1291
+
1292
+ def _get_multiple_orch_tool_errs(
1293
+ self, tools: List[ToolMessage]
1294
+ ) -> List[str | ChatDocument | None]:
1295
+ """
1296
+ Return error document if the message contains multiple orchestration tools
1297
+ """
1298
+ # check whether there are multiple orchestration-tools (e.g. DoneTool etc),
1299
+ # in which case set result to error-string since we don't yet support
1300
+ # multi-tools with one or more orch tools.
1301
+ from langroid.agent.tools.orchestration import (
1302
+ AgentDoneTool,
1303
+ AgentSendTool,
1304
+ DonePassTool,
1305
+ DoneTool,
1306
+ ForwardTool,
1307
+ PassTool,
1308
+ SendTool,
1309
+ )
1310
+ from langroid.agent.tools.recipient_tool import RecipientTool
1311
+
1312
+ ORCHESTRATION_TOOLS = (
1313
+ AgentDoneTool,
1314
+ DoneTool,
1315
+ PassTool,
1316
+ DonePassTool,
1317
+ ForwardTool,
1318
+ RecipientTool,
1319
+ SendTool,
1320
+ AgentSendTool,
1321
+ )
1322
+
1323
+ has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
1324
+ if has_orch and len(tools) > 1:
1325
+ err_str = "ERROR: Use ONE tool at a time!"
1326
+ return [err_str for _ in tools]
1327
+
1328
+ return []
1329
+
1330
+ def _handle_message_final(
1331
+ self, tools: List[ToolMessage], results: List[str | ChatDocument | None]
1332
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1333
+ """
1334
+ Convert results to final response
1335
+ """
1336
+ # extract content from ChatDocument results so we have all str|None
1337
+ results = [r.content if isinstance(r, ChatDocument) else r for r in results]
1338
+
1339
+ tool_names = [t.default_value("request") for t in tools]
1340
+
1341
+ has_ids = all([t.id != "" for t in tools])
1342
+ if has_ids:
1343
+ id2result = OrderedDict(
1344
+ (t.id, r)
1345
+ for t, r in zip(tools, results)
1346
+ if r is not None and isinstance(r, str)
1347
+ )
1348
+ result_values = list(id2result.values())
1349
+ if len(id2result) > 1 and any(
1350
+ orch_str in r
1351
+ for r in result_values
1352
+ for orch_str in ORCHESTRATION_STRINGS
1353
+ ):
1354
+ # Cannot support multi-tool results containing orchestration strings!
1355
+ # Replace results with err string to force LLM to retry
1356
+ err_str = "ERROR: Please use ONE tool at a time!"
1357
+ id2result = OrderedDict((id, err_str) for id in id2result.keys())
1358
+
1359
+ name_results_list = [
1360
+ (name, r) for name, r in zip(tool_names, results) if r is not None
1361
+ ]
1362
+ if len(name_results_list) == 0:
1363
+ return None
1364
+
1365
+ # there was a non-None result
1366
+
1367
+ if has_ids and len(id2result) > 1:
1368
+ # if there are multiple OpenAI Tool results, return them as a dict
1369
+ return id2result
1370
+
1371
+ # multi-results: prepend the tool name to each result
1372
+ str_results = [f"Result from {name}: {r}" for name, r in name_results_list]
1373
+ final = "\n\n".join(str_results)
1374
+ return final
1375
+
1376
+ async def handle_message_async(
1377
+ self, msg: str | ChatDocument
1378
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1379
+ """
1380
+ Asynch version of `handle_message`. See there for details.
1381
+ """
1382
+ try:
1383
+ tools = self.get_tool_messages(msg)
1384
+ tools = [t for t in tools if self._tool_recipient_match(t)]
1385
+ except ValidationError as ve:
1386
+ # correct tool name but bad fields
1387
+ return self.tool_validation_error(ve)
1388
+ except XMLException as xe: # from XMLToolMessage parsing
1389
+ return str(xe)
1390
+ except ValueError:
1391
+ # invalid tool name
1392
+ # We return None since returning "invalid tool name" would
1393
+ # be considered a valid result in task loop, and would be treated
1394
+ # as a response to the tool message even though the tool was not intended
1395
+ # for this agent.
1396
+ return None
1397
+ if len(tools) > 1 and not self.config.allow_multiple_tools:
1398
+ return self.to_ChatDocument("ERROR: Use ONE tool at a time!")
1399
+ if len(tools) == 0:
1400
+ fallback_result = self.handle_message_fallback(msg)
1401
+ if fallback_result is None:
1402
+ return None
1403
+ return self.to_ChatDocument(
1404
+ fallback_result,
1405
+ chat_doc=msg if isinstance(msg, ChatDocument) else None,
1406
+ )
1407
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1408
+
1409
+ results = self._get_multiple_orch_tool_errs(tools)
1410
+ if not results:
1411
+ results = [
1412
+ await self.handle_tool_message_async(t, chat_doc=chat_doc)
1413
+ for t in tools
1414
+ ]
1415
+ # if there's a solitary ChatDocument|str result, return it as is
1416
+ if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
1417
+ return results[0]
1418
+
1419
+ return self._handle_message_final(tools, results)
1420
+
1421
+ def handle_message(
1422
+ self, msg: str | ChatDocument
1423
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1424
+ """
1425
+ Handle a "tool" message either a string containing one or more
1426
+ valid "tool" JSON substrings, or a
1427
+ ChatDocument containing a `function_call` attribute.
1428
+ Handle with the corresponding handler method, and return
1429
+ the results as a combined string.
1430
+
1431
+ Args:
1432
+ msg (str | ChatDocument): The string or ChatDocument to handle
1433
+
1434
+ Returns:
1435
+ The result of the handler method can be:
1436
+ - None if no tools successfully handled, or no tools present
1437
+ - str if langroid-native JSON tools were handled, and results concatenated,
1438
+ OR there's a SINGLE OpenAI tool-call.
1439
+ (We do this so the common scenario of a single tool/fn-call
1440
+ has a simple behavior).
1441
+ - Dict[str, str] if multiple OpenAI tool-calls were handled
1442
+ (dict is an id->result map)
1443
+ - ChatDocument if a handler returned a ChatDocument, intended to be the
1444
+ final response of the `agent_response` method.
1445
+ """
1446
+ try:
1447
+ tools = self.get_tool_messages(msg)
1448
+ tools = [t for t in tools if self._tool_recipient_match(t)]
1449
+ except ValidationError as ve:
1450
+ # correct tool name but bad fields
1451
+ return self.tool_validation_error(ve)
1452
+ except XMLException as xe: # from XMLToolMessage parsing
1453
+ return str(xe)
1454
+ except ValueError:
1455
+ # invalid tool name
1456
+ # We return None since returning "invalid tool name" would
1457
+ # be considered a valid result in task loop, and would be treated
1458
+ # as a response to the tool message even though the tool was not intended
1459
+ # for this agent.
1460
+ return None
1461
+ if len(tools) > 1 and not self.config.allow_multiple_tools:
1462
+ return self.to_ChatDocument("ERROR: Use ONE tool at a time!")
1463
+ if len(tools) == 0:
1464
+ fallback_result = self.handle_message_fallback(msg)
1465
+ if fallback_result is None:
1466
+ return None
1467
+ return self.to_ChatDocument(
1468
+ fallback_result,
1469
+ chat_doc=msg if isinstance(msg, ChatDocument) else None,
1470
+ )
1471
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1472
+
1473
+ results = self._get_multiple_orch_tool_errs(tools)
1474
+ if not results:
1475
+ results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
1476
+ # if there's a solitary ChatDocument|str result, return it as is
1477
+ if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
1478
+ return results[0]
1479
+
1480
+ return self._handle_message_final(tools, results)
1481
+
1482
+ @property
1483
+ def all_llm_tools_known(self) -> set[str]:
1484
+ """All known tools; this may extend self.llm_tools_known."""
1485
+ return self.llm_tools_known
1486
+
1487
+ def handle_message_fallback(self, msg: str | ChatDocument) -> Any:
1488
+ """
1489
+ Fallback method for the "no-tools" scenario.
1490
+ This method can be overridden by subclasses, e.g.,
1491
+ to create a "reminder" message when a tool is expected but the LLM "forgot"
1492
+ to generate one.
1493
+
1494
+ Args:
1495
+ msg (str | ChatDocument): The input msg to handle
1496
+ Returns:
1497
+ Any: The result of the handler method
1498
+ """
1499
+ return None
1500
+
1501
+ def _get_one_tool_message(
1502
+ self, tool_candidate_str: str, is_json: bool = True
1503
+ ) -> Optional[ToolMessage]:
1504
+ """
1505
+ Parse the tool_candidate_str into ANY ToolMessage KNOWN to agent --
1506
+ This includes non-used/handled tools, i.e. any tool in self.all_llm_tools_known.
1507
+ The exception to this is below where we try our best to infer the tool
1508
+ when the LLM has "forgotten" to include the "request" field in the tool str ---
1509
+ in this case we ONLY look at the possible set of HANDLED tools, i.e.
1510
+ self.llm_tools_handled.
1511
+ """
1512
+ if is_json:
1513
+ maybe_tool_dict = json.loads(tool_candidate_str)
1514
+ else:
1515
+ try:
1516
+ maybe_tool_dict = XMLToolMessage.extract_field_values(
1517
+ tool_candidate_str
1518
+ )
1519
+ except Exception as e:
1520
+ from langroid.exceptions import XMLException
1521
+
1522
+ raise XMLException(f"Error extracting XML fields:\n {str(e)}")
1523
+ # check if the maybe_tool_dict contains a "properties" field
1524
+ # which further contains the actual tool-call
1525
+ # (some weak LLMs do this). E.g. gpt-4o sometimes generates this:
1526
+ # TOOL: {
1527
+ # "type": "object",
1528
+ # "properties": {
1529
+ # "request": "square",
1530
+ # "number": 9
1531
+ # },
1532
+ # "required": [
1533
+ # "number",
1534
+ # "request"
1535
+ # ]
1536
+ # }
1537
+
1538
+ if not isinstance(maybe_tool_dict, dict):
1539
+ self.tool_error = True
1540
+ return None
1541
+
1542
+ properties = maybe_tool_dict.get("properties")
1543
+ if isinstance(properties, dict):
1544
+ maybe_tool_dict = properties
1545
+ request = maybe_tool_dict.get("request")
1546
+ if request is None:
1547
+ if self.enabled_requests_for_inference is None:
1548
+ possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
1549
+ else:
1550
+ allowable = self.enabled_requests_for_inference.intersection(
1551
+ self.llm_tools_handled
1552
+ )
1553
+ possible = [self.llm_tools_map[r] for r in allowable]
1554
+
1555
+ default_keys = set(ToolMessage.__fields__.keys())
1556
+ request_keys = set(maybe_tool_dict.keys())
1557
+
1558
+ def maybe_parse(tool: type[ToolMessage]) -> Optional[ToolMessage]:
1559
+ all_keys = set(tool.__fields__.keys())
1560
+ non_inherited_keys = all_keys.difference(default_keys)
1561
+ # If the request has any keys not valid for the tool and
1562
+ # does not specify some key specific to the type
1563
+ # (e.g. not just `purpose`), the LLM must explicitly specify `request`
1564
+ if not (
1565
+ request_keys.issubset(all_keys)
1566
+ and len(request_keys.intersection(non_inherited_keys)) > 0
1567
+ ):
1568
+ return None
1569
+
1570
+ try:
1571
+ return tool.parse_obj(maybe_tool_dict)
1572
+ except ValidationError:
1573
+ return None
1574
+
1575
+ candidate_tools = list(
1576
+ filter(
1577
+ lambda t: t is not None,
1578
+ map(maybe_parse, possible),
1579
+ )
1580
+ )
1581
+
1582
+ # If only one valid candidate exists, we infer
1583
+ # "request" to be the only possible value
1584
+ if len(candidate_tools) == 1:
1585
+ return candidate_tools[0]
1586
+ else:
1587
+ self.tool_error = True
1588
+ return None
1589
+
1590
+ if not isinstance(request, str) or request not in self.all_llm_tools_known:
1591
+ self.tool_error = True
1592
+ return None
1593
+
1594
+ message_class = self.llm_tools_map.get(request)
1595
+ if message_class is None:
1596
+ logger.warning(f"No message class found for request '{request}'")
1597
+ self.tool_error = True
1598
+ return None
1599
+
1600
+ try:
1601
+ message = message_class.parse_obj(maybe_tool_dict)
1602
+ except ValidationError as ve:
1603
+ self.tool_error = True
1604
+ raise ve
1605
+ return message
1606
+
1607
+ def to_ChatDocument(
1608
+ self,
1609
+ msg: Any,
1610
+ orig_tool_name: str | None = None,
1611
+ chat_doc: Optional[ChatDocument] = None,
1612
+ author_entity: Entity = Entity.AGENT,
1613
+ ) -> Optional[ChatDocument]:
1614
+ """
1615
+ Convert result of a responder (agent_response or llm_response, or task.run()),
1616
+ or tool handler, or handle_message_fallback,
1617
+ to a ChatDocument, to enable handling by other
1618
+ responders/tasks in a task loop possibly involving multiple agents.
1619
+
1620
+ Args:
1621
+ msg (Any): The result of a responder or tool handler or task.run()
1622
+ orig_tool_name (str): The original tool name that generated the response,
1623
+ if any.
1624
+ chat_doc (ChatDocument): The original ChatDocument object that `msg`
1625
+ is a response to.
1626
+ author_entity (Entity): The intended author of the result ChatDocument
1627
+ """
1628
+ if msg is None or isinstance(msg, ChatDocument):
1629
+ return msg
1630
+
1631
+ is_agent_author = author_entity == Entity.AGENT
1632
+
1633
+ if isinstance(msg, str):
1634
+ return self.response_template(author_entity, content=msg, content_any=msg)
1635
+ elif isinstance(msg, ToolMessage):
1636
+ # result is a ToolMessage, so...
1637
+ result_tool_name = msg.default_value("request")
1638
+ if (
1639
+ is_agent_author
1640
+ and result_tool_name in self.llm_tools_handled
1641
+ and (orig_tool_name is None or orig_tool_name != result_tool_name)
1642
+ ):
1643
+ # TODO: do we need to remove the tool message from the chat_doc?
1644
+ # if (chat_doc is not None and
1645
+ # msg in chat_doc.tool_messages):
1646
+ # chat_doc.tool_messages.remove(msg)
1647
+ # if we can handle it, do so
1648
+ result = self.handle_tool_message(msg, chat_doc=chat_doc)
1649
+ if result is not None and isinstance(result, ChatDocument):
1650
+ return result
1651
+ else:
1652
+ # else wrap it in an agent response and return it so
1653
+ # orchestrator can find a respondent
1654
+ return self.response_template(author_entity, tool_messages=[msg])
1655
+ else:
1656
+ result = to_string(msg)
1657
+
1658
+ return (
1659
+ None
1660
+ if result is None
1661
+ else self.response_template(author_entity, content=result, content_any=msg)
1662
+ )
1663
+
1664
+ def from_ChatDocument(self, msg: ChatDocument, output_type: Type[T]) -> Optional[T]:
1665
+ """
1666
+ Extract a desired output_type from a ChatDocument object.
1667
+ We use this fallback order:
1668
+ - if `msg.content_any` exists and matches the output_type, return it
1669
+ - if `msg.content` exists and output_type is str return it
1670
+ - if output_type is a ToolMessage, return the first tool in `msg.tool_messages`
1671
+ - if output_type is a list of ToolMessage,
1672
+ return all tools in `msg.tool_messages`
1673
+ - search for a tool in `msg.tool_messages` that has a field of output_type,
1674
+ and if found, return that field value
1675
+ - return None if all the above fail
1676
+ """
1677
+ content = msg.content
1678
+ if output_type is str and content != "":
1679
+ return cast(T, content)
1680
+ content_any = msg.content_any
1681
+ if content_any is not None and isinstance(content_any, output_type):
1682
+ return cast(T, content_any)
1683
+
1684
+ tools = self.try_get_tool_messages(msg, all_tools=True)
1685
+
1686
+ if get_origin(output_type) is list:
1687
+ list_element_type = get_args(output_type)[0]
1688
+ if issubclass(list_element_type, ToolMessage):
1689
+ # list_element_type is a subclass of ToolMessage:
1690
+ # We output a list of objects derived from list_element_type
1691
+ return cast(
1692
+ T,
1693
+ [t for t in tools if isinstance(t, list_element_type)],
1694
+ )
1695
+ elif get_origin(output_type) is None and issubclass(output_type, ToolMessage):
1696
+ # output_type is a subclass of ToolMessage:
1697
+ # return the first tool that has this specific output_type
1698
+ for tool in tools:
1699
+ if isinstance(tool, output_type):
1700
+ return cast(T, tool)
1701
+ return None
1702
+ elif get_origin(output_type) is None and output_type in (str, int, float, bool):
1703
+ # attempt to get the output_type from the content,
1704
+ # if it's a primitive type
1705
+ primitive_value = from_string(content, output_type) # type: ignore
1706
+ if primitive_value is not None:
1707
+ return cast(T, primitive_value)
1708
+
1709
+ # then search for output_type as a field in a tool
1710
+ for tool in tools:
1711
+ value = tool.get_value_of_type(output_type)
1712
+ if value is not None:
1713
+ return cast(T, value)
1714
+ return None
1715
+
1716
+ def _maybe_truncate_result(
1717
+ self, result: str | ChatDocument | None, max_tokens: int | None
1718
+ ) -> str | ChatDocument | None:
1719
+ """
1720
+ Truncate the result string to `max_tokens` tokens.
1721
+ """
1722
+ if result is None or max_tokens is None:
1723
+ return result
1724
+ result_str = result.content if isinstance(result, ChatDocument) else result
1725
+ num_tokens = (
1726
+ self.parser.num_tokens(result_str)
1727
+ if self.parser is not None
1728
+ else len(result_str) / 4.0
1729
+ )
1730
+ if num_tokens <= max_tokens:
1731
+ return result
1732
+ truncate_warning = f"""
1733
+ The TOOL result was large, so it was truncated to {max_tokens} tokens.
1734
+ To get the full result, the TOOL must be called again.
1735
+ """
1736
+ if isinstance(result, str):
1737
+ return (
1738
+ self.parser.truncate_tokens(result, max_tokens)
1739
+ if self.parser is not None
1740
+ else result[: max_tokens * 4] # approx truncate
1741
+ ) + truncate_warning
1742
+ elif isinstance(result, ChatDocument):
1743
+ result.content = (
1744
+ self.parser.truncate_tokens(result.content, max_tokens)
1745
+ if self.parser is not None
1746
+ else result.content[: max_tokens * 4] # approx truncate
1747
+ ) + truncate_warning
1748
+ return result
1749
+
1750
+ async def handle_tool_message_async(
1751
+ self,
1752
+ tool: ToolMessage,
1753
+ chat_doc: Optional[ChatDocument] = None,
1754
+ ) -> None | str | ChatDocument:
1755
+ """
1756
+ Asynch version of `handle_tool_message`. See there for details.
1757
+ """
1758
+ tool_name = tool.default_value("request")
1759
+ if hasattr(tool, "_handler"):
1760
+ handler_name = getattr(tool, "_handler", tool_name)
1761
+ else:
1762
+ handler_name = tool_name
1763
+ handler_method = getattr(self, handler_name + "_async", None)
1764
+ if handler_method is None:
1765
+ return self.handle_tool_message(tool, chat_doc=chat_doc)
1766
+ has_chat_doc_arg = (
1767
+ chat_doc is not None
1768
+ and "chat_doc" in inspect.signature(handler_method).parameters
1769
+ )
1770
+ try:
1771
+ if has_chat_doc_arg:
1772
+ maybe_result = await handler_method(tool, chat_doc=chat_doc)
1773
+ else:
1774
+ maybe_result = await handler_method(tool)
1775
+ result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
1776
+ except Exception as e:
1777
+ # raise the error here since we are sure it's
1778
+ # not a pydantic validation error,
1779
+ # which we check in `handle_message`
1780
+ raise e
1781
+ return self._maybe_truncate_result(
1782
+ result, tool._max_result_tokens
1783
+ ) # type: ignore
1784
+
1785
+ def handle_tool_message(
1786
+ self,
1787
+ tool: ToolMessage,
1788
+ chat_doc: Optional[ChatDocument] = None,
1789
+ ) -> None | str | ChatDocument:
1790
+ """
1791
+ Respond to a tool request from the LLM, in the form of an ToolMessage object.
1792
+ Args:
1793
+ tool: ToolMessage object representing the tool request.
1794
+ chat_doc: Optional ChatDocument object containing the tool request.
1795
+ This is passed to the tool-handler method only if it has a `chat_doc`
1796
+ argument.
1797
+
1798
+ Returns:
1799
+
1800
+ """
1801
+ tool_name = tool.default_value("request")
1802
+ if hasattr(tool, "_handler"):
1803
+ handler_name = getattr(tool, "_handler", tool_name)
1804
+ else:
1805
+ handler_name = tool_name
1806
+ handler_method = getattr(self, handler_name, None)
1807
+ if handler_method is None:
1808
+ return None
1809
+ has_chat_doc_arg = (
1810
+ chat_doc is not None
1811
+ and "chat_doc" in inspect.signature(handler_method).parameters
1812
+ )
1813
+ try:
1814
+ if has_chat_doc_arg:
1815
+ maybe_result = handler_method(tool, chat_doc=chat_doc)
1816
+ else:
1817
+ maybe_result = handler_method(tool)
1818
+ result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
1819
+ except Exception as e:
1820
+ # raise the error here since we are sure it's
1821
+ # not a pydantic validation error,
1822
+ # which we check in `handle_message`
1823
+ raise e
1824
+ return self._maybe_truncate_result(
1825
+ result, tool._max_result_tokens
1826
+ ) # type: ignore
1827
+
1828
+ def num_tokens(self, prompt: str | List[LLMMessage]) -> int:
1829
+ if self.parser is None:
1830
+ raise ValueError("Parser must be set, to count tokens")
1831
+ if isinstance(prompt, str):
1832
+ return self.parser.num_tokens(prompt)
1833
+ else:
1834
+ return sum(
1835
+ [
1836
+ self.parser.num_tokens(m.content)
1837
+ + self.parser.num_tokens(str(m.function_call or ""))
1838
+ for m in prompt
1839
+ ]
1840
+ )
1841
+
1842
+ def _get_response_stats(
1843
+ self, chat_length: int, tot_cost: float, response: LLMResponse
1844
+ ) -> str:
1845
+ """
1846
+ Get LLM response stats as a string
1847
+
1848
+ Args:
1849
+ chat_length (int): number of messages in the chat
1850
+ tot_cost (float): total cost of the chat so far
1851
+ response (LLMResponse): LLMResponse object
1852
+ """
1853
+
1854
+ if self.config.llm is None:
1855
+ logger.warning("LLM config is None, cannot get response stats")
1856
+ return ""
1857
+ if response.usage:
1858
+ in_tokens = response.usage.prompt_tokens
1859
+ out_tokens = response.usage.completion_tokens
1860
+ llm_response_cost = format(response.usage.cost, ".4f")
1861
+ cumul_cost = format(tot_cost, ".4f")
1862
+ assert isinstance(self.llm, LanguageModel)
1863
+ context_length = self.llm.chat_context_length()
1864
+ max_out = self.config.llm.max_output_tokens
1865
+
1866
+ llm_model = (
1867
+ "no-LLM" if self.config.llm is None else self.llm.config.chat_model
1868
+ )
1869
+ # tot cost across all LLMs, agents
1870
+ all_cost = format(self.llm.tot_tokens_cost()[1], ".4f")
1871
+ return (
1872
+ f"[bold]Stats:[/bold] [magenta]N_MSG={chat_length}, "
1873
+ f"TOKENS: in={in_tokens}, out={out_tokens}, "
1874
+ f"max={max_out}, ctx={context_length}, "
1875
+ f"COST: now=${llm_response_cost}, cumul=${cumul_cost}, "
1876
+ f"tot=${all_cost} "
1877
+ f"[bold]({llm_model})[/bold][/magenta]"
1878
+ )
1879
+ return ""
1880
+
1881
+ def update_token_usage(
1882
+ self,
1883
+ response: LLMResponse,
1884
+ prompt: str | List[LLMMessage],
1885
+ stream: bool,
1886
+ chat: bool = True,
1887
+ print_response_stats: bool = True,
1888
+ ) -> None:
1889
+ """
1890
+ Updates `response.usage` obj (token usage and cost fields).the usage memebr
1891
+ It updates the cost after checking the cache and updates the
1892
+ tokens (prompts and completion) if the response stream is True, because OpenAI
1893
+ doesn't returns these fields.
1894
+
1895
+ Args:
1896
+ response (LLMResponse): LLMResponse object
1897
+ prompt (str | List[LLMMessage]): prompt or list of LLMMessage objects
1898
+ stream (bool): whether to update the usage in the response object
1899
+ if the response is not cached.
1900
+ chat (bool): whether this is a chat model or a completion model
1901
+ print_response_stats (bool): whether to print the response stats
1902
+ """
1903
+ if response is None or self.llm is None:
1904
+ return
1905
+
1906
+ # Note: If response was not streamed, then
1907
+ # `response.usage` would already have been set by the API,
1908
+ # so we only need to update in the stream case.
1909
+ if stream:
1910
+ # usage, cost = 0 when response is from cache
1911
+ prompt_tokens = 0
1912
+ completion_tokens = 0
1913
+ cost = 0.0
1914
+ if not response.cached:
1915
+ prompt_tokens = self.num_tokens(prompt)
1916
+ completion_tokens = self.num_tokens(response.message)
1917
+ if response.function_call is not None:
1918
+ completion_tokens += self.num_tokens(str(response.function_call))
1919
+ cost = self.compute_token_cost(prompt_tokens, completion_tokens)
1920
+ response.usage = LLMTokenUsage(
1921
+ prompt_tokens=prompt_tokens,
1922
+ completion_tokens=completion_tokens,
1923
+ cost=cost,
1924
+ )
1925
+
1926
+ # update total counters
1927
+ if response.usage is not None:
1928
+ self.total_llm_token_cost += response.usage.cost
1929
+ self.total_llm_token_usage += response.usage.total_tokens
1930
+ self.llm.update_usage_cost(
1931
+ chat,
1932
+ response.usage.prompt_tokens,
1933
+ response.usage.completion_tokens,
1934
+ response.usage.cost,
1935
+ )
1936
+ chat_length = 1 if isinstance(prompt, str) else len(prompt)
1937
+ self.token_stats_str = self._get_response_stats(
1938
+ chat_length, self.total_llm_token_cost, response
1939
+ )
1940
+ if print_response_stats:
1941
+ print(self.indent + self.token_stats_str)
1942
+
1943
+ def compute_token_cost(self, prompt: int, completion: int) -> float:
1944
+ price = cast(LanguageModel, self.llm).chat_cost()
1945
+ return (price[0] * prompt + price[1] * completion) / 1000
1946
+
1947
+ def ask_agent(
1948
+ self,
1949
+ agent: "Agent",
1950
+ request: str,
1951
+ no_answer: str = NO_ANSWER,
1952
+ user_confirm: bool = True,
1953
+ ) -> Optional[str]:
1954
+ """
1955
+ Send a request to another agent, possibly after confirming with the user.
1956
+ This is not currently used, since we rely on the task loop and
1957
+ `RecipientTool` to address requests to other agents. It is generally best to
1958
+ avoid using this method.
1959
+
1960
+ Args:
1961
+ agent (Agent): agent to ask
1962
+ request (str): request to send
1963
+ no_answer (str): expected response when agent does not know the answer
1964
+ user_confirm (bool): whether to gate the request with a human confirmation
1965
+
1966
+ Returns:
1967
+ str: response from agent
1968
+ """
1969
+ agent_type = type(agent).__name__
1970
+ if user_confirm:
1971
+ user_response = Prompt.ask(
1972
+ f"""[magenta]Here is the request or message:
1973
+ {request}
1974
+ Should I forward this to {agent_type}?""",
1975
+ default="y",
1976
+ choices=["y", "n"],
1977
+ )
1978
+ if user_response not in ["y", "yes"]:
1979
+ return None
1980
+ answer = agent.llm_response(request)
1981
+ if answer != no_answer:
1982
+ return (f"{agent_type} says: " + str(answer)).strip()
1983
+ return None