langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. langroid/agent/base.py +39 -17
  2. langroid/agent/base.py-e +2216 -0
  3. langroid/agent/callbacks/chainlit.py +2 -1
  4. langroid/agent/chat_agent.py +73 -55
  5. langroid/agent/chat_agent.py-e +2086 -0
  6. langroid/agent/chat_document.py +7 -7
  7. langroid/agent/chat_document.py-e +513 -0
  8. langroid/agent/openai_assistant.py +9 -9
  9. langroid/agent/openai_assistant.py-e +882 -0
  10. langroid/agent/special/arangodb/arangodb_agent.py +10 -18
  11. langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
  12. langroid/agent/special/arangodb/tools.py +3 -3
  13. langroid/agent/special/doc_chat_agent.py +16 -14
  14. langroid/agent/special/lance_rag/critic_agent.py +2 -2
  15. langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
  16. langroid/agent/special/lance_tools.py +6 -5
  17. langroid/agent/special/lance_tools.py-e +61 -0
  18. langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
  20. langroid/agent/special/relevance_extractor_agent.py +1 -1
  21. langroid/agent/special/sql/sql_chat_agent.py +11 -3
  22. langroid/agent/task.py +9 -87
  23. langroid/agent/task.py-e +2418 -0
  24. langroid/agent/tool_message.py +33 -17
  25. langroid/agent/tool_message.py-e +400 -0
  26. langroid/agent/tools/file_tools.py +4 -2
  27. langroid/agent/tools/file_tools.py-e +234 -0
  28. langroid/agent/tools/mcp/fastmcp_client.py +19 -6
  29. langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
  30. langroid/agent/tools/orchestration.py +22 -17
  31. langroid/agent/tools/orchestration.py-e +301 -0
  32. langroid/agent/tools/recipient_tool.py +3 -3
  33. langroid/agent/tools/task_tool.py +22 -16
  34. langroid/agent/tools/task_tool.py-e +249 -0
  35. langroid/agent/xml_tool_message.py +90 -35
  36. langroid/agent/xml_tool_message.py-e +392 -0
  37. langroid/cachedb/base.py +1 -1
  38. langroid/embedding_models/base.py +2 -2
  39. langroid/embedding_models/models.py +3 -7
  40. langroid/embedding_models/models.py-e +563 -0
  41. langroid/exceptions.py +4 -1
  42. langroid/language_models/azure_openai.py +2 -2
  43. langroid/language_models/azure_openai.py-e +134 -0
  44. langroid/language_models/base.py +6 -4
  45. langroid/language_models/base.py-e +812 -0
  46. langroid/language_models/client_cache.py +64 -0
  47. langroid/language_models/config.py +2 -4
  48. langroid/language_models/config.py-e +18 -0
  49. langroid/language_models/model_info.py +9 -1
  50. langroid/language_models/model_info.py-e +483 -0
  51. langroid/language_models/openai_gpt.py +119 -20
  52. langroid/language_models/openai_gpt.py-e +2280 -0
  53. langroid/language_models/provider_params.py +3 -22
  54. langroid/language_models/provider_params.py-e +153 -0
  55. langroid/mytypes.py +11 -4
  56. langroid/mytypes.py-e +132 -0
  57. langroid/parsing/code_parser.py +1 -1
  58. langroid/parsing/file_attachment.py +1 -1
  59. langroid/parsing/file_attachment.py-e +246 -0
  60. langroid/parsing/md_parser.py +14 -4
  61. langroid/parsing/md_parser.py-e +574 -0
  62. langroid/parsing/parser.py +22 -7
  63. langroid/parsing/parser.py-e +410 -0
  64. langroid/parsing/repo_loader.py +3 -1
  65. langroid/parsing/repo_loader.py-e +812 -0
  66. langroid/parsing/search.py +1 -1
  67. langroid/parsing/url_loader.py +17 -51
  68. langroid/parsing/url_loader.py-e +683 -0
  69. langroid/parsing/urls.py +5 -4
  70. langroid/parsing/urls.py-e +279 -0
  71. langroid/prompts/prompts_config.py +1 -1
  72. langroid/pydantic_v1/__init__.py +45 -6
  73. langroid/pydantic_v1/__init__.py-e +36 -0
  74. langroid/pydantic_v1/main.py +11 -4
  75. langroid/pydantic_v1/main.py-e +11 -0
  76. langroid/utils/configuration.py +13 -11
  77. langroid/utils/configuration.py-e +141 -0
  78. langroid/utils/constants.py +1 -1
  79. langroid/utils/constants.py-e +32 -0
  80. langroid/utils/globals.py +21 -5
  81. langroid/utils/globals.py-e +49 -0
  82. langroid/utils/html_logger.py +2 -1
  83. langroid/utils/html_logger.py-e +825 -0
  84. langroid/utils/object_registry.py +1 -1
  85. langroid/utils/object_registry.py-e +66 -0
  86. langroid/utils/pydantic_utils.py +55 -28
  87. langroid/utils/pydantic_utils.py-e +602 -0
  88. langroid/utils/types.py +2 -2
  89. langroid/utils/types.py-e +113 -0
  90. langroid/vector_store/base.py +3 -3
  91. langroid/vector_store/lancedb.py +5 -5
  92. langroid/vector_store/lancedb.py-e +404 -0
  93. langroid/vector_store/meilisearch.py +2 -2
  94. langroid/vector_store/pineconedb.py +4 -4
  95. langroid/vector_store/pineconedb.py-e +427 -0
  96. langroid/vector_store/postgres.py +1 -1
  97. langroid/vector_store/qdrantdb.py +3 -3
  98. langroid/vector_store/weaviatedb.py +1 -1
  99. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
  100. langroid-0.59.0b1.dist-info/RECORD +181 -0
  101. langroid/agent/special/doc_chat_task.py +0 -0
  102. langroid/mcp/__init__.py +0 -1
  103. langroid/mcp/server/__init__.py +0 -1
  104. langroid-0.58.2.dist-info/RECORD +0 -145
  105. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
  106. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2216 @@
1
+ import asyncio
2
+ import copy
3
+ import inspect
4
+ import json
5
+ import logging
6
+ import re
7
+ from abc import ABC
8
+ from collections import OrderedDict
9
+ from contextlib import ExitStack
10
+ from types import SimpleNamespace
11
+ from typing import (
12
+ Any,
13
+ Callable,
14
+ Coroutine,
15
+ Dict,
16
+ List,
17
+ Optional,
18
+ Set,
19
+ Tuple,
20
+ Type,
21
+ TypeVar,
22
+ cast,
23
+ get_args,
24
+ get_origin,
25
+ no_type_check,
26
+ )
27
+
28
+ from pydantic_settings import BaseSettings
29
+ from rich import print
30
+ from rich.console import Console
31
+ from rich.markup import escape
32
+ from rich.prompt import Prompt
33
+
34
+ from langroid.agent.chat_document import ChatDocMetaData, ChatDocument
35
+ from langroid.agent.tool_message import ToolMessage
36
+ from langroid.agent.xml_tool_message import XMLToolMessage
37
+ from langroid.exceptions import XMLException
38
+ from langroid.language_models.base import (
39
+ LanguageModel,
40
+ LLMConfig,
41
+ LLMFunctionCall,
42
+ LLMMessage,
43
+ LLMResponse,
44
+ LLMTokenUsage,
45
+ OpenAIToolCall,
46
+ StreamingIfAllowed,
47
+ ToolChoiceTypes,
48
+ )
49
+ from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
50
+ from langroid.mytypes import Entity
51
+ from langroid.parsing.file_attachment import FileAttachment
52
+ from langroid.parsing.parse_json import extract_top_level_json
53
+ from langroid.parsing.parser import Parser, ParsingConfig
54
+ from langroid.prompts.prompts_config import PromptsConfig
55
+ from pydantic import (
56
+ Field,
57
+ ValidationError,
58
+ field_validator,
59
+ )
60
+ from langroid.utils.configuration import settings
61
+ from langroid.utils.constants import (
62
+ DONE,
63
+ NO_ANSWER,
64
+ PASS,
65
+ PASS_TO,
66
+ SEND_TO,
67
+ )
68
+ from langroid.utils.object_registry import ObjectRegistry
69
+ from langroid.utils.output import status
70
+ from langroid.utils.types import from_string, to_string
71
+ from langroid.vector_store.base import VectorStore, VectorStoreConfig
72
+
73
+ ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
74
+ console = Console(quiet=settings.quiet)
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+ T = TypeVar("T")
79
+
80
+
81
+ class AgentConfig(BaseSettings):
82
+ """
83
+ General config settings for an LLM agent. This is nested, combining configs of
84
+ various components.
85
+ """
86
+
87
+ name: str = "LLM-Agent"
88
+ debug: bool = False
89
+ vecdb: Optional[VectorStoreConfig] = None
90
+ llm: Optional[LLMConfig] = OpenAIGPTConfig()
91
+ parsing: Optional[ParsingConfig] = ParsingConfig()
92
+ prompts: Optional[PromptsConfig] = PromptsConfig()
93
+ show_stats: bool = True # show token usage/cost stats?
94
+ hide_agent_response: bool = False # hide agent response?
95
+ add_to_registry: bool = True # register agent in ObjectRegistry?
96
+ respond_tools_only: bool = False # respond only to tool messages (not plain text)?
97
+ # allow multiple tool messages in a single response?
98
+ allow_multiple_tools: bool = True
99
+ human_prompt: str = (
100
+ "Human (respond or q, x to exit current level, " "or hit enter to continue)"
101
+ )
102
+
103
+ @field_validator("name")
104
+ @classmethod
105
+ def check_name_alphanum(cls, v: str) -> str:
106
+ if not re.match(r"^[a-zA-Z0-9_-]+$", v):
107
+ raise ValueError(
108
+ "The name must only contain alphanumeric characters, "
109
+ "underscores, or hyphens, with no spaces"
110
+ )
111
+ return v
112
+
113
+
114
+ def noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
115
+ pass
116
+
117
+
118
+ async def async_noop_fn(*args: List[Any], **kwargs: Dict[str, Any]) -> None:
119
+ pass
120
+
121
+
122
+ async def async_lambda_noop_fn() -> Callable[..., Coroutine[Any, Any, None]]:
123
+ return async_noop_fn
124
+
125
+
126
+ class Agent(ABC):
127
+ """
128
+ An Agent is an abstraction that typically (but not necessarily)
129
+ encapsulates an LLM.
130
+ """
131
+
132
+ id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
133
+ # OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
134
+ # is added to self.message_history
135
+ oai_tool_calls: List[OpenAIToolCall] = []
136
+ # Index of ALL tool calls generated by the agent
137
+ oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
138
+
139
+ def __init__(self, config: AgentConfig = AgentConfig()):
140
+ self.config = config
141
+ self.id = ObjectRegistry.new_id() # Initialize agent ID
142
+ self.lock = asyncio.Lock() # for async access to update self.llm.usage_cost
143
+ self.dialog: List[Tuple[str, str]] = [] # seq of LLM (prompt, response) tuples
144
+ self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
145
+ self.llm_tools_handled: Set[str] = set()
146
+ self.llm_tools_usable: Set[str] = set()
147
+ self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
148
+ # Indicates which tool-names are allowed to be inferred when
149
+ # the LLM "forgets" to include the request field in its tool-call.
150
+ self.enabled_requests_for_inference: Optional[Set[str]] = (
151
+ None # If None, we allow all
152
+ )
153
+ self.interactive: bool = True # may be modified by Task wrapper
154
+ self.token_stats_str = ""
155
+ self.default_human_response: Optional[str] = None
156
+ self._indent = ""
157
+ self.llm = LanguageModel.create(config.llm)
158
+ self.vecdb = VectorStore.create(config.vecdb) if config.vecdb else None
159
+ self.tool_error = False
160
+ if config.parsing is not None and self.config.llm is not None:
161
+ # token_encoding_model is used to obtain the tokenizer,
162
+ # so in case it's an OpenAI model, we ensure that the tokenizer
163
+ # corresponding to the model is used.
164
+ if isinstance(self.llm, OpenAIGPT) and self.llm.is_openai_chat_model():
165
+ config.parsing.token_encoding_model = self.llm.config.chat_model
166
+ self.parser: Optional[Parser] = (
167
+ Parser(config.parsing) if config.parsing else None
168
+ )
169
+ if config.add_to_registry:
170
+ ObjectRegistry.register_object(self)
171
+
172
+ self.callbacks = SimpleNamespace(
173
+ start_llm_stream=lambda: noop_fn,
174
+ start_llm_stream_async=async_lambda_noop_fn,
175
+ cancel_llm_stream=noop_fn,
176
+ finish_llm_stream=noop_fn,
177
+ show_llm_response=noop_fn,
178
+ show_agent_response=noop_fn,
179
+ get_user_response=None,
180
+ get_user_response_async=None,
181
+ get_last_step=noop_fn,
182
+ set_parent_agent=noop_fn,
183
+ show_error_message=noop_fn,
184
+ show_start_response=noop_fn,
185
+ )
186
+ Agent.init_state(self)
187
+
188
+ def init_state(self) -> None:
189
+ """Initialize all state vars. Called by Task.run() if restart is True"""
190
+ self.total_llm_token_cost = 0.0
191
+ self.total_llm_token_usage = 0
192
+
193
+ @staticmethod
194
+ def from_id(id: str) -> "Agent":
195
+ return cast(Agent, ObjectRegistry.get(id))
196
+
197
+ @staticmethod
198
+ def delete_id(id: str) -> None:
199
+ ObjectRegistry.remove(id)
200
+
201
+ def entity_responders(
202
+ self,
203
+ ) -> List[
204
+ Tuple[Entity, Callable[[None | str | ChatDocument], None | ChatDocument]]
205
+ ]:
206
+ """
207
+ Sequence of (entity, response_method) pairs. This sequence is used
208
+ in a `Task` to respond to the current pending message.
209
+ See `Task.step()` for details.
210
+ Returns:
211
+ Sequence of (entity, response_method) pairs.
212
+ """
213
+ return [
214
+ (Entity.AGENT, self.agent_response),
215
+ (Entity.LLM, self.llm_response),
216
+ (Entity.USER, self.user_response),
217
+ ]
218
+
219
+ def entity_responders_async(
220
+ self,
221
+ ) -> List[
222
+ Tuple[
223
+ Entity,
224
+ Callable[
225
+ [None | str | ChatDocument], Coroutine[Any, Any, None | ChatDocument]
226
+ ],
227
+ ]
228
+ ]:
229
+ """
230
+ Async version of `entity_responders`. See there for details.
231
+ """
232
+ return [
233
+ (Entity.AGENT, self.agent_response_async),
234
+ (Entity.LLM, self.llm_response_async),
235
+ (Entity.USER, self.user_response_async),
236
+ ]
237
+
238
+ @property
239
+ def indent(self) -> str:
240
+ """Indentation to print before any responses from the agent's entities."""
241
+ return self._indent
242
+
243
+ @indent.setter
244
+ def indent(self, value: str) -> None:
245
+ self._indent = value
246
+
247
+ def update_dialog(self, prompt: str, output: str) -> None:
248
+ self.dialog.append((prompt, output))
249
+
250
+ def get_dialog(self) -> List[Tuple[str, str]]:
251
+ return self.dialog
252
+
253
+ def clear_dialog(self) -> None:
254
+ self.dialog = []
255
+
256
+ def _analyze_handler_params(
257
+ self, handler_method: Any
258
+ ) -> Tuple[bool, Optional[str], Optional[str]]:
259
+ """
260
+ Analyze parameters of a handler method to determine their types.
261
+
262
+ Returns:
263
+ Tuple of (has_annotations, agent_param_name, chat_doc_param_name)
264
+ - has_annotations: True if useful type annotations were found
265
+ - agent_param_name: Name of the agent parameter if found
266
+ - chat_doc_param_name: Name of the chat_doc parameter if found
267
+ """
268
+ sig = inspect.signature(handler_method)
269
+ params = list(sig.parameters.values())
270
+ # Remove the first 'self' parameter
271
+ params = params[1:]
272
+ # Don't use name
273
+ # [p for p in params if p.name != "self"]
274
+
275
+ agent_param = None
276
+ chat_doc_param = None
277
+ has_annotations = False
278
+
279
+ for param in params:
280
+ # First try type annotations
281
+ if param.annotation != inspect.Parameter.empty:
282
+ ann_str = str(param.annotation)
283
+ # Check for Agent-like types
284
+ if (
285
+ inspect.isclass(param.annotation)
286
+ and issubclass(param.annotation, Agent)
287
+ ) or (
288
+ not inspect.isclass(param.annotation)
289
+ and (
290
+ "Agent" in ann_str
291
+ or (
292
+ hasattr(param.annotation, "__name__")
293
+ and "Agent" in param.annotation.__name__
294
+ )
295
+ )
296
+ ):
297
+ agent_param = param.name
298
+ has_annotations = True
299
+ # Check for ChatDocument-like types
300
+ elif (
301
+ param.annotation is ChatDocument
302
+ or "ChatDocument" in ann_str
303
+ or "ChatDoc" in ann_str
304
+ ):
305
+ chat_doc_param = param.name
306
+ has_annotations = True
307
+
308
+ # Fallback to parameter names
309
+ elif param.name == "agent":
310
+ agent_param = param.name
311
+ elif param.name == "chat_doc":
312
+ chat_doc_param = param.name
313
+
314
+ return has_annotations, agent_param, chat_doc_param
315
+
316
+ @no_type_check
317
+ def _create_handler_wrapper(
318
+ self,
319
+ handler_method: Any,
320
+ is_async: bool = False,
321
+ ) -> Any:
322
+ """
323
+ Create a wrapper function for a handler method based on its signature.
324
+
325
+ Args:
326
+ message_class: The ToolMessage class
327
+ handler_method: The handle/handle_async method
328
+ is_async: Whether this is for an async handler
329
+
330
+ Returns:
331
+ Appropriate wrapper function
332
+ """
333
+ sig = inspect.signature(handler_method)
334
+ params = list(sig.parameters.values())
335
+ params = params[1:]
336
+ # params = [p for p in params if p.name != "self"]
337
+
338
+ has_annotations, agent_param, chat_doc_param = self._analyze_handler_params(
339
+ handler_method,
340
+ )
341
+
342
+ # Build wrapper based on found parameters
343
+ if len(params) == 0:
344
+ if is_async:
345
+
346
+ async def wrapper(obj: Any) -> Any:
347
+ return await obj.handle_async()
348
+
349
+ else:
350
+
351
+ def wrapper(obj: Any) -> Any:
352
+ return obj.handle()
353
+
354
+ elif agent_param and chat_doc_param:
355
+ # Both parameters present - build wrapper respecting their order
356
+ param_names = [p.name for p in params]
357
+ if param_names.index(agent_param) < param_names.index(chat_doc_param):
358
+ # agent is first parameter
359
+ if is_async:
360
+
361
+ async def wrapper(obj: Any, chat_doc: Any) -> Any:
362
+ return await obj.handle_async(self, chat_doc)
363
+
364
+ else:
365
+
366
+ def wrapper(obj: Any, chat_doc: Any) -> Any:
367
+ return obj.handle(self, chat_doc)
368
+
369
+ else:
370
+ # chat_doc is first parameter
371
+ if is_async:
372
+
373
+ async def wrapper(obj: Any, chat_doc: Any) -> Any:
374
+ return await obj.handle_async(chat_doc, self)
375
+
376
+ else:
377
+
378
+ def wrapper(obj: Any, chat_doc: Any) -> Any:
379
+ return obj.handle(chat_doc, self)
380
+
381
+ elif agent_param and not chat_doc_param:
382
+ # Only agent parameter
383
+ if is_async:
384
+
385
+ async def wrapper(obj: Any) -> Any:
386
+ return await obj.handle_async(self)
387
+
388
+ else:
389
+
390
+ def wrapper(obj: Any) -> Any:
391
+ return obj.handle(self)
392
+
393
+ elif chat_doc_param and not agent_param:
394
+ # Only chat_doc parameter
395
+ if is_async:
396
+
397
+ async def wrapper(obj: Any, chat_doc: Any) -> Any:
398
+ return await obj.handle_async(chat_doc)
399
+
400
+ else:
401
+
402
+ def wrapper(obj: Any, chat_doc: Any) -> Any:
403
+ return obj.handle(chat_doc)
404
+
405
+ else:
406
+ # No recognized parameters - backward compatibility
407
+ # Assume single parameter is chat_doc (legacy behavior)
408
+ if len(params) == 1:
409
+ if is_async:
410
+
411
+ async def wrapper(obj: Any, chat_doc: Any) -> Any:
412
+ return await obj.handle_async(chat_doc)
413
+
414
+ else:
415
+
416
+ def wrapper(obj: Any, chat_doc: Any) -> Any:
417
+ return obj.handle(chat_doc)
418
+
419
+ else:
420
+ # Multiple unrecognized parameters - best guess
421
+ if is_async:
422
+
423
+ async def wrapper(obj: Any, chat_doc: Any) -> Any:
424
+ return await obj.handle_async(chat_doc)
425
+
426
+ else:
427
+
428
+ def wrapper(obj: Any, chat_doc: Any) -> Any:
429
+ return obj.handle(chat_doc)
430
+
431
+ return wrapper
432
+
433
+ def _get_tool_list(
434
+ self, message_class: Optional[Type[ToolMessage]] = None
435
+ ) -> List[str]:
436
+ """
437
+ If `message_class` is None, return a list of all known tool names.
438
+ Otherwise, first add the tool name corresponding to the message class
439
+ (which is the value of the `request` field of the message class),
440
+ to the `self.llm_tools_map` dict, and then return a list
441
+ containing this tool name.
442
+
443
+ Args:
444
+ message_class (Optional[Type[ToolMessage]]): The message class whose tool
445
+ name is to be returned; Optional, default is None.
446
+ if None, return a list of all known tool names.
447
+
448
+ Returns:
449
+ List[str]: List of tool names: either just the tool name corresponding
450
+ to the message class, or all known tool names
451
+ (when `message_class` is None).
452
+
453
+ """
454
+ if message_class is None:
455
+ return list(self.llm_tools_map.keys())
456
+
457
+ if not issubclass(message_class, ToolMessage):
458
+ raise ValueError("message_class must be a subclass of ToolMessage")
459
+ tool = message_class.default_value("request")
460
+
461
+ """
462
+ if tool has handler method explicitly defined - use it,
463
+ otherwise use the tool name as the handler
464
+ """
465
+ if hasattr(message_class, "_handler"):
466
+ handler = getattr(message_class, "_handler", tool)
467
+ else:
468
+ handler = tool
469
+
470
+ self.llm_tools_map[tool] = message_class
471
+ if (
472
+ hasattr(message_class, "handle")
473
+ and inspect.isfunction(message_class.handle)
474
+ and not hasattr(self, handler)
475
+ ):
476
+ """
477
+ If the message class has a `handle` method,
478
+ and agent does NOT have a tool handler method,
479
+ then we create a method for the agent whose name
480
+ is the value of `handler`, and whose body is the `handle` method.
481
+ This removes a separate step of having to define this method
482
+ for the agent, and also keeps the tool definition AND handling
483
+ in one place, i.e. in the message class.
484
+ See `tests/main/test_stateless_tool_messages.py` for an example.
485
+ """
486
+ wrapper = self._create_handler_wrapper(
487
+ message_class.handle,
488
+ is_async=False,
489
+ )
490
+ setattr(self, handler, wrapper)
491
+ elif (
492
+ hasattr(message_class, "response")
493
+ and inspect.isfunction(message_class.response)
494
+ and not hasattr(self, handler)
495
+ ):
496
+ has_chat_doc_arg = (
497
+ len(inspect.signature(message_class.response).parameters) > 2
498
+ )
499
+ if has_chat_doc_arg:
500
+
501
+ def response_wrapper_with_chat_doc(obj: Any, chat_doc: Any) -> Any:
502
+ return obj.response(self, chat_doc)
503
+
504
+ setattr(self, handler, response_wrapper_with_chat_doc)
505
+ else:
506
+
507
+ def response_wrapper_no_chat_doc(obj: Any) -> Any:
508
+ return obj.response(self)
509
+
510
+ setattr(self, handler, response_wrapper_no_chat_doc)
511
+
512
+ if hasattr(message_class, "handle_message_fallback") and (
513
+ inspect.isfunction(message_class.handle_message_fallback)
514
+ ):
515
+ # When a ToolMessage has a `handle_message_fallback` method,
516
+ # we inject it into the agent as a method, overriding the default
517
+ # `handle_message_fallback` method (which does nothing).
518
+ # It's possible multiple tool messages have a `handle_message_fallback`,
519
+ # in which case, the last one inserted will be used.
520
+ def fallback_wrapper(msg: Any) -> Any:
521
+ return message_class.handle_message_fallback(self, msg)
522
+
523
+ setattr(
524
+ self,
525
+ "handle_message_fallback",
526
+ fallback_wrapper,
527
+ )
528
+
529
+ async_handler_name = f"{handler}_async"
530
+ if (
531
+ hasattr(message_class, "handle_async")
532
+ and inspect.isfunction(message_class.handle_async)
533
+ and not hasattr(self, async_handler_name)
534
+ ):
535
+ wrapper = self._create_handler_wrapper(
536
+ message_class.handle_async,
537
+ is_async=True,
538
+ )
539
+ setattr(self, async_handler_name, wrapper)
540
+ elif (
541
+ hasattr(message_class, "response_async")
542
+ and inspect.isfunction(message_class.response_async)
543
+ and not hasattr(self, async_handler_name)
544
+ ):
545
+ has_chat_doc_arg = (
546
+ len(inspect.signature(message_class.response_async).parameters) > 2
547
+ )
548
+
549
+ if has_chat_doc_arg:
550
+
551
+ @no_type_check
552
+ async def handler(obj, chat_doc):
553
+ return await obj.response_async(self, chat_doc)
554
+
555
+ else:
556
+
557
+ @no_type_check
558
+ async def handler(obj):
559
+ return await obj.response_async(self)
560
+
561
+ setattr(self, async_handler_name, handler)
562
+
563
+ return [tool]
564
+
565
+ def enable_message_handling(
566
+ self, message_class: Optional[Type[ToolMessage]] = None
567
+ ) -> None:
568
+ """
569
+ Enable an agent to RESPOND (i.e. handle) a "tool" message of a specific type
570
+ from LLM. Also "registers" (i.e. adds) the `message_class` to the
571
+ `self.llm_tools_map` dict.
572
+
573
+ Args:
574
+ message_class (Optional[Type[ToolMessage]]): The message class to enable;
575
+ Optional; if None, all known message classes are enabled for handling.
576
+
577
+ """
578
+ for t in self._get_tool_list(message_class):
579
+ self.llm_tools_handled.add(t)
580
+
581
+ def disable_message_handling(
582
+ self,
583
+ message_class: Optional[Type[ToolMessage]] = None,
584
+ ) -> None:
585
+ """
586
+ Disable a message class from being handled by this Agent.
587
+
588
+ Args:
589
+ message_class (Optional[Type[ToolMessage]]): The message class to disable.
590
+ If None, all message classes are disabled.
591
+ """
592
+ for t in self._get_tool_list(message_class):
593
+ self.llm_tools_handled.discard(t)
594
+
595
+ def sample_multi_round_dialog(self) -> str:
596
+ """
597
+ Generate a sample multi-round dialog based on enabled message classes.
598
+ Returns:
599
+ str: The sample dialog string.
600
+ """
601
+ enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
602
+ # use at most 2 sample conversations, no need to be exhaustive;
603
+ sample_convo = [
604
+ msg_cls().usage_examples(random=True) # type: ignore
605
+ for i, msg_cls in enumerate(enabled_classes)
606
+ if i < 2
607
+ ]
608
+ return "\n\n".join(sample_convo)
609
+
610
+ def create_agent_response(
611
+ self,
612
+ content: str | None = None,
613
+ files: List[FileAttachment] = [],
614
+ content_any: Any = None,
615
+ tool_messages: List[ToolMessage] = [],
616
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
617
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
618
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
619
+ function_call: LLMFunctionCall | None = None,
620
+ recipient: str = "",
621
+ ) -> ChatDocument:
622
+ """Template for agent_response."""
623
+ return self.response_template(
624
+ Entity.AGENT,
625
+ content=content,
626
+ files=files,
627
+ content_any=content_any,
628
+ tool_messages=tool_messages,
629
+ oai_tool_calls=oai_tool_calls,
630
+ oai_tool_choice=oai_tool_choice,
631
+ oai_tool_id2result=oai_tool_id2result,
632
+ function_call=function_call,
633
+ recipient=recipient,
634
+ )
635
+
636
+ def render_agent_response(
637
+ self,
638
+ results: Optional[str | OrderedDict[str, str] | ChatDocument],
639
+ ) -> None:
640
+ """
641
+ Render the response from the agent, typically from tool-handling.
642
+ Args:
643
+ results: results from tool-handling, which may be a string,
644
+ a dict of tool results, or a ChatDocument.
645
+ """
646
+ if self.config.hide_agent_response or results is None:
647
+ return
648
+ if isinstance(results, str):
649
+ results_str = results
650
+ elif isinstance(results, ChatDocument):
651
+ results_str = results.content
652
+ elif isinstance(results, dict):
653
+ results_str = json.dumps(results, indent=2)
654
+ if not settings.quiet:
655
+ console.print(f"[red]{self.indent}", end="")
656
+ print(f"[red]Agent: {escape(results_str)}")
657
+
658
+ def _agent_response_final(
659
+ self,
660
+ msg: Optional[str | ChatDocument],
661
+ results: Optional[str | OrderedDict[str, str] | ChatDocument],
662
+ ) -> Optional[ChatDocument]:
663
+ """
664
+ Convert results to final response.
665
+ """
666
+ if results is None:
667
+ return None
668
+ if isinstance(results, str):
669
+ results_str = results
670
+ elif isinstance(results, ChatDocument):
671
+ results_str = results.content
672
+ elif isinstance(results, dict):
673
+ results_str = json.dumps(results, indent=2)
674
+ if not settings.quiet:
675
+ self.render_agent_response(results)
676
+ maybe_json = len(extract_top_level_json(results_str)) > 0
677
+ self.callbacks.show_agent_response(
678
+ content=results_str,
679
+ language="json" if maybe_json else "text",
680
+ is_tool=(
681
+ isinstance(results, ChatDocument)
682
+ and self.has_tool_message_attempt(results)
683
+ ),
684
+ )
685
+ if isinstance(results, ChatDocument):
686
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
687
+ results.metadata.tool_ids = (
688
+ [] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
689
+ )
690
+ results.metadata.agent_id = self.id
691
+ return results
692
+ sender_name = self.config.name
693
+ if isinstance(msg, ChatDocument) and msg.function_call is not None:
694
+ # if result was from handling an LLM `function_call`,
695
+ # set sender_name to name of the function_call
696
+ sender_name = msg.function_call.name
697
+
698
+ results_str, id2result, oai_tool_id = self.process_tool_results(
699
+ results if isinstance(results, str) else "",
700
+ id2result=None if isinstance(results, str) else results,
701
+ tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
702
+ )
703
+ return ChatDocument(
704
+ content=results_str,
705
+ oai_tool_id2result=id2result,
706
+ metadata=ChatDocMetaData(
707
+ source=Entity.AGENT,
708
+ sender=Entity.AGENT,
709
+ agent_id=self.id,
710
+ sender_name=sender_name,
711
+ oai_tool_id=oai_tool_id,
712
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
713
+ tool_ids=(
714
+ [] if msg is None or isinstance(msg, str) else msg.metadata.tool_ids
715
+ ),
716
+ ),
717
+ )
718
+
719
+ async def agent_response_async(
720
+ self,
721
+ msg: Optional[str | ChatDocument] = None,
722
+ ) -> Optional[ChatDocument]:
723
+ """
724
+ Asynch version of `agent_response`. See there for details.
725
+ """
726
+ if msg is None:
727
+ return None
728
+
729
+ results = await self.handle_message_async(msg)
730
+
731
+ return self._agent_response_final(msg, results)
732
+
733
+ def agent_response(
734
+ self,
735
+ msg: Optional[str | ChatDocument] = None,
736
+ ) -> Optional[ChatDocument]:
737
+ """
738
+ Response from the "agent itself", typically (but not only)
739
+ used to handle LLM's "tool message" or `function_call`
740
+ (e.g. OpenAI `function_call`).
741
+ Args:
742
+ msg (str|ChatDocument): the input to respond to: if msg is a string,
743
+ and it contains a valid JSON-structured "tool message", or
744
+ if msg is a ChatDocument, and it contains a `function_call`.
745
+ Returns:
746
+ Optional[ChatDocument]: the response, packaged as a ChatDocument
747
+
748
+ """
749
+ if msg is None:
750
+ return None
751
+
752
+ results = self.handle_message(msg)
753
+
754
+ return self._agent_response_final(msg, results)
755
+
756
+ def process_tool_results(
757
+ self,
758
+ results: str,
759
+ id2result: OrderedDict[str, str] | None,
760
+ tool_calls: List[OpenAIToolCall] | None = None,
761
+ ) -> Tuple[str, Dict[str, str] | None, str | None]:
762
+ """
763
+ Process results from a response, based on whether
764
+ they are results of OpenAI tool-calls from THIS agent, so that
765
+ we can construct an appropriate LLMMessage that contains tool results.
766
+
767
+ Args:
768
+ results (str): A possible string result from handling tool(s)
769
+ id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
770
+ if there are multiple tool results.
771
+ tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
772
+ results are a response to.
773
+
774
+ Return:
775
+ - str: The response string
776
+ - Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
777
+ multiple tool results.
778
+ - str|None: tool_id if there was a single tool result
779
+
780
+ """
781
+ id2result_ = copy.deepcopy(id2result) if id2result is not None else None
782
+ results_str = ""
783
+ oai_tool_id = None
784
+
785
+ if results != "":
786
+ # in this case ignore id2result
787
+ assert (
788
+ id2result is None
789
+ ), "id2result should be None when results string is non-empty!"
790
+ results_str = results
791
+ if len(self.oai_tool_calls) > 0:
792
+ # We only have one result, so in case there is a
793
+ # "pending" OpenAI tool-call, we expect no more than 1 such.
794
+ assert (
795
+ len(self.oai_tool_calls) == 1
796
+ ), "There are multiple pending tool-calls, but only one result!"
797
+ # We record the tool_id of the tool-call that
798
+ # the result is a response to, so that ChatDocument.to_LLMMessage
799
+ # can properly set the `tool_call_id` field of the LLMMessage.
800
+ oai_tool_id = self.oai_tool_calls[0].id
801
+ elif id2result is not None and id2result_ is not None: # appease mypy
802
+ if len(id2result_) == len(self.oai_tool_calls):
803
+ # if the number of pending tool calls equals the number of results,
804
+ # then ignore the ids in id2result, and use the results in order,
805
+ # which is preserved since id2result is an OrderedDict.
806
+ assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
807
+ results_str = ""
808
+ id2result_ = OrderedDict(
809
+ zip(
810
+ [tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
811
+ )
812
+ )
813
+ else:
814
+ assert (
815
+ tool_calls is not None
816
+ ), "tool_calls cannot be None when id2result is not None!"
817
+ # This must be an OpenAI tool id -> result map;
818
+ # However some ids may not correspond to the tool-calls in the list of
819
+ # pending tool-calls (self.oai_tool_calls).
820
+ # Such results are concatenated into a simple string, to store in the
821
+ # ChatDocument.content, and the rest
822
+ # (i.e. those that DO correspond to tools in self.oai_tool_calls)
823
+ # are stored as a dict in ChatDocument.oai_tool_id2result.
824
+
825
+ # OAI tools from THIS agent, awaiting response
826
+ pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
827
+ # tool_calls that the results are a response to
828
+ # (but these may have been sent from another agent, hence may not be in
829
+ # self.oai_tool_calls)
830
+ parent_tool_id2name = {
831
+ tc.id: tc.function.name
832
+ for tc in tool_calls or []
833
+ if tc.function is not None
834
+ }
835
+
836
+ # (id, result) for result NOT corresponding to self.oai_tool_calls,
837
+ # i.e. these are results of EXTERNAL tool-calls from another agent.
838
+ external_tool_id_results = []
839
+
840
+ for tc_id, result in id2result.items():
841
+ if tc_id not in pending_tool_ids:
842
+ external_tool_id_results.append((tc_id, result))
843
+ id2result_.pop(tc_id)
844
+ if len(external_tool_id_results) == 0:
845
+ results_str = ""
846
+ elif len(external_tool_id_results) == 1:
847
+ results_str = external_tool_id_results[0][1]
848
+ else:
849
+ results_str = "\n\n".join(
850
+ [
851
+ f"Result from tool/function "
852
+ f"{parent_tool_id2name[id]}: {result}"
853
+ for id, result in external_tool_id_results
854
+ ]
855
+ )
856
+
857
+ if len(id2result_) == 0:
858
+ id2result_ = None
859
+ elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
860
+ results_str = list(id2result_.values())[0]
861
+ oai_tool_id = list(id2result_.keys())[0]
862
+ id2result_ = None
863
+
864
+ return results_str, id2result_, oai_tool_id
865
+
866
+ def response_template(
867
+ self,
868
+ e: Entity,
869
+ content: str | None = None,
870
+ files: List[FileAttachment] = [],
871
+ content_any: Any = None,
872
+ tool_messages: List[ToolMessage] = [],
873
+ oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
874
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
875
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
876
+ function_call: LLMFunctionCall | None = None,
877
+ recipient: str = "",
878
+ ) -> ChatDocument:
879
+ """Template for response from entity `e`."""
880
+ return ChatDocument(
881
+ content=content or "",
882
+ files=files,
883
+ content_any=content_any,
884
+ tool_messages=tool_messages,
885
+ oai_tool_calls=oai_tool_calls,
886
+ oai_tool_id2result=oai_tool_id2result,
887
+ function_call=function_call,
888
+ oai_tool_choice=oai_tool_choice,
889
+ metadata=ChatDocMetaData(
890
+ source=e, sender=e, sender_name=self.config.name, recipient=recipient
891
+ ),
892
+ )
893
+
894
+ def create_user_response(
895
+ self,
896
+ content: str | None = None,
897
+ files: List[FileAttachment] = [],
898
+ content_any: Any = None,
899
+ tool_messages: List[ToolMessage] = [],
900
+ oai_tool_calls: List[OpenAIToolCall] | None = None,
901
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
902
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
903
+ function_call: LLMFunctionCall | None = None,
904
+ recipient: str = "",
905
+ ) -> ChatDocument:
906
+ """Template for user_response."""
907
+ return self.response_template(
908
+ e=Entity.USER,
909
+ content=content,
910
+ files=files,
911
+ content_any=content_any,
912
+ tool_messages=tool_messages,
913
+ oai_tool_calls=oai_tool_calls,
914
+ oai_tool_choice=oai_tool_choice,
915
+ oai_tool_id2result=oai_tool_id2result,
916
+ function_call=function_call,
917
+ recipient=recipient,
918
+ )
919
+
920
+ def user_can_respond(self, msg: Optional[str | ChatDocument] = None) -> bool:
921
+ """
922
+ Whether the user can respond to a message.
923
+
924
+ Args:
925
+ msg (str|ChatDocument): the string to respond to.
926
+
927
+ Returns:
928
+
929
+ """
930
+ # When msg explicitly addressed to user, this means an actual human response
931
+ # is being sought.
932
+ need_human_response = (
933
+ isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
934
+ )
935
+
936
+ if not self.interactive and not need_human_response:
937
+ return False
938
+
939
+ return True
940
+
941
+ def _user_response_final(
942
+ self, msg: Optional[str | ChatDocument], user_msg: str
943
+ ) -> Optional[ChatDocument]:
944
+ """
945
+ Convert user_msg to final response.
946
+ """
947
+ if not user_msg:
948
+ need_human_response = (
949
+ isinstance(msg, ChatDocument) and msg.metadata.recipient == Entity.USER
950
+ )
951
+ user_msg = (
952
+ (self.default_human_response or "null") if need_human_response else ""
953
+ )
954
+ user_msg = user_msg.strip()
955
+
956
+ tool_ids = []
957
+ if msg is not None and isinstance(msg, ChatDocument):
958
+ tool_ids = msg.metadata.tool_ids
959
+
960
+ # only return non-None result if user_msg not empty
961
+ if not user_msg:
962
+ return None
963
+ else:
964
+ if user_msg.startswith("SYSTEM"):
965
+ user_msg = user_msg.replace("SYSTEM", "").strip()
966
+ source = Entity.SYSTEM
967
+ sender = Entity.SYSTEM
968
+ else:
969
+ source = Entity.USER
970
+ sender = Entity.USER
971
+ return ChatDocument(
972
+ content=user_msg,
973
+ metadata=ChatDocMetaData(
974
+ agent_id=self.id,
975
+ source=source,
976
+ sender=sender,
977
+ # preserve trail of tool_ids for OpenAI Assistant fn-calls
978
+ tool_ids=tool_ids,
979
+ ),
980
+ )
981
+
982
+ async def user_response_async(
983
+ self,
984
+ msg: Optional[str | ChatDocument] = None,
985
+ ) -> Optional[ChatDocument]:
986
+ """
987
+ Asynch version of `user_response`. See there for details.
988
+ """
989
+ if not self.user_can_respond(msg):
990
+ return None
991
+
992
+ if self.default_human_response is not None:
993
+ user_msg = self.default_human_response
994
+ else:
995
+ if (
996
+ self.callbacks.get_user_response_async is not None
997
+ and self.callbacks.get_user_response_async is not async_noop_fn
998
+ ):
999
+ user_msg = await self.callbacks.get_user_response_async(prompt="")
1000
+ elif self.callbacks.get_user_response is not None:
1001
+ user_msg = self.callbacks.get_user_response(prompt="")
1002
+ else:
1003
+ user_msg = Prompt.ask(
1004
+ f"[blue]{self.indent}"
1005
+ + self.config.human_prompt
1006
+ + f"\n{self.indent}"
1007
+ )
1008
+
1009
+ return self._user_response_final(msg, user_msg)
1010
+
1011
+ def user_response(
1012
+ self,
1013
+ msg: Optional[str | ChatDocument] = None,
1014
+ ) -> Optional[ChatDocument]:
1015
+ """
1016
+ Get user response to current message. Could allow (human) user to intervene
1017
+ with an actual answer, or quit using "q" or "x"
1018
+
1019
+ Args:
1020
+ msg (str|ChatDocument): the string to respond to.
1021
+
1022
+ Returns:
1023
+ (str) User response, packaged as a ChatDocument
1024
+
1025
+ """
1026
+
1027
+ if not self.user_can_respond(msg):
1028
+ return None
1029
+
1030
+ if self.default_human_response is not None:
1031
+ user_msg = self.default_human_response
1032
+ else:
1033
+ if self.callbacks.get_user_response is not None:
1034
+ # ask user with empty prompt: no need for prompt
1035
+ # since user has seen the conversation so far.
1036
+ # But non-empty prompt can be useful when Agent
1037
+ # uses a tool that requires user input, or in other scenarios.
1038
+ user_msg = self.callbacks.get_user_response(prompt="")
1039
+ else:
1040
+ user_msg = Prompt.ask(
1041
+ f"[blue]{self.indent}"
1042
+ + self.config.human_prompt
1043
+ + f"\n{self.indent}"
1044
+ )
1045
+
1046
+ return self._user_response_final(msg, user_msg)
1047
+
1048
+ @no_type_check
1049
+ def llm_can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
1050
+ """
1051
+ Whether the LLM can respond to a message.
1052
+ Args:
1053
+ message (str|ChatDocument): message or ChatDocument object to respond to.
1054
+
1055
+ Returns:
1056
+
1057
+ """
1058
+ if self.llm is None:
1059
+ return False
1060
+
1061
+ if message is not None and len(self.try_get_tool_messages(message)) > 0:
1062
+ # if there is a valid "tool" message (either JSON or via `function_call`)
1063
+ # then LLM cannot respond to it
1064
+ return False
1065
+
1066
+ return True
1067
+
1068
+ def can_respond(self, message: Optional[str | ChatDocument] = None) -> bool:
1069
+ """
1070
+ Whether the agent can respond to a message.
1071
+ Used in Task.py to skip a sub-task when we know it would not respond.
1072
+ Args:
1073
+ message (str|ChatDocument): message or ChatDocument object to respond to.
1074
+ """
1075
+ tools = self.try_get_tool_messages(message)
1076
+ if len(tools) == 0 and self.config.respond_tools_only:
1077
+ return False
1078
+ if message is not None and self.has_only_unhandled_tools(message):
1079
+ # The message has tools that are NOT enabled to be handled by this agent,
1080
+ # which means the agent cannot respond to it.
1081
+ return False
1082
+ return True
1083
+
1084
+ def create_llm_response(
1085
+ self,
1086
+ content: str | None = None,
1087
+ content_any: Any = None,
1088
+ tool_messages: List[ToolMessage] = [],
1089
+ oai_tool_calls: None | List[OpenAIToolCall] = None,
1090
+ oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
1091
+ oai_tool_id2result: OrderedDict[str, str] | None = None,
1092
+ function_call: LLMFunctionCall | None = None,
1093
+ recipient: str = "",
1094
+ ) -> ChatDocument:
1095
+ """Template for llm_response."""
1096
+ return self.response_template(
1097
+ Entity.LLM,
1098
+ content=content,
1099
+ content_any=content_any,
1100
+ tool_messages=tool_messages,
1101
+ oai_tool_calls=oai_tool_calls,
1102
+ oai_tool_choice=oai_tool_choice,
1103
+ oai_tool_id2result=oai_tool_id2result,
1104
+ function_call=function_call,
1105
+ recipient=recipient,
1106
+ )
1107
+
1108
+ @no_type_check
1109
+ async def llm_response_async(
1110
+ self,
1111
+ message: Optional[str | ChatDocument] = None,
1112
+ ) -> Optional[ChatDocument]:
1113
+ """
1114
+ Asynch version of `llm_response`. See there for details.
1115
+ """
1116
+ if message is None or not self.llm_can_respond(message):
1117
+ return None
1118
+
1119
+ if isinstance(message, ChatDocument):
1120
+ prompt = message.content
1121
+ else:
1122
+ prompt = message
1123
+
1124
+ output_len = self.config.llm.model_max_output_tokens
1125
+ if self.num_tokens(prompt) + output_len > self.llm.completion_context_length():
1126
+ output_len = self.llm.completion_context_length() - self.num_tokens(prompt)
1127
+ if output_len < self.config.llm.min_output_tokens:
1128
+ raise ValueError(
1129
+ """
1130
+ Token-length of Prompt + Output is longer than the
1131
+ completion context length of the LLM!
1132
+ """
1133
+ )
1134
+ else:
1135
+ logger.warning(
1136
+ f"""
1137
+ Requested output length has been shortened to {output_len}
1138
+ so that the total length of Prompt + Output is less than
1139
+ the completion context length of the LLM.
1140
+ """
1141
+ )
1142
+
1143
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
1144
+ response = await self.llm.agenerate(prompt, output_len)
1145
+
1146
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
1147
+ # We would have already displayed the msg "live" ONLY if
1148
+ # streaming was enabled, AND we did not find a cached response.
1149
+ # If we are here, it means the response has not yet been displayed.
1150
+ cached = f"[red]{self.indent}(cached)[/red]" if response.cached else ""
1151
+ print(cached + "[green]" + escape(response.message))
1152
+ async with self.lock:
1153
+ self.update_token_usage(
1154
+ response,
1155
+ prompt,
1156
+ self.llm.get_stream(),
1157
+ chat=False, # i.e. it's a completion model not chat model
1158
+ print_response_stats=self.config.show_stats and not settings.quiet,
1159
+ )
1160
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
1161
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
1162
+ cdoc.metadata.tool_ids = (
1163
+ [] if isinstance(message, str) else message.metadata.tool_ids
1164
+ )
1165
+ return cdoc
1166
+
1167
+ @no_type_check
1168
+ def llm_response(
1169
+ self,
1170
+ message: Optional[str | ChatDocument] = None,
1171
+ ) -> Optional[ChatDocument]:
1172
+ """
1173
+ LLM response to a prompt.
1174
+ Args:
1175
+ message (str|ChatDocument): prompt string, or ChatDocument object
1176
+
1177
+ Returns:
1178
+ Response from LLM, packaged as a ChatDocument
1179
+ """
1180
+ if message is None or not self.llm_can_respond(message):
1181
+ return None
1182
+
1183
+ if isinstance(message, ChatDocument):
1184
+ prompt = message.content
1185
+ else:
1186
+ prompt = message
1187
+
1188
+ with ExitStack() as stack: # for conditionally using rich spinner
1189
+ if not self.llm.get_stream():
1190
+ # show rich spinner only if not streaming!
1191
+ cm = status("LLM responding to message...")
1192
+ stack.enter_context(cm)
1193
+ output_len = self.config.llm.model_max_output_tokens
1194
+ if (
1195
+ self.num_tokens(prompt) + output_len
1196
+ > self.llm.completion_context_length()
1197
+ ):
1198
+ output_len = self.llm.completion_context_length() - self.num_tokens(
1199
+ prompt
1200
+ )
1201
+ if output_len < self.config.llm.min_output_tokens:
1202
+ raise ValueError(
1203
+ """
1204
+ Token-length of Prompt + Output is longer than the
1205
+ completion context length of the LLM!
1206
+ """
1207
+ )
1208
+ else:
1209
+ logger.warning(
1210
+ f"""
1211
+ Requested output length has been shortened to {output_len}
1212
+ so that the total length of Prompt + Output is less than
1213
+ the completion context length of the LLM.
1214
+ """
1215
+ )
1216
+ if self.llm.get_stream() and not settings.quiet:
1217
+ console.print(f"[green]{self.indent}", end="")
1218
+ response = self.llm.generate(prompt, output_len)
1219
+
1220
+ if not self.llm.get_stream() or response.cached and not settings.quiet:
1221
+ # we would have already displayed the msg "live" ONLY if
1222
+ # streaming was enabled, AND we did not find a cached response
1223
+ # If we are here, it means the response has not yet been displayed.
1224
+ cached = "[red](cached)[/red]" if response.cached else ""
1225
+ console.print(f"[green]{self.indent}", end="")
1226
+ print(cached + "[green]" + escape(response.message))
1227
+ self.update_token_usage(
1228
+ response,
1229
+ prompt,
1230
+ self.llm.get_stream(),
1231
+ chat=False, # i.e. it's a completion model not chat model
1232
+ print_response_stats=self.config.show_stats and not settings.quiet,
1233
+ )
1234
+ cdoc = ChatDocument.from_LLMResponse(response, displayed=True)
1235
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
1236
+ cdoc.metadata.tool_ids = (
1237
+ [] if isinstance(message, str) else message.metadata.tool_ids
1238
+ )
1239
+ return cdoc
1240
+
1241
+ def has_tool_message_attempt(self, msg: str | ChatDocument | None) -> bool:
1242
+ """
1243
+ Check whether msg contains a Tool/fn-call attempt (by the LLM).
1244
+
1245
+ CAUTION: This uses self.get_tool_messages(msg) which as a side-effect
1246
+ may update msg.tool_messages when msg is a ChatDocument, if there are
1247
+ any tools in msg.
1248
+ """
1249
+ if msg is None:
1250
+ return False
1251
+ if isinstance(msg, ChatDocument):
1252
+ if len(msg.tool_messages) > 0:
1253
+ return True
1254
+ if msg.metadata.sender != Entity.LLM:
1255
+ return False
1256
+ try:
1257
+ tools = self.get_tool_messages(msg)
1258
+ return len(tools) > 0
1259
+ except (ValidationError, XMLException):
1260
+ # there is a tool/fn-call attempt but had a validation error,
1261
+ # so we still consider this a tool message "attempt"
1262
+ return True
1263
+ return False
1264
+
1265
+ def _tool_recipient_match(self, tool: ToolMessage) -> bool:
1266
+ """Is tool enabled for handling by this agent and intended for this
1267
+ agent to handle (i.e. if there's any explicit `recipient` field exists in
1268
+ tool, then it matches this agent's name)?
1269
+ """
1270
+ if tool.default_value("request") not in self.llm_tools_handled:
1271
+ return False
1272
+ if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
1273
+ return tool.recipient == "" or tool.recipient == self.config.name
1274
+ return True
1275
+
1276
+ def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
1277
+ """
1278
+ Does the msg have at least one tool, and none of the tools in the msg are
1279
+ handleable by this agent?
1280
+ """
1281
+ if msg is None:
1282
+ return False
1283
+ tools = self.try_get_tool_messages(msg, all_tools=True)
1284
+ if len(tools) == 0:
1285
+ return False
1286
+ return all(not self._tool_recipient_match(t) for t in tools)
1287
+
1288
+ def try_get_tool_messages(
1289
+ self,
1290
+ msg: str | ChatDocument | None,
1291
+ all_tools: bool = False,
1292
+ ) -> List[ToolMessage]:
1293
+ try:
1294
+ return self.get_tool_messages(msg, all_tools)
1295
+ except (ValidationError, XMLException):
1296
+ return []
1297
+
1298
+ def get_tool_messages(
1299
+ self,
1300
+ msg: str | ChatDocument | None,
1301
+ all_tools: bool = False,
1302
+ ) -> List[ToolMessage]:
1303
+ """
1304
+ Get ToolMessages recognized in msg, handle-able by this agent.
1305
+ NOTE: as a side-effect, this will update msg.tool_messages
1306
+ when msg is a ChatDocument and msg contains tool messages.
1307
+ The intent here is that update=True should be set ONLY within agent_response()
1308
+ or agent_response_async() methods. In other words, we want to persist the
1309
+ msg.tool_messages only AFTER the agent has had a chance to handle the tools.
1310
+
1311
+ Args:
1312
+ msg (str|ChatDocument): the message to extract tools from.
1313
+ all_tools (bool):
1314
+ - if True, return all tools,
1315
+ i.e. any recognized tool in self.llm_tools_known,
1316
+ whether it is handled by this agent or not;
1317
+ - otherwise, return only the tools handled by this agent.
1318
+
1319
+ Returns:
1320
+ List[ToolMessage]: list of ToolMessage objects
1321
+ """
1322
+
1323
+ if msg is None:
1324
+ return []
1325
+
1326
+ if isinstance(msg, str):
1327
+ json_tools = self.get_formatted_tool_messages(msg)
1328
+ if all_tools:
1329
+ return json_tools
1330
+ else:
1331
+ return [
1332
+ t
1333
+ for t in json_tools
1334
+ if self._tool_recipient_match(t) and t.default_value("request")
1335
+ ]
1336
+
1337
+ if all_tools and len(msg.all_tool_messages) > 0:
1338
+ # We've already identified all_tool_messages in the msg;
1339
+ # return the corresponding ToolMessage objects
1340
+ return msg.all_tool_messages
1341
+ if len(msg.tool_messages) > 0:
1342
+ # We've already found tool_messages,
1343
+ # (either via OpenAI Fn-call or Langroid-native ToolMessage);
1344
+ # or they were added by an agent_response.
1345
+ # note these could be from a forwarded msg from another agent,
1346
+ # so return ONLY the messages THIS agent to enabled to handle.
1347
+ if all_tools:
1348
+ return msg.tool_messages
1349
+ return [t for t in msg.tool_messages if self._tool_recipient_match(t)]
1350
+ assert isinstance(msg, ChatDocument)
1351
+ if (
1352
+ msg.content != ""
1353
+ and msg.oai_tool_calls is None
1354
+ and msg.function_call is None
1355
+ ):
1356
+
1357
+ tools = self.get_formatted_tool_messages(
1358
+ msg.content, from_llm=msg.metadata.sender == Entity.LLM
1359
+ )
1360
+ msg.all_tool_messages = tools
1361
+ # filter for actually handle-able tools, and recipient is this agent
1362
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1363
+ msg.tool_messages = my_tools
1364
+
1365
+ if all_tools:
1366
+ return tools
1367
+ else:
1368
+ return my_tools
1369
+
1370
+ # otherwise, we look for `tool_calls` (possibly multiple)
1371
+ tools = self.get_oai_tool_calls_classes(msg)
1372
+ msg.all_tool_messages = tools
1373
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1374
+ msg.tool_messages = my_tools
1375
+
1376
+ if len(tools) == 0:
1377
+ # otherwise, we look for a `function_call`
1378
+ fun_call_cls = self.get_function_call_class(msg)
1379
+ tools = [fun_call_cls] if fun_call_cls is not None else []
1380
+ msg.all_tool_messages = tools
1381
+ my_tools = [t for t in tools if self._tool_recipient_match(t)]
1382
+ msg.tool_messages = my_tools
1383
+ if all_tools:
1384
+ return tools
1385
+ else:
1386
+ return my_tools
1387
+
1388
+ def get_formatted_tool_messages(
1389
+ self, input_str: str, from_llm: bool = True
1390
+ ) -> List[ToolMessage]:
1391
+ """
1392
+ Returns ToolMessage objects (tools) corresponding to
1393
+ tool-formatted substrings, if any.
1394
+ ASSUMPTION - These tools are either ALL JSON-based, or ALL XML-based
1395
+ (i.e. not a mix of both).
1396
+ Terminology: a "formatted tool msg" is one which the LLM generates as
1397
+ part of its raw string output, rather than within a JSON object
1398
+ in the API response (i.e. this method does not extract tools/fns returned
1399
+ by OpenAI's tools/fns API or similar APIs).
1400
+
1401
+ Args:
1402
+ input_str (str): input string, typically a message sent by an LLM
1403
+ from_llm (bool): whether the input was generated by the LLM. If so,
1404
+ we track malformed tool calls.
1405
+
1406
+ Returns:
1407
+ List[ToolMessage]: list of ToolMessage objects
1408
+ """
1409
+ self.tool_error = False
1410
+ substrings = XMLToolMessage.find_candidates(input_str)
1411
+ is_json = False
1412
+ if len(substrings) == 0:
1413
+ substrings = extract_top_level_json(input_str)
1414
+ is_json = len(substrings) > 0
1415
+ if not is_json:
1416
+ return []
1417
+
1418
+ results = [self._get_one_tool_message(j, is_json, from_llm) for j in substrings]
1419
+ valid_results = [r for r in results if r is not None]
1420
+ # If any tool is correctly formed we do not set the flag
1421
+ if len(valid_results) > 0:
1422
+ self.tool_error = False
1423
+ return valid_results
1424
+
1425
+ def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
1426
+ """
1427
+ From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
1428
+ corresponding to the `function_call` if it exists.
1429
+ """
1430
+ if msg.function_call is None:
1431
+ return None
1432
+ tool_name = msg.function_call.name
1433
+ tool_msg = msg.function_call.arguments or {}
1434
+ self.tool_error = False
1435
+ if tool_name not in self.llm_tools_handled:
1436
+ logger.warning(
1437
+ f"""
1438
+ The function_call '{tool_name}' is not handled
1439
+ by the agent named '{self.config.name}'!
1440
+ If you intended this agent to handle this function_call,
1441
+ either the fn-call name is incorrectly generated by the LLM,
1442
+ (in which case you may need to adjust your LLM instructions),
1443
+ or you need to enable this agent to handle this fn-call.
1444
+ """
1445
+ )
1446
+ if (
1447
+ tool_name not in self.all_llm_tools_known
1448
+ and msg.metadata.sender == Entity.LLM
1449
+ ):
1450
+ self.tool_error = True
1451
+ return None
1452
+ tool_class = self.llm_tools_map[tool_name]
1453
+ tool_msg.update(dict(request=tool_name))
1454
+ tool = tool_class.model_validate(tool_msg)
1455
+ return tool
1456
+
1457
+ def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
1458
+ """
1459
+ From ChatDocument (constructed from an LLM Response), get
1460
+ a list of ToolMessages corresponding to the `tool_calls`, if any.
1461
+ """
1462
+
1463
+ if msg.oai_tool_calls is None:
1464
+ return []
1465
+ tools = []
1466
+ all_errors = True
1467
+ for tc in msg.oai_tool_calls:
1468
+ if tc.function is None:
1469
+ continue
1470
+ tool_name = tc.function.name
1471
+ tool_msg = tc.function.arguments or {}
1472
+ if tool_name not in self.llm_tools_handled:
1473
+ logger.warning(
1474
+ f"""
1475
+ The tool_call '{tool_name}' is not handled
1476
+ by the agent named '{self.config.name}'!
1477
+ If you intended this agent to handle this function_call,
1478
+ either the fn-call name is incorrectly generated by the LLM,
1479
+ (in which case you may need to adjust your LLM instructions),
1480
+ or you need to enable this agent to handle this fn-call.
1481
+ """
1482
+ )
1483
+ continue
1484
+ all_errors = False
1485
+ tool_class = self.llm_tools_map[tool_name]
1486
+ tool_msg.update(dict(request=tool_name))
1487
+ tool = tool_class.model_validate(tool_msg)
1488
+ tool.id = tc.id or ""
1489
+ tools.append(tool)
1490
+ # When no tool is valid and the message was produced
1491
+ # by the LLM, set the recovery flag
1492
+ self.tool_error = all_errors and msg.metadata.sender == Entity.LLM
1493
+ return tools
1494
+
1495
+ def tool_validation_error(self, ve: ValidationError) -> str:
1496
+ """
1497
+ Handle a validation error raised when parsing a tool message,
1498
+ when there is a legit tool name used, but it has missing/bad fields.
1499
+ Args:
1500
+ tool (ToolMessage): The tool message that failed validation
1501
+ ve (ValidationError): The exception raised
1502
+
1503
+ Returns:
1504
+ str: The error message to send back to the LLM
1505
+ """
1506
+ tool_name = cast(ToolMessage, ve.model).default_value("request")
1507
+ bad_field_errors = "\n".join(
1508
+ [f"{e['loc']}: {e['msg']}" for e in ve.errors() if "loc" in e]
1509
+ )
1510
+ return f"""
1511
+ There were one or more errors in your attempt to use the
1512
+ TOOL or function_call named '{tool_name}':
1513
+ {bad_field_errors}
1514
+ Please write your message again, correcting the errors.
1515
+ """
1516
+
1517
+ def _get_multiple_orch_tool_errs(
1518
+ self, tools: List[ToolMessage]
1519
+ ) -> List[str | ChatDocument | None]:
1520
+ """
1521
+ Return error document if the message contains multiple orchestration tools
1522
+ """
1523
+ # check whether there are multiple orchestration-tools (e.g. DoneTool etc),
1524
+ # in which case set result to error-string since we don't yet support
1525
+ # multi-tools with one or more orch tools.
1526
+ from langroid.agent.tools.orchestration import (
1527
+ AgentDoneTool,
1528
+ AgentSendTool,
1529
+ DonePassTool,
1530
+ DoneTool,
1531
+ ForwardTool,
1532
+ PassTool,
1533
+ SendTool,
1534
+ )
1535
+ from langroid.agent.tools.recipient_tool import RecipientTool
1536
+
1537
+ ORCHESTRATION_TOOLS = (
1538
+ AgentDoneTool,
1539
+ DoneTool,
1540
+ PassTool,
1541
+ DonePassTool,
1542
+ ForwardTool,
1543
+ RecipientTool,
1544
+ SendTool,
1545
+ AgentSendTool,
1546
+ )
1547
+
1548
+ has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
1549
+ if has_orch and len(tools) > 1:
1550
+ return ["ERROR: Use ONE tool at a time!"] * len(tools)
1551
+
1552
+ return []
1553
+
1554
+ def _handle_message_final(
1555
+ self, tools: List[ToolMessage], results: List[str | ChatDocument | None]
1556
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1557
+ """
1558
+ Convert results to final response
1559
+ """
1560
+ # extract content from ChatDocument results so we have all str|None
1561
+ results = [r.content if isinstance(r, ChatDocument) else r for r in results]
1562
+
1563
+ tool_names = [t.default_value("request") for t in tools]
1564
+
1565
+ has_ids = all([t.id != "" for t in tools])
1566
+ if has_ids:
1567
+ id2result = OrderedDict(
1568
+ (t.id, r)
1569
+ for t, r in zip(tools, results)
1570
+ if r is not None and isinstance(r, str)
1571
+ )
1572
+ result_values = list(id2result.values())
1573
+ if len(id2result) > 1 and any(
1574
+ orch_str in r
1575
+ for r in result_values
1576
+ for orch_str in ORCHESTRATION_STRINGS
1577
+ ):
1578
+ # Cannot support multi-tool results containing orchestration strings!
1579
+ # Replace results with err string to force LLM to retry
1580
+ err_str = "ERROR: Please use ONE tool at a time!"
1581
+ id2result = OrderedDict((id, err_str) for id in id2result.keys())
1582
+
1583
+ name_results_list = [
1584
+ (name, r) for name, r in zip(tool_names, results) if r is not None
1585
+ ]
1586
+ if len(name_results_list) == 0:
1587
+ return None
1588
+
1589
+ # there was a non-None result
1590
+
1591
+ if has_ids and len(id2result) > 1:
1592
+ # if there are multiple OpenAI Tool results, return them as a dict
1593
+ return id2result
1594
+
1595
+ # multi-results: prepend the tool name to each result
1596
+ str_results = [f"Result from {name}: {r}" for name, r in name_results_list]
1597
+ final = "\n\n".join(str_results)
1598
+ return final
1599
+
1600
+ async def handle_message_async(
1601
+ self, msg: str | ChatDocument
1602
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1603
+ """
1604
+ Asynch version of `handle_message`. See there for details.
1605
+ """
1606
+ try:
1607
+ tools = self.get_tool_messages(msg)
1608
+ tools = [t for t in tools if self._tool_recipient_match(t)]
1609
+ except ValidationError as ve:
1610
+ # correct tool name but bad fields
1611
+ return self.tool_validation_error(ve)
1612
+ except XMLException as xe: # from XMLToolMessage parsing
1613
+ return str(xe)
1614
+ except ValueError:
1615
+ # invalid tool name
1616
+ # We return None since returning "invalid tool name" would
1617
+ # be considered a valid result in task loop, and would be treated
1618
+ # as a response to the tool message even though the tool was not intended
1619
+ # for this agent.
1620
+ return None
1621
+ if len(tools) > 1 and not self.config.allow_multiple_tools:
1622
+ return self.to_ChatDocument("ERROR: Use ONE tool at a time!")
1623
+ if len(tools) == 0:
1624
+ fallback_result = self.handle_message_fallback(msg)
1625
+ if fallback_result is None:
1626
+ return None
1627
+ return self.to_ChatDocument(
1628
+ fallback_result,
1629
+ chat_doc=msg if isinstance(msg, ChatDocument) else None,
1630
+ )
1631
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1632
+
1633
+ results = self._get_multiple_orch_tool_errs(tools)
1634
+ if not results:
1635
+ results = [
1636
+ await self.handle_tool_message_async(t, chat_doc=chat_doc)
1637
+ for t in tools
1638
+ ]
1639
+ # if there's a solitary ChatDocument|str result, return it as is
1640
+ if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
1641
+ return results[0]
1642
+
1643
+ return self._handle_message_final(tools, results)
1644
+
1645
+ def handle_message(
1646
+ self, msg: str | ChatDocument
1647
+ ) -> None | str | OrderedDict[str, str] | ChatDocument:
1648
+ """
1649
+ Handle a "tool" message either a string containing one or more
1650
+ valid "tool" JSON substrings, or a
1651
+ ChatDocument containing a `function_call` attribute.
1652
+ Handle with the corresponding handler method, and return
1653
+ the results as a combined string.
1654
+
1655
+ Args:
1656
+ msg (str | ChatDocument): The string or ChatDocument to handle
1657
+
1658
+ Returns:
1659
+ The result of the handler method can be:
1660
+ - None if no tools successfully handled, or no tools present
1661
+ - str if langroid-native JSON tools were handled, and results concatenated,
1662
+ OR there's a SINGLE OpenAI tool-call.
1663
+ (We do this so the common scenario of a single tool/fn-call
1664
+ has a simple behavior).
1665
+ - Dict[str, str] if multiple OpenAI tool-calls were handled
1666
+ (dict is an id->result map)
1667
+ - ChatDocument if a handler returned a ChatDocument, intended to be the
1668
+ final response of the `agent_response` method.
1669
+ """
1670
+ try:
1671
+ tools = self.get_tool_messages(msg)
1672
+ tools = [t for t in tools if self._tool_recipient_match(t)]
1673
+ except ValidationError as ve:
1674
+ # correct tool name but bad fields
1675
+ return self.tool_validation_error(ve)
1676
+ except XMLException as xe: # from XMLToolMessage parsing
1677
+ return str(xe)
1678
+ except ValueError:
1679
+ # invalid tool name
1680
+ # We return None since returning "invalid tool name" would
1681
+ # be considered a valid result in task loop, and would be treated
1682
+ # as a response to the tool message even though the tool was not intended
1683
+ # for this agent.
1684
+ return None
1685
+ if len(tools) == 0:
1686
+ fallback_result = self.handle_message_fallback(msg)
1687
+ if fallback_result is None:
1688
+ return None
1689
+ return self.to_ChatDocument(
1690
+ fallback_result,
1691
+ chat_doc=msg if isinstance(msg, ChatDocument) else None,
1692
+ )
1693
+
1694
+ results: List[str | ChatDocument | None] = []
1695
+ if len(tools) > 1 and not self.config.allow_multiple_tools:
1696
+ results = ["ERROR: Use ONE tool at a time!"] * len(tools)
1697
+ if not results:
1698
+ results = self._get_multiple_orch_tool_errs(tools)
1699
+ if not results:
1700
+ chat_doc = msg if isinstance(msg, ChatDocument) else None
1701
+ results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
1702
+ # if there's a solitary ChatDocument|str result, return it as is
1703
+ if len(results) == 1 and isinstance(results[0], (str, ChatDocument)):
1704
+ return results[0]
1705
+
1706
+ return self._handle_message_final(tools, results)
1707
+
1708
+ @property
1709
+ def all_llm_tools_known(self) -> set[str]:
1710
+ """All known tools; this may extend self.llm_tools_known."""
1711
+ return self.llm_tools_known
1712
+
1713
+ def handle_message_fallback(self, msg: str | ChatDocument) -> Any:
1714
+ """
1715
+ Fallback method for the case where the msg has no tools that
1716
+ can be handled by this agent.
1717
+ This method can be overridden by subclasses, e.g.,
1718
+ to create a "reminder" message when a tool is expected but the LLM "forgot"
1719
+ to generate one.
1720
+
1721
+ Args:
1722
+ msg (str | ChatDocument): The input msg to handle
1723
+ Returns:
1724
+ Any: The result of the handler method
1725
+ """
1726
+ return None
1727
+
1728
+ def _get_one_tool_message(
1729
+ self, tool_candidate_str: str, is_json: bool = True, from_llm: bool = True
1730
+ ) -> Optional[ToolMessage]:
1731
+ """
1732
+ Parse the tool_candidate_str into ANY ToolMessage KNOWN to agent --
1733
+ This includes non-used/handled tools, i.e. any tool in self.all_llm_tools_known.
1734
+ The exception to this is below where we try our best to infer the tool
1735
+ when the LLM has "forgotten" to include the "request" field in the tool str ---
1736
+ in this case we ONLY look at the possible set of HANDLED tools, i.e.
1737
+ self.llm_tools_handled.
1738
+ """
1739
+ if is_json:
1740
+ maybe_tool_dict = json.loads(tool_candidate_str)
1741
+ else:
1742
+ try:
1743
+ maybe_tool_dict = XMLToolMessage.extract_field_values(
1744
+ tool_candidate_str
1745
+ )
1746
+ except Exception as e:
1747
+ from langroid.exceptions import XMLException
1748
+
1749
+ raise XMLException(f"Error extracting XML fields:\n {str(e)}")
1750
+ # check if the maybe_tool_dict contains a "properties" field
1751
+ # which further contains the actual tool-call
1752
+ # (some weak LLMs do this). E.g. gpt-4o sometimes generates this:
1753
+ # TOOL: {
1754
+ # "type": "object",
1755
+ # "properties": {
1756
+ # "request": "square",
1757
+ # "number": 9
1758
+ # },
1759
+ # "required": [
1760
+ # "number",
1761
+ # "request"
1762
+ # ]
1763
+ # }
1764
+
1765
+ if not isinstance(maybe_tool_dict, dict):
1766
+ self.tool_error = from_llm
1767
+ return None
1768
+
1769
+ properties = maybe_tool_dict.get("properties")
1770
+ if isinstance(properties, dict):
1771
+ maybe_tool_dict = properties
1772
+ request = maybe_tool_dict.get("request")
1773
+ if request is None:
1774
+ if self.enabled_requests_for_inference is None:
1775
+ possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
1776
+ else:
1777
+ allowable = self.enabled_requests_for_inference.intersection(
1778
+ self.llm_tools_handled
1779
+ )
1780
+ possible = [self.llm_tools_map[r] for r in allowable]
1781
+
1782
+ default_keys = set(ToolMessage.model_fields.keys())
1783
+ request_keys = set(maybe_tool_dict.keys())
1784
+
1785
+ def maybe_parse(tool: type[ToolMessage]) -> Optional[ToolMessage]:
1786
+ all_keys = set(tool.model_fields.keys())
1787
+ non_inherited_keys = all_keys.difference(default_keys)
1788
+ # If the request has any keys not valid for the tool and
1789
+ # does not specify some key specific to the type
1790
+ # (e.g. not just `purpose`), the LLM must explicitly specify `request`
1791
+ if not (
1792
+ request_keys.issubset(all_keys)
1793
+ and len(request_keys.intersection(non_inherited_keys)) > 0
1794
+ ):
1795
+ return None
1796
+
1797
+ try:
1798
+ return tool.model_validate(maybe_tool_dict)
1799
+ except ValidationError:
1800
+ return None
1801
+
1802
+ candidate_tools = list(
1803
+ filter(
1804
+ lambda t: t is not None,
1805
+ map(maybe_parse, possible),
1806
+ )
1807
+ )
1808
+
1809
+ # If only one valid candidate exists, we infer
1810
+ # "request" to be the only possible value
1811
+ if len(candidate_tools) == 1:
1812
+ return candidate_tools[0]
1813
+ else:
1814
+ self.tool_error = from_llm
1815
+ return None
1816
+
1817
+ if not isinstance(request, str) or request not in self.all_llm_tools_known:
1818
+ self.tool_error = from_llm
1819
+ return None
1820
+
1821
+ message_class = self.llm_tools_map.get(request)
1822
+ if message_class is None:
1823
+ logger.warning(f"No message class found for request '{request}'")
1824
+ self.tool_error = from_llm
1825
+ return None
1826
+
1827
+ try:
1828
+ message = message_class.model_validate(maybe_tool_dict)
1829
+ except ValidationError as ve:
1830
+ self.tool_error = from_llm
1831
+ raise ve
1832
+ return message
1833
+
1834
+ def to_ChatDocument(
1835
+ self,
1836
+ msg: Any,
1837
+ orig_tool_name: str | None = None,
1838
+ chat_doc: Optional[ChatDocument] = None,
1839
+ author_entity: Entity = Entity.AGENT,
1840
+ ) -> Optional[ChatDocument]:
1841
+ """
1842
+ Convert result of a responder (agent_response or llm_response, or task.run()),
1843
+ or tool handler, or handle_message_fallback,
1844
+ to a ChatDocument, to enable handling by other
1845
+ responders/tasks in a task loop possibly involving multiple agents.
1846
+
1847
+ Args:
1848
+ msg (Any): The result of a responder or tool handler or task.run()
1849
+ orig_tool_name (str): The original tool name that generated the response,
1850
+ if any.
1851
+ chat_doc (ChatDocument): The original ChatDocument object that `msg`
1852
+ is a response to.
1853
+ author_entity (Entity): The intended author of the result ChatDocument
1854
+ """
1855
+ if msg is None or isinstance(msg, ChatDocument):
1856
+ return msg
1857
+
1858
+ is_agent_author = author_entity == Entity.AGENT
1859
+
1860
+ if isinstance(msg, str):
1861
+ return self.response_template(author_entity, content=msg, content_any=msg)
1862
+ elif isinstance(msg, ToolMessage):
1863
+ # result is a ToolMessage, so...
1864
+ result_tool_name = msg.default_value("request")
1865
+ if (
1866
+ is_agent_author
1867
+ and result_tool_name in self.llm_tools_handled
1868
+ and (orig_tool_name is None or orig_tool_name != result_tool_name)
1869
+ ):
1870
+ # TODO: do we need to remove the tool message from the chat_doc?
1871
+ # if (chat_doc is not None and
1872
+ # msg in chat_doc.tool_messages):
1873
+ # chat_doc.tool_messages.remove(msg)
1874
+ # if we can handle it, do so
1875
+ result = self.handle_tool_message(msg, chat_doc=chat_doc)
1876
+ if result is not None and isinstance(result, ChatDocument):
1877
+ return result
1878
+ else:
1879
+ # else wrap it in an agent response and return it so
1880
+ # orchestrator can find a respondent
1881
+ return self.response_template(author_entity, tool_messages=[msg])
1882
+ else:
1883
+ result = to_string(msg)
1884
+
1885
+ return (
1886
+ None
1887
+ if result is None
1888
+ else self.response_template(author_entity, content=result, content_any=msg)
1889
+ )
1890
+
1891
+ def from_ChatDocument(self, msg: ChatDocument, output_type: Type[T]) -> Optional[T]:
1892
+ """
1893
+ Extract a desired output_type from a ChatDocument object.
1894
+ We use this fallback order:
1895
+ - if `msg.content_any` exists and matches the output_type, return it
1896
+ - if `msg.content` exists and output_type is str return it
1897
+ - if output_type is a ToolMessage, return the first tool in `msg.tool_messages`
1898
+ - if output_type is a list of ToolMessage,
1899
+ return all tools in `msg.tool_messages`
1900
+ - search for a tool in `msg.tool_messages` that has a field of output_type,
1901
+ and if found, return that field value
1902
+ - return None if all the above fail
1903
+ """
1904
+ content = msg.content
1905
+ if output_type is str and content != "":
1906
+ return cast(T, content)
1907
+ content_any = msg.content_any
1908
+ if content_any is not None and isinstance(content_any, output_type):
1909
+ return cast(T, content_any)
1910
+
1911
+ tools = self.try_get_tool_messages(msg, all_tools=True)
1912
+
1913
+ if get_origin(output_type) is list:
1914
+ list_element_type = get_args(output_type)[0]
1915
+ if issubclass(list_element_type, ToolMessage):
1916
+ # list_element_type is a subclass of ToolMessage:
1917
+ # We output a list of objects derived from list_element_type
1918
+ return cast(
1919
+ T,
1920
+ [t for t in tools if isinstance(t, list_element_type)],
1921
+ )
1922
+ elif get_origin(output_type) is None and issubclass(output_type, ToolMessage):
1923
+ # output_type is a subclass of ToolMessage:
1924
+ # return the first tool that has this specific output_type
1925
+ for tool in tools:
1926
+ if isinstance(tool, output_type):
1927
+ return cast(T, tool)
1928
+ return None
1929
+ elif get_origin(output_type) is None and output_type in (str, int, float, bool):
1930
+ # attempt to get the output_type from the content,
1931
+ # if it's a primitive type
1932
+ primitive_value = from_string(content, output_type) # type: ignore
1933
+ if primitive_value is not None:
1934
+ return cast(T, primitive_value)
1935
+
1936
+ # then search for output_type as a field in a tool
1937
+ for tool in tools:
1938
+ value = tool.get_value_of_type(output_type)
1939
+ if value is not None:
1940
+ return cast(T, value)
1941
+ return None
1942
+
1943
+ def _maybe_truncate_result(
1944
+ self, result: str | ChatDocument | None, max_tokens: int | None
1945
+ ) -> str | ChatDocument | None:
1946
+ """
1947
+ Truncate the result string to `max_tokens` tokens.
1948
+ """
1949
+ if result is None or max_tokens is None:
1950
+ return result
1951
+ result_str = result.content if isinstance(result, ChatDocument) else result
1952
+ num_tokens = (
1953
+ self.parser.num_tokens(result_str)
1954
+ if self.parser is not None
1955
+ else len(result_str) / 4.0
1956
+ )
1957
+ if num_tokens <= max_tokens:
1958
+ return result
1959
+ truncate_warning = f"""
1960
+ The TOOL result was large, so it was truncated to {max_tokens} tokens.
1961
+ To get the full result, the TOOL must be called again.
1962
+ """
1963
+ if isinstance(result, str):
1964
+ return (
1965
+ self.parser.truncate_tokens(result, max_tokens)
1966
+ if self.parser is not None
1967
+ else result[: max_tokens * 4] # approx truncate
1968
+ ) + truncate_warning
1969
+ elif isinstance(result, ChatDocument):
1970
+ result.content = (
1971
+ self.parser.truncate_tokens(result.content, max_tokens)
1972
+ if self.parser is not None
1973
+ else result.content[: max_tokens * 4] # approx truncate
1974
+ ) + truncate_warning
1975
+ return result
1976
+
1977
+ async def handle_tool_message_async(
1978
+ self,
1979
+ tool: ToolMessage,
1980
+ chat_doc: Optional[ChatDocument] = None,
1981
+ ) -> None | str | ChatDocument:
1982
+ """
1983
+ Asynch version of `handle_tool_message`. See there for details.
1984
+ """
1985
+ tool_name = tool.default_value("request")
1986
+ if hasattr(tool, "_handler"):
1987
+ handler_name = getattr(tool, "_handler", tool_name)
1988
+ else:
1989
+ handler_name = tool_name
1990
+ handler_method = getattr(self, handler_name + "_async", None)
1991
+ if handler_method is None:
1992
+ return self.handle_tool_message(tool, chat_doc=chat_doc)
1993
+ has_chat_doc_arg = (
1994
+ chat_doc is not None
1995
+ and "chat_doc" in inspect.signature(handler_method).parameters
1996
+ )
1997
+ try:
1998
+ if has_chat_doc_arg:
1999
+ maybe_result = await handler_method(tool, chat_doc=chat_doc)
2000
+ else:
2001
+ maybe_result = await handler_method(tool)
2002
+ result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
2003
+ except Exception as e:
2004
+ # raise the error here since we are sure it's
2005
+ # not a pydantic validation error,
2006
+ # which we check in `handle_message`
2007
+ raise e
2008
+ return self._maybe_truncate_result(
2009
+ result, tool._max_result_tokens
2010
+ ) # type: ignore
2011
+
2012
+ def handle_tool_message(
2013
+ self,
2014
+ tool: ToolMessage,
2015
+ chat_doc: Optional[ChatDocument] = None,
2016
+ ) -> None | str | ChatDocument:
2017
+ """
2018
+ Respond to a tool request from the LLM, in the form of an ToolMessage object.
2019
+ Args:
2020
+ tool: ToolMessage object representing the tool request.
2021
+ chat_doc: Optional ChatDocument object containing the tool request.
2022
+ This is passed to the tool-handler method only if it has a `chat_doc`
2023
+ argument.
2024
+
2025
+ Returns:
2026
+
2027
+ """
2028
+ tool_name = tool.default_value("request")
2029
+ if hasattr(tool, "_handler"):
2030
+ handler_name = getattr(tool, "_handler", tool_name)
2031
+ else:
2032
+ handler_name = tool_name
2033
+ handler_method = getattr(self, handler_name, None)
2034
+ if handler_method is None:
2035
+ return None
2036
+ has_chat_doc_arg = (
2037
+ chat_doc is not None
2038
+ and "chat_doc" in inspect.signature(handler_method).parameters
2039
+ )
2040
+ try:
2041
+ if has_chat_doc_arg:
2042
+ maybe_result = handler_method(tool, chat_doc=chat_doc)
2043
+ else:
2044
+ maybe_result = handler_method(tool)
2045
+ result = self.to_ChatDocument(maybe_result, tool_name, chat_doc)
2046
+ except Exception as e:
2047
+ # raise the error here since we are sure it's
2048
+ # not a pydantic validation error,
2049
+ # which we check in `handle_message`
2050
+ raise e
2051
+ return self._maybe_truncate_result(
2052
+ result, tool._max_result_tokens
2053
+ ) # type: ignore
2054
+
2055
+ def num_tokens(self, prompt: str | List[LLMMessage]) -> int:
2056
+ if self.parser is None:
2057
+ raise ValueError("Parser must be set, to count tokens")
2058
+ if isinstance(prompt, str):
2059
+ return self.parser.num_tokens(prompt)
2060
+ else:
2061
+ return sum(
2062
+ [
2063
+ self.parser.num_tokens(m.content)
2064
+ + self.parser.num_tokens(str(m.function_call or ""))
2065
+ for m in prompt
2066
+ ]
2067
+ )
2068
+
2069
+ def _get_response_stats(
2070
+ self, chat_length: int, tot_cost: float, response: LLMResponse
2071
+ ) -> str:
2072
+ """
2073
+ Get LLM response stats as a string
2074
+
2075
+ Args:
2076
+ chat_length (int): number of messages in the chat
2077
+ tot_cost (float): total cost of the chat so far
2078
+ response (LLMResponse): LLMResponse object
2079
+ """
2080
+
2081
+ if self.config.llm is None:
2082
+ logger.warning("LLM config is None, cannot get response stats")
2083
+ return ""
2084
+ if response.usage:
2085
+ in_tokens = response.usage.prompt_tokens
2086
+ out_tokens = response.usage.completion_tokens
2087
+ llm_response_cost = format(response.usage.cost, ".4f")
2088
+ cumul_cost = format(tot_cost, ".4f")
2089
+ assert isinstance(self.llm, LanguageModel)
2090
+ context_length = self.llm.chat_context_length()
2091
+ max_out = self.config.llm.model_max_output_tokens
2092
+
2093
+ llm_model = (
2094
+ "no-LLM" if self.config.llm is None else self.llm.config.chat_model
2095
+ )
2096
+ # tot cost across all LLMs, agents
2097
+ all_cost = format(self.llm.tot_tokens_cost()[1], ".4f")
2098
+ return (
2099
+ f"[bold]Stats:[/bold] [magenta]N_MSG={chat_length}, "
2100
+ f"TOKENS: in={in_tokens}, out={out_tokens}, "
2101
+ f"max={max_out}, ctx={context_length}, "
2102
+ f"COST: now=${llm_response_cost}, cumul=${cumul_cost}, "
2103
+ f"tot=${all_cost} "
2104
+ f"[bold]({llm_model})[/bold][/magenta]"
2105
+ )
2106
+ return ""
2107
+
2108
+ def update_token_usage(
2109
+ self,
2110
+ response: LLMResponse,
2111
+ prompt: str | List[LLMMessage],
2112
+ stream: bool,
2113
+ chat: bool = True,
2114
+ print_response_stats: bool = True,
2115
+ ) -> None:
2116
+ """
2117
+ Updates `response.usage` obj (token usage and cost fields) if needed.
2118
+ An update is needed only if:
2119
+ - stream is True (i.e. streaming was enabled), and
2120
+ - the response was NOT obtained from cached, and
2121
+ - the API did NOT provide the usage/cost fields during streaming
2122
+ (As of Sep 2024, the OpenAI API started providing these; for other APIs
2123
+ this may not necessarily be the case).
2124
+
2125
+ Args:
2126
+ response (LLMResponse): LLMResponse object
2127
+ prompt (str | List[LLMMessage]): prompt or list of LLMMessage objects
2128
+ stream (bool): whether to update the usage in the response object
2129
+ if the response is not cached.
2130
+ chat (bool): whether this is a chat model or a completion model
2131
+ print_response_stats (bool): whether to print the response stats
2132
+ """
2133
+ if response is None or self.llm is None:
2134
+ return
2135
+
2136
+ no_usage_info = response.usage is None or response.usage.prompt_tokens == 0
2137
+ # Note: If response was not streamed, then
2138
+ # `response.usage` would already have been set by the API,
2139
+ # so we only need to update in the stream case.
2140
+ if stream and no_usage_info:
2141
+ # usage, cost = 0 when response is from cache
2142
+ prompt_tokens = 0
2143
+ completion_tokens = 0
2144
+ cost = 0.0
2145
+ if not response.cached:
2146
+ prompt_tokens = self.num_tokens(prompt)
2147
+ completion_tokens = self.num_tokens(response.message)
2148
+ if response.function_call is not None:
2149
+ completion_tokens += self.num_tokens(str(response.function_call))
2150
+ cost = self.compute_token_cost(prompt_tokens, 0, completion_tokens)
2151
+ response.usage = LLMTokenUsage(
2152
+ prompt_tokens=prompt_tokens,
2153
+ completion_tokens=completion_tokens,
2154
+ cost=cost,
2155
+ )
2156
+
2157
+ # update total counters
2158
+ if response.usage is not None:
2159
+ self.total_llm_token_cost += response.usage.cost
2160
+ self.total_llm_token_usage += response.usage.total_tokens
2161
+ self.llm.update_usage_cost(
2162
+ chat,
2163
+ response.usage.prompt_tokens,
2164
+ response.usage.completion_tokens,
2165
+ response.usage.cost,
2166
+ )
2167
+ chat_length = 1 if isinstance(prompt, str) else len(prompt)
2168
+ self.token_stats_str = self._get_response_stats(
2169
+ chat_length, self.total_llm_token_cost, response
2170
+ )
2171
+ if print_response_stats:
2172
+ print(self.indent + self.token_stats_str)
2173
+
2174
+ def compute_token_cost(self, prompt: int, cached: int, completion: int) -> float:
2175
+ price = cast(LanguageModel, self.llm).chat_cost()
2176
+ return (
2177
+ price[0] * (prompt - cached) + price[1] * cached + price[2] * completion
2178
+ ) / 1000
2179
+
2180
+ def ask_agent(
2181
+ self,
2182
+ agent: "Agent",
2183
+ request: str,
2184
+ no_answer: str = NO_ANSWER,
2185
+ user_confirm: bool = True,
2186
+ ) -> Optional[str]:
2187
+ """
2188
+ Send a request to another agent, possibly after confirming with the user.
2189
+ This is not currently used, since we rely on the task loop and
2190
+ `RecipientTool` to address requests to other agents. It is generally best to
2191
+ avoid using this method.
2192
+
2193
+ Args:
2194
+ agent (Agent): agent to ask
2195
+ request (str): request to send
2196
+ no_answer (str): expected response when agent does not know the answer
2197
+ user_confirm (bool): whether to gate the request with a human confirmation
2198
+
2199
+ Returns:
2200
+ str: response from agent
2201
+ """
2202
+ agent_type = type(agent).__name__
2203
+ if user_confirm:
2204
+ user_response = Prompt.ask(
2205
+ f"""[magenta]Here is the request or message:
2206
+ {request}
2207
+ Should I forward this to {agent_type}?""",
2208
+ default="y",
2209
+ choices=["y", "n"],
2210
+ )
2211
+ if user_response not in ["y", "yes"]:
2212
+ return None
2213
+ answer = agent.llm_response(request)
2214
+ if answer != no_answer:
2215
+ return (f"{agent_type} says: " + str(answer)).strip()
2216
+ return None